Shield Class¶
The Shield
class is the core component of FastAPI Shield, providing request interception and validation functionality.
Overview¶
The Shield
class works as a decorator that wraps FastAPI endpoint functions. It intercepts requests before they reach the endpoint, runs validation logic, and either allows or blocks the request based on the validation result.
Class Reference¶
Bases: Generic[U]
The main shield decorator class for request interception and validation.
Shield provides a powerful framework for intercepting FastAPI requests before they reach the endpoint handlers. It can validate authentication, authorization, rate limiting, input sanitization, and any other request-level logic.
The Shield
class works as a decorator that wraps endpoint functions. When a
request is made to a shielded endpoint:
- The shield function is called first with request parameters
- If the shield returns truthy data, the request proceeds
- The data returned by the shield is available to
ShieldedDepends
dependencies - If the shield returns None/False, the request is blocked
Attributes:
Name | Type | Description |
---|---|---|
auto_error |
Whether to raise HTTP exceptions on shield failure |
|
name |
Human-readable name for the shield (used in error messages) |
|
_guard_func |
The actual shield validation function |
|
_guard_func_is_async |
Whether the shield function is async |
|
_guard_func_params |
Parameters of the shield function |
|
_exception_to_raise_if_fail |
Exception to raise when shield blocks request |
|
_default_response_to_return_if_fail |
Response to return when not using auto_error |
Examples:
# Basic authentication shield
@shield
def auth_shield(request: Request) -> Optional[dict]:
token = request.headers.get("Authorization")
if validate_token(token):
return {"user_id": 123, "username": "john"}
return None # Block the request
# Apply shield to endpoint
@app.get("/protected")
@auth_shield
def protected_endpoint():
return {"message": "Access granted"}
# Custom error handling
auth_shield_custom = Shield(
auth_shield,
name="Authentication",
auto_error=False,
default_response_to_return_if_fail=Response(
content="Authentication required",
status_code=401
)
)
__init__(shield_func, *, name=None, auto_error=True, exception_to_raise_if_fail=None, default_response_to_return_if_fail=None)
¶
Initialize a new Shield instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
shield_func
|
U
|
The validation function to use for this shield. Should return truthy data to allow requests, or None/False to block. Can be sync or async. |
required |
name
|
str
|
Human-readable name for the shield. Used in error messages and logging. Defaults to "unknown" if not provided. |
None
|
auto_error
|
bool
|
Whether to automatically raise HTTP exceptions when the shield blocks a request. If False, returns the default response instead. |
True
|
exception_to_raise_if_fail
|
Optional[HTTPException]
|
Custom HTTP exception to raise when shield blocks a request and auto_error=True. Defaults to a 500 error with the shield name. |
None
|
default_response_to_return_if_fail
|
Optional[Response]
|
Custom response to return when shield blocks a request and auto_error=False. Defaults to a 500 response with shield name. |
None
|
Raises:
Type | Description |
---|---|
AssertionError
|
If shield_func is not callable, or if exception/response parameters are not the correct types. |
Examples:
# Basic shield
shield = Shield(my_auth_function)
# Named shield with custom error
shield = Shield(
my_auth_function,
name="Authentication",
exception_to_raise_if_fail=HTTPException(401, "Authentication required")
)
# Shield with custom response instead of exceptions
shield = Shield(
my_auth_function,
name="Authentication",
auto_error=False,
default_response_to_return_if_fail=Response(
content="Please log in",
status_code=401,
headers={"WWW-Authenticate": "Bearer"}
)
)
__call__(endpoint)
¶
Apply the shield to a FastAPI endpoint function.
This method implements the decorator functionality, wrapping the endpoint function with shield validation logic. When the returned wrapper is called:
- Extracts relevant parameters for the shield function
- Calls the shield function with those parameters
- If shield returns truthy data:
- Resolves all endpoint dependencies
- Injects data returned by the shield into
ShieldedDepends
dependencies - Calls the original endpoint with resolved parameters
- If shield returns None/False:
- Blocks the request by raising an exception or returning error response
The wrapper handles both sync and async shield functions and endpoints automatically, and integrates with FastAPI's dependency injection system.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
endpoint
|
EndPointFunc
|
The FastAPI endpoint function to protect with this shield. Can be sync or async. |
required |
Returns:
Name | Type | Description |
---|---|---|
EndPointFunc |
EndPointFunc
|
The wrapped endpoint function with shield protection. |
Raises:
Type | Description |
---|---|
AssertionError
|
If endpoint is not callable. |
Examples:
@shield
def auth_shield(request: Request) -> Optional[dict]:
# Validation logic
return user_data or None
# Using as decorator
@app.get("/protected")
@auth_shield
def protected_endpoint():
return {"message": "Success"}
Note
The wrapper function is always async, even if the original endpoint is sync, because dependency resolution is inherently async in FastAPI.
_raise_or_return_default_response()
¶
Handle shield failure by raising an exception or returning a default response.
This method is called when the shield blocks a request (returns None/False). The behavior depends on the auto_error setting: - If auto_error=True: raises the configured HTTP exception - If auto_error=False: returns the configured default response
Returns:
Name | Type | Description |
---|---|---|
Response |
The default response if auto_error=False |
Raises:
Type | Description |
---|---|
HTTPException
|
The configured exception if auto_error=True |
Usage Examples¶
Basic Shield¶
from fastapi import FastAPI, Request
from fastapi_shield import Shield
app = FastAPI()
def auth_validation(request: Request) -> dict | None:
"""Validate authentication token."""
token = request.headers.get("Authorization")
if validate_token(token):
return {"user_id": 123, "role": "user"}
return None
# Create shield instance
auth_shield = Shield(auth_validation, name="Authentication")
@app.get("/protected")
@auth_shield
def protected_endpoint():
return {"message": "Access granted"}
Custom Error Handling¶
from fastapi import HTTPException, Response
from fastapi_shield import Shield
# Shield with custom exception
auth_shield = Shield(
auth_validation,
name="Authentication",
exception_to_raise_if_fail=HTTPException(
status_code=401,
detail="Authentication required"
)
)
# Shield with custom response (no exception)
auth_shield_no_error = Shield(
auth_validation,
name="Authentication",
auto_error=False,
default_response_to_return_if_fail=Response(
content="Please authenticate",
status_code=401,
headers={"WWW-Authenticate": "Bearer"}
)
)
Async Shield Functions¶
import aioredis
from fastapi import Request
async def rate_limit_shield(request: Request) -> dict | None:
"""Rate limiting with Redis."""
redis = aioredis.from_url("redis://localhost")
client_ip = request.client.host
# Check rate limit
current_count = await redis.incr(f"rate_limit:{client_ip}")
if current_count == 1:
await redis.expire(f"rate_limit:{client_ip}", 60)
if current_count > 100: # 100 requests per minute
return None
return {"requests_remaining": 100 - current_count}
rate_limiter = Shield(rate_limit_shield, name="RateLimit")
@app.get("/api/data")
@rate_limiter
def get_data():
return {"data": "sensitive information"}
Integration with Dependencies¶
Shields work seamlessly with FastAPI's dependency injection system and can access any parameters that the endpoint would receive:
from fastapi import Depends, Path
def get_database():
# Database connection logic
return database
def ownership_shield(
request: Request,
user_id: int = Path(...),
db = Depends(get_database)
) -> dict | None:
"""Verify user owns the resource."""
current_user = get_current_user_from_token(request)
if current_user.id == user_id or current_user.is_admin:
return {"current_user": current_user}
return None
ownership_guard = Shield(ownership_shield, name="Ownership")
@app.get("/users/{user_id}/profile")
@ownership_guard
def get_user_profile(user_id: int, db = Depends(get_database)):
return get_profile_from_db(db, user_id)
Shield Chaining¶
Multiple shields can be chained together for layered protection:
@app.get("/admin/users/{user_id}")
@auth_shield
@admin_shield
@rate_limiter
def admin_get_user(user_id: int):
return {"user": get_user(user_id)}
Error Handling¶
Shields provide comprehensive error handling options:
- auto_error=True (default): Raises HTTP exceptions when validation fails
- auto_error=False: Returns custom responses without raising exceptions
- Custom exceptions: Define specific HTTPExceptions for different failure scenarios
- Custom responses: Full control over response content, headers, and status codes
Performance Considerations¶
- Shield functions should be lightweight and fast
- Use async shield functions for I/O operations (database, Redis, HTTP calls)
- Consider caching validation results when appropriate
- Shields are called for every request, so optimize for performance
See Also¶
- ShieldDepends - Dependency injection for shields
- shield factory function - Convenient decorator interface
- Utils - Utility functions used internally