Rewrite to use fasthttp
This commit is contained in:
113
esi_sso.go
113
esi_sso.go
@@ -9,7 +9,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@@ -20,6 +19,8 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
logger "git.site.quack-lab.dev/dave/cylogger"
|
||||
"github.com/fasthttp/router"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -43,8 +44,7 @@ type SSO struct {
|
||||
scopes []string
|
||||
db DB
|
||||
mu sync.Mutex
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
router *router.Router
|
||||
state string
|
||||
callbackChan chan struct {
|
||||
code string
|
||||
@@ -79,10 +79,10 @@ func NewSSO(clientID, redirectURI string, scopes []string) (*SSO, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// SetMuxer allows the SSO to use an existing HTTP muxer instead of creating its own server
|
||||
func (s *SSO) SetMuxer(mux *http.ServeMux) {
|
||||
s.mux = mux
|
||||
logger.Debug("SSO configured to use existing HTTP muxer")
|
||||
// SetRouter allows the SSO to use an existing fasthttp router
|
||||
func (s *SSO) SetRouter(r *router.Router) {
|
||||
s.router = r
|
||||
logger.Debug("SSO configured to use existing fasthttp router")
|
||||
}
|
||||
|
||||
func (s *SSO) initDB() error {
|
||||
@@ -181,19 +181,12 @@ func (s *SSO) startAuthFlow(ctx context.Context, characterName string) error {
|
||||
logger.Info("Waiting for authentication...")
|
||||
|
||||
// Setup callback handling
|
||||
if s.mux != nil {
|
||||
logger.Debug("Using existing HTTP muxer for callback handling")
|
||||
s.setupCallbackHandler()
|
||||
} else {
|
||||
logger.Debug("Starting dedicated callback server")
|
||||
server, err := s.startCallbackServer()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start callback server: %v", err)
|
||||
return err
|
||||
}
|
||||
s.server = server
|
||||
defer server.Shutdown(ctx)
|
||||
if s.router == nil {
|
||||
logger.Error("No router configured for callback handling")
|
||||
return errors.New("no router configured for callback handling")
|
||||
}
|
||||
logger.Debug("Using fasthttp router for callback handling")
|
||||
s.setupCallbackHandler()
|
||||
|
||||
// Wait for callback
|
||||
logger.Debug("Waiting for authentication callback")
|
||||
@@ -251,14 +244,28 @@ func (s *SSO) setupCallbackHandler() {
|
||||
}
|
||||
|
||||
logger.Debug("Setting up callback handler on path: %s", u.Path)
|
||||
s.mux.HandleFunc(u.Path, s.handleCallback)
|
||||
s.router.GET(u.Path, s.handleCallback)
|
||||
}
|
||||
|
||||
func (s *SSO) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Debug("Received callback request: %s %s", r.Method, r.URL.String())
|
||||
if r.Method != http.MethodGet {
|
||||
logger.Warning("Invalid callback method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
func (s *SSO) handleCallback(ctx *fasthttp.RequestCtx) {
|
||||
s.processCallback(
|
||||
ctx.IsGet(),
|
||||
string(ctx.QueryArgs().Peek("code")),
|
||||
string(ctx.QueryArgs().Peek("state")),
|
||||
func(status int, body string) {
|
||||
ctx.SetStatusCode(status)
|
||||
ctx.WriteString(body)
|
||||
},
|
||||
func(contentType string) {
|
||||
ctx.SetContentType(contentType)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (s *SSO) processCallback(isGet bool, code, state string, writeResponse func(int, string), setContentType func(string)) {
|
||||
if !isGet {
|
||||
logger.Warning("Invalid callback method")
|
||||
writeResponse(http.StatusMethodNotAllowed, "Method not allowed")
|
||||
s.callbackChan <- struct {
|
||||
code string
|
||||
state string
|
||||
@@ -266,13 +273,10 @@ func (s *SSO) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}{"", "", errors.New("method not allowed")}
|
||||
return
|
||||
}
|
||||
q := r.URL.Query()
|
||||
code := q.Get("code")
|
||||
st := q.Get("state")
|
||||
if code == "" || st == "" || st != s.state {
|
||||
logger.Error("Invalid SSO response: code=%s, state=%s, expected_state=%s", code, st, s.state)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("Invalid SSO response"))
|
||||
|
||||
if code == "" || state == "" || state != s.state {
|
||||
logger.Error("Invalid SSO response: code=%s, state=%s, expected_state=%s", code, state, s.state)
|
||||
writeResponse(http.StatusBadRequest, "Invalid SSO response")
|
||||
s.callbackChan <- struct {
|
||||
code string
|
||||
state string
|
||||
@@ -280,52 +284,15 @@ func (s *SSO) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}{"", "", errors.New("invalid state")}
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received valid callback, exchanging token for code: %s", code)
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
_, _ = w.Write([]byte("<html><body><h1>Login successful!</h1><p>You can close this window.</p></body></html>"))
|
||||
setContentType("text/html")
|
||||
writeResponse(http.StatusOK, "<html><body><h1>Login successful!</h1><p>You can close this window.</p></body></html>")
|
||||
s.callbackChan <- struct {
|
||||
code string
|
||||
state string
|
||||
err error
|
||||
}{code, st, nil}
|
||||
}
|
||||
|
||||
func (s *SSO) startCallbackServer() (*http.Server, error) {
|
||||
logger.Debug("Starting dedicated callback server for redirect URI: %s", s.redirectURI)
|
||||
u, err := url.Parse(s.redirectURI)
|
||||
if err != nil {
|
||||
logger.Error("Failed to parse redirect URI: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
logger.Error("Invalid redirect URI scheme: %s", u.Scheme)
|
||||
return nil, errors.New("redirect URI must be http(s)")
|
||||
}
|
||||
hostPort := u.Host
|
||||
if !strings.Contains(hostPort, ":") {
|
||||
if u.Scheme == "https" {
|
||||
hostPort += ":443"
|
||||
} else {
|
||||
hostPort += ":80"
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Callback server will listen on %s", hostPort)
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(u.Path, s.handleCallback)
|
||||
|
||||
ln, err := net.Listen("tcp", hostPort)
|
||||
if err != nil {
|
||||
logger.Error("Failed to listen on %s: %v", hostPort, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server := &http.Server{Handler: mux}
|
||||
go func() {
|
||||
logger.Debug("Callback server started successfully")
|
||||
_ = server.Serve(ln)
|
||||
}()
|
||||
return server, nil
|
||||
}{code, state, nil}
|
||||
}
|
||||
|
||||
func (s *SSO) waitForCallback() (code, state string, err error) {
|
||||
|
||||
Reference in New Issue
Block a user