// middleware provides a small proof-of-work puzzle that users solve before // accessing protected pages or APIs, plus transparent reverse-proxy support. // It issues HMAC-signed tokens bound to IP/browser, stores them in BadgerDB, // and automatically cleans up expired data. package middleware import ( "context" "crypto/hmac" cryptorand "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "io" "log" "net" "net/http" "net/http/httputil" "net/url" "os" "path/filepath" "regexp" "strings" "sync" "sync/atomic" "time" "bytes" "encoding/gob" "html/template" "github.com/dgraph-io/badger/v4" "github.com/gofiber/fiber/v2" "github.com/mileusna/useragent" ) // --- Configuration --- // Config struct holds all configurable parameters for the Checkpoint middleware type Config struct { // General Settings Difficulty int // Number of leading zeros for PoW hash TokenExpiration time.Duration // Validity period for issued tokens CookieName string // Name of the cookie used to store tokens CookieDomain string // Domain scope for the cookie (e.g., ".example.com" for subdomains) SaltLength int // Length of the salt used in challenges // Rate Limiting & Expiration MaxAttemptsPerHour int // Max PoW verification attempts per IP per hour MaxNonceAge time.Duration // Max age for used nonces before cleanup ChallengeExpiration time.Duration // Time limit for solving a challenge // File Paths SecretConfigPath string // Path to the persistent HMAC secret file TokenStoreDBPath string // Directory path for the BadgerDB token store InterstitialPaths []string // Paths to search for the interstitial HTML page // Security Settings CheckPoSTimes bool // Enable Proof-of-Space-Time consistency checks PoSTimeConsistencyRatio float64 // Allowed ratio between fastest and slowest PoS runs HTMLCheckpointExclusions []string // Path prefixes to exclude from HTML checkpoint HTMLCheckpointExcludedExtensions map[string]bool // File extensions to exclude (lowercase, '.') DangerousQueryPatterns []*regexp.Regexp // Regex patterns to block in query strings BlockDangerousPathChars bool // Block paths containing potentially dangerous characters (;, `) // User Agent validation settings UserAgentValidationExclusions []string // Path prefixes to skip UA validation UserAgentRequiredPrefixes map[string]string // Path prefix -> required UA prefix // Note: Binding to IP, User Agent, and Browser Hint is always enabled. // Reverse Proxy Settings ReverseProxyMappings map[string]string // Map of hostname to backend URL (e.g., "app.example.com": "http://127.0.0.1:8080") } var ( // Global configuration instance checkpointConfig Config // Secret key used for HMAC verification - automatically generated on startup hmacSecret []byte // Used nonces to prevent replay attacks - use sync.Map for concurrency usedNonces sync.Map // map[string]time.Time // IP-based rate limiting for token generation - use sync.Map for concurrency ipRateLimit sync.Map // map[string]*atomic.Int64 (or similar atomic counter) // Challenge parameters store with request IDs - use sync.Map for concurrency challengeStore sync.Map // map[string]ChallengeParams // Global token store (now BadgerDB based) tokenStore *TokenStore // in-memory cache for the interstitial HTML to avoid repeated disk reads interstitialContent string interstitialOnce sync.Once interstitialLoadErr error // parsed template for interstitial page interstitialTmpl *template.Template interstitialTmplOnce sync.Once interstitialTmplErr error // pool for gob encoding buffers to reduce allocations gobBufferPool = sync.Pool{ New: func() interface{} { return new(bytes.Buffer) }, } ) // Need atomic package for ipRateLimit counter func init() { // Load complete configuration from checkpoint.toml (required) var cfg Config if err := LoadConfig("checkpoint", &cfg); err != nil { log.Fatalf("Failed to load checkpoint config: %v", err) } SetConfig(cfg) // Register sanitization plugin (cleanup URLs/queries before checkpoint) RegisterPlugin("sanitize", RequestSanitizationMiddleware) // Register checkpoint plugin RegisterPlugin("checkpoint", New) // Initialize stores AFTER config is potentially set/loaded // Ensure tokenStore is initialized before use var err error tokenStore, err = NewTokenStore(checkpointConfig.TokenStoreDBPath) if err != nil { log.Fatalf("CRITICAL: Failed to initialize TokenStore database: %v", err) } // Initialize secret _ = initSecret() // Start cleanup timer for nonces/ip rates (token cleanup handled by DB TTL) _ = startCleanupTimer() } // SecretConfig contains configuration for the Checkpoint system (for secret file persistence) type SecretConfig struct { HmacSecret []byte `json:"hmac_secret"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // --- End Configuration --- // SetConfig swaps in your custom Config (usually loaded from TOML). // Do this before using the middleware, ideally at startup. func SetConfig(cfg Config) { checkpointConfig = cfg // Re-initialization of token store path is complex with BadgerDB, recommend restart. // Other config changes can be applied dynamically if needed. } // --- Token Store (BadgerDB Implementation) --- // StoredTokenData holds the relevant information persisted for each token hash. // This includes binding information needed for verification. type StoredTokenData struct { ClientIPHash string // Hash of IP used during issuance UserAgentHash string // Hash of User Agent used during issuance BrowserHint string // Browser Hint used during issuance LastVerified time.Time // Last time this token was successfully validated ExpiresAt time.Time // Original expiration time of the token (for reference, TTL enforces) } // TokenStore manages persistent storage of verified tokens using BadgerDB. type TokenStore struct { DB *badger.DB } // NewTokenStore initializes and returns a new TokenStore using BadgerDB. func NewTokenStore(dbPath string) (*TokenStore, error) { if err := os.MkdirAll(dbPath, 0755); err != nil { return nil, fmt.Errorf("failed to create token store directory %s: %w", dbPath, err) } opts := badger.DefaultOptions(dbPath) // Tune options for performance if needed (e.g., memory usage) opts.Logger = nil // Disable default Badger logger unless debugging db, err := badger.Open(opts) if err != nil { return nil, fmt.Errorf("failed to open token store database at %s: %w", dbPath, err) } store := &TokenStore{DB: db} // Start BadgerDB's own value log GC routine (optional but recommended) go store.runValueLogGC() return store, nil } // Close closes the BadgerDB database. // Should be called during graceful shutdown. func (store *TokenStore) Close() error { if store.DB != nil { log.Println("Closing TokenStore database...") return store.DB.Close() } return nil } // runValueLogGC runs BadgerDB's value log garbage collection periodically. func (store *TokenStore) runValueLogGC() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { again: err := store.DB.RunValueLogGC(0.7) // Run GC if 70% space can be reclaimed if err == nil { goto again // Run GC multiple times if needed } if err != badger.ErrNoRewrite { log.Printf("WARNING: BadgerDB RunValueLogGC error: %v", err) } } } // encodeTokenData serializes StoredTokenData using gob. func encodeTokenData(data *StoredTokenData) ([]byte, error) { // get a buffer from pool buf := gobBufferPool.Get().(*bytes.Buffer) buf.Reset() enc := gob.NewEncoder(buf) if err := enc.Encode(data); err != nil { gobBufferPool.Put(buf) return nil, fmt.Errorf("failed to gob encode token data: %w", err) } // copy out the bytes to avoid retaining large buffer out := make([]byte, buf.Len()) copy(out, buf.Bytes()) buf.Reset() gobBufferPool.Put(buf) return out, nil } // decodeTokenData deserializes StoredTokenData using gob. func decodeTokenData(encoded []byte) (*StoredTokenData, error) { var data StoredTokenData // use a reader to avoid extra buffer allocation reader := bytes.NewReader(encoded) dec := gob.NewDecoder(reader) if err := dec.Decode(&data); err != nil { return nil, fmt.Errorf("failed to gob decode token data: %w", err) } return &data, nil } // addToken stores the token data in BadgerDB with a TTL. func (store *TokenStore) addToken(tokenHash string, data *StoredTokenData) error { encodedData, err := encodeTokenData(data) if err != nil { return err // Error already wrapped } // Calculate TTL based on the token's specific expiration ttl := time.Until(data.ExpiresAt) if ttl <= 0 { log.Printf("Attempted to add already expired token hash %s", tokenHash) return nil // Don't add already expired tokens } err = store.DB.Update(func(txn *badger.Txn) error { e := badger.NewEntry([]byte(tokenHash), encodedData).WithTTL(ttl) return txn.SetEntry(e) }) if err != nil { return fmt.Errorf("failed to add token hash %s to DB: %w", tokenHash, err) } return nil } // updateTokenVerification updates the LastVerified time for an existing token. func (store *TokenStore) updateTokenVerification(tokenHash string) error { return store.DB.Update(func(txn *badger.Txn) error { item, err := txn.Get([]byte(tokenHash)) if err != nil { // If token expired or was deleted between check and update, log and ignore. if err == badger.ErrKeyNotFound { log.Printf("Token hash %s not found during update verification (likely expired/deleted)", tokenHash) return nil // Not a critical error in this context } return fmt.Errorf("failed to get token %s for update: %w", tokenHash, err) } var storedData *StoredTokenData err = item.Value(func(val []byte) error { storedData, err = decodeTokenData(val) return err }) if err != nil { return fmt.Errorf("failed to decode token %s value for update: %w", tokenHash, err) } // Update LastVerified and re-encode storedData.LastVerified = time.Now() encodedData, err := encodeTokenData(storedData) if err != nil { return err } // Set the entry again (TTL remains the same based on original ExpiresAt) ttl := time.Until(storedData.ExpiresAt) if ttl <= 0 { return nil } // Don't update if expired e := badger.NewEntry([]byte(tokenHash), encodedData).WithTTL(ttl) return txn.SetEntry(e) }) } // lookupTokenData retrieves token data from BadgerDB. // Returns the data, true if found and not expired, or false otherwise. // Added context parameter func (store *TokenStore) lookupTokenData(ctx context.Context, tokenHash string) (*StoredTokenData, bool, error) { var storedData *StoredTokenData var found bool err := store.DB.View(func(txn *badger.Txn) error { // Check context cancellation within the transaction if ctx.Err() != nil { return ctx.Err() } item, err := txn.Get([]byte(tokenHash)) if err != nil { if err == badger.ErrKeyNotFound { return nil // Not found, not an error for lookup } return fmt.Errorf("failed to get token hash %s from DB: %w", tokenHash, err) } // Key exists, decode the value err = item.Value(func(val []byte) error { // Check context cancellation before decoding if ctx.Err() != nil { return ctx.Err() } var decodeErr error storedData, decodeErr = decodeTokenData(val) return decodeErr }) if err != nil { // If context was cancelled, return that error if ctx.Err() != nil { return ctx.Err() } // Return actual decoding error return fmt.Errorf("failed to decode StoredTokenData for hash %s: %w", tokenHash, err) } // Check expiration explicitly just in case TTL mechanism has latency if time.Now().After(storedData.ExpiresAt) { log.Printf("Token hash %s found but expired (ExpiresAt: %v)", tokenHash, storedData.ExpiresAt) storedData = nil // Treat as not found if expired return nil } found = true return nil }) if err != nil { // Don't log here, return the error to the caller (validateToken) return nil, false, err // Return the actual error } return storedData, found, nil // Success } // --- End Token Store --- // CloseTokenStore provides a package-level function to close the global token store. // This should be called during application shutdown. func CloseTokenStore() error { if tokenStore != nil { return tokenStore.Close() } return nil } // loadInterstitialHTML returns the cached interstitial HTML (loads once from disk) func loadInterstitialHTML() (string, error) { interstitialOnce.Do(func() { for _, path := range checkpointConfig.InterstitialPaths { if data, err := os.ReadFile(path); err == nil { interstitialContent = string(data) return } } interstitialLoadErr = fmt.Errorf("could not find checkpoint interstitial HTML at any configured path") }) return interstitialContent, interstitialLoadErr } // getInterstitialTemplate parses the cached HTML as a Go template (once) func getInterstitialTemplate() (*template.Template, error) { interstitialTmplOnce.Do(func() { raw, err := loadInterstitialHTML() if err != nil { interstitialTmplErr = err return } interstitialTmpl, interstitialTmplErr = template.New("interstitial").Parse(raw) }) return interstitialTmpl, interstitialTmplErr } // serveInterstitial serves the challenge page using a Go template for safe interpolation func serveInterstitial(c *fiber.Ctx) error { requestID := generateRequestID(c) c.Status(200) c.Set("Content-Type", "text/html; charset=utf-8") tmpl, err := getInterstitialTemplate() if err != nil { log.Printf("WARNING: %v", err) return c.SendString("Security verification required. Please refresh the page.") } // prepare data for template host := c.Hostname() originalURL, _ := c.Locals("originalURL").(string) targetPath := c.Path() if originalURL != "" { targetPath = originalURL } data := struct { TargetPath string RequestID string Host string FullURL string }{ TargetPath: targetPath, RequestID: requestID, Host: host, FullURL: c.BaseURL() + targetPath, } var buf bytes.Buffer if err := tmpl.Execute(&buf, data); err != nil { log.Printf("ERROR: Interstitial template execution failed: %v", err) return c.SendString("Security verification required. Please refresh the page.") } return c.SendString(buf.String()) } // checkPoSTimes ensures that memory proof run times are within the allowed ratio func checkPoSTimes(times []int64) error { if len(times) != 3 { return fmt.Errorf("invalid PoS run times length") } minT, maxT := times[0], times[0] for _, t := range times[1:] { if t < minT { minT = t } if t > maxT { maxT = t } } if checkpointConfig.CheckPoSTimes && float64(maxT) > float64(minT)*checkpointConfig.PoSTimeConsistencyRatio { return fmt.Errorf("PoS run times ('i') are not consistent (ratio %.2f > %.2f)", float64(maxT)/float64(minT), checkpointConfig.PoSTimeConsistencyRatio) } return nil } // getDomainFromHost returns the base domain from a hostname // For proper cookie sharing in both production and development func getDomainFromHost(hostname string) string { // Handle localhost development if hostname == "localhost" || strings.HasPrefix(hostname, "localhost:") || hostname == "127.0.0.1" || strings.HasPrefix(hostname, "127.0.0.1:") { return "" // Use host-only cookies for localhost } // For IP addresses, use host-only cookies if net.ParseIP(strings.Split(hostname, ":")[0]) != nil { return "" // IP address - use host-only } parts := strings.Split(hostname, ".") if len(parts) <= 1 { return hostname // single word domain - unlikely } // For standard domains, return domain with leading dot if len(parts) >= 2 { // Return parent domain for proper cookie sharing domain := parts[len(parts)-2] + "." + parts[len(parts)-1] return "." + domain // Leading dot is important } return "" // Fallback to host-only cookie } // issueToken handles token generation, cookie setting, and JSON response func issueToken(c *fiber.Ctx, token CheckpointToken) error { // 1. Generate the token hash tokenHash := calculateTokenHash(token) // 2. Create the data to store in the DB storedData := &StoredTokenData{ ClientIPHash: token.ClientIP, // Assumes token struct is already populated UserAgentHash: token.UserAgent, BrowserHint: token.BrowserHint, LastVerified: token.LastVerified, ExpiresAt: token.ExpiresAt, // Store original expiration } // 3. Add to the database if err := tokenStore.addToken(tokenHash, storedData); err != nil { log.Printf("ERROR: Failed to store token in DB for hash %s: %v", tokenHash, err) // Decide if this is fatal or just a warning. For now, log and continue. // return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to store verification proof"}) } // 4. Sign the token (as before) token.Signature = "" // Clear signature before marshalling for signing tokenBytesForSig, _ := json.Marshal(token) token.Signature = computeTokenSignature(token, tokenBytesForSig) // 5. Prepare final token for cookie finalBytes, err := json.Marshal(token) if err != nil { log.Printf("ERROR: Failed to marshal final token: %v", err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to prepare token"}) } tokenStr := base64.StdEncoding.EncodeToString(finalBytes) // 6. Set cookie // Determine if we're serving on HTTPS or HTTP isSecure := true // Check if we're in development mode using non-secure connection if strings.HasPrefix(c.Protocol(), "http") && !strings.HasPrefix(c.BaseURL(), "https") { isSecure = false // Running on http:// (dev mode) } // Get domain for cookie - either from config or auto-detect cookieDomain := checkpointConfig.CookieDomain if cookieDomain == "" { // Auto-detect - for development convenience cookieDomain = getDomainFromHost(c.Hostname()) } // Set SameSite based on domain - use Lax for cross-subdomain sameSite := "Strict" if cookieDomain != "" { sameSite = "Lax" // Lax allows subdomain sharing better than Strict } c.Cookie(&fiber.Cookie{ Name: checkpointConfig.CookieName, Value: tokenStr, Expires: token.ExpiresAt, // Cookie expires when token expires Path: "/", Domain: cookieDomain, HTTPOnly: true, SameSite: sameSite, Secure: isSecure, // Only set Secure in HTTPS environments }) return c.JSON(fiber.Map{"token": tokenStr, "expires_at": token.ExpiresAt}) } // Initialize a secure random secret key or load from persistent storage func initSecret() bool { if _, err := os.Stat(checkpointConfig.SecretConfigPath); err == nil { // Config file exists, try to load it if loadedSecret := loadSecretFromFile(); loadedSecret != nil { hmacSecret = loadedSecret log.Printf("Loaded existing HMAC secret from %s", checkpointConfig.SecretConfigPath) return true } } // No config file or loading failed, generate a new secret hmacSecret = make([]byte, 32) _, err := cryptorand.Read(hmacSecret) if err != nil { // Critical security error - don't continue with insecure random numbers log.Fatalf("CRITICAL: Could not generate secure random secret: %v", err) } // Ensure data directory exists if err := os.MkdirAll(filepath.Dir(checkpointConfig.SecretConfigPath), 0755); err != nil { log.Printf("WARNING: Could not create data directory: %v", err) return true } // Save the new secret to file config := SecretConfig{ HmacSecret: hmacSecret, CreatedAt: time.Now(), UpdatedAt: time.Now(), } if configBytes, err := json.Marshal(config); err == nil { if err := os.WriteFile(checkpointConfig.SecretConfigPath, configBytes, 0600); err != nil { log.Printf("WARNING: Could not save HMAC secret to file: %v", err) } else { log.Printf("Created and saved new HMAC secret to %s", checkpointConfig.SecretConfigPath) } } return true } // loadSecretFromFile loads the HMAC secret from persistent storage func loadSecretFromFile() []byte { configBytes, err := os.ReadFile(checkpointConfig.SecretConfigPath) if err != nil { log.Printf("ERROR: Could not read secret config file: %v", err) return nil } var config SecretConfig if err := json.Unmarshal(configBytes, &config); err != nil { log.Printf("ERROR: Could not parse secret config file: %v", err) return nil } if len(config.HmacSecret) < 16 { log.Printf("ERROR: Secret from file is too short, generating a new one") return nil } // Update the last loaded time config.UpdatedAt = time.Now() if configBytes, err := json.Marshal(config); err == nil { if err := os.WriteFile(checkpointConfig.SecretConfigPath, configBytes, 0600); err != nil { log.Printf("WARNING: Could not update HMAC secret file: %v", err) } } return config.HmacSecret } // Start a timer to periodically clean up the nonce and rate limit maps func startCleanupTimer() bool { ticker := time.NewTicker(1 * time.Hour) go func() { for range ticker.C { cleanupExpiredData() cleanupExpiredChallenges() } }() return true } // Clean up expired nonces and rate limit data func cleanupExpiredData() { // Clean up used nonces now := time.Now() expiredNonceCount := 0 usedNonces.Range(func(key, value interface{}) bool { nonce := key.(string) timestamp := value.(time.Time) if now.Sub(timestamp) > checkpointConfig.MaxNonceAge { usedNonces.Delete(nonce) expiredNonceCount++ } return true // continue iteration }) if expiredNonceCount > 0 { log.Printf("Checkpoint: Cleaned up %d expired nonces.", expiredNonceCount) } // Reset IP rate limits every hour by deleting all entries ipRateLimit.Range(func(key, value interface{}) bool { ipRateLimit.Delete(key) return true }) log.Println("Checkpoint: IP rate limits reset.") } // CheckpointToken represents a validated token type CheckpointToken struct { Nonce string `json:"g"` // Nonce Challenge string `json:"-"` // Derived server-side, not in token Salt string `json:"-"` // Derived server-side, not in token Difficulty int `json:"-"` // Derived server-side, not in token ExpiresAt time.Time `json:"exp"` ClientIP string `json:"cip,omitempty"` UserAgent string `json:"ua,omitempty"` BrowserHint string `json:"bh,omitempty"` Entropy string `json:"ent,omitempty"` Created time.Time `json:"crt"` LastVerified time.Time `json:"lvf,omitempty"` Signature string `json:"sig,omitempty"` TokenFormat int `json:"fmt"` } // ChallengeParams stores parameters for a challenge type ChallengeParams struct { Challenge string `json:"challenge"` // Base64 encoded Salt string `json:"salt"` // Base64 encoded Difficulty int `json:"difficulty"` ExpiresAt time.Time `json:"expires_at"` ClientIP string `json:"-"` PoSSeed string `json:"pos_seed"` // Hex encoded } // isExcludedHTMLPath checks if a path should be excluded from the HTML checkpoint. // Exclusions happen based on configured prefixes or file extensions. func isExcludedHTMLPath(path string) bool { // 1. Check path prefixes for _, prefix := range checkpointConfig.HTMLCheckpointExclusions { if strings.HasPrefix(path, prefix) { return true // Excluded by prefix } } // 2. Check file extension using the set ext := strings.ToLower(filepath.Ext(path)) if ext != "" { if _, exists := checkpointConfig.HTMLCheckpointExcludedExtensions[ext]; exists { return true // Excluded by file extension } } // 3. If not excluded by prefix or extension, it needs the checkpoint return false } // DirectProxy returns a handler that simply forwards the request/response to targetURL. // Headers, status codes, and body are passed through without modification. func DirectProxy(targetURL string) fiber.Handler { target, err := url.Parse(targetURL) if err != nil { return func(c *fiber.Ctx) error { log.Printf("ERROR: Invalid target URL %s: %v", targetURL, err) return fiber.ErrBadGateway } } proxy := httputil.NewSingleHostReverseProxy(target) // Set up custom director to properly map headers originalDirector := proxy.Director proxy.Director = func(req *http.Request) { originalDirector(req) // Add X-Forwarded headers req.Header.Set("X-Forwarded-Host", req.Host) req.Header.Set("X-Forwarded-Proto", "http") // Update to https when needed if v := req.Header.Get("X-Forwarded-For"); v != "" { req.Header.Set("X-Forwarded-For", v+", "+req.RemoteAddr) } else { req.Header.Set("X-Forwarded-For", req.RemoteAddr) } } return func(c *fiber.Ctx) error { // Create proxy request proxyReq, err := http.NewRequest( string(c.Method()), target.String()+c.Path(), bytes.NewReader(c.Body()), ) if err != nil { log.Printf("ERROR: Failed to create proxy request: %v", err) return fiber.ErrBadGateway } // Copy all headers from the Fiber context to the proxy request c.Request().Header.VisitAll(func(key, value []byte) { proxyReq.Header.Set(string(key), string(value)) }) // Execute the proxy request proxyRes, err := http.DefaultClient.Do(proxyReq) if err != nil { log.Printf("ERROR: Proxy request failed: %v", err) return fiber.ErrBadGateway } defer proxyRes.Body.Close() // Copy all headers from the proxy response to Fiber's response for key, values := range proxyRes.Header { for _, value := range values { c.Response().Header.Add(key, value) } } // Set the status code c.Status(proxyRes.StatusCode) // Copy the body body, err := io.ReadAll(proxyRes.Body) if err != nil { log.Printf("ERROR: Failed to read proxy response body: %v", err) return fiber.ErrBadGateway } return c.Send(body) } } // isBlockedBot checks concurrently if the User-Agent indicates a known bot // or doesn't have a standard browser prefix. // It returns true as soon as one check decides to block. func isBlockedBot(userAgent string) bool { if userAgent == "" { // Empty User-Agent is suspicious, block it log.Printf("INFO: UA blocked - empty user agent") return true } ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Ensure context is cancelled eventually resultChan := make(chan bool, 2) // Buffered channel for results // Goroutine 1: Library-based bot check go func() { ua := useragent.Parse(userAgent) shouldBlock := ua.Bot if shouldBlock { log.Printf("INFO: UA blocked by library (Bot detected: %s): %s", ua.Name, userAgent) } select { case resultChan <- shouldBlock: case <-ctx.Done(): // Don't send if context is cancelled } }() // Goroutine 2: Prefix check go func() { // Standard browser User-Agent prefixes standardPrefixes := []string{"Mozilla/", "Opera/", "DuckDuckGo/", "Dart/"} hasStandardPrefix := false for _, prefix := range standardPrefixes { if strings.HasPrefix(userAgent, prefix) { hasStandardPrefix = true break } } // Block if it does NOT have a standard prefix shouldBlock := !hasStandardPrefix if shouldBlock { log.Printf("INFO: UA blocked by prefix check (doesn't have standard prefix): %s", userAgent) } select { case resultChan <- shouldBlock: case <-ctx.Done(): // Don't send if context is cancelled } }() // Wait for results and decide result1 := <-resultChan if result1 { cancel() // Found a reason to block, cancel the other check return true } // First check didn't block, wait for the second result result2 := <-resultChan // cancel() is deferred, so it will run anyway, ensuring cleanup return result2 // Block if the second check decided to block } // New gives you a Fiber handler that does the POW challenge (HTML/API) or proxies requests. func New() fiber.Handler { return func(c *fiber.Ctx) error { host := c.Hostname() targetURL, useProxy := checkpointConfig.ReverseProxyMappings[host] path := c.Path() // --- User-Agent Validation --- // Only check User-Agent if path is not in exclusion list skipUA := false for _, prefix := range checkpointConfig.UserAgentValidationExclusions { if strings.HasPrefix(path, prefix) { skipUA = true break } } if !skipUA { // First check required UA prefixes for specific paths for p, required := range checkpointConfig.UserAgentRequiredPrefixes { if strings.HasPrefix(path, p) { ua := c.Get("User-Agent") if !strings.HasPrefix(ua, required) { log.Printf("INFO: UA blocked by required prefix %s: %s", required, ua) if strings.HasPrefix(path, "/api") { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": "Access denied for automated clients.", "reason": "useragent", }) } return c.Status(fiber.StatusForbidden).SendString("Access denied for automated clients.") } break } } // Then do general bot check for all non-excluded paths userAgent := c.Get("User-Agent") if isBlockedBot(userAgent) { if strings.HasPrefix(path, "/api") { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": "Access denied for automated clients.", "reason": "useragent", }) } return c.Status(fiber.StatusForbidden).SendString("Access denied for automated clients.") } } // Handle any API endpoints if strings.HasPrefix(path, "/api") { // Always serve PoW endpoints locally (challenge & verify) if strings.HasPrefix(path, "/api/pow/") || strings.HasPrefix(path, "/api/verify") { log.Printf("API checkpoint endpoint %s - handling locally", path) return c.Next() } // Other API paths: skip checkpoint if useProxy { // Proxy to backend for proxied hosts log.Printf("API proxying endpoint %s to %s", path, targetURL) return DirectProxy(targetURL)(c) } log.Printf("API endpoint %s - bypassing checkpoint", path) return c.Next() } // --- Reverse Proxy Logic --- if useProxy { // Check for existing valid token cookie tokenCookie := c.Cookies(checkpointConfig.CookieName) log.Printf("Proxy: Checking token for host %s, path %s, cookie present: %v", host, path, tokenCookie != "") // Check if this is an excluded path (API endpoints, etc) if isExcludedHTMLPath(path) { log.Printf("Excluded path %s for proxied host %s - proxying without token check", path, host) // Direct transparent proxy (preserves all headers/content types) return DirectProxy(targetURL)(c) } valid, err := validateToken(tokenCookie, c) if err != nil { // Log validation errors but treat as invalid for proxying log.Printf("Error validating token for proxied host %s, path %s: %v", host, path, err) } if valid { log.Printf("Valid token found for proxied host %s, path %s - forwarding request", host, path) // Token is valid, proxy the request // Direct transparent proxy (preserves all headers/content types) return DirectProxy(targetURL)(c) } else { // Add debug logging log.Printf("No valid token for proxied host %s, path %s - serving interstitial", host, path) // Save the original full URL for potential redirection after verification c.Locals("originalURL", c.OriginalURL()) // No valid token, serve the interstitial challenge page. return serveInterstitial(c) } } // --- Standard HTML/Static/API Logic (No Proxy Mapping) --- // Skip checkpoint for excluded paths (e.g., static assets, API endpoints handled separately) if isExcludedHTMLPath(path) { return c.Next() } // --- Path needs checkpoint (potential HTML page) --- tokenCookie := c.Cookies(checkpointConfig.CookieName) if tokenCookie != "" { valid, err := validateToken(tokenCookie, c) if err != nil { // Log validation errors but still serve interstitial for safety log.Printf("Error validating token for path %s: %v", path, err) // Fall through to serve interstitial } else if valid { // Token is valid, proceed to the requested page/handler return c.Next() } // If token was present but invalid/expired, fall through to serve interstitial } // No valid token found, serve the interstitial challenge page. return serveInterstitial(c) } } // generateRequestID creates a unique ID for this verification request func generateRequestID(c *fiber.Ctx) string { challenge, salt := generateChallenge() // Generate PoS seed posSeedBytes := make([]byte, 32) if n, err := cryptorand.Read(posSeedBytes); err != nil { log.Fatalf("CRITICAL: Failed to generate PoS seed: %v", err) } else if n != len(posSeedBytes) { log.Fatalf("CRITICAL: Short read generating PoS seed: read %d bytes", n) } posSeed := hex.EncodeToString(posSeedBytes) // Generate request ID randBytes := make([]byte, 16) if n, err := cryptorand.Read(randBytes); err != nil { log.Fatalf("CRITICAL: Failed to generate request ID: %v", err) } else if n != len(randBytes) { log.Fatalf("CRITICAL: Short read generating request ID: read %d bytes", n) } requestID := hex.EncodeToString(randBytes) // Base64-encode the hex challenge and salt for storage encodedChallenge := base64.StdEncoding.EncodeToString([]byte(challenge)) encodedSalt := base64.StdEncoding.EncodeToString([]byte(salt)) params := ChallengeParams{ Challenge: encodedChallenge, Salt: encodedSalt, Difficulty: checkpointConfig.Difficulty, ExpiresAt: time.Now().Add(checkpointConfig.ChallengeExpiration), ClientIP: getRealIP(c), PoSSeed: posSeed, } challengeStore.Store(requestID, params) return requestID } func cleanupExpiredChallenges() { now := time.Now() expiredChallengeCount := 0 challengeStore.Range(func(key, value interface{}) bool { id := key.(string) params := value.(ChallengeParams) if now.After(params.ExpiresAt) { challengeStore.Delete(id) expiredChallengeCount++ } return true // continue iteration }) if expiredChallengeCount > 0 { log.Printf("Checkpoint: Cleaned up %d expired challenges.", expiredChallengeCount) } } // GetCheckpointChallengeHandler serves challenge parameters via API func GetCheckpointChallengeHandler(c *fiber.Ctx) error { requestID := c.Query("id") if requestID == "" { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Missing request ID"}) } // Apply rate limiting to challenge generation clientIP := getRealIP(c) val, _ := ipRateLimit.LoadOrStore(clientIP, new(atomic.Int64)) ipCounter := val.(*atomic.Int64) attempts := ipCounter.Add(1) // Increment and get new value // Limit to a reasonable number of challenge requests per hour (using the same MaxAttemptsPerHour config) if attempts > int64(checkpointConfig.MaxAttemptsPerHour) { return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{"error": "Too many challenge requests. Please try again later."}) } val, exists := challengeStore.Load(requestID) if !exists { return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "Challenge not found or expired"}) } params := val.(ChallengeParams) if clientIP != params.ClientIP { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "IP address mismatch for challenge"}) } decoySeedBytes := make([]byte, 8) cryptorand.Read(decoySeedBytes) decoySeed := hex.EncodeToString(decoySeedBytes) decoyFields := make([]map[string]interface{}, 0) decoyFieldCount := 2 + int(decoySeedBytes[0])%3 for i := 0; i < decoyFieldCount; i++ { nameLen := 5 + int(decoySeedBytes[i%8])%8 valLen := 8 + int(decoySeedBytes[(i+1)%8])%24 name := randomHexString(nameLen) val := randomHexString(valLen) decoyFields = append(decoyFields, map[string]interface{}{name: val}) } return c.JSON(fiber.Map{ "a": params.Challenge, // challenge "b": params.Salt, // salt "c": params.Difficulty, // difficulty "d": params.PoSSeed, // pos_seed "e": decoySeed, // decoy_seed "f": decoyFields, // decoy_fields }) } func randomHexString(n int) string { b := make([]byte, (n+1)/2) if m, err := cryptorand.Read(b); err != nil { log.Fatalf("CRITICAL: Failed to generate random hex string: %v", err) } else if m != len(b) { log.Fatalf("CRITICAL: Short read generating random hex string: read %d bytes", m) } s := hex.EncodeToString(b) if len(s) < n { log.Fatalf("CRITICAL: Random hex string too short: got %d hex chars, want %d", len(s), n) } return s[:n] } func getFullClientIP(c *fiber.Ctx) string { ip := getRealIP(c) if ip == "" { return "unknown" } h := sha256.Sum256([]byte(ip)) return hex.EncodeToString(h[:8]) } func hashUserAgent(userAgent string) string { if userAgent == "" { return "" } hash := sha256.Sum256([]byte(userAgent)) return hex.EncodeToString(hash[:8]) } func extractBrowserFingerprint(c *fiber.Ctx) string { headers := []string{ c.Get("Sec-CH-UA"), c.Get("Sec-CH-UA-Platform"), c.Get("Sec-CH-UA-Mobile"), c.Get("Sec-CH-UA-Platform-Version"), c.Get("Sec-CH-UA-Arch"), c.Get("Sec-CH-UA-Model"), } var validHeaders []string for _, h := range headers { if h != "" { validHeaders = append(validHeaders, h) } } if len(validHeaders) == 0 { return "" } fingerprint := strings.Join(validHeaders, "|") hash := sha256.Sum256([]byte(fingerprint)) return hex.EncodeToString(hash[:12]) } func validateToken(tokenStr string, c *fiber.Ctx) (bool, error) { // Explicitly handle missing token case first. if tokenStr == "" { return false, nil // No token cookie found, definitely not valid. } // 1. Decode the token string from the cookie tokenBytes, err := base64.StdEncoding.DecodeString(tokenStr) if err != nil { // Invalid Base64 encoding - treat as invalid token, not a system error return false, nil } // Check for empty byte slice after decoding if len(tokenBytes) == 0 { // Decoded to empty - treat as invalid token return false, nil } // 2. Unmarshal var token CheckpointToken if err := json.Unmarshal(tokenBytes, &token); err != nil { // Invalid JSON structure - treat as invalid token return false, nil // Error seen in logs comes from here, now returns nil error } // 3. Basic expiration check based on ExpiresAt field in the token itself // Note: Return nil error for expired token, it's just invalid. if time.Now().After(token.ExpiresAt) { return false, nil // Token itself says it's expired } // 4. Check token signature first (Format 2+) if token.TokenFormat < 2 { return false, nil // Old format not supported/secure - invalid } if !verifyTokenSignature(token, tokenBytes) { return false, nil // Invalid signature - invalid } // 5. Calculate the token hash to look up in the database tokenHash := calculateTokenHash(token) // 6. Look up the token data in BadgerDB storedData, found, dbErr := tokenStore.lookupTokenData(c.Context(), tokenHash) if dbErr != nil { // Actual DB error during lookup - THIS is a real error to return return false, fmt.Errorf("token DB lookup failed: %w", dbErr) } if !found { // Token hash not found in DB or explicitly expired according to DB record return false, nil } // 7. *** CRITICAL: Verify bindings against stored data and current request *** // Compare Client IP Hash currentPartialIP := getFullClientIP(c) if storedData.ClientIPHash != currentPartialIP { return false, nil // IP mismatch - invalid } // Compare User Agent Hash currentUserAgent := hashUserAgent(c.Get("User-Agent")) if storedData.UserAgentHash != currentUserAgent { return false, nil // User agent mismatch - invalid } // Compare Browser Hint currentBrowserHint := extractBrowserFingerprint(c) // Only enforce if hint was stored AND current hint is available if storedData.BrowserHint != "" && currentBrowserHint != "" && storedData.BrowserHint != currentBrowserHint { return false, nil // Browser hint mismatch - invalid } // 8. All checks passed! Token is valid and bound correctly. // Update LastVerified time in the database (best effort, log errors) if err := tokenStore.updateTokenVerification(tokenHash); err != nil { log.Printf("WARNING: Failed to update token verification time for hash %s: %v", tokenHash, err) } // Refresh the cookie with potentially updated ExpiresAt (if sliding window desired) or just LastVerified. // For simplicity, we'll just refresh with the same ExpiresAt for now. token.LastVerified = time.Now() updateTokenCookie(c, token) // Resign and set cookie return true, nil } func updateTokenCookie(c *fiber.Ctx, token CheckpointToken) { // Determine if we're serving on HTTPS or HTTP isSecure := true // Check if we're in development mode using non-secure connection if strings.HasPrefix(c.Protocol(), "http") && !strings.HasPrefix(c.BaseURL(), "https") { isSecure = false // Running on http:// (dev mode) } // Get domain for cookie - either from config or auto-detect cookieDomain := checkpointConfig.CookieDomain if cookieDomain == "" { // Auto-detect - for development convenience cookieDomain = getDomainFromHost(c.Hostname()) } // Set SameSite based on domain - use Lax for cross-subdomain sameSite := "Strict" if cookieDomain != "" { sameSite = "Lax" // Lax allows subdomain sharing better than Strict } // Recompute signature because LastVerified might have changed token.Signature = "" tempBytes, _ := json.Marshal(token) token.Signature = computeTokenSignature(token, tempBytes) // Compute signature on token WITHOUT old signature finalTokenBytes, err := json.Marshal(token) // Marshal again with new signature if err != nil { log.Printf("Error marshaling token for cookie update: %v", err) return } tokenStr := base64.StdEncoding.EncodeToString(finalTokenBytes) c.Cookie(&fiber.Cookie{ Name: checkpointConfig.CookieName, Value: tokenStr, Expires: token.ExpiresAt, // Use original expiration Path: "/", Domain: cookieDomain, HTTPOnly: true, SameSite: sameSite, Secure: isSecure, // Only set Secure in HTTPS environments }) } func verifyProofOfWork(challenge, salt, nonce string, difficulty int) bool { inputStr := challenge + salt + nonce hash := calculateHash(inputStr) prefix := strings.Repeat("0", difficulty) return strings.HasPrefix(hash, prefix) } func calculateHash(input string) string { hash := sha256.Sum256([]byte(input)) return hex.EncodeToString(hash[:]) } func computeTokenSignature(token CheckpointToken, tokenBytes []byte) string { tokenCopy := token tokenCopy.Signature = "" // Ensure signature field is empty for signing tokenToSign, _ := json.Marshal(tokenCopy) h := hmac.New(sha256.New, hmacSecret) h.Write(tokenToSign) return hex.EncodeToString(h.Sum(nil)) } func verifyTokenSignature(token CheckpointToken, tokenBytes []byte) bool { if token.Signature == "" { return false } expectedSignature := computeTokenSignature(token, tokenBytes) return hmac.Equal([]byte(token.Signature), []byte(expectedSignature)) } // VerifyCheckpointHandler verifies the challenge solution func VerifyCheckpointHandler(c *fiber.Ctx) error { clientIP := getRealIP(c) var req CheckpointVerifyRequest if err := c.BodyParser(&req); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request format"}) } // Challenge lookup challengeVal, challengeExists := challengeStore.Load(req.RequestID) if !challengeExists { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid or expired request ID"}) } params := challengeVal.(ChallengeParams) if clientIP != params.ClientIP { // Check against IP stored with challenge return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "IP address mismatch for challenge"}) } decodedChallenge := "" if decoded, err := base64.StdEncoding.DecodeString(params.Challenge); err == nil { decodedChallenge = string(decoded) } else { return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to decode challenge"}) } decodedSalt := "" if decoded, err := base64.StdEncoding.DecodeString(params.Salt); err == nil { decodedSalt = string(decoded) } else { return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to decode salt"}) } if req.Nonce == "" { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Nonce ('g') required"}) } // --- Nonce Check --- nonceKey := req.Nonce + decodedChallenge _, nonceExists := usedNonces.Load(nonceKey) if nonceExists { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "This solution has already been used"}) } // --- End Nonce Check --- if !verifyProofOfWork(decodedChallenge, decodedSalt, req.Nonce, params.Difficulty) { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid proof-of-work solution"}) } // --- Store Used Nonce (only after PoW is verified) --- usedNonces.Store(nonceKey, time.Now()) // --- End Store Used Nonce --- // Validate PoS hashes and times if provided if len(req.PoSHashes) == 3 && len(req.PoSTimes) == 3 { if req.PoSHashes[0] != req.PoSHashes[1] || req.PoSHashes[1] != req.PoSHashes[2] { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "PoS hashes ('h') do not match"}) } if len(req.PoSHashes[0]) != 64 { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid PoS hash ('h') length"}) } if err := checkPoSTimes(req.PoSTimes); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } } else if checkpointConfig.CheckPoSTimes && (len(req.PoSHashes) != 0 || len(req.PoSTimes) != 0) { // If PoS checking is enabled, but incorrect number of hashes/times provided return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid PoS data provided"}) } // Challenge is valid, remove it from store challengeStore.Delete(req.RequestID) entropyBytes := make([]byte, 8) _, err := cryptorand.Read(entropyBytes) if err != nil { return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to generate secure token entropy"}) } entropy := hex.EncodeToString(entropyBytes) // *** Gather current binding info for the new token *** now := time.Now() expiresAt := now.Add(checkpointConfig.TokenExpiration) browserHint := extractBrowserFingerprint(c) clientIPHash := getFullClientIP(c) userAgentHash := hashUserAgent(c.Get("User-Agent")) token := CheckpointToken{ Nonce: req.Nonce, ExpiresAt: expiresAt, ClientIP: clientIPHash, UserAgent: userAgentHash, BrowserHint: browserHint, Entropy: entropy, Created: now, LastVerified: now, TokenFormat: 2, } // Add a response header indicating success for the proxy c.Set("X-Checkpoint-Status", "success") log.Printf("Successfully verified challenge for IP %s, issuing token", clientIP) // Issue token (handles DB storage, signing, cookie setting) return issueToken(c, token) } // Renamed request struct type CheckpointVerifyRequest struct { RequestID string `json:"request_id"` Nonce string `json:"g"` PoSHashes []string `json:"h"` PoSTimes []int64 `json:"i"` DecoyHashes []string `json:"j"` DecoyTimes []int64 `json:"k"` DecoyFields []map[string]interface{} `json:"l"` } func generateChallenge() (string, string) { randomBytes := make([]byte, 16) _, err := cryptorand.Read(randomBytes) if err != nil { log.Fatalf("CRITICAL: Failed to generate secure random challenge: %v", err) } saltBytes := make([]byte, checkpointConfig.SaltLength) _, err = cryptorand.Read(saltBytes) if err != nil { log.Fatalf("CRITICAL: Failed to generate secure random salt: %v", err) } salt := hex.EncodeToString(saltBytes) return hex.EncodeToString(randomBytes), salt } // calculateTokenHash calculates a unique hash for storing the token status // IMPORTANT: This hash is now used as the key in the database. func calculateTokenHash(token CheckpointToken) string { // Hash relevant fields that identify this specific verification instance // Using Nonce, Entropy, and Creation time ensures uniqueness per issuance. data := fmt.Sprintf("%s:%s:%d", token.Nonce, token.Entropy, token.Created.UnixNano()) hash := sha256.Sum256([]byte(data)) return hex.EncodeToString(hash[:]) } // RequestSanitizationMiddleware spots malicious patterns (SQLi, XSS, path traversal) // and returns 403 immediately to keep your app safe. func RequestSanitizationMiddleware() fiber.Handler { return func(c *fiber.Ctx) error { // Check URL path for directory traversal path := c.Path() if strings.Contains(path, "../") || strings.Contains(path, "..\\") { log.Printf("Security block: Directory traversal attempt in path: %s from IP: %s", path, getRealIP(c)) return c.Status(fiber.StatusForbidden).SendString("Forbidden") } // Check query parameters for malicious patterns query := c.Request().URI().QueryString() if len(query) > 0 { queryStr := string(query) // Check for dangerous characters if configured if checkpointConfig.BlockDangerousPathChars { if strings.Contains(queryStr, ";") || strings.Contains(queryStr, "\\") || strings.Contains(queryStr, "`") { log.Printf("Security block: Dangerous character in query from IP: %s, Query: %s", getRealIP(c), queryStr) return c.Status(fiber.StatusForbidden).SendString("Forbidden") } } // Check for configured attack patterns for _, pattern := range checkpointConfig.DangerousQueryPatterns { if pattern.MatchString(queryStr) { log.Printf("Security block: Malicious pattern match in query from IP: %s, Pattern: %s, Query: %s", getRealIP(c), pattern.String(), queryStr) return c.Status(fiber.StatusForbidden).SendString("Forbidden") } } } return c.Next() } }