aegra-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. aegra_api/__init__.py +3 -0
  2. aegra_api/api/__init__.py +1 -0
  3. aegra_api/api/assistants.py +235 -0
  4. aegra_api/api/runs.py +1110 -0
  5. aegra_api/api/store.py +200 -0
  6. aegra_api/api/threads.py +761 -0
  7. aegra_api/config.py +204 -0
  8. aegra_api/constants.py +5 -0
  9. aegra_api/core/__init__.py +0 -0
  10. aegra_api/core/app_loader.py +91 -0
  11. aegra_api/core/auth_ctx.py +65 -0
  12. aegra_api/core/auth_deps.py +186 -0
  13. aegra_api/core/auth_handlers.py +248 -0
  14. aegra_api/core/auth_middleware.py +331 -0
  15. aegra_api/core/database.py +123 -0
  16. aegra_api/core/health.py +131 -0
  17. aegra_api/core/orm.py +165 -0
  18. aegra_api/core/route_merger.py +69 -0
  19. aegra_api/core/serializers/__init__.py +7 -0
  20. aegra_api/core/serializers/base.py +22 -0
  21. aegra_api/core/serializers/general.py +54 -0
  22. aegra_api/core/serializers/langgraph.py +102 -0
  23. aegra_api/core/sse.py +178 -0
  24. aegra_api/main.py +303 -0
  25. aegra_api/middleware/__init__.py +4 -0
  26. aegra_api/middleware/double_encoded_json.py +74 -0
  27. aegra_api/middleware/logger_middleware.py +95 -0
  28. aegra_api/models/__init__.py +76 -0
  29. aegra_api/models/assistants.py +81 -0
  30. aegra_api/models/auth.py +62 -0
  31. aegra_api/models/enums.py +29 -0
  32. aegra_api/models/errors.py +29 -0
  33. aegra_api/models/runs.py +124 -0
  34. aegra_api/models/store.py +67 -0
  35. aegra_api/models/threads.py +152 -0
  36. aegra_api/observability/__init__.py +1 -0
  37. aegra_api/observability/base.py +88 -0
  38. aegra_api/observability/otel.py +133 -0
  39. aegra_api/observability/setup.py +27 -0
  40. aegra_api/observability/targets/__init__.py +11 -0
  41. aegra_api/observability/targets/base.py +18 -0
  42. aegra_api/observability/targets/langfuse.py +33 -0
  43. aegra_api/observability/targets/otlp.py +38 -0
  44. aegra_api/observability/targets/phoenix.py +24 -0
  45. aegra_api/services/__init__.py +0 -0
  46. aegra_api/services/assistant_service.py +569 -0
  47. aegra_api/services/base_broker.py +59 -0
  48. aegra_api/services/broker.py +141 -0
  49. aegra_api/services/event_converter.py +157 -0
  50. aegra_api/services/event_store.py +196 -0
  51. aegra_api/services/graph_streaming.py +433 -0
  52. aegra_api/services/langgraph_service.py +456 -0
  53. aegra_api/services/streaming_service.py +362 -0
  54. aegra_api/services/thread_state_service.py +128 -0
  55. aegra_api/settings.py +124 -0
  56. aegra_api/utils/__init__.py +3 -0
  57. aegra_api/utils/assistants.py +23 -0
  58. aegra_api/utils/run_utils.py +60 -0
  59. aegra_api/utils/setup_logging.py +122 -0
  60. aegra_api/utils/sse_utils.py +26 -0
  61. aegra_api/utils/status_compat.py +57 -0
  62. aegra_api-0.1.0.dist-info/METADATA +244 -0
  63. aegra_api-0.1.0.dist-info/RECORD +64 -0
  64. aegra_api-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,569 @@
1
+ """Service layer for assistant business logic
2
+
3
+ This service encapsulates all business logic for assistant management, following
4
+ a layered architecture pattern. The code was extracted from api/assistants.py
5
+ to separate concerns and improve maintainability.
6
+
7
+ Responsibilities:
8
+ - Business logic and validation
9
+ - Database operations via SQLAlchemy ORM
10
+ - Graph schema extraction and manipulation
11
+ - Coordination between different components
12
+
13
+ This is the first service layer implementation in Aegra. The pattern will be
14
+ applied to other APIs (runs, threads, crons) as part of ongoing refactoring.
15
+ """
16
+
17
+ import uuid
18
+ from datetime import UTC, datetime
19
+ from typing import Any
20
+ from uuid import uuid4
21
+
22
+ from fastapi import Depends, HTTPException
23
+ from sqlalchemy import func, or_, select, update
24
+ from sqlalchemy.ext.asyncio import AsyncSession
25
+
26
+ from aegra_api.core.orm import Assistant as AssistantORM
27
+ from aegra_api.core.orm import AssistantVersion as AssistantVersionORM
28
+ from aegra_api.core.orm import get_session
29
+ from aegra_api.models import Assistant, AssistantCreate, AssistantUpdate
30
+ from aegra_api.services.langgraph_service import LangGraphService, get_langgraph_service
31
+
32
+
33
+ def to_pydantic(row: AssistantORM) -> Assistant:
34
+ """Convert SQLAlchemy ORM object to Pydantic model with proper type casting.
35
+
36
+ Uses from_attributes=True because Assistant ORM has attribute/column name mismatch:
37
+ - ORM attribute: metadata_dict
38
+ - DB column: metadata
39
+ - Pydantic field: metadata (with alias="metadata_dict")
40
+
41
+ This is different from Thread/Run where attribute names match column names.
42
+ """
43
+ # Cast UUIDs to str so they match the Pydantic schema
44
+ if hasattr(row, "assistant_id") and row.assistant_id is not None:
45
+ row.assistant_id = str(row.assistant_id)
46
+ if hasattr(row, "user_id") and isinstance(row.user_id, uuid.UUID):
47
+ row.user_id = str(row.user_id)
48
+
49
+ # Use Pydantic's built-in ORM conversion with from_attributes=True
50
+ return Assistant.model_validate(row, from_attributes=True)
51
+
52
+
53
+ def _state_jsonschema(graph) -> dict | None:
54
+ """Extract state schema from graph channels"""
55
+ from typing import Any
56
+
57
+ from langchain_core.runnables.utils import create_model
58
+
59
+ fields: dict = {}
60
+ for k in graph.stream_channels_list:
61
+ v = graph.channels[k]
62
+ try:
63
+ create_model(k, __root__=(v.UpdateType, None)).model_json_schema()
64
+ fields[k] = (v.UpdateType, None)
65
+ except Exception:
66
+ fields[k] = (Any, None)
67
+ return create_model(graph.get_name("State"), **fields).model_json_schema()
68
+
69
+
70
+ def _get_configurable_jsonschema(graph) -> dict:
71
+ """Get the JSON schema for the configurable part of the graph"""
72
+ from pydantic import TypeAdapter
73
+
74
+ EXCLUDED_CONFIG_SCHEMA = {"__pregel_resuming", "__pregel_checkpoint_id"}
75
+
76
+ config_schema = graph.config_schema()
77
+ model_fields = getattr(config_schema, "model_fields", None) or getattr(config_schema, "__fields__", None)
78
+
79
+ if model_fields is not None and "configurable" in model_fields:
80
+ configurable = TypeAdapter(model_fields["configurable"].annotation)
81
+ json_schema = configurable.json_schema()
82
+ if json_schema:
83
+ for key in EXCLUDED_CONFIG_SCHEMA:
84
+ json_schema["properties"].pop(key, None)
85
+ if hasattr(graph, "config_type") and graph.config_type is not None and hasattr(graph.config_type, "__name__"):
86
+ json_schema["title"] = graph.config_type.__name__
87
+ return json_schema
88
+ return {}
89
+
90
+
91
+ def _extract_graph_schemas(graph) -> dict:
92
+ """Extract schemas from a compiled LangGraph graph object"""
93
+ try:
94
+ input_schema = graph.get_input_jsonschema()
95
+ except Exception:
96
+ input_schema = None
97
+
98
+ try:
99
+ output_schema = graph.get_output_jsonschema()
100
+ except Exception:
101
+ output_schema = None
102
+
103
+ try:
104
+ state_schema = _state_jsonschema(graph)
105
+ except Exception:
106
+ state_schema = None
107
+
108
+ try:
109
+ config_schema = _get_configurable_jsonschema(graph)
110
+ except Exception:
111
+ config_schema = None
112
+
113
+ try:
114
+ context_schema = graph.get_context_jsonschema()
115
+ except Exception:
116
+ context_schema = None
117
+
118
+ return {
119
+ "input_schema": input_schema,
120
+ "output_schema": output_schema,
121
+ "state_schema": state_schema,
122
+ "config_schema": config_schema,
123
+ "context_schema": context_schema,
124
+ }
125
+
126
+
127
+ class AssistantService:
128
+ """Service for managing assistants"""
129
+
130
+ def __init__(self, session: AsyncSession, langgraph_service: LangGraphService):
131
+ self.session = session
132
+ self.langgraph_service = langgraph_service
133
+
134
+ async def create_assistant(self, request: AssistantCreate, user_identity: str) -> Assistant:
135
+ """Create a new assistant"""
136
+ # Get LangGraph service to validate graph
137
+ available_graphs = self.langgraph_service.list_graphs()
138
+
139
+ # Use graph_id as the main identifier
140
+ graph_id = request.graph_id
141
+
142
+ if graph_id not in available_graphs:
143
+ raise HTTPException(
144
+ 400,
145
+ f"Graph '{graph_id}' not found in aegra.json. Available: {list(available_graphs.keys())}",
146
+ )
147
+
148
+ # Validate graph can be loaded
149
+ try:
150
+ await self.langgraph_service.get_graph_for_validation(graph_id)
151
+ except Exception as e:
152
+ raise HTTPException(400, f"Failed to load graph: {str(e)}") from e
153
+
154
+ config = request.config
155
+ context = request.context
156
+
157
+ if config.get("configurable") and context:
158
+ raise HTTPException(
159
+ status_code=400,
160
+ detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
161
+ )
162
+
163
+ # Keep config and context up to date with one another
164
+ if config.get("configurable"):
165
+ context = config["configurable"]
166
+ elif context:
167
+ config["configurable"] = context
168
+
169
+ # Generate assistant_id if not provided
170
+ assistant_id = request.assistant_id or str(uuid4())
171
+
172
+ # Generate name if not provided
173
+ name = request.name or f"Assistant for {graph_id}"
174
+
175
+ # Check if an assistant already exists for this user, graph and config pair
176
+ existing_stmt = select(AssistantORM).where(
177
+ AssistantORM.user_id == user_identity,
178
+ or_(
179
+ (AssistantORM.graph_id == graph_id) & (AssistantORM.config == config),
180
+ AssistantORM.assistant_id == assistant_id,
181
+ ),
182
+ )
183
+ existing = await self.session.scalar(existing_stmt)
184
+
185
+ if existing:
186
+ if request.if_exists == "do_nothing":
187
+ return to_pydantic(existing)
188
+ else: # error (default)
189
+ raise HTTPException(409, f"Assistant '{assistant_id}' already exists")
190
+
191
+ # Create assistant record
192
+ assistant_orm = AssistantORM(
193
+ assistant_id=assistant_id,
194
+ name=name,
195
+ description=request.description,
196
+ config=config,
197
+ context=context,
198
+ graph_id=graph_id,
199
+ user_id=user_identity,
200
+ metadata_dict=request.metadata,
201
+ version=1,
202
+ )
203
+
204
+ self.session.add(assistant_orm)
205
+ await self.session.commit()
206
+ await self.session.refresh(assistant_orm)
207
+
208
+ # Create initial version record
209
+ assistant_version_orm = AssistantVersionORM(
210
+ assistant_id=assistant_id,
211
+ version=1,
212
+ graph_id=graph_id,
213
+ config=config,
214
+ context=context,
215
+ created_at=datetime.now(UTC),
216
+ name=name,
217
+ description=request.description,
218
+ metadata_dict=request.metadata,
219
+ )
220
+ self.session.add(assistant_version_orm)
221
+ await self.session.commit()
222
+
223
+ return to_pydantic(assistant_orm)
224
+
225
+ async def list_assistants(self, user_identity: str) -> list[Assistant]:
226
+ """List user's assistants and system assistants"""
227
+ # Include both user's assistants and system assistants (like search_assistants does)
228
+ stmt = select(AssistantORM).where(or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"))
229
+ result = await self.session.scalars(stmt)
230
+ user_assistants = [to_pydantic(a) for a in result.all()]
231
+ return user_assistants
232
+
233
+ async def search_assistants(
234
+ self,
235
+ request: Any, # AssistantSearchRequest
236
+ user_identity: str,
237
+ ) -> list[Assistant]:
238
+ """Search assistants with filters"""
239
+ # Start with user's assistants
240
+ stmt = select(AssistantORM).where(or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"))
241
+
242
+ # Apply filters
243
+ if request.name:
244
+ stmt = stmt.where(AssistantORM.name.ilike(f"%{request.name}%"))
245
+
246
+ if request.description:
247
+ stmt = stmt.where(AssistantORM.description.ilike(f"%{request.description}%"))
248
+
249
+ if request.graph_id:
250
+ stmt = stmt.where(AssistantORM.graph_id == request.graph_id)
251
+
252
+ if request.metadata:
253
+ stmt = stmt.where(AssistantORM.metadata_dict.op("@>")(request.metadata))
254
+
255
+ # Apply pagination
256
+ offset = request.offset or 0
257
+ limit = request.limit or 20
258
+ stmt = stmt.offset(offset).limit(limit)
259
+
260
+ result = await self.session.scalars(stmt)
261
+ paginated_assistants = [to_pydantic(a) for a in result.all()]
262
+
263
+ return paginated_assistants
264
+
265
+ async def count_assistants(
266
+ self,
267
+ request: Any, # AssistantSearchRequest
268
+ user_identity: str,
269
+ ) -> int:
270
+ """Count assistants with filters"""
271
+ # Include both user's assistants and system assistants (like search_assistants does)
272
+ stmt = select(func.count()).where(or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"))
273
+
274
+ if request.name:
275
+ stmt = stmt.where(AssistantORM.name.ilike(f"%{request.name}%"))
276
+
277
+ if request.description:
278
+ stmt = stmt.where(AssistantORM.description.ilike(f"%{request.description}%"))
279
+
280
+ if request.graph_id:
281
+ stmt = stmt.where(AssistantORM.graph_id == request.graph_id)
282
+
283
+ if request.metadata:
284
+ stmt = stmt.where(AssistantORM.metadata_dict.op("@>")(request.metadata))
285
+
286
+ total = await self.session.scalar(stmt)
287
+ return total or 0
288
+
289
+ async def get_assistant(self, assistant_id: str, user_identity: str) -> Assistant:
290
+ """Get assistant by ID"""
291
+ stmt = select(AssistantORM).where(
292
+ AssistantORM.assistant_id == assistant_id,
293
+ or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"),
294
+ )
295
+ assistant = await self.session.scalar(stmt)
296
+
297
+ if not assistant:
298
+ raise HTTPException(404, f"Assistant '{assistant_id}' not found")
299
+
300
+ return to_pydantic(assistant)
301
+
302
+ async def update_assistant(self, assistant_id: str, request: AssistantUpdate, user_identity: str) -> Assistant:
303
+ """Update assistant by ID"""
304
+ metadata = request.metadata or {}
305
+ config = request.config or {}
306
+ context = request.context or {}
307
+
308
+ if config.get("configurable") and context:
309
+ raise HTTPException(
310
+ status_code=400,
311
+ detail="Cannot specify both configurable and context. Use only one.",
312
+ )
313
+
314
+ # Keep config and context up to date with one another
315
+ if config.get("configurable"):
316
+ context = config["configurable"]
317
+ elif context:
318
+ config["configurable"] = context
319
+
320
+ stmt = select(AssistantORM).where(
321
+ AssistantORM.assistant_id == assistant_id,
322
+ AssistantORM.user_id == user_identity,
323
+ )
324
+ assistant = await self.session.scalar(stmt)
325
+ if not assistant:
326
+ raise HTTPException(404, f"Assistant '{assistant_id}' not found")
327
+
328
+ now = datetime.now(UTC)
329
+ version_stmt = select(func.max(AssistantVersionORM.version)).where(
330
+ AssistantVersionORM.assistant_id == assistant_id
331
+ )
332
+ max_version = await self.session.scalar(version_stmt)
333
+ new_version = (max_version or 1) + 1 if max_version is not None else 1
334
+
335
+ new_version_details = {
336
+ "assistant_id": assistant_id,
337
+ "version": new_version,
338
+ "graph_id": request.graph_id or assistant.graph_id,
339
+ "config": config,
340
+ "context": context,
341
+ "created_at": now,
342
+ "name": request.name or assistant.name,
343
+ "description": request.description or assistant.description,
344
+ "metadata_dict": metadata,
345
+ }
346
+
347
+ assistant_version_orm = AssistantVersionORM(**new_version_details)
348
+ self.session.add(assistant_version_orm)
349
+ await self.session.commit()
350
+
351
+ assistant_update = (
352
+ update(AssistantORM)
353
+ .where(
354
+ AssistantORM.assistant_id == assistant_id,
355
+ AssistantORM.user_id == user_identity,
356
+ )
357
+ .values(
358
+ name=new_version_details["name"],
359
+ description=new_version_details["description"],
360
+ graph_id=new_version_details["graph_id"],
361
+ config=new_version_details["config"],
362
+ context=new_version_details["context"],
363
+ version=new_version,
364
+ updated_at=now,
365
+ )
366
+ )
367
+ await self.session.execute(assistant_update)
368
+ await self.session.commit()
369
+ updated_assistant = await self.session.scalar(stmt)
370
+ return to_pydantic(updated_assistant)
371
+
372
+ async def delete_assistant(self, assistant_id: str, user_identity: str) -> dict:
373
+ """Delete assistant by ID"""
374
+ stmt = select(AssistantORM).where(
375
+ AssistantORM.assistant_id == assistant_id,
376
+ AssistantORM.user_id == user_identity,
377
+ )
378
+ assistant = await self.session.scalar(stmt)
379
+
380
+ if not assistant:
381
+ raise HTTPException(404, f"Assistant '{assistant_id}' not found")
382
+
383
+ await self.session.delete(assistant)
384
+ await self.session.commit()
385
+
386
+ return {"status": "deleted"}
387
+
388
+ async def set_assistant_latest(self, assistant_id: str, version: int, user_identity: str) -> Assistant:
389
+ """Set the given version as the latest version of an assistant"""
390
+ stmt = select(AssistantORM).where(
391
+ AssistantORM.assistant_id == assistant_id,
392
+ AssistantORM.user_id == user_identity,
393
+ )
394
+ assistant = await self.session.scalar(stmt)
395
+ if not assistant:
396
+ raise HTTPException(404, f"Assistant '{assistant_id}' not found")
397
+
398
+ version_stmt = select(AssistantVersionORM).where(
399
+ AssistantVersionORM.assistant_id == assistant_id,
400
+ AssistantVersionORM.version == version,
401
+ )
402
+ assistant_version = await self.session.scalar(version_stmt)
403
+ if not assistant_version:
404
+ raise HTTPException(404, f"Version '{version}' for Assistant '{assistant_id}' not found")
405
+
406
+ assistant_update = (
407
+ update(AssistantORM)
408
+ .where(
409
+ AssistantORM.assistant_id == assistant_id,
410
+ AssistantORM.user_id == user_identity,
411
+ )
412
+ .values(
413
+ name=assistant_version.name,
414
+ description=assistant_version.description,
415
+ config=assistant_version.config,
416
+ context=assistant_version.context,
417
+ graph_id=assistant_version.graph_id,
418
+ version=version,
419
+ updated_at=datetime.now(UTC),
420
+ )
421
+ )
422
+ await self.session.execute(assistant_update)
423
+ await self.session.commit()
424
+ updated_assistant = await self.session.scalar(stmt)
425
+ return to_pydantic(updated_assistant)
426
+
427
+ async def list_assistant_versions(self, assistant_id: str, user_identity: str) -> list[Assistant]:
428
+ """List all versions of an assistant"""
429
+ stmt = select(AssistantORM).where(
430
+ AssistantORM.assistant_id == assistant_id,
431
+ or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"),
432
+ )
433
+ assistant = await self.session.scalar(stmt)
434
+ if not assistant:
435
+ raise HTTPException(404, f"Assistant '{assistant_id}' not found")
436
+
437
+ stmt = (
438
+ select(AssistantVersionORM)
439
+ .where(AssistantVersionORM.assistant_id == assistant_id)
440
+ .order_by(AssistantVersionORM.version.desc())
441
+ )
442
+ result = await self.session.scalars(stmt)
443
+ versions = result.all()
444
+
445
+ if not versions:
446
+ raise HTTPException(404, f"No versions found for Assistant '{assistant_id}'")
447
+
448
+ # Convert to Pydantic models
449
+ version_list = [
450
+ Assistant(
451
+ assistant_id=assistant_id,
452
+ name=v.name,
453
+ description=v.description,
454
+ config=v.config or {},
455
+ context=v.context or {},
456
+ graph_id=v.graph_id,
457
+ user_id=user_identity,
458
+ version=v.version,
459
+ created_at=v.created_at,
460
+ updated_at=v.created_at,
461
+ metadata_dict=v.metadata_dict or {},
462
+ )
463
+ for v in versions
464
+ ]
465
+
466
+ return version_list
467
+
468
+ async def get_assistant_schemas(self, assistant_id: str, user_identity: str) -> dict:
469
+ """Get input, output, state, config and context schemas for an assistant"""
470
+ stmt = select(AssistantORM).where(
471
+ AssistantORM.assistant_id == assistant_id,
472
+ or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"),
473
+ )
474
+ assistant = await self.session.scalar(stmt)
475
+
476
+ if not assistant:
477
+ raise HTTPException(404, f"Assistant '{assistant_id}' not found")
478
+
479
+ try:
480
+ # Use get_graph_for_validation since we only need schema extraction,
481
+ # not checkpointer/store for execution
482
+ graph = await self.langgraph_service.get_graph_for_validation(assistant.graph_id)
483
+ schemas = _extract_graph_schemas(graph)
484
+
485
+ return {"graph_id": assistant.graph_id, **schemas}
486
+
487
+ except Exception as e:
488
+ raise HTTPException(400, f"Failed to extract schemas: {str(e)}") from e
489
+
490
+ async def get_assistant_graph(self, assistant_id: str, xray: bool | int, user_identity: str) -> dict:
491
+ """Get the graph structure for visualization"""
492
+ stmt = select(AssistantORM).where(
493
+ AssistantORM.assistant_id == assistant_id,
494
+ or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"),
495
+ )
496
+ assistant = await self.session.scalar(stmt)
497
+
498
+ if not assistant:
499
+ raise HTTPException(404, f"Assistant '{assistant_id}' not found")
500
+
501
+ try:
502
+ # Use get_graph_for_validation since we only need graph structure,
503
+ # not checkpointer/store for execution
504
+ graph = await self.langgraph_service.get_graph_for_validation(assistant.graph_id)
505
+
506
+ # Validate xray if it's an integer (not a boolean)
507
+ if isinstance(xray, int) and not isinstance(xray, bool) and xray <= 0:
508
+ raise HTTPException(422, detail="Invalid xray value")
509
+
510
+ try:
511
+ drawable_graph = await graph.aget_graph(xray=xray)
512
+ json_graph = drawable_graph.to_json()
513
+
514
+ for node in json_graph.get("nodes", []):
515
+ if (data := node.get("data")) and isinstance(data, dict):
516
+ data.pop("id", None)
517
+
518
+ return json_graph
519
+ except NotImplementedError as e:
520
+ raise HTTPException(422, detail="The graph does not support visualization") from e
521
+
522
+ except HTTPException:
523
+ raise
524
+ except Exception as e:
525
+ raise HTTPException(400, f"Failed to get graph: {str(e)}") from e
526
+
527
+ async def get_assistant_subgraphs(
528
+ self,
529
+ assistant_id: str,
530
+ namespace: str | None,
531
+ recurse: bool,
532
+ user_identity: str,
533
+ ) -> dict:
534
+ """Get subgraphs of an assistant"""
535
+ stmt = select(AssistantORM).where(
536
+ AssistantORM.assistant_id == assistant_id,
537
+ or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"),
538
+ )
539
+ assistant = await self.session.scalar(stmt)
540
+
541
+ if not assistant:
542
+ raise HTTPException(404, f"Assistant '{assistant_id}' not found")
543
+
544
+ try:
545
+ # Use get_graph_for_validation since we only need schema extraction,
546
+ # not checkpointer/store for execution
547
+ graph = await self.langgraph_service.get_graph_for_validation(assistant.graph_id)
548
+
549
+ try:
550
+ subgraphs = {
551
+ ns: _extract_graph_schemas(subgraph)
552
+ async for ns, subgraph in graph.aget_subgraphs(namespace=namespace, recurse=recurse)
553
+ }
554
+ return subgraphs
555
+ except NotImplementedError as e:
556
+ raise HTTPException(422, detail="The graph does not support subgraphs") from e
557
+
558
+ except HTTPException:
559
+ raise
560
+ except Exception as e:
561
+ raise HTTPException(400, f"Failed to get subgraphs: {str(e)}") from e
562
+
563
+
564
+ def get_assistant_service(
565
+ session: AsyncSession = Depends(get_session),
566
+ langgraph_service: LangGraphService = Depends(get_langgraph_service),
567
+ ) -> AssistantService:
568
+ """Dependency injection for AssistantService"""
569
+ return AssistantService(session, langgraph_service)
@@ -0,0 +1,59 @@
1
+ """Abstract base classes for the broker system"""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import AsyncIterator
5
+ from typing import Any
6
+
7
+
8
+ class BaseRunBroker(ABC):
9
+ """Abstract base class for a run-specific event broker"""
10
+
11
+ @abstractmethod
12
+ async def put(self, event_id: str, payload: Any) -> None:
13
+ """Put an event into the broker queue"""
14
+ pass
15
+
16
+ @abstractmethod
17
+ def aiter(self) -> AsyncIterator[tuple[str, Any]]:
18
+ """Async iterator yielding (event_id, payload) pairs"""
19
+ # Abstract async generator method; must be implemented by subclass
20
+ raise NotImplementedError("aiter method must be implemented by subclass")
21
+
22
+ @abstractmethod
23
+ def mark_finished(self) -> None:
24
+ """Mark this broker as finished"""
25
+ pass
26
+
27
+ @abstractmethod
28
+ def is_finished(self) -> bool:
29
+ """Check if this broker is finished"""
30
+ pass
31
+
32
+
33
+ class BaseBrokerManager(ABC):
34
+ """Abstract base class for managing multiple RunBroker instances"""
35
+
36
+ @abstractmethod
37
+ def get_or_create_broker(self, run_id: str) -> BaseRunBroker:
38
+ """Get or create a broker for a run"""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def get_broker(self, run_id: str) -> BaseRunBroker | None:
43
+ """Get an existing broker or None"""
44
+ pass
45
+
46
+ @abstractmethod
47
+ def cleanup_broker(self, run_id: str) -> None:
48
+ """Clean up a broker for a run"""
49
+ pass
50
+
51
+ @abstractmethod
52
+ async def start_cleanup_task(self) -> None:
53
+ """Start background cleanup task for old brokers"""
54
+ pass
55
+
56
+ @abstractmethod
57
+ async def stop_cleanup_task(self) -> None:
58
+ """Stop background cleanup task"""
59
+ pass