grasp_agents 0.5.9__py3-none-any.whl → 0.5.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +87 -109
- grasp_agents/litellm/converters.py +4 -2
- grasp_agents/litellm/lite_llm.py +72 -83
- grasp_agents/llm.py +35 -68
- grasp_agents/llm_agent.py +76 -52
- grasp_agents/llm_agent_memory.py +4 -2
- grasp_agents/llm_policy_executor.py +91 -55
- grasp_agents/openai/converters.py +4 -2
- grasp_agents/openai/openai_llm.py +61 -88
- grasp_agents/openai/tool_converters.py +6 -4
- grasp_agents/processors/base_processor.py +18 -10
- grasp_agents/processors/parallel_processor.py +8 -6
- grasp_agents/processors/processor.py +10 -6
- grasp_agents/prompt_builder.py +38 -28
- grasp_agents/run_context.py +1 -1
- grasp_agents/runner.py +1 -1
- grasp_agents/typing/converters.py +3 -1
- grasp_agents/typing/tool.py +15 -5
- grasp_agents/workflow/workflow_processor.py +4 -4
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.11.dist-info}/METADATA +4 -5
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.11.dist-info}/RECORD +23 -23
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.11.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.9.dist-info → grasp_agents-0.5.11.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
|
-
from collections.abc import AsyncIterator, Coroutine, Sequence
|
3
|
+
from collections.abc import AsyncIterator, Coroutine, Mapping, Sequence
|
4
4
|
from itertools import starmap
|
5
5
|
from logging import getLogger
|
6
6
|
from typing import Any, Generic, Protocol, final
|
@@ -36,7 +36,8 @@ class ToolCallLoopTerminator(Protocol[CtxT]):
|
|
36
36
|
self,
|
37
37
|
conversation: Messages,
|
38
38
|
*,
|
39
|
-
ctx: RunContext[CtxT]
|
39
|
+
ctx: RunContext[CtxT],
|
40
|
+
call_id: str,
|
40
41
|
**kwargs: Any,
|
41
42
|
) -> bool: ...
|
42
43
|
|
@@ -46,7 +47,8 @@ class MemoryManager(Protocol[CtxT]):
|
|
46
47
|
self,
|
47
48
|
memory: LLMAgentMemory,
|
48
49
|
*,
|
49
|
-
ctx: RunContext[CtxT]
|
50
|
+
ctx: RunContext[CtxT],
|
51
|
+
call_id: str,
|
50
52
|
**kwargs: Any,
|
51
53
|
) -> None: ...
|
52
54
|
|
@@ -54,9 +56,12 @@ class MemoryManager(Protocol[CtxT]):
|
|
54
56
|
class LLMPolicyExecutor(Generic[CtxT]):
|
55
57
|
def __init__(
|
56
58
|
self,
|
59
|
+
*,
|
57
60
|
agent_name: str,
|
58
61
|
llm: LLM[LLMSettings, Converters],
|
59
62
|
tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
|
63
|
+
response_schema: Any | None = None,
|
64
|
+
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
60
65
|
max_turns: int,
|
61
66
|
react_mode: bool = False,
|
62
67
|
final_answer_type: type[BaseModel] = BaseModel,
|
@@ -70,12 +75,15 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
70
75
|
self._final_answer_as_tool_call = final_answer_as_tool_call
|
71
76
|
self._final_answer_tool = self.get_final_answer_tool()
|
72
77
|
|
73
|
-
|
78
|
+
tools_list: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
|
74
79
|
if tools and final_answer_as_tool_call:
|
75
|
-
|
80
|
+
tools_list = tools + [self._final_answer_tool]
|
81
|
+
self._tools = {t.name: t for t in tools_list} if tools_list else None
|
82
|
+
|
83
|
+
self._response_schema = response_schema
|
84
|
+
self._response_schema_by_xml_tag = response_schema_by_xml_tag
|
76
85
|
|
77
86
|
self._llm = llm
|
78
|
-
self._llm.tools = _tools
|
79
87
|
|
80
88
|
self._max_turns = max_turns
|
81
89
|
self._react_mode = react_mode
|
@@ -91,9 +99,21 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
91
99
|
def llm(self) -> LLM[LLMSettings, Converters]:
|
92
100
|
return self._llm
|
93
101
|
|
102
|
+
@property
|
103
|
+
def response_schema(self) -> Any | None:
|
104
|
+
return self._response_schema
|
105
|
+
|
106
|
+
@response_schema.setter
|
107
|
+
def response_schema(self, value: Any | None) -> None:
|
108
|
+
self._response_schema = value
|
109
|
+
|
110
|
+
@property
|
111
|
+
def response_schema_by_xml_tag(self) -> Mapping[str, Any] | None:
|
112
|
+
return self._response_schema_by_xml_tag
|
113
|
+
|
94
114
|
@property
|
95
115
|
def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
|
96
|
-
return self.
|
116
|
+
return self._tools or {}
|
97
117
|
|
98
118
|
@property
|
99
119
|
def max_turns(self) -> int:
|
@@ -104,11 +124,14 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
104
124
|
self,
|
105
125
|
conversation: Messages,
|
106
126
|
*,
|
107
|
-
ctx: RunContext[CtxT]
|
127
|
+
ctx: RunContext[CtxT],
|
128
|
+
call_id: str,
|
108
129
|
**kwargs: Any,
|
109
130
|
) -> bool:
|
110
131
|
if self.tool_call_loop_terminator:
|
111
|
-
return self.tool_call_loop_terminator(
|
132
|
+
return self.tool_call_loop_terminator(
|
133
|
+
conversation, ctx=ctx, call_id=call_id, **kwargs
|
134
|
+
)
|
112
135
|
|
113
136
|
return False
|
114
137
|
|
@@ -117,21 +140,26 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
117
140
|
self,
|
118
141
|
memory: LLMAgentMemory,
|
119
142
|
*,
|
120
|
-
ctx: RunContext[CtxT]
|
143
|
+
ctx: RunContext[CtxT],
|
144
|
+
call_id: str,
|
121
145
|
**kwargs: Any,
|
122
146
|
) -> None:
|
123
147
|
if self.memory_manager:
|
124
|
-
self.memory_manager(memory=memory, ctx=ctx, **kwargs)
|
148
|
+
self.memory_manager(memory=memory, ctx=ctx, call_id=call_id, **kwargs)
|
125
149
|
|
126
150
|
async def generate_message(
|
127
151
|
self,
|
128
152
|
memory: LLMAgentMemory,
|
129
|
-
|
153
|
+
*,
|
130
154
|
tool_choice: ToolChoice | None = None,
|
131
|
-
ctx: RunContext[CtxT]
|
155
|
+
ctx: RunContext[CtxT],
|
156
|
+
call_id: str,
|
132
157
|
) -> AssistantMessage:
|
133
158
|
completion = await self.llm.generate_completion(
|
134
159
|
memory.message_history,
|
160
|
+
response_schema=self.response_schema,
|
161
|
+
response_schema_by_xml_tag=self.response_schema_by_xml_tag,
|
162
|
+
tools=self.tools,
|
135
163
|
tool_choice=tool_choice,
|
136
164
|
n_choices=1,
|
137
165
|
proc_name=self.agent_name,
|
@@ -139,7 +167,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
139
167
|
)
|
140
168
|
memory.update(completion.messages)
|
141
169
|
self._process_completion(
|
142
|
-
completion,
|
170
|
+
completion, ctx=ctx, call_id=call_id, print_messages=True
|
143
171
|
)
|
144
172
|
|
145
173
|
return completion.messages[0]
|
@@ -147,9 +175,10 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
147
175
|
async def generate_message_stream(
|
148
176
|
self,
|
149
177
|
memory: LLMAgentMemory,
|
150
|
-
|
178
|
+
*,
|
151
179
|
tool_choice: ToolChoice | None = None,
|
152
|
-
ctx: RunContext[CtxT]
|
180
|
+
ctx: RunContext[CtxT],
|
181
|
+
call_id: str,
|
153
182
|
) -> AsyncIterator[
|
154
183
|
CompletionChunkEvent[CompletionChunk]
|
155
184
|
| CompletionEvent
|
@@ -160,6 +189,9 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
160
189
|
|
161
190
|
llm_event_stream = self.llm.generate_completion_stream(
|
162
191
|
memory.message_history,
|
192
|
+
response_schema=self.response_schema,
|
193
|
+
response_schema_by_xml_tag=self.response_schema_by_xml_tag,
|
194
|
+
tools=self.tools,
|
163
195
|
tool_choice=tool_choice,
|
164
196
|
n_choices=1,
|
165
197
|
proc_name=self.agent_name,
|
@@ -181,22 +213,22 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
181
213
|
memory.update(completion.messages)
|
182
214
|
|
183
215
|
self._process_completion(
|
184
|
-
completion,
|
216
|
+
completion, print_messages=True, ctx=ctx, call_id=call_id
|
185
217
|
)
|
186
218
|
|
187
219
|
async def call_tools(
|
188
220
|
self,
|
189
221
|
calls: Sequence[ToolCall],
|
190
222
|
memory: LLMAgentMemory,
|
223
|
+
ctx: RunContext[CtxT],
|
191
224
|
call_id: str,
|
192
|
-
ctx: RunContext[CtxT] | None = None,
|
193
225
|
) -> Sequence[ToolMessage]:
|
194
226
|
# TODO: Add image support
|
195
227
|
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
196
228
|
for call in calls:
|
197
229
|
tool = self.tools[call.tool_name]
|
198
230
|
args = json.loads(call.tool_arguments)
|
199
|
-
corouts.append(tool(ctx=ctx, **args))
|
231
|
+
corouts.append(tool(ctx=ctx, call_id=call_id, **args))
|
200
232
|
|
201
233
|
outs = await asyncio.gather(*corouts)
|
202
234
|
tool_messages = list(
|
@@ -216,11 +248,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
216
248
|
self,
|
217
249
|
calls: Sequence[ToolCall],
|
218
250
|
memory: LLMAgentMemory,
|
251
|
+
ctx: RunContext[CtxT],
|
219
252
|
call_id: str,
|
220
|
-
ctx: RunContext[CtxT] | None = None,
|
221
253
|
) -> AsyncIterator[ToolMessageEvent]:
|
222
254
|
tool_messages = await self.call_tools(
|
223
|
-
calls, memory=memory,
|
255
|
+
calls, memory=memory, ctx=ctx, call_id=call_id
|
224
256
|
)
|
225
257
|
for tool_message, call in zip(tool_messages, calls, strict=True):
|
226
258
|
yield ToolMessageEvent(
|
@@ -245,7 +277,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
245
277
|
return final_answer_message
|
246
278
|
|
247
279
|
async def _generate_final_answer(
|
248
|
-
self, memory: LLMAgentMemory,
|
280
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
|
249
281
|
) -> AssistantMessage:
|
250
282
|
user_message = UserMessage.from_text(
|
251
283
|
"Exceeded the maximum number of turns: provide a final answer now!"
|
@@ -258,7 +290,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
258
290
|
|
259
291
|
tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
|
260
292
|
await self.generate_message(
|
261
|
-
memory, tool_choice=tool_choice,
|
293
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
262
294
|
)
|
263
295
|
|
264
296
|
final_answer_message = self._extract_final_answer_from_tool_calls(memory=memory)
|
@@ -268,7 +300,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
268
300
|
return final_answer_message
|
269
301
|
|
270
302
|
async def _generate_final_answer_stream(
|
271
|
-
self, memory: LLMAgentMemory,
|
303
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
|
272
304
|
) -> AsyncIterator[Event[Any]]:
|
273
305
|
user_message = UserMessage.from_text(
|
274
306
|
"Exceeded the maximum number of turns: provide a final answer now!",
|
@@ -284,7 +316,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
284
316
|
|
285
317
|
tool_choice = NamedToolChoice(name=self._final_answer_tool.name)
|
286
318
|
async for event in self.generate_message_stream(
|
287
|
-
memory, tool_choice=tool_choice,
|
319
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
288
320
|
):
|
289
321
|
yield event
|
290
322
|
|
@@ -296,7 +328,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
296
328
|
)
|
297
329
|
|
298
330
|
async def execute(
|
299
|
-
self, memory: LLMAgentMemory,
|
331
|
+
self, memory: LLMAgentMemory, ctx: RunContext[CtxT], call_id: str
|
300
332
|
) -> AssistantMessage | Sequence[AssistantMessage]:
|
301
333
|
# 1. Generate the first message:
|
302
334
|
# In ReAct mode, we generate the first message without tool calls
|
@@ -306,7 +338,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
306
338
|
if self.tools:
|
307
339
|
tool_choice = "none" if self._react_mode else "auto"
|
308
340
|
gen_message = await self.generate_message(
|
309
|
-
memory, tool_choice=tool_choice,
|
341
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
310
342
|
)
|
311
343
|
if not self.tools:
|
312
344
|
return gen_message
|
@@ -319,7 +351,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
319
351
|
# If a final answer is not provided via a tool call, we use
|
320
352
|
# _terminate_tool_call_loop to determine whether to exit the loop.
|
321
353
|
if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
|
322
|
-
memory.message_history, ctx=ctx, num_turns=turns
|
354
|
+
memory.message_history, ctx=ctx, call_id=call_id, num_turns=turns
|
323
355
|
):
|
324
356
|
return gen_message
|
325
357
|
|
@@ -338,7 +370,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
338
370
|
# Otherwise, we simply return the last generated message.
|
339
371
|
if self._final_answer_as_tool_call:
|
340
372
|
final_answer = await self._generate_final_answer(
|
341
|
-
memory,
|
373
|
+
memory, ctx=ctx, call_id=call_id
|
342
374
|
)
|
343
375
|
else:
|
344
376
|
final_answer = gen_message
|
@@ -351,11 +383,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
351
383
|
|
352
384
|
if gen_message.tool_calls:
|
353
385
|
await self.call_tools(
|
354
|
-
gen_message.tool_calls, memory=memory,
|
386
|
+
gen_message.tool_calls, memory=memory, ctx=ctx, call_id=call_id
|
355
387
|
)
|
356
388
|
|
357
389
|
# Apply memory management (e.g. compacting or pruning memory)
|
358
|
-
self._manage_memory(memory, ctx=ctx, num_turns=turns)
|
390
|
+
self._manage_memory(memory, ctx=ctx, call_id=call_id, num_turns=turns)
|
359
391
|
|
360
392
|
# 4. Generate the next message based on the updated memory.
|
361
393
|
# In ReAct mode, we set tool_choice to "none" if we just called tools,
|
@@ -370,7 +402,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
370
402
|
tool_choice = "required"
|
371
403
|
|
372
404
|
gen_message = await self.generate_message(
|
373
|
-
memory, tool_choice=tool_choice,
|
405
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
374
406
|
)
|
375
407
|
|
376
408
|
turns += 1
|
@@ -378,13 +410,13 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
378
410
|
async def execute_stream(
|
379
411
|
self,
|
380
412
|
memory: LLMAgentMemory,
|
413
|
+
ctx: RunContext[CtxT],
|
381
414
|
call_id: str,
|
382
|
-
ctx: RunContext[CtxT] | None = None,
|
383
415
|
) -> AsyncIterator[Event[Any]]:
|
384
416
|
tool_choice: ToolChoice = "none" if self._react_mode else "auto"
|
385
417
|
gen_message: AssistantMessage | None = None
|
386
418
|
async for event in self.generate_message_stream(
|
387
|
-
memory, tool_choice=tool_choice,
|
419
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
388
420
|
):
|
389
421
|
if isinstance(event, GenMessageEvent):
|
390
422
|
gen_message = event.data
|
@@ -399,7 +431,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
399
431
|
|
400
432
|
while True:
|
401
433
|
if not self._final_answer_as_tool_call and self._terminate_tool_call_loop(
|
402
|
-
memory.message_history, ctx=ctx, num_turns=turns
|
434
|
+
memory.message_history, ctx=ctx, call_id=call_id, num_turns=turns
|
403
435
|
):
|
404
436
|
return
|
405
437
|
|
@@ -418,7 +450,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
418
450
|
if turns >= self.max_turns:
|
419
451
|
if self._final_answer_as_tool_call:
|
420
452
|
async for event in self._generate_final_answer_stream(
|
421
|
-
memory,
|
453
|
+
memory, ctx=ctx, call_id=call_id
|
422
454
|
):
|
423
455
|
yield event
|
424
456
|
logger.info(
|
@@ -433,11 +465,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
433
465
|
)
|
434
466
|
|
435
467
|
async for event in self.call_tools_stream(
|
436
|
-
gen_message.tool_calls, memory=memory,
|
468
|
+
gen_message.tool_calls, memory=memory, ctx=ctx, call_id=call_id
|
437
469
|
):
|
438
470
|
yield event
|
439
471
|
|
440
|
-
self._manage_memory(memory, ctx=ctx, num_turns=turns)
|
472
|
+
self._manage_memory(memory, ctx=ctx, call_id=call_id, num_turns=turns)
|
441
473
|
|
442
474
|
if self._react_mode and gen_message.tool_calls:
|
443
475
|
tool_choice = "none"
|
@@ -447,7 +479,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
447
479
|
tool_choice = "required"
|
448
480
|
|
449
481
|
async for event in self.generate_message_stream(
|
450
|
-
memory, tool_choice=tool_choice,
|
482
|
+
memory, tool_choice=tool_choice, ctx=ctx, call_id=call_id
|
451
483
|
):
|
452
484
|
yield event
|
453
485
|
if isinstance(event, GenMessageEvent):
|
@@ -464,7 +496,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
464
496
|
)
|
465
497
|
|
466
498
|
async def run(
|
467
|
-
self,
|
499
|
+
self,
|
500
|
+
inp: BaseModel,
|
501
|
+
*,
|
502
|
+
ctx: RunContext[Any] | None = None,
|
503
|
+
call_id: str | None = None,
|
468
504
|
) -> None:
|
469
505
|
return None
|
470
506
|
|
@@ -473,22 +509,22 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
473
509
|
def _process_completion(
|
474
510
|
self,
|
475
511
|
completion: Completion,
|
476
|
-
|
512
|
+
*,
|
477
513
|
print_messages: bool = False,
|
478
|
-
ctx: RunContext[CtxT]
|
514
|
+
ctx: RunContext[CtxT],
|
515
|
+
call_id: str,
|
479
516
|
) -> None:
|
480
|
-
|
481
|
-
|
482
|
-
|
517
|
+
ctx.completions[self.agent_name].append(completion)
|
518
|
+
ctx.usage_tracker.update(
|
519
|
+
agent_name=self.agent_name,
|
520
|
+
completions=[completion],
|
521
|
+
model_name=self.llm.model_name,
|
522
|
+
)
|
523
|
+
if ctx.printer and print_messages:
|
524
|
+
usages = [None] * (len(completion.messages) - 1) + [completion.usage]
|
525
|
+
ctx.printer.print_messages(
|
526
|
+
completion.messages,
|
527
|
+
usages=usages,
|
483
528
|
agent_name=self.agent_name,
|
484
|
-
|
485
|
-
model_name=self.llm.model_name,
|
529
|
+
call_id=call_id,
|
486
530
|
)
|
487
|
-
if ctx.printer and print_messages:
|
488
|
-
usages = [None] * (len(completion.messages) - 1) + [completion.usage]
|
489
|
-
ctx.printer.print_messages(
|
490
|
-
completion.messages,
|
491
|
-
usages=usages,
|
492
|
-
agent_name=self.agent_name,
|
493
|
-
call_id=call_id,
|
494
|
-
)
|
@@ -96,8 +96,10 @@ class OpenAIConverters(Converters):
|
|
96
96
|
return from_api_tool_message(raw_message, name=name, **kwargs)
|
97
97
|
|
98
98
|
@staticmethod
|
99
|
-
def to_tool(
|
100
|
-
|
99
|
+
def to_tool(
|
100
|
+
tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None, **kwargs: Any
|
101
|
+
) -> OpenAIToolParam:
|
102
|
+
return to_api_tool(tool, strict=strict, **kwargs)
|
101
103
|
|
102
104
|
@staticmethod
|
103
105
|
def to_tool_choice(
|
@@ -3,9 +3,9 @@ import logging
|
|
3
3
|
import os
|
4
4
|
from collections.abc import AsyncIterator, Iterable, Mapping
|
5
5
|
from copy import deepcopy
|
6
|
+
from dataclasses import dataclass, field
|
6
7
|
from typing import Any, Literal
|
7
8
|
|
8
|
-
import httpx
|
9
9
|
from openai import AsyncOpenAI, AsyncStream
|
10
10
|
from openai._types import NOT_GIVEN # type: ignore[import]
|
11
11
|
from openai.lib.streaming.chat import (
|
@@ -15,8 +15,7 @@ from openai.lib.streaming.chat import ChatCompletionStreamState
|
|
15
15
|
from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
|
16
16
|
from pydantic import BaseModel
|
17
17
|
|
18
|
-
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
19
|
-
from ..http_client import AsyncHTTPClientParams
|
18
|
+
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
20
19
|
from ..typing.tool import BaseTool
|
21
20
|
from . import (
|
22
21
|
OpenAICompletion,
|
@@ -61,7 +60,7 @@ def get_openai_compatible_providers() -> list[APIProvider]:
|
|
61
60
|
|
62
61
|
|
63
62
|
class OpenAILLMSettings(CloudLLMSettings, total=False):
|
64
|
-
reasoning_effort: Literal["low", "medium", "high"] | None
|
63
|
+
reasoning_effort: Literal["disable", "minimal", "low", "medium", "high"] | None
|
65
64
|
|
66
65
|
parallel_tool_calls: bool
|
67
66
|
|
@@ -90,105 +89,75 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
|
|
90
89
|
# TODO: support audio
|
91
90
|
|
92
91
|
|
92
|
+
@dataclass(frozen=True)
|
93
93
|
class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
response_schema_by_xml_tag: Mapping[str, Any] | None = None,
|
102
|
-
apply_response_schema_via_provider: bool = False,
|
103
|
-
model_id: str | None = None,
|
104
|
-
# Custom LLM provider
|
105
|
-
api_provider: APIProvider | None = None,
|
106
|
-
# Connection settings
|
107
|
-
max_client_retries: int = 2,
|
108
|
-
async_http_client: httpx.AsyncClient | None = None,
|
109
|
-
async_http_client_params: (
|
110
|
-
dict[str, Any] | AsyncHTTPClientParams | None
|
111
|
-
) = None,
|
112
|
-
async_openai_client_params: dict[str, Any] | None = None,
|
113
|
-
# Rate limiting
|
114
|
-
rate_limiter: LLMRateLimiter | None = None,
|
115
|
-
# LLM response retries: try to regenerate to pass validation
|
116
|
-
max_response_retries: int = 1,
|
117
|
-
) -> None:
|
94
|
+
converters: OpenAIConverters = field(default_factory=OpenAIConverters)
|
95
|
+
async_openai_client_params: dict[str, Any] | None = None
|
96
|
+
client: AsyncOpenAI = field(init=False)
|
97
|
+
|
98
|
+
def __post_init__(self):
|
99
|
+
super().__post_init__()
|
100
|
+
|
118
101
|
openai_compatible_providers = get_openai_compatible_providers()
|
119
102
|
|
120
|
-
|
121
|
-
|
122
|
-
|
103
|
+
_api_provider = self.api_provider
|
104
|
+
|
105
|
+
model_name_parts = self.model_name.split("/", 1)
|
106
|
+
if _api_provider is not None:
|
107
|
+
_model_name = self.model_name
|
123
108
|
elif len(model_name_parts) == 2:
|
124
109
|
compat_providers_map = {
|
125
110
|
provider["name"]: provider for provider in openai_compatible_providers
|
126
111
|
}
|
127
|
-
provider_name,
|
112
|
+
provider_name, _model_name = model_name_parts
|
128
113
|
if provider_name not in compat_providers_map:
|
129
114
|
raise ValueError(
|
130
115
|
f"API provider '{provider_name}' is not a supported OpenAI "
|
131
116
|
f"compatible provider. Supported providers are: "
|
132
117
|
f"{', '.join(compat_providers_map.keys())}"
|
133
118
|
)
|
134
|
-
|
119
|
+
_api_provider = compat_providers_map[provider_name]
|
135
120
|
else:
|
136
121
|
raise ValueError(
|
137
122
|
"Model name must be in the format 'provider/model_name' or "
|
138
123
|
"you must provide an 'api_provider' argument."
|
139
124
|
)
|
140
125
|
|
141
|
-
if llm_settings is not None:
|
142
|
-
stream_options = llm_settings.get("stream_options") or {}
|
126
|
+
if self.llm_settings is not None:
|
127
|
+
stream_options = self.llm_settings.get("stream_options") or {}
|
143
128
|
stream_options["include_usage"] = True
|
144
|
-
_llm_settings = deepcopy(llm_settings)
|
129
|
+
_llm_settings = deepcopy(self.llm_settings)
|
145
130
|
_llm_settings["stream_options"] = stream_options
|
146
131
|
else:
|
147
132
|
_llm_settings = OpenAILLMSettings(stream_options={"include_usage": True})
|
148
133
|
|
149
|
-
super().__init__(
|
150
|
-
model_name=provider_model_name,
|
151
|
-
model_id=model_id,
|
152
|
-
llm_settings=_llm_settings,
|
153
|
-
converters=OpenAIConverters(),
|
154
|
-
tools=tools,
|
155
|
-
response_schema=response_schema,
|
156
|
-
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
157
|
-
apply_response_schema_via_provider=apply_response_schema_via_provider,
|
158
|
-
api_provider=api_provider,
|
159
|
-
async_http_client=async_http_client,
|
160
|
-
async_http_client_params=async_http_client_params,
|
161
|
-
rate_limiter=rate_limiter,
|
162
|
-
max_client_retries=max_client_retries,
|
163
|
-
max_response_retries=max_response_retries,
|
164
|
-
)
|
165
|
-
|
166
134
|
response_schema_support: bool = any(
|
167
|
-
fnmatch.fnmatch(
|
168
|
-
for pat in
|
135
|
+
fnmatch.fnmatch(_model_name, pat)
|
136
|
+
for pat in _api_provider.get("response_schema_support") or []
|
169
137
|
)
|
170
|
-
if apply_response_schema_via_provider:
|
171
|
-
|
172
|
-
for
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
"Native response schema validation is not supported for model "
|
177
|
-
f"'{self._model_name}' by the API provider. Please set "
|
178
|
-
"apply_response_schema_via_provider=False."
|
179
|
-
)
|
138
|
+
if self.apply_response_schema_via_provider and not response_schema_support:
|
139
|
+
raise ValueError(
|
140
|
+
"Native response schema validation is not supported for model "
|
141
|
+
f"'{_model_name}' by the API provider. Please set "
|
142
|
+
"apply_response_schema_via_provider=False."
|
143
|
+
)
|
180
144
|
|
181
|
-
_async_openai_client_params = deepcopy(async_openai_client_params or {})
|
182
|
-
if self.
|
183
|
-
_async_openai_client_params["http_client"] = self.
|
145
|
+
_async_openai_client_params = deepcopy(self.async_openai_client_params or {})
|
146
|
+
if self.async_http_client is not None:
|
147
|
+
_async_openai_client_params["http_client"] = self.async_http_client
|
184
148
|
|
185
|
-
|
186
|
-
base_url=
|
187
|
-
api_key=
|
188
|
-
max_retries=max_client_retries,
|
149
|
+
_client = AsyncOpenAI(
|
150
|
+
base_url=_api_provider.get("base_url"),
|
151
|
+
api_key=_api_provider.get("api_key"),
|
152
|
+
max_retries=self.max_client_retries,
|
189
153
|
**_async_openai_client_params,
|
190
154
|
)
|
191
155
|
|
156
|
+
object.__setattr__(self, "model_name", _model_name)
|
157
|
+
object.__setattr__(self, "api_provider", _api_provider)
|
158
|
+
object.__setattr__(self, "llm_settings", _llm_settings)
|
159
|
+
object.__setattr__(self, "client", _client)
|
160
|
+
|
192
161
|
async def _get_completion(
|
193
162
|
self,
|
194
163
|
api_messages: Iterable[OpenAIMessageParam],
|
@@ -203,9 +172,9 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
203
172
|
response_format = api_response_schema or NOT_GIVEN
|
204
173
|
n = n_choices or NOT_GIVEN
|
205
174
|
|
206
|
-
if self.
|
207
|
-
return await self.
|
208
|
-
model=self.
|
175
|
+
if self.apply_response_schema_via_provider:
|
176
|
+
return await self.client.beta.chat.completions.parse(
|
177
|
+
model=self.model_name,
|
209
178
|
messages=api_messages,
|
210
179
|
tools=tools,
|
211
180
|
tool_choice=tool_choice,
|
@@ -214,8 +183,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
214
183
|
**api_llm_settings,
|
215
184
|
)
|
216
185
|
|
217
|
-
return await self.
|
218
|
-
model=self.
|
186
|
+
return await self.client.chat.completions.create(
|
187
|
+
model=self.model_name,
|
219
188
|
messages=api_messages,
|
220
189
|
tools=tools,
|
221
190
|
tool_choice=tool_choice,
|
@@ -238,10 +207,10 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
238
207
|
response_format = api_response_schema or NOT_GIVEN
|
239
208
|
n = n_choices or NOT_GIVEN
|
240
209
|
|
241
|
-
if self.
|
210
|
+
if self.apply_response_schema_via_provider:
|
242
211
|
stream_manager: OpenAIAsyncChatCompletionStreamManager[Any] = (
|
243
|
-
self.
|
244
|
-
model=self.
|
212
|
+
self.client.beta.chat.completions.stream(
|
213
|
+
model=self.model_name,
|
245
214
|
messages=api_messages,
|
246
215
|
tools=tools,
|
247
216
|
tool_choice=tool_choice,
|
@@ -257,8 +226,8 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
257
226
|
else:
|
258
227
|
stream_generator: AsyncStream[
|
259
228
|
OpenAICompletionChunk
|
260
|
-
] = await self.
|
261
|
-
model=self.
|
229
|
+
] = await self.client.chat.completions.create(
|
230
|
+
model=self.model_name,
|
262
231
|
messages=api_messages,
|
263
232
|
tools=tools,
|
264
233
|
tool_choice=tool_choice,
|
@@ -271,16 +240,20 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
271
240
|
yield completion_chunk
|
272
241
|
|
273
242
|
def combine_completion_chunks(
|
274
|
-
self,
|
243
|
+
self,
|
244
|
+
completion_chunks: list[OpenAICompletionChunk],
|
245
|
+
response_schema: Any | None = None,
|
246
|
+
tools: Mapping[str, BaseTool[BaseModel, Any, Any]] | None = None,
|
275
247
|
) -> OpenAICompletion:
|
276
248
|
response_format = NOT_GIVEN
|
277
249
|
input_tools = NOT_GIVEN
|
278
|
-
if self.
|
279
|
-
if
|
280
|
-
response_format =
|
281
|
-
if
|
250
|
+
if self.apply_response_schema_via_provider:
|
251
|
+
if response_schema:
|
252
|
+
response_format = response_schema
|
253
|
+
if tools:
|
282
254
|
input_tools = [
|
283
|
-
self.
|
255
|
+
self.converters.to_tool(tool, strict=True)
|
256
|
+
for tool in tools.values()
|
284
257
|
]
|
285
258
|
state = ChatCompletionStreamState[Any](
|
286
259
|
input_tools=input_tools, response_format=response_format
|