1
0
Fork 0
This repository has been archived on 2025-05-26. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
Checkpoint-Golang/checkpoint_service/middleware/checkpoint.go
2025-05-26 12:42:36 -05:00

1482 lines
49 KiB
Go

// middleware provides a small proof-of-work puzzle that users solve before
// accessing protected pages or APIs, plus transparent reverse-proxy support.
// It issues HMAC-signed tokens bound to IP/browser, stores them in BadgerDB,
// and automatically cleans up expired data.
package middleware
import (
"context"
"crypto/hmac"
cryptorand "crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"bytes"
"encoding/gob"
"html/template"
"github.com/dgraph-io/badger/v4"
"github.com/gofiber/fiber/v2"
"github.com/mileusna/useragent"
)
// --- Configuration ---
// Config struct holds all configurable parameters for the Checkpoint middleware
type Config struct {
// General Settings
Difficulty int // Number of leading zeros for PoW hash
TokenExpiration time.Duration // Validity period for issued tokens
CookieName string // Name of the cookie used to store tokens
CookieDomain string // Domain scope for the cookie (e.g., ".example.com" for subdomains)
SaltLength int // Length of the salt used in challenges
// Rate Limiting & Expiration
MaxAttemptsPerHour int // Max PoW verification attempts per IP per hour
MaxNonceAge time.Duration // Max age for used nonces before cleanup
ChallengeExpiration time.Duration // Time limit for solving a challenge
// File Paths
SecretConfigPath string // Path to the persistent HMAC secret file
TokenStoreDBPath string // Directory path for the BadgerDB token store
InterstitialPaths []string // Paths to search for the interstitial HTML page
// Security Settings
CheckPoSTimes bool // Enable Proof-of-Space-Time consistency checks
PoSTimeConsistencyRatio float64 // Allowed ratio between fastest and slowest PoS runs
HTMLCheckpointExclusions []string // Path prefixes to exclude from HTML checkpoint
HTMLCheckpointExcludedExtensions map[string]bool // File extensions to exclude (lowercase, '.')
DangerousQueryPatterns []*regexp.Regexp // Regex patterns to block in query strings
BlockDangerousPathChars bool // Block paths containing potentially dangerous characters (;, `)
// User Agent validation settings
UserAgentValidationExclusions []string // Path prefixes to skip UA validation
UserAgentRequiredPrefixes map[string]string // Path prefix -> required UA prefix
// Note: Binding to IP, User Agent, and Browser Hint is always enabled.
// Reverse Proxy Settings
ReverseProxyMappings map[string]string // Map of hostname to backend URL (e.g., "app.example.com": "http://127.0.0.1:8080")
}
var (
// Global configuration instance
checkpointConfig Config
// Secret key used for HMAC verification - automatically generated on startup
hmacSecret []byte
// Used nonces to prevent replay attacks - use sync.Map for concurrency
usedNonces sync.Map // map[string]time.Time
// IP-based rate limiting for token generation - use sync.Map for concurrency
ipRateLimit sync.Map // map[string]*atomic.Int64 (or similar atomic counter)
// Challenge parameters store with request IDs - use sync.Map for concurrency
challengeStore sync.Map // map[string]ChallengeParams
// Global token store (now BadgerDB based)
tokenStore *TokenStore
// in-memory cache for the interstitial HTML to avoid repeated disk reads
interstitialContent string
interstitialOnce sync.Once
interstitialLoadErr error
// parsed template for interstitial page
interstitialTmpl *template.Template
interstitialTmplOnce sync.Once
interstitialTmplErr error
// pool for gob encoding buffers to reduce allocations
gobBufferPool = sync.Pool{
New: func() interface{} { return new(bytes.Buffer) },
}
)
// Need atomic package for ipRateLimit counter
func init() {
// Load complete configuration from checkpoint.toml (required)
var cfg Config
if err := LoadConfig("checkpoint", &cfg); err != nil {
log.Fatalf("Failed to load checkpoint config: %v", err)
}
SetConfig(cfg)
// Register sanitization plugin (cleanup URLs/queries before checkpoint)
RegisterPlugin("sanitize", RequestSanitizationMiddleware)
// Register checkpoint plugin
RegisterPlugin("checkpoint", New)
// Initialize stores AFTER config is potentially set/loaded
// Ensure tokenStore is initialized before use
var err error
tokenStore, err = NewTokenStore(checkpointConfig.TokenStoreDBPath)
if err != nil {
log.Fatalf("CRITICAL: Failed to initialize TokenStore database: %v", err)
}
// Initialize secret
_ = initSecret()
// Start cleanup timer for nonces/ip rates (token cleanup handled by DB TTL)
_ = startCleanupTimer()
}
// SecretConfig contains configuration for the Checkpoint system (for secret file persistence)
type SecretConfig struct {
HmacSecret []byte `json:"hmac_secret"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// --- End Configuration ---
// SetConfig swaps in your custom Config (usually loaded from TOML).
// Do this before using the middleware, ideally at startup.
func SetConfig(cfg Config) {
checkpointConfig = cfg
// Re-initialization of token store path is complex with BadgerDB, recommend restart.
// Other config changes can be applied dynamically if needed.
}
// --- Token Store (BadgerDB Implementation) ---
// StoredTokenData holds the relevant information persisted for each token hash.
// This includes binding information needed for verification.
type StoredTokenData struct {
ClientIPHash string // Hash of IP used during issuance
UserAgentHash string // Hash of User Agent used during issuance
BrowserHint string // Browser Hint used during issuance
LastVerified time.Time // Last time this token was successfully validated
ExpiresAt time.Time // Original expiration time of the token (for reference, TTL enforces)
}
// TokenStore manages persistent storage of verified tokens using BadgerDB.
type TokenStore struct {
DB *badger.DB
}
// NewTokenStore initializes and returns a new TokenStore using BadgerDB.
func NewTokenStore(dbPath string) (*TokenStore, error) {
if err := os.MkdirAll(dbPath, 0755); err != nil {
return nil, fmt.Errorf("failed to create token store directory %s: %w", dbPath, err)
}
opts := badger.DefaultOptions(dbPath)
// Tune options for performance if needed (e.g., memory usage)
opts.Logger = nil // Disable default Badger logger unless debugging
db, err := badger.Open(opts)
if err != nil {
return nil, fmt.Errorf("failed to open token store database at %s: %w", dbPath, err)
}
store := &TokenStore{DB: db}
// Start BadgerDB's own value log GC routine (optional but recommended)
go store.runValueLogGC()
return store, nil
}
// Close closes the BadgerDB database.
// Should be called during graceful shutdown.
func (store *TokenStore) Close() error {
if store.DB != nil {
log.Println("Closing TokenStore database...")
return store.DB.Close()
}
return nil
}
// runValueLogGC runs BadgerDB's value log garbage collection periodically.
func (store *TokenStore) runValueLogGC() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
again:
err := store.DB.RunValueLogGC(0.7) // Run GC if 70% space can be reclaimed
if err == nil {
goto again // Run GC multiple times if needed
}
if err != badger.ErrNoRewrite {
log.Printf("WARNING: BadgerDB RunValueLogGC error: %v", err)
}
}
}
// encodeTokenData serializes StoredTokenData using gob.
func encodeTokenData(data *StoredTokenData) ([]byte, error) {
// get a buffer from pool
buf := gobBufferPool.Get().(*bytes.Buffer)
buf.Reset()
enc := gob.NewEncoder(buf)
if err := enc.Encode(data); err != nil {
gobBufferPool.Put(buf)
return nil, fmt.Errorf("failed to gob encode token data: %w", err)
}
// copy out the bytes to avoid retaining large buffer
out := make([]byte, buf.Len())
copy(out, buf.Bytes())
buf.Reset()
gobBufferPool.Put(buf)
return out, nil
}
// decodeTokenData deserializes StoredTokenData using gob.
func decodeTokenData(encoded []byte) (*StoredTokenData, error) {
var data StoredTokenData
// use a reader to avoid extra buffer allocation
reader := bytes.NewReader(encoded)
dec := gob.NewDecoder(reader)
if err := dec.Decode(&data); err != nil {
return nil, fmt.Errorf("failed to gob decode token data: %w", err)
}
return &data, nil
}
// addToken stores the token data in BadgerDB with a TTL.
func (store *TokenStore) addToken(tokenHash string, data *StoredTokenData) error {
encodedData, err := encodeTokenData(data)
if err != nil {
return err // Error already wrapped
}
// Calculate TTL based on the token's specific expiration
ttl := time.Until(data.ExpiresAt)
if ttl <= 0 {
log.Printf("Attempted to add already expired token hash %s", tokenHash)
return nil // Don't add already expired tokens
}
err = store.DB.Update(func(txn *badger.Txn) error {
e := badger.NewEntry([]byte(tokenHash), encodedData).WithTTL(ttl)
return txn.SetEntry(e)
})
if err != nil {
return fmt.Errorf("failed to add token hash %s to DB: %w", tokenHash, err)
}
return nil
}
// updateTokenVerification updates the LastVerified time for an existing token.
func (store *TokenStore) updateTokenVerification(tokenHash string) error {
return store.DB.Update(func(txn *badger.Txn) error {
item, err := txn.Get([]byte(tokenHash))
if err != nil {
// If token expired or was deleted between check and update, log and ignore.
if err == badger.ErrKeyNotFound {
log.Printf("Token hash %s not found during update verification (likely expired/deleted)", tokenHash)
return nil // Not a critical error in this context
}
return fmt.Errorf("failed to get token %s for update: %w", tokenHash, err)
}
var storedData *StoredTokenData
err = item.Value(func(val []byte) error {
storedData, err = decodeTokenData(val)
return err
})
if err != nil {
return fmt.Errorf("failed to decode token %s value for update: %w", tokenHash, err)
}
// Update LastVerified and re-encode
storedData.LastVerified = time.Now()
encodedData, err := encodeTokenData(storedData)
if err != nil {
return err
}
// Set the entry again (TTL remains the same based on original ExpiresAt)
ttl := time.Until(storedData.ExpiresAt)
if ttl <= 0 {
return nil
} // Don't update if expired
e := badger.NewEntry([]byte(tokenHash), encodedData).WithTTL(ttl)
return txn.SetEntry(e)
})
}
// lookupTokenData retrieves token data from BadgerDB.
// Returns the data, true if found and not expired, or false otherwise.
// Added context parameter
func (store *TokenStore) lookupTokenData(ctx context.Context, tokenHash string) (*StoredTokenData, bool, error) {
var storedData *StoredTokenData
var found bool
err := store.DB.View(func(txn *badger.Txn) error {
// Check context cancellation within the transaction
if ctx.Err() != nil {
return ctx.Err()
}
item, err := txn.Get([]byte(tokenHash))
if err != nil {
if err == badger.ErrKeyNotFound {
return nil // Not found, not an error for lookup
}
return fmt.Errorf("failed to get token hash %s from DB: %w", tokenHash, err)
}
// Key exists, decode the value
err = item.Value(func(val []byte) error {
// Check context cancellation before decoding
if ctx.Err() != nil {
return ctx.Err()
}
var decodeErr error
storedData, decodeErr = decodeTokenData(val)
return decodeErr
})
if err != nil {
// If context was cancelled, return that error
if ctx.Err() != nil {
return ctx.Err()
}
// Return actual decoding error
return fmt.Errorf("failed to decode StoredTokenData for hash %s: %w", tokenHash, err)
}
// Check expiration explicitly just in case TTL mechanism has latency
if time.Now().After(storedData.ExpiresAt) {
log.Printf("Token hash %s found but expired (ExpiresAt: %v)", tokenHash, storedData.ExpiresAt)
storedData = nil // Treat as not found if expired
return nil
}
found = true
return nil
})
if err != nil {
// Don't log here, return the error to the caller (validateToken)
return nil, false, err // Return the actual error
}
return storedData, found, nil // Success
}
// --- End Token Store ---
// CloseTokenStore provides a package-level function to close the global token store.
// This should be called during application shutdown.
func CloseTokenStore() error {
if tokenStore != nil {
return tokenStore.Close()
}
return nil
}
// loadInterstitialHTML returns the cached interstitial HTML (loads once from disk)
func loadInterstitialHTML() (string, error) {
interstitialOnce.Do(func() {
for _, path := range checkpointConfig.InterstitialPaths {
if data, err := os.ReadFile(path); err == nil {
interstitialContent = string(data)
return
}
}
interstitialLoadErr = fmt.Errorf("could not find checkpoint interstitial HTML at any configured path")
})
return interstitialContent, interstitialLoadErr
}
// getInterstitialTemplate parses the cached HTML as a Go template (once)
func getInterstitialTemplate() (*template.Template, error) {
interstitialTmplOnce.Do(func() {
raw, err := loadInterstitialHTML()
if err != nil {
interstitialTmplErr = err
return
}
interstitialTmpl, interstitialTmplErr = template.New("interstitial").Parse(raw)
})
return interstitialTmpl, interstitialTmplErr
}
// serveInterstitial serves the challenge page using a Go template for safe interpolation
func serveInterstitial(c *fiber.Ctx) error {
requestID := generateRequestID(c)
c.Status(200)
c.Set("Content-Type", "text/html; charset=utf-8")
tmpl, err := getInterstitialTemplate()
if err != nil {
log.Printf("WARNING: %v", err)
return c.SendString("Security verification required. Please refresh the page.")
}
// prepare data for template
host := c.Hostname()
originalURL, _ := c.Locals("originalURL").(string)
targetPath := c.Path()
if originalURL != "" {
targetPath = originalURL
}
data := struct {
TargetPath string
RequestID string
Host string
FullURL string
}{
TargetPath: targetPath,
RequestID: requestID,
Host: host,
FullURL: c.BaseURL() + targetPath,
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
log.Printf("ERROR: Interstitial template execution failed: %v", err)
return c.SendString("Security verification required. Please refresh the page.")
}
return c.SendString(buf.String())
}
// checkPoSTimes ensures that memory proof run times are within the allowed ratio
func checkPoSTimes(times []int64) error {
if len(times) != 3 {
return fmt.Errorf("invalid PoS run times length")
}
minT, maxT := times[0], times[0]
for _, t := range times[1:] {
if t < minT {
minT = t
}
if t > maxT {
maxT = t
}
}
if checkpointConfig.CheckPoSTimes && float64(maxT) > float64(minT)*checkpointConfig.PoSTimeConsistencyRatio {
return fmt.Errorf("PoS run times ('i') are not consistent (ratio %.2f > %.2f)",
float64(maxT)/float64(minT), checkpointConfig.PoSTimeConsistencyRatio)
}
return nil
}
// getDomainFromHost returns the base domain from a hostname
// For proper cookie sharing in both production and development
func getDomainFromHost(hostname string) string {
// Handle localhost development
if hostname == "localhost" || strings.HasPrefix(hostname, "localhost:") ||
hostname == "127.0.0.1" || strings.HasPrefix(hostname, "127.0.0.1:") {
return "" // Use host-only cookies for localhost
}
// For IP addresses, use host-only cookies
if net.ParseIP(strings.Split(hostname, ":")[0]) != nil {
return "" // IP address - use host-only
}
parts := strings.Split(hostname, ".")
if len(parts) <= 1 {
return hostname // single word domain - unlikely
}
// For standard domains, return domain with leading dot
if len(parts) >= 2 {
// Return parent domain for proper cookie sharing
domain := parts[len(parts)-2] + "." + parts[len(parts)-1]
return "." + domain // Leading dot is important
}
return "" // Fallback to host-only cookie
}
// issueToken handles token generation, cookie setting, and JSON response
func issueToken(c *fiber.Ctx, token CheckpointToken) error {
// 1. Generate the token hash
tokenHash := calculateTokenHash(token)
// 2. Create the data to store in the DB
storedData := &StoredTokenData{
ClientIPHash: token.ClientIP, // Assumes token struct is already populated
UserAgentHash: token.UserAgent,
BrowserHint: token.BrowserHint,
LastVerified: token.LastVerified,
ExpiresAt: token.ExpiresAt, // Store original expiration
}
// 3. Add to the database
if err := tokenStore.addToken(tokenHash, storedData); err != nil {
log.Printf("ERROR: Failed to store token in DB for hash %s: %v", tokenHash, err)
// Decide if this is fatal or just a warning. For now, log and continue.
// return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to store verification proof"})
}
// 4. Sign the token (as before)
token.Signature = "" // Clear signature before marshalling for signing
tokenBytesForSig, _ := json.Marshal(token)
token.Signature = computeTokenSignature(token, tokenBytesForSig)
// 5. Prepare final token for cookie
finalBytes, err := json.Marshal(token)
if err != nil {
log.Printf("ERROR: Failed to marshal final token: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to prepare token"})
}
tokenStr := base64.StdEncoding.EncodeToString(finalBytes)
// 6. Set cookie
// Determine if we're serving on HTTPS or HTTP
isSecure := true
// Check if we're in development mode using non-secure connection
if strings.HasPrefix(c.Protocol(), "http") && !strings.HasPrefix(c.BaseURL(), "https") {
isSecure = false // Running on http:// (dev mode)
}
// Get domain for cookie - either from config or auto-detect
cookieDomain := checkpointConfig.CookieDomain
if cookieDomain == "" {
// Auto-detect - for development convenience
cookieDomain = getDomainFromHost(c.Hostname())
}
// Set SameSite based on domain - use Lax for cross-subdomain
sameSite := "Strict"
if cookieDomain != "" {
sameSite = "Lax" // Lax allows subdomain sharing better than Strict
}
c.Cookie(&fiber.Cookie{
Name: checkpointConfig.CookieName,
Value: tokenStr,
Expires: token.ExpiresAt, // Cookie expires when token expires
Path: "/",
Domain: cookieDomain,
HTTPOnly: true,
SameSite: sameSite,
Secure: isSecure, // Only set Secure in HTTPS environments
})
return c.JSON(fiber.Map{"token": tokenStr, "expires_at": token.ExpiresAt})
}
// Initialize a secure random secret key or load from persistent storage
func initSecret() bool {
if _, err := os.Stat(checkpointConfig.SecretConfigPath); err == nil {
// Config file exists, try to load it
if loadedSecret := loadSecretFromFile(); loadedSecret != nil {
hmacSecret = loadedSecret
log.Printf("Loaded existing HMAC secret from %s", checkpointConfig.SecretConfigPath)
return true
}
}
// No config file or loading failed, generate a new secret
hmacSecret = make([]byte, 32)
_, err := cryptorand.Read(hmacSecret)
if err != nil {
// Critical security error - don't continue with insecure random numbers
log.Fatalf("CRITICAL: Could not generate secure random secret: %v", err)
}
// Ensure data directory exists
if err := os.MkdirAll(filepath.Dir(checkpointConfig.SecretConfigPath), 0755); err != nil {
log.Printf("WARNING: Could not create data directory: %v", err)
return true
}
// Save the new secret to file
config := SecretConfig{
HmacSecret: hmacSecret,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if configBytes, err := json.Marshal(config); err == nil {
if err := os.WriteFile(checkpointConfig.SecretConfigPath, configBytes, 0600); err != nil {
log.Printf("WARNING: Could not save HMAC secret to file: %v", err)
} else {
log.Printf("Created and saved new HMAC secret to %s", checkpointConfig.SecretConfigPath)
}
}
return true
}
// loadSecretFromFile loads the HMAC secret from persistent storage
func loadSecretFromFile() []byte {
configBytes, err := os.ReadFile(checkpointConfig.SecretConfigPath)
if err != nil {
log.Printf("ERROR: Could not read secret config file: %v", err)
return nil
}
var config SecretConfig
if err := json.Unmarshal(configBytes, &config); err != nil {
log.Printf("ERROR: Could not parse secret config file: %v", err)
return nil
}
if len(config.HmacSecret) < 16 {
log.Printf("ERROR: Secret from file is too short, generating a new one")
return nil
}
// Update the last loaded time
config.UpdatedAt = time.Now()
if configBytes, err := json.Marshal(config); err == nil {
if err := os.WriteFile(checkpointConfig.SecretConfigPath, configBytes, 0600); err != nil {
log.Printf("WARNING: Could not update HMAC secret file: %v", err)
}
}
return config.HmacSecret
}
// Start a timer to periodically clean up the nonce and rate limit maps
func startCleanupTimer() bool {
ticker := time.NewTicker(1 * time.Hour)
go func() {
for range ticker.C {
cleanupExpiredData()
cleanupExpiredChallenges()
}
}()
return true
}
// Clean up expired nonces and rate limit data
func cleanupExpiredData() {
// Clean up used nonces
now := time.Now()
expiredNonceCount := 0
usedNonces.Range(func(key, value interface{}) bool {
nonce := key.(string)
timestamp := value.(time.Time)
if now.Sub(timestamp) > checkpointConfig.MaxNonceAge {
usedNonces.Delete(nonce)
expiredNonceCount++
}
return true // continue iteration
})
if expiredNonceCount > 0 {
log.Printf("Checkpoint: Cleaned up %d expired nonces.", expiredNonceCount)
}
// Reset IP rate limits every hour by deleting all entries
ipRateLimit.Range(func(key, value interface{}) bool {
ipRateLimit.Delete(key)
return true
})
log.Println("Checkpoint: IP rate limits reset.")
}
// CheckpointToken represents a validated token
type CheckpointToken struct {
Nonce string `json:"g"` // Nonce
Challenge string `json:"-"` // Derived server-side, not in token
Salt string `json:"-"` // Derived server-side, not in token
Difficulty int `json:"-"` // Derived server-side, not in token
ExpiresAt time.Time `json:"exp"`
ClientIP string `json:"cip,omitempty"`
UserAgent string `json:"ua,omitempty"`
BrowserHint string `json:"bh,omitempty"`
Entropy string `json:"ent,omitempty"`
Created time.Time `json:"crt"`
LastVerified time.Time `json:"lvf,omitempty"`
Signature string `json:"sig,omitempty"`
TokenFormat int `json:"fmt"`
}
// ChallengeParams stores parameters for a challenge
type ChallengeParams struct {
Challenge string `json:"challenge"` // Base64 encoded
Salt string `json:"salt"` // Base64 encoded
Difficulty int `json:"difficulty"`
ExpiresAt time.Time `json:"expires_at"`
ClientIP string `json:"-"`
PoSSeed string `json:"pos_seed"` // Hex encoded
}
// isExcludedHTMLPath checks if a path should be excluded from the HTML checkpoint.
// Exclusions happen based on configured prefixes or file extensions.
func isExcludedHTMLPath(path string) bool {
// 1. Check path prefixes
for _, prefix := range checkpointConfig.HTMLCheckpointExclusions {
if strings.HasPrefix(path, prefix) {
return true // Excluded by prefix
}
}
// 2. Check file extension using the set
ext := strings.ToLower(filepath.Ext(path))
if ext != "" {
if _, exists := checkpointConfig.HTMLCheckpointExcludedExtensions[ext]; exists {
return true // Excluded by file extension
}
}
// 3. If not excluded by prefix or extension, it needs the checkpoint
return false
}
// DirectProxy returns a handler that simply forwards the request/response to targetURL.
// Headers, status codes, and body are passed through without modification.
func DirectProxy(targetURL string) fiber.Handler {
target, err := url.Parse(targetURL)
if err != nil {
return func(c *fiber.Ctx) error {
log.Printf("ERROR: Invalid target URL %s: %v", targetURL, err)
return fiber.ErrBadGateway
}
}
proxy := httputil.NewSingleHostReverseProxy(target)
// Set up custom director to properly map headers
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
// Add X-Forwarded headers
req.Header.Set("X-Forwarded-Host", req.Host)
req.Header.Set("X-Forwarded-Proto", "http") // Update to https when needed
if v := req.Header.Get("X-Forwarded-For"); v != "" {
req.Header.Set("X-Forwarded-For", v+", "+req.RemoteAddr)
} else {
req.Header.Set("X-Forwarded-For", req.RemoteAddr)
}
}
return func(c *fiber.Ctx) error {
// Create proxy request
proxyReq, err := http.NewRequest(
string(c.Method()),
target.String()+c.Path(),
bytes.NewReader(c.Body()),
)
if err != nil {
log.Printf("ERROR: Failed to create proxy request: %v", err)
return fiber.ErrBadGateway
}
// Copy all headers from the Fiber context to the proxy request
c.Request().Header.VisitAll(func(key, value []byte) {
proxyReq.Header.Set(string(key), string(value))
})
// Execute the proxy request
proxyRes, err := http.DefaultClient.Do(proxyReq)
if err != nil {
log.Printf("ERROR: Proxy request failed: %v", err)
return fiber.ErrBadGateway
}
defer proxyRes.Body.Close()
// Copy all headers from the proxy response to Fiber's response
for key, values := range proxyRes.Header {
for _, value := range values {
c.Response().Header.Add(key, value)
}
}
// Set the status code
c.Status(proxyRes.StatusCode)
// Copy the body
body, err := io.ReadAll(proxyRes.Body)
if err != nil {
log.Printf("ERROR: Failed to read proxy response body: %v", err)
return fiber.ErrBadGateway
}
return c.Send(body)
}
}
// isBlockedBot checks concurrently if the User-Agent indicates a known bot
// or doesn't have a standard browser prefix.
// It returns true as soon as one check decides to block.
func isBlockedBot(userAgent string) bool {
if userAgent == "" {
// Empty User-Agent is suspicious, block it
log.Printf("INFO: UA blocked - empty user agent")
return true
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel() // Ensure context is cancelled eventually
resultChan := make(chan bool, 2) // Buffered channel for results
// Goroutine 1: Library-based bot check
go func() {
ua := useragent.Parse(userAgent)
shouldBlock := ua.Bot
if shouldBlock {
log.Printf("INFO: UA blocked by library (Bot detected: %s): %s", ua.Name, userAgent)
}
select {
case resultChan <- shouldBlock:
case <-ctx.Done(): // Don't send if context is cancelled
}
}()
// Goroutine 2: Prefix check
go func() {
// Standard browser User-Agent prefixes
standardPrefixes := []string{"Mozilla/", "Opera/", "DuckDuckGo/", "Dart/"}
hasStandardPrefix := false
for _, prefix := range standardPrefixes {
if strings.HasPrefix(userAgent, prefix) {
hasStandardPrefix = true
break
}
}
// Block if it does NOT have a standard prefix
shouldBlock := !hasStandardPrefix
if shouldBlock {
log.Printf("INFO: UA blocked by prefix check (doesn't have standard prefix): %s", userAgent)
}
select {
case resultChan <- shouldBlock:
case <-ctx.Done(): // Don't send if context is cancelled
}
}()
// Wait for results and decide
result1 := <-resultChan
if result1 {
cancel() // Found a reason to block, cancel the other check
return true
}
// First check didn't block, wait for the second result
result2 := <-resultChan
// cancel() is deferred, so it will run anyway, ensuring cleanup
return result2 // Block if the second check decided to block
}
// New gives you a Fiber handler that does the POW challenge (HTML/API) or proxies requests.
func New() fiber.Handler {
return func(c *fiber.Ctx) error {
host := c.Hostname()
targetURL, useProxy := checkpointConfig.ReverseProxyMappings[host]
path := c.Path()
// --- User-Agent Validation ---
// Only check User-Agent if path is not in exclusion list
skipUA := false
for _, prefix := range checkpointConfig.UserAgentValidationExclusions {
if strings.HasPrefix(path, prefix) {
skipUA = true
break
}
}
if !skipUA {
// First check required UA prefixes for specific paths
for p, required := range checkpointConfig.UserAgentRequiredPrefixes {
if strings.HasPrefix(path, p) {
ua := c.Get("User-Agent")
if !strings.HasPrefix(ua, required) {
log.Printf("INFO: UA blocked by required prefix %s: %s", required, ua)
if strings.HasPrefix(path, "/api") {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "Access denied for automated clients.",
"reason": "useragent",
})
}
return c.Status(fiber.StatusForbidden).SendString("Access denied for automated clients.")
}
break
}
}
// Then do general bot check for all non-excluded paths
userAgent := c.Get("User-Agent")
if isBlockedBot(userAgent) {
if strings.HasPrefix(path, "/api") {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "Access denied for automated clients.",
"reason": "useragent",
})
}
return c.Status(fiber.StatusForbidden).SendString("Access denied for automated clients.")
}
}
// Handle any API endpoints
if strings.HasPrefix(path, "/api") {
// Always serve PoW endpoints locally (challenge & verify)
if strings.HasPrefix(path, "/api/pow/") || strings.HasPrefix(path, "/api/verify") {
log.Printf("API checkpoint endpoint %s - handling locally", path)
return c.Next()
}
// Other API paths: skip checkpoint
if useProxy {
// Proxy to backend for proxied hosts
log.Printf("API proxying endpoint %s to %s", path, targetURL)
return DirectProxy(targetURL)(c)
}
log.Printf("API endpoint %s - bypassing checkpoint", path)
return c.Next()
}
// --- Reverse Proxy Logic ---
if useProxy {
// Check for existing valid token cookie
tokenCookie := c.Cookies(checkpointConfig.CookieName)
log.Printf("Proxy: Checking token for host %s, path %s, cookie present: %v",
host, path, tokenCookie != "")
// Check if this is an excluded path (API endpoints, etc)
if isExcludedHTMLPath(path) {
log.Printf("Excluded path %s for proxied host %s - proxying without token check", path, host)
// Direct transparent proxy (preserves all headers/content types)
return DirectProxy(targetURL)(c)
}
valid, err := validateToken(tokenCookie, c)
if err != nil {
// Log validation errors but treat as invalid for proxying
log.Printf("Error validating token for proxied host %s, path %s: %v", host, path, err)
}
if valid {
log.Printf("Valid token found for proxied host %s, path %s - forwarding request", host, path)
// Token is valid, proxy the request
// Direct transparent proxy (preserves all headers/content types)
return DirectProxy(targetURL)(c)
} else {
// Add debug logging
log.Printf("No valid token for proxied host %s, path %s - serving interstitial", host, path)
// Save the original full URL for potential redirection after verification
c.Locals("originalURL", c.OriginalURL())
// No valid token, serve the interstitial challenge page.
return serveInterstitial(c)
}
}
// --- Standard HTML/Static/API Logic (No Proxy Mapping) ---
// Skip checkpoint for excluded paths (e.g., static assets, API endpoints handled separately)
if isExcludedHTMLPath(path) {
return c.Next()
}
// --- Path needs checkpoint (potential HTML page) ---
tokenCookie := c.Cookies(checkpointConfig.CookieName)
if tokenCookie != "" {
valid, err := validateToken(tokenCookie, c)
if err != nil {
// Log validation errors but still serve interstitial for safety
log.Printf("Error validating token for path %s: %v", path, err)
// Fall through to serve interstitial
} else if valid {
// Token is valid, proceed to the requested page/handler
return c.Next()
}
// If token was present but invalid/expired, fall through to serve interstitial
}
// No valid token found, serve the interstitial challenge page.
return serveInterstitial(c)
}
}
// generateRequestID creates a unique ID for this verification request
func generateRequestID(c *fiber.Ctx) string {
challenge, salt := generateChallenge()
// Generate PoS seed
posSeedBytes := make([]byte, 32)
if n, err := cryptorand.Read(posSeedBytes); err != nil {
log.Fatalf("CRITICAL: Failed to generate PoS seed: %v", err)
} else if n != len(posSeedBytes) {
log.Fatalf("CRITICAL: Short read generating PoS seed: read %d bytes", n)
}
posSeed := hex.EncodeToString(posSeedBytes)
// Generate request ID
randBytes := make([]byte, 16)
if n, err := cryptorand.Read(randBytes); err != nil {
log.Fatalf("CRITICAL: Failed to generate request ID: %v", err)
} else if n != len(randBytes) {
log.Fatalf("CRITICAL: Short read generating request ID: read %d bytes", n)
}
requestID := hex.EncodeToString(randBytes)
// Base64-encode the hex challenge and salt for storage
encodedChallenge := base64.StdEncoding.EncodeToString([]byte(challenge))
encodedSalt := base64.StdEncoding.EncodeToString([]byte(salt))
params := ChallengeParams{
Challenge: encodedChallenge,
Salt: encodedSalt,
Difficulty: checkpointConfig.Difficulty,
ExpiresAt: time.Now().Add(checkpointConfig.ChallengeExpiration),
ClientIP: getRealIP(c),
PoSSeed: posSeed,
}
challengeStore.Store(requestID, params)
return requestID
}
func cleanupExpiredChallenges() {
now := time.Now()
expiredChallengeCount := 0
challengeStore.Range(func(key, value interface{}) bool {
id := key.(string)
params := value.(ChallengeParams)
if now.After(params.ExpiresAt) {
challengeStore.Delete(id)
expiredChallengeCount++
}
return true // continue iteration
})
if expiredChallengeCount > 0 {
log.Printf("Checkpoint: Cleaned up %d expired challenges.", expiredChallengeCount)
}
}
// GetCheckpointChallengeHandler serves challenge parameters via API
func GetCheckpointChallengeHandler(c *fiber.Ctx) error {
requestID := c.Query("id")
if requestID == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Missing request ID"})
}
// Apply rate limiting to challenge generation
clientIP := getRealIP(c)
val, _ := ipRateLimit.LoadOrStore(clientIP, new(atomic.Int64))
ipCounter := val.(*atomic.Int64)
attempts := ipCounter.Add(1) // Increment and get new value
// Limit to a reasonable number of challenge requests per hour (using the same MaxAttemptsPerHour config)
if attempts > int64(checkpointConfig.MaxAttemptsPerHour) {
return c.Status(fiber.StatusTooManyRequests).JSON(fiber.Map{"error": "Too many challenge requests. Please try again later."})
}
val, exists := challengeStore.Load(requestID)
if !exists {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "Challenge not found or expired"})
}
params := val.(ChallengeParams)
if clientIP != params.ClientIP {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "IP address mismatch for challenge"})
}
decoySeedBytes := make([]byte, 8)
cryptorand.Read(decoySeedBytes)
decoySeed := hex.EncodeToString(decoySeedBytes)
decoyFields := make([]map[string]interface{}, 0)
decoyFieldCount := 2 + int(decoySeedBytes[0])%3
for i := 0; i < decoyFieldCount; i++ {
nameLen := 5 + int(decoySeedBytes[i%8])%8
valLen := 8 + int(decoySeedBytes[(i+1)%8])%24
name := randomHexString(nameLen)
val := randomHexString(valLen)
decoyFields = append(decoyFields, map[string]interface{}{name: val})
}
return c.JSON(fiber.Map{
"a": params.Challenge, // challenge
"b": params.Salt, // salt
"c": params.Difficulty, // difficulty
"d": params.PoSSeed, // pos_seed
"e": decoySeed, // decoy_seed
"f": decoyFields, // decoy_fields
})
}
func randomHexString(n int) string {
b := make([]byte, (n+1)/2)
if m, err := cryptorand.Read(b); err != nil {
log.Fatalf("CRITICAL: Failed to generate random hex string: %v", err)
} else if m != len(b) {
log.Fatalf("CRITICAL: Short read generating random hex string: read %d bytes", m)
}
s := hex.EncodeToString(b)
if len(s) < n {
log.Fatalf("CRITICAL: Random hex string too short: got %d hex chars, want %d", len(s), n)
}
return s[:n]
}
func getFullClientIP(c *fiber.Ctx) string {
ip := getRealIP(c)
if ip == "" {
return "unknown"
}
h := sha256.Sum256([]byte(ip))
return hex.EncodeToString(h[:8])
}
func hashUserAgent(userAgent string) string {
if userAgent == "" {
return ""
}
hash := sha256.Sum256([]byte(userAgent))
return hex.EncodeToString(hash[:8])
}
func extractBrowserFingerprint(c *fiber.Ctx) string {
headers := []string{
c.Get("Sec-CH-UA"), c.Get("Sec-CH-UA-Platform"), c.Get("Sec-CH-UA-Mobile"),
c.Get("Sec-CH-UA-Platform-Version"), c.Get("Sec-CH-UA-Arch"), c.Get("Sec-CH-UA-Model"),
}
var validHeaders []string
for _, h := range headers {
if h != "" {
validHeaders = append(validHeaders, h)
}
}
if len(validHeaders) == 0 {
return ""
}
fingerprint := strings.Join(validHeaders, "|")
hash := sha256.Sum256([]byte(fingerprint))
return hex.EncodeToString(hash[:12])
}
func validateToken(tokenStr string, c *fiber.Ctx) (bool, error) {
// Explicitly handle missing token case first.
if tokenStr == "" {
return false, nil // No token cookie found, definitely not valid.
}
// 1. Decode the token string from the cookie
tokenBytes, err := base64.StdEncoding.DecodeString(tokenStr)
if err != nil {
// Invalid Base64 encoding - treat as invalid token, not a system error
return false, nil
}
// Check for empty byte slice after decoding
if len(tokenBytes) == 0 {
// Decoded to empty - treat as invalid token
return false, nil
}
// 2. Unmarshal
var token CheckpointToken
if err := json.Unmarshal(tokenBytes, &token); err != nil {
// Invalid JSON structure - treat as invalid token
return false, nil // Error seen in logs comes from here, now returns nil error
}
// 3. Basic expiration check based on ExpiresAt field in the token itself
// Note: Return nil error for expired token, it's just invalid.
if time.Now().After(token.ExpiresAt) {
return false, nil // Token itself says it's expired
}
// 4. Check token signature first (Format 2+)
if token.TokenFormat < 2 {
return false, nil // Old format not supported/secure - invalid
}
if !verifyTokenSignature(token, tokenBytes) {
return false, nil // Invalid signature - invalid
}
// 5. Calculate the token hash to look up in the database
tokenHash := calculateTokenHash(token)
// 6. Look up the token data in BadgerDB
storedData, found, dbErr := tokenStore.lookupTokenData(c.Context(), tokenHash)
if dbErr != nil {
// Actual DB error during lookup - THIS is a real error to return
return false, fmt.Errorf("token DB lookup failed: %w", dbErr)
}
if !found {
// Token hash not found in DB or explicitly expired according to DB record
return false, nil
}
// 7. *** CRITICAL: Verify bindings against stored data and current request ***
// Compare Client IP Hash
currentPartialIP := getFullClientIP(c)
if storedData.ClientIPHash != currentPartialIP {
return false, nil // IP mismatch - invalid
}
// Compare User Agent Hash
currentUserAgent := hashUserAgent(c.Get("User-Agent"))
if storedData.UserAgentHash != currentUserAgent {
return false, nil // User agent mismatch - invalid
}
// Compare Browser Hint
currentBrowserHint := extractBrowserFingerprint(c)
// Only enforce if hint was stored AND current hint is available
if storedData.BrowserHint != "" && currentBrowserHint != "" && storedData.BrowserHint != currentBrowserHint {
return false, nil // Browser hint mismatch - invalid
}
// 8. All checks passed! Token is valid and bound correctly.
// Update LastVerified time in the database (best effort, log errors)
if err := tokenStore.updateTokenVerification(tokenHash); err != nil {
log.Printf("WARNING: Failed to update token verification time for hash %s: %v", tokenHash, err)
}
// Refresh the cookie with potentially updated ExpiresAt (if sliding window desired) or just LastVerified.
// For simplicity, we'll just refresh with the same ExpiresAt for now.
token.LastVerified = time.Now()
updateTokenCookie(c, token) // Resign and set cookie
return true, nil
}
func updateTokenCookie(c *fiber.Ctx, token CheckpointToken) {
// Determine if we're serving on HTTPS or HTTP
isSecure := true
// Check if we're in development mode using non-secure connection
if strings.HasPrefix(c.Protocol(), "http") && !strings.HasPrefix(c.BaseURL(), "https") {
isSecure = false // Running on http:// (dev mode)
}
// Get domain for cookie - either from config or auto-detect
cookieDomain := checkpointConfig.CookieDomain
if cookieDomain == "" {
// Auto-detect - for development convenience
cookieDomain = getDomainFromHost(c.Hostname())
}
// Set SameSite based on domain - use Lax for cross-subdomain
sameSite := "Strict"
if cookieDomain != "" {
sameSite = "Lax" // Lax allows subdomain sharing better than Strict
}
// Recompute signature because LastVerified might have changed
token.Signature = ""
tempBytes, _ := json.Marshal(token)
token.Signature = computeTokenSignature(token, tempBytes) // Compute signature on token WITHOUT old signature
finalTokenBytes, err := json.Marshal(token) // Marshal again with new signature
if err != nil {
log.Printf("Error marshaling token for cookie update: %v", err)
return
}
tokenStr := base64.StdEncoding.EncodeToString(finalTokenBytes)
c.Cookie(&fiber.Cookie{
Name: checkpointConfig.CookieName,
Value: tokenStr,
Expires: token.ExpiresAt, // Use original expiration
Path: "/",
Domain: cookieDomain,
HTTPOnly: true,
SameSite: sameSite,
Secure: isSecure, // Only set Secure in HTTPS environments
})
}
func verifyProofOfWork(challenge, salt, nonce string, difficulty int) bool {
inputStr := challenge + salt + nonce
hash := calculateHash(inputStr)
prefix := strings.Repeat("0", difficulty)
return strings.HasPrefix(hash, prefix)
}
func calculateHash(input string) string {
hash := sha256.Sum256([]byte(input))
return hex.EncodeToString(hash[:])
}
func computeTokenSignature(token CheckpointToken, tokenBytes []byte) string {
tokenCopy := token
tokenCopy.Signature = "" // Ensure signature field is empty for signing
tokenToSign, _ := json.Marshal(tokenCopy)
h := hmac.New(sha256.New, hmacSecret)
h.Write(tokenToSign)
return hex.EncodeToString(h.Sum(nil))
}
func verifyTokenSignature(token CheckpointToken, tokenBytes []byte) bool {
if token.Signature == "" {
return false
}
expectedSignature := computeTokenSignature(token, tokenBytes)
return hmac.Equal([]byte(token.Signature), []byte(expectedSignature))
}
// VerifyCheckpointHandler verifies the challenge solution
func VerifyCheckpointHandler(c *fiber.Ctx) error {
clientIP := getRealIP(c)
var req CheckpointVerifyRequest
if err := c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid request format"})
}
// Challenge lookup
challengeVal, challengeExists := challengeStore.Load(req.RequestID)
if !challengeExists {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid or expired request ID"})
}
params := challengeVal.(ChallengeParams)
if clientIP != params.ClientIP { // Check against IP stored with challenge
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "IP address mismatch for challenge"})
}
decodedChallenge := ""
if decoded, err := base64.StdEncoding.DecodeString(params.Challenge); err == nil {
decodedChallenge = string(decoded)
} else {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to decode challenge"})
}
decodedSalt := ""
if decoded, err := base64.StdEncoding.DecodeString(params.Salt); err == nil {
decodedSalt = string(decoded)
} else {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to decode salt"})
}
if req.Nonce == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Nonce ('g') required"})
}
// --- Nonce Check ---
nonceKey := req.Nonce + decodedChallenge
_, nonceExists := usedNonces.Load(nonceKey)
if nonceExists {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "This solution has already been used"})
}
// --- End Nonce Check ---
if !verifyProofOfWork(decodedChallenge, decodedSalt, req.Nonce, params.Difficulty) {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid proof-of-work solution"})
}
// --- Store Used Nonce (only after PoW is verified) ---
usedNonces.Store(nonceKey, time.Now())
// --- End Store Used Nonce ---
// Validate PoS hashes and times if provided
if len(req.PoSHashes) == 3 && len(req.PoSTimes) == 3 {
if req.PoSHashes[0] != req.PoSHashes[1] || req.PoSHashes[1] != req.PoSHashes[2] {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "PoS hashes ('h') do not match"})
}
if len(req.PoSHashes[0]) != 64 {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid PoS hash ('h') length"})
}
if err := checkPoSTimes(req.PoSTimes); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
} else if checkpointConfig.CheckPoSTimes && (len(req.PoSHashes) != 0 || len(req.PoSTimes) != 0) {
// If PoS checking is enabled, but incorrect number of hashes/times provided
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid PoS data provided"})
}
// Challenge is valid, remove it from store
challengeStore.Delete(req.RequestID)
entropyBytes := make([]byte, 8)
_, err := cryptorand.Read(entropyBytes)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to generate secure token entropy"})
}
entropy := hex.EncodeToString(entropyBytes)
// *** Gather current binding info for the new token ***
now := time.Now()
expiresAt := now.Add(checkpointConfig.TokenExpiration)
browserHint := extractBrowserFingerprint(c)
clientIPHash := getFullClientIP(c)
userAgentHash := hashUserAgent(c.Get("User-Agent"))
token := CheckpointToken{
Nonce: req.Nonce,
ExpiresAt: expiresAt,
ClientIP: clientIPHash,
UserAgent: userAgentHash,
BrowserHint: browserHint,
Entropy: entropy,
Created: now,
LastVerified: now,
TokenFormat: 2,
}
// Add a response header indicating success for the proxy
c.Set("X-Checkpoint-Status", "success")
log.Printf("Successfully verified challenge for IP %s, issuing token", clientIP)
// Issue token (handles DB storage, signing, cookie setting)
return issueToken(c, token)
}
// Renamed request struct
type CheckpointVerifyRequest struct {
RequestID string `json:"request_id"`
Nonce string `json:"g"`
PoSHashes []string `json:"h"`
PoSTimes []int64 `json:"i"`
DecoyHashes []string `json:"j"`
DecoyTimes []int64 `json:"k"`
DecoyFields []map[string]interface{} `json:"l"`
}
func generateChallenge() (string, string) {
randomBytes := make([]byte, 16)
_, err := cryptorand.Read(randomBytes)
if err != nil {
log.Fatalf("CRITICAL: Failed to generate secure random challenge: %v", err)
}
saltBytes := make([]byte, checkpointConfig.SaltLength)
_, err = cryptorand.Read(saltBytes)
if err != nil {
log.Fatalf("CRITICAL: Failed to generate secure random salt: %v", err)
}
salt := hex.EncodeToString(saltBytes)
return hex.EncodeToString(randomBytes), salt
}
// calculateTokenHash calculates a unique hash for storing the token status
// IMPORTANT: This hash is now used as the key in the database.
func calculateTokenHash(token CheckpointToken) string {
// Hash relevant fields that identify this specific verification instance
// Using Nonce, Entropy, and Creation time ensures uniqueness per issuance.
data := fmt.Sprintf("%s:%s:%d",
token.Nonce,
token.Entropy,
token.Created.UnixNano())
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
// RequestSanitizationMiddleware spots malicious patterns (SQLi, XSS, path traversal)
// and returns 403 immediately to keep your app safe.
func RequestSanitizationMiddleware() fiber.Handler {
return func(c *fiber.Ctx) error {
// Check URL path for directory traversal
path := c.Path()
if strings.Contains(path, "../") || strings.Contains(path, "..\\") {
log.Printf("Security block: Directory traversal attempt in path: %s from IP: %s", path, getRealIP(c))
return c.Status(fiber.StatusForbidden).SendString("Forbidden")
}
// Check query parameters for malicious patterns
query := c.Request().URI().QueryString()
if len(query) > 0 {
queryStr := string(query)
// Check for dangerous characters if configured
if checkpointConfig.BlockDangerousPathChars {
if strings.Contains(queryStr, ";") || strings.Contains(queryStr, "\\") || strings.Contains(queryStr, "`") {
log.Printf("Security block: Dangerous character in query from IP: %s, Query: %s", getRealIP(c), queryStr)
return c.Status(fiber.StatusForbidden).SendString("Forbidden")
}
}
// Check for configured attack patterns
for _, pattern := range checkpointConfig.DangerousQueryPatterns {
if pattern.MatchString(queryStr) {
log.Printf("Security block: Malicious pattern match in query from IP: %s, Pattern: %s, Query: %s",
getRealIP(c), pattern.String(), queryStr)
return c.Status(fiber.StatusForbidden).SendString("Forbidden")
}
}
}
return c.Next()
}
}