Real World Examples¶
This section provides real-world examples of using FastAPI Shield in production scenarios. These examples demonstrate practical implementations of security patterns that have been thoroughly tested and validated.
E-commerce API Security¶
A comprehensive example of securing an e-commerce API with multiple layers of protection:
from fastapi import Body, FastAPI, Header, HTTPException, status, Depends, Query, Path
from fastapi_shield import shield, ShieldedDepends
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import jwt
import time
from uuid import UUID, uuid4
import logging
from datetime import datetime, timedelta
# Configuration
JWT_SECRET = "your-production-secret-key"
JWT_ALGORITHM = "HS256"
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_MAX_REQUESTS = 100
app = FastAPI(title="E-commerce API")
# Models
class Product(BaseModel):
id: UUID
name: str
price: float
description: str
category: str
class OrderItem(BaseModel):
product_id: UUID
quantity: int = Field(gt=0)
price: float
class Order(BaseModel):
id: UUID
user_id: UUID
items: List[OrderItem]
total: float
status: str
class UserRole:
ADMIN = "admin"
MANAGER = "manager"
CUSTOMER = "customer"
# Mock databases (in production, use proper databases)
PRODUCTS_DB = {}
ORDERS_DB = {}
# Rate limiting storage (in production, use Redis)
rate_limits = {}
# Shield System
@shield(name="Rate Limiter")
async def rate_limit_shield(client_ip: str = Header(None, alias="X-Forwarded-For")):
"""Rate limit based on client IP"""
if not client_ip:
client_ip = "127.0.0.1"
rate_key = f"rate_limit:{client_ip}"
current_time = int(time.time())
window_start = current_time - RATE_LIMIT_WINDOW
# Simple in-memory rate limiting (use Redis in production)
if rate_key not in rate_limits:
rate_limits[rate_key] = []
# Remove old entries
rate_limits[rate_key] = [
t for t in rate_limits[rate_key]
if t > window_start
]
# Add current request
rate_limits[rate_key].append(current_time)
# Check if rate limit exceeded
if len(rate_limits[rate_key]) > RATE_LIMIT_MAX_REQUESTS:
return None
return {"client_ip": client_ip, "request_count": len(rate_limits[rate_key])}
@shield(name="JWT Auth")
async def jwt_auth_shield(
authorization: str = Header(None),
x_refresh_token: Optional[str] = Header(None)
):
"""Validate JWT token with refresh token support"""
if not authorization:
return None
if not authorization.startswith("Bearer "):
return None
token = authorization.replace("Bearer ", "")
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return {
"user_id": payload.get("sub"),
"roles": payload.get("roles", []),
"permissions": payload.get("permissions", [])
}
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
def require_permissions(required_permissions: List[str]):
"""Create a shield that requires specific permissions"""
@shield(
name=f"Permission Check ({', '.join(required_permissions)})",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient permissions"
)
)
async def permission_shield(auth_data = ShieldedDepends(lambda auth_data: auth_data)):
user_permissions = auth_data.get("permissions", [])
if any(perm in user_permissions for perm in required_permissions):
return auth_data
return None
return permission_shield
@shield(name="Product Exists")
async def product_exists_shield(product_id: UUID = Path(...)):
"""Check if a product exists"""
product = PRODUCTS_DB.get(product_id)
if not product:
return None
return product
@shield(name="Order Owner")
async def order_owner_shield(
order_id: UUID = Path(...),
auth_data = ShieldedDepends(lambda payload: payload)
):
"""Check if the user owns the order"""
order = ORDERS_DB.get(order_id)
if not order:
return None
user_id = auth_data.get("user_id")
user_roles = auth_data.get("roles", [])
# Admin or manager can access any order
if UserRole.ADMIN in user_roles or UserRole.MANAGER in user_roles:
return order
# Owner can access their own order
if str(order.user_id) == user_id:
return order
return None
# API Endpoints
@app.get("/products", response_model=List[Product])
@rate_limit_shield
async def list_products(
category: Optional[str] = Query(None),
rate_data = ShieldedDepends(lambda rate_data: rate_data)
):
"""List all products with optional filtering - Public endpoint"""
products = list(PRODUCTS_DB.values())
if category:
products = [p for p in products if p.category == category]
return products
@app.get("/products/{product_id}", response_model=Product)
@rate_limit_shield
@product_exists_shield
async def get_product(product: Product = ShieldedDepends(lambda p: p)):
"""Get a specific product by ID - Public endpoint"""
return product
@app.post("/products", response_model=Product)
@rate_limit_shield
@jwt_auth_shield
@require_permissions(["product:create"])
async def create_product(product_data: dict = Body()):
"""Create a new product - Requires authentication and product:create permission"""
product_id = uuid4()
product = Product(
id=product_id,
name=product_data["name"],
price=product_data["price"],
description=product_data["description"],
category=product_data["category"]
)
PRODUCTS_DB[product_id] = product
return product
@app.put("/products/{product_id}", response_model=Product)
@rate_limit_shield
@jwt_auth_shield
@require_permissions(["product:update"])
@product_exists_shield
async def update_product(
updated_data: dict,
product: Product = ShieldedDepends(lambda p: p),
):
"""Update an existing product - Requires authentication and product:update permission"""
for key, value in updated_data.items():
if hasattr(product, key):
setattr(product, key, value)
PRODUCTS_DB[product.id] = product
return product
@app.delete("/products/{product_id}")
@rate_limit_shield
@jwt_auth_shield
@require_permissions(["product:delete"])
@product_exists_shield
async def delete_product(product: Product = ShieldedDepends(lambda p: p)):
"""Delete a product - Requires authentication and product:delete permission"""
del PRODUCTS_DB[product.id]
return {"message": "Product deleted successfully"}
@app.get("/orders", response_model=List[Order])
@rate_limit_shield
@jwt_auth_shield
async def list_orders(auth_data = ShieldedDepends(lambda d: d)):
"""List orders - Users see their own, admins/managers see all"""
user_id = auth_data.get("user_id")
user_roles = auth_data.get("roles", [])
if UserRole.ADMIN in user_roles or UserRole.MANAGER in user_roles:
return list(ORDERS_DB.values())
return [o for o in ORDERS_DB.values() if str(o.user_id) == user_id]
@app.get("/orders/{order_id}", response_model=Order)
@rate_limit_shield
@jwt_auth_shield
@order_owner_shield
async def get_order(order: Order = ShieldedDepends(lambda o: o)):
"""Get a specific order - Only owner or admin/manager can access"""
return order
@app.post("/orders", response_model=Order)
@rate_limit_shield
@jwt_auth_shield
async def create_order(
order_data: dict,
auth_data = ShieldedDepends(lambda d: d)
):
"""Create a new order - Requires authentication"""
user_id = UUID(auth_data.get("user_id"))
order_id = uuid4()
order = Order(
id=order_id,
user_id=user_id,
items=[OrderItem(**item) for item in order_data["items"]],
total=order_data["total"],
status="created"
)
ORDERS_DB[order_id] = order
return order
@app.delete("/orders/{order_id}")
@rate_limit_shield
@jwt_auth_shield
@require_permissions(["order:delete"])
@order_owner_shield
async def cancel_order(order: Order = ShieldedDepends(lambda o: o)):
"""Cancel an order - Requires order:delete permission and ownership"""
order.status = "cancelled"
ORDERS_DB[order.id] = order
return {"message": "Order cancelled successfully"}
# Helper function to create test tokens
def create_test_token(user_id: str, roles: list = None, permissions: list = None):
"""Create a JWT token for testing"""
payload = {
"sub": user_id,
"roles": roles or ["user"],
"permissions": permissions or [],
"exp": datetime.utcnow() + timedelta(hours=1)
}
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
Key Features¶
- Public Endpoints: Product listing and viewing don't require authentication
- Authenticated Endpoints: Order management requires valid JWT tokens
- Permission-Based Access: Product management requires specific permissions
- Role-Based Access: Admins and managers have broader access to orders
- Resource Ownership: Users can only access their own orders (unless admin/manager)
- Rate Limiting: All endpoints are rate-limited by client IP
- Shield Composition: Multiple shields can be stacked for layered security
Shield Order Matters¶
Notice the order of shields in the decorators:
@rate_limit_shield # Applied first
@jwt_auth_shield # Applied second
@require_permissions(...) # Applied third
@product_exists_shield # Applied last
The shields are applied in reverse order of decoration, so the bottom shield runs first.
Microservice API Gateway¶
An example of using FastAPI Shield as part of an API gateway for a microservices architecture:
from fastapi import FastAPI, Depends, Header, HTTPException, status, Request, Response
from fastapi_shield import shield, ShieldedDepends
import httpx
import jwt
import time
import uuid
from typing import Dict, Any, Optional
import logging
from datetime import datetime, timedelta
# Configuration
JWT_SECRET = "your-production-secret-key"
JWT_ALGORITHM = "HS256"
app = FastAPI(title="API Gateway")
# Service registry (in production, this would be dynamic)
SERVICES = {
"user-service": "http://user-service:8001",
"product-service": "http://product-service:8002",
"order-service": "http://order-service:8003",
"payment-service": "http://payment-service:8004"
}
# HTTP client for forwarding requests
http_client = httpx.AsyncClient()
# Gateway Shields
@shield(name="Gateway Auth")
async def gateway_auth_shield(authorization: str = Header(None)):
"""Validate JWT token for the gateway"""
if not authorization:
return None
if not authorization.startswith("Bearer "):
return None
token = authorization.replace("Bearer ", "")
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return {
"user_id": payload.get("sub"),
"roles": payload.get("roles", []),
"permissions": payload.get("permissions", []),
"token": token # Keep the token to pass to microservices
}
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
@shield(name="Rate Limiter")
async def gateway_rate_limit(request: Request):
"""Simple rate limiting for the gateway"""
client_ip = getattr(request.client, 'host', '127.0.0.1')
# Generate a key for rate limiting
rate_key = f"rate:{client_ip}:{int(time.time() / 60)}" # Per minute
# Simple in-memory rate limiting (use Redis in production)
if not hasattr(app, "rate_limits"):
app.rate_limits = {}
current_count = app.rate_limits.get(rate_key, 0)
max_requests = 100 # 100 requests per minute
if current_count >= max_requests:
return None
app.rate_limits[rate_key] = current_count + 1
return {"client_ip": client_ip}
# Middleware for request logging
@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()
request_id = str(uuid.uuid4())
# Add request ID to request state
request.state.request_id = request_id
response = await call_next(request)
process_time = time.time() - start_time
# Add request ID to response headers
response.headers["X-Request-ID"] = request_id
return response
# Health check endpoint
@app.get("/gateway/health")
async def health_check():
"""Health check endpoint for the gateway"""
service_status = {}
for service_name, service_url in SERVICES.items():
try:
# In production, you would actually check service health
service_status[service_name] = {
"status": "up",
"status_code": 200
}
except Exception as e:
service_status[service_name] = {
"status": "down",
"error": str(e)
}
return {
"status": "ok",
"timestamp": time.time(),
"services": service_status
}
# Main proxy route
@app.api_route("/{service}/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
@gateway_rate_limit
@gateway_auth_shield
async def proxy_to_service(
service: str,
path: str,
request: Request,
auth_data = ShieldedDepends(lambda auth_data: auth_data),
):
"""
Main proxy handler that forwards requests to the appropriate microservice.
Applies authentication and authorization before forwarding.
"""
# Check if service exists
if service not in SERVICES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Service '{service}' not found"
)
# Get service URL
service_url = SERVICES[service]
target_url = f"{service_url}/{path}"
# Get request body if present
body = await request.body()
# Extract headers and filter out host-specific headers
headers = dict(request.headers)
headers_to_remove = ["host", "content-length", "connection"]
for header in headers_to_remove:
if header in headers:
del headers[header]
# Add custom gateway headers
headers["X-Gateway-Request-ID"] = request.state.request_id
if auth_data:
# Add user context for the microservice
headers["X-User-ID"] = auth_data.get("user_id", "")
headers["X-User-Roles"] = ",".join(auth_data.get("roles", []))
# Keep the original auth token
if "Authorization" not in headers and "token" in auth_data:
headers["Authorization"] = f"Bearer {auth_data['token']}"
# Forward the request to the target service
try:
method = request.method.lower()
request_kwargs = {
"headers": headers,
"params": dict(request.query_params),
"timeout": 30.0 # 30 seconds timeout
}
if body:
request_kwargs["content"] = body
# Make the request to the microservice
if method == "get":
response = await http_client.get(target_url, **request_kwargs)
elif method == "post":
response = await http_client.post(target_url, **request_kwargs)
elif method == "put":
response = await http_client.put(target_url, **request_kwargs)
elif method == "delete":
response = await http_client.delete(target_url, **request_kwargs)
elif method == "patch":
response = await http_client.patch(target_url, **request_kwargs)
else:
raise HTTPException(
status_code=status.HTTP_405_METHOD_NOT_ALLOWED,
detail=f"Method {method} not allowed"
)
# Create a FastAPI response from the microservice response
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type")
)
except httpx.RequestError as e:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Error communicating with service: {str(e)}"
)
# Helper function to create test tokens
def create_test_token(user_id: str, roles: list = None, permissions: list = None):
"""Create a JWT token for testing"""
payload = {
"sub": user_id,
"roles": roles or ["user"],
"permissions": permissions or [],
"exp": datetime.utcnow() + timedelta(hours=1)
}
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
The API Gateway implements several critical features essential for modern microservice architectures. It provides service discovery capabilities, intelligently routing incoming requests to their appropriate microservice destinations. Authentication is handled through JWT token validation, ensuring secure request forwarding. The gateway maintains comprehensive request logging with unique identifiers for each request, while also managing headers to provide necessary user context to downstream microservices. Health monitoring capabilities allow for real-time status tracking of all services, and built-in rate limiting protects the system from potential abuse. The gateway also implements robust error handling to gracefully manage service communication failures.
The request flow through the gateway follows a well-defined sequence. Initially, each request undergoes rate limiting checks to prevent abuse. This is followed by JWT token validation for authentication. The gateway then determines the appropriate target microservice through service routing. During header processing, it enriches the request with gateway headers and user context information. The request is then forwarded to the target microservice, and finally, the microservice's response is processed and returned to the client.
The implementation includes thorough testing suites that validate various aspects of the system. These tests cover authentication flows, including scenarios with valid, invalid, and expired tokens. Authorization testing ensures proper handling of permissions and roles, while rate limiting tests verify behavior both within and exceeding limits. The test suite also includes comprehensive error handling scenarios, covering network errors and invalid requests, as well as edge cases involving missing resources and malformed data. These tests not only validate the system but also serve as living documentation of expected behavior.
When deploying to production, several important considerations must be addressed. Redis should be implemented for rate limiting and token blacklisting functionality. Proper logging with structured formats is essential for system monitoring and debugging. The system should include monitoring and alerting mechanisms for shield failures, while configuration should be managed through environment variables. Circuit breakers should be implemented for external service calls to prevent cascading failures. Comprehensive error handling with appropriate status codes is crucial, along with HTTPS implementation and secure token storage. Token refresh mechanisms should be in place, and request/response validation schemas should be implemented. Finally, continuous monitoring of shield execution performance is necessary to maintain system health. These real-world examples provide a robust foundation for building secure and scalable APIs using FastAPI Shield.