grasp_agents 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/__init__.py CHANGED
@@ -1,14 +1,13 @@
1
1
  # pyright: reportUnusedImport=false
2
2
 
3
3
 
4
- from .comm_processor import CommProcessor
5
4
  from .llm import LLM, LLMSettings
6
5
  from .llm_agent import LLMAgent
7
6
  from .llm_agent_memory import LLMAgentMemory
8
7
  from .memory import Memory
9
8
  from .packet import Packet
10
9
  from .processor import Processor
11
- from .run_context import RunArgs, RunContext
10
+ from .run_context import RunContext
12
11
  from .typing.completion import Completion
13
12
  from .typing.content import Content, ImageData
14
13
  from .typing.io import LLMPrompt, LLMPromptArgs, ProcName
@@ -19,20 +18,20 @@ __all__ = [
19
18
  "LLM",
20
19
  "AssistantMessage",
21
20
  "BaseTool",
22
- "CommProcessor",
23
21
  "Completion",
24
22
  "Content",
25
23
  "ImageData",
26
24
  "LLMAgent",
25
+ "LLMAgentMemory",
27
26
  "LLMPrompt",
28
27
  "LLMPromptArgs",
29
28
  "LLMSettings",
29
+ "Memory",
30
30
  "Messages",
31
31
  "Packet",
32
32
  "Packet",
33
33
  "ProcName",
34
34
  "Processor",
35
- "RunArgs",
36
35
  "RunContext",
37
36
  "SystemMessage",
38
37
  "UserMessage",
grasp_agents/errors.py CHANGED
@@ -1,24 +1,48 @@
1
1
  # from openai import APIResponseValidationError
2
- class CompletionError(Exception):
3
- pass
4
2
 
5
3
 
6
- class CombineCompletionChunksError(Exception):
7
- pass
4
+ class ProcRunError(Exception):
5
+ def __init__(
6
+ self, proc_name: str, call_id: str, message: str | None = None
7
+ ) -> None:
8
+ super().__init__(
9
+ message
10
+ or f"Processor run failed [proc_name: {proc_name}; call_id: {call_id}]."
11
+ )
12
+ self.proc_name = proc_name
13
+ self.call_id = call_id
8
14
 
9
15
 
10
- class ProcInputValidationError(Exception):
16
+ class ProcInputValidationError(ProcRunError):
11
17
  pass
12
18
 
13
19
 
14
- class ProcOutputValidationError(Exception):
15
- pass
20
+ class ProcOutputValidationError(ProcRunError):
21
+ def __init__(
22
+ self, schema: object, proc_name: str, call_id: str, message: str | None = None
23
+ ):
24
+ super().__init__(
25
+ proc_name=proc_name,
26
+ call_id=call_id,
27
+ message=message
28
+ or (
29
+ "Processor output validation failed "
30
+ f"[proc_name: {proc_name}; call_id: {call_id}]. "
31
+ f"Expected type:\n{schema}"
32
+ ),
33
+ )
16
34
 
17
35
 
18
- class AgentFinalAnswerError(Exception):
19
- def __init__(self, message: str | None = None) -> None:
36
+ class AgentFinalAnswerError(ProcRunError):
37
+ def __init__(
38
+ self, proc_name: str, call_id: str, message: str | None = None
39
+ ) -> None:
20
40
  super().__init__(
21
- message or "Final answer tool call did not return a final answer message."
41
+ proc_name=proc_name,
42
+ call_id=call_id,
43
+ message=message
44
+ or "Final answer tool call did not return a final answer message "
45
+ f"[proc_name={proc_name}; call_id={call_id}]",
22
46
  )
23
47
  self.message = message
24
48
 
@@ -27,28 +51,58 @@ class WorkflowConstructionError(Exception):
27
51
  pass
28
52
 
29
53
 
30
- class PacketRoutingError(Exception):
54
+ class PacketRoutingError(ProcRunError):
31
55
  def __init__(
32
56
  self,
33
- selected_recipient: str,
34
- allowed_recipients: list[str],
57
+ proc_name: str,
58
+ call_id: str,
59
+ selected_recipient: str | None = None,
60
+ allowed_recipients: list[str] | None = None,
35
61
  message: str | None = None,
36
62
  ) -> None:
37
63
  default_message = (
38
64
  f"Selected recipient '{selected_recipient}' is not in the allowed "
39
- f"recipients: {allowed_recipients}"
65
+ f"recipients: {allowed_recipients} "
66
+ f"[proc_name={proc_name}; call_id={call_id}]"
67
+ )
68
+ super().__init__(
69
+ proc_name=proc_name, call_id=call_id, message=message or default_message
40
70
  )
41
- super().__init__(message or default_message)
42
71
  self.selected_recipient = selected_recipient
43
72
  self.allowed_recipients = allowed_recipients
44
73
 
45
74
 
46
- class SystemPromptBuilderError(Exception):
75
+ class RunnerError(Exception):
47
76
  pass
48
77
 
49
78
 
50
- class InputPromptBuilderError(Exception):
51
- pass
79
+ class PromptBuilderError(Exception):
80
+ def __init__(self, proc_name: str, message: str | None = None) -> None:
81
+ super().__init__(message or f"Prompt builder failed [proc_name={proc_name}]")
82
+ self.proc_name = proc_name
83
+ self.message = message
84
+
85
+
86
+ class SystemPromptBuilderError(PromptBuilderError):
87
+ def __init__(self, proc_name: str, message: str | None = None) -> None:
88
+ super().__init__(
89
+ proc_name=proc_name,
90
+ message=message
91
+ or "System prompt builder failed to make system prompt "
92
+ f"[proc_name={proc_name}]",
93
+ )
94
+ self.message = message
95
+
96
+
97
+ class InputPromptBuilderError(PromptBuilderError):
98
+ def __init__(self, proc_name: str, message: str | None = None) -> None:
99
+ super().__init__(
100
+ proc_name=proc_name,
101
+ message=message
102
+ or "Input prompt builder failed to make input content "
103
+ f"[proc_name={proc_name}]",
104
+ )
105
+ self.message = message
52
106
 
53
107
 
54
108
  class PyJSONStringParsingError(Exception):
@@ -71,6 +125,14 @@ class JSONSchemaValidationError(Exception):
71
125
  self.schema = schema
72
126
 
73
127
 
128
+ class CompletionError(Exception):
129
+ pass
130
+
131
+
132
+ class CombineCompletionChunksError(Exception):
133
+ pass
134
+
135
+
74
136
  class LLMToolCallValidationError(Exception):
75
137
  def __init__(
76
138
  self, tool_name: str, tool_args: str, message: str | None = None
grasp_agents/llm_agent.py CHANGED
@@ -4,7 +4,6 @@ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
7
- from .comm_processor import CommProcessor
8
7
  from .llm import LLM, LLMSettings
9
8
  from .llm_agent_memory import LLMAgentMemory, PrepareMemoryHandler
10
9
  from .llm_policy_executor import (
@@ -12,7 +11,7 @@ from .llm_policy_executor import (
12
11
  LLMPolicyExecutor,
13
12
  ManageMemoryHandler,
14
13
  )
15
- from .packet_pool import PacketPool
14
+ from .processor import Processor
16
15
  from .prompt_builder import (
17
16
  MakeInputContentHandler,
18
17
  MakeSystemPromptHandler,
@@ -27,7 +26,7 @@ from .typing.events import (
27
26
  SystemMessageEvent,
28
27
  UserMessageEvent,
29
28
  )
30
- from .typing.io import InT, LLMPrompt, LLMPromptArgs, OutT_co, ProcName
29
+ from .typing.io import InT, LLMPrompt, LLMPromptArgs, OutT, ProcName
31
30
  from .typing.message import Message, Messages, SystemMessage, UserMessage
32
31
  from .typing.tool import BaseTool
33
32
  from .utils import get_prompt, validate_obj_from_json_or_py_string
@@ -47,8 +46,8 @@ class ParseOutputHandler(Protocol[_InT_contra, _OutT_co, CtxT]):
47
46
 
48
47
 
49
48
  class LLMAgent(
50
- CommProcessor[InT, OutT_co, LLMAgentMemory, CtxT],
51
- Generic[InT, OutT_co, CtxT],
49
+ Processor[InT, OutT, LLMAgentMemory, CtxT],
50
+ Generic[InT, OutT, CtxT],
52
51
  ):
53
52
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
54
53
  0: "_in_type",
@@ -71,8 +70,6 @@ class LLMAgent(
71
70
  sys_prompt_path: str | Path | None = None,
72
71
  # System args (static args provided via RunContext)
73
72
  sys_args_schema: type[LLMPromptArgs] | None = None,
74
- # User args (static args provided via RunContext)
75
- usr_args_schema: type[LLMPromptArgs] | None = None,
76
73
  # Agent loop settings
77
74
  max_turns: int = 100,
78
75
  react_mode: bool = False,
@@ -82,15 +79,9 @@ class LLMAgent(
82
79
  # Retries
83
80
  max_retries: int = 0,
84
81
  # Multi-agent routing
85
- packet_pool: PacketPool[CtxT] | None = None,
86
82
  recipients: list[ProcName] | None = None,
87
83
  ) -> None:
88
- super().__init__(
89
- name=name,
90
- packet_pool=packet_pool,
91
- recipients=recipients,
92
- max_retries=max_retries,
93
- )
84
+ super().__init__(name=name, recipients=recipients, max_retries=max_retries)
94
85
 
95
86
  # Agent memory
96
87
 
@@ -132,14 +123,13 @@ class LLMAgent(
132
123
  self.in_type, CtxT
133
124
  ](
134
125
  agent_name=self._name,
135
- sys_prompt_template=sys_prompt,
136
- in_prompt_template=in_prompt,
126
+ sys_prompt=sys_prompt,
127
+ in_prompt=in_prompt,
137
128
  sys_args_schema=sys_args_schema,
138
- usr_args_schema=usr_args_schema,
139
129
  )
140
130
 
141
131
  self._prepare_memory_impl: PrepareMemoryHandler | None = None
142
- self._parse_output_impl: ParseOutputHandler[InT, OutT_co, CtxT] | None = None
132
+ self._parse_output_impl: ParseOutputHandler[InT, OutT, CtxT] | None = None
143
133
  self._register_overridden_handlers()
144
134
 
145
135
  @property
@@ -158,17 +148,13 @@ class LLMAgent(
158
148
  def sys_args_schema(self) -> type[LLMPromptArgs] | None:
159
149
  return self._prompt_builder.sys_args_schema
160
150
 
161
- @property
162
- def usr_args_schema(self) -> type[LLMPromptArgs] | None:
163
- return self._prompt_builder.usr_args_schema
164
-
165
151
  @property
166
152
  def sys_prompt(self) -> LLMPrompt | None:
167
- return self._prompt_builder.sys_prompt_template
153
+ return self._prompt_builder.sys_prompt
168
154
 
169
155
  @property
170
156
  def in_prompt(self) -> LLMPrompt | None:
171
- return self._prompt_builder.in_prompt_template
157
+ return self._prompt_builder.in_prompt
172
158
 
173
159
  def _prepare_memory(
174
160
  self,
@@ -184,20 +170,13 @@ class LLMAgent(
184
170
 
185
171
  def _memorize_inputs(
186
172
  self,
173
+ memory: LLMAgentMemory,
187
174
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
188
- *,
189
175
  in_args: InT | None = None,
190
- memory: LLMAgentMemory,
191
176
  ctx: RunContext[CtxT] | None = None,
192
177
  ) -> tuple[SystemMessage | None, UserMessage | None]:
193
- # 1. Get run arguments
194
- sys_args: LLMPromptArgs | None = None
195
- usr_args: LLMPromptArgs | None = None
196
- if ctx is not None:
197
- run_args = ctx.run_args.get(self.name)
198
- if run_args is not None:
199
- sys_args = run_args.sys
200
- usr_args = run_args.usr
178
+ # 1. Get system arguments
179
+ sys_args = ctx.sys_args.get(self.name) if ctx and ctx.sys_args else None
201
180
 
202
181
  # 2. Make system prompt (can be None)
203
182
 
@@ -214,16 +193,13 @@ class LLMAgent(
214
193
  system_message = cast("SystemMessage", memory.message_history[0])
215
194
  else:
216
195
  self._prepare_memory(
217
- memory=memory,
218
- in_args=in_args,
219
- sys_prompt=formatted_sys_prompt,
220
- ctx=ctx,
196
+ memory=memory, in_args=in_args, sys_prompt=formatted_sys_prompt, ctx=ctx
221
197
  )
222
198
 
223
199
  # 3. Make and add input messages
224
200
 
225
201
  input_message = self._prompt_builder.make_input_message(
226
- chat_inputs=chat_inputs, in_args=in_args, usr_args=usr_args, ctx=ctx
202
+ chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
227
203
  )
228
204
  if input_message:
229
205
  memory.update([input_message])
@@ -236,7 +212,7 @@ class LLMAgent(
236
212
  *,
237
213
  in_args: InT | None = None,
238
214
  ctx: RunContext[CtxT] | None = None,
239
- ) -> OutT_co:
215
+ ) -> OutT:
240
216
  if self._parse_output_impl:
241
217
  return self._parse_output_impl(
242
218
  conversation=conversation, in_args=in_args, ctx=ctx
@@ -257,9 +233,12 @@ class LLMAgent(
257
233
  memory: LLMAgentMemory,
258
234
  call_id: str,
259
235
  ctx: RunContext[CtxT] | None = None,
260
- ) -> Sequence[OutT_co]:
236
+ ) -> OutT:
261
237
  system_message, input_message = self._memorize_inputs(
262
- chat_inputs=chat_inputs, in_args=in_args, memory=memory, ctx=ctx
238
+ memory=memory,
239
+ chat_inputs=chat_inputs,
240
+ in_args=in_args,
241
+ ctx=ctx,
263
242
  )
264
243
  if system_message:
265
244
  self._print_messages([system_message], call_id=call_id, ctx=ctx)
@@ -268,11 +247,9 @@ class LLMAgent(
268
247
 
269
248
  await self._policy_executor.execute(memory, call_id=call_id, ctx=ctx)
270
249
 
271
- return [
272
- self._parse_output(
273
- conversation=memory.message_history, in_args=in_args, ctx=ctx
274
- )
275
- ]
250
+ return self._parse_output(
251
+ conversation=memory.message_history, in_args=in_args, ctx=ctx
252
+ )
276
253
 
277
254
  async def _process_stream(
278
255
  self,
@@ -284,7 +261,10 @@ class LLMAgent(
284
261
  ctx: RunContext[CtxT] | None = None,
285
262
  ) -> AsyncIterator[Event[Any]]:
286
263
  system_message, input_message = self._memorize_inputs(
287
- chat_inputs=chat_inputs, in_args=in_args, memory=memory, ctx=ctx
264
+ memory=memory,
265
+ chat_inputs=chat_inputs,
266
+ in_args=in_args,
267
+ ctx=ctx,
288
268
  )
289
269
  if system_message:
290
270
  self._print_messages([system_message], call_id=call_id, ctx=ctx)
@@ -322,6 +302,10 @@ class LLMAgent(
322
302
  cur_cls = type(self)
323
303
  base_cls = LLMAgent[Any, Any, Any]
324
304
 
305
+ # Packet routing
306
+ if cur_cls._select_recipients is not base_cls._select_recipients: # noqa: SLF001
307
+ self.select_recipients_impl = self._select_recipients
308
+
325
309
  # Prompt builder
326
310
 
327
311
  if cur_cls._make_system_prompt is not base_cls._make_system_prompt: # noqa: SLF001
@@ -354,15 +338,9 @@ class LLMAgent(
354
338
  return self._prompt_builder.make_system_prompt(sys_args=sys_args, ctx=ctx)
355
339
 
356
340
  def _make_input_content(
357
- self,
358
- *,
359
- in_args: InT | None = None,
360
- usr_args: LLMPromptArgs | None = None,
361
- ctx: RunContext[CtxT] | None = None,
341
+ self, in_args: InT | None = None, *, ctx: RunContext[CtxT] | None = None
362
342
  ) -> Content:
363
- return self._prompt_builder.make_input_content(
364
- in_args=in_args, usr_args=usr_args, ctx=ctx
365
- )
343
+ return self._prompt_builder.make_input_content(in_args=in_args, ctx=ctx)
366
344
 
367
345
  def _exit_tool_call_loop(
368
346
  self,
@@ -403,8 +381,8 @@ class LLMAgent(
403
381
  return func
404
382
 
405
383
  def parse_output(
406
- self, func: ParseOutputHandler[InT, OutT_co, CtxT]
407
- ) -> ParseOutputHandler[InT, OutT_co, CtxT]:
384
+ self, func: ParseOutputHandler[InT, OutT, CtxT]
385
+ ) -> ParseOutputHandler[InT, OutT, CtxT]:
408
386
  if self._used_default_llm_response_schema:
409
387
  self._policy_executor.llm.response_schema = None
410
388
  self._parse_output_impl = func
@@ -255,7 +255,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
255
255
 
256
256
  final_answer_message = self._extract_final_answer_from_tool_calls(memory=memory)
257
257
  if final_answer_message is None:
258
- raise AgentFinalAnswerError
258
+ raise AgentFinalAnswerError(proc_name=self.agent_name, call_id=call_id)
259
259
 
260
260
  return final_answer_message
261
261
 
@@ -282,7 +282,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
282
282
 
283
283
  final_answer_message = self._extract_final_answer_from_tool_calls(memory)
284
284
  if final_answer_message is None:
285
- raise AgentFinalAnswerError
285
+ raise AgentFinalAnswerError(proc_name=self.agent_name, call_id=call_id)
286
286
  yield GenMessageEvent(
287
287
  proc_name=self.agent_name, call_id=call_id, data=final_answer_message
288
288
  )
grasp_agents/packet.py CHANGED
@@ -1,11 +1,13 @@
1
1
  from collections.abc import Sequence
2
- from typing import Generic, TypeVar
2
+ from typing import Annotated, Any, Generic, Literal, TypeVar
3
3
  from uuid import uuid4
4
4
 
5
- from pydantic import BaseModel, ConfigDict, Field
5
+ from pydantic import AfterValidator, BaseModel, ConfigDict, Field
6
6
 
7
7
  from .typing.io import ProcName
8
8
 
9
+ START_PROC_NAME: Literal["*START*"] = "*START*"
10
+
9
11
  _PayloadT_co = TypeVar("_PayloadT_co", covariant=True)
10
12
 
11
13
 
@@ -13,15 +15,32 @@ class Packet(BaseModel, Generic[_PayloadT_co]):
13
15
  id: str = Field(default_factory=lambda: str(uuid4())[:8])
14
16
  payloads: Sequence[_PayloadT_co]
15
17
  sender: ProcName
16
- recipients: Sequence[ProcName] = Field(default_factory=list)
18
+ recipients: Sequence[ProcName] | None = None
17
19
 
18
20
  model_config = ConfigDict(extra="forbid")
19
21
 
20
22
  def __repr__(self) -> str:
23
+ _to = ", ".join(self.recipients) if self.recipients else "None"
21
24
  return (
22
25
  f"{self.__class__.__name__}:\n"
23
26
  f"ID: {self.id}\n"
24
27
  f"From: {self.sender}\n"
25
- f"To: {', '.join(self.recipients)}\n"
28
+ f"To: {_to}\n"
26
29
  f"Payloads: {len(self.payloads)}"
27
30
  )
31
+
32
+
33
+ def _check_recipients_length(v: Sequence[ProcName] | None) -> Sequence[ProcName] | None:
34
+ if v is not None and len(v) != 1:
35
+ raise ValueError("recipients must contain exactly one item")
36
+ return v
37
+
38
+
39
+ class StartPacket(Packet[_PayloadT_co]):
40
+ chat_inputs: Any | None = "start"
41
+ sender: ProcName = Field(default=START_PROC_NAME, frozen=True)
42
+ payloads: Sequence[_PayloadT_co] = Field(default=(), frozen=True)
43
+ recipients: Annotated[
44
+ Sequence[ProcName] | None,
45
+ AfterValidator(_check_recipients_length),
46
+ ] = None
@@ -1,7 +1,8 @@
1
1
  import asyncio
2
2
  import logging
3
3
  from collections.abc import AsyncIterator
4
- from typing import Any, Generic, Protocol, TypeVar
4
+ from types import TracebackType
5
+ from typing import Any, Generic, Literal, Protocol, TypeVar
5
6
 
6
7
  from .packet import Packet
7
8
  from .run_context import CtxT, RunContext
@@ -11,6 +12,9 @@ from .typing.io import ProcName
11
12
  logger = logging.getLogger(__name__)
12
13
 
13
14
 
15
+ END_PROC_NAME: Literal["*END*"] = "*END*"
16
+
17
+
14
18
  _PayloadT_contra = TypeVar("_PayloadT_contra", contravariant=True)
15
19
 
16
20
 
@@ -20,73 +24,136 @@ class PacketHandler(Protocol[_PayloadT_contra, CtxT]):
20
24
  packet: Packet[_PayloadT_contra],
21
25
  ctx: RunContext[CtxT],
22
26
  **kwargs: Any,
23
- ) -> AsyncIterator[Event[Any]] | None: ...
27
+ ) -> None: ...
24
28
 
25
29
 
26
30
  class PacketPool(Generic[CtxT]):
27
31
  def __init__(self) -> None:
28
- self._queues: dict[ProcName, asyncio.Queue[Packet[Any]]] = {}
32
+ self._packet_queues: dict[ProcName, asyncio.Queue[Packet[Any] | None]] = {}
29
33
  self._packet_handlers: dict[ProcName, PacketHandler[Any, CtxT]] = {}
30
- self._tasks: dict[ProcName, asyncio.Task[AsyncIterator[Event[Any]] | None]] = {}
34
+ self._task_group: asyncio.TaskGroup | None = None
35
+
36
+ self._event_queue: asyncio.Queue[Event[Any] | None] = asyncio.Queue()
37
+
38
+ self._final_result_fut: asyncio.Future[Packet[Any]] | None = None
39
+
40
+ self._stopping = False
41
+ self._stopped_evt = asyncio.Event()
42
+
43
+ self._errors: list[Exception] = []
31
44
 
32
45
  async def post(self, packet: Packet[Any]) -> None:
33
- for recipient_id in packet.recipients:
34
- queue = self._queues.setdefault(recipient_id, asyncio.Queue())
46
+ if packet.recipients == [END_PROC_NAME]:
47
+ fut = self._ensure_final_future()
48
+ if not fut.done():
49
+ fut.set_result(packet)
50
+ await self.shutdown()
51
+ return
52
+
53
+ for recipient_id in packet.recipients or []:
54
+ queue = self._packet_queues.setdefault(recipient_id, asyncio.Queue())
35
55
  await queue.put(packet)
36
56
 
57
+ def _ensure_final_future(self) -> asyncio.Future[Packet[Any]]:
58
+ fut = self._final_result_fut
59
+ if fut is None:
60
+ fut = asyncio.get_running_loop().create_future()
61
+ self._final_result_fut = fut
62
+ return fut
63
+
64
+ async def final_result(self) -> Packet[Any]:
65
+ fut = self._ensure_final_future()
66
+ try:
67
+ return await fut
68
+ finally:
69
+ await self.shutdown()
70
+
37
71
  def register_packet_handler(
38
72
  self,
39
- processor_name: ProcName,
73
+ proc_name: ProcName,
40
74
  handler: PacketHandler[Any, CtxT],
41
75
  ctx: RunContext[CtxT],
42
76
  **run_kwargs: Any,
43
77
  ) -> None:
44
- self._packet_handlers[processor_name] = handler
45
- self._queues.setdefault(processor_name, asyncio.Queue())
46
- if processor_name not in self._tasks:
47
- self._tasks[processor_name] = asyncio.create_task(
48
- self._handle_packets(processor_name, ctx=ctx, **run_kwargs)
78
+ if self._stopping:
79
+ raise RuntimeError("PacketPool is stopping/stopped")
80
+
81
+ self._packet_handlers[proc_name] = handler
82
+ self._packet_queues.setdefault(proc_name, asyncio.Queue())
83
+
84
+ if self._task_group is not None:
85
+ self._task_group.create_task(
86
+ self._handle_packets(proc_name, ctx=ctx, **run_kwargs),
87
+ name=f"packet-handler:{proc_name}",
49
88
  )
50
89
 
90
+ async def push_event(self, event: Event[Any]) -> None:
91
+ await self._event_queue.put(event)
92
+
93
+ async def __aenter__(self) -> "PacketPool[CtxT]":
94
+ self._task_group = asyncio.TaskGroup()
95
+ await self._task_group.__aenter__()
96
+
97
+ return self
98
+
99
+ async def __aexit__(
100
+ self,
101
+ exc_type: type[BaseException] | None,
102
+ exc: BaseException | None,
103
+ tb: TracebackType | None,
104
+ ) -> bool | None:
105
+ await self.shutdown()
106
+
107
+ if self._task_group is not None:
108
+ try:
109
+ return await self._task_group.__aexit__(exc_type, exc, tb)
110
+ finally:
111
+ self._task_group = None
112
+
113
+ if self._errors:
114
+ raise ExceptionGroup("PacketPool worker errors", self._errors)
115
+
116
+ return False
117
+
51
118
  async def _handle_packets(
52
- self, processor_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
53
- ) -> AsyncIterator[Event[Any]] | None:
54
- queue = self._queues[processor_name]
119
+ self, proc_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
120
+ ) -> None:
121
+ queue = self._packet_queues[proc_name]
122
+ handler = self._packet_handlers[proc_name]
123
+
55
124
  while True:
125
+ packet = await queue.get()
126
+ if packet is None:
127
+ break
56
128
  try:
57
- packet = await queue.get()
58
- handler = self._packet_handlers.get(processor_name)
59
- if handler is None:
60
- break
61
- try:
62
- if ctx.is_streaming:
63
- async for event in handler(packet, ctx=ctx, **run_kwargs): # type: ignore[return-value]
64
- yield event
65
- else:
66
- await handler(packet, ctx=ctx, **run_kwargs)
67
-
68
- except Exception:
69
- logger.exception(f"Error handling packet for {processor_name}")
70
-
71
- queue.task_done()
72
-
73
- except Exception:
74
- logger.exception(
75
- f"Unexpected error in processing loop for {processor_name}"
76
- )
77
-
78
- async def unregister_packet_handler(self, processor_name: ProcName) -> None:
79
- if task := self._tasks.get(processor_name):
80
- task.cancel()
81
- try:
82
- await task
129
+ await handler(packet, ctx=ctx, **run_kwargs)
83
130
  except asyncio.CancelledError:
84
- logger.info(f"{processor_name} exited")
85
-
86
- self._tasks.pop(processor_name, None)
87
- self._queues.pop(processor_name, None)
88
- self._packet_handlers.pop(processor_name, None)
89
-
90
- async def stop_all(self) -> None:
91
- for processor_name in list(self._tasks):
92
- await self.unregister_packet_handler(processor_name)
131
+ raise
132
+ except Exception as err:
133
+ logger.exception("Error handling packet for %s", proc_name)
134
+ self._errors.append(err)
135
+ fut = self._final_result_fut
136
+ if fut and not fut.done():
137
+ fut.set_exception(err)
138
+ await self.shutdown()
139
+ raise
140
+
141
+ async def stream_events(self) -> AsyncIterator[Event[Any]]:
142
+ while True:
143
+ event = await self._event_queue.get()
144
+ if event is None:
145
+ break
146
+ yield event
147
+
148
+ async def shutdown(self) -> None:
149
+ if self._stopping:
150
+ await self._stopped_evt.wait()
151
+ return
152
+ self._stopping = True
153
+ try:
154
+ await self._event_queue.put(None)
155
+ for queue in self._packet_queues.values():
156
+ await queue.put(None)
157
+
158
+ finally:
159
+ self._stopped_evt.set()