Source code for auth

#!/usr/bin/env python3
"""
auth.py - API Key Authentication Module for Gateway

File-based authentication system that enforces API keys while maintaining
OpenAI compatibility. Uses key_id:api_key format for easy management and auditing.

Usage:
    from auth import api_validator, authenticate_request, log_access

    # In your request handler:
    api_key_info = await authenticate_request(writer, headers)
    if api_key_info is None:
        # 401 response already sent
        return

    key_id = api_key_info  # The key_id for logging/tracking

    # Process request...
    await log_access(method, path, key_id, status_code)

Environment Variables:
    AUTH_ENABLED              - Enable/disable authentication (default: true)
    AUTH_KEYS_FILE            - Path to keys file (default: $DATA_DIR/api_keys.txt)
    MAX_REQUESTS_PER_MINUTE   - Rate limit per key_id (default: 100)

Keys File Format:
    key_id:api_key

Example:
    production:sk-prod-abc123def456
    alice-laptop:sk-alice-xyz789
    development:sk-dev-test123
"""

import asyncio
import datetime
import json
import os
import time
from collections import defaultdict
from typing import Optional, Tuple


[docs] class APIKeyValidator: """ Validates API keys for incoming requests. Features: - File-based configuration (key_id:api_key format) - Rate limiting per key_id - Key format validation - Audit trail with key_id tracking """
[docs] def __init__(self): self.enabled = os.environ.get("AUTH_ENABLED", "true").lower() == "true" self.keys_file = os.environ.get( "AUTH_KEYS_FILE", f"{os.environ.get('DATA_DIR', '/data')}/api_keys.txt", ) self.keys = self._load_keys() # Maps api_key -> key_id self.rate_limiter = defaultdict(list) # Maps key_id -> [timestamps] self.max_requests_per_minute = int(os.environ.get("MAX_REQUESTS_PER_MINUTE", "100")) if self.enabled: if self.keys: print(f"✅ Authentication enabled with {len(self.keys)} key(s)") else: print("⚠️ Authentication enabled but no keys configured. Allowing all requests.") print(f" Create keys file at: {self.keys_file}") else: print("⚠️ Authentication disabled! All requests will be accepted.")
def _load_keys(self) -> dict: """ Load API keys from file. File format: key_id:api_key # Comments allowed Returns: dict mapping api_key -> key_id for reverse lookup Example: {"sk-prod-abc123": "production", "sk-alice-xyz": "alice-laptop"} """ if not self.enabled: return {} if not os.path.exists(self.keys_file): print(f"⚠️ AUTH_ENABLED=true but keys file not found: {self.keys_file}") print(" Create file with format: key_id:api_key") print(" WARNING: Accepting all requests until keys file is configured!") return {} try: keys: dict[str, str] = {} # Maps api_key -> key_id with open(self.keys_file, "r") as f: for line_num, line in enumerate(f, 1): line = line.strip() # Skip comments and empty lines if not line or line.startswith("#"): continue # Parse key_id:api_key if ":" not in line: print(f"⚠️ Line {line_num}: Invalid format (missing ':'), skipping") continue parts = line.split(":", 1) if len(parts) != 2: print(f"⚠️ Line {line_num}: Invalid format, skipping") continue key_id = parts[0].strip() api_key = parts[1].strip() # Validate key_id if not key_id or not all(c.isalnum() or c in "-_" for c in key_id): print(f"⚠️ Line {line_num}: Invalid key_id '{key_id}', skipping") continue # Validate api_key format if not self._is_valid_format(api_key): print( f"⚠️ Line {line_num}: Invalid api_key format for '{key_id}', skipping" ) continue # Check for duplicate keys if api_key in keys: print( f"⚠️ Line {line_num}: Duplicate api_key for '{key_id}' " f"(already used by '{keys[api_key]}'), skipping" ) continue keys[api_key] = key_id if keys: print(f"🔐 Loaded {len(keys)} API key(s) from {self.keys_file}") # Show key IDs for verification key_ids = sorted(set(keys.values())) print(f" Key IDs: {', '.join(key_ids)}") return keys else: print(f"⚠️ Keys file exists but contains no valid keys: {self.keys_file}") return {} except Exception as e: print(f"❌ Error loading keys from {self.keys_file}: {e}") return {}
[docs] def validate(self, headers: dict) -> Tuple[bool, str]: """ Validate API key from request headers. Args: headers: Request headers (lowercase keys) Returns: (is_valid, key_id_or_error_message) - If valid: (True, key_id) - If invalid: (False, error_message) """ # If auth disabled, allow everything if not self.enabled: return True, "auth-disabled" # If no keys configured, allow everything if not self.keys: return True, "no-keys-configured" # Extract key from Authorization header auth_header = headers.get("authorization", "") if not auth_header: return False, "Missing Authorization header" # Remove "Bearer " prefix if present (OpenAI compatible) if auth_header.lower().startswith("bearer "): api_key = auth_header[7:].strip() else: api_key = auth_header.strip() if not api_key: return False, "Empty Authorization header" # Validate key format if not self._is_valid_format(api_key): return False, "Invalid API key format" # Check if key exists if api_key not in self.keys: return False, "Invalid API key" # Get key_id for this api_key key_id = self.keys[api_key] # Check rate limit if not self._check_rate_limit(key_id): return False, f"Rate limit exceeded for {key_id}" # Record successful request self._record_request(key_id) return True, key_id
def _is_valid_format(self, key: str) -> bool: """ Check if key format looks valid. Accepts alphanumeric, hyphens, underscores. Length: 16-128 characters. """ if not (16 <= len(key) <= 128): return False return all(c.isalnum() or c in "-_" for c in key) def _check_rate_limit(self, key_id: str) -> bool: """ Check if key_id has exceeded rate limit. Returns True if under limit, False if exceeded. """ now = time.time() minute_ago = now - 60 # Clean old requests (older than 1 minute) self.rate_limiter[key_id] = [ts for ts in self.rate_limiter[key_id] if ts > minute_ago] # Check if under limit if len(self.rate_limiter[key_id]) >= self.max_requests_per_minute: return False return True def _record_request(self, key_id: str): """Record a request timestamp for rate limiting.""" self.rate_limiter[key_id].append(time.time())
[docs] def get_metrics(self) -> dict: """ Get current rate limiter metrics per key_id. Returns: dict with request counts and limits per key_id """ now = time.time() minute_ago = now - 60 metrics = {} for key_id, timestamps in self.rate_limiter.items(): # Count recent requests recent = [ts for ts in timestamps if ts > minute_ago] metrics[key_id] = { "requests_last_minute": len(recent), "rate_limit": self.max_requests_per_minute, } return metrics
# Global validator instance api_validator = APIKeyValidator()
[docs] async def authenticate_request(writer: asyncio.StreamWriter, headers: dict) -> Optional[str]: """ Authenticate an incoming request. If authentication fails, sends a 401 response and returns None. If authentication succeeds, returns the key_id. Args: writer: asyncio StreamWriter to send response headers: Request headers (lowercase keys) Returns: key_id if valid, None if invalid (401 response already sent) """ is_valid, result = api_validator.validate(headers) if not is_valid: # Send OpenAI-compatible 401 error error_response = { "error": { "message": result, "type": "invalid_request_error", "param": "authorization", "code": "invalid_api_key", } } body = json.dumps(error_response) response = ( "HTTP/1.1 401 Unauthorized\r\n" "Content-Type: application/json\r\n" f"Content-Length: {len(body)}\r\n" "Connection: close\r\n" "\r\n" + body ) writer.write(response.encode()) await writer.drain() return None # Authentication succeeded, return key_id return result
[docs] async def send_rate_limit_error(writer: asyncio.StreamWriter): """Send OpenAI-compatible rate limit error (429).""" error_response = { "error": { "message": "Rate limit exceeded. Please slow down your requests.", "type": "rate_limit_error", "code": "rate_limit_exceeded", } } body = json.dumps(error_response) response = ( "HTTP/1.1 429 Too Many Requests\r\n" "Content-Type: application/json\r\n" f"Content-Length: {len(body)}\r\n" "Retry-After: 60\r\n" "Connection: close\r\n" "\r\n" + body ) writer.write(response.encode()) await writer.drain()
[docs] async def log_access(method: str, path: str, key_id: str, status_code: int): """ Log API access for auditing. Logs to: /data/logs/api_access.log Format: ISO8601_timestamp | key_id | method path | status_code Example: 2024-02-06T14:30:22.123456 | production | POST /v1/chat/completions | 200 Args: method: HTTP method path: Request path key_id: The key identifier (e.g., "production", "alice-laptop") status_code: HTTP status code """ timestamp = datetime.datetime.now().isoformat() log_entry = f"{timestamp} | {key_id} | {method} {path} | {status_code}" # Log to file log_file = "/data/logs/api_access.log" try: os.makedirs(os.path.dirname(log_file), exist_ok=True) with open(log_file, "a") as f: f.write(log_entry + "\n") except Exception as e: # Don't fail request if logging fails, but print warning print(f"Warning: Failed to log access: {e}")