aegra-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. aegra_api/__init__.py +3 -0
  2. aegra_api/api/__init__.py +1 -0
  3. aegra_api/api/assistants.py +235 -0
  4. aegra_api/api/runs.py +1110 -0
  5. aegra_api/api/store.py +200 -0
  6. aegra_api/api/threads.py +761 -0
  7. aegra_api/config.py +204 -0
  8. aegra_api/constants.py +5 -0
  9. aegra_api/core/__init__.py +0 -0
  10. aegra_api/core/app_loader.py +91 -0
  11. aegra_api/core/auth_ctx.py +65 -0
  12. aegra_api/core/auth_deps.py +186 -0
  13. aegra_api/core/auth_handlers.py +248 -0
  14. aegra_api/core/auth_middleware.py +331 -0
  15. aegra_api/core/database.py +123 -0
  16. aegra_api/core/health.py +131 -0
  17. aegra_api/core/orm.py +165 -0
  18. aegra_api/core/route_merger.py +69 -0
  19. aegra_api/core/serializers/__init__.py +7 -0
  20. aegra_api/core/serializers/base.py +22 -0
  21. aegra_api/core/serializers/general.py +54 -0
  22. aegra_api/core/serializers/langgraph.py +102 -0
  23. aegra_api/core/sse.py +178 -0
  24. aegra_api/main.py +303 -0
  25. aegra_api/middleware/__init__.py +4 -0
  26. aegra_api/middleware/double_encoded_json.py +74 -0
  27. aegra_api/middleware/logger_middleware.py +95 -0
  28. aegra_api/models/__init__.py +76 -0
  29. aegra_api/models/assistants.py +81 -0
  30. aegra_api/models/auth.py +62 -0
  31. aegra_api/models/enums.py +29 -0
  32. aegra_api/models/errors.py +29 -0
  33. aegra_api/models/runs.py +124 -0
  34. aegra_api/models/store.py +67 -0
  35. aegra_api/models/threads.py +152 -0
  36. aegra_api/observability/__init__.py +1 -0
  37. aegra_api/observability/base.py +88 -0
  38. aegra_api/observability/otel.py +133 -0
  39. aegra_api/observability/setup.py +27 -0
  40. aegra_api/observability/targets/__init__.py +11 -0
  41. aegra_api/observability/targets/base.py +18 -0
  42. aegra_api/observability/targets/langfuse.py +33 -0
  43. aegra_api/observability/targets/otlp.py +38 -0
  44. aegra_api/observability/targets/phoenix.py +24 -0
  45. aegra_api/services/__init__.py +0 -0
  46. aegra_api/services/assistant_service.py +569 -0
  47. aegra_api/services/base_broker.py +59 -0
  48. aegra_api/services/broker.py +141 -0
  49. aegra_api/services/event_converter.py +157 -0
  50. aegra_api/services/event_store.py +196 -0
  51. aegra_api/services/graph_streaming.py +433 -0
  52. aegra_api/services/langgraph_service.py +456 -0
  53. aegra_api/services/streaming_service.py +362 -0
  54. aegra_api/services/thread_state_service.py +128 -0
  55. aegra_api/settings.py +124 -0
  56. aegra_api/utils/__init__.py +3 -0
  57. aegra_api/utils/assistants.py +23 -0
  58. aegra_api/utils/run_utils.py +60 -0
  59. aegra_api/utils/setup_logging.py +122 -0
  60. aegra_api/utils/sse_utils.py +26 -0
  61. aegra_api/utils/status_compat.py +57 -0
  62. aegra_api-0.1.0.dist-info/METADATA +244 -0
  63. aegra_api-0.1.0.dist-info/RECORD +64 -0
  64. aegra_api-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,141 @@
1
+ """Event broker for managing run-specific event queues"""
2
+
3
+ import asyncio
4
+ import contextlib
5
+ from collections.abc import AsyncIterator
6
+ from typing import Any
7
+
8
+ import structlog
9
+
10
+ from aegra_api.services.base_broker import BaseBrokerManager, BaseRunBroker
11
+
12
+ logger = structlog.getLogger(__name__)
13
+
14
+
15
+ class RunBroker(BaseRunBroker):
16
+ """Manages event queuing and distribution for a specific run"""
17
+
18
+ def __init__(self, run_id: str):
19
+ self.run_id = run_id
20
+ self.queue: asyncio.Queue[tuple[str, Any]] = asyncio.Queue()
21
+ self.finished = asyncio.Event()
22
+ self._created_at = asyncio.get_event_loop().time()
23
+
24
+ async def put(self, event_id: str, payload: Any) -> None:
25
+ """Put an event into the broker queue"""
26
+ if self.finished.is_set():
27
+ logger.warning(f"Attempted to put event {event_id} into finished broker for run {self.run_id}")
28
+ return
29
+
30
+ await self.queue.put((event_id, payload))
31
+
32
+ # Check if this is an end event
33
+ if isinstance(payload, tuple) and len(payload) >= 1 and payload[0] == "end":
34
+ self.mark_finished()
35
+
36
+ async def aiter(self) -> AsyncIterator[tuple[str, Any]]:
37
+ """Async iterator yielding (event_id, payload) pairs"""
38
+ while True:
39
+ try:
40
+ # Use timeout to check if run is finished
41
+ event_id, payload = await asyncio.wait_for(self.queue.get(), timeout=0.1)
42
+ yield event_id, payload
43
+
44
+ # Check if this is an end event
45
+ if isinstance(payload, tuple) and len(payload) >= 1 and payload[0] == "end":
46
+ break
47
+
48
+ except TimeoutError:
49
+ # Check if run is finished and queue is empty
50
+ if self.finished.is_set() and self.queue.empty():
51
+ break
52
+ continue
53
+
54
+ def mark_finished(self) -> None:
55
+ """Mark this broker as finished"""
56
+ self.finished.set()
57
+ logger.debug(f"Broker for run {self.run_id} marked as finished")
58
+
59
+ def is_finished(self) -> bool:
60
+ """Check if this broker is finished"""
61
+ return self.finished.is_set()
62
+
63
+ def is_empty(self) -> bool:
64
+ """Check if the queue is empty"""
65
+ return self.queue.empty()
66
+
67
+ def get_age(self) -> float:
68
+ """Get the age of this broker in seconds"""
69
+ return asyncio.get_event_loop().time() - self._created_at
70
+
71
+
72
+ class BrokerManager(BaseBrokerManager):
73
+ """Manages multiple RunBroker instances"""
74
+
75
+ def __init__(self) -> None:
76
+ self._brokers: dict[str, RunBroker] = {}
77
+ self._cleanup_task: asyncio.Task | None = None
78
+
79
+ def get_or_create_broker(self, run_id: str) -> RunBroker:
80
+ """Get or create a broker for a run"""
81
+ if run_id not in self._brokers:
82
+ self._brokers[run_id] = RunBroker(run_id)
83
+ logger.debug(f"Created new broker for run {run_id}")
84
+ return self._brokers[run_id]
85
+
86
+ def get_broker(self, run_id: str) -> RunBroker | None:
87
+ """Get an existing broker or None"""
88
+ return self._brokers.get(run_id)
89
+
90
+ def cleanup_broker(self, run_id: str) -> None:
91
+ """Clean up a broker for a run"""
92
+ if run_id in self._brokers:
93
+ self._brokers[run_id].mark_finished()
94
+ # Don't immediately delete in case there are still consumers
95
+ logger.debug(f"Marked broker for run {run_id} for cleanup")
96
+
97
+ def remove_broker(self, run_id: str) -> None:
98
+ """Remove a broker completely"""
99
+ if run_id in self._brokers:
100
+ self._brokers[run_id].mark_finished()
101
+ del self._brokers[run_id]
102
+ logger.debug(f"Removed broker for run {run_id}")
103
+
104
+ async def start_cleanup_task(self) -> None:
105
+ """Start background cleanup task for old brokers"""
106
+ if self._cleanup_task is None or self._cleanup_task.done():
107
+ self._cleanup_task = asyncio.create_task(self._cleanup_old_brokers())
108
+
109
+ async def stop_cleanup_task(self) -> None:
110
+ """Stop background cleanup task"""
111
+ if self._cleanup_task and not self._cleanup_task.done():
112
+ self._cleanup_task.cancel()
113
+ with contextlib.suppress(asyncio.CancelledError):
114
+ await self._cleanup_task
115
+
116
+ async def _cleanup_old_brokers(self) -> None:
117
+ """Background task to clean up old finished brokers"""
118
+ while True:
119
+ try:
120
+ await asyncio.sleep(300) # Check every 5 minutes
121
+
122
+ asyncio.get_event_loop().time()
123
+ to_remove = []
124
+
125
+ for run_id, broker in self._brokers.items():
126
+ # Remove brokers that are finished and older than 1 hour
127
+ if broker.is_finished() and broker.is_empty() and broker.get_age() > 3600:
128
+ to_remove.append(run_id)
129
+
130
+ for run_id in to_remove:
131
+ self.remove_broker(run_id)
132
+ logger.info(f"Cleaned up old broker for run {run_id}")
133
+
134
+ except asyncio.CancelledError:
135
+ break
136
+ except Exception as e:
137
+ logger.error(f"Error in broker cleanup task: {e}")
138
+
139
+
140
+ # Global broker manager instance
141
+ broker_manager = BrokerManager()
@@ -0,0 +1,157 @@
1
+ """Event converter for SSE streaming"""
2
+
3
+ from typing import Any
4
+
5
+ from aegra_api.core.sse import (
6
+ create_debug_event,
7
+ create_end_event,
8
+ create_error_event,
9
+ create_messages_event,
10
+ create_metadata_event,
11
+ format_sse_message,
12
+ )
13
+
14
+
15
+ class EventConverter:
16
+ """Converts events to SSE format"""
17
+
18
+ def __init__(self):
19
+ """Initialize event converter"""
20
+ self.subgraphs: bool = False
21
+
22
+ def set_subgraphs(self, subgraphs: bool) -> None:
23
+ """Set whether subgraphs mode is enabled for namespace extraction"""
24
+ self.subgraphs = subgraphs
25
+
26
+ def convert_raw_to_sse(self, event_id: str, raw_event: Any) -> str | None:
27
+ """Convert raw event to SSE format"""
28
+ stream_mode, payload, namespace = self._parse_raw_event(raw_event)
29
+ return self._create_sse_event(stream_mode, payload, event_id, namespace)
30
+
31
+ def convert_stored_to_sse(self, stored_event, run_id: str = None) -> str | None:
32
+ """Convert stored event to SSE format"""
33
+ event_type = stored_event.event
34
+ data = stored_event.data
35
+ event_id = stored_event.id
36
+
37
+ # Handle special cases with custom logic
38
+ if event_type == "messages":
39
+ message_chunk = data.get("message_chunk")
40
+ metadata = data.get("metadata")
41
+ if message_chunk is None:
42
+ return None
43
+ message_data = (message_chunk, metadata) if metadata is not None else message_chunk
44
+ return create_messages_event(message_data, event_id=event_id)
45
+ elif event_type == "metadata":
46
+ return create_metadata_event(run_id, event_id)
47
+ elif event_type == "debug":
48
+ return create_debug_event(data.get("debug"), event_id)
49
+ elif event_type == "end":
50
+ return create_end_event(event_id)
51
+ elif event_type == "error":
52
+ return create_error_event(data.get("error"), event_id)
53
+ else:
54
+ # Handle all other event types generically (values, state, logs, tasks, etc.)
55
+ # Extract payload - try common patterns
56
+ payload = data.get(event_type) or data.get("chunk") or data
57
+ return format_sse_message(event_type, payload, event_id)
58
+
59
+ def _parse_raw_event(self, raw_event: Any) -> tuple[str, Any, list[str] | None]:
60
+ """
61
+ Parse raw event into (stream_mode, payload, namespace).
62
+
63
+ When subgraphs=True, 3-tuple format is (namespace, mode, chunk).
64
+ When subgraphs=False, 3-tuple format is (node_path, mode, chunk) for legacy support.
65
+ """
66
+ namespace = None
67
+
68
+ if isinstance(raw_event, tuple):
69
+ if len(raw_event) == 2:
70
+ # Standard format: (mode, chunk)
71
+ return raw_event[0], raw_event[1], None
72
+ elif len(raw_event) == 3:
73
+ if self.subgraphs:
74
+ # Subgraphs format: (namespace, mode, chunk)
75
+ namespace, mode, chunk = raw_event
76
+ # Normalize namespace to list format
77
+ if namespace is None or (isinstance(namespace, (list, tuple)) and not namespace):
78
+ # Handle None or empty tuple/list - no namespace prefix
79
+ namespace_list = None
80
+ elif isinstance(namespace, (list, tuple)):
81
+ # Convert tuple/list to list of strings
82
+ namespace_list = [str(item) for item in namespace]
83
+ elif isinstance(namespace, str):
84
+ # Handle string namespace (shouldn't happen but be safe)
85
+ namespace_list = [namespace] if namespace else None
86
+ else:
87
+ # Fallback - shouldn't reach here
88
+ namespace_list = [str(namespace)]
89
+ return mode, chunk, namespace_list
90
+ else:
91
+ # Legacy format: (node_path, mode, chunk)
92
+ return raw_event[1], raw_event[2], None
93
+
94
+ # Non-tuple events are values mode
95
+ return "values", raw_event, None
96
+
97
+ def _create_sse_event(
98
+ self,
99
+ stream_mode: str,
100
+ payload: Any,
101
+ event_id: str,
102
+ namespace: list[str] | None = None,
103
+ ) -> str | None:
104
+ """
105
+ Create SSE event based on stream mode.
106
+
107
+ Args:
108
+ stream_mode: The stream mode (e.g., "messages", "values")
109
+ payload: The event payload
110
+ event_id: The event ID
111
+ namespace: Optional namespace for subgraph events (e.g., ["subagent_name"])
112
+
113
+ Returns:
114
+ SSE-formatted event string or None
115
+ """
116
+ # Prefix event type with namespace if subgraphs enabled
117
+ if self.subgraphs and namespace:
118
+ event_type = f"{stream_mode}|{'|'.join(namespace)}"
119
+ else:
120
+ event_type = stream_mode
121
+
122
+ # Handle updates events (rarely reached - updates are filtered in graph_streaming)
123
+ if stream_mode == "updates":
124
+ if isinstance(payload, dict) and "__interrupt__" in payload:
125
+ # Convert interrupt updates to values events
126
+ if self.subgraphs and namespace:
127
+ event_type = f"values|{'|'.join(namespace)}"
128
+ else:
129
+ event_type = "values"
130
+ return format_sse_message(event_type, payload, event_id)
131
+ else:
132
+ # Non-interrupt updates (pass through as-is)
133
+ return format_sse_message(event_type, payload, event_id)
134
+
135
+ # Handle specific message event types (Studio compatibility and standard messages)
136
+ if stream_mode in (
137
+ "messages/metadata",
138
+ "messages/partial",
139
+ "messages/complete",
140
+ ):
141
+ # Studio-specific message events - pass through as-is
142
+ return format_sse_message(stream_mode, payload, event_id)
143
+ elif stream_mode == "messages" or event_type.startswith("messages"):
144
+ return create_messages_event(payload, event_type=event_type, event_id=event_id)
145
+ elif stream_mode == "values" or event_type.startswith("values"):
146
+ # For values events, use format_sse_message directly to support namespaces
147
+ return format_sse_message(event_type, payload, event_id)
148
+ elif stream_mode == "debug":
149
+ return create_debug_event(payload, event_id)
150
+ elif stream_mode == "end":
151
+ return create_end_event(event_id)
152
+ elif stream_mode == "error":
153
+ return create_error_event(payload, event_id)
154
+ else:
155
+ # Generic handler for all other event types (state, logs, tasks, events, etc.)
156
+ # This automatically supports any new event types without code changes
157
+ return format_sse_message(event_type, payload, event_id)
@@ -0,0 +1,196 @@
1
+ """Persistent event store for SSE replay functionality (Postgres-backed)."""
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import json
6
+ from datetime import UTC, datetime
7
+
8
+ import structlog
9
+ from psycopg.types.json import Jsonb
10
+
11
+ from aegra_api.core.database import db_manager
12
+ from aegra_api.core.serializers import GeneralSerializer
13
+ from aegra_api.core.sse import SSEEvent
14
+
15
+ logger = structlog.get_logger(__name__)
16
+
17
+
18
+ class EventStore:
19
+ """Postgres-backed event store for SSE replay functionality"""
20
+
21
+ CLEANUP_INTERVAL = 300 # seconds
22
+
23
+ def __init__(self) -> None:
24
+ self._cleanup_task: asyncio.Task | None = None
25
+
26
+ async def start_cleanup_task(self) -> None:
27
+ if self._cleanup_task is None or self._cleanup_task.done():
28
+ self._cleanup_task = asyncio.create_task(self._cleanup_loop())
29
+
30
+ async def stop_cleanup_task(self) -> None:
31
+ if self._cleanup_task and not self._cleanup_task.done():
32
+ self._cleanup_task.cancel()
33
+ with contextlib.suppress(asyncio.CancelledError):
34
+ await self._cleanup_task
35
+
36
+ async def store_event(self, run_id: str, event: SSEEvent) -> None:
37
+ """Persist an event with sequence extracted from id suffix.
38
+
39
+ We expect event.id format: f"{run_id}_event_{seq}".
40
+ """
41
+ try:
42
+ seq = int(str(event.id).split("_event_")[-1])
43
+ except Exception:
44
+ seq = 0
45
+
46
+ # USE SHARED POOL
47
+ if not db_manager.lg_pool:
48
+ logger.error("Database pool not initialized!")
49
+ return
50
+
51
+ async with db_manager.lg_pool.connection() as conn, conn.cursor() as cur:
52
+ await cur.execute(
53
+ """
54
+ INSERT INTO run_events (id, run_id, seq, event, data, created_at)
55
+ VALUES (%(id)s, %(run_id)s, %(seq)s, %(event)s, %(data)s, NOW())
56
+ ON CONFLICT (id) DO NOTHING
57
+ """,
58
+ {
59
+ "id": event.id,
60
+ "run_id": run_id,
61
+ "seq": seq,
62
+ "event": event.event,
63
+ "data": Jsonb(event.data),
64
+ },
65
+ )
66
+
67
+ async def get_events_since(self, run_id: str, last_event_id: str) -> list[SSEEvent]:
68
+ """Fetch all events for run after last_event_id sequence."""
69
+ try:
70
+ last_seq = int(str(last_event_id).split("_event_")[-1])
71
+ except Exception:
72
+ last_seq = -1
73
+
74
+ if not db_manager.lg_pool:
75
+ return []
76
+
77
+ async with db_manager.lg_pool.connection() as conn, conn.cursor() as cur:
78
+ await cur.execute(
79
+ """
80
+ SELECT id, event, data, created_at
81
+ FROM run_events
82
+ WHERE run_id = %(run_id)s AND seq > %(last_seq)s
83
+ ORDER BY seq ASC
84
+ """,
85
+ {"run_id": run_id, "last_seq": last_seq},
86
+ )
87
+ rows = await cur.fetchall()
88
+
89
+ return [SSEEvent(id=r["id"], event=r["event"], data=r["data"], timestamp=r["created_at"]) for r in rows]
90
+
91
+ async def get_all_events(self, run_id: str) -> list[SSEEvent]:
92
+ if not db_manager.lg_pool:
93
+ return []
94
+
95
+ async with db_manager.lg_pool.connection() as conn, conn.cursor() as cur:
96
+ await cur.execute(
97
+ """
98
+ SELECT id, event, data, created_at
99
+ FROM run_events
100
+ WHERE run_id = %(run_id)s
101
+ ORDER BY seq ASC
102
+ """,
103
+ {"run_id": run_id},
104
+ )
105
+ rows = await cur.fetchall()
106
+
107
+ return [SSEEvent(id=r["id"], event=r["event"], data=r["data"], timestamp=r["created_at"]) for r in rows]
108
+
109
+ async def cleanup_events(self, run_id: str) -> None:
110
+ if not db_manager.lg_pool:
111
+ return
112
+
113
+ async with db_manager.lg_pool.connection() as conn, conn.cursor() as cur:
114
+ await cur.execute(
115
+ "DELETE FROM run_events WHERE run_id = %(run_id)s",
116
+ {"run_id": run_id},
117
+ )
118
+
119
+ async def get_run_info(self, run_id: str) -> dict | None:
120
+ if not db_manager.lg_pool:
121
+ return
122
+
123
+ async with db_manager.lg_pool.connection() as conn, conn.cursor() as cur:
124
+ # 1. Fetch sequence range
125
+ await cur.execute(
126
+ """
127
+ SELECT MIN(seq) AS first_seq, MAX(seq) AS last_seq
128
+ FROM run_events
129
+ WHERE run_id = %(run_id)s
130
+ """,
131
+ {"run_id": run_id},
132
+ )
133
+ row = await cur.fetchone()
134
+
135
+ if not row or row["last_seq"] is None:
136
+ return None
137
+
138
+ # 2. Fetch last event
139
+ await cur.execute(
140
+ """
141
+ SELECT id, created_at
142
+ FROM run_events
143
+ WHERE run_id = %(run_id)s AND seq = %(last_seq)s
144
+ LIMIT 1
145
+ """,
146
+ {"run_id": run_id, "last_seq": row["last_seq"]},
147
+ )
148
+ last = await cur.fetchone()
149
+
150
+ return {
151
+ "run_id": run_id,
152
+ "event_count": int(row["last_seq"]) - int(row["first_seq"]) + 1 if row["first_seq"] is not None else 0,
153
+ "first_event_time": None,
154
+ "last_event_time": last["created_at"] if last else None,
155
+ "last_event_id": last["id"] if last else None,
156
+ }
157
+
158
+ async def _cleanup_loop(self) -> None:
159
+ while True:
160
+ try:
161
+ await asyncio.sleep(self.CLEANUP_INTERVAL)
162
+ await self._cleanup_old_runs()
163
+ except asyncio.CancelledError:
164
+ break
165
+ except Exception as e:
166
+ logger.error(f"Error in event store cleanup: {e}")
167
+
168
+ async def _cleanup_old_runs(self) -> None:
169
+ # Retain events for 1 hour by default
170
+ if not db_manager.lg_pool:
171
+ return
172
+
173
+ try:
174
+ async with db_manager.lg_pool.connection() as conn, conn.cursor() as cur:
175
+ await cur.execute("DELETE FROM run_events WHERE created_at < NOW() - INTERVAL '1 hour'")
176
+ except Exception as e:
177
+ logger.error(f"Failed to cleanup old runs: {e}")
178
+
179
+
180
+ # Global event store instance
181
+ event_store = EventStore()
182
+
183
+
184
+ async def store_sse_event(run_id: str, event_id: str, event_type: str, data: dict) -> SSEEvent:
185
+ """Store SSE event with proper serialization"""
186
+ serializer = GeneralSerializer()
187
+
188
+ # Ensure JSONB-safe data by serializing complex objects
189
+ try:
190
+ safe_data = json.loads(json.dumps(data, default=serializer.serialize))
191
+ except Exception:
192
+ # Fallback to stringifying as a last resort to avoid crashing the run
193
+ safe_data = {"raw": str(data)}
194
+ event = SSEEvent(id=event_id, event=event_type, data=safe_data, timestamp=datetime.now(UTC))
195
+ await event_store.store_event(run_id, event)
196
+ return event