grasp_agents 0.5.10__tar.gz → 0.5.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/PKG-INFO +7 -20
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/README.md +6 -19
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/pyproject.toml +1 -1
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/__init__.py +3 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/cloud_llm.py +15 -15
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/generics_utils.py +1 -1
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/litellm/lite_llm.py +3 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/llm_agent.py +63 -38
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/llm_agent_memory.py +1 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/llm_policy_executor.py +40 -45
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/openai/openai_llm.py +4 -1
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/printer.py +153 -136
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/processors/base_processor.py +5 -3
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/processors/parallel_processor.py +2 -2
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/processors/processor.py +2 -2
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/prompt_builder.py +23 -7
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/run_context.py +2 -9
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/typing/tool.py +5 -3
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/.gitignore +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/LICENSE.md +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/costs_dict.yaml +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/errors.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/grasp_logging.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/http_client.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/litellm/__init__.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/litellm/completion_chunk_converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/litellm/completion_converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/litellm/converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/litellm/message_converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/llm.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/memory.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/openai/__init__.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/openai/completion_chunk_converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/openai/completion_converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/openai/content_converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/openai/converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/openai/message_converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/openai/tool_converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/packet.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/packet_pool.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/rate_limiting/__init__.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/rate_limiting/types.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/rate_limiting/utils.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/runner.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/typing/__init__.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/typing/completion.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/typing/completion_chunk.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/typing/content.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/typing/converters.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/typing/events.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/typing/io.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/typing/message.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/usage_tracker.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/utils.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/workflow/__init__.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/workflow/looped_workflow.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/workflow/sequential_workflow.py +0 -0
- {grasp_agents-0.5.10 → grasp_agents-0.5.12}/src/grasp_agents/workflow/workflow_processor.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: grasp_agents
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.12
|
4
4
|
Summary: Grasp Agents Library
|
5
5
|
License-File: LICENSE.md
|
6
6
|
Requires-Python: <4,>=3.11.4
|
@@ -110,24 +110,16 @@ Create a script, e.g., `problem_recommender.py`:
|
|
110
110
|
|
111
111
|
```python
|
112
112
|
import asyncio
|
113
|
-
from pathlib import Path
|
114
113
|
from typing import Any
|
115
114
|
|
116
115
|
from dotenv import load_dotenv
|
117
116
|
from pydantic import BaseModel, Field
|
118
117
|
|
119
|
-
from grasp_agents
|
118
|
+
from grasp_agents import LLMAgent, BaseTool, RunContext, Printer
|
120
119
|
from grasp_agents.litellm import LiteLLM, LiteLLMSettings
|
121
|
-
from grasp_agents import LLMAgent, BaseTool, RunContext
|
122
|
-
|
123
|
-
load_dotenv()
|
124
120
|
|
125
121
|
|
126
|
-
|
127
|
-
setup_logging(
|
128
|
-
logs_file_path="grasp_agents_demo.log",
|
129
|
-
logs_config_path=Path().cwd() / "configs/logging/default.yaml",
|
130
|
-
)
|
122
|
+
load_dotenv()
|
131
123
|
|
132
124
|
sys_prompt_react = """
|
133
125
|
Your task is to suggest an exciting stats problem to the student.
|
@@ -162,13 +154,11 @@ Returns:
|
|
162
154
|
"""
|
163
155
|
|
164
156
|
|
165
|
-
class AskStudentTool(BaseTool[TeacherQuestion, StudentReply,
|
157
|
+
class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, None]):
|
166
158
|
name: str = "ask_student"
|
167
159
|
description: str = ask_student_tool_description
|
168
160
|
|
169
|
-
async def run(
|
170
|
-
self, inp: TeacherQuestion, ctx: RunContext[Any] | None = None
|
171
|
-
) -> StudentReply:
|
161
|
+
async def run(self, inp: TeacherQuestion, **kwargs: Any) -> StudentReply:
|
172
162
|
return input(inp.question)
|
173
163
|
|
174
164
|
|
@@ -178,10 +168,7 @@ class Problem(BaseModel):
|
|
178
168
|
|
179
169
|
teacher = LLMAgent[None, Problem, None](
|
180
170
|
name="teacher",
|
181
|
-
llm=LiteLLM(
|
182
|
-
model_name="gpt-4.1",
|
183
|
-
llm_settings=LiteLLMSettings(temperature=0.5),
|
184
|
-
),
|
171
|
+
llm=LiteLLM(model_name="gpt-4.1"),
|
185
172
|
tools=[AskStudentTool()],
|
186
173
|
react_mode=True,
|
187
174
|
final_answer_as_tool_call=True,
|
@@ -189,7 +176,7 @@ teacher = LLMAgent[None, Problem, None](
|
|
189
176
|
)
|
190
177
|
|
191
178
|
async def main():
|
192
|
-
ctx = RunContext[None](
|
179
|
+
ctx = RunContext[None](printer=Printer())
|
193
180
|
out = await teacher.run("start", ctx=ctx)
|
194
181
|
print(out.payloads[0])
|
195
182
|
print(ctx.usage_tracker.total_usage)
|
@@ -93,24 +93,16 @@ Create a script, e.g., `problem_recommender.py`:
|
|
93
93
|
|
94
94
|
```python
|
95
95
|
import asyncio
|
96
|
-
from pathlib import Path
|
97
96
|
from typing import Any
|
98
97
|
|
99
98
|
from dotenv import load_dotenv
|
100
99
|
from pydantic import BaseModel, Field
|
101
100
|
|
102
|
-
from grasp_agents
|
101
|
+
from grasp_agents import LLMAgent, BaseTool, RunContext, Printer
|
103
102
|
from grasp_agents.litellm import LiteLLM, LiteLLMSettings
|
104
|
-
from grasp_agents import LLMAgent, BaseTool, RunContext
|
105
|
-
|
106
|
-
load_dotenv()
|
107
103
|
|
108
104
|
|
109
|
-
|
110
|
-
setup_logging(
|
111
|
-
logs_file_path="grasp_agents_demo.log",
|
112
|
-
logs_config_path=Path().cwd() / "configs/logging/default.yaml",
|
113
|
-
)
|
105
|
+
load_dotenv()
|
114
106
|
|
115
107
|
sys_prompt_react = """
|
116
108
|
Your task is to suggest an exciting stats problem to the student.
|
@@ -145,13 +137,11 @@ Returns:
|
|
145
137
|
"""
|
146
138
|
|
147
139
|
|
148
|
-
class AskStudentTool(BaseTool[TeacherQuestion, StudentReply,
|
140
|
+
class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, None]):
|
149
141
|
name: str = "ask_student"
|
150
142
|
description: str = ask_student_tool_description
|
151
143
|
|
152
|
-
async def run(
|
153
|
-
self, inp: TeacherQuestion, ctx: RunContext[Any] | None = None
|
154
|
-
) -> StudentReply:
|
144
|
+
async def run(self, inp: TeacherQuestion, **kwargs: Any) -> StudentReply:
|
155
145
|
return input(inp.question)
|
156
146
|
|
157
147
|
|
@@ -161,10 +151,7 @@ class Problem(BaseModel):
|
|
161
151
|
|
162
152
|
teacher = LLMAgent[None, Problem, None](
|
163
153
|
name="teacher",
|
164
|
-
llm=LiteLLM(
|
165
|
-
model_name="gpt-4.1",
|
166
|
-
llm_settings=LiteLLMSettings(temperature=0.5),
|
167
|
-
),
|
154
|
+
llm=LiteLLM(model_name="gpt-4.1"),
|
168
155
|
tools=[AskStudentTool()],
|
169
156
|
react_mode=True,
|
170
157
|
final_answer_as_tool_call=True,
|
@@ -172,7 +159,7 @@ teacher = LLMAgent[None, Problem, None](
|
|
172
159
|
)
|
173
160
|
|
174
161
|
async def main():
|
175
|
-
ctx = RunContext[None](
|
162
|
+
ctx = RunContext[None](printer=Printer())
|
176
163
|
out = await teacher.run("start", ctx=ctx)
|
177
164
|
print(out.payloads[0])
|
178
165
|
print(ctx.usage_tracker.total_usage)
|
@@ -6,6 +6,7 @@ from .llm_agent import LLMAgent
|
|
6
6
|
from .llm_agent_memory import LLMAgentMemory
|
7
7
|
from .memory import Memory
|
8
8
|
from .packet import Packet
|
9
|
+
from .printer import Printer, print_event_stream
|
9
10
|
from .processors.base_processor import BaseProcessor
|
10
11
|
from .processors.parallel_processor import ParallelProcessor
|
11
12
|
from .processors.processor import Processor
|
@@ -33,9 +34,11 @@ __all__ = [
|
|
33
34
|
"Packet",
|
34
35
|
"Packet",
|
35
36
|
"ParallelProcessor",
|
37
|
+
"Printer",
|
36
38
|
"ProcName",
|
37
39
|
"Processor",
|
38
40
|
"RunContext",
|
39
41
|
"SystemMessage",
|
40
42
|
"UserMessage",
|
43
|
+
"print_event_stream",
|
41
44
|
]
|
@@ -61,7 +61,6 @@ LLMRateLimiter = RateLimiterC[
|
|
61
61
|
|
62
62
|
@dataclass(frozen=True)
|
63
63
|
class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co]):
|
64
|
-
# Make this field keyword-only to avoid ordering issues with inherited defaulted fields
|
65
64
|
api_provider: APIProvider | None = None
|
66
65
|
llm_settings: SettingsT_co | None = None
|
67
66
|
rate_limiter: LLMRateLimiter | None = None
|
@@ -70,6 +69,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
70
69
|
0 # LLM response retries: try to regenerate to pass validation
|
71
70
|
)
|
72
71
|
apply_response_schema_via_provider: bool = False
|
72
|
+
apply_tool_call_schema_via_provider: bool = False
|
73
73
|
async_http_client: httpx.AsyncClient | None = None
|
74
74
|
async_http_client_params: dict[str, Any] | AsyncHTTPClientParams | None = None
|
75
75
|
|
@@ -80,6 +80,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
80
80
|
f"{self.rate_limiter.rpm} RPM"
|
81
81
|
)
|
82
82
|
|
83
|
+
if self.apply_response_schema_via_provider:
|
84
|
+
object.__setattr__(self, "apply_tool_call_schema_via_provider", True)
|
85
|
+
|
83
86
|
if self.async_http_client is None and self.async_http_client_params is not None:
|
84
87
|
object.__setattr__(
|
85
88
|
self,
|
@@ -100,7 +103,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
100
103
|
api_tools = None
|
101
104
|
api_tool_choice = None
|
102
105
|
if tools:
|
103
|
-
strict = True if self.
|
106
|
+
strict = True if self.apply_tool_call_schema_via_provider else None
|
104
107
|
api_tools = [
|
105
108
|
self.converters.to_tool(t, strict=strict) for t in tools.values()
|
106
109
|
]
|
@@ -175,8 +178,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
175
178
|
response_schema=response_schema,
|
176
179
|
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
177
180
|
)
|
178
|
-
|
179
|
-
|
181
|
+
if not self.apply_tool_call_schema_via_provider and tools is not None:
|
182
|
+
self._validate_tool_calls(completion, tools=tools)
|
180
183
|
|
181
184
|
return completion
|
182
185
|
|
@@ -208,17 +211,16 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
208
211
|
|
209
212
|
if n_attempt > self.max_response_retries:
|
210
213
|
if n_attempt == 1:
|
211
|
-
logger.warning(f"\nCloudLLM completion
|
214
|
+
logger.warning(f"\nCloudLLM completion failed:\n{err}")
|
212
215
|
if n_attempt > 1:
|
213
216
|
logger.warning(
|
214
|
-
f"\nCloudLLM completion
|
217
|
+
f"\nCloudLLM completion failed after retrying:\n{err}"
|
215
218
|
)
|
216
219
|
raise err
|
217
220
|
# return make_refusal_completion(self._model_name, err)
|
218
221
|
|
219
222
|
logger.warning(
|
220
|
-
f"\nCloudLLM completion
|
221
|
-
f"\n{err}"
|
223
|
+
f"\nCloudLLM completion failed (retry attempt {n_attempt}):\n{err}"
|
222
224
|
)
|
223
225
|
|
224
226
|
return make_refusal_completion(
|
@@ -282,8 +284,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
282
284
|
response_schema=response_schema,
|
283
285
|
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
284
286
|
)
|
285
|
-
|
286
|
-
|
287
|
+
if not self.apply_tool_call_schema_via_provider and tools is not None:
|
288
|
+
self._validate_tool_calls(completion, tools=tools)
|
287
289
|
|
288
290
|
return iterator()
|
289
291
|
|
@@ -327,11 +329,10 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
327
329
|
n_attempt += 1
|
328
330
|
if n_attempt > self.max_response_retries:
|
329
331
|
if n_attempt == 1:
|
330
|
-
logger.warning(f"\nCloudLLM completion
|
332
|
+
logger.warning(f"\nCloudLLM completion failed:\n{err}")
|
331
333
|
if n_attempt > 1:
|
332
334
|
logger.warning(
|
333
|
-
"\nCloudLLM completion
|
334
|
-
f"retrying:\n{err}"
|
335
|
+
f"\nCloudLLM completion failed after retrying:\n{err}"
|
335
336
|
)
|
336
337
|
refusal_completion = make_refusal_completion(
|
337
338
|
self.model_name, err
|
@@ -345,6 +346,5 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
345
346
|
# return
|
346
347
|
|
347
348
|
logger.warning(
|
348
|
-
"\nCloudLLM completion
|
349
|
-
f"(retry attempt {n_attempt}):\n{err}"
|
349
|
+
f"\nCloudLLM completion failed (retry attempt {n_attempt}):\n{err}"
|
350
350
|
)
|
@@ -159,7 +159,7 @@ class AutoInstanceAttributesMixin:
|
|
159
159
|
attr_type = resolved_attr_types[attr_name]
|
160
160
|
# attr_type = None if _attr_type is type(None) else _attr_type
|
161
161
|
else:
|
162
|
-
attr_type =
|
162
|
+
attr_type = object
|
163
163
|
|
164
164
|
if attr_name in pyd_private:
|
165
165
|
pyd_private[attr_name] = attr_type
|
@@ -149,6 +149,9 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
|
|
149
149
|
n_choices: int | None = None,
|
150
150
|
**api_llm_settings: Any,
|
151
151
|
) -> LiteLLMCompletion:
|
152
|
+
if api_llm_settings and api_llm_settings.get("stream_options"):
|
153
|
+
api_llm_settings.pop("stream_options")
|
154
|
+
|
152
155
|
completion = await litellm.acompletion( # type: ignore[no-untyped-call]
|
153
156
|
model=self.model_name,
|
154
157
|
messages=api_messages,
|
@@ -42,6 +42,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
|
|
42
42
|
*,
|
43
43
|
in_args: _InT_contra | None,
|
44
44
|
ctx: RunContext[CtxT],
|
45
|
+
call_id: str,
|
45
46
|
) -> _OutT_co: ...
|
46
47
|
|
47
48
|
|
@@ -169,10 +170,15 @@ class LLMAgent(
|
|
169
170
|
in_args: InT | None = None,
|
170
171
|
sys_prompt: LLMPrompt | None = None,
|
171
172
|
ctx: RunContext[Any],
|
173
|
+
call_id: str,
|
172
174
|
) -> None:
|
173
175
|
if self.memory_preparator:
|
174
176
|
return self.memory_preparator(
|
175
|
-
memory=memory,
|
177
|
+
memory=memory,
|
178
|
+
in_args=in_args,
|
179
|
+
sys_prompt=sys_prompt,
|
180
|
+
ctx=ctx,
|
181
|
+
call_id=call_id,
|
176
182
|
)
|
177
183
|
|
178
184
|
def _memorize_inputs(
|
@@ -182,8 +188,11 @@ class LLMAgent(
|
|
182
188
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
183
189
|
in_args: InT | None = None,
|
184
190
|
ctx: RunContext[CtxT],
|
191
|
+
call_id: str,
|
185
192
|
) -> tuple[SystemMessage | None, UserMessage | None]:
|
186
|
-
formatted_sys_prompt = self._prompt_builder.build_system_prompt(
|
193
|
+
formatted_sys_prompt = self._prompt_builder.build_system_prompt(
|
194
|
+
ctx=ctx, call_id=call_id
|
195
|
+
)
|
187
196
|
|
188
197
|
system_message: SystemMessage | None = None
|
189
198
|
if self._reset_memory_on_run or memory.is_empty:
|
@@ -192,24 +201,22 @@ class LLMAgent(
|
|
192
201
|
system_message = cast("SystemMessage", memory.message_history[0])
|
193
202
|
else:
|
194
203
|
self._prepare_memory(
|
195
|
-
memory=memory,
|
204
|
+
memory=memory,
|
205
|
+
in_args=in_args,
|
206
|
+
sys_prompt=formatted_sys_prompt,
|
207
|
+
ctx=ctx,
|
208
|
+
call_id=call_id,
|
196
209
|
)
|
197
210
|
|
198
211
|
input_message = self._prompt_builder.build_input_message(
|
199
|
-
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx
|
212
|
+
chat_inputs=chat_inputs, in_args=in_args, ctx=ctx, call_id=call_id
|
200
213
|
)
|
201
214
|
if input_message:
|
202
215
|
memory.update([input_message])
|
203
216
|
|
204
217
|
return system_message, input_message
|
205
218
|
|
206
|
-
def
|
207
|
-
self,
|
208
|
-
conversation: Messages,
|
209
|
-
*,
|
210
|
-
in_args: InT | None = None,
|
211
|
-
ctx: RunContext[CtxT],
|
212
|
-
) -> OutT:
|
219
|
+
def parse_output_default(self, conversation: Messages) -> OutT:
|
213
220
|
return validate_obj_from_json_or_py_string(
|
214
221
|
str(conversation[-1].content or ""),
|
215
222
|
schema=self._out_type,
|
@@ -223,15 +230,14 @@ class LLMAgent(
|
|
223
230
|
*,
|
224
231
|
in_args: InT | None = None,
|
225
232
|
ctx: RunContext[CtxT],
|
233
|
+
call_id: str,
|
226
234
|
) -> OutT:
|
227
235
|
if self.output_parser:
|
228
236
|
return self.output_parser(
|
229
|
-
conversation=conversation, in_args=in_args, ctx=ctx
|
237
|
+
conversation=conversation, in_args=in_args, ctx=ctx, call_id=call_id
|
230
238
|
)
|
231
239
|
|
232
|
-
return self.
|
233
|
-
conversation=conversation, in_args=in_args, ctx=ctx
|
234
|
-
)
|
240
|
+
return self.parse_output_default(conversation)
|
235
241
|
|
236
242
|
async def _process(
|
237
243
|
self,
|
@@ -239,24 +245,28 @@ class LLMAgent(
|
|
239
245
|
*,
|
240
246
|
in_args: InT | None = None,
|
241
247
|
memory: LLMAgentMemory,
|
242
|
-
call_id: str,
|
243
248
|
ctx: RunContext[CtxT],
|
249
|
+
call_id: str,
|
244
250
|
) -> OutT:
|
245
251
|
system_message, input_message = self._memorize_inputs(
|
246
252
|
memory=memory,
|
247
253
|
chat_inputs=chat_inputs,
|
248
254
|
in_args=in_args,
|
249
255
|
ctx=ctx,
|
256
|
+
call_id=call_id,
|
250
257
|
)
|
251
258
|
if system_message:
|
252
|
-
self._print_messages([system_message],
|
259
|
+
self._print_messages([system_message], ctx=ctx, call_id=call_id)
|
253
260
|
if input_message:
|
254
|
-
self._print_messages([input_message],
|
261
|
+
self._print_messages([input_message], ctx=ctx, call_id=call_id)
|
255
262
|
|
256
|
-
await self._policy_executor.execute(memory,
|
263
|
+
await self._policy_executor.execute(memory, ctx=ctx, call_id=call_id)
|
257
264
|
|
258
265
|
return self._parse_output(
|
259
|
-
conversation=memory.message_history,
|
266
|
+
conversation=memory.message_history,
|
267
|
+
in_args=in_args,
|
268
|
+
ctx=ctx,
|
269
|
+
call_id=call_id,
|
260
270
|
)
|
261
271
|
|
262
272
|
async def _process_stream(
|
@@ -265,43 +275,44 @@ class LLMAgent(
|
|
265
275
|
*,
|
266
276
|
in_args: InT | None = None,
|
267
277
|
memory: LLMAgentMemory,
|
268
|
-
call_id: str,
|
269
278
|
ctx: RunContext[CtxT],
|
279
|
+
call_id: str,
|
270
280
|
) -> AsyncIterator[Event[Any]]:
|
271
281
|
system_message, input_message = self._memorize_inputs(
|
272
282
|
memory=memory,
|
273
283
|
chat_inputs=chat_inputs,
|
274
284
|
in_args=in_args,
|
275
285
|
ctx=ctx,
|
286
|
+
call_id=call_id,
|
276
287
|
)
|
277
288
|
if system_message:
|
278
|
-
self._print_messages([system_message],
|
289
|
+
self._print_messages([system_message], ctx=ctx, call_id=call_id)
|
279
290
|
yield SystemMessageEvent(
|
280
291
|
data=system_message, proc_name=self.name, call_id=call_id
|
281
292
|
)
|
282
293
|
if input_message:
|
283
|
-
self._print_messages([input_message],
|
294
|
+
self._print_messages([input_message], ctx=ctx, call_id=call_id)
|
284
295
|
yield UserMessageEvent(
|
285
296
|
data=input_message, proc_name=self.name, call_id=call_id
|
286
297
|
)
|
287
298
|
|
288
299
|
async for event in self._policy_executor.execute_stream(
|
289
|
-
memory,
|
300
|
+
memory, ctx=ctx, call_id=call_id
|
290
301
|
):
|
291
302
|
yield event
|
292
303
|
|
293
304
|
output = self._parse_output(
|
294
|
-
conversation=memory.message_history,
|
305
|
+
conversation=memory.message_history,
|
306
|
+
in_args=in_args,
|
307
|
+
ctx=ctx,
|
308
|
+
call_id=call_id,
|
295
309
|
)
|
296
310
|
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
297
311
|
|
298
312
|
def _print_messages(
|
299
|
-
self,
|
300
|
-
messages: Sequence[Message],
|
301
|
-
call_id: str,
|
302
|
-
ctx: RunContext[CtxT],
|
313
|
+
self, messages: Sequence[Message], ctx: RunContext[CtxT], call_id: str
|
303
314
|
) -> None:
|
304
|
-
if ctx
|
315
|
+
if ctx.printer:
|
305
316
|
ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
|
306
317
|
|
307
318
|
# -- Override these methods in subclasses if needed --
|
@@ -328,31 +339,45 @@ class LLMAgent(
|
|
328
339
|
if cur_cls.memory_manager is not base_cls.memory_manager:
|
329
340
|
self._policy_executor.memory_manager = self.memory_manager
|
330
341
|
|
331
|
-
def system_prompt_builder(self, ctx: RunContext[CtxT]) -> str | None:
|
342
|
+
def system_prompt_builder(self, ctx: RunContext[CtxT], call_id: str) -> str | None:
|
332
343
|
if self._prompt_builder.system_prompt_builder is not None:
|
333
|
-
return self._prompt_builder.system_prompt_builder(ctx=ctx)
|
344
|
+
return self._prompt_builder.system_prompt_builder(ctx=ctx, call_id=call_id)
|
334
345
|
raise NotImplementedError("System prompt builder is not implemented.")
|
335
346
|
|
336
|
-
def input_content_builder(
|
347
|
+
def input_content_builder(
|
348
|
+
self, in_args: InT, ctx: RunContext[CtxT], call_id: str
|
349
|
+
) -> Content:
|
337
350
|
if self._prompt_builder.input_content_builder is not None:
|
338
|
-
return self._prompt_builder.input_content_builder(
|
351
|
+
return self._prompt_builder.input_content_builder(
|
352
|
+
in_args=in_args, ctx=ctx, call_id=call_id
|
353
|
+
)
|
339
354
|
raise NotImplementedError("Input content builder is not implemented.")
|
340
355
|
|
341
356
|
def tool_call_loop_terminator(
|
342
|
-
self,
|
357
|
+
self,
|
358
|
+
conversation: Messages,
|
359
|
+
*,
|
360
|
+
ctx: RunContext[CtxT],
|
361
|
+
call_id: str,
|
362
|
+
**kwargs: Any,
|
343
363
|
) -> bool:
|
344
364
|
if self._policy_executor.tool_call_loop_terminator is not None:
|
345
365
|
return self._policy_executor.tool_call_loop_terminator(
|
346
|
-
conversation=conversation, ctx=ctx, **kwargs
|
366
|
+
conversation=conversation, ctx=ctx, call_id=call_id, **kwargs
|
347
367
|
)
|
348
368
|
raise NotImplementedError("Tool call loop terminator is not implemented.")
|
349
369
|
|
350
370
|
def memory_manager(
|
351
|
-
self,
|
371
|
+
self,
|
372
|
+
memory: LLMAgentMemory,
|
373
|
+
*,
|
374
|
+
ctx: RunContext[CtxT],
|
375
|
+
call_id: str,
|
376
|
+
**kwargs: Any,
|
352
377
|
) -> None:
|
353
378
|
if self._policy_executor.memory_manager is not None:
|
354
379
|
return self._policy_executor.memory_manager(
|
355
|
-
memory=memory, ctx=ctx, **kwargs
|
380
|
+
memory=memory, ctx=ctx, call_id=call_id, **kwargs
|
356
381
|
)
|
357
382
|
raise NotImplementedError("Memory manager is not implemented.")
|
358
383
|
|