grasp_agents 0.3.10__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +70 -77
- grasp_agents/comm_processor.py +21 -11
- grasp_agents/errors.py +34 -0
- grasp_agents/http_client.py +7 -5
- grasp_agents/llm.py +3 -9
- grasp_agents/llm_agent.py +92 -103
- grasp_agents/llm_agent_memory.py +36 -27
- grasp_agents/llm_policy_executor.py +66 -63
- grasp_agents/memory.py +3 -1
- grasp_agents/openai/completion_chunk_converters.py +4 -3
- grasp_agents/openai/openai_llm.py +14 -20
- grasp_agents/openai/tool_converters.py +0 -1
- grasp_agents/packet_pool.py +1 -1
- grasp_agents/printer.py +6 -6
- grasp_agents/processor.py +182 -48
- grasp_agents/prompt_builder.py +41 -55
- grasp_agents/run_context.py +1 -5
- grasp_agents/typing/completion_chunk.py +10 -5
- grasp_agents/typing/content.py +2 -2
- grasp_agents/typing/io.py +4 -4
- grasp_agents/typing/message.py +3 -6
- grasp_agents/typing/tool.py +5 -23
- grasp_agents/usage_tracker.py +2 -4
- grasp_agents/utils.py +37 -15
- grasp_agents/workflow/looped_workflow.py +14 -9
- grasp_agents/workflow/sequential_workflow.py +11 -6
- grasp_agents/workflow/workflow_processor.py +30 -13
- {grasp_agents-0.3.10.dist-info → grasp_agents-0.4.0.dist-info}/METADATA +2 -1
- grasp_agents-0.4.0.dist-info/RECORD +50 -0
- grasp_agents/message_history.py +0 -140
- grasp_agents/workflow/parallel_processor.py +0 -95
- grasp_agents-0.3.10.dist-info/RECORD +0 -51
- {grasp_agents-0.3.10.dist-info → grasp_agents-0.4.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.3.10.dist-info → grasp_agents-0.4.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
3
|
from collections.abc import AsyncIterator, Coroutine, Sequence
|
4
|
-
from itertools import
|
4
|
+
from itertools import starmap
|
5
5
|
from logging import getLogger
|
6
6
|
from typing import Any, ClassVar, Generic, Protocol, TypeVar
|
7
7
|
|
@@ -132,62 +132,58 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
132
132
|
if self.manage_memory_impl:
|
133
133
|
self.manage_memory_impl(memory=memory, ctx=ctx, **kwargs)
|
134
134
|
|
135
|
-
async def
|
135
|
+
async def generate_messages(
|
136
136
|
self,
|
137
137
|
memory: LLMAgentMemory,
|
138
|
+
run_id: str,
|
138
139
|
tool_choice: ToolChoice | None = None,
|
139
140
|
ctx: RunContext[CtxT] | None = None,
|
140
141
|
) -> Sequence[AssistantMessage]:
|
141
|
-
|
142
|
+
completion = await self.llm.generate_completion(
|
142
143
|
memory.message_history, tool_choice=tool_choice
|
143
144
|
)
|
144
|
-
|
145
|
-
chain.from_iterable([c.messages for c in completion_batch])
|
146
|
-
)
|
147
|
-
memory.update(message_batch=message_batch)
|
145
|
+
memory.update(completion.messages)
|
148
146
|
|
149
147
|
if ctx is not None:
|
150
|
-
ctx.completions[self.agent_name].
|
151
|
-
self._track_usage(self.agent_name,
|
152
|
-
self.
|
148
|
+
ctx.completions[self.agent_name].append(completion)
|
149
|
+
self._track_usage(self.agent_name, completion, ctx=ctx)
|
150
|
+
self._print_completion(completion, run_id=run_id, ctx=ctx)
|
153
151
|
|
154
|
-
return
|
152
|
+
return completion.messages
|
155
153
|
|
156
|
-
async def
|
154
|
+
async def generate_messages_stream(
|
157
155
|
self,
|
158
156
|
memory: LLMAgentMemory,
|
157
|
+
run_id: str,
|
159
158
|
tool_choice: ToolChoice | None = None,
|
160
159
|
ctx: RunContext[CtxT] | None = None,
|
161
160
|
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | GenMessageEvent]:
|
162
161
|
message_hist = memory.message_history
|
163
|
-
if memory.message_history.batch_size > 1:
|
164
|
-
raise ValueError("Batch size must be 1 when streaming completions.")
|
165
|
-
conversation = message_hist.conversations[0]
|
166
162
|
|
167
163
|
completion: Completion | None = None
|
168
164
|
async for event in await self.llm.generate_completion_stream(
|
169
|
-
|
165
|
+
message_hist, tool_choice=tool_choice
|
170
166
|
):
|
171
167
|
yield event
|
172
168
|
if isinstance(event, CompletionEvent):
|
173
169
|
completion = event.data
|
174
|
-
|
175
170
|
if completion is None:
|
176
171
|
raise RuntimeError("No completion generated during stream.")
|
177
172
|
|
178
|
-
memory.update(
|
173
|
+
memory.update(completion.messages)
|
179
174
|
|
180
175
|
for message in completion.messages:
|
181
176
|
yield GenMessageEvent(name=self.agent_name, data=message)
|
182
177
|
|
183
178
|
if ctx is not None:
|
184
|
-
self._track_usage(self.agent_name,
|
179
|
+
self._track_usage(self.agent_name, completion, ctx=ctx)
|
185
180
|
ctx.completions[self.agent_name].append(completion)
|
186
181
|
|
187
182
|
async def call_tools(
|
188
183
|
self,
|
189
184
|
calls: Sequence[ToolCall],
|
190
185
|
memory: LLMAgentMemory,
|
186
|
+
run_id: str,
|
191
187
|
ctx: RunContext[CtxT] | None = None,
|
192
188
|
) -> Sequence[ToolMessage]:
|
193
189
|
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
@@ -198,12 +194,14 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
198
194
|
|
199
195
|
outs = await asyncio.gather(*corouts)
|
200
196
|
tool_messages = list(
|
201
|
-
starmap(ToolMessage.from_tool_output, zip(outs, calls, strict=
|
197
|
+
starmap(ToolMessage.from_tool_output, zip(outs, calls, strict=True))
|
202
198
|
)
|
203
199
|
memory.update(tool_messages)
|
204
200
|
|
205
201
|
if ctx is not None:
|
206
|
-
ctx.printer.print_llm_messages(
|
202
|
+
ctx.printer.print_llm_messages(
|
203
|
+
tool_messages, agent_name=self.agent_name, run_id=run_id
|
204
|
+
)
|
207
205
|
|
208
206
|
return tool_messages
|
209
207
|
|
@@ -211,10 +209,13 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
211
209
|
self,
|
212
210
|
calls: Sequence[ToolCall],
|
213
211
|
memory: LLMAgentMemory,
|
212
|
+
run_id: str,
|
214
213
|
ctx: RunContext[CtxT] | None = None,
|
215
214
|
) -> AsyncIterator[ToolMessageEvent]:
|
216
|
-
tool_messages = await self.call_tools(
|
217
|
-
|
215
|
+
tool_messages = await self.call_tools(
|
216
|
+
calls, memory=memory, run_id=run_id, ctx=ctx
|
217
|
+
)
|
218
|
+
for tool_message, call in zip(tool_messages, calls, strict=True):
|
218
219
|
yield ToolMessageEvent(name=call.tool_name, data=tool_message)
|
219
220
|
|
220
221
|
def _extract_final_answer_from_tool_calls(
|
@@ -233,7 +234,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
233
234
|
return final_answer_message
|
234
235
|
|
235
236
|
async def _generate_final_answer(
|
236
|
-
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
237
|
+
self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
|
237
238
|
) -> AssistantMessage:
|
238
239
|
assert self._final_answer_tool_name is not None
|
239
240
|
|
@@ -242,11 +243,15 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
242
243
|
)
|
243
244
|
memory.update([user_message])
|
244
245
|
if ctx is not None:
|
245
|
-
ctx.printer.print_llm_messages(
|
246
|
+
ctx.printer.print_llm_messages(
|
247
|
+
[user_message], agent_name=self.agent_name, run_id=run_id
|
248
|
+
)
|
246
249
|
|
247
250
|
tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
|
248
251
|
gen_message = (
|
249
|
-
await self.
|
252
|
+
await self.generate_messages(
|
253
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
254
|
+
)
|
250
255
|
)[0]
|
251
256
|
|
252
257
|
final_answer_message = self._extract_final_answer_from_tool_calls(
|
@@ -260,7 +265,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
260
265
|
return final_answer_message
|
261
266
|
|
262
267
|
async def _generate_final_answer_stream(
|
263
|
-
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
268
|
+
self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
|
264
269
|
) -> AsyncIterator[Event[Any]]:
|
265
270
|
assert self._final_answer_tool_name is not None
|
266
271
|
|
@@ -272,8 +277,8 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
272
277
|
|
273
278
|
tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
|
274
279
|
event: Event[Any] | None = None
|
275
|
-
async for event in self.
|
276
|
-
memory, tool_choice=tool_choice, ctx=ctx
|
280
|
+
async for event in self.generate_messages_stream(
|
281
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
277
282
|
):
|
278
283
|
yield event
|
279
284
|
|
@@ -289,7 +294,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
289
294
|
yield GenMessageEvent(name=self.agent_name, data=final_answer_message)
|
290
295
|
|
291
296
|
async def execute(
|
292
|
-
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
297
|
+
self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
|
293
298
|
) -> AssistantMessage | Sequence[AssistantMessage]:
|
294
299
|
# 1. Generate the first message:
|
295
300
|
# In ReAct mode, we generate the first message without tool calls
|
@@ -297,26 +302,24 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
297
302
|
tool_choice: ToolChoice | None = None
|
298
303
|
if self.tools:
|
299
304
|
tool_choice = "none" if self._react_mode else "auto"
|
300
|
-
|
301
|
-
memory, tool_choice=tool_choice, ctx=ctx
|
305
|
+
gen_messages = await self.generate_messages(
|
306
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
302
307
|
)
|
303
308
|
if not self.tools:
|
304
|
-
return
|
309
|
+
return gen_messages
|
305
310
|
|
306
|
-
if
|
307
|
-
raise ValueError("
|
308
|
-
gen_message =
|
311
|
+
if len(gen_messages) > 1:
|
312
|
+
raise ValueError("n_choices must be 1 when executing the tool call loop.")
|
313
|
+
gen_message = gen_messages[0]
|
309
314
|
turns = 0
|
310
315
|
|
311
316
|
while True:
|
312
|
-
conversation = memory.message_history.conversations[0]
|
313
|
-
|
314
317
|
# 2. Check if we should exit the tool call loop
|
315
318
|
|
316
319
|
# When final_answer_tool_name is None, we use exit_tool_call_loop_impl
|
317
320
|
# to determine whether to exit the loop.
|
318
321
|
if self._final_answer_tool_name is None and self._exit_tool_call_loop(
|
319
|
-
|
322
|
+
memory.message_history, ctx=ctx, num_turns=turns
|
320
323
|
):
|
321
324
|
return gen_message
|
322
325
|
|
@@ -336,7 +339,9 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
336
339
|
# tool call.
|
337
340
|
# Otherwise, we simply return the last generated message.
|
338
341
|
if self._final_answer_tool_name is not None:
|
339
|
-
final_answer = await self._generate_final_answer(
|
342
|
+
final_answer = await self._generate_final_answer(
|
343
|
+
memory, run_id=run_id, ctx=ctx
|
344
|
+
)
|
340
345
|
else:
|
341
346
|
final_answer = gen_message
|
342
347
|
logger.info(
|
@@ -347,7 +352,9 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
347
352
|
# 3. Call tools if there are any tool calls in the generated message.
|
348
353
|
|
349
354
|
if gen_message.tool_calls:
|
350
|
-
await self.call_tools(
|
355
|
+
await self.call_tools(
|
356
|
+
gen_message.tool_calls, memory=memory, run_id=run_id, ctx=ctx
|
357
|
+
)
|
351
358
|
|
352
359
|
# Apply the memory management function if provided.
|
353
360
|
self._manage_memory(memory, ctx=ctx, num_turns=turns)
|
@@ -363,23 +370,20 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
363
370
|
"none" if (self._react_mode and gen_message.tool_calls) else "required"
|
364
371
|
)
|
365
372
|
gen_message = (
|
366
|
-
await self.
|
367
|
-
memory, tool_choice=tool_choice, ctx=ctx
|
373
|
+
await self.generate_messages(
|
374
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
368
375
|
)
|
369
376
|
)[0]
|
370
377
|
|
371
378
|
turns += 1
|
372
379
|
|
373
380
|
async def execute_stream(
|
374
|
-
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
381
|
+
self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
|
375
382
|
) -> AsyncIterator[Event[Any]]:
|
376
|
-
if memory.message_history.batch_size > 1:
|
377
|
-
raise ValueError("Batch size must be 1 when streaming.")
|
378
|
-
|
379
383
|
tool_choice: ToolChoice = "none" if self._react_mode else "auto"
|
380
384
|
gen_message: AssistantMessage | None = None
|
381
|
-
async for event in self.
|
382
|
-
memory, tool_choice=tool_choice, ctx=ctx
|
385
|
+
async for event in self.generate_messages_stream(
|
386
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
383
387
|
):
|
384
388
|
yield event
|
385
389
|
if isinstance(event, GenMessageEvent):
|
@@ -389,10 +393,8 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
389
393
|
turns = 0
|
390
394
|
|
391
395
|
while True:
|
392
|
-
conversation = memory.message_history.conversations[0]
|
393
|
-
|
394
396
|
if self._final_answer_tool_name is None and self._exit_tool_call_loop(
|
395
|
-
|
397
|
+
memory.message_history, ctx=ctx, num_turns=turns
|
396
398
|
):
|
397
399
|
return
|
398
400
|
|
@@ -409,7 +411,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
409
411
|
if turns >= self.max_turns:
|
410
412
|
if self._final_answer_tool_name is not None:
|
411
413
|
async for event in self._generate_final_answer_stream(
|
412
|
-
memory, ctx=ctx
|
414
|
+
memory, run_id=run_id, ctx=ctx
|
413
415
|
):
|
414
416
|
yield event
|
415
417
|
logger.info(
|
@@ -422,7 +424,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
422
424
|
yield ToolCallEvent(name=self.agent_name, data=tool_call)
|
423
425
|
|
424
426
|
async for tool_message_event in self.call_tools_stream(
|
425
|
-
gen_message.tool_calls, memory=memory, ctx=ctx
|
427
|
+
gen_message.tool_calls, memory=memory, run_id=run_id, ctx=ctx
|
426
428
|
):
|
427
429
|
yield tool_message_event
|
428
430
|
|
@@ -431,8 +433,8 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
431
433
|
tool_choice = (
|
432
434
|
"none" if (self._react_mode and gen_message.tool_calls) else "required"
|
433
435
|
)
|
434
|
-
async for event in self.
|
435
|
-
memory, tool_choice=tool_choice, ctx=ctx
|
436
|
+
async for event in self.generate_messages_stream(
|
437
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
436
438
|
):
|
437
439
|
yield event
|
438
440
|
if isinstance(event, GenMessageEvent):
|
@@ -443,12 +445,12 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
443
445
|
def _track_usage(
|
444
446
|
self,
|
445
447
|
agent_name: str,
|
446
|
-
|
448
|
+
completion: Completion,
|
447
449
|
ctx: RunContext[CtxT],
|
448
450
|
) -> None:
|
449
451
|
ctx.usage_tracker.update(
|
450
452
|
agent_name=agent_name,
|
451
|
-
completions=
|
453
|
+
completions=[completion],
|
452
454
|
model_name=self.llm.model_name,
|
453
455
|
)
|
454
456
|
|
@@ -473,11 +475,12 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
473
475
|
|
474
476
|
return FinalAnswerTool()
|
475
477
|
|
476
|
-
def
|
477
|
-
self,
|
478
|
+
def _print_completion(
|
479
|
+
self, completion: Completion, run_id: str, ctx: RunContext[CtxT]
|
478
480
|
) -> None:
|
479
|
-
messages = [c.messages[0] for c in completion_batch]
|
480
|
-
usages = [c.usage for c in completion_batch]
|
481
481
|
ctx.printer.print_llm_messages(
|
482
|
-
messages,
|
482
|
+
completion.messages,
|
483
|
+
usages=[completion.usage],
|
484
|
+
agent_name=self.agent_name,
|
485
|
+
run_id=run_id,
|
483
486
|
)
|
grasp_agents/memory.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any, TypeVar
|
3
3
|
|
4
4
|
from pydantic import BaseModel, ConfigDict
|
5
5
|
|
6
6
|
from .run_context import RunContext
|
7
7
|
|
8
|
+
MemT = TypeVar("MemT", bound="Memory")
|
9
|
+
|
8
10
|
|
9
11
|
class Memory(BaseModel, ABC):
|
10
12
|
@abstractmethod
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from ..errors import CompletionError
|
1
2
|
from ..typing.completion_chunk import (
|
2
3
|
CompletionChunk,
|
3
4
|
CompletionChunkChoice,
|
@@ -12,7 +13,7 @@ def from_api_completion_chunk(
|
|
12
13
|
api_completion_chunk: OpenAICompletionChunk, name: str | None = None
|
13
14
|
) -> CompletionChunk:
|
14
15
|
if api_completion_chunk.choices is None: # type: ignore
|
15
|
-
raise
|
16
|
+
raise CompletionError(
|
16
17
|
f"Completion chunk API error: "
|
17
18
|
f"{getattr(api_completion_chunk, 'error', None)}"
|
18
19
|
)
|
@@ -24,12 +25,12 @@ def from_api_completion_chunk(
|
|
24
25
|
finish_reason = api_choice.finish_reason
|
25
26
|
|
26
27
|
if api_choice.delta is None: # type: ignore
|
27
|
-
raise
|
28
|
+
raise CompletionError(
|
28
29
|
"API returned None for delta content in completion chunk "
|
29
30
|
f"with finish_reason: {finish_reason}."
|
30
31
|
)
|
31
32
|
# if api_choice.delta.content is None:
|
32
|
-
# raise
|
33
|
+
# raise CompletionError(
|
33
34
|
# "API returned None for delta content in completion chunk "
|
34
35
|
# f"with finish_reason: {finish_reason}."
|
35
36
|
# )
|
@@ -3,6 +3,7 @@ from collections.abc import AsyncIterator, Iterable, Mapping
|
|
3
3
|
from copy import deepcopy
|
4
4
|
from typing import Any, Literal, NamedTuple
|
5
5
|
|
6
|
+
import httpx
|
6
7
|
from openai import AsyncOpenAI
|
7
8
|
from openai._types import NOT_GIVEN # type: ignore[import]
|
8
9
|
from openai.lib.streaming.chat import (
|
@@ -11,7 +12,7 @@ from openai.lib.streaming.chat import (
|
|
11
12
|
from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
|
12
13
|
from pydantic import BaseModel
|
13
14
|
|
14
|
-
from ..cloud_llm import CloudLLM, CloudLLMSettings
|
15
|
+
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
15
16
|
from ..http_client import AsyncHTTPClientParams
|
16
17
|
from ..rate_limiting.rate_limiter_chunked import RateLimiterC
|
17
18
|
from ..typing.message import AssistantMessage, Messages
|
@@ -57,8 +58,6 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
|
|
57
58
|
store: bool | None
|
58
59
|
user: str
|
59
60
|
|
60
|
-
strict_tool_args: bool
|
61
|
-
|
62
61
|
# response_format: (
|
63
62
|
# OpenAIResponseFormatText
|
64
63
|
# | OpenAIResponseFormatJSONSchema
|
@@ -71,16 +70,18 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
71
70
|
self,
|
72
71
|
# Base LLM args
|
73
72
|
model_name: str,
|
74
|
-
model_id: str | None = None,
|
75
73
|
llm_settings: OpenAILLMSettings | None = None,
|
76
74
|
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
77
75
|
response_format: type | Mapping[str, type] | None = None,
|
76
|
+
model_id: str | None = None,
|
77
|
+
# Custom LLM provider
|
78
|
+
api_provider: APIProvider | None = None,
|
78
79
|
# Connection settings
|
80
|
+
async_http_client: httpx.AsyncClient | None = None,
|
79
81
|
async_http_client_params: (
|
80
82
|
dict[str, Any] | AsyncHTTPClientParams | None
|
81
83
|
) = None,
|
82
84
|
async_openai_client_params: dict[str, Any] | None = None,
|
83
|
-
client: AsyncOpenAI | None = None,
|
84
85
|
# Rate limiting
|
85
86
|
rate_limiter: (RateLimiterC[Messages, AssistantMessage] | None) = None,
|
86
87
|
rate_limiter_rpm: float | None = None,
|
@@ -88,9 +89,6 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
88
89
|
rate_limiter_max_concurrency: int = 300,
|
89
90
|
# Retries
|
90
91
|
num_generation_retries: int = 0,
|
91
|
-
# Disable tqdm for batch processing
|
92
|
-
no_tqdm: bool = True,
|
93
|
-
**kwargs: Any,
|
94
92
|
) -> None:
|
95
93
|
super().__init__(
|
96
94
|
model_name=model_name,
|
@@ -99,33 +97,29 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
99
97
|
converters=OpenAIConverters(),
|
100
98
|
tools=tools,
|
101
99
|
response_format=response_format,
|
100
|
+
api_provider=api_provider,
|
101
|
+
async_http_client=async_http_client,
|
102
102
|
async_http_client_params=async_http_client_params,
|
103
103
|
rate_limiter=rate_limiter,
|
104
104
|
rate_limiter_rpm=rate_limiter_rpm,
|
105
105
|
rate_limiter_chunk_size=rate_limiter_chunk_size,
|
106
106
|
rate_limiter_max_concurrency=rate_limiter_max_concurrency,
|
107
107
|
num_generation_retries=num_generation_retries,
|
108
|
-
no_tqdm=no_tqdm,
|
109
|
-
**kwargs,
|
110
108
|
)
|
111
109
|
|
112
110
|
self._tool_call_settings = {
|
113
|
-
"strict": self._llm_settings.
|
111
|
+
"strict": self._llm_settings.get("use_struct_outputs", False),
|
114
112
|
}
|
115
113
|
|
116
114
|
_async_openai_client_params = deepcopy(async_openai_client_params or {})
|
117
115
|
if self._async_http_client is not None:
|
118
116
|
_async_openai_client_params["http_client"] = self._async_http_client
|
119
117
|
|
120
|
-
|
121
|
-
|
122
|
-
self.
|
123
|
-
|
124
|
-
|
125
|
-
base_url=self._base_url,
|
126
|
-
api_key=self._api_key,
|
127
|
-
**_async_openai_client_params,
|
128
|
-
)
|
118
|
+
self._client: AsyncOpenAI = AsyncOpenAI(
|
119
|
+
base_url=self._base_url,
|
120
|
+
api_key=self._api_key,
|
121
|
+
**_async_openai_client_params,
|
122
|
+
)
|
129
123
|
|
130
124
|
async def _get_completion(
|
131
125
|
self,
|
@@ -17,7 +17,6 @@ def to_api_tool(
|
|
17
17
|
tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None
|
18
18
|
) -> OpenAIToolParam:
|
19
19
|
if strict:
|
20
|
-
# Enforce strict mode for pydantic models
|
21
20
|
return pydantic_function_tool(
|
22
21
|
model=tool.in_type, name=tool.name, description=tool.description
|
23
22
|
)
|
grasp_agents/packet_pool.py
CHANGED
@@ -80,7 +80,7 @@ class PacketPool(Generic[CtxT]):
|
|
80
80
|
try:
|
81
81
|
await task
|
82
82
|
except asyncio.CancelledError:
|
83
|
-
logger.
|
83
|
+
logger.info(f"{processor_name} exited")
|
84
84
|
|
85
85
|
self._tasks.pop(processor_name, None)
|
86
86
|
self._queues.pop(processor_name, None)
|
grasp_agents/printer.py
CHANGED
@@ -36,12 +36,10 @@ AVAILABLE_COLORS: list[Color] = [
|
|
36
36
|
class Printer:
|
37
37
|
def __init__(
|
38
38
|
self,
|
39
|
-
source_id: str,
|
40
39
|
color_by: ColoringMode = "role",
|
41
40
|
msg_trunc_len: int = 20000,
|
42
41
|
print_messages: bool = False,
|
43
42
|
) -> None:
|
44
|
-
self.source_id = source_id
|
45
43
|
self.color_by = color_by
|
46
44
|
self.msg_trunc_len = msg_trunc_len
|
47
45
|
self.print_messages = print_messages
|
@@ -84,7 +82,7 @@ class Printer:
|
|
84
82
|
return content_str
|
85
83
|
|
86
84
|
def print_llm_message(
|
87
|
-
self, message: Message, agent_name: str, usage: Usage | None = None
|
85
|
+
self, message: Message, agent_name: str, run_id: str, usage: Usage | None = None
|
88
86
|
) -> None:
|
89
87
|
if not self.print_messages:
|
90
88
|
return
|
@@ -106,8 +104,7 @@ class Printer:
|
|
106
104
|
|
107
105
|
# Print message title
|
108
106
|
|
109
|
-
out = f"\n
|
110
|
-
out += "[" + role.value.upper() + "]"
|
107
|
+
out = f"\n[agent: {agent_name} | role: {role.value} | run: {run_id}]"
|
111
108
|
|
112
109
|
if isinstance(message, ToolMessage):
|
113
110
|
out += f"\n{message.name} | {message.tool_call_id}"
|
@@ -159,6 +156,7 @@ class Printer:
|
|
159
156
|
self,
|
160
157
|
messages: Sequence[Message],
|
161
158
|
agent_name: str,
|
159
|
+
run_id: str,
|
162
160
|
usages: Sequence[Usage | None] | None = None,
|
163
161
|
) -> None:
|
164
162
|
if not self.print_messages:
|
@@ -167,4 +165,6 @@ class Printer:
|
|
167
165
|
_usages: Sequence[Usage | None] = usages or [None] * len(messages)
|
168
166
|
|
169
167
|
for _message, _usage in zip(messages, _usages, strict=False):
|
170
|
-
self.print_llm_message(
|
168
|
+
self.print_llm_message(
|
169
|
+
_message, usage=_usage, agent_name=agent_name, run_id=run_id
|
170
|
+
)
|