grasp_agents 0.3.11__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +57 -74
- grasp_agents/comm_processor.py +21 -11
- grasp_agents/errors.py +34 -0
- grasp_agents/http_client.py +7 -5
- grasp_agents/llm.py +3 -9
- grasp_agents/llm_agent.py +92 -103
- grasp_agents/llm_agent_memory.py +36 -27
- grasp_agents/llm_policy_executor.py +73 -66
- grasp_agents/memory.py +3 -1
- grasp_agents/openai/completion_chunk_converters.py +4 -3
- grasp_agents/openai/openai_llm.py +14 -20
- grasp_agents/openai/tool_converters.py +0 -1
- grasp_agents/packet_pool.py +1 -1
- grasp_agents/printer.py +6 -6
- grasp_agents/processor.py +182 -48
- grasp_agents/prompt_builder.py +41 -55
- grasp_agents/run_context.py +1 -5
- grasp_agents/typing/completion_chunk.py +10 -5
- grasp_agents/typing/content.py +3 -2
- grasp_agents/typing/io.py +4 -4
- grasp_agents/typing/message.py +3 -8
- grasp_agents/typing/tool.py +5 -23
- grasp_agents/usage_tracker.py +2 -4
- grasp_agents/utils.py +37 -15
- grasp_agents/workflow/looped_workflow.py +14 -9
- grasp_agents/workflow/sequential_workflow.py +11 -6
- grasp_agents/workflow/workflow_processor.py +30 -13
- {grasp_agents-0.3.11.dist-info → grasp_agents-0.4.2.dist-info}/METADATA +3 -2
- grasp_agents-0.4.2.dist-info/RECORD +50 -0
- grasp_agents/message_history.py +0 -140
- grasp_agents/workflow/parallel_processor.py +0 -95
- grasp_agents-0.3.11.dist-info/RECORD +0 -51
- {grasp_agents-0.3.11.dist-info → grasp_agents-0.4.2.dist-info}/WHEEL +0 -0
- {grasp_agents-0.3.11.dist-info → grasp_agents-0.4.2.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
3
|
from collections.abc import AsyncIterator, Coroutine, Sequence
|
4
|
-
from itertools import
|
4
|
+
from itertools import starmap
|
5
5
|
from logging import getLogger
|
6
6
|
from typing import Any, ClassVar, Generic, Protocol, TypeVar
|
7
7
|
|
@@ -132,62 +132,58 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
132
132
|
if self.manage_memory_impl:
|
133
133
|
self.manage_memory_impl(memory=memory, ctx=ctx, **kwargs)
|
134
134
|
|
135
|
-
async def
|
135
|
+
async def generate_messages(
|
136
136
|
self,
|
137
137
|
memory: LLMAgentMemory,
|
138
|
+
run_id: str,
|
138
139
|
tool_choice: ToolChoice | None = None,
|
139
140
|
ctx: RunContext[CtxT] | None = None,
|
140
141
|
) -> Sequence[AssistantMessage]:
|
141
|
-
|
142
|
+
completion = await self.llm.generate_completion(
|
142
143
|
memory.message_history, tool_choice=tool_choice
|
143
144
|
)
|
144
|
-
|
145
|
-
chain.from_iterable([c.messages for c in completion_batch])
|
146
|
-
)
|
147
|
-
memory.update(message_batch=message_batch)
|
145
|
+
memory.update(completion.messages)
|
148
146
|
|
149
147
|
if ctx is not None:
|
150
|
-
ctx.completions[self.agent_name].
|
151
|
-
self._track_usage(self.agent_name,
|
152
|
-
self.
|
148
|
+
ctx.completions[self.agent_name].append(completion)
|
149
|
+
self._track_usage(self.agent_name, completion, ctx=ctx)
|
150
|
+
self._print_completion(completion, run_id=run_id, ctx=ctx)
|
153
151
|
|
154
|
-
return
|
152
|
+
return completion.messages
|
155
153
|
|
156
|
-
async def
|
154
|
+
async def generate_messages_stream(
|
157
155
|
self,
|
158
156
|
memory: LLMAgentMemory,
|
157
|
+
run_id: str,
|
159
158
|
tool_choice: ToolChoice | None = None,
|
160
159
|
ctx: RunContext[CtxT] | None = None,
|
161
160
|
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | GenMessageEvent]:
|
162
161
|
message_hist = memory.message_history
|
163
|
-
if memory.message_history.batch_size > 1:
|
164
|
-
raise ValueError("Batch size must be 1 when streaming completions.")
|
165
|
-
conversation = message_hist.conversations[0]
|
166
162
|
|
167
163
|
completion: Completion | None = None
|
168
164
|
async for event in await self.llm.generate_completion_stream(
|
169
|
-
|
165
|
+
message_hist, tool_choice=tool_choice
|
170
166
|
):
|
171
167
|
yield event
|
172
168
|
if isinstance(event, CompletionEvent):
|
173
169
|
completion = event.data
|
174
|
-
|
175
170
|
if completion is None:
|
176
171
|
raise RuntimeError("No completion generated during stream.")
|
177
172
|
|
178
|
-
memory.update(
|
173
|
+
memory.update(completion.messages)
|
179
174
|
|
180
175
|
for message in completion.messages:
|
181
176
|
yield GenMessageEvent(name=self.agent_name, data=message)
|
182
177
|
|
183
178
|
if ctx is not None:
|
184
|
-
self._track_usage(self.agent_name,
|
179
|
+
self._track_usage(self.agent_name, completion, ctx=ctx)
|
185
180
|
ctx.completions[self.agent_name].append(completion)
|
186
181
|
|
187
182
|
async def call_tools(
|
188
183
|
self,
|
189
184
|
calls: Sequence[ToolCall],
|
190
185
|
memory: LLMAgentMemory,
|
186
|
+
run_id: str,
|
191
187
|
ctx: RunContext[CtxT] | None = None,
|
192
188
|
) -> Sequence[ToolMessage]:
|
193
189
|
corouts: list[Coroutine[Any, Any, BaseModel]] = []
|
@@ -198,12 +194,14 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
198
194
|
|
199
195
|
outs = await asyncio.gather(*corouts)
|
200
196
|
tool_messages = list(
|
201
|
-
starmap(ToolMessage.from_tool_output, zip(outs, calls, strict=
|
197
|
+
starmap(ToolMessage.from_tool_output, zip(outs, calls, strict=True))
|
202
198
|
)
|
203
199
|
memory.update(tool_messages)
|
204
200
|
|
205
201
|
if ctx is not None:
|
206
|
-
ctx.printer.print_llm_messages(
|
202
|
+
ctx.printer.print_llm_messages(
|
203
|
+
tool_messages, agent_name=self.agent_name, run_id=run_id
|
204
|
+
)
|
207
205
|
|
208
206
|
return tool_messages
|
209
207
|
|
@@ -211,10 +209,13 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
211
209
|
self,
|
212
210
|
calls: Sequence[ToolCall],
|
213
211
|
memory: LLMAgentMemory,
|
212
|
+
run_id: str,
|
214
213
|
ctx: RunContext[CtxT] | None = None,
|
215
214
|
) -> AsyncIterator[ToolMessageEvent]:
|
216
|
-
tool_messages = await self.call_tools(
|
217
|
-
|
215
|
+
tool_messages = await self.call_tools(
|
216
|
+
calls, memory=memory, run_id=run_id, ctx=ctx
|
217
|
+
)
|
218
|
+
for tool_message, call in zip(tool_messages, calls, strict=True):
|
218
219
|
yield ToolMessageEvent(name=call.tool_name, data=tool_message)
|
219
220
|
|
220
221
|
def _extract_final_answer_from_tool_calls(
|
@@ -233,7 +234,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
233
234
|
return final_answer_message
|
234
235
|
|
235
236
|
async def _generate_final_answer(
|
236
|
-
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
237
|
+
self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
|
237
238
|
) -> AssistantMessage:
|
238
239
|
assert self._final_answer_tool_name is not None
|
239
240
|
|
@@ -242,11 +243,15 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
242
243
|
)
|
243
244
|
memory.update([user_message])
|
244
245
|
if ctx is not None:
|
245
|
-
ctx.printer.print_llm_messages(
|
246
|
+
ctx.printer.print_llm_messages(
|
247
|
+
[user_message], agent_name=self.agent_name, run_id=run_id
|
248
|
+
)
|
246
249
|
|
247
250
|
tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
|
248
251
|
gen_message = (
|
249
|
-
await self.
|
252
|
+
await self.generate_messages(
|
253
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
254
|
+
)
|
250
255
|
)[0]
|
251
256
|
|
252
257
|
final_answer_message = self._extract_final_answer_from_tool_calls(
|
@@ -260,7 +265,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
260
265
|
return final_answer_message
|
261
266
|
|
262
267
|
async def _generate_final_answer_stream(
|
263
|
-
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
268
|
+
self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
|
264
269
|
) -> AsyncIterator[Event[Any]]:
|
265
270
|
assert self._final_answer_tool_name is not None
|
266
271
|
|
@@ -272,8 +277,8 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
272
277
|
|
273
278
|
tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
|
274
279
|
event: Event[Any] | None = None
|
275
|
-
async for event in self.
|
276
|
-
memory, tool_choice=tool_choice, ctx=ctx
|
280
|
+
async for event in self.generate_messages_stream(
|
281
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
277
282
|
):
|
278
283
|
yield event
|
279
284
|
|
@@ -289,7 +294,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
289
294
|
yield GenMessageEvent(name=self.agent_name, data=final_answer_message)
|
290
295
|
|
291
296
|
async def execute(
|
292
|
-
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
297
|
+
self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
|
293
298
|
) -> AssistantMessage | Sequence[AssistantMessage]:
|
294
299
|
# 1. Generate the first message:
|
295
300
|
# In ReAct mode, we generate the first message without tool calls
|
@@ -297,26 +302,24 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
297
302
|
tool_choice: ToolChoice | None = None
|
298
303
|
if self.tools:
|
299
304
|
tool_choice = "none" if self._react_mode else "auto"
|
300
|
-
|
301
|
-
memory, tool_choice=tool_choice, ctx=ctx
|
305
|
+
gen_messages = await self.generate_messages(
|
306
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
302
307
|
)
|
303
308
|
if not self.tools:
|
304
|
-
return
|
309
|
+
return gen_messages
|
305
310
|
|
306
|
-
if
|
307
|
-
raise ValueError("
|
308
|
-
gen_message =
|
311
|
+
if len(gen_messages) > 1:
|
312
|
+
raise ValueError("n_choices must be 1 when executing the tool call loop.")
|
313
|
+
gen_message = gen_messages[0]
|
309
314
|
turns = 0
|
310
315
|
|
311
316
|
while True:
|
312
|
-
conversation = memory.message_history.conversations[0]
|
313
|
-
|
314
317
|
# 2. Check if we should exit the tool call loop
|
315
318
|
|
316
319
|
# When final_answer_tool_name is None, we use exit_tool_call_loop_impl
|
317
320
|
# to determine whether to exit the loop.
|
318
321
|
if self._final_answer_tool_name is None and self._exit_tool_call_loop(
|
319
|
-
|
322
|
+
memory.message_history, ctx=ctx, num_turns=turns
|
320
323
|
):
|
321
324
|
return gen_message
|
322
325
|
|
@@ -336,7 +339,9 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
336
339
|
# tool call.
|
337
340
|
# Otherwise, we simply return the last generated message.
|
338
341
|
if self._final_answer_tool_name is not None:
|
339
|
-
final_answer = await self._generate_final_answer(
|
342
|
+
final_answer = await self._generate_final_answer(
|
343
|
+
memory, run_id=run_id, ctx=ctx
|
344
|
+
)
|
340
345
|
else:
|
341
346
|
final_answer = gen_message
|
342
347
|
logger.info(
|
@@ -347,7 +352,9 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
347
352
|
# 3. Call tools if there are any tool calls in the generated message.
|
348
353
|
|
349
354
|
if gen_message.tool_calls:
|
350
|
-
await self.call_tools(
|
355
|
+
await self.call_tools(
|
356
|
+
gen_message.tool_calls, memory=memory, run_id=run_id, ctx=ctx
|
357
|
+
)
|
351
358
|
|
352
359
|
# Apply the memory management function if provided.
|
353
360
|
self._manage_memory(memory, ctx=ctx, num_turns=turns)
|
@@ -359,27 +366,28 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
359
366
|
# If we are not in ReAct mode, we set tool_choice to "auto" to allow
|
360
367
|
# the LLM to choose freely whether to call tools.
|
361
368
|
|
362
|
-
|
363
|
-
|
364
|
-
|
369
|
+
if self._react_mode and gen_message.tool_calls:
|
370
|
+
tool_choice = "none"
|
371
|
+
elif gen_message.tool_calls:
|
372
|
+
tool_choice = "auto"
|
373
|
+
else:
|
374
|
+
tool_choice = "required"
|
375
|
+
|
365
376
|
gen_message = (
|
366
|
-
await self.
|
367
|
-
memory, tool_choice=tool_choice, ctx=ctx
|
377
|
+
await self.generate_messages(
|
378
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
368
379
|
)
|
369
380
|
)[0]
|
370
381
|
|
371
382
|
turns += 1
|
372
383
|
|
373
384
|
async def execute_stream(
|
374
|
-
self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
|
385
|
+
self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
|
375
386
|
) -> AsyncIterator[Event[Any]]:
|
376
|
-
if memory.message_history.batch_size > 1:
|
377
|
-
raise ValueError("Batch size must be 1 when streaming.")
|
378
|
-
|
379
387
|
tool_choice: ToolChoice = "none" if self._react_mode else "auto"
|
380
388
|
gen_message: AssistantMessage | None = None
|
381
|
-
async for event in self.
|
382
|
-
memory, tool_choice=tool_choice, ctx=ctx
|
389
|
+
async for event in self.generate_messages_stream(
|
390
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
383
391
|
):
|
384
392
|
yield event
|
385
393
|
if isinstance(event, GenMessageEvent):
|
@@ -389,10 +397,8 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
389
397
|
turns = 0
|
390
398
|
|
391
399
|
while True:
|
392
|
-
conversation = memory.message_history.conversations[0]
|
393
|
-
|
394
400
|
if self._final_answer_tool_name is None and self._exit_tool_call_loop(
|
395
|
-
|
401
|
+
memory.message_history, ctx=ctx, num_turns=turns
|
396
402
|
):
|
397
403
|
return
|
398
404
|
|
@@ -409,7 +415,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
409
415
|
if turns >= self.max_turns:
|
410
416
|
if self._final_answer_tool_name is not None:
|
411
417
|
async for event in self._generate_final_answer_stream(
|
412
|
-
memory, ctx=ctx
|
418
|
+
memory, run_id=run_id, ctx=ctx
|
413
419
|
):
|
414
420
|
yield event
|
415
421
|
logger.info(
|
@@ -422,7 +428,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
422
428
|
yield ToolCallEvent(name=self.agent_name, data=tool_call)
|
423
429
|
|
424
430
|
async for tool_message_event in self.call_tools_stream(
|
425
|
-
gen_message.tool_calls, memory=memory, ctx=ctx
|
431
|
+
gen_message.tool_calls, memory=memory, run_id=run_id, ctx=ctx
|
426
432
|
):
|
427
433
|
yield tool_message_event
|
428
434
|
|
@@ -431,8 +437,8 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
431
437
|
tool_choice = (
|
432
438
|
"none" if (self._react_mode and gen_message.tool_calls) else "required"
|
433
439
|
)
|
434
|
-
async for event in self.
|
435
|
-
memory, tool_choice=tool_choice, ctx=ctx
|
440
|
+
async for event in self.generate_messages_stream(
|
441
|
+
memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
|
436
442
|
):
|
437
443
|
yield event
|
438
444
|
if isinstance(event, GenMessageEvent):
|
@@ -443,12 +449,12 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
443
449
|
def _track_usage(
|
444
450
|
self,
|
445
451
|
agent_name: str,
|
446
|
-
|
452
|
+
completion: Completion,
|
447
453
|
ctx: RunContext[CtxT],
|
448
454
|
) -> None:
|
449
455
|
ctx.usage_tracker.update(
|
450
456
|
agent_name=agent_name,
|
451
|
-
completions=
|
457
|
+
completions=[completion],
|
452
458
|
model_name=self.llm.model_name,
|
453
459
|
)
|
454
460
|
|
@@ -473,11 +479,12 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
|
|
473
479
|
|
474
480
|
return FinalAnswerTool()
|
475
481
|
|
476
|
-
def
|
477
|
-
self,
|
482
|
+
def _print_completion(
|
483
|
+
self, completion: Completion, run_id: str, ctx: RunContext[CtxT]
|
478
484
|
) -> None:
|
479
|
-
messages = [c.messages[0] for c in completion_batch]
|
480
|
-
usages = [c.usage for c in completion_batch]
|
481
485
|
ctx.printer.print_llm_messages(
|
482
|
-
messages,
|
486
|
+
completion.messages,
|
487
|
+
usages=[completion.usage],
|
488
|
+
agent_name=self.agent_name,
|
489
|
+
run_id=run_id,
|
483
490
|
)
|
grasp_agents/memory.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any, TypeVar
|
3
3
|
|
4
4
|
from pydantic import BaseModel, ConfigDict
|
5
5
|
|
6
6
|
from .run_context import RunContext
|
7
7
|
|
8
|
+
MemT = TypeVar("MemT", bound="Memory")
|
9
|
+
|
8
10
|
|
9
11
|
class Memory(BaseModel, ABC):
|
10
12
|
@abstractmethod
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from ..errors import CompletionError
|
1
2
|
from ..typing.completion_chunk import (
|
2
3
|
CompletionChunk,
|
3
4
|
CompletionChunkChoice,
|
@@ -12,7 +13,7 @@ def from_api_completion_chunk(
|
|
12
13
|
api_completion_chunk: OpenAICompletionChunk, name: str | None = None
|
13
14
|
) -> CompletionChunk:
|
14
15
|
if api_completion_chunk.choices is None: # type: ignore
|
15
|
-
raise
|
16
|
+
raise CompletionError(
|
16
17
|
f"Completion chunk API error: "
|
17
18
|
f"{getattr(api_completion_chunk, 'error', None)}"
|
18
19
|
)
|
@@ -24,12 +25,12 @@ def from_api_completion_chunk(
|
|
24
25
|
finish_reason = api_choice.finish_reason
|
25
26
|
|
26
27
|
if api_choice.delta is None: # type: ignore
|
27
|
-
raise
|
28
|
+
raise CompletionError(
|
28
29
|
"API returned None for delta content in completion chunk "
|
29
30
|
f"with finish_reason: {finish_reason}."
|
30
31
|
)
|
31
32
|
# if api_choice.delta.content is None:
|
32
|
-
# raise
|
33
|
+
# raise CompletionError(
|
33
34
|
# "API returned None for delta content in completion chunk "
|
34
35
|
# f"with finish_reason: {finish_reason}."
|
35
36
|
# )
|
@@ -3,6 +3,7 @@ from collections.abc import AsyncIterator, Iterable, Mapping
|
|
3
3
|
from copy import deepcopy
|
4
4
|
from typing import Any, Literal, NamedTuple
|
5
5
|
|
6
|
+
import httpx
|
6
7
|
from openai import AsyncOpenAI
|
7
8
|
from openai._types import NOT_GIVEN # type: ignore[import]
|
8
9
|
from openai.lib.streaming.chat import (
|
@@ -11,7 +12,7 @@ from openai.lib.streaming.chat import (
|
|
11
12
|
from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
|
12
13
|
from pydantic import BaseModel
|
13
14
|
|
14
|
-
from ..cloud_llm import CloudLLM, CloudLLMSettings
|
15
|
+
from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
|
15
16
|
from ..http_client import AsyncHTTPClientParams
|
16
17
|
from ..rate_limiting.rate_limiter_chunked import RateLimiterC
|
17
18
|
from ..typing.message import AssistantMessage, Messages
|
@@ -57,8 +58,6 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
|
|
57
58
|
store: bool | None
|
58
59
|
user: str
|
59
60
|
|
60
|
-
strict_tool_args: bool
|
61
|
-
|
62
61
|
# response_format: (
|
63
62
|
# OpenAIResponseFormatText
|
64
63
|
# | OpenAIResponseFormatJSONSchema
|
@@ -71,16 +70,18 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
71
70
|
self,
|
72
71
|
# Base LLM args
|
73
72
|
model_name: str,
|
74
|
-
model_id: str | None = None,
|
75
73
|
llm_settings: OpenAILLMSettings | None = None,
|
76
74
|
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
77
75
|
response_format: type | Mapping[str, type] | None = None,
|
76
|
+
model_id: str | None = None,
|
77
|
+
# Custom LLM provider
|
78
|
+
api_provider: APIProvider | None = None,
|
78
79
|
# Connection settings
|
80
|
+
async_http_client: httpx.AsyncClient | None = None,
|
79
81
|
async_http_client_params: (
|
80
82
|
dict[str, Any] | AsyncHTTPClientParams | None
|
81
83
|
) = None,
|
82
84
|
async_openai_client_params: dict[str, Any] | None = None,
|
83
|
-
client: AsyncOpenAI | None = None,
|
84
85
|
# Rate limiting
|
85
86
|
rate_limiter: (RateLimiterC[Messages, AssistantMessage] | None) = None,
|
86
87
|
rate_limiter_rpm: float | None = None,
|
@@ -88,9 +89,6 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
88
89
|
rate_limiter_max_concurrency: int = 300,
|
89
90
|
# Retries
|
90
91
|
num_generation_retries: int = 0,
|
91
|
-
# Disable tqdm for batch processing
|
92
|
-
no_tqdm: bool = True,
|
93
|
-
**kwargs: Any,
|
94
92
|
) -> None:
|
95
93
|
super().__init__(
|
96
94
|
model_name=model_name,
|
@@ -99,33 +97,29 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
99
97
|
converters=OpenAIConverters(),
|
100
98
|
tools=tools,
|
101
99
|
response_format=response_format,
|
100
|
+
api_provider=api_provider,
|
101
|
+
async_http_client=async_http_client,
|
102
102
|
async_http_client_params=async_http_client_params,
|
103
103
|
rate_limiter=rate_limiter,
|
104
104
|
rate_limiter_rpm=rate_limiter_rpm,
|
105
105
|
rate_limiter_chunk_size=rate_limiter_chunk_size,
|
106
106
|
rate_limiter_max_concurrency=rate_limiter_max_concurrency,
|
107
107
|
num_generation_retries=num_generation_retries,
|
108
|
-
no_tqdm=no_tqdm,
|
109
|
-
**kwargs,
|
110
108
|
)
|
111
109
|
|
112
110
|
self._tool_call_settings = {
|
113
|
-
"strict": self._llm_settings.
|
111
|
+
"strict": self._llm_settings.get("use_struct_outputs", False),
|
114
112
|
}
|
115
113
|
|
116
114
|
_async_openai_client_params = deepcopy(async_openai_client_params or {})
|
117
115
|
if self._async_http_client is not None:
|
118
116
|
_async_openai_client_params["http_client"] = self._async_http_client
|
119
117
|
|
120
|
-
|
121
|
-
|
122
|
-
self.
|
123
|
-
|
124
|
-
|
125
|
-
base_url=self._base_url,
|
126
|
-
api_key=self._api_key,
|
127
|
-
**_async_openai_client_params,
|
128
|
-
)
|
118
|
+
self._client: AsyncOpenAI = AsyncOpenAI(
|
119
|
+
base_url=self._base_url,
|
120
|
+
api_key=self._api_key,
|
121
|
+
**_async_openai_client_params,
|
122
|
+
)
|
129
123
|
|
130
124
|
async def _get_completion(
|
131
125
|
self,
|
@@ -17,7 +17,6 @@ def to_api_tool(
|
|
17
17
|
tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None
|
18
18
|
) -> OpenAIToolParam:
|
19
19
|
if strict:
|
20
|
-
# Enforce strict mode for pydantic models
|
21
20
|
return pydantic_function_tool(
|
22
21
|
model=tool.in_type, name=tool.name, description=tool.description
|
23
22
|
)
|
grasp_agents/packet_pool.py
CHANGED
@@ -80,7 +80,7 @@ class PacketPool(Generic[CtxT]):
|
|
80
80
|
try:
|
81
81
|
await task
|
82
82
|
except asyncio.CancelledError:
|
83
|
-
logger.
|
83
|
+
logger.info(f"{processor_name} exited")
|
84
84
|
|
85
85
|
self._tasks.pop(processor_name, None)
|
86
86
|
self._queues.pop(processor_name, None)
|
grasp_agents/printer.py
CHANGED
@@ -36,12 +36,10 @@ AVAILABLE_COLORS: list[Color] = [
|
|
36
36
|
class Printer:
|
37
37
|
def __init__(
|
38
38
|
self,
|
39
|
-
source_id: str,
|
40
39
|
color_by: ColoringMode = "role",
|
41
40
|
msg_trunc_len: int = 20000,
|
42
41
|
print_messages: bool = False,
|
43
42
|
) -> None:
|
44
|
-
self.source_id = source_id
|
45
43
|
self.color_by = color_by
|
46
44
|
self.msg_trunc_len = msg_trunc_len
|
47
45
|
self.print_messages = print_messages
|
@@ -84,7 +82,7 @@ class Printer:
|
|
84
82
|
return content_str
|
85
83
|
|
86
84
|
def print_llm_message(
|
87
|
-
self, message: Message, agent_name: str, usage: Usage | None = None
|
85
|
+
self, message: Message, agent_name: str, run_id: str, usage: Usage | None = None
|
88
86
|
) -> None:
|
89
87
|
if not self.print_messages:
|
90
88
|
return
|
@@ -106,8 +104,7 @@ class Printer:
|
|
106
104
|
|
107
105
|
# Print message title
|
108
106
|
|
109
|
-
out = f"\n
|
110
|
-
out += "[" + role.value.upper() + "]"
|
107
|
+
out = f"\n[agent: {agent_name} | role: {role.value} | run: {run_id}]"
|
111
108
|
|
112
109
|
if isinstance(message, ToolMessage):
|
113
110
|
out += f"\n{message.name} | {message.tool_call_id}"
|
@@ -159,6 +156,7 @@ class Printer:
|
|
159
156
|
self,
|
160
157
|
messages: Sequence[Message],
|
161
158
|
agent_name: str,
|
159
|
+
run_id: str,
|
162
160
|
usages: Sequence[Usage | None] | None = None,
|
163
161
|
) -> None:
|
164
162
|
if not self.print_messages:
|
@@ -167,4 +165,6 @@ class Printer:
|
|
167
165
|
_usages: Sequence[Usage | None] = usages or [None] * len(messages)
|
168
166
|
|
169
167
|
for _message, _usage in zip(messages, _usages, strict=False):
|
170
|
-
self.print_llm_message(
|
168
|
+
self.print_llm_message(
|
169
|
+
_message, usage=_usage, agent_name=agent_name, run_id=run_id
|
170
|
+
)
|