自定义中间件 #

一、中间件基础结构 #

1.1 基本格式 #

go
func MyMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 前置处理:在c.Next()之前
        
        c.Next() // 调用下一个处理函数
        
        // 后置处理:在c.Next()之后
    }
}

1.2 带参数的中间件 #

go
func RateLimitMiddleware(rps int) gin.HandlerFunc {
    limiter := rate.NewLimiter(rate.Limit(rps), rps)
    
    return func(c *gin.Context) {
        if !limiter.Allow() {
            c.AbortWithStatusJSON(429, gin.H{
                "code":    429,
                "message": "请求过于频繁",
            })
            return
        }
        c.Next()
    }
}

func main() {
    r := gin.New()
    r.Use(RateLimitMiddleware(100))
    r.Run()
}

1.3 带配置的中间件 #

go
type CORSConfig struct {
    AllowOrigins     []string
    AllowMethods     []string
    AllowHeaders     []string
    ExposeHeaders    []string
    AllowCredentials bool
    MaxAge           int
}

func CORSWithConfig(config CORSConfig) gin.HandlerFunc {
    return func(c *gin.Context) {
        origin := c.GetHeader("Origin")
        
        // 检查是否允许的源
        allowed := false
        for _, o := range config.AllowOrigins {
            if o == "*" || o == origin {
                allowed = true
                break
            }
        }
        
        if allowed {
            c.Header("Access-Control-Allow-Origin", origin)
        }
        
        if len(config.AllowMethods) > 0 {
            c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ", "))
        }
        
        if len(config.AllowHeaders) > 0 {
            c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ", "))
        }
        
        if config.AllowCredentials {
            c.Header("Access-Control-Allow-Credentials", "true")
        }
        
        if config.MaxAge > 0 {
            c.Header("Access-Control-Max-Age", strconv.Itoa(config.MaxAge))
        }
        
        if c.Request.Method == "OPTIONS" {
            c.AbortWithStatus(204)
            return
        }
        
        c.Next()
    }
}

func main() {
    r := gin.New()
    
    config := CORSConfig{
        AllowOrigins:     []string{"http://localhost:3000", "https://example.com"},
        AllowMethods:     []string{"GET", "POST", "PUT", "DELETE"},
        AllowHeaders:     []string{"Content-Type", "Authorization"},
        AllowCredentials: true,
        MaxAge:           86400,
    }
    
    r.Use(CORSWithConfig(config))
    r.Run()
}

二、常用自定义中间件 #

2.1 认证中间件 #

go
type JWTConfig struct {
    Secret     string
    TokenLookup string
}

func JWTAuth(config JWTConfig) gin.HandlerFunc {
    return func(c *gin.Context) {
        // 获取Token
        var token string
        switch {
        case strings.HasPrefix(config.TokenLookup, "header:"):
            header := strings.TrimPrefix(config.TokenLookup, "header:")
            token = c.GetHeader(header)
            if strings.HasPrefix(token, "Bearer ") {
                token = strings.TrimPrefix(token, "Bearer ")
            }
        case strings.HasPrefix(config.TokenLookup, "query:"):
            param := strings.TrimPrefix(config.TokenLookup, "query:")
            token = c.Query(param)
        case strings.HasPrefix(config.TokenLookup, "cookie:"):
            cookie := strings.TrimPrefix(config.TokenLookup, "cookie:")
            token, _ = c.Cookie(cookie)
        }
        
        if token == "" {
            c.AbortWithStatusJSON(401, gin.H{
                "code":    401,
                "message": "未提供认证令牌",
            })
            return
        }
        
        // 验证Token
        claims, err := validateJWT(token, config.Secret)
        if err != nil {
            c.AbortWithStatusJSON(401, gin.H{
                "code":    401,
                "message": "无效的认证令牌",
            })
            return
        }
        
        // 存储用户信息
        c.Set("userId", claims.UserID)
        c.Set("role", claims.Role)
        
        c.Next()
    }
}

2.2 权限中间件 #

go
func RequireRoles(roles ...string) gin.HandlerFunc {
    return func(c *gin.Context) {
        userRole, exists := c.Get("role")
        if !exists {
            c.AbortWithStatusJSON(403, gin.H{
                "code":    403,
                "message": "无权限访问",
            })
            return
        }
        
        role := userRole.(string)
        for _, r := range roles {
            if role == r {
                c.Next()
                return
            }
        }
        
        c.AbortWithStatusJSON(403, gin.H{
            "code":    403,
            "message": "权限不足",
        })
    }
}

func main() {
    r := gin.New()
    r.Use(JWTAuth(config))
    
    admin := r.Group("/admin")
    admin.Use(RequireRoles("admin", "super_admin"))
    {
        admin.GET("/users", listUsers)
    }
    
    r.Run()
}

2.3 请求日志中间件 #

go
type LoggerConfig struct {
    Output    io.Writer
    SkipPaths []string
    Format    string
}

func LoggerWithConfig(config LoggerConfig) gin.HandlerFunc {
    skipPaths := make(map[string]bool)
    for _, path := range config.SkipPaths {
        skipPaths[path] = true
    }
    
    return func(c *gin.Context) {
        path := c.Request.URL.Path
        
        // 跳过指定路径
        if skipPaths[path] {
            c.Next()
            return
        }
        
        start := time.Now()
        
        c.Next()
        
        latency := time.Since(start)
        statusCode := c.Writer.Status()
        clientIP := c.ClientIP()
        method := c.Request.Method
        
        log.Printf("[%s] %s %s %d %v",
            method,
            path,
            clientIP,
            statusCode,
            latency,
        )
    }
}

2.4 请求体缓存 #

go
func BodyCache() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 读取请求体
        body, err := io.ReadAll(c.Request.Body)
        if err != nil {
            c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
            return
        }
        
        // 恢复请求体
        c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
        
        // 缓存请求体
        c.Set("rawBody", body)
        
        c.Next()
    }
}

2.5 响应包装 #

go
type ResponseWrapper struct {
    Code    int         `json:"code"`
    Message string      `json:"message"`
    Data    interface{} `json:"data,omitempty"`
}

func ResponseWrapperMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 创建响应写入器
        writer := &responseWriter{
            ResponseWriter: c.Writer,
            body:           &bytes.Buffer{},
        }
        c.Writer = writer
        
        c.Next()
        
        // 包装响应
        if writer.status == 200 && len(writer.body.Bytes()) > 0 {
            var data interface{}
            json.Unmarshal(writer.body.Bytes(), &data)
            
            response := ResponseWrapper{
                Code:    0,
                Message: "success",
                Data:    data,
            }
            
            jsonData, _ := json.Marshal(response)
            c.Writer = writer.ResponseWriter
            c.Writer.Write(jsonData)
        }
    }
}

2.6 超时控制 #

go
func Timeout(timeout time.Duration) gin.HandlerFunc {
    return func(c *gin.Context) {
        ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
        defer cancel()
        
        c.Request = c.Request.WithContext(ctx)
        
        done := make(chan struct{})
        go func() {
            c.Next()
            close(done)
        }()
        
        select {
        case <-done:
            return
        case <-ctx.Done():
            c.AbortWithStatusJSON(504, gin.H{
                "code":    504,
                "message": "请求超时",
            })
            return
        }
    }
}

三、中间件设计模式 #

3.1 工厂模式 #

go
type MiddlewareFactory struct {
    config *Config
}

func NewMiddlewareFactory(config *Config) *MiddlewareFactory {
    return &MiddlewareFactory{config: config}
}

func (f *MiddlewareFactory) Auth() gin.HandlerFunc {
    return JWTAuth(JWTConfig{
        Secret:      f.config.JWTSecret,
        TokenLookup: "header:Authorization",
    })
}

func (f *MiddlewareFactory) CORS() gin.HandlerFunc {
    return CORSWithConfig(CORSConfig{
        AllowOrigins: f.config.CORSOrigins,
    })
}

func (f *MiddlewareFactory) RateLimit() gin.HandlerFunc {
    return RateLimitMiddleware(f.config.RateLimit)
}

func main() {
    config := LoadConfig()
    factory := NewMiddlewareFactory(config)
    
    r := gin.New()
    r.Use(factory.CORS())
    r.Use(factory.Auth())
    r.Use(factory.RateLimit())
    
    r.Run()
}

3.2 装饰器模式 #

go
func WithMetrics(middleware gin.HandlerFunc) gin.HandlerFunc {
    return func(c *gin.Context) {
        start := time.Now()
        
        middleware(c)
        
        duration := time.Since(start)
        metrics.RecordRequestDuration(duration)
    }
}

func main() {
    r := gin.New()
    r.Use(WithMetrics(AuthMiddleware()))
    r.Run()
}

3.3 责任链模式 #

go
type MiddlewareChain struct {
    middlewares []gin.HandlerFunc
}

func NewChain() *MiddlewareChain {
    return &MiddlewareChain{}
}

func (c *MiddlewareChain) Use(middleware gin.HandlerFunc) *MiddlewareChain {
    c.middlewares = append(c.middlewares, middleware)
    return c
}

func (c *MiddlewareChain) Then(handler gin.HandlerFunc) gin.HandlerFunc {
    return func(ctx *gin.Context) {
        handlers := append(c.middlewares, handler)
        ctx.handlers = handlers
        ctx.Next()
    }
}

func main() {
    r := gin.New()
    
    chain := NewChain().
        Use(AuthMiddleware()).
        Use(RateLimitMiddleware(100))
    
    r.GET("/protected", chain.Then(handler))
    
    r.Run()
}

四、中间件测试 #

4.1 单元测试 #

go
func TestAuthMiddleware(t *testing.T) {
    r := gin.New()
    r.Use(AuthMiddleware())
    r.GET("/test", func(c *gin.Context) {
        c.String(200, "ok")
    })
    
    // 测试无Token
    w := httptest.NewRecorder()
    req, _ := http.NewRequest("GET", "/test", nil)
    r.ServeHTTP(w, req)
    
    assert.Equal(t, 401, w.Code)
    
    // 测试有效Token
    w = httptest.NewRecorder()
    req, _ = http.NewRequest("GET", "/test", nil)
    req.Header.Set("Authorization", "Bearer valid_token")
    r.ServeHTTP(w, req)
    
    assert.Equal(t, 200, w.Code)
}

4.2 集成测试 #

go
func TestMiddlewareChain(t *testing.T) {
    r := gin.New()
    r.Use(LoggerMiddleware())
    r.Use(AuthMiddleware())
    r.Use(RateLimitMiddleware(100))
    
    r.GET("/protected", func(c *gin.Context) {
        c.JSON(200, gin.H{"message": "ok"})
    })
    
    w := httptest.NewRecorder()
    req, _ := http.NewRequest("GET", "/protected", nil)
    req.Header.Set("Authorization", "Bearer valid_token")
    r.ServeHTTP(w, req)
    
    assert.Equal(t, 200, w.Code)
}

4.3 Mock测试 #

go
type MockAuthService struct {
    validToken string
}

func (m *MockAuthService) ValidateToken(token string) (*Claims, error) {
    if token == m.validToken {
        return &Claims{UserID: "123", Role: "user"}, nil
    }
    return nil, errors.New("invalid token")
}

func TestAuthMiddlewareWithMock(t *testing.T) {
    mockAuth := &MockAuthService{validToken: "test_token"}
    
    r := gin.New()
    r.Use(AuthMiddlewareWithService(mockAuth))
    r.GET("/test", func(c *gin.Context) {
        c.String(200, "ok")
    })
    
    w := httptest.NewRecorder()
    req, _ := http.NewRequest("GET", "/test", nil)
    req.Header.Set("Authorization", "Bearer test_token")
    r.ServeHTTP(w, req)
    
    assert.Equal(t, 200, w.Code)
}

五、中间件最佳实践 #

5.1 错误处理 #

go
func SafeMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        defer func() {
            if err := recover(); err != nil {
                log.Printf("[PANIC] %v\n%s", err, debug.Stack())
                c.AbortWithStatusJSON(500, gin.H{
                    "code":    500,
                    "message": "服务器内部错误",
                })
            }
        }()
        
        c.Next()
    }
}

5.2 性能优化 #

go
func OptimizedMiddleware() gin.HandlerFunc {
    // 预分配资源
    pool := sync.Pool{
        New: func() interface{} {
            return new(bytes.Buffer)
        },
    }
    
    return func(c *gin.Context) {
        buf := pool.Get().(*bytes.Buffer)
        defer func() {
            buf.Reset()
            pool.Put(buf)
        }()
        
        // 使用buf处理请求
        c.Next()
    }
}

5.3 可配置性 #

go
type ConfigurableMiddleware struct {
    enabled bool
    config  *Config
}

func (m *ConfigurableMiddleware) Handler() gin.HandlerFunc {
    return func(c *gin.Context) {
        if !m.enabled {
            c.Next()
            return
        }
        
        // 中间件逻辑
        c.Next()
    }
}

六、总结 #

6.1 核心要点 #

要点 说明
基本结构 返回HandlerFunc的函数
参数传递 闭包捕获参数
配置化 使用配置结构体
错误处理 defer recover

6.2 最佳实践 #

实践 说明
单一职责 每个中间件只做一件事
可配置 支持配置参数
可测试 编写单元测试
错误恢复 处理panic

6.3 下一步 #

现在你已经掌握了自定义中间件,接下来让我们学习 常用中间件,了解Gin生态中的常用中间件!

最后更新:2026-03-28