aegra-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aegra_api/__init__.py +3 -0
- aegra_api/api/__init__.py +1 -0
- aegra_api/api/assistants.py +235 -0
- aegra_api/api/runs.py +1110 -0
- aegra_api/api/store.py +200 -0
- aegra_api/api/threads.py +761 -0
- aegra_api/config.py +204 -0
- aegra_api/constants.py +5 -0
- aegra_api/core/__init__.py +0 -0
- aegra_api/core/app_loader.py +91 -0
- aegra_api/core/auth_ctx.py +65 -0
- aegra_api/core/auth_deps.py +186 -0
- aegra_api/core/auth_handlers.py +248 -0
- aegra_api/core/auth_middleware.py +331 -0
- aegra_api/core/database.py +123 -0
- aegra_api/core/health.py +131 -0
- aegra_api/core/orm.py +165 -0
- aegra_api/core/route_merger.py +69 -0
- aegra_api/core/serializers/__init__.py +7 -0
- aegra_api/core/serializers/base.py +22 -0
- aegra_api/core/serializers/general.py +54 -0
- aegra_api/core/serializers/langgraph.py +102 -0
- aegra_api/core/sse.py +178 -0
- aegra_api/main.py +303 -0
- aegra_api/middleware/__init__.py +4 -0
- aegra_api/middleware/double_encoded_json.py +74 -0
- aegra_api/middleware/logger_middleware.py +95 -0
- aegra_api/models/__init__.py +76 -0
- aegra_api/models/assistants.py +81 -0
- aegra_api/models/auth.py +62 -0
- aegra_api/models/enums.py +29 -0
- aegra_api/models/errors.py +29 -0
- aegra_api/models/runs.py +124 -0
- aegra_api/models/store.py +67 -0
- aegra_api/models/threads.py +152 -0
- aegra_api/observability/__init__.py +1 -0
- aegra_api/observability/base.py +88 -0
- aegra_api/observability/otel.py +133 -0
- aegra_api/observability/setup.py +27 -0
- aegra_api/observability/targets/__init__.py +11 -0
- aegra_api/observability/targets/base.py +18 -0
- aegra_api/observability/targets/langfuse.py +33 -0
- aegra_api/observability/targets/otlp.py +38 -0
- aegra_api/observability/targets/phoenix.py +24 -0
- aegra_api/services/__init__.py +0 -0
- aegra_api/services/assistant_service.py +569 -0
- aegra_api/services/base_broker.py +59 -0
- aegra_api/services/broker.py +141 -0
- aegra_api/services/event_converter.py +157 -0
- aegra_api/services/event_store.py +196 -0
- aegra_api/services/graph_streaming.py +433 -0
- aegra_api/services/langgraph_service.py +456 -0
- aegra_api/services/streaming_service.py +362 -0
- aegra_api/services/thread_state_service.py +128 -0
- aegra_api/settings.py +124 -0
- aegra_api/utils/__init__.py +3 -0
- aegra_api/utils/assistants.py +23 -0
- aegra_api/utils/run_utils.py +60 -0
- aegra_api/utils/setup_logging.py +122 -0
- aegra_api/utils/sse_utils.py +26 -0
- aegra_api/utils/status_compat.py +57 -0
- aegra_api-0.1.0.dist-info/METADATA +244 -0
- aegra_api-0.1.0.dist-info/RECORD +64 -0
- aegra_api-0.1.0.dist-info/WHEEL +4 -0
aegra_api/core/sse.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Server-Sent Events utilities and formatting"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from datetime import UTC, datetime
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
# Import our serializer for handling complex objects
|
|
10
|
+
from aegra_api.core.serializers import GeneralSerializer
|
|
11
|
+
|
|
12
|
+
# Global serializer instance
|
|
13
|
+
_serializer = GeneralSerializer()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_sse_headers() -> dict[str, str]:
|
|
17
|
+
"""Get standard SSE headers"""
|
|
18
|
+
return {
|
|
19
|
+
"Cache-Control": "no-cache",
|
|
20
|
+
"Connection": "keep-alive",
|
|
21
|
+
"Content-Type": "text/event-stream",
|
|
22
|
+
"Access-Control-Allow-Origin": "*",
|
|
23
|
+
"Access-Control-Allow-Headers": "Last-Event-ID",
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def format_sse_message(
|
|
28
|
+
event: str,
|
|
29
|
+
data: Any,
|
|
30
|
+
event_id: str | None = None,
|
|
31
|
+
serializer: Callable[[Any], Any] | None = None,
|
|
32
|
+
) -> str:
|
|
33
|
+
"""Format a message as Server-Sent Event following SSE standard
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
event: SSE event type
|
|
37
|
+
data: Data to serialize and send
|
|
38
|
+
event_id: Optional event ID
|
|
39
|
+
serializer: Optional custom serializer function
|
|
40
|
+
"""
|
|
41
|
+
lines = []
|
|
42
|
+
|
|
43
|
+
lines.append(f"event: {event}")
|
|
44
|
+
|
|
45
|
+
# Convert data to JSON string
|
|
46
|
+
if data is None:
|
|
47
|
+
data_str = ""
|
|
48
|
+
else:
|
|
49
|
+
# Use our general serializer by default to handle complex objects
|
|
50
|
+
default_serializer = serializer or _serializer.serialize
|
|
51
|
+
data_str = json.dumps(data, default=default_serializer, separators=(",", ":"))
|
|
52
|
+
|
|
53
|
+
lines.append(f"data: {data_str}")
|
|
54
|
+
|
|
55
|
+
if event_id:
|
|
56
|
+
lines.append(f"id: {event_id}")
|
|
57
|
+
|
|
58
|
+
lines.append("") # Empty line to end the event
|
|
59
|
+
|
|
60
|
+
return "\n".join(lines) + "\n"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def create_metadata_event(run_id: str, event_id: str | None = None, attempt: int = 1) -> str:
|
|
64
|
+
"""Create metadata event for LangSmith Studio compatibility"""
|
|
65
|
+
data = {"run_id": run_id, "attempt": attempt}
|
|
66
|
+
return format_sse_message("metadata", data, event_id)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def create_debug_event(debug_data: dict[str, Any], event_id: str | None = None) -> str:
|
|
70
|
+
"""Create debug event with checkpoint fields for LangSmith Studio compatibility"""
|
|
71
|
+
|
|
72
|
+
# Add checkpoint and parent_checkpoint fields if not present
|
|
73
|
+
if "payload" in debug_data and isinstance(debug_data["payload"], dict):
|
|
74
|
+
payload = debug_data["payload"]
|
|
75
|
+
|
|
76
|
+
# Extract checkpoint from config.configurable
|
|
77
|
+
if "checkpoint" not in payload and "config" in payload:
|
|
78
|
+
config = payload.get("config", {})
|
|
79
|
+
if isinstance(config, dict) and "configurable" in config:
|
|
80
|
+
configurable = config["configurable"]
|
|
81
|
+
if isinstance(configurable, dict):
|
|
82
|
+
payload["checkpoint"] = {
|
|
83
|
+
"thread_id": configurable.get("thread_id"),
|
|
84
|
+
"checkpoint_id": configurable.get("checkpoint_id"),
|
|
85
|
+
"checkpoint_ns": configurable.get("checkpoint_ns", ""),
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
# Extract parent_checkpoint from parent_config.configurable
|
|
89
|
+
if "parent_checkpoint" not in payload and "parent_config" in payload:
|
|
90
|
+
parent_config = payload.get("parent_config")
|
|
91
|
+
if isinstance(parent_config, dict) and "configurable" in parent_config:
|
|
92
|
+
configurable = parent_config["configurable"]
|
|
93
|
+
if isinstance(configurable, dict):
|
|
94
|
+
payload["parent_checkpoint"] = {
|
|
95
|
+
"thread_id": configurable.get("thread_id"),
|
|
96
|
+
"checkpoint_id": configurable.get("checkpoint_id"),
|
|
97
|
+
"checkpoint_ns": configurable.get("checkpoint_ns", ""),
|
|
98
|
+
}
|
|
99
|
+
elif parent_config is None:
|
|
100
|
+
payload["parent_checkpoint"] = None
|
|
101
|
+
|
|
102
|
+
return format_sse_message("debug", debug_data, event_id)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def create_end_event(event_id: str | None = None) -> str:
|
|
106
|
+
"""Create end event - signals completion of stream
|
|
107
|
+
|
|
108
|
+
Uses standard status: "success" instead of "completed"
|
|
109
|
+
"""
|
|
110
|
+
return format_sse_message("end", {"status": "success"}, event_id)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def create_error_event(error: str | dict[str, Any], event_id: str | None = None) -> str:
|
|
114
|
+
"""Create error event with structured error information.
|
|
115
|
+
|
|
116
|
+
Error format: {"error": str, "message": str}
|
|
117
|
+
This format ensures compatibility with standard SSE error event consumers.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
error: Either a simple error string, or a dict with structured error info.
|
|
121
|
+
Dict format: {"error": "ErrorType", "message": "detailed message"}
|
|
122
|
+
event_id: Optional SSE event ID for reconnection support.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
SSE-formatted error event string with standard error format.
|
|
126
|
+
"""
|
|
127
|
+
if isinstance(error, dict):
|
|
128
|
+
# Structured error format - standard format: {error: str, message: str}
|
|
129
|
+
data = {
|
|
130
|
+
"error": error.get("error", "Error"),
|
|
131
|
+
"message": error.get("message", str(error)),
|
|
132
|
+
}
|
|
133
|
+
else:
|
|
134
|
+
# Simple string format - wrap it to standard format
|
|
135
|
+
data = {
|
|
136
|
+
"error": "Error",
|
|
137
|
+
"message": str(error),
|
|
138
|
+
}
|
|
139
|
+
return format_sse_message("error", data, event_id)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def create_messages_event(messages_data: Any, event_type: str = "messages", event_id: str | None = None) -> str:
|
|
143
|
+
"""Create messages event (messages, messages/partial, messages/complete, messages/metadata)"""
|
|
144
|
+
# Handle tuple format for token streaming: (message_chunk, metadata)
|
|
145
|
+
if isinstance(messages_data, tuple) and len(messages_data) == 2:
|
|
146
|
+
message_chunk, metadata = messages_data
|
|
147
|
+
# Format as expected by LangGraph SDK client
|
|
148
|
+
data = [message_chunk, metadata]
|
|
149
|
+
return format_sse_message(event_type, data, event_id)
|
|
150
|
+
else:
|
|
151
|
+
# Handle list of messages format
|
|
152
|
+
return format_sse_message(event_type, messages_data, event_id)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# Legacy compatibility - used by event_store.py
|
|
156
|
+
@dataclass
|
|
157
|
+
class SSEEvent:
|
|
158
|
+
"""SSE Event data structure for event storage"""
|
|
159
|
+
|
|
160
|
+
id: str
|
|
161
|
+
event: str
|
|
162
|
+
data: dict[str, Any]
|
|
163
|
+
timestamp: datetime | None = None
|
|
164
|
+
|
|
165
|
+
def __post_init__(self) -> None:
|
|
166
|
+
if self.timestamp is None:
|
|
167
|
+
self.timestamp = datetime.now(UTC)
|
|
168
|
+
|
|
169
|
+
def format(self) -> str:
|
|
170
|
+
"""Format as proper SSE event"""
|
|
171
|
+
json_data = json.dumps(self.data, default=str)
|
|
172
|
+
return f"id: {self.id}\nevent: {self.event}\ndata: {json_data}\n\n"
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def format_sse_event(id: str, event: str, data: dict[str, Any]) -> str:
|
|
176
|
+
"""Format SSE event (used by event_store)"""
|
|
177
|
+
json_data = json.dumps(data, default=str)
|
|
178
|
+
return f"id: {id}\nevent: {event}\ndata: {json_data}\n\n"
|
aegra_api/main.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
"""FastAPI application for Aegra (Agent Protocol Server)"""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import sys
|
|
5
|
+
from collections.abc import AsyncIterator
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from dotenv import load_dotenv
|
|
11
|
+
|
|
12
|
+
# Load environment variables from .env file
|
|
13
|
+
load_dotenv()
|
|
14
|
+
|
|
15
|
+
# Add graphs directory to Python path so react_agent can be imported
|
|
16
|
+
# This MUST happen before importing any modules that depend on graphs/
|
|
17
|
+
current_dir = Path(__file__).parent.parent.parent # Go up to aegra root
|
|
18
|
+
graphs_dir = current_dir / "graphs"
|
|
19
|
+
if str(graphs_dir) not in sys.path:
|
|
20
|
+
sys.path.insert(0, str(graphs_dir))
|
|
21
|
+
|
|
22
|
+
# ruff: noqa: E402 - imports below require sys.path modification above
|
|
23
|
+
import structlog
|
|
24
|
+
from asgi_correlation_id import CorrelationIdMiddleware
|
|
25
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
26
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
27
|
+
from fastapi.responses import JSONResponse
|
|
28
|
+
|
|
29
|
+
from aegra_api.api.assistants import router as assistants_router
|
|
30
|
+
from aegra_api.api.runs import router as runs_router
|
|
31
|
+
from aegra_api.api.store import router as store_router
|
|
32
|
+
from aegra_api.api.threads import router as threads_router
|
|
33
|
+
from aegra_api.config import HttpConfig, get_config_dir, load_http_config
|
|
34
|
+
from aegra_api.core.app_loader import load_custom_app
|
|
35
|
+
from aegra_api.core.auth_deps import auth_dependency
|
|
36
|
+
from aegra_api.core.database import db_manager
|
|
37
|
+
from aegra_api.core.health import router as health_router
|
|
38
|
+
from aegra_api.core.route_merger import (
|
|
39
|
+
merge_exception_handlers,
|
|
40
|
+
merge_lifespans,
|
|
41
|
+
)
|
|
42
|
+
from aegra_api.middleware import DoubleEncodedJSONMiddleware, StructLogMiddleware
|
|
43
|
+
from aegra_api.models.errors import AgentProtocolError, get_error_type
|
|
44
|
+
from aegra_api.observability.setup import setup_observability
|
|
45
|
+
from aegra_api.services.event_store import event_store
|
|
46
|
+
from aegra_api.services.langgraph_service import get_langgraph_service
|
|
47
|
+
from aegra_api.settings import settings
|
|
48
|
+
from aegra_api.utils.setup_logging import setup_logging
|
|
49
|
+
|
|
50
|
+
# Task management for run cancellation
|
|
51
|
+
active_runs: dict[str, asyncio.Task] = {}
|
|
52
|
+
|
|
53
|
+
setup_logging()
|
|
54
|
+
logger = structlog.getLogger(__name__)
|
|
55
|
+
|
|
56
|
+
# Default CORS headers required for LangGraph SDK stream reconnection
|
|
57
|
+
DEFAULT_EXPOSE_HEADERS = ["Content-Location", "Location"]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@asynccontextmanager
|
|
61
|
+
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
|
62
|
+
"""FastAPI lifespan context manager for startup/shutdown"""
|
|
63
|
+
# Startup: Initialize database and LangGraph components
|
|
64
|
+
await db_manager.initialize()
|
|
65
|
+
|
|
66
|
+
# Observability
|
|
67
|
+
setup_observability()
|
|
68
|
+
|
|
69
|
+
# Initialize LangGraph service
|
|
70
|
+
langgraph_service = get_langgraph_service()
|
|
71
|
+
await langgraph_service.initialize()
|
|
72
|
+
|
|
73
|
+
# Initialize event store cleanup task
|
|
74
|
+
await event_store.start_cleanup_task()
|
|
75
|
+
|
|
76
|
+
yield
|
|
77
|
+
|
|
78
|
+
# Shutdown: Clean up connections and cancel active runs
|
|
79
|
+
for task in active_runs.values():
|
|
80
|
+
if not task.done():
|
|
81
|
+
task.cancel()
|
|
82
|
+
|
|
83
|
+
# Stop event store cleanup task
|
|
84
|
+
await event_store.stop_cleanup_task()
|
|
85
|
+
|
|
86
|
+
await db_manager.close()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# Define core exception handlers
|
|
90
|
+
async def agent_protocol_exception_handler(_request: Request, exc: HTTPException) -> JSONResponse:
|
|
91
|
+
"""Convert HTTP exceptions to Agent Protocol error format"""
|
|
92
|
+
return JSONResponse(
|
|
93
|
+
status_code=exc.status_code,
|
|
94
|
+
content=AgentProtocolError(
|
|
95
|
+
error=get_error_type(exc.status_code),
|
|
96
|
+
message=exc.detail,
|
|
97
|
+
details=getattr(exc, "details", None),
|
|
98
|
+
).model_dump(),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
async def general_exception_handler(_request: Request, exc: Exception) -> JSONResponse:
|
|
103
|
+
"""Handle unexpected exceptions"""
|
|
104
|
+
return JSONResponse(
|
|
105
|
+
status_code=500,
|
|
106
|
+
content=AgentProtocolError(
|
|
107
|
+
error="internal_error",
|
|
108
|
+
message="An unexpected error occurred",
|
|
109
|
+
details={"exception": str(exc)},
|
|
110
|
+
).model_dump(),
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
exception_handlers = {
|
|
115
|
+
HTTPException: agent_protocol_exception_handler,
|
|
116
|
+
Exception: general_exception_handler,
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# Define root endpoint handler
|
|
121
|
+
async def root_handler() -> dict[str, str]:
|
|
122
|
+
"""Root endpoint"""
|
|
123
|
+
return {
|
|
124
|
+
"message": settings.app.PROJECT_NAME,
|
|
125
|
+
"version": settings.app.VERSION,
|
|
126
|
+
"status": "running",
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _apply_auth_to_routes(app: FastAPI, auth_deps: list[Any]) -> None:
|
|
131
|
+
"""Apply auth dependency to all existing routes in the FastAPI app.
|
|
132
|
+
|
|
133
|
+
This function recursively processes all routes including nested routers,
|
|
134
|
+
adding the auth dependency to each route that doesn't already have it.
|
|
135
|
+
Auth dependencies are prepended to ensure they run first (fail-fast).
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
app: FastAPI application instance
|
|
139
|
+
auth_deps: List of dependencies to apply (e.g., [Depends(require_auth)])
|
|
140
|
+
"""
|
|
141
|
+
from fastapi.routing import APIRoute, APIRouter
|
|
142
|
+
|
|
143
|
+
def process_routes(routes: list) -> None:
|
|
144
|
+
"""Recursively process routes and nested routers."""
|
|
145
|
+
for route in routes:
|
|
146
|
+
if isinstance(route, APIRoute):
|
|
147
|
+
# Add auth dependency if not already present
|
|
148
|
+
existing_deps = list(route.dependencies or [])
|
|
149
|
+
# Check if auth dependency is already present
|
|
150
|
+
auth_dep_ids = {id(dep) for dep in auth_deps}
|
|
151
|
+
existing_dep_ids = {id(dep) for dep in existing_deps}
|
|
152
|
+
if not auth_dep_ids.intersection(existing_dep_ids):
|
|
153
|
+
# Prepend auth deps so they run first (fail-fast)
|
|
154
|
+
route.dependencies = auth_deps + existing_deps
|
|
155
|
+
elif isinstance(route, APIRouter):
|
|
156
|
+
# Process nested router
|
|
157
|
+
process_routes(route.routes)
|
|
158
|
+
elif hasattr(route, "routes"):
|
|
159
|
+
# Handle other route types that have nested routes
|
|
160
|
+
process_routes(route.routes)
|
|
161
|
+
|
|
162
|
+
process_routes(app.routes)
|
|
163
|
+
logger.info("Applied authentication dependency to custom routes")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _add_cors_middleware(app: FastAPI, cors_config: dict[str, Any] | None) -> None:
|
|
167
|
+
"""Add CORS middleware with config or defaults.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
app: FastAPI application instance
|
|
171
|
+
cors_config: CORS configuration dict or None for defaults
|
|
172
|
+
"""
|
|
173
|
+
if cors_config:
|
|
174
|
+
app.add_middleware(
|
|
175
|
+
CORSMiddleware,
|
|
176
|
+
allow_origins=cors_config.get("allow_origins", ["*"]),
|
|
177
|
+
allow_credentials=cors_config.get("allow_credentials", True),
|
|
178
|
+
allow_methods=cors_config.get("allow_methods", ["*"]),
|
|
179
|
+
allow_headers=cors_config.get("allow_headers", ["*"]),
|
|
180
|
+
expose_headers=cors_config.get("expose_headers", DEFAULT_EXPOSE_HEADERS),
|
|
181
|
+
max_age=cors_config.get("max_age", 600),
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
app.add_middleware(
|
|
185
|
+
CORSMiddleware,
|
|
186
|
+
allow_origins=["*"],
|
|
187
|
+
allow_credentials=True,
|
|
188
|
+
allow_methods=["*"],
|
|
189
|
+
allow_headers=["*"],
|
|
190
|
+
expose_headers=DEFAULT_EXPOSE_HEADERS,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _add_common_middleware(app: FastAPI, cors_config: dict[str, Any] | None) -> None:
|
|
195
|
+
"""Add common middleware stack in correct order.
|
|
196
|
+
|
|
197
|
+
Middleware runs in reverse registration order, so we register:
|
|
198
|
+
1. DoubleEncodedJSONMiddleware (outermost - runs first)
|
|
199
|
+
2. CORSMiddleware (handles preflight early)
|
|
200
|
+
3. CorrelationIdMiddleware (adds request ID)
|
|
201
|
+
4. StructLogMiddleware (innermost - logs with correlation ID)
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
app: FastAPI application instance
|
|
205
|
+
cors_config: CORS configuration dict or None for defaults
|
|
206
|
+
"""
|
|
207
|
+
app.add_middleware(StructLogMiddleware)
|
|
208
|
+
app.add_middleware(CorrelationIdMiddleware)
|
|
209
|
+
_add_cors_middleware(app, cors_config)
|
|
210
|
+
app.add_middleware(DoubleEncodedJSONMiddleware)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _include_core_routers(app: FastAPI) -> None:
|
|
214
|
+
"""Include all core API routers with auth dependency.
|
|
215
|
+
|
|
216
|
+
Routers are included in consistent order:
|
|
217
|
+
1. Health (no auth)
|
|
218
|
+
2. Assistants (with auth)
|
|
219
|
+
3. Threads (with auth)
|
|
220
|
+
4. Runs (with auth)
|
|
221
|
+
5. Store (with auth)
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
app: FastAPI application instance
|
|
225
|
+
"""
|
|
226
|
+
app.include_router(health_router, prefix="", tags=["Health"])
|
|
227
|
+
app.include_router(assistants_router, dependencies=auth_dependency, prefix="", tags=["Assistants"])
|
|
228
|
+
app.include_router(threads_router, dependencies=auth_dependency, prefix="", tags=["Threads"])
|
|
229
|
+
app.include_router(runs_router, dependencies=auth_dependency, prefix="", tags=["Runs"])
|
|
230
|
+
app.include_router(store_router, dependencies=auth_dependency, prefix="", tags=["Store"])
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def create_app() -> FastAPI:
|
|
234
|
+
"""Create and configure the FastAPI application.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
Configured FastAPI application instance
|
|
238
|
+
"""
|
|
239
|
+
http_config: HttpConfig | None = load_http_config()
|
|
240
|
+
cors_config = http_config.get("cors") if http_config else None
|
|
241
|
+
|
|
242
|
+
# Try to load custom app if configured
|
|
243
|
+
user_app = None
|
|
244
|
+
if http_config and http_config.get("app"):
|
|
245
|
+
try:
|
|
246
|
+
config_dir = get_config_dir()
|
|
247
|
+
user_app = load_custom_app(http_config["app"], base_dir=config_dir)
|
|
248
|
+
logger.info("Custom app loaded successfully")
|
|
249
|
+
except Exception as e:
|
|
250
|
+
logger.error(f"Failed to load custom app: {e}", exc_info=True)
|
|
251
|
+
raise
|
|
252
|
+
|
|
253
|
+
if user_app:
|
|
254
|
+
if not isinstance(user_app, FastAPI):
|
|
255
|
+
raise TypeError(
|
|
256
|
+
"Custom apps must be FastAPI applications. Use: from fastapi import FastAPI; app = FastAPI()"
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
application = user_app
|
|
260
|
+
_include_core_routers(application)
|
|
261
|
+
|
|
262
|
+
# Add root endpoint if not already defined
|
|
263
|
+
if not any(route.path == "/" for route in application.routes if hasattr(route, "path")):
|
|
264
|
+
application.get("/")(root_handler)
|
|
265
|
+
|
|
266
|
+
application = merge_lifespans(application, lifespan)
|
|
267
|
+
application = merge_exception_handlers(application, exception_handlers)
|
|
268
|
+
_add_common_middleware(application, cors_config)
|
|
269
|
+
|
|
270
|
+
# Apply auth to custom routes if enabled
|
|
271
|
+
if http_config and http_config.get("enable_custom_route_auth", False):
|
|
272
|
+
_apply_auth_to_routes(application, auth_dependency)
|
|
273
|
+
else:
|
|
274
|
+
application = FastAPI(
|
|
275
|
+
title=settings.app.PROJECT_NAME,
|
|
276
|
+
description="Production-ready Agent Protocol server",
|
|
277
|
+
version=settings.app.VERSION,
|
|
278
|
+
debug=settings.app.DEBUG,
|
|
279
|
+
docs_url="/docs",
|
|
280
|
+
redoc_url="/redoc",
|
|
281
|
+
lifespan=lifespan,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
_add_common_middleware(application, cors_config)
|
|
285
|
+
_include_core_routers(application)
|
|
286
|
+
|
|
287
|
+
for exc_type, handler in exception_handlers.items():
|
|
288
|
+
application.exception_handler(exc_type)(handler)
|
|
289
|
+
|
|
290
|
+
application.get("/")(root_handler)
|
|
291
|
+
|
|
292
|
+
return application
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# Create application instance
|
|
296
|
+
app = create_app()
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
if __name__ == "__main__":
|
|
300
|
+
import uvicorn
|
|
301
|
+
|
|
302
|
+
port = int(settings.app.PORT)
|
|
303
|
+
uvicorn.run(app, host=settings.app.HOST, port=port) # nosec B104 - binding to all interfaces is intentional
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
5
|
+
|
|
6
|
+
logger = structlog.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DoubleEncodedJSONMiddleware:
|
|
10
|
+
"""Middleware to handle double-encoded JSON payloads from frontend.
|
|
11
|
+
|
|
12
|
+
Some frontend clients may send JSON that's been stringified twice,
|
|
13
|
+
resulting in payloads like '"{\"key\":\"value\"}"' instead of '{"key":"value"}'.
|
|
14
|
+
This middleware detects and corrects such cases.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, app: ASGIApp):
|
|
18
|
+
self.app = app
|
|
19
|
+
|
|
20
|
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
21
|
+
if scope["type"] != "http":
|
|
22
|
+
await self.app(scope, receive, send)
|
|
23
|
+
return
|
|
24
|
+
|
|
25
|
+
method = scope["method"]
|
|
26
|
+
headers = dict(scope.get("headers", []))
|
|
27
|
+
content_type = headers.get(b"content-type", b"").decode("latin1")
|
|
28
|
+
|
|
29
|
+
if method in ["POST", "PUT", "PATCH"] and content_type:
|
|
30
|
+
body_parts = []
|
|
31
|
+
|
|
32
|
+
async def receive_wrapper() -> dict:
|
|
33
|
+
message = await receive()
|
|
34
|
+
if message["type"] == "http.request":
|
|
35
|
+
body_parts.append(message.get("body", b""))
|
|
36
|
+
|
|
37
|
+
if not message.get("more_body", False):
|
|
38
|
+
body = b"".join(body_parts)
|
|
39
|
+
|
|
40
|
+
if body:
|
|
41
|
+
try:
|
|
42
|
+
decoded = body.decode("utf-8")
|
|
43
|
+
parsed = json.loads(decoded)
|
|
44
|
+
|
|
45
|
+
if isinstance(parsed, str):
|
|
46
|
+
parsed = json.loads(parsed)
|
|
47
|
+
|
|
48
|
+
new_body = json.dumps(parsed).encode("utf-8")
|
|
49
|
+
|
|
50
|
+
if b"content-type" in headers and content_type != "application/json":
|
|
51
|
+
new_headers = []
|
|
52
|
+
for name, value in scope.get("headers", []):
|
|
53
|
+
if name != b"content-type":
|
|
54
|
+
new_headers.append((name, value))
|
|
55
|
+
new_headers.append((b"content-type", b"application/json"))
|
|
56
|
+
scope["headers"] = new_headers
|
|
57
|
+
|
|
58
|
+
return {
|
|
59
|
+
"type": "http.request",
|
|
60
|
+
"body": new_body,
|
|
61
|
+
"more_body": False,
|
|
62
|
+
}
|
|
63
|
+
except (
|
|
64
|
+
json.JSONDecodeError,
|
|
65
|
+
ValueError,
|
|
66
|
+
UnicodeDecodeError,
|
|
67
|
+
):
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
return message
|
|
71
|
+
|
|
72
|
+
await self.app(scope, receive_wrapper, send)
|
|
73
|
+
else:
|
|
74
|
+
await self.app(scope, receive, send)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import TypedDict
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
from asgi_correlation_id import correlation_id
|
|
6
|
+
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
7
|
+
from uvicorn.protocols.utils import get_path_with_query_string
|
|
8
|
+
|
|
9
|
+
from aegra_api.settings import settings
|
|
10
|
+
|
|
11
|
+
app_logger = structlog.stdlib.get_logger("app.app_logs")
|
|
12
|
+
access_logger = structlog.stdlib.get_logger("app.access_logs")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AccessInfo(TypedDict, total=False):
|
|
16
|
+
status_code: int
|
|
17
|
+
start_time: float
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StructLogMiddleware:
|
|
21
|
+
def __init__(self, app: ASGIApp):
|
|
22
|
+
self.app = app
|
|
23
|
+
|
|
24
|
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
25
|
+
# If the request is not an HTTP request, we don't need to do anything special
|
|
26
|
+
if scope["type"] != "http":
|
|
27
|
+
await self.app(scope, receive, send)
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
structlog.contextvars.clear_contextvars()
|
|
31
|
+
structlog.contextvars.bind_contextvars(request_id=correlation_id.get())
|
|
32
|
+
|
|
33
|
+
info = AccessInfo()
|
|
34
|
+
|
|
35
|
+
# Inner send function
|
|
36
|
+
async def inner_send(message):
|
|
37
|
+
if message.get("type") == "http.response.start":
|
|
38
|
+
info["status_code"] = message.get("status", 500)
|
|
39
|
+
await send(message)
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
info["start_time"] = time.perf_counter_ns()
|
|
43
|
+
await self.app(scope, receive, inner_send)
|
|
44
|
+
except Exception as e:
|
|
45
|
+
# Log the exception here, but re-raise so application-level
|
|
46
|
+
# exception handlers (e.g. Agent Protocol formatter) can run.
|
|
47
|
+
app_logger.exception(
|
|
48
|
+
"An unhandled exception was caught by middleware; re-raising to allow app handlers to format response",
|
|
49
|
+
exception_class=e.__class__.__name__,
|
|
50
|
+
exc_info=e,
|
|
51
|
+
stack_info=True,
|
|
52
|
+
)
|
|
53
|
+
raise
|
|
54
|
+
finally:
|
|
55
|
+
process_time = time.perf_counter_ns() - info["start_time"]
|
|
56
|
+
client_host, client_port = scope["client"]
|
|
57
|
+
http_method = scope["method"]
|
|
58
|
+
http_version = scope["http_version"]
|
|
59
|
+
url = get_path_with_query_string(scope)
|
|
60
|
+
|
|
61
|
+
# Recreate the Uvicorn access log format, but add all parameters as structured information
|
|
62
|
+
log_data = {
|
|
63
|
+
"url": str(url),
|
|
64
|
+
"status_code": info.get("status_code", 500),
|
|
65
|
+
"method": http_method,
|
|
66
|
+
"version": http_version,
|
|
67
|
+
}
|
|
68
|
+
if settings.app.LOG_VERBOSITY == "verbose":
|
|
69
|
+
log_data["request_id"] = correlation_id.get()
|
|
70
|
+
|
|
71
|
+
status_code = info.get("status_code", 500)
|
|
72
|
+
if 400 <= status_code < 500:
|
|
73
|
+
# Log as warning for client errors (4xx)
|
|
74
|
+
access_logger.warning(
|
|
75
|
+
f"""{client_host}:{client_port} - "{http_method} {scope["path"]} HTTP/{http_version}" {status_code}""",
|
|
76
|
+
http=log_data,
|
|
77
|
+
network={"client": {"ip": client_host, "port": client_port}},
|
|
78
|
+
duration=process_time,
|
|
79
|
+
)
|
|
80
|
+
elif 500 <= status_code < 600:
|
|
81
|
+
# Log as error for server errors (5xx)
|
|
82
|
+
access_logger.error(
|
|
83
|
+
f"""{client_host}:{client_port} - "{http_method} {scope["path"]} HTTP/{http_version}" {status_code}""",
|
|
84
|
+
http=log_data,
|
|
85
|
+
network={"client": {"ip": client_host, "port": client_port}},
|
|
86
|
+
duration=process_time,
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
# Normal log for successful responses (2xx, 3xx)
|
|
90
|
+
access_logger.info(
|
|
91
|
+
f"""{client_host}:{client_port} - "{http_method} {scope["path"]} HTTP/{http_version}" {status_code}""",
|
|
92
|
+
http=log_data,
|
|
93
|
+
network={"client": {"ip": client_host, "port": client_port}},
|
|
94
|
+
duration=process_time,
|
|
95
|
+
)
|