tRPC 中间件 #

中间件概念 #

中间件是 tRPC 中处理请求和响应的核心机制。它允许你在过程执行前后插入自定义逻辑,实现认证、日志、错误处理等功能。

text
┌─────────────────────────────────────────────────────────────┐
│                    中间件执行流程                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   请求                                                      │
│     │                                                       │
│     ▼                                                       │
│   ┌─────────────┐                                          │
│   │ Middleware 1│                                          │
│   │  (日志)     │                                          │
│   └──────┬──────┘                                          │
│          │                                                  │
│          ▼                                                  │
│   ┌─────────────┐                                          │
│   │ Middleware 2│                                          │
│   │  (认证)     │                                          │
│   └──────┬──────┘                                          │
│          │                                                  │
│          ▼                                                  │
│   ┌─────────────┐                                          │
│   │ Middleware 3│                                          │
│   │  (验证)     │                                          │
│   └──────┬──────┘                                          │
│          │                                                  │
│          ▼                                                  │
│   ┌─────────────┐                                          │
│   │  Procedure  │                                          │
│   │  (处理)     │                                          │
│   └──────┬──────┘                                          │
│          │                                                  │
│          ▼                                                  │
│   响应                                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

创建中间件 #

基本语法 #

typescript
import { initTRPC } from '@trpc/server';

const t = initTRPC.create();

const myMiddleware = t.middleware(async ({ ctx, next }) => {
  console.log('Before procedure');
  
  const result = await next();
  
  console.log('After procedure');
  
  return result;
});

中间件参数 #

typescript
type MiddlewareOptions<TContext> = {
  ctx: TContext;
  type: 'query' | 'mutation' | 'subscription';
  path: string;
  input: unknown;
  rawInput: unknown;
  meta: unknown | undefined;
  next: () => Promise<MiddlewareResult<TContext>>;
};
text
┌─────────────────────────────────────────────────────────────┐
│                    中间件参数说明                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ctx                                                        │
│  - 请求上下文                                               │
│  - 包含用户信息、请求对象等                                  │
│                                                             │
│  type                                                       │
│  - 过程类型:query/mutation/subscription                    │
│                                                             │
│  path                                                       │
│  - 过程路径,如 'user.getById'                              │
│                                                             │
│  input                                                      │
│  - 解析后的输入数据                                         │
│                                                             │
│  rawInput                                                   │
│  - 原始输入数据                                             │
│                                                             │
│  meta                                                       │
│  - 过程元数据                                               │
│                                                             │
│  next()                                                     │
│  - 调用下一个中间件或过程                                   │
│  - 返回 Promise<MiddlewareResult>                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

应用中间件 #

单个过程 #

typescript
const publicProcedure = t.procedure
  .use(loggingMiddleware)
  .use(timingMiddleware);

const router = t.router({
  hello: publicProcedure.query(() => {
    return 'Hello World!';
  }),
});

路由级别 #

typescript
const userRouter = t.router({
  list: publicProcedure.query(() => { /* ... */ }),
  getById: publicProcedure.query(() => { /* ... */ }),
  create: protectedProcedure.mutation(() => { /* ... */ }),
});

全局中间件 #

typescript
const t = initTRPC.create();

const publicProcedure = t.procedure
  .use(loggingMiddleware)
  .use(errorHandlerMiddleware)
  .use(timingMiddleware);

常见中间件 #

日志中间件 #

typescript
const loggerMiddleware = t.middleware(async ({ path, type, next }) => {
  const start = Date.now();
  const result = await next();
  const duration = Date.now() - start;

  console.log({
    type,
    path,
    duration: `${duration}ms`,
    ok: result.ok,
  });

  return result;
});

export const publicProcedure = t.procedure.use(loggerMiddleware);

详细日志 #

typescript
const detailedLoggerMiddleware = t.middleware(async ({ path, type, input, next }) => {
  const requestId = Math.random().toString(36).substring(7);
  const startTime = Date.now();

  console.log(`[${requestId}] >>> ${type} ${path}`, {
    input,
    timestamp: new Date().toISOString(),
  });

  try {
    const result = await next();
    
    console.log(`[${requestId}] <<< ${type} ${path}`, {
      ok: result.ok,
      duration: `${Date.now() - startTime}ms`,
    });
    
    return result;
  } catch (error) {
    console.error(`[${requestId}] !!! ${type} ${path}`, {
      error,
      duration: `${Date.now() - startTime}ms`,
    });
    throw error;
  }
});

认证中间件 #

typescript
import { TRPCError } from '@trpc/server';

interface Context {
  user: User | null;
}

const isAuthed = t.middleware(async ({ ctx, next }) => {
  if (!ctx.user) {
    throw new TRPCError({
      code: 'UNAUTHORIZED',
      message: 'You must be logged in',
    });
  }

  return next({
    ctx: {
      user: ctx.user,
    },
  });
});

export const protectedProcedure = t.procedure.use(isAuthed);

角色验证 #

typescript
interface Meta {
  requiredRole?: 'user' | 'admin' | 'superadmin';
}

const t = initTRPC.context<Context>().meta<Meta>().create();

const roleMiddleware = t.middleware(async ({ ctx, meta, next }) => {
  if (!ctx.user) {
    throw new TRPCError({ code: 'UNAUTHORIZED' });
  }

  const requiredRole = meta?.requiredRole;
  
  if (requiredRole) {
    const roleHierarchy = {
      user: 1,
      admin: 2,
      superadmin: 3,
    };

    const userLevel = roleHierarchy[ctx.user.role] || 0;
    const requiredLevel = roleHierarchy[requiredRole] || 0;

    if (userLevel < requiredLevel) {
      throw new TRPCError({
        code: 'FORBIDDEN',
        message: `Requires ${requiredRole} role`,
      });
    }
  }

  return next({
    ctx: {
      user: ctx.user,
    },
  });
});

export const protectedProcedure = t.procedure.use(roleMiddleware);

const router = t.router({
  admin: t.router({
    deleteUser: protectedProcedure
      .meta({ requiredRole: 'admin' })
      .input(z.object({ id: z.string() }))
      .mutation(async ({ input }) => {
        return db.user.delete({ where: { id: input.id } });
      }),
  }),
});

速率限制中间件 #

typescript
import rateLimit from 'express-rate-limit';

const rateLimitMiddleware = t.middleware(async ({ ctx, next }) => {
  const limiter = rateLimit({
    windowMs: 15 * 60 * 1000,
    max: 100,
    keyGenerator: (req) => {
      return ctx.user?.id || req.ip;
    },
  });

  return next();
});

const strictRateLimitMiddleware = t.middleware(async ({ ctx, next }) => {
  const limiter = rateLimit({
    windowMs: 60 * 60 * 1000,
    max: 5,
    keyGenerator: (req) => {
      return ctx.user?.id || req.ip;
    },
  });

  return next();
});

缓存中间件 #

typescript
import { LRUCache } from 'lru-cache';

const cache = new LRUCache<string, any>({
  max: 1000,
  ttl: 1000 * 60 * 5,
});

const cacheMiddleware = t.middleware(async ({ path, input, type, next }) => {
  if (type !== 'query') {
    return next();
  }

  const cacheKey = `${path}:${JSON.stringify(input)}`;
  const cached = cache.get(cacheKey);

  if (cached) {
    console.log(`Cache hit: ${cacheKey}`);
    return cached;
  }

  console.log(`Cache miss: ${cacheKey}`);
  const result = await next();
  
  if (result.ok) {
    cache.set(cacheKey, result);
  }

  return result;
});

性能监控 #

typescript
const performanceMiddleware = t.middleware(async ({ path, type, next }) => {
  const start = performance.now();
  const startMemory = process.memoryUsage();

  const result = await next();

  const endMemory = process.memoryUsage();
  const duration = performance.now() - start;

  const metrics = {
    path,
    type,
    duration: `${duration.toFixed(2)}ms`,
    memoryDelta: {
      heapUsed: `${((endMemory.heapUsed - startMemory.heapUsed) / 1024).toFixed(2)}KB`,
      external: `${((endMemory.external - startMemory.external) / 1024).toFixed(2)}KB`,
    },
  };

  if (duration > 1000) {
    console.warn('Slow query detected:', metrics);
  }

  return result;
});

输入验证中间件 #

typescript
const validationMiddleware = t.middleware(async ({ input, next }) => {
  if (input && typeof input === 'object') {
    const sanitized = sanitizeInput(input);
    return next({ input: sanitized });
  }
  return next();
});

function sanitizeInput(obj: any): any {
  if (typeof obj !== 'object' || obj === null) {
    return obj;
  }

  if (Array.isArray(obj)) {
    return obj.map(sanitizeInput);
  }

  const result: any = {};
  for (const key in obj) {
    if (obj.hasOwnProperty(key)) {
      const value = obj[key];
      if (typeof value === 'string') {
        result[key] = value.trim();
      } else {
        result[key] = sanitizeInput(value);
      }
    }
  }
  return result;
}

中间件组合 #

链式组合 #

typescript
const publicProcedure = t.procedure
  .use(loggingMiddleware)
  .use(performanceMiddleware)
  .use(cacheMiddleware);

const protectedProcedure = publicProcedure
  .use(authMiddleware)
  .use(rateLimitMiddleware);

const adminProcedure = protectedProcedure
  .use(adminMiddleware);
text
┌─────────────────────────────────────────────────────────────┐
│                    中间件链执行顺序                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  publicProcedure:                                           │
│  logging → performance → cache → handler                    │
│                                                             │
│  protectedProcedure:                                        │
│  logging → performance → cache → auth → rateLimit → handler │
│                                                             │
│  adminProcedure:                                            │
│  logging → performance → cache → auth → rateLimit → admin → handler│
│                                                             │
│  执行顺序:                                                  │
│  1. logging (before)                                        │
│  2. performance (before)                                    │
│  3. cache (before)                                          │
│  4. auth (before)                                           │
│  5. rateLimit (before)                                      │
│  6. admin (before)                                          │
│  7. handler                                                 │
│  8. admin (after)                                           │
│  9. rateLimit (after)                                       │
│  10. auth (after)                                           │
│  11. cache (after)                                          │
│  12. performance (after)                                    │
│  13. logging (after)                                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

条件中间件 #

typescript
const conditionalMiddleware = t.middleware(async ({ meta, next }) => {
  if (meta?.skipAuth) {
    return next();
  }
  
  return authMiddleware({ ctx, next });
});

const router = t.router({
  public: t.router({
    list: publicProcedure
      .meta({ skipAuth: true })
      .query(() => { /* ... */ }),
  }),
});

动态中间件 #

typescript
const dynamicMiddleware = t.middleware(async ({ path, ctx, next }) => {
  const middlewares = getMiddlewaresForPath(path);
  
  let result = next();
  for (const middleware of middlewares.reverse()) {
    result = await middleware({ ctx, next: () => result });
  }
  
  return result;
});

Context 增强 #

添加 Context #

typescript
const userMiddleware = t.middleware(async ({ ctx, next }) => {
  const user = await getCurrentUser(ctx.req);
  
  return next({
    ctx: {
      ...ctx,
      user,
    },
  });
});

const dbMiddleware = t.middleware(async ({ ctx, next }) => {
  const db = createDbConnection();
  
  return next({
    ctx: {
      ...ctx,
      db,
    },
  });
});

Context 类型推断 #

typescript
interface BaseContext {
  req: Request;
}

interface UserContext extends BaseContext {
  user: User;
}

const t = initTRPC.context<BaseContext>().create();

const isAuthed = t.middleware(async ({ ctx, next }) => {
  const user = await authenticate(ctx.req);
  
  if (!user) {
    throw new TRPCError({ code: 'UNAUTHORIZED' });
  }

  return next({
    ctx: {
      ...ctx,
      user,
    } as UserContext,
  });
});

const protectedProcedure = t.procedure.use(isAuthed);

protectedProcedure.query(({ ctx }) => {
  ctx.user;
});

错误处理中间件 #

全局错误处理 #

typescript
import { TRPCError } from '@trpc/server';

const errorHandlerMiddleware = t.middleware(async ({ next }) => {
  try {
    return await next();
  } catch (error) {
    if (error instanceof TRPCError) {
      throw error;
    }

    if (error instanceof ValidationError) {
      throw new TRPCError({
        code: 'BAD_REQUEST',
        message: error.message,
        cause: error,
      });
    }

    if (error instanceof DatabaseError) {
      console.error('Database error:', error);
      throw new TRPCError({
        code: 'INTERNAL_SERVER_ERROR',
        message: 'Database operation failed',
      });
    }

    console.error('Unexpected error:', error);
    throw new TRPCError({
      code: 'INTERNAL_SERVER_ERROR',
      message: 'An unexpected error occurred',
    });
  }
});

错误转换 #

typescript
const errorTransformMiddleware = t.middleware(async ({ next }) => {
  const result = await next();

  if (!result.ok) {
    result.error = transformError(result.error);
  }

  return result;
});

function transformError(error: TRPCError): TRPCError {
  const errorMessages: Record<string, string> = {
    UNAUTHORIZED: '请先登录',
    FORBIDDEN: '没有权限执行此操作',
    NOT_FOUND: '请求的资源不存在',
    BAD_REQUEST: '请求参数错误',
    INTERNAL_SERVER_ERROR: '服务器内部错误',
  };

  return new TRPCError({
    code: error.code,
    message: errorMessages[error.code] || error.message,
    cause: error.cause,
  });
}

中间件最佳实践 #

1. 单一职责 #

typescript
const loggingMiddleware = t.middleware(async ({ next }) => {
  console.log('Request received');
  return next();
});

const authMiddleware = t.middleware(async ({ ctx, next }) => {
  if (!ctx.user) {
    throw new TRPCError({ code: 'UNAUTHORIZED' });
  }
  return next();
});

2. 可配置性 #

typescript
function createRateLimitMiddleware(options: {
  windowMs: number;
  max: number;
}) {
  return t.middleware(async ({ ctx, next }) => {
    const limiter = rateLimit({
      windowMs: options.windowMs,
      max: options.max,
    });
    return next();
  });
}

const strictRateLimit = createRateLimitMiddleware({
  windowMs: 60 * 1000,
  max: 10,
});

3. 类型安全 #

typescript
interface AuthContext {
  user: User;
}

const isAuthed = t.middleware(async ({ ctx, next }) => {
  if (!ctx.user) {
    throw new TRPCError({ code: 'UNAUTHORIZED' });
  }

  return next({
    ctx: {
      user: ctx.user,
    } as AuthContext,
  });
});

4. 错误处理 #

typescript
const safeMiddleware = t.middleware(async ({ next }) => {
  try {
    return await next();
  } catch (error) {
    console.error('Middleware error:', error);
    throw error;
  }
});

5. 文档化 #

typescript
const authMiddleware = t.middleware(async ({ ctx, next }) => {
  if (!ctx.user) {
    throw new TRPCError({
      code: 'UNAUTHORIZED',
      message: 'Authentication required',
    });
  }
  return next();
});

完整示例 #

typescript
import { initTRPC, TRPCError } from '@trpc/server';
import { z } from 'zod';
import { rateLimit } from 'express-rate-limit';
import { LRUCache } from 'lru-cache';

interface Context {
  user: User | null;
  req: Request;
}

interface Meta {
  requiresAuth?: boolean;
  requiredRole?: 'user' | 'admin';
  cache?: boolean;
  rateLimit?: { max: number; windowMs: number };
}

const t = initTRPC.context<Context>().meta<Meta>().create();

const loggerMiddleware = t.middleware(async ({ path, type, next }) => {
  const start = Date.now();
  const result = await next();
  console.log(`[${type}] ${path} - ${Date.now() - start}ms`);
  return result;
});

const authMiddleware = t.middleware(async ({ ctx, meta, next }) => {
  if (meta?.requiresAuth !== false && !ctx.user) {
    throw new TRPCError({ code: 'UNAUTHORIZED' });
  }

  if (meta?.requiredRole && ctx.user?.role !== meta.requiredRole) {
    throw new TRPCError({ code: 'FORBIDDEN' });
  }

  return next({
    ctx: {
      user: ctx.user!,
    },
  });
});

const cacheStore = new LRUCache<string, any>({ max: 1000, ttl: 60000 });

const cacheMiddleware = t.middleware(async ({ path, input, type, meta, next }) => {
  if (type !== 'query' || !meta?.cache) {
    return next();
  }

  const key = `${path}:${JSON.stringify(input)}`;
  const cached = cacheStore.get(key);
  
  if (cached) {
    return cached;
  }

  const result = await next();
  if (result.ok) {
    cacheStore.set(key, result);
  }
  
  return result;
});

const publicProcedure = t.procedure
  .use(loggerMiddleware)
  .use(cacheMiddleware);

const protectedProcedure = publicProcedure
  .use(authMiddleware);

const adminProcedure = protectedProcedure
  .meta({ requiredRole: 'admin' });

export const appRouter = t.router({
  user: t.router({
    list: publicProcedure
      .meta({ cache: true })
      .query(async () => {
        return db.user.findMany();
      }),

    getProfile: protectedProcedure
      .query(({ ctx }) => {
        return ctx.user;
      }),

    delete: adminProcedure
      .input(z.object({ id: z.string() }))
      .mutation(async ({ input }) => {
        return db.user.delete({ where: { id: input.id } });
      }),
  }),
});

export type AppRouter = typeof appRouter;

下一步 #

现在你已经掌握了 tRPC 中间件的使用方法,接下来学习 错误处理,了解如何优雅地处理 tRPC 应用中的各种错误!

最后更新:2026-03-29