aegra-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. aegra_api/__init__.py +3 -0
  2. aegra_api/api/__init__.py +1 -0
  3. aegra_api/api/assistants.py +235 -0
  4. aegra_api/api/runs.py +1110 -0
  5. aegra_api/api/store.py +200 -0
  6. aegra_api/api/threads.py +761 -0
  7. aegra_api/config.py +204 -0
  8. aegra_api/constants.py +5 -0
  9. aegra_api/core/__init__.py +0 -0
  10. aegra_api/core/app_loader.py +91 -0
  11. aegra_api/core/auth_ctx.py +65 -0
  12. aegra_api/core/auth_deps.py +186 -0
  13. aegra_api/core/auth_handlers.py +248 -0
  14. aegra_api/core/auth_middleware.py +331 -0
  15. aegra_api/core/database.py +123 -0
  16. aegra_api/core/health.py +131 -0
  17. aegra_api/core/orm.py +165 -0
  18. aegra_api/core/route_merger.py +69 -0
  19. aegra_api/core/serializers/__init__.py +7 -0
  20. aegra_api/core/serializers/base.py +22 -0
  21. aegra_api/core/serializers/general.py +54 -0
  22. aegra_api/core/serializers/langgraph.py +102 -0
  23. aegra_api/core/sse.py +178 -0
  24. aegra_api/main.py +303 -0
  25. aegra_api/middleware/__init__.py +4 -0
  26. aegra_api/middleware/double_encoded_json.py +74 -0
  27. aegra_api/middleware/logger_middleware.py +95 -0
  28. aegra_api/models/__init__.py +76 -0
  29. aegra_api/models/assistants.py +81 -0
  30. aegra_api/models/auth.py +62 -0
  31. aegra_api/models/enums.py +29 -0
  32. aegra_api/models/errors.py +29 -0
  33. aegra_api/models/runs.py +124 -0
  34. aegra_api/models/store.py +67 -0
  35. aegra_api/models/threads.py +152 -0
  36. aegra_api/observability/__init__.py +1 -0
  37. aegra_api/observability/base.py +88 -0
  38. aegra_api/observability/otel.py +133 -0
  39. aegra_api/observability/setup.py +27 -0
  40. aegra_api/observability/targets/__init__.py +11 -0
  41. aegra_api/observability/targets/base.py +18 -0
  42. aegra_api/observability/targets/langfuse.py +33 -0
  43. aegra_api/observability/targets/otlp.py +38 -0
  44. aegra_api/observability/targets/phoenix.py +24 -0
  45. aegra_api/services/__init__.py +0 -0
  46. aegra_api/services/assistant_service.py +569 -0
  47. aegra_api/services/base_broker.py +59 -0
  48. aegra_api/services/broker.py +141 -0
  49. aegra_api/services/event_converter.py +157 -0
  50. aegra_api/services/event_store.py +196 -0
  51. aegra_api/services/graph_streaming.py +433 -0
  52. aegra_api/services/langgraph_service.py +456 -0
  53. aegra_api/services/streaming_service.py +362 -0
  54. aegra_api/services/thread_state_service.py +128 -0
  55. aegra_api/settings.py +124 -0
  56. aegra_api/utils/__init__.py +3 -0
  57. aegra_api/utils/assistants.py +23 -0
  58. aegra_api/utils/run_utils.py +60 -0
  59. aegra_api/utils/setup_logging.py +122 -0
  60. aegra_api/utils/sse_utils.py +26 -0
  61. aegra_api/utils/status_compat.py +57 -0
  62. aegra_api-0.1.0.dist-info/METADATA +244 -0
  63. aegra_api-0.1.0.dist-info/RECORD +64 -0
  64. aegra_api-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,433 @@
1
+ """Graph event streaming service.
2
+
3
+ This module provides streaming functionality for LangGraph graph executions,
4
+ handling message accumulation, event processing, and multiple stream modes.
5
+ """
6
+
7
+ import uuid
8
+ from collections.abc import AsyncIterator, Callable
9
+ from contextlib import aclosing
10
+ from typing import Any, cast
11
+
12
+ import structlog
13
+ from langchain_core.messages import (
14
+ AIMessageChunk,
15
+ BaseMessage,
16
+ BaseMessageChunk,
17
+ ToolMessageChunk,
18
+ convert_to_messages,
19
+ message_chunk_to_message,
20
+ )
21
+ from langchain_core.runnables import RunnableConfig
22
+ from langgraph.errors import (
23
+ EmptyChannelError,
24
+ EmptyInputError,
25
+ GraphRecursionError,
26
+ InvalidUpdateError,
27
+ )
28
+ from langgraph.pregel.debug import CheckpointPayload, TaskResultPayload
29
+ from pydantic import ValidationError
30
+ from pydantic.v1 import ValidationError as ValidationErrorLegacy
31
+
32
+ from aegra_api.utils.run_utils import _filter_context_by_schema
33
+
34
+ logger = structlog.getLogger(__name__)
35
+
36
+ # Type alias for stream output
37
+ AnyStream = AsyncIterator[tuple[str, Any]]
38
+
39
+
40
+ def _normalize_checkpoint_task(task: dict[str, Any]) -> dict[str, Any]:
41
+ """Normalize checkpoint task structure by extracting configurable state."""
42
+ state_data = task.get("state")
43
+
44
+ # Only process if state contains configurable data
45
+ if not state_data or "configurable" not in state_data:
46
+ return task
47
+
48
+ configurable = state_data.get("configurable")
49
+ if not configurable:
50
+ return task
51
+
52
+ # Restructure task with checkpoint reference
53
+ task["checkpoint"] = configurable
54
+ del task["state"]
55
+ return task
56
+
57
+
58
+ def _normalize_checkpoint_payload(
59
+ payload: CheckpointPayload | None,
60
+ ) -> dict[str, Any] | None:
61
+ """Normalize debug checkpoint payload structure.
62
+
63
+ Ensures checkpoint payloads have consistent task formatting.
64
+ """
65
+ if not payload:
66
+ return None
67
+
68
+ # Process all tasks in the checkpoint
69
+ normalized_tasks = [_normalize_checkpoint_task(t) for t in payload["tasks"]]
70
+
71
+ return {
72
+ **payload,
73
+ "tasks": normalized_tasks,
74
+ }
75
+
76
+
77
+ async def stream_graph_events(
78
+ graph: Any,
79
+ input_data: Any,
80
+ config: RunnableConfig,
81
+ *,
82
+ stream_mode: list[str],
83
+ context: dict[str, Any] | None = None,
84
+ subgraphs: bool = False,
85
+ output_keys: list[str] | None = None,
86
+ on_checkpoint: Callable[[CheckpointPayload | None], None] = lambda _: None,
87
+ on_task_result: Callable[[TaskResultPayload], None] = lambda _: None,
88
+ ) -> AnyStream:
89
+ """Stream events from a graph execution.
90
+
91
+ Handles both standard streaming (astream) and event-based streaming (astream_events)
92
+ depending on the graph type and requested stream modes. Automatically accumulates
93
+ message chunks and yields appropriate partial/complete events.
94
+
95
+ Args:
96
+ graph: The graph instance to execute
97
+ input_data: Input data for graph execution
98
+ config: RunnableConfig for execution
99
+ stream_mode: List of stream modes (e.g., ["messages", "values", "debug"])
100
+ context: Optional context dictionary
101
+ subgraphs: Whether to include subgraph namespaces in event types
102
+ output_keys: Optional output channel keys for astream
103
+ on_checkpoint: Callback invoked when checkpoint events are received
104
+ on_task_result: Callback invoked when task result events are received
105
+
106
+ Yields:
107
+ Tuples of (mode, payload) where mode is the stream mode and payload is the event data
108
+ """
109
+ run_id = str(config.get("configurable", {}).get("run_id", uuid.uuid4()))
110
+
111
+ # Prepare stream modes
112
+ stream_modes_set: set[str] = set(stream_mode) - {"events"}
113
+ if "debug" not in stream_modes_set:
114
+ stream_modes_set.add("debug")
115
+
116
+ # Check if graph is a remote (JavaScript) implementation
117
+ try:
118
+ from langgraph_api.js.base import BaseRemotePregel
119
+
120
+ is_js_graph = isinstance(graph, BaseRemotePregel)
121
+ except ImportError:
122
+ is_js_graph = False
123
+
124
+ # Python graphs need messages-tuple converted to standard messages mode
125
+ if "messages-tuple" in stream_modes_set and not is_js_graph:
126
+ stream_modes_set.discard("messages-tuple")
127
+ stream_modes_set.add("messages")
128
+
129
+ # Ensure updates mode is enabled for interrupt support
130
+ updates_explicitly_requested = "updates" in stream_modes_set
131
+ if not updates_explicitly_requested:
132
+ stream_modes_set.add("updates")
133
+
134
+ # Track whether to filter non-interrupt updates
135
+ only_interrupt_updates = not updates_explicitly_requested
136
+
137
+ # Apply context schema filtering if available
138
+ if context and not is_js_graph:
139
+ try:
140
+ context_schema = graph.get_context_jsonschema()
141
+ context = await _filter_context_by_schema(context, context_schema)
142
+ except Exception as e:
143
+ await logger.adebug(f"Failed to get context schema for filtering: {e}", exc_info=e)
144
+
145
+ # Initialize streaming state
146
+ messages: dict[str, BaseMessageChunk] = {}
147
+
148
+ # Choose streaming method based on mode and graph type
149
+ use_astream_events = "events" in stream_mode or is_js_graph
150
+
151
+ # Yield metadata event
152
+ yield (
153
+ "metadata",
154
+ {"run_id": run_id, "attempt": config.get("metadata", {}).get("run_attempt", 1)},
155
+ )
156
+
157
+ # Stream execution using appropriate method
158
+ if use_astream_events:
159
+ async with aclosing(
160
+ graph.astream_events(
161
+ input_data,
162
+ config,
163
+ context=context,
164
+ version="v2",
165
+ stream_mode=list(stream_modes_set),
166
+ subgraphs=subgraphs,
167
+ )
168
+ ) as stream:
169
+ async for event in stream:
170
+ event = cast("dict", event)
171
+
172
+ # Filter events marked as hidden
173
+ if event.get("tags") and "langsmith:hidden" in event["tags"]:
174
+ continue
175
+
176
+ # Extract message events from JavaScript graphs
177
+ is_message_event = "messages" in stream_mode and is_js_graph and event.get("event") == "on_custom_event"
178
+
179
+ if is_message_event:
180
+ event_name = event.get("name")
181
+ if event_name in (
182
+ "messages/complete",
183
+ "messages/partial",
184
+ "messages/metadata",
185
+ ):
186
+ yield event_name, event["data"]
187
+
188
+ # Process on_chain_stream events
189
+ if event.get("event") == "on_chain_stream" and event.get("run_id") == run_id:
190
+ chunk_data = event.get("data", {}).get("chunk")
191
+ if chunk_data is None:
192
+ continue
193
+
194
+ if subgraphs:
195
+ if isinstance(chunk_data, (tuple, list)) and len(chunk_data) == 3:
196
+ ns, mode, chunk = chunk_data
197
+ else:
198
+ # Fallback: assume 2-tuple
199
+ mode, chunk = chunk_data
200
+ ns = None
201
+ else:
202
+ if isinstance(chunk_data, (tuple, list)) and len(chunk_data) == 2:
203
+ mode, chunk = chunk_data
204
+ else:
205
+ # Single value
206
+ mode = "values"
207
+ chunk = chunk_data
208
+ ns = None
209
+
210
+ # Shared logic for processing events
211
+ processed = _process_stream_event(
212
+ mode=mode,
213
+ chunk=chunk,
214
+ namespace=ns,
215
+ subgraphs=subgraphs,
216
+ stream_mode=stream_mode,
217
+ messages=messages,
218
+ only_interrupt_updates=only_interrupt_updates,
219
+ on_checkpoint=on_checkpoint,
220
+ on_task_result=on_task_result,
221
+ )
222
+
223
+ if processed:
224
+ for event_tuple in processed:
225
+ yield event_tuple
226
+
227
+ # Update checkpoint state for debug tracking
228
+ if mode == "debug" and chunk.get("type") == "checkpoint":
229
+ _normalize_checkpoint_payload(chunk.get("payload"))
230
+
231
+ # Also yield as raw "events" event if "events" mode requested
232
+ # This ensures on_chain_stream events are available as raw events
233
+ if "events" in stream_mode:
234
+ yield "events", event
235
+
236
+ # Pass through raw events if "events" mode requested
237
+ elif "events" in stream_mode:
238
+ yield "events", event
239
+
240
+ else:
241
+ # Use astream for standard streaming
242
+ if output_keys is None:
243
+ output_keys = getattr(graph, "output_channels", None)
244
+
245
+ async with aclosing(
246
+ graph.astream(
247
+ input_data,
248
+ config,
249
+ context=context,
250
+ stream_mode=list(stream_modes_set),
251
+ output_keys=output_keys,
252
+ subgraphs=subgraphs,
253
+ )
254
+ ) as stream:
255
+ async for event in stream:
256
+ # Parse event tuple
257
+ if subgraphs:
258
+ if isinstance(event, tuple) and len(event) == 3:
259
+ ns, mode, chunk = event
260
+ else:
261
+ # Fallback: assume 2-tuple format
262
+ mode, chunk = cast("tuple[str, dict[str, Any]]", event)
263
+ ns = None
264
+ else:
265
+ mode, chunk = cast("tuple[str, dict[str, Any]]", event)
266
+ ns = None
267
+
268
+ # Shared logic for processing events
269
+ processed = _process_stream_event(
270
+ mode=mode,
271
+ chunk=chunk,
272
+ namespace=ns,
273
+ subgraphs=subgraphs,
274
+ stream_mode=stream_mode,
275
+ messages=messages,
276
+ only_interrupt_updates=only_interrupt_updates,
277
+ on_checkpoint=on_checkpoint,
278
+ on_task_result=on_task_result,
279
+ )
280
+
281
+ if processed:
282
+ for event_tuple in processed:
283
+ yield event_tuple
284
+
285
+ # Update checkpoint state for debug tracking
286
+ if mode == "debug" and chunk.get("type") == "checkpoint":
287
+ _normalize_checkpoint_payload(chunk.get("payload"))
288
+
289
+
290
+ def _process_stream_event(
291
+ mode: str,
292
+ chunk: Any,
293
+ namespace: str | None,
294
+ subgraphs: bool,
295
+ stream_mode: list[str],
296
+ messages: dict[str, BaseMessageChunk],
297
+ only_interrupt_updates: bool,
298
+ on_checkpoint: Callable[[CheckpointPayload | None], None],
299
+ on_task_result: Callable[[TaskResultPayload], None],
300
+ ) -> list[tuple[str, Any]] | None:
301
+ """Process a single stream event and generate output events.
302
+
303
+ Handles message accumulation, debug events, and stream mode routing.
304
+ Used by both astream and astream_events execution paths.
305
+
306
+ Args:
307
+ mode: The stream mode (e.g., "messages", "values", "debug")
308
+ chunk: The event chunk data
309
+ namespace: Optional namespace for subgraph events
310
+ subgraphs: Whether subgraph namespaces should be included
311
+ stream_mode: List of requested stream modes
312
+ messages: Dictionary for accumulating message chunks by ID
313
+ only_interrupt_updates: Whether to filter non-interrupt updates
314
+ on_checkpoint: Callback for checkpoint events
315
+ on_task_result: Callback for task result events
316
+
317
+ Returns:
318
+ List of (mode, payload) tuples to yield, or None if nothing to yield
319
+ """
320
+ results: list[tuple[str, Any]] = []
321
+
322
+ # Process debug mode events
323
+ if mode == "debug":
324
+ debug_type = chunk.get("type")
325
+
326
+ if debug_type == "checkpoint":
327
+ # Normalize checkpoint and invoke callback
328
+ normalized = _normalize_checkpoint_payload(chunk.get("payload"))
329
+ chunk["payload"] = normalized
330
+ on_checkpoint(normalized)
331
+ elif debug_type == "task_result":
332
+ # Forward task results to callback
333
+ on_task_result(chunk.get("payload"))
334
+
335
+ # Handle messages mode
336
+ if mode == "messages":
337
+ if "messages-tuple" in stream_mode:
338
+ # Pass through raw tuple format
339
+ if subgraphs and namespace:
340
+ ns_str = "|".join(namespace) if isinstance(namespace, (list, tuple)) else str(namespace)
341
+ results.append((f"messages|{ns_str}", chunk))
342
+ else:
343
+ results.append(("messages", chunk))
344
+ else:
345
+ # Accumulate and yield messages/partial or messages/complete
346
+ msg_, meta = cast("tuple[BaseMessage | dict, dict[str, Any]]", chunk)
347
+
348
+ # Handle dict-to-message conversion
349
+ is_chunk_type = False
350
+ if isinstance(msg_, dict):
351
+ msg_type = msg_.get("type", "").lower()
352
+ msg_role = msg_.get("role", "").lower()
353
+
354
+ # Detect if this is a streaming chunk based on type/role indicators
355
+ has_chunk_indicator = "chunk" in msg_type or "chunk" in msg_role
356
+
357
+ if has_chunk_indicator:
358
+ # Instantiate appropriate chunk class based on role
359
+ if "ai" in msg_role:
360
+ msg = AIMessageChunk(**msg_) # type: ignore[arg-type]
361
+ elif "tool" in msg_role:
362
+ msg = ToolMessageChunk(**msg_) # type: ignore[arg-type]
363
+ else:
364
+ msg = BaseMessageChunk(**msg_) # type: ignore[arg-type]
365
+ is_chunk_type = True
366
+ else:
367
+ # Complete message - convert to proper message instance
368
+ msg = convert_to_messages([msg_])[0]
369
+ else:
370
+ msg = msg_
371
+
372
+ # Track and accumulate messages by ID
373
+ msg_id = msg.id
374
+ is_new_message = msg_id not in messages
375
+
376
+ if is_new_message:
377
+ messages[msg_id] = msg
378
+ # First time seeing this message - send metadata
379
+ results.append(("messages/metadata", {msg_id: {"metadata": meta}}))
380
+ else:
381
+ # Accumulate additional chunks
382
+ messages[msg_id] += msg
383
+
384
+ # Determine event type based on message instance type
385
+ is_partial_message = isinstance(msg, BaseMessageChunk)
386
+ event_name = "messages/partial" if is_partial_message else "messages/complete"
387
+
388
+ # Format accumulated message for output
389
+ if is_chunk_type:
390
+ # Keep raw chunks for streaming messages
391
+ formatted_msg = messages[msg_id]
392
+ else:
393
+ # Convert accumulated chunks to complete message
394
+ formatted_msg = message_chunk_to_message(messages[msg_id])
395
+
396
+ results.append((event_name, [formatted_msg]))
397
+
398
+ # Handle other stream modes
399
+ elif mode in stream_mode:
400
+ if subgraphs and namespace:
401
+ ns_str = "|".join(namespace) if isinstance(namespace, (list, tuple)) else str(namespace)
402
+ results.append((f"{mode}|{ns_str}", chunk))
403
+ else:
404
+ results.append((mode, chunk))
405
+
406
+ # Special handling for interrupt events when updates mode not explicitly requested
407
+ elif mode == "updates" and only_interrupt_updates:
408
+ # Check if this update contains interrupt data
409
+ has_interrupt_data = (
410
+ isinstance(chunk, dict) and "__interrupt__" in chunk and len(chunk.get("__interrupt__", [])) > 0
411
+ )
412
+
413
+ if has_interrupt_data:
414
+ # Remap interrupt updates to values events for backward compatibility
415
+ if subgraphs and namespace:
416
+ ns_str = "|".join(namespace) if isinstance(namespace, (list, tuple)) else str(namespace)
417
+ results.append((f"values|{ns_str}", chunk))
418
+ else:
419
+ results.append(("values", chunk))
420
+
421
+ return results if results else None
422
+
423
+
424
+ # Expected error types for error handling
425
+ EXPECTED_ERRORS = (
426
+ ValueError,
427
+ InvalidUpdateError,
428
+ GraphRecursionError,
429
+ EmptyInputError,
430
+ EmptyChannelError,
431
+ ValidationError,
432
+ ValidationErrorLegacy,
433
+ )