grasp_agents 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +88 -109
- grasp_agents/litellm/converters.py +4 -2
- grasp_agents/litellm/lite_llm.py +90 -80
- grasp_agents/llm.py +52 -97
- grasp_agents/llm_agent.py +32 -36
- grasp_agents/llm_agent_memory.py +3 -2
- grasp_agents/llm_policy_executor.py +63 -33
- grasp_agents/openai/converters.py +4 -2
- grasp_agents/openai/openai_llm.py +66 -85
- grasp_agents/openai/tool_converters.py +6 -4
- grasp_agents/processors/base_processor.py +18 -10
- grasp_agents/processors/parallel_processor.py +8 -6
- grasp_agents/processors/processor.py +10 -6
- grasp_agents/prompt_builder.py +22 -28
- grasp_agents/run_context.py +1 -1
- grasp_agents/runner.py +1 -1
- grasp_agents/typing/converters.py +3 -1
- grasp_agents/typing/tool.py +13 -5
- grasp_agents/workflow/workflow_processor.py +4 -4
- {grasp_agents-0.5.8.dist-info → grasp_agents-0.5.10.dist-info}/METADATA +12 -13
- {grasp_agents-0.5.8.dist-info → grasp_agents-0.5.10.dist-info}/RECORD +23 -23
- {grasp_agents-0.5.8.dist-info → grasp_agents-0.5.10.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.8.dist-info → grasp_agents-0.5.10.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
|
-
from collections.abc import AsyncIterator, Coroutine, Sequence
|
3
|
+
from collections.abc import AsyncIterator, Coroutine, Mapping, Sequence
|
4
4
|
from itertools import starmap
|
5
5
|
from logging import getLogger
|
6
6
|
from typing import Any, Generic, Protocol, final
|
@@ -36,7 +36,7 @@ class ToolCallLoopTerminator(Protocol[CtxT]):
|
|
36
36
|
self,
|
37
37
|
conversation: Messages,
|
38
38
|
*,
|
39
|
-
ctx: RunContext[CtxT]
|
39
|
+
ctx: RunContext[CtxT],
|
40
40
|
**kwargs: Any,
|
41
41
|
) -> bool: ...
|
42
42
|
|
@@ -46,7 +46,7 @@ class MemoryManager(Protocol[CtxT]):
|
|
46
46
|
self,
|
47
47
|
memory: LLMAgentMemory,
|
48
48
|
*,
|
49
|
-
ctx: RunContext[CtxT]
|
49
|
+
ctx: RunContext[CtxT],
|
50
50
|
**kwargs: Any,
|
51
51
|
) -> None: ...
|
52
52
|
|
@@ -54,9 +54,12 @@ class MemoryManager(Protocol[CtxT]):
|
|
54
54
|
class LLMPolicyExecutor(Generic[CtxT]):
|
55
55
|
def __init__(
|
56
56
|
self,
|
57
|
+
*,
|
57
58
|
agent_name: str,
|
58
59
|
llm: LLM[LLMSettings, Converters],
|
59
60
|
tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
|
61
|
+
response_schema: Any | None = None,
|
62
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
60
63
|
max_turns: int,
|
61
64
|
react_mode: bool = False,
|
62
65
|
final_answer_type: type[BaseModel] = BaseModel,
|
@@ -70,12 +73,15 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
70
73
|
self._final_answer_as_tool_call = final_answer_as_tool_call
|
71
74
|
self._final_answer_tool = self.get_final_answer_tool()
|
72
75
|
|
73
|
-
|
76
|
+
tools_list: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
|
74
77
|
if tools and final_answer_as_tool_call:
|
75
|
-
|
78
|
+
tools_list = tools + [self._final_answer_tool]
|
79
|
+
self._tools = {t.name: t for t in tools_list} if tools_list else None
|
80
|
+
|
81
|
+
self._response_schema = response_schema
|
82
|
+
self._response_schema_by_xml_tag = response_schema_by_xml_tag
|
76
83
|
|
77
84
|
self._llm = llm
|
78
|
-
self._llm.tools = _tools
|
79
85
|
|
80
86
|
self._max_turns = max_turns
|
81
87
|
self._react_mode = react_mode
|
@@ -91,9 +97,21 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
91
97
|
def llm(self) -> LLM[LLMSettings, Converters]:
|
92
98
|
return self._llm
|
93
99
|
|
100
|
+
@property
|
101
|
+
def response_schema(self) -> Any | None:
|
102
|
+
return self._response_schema
|
103
|
+
|
104
|
+
@response_schema.setter
|
105
|
+
def response_schema(self, value: Any | None) -> None:
|
106
|
+
self._response_schema = value
|
107
|
+
|
108
|
+
@property
|
109
|
+
def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
|
110
|
+
return self._response_schema_by_xml_tag
|
111
|
+
|
94
112
|
@property
|
95
113
|
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
96
|
-
return self.
|
114
|
+
return self._tools or {}
|
97
115
|
|
98
116
|
@property
|
99
117
|
def max_turns(self) -> int:
|
@@ -104,7 +122,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
104
122
|
self,
|
105
123
|
conversation: Messages,
|
106
124
|
*,
|
107
|
-
ctx: RunContext[CtxT]
|
125
|
+
ctx: RunContext[CtxT],
|
108
126
|
**kwargs: Any,
|
109
127
|
) -> bool:
|
110
128
|
if self.tool_call_loop_terminator:
|
@@ -117,7 +135,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
117
135
|
self,
|
118
136
|
memory: LLMAgentMemory,
|
119
137
|
*,
|
120
|
-
ctx: RunContext[CtxT]
|
138
|
+
ctx: RunContext[CtxT],
|
121
139
|
**kwargs: Any,
|
122
140
|
) -> None:
|
123
141
|
if self.memory_manager:
|
@@ -126,12 +144,16 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
126
144
|
async def generate_message(
|
127
145
|
self,
|
128
146
|
memory: LLMAgentMemory,
|
147
|
+
*,
|
129
148
|
call_id: str,
|
130
149
|
tool_choice: ToolChoice | None = None,
|
131
|
-
ctx: RunContext[CtxT]
|
150
|
+
ctx: RunContext[CtxT],
|
132
151
|
) -> AssistantMessage:
|
133
152
|
completion = await self.llm.generate_completion(
|
134
153
|
memory.message_history,
|
154
|
+
response_schema=self.response_schema,
|
155
|
+
response_schema_by_xml_tag=self.response_schema_by_xml_tag,
|
156
|
+
tools=self.tools,
|
135
157
|
tool_choice=tool_choice,
|
136
158
|
n_choices=1,
|
137
159
|
proc_name=self.agent_name,
|
@@ -147,9 +169,10 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
147
169
|
async def generate_message_stream(
|
148
170
|
self,
|
149
171
|
memory: LLMAgentMemory,
|
172
|
+
*,
|
150
173
|
call_id: str,
|
151
174
|
tool_choice: ToolChoice | None = None,
|
152
|
-
ctx: RunContext[CtxT]
|
175
|
+
ctx: RunContext[CtxT],
|
153
176
|
) -> AsyncIterator[
|
154
177
|
CompletionChunkEvent[CompletionChunk]
|
155
178
|
| CompletionEvent
|
@@ -160,6 +183,9 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
160
183
|
|
161
184
|
llm_event_stream = self.llm.generate_completion_stream(
|
162
185
|
memory.message_history,
|
186
|
+
response_schema=self.response_schema,
|
187
|
+
response_schema_by_xml_tag=self.response_schema_by_xml_tag,
|
188
|
+
tools=self.tools,
|
163
189
|
tool_choice=tool_choice,
|
164
190
|
n_choices=1,
|
165
191
|
proc_name=self.agent_name,
|
@@ -189,14 +215,14 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
189
215
|
calls: Sequence[ToolCall],
|
190
216
|
memory: LLMAgentMemory,
|
191
217
|
call_id: str,
|
192
|
-
ctx: RunContext[CtxT]
|
218
|
+
ctx: RunContext[CtxT],
|
193
219
|
) -> Sequence[ToolMessage]:
|
194
220
|
# TODO: Add image support
|
195
221
|
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
196
222
|
for call in calls:
|
197
223
|
tool = self.tools[call.tool_name]
|
198
224
|
args = json.loads(call.tool_arguments)
|
199
|
-
corouts.append(tool(ctx=ctx, **args))
|
225
|
+
corouts.append(tool(call_id=call_id, ctx=ctx, **args))
|
200
226
|
|
201
227
|
outs = await asyncio.gather(*corouts)
|
202
228
|
tool_messages = list(
|
@@ -217,7 +243,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
217
243
|
calls: Sequence[ToolCall],
|
218
244
|
memory: LLMAgentMemory,
|
219
245
|
call_id: str,
|
220
|
-
ctx: RunContext[CtxT]
|
246
|
+
ctx: RunContext[CtxT],
|
221
247
|
) -> AsyncIterator[ToolMessageEvent]:
|
222
248
|
tool_messages = await self.call_tools(
|
223
249
|
calls, memory=memory, call_id=call_id, ctx=ctx
|
@@ -245,7 +271,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
245
271
|
return final_answer_message
|
246
272
|
|
247
273
|
async def _generate_final_answer(
|
248
|
-
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
274
|
+
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
249
275
|
) -> AssistantMessage:
|
250
276
|
user_message = UserMessage.from_text(
|
251
277
|
"Exceeded the maximum number of turns: provide a final answer now!"
|
@@ -268,7 +294,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
268
294
|
return final_answer_message
|
269
295
|
|
270
296
|
async def _generate_final_answer_stream(
|
271
|
-
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
297
|
+
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
272
298
|
) -> AsyncIterator[Event[Any]]:
|
273
299
|
user_message = UserMessage.from_text(
|
274
300
|
"Exceeded the maximum number of turns: provide a final answer now!",
|
@@ -296,7 +322,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
296
322
|
)
|
297
323
|
|
298
324
|
async def execute(
|
299
|
-
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
325
|
+
self, memory: LLMAgentMemory, call_id: str, ctx: RunContext[CtxT]
|
300
326
|
) -> AssistantMessage | Sequence[AssistantMessage]:
|
301
327
|
# 1. Generate the first message:
|
302
328
|
# In ReAct mode, we generate the first message without tool calls
|
@@ -379,7 +405,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
379
405
|
self,
|
380
406
|
memory: LLMAgentMemory,
|
381
407
|
call_id: str,
|
382
|
-
ctx: RunContext[CtxT]
|
408
|
+
ctx: RunContext[CtxT],
|
383
409
|
) -> AsyncIterator[Event[Any]]:
|
384
410
|
tool_choice: ToolChoice = "none" if self._react_mode else "auto"
|
385
411
|
gen_message: AssistantMessage | None = None
|
@@ -464,7 +490,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
464
490
|
)
|
465
491
|
|
466
492
|
async def run(
|
467
|
-
self,
|
493
|
+
self,
|
494
|
+
inp: BaseModel,
|
495
|
+
*,
|
496
|
+
call_id: str | None = None,
|
497
|
+
ctx: RunContext[Any] | None = None,
|
468
498
|
) -> None:
|
469
499
|
return None
|
470
500
|
|
@@ -473,22 +503,22 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
473
503
|
def _process_completion(
|
474
504
|
self,
|
475
505
|
completion: Completion,
|
506
|
+
*,
|
476
507
|
call_id: str,
|
477
508
|
print_messages: bool = False,
|
478
|
-
ctx: RunContext[CtxT]
|
509
|
+
ctx: RunContext[CtxT],
|
479
510
|
) -> None:
|
480
|
-
|
481
|
-
|
482
|
-
|
511
|
+
ctx.completions[self.agent_name].append(completion)
|
512
|
+
ctx.usage_tracker.update(
|
513
|
+
agent_name=self.agent_name,
|
514
|
+
completions=[completion],
|
515
|
+
model_name=self.llm.model_name,
|
516
|
+
)
|
517
|
+
if ctx.printer and print_messages:
|
518
|
+
usages = [None] * (len(completion.messages) - 1) + [completion.usage]
|
519
|
+
ctx.printer.print_messages(
|
520
|
+
completion.messages,
|
521
|
+
usages=usages,
|
483
522
|
agent_name=self.agent_name,
|
484
|
-
|
485
|
-
model_name=self.llm.model_name,
|
523
|
+
call_id=call_id,
|
486
524
|
)
|
487
|
-
if ctx.printer and print_messages:
|
488
|
-
usages = [None] * (len(completion.messages) - 1) + [completion.usage]
|
489
|
-
ctx.printer.print_messages(
|
490
|
-
completion.messages,
|
491
|
-
usages=usages,
|
492
|
-
agent_name=self.agent_name,
|
493
|
-
call_id=call_id,
|
494
|
-
)
|
@@ -96,8 +96,10 @@ class OpenAIConverters(Converters):
|
|
96
96
|
return from_api_tool_message(raw_message, name=name, **kwargs)
|
97
97
|
|
98
98
|
@staticmethod
|
99
|
-
def to_tool(
|
100
|
-
|
99
|
+
def to_tool(
|
100
|
+
tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
|
101
|
+
) -> OpenAIToolParam:
|
102
|
+
return to_api_tool(tool, strict=strict, **kwargs)
|
101
103
|
|
102
104
|
@staticmethod
|
103
105
|
def to_tool_choice(
|
@@ -3,9 +3,9 @@ import logging
|
|
3
3
|
import os
|
4
4
|
from collections.abc import AsyncIterator, Iterable, Mapping
|
5
5
|
from copy import deepcopy
|
6
|
+
from dataclasses import dataclass, field
|
6
7
|
from typing import Any, Literal
|
7
8
|
|
8
|
-
import httpx
|
9
9
|
from openai import AsyncOpenAI, AsyncStream
|
10
10
|
from openai._types import NOT_GIVEN # type: ignore[import]
|
11
11
|
from openai.lib.streaming.chat import (
|
@@ -15,8 +15,7 @@ from openai.lib.streaming.chat import ChatCompletionStreamState
|
|
15
15
|
from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
|
16
16
|
from pydantic import BaseModel
|
17
17
|
|
18
|
-
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
19
|
-
from ..http_client import AsyncHTTPClientParams
|
18
|
+
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
20
19
|
from ..typing.tool import BaseTool
|
21
20
|
from . import (
|
22
21
|
OpenAICompletion,
|
@@ -90,97 +89,75 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
|
|
90
89
|
# TODO: support audio
|
91
90
|
|
92
91
|
|
92
|
+
@dataclass(frozen=True)
|
93
93
|
class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
102
|
-
apply_response_schema_via_provider: bool = False,
|
103
|
-
model_id: str | None = None,
|
104
|
-
# Custom LLM provider
|
105
|
-
api_provider: APIProvider | None = None,
|
106
|
-
# Connection settings
|
107
|
-
max_client_retries: int = 2,
|
108
|
-
async_http_client: httpx.AsyncClient | None = None,
|
109
|
-
async_http_client_params: (
|
110
|
-
dict[str, Any] | AsyncHTTPClientParams | None
|
111
|
-
) = None,
|
112
|
-
async_openai_client_params: dict[str, Any] | None = None,
|
113
|
-
# Rate limiting
|
114
|
-
rate_limiter: LLMRateLimiter | None = None,
|
115
|
-
# LLM response retries: try to regenerate to pass validation
|
116
|
-
max_response_retries: int = 1,
|
117
|
-
) -> None:
|
94
|
+
converters: OpenAIConverters = field(default_factory=OpenAIConverters)
|
95
|
+
async_openai_client_params: dict[str, Any] | None = None
|
96
|
+
client: AsyncOpenAI = field(init=False)
|
97
|
+
|
98
|
+
def __post_init__(self):
|
99
|
+
super().__post_init__()
|
100
|
+
|
118
101
|
openai_compatible_providers = get_openai_compatible_providers()
|
119
102
|
|
120
|
-
|
121
|
-
|
122
|
-
|
103
|
+
_api_provider = self.api_provider
|
104
|
+
|
105
|
+
model_name_parts = self.model_name.split("/", 1)
|
106
|
+
if _api_provider is not None:
|
107
|
+
_model_name = self.model_name
|
123
108
|
elif len(model_name_parts) == 2:
|
124
109
|
compat_providers_map = {
|
125
110
|
provider["name"]: provider for provider in openai_compatible_providers
|
126
111
|
}
|
127
|
-
provider_name,
|
112
|
+
provider_name, _model_name = model_name_parts
|
128
113
|
if provider_name not in compat_providers_map:
|
129
114
|
raise ValueError(
|
130
|
-
f"
|
131
|
-
"
|
115
|
+
f"API provider '{provider_name}' is not a supported OpenAI "
|
116
|
+
f"compatible provider. Supported providers are: "
|
132
117
|
f"{', '.join(compat_providers_map.keys())}"
|
133
118
|
)
|
134
|
-
|
119
|
+
_api_provider = compat_providers_map[provider_name]
|
135
120
|
else:
|
136
121
|
raise ValueError(
|
137
122
|
"Model name must be in the format 'provider/model_name' or "
|
138
123
|
"you must provide an 'api_provider' argument."
|
139
124
|
)
|
140
125
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
149
|
-
apply_response_schema_via_provider=apply_response_schema_via_provider,
|
150
|
-
api_provider=api_provider,
|
151
|
-
async_http_client=async_http_client,
|
152
|
-
async_http_client_params=async_http_client_params,
|
153
|
-
rate_limiter=rate_limiter,
|
154
|
-
max_client_retries=max_client_retries,
|
155
|
-
max_response_retries=max_response_retries,
|
156
|
-
)
|
126
|
+
if self.llm_settings is not None:
|
127
|
+
stream_options = self.llm_settings.get("stream_options") or {}
|
128
|
+
stream_options["include_usage"] = True
|
129
|
+
_llm_settings = deepcopy(self.llm_settings)
|
130
|
+
_llm_settings["stream_options"] = stream_options
|
131
|
+
else:
|
132
|
+
_llm_settings = OpenAILLMSettings(stream_options={"include_usage": True})
|
157
133
|
|
158
134
|
response_schema_support: bool = any(
|
159
|
-
fnmatch.fnmatch(
|
160
|
-
for pat in
|
135
|
+
fnmatch.fnmatch(_model_name, pat)
|
136
|
+
for pat in _api_provider.get("response_schema_support") or []
|
161
137
|
)
|
162
|
-
if apply_response_schema_via_provider:
|
163
|
-
|
164
|
-
for
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
"Native response schema validation is not supported for model "
|
169
|
-
f"'{self._model_name}' by the API provider. Please set "
|
170
|
-
"apply_response_schema_via_provider=False."
|
171
|
-
)
|
138
|
+
if self.apply_response_schema_via_provider and not response_schema_support:
|
139
|
+
raise ValueError(
|
140
|
+
"Native response schema validation is not supported for model "
|
141
|
+
f"'{_model_name}' by the API provider. Please set "
|
142
|
+
"apply_response_schema_via_provider=False."
|
143
|
+
)
|
172
144
|
|
173
|
-
_async_openai_client_params = deepcopy(async_openai_client_params or {})
|
174
|
-
if self.
|
175
|
-
_async_openai_client_params["http_client"] = self.
|
145
|
+
_async_openai_client_params = deepcopy(self.async_openai_client_params or {})
|
146
|
+
if self.async_http_client is not None:
|
147
|
+
_async_openai_client_params["http_client"] = self.async_http_client
|
176
148
|
|
177
|
-
|
178
|
-
base_url=
|
179
|
-
api_key=
|
180
|
-
max_retries=max_client_retries,
|
149
|
+
_client = AsyncOpenAI(
|
150
|
+
base_url=_api_provider.get("base_url"),
|
151
|
+
api_key=_api_provider.get("api_key"),
|
152
|
+
max_retries=self.max_client_retries,
|
181
153
|
**_async_openai_client_params,
|
182
154
|
)
|
183
155
|
|
156
|
+
object.__setattr__(self, "model_name", _model_name)
|
157
|
+
object.__setattr__(self, "api_provider", _api_provider)
|
158
|
+
object.__setattr__(self, "llm_settings", _llm_settings)
|
159
|
+
object.__setattr__(self, "client", _client)
|
160
|
+
|
184
161
|
async def _get_completion(
|
185
162
|
self,
|
186
163
|
api_messages: Iterable[OpenAIMessageParam],
|
@@ -195,9 +172,9 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
195
172
|
response_format = api_response_schema or NOT_GIVEN
|
196
173
|
n = n_choices or NOT_GIVEN
|
197
174
|
|
198
|
-
if self.
|
199
|
-
return await self.
|
200
|
-
model=self.
|
175
|
+
if self.apply_response_schema_via_provider:
|
176
|
+
return await self.client.beta.chat.completions.parse(
|
177
|
+
model=self.model_name,
|
201
178
|
messages=api_messages,
|
202
179
|
tools=tools,
|
203
180
|
tool_choice=tool_choice,
|
@@ -206,8 +183,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
206
183
|
**api_llm_settings,
|
207
184
|
)
|
208
185
|
|
209
|
-
return await self.
|
210
|
-
model=self.
|
186
|
+
return await self.client.chat.completions.create(
|
187
|
+
model=self.model_name,
|
211
188
|
messages=api_messages,
|
212
189
|
tools=tools,
|
213
190
|
tool_choice=tool_choice,
|
@@ -230,10 +207,10 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
230
207
|
response_format = api_response_schema or NOT_GIVEN
|
231
208
|
n = n_choices or NOT_GIVEN
|
232
209
|
|
233
|
-
if self.
|
210
|
+
if self.apply_response_schema_via_provider:
|
234
211
|
stream_manager: OpenAIAsyncChatCompletionStreamManager[Any] = (
|
235
|
-
self.
|
236
|
-
model=self.
|
212
|
+
self.client.beta.chat.completions.stream(
|
213
|
+
model=self.model_name,
|
237
214
|
messages=api_messages,
|
238
215
|
tools=tools,
|
239
216
|
tool_choice=tool_choice,
|
@@ -249,8 +226,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
249
226
|
else:
|
250
227
|
stream_generator: AsyncStream[
|
251
228
|
OpenAICompletionChunk
|
252
|
-
] = await self.
|
253
|
-
model=self.
|
229
|
+
] = await self.client.chat.completions.create(
|
230
|
+
model=self.model_name,
|
254
231
|
messages=api_messages,
|
255
232
|
tools=tools,
|
256
233
|
tool_choice=tool_choice,
|
@@ -263,16 +240,20 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
263
240
|
yield completion_chunk
|
264
241
|
|
265
242
|
def combine_completion_chunks(
|
266
|
-
self,
|
243
|
+
self,
|
244
|
+
completion_chunks: list[OpenAICompletionChunk],
|
245
|
+
response_schema: Any | None = None,
|
246
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
267
247
|
) -> OpenAICompletion:
|
268
248
|
response_format = NOT_GIVEN
|
269
249
|
input_tools = NOT_GIVEN
|
270
|
-
if self.
|
271
|
-
if
|
272
|
-
response_format =
|
273
|
-
if
|
250
|
+
if self.apply_response_schema_via_provider:
|
251
|
+
if response_schema:
|
252
|
+
response_format = response_schema
|
253
|
+
if tools:
|
274
254
|
input_tools = [
|
275
|
-
self.
|
255
|
+
self.converters.to_tool(tool, strict=True)
|
256
|
+
for tool in tools.values()
|
276
257
|
]
|
277
258
|
state = ChatCompletionStreamState[Any](
|
278
259
|
input_tools=input_tools, response_format=response_format
|
@@ -13,8 +13,10 @@ from . import (
|
|
13
13
|
)
|
14
14
|
|
15
15
|
|
16
|
-
def to_api_tool(
|
17
|
-
|
16
|
+
def to_api_tool(
|
17
|
+
tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None
|
18
|
+
) -> OpenAIToolParam:
|
19
|
+
if strict:
|
18
20
|
return pydantic_function_tool(
|
19
21
|
model=tool.in_type, name=tool.name, description=tool.description
|
20
22
|
)
|
@@ -23,9 +25,9 @@ def to_api_tool(tool: BaseTool[BaseModel, Any, Any]) -> OpenAIToolParam:
|
|
23
25
|
name=tool.name,
|
24
26
|
description=tool.description,
|
25
27
|
parameters=tool.in_type.model_json_schema(),
|
26
|
-
strict=
|
28
|
+
strict=strict,
|
27
29
|
)
|
28
|
-
if
|
30
|
+
if strict is None:
|
29
31
|
function.pop("strict")
|
30
32
|
|
31
33
|
return OpenAIToolParam(type="function", function=function)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import logging
|
2
2
|
from abc import ABC, abstractmethod
|
3
|
-
from collections.abc import AsyncIterator, Callable, Coroutine
|
3
|
+
from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
|
4
4
|
from functools import wraps
|
5
5
|
from typing import (
|
6
6
|
Any,
|
@@ -37,7 +37,6 @@ from ..typing.tool import BaseTool
|
|
37
37
|
|
38
38
|
logger = logging.getLogger(__name__)
|
39
39
|
|
40
|
-
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
41
40
|
|
42
41
|
F = TypeVar("F", bound=Callable[..., Coroutine[Any, Any, Packet[Any]]])
|
43
42
|
F_stream = TypeVar("F_stream", bound=Callable[..., AsyncIterator[Event[Any]]])
|
@@ -102,10 +101,13 @@ def with_retry_stream(func: F_stream) -> F_stream:
|
|
102
101
|
return cast("F_stream", wrapper)
|
103
102
|
|
104
103
|
|
104
|
+
_OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
105
|
+
|
106
|
+
|
105
107
|
class RecipientSelector(Protocol[_OutT_contra, CtxT]):
|
106
108
|
def __call__(
|
107
|
-
self, output: _OutT_contra, ctx: RunContext[CtxT]
|
108
|
-
) ->
|
109
|
+
self, output: _OutT_contra, ctx: RunContext[CtxT]
|
110
|
+
) -> Sequence[ProcName] | None: ...
|
109
111
|
|
110
112
|
|
111
113
|
class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, CtxT]):
|
@@ -118,7 +120,7 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
|
|
118
120
|
self,
|
119
121
|
name: ProcName,
|
120
122
|
max_retries: int = 0,
|
121
|
-
recipients:
|
123
|
+
recipients: Sequence[ProcName] | None = None,
|
122
124
|
**kwargs: Any,
|
123
125
|
) -> None:
|
124
126
|
self._in_type: type[InT]
|
@@ -239,7 +241,7 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
|
|
239
241
|
) from err
|
240
242
|
|
241
243
|
def _validate_recipients(
|
242
|
-
self, recipients:
|
244
|
+
self, recipients: Sequence[ProcName] | None, call_id: str
|
243
245
|
) -> None:
|
244
246
|
for r in recipients or []:
|
245
247
|
if r not in (self.recipients or []):
|
@@ -252,8 +254,8 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
|
|
252
254
|
|
253
255
|
@final
|
254
256
|
def _select_recipients(
|
255
|
-
self, output: OutT, ctx: RunContext[CtxT]
|
256
|
-
) ->
|
257
|
+
self, output: OutT, ctx: RunContext[CtxT]
|
258
|
+
) -> Sequence[ProcName] | None:
|
257
259
|
if self.recipient_selector:
|
258
260
|
return self.recipient_selector(output=output, ctx=ctx)
|
259
261
|
|
@@ -310,9 +312,15 @@ class BaseProcessor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT, MemT, C
|
|
310
312
|
name: str = tool_name
|
311
313
|
description: str = tool_description
|
312
314
|
|
313
|
-
async def run(
|
315
|
+
async def run(
|
316
|
+
self,
|
317
|
+
inp: InT,
|
318
|
+
*,
|
319
|
+
call_id: str | None = None,
|
320
|
+
ctx: RunContext[CtxT] | None = None,
|
321
|
+
) -> OutT:
|
314
322
|
result = await processor_instance.run(
|
315
|
-
in_args=inp, forgetful=True, ctx=ctx
|
323
|
+
in_args=inp, forgetful=True, call_id=call_id, ctx=ctx
|
316
324
|
)
|
317
325
|
|
318
326
|
return result.payloads[0]
|
@@ -30,7 +30,7 @@ class ParallelProcessor(
|
|
30
30
|
in_args: InT | None = None,
|
31
31
|
memory: MemT,
|
32
32
|
call_id: str,
|
33
|
-
ctx: RunContext[CtxT]
|
33
|
+
ctx: RunContext[CtxT],
|
34
34
|
) -> OutT:
|
35
35
|
return cast("OutT", in_args)
|
36
36
|
|
@@ -41,7 +41,7 @@ class ParallelProcessor(
|
|
41
41
|
in_args: InT | None = None,
|
42
42
|
memory: MemT,
|
43
43
|
call_id: str,
|
44
|
-
ctx: RunContext[CtxT]
|
44
|
+
ctx: RunContext[CtxT],
|
45
45
|
) -> AsyncIterator[Event[Any]]:
|
46
46
|
output = cast("OutT", in_args)
|
47
47
|
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
@@ -67,7 +67,7 @@ class ParallelProcessor(
|
|
67
67
|
in_args: InT | None = None,
|
68
68
|
forgetful: bool = False,
|
69
69
|
call_id: str,
|
70
|
-
ctx: RunContext[CtxT]
|
70
|
+
ctx: RunContext[CtxT],
|
71
71
|
) -> Packet[OutT]:
|
72
72
|
memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
73
73
|
|
@@ -86,7 +86,7 @@ class ParallelProcessor(
|
|
86
86
|
return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
|
87
87
|
|
88
88
|
async def _run_parallel(
|
89
|
-
self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT]
|
89
|
+
self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT]
|
90
90
|
) -> Packet[OutT]:
|
91
91
|
tasks = [
|
92
92
|
self._run_single(
|
@@ -114,6 +114,7 @@ class ParallelProcessor(
|
|
114
114
|
ctx: RunContext[CtxT] | None = None,
|
115
115
|
) -> Packet[OutT]:
|
116
116
|
call_id = self._generate_call_id(call_id)
|
117
|
+
ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
|
117
118
|
|
118
119
|
val_in_args = self._validate_inputs(
|
119
120
|
call_id=call_id,
|
@@ -143,7 +144,7 @@ class ParallelProcessor(
|
|
143
144
|
in_args: InT | None = None,
|
144
145
|
forgetful: bool = False,
|
145
146
|
call_id: str,
|
146
|
-
ctx: RunContext[CtxT]
|
147
|
+
ctx: RunContext[CtxT],
|
147
148
|
) -> AsyncIterator[Event[Any]]:
|
148
149
|
memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
149
150
|
|
@@ -178,7 +179,7 @@ class ParallelProcessor(
|
|
178
179
|
self,
|
179
180
|
in_args: list[InT],
|
180
181
|
call_id: str,
|
181
|
-
ctx: RunContext[CtxT]
|
182
|
+
ctx: RunContext[CtxT],
|
182
183
|
) -> AsyncIterator[Event[Any]]:
|
183
184
|
streams = [
|
184
185
|
self._run_single_stream(
|
@@ -222,6 +223,7 @@ class ParallelProcessor(
|
|
222
223
|
ctx: RunContext[CtxT] | None = None,
|
223
224
|
) -> AsyncIterator[Event[Any]]:
|
224
225
|
call_id = self._generate_call_id(call_id)
|
226
|
+
ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
|
225
227
|
|
226
228
|
val_in_args = self._validate_inputs(
|
227
229
|
call_id=call_id,
|