路由守卫 #

路由守卫是Rocket中实现请求前置检查和权限控制的重要机制。通过实现 FromParam trait,可以在路由匹配阶段进行验证和转换。

FromParam Trait #

FromParam 是Rocket中用于路径参数验证和转换的核心trait。

基本定义 #

rust
pub trait FromParam<'a>: Sized {
    type Error;

    fn from_param(param: &'a str) -> Result<Self, Self::Error>;
}

内置实现 #

Rocket为以下类型提供了内置的 FromParam 实现:

类型 说明
&str 字符串切片
String 拥有所有权的字符串
i8, i16, i32, i64, i128 有符号整数
u8, u16, u32, u64, u128 无符号整数
f32, f64 浮点数
bool 布尔值
PathBuf 路径
Option<T> 可选值
Result<T, E> 结果类型

自定义参数守卫 #

验证正整数 #

rust
use rocket::request::FromParam;
use rocket::http::Status;

struct PositiveId(u32);

impl<'a> FromParam<'a> for PositiveId {
    type Error = &'a str;

    fn from_param(param: &'a str) -> Result<Self, Self::Error> {
        let id: u32 = param.parse().map_err(|_| param)?;
        if id > 0 {
            Ok(PositiveId(id))
        } else {
            Err(param)
        }
    }
}

#[get("/user/<id>")]
fn get_user(id: Result<PositiveId, &str>) -> Result<String, Status> {
    match id {
        Ok(PositiveId(id)) => Ok(format!("User ID: {}", id)),
        Err(_) => Err(Status::BadRequest),
    }
}

验证邮箱格式 #

rust
use rocket::request::FromParam;
use regex::Regex;

struct Email(String);

impl<'a> FromParam<'a> for Email {
    type Error = &'a str;

    fn from_param(param: &'a str) -> Result<Self, Self::Error> {
        let email_regex = Regex::new(r"^[^@]+@[^@]+\.[^@]+$").unwrap();
        if email_regex.is_match(param) {
            Ok(Email(param.to_string()))
        } else {
            Err(param)
        }
    }
}

#[get("/verify/<email>")]
fn verify_email(email: Result<Email, &str>) -> String {
    match email {
        Ok(Email(e)) => format!("Valid email: {}", e),
        Err(e) => format!("Invalid email: {}", e),
    }
}

枚举参数 #

rust
use rocket::request::FromParam;

#[derive(Debug)]
enum UserRole {
    Admin,
    Moderator,
    User,
}

impl<'a> FromParam<'a> for UserRole {
    type Error = &'a str;

    fn from_param(param: &'a str) -> Result<Self, Self::Error> {
        match param.to_lowercase().as_str() {
            "admin" => Ok(UserRole::Admin),
            "moderator" => Ok(UserRole::Moderator),
            "user" => Ok(UserRole::User),
            _ => Err(param),
        }
    }
}

#[get("/role/<role>")]
fn get_by_role(role: Result<UserRole, &str>) -> String {
    match role {
        Ok(r) => format!("Role: {:?}", r),
        Err(e) => format!("Invalid role: {}", e),
    }
}

参数守卫与错误处理 #

使用Option处理 #

rust
#[get("/user/<id>")]
fn get_user(id: Option<u32>) -> String {
    match id {
        Some(id) => format!("User ID: {}", id),
        None => "Invalid ID format".to_string(),
    }
}

使用Result处理 #

rust
use rocket::http::Status;

#[get("/user/<id>")]
fn get_user(id: Result<u32, &str>) -> Result<String, Status> {
    id.map(|id| format!("User ID: {}", id))
      .map_err(|_| Status::BadRequest)
}

自定义错误类型 #

rust
use rocket::request::FromParam;
use rocket::http::Status;
use std::fmt;

#[derive(Debug)]
struct ValidationError {
    message: String,
}

impl fmt::Display for ValidationError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.message)
    }
}

struct Username(String);

impl<'a> FromParam<'a> for Username {
    type Error = ValidationError;

    fn from_param(param: &'a str) -> Result<Self, Self::Error> {
        if param.len() >= 3 && param.len() <= 20 {
            Ok(Username(param.to_string()))
        } else {
            Err(ValidationError {
                message: "Username must be 3-20 characters".to_string(),
            })
        }
    }
}

查询参数守卫 #

自定义查询守卫 #

rust
use rocket::request::FromQuery;
use rocket::request::Query;
use rocket::FromForm;

#[derive(FromForm)]
struct DateRange {
    start: String,
    end: String,
}

impl<'a> FromQuery<'a> for DateRange {
    type Error = String;

    fn from_query(query: Query<'a>) -> Result<Self, Self::Error> {
        let start = query.get("start").map(|v| v.to_string()).unwrap_or_default();
        let end = query.get("end").map(|v| v.to_string()).unwrap_or_default();
        
        if start.is_empty() || end.is_empty() {
            Err("Both start and end dates are required".to_string())
        } else {
            Ok(DateRange { start, end })
        }
    }
}

路由级别验证 #

结合FromRequest #

对于更复杂的验证,可以使用 FromRequest

rust
use rocket::request::{self, FromRequest, Request, Outcome};
use rocket::http::Status;

struct AuthenticatedUser {
    id: u32,
    username: String,
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for AuthenticatedUser {
    type Error = AuthError;

    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
        let auth_header = request.headers().get_one("Authorization");
        
        match auth_header {
            Some(token) if token.starts_with("Bearer ") => {
                let token = token.trim_start_matches("Bearer ");
                match validate_token(token) {
                    Ok(user) => Outcome::Success(user),
                    Err(_) => Outcome::Error((Status::Unauthorized, AuthError::Invalid)),
                }
            }
            _ => Outcome::Error((Status::Unauthorized, AuthError::Missing)),
        }
    }
}

#[derive(Debug)]
enum AuthError {
    Missing,
    Invalid,
}

fn validate_token(token: &str) -> Result<AuthenticatedUser, ()> {
    Ok(AuthenticatedUser {
        id: 1,
        username: "admin".to_string(),
    })
}

#[get("/protected")]
fn protected(user: AuthenticatedUser) -> String {
    format!("Welcome, {}!", user.username)
}

管理员权限守卫 #

rust
use rocket::request::{self, FromRequest, Request, Outcome};
use rocket::http::Status;

struct AdminUser {
    username: String,
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for AdminUser {
    type Error = ();

    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
        let user = request.local_cache(|| {
            request.headers()
                .get_one("X-Admin-Key")
                .map(|key| key == "secret-admin-key")
                .unwrap_or(false)
        });

        if *user {
            Outcome::Success(AdminUser {
                username: "admin".to_string(),
            })
        } else {
            Outcome::Error((Status::Forbidden, ()))
        }
    }
}

#[get("/admin/dashboard")]
fn admin_dashboard(admin: AdminUser) -> String {
    format!("Admin Dashboard - Welcome {}", admin.username)
}

组合守卫 #

多重验证 #

rust
#[get("/api/users/<user_id>/posts/<post_id>")]
fn get_user_post(
    user_id: Result<PositiveId, &str>,
    post_id: Result<PositiveId, &str>,
    auth: AuthenticatedUser,
) -> Result<String, Status> {
    let user_id = user_id.map_err(|_| Status::BadRequest)?;
    let post_id = post_id.map_err(|_| Status::BadRequest)?;
    
    if auth.id != user_id.0 {
        return Err(Status::Forbidden);
    }
    
    Ok(format!("User {}'s post {}", user_id.0, post_id.0))
}

完整示例 #

rust
#[macro_use] extern crate rocket;

use rocket::request::{self, FromParam, FromRequest, Request, Outcome};
use rocket::http::Status;
use rocket::serde::json::Json;

struct PositiveId(u32);

impl<'a> FromParam<'a> for PositiveId {
    type Error = &'a str;

    fn from_param(param: &'a str) -> Result<Self, Self::Error> {
        let id: u32 = param.parse().map_err(|_| param)?;
        if id > 0 {
            Ok(PositiveId(id))
        } else {
            Err(param)
        }
    }
}

struct AuthenticatedUser {
    id: u32,
    username: String,
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for AuthenticatedUser {
    type Error = ();

    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
        let auth = request.headers().get_one("Authorization");
        
        match auth {
            Some(token) if token == "valid-token" => {
                Outcome::Success(AuthenticatedUser {
                    id: 1,
                    username: "alice".to_string(),
                })
            }
            _ => Outcome::Error((Status::Unauthorized, ())),
        }
    }
}

#[get("/public/<id>")]
fn public_endpoint(id: Result<PositiveId, &str>) -> String {
    match id {
        Ok(PositiveId(id)) => format!("Public data for ID: {}", id),
        Err(e) => format!("Invalid ID: {}", e),
    }
}

#[get("/protected/<id>")]
fn protected_endpoint(
    id: Result<PositiveId, &str>,
    user: AuthenticatedUser,
) -> Result<String, Status> {
    let PositiveId(id) = id.map_err(|_| Status::BadRequest)?;
    Ok(format!("Protected data for user {} (ID: {})", user.username, id))
}

#[launch]
fn rocket() -> _ {
    rocket::build()
        .mount("/api", routes![public_endpoint, protected_endpoint])
}

测试守卫 #

bash
# 公开端点
curl http://127.0.0.1:8000/api/public/123

# 受保护端点(无认证)
curl http://127.0.0.1:8000/api/protected/123

# 受保护端点(有认证)
curl -H "Authorization: valid-token" http://127.0.0.1:8000/api/protected/123

下一步 #

掌握了路由守卫后,让我们继续学习 请求参数,深入了解请求处理的各种方式。

最后更新:2026-03-28