自定义中间件 #
一、中间件基础结构 #
1.1 最简单的中间件 #
go
func simpleMiddleware(c *fiber.Ctx) error {
fmt.Println("Before handler")
err := c.Next()
fmt.Println("After handler")
return err
}
// 使用
app.Use(simpleMiddleware)
1.2 标准中间件模式 #
go
func myMiddleware() fiber.Handler {
return func(c *fiber.Ctx) error {
// 前置处理
err := c.Next()
// 后置处理
return err
}
}
// 使用
app.Use(myMiddleware())
二、带配置的中间件 #
2.1 配置结构体 #
go
type AuthConfig struct {
TokenHeader string
Secret string
SkipPaths []string
}
func AuthMiddleware(config AuthConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// 检查是否跳过
for _, path := range config.SkipPaths {
if c.Path() == path {
return c.Next()
}
}
// 获取Token
token := c.Get(config.TokenHeader)
if token == "" {
return c.Status(401).JSON(fiber.Map{
"error": "Missing token",
})
}
// 验证Token
if !validateToken(token, config.Secret) {
return c.Status(401).JSON(fiber.Map{
"error": "Invalid token",
})
}
return c.Next()
}
}
// 使用
app.Use(AuthMiddleware(AuthConfig{
TokenHeader: "Authorization",
Secret: "my-secret",
SkipPaths: []string{"/health", "/login"},
}))
2.2 默认配置 #
go
type LoggerConfig struct {
Format string
TimeFormat string
Output io.Writer
SkipPaths []string
}
var LoggerConfigDefault = LoggerConfig{
Format: "[${time}] ${status} - ${method} ${path}\n",
TimeFormat: "2006-01-02 15:04:05",
Output: os.Stdout,
SkipPaths: []string{},
}
func LoggerMiddleware(config ...LoggerConfig) fiber.Handler {
cfg := LoggerConfigDefault
if len(config) > 0 {
cfg = config[0]
}
return func(c *fiber.Ctx) error {
// 检查是否跳过
for _, path := range cfg.SkipPaths {
if c.Path() == path {
return c.Next()
}
}
start := time.Now()
err := c.Next()
// 格式化日志
log := cfg.Format
log = strings.Replace(log, "${time}", time.Now().Format(cfg.TimeFormat), -1)
log = strings.Replace(log, "${status}", strconv.Itoa(c.Response().StatusCode()), -1)
log = strings.Replace(log, "${method}", c.Method(), -1)
log = strings.Replace(log, "${path}", c.Path(), -1)
log = strings.Replace(log, "${latency}", time.Since(start).String(), -1)
cfg.Output.Write([]byte(log))
return err
}
}
三、中间件设计模式 #
3.1 认证中间件 #
go
type User struct {
ID string
Name string
Role string
}
func AuthMiddleware() fiber.Handler {
return func(c *fiber.Ctx) error {
token := c.Get("Authorization")
if token == "" {
return c.Status(401).JSON(fiber.Map{
"error": "Unauthorized",
})
}
// 解析Token
user, err := parseToken(token)
if err != nil {
return c.Status(401).JSON(fiber.Map{
"error": "Invalid token",
})
}
// 存储用户信息
c.Locals("user", user)
return c.Next()
}
}
// 获取当前用户
func GetCurrentUser(c *fiber.Ctx) *User {
if user, ok := c.Locals("user").(*User); ok {
return user
}
return nil
}
3.2 权限中间件 #
go
func RoleMiddleware(roles ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
user := GetCurrentUser(c)
if user == nil {
return c.Status(401).JSON(fiber.Map{
"error": "Unauthorized",
})
}
// 检查角色
for _, role := range roles {
if user.Role == role {
return c.Next()
}
}
return c.Status(403).JSON(fiber.Map{
"error": "Forbidden",
})
}
}
// 使用
app.Get("/admin", AuthMiddleware(), RoleMiddleware("admin"), adminHandler)
3.3 缓存中间件 #
go
type CacheConfig struct {
Expiration time.Duration
KeyFunc func(c *fiber.Ctx) string
Store CacheStore
}
func CacheMiddleware(config CacheConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// 只缓存GET请求
if c.Method() != fiber.MethodGet {
return c.Next()
}
// 生成缓存键
key := config.KeyFunc(c)
// 尝试从缓存获取
cached, found := config.Store.Get(key)
if found {
return c.Send(cached.([]byte))
}
// 执行请求
err := c.Next()
if err != nil {
return err
}
// 缓存响应
body := c.Response().Body()
config.Store.Set(key, body, config.Expiration)
return nil
}
}
3.4 请求验证中间件 #
go
type ValidatorConfig struct {
Rules map[string]string
}
func ValidatorMiddleware(config ValidatorConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
errors := make(map[string]string)
for field, rule := range config.Rules {
value := c.Query(field)
switch rule {
case "required":
if value == "" {
errors[field] = field + " is required"
}
case "email":
if !isValidEmail(value) {
errors[field] = field + " is not a valid email"
}
}
}
if len(errors) > 0 {
return c.Status(400).JSON(fiber.Map{
"errors": errors,
})
}
return c.Next()
}
}
// 使用
app.Get("/users", ValidatorMiddleware(ValidatorConfig{
Rules: map[string]string{
"email": "required,email",
"name": "required",
},
}), getUsers)
四、高级中间件 #
4.1 请求追踪中间件 #
go
type TraceConfig struct {
HeaderName string
Generator func() string
}
func TraceMiddleware(config TraceConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// 获取或生成TraceID
traceID := c.Get(config.HeaderName)
if traceID == "" {
traceID = config.Generator()
}
// 设置响应头
c.Set(config.HeaderName, traceID)
// 存储到上下文
c.Locals("traceID", traceID)
// 记录开始时间
start := time.Now()
err := c.Next()
// 记录请求信息
log.Printf("[%s] %s %s %d %v",
traceID,
c.Method(),
c.Path(),
c.Response().StatusCode(),
time.Since(start),
)
return err
}
}
4.2 请求重试中间件 #
go
type RetryConfig struct {
MaxAttempts int
Delay time.Duration
RetryIf func(c *fiber.Ctx, err error) bool
}
func RetryMiddleware(config RetryConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
var err error
for i := 0; i < config.MaxAttempts; i++ {
// 重置请求体
if i > 0 {
c.Request().ResetBody()
}
err = c.Next()
// 检查是否需要重试
if !config.RetryIf(c, err) {
return err
}
// 等待后重试
if i < config.MaxAttempts-1 {
time.Sleep(config.Delay)
}
}
return err
}
}
4.3 请求签名中间件 #
go
type SignConfig struct {
Secret string
HeaderKey string
}
func SignMiddleware(config SignConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// 获取签名
signature := c.Get(config.HeaderKey)
if signature == "" {
return c.Status(401).JSON(fiber.Map{
"error": "Missing signature",
})
}
// 计算预期签名
body := c.Body()
expected := hmacSHA256(body, config.Secret)
// 验证签名
if !hmac.Equal([]byte(signature), []byte(expected)) {
return c.Status(401).JSON(fiber.Map{
"error": "Invalid signature",
})
}
return c.Next()
}
}
五、中间件工具函数 #
5.1 跳过路径 #
go
func SkipPaths(paths []string) func(c *fiber.Ctx) bool {
return func(c *fiber.Ctx) bool {
for _, path := range paths {
if c.Path() == path {
return true
}
}
return false
}
}
func MyMiddleware(skipFunc func(c *fiber.Ctx) bool) fiber.Handler {
return func(c *fiber.Ctx) error {
if skipFunc != nil && skipFunc(c) {
return c.Next()
}
// 中间件逻辑
return c.Next()
}
}
// 使用
app.Use(MyMiddleware(SkipPaths([]string{"/health", "/metrics"})))
5.2 条件执行 #
go
func ConditionalMiddleware(condition func(c *fiber.Ctx) bool, middleware fiber.Handler) fiber.Handler {
return func(c *fiber.Ctx) error {
if condition(c) {
return middleware(c)
}
return c.Next()
}
}
// 使用
app.Use(ConditionalMiddleware(
func(c *fiber.Ctx) bool {
return c.Path() == "/api"
},
authMiddleware,
))
六、中间件测试 #
6.1 单元测试 #
go
func TestAuthMiddleware(t *testing.T) {
app := fiber.New()
app.Use(AuthMiddleware(AuthConfig{
Secret: "test-secret",
}))
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("OK")
})
// 测试无Token
req := httptest.NewRequest("GET", "/", nil)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, 401, resp.StatusCode)
// 测试有效Token
token := generateToken("test-secret")
req = httptest.NewRequest("GET", "/", nil)
req.Header.Set("Authorization", token)
resp, err = app.Test(req)
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
}
6.2 集成测试 #
go
func TestMiddlewareChain(t *testing.T) {
app := fiber.New()
app.Use(requestid.New())
app.Use(logger.New())
app.Use(AuthMiddleware(AuthConfig{
Secret: "test-secret",
}))
app.Get("/protected", func(c *fiber.Ctx) error {
return c.SendString("Protected")
})
// 测试完整中间件链
token := generateToken("test-secret")
req := httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("Authorization", token)
resp, err := app.Test(req)
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
assert.NotEmpty(t, resp.Header.Get("X-Request-ID"))
}
七、最佳实践 #
7.1 设计原则 #
text
1. 单一职责:每个中间件只做一件事
2. 可配置:提供配置选项
3. 可测试:易于编写测试
4. 文档化:提供清晰的文档
5. 错误处理:正确处理和传递错误
7.2 命名约定 #
go
// 中间件函数以Middleware结尾
func AuthMiddleware() fiber.Handler { ... }
func LoggerMiddleware() fiber.Handler { ... }
// 配置结构体以Config结尾
type AuthConfig struct { ... }
type LoggerConfig struct { ... }
// 默认配置以Default结尾
var AuthConfigDefault = AuthConfig{ ... }
7.3 错误处理 #
go
func MyMiddleware(config MyConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// 使用fiber.Error返回错误
if someCondition {
return fiber.NewError(fiber.StatusBadRequest, "Invalid request")
}
// 或者返回自定义错误
if anotherCondition {
return &MyError{Code: 1001, Message: "Custom error"}
}
return c.Next()
}
}
八、完整示例 #
8.1 JWT认证中间件 #
go
package middleware
import (
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/golang-jwt/jwt/v4"
)
type JWTConfig struct {
Secret string
TokenHeader string
ContextKey string
SkipPaths []string
}
var JWTConfigDefault = JWTConfig{
TokenHeader: "Authorization",
ContextKey: "user",
}
func JWTProtected(config ...JWTConfig) fiber.Handler {
cfg := JWTConfigDefault
if len(config) > 0 {
cfg = config[0]
}
return func(c *fiber.Ctx) error {
// 检查是否跳过
for _, path := range cfg.SkipPaths {
if c.Path() == path {
return c.Next()
}
}
// 获取Token
authHeader := c.Get(cfg.TokenHeader)
if authHeader == "" {
return c.Status(401).JSON(fiber.Map{
"error": "Missing authorization header",
})
}
// 解析Bearer Token
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
return c.Status(401).JSON(fiber.Map{
"error": "Invalid authorization format",
})
}
tokenString := parts[1]
// 验证Token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte(cfg.Secret), nil
})
if err != nil || !token.Valid {
return c.Status(401).JSON(fiber.Map{
"error": "Invalid token",
})
}
// 存储用户信息
claims := token.Claims.(jwt.MapClaims)
c.Locals(cfg.ContextKey, claims)
return c.Next()
}
}
// 获取当前用户
func GetCurrentUser(c *fiber.Ctx) jwt.MapClaims {
if claims, ok := c.Locals("user").(jwt.MapClaims); ok {
return claims
}
return nil
}
九、总结 #
9.1 核心要点 #
| 要点 | 说明 |
|---|---|
| 函数签名 | func(c *fiber.Ctx) error |
| 配置模式 | 返回闭包的工厂函数 |
| c.Next() | 调用下一个处理函数 |
| c.Locals() | 上下文数据传递 |
| 错误处理 | 返回error终止请求 |
9.2 下一步 #
现在你已经掌握了自定义中间件,接下来让我们学习 常用中间件,了解实际开发中常用的中间件实现!
最后更新:2026-03-28