路由中间件 #

一、路由中间件概述 #

1.1 什么是路由中间件 #

路由中间件是在特定路由或路由分组上应用的中间件,用于处理该路由的请求:

text
请求 → 路由中间件 → 处理函数 → 响应

1.2 中间件层级 #

text
全局中间件
    ↓
分组中间件
    ↓
路由中间件
    ↓
处理函数

二、路由级中间件 #

2.1 单路由中间件 #

go
func main() {
    r := gin.Default()
    
    // 单个路由添加中间件
    r.GET("/profile", AuthMiddleware(), func(c *gin.Context) {
        c.JSON(200, gin.H{"message": "profile"})
    })
    
    r.Run()
}

func AuthMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        token := c.GetHeader("Authorization")
        if token == "" {
            c.AbortWithStatusJSON(401, gin.H{"error": "unauthorized"})
            return
        }
        c.Next()
    }
}

2.2 多个中间件 #

go
func main() {
    r := gin.Default()
    
    // 多个中间件按顺序执行
    r.GET("/admin",
        AuthMiddleware(),
        AdminMiddleware(),
        func(c *gin.Context) {
            c.JSON(200, gin.H{"message": "admin dashboard"})
        },
    )
    
    r.Run()
}

func AdminMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        role, _ := c.Get("role")
        if role != "admin" {
            c.AbortWithStatusJSON(403, gin.H{"error": "forbidden"})
            return
        }
        c.Next()
    }
}

2.3 中间件执行顺序 #

go
func main() {
    r := gin.Default()
    
    r.GET("/test",
        func(c *gin.Context) {
            fmt.Println("中间件1 - Before")
            c.Next()
            fmt.Println("中间件1 - After")
        },
        func(c *gin.Context) {
            fmt.Println("中间件2 - Before")
            c.Next()
            fmt.Println("中间件2 - After")
        },
        func(c *gin.Context) {
            fmt.Println("处理函数")
            c.String(200, "ok")
        },
    )
    
    r.Run()
}

// 输出顺序:
// 中间件1 - Before
// 中间件2 - Before
// 处理函数
// 中间件2 - After
// 中间件1 - After

三、分组中间件 #

3.1 分组级别中间件 #

go
func main() {
    r := gin.Default()
    
    // API分组 - 需要认证
    api := r.Group("/api")
    api.Use(AuthMiddleware())
    {
        api.GET("/users", listUsers)
        api.GET("/posts", listPosts)
    }
    
    // 公开分组 - 不需要认证
    public := r.Group("/public")
    {
        public.GET("/info", getPublicInfo)
    }
    
    r.Run()
}

3.2 嵌套分组中间件 #

go
func main() {
    r := gin.Default()
    
    api := r.Group("/api")
    api.Use(LoggerMiddleware())
    {
        // 用户分组 - 需要认证
        users := api.Group("/users")
        users.Use(AuthMiddleware())
        {
            users.GET("", listUsers)
            users.GET("/:id", getUser)
        }
        
        // 管理员分组 - 需要认证和管理员权限
        admin := api.Group("/admin")
        admin.Use(AuthMiddleware())
        admin.Use(AdminMiddleware())
        {
            admin.GET("/dashboard", adminDashboard)
            admin.GET("/stats", adminStats)
        }
        
        // 公开分组 - 不需要认证
        public := api.Group("/public")
        {
            public.GET("/info", getPublicInfo)
        }
    }
    
    r.Run()
}

3.3 条件中间件 #

go
func main() {
    r := gin.Default()
    
    api := r.Group("/api")
    
    // 根据环境添加中间件
    if gin.Mode() == gin.ReleaseMode {
        api.Use(RateLimitMiddleware())
    }
    
    api.GET("/users", listUsers)
    
    r.Run()
}

四、常用路由中间件 #

4.1 认证中间件 #

go
func AuthMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        token := c.GetHeader("Authorization")
        
        if token == "" {
            c.AbortWithStatusJSON(401, gin.H{
                "code":    401,
                "message": "未提供认证令牌",
            })
            return
        }
        
        // 验证Token
        claims, err := validateToken(token)
        if err != nil {
            c.AbortWithStatusJSON(401, gin.H{
                "code":    401,
                "message": "无效的认证令牌",
            })
            return
        }
        
        // 存储用户信息到上下文
        c.Set("userId", claims.UserID)
        c.Set("role", claims.Role)
        
        c.Next()
    }
}

4.2 权限中间件 #

go
func RoleMiddleware(roles ...string) gin.HandlerFunc {
    return func(c *gin.Context) {
        userRole, exists := c.Get("role")
        if !exists {
            c.AbortWithStatusJSON(403, gin.H{
                "code":    403,
                "message": "无权限访问",
            })
            return
        }
        
        role := userRole.(string)
        for _, r := range roles {
            if role == r {
                c.Next()
                return
            }
        }
        
        c.AbortWithStatusJSON(403, gin.H{
            "code":    403,
            "message": "权限不足",
        })
    }
}

func main() {
    r := gin.Default()
    
    admin := r.Group("/admin")
    admin.Use(AuthMiddleware())
    admin.Use(RoleMiddleware("admin", "super_admin"))
    {
        admin.GET("/users", listUsers)
    }
    
    r.Run()
}

4.3 限流中间件 #

go
func RateLimitMiddleware(limit int, window time.Duration) gin.HandlerFunc {
    limiter := make(map[string]*rateLimiter)
    mu := sync.RWMutex{}
    
    return func(c *gin.Context) {
        ip := c.ClientIP()
        
        mu.Lock()
        if _, exists := limiter[ip]; !exists {
            limiter[ip] = newRateLimiter(limit, window)
        }
        mu.Unlock()
        
        if !limiter[ip].Allow() {
            c.AbortWithStatusJSON(429, gin.H{
                "code":    429,
                "message": "请求过于频繁,请稍后再试",
            })
            return
        }
        
        c.Next()
    }
}

func main() {
    r := gin.Default()
    
    api := r.Group("/api")
    api.Use(RateLimitMiddleware(100, time.Minute))
    {
        api.GET("/users", listUsers)
    }
    
    r.Run()
}

4.4 日志中间件 #

go
func LoggerMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        start := time.Now()
        path := c.Request.URL.Path
        method := c.Request.Method
        
        c.Next()
        
        latency := time.Since(start)
        statusCode := c.Writer.Status()
        clientIP := c.ClientIP()
        
        log.Printf("[%s] %s %s %d %v %s",
            method,
            path,
            clientIP,
            statusCode,
            latency,
            c.Errors.String(),
        )
    }
}

4.5 CORS中间件 #

go
func CORSMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        c.Header("Access-Control-Allow-Origin", "*")
        c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
        c.Header("Access-Control-Allow-Headers", "Content-Type, Authorization")
        c.Header("Access-Control-Expose-Headers", "Content-Length, Content-Type")
        c.Header("Access-Control-Max-Age", "86400")
        
        if c.Request.Method == "OPTIONS" {
            c.AbortWithStatus(204)
            return
        }
        
        c.Next()
    }
}

func main() {
    r := gin.Default()
    
    api := r.Group("/api")
    api.Use(CORSMiddleware())
    {
        api.GET("/users", listUsers)
    }
    
    r.Run()
}

4.6 请求ID中间件 #

go
func RequestIDMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        requestID := c.GetHeader("X-Request-ID")
        if requestID == "" {
            requestID = uuid.New().String()
        }
        
        c.Set("requestId", requestID)
        c.Header("X-Request-ID", requestID)
        
        c.Next()
    }
}

4.7 超时中间件 #

go
func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc {
    return func(c *gin.Context) {
        ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
        defer cancel()
        
        c.Request = c.Request.WithContext(ctx)
        
        done := make(chan struct{})
        go func() {
            c.Next()
            close(done)
        }()
        
        select {
        case <-done:
            return
        case <-ctx.Done():
            c.AbortWithStatusJSON(504, gin.H{
                "code":    504,
                "message": "请求超时",
            })
            return
        }
    }
}

五、中间件组合 #

5.1 链式中间件 #

go
func main() {
    r := gin.Default()
    
    // 链式组合中间件
    protected := r.Group("/protected")
    protected.Use(
        RequestIDMiddleware(),
        LoggerMiddleware(),
        CORSMiddleware(),
        AuthMiddleware(),
    )
    {
        protected.GET("/profile", getProfile)
        protected.GET("/settings", getSettings)
    }
    
    r.Run()
}

5.2 中间件工厂 #

go
type MiddlewareFactory struct {
    config *Config
}

func NewMiddlewareFactory(config *Config) *MiddlewareFactory {
    return &MiddlewareFactory{config: config}
}

func (f *MiddlewareFactory) Auth() gin.HandlerFunc {
    return AuthMiddleware(f.config.JWTSecret)
}

func (f *MiddlewareFactory) RateLimit(limit int) gin.HandlerFunc {
    return RateLimitMiddleware(limit, f.config.RateLimitWindow)
}

func (f *MiddlewareFactory) CORS() gin.HandlerFunc {
    return CORSMiddleware(f.config.AllowedOrigins)
}

func main() {
    config := LoadConfig()
    factory := NewMiddlewareFactory(config)
    
    r := gin.Default()
    
    api := r.Group("/api")
    api.Use(factory.CORS())
    api.Use(factory.Auth())
    api.Use(factory.RateLimit(100))
    {
        api.GET("/users", listUsers)
    }
    
    r.Run()
}

六、中间件最佳实践 #

6.1 中间件组织 #

text
middleware/
├── auth.go         # 认证中间件
├── cors.go         # CORS中间件
├── logger.go       # 日志中间件
├── ratelimit.go    # 限流中间件
├── recovery.go     # 恢复中间件
└── middleware.go   # 中间件工厂

6.2 中间件配置 #

go
type MiddlewareConfig struct {
    EnableAuth      bool
    EnableCORS      bool
    EnableRateLimit bool
    RateLimit       int
}

func SetupMiddleware(r *gin.Engine, config *MiddlewareConfig) {
    // 全局中间件
    r.Use(gin.Recovery())
    r.Use(LoggerMiddleware())
    
    if config.EnableCORS {
        r.Use(CORSMiddleware())
    }
    
    // API分组中间件
    api := r.Group("/api")
    if config.EnableAuth {
        api.Use(AuthMiddleware())
    }
    
    if config.EnableRateLimit {
        api.Use(RateLimitMiddleware(config.RateLimit, time.Minute))
    }
}

6.3 错误处理 #

go
func ErrorHandlerMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        c.Next()
        
        // 处理中间件或处理函数中的错误
        if len(c.Errors) > 0 {
            err := c.Errors.Last()
            
            // 根据错误类型返回不同状态码
            switch err.Type {
            case gin.ErrorTypeBind:
                c.JSON(400, gin.H{
                    "code":    400,
                    "message": "请求参数错误",
                    "error":   err.Error(),
                })
            case gin.ErrorTypePrivate:
                c.JSON(500, gin.H{
                    "code":    500,
                    "message": "服务器内部错误",
                })
            default:
                c.JSON(500, gin.H{
                    "code":    500,
                    "message": err.Error(),
                })
            }
        }
    }
}

七、总结 #

7.1 核心要点 #

要点 说明
路由中间件 针对特定路由的中间件
分组中间件 针对路由分组的中间件
执行顺序 先注册先执行(Before),后注册后执行(After)
中断请求 c.Abort()

7.2 最佳实践 #

实践 说明
分层设计 全局、分组、路由级别
单一职责 每个中间件只做一件事
配置化 通过配置控制中间件开关
错误处理 统一错误处理机制

7.3 下一步 #

现在你已经掌握了路由中间件,接下来让我们学习 中间件概念,深入了解中间件的原理!

最后更新:2026-03-28