diff --git a/downloader/ws-client.go b/downloader/ws-client.go index ca9357f..ba982a9 100644 --- a/downloader/ws-client.go +++ b/downloader/ws-client.go @@ -2,22 +2,30 @@ package main import ( "log" + "sync" "time" "github.com/gorilla/websocket" ) +const TIMEOUT = 6 +const IDLE_TIMEOUT = TIMEOUT * time.Second +const PING_INTERVAL = (TIMEOUT - 1) * time.Second + type WSConnection struct { - url string - conn *websocket.Conn - errChan chan error - ReadChan chan string + url string + conn *websocket.Conn + errChan chan error + writeLock sync.Mutex + ReadChan chan string + WriteChan chan string } -func (ws *WSConnection) readMessage() { +func (ws *WSConnection) messageReader() { log.Printf("Reading messages") for { _, message, err := ws.conn.ReadMessage() + ws.conn.SetReadDeadline(time.Now().Add(IDLE_TIMEOUT)) if err != nil { ws.errChan <- err return @@ -27,12 +35,38 @@ func (ws *WSConnection) readMessage() { } } -func (ws *WSConnection) writeMessage(message string) { - err := ws.conn.WriteMessage(websocket.TextMessage, []byte(message)) - if err != nil { - log.Printf("Error during message writing: %v", err) - ws.errChan <- err - return +func (ws *WSConnection) messageSender() { + log.Printf("Sending messages") + for { + msg := <-ws.WriteChan + ws.writeLock.Lock() + + ws.conn.SetWriteDeadline(time.Now().Add(IDLE_TIMEOUT)) + log.Printf("Sending: %s", msg) + err := ws.conn.WriteMessage(websocket.TextMessage, []byte(msg)) + if err != nil { + log.Printf("Error during message writing: %v", err) + ws.errChan <- err + return + } + ws.writeLock.Unlock() + } +} + +func (ws *WSConnection) pinger() { + log.Printf("Starting pinger, sleeping for %v", PING_INTERVAL) + for { + time.Sleep(PING_INTERVAL) + + // log.Printf("Ping") + ws.writeLock.Lock() + err := ws.conn.WriteMessage(websocket.PingMessage, nil) + if err != nil { + log.Println("Error during ping:", err) + ws.errChan <- err + return + } + ws.writeLock.Unlock() } } @@ -58,6 +92,20 @@ func (ws *WSConnection) Open() { ws.conn = conn ws.errChan = make(chan error) ws.ReadChan = make(chan string, 1024) - go ws.readMessage() + ws.WriteChan = make(chan string, 1024) + + ws.conn.SetReadLimit(1024) + ws.conn.SetReadDeadline(time.Now().Add(IDLE_TIMEOUT)) + ws.conn.SetWriteDeadline(time.Now().Add(IDLE_TIMEOUT)) + ws.conn.SetPongHandler(func(string) error { + // log.Println("Pong") + ws.conn.SetReadDeadline(time.Now().Add(IDLE_TIMEOUT)) + ws.conn.SetWriteDeadline(time.Now().Add(IDLE_TIMEOUT)) + return nil + }) + + go ws.messageReader() + go ws.messageSender() go ws.handleError() + go ws.pinger() } diff --git a/ws-server/main.go b/ws-server/main.go index ef4bb4a..b39dd8c 100644 --- a/ws-server/main.go +++ b/ws-server/main.go @@ -6,6 +6,8 @@ import ( "io" "log" "net/http" + "sync" + "time" "github.com/gorilla/websocket" ) @@ -13,6 +15,83 @@ import ( var upgrader = websocket.Upgrader{} var wsBroadcast = make(chan []byte, 100) +const TIMEOUT = 6 +const IDLE_TIMEOUT = TIMEOUT * time.Second +const PING_INTERVAL = (TIMEOUT - 1) * time.Second + +type WSConnection struct { + conn *websocket.Conn + writeLock sync.Mutex + ReadChan chan string + WriteChan chan string +} + +func (ws *WSConnection) messageReader() { + log.Printf("Reading messages") + for { + _, message, err := ws.conn.ReadMessage() + ws.conn.SetReadDeadline(time.Now().Add(IDLE_TIMEOUT)) + if err != nil { + return + } + log.Printf("Received: %s", message) + ws.ReadChan <- string(message) + } +} + +func (ws *WSConnection) messageSender() { + log.Printf("Sending messages") + for { + msg := <-ws.WriteChan + ws.writeLock.Lock() + + ws.conn.SetWriteDeadline(time.Now().Add(IDLE_TIMEOUT)) + log.Printf("Sending: %s", msg) + err := ws.conn.WriteMessage(websocket.TextMessage, []byte(msg)) + if err != nil { + log.Printf("Error during message writing: %v", err) + return + } + ws.writeLock.Unlock() + } +} + +func (ws *WSConnection) pinger() { + log.Printf("Starting pinger, sleeping for %v", PING_INTERVAL) + for { + time.Sleep(PING_INTERVAL) + + // log.Printf("Ping") + ws.writeLock.Lock() + err := ws.conn.WriteMessage(websocket.PingMessage, nil) + if err != nil { + log.Println("Error during ping:", err) + return + } + ws.writeLock.Unlock() + } +} + +func (ws *WSConnection) Open() { + log.Printf("Client connected") + ws.ReadChan = make(chan string, 1024) + ws.WriteChan = make(chan string, 1024) + + ws.conn.SetReadLimit(1024) + ws.conn.SetReadDeadline(time.Now().Add(IDLE_TIMEOUT)) + ws.conn.SetWriteDeadline(time.Now().Add(IDLE_TIMEOUT)) + ws.conn.SetPongHandler(func(string) error { + // log.Println("Pong") + ws.conn.SetReadDeadline(time.Now().Add(IDLE_TIMEOUT)) + ws.conn.SetWriteDeadline(time.Now().Add(IDLE_TIMEOUT)) + return nil + }) + + go ws.messageReader() + go ws.messageSender() + go ws.pinger() +} + func wsHandler(responseWriter http.ResponseWriter, request *http.Request) { conn, err := upgrader.Upgrade(responseWriter, request, nil) if err != nil { @@ -20,33 +99,9 @@ func wsHandler(responseWriter http.ResponseWriter, request *http.Request) { return } - go wsHandleRead(conn) - wsHandleWrite(conn) -} - -func wsHandleRead(conn *websocket.Conn) { - log.Printf("Starting read handler") - for { - messageType, packet, err := conn.ReadMessage() - if err != nil { - fmt.Println("Error during message reading:", err) - return - } - log.Printf("Received: %v %s", messageType, packet) - } -} - -func wsHandleWrite(conn *websocket.Conn) { - log.Println("Starting write handler") - for { - packet := <-wsBroadcast - log.Printf("Broadcasting: %s", packet) - err := conn.WriteMessage(websocket.TextMessage, packet) - if err != nil { - fmt.Println("Error during message writing:", err) - return - } - } + ws := new(WSConnection) + ws.conn = conn + ws.Open() } type DownloadReq struct {