grasp_agents 0.5.11__py3-none-any.whl → 0.5.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +3 -0
- grasp_agents/cloud_llm.py +15 -14
- grasp_agents/generics_utils.py +1 -1
- grasp_agents/litellm/lite_llm.py +3 -0
- grasp_agents/llm_agent.py +2 -5
- grasp_agents/llm_policy_executor.py +7 -18
- grasp_agents/openai/openai_llm.py +3 -0
- grasp_agents/printer.py +153 -136
- grasp_agents/processors/base_processor.py +4 -2
- grasp_agents/run_context.py +2 -9
- {grasp_agents-0.5.11.dist-info → grasp_agents-0.5.12.dist-info}/METADATA +6 -18
- {grasp_agents-0.5.11.dist-info → grasp_agents-0.5.12.dist-info}/RECORD +14 -14
- {grasp_agents-0.5.11.dist-info → grasp_agents-0.5.12.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.11.dist-info → grasp_agents-0.5.12.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/__init__.py
CHANGED
@@ -6,6 +6,7 @@ from .llm_agent import LLMAgent
|
|
6
6
|
from .llm_agent_memory import LLMAgentMemory
|
7
7
|
from .memory import Memory
|
8
8
|
from .packet import Packet
|
9
|
+
from .printer import Printer, print_event_stream
|
9
10
|
from .processors.base_processor import BaseProcessor
|
10
11
|
from .processors.parallel_processor import ParallelProcessor
|
11
12
|
from .processors.processor import Processor
|
@@ -33,9 +34,11 @@ __all__ = [
|
|
33
34
|
"Packet",
|
34
35
|
"Packet",
|
35
36
|
"ParallelProcessor",
|
37
|
+
"Printer",
|
36
38
|
"ProcName",
|
37
39
|
"Processor",
|
38
40
|
"RunContext",
|
39
41
|
"SystemMessage",
|
40
42
|
"UserMessage",
|
43
|
+
"print_event_stream",
|
41
44
|
]
|
grasp_agents/cloud_llm.py
CHANGED
@@ -69,6 +69,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
69
69
|
0 # LLM response retries: try to regenerate to pass validation
|
70
70
|
)
|
71
71
|
apply_response_schema_via_provider: bool = False
|
72
|
+
apply_tool_call_schema_via_provider: bool = False
|
72
73
|
async_http_client: httpx.AsyncClient | None = None
|
73
74
|
async_http_client_params: dict[str, Any] | AsyncHTTPClientParams | None = None
|
74
75
|
|
@@ -79,6 +80,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
79
80
|
f"{self.rate_limiter.rpm} RPM"
|
80
81
|
)
|
81
82
|
|
83
|
+
if self.apply_response_schema_via_provider:
|
84
|
+
object.__setattr__(self, "apply_tool_call_schema_via_provider", True)
|
85
|
+
|
82
86
|
if self.async_http_client is None and self.async_http_client_params is not None:
|
83
87
|
object.__setattr__(
|
84
88
|
self,
|
@@ -99,7 +103,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
99
103
|
api_tools = None
|
100
104
|
api_tool_choice = None
|
101
105
|
if tools:
|
102
|
-
strict = True if self.
|
106
|
+
strict = True if self.apply_tool_call_schema_via_provider else None
|
103
107
|
api_tools = [
|
104
108
|
self.converters.to_tool(t, strict=strict) for t in tools.values()
|
105
109
|
]
|
@@ -174,8 +178,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
174
178
|
response_schema=response_schema,
|
175
179
|
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
176
180
|
)
|
177
|
-
|
178
|
-
|
181
|
+
if not self.apply_tool_call_schema_via_provider and tools is not None:
|
182
|
+
self._validate_tool_calls(completion, tools=tools)
|
179
183
|
|
180
184
|
return completion
|
181
185
|
|
@@ -207,17 +211,16 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
207
211
|
|
208
212
|
if n_attempt > self.max_response_retries:
|
209
213
|
if n_attempt == 1:
|
210
|
-
logger.warning(f"\nCloudLLM completion
|
214
|
+
logger.warning(f"\nCloudLLM completion failed:\n{err}")
|
211
215
|
if n_attempt > 1:
|
212
216
|
logger.warning(
|
213
|
-
f"\nCloudLLM completion
|
217
|
+
f"\nCloudLLM completion failed after retrying:\n{err}"
|
214
218
|
)
|
215
219
|
raise err
|
216
220
|
# return make_refusal_completion(self._model_name, err)
|
217
221
|
|
218
222
|
logger.warning(
|
219
|
-
f"\nCloudLLM completion
|
220
|
-
f"\n{err}"
|
223
|
+
f"\nCloudLLM completion failed (retry attempt {n_attempt}):\n{err}"
|
221
224
|
)
|
222
225
|
|
223
226
|
return make_refusal_completion(
|
@@ -281,8 +284,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
281
284
|
response_schema=response_schema,
|
282
285
|
response_schema_by_xml_tag=response_schema_by_xml_tag,
|
283
286
|
)
|
284
|
-
|
285
|
-
|
287
|
+
if not self.apply_tool_call_schema_via_provider and tools is not None:
|
288
|
+
self._validate_tool_calls(completion, tools=tools)
|
286
289
|
|
287
290
|
return iterator()
|
288
291
|
|
@@ -326,11 +329,10 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
326
329
|
n_attempt += 1
|
327
330
|
if n_attempt > self.max_response_retries:
|
328
331
|
if n_attempt == 1:
|
329
|
-
logger.warning(f"\nCloudLLM completion
|
332
|
+
logger.warning(f"\nCloudLLM completion failed:\n{err}")
|
330
333
|
if n_attempt > 1:
|
331
334
|
logger.warning(
|
332
|
-
"\nCloudLLM completion
|
333
|
-
f"retrying:\n{err}"
|
335
|
+
f"\nCloudLLM completion failed after retrying:\n{err}"
|
334
336
|
)
|
335
337
|
refusal_completion = make_refusal_completion(
|
336
338
|
self.model_name, err
|
@@ -344,6 +346,5 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
344
346
|
# return
|
345
347
|
|
346
348
|
logger.warning(
|
347
|
-
"\nCloudLLM completion
|
348
|
-
f"(retry attempt {n_attempt}):\n{err}"
|
349
|
+
f"\nCloudLLM completion failed (retry attempt {n_attempt}):\n{err}"
|
349
350
|
)
|
grasp_agents/generics_utils.py
CHANGED
@@ -159,7 +159,7 @@ class AutoInstanceAttributesMixin:
|
|
159
159
|
attr_type = resolved_attr_types[attr_name]
|
160
160
|
# attr_type = None if _attr_type is type(None) else _attr_type
|
161
161
|
else:
|
162
|
-
attr_type =
|
162
|
+
attr_type = object
|
163
163
|
|
164
164
|
if attr_name in pyd_private:
|
165
165
|
pyd_private[attr_name] = attr_type
|
grasp_agents/litellm/lite_llm.py
CHANGED
@@ -149,6 +149,9 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
|
|
149
149
|
n_choices: int | None = None,
|
150
150
|
**api_llm_settings: Any,
|
151
151
|
) -> LiteLLMCompletion:
|
152
|
+
if api_llm_settings and api_llm_settings.get("stream_options"):
|
153
|
+
api_llm_settings.pop("stream_options")
|
154
|
+
|
152
155
|
completion = await litellm.acompletion( # type: ignore[no-untyped-call]
|
153
156
|
model=self.model_name,
|
154
157
|
messages=api_messages,
|
grasp_agents/llm_agent.py
CHANGED
@@ -310,12 +310,9 @@ class LLMAgent(
|
|
310
310
|
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
|
311
311
|
|
312
312
|
def _print_messages(
|
313
|
-
self,
|
314
|
-
messages: Sequence[Message],
|
315
|
-
ctx: RunContext[CtxT],
|
316
|
-
call_id: str,
|
313
|
+
self, messages: Sequence[Message], ctx: RunContext[CtxT], call_id: str
|
317
314
|
) -> None:
|
318
|
-
if ctx
|
315
|
+
if ctx.printer:
|
319
316
|
ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
|
320
317
|
|
321
318
|
# -- Override these methods in subclasses if needed --
|
@@ -161,14 +161,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
161
161
|
response_schema_by_xml_tag=self.response_schema_by_xml_tag,
|
162
162
|
tools=self.tools,
|
163
163
|
tool_choice=tool_choice,
|
164
|
-
n_choices=1,
|
165
164
|
proc_name=self.agent_name,
|
166
165
|
call_id=call_id,
|
167
166
|
)
|
168
167
|
memory.update(completion.messages)
|
169
|
-
self._process_completion(
|
170
|
-
completion, ctx=ctx, call_id=call_id, print_messages=True
|
171
|
-
)
|
168
|
+
self._process_completion(completion, ctx=ctx, call_id=call_id)
|
172
169
|
|
173
170
|
return completion.messages[0]
|
174
171
|
|
@@ -193,7 +190,6 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
193
190
|
response_schema_by_xml_tag=self.response_schema_by_xml_tag,
|
194
191
|
tools=self.tools,
|
195
192
|
tool_choice=tool_choice,
|
196
|
-
n_choices=1,
|
197
193
|
proc_name=self.agent_name,
|
198
194
|
call_id=call_id,
|
199
195
|
)
|
@@ -212,9 +208,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
212
208
|
|
213
209
|
memory.update(completion.messages)
|
214
210
|
|
215
|
-
self._process_completion(
|
216
|
-
completion, print_messages=True, ctx=ctx, call_id=call_id
|
217
|
-
)
|
211
|
+
self._process_completion(completion, ctx=ctx, call_id=call_id)
|
218
212
|
|
219
213
|
async def call_tools(
|
220
214
|
self,
|
@@ -237,7 +231,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
237
231
|
|
238
232
|
memory.update(tool_messages)
|
239
233
|
|
240
|
-
if ctx
|
234
|
+
if ctx.printer:
|
241
235
|
ctx.printer.print_messages(
|
242
236
|
tool_messages, agent_name=self.agent_name, call_id=call_id
|
243
237
|
)
|
@@ -283,7 +277,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
283
277
|
"Exceeded the maximum number of turns: provide a final answer now!"
|
284
278
|
)
|
285
279
|
memory.update([user_message])
|
286
|
-
if ctx
|
280
|
+
if ctx.printer:
|
287
281
|
ctx.printer.print_messages(
|
288
282
|
[user_message], agent_name=self.agent_name, call_id=call_id
|
289
283
|
)
|
@@ -309,7 +303,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
309
303
|
yield UserMessageEvent(
|
310
304
|
proc_name=self.agent_name, call_id=call_id, data=user_message
|
311
305
|
)
|
312
|
-
if ctx
|
306
|
+
if ctx.printer:
|
313
307
|
ctx.printer.print_messages(
|
314
308
|
[user_message], agent_name=self.agent_name, call_id=call_id
|
315
309
|
)
|
@@ -507,12 +501,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
507
501
|
return FinalAnswerTool()
|
508
502
|
|
509
503
|
def _process_completion(
|
510
|
-
self,
|
511
|
-
completion: Completion,
|
512
|
-
*,
|
513
|
-
print_messages: bool = False,
|
514
|
-
ctx: RunContext[CtxT],
|
515
|
-
call_id: str,
|
504
|
+
self, completion: Completion, *, ctx: RunContext[CtxT], call_id: str
|
516
505
|
) -> None:
|
517
506
|
ctx.completions[self.agent_name].append(completion)
|
518
507
|
ctx.usage_tracker.update(
|
@@ -520,7 +509,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
520
509
|
completions=[completion],
|
521
510
|
model_name=self.llm.model_name,
|
522
511
|
)
|
523
|
-
if ctx.printer
|
512
|
+
if ctx.printer:
|
524
513
|
usages = [None] * (len(completion.messages) - 1) + [completion.usage]
|
525
514
|
ctx.printer.print_messages(
|
526
515
|
completion.messages,
|
@@ -172,6 +172,9 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
|
|
172
172
|
response_format = api_response_schema or NOT_GIVEN
|
173
173
|
n = n_choices or NOT_GIVEN
|
174
174
|
|
175
|
+
if api_llm_settings and api_llm_settings.get("stream_options"):
|
176
|
+
api_llm_settings.pop("stream_options")
|
177
|
+
|
175
178
|
if self.apply_response_schema_via_provider:
|
176
179
|
return await self.client.beta.chat.completions.parse(
|
177
180
|
model=self.model_name,
|
grasp_agents/printer.py
CHANGED
@@ -72,50 +72,61 @@ ColoringMode: TypeAlias = Literal["agent", "role"]
|
|
72
72
|
CompletionBlockType: TypeAlias = Literal["response", "thinking", "tool_call"]
|
73
73
|
|
74
74
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
) -> None:
|
79
|
-
self.color_by = color_by
|
80
|
-
self.msg_trunc_len = msg_trunc_len
|
81
|
-
self._current_message: str = ""
|
75
|
+
def stream_colored_text(new_colored_text: str) -> None:
|
76
|
+
sys.stdout.write(new_colored_text)
|
77
|
+
sys.stdout.flush()
|
82
78
|
|
83
|
-
@staticmethod
|
84
|
-
def get_role_color(role: Role) -> Color:
|
85
|
-
return ROLE_TO_COLOR[role]
|
86
79
|
|
87
|
-
|
88
|
-
|
80
|
+
def get_color(
|
81
|
+
agent_name: str = "", role: Role = Role.ASSISTANT, color_by: ColoringMode = "role"
|
82
|
+
) -> Color:
|
83
|
+
if color_by == "agent":
|
89
84
|
idx = int(
|
90
85
|
hashlib.md5(agent_name.encode()).hexdigest(), # noqa :S324
|
91
86
|
16,
|
92
87
|
) % len(AVAILABLE_COLORS)
|
93
88
|
|
94
89
|
return AVAILABLE_COLORS[idx]
|
90
|
+
return ROLE_TO_COLOR[role]
|
91
|
+
|
95
92
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
return "\n".join(content_str_parts)
|
93
|
+
def content_to_str(content: Content | str | None, role: Role) -> str:
|
94
|
+
if role == Role.USER and isinstance(content, Content):
|
95
|
+
content_str_parts: list[str] = []
|
96
|
+
for content_part in content.parts:
|
97
|
+
if isinstance(content_part, ContentPartText):
|
98
|
+
content_str_parts.append(content_part.data.strip(" \n"))
|
99
|
+
elif content_part.data.type == "url":
|
100
|
+
content_str_parts.append(str(content_part.data.url))
|
101
|
+
elif content_part.data.type == "base64":
|
102
|
+
content_str_parts.append("<ENCODED_IMAGE>")
|
103
|
+
return "\n".join(content_str_parts)
|
108
104
|
|
109
|
-
|
105
|
+
assert isinstance(content, str | None)
|
110
106
|
|
111
|
-
|
107
|
+
return (content or "").strip(" \n")
|
112
108
|
|
113
|
-
@staticmethod
|
114
|
-
def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
|
115
|
-
if len(content_str) > trunc_len:
|
116
|
-
return content_str[:trunc_len] + "[...]"
|
117
109
|
|
118
|
-
|
110
|
+
def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
|
111
|
+
if len(content_str) > trunc_len:
|
112
|
+
return content_str[:trunc_len] + "[...]"
|
113
|
+
|
114
|
+
return content_str
|
115
|
+
|
116
|
+
|
117
|
+
class Printer:
|
118
|
+
def __init__(
|
119
|
+
self,
|
120
|
+
color_by: ColoringMode = "role",
|
121
|
+
msg_trunc_len: int = 20000,
|
122
|
+
output_to: Literal["print", "log"] = "print",
|
123
|
+
logging_level: Literal["info", "debug", "warning", "error"] = "info",
|
124
|
+
) -> None:
|
125
|
+
self.color_by: ColoringMode = color_by
|
126
|
+
self.msg_trunc_len = msg_trunc_len
|
127
|
+
self._current_message: str = ""
|
128
|
+
self._logging_level = logging_level
|
129
|
+
self._output_to = output_to
|
119
130
|
|
120
131
|
def print_message(
|
121
132
|
self,
|
@@ -128,11 +139,8 @@ class Printer:
|
|
128
139
|
raise ValueError(
|
129
140
|
"Usage information can only be printed for AssistantMessage"
|
130
141
|
)
|
131
|
-
|
132
|
-
|
133
|
-
self.get_agent_color(agent_name)
|
134
|
-
if self.color_by == "agent"
|
135
|
-
else self.get_role_color(message.role)
|
142
|
+
color = get_color(
|
143
|
+
agent_name=agent_name, role=message.role, color_by=self.color_by
|
136
144
|
)
|
137
145
|
log_kwargs = {"extra": {"color": color}}
|
138
146
|
|
@@ -144,13 +152,13 @@ class Printer:
|
|
144
152
|
out += f"<thinking>\n{thinking}\n</thinking>\n"
|
145
153
|
|
146
154
|
# Content
|
147
|
-
content =
|
155
|
+
content = content_to_str(message.content or "", message.role)
|
148
156
|
if content:
|
149
157
|
try:
|
150
158
|
content = json.dumps(json.loads(content), indent=2)
|
151
159
|
except Exception:
|
152
160
|
pass
|
153
|
-
content =
|
161
|
+
content = truncate_content_str(content, trunc_len=self.msg_trunc_len)
|
154
162
|
if isinstance(message, SystemMessage):
|
155
163
|
out += f"<system>\n{content}\n</system>\n"
|
156
164
|
elif isinstance(message, UserMessage):
|
@@ -176,7 +184,17 @@ class Printer:
|
|
176
184
|
|
177
185
|
out += f"\n------------------------------------\n{usage_str}\n"
|
178
186
|
|
179
|
-
|
187
|
+
if self._output_to == "log":
|
188
|
+
if self._logging_level == "debug":
|
189
|
+
logger.debug(out, **log_kwargs) # type: ignore
|
190
|
+
elif self._logging_level == "info":
|
191
|
+
logger.info(out, **log_kwargs) # type: ignore
|
192
|
+
elif self._logging_level == "warning":
|
193
|
+
logger.warning(out, **log_kwargs) # type: ignore
|
194
|
+
else:
|
195
|
+
logger.error(out, **log_kwargs) # type: ignore
|
196
|
+
else:
|
197
|
+
stream_colored_text(colored(out + "\n", color))
|
180
198
|
|
181
199
|
def print_messages(
|
182
200
|
self,
|
@@ -193,28 +211,101 @@ class Printer:
|
|
193
211
|
)
|
194
212
|
|
195
213
|
|
196
|
-
def stream_text(new_text: str, color: Color) -> None:
|
197
|
-
sys.stdout.write(colored(new_text, color))
|
198
|
-
sys.stdout.flush()
|
199
|
-
|
200
|
-
|
201
214
|
async def print_event_stream(
|
202
215
|
event_generator: AsyncIterator[Event[Any]],
|
203
216
|
color_by: ColoringMode = "role",
|
204
217
|
trunc_len: int = 10000,
|
205
218
|
) -> AsyncIterator[Event[Any]]:
|
206
|
-
|
219
|
+
def _make_chunk_text(event: CompletionChunkEvent[CompletionChunk]) -> str:
|
220
|
+
color = get_color(
|
221
|
+
agent_name=event.proc_name or "", role=Role.ASSISTANT, color_by=color_by
|
222
|
+
)
|
223
|
+
text = ""
|
224
|
+
|
225
|
+
if isinstance(event, CompletionStartEvent):
|
226
|
+
text += f"\n<{event.proc_name}> [{event.call_id}]\n"
|
227
|
+
elif isinstance(event, ThinkingStartEvent):
|
228
|
+
text += "<thinking>\n"
|
229
|
+
elif isinstance(event, ResponseStartEvent):
|
230
|
+
text += "<response>\n"
|
231
|
+
elif isinstance(event, ToolCallStartEvent):
|
232
|
+
tc = event.data.tool_call
|
233
|
+
text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
|
234
|
+
elif isinstance(event, AnnotationsStartEvent):
|
235
|
+
text += "<annotations>\n"
|
236
|
+
|
237
|
+
# if isinstance(event, CompletionEndEvent):
|
238
|
+
# text += f"\n</{event.proc_name}>\n"
|
239
|
+
if isinstance(event, ThinkingEndEvent):
|
240
|
+
text += "\n</thinking>\n"
|
241
|
+
elif isinstance(event, ResponseEndEvent):
|
242
|
+
text += "\n</response>\n"
|
243
|
+
elif isinstance(event, ToolCallEndEvent):
|
244
|
+
text += "\n</tool call>\n"
|
245
|
+
elif isinstance(event, AnnotationsEndEvent):
|
246
|
+
text += "\n</annotations>\n"
|
247
|
+
|
248
|
+
if isinstance(event, ThinkingChunkEvent):
|
249
|
+
thinking = event.data.thinking
|
250
|
+
if isinstance(thinking, str):
|
251
|
+
text += thinking
|
252
|
+
else:
|
253
|
+
text = "\n".join(
|
254
|
+
[block.get("thinking", "[redacted]") for block in thinking]
|
255
|
+
)
|
207
256
|
|
208
|
-
|
209
|
-
|
210
|
-
return Printer.get_agent_color(event.proc_name or "")
|
211
|
-
return Printer.get_role_color(role)
|
257
|
+
if isinstance(event, ResponseChunkEvent):
|
258
|
+
text += event.data.response
|
212
259
|
|
213
|
-
|
214
|
-
|
215
|
-
) -> None:
|
216
|
-
color = _get_color(event, Role.ASSISTANT)
|
260
|
+
if isinstance(event, ToolCallChunkEvent):
|
261
|
+
text += event.data.tool_call.tool_arguments or ""
|
217
262
|
|
263
|
+
if isinstance(event, AnnotationsChunkEvent):
|
264
|
+
text += "\n".join(
|
265
|
+
[
|
266
|
+
json.dumps(annotation, indent=2)
|
267
|
+
for annotation in event.data.annotations
|
268
|
+
]
|
269
|
+
)
|
270
|
+
|
271
|
+
return colored(text, color)
|
272
|
+
|
273
|
+
def _make_message_text(
|
274
|
+
event: MessageEvent[SystemMessage | UserMessage | ToolMessage],
|
275
|
+
) -> str:
|
276
|
+
message = event.data
|
277
|
+
role = message.role
|
278
|
+
content = content_to_str(message.content, role=role)
|
279
|
+
|
280
|
+
color = get_color(
|
281
|
+
agent_name=event.proc_name or "", role=role, color_by=color_by
|
282
|
+
)
|
283
|
+
text = f"\n<{event.proc_name}> [{event.call_id}]\n"
|
284
|
+
|
285
|
+
if isinstance(event, (SystemMessageEvent, UserMessageEvent)):
|
286
|
+
content = truncate_content_str(content, trunc_len=trunc_len)
|
287
|
+
|
288
|
+
if isinstance(event, SystemMessageEvent):
|
289
|
+
text += f"<system>\n{content}\n</system>\n"
|
290
|
+
|
291
|
+
elif isinstance(event, UserMessageEvent):
|
292
|
+
text += f"<input>\n{content}\n</input>\n"
|
293
|
+
|
294
|
+
elif isinstance(event, ToolMessageEvent):
|
295
|
+
message = event.data
|
296
|
+
try:
|
297
|
+
content = json.dumps(json.loads(content), indent=2)
|
298
|
+
except Exception:
|
299
|
+
pass
|
300
|
+
text += (
|
301
|
+
f"<tool result> [{message.tool_call_id}]\n{content}\n</tool result>\n"
|
302
|
+
)
|
303
|
+
|
304
|
+
return colored(text, color)
|
305
|
+
|
306
|
+
def _make_packet_text(
|
307
|
+
event: ProcPacketOutputEvent | WorkflowResultEvent | RunResultEvent,
|
308
|
+
) -> str:
|
218
309
|
if isinstance(event, WorkflowResultEvent):
|
219
310
|
src = "workflow"
|
220
311
|
elif isinstance(event, RunResultEvent):
|
@@ -222,6 +313,9 @@ async def print_event_stream(
|
|
222
313
|
else:
|
223
314
|
src = "processor"
|
224
315
|
|
316
|
+
color = get_color(
|
317
|
+
agent_name=event.proc_name or "", role=Role.ASSISTANT, color_by=color_by
|
318
|
+
)
|
225
319
|
text = f"\n<{event.proc_name}> [{event.call_id}]\n"
|
226
320
|
|
227
321
|
if event.data.payloads:
|
@@ -237,7 +331,9 @@ async def print_event_stream(
|
|
237
331
|
text += f"{p_str}\n"
|
238
332
|
text += f"</{src} output>\n"
|
239
333
|
|
240
|
-
|
334
|
+
return colored(text, color)
|
335
|
+
|
336
|
+
# ------ Wrap event generator -------
|
241
337
|
|
242
338
|
async for event in event_generator:
|
243
339
|
yield event
|
@@ -245,91 +341,12 @@ async def print_event_stream(
|
|
245
341
|
if isinstance(event, CompletionChunkEvent) and isinstance(
|
246
342
|
event.data, CompletionChunk
|
247
343
|
):
|
248
|
-
|
249
|
-
|
250
|
-
text = ""
|
251
|
-
|
252
|
-
if isinstance(event, CompletionStartEvent):
|
253
|
-
text += f"\n<{event.proc_name}> [{event.call_id}]\n"
|
254
|
-
elif isinstance(event, ThinkingStartEvent):
|
255
|
-
text += "<thinking>\n"
|
256
|
-
elif isinstance(event, ResponseStartEvent):
|
257
|
-
text += "<response>\n"
|
258
|
-
elif isinstance(event, ToolCallStartEvent):
|
259
|
-
tc = event.data.tool_call
|
260
|
-
text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
|
261
|
-
elif isinstance(event, AnnotationsStartEvent):
|
262
|
-
text += "<annotations>\n"
|
263
|
-
|
264
|
-
# if isinstance(event, CompletionEndEvent):
|
265
|
-
# text += f"\n</{event.proc_name}>\n"
|
266
|
-
if isinstance(event, ThinkingEndEvent):
|
267
|
-
text += "\n</thinking>\n"
|
268
|
-
elif isinstance(event, ResponseEndEvent):
|
269
|
-
text += "\n</response>\n"
|
270
|
-
elif isinstance(event, ToolCallEndEvent):
|
271
|
-
text += "\n</tool call>\n"
|
272
|
-
elif isinstance(event, AnnotationsEndEvent):
|
273
|
-
text += "\n</annotations>\n"
|
274
|
-
|
275
|
-
if isinstance(event, ThinkingChunkEvent):
|
276
|
-
thinking = event.data.thinking
|
277
|
-
if isinstance(thinking, str):
|
278
|
-
text += thinking
|
279
|
-
else:
|
280
|
-
text = "\n".join(
|
281
|
-
[block.get("thinking", "[redacted]") for block in thinking]
|
282
|
-
)
|
283
|
-
|
284
|
-
if isinstance(event, ResponseChunkEvent):
|
285
|
-
text += event.data.response
|
286
|
-
|
287
|
-
if isinstance(event, ToolCallChunkEvent):
|
288
|
-
text += event.data.tool_call.tool_arguments or ""
|
289
|
-
|
290
|
-
if isinstance(event, AnnotationsChunkEvent):
|
291
|
-
text += "\n".join(
|
292
|
-
[
|
293
|
-
json.dumps(annotation, indent=2)
|
294
|
-
for annotation in event.data.annotations
|
295
|
-
]
|
296
|
-
)
|
297
|
-
|
298
|
-
stream_text(text, color)
|
344
|
+
stream_colored_text(_make_chunk_text(event))
|
299
345
|
|
300
346
|
if isinstance(event, MessageEvent) and not isinstance(event, GenMessageEvent):
|
301
|
-
|
302
|
-
|
303
|
-
message = event.data
|
304
|
-
role = message.role
|
305
|
-
content = Printer.content_to_str(message.content, role=role)
|
306
|
-
color = _get_color(event, role)
|
307
|
-
|
308
|
-
text = f"\n<{event.proc_name}> [{event.call_id}]\n"
|
309
|
-
|
310
|
-
if isinstance(event, (SystemMessageEvent, UserMessageEvent)):
|
311
|
-
content = Printer.truncate_content_str(content, trunc_len=trunc_len)
|
312
|
-
|
313
|
-
if isinstance(event, SystemMessageEvent):
|
314
|
-
text += f"<system>\n{content}\n</system>\n"
|
315
|
-
|
316
|
-
elif isinstance(event, UserMessageEvent):
|
317
|
-
text += f"<input>\n{content}\n</input>\n"
|
318
|
-
|
319
|
-
elif isinstance(event, ToolMessageEvent):
|
320
|
-
message = event.data
|
321
|
-
try:
|
322
|
-
content = json.dumps(json.loads(content), indent=2)
|
323
|
-
except Exception:
|
324
|
-
pass
|
325
|
-
text += (
|
326
|
-
f"<tool result> [{message.tool_call_id}]\n"
|
327
|
-
f"{content}\n</tool result>\n"
|
328
|
-
)
|
329
|
-
|
330
|
-
stream_text(text, color)
|
347
|
+
stream_colored_text(_make_message_text(event))
|
331
348
|
|
332
349
|
if isinstance(
|
333
350
|
event, (ProcPacketOutputEvent, WorkflowResultEvent, RunResultEvent)
|
334
351
|
):
|
335
|
-
|
352
|
+
stream_colored_text(_make_packet_text(event))
|
@@ -60,11 +60,13 @@ def with_retry(func: F) -> F:
|
|
60
60
|
logger.warning(f"{err_message}:\n{err}")
|
61
61
|
else:
|
62
62
|
logger.warning(f"{err_message} after retrying:\n{err}")
|
63
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
63
|
+
# raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
64
|
+
return None # type: ignore[return]
|
64
65
|
|
65
66
|
logger.warning(f"{err_message} (retry attempt {n_attempt + 1}):\n{err}")
|
66
67
|
# This part should not be reachable due to the raise in the loop
|
67
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id)
|
68
|
+
# raise ProcRunError(proc_name=self.name, call_id=call_id)
|
69
|
+
return None # type: ignore[return]
|
68
70
|
|
69
71
|
return cast("F", wrapper)
|
70
72
|
|
grasp_agents/run_context.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
from collections import defaultdict
|
2
|
-
from typing import
|
2
|
+
from typing import Generic, TypeVar
|
3
3
|
|
4
4
|
from pydantic import BaseModel, ConfigDict, Field
|
5
5
|
|
6
6
|
from grasp_agents.typing.completion import Completion
|
7
7
|
|
8
|
-
from .printer import
|
8
|
+
from .printer import Printer
|
9
9
|
from .typing.io import ProcName
|
10
10
|
from .usage_tracker import UsageTracker
|
11
11
|
|
@@ -19,13 +19,6 @@ class RunContext(BaseModel, Generic[CtxT]):
|
|
19
19
|
default_factory=lambda: defaultdict(list)
|
20
20
|
)
|
21
21
|
usage_tracker: UsageTracker = Field(default_factory=UsageTracker)
|
22
|
-
|
23
22
|
printer: Printer | None = None
|
24
|
-
log_messages: bool = False
|
25
|
-
color_messages_by: ColoringMode = "role"
|
26
|
-
|
27
|
-
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
28
|
-
if self.log_messages:
|
29
|
-
self.printer = Printer(color_by=self.color_messages_by)
|
30
23
|
|
31
24
|
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: grasp_agents
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.12
|
4
4
|
Summary: Grasp Agents Library
|
5
5
|
License-File: LICENSE.md
|
6
6
|
Requires-Python: <4,>=3.11.4
|
@@ -110,24 +110,16 @@ Create a script, e.g., `problem_recommender.py`:
|
|
110
110
|
|
111
111
|
```python
|
112
112
|
import asyncio
|
113
|
-
from pathlib import Path
|
114
113
|
from typing import Any
|
115
114
|
|
116
115
|
from dotenv import load_dotenv
|
117
116
|
from pydantic import BaseModel, Field
|
118
117
|
|
119
|
-
from grasp_agents
|
118
|
+
from grasp_agents import LLMAgent, BaseTool, RunContext, Printer
|
120
119
|
from grasp_agents.litellm import LiteLLM, LiteLLMSettings
|
121
|
-
from grasp_agents import LLMAgent, BaseTool, RunContext
|
122
|
-
|
123
|
-
load_dotenv()
|
124
120
|
|
125
121
|
|
126
|
-
|
127
|
-
setup_logging(
|
128
|
-
logs_file_path="grasp_agents_demo.log",
|
129
|
-
logs_config_path=Path().cwd() / "configs/logging/default.yaml",
|
130
|
-
)
|
122
|
+
load_dotenv()
|
131
123
|
|
132
124
|
sys_prompt_react = """
|
133
125
|
Your task is to suggest an exciting stats problem to the student.
|
@@ -162,7 +154,7 @@ Returns:
|
|
162
154
|
"""
|
163
155
|
|
164
156
|
|
165
|
-
class AskStudentTool(BaseTool[TeacherQuestion, StudentReply,
|
157
|
+
class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, None]):
|
166
158
|
name: str = "ask_student"
|
167
159
|
description: str = ask_student_tool_description
|
168
160
|
|
@@ -176,11 +168,7 @@ class Problem(BaseModel):
|
|
176
168
|
|
177
169
|
teacher = LLMAgent[None, Problem, None](
|
178
170
|
name="teacher",
|
179
|
-
llm=LiteLLM(
|
180
|
-
model_name="gpt-4.1",
|
181
|
-
# model_name="claude-sonnet-4-20250514",
|
182
|
-
# llm_settings=LiteLLMSettings(reasoning_effort="low"),
|
183
|
-
),
|
171
|
+
llm=LiteLLM(model_name="gpt-4.1"),
|
184
172
|
tools=[AskStudentTool()],
|
185
173
|
react_mode=True,
|
186
174
|
final_answer_as_tool_call=True,
|
@@ -188,7 +176,7 @@ teacher = LLMAgent[None, Problem, None](
|
|
188
176
|
)
|
189
177
|
|
190
178
|
async def main():
|
191
|
-
ctx = RunContext[None](
|
179
|
+
ctx = RunContext[None](printer=Printer())
|
192
180
|
out = await teacher.run("start", ctx=ctx)
|
193
181
|
print(out.payloads[0])
|
194
182
|
print(ctx.usage_tracker.total_usage)
|
@@ -1,20 +1,20 @@
|
|
1
|
-
grasp_agents/__init__.py,sha256=
|
2
|
-
grasp_agents/cloud_llm.py,sha256=
|
1
|
+
grasp_agents/__init__.py,sha256=0pRU10xjcpuPdCitYtPK_bJVSUZ89FD4Jsmv1DJ_0GY,1121
|
2
|
+
grasp_agents/cloud_llm.py,sha256=NwKr1XwpJP-px9xZDI1rngHuDmpSbtJ61VynLacbtNo,13237
|
3
3
|
grasp_agents/costs_dict.yaml,sha256=2MFNWtkv5W5WSCcv1Cj13B1iQLVv5Ot9pS_KW2Gu2DA,2510
|
4
4
|
grasp_agents/errors.py,sha256=K-22TCM1Klhsej47Rg5eTqnGiGPaXgKOpdOZZ7cPipw,4633
|
5
|
-
grasp_agents/generics_utils.py,sha256=
|
5
|
+
grasp_agents/generics_utils.py,sha256=HTX7G8eoylR-zMKz7JDKC-QnDjwFLqMNMjiI8XoGEow,6545
|
6
6
|
grasp_agents/grasp_logging.py,sha256=H1GYhXdQvVkmauFDZ-KDwvVmPQHZUUm9sRqX_ObK2xI,1111
|
7
7
|
grasp_agents/http_client.py,sha256=Es8NXGDkp4Nem7g24-jW0KFGA9Hp_o2Cv3cOvjup-iU,859
|
8
8
|
grasp_agents/llm.py,sha256=IeV2QpR4AldVP3THzSETEnsaDx3DYz5HM6dkikSpy4o,10684
|
9
|
-
grasp_agents/llm_agent.py,sha256=
|
9
|
+
grasp_agents/llm_agent.py,sha256=Ig4YUsxPGvC444ZGl1eDIwfy46gmyygOe7mleEq0ZYM,14090
|
10
10
|
grasp_agents/llm_agent_memory.py,sha256=XmOT2G8RG5AHd0LR3WuK7VbD-KFFfThmJnuZK2iU3Fs,1856
|
11
|
-
grasp_agents/llm_policy_executor.py,sha256=
|
11
|
+
grasp_agents/llm_policy_executor.py,sha256=AVUuulV4QYQG86hHrBf7ldfThAGx3MPnZJdotnxwO8U,17886
|
12
12
|
grasp_agents/memory.py,sha256=keHuNEZNSxHT9FKpMohHOCNi7UAz_oRIc91IQEuzaWE,1162
|
13
13
|
grasp_agents/packet.py,sha256=EmE-W4ZSMVZoqClECGFe7OGqrT4FSJ8IVGICrdjtdEY,1462
|
14
14
|
grasp_agents/packet_pool.py,sha256=AF7ZMYY1U6ppNLEn6o0R8QXyWmcLQGcju7_TYQpAudg,4443
|
15
|
-
grasp_agents/printer.py,sha256=
|
15
|
+
grasp_agents/printer.py,sha256=9hw4XOoYlj6mZ3hcH04sxbQQ3rCZqixrn0RIwiXwGh8,11637
|
16
16
|
grasp_agents/prompt_builder.py,sha256=wNPphkW8RL8501jV4Z7ncsN_sxBDR9Ax7eILLHr-OYg,6110
|
17
|
-
grasp_agents/run_context.py,sha256=
|
17
|
+
grasp_agents/run_context.py,sha256=0kWvOKBzQzx6FbdtDVAoeCOmiRGssN6X4n8YPX_oLBY,687
|
18
18
|
grasp_agents/runner.py,sha256=JL2wSKahbPYVd56NRB09cwco43sjhZPI4XYFCZyOXOA,5173
|
19
19
|
grasp_agents/usage_tracker.py,sha256=ZQfVUUpG0C89hyPWT_JgXnjQOxoYmumcQ9t-aCfcMo8,3561
|
20
20
|
grasp_agents/utils.py,sha256=qKmGBwrQHw1-BgqRLuGTPKGs3J_zbrpk3nxnP1iZBiQ,6152
|
@@ -22,7 +22,7 @@ grasp_agents/litellm/__init__.py,sha256=wD8RZBYokFDfbS9Cs7nO_zKb3w7RIVwEGj7g2D5C
|
|
22
22
|
grasp_agents/litellm/completion_chunk_converters.py,sha256=J5PPxzoTBqkvKQnCoBxQxJo7Q8Xfl9cbv2GRZox8Cjo,2689
|
23
23
|
grasp_agents/litellm/completion_converters.py,sha256=JQ7XvQwwc-biFqVMcRO61SL5VGs_SkUvAhUz1QD7EmU,2516
|
24
24
|
grasp_agents/litellm/converters.py,sha256=XjePHii578sXP26Fyhnv0XfwJ3cNTp5PraggTsvcBXo,4778
|
25
|
-
grasp_agents/litellm/lite_llm.py,sha256=
|
25
|
+
grasp_agents/litellm/lite_llm.py,sha256=6Am1rWZ1LWWog8tjSSoN6sFtOqSi4roX4QNC9FLw5Gs,8403
|
26
26
|
grasp_agents/litellm/message_converters.py,sha256=PsGLIJEcAeEoluHIh-utEufJ_9WeMYzXkwnR-8jyULQ,2037
|
27
27
|
grasp_agents/openai/__init__.py,sha256=xaRnblUskiLvypIhMe4NRp9dxCG-gNR7dPiugUbPbhE,4717
|
28
28
|
grasp_agents/openai/completion_chunk_converters.py,sha256=3MnMskdlp7ycsggc1ok1XpCHaP4Us2rLYaxImPLw1eI,2573
|
@@ -30,9 +30,9 @@ grasp_agents/openai/completion_converters.py,sha256=UlDeQSl0AEFUS-QI5e8rrjfmXZoj
|
|
30
30
|
grasp_agents/openai/content_converters.py,sha256=sMsZhoatuL_8t0IdVaGWIVZLB4nyi1ajD61GewQmeY4,2503
|
31
31
|
grasp_agents/openai/converters.py,sha256=RKOfMbIJmfFQ7ot0RGR6wrdMbR6_L7PB0UZwxwgM88g,4691
|
32
32
|
grasp_agents/openai/message_converters.py,sha256=fhSN81uK51EGbLyM2-f0MvPX_UBrMy7SF3JQPo-dkXg,4686
|
33
|
-
grasp_agents/openai/openai_llm.py,sha256=
|
33
|
+
grasp_agents/openai/openai_llm.py,sha256=DCvoF7MM_OID2UhTkNcBLeEdf8Ej2SQ-6FhgRGe8l1E,9861
|
34
34
|
grasp_agents/openai/tool_converters.py,sha256=rNH5t2Wir9nuy8Ei0jaxNuzDaXGqTLmLz3VyrnJhyn0,1196
|
35
|
-
grasp_agents/processors/base_processor.py,sha256=
|
35
|
+
grasp_agents/processors/base_processor.py,sha256=2ks_jpeye0XW9k3BcAhD9XO8NctspOV2py9ZAD-dO6c,10942
|
36
36
|
grasp_agents/processors/parallel_processor.py,sha256=BOXRlPaZ-hooz0hHctqiW_5ldR-yDPYjFxuP7fAbZCI,7911
|
37
37
|
grasp_agents/processors/processor.py,sha256=35MtYKrKtCZZMhV-U1DXBXtCNbCvZGaiiXo_5a3tI6s,5249
|
38
38
|
grasp_agents/rate_limiting/__init__.py,sha256=KRgtF_E7R3YfA2cpYcFcZ7wycV0pWVJ0xRQC7YhiIEQ,158
|
@@ -52,7 +52,7 @@ grasp_agents/workflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3h
|
|
52
52
|
grasp_agents/workflow/looped_workflow.py,sha256=WHp9O3Za2sBVfY_BLOdvPvtY20XsjZQaWSO2-oAFvOY,6806
|
53
53
|
grasp_agents/workflow/sequential_workflow.py,sha256=e3BIWzy_2novmEWNwIteyMbrzvl1-evHrTBE3r3SpU8,3648
|
54
54
|
grasp_agents/workflow/workflow_processor.py,sha256=DwHz70UOTp9dkbtzH9KE5LkGcT1RdHV7Hdiby0Bu9tw,3535
|
55
|
-
grasp_agents-0.5.
|
56
|
-
grasp_agents-0.5.
|
57
|
-
grasp_agents-0.5.
|
58
|
-
grasp_agents-0.5.
|
55
|
+
grasp_agents-0.5.12.dist-info/METADATA,sha256=zfsWyMd6aeCxdXiTGWUM4ZuOnA3xMAbgfeLXKYpNdcI,6633
|
56
|
+
grasp_agents-0.5.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
57
|
+
grasp_agents-0.5.12.dist-info/licenses/LICENSE.md,sha256=-nNNdWqGB8gJ2O-peFQ2Irshv5tW5pHKyTcYkwvH7CE,1201
|
58
|
+
grasp_agents-0.5.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|