aegra-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aegra_api/__init__.py +3 -0
- aegra_api/api/__init__.py +1 -0
- aegra_api/api/assistants.py +235 -0
- aegra_api/api/runs.py +1110 -0
- aegra_api/api/store.py +200 -0
- aegra_api/api/threads.py +761 -0
- aegra_api/config.py +204 -0
- aegra_api/constants.py +5 -0
- aegra_api/core/__init__.py +0 -0
- aegra_api/core/app_loader.py +91 -0
- aegra_api/core/auth_ctx.py +65 -0
- aegra_api/core/auth_deps.py +186 -0
- aegra_api/core/auth_handlers.py +248 -0
- aegra_api/core/auth_middleware.py +331 -0
- aegra_api/core/database.py +123 -0
- aegra_api/core/health.py +131 -0
- aegra_api/core/orm.py +165 -0
- aegra_api/core/route_merger.py +69 -0
- aegra_api/core/serializers/__init__.py +7 -0
- aegra_api/core/serializers/base.py +22 -0
- aegra_api/core/serializers/general.py +54 -0
- aegra_api/core/serializers/langgraph.py +102 -0
- aegra_api/core/sse.py +178 -0
- aegra_api/main.py +303 -0
- aegra_api/middleware/__init__.py +4 -0
- aegra_api/middleware/double_encoded_json.py +74 -0
- aegra_api/middleware/logger_middleware.py +95 -0
- aegra_api/models/__init__.py +76 -0
- aegra_api/models/assistants.py +81 -0
- aegra_api/models/auth.py +62 -0
- aegra_api/models/enums.py +29 -0
- aegra_api/models/errors.py +29 -0
- aegra_api/models/runs.py +124 -0
- aegra_api/models/store.py +67 -0
- aegra_api/models/threads.py +152 -0
- aegra_api/observability/__init__.py +1 -0
- aegra_api/observability/base.py +88 -0
- aegra_api/observability/otel.py +133 -0
- aegra_api/observability/setup.py +27 -0
- aegra_api/observability/targets/__init__.py +11 -0
- aegra_api/observability/targets/base.py +18 -0
- aegra_api/observability/targets/langfuse.py +33 -0
- aegra_api/observability/targets/otlp.py +38 -0
- aegra_api/observability/targets/phoenix.py +24 -0
- aegra_api/services/__init__.py +0 -0
- aegra_api/services/assistant_service.py +569 -0
- aegra_api/services/base_broker.py +59 -0
- aegra_api/services/broker.py +141 -0
- aegra_api/services/event_converter.py +157 -0
- aegra_api/services/event_store.py +196 -0
- aegra_api/services/graph_streaming.py +433 -0
- aegra_api/services/langgraph_service.py +456 -0
- aegra_api/services/streaming_service.py +362 -0
- aegra_api/services/thread_state_service.py +128 -0
- aegra_api/settings.py +124 -0
- aegra_api/utils/__init__.py +3 -0
- aegra_api/utils/assistants.py +23 -0
- aegra_api/utils/run_utils.py +60 -0
- aegra_api/utils/setup_logging.py +122 -0
- aegra_api/utils/sse_utils.py +26 -0
- aegra_api/utils/status_compat.py +57 -0
- aegra_api-0.1.0.dist-info/METADATA +244 -0
- aegra_api-0.1.0.dist-info/RECORD +64 -0
- aegra_api-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
"""Service layer for assistant business logic
|
|
2
|
+
|
|
3
|
+
This service encapsulates all business logic for assistant management, following
|
|
4
|
+
a layered architecture pattern. The code was extracted from api/assistants.py
|
|
5
|
+
to separate concerns and improve maintainability.
|
|
6
|
+
|
|
7
|
+
Responsibilities:
|
|
8
|
+
- Business logic and validation
|
|
9
|
+
- Database operations via SQLAlchemy ORM
|
|
10
|
+
- Graph schema extraction and manipulation
|
|
11
|
+
- Coordination between different components
|
|
12
|
+
|
|
13
|
+
This is the first service layer implementation in Aegra. The pattern will be
|
|
14
|
+
applied to other APIs (runs, threads, crons) as part of ongoing refactoring.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import uuid
|
|
18
|
+
from datetime import UTC, datetime
|
|
19
|
+
from typing import Any
|
|
20
|
+
from uuid import uuid4
|
|
21
|
+
|
|
22
|
+
from fastapi import Depends, HTTPException
|
|
23
|
+
from sqlalchemy import func, or_, select, update
|
|
24
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
25
|
+
|
|
26
|
+
from aegra_api.core.orm import Assistant as AssistantORM
|
|
27
|
+
from aegra_api.core.orm import AssistantVersion as AssistantVersionORM
|
|
28
|
+
from aegra_api.core.orm import get_session
|
|
29
|
+
from aegra_api.models import Assistant, AssistantCreate, AssistantUpdate
|
|
30
|
+
from aegra_api.services.langgraph_service import LangGraphService, get_langgraph_service
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def to_pydantic(row: AssistantORM) -> Assistant:
|
|
34
|
+
"""Convert SQLAlchemy ORM object to Pydantic model with proper type casting.
|
|
35
|
+
|
|
36
|
+
Uses from_attributes=True because Assistant ORM has attribute/column name mismatch:
|
|
37
|
+
- ORM attribute: metadata_dict
|
|
38
|
+
- DB column: metadata
|
|
39
|
+
- Pydantic field: metadata (with alias="metadata_dict")
|
|
40
|
+
|
|
41
|
+
This is different from Thread/Run where attribute names match column names.
|
|
42
|
+
"""
|
|
43
|
+
# Cast UUIDs to str so they match the Pydantic schema
|
|
44
|
+
if hasattr(row, "assistant_id") and row.assistant_id is not None:
|
|
45
|
+
row.assistant_id = str(row.assistant_id)
|
|
46
|
+
if hasattr(row, "user_id") and isinstance(row.user_id, uuid.UUID):
|
|
47
|
+
row.user_id = str(row.user_id)
|
|
48
|
+
|
|
49
|
+
# Use Pydantic's built-in ORM conversion with from_attributes=True
|
|
50
|
+
return Assistant.model_validate(row, from_attributes=True)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _state_jsonschema(graph) -> dict | None:
|
|
54
|
+
"""Extract state schema from graph channels"""
|
|
55
|
+
from typing import Any
|
|
56
|
+
|
|
57
|
+
from langchain_core.runnables.utils import create_model
|
|
58
|
+
|
|
59
|
+
fields: dict = {}
|
|
60
|
+
for k in graph.stream_channels_list:
|
|
61
|
+
v = graph.channels[k]
|
|
62
|
+
try:
|
|
63
|
+
create_model(k, __root__=(v.UpdateType, None)).model_json_schema()
|
|
64
|
+
fields[k] = (v.UpdateType, None)
|
|
65
|
+
except Exception:
|
|
66
|
+
fields[k] = (Any, None)
|
|
67
|
+
return create_model(graph.get_name("State"), **fields).model_json_schema()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _get_configurable_jsonschema(graph) -> dict:
|
|
71
|
+
"""Get the JSON schema for the configurable part of the graph"""
|
|
72
|
+
from pydantic import TypeAdapter
|
|
73
|
+
|
|
74
|
+
EXCLUDED_CONFIG_SCHEMA = {"__pregel_resuming", "__pregel_checkpoint_id"}
|
|
75
|
+
|
|
76
|
+
config_schema = graph.config_schema()
|
|
77
|
+
model_fields = getattr(config_schema, "model_fields", None) or getattr(config_schema, "__fields__", None)
|
|
78
|
+
|
|
79
|
+
if model_fields is not None and "configurable" in model_fields:
|
|
80
|
+
configurable = TypeAdapter(model_fields["configurable"].annotation)
|
|
81
|
+
json_schema = configurable.json_schema()
|
|
82
|
+
if json_schema:
|
|
83
|
+
for key in EXCLUDED_CONFIG_SCHEMA:
|
|
84
|
+
json_schema["properties"].pop(key, None)
|
|
85
|
+
if hasattr(graph, "config_type") and graph.config_type is not None and hasattr(graph.config_type, "__name__"):
|
|
86
|
+
json_schema["title"] = graph.config_type.__name__
|
|
87
|
+
return json_schema
|
|
88
|
+
return {}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _extract_graph_schemas(graph) -> dict:
|
|
92
|
+
"""Extract schemas from a compiled LangGraph graph object"""
|
|
93
|
+
try:
|
|
94
|
+
input_schema = graph.get_input_jsonschema()
|
|
95
|
+
except Exception:
|
|
96
|
+
input_schema = None
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
output_schema = graph.get_output_jsonschema()
|
|
100
|
+
except Exception:
|
|
101
|
+
output_schema = None
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
state_schema = _state_jsonschema(graph)
|
|
105
|
+
except Exception:
|
|
106
|
+
state_schema = None
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
config_schema = _get_configurable_jsonschema(graph)
|
|
110
|
+
except Exception:
|
|
111
|
+
config_schema = None
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
context_schema = graph.get_context_jsonschema()
|
|
115
|
+
except Exception:
|
|
116
|
+
context_schema = None
|
|
117
|
+
|
|
118
|
+
return {
|
|
119
|
+
"input_schema": input_schema,
|
|
120
|
+
"output_schema": output_schema,
|
|
121
|
+
"state_schema": state_schema,
|
|
122
|
+
"config_schema": config_schema,
|
|
123
|
+
"context_schema": context_schema,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class AssistantService:
|
|
128
|
+
"""Service for managing assistants"""
|
|
129
|
+
|
|
130
|
+
def __init__(self, session: AsyncSession, langgraph_service: LangGraphService):
|
|
131
|
+
self.session = session
|
|
132
|
+
self.langgraph_service = langgraph_service
|
|
133
|
+
|
|
134
|
+
async def create_assistant(self, request: AssistantCreate, user_identity: str) -> Assistant:
|
|
135
|
+
"""Create a new assistant"""
|
|
136
|
+
# Get LangGraph service to validate graph
|
|
137
|
+
available_graphs = self.langgraph_service.list_graphs()
|
|
138
|
+
|
|
139
|
+
# Use graph_id as the main identifier
|
|
140
|
+
graph_id = request.graph_id
|
|
141
|
+
|
|
142
|
+
if graph_id not in available_graphs:
|
|
143
|
+
raise HTTPException(
|
|
144
|
+
400,
|
|
145
|
+
f"Graph '{graph_id}' not found in aegra.json. Available: {list(available_graphs.keys())}",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Validate graph can be loaded
|
|
149
|
+
try:
|
|
150
|
+
await self.langgraph_service.get_graph_for_validation(graph_id)
|
|
151
|
+
except Exception as e:
|
|
152
|
+
raise HTTPException(400, f"Failed to load graph: {str(e)}") from e
|
|
153
|
+
|
|
154
|
+
config = request.config
|
|
155
|
+
context = request.context
|
|
156
|
+
|
|
157
|
+
if config.get("configurable") and context:
|
|
158
|
+
raise HTTPException(
|
|
159
|
+
status_code=400,
|
|
160
|
+
detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Keep config and context up to date with one another
|
|
164
|
+
if config.get("configurable"):
|
|
165
|
+
context = config["configurable"]
|
|
166
|
+
elif context:
|
|
167
|
+
config["configurable"] = context
|
|
168
|
+
|
|
169
|
+
# Generate assistant_id if not provided
|
|
170
|
+
assistant_id = request.assistant_id or str(uuid4())
|
|
171
|
+
|
|
172
|
+
# Generate name if not provided
|
|
173
|
+
name = request.name or f"Assistant for {graph_id}"
|
|
174
|
+
|
|
175
|
+
# Check if an assistant already exists for this user, graph and config pair
|
|
176
|
+
existing_stmt = select(AssistantORM).where(
|
|
177
|
+
AssistantORM.user_id == user_identity,
|
|
178
|
+
or_(
|
|
179
|
+
(AssistantORM.graph_id == graph_id) & (AssistantORM.config == config),
|
|
180
|
+
AssistantORM.assistant_id == assistant_id,
|
|
181
|
+
),
|
|
182
|
+
)
|
|
183
|
+
existing = await self.session.scalar(existing_stmt)
|
|
184
|
+
|
|
185
|
+
if existing:
|
|
186
|
+
if request.if_exists == "do_nothing":
|
|
187
|
+
return to_pydantic(existing)
|
|
188
|
+
else: # error (default)
|
|
189
|
+
raise HTTPException(409, f"Assistant '{assistant_id}' already exists")
|
|
190
|
+
|
|
191
|
+
# Create assistant record
|
|
192
|
+
assistant_orm = AssistantORM(
|
|
193
|
+
assistant_id=assistant_id,
|
|
194
|
+
name=name,
|
|
195
|
+
description=request.description,
|
|
196
|
+
config=config,
|
|
197
|
+
context=context,
|
|
198
|
+
graph_id=graph_id,
|
|
199
|
+
user_id=user_identity,
|
|
200
|
+
metadata_dict=request.metadata,
|
|
201
|
+
version=1,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
self.session.add(assistant_orm)
|
|
205
|
+
await self.session.commit()
|
|
206
|
+
await self.session.refresh(assistant_orm)
|
|
207
|
+
|
|
208
|
+
# Create initial version record
|
|
209
|
+
assistant_version_orm = AssistantVersionORM(
|
|
210
|
+
assistant_id=assistant_id,
|
|
211
|
+
version=1,
|
|
212
|
+
graph_id=graph_id,
|
|
213
|
+
config=config,
|
|
214
|
+
context=context,
|
|
215
|
+
created_at=datetime.now(UTC),
|
|
216
|
+
name=name,
|
|
217
|
+
description=request.description,
|
|
218
|
+
metadata_dict=request.metadata,
|
|
219
|
+
)
|
|
220
|
+
self.session.add(assistant_version_orm)
|
|
221
|
+
await self.session.commit()
|
|
222
|
+
|
|
223
|
+
return to_pydantic(assistant_orm)
|
|
224
|
+
|
|
225
|
+
async def list_assistants(self, user_identity: str) -> list[Assistant]:
|
|
226
|
+
"""List user's assistants and system assistants"""
|
|
227
|
+
# Include both user's assistants and system assistants (like search_assistants does)
|
|
228
|
+
stmt = select(AssistantORM).where(or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"))
|
|
229
|
+
result = await self.session.scalars(stmt)
|
|
230
|
+
user_assistants = [to_pydantic(a) for a in result.all()]
|
|
231
|
+
return user_assistants
|
|
232
|
+
|
|
233
|
+
async def search_assistants(
|
|
234
|
+
self,
|
|
235
|
+
request: Any, # AssistantSearchRequest
|
|
236
|
+
user_identity: str,
|
|
237
|
+
) -> list[Assistant]:
|
|
238
|
+
"""Search assistants with filters"""
|
|
239
|
+
# Start with user's assistants
|
|
240
|
+
stmt = select(AssistantORM).where(or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"))
|
|
241
|
+
|
|
242
|
+
# Apply filters
|
|
243
|
+
if request.name:
|
|
244
|
+
stmt = stmt.where(AssistantORM.name.ilike(f"%{request.name}%"))
|
|
245
|
+
|
|
246
|
+
if request.description:
|
|
247
|
+
stmt = stmt.where(AssistantORM.description.ilike(f"%{request.description}%"))
|
|
248
|
+
|
|
249
|
+
if request.graph_id:
|
|
250
|
+
stmt = stmt.where(AssistantORM.graph_id == request.graph_id)
|
|
251
|
+
|
|
252
|
+
if request.metadata:
|
|
253
|
+
stmt = stmt.where(AssistantORM.metadata_dict.op("@>")(request.metadata))
|
|
254
|
+
|
|
255
|
+
# Apply pagination
|
|
256
|
+
offset = request.offset or 0
|
|
257
|
+
limit = request.limit or 20
|
|
258
|
+
stmt = stmt.offset(offset).limit(limit)
|
|
259
|
+
|
|
260
|
+
result = await self.session.scalars(stmt)
|
|
261
|
+
paginated_assistants = [to_pydantic(a) for a in result.all()]
|
|
262
|
+
|
|
263
|
+
return paginated_assistants
|
|
264
|
+
|
|
265
|
+
async def count_assistants(
|
|
266
|
+
self,
|
|
267
|
+
request: Any, # AssistantSearchRequest
|
|
268
|
+
user_identity: str,
|
|
269
|
+
) -> int:
|
|
270
|
+
"""Count assistants with filters"""
|
|
271
|
+
# Include both user's assistants and system assistants (like search_assistants does)
|
|
272
|
+
stmt = select(func.count()).where(or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"))
|
|
273
|
+
|
|
274
|
+
if request.name:
|
|
275
|
+
stmt = stmt.where(AssistantORM.name.ilike(f"%{request.name}%"))
|
|
276
|
+
|
|
277
|
+
if request.description:
|
|
278
|
+
stmt = stmt.where(AssistantORM.description.ilike(f"%{request.description}%"))
|
|
279
|
+
|
|
280
|
+
if request.graph_id:
|
|
281
|
+
stmt = stmt.where(AssistantORM.graph_id == request.graph_id)
|
|
282
|
+
|
|
283
|
+
if request.metadata:
|
|
284
|
+
stmt = stmt.where(AssistantORM.metadata_dict.op("@>")(request.metadata))
|
|
285
|
+
|
|
286
|
+
total = await self.session.scalar(stmt)
|
|
287
|
+
return total or 0
|
|
288
|
+
|
|
289
|
+
async def get_assistant(self, assistant_id: str, user_identity: str) -> Assistant:
|
|
290
|
+
"""Get assistant by ID"""
|
|
291
|
+
stmt = select(AssistantORM).where(
|
|
292
|
+
AssistantORM.assistant_id == assistant_id,
|
|
293
|
+
or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"),
|
|
294
|
+
)
|
|
295
|
+
assistant = await self.session.scalar(stmt)
|
|
296
|
+
|
|
297
|
+
if not assistant:
|
|
298
|
+
raise HTTPException(404, f"Assistant '{assistant_id}' not found")
|
|
299
|
+
|
|
300
|
+
return to_pydantic(assistant)
|
|
301
|
+
|
|
302
|
+
async def update_assistant(self, assistant_id: str, request: AssistantUpdate, user_identity: str) -> Assistant:
|
|
303
|
+
"""Update assistant by ID"""
|
|
304
|
+
metadata = request.metadata or {}
|
|
305
|
+
config = request.config or {}
|
|
306
|
+
context = request.context or {}
|
|
307
|
+
|
|
308
|
+
if config.get("configurable") and context:
|
|
309
|
+
raise HTTPException(
|
|
310
|
+
status_code=400,
|
|
311
|
+
detail="Cannot specify both configurable and context. Use only one.",
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Keep config and context up to date with one another
|
|
315
|
+
if config.get("configurable"):
|
|
316
|
+
context = config["configurable"]
|
|
317
|
+
elif context:
|
|
318
|
+
config["configurable"] = context
|
|
319
|
+
|
|
320
|
+
stmt = select(AssistantORM).where(
|
|
321
|
+
AssistantORM.assistant_id == assistant_id,
|
|
322
|
+
AssistantORM.user_id == user_identity,
|
|
323
|
+
)
|
|
324
|
+
assistant = await self.session.scalar(stmt)
|
|
325
|
+
if not assistant:
|
|
326
|
+
raise HTTPException(404, f"Assistant '{assistant_id}' not found")
|
|
327
|
+
|
|
328
|
+
now = datetime.now(UTC)
|
|
329
|
+
version_stmt = select(func.max(AssistantVersionORM.version)).where(
|
|
330
|
+
AssistantVersionORM.assistant_id == assistant_id
|
|
331
|
+
)
|
|
332
|
+
max_version = await self.session.scalar(version_stmt)
|
|
333
|
+
new_version = (max_version or 1) + 1 if max_version is not None else 1
|
|
334
|
+
|
|
335
|
+
new_version_details = {
|
|
336
|
+
"assistant_id": assistant_id,
|
|
337
|
+
"version": new_version,
|
|
338
|
+
"graph_id": request.graph_id or assistant.graph_id,
|
|
339
|
+
"config": config,
|
|
340
|
+
"context": context,
|
|
341
|
+
"created_at": now,
|
|
342
|
+
"name": request.name or assistant.name,
|
|
343
|
+
"description": request.description or assistant.description,
|
|
344
|
+
"metadata_dict": metadata,
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
assistant_version_orm = AssistantVersionORM(**new_version_details)
|
|
348
|
+
self.session.add(assistant_version_orm)
|
|
349
|
+
await self.session.commit()
|
|
350
|
+
|
|
351
|
+
assistant_update = (
|
|
352
|
+
update(AssistantORM)
|
|
353
|
+
.where(
|
|
354
|
+
AssistantORM.assistant_id == assistant_id,
|
|
355
|
+
AssistantORM.user_id == user_identity,
|
|
356
|
+
)
|
|
357
|
+
.values(
|
|
358
|
+
name=new_version_details["name"],
|
|
359
|
+
description=new_version_details["description"],
|
|
360
|
+
graph_id=new_version_details["graph_id"],
|
|
361
|
+
config=new_version_details["config"],
|
|
362
|
+
context=new_version_details["context"],
|
|
363
|
+
version=new_version,
|
|
364
|
+
updated_at=now,
|
|
365
|
+
)
|
|
366
|
+
)
|
|
367
|
+
await self.session.execute(assistant_update)
|
|
368
|
+
await self.session.commit()
|
|
369
|
+
updated_assistant = await self.session.scalar(stmt)
|
|
370
|
+
return to_pydantic(updated_assistant)
|
|
371
|
+
|
|
372
|
+
async def delete_assistant(self, assistant_id: str, user_identity: str) -> dict:
|
|
373
|
+
"""Delete assistant by ID"""
|
|
374
|
+
stmt = select(AssistantORM).where(
|
|
375
|
+
AssistantORM.assistant_id == assistant_id,
|
|
376
|
+
AssistantORM.user_id == user_identity,
|
|
377
|
+
)
|
|
378
|
+
assistant = await self.session.scalar(stmt)
|
|
379
|
+
|
|
380
|
+
if not assistant:
|
|
381
|
+
raise HTTPException(404, f"Assistant '{assistant_id}' not found")
|
|
382
|
+
|
|
383
|
+
await self.session.delete(assistant)
|
|
384
|
+
await self.session.commit()
|
|
385
|
+
|
|
386
|
+
return {"status": "deleted"}
|
|
387
|
+
|
|
388
|
+
async def set_assistant_latest(self, assistant_id: str, version: int, user_identity: str) -> Assistant:
|
|
389
|
+
"""Set the given version as the latest version of an assistant"""
|
|
390
|
+
stmt = select(AssistantORM).where(
|
|
391
|
+
AssistantORM.assistant_id == assistant_id,
|
|
392
|
+
AssistantORM.user_id == user_identity,
|
|
393
|
+
)
|
|
394
|
+
assistant = await self.session.scalar(stmt)
|
|
395
|
+
if not assistant:
|
|
396
|
+
raise HTTPException(404, f"Assistant '{assistant_id}' not found")
|
|
397
|
+
|
|
398
|
+
version_stmt = select(AssistantVersionORM).where(
|
|
399
|
+
AssistantVersionORM.assistant_id == assistant_id,
|
|
400
|
+
AssistantVersionORM.version == version,
|
|
401
|
+
)
|
|
402
|
+
assistant_version = await self.session.scalar(version_stmt)
|
|
403
|
+
if not assistant_version:
|
|
404
|
+
raise HTTPException(404, f"Version '{version}' for Assistant '{assistant_id}' not found")
|
|
405
|
+
|
|
406
|
+
assistant_update = (
|
|
407
|
+
update(AssistantORM)
|
|
408
|
+
.where(
|
|
409
|
+
AssistantORM.assistant_id == assistant_id,
|
|
410
|
+
AssistantORM.user_id == user_identity,
|
|
411
|
+
)
|
|
412
|
+
.values(
|
|
413
|
+
name=assistant_version.name,
|
|
414
|
+
description=assistant_version.description,
|
|
415
|
+
config=assistant_version.config,
|
|
416
|
+
context=assistant_version.context,
|
|
417
|
+
graph_id=assistant_version.graph_id,
|
|
418
|
+
version=version,
|
|
419
|
+
updated_at=datetime.now(UTC),
|
|
420
|
+
)
|
|
421
|
+
)
|
|
422
|
+
await self.session.execute(assistant_update)
|
|
423
|
+
await self.session.commit()
|
|
424
|
+
updated_assistant = await self.session.scalar(stmt)
|
|
425
|
+
return to_pydantic(updated_assistant)
|
|
426
|
+
|
|
427
|
+
async def list_assistant_versions(self, assistant_id: str, user_identity: str) -> list[Assistant]:
|
|
428
|
+
"""List all versions of an assistant"""
|
|
429
|
+
stmt = select(AssistantORM).where(
|
|
430
|
+
AssistantORM.assistant_id == assistant_id,
|
|
431
|
+
or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"),
|
|
432
|
+
)
|
|
433
|
+
assistant = await self.session.scalar(stmt)
|
|
434
|
+
if not assistant:
|
|
435
|
+
raise HTTPException(404, f"Assistant '{assistant_id}' not found")
|
|
436
|
+
|
|
437
|
+
stmt = (
|
|
438
|
+
select(AssistantVersionORM)
|
|
439
|
+
.where(AssistantVersionORM.assistant_id == assistant_id)
|
|
440
|
+
.order_by(AssistantVersionORM.version.desc())
|
|
441
|
+
)
|
|
442
|
+
result = await self.session.scalars(stmt)
|
|
443
|
+
versions = result.all()
|
|
444
|
+
|
|
445
|
+
if not versions:
|
|
446
|
+
raise HTTPException(404, f"No versions found for Assistant '{assistant_id}'")
|
|
447
|
+
|
|
448
|
+
# Convert to Pydantic models
|
|
449
|
+
version_list = [
|
|
450
|
+
Assistant(
|
|
451
|
+
assistant_id=assistant_id,
|
|
452
|
+
name=v.name,
|
|
453
|
+
description=v.description,
|
|
454
|
+
config=v.config or {},
|
|
455
|
+
context=v.context or {},
|
|
456
|
+
graph_id=v.graph_id,
|
|
457
|
+
user_id=user_identity,
|
|
458
|
+
version=v.version,
|
|
459
|
+
created_at=v.created_at,
|
|
460
|
+
updated_at=v.created_at,
|
|
461
|
+
metadata_dict=v.metadata_dict or {},
|
|
462
|
+
)
|
|
463
|
+
for v in versions
|
|
464
|
+
]
|
|
465
|
+
|
|
466
|
+
return version_list
|
|
467
|
+
|
|
468
|
+
async def get_assistant_schemas(self, assistant_id: str, user_identity: str) -> dict:
|
|
469
|
+
"""Get input, output, state, config and context schemas for an assistant"""
|
|
470
|
+
stmt = select(AssistantORM).where(
|
|
471
|
+
AssistantORM.assistant_id == assistant_id,
|
|
472
|
+
or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"),
|
|
473
|
+
)
|
|
474
|
+
assistant = await self.session.scalar(stmt)
|
|
475
|
+
|
|
476
|
+
if not assistant:
|
|
477
|
+
raise HTTPException(404, f"Assistant '{assistant_id}' not found")
|
|
478
|
+
|
|
479
|
+
try:
|
|
480
|
+
# Use get_graph_for_validation since we only need schema extraction,
|
|
481
|
+
# not checkpointer/store for execution
|
|
482
|
+
graph = await self.langgraph_service.get_graph_for_validation(assistant.graph_id)
|
|
483
|
+
schemas = _extract_graph_schemas(graph)
|
|
484
|
+
|
|
485
|
+
return {"graph_id": assistant.graph_id, **schemas}
|
|
486
|
+
|
|
487
|
+
except Exception as e:
|
|
488
|
+
raise HTTPException(400, f"Failed to extract schemas: {str(e)}") from e
|
|
489
|
+
|
|
490
|
+
async def get_assistant_graph(self, assistant_id: str, xray: bool | int, user_identity: str) -> dict:
|
|
491
|
+
"""Get the graph structure for visualization"""
|
|
492
|
+
stmt = select(AssistantORM).where(
|
|
493
|
+
AssistantORM.assistant_id == assistant_id,
|
|
494
|
+
or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"),
|
|
495
|
+
)
|
|
496
|
+
assistant = await self.session.scalar(stmt)
|
|
497
|
+
|
|
498
|
+
if not assistant:
|
|
499
|
+
raise HTTPException(404, f"Assistant '{assistant_id}' not found")
|
|
500
|
+
|
|
501
|
+
try:
|
|
502
|
+
# Use get_graph_for_validation since we only need graph structure,
|
|
503
|
+
# not checkpointer/store for execution
|
|
504
|
+
graph = await self.langgraph_service.get_graph_for_validation(assistant.graph_id)
|
|
505
|
+
|
|
506
|
+
# Validate xray if it's an integer (not a boolean)
|
|
507
|
+
if isinstance(xray, int) and not isinstance(xray, bool) and xray <= 0:
|
|
508
|
+
raise HTTPException(422, detail="Invalid xray value")
|
|
509
|
+
|
|
510
|
+
try:
|
|
511
|
+
drawable_graph = await graph.aget_graph(xray=xray)
|
|
512
|
+
json_graph = drawable_graph.to_json()
|
|
513
|
+
|
|
514
|
+
for node in json_graph.get("nodes", []):
|
|
515
|
+
if (data := node.get("data")) and isinstance(data, dict):
|
|
516
|
+
data.pop("id", None)
|
|
517
|
+
|
|
518
|
+
return json_graph
|
|
519
|
+
except NotImplementedError as e:
|
|
520
|
+
raise HTTPException(422, detail="The graph does not support visualization") from e
|
|
521
|
+
|
|
522
|
+
except HTTPException:
|
|
523
|
+
raise
|
|
524
|
+
except Exception as e:
|
|
525
|
+
raise HTTPException(400, f"Failed to get graph: {str(e)}") from e
|
|
526
|
+
|
|
527
|
+
async def get_assistant_subgraphs(
|
|
528
|
+
self,
|
|
529
|
+
assistant_id: str,
|
|
530
|
+
namespace: str | None,
|
|
531
|
+
recurse: bool,
|
|
532
|
+
user_identity: str,
|
|
533
|
+
) -> dict:
|
|
534
|
+
"""Get subgraphs of an assistant"""
|
|
535
|
+
stmt = select(AssistantORM).where(
|
|
536
|
+
AssistantORM.assistant_id == assistant_id,
|
|
537
|
+
or_(AssistantORM.user_id == user_identity, AssistantORM.user_id == "system"),
|
|
538
|
+
)
|
|
539
|
+
assistant = await self.session.scalar(stmt)
|
|
540
|
+
|
|
541
|
+
if not assistant:
|
|
542
|
+
raise HTTPException(404, f"Assistant '{assistant_id}' not found")
|
|
543
|
+
|
|
544
|
+
try:
|
|
545
|
+
# Use get_graph_for_validation since we only need schema extraction,
|
|
546
|
+
# not checkpointer/store for execution
|
|
547
|
+
graph = await self.langgraph_service.get_graph_for_validation(assistant.graph_id)
|
|
548
|
+
|
|
549
|
+
try:
|
|
550
|
+
subgraphs = {
|
|
551
|
+
ns: _extract_graph_schemas(subgraph)
|
|
552
|
+
async for ns, subgraph in graph.aget_subgraphs(namespace=namespace, recurse=recurse)
|
|
553
|
+
}
|
|
554
|
+
return subgraphs
|
|
555
|
+
except NotImplementedError as e:
|
|
556
|
+
raise HTTPException(422, detail="The graph does not support subgraphs") from e
|
|
557
|
+
|
|
558
|
+
except HTTPException:
|
|
559
|
+
raise
|
|
560
|
+
except Exception as e:
|
|
561
|
+
raise HTTPException(400, f"Failed to get subgraphs: {str(e)}") from e
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def get_assistant_service(
|
|
565
|
+
session: AsyncSession = Depends(get_session),
|
|
566
|
+
langgraph_service: LangGraphService = Depends(get_langgraph_service),
|
|
567
|
+
) -> AssistantService:
|
|
568
|
+
"""Dependency injection for AssistantService"""
|
|
569
|
+
return AssistantService(session, langgraph_service)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Abstract base classes for the broker system"""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseRunBroker(ABC):
|
|
9
|
+
"""Abstract base class for a run-specific event broker"""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
async def put(self, event_id: str, payload: Any) -> None:
|
|
13
|
+
"""Put an event into the broker queue"""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def aiter(self) -> AsyncIterator[tuple[str, Any]]:
|
|
18
|
+
"""Async iterator yielding (event_id, payload) pairs"""
|
|
19
|
+
# Abstract async generator method; must be implemented by subclass
|
|
20
|
+
raise NotImplementedError("aiter method must be implemented by subclass")
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def mark_finished(self) -> None:
|
|
24
|
+
"""Mark this broker as finished"""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def is_finished(self) -> bool:
|
|
29
|
+
"""Check if this broker is finished"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseBrokerManager(ABC):
|
|
34
|
+
"""Abstract base class for managing multiple RunBroker instances"""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def get_or_create_broker(self, run_id: str) -> BaseRunBroker:
|
|
38
|
+
"""Get or create a broker for a run"""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def get_broker(self, run_id: str) -> BaseRunBroker | None:
|
|
43
|
+
"""Get an existing broker or None"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def cleanup_broker(self, run_id: str) -> None:
|
|
48
|
+
"""Clean up a broker for a run"""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
async def start_cleanup_task(self) -> None:
|
|
53
|
+
"""Start background cleanup task for old brokers"""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
async def stop_cleanup_task(self) -> None:
|
|
58
|
+
"""Stop background cleanup task"""
|
|
59
|
+
pass
|