shield Function¶
The shield
function is the main entry point for creating Shield instances. It can be used as a decorator with or without parameters, or as a factory function.
Overview¶
The shield
function provides a convenient interface for creating Shield
instances. It supports multiple usage patterns:
- Direct decorator usage:
@shield
- Parameterized decorator:
@shield(name="Auth", auto_error=False)
- Factory function:
auth_shield = shield(my_function, name="Auth")
Function Reference¶
Factory function and decorator for creating Shield
instances.
This is the main entry point for creating shields. It can be used as a
decorator with or without parameters, or as a factory function to create
Shield
instances.
The shield function should accept parameters that match FastAPI endpoint
parameters (like request, path parameters, query parameters, etc.) and
return either:
- Truthy data (dict, object, etc.) to allow the request and provide data
to ShieldedDepends
dependencies
- None or False to block the request
Parameters:
Name | Type | Description | Default |
---|---|---|---|
shield_func
|
Optional[U]
|
The validation function to use. If None, returns a decorator function that accepts the shield function. |
None
|
name
|
str
|
Human-readable name for the shield (used in error messages). |
None
|
auto_error
|
bool
|
Whether to raise HTTP exceptions on shield failure. |
True
|
exception_to_raise_if_fail
|
Optional[HTTPException]
|
Custom exception to raise on failure. |
None
|
default_response_to_return_if_fail
|
Optional[Response]
|
Custom response when auto_error=False. |
None
|
Returns:
Type | Description |
---|---|
Shield[U]
|
Shield[U]: A |
Examples:
# Basic usage as decorator
@shield
def auth_shield(request: Request) -> Optional[dict]:
token = request.headers.get("Authorization")
if validate_token(token):
return {"user_id": 123}
return None
# With parameters
@shield(name="Authentication", auto_error=False)
def auth_shield(request: Request) -> Optional[dict]:
# validation logic
pass
# As factory function
auth_shield = shield(
my_auth_function,
name="Authentication",
exception_to_raise_if_fail=HTTPException(401, "Unauthorized")
)
# Apply to endpoints
@app.get("/protected")
@auth_shield
def protected_endpoint():
return {"message": "Access granted"}
# Shield with path parameters
@shield
def user_ownership_shield(request: Request, user_id: int) -> Optional[dict]:
current_user = get_current_user_from_token(request)
if current_user.id == user_id or current_user.is_admin:
return {"current_user": current_user}
return None
Note
Shield functions can be sync or async. The shield system handles both transparently. Shield functions should use type hints for better integration with FastAPI's dependency injection.
Usage Patterns¶
Direct Decorator Usage¶
The simplest way to create a shield:
from fastapi import FastAPI, Request
from fastapi_shield import shield
app = FastAPI()
@shield
def auth_shield(request: Request) -> dict | None:
"""Validate authentication token."""
token = request.headers.get("Authorization")
if validate_token(token):
return {"user_id": 123, "username": "john"}
return None
@app.get("/protected")
@auth_shield
def protected_endpoint():
return {"message": "Access granted"}
Parameterized Decorator¶
Configure shield behavior with parameters:
@shield(name="Authentication", auto_error=False)
def auth_shield(request: Request) -> dict | None:
"""Named shield with custom error handling."""
token = request.headers.get("Authorization")
if validate_token(token):
return {"user_id": 123}
return None
@shield(
name="RateLimit",
exception_to_raise_if_fail=HTTPException(429, "Rate limit exceeded")
)
async def rate_limit_shield(request: Request) -> dict | None:
"""Rate limiting shield with custom exception."""
client_ip = request.client.host
if await check_rate_limit(client_ip):
return {"client_ip": client_ip}
return None
Factory Function Usage¶
Create shield instances programmatically:
def auth_validation(request: Request) -> dict | None:
"""Authentication validation logic."""
# Validation logic here
return user_data or None
# Create shield instance
auth_shield = shield(
auth_validation,
name="Authentication",
auto_error=True,
exception_to_raise_if_fail=HTTPException(401, "Unauthorized")
)
# Apply to endpoints
@app.get("/profile")
@auth_shield
def get_profile():
return {"profile": "user data"}
Dynamic Shield Creation¶
Create shields based on configuration:
def create_auth_shield(config: dict):
"""Factory for creating authentication shields."""
def auth_function(request: Request) -> dict | None:
# Use config to determine validation method
if config["method"] == "jwt":
return validate_jwt(request)
elif config["method"] == "api_key":
return validate_api_key(request)
return None
return shield(
auth_function,
name=config.get("name", "Authentication"),
auto_error=config.get("auto_error", True)
)
# Create shields from configuration
jwt_config = {"method": "jwt", "name": "JWT Auth"}
api_key_config = {"method": "api_key", "name": "API Key Auth"}
jwt_shield = create_auth_shield(jwt_config)
api_key_shield = create_auth_shield(api_key_config)
Advanced Patterns¶
Conditional Shield Application¶
def create_conditional_shield(condition_func):
"""Create a shield that only applies under certain conditions."""
@shield
def conditional_shield(request: Request) -> dict | None:
if condition_func(request):
return perform_validation(request)
return {"bypass": True} # Allow request without validation
return conditional_shield
# Example: Only validate during business hours
def is_business_hours(request: Request) -> bool:
from datetime import datetime
now = datetime.now()
return 9 <= now.hour <= 17 # 9 AM to 5 PM
business_hours_auth = create_conditional_shield(is_business_hours)
@app.get("/api/data")
@business_hours_auth
def get_data():
return {"data": "sensitive information"}
Shield with Retry Logic¶
@shield(name="ResilientAuth")
async def resilient_auth_shield(request: Request) -> dict | None:
"""Authentication shield with retry logic."""
max_retries = 3
for attempt in range(max_retries):
try:
token = request.headers.get("Authorization")
user_data = await validate_token_async(token)
if user_data:
return user_data
except TemporaryError:
if attempt == max_retries - 1:
break
await asyncio.sleep(0.1 * (2 ** attempt)) # Exponential backoff
return None
Multi-Environment Shields¶
def create_environment_specific_shield(environment: str):
"""Create shields with environment-specific behavior."""
if environment == "development":
@shield(name="DevAuth", auto_error=False)
def dev_shield(request: Request) -> dict | None:
# Relaxed validation for development
return {"user_id": 1, "username": "dev_user"}
return dev_shield
elif environment == "production":
@shield(
name="ProdAuth",
exception_to_raise_if_fail=HTTPException(401, "Unauthorized")
)
def prod_shield(request: Request) -> dict | None:
# Strict validation for production
return validate_production_token(request)
return prod_shield
# Usage
import os
env = os.getenv("ENVIRONMENT", "development")
auth_shield = create_environment_specific_shield(env)
Shield Function Patterns¶
Path Parameter Validation¶
@shield
def ownership_shield(
request: Request,
user_id: int, # Path parameter
item_id: int # Path parameter
) -> dict | None:
"""Validate user owns the item."""
current_user = get_current_user_from_request(request)
if current_user.id == user_id:
item = get_item(item_id)
if item and item.owner_id == user_id:
return {"user": current_user, "item": item}
return None
@app.get("/users/{user_id}/items/{item_id}")
@ownership_shield
def get_user_item(user_id: int, item_id: int):
return {"message": "Access granted to item"}
Query Parameter Validation¶
@shield
def api_version_shield(
request: Request,
version: str = Query(default="v1") # Query parameter
) -> dict | None:
"""Validate API version compatibility."""
supported_versions = ["v1", "v2", "v3"]
if version in supported_versions:
return {"api_version": version, "features": get_features(version)}
return None
@app.get("/api/data")
@api_version_shield
def get_api_data(version: str = Query(default="v1")):
return {"data": "API response", "version": version}
Request Body Validation¶
from pydantic import BaseModel
class AuthRequest(BaseModel):
username: str
password: str
@shield
def login_shield(request: Request, auth_data: AuthRequest) -> dict | None:
"""Validate login credentials."""
if authenticate_user(auth_data.username, auth_data.password):
return {"user_id": get_user_id(auth_data.username)}
return None
@app.post("/login")
@login_shield
def login(auth_data: AuthRequest):
return {"message": "Login successful"}
Error Handling Patterns¶
Custom Error Responses¶
from fastapi import Response
custom_auth_shield = shield(
auth_function,
name="CustomAuth",
auto_error=False,
default_response_to_return_if_fail=Response(
content='{"error": "Authentication required", "code": "AUTH_001"}',
status_code=401,
headers={
"Content-Type": "application/json",
"WWW-Authenticate": "Bearer realm=\"API\""
}
)
)
Progressive Error Responses¶
def create_progressive_shield(max_attempts: int = 3):
"""Shield with progressive error responses."""
@shield(name="ProgressiveAuth")
def progressive_shield(request: Request) -> dict | None:
client_ip = request.client.host
attempts = get_failed_attempts(client_ip)
if attempts >= max_attempts:
# Temporary ban
raise HTTPException(
status_code=429,
detail="Too many failed attempts. Try again later.",
headers={"Retry-After": "300"}
)
auth_result = validate_request(request)
if not auth_result:
increment_failed_attempts(client_ip)
if attempts >= max_attempts - 1:
raise HTTPException(
status_code=401,
detail="Last attempt before temporary ban"
)
else:
raise HTTPException(
status_code=401,
detail=f"Authentication failed. {max_attempts - attempts - 1} attempts remaining"
)
reset_failed_attempts(client_ip)
return auth_result
return progressive_shield
Best Practices¶
- Name your shields - Use descriptive names for better error messages and debugging
- Handle async properly - Use async shield functions for I/O operations
- Return structured data - Return dictionaries or objects that can be used by ShieldedDepends
- Fail fast - Validate early and return None/False to block invalid requests
- Use type hints - Provide proper type annotations for better IDE support
- Document shield behavior - Include docstrings explaining what the shield validates
See Also¶
- Shield Class - The underlying Shield class
- ShieldedDepends Factory - Factory for creating dependencies
- Authentication Patterns - Common authentication use cases