agentstack-sdk 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentstack_sdk/__init__.py +6 -0
- agentstack_sdk/a2a/__init__.py +2 -0
- agentstack_sdk/a2a/extensions/__init__.py +8 -0
- agentstack_sdk/a2a/extensions/auth/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/auth/oauth/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/auth/oauth/oauth.py +151 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/base.py +11 -0
- agentstack_sdk/a2a/extensions/auth/oauth/storage/memory.py +38 -0
- agentstack_sdk/a2a/extensions/auth/secrets/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/auth/secrets/secrets.py +77 -0
- agentstack_sdk/a2a/extensions/base.py +205 -0
- agentstack_sdk/a2a/extensions/common/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/common/form.py +149 -0
- agentstack_sdk/a2a/extensions/exceptions.py +11 -0
- agentstack_sdk/a2a/extensions/interactions/__init__.py +4 -0
- agentstack_sdk/a2a/extensions/interactions/approval.py +125 -0
- agentstack_sdk/a2a/extensions/services/__init__.py +8 -0
- agentstack_sdk/a2a/extensions/services/embedding.py +106 -0
- agentstack_sdk/a2a/extensions/services/form.py +54 -0
- agentstack_sdk/a2a/extensions/services/llm.py +100 -0
- agentstack_sdk/a2a/extensions/services/mcp.py +193 -0
- agentstack_sdk/a2a/extensions/services/platform.py +141 -0
- agentstack_sdk/a2a/extensions/tools/__init__.py +5 -0
- agentstack_sdk/a2a/extensions/tools/call.py +114 -0
- agentstack_sdk/a2a/extensions/tools/exceptions.py +6 -0
- agentstack_sdk/a2a/extensions/ui/__init__.py +10 -0
- agentstack_sdk/a2a/extensions/ui/agent_detail.py +54 -0
- agentstack_sdk/a2a/extensions/ui/canvas.py +71 -0
- agentstack_sdk/a2a/extensions/ui/citation.py +78 -0
- agentstack_sdk/a2a/extensions/ui/error.py +223 -0
- agentstack_sdk/a2a/extensions/ui/form_request.py +52 -0
- agentstack_sdk/a2a/extensions/ui/settings.py +73 -0
- agentstack_sdk/a2a/extensions/ui/trajectory.py +70 -0
- agentstack_sdk/a2a/types.py +104 -0
- agentstack_sdk/platform/__init__.py +12 -0
- agentstack_sdk/platform/client.py +123 -0
- agentstack_sdk/platform/common.py +37 -0
- agentstack_sdk/platform/configuration.py +47 -0
- agentstack_sdk/platform/context.py +291 -0
- agentstack_sdk/platform/file.py +295 -0
- agentstack_sdk/platform/model_provider.py +131 -0
- agentstack_sdk/platform/provider.py +219 -0
- agentstack_sdk/platform/provider_build.py +190 -0
- agentstack_sdk/platform/types.py +45 -0
- agentstack_sdk/platform/user.py +70 -0
- agentstack_sdk/platform/user_feedback.py +42 -0
- agentstack_sdk/platform/variables.py +44 -0
- agentstack_sdk/platform/vector_store.py +217 -0
- agentstack_sdk/py.typed +0 -0
- agentstack_sdk/server/__init__.py +4 -0
- agentstack_sdk/server/agent.py +594 -0
- agentstack_sdk/server/app.py +87 -0
- agentstack_sdk/server/constants.py +9 -0
- agentstack_sdk/server/context.py +68 -0
- agentstack_sdk/server/dependencies.py +117 -0
- agentstack_sdk/server/exceptions.py +3 -0
- agentstack_sdk/server/middleware/__init__.py +3 -0
- agentstack_sdk/server/middleware/platform_auth_backend.py +131 -0
- agentstack_sdk/server/server.py +376 -0
- agentstack_sdk/server/store/__init__.py +3 -0
- agentstack_sdk/server/store/context_store.py +35 -0
- agentstack_sdk/server/store/memory_context_store.py +59 -0
- agentstack_sdk/server/store/platform_context_store.py +58 -0
- agentstack_sdk/server/telemetry.py +53 -0
- agentstack_sdk/server/utils.py +26 -0
- agentstack_sdk/types.py +15 -0
- agentstack_sdk/util/__init__.py +4 -0
- agentstack_sdk/util/file.py +260 -0
- agentstack_sdk/util/httpx.py +18 -0
- agentstack_sdk/util/logging.py +63 -0
- agentstack_sdk/util/resource_context.py +44 -0
- agentstack_sdk/util/utils.py +47 -0
- agentstack_sdk-0.5.2rc2.dist-info/METADATA +120 -0
- agentstack_sdk-0.5.2rc2.dist-info/RECORD +76 -0
- agentstack_sdk-0.5.2rc2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,594 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
import inspect
|
|
6
|
+
import typing
|
|
7
|
+
from asyncio import CancelledError
|
|
8
|
+
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator
|
|
9
|
+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, suppress
|
|
10
|
+
from datetime import datetime, timedelta
|
|
11
|
+
from typing import Any, NamedTuple, TypeAlias, TypeVar, cast
|
|
12
|
+
|
|
13
|
+
import janus
|
|
14
|
+
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
|
15
|
+
from a2a.server.events import EventQueue, QueueManager
|
|
16
|
+
from a2a.server.tasks import TaskManager, TaskStore, TaskUpdater
|
|
17
|
+
from a2a.types import (
|
|
18
|
+
AgentCapabilities,
|
|
19
|
+
AgentCard,
|
|
20
|
+
AgentInterface,
|
|
21
|
+
AgentProvider,
|
|
22
|
+
AgentSkill,
|
|
23
|
+
Artifact,
|
|
24
|
+
DataPart,
|
|
25
|
+
FilePart,
|
|
26
|
+
FileWithBytes,
|
|
27
|
+
FileWithUri,
|
|
28
|
+
Message,
|
|
29
|
+
Part,
|
|
30
|
+
SecurityScheme,
|
|
31
|
+
TaskArtifactUpdateEvent,
|
|
32
|
+
TaskState,
|
|
33
|
+
TaskStatus,
|
|
34
|
+
TaskStatusUpdateEvent,
|
|
35
|
+
TextPart,
|
|
36
|
+
)
|
|
37
|
+
from typing_extensions import override
|
|
38
|
+
|
|
39
|
+
from agentstack_sdk.a2a.extensions.ui.agent_detail import AgentDetail, AgentDetailExtensionSpec
|
|
40
|
+
from agentstack_sdk.a2a.extensions.ui.error import (
|
|
41
|
+
ErrorExtensionParams,
|
|
42
|
+
ErrorExtensionServer,
|
|
43
|
+
ErrorExtensionSpec,
|
|
44
|
+
get_error_extension_context,
|
|
45
|
+
)
|
|
46
|
+
from agentstack_sdk.a2a.types import ArtifactChunk, Metadata, RunYield, RunYieldResume
|
|
47
|
+
from agentstack_sdk.server.constants import _IMPLICIT_DEPENDENCY_PREFIX
|
|
48
|
+
from agentstack_sdk.server.context import RunContext
|
|
49
|
+
from agentstack_sdk.server.dependencies import Dependency, Depends, extract_dependencies
|
|
50
|
+
from agentstack_sdk.server.store.context_store import ContextStore
|
|
51
|
+
from agentstack_sdk.server.utils import cancel_task
|
|
52
|
+
from agentstack_sdk.util.logging import logger
|
|
53
|
+
|
|
54
|
+
AgentFunction: TypeAlias = Callable[[], AsyncGenerator[RunYield, RunYieldResume]]
|
|
55
|
+
AgentFunctionFactory: TypeAlias = Callable[[RequestContext, ContextStore], AbstractAsyncContextManager[AgentFunction]]
|
|
56
|
+
|
|
57
|
+
OriginalFnType = TypeVar("OriginalFnType", bound=Callable[..., Any]) # pyright: ignore[reportExplicitAny]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class AgentExecuteFn(typing.Protocol):
|
|
61
|
+
async def __call__(self, _ctx: RunContext, **kwargs: Any) -> None: ...
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Agent(NamedTuple):
|
|
65
|
+
card: AgentCard
|
|
66
|
+
dependencies: dict[str, Depends]
|
|
67
|
+
execute_fn: AgentExecuteFn
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
AgentFactory: TypeAlias = Callable[[Callable[[dict[str, Depends]], None]], Agent]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def agent(
|
|
74
|
+
name: str | None = None,
|
|
75
|
+
description: str | None = None,
|
|
76
|
+
*,
|
|
77
|
+
url: str = "http://invalid", # Default will be replaced by the server
|
|
78
|
+
additional_interfaces: list[AgentInterface] | None = None,
|
|
79
|
+
capabilities: AgentCapabilities | None = None,
|
|
80
|
+
default_input_modes: list[str] | None = None,
|
|
81
|
+
default_output_modes: list[str] | None = None,
|
|
82
|
+
detail: AgentDetail | None = None,
|
|
83
|
+
documentation_url: str | None = None,
|
|
84
|
+
icon_url: str | None = None,
|
|
85
|
+
preferred_transport: str | None = None,
|
|
86
|
+
provider: AgentProvider | None = None,
|
|
87
|
+
security: list[dict[str, list[str]]] | None = None,
|
|
88
|
+
security_schemes: dict[str, SecurityScheme] | None = None,
|
|
89
|
+
skills: list[AgentSkill] | None = None,
|
|
90
|
+
supports_authenticated_extended_card: bool | None = None,
|
|
91
|
+
version: str | None = None,
|
|
92
|
+
) -> Callable[[OriginalFnType], AgentFactory]:
|
|
93
|
+
"""
|
|
94
|
+
Create an Agent function.
|
|
95
|
+
|
|
96
|
+
:param name: A human-readable name for the agent (inferred from the function name if not provided).
|
|
97
|
+
:param description: A human-readable description of the agent, assisting users and other agents in understanding
|
|
98
|
+
its purpose (inferred from the function docstring if not provided).
|
|
99
|
+
:param additional_interfaces: A list of additional supported interfaces (transport and URL combinations).
|
|
100
|
+
A client can use any of these to communicate with the agent.
|
|
101
|
+
:param capabilities: A declaration of optional capabilities supported by the agent.
|
|
102
|
+
:param default_input_modes: Default set of supported input MIME types for all skills, which can be overridden on
|
|
103
|
+
a per-skill basis.
|
|
104
|
+
:param default_output_modes: Default set of supported output MIME types for all skills, which can be overridden on
|
|
105
|
+
a per-skill basis.
|
|
106
|
+
:param detail: Agent Stack SDK details extending the agent metadata
|
|
107
|
+
:param documentation_url: An optional URL to the agent's documentation.
|
|
108
|
+
:param extensions: Agent Stack SDK extensions to apply to the agent.
|
|
109
|
+
:param icon_url: An optional URL to an icon for the agent.
|
|
110
|
+
:param preferred_transport: The transport protocol for the preferred endpoint. Defaults to 'JSONRPC' if not
|
|
111
|
+
specified.
|
|
112
|
+
:param provider: Information about the agent's service provider.
|
|
113
|
+
:param security: A list of security requirement objects that apply to all agent interactions. Each object lists
|
|
114
|
+
security schemes that can be used. Follows the OpenAPI 3.0 Security Requirement Object.
|
|
115
|
+
:param security_schemes: A declaration of the security schemes available to authorize requests. The key is the
|
|
116
|
+
scheme name. Follows the OpenAPI 3.0 Security Scheme Object.
|
|
117
|
+
:param skills: The set of skills, or distinct capabilities, that the agent can perform.
|
|
118
|
+
:param supports_authenticated_extended_card: If true, the agent can provide an extended agent card with additional
|
|
119
|
+
details to authenticated users. Defaults to false.
|
|
120
|
+
:param version: The agent's own version number. The format is defined by the provider.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
capabilities = capabilities.model_copy(deep=True) if capabilities else AgentCapabilities(streaming=True)
|
|
124
|
+
detail = detail or AgentDetail() # pyright: ignore [reportCallIssue]
|
|
125
|
+
|
|
126
|
+
def decorator(fn: OriginalFnType) -> AgentFactory:
|
|
127
|
+
def agent_factory(modify_dependencies: Callable[[dict[str, Depends]], None]):
|
|
128
|
+
signature = inspect.signature(fn)
|
|
129
|
+
dependencies = extract_dependencies(signature)
|
|
130
|
+
modify_dependencies(dependencies)
|
|
131
|
+
|
|
132
|
+
sdk_extensions = [dep.extension for dep in dependencies.values() if dep.extension is not None]
|
|
133
|
+
|
|
134
|
+
resolved_name = name or fn.__name__
|
|
135
|
+
resolved_description = description or fn.__doc__ or ""
|
|
136
|
+
|
|
137
|
+
# Check if user has provided an ErrorExtensionServer, if not add default
|
|
138
|
+
has_error_extension = any(isinstance(ext, ErrorExtensionServer) for ext in sdk_extensions)
|
|
139
|
+
error_extension_spec = ErrorExtensionSpec(ErrorExtensionParams()) if not has_error_extension else None
|
|
140
|
+
|
|
141
|
+
capabilities.extensions = [
|
|
142
|
+
*(capabilities.extensions or []),
|
|
143
|
+
*(AgentDetailExtensionSpec(detail).to_agent_card_extensions()),
|
|
144
|
+
*(error_extension_spec.to_agent_card_extensions() if error_extension_spec else []),
|
|
145
|
+
*(e_card for ext in sdk_extensions for e_card in ext.spec.to_agent_card_extensions()),
|
|
146
|
+
]
|
|
147
|
+
|
|
148
|
+
card = AgentCard(
|
|
149
|
+
url=url,
|
|
150
|
+
preferred_transport=preferred_transport,
|
|
151
|
+
additional_interfaces=additional_interfaces,
|
|
152
|
+
capabilities=capabilities,
|
|
153
|
+
default_input_modes=default_input_modes or ["text"],
|
|
154
|
+
default_output_modes=default_output_modes or ["text"],
|
|
155
|
+
description=resolved_description,
|
|
156
|
+
documentation_url=documentation_url,
|
|
157
|
+
icon_url=icon_url,
|
|
158
|
+
name=resolved_name,
|
|
159
|
+
provider=provider,
|
|
160
|
+
security=security,
|
|
161
|
+
security_schemes=security_schemes,
|
|
162
|
+
skills=skills or [],
|
|
163
|
+
supports_authenticated_extended_card=supports_authenticated_extended_card,
|
|
164
|
+
version=version or "1.0.0",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if inspect.isasyncgenfunction(fn):
|
|
168
|
+
|
|
169
|
+
async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None:
|
|
170
|
+
try:
|
|
171
|
+
gen: AsyncGenerator[RunYield, RunYieldResume] = fn(*args, **kwargs)
|
|
172
|
+
value: RunYieldResume = None
|
|
173
|
+
while True:
|
|
174
|
+
value = await _ctx.yield_async(await gen.asend(value))
|
|
175
|
+
except StopAsyncIteration:
|
|
176
|
+
pass
|
|
177
|
+
except Exception as e:
|
|
178
|
+
await _ctx.yield_async(e)
|
|
179
|
+
finally:
|
|
180
|
+
_ctx.shutdown()
|
|
181
|
+
|
|
182
|
+
elif inspect.iscoroutinefunction(fn):
|
|
183
|
+
|
|
184
|
+
async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None:
|
|
185
|
+
try:
|
|
186
|
+
await _ctx.yield_async(await fn(*args, **kwargs))
|
|
187
|
+
except Exception as e:
|
|
188
|
+
await _ctx.yield_async(e)
|
|
189
|
+
finally:
|
|
190
|
+
_ctx.shutdown()
|
|
191
|
+
|
|
192
|
+
elif inspect.isgeneratorfunction(fn):
|
|
193
|
+
|
|
194
|
+
def _execute_fn_sync(_ctx: RunContext, *args, **kwargs) -> None:
|
|
195
|
+
try:
|
|
196
|
+
gen: Generator[RunYield, RunYieldResume] = fn(*args, **kwargs)
|
|
197
|
+
value = None
|
|
198
|
+
while True:
|
|
199
|
+
value = _ctx.yield_sync(gen.send(value))
|
|
200
|
+
except StopIteration:
|
|
201
|
+
pass
|
|
202
|
+
except Exception as e:
|
|
203
|
+
_ctx.yield_sync(e)
|
|
204
|
+
finally:
|
|
205
|
+
_ctx.shutdown()
|
|
206
|
+
|
|
207
|
+
async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None:
|
|
208
|
+
await asyncio.to_thread(_execute_fn_sync, _ctx, *args, **kwargs)
|
|
209
|
+
|
|
210
|
+
else:
|
|
211
|
+
|
|
212
|
+
def _execute_fn_sync(_ctx: RunContext, *args, **kwargs) -> None:
|
|
213
|
+
try:
|
|
214
|
+
_ctx.yield_sync(fn(*args, **kwargs))
|
|
215
|
+
except Exception as e:
|
|
216
|
+
_ctx.yield_sync(e)
|
|
217
|
+
finally:
|
|
218
|
+
_ctx.shutdown()
|
|
219
|
+
|
|
220
|
+
async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None:
|
|
221
|
+
await asyncio.to_thread(_execute_fn_sync, _ctx, *args, **kwargs)
|
|
222
|
+
|
|
223
|
+
return Agent(card=card, dependencies=dependencies, execute_fn=execute_fn)
|
|
224
|
+
|
|
225
|
+
return agent_factory
|
|
226
|
+
|
|
227
|
+
return decorator
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class AgentRun:
|
|
231
|
+
def __init__(self, agent: Agent, context_store: ContextStore, on_finish: Callable[[], None] | None = None) -> None:
|
|
232
|
+
self._agent: Agent = agent
|
|
233
|
+
self._task: asyncio.Task[None] | None = None
|
|
234
|
+
self.last_invocation: datetime = datetime.now()
|
|
235
|
+
self.resume_queue: asyncio.Queue[RunYieldResume] = asyncio.Queue()
|
|
236
|
+
self._run_context: RunContext | None = None
|
|
237
|
+
self._request_context: RequestContext | None = None
|
|
238
|
+
self._task_updater: TaskUpdater | None = None
|
|
239
|
+
self._context_store: ContextStore = context_store
|
|
240
|
+
self._lock: asyncio.Lock = asyncio.Lock()
|
|
241
|
+
self._on_finish: Callable[[], None] | None = on_finish
|
|
242
|
+
self._working: bool = False
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def run_context(self) -> RunContext:
|
|
246
|
+
if not self._run_context:
|
|
247
|
+
raise RuntimeError("Accessing run context for run that has not been started")
|
|
248
|
+
return self._run_context
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def request_context(self) -> RequestContext:
|
|
252
|
+
if not self._request_context:
|
|
253
|
+
raise RuntimeError("Accessing request context for run that has not been started")
|
|
254
|
+
return self._request_context
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def task_updater(self) -> TaskUpdater:
|
|
258
|
+
if not self._task_updater:
|
|
259
|
+
raise RuntimeError("Accessing task updater for run that has not been started")
|
|
260
|
+
return self._task_updater
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def done(self) -> bool:
|
|
264
|
+
return self._task is not None and self._task.done()
|
|
265
|
+
|
|
266
|
+
def _handle_finish(self) -> None:
|
|
267
|
+
if self._on_finish:
|
|
268
|
+
self._on_finish()
|
|
269
|
+
|
|
270
|
+
async def start(self, request_context: RequestContext, event_queue: EventQueue):
|
|
271
|
+
async with self._lock:
|
|
272
|
+
if self._working or self.done:
|
|
273
|
+
raise RuntimeError("Attempting to start a run that is already executing or done")
|
|
274
|
+
task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message
|
|
275
|
+
assert task_id and context_id and message
|
|
276
|
+
self._run_context = RunContext(
|
|
277
|
+
configuration=request_context.configuration,
|
|
278
|
+
context_id=context_id,
|
|
279
|
+
task_id=task_id,
|
|
280
|
+
current_task=request_context.current_task,
|
|
281
|
+
related_tasks=request_context.related_tasks,
|
|
282
|
+
)
|
|
283
|
+
self._request_context = request_context
|
|
284
|
+
self._task_updater = TaskUpdater(event_queue, task_id, context_id)
|
|
285
|
+
if not request_context.current_task:
|
|
286
|
+
await self._task_updater.submit()
|
|
287
|
+
await self._task_updater.start_work()
|
|
288
|
+
self._working = True
|
|
289
|
+
self._task = asyncio.create_task(self._run_agent_function(initial_message=message))
|
|
290
|
+
|
|
291
|
+
async def resume(self, request_context: RequestContext, event_queue: EventQueue):
|
|
292
|
+
# These are incorrectly typed in a2a
|
|
293
|
+
async with self._lock:
|
|
294
|
+
if self._working or self.done:
|
|
295
|
+
raise RuntimeError("Attempting to resume a run that is already executing or done")
|
|
296
|
+
task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message
|
|
297
|
+
assert task_id and context_id and message
|
|
298
|
+
self._request_context = request_context
|
|
299
|
+
self._task_updater = TaskUpdater(event_queue, task_id, context_id)
|
|
300
|
+
|
|
301
|
+
for dependency in self._agent.dependencies.values():
|
|
302
|
+
if dependency.extension:
|
|
303
|
+
dependency.extension.handle_incoming_message(message, self.run_context, request_context)
|
|
304
|
+
|
|
305
|
+
self._working = True
|
|
306
|
+
await self.resume_queue.put(message)
|
|
307
|
+
|
|
308
|
+
async def cancel(self, request_context: RequestContext, event_queue: EventQueue):
|
|
309
|
+
if not self._task:
|
|
310
|
+
raise RuntimeError("Cannot cancel run that has not been started")
|
|
311
|
+
|
|
312
|
+
async with self._lock:
|
|
313
|
+
try:
|
|
314
|
+
assert request_context.task_id
|
|
315
|
+
assert request_context.context_id
|
|
316
|
+
self._task_updater = TaskUpdater(event_queue, request_context.task_id, request_context.context_id)
|
|
317
|
+
await self._task_updater.cancel()
|
|
318
|
+
finally:
|
|
319
|
+
await cancel_task(self._task)
|
|
320
|
+
|
|
321
|
+
@asynccontextmanager
|
|
322
|
+
async def _dependencies_lifespan(self, message: Message) -> AsyncIterator[dict[str, Dependency]]:
|
|
323
|
+
async with AsyncExitStack() as stack:
|
|
324
|
+
dependency_args: dict[str, Dependency] = {}
|
|
325
|
+
initialize_deps_exceptions: list[Exception] = []
|
|
326
|
+
for pname, depends in self._agent.dependencies.items():
|
|
327
|
+
# call dependencies with the first message and initialize their lifespan
|
|
328
|
+
try:
|
|
329
|
+
dependency_args[pname] = await stack.enter_async_context(
|
|
330
|
+
depends(message, self.run_context, self.request_context, dependency_args)
|
|
331
|
+
)
|
|
332
|
+
except Exception as e:
|
|
333
|
+
initialize_deps_exceptions.append(e)
|
|
334
|
+
|
|
335
|
+
if initialize_deps_exceptions:
|
|
336
|
+
raise (
|
|
337
|
+
ExceptionGroup("Failed to initialize dependencies", initialize_deps_exceptions)
|
|
338
|
+
if len(initialize_deps_exceptions) > 1
|
|
339
|
+
else initialize_deps_exceptions[0]
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
self.run_context._store = await self._context_store.create( # pyright: ignore[reportPrivateUsage]
|
|
343
|
+
context_id=self.run_context.context_id,
|
|
344
|
+
initialized_dependencies=list(dependency_args.values()),
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
yield {k: v for k, v in dependency_args.items() if not k.startswith(_IMPLICIT_DEPENDENCY_PREFIX)}
|
|
348
|
+
|
|
349
|
+
def _with_context(self, message: Message | None = None) -> Message | None:
|
|
350
|
+
if message is None:
|
|
351
|
+
return None
|
|
352
|
+
# Note: This check would require extra handling in agents just forwarding messages from other agents
|
|
353
|
+
# Instead, we just silently replace it.
|
|
354
|
+
# if message.task_id and message.task_id != task_updater.task_id:
|
|
355
|
+
# raise ValueError("Message must have the same task_id as the task")
|
|
356
|
+
# if message.context_id and message.context_id != task_updater.context_id:
|
|
357
|
+
# raise ValueError("Message must have the same context_id as the task")
|
|
358
|
+
return message.model_copy(
|
|
359
|
+
deep=True, update={"context_id": self.task_updater.context_id, "task_id": self.task_updater.task_id}
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
async def _run_agent_function(self, initial_message: Message) -> None:
|
|
363
|
+
yield_queue = self.run_context._yield_queue # pyright: ignore[reportPrivateUsage]
|
|
364
|
+
yield_resume_queue = self.run_context._yield_resume_queue # pyright: ignore[reportPrivateUsage]
|
|
365
|
+
|
|
366
|
+
try:
|
|
367
|
+
async with self._dependencies_lifespan(initial_message) as dependency_args:
|
|
368
|
+
task = asyncio.create_task(self._agent.execute_fn(self.run_context, **dependency_args))
|
|
369
|
+
try:
|
|
370
|
+
resume_value: RunYieldResume = None
|
|
371
|
+
opened_artifacts: set[str] = set()
|
|
372
|
+
while not task.done() or yield_queue.async_q.qsize() > 0:
|
|
373
|
+
yielded_value = await yield_queue.async_q.get()
|
|
374
|
+
|
|
375
|
+
self.last_invocation = datetime.now()
|
|
376
|
+
|
|
377
|
+
match yielded_value:
|
|
378
|
+
case str(text):
|
|
379
|
+
await self.task_updater.update_status(
|
|
380
|
+
TaskState.working,
|
|
381
|
+
message=self.task_updater.new_agent_message(parts=[Part(root=TextPart(text=text))]),
|
|
382
|
+
)
|
|
383
|
+
case Part(root=part) | (TextPart() | FilePart() | DataPart() as part):
|
|
384
|
+
await self.task_updater.update_status(
|
|
385
|
+
TaskState.working,
|
|
386
|
+
message=self.task_updater.new_agent_message(parts=[Part(root=part)]),
|
|
387
|
+
)
|
|
388
|
+
case FileWithBytes() | FileWithUri() as file:
|
|
389
|
+
await self.task_updater.update_status(
|
|
390
|
+
TaskState.working,
|
|
391
|
+
message=self.task_updater.new_agent_message(parts=[Part(root=FilePart(file=file))]),
|
|
392
|
+
)
|
|
393
|
+
case Message() as message:
|
|
394
|
+
await self.task_updater.update_status(
|
|
395
|
+
TaskState.working, message=self._with_context(message)
|
|
396
|
+
)
|
|
397
|
+
case ArtifactChunk(
|
|
398
|
+
parts=parts,
|
|
399
|
+
artifact_id=artifact_id,
|
|
400
|
+
name=name,
|
|
401
|
+
metadata=metadata,
|
|
402
|
+
last_chunk=last_chunk,
|
|
403
|
+
):
|
|
404
|
+
await self.task_updater.add_artifact(
|
|
405
|
+
parts=cast(list[Part], parts),
|
|
406
|
+
artifact_id=artifact_id,
|
|
407
|
+
name=name,
|
|
408
|
+
metadata=metadata,
|
|
409
|
+
append=artifact_id in opened_artifacts,
|
|
410
|
+
last_chunk=last_chunk,
|
|
411
|
+
)
|
|
412
|
+
opened_artifacts.add(artifact_id)
|
|
413
|
+
case Artifact(parts=parts, artifact_id=artifact_id, name=name, metadata=metadata):
|
|
414
|
+
await self.task_updater.add_artifact(
|
|
415
|
+
parts=parts,
|
|
416
|
+
artifact_id=artifact_id,
|
|
417
|
+
name=name,
|
|
418
|
+
metadata=metadata,
|
|
419
|
+
last_chunk=True,
|
|
420
|
+
append=False,
|
|
421
|
+
)
|
|
422
|
+
case TaskStatus(
|
|
423
|
+
state=(TaskState.auth_required | TaskState.input_required) as state,
|
|
424
|
+
message=message,
|
|
425
|
+
timestamp=timestamp,
|
|
426
|
+
):
|
|
427
|
+
await self.task_updater.update_status(
|
|
428
|
+
state=state, message=self._with_context(message), final=True, timestamp=timestamp
|
|
429
|
+
)
|
|
430
|
+
self._working = False
|
|
431
|
+
resume_value = await self.resume_queue.get()
|
|
432
|
+
self.resume_queue.task_done()
|
|
433
|
+
case TaskStatus(state=state, message=message, timestamp=timestamp):
|
|
434
|
+
await self.task_updater.update_status(
|
|
435
|
+
state=state, message=self._with_context(message), timestamp=timestamp
|
|
436
|
+
)
|
|
437
|
+
case TaskStatusUpdateEvent(
|
|
438
|
+
status=TaskStatus(state=state, message=message, timestamp=timestamp),
|
|
439
|
+
final=final,
|
|
440
|
+
metadata=metadata,
|
|
441
|
+
):
|
|
442
|
+
await self.task_updater.update_status(
|
|
443
|
+
state=state,
|
|
444
|
+
message=self._with_context(message),
|
|
445
|
+
timestamp=timestamp,
|
|
446
|
+
final=final,
|
|
447
|
+
metadata=metadata,
|
|
448
|
+
)
|
|
449
|
+
case TaskArtifactUpdateEvent(
|
|
450
|
+
artifact=Artifact(artifact_id=artifact_id, name=name, metadata=metadata, parts=parts),
|
|
451
|
+
append=append,
|
|
452
|
+
last_chunk=last_chunk,
|
|
453
|
+
):
|
|
454
|
+
await self.task_updater.add_artifact(
|
|
455
|
+
parts=parts,
|
|
456
|
+
artifact_id=artifact_id,
|
|
457
|
+
name=name,
|
|
458
|
+
metadata=metadata,
|
|
459
|
+
append=append,
|
|
460
|
+
last_chunk=last_chunk,
|
|
461
|
+
)
|
|
462
|
+
case Metadata() as metadata:
|
|
463
|
+
await self.task_updater.update_status(
|
|
464
|
+
state=TaskState.working,
|
|
465
|
+
message=self.task_updater.new_agent_message(parts=[], metadata=metadata),
|
|
466
|
+
)
|
|
467
|
+
case dict() as data:
|
|
468
|
+
await self.task_updater.update_status(
|
|
469
|
+
state=TaskState.working,
|
|
470
|
+
message=self.task_updater.new_agent_message(parts=[Part(root=DataPart(data=data))]),
|
|
471
|
+
)
|
|
472
|
+
case Exception() as ex:
|
|
473
|
+
raise ex
|
|
474
|
+
case _:
|
|
475
|
+
raise ValueError(f"Invalid value yielded from agent: {type(yielded_value)}")
|
|
476
|
+
|
|
477
|
+
await yield_resume_queue.async_q.put(resume_value)
|
|
478
|
+
|
|
479
|
+
await self.task_updater.complete()
|
|
480
|
+
|
|
481
|
+
except (janus.AsyncQueueShutDown, GeneratorExit):
|
|
482
|
+
await self.task_updater.complete()
|
|
483
|
+
except Exception as ex:
|
|
484
|
+
logger.error("Error when executing agent", exc_info=ex)
|
|
485
|
+
await self.task_updater.failed(get_error_extension_context().server.message(ex))
|
|
486
|
+
await cancel_task(task)
|
|
487
|
+
except Exception as ex:
|
|
488
|
+
logger.error("Error when executing agent", exc_info=ex)
|
|
489
|
+
await self.task_updater.failed(get_error_extension_context().server.message(ex))
|
|
490
|
+
finally:
|
|
491
|
+
self._working = False
|
|
492
|
+
with suppress(Exception):
|
|
493
|
+
self._handle_finish()
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
class Executor(AgentExecutor):
|
|
497
|
+
def __init__(
|
|
498
|
+
self,
|
|
499
|
+
agent: Agent,
|
|
500
|
+
queue_manager: QueueManager,
|
|
501
|
+
context_store: ContextStore,
|
|
502
|
+
task_timeout: timedelta,
|
|
503
|
+
task_store: TaskStore,
|
|
504
|
+
) -> None:
|
|
505
|
+
self._agent: Agent = agent
|
|
506
|
+
self._running_tasks: dict[str, AgentRun] = {}
|
|
507
|
+
self._scheduled_cleanups: dict[str, asyncio.Task[None]] = {}
|
|
508
|
+
self._context_store: ContextStore = context_store
|
|
509
|
+
self._task_timeout: timedelta = task_timeout
|
|
510
|
+
self._task_store: TaskStore = task_store
|
|
511
|
+
|
|
512
|
+
@override
|
|
513
|
+
async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
|
|
514
|
+
# this is only executed in the context of SendMessage request
|
|
515
|
+
message, task_id, context_id = context.message, context.task_id, context.context_id
|
|
516
|
+
assert message and task_id and context_id
|
|
517
|
+
agent_run: AgentRun | None = None
|
|
518
|
+
try:
|
|
519
|
+
if not context.current_task:
|
|
520
|
+
agent_run = AgentRun(self._agent, self._context_store, lambda: self._handle_finish(task_id))
|
|
521
|
+
self._running_tasks[task_id] = agent_run
|
|
522
|
+
await self._schedule_run_cleanup(request_context=context)
|
|
523
|
+
await agent_run.start(request_context=context, event_queue=event_queue)
|
|
524
|
+
elif agent_run := self._running_tasks.get(task_id):
|
|
525
|
+
await agent_run.resume(request_context=context, event_queue=event_queue)
|
|
526
|
+
else:
|
|
527
|
+
raise self._run_not_found_error(task_id)
|
|
528
|
+
|
|
529
|
+
# will run until complete or next input/auth required task state
|
|
530
|
+
tapped_queue = event_queue.tap()
|
|
531
|
+
while True:
|
|
532
|
+
match await tapped_queue.dequeue_event():
|
|
533
|
+
case TaskStatusUpdateEvent(final=True):
|
|
534
|
+
break
|
|
535
|
+
case _:
|
|
536
|
+
pass
|
|
537
|
+
|
|
538
|
+
except CancelledError:
|
|
539
|
+
if agent_run:
|
|
540
|
+
await agent_run.cancel(request_context=context, event_queue=event_queue)
|
|
541
|
+
except Exception as ex:
|
|
542
|
+
logger.error("Unhandled error when executing agent:", exc_info=ex)
|
|
543
|
+
|
|
544
|
+
@override
|
|
545
|
+
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
|
|
546
|
+
if not context.task_id or not context.context_id:
|
|
547
|
+
raise ValueError("Task ID and context ID must be set to cancel a task")
|
|
548
|
+
if not (run := self._running_tasks.get(context.task_id)):
|
|
549
|
+
raise self._run_not_found_error(context.task_id)
|
|
550
|
+
await run.cancel(context, event_queue)
|
|
551
|
+
|
|
552
|
+
def _handle_finish(self, task_id: str) -> None:
|
|
553
|
+
if task := self._scheduled_cleanups.pop(task_id, None):
|
|
554
|
+
task.cancel()
|
|
555
|
+
self._running_tasks.pop(task_id, None)
|
|
556
|
+
|
|
557
|
+
def _run_not_found_error(self, task_id: str | None) -> Exception:
|
|
558
|
+
return RuntimeError(
|
|
559
|
+
f"Run for task ID {task_id} not found. "
|
|
560
|
+
+ "It may be on another replica, make sure to enable sticky sessions in your load balancer"
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
async def _schedule_run_cleanup(self, request_context: RequestContext):
|
|
564
|
+
task_id, context_id = request_context.task_id, request_context.context_id
|
|
565
|
+
assert task_id and context_id
|
|
566
|
+
|
|
567
|
+
async def cleanup_fn():
|
|
568
|
+
await asyncio.sleep(self._task_timeout.total_seconds())
|
|
569
|
+
if not (run := self._running_tasks.get(task_id)):
|
|
570
|
+
return
|
|
571
|
+
try:
|
|
572
|
+
while not run.done:
|
|
573
|
+
if run.last_invocation + self._task_timeout < datetime.now():
|
|
574
|
+
logger.warning(f"Task {task_id} did not finish in {self._task_timeout}")
|
|
575
|
+
queue = EventQueue()
|
|
576
|
+
await run.cancel(request_context=request_context, event_queue=queue)
|
|
577
|
+
# the original request queue is closed at this point, we need to propagate state to store manually
|
|
578
|
+
manager = TaskManager(
|
|
579
|
+
task_id=task_id, context_id=context_id, task_store=self._task_store, initial_message=None
|
|
580
|
+
)
|
|
581
|
+
event = await queue.dequeue_event(no_wait=True)
|
|
582
|
+
if not isinstance(event, TaskStatusUpdateEvent) or event.status.state != TaskState.canceled:
|
|
583
|
+
raise RuntimeError(f"Something strange occured during scheduled cancel, event: {event}")
|
|
584
|
+
await manager.save_task_event(event)
|
|
585
|
+
break
|
|
586
|
+
await asyncio.sleep(2)
|
|
587
|
+
except Exception as ex:
|
|
588
|
+
logger.error("Error when cleaning up task", exc_info=ex)
|
|
589
|
+
finally:
|
|
590
|
+
self._running_tasks.pop(task_id, None)
|
|
591
|
+
self._scheduled_cleanups.pop(task_id, None)
|
|
592
|
+
|
|
593
|
+
self._scheduled_cleanups[task_id] = asyncio.create_task(cleanup_fn())
|
|
594
|
+
self._scheduled_cleanups[task_id].add_done_callback(lambda _: ...)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
from datetime import timedelta
|
|
6
|
+
|
|
7
|
+
from a2a.server.agent_execution import RequestContextBuilder
|
|
8
|
+
from a2a.server.apps.jsonrpc import A2AFastAPIApplication
|
|
9
|
+
from a2a.server.apps.rest import A2ARESTFastAPIApplication
|
|
10
|
+
from a2a.server.events import InMemoryQueueManager, QueueManager
|
|
11
|
+
from a2a.server.request_handlers import DefaultRequestHandler
|
|
12
|
+
from a2a.server.tasks import (
|
|
13
|
+
InMemoryTaskStore,
|
|
14
|
+
PushNotificationConfigStore,
|
|
15
|
+
PushNotificationSender,
|
|
16
|
+
TaskStore,
|
|
17
|
+
)
|
|
18
|
+
from a2a.types import AgentInterface, TransportProtocol
|
|
19
|
+
from fastapi import APIRouter, Depends, FastAPI
|
|
20
|
+
from fastapi.applications import AppType
|
|
21
|
+
from starlette.authentication import AuthenticationBackend
|
|
22
|
+
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
23
|
+
from starlette.types import Lifespan
|
|
24
|
+
|
|
25
|
+
from agentstack_sdk.server.agent import Agent, Executor
|
|
26
|
+
from agentstack_sdk.server.store.context_store import ContextStore
|
|
27
|
+
from agentstack_sdk.server.store.memory_context_store import InMemoryContextStore
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def create_app(
|
|
31
|
+
agent: Agent,
|
|
32
|
+
task_store: TaskStore | None = None,
|
|
33
|
+
context_store: ContextStore | None = None,
|
|
34
|
+
queue_manager: QueueManager | None = None,
|
|
35
|
+
push_config_store: PushNotificationConfigStore | None = None,
|
|
36
|
+
push_sender: PushNotificationSender | None = None,
|
|
37
|
+
request_context_builder: RequestContextBuilder | None = None,
|
|
38
|
+
lifespan: Lifespan[AppType] | None = None,
|
|
39
|
+
dependencies: list[Depends] | None = None, # pyright: ignore [reportGeneralTypeIssues]
|
|
40
|
+
override_interfaces: bool = True,
|
|
41
|
+
task_timeout: timedelta = timedelta(minutes=10),
|
|
42
|
+
auth_backend: AuthenticationBackend | None = None,
|
|
43
|
+
**kwargs,
|
|
44
|
+
) -> FastAPI:
|
|
45
|
+
queue_manager = queue_manager or InMemoryQueueManager()
|
|
46
|
+
task_store = task_store or InMemoryTaskStore()
|
|
47
|
+
context_store = context_store or InMemoryContextStore()
|
|
48
|
+
http_handler = DefaultRequestHandler(
|
|
49
|
+
agent_executor=Executor(
|
|
50
|
+
agent,
|
|
51
|
+
queue_manager,
|
|
52
|
+
context_store=context_store,
|
|
53
|
+
task_timeout=task_timeout,
|
|
54
|
+
task_store=task_store,
|
|
55
|
+
),
|
|
56
|
+
task_store=task_store,
|
|
57
|
+
queue_manager=queue_manager,
|
|
58
|
+
push_config_store=push_config_store,
|
|
59
|
+
push_sender=push_sender,
|
|
60
|
+
request_context_builder=request_context_builder,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if override_interfaces:
|
|
64
|
+
agent.card.additional_interfaces = [
|
|
65
|
+
AgentInterface(url=agent.card.url, transport=TransportProtocol.http_json),
|
|
66
|
+
AgentInterface(url=agent.card.url + "/jsonrpc/", transport=TransportProtocol.jsonrpc),
|
|
67
|
+
]
|
|
68
|
+
agent.card.url = agent.card.url + "/jsonrpc/"
|
|
69
|
+
agent.card.preferred_transport = TransportProtocol.jsonrpc
|
|
70
|
+
|
|
71
|
+
jsonrpc_app = A2AFastAPIApplication(agent_card=agent.card, http_handler=http_handler).build(
|
|
72
|
+
dependencies=dependencies,
|
|
73
|
+
**kwargs,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
rest_app = A2ARESTFastAPIApplication(agent_card=agent.card, http_handler=http_handler).build(
|
|
77
|
+
dependencies=dependencies,
|
|
78
|
+
**kwargs,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if auth_backend:
|
|
82
|
+
rest_app.add_middleware(AuthenticationMiddleware, backend=auth_backend)
|
|
83
|
+
jsonrpc_app.add_middleware(AuthenticationMiddleware, backend=auth_backend)
|
|
84
|
+
|
|
85
|
+
rest_app.mount("/jsonrpc", jsonrpc_app)
|
|
86
|
+
rest_app.include_router(APIRouter(lifespan=lifespan))
|
|
87
|
+
return rest_app
|