自定义中间件 #

一、中间件基础结构 #

1.1 最简单的中间件 #

go
func simpleMiddleware(c *fiber.Ctx) error {
    fmt.Println("Before handler")
    err := c.Next()
    fmt.Println("After handler")
    return err
}

// 使用
app.Use(simpleMiddleware)

1.2 标准中间件模式 #

go
func myMiddleware() fiber.Handler {
    return func(c *fiber.Ctx) error {
        // 前置处理
        
        err := c.Next()
        
        // 后置处理
        
        return err
    }
}

// 使用
app.Use(myMiddleware())

二、带配置的中间件 #

2.1 配置结构体 #

go
type AuthConfig struct {
    TokenHeader string
    Secret      string
    SkipPaths   []string
}

func AuthMiddleware(config AuthConfig) fiber.Handler {
    return func(c *fiber.Ctx) error {
        // 检查是否跳过
        for _, path := range config.SkipPaths {
            if c.Path() == path {
                return c.Next()
            }
        }
        
        // 获取Token
        token := c.Get(config.TokenHeader)
        if token == "" {
            return c.Status(401).JSON(fiber.Map{
                "error": "Missing token",
            })
        }
        
        // 验证Token
        if !validateToken(token, config.Secret) {
            return c.Status(401).JSON(fiber.Map{
                "error": "Invalid token",
            })
        }
        
        return c.Next()
    }
}

// 使用
app.Use(AuthMiddleware(AuthConfig{
    TokenHeader: "Authorization",
    Secret:      "my-secret",
    SkipPaths:   []string{"/health", "/login"},
}))

2.2 默认配置 #

go
type LoggerConfig struct {
    Format      string
    TimeFormat  string
    Output      io.Writer
    SkipPaths   []string
}

var LoggerConfigDefault = LoggerConfig{
    Format:     "[${time}] ${status} - ${method} ${path}\n",
    TimeFormat: "2006-01-02 15:04:05",
    Output:     os.Stdout,
    SkipPaths:  []string{},
}

func LoggerMiddleware(config ...LoggerConfig) fiber.Handler {
    cfg := LoggerConfigDefault
    if len(config) > 0 {
        cfg = config[0]
    }
    
    return func(c *fiber.Ctx) error {
        // 检查是否跳过
        for _, path := range cfg.SkipPaths {
            if c.Path() == path {
                return c.Next()
            }
        }
        
        start := time.Now()
        err := c.Next()
        
        // 格式化日志
        log := cfg.Format
        log = strings.Replace(log, "${time}", time.Now().Format(cfg.TimeFormat), -1)
        log = strings.Replace(log, "${status}", strconv.Itoa(c.Response().StatusCode()), -1)
        log = strings.Replace(log, "${method}", c.Method(), -1)
        log = strings.Replace(log, "${path}", c.Path(), -1)
        log = strings.Replace(log, "${latency}", time.Since(start).String(), -1)
        
        cfg.Output.Write([]byte(log))
        
        return err
    }
}

三、中间件设计模式 #

3.1 认证中间件 #

go
type User struct {
    ID   string
    Name string
    Role string
}

func AuthMiddleware() fiber.Handler {
    return func(c *fiber.Ctx) error {
        token := c.Get("Authorization")
        if token == "" {
            return c.Status(401).JSON(fiber.Map{
                "error": "Unauthorized",
            })
        }
        
        // 解析Token
        user, err := parseToken(token)
        if err != nil {
            return c.Status(401).JSON(fiber.Map{
                "error": "Invalid token",
            })
        }
        
        // 存储用户信息
        c.Locals("user", user)
        
        return c.Next()
    }
}

// 获取当前用户
func GetCurrentUser(c *fiber.Ctx) *User {
    if user, ok := c.Locals("user").(*User); ok {
        return user
    }
    return nil
}

3.2 权限中间件 #

go
func RoleMiddleware(roles ...string) fiber.Handler {
    return func(c *fiber.Ctx) error {
        user := GetCurrentUser(c)
        if user == nil {
            return c.Status(401).JSON(fiber.Map{
                "error": "Unauthorized",
            })
        }
        
        // 检查角色
        for _, role := range roles {
            if user.Role == role {
                return c.Next()
            }
        }
        
        return c.Status(403).JSON(fiber.Map{
            "error": "Forbidden",
        })
    }
}

// 使用
app.Get("/admin", AuthMiddleware(), RoleMiddleware("admin"), adminHandler)

3.3 缓存中间件 #

go
type CacheConfig struct {
    Expiration time.Duration
    KeyFunc    func(c *fiber.Ctx) string
    Store      CacheStore
}

func CacheMiddleware(config CacheConfig) fiber.Handler {
    return func(c *fiber.Ctx) error {
        // 只缓存GET请求
        if c.Method() != fiber.MethodGet {
            return c.Next()
        }
        
        // 生成缓存键
        key := config.KeyFunc(c)
        
        // 尝试从缓存获取
        cached, found := config.Store.Get(key)
        if found {
            return c.Send(cached.([]byte))
        }
        
        // 执行请求
        err := c.Next()
        if err != nil {
            return err
        }
        
        // 缓存响应
        body := c.Response().Body()
        config.Store.Set(key, body, config.Expiration)
        
        return nil
    }
}

3.4 请求验证中间件 #

go
type ValidatorConfig struct {
    Rules map[string]string
}

func ValidatorMiddleware(config ValidatorConfig) fiber.Handler {
    return func(c *fiber.Ctx) error {
        errors := make(map[string]string)
        
        for field, rule := range config.Rules {
            value := c.Query(field)
            
            switch rule {
            case "required":
                if value == "" {
                    errors[field] = field + " is required"
                }
            case "email":
                if !isValidEmail(value) {
                    errors[field] = field + " is not a valid email"
                }
            }
        }
        
        if len(errors) > 0 {
            return c.Status(400).JSON(fiber.Map{
                "errors": errors,
            })
        }
        
        return c.Next()
    }
}

// 使用
app.Get("/users", ValidatorMiddleware(ValidatorConfig{
    Rules: map[string]string{
        "email": "required,email",
        "name":  "required",
    },
}), getUsers)

四、高级中间件 #

4.1 请求追踪中间件 #

go
type TraceConfig struct {
    HeaderName string
    Generator  func() string
}

func TraceMiddleware(config TraceConfig) fiber.Handler {
    return func(c *fiber.Ctx) error {
        // 获取或生成TraceID
        traceID := c.Get(config.HeaderName)
        if traceID == "" {
            traceID = config.Generator()
        }
        
        // 设置响应头
        c.Set(config.HeaderName, traceID)
        
        // 存储到上下文
        c.Locals("traceID", traceID)
        
        // 记录开始时间
        start := time.Now()
        
        err := c.Next()
        
        // 记录请求信息
        log.Printf("[%s] %s %s %d %v",
            traceID,
            c.Method(),
            c.Path(),
            c.Response().StatusCode(),
            time.Since(start),
        )
        
        return err
    }
}

4.2 请求重试中间件 #

go
type RetryConfig struct {
    MaxAttempts int
    Delay       time.Duration
    RetryIf     func(c *fiber.Ctx, err error) bool
}

func RetryMiddleware(config RetryConfig) fiber.Handler {
    return func(c *fiber.Ctx) error {
        var err error
        
        for i := 0; i < config.MaxAttempts; i++ {
            // 重置请求体
            if i > 0 {
                c.Request().ResetBody()
            }
            
            err = c.Next()
            
            // 检查是否需要重试
            if !config.RetryIf(c, err) {
                return err
            }
            
            // 等待后重试
            if i < config.MaxAttempts-1 {
                time.Sleep(config.Delay)
            }
        }
        
        return err
    }
}

4.3 请求签名中间件 #

go
type SignConfig struct {
    Secret    string
    HeaderKey string
}

func SignMiddleware(config SignConfig) fiber.Handler {
    return func(c *fiber.Ctx) error {
        // 获取签名
        signature := c.Get(config.HeaderKey)
        if signature == "" {
            return c.Status(401).JSON(fiber.Map{
                "error": "Missing signature",
            })
        }
        
        // 计算预期签名
        body := c.Body()
        expected := hmacSHA256(body, config.Secret)
        
        // 验证签名
        if !hmac.Equal([]byte(signature), []byte(expected)) {
            return c.Status(401).JSON(fiber.Map{
                "error": "Invalid signature",
            })
        }
        
        return c.Next()
    }
}

五、中间件工具函数 #

5.1 跳过路径 #

go
func SkipPaths(paths []string) func(c *fiber.Ctx) bool {
    return func(c *fiber.Ctx) bool {
        for _, path := range paths {
            if c.Path() == path {
                return true
            }
        }
        return false
    }
}

func MyMiddleware(skipFunc func(c *fiber.Ctx) bool) fiber.Handler {
    return func(c *fiber.Ctx) error {
        if skipFunc != nil && skipFunc(c) {
            return c.Next()
        }
        
        // 中间件逻辑
        return c.Next()
    }
}

// 使用
app.Use(MyMiddleware(SkipPaths([]string{"/health", "/metrics"})))

5.2 条件执行 #

go
func ConditionalMiddleware(condition func(c *fiber.Ctx) bool, middleware fiber.Handler) fiber.Handler {
    return func(c *fiber.Ctx) error {
        if condition(c) {
            return middleware(c)
        }
        return c.Next()
    }
}

// 使用
app.Use(ConditionalMiddleware(
    func(c *fiber.Ctx) bool {
        return c.Path() == "/api"
    },
    authMiddleware,
))

六、中间件测试 #

6.1 单元测试 #

go
func TestAuthMiddleware(t *testing.T) {
    app := fiber.New()
    
    app.Use(AuthMiddleware(AuthConfig{
        Secret: "test-secret",
    }))
    
    app.Get("/", func(c *fiber.Ctx) error {
        return c.SendString("OK")
    })
    
    // 测试无Token
    req := httptest.NewRequest("GET", "/", nil)
    resp, err := app.Test(req)
    assert.NoError(t, err)
    assert.Equal(t, 401, resp.StatusCode)
    
    // 测试有效Token
    token := generateToken("test-secret")
    req = httptest.NewRequest("GET", "/", nil)
    req.Header.Set("Authorization", token)
    resp, err = app.Test(req)
    assert.NoError(t, err)
    assert.Equal(t, 200, resp.StatusCode)
}

6.2 集成测试 #

go
func TestMiddlewareChain(t *testing.T) {
    app := fiber.New()
    
    app.Use(requestid.New())
    app.Use(logger.New())
    app.Use(AuthMiddleware(AuthConfig{
        Secret: "test-secret",
    }))
    
    app.Get("/protected", func(c *fiber.Ctx) error {
        return c.SendString("Protected")
    })
    
    // 测试完整中间件链
    token := generateToken("test-secret")
    req := httptest.NewRequest("GET", "/protected", nil)
    req.Header.Set("Authorization", token)
    
    resp, err := app.Test(req)
    assert.NoError(t, err)
    assert.Equal(t, 200, resp.StatusCode)
    assert.NotEmpty(t, resp.Header.Get("X-Request-ID"))
}

七、最佳实践 #

7.1 设计原则 #

text
1. 单一职责:每个中间件只做一件事
2. 可配置:提供配置选项
3. 可测试:易于编写测试
4. 文档化:提供清晰的文档
5. 错误处理:正确处理和传递错误

7.2 命名约定 #

go
// 中间件函数以Middleware结尾
func AuthMiddleware() fiber.Handler { ... }
func LoggerMiddleware() fiber.Handler { ... }

// 配置结构体以Config结尾
type AuthConfig struct { ... }
type LoggerConfig struct { ... }

// 默认配置以Default结尾
var AuthConfigDefault = AuthConfig{ ... }

7.3 错误处理 #

go
func MyMiddleware(config MyConfig) fiber.Handler {
    return func(c *fiber.Ctx) error {
        // 使用fiber.Error返回错误
        if someCondition {
            return fiber.NewError(fiber.StatusBadRequest, "Invalid request")
        }
        
        // 或者返回自定义错误
        if anotherCondition {
            return &MyError{Code: 1001, Message: "Custom error"}
        }
        
        return c.Next()
    }
}

八、完整示例 #

8.1 JWT认证中间件 #

go
package middleware

import (
    "strings"
    "time"
    
    "github.com/gofiber/fiber/v2"
    "github.com/golang-jwt/jwt/v4"
)

type JWTConfig struct {
    Secret      string
    TokenHeader string
    ContextKey  string
    SkipPaths   []string
}

var JWTConfigDefault = JWTConfig{
    TokenHeader: "Authorization",
    ContextKey:  "user",
}

func JWTProtected(config ...JWTConfig) fiber.Handler {
    cfg := JWTConfigDefault
    if len(config) > 0 {
        cfg = config[0]
    }
    
    return func(c *fiber.Ctx) error {
        // 检查是否跳过
        for _, path := range cfg.SkipPaths {
            if c.Path() == path {
                return c.Next()
            }
        }
        
        // 获取Token
        authHeader := c.Get(cfg.TokenHeader)
        if authHeader == "" {
            return c.Status(401).JSON(fiber.Map{
                "error": "Missing authorization header",
            })
        }
        
        // 解析Bearer Token
        parts := strings.Split(authHeader, " ")
        if len(parts) != 2 || parts[0] != "Bearer" {
            return c.Status(401).JSON(fiber.Map{
                "error": "Invalid authorization format",
            })
        }
        
        tokenString := parts[1]
        
        // 验证Token
        token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
            return []byte(cfg.Secret), nil
        })
        
        if err != nil || !token.Valid {
            return c.Status(401).JSON(fiber.Map{
                "error": "Invalid token",
            })
        }
        
        // 存储用户信息
        claims := token.Claims.(jwt.MapClaims)
        c.Locals(cfg.ContextKey, claims)
        
        return c.Next()
    }
}

// 获取当前用户
func GetCurrentUser(c *fiber.Ctx) jwt.MapClaims {
    if claims, ok := c.Locals("user").(jwt.MapClaims); ok {
        return claims
    }
    return nil
}

九、总结 #

9.1 核心要点 #

要点 说明
函数签名 func(c *fiber.Ctx) error
配置模式 返回闭包的工厂函数
c.Next() 调用下一个处理函数
c.Locals() 上下文数据传递
错误处理 返回error终止请求

9.2 下一步 #

现在你已经掌握了自定义中间件,接下来让我们学习 常用中间件,了解实际开发中常用的中间件实现!

最后更新:2026-03-28