grasp_agents 0.2.11__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. grasp_agents/__init__.py +15 -14
  2. grasp_agents/cloud_llm.py +118 -131
  3. grasp_agents/comm_processor.py +201 -0
  4. grasp_agents/generics_utils.py +15 -7
  5. grasp_agents/llm.py +60 -31
  6. grasp_agents/llm_agent.py +229 -273
  7. grasp_agents/llm_agent_memory.py +58 -0
  8. grasp_agents/llm_policy_executor.py +482 -0
  9. grasp_agents/memory.py +20 -134
  10. grasp_agents/message_history.py +140 -0
  11. grasp_agents/openai/__init__.py +54 -36
  12. grasp_agents/openai/completion_chunk_converters.py +78 -0
  13. grasp_agents/openai/completion_converters.py +53 -30
  14. grasp_agents/openai/content_converters.py +13 -14
  15. grasp_agents/openai/converters.py +44 -68
  16. grasp_agents/openai/message_converters.py +58 -72
  17. grasp_agents/openai/openai_llm.py +101 -42
  18. grasp_agents/openai/tool_converters.py +24 -19
  19. grasp_agents/packet.py +24 -0
  20. grasp_agents/packet_pool.py +91 -0
  21. grasp_agents/printer.py +29 -15
  22. grasp_agents/processor.py +194 -0
  23. grasp_agents/prompt_builder.py +175 -192
  24. grasp_agents/run_context.py +20 -37
  25. grasp_agents/typing/completion.py +58 -12
  26. grasp_agents/typing/completion_chunk.py +173 -0
  27. grasp_agents/typing/converters.py +8 -12
  28. grasp_agents/typing/events.py +86 -0
  29. grasp_agents/typing/io.py +4 -13
  30. grasp_agents/typing/message.py +12 -50
  31. grasp_agents/typing/tool.py +52 -26
  32. grasp_agents/usage_tracker.py +6 -6
  33. grasp_agents/utils.py +3 -3
  34. grasp_agents/workflow/looped_workflow.py +132 -0
  35. grasp_agents/workflow/parallel_processor.py +95 -0
  36. grasp_agents/workflow/sequential_workflow.py +66 -0
  37. grasp_agents/workflow/workflow_processor.py +78 -0
  38. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
  39. grasp_agents-0.3.1.dist-info/RECORD +51 -0
  40. grasp_agents/agent_message.py +0 -27
  41. grasp_agents/agent_message_pool.py +0 -92
  42. grasp_agents/base_agent.py +0 -51
  43. grasp_agents/comm_agent.py +0 -217
  44. grasp_agents/llm_agent_state.py +0 -79
  45. grasp_agents/tool_orchestrator.py +0 -203
  46. grasp_agents/workflow/looped_agent.py +0 -134
  47. grasp_agents/workflow/sequential_agent.py +0 -72
  48. grasp_agents/workflow/workflow_agent.py +0 -88
  49. grasp_agents-0.2.11.dist-info/RECORD +0 -46
  50. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
  51. {grasp_agents-0.2.11.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/__init__.py CHANGED
@@ -1,38 +1,39 @@
1
1
  # pyright: reportUnusedImport=false
2
2
 
3
3
 
4
- from .agent_message import AgentMessage
5
- from .base_agent import BaseAgent
6
- from .comm_agent import CommunicatingAgent
4
+ from .comm_processor import CommProcessor
7
5
  from .llm import LLM, LLMSettings
8
6
  from .llm_agent import LLMAgent
9
- from .run_context import RunArgs, RunContextWrapper
7
+ from .llm_agent_memory import LLMAgentMemory
8
+ from .memory import Memory
9
+ from .packet import Packet
10
+ from .processor import Processor
11
+ from .run_context import RunArgs, RunContext
10
12
  from .typing.completion import Completion
11
13
  from .typing.content import Content, ImageData
12
- from .typing.io import AgentID, AgentState, LLMFormattedArgs, LLMPrompt, LLMPromptArgs
13
- from .typing.message import AssistantMessage, Conversation, SystemMessage, UserMessage
14
+ from .typing.io import LLMPrompt, LLMPromptArgs, ProcName
15
+ from .typing.message import AssistantMessage, Messages, SystemMessage, UserMessage
14
16
  from .typing.tool import BaseTool
15
17
 
16
18
  __all__ = [
17
19
  "LLM",
18
- "AgentID",
19
- "AgentMessage",
20
- "AgentState",
21
20
  "AssistantMessage",
22
- "BaseAgent",
23
21
  "BaseTool",
24
- "CommunicatingAgent",
22
+ "CommProcessor",
25
23
  "Completion",
26
24
  "Content",
27
- "Conversation",
28
25
  "ImageData",
29
26
  "LLMAgent",
30
- "LLMFormattedArgs",
31
27
  "LLMPrompt",
32
28
  "LLMPromptArgs",
33
29
  "LLMSettings",
30
+ "Messages",
31
+ "Packet",
32
+ "Packet",
33
+ "ProcName",
34
+ "Processor",
34
35
  "RunArgs",
35
- "RunContextWrapper",
36
+ "RunContext",
36
37
  "SystemMessage",
37
38
  "UserMessage",
38
39
  ]
grasp_agents/cloud_llm.py CHANGED
@@ -17,16 +17,14 @@ from tenacity import (
17
17
  from typing_extensions import TypedDict
18
18
 
19
19
  from .http_client import AsyncHTTPClientParams, create_async_http_client
20
- from .llm import LLM, ConvertT, LLMSettings, SettingsT
21
- from .memory import MessageHistory
22
- from .rate_limiting.rate_limiter_chunked import ( # type: ignore
23
- RateLimiterC,
24
- limit_rate_chunked,
25
- )
26
- from .typing.completion import Completion, CompletionChunk
27
- from .typing.message import AssistantMessage, Conversation
20
+ from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
21
+ from .message_history import MessageHistory
22
+ from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate_chunked
23
+ from .typing.completion import Completion
24
+ from .typing.completion_chunk import CompletionChunk, combine_completion_chunks
25
+ from .typing.events import CompletionChunkEvent, CompletionEvent
26
+ from .typing.message import AssistantMessage, Messages
28
27
  from .typing.tool import BaseTool, ToolChoice
29
- from .utils import validate_obj_from_json_or_py_string
30
28
 
31
29
  logger = logging.getLogger(__name__)
32
30
 
@@ -38,7 +36,7 @@ class APIProviderInfo(TypedDict):
38
36
  name: APIProvider
39
37
  base_url: str
40
38
  api_key: str | None
41
- struct_output_support: tuple[str, ...]
39
+ struct_outputs_support: tuple[str, ...]
42
40
 
43
41
 
44
42
  PROVIDERS: dict[APIProvider, APIProviderInfo] = {
@@ -46,19 +44,19 @@ PROVIDERS: dict[APIProvider, APIProviderInfo] = {
46
44
  name="openai",
47
45
  base_url="https://api.openai.com/v1",
48
46
  api_key=os.getenv("OPENAI_API_KEY"),
49
- struct_output_support=("*",),
47
+ struct_outputs_support=("*",),
50
48
  ),
51
49
  "openrouter": APIProviderInfo(
52
50
  name="openrouter",
53
51
  base_url="https://openrouter.ai/api/v1",
54
52
  api_key=os.getenv("OPENROUTER_API_KEY"),
55
- struct_output_support=(),
53
+ struct_outputs_support=(),
56
54
  ),
57
55
  "google_ai_studio": APIProviderInfo(
58
56
  name="google_ai_studio",
59
57
  base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
60
58
  api_key=os.getenv("GOOGLE_AI_STUDIO_API_KEY"),
61
- struct_output_support=("*",),
59
+ struct_outputs_support=("*",),
62
60
  ),
63
61
  }
64
62
 
@@ -88,20 +86,16 @@ def retry_before_callback(retry_state: RetryCallState) -> None:
88
86
 
89
87
 
90
88
  class CloudLLMSettings(LLMSettings, total=False):
91
- max_completion_tokens: int | None
92
- temperature: float | None
93
- top_p: float | None
94
- seed: int | None
95
- use_structured_outputs: bool
89
+ use_struct_outputs: bool
96
90
 
97
91
 
98
- class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
92
+ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
99
93
  def __init__(
100
94
  self,
101
95
  # Base LLM args
102
96
  model_name: str,
103
- converters: ConvertT,
104
- llm_settings: SettingsT | None = None,
97
+ converters: ConvertT_co,
98
+ llm_settings: SettingsT_co | None = None,
105
99
  model_id: str | None = None,
106
100
  tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
107
101
  response_format: type | Mapping[str, type] | None = None,
@@ -110,7 +104,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
110
104
  dict[str, Any] | AsyncHTTPClientParams | None
111
105
  ) = None,
112
106
  # Rate limiting
113
- rate_limiter: (RateLimiterC[Conversation, AssistantMessage] | None) = None,
107
+ rate_limiter: (RateLimiterC[Messages, AssistantMessage] | None) = None,
114
108
  rate_limiter_rpm: float | None = None,
115
109
  rate_limiter_chunk_size: int = 1000,
116
110
  rate_limiter_max_concurrency: int = 300,
@@ -144,24 +138,26 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
144
138
  self._api_provider: APIProvider = api_provider
145
139
  self._api_model_name: str = api_model_name
146
140
 
147
- self._struct_output_support: bool = any(
141
+ self._struct_outputs_support: bool = any(
148
142
  fnmatch.fnmatch(self._model_name, pat)
149
- for pat in PROVIDERS[api_provider]["struct_output_support"]
143
+ for pat in PROVIDERS[api_provider]["struct_outputs_support"]
150
144
  )
151
145
  if (
152
- self._llm_settings.get("use_structured_outputs")
153
- and not self._struct_output_support
146
+ self._llm_settings.get("use_struct_outputs")
147
+ and not self._struct_outputs_support
154
148
  ):
155
149
  raise ValueError(
156
150
  f"Model {self._model_name} does not support structured outputs."
157
151
  )
158
152
 
159
- self._rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = (
153
+ self._tool_call_settings: dict[str, Any] = {}
154
+
155
+ self._rate_limiter: RateLimiterC[Messages, AssistantMessage] | None = (
160
156
  self._get_rate_limiter(
161
157
  rate_limiter=rate_limiter,
162
- rate_limiter_rpm=rate_limiter_rpm,
163
- rate_limiter_chunk_size=rate_limiter_chunk_size,
164
- rate_limiter_max_concurrency=rate_limiter_max_concurrency,
158
+ rpm=rate_limiter_rpm,
159
+ chunk_size=rate_limiter_chunk_size,
160
+ max_concurrency=rate_limiter_max_concurrency,
165
161
  )
166
162
  )
167
163
  self.no_tqdm = no_tqdm
@@ -188,29 +184,36 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
188
184
  @property
189
185
  def rate_limiter(
190
186
  self,
191
- ) -> RateLimiterC[Conversation, AssistantMessage] | None:
187
+ ) -> RateLimiterC[Messages, AssistantMessage] | None:
192
188
  return self._rate_limiter
193
189
 
194
190
  def _make_completion_kwargs(
195
- self, conversation: Conversation, tool_choice: ToolChoice | None = None
191
+ self,
192
+ conversation: Messages,
193
+ tool_choice: ToolChoice | None = None,
194
+ n_choices: int | None = None,
196
195
  ) -> dict[str, Any]:
197
196
  api_messages = [self._converters.to_message(m) for m in conversation]
198
197
 
199
198
  api_tools = None
200
199
  api_tool_choice = None
201
200
  if self.tools:
202
- api_tools = [self._converters.to_tool(t) for t in self.tools.values()]
201
+ api_tools = [
202
+ self._converters.to_tool(t, **self._tool_call_settings)
203
+ for t in self.tools.values()
204
+ ]
203
205
  if tool_choice is not None:
204
206
  api_tool_choice = self._converters.to_tool_choice(tool_choice)
205
207
 
206
208
  api_llm_settings = deepcopy(self.llm_settings or {})
207
- api_llm_settings.pop("use_structured_outputs", None)
209
+ api_llm_settings.pop("use_struct_outputs", None)
208
210
 
209
211
  return dict(
210
212
  api_messages=api_messages,
211
213
  api_tools=api_tools,
212
214
  api_tool_choice=api_tool_choice,
213
215
  api_response_format=self._response_format,
216
+ n_choices=n_choices,
214
217
  **api_llm_settings,
215
218
  )
216
219
 
@@ -221,6 +224,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
221
224
  *,
222
225
  api_tools: list[Any] | None = None,
223
226
  api_tool_choice: Any | None = None,
227
+ n_choices: int | None = None,
224
228
  **api_llm_settings: Any,
225
229
  ) -> Any:
226
230
  pass
@@ -233,6 +237,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
233
237
  api_tools: list[Any] | None = None,
234
238
  api_tool_choice: Any | None = None,
235
239
  api_response_format: type | None = None,
240
+ n_choices: int | None = None,
236
241
  **api_llm_settings: Any,
237
242
  ) -> Any:
238
243
  pass
@@ -244,161 +249,143 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
244
249
  *,
245
250
  api_tools: list[Any] | None = None,
246
251
  api_tool_choice: Any | None = None,
252
+ n_choices: int | None = None,
247
253
  **api_llm_settings: Any,
248
254
  ) -> AsyncIterator[Any]:
249
255
  pass
250
256
 
251
- async def generate_completion(
257
+ @abstractmethod
258
+ async def _get_parsed_completion_stream(
259
+ self,
260
+ api_messages: list[Any],
261
+ *,
262
+ api_tools: list[Any] | None = None,
263
+ api_tool_choice: Any | None = None,
264
+ api_response_format: type | None = None,
265
+ n_choices: int | None = None,
266
+ **api_llm_settings: Any,
267
+ ) -> AsyncIterator[Any]:
268
+ pass
269
+
270
+ async def generate_completion_no_retry(
252
271
  self,
253
- conversation: Conversation,
272
+ conversation: Messages,
254
273
  *,
255
274
  tool_choice: ToolChoice | None = None,
256
- **kwargs: Any,
275
+ n_choices: int | None = None,
257
276
  ) -> Completion:
258
277
  completion_kwargs = self._make_completion_kwargs(
259
- conversation=conversation, tool_choice=tool_choice
278
+ conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
260
279
  )
261
280
 
262
- if (
263
- self._response_format is None
264
- or (not self._struct_output_support)
265
- or (not self._llm_settings.get("use_structured_outputs"))
266
- ):
281
+ if not self._llm_settings.get("use_struct_outputs"):
267
282
  completion_kwargs.pop("api_response_format", None)
268
- api_completion = await self._get_completion(**completion_kwargs, **kwargs)
283
+ api_completion = await self._get_completion(**completion_kwargs)
269
284
  else:
270
- api_completion = await self._get_parsed_completion(
271
- **completion_kwargs, **kwargs
272
- )
285
+ api_completion = await self._get_parsed_completion(**completion_kwargs)
273
286
 
274
287
  completion = self._converters.from_completion(
275
- api_completion, model_id=self.model_id
288
+ api_completion, name=self.model_id
276
289
  )
277
290
 
278
- self._validate_completion(completion)
291
+ if not self._llm_settings.get("use_struct_outputs"):
292
+ # If validation is not handled by the structured output functionality
293
+ # of the LLM provider
294
+ self._validate_completion(completion)
295
+ self._validate_tool_calls(completion)
279
296
 
280
297
  return completion
281
298
 
282
- def _validate_completion(self, completion: Completion) -> None:
283
- for choice in completion.choices:
284
- message = choice.message
285
- if (
286
- self._response_format_pyd is not None
287
- and not self._llm_settings.get("use_structured_outputs")
288
- and not message.tool_calls
289
- ):
290
- validate_obj_from_json_or_py_string(
291
- message.content or "",
292
- adapter=self._response_format_pyd,
293
- from_substring=True,
294
- )
295
-
296
299
  async def generate_completion_stream(
297
300
  self,
298
- conversation: Conversation,
301
+ conversation: Messages,
299
302
  *,
300
303
  tool_choice: ToolChoice | None = None,
301
- **kwargs: Any,
302
- ) -> AsyncIterator[CompletionChunk]:
304
+ n_choices: int | None = None,
305
+ ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
303
306
  completion_kwargs = self._make_completion_kwargs(
304
- conversation=conversation, tool_choice=tool_choice
305
- )
306
- completion_kwargs.pop("api_response_format", None)
307
- api_completion_chunk_iterator = await self._get_completion_stream(
308
- **completion_kwargs, **kwargs
307
+ conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
309
308
  )
310
309
 
311
- return self._converters.from_completion_chunk_iterator(
312
- api_completion_chunk_iterator, model_id=self.model_id
313
- )
310
+ if not self._llm_settings.get("use_struct_outputs"):
311
+ completion_kwargs.pop("api_response_format", None)
312
+ api_stream = await self._get_completion_stream(**completion_kwargs)
313
+ else:
314
+ api_stream = await self._get_parsed_completion_stream(**completion_kwargs)
315
+
316
+ async def iterate() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
317
+ completion_chunks: list[CompletionChunk] = []
318
+ async for api_completion_chunk in api_stream:
319
+ completion_chunk = self._converters.from_completion_chunk(
320
+ api_completion_chunk, name=self.model_id
321
+ )
322
+ completion_chunks.append(completion_chunk)
323
+ yield CompletionChunkEvent(data=completion_chunk, name=self.model_id)
324
+
325
+ # TODO: can be done using the OpenAI final_completion_chunk
326
+ completion = combine_completion_chunks(completion_chunks)
327
+
328
+ yield CompletionEvent(data=completion, name=self.model_id)
314
329
 
315
- async def _generate_completion_with_retry(
330
+ if not self._llm_settings.get("use_struct_outputs"):
331
+ # If validation is not handled by the structured outputs functionality
332
+ # of the LLM provider
333
+ self._validate_completion(completion)
334
+ self._validate_tool_calls(completion)
335
+
336
+ return iterate()
337
+
338
+ async def generate_completion(
316
339
  self,
317
- conversation: Conversation,
340
+ conversation: Messages,
318
341
  *,
319
342
  tool_choice: ToolChoice | None = None,
320
- **kwargs: Any,
343
+ n_choices: int | None = None,
321
344
  ) -> Completion:
322
345
  wrapped_func = retry(
323
346
  wait=wait_random_exponential(min=1, max=8),
324
347
  stop=stop_after_attempt(self.num_generation_retries + 1),
325
348
  before=retry_before_callback,
326
349
  retry_error_callback=retry_error_callback,
327
- )(self.__class__.generate_completion)
350
+ )(self.__class__.generate_completion_no_retry)
328
351
 
329
- return await wrapped_func(self, conversation, tool_choice=tool_choice, **kwargs)
352
+ return await wrapped_func(
353
+ self, conversation, tool_choice=tool_choice, n_choices=n_choices
354
+ )
330
355
 
331
356
  @limit_rate_chunked # type: ignore
332
- async def _generate_completion_batch_with_retry_and_rate_lim(
357
+ async def _generate_completion_batch(
333
358
  self,
334
- conversation: Conversation,
359
+ conversation: Messages,
335
360
  *,
336
361
  tool_choice: ToolChoice | None = None,
337
- **kwargs: Any,
338
362
  ) -> Completion:
339
- return await self._generate_completion_with_retry(
340
- conversation, tool_choice=tool_choice, **kwargs
341
- )
363
+ return await self.generate_completion(conversation, tool_choice=tool_choice)
342
364
 
343
365
  async def generate_completion_batch(
344
- self,
345
- message_history: MessageHistory,
346
- *,
347
- tool_choice: ToolChoice | None = None,
348
- **kwargs: Any,
366
+ self, message_history: MessageHistory, *, tool_choice: ToolChoice | None = None
349
367
  ) -> Sequence[Completion]:
350
- return await self._generate_completion_batch_with_retry_and_rate_lim(
351
- list(message_history.batched_conversations), # type: ignore
368
+ return await self._generate_completion_batch(
369
+ list(message_history.conversations), # type: ignore
352
370
  tool_choice=tool_choice,
353
- **kwargs,
354
- )
355
-
356
- async def generate_message(
357
- self,
358
- conversation: Conversation,
359
- *,
360
- tool_choice: ToolChoice | None = None,
361
- **kwargs: Any,
362
- ) -> AssistantMessage:
363
- completion = await self.generate_completion(
364
- conversation, tool_choice=tool_choice, **kwargs
365
371
  )
366
372
 
367
- return completion.choices[0].message
368
-
369
- async def generate_message_batch(
370
- self,
371
- message_history: MessageHistory,
372
- *,
373
- tool_choice: ToolChoice | None = None,
374
- **kwargs: Any,
375
- ) -> Sequence[AssistantMessage]:
376
- completion_batch = await self.generate_completion_batch(
377
- message_history, tool_choice=tool_choice, **kwargs
378
- )
379
-
380
- return [completion.choices[0].message for completion in completion_batch]
381
-
382
373
  def _get_rate_limiter(
383
374
  self,
384
- rate_limiter: RateLimiterC[Conversation, AssistantMessage] | None = None,
385
- rate_limiter_rpm: float | None = None,
386
- rate_limiter_chunk_size: int = 1000,
387
- rate_limiter_max_concurrency: int = 300,
388
- ) -> RateLimiterC[Conversation, AssistantMessage] | None:
375
+ rate_limiter: RateLimiterC[Messages, AssistantMessage] | None = None,
376
+ rpm: float | None = None,
377
+ chunk_size: int = 1000,
378
+ max_concurrency: int = 300,
379
+ ) -> RateLimiterC[Messages, AssistantMessage] | None:
389
380
  if rate_limiter is not None:
390
381
  logger.info(
391
382
  f"[{self.__class__.__name__}] Set rate limit to {rate_limiter.rpm} RPM"
392
383
  )
393
384
  return rate_limiter
394
- if rate_limiter_rpm is not None:
395
- logger.info(
396
- f"[{self.__class__.__name__}] Set rate limit to {rate_limiter_rpm} RPM"
397
- )
385
+ if rpm is not None:
386
+ logger.info(f"[{self.__class__.__name__}] Set rate limit to {rpm} RPM")
398
387
  return RateLimiterC(
399
- rpm=rate_limiter_rpm,
400
- chunk_size=rate_limiter_chunk_size,
401
- max_concurrency=rate_limiter_max_concurrency,
388
+ rpm=rpm, chunk_size=chunk_size, max_concurrency=max_concurrency
402
389
  )
403
390
 
404
391
  return None
@@ -0,0 +1,201 @@
1
+ import logging
2
+ from collections.abc import AsyncIterator, Sequence
3
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast
4
+
5
+ from pydantic import BaseModel
6
+ from pydantic.json_schema import SkipJsonSchema
7
+
8
+ from .packet import Packet
9
+ from .packet_pool import PacketPool
10
+ from .processor import Processor
11
+ from .run_context import CtxT, RunContext
12
+ from .typing.events import Event, PacketEvent
13
+ from .typing.io import InT_contra, MemT_co, OutT_co, ProcName
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class DynCommPayload(BaseModel):
19
+ selected_recipients: SkipJsonSchema[Sequence[ProcName]]
20
+
21
+
22
+ _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
23
+
24
+
25
+ class ExitCommunicationHandler(Protocol[_OutT_contra, CtxT]):
26
+ def __call__(
27
+ self,
28
+ out_packet: Packet[_OutT_contra],
29
+ ctx: RunContext[CtxT] | None,
30
+ ) -> bool: ...
31
+
32
+
33
+ class CommProcessor(
34
+ Processor[InT_contra, OutT_co, MemT_co, CtxT],
35
+ Generic[InT_contra, OutT_co, MemT_co, CtxT],
36
+ ):
37
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
38
+ 0: "_in_type",
39
+ 1: "_out_type",
40
+ }
41
+
42
+ def __init__(
43
+ self,
44
+ name: ProcName,
45
+ *,
46
+ recipients: Sequence[ProcName] | None = None,
47
+ packet_pool: PacketPool[CtxT] | None = None,
48
+ ) -> None:
49
+ super().__init__(name=name)
50
+
51
+ self.recipients = recipients or []
52
+
53
+ self._packet_pool = packet_pool
54
+ self._is_listening = False
55
+ self._exit_communication_impl: (
56
+ ExitCommunicationHandler[OutT_co, CtxT] | None
57
+ ) = None
58
+
59
+ def _validate_routing(self, payloads: Sequence[OutT_co]) -> Sequence[ProcName]:
60
+ if all(isinstance(p, DynCommPayload) for p in payloads):
61
+ payloads_ = cast("Sequence[DynCommPayload]", payloads)
62
+ selected_recipients_per_payload = [
63
+ set(p.selected_recipients or []) for p in payloads_
64
+ ]
65
+ assert all(
66
+ x == selected_recipients_per_payload[0]
67
+ for x in selected_recipients_per_payload
68
+ ), "All payloads must have the same recipient IDs for dynamic routing"
69
+
70
+ assert payloads_[0].selected_recipients is not None
71
+ selected_recipients = payloads_[0].selected_recipients
72
+
73
+ assert all(rid in self.recipients for rid in selected_recipients), (
74
+ "Dynamic routing is enabled, but recipient IDs are not in "
75
+ "the allowed agent's recipient IDs"
76
+ )
77
+
78
+ return selected_recipients
79
+
80
+ if all((not isinstance(p, DynCommPayload)) for p in payloads):
81
+ return self.recipients
82
+
83
+ raise ValueError(
84
+ "All payloads must be either DCommAgentPayload or not DCommAgentPayload"
85
+ )
86
+
87
+ async def run(
88
+ self,
89
+ chat_inputs: Any | None = None,
90
+ *,
91
+ in_packet: Packet[InT_contra] | None = None,
92
+ in_args: InT_contra | Sequence[InT_contra] | None = None,
93
+ forgetful: bool = True,
94
+ ctx: RunContext[CtxT] | None = None,
95
+ ) -> Packet[OutT_co]:
96
+ out_packet = await super().run(
97
+ chat_inputs=chat_inputs,
98
+ in_packet=in_packet,
99
+ in_args=in_args,
100
+ ctx=ctx,
101
+ )
102
+ recipients = self._validate_routing(out_packet.payloads)
103
+ routed_out_packet = Packet(
104
+ payloads=out_packet.payloads, sender=self.name, recipients=recipients
105
+ )
106
+ if self._packet_pool is not None and in_packet is None and in_args is None:
107
+ # If no input packet or args, we assume this is the first run.
108
+ await self._packet_pool.post(routed_out_packet)
109
+
110
+ return routed_out_packet
111
+
112
+ async def run_stream(
113
+ self,
114
+ chat_inputs: Any | None = None,
115
+ *,
116
+ in_packet: Packet[InT_contra] | None = None,
117
+ in_args: InT_contra | Sequence[InT_contra] | None = None,
118
+ forgetful: bool = True,
119
+ ctx: RunContext[CtxT] | None = None,
120
+ ) -> AsyncIterator[Event[Any]]:
121
+ out_packet: Packet[OutT_co] | None = None
122
+ async for event in super().run_stream(
123
+ chat_inputs=chat_inputs,
124
+ in_packet=in_packet,
125
+ in_args=in_args,
126
+ ctx=ctx,
127
+ ):
128
+ if isinstance(event, PacketEvent):
129
+ out_packet = event.data
130
+ else:
131
+ yield event
132
+
133
+ if out_packet is None:
134
+ raise RuntimeError("No output packet generated during stream run")
135
+
136
+ recipients = self._validate_routing(out_packet.payloads)
137
+ routed_out_packet = Packet(
138
+ payloads=out_packet.payloads, sender=self.name, recipients=recipients
139
+ )
140
+ if self._packet_pool is not None and in_packet is None and in_args is None:
141
+ # If no input packet or args, we assume this is the first run.
142
+ await self._packet_pool.post(routed_out_packet)
143
+
144
+ yield PacketEvent(data=routed_out_packet, name=self.name)
145
+
146
+ def exit_communication(
147
+ self, func: ExitCommunicationHandler[OutT_co, CtxT]
148
+ ) -> ExitCommunicationHandler[OutT_co, CtxT]:
149
+ self._exit_communication_impl = func
150
+
151
+ return func
152
+
153
+ def _exit_communication_fn(
154
+ self, out_packet: Packet[OutT_co], ctx: RunContext[CtxT] | None
155
+ ) -> bool:
156
+ if self._exit_communication_impl:
157
+ return self._exit_communication_impl(out_packet=out_packet, ctx=ctx)
158
+
159
+ return False
160
+
161
+ async def _packet_handler(
162
+ self,
163
+ packet: Packet[InT_contra],
164
+ ctx: RunContext[CtxT] | None = None,
165
+ **run_kwargs: Any,
166
+ ) -> None:
167
+ assert self._packet_pool is not None, "Packet pool must be initialized"
168
+
169
+ out_packet = await self.run(ctx=ctx, in_packet=packet, **run_kwargs)
170
+
171
+ if self._exit_communication_fn(out_packet=out_packet, ctx=ctx):
172
+ await self._packet_pool.stop_all()
173
+ return
174
+
175
+ await self._packet_pool.post(out_packet)
176
+
177
+ @property
178
+ def is_listening(self) -> bool:
179
+ return self._is_listening
180
+
181
+ async def start_listening(
182
+ self, ctx: RunContext[CtxT] | None = None, **run_kwargs: Any
183
+ ) -> None:
184
+ assert self._packet_pool is not None, "Packet pool must be initialized"
185
+
186
+ if self._is_listening:
187
+ return
188
+
189
+ self._is_listening = True
190
+ self._packet_pool.register_packet_handler(
191
+ processor_name=self.name,
192
+ handler=self._packet_handler,
193
+ ctx=ctx,
194
+ **run_kwargs,
195
+ )
196
+
197
+ async def stop_listening(self) -> None:
198
+ assert self._packet_pool is not None, "Packet pool must be initialized"
199
+
200
+ self._is_listening = False
201
+ await self._packet_pool.unregister_packet_handler(self.name)