grasp_agents 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,9 +3,9 @@ import logging
3
3
  import os
4
4
  from collections.abc import AsyncIterator, Iterable, Mapping
5
5
  from copy import deepcopy
6
+ from dataclasses import dataclass, field
6
7
  from typing import Any, Literal
7
8
 
8
- import httpx
9
9
  from openai import AsyncOpenAI, AsyncStream
10
10
  from openai._types import NOT_GIVEN # type: ignore[import]
11
11
  from openai.lib.streaming.chat import (
@@ -15,8 +15,7 @@ from openai.lib.streaming.chat import ChatCompletionStreamState
15
15
  from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
16
16
  from pydantic import BaseModel
17
17
 
18
- from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings, LLMRateLimiter
19
- from ..http_client import AsyncHTTPClientParams
18
+ from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
20
19
  from ..typing.tool import BaseTool
21
20
  from . import (
22
21
  OpenAICompletion,
@@ -90,105 +89,75 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
90
89
  # TODO: support audio
91
90
 
92
91
 
92
+ @dataclass(frozen=True)
93
93
  class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
94
- def __init__(
95
- self,
96
- # Base LLM args
97
- model_name: str,
98
- llm_settings: OpenAILLMSettings | None = None,
99
- tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
100
- response_schema: Any | None = None,
101
- response_schema_by_xml_tag: Mapping[str, Any] | None = None,
102
- apply_response_schema_via_provider: bool = False,
103
- model_id: str | None = None,
104
- # Custom LLM provider
105
- api_provider: APIProvider | None = None,
106
- # Connection settings
107
- max_client_retries: int = 2,
108
- async_http_client: httpx.AsyncClient | None = None,
109
- async_http_client_params: (
110
- dict[str, Any] | AsyncHTTPClientParams | None
111
- ) = None,
112
- async_openai_client_params: dict[str, Any] | None = None,
113
- # Rate limiting
114
- rate_limiter: LLMRateLimiter | None = None,
115
- # LLM response retries: try to regenerate to pass validation
116
- max_response_retries: int = 1,
117
- ) -> None:
94
+ converters: OpenAIConverters = field(default_factory=OpenAIConverters)
95
+ async_openai_client_params: dict[str, Any] | None = None
96
+ client: AsyncOpenAI = field(init=False)
97
+
98
+ def __post_init__(self):
99
+ super().__post_init__()
100
+
118
101
  openai_compatible_providers = get_openai_compatible_providers()
119
102
 
120
- model_name_parts = model_name.split("/", 1)
121
- if api_provider is not None:
122
- provider_model_name = model_name
103
+ _api_provider = self.api_provider
104
+
105
+ model_name_parts = self.model_name.split("/", 1)
106
+ if _api_provider is not None:
107
+ _model_name = self.model_name
123
108
  elif len(model_name_parts) == 2:
124
109
  compat_providers_map = {
125
110
  provider["name"]: provider for provider in openai_compatible_providers
126
111
  }
127
- provider_name, provider_model_name = model_name_parts
112
+ provider_name, _model_name = model_name_parts
128
113
  if provider_name not in compat_providers_map:
129
114
  raise ValueError(
130
115
  f"API provider '{provider_name}' is not a supported OpenAI "
131
116
  f"compatible provider. Supported providers are: "
132
117
  f"{', '.join(compat_providers_map.keys())}"
133
118
  )
134
- api_provider = compat_providers_map[provider_name]
119
+ _api_provider = compat_providers_map[provider_name]
135
120
  else:
136
121
  raise ValueError(
137
122
  "Model name must be in the format 'provider/model_name' or "
138
123
  "you must provide an 'api_provider' argument."
139
124
  )
140
125
 
141
- if llm_settings is not None:
142
- stream_options = llm_settings.get("stream_options") or {}
126
+ if self.llm_settings is not None:
127
+ stream_options = self.llm_settings.get("stream_options") or {}
143
128
  stream_options["include_usage"] = True
144
- _llm_settings = deepcopy(llm_settings)
129
+ _llm_settings = deepcopy(self.llm_settings)
145
130
  _llm_settings["stream_options"] = stream_options
146
131
  else:
147
132
  _llm_settings = OpenAILLMSettings(stream_options={"include_usage": True})
148
133
 
149
- super().__init__(
150
- model_name=provider_model_name,
151
- model_id=model_id,
152
- llm_settings=_llm_settings,
153
- converters=OpenAIConverters(),
154
- tools=tools,
155
- response_schema=response_schema,
156
- response_schema_by_xml_tag=response_schema_by_xml_tag,
157
- apply_response_schema_via_provider=apply_response_schema_via_provider,
158
- api_provider=api_provider,
159
- async_http_client=async_http_client,
160
- async_http_client_params=async_http_client_params,
161
- rate_limiter=rate_limiter,
162
- max_client_retries=max_client_retries,
163
- max_response_retries=max_response_retries,
164
- )
165
-
166
134
  response_schema_support: bool = any(
167
- fnmatch.fnmatch(self._model_name, pat)
168
- for pat in api_provider.get("response_schema_support") or []
135
+ fnmatch.fnmatch(_model_name, pat)
136
+ for pat in _api_provider.get("response_schema_support") or []
169
137
  )
170
- if apply_response_schema_via_provider:
171
- if self._tools:
172
- for tool in self._tools.values():
173
- tool.strict = True
174
- if not response_schema_support:
175
- raise ValueError(
176
- "Native response schema validation is not supported for model "
177
- f"'{self._model_name}' by the API provider. Please set "
178
- "apply_response_schema_via_provider=False."
179
- )
138
+ if self.apply_response_schema_via_provider and not response_schema_support:
139
+ raise ValueError(
140
+ "Native response schema validation is not supported for model "
141
+ f"'{_model_name}' by the API provider. Please set "
142
+ "apply_response_schema_via_provider=False."
143
+ )
180
144
 
181
- _async_openai_client_params = deepcopy(async_openai_client_params or {})
182
- if self._async_http_client is not None:
183
- _async_openai_client_params["http_client"] = self._async_http_client
145
+ _async_openai_client_params = deepcopy(self.async_openai_client_params or {})
146
+ if self.async_http_client is not None:
147
+ _async_openai_client_params["http_client"] = self.async_http_client
184
148
 
185
- self._client: AsyncOpenAI = AsyncOpenAI(
186
- base_url=self.api_provider.get("base_url"),
187
- api_key=self.api_provider.get("api_key"),
188
- max_retries=max_client_retries,
149
+ _client = AsyncOpenAI(
150
+ base_url=_api_provider.get("base_url"),
151
+ api_key=_api_provider.get("api_key"),
152
+ max_retries=self.max_client_retries,
189
153
  **_async_openai_client_params,
190
154
  )
191
155
 
156
+ object.__setattr__(self, "model_name", _model_name)
157
+ object.__setattr__(self, "api_provider", _api_provider)
158
+ object.__setattr__(self, "llm_settings", _llm_settings)
159
+ object.__setattr__(self, "client", _client)
160
+
192
161
  async def _get_completion(
193
162
  self,
194
163
  api_messages: Iterable[OpenAIMessageParam],
@@ -203,9 +172,9 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
203
172
  response_format = api_response_schema or NOT_GIVEN
204
173
  n = n_choices or NOT_GIVEN
205
174
 
206
- if self._apply_response_schema_via_provider:
207
- return await self._client.beta.chat.completions.parse(
208
- model=self._model_name,
175
+ if self.apply_response_schema_via_provider:
176
+ return await self.client.beta.chat.completions.parse(
177
+ model=self.model_name,
209
178
  messages=api_messages,
210
179
  tools=tools,
211
180
  tool_choice=tool_choice,
@@ -214,8 +183,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
214
183
  **api_llm_settings,
215
184
  )
216
185
 
217
- return await self._client.chat.completions.create(
218
- model=self._model_name,
186
+ return await self.client.chat.completions.create(
187
+ model=self.model_name,
219
188
  messages=api_messages,
220
189
  tools=tools,
221
190
  tool_choice=tool_choice,
@@ -238,10 +207,10 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
238
207
  response_format = api_response_schema or NOT_GIVEN
239
208
  n = n_choices or NOT_GIVEN
240
209
 
241
- if self._apply_response_schema_via_provider:
210
+ if self.apply_response_schema_via_provider:
242
211
  stream_manager: OpenAIAsyncChatCompletionStreamManager[Any] = (
243
- self._client.beta.chat.completions.stream(
244
- model=self._model_name,
212
+ self.client.beta.chat.completions.stream(
213
+ model=self.model_name,
245
214
  messages=api_messages,
246
215
  tools=tools,
247
216
  tool_choice=tool_choice,
@@ -257,8 +226,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
257
226
  else:
258
227
  stream_generator: AsyncStream[
259
228
  OpenAICompletionChunk
260
- ] = await self._client.chat.completions.create(
261
- model=self._model_name,
229
+ ] = await self.client.chat.completions.create(
230
+ model=self.model_name,
262
231
  messages=api_messages,
263
232
  tools=tools,
264
233
  tool_choice=tool_choice,
@@ -271,16 +240,20 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
271
240
  yield completion_chunk
272
241
 
273
242
  def combine_completion_chunks(
274
- self, completion_chunks: list[OpenAICompletionChunk]
243
+ self,
244
+ completion_chunks: list[OpenAICompletionChunk],
245
+ response_schema: Any | None = None,
246
+ tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
275
247
  ) -> OpenAICompletion:
276
248
  response_format = NOT_GIVEN
277
249
  input_tools = NOT_GIVEN
278
- if self._apply_response_schema_via_provider:
279
- if self._response_schema:
280
- response_format = self._response_schema
281
- if self._tools:
250
+ if self.apply_response_schema_via_provider:
251
+ if response_schema:
252
+ response_format = response_schema
253
+ if tools:
282
254
  input_tools = [
283
- self._converters.to_tool(tool) for tool in self._tools.values()
255
+ self.converters.to_tool(tool, strict=True)
256
+ for tool in tools.values()
284
257
  ]
285
258
  state = ChatCompletionStreamState[Any](
286
259
  input_tools=input_tools, response_format=response_format
@@ -13,8 +13,10 @@ from . import (
13
13
  )
14
14
 
15
15
 
16
- def to_api_tool(tool: BaseTool[BaseModel, Any, Any]) -> OpenAIToolParam:
17
- if tool.strict:
16
+ def to_api_tool(
17
+ tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None
18
+ ) -> OpenAIToolParam:
19
+ if strict:
18
20
  return pydantic_function_tool(
19
21
  model=tool.in_type, name=tool.name, description=tool.description
20
22
  )
@@ -23,9 +25,9 @@ def to_api_tool(tool: BaseTool[BaseModel, Any, Any]) -> OpenAIToolParam:
23
25
  name=tool.name,
24
26
  description=tool.description,
25
27
  parameters=tool.in_type.model_json_schema(),
26
- strict=tool.strict,
28
+ strict=strict,
27
29
  )
28
- if tool.strict is None:
30
+ if strict is None:
29
31
  function.pop("strict")
30
32
 
31
33
  return OpenAIToolParam(type="function", function=function)
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
- from collections.abc import AsyncIterator, Callable, Coroutine
3
+ from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
4
4
  from functools import wraps
5
5
  from typing import (
6
6
  Any,
@@ -37,7 +37,6 @@ from ..typing.tool import BaseTool
37
37
 
38
38
  logger = logging.getLogger(__name__)
39
39
 
40
- _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
41
40
 
42
41
  F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, Packet[Any]]])
43
42
  F_stream = TypeVar("F_stream", bound=Callable[..., AsyncIterator[Event[Any]]])
@@ -102,10 +101,13 @@ def with_retry_stream(func: F_stream) -> F_stream:
102
101
  return cast("F_stream", wrapper)
103
102
 
104
103
 
104
+ _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
105
+
106
+
105
107
  class RecipientSelector(Protocol[_OutT_contra, CtxT]):
106
108
  def __call__(
107
- self, output: _OutT_contra, ctx: RunContext[CtxT] | None
108
- ) -> list[ProcName] | None: ...
109
+ self, output: _OutT_contra, ctx: RunContext[CtxT]
110
+ ) -> Sequence[ProcName] | None: ...
109
111
 
110
112
 
111
113
  class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
@@ -118,7 +120,7 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
118
120
  self,
119
121
  name: ProcName,
120
122
  max_retries: int = 0,
121
- recipients: list[ProcName] | None = None,
123
+ recipients: Sequence[ProcName] | None = None,
122
124
  **kwargs: Any,
123
125
  ) -> None:
124
126
  self._in_type: type[InT]
@@ -239,7 +241,7 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
239
241
  ) from err
240
242
 
241
243
  def _validate_recipients(
242
- self, recipients: list[ProcName] | None, call_id: str
244
+ self, recipients: Sequence[ProcName] | None, call_id: str
243
245
  ) -> None:
244
246
  for r in recipients or []:
245
247
  if r not in (self.recipients or []):
@@ -252,8 +254,8 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
252
254
 
253
255
  @final
254
256
  def _select_recipients(
255
- self, output: OutT, ctx: RunContext[CtxT] | None = None
256
- ) -> list[ProcName] | None:
257
+ self, output: OutT, ctx: RunContext[CtxT]
258
+ ) -> Sequence[ProcName] | None:
257
259
  if self.recipient_selector:
258
260
  return self.recipient_selector(output=output, ctx=ctx)
259
261
 
@@ -310,9 +312,15 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
310
312
  name: str = tool_name
311
313
  description: str = tool_description
312
314
 
313
- async def run(self, inp: InT, ctx: RunContext[CtxT] | None = None) -> OutT:
315
+ async def run(
316
+ self,
317
+ inp: InT,
318
+ *,
319
+ call_id: str | None = None,
320
+ ctx: RunContext[CtxT] | None = None,
321
+ ) -> OutT:
314
322
  result = await processor_instance.run(
315
- in_args=inp, forgetful=True, ctx=ctx
323
+ in_args=inp, forgetful=True, call_id=call_id, ctx=ctx
316
324
  )
317
325
 
318
326
  return result.payloads[0]
@@ -30,7 +30,7 @@ class ParallelProcessor(
30
30
  in_args: InT | None = None,
31
31
  memory: MemT,
32
32
  call_id: str,
33
- ctx: RunContext[CtxT] | None = None,
33
+ ctx: RunContext[CtxT],
34
34
  ) -> OutT:
35
35
  return cast("OutT", in_args)
36
36
 
@@ -41,7 +41,7 @@ class ParallelProcessor(
41
41
  in_args: InT | None = None,
42
42
  memory: MemT,
43
43
  call_id: str,
44
- ctx: RunContext[CtxT] | None = None,
44
+ ctx: RunContext[CtxT],
45
45
  ) -> AsyncIterator[Event[Any]]:
46
46
  output = cast("OutT", in_args)
47
47
  yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
@@ -67,7 +67,7 @@ class ParallelProcessor(
67
67
  in_args: InT | None = None,
68
68
  forgetful: bool = False,
69
69
  call_id: str,
70
- ctx: RunContext[CtxT] | None = None,
70
+ ctx: RunContext[CtxT],
71
71
  ) -> Packet[OutT]:
72
72
  memory = self.memory.model_copy(deep=True) if forgetful else self.memory
73
73
 
@@ -86,7 +86,7 @@ class ParallelProcessor(
86
86
  return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
87
87
 
88
88
  async def _run_parallel(
89
- self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
89
+ self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT]
90
90
  ) -> Packet[OutT]:
91
91
  tasks = [
92
92
  self._run_single(
@@ -114,6 +114,7 @@ class ParallelProcessor(
114
114
  ctx: RunContext[CtxT] | None = None,
115
115
  ) -> Packet[OutT]:
116
116
  call_id = self._generate_call_id(call_id)
117
+ ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
117
118
 
118
119
  val_in_args = self._validate_inputs(
119
120
  call_id=call_id,
@@ -143,7 +144,7 @@ class ParallelProcessor(
143
144
  in_args: InT | None = None,
144
145
  forgetful: bool = False,
145
146
  call_id: str,
146
- ctx: RunContext[CtxT] | None = None,
147
+ ctx: RunContext[CtxT],
147
148
  ) -> AsyncIterator[Event[Any]]:
148
149
  memory = self.memory.model_copy(deep=True) if forgetful else self.memory
149
150
 
@@ -178,7 +179,7 @@ class ParallelProcessor(
178
179
  self,
179
180
  in_args: list[InT],
180
181
  call_id: str,
181
- ctx: RunContext[CtxT] | None = None,
182
+ ctx: RunContext[CtxT],
182
183
  ) -> AsyncIterator[Event[Any]]:
183
184
  streams = [
184
185
  self._run_single_stream(
@@ -222,6 +223,7 @@ class ParallelProcessor(
222
223
  ctx: RunContext[CtxT] | None = None,
223
224
  ) -> AsyncIterator[Event[Any]]:
224
225
  call_id = self._generate_call_id(call_id)
226
+ ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
225
227
 
226
228
  val_in_args = self._validate_inputs(
227
229
  call_id=call_id,
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from collections.abc import AsyncIterator
2
+ from collections.abc import AsyncIterator, Sequence
3
3
  from typing import Any, ClassVar, Generic, cast
4
4
 
5
5
  from ..memory import MemT
@@ -25,7 +25,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
25
25
  in_args: list[InT] | None = None,
26
26
  memory: MemT,
27
27
  call_id: str,
28
- ctx: RunContext[CtxT] | None = None,
28
+ ctx: RunContext[CtxT],
29
29
  ) -> list[OutT]:
30
30
  return cast("list[OutT]", in_args)
31
31
 
@@ -36,7 +36,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
36
36
  in_args: list[InT] | None = None,
37
37
  memory: MemT,
38
38
  call_id: str,
39
- ctx: RunContext[CtxT] | None = None,
39
+ ctx: RunContext[CtxT],
40
40
  ) -> AsyncIterator[Event[Any]]:
41
41
  outputs = await self._process(
42
42
  chat_inputs=chat_inputs,
@@ -58,7 +58,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
58
58
  in_args: InT | list[InT] | None = None,
59
59
  forgetful: bool = False,
60
60
  call_id: str | None = None,
61
- ctx: RunContext[CtxT] | None = None,
61
+ ctx: RunContext[CtxT],
62
62
  ) -> tuple[list[InT] | None, MemT, str]:
63
63
  call_id = self._generate_call_id(call_id)
64
64
 
@@ -74,10 +74,10 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
74
74
  return val_in_args, memory, call_id
75
75
 
76
76
  def _postprocess(
77
- self, outputs: list[OutT], call_id: str, ctx: RunContext[CtxT] | None = None
77
+ self, outputs: list[OutT], call_id: str, ctx: RunContext[CtxT]
78
78
  ) -> Packet[OutT]:
79
79
  payloads: list[OutT] = []
80
- routing: dict[int, list[ProcName] | None] = {}
80
+ routing: dict[int, Sequence[ProcName] | None] = {}
81
81
  for idx, output in enumerate(outputs):
82
82
  val_output = self._validate_output(output, call_id=call_id)
83
83
  recipients = self._select_recipients(output=val_output, ctx=ctx)
@@ -105,6 +105,8 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
105
105
  call_id: str | None = None,
106
106
  ctx: RunContext[CtxT] | None = None,
107
107
  ) -> Packet[OutT]:
108
+ ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
109
+
108
110
  val_in_args, memory, call_id = self._preprocess(
109
111
  chat_inputs=chat_inputs,
110
112
  in_packet=in_packet,
@@ -134,6 +136,8 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
134
136
  call_id: str | None = None,
135
137
  ctx: RunContext[CtxT] | None = None,
136
138
  ) -> AsyncIterator[Event[Any]]:
139
+ ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
140
+
137
141
  val_in_args, memory, call_id = self._preprocess(
138
142
  chat_inputs=chat_inputs,
139
143
  in_packet=in_packet,
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from collections.abc import Sequence
3
- from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar, final
3
+ from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar, cast, final
4
4
 
5
5
  from pydantic import BaseModel, TypeAdapter
6
6
 
@@ -15,13 +15,11 @@ _InT_contra = TypeVar("_InT_contra", contravariant=True)
15
15
 
16
16
 
17
17
  class SystemPromptBuilder(Protocol[CtxT]):
18
- def __call__(self, ctx: RunContext[CtxT] | None) -> str | None: ...
18
+ def __call__(self, ctx: RunContext[CtxT]) -> str | None: ...
19
19
 
20
20
 
21
21
  class InputContentBuilder(Protocol[_InT_contra, CtxT]):
22
- def __call__(
23
- self, in_args: _InT_contra | None, *, ctx: RunContext[CtxT] | None
24
- ) -> Content: ...
22
+ def __call__(self, in_args: _InT_contra, ctx: RunContext[CtxT]) -> Content: ...
25
23
 
26
24
 
27
25
  PromptArgumentType: TypeAlias = str | bool | int | ImageData
@@ -45,7 +43,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
45
43
  self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
46
44
 
47
45
  @final
48
- def build_system_prompt(self, ctx: RunContext[CtxT] | None = None) -> str | None:
46
+ def build_system_prompt(self, ctx: RunContext[CtxT]) -> str | None:
49
47
  if self.system_prompt_builder:
50
48
  return self.system_prompt_builder(ctx=ctx)
51
49
 
@@ -73,23 +71,19 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
73
71
 
74
72
  @final
75
73
  def _build_input_content(
76
- self,
77
- in_args: InT | None = None,
78
- ctx: RunContext[CtxT] | None = None,
74
+ self, in_args: InT | None, ctx: RunContext[CtxT]
79
75
  ) -> Content:
80
- val_in_args = in_args
81
- if in_args is not None:
82
- val_in_args = self._validate_input_args(in_args=in_args)
83
-
84
- if self.input_content_builder:
85
- return self.input_content_builder(in_args=val_in_args, ctx=ctx)
86
-
87
- if val_in_args is None:
76
+ if in_args is None and self._in_type is not type(None):
88
77
  raise InputPromptBuilderError(
89
78
  proc_name=self._agent_name,
90
- message="Input arguments are not provided, "
91
- f"but input content is required [agent_name={self._agent_name}]",
79
+ message="Either chat inputs or input arguments must be provided "
80
+ f"when input type is not None [agent_name={self._agent_name}]",
92
81
  )
82
+ in_args = cast("InT", in_args)
83
+
84
+ val_in_args = self._validate_input_args(in_args)
85
+ if self.input_content_builder:
86
+ return self.input_content_builder(in_args=val_in_args, ctx=ctx)
93
87
 
94
88
  if issubclass(self._in_type, BaseModel) and isinstance(val_in_args, BaseModel):
95
89
  val_in_args_map = self._format_pydantic_prompt_args(val_in_args)
@@ -106,17 +100,17 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
106
100
  def build_input_message(
107
101
  self,
108
102
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
103
+ *,
109
104
  in_args: InT | None = None,
110
- ctx: RunContext[CtxT] | None = None,
105
+ ctx: RunContext[CtxT],
111
106
  ) -> UserMessage | None:
112
- if chat_inputs is not None and in_args is not None:
113
- raise InputPromptBuilderError(
114
- proc_name=self._agent_name,
115
- message="Cannot use both chat inputs and input arguments "
116
- f"at the same time [agent_name={self._agent_name}]",
117
- )
118
-
119
- if chat_inputs:
107
+ if chat_inputs is not None:
108
+ if in_args is not None:
109
+ raise InputPromptBuilderError(
110
+ proc_name=self._agent_name,
111
+ message="Cannot use both chat inputs and input arguments "
112
+ f"at the same time [agent_name={self._agent_name}]",
113
+ )
120
114
  if isinstance(chat_inputs, LLMPrompt):
121
115
  return UserMessage.from_text(chat_inputs, name=self._agent_name)
122
116
  return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
@@ -13,7 +13,7 @@ CtxT = TypeVar("CtxT")
13
13
 
14
14
 
15
15
  class RunContext(BaseModel, Generic[CtxT]):
16
- state: CtxT | None = None
16
+ state: CtxT = None # type: ignore
17
17
 
18
18
  completions: dict[ProcName, list[Completion]] = Field(
19
19
  default_factory=lambda: defaultdict(list)
grasp_agents/runner.py CHANGED
@@ -33,7 +33,7 @@ class Runner(Generic[OutT, CtxT]):
33
33
 
34
34
  self._entry_proc = entry_proc
35
35
  self._procs = procs
36
- self._ctx = ctx or RunContext[CtxT]()
36
+ self._ctx = ctx or RunContext[CtxT](state=None) # type: ignore
37
37
 
38
38
  @property
39
39
  def ctx(self) -> RunContext[CtxT]:
@@ -69,7 +69,9 @@ class Converters(ABC):
69
69
 
70
70
  @staticmethod
71
71
  @abstractmethod
72
- def to_tool(tool: BaseTool[BaseModel, Any, Any], **kwargs: Any) -> Any:
72
+ def to_tool(
73
+ tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
74
+ ) -> Any:
73
75
  pass
74
76
 
75
77
  @staticmethod
@@ -48,8 +48,6 @@ class BaseTool(
48
48
  name: str
49
49
  description: str
50
50
 
51
- strict: bool | None = None
52
-
53
51
  _in_type: type[_InT] = PrivateAttr()
54
52
  _out_type: type[_OutT_co] = PrivateAttr()
55
53
 
@@ -62,14 +60,24 @@ class BaseTool(
62
60
  return self._out_type
63
61
 
64
62
  @abstractmethod
65
- async def run(self, inp: _InT, ctx: RunContext[CtxT] | None = None) -> _OutT_co:
63
+ async def run(
64
+ self,
65
+ inp: _InT,
66
+ *,
67
+ call_id: str | None = None,
68
+ ctx: RunContext[CtxT] | None = None,
69
+ ) -> _OutT_co:
66
70
  pass
67
71
 
68
72
  async def __call__(
69
- self, ctx: RunContext[CtxT] | None = None, **kwargs: Any
73
+ self,
74
+ *,
75
+ call_id: str | None = None,
76
+ ctx: RunContext[CtxT] | None = None,
77
+ **kwargs: Any,
70
78
  ) -> _OutT_co:
71
79
  input_args = TypeAdapter(self._in_type).validate_python(kwargs)
72
- output = await self.run(input_args, ctx=ctx)
80
+ output = await self.run(input_args, call_id=call_id, ctx=ctx)
73
81
 
74
82
  return TypeAdapter(self._out_type).validate_python(output)
75
83
 
@@ -21,7 +21,7 @@ class WorkflowProcessor(
21
21
  subprocs: Sequence[BaseProcessor[Any, Any, Any, CtxT]],
22
22
  start_proc: BaseProcessor[InT, Any, Any, CtxT],
23
23
  end_proc: BaseProcessor[Any, OutT, Any, CtxT],
24
- recipients: list[ProcName] | None = None,
24
+ recipients: Sequence[ProcName] | None = None,
25
25
  max_retries: int = 0,
26
26
  ) -> None:
27
27
  super().__init__(name=name, recipients=recipients, max_retries=max_retries)
@@ -57,11 +57,11 @@ class WorkflowProcessor(
57
57
  return func
58
58
 
59
59
  @property
60
- def recipients(self) -> list[ProcName] | None:
60
+ def recipients(self) -> Sequence[ProcName] | None:
61
61
  return self._end_proc.recipients
62
62
 
63
63
  @recipients.setter
64
- def recipients(self, value: list[ProcName] | None) -> None:
64
+ def recipients(self, value: Sequence[ProcName] | None) -> None:
65
65
  if hasattr(self, "_end_proc"):
66
66
  self._end_proc.recipients = value
67
67
 
@@ -96,7 +96,7 @@ class WorkflowProcessor(
96
96
  pass
97
97
 
98
98
  @abstractmethod
99
- async def run_stream( # type: ignore[override]
99
+ async def run_stream(
100
100
  self,
101
101
  chat_inputs: Any | None = None,
102
102
  *,