diff --git a/ws-server/connection.go b/ws-server/connection.go index 84488dd..b366815 100644 --- a/ws-server/connection.go +++ b/ws-server/connection.go @@ -3,12 +3,14 @@ package main import ( "log" "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" ) type WSConnection struct { + id int32 conn *websocket.Conn writeLock sync.Mutex alive bool @@ -21,6 +23,7 @@ type WSConnection struct { func NewConn(conn *websocket.Conn, server *WSServer) *WSConnection { wsconn := &WSConnection{ + id: server.clientId, conn: conn, alive: true, IdleTimeout: server.IdleTimeout, @@ -29,6 +32,7 @@ func NewConn(conn *websocket.Conn, server *WSServer) *WSConnection { WriteChan: make(chan string, 1024), ErrorChan: make(chan error, 1), } + atomic.AddInt32(&server.clientId, 1) return wsconn } @@ -36,11 +40,16 @@ func NewConn(conn *websocket.Conn, server *WSServer) *WSConnection { func (ws *WSConnection) Open() { go ws.messageReader() go ws.messageSender() - // go ws.pinger() + go ws.pinger() + ws.conn.SetPongHandler(func(string) error { + // log.Printf("Client %d: Pong OK", ws.id) + ws.conn.SetReadDeadline(time.Now().Add(ws.IdleTimeout)) + return nil + }) } func (ws *WSConnection) messageReader() { - log.Printf("Reading messages") + log.Printf("Client %d: Reading messages", ws.id) for { _, message, err := ws.conn.ReadMessage() if !ws.alive { @@ -51,48 +60,52 @@ func (ws *WSConnection) messageReader() { ws.ErrorChan <- err return } - log.Printf("Received: %s, %d in output channel", message, len(ws.ReadChan)) + log.Printf("Client %d: Received: %s, %d in output channel", ws.id, message, len(ws.ReadChan)) ws.ReadChan <- string(message) } } func (ws *WSConnection) messageSender() { - log.Printf("Sending messages") + log.Printf("Client %d: Sending messages", ws.id) for { - msg := <-ws.WriteChan - if !ws.alive { - return - } - ws.writeLock.Lock() - - ws.conn.SetWriteDeadline(time.Now().Add(ws.IdleTimeout)) - log.Printf("Sending: %s, %d in input channel", msg, len(ws.WriteChan)) - err := ws.conn.WriteMessage(websocket.TextMessage, []byte(msg)) - if err != nil { - log.Printf("Error during message writing: %v", err) - ws.ErrorChan <- err - return - } - ws.writeLock.Unlock() + func() { + msg, ok := <-ws.WriteChan + if !ok || !ws.alive { + return + } + ws.writeLock.Lock() + defer ws.writeLock.Unlock() + + ws.conn.SetWriteDeadline(time.Now().Add(ws.IdleTimeout)) + log.Printf("Client %d: Sending: %s, %d in input channel", ws.id, msg, len(ws.WriteChan)) + err := ws.conn.WriteMessage(websocket.TextMessage, []byte(msg)) + if err != nil { + log.Printf("Client %d: Error during message writing: %v", ws.id, err) + ws.ErrorChan <- err + return + } + }() } } -// func (ws *WSConnection) pinger() { -// log.Printf("Starting pinger, sleeping for %v", ws.PingInterval) -// for { -// time.Sleep(ws.PingInterval) -// if !ws.alive { -// return -// } +func (ws *WSConnection) pinger() { + log.Printf("Client %d: Starting pinger, sleeping for %v", ws.id, ws.PingInterval) + for { + time.Sleep(ws.PingInterval) + if !ws.alive { + return + } -// log.Printf("Ping") -// ws.writeLock.Lock() -// err := ws.conn.WriteMessage(websocket.PingMessage, nil) -// if err != nil { -// log.Println("Error during ping:", err) -// ws.ErrorChan <- err -// return -// } -// ws.writeLock.Unlock() -// } -// } + // log.Printf("Client %d: Ping", ws.id) + ws.writeLock.Lock() + err := ws.conn.WriteMessage(websocket.PingMessage, nil) + if err != nil { + log.Printf("Client %d: Error during ping: %+v", ws.id, err) + ws.ErrorChan <- err + return + } + ws.conn.SetWriteDeadline(time.Now().Add(ws.IdleTimeout)) + ws.writeLock.Unlock() + // log.Printf("Client %d: Ping OK", ws.id) + } +} diff --git a/ws-server/main.go b/ws-server/main.go index 0eb84b4..af888ef 100644 --- a/ws-server/main.go +++ b/ws-server/main.go @@ -43,9 +43,11 @@ func handleDownload(responseWriter http.ResponseWriter, request *http.Request) { server.Broadcast <- req.Link } -func main() { - log.SetFlags(log.Lmicroseconds) +func init() { + log.SetFlags(log.Lmicroseconds | log.Lshortfile) +} +func main() { http.HandleFunc("/ws", wsHandler) http.HandleFunc("/download", handleDownload) log.Println("Server starting on :8080") diff --git a/ws-server/server.go b/ws-server/server.go index b990b3f..f6deaea 100644 --- a/ws-server/server.go +++ b/ws-server/server.go @@ -9,6 +9,7 @@ import ( type WSServer struct { connections map[*WSConnection]bool + clientId int32 Upgrader websocket.Upgrader Broadcast chan string IdleTimeout time.Duration @@ -42,9 +43,12 @@ func (server *WSServer) HandleNew(conn *websocket.Conn) { server.connections[wsconn] = true go func() { - <-wsconn.ErrorChan + err := <-wsconn.ErrorChan wsconn.alive = false + close(wsconn.ReadChan) + close(wsconn.WriteChan) + close(wsconn.ErrorChan) + log.Printf("Client %d: disconnected due to %+v, now %d clients", wsconn.id, err, len(server.connections)) delete(server.connections, wsconn) - log.Printf("Client disconnected, now %d clients", len(server.connections)) }() }