Skip to content

Basic Examples

This section provides basic examples of using FastAPI Shield. These examples demonstrate the fundamental features and patterns of FastAPI Shield.

Simple Authentication Shield

Here's a basic authentication shield that validates API tokens:

from fastapi import FastAPI, Header, HTTPException, status
from fastapi_shield import shield

app = FastAPI()

# Define valid API tokens
VALID_TOKENS = ["token1", "token2", "token3"]

@shield(
    name="API Token Auth",
    exception_to_raise_if_fail=HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Invalid API token"
    )
)
def auth_shield(api_token: str = Header()):
    """Validate the API token"""
    if api_token in VALID_TOKENS:
        return api_token
    return None

@app.get("/protected")
@auth_shield
async def protected_endpoint():
    return {"message": "You have access to the protected endpoint"}

User Role Shield

A shield that checks user roles:

from fastapi import FastAPI, Header, HTTPException, status
from fastapi_shield import shield, ShieldedDepends

app = FastAPI()

# User database with roles
USERS = {
    "token1": {"user_id": 1, "role": "admin"},
    "token2": {"user_id": 2, "role": "editor"},
    "token3": {"user_id": 3, "role": "user"}
}

@shield(name="Auth Shield")
def auth_shield(api_token: str = Header()):
    """Authenticate the user and return user data"""
    if api_token in USERS:
        return USERS[api_token]
    return None

@shield(name="Admin Shield")
def admin_shield(user_data = ShieldedDepends(lambda user: user)):
    """Check if the user has admin role"""
    if user_data["role"] == "admin":
        return user_data
    return None

@shield(name="Editor Shield")
def editor_shield(user_data = ShieldedDepends(lambda user: user)):
    """Check if the user has editor role"""
    if user_data["role"] in ["admin", "editor"]:
        return user_data
    return None

@app.get("/admin-only")
@auth_shield
@admin_shield
async def admin_endpoint():
    return {"message": "Admin endpoint"}

@app.get("/editor-access")
@auth_shield
@editor_shield
async def editor_endpoint():
    return {"message": "Editor endpoint"}

@app.get("/all-users")
@auth_shield
async def all_users_endpoint(user_data = ShieldedDepends(lambda user: user)):
    return {
        "message": f"Welcome, {user_data['role']}",
        "user_id": user_data["user_id"]
    }

Shield with Parameters

Creating a shield that accepts parameters:

from fastapi import FastAPI, Header, HTTPException, status
from fastapi_shield import shield

app = FastAPI()

def create_permission_shield(required_permission: str):
    """Factory function to create a permission shield"""

    @shield(
        name=f"Permission Shield ({required_permission})",
        exception_to_raise_if_fail=HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail=f"Missing required permission: {required_permission}"
        )
    )
    def permission_shield(permissions: str = Header()):
        """Check if the user has the required permission"""
        user_permissions = permissions.split(",")
        if required_permission in user_permissions:
            return {"granted_permission": required_permission}
        return None

    return permission_shield

read_shield = create_permission_shield("read")
write_shield = create_permission_shield("write")
delete_shield = create_permission_shield("delete")

@app.get("/data")
@read_shield
async def read_data():
    return {"message": "Reading data"}

@app.post("/data")
@write_shield
async def write_data():
    return {"message": "Writing data"}

@app.delete("/data")
@delete_shield
async def delete_data():
    return {"message": "Deleting data"}

Multiple Shields

Applying multiple shields to an endpoint:

from fastapi import FastAPI, Header, Query, HTTPException, status
from fastapi_shield import shield
import time

app = FastAPI()

# Simple rate limiting
last_request_time = {}
MIN_REQUEST_INTERVAL = 2  # seconds

@shield(name="Rate Limit Shield")
def rate_limit_shield(client_id: str = Header()):
    """Limit request rate per client"""
    now = time.time()

    if client_id in last_request_time:
        time_since_last = now - last_request_time[client_id]
        if time_since_last < MIN_REQUEST_INTERVAL:
            return None

    last_request_time[client_id] = now
    return client_id

@shield(name="API Key Shield")
def api_key_shield(api_key: str = Header()):
    """Validate API key"""
    if api_key.startswith("valid_key_"):
        return api_key
    return None

@shield(name="Parameter Validator")
def param_validator(action: str = Query(...)):
    """Validate query parameters"""
    valid_actions = ["read", "write", "update", "delete"]
    if action in valid_actions:
        return action
    return None

@app.get("/api/resource")
@rate_limit_shield
@api_key_shield
@param_validator
async def resource_endpoint(action: str = Query(...)):
    return {
        "message": f"Performing {action} action",
        "timestamp": time.time()
    }

Shield with Custom Response

Customizing the response when a shield blocks the request:

from fastapi import FastAPI, Path, Response, status
from fastapi_shield import shield

app = FastAPI()

# Fake user database with subscription plans
USER_DB = {
    "user1": {"subscription": "free"},
    "user2": {"subscription": "premium"},
    "user3": {"subscription": "enterprise"},
    "user4": {"subscription": "unlimited"}
}

@shield(
    name="Feature Flag Shield",
    auto_error=False,
    default_response_to_return_if_fail=Response(
        content='{"message": "This feature is not available in your subscription plan"}',
        media_type="application/json",
        status_code=status.HTTP_402_PAYMENT_REQUIRED
    )
)
def feature_flag_shield(user_id: str = Path(...)):
    """Check if the user's subscription plan includes the feature"""

    # Get user's subscription plan from the database
    user_data = USER_DB.get(user_id, {"subscription": "free"})
    subscription_plan = user_data["subscription"]

    # Check if user has a premium plan
    premium_plans = ["premium", "enterprise", "unlimited"]
    if subscription_plan in premium_plans:
        return subscription_plan
    return None

@app.get("/premium-feature/{user_id}")
@feature_flag_shield
async def premium_feature():
    return {"message": "Welcome to the premium feature!"}

Shield with Asynchronous Function

Using an asynchronous function as a shield:

from fastapi import FastAPI, Header, HTTPException, status
from fastapi_shield import shield
import asyncio

app = FastAPI()

async def validate_token_async(token: str) -> bool:
    """Simulate an asynchronous token validation process"""
    await asyncio.sleep(0.05)  # Simulate external API call
    return token.startswith("valid_")

@shield(
    name="Async Auth Shield",
    exception_to_raise_if_fail=HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Invalid authentication token"
    )
)
async def async_auth_shield(auth_token: str = Header()):
    """Asynchronous shield for token validation"""
    # Simulate calling an external authentication service
    is_valid = await validate_token_async(auth_token)

    if is_valid:
        return {"token": auth_token, "validated": True}
    return None

@app.get("/async-protected")
@async_auth_shield
async def async_protected_endpoint():
    return {"message": "Protected by async shield"}

These examples demonstrate the basic usage of FastAPI Shield. You can combine these patterns and extend them to implement more complex security requirements.