REST API Development
Building REST APIs in Rust combines the language's performance and safety with modern web development practices. This guide covers comprehensive REST API development from design to deployment.
REST API Design Principles
RESTful Resource Design
// Example API structure for a blog application
/*
Resources and Endpoints:
Posts:
- GET /api/posts - List all posts
- GET /api/posts/:id - Get specific post
- POST /api/posts - Create new post
- PUT /api/posts/:id - Update entire post
- PATCH /api/posts/:id - Partial update post
- DELETE /api/posts/:id - Delete post
Users:
- GET /api/users - List users
- GET /api/users/:id - Get user
- POST /api/users - Create user
- PUT /api/users/:id - Update user
- DELETE /api/users/:id - Delete user
Comments:
- GET /api/posts/:id/comments - List post comments
- POST /api/posts/:id/comments - Create comment
- GET /api/comments/:id - Get comment
- PUT /api/comments/:id - Update comment
- DELETE /api/comments/:id - Delete comment
Authentication:
- POST /api/auth/login - User login
- POST /api/auth/logout - User logout
- POST /api/auth/refresh - Refresh token
- GET /api/auth/me - Get current user
*/
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::Json,
routing::{get, post, put, patch, delete},
Router,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use chrono::{DateTime, Utc};
// Core data models
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Post {
pub id: Uuid,
pub title: String,
pub content: String,
pub author_id: Uuid,
pub published: bool,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: Uuid,
pub username: String,
pub email: String,
pub full_name: String,
pub avatar_url: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Comment {
pub id: Uuid,
pub post_id: Uuid,
pub author_id: Uuid,
pub content: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
// API Response wrappers
#[derive(Serialize)]
pub struct ApiResponse<T> {
pub success: bool,
pub data: Option<T>,
pub message: Option<String>,
pub timestamp: DateTime<Utc>,
}
#[derive(Serialize)]
pub struct PaginatedResponse<T> {
pub data: Vec<T>,
pub pagination: PaginationInfo,
}
#[derive(Serialize)]
pub struct PaginationInfo {
pub page: u32,
pub per_page: u32,
pub total: u64,
pub total_pages: u32,
}
impl<T> ApiResponse<T> {
pub fn success(data: T) -> Self {
Self {
success: true,
data: Some(data),
message: None,
timestamp: Utc::now(),
}
}
pub fn error(message: String) -> Self {
Self {
success: false,
data: None,
message: Some(message),
timestamp: Utc::now(),
}
}
}
// Request DTOs (Data Transfer Objects)
#[derive(Debug, Deserialize)]
pub struct CreatePostRequest {
pub title: String,
pub content: String,
pub published: Option<bool>,
pub tags: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct UpdatePostRequest {
pub title: Option<String>,
pub content: Option<String>,
pub published: Option<bool>,
pub tags: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct CreateUserRequest {
pub username: String,
pub email: String,
pub full_name: String,
pub password: String,
}
#[derive(Debug, Deserialize)]
pub struct UpdateUserRequest {
pub username: Option<String>,
pub email: Option<String>,
pub full_name: Option<String>,
pub avatar_url: Option<String>,
}
// Query parameters
#[derive(Debug, Deserialize)]
pub struct PostQuery {
pub page: Option<u32>,
pub per_page: Option<u32>,
pub published: Option<bool>,
pub author_id: Option<Uuid>,
pub tag: Option<String>,
pub search: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct UserQuery {
pub page: Option<u32>,
pub per_page: Option<u32>,
pub search: Option<String>,
}
CRUD Operations Implementation
Complete CRUD for Posts
use std::sync::Arc;
use tokio::sync::RwLock;
use std::collections::HashMap;
// Application state and database abstraction
#[derive(Clone)]
pub struct AppState {
pub db: Arc<Database>,
pub auth: Arc<AuthService>,
pub config: Arc<Config>,
}
// Simplified in-memory database for demo
#[derive(Default)]
pub struct Database {
pub posts: RwLock<HashMap<Uuid, Post>>,
pub users: RwLock<HashMap<Uuid, User>>,
pub comments: RwLock<HashMap<Uuid, Comment>>,
}
// Posts CRUD handlers
pub async fn list_posts(
Query(params): Query<PostQuery>,
State(state): State<AppState>,
) -> Result<Json<PaginatedResponse<Post>>, ApiError> {
let posts = state.db.posts.read().await;
let mut filtered_posts: Vec<Post> = posts.values().cloned().collect();
// Apply filters
if let Some(published) = params.published {
filtered_posts.retain(|p| p.published == published);
}
if let Some(author_id) = params.author_id {
filtered_posts.retain(|p| p.author_id == author_id);
}
if let Some(tag) = ¶ms.tag {
filtered_posts.retain(|p| p.tags.contains(tag));
}
if let Some(search) = ¶ms.search {
let search_lower = search.to_lowercase();
filtered_posts.retain(|p| {
p.title.to_lowercase().contains(&search_lower) ||
p.content.to_lowercase().contains(&search_lower)
});
}
// Sort by creation date (newest first)
filtered_posts.sort_by(|a, b| b.created_at.cmp(&a.created_at));
// Pagination
let page = params.page.unwrap_or(1);
let per_page = params.per_page.unwrap_or(10).min(100); // Max 100 per page
let total = filtered_posts.len() as u64;
let total_pages = ((total as f64) / (per_page as f64)).ceil() as u32;
let start = ((page - 1) * per_page) as usize;
let end = (start + per_page as usize).min(filtered_posts.len());
let paginated_posts = filtered_posts[start..end].to_vec();
let response = PaginatedResponse {
data: paginated_posts,
pagination: PaginationInfo {
page,
per_page,
total,
total_pages,
},
};
Ok(Json(response))
}
pub async fn get_post(
Path(id): Path<Uuid>,
State(state): State<AppState>,
) -> Result<Json<ApiResponse<Post>>, ApiError> {
let posts = state.db.posts.read().await;
match posts.get(&id) {
Some(post) => Ok(Json(ApiResponse::success(post.clone()))),
None => Err(ApiError::NotFound("Post not found".to_string())),
}
}
pub async fn create_post(
State(state): State<AppState>,
Json(payload): Json<CreatePostRequest>,
current_user: CurrentUser,
) -> Result<Json<ApiResponse<Post>>, ApiError> {
// Validate request
if payload.title.trim().is_empty() {
return Err(ApiError::BadRequest("Title cannot be empty".to_string()));
}
if payload.content.trim().is_empty() {
return Err(ApiError::BadRequest("Content cannot be empty".to_string()));
}
let mut posts = state.db.posts.write().await;
let now = Utc::now();
let post = Post {
id: Uuid::new_v4(),
title: payload.title.trim().to_string(),
content: payload.content.trim().to_string(),
author_id: current_user.id,
published: payload.published.unwrap_or(false),
created_at: now,
updated_at: now,
tags: payload.tags.unwrap_or_default(),
};
posts.insert(post.id, post.clone());
Ok(Json(ApiResponse::success(post)))
}
pub async fn update_post(
Path(id): Path<Uuid>,
State(state): State<AppState>,
Json(payload): Json<UpdatePostRequest>,
current_user: CurrentUser,
) -> Result<Json<ApiResponse<Post>>, ApiError> {
let mut posts = state.db.posts.write().await;
let post = posts.get_mut(&id)
.ok_or_else(|| ApiError::NotFound("Post not found".to_string()))?;
// Check ownership
if post.author_id != current_user.id && !current_user.is_admin {
return Err(ApiError::Forbidden("You can only edit your own posts".to_string()));
}
// Update fields if provided
if let Some(title) = payload.title {
if title.trim().is_empty() {
return Err(ApiError::BadRequest("Title cannot be empty".to_string()));
}
post.title = title.trim().to_string();
}
if let Some(content) = payload.content {
if content.trim().is_empty() {
return Err(ApiError::BadRequest("Content cannot be empty".to_string()));
}
post.content = content.trim().to_string();
}
if let Some(published) = payload.published {
post.published = published;
}
if let Some(tags) = payload.tags {
post.tags = tags;
}
post.updated_at = Utc::now();
Ok(Json(ApiResponse::success(post.clone())))
}
pub async fn delete_post(
Path(id): Path<Uuid>,
State(state): State<AppState>,
current_user: CurrentUser,
) -> Result<StatusCode, ApiError> {
let mut posts = state.db.posts.write().await;
let post = posts.get(&id)
.ok_or_else(|| ApiError::NotFound("Post not found".to_string()))?;
// Check ownership
if post.author_id != current_user.id && !current_user.is_admin {
return Err(ApiError::Forbidden("You can only delete your own posts".to_string()));
}
posts.remove(&id);
// Also remove associated comments
let mut comments = state.db.comments.write().await;
comments.retain(|_, comment| comment.post_id != id);
Ok(StatusCode::NO_CONTENT)
}
Comments as Nested Resources
pub async fn list_post_comments(
Path(post_id): Path<Uuid>,
Query(params): Query<CommentQuery>,
State(state): State<AppState>,
) -> Result<Json<PaginatedResponse<CommentWithAuthor>>, ApiError> {
// Verify post exists
let posts = state.db.posts.read().await;
if !posts.contains_key(&post_id) {
return Err(ApiError::NotFound("Post not found".to_string()));
}
drop(posts);
let comments = state.db.comments.read().await;
let users = state.db.users.read().await;
let mut post_comments: Vec<CommentWithAuthor> = comments
.values()
.filter(|c| c.post_id == post_id)
.filter_map(|comment| {
users.get(&comment.author_id).map(|author| CommentWithAuthor {
id: comment.id,
post_id: comment.post_id,
content: comment.content.clone(),
author: UserSummary {
id: author.id,
username: author.username.clone(),
full_name: author.full_name.clone(),
avatar_url: author.avatar_url.clone(),
},
created_at: comment.created_at,
updated_at: comment.updated_at,
})
})
.collect();
// Sort by creation date (oldest first for comments)
post_comments.sort_by(|a, b| a.created_at.cmp(&b.created_at));
// Apply pagination
let page = params.page.unwrap_or(1);
let per_page = params.per_page.unwrap_or(20).min(100);
let total = post_comments.len() as u64;
let total_pages = ((total as f64) / (per_page as f64)).ceil() as u32;
let start = ((page - 1) * per_page) as usize;
let end = (start + per_page as usize).min(post_comments.len());
let paginated_comments = post_comments[start..end].to_vec();
let response = PaginatedResponse {
data: paginated_comments,
pagination: PaginationInfo {
page,
per_page,
total,
total_pages,
},
};
Ok(Json(response))
}
pub async fn create_comment(
Path(post_id): Path<Uuid>,
State(state): State<AppState>,
Json(payload): Json<CreateCommentRequest>,
current_user: CurrentUser,
) -> Result<Json<ApiResponse<Comment>>, ApiError> {
// Verify post exists
let posts = state.db.posts.read().await;
if !posts.contains_key(&post_id) {
return Err(ApiError::NotFound("Post not found".to_string()));
}
drop(posts);
// Validate comment content
if payload.content.trim().is_empty() {
return Err(ApiError::BadRequest("Comment content cannot be empty".to_string()));
}
let mut comments = state.db.comments.write().await;
let now = Utc::now();
let comment = Comment {
id: Uuid::new_v4(),
post_id,
author_id: current_user.id,
content: payload.content.trim().to_string(),
created_at: now,
updated_at: now,
};
comments.insert(comment.id, comment.clone());
Ok(Json(ApiResponse::success(comment)))
}
#[derive(Debug, Serialize)]
pub struct CommentWithAuthor {
pub id: Uuid,
pub post_id: Uuid,
pub content: String,
pub author: UserSummary,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Serialize)]
pub struct UserSummary {
pub id: Uuid,
pub username: String,
pub full_name: String,
pub avatar_url: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct CreateCommentRequest {
pub content: String,
}
#[derive(Debug, Deserialize)]
pub struct CommentQuery {
pub page: Option<u32>,
pub per_page: Option<u32>,
}
Authentication & Authorization
JWT-based Authentication
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation, Algorithm};
use axum::{
async_trait,
extract::{FromRequestParts, TypedHeader},
headers::{authorization::Bearer, Authorization},
http::request::Parts,
RequestPartsExt,
};
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String, // Subject (user ID)
pub username: String, // Username
pub email: String, // Email
pub is_admin: bool, // Admin flag
pub exp: u64, // Expiration time
pub iat: u64, // Issued at
pub jti: String, // JWT ID
}
#[derive(Clone)]
pub struct AuthService {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
validation: Validation,
}
impl AuthService {
pub fn new(secret: &str) -> Self {
let mut validation = Validation::new(Algorithm::HS256);
validation.required_spec_claims.insert("exp".to_string());
validation.required_spec_claims.insert("sub".to_string());
Self {
encoding_key: EncodingKey::from_secret(secret.as_ref()),
decoding_key: DecodingKey::from_secret(secret.as_ref()),
validation,
}
}
pub fn create_token(&self, user: &User) -> Result<String, AuthError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = Claims {
sub: user.id.to_string(),
username: user.username.clone(),
email: user.email.clone(),
is_admin: false, // This should come from user data
exp: now + 24 * 60 * 60, // 24 hours
iat: now,
jti: Uuid::new_v4().to_string(),
};
encode(&Header::default(), &claims, &self.encoding_key)
.map_err(|_| AuthError::TokenCreation)
}
pub fn verify_token(&self, token: &str) -> Result<Claims, AuthError> {
decode::<Claims>(token, &self.decoding_key, &self.validation)
.map(|data| data.claims)
.map_err(|_| AuthError::InvalidToken)
}
}
// Current user extractor
#[derive(Debug, Clone)]
pub struct CurrentUser {
pub id: Uuid,
pub username: String,
pub email: String,
pub is_admin: bool,
}
#[async_trait]
impl<S> FromRequestParts<S> for CurrentUser
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// Extract authorization header
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AuthError::MissingToken)?;
// Get auth service from app state
let auth_service = parts
.extensions
.get::<Arc<AuthService>>()
.ok_or(AuthError::ConfigError)?;
// Verify token
let claims = auth_service.verify_token(bearer.token())?;
// Parse user ID
let user_id = Uuid::parse_str(&claims.sub)
.map_err(|_| AuthError::InvalidToken)?;
Ok(CurrentUser {
id: user_id,
username: claims.username,
email: claims.email,
is_admin: claims.is_admin,
})
}
}
// Optional user extractor (for endpoints that work with or without auth)
#[derive(Debug, Clone)]
pub struct OptionalUser(pub Option<CurrentUser>);
#[async_trait]
impl<S> FromRequestParts<S> for OptionalUser
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
match CurrentUser::from_request_parts(parts, state).await {
Ok(user) => Ok(OptionalUser(Some(user))),
Err(_) => Ok(OptionalUser(None)),
}
}
}
// Authentication endpoints
pub async fn login(
State(state): State<AppState>,
Json(credentials): Json<LoginRequest>,
) -> Result<Json<LoginResponse>, ApiError> {
let users = state.db.users.read().await;
// Find user by username or email
let user = users
.values()
.find(|u| u.username == credentials.username || u.email == credentials.username)
.ok_or_else(|| ApiError::Unauthorized("Invalid credentials".to_string()))?;
// Verify password (in real app, use proper password hashing)
if !verify_password(&credentials.password, &user.password_hash) {
return Err(ApiError::Unauthorized("Invalid credentials".to_string()));
}
// Create JWT token
let token = state.auth.create_token(user)
.map_err(|_| ApiError::InternalError("Failed to create token".to_string()))?;
Ok(Json(LoginResponse {
token,
user: UserProfile {
id: user.id,
username: user.username.clone(),
email: user.email.clone(),
full_name: user.full_name.clone(),
avatar_url: user.avatar_url.clone(),
created_at: user.created_at,
},
}))
}
pub async fn register(
State(state): State<AppState>,
Json(payload): Json<CreateUserRequest>,
) -> Result<Json<ApiResponse<UserProfile>>, ApiError> {
// Validate input
if payload.username.trim().is_empty() {
return Err(ApiError::BadRequest("Username cannot be empty".to_string()));
}
if payload.email.trim().is_empty() || !is_valid_email(&payload.email) {
return Err(ApiError::BadRequest("Invalid email address".to_string()));
}
if payload.password.len() < 8 {
return Err(ApiError::BadRequest("Password must be at least 8 characters".to_string()));
}
let mut users = state.db.users.write().await;
// Check if username or email already exists
for user in users.values() {
if user.username == payload.username {
return Err(ApiError::Conflict("Username already exists".to_string()));
}
if user.email == payload.email {
return Err(ApiError::Conflict("Email already exists".to_string()));
}
}
// Create new user
let now = Utc::now();
let password_hash = hash_password(&payload.password)?;
let user = User {
id: Uuid::new_v4(),
username: payload.username.trim().to_string(),
email: payload.email.trim().to_lowercase(),
full_name: payload.full_name.trim().to_string(),
password_hash,
avatar_url: None,
created_at: now,
updated_at: now,
};
let user_profile = UserProfile {
id: user.id,
username: user.username.clone(),
email: user.email.clone(),
full_name: user.full_name.clone(),
avatar_url: user.avatar_url.clone(),
created_at: user.created_at,
};
users.insert(user.id, user);
Ok(Json(ApiResponse::success(user_profile)))
}
pub async fn get_current_user(
current_user: CurrentUser,
) -> Result<Json<ApiResponse<CurrentUser>>, ApiError> {
Ok(Json(ApiResponse::success(current_user)))
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub username: String, // Can be username or email
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub token: String,
pub user: UserProfile,
}
#[derive(Debug, Serialize)]
pub struct UserProfile {
pub id: Uuid,
pub username: String,
pub email: String,
pub full_name: String,
pub avatar_url: Option<String>,
pub created_at: DateTime<Utc>,
}
// Error types
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("Missing authentication token")]
MissingToken,
#[error("Invalid authentication token")]
InvalidToken,
#[error("Token creation failed")]
TokenCreation,
#[error("Authentication configuration error")]
ConfigError,
}
// Helper functions (simplified for demo)
fn verify_password(password: &str, hash: &str) -> bool {
// In real app, use bcrypt or argon2
password == hash // NEVER do this in production!
}
fn hash_password(password: &str) -> Result<String, ApiError> {
// In real app, use bcrypt or argon2
Ok(password.to_string()) // NEVER do this in production!
}
fn is_valid_email(email: &str) -> bool {
email.contains('@') // Very basic validation
}
Error Handling
Comprehensive Error System
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ApiError {
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Conflict: {0}")]
Conflict(String),
#[error("Validation failed: {0}")]
ValidationError(String),
#[error("Internal server error: {0}")]
InternalError(String),
#[error("Service unavailable: {0}")]
ServiceUnavailable(String),
#[error("Database error: {0}")]
DatabaseError(String),
#[error("Authentication error: {0}")]
AuthError(#[from] AuthError),
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let (status_code, error_code, message) = match self {
ApiError::BadRequest(msg) => (StatusCode::BAD_REQUEST, "BAD_REQUEST", msg),
ApiError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg),
ApiError::Forbidden(msg) => (StatusCode::FORBIDDEN, "FORBIDDEN", msg),
ApiError::NotFound(msg) => (StatusCode::NOT_FOUND, "NOT_FOUND", msg),
ApiError::Conflict(msg) => (StatusCode::CONFLICT, "CONFLICT", msg),
ApiError::ValidationError(msg) => (StatusCode::UNPROCESSABLE_ENTITY, "VALIDATION_ERROR", msg),
ApiError::InternalError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, "INTERNAL_ERROR", msg),
ApiError::ServiceUnavailable(msg) => (StatusCode::SERVICE_UNAVAILABLE, "SERVICE_UNAVAILABLE", msg),
ApiError::DatabaseError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, "DATABASE_ERROR", msg),
ApiError::AuthError(auth_err) => {
match auth_err {
AuthError::MissingToken | AuthError::InvalidToken => {
(StatusCode::UNAUTHORIZED, "INVALID_TOKEN", auth_err.to_string())
}
_ => (StatusCode::INTERNAL_SERVER_ERROR, "AUTH_ERROR", auth_err.to_string())
}
}
};
let error_response = ErrorResponse {
success: false,
error: ErrorDetail {
code: error_code.to_string(),
message,
timestamp: Utc::now(),
},
};
(status_code, Json(error_response)).into_response()
}
}
#[derive(Serialize)]
struct ErrorResponse {
success: bool,
error: ErrorDetail,
}
#[derive(Serialize)]
struct ErrorDetail {
code: String,
message: String,
timestamp: DateTime<Utc>,
}
// Result type alias for convenience
pub type ApiResult<T> = Result<T, ApiError>;
// Validation error handling
pub fn validate_post_request(req: &CreatePostRequest) -> Result<(), ApiError> {
let mut errors = Vec::new();
if req.title.trim().is_empty() {
errors.push("Title cannot be empty");
} else if req.title.len() > 200 {
errors.push("Title cannot exceed 200 characters");
}
if req.content.trim().is_empty() {
errors.push("Content cannot be empty");
} else if req.content.len() > 50000 {
errors.push("Content cannot exceed 50,000 characters");
}
if let Some(tags) = &req.tags {
if tags.len() > 10 {
errors.push("Cannot have more than 10 tags");
}
for tag in tags {
if tag.trim().is_empty() {
errors.push("Tags cannot be empty");
break;
}
if tag.len() > 50 {
errors.push("Tags cannot exceed 50 characters");
break;
}
}
}
if !errors.is_empty() {
return Err(ApiError::ValidationError(errors.join(", ")));
}
Ok(())
}
// Middleware for error logging
pub async fn error_handler(
uri: axum::http::Uri,
method: axum::http::Method,
error: ApiError,
) -> Response {
// Log error details
match &error {
ApiError::InternalError(_) | ApiError::DatabaseError(_) => {
tracing::error!(
method = %method,
uri = %uri,
error = %error,
"Internal server error occurred"
);
}
_ => {
tracing::warn!(
method = %method,
uri = %uri,
error = %error,
"Client error occurred"
);
}
}
error.into_response()
}
Input Validation
Request Validation with Validator
use validator::{Validate, ValidationError, ValidationErrors};
use axum::{
async_trait,
extract::{FromRequest, Request},
Json,
};
// Validated JSON extractor
pub struct ValidatedJson<T>(pub T);
#[async_trait]
impl<T, S> FromRequest<S> for ValidatedJson<T>
where
T: DeserializeOwned + Validate,
S: Send + Sync,
Json<T>: FromRequest<S>,
{
type Rejection = ApiError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let Json(value) = Json::<T>::from_request(req, state)
.await
.map_err(|err| ApiError::BadRequest(format!("Invalid JSON: {}", err)))?;
value.validate()
.map_err(|errors| ApiError::ValidationError(format_validation_errors(errors)))?;
Ok(ValidatedJson(value))
}
}
// Enhanced request DTOs with validation
#[derive(Debug, Deserialize, Validate)]
pub struct CreatePostRequestValidated {
#[validate(length(min = 1, max = 200, message = "Title must be between 1 and 200 characters"))]
pub title: String,
#[validate(length(min = 1, max = 50000, message = "Content must be between 1 and 50,000 characters"))]
pub content: String,
pub published: Option<bool>,
#[validate(length(max = 10, message = "Cannot have more than 10 tags"))]
#[validate(custom = "validate_tags")]
pub tags: Option<Vec<String>>,
}
#[derive(Debug, Deserialize, Validate)]
pub struct CreateUserRequestValidated {
#[validate(length(min = 3, max = 30, message = "Username must be between 3 and 30 characters"))]
#[validate(regex = "USERNAME_REGEX", message = "Username can only contain letters, numbers, and underscores"))]
pub username: String,
#[validate(email(message = "Invalid email address"))]
pub email: String,
#[validate(length(min = 1, max = 100, message = "Full name must be between 1 and 100 characters"))]
pub full_name: String,
#[validate(length(min = 8, message = "Password must be at least 8 characters"))]
#[validate(custom = "validate_password")]
pub password: String,
}
// Custom validators
fn validate_tags(tags: &[String]) -> Result<(), ValidationError> {
for tag in tags {
if tag.trim().is_empty() {
return Err(ValidationError::new("Tags cannot be empty"));
}
if tag.len() > 50 {
return Err(ValidationError::new("Tags cannot exceed 50 characters"));
}
if !tag.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_') {
return Err(ValidationError::new("Tags can only contain letters, numbers, hyphens, and underscores"));
}
}
Ok(())
}
fn validate_password(password: &str) -> Result<(), ValidationError> {
let has_lowercase = password.chars().any(|c| c.is_lowercase());
let has_uppercase = password.chars().any(|c| c.is_uppercase());
let has_digit = password.chars().any(|c| c.is_numeric());
let has_special = password.chars().any(|c| !c.is_alphanumeric());
if !has_lowercase {
return Err(ValidationError::new("Password must contain at least one lowercase letter"));
}
if !has_uppercase {
return Err(ValidationError::new("Password must contain at least one uppercase letter"));
}
if !has_digit {
return Err(ValidationError::new("Password must contain at least one digit"));
}
if !has_special {
return Err(ValidationError::new("Password must contain at least one special character"));
}
Ok(())
}
// Regex for username validation
lazy_static::lazy_static! {
static ref USERNAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_]+$").unwrap();
}
// Format validation errors into user-friendly messages
fn format_validation_errors(errors: ValidationErrors) -> String {
let mut messages = Vec::new();
for (field, field_errors) in errors.field_errors() {
for error in field_errors {
let message = error.message
.as_ref()
.map(|m| m.to_string())
.unwrap_or_else(|| format!("Invalid value for field '{}'", field));
messages.push(message);
}
}
messages.join(", ")
}
// Updated handler using validated input
pub async fn create_post_validated(
State(state): State<AppState>,
ValidatedJson(payload): ValidatedJson<CreatePostRequestValidated>,
current_user: CurrentUser,
) -> Result<Json<ApiResponse<Post>>, ApiError> {
let mut posts = state.db.posts.write().await;
let now = Utc::now();
let post = Post {
id: Uuid::new_v4(),
title: payload.title.trim().to_string(),
content: payload.content.trim().to_string(),
author_id: current_user.id,
published: payload.published.unwrap_or(false),
created_at: now,
updated_at: now,
tags: payload.tags.unwrap_or_default(),
};
posts.insert(post.id, post.clone());
Ok(Json(ApiResponse::success(post)))
}
API Testing
Unit and Integration Tests
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
};
use tower::ServiceExt; // for `oneshot`
use serde_json::json;
// Test setup helpers
async fn create_test_app() -> Router {
let db = Arc::new(Database::default());
let auth = Arc::new(AuthService::new("test-secret"));
let config = Arc::new(Config::default());
let state = AppState { db, auth, config };
create_router(state)
}
async fn create_test_user(app_state: &AppState) -> (Uuid, String) {
let user = User {
id: Uuid::new_v4(),
username: "testuser".to_string(),
email: "[email protected]".to_string(),
full_name: "Test User".to_string(),
password_hash: "password".to_string(), // Not secure, just for testing
avatar_url: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let token = app_state.auth.create_token(&user).unwrap();
let user_id = user.id;
app_state.db.users.write().await.insert(user.id, user);
(user_id, token)
}
#[tokio::test]
async fn test_create_post() {
let app = create_test_app().await;
let app_state = app.layer_stack().get::<AppState>().unwrap();
let (_user_id, token) = create_test_user(app_state).await;
let post_data = json!({
"title": "Test Post",
"content": "This is a test post content",
"published": true,
"tags": ["rust", "web"]
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/posts")
.header("authorization", format!("Bearer {}", token))
.header("content-type", "application/json")
.body(Body::from(post_data.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let response_data: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(response_data["success"], true);
assert_eq!(response_data["data"]["title"], "Test Post");
assert_eq!(response_data["data"]["published"], true);
}
#[tokio::test]
async fn test_get_posts_pagination() {
let app = create_test_app().await;
let app_state = app.layer_stack().get::<AppState>().unwrap();
let (user_id, _token) = create_test_user(app_state).await;
// Create test posts
let mut posts = app_state.db.posts.write().await;
for i in 0..25 {
let post = Post {
id: Uuid::new_v4(),
title: format!("Test Post {}", i),
content: format!("Content for post {}", i),
author_id: user_id,
published: true,
created_at: Utc::now(),
updated_at: Utc::now(),
tags: vec![],
};
posts.insert(post.id, post);
}
drop(posts);
// Test pagination
let response = app
.oneshot(
Request::builder()
.method("GET")
.uri("/api/posts?page=2&per_page=10")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let response_data: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(response_data["data"].as_array().unwrap().len(), 10);
assert_eq!(response_data["pagination"]["page"], 2);
assert_eq!(response_data["pagination"]["per_page"], 10);
assert_eq!(response_data["pagination"]["total"], 25);
assert_eq!(response_data["pagination"]["total_pages"], 3);
}
#[tokio::test]
async fn test_authentication_required() {
let app = create_test_app().await;
let post_data = json!({
"title": "Test Post",
"content": "This should fail without auth"
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/posts")
.header("content-type", "application/json")
.body(Body::from(post_data.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_validation_errors() {
let app = create_test_app().await;
let app_state = app.layer_stack().get::<AppState>().unwrap();
let (_user_id, token) = create_test_user(app_state).await;
// Test with empty title
let invalid_post = json!({
"title": "",
"content": "This has an empty title"
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/posts")
.header("authorization", format!("Bearer {}", token))
.header("content-type", "application/json")
.body(Body::from(invalid_post.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let response_data: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(response_data["success"], false);
assert!(response_data["error"]["message"].as_str().unwrap().contains("Title"));
}
#[tokio::test]
async fn test_user_registration() {
let app = create_test_app().await;
let user_data = json!({
"username": "newuser",
"email": "[email protected]",
"full_name": "New User",
"password": "SecurePass123!"
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/auth/register")
.header("content-type", "application/json")
.body(Body::from(user_data.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let response_data: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(response_data["success"], true);
assert_eq!(response_data["data"]["username"], "newuser");
assert_eq!(response_data["data"]["email"], "[email protected]");
}
#[tokio::test]
async fn test_duplicate_username_registration() {
let app = create_test_app().await;
let app_state = app.layer_stack().get::<AppState>().unwrap();
let (_user_id, _token) = create_test_user(app_state).await;
// Try to register with existing username
let duplicate_user = json!({
"username": "testuser", // This already exists
"email": "[email protected]",
"full_name": "Different User",
"password": "SecurePass123!"
});
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/auth/register")
.header("content-type", "application/json")
.body(Body::from(duplicate_user.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::CONFLICT);
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let response_data: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(response_data["success"], false);
assert!(response_data["error"]["message"].as_str().unwrap().contains("Username already exists"));
}
// Performance test example
#[tokio::test]
async fn test_concurrent_post_creation() {
let app = create_test_app().await;
let app_state = app.layer_stack().get::<AppState>().unwrap();
let (_user_id, token) = create_test_user(app_state).await;
let tasks: Vec<_> = (0..100)
.map(|i| {
let app_clone = app.clone();
let token_clone = token.clone();
tokio::spawn(async move {
let post_data = json!({
"title": format!("Concurrent Post {}", i),
"content": format!("Content for concurrent post {}", i),
"published": true
});
app_clone
.oneshot(
Request::builder()
.method("POST")
.uri("/api/posts")
.header("authorization", format!("Bearer {}", token_clone))
.header("content-type", "application/json")
.body(Body::from(post_data.to_string()))
.unwrap(),
)
.await
.unwrap()
.status()
})
})
.collect();
let results = futures::future::join_all(tasks).await;
// All requests should succeed
for result in results {
assert_eq!(result.unwrap(), StatusCode::OK);
}
// Verify all posts were created
let posts = app_state.db.posts.read().await;
assert_eq!(posts.len(), 100);
}
}
// Load testing with criterion
#[cfg(test)]
mod bench {
use super::*;
use criterion::{criterion_group, criterion_main, Criterion};
fn bench_post_creation(c: &mut Criterion) {
let rt = tokio::runtime::Runtime::new().unwrap();
c.bench_function("create_post", |b| {
b.to_async(&rt).iter(|| async {
let app = create_test_app().await;
let app_state = app.layer_stack().get::<AppState>().unwrap();
let (_user_id, token) = create_test_user(app_state).await;
let post_data = json!({
"title": "Benchmark Post",
"content": "Benchmark content",
"published": true
});
app.oneshot(
Request::builder()
.method("POST")
.uri("/api/posts")
.header("authorization", format!("Bearer {}", token))
.header("content-type", "application/json")
.body(Body::from(post_data.to_string()))
.unwrap(),
)
.await
.unwrap()
});
});
}
criterion_group!(benches, bench_post_creation);
criterion_main!(benches);
}
API Documentation
OpenAPI/Swagger Documentation
use utoipa::{OpenApi, ToSchema};
use utoipa_swagger_ui::SwaggerUi;
#[derive(OpenApi)]
#[openapi(
paths(
list_posts,
get_post,
create_post,
update_post,
delete_post,
login,
register,
get_current_user
),
components(
schemas(
Post,
User,
CreatePostRequest,
UpdatePostRequest,
LoginRequest,
LoginResponse,
ApiResponse<Post>,
PaginatedResponse<Post>,
ErrorResponse
)
),
tags(
(name = "posts", description = "Blog post management"),
(name = "auth", description = "Authentication endpoints"),
(name = "users", description = "User management")
),
info(
title = "Blog API",
description = "A REST API for a blog application built with Rust and Axum",
version = "1.0.0",
contact(
name = "API Support",
email = "[email protected]"
),
license(
name = "MIT",
url = "https://opensource.org/licenses/MIT"
)
),
servers(
(url = "http://localhost:3000", description = "Local development server"),
(url = "https://api.blogapi.com", description = "Production server")
),
security(
("bearer_auth" = [])
)
)]
pub struct ApiDoc;
// Add schemas for documentation
#[derive(ToSchema, Serialize, Deserialize)]
pub struct Post {
#[schema(example = "01234567-89ab-cdef-0123-456789abcdef")]
pub id: Uuid,
#[schema(example = "My First Blog Post")]
pub title: String,
#[schema(example = "This is the content of my first blog post...")]
pub content: String,
#[schema(example = "01234567-89ab-cdef-0123-456789abcdef")]
pub author_id: Uuid,
#[schema(example = true)]
pub published: bool,
#[schema(example = "2023-01-01T00:00:00Z")]
pub created_at: DateTime<Utc>,
#[schema(example = "2023-01-01T00:00:00Z")]
pub updated_at: DateTime<Utc>,
#[schema(example = json!(["rust", "web", "api"]))]
pub tags: Vec<String>,
}
// Document API endpoints
#[utoipa::path(
get,
path = "/api/posts",
tag = "posts",
summary = "List blog posts",
description = "Retrieve a paginated list of blog posts with optional filtering",
params(
("page" = Option<u32>, Query, description = "Page number (default: 1)"),
("per_page" = Option<u32>, Query, description = "Items per page (default: 10, max: 100)"),
("published" = Option<bool>, Query, description = "Filter by published status"),
("author_id" = Option<Uuid>, Query, description = "Filter by author ID"),
("tag" = Option<String>, Query, description = "Filter by tag"),
("search" = Option<String>, Query, description = "Search in title and content")
),
responses(
(status = 200, description = "List of posts retrieved successfully", body = PaginatedResponse<Post>),
(status = 400, description = "Invalid query parameters", body = ErrorResponse)
)
)]
pub async fn list_posts(
Query(params): Query<PostQuery>,
State(state): State<AppState>,
) -> Result<Json<PaginatedResponse<Post>>, ApiError> {
// Implementation here...
}
#[utoipa::path(
post,
path = "/api/posts",
tag = "posts",
summary = "Create a new blog post",
description = "Create a new blog post. Requires authentication.",
request_body = CreatePostRequest,
responses(
(status = 200, description = "Post created successfully", body = ApiResponse<Post>),
(status = 400, description = "Invalid input", body = ErrorResponse),
(status = 401, description = "Authentication required", body = ErrorResponse),
(status = 422, description = "Validation failed", body = ErrorResponse)
),
security(
("bearer_auth" = [])
)
)]
pub async fn create_post(
State(state): State<AppState>,
Json(payload): Json<CreatePostRequest>,
current_user: CurrentUser,
) -> Result<Json<ApiResponse<Post>>, ApiError> {
// Implementation here...
}
// Add security scheme
impl ApiDoc {
pub fn with_security() -> utoipa::openapi::OpenApi {
use utoipa::openapi::security::{HttpAuthScheme, HttpBuilder, SecurityScheme};
Self::openapi()
.components(Some(
utoipa::openapi::ComponentsBuilder::new()
.security_scheme(
"bearer_auth",
SecurityScheme::Http(
HttpBuilder::new()
.scheme(HttpAuthScheme::Bearer)
.bearer_format("JWT")
.build(),
),
)
.build(),
))
}
}
// Router setup with Swagger UI
pub fn create_router(state: AppState) -> Router {
Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::with_security()))
.route("/api/posts", get(list_posts).post(create_post))
.route("/api/posts/:id", get(get_post).put(update_post).delete(delete_post))
.route("/api/auth/login", post(login))
.route("/api/auth/register", post(register))
.route("/api/auth/me", get(get_current_user))
.with_state(state)
}
Performance and Deployment
Production Configuration
// Configuration management
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub auth: AuthConfig,
pub logging: LoggingConfig,
pub cors: CorsConfig,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub workers: Option<usize>,
pub max_connections: usize,
pub timeout_seconds: u64,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct DatabaseConfig {
pub url: String,
pub max_connections: u32,
pub min_connections: u32,
pub connect_timeout_seconds: u64,
pub idle_timeout_seconds: u64,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct AuthConfig {
pub jwt_secret: String,
pub token_expiry_hours: u64,
pub refresh_token_expiry_days: u64,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LoggingConfig {
pub level: String,
pub json_format: bool,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct CorsConfig {
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<String>,
pub allowed_headers: Vec<String>,
pub max_age_seconds: usize,
}
impl Config {
pub fn from_env() -> Result<Self, config::ConfigError> {
let settings = config::Config::builder()
.add_source(config::File::with_name("config/default"))
.add_source(config::File::with_name("config/production").required(false))
.add_source(config::Environment::with_prefix("API"))
.build()?;
settings.try_deserialize()
}
}
// Production server setup
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Load configuration
let config = Config::from_env()?;
// Setup logging
setup_logging(&config.logging)?;
// Setup database
let db_pool = create_database_pool(&config.database).await?;
// Setup authentication
let auth_service = Arc::new(AuthService::new(&config.auth.jwt_secret));
// Create application state
let app_state = AppState {
db: Arc::new(Database::new(db_pool)),
auth: auth_service,
config: Arc::new(config.clone()),
};
// Setup CORS
let cors = setup_cors(&config.cors);
// Create router
let app = create_router(app_state)
.layer(cors)
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.layer(TimeoutLayer::new(Duration::from_secs(config.server.timeout_seconds)));
// Start server
let bind_address = format!("{}:{}", config.server.host, config.server.port);
let listener = tokio::net::TcpListener::bind(&bind_address).await?;
tracing::info!("Server starting on {}", bind_address);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}
fn setup_logging(config: &LoggingConfig) -> Result<(), Box<dyn std::error::Error>> {
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(&config.level));
if config.json_format {
tracing_subscriber::registry()
.with(env_filter)
.with(tracing_subscriber::fmt::layer().json())
.init();
} else {
tracing_subscriber::registry()
.with(env_filter)
.with(tracing_subscriber::fmt::layer())
.init();
}
Ok(())
}
fn setup_cors(config: &CorsConfig) -> CorsLayer {
use tower_http::cors::{Any, CorsLayer};
use http::{HeaderValue, Method};
let mut cors = CorsLayer::new();
// Setup allowed origins
if config.allowed_origins.contains(&"*".to_string()) {
cors = cors.allow_origin(Any);
} else {
let origins: Result<Vec<HeaderValue>, _> = config
.allowed_origins
.iter()
.map(|origin| origin.parse())
.collect();
if let Ok(origins) = origins {
cors = cors.allow_origin(origins);
}
}
// Setup allowed methods
let methods: Result<Vec<Method>, _> = config
.allowed_methods
.iter()
.map(|method| method.parse())
.collect();
if let Ok(methods) = methods {
cors = cors.allow_methods(methods);
}
// Setup allowed headers
if config.allowed_headers.contains(&"*".to_string()) {
cors = cors.allow_headers(Any);
} else {
let headers: Result<Vec<HeaderValue>, _> = config
.allowed_headers
.iter()
.map(|header| header.parse())
.collect();
if let Ok(headers) = headers {
cors = cors.allow_headers(headers);
}
}
cors.max_age(Duration::from_secs(config.max_age_seconds as u64))
}
// Graceful shutdown
async fn shutdown_signal() {
use tokio::signal;
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
tracing::info!("Received Ctrl+C signal");
},
_ = terminate => {
tracing::info!("Received terminate signal");
},
}
tracing::info!("Starting graceful shutdown");
}
Best Practices Summary
- RESTful Design: Follow REST principles for resource design and HTTP methods
- Input Validation: Validate all input data with clear error messages
- Authentication: Implement secure JWT-based authentication
- Error Handling: Provide consistent, informative error responses
- Documentation: Use OpenAPI/Swagger for API documentation
- Testing: Write comprehensive unit and integration tests
- Pagination: Implement pagination for list endpoints
- Rate Limiting: Protect against abuse with rate limiting
- Logging: Implement structured logging for monitoring
- Security: Follow security best practices (HTTPS, CORS, validation)
- Performance: Use connection pooling and async processing
- Monitoring: Implement health checks and metrics
Building REST APIs in Rust provides excellent performance, safety, and maintainability. The strong type system helps catch errors at compile time, while the async ecosystem enables handling many concurrent connections efficiently. Always prioritize security, testing, and documentation for production APIs.