Security Best Practices¶
This guide outlines best practices for securing your FastAPI applications using FastAPI Shield, based on real-world patterns and proven implementations.
General Security Principles¶
When implementing security with FastAPI Shield, follow these core principles:
- Defense in Depth: Apply multiple layers of security shields
- Least Privilege: Restrict access to only what's necessary
- Secure by Default: Start with maximum security and relax as needed
- Fail Securely: Always fail closed, not open
Secure Authentication¶
Authentication is the process of verifying who a user is. FastAPI Shield provides robust tools for implementing authentication using the decorator pattern.
JWT Authentication¶
from fastapi import FastAPI, Header, HTTPException, status
from fastapi_shield import shield, ShieldedDepends
import jwt
from jwt.exceptions import PyJWTError
app = FastAPI()
# Configuration
JWT_SECRET = "your-secret-key" # Store this securely!
JWT_ALGORITHM = "HS256"
@shield(
name="JWT Authentication",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"}
)
)
def jwt_auth_shield(authorization: str = Header()) -> dict:
"""Validate JWT token and return decoded payload"""
if not authorization.startswith("Bearer "):
return None
token = authorization.replace("Bearer ", "")
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return payload
except PyJWTError:
return None
# Protected endpoint
@app.get("/protected")
@jwt_auth_shield
async def protected_endpoint(
payload: dict = ShieldedDepends(lambda payload: payload)
):
return {
"message": "Access granted",
"user": payload.get("sub"),
"roles": payload.get("roles", [])
}
Basic Authentication¶
import secrets
import base64
from fastapi import Header, HTTPException, status
from fastapi_shield import shield, ShieldedDepends
# Mock user database (use proper database in production)
USERS = {
"admin": "strongpassword",
"user": "userpassword",
}
@shield(
name="Basic Authentication",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Basic"}
)
)
def basic_auth_shield(authorization: str = Header()) -> str:
"""Validate basic authentication credentials"""
if not authorization or not authorization.startswith("Basic "):
return None
auth_data = authorization.replace("Basic ", "")
try:
decoded = base64.b64decode(auth_data).decode("ascii")
username, password = decoded.split(":")
except Exception:
return None
# Validate username format
if len(username) < 3:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username must be at least 3 characters"
)
# Validate password format
if len(password) < 8:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password must be at least 8 characters"
)
# Check credentials
if username not in USERS:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username",
headers={"WWW-Authenticate": "Basic"}
)
# Use constant-time comparison to prevent timing attacks
if not secrets.compare_digest(password, USERS[username]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid password",
headers={"WWW-Authenticate": "Basic"}
)
return username
@app.get("/secure-data")
@basic_auth_shield
async def get_secure_data(
username: str = ShieldedDepends(lambda username: username)
):
return {"message": f"Hello {username}", "data": "Sensitive information"}
Role-Based Access Control (RBAC)¶
Implement sophisticated authorization using shield chaining and factory patterns.
from fastapi import FastAPI, Header, HTTPException
from fastapi_shield import shield, ShieldedDepends
app = FastAPI()
# User database with roles
USERS = {
"admin_token": {"user_id": "admin", "roles": ["admin", "user"]},
"editor_token": {"user_id": "editor", "roles": ["editor", "user"]},
"user_token": {"user_id": "user1", "roles": ["user"]},
}
@shield(name="Authentication")
def auth_shield(api_token: str = Header()) -> dict:
"""Authenticate the user and return user data"""
if api_token in USERS:
return USERS[api_token]
return None
def role_shield(required_roles: list[str]):
"""Factory function to create a role-checking shield"""
@shield(
name=f"Role Check ({', '.join(required_roles)})",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied. Required roles: {', '.join(required_roles)}"
)
)
def check_role(user_data: dict = ShieldedDepends(lambda user: user)) -> dict:
"""Check if the user has any of the required roles"""
user_roles = user_data.get("roles", [])
if any(role in required_roles for role in user_roles):
return user_data
return None
return check_role
# Create specific role shields
admin_shield = role_shield(["admin"])
editor_shield = role_shield(["admin", "editor"])
user_shield = role_shield(["admin", "editor", "user"])
@app.get("/admin")
@auth_shield
@admin_shield
async def admin_endpoint(
user: dict = ShieldedDepends(lambda user: user)
):
return {"message": "Admin endpoint", "user": user["user_id"]}
@app.get("/editor")
@auth_shield
@editor_shield
async def editor_endpoint(
user: dict = ShieldedDepends(lambda user: user)
):
return {"message": "Editor endpoint", "user": user["user_id"]}
@app.get("/user")
@auth_shield
@user_shield
async def user_endpoint(
user: dict = ShieldedDepends(lambda user: user)
):
return {"message": "User endpoint", "user": user["user_id"]}
Rate Limiting¶
Protect against abuse by implementing rate limits using shield patterns.
import time
from collections import defaultdict
from fastapi import Request, HTTPException, status
from fastapi_shield import shield
# In-memory storage (use Redis in production)
request_counts = defaultdict(list)
@shield(
name="Rate Limiter",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded. Maximum 5 requests per second.",
headers={"Retry-After": "1"}
)
)
def rate_limit_shield(request: Request) -> bool:
"""Limit requests based on client IP"""
client_ip = request.client.host
now = time.time()
window_seconds = 1
max_requests = 5
# Remove expired timestamps
request_counts[client_ip] = [
ts for ts in request_counts[client_ip]
if now - ts < window_seconds
]
# Check if rate limit is exceeded
if len(request_counts[client_ip]) >= max_requests:
return None
# Add current timestamp and allow request
request_counts[client_ip].append(now)
return True
@app.get("/api/data")
@rate_limit_shield
async def get_data():
return {"message": "Data retrieved successfully"}
IP Restriction¶
Restrict access based on client IP addresses for internal APIs.
from fastapi import Request, HTTPException, status
from fastapi_shield import shield, ShieldedDepends
# List of allowed IP addresses
ALLOWED_IPS = ["127.0.0.1", "::1", "192.168.1.0/24"]
@shield(
name="IP Restriction",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied by IP restriction"
)
)
def ip_restriction_shield(request: Request) -> dict:
"""Shield that allows only specific IP addresses"""
client_ip = request.client.host
if client_ip in ALLOWED_IPS:
return {"client_ip": client_ip}
return None
@app.get("/internal-api")
@ip_restriction_shield
async def internal_api(
ip_info: dict = ShieldedDepends(lambda info: info)
):
return {
"message": "Internal API endpoint",
"client_ip": ip_info["client_ip"]
}
Input Validation and Sanitization¶
Validate and sanitize all inputs to prevent injection attacks.
import re
import html
from fastapi import HTTPException, status
from fastapi_shield import shield, ShieldedDepends
from pydantic import BaseModel, EmailStr, validator
class UserRegistration(BaseModel):
username: str
email: EmailStr
password: str
@validator('username')
def username_must_be_valid(cls, v):
if not v or not re.match(r'^[a-zA-Z0-9_-]{3,16}$', v):
raise ValueError('Invalid username format')
return v
@validator('password')
def password_must_be_strong(cls, v):
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not re.search(r'[A-Z]', v):
raise ValueError('Password must contain an uppercase letter')
if not re.search(r'[a-z]', v):
raise ValueError('Password must contain a lowercase letter')
if not re.search(r'[0-9]', v):
raise ValueError('Password must contain a number')
return v
@shield(
name="Input Validation",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid input data"
)
)
def validate_registration_shield(registration: UserRegistration) -> UserRegistration:
"""Validate registration data"""
# Additional business logic validation
if registration.username.lower() in ['admin', 'root', 'system']:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username not allowed"
)
return registration
@shield(name="Content Sanitization")
def sanitize_comment_shield(comment: str) -> str:
"""Sanitize user-generated content"""
# Strip potentially dangerous HTML
sanitized = html.escape(comment)
# Limit comment length
if len(sanitized) > 1000:
sanitized = sanitized[:1000] + "..."
# Block dangerous patterns
if re.search(r'(script|javascript|eval\(|<iframe)', comment, re.IGNORECASE):
return None
return sanitized
@app.post("/register")
@validate_registration_shield
async def register_user(
registration: UserRegistration = ShieldedDepends(lambda reg: reg)
):
# Process validated registration
return {"status": "User registered successfully"}
@app.post("/comments")
@sanitize_comment_shield
async def create_comment(
clean_comment: str = ShieldedDepends(lambda comment: comment)
):
# Store sanitized comment
return {"status": "Comment added", "comment": clean_comment}
Secure File Handling¶
Validate file uploads to prevent malicious file attacks.
from fastapi import UploadFile, File, HTTPException, status
from fastapi_shield import shield, ShieldedDepends
import magic
import os
@shield(
name="File Validation",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid file upload"
)
)
async def validate_file_shield(file: UploadFile = File(...)) -> UploadFile:
"""Validate uploaded file for security"""
# Check file size (limit to 5MB)
content = await file.read()
await file.seek(0) # Reset file position
if len(content) > 5 * 1024 * 1024: # 5MB
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail="File too large. Maximum size is 5MB"
)
# Check MIME type
mime = magic.Magic(mime=True)
file_type = mime.from_buffer(content)
# Allow only certain file types
allowed_types = ["image/jpeg", "image/png", "application/pdf", "text/plain"]
if file_type not in allowed_types:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type {file_type} not allowed"
)
# Check file extension matches content
filename = file.filename
ext = os.path.splitext(filename)[1].lower()
valid_extensions = {
"image/jpeg": [".jpg", ".jpeg"],
"image/png": [".png"],
"application/pdf": [".pdf"],
"text/plain": [".txt"]
}
if ext not in valid_extensions.get(file_type, []):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="File extension doesn't match content type"
)
return file
@app.post("/upload")
@validate_file_shield
async def upload_file(
file: UploadFile = ShieldedDepends(lambda file: file)
):
# Process the validated file
return {
"filename": file.filename,
"content_type": file.content_type,
"status": "File uploaded successfully"
}
API Key Management¶
Secure API key validation with proper error handling.
import secrets
from fastapi import Header, HTTPException, status, Depends
from fastapi_shield import shield, ShieldedDepends
# Mock API key database (use proper database in production)
API_KEYS = {
"sk_test_abcdefghijklmnopqrstuvwxyz123456": {
"client_id": "client1",
"permissions": ["read", "write"],
"active": True
},
"sk_test_zyxwvutsrqponmlkjihgfedcba654321": {
"client_id": "client2",
"permissions": ["read", "write", "admin"],
"active": True
}
}
@shield(
name="API Key Validation",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "ApiKey"}
)
)
def validate_api_key_shield(x_api_key: str = Header()) -> dict:
"""Validate API key and return client information"""
# Check if API key exists
for key, data in API_KEYS.items():
# Use constant-time comparison to prevent timing attacks
if secrets.compare_digest(x_api_key, key):
if not data.get("active", False):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key is inactive"
)
return data
return None
def require_permission(permission: str):
"""Factory function for permission-based shields"""
@shield(
name=f"Permission Check ({permission})",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Permission '{permission}' required"
)
)
def permission_shield(
client_data: dict = ShieldedDepends(lambda data: data)
) -> dict:
"""Check if client has required permission"""
client_permissions = client_data.get("permissions", [])
if permission in client_permissions:
return client_data
return None
return permission_shield
# Create permission shields
read_permission = require_permission("read")
write_permission = require_permission("write")
admin_permission = require_permission("admin")
@app.get("/api/data")
@validate_api_key_shield
@read_permission
async def get_data(
client: dict = ShieldedDepends(lambda data: data)
):
return {
"data": "Sensitive information",
"client": client["client_id"]
}
@app.post("/api/data")
@validate_api_key_shield
@write_permission
async def create_data(
client: dict = ShieldedDepends(lambda data: data)
):
return {
"status": "Data created",
"client": client["client_id"]
}
Multi-Factor Authentication (MFA)¶
Implement MFA for enhanced security.
import jwt
from datetime import datetime, timedelta
from fastapi import Header, HTTPException, status
from fastapi_shield import shield, ShieldedDepends
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
# Mock user database with MFA settings
USERS_DB = {
"user1": {
"username": "user1",
"mfa_enabled": True,
"roles": ["user"]
},
"admin1": {
"username": "admin1",
"mfa_enabled": True,
"roles": ["admin", "user"]
}
}
@shield(
name="JWT MFA Authentication",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"}
)
)
def jwt_mfa_shield(authorization: str = Header()) -> dict:
"""Validate JWT token with MFA verification"""
if not authorization or not authorization.startswith("Bearer "):
return None
token = authorization.replace("Bearer ", "")
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username = payload.get("sub")
mfa_verified = payload.get("mfa_verified", False)
if username is None or username not in USERS_DB:
return None
user = USERS_DB[username]
# If MFA is enabled but not verified, require MFA
if user["mfa_enabled"] and not mfa_verified:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="MFA verification required",
headers={"WWW-Authenticate": "Bearer"}
)
return user
except jwt.PyJWTError:
return None
@app.get("/secure-endpoint")
@jwt_mfa_shield
async def secure_endpoint(
user: dict = ShieldedDepends(lambda user: user)
):
return {
"message": "Access granted to secure endpoint",
"user": user["username"],
"mfa_verified": True
}
Security Monitoring and Logging¶
Implement security monitoring using shields.
import time
import logging
from fastapi import Request, HTTPException
from fastapi_shield import shield, ShieldedDepends
# Configure security logger
security_logger = logging.getLogger("security")
security_logger.setLevel(logging.INFO)
@shield(name="Security Monitor")
async def security_monitor_shield(request: Request) -> dict:
"""Monitor and log security-relevant request information"""
start_time = time.time()
security_info = {
"timestamp": start_time,
"method": request.method,
"path": request.url.path,
"client_ip": request.client.host,
"user_agent": request.headers.get("user-agent", "unknown"),
"referer": request.headers.get("referer", ""),
"content_length": request.headers.get("content-length", "0")
}
# Log security event
security_logger.info(
f"Request: {security_info['method']} {security_info['path']} "
f"from {security_info['client_ip']}"
)
# Check for suspicious patterns
suspicious_patterns = [
"script", "javascript", "eval(", "<iframe", "union select",
"../", "etc/passwd", "cmd.exe"
]
path_lower = request.url.path.lower()
if any(pattern in path_lower for pattern in suspicious_patterns):
security_logger.warning(
f"Suspicious request detected: {security_info['path']} "
f"from {security_info['client_ip']}"
)
return security_info
@app.get("/monitored-endpoint")
@security_monitor_shield
async def monitored_endpoint(
security_info: dict = ShieldedDepends(lambda info: info)
):
return {
"message": "Request processed",
"request_id": security_info["timestamp"]
}
Security Best Practices Checklist¶
When implementing security with FastAPI Shield, follow this checklist:
Shield Implementation¶
- ✅ Use shields as decorators, not as function calls
- ✅ Give descriptive names to your shields for better debugging
- ✅ Return
None
from shields that fail validation - ✅ Use specific HTTP exceptions with clear error messages
- ✅ Implement proper error handling within shields
Authentication & Authorization¶
- ✅ Use constant-time comparison for credentials (
secrets.compare_digest
) - ✅ Implement proper JWT validation with expiration checks
- ✅ Use shield chaining for complex authorization requirements
- ✅ Implement role-based access control with factory patterns
- ✅ Store secrets securely using environment variables
Input Validation¶
- ✅ Validate all user inputs with appropriate shields
- ✅ Sanitize user-generated content to prevent XSS
- ✅ Use Pydantic models for structured data validation
- ✅ Implement file upload validation for security
Rate Limiting & Protection¶
- ✅ Implement rate limiting for public endpoints
- ✅ Use IP restrictions for internal APIs
- ✅ Monitor and log security events
- ✅ Implement proper CORS policies
Dependency Injection¶
- ✅ Use
ShieldedDepends(lambda data: data)
for accessing shield data - ✅ Layer multiple shields for defense in depth
- ✅ Keep shields focused and single-purpose
- ✅ Use shield factories for reusability
Error Handling¶
- ✅ Configure appropriate HTTP status codes for different failures
- ✅ Provide clear error messages without exposing sensitive information
- ✅ Log security events for monitoring and auditing
- ✅ Fail securely by default
By following these patterns and best practices, you can create FastAPI applications that are both functional and secure, leveraging FastAPI Shield's powerful protection mechanisms in a maintainable and scalable way.