agentstack-sdk 0.4.3rc2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,17 +3,17 @@
3
3
 
4
4
  import asyncio
5
5
  import inspect
6
+ import typing
6
7
  from asyncio import CancelledError
7
8
  from collections.abc import AsyncGenerator, AsyncIterator, Callable, Generator
8
- from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
9
+ from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, suppress
9
10
  from datetime import datetime, timedelta
10
- from typing import NamedTuple, TypeAlias, TypedDict, cast
11
+ from typing import Any, NamedTuple, TypeAlias, TypeVar, cast
11
12
 
12
13
  import janus
13
- from a2a.client import create_text_message_object
14
14
  from a2a.server.agent_execution import AgentExecutor, RequestContext
15
15
  from a2a.server.events import EventQueue, QueueManager
16
- from a2a.server.tasks import TaskUpdater
16
+ from a2a.server.tasks import TaskManager, TaskStore, TaskUpdater
17
17
  from a2a.types import (
18
18
  AgentCapabilities,
19
19
  AgentCard,
@@ -34,29 +34,40 @@ from a2a.types import (
34
34
  TaskStatusUpdateEvent,
35
35
  TextPart,
36
36
  )
37
+ from typing_extensions import override
37
38
 
38
39
  from agentstack_sdk.a2a.extensions.ui.agent_detail import AgentDetail, AgentDetailExtensionSpec
39
- from agentstack_sdk.a2a.extensions.ui.error import ErrorExtensionParams, ErrorExtensionServer, ErrorExtensionSpec
40
- from agentstack_sdk.a2a.types import AgentMessage, ArtifactChunk, Metadata, RunYield, RunYieldResume
41
- from agentstack_sdk.server.constants import _IMPLICIT_DEPENDENCY_PREFIX, DEFAULT_ERROR_EXTENSION
40
+ from agentstack_sdk.a2a.extensions.ui.error import (
41
+ ErrorExtensionParams,
42
+ ErrorExtensionServer,
43
+ ErrorExtensionSpec,
44
+ get_error_extension_context,
45
+ )
46
+ from agentstack_sdk.a2a.types import ArtifactChunk, Metadata, RunYield, RunYieldResume
47
+ from agentstack_sdk.server.constants import _IMPLICIT_DEPENDENCY_PREFIX
42
48
  from agentstack_sdk.server.context import RunContext
43
- from agentstack_sdk.server.dependencies import extract_dependencies
49
+ from agentstack_sdk.server.dependencies import Depends, extract_dependencies
44
50
  from agentstack_sdk.server.store.context_store import ContextStore
45
- from agentstack_sdk.server.utils import cancel_task, close_queue
51
+ from agentstack_sdk.server.utils import cancel_task
46
52
  from agentstack_sdk.util.logging import logger
47
53
 
48
54
  AgentFunction: TypeAlias = Callable[[], AsyncGenerator[RunYield, RunYieldResume]]
49
- AgentFunctionFactory: TypeAlias = Callable[
50
- [TaskUpdater, RequestContext, ContextStore], AbstractAsyncContextManager[tuple[AgentFunction, RunContext]]
51
- ]
55
+ AgentFunctionFactory: TypeAlias = Callable[[RequestContext, ContextStore], AbstractAsyncContextManager[AgentFunction]]
56
+
57
+ OriginalFnType = TypeVar("OriginalFnType", bound=Callable[..., Any]) # pyright: ignore[reportExplicitAny]
58
+
59
+
60
+ class AgentExecuteFn(typing.Protocol):
61
+ async def __call__(self, _ctx: RunContext, **kwargs: Any) -> None: ...
52
62
 
53
63
 
54
64
  class Agent(NamedTuple):
55
65
  card: AgentCard
56
- execute: AgentFunctionFactory
66
+ dependencies: dict[str, Depends]
67
+ execute_fn: AgentExecuteFn
57
68
 
58
69
 
59
- AgentFactory: TypeAlias = Callable[[ContextStore], Agent]
70
+ AgentFactory: TypeAlias = Callable[[Callable[[dict[str, Depends]], None]], Agent]
60
71
 
61
72
 
62
73
  def agent(
@@ -78,7 +89,7 @@ def agent(
78
89
  skills: list[AgentSkill] | None = None,
79
90
  supports_authenticated_extended_card: bool | None = None,
80
91
  version: str | None = None,
81
- ) -> Callable[[Callable], AgentFactory]:
92
+ ) -> Callable[[OriginalFnType], AgentFactory]:
82
93
  """
83
94
  Create an Agent function.
84
95
 
@@ -112,11 +123,11 @@ def agent(
112
123
  capabilities = capabilities.model_copy(deep=True) if capabilities else AgentCapabilities(streaming=True)
113
124
  detail = detail or AgentDetail() # pyright: ignore [reportCallIssue]
114
125
 
115
- def decorator(fn: Callable) -> AgentFactory:
116
- def agent_factory(context_store: ContextStore):
126
+ def decorator(fn: OriginalFnType) -> AgentFactory:
127
+ def agent_factory(modify_dependencies: Callable[[dict[str, Depends]], None]):
117
128
  signature = inspect.signature(fn)
118
129
  dependencies = extract_dependencies(signature)
119
- context_store.modify_dependencies(dependencies)
130
+ modify_dependencies(dependencies)
120
131
 
121
132
  sdk_extensions = [dep.extension for dep in dependencies.values() if dep.extension is not None]
122
133
 
@@ -209,178 +220,172 @@ def agent(
209
220
  async def execute_fn(_ctx: RunContext, *args, **kwargs) -> None:
210
221
  await asyncio.to_thread(_execute_fn_sync, _ctx, *args, **kwargs)
211
222
 
212
- @asynccontextmanager
213
- async def agent_executor_lifespan(
214
- task_updater: TaskUpdater, request_context: RequestContext, context_store: ContextStore
215
- ) -> AsyncIterator[tuple[AgentFunction, RunContext]]:
216
- message = request_context.message
217
- assert message # this is only executed in the context of SendMessage request
218
- # These are incorrectly typed in a2a
219
- assert request_context.task_id
220
- assert request_context.context_id
221
- context = RunContext(
222
- configuration=request_context.configuration,
223
- context_id=request_context.context_id,
224
- task_id=request_context.task_id,
225
- task_updater=task_updater,
226
- current_task=request_context.current_task,
227
- related_tasks=request_context.related_tasks,
228
- call_context=request_context.call_context,
229
- )
230
-
231
- async with AsyncExitStack() as stack:
232
- dependency_args = {}
233
- for pname, depends in dependencies.items():
234
- # call dependencies with the first message and initialize their lifespan
235
- dependency_args[pname] = await stack.enter_async_context(
236
- depends(message, context, dependency_args)
237
- )
238
-
239
- context._error_extension = next(
240
- (ext for ext in dependency_args.values() if isinstance(ext, ErrorExtensionServer)),
241
- DEFAULT_ERROR_EXTENSION,
242
- )
243
-
244
- context._store = await context_store.create(
245
- context_id=request_context.context_id,
246
- initialized_dependencies=list(dependency_args.values()),
247
- )
248
-
249
- async def agent_generator():
250
- yield_queue = context._yield_queue
251
- yield_resume_queue = context._yield_resume_queue
252
-
253
- task = asyncio.create_task(
254
- execute_fn(
255
- context,
256
- **{
257
- k: v
258
- for k, v in dependency_args.items()
259
- if not k.startswith(_IMPLICIT_DEPENDENCY_PREFIX)
260
- },
261
- )
262
- )
263
- try:
264
- while not task.done() or yield_queue.async_q.qsize() > 0:
265
- value = yield await yield_queue.async_q.get()
266
- if isinstance(value, Exception):
267
- raise value
268
-
269
- if value:
270
- # TODO: context.call_context should be updated here
271
- # Unfortunately queue implementation does not support passing external types
272
- # (only a2a.event_queue.Event is supported:
273
- # Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
274
- for ext in sdk_extensions:
275
- ext.handle_incoming_message(value, context)
276
-
277
- await yield_resume_queue.async_q.put(value)
278
- except janus.AsyncQueueShutDown:
279
- pass
280
- except GeneratorExit:
281
- return
282
- finally:
283
- await cancel_task(task)
284
-
285
- yield agent_generator, context
286
-
287
- return Agent(card=card, execute=agent_executor_lifespan)
223
+ return Agent(card=card, dependencies=dependencies, execute_fn=execute_fn)
288
224
 
289
225
  return agent_factory
290
226
 
291
227
  return decorator
292
228
 
293
229
 
294
- class RunningTask(TypedDict):
295
- task: asyncio.Task
296
- last_invocation: datetime
230
+ class AgentRun:
231
+ def __init__(self, agent: Agent, context_store: ContextStore, on_finish: Callable[[], None] | None = None) -> None:
232
+ self._agent: Agent = agent
233
+ self._task: asyncio.Task[None] | None = None
234
+ self.last_invocation: datetime = datetime.now()
235
+ self.resume_queue: asyncio.Queue[RunYieldResume] = asyncio.Queue()
236
+ self._run_context: RunContext | None = None
237
+ self._task_updater: TaskUpdater | None = None
238
+ self._context_store: ContextStore = context_store
239
+ self._lock: asyncio.Lock = asyncio.Lock()
240
+ self._on_finish: Callable[[], None] | None = on_finish
241
+ self._working: bool = False
242
+
243
+ @property
244
+ def run_context(self) -> RunContext:
245
+ if not self._run_context:
246
+ raise RuntimeError("Accessing run context for run that has not been started")
247
+ return self._run_context
248
+
249
+ @property
250
+ def task_updater(self) -> TaskUpdater:
251
+ if not self._task_updater:
252
+ raise RuntimeError("Accessing task updater for run that has not been started")
253
+ return self._task_updater
254
+
255
+ @property
256
+ def done(self) -> bool:
257
+ return self._task is not None and self._task.done()
258
+
259
+ def _handle_finish(self) -> None:
260
+ if self._on_finish:
261
+ self._on_finish()
262
+
263
+ async def start(self, request_context: RequestContext, event_queue: EventQueue):
264
+ # These are incorrectly typed in a2a
265
+ async with self._lock:
266
+ if self._working or self.done:
267
+ raise RuntimeError("Attempting to start a run that is already executing or done")
268
+ task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message
269
+ assert task_id and context_id and message
270
+ self._run_context = RunContext(
271
+ configuration=request_context.configuration,
272
+ context_id=context_id,
273
+ task_id=task_id,
274
+ current_task=request_context.current_task,
275
+ related_tasks=request_context.related_tasks,
276
+ )
277
+ self._task_updater = TaskUpdater(event_queue, task_id, context_id)
278
+ if not request_context.current_task:
279
+ await self._task_updater.submit()
280
+ await self._task_updater.start_work()
281
+ self._working = True
282
+ self._task = asyncio.create_task(self._run_agent_function(initial_message=message))
283
+
284
+ async def resume(self, request_context: RequestContext, event_queue: EventQueue):
285
+ # These are incorrectly typed in a2a
286
+ async with self._lock:
287
+ if self._working or self.done:
288
+ raise RuntimeError("Attempting to resume a run that is already executing or done")
289
+ task_id, context_id, message = request_context.task_id, request_context.context_id, request_context.message
290
+ assert task_id and context_id and message
291
+ self._task_updater = TaskUpdater(event_queue, task_id, context_id)
297
292
 
293
+ for dependency in self._agent.dependencies.values():
294
+ if dependency.extension:
295
+ dependency.extension.handle_incoming_message(message, self.run_context)
298
296
 
299
- class Executor(AgentExecutor):
300
- def __init__(
301
- self,
302
- execute_fn: AgentFunctionFactory,
303
- queue_manager: QueueManager,
304
- context_store: ContextStore,
305
- task_timeout: timedelta,
306
- ) -> None:
307
- self._agent_executor_span = execute_fn
308
- self._queue_manager = queue_manager
309
- self._running_tasks: dict[str, RunningTask] = {}
310
- self._cancel_queues: dict[str, EventQueue] = {}
311
- self._context_store = context_store
312
- self._task_timeout = task_timeout
297
+ self._working = True
298
+ await self.resume_queue.put(message)
313
299
 
314
- async def _watch_for_cancellation(self, task_id: str, task: asyncio.Task) -> None:
315
- cancel_queue = await self._queue_manager.create_or_tap(f"_cancel_{task_id}")
316
- self._cancel_queues[task_id] = cancel_queue
300
+ async def cancel(self, request_context: RequestContext, event_queue: EventQueue):
301
+ if not self._task:
302
+ raise RuntimeError("Cannot cancel run that has not been started")
317
303
 
318
- try:
319
- await cancel_queue.dequeue_event()
320
- cancel_queue.task_done()
321
- task.cancel()
322
- finally:
323
- await self._queue_manager.close(f"_cancel_{task_id}")
324
- self._cancel_queues.pop(task_id)
304
+ async with self._lock:
305
+ try:
306
+ assert request_context.task_id
307
+ assert request_context.context_id
308
+ self._task_updater = TaskUpdater(event_queue, request_context.task_id, request_context.context_id)
309
+ await self._task_updater.cancel()
310
+ finally:
311
+ await cancel_task(self._task)
312
+
313
+ @asynccontextmanager
314
+ async def _dependencies_lifespan(self, message: Message) -> AsyncIterator[dict[str, Depends]]:
315
+ async with AsyncExitStack() as stack:
316
+ dependency_args: dict[str, Depends] = {}
317
+ initialize_deps_exceptions: list[Exception] = []
318
+ for pname, depends in self._agent.dependencies.items():
319
+ # call dependencies with the first message and initialize their lifespan
320
+ try:
321
+ dependency_args[pname] = await stack.enter_async_context(
322
+ depends(message, self.run_context, dependency_args)
323
+ )
324
+ except Exception as e:
325
+ initialize_deps_exceptions.append(e)
326
+
327
+ if initialize_deps_exceptions:
328
+ raise (
329
+ ExceptionGroup("Failed to initialize dependencies", initialize_deps_exceptions)
330
+ if len(initialize_deps_exceptions) > 1
331
+ else initialize_deps_exceptions[0]
332
+ )
325
333
 
326
- async def _run_agent_function(
327
- self,
328
- *,
329
- context: RequestContext,
330
- context_store: ContextStore,
331
- task_updater: TaskUpdater,
332
- resume_queue: EventQueue,
333
- ) -> None:
334
- current_task = asyncio.current_task()
335
- assert current_task
336
- cancellation_task = asyncio.create_task(self._watch_for_cancellation(task_updater.task_id, current_task))
337
-
338
- def with_context(message: Message | None = None) -> Message | None:
339
- if message is None:
340
- return None
341
- # Note: This check would require extra handling in agents just forwarding messages from other agents
342
- # Instead, we just silently replace it.
343
- # if message.task_id and message.task_id != task_updater.task_id:
344
- # raise ValueError("Message must have the same task_id as the task")
345
- # if message.context_id and message.context_id != task_updater.context_id:
346
- # raise ValueError("Message must have the same context_id as the task")
347
- return message.model_copy(
348
- deep=True, update={"context_id": task_updater.context_id, "task_id": task_updater.task_id}
334
+ self.run_context._store = await self._context_store.create( # pyright: ignore[reportPrivateUsage]
335
+ context_id=self.run_context.context_id,
336
+ initialized_dependencies=list(dependency_args.values()),
349
337
  )
350
338
 
351
- run_context: RunContext | None = None
339
+ yield {k: v for k, v in dependency_args.items() if not k.startswith(_IMPLICIT_DEPENDENCY_PREFIX)}
340
+
341
+ def _with_context(self, message: Message | None = None) -> Message | None:
342
+ if message is None:
343
+ return None
344
+ # Note: This check would require extra handling in agents just forwarding messages from other agents
345
+ # Instead, we just silently replace it.
346
+ # if message.task_id and message.task_id != task_updater.task_id:
347
+ # raise ValueError("Message must have the same task_id as the task")
348
+ # if message.context_id and message.context_id != task_updater.context_id:
349
+ # raise ValueError("Message must have the same context_id as the task")
350
+ return message.model_copy(
351
+ deep=True, update={"context_id": self.task_updater.context_id, "task_id": self.task_updater.task_id}
352
+ )
353
+
354
+ async def _run_agent_function(self, initial_message: Message) -> None:
355
+ yield_queue = self.run_context._yield_queue # pyright: ignore[reportPrivateUsage]
356
+ yield_resume_queue = self.run_context._yield_resume_queue # pyright: ignore[reportPrivateUsage]
357
+
352
358
  try:
353
- async with self._agent_executor_span(task_updater, context, context_store) as (execute_fn, run_context):
359
+ async with self._dependencies_lifespan(initial_message) as dependency_args:
360
+ task = asyncio.create_task(self._agent.execute_fn(self.run_context, **dependency_args))
354
361
  try:
355
- agent_generator_fn = execute_fn()
356
-
357
- await task_updater.start_work()
358
- value: RunYieldResume = None
362
+ resume_value: RunYieldResume = None
359
363
  opened_artifacts: set[str] = set()
360
- while True:
361
- # update invocation time
362
- self._running_tasks[task_updater.task_id]["last_invocation"] = datetime.now()
364
+ while not task.done() or yield_queue.async_q.qsize() > 0:
365
+ yielded_value = await yield_queue.async_q.get()
363
366
 
364
- yielded_value = await agent_generator_fn.asend(value)
367
+ self.last_invocation = datetime.now()
365
368
 
366
369
  match yielded_value:
367
370
  case str(text):
368
- await task_updater.update_status(
371
+ await self.task_updater.update_status(
369
372
  TaskState.working,
370
- message=task_updater.new_agent_message(parts=[Part(root=TextPart(text=text))]),
373
+ message=self.task_updater.new_agent_message(parts=[Part(root=TextPart(text=text))]),
371
374
  )
372
375
  case Part(root=part) | (TextPart() | FilePart() | DataPart() as part):
373
- await task_updater.update_status(
376
+ await self.task_updater.update_status(
374
377
  TaskState.working,
375
- message=task_updater.new_agent_message(parts=[Part(root=part)]),
378
+ message=self.task_updater.new_agent_message(parts=[Part(root=part)]),
376
379
  )
377
380
  case FileWithBytes() | FileWithUri() as file:
378
- await task_updater.update_status(
381
+ await self.task_updater.update_status(
379
382
  TaskState.working,
380
- message=task_updater.new_agent_message(parts=[Part(root=FilePart(file=file))]),
383
+ message=self.task_updater.new_agent_message(parts=[Part(root=FilePart(file=file))]),
381
384
  )
382
385
  case Message() as message:
383
- await task_updater.update_status(TaskState.working, message=with_context(message))
386
+ await self.task_updater.update_status(
387
+ TaskState.working, message=self._with_context(message)
388
+ )
384
389
  case ArtifactChunk(
385
390
  parts=parts,
386
391
  artifact_id=artifact_id,
@@ -388,7 +393,7 @@ class Executor(AgentExecutor):
388
393
  metadata=metadata,
389
394
  last_chunk=last_chunk,
390
395
  ):
391
- await task_updater.add_artifact(
396
+ await self.task_updater.add_artifact(
392
397
  parts=cast(list[Part], parts),
393
398
  artifact_id=artifact_id,
394
399
  name=name,
@@ -398,7 +403,7 @@ class Executor(AgentExecutor):
398
403
  )
399
404
  opened_artifacts.add(artifact_id)
400
405
  case Artifact(parts=parts, artifact_id=artifact_id, name=name, metadata=metadata):
401
- await task_updater.add_artifact(
406
+ await self.task_updater.add_artifact(
402
407
  parts=parts,
403
408
  artifact_id=artifact_id,
404
409
  name=name,
@@ -406,28 +411,29 @@ class Executor(AgentExecutor):
406
411
  last_chunk=True,
407
412
  append=False,
408
413
  )
409
- case TaskStatus(state=TaskState.input_required, message=message, timestamp=timestamp):
410
- await task_updater.requires_input(message=with_context(message), final=True)
411
- value = cast(RunYieldResume, await resume_queue.dequeue_event())
412
- resume_queue.task_done()
413
- continue
414
- case TaskStatus(state=TaskState.auth_required, message=message, timestamp=timestamp):
415
- await task_updater.requires_auth(message=with_context(message), final=True)
416
- value = cast(RunYieldResume, await resume_queue.dequeue_event())
417
- resume_queue.task_done()
418
- continue
414
+ case TaskStatus(
415
+ state=(TaskState.auth_required | TaskState.input_required) as state,
416
+ message=message,
417
+ timestamp=timestamp,
418
+ ):
419
+ await self.task_updater.update_status(
420
+ state=state, message=self._with_context(message), final=True, timestamp=timestamp
421
+ )
422
+ self._working = False
423
+ resume_value = await self.resume_queue.get()
424
+ self.resume_queue.task_done()
419
425
  case TaskStatus(state=state, message=message, timestamp=timestamp):
420
- await task_updater.update_status(
421
- state=state, message=with_context(message), timestamp=timestamp
426
+ await self.task_updater.update_status(
427
+ state=state, message=self._with_context(message), timestamp=timestamp
422
428
  )
423
429
  case TaskStatusUpdateEvent(
424
430
  status=TaskStatus(state=state, message=message, timestamp=timestamp),
425
431
  final=final,
426
432
  metadata=metadata,
427
433
  ):
428
- await task_updater.update_status(
434
+ await self.task_updater.update_status(
429
435
  state=state,
430
- message=with_context(message),
436
+ message=self._with_context(message),
431
437
  timestamp=timestamp,
432
438
  final=final,
433
439
  metadata=metadata,
@@ -437,7 +443,7 @@ class Executor(AgentExecutor):
437
443
  append=append,
438
444
  last_chunk=last_chunk,
439
445
  ):
440
- await task_updater.add_artifact(
446
+ await self.task_updater.add_artifact(
441
447
  parts=parts,
442
448
  artifact_id=artifact_id,
443
449
  name=name,
@@ -446,128 +452,133 @@ class Executor(AgentExecutor):
446
452
  last_chunk=last_chunk,
447
453
  )
448
454
  case Metadata() as metadata:
449
- await task_updater.update_status(
455
+ await self.task_updater.update_status(
450
456
  state=TaskState.working,
451
- message=task_updater.new_agent_message(parts=[], metadata=metadata),
457
+ message=self.task_updater.new_agent_message(parts=[], metadata=metadata),
452
458
  )
453
459
  case dict() as data:
454
- await task_updater.update_status(
460
+ await self.task_updater.update_status(
455
461
  state=TaskState.working,
456
- message=task_updater.new_agent_message(parts=[Part(root=DataPart(data=data))]),
462
+ message=self.task_updater.new_agent_message(parts=[Part(root=DataPart(data=data))]),
457
463
  )
458
464
  case Exception() as ex:
459
465
  raise ex
460
466
  case _:
461
467
  raise ValueError(f"Invalid value yielded from agent: {type(yielded_value)}")
462
- value = None
463
- except StopAsyncIteration:
464
- await task_updater.complete()
465
- except CancelledError:
466
- await task_updater.cancel()
468
+
469
+ await yield_resume_queue.async_q.put(resume_value)
470
+
471
+ await self.task_updater.complete()
472
+
473
+ except (janus.AsyncQueueShutDown, GeneratorExit):
474
+ await self.task_updater.complete()
467
475
  except Exception as ex:
468
476
  logger.error("Error when executing agent", exc_info=ex)
469
- try:
470
- error_extension = run_context._error_extension if run_context else None
471
- error_extension = error_extension if error_extension is not None else DEFAULT_ERROR_EXTENSION
472
- error_msg = error_extension.message(ex)
473
- except Exception as error_exc:
474
- error_msg = AgentMessage(
475
- text=(f"Failed to create error message: {error_exc!s}\noriginal exc: {ex!s}")
476
- )
477
- await task_updater.failed(error_msg)
478
- finally: # cleanup
479
- await cancel_task(cancellation_task)
480
- is_cancelling = bool(current_task.cancelling())
481
- try:
482
- async with asyncio.timeout(10): # grace period to read all events from queue
483
- await close_queue(self._queue_manager, f"_event_{context.task_id}", immediate=is_cancelling)
484
- await close_queue(self._queue_manager, f"_resume_{context.task_id}", immediate=is_cancelling)
485
- except (TimeoutError, CancelledError):
486
- await close_queue(self._queue_manager, f"_event_{context.task_id}", immediate=True)
487
- await close_queue(self._queue_manager, f"_resume_{context.task_id}", immediate=True)
477
+ await self.task_updater.failed(get_error_extension_context().server.message(ex))
478
+ await cancel_task(task)
479
+ except Exception as ex:
480
+ logger.error("Error when executing agent", exc_info=ex)
481
+ await self.task_updater.failed(get_error_extension_context().server.message(ex))
482
+ finally:
483
+ self._working = False
484
+ with suppress(Exception):
485
+ self._handle_finish()
486
+
488
487
 
488
+ class Executor(AgentExecutor):
489
+ def __init__(
490
+ self,
491
+ agent: Agent,
492
+ queue_manager: QueueManager,
493
+ context_store: ContextStore,
494
+ task_timeout: timedelta,
495
+ task_store: TaskStore,
496
+ ) -> None:
497
+ self._agent: Agent = agent
498
+ self._running_tasks: dict[str, AgentRun] = {}
499
+ self._scheduled_cleanups: dict[str, asyncio.Task[None]] = {}
500
+ self._context_store: ContextStore = context_store
501
+ self._task_timeout: timedelta = task_timeout
502
+ self._task_store: TaskStore = task_store
503
+
504
+ @override
489
505
  async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
490
- assert context.message # this is only executed in the context of SendMessage request
491
- # These are incorrectly typed in a2a
492
- assert context.context_id
493
- assert context.task_id
506
+ # this is only executed in the context of SendMessage request
507
+ message, task_id, context_id = context.message, context.task_id, context.context_id
508
+ assert message and task_id and context_id
509
+ agent_run: AgentRun | None = None
494
510
  try:
495
- current_status = context.current_task and context.current_task.status.state
496
- if current_status == TaskState.working:
497
- raise RuntimeError("Cannot resume working task")
498
- if not context.task_id:
499
- raise RuntimeError("Task ID was not created")
500
-
501
- if not (resume_queue := await self._queue_manager.get(task_id=f"_resume_{context.task_id}")):
502
- resume_queue = await self._queue_manager.create_or_tap(task_id=f"_resume_{context.task_id}")
503
-
504
- if not (long_running_event_queue := await self._queue_manager.get(task_id=f"_event_{context.task_id}")):
505
- long_running_event_queue = await self._queue_manager.create_or_tap(task_id=f"_event_{context.task_id}")
506
-
507
- if current_status in {TaskState.input_required, TaskState.auth_required}:
508
- await resume_queue.enqueue_event(context.message)
511
+ if not context.current_task:
512
+ agent_run = AgentRun(self._agent, self._context_store, lambda: self._handle_finish(task_id))
513
+ self._running_tasks[task_id] = agent_run
514
+ await self._schedule_run_cleanup(request_context=context)
515
+ await agent_run.start(request_context=context, event_queue=event_queue)
516
+ elif agent_run := self._running_tasks.get(task_id):
517
+ await agent_run.resume(request_context=context, event_queue=event_queue)
509
518
  else:
510
- task_updater = TaskUpdater(long_running_event_queue, context.task_id, context.context_id)
511
- run_generator = self._run_agent_function(
512
- context=context,
513
- context_store=self._context_store,
514
- task_updater=task_updater,
515
- resume_queue=resume_queue,
516
- )
517
-
518
- self._running_tasks[context.task_id] = RunningTask(
519
- task=asyncio.create_task(run_generator), last_invocation=datetime.now()
520
- )
521
- asyncio.create_task(
522
- self._schedule_run_cleanup(task_id=context.task_id, task_timeout=self._task_timeout)
523
- ).add_done_callback(lambda _: ...)
519
+ raise self._run_not_found_error(task_id)
524
520
 
521
+ # will run until complete or next input/auth required task state
522
+ tapped_queue = event_queue.tap()
525
523
  while True:
526
- # Forward messages to local event queue
527
- event = await long_running_event_queue.dequeue_event()
528
- long_running_event_queue.task_done()
529
- await event_queue.enqueue_event(event)
530
- match event:
524
+ match await tapped_queue.dequeue_event():
531
525
  case TaskStatusUpdateEvent(final=True):
532
526
  break
533
- except CancelledError:
534
- # Handles cancellation of this handler:
535
- # When a streaming request is canceled, this executor is canceled first meaning that "cancellation" event
536
- # passed from the agent's long_running_event_queue is not forwarded. Instead of shielding this function,
537
- # we report the cancellation explicitly
538
- await self._cancel_task(context.task_id)
539
- local_updater = TaskUpdater(event_queue, task_id=context.task_id, context_id=context.context_id)
540
- await local_updater.cancel()
541
- except Exception as ex:
542
- logger.error("Error executing agent", exc_info=ex)
543
- local_updater = TaskUpdater(event_queue, task_id=context.task_id, context_id=context.context_id)
544
- await local_updater.failed(local_updater.new_agent_message(parts=[Part(root=TextPart(text=str(ex)))]))
545
-
546
- async def _cancel_task(self, task_id: str):
547
- if queue := self._cancel_queues.get(task_id):
548
- await queue.enqueue_event(create_text_message_object(content="canceled"))
549
-
550
- async def _schedule_run_cleanup(self, task_id: str, task_timeout: timedelta):
551
- task = self._running_tasks.get(task_id)
552
- assert task
553
527
 
554
- try:
555
- while not task["task"].done():
556
- await asyncio.sleep(5)
557
- if not task["task"].done() and task["last_invocation"] + task_timeout < datetime.now():
558
- # Task might be stuck waiting for queue events to be processed
559
- logger.warning(f"Task {task_id} did not finish in {task_timeout}")
560
- await self._cancel_task(task_id)
561
- break
528
+ except CancelledError:
529
+ if agent_run:
530
+ await agent_run.cancel(request_context=context, event_queue=event_queue)
562
531
  except Exception as ex:
563
- logger.error("Error when cleaning up task", exc_info=ex)
564
- finally:
565
- self._running_tasks.pop(task_id)
532
+ logger.error("Unhandled error when executing agent:", exc_info=ex)
566
533
 
534
+ @override
567
535
  async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
568
536
  if not context.task_id or not context.context_id:
569
537
  raise ValueError("Task ID and context ID must be set to cancel a task")
570
- try:
571
- await self._cancel_task(task_id=context.task_id)
572
- finally:
573
- await TaskUpdater(event_queue, task_id=context.task_id, context_id=context.context_id).cancel()
538
+ if not (run := self._running_tasks.get(context.task_id)):
539
+ raise self._run_not_found_error(context.task_id)
540
+ await run.cancel(context, event_queue)
541
+
542
+ def _handle_finish(self, task_id: str) -> None:
543
+ if task := self._scheduled_cleanups.pop(task_id, None):
544
+ task.cancel()
545
+ self._running_tasks.pop(task_id, None)
546
+
547
+ def _run_not_found_error(self, task_id: str | None) -> Exception:
548
+ return RuntimeError(
549
+ f"Run for task ID {task_id} not found. "
550
+ + "It may be on another replica, make sure to enable sticky sessions in your load balancer"
551
+ )
552
+
553
+ async def _schedule_run_cleanup(self, request_context: RequestContext):
554
+ task_id, context_id = request_context.task_id, request_context.context_id
555
+ assert task_id and context_id
556
+
557
+ async def cleanup_fn():
558
+ await asyncio.sleep(self._task_timeout.total_seconds())
559
+ if not (run := self._running_tasks.get(task_id)):
560
+ return
561
+ try:
562
+ while not run.done:
563
+ if run.last_invocation + self._task_timeout < datetime.now():
564
+ logger.warning(f"Task {task_id} did not finish in {self._task_timeout}")
565
+ queue = EventQueue()
566
+ await run.cancel(request_context=request_context, event_queue=queue)
567
+ # the original request queue is closed at this point, we need to propagate state to store manually
568
+ manager = TaskManager(
569
+ task_id=task_id, context_id=context_id, task_store=self._task_store, initial_message=None
570
+ )
571
+ event = await queue.dequeue_event(no_wait=True)
572
+ if not isinstance(event, TaskStatusUpdateEvent) or event.status.state != TaskState.canceled:
573
+ raise RuntimeError(f"Something strange occured during scheduled cancel, event: {event}")
574
+ _ = await manager.save_task_event(event)
575
+ break
576
+ await asyncio.sleep(2)
577
+ except Exception as ex:
578
+ logger.error("Error when cleaning up task", exc_info=ex)
579
+ finally:
580
+ _ = self._running_tasks.pop(task_id, None)
581
+ _ = self._scheduled_cleanups.pop(task_id, None)
582
+
583
+ self._scheduled_cleanups[task_id] = asyncio.create_task(cleanup_fn())
584
+ self._scheduled_cleanups[task_id].add_done_callback(lambda _: ...)