常用中间件 #

一、JWT认证中间件 #

1.1 基本实现 #

go
package middleware

import (
    "strings"
    "time"
    
    "github.com/gofiber/fiber/v2"
    "github.com/golang-jwt/jwt/v4"
)

type JWTConfig struct {
    Secret      string
    TokenHeader string
    ContextKey  string
    SkipPaths   []string
}

func JWTProtected(config JWTConfig) fiber.Handler {
    return func(c *fiber.Ctx) error {
        // 跳过指定路径
        for _, path := range config.SkipPaths {
            if c.Path() == path {
                return c.Next()
            }
        }
        
        // 获取Token
        authHeader := c.Get(config.TokenHeader)
        if authHeader == "" {
            return c.Status(401).JSON(fiber.Map{
                "error": "Missing token",
            })
        }
        
        // 解析Bearer Token
        parts := strings.Split(authHeader, " ")
        if len(parts) != 2 || parts[0] != "Bearer" {
            return c.Status(401).JSON(fiber.Map{
                "error": "Invalid token format",
            })
        }
        
        // 验证Token
        token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) {
            return []byte(config.Secret), nil
        })
        
        if err != nil || !token.Valid {
            return c.Status(401).JSON(fiber.Map{
                "error": "Invalid token",
            })
        }
        
        // 存储Claims
        c.Locals(config.ContextKey, token.Claims)
        
        return c.Next()
    }
}

1.2 生成Token #

go
type UserClaims struct {
    UserID string `json:"user_id"`
    Role   string `json:"role"`
    jwt.RegisteredClaims
}

func GenerateToken(userID, role, secret string) (string, error) {
    claims := UserClaims{
        UserID: userID,
        Role:   role,
        RegisteredClaims: jwt.RegisteredClaims{
            ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
            IssuedAt:  jwt.NewNumericDate(time.Now()),
        },
    }
    
    token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
    return token.SignedString([]byte(secret))
}

1.3 使用示例 #

go
func main() {
    app := fiber.New()
    
    // 公开路由
    app.Post("/login", loginHandler)
    
    // 受保护路由
    protected := app.Group("/api", JWTProtected(JWTConfig{
        Secret:      "your-secret-key",
        TokenHeader: "Authorization",
        ContextKey:  "user",
        SkipPaths:   []string{"/health"},
    }))
    
    protected.Get("/profile", getProfile)
    
    app.Listen(":3000")
}

func getProfile(c *fiber.Ctx) error {
    claims := c.Locals("user").(jwt.MapClaims)
    return c.JSON(fiber.Map{
        "user_id": claims["user_id"],
        "role":    claims["role"],
    })
}

二、权限控制中间件 #

2.1 角色权限 #

go
func RoleRequired(roles ...string) fiber.Handler {
    return func(c *fiber.Ctx) error {
        claims, ok := c.Locals("user").(jwt.MapClaims)
        if !ok {
            return c.Status(401).JSON(fiber.Map{
                "error": "Unauthorized",
            })
        }
        
        userRole := claims["role"].(string)
        
        for _, role := range roles {
            if userRole == role {
                return c.Next()
            }
        }
        
        return c.Status(403).JSON(fiber.Map{
            "error": "Forbidden",
        })
    }
}

// 使用
admin := app.Group("/admin", 
    JWTProtected(config),
    RoleRequired("admin"),
)

2.2 资源权限 #

go
func ResourcePermission(resource, action string) fiber.Handler {
    return func(c *fiber.Ctx) error {
        claims := c.Locals("user").(jwt.MapClaims)
        userID := claims["user_id"].(string)
        
        // 检查权限
        hasPermission := checkPermission(userID, resource, action)
        if !hasPermission {
            return c.Status(403).JSON(fiber.Map{
                "error": "Permission denied",
            })
        }
        
        return c.Next()
    }
}

// 使用
app.Delete("/users/:id", 
    JWTProtected(config),
    ResourcePermission("users", "delete"),
    deleteUser,
)

三、请求日志中间件 #

3.1 结构化日志 #

go
package middleware

import (
    "time"
    
    "github.com/gofiber/fiber/v2"
    "go.uber.org/zap"
)

func StructuredLogger(logger *zap.Logger) fiber.Handler {
    return func(c *fiber.Ctx) error {
        start := time.Now()
        
        err := c.Next()
        
        // 记录请求信息
        logger.Info("HTTP Request",
            zap.String("method", c.Method()),
            zap.String("path", c.Path()),
            zap.Int("status", c.Response().StatusCode()),
            zap.Duration("latency", time.Since(start)),
            zap.String("ip", c.IP()),
            zap.String("user_agent", c.Get("User-Agent")),
        )
        
        return err
    }
}

3.2 请求追踪日志 #

go
func TraceLogger() fiber.Handler {
    return func(c *fiber.Ctx) error {
        requestID := c.Locals("requestid").(string)
        
        start := time.Now()
        
        // 请求开始
        log.Printf("[%s] --> %s %s", requestID, c.Method(), c.Path())
        
        err := c.Next()
        
        // 请求结束
        log.Printf("[%s] <-- %d %v", 
            requestID, 
            c.Response().StatusCode(), 
            time.Since(start),
        )
        
        return err
    }
}

四、性能监控中间件 #

4.1 响应时间监控 #

go
func ResponseTime() fiber.Handler {
    return func(c *fiber.Ctx) error {
        start := time.Now()
        
        err := c.Next()
        
        duration := time.Since(start)
        
        // 设置响应头
        c.Set("X-Response-Time", duration.String())
        
        // 慢请求告警
        if duration > 1*time.Second {
            log.Printf("Slow request: %s %s took %v", 
                c.Method(), c.Path(), duration)
        }
        
        return err
    }
}

4.2 性能指标收集 #

go
type Metrics struct {
    RequestCount    int64
    ResponseTime    time.Duration
    ErrorCount      int64
}

func MetricsMiddleware(metrics *Metrics) fiber.Handler {
    return func(c *fiber.Ctx) error {
        atomic.AddInt64(&metrics.RequestCount, 1)
        
        start := time.Now()
        
        err := c.Next()
        
        duration := time.Since(start)
        atomic.AddInt64((*int64)(&metrics.ResponseTime), int64(duration))
        
        if c.Response().StatusCode() >= 400 {
            atomic.AddInt64(&metrics.ErrorCount, 1)
        }
        
        return err
    }
}

五、请求验证中间件 #

5.1 Body验证 #

go
import "github.com/go-playground/validator/v10"

type Validator struct {
    validate *validator.Validate
}

func NewValidator() *Validator {
    return &Validator{
        validate: validator.New(),
    }
}

func (v *Validator) ValidateBody(s interface{}) fiber.Handler {
    return func(c *fiber.Ctx) error {
        if err := c.BodyParser(s); err != nil {
            return c.Status(400).JSON(fiber.Map{
                "error": "Invalid request body",
            })
        }
        
        if err := v.validate.Struct(s); err != nil {
            return c.Status(400).JSON(fiber.Map{
                "error":  "Validation failed",
                "fields": err.Error(),
            })
        }
        
        return c.Next()
    }
}

// 使用
type CreateUserRequest struct {
    Name  string `json:"name" validate:"required,min=2,max=50"`
    Email string `json:"email" validate:"required,email"`
    Age   int    `json:"age" validate:"min=0,max=150"`
}

validator := NewValidator()

app.Post("/users", 
    validator.ValidateBody(&CreateUserRequest{}),
    createUser,
)

5.2 Query验证 #

go
func ValidateQuery(s interface{}) fiber.Handler {
    return func(c *fiber.Ctx) error {
        if err := c.QueryParser(s); err != nil {
            return c.Status(400).JSON(fiber.Map{
                "error": "Invalid query parameters",
            })
        }
        
        validate := validator.New()
        if err := validate.Struct(s); err != nil {
            return c.Status(400).JSON(fiber.Map{
                "error": err.Error(),
            })
        }
        
        return c.Next()
    }
}

六、跨域处理中间件 #

6.1 自定义CORS #

go
func CustomCORS(allowedOrigins []string) fiber.Handler {
    return func(c *fiber.Ctx) error {
        origin := c.Get("Origin")
        
        // 检查是否允许
        allowed := false
        for _, o := range allowedOrigins {
            if o == "*" || o == origin {
                allowed = true
                break
            }
        }
        
        if allowed {
            c.Set("Access-Control-Allow-Origin", origin)
            c.Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS")
            c.Set("Access-Control-Allow-Headers", "Content-Type,Authorization")
            c.Set("Access-Control-Allow-Credentials", "true")
        }
        
        // 处理预检请求
        if c.Method() == "OPTIONS" {
            return c.SendStatus(204)
        }
        
        return c.Next()
    }
}

七、限流中间件 #

7.1 基于IP限流 #

go
import "golang.org/x/time/rate"

func RateLimit(rps int, burst int) fiber.Handler {
    limiters := make(map[string]*rate.Limiter)
    mu := sync.Mutex{}
    
    return func(c *fiber.Ctx) error {
        ip := c.IP()
        
        mu.Lock()
        limiter, exists := limiters[ip]
        if !exists {
            limiter = rate.NewLimiter(rate.Limit(rps), burst)
            limiters[ip] = limiter
        }
        mu.Unlock()
        
        if !limiter.Allow() {
            return c.Status(429).JSON(fiber.Map{
                "error": "Too many requests",
            })
        }
        
        return c.Next()
    }
}

7.2 基于用户限流 #

go
func UserRateLimit(rps int, burst int) fiber.Handler {
    limiters := make(map[string]*rate.Limiter)
    mu := sync.Mutex{}
    
    return func(c *fiber.Ctx) error {
        // 获取用户ID
        claims := c.Locals("user")
        if claims == nil {
            return c.Next()
        }
        
        userID := claims.(jwt.MapClaims)["user_id"].(string)
        
        mu.Lock()
        limiter, exists := limiters[userID]
        if !exists {
            limiter = rate.NewLimiter(rate.Limit(rps), burst)
            limiters[userID] = limiter
        }
        mu.Unlock()
        
        if !limiter.Allow() {
            return c.Status(429).JSON(fiber.Map{
                "error": "Rate limit exceeded",
            })
        }
        
        return c.Next()
    }
}

八、缓存中间件 #

8.1 响应缓存 #

go
type CacheStore interface {
    Get(key string) ([]byte, bool)
    Set(key string, value []byte, ttl time.Duration)
}

func CacheMiddleware(store CacheStore, ttl time.Duration) fiber.Handler {
    return func(c *fiber.Ctx) error {
        // 只缓存GET请求
        if c.Method() != fiber.MethodGet {
            return c.Next()
        }
        
        // 生成缓存键
        key := c.Path() + "?" + c.Request().URI().Query().String()
        
        // 尝试从缓存获取
        if cached, found := store.Get(key); found {
            c.Set("X-Cache", "HIT")
            return c.Send(cached)
        }
        
        // 执行请求
        err := c.Next()
        if err != nil {
            return err
        }
        
        // 缓存响应
        store.Set(key, c.Response().Body(), ttl)
        c.Set("X-Cache", "MISS")
        
        return nil
    }
}

九、错误处理中间件 #

9.1 统一错误处理 #

go
type AppError struct {
    Code    int    `json:"code"`
    Message string `json:"message"`
}

func (e *AppError) Error() string {
    return e.Message
}

func ErrorHandler(c *fiber.Ctx, err error) error {
    code := fiber.StatusInternalServerError
    message := "Internal Server Error"
    
    if e, ok := err.(*AppError); ok {
        code = e.Code
        message = e.Message
    } else if e, ok := err.(*fiber.Error); ok {
        code = e.Code
        message = e.Message
    }
    
    return c.Status(code).JSON(fiber.Map{
        "error": message,
        "code":  code,
    })
}

// 配置
app := fiber.New(fiber.Config{
    ErrorHandler: ErrorHandler,
})

十、中间件组合示例 #

10.1 完整中间件链 #

go
func main() {
    app := fiber.New()
    
    // 1. 错误恢复
    app.Use(recover.New())
    
    // 2. 请求ID
    app.Use(requestid.New())
    
    // 3. 日志记录
    app.Use(logger.New())
    
    // 4. CORS
    app.Use(cors.New())
    
    // 5. 安全头
    app.Use(helmet.New())
    
    // 6. 压缩
    app.Use(compress.New())
    
    // 7. 限流
    app.Use(RateLimit(100, 20))
    
    // 公开路由
    app.Post("/login", loginHandler)
    app.Get("/health", healthHandler)
    
    // API路由
    api := app.Group("/api")
    
    // 8. JWT认证
    api.Use(JWTProtected(JWTConfig{
        Secret: "secret",
    }))
    
    // 9. 权限检查
    api.Use(RoleRequired("user"))
    
    // 业务路由
    api.Get("/profile", getProfile)
    api.Get("/users", getUsers)
    
    // 管理员路由
    admin := api.Group("/admin")
    admin.Use(RoleRequired("admin"))
    admin.Get("/stats", getStats)
    
    app.Listen(":3000")
}

十一、总结 #

11.1 常用中间件清单 #

中间件 用途 优先级
Recover 错误恢复 最高
RequestID 请求追踪
Logger 日志记录
CORS 跨域处理
Helmet 安全头
RateLimit 请求限流
JWT 身份认证
Permission 权限控制
Validator 数据验证
Cache 响应缓存

11.2 下一步 #

现在你已经了解了常用中间件,接下来让我们学习 Context上下文,深入了解Fiber的请求处理机制!

最后更新:2026-03-28