grasp_agents 0.2.3__tar.gz → 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/PKG-INFO +3 -4
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/README.md +2 -3
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/pyproject.toml +1 -1
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/base_agent.py +1 -1
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/cloud_llm.py +10 -2
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/comm_agent.py +10 -10
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/llm_agent.py +70 -44
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/llm_agent_state.py +9 -9
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/openai/completion_converters.py +7 -2
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/openai/message_converters.py +13 -8
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/openai/openai_llm.py +4 -6
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/openai/tool_converters.py +9 -8
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/printer.py +1 -1
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/prompt_builder.py +63 -57
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/run_context.py +3 -3
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/tool_orchestrator.py +2 -2
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/typing/message.py +1 -1
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/usage_tracker.py +2 -1
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/workflow/looped_agent.py +11 -7
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/workflow/sequential_agent.py +9 -6
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/workflow/workflow_agent.py +3 -2
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/.gitignore +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/LICENSE.md +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/__init__.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/agent_message.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/agent_message_pool.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/costs_dict.yaml +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/generics_utils.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/grasp_logging.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/http_client.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/llm.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/memory.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/openai/__init__.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/openai/content_converters.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/openai/converters.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/rate_limiting/__init__.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/rate_limiting/types.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/rate_limiting/utils.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/typing/__init__.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/typing/completion.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/typing/content.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/typing/converters.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/typing/io.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/typing/tool.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/utils.py +0 -0
- {grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/workflow/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: grasp_agents
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.5
|
4
4
|
Summary: Grasp Agents Library
|
5
5
|
License-File: LICENSE.md
|
6
6
|
Requires-Python: <4,>=3.11.4
|
@@ -168,9 +168,8 @@ Problem = str
|
|
168
168
|
teacher = LLMAgent[Any, Problem, None](
|
169
169
|
agent_id="teacher",
|
170
170
|
llm=OpenAILLM(
|
171
|
-
model_name="gpt-4.1",
|
172
|
-
|
173
|
-
llm_settings=OpenAILLMSettings(temperature=0.1),
|
171
|
+
model_name="openai:gpt-4.1",
|
172
|
+
llm_settings=OpenAILLMSettings(temperature=0.1)
|
174
173
|
),
|
175
174
|
tools=[AskStudentTool()],
|
176
175
|
max_turns=20,
|
@@ -152,9 +152,8 @@ Problem = str
|
|
152
152
|
teacher = LLMAgent[Any, Problem, None](
|
153
153
|
agent_id="teacher",
|
154
154
|
llm=OpenAILLM(
|
155
|
-
model_name="gpt-4.1",
|
156
|
-
|
157
|
-
llm_settings=OpenAILLMSettings(temperature=0.1),
|
155
|
+
model_name="openai:gpt-4.1",
|
156
|
+
llm_settings=OpenAILLMSettings(temperature=0.1)
|
158
157
|
),
|
159
158
|
tools=[AskStudentTool()],
|
160
159
|
max_turns=20,
|
@@ -106,7 +106,6 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
106
106
|
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
107
107
|
response_format: type | None = None,
|
108
108
|
# Connection settings
|
109
|
-
api_provider: APIProvider = "openai",
|
110
109
|
async_http_client_params: (
|
111
110
|
dict[str, Any] | AsyncHTTPClientParams | None
|
112
111
|
) = None,
|
@@ -134,7 +133,16 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
134
133
|
)
|
135
134
|
|
136
135
|
self._model_name = model_name
|
136
|
+
|
137
|
+
api_provider = model_name.split(":", 1)[0]
|
138
|
+
api_model_name = model_name.split(":", 1)[-1]
|
139
|
+
if api_provider not in PROVIDERS:
|
140
|
+
raise ValueError(
|
141
|
+
f"API provider '{api_provider}' is not supported. "
|
142
|
+
f"Supported providers are: {', '.join(PROVIDERS.keys())}"
|
143
|
+
)
|
137
144
|
self._api_provider: APIProvider = api_provider
|
145
|
+
self._api_model_name: str = api_model_name
|
138
146
|
|
139
147
|
self._struct_output_support: bool = any(
|
140
148
|
fnmatch.fnmatch(self._model_name, pat)
|
@@ -284,7 +292,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
|
|
284
292
|
and not message.tool_calls
|
285
293
|
):
|
286
294
|
validate_obj_from_json_or_py_string(
|
287
|
-
message.content,
|
295
|
+
message.content or "",
|
288
296
|
adapter=self._response_format_pyd,
|
289
297
|
from_substring=True,
|
290
298
|
)
|
@@ -51,7 +51,7 @@ class CommunicatingAgent(
|
|
51
51
|
self._in_type: type[InT]
|
52
52
|
super().__init__(agent_id=agent_id, **kwargs)
|
53
53
|
|
54
|
-
self.
|
54
|
+
self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
55
55
|
self.recipient_ids = recipient_ids or []
|
56
56
|
|
57
57
|
self._message_pool = message_pool or AgentMessagePool()
|
@@ -99,10 +99,10 @@ class CommunicatingAgent(
|
|
99
99
|
@abstractmethod
|
100
100
|
async def run(
|
101
101
|
self,
|
102
|
-
|
102
|
+
chat_inputs: Any | None = None,
|
103
103
|
*,
|
104
104
|
ctx: RunContextWrapper[CtxT] | None = None,
|
105
|
-
|
105
|
+
in_message: AgentMessage[InT, AgentState] | None = None,
|
106
106
|
entry_point: bool = False,
|
107
107
|
forbid_state_change: bool = False,
|
108
108
|
**kwargs: Any,
|
@@ -113,7 +113,7 @@ class CommunicatingAgent(
|
|
113
113
|
self, ctx: RunContextWrapper[CtxT] | None = None, **run_kwargs: Any
|
114
114
|
) -> None:
|
115
115
|
output_message = await self.run(
|
116
|
-
ctx=ctx,
|
116
|
+
ctx=ctx, in_message=None, entry_point=True, **run_kwargs
|
117
117
|
)
|
118
118
|
await self.post_message(output_message)
|
119
119
|
|
@@ -140,8 +140,8 @@ class CommunicatingAgent(
|
|
140
140
|
ctx: RunContextWrapper[CtxT] | None = None,
|
141
141
|
**run_kwargs: Any,
|
142
142
|
) -> None:
|
143
|
-
|
144
|
-
out_message = await self.run(ctx=ctx,
|
143
|
+
in_message = cast("AgentMessage[InT, AgentState]", message)
|
144
|
+
out_message = await self.run(ctx=ctx, in_message=in_message, **run_kwargs)
|
145
145
|
|
146
146
|
if self._exit_condition(output_message=out_message, ctx=ctx):
|
147
147
|
await self._message_pool.stop_all()
|
@@ -199,14 +199,14 @@ class CommunicatingAgent(
|
|
199
199
|
inp: InT,
|
200
200
|
ctx: RunContextWrapper[CtxT] | None = None,
|
201
201
|
) -> OutT:
|
202
|
-
|
203
|
-
|
204
|
-
payloads=[
|
202
|
+
in_args = in_type.model_validate(inp)
|
203
|
+
in_message = AgentMessage[in_type, AgentState](
|
204
|
+
payloads=[in_args],
|
205
205
|
sender_id="<tool_user>",
|
206
206
|
recipient_ids=[agent_instance.agent_id],
|
207
207
|
)
|
208
208
|
agent_result = await agent_instance.run(
|
209
|
-
|
209
|
+
in_message=in_message,
|
210
210
|
entry_point=False,
|
211
211
|
forbid_state_change=True,
|
212
212
|
ctx=ctx,
|
@@ -52,7 +52,7 @@ class ParseOutputHandler(Protocol[InT, OutT, CtxT]):
|
|
52
52
|
self,
|
53
53
|
conversation: Conversation,
|
54
54
|
*,
|
55
|
-
|
55
|
+
in_args: InT | None,
|
56
56
|
batch_idx: int,
|
57
57
|
ctx: RunContextWrapper[CtxT] | None,
|
58
58
|
) -> OutT: ...
|
@@ -74,8 +74,8 @@ class LLMAgent(
|
|
74
74
|
# LLM
|
75
75
|
llm: LLM[LLMSettings, Converters],
|
76
76
|
# Input prompt template (combines user and received arguments)
|
77
|
-
|
78
|
-
|
77
|
+
in_prompt: LLMPrompt | None = None,
|
78
|
+
in_prompt_path: str | Path | None = None,
|
79
79
|
# System prompt template
|
80
80
|
sys_prompt: LLMPrompt | None = None,
|
81
81
|
sys_prompt_path: str | Path | None = None,
|
@@ -119,13 +119,13 @@ class LLMAgent(
|
|
119
119
|
|
120
120
|
# Prompt builder
|
121
121
|
sys_prompt = get_prompt(prompt_text=sys_prompt, prompt_path=sys_prompt_path)
|
122
|
-
|
122
|
+
in_prompt = get_prompt(prompt_text=in_prompt, prompt_path=in_prompt_path)
|
123
123
|
self._prompt_builder: PromptBuilder[InT, CtxT] = PromptBuilder[
|
124
124
|
self.in_type, CtxT
|
125
125
|
](
|
126
126
|
agent_id=self._agent_id,
|
127
127
|
sys_prompt=sys_prompt,
|
128
|
-
|
128
|
+
in_prompt=in_prompt,
|
129
129
|
sys_args_schema=sys_args_schema,
|
130
130
|
usr_args_schema=usr_args_schema,
|
131
131
|
)
|
@@ -159,14 +159,14 @@ class LLMAgent(
|
|
159
159
|
return self._prompt_builder.sys_prompt
|
160
160
|
|
161
161
|
@property
|
162
|
-
def
|
163
|
-
return self._prompt_builder.
|
162
|
+
def in_prompt(self) -> LLMPrompt | None:
|
163
|
+
return self._prompt_builder.in_prompt
|
164
164
|
|
165
165
|
def _parse_output(
|
166
166
|
self,
|
167
167
|
conversation: Conversation,
|
168
168
|
*,
|
169
|
-
|
169
|
+
in_args: InT | None = None,
|
170
170
|
batch_idx: int = 0,
|
171
171
|
ctx: RunContextWrapper[CtxT] | None = None,
|
172
172
|
) -> OutT:
|
@@ -180,25 +180,48 @@ class LLMAgent(
|
|
180
180
|
|
181
181
|
return self._parse_output_impl(
|
182
182
|
conversation=conversation,
|
183
|
-
|
183
|
+
in_args=in_args,
|
184
184
|
batch_idx=batch_idx,
|
185
185
|
ctx=ctx,
|
186
186
|
)
|
187
187
|
|
188
188
|
return validate_obj_from_json_or_py_string(
|
189
|
-
str(conversation[-1].content),
|
189
|
+
str(conversation[-1].content or ""),
|
190
190
|
adapter=self._out_type_adapter,
|
191
191
|
from_substring=True,
|
192
192
|
)
|
193
193
|
|
194
|
+
@staticmethod
|
195
|
+
def _validate_run_inputs(
|
196
|
+
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
197
|
+
in_args: InT | Sequence[InT] | None = None,
|
198
|
+
in_message: AgentMessage[InT, AgentState] | None = None,
|
199
|
+
entry_point: bool = False,
|
200
|
+
) -> None:
|
201
|
+
multiple_inputs_err_message = (
|
202
|
+
"Only one of chat_inputs, in_args, or in_message must be provided."
|
203
|
+
)
|
204
|
+
if chat_inputs is not None and in_args is not None:
|
205
|
+
raise ValueError(multiple_inputs_err_message)
|
206
|
+
if chat_inputs is not None and in_message is not None:
|
207
|
+
raise ValueError(multiple_inputs_err_message)
|
208
|
+
if in_args is not None and in_message is not None:
|
209
|
+
raise ValueError(multiple_inputs_err_message)
|
210
|
+
|
211
|
+
if entry_point and in_message is not None:
|
212
|
+
raise ValueError(
|
213
|
+
"Entry point agent cannot receive messages from other agents."
|
214
|
+
)
|
215
|
+
|
194
216
|
@final
|
195
217
|
async def run(
|
196
218
|
self,
|
197
|
-
|
219
|
+
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
198
220
|
*,
|
199
|
-
|
200
|
-
|
221
|
+
in_message: AgentMessage[InT, AgentState] | None = None,
|
222
|
+
in_args: InT | Sequence[InT] | None = None,
|
201
223
|
entry_point: bool = False,
|
224
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
202
225
|
forbid_state_change: bool = False,
|
203
226
|
**gen_kwargs: Any, # noqa: ARG002
|
204
227
|
) -> AgentMessage[OutT, LLMAgentState]:
|
@@ -211,16 +234,12 @@ class LLMAgent(
|
|
211
234
|
sys_args = run_args.sys
|
212
235
|
usr_args = run_args.usr
|
213
236
|
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
"There must be no received message with user inputs"
|
221
|
-
)
|
222
|
-
|
223
|
-
cur_state = self.state.model_copy(deep=True)
|
237
|
+
self._validate_run_inputs(
|
238
|
+
chat_inputs=chat_inputs,
|
239
|
+
in_args=in_args,
|
240
|
+
in_message=in_message,
|
241
|
+
entry_point=entry_point,
|
242
|
+
)
|
224
243
|
|
225
244
|
# 1. Make system prompt (can be None)
|
226
245
|
formatted_sys_prompt = self._prompt_builder.make_sys_prompt(
|
@@ -229,12 +248,13 @@ class LLMAgent(
|
|
229
248
|
|
230
249
|
# 2. Set agent state
|
231
250
|
|
232
|
-
|
251
|
+
cur_state = self.state.model_copy(deep=True)
|
252
|
+
in_state = in_message.sender_state if in_message else None
|
233
253
|
prev_mh_len = len(cur_state.message_history)
|
234
254
|
|
235
|
-
state = LLMAgentState.
|
255
|
+
state = LLMAgentState.from_cur_and_in_states(
|
236
256
|
cur_state=cur_state,
|
237
|
-
|
257
|
+
in_state=in_state,
|
238
258
|
sys_prompt=formatted_sys_prompt,
|
239
259
|
strategy=self.set_state_strategy,
|
240
260
|
set_agent_state_impl=self._set_agent_state_impl,
|
@@ -244,10 +264,16 @@ class LLMAgent(
|
|
244
264
|
self._print_sys_msg(state=state, prev_mh_len=prev_mh_len, ctx=ctx)
|
245
265
|
|
246
266
|
# 3. Make and add user messages (can be empty)
|
267
|
+
_in_args_batch: Sequence[InT] | None = None
|
268
|
+
if in_message is not None:
|
269
|
+
_in_args_batch = in_message.payloads
|
270
|
+
elif in_args is not None:
|
271
|
+
_in_args_batch = in_args if isinstance(in_args, Sequence) else [in_args] # type: ignore[assignment]
|
272
|
+
|
247
273
|
user_message_batch = self._prompt_builder.make_user_messages(
|
248
|
-
|
274
|
+
chat_inputs=chat_inputs,
|
249
275
|
usr_args=usr_args,
|
250
|
-
|
276
|
+
in_args_batch=_in_args_batch,
|
251
277
|
entry_point=entry_point,
|
252
278
|
ctx=ctx,
|
253
279
|
)
|
@@ -257,7 +283,7 @@ class LLMAgent(
|
|
257
283
|
|
258
284
|
if not self.tools:
|
259
285
|
# 4. Generate messages without tools
|
260
|
-
await self._tool_orchestrator.generate_once(
|
286
|
+
await self._tool_orchestrator.generate_once(state=state, ctx=ctx)
|
261
287
|
else:
|
262
288
|
# 4. Run tool call loop (new messages are added to the message
|
263
289
|
# history inside the loop)
|
@@ -265,14 +291,14 @@ class LLMAgent(
|
|
265
291
|
|
266
292
|
# 5. Parse outputs
|
267
293
|
batch_size = state.message_history.batch_size
|
268
|
-
|
294
|
+
in_args_batch = in_message.payloads if in_message else batch_size * [None]
|
269
295
|
val_output_batch = [
|
270
296
|
self._out_type_adapter.validate_python(
|
271
|
-
self._parse_output(conversation=conv,
|
297
|
+
self._parse_output(conversation=conv, in_args=in_args, ctx=ctx)
|
272
298
|
)
|
273
|
-
for conv,
|
299
|
+
for conv, in_args in zip(
|
274
300
|
state.message_history.batched_conversations,
|
275
|
-
|
301
|
+
in_args_batch,
|
276
302
|
strict=False,
|
277
303
|
)
|
278
304
|
]
|
@@ -285,12 +311,12 @@ class LLMAgent(
|
|
285
311
|
interaction_record = InteractionRecord(
|
286
312
|
source_id=self.agent_id,
|
287
313
|
recipient_ids=recipient_ids,
|
288
|
-
|
314
|
+
chat_inputs=chat_inputs,
|
289
315
|
sys_prompt=self.sys_prompt,
|
290
|
-
|
316
|
+
in_prompt=self.in_prompt,
|
291
317
|
sys_args=sys_args,
|
292
318
|
usr_args=usr_args,
|
293
|
-
|
319
|
+
in_args=(in_message.payloads if in_message is not None else None),
|
294
320
|
outputs=val_output_batch,
|
295
321
|
state=state,
|
296
322
|
)
|
@@ -343,10 +369,10 @@ class LLMAgent(
|
|
343
369
|
|
344
370
|
return func
|
345
371
|
|
346
|
-
def
|
372
|
+
def format_in_args_handler(
|
347
373
|
self, func: FormatInputArgsHandler[InT, CtxT]
|
348
374
|
) -> FormatInputArgsHandler[InT, CtxT]:
|
349
|
-
self._prompt_builder.
|
375
|
+
self._prompt_builder.format_in_args_impl = func
|
350
376
|
|
351
377
|
return func
|
352
378
|
|
@@ -385,8 +411,8 @@ class LLMAgent(
|
|
385
411
|
if cur_cls._format_sys_args is not base_cls._format_sys_args: # noqa: SLF001
|
386
412
|
self._prompt_builder.format_sys_args_impl = self._format_sys_args
|
387
413
|
|
388
|
-
if cur_cls.
|
389
|
-
self._prompt_builder.
|
414
|
+
if cur_cls._format_in_args is not base_cls._format_in_args: # noqa: SLF001
|
415
|
+
self._prompt_builder.format_in_args_impl = self._format_in_args
|
390
416
|
|
391
417
|
if cur_cls._set_agent_state is not base_cls._set_agent_state: # noqa: SLF001
|
392
418
|
self._set_agent_state_impl = self._set_agent_state
|
@@ -412,23 +438,23 @@ class LLMAgent(
|
|
412
438
|
"if it's intended to be used as the system arguments formatter."
|
413
439
|
)
|
414
440
|
|
415
|
-
def
|
441
|
+
def _format_in_args(
|
416
442
|
self,
|
417
443
|
*,
|
418
444
|
usr_args: LLMPromptArgs,
|
419
|
-
|
445
|
+
in_args: InT,
|
420
446
|
batch_idx: int = 0,
|
421
447
|
ctx: RunContextWrapper[CtxT] | None = None,
|
422
448
|
) -> LLMFormattedArgs:
|
423
449
|
raise NotImplementedError(
|
424
|
-
"LLMAgent.
|
450
|
+
"LLMAgent._format_in_args must be overridden by a subclass"
|
425
451
|
)
|
426
452
|
|
427
453
|
def _set_agent_state(
|
428
454
|
self,
|
429
455
|
cur_state: LLMAgentState,
|
430
456
|
*,
|
431
|
-
|
457
|
+
in_state: AgentState | None,
|
432
458
|
sys_prompt: LLMPrompt | None,
|
433
459
|
ctx: RunContextWrapper[Any] | None,
|
434
460
|
) -> LLMAgentState:
|
@@ -15,7 +15,7 @@ class SetAgentState(Protocol):
|
|
15
15
|
self,
|
16
16
|
cur_state: "LLMAgentState",
|
17
17
|
*,
|
18
|
-
|
18
|
+
in_state: AgentState | None,
|
19
19
|
sys_prompt: LLMPrompt | None,
|
20
20
|
ctx: RunContextWrapper[Any] | None,
|
21
21
|
) -> "LLMAgentState": ...
|
@@ -29,11 +29,11 @@ class LLMAgentState(AgentState):
|
|
29
29
|
return self.message_history.batch_size
|
30
30
|
|
31
31
|
@classmethod
|
32
|
-
def
|
32
|
+
def from_cur_and_in_states(
|
33
33
|
cls,
|
34
34
|
cur_state: "LLMAgentState",
|
35
35
|
*,
|
36
|
-
|
36
|
+
in_state: Optional["AgentState"] = None,
|
37
37
|
sys_prompt: LLMPrompt | None = None,
|
38
38
|
strategy: SetAgentStateStrategy = "from_sender",
|
39
39
|
set_agent_state_impl: SetAgentState | None = None,
|
@@ -50,13 +50,13 @@ class LLMAgentState(AgentState):
|
|
50
50
|
upd_mh.reset(sys_prompt)
|
51
51
|
|
52
52
|
elif strategy == "from_sender":
|
53
|
-
|
54
|
-
|
55
|
-
if
|
53
|
+
in_mh = (
|
54
|
+
in_state.message_history
|
55
|
+
if in_state and isinstance(in_state, "LLMAgentState")
|
56
56
|
else None
|
57
57
|
)
|
58
|
-
if
|
59
|
-
|
58
|
+
if in_mh:
|
59
|
+
in_mh = deepcopy(in_mh)
|
60
60
|
else:
|
61
61
|
upd_mh.reset(sys_prompt)
|
62
62
|
|
@@ -66,7 +66,7 @@ class LLMAgentState(AgentState):
|
|
66
66
|
)
|
67
67
|
return set_agent_state_impl(
|
68
68
|
cur_state=cur_state,
|
69
|
-
|
69
|
+
in_state=in_state,
|
70
70
|
sys_prompt=sys_prompt,
|
71
71
|
ctx=ctx,
|
72
72
|
)
|
@@ -14,17 +14,22 @@ def from_api_completion(
|
|
14
14
|
) -> Completion:
|
15
15
|
choices: list[CompletionChoice] = []
|
16
16
|
if api_completion.choices is None: # type: ignore
|
17
|
-
#
|
17
|
+
# Some providers return None for the choices when there is an error
|
18
18
|
# TODO: add custom error types
|
19
19
|
raise RuntimeError(
|
20
20
|
f"Completion API error: {getattr(api_completion, 'error', None)}"
|
21
21
|
)
|
22
22
|
for api_choice in api_completion.choices:
|
23
23
|
# TODO: currently no way to assign individual message usages when len(choices) > 1
|
24
|
+
finish_reason = api_choice.finish_reason
|
25
|
+
# Some providers return None for the message when finish_reason is other than "stop"
|
26
|
+
if api_choice.message is None: # type: ignore
|
27
|
+
raise RuntimeError(
|
28
|
+
f"API returned None for message with finish_reason: {finish_reason}"
|
29
|
+
)
|
24
30
|
message = from_api_assistant_message(
|
25
31
|
api_choice.message, api_completion.usage, model_id=model_id
|
26
32
|
)
|
27
|
-
finish_reason = api_choice.finish_reason
|
28
33
|
choices.append(CompletionChoice(message=message, finish_reason=finish_reason))
|
29
34
|
|
30
35
|
return Completion(choices=choices, model_id=model_id)
|
@@ -51,11 +51,6 @@ def from_api_assistant_message(
|
|
51
51
|
api_usage: ChatCompletionUsage | None = None,
|
52
52
|
model_id: str | None = None,
|
53
53
|
) -> AssistantMessage:
|
54
|
-
content = api_message.content or ""
|
55
|
-
assert isinstance(content, str), (
|
56
|
-
"Only string content is currently supported in assistant messages"
|
57
|
-
)
|
58
|
-
|
59
54
|
usage = None
|
60
55
|
if api_usage is not None:
|
61
56
|
reasoning_tokens = None
|
@@ -88,7 +83,7 @@ def from_api_assistant_message(
|
|
88
83
|
]
|
89
84
|
|
90
85
|
return AssistantMessage(
|
91
|
-
content=content,
|
86
|
+
content=api_message.content,
|
92
87
|
usage=usage,
|
93
88
|
tool_calls=tool_calls,
|
94
89
|
refusal=api_message.refusal,
|
@@ -113,12 +108,22 @@ def to_api_assistant_message(
|
|
113
108
|
for tool_call in message.tool_calls
|
114
109
|
]
|
115
110
|
|
116
|
-
|
111
|
+
api_message = ChatCompletionAssistantMessageParam(
|
117
112
|
role="assistant",
|
118
113
|
content=message.content,
|
119
|
-
tool_calls=api_tool_calls
|
114
|
+
tool_calls=api_tool_calls or [],
|
120
115
|
refusal=message.refusal,
|
121
116
|
)
|
117
|
+
if message.content is None and not api_tool_calls:
|
118
|
+
# Some API providers return None in the generated content without errors,
|
119
|
+
# even though None in the input content is not accepted.
|
120
|
+
api_message["content"] = "<empty>"
|
121
|
+
if api_tool_calls is None:
|
122
|
+
api_message.pop("tool_calls")
|
123
|
+
if message.refusal is None:
|
124
|
+
api_message.pop("refusal")
|
125
|
+
|
126
|
+
return api_message
|
122
127
|
|
123
128
|
|
124
129
|
def from_api_system_message(
|
@@ -7,7 +7,7 @@ from openai import AsyncOpenAI
|
|
7
7
|
from openai._types import NOT_GIVEN # type: ignore[import]
|
8
8
|
from pydantic import BaseModel
|
9
9
|
|
10
|
-
from ..cloud_llm import
|
10
|
+
from ..cloud_llm import CloudLLM, CloudLLMSettings
|
11
11
|
from ..http_client import AsyncHTTPClientParams
|
12
12
|
from ..rate_limiting.rate_limiter_chunked import RateLimiterC
|
13
13
|
from ..typing.message import AssistantMessage, Conversation
|
@@ -69,7 +69,6 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
69
69
|
tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
|
70
70
|
response_format: type | None = None,
|
71
71
|
# Connection settings
|
72
|
-
api_provider: APIProvider = "openai",
|
73
72
|
async_http_client_params: (
|
74
73
|
dict[str, Any] | AsyncHTTPClientParams | None
|
75
74
|
) = None,
|
@@ -92,7 +91,6 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
92
91
|
converters=OpenAIConverters(),
|
93
92
|
tools=tools,
|
94
93
|
response_format=response_format,
|
95
|
-
api_provider=api_provider,
|
96
94
|
async_http_client_params=async_http_client_params,
|
97
95
|
rate_limiter=rate_limiter,
|
98
96
|
rate_limiter_rpm=rate_limiter_rpm,
|
@@ -125,7 +123,7 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
125
123
|
tool_choice = api_tool_choice or NOT_GIVEN
|
126
124
|
|
127
125
|
return await self._client.chat.completions.create(
|
128
|
-
model=self.
|
126
|
+
model=self._api_model_name,
|
129
127
|
messages=api_messages,
|
130
128
|
tools=tools,
|
131
129
|
tool_choice=tool_choice,
|
@@ -146,7 +144,7 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
146
144
|
response_format = api_response_format or NOT_GIVEN
|
147
145
|
|
148
146
|
return await self._client.beta.chat.completions.parse(
|
149
|
-
model=self.
|
147
|
+
model=self._api_model_name,
|
150
148
|
messages=api_messages,
|
151
149
|
tools=tools,
|
152
150
|
tool_choice=tool_choice,
|
@@ -167,7 +165,7 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
167
165
|
tool_choice = api_tool_choice or NOT_GIVEN
|
168
166
|
|
169
167
|
return await self._client.chat.completions.create(
|
170
|
-
model=self.
|
168
|
+
model=self._api_model_name,
|
171
169
|
messages=api_messages,
|
172
170
|
tools=tools,
|
173
171
|
tool_choice=tool_choice,
|
@@ -15,15 +15,16 @@ from . import (
|
|
15
15
|
def to_api_tool(
|
16
16
|
tool: BaseTool[BaseModel, Any, Any],
|
17
17
|
) -> ChatCompletionToolParam:
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
parameters=tool.in_schema.model_json_schema(),
|
24
|
-
strict=tool.strict,
|
25
|
-
),
|
18
|
+
function = ChatCompletionFunctionDefinition(
|
19
|
+
name=tool.name,
|
20
|
+
description=tool.description,
|
21
|
+
parameters=tool.in_schema.model_json_schema(),
|
22
|
+
strict=tool.strict,
|
26
23
|
)
|
24
|
+
if tool.strict is None:
|
25
|
+
function.pop("strict")
|
26
|
+
|
27
|
+
return ChatCompletionToolParam(type="function", function=function)
|
27
28
|
|
28
29
|
|
29
30
|
def to_api_tool_choice(
|
@@ -88,7 +88,7 @@ class Printer:
|
|
88
88
|
|
89
89
|
role = message.role
|
90
90
|
usage = message.usage if isinstance(message, AssistantMessage) else None
|
91
|
-
content_str = self.content_to_str(message.content, message.role)
|
91
|
+
content_str = self.content_to_str(message.content or "", message.role)
|
92
92
|
|
93
93
|
if self.color_by == "agent_id":
|
94
94
|
color = self.get_agent_color(agent_id)
|
@@ -35,7 +35,7 @@ class FormatInputArgsHandler(Protocol[InT, CtxT]):
|
|
35
35
|
self,
|
36
36
|
*,
|
37
37
|
usr_args: LLMPromptArgs,
|
38
|
-
|
38
|
+
in_args: InT,
|
39
39
|
batch_idx: int,
|
40
40
|
ctx: RunContextWrapper[CtxT] | None,
|
41
41
|
) -> LLMFormattedArgs: ...
|
@@ -48,7 +48,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
48
48
|
self,
|
49
49
|
agent_id: str,
|
50
50
|
sys_prompt: LLMPrompt | None,
|
51
|
-
|
51
|
+
in_prompt: LLMPrompt | None,
|
52
52
|
sys_args_schema: type[LLMPromptArgs],
|
53
53
|
usr_args_schema: type[LLMPromptArgs],
|
54
54
|
):
|
@@ -57,13 +57,13 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
57
57
|
|
58
58
|
self._agent_id = agent_id
|
59
59
|
self.sys_prompt = sys_prompt
|
60
|
-
self.
|
60
|
+
self.in_prompt = in_prompt
|
61
61
|
self.sys_args_schema = sys_args_schema
|
62
62
|
self.usr_args_schema = usr_args_schema
|
63
63
|
self.format_sys_args_impl: FormatSystemArgsHandler[CtxT] | None = None
|
64
|
-
self.
|
64
|
+
self.format_in_args_impl: FormatInputArgsHandler[InT, CtxT] | None = None
|
65
65
|
|
66
|
-
self.
|
66
|
+
self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
67
67
|
|
68
68
|
def _format_sys_args(
|
69
69
|
self,
|
@@ -75,31 +75,31 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
75
75
|
|
76
76
|
return sys_args.model_dump(exclude_unset=True)
|
77
77
|
|
78
|
-
def
|
78
|
+
def _format_in_args(
|
79
79
|
self,
|
80
80
|
*,
|
81
81
|
usr_args: LLMPromptArgs,
|
82
|
-
|
82
|
+
in_args: InT,
|
83
83
|
batch_idx: int = 0,
|
84
84
|
ctx: RunContextWrapper[CtxT] | None = None,
|
85
85
|
) -> LLMFormattedArgs:
|
86
|
-
if self.
|
87
|
-
return self.
|
88
|
-
usr_args=usr_args,
|
86
|
+
if self.format_in_args_impl:
|
87
|
+
return self.format_in_args_impl(
|
88
|
+
usr_args=usr_args, in_args=in_args, batch_idx=batch_idx, ctx=ctx
|
89
89
|
)
|
90
90
|
|
91
|
-
if not isinstance(
|
91
|
+
if not isinstance(in_args, BaseModel) and in_args is not None:
|
92
92
|
raise TypeError(
|
93
93
|
"Cannot apply default formatting to non-BaseModel received arguments."
|
94
94
|
)
|
95
95
|
|
96
96
|
usr_args_ = usr_args
|
97
|
-
|
97
|
+
in_args_ = DummySchema() if in_args is None else in_args
|
98
98
|
|
99
99
|
usr_args_dump = usr_args_.model_dump(exclude_unset=True)
|
100
|
-
|
100
|
+
in_args_dump = in_args_.model_dump(exclude={"selected_recipient_ids"})
|
101
101
|
|
102
|
-
return usr_args_dump |
|
102
|
+
return usr_args_dump | in_args_dump
|
103
103
|
|
104
104
|
def make_sys_prompt(
|
105
105
|
self,
|
@@ -118,17 +118,17 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
118
118
|
return [UserMessage.from_text(text, model_id=self._agent_id)]
|
119
119
|
|
120
120
|
def _usr_messages_from_content_parts(
|
121
|
-
self, content_parts:
|
122
|
-
) ->
|
121
|
+
self, content_parts: Sequence[str | ImageData]
|
122
|
+
) -> Sequence[UserMessage]:
|
123
123
|
return [UserMessage.from_content_parts(content_parts, model_id=self._agent_id)]
|
124
124
|
|
125
|
-
def
|
126
|
-
self,
|
127
|
-
) ->
|
125
|
+
def _usr_messages_from_in_args(
|
126
|
+
self, in_args_batch: Sequence[InT]
|
127
|
+
) -> Sequence[UserMessage]:
|
128
128
|
return [
|
129
129
|
UserMessage.from_text(
|
130
|
-
self.
|
131
|
-
|
130
|
+
self._in_args_type_adapter.dump_json(
|
131
|
+
inp,
|
132
132
|
exclude_unset=True,
|
133
133
|
indent=2,
|
134
134
|
exclude={"selected_recipient_ids"},
|
@@ -136,96 +136,102 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
136
136
|
).decode("utf-8"),
|
137
137
|
model_id=self._agent_id,
|
138
138
|
)
|
139
|
-
for
|
139
|
+
for inp in in_args_batch
|
140
140
|
]
|
141
141
|
|
142
142
|
def _usr_messages_from_prompt_template(
|
143
143
|
self,
|
144
|
-
|
144
|
+
in_prompt: LLMPrompt,
|
145
145
|
usr_args: UserRunArgs | None = None,
|
146
|
-
|
146
|
+
in_args_batch: Sequence[InT] | None = None,
|
147
147
|
ctx: RunContextWrapper[CtxT] | None = None,
|
148
148
|
) -> Sequence[UserMessage]:
|
149
|
-
usr_args_batch_,
|
149
|
+
usr_args_batch_, in_args_batch_ = self._make_batched(usr_args, in_args_batch)
|
150
150
|
|
151
151
|
val_usr_args_batch_ = [
|
152
152
|
self.usr_args_schema.model_validate(u) for u in usr_args_batch_
|
153
153
|
]
|
154
|
-
|
155
|
-
self.
|
154
|
+
val_in_args_batch_ = [
|
155
|
+
self._in_args_type_adapter.validate_python(inp) for inp in in_args_batch_
|
156
156
|
]
|
157
157
|
|
158
|
-
|
159
|
-
self.
|
160
|
-
usr_args=val_usr_args,
|
158
|
+
formatted_in_args_batch = [
|
159
|
+
self._format_in_args(
|
160
|
+
usr_args=val_usr_args, in_args=val_in_args, batch_idx=i, ctx=ctx
|
161
161
|
)
|
162
|
-
for i, (val_usr_args,
|
163
|
-
zip(val_usr_args_batch_,
|
162
|
+
for i, (val_usr_args, val_in_args) in enumerate(
|
163
|
+
zip(val_usr_args_batch_, val_in_args_batch_, strict=False)
|
164
164
|
)
|
165
165
|
]
|
166
166
|
|
167
167
|
return [
|
168
168
|
UserMessage.from_formatted_prompt(
|
169
|
-
prompt_template=
|
169
|
+
prompt_template=in_prompt, prompt_args=in_args
|
170
170
|
)
|
171
|
-
for
|
171
|
+
for in_args in formatted_in_args_batch
|
172
172
|
]
|
173
173
|
|
174
174
|
def make_user_messages(
|
175
175
|
self,
|
176
|
-
|
176
|
+
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
177
177
|
usr_args: UserRunArgs | None = None,
|
178
|
-
|
178
|
+
in_args_batch: Sequence[InT] | None = None,
|
179
179
|
entry_point: bool = False,
|
180
180
|
ctx: RunContextWrapper[CtxT] | None = None,
|
181
181
|
) -> Sequence[UserMessage]:
|
182
182
|
# 1) Direct user input (e.g. chat input)
|
183
|
-
if
|
183
|
+
if chat_inputs is not None or entry_point:
|
184
184
|
"""
|
185
|
-
* If
|
185
|
+
* If chat inputs are provided, use them instead of the predefined
|
186
186
|
input prompt template
|
187
187
|
* In a multi-agent system, the predefined input prompt is used to
|
188
188
|
construct agent inputs using the combination of received
|
189
189
|
and user arguments.
|
190
190
|
However, the first agent run (entry point) has no received
|
191
|
-
messages, so we use the
|
191
|
+
messages, so we use the chat inputs directly, if provided.
|
192
192
|
"""
|
193
|
-
if isinstance(
|
194
|
-
return self._usr_messages_from_text(
|
195
|
-
|
196
|
-
|
197
|
-
|
193
|
+
if isinstance(chat_inputs, LLMPrompt):
|
194
|
+
return self._usr_messages_from_text(chat_inputs)
|
195
|
+
|
196
|
+
if isinstance(chat_inputs, Sequence) and chat_inputs:
|
197
|
+
return self._usr_messages_from_content_parts(chat_inputs)
|
198
198
|
|
199
199
|
# 2) No input prompt template + received args → raw JSON messages
|
200
|
-
if self.
|
201
|
-
return self.
|
200
|
+
if self.in_prompt is None and in_args_batch:
|
201
|
+
return self._usr_messages_from_in_args(in_args_batch)
|
202
202
|
|
203
203
|
# 3) Input prompt template + any args → batch & format
|
204
|
-
if self.
|
204
|
+
if self.in_prompt is not None:
|
205
|
+
if in_args_batch and not isinstance(in_args_batch[0], BaseModel):
|
206
|
+
raise TypeError(
|
207
|
+
"Cannot use the input prompt template with "
|
208
|
+
"non-BaseModel received arguments."
|
209
|
+
)
|
205
210
|
return self._usr_messages_from_prompt_template(
|
206
|
-
|
211
|
+
in_prompt=self.in_prompt,
|
207
212
|
usr_args=usr_args,
|
208
|
-
|
213
|
+
in_args_batch=in_args_batch,
|
209
214
|
ctx=ctx,
|
210
215
|
)
|
216
|
+
|
211
217
|
return []
|
212
218
|
|
213
219
|
def _make_batched(
|
214
220
|
self,
|
215
221
|
usr_args: UserRunArgs | None = None,
|
216
|
-
|
222
|
+
in_args_batch: Sequence[InT] | None = None,
|
217
223
|
) -> tuple[Sequence[LLMPromptArgs | DummySchema], Sequence[InT | DummySchema]]:
|
218
224
|
usr_args_batch_ = (
|
219
225
|
usr_args if isinstance(usr_args, list) else [usr_args or DummySchema()]
|
220
226
|
)
|
221
|
-
|
227
|
+
in_args_batch_ = in_args_batch or [DummySchema()]
|
222
228
|
|
223
229
|
# Broadcast singleton → match lengths
|
224
|
-
if len(usr_args_batch_) == 1 and len(
|
225
|
-
usr_args_batch_ = [deepcopy(usr_args_batch_[0]) for _ in
|
226
|
-
if len(
|
227
|
-
|
228
|
-
if len(usr_args_batch_) != len(
|
230
|
+
if len(usr_args_batch_) == 1 and len(in_args_batch_) > 1:
|
231
|
+
usr_args_batch_ = [deepcopy(usr_args_batch_[0]) for _ in in_args_batch_]
|
232
|
+
if len(in_args_batch_) == 1 and len(usr_args_batch_) > 1:
|
233
|
+
in_args_batch_ = [deepcopy(in_args_batch_[0]) for _ in usr_args_batch_]
|
234
|
+
if len(usr_args_batch_) != len(in_args_batch_):
|
229
235
|
raise ValueError("User args and received args must have the same length")
|
230
236
|
|
231
|
-
return usr_args_batch_,
|
237
|
+
return usr_args_batch_, in_args_batch_
|
@@ -32,12 +32,12 @@ class InteractionRecord(BaseModel, Generic[InT, OutT, StateT]):
|
|
32
32
|
source_id: str
|
33
33
|
recipient_ids: Sequence[AgentID]
|
34
34
|
state: StateT
|
35
|
-
|
35
|
+
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None
|
36
36
|
sys_prompt: LLMPrompt | None = None
|
37
|
-
|
37
|
+
in_prompt: LLMPrompt | None = None
|
38
38
|
sys_args: SystemRunArgs | None = None
|
39
39
|
usr_args: UserRunArgs | None = None
|
40
|
-
|
40
|
+
in_args: Sequence[InT] | None = None
|
41
41
|
outputs: Sequence[OutT]
|
42
42
|
|
43
43
|
model_config = ConfigDict(extra="forbid", frozen=True)
|
@@ -103,11 +103,11 @@ class ToolOrchestrator(Generic[CtxT]):
|
|
103
103
|
|
104
104
|
async def generate_once(
|
105
105
|
self,
|
106
|
-
|
106
|
+
state: LLMAgentState,
|
107
107
|
tool_choice: ToolChoice | None = None,
|
108
108
|
ctx: RunContextWrapper[CtxT] | None = None,
|
109
109
|
) -> Sequence[AssistantMessage]:
|
110
|
-
message_history =
|
110
|
+
message_history = state.message_history
|
111
111
|
message_batch = await self.llm.generate_message_batch(
|
112
112
|
message_history, tool_choice=tool_choice
|
113
113
|
)
|
@@ -61,7 +61,7 @@ class MessageBase(BaseModel):
|
|
61
61
|
|
62
62
|
class AssistantMessage(MessageBase):
|
63
63
|
role: Literal[Role.ASSISTANT] = Role.ASSISTANT
|
64
|
-
content: str
|
64
|
+
content: str | None
|
65
65
|
usage: Usage | None = None
|
66
66
|
tool_calls: Sequence[ToolCall] | None = None
|
67
67
|
refusal: str | None = None
|
@@ -19,6 +19,7 @@ CostsDict: TypeAlias = dict[str, ModelCostsDict]
|
|
19
19
|
|
20
20
|
|
21
21
|
class UsageTracker(BaseModel):
|
22
|
+
# TODO: specify different costs per provider:model, not just per model
|
22
23
|
source_id: str
|
23
24
|
costs_dict_path: str | Path = COSTS_DICT_PATH
|
24
25
|
costs_dict: CostsDict | None = None
|
@@ -60,7 +61,7 @@ class UsageTracker(BaseModel):
|
|
60
61
|
self, messages: Sequence[Message], model_name: str | None = None
|
61
62
|
) -> None:
|
62
63
|
if model_name is not None and self.costs_dict is not None:
|
63
|
-
model_costs_dict = self.costs_dict.get(model_name)
|
64
|
+
model_costs_dict = self.costs_dict.get(model_name.split(":", 1)[-1])
|
64
65
|
else:
|
65
66
|
model_costs_dict = None
|
66
67
|
|
@@ -67,7 +67,8 @@ class LoopedWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, Ctx
|
|
67
67
|
def _exit_workflow_loop(
|
68
68
|
self,
|
69
69
|
output_message: AgentMessage[OutT, Any],
|
70
|
-
|
70
|
+
*,
|
71
|
+
ctx: RunContextWrapper[CtxT] | None = None,
|
71
72
|
**kwargs: Any,
|
72
73
|
) -> bool:
|
73
74
|
if self._exit_workflow_loop_impl:
|
@@ -78,23 +79,25 @@ class LoopedWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, Ctx
|
|
78
79
|
@final
|
79
80
|
async def run(
|
80
81
|
self,
|
81
|
-
|
82
|
+
chat_inputs: Any | None = None,
|
82
83
|
*,
|
83
|
-
|
84
|
+
in_args: InT | Sequence[InT] | None = None,
|
85
|
+
in_message: AgentMessage[InT, Any] | None = None,
|
84
86
|
ctx: RunContextWrapper[CtxT] | None = None,
|
85
87
|
entry_point: bool = False,
|
86
88
|
forbid_state_change: bool = False,
|
87
89
|
**kwargs: Any,
|
88
90
|
) -> AgentMessage[OutT, AgentState]:
|
89
|
-
agent_message =
|
91
|
+
agent_message = in_message
|
90
92
|
num_iterations = 0
|
91
93
|
exit_message: AgentMessage[OutT, Any] | None = None
|
92
94
|
|
93
95
|
while True:
|
94
96
|
for subagent in self.subagents:
|
95
97
|
agent_message = await subagent.run(
|
96
|
-
|
97
|
-
|
98
|
+
chat_inputs=chat_inputs,
|
99
|
+
in_args=in_args,
|
100
|
+
in_message=agent_message,
|
98
101
|
entry_point=entry_point,
|
99
102
|
forbid_state_change=forbid_state_change,
|
100
103
|
ctx=ctx,
|
@@ -112,5 +115,6 @@ class LoopedWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT, Ctx
|
|
112
115
|
)
|
113
116
|
return exit_message
|
114
117
|
|
115
|
-
|
118
|
+
chat_inputs = None
|
119
|
+
in_args = None
|
116
120
|
entry_point = False
|
@@ -36,25 +36,28 @@ class SequentialWorkflowAgent(WorkflowAgent[InT, OutT, CtxT], Generic[InT, OutT,
|
|
36
36
|
@final
|
37
37
|
async def run(
|
38
38
|
self,
|
39
|
-
|
39
|
+
chat_inputs: Any | None = None,
|
40
40
|
*,
|
41
|
-
|
41
|
+
in_args: InT | Sequence[InT] | None = None,
|
42
|
+
in_message: AgentMessage[InT, Any] | None = None,
|
42
43
|
ctx: RunContextWrapper[CtxT] | None = None,
|
43
44
|
entry_point: bool = False,
|
44
45
|
forbid_state_change: bool = False,
|
45
46
|
**kwargs: Any,
|
46
47
|
) -> AgentMessage[OutT, Any]:
|
47
|
-
agent_message =
|
48
|
+
agent_message = in_message
|
48
49
|
for subagent in self.subagents:
|
49
50
|
agent_message = await subagent.run(
|
50
|
-
|
51
|
-
|
51
|
+
chat_inputs=chat_inputs,
|
52
|
+
in_args=in_args,
|
53
|
+
in_message=agent_message,
|
52
54
|
entry_point=entry_point,
|
53
55
|
forbid_state_change=forbid_state_change,
|
54
56
|
ctx=ctx,
|
55
57
|
**kwargs,
|
56
58
|
)
|
57
|
-
|
59
|
+
chat_inputs = None
|
60
|
+
in_args = None
|
58
61
|
entry_point = False
|
59
62
|
|
60
63
|
return cast("AgentMessage[OutT, Any]", agent_message)
|
@@ -61,9 +61,10 @@ class WorkflowAgent(
|
|
61
61
|
@abstractmethod
|
62
62
|
async def run(
|
63
63
|
self,
|
64
|
-
|
64
|
+
chat_inputs: Any | None = None,
|
65
65
|
*,
|
66
|
-
|
66
|
+
in_args: InT | Sequence[InT] | None = None,
|
67
|
+
in_message: AgentMessage[InT, Any] | None = None,
|
67
68
|
ctx: RunContextWrapper[CtxT] | None = None,
|
68
69
|
entry_point: bool = False,
|
69
70
|
forbid_state_change: bool = False,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{grasp_agents-0.2.3 → grasp_agents-0.2.5}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|