Rewrite to use fasthttp

This commit is contained in:
2025-10-10 20:19:50 +02:00
parent 2c727632b8
commit a1f568cbe6
4 changed files with 108 additions and 125 deletions

View File

@@ -9,7 +9,6 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
@@ -20,6 +19,8 @@ import (
"gorm.io/gorm"
logger "git.site.quack-lab.dev/dave/cylogger"
"github.com/fasthttp/router"
"github.com/valyala/fasthttp"
)
const (
@@ -43,8 +44,7 @@ type SSO struct {
scopes []string
db DB
mu sync.Mutex
server *http.Server
mux *http.ServeMux
router *router.Router
state string
callbackChan chan struct {
code string
@@ -79,10 +79,10 @@ func NewSSO(clientID, redirectURI string, scopes []string) (*SSO, error) {
return s, nil
}
// SetMuxer allows the SSO to use an existing HTTP muxer instead of creating its own server
func (s *SSO) SetMuxer(mux *http.ServeMux) {
s.mux = mux
logger.Debug("SSO configured to use existing HTTP muxer")
// SetRouter allows the SSO to use an existing fasthttp router
func (s *SSO) SetRouter(r *router.Router) {
s.router = r
logger.Debug("SSO configured to use existing fasthttp router")
}
func (s *SSO) initDB() error {
@@ -181,19 +181,12 @@ func (s *SSO) startAuthFlow(ctx context.Context, characterName string) error {
logger.Info("Waiting for authentication...")
// Setup callback handling
if s.mux != nil {
logger.Debug("Using existing HTTP muxer for callback handling")
s.setupCallbackHandler()
} else {
logger.Debug("Starting dedicated callback server")
server, err := s.startCallbackServer()
if err != nil {
logger.Error("Failed to start callback server: %v", err)
return err
}
s.server = server
defer server.Shutdown(ctx)
if s.router == nil {
logger.Error("No router configured for callback handling")
return errors.New("no router configured for callback handling")
}
logger.Debug("Using fasthttp router for callback handling")
s.setupCallbackHandler()
// Wait for callback
logger.Debug("Waiting for authentication callback")
@@ -251,14 +244,28 @@ func (s *SSO) setupCallbackHandler() {
}
logger.Debug("Setting up callback handler on path: %s", u.Path)
s.mux.HandleFunc(u.Path, s.handleCallback)
s.router.GET(u.Path, s.handleCallback)
}
func (s *SSO) handleCallback(w http.ResponseWriter, r *http.Request) {
logger.Debug("Received callback request: %s %s", r.Method, r.URL.String())
if r.Method != http.MethodGet {
logger.Warning("Invalid callback method: %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
func (s *SSO) handleCallback(ctx *fasthttp.RequestCtx) {
s.processCallback(
ctx.IsGet(),
string(ctx.QueryArgs().Peek("code")),
string(ctx.QueryArgs().Peek("state")),
func(status int, body string) {
ctx.SetStatusCode(status)
ctx.WriteString(body)
},
func(contentType string) {
ctx.SetContentType(contentType)
},
)
}
func (s *SSO) processCallback(isGet bool, code, state string, writeResponse func(int, string), setContentType func(string)) {
if !isGet {
logger.Warning("Invalid callback method")
writeResponse(http.StatusMethodNotAllowed, "Method not allowed")
s.callbackChan <- struct {
code string
state string
@@ -266,13 +273,10 @@ func (s *SSO) handleCallback(w http.ResponseWriter, r *http.Request) {
}{"", "", errors.New("method not allowed")}
return
}
q := r.URL.Query()
code := q.Get("code")
st := q.Get("state")
if code == "" || st == "" || st != s.state {
logger.Error("Invalid SSO response: code=%s, state=%s, expected_state=%s", code, st, s.state)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("Invalid SSO response"))
if code == "" || state == "" || state != s.state {
logger.Error("Invalid SSO response: code=%s, state=%s, expected_state=%s", code, state, s.state)
writeResponse(http.StatusBadRequest, "Invalid SSO response")
s.callbackChan <- struct {
code string
state string
@@ -280,52 +284,15 @@ func (s *SSO) handleCallback(w http.ResponseWriter, r *http.Request) {
}{"", "", errors.New("invalid state")}
return
}
logger.Info("Received valid callback, exchanging token for code: %s", code)
w.Header().Set("Content-Type", "text/html")
_, _ = w.Write([]byte("<html><body><h1>Login successful!</h1><p>You can close this window.</p></body></html>"))
setContentType("text/html")
writeResponse(http.StatusOK, "<html><body><h1>Login successful!</h1><p>You can close this window.</p></body></html>")
s.callbackChan <- struct {
code string
state string
err error
}{code, st, nil}
}
func (s *SSO) startCallbackServer() (*http.Server, error) {
logger.Debug("Starting dedicated callback server for redirect URI: %s", s.redirectURI)
u, err := url.Parse(s.redirectURI)
if err != nil {
logger.Error("Failed to parse redirect URI: %v", err)
return nil, err
}
if u.Scheme != "http" && u.Scheme != "https" {
logger.Error("Invalid redirect URI scheme: %s", u.Scheme)
return nil, errors.New("redirect URI must be http(s)")
}
hostPort := u.Host
if !strings.Contains(hostPort, ":") {
if u.Scheme == "https" {
hostPort += ":443"
} else {
hostPort += ":80"
}
}
logger.Debug("Callback server will listen on %s", hostPort)
mux := http.NewServeMux()
mux.HandleFunc(u.Path, s.handleCallback)
ln, err := net.Listen("tcp", hostPort)
if err != nil {
logger.Error("Failed to listen on %s: %v", hostPort, err)
return nil, err
}
server := &http.Server{Handler: mux}
go func() {
logger.Debug("Callback server started successfully")
_ = server.Serve(ln)
}()
return server, nil
}{code, state, nil}
}
func (s *SSO) waitForCallback() (code, state string, err error) {