常用中间件 #
一、JWT认证中间件 #
1.1 基本实现 #
go
package middleware
import (
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/golang-jwt/jwt/v4"
)
type JWTConfig struct {
Secret string
TokenHeader string
ContextKey string
SkipPaths []string
}
func JWTProtected(config JWTConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// 跳过指定路径
for _, path := range config.SkipPaths {
if c.Path() == path {
return c.Next()
}
}
// 获取Token
authHeader := c.Get(config.TokenHeader)
if authHeader == "" {
return c.Status(401).JSON(fiber.Map{
"error": "Missing token",
})
}
// 解析Bearer Token
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return c.Status(401).JSON(fiber.Map{
"error": "Invalid token format",
})
}
// 验证Token
token, err := jwt.Parse(parts[1], func(token *jwt.Token) (interface{}, error) {
return []byte(config.Secret), nil
})
if err != nil || !token.Valid {
return c.Status(401).JSON(fiber.Map{
"error": "Invalid token",
})
}
// 存储Claims
c.Locals(config.ContextKey, token.Claims)
return c.Next()
}
}
1.2 生成Token #
go
type UserClaims struct {
UserID string `json:"user_id"`
Role string `json:"role"`
jwt.RegisteredClaims
}
func GenerateToken(userID, role, secret string) (string, error) {
claims := UserClaims{
UserID: userID,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(secret))
}
1.3 使用示例 #
go
func main() {
app := fiber.New()
// 公开路由
app.Post("/login", loginHandler)
// 受保护路由
protected := app.Group("/api", JWTProtected(JWTConfig{
Secret: "your-secret-key",
TokenHeader: "Authorization",
ContextKey: "user",
SkipPaths: []string{"/health"},
}))
protected.Get("/profile", getProfile)
app.Listen(":3000")
}
func getProfile(c *fiber.Ctx) error {
claims := c.Locals("user").(jwt.MapClaims)
return c.JSON(fiber.Map{
"user_id": claims["user_id"],
"role": claims["role"],
})
}
二、权限控制中间件 #
2.1 角色权限 #
go
func RoleRequired(roles ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
claims, ok := c.Locals("user").(jwt.MapClaims)
if !ok {
return c.Status(401).JSON(fiber.Map{
"error": "Unauthorized",
})
}
userRole := claims["role"].(string)
for _, role := range roles {
if userRole == role {
return c.Next()
}
}
return c.Status(403).JSON(fiber.Map{
"error": "Forbidden",
})
}
}
// 使用
admin := app.Group("/admin",
JWTProtected(config),
RoleRequired("admin"),
)
2.2 资源权限 #
go
func ResourcePermission(resource, action string) fiber.Handler {
return func(c *fiber.Ctx) error {
claims := c.Locals("user").(jwt.MapClaims)
userID := claims["user_id"].(string)
// 检查权限
hasPermission := checkPermission(userID, resource, action)
if !hasPermission {
return c.Status(403).JSON(fiber.Map{
"error": "Permission denied",
})
}
return c.Next()
}
}
// 使用
app.Delete("/users/:id",
JWTProtected(config),
ResourcePermission("users", "delete"),
deleteUser,
)
三、请求日志中间件 #
3.1 结构化日志 #
go
package middleware
import (
"time"
"github.com/gofiber/fiber/v2"
"go.uber.org/zap"
)
func StructuredLogger(logger *zap.Logger) fiber.Handler {
return func(c *fiber.Ctx) error {
start := time.Now()
err := c.Next()
// 记录请求信息
logger.Info("HTTP Request",
zap.String("method", c.Method()),
zap.String("path", c.Path()),
zap.Int("status", c.Response().StatusCode()),
zap.Duration("latency", time.Since(start)),
zap.String("ip", c.IP()),
zap.String("user_agent", c.Get("User-Agent")),
)
return err
}
}
3.2 请求追踪日志 #
go
func TraceLogger() fiber.Handler {
return func(c *fiber.Ctx) error {
requestID := c.Locals("requestid").(string)
start := time.Now()
// 请求开始
log.Printf("[%s] --> %s %s", requestID, c.Method(), c.Path())
err := c.Next()
// 请求结束
log.Printf("[%s] <-- %d %v",
requestID,
c.Response().StatusCode(),
time.Since(start),
)
return err
}
}
四、性能监控中间件 #
4.1 响应时间监控 #
go
func ResponseTime() fiber.Handler {
return func(c *fiber.Ctx) error {
start := time.Now()
err := c.Next()
duration := time.Since(start)
// 设置响应头
c.Set("X-Response-Time", duration.String())
// 慢请求告警
if duration > 1*time.Second {
log.Printf("Slow request: %s %s took %v",
c.Method(), c.Path(), duration)
}
return err
}
}
4.2 性能指标收集 #
go
type Metrics struct {
RequestCount int64
ResponseTime time.Duration
ErrorCount int64
}
func MetricsMiddleware(metrics *Metrics) fiber.Handler {
return func(c *fiber.Ctx) error {
atomic.AddInt64(&metrics.RequestCount, 1)
start := time.Now()
err := c.Next()
duration := time.Since(start)
atomic.AddInt64((*int64)(&metrics.ResponseTime), int64(duration))
if c.Response().StatusCode() >= 400 {
atomic.AddInt64(&metrics.ErrorCount, 1)
}
return err
}
}
五、请求验证中间件 #
5.1 Body验证 #
go
import "github.com/go-playground/validator/v10"
type Validator struct {
validate *validator.Validate
}
func NewValidator() *Validator {
return &Validator{
validate: validator.New(),
}
}
func (v *Validator) ValidateBody(s interface{}) fiber.Handler {
return func(c *fiber.Ctx) error {
if err := c.BodyParser(s); err != nil {
return c.Status(400).JSON(fiber.Map{
"error": "Invalid request body",
})
}
if err := v.validate.Struct(s); err != nil {
return c.Status(400).JSON(fiber.Map{
"error": "Validation failed",
"fields": err.Error(),
})
}
return c.Next()
}
}
// 使用
type CreateUserRequest struct {
Name string `json:"name" validate:"required,min=2,max=50"`
Email string `json:"email" validate:"required,email"`
Age int `json:"age" validate:"min=0,max=150"`
}
validator := NewValidator()
app.Post("/users",
validator.ValidateBody(&CreateUserRequest{}),
createUser,
)
5.2 Query验证 #
go
func ValidateQuery(s interface{}) fiber.Handler {
return func(c *fiber.Ctx) error {
if err := c.QueryParser(s); err != nil {
return c.Status(400).JSON(fiber.Map{
"error": "Invalid query parameters",
})
}
validate := validator.New()
if err := validate.Struct(s); err != nil {
return c.Status(400).JSON(fiber.Map{
"error": err.Error(),
})
}
return c.Next()
}
}
六、跨域处理中间件 #
6.1 自定义CORS #
go
func CustomCORS(allowedOrigins []string) fiber.Handler {
return func(c *fiber.Ctx) error {
origin := c.Get("Origin")
// 检查是否允许
allowed := false
for _, o := range allowedOrigins {
if o == "*" || o == origin {
allowed = true
break
}
}
if allowed {
c.Set("Access-Control-Allow-Origin", origin)
c.Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS")
c.Set("Access-Control-Allow-Headers", "Content-Type,Authorization")
c.Set("Access-Control-Allow-Credentials", "true")
}
// 处理预检请求
if c.Method() == "OPTIONS" {
return c.SendStatus(204)
}
return c.Next()
}
}
七、限流中间件 #
7.1 基于IP限流 #
go
import "golang.org/x/time/rate"
func RateLimit(rps int, burst int) fiber.Handler {
limiters := make(map[string]*rate.Limiter)
mu := sync.Mutex{}
return func(c *fiber.Ctx) error {
ip := c.IP()
mu.Lock()
limiter, exists := limiters[ip]
if !exists {
limiter = rate.NewLimiter(rate.Limit(rps), burst)
limiters[ip] = limiter
}
mu.Unlock()
if !limiter.Allow() {
return c.Status(429).JSON(fiber.Map{
"error": "Too many requests",
})
}
return c.Next()
}
}
7.2 基于用户限流 #
go
func UserRateLimit(rps int, burst int) fiber.Handler {
limiters := make(map[string]*rate.Limiter)
mu := sync.Mutex{}
return func(c *fiber.Ctx) error {
// 获取用户ID
claims := c.Locals("user")
if claims == nil {
return c.Next()
}
userID := claims.(jwt.MapClaims)["user_id"].(string)
mu.Lock()
limiter, exists := limiters[userID]
if !exists {
limiter = rate.NewLimiter(rate.Limit(rps), burst)
limiters[userID] = limiter
}
mu.Unlock()
if !limiter.Allow() {
return c.Status(429).JSON(fiber.Map{
"error": "Rate limit exceeded",
})
}
return c.Next()
}
}
八、缓存中间件 #
8.1 响应缓存 #
go
type CacheStore interface {
Get(key string) ([]byte, bool)
Set(key string, value []byte, ttl time.Duration)
}
func CacheMiddleware(store CacheStore, ttl time.Duration) fiber.Handler {
return func(c *fiber.Ctx) error {
// 只缓存GET请求
if c.Method() != fiber.MethodGet {
return c.Next()
}
// 生成缓存键
key := c.Path() + "?" + c.Request().URI().Query().String()
// 尝试从缓存获取
if cached, found := store.Get(key); found {
c.Set("X-Cache", "HIT")
return c.Send(cached)
}
// 执行请求
err := c.Next()
if err != nil {
return err
}
// 缓存响应
store.Set(key, c.Response().Body(), ttl)
c.Set("X-Cache", "MISS")
return nil
}
}
九、错误处理中间件 #
9.1 统一错误处理 #
go
type AppError struct {
Code int `json:"code"`
Message string `json:"message"`
}
func (e *AppError) Error() string {
return e.Message
}
func ErrorHandler(c *fiber.Ctx, err error) error {
code := fiber.StatusInternalServerError
message := "Internal Server Error"
if e, ok := err.(*AppError); ok {
code = e.Code
message = e.Message
} else if e, ok := err.(*fiber.Error); ok {
code = e.Code
message = e.Message
}
return c.Status(code).JSON(fiber.Map{
"error": message,
"code": code,
})
}
// 配置
app := fiber.New(fiber.Config{
ErrorHandler: ErrorHandler,
})
十、中间件组合示例 #
10.1 完整中间件链 #
go
func main() {
app := fiber.New()
// 1. 错误恢复
app.Use(recover.New())
// 2. 请求ID
app.Use(requestid.New())
// 3. 日志记录
app.Use(logger.New())
// 4. CORS
app.Use(cors.New())
// 5. 安全头
app.Use(helmet.New())
// 6. 压缩
app.Use(compress.New())
// 7. 限流
app.Use(RateLimit(100, 20))
// 公开路由
app.Post("/login", loginHandler)
app.Get("/health", healthHandler)
// API路由
api := app.Group("/api")
// 8. JWT认证
api.Use(JWTProtected(JWTConfig{
Secret: "secret",
}))
// 9. 权限检查
api.Use(RoleRequired("user"))
// 业务路由
api.Get("/profile", getProfile)
api.Get("/users", getUsers)
// 管理员路由
admin := api.Group("/admin")
admin.Use(RoleRequired("admin"))
admin.Get("/stats", getStats)
app.Listen(":3000")
}
十一、总结 #
11.1 常用中间件清单 #
| 中间件 | 用途 | 优先级 |
|---|---|---|
| Recover | 错误恢复 | 最高 |
| RequestID | 请求追踪 | 高 |
| Logger | 日志记录 | 高 |
| CORS | 跨域处理 | 高 |
| Helmet | 安全头 | 高 |
| RateLimit | 请求限流 | 中 |
| JWT | 身份认证 | 中 |
| Permission | 权限控制 | 中 |
| Validator | 数据验证 | 中 |
| Cache | 响应缓存 | 低 |
11.2 下一步 #
现在你已经了解了常用中间件,接下来让我们学习 Context上下文,深入了解Fiber的请求处理机制!
最后更新:2026-03-28