1482 lines
49 KiB
Go
1482 lines
49 KiB
Go
// middleware provides a small proof-of-work puzzle that users solve before
|
|
// accessing protected pages or APIs, plus transparent reverse-proxy support.
|
|
// It issues HMAC-signed tokens bound to IP/browser, stores them in BadgerDB,
|
|
// and automatically cleans up expired data.
|
|
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
cryptorand "crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"bytes"
|
|
"encoding/gob"
|
|
|
|
"html/template"
|
|
|
|
"github.com/dgraph-io/badger/v4"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/mileusna/useragent"
|
|
)
|
|
|
|
// --- Configuration ---
|
|
|
|
// Config struct holds all configurable parameters for the Checkpoint middleware
|
|
type Config struct {
|
|
// General Settings
|
|
Difficulty int // Number of leading zeros for PoW hash
|
|
TokenExpiration time.Duration // Validity period for issued tokens
|
|
CookieName string // Name of the cookie used to store tokens
|
|
CookieDomain string // Domain scope for the cookie (e.g., ".example.com" for subdomains)
|
|
SaltLength int // Length of the salt used in challenges
|
|
|
|
// Rate Limiting & Expiration
|
|
MaxAttemptsPerHour int // Max PoW verification attempts per IP per hour
|
|
MaxNonceAge time.Duration // Max age for used nonces before cleanup
|
|
ChallengeExpiration time.Duration // Time limit for solving a challenge
|
|
|
|
// File Paths
|
|
SecretConfigPath string // Path to the persistent HMAC secret file
|
|
TokenStoreDBPath string // Directory path for the BadgerDB token store
|
|
InterstitialPaths []string // Paths to search for the interstitial HTML page
|
|
|
|
// Security Settings
|
|
CheckPoSTimes bool // Enable Proof-of-Space-Time consistency checks
|
|
PoSTimeConsistencyRatio float64 // Allowed ratio between fastest and slowest PoS runs
|
|
HTMLCheckpointExclusions []string // Path prefixes to exclude from HTML checkpoint
|
|
HTMLCheckpointExcludedExtensions map[string]bool // File extensions to exclude (lowercase, '.')
|
|
DangerousQueryPatterns []*regexp.Regexp // Regex patterns to block in query strings
|
|
BlockDangerousPathChars bool // Block paths containing potentially dangerous characters (;, `)
|
|
// User Agent validation settings
|
|
UserAgentValidationExclusions []string // Path prefixes to skip UA validation
|
|
UserAgentRequiredPrefixes map[string]string // Path prefix -> required UA prefix
|
|
// Note: Binding to IP, User Agent, and Browser Hint is always enabled.
|
|
|
|
// Reverse Proxy Settings
|
|
ReverseProxyMappings map[string]string // Map of hostname to backend URL (e.g., "app.example.com": "http://127.0.0.1:8080")
|
|
}
|
|
|
|
var (
|
|
// Global configuration instance
|
|
checkpointConfig Config
|
|
|
|
// Secret key used for HMAC verification - automatically generated on startup
|
|
hmacSecret []byte
|
|
// Used nonces to prevent replay attacks - use sync.Map for concurrency
|
|
usedNonces sync.Map // map[string]time.Time
|
|
// IP-based rate limiting for token generation - use sync.Map for concurrency
|
|
ipRateLimit sync.Map // map[string]*atomic.Int64 (or similar atomic counter)
|
|
// Challenge parameters store with request IDs - use sync.Map for concurrency
|
|
challengeStore sync.Map // map[string]ChallengeParams
|
|
// Global token store (now BadgerDB based)
|
|
tokenStore *TokenStore
|
|
// in-memory cache for the interstitial HTML to avoid repeated disk reads
|
|
interstitialContent string
|
|
interstitialOnce sync.Once
|
|
interstitialLoadErr error
|
|
// parsed template for interstitial page
|
|
interstitialTmpl *template.Template
|
|
interstitialTmplOnce sync.Once
|
|
interstitialTmplErr error
|
|
// pool for gob encoding buffers to reduce allocations
|
|
gobBufferPool = sync.Pool{
|
|
New: func() interface{} { return new(bytes.Buffer) },
|
|
}
|
|
)
|
|
|
|
// Need atomic package for ipRateLimit counter
|
|
|
|
func init() {
|
|
// Load complete configuration from checkpoint.toml (required)
|
|
var cfg Config
|
|
if err := LoadConfig("checkpoint", &cfg); err != nil {
|
|
log.Fatalf("Failed to load checkpoint config: %v", err)
|
|
}
|
|
SetConfig(cfg)
|
|
// Register sanitization plugin (cleanup URLs/queries before checkpoint)
|
|
RegisterPlugin("sanitize", RequestSanitizationMiddleware)
|
|
// Register checkpoint plugin
|
|
RegisterPlugin("checkpoint", New)
|
|
|
|
// Initialize stores AFTER config is potentially set/loaded
|
|
// Ensure tokenStore is initialized before use
|
|
var err error
|
|
tokenStore, err = NewTokenStore(checkpointConfig.TokenStoreDBPath)
|
|
if err != nil {
|
|
log.Fatalf("CRITICAL: Failed to initialize TokenStore database: %v", err)
|
|
}
|
|
|
|
// Initialize secret
|
|
_ = initSecret()
|
|
|
|
// Start cleanup timer for nonces/ip rates (token cleanup handled by DB TTL)
|
|
_ = startCleanupTimer()
|
|
}
|
|
|
|
// SecretConfig contains configuration for the Checkpoint system (for secret file persistence)
|
|
type SecretConfig struct {
|
|
HmacSecret []byte `json:"hmac_secret"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
// --- End Configuration ---
|
|
|
|
// SetConfig swaps in your custom Config (usually loaded from TOML).
|
|
// Do this before using the middleware, ideally at startup.
|
|
func SetConfig(cfg Config) {
|
|
checkpointConfig = cfg
|
|
// Re-initialization of token store path is complex with BadgerDB, recommend restart.
|
|
// Other config changes can be applied dynamically if needed.
|
|
}
|
|
|
|
// --- Token Store (BadgerDB Implementation) ---
|
|
|
|
// StoredTokenData holds the relevant information persisted for each token hash.
|
|
// This includes binding information needed for verification.
|
|
type StoredTokenData struct {
|
|
ClientIPHash string // Hash of IP used during issuance
|
|
UserAgentHash string // Hash of User Agent used during issuance
|
|
BrowserHint string // Browser Hint used during issuance
|
|
LastVerified time.Time // Last time this token was successfully validated
|
|
ExpiresAt time.Time // Original expiration time of the token (for reference, TTL enforces)
|
|
}
|
|
|
|
// TokenStore manages persistent storage of verified tokens using BadgerDB.
|
|
type TokenStore struct {
|
|
DB *badger.DB
|
|
}
|
|
|
|
// NewTokenStore initializes and returns a new TokenStore using BadgerDB.
|
|
func NewTokenStore(dbPath string) (*TokenStore, error) {
|
|
if err := os.MkdirAll(dbPath, 0755); err != nil {
|
|
return nil, fmt.Errorf("failed to create token store directory %s: %w", dbPath, err)
|
|
}
|
|
opts := badger.DefaultOptions(dbPath)
|
|
// Tune options for performance if needed (e.g., memory usage)
|
|
opts.Logger = nil // Disable default Badger logger unless debugging
|
|
db, err := badger.Open(opts)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open token store database at %s: %w", dbPath, err)
|
|
}
|
|
store := &TokenStore{DB: db}
|
|
// Start BadgerDB's own value log GC routine (optional but recommended)
|
|
go store.runValueLogGC()
|
|
return store, nil
|
|
}
|
|
|
|
// Close closes the BadgerDB database.
|
|
// Should be called during graceful shutdown.
|
|
func (store *TokenStore) Close() error {
|
|
if store.DB != nil {
|
|
log.Println("Closing TokenStore database...")
|
|
return store.DB.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// runValueLogGC runs BadgerDB's value log garbage collection periodically.
|
|
func (store *TokenStore) runValueLogGC() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
again:
|
|
err := store.DB.RunValueLogGC(0.7) // Run GC if 70% space can be reclaimed
|
|
if err == nil {
|
|
goto again // Run GC multiple times if needed
|
|
}
|
|
if err != badger.ErrNoRewrite {
|
|
log.Printf("WARNING: BadgerDB RunValueLogGC error: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// encodeTokenData serializes StoredTokenData using gob.
|
|
func encodeTokenData(data *StoredTokenData) ([]byte, error) {
|
|
// get a buffer from pool
|
|
buf := gobBufferPool.Get().(*bytes.Buffer)
|
|
buf.Reset()
|
|
enc := gob.NewEncoder(buf)
|
|
if err := enc.Encode(data); err != nil {
|
|
gobBufferPool.Put(buf)
|
|
return nil, fmt.Errorf("failed to gob encode token data: %w", err)
|
|
}
|
|
// copy out the bytes to avoid retaining large buffer
|
|
out := make([]byte, buf.Len())
|
|
copy(out, buf.Bytes())
|
|
buf.Reset()
|
|
gobBufferPool.Put(buf)
|
|
return out, nil
|
|
}
|
|
|
|
// decodeTokenData deserializes StoredTokenData using gob.
|
|
func decodeTokenData(encoded []byte) (*StoredTokenData, error) {
|
|
var data StoredTokenData
|
|
// use a reader to avoid extra buffer allocation
|
|
reader := bytes.NewReader(encoded)
|
|
dec := gob.NewDecoder(reader)
|
|
if err := dec.Decode(&data); err != nil {
|
|
return nil, fmt.Errorf("failed to gob decode token data: %w", err)
|
|
}
|
|
return &data, nil
|
|
}
|
|
|
|
// addToken stores the token data in BadgerDB with a TTL.
|
|
func (store *TokenStore) addToken(tokenHash string, data *StoredTokenData) error {
|
|
encodedData, err := encodeTokenData(data)
|
|
if err != nil {
|
|
return err // Error already wrapped
|
|
}
|
|
|
|
// Calculate TTL based on the token's specific expiration
|
|
ttl := time.Until(data.ExpiresAt)
|
|
if ttl <= 0 {
|
|
log.Printf("Attempted to add already expired token hash %s", tokenHash)
|
|
return nil // Don't add already expired tokens
|
|
}
|
|
|
|
err = store.DB.Update(func(txn *badger.Txn) error {
|
|
e := badger.NewEntry([]byte(tokenHash), encodedData).WithTTL(ttl)
|
|
return txn.SetEntry(e)
|
|
})
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to add token hash %s to DB: %w", tokenHash, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// updateTokenVerification updates the LastVerified time for an existing token.
|
|
func (store *TokenStore) updateTokenVerification(tokenHash string) error {
|
|
return store.DB.Update(func(txn *badger.Txn) error {
|
|
item, err := txn.Get([]byte(tokenHash))
|
|
if err != nil {
|
|
// If token expired or was deleted between check and update, log and ignore.
|
|
if err == badger.ErrKeyNotFound {
|
|
log.Printf("Token hash %s not found during update verification (likely expired/deleted)", tokenHash)
|
|
return nil // Not a critical error in this context
|
|
}
|
|
return fmt.Errorf("failed to get token %s for update: %w", tokenHash, err)
|
|
}
|
|
|
|
var storedData *StoredTokenData
|
|
err = item.Value(func(val []byte) error {
|
|
storedData, err = decodeTokenData(val)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode token %s value for update: %w", tokenHash, err)
|
|
}
|
|
|
|
// Update LastVerified and re-encode
|
|
storedData.LastVerified = time.Now()
|
|
encodedData, err := encodeTokenData(storedData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set the entry again (TTL remains the same based on original ExpiresAt)
|
|
ttl := time.Until(storedData.ExpiresAt)
|
|
if ttl <= 0 {
|
|
return nil
|
|
} // Don't update if expired
|
|
e := badger.NewEntry([]byte(tokenHash), encodedData).WithTTL(ttl)
|
|
return txn.SetEntry(e)
|
|
})
|
|
}
|
|
|
|
// lookupTokenData retrieves token data from BadgerDB.
|
|
// Returns the data, true if found and not expired, or false otherwise.
|
|
// Added context parameter
|
|
func (store *TokenStore) lookupTokenData(ctx context.Context, tokenHash string) (*StoredTokenData, bool, error) {
|
|
var storedData *StoredTokenData
|
|
var found bool
|
|
|
|
err := store.DB.View(func(txn *badger.Txn) error {
|
|
// Check context cancellation within the transaction
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
item, err := txn.Get([]byte(tokenHash))
|
|
if err != nil {
|
|
if err == badger.ErrKeyNotFound {
|
|
return nil // Not found, not an error for lookup
|
|
}
|
|
return fmt.Errorf("failed to get token hash %s from DB: %w", tokenHash, err)
|
|
}
|
|
|
|
// Key exists, decode the value
|
|
err = item.Value(func(val []byte) error {
|
|
// Check context cancellation before decoding
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
var decodeErr error
|
|
storedData, decodeErr = decodeTokenData(val)
|
|
return decodeErr
|
|
})
|
|
if err != nil {
|
|
// If context was cancelled, return that error
|
|
if ctx.Err() != nil {
|
|
return ctx.Err()
|
|
}
|
|
// Return actual decoding error
|
|
return fmt.Errorf("failed to decode StoredTokenData for hash %s: %w", tokenHash, err)
|
|
}
|
|
|
|
// Check expiration explicitly just in case TTL mechanism has latency
|
|
if time.Now().After(storedData.ExpiresAt) {
|
|
log.Printf("Token hash %s found but expired (ExpiresAt: %v)", tokenHash, storedData.ExpiresAt)
|
|
storedData = nil // Treat as not found if expired
|
|
return nil
|
|
}
|
|
|
|
found = true
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
// Don't log here, return the error to the caller (validateToken)
|
|
return nil, false, err // Return the actual error
|
|
}
|
|
|
|
return storedData, found, nil // Success
|
|
}
|
|
|
|
// --- End Token Store ---
|
|
|
|
// CloseTokenStore provides a package-level function to close the global token store.
|
|
// This should be called during application shutdown.
|
|
func CloseTokenStore() error {
|
|
if tokenStore != nil {
|
|
return tokenStore.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// loadInterstitialHTML returns the cached interstitial HTML (loads once from disk)
|
|
func loadInterstitialHTML() (string, error) {
|
|
interstitialOnce.Do(func() {
|
|
for _, path := range checkpointConfig.InterstitialPaths {
|
|
if data, err := os.ReadFile(path); err == nil {
|
|
interstitialContent = string(data)
|
|
return
|
|
}
|
|
}
|
|
interstitialLoadErr = fmt.Errorf("could not find checkpoint interstitial HTML at any configured path")
|
|
})
|
|
return interstitialContent, interstitialLoadErr
|
|
}
|
|
|
|
// getInterstitialTemplate parses the cached HTML as a Go template (once)
|
|
func getInterstitialTemplate() (*template.Template, error) {
|
|
interstitialTmplOnce.Do(func() {
|
|
raw, err := loadInterstitialHTML()
|
|
if err != nil {
|
|
interstitialTmplErr = err
|
|
return
|
|
}
|
|
interstitialTmpl, interstitialTmplErr = template.New("interstitial").Parse(raw)
|
|
})
|
|
return interstitialTmpl, interstitialTmplErr
|
|
}
|
|
|
|
// serveInterstitial serves the challenge page using a Go template for safe interpolation
|
|
func serveInterstitial(c *fiber.Ctx) error {
|
|
requestID := generateRequestID(c)
|
|
c.Status(200)
|
|
c.Set("Content-Type", "text/html; charset=utf-8")
|
|
tmpl, err := getInterstitialTemplate()
|
|
if err != nil {
|
|
log.Printf("WARNING: %v", err)
|
|
return c.SendString("Security verification required. Please refresh the page.")
|
|
}
|
|
// prepare data for template
|
|
host := c.Hostname()
|
|
originalURL, _ := c.Locals("originalURL").(string)
|
|
targetPath := c.Path()
|
|
if originalURL != "" {
|
|
targetPath = originalURL
|
|
}
|
|
data := struct {
|
|
TargetPath string
|
|
RequestID string
|
|
Host string
|
|
FullURL string
|
|
}{
|
|
TargetPath: targetPath,
|
|
RequestID: requestID,
|
|
Host: host,
|
|
FullURL: c.BaseURL() + targetPath,
|
|
}
|
|
var buf bytes.Buffer
|
|
if err := tmpl.Execute(&buf, data); err != nil {
|
|
log.Printf("ERROR: Interstitial template execution failed: %v", err)
|
|
return c.SendString("Security verification required. Please refresh the page.")
|
|
}
|
|
return c.SendString(buf.String())
|
|
}
|
|
|
|
// checkPoSTimes ensures that memory proof run times are within the allowed ratio
|
|
func checkPoSTimes(times []int64) error {
|
|
if len(times) != 3 {
|
|
return fmt.Errorf("invalid PoS run times length")
|
|
}
|
|
minT, maxT := times[0], times[0]
|
|
for _, t := range times[1:] {
|
|
if t < minT {
|
|
minT = t
|
|
}
|
|
if t > maxT {
|
|
maxT = t
|
|
}
|
|
}
|
|
if checkpointConfig.CheckPoSTimes && float64(maxT) > float64(minT)*checkpointConfig.PoSTimeConsistencyRatio {
|
|
return fmt.Errorf("PoS run times ('i') are not consistent (ratio %.2f > %.2f)",
|
|
float64(maxT)/float64(minT), checkpointConfig.PoSTimeConsistencyRatio)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getDomainFromHost returns the base domain from a hostname
|
|
// For proper cookie sharing in both production and development
|
|
func getDomainFromHost(hostname string) string {
|
|
// Handle localhost development
|
|
if hostname == "localhost" || strings.HasPrefix(hostname, "localhost:") ||
|
|
hostname == "127.0.0.1" || strings.HasPrefix(hostname, "127.0.0.1:") {
|
|
return "" // Use host-only cookies for localhost
|
|
}
|
|
|
|
// For IP addresses, use host-only cookies
|
|
if net.ParseIP(strings.Split(hostname, ":")[0]) != nil {
|
|
return "" // IP address - use host-only
|
|
}
|
|
|
|
parts := strings.Split(hostname, ".")
|
|
if len(parts) <= 1 {
|
|
return hostname // single word domain - unlikely
|
|
}
|
|
|
|
// For standard domains, return domain with leading dot
|
|
if len(parts) >= 2 {
|
|
// Return parent domain for proper cookie sharing
|
|
domain := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
|
return "." + domain // Leading dot is important
|
|
}
|
|
|
|
return "" // Fallback to host-only cookie
|
|
}
|
|
|
|
// issueToken handles token generation, cookie setting, and JSON response
|
|
func issueToken(c *fiber.Ctx, token CheckpointToken) error {
|
|
// 1. Generate the token hash
|
|
tokenHash := calculateTokenHash(token)
|
|
|
|
// 2. Create the data to store in the DB
|
|
storedData := &StoredTokenData{
|
|
ClientIPHash: token.ClientIP, // Assumes token struct is already populated
|
|
UserAgentHash: token.UserAgent,
|
|
BrowserHint: token.BrowserHint,
|
|
LastVerified: token.LastVerified,
|
|
ExpiresAt: token.ExpiresAt, // Store original expiration
|
|
}
|
|
|
|
// 3. Add to the database
|
|
if err := tokenStore.addToken(tokenHash, storedData); err != nil {
|
|
log.Printf("ERROR: Failed to store token in DB for hash %s: %v", tokenHash, err)
|
|
// Decide if this is fatal or just a warning. For now, log and continue.
|
|
// return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to store verification proof"})
|
|
}
|
|
|
|
// 4. Sign the token (as before)
|
|
token.Signature = "" // Clear signature before marshalling for signing
|
|
tokenBytesForSig, _ := json.Marshal(token)
|
|
token.Signature = computeTokenSignature(token, tokenBytesForSig)
|
|
|
|
// 5. Prepare final token for cookie
|
|
finalBytes, err := json.Marshal(token)
|
|
if err != nil {
|
|
log.Printf("ERROR: Failed to marshal final token: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to prepare token"})
|
|
}
|
|
tokenStr := base64.StdEncoding.EncodeToString(finalBytes)
|
|
|
|
// 6. Set cookie
|
|
// Determine if we're serving on HTTPS or HTTP
|
|
isSecure := true
|
|
// Check if we're in development mode using non-secure connection
|
|
if strings.HasPrefix(c.Protocol(), "http") && !strings.HasPrefix(c.BaseURL(), "https") {
|
|
isSecure = false // Running on http:// (dev mode)
|
|
}
|
|
|
|
// Get domain for cookie - either from config or auto-detect
|
|
cookieDomain := checkpointConfig.CookieDomain
|
|
if cookieDomain == "" {
|
|
// Auto-detect - for development convenience
|
|
cookieDomain = getDomainFromHost(c.Hostname())
|
|
}
|
|
|
|
// Set SameSite based on domain - use Lax for cross-subdomain
|
|
sameSite := "Strict"
|
|
if cookieDomain != "" {
|
|
sameSite = "Lax" // Lax allows subdomain sharing better than Strict
|
|
}
|
|
|
|
c.Cookie(&fiber.Cookie{
|
|
Name: checkpointConfig.CookieName,
|
|
Value: tokenStr,
|
|
Expires: token.ExpiresAt, // Cookie expires when token expires
|
|
Path: "/",
|
|
Domain: cookieDomain,
|
|
HTTPOnly: true,
|
|
SameSite: sameSite,
|
|
Secure: isSecure, // Only set Secure in HTTPS environments
|
|
})
|
|
|
|
return c.JSON(fiber.Map{"token": tokenStr, "expires_at": token.ExpiresAt})
|
|
}
|
|
|
|
// Initialize a secure random secret key or load from persistent storage
|
|
func initSecret() bool {
|
|
if _, err := os.Stat(checkpointConfig.SecretConfigPath); err == nil {
|
|
// Config file exists, try to load it
|
|
if loadedSecret := loadSecretFromFile(); loadedSecret != nil {
|
|
hmacSecret = loadedSecret
|
|
log.Printf("Loaded existing HMAC secret from %s", checkpointConfig.SecretConfigPath)
|
|
return true
|
|
}
|
|
}
|
|
|
|
// No config file or loading failed, generate a new secret
|
|
hmacSecret = make([]byte, 32)
|
|
_, err := cryptorand.Read(hmacSecret)
|
|
if err != nil {
|
|
// Critical security error - don't continue with insecure random numbers
|
|
log.Fatalf("CRITICAL: Could not generate secure random secret: %v", err)
|
|
}
|
|
|
|
// Ensure data directory exists
|
|
if err := os.MkdirAll(filepath.Dir(checkpointConfig.SecretConfigPath), 0755); err != nil {
|
|
log.Printf("WARNING: Could not create data directory: %v", err)
|
|
return true
|
|
}
|
|
|
|
// Save the new secret to file
|
|
config := SecretConfig{
|
|
HmacSecret: hmacSecret,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
if configBytes, err := json.Marshal(config); err == nil {
|
|
if err := os.WriteFile(checkpointConfig.SecretConfigPath, configBytes, 0600); err != nil {
|
|
log.Printf("WARNING: Could not save HMAC secret to file: %v", err)
|
|
} else {
|
|
log.Printf("Created and saved new HMAC secret to %s", checkpointConfig.SecretConfigPath)
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// loadSecretFromFile loads the HMAC secret from persistent storage
|
|
func loadSecretFromFile() []byte {
|
|
configBytes, err := os.ReadFile(checkpointConfig.SecretConfigPath)
|
|
if err != nil {
|
|
log.Printf("ERROR: Could not read secret config file: %v", err)
|
|
return nil
|
|
}
|
|
|
|
var config SecretConfig
|
|
if err := json.Unmarshal(configBytes, &config); err != nil {
|
|
log.Printf("ERROR: Could not parse secret config file: %v", err)
|
|
return nil
|
|
}
|
|
|
|
if len(config.HmacSecret) < 16 {
|
|
log.Printf("ERROR: Secret from file is too short, generating a new one")
|
|
return nil
|
|
}
|
|
|
|
// Update the last loaded time
|
|
config.UpdatedAt = time.Now()
|
|
if configBytes, err := json.Marshal(config); err == nil {
|
|
if err := os.WriteFile(checkpointConfig.SecretConfigPath, configBytes, 0600); err != nil {
|
|
log.Printf("WARNING: Could not update HMAC secret file: %v", err)
|
|
}
|
|
}
|
|
|
|
return config.HmacSecret
|
|
}
|
|
|
|
// Start a timer to periodically clean up the nonce and rate limit maps
|
|
func startCleanupTimer() bool {
|
|
ticker := time.NewTicker(1 * time.Hour)
|
|
go func() {
|
|
for range ticker.C {
|
|
cleanupExpiredData()
|
|
cleanupExpiredChallenges()
|
|
}
|
|
}()
|
|
return true
|
|
}
|
|
|
|
// Clean up expired nonces and rate limit data
|
|
func cleanupExpiredData() {
|
|
// Clean up used nonces
|
|
now := time.Now()
|
|
expiredNonceCount := 0
|
|
usedNonces.Range(func(key, value interface{}) bool {
|
|
nonce := key.(string)
|
|
timestamp := value.(time.Time)
|
|
if now.Sub(timestamp) > checkpointConfig.MaxNonceAge {
|
|
usedNonces.Delete(nonce)
|
|
expiredNonceCount++
|
|
}
|
|
return true // continue iteration
|
|
})
|
|
if expiredNonceCount > 0 {
|
|
log.Printf("Checkpoint: Cleaned up %d expired nonces.", expiredNonceCount)
|
|
}
|
|
|
|
// Reset IP rate limits every hour by deleting all entries
|
|
ipRateLimit.Range(func(key, value interface{}) bool {
|
|
ipRateLimit.Delete(key)
|
|
return true
|
|
})
|
|
log.Println("Checkpoint: IP rate limits reset.")
|
|
}
|
|
|
|
// CheckpointToken represents a validated token
|
|
type CheckpointToken struct {
|
|
Nonce string `json:"g"` // Nonce
|
|
Challenge string `json:"-"` // Derived server-side, not in token
|
|
Salt string `json:"-"` // Derived server-side, not in token
|
|
Difficulty int `json:"-"` // Derived server-side, not in token
|
|
ExpiresAt time.Time `json:"exp"`
|
|
ClientIP string `json:"cip,omitempty"`
|
|
UserAgent string `json:"ua,omitempty"`
|
|
BrowserHint string `json:"bh,omitempty"`
|
|
Entropy string `json:"ent,omitempty"`
|
|
Created time.Time `json:"crt"`
|
|
LastVerified time.Time `json:"lvf,omitempty"`
|
|
Signature string `json:"sig,omitempty"`
|
|
TokenFormat int `json:"fmt"`
|
|
}
|
|
|
|
// ChallengeParams stores parameters for a challenge
|
|
type ChallengeParams struct {
|
|
Challenge string `json:"challenge"` // Base64 encoded
|
|
Salt string `json:"salt"` // Base64 encoded
|
|
Difficulty int `json:"difficulty"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
ClientIP string `json:"-"`
|
|
PoSSeed string `json:"pos_seed"` // Hex encoded
|
|
}
|
|
|
|
// isExcludedHTMLPath checks if a path should be excluded from the HTML checkpoint.
|
|
// Exclusions happen based on configured prefixes or file extensions.
|
|
func isExcludedHTMLPath(path string) bool {
|
|
// 1. Check path prefixes
|
|
for _, prefix := range checkpointConfig.HTMLCheckpointExclusions {
|
|
if strings.HasPrefix(path, prefix) {
|
|
return true // Excluded by prefix
|
|
}
|
|
}
|
|
|
|
// 2. Check file extension using the set
|
|
ext := strings.ToLower(filepath.Ext(path))
|
|
if ext != "" {
|
|
if _, exists := checkpointConfig.HTMLCheckpointExcludedExtensions[ext]; exists {
|
|
return true // Excluded by file extension
|
|
}
|
|
}
|
|
|
|
// 3. If not excluded by prefix or extension, it needs the checkpoint
|
|
return false
|
|
}
|
|
|
|
// DirectProxy returns a handler that simply forwards the request/response to targetURL.
|
|
// Headers, status codes, and body are passed through without modification.
|
|
func DirectProxy(targetURL string) fiber.Handler {
|
|
target, err := url.Parse(targetURL)
|
|
if err != nil {
|
|
return func(c *fiber.Ctx) error {
|
|
log.Printf("ERROR: Invalid target URL %s: %v", targetURL, err)
|
|
return fiber.ErrBadGateway
|
|
}
|
|
}
|
|
|
|
proxy := httputil.NewSingleHostReverseProxy(target)
|
|
|
|
// Set up custom director to properly map headers
|
|
originalDirector := proxy.Director
|
|
proxy.Director = func(req *http.Request) {
|
|
originalDirector(req)
|
|
|
|
// Add X-Forwarded headers
|
|
req.Header.Set("X-Forwarded-Host", req.Host)
|
|
req.Header.Set("X-Forwarded-Proto", "http") // Update to https when needed
|
|
|
|
if v := req.Header.Get("X-Forwarded-For"); v != "" {
|
|
req.Header.Set("X-Forwarded-For", v+", "+req.RemoteAddr)
|
|
} else {
|
|
req.Header.Set("X-Forwarded-For", req.RemoteAddr)
|
|
}
|
|
}
|
|
|
|
return func(c *fiber.Ctx) error {
|
|
// Create proxy request
|
|
proxyReq, err := http.NewRequest(
|
|
string(c.Method()),
|
|
target.String()+c.Path(),
|
|
bytes.NewReader(c.Body()),
|
|
)
|
|
if err != nil {
|
|
log.Printf("ERROR: Failed to create proxy request: %v", err)
|
|
return fiber.ErrBadGateway
|
|
}
|
|
|
|
// Copy all headers from the Fiber context to the proxy request
|
|
c.Request().Header.VisitAll(func(key, value []byte) {
|
|
proxyReq.Header.Set(string(key), string(value))
|
|
})
|
|
|
|
// Execute the proxy request
|
|
proxyRes, err := http.DefaultClient.Do(proxyReq)
|
|
if err != nil {
|
|
log.Printf("ERROR: Proxy request failed: %v", err)
|
|
return fiber.ErrBadGateway
|
|
}
|
|
defer proxyRes.Body.Close()
|
|
|
|
// Copy all headers from the proxy response to Fiber's response
|
|
for key, values := range proxyRes.Header {
|
|
for _, value := range values {
|
|
c.Response().Header.Add(key, value)
|
|
}
|
|
}
|
|
|
|
// Set the status code
|
|
c.Status(proxyRes.StatusCode)
|
|
|
|
// Copy the body
|
|
body, err := io.ReadAll(proxyRes.Body)
|
|
if err != nil {
|
|
log.Printf("ERROR: Failed to read proxy response body: %v", err)
|
|
return fiber.ErrBadGateway
|
|
}
|
|
|
|
return c.Send(body)
|
|
}
|
|
}
|
|
|
|
// isBlockedBot checks concurrently if the User-Agent indicates a known bot
|
|
// or doesn't have a standard browser prefix.
|
|
// It returns true as soon as one check decides to block.
|
|
func isBlockedBot(userAgent string) bool {
|
|
if userAgent == "" {
|
|
// Empty User-Agent is suspicious, block it
|
|
log.Printf("INFO: UA blocked - empty user agent")
|
|
return true
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel() // Ensure context is cancelled eventually
|
|
|
|
resultChan := make(chan bool, 2) // Buffered channel for results
|
|
|
|
// Goroutine 1: Library-based bot check
|
|
go func() {
|
|
ua := useragent.Parse(userAgent)
|
|
shouldBlock := ua.Bot
|
|
if shouldBlock {
|
|
log.Printf("INFO: UA blocked by library (Bot detected: %s): %s", ua.Name, userAgent)
|
|
}
|
|
select {
|
|
case resultChan <- shouldBlock:
|
|
case <-ctx.Done(): // Don't send if context is cancelled
|
|
}
|
|
}()
|
|
|
|
// Goroutine 2: Prefix check
|
|
go func() {
|
|
// Standard browser User-Agent prefixes
|
|
standardPrefixes := []string{"Mozilla/", "Opera/", "DuckDuckGo/", "Dart/"}
|
|
hasStandardPrefix := false
|
|
|
|
for _, prefix := range standardPrefixes {
|
|
if strings.HasPrefix(userAgent, prefix) {
|
|
hasStandardPrefix = true
|
|
break
|
|
}
|
|
}
|
|
|
|
// Block if it does NOT have a standard prefix
|
|
shouldBlock := !hasStandardPrefix
|
|
|
|
if shouldBlock {
|
|
log.Printf("INFO: UA blocked by prefix check (doesn't have standard prefix): %s", userAgent)
|
|
}
|
|
select {
|
|
case resultChan <- shouldBlock:
|
|
case <-ctx.Done(): // Don't send if context is cancelled
|
|
}
|
|
}()
|
|
|
|
// Wait for results and decide
|
|
result1 := <-resultChan
|
|
if result1 {
|
|
cancel() // Found a reason to block, cancel the other check
|
|
return true
|
|
}
|
|
|
|
// First check didn't block, wait for the second result
|
|
result2 := <-resultChan
|
|
// cancel() is deferred, so it will run anyway, ensuring cleanup
|
|
return result2 // Block if the second check decided to block
|
|
}
|
|
|
|
// New gives you a Fiber handler that does the POW challenge (HTML/API) or proxies requests.
|
|
func New() fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
host := c.Hostname()
|
|
targetURL, useProxy := checkpointConfig.ReverseProxyMappings[host]
|
|
path := c.Path()
|
|
|
|
// --- User-Agent Validation ---
|
|
// Only check User-Agent if path is not in exclusion list
|
|
skipUA := false
|
|
for _, prefix := range checkpointConfig.UserAgentValidationExclusions {
|
|
if strings.HasPrefix(path, prefix) {
|
|
skipUA = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !skipUA {
|
|
// First check required UA prefixes for specific paths
|
|
for p, required := range checkpointConfig.UserAgentRequiredPrefixes {
|
|
if strings.HasPrefix(path, p) {
|
|
ua := c.Get("User-Agent")
|
|
if !strings.HasPrefix(ua, required) {
|
|
log.Printf("INFO: UA blocked by required prefix %s: %s", required, ua)
|
|
if strings.HasPrefix(path, "/api") {
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
|
"error": "Access denied for automated clients.",
|
|
"reason": "useragent",
|
|
})
|
|
}
|
|
return c.Status(fiber.StatusForbidden).SendString("Access denied for automated clients.")
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
// Then do general bot check for all non-excluded paths
|
|
userAgent := c.Get("User-Agent")
|
|
if isBlockedBot(userAgent) {
|
|
if strings.HasPrefix(path, "/api") {
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
|
|
"error": "Access denied for automated clients.",
|
|
"reason": "useragent",
|
|
})
|
|
}
|
|
return c.Status(fiber.StatusForbidden).SendString("Access denied for automated clients.")
|
|
}
|
|
}
|
|
|
|
// Handle any API endpoints
|
|
if strings.HasPrefix(path, "/api") {
|
|
// Always serve PoW endpoints locally (challenge & verify)
|
|
if strings.HasPrefix(path, "/api/pow/") || strings.HasPrefix(path, "/api/verify") {
|
|
log.Printf("API checkpoint endpoint %s - handling locally", path)
|
|
return c.Next()
|
|
}
|
|
// Other API paths: skip checkpoint
|
|
if useProxy {
|
|
// Proxy to backend for proxied hosts
|
|
log.Printf("API proxying endpoint %s to %s", path, targetURL)
|
|
return DirectProxy(targetURL)(c)
|
|
}
|
|
log.Printf("API endpoint %s - bypassing checkpoint", path)
|
|
return c.Next()
|
|
}
|
|
|
|
// --- Reverse Proxy Logic ---
|
|
if useProxy {
|
|
// Check for existing valid token cookie
|
|
tokenCookie := c.Cookies(checkpointConfig.CookieName)
|
|
log.Printf("Proxy: Checking token for host %s, path %s, cookie present: %v",
|
|
host, path, tokenCookie != "")
|
|
|
|
// Check if this is an excluded path (API endpoints, etc)
|
|
if isExcludedHTMLPath(path) {
|
|
log.Printf("Excluded path %s for proxied host %s - proxying without token check", path, host)
|
|
|
|
// Direct transparent proxy (preserves all headers/content types)
|
|
return DirectProxy(targetURL)(c)
|
|
}
|
|
|
|
valid, err := validateToken(tokenCookie, c)
|
|
|
|
if err != nil {
|
|
// Log validation errors but treat as invalid for proxying
|
|
log.Printf("Error validating token for proxied host %s, path %s: %v", host, path, err)
|
|
}
|
|
|
|
if valid {
|
|
log.Printf("Valid token found for proxied host %s, path %s - forwarding request", host, path)
|
|
// Token is valid, proxy the request
|
|
// Direct transparent proxy (preserves all headers/content types)
|
|
return DirectProxy(targetURL)(c)
|
|
} else {
|
|
// Add debug logging
|
|
log.Printf("No valid token for proxied host %s, path %s - serving interstitial", host, path)
|
|
|
|
// Save the original full URL for potential redirection after verification
|
|
c.Locals("originalURL", c.OriginalURL())
|
|
|
|
// No valid token, serve the interstitial challenge page.
|
|
return serveInterstitial(c)
|
|
}
|
|
}
|
|
|
|
// --- Standard HTML/Static/API Logic (No Proxy Mapping) ---
|
|
// Skip checkpoint for excluded paths (e.g., static assets, API endpoints handled separately)
|
|
if isExcludedHTMLPath(path) {
|
|
return c.Next()
|
|
}
|
|
|
|
// --- Path needs checkpoint (potential HTML page) ---
|
|
tokenCookie := c.Cookies(checkpointConfig.CookieName)
|
|
if tokenCookie != "" {
|
|
valid, err := validateToken(tokenCookie, c)
|
|
if err != nil {
|
|
// Log validation errors but still serve interstitial for safety
|
|
log.Printf("Error validating token for path %s: %v", path, err)
|
|
// Fall through to serve interstitial
|
|
} else if valid {
|
|
// Token is valid, proceed to the requested page/handler
|
|
return c.Next()
|
|
}
|
|
// If token was present but invalid/expired, fall through to serve interstitial
|
|
}
|
|
|
|
// No valid token found, serve the interstitial challenge page.
|
|
return serveInterstitial(c)
|
|
}
|
|
}
|
|
|
|
// generateRequestID creates a unique ID for this verification request
|
|
func generateRequestID(c *fiber.Ctx) string {
|
|
challenge, salt := generateChallenge()
|
|
// Generate PoS seed
|
|
posSeedBytes := make([]byte, 32)
|
|
if n, err := cryptorand.Read(posSeedBytes); err != nil {
|
|
log.Fatalf("CRITICAL: Failed to generate PoS seed: %v", err)
|
|
} else if n != len(posSeedBytes) {
|
|
log.Fatalf("CRITICAL: Short read generating PoS seed: read %d bytes", n)
|
|
}
|
|
posSeed := hex.EncodeToString(posSeedBytes)
|
|
// Generate request ID
|
|
randBytes := make([]byte, 16)
|
|
if n, err := cryptorand.Read(randBytes); err != nil {
|
|
log.Fatalf("CRITICAL: Failed to generate request ID: %v", err)
|
|
} else if n != len(randBytes) {
|
|
log.Fatalf("CRITICAL: Short read generating request ID: read %d bytes", n)
|
|
}
|
|
requestID := hex.EncodeToString(randBytes)
|
|
|
|
// Base64-encode the hex challenge and salt for storage
|
|
encodedChallenge := base64.StdEncoding.EncodeToString([]byte(challenge))
|
|
encodedSalt := base64.StdEncoding.EncodeToString([]byte(salt))
|
|
params := ChallengeParams{
|
|
Challenge: encodedChallenge,
|
|
Salt: encodedSalt,
|
|
Difficulty: checkpointConfig.Difficulty,
|
|
ExpiresAt: time.Now().Add(checkpointConfig.ChallengeExpiration),
|
|
ClientIP: getRealIP(c),
|
|
PoSSeed: posSeed,
|
|
}
|
|
challengeStore.Store(requestID, params)
|
|
return requestID
|
|
}
|
|
|
|
func cleanupExpiredChallenges() {
|
|
now := time.Now()
|
|
expiredChallengeCount := 0
|
|
challengeStore.Range(func(key, value interface{}) bool {
|
|
id := key.(string)
|
|
params := value.(ChallengeParams)
|
|
if now.After(params.ExpiresAt) {
|
|
challengeStore.Delete(id)
|
|
expiredChallengeCount++
|
|
}
|
|
return true // continue iteration
|
|
})
|
|
if expiredChallengeCount > 0 {
|
|
log.Printf("Checkpoint: Cleaned up %d expired challenges.", expiredChallengeCount)
|
|
}
|
|
}
|
|
|
|
// GetCheckpointChallengeHandler serves challenge parameters via API
|
|
func GetCheckpointChallengeHandler(c *fiber.Ctx) error {
|
|
requestID := c.Query("id")
|
|
if requestID == "" {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Missing request ID"})
|
|
}
|
|
|
|
// Apply rate limiting to challenge generation
|
|
clientIP := getRealIP(c)
|
|
val, _ := ipRateLimit.LoadOrStore(clientIP, new(atomic.Int64))
|
|
ipCounter := val.(*atomic.Int64)
|
|
attempts := ipCounter.Add(1) // Increment and get new value
|
|
|
|
// Limit to a reasonable number of challenge requests per hour (using the same MaxAttemptsPerHour config)
|
|
if attempts > int64(checkpointConfig.MaxAttemptsPerHour) {
|
|
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{"error": "Too many challenge requests. Please try again later."})
|
|
}
|
|
|
|
val, exists := challengeStore.Load(requestID)
|
|
if !exists {
|
|
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "Challenge not found or expired"})
|
|
}
|
|
params := val.(ChallengeParams)
|
|
|
|
if clientIP != params.ClientIP {
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "IP address mismatch for challenge"})
|
|
}
|
|
decoySeedBytes := make([]byte, 8)
|
|
cryptorand.Read(decoySeedBytes)
|
|
decoySeed := hex.EncodeToString(decoySeedBytes)
|
|
decoyFields := make([]map[string]interface{}, 0)
|
|
decoyFieldCount := 2 + int(decoySeedBytes[0])%3
|
|
for i := 0; i < decoyFieldCount; i++ {
|
|
nameLen := 5 + int(decoySeedBytes[i%8])%8
|
|
valLen := 8 + int(decoySeedBytes[(i+1)%8])%24
|
|
name := randomHexString(nameLen)
|
|
val := randomHexString(valLen)
|
|
decoyFields = append(decoyFields, map[string]interface{}{name: val})
|
|
}
|
|
return c.JSON(fiber.Map{
|
|
"a": params.Challenge, // challenge
|
|
"b": params.Salt, // salt
|
|
"c": params.Difficulty, // difficulty
|
|
"d": params.PoSSeed, // pos_seed
|
|
"e": decoySeed, // decoy_seed
|
|
"f": decoyFields, // decoy_fields
|
|
})
|
|
}
|
|
|
|
func randomHexString(n int) string {
|
|
b := make([]byte, (n+1)/2)
|
|
if m, err := cryptorand.Read(b); err != nil {
|
|
log.Fatalf("CRITICAL: Failed to generate random hex string: %v", err)
|
|
} else if m != len(b) {
|
|
log.Fatalf("CRITICAL: Short read generating random hex string: read %d bytes", m)
|
|
}
|
|
s := hex.EncodeToString(b)
|
|
if len(s) < n {
|
|
log.Fatalf("CRITICAL: Random hex string too short: got %d hex chars, want %d", len(s), n)
|
|
}
|
|
return s[:n]
|
|
}
|
|
|
|
func getFullClientIP(c *fiber.Ctx) string {
|
|
ip := getRealIP(c)
|
|
if ip == "" {
|
|
return "unknown"
|
|
}
|
|
h := sha256.Sum256([]byte(ip))
|
|
return hex.EncodeToString(h[:8])
|
|
}
|
|
|
|
func hashUserAgent(userAgent string) string {
|
|
if userAgent == "" {
|
|
return ""
|
|
}
|
|
hash := sha256.Sum256([]byte(userAgent))
|
|
return hex.EncodeToString(hash[:8])
|
|
}
|
|
|
|
func extractBrowserFingerprint(c *fiber.Ctx) string {
|
|
headers := []string{
|
|
c.Get("Sec-CH-UA"), c.Get("Sec-CH-UA-Platform"), c.Get("Sec-CH-UA-Mobile"),
|
|
c.Get("Sec-CH-UA-Platform-Version"), c.Get("Sec-CH-UA-Arch"), c.Get("Sec-CH-UA-Model"),
|
|
}
|
|
var validHeaders []string
|
|
for _, h := range headers {
|
|
if h != "" {
|
|
validHeaders = append(validHeaders, h)
|
|
}
|
|
}
|
|
if len(validHeaders) == 0 {
|
|
return ""
|
|
}
|
|
fingerprint := strings.Join(validHeaders, "|")
|
|
hash := sha256.Sum256([]byte(fingerprint))
|
|
return hex.EncodeToString(hash[:12])
|
|
}
|
|
|
|
func validateToken(tokenStr string, c *fiber.Ctx) (bool, error) {
|
|
// Explicitly handle missing token case first.
|
|
if tokenStr == "" {
|
|
return false, nil // No token cookie found, definitely not valid.
|
|
}
|
|
|
|
// 1. Decode the token string from the cookie
|
|
tokenBytes, err := base64.StdEncoding.DecodeString(tokenStr)
|
|
if err != nil {
|
|
// Invalid Base64 encoding - treat as invalid token, not a system error
|
|
return false, nil
|
|
}
|
|
|
|
// Check for empty byte slice after decoding
|
|
if len(tokenBytes) == 0 {
|
|
// Decoded to empty - treat as invalid token
|
|
return false, nil
|
|
}
|
|
|
|
// 2. Unmarshal
|
|
var token CheckpointToken
|
|
if err := json.Unmarshal(tokenBytes, &token); err != nil {
|
|
// Invalid JSON structure - treat as invalid token
|
|
return false, nil // Error seen in logs comes from here, now returns nil error
|
|
}
|
|
|
|
// 3. Basic expiration check based on ExpiresAt field in the token itself
|
|
// Note: Return nil error for expired token, it's just invalid.
|
|
if time.Now().After(token.ExpiresAt) {
|
|
return false, nil // Token itself says it's expired
|
|
}
|
|
|
|
// 4. Check token signature first (Format 2+)
|
|
if token.TokenFormat < 2 {
|
|
return false, nil // Old format not supported/secure - invalid
|
|
}
|
|
if !verifyTokenSignature(token, tokenBytes) {
|
|
return false, nil // Invalid signature - invalid
|
|
}
|
|
|
|
// 5. Calculate the token hash to look up in the database
|
|
tokenHash := calculateTokenHash(token)
|
|
|
|
// 6. Look up the token data in BadgerDB
|
|
storedData, found, dbErr := tokenStore.lookupTokenData(c.Context(), tokenHash)
|
|
if dbErr != nil {
|
|
// Actual DB error during lookup - THIS is a real error to return
|
|
return false, fmt.Errorf("token DB lookup failed: %w", dbErr)
|
|
}
|
|
if !found {
|
|
// Token hash not found in DB or explicitly expired according to DB record
|
|
return false, nil
|
|
}
|
|
|
|
// 7. *** CRITICAL: Verify bindings against stored data and current request ***
|
|
// Compare Client IP Hash
|
|
currentPartialIP := getFullClientIP(c)
|
|
if storedData.ClientIPHash != currentPartialIP {
|
|
return false, nil // IP mismatch - invalid
|
|
}
|
|
|
|
// Compare User Agent Hash
|
|
currentUserAgent := hashUserAgent(c.Get("User-Agent"))
|
|
if storedData.UserAgentHash != currentUserAgent {
|
|
return false, nil // User agent mismatch - invalid
|
|
}
|
|
|
|
// Compare Browser Hint
|
|
currentBrowserHint := extractBrowserFingerprint(c)
|
|
// Only enforce if hint was stored AND current hint is available
|
|
if storedData.BrowserHint != "" && currentBrowserHint != "" && storedData.BrowserHint != currentBrowserHint {
|
|
return false, nil // Browser hint mismatch - invalid
|
|
}
|
|
|
|
// 8. All checks passed! Token is valid and bound correctly.
|
|
// Update LastVerified time in the database (best effort, log errors)
|
|
if err := tokenStore.updateTokenVerification(tokenHash); err != nil {
|
|
log.Printf("WARNING: Failed to update token verification time for hash %s: %v", tokenHash, err)
|
|
}
|
|
|
|
// Refresh the cookie with potentially updated ExpiresAt (if sliding window desired) or just LastVerified.
|
|
// For simplicity, we'll just refresh with the same ExpiresAt for now.
|
|
token.LastVerified = time.Now()
|
|
updateTokenCookie(c, token) // Resign and set cookie
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func updateTokenCookie(c *fiber.Ctx, token CheckpointToken) {
|
|
// Determine if we're serving on HTTPS or HTTP
|
|
isSecure := true
|
|
// Check if we're in development mode using non-secure connection
|
|
if strings.HasPrefix(c.Protocol(), "http") && !strings.HasPrefix(c.BaseURL(), "https") {
|
|
isSecure = false // Running on http:// (dev mode)
|
|
}
|
|
|
|
// Get domain for cookie - either from config or auto-detect
|
|
cookieDomain := checkpointConfig.CookieDomain
|
|
if cookieDomain == "" {
|
|
// Auto-detect - for development convenience
|
|
cookieDomain = getDomainFromHost(c.Hostname())
|
|
}
|
|
|
|
// Set SameSite based on domain - use Lax for cross-subdomain
|
|
sameSite := "Strict"
|
|
if cookieDomain != "" {
|
|
sameSite = "Lax" // Lax allows subdomain sharing better than Strict
|
|
}
|
|
|
|
// Recompute signature because LastVerified might have changed
|
|
token.Signature = ""
|
|
tempBytes, _ := json.Marshal(token)
|
|
token.Signature = computeTokenSignature(token, tempBytes) // Compute signature on token WITHOUT old signature
|
|
|
|
finalTokenBytes, err := json.Marshal(token) // Marshal again with new signature
|
|
if err != nil {
|
|
log.Printf("Error marshaling token for cookie update: %v", err)
|
|
return
|
|
}
|
|
tokenStr := base64.StdEncoding.EncodeToString(finalTokenBytes)
|
|
c.Cookie(&fiber.Cookie{
|
|
Name: checkpointConfig.CookieName,
|
|
Value: tokenStr,
|
|
Expires: token.ExpiresAt, // Use original expiration
|
|
Path: "/",
|
|
Domain: cookieDomain,
|
|
HTTPOnly: true,
|
|
SameSite: sameSite,
|
|
Secure: isSecure, // Only set Secure in HTTPS environments
|
|
})
|
|
}
|
|
|
|
func verifyProofOfWork(challenge, salt, nonce string, difficulty int) bool {
|
|
inputStr := challenge + salt + nonce
|
|
hash := calculateHash(inputStr)
|
|
prefix := strings.Repeat("0", difficulty)
|
|
return strings.HasPrefix(hash, prefix)
|
|
}
|
|
|
|
func calculateHash(input string) string {
|
|
hash := sha256.Sum256([]byte(input))
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
func computeTokenSignature(token CheckpointToken, tokenBytes []byte) string {
|
|
tokenCopy := token
|
|
tokenCopy.Signature = "" // Ensure signature field is empty for signing
|
|
tokenToSign, _ := json.Marshal(tokenCopy)
|
|
h := hmac.New(sha256.New, hmacSecret)
|
|
h.Write(tokenToSign)
|
|
return hex.EncodeToString(h.Sum(nil))
|
|
}
|
|
|
|
func verifyTokenSignature(token CheckpointToken, tokenBytes []byte) bool {
|
|
if token.Signature == "" {
|
|
return false
|
|
}
|
|
expectedSignature := computeTokenSignature(token, tokenBytes)
|
|
return hmac.Equal([]byte(token.Signature), []byte(expectedSignature))
|
|
}
|
|
|
|
// VerifyCheckpointHandler verifies the challenge solution
|
|
func VerifyCheckpointHandler(c *fiber.Ctx) error {
|
|
clientIP := getRealIP(c)
|
|
|
|
var req CheckpointVerifyRequest
|
|
if err := c.BodyParser(&req); err != nil {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request format"})
|
|
}
|
|
|
|
// Challenge lookup
|
|
challengeVal, challengeExists := challengeStore.Load(req.RequestID)
|
|
if !challengeExists {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid or expired request ID"})
|
|
}
|
|
params := challengeVal.(ChallengeParams)
|
|
|
|
if clientIP != params.ClientIP { // Check against IP stored with challenge
|
|
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "IP address mismatch for challenge"})
|
|
}
|
|
|
|
decodedChallenge := ""
|
|
if decoded, err := base64.StdEncoding.DecodeString(params.Challenge); err == nil {
|
|
decodedChallenge = string(decoded)
|
|
} else {
|
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to decode challenge"})
|
|
}
|
|
decodedSalt := ""
|
|
if decoded, err := base64.StdEncoding.DecodeString(params.Salt); err == nil {
|
|
decodedSalt = string(decoded)
|
|
} else {
|
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to decode salt"})
|
|
}
|
|
|
|
if req.Nonce == "" {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Nonce ('g') required"})
|
|
}
|
|
|
|
// --- Nonce Check ---
|
|
nonceKey := req.Nonce + decodedChallenge
|
|
_, nonceExists := usedNonces.Load(nonceKey)
|
|
if nonceExists {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "This solution has already been used"})
|
|
}
|
|
// --- End Nonce Check ---
|
|
|
|
if !verifyProofOfWork(decodedChallenge, decodedSalt, req.Nonce, params.Difficulty) {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid proof-of-work solution"})
|
|
}
|
|
|
|
// --- Store Used Nonce (only after PoW is verified) ---
|
|
usedNonces.Store(nonceKey, time.Now())
|
|
// --- End Store Used Nonce ---
|
|
|
|
// Validate PoS hashes and times if provided
|
|
if len(req.PoSHashes) == 3 && len(req.PoSTimes) == 3 {
|
|
if req.PoSHashes[0] != req.PoSHashes[1] || req.PoSHashes[1] != req.PoSHashes[2] {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "PoS hashes ('h') do not match"})
|
|
}
|
|
if len(req.PoSHashes[0]) != 64 {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid PoS hash ('h') length"})
|
|
}
|
|
if err := checkPoSTimes(req.PoSTimes); err != nil {
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
|
|
}
|
|
} else if checkpointConfig.CheckPoSTimes && (len(req.PoSHashes) != 0 || len(req.PoSTimes) != 0) {
|
|
// If PoS checking is enabled, but incorrect number of hashes/times provided
|
|
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid PoS data provided"})
|
|
}
|
|
|
|
// Challenge is valid, remove it from store
|
|
challengeStore.Delete(req.RequestID)
|
|
|
|
entropyBytes := make([]byte, 8)
|
|
_, err := cryptorand.Read(entropyBytes)
|
|
if err != nil {
|
|
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to generate secure token entropy"})
|
|
}
|
|
entropy := hex.EncodeToString(entropyBytes)
|
|
|
|
// *** Gather current binding info for the new token ***
|
|
now := time.Now()
|
|
expiresAt := now.Add(checkpointConfig.TokenExpiration)
|
|
browserHint := extractBrowserFingerprint(c)
|
|
clientIPHash := getFullClientIP(c)
|
|
userAgentHash := hashUserAgent(c.Get("User-Agent"))
|
|
|
|
token := CheckpointToken{
|
|
Nonce: req.Nonce,
|
|
ExpiresAt: expiresAt,
|
|
ClientIP: clientIPHash,
|
|
UserAgent: userAgentHash,
|
|
BrowserHint: browserHint,
|
|
Entropy: entropy,
|
|
Created: now,
|
|
LastVerified: now,
|
|
TokenFormat: 2,
|
|
}
|
|
|
|
// Add a response header indicating success for the proxy
|
|
c.Set("X-Checkpoint-Status", "success")
|
|
log.Printf("Successfully verified challenge for IP %s, issuing token", clientIP)
|
|
|
|
// Issue token (handles DB storage, signing, cookie setting)
|
|
return issueToken(c, token)
|
|
}
|
|
|
|
// Renamed request struct
|
|
type CheckpointVerifyRequest struct {
|
|
RequestID string `json:"request_id"`
|
|
Nonce string `json:"g"`
|
|
PoSHashes []string `json:"h"`
|
|
PoSTimes []int64 `json:"i"`
|
|
DecoyHashes []string `json:"j"`
|
|
DecoyTimes []int64 `json:"k"`
|
|
DecoyFields []map[string]interface{} `json:"l"`
|
|
}
|
|
|
|
func generateChallenge() (string, string) {
|
|
randomBytes := make([]byte, 16)
|
|
_, err := cryptorand.Read(randomBytes)
|
|
if err != nil {
|
|
log.Fatalf("CRITICAL: Failed to generate secure random challenge: %v", err)
|
|
}
|
|
saltBytes := make([]byte, checkpointConfig.SaltLength)
|
|
_, err = cryptorand.Read(saltBytes)
|
|
if err != nil {
|
|
log.Fatalf("CRITICAL: Failed to generate secure random salt: %v", err)
|
|
}
|
|
salt := hex.EncodeToString(saltBytes)
|
|
return hex.EncodeToString(randomBytes), salt
|
|
}
|
|
|
|
// calculateTokenHash calculates a unique hash for storing the token status
|
|
// IMPORTANT: This hash is now used as the key in the database.
|
|
func calculateTokenHash(token CheckpointToken) string {
|
|
// Hash relevant fields that identify this specific verification instance
|
|
// Using Nonce, Entropy, and Creation time ensures uniqueness per issuance.
|
|
data := fmt.Sprintf("%s:%s:%d",
|
|
token.Nonce,
|
|
token.Entropy,
|
|
token.Created.UnixNano())
|
|
hash := sha256.Sum256([]byte(data))
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
// RequestSanitizationMiddleware spots malicious patterns (SQLi, XSS, path traversal)
|
|
// and returns 403 immediately to keep your app safe.
|
|
func RequestSanitizationMiddleware() fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
// Check URL path for directory traversal
|
|
path := c.Path()
|
|
if strings.Contains(path, "../") || strings.Contains(path, "..\\") {
|
|
log.Printf("Security block: Directory traversal attempt in path: %s from IP: %s", path, getRealIP(c))
|
|
return c.Status(fiber.StatusForbidden).SendString("Forbidden")
|
|
}
|
|
|
|
// Check query parameters for malicious patterns
|
|
query := c.Request().URI().QueryString()
|
|
if len(query) > 0 {
|
|
queryStr := string(query)
|
|
|
|
// Check for dangerous characters if configured
|
|
if checkpointConfig.BlockDangerousPathChars {
|
|
if strings.Contains(queryStr, ";") || strings.Contains(queryStr, "\\") || strings.Contains(queryStr, "`") {
|
|
log.Printf("Security block: Dangerous character in query from IP: %s, Query: %s", getRealIP(c), queryStr)
|
|
return c.Status(fiber.StatusForbidden).SendString("Forbidden")
|
|
}
|
|
}
|
|
|
|
// Check for configured attack patterns
|
|
for _, pattern := range checkpointConfig.DangerousQueryPatterns {
|
|
if pattern.MatchString(queryStr) {
|
|
log.Printf("Security block: Malicious pattern match in query from IP: %s, Pattern: %s, Query: %s",
|
|
getRealIP(c), pattern.String(), queryStr)
|
|
return c.Status(fiber.StatusForbidden).SendString("Forbidden")
|
|
}
|
|
}
|
|
}
|
|
|
|
return c.Next()
|
|
}
|
|
}
|