Files
go-eve-pi/esi_sso.go

498 lines
14 KiB
Go

package main
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"gorm.io/gorm"
logger "git.site.quack-lab.dev/dave/cylogger"
)
const (
issuerAuthorizeURL = "https://login.eveonline.com/v2/oauth/authorize"
issuerTokenURL = "https://login.eveonline.com/v2/oauth/token"
)
type Token struct {
ID uint `gorm:"primaryKey"`
CharacterName string `gorm:"uniqueIndex"`
AccessToken string
RefreshToken string
ExpiresAt time.Time
UpdatedAt time.Time
CreatedAt time.Time
}
type SSO struct {
clientID string
redirectURI string
scopes []string
db DB
mu sync.Mutex
server *http.Server
mux *http.ServeMux
state string
callbackChan chan struct {
code string
state string
err error
}
}
// NewSSO creates a new SSO instance
func NewSSO(clientID, redirectURI string, scopes []string) (*SSO, error) {
logger.Info("Creating new SSO instance for clientID %s with redirectURI %s and scopes %v", clientID, redirectURI, scopes)
db, err := GetDB()
if err != nil {
logger.Error("Failed to get database connection %v", err)
return nil, err
}
s := &SSO{
clientID: clientID,
redirectURI: redirectURI,
scopes: scopes,
db: db,
}
if err := s.initDB(); err != nil {
logger.Error("Failed to initialize SSO database %v", err)
return nil, err
}
logger.Info("SSO instance created successfully")
return s, nil
}
// SetMuxer allows the SSO to use an existing HTTP muxer instead of creating its own server
func (s *SSO) SetMuxer(mux *http.ServeMux) {
s.mux = mux
logger.Debug("SSO configured to use existing HTTP muxer")
}
func (s *SSO) initDB() error {
logger.Debug("Initializing SSO database schema")
err := s.db.AutoMigrate(&Token{})
if err != nil {
logger.Error("Failed to migrate Token table %v", err)
return err
}
logger.Debug("SSO database schema initialized successfully")
return nil
}
// GetToken returns a valid access token for the given character name
// If no token exists, it will start the OAuth flow
// If token is expired, it will refresh it automatically
func (s *SSO) GetToken(ctx context.Context, characterName string) (string, error) {
logger.Debug("Getting token for character %s", characterName)
s.mu.Lock()
defer s.mu.Unlock()
// Try to get existing token from DB
token, err := s.db.GetTokenForCharacter(characterName)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
logger.Info("No existing token found for character %s, starting authentication flow", characterName)
// No token exists, need to authenticate
if err := s.startAuthFlow(ctx, characterName); err != nil {
logger.Error("Authentication flow failed for character %s: %v", characterName, err)
return "", err
}
// After authentication, fetch the token from DB
token, err = s.db.GetTokenForCharacter(characterName)
if err != nil {
logger.Error("Failed to fetch token after authentication for character %s: %v", characterName, err)
return "", err
}
logger.Info("Successfully authenticated character %s", characterName)
} else {
logger.Error("Database error when fetching token for character %s: %v", characterName, err)
return "", err
}
} else {
logger.Debug("Found existing token for character %s, expires at %v", characterName, token.ExpiresAt)
}
// Check if token needs refresh
if time.Now().After(token.ExpiresAt.Add(-60 * time.Second)) {
logger.Info("Token for character %s is expired or expiring soon, refreshing", characterName)
if err := s.refreshToken(ctx, token); err != nil {
logger.Warning("Token refresh failed for character %s, re-authenticating: %v", characterName, err)
// Refresh failed, need to re-authenticate
if err := s.startAuthFlow(ctx, characterName); err != nil {
logger.Error("Re-authentication failed for character %s: %v", characterName, err)
return "", err
}
// After re-authentication, fetch the token from DB
token, err = s.db.GetTokenForCharacter(characterName)
if err != nil {
logger.Error("Failed to fetch token after re-authentication for character %s: %v", characterName, err)
return "", err
}
logger.Info("Successfully re-authenticated character %s", characterName)
} else {
logger.Debug("Token refreshed successfully for character %s", characterName)
}
} else {
logger.Debug("Token for character %s is still valid", characterName)
}
logger.Debug("Returning valid token for character %s", characterName)
return token.AccessToken, nil
}
func (s *SSO) startAuthFlow(ctx context.Context, characterName string) error {
logger.Info("Starting authentication flow for character %s", characterName)
// Generate PKCE
logger.Debug("Generating PKCE parameters")
verifier, challenge, err := generatePKCE()
if err != nil {
logger.Error("Failed to generate PKCE parameters: %v", err)
return err
}
s.state = randString(24)
s.callbackChan = make(chan struct {
code string
state string
err error
}, 1)
authURL := s.buildAuthURL(challenge, s.state)
logger.Info("Generated authentication URL for character %s", characterName)
logger.Info("Please visit this URL to authenticate: \n%s", authURL)
logger.Info("Waiting for authentication...")
// Setup callback handling
if s.mux != nil {
logger.Debug("Using existing HTTP muxer for callback handling")
s.setupCallbackHandler()
} else {
logger.Debug("Starting dedicated callback server")
server, err := s.startCallbackServer()
if err != nil {
logger.Error("Failed to start callback server: %v", err)
return err
}
s.server = server
defer server.Shutdown(ctx)
}
// Wait for callback
logger.Debug("Waiting for authentication callback")
code, receivedState, err := s.waitForCallback()
if err != nil {
logger.Error("Callback wait failed: %v", err)
return err
}
if receivedState != s.state {
logger.Error("Invalid state parameter received: %s, expected: %s", receivedState, s.state)
return errors.New("invalid state parameter")
}
logger.Debug("Received valid callback, exchanging code for token")
// Exchange code for token
token, err := s.exchangeCodeForToken(ctx, code, verifier)
if err != nil {
logger.Error("Failed to exchange code for token: %v", err)
return err
}
// Save token to DB
token.CharacterName = characterName
logger.Debug("Saving token to database for character %s", characterName)
if err := s.db.SaveTokenForCharacter(token); err != nil {
logger.Error("Failed to save token to database: %v", err)
return err
}
logger.Info("Authentication flow completed successfully for character %s", characterName)
return nil
}
func (s *SSO) buildAuthURL(challenge, state string) string {
q := url.Values{}
q.Set("response_type", "code")
q.Set("client_id", s.clientID)
q.Set("redirect_uri", s.redirectURI)
if len(s.scopes) > 0 {
q.Set("scope", strings.Join(s.scopes, " "))
}
q.Set("state", state)
q.Set("code_challenge", challenge)
q.Set("code_challenge_method", "S256")
return issuerAuthorizeURL + "?" + q.Encode()
}
func (s *SSO) setupCallbackHandler() {
u, err := url.Parse(s.redirectURI)
if err != nil {
logger.Error("Failed to parse redirect URI for callback handler: %v", err)
return
}
logger.Debug("Setting up callback handler on path: %s", u.Path)
s.mux.HandleFunc(u.Path, s.handleCallback)
}
func (s *SSO) handleCallback(w http.ResponseWriter, r *http.Request) {
logger.Debug("Received callback request: %s %s", r.Method, r.URL.String())
if r.Method != http.MethodGet {
logger.Warning("Invalid callback method: %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
s.callbackChan <- struct {
code string
state string
err error
}{"", "", errors.New("method not allowed")}
return
}
q := r.URL.Query()
code := q.Get("code")
st := q.Get("state")
if code == "" || st == "" || st != s.state {
logger.Error("Invalid SSO response: code=%s, state=%s, expected_state=%s", code, st, s.state)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("Invalid SSO response"))
s.callbackChan <- struct {
code string
state string
err error
}{"", "", errors.New("invalid state")}
return
}
logger.Info("Received valid callback, exchanging token for code: %s", code)
w.Header().Set("Content-Type", "text/html")
_, _ = w.Write([]byte("<html><body><h1>Login successful!</h1><p>You can close this window.</p></body></html>"))
s.callbackChan <- struct {
code string
state string
err error
}{code, st, nil}
}
func (s *SSO) startCallbackServer() (*http.Server, error) {
logger.Debug("Starting dedicated callback server for redirect URI: %s", s.redirectURI)
u, err := url.Parse(s.redirectURI)
if err != nil {
logger.Error("Failed to parse redirect URI: %v", err)
return nil, err
}
if u.Scheme != "http" && u.Scheme != "https" {
logger.Error("Invalid redirect URI scheme: %s", u.Scheme)
return nil, errors.New("redirect URI must be http(s)")
}
hostPort := u.Host
if !strings.Contains(hostPort, ":") {
if u.Scheme == "https" {
hostPort += ":443"
} else {
hostPort += ":80"
}
}
logger.Debug("Callback server will listen on %s", hostPort)
mux := http.NewServeMux()
mux.HandleFunc(u.Path, s.handleCallback)
ln, err := net.Listen("tcp", hostPort)
if err != nil {
logger.Error("Failed to listen on %s: %v", hostPort, err)
return nil, err
}
server := &http.Server{Handler: mux}
go func() {
logger.Debug("Callback server started successfully")
_ = server.Serve(ln)
}()
return server, nil
}
func (s *SSO) waitForCallback() (code, state string, err error) {
logger.Debug("Waiting for authentication callback (timeout: 30s)")
// Wait for callback through channel
select {
case result := <-s.callbackChan:
if result.err != nil {
logger.Error("Callback received with error: %v", result.err)
return "", "", result.err
}
logger.Debug("Callback received successfully")
return result.code, result.state, result.err
case <-time.After(30 * time.Second):
logger.Error("Callback timeout after 30 seconds")
return "", "", errors.New("callback timeout")
}
}
func (s *SSO) exchangeCodeForToken(ctx context.Context, code, verifier string) (*Token, error) {
logger.Debug("Exchanging authorization code for access token")
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("code", code)
form.Set("client_id", s.clientID)
form.Set("code_verifier", verifier)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, issuerTokenURL, strings.NewReader(form.Encode()))
if err != nil {
logger.Error("Failed to create token exchange request: %v", err)
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
logger.Debug("Sending token exchange request to %s", issuerTokenURL)
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.Error("Token exchange request failed: %v", err)
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
b, _ := io.ReadAll(resp.Body)
logger.Error("Token exchange failed with status %d: %s", resp.StatusCode, string(b))
return nil, fmt.Errorf("token exchange failed: %s: %s", resp.Status, string(b))
}
var tr tokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil {
logger.Error("Failed to decode token response: %v", err)
return nil, err
}
// Parse character info from token
name, _ := parseTokenCharacter(tr.AccessToken)
logger.Info("Successfully exchanged code for token, character: %s", name)
return &Token{
CharacterName: name,
AccessToken: tr.AccessToken,
RefreshToken: tr.RefreshToken,
ExpiresAt: time.Now().Add(time.Duration(tr.ExpiresIn-30) * time.Second),
}, nil
}
func (s *SSO) refreshToken(ctx context.Context, token *Token) error {
logger.Debug("Refreshing token for character %s", token.CharacterName)
form := url.Values{}
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", token.RefreshToken)
form.Set("client_id", s.clientID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, issuerTokenURL, strings.NewReader(form.Encode()))
if err != nil {
logger.Error("Failed to create token refresh request: %v", err)
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
logger.Debug("Sending token refresh request to %s", issuerTokenURL)
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.Error("Token refresh request failed: %v", err)
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
b, _ := io.ReadAll(resp.Body)
logger.Error("Token refresh failed with status %d: %s", resp.StatusCode, string(b))
return fmt.Errorf("token refresh failed: %s: %s", resp.Status, string(b))
}
var tr tokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil {
logger.Error("Failed to decode refresh token response: %v", err)
return err
}
// Update token
token.AccessToken = tr.AccessToken
if tr.RefreshToken != "" {
token.RefreshToken = tr.RefreshToken
}
if tr.ExpiresIn > 0 {
token.ExpiresAt = time.Now().Add(time.Duration(tr.ExpiresIn-30) * time.Second)
}
logger.Debug("Saving refreshed token to database for character %s", token.CharacterName)
if err := s.db.SaveTokenForCharacter(token); err != nil {
logger.Error("Failed to save refreshed token: %v", err)
return err
}
logger.Info("Token refreshed successfully for character %s", token.CharacterName)
return nil
}
// Utility functions
func generatePKCE() (verifier string, challenge string, err error) {
buf := make([]byte, 32)
if _, err = rand.Read(buf); err != nil {
return
}
v := base64.RawURLEncoding.EncodeToString(buf)
h := sha256.Sum256([]byte(v))
c := base64.RawURLEncoding.EncodeToString(h[:])
return v, c, nil
}
func randString(n int) string {
buf := make([]byte, n)
_, _ = rand.Read(buf)
return base64.RawURLEncoding.EncodeToString(buf)
}
func parseTokenCharacter(jwt string) (name string, id int64) {
parts := strings.Split(jwt, ".")
if len(parts) != 3 {
return "", 0
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", 0
}
var m map[string]any
if err := json.Unmarshal(payload, &m); err != nil {
return "", 0
}
if v, ok := m["name"].(string); ok {
name = v
}
if v, ok := m["sub"].(string); ok {
if idx := strings.LastIndexByte(v, ':'); idx > -1 {
if idv, err := strconv.ParseInt(v[idx+1:], 10, 64); err == nil {
id = idv
}
}
}
return
}
type tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}