Hypercode/alex/hypercodePublic

Code

  1. hypercode
  2. controllers
  3. device_auth_controller.go
device_auth_controller.go274 lines
package controllers

import (
	"crypto/rand"
	"crypto/sha256"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"math/big"
	"net/http"
	"strings"
	"time"

	"github.com/google/uuid"
	"github.com/hypercodehq/hypercode/database/repositories"
	"github.com/hypercodehq/hypercode/httperror"
	"github.com/hypercodehq/hypercode/middleware"
	"github.com/hypercodehq/hypercode/views/pages"
)

type DeviceAuthController interface {
	InitiateDeviceAuth(w http.ResponseWriter, r *http.Request) error
	PollDeviceAuth(w http.ResponseWriter, r *http.Request) error
	ShowDeviceAuthPage(w http.ResponseWriter, r *http.Request) error
	ConfirmDeviceAuth(w http.ResponseWriter, r *http.Request) error
}

type deviceAuthController struct {
	sessions     repositories.DeviceAuthSessionsRepository
	accessTokens repositories.AccessTokensRepository
	users        repositories.UsersRepository
}

func NewDeviceAuthController(
	sessions repositories.DeviceAuthSessionsRepository,
	accessTokens repositories.AccessTokensRepository,
	users repositories.UsersRepository,
) DeviceAuthController {
	return &deviceAuthController{
		sessions:     sessions,
		accessTokens: accessTokens,
		users:        users,
	}
}

// generateDeviceCode generates a user-friendly device code
func (c *deviceAuthController) generateDeviceCode() (string, error) {
	const charset = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" // Exclude ambiguous characters
	const codeLength = 8

	code := make([]byte, codeLength)
	for i := range code {
		num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
		if err != nil {
			return "", err
		}
		code[i] = charset[num.Int64()]
	}

	// Format as XXXX-XXXX
	return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
}

// generateAccessToken generates a cryptographically secure access token
func (c *deviceAuthController) generateAccessToken() (string, string, error) {
	b := make([]byte, 32)
	if _, err := rand.Read(b); err != nil {
		return "", "", err
	}

	rawToken := base64.URLEncoding.EncodeToString(b)
	hash := sha256.Sum256([]byte(rawToken))
	tokenHash := fmt.Sprintf("%x", hash)

	return rawToken, tokenHash, nil
}

// InitiateDeviceAuth creates a new device auth session
func (c *deviceAuthController) InitiateDeviceAuth(w http.ResponseWriter, r *http.Request) error {
	sessionID := uuid.New().String()
	code, err := c.generateDeviceCode()
	if err != nil {
		return err
	}

	// Session expires in 10 minutes
	expiresAt := time.Now().Add(10 * time.Minute).Unix()

	session, err := c.sessions.Create(sessionID, code, expiresAt)
	if err != nil {
		return err
	}

	// Return JSON response with session details
	verificationURL := fmt.Sprintf("%s://%s/auth/device?code=%s",
		func() string {
			if r.TLS != nil {
				return "https"
			}
			return "http"
		}(),
		r.Host,
		code,
	)

	response := map[string]interface{}{
		"session_id":        session.ID,
		"user_code":         session.Code,
		"verification_url":  verificationURL,
		"expires_at":        session.ExpiresAt,
		"interval":          1, // Poll every 1 second
	}

	w.Header().Set("Content-Type", "application/json")
	return json.NewEncoder(w).Encode(response)
}

// PollDeviceAuth checks the status of a device auth session
func (c *deviceAuthController) PollDeviceAuth(w http.ResponseWriter, r *http.Request) error {
	sessionID := r.URL.Query().Get("session_id")
	if sessionID == "" {
		return httperror.New(http.StatusBadRequest, "session_id is required")
	}

	session, err := c.sessions.FindByID(sessionID)
	if err != nil {
		return err
	}

	if session == nil {
		return httperror.New(http.StatusNotFound, "Session not found")
	}

	// Check if expired
	if time.Now().Unix() > session.ExpiresAt {
		_ = c.sessions.UpdateStatus(sessionID, "expired")
		response := map[string]interface{}{
			"status": "expired",
			"error":  "The device code has expired. Please start a new authentication flow.",
		}
		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(http.StatusGone)
		return json.NewEncoder(w).Encode(response)
	}

	// Return current status
	response := map[string]interface{}{
		"status": session.Status,
	}

	if session.Status == "confirmed" && session.AccessToken != nil && session.UserID != nil {
		// Fetch username
		user, err := c.users.FindByID(*session.UserID)
		if err != nil {
			return err
		}

		response["access_token"] = *session.AccessToken
		response["username"] = user.Username
	}

	w.Header().Set("Content-Type", "application/json")
	return json.NewEncoder(w).Encode(response)
}

// ShowDeviceAuthPage shows the device confirmation page
func (c *deviceAuthController) ShowDeviceAuthPage(w http.ResponseWriter, r *http.Request) error {
	code := r.URL.Query().Get("code")
	if code != "" {
		// Normalize code (remove hyphens and convert to uppercase)
		code = strings.ToUpper(strings.ReplaceAll(code, "-", ""))
		// Add hyphen back in the middle for display
		if len(code) == 8 {
			code = fmt.Sprintf("%s-%s", code[:4], code[4:])
		}
	}

	user := middleware.GetUserFromContext(r)

	data := &pages.DeviceAuthData{
		User: user,
		Code: code,
	}

	return pages.DeviceAuth(r, data).Render(w, r)
}

// ConfirmDeviceAuth confirms a device auth session
func (c *deviceAuthController) ConfirmDeviceAuth(w http.ResponseWriter, r *http.Request) error {
	user := middleware.GetUserFromContext(r)
	if user == nil {
		http.Redirect(w, r, "/auth/sign-in", http.StatusSeeOther)
		return nil
	}

	if err := r.ParseForm(); err != nil {
		return httperror.New(http.StatusBadRequest, "Invalid form data")
	}

	code := r.FormValue("code")
	if code == "" {
		return httperror.New(http.StatusBadRequest, "Code is required")
	}

	// Normalize code
	code = strings.ToUpper(strings.ReplaceAll(code, "-", ""))
	if len(code) == 8 {
		code = fmt.Sprintf("%s-%s", code[:4], code[4:])
	}

	// Find session by code
	session, err := c.sessions.FindByCode(code)
	if err != nil {
		return err
	}

	if session == nil {
		data := &pages.DeviceAuthData{
			User:  user,
			Code:  code,
			Error: "Invalid or expired code. Please try again.",
		}
		return pages.DeviceAuth(r, data).Render(w, r)
	}

	// Check if expired
	if time.Now().Unix() > session.ExpiresAt {
		_ = c.sessions.UpdateStatus(session.ID, "expired")
		data := &pages.DeviceAuthData{
			User:  user,
			Code:  code,
			Error: "This code has expired. Please start a new authentication flow.",
		}
		return pages.DeviceAuth(r, data).Render(w, r)
	}

	// Check if already confirmed
	if session.Status == "confirmed" {
		data := &pages.DeviceAuthData{
			User:    user,
			Code:    code,
			Success: true,
		}
		return pages.DeviceAuth(r, data).Render(w, r)
	}

	// Generate access token
	rawToken, tokenHash, err := c.generateAccessToken()
	if err != nil {
		return err
	}

	// Create access token in database
	tokenName := fmt.Sprintf("CLI Device Auth - %s", time.Now().Format("2006-01-02 15:04:05"))
	_, err = c.accessTokens.Create(user.ID, tokenName, tokenHash)
	if err != nil {
		return err
	}

	// Confirm the session
	err = c.sessions.Confirm(session.ID, user.ID, rawToken)
	if err != nil {
		return err
	}

	// Show success page
	data := &pages.DeviceAuthData{
		User:    user,
		Code:    code,
		Success: true,
	}
	return pages.DeviceAuth(r, data).Render(w, r)
}