Skip to content

Custom Initialization

FastAPI Shield provides flexible ways to customize how shields are initialized and configured.

Shield Factory Pattern

The shield factory pattern allows you to create reusable, configurable shields:

from fastapi import FastAPI, Header, Depends, HTTPException, status
from fastapi_shield import shield
from typing import List, Optional

app = FastAPI()

# Mock user data for testing
user_db = {
    "user1": {"username": "user1", "roles": ["viewer"]},
    "user2": {"username": "user2", "roles": ["editor", "viewer"]},
    "admin": {"username": "admin", "roles": ["admin", "editor", "viewer"]}
}

# Dependency to extract token from header
def get_token_from_header(authorization: str = Header(None)) -> Optional[str]:
    if not authorization or not authorization.startswith("Bearer "):
        return None
    return authorization.replace("Bearer ", "")

# Dependency to get current user from token
def get_current_user(token: str = Depends(get_token_from_header)) -> Optional[dict]:
    if not token:
        return None
    return user_db.get(token)

# Create a shield factory for role-based authorization
def create_role_shield(allowed_roles: List[str], auto_error: bool = True):
    """Factory function that creates a role-based shield"""
    exception = HTTPException(
        status_code=status.HTTP_403_FORBIDDEN,
        detail=f"Access denied. Required roles: {allowed_roles}"
    )

    @shield(
        name=f"Role Shield ({', '.join(allowed_roles)})",
        auto_error=auto_error,
        exception_to_raise_if_fail=exception
    )
    def role_checker(user: Optional[dict] = Depends(get_current_user)) -> Optional[dict]:
        if not user:
            return None

        user_roles = user.get("roles", [])
        for role in user_roles:
            if role in allowed_roles:
                return user
        return None

    return role_checker

# Create shields for different roles
admin_shield = create_role_shield(["admin"])
editor_shield = create_role_shield(["admin", "editor"])
viewer_shield = create_role_shield(["admin", "editor", "viewer"])
non_auto_error_shield = create_role_shield(["admin"], auto_error=False)

# Apply shields to endpoints
@app.get("/admin")
@admin_shield
def admin_endpoint(user: dict = Depends(lambda x: x)):
    return {"message": "Admin endpoint", "user": user["username"]}

@app.get("/editor")
@editor_shield
def editor_endpoint(user: dict = Depends(lambda x: x)):
    return {"message": "Editor endpoint", "user": user["username"]}

@app.get("/viewer")
@viewer_shield
def viewer_endpoint(user: dict = Depends(lambda x: x)):
    return {"message": "Viewer endpoint", "user": user["username"]}

# Example with non-auto-error shield
@app.get("/optional-admin")
@non_auto_error_shield
def optional_admin_endpoint(user: Optional[dict] = Depends(lambda x: x)):
    if user and "admin" in user.get("roles", []):
        return {"message": "Admin content", "is_admin": True}
    return {"message": "Regular content", "is_admin": False}

## Shield Classes

For more complex shields, you can use classes to encapsulate shield functionality:

```python
from fastapi import FastAPI, Request, HTTPException, status
from fastapi_shield import shield
import time
from typing import Dict, Any, Optional

app = FastAPI()

class RateLimiter:
    """Rate limiter class that enforces request limits per client"""

    def __init__(self, max_requests: int, window_seconds: int):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.client_requests: Dict[str, list] = {}

    def check_rate_limit(self, client_ip: str) -> bool:
        """Check if a client has exceeded their rate limit"""
        now = time.time()

        # Initialize or update client's request history
        if client_ip not in self.client_requests:
            self.client_requests[client_ip] = []

        # Remove expired timestamps
        self.client_requests[client_ip] = [
            ts for ts in self.client_requests[client_ip]
            if now - ts < self.window_seconds
        ]

        # Check if rate limit is exceeded
        if len(self.client_requests[client_ip]) >= self.max_requests:
            return False

        # Add current timestamp
        self.client_requests[client_ip].append(now)
        return True

    def create_shield(self, name: str = "Rate Limiter"):
        """Create a shield function from this rate limiter"""

        @shield(
            name=name,
            auto_error=True,
            exception_to_raise_if_fail=HTTPException(
                status_code=status.HTTP_429_TOO_MANY_REQUESTS,
                detail=f"Rate limit exceeded. Maximum {self.max_requests} requests per {self.window_seconds} seconds.",
                headers={"Retry-After": str(self.window_seconds)}
            )
        )
        def rate_limit_shield(request: Request) -> Optional[Dict[str, Any]]:
            client_ip = request.client.host
            if self.check_rate_limit(client_ip):
                return {"client_ip": client_ip, "rate_limited": False}
            return None

        return rate_limit_shield

# Create rate limiters with different configurations
default_limiter = RateLimiter(max_requests=10, window_seconds=60)
strict_limiter = RateLimiter(max_requests=3, window_seconds=60)

# Create shields from the rate limiters
default_rate_shield = default_limiter.create_shield(name="Default Rate Limiter")
strict_rate_shield = strict_limiter.create_shield(name="Strict Rate Limiter")

# Apply the shields to different endpoints
@app.get("/normal")
@default_rate_shield
async def normal_endpoint():
    return {"message": "Normal endpoint with standard rate limiting"}

@app.get("/sensitive")
@strict_rate_shield
async def sensitive_endpoint():
    return {"message": "Sensitive endpoint with strict rate limiting"}

Conditional Shield Creation

You can create shields conditionally based on configuration:

from fastapi import FastAPI
from fastapi_shield import shield
import os

app = FastAPI()

# Get configuration from environment
ENABLE_SECURITY = os.getenv("ENABLE_SECURITY", "true").lower() == "true"
ENVIRONMENT = os.getenv("ENVIRONMENT", "production")

# Basic shield
@shield(name="Basic Security")
def basic_security():
    if ENABLE_SECURITY:
        return True
    return None

# Create shields conditionally
if ENVIRONMENT == "production":
    # In production, use strict checks
    @shield(name="Strict Mode")
    def environment_shield():
        return True
else:
    # In development, be more lenient
    @shield(name="Development Mode")
    def environment_shield():
        # Always pass in development
        return {"development_mode": True}

@app.get("/api/data")
@basic_security
@environment_shield
async def get_data():
    return {"message": "Secured data endpoint"}

Using Dependency Injection for Shield Configuration

You can use FastAPI's dependency injection to configure shields:

from fastapi import FastAPI, Depends
from fastapi_shield import shield
from functools import lru_cache
from pydantic import BaseSettings

class SecuritySettings(BaseSettings):
    """Security settings that can be loaded from environment variables"""
    enable_auth: bool = True
    admin_token: str = "admin_token"
    required_role: str = "admin"

    class Config:
        env_prefix = "SECURITY_"

@lru_cache()
def get_security_settings():
    """Cache the settings object"""
    return SecuritySettings()

def create_auth_shield(settings: SecuritySettings = Depends(get_security_settings)):
    """Create an authentication shield based on settings"""

    @shield(name="Auth Shield")
    def auth_shield(api_token: str = Header()):
        if not settings.enable_auth:
            # Skip auth check if disabled
            return {"auth_skipped": True}

        if api_token == settings.admin_token:
            return {"role": settings.required_role}
        return None

    return auth_shield

app = FastAPI()

@app.get("/admin")
@create_auth_shield()
async def admin_endpoint():
    return {"message": "Admin endpoint"}

Dynamic Shield Configuration

For even more flexibility, you can create shields that adjust their behavior dynamically:

from fastapi import FastAPI, Header, Request
from fastapi_shield import shield
import json
import os

app = FastAPI()

# Load shield configuration from file
def load_shield_config():
    config_path = os.getenv("SHIELD_CONFIG", "shield_config.json")
    try:
        with open(config_path, "r") as f:
            return json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        # Return default config if file not found or invalid
        return {
            "enabled": True,
            "allowed_tokens": ["admin_token", "user_token"],
            "allowed_ips": ["127.0.0.1"]
        }

@shield(name="Configurable Shield")
def configurable_shield(request: Request, api_token: str = Header()):
    # Load config on each request (could be cached for performance)
    config = load_shield_config()

    # Check if shield is enabled
    if not config.get("enabled", True):
        return {"shield_disabled": True}

    # Check if client IP is allowed
    client_ip = request.client.host
    allowed_ips = config.get("allowed_ips", [])
    if allowed_ips and client_ip in allowed_ips:
        return {"allowed_by_ip": True}

    # Check if token is allowed
    allowed_tokens = config.get("allowed_tokens", [])
    if api_token in allowed_tokens:
        return {"allowed_by_token": True}

    # If none of the checks passed, block the request
    return None

@app.get("/configurable")
@configurable_shield
async def configurable_endpoint():
    return {"message": "Endpoint with configurable shield"}

These patterns provide flexible ways to customize and configure shields in your FastAPI application.