Skip to content

Dependency Injection with FastAPI Shield

FastAPI Shield provides a powerful dependency injection system that extends FastAPI's built-in dependency injection with shield-based validation and transformation. This guide explores how to use shields with dependencies based on the actual implementation patterns.

Understanding Shield Architecture

FastAPI Shield works by decorating endpoints with shields that: 1. Execute before the endpoint function 2. Can validate, transform, or block requests 3. Pass validated data to the endpoint via ShieldedDepends

The key components are: - @shield decorator: Creates a shield function that validates/transforms data - ShieldedDepends: A dependency that receives data from shield functions - Shield composition: Multiple shields can be chained together

Basic Shield with Dependencies

Here's how shields work with dependencies:

from fastapi import FastAPI, Header, HTTPException, Depends
from fastapi_shield import shield, ShieldedDepends
from typing import Optional

app = FastAPI()

# Mock database
USERS_DB = {
    "user1": {"username": "user1", "email": "user1@example.com", "roles": ["user"]},
    "admin1": {"username": "admin1", "email": "admin1@example.com", "roles": ["admin", "user"]},
}

def get_database():
    """Dependency that provides database access"""
    return USERS_DB

def validate_token(token: str) -> bool:
    """Helper function to validate tokens"""
    return token in ["valid_user_token", "valid_admin_token"]

def get_user_from_token(token: str) -> Optional[str]:
    """Helper function to extract username from token"""
    if token == "valid_user_token":
        return "user1"
    elif token == "valid_admin_token":
        return "admin1"
    return None

# Authentication shield
@shield(name="Authentication Shield")
def auth_shield(authorization: str = Header()) -> Optional[str]:
    """Shield that validates authorization header and returns token"""
    if not authorization.startswith("Bearer "):
        return None

    token = authorization.replace("Bearer ", "")
    if validate_token(token):
        return token
    return None

# User data retrieval function (used with ShieldedDepends)
def get_user_data(
    token: str,  # This comes from the shield
    db: dict = Depends(get_database)  # This is a regular FastAPI dependency
) -> dict:
    """Function that gets user data using token from shield and database dependency"""
    username = get_user_from_token(token)
    if username and username in db:
        return db[username]
    raise HTTPException(status_code=404, detail="User not found")

# Endpoint using shield and ShieldedDepends
@app.get("/profile")
@auth_shield
async def get_profile(
    user: dict = ShieldedDepends(get_user_data)
):
    """Endpoint that requires authentication and returns user profile"""
    return {
        "username": user["username"],
        "email": user["email"],
        "roles": user["roles"]
    }

Shield Composition and Chaining

Shields can be composed to create complex validation chains:

from fastapi import FastAPI, Header, HTTPException
from fastapi_shield import shield, ShieldedDepends
from typing import List, Optional

app = FastAPI()

# Authentication shield (first in chain)
@shield(name="JWT Auth")
def jwt_auth_shield(authorization: str = Header()) -> Optional[dict]:
    """Validates JWT token and returns payload"""
    if not authorization.startswith("Bearer "):
        return None

    token = authorization.replace("Bearer ", "")
    # In real app, decode JWT here
    if token == "valid_jwt_token":
        return {
            "user_id": "user123",
            "username": "john_doe",
            "roles": ["user", "admin"],
            "permissions": ["read:profile", "write:profile"]
        }
    return None

# Role validation shield (second in chain)
def require_role(required_role: str):
    """Factory function that creates role-checking shields"""

    @shield(
        name=f"Role Check ({required_role})",
        exception_to_raise_if_fail=HTTPException(
            status_code=403,
            detail=f"Role '{required_role}' required"
        )
    )
    def role_shield(
        payload: dict = ShieldedDepends(lambda payload: payload)  # Gets data from previous shield
    ) -> Optional[dict]:
        """Shield that checks if user has required role"""
        user_roles = payload.get("roles", [])
        if required_role in user_roles:
            return payload
        return None

    return role_shield

# Permission validation shield (third in chain)
def require_permission(required_permission: str):
    """Factory function that creates permission-checking shields"""

    @shield(
        name=f"Permission Check ({required_permission})",
        exception_to_raise_if_fail=HTTPException(
            status_code=403,
            detail=f"Permission '{required_permission}' required"
        )
    )
    def permission_shield(
        payload: dict = ShieldedDepends(lambda payload: payload)  # Gets data from previous shield
    ) -> Optional[dict]:
        """Shield that checks if user has required permission"""
        user_permissions = payload.get("permissions", [])
        if required_permission in user_permissions:
            return payload
        return None

    return permission_shield

# Create specific shield instances
admin_role_shield = require_role("admin")
write_permission_shield = require_permission("write:profile")

# Endpoint with multiple shields
@app.get("/admin-profile")
@jwt_auth_shield
@admin_role_shield
async def admin_profile(
    user_data: dict = ShieldedDepends(lambda payload: payload)
):
    """Endpoint requiring JWT auth and admin role"""
    return {
        "message": "Admin profile access granted",
        "user_id": user_data["user_id"],
        "username": user_data["username"]
    }

@app.post("/update-profile")
@jwt_auth_shield
@write_permission_shield
async def update_profile(
    profile_data: dict,
    user_data: dict = ShieldedDepends(lambda payload: payload)
):
    """Endpoint requiring JWT auth and write permission"""
    return {
        "message": "Profile updated",
        "user_id": user_data["user_id"],
        "updated_data": profile_data
    }

Working with Pydantic Models

FastAPI Shield integrates seamlessly with Pydantic for data validation:

from fastapi import FastAPI, Body, HTTPException
from fastapi_shield import shield, ShieldedDepends
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List

app = FastAPI()

# Pydantic models
class UserInput(BaseModel):
    username: str = Field(..., min_length=3, max_length=20)
    email: str = Field(..., pattern=r'^[^@]+@[^@]+\.[^@]+$')
    full_name: Optional[str] = None
    age: int = Field(..., ge=13, le=120)

class ValidatedUser(BaseModel):
    username: str
    email: str
    full_name: Optional[str]
    age: int
    is_valid: bool = True
    validation_notes: List[str] = []

# Shield that validates and transforms user data
@shield(
    name="User Validator",
    exception_to_raise_if_fail=HTTPException(
        status_code=400,
        detail="User validation failed"
    )
)
def validate_user_shield(user_input: UserInput = Body()) -> Optional[ValidatedUser]:
    """Shield that performs additional validation beyond Pydantic"""

    # Check for reserved usernames
    reserved_usernames = ["admin", "system", "root", "api"]
    if user_input.username.lower() in reserved_usernames:
        return None

    # Check email domain restrictions
    allowed_domains = ["company.com", "partner.org"]
    email_domain = user_input.email.split("@")[1]
    if email_domain not in allowed_domains:
        return None

    # Create validated user with additional metadata
    validated_user = ValidatedUser(
        username=user_input.username,
        email=user_input.email,
        full_name=user_input.full_name,
        age=user_input.age,
        validation_notes=["Email domain approved", "Username available"]
    )

    return validated_user

# Function to enrich user data (used with ShieldedDepends)
def enrich_user_data(validated_user: ValidatedUser) -> dict:
    """Function that enriches validated user data"""
    return {
        "user": validated_user.dict(),
        "account_type": "premium" if validated_user.age >= 18 else "standard",
        "welcome_message": f"Welcome, {validated_user.username}!",
        "next_steps": ["verify_email", "complete_profile"]
    }

@app.post("/register")
@validate_user_shield
async def register_user(
    enriched_data: dict = ShieldedDepends(enrich_user_data)
):
    """Endpoint that registers a user with validation and enrichment"""
    return {
        "message": "User registered successfully",
        "data": enriched_data
    }

Database Integration with Shields

Here's how to integrate shields with database operations:

from fastapi import FastAPI, Depends, HTTPException, Header
from fastapi_shield import shield, ShieldedDepends
from typing import Optional, Dict, Any
import asyncio

app = FastAPI()

# Mock database
USERS_DB = {
    "user1": {"id": 1, "username": "user1", "active": True, "role": "user"},
    "admin1": {"id": 2, "username": "admin1", "active": True, "role": "admin"},
    "inactive1": {"id": 3, "username": "inactive1", "active": False, "role": "user"},
}

async def get_database():
    """Async database dependency"""
    # Simulate database connection
    await asyncio.sleep(0.01)
    return USERS_DB

# Authentication shield with database lookup
@shield(name="Database Auth")
async def db_auth_shield(
    api_key: str = Header(),
    db: Dict[str, Any] = Depends(get_database)
) -> Optional[dict]:
    """Shield that authenticates user against database"""

    # Simple API key to username mapping
    api_key_mapping = {
        "user1_key": "user1",
        "admin1_key": "admin1",
        "inactive1_key": "inactive1"
    }

    username = api_key_mapping.get(api_key)
    if not username:
        return None

    user = db.get(username)
    if not user or not user["active"]:
        return None

    return user

# Function to get user permissions (used with ShieldedDepends)
async def get_user_permissions(
    user: dict,  # Comes from shield
    db: Dict[str, Any] = Depends(get_database)  # Database dependency
) -> dict:
    """Function that retrieves user permissions from database"""

    # Mock permission lookup
    permissions_map = {
        "user": ["read:own_data"],
        "admin": ["read:own_data", "read:all_data", "write:all_data"]
    }

    permissions = permissions_map.get(user["role"], [])

    return {
        "user": user,
        "permissions": permissions,
        "can_read_all": "read:all_data" in permissions,
        "can_write_all": "write:all_data" in permissions
    }

@app.get("/user-data")
@db_auth_shield
async def get_user_data(
    user_info: dict = ShieldedDepends(get_user_permissions)
):
    """Endpoint that returns user data based on permissions"""

    if user_info["can_read_all"]:
        # Admin can see all users
        return {
            "message": "All user data",
            "data": list(USERS_DB.values()),
            "user": user_info["user"]
        }
    else:
        # Regular user can only see their own data
        return {
            "message": "Your user data",
            "data": user_info["user"],
            "permissions": user_info["permissions"]
        }

Advanced Shield Patterns

Conditional Shield Execution

from fastapi import FastAPI, Header, Query
from fastapi_shield import shield, ShieldedDepends
from typing import Optional

app = FastAPI()

# Feature flag shield
@shield(name="Feature Flag Check")
def feature_flag_shield(
    feature: str = Query(...),
    user_type: str = Header(default="regular")
) -> Optional[dict]:
    """Shield that checks if feature is enabled for user type"""

    feature_flags = {
        "beta_feature": ["premium", "admin"],
        "experimental_api": ["admin"],
        "new_ui": ["regular", "premium", "admin"]
    }

    allowed_user_types = feature_flags.get(feature, [])
    if user_type in allowed_user_types:
        return {
            "feature": feature,
            "user_type": user_type,
            "access_granted": True
        }
    return None

@app.get("/feature/{feature_name}")
@feature_flag_shield
async def access_feature(
    feature_name: str,
    access_info: dict = ShieldedDepends(lambda info: info)
):
    """Endpoint that provides access to features based on user type"""
    return {
        "message": f"Access granted to {access_info['feature']}",
        "user_type": access_info["user_type"],
        "feature_data": f"Data for {feature_name}"
    }

Error Handling in Shields

from fastapi import FastAPI, HTTPException, Header
from fastapi_shield import shield, ShieldedDepends
from typing import Optional

app = FastAPI()

# Shield with custom error handling
@shield(
    name="Rate Limit Shield",
    exception_to_raise_if_fail=HTTPException(
        status_code=429,
        detail="Rate limit exceeded",
        headers={"Retry-After": "60"}
    )
)
def rate_limit_shield(
    x_client_id: str = Header()
) -> Optional[dict]:
    """Shield that implements rate limiting"""

    # Mock rate limiting logic
    rate_limits = {
        "client1": {"requests": 5, "window": 60},
        "client2": {"requests": 100, "window": 60}
    }

    if x_client_id not in rate_limits:
        return None

    # In real implementation, check against Redis or similar
    # For demo, always allow
    return {
        "client_id": x_client_id,
        "rate_limit": rate_limits[x_client_id]
    }

@app.get("/api/data")
@rate_limit_shield
async def get_api_data(
    client_info: dict = ShieldedDepends(lambda info: info)
):
    """Rate-limited API endpoint"""
    return {
        "data": "API response data",
        "client": client_info["client_id"],
        "rate_limit": client_info["rate_limit"]
    }

Best Practices

1. Shield Naming and Organization

# Good: Descriptive shield names
@shield(name="JWT Authentication")
def jwt_auth_shield(token: str = Header()) -> Optional[dict]:
    pass

@shield(name="Admin Role Check")
def admin_role_shield(user: dict = ShieldedDepends(lambda u: u)) -> Optional[dict]:
    pass

# Good: Shield factory functions for reusability
def require_role(role: str):
    @shield(name=f"Require {role} Role")
    def role_shield(user: dict = ShieldedDepends(lambda u: u)) -> Optional[dict]:
        return user if role in user.get("roles", []) else None
    return role_shield

2. Proper ShieldedDepends Usage

# Correct: Use lambda functions to pass data from shields
@app.get("/endpoint")
@auth_shield
async def endpoint(
    user_data: dict = ShieldedDepends(lambda user: user)  # Gets data from auth_shield
):
    pass

# Correct: Use functions for complex dependency resolution
def get_user_with_permissions(user_data: dict, db = Depends(get_db)) -> dict:
    # Complex logic here
    return enriched_user_data

@app.get("/endpoint")
@auth_shield
async def endpoint(
    user: dict = ShieldedDepends(get_user_with_permissions)
):
    pass

3. Shield Composition Order

# Correct order: Authentication -> Authorization -> Business Logic
@app.get("/admin-endpoint")
@jwt_auth_shield          # 1. Authenticate user
@admin_role_shield        # 2. Check admin role
@rate_limit_shield        # 3. Apply rate limiting
async def admin_endpoint(
    user: dict = ShieldedDepends(lambda user: user)
):
    pass

4. Error Handling

# Good: Specific error messages and status codes
@shield(
    name="Permission Check",
    exception_to_raise_if_fail=HTTPException(
        status_code=403,
        detail="Insufficient permissions for this operation",
        headers={"X-Required-Permission": "admin:write"}
    )
)
def permission_shield(user: dict = ShieldedDepends(lambda u: u)) -> Optional[dict]:
    return user if "admin:write" in user.get("permissions", []) else None

5. Testing Shields

# Test shields independently
def test_auth_shield():
    # Test shield logic directly
    result = auth_shield.__wrapped__("Bearer valid_token")
    assert result is not None

    result = auth_shield.__wrapped__("Bearer invalid_token")
    assert result is None

# Test shield composition
def test_endpoint_with_shields(client):
    response = client.get("/protected", headers={"Authorization": "Bearer valid_token"})
    assert response.status_code == 200

This documentation reflects the actual implementation patterns used in FastAPI Shield, showing how shields work as decorators that validate requests and pass data to endpoints via ShieldedDepends.