agentstack-sdk 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. agentstack_sdk/__init__.py +6 -0
  2. agentstack_sdk/a2a/__init__.py +2 -0
  3. agentstack_sdk/a2a/extensions/__init__.py +8 -0
  4. agentstack_sdk/a2a/extensions/auth/__init__.py +5 -0
  5. agentstack_sdk/a2a/extensions/auth/oauth/__init__.py +4 -0
  6. agentstack_sdk/a2a/extensions/auth/oauth/oauth.py +151 -0
  7. agentstack_sdk/a2a/extensions/auth/oauth/storage/__init__.py +5 -0
  8. agentstack_sdk/a2a/extensions/auth/oauth/storage/base.py +11 -0
  9. agentstack_sdk/a2a/extensions/auth/oauth/storage/memory.py +38 -0
  10. agentstack_sdk/a2a/extensions/auth/secrets/__init__.py +4 -0
  11. agentstack_sdk/a2a/extensions/auth/secrets/secrets.py +77 -0
  12. agentstack_sdk/a2a/extensions/base.py +205 -0
  13. agentstack_sdk/a2a/extensions/common/__init__.py +4 -0
  14. agentstack_sdk/a2a/extensions/common/form.py +149 -0
  15. agentstack_sdk/a2a/extensions/exceptions.py +11 -0
  16. agentstack_sdk/a2a/extensions/interactions/__init__.py +4 -0
  17. agentstack_sdk/a2a/extensions/interactions/approval.py +125 -0
  18. agentstack_sdk/a2a/extensions/services/__init__.py +8 -0
  19. agentstack_sdk/a2a/extensions/services/embedding.py +106 -0
  20. agentstack_sdk/a2a/extensions/services/form.py +54 -0
  21. agentstack_sdk/a2a/extensions/services/llm.py +100 -0
  22. agentstack_sdk/a2a/extensions/services/mcp.py +193 -0
  23. agentstack_sdk/a2a/extensions/services/platform.py +141 -0
  24. agentstack_sdk/a2a/extensions/tools/__init__.py +5 -0
  25. agentstack_sdk/a2a/extensions/tools/call.py +114 -0
  26. agentstack_sdk/a2a/extensions/tools/exceptions.py +6 -0
  27. agentstack_sdk/a2a/extensions/ui/__init__.py +10 -0
  28. agentstack_sdk/a2a/extensions/ui/agent_detail.py +54 -0
  29. agentstack_sdk/a2a/extensions/ui/canvas.py +71 -0
  30. agentstack_sdk/a2a/extensions/ui/citation.py +78 -0
  31. agentstack_sdk/a2a/extensions/ui/error.py +223 -0
  32. agentstack_sdk/a2a/extensions/ui/form_request.py +52 -0
  33. agentstack_sdk/a2a/extensions/ui/settings.py +73 -0
  34. agentstack_sdk/a2a/extensions/ui/trajectory.py +70 -0
  35. agentstack_sdk/a2a/types.py +104 -0
  36. agentstack_sdk/platform/__init__.py +12 -0
  37. agentstack_sdk/platform/client.py +123 -0
  38. agentstack_sdk/platform/common.py +37 -0
  39. agentstack_sdk/platform/configuration.py +47 -0
  40. agentstack_sdk/platform/context.py +291 -0
  41. agentstack_sdk/platform/file.py +295 -0
  42. agentstack_sdk/platform/model_provider.py +131 -0
  43. agentstack_sdk/platform/provider.py +219 -0
  44. agentstack_sdk/platform/provider_build.py +190 -0
  45. agentstack_sdk/platform/types.py +45 -0
  46. agentstack_sdk/platform/user.py +70 -0
  47. agentstack_sdk/platform/user_feedback.py +42 -0
  48. agentstack_sdk/platform/variables.py +44 -0
  49. agentstack_sdk/platform/vector_store.py +217 -0
  50. agentstack_sdk/py.typed +0 -0
  51. agentstack_sdk/server/__init__.py +4 -0
  52. agentstack_sdk/server/agent.py +594 -0
  53. agentstack_sdk/server/app.py +87 -0
  54. agentstack_sdk/server/constants.py +9 -0
  55. agentstack_sdk/server/context.py +68 -0
  56. agentstack_sdk/server/dependencies.py +117 -0
  57. agentstack_sdk/server/exceptions.py +3 -0
  58. agentstack_sdk/server/middleware/__init__.py +3 -0
  59. agentstack_sdk/server/middleware/platform_auth_backend.py +131 -0
  60. agentstack_sdk/server/server.py +376 -0
  61. agentstack_sdk/server/store/__init__.py +3 -0
  62. agentstack_sdk/server/store/context_store.py +35 -0
  63. agentstack_sdk/server/store/memory_context_store.py +59 -0
  64. agentstack_sdk/server/store/platform_context_store.py +58 -0
  65. agentstack_sdk/server/telemetry.py +53 -0
  66. agentstack_sdk/server/utils.py +26 -0
  67. agentstack_sdk/types.py +15 -0
  68. agentstack_sdk/util/__init__.py +4 -0
  69. agentstack_sdk/util/file.py +260 -0
  70. agentstack_sdk/util/httpx.py +18 -0
  71. agentstack_sdk/util/logging.py +63 -0
  72. agentstack_sdk/util/resource_context.py +44 -0
  73. agentstack_sdk/util/utils.py +47 -0
  74. agentstack_sdk-0.5.2rc2.dist-info/METADATA +120 -0
  75. agentstack_sdk-0.5.2rc2.dist-info/RECORD +76 -0
  76. agentstack_sdk-0.5.2rc2.dist-info/WHEEL +4 -0
@@ -0,0 +1,594 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import asyncio
5
+ import inspect
6
+ import typing
7
+ from asyncio import CancelledError
8
+ from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator
9
+ from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, suppress
10
+ from datetime import datetime, timedelta
11
+ from typing import Any, NamedTuple, TypeAlias, TypeVar, cast
12
+
13
+ import janus
14
+ from a2a.server.agent_execution import AgentExecutor, RequestContext
15
+ from a2a.server.events import EventQueue, QueueManager
16
+ from a2a.server.tasks import TaskManager, TaskStore, TaskUpdater
17
+ from a2a.types import (
18
+ AgentCapabilities,
19
+ AgentCard,
20
+ AgentInterface,
21
+ AgentProvider,
22
+ AgentSkill,
23
+ Artifact,
24
+ DataPart,
25
+ FilePart,
26
+ FileWithBytes,
27
+ FileWithUri,
28
+ Message,
29
+ Part,
30
+ SecurityScheme,
31
+ TaskArtifactUpdateEvent,
32
+ TaskState,
33
+ TaskStatus,
34
+ TaskStatusUpdateEvent,
35
+ TextPart,
36
+ )
37
+ from typing_extensions import override
38
+
39
+ from agentstack_sdk.a2a.extensions.ui.agent_detail import AgentDetail, AgentDetailExtensionSpec
40
+ from agentstack_sdk.a2a.extensions.ui.error import (
41
+ ErrorExtensionParams,
42
+ ErrorExtensionServer,
43
+ ErrorExtensionSpec,
44
+ get_error_extension_context,
45
+ )
46
+ from agentstack_sdk.a2a.types import ArtifactChunk, Metadata, RunYield, RunYieldResume
47
+ from agentstack_sdk.server.constants import _IMPLICIT_DEPENDENCY_PREFIX
48
+ from agentstack_sdk.server.context import RunContext
49
+ from agentstack_sdk.server.dependencies import Dependency, Depends, extract_dependencies
50
+ from agentstack_sdk.server.store.context_store import ContextStore
51
+ from agentstack_sdk.server.utils import cancel_task
52
+ from agentstack_sdk.util.logging import logger
53
+
54
+ AgentFunction: TypeAlias = Callable[[], AsyncGenerator[RunYield, RunYieldResume]]
55
+ AgentFunctionFactory: TypeAlias = Callable[[RequestContext, ContextStore], AbstractAsyncContextManager[AgentFunction]]
56
+
57
+ OriginalFnType = TypeVar("OriginalFnType", bound=Callable[..., Any]) # pyright: ignore[reportExplicitAny]
58
+
59
+
60
+ class AgentExecuteFn(typing.Protocol):
61
+ async def __call__(self, _ctx: RunContext, **kwargs: Any) -> None: ...
62
+
63
+
64
+ class Agent(NamedTuple):
65
+ card: AgentCard
66
+ dependencies: dict[str, Depends]
67
+ execute_fn: AgentExecuteFn
68
+
69
+
70
+ AgentFactory: TypeAlias = Callable[[Callable[[dict[str, Depends]], None]], Agent]
71
+
72
+
73
+ def agent(
74
+ name: str | None = None,
75
+ description: str | None = None,
76
+ *,
77
+ url: str = "http://invalid", # Default will be replaced by the server
78
+ additional_interfaces: list[AgentInterface] | None = None,
79
+ capabilities: AgentCapabilities | None = None,
80
+ default_input_modes: list[str] | None = None,
81
+ default_output_modes: list[str] | None = None,
82
+ detail: AgentDetail | None = None,
83
+ documentation_url: str | None = None,
84
+ icon_url: str | None = None,
85
+ preferred_transport: str | None = None,
86
+ provider: AgentProvider | None = None,
87
+ security: list[dict[str, list[str]]] | None = None,
88
+ security_schemes: dict[str, SecurityScheme] | None = None,
89
+ skills: list[AgentSkill] | None = None,
90
+ supports_authenticated_extended_card: bool | None = None,
91
+ version: str | None = None,
92
+ ) -> Callable[[OriginalFnType], AgentFactory]:
93
+ """
94
+ Create an Agent function.
95
+
96
+ :param name: A human-readable name for the agent (inferred from the function name if not provided).
97
+ :param description: A human-readable description of the agent, assisting users and other agents in understanding
98
+ its purpose (inferred from the function docstring if not provided).
99
+ :param additional_interfaces: A list of additional supported interfaces (transport and URL combinations).
100
+ A client can use any of these to communicate with the agent.
101
+ :param capabilities: A declaration of optional capabilities supported by the agent.
102
+ :param default_input_modes: Default set of supported input MIME types for all skills, which can be overridden on
103
+ a per-skill basis.
104
+ :param default_output_modes: Default set of supported output MIME types for all skills, which can be overridden on
105
+ a per-skill basis.
106
+ :param detail: Agent Stack SDK details extending the agent metadata
107
+ :param documentation_url: An optional URL to the agent's documentation.
108
+ :param extensions: Agent Stack SDK extensions to apply to the agent.
109
+ :param icon_url: An optional URL to an icon for the agent.
110
+ :param preferred_transport: The transport protocol for the preferred endpoint. Defaults to 'JSONRPC' if not
111
+ specified.
112
+ :param provider: Information about the agent's service provider.
113
+ :param security: A list of security requirement objects that apply to all agent interactions. Each object lists
114
+ security schemes that can be used. Follows the OpenAPI 3.0 Security Requirement Object.
115
+ :param security_schemes: A declaration of the security schemes available to authorize requests. The key is the
116
+ scheme name. Follows the OpenAPI 3.0 Security Scheme Object.
117
+ :param skills: The set of skills, or distinct capabilities, that the agent can perform.
118
+ :param supports_authenticated_extended_card: If true, the agent can provide an extended agent card with additional
119
+ details to authenticated users. Defaults to false.
120
+ :param version: The agent's own version number. The format is defined by the provider.
121
+ """
122
+
123
+ capabilities = capabilities.model_copy(deep=True) if capabilities else AgentCapabilities(streaming=True)
124
+ detail = detail or AgentDetail() # pyright: ignore [reportCallIssue]
125
+
126
+ def decorator(fn: OriginalFnType) -> AgentFactory:
127
+ def agent_factory(modify_dependencies: Callable[[dict[str, Depends]], None]):
128
+ signature = inspect.signature(fn)
129
+ dependencies = extract_dependencies(signature)
130
+ modify_dependencies(dependencies)
131
+
132
+ sdk_extensions = [dep.extension for dep in dependencies.values() if dep.extension is not None]
133
+
134
+ resolved_name = name or fn.__name__
135
+ resolved_description = description or fn.__doc__ or ""
136
+
137
+ # Check if user has provided an ErrorExtensionServer, if not add default
138
+ has_error_extension = any(isinstance(ext, ErrorExtensionServer) for ext in sdk_extensions)
139
+ error_extension_spec = ErrorExtensionSpec(ErrorExtensionParams()) if not has_error_extension else None
140
+
141
+ capabilities.extensions = [
142
+ *(capabilities.extensions or []),
143
+ *(AgentDetailExtensionSpec(detail).to_agent_card_extensions()),
144
+ *(error_extension_spec.to_agent_card_extensions() if error_extension_spec else []),
145
+ *(e_card for ext in sdk_extensions for e_card in ext.spec.to_agent_card_extensions()),
146
+ ]
147
+
148
+ card = AgentCard(
149
+ url=url,
150
+ preferred_transport=preferred_transport,
151
+ additional_interfaces=additional_interfaces,
152
+ capabilities=capabilities,
153
+ default_input_modes=default_input_modes or ["text"],
154
+ default_output_modes=default_output_modes or ["text"],
155
+ description=resolved_description,
156
+ documentation_url=documentation_url,
157
+ icon_url=icon_url,
158
+ name=resolved_name,
159
+ provider=provider,
160
+ security=security,
161
+ security_schemes=security_schemes,
162
+ skills=skills or [],
163
+ supports_authenticated_extended_card=supports_authenticated_extended_card,
164
+ version=version or "1.0.0",
165
+ )
166
+
167
+ if inspect.isasyncgenfunction(fn):
168
+
169
+ async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None:
170
+ try:
171
+ gen: AsyncGenerator[RunYield, RunYieldResume] = fn(*args, **kwargs)
172
+ value: RunYieldResume = None
173
+ while True:
174
+ value = await _ctx.yield_async(await gen.asend(value))
175
+ except StopAsyncIteration:
176
+ pass
177
+ except Exception as e:
178
+ await _ctx.yield_async(e)
179
+ finally:
180
+ _ctx.shutdown()
181
+
182
+ elif inspect.iscoroutinefunction(fn):
183
+
184
+ async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None:
185
+ try:
186
+ await _ctx.yield_async(await fn(*args, **kwargs))
187
+ except Exception as e:
188
+ await _ctx.yield_async(e)
189
+ finally:
190
+ _ctx.shutdown()
191
+
192
+ elif inspect.isgeneratorfunction(fn):
193
+
194
+ def _execute_fn_sync(_ctx: RunContext, *args, **kwargs) -> None:
195
+ try:
196
+ gen: Generator[RunYield, RunYieldResume] = fn(*args, **kwargs)
197
+ value = None
198
+ while True:
199
+ value = _ctx.yield_sync(gen.send(value))
200
+ except StopIteration:
201
+ pass
202
+ except Exception as e:
203
+ _ctx.yield_sync(e)
204
+ finally:
205
+ _ctx.shutdown()
206
+
207
+ async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None:
208
+ await asyncio.to_thread(_execute_fn_sync, _ctx, *args, **kwargs)
209
+
210
+ else:
211
+
212
+ def _execute_fn_sync(_ctx: RunContext, *args, **kwargs) -> None:
213
+ try:
214
+ _ctx.yield_sync(fn(*args, **kwargs))
215
+ except Exception as e:
216
+ _ctx.yield_sync(e)
217
+ finally:
218
+ _ctx.shutdown()
219
+
220
+ async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None:
221
+ await asyncio.to_thread(_execute_fn_sync, _ctx, *args, **kwargs)
222
+
223
+ return Agent(card=card, dependencies=dependencies, execute_fn=execute_fn)
224
+
225
+ return agent_factory
226
+
227
+ return decorator
228
+
229
+
230
+ class AgentRun:
231
+ def __init__(self, agent: Agent, context_store: ContextStore, on_finish: Callable[[], None] | None = None) -> None:
232
+ self._agent: Agent = agent
233
+ self._task: asyncio.Task[None] | None = None
234
+ self.last_invocation: datetime = datetime.now()
235
+ self.resume_queue: asyncio.Queue[RunYieldResume] = asyncio.Queue()
236
+ self._run_context: RunContext | None = None
237
+ self._request_context: RequestContext | None = None
238
+ self._task_updater: TaskUpdater | None = None
239
+ self._context_store: ContextStore = context_store
240
+ self._lock: asyncio.Lock = asyncio.Lock()
241
+ self._on_finish: Callable[[], None] | None = on_finish
242
+ self._working: bool = False
243
+
244
+ @property
245
+ def run_context(self) -> RunContext:
246
+ if not self._run_context:
247
+ raise RuntimeError("Accessing run context for run that has not been started")
248
+ return self._run_context
249
+
250
+ @property
251
+ def request_context(self) -> RequestContext:
252
+ if not self._request_context:
253
+ raise RuntimeError("Accessing request context for run that has not been started")
254
+ return self._request_context
255
+
256
+ @property
257
+ def task_updater(self) -> TaskUpdater:
258
+ if not self._task_updater:
259
+ raise RuntimeError("Accessing task updater for run that has not been started")
260
+ return self._task_updater
261
+
262
+ @property
263
+ def done(self) -> bool:
264
+ return self._task is not None and self._task.done()
265
+
266
+ def _handle_finish(self) -> None:
267
+ if self._on_finish:
268
+ self._on_finish()
269
+
270
+ async def start(self, request_context: RequestContext, event_queue: EventQueue):
271
+ async with self._lock:
272
+ if self._working or self.done:
273
+ raise RuntimeError("Attempting to start a run that is already executing or done")
274
+ task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message
275
+ assert task_id and context_id and message
276
+ self._run_context = RunContext(
277
+ configuration=request_context.configuration,
278
+ context_id=context_id,
279
+ task_id=task_id,
280
+ current_task=request_context.current_task,
281
+ related_tasks=request_context.related_tasks,
282
+ )
283
+ self._request_context = request_context
284
+ self._task_updater = TaskUpdater(event_queue, task_id, context_id)
285
+ if not request_context.current_task:
286
+ await self._task_updater.submit()
287
+ await self._task_updater.start_work()
288
+ self._working = True
289
+ self._task = asyncio.create_task(self._run_agent_function(initial_message=message))
290
+
291
+ async def resume(self, request_context: RequestContext, event_queue: EventQueue):
292
+ # These are incorrectly typed in a2a
293
+ async with self._lock:
294
+ if self._working or self.done:
295
+ raise RuntimeError("Attempting to resume a run that is already executing or done")
296
+ task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message
297
+ assert task_id and context_id and message
298
+ self._request_context = request_context
299
+ self._task_updater = TaskUpdater(event_queue, task_id, context_id)
300
+
301
+ for dependency in self._agent.dependencies.values():
302
+ if dependency.extension:
303
+ dependency.extension.handle_incoming_message(message, self.run_context, request_context)
304
+
305
+ self._working = True
306
+ await self.resume_queue.put(message)
307
+
308
+ async def cancel(self, request_context: RequestContext, event_queue: EventQueue):
309
+ if not self._task:
310
+ raise RuntimeError("Cannot cancel run that has not been started")
311
+
312
+ async with self._lock:
313
+ try:
314
+ assert request_context.task_id
315
+ assert request_context.context_id
316
+ self._task_updater = TaskUpdater(event_queue, request_context.task_id, request_context.context_id)
317
+ await self._task_updater.cancel()
318
+ finally:
319
+ await cancel_task(self._task)
320
+
321
+ @asynccontextmanager
322
+ async def _dependencies_lifespan(self, message: Message) -> AsyncIterator[dict[str, Dependency]]:
323
+ async with AsyncExitStack() as stack:
324
+ dependency_args: dict[str, Dependency] = {}
325
+ initialize_deps_exceptions: list[Exception] = []
326
+ for pname, depends in self._agent.dependencies.items():
327
+ # call dependencies with the first message and initialize their lifespan
328
+ try:
329
+ dependency_args[pname] = await stack.enter_async_context(
330
+ depends(message, self.run_context, self.request_context, dependency_args)
331
+ )
332
+ except Exception as e:
333
+ initialize_deps_exceptions.append(e)
334
+
335
+ if initialize_deps_exceptions:
336
+ raise (
337
+ ExceptionGroup("Failed to initialize dependencies", initialize_deps_exceptions)
338
+ if len(initialize_deps_exceptions) > 1
339
+ else initialize_deps_exceptions[0]
340
+ )
341
+
342
+ self.run_context._store = await self._context_store.create( # pyright: ignore[reportPrivateUsage]
343
+ context_id=self.run_context.context_id,
344
+ initialized_dependencies=list(dependency_args.values()),
345
+ )
346
+
347
+ yield {k: v for k, v in dependency_args.items() if not k.startswith(_IMPLICIT_DEPENDENCY_PREFIX)}
348
+
349
+ def _with_context(self, message: Message | None = None) -> Message | None:
350
+ if message is None:
351
+ return None
352
+ # Note: This check would require extra handling in agents just forwarding messages from other agents
353
+ # Instead, we just silently replace it.
354
+ # if message.task_id and message.task_id != task_updater.task_id:
355
+ # raise ValueError("Message must have the same task_id as the task")
356
+ # if message.context_id and message.context_id != task_updater.context_id:
357
+ # raise ValueError("Message must have the same context_id as the task")
358
+ return message.model_copy(
359
+ deep=True, update={"context_id": self.task_updater.context_id, "task_id": self.task_updater.task_id}
360
+ )
361
+
362
+ async def _run_agent_function(self, initial_message: Message) -> None:
363
+ yield_queue = self.run_context._yield_queue # pyright: ignore[reportPrivateUsage]
364
+ yield_resume_queue = self.run_context._yield_resume_queue # pyright: ignore[reportPrivateUsage]
365
+
366
+ try:
367
+ async with self._dependencies_lifespan(initial_message) as dependency_args:
368
+ task = asyncio.create_task(self._agent.execute_fn(self.run_context, **dependency_args))
369
+ try:
370
+ resume_value: RunYieldResume = None
371
+ opened_artifacts: set[str] = set()
372
+ while not task.done() or yield_queue.async_q.qsize() > 0:
373
+ yielded_value = await yield_queue.async_q.get()
374
+
375
+ self.last_invocation = datetime.now()
376
+
377
+ match yielded_value:
378
+ case str(text):
379
+ await self.task_updater.update_status(
380
+ TaskState.working,
381
+ message=self.task_updater.new_agent_message(parts=[Part(root=TextPart(text=text))]),
382
+ )
383
+ case Part(root=part) | (TextPart() | FilePart() | DataPart() as part):
384
+ await self.task_updater.update_status(
385
+ TaskState.working,
386
+ message=self.task_updater.new_agent_message(parts=[Part(root=part)]),
387
+ )
388
+ case FileWithBytes() | FileWithUri() as file:
389
+ await self.task_updater.update_status(
390
+ TaskState.working,
391
+ message=self.task_updater.new_agent_message(parts=[Part(root=FilePart(file=file))]),
392
+ )
393
+ case Message() as message:
394
+ await self.task_updater.update_status(
395
+ TaskState.working, message=self._with_context(message)
396
+ )
397
+ case ArtifactChunk(
398
+ parts=parts,
399
+ artifact_id=artifact_id,
400
+ name=name,
401
+ metadata=metadata,
402
+ last_chunk=last_chunk,
403
+ ):
404
+ await self.task_updater.add_artifact(
405
+ parts=cast(list[Part], parts),
406
+ artifact_id=artifact_id,
407
+ name=name,
408
+ metadata=metadata,
409
+ append=artifact_id in opened_artifacts,
410
+ last_chunk=last_chunk,
411
+ )
412
+ opened_artifacts.add(artifact_id)
413
+ case Artifact(parts=parts, artifact_id=artifact_id, name=name, metadata=metadata):
414
+ await self.task_updater.add_artifact(
415
+ parts=parts,
416
+ artifact_id=artifact_id,
417
+ name=name,
418
+ metadata=metadata,
419
+ last_chunk=True,
420
+ append=False,
421
+ )
422
+ case TaskStatus(
423
+ state=(TaskState.auth_required | TaskState.input_required) as state,
424
+ message=message,
425
+ timestamp=timestamp,
426
+ ):
427
+ await self.task_updater.update_status(
428
+ state=state, message=self._with_context(message), final=True, timestamp=timestamp
429
+ )
430
+ self._working = False
431
+ resume_value = await self.resume_queue.get()
432
+ self.resume_queue.task_done()
433
+ case TaskStatus(state=state, message=message, timestamp=timestamp):
434
+ await self.task_updater.update_status(
435
+ state=state, message=self._with_context(message), timestamp=timestamp
436
+ )
437
+ case TaskStatusUpdateEvent(
438
+ status=TaskStatus(state=state, message=message, timestamp=timestamp),
439
+ final=final,
440
+ metadata=metadata,
441
+ ):
442
+ await self.task_updater.update_status(
443
+ state=state,
444
+ message=self._with_context(message),
445
+ timestamp=timestamp,
446
+ final=final,
447
+ metadata=metadata,
448
+ )
449
+ case TaskArtifactUpdateEvent(
450
+ artifact=Artifact(artifact_id=artifact_id, name=name, metadata=metadata, parts=parts),
451
+ append=append,
452
+ last_chunk=last_chunk,
453
+ ):
454
+ await self.task_updater.add_artifact(
455
+ parts=parts,
456
+ artifact_id=artifact_id,
457
+ name=name,
458
+ metadata=metadata,
459
+ append=append,
460
+ last_chunk=last_chunk,
461
+ )
462
+ case Metadata() as metadata:
463
+ await self.task_updater.update_status(
464
+ state=TaskState.working,
465
+ message=self.task_updater.new_agent_message(parts=[], metadata=metadata),
466
+ )
467
+ case dict() as data:
468
+ await self.task_updater.update_status(
469
+ state=TaskState.working,
470
+ message=self.task_updater.new_agent_message(parts=[Part(root=DataPart(data=data))]),
471
+ )
472
+ case Exception() as ex:
473
+ raise ex
474
+ case _:
475
+ raise ValueError(f"Invalid value yielded from agent: {type(yielded_value)}")
476
+
477
+ await yield_resume_queue.async_q.put(resume_value)
478
+
479
+ await self.task_updater.complete()
480
+
481
+ except (janus.AsyncQueueShutDown, GeneratorExit):
482
+ await self.task_updater.complete()
483
+ except Exception as ex:
484
+ logger.error("Error when executing agent", exc_info=ex)
485
+ await self.task_updater.failed(get_error_extension_context().server.message(ex))
486
+ await cancel_task(task)
487
+ except Exception as ex:
488
+ logger.error("Error when executing agent", exc_info=ex)
489
+ await self.task_updater.failed(get_error_extension_context().server.message(ex))
490
+ finally:
491
+ self._working = False
492
+ with suppress(Exception):
493
+ self._handle_finish()
494
+
495
+
496
+ class Executor(AgentExecutor):
497
+ def __init__(
498
+ self,
499
+ agent: Agent,
500
+ queue_manager: QueueManager,
501
+ context_store: ContextStore,
502
+ task_timeout: timedelta,
503
+ task_store: TaskStore,
504
+ ) -> None:
505
+ self._agent: Agent = agent
506
+ self._running_tasks: dict[str, AgentRun] = {}
507
+ self._scheduled_cleanups: dict[str, asyncio.Task[None]] = {}
508
+ self._context_store: ContextStore = context_store
509
+ self._task_timeout: timedelta = task_timeout
510
+ self._task_store: TaskStore = task_store
511
+
512
+ @override
513
+ async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
514
+ # this is only executed in the context of SendMessage request
515
+ message, task_id, context_id = context.message, context.task_id, context.context_id
516
+ assert message and task_id and context_id
517
+ agent_run: AgentRun | None = None
518
+ try:
519
+ if not context.current_task:
520
+ agent_run = AgentRun(self._agent, self._context_store, lambda: self._handle_finish(task_id))
521
+ self._running_tasks[task_id] = agent_run
522
+ await self._schedule_run_cleanup(request_context=context)
523
+ await agent_run.start(request_context=context, event_queue=event_queue)
524
+ elif agent_run := self._running_tasks.get(task_id):
525
+ await agent_run.resume(request_context=context, event_queue=event_queue)
526
+ else:
527
+ raise self._run_not_found_error(task_id)
528
+
529
+ # will run until complete or next input/auth required task state
530
+ tapped_queue = event_queue.tap()
531
+ while True:
532
+ match await tapped_queue.dequeue_event():
533
+ case TaskStatusUpdateEvent(final=True):
534
+ break
535
+ case _:
536
+ pass
537
+
538
+ except CancelledError:
539
+ if agent_run:
540
+ await agent_run.cancel(request_context=context, event_queue=event_queue)
541
+ except Exception as ex:
542
+ logger.error("Unhandled error when executing agent:", exc_info=ex)
543
+
544
+ @override
545
+ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
546
+ if not context.task_id or not context.context_id:
547
+ raise ValueError("Task ID and context ID must be set to cancel a task")
548
+ if not (run := self._running_tasks.get(context.task_id)):
549
+ raise self._run_not_found_error(context.task_id)
550
+ await run.cancel(context, event_queue)
551
+
552
+ def _handle_finish(self, task_id: str) -> None:
553
+ if task := self._scheduled_cleanups.pop(task_id, None):
554
+ task.cancel()
555
+ self._running_tasks.pop(task_id, None)
556
+
557
+ def _run_not_found_error(self, task_id: str | None) -> Exception:
558
+ return RuntimeError(
559
+ f"Run for task ID {task_id} not found. "
560
+ + "It may be on another replica, make sure to enable sticky sessions in your load balancer"
561
+ )
562
+
563
+ async def _schedule_run_cleanup(self, request_context: RequestContext):
564
+ task_id, context_id = request_context.task_id, request_context.context_id
565
+ assert task_id and context_id
566
+
567
+ async def cleanup_fn():
568
+ await asyncio.sleep(self._task_timeout.total_seconds())
569
+ if not (run := self._running_tasks.get(task_id)):
570
+ return
571
+ try:
572
+ while not run.done:
573
+ if run.last_invocation + self._task_timeout < datetime.now():
574
+ logger.warning(f"Task {task_id} did not finish in {self._task_timeout}")
575
+ queue = EventQueue()
576
+ await run.cancel(request_context=request_context, event_queue=queue)
577
+ # the original request queue is closed at this point, we need to propagate state to store manually
578
+ manager = TaskManager(
579
+ task_id=task_id, context_id=context_id, task_store=self._task_store, initial_message=None
580
+ )
581
+ event = await queue.dequeue_event(no_wait=True)
582
+ if not isinstance(event, TaskStatusUpdateEvent) or event.status.state != TaskState.canceled:
583
+ raise RuntimeError(f"Something strange occured during scheduled cancel, event: {event}")
584
+ await manager.save_task_event(event)
585
+ break
586
+ await asyncio.sleep(2)
587
+ except Exception as ex:
588
+ logger.error("Error when cleaning up task", exc_info=ex)
589
+ finally:
590
+ self._running_tasks.pop(task_id, None)
591
+ self._scheduled_cleanups.pop(task_id, None)
592
+
593
+ self._scheduled_cleanups[task_id] = asyncio.create_task(cleanup_fn())
594
+ self._scheduled_cleanups[task_id].add_done_callback(lambda _: ...)
@@ -0,0 +1,87 @@
1
+ # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ from datetime import timedelta
6
+
7
+ from a2a.server.agent_execution import RequestContextBuilder
8
+ from a2a.server.apps.jsonrpc import A2AFastAPIApplication
9
+ from a2a.server.apps.rest import A2ARESTFastAPIApplication
10
+ from a2a.server.events import InMemoryQueueManager, QueueManager
11
+ from a2a.server.request_handlers import DefaultRequestHandler
12
+ from a2a.server.tasks import (
13
+ InMemoryTaskStore,
14
+ PushNotificationConfigStore,
15
+ PushNotificationSender,
16
+ TaskStore,
17
+ )
18
+ from a2a.types import AgentInterface, TransportProtocol
19
+ from fastapi import APIRouter, Depends, FastAPI
20
+ from fastapi.applications import AppType
21
+ from starlette.authentication import AuthenticationBackend
22
+ from starlette.middleware.authentication import AuthenticationMiddleware
23
+ from starlette.types import Lifespan
24
+
25
+ from agentstack_sdk.server.agent import Agent, Executor
26
+ from agentstack_sdk.server.store.context_store import ContextStore
27
+ from agentstack_sdk.server.store.memory_context_store import InMemoryContextStore
28
+
29
+
30
+ def create_app(
31
+ agent: Agent,
32
+ task_store: TaskStore | None = None,
33
+ context_store: ContextStore | None = None,
34
+ queue_manager: QueueManager | None = None,
35
+ push_config_store: PushNotificationConfigStore | None = None,
36
+ push_sender: PushNotificationSender | None = None,
37
+ request_context_builder: RequestContextBuilder | None = None,
38
+ lifespan: Lifespan[AppType] | None = None,
39
+ dependencies: list[Depends] | None = None, # pyright: ignore [reportGeneralTypeIssues]
40
+ override_interfaces: bool = True,
41
+ task_timeout: timedelta = timedelta(minutes=10),
42
+ auth_backend: AuthenticationBackend | None = None,
43
+ **kwargs,
44
+ ) -> FastAPI:
45
+ queue_manager = queue_manager or InMemoryQueueManager()
46
+ task_store = task_store or InMemoryTaskStore()
47
+ context_store = context_store or InMemoryContextStore()
48
+ http_handler = DefaultRequestHandler(
49
+ agent_executor=Executor(
50
+ agent,
51
+ queue_manager,
52
+ context_store=context_store,
53
+ task_timeout=task_timeout,
54
+ task_store=task_store,
55
+ ),
56
+ task_store=task_store,
57
+ queue_manager=queue_manager,
58
+ push_config_store=push_config_store,
59
+ push_sender=push_sender,
60
+ request_context_builder=request_context_builder,
61
+ )
62
+
63
+ if override_interfaces:
64
+ agent.card.additional_interfaces = [
65
+ AgentInterface(url=agent.card.url, transport=TransportProtocol.http_json),
66
+ AgentInterface(url=agent.card.url + "/jsonrpc/", transport=TransportProtocol.jsonrpc),
67
+ ]
68
+ agent.card.url = agent.card.url + "/jsonrpc/"
69
+ agent.card.preferred_transport = TransportProtocol.jsonrpc
70
+
71
+ jsonrpc_app = A2AFastAPIApplication(agent_card=agent.card, http_handler=http_handler).build(
72
+ dependencies=dependencies,
73
+ **kwargs,
74
+ )
75
+
76
+ rest_app = A2ARESTFastAPIApplication(agent_card=agent.card, http_handler=http_handler).build(
77
+ dependencies=dependencies,
78
+ **kwargs,
79
+ )
80
+
81
+ if auth_backend:
82
+ rest_app.add_middleware(AuthenticationMiddleware, backend=auth_backend)
83
+ jsonrpc_app.add_middleware(AuthenticationMiddleware, backend=auth_backend)
84
+
85
+ rest_app.mount("/jsonrpc", jsonrpc_app)
86
+ rest_app.include_router(APIRouter(lifespan=lifespan))
87
+ return rest_app