自定义中间件 #
一、中间件基础结构 #
1.1 基本格式 #
go
func MyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 前置处理:在c.Next()之前
c.Next() // 调用下一个处理函数
// 后置处理:在c.Next()之后
}
}
1.2 带参数的中间件 #
go
func RateLimitMiddleware(rps int) gin.HandlerFunc {
limiter := rate.NewLimiter(rate.Limit(rps), rps)
return func(c *gin.Context) {
if !limiter.Allow() {
c.AbortWithStatusJSON(429, gin.H{
"code": 429,
"message": "请求过于频繁",
})
return
}
c.Next()
}
}
func main() {
r := gin.New()
r.Use(RateLimitMiddleware(100))
r.Run()
}
1.3 带配置的中间件 #
go
type CORSConfig struct {
AllowOrigins []string
AllowMethods []string
AllowHeaders []string
ExposeHeaders []string
AllowCredentials bool
MaxAge int
}
func CORSWithConfig(config CORSConfig) gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.GetHeader("Origin")
// 检查是否允许的源
allowed := false
for _, o := range config.AllowOrigins {
if o == "*" || o == origin {
allowed = true
break
}
}
if allowed {
c.Header("Access-Control-Allow-Origin", origin)
}
if len(config.AllowMethods) > 0 {
c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ", "))
}
if len(config.AllowHeaders) > 0 {
c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ", "))
}
if config.AllowCredentials {
c.Header("Access-Control-Allow-Credentials", "true")
}
if config.MaxAge > 0 {
c.Header("Access-Control-Max-Age", strconv.Itoa(config.MaxAge))
}
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}
func main() {
r := gin.New()
config := CORSConfig{
AllowOrigins: []string{"http://localhost:3000", "https://example.com"},
AllowMethods: []string{"GET", "POST", "PUT", "DELETE"},
AllowHeaders: []string{"Content-Type", "Authorization"},
AllowCredentials: true,
MaxAge: 86400,
}
r.Use(CORSWithConfig(config))
r.Run()
}
二、常用自定义中间件 #
2.1 认证中间件 #
go
type JWTConfig struct {
Secret string
TokenLookup string
}
func JWTAuth(config JWTConfig) gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Token
var token string
switch {
case strings.HasPrefix(config.TokenLookup, "header:"):
header := strings.TrimPrefix(config.TokenLookup, "header:")
token = c.GetHeader(header)
if strings.HasPrefix(token, "Bearer ") {
token = strings.TrimPrefix(token, "Bearer ")
}
case strings.HasPrefix(config.TokenLookup, "query:"):
param := strings.TrimPrefix(config.TokenLookup, "query:")
token = c.Query(param)
case strings.HasPrefix(config.TokenLookup, "cookie:"):
cookie := strings.TrimPrefix(config.TokenLookup, "cookie:")
token, _ = c.Cookie(cookie)
}
if token == "" {
c.AbortWithStatusJSON(401, gin.H{
"code": 401,
"message": "未提供认证令牌",
})
return
}
// 验证Token
claims, err := validateJWT(token, config.Secret)
if err != nil {
c.AbortWithStatusJSON(401, gin.H{
"code": 401,
"message": "无效的认证令牌",
})
return
}
// 存储用户信息
c.Set("userId", claims.UserID)
c.Set("role", claims.Role)
c.Next()
}
}
2.2 权限中间件 #
go
func RequireRoles(roles ...string) gin.HandlerFunc {
return func(c *gin.Context) {
userRole, exists := c.Get("role")
if !exists {
c.AbortWithStatusJSON(403, gin.H{
"code": 403,
"message": "无权限访问",
})
return
}
role := userRole.(string)
for _, r := range roles {
if role == r {
c.Next()
return
}
}
c.AbortWithStatusJSON(403, gin.H{
"code": 403,
"message": "权限不足",
})
}
}
func main() {
r := gin.New()
r.Use(JWTAuth(config))
admin := r.Group("/admin")
admin.Use(RequireRoles("admin", "super_admin"))
{
admin.GET("/users", listUsers)
}
r.Run()
}
2.3 请求日志中间件 #
go
type LoggerConfig struct {
Output io.Writer
SkipPaths []string
Format string
}
func LoggerWithConfig(config LoggerConfig) gin.HandlerFunc {
skipPaths := make(map[string]bool)
for _, path := range config.SkipPaths {
skipPaths[path] = true
}
return func(c *gin.Context) {
path := c.Request.URL.Path
// 跳过指定路径
if skipPaths[path] {
c.Next()
return
}
start := time.Now()
c.Next()
latency := time.Since(start)
statusCode := c.Writer.Status()
clientIP := c.ClientIP()
method := c.Request.Method
log.Printf("[%s] %s %s %d %v",
method,
path,
clientIP,
statusCode,
latency,
)
}
}
2.4 请求体缓存 #
go
func BodyCache() gin.HandlerFunc {
return func(c *gin.Context) {
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.AbortWithStatusJSON(400, gin.H{"error": err.Error()})
return
}
// 恢复请求体
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
// 缓存请求体
c.Set("rawBody", body)
c.Next()
}
}
2.5 响应包装 #
go
type ResponseWrapper struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
func ResponseWrapperMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 创建响应写入器
writer := &responseWriter{
ResponseWriter: c.Writer,
body: &bytes.Buffer{},
}
c.Writer = writer
c.Next()
// 包装响应
if writer.status == 200 && len(writer.body.Bytes()) > 0 {
var data interface{}
json.Unmarshal(writer.body.Bytes(), &data)
response := ResponseWrapper{
Code: 0,
Message: "success",
Data: data,
}
jsonData, _ := json.Marshal(response)
c.Writer = writer.ResponseWriter
c.Writer.Write(jsonData)
}
}
}
2.6 超时控制 #
go
func Timeout(timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
c.Request = c.Request.WithContext(ctx)
done := make(chan struct{})
go func() {
c.Next()
close(done)
}()
select {
case <-done:
return
case <-ctx.Done():
c.AbortWithStatusJSON(504, gin.H{
"code": 504,
"message": "请求超时",
})
return
}
}
}
三、中间件设计模式 #
3.1 工厂模式 #
go
type MiddlewareFactory struct {
config *Config
}
func NewMiddlewareFactory(config *Config) *MiddlewareFactory {
return &MiddlewareFactory{config: config}
}
func (f *MiddlewareFactory) Auth() gin.HandlerFunc {
return JWTAuth(JWTConfig{
Secret: f.config.JWTSecret,
TokenLookup: "header:Authorization",
})
}
func (f *MiddlewareFactory) CORS() gin.HandlerFunc {
return CORSWithConfig(CORSConfig{
AllowOrigins: f.config.CORSOrigins,
})
}
func (f *MiddlewareFactory) RateLimit() gin.HandlerFunc {
return RateLimitMiddleware(f.config.RateLimit)
}
func main() {
config := LoadConfig()
factory := NewMiddlewareFactory(config)
r := gin.New()
r.Use(factory.CORS())
r.Use(factory.Auth())
r.Use(factory.RateLimit())
r.Run()
}
3.2 装饰器模式 #
go
func WithMetrics(middleware gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
middleware(c)
duration := time.Since(start)
metrics.RecordRequestDuration(duration)
}
}
func main() {
r := gin.New()
r.Use(WithMetrics(AuthMiddleware()))
r.Run()
}
3.3 责任链模式 #
go
type MiddlewareChain struct {
middlewares []gin.HandlerFunc
}
func NewChain() *MiddlewareChain {
return &MiddlewareChain{}
}
func (c *MiddlewareChain) Use(middleware gin.HandlerFunc) *MiddlewareChain {
c.middlewares = append(c.middlewares, middleware)
return c
}
func (c *MiddlewareChain) Then(handler gin.HandlerFunc) gin.HandlerFunc {
return func(ctx *gin.Context) {
handlers := append(c.middlewares, handler)
ctx.handlers = handlers
ctx.Next()
}
}
func main() {
r := gin.New()
chain := NewChain().
Use(AuthMiddleware()).
Use(RateLimitMiddleware(100))
r.GET("/protected", chain.Then(handler))
r.Run()
}
四、中间件测试 #
4.1 单元测试 #
go
func TestAuthMiddleware(t *testing.T) {
r := gin.New()
r.Use(AuthMiddleware())
r.GET("/test", func(c *gin.Context) {
c.String(200, "ok")
})
// 测试无Token
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
r.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
// 测试有效Token
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer valid_token")
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
4.2 集成测试 #
go
func TestMiddlewareChain(t *testing.T) {
r := gin.New()
r.Use(LoggerMiddleware())
r.Use(AuthMiddleware())
r.Use(RateLimitMiddleware(100))
r.GET("/protected", func(c *gin.Context) {
c.JSON(200, gin.H{"message": "ok"})
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", "Bearer valid_token")
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
4.3 Mock测试 #
go
type MockAuthService struct {
validToken string
}
func (m *MockAuthService) ValidateToken(token string) (*Claims, error) {
if token == m.validToken {
return &Claims{UserID: "123", Role: "user"}, nil
}
return nil, errors.New("invalid token")
}
func TestAuthMiddlewareWithMock(t *testing.T) {
mockAuth := &MockAuthService{validToken: "test_token"}
r := gin.New()
r.Use(AuthMiddlewareWithService(mockAuth))
r.GET("/test", func(c *gin.Context) {
c.String(200, "ok")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer test_token")
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
五、中间件最佳实践 #
5.1 错误处理 #
go
func SafeMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
log.Printf("[PANIC] %v\n%s", err, debug.Stack())
c.AbortWithStatusJSON(500, gin.H{
"code": 500,
"message": "服务器内部错误",
})
}
}()
c.Next()
}
}
5.2 性能优化 #
go
func OptimizedMiddleware() gin.HandlerFunc {
// 预分配资源
pool := sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
return func(c *gin.Context) {
buf := pool.Get().(*bytes.Buffer)
defer func() {
buf.Reset()
pool.Put(buf)
}()
// 使用buf处理请求
c.Next()
}
}
5.3 可配置性 #
go
type ConfigurableMiddleware struct {
enabled bool
config *Config
}
func (m *ConfigurableMiddleware) Handler() gin.HandlerFunc {
return func(c *gin.Context) {
if !m.enabled {
c.Next()
return
}
// 中间件逻辑
c.Next()
}
}
六、总结 #
6.1 核心要点 #
| 要点 | 说明 |
|---|---|
| 基本结构 | 返回HandlerFunc的函数 |
| 参数传递 | 闭包捕获参数 |
| 配置化 | 使用配置结构体 |
| 错误处理 | defer recover |
6.2 最佳实践 #
| 实践 | 说明 |
|---|---|
| 单一职责 | 每个中间件只做一件事 |
| 可配置 | 支持配置参数 |
| 可测试 | 编写单元测试 |
| 错误恢复 | 处理panic |
6.3 下一步 #
现在你已经掌握了自定义中间件,接下来让我们学习 常用中间件,了解Gin生态中的常用中间件!
最后更新:2026-03-28