grasp_agents 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +88 -109
- grasp_agents/litellm/converters.py +4 -2
- grasp_agents/litellm/lite_llm.py +72 -83
- grasp_agents/llm.py +35 -68
- grasp_agents/llm_agent.py +32 -36
- grasp_agents/llm_agent_memory.py +3 -2
- grasp_agents/llm_policy_executor.py +63 -33
- grasp_agents/openai/converters.py +4 -2
- grasp_agents/openai/openai_llm.py +60 -87
- grasp_agents/openai/tool_converters.py +6 -4
- grasp_agents/processors/base_processor.py +18 -10
- grasp_agents/processors/parallel_processor.py +8 -6
- grasp_agents/processors/processor.py +10 -6
- grasp_agents/prompt_builder.py +22 -28
- grasp_agents/run_context.py +1 -1
- grasp_agents/runner.py +1 -1
- grasp_agents/typing/converters.py +3 -1
- grasp_agents/typing/tool.py +13 -5
- grasp_agents/workflow/workflow_processor.py +4 -4
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.10.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.10.dist-info}/RECORD +23 -23
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.10.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.10.dist-info}/licenses/LICENSE.md +0 -0
@@ -3,9 +3,9 @@ import logging
|
|
3
3
|
import os
|
4
4
|
from collections.abc import AsyncIterator, Iterable, Mapping
|
5
5
|
from copy import deepcopy
|
6
|
+
from dataclasses import dataclass, field
|
6
7
|
from typing import Any, Literal
|
7
8
|
|
8
|
-
import httpx
|
9
9
|
from openai import AsyncOpenAI, AsyncStream
|
10
10
|
from openai._types import NOT_GIVEN # type: ignore[import]
|
11
11
|
from openai.lib.streaming.chat import (
|
@@ -15,8 +15,7 @@ from openai.lib.streaming.chat import ChatCompletionStreamState
|
|
15
15
|
from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
|
16
16
|
from pydantic import BaseModel
|
17
17
|
|
18
|
-
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
19
|
-
from ..http_client import AsyncHTTPClientParams
|
18
|
+
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
20
19
|
from ..typing.tool import BaseTool
|
21
20
|
from . import (
|
22
21
|
OpenAICompletion,
|
@@ -90,105 +89,75 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
|
|
90
89
|
# TODO: support audio
|
91
90
|
|
92
91
|
|
92
|
+
@dataclass(frozen=True)
|
93
93
|
class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
102
|
-
apply_response_schema_via_provider: bool = False,
|
103
|
-
model_id: str | None = None,
|
104
|
-
# Custom LLM provider
|
105
|
-
api_provider: APIProvider | None = None,
|
106
|
-
# Connection settings
|
107
|
-
max_client_retries: int = 2,
|
108
|
-
async_http_client: httpx.AsyncClient | None = None,
|
109
|
-
async_http_client_params: (
|
110
|
-
dict[str, Any] | AsyncHTTPClientParams | None
|
111
|
-
) = None,
|
112
|
-
async_openai_client_params: dict[str, Any] | None = None,
|
113
|
-
# Rate limiting
|
114
|
-
rate_limiter: LLMRateLimiter | None = None,
|
115
|
-
# LLM response retries: try to regenerate to pass validation
|
116
|
-
max_response_retries: int = 1,
|
117
|
-
) -> None:
|
94
|
+
converters: OpenAIConverters = field(default_factory=OpenAIConverters)
|
95
|
+
async_openai_client_params: dict[str, Any] | None = None
|
96
|
+
client: AsyncOpenAI = field(init=False)
|
97
|
+
|
98
|
+
def __post_init__(self):
|
99
|
+
super().__post_init__()
|
100
|
+
|
118
101
|
openai_compatible_providers = get_openai_compatible_providers()
|
119
102
|
|
120
|
-
|
121
|
-
|
122
|
-
|
103
|
+
_api_provider = self.api_provider
|
104
|
+
|
105
|
+
model_name_parts = self.model_name.split("/", 1)
|
106
|
+
if _api_provider is not None:
|
107
|
+
_model_name = self.model_name
|
123
108
|
elif len(model_name_parts) == 2:
|
124
109
|
compat_providers_map = {
|
125
110
|
provider["name"]: provider for provider in openai_compatible_providers
|
126
111
|
}
|
127
|
-
provider_name,
|
112
|
+
provider_name, _model_name = model_name_parts
|
128
113
|
if provider_name not in compat_providers_map:
|
129
114
|
raise ValueError(
|
130
115
|
f"API provider '{provider_name}' is not a supported OpenAI "
|
131
116
|
f"compatible provider. Supported providers are: "
|
132
117
|
f"{', '.join(compat_providers_map.keys())}"
|
133
118
|
)
|
134
|
-
|
119
|
+
_api_provider = compat_providers_map[provider_name]
|
135
120
|
else:
|
136
121
|
raise ValueError(
|
137
122
|
"Model name must be in the format 'provider/model_name' or "
|
138
123
|
"you must provide an 'api_provider' argument."
|
139
124
|
)
|
140
125
|
|
141
|
-
if llm_settings is not None:
|
142
|
-
stream_options = llm_settings.get("stream_options") or {}
|
126
|
+
if self.llm_settings is not None:
|
127
|
+
stream_options = self.llm_settings.get("stream_options") or {}
|
143
128
|
stream_options["include_usage"] = True
|
144
|
-
_llm_settings = deepcopy(llm_settings)
|
129
|
+
_llm_settings = deepcopy(self.llm_settings)
|
145
130
|
_llm_settings["stream_options"] = stream_options
|
146
131
|
else:
|
147
132
|
_llm_settings = OpenAILLMSettings(stream_options={"include_usage": True})
|
148
133
|
|
149
|
-
super().__init__(
|
150
|
-
model_name=provider_model_name,
|
151
|
-
model_id=model_id,
|
152
|
-
llm_settings=_llm_settings,
|
153
|
-
converters=OpenAIConverters(),
|
154
|
-
tools=tools,
|
155
|
-
response_schema=response_schema,
|
156
|
-
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
157
|
-
apply_response_schema_via_provider=apply_response_schema_via_provider,
|
158
|
-
api_provider=api_provider,
|
159
|
-
async_http_client=async_http_client,
|
160
|
-
async_http_client_params=async_http_client_params,
|
161
|
-
rate_limiter=rate_limiter,
|
162
|
-
max_client_retries=max_client_retries,
|
163
|
-
max_response_retries=max_response_retries,
|
164
|
-
)
|
165
|
-
|
166
134
|
response_schema_support: bool = any(
|
167
|
-
fnmatch.fnmatch(
|
168
|
-
for pat in
|
135
|
+
fnmatch.fnmatch(_model_name, pat)
|
136
|
+
for pat in _api_provider.get("response_schema_support") or []
|
169
137
|
)
|
170
|
-
if apply_response_schema_via_provider:
|
171
|
-
|
172
|
-
for
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
"Native response schema validation is not supported for model "
|
177
|
-
f"'{self._model_name}' by the API provider. Please set "
|
178
|
-
"apply_response_schema_via_provider=False."
|
179
|
-
)
|
138
|
+
if self.apply_response_schema_via_provider and not response_schema_support:
|
139
|
+
raise ValueError(
|
140
|
+
"Native response schema validation is not supported for model "
|
141
|
+
f"'{_model_name}' by the API provider. Please set "
|
142
|
+
"apply_response_schema_via_provider=False."
|
143
|
+
)
|
180
144
|
|
181
|
-
_async_openai_client_params = deepcopy(async_openai_client_params or {})
|
182
|
-
if self.
|
183
|
-
_async_openai_client_params["http_client"] = self.
|
145
|
+
_async_openai_client_params = deepcopy(self.async_openai_client_params or {})
|
146
|
+
if self.async_http_client is not None:
|
147
|
+
_async_openai_client_params["http_client"] = self.async_http_client
|
184
148
|
|
185
|
-
|
186
|
-
base_url=
|
187
|
-
api_key=
|
188
|
-
max_retries=max_client_retries,
|
149
|
+
_client = AsyncOpenAI(
|
150
|
+
base_url=_api_provider.get("base_url"),
|
151
|
+
api_key=_api_provider.get("api_key"),
|
152
|
+
max_retries=self.max_client_retries,
|
189
153
|
**_async_openai_client_params,
|
190
154
|
)
|
191
155
|
|
156
|
+
object.__setattr__(self, "model_name", _model_name)
|
157
|
+
object.__setattr__(self, "api_provider", _api_provider)
|
158
|
+
object.__setattr__(self, "llm_settings", _llm_settings)
|
159
|
+
object.__setattr__(self, "client", _client)
|
160
|
+
|
192
161
|
async def _get_completion(
|
193
162
|
self,
|
194
163
|
api_messages: Iterable[OpenAIMessageParam],
|
@@ -203,9 +172,9 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
203
172
|
response_format = api_response_schema or NOT_GIVEN
|
204
173
|
n = n_choices or NOT_GIVEN
|
205
174
|
|
206
|
-
if self.
|
207
|
-
return await self.
|
208
|
-
model=self.
|
175
|
+
if self.apply_response_schema_via_provider:
|
176
|
+
return await self.client.beta.chat.completions.parse(
|
177
|
+
model=self.model_name,
|
209
178
|
messages=api_messages,
|
210
179
|
tools=tools,
|
211
180
|
tool_choice=tool_choice,
|
@@ -214,8 +183,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
214
183
|
**api_llm_settings,
|
215
184
|
)
|
216
185
|
|
217
|
-
return await self.
|
218
|
-
model=self.
|
186
|
+
return await self.client.chat.completions.create(
|
187
|
+
model=self.model_name,
|
219
188
|
messages=api_messages,
|
220
189
|
tools=tools,
|
221
190
|
tool_choice=tool_choice,
|
@@ -238,10 +207,10 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
238
207
|
response_format = api_response_schema or NOT_GIVEN
|
239
208
|
n = n_choices or NOT_GIVEN
|
240
209
|
|
241
|
-
if self.
|
210
|
+
if self.apply_response_schema_via_provider:
|
242
211
|
stream_manager: OpenAIAsyncChatCompletionStreamManager[Any] = (
|
243
|
-
self.
|
244
|
-
model=self.
|
212
|
+
self.client.beta.chat.completions.stream(
|
213
|
+
model=self.model_name,
|
245
214
|
messages=api_messages,
|
246
215
|
tools=tools,
|
247
216
|
tool_choice=tool_choice,
|
@@ -257,8 +226,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
257
226
|
else:
|
258
227
|
stream_generator: AsyncStream[
|
259
228
|
OpenAICompletionChunk
|
260
|
-
] = await self.
|
261
|
-
model=self.
|
229
|
+
] = await self.client.chat.completions.create(
|
230
|
+
model=self.model_name,
|
262
231
|
messages=api_messages,
|
263
232
|
tools=tools,
|
264
233
|
tool_choice=tool_choice,
|
@@ -271,16 +240,20 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
271
240
|
yield completion_chunk
|
272
241
|
|
273
242
|
def combine_completion_chunks(
|
274
|
-
self,
|
243
|
+
self,
|
244
|
+
completion_chunks: list[OpenAICompletionChunk],
|
245
|
+
response_schema: Any | None = None,
|
246
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
275
247
|
) -> OpenAICompletion:
|
276
248
|
response_format = NOT_GIVEN
|
277
249
|
input_tools = NOT_GIVEN
|
278
|
-
if self.
|
279
|
-
if
|
280
|
-
response_format =
|
281
|
-
if
|
250
|
+
if self.apply_response_schema_via_provider:
|
251
|
+
if response_schema:
|
252
|
+
response_format = response_schema
|
253
|
+
if tools:
|
282
254
|
input_tools = [
|
283
|
-
self.
|
255
|
+
self.converters.to_tool(tool, strict=True)
|
256
|
+
for tool in tools.values()
|
284
257
|
]
|
285
258
|
state = ChatCompletionStreamState[Any](
|
286
259
|
input_tools=input_tools, response_format=response_format
|
@@ -13,8 +13,10 @@ from . import (
|
|
13
13
|
)
|
14
14
|
|
15
15
|
|
16
|
-
def to_api_tool(
|
17
|
-
|
16
|
+
def to_api_tool(
|
17
|
+
tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None
|
18
|
+
) -> OpenAIToolParam:
|
19
|
+
if strict:
|
18
20
|
return pydantic_function_tool(
|
19
21
|
model=tool.in_type, name=tool.name, description=tool.description
|
20
22
|
)
|
@@ -23,9 +25,9 @@ def to_api_tool(tool: BaseTool[BaseModel, Any, Any]) -> OpenAIToolParam:
|
|
23
25
|
name=tool.name,
|
24
26
|
description=tool.description,
|
25
27
|
parameters=tool.in_type.model_json_schema(),
|
26
|
-
strict=
|
28
|
+
strict=strict,
|
27
29
|
)
|
28
|
-
if
|
30
|
+
if strict is None:
|
29
31
|
function.pop("strict")
|
30
32
|
|
31
33
|
return OpenAIToolParam(type="function", function=function)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import logging
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from collections.abc import AsyncIterator, Callable, Coroutine
|
3
|
+
from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
|
4
4
|
from functools import wraps
|
5
5
|
from typing import (
|
6
6
|
Any,
|
@@ -37,7 +37,6 @@ from ..typing.tool import BaseTool
|
|
37
37
|
|
38
38
|
logger = logging.getLogger(__name__)
|
39
39
|
|
40
|
-
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
41
40
|
|
42
41
|
F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, Packet[Any]]])
|
43
42
|
F_stream = TypeVar("F_stream", bound=Callable[..., AsyncIterator[Event[Any]]])
|
@@ -102,10 +101,13 @@ def with_retry_stream(func: F_stream) -> F_stream:
|
|
102
101
|
return cast("F_stream", wrapper)
|
103
102
|
|
104
103
|
|
104
|
+
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
105
|
+
|
106
|
+
|
105
107
|
class RecipientSelector(Protocol[_OutT_contra, CtxT]):
|
106
108
|
def __call__(
|
107
|
-
self, output: _OutT_contra, ctx: RunContext[CtxT]
|
108
|
-
) ->
|
109
|
+
self, output: _OutT_contra, ctx: RunContext[CtxT]
|
110
|
+
) -> Sequence[ProcName] | None: ...
|
109
111
|
|
110
112
|
|
111
113
|
class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
|
@@ -118,7 +120,7 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
|
|
118
120
|
self,
|
119
121
|
name: ProcName,
|
120
122
|
max_retries: int = 0,
|
121
|
-
recipients:
|
123
|
+
recipients: Sequence[ProcName] | None = None,
|
122
124
|
**kwargs: Any,
|
123
125
|
) -> None:
|
124
126
|
self._in_type: type[InT]
|
@@ -239,7 +241,7 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
|
|
239
241
|
) from err
|
240
242
|
|
241
243
|
def _validate_recipients(
|
242
|
-
self, recipients:
|
244
|
+
self, recipients: Sequence[ProcName] | None, call_id: str
|
243
245
|
) -> None:
|
244
246
|
for r in recipients or []:
|
245
247
|
if r not in (self.recipients or []):
|
@@ -252,8 +254,8 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
|
|
252
254
|
|
253
255
|
@final
|
254
256
|
def _select_recipients(
|
255
|
-
self, output: OutT, ctx: RunContext[CtxT]
|
256
|
-
) ->
|
257
|
+
self, output: OutT, ctx: RunContext[CtxT]
|
258
|
+
) -> Sequence[ProcName] | None:
|
257
259
|
if self.recipient_selector:
|
258
260
|
return self.recipient_selector(output=output, ctx=ctx)
|
259
261
|
|
@@ -310,9 +312,15 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
|
|
310
312
|
name: str = tool_name
|
311
313
|
description: str = tool_description
|
312
314
|
|
313
|
-
async def run(
|
315
|
+
async def run(
|
316
|
+
self,
|
317
|
+
inp: InT,
|
318
|
+
*,
|
319
|
+
call_id: str | None = None,
|
320
|
+
ctx: RunContext[CtxT] | None = None,
|
321
|
+
) -> OutT:
|
314
322
|
result = await processor_instance.run(
|
315
|
-
in_args=inp, forgetful=True, ctx=ctx
|
323
|
+
in_args=inp, forgetful=True, call_id=call_id, ctx=ctx
|
316
324
|
)
|
317
325
|
|
318
326
|
return result.payloads[0]
|
@@ -30,7 +30,7 @@ class ParallelProcessor(
|
|
30
30
|
in_args: InT | None = None,
|
31
31
|
memory: MemT,
|
32
32
|
call_id: str,
|
33
|
-
ctx: RunContext[CtxT]
|
33
|
+
ctx: RunContext[CtxT],
|
34
34
|
) -> OutT:
|
35
35
|
return cast("OutT", in_args)
|
36
36
|
|
@@ -41,7 +41,7 @@ class ParallelProcessor(
|
|
41
41
|
in_args: InT | None = None,
|
42
42
|
memory: MemT,
|
43
43
|
call_id: str,
|
44
|
-
ctx: RunContext[CtxT]
|
44
|
+
ctx: RunContext[CtxT],
|
45
45
|
) -> AsyncIterator[Event[Any]]:
|
46
46
|
output = cast("OutT", in_args)
|
47
47
|
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
@@ -67,7 +67,7 @@ class ParallelProcessor(
|
|
67
67
|
in_args: InT | None = None,
|
68
68
|
forgetful: bool = False,
|
69
69
|
call_id: str,
|
70
|
-
ctx: RunContext[CtxT]
|
70
|
+
ctx: RunContext[CtxT],
|
71
71
|
) -> Packet[OutT]:
|
72
72
|
memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
73
73
|
|
@@ -86,7 +86,7 @@ class ParallelProcessor(
|
|
86
86
|
return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
|
87
87
|
|
88
88
|
async def _run_parallel(
|
89
|
-
self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT]
|
89
|
+
self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT]
|
90
90
|
) -> Packet[OutT]:
|
91
91
|
tasks = [
|
92
92
|
self._run_single(
|
@@ -114,6 +114,7 @@ class ParallelProcessor(
|
|
114
114
|
ctx: RunContext[CtxT] | None = None,
|
115
115
|
) -> Packet[OutT]:
|
116
116
|
call_id = self._generate_call_id(call_id)
|
117
|
+
ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
|
117
118
|
|
118
119
|
val_in_args = self._validate_inputs(
|
119
120
|
call_id=call_id,
|
@@ -143,7 +144,7 @@ class ParallelProcessor(
|
|
143
144
|
in_args: InT | None = None,
|
144
145
|
forgetful: bool = False,
|
145
146
|
call_id: str,
|
146
|
-
ctx: RunContext[CtxT]
|
147
|
+
ctx: RunContext[CtxT],
|
147
148
|
) -> AsyncIterator[Event[Any]]:
|
148
149
|
memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
149
150
|
|
@@ -178,7 +179,7 @@ class ParallelProcessor(
|
|
178
179
|
self,
|
179
180
|
in_args: list[InT],
|
180
181
|
call_id: str,
|
181
|
-
ctx: RunContext[CtxT]
|
182
|
+
ctx: RunContext[CtxT],
|
182
183
|
) -> AsyncIterator[Event[Any]]:
|
183
184
|
streams = [
|
184
185
|
self._run_single_stream(
|
@@ -222,6 +223,7 @@ class ParallelProcessor(
|
|
222
223
|
ctx: RunContext[CtxT] | None = None,
|
223
224
|
) -> AsyncIterator[Event[Any]]:
|
224
225
|
call_id = self._generate_call_id(call_id)
|
226
|
+
ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
|
225
227
|
|
226
228
|
val_in_args = self._validate_inputs(
|
227
229
|
call_id=call_id,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from collections.abc import AsyncIterator
|
2
|
+
from collections.abc import AsyncIterator, Sequence
|
3
3
|
from typing import Any, ClassVar, Generic, cast
|
4
4
|
|
5
5
|
from ..memory import MemT
|
@@ -25,7 +25,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
|
|
25
25
|
in_args: list[InT] | None = None,
|
26
26
|
memory: MemT,
|
27
27
|
call_id: str,
|
28
|
-
ctx: RunContext[CtxT]
|
28
|
+
ctx: RunContext[CtxT],
|
29
29
|
) -> list[OutT]:
|
30
30
|
return cast("list[OutT]", in_args)
|
31
31
|
|
@@ -36,7 +36,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
|
|
36
36
|
in_args: list[InT] | None = None,
|
37
37
|
memory: MemT,
|
38
38
|
call_id: str,
|
39
|
-
ctx: RunContext[CtxT]
|
39
|
+
ctx: RunContext[CtxT],
|
40
40
|
) -> AsyncIterator[Event[Any]]:
|
41
41
|
outputs = await self._process(
|
42
42
|
chat_inputs=chat_inputs,
|
@@ -58,7 +58,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
|
|
58
58
|
in_args: InT | list[InT] | None = None,
|
59
59
|
forgetful: bool = False,
|
60
60
|
call_id: str | None = None,
|
61
|
-
ctx: RunContext[CtxT]
|
61
|
+
ctx: RunContext[CtxT],
|
62
62
|
) -> tuple[list[InT] | None, MemT, str]:
|
63
63
|
call_id = self._generate_call_id(call_id)
|
64
64
|
|
@@ -74,10 +74,10 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
|
|
74
74
|
return val_in_args, memory, call_id
|
75
75
|
|
76
76
|
def _postprocess(
|
77
|
-
self, outputs: list[OutT], call_id: str, ctx: RunContext[CtxT]
|
77
|
+
self, outputs: list[OutT], call_id: str, ctx: RunContext[CtxT]
|
78
78
|
) -> Packet[OutT]:
|
79
79
|
payloads: list[OutT] = []
|
80
|
-
routing: dict[int,
|
80
|
+
routing: dict[int, Sequence[ProcName] | None] = {}
|
81
81
|
for idx, output in enumerate(outputs):
|
82
82
|
val_output = self._validate_output(output, call_id=call_id)
|
83
83
|
recipients = self._select_recipients(output=val_output, ctx=ctx)
|
@@ -105,6 +105,8 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
|
|
105
105
|
call_id: str | None = None,
|
106
106
|
ctx: RunContext[CtxT] | None = None,
|
107
107
|
) -> Packet[OutT]:
|
108
|
+
ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
|
109
|
+
|
108
110
|
val_in_args, memory, call_id = self._preprocess(
|
109
111
|
chat_inputs=chat_inputs,
|
110
112
|
in_packet=in_packet,
|
@@ -134,6 +136,8 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
|
|
134
136
|
call_id: str | None = None,
|
135
137
|
ctx: RunContext[CtxT] | None = None,
|
136
138
|
) -> AsyncIterator[Event[Any]]:
|
139
|
+
ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
|
140
|
+
|
137
141
|
val_in_args, memory, call_id = self._preprocess(
|
138
142
|
chat_inputs=chat_inputs,
|
139
143
|
in_packet=in_packet,
|
grasp_agents/prompt_builder.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
from collections.abc import Sequence
|
3
|
-
from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar, final
|
3
|
+
from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar, cast, final
|
4
4
|
|
5
5
|
from pydantic import BaseModel, TypeAdapter
|
6
6
|
|
@@ -15,13 +15,11 @@ _InT_contra = TypeVar("_InT_contra", contravariant=True)
|
|
15
15
|
|
16
16
|
|
17
17
|
class SystemPromptBuilder(Protocol[CtxT]):
|
18
|
-
def __call__(self, ctx: RunContext[CtxT]
|
18
|
+
def __call__(self, ctx: RunContext[CtxT]) -> str | None: ...
|
19
19
|
|
20
20
|
|
21
21
|
class InputContentBuilder(Protocol[_InT_contra, CtxT]):
|
22
|
-
def __call__(
|
23
|
-
self, in_args: _InT_contra | None, *, ctx: RunContext[CtxT] | None
|
24
|
-
) -> Content: ...
|
22
|
+
def __call__(self, in_args: _InT_contra, ctx: RunContext[CtxT]) -> Content: ...
|
25
23
|
|
26
24
|
|
27
25
|
PromptArgumentType: TypeAlias = str | bool | int | ImageData
|
@@ -45,7 +43,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
45
43
|
self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
46
44
|
|
47
45
|
@final
|
48
|
-
def build_system_prompt(self, ctx: RunContext[CtxT]
|
46
|
+
def build_system_prompt(self, ctx: RunContext[CtxT]) -> str | None:
|
49
47
|
if self.system_prompt_builder:
|
50
48
|
return self.system_prompt_builder(ctx=ctx)
|
51
49
|
|
@@ -73,23 +71,19 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
73
71
|
|
74
72
|
@final
|
75
73
|
def _build_input_content(
|
76
|
-
self,
|
77
|
-
in_args: InT | None = None,
|
78
|
-
ctx: RunContext[CtxT] | None = None,
|
74
|
+
self, in_args: InT | None, ctx: RunContext[CtxT]
|
79
75
|
) -> Content:
|
80
|
-
|
81
|
-
if in_args is not None:
|
82
|
-
val_in_args = self._validate_input_args(in_args=in_args)
|
83
|
-
|
84
|
-
if self.input_content_builder:
|
85
|
-
return self.input_content_builder(in_args=val_in_args, ctx=ctx)
|
86
|
-
|
87
|
-
if val_in_args is None:
|
76
|
+
if in_args is None and self._in_type is not type(None):
|
88
77
|
raise InputPromptBuilderError(
|
89
78
|
proc_name=self._agent_name,
|
90
|
-
message="
|
91
|
-
f"
|
79
|
+
message="Either chat inputs or input arguments must be provided "
|
80
|
+
f"when input type is not None [agent_name={self._agent_name}]",
|
92
81
|
)
|
82
|
+
in_args = cast("InT", in_args)
|
83
|
+
|
84
|
+
val_in_args = self._validate_input_args(in_args)
|
85
|
+
if self.input_content_builder:
|
86
|
+
return self.input_content_builder(in_args=val_in_args, ctx=ctx)
|
93
87
|
|
94
88
|
if issubclass(self._in_type, BaseModel) and isinstance(val_in_args, BaseModel):
|
95
89
|
val_in_args_map = self._format_pydantic_prompt_args(val_in_args)
|
@@ -106,17 +100,17 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
106
100
|
def build_input_message(
|
107
101
|
self,
|
108
102
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
103
|
+
*,
|
109
104
|
in_args: InT | None = None,
|
110
|
-
ctx: RunContext[CtxT]
|
105
|
+
ctx: RunContext[CtxT],
|
111
106
|
) -> UserMessage | None:
|
112
|
-
if chat_inputs is not None
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
if chat_inputs:
|
107
|
+
if chat_inputs is not None:
|
108
|
+
if in_args is not None:
|
109
|
+
raise InputPromptBuilderError(
|
110
|
+
proc_name=self._agent_name,
|
111
|
+
message="Cannot use both chat inputs and input arguments "
|
112
|
+
f"at the same time [agent_name={self._agent_name}]",
|
113
|
+
)
|
120
114
|
if isinstance(chat_inputs, LLMPrompt):
|
121
115
|
return UserMessage.from_text(chat_inputs, name=self._agent_name)
|
122
116
|
return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
|
grasp_agents/run_context.py
CHANGED
grasp_agents/runner.py
CHANGED
@@ -69,7 +69,9 @@ class Converters(ABC):
|
|
69
69
|
|
70
70
|
@staticmethod
|
71
71
|
@abstractmethod
|
72
|
-
def to_tool(
|
72
|
+
def to_tool(
|
73
|
+
tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
|
74
|
+
) -> Any:
|
73
75
|
pass
|
74
76
|
|
75
77
|
@staticmethod
|
grasp_agents/typing/tool.py
CHANGED
@@ -48,8 +48,6 @@ class BaseTool(
|
|
48
48
|
name: str
|
49
49
|
description: str
|
50
50
|
|
51
|
-
strict: bool | None = None
|
52
|
-
|
53
51
|
_in_type: type[_InT] = PrivateAttr()
|
54
52
|
_out_type: type[_OutT_co] = PrivateAttr()
|
55
53
|
|
@@ -62,14 +60,24 @@ class BaseTool(
|
|
62
60
|
return self._out_type
|
63
61
|
|
64
62
|
@abstractmethod
|
65
|
-
async def run(
|
63
|
+
async def run(
|
64
|
+
self,
|
65
|
+
inp: _InT,
|
66
|
+
*,
|
67
|
+
call_id: str | None = None,
|
68
|
+
ctx: RunContext[CtxT] | None = None,
|
69
|
+
) -> _OutT_co:
|
66
70
|
pass
|
67
71
|
|
68
72
|
async def __call__(
|
69
|
-
self,
|
73
|
+
self,
|
74
|
+
*,
|
75
|
+
call_id: str | None = None,
|
76
|
+
ctx: RunContext[CtxT] | None = None,
|
77
|
+
**kwargs: Any,
|
70
78
|
) -> _OutT_co:
|
71
79
|
input_args = TypeAdapter(self._in_type).validate_python(kwargs)
|
72
|
-
output = await self.run(input_args, ctx=ctx)
|
80
|
+
output = await self.run(input_args, call_id=call_id, ctx=ctx)
|
73
81
|
|
74
82
|
return TypeAdapter(self._out_type).validate_python(output)
|
75
83
|
|
@@ -21,7 +21,7 @@ class WorkflowProcessor(
|
|
21
21
|
subprocs: Sequence[BaseProcessor[Any, Any, Any, CtxT]],
|
22
22
|
start_proc: BaseProcessor[InT, Any, Any, CtxT],
|
23
23
|
end_proc: BaseProcessor[Any, OutT, Any, CtxT],
|
24
|
-
recipients:
|
24
|
+
recipients: Sequence[ProcName] | None = None,
|
25
25
|
max_retries: int = 0,
|
26
26
|
) -> None:
|
27
27
|
super().__init__(name=name, recipients=recipients, max_retries=max_retries)
|
@@ -57,11 +57,11 @@ class WorkflowProcessor(
|
|
57
57
|
return func
|
58
58
|
|
59
59
|
@property
|
60
|
-
def recipients(self) ->
|
60
|
+
def recipients(self) -> Sequence[ProcName] | None:
|
61
61
|
return self._end_proc.recipients
|
62
62
|
|
63
63
|
@recipients.setter
|
64
|
-
def recipients(self, value:
|
64
|
+
def recipients(self, value: Sequence[ProcName] | None) -> None:
|
65
65
|
if hasattr(self, "_end_proc"):
|
66
66
|
self._end_proc.recipients = value
|
67
67
|
|
@@ -96,7 +96,7 @@ class WorkflowProcessor(
|
|
96
96
|
pass
|
97
97
|
|
98
98
|
@abstractmethod
|
99
|
-
async def run_stream(
|
99
|
+
async def run_stream(
|
100
100
|
self,
|
101
101
|
chat_inputs: Any | None = None,
|
102
102
|
*,
|