From 4acb89cdb107108f4c57aedc1c08f95bc8651211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Majdand=C5=BEi=C4=87?= Date: Thu, 27 Jun 2024 10:42:17 +0200 Subject: [PATCH] Rework WS server --- ws-server/connection.go | 98 ++++++++++++++++++++++++++++++++ ws-server/main.go | 121 ++-------------------------------------- ws-server/server.go | 50 +++++++++++++++++ 3 files changed, 152 insertions(+), 117 deletions(-) create mode 100644 ws-server/connection.go create mode 100644 ws-server/server.go diff --git a/ws-server/connection.go b/ws-server/connection.go new file mode 100644 index 0000000..84488dd --- /dev/null +++ b/ws-server/connection.go @@ -0,0 +1,98 @@ +package main + +import ( + "log" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type WSConnection struct { + conn *websocket.Conn + writeLock sync.Mutex + alive bool + ReadChan chan string + WriteChan chan string + ErrorChan chan error + IdleTimeout time.Duration + PingInterval time.Duration +} + +func NewConn(conn *websocket.Conn, server *WSServer) *WSConnection { + wsconn := &WSConnection{ + conn: conn, + alive: true, + IdleTimeout: server.IdleTimeout, + PingInterval: server.PingInterval, + ReadChan: make(chan string, 1024), + WriteChan: make(chan string, 1024), + ErrorChan: make(chan error, 1), + } + + return wsconn +} + +func (ws *WSConnection) Open() { + go ws.messageReader() + go ws.messageSender() + // go ws.pinger() +} + +func (ws *WSConnection) messageReader() { + log.Printf("Reading messages") + for { + _, message, err := ws.conn.ReadMessage() + if !ws.alive { + return + } + ws.conn.SetReadDeadline(time.Now().Add(ws.IdleTimeout)) + if err != nil { + ws.ErrorChan <- err + return + } + log.Printf("Received: %s, %d in output channel", message, len(ws.ReadChan)) + ws.ReadChan <- string(message) + } +} + +func (ws *WSConnection) messageSender() { + log.Printf("Sending messages") + for { + msg := <-ws.WriteChan + if !ws.alive { + return + } + ws.writeLock.Lock() + + ws.conn.SetWriteDeadline(time.Now().Add(ws.IdleTimeout)) + log.Printf("Sending: %s, %d in input channel", msg, len(ws.WriteChan)) + err := ws.conn.WriteMessage(websocket.TextMessage, []byte(msg)) + if err != nil { + log.Printf("Error during message writing: %v", err) + ws.ErrorChan <- err + return + } + ws.writeLock.Unlock() + } +} + +// func (ws *WSConnection) pinger() { +// log.Printf("Starting pinger, sleeping for %v", ws.PingInterval) +// for { +// time.Sleep(ws.PingInterval) +// if !ws.alive { +// return +// } + +// log.Printf("Ping") +// ws.writeLock.Lock() +// err := ws.conn.WriteMessage(websocket.PingMessage, nil) +// if err != nil { +// log.Println("Error during ping:", err) +// ws.ErrorChan <- err +// return +// } +// ws.writeLock.Unlock() +// } +// } diff --git a/ws-server/main.go b/ws-server/main.go index 06f8be9..0eb84b4 100644 --- a/ws-server/main.go +++ b/ws-server/main.go @@ -6,124 +6,18 @@ import ( "io" "log" "net/http" - "sync" "time" - - "github.com/gorilla/websocket" ) -var upgrader = websocket.Upgrader{} -var wsBroadcast = make(chan string, 128) -var connections = make(map[*WSConnection]bool) - -const TIMEOUT = 6 -const IDLE_TIMEOUT = TIMEOUT * time.Second -const PING_INTERVAL = (TIMEOUT / 2) * time.Second - -type WSConnection struct { - conn *websocket.Conn - writeLock sync.Mutex - ReadChan chan string - WriteChan chan string - ErrorChan chan error -} - -func (ws *WSConnection) messageReader() { - log.Printf("Reading messages") - for { - _, message, err := ws.conn.ReadMessage() - ws.conn.SetReadDeadline(time.Now().Add(IDLE_TIMEOUT)) - if err != nil { - ws.ErrorChan <- err - return - } - log.Printf("Received: %s, %d in output channel", message, len(ws.ReadChan)) - ws.ReadChan <- string(message) - } -} - -func (ws *WSConnection) messageSender() { - log.Printf("Sending messages") - for { - msg := <-ws.WriteChan - ws.writeLock.Lock() - - ws.conn.SetWriteDeadline(time.Now().Add(IDLE_TIMEOUT)) - log.Printf("Sending: %s, %d in input channel", msg, len(ws.WriteChan)) - err := ws.conn.WriteMessage(websocket.TextMessage, []byte(msg)) - if err != nil { - log.Printf("Error during message writing: %v", err) - ws.ErrorChan <- err - return - } - ws.writeLock.Unlock() - } -} - -func (ws *WSConnection) pinger() { - log.Printf("Starting pinger, sleeping for %v", PING_INTERVAL) - for { - time.Sleep(PING_INTERVAL) - - log.Printf("Ping") - ws.writeLock.Lock() - err := ws.conn.WriteMessage(websocket.PingMessage, nil) - if err != nil { - log.Println("Error during ping:", err) - ws.ErrorChan <- err - return - } - ws.writeLock.Unlock() - } -} - -func (ws *WSConnection) Open() { - log.Printf("Client connected") - ws.ReadChan = make(chan string, 1024) - ws.WriteChan = make(chan string, 1024) - ws.ErrorChan = make(chan error, 16) - - ws.conn.SetReadLimit(1024) - ws.conn.SetReadDeadline(time.Now().Add(IDLE_TIMEOUT)) - ws.conn.SetWriteDeadline(time.Now().Add(IDLE_TIMEOUT)) - ws.conn.SetPongHandler(func(string) error { - log.Println("Pong") - ws.conn.SetReadDeadline(time.Now().Add(IDLE_TIMEOUT)) - ws.conn.SetWriteDeadline(time.Now().Add(IDLE_TIMEOUT)) - return nil - }) - - connections[ws] = true - - go ws.messageReader() - go ws.messageSender() - go ws.pinger() - go func() { - for { - select { - case err := <-ws.ErrorChan: - log.Printf("Error: %v", err) - ws.conn.Close() - log.Printf("Client disconnected") - connections[ws] = false - return - // case msg := <-wsBroadcast: - // ws.WriteChan <- msg - } - } - }() -} +var server = New(10 * time.Second) func wsHandler(responseWriter http.ResponseWriter, request *http.Request) { - conn, err := upgrader.Upgrade(responseWriter, request, nil) + conn, err := server.Upgrader.Upgrade(responseWriter, request, nil) if err != nil { fmt.Println("Error during connection upgrade:", err) return } - - ws := new(WSConnection) - ws.conn = conn - ws.Open() + server.HandleNew(conn) } type DownloadReq struct { @@ -146,14 +40,7 @@ func handleDownload(responseWriter http.ResponseWriter, request *http.Request) { http.Error(responseWriter, "Error parsing JSON", http.StatusBadRequest) return } - - log.Printf("Received download request: %s, %d in channel", req.Link, len(wsBroadcast)) - go func() { - for ws := range connections { - ws.WriteChan <- req.Link - } - }() - // wsBroadcast <- req.Link + server.Broadcast <- req.Link } func main() { diff --git a/ws-server/server.go b/ws-server/server.go new file mode 100644 index 0000000..b990b3f --- /dev/null +++ b/ws-server/server.go @@ -0,0 +1,50 @@ +package main + +import ( + "log" + "time" + + "github.com/gorilla/websocket" +) + +type WSServer struct { + connections map[*WSConnection]bool + Upgrader websocket.Upgrader + Broadcast chan string + IdleTimeout time.Duration + PingInterval time.Duration +} + +func New(timeout time.Duration) *WSServer { + server := &WSServer{ + connections: make(map[*WSConnection]bool), + Upgrader: websocket.Upgrader{}, + Broadcast: make(chan string, 128), + IdleTimeout: timeout, + PingInterval: timeout / 2, + } + go func() { + for { + msg := <-server.Broadcast + for conn := range server.connections { + conn.WriteChan <- msg + } + } + }() + return server +} + +func (server *WSServer) HandleNew(conn *websocket.Conn) { + log.Printf("Client connected, now %d clients", len(server.connections)+1) + + wsconn := NewConn(conn, server) + go wsconn.Open() + server.connections[wsconn] = true + + go func() { + <-wsconn.ErrorChan + wsconn.alive = false + delete(server.connections, wsconn) + log.Printf("Client disconnected, now %d clients", len(server.connections)) + }() +}