Skip to content

Shield Class

The Shield class is the core component of FastAPI Shield, providing request interception and validation functionality.

Overview

The Shield class works as a decorator that wraps FastAPI endpoint functions. It intercepts requests before they reach the endpoint, runs validation logic, and either allows or blocks the request based on the validation result.

Class Reference

Bases: Generic[U]

The main shield decorator class for request interception and validation.

Shield provides a powerful framework for intercepting FastAPI requests before they reach the endpoint handlers. It can validate authentication, authorization, rate limiting, input sanitization, and any other request-level logic.

The Shield class works as a decorator that wraps endpoint functions. When a request is made to a shielded endpoint:

  1. The shield function is called first with request parameters
  2. If the shield returns truthy data, the request proceeds
  3. The data returned by the shield is available to ShieldedDepends dependencies
  4. If the shield returns None/False, the request is blocked

Attributes:

Name Type Description
auto_error

Whether to raise HTTP exceptions on shield failure

name

Human-readable name for the shield (used in error messages)

_guard_func

The actual shield validation function

_guard_func_is_async

Whether the shield function is async

_guard_func_params

Parameters of the shield function

_exception_to_raise_if_fail

Exception to raise when shield blocks request

_default_response_to_return_if_fail

Response to return when not using auto_error

Examples:

# Basic authentication shield
@shield
def auth_shield(request: Request) -> Optional[dict]:
    token = request.headers.get("Authorization")
    if validate_token(token):
        return {"user_id": 123, "username": "john"}
    return None  # Block the request

# Apply shield to endpoint
@app.get("/protected")
@auth_shield
def protected_endpoint():
    return {"message": "Access granted"}

# Custom error handling
auth_shield_custom = Shield(
    auth_shield,
    name="Authentication",
    auto_error=False,
    default_response_to_return_if_fail=Response(
        content="Authentication required",
        status_code=401
    )
)

__init__(shield_func, *, name=None, auto_error=True, exception_to_raise_if_fail=None, default_response_to_return_if_fail=None)

Initialize a new Shield instance.

Parameters:

Name Type Description Default
shield_func U

The validation function to use for this shield. Should return truthy data to allow requests, or None/False to block. Can be sync or async.

required
name str

Human-readable name for the shield. Used in error messages and logging. Defaults to "unknown" if not provided.

None
auto_error bool

Whether to automatically raise HTTP exceptions when the shield blocks a request. If False, returns the default response instead.

True
exception_to_raise_if_fail Optional[HTTPException]

Custom HTTP exception to raise when shield blocks a request and auto_error=True. Defaults to a 500 error with the shield name.

None
default_response_to_return_if_fail Optional[Response]

Custom response to return when shield blocks a request and auto_error=False. Defaults to a 500 response with shield name.

None

Raises:

Type Description
AssertionError

If shield_func is not callable, or if exception/response parameters are not the correct types.

Examples:

# Basic shield
shield = Shield(my_auth_function)

# Named shield with custom error
shield = Shield(
    my_auth_function,
    name="Authentication",
    exception_to_raise_if_fail=HTTPException(401, "Authentication required")
)

# Shield with custom response instead of exceptions
shield = Shield(
    my_auth_function,
    name="Authentication",
    auto_error=False,
    default_response_to_return_if_fail=Response(
        content="Please log in",
        status_code=401,
        headers={"WWW-Authenticate": "Bearer"}
    )
)

__call__(endpoint)

Apply the shield to a FastAPI endpoint function.

This method implements the decorator functionality, wrapping the endpoint function with shield validation logic. When the returned wrapper is called:

  1. Extracts relevant parameters for the shield function
  2. Calls the shield function with those parameters
  3. If shield returns truthy data:
  4. Resolves all endpoint dependencies
  5. Injects data returned by the shield into ShieldedDepends dependencies
  6. Calls the original endpoint with resolved parameters
  7. If shield returns None/False:
  8. Blocks the request by raising an exception or returning error response

The wrapper handles both sync and async shield functions and endpoints automatically, and integrates with FastAPI's dependency injection system.

Parameters:

Name Type Description Default
endpoint EndPointFunc

The FastAPI endpoint function to protect with this shield. Can be sync or async.

required

Returns:

Name Type Description
EndPointFunc EndPointFunc

The wrapped endpoint function with shield protection.

Raises:

Type Description
AssertionError

If endpoint is not callable.

Examples:

@shield
def auth_shield(request: Request) -> Optional[dict]:
    # Validation logic
    return user_data or None

# Using as decorator
@app.get("/protected")
@auth_shield
def protected_endpoint():
    return {"message": "Success"}
Note

The wrapper function is always async, even if the original endpoint is sync, because dependency resolution is inherently async in FastAPI.

_raise_or_return_default_response()

Handle shield failure by raising an exception or returning a default response.

This method is called when the shield blocks a request (returns None/False). The behavior depends on the auto_error setting: - If auto_error=True: raises the configured HTTP exception - If auto_error=False: returns the configured default response

Returns:

Name Type Description
Response

The default response if auto_error=False

Raises:

Type Description
HTTPException

The configured exception if auto_error=True

Usage Examples

Basic Shield

from fastapi import FastAPI, Request
from fastapi_shield import Shield

app = FastAPI()

def auth_validation(request: Request) -> dict | None:
    """Validate authentication token."""
    token = request.headers.get("Authorization")
    if validate_token(token):
        return {"user_id": 123, "role": "user"}
    return None

# Create shield instance
auth_shield = Shield(auth_validation, name="Authentication")

@app.get("/protected")
@auth_shield
def protected_endpoint():
    return {"message": "Access granted"}

Custom Error Handling

from fastapi import HTTPException, Response
from fastapi_shield import Shield

# Shield with custom exception
auth_shield = Shield(
    auth_validation,
    name="Authentication",
    exception_to_raise_if_fail=HTTPException(
        status_code=401,
        detail="Authentication required"
    )
)

# Shield with custom response (no exception)
auth_shield_no_error = Shield(
    auth_validation,
    name="Authentication",
    auto_error=False,
    default_response_to_return_if_fail=Response(
        content="Please authenticate",
        status_code=401,
        headers={"WWW-Authenticate": "Bearer"}
    )
)

Async Shield Functions

import aioredis
from fastapi import Request

async def rate_limit_shield(request: Request) -> dict | None:
    """Rate limiting with Redis."""
    redis = aioredis.from_url("redis://localhost")
    client_ip = request.client.host

    # Check rate limit
    current_count = await redis.incr(f"rate_limit:{client_ip}")
    if current_count == 1:
        await redis.expire(f"rate_limit:{client_ip}", 60)

    if current_count > 100:  # 100 requests per minute
        return None

    return {"requests_remaining": 100 - current_count}

rate_limiter = Shield(rate_limit_shield, name="RateLimit")

@app.get("/api/data")
@rate_limiter
def get_data():
    return {"data": "sensitive information"}

Integration with Dependencies

Shields work seamlessly with FastAPI's dependency injection system and can access any parameters that the endpoint would receive:

from fastapi import Depends, Path

def get_database():
    # Database connection logic
    return database

def ownership_shield(
    request: Request,
    user_id: int = Path(...),
    db = Depends(get_database)
) -> dict | None:
    """Verify user owns the resource."""
    current_user = get_current_user_from_token(request)
    if current_user.id == user_id or current_user.is_admin:
        return {"current_user": current_user}
    return None

ownership_guard = Shield(ownership_shield, name="Ownership")

@app.get("/users/{user_id}/profile")
@ownership_guard
def get_user_profile(user_id: int, db = Depends(get_database)):
    return get_profile_from_db(db, user_id)

Shield Chaining

Multiple shields can be chained together for layered protection:

@app.get("/admin/users/{user_id}")
@auth_shield
@admin_shield
@rate_limiter
def admin_get_user(user_id: int):
    return {"user": get_user(user_id)}

Error Handling

Shields provide comprehensive error handling options:

  • auto_error=True (default): Raises HTTP exceptions when validation fails
  • auto_error=False: Returns custom responses without raising exceptions
  • Custom exceptions: Define specific HTTPExceptions for different failure scenarios
  • Custom responses: Full control over response content, headers, and status codes

Performance Considerations

  • Shield functions should be lightweight and fast
  • Use async shield functions for I/O operations (database, Redis, HTTP calls)
  • Consider caching validation results when appropriate
  • Shields are called for every request, so optimize for performance

See Also