agentstack-sdk 0.4.3rc2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentstack_sdk/a2a/extensions/base.py +1 -1
- agentstack_sdk/a2a/extensions/services/form.py +2 -1
- agentstack_sdk/a2a/extensions/services/mcp.py +8 -13
- agentstack_sdk/a2a/extensions/ui/canvas.py +4 -1
- agentstack_sdk/a2a/extensions/ui/citation.py +1 -6
- agentstack_sdk/a2a/extensions/ui/error.py +39 -30
- agentstack_sdk/a2a/extensions/ui/trajectory.py +1 -2
- agentstack_sdk/a2a/types.py +15 -14
- agentstack_sdk/platform/__init__.py +2 -0
- agentstack_sdk/platform/context.py +2 -6
- agentstack_sdk/platform/file.py +45 -1
- agentstack_sdk/platform/user.py +49 -4
- agentstack_sdk/platform/user_feedback.py +42 -0
- agentstack_sdk/server/agent.py +289 -278
- agentstack_sdk/server/app.py +13 -2
- agentstack_sdk/server/constants.py +0 -3
- agentstack_sdk/server/context.py +0 -9
- agentstack_sdk/server/dependencies.py +5 -11
- agentstack_sdk/server/server.py +3 -1
- agentstack_sdk/util/utils.py +5 -1
- agentstack_sdk-0.5.0.dist-info/METADATA +118 -0
- {agentstack_sdk-0.4.3rc2.dist-info → agentstack_sdk-0.5.0.dist-info}/RECORD +23 -22
- agentstack_sdk-0.4.3rc2.dist-info/METADATA +0 -69
- {agentstack_sdk-0.4.3rc2.dist-info → agentstack_sdk-0.5.0.dist-info}/WHEEL +0 -0
agentstack_sdk/server/agent.py
CHANGED
|
@@ -3,17 +3,17 @@
|
|
|
3
3
|
|
|
4
4
|
import asyncio
|
|
5
5
|
import inspect
|
|
6
|
+
import typing
|
|
6
7
|
from asyncio import CancelledError
|
|
7
8
|
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator
|
|
8
|
-
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
9
|
+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, suppress
|
|
9
10
|
from datetime import datetime, timedelta
|
|
10
|
-
from typing import NamedTuple, TypeAlias,
|
|
11
|
+
from typing import Any, NamedTuple, TypeAlias, TypeVar, cast
|
|
11
12
|
|
|
12
13
|
import janus
|
|
13
|
-
from a2a.client import create_text_message_object
|
|
14
14
|
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
|
15
15
|
from a2a.server.events import EventQueue, QueueManager
|
|
16
|
-
from a2a.server.tasks import TaskUpdater
|
|
16
|
+
from a2a.server.tasks import TaskManager, TaskStore, TaskUpdater
|
|
17
17
|
from a2a.types import (
|
|
18
18
|
AgentCapabilities,
|
|
19
19
|
AgentCard,
|
|
@@ -34,29 +34,40 @@ from a2a.types import (
|
|
|
34
34
|
TaskStatusUpdateEvent,
|
|
35
35
|
TextPart,
|
|
36
36
|
)
|
|
37
|
+
from typing_extensions import override
|
|
37
38
|
|
|
38
39
|
from agentstack_sdk.a2a.extensions.ui.agent_detail import AgentDetail, AgentDetailExtensionSpec
|
|
39
|
-
from agentstack_sdk.a2a.extensions.ui.error import
|
|
40
|
-
|
|
41
|
-
|
|
40
|
+
from agentstack_sdk.a2a.extensions.ui.error import (
|
|
41
|
+
ErrorExtensionParams,
|
|
42
|
+
ErrorExtensionServer,
|
|
43
|
+
ErrorExtensionSpec,
|
|
44
|
+
get_error_extension_context,
|
|
45
|
+
)
|
|
46
|
+
from agentstack_sdk.a2a.types import ArtifactChunk, Metadata, RunYield, RunYieldResume
|
|
47
|
+
from agentstack_sdk.server.constants import _IMPLICIT_DEPENDENCY_PREFIX
|
|
42
48
|
from agentstack_sdk.server.context import RunContext
|
|
43
|
-
from agentstack_sdk.server.dependencies import extract_dependencies
|
|
49
|
+
from agentstack_sdk.server.dependencies import Depends, extract_dependencies
|
|
44
50
|
from agentstack_sdk.server.store.context_store import ContextStore
|
|
45
|
-
from agentstack_sdk.server.utils import cancel_task
|
|
51
|
+
from agentstack_sdk.server.utils import cancel_task
|
|
46
52
|
from agentstack_sdk.util.logging import logger
|
|
47
53
|
|
|
48
54
|
AgentFunction: TypeAlias = Callable[[], AsyncGenerator[RunYield, RunYieldResume]]
|
|
49
|
-
AgentFunctionFactory: TypeAlias = Callable[
|
|
50
|
-
|
|
51
|
-
]
|
|
55
|
+
AgentFunctionFactory: TypeAlias = Callable[[RequestContext, ContextStore], AbstractAsyncContextManager[AgentFunction]]
|
|
56
|
+
|
|
57
|
+
OriginalFnType = TypeVar("OriginalFnType", bound=Callable[..., Any]) # pyright: ignore[reportExplicitAny]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class AgentExecuteFn(typing.Protocol):
|
|
61
|
+
async def __call__(self, _ctx: RunContext, **kwargs: Any) -> None: ...
|
|
52
62
|
|
|
53
63
|
|
|
54
64
|
class Agent(NamedTuple):
|
|
55
65
|
card: AgentCard
|
|
56
|
-
|
|
66
|
+
dependencies: dict[str, Depends]
|
|
67
|
+
execute_fn: AgentExecuteFn
|
|
57
68
|
|
|
58
69
|
|
|
59
|
-
AgentFactory: TypeAlias = Callable[[
|
|
70
|
+
AgentFactory: TypeAlias = Callable[[Callable[[dict[str, Depends]], None]], Agent]
|
|
60
71
|
|
|
61
72
|
|
|
62
73
|
def agent(
|
|
@@ -78,7 +89,7 @@ def agent(
|
|
|
78
89
|
skills: list[AgentSkill] | None = None,
|
|
79
90
|
supports_authenticated_extended_card: bool | None = None,
|
|
80
91
|
version: str | None = None,
|
|
81
|
-
) -> Callable[[
|
|
92
|
+
) -> Callable[[OriginalFnType], AgentFactory]:
|
|
82
93
|
"""
|
|
83
94
|
Create an Agent function.
|
|
84
95
|
|
|
@@ -112,11 +123,11 @@ def agent(
|
|
|
112
123
|
capabilities = capabilities.model_copy(deep=True) if capabilities else AgentCapabilities(streaming=True)
|
|
113
124
|
detail = detail or AgentDetail() # pyright: ignore [reportCallIssue]
|
|
114
125
|
|
|
115
|
-
def decorator(fn:
|
|
116
|
-
def agent_factory(
|
|
126
|
+
def decorator(fn: OriginalFnType) -> AgentFactory:
|
|
127
|
+
def agent_factory(modify_dependencies: Callable[[dict[str, Depends]], None]):
|
|
117
128
|
signature = inspect.signature(fn)
|
|
118
129
|
dependencies = extract_dependencies(signature)
|
|
119
|
-
|
|
130
|
+
modify_dependencies(dependencies)
|
|
120
131
|
|
|
121
132
|
sdk_extensions = [dep.extension for dep in dependencies.values() if dep.extension is not None]
|
|
122
133
|
|
|
@@ -209,178 +220,172 @@ def agent(
|
|
|
209
220
|
async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None:
|
|
210
221
|
await asyncio.to_thread(_execute_fn_sync, _ctx, *args, **kwargs)
|
|
211
222
|
|
|
212
|
-
|
|
213
|
-
async def agent_executor_lifespan(
|
|
214
|
-
task_updater: TaskUpdater, request_context: RequestContext, context_store: ContextStore
|
|
215
|
-
) -> AsyncIterator[tuple[AgentFunction, RunContext]]:
|
|
216
|
-
message = request_context.message
|
|
217
|
-
assert message # this is only executed in the context of SendMessage request
|
|
218
|
-
# These are incorrectly typed in a2a
|
|
219
|
-
assert request_context.task_id
|
|
220
|
-
assert request_context.context_id
|
|
221
|
-
context = RunContext(
|
|
222
|
-
configuration=request_context.configuration,
|
|
223
|
-
context_id=request_context.context_id,
|
|
224
|
-
task_id=request_context.task_id,
|
|
225
|
-
task_updater=task_updater,
|
|
226
|
-
current_task=request_context.current_task,
|
|
227
|
-
related_tasks=request_context.related_tasks,
|
|
228
|
-
call_context=request_context.call_context,
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
async with AsyncExitStack() as stack:
|
|
232
|
-
dependency_args = {}
|
|
233
|
-
for pname, depends in dependencies.items():
|
|
234
|
-
# call dependencies with the first message and initialize their lifespan
|
|
235
|
-
dependency_args[pname] = await stack.enter_async_context(
|
|
236
|
-
depends(message, context, dependency_args)
|
|
237
|
-
)
|
|
238
|
-
|
|
239
|
-
context._error_extension = next(
|
|
240
|
-
(ext for ext in dependency_args.values() if isinstance(ext, ErrorExtensionServer)),
|
|
241
|
-
DEFAULT_ERROR_EXTENSION,
|
|
242
|
-
)
|
|
243
|
-
|
|
244
|
-
context._store = await context_store.create(
|
|
245
|
-
context_id=request_context.context_id,
|
|
246
|
-
initialized_dependencies=list(dependency_args.values()),
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
async def agent_generator():
|
|
250
|
-
yield_queue = context._yield_queue
|
|
251
|
-
yield_resume_queue = context._yield_resume_queue
|
|
252
|
-
|
|
253
|
-
task = asyncio.create_task(
|
|
254
|
-
execute_fn(
|
|
255
|
-
context,
|
|
256
|
-
**{
|
|
257
|
-
k: v
|
|
258
|
-
for k, v in dependency_args.items()
|
|
259
|
-
if not k.startswith(_IMPLICIT_DEPENDENCY_PREFIX)
|
|
260
|
-
},
|
|
261
|
-
)
|
|
262
|
-
)
|
|
263
|
-
try:
|
|
264
|
-
while not task.done() or yield_queue.async_q.qsize() > 0:
|
|
265
|
-
value = yield await yield_queue.async_q.get()
|
|
266
|
-
if isinstance(value, Exception):
|
|
267
|
-
raise value
|
|
268
|
-
|
|
269
|
-
if value:
|
|
270
|
-
# TODO: context.call_context should be updated here
|
|
271
|
-
# Unfortunately queue implementation does not support passing external types
|
|
272
|
-
# (only a2a.event_queue.Event is supported:
|
|
273
|
-
# Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
|
|
274
|
-
for ext in sdk_extensions:
|
|
275
|
-
ext.handle_incoming_message(value, context)
|
|
276
|
-
|
|
277
|
-
await yield_resume_queue.async_q.put(value)
|
|
278
|
-
except janus.AsyncQueueShutDown:
|
|
279
|
-
pass
|
|
280
|
-
except GeneratorExit:
|
|
281
|
-
return
|
|
282
|
-
finally:
|
|
283
|
-
await cancel_task(task)
|
|
284
|
-
|
|
285
|
-
yield agent_generator, context
|
|
286
|
-
|
|
287
|
-
return Agent(card=card, execute=agent_executor_lifespan)
|
|
223
|
+
return Agent(card=card, dependencies=dependencies, execute_fn=execute_fn)
|
|
288
224
|
|
|
289
225
|
return agent_factory
|
|
290
226
|
|
|
291
227
|
return decorator
|
|
292
228
|
|
|
293
229
|
|
|
294
|
-
class
|
|
295
|
-
|
|
296
|
-
|
|
230
|
+
class AgentRun:
|
|
231
|
+
def __init__(self, agent: Agent, context_store: ContextStore, on_finish: Callable[[], None] | None = None) -> None:
|
|
232
|
+
self._agent: Agent = agent
|
|
233
|
+
self._task: asyncio.Task[None] | None = None
|
|
234
|
+
self.last_invocation: datetime = datetime.now()
|
|
235
|
+
self.resume_queue: asyncio.Queue[RunYieldResume] = asyncio.Queue()
|
|
236
|
+
self._run_context: RunContext | None = None
|
|
237
|
+
self._task_updater: TaskUpdater | None = None
|
|
238
|
+
self._context_store: ContextStore = context_store
|
|
239
|
+
self._lock: asyncio.Lock = asyncio.Lock()
|
|
240
|
+
self._on_finish: Callable[[], None] | None = on_finish
|
|
241
|
+
self._working: bool = False
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
def run_context(self) -> RunContext:
|
|
245
|
+
if not self._run_context:
|
|
246
|
+
raise RuntimeError("Accessing run context for run that has not been started")
|
|
247
|
+
return self._run_context
|
|
248
|
+
|
|
249
|
+
@property
|
|
250
|
+
def task_updater(self) -> TaskUpdater:
|
|
251
|
+
if not self._task_updater:
|
|
252
|
+
raise RuntimeError("Accessing task updater for run that has not been started")
|
|
253
|
+
return self._task_updater
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def done(self) -> bool:
|
|
257
|
+
return self._task is not None and self._task.done()
|
|
258
|
+
|
|
259
|
+
def _handle_finish(self) -> None:
|
|
260
|
+
if self._on_finish:
|
|
261
|
+
self._on_finish()
|
|
262
|
+
|
|
263
|
+
async def start(self, request_context: RequestContext, event_queue: EventQueue):
|
|
264
|
+
# These are incorrectly typed in a2a
|
|
265
|
+
async with self._lock:
|
|
266
|
+
if self._working or self.done:
|
|
267
|
+
raise RuntimeError("Attempting to start a run that is already executing or done")
|
|
268
|
+
task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message
|
|
269
|
+
assert task_id and context_id and message
|
|
270
|
+
self._run_context = RunContext(
|
|
271
|
+
configuration=request_context.configuration,
|
|
272
|
+
context_id=context_id,
|
|
273
|
+
task_id=task_id,
|
|
274
|
+
current_task=request_context.current_task,
|
|
275
|
+
related_tasks=request_context.related_tasks,
|
|
276
|
+
)
|
|
277
|
+
self._task_updater = TaskUpdater(event_queue, task_id, context_id)
|
|
278
|
+
if not request_context.current_task:
|
|
279
|
+
await self._task_updater.submit()
|
|
280
|
+
await self._task_updater.start_work()
|
|
281
|
+
self._working = True
|
|
282
|
+
self._task = asyncio.create_task(self._run_agent_function(initial_message=message))
|
|
283
|
+
|
|
284
|
+
async def resume(self, request_context: RequestContext, event_queue: EventQueue):
|
|
285
|
+
# These are incorrectly typed in a2a
|
|
286
|
+
async with self._lock:
|
|
287
|
+
if self._working or self.done:
|
|
288
|
+
raise RuntimeError("Attempting to resume a run that is already executing or done")
|
|
289
|
+
task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message
|
|
290
|
+
assert task_id and context_id and message
|
|
291
|
+
self._task_updater = TaskUpdater(event_queue, task_id, context_id)
|
|
297
292
|
|
|
293
|
+
for dependency in self._agent.dependencies.values():
|
|
294
|
+
if dependency.extension:
|
|
295
|
+
dependency.extension.handle_incoming_message(message, self.run_context)
|
|
298
296
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
self,
|
|
302
|
-
execute_fn: AgentFunctionFactory,
|
|
303
|
-
queue_manager: QueueManager,
|
|
304
|
-
context_store: ContextStore,
|
|
305
|
-
task_timeout: timedelta,
|
|
306
|
-
) -> None:
|
|
307
|
-
self._agent_executor_span = execute_fn
|
|
308
|
-
self._queue_manager = queue_manager
|
|
309
|
-
self._running_tasks: dict[str, RunningTask] = {}
|
|
310
|
-
self._cancel_queues: dict[str, EventQueue] = {}
|
|
311
|
-
self._context_store = context_store
|
|
312
|
-
self._task_timeout = task_timeout
|
|
297
|
+
self._working = True
|
|
298
|
+
await self.resume_queue.put(message)
|
|
313
299
|
|
|
314
|
-
async def
|
|
315
|
-
|
|
316
|
-
|
|
300
|
+
async def cancel(self, request_context: RequestContext, event_queue: EventQueue):
|
|
301
|
+
if not self._task:
|
|
302
|
+
raise RuntimeError("Cannot cancel run that has not been started")
|
|
317
303
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
304
|
+
async with self._lock:
|
|
305
|
+
try:
|
|
306
|
+
assert request_context.task_id
|
|
307
|
+
assert request_context.context_id
|
|
308
|
+
self._task_updater = TaskUpdater(event_queue, request_context.task_id, request_context.context_id)
|
|
309
|
+
await self._task_updater.cancel()
|
|
310
|
+
finally:
|
|
311
|
+
await cancel_task(self._task)
|
|
312
|
+
|
|
313
|
+
@asynccontextmanager
|
|
314
|
+
async def _dependencies_lifespan(self, message: Message) -> AsyncIterator[dict[str, Depends]]:
|
|
315
|
+
async with AsyncExitStack() as stack:
|
|
316
|
+
dependency_args: dict[str, Depends] = {}
|
|
317
|
+
initialize_deps_exceptions: list[Exception] = []
|
|
318
|
+
for pname, depends in self._agent.dependencies.items():
|
|
319
|
+
# call dependencies with the first message and initialize their lifespan
|
|
320
|
+
try:
|
|
321
|
+
dependency_args[pname] = await stack.enter_async_context(
|
|
322
|
+
depends(message, self.run_context, dependency_args)
|
|
323
|
+
)
|
|
324
|
+
except Exception as e:
|
|
325
|
+
initialize_deps_exceptions.append(e)
|
|
326
|
+
|
|
327
|
+
if initialize_deps_exceptions:
|
|
328
|
+
raise (
|
|
329
|
+
ExceptionGroup("Failed to initialize dependencies", initialize_deps_exceptions)
|
|
330
|
+
if len(initialize_deps_exceptions) > 1
|
|
331
|
+
else initialize_deps_exceptions[0]
|
|
332
|
+
)
|
|
325
333
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
context: RequestContext,
|
|
330
|
-
context_store: ContextStore,
|
|
331
|
-
task_updater: TaskUpdater,
|
|
332
|
-
resume_queue: EventQueue,
|
|
333
|
-
) -> None:
|
|
334
|
-
current_task = asyncio.current_task()
|
|
335
|
-
assert current_task
|
|
336
|
-
cancellation_task = asyncio.create_task(self._watch_for_cancellation(task_updater.task_id, current_task))
|
|
337
|
-
|
|
338
|
-
def with_context(message: Message | None = None) -> Message | None:
|
|
339
|
-
if message is None:
|
|
340
|
-
return None
|
|
341
|
-
# Note: This check would require extra handling in agents just forwarding messages from other agents
|
|
342
|
-
# Instead, we just silently replace it.
|
|
343
|
-
# if message.task_id and message.task_id != task_updater.task_id:
|
|
344
|
-
# raise ValueError("Message must have the same task_id as the task")
|
|
345
|
-
# if message.context_id and message.context_id != task_updater.context_id:
|
|
346
|
-
# raise ValueError("Message must have the same context_id as the task")
|
|
347
|
-
return message.model_copy(
|
|
348
|
-
deep=True, update={"context_id": task_updater.context_id, "task_id": task_updater.task_id}
|
|
334
|
+
self.run_context._store = await self._context_store.create( # pyright: ignore[reportPrivateUsage]
|
|
335
|
+
context_id=self.run_context.context_id,
|
|
336
|
+
initialized_dependencies=list(dependency_args.values()),
|
|
349
337
|
)
|
|
350
338
|
|
|
351
|
-
|
|
339
|
+
yield {k: v for k, v in dependency_args.items() if not k.startswith(_IMPLICIT_DEPENDENCY_PREFIX)}
|
|
340
|
+
|
|
341
|
+
def _with_context(self, message: Message | None = None) -> Message | None:
|
|
342
|
+
if message is None:
|
|
343
|
+
return None
|
|
344
|
+
# Note: This check would require extra handling in agents just forwarding messages from other agents
|
|
345
|
+
# Instead, we just silently replace it.
|
|
346
|
+
# if message.task_id and message.task_id != task_updater.task_id:
|
|
347
|
+
# raise ValueError("Message must have the same task_id as the task")
|
|
348
|
+
# if message.context_id and message.context_id != task_updater.context_id:
|
|
349
|
+
# raise ValueError("Message must have the same context_id as the task")
|
|
350
|
+
return message.model_copy(
|
|
351
|
+
deep=True, update={"context_id": self.task_updater.context_id, "task_id": self.task_updater.task_id}
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
async def _run_agent_function(self, initial_message: Message) -> None:
|
|
355
|
+
yield_queue = self.run_context._yield_queue # pyright: ignore[reportPrivateUsage]
|
|
356
|
+
yield_resume_queue = self.run_context._yield_resume_queue # pyright: ignore[reportPrivateUsage]
|
|
357
|
+
|
|
352
358
|
try:
|
|
353
|
-
async with self.
|
|
359
|
+
async with self._dependencies_lifespan(initial_message) as dependency_args:
|
|
360
|
+
task = asyncio.create_task(self._agent.execute_fn(self.run_context, **dependency_args))
|
|
354
361
|
try:
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
await task_updater.start_work()
|
|
358
|
-
value: RunYieldResume = None
|
|
362
|
+
resume_value: RunYieldResume = None
|
|
359
363
|
opened_artifacts: set[str] = set()
|
|
360
|
-
while
|
|
361
|
-
|
|
362
|
-
self._running_tasks[task_updater.task_id]["last_invocation"] = datetime.now()
|
|
364
|
+
while not task.done() or yield_queue.async_q.qsize() > 0:
|
|
365
|
+
yielded_value = await yield_queue.async_q.get()
|
|
363
366
|
|
|
364
|
-
|
|
367
|
+
self.last_invocation = datetime.now()
|
|
365
368
|
|
|
366
369
|
match yielded_value:
|
|
367
370
|
case str(text):
|
|
368
|
-
await task_updater.update_status(
|
|
371
|
+
await self.task_updater.update_status(
|
|
369
372
|
TaskState.working,
|
|
370
|
-
message=task_updater.new_agent_message(parts=[Part(root=TextPart(text=text))]),
|
|
373
|
+
message=self.task_updater.new_agent_message(parts=[Part(root=TextPart(text=text))]),
|
|
371
374
|
)
|
|
372
375
|
case Part(root=part) | (TextPart() | FilePart() | DataPart() as part):
|
|
373
|
-
await task_updater.update_status(
|
|
376
|
+
await self.task_updater.update_status(
|
|
374
377
|
TaskState.working,
|
|
375
|
-
message=task_updater.new_agent_message(parts=[Part(root=part)]),
|
|
378
|
+
message=self.task_updater.new_agent_message(parts=[Part(root=part)]),
|
|
376
379
|
)
|
|
377
380
|
case FileWithBytes() | FileWithUri() as file:
|
|
378
|
-
await task_updater.update_status(
|
|
381
|
+
await self.task_updater.update_status(
|
|
379
382
|
TaskState.working,
|
|
380
|
-
message=task_updater.new_agent_message(parts=[Part(root=FilePart(file=file))]),
|
|
383
|
+
message=self.task_updater.new_agent_message(parts=[Part(root=FilePart(file=file))]),
|
|
381
384
|
)
|
|
382
385
|
case Message() as message:
|
|
383
|
-
await task_updater.update_status(
|
|
386
|
+
await self.task_updater.update_status(
|
|
387
|
+
TaskState.working, message=self._with_context(message)
|
|
388
|
+
)
|
|
384
389
|
case ArtifactChunk(
|
|
385
390
|
parts=parts,
|
|
386
391
|
artifact_id=artifact_id,
|
|
@@ -388,7 +393,7 @@ class Executor(AgentExecutor):
|
|
|
388
393
|
metadata=metadata,
|
|
389
394
|
last_chunk=last_chunk,
|
|
390
395
|
):
|
|
391
|
-
await task_updater.add_artifact(
|
|
396
|
+
await self.task_updater.add_artifact(
|
|
392
397
|
parts=cast(list[Part], parts),
|
|
393
398
|
artifact_id=artifact_id,
|
|
394
399
|
name=name,
|
|
@@ -398,7 +403,7 @@ class Executor(AgentExecutor):
|
|
|
398
403
|
)
|
|
399
404
|
opened_artifacts.add(artifact_id)
|
|
400
405
|
case Artifact(parts=parts, artifact_id=artifact_id, name=name, metadata=metadata):
|
|
401
|
-
await task_updater.add_artifact(
|
|
406
|
+
await self.task_updater.add_artifact(
|
|
402
407
|
parts=parts,
|
|
403
408
|
artifact_id=artifact_id,
|
|
404
409
|
name=name,
|
|
@@ -406,28 +411,29 @@ class Executor(AgentExecutor):
|
|
|
406
411
|
last_chunk=True,
|
|
407
412
|
append=False,
|
|
408
413
|
)
|
|
409
|
-
case TaskStatus(
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
414
|
+
case TaskStatus(
|
|
415
|
+
state=(TaskState.auth_required | TaskState.input_required) as state,
|
|
416
|
+
message=message,
|
|
417
|
+
timestamp=timestamp,
|
|
418
|
+
):
|
|
419
|
+
await self.task_updater.update_status(
|
|
420
|
+
state=state, message=self._with_context(message), final=True, timestamp=timestamp
|
|
421
|
+
)
|
|
422
|
+
self._working = False
|
|
423
|
+
resume_value = await self.resume_queue.get()
|
|
424
|
+
self.resume_queue.task_done()
|
|
419
425
|
case TaskStatus(state=state, message=message, timestamp=timestamp):
|
|
420
|
-
await task_updater.update_status(
|
|
421
|
-
state=state, message=
|
|
426
|
+
await self.task_updater.update_status(
|
|
427
|
+
state=state, message=self._with_context(message), timestamp=timestamp
|
|
422
428
|
)
|
|
423
429
|
case TaskStatusUpdateEvent(
|
|
424
430
|
status=TaskStatus(state=state, message=message, timestamp=timestamp),
|
|
425
431
|
final=final,
|
|
426
432
|
metadata=metadata,
|
|
427
433
|
):
|
|
428
|
-
await task_updater.update_status(
|
|
434
|
+
await self.task_updater.update_status(
|
|
429
435
|
state=state,
|
|
430
|
-
message=
|
|
436
|
+
message=self._with_context(message),
|
|
431
437
|
timestamp=timestamp,
|
|
432
438
|
final=final,
|
|
433
439
|
metadata=metadata,
|
|
@@ -437,7 +443,7 @@ class Executor(AgentExecutor):
|
|
|
437
443
|
append=append,
|
|
438
444
|
last_chunk=last_chunk,
|
|
439
445
|
):
|
|
440
|
-
await task_updater.add_artifact(
|
|
446
|
+
await self.task_updater.add_artifact(
|
|
441
447
|
parts=parts,
|
|
442
448
|
artifact_id=artifact_id,
|
|
443
449
|
name=name,
|
|
@@ -446,128 +452,133 @@ class Executor(AgentExecutor):
|
|
|
446
452
|
last_chunk=last_chunk,
|
|
447
453
|
)
|
|
448
454
|
case Metadata() as metadata:
|
|
449
|
-
await task_updater.update_status(
|
|
455
|
+
await self.task_updater.update_status(
|
|
450
456
|
state=TaskState.working,
|
|
451
|
-
message=task_updater.new_agent_message(parts=[], metadata=metadata),
|
|
457
|
+
message=self.task_updater.new_agent_message(parts=[], metadata=metadata),
|
|
452
458
|
)
|
|
453
459
|
case dict() as data:
|
|
454
|
-
await task_updater.update_status(
|
|
460
|
+
await self.task_updater.update_status(
|
|
455
461
|
state=TaskState.working,
|
|
456
|
-
message=task_updater.new_agent_message(parts=[Part(root=DataPart(data=data))]),
|
|
462
|
+
message=self.task_updater.new_agent_message(parts=[Part(root=DataPart(data=data))]),
|
|
457
463
|
)
|
|
458
464
|
case Exception() as ex:
|
|
459
465
|
raise ex
|
|
460
466
|
case _:
|
|
461
467
|
raise ValueError(f"Invalid value yielded from agent: {type(yielded_value)}")
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
468
|
+
|
|
469
|
+
await yield_resume_queue.async_q.put(resume_value)
|
|
470
|
+
|
|
471
|
+
await self.task_updater.complete()
|
|
472
|
+
|
|
473
|
+
except (janus.AsyncQueueShutDown, GeneratorExit):
|
|
474
|
+
await self.task_updater.complete()
|
|
467
475
|
except Exception as ex:
|
|
468
476
|
logger.error("Error when executing agent", exc_info=ex)
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
await cancel_task(cancellation_task)
|
|
480
|
-
is_cancelling = bool(current_task.cancelling())
|
|
481
|
-
try:
|
|
482
|
-
async with asyncio.timeout(10): # grace period to read all events from queue
|
|
483
|
-
await close_queue(self._queue_manager, f"_event_{context.task_id}", immediate=is_cancelling)
|
|
484
|
-
await close_queue(self._queue_manager, f"_resume_{context.task_id}", immediate=is_cancelling)
|
|
485
|
-
except (TimeoutError, CancelledError):
|
|
486
|
-
await close_queue(self._queue_manager, f"_event_{context.task_id}", immediate=True)
|
|
487
|
-
await close_queue(self._queue_manager, f"_resume_{context.task_id}", immediate=True)
|
|
477
|
+
await self.task_updater.failed(get_error_extension_context().server.message(ex))
|
|
478
|
+
await cancel_task(task)
|
|
479
|
+
except Exception as ex:
|
|
480
|
+
logger.error("Error when executing agent", exc_info=ex)
|
|
481
|
+
await self.task_updater.failed(get_error_extension_context().server.message(ex))
|
|
482
|
+
finally:
|
|
483
|
+
self._working = False
|
|
484
|
+
with suppress(Exception):
|
|
485
|
+
self._handle_finish()
|
|
486
|
+
|
|
488
487
|
|
|
488
|
+
class Executor(AgentExecutor):
|
|
489
|
+
def __init__(
|
|
490
|
+
self,
|
|
491
|
+
agent: Agent,
|
|
492
|
+
queue_manager: QueueManager,
|
|
493
|
+
context_store: ContextStore,
|
|
494
|
+
task_timeout: timedelta,
|
|
495
|
+
task_store: TaskStore,
|
|
496
|
+
) -> None:
|
|
497
|
+
self._agent: Agent = agent
|
|
498
|
+
self._running_tasks: dict[str, AgentRun] = {}
|
|
499
|
+
self._scheduled_cleanups: dict[str, asyncio.Task[None]] = {}
|
|
500
|
+
self._context_store: ContextStore = context_store
|
|
501
|
+
self._task_timeout: timedelta = task_timeout
|
|
502
|
+
self._task_store: TaskStore = task_store
|
|
503
|
+
|
|
504
|
+
@override
|
|
489
505
|
async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
assert
|
|
493
|
-
|
|
506
|
+
# this is only executed in the context of SendMessage request
|
|
507
|
+
message, task_id, context_id = context.message, context.task_id, context.context_id
|
|
508
|
+
assert message and task_id and context_id
|
|
509
|
+
agent_run: AgentRun | None = None
|
|
494
510
|
try:
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
resume_queue = await self._queue_manager.create_or_tap(task_id=f"_resume_{context.task_id}")
|
|
503
|
-
|
|
504
|
-
if not (long_running_event_queue := await self._queue_manager.get(task_id=f"_event_{context.task_id}")):
|
|
505
|
-
long_running_event_queue = await self._queue_manager.create_or_tap(task_id=f"_event_{context.task_id}")
|
|
506
|
-
|
|
507
|
-
if current_status in {TaskState.input_required, TaskState.auth_required}:
|
|
508
|
-
await resume_queue.enqueue_event(context.message)
|
|
511
|
+
if not context.current_task:
|
|
512
|
+
agent_run = AgentRun(self._agent, self._context_store, lambda: self._handle_finish(task_id))
|
|
513
|
+
self._running_tasks[task_id] = agent_run
|
|
514
|
+
await self._schedule_run_cleanup(request_context=context)
|
|
515
|
+
await agent_run.start(request_context=context, event_queue=event_queue)
|
|
516
|
+
elif agent_run := self._running_tasks.get(task_id):
|
|
517
|
+
await agent_run.resume(request_context=context, event_queue=event_queue)
|
|
509
518
|
else:
|
|
510
|
-
|
|
511
|
-
run_generator = self._run_agent_function(
|
|
512
|
-
context=context,
|
|
513
|
-
context_store=self._context_store,
|
|
514
|
-
task_updater=task_updater,
|
|
515
|
-
resume_queue=resume_queue,
|
|
516
|
-
)
|
|
517
|
-
|
|
518
|
-
self._running_tasks[context.task_id] = RunningTask(
|
|
519
|
-
task=asyncio.create_task(run_generator), last_invocation=datetime.now()
|
|
520
|
-
)
|
|
521
|
-
asyncio.create_task(
|
|
522
|
-
self._schedule_run_cleanup(task_id=context.task_id, task_timeout=self._task_timeout)
|
|
523
|
-
).add_done_callback(lambda _: ...)
|
|
519
|
+
raise self._run_not_found_error(task_id)
|
|
524
520
|
|
|
521
|
+
# will run until complete or next input/auth required task state
|
|
522
|
+
tapped_queue = event_queue.tap()
|
|
525
523
|
while True:
|
|
526
|
-
|
|
527
|
-
event = await long_running_event_queue.dequeue_event()
|
|
528
|
-
long_running_event_queue.task_done()
|
|
529
|
-
await event_queue.enqueue_event(event)
|
|
530
|
-
match event:
|
|
524
|
+
match await tapped_queue.dequeue_event():
|
|
531
525
|
case TaskStatusUpdateEvent(final=True):
|
|
532
526
|
break
|
|
533
|
-
except CancelledError:
|
|
534
|
-
# Handles cancellation of this handler:
|
|
535
|
-
# When a streaming request is canceled, this executor is canceled first meaning that "cancellation" event
|
|
536
|
-
# passed from the agent's long_running_event_queue is not forwarded. Instead of shielding this function,
|
|
537
|
-
# we report the cancellation explicitly
|
|
538
|
-
await self._cancel_task(context.task_id)
|
|
539
|
-
local_updater = TaskUpdater(event_queue, task_id=context.task_id, context_id=context.context_id)
|
|
540
|
-
await local_updater.cancel()
|
|
541
|
-
except Exception as ex:
|
|
542
|
-
logger.error("Error executing agent", exc_info=ex)
|
|
543
|
-
local_updater = TaskUpdater(event_queue, task_id=context.task_id, context_id=context.context_id)
|
|
544
|
-
await local_updater.failed(local_updater.new_agent_message(parts=[Part(root=TextPart(text=str(ex)))]))
|
|
545
|
-
|
|
546
|
-
async def _cancel_task(self, task_id: str):
|
|
547
|
-
if queue := self._cancel_queues.get(task_id):
|
|
548
|
-
await queue.enqueue_event(create_text_message_object(content="canceled"))
|
|
549
|
-
|
|
550
|
-
async def _schedule_run_cleanup(self, task_id: str, task_timeout: timedelta):
|
|
551
|
-
task = self._running_tasks.get(task_id)
|
|
552
|
-
assert task
|
|
553
527
|
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
await
|
|
557
|
-
if not task["task"].done() and task["last_invocation"] + task_timeout < datetime.now():
|
|
558
|
-
# Task might be stuck waiting for queue events to be processed
|
|
559
|
-
logger.warning(f"Task {task_id} did not finish in {task_timeout}")
|
|
560
|
-
await self._cancel_task(task_id)
|
|
561
|
-
break
|
|
528
|
+
except CancelledError:
|
|
529
|
+
if agent_run:
|
|
530
|
+
await agent_run.cancel(request_context=context, event_queue=event_queue)
|
|
562
531
|
except Exception as ex:
|
|
563
|
-
logger.error("
|
|
564
|
-
finally:
|
|
565
|
-
self._running_tasks.pop(task_id)
|
|
532
|
+
logger.error("Unhandled error when executing agent:", exc_info=ex)
|
|
566
533
|
|
|
534
|
+
@override
|
|
567
535
|
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
|
|
568
536
|
if not context.task_id or not context.context_id:
|
|
569
537
|
raise ValueError("Task ID and context ID must be set to cancel a task")
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
538
|
+
if not (run := self._running_tasks.get(context.task_id)):
|
|
539
|
+
raise self._run_not_found_error(context.task_id)
|
|
540
|
+
await run.cancel(context, event_queue)
|
|
541
|
+
|
|
542
|
+
def _handle_finish(self, task_id: str) -> None:
|
|
543
|
+
if task := self._scheduled_cleanups.pop(task_id, None):
|
|
544
|
+
task.cancel()
|
|
545
|
+
self._running_tasks.pop(task_id, None)
|
|
546
|
+
|
|
547
|
+
def _run_not_found_error(self, task_id: str | None) -> Exception:
|
|
548
|
+
return RuntimeError(
|
|
549
|
+
f"Run for task ID {task_id} not found. "
|
|
550
|
+
+ "It may be on another replica, make sure to enable sticky sessions in your load balancer"
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
async def _schedule_run_cleanup(self, request_context: RequestContext):
|
|
554
|
+
task_id, context_id = request_context.task_id, request_context.context_id
|
|
555
|
+
assert task_id and context_id
|
|
556
|
+
|
|
557
|
+
async def cleanup_fn():
|
|
558
|
+
await asyncio.sleep(self._task_timeout.total_seconds())
|
|
559
|
+
if not (run := self._running_tasks.get(task_id)):
|
|
560
|
+
return
|
|
561
|
+
try:
|
|
562
|
+
while not run.done:
|
|
563
|
+
if run.last_invocation + self._task_timeout < datetime.now():
|
|
564
|
+
logger.warning(f"Task {task_id} did not finish in {self._task_timeout}")
|
|
565
|
+
queue = EventQueue()
|
|
566
|
+
await run.cancel(request_context=request_context, event_queue=queue)
|
|
567
|
+
# the original request queue is closed at this point, we need to propagate state to store manually
|
|
568
|
+
manager = TaskManager(
|
|
569
|
+
task_id=task_id, context_id=context_id, task_store=self._task_store, initial_message=None
|
|
570
|
+
)
|
|
571
|
+
event = await queue.dequeue_event(no_wait=True)
|
|
572
|
+
if not isinstance(event, TaskStatusUpdateEvent) or event.status.state != TaskState.canceled:
|
|
573
|
+
raise RuntimeError(f"Something strange occured during scheduled cancel, event: {event}")
|
|
574
|
+
_ = await manager.save_task_event(event)
|
|
575
|
+
break
|
|
576
|
+
await asyncio.sleep(2)
|
|
577
|
+
except Exception as ex:
|
|
578
|
+
logger.error("Error when cleaning up task", exc_info=ex)
|
|
579
|
+
finally:
|
|
580
|
+
_ = self._running_tasks.pop(task_id, None)
|
|
581
|
+
_ = self._scheduled_cleanups.pop(task_id, None)
|
|
582
|
+
|
|
583
|
+
self._scheduled_cleanups[task_id] = asyncio.create_task(cleanup_fn())
|
|
584
|
+
self._scheduled_cleanups[task_id].add_done_callback(lambda _: ...)
|