Import existing project
This commit is contained in:
parent
7887817595
commit
80b0cc4939
125 changed files with 16980 additions and 0 deletions
604
main.go
Normal file
604
main.go
Normal file
|
|
@ -0,0 +1,604 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"caileb/middleware"
|
||||
"caileb/utils"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/compress"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gofiber/template/html/v2"
|
||||
)
|
||||
|
||||
// getEnvInt tries to read an integer from the environment.
|
||||
// If it's missing or invalid, it returns the fallback.
|
||||
func getEnvInt(key string, fallback int) int {
|
||||
if value, exists := os.LookupEnv(key); exists {
|
||||
if intVal, err := strconv.Atoi(value); err == nil {
|
||||
return intVal
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// getEnvBool reads a boolean from the environment.
|
||||
// Returns fallback if missing; accepts "true" or "1" as true.
|
||||
func getEnvBool(key string, fallback bool) bool {
|
||||
if value, exists := os.LookupEnv(key); exists {
|
||||
return strings.ToLower(value) == "true" || value == "1"
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// validatePathParam rejects path params with unsafe characters (../, slashes, etc.).
|
||||
func validatePathParam(paramName string) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
param := c.Params(paramName)
|
||||
|
||||
// Clean and validate path
|
||||
cleanedParam := filepath.Clean(param)
|
||||
|
||||
// Security checks
|
||||
if cleanedParam != param || strings.Contains(param, "..") ||
|
||||
strings.Contains(param, "/") || strings.Contains(param, "\\") {
|
||||
return c.Status(fiber.StatusForbidden).SendString("Forbidden")
|
||||
}
|
||||
|
||||
// Alphanumeric validation
|
||||
validChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-."
|
||||
for _, char := range param {
|
||||
if !strings.ContainsRune(validChars, char) {
|
||||
return c.Status(fiber.StatusForbidden).SendString("Forbidden")
|
||||
}
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// customCompression adds Brotli or gzip compression depending on client support.
|
||||
func customCompression() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Check Accept-Encoding header
|
||||
acceptEncoding := c.Get("Accept-Encoding")
|
||||
|
||||
// Set Vary header
|
||||
c.Append("Vary", "Accept-Encoding")
|
||||
|
||||
// Process after the handler has been executed
|
||||
if err := c.Next(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only compress if response is successful and has body
|
||||
if c.Response().StatusCode() != 200 || len(c.Response().Body()) < 1024 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check content type for compressibility
|
||||
contentType := string(c.Response().Header.ContentType())
|
||||
if !isCompressible(contentType) {
|
||||
return nil
|
||||
}
|
||||
|
||||
body := c.Response().Body()
|
||||
|
||||
// Apply compression based on client support
|
||||
if strings.Contains(acceptEncoding, "br") {
|
||||
// Brotli compression
|
||||
var buf bytes.Buffer
|
||||
writer := brotli.NewWriterLevel(&buf, 7)
|
||||
|
||||
if _, err := writer.Write(body); err != nil {
|
||||
return nil // Skip compression on error
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil // Skip compression on error
|
||||
}
|
||||
|
||||
compressed := buf.Bytes()
|
||||
|
||||
// Only use compression if it's actually smaller
|
||||
if len(compressed) < len(body) {
|
||||
c.Response().Header.Set("Content-Encoding", "br")
|
||||
c.Response().SetBodyRaw(compressed)
|
||||
}
|
||||
} else if strings.Contains(acceptEncoding, "gzip") {
|
||||
// Let the built-in gzip middleware handle it
|
||||
c.Response().Header.Set("Content-Encoding", "gzip")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// isCompressible returns true for common types that benefit from compression.
|
||||
func isCompressible(contentType string) bool {
|
||||
compressibleTypes := []string{
|
||||
"text/", "application/json", "application/javascript",
|
||||
"application/xml", "image/svg", "font/",
|
||||
"application/wasm", "application/xhtml", "application/rss",
|
||||
}
|
||||
|
||||
for _, t := range compressibleTypes {
|
||||
if strings.Contains(contentType, t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// main parses flags, sets up middleware/plugins, and starts the server.
|
||||
// It also handles graceful shutdown signals.
|
||||
func main() {
|
||||
// Parse command line flags
|
||||
prodMode := flag.Bool("p", false, "Run in production mode")
|
||||
devMode := flag.Bool("d", false, "Run in development mode")
|
||||
daemonMode := flag.Bool("b", false, "Run as a daemon (background process)")
|
||||
port := flag.String("port", "1488", "Port to listen on")
|
||||
skipPOW := flag.Bool("skip-pow", false, "Skip proof-of-work protection")
|
||||
flag.Parse()
|
||||
|
||||
// Handle daemon mode - fork a new process and exit the parent
|
||||
if *daemonMode && os.Getenv("_DAEMON_CHILD") != "1" {
|
||||
// Prepare the command to run this program again as a child
|
||||
cmd := exec.Command(os.Args[0], os.Args[1:]...)
|
||||
cmd.Env = append(os.Environ(), "_DAEMON_CHILD=1")
|
||||
cmd.Start()
|
||||
|
||||
log.Printf("Server started in daemon mode with PID: %d\n", cmd.Process.Pid)
|
||||
// Write the PID to a file for later reference
|
||||
pidFile := "server.pid"
|
||||
if err := os.WriteFile(pidFile, []byte(strconv.Itoa(cmd.Process.Pid)), 0644); err != nil {
|
||||
log.Printf("Warning: Could not write PID file: %v\n", err)
|
||||
}
|
||||
|
||||
// Exit the parent process
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Set environment based on flags
|
||||
if *prodMode {
|
||||
os.Setenv("APP_ENV", "production")
|
||||
} else if *devMode {
|
||||
os.Setenv("APP_ENV", "development")
|
||||
} else {
|
||||
// Default to production mode if no mode is specified
|
||||
os.Setenv("APP_ENV", "production")
|
||||
}
|
||||
|
||||
// Configure minification options from environment variables
|
||||
opts := utils.DefaultMinifierOptions()
|
||||
opts.MaxWorkers = getEnvInt("MINIFY_WORKERS", opts.MaxWorkers)
|
||||
opts.SkipUnchanged = getEnvBool("MINIFY_SKIP_UNCHANGED", opts.SkipUnchanged)
|
||||
opts.RemoveComments = getEnvBool("MINIFY_REMOVE_COMMENTS", opts.RemoveComments)
|
||||
opts.KeepConditionalComments = getEnvBool("MINIFY_KEEP_CONDITIONAL_COMMENTS", opts.KeepConditionalComments)
|
||||
opts.KeepSpecialComments = getEnvBool("MINIFY_KEEP_SPECIAL_COMMENTS", opts.KeepSpecialComments)
|
||||
|
||||
// Minify assets from develop to public
|
||||
log.Println("Minifying assets from /develop to /public directories...")
|
||||
if err := utils.MinifyAssetsWithOptions(opts); err != nil {
|
||||
log.Fatalf("Failed to minify assets: %v", err)
|
||||
}
|
||||
|
||||
// Setup the template engine
|
||||
engine := html.New("./public/static", ".html")
|
||||
engine.Reload(os.Getenv("APP_ENV") != "production") // Enable reloading in development mode
|
||||
|
||||
// Create a new Fiber app with a custom error handler (serving error.html).
|
||||
app := fiber.New(fiber.Config{
|
||||
ErrorHandler: func(c *fiber.Ctx, err error) error {
|
||||
code := fiber.StatusInternalServerError
|
||||
if e, ok := err.(*fiber.Error); ok {
|
||||
code = e.Code
|
||||
}
|
||||
return c.Status(code).SendFile(filepath.Join("public", "html", "error.html"))
|
||||
},
|
||||
StrictRouting: true, // Enable strict routing for better path validation
|
||||
Views: engine, // Set the template engine
|
||||
ProxyHeader: "X-Forwarded-For", // Trust X-Forwarded-For header
|
||||
EnableTrustedProxyCheck: true, // Enable proxy checking
|
||||
TrustedProxies: []string{"127.0.0.1", "::1"}, // Add your NAS IP here
|
||||
})
|
||||
|
||||
// Logger middleware only in development mode
|
||||
if os.Getenv("APP_ENV") != "production" {
|
||||
// API routes: log method, path and latency only (no status)
|
||||
app.Use(logger.New(logger.Config{
|
||||
Format: "${time} ${method} ${path} - ${latency}",
|
||||
TimeFormat: "2006-01-02T15:04:05",
|
||||
TimeZone: "Local",
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
// skip this logger for non-API paths
|
||||
return !strings.HasPrefix(c.Path(), "/api")
|
||||
},
|
||||
}))
|
||||
// Non-API routes: log full details including status
|
||||
app.Use(logger.New(logger.Config{
|
||||
Format: "${time} ${status} | ${method} ${path} - ${latency}",
|
||||
TimeFormat: "2006-01-02T15:04:05",
|
||||
TimeZone: "Local",
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
// skip this logger for API paths
|
||||
return strings.HasPrefix(c.Path(), "/api")
|
||||
},
|
||||
}))
|
||||
log.Printf("Logger middleware enabled (%s mode)\n", os.Getenv("APP_ENV"))
|
||||
}
|
||||
|
||||
// Force text/html content type for HTML files for better compression detection
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
path := c.Path()
|
||||
// Only set content type for HTML files
|
||||
if strings.HasSuffix(path, ".html") || path == "/" || (len(path) > 0 && !strings.Contains(path, ".")) {
|
||||
c.Set("Content-Type", "text/html; charset=utf-8")
|
||||
} else if strings.HasSuffix(path, ".json") {
|
||||
c.Set("Content-Type", "application/json")
|
||||
}
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
// Built-in compression for non-Brotli clients
|
||||
app.Use(compress.New(compress.Config{
|
||||
Level: 7,
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
// Skip if client accepts Brotli
|
||||
return strings.Contains(c.Get("Accept-Encoding"), "br")
|
||||
},
|
||||
}))
|
||||
|
||||
// Custom Brotli compression for supported clients
|
||||
app.Use(customCompression())
|
||||
|
||||
// Security headers middleware (improves site security)
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Set("X-Frame-Options", "SAMEORIGIN")
|
||||
c.Set("X-Content-Type-Options", "nosniff")
|
||||
c.Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Set("X-Permitted-Cross-Domain-Policies", "none")
|
||||
c.Set("Cross-Origin-Opener-Policy", "same-origin")
|
||||
c.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
|
||||
//c.Set("Content-Security-Policy", "base-uri 'self'; default-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline' blob: *.caileb.com; img-src * data:; font-src 'self'; worker-src 'self' blob:; frame-src *.youtube.com *.youtube-nocookie.com *.bitchute.com *.rumble.com rumble.com; connect-src 'self' *.youtube.com *.youtube-nocookie.com *.ytimg.com *.bitchute.com *.rumble.com *.caileb.com;")
|
||||
return c.Next()
|
||||
})
|
||||
|
||||
// Serve static files from the public directory with optimized caching
|
||||
// This should be before HTML POW middleware to correctly handle static files
|
||||
app.Static("/", "./public", fiber.Static{
|
||||
Compress: true, // Enable compression for static files
|
||||
ByteRange: true, // Enable byte range requests
|
||||
CacheDuration: 24 * time.Hour,
|
||||
MaxAge: 86400,
|
||||
})
|
||||
|
||||
// Special handler for favicon.ico to ensure it's properly served
|
||||
app.Get("/favicon.ico", func(c *fiber.Ctx) error {
|
||||
return c.SendFile("./public/favicon.ico", false)
|
||||
})
|
||||
|
||||
// Load and apply registered middleware plugins
|
||||
for _, handler := range middleware.LoadPlugins(*skipPOW) {
|
||||
app.Use(handler)
|
||||
}
|
||||
log.Println("Loaded middleware plugins")
|
||||
|
||||
// API group with POW protection
|
||||
api := app.Group("/api")
|
||||
|
||||
// Endpoint to verify POW solutions and issue tokens
|
||||
api.Post("/pow/verify", middleware.VerifyCheckpointHandler)
|
||||
|
||||
// Challenge endpoint for secure POW parameters
|
||||
api.Get("/pow/challenge", middleware.GetCheckpointChallengeHandler)
|
||||
|
||||
// Backwards compatibility for existing clients
|
||||
api.Get("/verify", middleware.VerifyCheckpointHandler)
|
||||
|
||||
// Homepage route: serve index.html from public/html/ with compression
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "text/html; charset=utf-8")
|
||||
c.Response().Header.Add("Vary", "Accept-Encoding")
|
||||
return c.SendFile(filepath.Join("public", "html", "index.html"))
|
||||
})
|
||||
|
||||
// Dynamic page route using the validation middleware
|
||||
app.Get("/:page", validatePathParam("page"), func(c *fiber.Ctx) error {
|
||||
page := c.Params("page")
|
||||
c.Set("Content-Type", "text/html; charset=utf-8")
|
||||
c.Response().Header.Add("Vary", "Accept-Encoding")
|
||||
return c.SendFile(filepath.Join("public", "html", page+".html"))
|
||||
})
|
||||
|
||||
// Catch-all: serve a 404 error page for unmatched routes
|
||||
app.Use(func(c *fiber.Ctx) error {
|
||||
c.Set("Content-Type", "text/html; charset=utf-8")
|
||||
c.Response().Header.Add("Vary", "Accept-Encoding")
|
||||
return c.Status(404).SendFile(filepath.Join("public", "html", "error.html"))
|
||||
})
|
||||
|
||||
// Start the server
|
||||
go func() {
|
||||
addr := ":" + *port
|
||||
log.Printf("Server starting on %s in %s mode\n", addr, os.Getenv("APP_ENV"))
|
||||
if err := app.Listen(addr); err != nil {
|
||||
log.Fatalf("Server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start the GeoIP database update routine
|
||||
go startGeoIPUpdateRoutine()
|
||||
|
||||
// If running as daemon child, no need to wait for signals in foreground
|
||||
if os.Getenv("_DAEMON_CHILD") == "1" {
|
||||
// In daemon mode, we still need to wait for signals
|
||||
// but we can close stdout/stderr
|
||||
if f, err := os.OpenFile("/dev/null", os.O_RDWR, 0); err == nil {
|
||||
// Redirect stdout/stderr to /dev/null for true daemon behavior
|
||||
os.Stdout = f
|
||||
os.Stderr = f
|
||||
// Don't close f as it's now used by os.Stdout and os.Stderr
|
||||
}
|
||||
}
|
||||
|
||||
// Graceful shutdown handling
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
log.Println("Shutting down server...")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := app.ShutdownWithContext(ctx); err != nil {
|
||||
log.Fatalf("Server forced to shutdown: %v", err)
|
||||
}
|
||||
|
||||
// Close the token store database
|
||||
if err := middleware.CloseTokenStore(); err != nil {
|
||||
log.Printf("Error closing token store: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Server exiting")
|
||||
}
|
||||
|
||||
// startGeoIPUpdateRoutine starts a goroutine that updates GeoIP databases daily
|
||||
func startGeoIPUpdateRoutine() {
|
||||
// Start immediately after server startup to ensure databases are fresh
|
||||
updateGeoIPDatabases()
|
||||
|
||||
// Then schedule daily updates
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
updateGeoIPDatabases()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// updateGeoIPDatabases downloads the latest GeoLite2 Country and ASN databases
|
||||
func updateGeoIPDatabases() {
|
||||
// MaxMind account credentials
|
||||
accountID := "1015174"
|
||||
licenseKey := "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"
|
||||
|
||||
// Database paths and URLs
|
||||
databases := []struct {
|
||||
name string
|
||||
url string
|
||||
destFile string
|
||||
}{
|
||||
{
|
||||
name: "GeoLite2-Country",
|
||||
url: "https://download.maxmind.com/geoip/databases/GeoLite2-Country/download?suffix=tar.gz",
|
||||
destFile: "./data/GeoLite2-Country.mmdb",
|
||||
},
|
||||
{
|
||||
name: "GeoLite2-ASN",
|
||||
url: "https://download.maxmind.com/geoip/databases/GeoLite2-ASN/download?suffix=tar.gz",
|
||||
destFile: "./data/GeoLite2-ASN.mmdb",
|
||||
},
|
||||
}
|
||||
|
||||
// Ensure data directory exists
|
||||
if err := os.MkdirAll("./data", 0755); err != nil {
|
||||
log.Printf("ERROR: Failed to create data directory: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create HTTP client that follows redirects
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// MaxMind uses Cloudflare R2 for redirects, follow them
|
||||
if len(via) >= 10 {
|
||||
return fmt.Errorf("too many redirects")
|
||||
}
|
||||
// Add basic auth to the redirected request if needed
|
||||
if req.URL.Host != "mm-prod-geoip-databases.a2649acb697e2c09b632799562c076f2.r2.cloudflarestorage.com" {
|
||||
req.SetBasicAuth(accountID, licenseKey)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Download and process each database
|
||||
for _, db := range databases {
|
||||
log.Printf("Checking for updates to %s...", db.name)
|
||||
|
||||
// First, check if an update is needed via HEAD request
|
||||
headReq, err := http.NewRequest("HEAD", db.url, nil)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to create HEAD request for %s: %v", db.name, err)
|
||||
continue
|
||||
}
|
||||
headReq.SetBasicAuth(accountID, licenseKey)
|
||||
|
||||
headResp, err := client.Do(headReq)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to make HEAD request for %s: %v", db.name, err)
|
||||
continue
|
||||
}
|
||||
headResp.Body.Close()
|
||||
|
||||
// Check if file exists and get its modification time
|
||||
updateNeeded := true
|
||||
if fileInfo, err := os.Stat(db.destFile); err == nil {
|
||||
lastModified := headResp.Header.Get("Last-Modified")
|
||||
if lastModified != "" {
|
||||
remoteTime, err := time.Parse(time.RFC1123, lastModified)
|
||||
if err == nil {
|
||||
// Only update if remote file is newer
|
||||
if !remoteTime.After(fileInfo.ModTime()) {
|
||||
log.Printf("No update needed for %s, local copy is current", db.name)
|
||||
updateNeeded = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !updateNeeded {
|
||||
continue
|
||||
}
|
||||
|
||||
// Download the database
|
||||
log.Printf("Downloading %s...", db.name)
|
||||
req, err := http.NewRequest("GET", db.url, nil)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to create request for %s: %v", db.name, err)
|
||||
continue
|
||||
}
|
||||
req.SetBasicAuth(accountID, licenseKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to download %s: %v", db.name, err)
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("ERROR: Failed to download %s: HTTP %d", db.name, resp.StatusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create a temporary file to store the downloaded archive
|
||||
tempFile, err := os.CreateTemp("", "geoip-*.tar.gz")
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to create temp file for %s: %v", db.name, err)
|
||||
continue
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
// Copy the response body to the temporary file
|
||||
_, err = io.Copy(tempFile, resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to save downloaded %s: %v", db.name, err)
|
||||
tempFile.Close()
|
||||
continue
|
||||
}
|
||||
tempFile.Close()
|
||||
|
||||
// Extract the .mmdb file from the tar.gz archive
|
||||
extracted, err := extractMMDBFromTarGz(tempFile.Name(), db.name)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to extract %s: %v", db.name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Move the extracted file to the destination
|
||||
err = os.Rename(extracted, db.destFile)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: Failed to move %s to destination: %v", db.name, err)
|
||||
os.Remove(extracted) // Clean up
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Successfully updated %s", db.name)
|
||||
}
|
||||
|
||||
// Reload the databases in the middleware
|
||||
middleware.ReloadGeoIPDatabases()
|
||||
}
|
||||
|
||||
// extractMMDBFromTarGz extracts the .mmdb file from a tar.gz archive
|
||||
func extractMMDBFromTarGz(tarGzPath, dbName string) (string, error) {
|
||||
file, err := os.Open(tarGzPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
gzr, err := gzip.NewReader(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer gzr.Close()
|
||||
|
||||
tr := tar.NewReader(gzr)
|
||||
|
||||
// Create a temporary directory for extraction
|
||||
tempDir, err := os.MkdirTemp("", "geoip-extract-")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Find and extract the .mmdb file
|
||||
var mmdbPath string
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
os.RemoveAll(tempDir)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Look for the .mmdb file in the archive
|
||||
if strings.HasSuffix(header.Name, ".mmdb") && strings.Contains(header.Name, dbName) {
|
||||
// Extract to temporary directory
|
||||
mmdbPath = filepath.Join(tempDir, filepath.Base(header.Name))
|
||||
outFile, err := os.Create(mmdbPath)
|
||||
if err != nil {
|
||||
os.RemoveAll(tempDir)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(outFile, tr); err != nil {
|
||||
outFile.Close()
|
||||
os.RemoveAll(tempDir)
|
||||
return "", err
|
||||
}
|
||||
outFile.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if mmdbPath == "" {
|
||||
os.RemoveAll(tempDir)
|
||||
return "", fmt.Errorf("no .mmdb file found in archive for %s", dbName)
|
||||
}
|
||||
|
||||
return mmdbPath, nil
|
||||
}
|
||||
Reference in a new issue