package esi import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "strconv" "strings" "sync" "time" "go-eve-pi/repositories" "go-eve-pi/types" logger "git.site.quack-lab.dev/dave/cylogger" "github.com/fasthttp/router" "github.com/valyala/fasthttp" "gorm.io/gorm" ) const ( issuerAuthorizeURL = "https://login.eveonline.com/v2/oauth/authorize" issuerTokenURL = "https://login.eveonline.com/v2/oauth/token" ) // CharacterRepositoryInterface defines the interface for character operations type CharacterRepositoryInterface = repositories.CharacterRepositoryInterface type SSO struct { clientID string redirectURI string scopes []string characterRepo CharacterRepositoryInterface mu sync.Mutex router *router.Router state string callbackChan chan struct { code string state string err error } } // NewSSO creates a new SSO instance func NewSSO(clientID, redirectURI string, scopes []string, characterRepo CharacterRepositoryInterface) (*SSO, error) { logger.Info("Creating new SSO instance for clientID %s with redirectURI %s and scopes %v", clientID, redirectURI, scopes) s := &SSO{ clientID: clientID, redirectURI: redirectURI, scopes: scopes, characterRepo: characterRepo, } logger.Info("SSO instance created successfully") return s, nil } // SetRouter allows the SSO to use an existing fasthttp router func (s *SSO) SetRouter(r *router.Router) { s.router = r logger.Debug("SSO configured to use existing fasthttp router") s.setupCallbackHandler() } // initDB is no longer needed as migrations are handled by the repository // GetCharacter returns a valid character object for the given character name // If no token exists, it will start the OAuth flow // If token is expired, it will refresh it automatically func (s *SSO) GetCharacter(ctx context.Context, characterName string) (types.Character, error) { logger.Debug("Getting token for character %s", characterName) s.mu.Lock() defer s.mu.Unlock() // Try to get existing token from DB char, err := s.characterRepo.GetCharacterByName(characterName) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { logger.Info("No existing token found for character %s, starting authentication flow", characterName) // No token exists, need to authenticate if err := s.startAuthFlow(ctx, characterName); err != nil { return types.Character{}, err } // After authentication, fetch the token from DB char, err = s.characterRepo.GetCharacterByName(characterName) if err != nil { return types.Character{}, err } logger.Info("Successfully authenticated character %s", characterName) } else { return types.Character{}, err } } else { logger.Debug("Found existing token for character %s, expires at %v", characterName, char.ExpiresAt) // Check if character ID is missing (old record) if char.ID == 0 { logger.Info("Character ID missing for %s, extracting from token", characterName) _, eveCharID := parseTokenCharacter(char.AccessToken) if eveCharID > 0 { char.ID = eveCharID logger.Debug("Updating character %s with ID: %d", characterName, eveCharID) if err := s.characterRepo.SaveCharacter(char); err != nil { logger.Warning("Failed to update character %s with ID: %v", characterName, err) } } else { logger.Warning("Failed to extract character ID from token for %s", characterName) } } } // Check if token needs refresh if time.Now().After(char.ExpiresAt.Add(-1 * time.Minute)) { logger.Info("Token for character %s is expired or expiring soon, refreshing", characterName) if err := s.refreshToken(ctx, char); err != nil { logger.Warning("Token refresh failed for character %s, re-authenticating: %v", characterName, err) // Refresh failed, need to re-authenticate if err := s.startAuthFlow(ctx, characterName); err != nil { return types.Character{}, err } // After re-authentication, fetch the token from DB char, err = s.characterRepo.GetCharacterByName(characterName) if err != nil { return types.Character{}, err } logger.Info("Successfully re-authenticated character %s", characterName) } else { logger.Debug("Token refreshed successfully for character %s", characterName) } } else { logger.Debug("Token for character %s is still valid", characterName) } logger.Debug("Returning valid token for character %s", characterName) return *char, nil } func (s *SSO) startAuthFlow(ctx context.Context, characterName string) error { logger.Info("Starting authentication flow for character %s", characterName) // Generate PKCE logger.Debug("Generating PKCE parameters") verifier, challenge, err := generatePKCE() if err != nil { return err } s.state = randString(24) s.callbackChan = make(chan struct { code string state string err error }, 1) authURL := s.buildAuthURL(challenge, s.state) logger.Info("Generated authentication URL for character %s", characterName) logger.Info("Please visit this URL to authenticate: \n%s", authURL) logger.Info("Waiting for authentication...") // Setup callback handling if s.router == nil { logger.Error("No router configured for callback handling") return errors.New("no router configured for callback handling") } logger.Debug("Using fasthttp router for callback handling") // Wait for callback logger.Debug("Waiting for authentication callback") code, receivedState, err := s.waitForCallback() if err != nil { return err } if receivedState != s.state { logger.Error("Invalid state parameter received: %s, expected: %s", receivedState, s.state) return errors.New("invalid state parameter") } logger.Debug("Received valid callback, exchanging code for token") // Exchange code for token char, err := s.exchangeCodeForToken(ctx, code, verifier) if err != nil { return err } // Save token to DB char.CharacterName = characterName // Extract and set EVE character ID as primary key _, eveCharID := parseTokenCharacter(char.AccessToken) char.ID = eveCharID logger.Debug("Saving token to database for character %s (EVE ID: %d)", characterName, eveCharID) if err := s.characterRepo.SaveCharacter(char); err != nil { return err } logger.Info("Authentication flow completed successfully for character %s", characterName) return nil } func (s *SSO) buildAuthURL(challenge, state string) string { q := url.Values{} q.Set("response_type", "code") q.Set("client_id", s.clientID) q.Set("redirect_uri", s.redirectURI) if len(s.scopes) > 0 { q.Set("scope", strings.Join(s.scopes, " ")) } q.Set("state", state) q.Set("code_challenge", challenge) q.Set("code_challenge_method", "S256") return issuerAuthorizeURL + "?" + q.Encode() } func (s *SSO) setupCallbackHandler() { u, err := url.Parse(s.redirectURI) if err != nil { logger.Error("Failed to parse redirect URI for callback handler: %v", err) return } logger.Debug("Setting up callback handler on path: %s", u.Path) s.router.GET(u.Path, s.handleCallback) } func (s *SSO) handleCallback(ctx *fasthttp.RequestCtx) { s.processCallback( ctx.IsGet(), string(ctx.QueryArgs().Peek("code")), string(ctx.QueryArgs().Peek("state")), func(status int, body string) { ctx.SetStatusCode(status) ctx.WriteString(body) }, func(contentType string) { ctx.SetContentType(contentType) }, ) } func (s *SSO) processCallback(isGet bool, code, state string, writeResponse func(int, string), setContentType func(string)) { if !isGet { logger.Warning("Invalid callback method") writeResponse(http.StatusMethodNotAllowed, "Method not allowed") s.callbackChan <- struct { code string state string err error }{"", "", errors.New("method not allowed")} return } if code == "" || state == "" || state != s.state { logger.Error("Invalid SSO response: code=%s, state=%s, expected_state=%s", code, state, s.state) writeResponse(http.StatusBadRequest, "Invalid SSO response") s.callbackChan <- struct { code string state string err error }{"", "", errors.New("invalid state")} return } logger.Info("Received valid callback, exchanging token for code: %s", code) setContentType("text/html") writeResponse(http.StatusOK, "
You can close this window.
") s.callbackChan <- struct { code string state string err error }{code, state, nil} } func (s *SSO) waitForCallback() (code, state string, err error) { logger.Debug("Waiting for authentication callback (timeout: 30s)") // Wait for callback through channel select { case result := <-s.callbackChan: if result.err != nil { logger.Error("Callback received with error: %v", result.err) return "", "", result.err } logger.Debug("Callback received successfully") return result.code, result.state, result.err case <-time.After(30 * time.Second): logger.Error("Callback timeout after 30 seconds") return "", "", errors.New("callback timeout") } } func (s *SSO) exchangeCodeForToken(ctx context.Context, code, verifier string) (*types.Character, error) { logger.Debug("Exchanging authorization code for access token") form := url.Values{} form.Set("grant_type", "authorization_code") form.Set("code", code) form.Set("client_id", s.clientID) form.Set("code_verifier", verifier) req, err := http.NewRequestWithContext(ctx, http.MethodPost, issuerTokenURL, strings.NewReader(form.Encode())) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") logger.Debug("Sending token exchange request to %s", issuerTokenURL) resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { b, _ := io.ReadAll(resp.Body) logger.Error("Token exchange failed with status %d: %s", resp.StatusCode, string(b)) return nil, fmt.Errorf("token exchange failed: %s: %s", resp.Status, string(b)) } var tr tokenResponse if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil { return nil, err } // Parse character info from token name, _ := parseTokenCharacter(tr.AccessToken) logger.Info("Successfully exchanged code for token, character: %s", name) return &types.Character{ CharacterName: name, AccessToken: tr.AccessToken, RefreshToken: tr.RefreshToken, ExpiresAt: time.Now().Add(time.Duration(tr.ExpiresIn-30) * time.Second), }, nil } func (s *SSO) refreshToken(ctx context.Context, char *types.Character) error { logger.Debug("Refreshing token for character %s", char.CharacterName) form := url.Values{} form.Set("grant_type", "refresh_token") form.Set("refresh_token", char.RefreshToken) form.Set("client_id", s.clientID) req, err := http.NewRequestWithContext(ctx, http.MethodPost, issuerTokenURL, strings.NewReader(form.Encode())) if err != nil { return err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") logger.Debug("Sending token refresh request to %s", issuerTokenURL) resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { b, _ := io.ReadAll(resp.Body) logger.Error("Token refresh failed with status %d: %s", resp.StatusCode, string(b)) return fmt.Errorf("token refresh failed: %s: %s", resp.Status, string(b)) } var tr tokenResponse if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil { return err } // Update token char.AccessToken = tr.AccessToken if tr.RefreshToken != "" { char.RefreshToken = tr.RefreshToken } if tr.ExpiresIn > 0 { char.ExpiresAt = time.Now().Add(time.Duration(tr.ExpiresIn-30) * time.Second) } logger.Debug("Saving refreshed token to database for character %s", char.CharacterName) if err := s.characterRepo.SaveCharacter(char); err != nil { return err } logger.Info("Token refreshed successfully for character %s", char.CharacterName) return nil } // Utility functions func generatePKCE() (verifier string, challenge string, err error) { buf := make([]byte, 32) if _, err = rand.Read(buf); err != nil { return } v := base64.RawURLEncoding.EncodeToString(buf) h := sha256.Sum256([]byte(v)) c := base64.RawURLEncoding.EncodeToString(h[:]) return v, c, nil } func randString(n int) string { buf := make([]byte, n) _, _ = rand.Read(buf) return base64.RawURLEncoding.EncodeToString(buf) } func parseTokenCharacter(jwt string) (name string, id int64) { parts := strings.Split(jwt, ".") if len(parts) != 3 { return "", 0 } payload, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return "", 0 } var m map[string]any if err := json.Unmarshal(payload, &m); err != nil { return "", 0 } if v, ok := m["name"].(string); ok { name = v } if v, ok := m["sub"].(string); ok { if idx := strings.LastIndexByte(v, ':'); idx > -1 { if idv, err := strconv.ParseInt(v[idx+1:], 10, 64); err == nil { id = idv } } } return } type tokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` }