Request Validation¶
FastAPI Shield provides powerful mechanisms to validate incoming requests beyond what Pydantic offers, allowing for complex validation rules, custom error messaging, and request transformation.
Basic Request Validation¶
from fastapi import FastAPI, Depends, HTTPException
from typing import NewType, Annotated, List, Dict, Any
from pydantic import BaseModel, Field, EmailStr, validator
from fastapi_shield import shield, Shield
app = FastAPI()
# Define a model for the request
class UserRegistration(BaseModel):
username: str = Field(..., min_length=3, max_length=50)
email: EmailStr
password: str = Field(..., min_length=8)
confirm_password: str
@validator('confirm_password')
def passwords_match(cls, v, values):
if 'password' in values and v != values['password']:
raise ValueError('Passwords do not match')
return v
# Define a validated request type
ValidatedRegistration = NewType("ValidatedRegistration", UserRegistration)
# Create a shield to validate the request
@shield
def validate_registration(registration: UserRegistration) -> ValidatedRegistration:
# Check if username exists in database (mock check)
if registration.username == "existing_user":
raise HTTPException(status_code=400, detail="Username already exists")
# Check if email exists in database (mock check)
if registration.email == "existing@example.com":
raise HTTPException(status_code=400, detail="Email already registered")
# Additional complex validation can be performed here
return ValidatedRegistration(registration)
# Use the shield in an endpoint
@app.post("/register")
def register_user(registration: ValidatedRegistration = Depends(validate_registration)):
# Process the validated registration
return {"message": "User registered successfully"}
Input Sanitization¶
Use shields to sanitize input data before processing:
from fastapi import FastAPI, Depends
from typing import NewType, Annotated, Dict, Any
from pydantic import BaseModel
from fastapi_shield import shield, Shield
import re
import html
app = FastAPI()
# Define a model for the request
class Comment(BaseModel):
user_id: int
content: str
# Define a sanitized comment type
SanitizedComment = NewType("SanitizedComment", Comment)
# Shield to sanitize content
@shield
def sanitize_comment(comment: Comment) -> SanitizedComment:
# Escape HTML to prevent XSS
sanitized_content = html.escape(comment.content)
# Remove unwanted patterns (e.g., excessive whitespace)
sanitized_content = re.sub(r'\s+', ' ', sanitized_content).strip()
# Create a new comment with sanitized content
sanitized_comment = Comment(user_id=comment.user_id, content=sanitized_content)
return SanitizedComment(sanitized_comment)
# Use the shield in an endpoint
@app.post("/comments")
def create_comment(comment: SanitizedComment = Depends(sanitize_comment)):
# Store the sanitized comment
return {"message": "Comment created", "comment": comment}
Request Transformation¶
Shields can transform requests before they reach your business logic:
from fastapi import FastAPI, Depends, Query
from typing import NewType, Annotated, List, Dict, Any, Optional
from pydantic import BaseModel
from fastapi_shield import shield, Shield
app = FastAPI()
# Define a model for the API response
class SearchParams(BaseModel):
query: str
page: int = 1
per_page: int = 10
sort_by: Optional[str] = None
filters: Dict[str, Any] = {}
# Define a normalized search params type
NormalizedSearchParams = NewType("NormalizedSearchParams", SearchParams)
# Shield to normalize search parameters
@shield
def normalize_search_params(
query: str,
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
sort_by: Optional[str] = None,
category: Optional[str] = None,
min_price: Optional[float] = None,
max_price: Optional[float] = None,
) -> NormalizedSearchParams:
# Normalize the query
normalized_query = query.lower().strip()
# Build filters dictionary
filters = {}
if category:
filters["category"] = category
if min_price is not None or max_price is not None:
price_filter = {}
if min_price is not None:
price_filter["min"] = min_price
if max_price is not None:
price_filter["max"] = max_price
filters["price"] = price_filter
# Create search params object
search_params = SearchParams(
query=normalized_query,
page=page,
per_page=per_page,
sort_by=sort_by,
filters=filters
)
return NormalizedSearchParams(search_params)
# Use the shield in an endpoint
@app.get("/search")
def search(params: NormalizedSearchParams = Depends(normalize_search_params)):
# Use the normalized search parameters
return {
"query": params.query,
"page": params.page,
"per_page": params.per_page,
"sort_by": params.sort_by,
"filters": params.filters,
"results": [
# In a real application, you would query a database
{"id": 1, "name": "Example result 1"},
{"id": 2, "name": "Example result 2"}
]
}
Contextual Validation¶
Validation that depends on the current state of the application or external resources:
from fastapi import FastAPI, Depends, HTTPException, Header
from typing import NewType, Annotated, Optional
from pydantic import BaseModel
from fastapi_shield import shield, Shield
import time
from datetime import datetime, timedelta
app = FastAPI()
# Define a model for a time-sensitive operation
class TimeSensitiveOperation(BaseModel):
operation_id: str
timestamp: int # Unix timestamp
# Define a validated operation type
ValidatedOperation = NewType("ValidatedOperation", TimeSensitiveOperation)
# Mock database of recent operations
recent_operations = set()
# Shield to validate the operation
@shield
def validate_operation(
operation: TimeSensitiveOperation,
x_request_timestamp: Optional[str] = Header(None)
) -> ValidatedOperation:
# Validate timestamp is recent (within 5 minutes)
current_time = int(time.time())
max_age = 300 # 5 minutes in seconds
if operation.timestamp < (current_time - max_age):
raise HTTPException(
status_code=400,
detail="Operation timestamp too old"
)
# Validate operation hasn't been processed before (prevent replay)
if operation.operation_id in recent_operations:
raise HTTPException(
status_code=400,
detail="Operation already processed"
)
# Add operation to recent operations
recent_operations.add(operation.operation_id)
# Clean up old operations (in a real app, you would use a time-based cache)
if len(recent_operations) > 1000:
recent_operations.clear()
return ValidatedOperation(operation)
# Use the shield in an endpoint
@app.post("/operations")
def perform_operation(operation: ValidatedOperation = Depends(validate_operation)):
# Process the validated operation
return {"status": "success", "operation_id": operation.operation_id}
Combining Multiple Validations¶
You can chain shields to apply multiple validations:
from fastapi import FastAPI, Depends, HTTPException
from typing import NewType, Annotated, List, Dict, Any
from pydantic import BaseModel, Field
from fastapi_shield import shield, Shield
app = FastAPI()
# Define a model for the request
class Product(BaseModel):
name: str
description: str
price: float
stock: int
categories: List[str]
# Define specialized product types
ValidatedProduct = NewType("ValidatedProduct", Product)
SanitizedProduct = NewType("SanitizedProduct", ValidatedProduct)
EnrichedProduct = NewType("EnrichedProduct", SanitizedProduct)
# Shield to validate product data
@shield
def validate_product(product: Product) -> ValidatedProduct:
if product.price <= 0:
raise HTTPException(status_code=400, detail="Price must be positive")
if product.stock < 0:
raise HTTPException(status_code=400, detail="Stock cannot be negative")
if len(product.categories) == 0:
raise HTTPException(status_code=400, detail="At least one category is required")
return ValidatedProduct(product)
# Shield to sanitize product data
@shield
def sanitize_product(product: ValidatedProduct) -> SanitizedProduct:
# Sanitize name and description
sanitized_name = product.name.strip()
sanitized_description = product.description.strip()
# Normalize categories
sanitized_categories = [cat.lower().strip() for cat in product.categories]
# Create sanitized product
sanitized_product = Product(
name=sanitized_name,
description=sanitized_description,
price=product.price,
stock=product.stock,
categories=sanitized_categories
)
return SanitizedProduct(sanitized_product)
# Shield to enrich product data
@shield
def enrich_product(product: SanitizedProduct) -> EnrichedProduct:
# In a real application, you might add additional data
# such as tax information, availability, etc.
return EnrichedProduct(product)
# Combine all shields
def process_product(product: Product) -> EnrichedProduct:
validated = validate_product(product)
sanitized = sanitize_product(validated)
return enrich_product(sanitized)
# Use the combined shield in an endpoint
@app.post("/products")
def create_product(product: EnrichedProduct = Depends(process_product)):
# Store the processed product
return {"message": "Product created", "product": product}
Form Validation¶
Validate form data with shields:
from fastapi import FastAPI, Depends, Form, HTTPException
from typing import NewType, Annotated, Optional
from fastapi_shield import shield, Shield
import re
app = FastAPI()
# Define a validated contact form type
ValidatedContactForm = NewType("ValidatedContactForm", dict)
# Shield to validate contact form
@shield
def validate_contact_form(
name: str = Form(...),
email: str = Form(...),
message: str = Form(...),
phone: Optional[str] = Form(None)
) -> ValidatedContactForm:
# Validate name
if len(name) < 2:
raise HTTPException(status_code=400, detail="Name is too short")
# Validate email
email_pattern = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$')
if not email_pattern.match(email):
raise HTTPException(status_code=400, detail="Invalid email format")
# Validate message
if len(message) < 10:
raise HTTPException(status_code=400, detail="Message is too short")
# Validate phone if provided
if phone:
# Simple validation - in a real app, use a proper phone validation library
phone_pattern = re.compile(r'^\+?[0-9\s\-\(\)]{8,20}$')
if not phone_pattern.match(phone):
raise HTTPException(status_code=400, detail="Invalid phone format")
# Return validated form data
form_data = {
"name": name,
"email": email,
"message": message,
"phone": phone
}
return ValidatedContactForm(form_data)
# Use the shield in an endpoint
@app.post("/contact")
def submit_contact_form(form: ValidatedContactForm = Depends(validate_contact_form)):
# Process the validated form
return {"message": "Form submitted successfully"}
Validation with External Services¶
Perform validation using external services or APIs:
from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks
from typing import NewType, Annotated
from pydantic import BaseModel, EmailStr
from fastapi_shield import shield, Shield
import httpx
app = FastAPI()
# Define a model for the request
class EmailSubscription(BaseModel):
email: EmailStr
name: str
# Define a validated subscription type
ValidatedSubscription = NewType("ValidatedSubscription", EmailSubscription)
# Mock function to check if email is disposable
async def is_disposable_email(email: str) -> bool:
# In a real application, you would call an API like:
# async with httpx.AsyncClient() as client:
# response = await client.get(f"https://disposable-email-checker.example.com/api/check/{email}")
# return response.json().get("disposable", False)
# For this example, we'll just check for common disposable domains
disposable_domains = ["tempmail.com", "throwaway.com", "mailinator.com"]
domain = email.split("@")[1]
return domain in disposable_domains
# Shield to validate the subscription
@shield
async def validate_subscription(
subscription: EmailSubscription,
background_tasks: BackgroundTasks
) -> ValidatedSubscription:
# Check if email is disposable
if await is_disposable_email(subscription.email):
raise HTTPException(
status_code=400,
detail="Disposable email addresses are not allowed"
)
# Schedule email verification in background
def send_verification_email(email: str, name: str):
# In a real application, you would send an email
print(f"Sending verification email to {email}")
background_tasks.add_task(
send_verification_email,
subscription.email,
subscription.name
)
return ValidatedSubscription(subscription)
# Use the shield in an endpoint
@app.post("/subscribe")
async def subscribe(
subscription: ValidatedSubscription = Depends(validate_subscription)
):
# Process the validated subscription
return {"message": "Subscription pending verification"}
Best Practices¶
- Separate concerns: Keep validation logic separate from business logic.
- Reuse shields: Create reusable validation shields for common patterns.
- Type safety: Use
NewType
to create distinct types for validated data. - Fail fast: Validate as early as possible in the request lifecycle.
- Provide clear errors: Return descriptive error messages to help clients.
- Sanitize input: Always sanitize user input to prevent security issues.
- Validate extensively: Don't trust client-side validation alone.
- Layer validations: Chain shields to build complex validation pipelines.
- Performance: For expensive validations, consider caching or background processing.
- Testing: Write tests for your validation shields to ensure they work as expected.