agentseek-api 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. agentseek_api/__init__.py +1 -0
  2. agentseek_api/a2a_server.py +879 -0
  3. agentseek_api/api/__init__.py +1 -0
  4. agentseek_api/api/assistants.py +251 -0
  5. agentseek_api/api/crons.py +88 -0
  6. agentseek_api/api/runs.py +299 -0
  7. agentseek_api/api/stateless_runs.py +80 -0
  8. agentseek_api/api/store.py +111 -0
  9. agentseek_api/api/threads.py +795 -0
  10. agentseek_api/cli.py +1017 -0
  11. agentseek_api/core/__init__.py +1 -0
  12. agentseek_api/core/a2a_config.py +23 -0
  13. agentseek_api/core/auth_deps.py +18 -0
  14. agentseek_api/core/auth_middleware.py +301 -0
  15. agentseek_api/core/config_file.py +30 -0
  16. agentseek_api/core/database.py +245 -0
  17. agentseek_api/core/mcp_config.py +23 -0
  18. agentseek_api/core/oceanbase_checkpointer.py +50 -0
  19. agentseek_api/core/orm.py +170 -0
  20. agentseek_api/core/runtime_store.py +326 -0
  21. agentseek_api/core/store_config.py +178 -0
  22. agentseek_api/main.py +249 -0
  23. agentseek_api/mcp_server.py +192 -0
  24. agentseek_api/models/__init__.py +1 -0
  25. agentseek_api/models/api.py +250 -0
  26. agentseek_api/models/auth.py +6 -0
  27. agentseek_api/models/protocol.py +18 -0
  28. agentseek_api/scheduler.py +97 -0
  29. agentseek_api/services/__init__.py +1 -0
  30. agentseek_api/services/cron_models.py +57 -0
  31. agentseek_api/services/cron_rrule.py +205 -0
  32. agentseek_api/services/cron_scheduler.py +406 -0
  33. agentseek_api/services/cron_service.py +183 -0
  34. agentseek_api/services/cron_webhooks.py +195 -0
  35. agentseek_api/services/executor.py +39 -0
  36. agentseek_api/services/langgraph_service.py +370 -0
  37. agentseek_api/services/redis_queue.py +121 -0
  38. agentseek_api/services/run_executor.py +854 -0
  39. agentseek_api/services/run_jobs.py +205 -0
  40. agentseek_api/services/run_preparation.py +397 -0
  41. agentseek_api/services/run_state.py +78 -0
  42. agentseek_api/services/sample_graphs.py +187 -0
  43. agentseek_api/services/stream_persistence.py +237 -0
  44. agentseek_api/services/thread_checkpoint_store.py +232 -0
  45. agentseek_api/services/thread_protocol.py +787 -0
  46. agentseek_api/services/thread_service.py +23 -0
  47. agentseek_api/settings.py +41 -0
  48. agentseek_api/worker.py +103 -0
  49. agentseek_api-0.0.1.dist-info/METADATA +646 -0
  50. agentseek_api-0.0.1.dist-info/RECORD +53 -0
  51. agentseek_api-0.0.1.dist-info/WHEEL +4 -0
  52. agentseek_api-0.0.1.dist-info/entry_points.txt +2 -0
  53. agentseek_api-0.0.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,879 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ from collections.abc import AsyncIterator
6
+ from contextlib import suppress
7
+ from dataclasses import dataclass, field
8
+ from threading import Lock
9
+ from typing import Any
10
+ from uuid import uuid4
11
+
12
+ from fastapi import HTTPException
13
+ from fastapi.responses import StreamingResponse
14
+ from langchain_core.messages import BaseMessage, HumanMessage
15
+ from langgraph.constants import CONF, CONFIG_KEY_CHECKPOINTER
16
+ from sqlalchemy import select
17
+
18
+ from agentseek_api import __version__
19
+ from agentseek_api.core.auth_middleware import get_config_auth_openapi
20
+ from agentseek_api.core.database import db_manager
21
+ from agentseek_api.core.orm import Assistant
22
+ from agentseek_api.core.runtime_store import UserScopedStore
23
+ from agentseek_api.models.api import AssistantRead
24
+ from agentseek_api.models.auth import User
25
+ from agentseek_api.services.langgraph_service import GraphEntry
26
+ from agentseek_api.settings import settings
27
+
28
+
29
+ @dataclass
30
+ class A2ATaskRecord:
31
+ task_id: str
32
+ assistant_id: str
33
+ user_id: str
34
+ context_id: str
35
+ state: str = "submitted"
36
+ status_message: str = ""
37
+ artifacts: list[dict[str, Any]] = field(default_factory=list)
38
+ cancellation_requested: bool = False
39
+
40
+
41
+ class A2ATaskRegistry:
42
+ def __init__(self, *, max_tasks: int = 1000) -> None:
43
+ self._tasks: dict[str, A2ATaskRecord] = {}
44
+ self._lock = Lock()
45
+ self._max_tasks = max_tasks
46
+
47
+ def save(self, record: A2ATaskRecord) -> None:
48
+ with self._lock:
49
+ self._tasks.pop(record.task_id, None)
50
+ self._tasks[record.task_id] = record
51
+ self._prune_locked()
52
+
53
+ def get(self, task_id: str) -> A2ATaskRecord:
54
+ with self._lock:
55
+ try:
56
+ return self._tasks[task_id]
57
+ except KeyError as exc:
58
+ raise ValueError(f"Unknown task: {task_id}") from exc
59
+
60
+ def _prune_locked(self) -> None:
61
+ while len(self._tasks) > self._max_tasks:
62
+ evicted = False
63
+ for task_id, record in list(self._tasks.items()):
64
+ if _is_terminal_state(record.state):
65
+ del self._tasks[task_id]
66
+ evicted = True
67
+ break
68
+ if not evicted:
69
+ oldest_task_id = next(iter(self._tasks), None)
70
+ if oldest_task_id is None:
71
+ break
72
+ del self._tasks[oldest_task_id]
73
+
74
+
75
+ def is_a2a_compatible_entry(entry: GraphEntry) -> bool:
76
+ input_schema = entry.input_schema
77
+ if input_schema.get("type") != "object":
78
+ return False
79
+
80
+ properties = input_schema.get("properties")
81
+ required = input_schema.get("required")
82
+ if not isinstance(properties, dict) or not isinstance(required, list):
83
+ return False
84
+
85
+ messages = properties.get("messages")
86
+ if not isinstance(messages, dict):
87
+ return False
88
+
89
+ return messages.get("type") == "array" and "messages" in required
90
+
91
+
92
+ def _agent_card_auth_metadata() -> dict[str, Any]:
93
+ auth_openapi = get_config_auth_openapi()
94
+ if isinstance(auth_openapi, dict):
95
+ security_schemes = auth_openapi.get("securitySchemes")
96
+ security = auth_openapi.get("security")
97
+ if isinstance(security_schemes, dict) and isinstance(security, list):
98
+ translated_schemes: dict[str, Any] = {}
99
+ for scheme_name, scheme in security_schemes.items():
100
+ if not isinstance(scheme_name, str) or not isinstance(scheme, dict):
101
+ continue
102
+ translated = _translate_openapi_security_scheme(scheme)
103
+ if translated is not None:
104
+ translated_schemes[scheme_name] = translated
105
+
106
+ translated_security = _translate_openapi_security_requirements(
107
+ security,
108
+ retained_scheme_names=set(translated_schemes.keys()),
109
+ )
110
+ if translated_schemes:
111
+ metadata: dict[str, Any] = {"securitySchemes": translated_schemes}
112
+ if translated_security:
113
+ metadata["securityRequirements"] = translated_security
114
+ return metadata
115
+
116
+ auth_type = settings.AUTH_TYPE.strip().lower()
117
+ if auth_type == "api_key":
118
+ return {
119
+ "securitySchemes": {
120
+ "apiKeyAuth": {
121
+ "apiKeySecurityScheme": {
122
+ "location": "header",
123
+ "name": "x-api-key",
124
+ }
125
+ }
126
+ },
127
+ "securityRequirements": [{"apiKeyAuth": []}],
128
+ }
129
+ if auth_type == "jwt":
130
+ return {
131
+ "securitySchemes": {
132
+ "bearerAuth": {
133
+ "httpAuthSecurityScheme": {
134
+ "scheme": "bearer",
135
+ "bearerFormat": "JWT",
136
+ }
137
+ }
138
+ },
139
+ "securityRequirements": [{"bearerAuth": []}],
140
+ }
141
+ return {}
142
+
143
+
144
+ def _translate_openapi_security_scheme(scheme: dict[str, Any]) -> dict[str, Any] | None:
145
+ description = scheme.get("description")
146
+ description_value = description if isinstance(description, str) and description else None
147
+
148
+ if scheme.get("type") == "apiKey":
149
+ location = scheme.get("in")
150
+ name = scheme.get("name")
151
+ if isinstance(location, str) and isinstance(name, str):
152
+ translated: dict[str, Any] = {
153
+ "location": location,
154
+ "name": name,
155
+ }
156
+ if description_value is not None:
157
+ translated["description"] = description_value
158
+ return {"apiKeySecurityScheme": translated}
159
+
160
+ if scheme.get("type") == "http" and scheme.get("scheme") == "bearer":
161
+ translated: dict[str, Any] = {"scheme": "bearer"}
162
+ bearer_format = scheme.get("bearerFormat")
163
+ if isinstance(bearer_format, str) and bearer_format:
164
+ translated["bearerFormat"] = bearer_format
165
+ if description_value is not None:
166
+ translated["description"] = description_value
167
+ return {"httpAuthSecurityScheme": translated}
168
+
169
+ return None
170
+
171
+
172
+ def _translate_openapi_security_requirements(
173
+ security: list[Any],
174
+ *,
175
+ retained_scheme_names: set[str],
176
+ ) -> list[dict[str, list[Any]]]:
177
+ translated: list[dict[str, list[Any]]] = []
178
+ for item in security:
179
+ if not isinstance(item, dict):
180
+ continue
181
+ if not all(isinstance(name, str) and name in retained_scheme_names and isinstance(scopes, list) for name, scopes in item.items()):
182
+ continue
183
+ translated.append(item)
184
+ return translated
185
+
186
+
187
+ def build_agent_card(base_url: str, assistant: AssistantRead, entry: GraphEntry) -> dict[str, Any]:
188
+ description = assistant.description or ""
189
+ url = f"{base_url}/a2a/{assistant.assistant_id}"
190
+ skill_description = assistant.description or entry.description or f"Runs the {assistant.graph_id} graph."
191
+
192
+ card: dict[str, Any] = {
193
+ "name": assistant.name,
194
+ "description": description,
195
+ "supportedInterfaces": [
196
+ {
197
+ "url": url,
198
+ "protocolBinding": "JSONRPC",
199
+ "protocolVersion": "1.0",
200
+ }
201
+ ],
202
+ "version": __version__,
203
+ "capabilities": {"streaming": True, "pushNotifications": False},
204
+ "defaultInputModes": ["text/plain"],
205
+ "defaultOutputModes": ["text/plain"],
206
+ "skills": [
207
+ {
208
+ "id": entry.tool_name,
209
+ "name": assistant.name,
210
+ "description": skill_description,
211
+ "tags": [assistant.graph_id, entry.tool_name],
212
+ }
213
+ ],
214
+ }
215
+ card.update(_agent_card_auth_metadata())
216
+ return card
217
+
218
+
219
+ def build_graph_config(*, user: User, context_id: str) -> tuple[UserScopedStore, dict[str, Any]]:
220
+ runtime_store = UserScopedStore(db_manager.get_store(), user_id=user.identity)
221
+ checkpointer = db_manager.get_langgraph_checkpointer()
222
+ return runtime_store, {
223
+ CONF: {
224
+ "thread_id": context_id,
225
+ "checkpoint_ns": f"a2a:{uuid4()}",
226
+ CONFIG_KEY_CHECKPOINTER: checkpointer,
227
+ "store": runtime_store,
228
+ "langgraph_auth_user": user.model_dump(),
229
+ }
230
+ }
231
+
232
+
233
+ def make_text_artifact(text: str) -> dict[str, Any]:
234
+ return {
235
+ "artifactId": str(uuid4()),
236
+ "name": "Assistant Response",
237
+ "parts": [{"kind": "text", "text": text}],
238
+ }
239
+
240
+
241
+ def make_sdk_text_artifact(text: str) -> dict[str, Any]:
242
+ return {
243
+ "artifactId": str(uuid4()),
244
+ "name": "Assistant Response",
245
+ "parts": [{"text": text}],
246
+ }
247
+
248
+
249
+ def _jsonrpc_result(request_id: Any, result: dict[str, Any]) -> dict[str, Any]:
250
+ return {"jsonrpc": "2.0", "id": request_id, "result": result}
251
+
252
+
253
+ def _jsonrpc_error(request_id: Any, *, code: int, message: str) -> dict[str, Any]:
254
+ return {"jsonrpc": "2.0", "id": request_id, "error": {"code": code, "message": message}}
255
+
256
+
257
+ def _task_result(record: A2ATaskRecord, *, sdk_compatible: bool = False) -> dict[str, Any]:
258
+ status_state = _sdk_task_state(record.state) if sdk_compatible else record.state
259
+ status: dict[str, Any] = {"state": status_state}
260
+ if record.status_message and not sdk_compatible:
261
+ status["message"] = {"kind": "text", "text": record.status_message}
262
+ artifacts = (
263
+ [make_sdk_text_artifact(_artifact_text(artifact)) for artifact in record.artifacts]
264
+ if sdk_compatible
265
+ else record.artifacts
266
+ )
267
+ result = {
268
+ "id": record.task_id,
269
+ "contextId": record.context_id,
270
+ "status": status,
271
+ "artifacts": artifacts,
272
+ }
273
+ if not sdk_compatible:
274
+ result["kind"] = "task"
275
+ return result
276
+
277
+
278
+ def _is_terminal_state(state: str) -> bool:
279
+ return state in {"completed", "failed", "cancelled"}
280
+
281
+
282
+ def _sdk_task_state(state: str) -> str:
283
+ return {
284
+ "submitted": "TASK_STATE_SUBMITTED",
285
+ "working": "TASK_STATE_WORKING",
286
+ "completed": "TASK_STATE_COMPLETED",
287
+ "failed": "TASK_STATE_FAILED",
288
+ "cancelled": "TASK_STATE_CANCELED",
289
+ }.get(state, "TASK_STATE_UNSPECIFIED")
290
+
291
+
292
+ def _artifact_text(artifact: dict[str, Any]) -> str:
293
+ parts = artifact.get("parts")
294
+ if not isinstance(parts, list):
295
+ return ""
296
+ for part in parts:
297
+ if not isinstance(part, dict):
298
+ continue
299
+ text = part.get("text")
300
+ if isinstance(text, str):
301
+ return text
302
+ if part.get("kind") == "text" and isinstance(part.get("text"), str):
303
+ return part["text"]
304
+ return ""
305
+
306
+
307
+ def _message_content(message: BaseMessage) -> str:
308
+ content = getattr(message, "content", "")
309
+ if isinstance(content, str):
310
+ return content
311
+ return json.dumps(content, default=str)
312
+
313
+
314
+ def _extract_chunk_messages(chunk: Any) -> list[BaseMessage]:
315
+ messages: list[BaseMessage] = []
316
+ if isinstance(chunk, BaseMessage):
317
+ return [chunk]
318
+ if isinstance(chunk, dict):
319
+ nested_messages = chunk.get("messages")
320
+ if isinstance(nested_messages, list):
321
+ messages.extend(item for item in nested_messages if isinstance(item, BaseMessage))
322
+ for value in chunk.values():
323
+ if value is nested_messages:
324
+ continue
325
+ messages.extend(_extract_chunk_messages(value))
326
+ elif isinstance(chunk, (list, tuple)):
327
+ for item in chunk:
328
+ messages.extend(_extract_chunk_messages(item))
329
+ return messages
330
+
331
+
332
+ def _extract_stream_chunk_text(chunk: Any) -> str | None:
333
+ messages = _extract_chunk_messages(chunk)
334
+ if messages:
335
+ texts = [_message_content(message) for message in messages if _message_content(message)]
336
+ combined = "\n".join(texts).strip()
337
+ return combined or None
338
+ text = getattr(chunk, "text", None)
339
+ if isinstance(text, str) and text:
340
+ return text
341
+ if isinstance(chunk, dict):
342
+ raw_text = chunk.get("text")
343
+ if isinstance(raw_text, str) and raw_text:
344
+ return raw_text
345
+ return None
346
+
347
+
348
+ def _is_root_stream_event(event: dict[str, Any]) -> bool:
349
+ parent_ids = event.get("parent_ids")
350
+ return isinstance(parent_ids, list) and not parent_ids
351
+
352
+
353
+ def _extract_request_text(message: dict[str, Any]) -> str:
354
+ parts = message.get("parts")
355
+ if not isinstance(parts, list) or not parts:
356
+ raise ValueError("message.parts must be a non-empty array.")
357
+
358
+ texts: list[str] = []
359
+ for part in parts:
360
+ if not isinstance(part, dict):
361
+ raise ValueError("Only text parts are supported.")
362
+ direct_text = part.get("text")
363
+ if isinstance(direct_text, str):
364
+ texts.append(direct_text)
365
+ continue
366
+ if part.get("kind") == "text" and isinstance(direct_text, str):
367
+ texts.append(direct_text)
368
+ continue
369
+ if isinstance(direct_text, dict) and isinstance(direct_text.get("text"), str):
370
+ texts.append(direct_text["text"])
371
+ continue
372
+ raise ValueError("Only text parts are supported.")
373
+ return "\n".join(texts)
374
+
375
+
376
+ def _normalize_optional_id(value: Any) -> str | None:
377
+ return value if isinstance(value, str) and value else None
378
+
379
+
380
+ def _extract_output_text(extracted: Any) -> str:
381
+ if isinstance(extracted, str):
382
+ return extracted
383
+ if isinstance(extracted, dict):
384
+ final_text = extracted.get("final_text")
385
+ if isinstance(final_text, str):
386
+ return final_text
387
+ text = extracted.get("text")
388
+ if isinstance(text, str):
389
+ return text
390
+ messages = extracted.get("messages")
391
+ if isinstance(messages, list):
392
+ message_texts = [_message_content(message) for message in messages if isinstance(message, BaseMessage)]
393
+ if message_texts:
394
+ return message_texts[-1]
395
+ return json.dumps(extracted, default=str)
396
+
397
+
398
+ async def _invoke_a2a_graph(
399
+ *,
400
+ entry: GraphEntry,
401
+ user: User,
402
+ context_id: str,
403
+ text: str,
404
+ ) -> dict[str, Any]:
405
+ runtime_store, config = build_graph_config(user=user, context_id=context_id)
406
+ configurable = config.get(CONF, {})
407
+ graph = entry.build_graph(
408
+ checkpointer=configurable.get(CONFIG_KEY_CHECKPOINTER),
409
+ store=runtime_store,
410
+ )
411
+ graph_payload = {"messages": [HumanMessage(content=text)]}
412
+ prepared = entry.prepare_input(graph_payload)
413
+ if hasattr(graph, "ainvoke"):
414
+ raw_result = await graph.ainvoke(prepared, config)
415
+ else: # pragma: no cover
416
+ raw_result = graph.invoke(prepared, config)
417
+ extracted = entry.extract_output(raw_result, graph_payload)
418
+ if isinstance(extracted, dict):
419
+ return extracted
420
+ return {"result": extracted}
421
+
422
+
423
+ async def _invoke_a2a_graph_stream(
424
+ *,
425
+ entry: GraphEntry,
426
+ user: User,
427
+ context_id: str,
428
+ text: str,
429
+ ) -> AsyncIterator[dict[str, Any]]:
430
+ runtime_store, config = build_graph_config(user=user, context_id=context_id)
431
+ configurable = config.get(CONF, {})
432
+ graph = entry.build_graph(
433
+ checkpointer=configurable.get(CONFIG_KEY_CHECKPOINTER),
434
+ store=runtime_store,
435
+ )
436
+ graph_payload = {"messages": [HumanMessage(content=text)]}
437
+ prepared = entry.prepare_input(graph_payload)
438
+
439
+ if hasattr(graph, "astream_events"):
440
+ async for stream_event in graph.astream_events(prepared, config, version="v2"):
441
+ raw_event_name = stream_event.get("event")
442
+ if raw_event_name in {"on_chat_model_stream", "on_llm_stream", "on_chain_stream"}:
443
+ data = stream_event.get("data", {})
444
+ chunk = data.get("chunk") if isinstance(data, dict) else None
445
+ text = _extract_stream_chunk_text(chunk)
446
+ if text:
447
+ yield {"text": text}
448
+ if raw_event_name == "on_chain_end" and _is_root_stream_event(stream_event):
449
+ data = stream_event.get("data", {})
450
+ output = data.get("output") if isinstance(data, dict) else None
451
+ extracted = entry.extract_output(output, graph_payload)
452
+ if isinstance(extracted, dict):
453
+ yield extracted
454
+ else:
455
+ yield {"result": extracted}
456
+ return
457
+ return
458
+
459
+ if hasattr(graph, "astream"):
460
+ async for raw_result in graph.astream(prepared, config):
461
+ extracted = entry.extract_output(raw_result, graph_payload)
462
+ if isinstance(extracted, dict):
463
+ yield extracted
464
+ else:
465
+ yield {"result": extracted}
466
+ return
467
+
468
+ if hasattr(graph, "ainvoke"):
469
+ raw_result = await graph.ainvoke(prepared, config)
470
+ extracted = entry.extract_output(raw_result, graph_payload)
471
+ if isinstance(extracted, dict):
472
+ yield extracted
473
+ else:
474
+ yield {"result": extracted}
475
+ return
476
+
477
+ raw_result = graph.invoke(prepared, config) # pragma: no cover
478
+ extracted = entry.extract_output(raw_result, graph_payload)
479
+ if isinstance(extracted, dict):
480
+ yield extracted
481
+ else: # pragma: no cover
482
+ yield {"result": extracted}
483
+
484
+
485
+ def _resolve_task_record(
486
+ *,
487
+ registry: A2ATaskRegistry,
488
+ assistant_id: str,
489
+ user: User,
490
+ context_id: str | None,
491
+ task_id: str,
492
+ ) -> A2ATaskRecord | None:
493
+ try:
494
+ existing = registry.get(task_id)
495
+ except ValueError:
496
+ existing = None
497
+
498
+ if existing is not None:
499
+ if existing.user_id != user.identity or existing.assistant_id != assistant_id:
500
+ return None
501
+ existing.context_id = context_id or existing.context_id
502
+ existing.state = "submitted"
503
+ existing.status_message = ""
504
+ existing.artifacts = []
505
+ existing.cancellation_requested = False
506
+ return existing
507
+
508
+ return A2ATaskRecord(
509
+ task_id=task_id,
510
+ assistant_id=assistant_id,
511
+ user_id=user.identity,
512
+ context_id=context_id or str(uuid4()),
513
+ )
514
+
515
+
516
+ def _cancel_task(record: A2ATaskRecord) -> A2ATaskRecord:
517
+ if not _is_terminal_state(record.state):
518
+ record.cancellation_requested = True
519
+ record.state = "cancelled"
520
+ record.status_message = "Task cancelled"
521
+ return record
522
+
523
+
524
+ def _task_status_update_event(
525
+ record: A2ATaskRecord,
526
+ *,
527
+ final: bool,
528
+ sdk_compatible: bool = False,
529
+ ) -> dict[str, Any]:
530
+ status_state = _sdk_task_state(record.state) if sdk_compatible else record.state
531
+ status: dict[str, Any] = {"state": status_state}
532
+ if record.status_message and not sdk_compatible:
533
+ status["message"] = {"kind": "text", "text": record.status_message}
534
+ event = {
535
+ "taskId": record.task_id,
536
+ "contextId": record.context_id,
537
+ "status": status,
538
+ }
539
+ if sdk_compatible:
540
+ return {"statusUpdate": event}
541
+ event["final"] = final
542
+ event["kind"] = "status-update"
543
+ return event
544
+
545
+
546
+ def _task_artifact_update_event(
547
+ record: A2ATaskRecord,
548
+ *,
549
+ artifact: dict[str, Any],
550
+ append: bool,
551
+ last_chunk: bool,
552
+ sdk_compatible: bool = False,
553
+ ) -> dict[str, Any]:
554
+ artifact_payload = make_sdk_text_artifact(_artifact_text(artifact)) if sdk_compatible else artifact
555
+ event = {
556
+ "taskId": record.task_id,
557
+ "contextId": record.context_id,
558
+ "artifact": artifact_payload,
559
+ "append": append,
560
+ "lastChunk": last_chunk,
561
+ }
562
+ if sdk_compatible:
563
+ return {"artifactUpdate": event}
564
+ event["kind"] = "artifact-update"
565
+ return event
566
+
567
+
568
+ def _sdk_send_message_result(record: A2ATaskRecord) -> dict[str, Any]:
569
+ return {"task": _task_result(record, sdk_compatible=True)}
570
+
571
+
572
+ def _canonical_a2a_method(method: str) -> tuple[str, bool]:
573
+ mapping = {
574
+ "message/send": ("message/send", False),
575
+ "message/stream": ("message/stream", False),
576
+ "tasks/get": ("tasks/get", False),
577
+ "tasks/cancel": ("tasks/cancel", False),
578
+ "SendMessage": ("message/send", True),
579
+ "SendStreamingMessage": ("message/stream", True),
580
+ "GetTask": ("tasks/get", True),
581
+ "CancelTask": ("tasks/cancel", True),
582
+ }
583
+ return mapping.get(method, (method, False))
584
+
585
+
586
+ def _sse_jsonrpc_event(*, request_id: Any, result: dict[str, Any]) -> str:
587
+ return f"event: message\ndata: {json.dumps(_jsonrpc_result(request_id, result))}\n\n"
588
+
589
+
590
+ async def handle_a2a_request(
591
+ *,
592
+ assistant_id: str,
593
+ payload: dict[str, Any],
594
+ user: User,
595
+ service,
596
+ registry: A2ATaskRegistry,
597
+ ) -> dict[str, Any]:
598
+ request_id = payload.get("id")
599
+ if payload.get("jsonrpc") != "2.0" or not isinstance(payload.get("method"), str):
600
+ return _jsonrpc_error(request_id, code=-32600, message="Invalid JSON-RPC request.")
601
+
602
+ method, sdk_compatible = _canonical_a2a_method(payload["method"])
603
+
604
+ try:
605
+ assistant = await load_assistant(assistant_id)
606
+ except HTTPException as exc:
607
+ return _jsonrpc_error(
608
+ request_id,
609
+ code=-32001 if sdk_compatible else -32004,
610
+ message=str(exc.detail),
611
+ )
612
+
613
+ entry = service.get_entry(assistant.graph_id)
614
+ if not is_a2a_compatible_entry(entry):
615
+ return _jsonrpc_error(
616
+ request_id,
617
+ code=-32004 if sdk_compatible else -32000,
618
+ message="Assistant graph is not A2A-compatible.",
619
+ )
620
+
621
+ params = payload.get("params")
622
+ if not isinstance(params, dict):
623
+ return _jsonrpc_error(request_id, code=-32602, message="params must be an object.")
624
+
625
+ if method == "tasks/get":
626
+ task_id = params.get("id")
627
+ if not isinstance(task_id, str) or not task_id:
628
+ return _jsonrpc_error(request_id, code=-32602, message="tasks/get requires params.id.")
629
+ try:
630
+ record = registry.get(task_id)
631
+ except ValueError as exc:
632
+ return _jsonrpc_error(request_id, code=-32001 if sdk_compatible else -32004, message=str(exc))
633
+ if record.assistant_id != assistant_id or record.user_id != user.identity:
634
+ return _jsonrpc_error(
635
+ request_id,
636
+ code=-32001 if sdk_compatible else -32004,
637
+ message=f"Unknown task: {task_id}",
638
+ )
639
+ return _jsonrpc_result(request_id, _task_result(record, sdk_compatible=sdk_compatible))
640
+
641
+ if method == "tasks/cancel":
642
+ task_id = params.get("id")
643
+ if not isinstance(task_id, str) or not task_id:
644
+ return _jsonrpc_error(request_id, code=-32602, message="tasks/cancel requires params.id.")
645
+ try:
646
+ record = registry.get(task_id)
647
+ except ValueError as exc:
648
+ return _jsonrpc_error(request_id, code=-32001 if sdk_compatible else -32004, message=str(exc))
649
+ if record.assistant_id != assistant_id or record.user_id != user.identity:
650
+ return _jsonrpc_error(
651
+ request_id,
652
+ code=-32001 if sdk_compatible else -32004,
653
+ message=f"Unknown task: {task_id}",
654
+ )
655
+ registry.save(_cancel_task(record))
656
+ return _jsonrpc_result(request_id, _task_result(record, sdk_compatible=sdk_compatible))
657
+
658
+ if method not in {"message/send", "message/stream"}:
659
+ return _jsonrpc_error(request_id, code=-32601, message=f"Unsupported method: {method}")
660
+
661
+ message = params.get("message")
662
+ if not isinstance(message, dict):
663
+ return _jsonrpc_error(request_id, code=-32602, message=f"{method} requires params.message.")
664
+
665
+ try:
666
+ text = _extract_request_text(message)
667
+ except ValueError as exc:
668
+ return _jsonrpc_error(request_id, code=-32602, message=str(exc))
669
+
670
+ context_id = _normalize_optional_id(message.get("contextId")) or _normalize_optional_id(params.get("contextId"))
671
+ task_id = _normalize_optional_id(message.get("taskId")) or _normalize_optional_id(params.get("taskId")) or str(uuid4())
672
+ record = _resolve_task_record(
673
+ registry=registry,
674
+ assistant_id=assistant_id,
675
+ user=user,
676
+ context_id=context_id,
677
+ task_id=task_id,
678
+ )
679
+ if record is None:
680
+ return _jsonrpc_error(
681
+ request_id,
682
+ code=-32001 if sdk_compatible else -32004,
683
+ message=f"Unknown task: {task_id}",
684
+ )
685
+ registry.save(record)
686
+
687
+ if method == "message/stream":
688
+ async def _event_iter() -> AsyncIterator[str]:
689
+ record.state = "working"
690
+ registry.save(record)
691
+ yield _sse_jsonrpc_event(
692
+ request_id=request_id,
693
+ result=_task_status_update_event(record, final=False, sdk_compatible=sdk_compatible),
694
+ )
695
+
696
+ final_extracted: dict[str, Any] | None = None
697
+ pending_text: str | None = None
698
+ stream_queue: asyncio.Queue[tuple[str, Any]] = asyncio.Queue()
699
+
700
+ async def _produce_stream() -> None:
701
+ try:
702
+ async for extracted in _invoke_a2a_graph_stream(
703
+ entry=entry,
704
+ user=user,
705
+ context_id=record.context_id,
706
+ text=text,
707
+ ):
708
+ await stream_queue.put(("item", extracted))
709
+ except asyncio.CancelledError:
710
+ return
711
+ except Exception as exc: # noqa: BLE001
712
+ await stream_queue.put(("error", exc))
713
+ return
714
+ await stream_queue.put(("done", None))
715
+
716
+ producer_task = asyncio.create_task(_produce_stream())
717
+ try:
718
+ while True:
719
+ if record.cancellation_requested:
720
+ if pending_text is not None:
721
+ artifact = make_text_artifact(pending_text)
722
+ record.artifacts = [artifact]
723
+ yield _sse_jsonrpc_event(
724
+ request_id=request_id,
725
+ result=_task_artifact_update_event(
726
+ record,
727
+ artifact=artifact,
728
+ append=True,
729
+ last_chunk=True,
730
+ sdk_compatible=sdk_compatible,
731
+ ),
732
+ )
733
+ else:
734
+ record.artifacts = []
735
+ registry.save(record)
736
+ yield _sse_jsonrpc_event(
737
+ request_id=request_id,
738
+ result=_task_status_update_event(record, final=True, sdk_compatible=sdk_compatible),
739
+ )
740
+ return
741
+ try:
742
+ event_kind, event_payload = await asyncio.wait_for(stream_queue.get(), timeout=0.05)
743
+ except asyncio.TimeoutError:
744
+ continue
745
+ if event_kind == "error":
746
+ raise event_payload
747
+ if event_kind == "done":
748
+ break
749
+ extracted = event_payload
750
+ final_extracted = extracted
751
+ chunk_text = _extract_stream_chunk_text(extracted)
752
+ if chunk_text is None:
753
+ continue
754
+ if pending_text is not None:
755
+ yield _sse_jsonrpc_event(
756
+ request_id=request_id,
757
+ result=_task_artifact_update_event(
758
+ record,
759
+ artifact=make_text_artifact(pending_text),
760
+ append=True,
761
+ last_chunk=False,
762
+ sdk_compatible=sdk_compatible,
763
+ ),
764
+ )
765
+ pending_text = chunk_text
766
+ except Exception as exc: # noqa: BLE001
767
+ record.state = "failed"
768
+ record.status_message = str(exc)
769
+ registry.save(record)
770
+ yield _sse_jsonrpc_event(
771
+ request_id=request_id,
772
+ result=_task_status_update_event(record, final=True, sdk_compatible=sdk_compatible),
773
+ )
774
+ return
775
+ finally:
776
+ if not producer_task.done():
777
+ producer_task.cancel()
778
+ with suppress(asyncio.CancelledError):
779
+ await producer_task
780
+
781
+ if record.cancellation_requested:
782
+ if pending_text is not None:
783
+ artifact = make_text_artifact(pending_text)
784
+ record.artifacts = [artifact]
785
+ yield _sse_jsonrpc_event(
786
+ request_id=request_id,
787
+ result=_task_artifact_update_event(
788
+ record,
789
+ artifact=artifact,
790
+ append=True,
791
+ last_chunk=True,
792
+ sdk_compatible=sdk_compatible,
793
+ ),
794
+ )
795
+ else:
796
+ record.artifacts = []
797
+ registry.save(record)
798
+ yield _sse_jsonrpc_event(
799
+ request_id=request_id,
800
+ result=_task_status_update_event(record, final=True, sdk_compatible=sdk_compatible),
801
+ )
802
+ return
803
+
804
+ final_text = _extract_output_text(final_extracted or {})
805
+ terminal_chunk_text = pending_text or final_text
806
+ if terminal_chunk_text:
807
+ yield _sse_jsonrpc_event(
808
+ request_id=request_id,
809
+ result=_task_artifact_update_event(
810
+ record,
811
+ artifact=make_text_artifact(terminal_chunk_text),
812
+ append=True,
813
+ last_chunk=True,
814
+ sdk_compatible=sdk_compatible,
815
+ ),
816
+ )
817
+ record.state = "completed"
818
+ record.artifacts = [make_text_artifact(final_text)]
819
+ registry.save(record)
820
+ yield _sse_jsonrpc_event(
821
+ request_id=request_id,
822
+ result=_task_status_update_event(record, final=True, sdk_compatible=sdk_compatible),
823
+ )
824
+
825
+ return StreamingResponse(_event_iter(), media_type="text/event-stream")
826
+
827
+ try:
828
+ record.state = "working"
829
+ registry.save(record)
830
+ extracted = await _invoke_a2a_graph(
831
+ entry=entry,
832
+ user=user,
833
+ context_id=record.context_id,
834
+ text=text,
835
+ )
836
+ except Exception as exc: # noqa: BLE001
837
+ record.state = "failed"
838
+ record.status_message = str(exc)
839
+ registry.save(record)
840
+ return _jsonrpc_error(
841
+ request_id,
842
+ code=-32603 if sdk_compatible else -32000,
843
+ message=str(exc),
844
+ )
845
+
846
+ if record.cancellation_requested:
847
+ registry.save(record)
848
+ return _jsonrpc_result(
849
+ request_id,
850
+ _sdk_send_message_result(record) if sdk_compatible else _task_result(record),
851
+ )
852
+
853
+ record.state = "completed"
854
+ record.artifacts = [make_text_artifact(_extract_output_text(extracted))]
855
+ registry.save(record)
856
+ return _jsonrpc_result(
857
+ request_id,
858
+ _sdk_send_message_result(record) if sdk_compatible else _task_result(record),
859
+ )
860
+
861
+
862
+ async def load_assistant(assistant_id: str) -> AssistantRead:
863
+ session_factory = db_manager.get_session_factory()
864
+ async with session_factory() as session:
865
+ row = await session.scalar(select(Assistant).where(Assistant.assistant_id == assistant_id))
866
+ if row is None:
867
+ raise HTTPException(status_code=404, detail="Assistant not found")
868
+ return AssistantRead(
869
+ assistant_id=row.assistant_id,
870
+ name=row.name,
871
+ graph_id=row.graph_id,
872
+ created_at=row.created_at,
873
+ updated_at=row.updated_at,
874
+ metadata=row.metadata_json,
875
+ config=row.config_json,
876
+ context=row.context_json,
877
+ version=row.version,
878
+ description=row.description,
879
+ )