路由守卫 #
路由守卫是Rocket中实现请求前置检查和权限控制的重要机制。通过实现 FromParam trait,可以在路由匹配阶段进行验证和转换。
FromParam Trait #
FromParam 是Rocket中用于路径参数验证和转换的核心trait。
基本定义 #
rust
pub trait FromParam<'a>: Sized {
type Error;
fn from_param(param: &'a str) -> Result<Self, Self::Error>;
}
内置实现 #
Rocket为以下类型提供了内置的 FromParam 实现:
| 类型 | 说明 |
|---|---|
&str |
字符串切片 |
String |
拥有所有权的字符串 |
i8, i16, i32, i64, i128 |
有符号整数 |
u8, u16, u32, u64, u128 |
无符号整数 |
f32, f64 |
浮点数 |
bool |
布尔值 |
PathBuf |
路径 |
Option<T> |
可选值 |
Result<T, E> |
结果类型 |
自定义参数守卫 #
验证正整数 #
rust
use rocket::request::FromParam;
use rocket::http::Status;
struct PositiveId(u32);
impl<'a> FromParam<'a> for PositiveId {
type Error = &'a str;
fn from_param(param: &'a str) -> Result<Self, Self::Error> {
let id: u32 = param.parse().map_err(|_| param)?;
if id > 0 {
Ok(PositiveId(id))
} else {
Err(param)
}
}
}
#[get("/user/<id>")]
fn get_user(id: Result<PositiveId, &str>) -> Result<String, Status> {
match id {
Ok(PositiveId(id)) => Ok(format!("User ID: {}", id)),
Err(_) => Err(Status::BadRequest),
}
}
验证邮箱格式 #
rust
use rocket::request::FromParam;
use regex::Regex;
struct Email(String);
impl<'a> FromParam<'a> for Email {
type Error = &'a str;
fn from_param(param: &'a str) -> Result<Self, Self::Error> {
let email_regex = Regex::new(r"^[^@]+@[^@]+\.[^@]+$").unwrap();
if email_regex.is_match(param) {
Ok(Email(param.to_string()))
} else {
Err(param)
}
}
}
#[get("/verify/<email>")]
fn verify_email(email: Result<Email, &str>) -> String {
match email {
Ok(Email(e)) => format!("Valid email: {}", e),
Err(e) => format!("Invalid email: {}", e),
}
}
枚举参数 #
rust
use rocket::request::FromParam;
#[derive(Debug)]
enum UserRole {
Admin,
Moderator,
User,
}
impl<'a> FromParam<'a> for UserRole {
type Error = &'a str;
fn from_param(param: &'a str) -> Result<Self, Self::Error> {
match param.to_lowercase().as_str() {
"admin" => Ok(UserRole::Admin),
"moderator" => Ok(UserRole::Moderator),
"user" => Ok(UserRole::User),
_ => Err(param),
}
}
}
#[get("/role/<role>")]
fn get_by_role(role: Result<UserRole, &str>) -> String {
match role {
Ok(r) => format!("Role: {:?}", r),
Err(e) => format!("Invalid role: {}", e),
}
}
参数守卫与错误处理 #
使用Option处理 #
rust
#[get("/user/<id>")]
fn get_user(id: Option<u32>) -> String {
match id {
Some(id) => format!("User ID: {}", id),
None => "Invalid ID format".to_string(),
}
}
使用Result处理 #
rust
use rocket::http::Status;
#[get("/user/<id>")]
fn get_user(id: Result<u32, &str>) -> Result<String, Status> {
id.map(|id| format!("User ID: {}", id))
.map_err(|_| Status::BadRequest)
}
自定义错误类型 #
rust
use rocket::request::FromParam;
use rocket::http::Status;
use std::fmt;
#[derive(Debug)]
struct ValidationError {
message: String,
}
impl fmt::Display for ValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
struct Username(String);
impl<'a> FromParam<'a> for Username {
type Error = ValidationError;
fn from_param(param: &'a str) -> Result<Self, Self::Error> {
if param.len() >= 3 && param.len() <= 20 {
Ok(Username(param.to_string()))
} else {
Err(ValidationError {
message: "Username must be 3-20 characters".to_string(),
})
}
}
}
查询参数守卫 #
自定义查询守卫 #
rust
use rocket::request::FromQuery;
use rocket::request::Query;
use rocket::FromForm;
#[derive(FromForm)]
struct DateRange {
start: String,
end: String,
}
impl<'a> FromQuery<'a> for DateRange {
type Error = String;
fn from_query(query: Query<'a>) -> Result<Self, Self::Error> {
let start = query.get("start").map(|v| v.to_string()).unwrap_or_default();
let end = query.get("end").map(|v| v.to_string()).unwrap_or_default();
if start.is_empty() || end.is_empty() {
Err("Both start and end dates are required".to_string())
} else {
Ok(DateRange { start, end })
}
}
}
路由级别验证 #
结合FromRequest #
对于更复杂的验证,可以使用 FromRequest:
rust
use rocket::request::{self, FromRequest, Request, Outcome};
use rocket::http::Status;
struct AuthenticatedUser {
id: u32,
username: String,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for AuthenticatedUser {
type Error = AuthError;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let auth_header = request.headers().get_one("Authorization");
match auth_header {
Some(token) if token.starts_with("Bearer ") => {
let token = token.trim_start_matches("Bearer ");
match validate_token(token) {
Ok(user) => Outcome::Success(user),
Err(_) => Outcome::Error((Status::Unauthorized, AuthError::Invalid)),
}
}
_ => Outcome::Error((Status::Unauthorized, AuthError::Missing)),
}
}
}
#[derive(Debug)]
enum AuthError {
Missing,
Invalid,
}
fn validate_token(token: &str) -> Result<AuthenticatedUser, ()> {
Ok(AuthenticatedUser {
id: 1,
username: "admin".to_string(),
})
}
#[get("/protected")]
fn protected(user: AuthenticatedUser) -> String {
format!("Welcome, {}!", user.username)
}
管理员权限守卫 #
rust
use rocket::request::{self, FromRequest, Request, Outcome};
use rocket::http::Status;
struct AdminUser {
username: String,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for AdminUser {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let user = request.local_cache(|| {
request.headers()
.get_one("X-Admin-Key")
.map(|key| key == "secret-admin-key")
.unwrap_or(false)
});
if *user {
Outcome::Success(AdminUser {
username: "admin".to_string(),
})
} else {
Outcome::Error((Status::Forbidden, ()))
}
}
}
#[get("/admin/dashboard")]
fn admin_dashboard(admin: AdminUser) -> String {
format!("Admin Dashboard - Welcome {}", admin.username)
}
组合守卫 #
多重验证 #
rust
#[get("/api/users/<user_id>/posts/<post_id>")]
fn get_user_post(
user_id: Result<PositiveId, &str>,
post_id: Result<PositiveId, &str>,
auth: AuthenticatedUser,
) -> Result<String, Status> {
let user_id = user_id.map_err(|_| Status::BadRequest)?;
let post_id = post_id.map_err(|_| Status::BadRequest)?;
if auth.id != user_id.0 {
return Err(Status::Forbidden);
}
Ok(format!("User {}'s post {}", user_id.0, post_id.0))
}
完整示例 #
rust
#[macro_use] extern crate rocket;
use rocket::request::{self, FromParam, FromRequest, Request, Outcome};
use rocket::http::Status;
use rocket::serde::json::Json;
struct PositiveId(u32);
impl<'a> FromParam<'a> for PositiveId {
type Error = &'a str;
fn from_param(param: &'a str) -> Result<Self, Self::Error> {
let id: u32 = param.parse().map_err(|_| param)?;
if id > 0 {
Ok(PositiveId(id))
} else {
Err(param)
}
}
}
struct AuthenticatedUser {
id: u32,
username: String,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for AuthenticatedUser {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let auth = request.headers().get_one("Authorization");
match auth {
Some(token) if token == "valid-token" => {
Outcome::Success(AuthenticatedUser {
id: 1,
username: "alice".to_string(),
})
}
_ => Outcome::Error((Status::Unauthorized, ())),
}
}
}
#[get("/public/<id>")]
fn public_endpoint(id: Result<PositiveId, &str>) -> String {
match id {
Ok(PositiveId(id)) => format!("Public data for ID: {}", id),
Err(e) => format!("Invalid ID: {}", e),
}
}
#[get("/protected/<id>")]
fn protected_endpoint(
id: Result<PositiveId, &str>,
user: AuthenticatedUser,
) -> Result<String, Status> {
let PositiveId(id) = id.map_err(|_| Status::BadRequest)?;
Ok(format!("Protected data for user {} (ID: {})", user.username, id))
}
#[launch]
fn rocket() -> _ {
rocket::build()
.mount("/api", routes![public_endpoint, protected_endpoint])
}
测试守卫 #
bash
# 公开端点
curl http://127.0.0.1:8000/api/public/123
# 受保护端点(无认证)
curl http://127.0.0.1:8000/api/protected/123
# 受保护端点(有认证)
curl -H "Authorization: valid-token" http://127.0.0.1:8000/api/protected/123
下一步 #
掌握了路由守卫后,让我们继续学习 请求参数,深入了解请求处理的各种方式。
最后更新:2026-03-28