Advanced Examples¶
This section provides advanced examples of using FastAPI Shield for more complex security and validation requirements. These examples have been thoroughly tested and validated.
Chained Shield Processing¶
Creating shields that work together and pass data between them in a chain:
from fastapi import FastAPI, Header, HTTPException, status, Depends
from fastapi_shield import shield, ShieldedDepends
import jwt
from typing import Dict, Optional, List
from datetime import datetime, timedelta
app = FastAPI()
# Configuration
JWT_SECRET = "your-secret-key"
JWT_ALGORITHM = "HS256"
class AuthData:
"""Class to hold authentication data and provide helper methods"""
def __init__(self, user_id: str, roles: List[str], permissions: List[str]):
self.user_id = user_id
self.roles = roles
self.permissions = permissions
def has_role(self, role: str) -> bool:
"""Check if user has a specific role"""
return role in self.roles
def has_permission(self, permission: str) -> bool:
"""Check if user has a specific permission"""
return permission in self.permissions
@shield(name="JWT Authentication")
async def jwt_auth_shield(authorization: str = Header()) -> Optional[Dict]:
"""Validate JWT token and extract payload"""
if not authorization.startswith("Bearer "):
return None
token = authorization.replace("Bearer ", "")
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return payload
except Exception:
return None
@shield(name="User Role Extraction")
async def role_extraction_shield(payload = ShieldedDepends(lambda payload: payload)) -> Optional[AuthData]:
"""Extract user roles from the JWT payload and create AuthData object"""
if not payload or "user_id" not in payload:
return None
user_id = payload.get("user_id")
roles = payload.get("roles", [])
permissions = payload.get("permissions", [])
return AuthData(user_id, roles, permissions)
def require_role(role: str):
"""Create a shield that requires a specific role"""
@shield(
name=f"Role Requirement ({role})",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Requires role: {role}"
)
)
async def role_shield(auth_data: AuthData = ShieldedDepends(lambda auth_data: auth_data)) -> Optional[AuthData]:
if auth_data.has_role(role):
return auth_data
return None
return role_shield
def require_permission(permission: str):
"""Create a shield that requires a specific permission"""
@shield(
name=f"Permission Requirement ({permission})",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Requires permission: {permission}"
)
)
async def permission_shield(auth_data: AuthData = ShieldedDepends(lambda auth_data: auth_data)) -> Optional[AuthData]:
if auth_data.has_permission(permission):
return auth_data
return None
return permission_shield
# Endpoints with chained shield protection
@app.get("/user-profile")
@jwt_auth_shield
@role_extraction_shield
async def user_profile(auth_data: AuthData = ShieldedDepends(lambda auth_data: auth_data)):
return {
"user_id": auth_data.user_id,
"roles": auth_data.roles,
"permissions": auth_data.permissions
}
@app.get("/admin-panel")
@jwt_auth_shield
@role_extraction_shield
@require_role("admin")
async def admin_panel(auth_data: AuthData = ShieldedDepends(lambda auth_data: auth_data)):
return {
"message": "Welcome to admin panel",
"user_id": auth_data.user_id
}
@app.post("/user-management")
@jwt_auth_shield
@role_extraction_shield
@require_permission("manage_users")
async def user_management(auth_data: AuthData = ShieldedDepends(lambda auth_data: auth_data)):
return {
"message": "User management access granted",
"user_id": auth_data.user_id
}
# Helper function to create test tokens
def create_test_token(user_id: str, roles: List[str] = None, permissions: List[str] = None):
"""Create a test JWT token"""
payload = {
"user_id": user_id,
"roles": roles or [],
"permissions": permissions or [],
"exp": datetime.utcnow() + timedelta(hours=1)
}
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
Key Features of Chained Processing¶
- Sequential Execution: Shields execute in sequence, with each shield depending on the previous one
- Data Flow: Each shield can access and transform data from previous shields using
ShieldedDepends
- Validation Chain: If any shield in the chain fails, the entire request is rejected
- Flexible Composition: Shield factories allow dynamic creation of requirement shields
Dynamic Shield Configuration with Database¶
Loading shield configuration and user data from a database:
from fastapi import FastAPI, Depends, Header, HTTPException, status
from fastapi_shield import shield, ShieldedDepends
from pydantic import BaseModel
from typing import List, Optional
app = FastAPI()
# Mock database storage (in production, use actual database)
api_keys_data = {}
user_permissions_data = {}
# Models
class Permission(BaseModel):
permission: str
class User(BaseModel):
id: int
permissions: List[Permission]
# Database dependency (mock)
async def get_db():
"""Mock database dependency - in production, return actual database session"""
return None # Mock database connection
@shield(name="Database Auth Shield")
async def db_auth_shield(api_key: str = Header(), db = Depends(get_db)):
"""Authenticate using API key from database"""
# Mock database query - in production, query actual database
result = api_keys_data.get(api_key)
if not result or not result.get("is_active", False):
return None
return {"user_id": result["user_id"], "key": result["key"]}
@shield(name="Permission Shield")
async def permission_shield(auth_data = ShieldedDepends(lambda auth_data: auth_data), db = Depends(get_db)):
"""Load user permissions from database"""
if not auth_data:
return None
user_id = auth_data["user_id"]
# Mock database query for permissions - in production, query actual database
user_perms = user_permissions_data.get(user_id, [])
permissions = [Permission(permission=perm) for perm in user_perms]
return User(id=user_id, permissions=permissions)
def require_db_permission(required_permission: str):
"""Create a shield that requires a database-stored permission"""
@shield(
name=f"DB Permission Requirement ({required_permission})",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Requires permission: {required_permission}"
)
)
async def permission_check(user: User = ShieldedDepends(lambda auth_data: auth_data)):
user_permissions = [p.permission for p in user.permissions]
if required_permission in user_permissions:
return user
return None
return permission_check
@app.get("/db-protected")
@db_auth_shield
@permission_shield
async def db_protected_endpoint(user: User = ShieldedDepends(lambda auth_data: auth_data)):
return {
"message": "Access granted via database authentication",
"user_id": user.id,
"permissions": [p.permission for p in user.permissions]
}
@app.get("/admin-access")
@db_auth_shield
@permission_shield
@require_db_permission("admin_access")
async def admin_access_endpoint(user: User = ShieldedDepends(lambda auth_data: auth_data)):
return {
"message": "Admin access granted",
"user_id": user.id
}
# Setup functions for testing/demo
def setup_test_data():
"""Setup test data for demonstration"""
# Mock API key
api_keys_data["valid_key"] = {
"user_id": 1,
"key": "valid_key",
"is_active": True
}
# Mock user permissions
user_permissions_data[1] = ["read_data", "write_data", "admin_access"]
# Call setup for demo
setup_test_data()
Database Integration Features¶
- Dynamic Configuration: Authentication and authorization rules stored in database
- Runtime Permissions: User permissions loaded at request time from database
- Scalable Architecture: Easy to add new permissions without code changes
- Audit Trail: All access attempts can be logged to database
OAuth2 Integration¶
Integrating FastAPI Shield with OAuth2 authentication flows:
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi_shield import shield, ShieldedDepends
from pydantic import BaseModel
from typing import Dict, Optional
import jwt
from datetime import datetime, timedelta
# OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# Mock user database
USERS_DB = {
"johndoe": {
"username": "johndoe",
"full_name": "John Doe",
"email": "john@example.com",
"hashed_password": "fakehashedsecret",
"roles": ["user"]
},
"alice": {
"username": "alice",
"full_name": "Alice Smith",
"email": "alice@example.com",
"hashed_password": "fakehashedsecret2",
"roles": ["user", "admin"]
}
}
# Models
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Optional[str] = None
roles: list[str] = []
class User(BaseModel):
username: str
email: Optional[str] = None
full_name: Optional[str] = None
roles: list[str] = []
app = FastAPI()
# Helper functions
def verify_password(plain_password, hashed_password):
"""Verify password (simplified for example)"""
return plain_password == hashed_password
def get_user(db, username: str):
"""Get user from database"""
if username in db:
user_dict = db[username]
return User(**user_dict)
return None
def authenticate_user(fake_db, username: str, password: str):
"""Authenticate user"""
user = get_user(fake_db, username)
if not user:
return False
if not verify_password(password, fake_db[username]["hashed_password"]):
return False
return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""Create JWT access token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
"""OAuth2 token endpoint"""
user = authenticate_user(USERS_DB, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username, "roles": user.roles},
expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@shield(name="OAuth2 Shield")
async def oauth2_shield(token: str = Depends(oauth2_scheme)):
"""Shield that validates OAuth2 tokens"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
return None
token_data = TokenData(username=username, roles=payload.get("roles", []))
except Exception:
return None
user = get_user(USERS_DB, username=token_data.username)
if user is None:
return None
# Return both user and token data
return {"user": user, "token_data": token_data}
def require_oauth2_role(role: str):
"""Shield factory for OAuth2 role checking"""
@shield(
name=f"OAuth2 Role ({role})",
exception_to_raise_if_fail=HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role {role} required"
)
)
async def role_check(oauth_data = ShieldedDepends(lambda payload: payload)):
token_data = oauth_data["token_data"]
if role in token_data.roles:
return oauth_data
return None
return role_check
@app.get("/users/me")
@oauth2_shield
async def read_users_me(oauth_data = ShieldedDepends(lambda oauth_data: oauth_data)):
"""Get current user information"""
user = oauth_data["user"]
return user
@app.get("/admin/settings")
@oauth2_shield
@require_oauth2_role("admin")
async def admin_settings(oauth_data = ShieldedDepends(lambda oauth_data: oauth_data)):
"""Admin-only endpoint"""
return {
"message": "Admin settings",
"user": oauth_data["user"].username
}
# Helper function to create test tokens
def create_oauth2_token(username: str, roles: List[str] = None):
"""Create a test OAuth2 JWT token"""
payload = {
"sub": username,
"roles": roles or ["user"],
"exp": datetime.utcnow() + timedelta(hours=1)
}
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
OAuth2 Integration Features¶
- Standard OAuth2 Flow: Implements proper OAuth2 password bearer flow
- Token Validation: Comprehensive JWT token validation with expiry checking
- Role-Based Access: Dynamic role checking through shield factories
- User Context: Full user information available in protected endpoints
- Token Generation: Secure token generation with configurable expiry
Advanced Error Handling¶
Robust error handling in shield implementations:
from fastapi import FastAPI, Header
from fastapi_shield import shield, ShieldedDepends
import asyncio
app = FastAPI()
@shield(name="Error Test Shield")
async def error_shield(test_mode: str = Header(None)):
"""Shield that demonstrates various error handling scenarios"""
if test_mode == "exception":
# Shield handles exceptions gracefully
raise ValueError("Simulated shield error")
elif test_mode == "timeout":
# Shield can handle slow operations
await asyncio.sleep(0.1) # Simulate slow operation
return {"mode": "timeout"}
elif test_mode == "none":
# Shield returns None to block access
return None
elif test_mode == "valid":
# Shield returns valid data
return {"mode": "valid"}
else:
# Default case - no access
return None
@app.get("/error-test")
@error_shield
async def error_test_endpoint(data = ShieldedDepends(lambda data: data)):
return {"message": "Success", "data": data}
Complex Shield Composition¶
Advanced patterns for composing multiple shields:
from fastapi import FastAPI, Header
from fastapi_shield import shield, ShieldedDepends
app = FastAPI()
@shield(name="Shield A")
async def shield_a(value_a: str = Header(None)):
"""First shield in the chain"""
if value_a == "valid_a":
return {"from_a": "data_a"}
return None
@shield(name="Shield B")
async def shield_b(
value_b: str = Header(None),
data_from_a = ShieldedDepends(lambda data: data)
):
"""Second shield that depends on Shield A"""
if value_b == "valid_b" and data_from_a:
return {"from_b": "data_b", "chain_data": data_from_a}
return None
@shield(name="Shield C")
async def shield_c(
value_c: str = Header(None),
data_from_b = ShieldedDepends(lambda data: data)
):
"""Third shield that depends on Shield B"""
if value_c == "valid_c" and data_from_b:
return {"from_c": "data_c", "chain_data": data_from_b}
return None
@app.get("/triple-shield")
@shield_a
@shield_b
@shield_c
async def triple_shield_endpoint(final_data = ShieldedDepends(lambda data: data)):
"""Endpoint protected by three chained shields"""
return {"message": "Triple shield success", "final_data": final_data}
Composition Features¶
- Data Chaining: Each shield can access and transform data from previous shields
- Fail-Fast: If any shield in the chain fails, the entire request is rejected
- Flexible Ordering: Shields execute in decorator order (bottom to top)
- Context Preservation: Rich context data flows through the entire chain
Testing Your Advanced Implementations¶
These examples include comprehensive testing patterns:
import pytest
from fastapi.testclient import TestClient
def test_chained_shields():
"""Test chained shield processing"""
client = TestClient(app)
# Test with valid token
token = create_test_token("user123", ["admin"], ["manage_users"])
headers = {"Authorization": f"Bearer {token}"}
response = client.get("/admin-panel", headers=headers)
assert response.status_code == 200
# Test without required role
token = create_test_token("user123", ["user"], ["read_profile"])
headers = {"Authorization": f"Bearer {token}"}
response = client.get("/admin-panel", headers=headers)
assert response.status_code == 403
def test_database_shields():
"""Test database-driven shields"""
client = TestClient(app)
# Setup test data
api_keys_data["test_key"] = {
"user_id": 1,
"key": "test_key",
"is_active": True
}
user_permissions_data[1] = ["admin_access"]
headers = {"api-key": "test_key"}
response = client.get("/admin-access", headers=headers)
assert response.status_code == 200
def test_oauth2_integration():
"""Test OAuth2 integration"""
client = TestClient(app)
# Test token generation
form_data = {"username": "johndoe", "password": "fakehashedsecret"}
response = client.post("/token", data=form_data)
assert response.status_code == 200
# Test protected endpoint
token = response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}
response = client.get("/users/me", headers=headers)
assert response.status_code == 200
Production Considerations¶
When implementing these advanced patterns in production:
- Error Handling: Implement comprehensive error logging and monitoring
- Performance: Cache database queries and token validations where appropriate
- Security: Use proper password hashing and secure JWT secrets
- Scalability: Consider distributed caching for multi-instance deployments
- Monitoring: Track shield execution times and failure rates
- Testing: Implement comprehensive test suites covering all edge cases
- Documentation: Maintain clear documentation of shield dependencies and data flow
These advanced examples demonstrate the full power and flexibility of FastAPI Shield for complex authentication and authorization scenarios.