Source code for gateway

#!/usr/bin/env python3
"""
gateway.py – Async HTTP gateway for llama.cpp llama-server

Features:
- API key authentication with health endpoint exemption
- Proper SSE/streaming support for chat completions
- /ping and /health endpoints (no auth required)
- Request timeout handling
- Basic metrics tracking
- Graceful error handling
- Per-key access logging

Note: /ping and /health return 200 immediately without backend checks
or authentication to enable scale-to-zero in serverless environments.

Environment Variables:
    GATEWAY_PORT    - Port to listen on (default: 8000)
    BACKEND_HOST    - llama-server host (default: 127.0.0.1)
    PORT_BACKEND    - llama-server port (default: 8080)
    BACKEND_API_KEY - API key for backend authentication (optional)
    REQUEST_TIMEOUT - Max request time in seconds (default: 300)
    HEALTH_TIMEOUT  - Health check timeout in seconds (default: 2)
    AUTH_ENABLED    - Enable API key authentication (default: true)
    AUTH_KEYS_FILE  - Path to API keys file (default: $DATA_DIR/api_keys.txt)
"""

import asyncio
import json
import os
import re
import signal
import socket
import sys
import time
from dataclasses import dataclass, field
from typing import Optional

# Import authentication module
try:
    from auth import api_validator, authenticate_request, log_access

    AUTH_AVAILABLE = True
except ImportError:
    print("[gateway] Warning: auth.py not found, authentication disabled")
    AUTH_AVAILABLE = False


[docs] def log(msg: str): """Simple logging to stderr.""" print(f"[gateway] {msg}", file=sys.stderr, flush=True)
# Configuration GATEWAY_HOST = "0.0.0.0" # nosec B104 - intentional bind-all for container networking GATEWAY_PORT = int(os.environ.get("GATEWAY_PORT", os.environ.get("PORT", "8000"))) BACKEND_HOST = os.environ.get("BACKEND_HOST", "127.0.0.1") # Support both PORT_BACKEND (new) and BACKEND_PORT (old, deprecated) if "BACKEND_PORT" in os.environ: print("[gateway] WARNING: BACKEND_PORT is deprecated, use PORT_BACKEND instead") BACKEND_PORT = int(os.environ.get("BACKEND_PORT", "8080")) else: BACKEND_PORT = int(os.environ.get("PORT_BACKEND", "8080")) REQUEST_TIMEOUT = float(os.environ.get("REQUEST_TIMEOUT", "300")) HEALTH_TIMEOUT = float(os.environ.get("HEALTH_TIMEOUT", "2")) # Backend authentication BACKEND_API_KEY = os.environ.get("BACKEND_API_KEY") if BACKEND_API_KEY: # Validate key format: "gateway-" + 43 base64url characters # Total length should be 51 characters (8 + 43) if not re.match(r"^gateway-[A-Za-z0-9_-]{43}$", BACKEND_API_KEY): log("ERROR: BACKEND_API_KEY has invalid format (expected: gateway-{43 base64url chars})") log(f"ERROR: Received length: {len(BACKEND_API_KEY)}, expected: 51") sys.exit(1) if len(BACKEND_API_KEY) != 51: log(f"ERROR: BACKEND_API_KEY has invalid length: {len(BACKEND_API_KEY)} (expected: 51)") sys.exit(1) log("Backend key format validated successfully") else: log("WARNING: BACKEND_API_KEY not set - backend will not require authentication") # Metrics (simple in-memory counters)
[docs] @dataclass class Metrics: requests_total: int = 0 requests_success: int = 0 requests_error: int = 0 requests_active: int = 0 requests_authenticated: int = 0 requests_unauthorized: int = 0 bytes_sent: int = 0 start_time: float = field(default_factory=time.time)
[docs] def to_dict(self): return { "requests_total": self.requests_total, "requests_success": self.requests_success, "requests_error": self.requests_error, "requests_active": self.requests_active, "requests_authenticated": self.requests_authenticated, "requests_unauthorized": self.requests_unauthorized, "bytes_sent": self.bytes_sent, "uptime_seconds": int(time.time() - self.start_time), }
metrics = Metrics()
[docs] def backend_tcp_ready() -> bool: """Check if backend is accepting TCP connections.""" try: with socket.create_connection((BACKEND_HOST, BACKEND_PORT), timeout=0.5): return True except (OSError, socket.timeout): return False
[docs] async def backend_health_check() -> dict: """Check backend health via /health endpoint. Returns health status dict or error. """ try: reader, writer = await asyncio.wait_for( asyncio.open_connection(BACKEND_HOST, BACKEND_PORT), timeout=HEALTH_TIMEOUT ) request = ( f"GET /health HTTP/1.1\r\n" f"Host: {BACKEND_HOST}:{BACKEND_PORT}\r\n" f"Connection: close\r\n\r\n" ) writer.write(request.encode()) await writer.drain() response = await asyncio.wait_for(reader.read(4096), timeout=HEALTH_TIMEOUT) writer.close() await writer.wait_closed() # Parse response response_str = response.decode("utf-8", errors="replace") # Extract status code first_line = response_str.split("\r\n")[0] status_code = int(first_line.split()[1]) if len(first_line.split()) > 1 else 0 # Extract body (after \r\n\r\n) if "\r\n\r\n" in response_str: body = response_str.split("\r\n\r\n", 1)[1] try: return { "status": "ok", "code": status_code, "backend": json.loads(body), } except json.JSONDecodeError: return { "status": "ok", "code": status_code, "backend_raw": body[:200], } return {"status": "ok", "code": status_code} except asyncio.TimeoutError: return {"status": "timeout", "error": "Backend health check timed out"} except Exception as e: return {"status": "error", "error": str(e)}
[docs] async def handle_ping(writer: asyncio.StreamWriter): """Handle /ping endpoint for RunPod health checks. Always returns 200 OK without authentication or backend checks. For detailed backend status, use /health endpoint instead. """ response = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n" writer.write(response.encode()) await writer.drain()
[docs] async def handle_health(writer: asyncio.StreamWriter): """Handle /health endpoint with detailed backend status. No authentication required. """ health = await backend_health_check() health["gateway"] = { "status": "ok", "metrics": metrics.to_dict(), } # Add auth metrics if available if AUTH_AVAILABLE and api_validator.enabled: health["authentication"] = { "enabled": True, "keys_configured": len(api_validator.keys), "rate_limits": api_validator.get_metrics(), } body = json.dumps(health, indent=2) response = ( f"HTTP/1.1 200 OK\r\n" f"Content-Type: application/json\r\n" f"Content-Length: {len(body)}\r\n" f"Connection: close\r\n" f"\r\n" f"{body}" ) writer.write(response.encode()) await writer.drain()
[docs] async def handle_metrics(writer: asyncio.StreamWriter): """Handle /metrics endpoint. No authentication required. """ metrics_data = {"gateway": metrics.to_dict()} # Add auth metrics if available if AUTH_AVAILABLE and api_validator.enabled: metrics_data["authentication"] = api_validator.get_metrics() body = json.dumps(metrics_data, indent=2) response = ( f"HTTP/1.1 200 OK\r\n" f"Content-Type: application/json\r\n" f"Content-Length: {len(body)}\r\n" f"Connection: close\r\n" f"\r\n" f"{body}" ) writer.write(response.encode()) await writer.drain()
[docs] async def proxy_request( method: str, path: str, headers: dict, body: Optional[bytes], writer: asyncio.StreamWriter, key_id: str = "unknown", ): """Proxy a request to the backend with streaming support. Args: method: HTTP method (GET, POST, etc.) path: Request path to forward to backend headers: Request headers dict (lowercase keys) body: Request body bytes, or None for bodyless requests writer: asyncio StreamWriter for the client connection key_id: The authenticated key_id for logging """ metrics.requests_total += 1 metrics.requests_active += 1 try: # Connect to backend try: backend_reader, backend_writer = await asyncio.wait_for( asyncio.open_connection(BACKEND_HOST, BACKEND_PORT), timeout=5.0 ) except (asyncio.TimeoutError, OSError): metrics.requests_error += 1 error_response = ( "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\nConnection: close\r\n\r\n" ) writer.write(error_response.encode()) await writer.drain() # Log failed request if AUTH_AVAILABLE: await log_access(method, path, key_id, 502) return # Build request to backend request_line = f"{method} {path} HTTP/1.1\r\n" # Forward headers, adjusting Host header_lines = [f"Host: {BACKEND_HOST}:{BACKEND_PORT}"] for key, value in headers.items(): key_lower = key.lower() if key_lower in ( "host", "connection", "keep-alive", "transfer-encoding", "authorization", ): continue # Skip user's authorization header header_lines.append(f"{key}: {value}") # Add backend authentication if configured if BACKEND_API_KEY: header_lines.append(f"Authorization: Bearer {BACKEND_API_KEY}") header_lines.append("Connection: close") request = request_line + "\r\n".join(header_lines) + "\r\n\r\n" backend_writer.write(request.encode()) if body: backend_writer.write(body) await backend_writer.drain() # Read and forward response headers response_headers = b"" while True: line = await asyncio.wait_for(backend_reader.readline(), timeout=REQUEST_TIMEOUT) response_headers += line if line == b"\r\n" or line == b"": break # Send headers to client writer.write(response_headers) await writer.drain() # Stream response body bytes_sent = 0 try: while True: chunk = await asyncio.wait_for(backend_reader.read(8192), timeout=REQUEST_TIMEOUT) if not chunk: break writer.write(chunk) await writer.drain() bytes_sent += len(chunk) except asyncio.TimeoutError: pass # Connection closed or timeout metrics.bytes_sent += bytes_sent metrics.requests_success += 1 backend_writer.close() await backend_writer.wait_closed() # Log successful request if AUTH_AVAILABLE: await log_access(method, path, key_id, 200) except Exception as e: metrics.requests_error += 1 log(f"Proxy error: {e}") try: error_response = ( "HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\nConnection: close\r\n\r\n" ) writer.write(error_response.encode()) await writer.drain() except Exception: log("Cleanup: failed to send error response to client") # Log error if AUTH_AVAILABLE: await log_access(method, path, key_id, 502) finally: metrics.requests_active -= 1
[docs] async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): """Handle an incoming client connection.""" try: # Read request line request_line_raw = await asyncio.wait_for(reader.readline(), timeout=30) if not request_line_raw: return request_line = request_line_raw.decode("utf-8", errors="replace").strip() parts = request_line.split() if len(parts) < 2: return method = parts[0] path = parts[1] # Read headers headers: dict[str, str] = {} content_length = 0 while True: header_line_raw = await asyncio.wait_for(reader.readline(), timeout=30) if header_line_raw == b"\r\n" or header_line_raw == b"": break header_line = header_line_raw.decode("utf-8", errors="replace").strip() if ":" in header_line: key, value = header_line.split(":", 1) headers[key.strip().lower()] = value.strip() if key.lower() == "content-length": content_length = int(value.strip()) # Read body if present body = None if content_length > 0: body = await asyncio.wait_for(reader.readexactly(content_length), timeout=30) # Route request - health endpoints bypass auth if path in ("/ping", "/health", "/metrics"): if path == "/ping": await handle_ping(writer) elif path == "/health": await handle_health(writer) elif path == "/metrics": await handle_metrics(writer) return # All other endpoints require authentication if AUTH_AVAILABLE: key_id = await authenticate_request(writer, headers) if key_id is None: # 401 response already sent by authenticate_request metrics.requests_unauthorized += 1 return metrics.requests_authenticated += 1 else: # Auth not available, allow request key_id = "auth-disabled" # Proxy to backend await proxy_request(method, path, headers, body, writer, key_id) except asyncio.TimeoutError: pass except Exception as e: log(f"Client handler error: {e}") finally: try: writer.close() await writer.wait_closed() except Exception: log("Cleanup: failed to close client writer")
[docs] async def main(): """Start the gateway server.""" log(f"Starting gateway on {GATEWAY_HOST}:{GATEWAY_PORT}") log(f"Backend: {BACKEND_HOST}:{BACKEND_PORT}") log(f"Request timeout: {REQUEST_TIMEOUT}s") if AUTH_AVAILABLE: if api_validator.enabled: log(f"Authentication: ENABLED ({len(api_validator.keys)} keys configured)") else: log("Authentication: DISABLED") else: log("Authentication: NOT AVAILABLE (auth.py not found)") server = await asyncio.start_server( handle_client, GATEWAY_HOST, GATEWAY_PORT, reuse_address=True, ) # Handle shutdown signals loop = asyncio.get_event_loop() def signal_handler(): log("Shutdown signal received") server.close() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, signal_handler) log(f"Gateway listening on http://{GATEWAY_HOST}:{GATEWAY_PORT}") log("Public endpoints (no auth): /ping, /health, /metrics") log("Protected endpoints (auth required): /v1/*") async with server: await server.serve_forever()
if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: log("Interrupted") except Exception as e: log(f"Fatal error: {e}") sys.exit(1)