grasp_agents 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  import json
3
- from collections.abc import AsyncIterator, Coroutine, Sequence
3
+ from collections.abc import AsyncIterator, Coroutine, Mapping, Sequence
4
4
  from itertools import starmap
5
5
  from logging import getLogger
6
6
  from typing import Any, Generic, Protocol, final
@@ -36,7 +36,7 @@ class ToolCallLoopTerminator(Protocol[CtxT]):
36
36
  self,
37
37
  conversation: Messages,
38
38
  *,
39
- ctx: RunContext[CtxT] | None,
39
+ ctx: RunContext[CtxT],
40
40
  **kwargs: Any,
41
41
  ) -> bool: ...
42
42
 
@@ -46,7 +46,7 @@ class MemoryManager(Protocol[CtxT]):
46
46
  self,
47
47
  memory: LLMAgentMemory,
48
48
  *,
49
- ctx: RunContext[CtxT] | None,
49
+ ctx: RunContext[CtxT],
50
50
  **kwargs: Any,
51
51
  ) -> None: ...
52
52
 
@@ -54,9 +54,12 @@ class MemoryManager(Protocol[CtxT]):
54
54
  class LLMPolicyExecutor(Generic[CtxT]):
55
55
  def __init__(
56
56
  self,
57
+ *,
57
58
  agent_name: str,
58
59
  llm: LLM[LLMSettings, Converters],
59
60
  tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
61
+ response_schema: Any | None = None,
62
+ response_schema_by_xml_tag: Mapping[str, Any] | None = None,
60
63
  max_turns: int,
61
64
  react_mode: bool = False,
62
65
  final_answer_type: type[BaseModel] = BaseModel,
@@ -70,12 +73,15 @@ class LLMPolicyExecutor(Generic[CtxT]):
70
73
  self._final_answer_as_tool_call = final_answer_as_tool_call
71
74
  self._final_answer_tool = self.get_final_answer_tool()
72
75
 
73
- _tools: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
76
+ tools_list: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
74
77
  if tools and final_answer_as_tool_call:
75
- _tools = tools + [self._final_answer_tool]
78
+ tools_list = tools + [self._final_answer_tool]
79
+ self._tools = {t.name: t for t in tools_list} if tools_list else None
80
+
81
+ self._response_schema = response_schema
82
+ self._response_schema_by_xml_tag = response_schema_by_xml_tag
76
83
 
77
84
  self._llm = llm
78
- self._llm.tools = _tools
79
85
 
80
86
  self._max_turns = max_turns
81
87
  self._react_mode = react_mode
@@ -91,9 +97,21 @@ class LLMPolicyExecutor(Generic[CtxT]):
91
97
  def llm(self) -> LLM[LLMSettings, Converters]:
92
98
  return self._llm
93
99
 
100
+ @property
101
+ def response_schema(self) -> Any | None:
102
+ return self._response_schema
103
+
104
+ @response_schema.setter
105
+ def response_schema(self, value: Any | None) -> None:
106
+ self._response_schema = value
107
+
108
+ @property
109
+ def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
110
+ return self._response_schema_by_xml_tag
111
+
94
112
  @property
95
113
  def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
96
- return self._llm.tools or {}
114
+ return self._tools or {}
97
115
 
98
116
  @property
99
117
  def max_turns(self) -> int:
@@ -104,7 +122,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
104
122
  self,
105
123
  conversation: Messages,
106
124
  *,
107
- ctx: RunContext[CtxT] | None = None,
125
+ ctx: RunContext[CtxT],
108
126
  **kwargs: Any,
109
127
  ) -> bool:
110
128
  if self.tool_call_loop_terminator:
@@ -117,7 +135,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
117
135
  self,
118
136
  memory: LLMAgentMemory,
119
137
  *,
120
- ctx: RunContext[CtxT] | None = None,
138
+ ctx: RunContext[CtxT],
121
139
  **kwargs: Any,
122
140
  ) -> None:
123
141
  if self.memory_manager:
@@ -126,12 +144,16 @@ class LLMPolicyExecutor(Generic[CtxT]):
126
144
  async def generate_message(
127
145
  self,
128
146
  memory: LLMAgentMemory,
147
+ *,
129
148
  call_id: str,
130
149
  tool_choice: ToolChoice | None = None,
131
- ctx: RunContext[CtxT] | None = None,
150
+ ctx: RunContext[CtxT],
132
151
  ) -> AssistantMessage:
133
152
  completion = await self.llm.generate_completion(
134
153
  memory.message_history,
154
+ response_schema=self.response_schema,
155
+ response_schema_by_xml_tag=self.response_schema_by_xml_tag,
156
+ tools=self.tools,
135
157
  tool_choice=tool_choice,
136
158
  n_choices=1,
137
159
  proc_name=self.agent_name,
@@ -147,9 +169,10 @@ class LLMPolicyExecutor(Generic[CtxT]):
147
169
  async def generate_message_stream(
148
170
  self,
149
171
  memory: LLMAgentMemory,
172
+ *,
150
173
  call_id: str,
151
174
  tool_choice: ToolChoice | None = None,
152
- ctx: RunContext[CtxT] | None = None,
175
+ ctx: RunContext[CtxT],
153
176
  ) -> AsyncIterator[
154
177
  CompletionChunkEvent[CompletionChunk]
155
178
  | CompletionEvent
@@ -160,6 +183,9 @@ class LLMPolicyExecutor(Generic[CtxT]):
160
183
 
161
184
  llm_event_stream = self.llm.generate_completion_stream(
162
185
  memory.message_history,
186
+ response_schema=self.response_schema,
187
+ response_schema_by_xml_tag=self.response_schema_by_xml_tag,
188
+ tools=self.tools,
163
189
  tool_choice=tool_choice,
164
190
  n_choices=1,
165
191
  proc_name=self.agent_name,
@@ -189,14 +215,14 @@ class LLMPolicyExecutor(Generic[CtxT]):
189
215
  calls: Sequence[ToolCall],
190
216
  memory: LLMAgentMemory,
191
217
  call_id: str,
192
- ctx: RunContext[CtxT] | None = None,
218
+ ctx: RunContext[CtxT],
193
219
  ) -> Sequence[ToolMessage]:
194
220
  # TODO: Add image support
195
221
  corouts: list[Coroutine[Any, Any, BaseModel]] = []
196
222
  for call in calls:
197
223
  tool = self.tools[call.tool_name]
198
224
  args = json.loads(call.tool_arguments)
199
- corouts.append(tool(ctx=ctx, **args))
225
+ corouts.append(tool(call_id=call_id, ctx=ctx, **args))
200
226
 
201
227
  outs = await asyncio.gather(*corouts)
202
228
  tool_messages = list(
@@ -217,7 +243,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
217
243
  calls: Sequence[ToolCall],
218
244
  memory: LLMAgentMemory,
219
245
  call_id: str,
220
- ctx: RunContext[CtxT] | None = None,
246
+ ctx: RunContext[CtxT],
221
247
  ) -> AsyncIterator[ToolMessageEvent]:
222
248
  tool_messages = await self.call_tools(
223
249
  calls, memory=memory, call_id=call_id, ctx=ctx
@@ -245,7 +271,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
245
271
  return final_answer_message
246
272
 
247
273
  async def _generate_final_answer(
248
- self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
274
+ self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
249
275
  ) -> AssistantMessage:
250
276
  user_message = UserMessage.from_text(
251
277
  "Exceeded the maximum number of turns: provide a final answer now!"
@@ -268,7 +294,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
268
294
  return final_answer_message
269
295
 
270
296
  async def _generate_final_answer_stream(
271
- self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
297
+ self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
272
298
  ) -> AsyncIterator[Event[Any]]:
273
299
  user_message = UserMessage.from_text(
274
300
  "Exceeded the maximum number of turns: provide a final answer now!",
@@ -296,7 +322,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
296
322
  )
297
323
 
298
324
  async def execute(
299
- self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT] | None = None
325
+ self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
300
326
  ) -> AssistantMessage | Sequence[AssistantMessage]:
301
327
  # 1. Generate the first message:
302
328
  # In ReAct mode, we generate the first message without tool calls
@@ -379,7 +405,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
379
405
  self,
380
406
  memory: LLMAgentMemory,
381
407
  call_id: str,
382
- ctx: RunContext[CtxT] | None = None,
408
+ ctx: RunContext[CtxT],
383
409
  ) -> AsyncIterator[Event[Any]]:
384
410
  tool_choice: ToolChoice = "none" if self._react_mode else "auto"
385
411
  gen_message: AssistantMessage | None = None
@@ -464,7 +490,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
464
490
  )
465
491
 
466
492
  async def run(
467
- self, inp: BaseModel, ctx: RunContext[Any] | None = None
493
+ self,
494
+ inp: BaseModel,
495
+ *,
496
+ call_id: str | None = None,
497
+ ctx: RunContext[Any] | None = None,
468
498
  ) -> None:
469
499
  return None
470
500
 
@@ -473,22 +503,22 @@ class LLMPolicyExecutor(Generic[CtxT]):
473
503
  def _process_completion(
474
504
  self,
475
505
  completion: Completion,
506
+ *,
476
507
  call_id: str,
477
508
  print_messages: bool = False,
478
- ctx: RunContext[CtxT] | None = None,
509
+ ctx: RunContext[CtxT],
479
510
  ) -> None:
480
- if ctx is not None:
481
- ctx.completions[self.agent_name].append(completion)
482
- ctx.usage_tracker.update(
511
+ ctx.completions[self.agent_name].append(completion)
512
+ ctx.usage_tracker.update(
513
+ agent_name=self.agent_name,
514
+ completions=[completion],
515
+ model_name=self.llm.model_name,
516
+ )
517
+ if ctx.printer and print_messages:
518
+ usages = [None] * (len(completion.messages) - 1) + [completion.usage]
519
+ ctx.printer.print_messages(
520
+ completion.messages,
521
+ usages=usages,
483
522
  agent_name=self.agent_name,
484
- completions=[completion],
485
- model_name=self.llm.model_name,
523
+ call_id=call_id,
486
524
  )
487
- if ctx.printer and print_messages:
488
- usages = [None] * (len(completion.messages) - 1) + [completion.usage]
489
- ctx.printer.print_messages(
490
- completion.messages,
491
- usages=usages,
492
- agent_name=self.agent_name,
493
- call_id=call_id,
494
- )
@@ -96,8 +96,10 @@ class OpenAIConverters(Converters):
96
96
  return from_api_tool_message(raw_message, name=name, **kwargs)
97
97
 
98
98
  @staticmethod
99
- def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> OpenAIToolParam:
100
- return to_api_tool(tool, **kwargs)
99
+ def to_tool(
100
+ tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
101
+ ) -> OpenAIToolParam:
102
+ return to_api_tool(tool, strict=strict, **kwargs)
101
103
 
102
104
  @staticmethod
103
105
  def to_tool_choice(
@@ -3,9 +3,9 @@ import logging
3
3
  import os
4
4
  from collections.abc import AsyncIterator, Iterable, Mapping
5
5
  from copy import deepcopy
6
+ from dataclasses import dataclass, field
6
7
  from typing import Any, Literal
7
8
 
8
- import httpx
9
9
  from openai import AsyncOpenAI, AsyncStream
10
10
  from openai._types import NOT_GIVEN # type: ignore[import]
11
11
  from openai.lib.streaming.chat import (
@@ -15,8 +15,7 @@ from openai.lib.streaming.chat import ChatCompletionStreamState
15
15
  from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
16
16
  from pydantic import BaseModel
17
17
 
18
- from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings, LLMRateLimiter
19
- from ..http_client import AsyncHTTPClientParams
18
+ from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
20
19
  from ..typing.tool import BaseTool
21
20
  from . import (
22
21
  OpenAICompletion,
@@ -90,97 +89,75 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
90
89
  # TODO: support audio
91
90
 
92
91
 
92
+ @dataclass(frozen=True)
93
93
  class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
94
- def __init__(
95
- self,
96
- # Base LLM args
97
- model_name: str,
98
- llm_settings: OpenAILLMSettings | None = None,
99
- tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
100
- response_schema: Any | None = None,
101
- response_schema_by_xml_tag: Mapping[str, Any] | None = None,
102
- apply_response_schema_via_provider: bool = False,
103
- model_id: str | None = None,
104
- # Custom LLM provider
105
- api_provider: APIProvider | None = None,
106
- # Connection settings
107
- max_client_retries: int = 2,
108
- async_http_client: httpx.AsyncClient | None = None,
109
- async_http_client_params: (
110
- dict[str, Any] | AsyncHTTPClientParams | None
111
- ) = None,
112
- async_openai_client_params: dict[str, Any] | None = None,
113
- # Rate limiting
114
- rate_limiter: LLMRateLimiter | None = None,
115
- # LLM response retries: try to regenerate to pass validation
116
- max_response_retries: int = 1,
117
- ) -> None:
94
+ converters: OpenAIConverters = field(default_factory=OpenAIConverters)
95
+ async_openai_client_params: dict[str, Any] | None = None
96
+ client: AsyncOpenAI = field(init=False)
97
+
98
+ def __post_init__(self):
99
+ super().__post_init__()
100
+
118
101
  openai_compatible_providers = get_openai_compatible_providers()
119
102
 
120
- model_name_parts = model_name.split("/", 1)
121
- if api_provider is not None:
122
- provider_model_name = model_name
103
+ _api_provider = self.api_provider
104
+
105
+ model_name_parts = self.model_name.split("/", 1)
106
+ if _api_provider is not None:
107
+ _model_name = self.model_name
123
108
  elif len(model_name_parts) == 2:
124
109
  compat_providers_map = {
125
110
  provider["name"]: provider for provider in openai_compatible_providers
126
111
  }
127
- provider_name, provider_model_name = model_name_parts
112
+ provider_name, _model_name = model_name_parts
128
113
  if provider_name not in compat_providers_map:
129
114
  raise ValueError(
130
- f"OpenAI compatible API provider '{provider_name}' "
131
- "is not supported. Supported providers are: "
115
+ f"API provider '{provider_name}' is not a supported OpenAI "
116
+ f"compatible provider. Supported providers are: "
132
117
  f"{', '.join(compat_providers_map.keys())}"
133
118
  )
134
- api_provider = compat_providers_map[provider_name]
119
+ _api_provider = compat_providers_map[provider_name]
135
120
  else:
136
121
  raise ValueError(
137
122
  "Model name must be in the format 'provider/model_name' or "
138
123
  "you must provide an 'api_provider' argument."
139
124
  )
140
125
 
141
- super().__init__(
142
- model_name=provider_model_name,
143
- model_id=model_id,
144
- llm_settings=llm_settings,
145
- converters=OpenAIConverters(),
146
- tools=tools,
147
- response_schema=response_schema,
148
- response_schema_by_xml_tag=response_schema_by_xml_tag,
149
- apply_response_schema_via_provider=apply_response_schema_via_provider,
150
- api_provider=api_provider,
151
- async_http_client=async_http_client,
152
- async_http_client_params=async_http_client_params,
153
- rate_limiter=rate_limiter,
154
- max_client_retries=max_client_retries,
155
- max_response_retries=max_response_retries,
156
- )
126
+ if self.llm_settings is not None:
127
+ stream_options = self.llm_settings.get("stream_options") or {}
128
+ stream_options["include_usage"] = True
129
+ _llm_settings = deepcopy(self.llm_settings)
130
+ _llm_settings["stream_options"] = stream_options
131
+ else:
132
+ _llm_settings = OpenAILLMSettings(stream_options={"include_usage": True})
157
133
 
158
134
  response_schema_support: bool = any(
159
- fnmatch.fnmatch(self._model_name, pat)
160
- for pat in api_provider.get("response_schema_support") or []
135
+ fnmatch.fnmatch(_model_name, pat)
136
+ for pat in _api_provider.get("response_schema_support") or []
161
137
  )
162
- if apply_response_schema_via_provider:
163
- if self._tools:
164
- for tool in self._tools.values():
165
- tool.strict = True
166
- if not response_schema_support:
167
- raise ValueError(
168
- "Native response schema validation is not supported for model "
169
- f"'{self._model_name}' by the API provider. Please set "
170
- "apply_response_schema_via_provider=False."
171
- )
138
+ if self.apply_response_schema_via_provider and not response_schema_support:
139
+ raise ValueError(
140
+ "Native response schema validation is not supported for model "
141
+ f"'{_model_name}' by the API provider. Please set "
142
+ "apply_response_schema_via_provider=False."
143
+ )
172
144
 
173
- _async_openai_client_params = deepcopy(async_openai_client_params or {})
174
- if self._async_http_client is not None:
175
- _async_openai_client_params["http_client"] = self._async_http_client
145
+ _async_openai_client_params = deepcopy(self.async_openai_client_params or {})
146
+ if self.async_http_client is not None:
147
+ _async_openai_client_params["http_client"] = self.async_http_client
176
148
 
177
- self._client: AsyncOpenAI = AsyncOpenAI(
178
- base_url=self.api_provider.get("base_url"),
179
- api_key=self.api_provider.get("api_key"),
180
- max_retries=max_client_retries,
149
+ _client = AsyncOpenAI(
150
+ base_url=_api_provider.get("base_url"),
151
+ api_key=_api_provider.get("api_key"),
152
+ max_retries=self.max_client_retries,
181
153
  **_async_openai_client_params,
182
154
  )
183
155
 
156
+ object.__setattr__(self, "model_name", _model_name)
157
+ object.__setattr__(self, "api_provider", _api_provider)
158
+ object.__setattr__(self, "llm_settings", _llm_settings)
159
+ object.__setattr__(self, "client", _client)
160
+
184
161
  async def _get_completion(
185
162
  self,
186
163
  api_messages: Iterable[OpenAIMessageParam],
@@ -195,9 +172,9 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
195
172
  response_format = api_response_schema or NOT_GIVEN
196
173
  n = n_choices or NOT_GIVEN
197
174
 
198
- if self._apply_response_schema_via_provider:
199
- return await self._client.beta.chat.completions.parse(
200
- model=self._model_name,
175
+ if self.apply_response_schema_via_provider:
176
+ return await self.client.beta.chat.completions.parse(
177
+ model=self.model_name,
201
178
  messages=api_messages,
202
179
  tools=tools,
203
180
  tool_choice=tool_choice,
@@ -206,8 +183,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
206
183
  **api_llm_settings,
207
184
  )
208
185
 
209
- return await self._client.chat.completions.create(
210
- model=self._model_name,
186
+ return await self.client.chat.completions.create(
187
+ model=self.model_name,
211
188
  messages=api_messages,
212
189
  tools=tools,
213
190
  tool_choice=tool_choice,
@@ -230,10 +207,10 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
230
207
  response_format = api_response_schema or NOT_GIVEN
231
208
  n = n_choices or NOT_GIVEN
232
209
 
233
- if self._apply_response_schema_via_provider:
210
+ if self.apply_response_schema_via_provider:
234
211
  stream_manager: OpenAIAsyncChatCompletionStreamManager[Any] = (
235
- self._client.beta.chat.completions.stream(
236
- model=self._model_name,
212
+ self.client.beta.chat.completions.stream(
213
+ model=self.model_name,
237
214
  messages=api_messages,
238
215
  tools=tools,
239
216
  tool_choice=tool_choice,
@@ -249,8 +226,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
249
226
  else:
250
227
  stream_generator: AsyncStream[
251
228
  OpenAICompletionChunk
252
- ] = await self._client.chat.completions.create(
253
- model=self._model_name,
229
+ ] = await self.client.chat.completions.create(
230
+ model=self.model_name,
254
231
  messages=api_messages,
255
232
  tools=tools,
256
233
  tool_choice=tool_choice,
@@ -263,16 +240,20 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
263
240
  yield completion_chunk
264
241
 
265
242
  def combine_completion_chunks(
266
- self, completion_chunks: list[OpenAICompletionChunk]
243
+ self,
244
+ completion_chunks: list[OpenAICompletionChunk],
245
+ response_schema: Any | None = None,
246
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
267
247
  ) -> OpenAICompletion:
268
248
  response_format = NOT_GIVEN
269
249
  input_tools = NOT_GIVEN
270
- if self._apply_response_schema_via_provider:
271
- if self._response_schema:
272
- response_format = self._response_schema
273
- if self._tools:
250
+ if self.apply_response_schema_via_provider:
251
+ if response_schema:
252
+ response_format = response_schema
253
+ if tools:
274
254
  input_tools = [
275
- self._converters.to_tool(tool) for tool in self._tools.values()
255
+ self.converters.to_tool(tool, strict=True)
256
+ for tool in tools.values()
276
257
  ]
277
258
  state = ChatCompletionStreamState[Any](
278
259
  input_tools=input_tools, response_format=response_format
@@ -13,8 +13,10 @@ from . import (
13
13
  )
14
14
 
15
15
 
16
- def to_api_tool(tool: BaseTool[BaseModel, Any, Any]) -> OpenAIToolParam:
17
- if tool.strict:
16
+ def to_api_tool(
17
+ tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None
18
+ ) -> OpenAIToolParam:
19
+ if strict:
18
20
  return pydantic_function_tool(
19
21
  model=tool.in_type, name=tool.name, description=tool.description
20
22
  )
@@ -23,9 +25,9 @@ def to_api_tool(tool: BaseTool[BaseModel, Any, Any]) -> OpenAIToolParam:
23
25
  name=tool.name,
24
26
  description=tool.description,
25
27
  parameters=tool.in_type.model_json_schema(),
26
- strict=tool.strict,
28
+ strict=strict,
27
29
  )
28
- if tool.strict is None:
30
+ if strict is None:
29
31
  function.pop("strict")
30
32
 
31
33
  return OpenAIToolParam(type="function", function=function)
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
- from collections.abc import AsyncIterator, Callable, Coroutine
3
+ from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
4
4
  from functools import wraps
5
5
  from typing import (
6
6
  Any,
@@ -37,7 +37,6 @@ from ..typing.tool import BaseTool
37
37
 
38
38
  logger = logging.getLogger(__name__)
39
39
 
40
- _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
41
40
 
42
41
  F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, Packet[Any]]])
43
42
  F_stream = TypeVar("F_stream", bound=Callable[..., AsyncIterator[Event[Any]]])
@@ -102,10 +101,13 @@ def with_retry_stream(func: F_stream) -> F_stream:
102
101
  return cast("F_stream", wrapper)
103
102
 
104
103
 
104
+ _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
105
+
106
+
105
107
  class RecipientSelector(Protocol[_OutT_contra, CtxT]):
106
108
  def __call__(
107
- self, output: _OutT_contra, ctx: RunContext[CtxT] | None
108
- ) -> list[ProcName] | None: ...
109
+ self, output: _OutT_contra, ctx: RunContext[CtxT]
110
+ ) -> Sequence[ProcName] | None: ...
109
111
 
110
112
 
111
113
  class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
@@ -118,7 +120,7 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
118
120
  self,
119
121
  name: ProcName,
120
122
  max_retries: int = 0,
121
- recipients: list[ProcName] | None = None,
123
+ recipients: Sequence[ProcName] | None = None,
122
124
  **kwargs: Any,
123
125
  ) -> None:
124
126
  self._in_type: type[InT]
@@ -239,7 +241,7 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
239
241
  ) from err
240
242
 
241
243
  def _validate_recipients(
242
- self, recipients: list[ProcName] | None, call_id: str
244
+ self, recipients: Sequence[ProcName] | None, call_id: str
243
245
  ) -> None:
244
246
  for r in recipients or []:
245
247
  if r not in (self.recipients or []):
@@ -252,8 +254,8 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
252
254
 
253
255
  @final
254
256
  def _select_recipients(
255
- self, output: OutT, ctx: RunContext[CtxT] | None = None
256
- ) -> list[ProcName] | None:
257
+ self, output: OutT, ctx: RunContext[CtxT]
258
+ ) -> Sequence[ProcName] | None:
257
259
  if self.recipient_selector:
258
260
  return self.recipient_selector(output=output, ctx=ctx)
259
261
 
@@ -310,9 +312,15 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
310
312
  name: str = tool_name
311
313
  description: str = tool_description
312
314
 
313
- async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
315
+ async def run(
316
+ self,
317
+ inp: InT,
318
+ *,
319
+ call_id: str | None = None,
320
+ ctx: RunContext[CtxT] | None = None,
321
+ ) -> OutT:
314
322
  result = await processor_instance.run(
315
- in_args=inp, forgetful=True, ctx=ctx
323
+ in_args=inp, forgetful=True, call_id=call_id, ctx=ctx
316
324
  )
317
325
 
318
326
  return result.payloads[0]
@@ -30,7 +30,7 @@ class ParallelProcessor(
30
30
  in_args: InT | None = None,
31
31
  memory: MemT,
32
32
  call_id: str,
33
- ctx: RunContext[CtxT] | None = None,
33
+ ctx: RunContext[CtxT],
34
34
  ) -> OutT:
35
35
  return cast("OutT", in_args)
36
36
 
@@ -41,7 +41,7 @@ class ParallelProcessor(
41
41
  in_args: InT | None = None,
42
42
  memory: MemT,
43
43
  call_id: str,
44
- ctx: RunContext[CtxT] | None = None,
44
+ ctx: RunContext[CtxT],
45
45
  ) -> AsyncIterator[Event[Any]]:
46
46
  output = cast("OutT", in_args)
47
47
  yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
@@ -67,7 +67,7 @@ class ParallelProcessor(
67
67
  in_args: InT | None = None,
68
68
  forgetful: bool = False,
69
69
  call_id: str,
70
- ctx: RunContext[CtxT] | None = None,
70
+ ctx: RunContext[CtxT],
71
71
  ) -> Packet[OutT]:
72
72
  memory = self.memory.model_copy(deep=True) if forgetful else self.memory
73
73
 
@@ -86,7 +86,7 @@ class ParallelProcessor(
86
86
  return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
87
87
 
88
88
  async def _run_parallel(
89
- self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
89
+ self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT]
90
90
  ) -> Packet[OutT]:
91
91
  tasks = [
92
92
  self._run_single(
@@ -114,6 +114,7 @@ class ParallelProcessor(
114
114
  ctx: RunContext[CtxT] | None = None,
115
115
  ) -> Packet[OutT]:
116
116
  call_id = self._generate_call_id(call_id)
117
+ ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
117
118
 
118
119
  val_in_args = self._validate_inputs(
119
120
  call_id=call_id,
@@ -143,7 +144,7 @@ class ParallelProcessor(
143
144
  in_args: InT | None = None,
144
145
  forgetful: bool = False,
145
146
  call_id: str,
146
- ctx: RunContext[CtxT] | None = None,
147
+ ctx: RunContext[CtxT],
147
148
  ) -> AsyncIterator[Event[Any]]:
148
149
  memory = self.memory.model_copy(deep=True) if forgetful else self.memory
149
150
 
@@ -178,7 +179,7 @@ class ParallelProcessor(
178
179
  self,
179
180
  in_args: list[InT],
180
181
  call_id: str,
181
- ctx: RunContext[CtxT] | None = None,
182
+ ctx: RunContext[CtxT],
182
183
  ) -> AsyncIterator[Event[Any]]:
183
184
  streams = [
184
185
  self._run_single_stream(
@@ -222,6 +223,7 @@ class ParallelProcessor(
222
223
  ctx: RunContext[CtxT] | None = None,
223
224
  ) -> AsyncIterator[Event[Any]]:
224
225
  call_id = self._generate_call_id(call_id)
226
+ ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
225
227
 
226
228
  val_in_args = self._validate_inputs(
227
229
  call_id=call_id,