自定义中间件 #

中间件实现方式 #

Actix Web 提供两种实现中间件的方式:

  1. 函数式中间件:简单、适合简单场景
  2. 结构体中间件:完整、适合复杂场景

函数式中间件 #

基本实现 #

rust
use actix_web::{dev::Service as _, web, App, HttpResponse, HttpServer};

async fn middleware(
    req: actix_web::dev::ServiceRequest,
    next: actix_web::dev::Next<impl actix_web::body::MessageBody>,
) -> Result<actix_web::dev::ServiceResponse<impl actix_web::body::MessageBody>, actix_web::Error> {
    println!("Request: {} {}", req.method(), req.path());
    
    let res = next.call(req).await?;
    
    println!("Response: {}", res.status());
    
    Ok(res)
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    HttpServer::new(|| {
        App::new()
            .wrap_fn(middleware)
            .route("/", web::get().to(|| HttpResponse::Ok().body("Hello")))
    })
    .bind("127.0.0.1:8080")?
    .run()
    .await
}

计时中间件 #

rust
use actix_web::{dev::{ServiceRequest, ServiceResponse}, web, App, HttpResponse, HttpServer};
use std::time::Instant;

async fn timing_middleware(
    req: ServiceRequest,
    next: actix_web::dev::Next<impl actix_web::body::MessageBody>,
) -> Result<ServiceResponse<impl actix_web::body::MessageBody>, actix_web::Error> {
    let start = Instant::now();
    
    let res = next.call(req).await?;
    
    let duration = start.elapsed();
    println!("Request took: {:?}", duration);
    
    Ok(res)
}

添加响应头 #

rust
use actix_web::http::header;

async fn add_headers(
    req: ServiceRequest,
    next: actix_web::dev::Next<impl actix_web::body::MessageBody>,
) -> Result<ServiceResponse<impl actix_web::body::MessageBody>, actix_web::Error> {
    let mut res = next.call(req).await?;
    
    res.headers_mut().insert(
        header::HeaderName::try_from("X-Response-Time").unwrap(),
        header::HeaderValue::from_static("100ms"),
    );
    
    Ok(res)
}

结构体中间件 #

基本结构 #

rust
use actix_web::{
    body::MessageBody,
    dev::{Service, ServiceRequest, ServiceResponse, Transform},
    Error, HttpMessage,
};
use futures::future::{ok, LocalBoxFuture, Ready};

pub struct MyMiddleware;

impl<S, B> Transform<S, ServiceRequest> for MyMiddleware
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type InitError = ();
    type Transform = MyMiddlewareService<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ok(MyMiddlewareService { service })
    }
}

pub struct MyMiddlewareService<S> {
    service: S,
}

impl<S, B> Service<ServiceRequest> for MyMiddlewareService<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(
        &self,
        ctx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.service.poll_ready(ctx)
    }

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let fut = self.service.call(req);

        Box::pin(async move {
            let res = fut.await?;
            Ok(res)
        })
    }
}

认证中间件 #

rust
use actix_web::{
    body::MessageBody,
    dev::{Service, ServiceRequest, ServiceResponse, Transform},
    http::header,
    Error, HttpMessage,
};
use futures::future::{ok, LocalBoxFuture, Ready};

pub struct AuthMiddleware {
    api_key: String,
}

impl AuthMiddleware {
    pub fn new(api_key: &str) -> Self {
        Self {
            api_key: api_key.to_string(),
        }
    }
}

impl<S, B> Transform<S, ServiceRequest> for AuthMiddleware
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type InitError = ();
    type Transform = AuthService<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ok(AuthService {
            service,
            api_key: self.api_key.clone(),
        })
    }
}

pub struct AuthService<S> {
    service: S,
    api_key: String,
}

impl<S, B> Service<ServiceRequest> for AuthService<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(
        &self,
        ctx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.service.poll_ready(ctx)
    }

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let api_key = self.api_key.clone();
        let auth_header = req.headers().get(header::AUTHORIZATION);

        match auth_header {
            Some(header_value) if header_value.to_str().unwrap_or("") == format!("Bearer {}", api_key) => {
                let fut = self.service.call(req);
                Box::pin(async move {
                    let res = fut.await?;
                    Ok(res)
                })
            }
            _ => {
                let (req, _) = req.into_parts();
                let response = actix_web::HttpResponse::Unauthorized()
                    .json(serde_json::json!({
                        "error": "Unauthorized"
                    }))
                    .map_into_right_body();
                
                Box::pin(async move {
                    Ok(ServiceResponse::new(req, response))
                })
            }
        }
    }
}

带状态的中间件 #

rust
use actix_web::{web, App, HttpServer, HttpResponse};
use std::sync::Arc;

struct AppState {
    request_count: std::sync::atomic::AtomicU64,
}

pub struct RequestCounter;

impl<S, B> Transform<S, ServiceRequest> for RequestCounter
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
    S::Future: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type InitError = ();
    type Transform = RequestCounterService<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ok(RequestCounterService { service })
    }
}

pub struct RequestCounterService<S> {
    service: S,
}

impl<S, B> Service<ServiceRequest> for RequestCounterService<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
    S::Future: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = actix_web::Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(
        &self,
        ctx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.service.poll_ready(ctx)
    }

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let state = req.app_data::<web::Data<Arc<AppState>>>().unwrap();
        let count = state.request_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
        
        println!("Request #{}", count + 1);
        
        let fut = self.service.call(req);
        
        Box::pin(async move {
            let res = fut.await?;
            Ok(res)
        })
    }
}

中间件配置 #

条件性应用 #

rust
use actix_web::middleware::Condition;

App::new()
    .wrap(Condition::new(true, MyMiddleware))

范围级别中间件 #

rust
App::new()
    .service(
        web::scope("/api")
            .wrap(AuthMiddleware::new("secret"))
            .route("/users", web::get().to(get_users))
    )

完整示例 #

rust
use actix_web::{
    body::MessageBody,
    dev::{Service, ServiceRequest, ServiceResponse, Transform},
    http::header,
    web, App, Error, HttpMessage, HttpResponse, HttpServer, Responder,
};
use futures::future::{ok, LocalBoxFuture, Ready};
use std::time::Instant;

pub struct TimingMiddleware;

impl<S, B> Transform<S, ServiceRequest> for TimingMiddleware
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type InitError = ();
    type Transform = TimingService<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ok(TimingService { service })
    }
}

pub struct TimingService<S> {
    service: S,
}

impl<S, B> Service<ServiceRequest> for TimingService<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: MessageBody + 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(
        &self,
        ctx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.service.poll_ready(ctx)
    }

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let start = Instant::now();
        let fut = self.service.call(req);

        Box::pin(async move {
            let mut res = fut.await?;
            let duration = start.elapsed();
            
            res.headers_mut().insert(
                header::HeaderName::try_from("X-Response-Time").unwrap(),
                header::HeaderValue::from_str(&format!("{:?}", duration)).unwrap(),
            );
            
            Ok(res)
        })
    }
}

async fn index() -> impl Responder {
    HttpResponse::Ok().json(serde_json::json!({
        "message": "Hello, World!"
    }))
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    println!("Server running at http://127.0.0.1:8080");
    
    HttpServer::new(|| {
        App::new()
            .wrap(TimingMiddleware)
            .route("/", web::get().to(index))
    })
    .bind("127.0.0.1:8080")?
    .run()
    .await
}

下一步 #

现在你已经掌握了自定义中间件,继续学习 错误处理,深入了解错误处理机制!

最后更新:2026-03-29