aegra-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. aegra_api/__init__.py +3 -0
  2. aegra_api/api/__init__.py +1 -0
  3. aegra_api/api/assistants.py +235 -0
  4. aegra_api/api/runs.py +1110 -0
  5. aegra_api/api/store.py +200 -0
  6. aegra_api/api/threads.py +761 -0
  7. aegra_api/config.py +204 -0
  8. aegra_api/constants.py +5 -0
  9. aegra_api/core/__init__.py +0 -0
  10. aegra_api/core/app_loader.py +91 -0
  11. aegra_api/core/auth_ctx.py +65 -0
  12. aegra_api/core/auth_deps.py +186 -0
  13. aegra_api/core/auth_handlers.py +248 -0
  14. aegra_api/core/auth_middleware.py +331 -0
  15. aegra_api/core/database.py +123 -0
  16. aegra_api/core/health.py +131 -0
  17. aegra_api/core/orm.py +165 -0
  18. aegra_api/core/route_merger.py +69 -0
  19. aegra_api/core/serializers/__init__.py +7 -0
  20. aegra_api/core/serializers/base.py +22 -0
  21. aegra_api/core/serializers/general.py +54 -0
  22. aegra_api/core/serializers/langgraph.py +102 -0
  23. aegra_api/core/sse.py +178 -0
  24. aegra_api/main.py +303 -0
  25. aegra_api/middleware/__init__.py +4 -0
  26. aegra_api/middleware/double_encoded_json.py +74 -0
  27. aegra_api/middleware/logger_middleware.py +95 -0
  28. aegra_api/models/__init__.py +76 -0
  29. aegra_api/models/assistants.py +81 -0
  30. aegra_api/models/auth.py +62 -0
  31. aegra_api/models/enums.py +29 -0
  32. aegra_api/models/errors.py +29 -0
  33. aegra_api/models/runs.py +124 -0
  34. aegra_api/models/store.py +67 -0
  35. aegra_api/models/threads.py +152 -0
  36. aegra_api/observability/__init__.py +1 -0
  37. aegra_api/observability/base.py +88 -0
  38. aegra_api/observability/otel.py +133 -0
  39. aegra_api/observability/setup.py +27 -0
  40. aegra_api/observability/targets/__init__.py +11 -0
  41. aegra_api/observability/targets/base.py +18 -0
  42. aegra_api/observability/targets/langfuse.py +33 -0
  43. aegra_api/observability/targets/otlp.py +38 -0
  44. aegra_api/observability/targets/phoenix.py +24 -0
  45. aegra_api/services/__init__.py +0 -0
  46. aegra_api/services/assistant_service.py +569 -0
  47. aegra_api/services/base_broker.py +59 -0
  48. aegra_api/services/broker.py +141 -0
  49. aegra_api/services/event_converter.py +157 -0
  50. aegra_api/services/event_store.py +196 -0
  51. aegra_api/services/graph_streaming.py +433 -0
  52. aegra_api/services/langgraph_service.py +456 -0
  53. aegra_api/services/streaming_service.py +362 -0
  54. aegra_api/services/thread_state_service.py +128 -0
  55. aegra_api/settings.py +124 -0
  56. aegra_api/utils/__init__.py +3 -0
  57. aegra_api/utils/assistants.py +23 -0
  58. aegra_api/utils/run_utils.py +60 -0
  59. aegra_api/utils/setup_logging.py +122 -0
  60. aegra_api/utils/sse_utils.py +26 -0
  61. aegra_api/utils/status_compat.py +57 -0
  62. aegra_api-0.1.0.dist-info/METADATA +244 -0
  63. aegra_api-0.1.0.dist-info/RECORD +64 -0
  64. aegra_api-0.1.0.dist-info/WHEEL +4 -0
aegra_api/core/sse.py ADDED
@@ -0,0 +1,178 @@
1
+ """Server-Sent Events utilities and formatting"""
2
+
3
+ import json
4
+ from collections.abc import Callable
5
+ from dataclasses import dataclass
6
+ from datetime import UTC, datetime
7
+ from typing import Any
8
+
9
+ # Import our serializer for handling complex objects
10
+ from aegra_api.core.serializers import GeneralSerializer
11
+
12
+ # Global serializer instance
13
+ _serializer = GeneralSerializer()
14
+
15
+
16
+ def get_sse_headers() -> dict[str, str]:
17
+ """Get standard SSE headers"""
18
+ return {
19
+ "Cache-Control": "no-cache",
20
+ "Connection": "keep-alive",
21
+ "Content-Type": "text/event-stream",
22
+ "Access-Control-Allow-Origin": "*",
23
+ "Access-Control-Allow-Headers": "Last-Event-ID",
24
+ }
25
+
26
+
27
+ def format_sse_message(
28
+ event: str,
29
+ data: Any,
30
+ event_id: str | None = None,
31
+ serializer: Callable[[Any], Any] | None = None,
32
+ ) -> str:
33
+ """Format a message as Server-Sent Event following SSE standard
34
+
35
+ Args:
36
+ event: SSE event type
37
+ data: Data to serialize and send
38
+ event_id: Optional event ID
39
+ serializer: Optional custom serializer function
40
+ """
41
+ lines = []
42
+
43
+ lines.append(f"event: {event}")
44
+
45
+ # Convert data to JSON string
46
+ if data is None:
47
+ data_str = ""
48
+ else:
49
+ # Use our general serializer by default to handle complex objects
50
+ default_serializer = serializer or _serializer.serialize
51
+ data_str = json.dumps(data, default=default_serializer, separators=(",", ":"))
52
+
53
+ lines.append(f"data: {data_str}")
54
+
55
+ if event_id:
56
+ lines.append(f"id: {event_id}")
57
+
58
+ lines.append("") # Empty line to end the event
59
+
60
+ return "\n".join(lines) + "\n"
61
+
62
+
63
+ def create_metadata_event(run_id: str, event_id: str | None = None, attempt: int = 1) -> str:
64
+ """Create metadata event for LangSmith Studio compatibility"""
65
+ data = {"run_id": run_id, "attempt": attempt}
66
+ return format_sse_message("metadata", data, event_id)
67
+
68
+
69
+ def create_debug_event(debug_data: dict[str, Any], event_id: str | None = None) -> str:
70
+ """Create debug event with checkpoint fields for LangSmith Studio compatibility"""
71
+
72
+ # Add checkpoint and parent_checkpoint fields if not present
73
+ if "payload" in debug_data and isinstance(debug_data["payload"], dict):
74
+ payload = debug_data["payload"]
75
+
76
+ # Extract checkpoint from config.configurable
77
+ if "checkpoint" not in payload and "config" in payload:
78
+ config = payload.get("config", {})
79
+ if isinstance(config, dict) and "configurable" in config:
80
+ configurable = config["configurable"]
81
+ if isinstance(configurable, dict):
82
+ payload["checkpoint"] = {
83
+ "thread_id": configurable.get("thread_id"),
84
+ "checkpoint_id": configurable.get("checkpoint_id"),
85
+ "checkpoint_ns": configurable.get("checkpoint_ns", ""),
86
+ }
87
+
88
+ # Extract parent_checkpoint from parent_config.configurable
89
+ if "parent_checkpoint" not in payload and "parent_config" in payload:
90
+ parent_config = payload.get("parent_config")
91
+ if isinstance(parent_config, dict) and "configurable" in parent_config:
92
+ configurable = parent_config["configurable"]
93
+ if isinstance(configurable, dict):
94
+ payload["parent_checkpoint"] = {
95
+ "thread_id": configurable.get("thread_id"),
96
+ "checkpoint_id": configurable.get("checkpoint_id"),
97
+ "checkpoint_ns": configurable.get("checkpoint_ns", ""),
98
+ }
99
+ elif parent_config is None:
100
+ payload["parent_checkpoint"] = None
101
+
102
+ return format_sse_message("debug", debug_data, event_id)
103
+
104
+
105
+ def create_end_event(event_id: str | None = None) -> str:
106
+ """Create end event - signals completion of stream
107
+
108
+ Uses standard status: "success" instead of "completed"
109
+ """
110
+ return format_sse_message("end", {"status": "success"}, event_id)
111
+
112
+
113
+ def create_error_event(error: str | dict[str, Any], event_id: str | None = None) -> str:
114
+ """Create error event with structured error information.
115
+
116
+ Error format: {"error": str, "message": str}
117
+ This format ensures compatibility with standard SSE error event consumers.
118
+
119
+ Args:
120
+ error: Either a simple error string, or a dict with structured error info.
121
+ Dict format: {"error": "ErrorType", "message": "detailed message"}
122
+ event_id: Optional SSE event ID for reconnection support.
123
+
124
+ Returns:
125
+ SSE-formatted error event string with standard error format.
126
+ """
127
+ if isinstance(error, dict):
128
+ # Structured error format - standard format: {error: str, message: str}
129
+ data = {
130
+ "error": error.get("error", "Error"),
131
+ "message": error.get("message", str(error)),
132
+ }
133
+ else:
134
+ # Simple string format - wrap it to standard format
135
+ data = {
136
+ "error": "Error",
137
+ "message": str(error),
138
+ }
139
+ return format_sse_message("error", data, event_id)
140
+
141
+
142
+ def create_messages_event(messages_data: Any, event_type: str = "messages", event_id: str | None = None) -> str:
143
+ """Create messages event (messages, messages/partial, messages/complete, messages/metadata)"""
144
+ # Handle tuple format for token streaming: (message_chunk, metadata)
145
+ if isinstance(messages_data, tuple) and len(messages_data) == 2:
146
+ message_chunk, metadata = messages_data
147
+ # Format as expected by LangGraph SDK client
148
+ data = [message_chunk, metadata]
149
+ return format_sse_message(event_type, data, event_id)
150
+ else:
151
+ # Handle list of messages format
152
+ return format_sse_message(event_type, messages_data, event_id)
153
+
154
+
155
+ # Legacy compatibility - used by event_store.py
156
+ @dataclass
157
+ class SSEEvent:
158
+ """SSE Event data structure for event storage"""
159
+
160
+ id: str
161
+ event: str
162
+ data: dict[str, Any]
163
+ timestamp: datetime | None = None
164
+
165
+ def __post_init__(self) -> None:
166
+ if self.timestamp is None:
167
+ self.timestamp = datetime.now(UTC)
168
+
169
+ def format(self) -> str:
170
+ """Format as proper SSE event"""
171
+ json_data = json.dumps(self.data, default=str)
172
+ return f"id: {self.id}\nevent: {self.event}\ndata: {json_data}\n\n"
173
+
174
+
175
+ def format_sse_event(id: str, event: str, data: dict[str, Any]) -> str:
176
+ """Format SSE event (used by event_store)"""
177
+ json_data = json.dumps(data, default=str)
178
+ return f"id: {id}\nevent: {event}\ndata: {json_data}\n\n"
aegra_api/main.py ADDED
@@ -0,0 +1,303 @@
1
+ """FastAPI application for Aegra (Agent Protocol Server)"""
2
+
3
+ import asyncio
4
+ import sys
5
+ from collections.abc import AsyncIterator
6
+ from contextlib import asynccontextmanager
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from dotenv import load_dotenv
11
+
12
+ # Load environment variables from .env file
13
+ load_dotenv()
14
+
15
+ # Add graphs directory to Python path so react_agent can be imported
16
+ # This MUST happen before importing any modules that depend on graphs/
17
+ current_dir = Path(__file__).parent.parent.parent # Go up to aegra root
18
+ graphs_dir = current_dir / "graphs"
19
+ if str(graphs_dir) not in sys.path:
20
+ sys.path.insert(0, str(graphs_dir))
21
+
22
+ # ruff: noqa: E402 - imports below require sys.path modification above
23
+ import structlog
24
+ from asgi_correlation_id import CorrelationIdMiddleware
25
+ from fastapi import FastAPI, HTTPException, Request
26
+ from fastapi.middleware.cors import CORSMiddleware
27
+ from fastapi.responses import JSONResponse
28
+
29
+ from aegra_api.api.assistants import router as assistants_router
30
+ from aegra_api.api.runs import router as runs_router
31
+ from aegra_api.api.store import router as store_router
32
+ from aegra_api.api.threads import router as threads_router
33
+ from aegra_api.config import HttpConfig, get_config_dir, load_http_config
34
+ from aegra_api.core.app_loader import load_custom_app
35
+ from aegra_api.core.auth_deps import auth_dependency
36
+ from aegra_api.core.database import db_manager
37
+ from aegra_api.core.health import router as health_router
38
+ from aegra_api.core.route_merger import (
39
+ merge_exception_handlers,
40
+ merge_lifespans,
41
+ )
42
+ from aegra_api.middleware import DoubleEncodedJSONMiddleware, StructLogMiddleware
43
+ from aegra_api.models.errors import AgentProtocolError, get_error_type
44
+ from aegra_api.observability.setup import setup_observability
45
+ from aegra_api.services.event_store import event_store
46
+ from aegra_api.services.langgraph_service import get_langgraph_service
47
+ from aegra_api.settings import settings
48
+ from aegra_api.utils.setup_logging import setup_logging
49
+
50
+ # Task management for run cancellation
51
+ active_runs: dict[str, asyncio.Task] = {}
52
+
53
+ setup_logging()
54
+ logger = structlog.getLogger(__name__)
55
+
56
+ # Default CORS headers required for LangGraph SDK stream reconnection
57
+ DEFAULT_EXPOSE_HEADERS = ["Content-Location", "Location"]
58
+
59
+
60
+ @asynccontextmanager
61
+ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
62
+ """FastAPI lifespan context manager for startup/shutdown"""
63
+ # Startup: Initialize database and LangGraph components
64
+ await db_manager.initialize()
65
+
66
+ # Observability
67
+ setup_observability()
68
+
69
+ # Initialize LangGraph service
70
+ langgraph_service = get_langgraph_service()
71
+ await langgraph_service.initialize()
72
+
73
+ # Initialize event store cleanup task
74
+ await event_store.start_cleanup_task()
75
+
76
+ yield
77
+
78
+ # Shutdown: Clean up connections and cancel active runs
79
+ for task in active_runs.values():
80
+ if not task.done():
81
+ task.cancel()
82
+
83
+ # Stop event store cleanup task
84
+ await event_store.stop_cleanup_task()
85
+
86
+ await db_manager.close()
87
+
88
+
89
+ # Define core exception handlers
90
+ async def agent_protocol_exception_handler(_request: Request, exc: HTTPException) -> JSONResponse:
91
+ """Convert HTTP exceptions to Agent Protocol error format"""
92
+ return JSONResponse(
93
+ status_code=exc.status_code,
94
+ content=AgentProtocolError(
95
+ error=get_error_type(exc.status_code),
96
+ message=exc.detail,
97
+ details=getattr(exc, "details", None),
98
+ ).model_dump(),
99
+ )
100
+
101
+
102
+ async def general_exception_handler(_request: Request, exc: Exception) -> JSONResponse:
103
+ """Handle unexpected exceptions"""
104
+ return JSONResponse(
105
+ status_code=500,
106
+ content=AgentProtocolError(
107
+ error="internal_error",
108
+ message="An unexpected error occurred",
109
+ details={"exception": str(exc)},
110
+ ).model_dump(),
111
+ )
112
+
113
+
114
+ exception_handlers = {
115
+ HTTPException: agent_protocol_exception_handler,
116
+ Exception: general_exception_handler,
117
+ }
118
+
119
+
120
+ # Define root endpoint handler
121
+ async def root_handler() -> dict[str, str]:
122
+ """Root endpoint"""
123
+ return {
124
+ "message": settings.app.PROJECT_NAME,
125
+ "version": settings.app.VERSION,
126
+ "status": "running",
127
+ }
128
+
129
+
130
+ def _apply_auth_to_routes(app: FastAPI, auth_deps: list[Any]) -> None:
131
+ """Apply auth dependency to all existing routes in the FastAPI app.
132
+
133
+ This function recursively processes all routes including nested routers,
134
+ adding the auth dependency to each route that doesn't already have it.
135
+ Auth dependencies are prepended to ensure they run first (fail-fast).
136
+
137
+ Args:
138
+ app: FastAPI application instance
139
+ auth_deps: List of dependencies to apply (e.g., [Depends(require_auth)])
140
+ """
141
+ from fastapi.routing import APIRoute, APIRouter
142
+
143
+ def process_routes(routes: list) -> None:
144
+ """Recursively process routes and nested routers."""
145
+ for route in routes:
146
+ if isinstance(route, APIRoute):
147
+ # Add auth dependency if not already present
148
+ existing_deps = list(route.dependencies or [])
149
+ # Check if auth dependency is already present
150
+ auth_dep_ids = {id(dep) for dep in auth_deps}
151
+ existing_dep_ids = {id(dep) for dep in existing_deps}
152
+ if not auth_dep_ids.intersection(existing_dep_ids):
153
+ # Prepend auth deps so they run first (fail-fast)
154
+ route.dependencies = auth_deps + existing_deps
155
+ elif isinstance(route, APIRouter):
156
+ # Process nested router
157
+ process_routes(route.routes)
158
+ elif hasattr(route, "routes"):
159
+ # Handle other route types that have nested routes
160
+ process_routes(route.routes)
161
+
162
+ process_routes(app.routes)
163
+ logger.info("Applied authentication dependency to custom routes")
164
+
165
+
166
+ def _add_cors_middleware(app: FastAPI, cors_config: dict[str, Any] | None) -> None:
167
+ """Add CORS middleware with config or defaults.
168
+
169
+ Args:
170
+ app: FastAPI application instance
171
+ cors_config: CORS configuration dict or None for defaults
172
+ """
173
+ if cors_config:
174
+ app.add_middleware(
175
+ CORSMiddleware,
176
+ allow_origins=cors_config.get("allow_origins", ["*"]),
177
+ allow_credentials=cors_config.get("allow_credentials", True),
178
+ allow_methods=cors_config.get("allow_methods", ["*"]),
179
+ allow_headers=cors_config.get("allow_headers", ["*"]),
180
+ expose_headers=cors_config.get("expose_headers", DEFAULT_EXPOSE_HEADERS),
181
+ max_age=cors_config.get("max_age", 600),
182
+ )
183
+ else:
184
+ app.add_middleware(
185
+ CORSMiddleware,
186
+ allow_origins=["*"],
187
+ allow_credentials=True,
188
+ allow_methods=["*"],
189
+ allow_headers=["*"],
190
+ expose_headers=DEFAULT_EXPOSE_HEADERS,
191
+ )
192
+
193
+
194
+ def _add_common_middleware(app: FastAPI, cors_config: dict[str, Any] | None) -> None:
195
+ """Add common middleware stack in correct order.
196
+
197
+ Middleware runs in reverse registration order, so we register:
198
+ 1. DoubleEncodedJSONMiddleware (outermost - runs first)
199
+ 2. CORSMiddleware (handles preflight early)
200
+ 3. CorrelationIdMiddleware (adds request ID)
201
+ 4. StructLogMiddleware (innermost - logs with correlation ID)
202
+
203
+ Args:
204
+ app: FastAPI application instance
205
+ cors_config: CORS configuration dict or None for defaults
206
+ """
207
+ app.add_middleware(StructLogMiddleware)
208
+ app.add_middleware(CorrelationIdMiddleware)
209
+ _add_cors_middleware(app, cors_config)
210
+ app.add_middleware(DoubleEncodedJSONMiddleware)
211
+
212
+
213
+ def _include_core_routers(app: FastAPI) -> None:
214
+ """Include all core API routers with auth dependency.
215
+
216
+ Routers are included in consistent order:
217
+ 1. Health (no auth)
218
+ 2. Assistants (with auth)
219
+ 3. Threads (with auth)
220
+ 4. Runs (with auth)
221
+ 5. Store (with auth)
222
+
223
+ Args:
224
+ app: FastAPI application instance
225
+ """
226
+ app.include_router(health_router, prefix="", tags=["Health"])
227
+ app.include_router(assistants_router, dependencies=auth_dependency, prefix="", tags=["Assistants"])
228
+ app.include_router(threads_router, dependencies=auth_dependency, prefix="", tags=["Threads"])
229
+ app.include_router(runs_router, dependencies=auth_dependency, prefix="", tags=["Runs"])
230
+ app.include_router(store_router, dependencies=auth_dependency, prefix="", tags=["Store"])
231
+
232
+
233
+ def create_app() -> FastAPI:
234
+ """Create and configure the FastAPI application.
235
+
236
+ Returns:
237
+ Configured FastAPI application instance
238
+ """
239
+ http_config: HttpConfig | None = load_http_config()
240
+ cors_config = http_config.get("cors") if http_config else None
241
+
242
+ # Try to load custom app if configured
243
+ user_app = None
244
+ if http_config and http_config.get("app"):
245
+ try:
246
+ config_dir = get_config_dir()
247
+ user_app = load_custom_app(http_config["app"], base_dir=config_dir)
248
+ logger.info("Custom app loaded successfully")
249
+ except Exception as e:
250
+ logger.error(f"Failed to load custom app: {e}", exc_info=True)
251
+ raise
252
+
253
+ if user_app:
254
+ if not isinstance(user_app, FastAPI):
255
+ raise TypeError(
256
+ "Custom apps must be FastAPI applications. Use: from fastapi import FastAPI; app = FastAPI()"
257
+ )
258
+
259
+ application = user_app
260
+ _include_core_routers(application)
261
+
262
+ # Add root endpoint if not already defined
263
+ if not any(route.path == "/" for route in application.routes if hasattr(route, "path")):
264
+ application.get("/")(root_handler)
265
+
266
+ application = merge_lifespans(application, lifespan)
267
+ application = merge_exception_handlers(application, exception_handlers)
268
+ _add_common_middleware(application, cors_config)
269
+
270
+ # Apply auth to custom routes if enabled
271
+ if http_config and http_config.get("enable_custom_route_auth", False):
272
+ _apply_auth_to_routes(application, auth_dependency)
273
+ else:
274
+ application = FastAPI(
275
+ title=settings.app.PROJECT_NAME,
276
+ description="Production-ready Agent Protocol server",
277
+ version=settings.app.VERSION,
278
+ debug=settings.app.DEBUG,
279
+ docs_url="/docs",
280
+ redoc_url="/redoc",
281
+ lifespan=lifespan,
282
+ )
283
+
284
+ _add_common_middleware(application, cors_config)
285
+ _include_core_routers(application)
286
+
287
+ for exc_type, handler in exception_handlers.items():
288
+ application.exception_handler(exc_type)(handler)
289
+
290
+ application.get("/")(root_handler)
291
+
292
+ return application
293
+
294
+
295
+ # Create application instance
296
+ app = create_app()
297
+
298
+
299
+ if __name__ == "__main__":
300
+ import uvicorn
301
+
302
+ port = int(settings.app.PORT)
303
+ uvicorn.run(app, host=settings.app.HOST, port=port) # nosec B104 - binding to all interfaces is intentional
@@ -0,0 +1,4 @@
1
+ from aegra_api.middleware.double_encoded_json import DoubleEncodedJSONMiddleware
2
+ from aegra_api.middleware.logger_middleware import StructLogMiddleware
3
+
4
+ __all__ = ["DoubleEncodedJSONMiddleware", "StructLogMiddleware"]
@@ -0,0 +1,74 @@
1
+ import json
2
+
3
+ import structlog
4
+ from starlette.types import ASGIApp, Receive, Scope, Send
5
+
6
+ logger = structlog.getLogger(__name__)
7
+
8
+
9
+ class DoubleEncodedJSONMiddleware:
10
+ """Middleware to handle double-encoded JSON payloads from frontend.
11
+
12
+ Some frontend clients may send JSON that's been stringified twice,
13
+ resulting in payloads like '"{\"key\":\"value\"}"' instead of '{"key":"value"}'.
14
+ This middleware detects and corrects such cases.
15
+ """
16
+
17
+ def __init__(self, app: ASGIApp):
18
+ self.app = app
19
+
20
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
21
+ if scope["type"] != "http":
22
+ await self.app(scope, receive, send)
23
+ return
24
+
25
+ method = scope["method"]
26
+ headers = dict(scope.get("headers", []))
27
+ content_type = headers.get(b"content-type", b"").decode("latin1")
28
+
29
+ if method in ["POST", "PUT", "PATCH"] and content_type:
30
+ body_parts = []
31
+
32
+ async def receive_wrapper() -> dict:
33
+ message = await receive()
34
+ if message["type"] == "http.request":
35
+ body_parts.append(message.get("body", b""))
36
+
37
+ if not message.get("more_body", False):
38
+ body = b"".join(body_parts)
39
+
40
+ if body:
41
+ try:
42
+ decoded = body.decode("utf-8")
43
+ parsed = json.loads(decoded)
44
+
45
+ if isinstance(parsed, str):
46
+ parsed = json.loads(parsed)
47
+
48
+ new_body = json.dumps(parsed).encode("utf-8")
49
+
50
+ if b"content-type" in headers and content_type != "application/json":
51
+ new_headers = []
52
+ for name, value in scope.get("headers", []):
53
+ if name != b"content-type":
54
+ new_headers.append((name, value))
55
+ new_headers.append((b"content-type", b"application/json"))
56
+ scope["headers"] = new_headers
57
+
58
+ return {
59
+ "type": "http.request",
60
+ "body": new_body,
61
+ "more_body": False,
62
+ }
63
+ except (
64
+ json.JSONDecodeError,
65
+ ValueError,
66
+ UnicodeDecodeError,
67
+ ):
68
+ pass
69
+
70
+ return message
71
+
72
+ await self.app(scope, receive_wrapper, send)
73
+ else:
74
+ await self.app(scope, receive, send)
@@ -0,0 +1,95 @@
1
+ import time
2
+ from typing import TypedDict
3
+
4
+ import structlog
5
+ from asgi_correlation_id import correlation_id
6
+ from starlette.types import ASGIApp, Receive, Scope, Send
7
+ from uvicorn.protocols.utils import get_path_with_query_string
8
+
9
+ from aegra_api.settings import settings
10
+
11
+ app_logger = structlog.stdlib.get_logger("app.app_logs")
12
+ access_logger = structlog.stdlib.get_logger("app.access_logs")
13
+
14
+
15
+ class AccessInfo(TypedDict, total=False):
16
+ status_code: int
17
+ start_time: float
18
+
19
+
20
+ class StructLogMiddleware:
21
+ def __init__(self, app: ASGIApp):
22
+ self.app = app
23
+
24
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
25
+ # If the request is not an HTTP request, we don't need to do anything special
26
+ if scope["type"] != "http":
27
+ await self.app(scope, receive, send)
28
+ return
29
+
30
+ structlog.contextvars.clear_contextvars()
31
+ structlog.contextvars.bind_contextvars(request_id=correlation_id.get())
32
+
33
+ info = AccessInfo()
34
+
35
+ # Inner send function
36
+ async def inner_send(message):
37
+ if message.get("type") == "http.response.start":
38
+ info["status_code"] = message.get("status", 500)
39
+ await send(message)
40
+
41
+ try:
42
+ info["start_time"] = time.perf_counter_ns()
43
+ await self.app(scope, receive, inner_send)
44
+ except Exception as e:
45
+ # Log the exception here, but re-raise so application-level
46
+ # exception handlers (e.g. Agent Protocol formatter) can run.
47
+ app_logger.exception(
48
+ "An unhandled exception was caught by middleware; re-raising to allow app handlers to format response",
49
+ exception_class=e.__class__.__name__,
50
+ exc_info=e,
51
+ stack_info=True,
52
+ )
53
+ raise
54
+ finally:
55
+ process_time = time.perf_counter_ns() - info["start_time"]
56
+ client_host, client_port = scope["client"]
57
+ http_method = scope["method"]
58
+ http_version = scope["http_version"]
59
+ url = get_path_with_query_string(scope)
60
+
61
+ # Recreate the Uvicorn access log format, but add all parameters as structured information
62
+ log_data = {
63
+ "url": str(url),
64
+ "status_code": info.get("status_code", 500),
65
+ "method": http_method,
66
+ "version": http_version,
67
+ }
68
+ if settings.app.LOG_VERBOSITY == "verbose":
69
+ log_data["request_id"] = correlation_id.get()
70
+
71
+ status_code = info.get("status_code", 500)
72
+ if 400 <= status_code < 500:
73
+ # Log as warning for client errors (4xx)
74
+ access_logger.warning(
75
+ f"""{client_host}:{client_port} - "{http_method} {scope["path"]} HTTP/{http_version}" {status_code}""",
76
+ http=log_data,
77
+ network={"client": {"ip": client_host, "port": client_port}},
78
+ duration=process_time,
79
+ )
80
+ elif 500 <= status_code < 600:
81
+ # Log as error for server errors (5xx)
82
+ access_logger.error(
83
+ f"""{client_host}:{client_port} - "{http_method} {scope["path"]} HTTP/{http_version}" {status_code}""",
84
+ http=log_data,
85
+ network={"client": {"ip": client_host, "port": client_port}},
86
+ duration=process_time,
87
+ )
88
+ else:
89
+ # Normal log for successful responses (2xx, 3xx)
90
+ access_logger.info(
91
+ f"""{client_host}:{client_port} - "{http_method} {scope["path"]} HTTP/{http_version}" {status_code}""",
92
+ http=log_data,
93
+ network={"client": {"ip": client_host, "port": client_port}},
94
+ duration=process_time,
95
+ )