grasp_agents 0.5.10__py3-none-any.whl → 0.5.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/__init__.py +3 -0
- grasp_agents/cloud_llm.py +15 -15
- grasp_agents/generics_utils.py +1 -1
- grasp_agents/litellm/lite_llm.py +3 -0
- grasp_agents/llm_agent.py +63 -38
- grasp_agents/llm_agent_memory.py +1 -0
- grasp_agents/llm_policy_executor.py +40 -45
- grasp_agents/openai/openai_llm.py +4 -1
- grasp_agents/printer.py +153 -136
- grasp_agents/processors/base_processor.py +5 -3
- grasp_agents/processors/parallel_processor.py +2 -2
- grasp_agents/processors/processor.py +2 -2
- grasp_agents/prompt_builder.py +23 -7
- grasp_agents/run_context.py +2 -9
- grasp_agents/typing/tool.py +5 -3
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/METADATA +7 -20
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/RECORD +19 -19
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.10.dist-info → grasp_agents-0.5.12.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/printer.py
CHANGED
@@ -72,50 +72,61 @@ ColoringMode: TypeAlias = Literal["agent", "role"]
|
|
72
72
|
CompletionBlockType: TypeAlias = Literal["response", "thinking", "tool_call"]
|
73
73
|
|
74
74
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
) -> None:
|
79
|
-
self.color_by = color_by
|
80
|
-
self.msg_trunc_len = msg_trunc_len
|
81
|
-
self._current_message: str = ""
|
75
|
+
def stream_colored_text(new_colored_text: str) -> None:
|
76
|
+
sys.stdout.write(new_colored_text)
|
77
|
+
sys.stdout.flush()
|
82
78
|
|
83
|
-
@staticmethod
|
84
|
-
def get_role_color(role: Role) -> Color:
|
85
|
-
return ROLE_TO_COLOR[role]
|
86
79
|
|
87
|
-
|
88
|
-
|
80
|
+
def get_color(
|
81
|
+
agent_name: str = "", role: Role = Role.ASSISTANT, color_by: ColoringMode = "role"
|
82
|
+
) -> Color:
|
83
|
+
if color_by == "agent":
|
89
84
|
idx = int(
|
90
85
|
hashlib.md5(agent_name.encode()).hexdigest(), # noqa :S324
|
91
86
|
16,
|
92
87
|
) % len(AVAILABLE_COLORS)
|
93
88
|
|
94
89
|
return AVAILABLE_COLORS[idx]
|
90
|
+
return ROLE_TO_COLOR[role]
|
91
|
+
|
95
92
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
return "\n".join(content_str_parts)
|
93
|
+
def content_to_str(content: Content | str | None, role: Role) -> str:
|
94
|
+
if role == Role.USER and isinstance(content, Content):
|
95
|
+
content_str_parts: list[str] = []
|
96
|
+
for content_part in content.parts:
|
97
|
+
if isinstance(content_part, ContentPartText):
|
98
|
+
content_str_parts.append(content_part.data.strip(" \n"))
|
99
|
+
elif content_part.data.type == "url":
|
100
|
+
content_str_parts.append(str(content_part.data.url))
|
101
|
+
elif content_part.data.type == "base64":
|
102
|
+
content_str_parts.append("<ENCODED_IMAGE>")
|
103
|
+
return "\n".join(content_str_parts)
|
108
104
|
|
109
|
-
|
105
|
+
assert isinstance(content, str | None)
|
110
106
|
|
111
|
-
|
107
|
+
return (content or "").strip(" \n")
|
112
108
|
|
113
|
-
@staticmethod
|
114
|
-
def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
|
115
|
-
if len(content_str) > trunc_len:
|
116
|
-
return content_str[:trunc_len] + "[...]"
|
117
109
|
|
118
|
-
|
110
|
+
def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
|
111
|
+
if len(content_str) > trunc_len:
|
112
|
+
return content_str[:trunc_len] + "[...]"
|
113
|
+
|
114
|
+
return content_str
|
115
|
+
|
116
|
+
|
117
|
+
class Printer:
|
118
|
+
def __init__(
|
119
|
+
self,
|
120
|
+
color_by: ColoringMode = "role",
|
121
|
+
msg_trunc_len: int = 20000,
|
122
|
+
output_to: Literal["print", "log"] = "print",
|
123
|
+
logging_level: Literal["info", "debug", "warning", "error"] = "info",
|
124
|
+
) -> None:
|
125
|
+
self.color_by: ColoringMode = color_by
|
126
|
+
self.msg_trunc_len = msg_trunc_len
|
127
|
+
self._current_message: str = ""
|
128
|
+
self._logging_level = logging_level
|
129
|
+
self._output_to = output_to
|
119
130
|
|
120
131
|
def print_message(
|
121
132
|
self,
|
@@ -128,11 +139,8 @@ class Printer:
|
|
128
139
|
raise ValueError(
|
129
140
|
"Usage information can only be printed for AssistantMessage"
|
130
141
|
)
|
131
|
-
|
132
|
-
|
133
|
-
self.get_agent_color(agent_name)
|
134
|
-
if self.color_by == "agent"
|
135
|
-
else self.get_role_color(message.role)
|
142
|
+
color = get_color(
|
143
|
+
agent_name=agent_name, role=message.role, color_by=self.color_by
|
136
144
|
)
|
137
145
|
log_kwargs = {"extra": {"color": color}}
|
138
146
|
|
@@ -144,13 +152,13 @@ class Printer:
|
|
144
152
|
out += f"<thinking>\n{thinking}\n</thinking>\n"
|
145
153
|
|
146
154
|
# Content
|
147
|
-
content =
|
155
|
+
content = content_to_str(message.content or "", message.role)
|
148
156
|
if content:
|
149
157
|
try:
|
150
158
|
content = json.dumps(json.loads(content), indent=2)
|
151
159
|
except Exception:
|
152
160
|
pass
|
153
|
-
content =
|
161
|
+
content = truncate_content_str(content, trunc_len=self.msg_trunc_len)
|
154
162
|
if isinstance(message, SystemMessage):
|
155
163
|
out += f"<system>\n{content}\n</system>\n"
|
156
164
|
elif isinstance(message, UserMessage):
|
@@ -176,7 +184,17 @@ class Printer:
|
|
176
184
|
|
177
185
|
out += f"\n------------------------------------\n{usage_str}\n"
|
178
186
|
|
179
|
-
|
187
|
+
if self._output_to == "log":
|
188
|
+
if self._logging_level == "debug":
|
189
|
+
logger.debug(out, **log_kwargs) # type: ignore
|
190
|
+
elif self._logging_level == "info":
|
191
|
+
logger.info(out, **log_kwargs) # type: ignore
|
192
|
+
elif self._logging_level == "warning":
|
193
|
+
logger.warning(out, **log_kwargs) # type: ignore
|
194
|
+
else:
|
195
|
+
logger.error(out, **log_kwargs) # type: ignore
|
196
|
+
else:
|
197
|
+
stream_colored_text(colored(out + "\n", color))
|
180
198
|
|
181
199
|
def print_messages(
|
182
200
|
self,
|
@@ -193,28 +211,101 @@ class Printer:
|
|
193
211
|
)
|
194
212
|
|
195
213
|
|
196
|
-
def stream_text(new_text: str, color: Color) -> None:
|
197
|
-
sys.stdout.write(colored(new_text, color))
|
198
|
-
sys.stdout.flush()
|
199
|
-
|
200
|
-
|
201
214
|
async def print_event_stream(
|
202
215
|
event_generator: AsyncIterator[Event[Any]],
|
203
216
|
color_by: ColoringMode = "role",
|
204
217
|
trunc_len: int = 10000,
|
205
218
|
) -> AsyncIterator[Event[Any]]:
|
206
|
-
|
219
|
+
def _make_chunk_text(event: CompletionChunkEvent[CompletionChunk]) -> str:
|
220
|
+
color = get_color(
|
221
|
+
agent_name=event.proc_name or "", role=Role.ASSISTANT, color_by=color_by
|
222
|
+
)
|
223
|
+
text = ""
|
224
|
+
|
225
|
+
if isinstance(event, CompletionStartEvent):
|
226
|
+
text += f"\n<{event.proc_name}> [{event.call_id}]\n"
|
227
|
+
elif isinstance(event, ThinkingStartEvent):
|
228
|
+
text += "<thinking>\n"
|
229
|
+
elif isinstance(event, ResponseStartEvent):
|
230
|
+
text += "<response>\n"
|
231
|
+
elif isinstance(event, ToolCallStartEvent):
|
232
|
+
tc = event.data.tool_call
|
233
|
+
text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
|
234
|
+
elif isinstance(event, AnnotationsStartEvent):
|
235
|
+
text += "<annotations>\n"
|
236
|
+
|
237
|
+
# if isinstance(event, CompletionEndEvent):
|
238
|
+
# text += f"\n</{event.proc_name}>\n"
|
239
|
+
if isinstance(event, ThinkingEndEvent):
|
240
|
+
text += "\n</thinking>\n"
|
241
|
+
elif isinstance(event, ResponseEndEvent):
|
242
|
+
text += "\n</response>\n"
|
243
|
+
elif isinstance(event, ToolCallEndEvent):
|
244
|
+
text += "\n</tool call>\n"
|
245
|
+
elif isinstance(event, AnnotationsEndEvent):
|
246
|
+
text += "\n</annotations>\n"
|
247
|
+
|
248
|
+
if isinstance(event, ThinkingChunkEvent):
|
249
|
+
thinking = event.data.thinking
|
250
|
+
if isinstance(thinking, str):
|
251
|
+
text += thinking
|
252
|
+
else:
|
253
|
+
text = "\n".join(
|
254
|
+
[block.get("thinking", "[redacted]") for block in thinking]
|
255
|
+
)
|
207
256
|
|
208
|
-
|
209
|
-
|
210
|
-
return Printer.get_agent_color(event.proc_name or "")
|
211
|
-
return Printer.get_role_color(role)
|
257
|
+
if isinstance(event, ResponseChunkEvent):
|
258
|
+
text += event.data.response
|
212
259
|
|
213
|
-
|
214
|
-
|
215
|
-
) -> None:
|
216
|
-
color = _get_color(event, Role.ASSISTANT)
|
260
|
+
if isinstance(event, ToolCallChunkEvent):
|
261
|
+
text += event.data.tool_call.tool_arguments or ""
|
217
262
|
|
263
|
+
if isinstance(event, AnnotationsChunkEvent):
|
264
|
+
text += "\n".join(
|
265
|
+
[
|
266
|
+
json.dumps(annotation, indent=2)
|
267
|
+
for annotation in event.data.annotations
|
268
|
+
]
|
269
|
+
)
|
270
|
+
|
271
|
+
return colored(text, color)
|
272
|
+
|
273
|
+
def _make_message_text(
|
274
|
+
event: MessageEvent[SystemMessage | UserMessage | ToolMessage],
|
275
|
+
) -> str:
|
276
|
+
message = event.data
|
277
|
+
role = message.role
|
278
|
+
content = content_to_str(message.content, role=role)
|
279
|
+
|
280
|
+
color = get_color(
|
281
|
+
agent_name=event.proc_name or "", role=role, color_by=color_by
|
282
|
+
)
|
283
|
+
text = f"\n<{event.proc_name}> [{event.call_id}]\n"
|
284
|
+
|
285
|
+
if isinstance(event, (SystemMessageEvent, UserMessageEvent)):
|
286
|
+
content = truncate_content_str(content, trunc_len=trunc_len)
|
287
|
+
|
288
|
+
if isinstance(event, SystemMessageEvent):
|
289
|
+
text += f"<system>\n{content}\n</system>\n"
|
290
|
+
|
291
|
+
elif isinstance(event, UserMessageEvent):
|
292
|
+
text += f"<input>\n{content}\n</input>\n"
|
293
|
+
|
294
|
+
elif isinstance(event, ToolMessageEvent):
|
295
|
+
message = event.data
|
296
|
+
try:
|
297
|
+
content = json.dumps(json.loads(content), indent=2)
|
298
|
+
except Exception:
|
299
|
+
pass
|
300
|
+
text += (
|
301
|
+
f"<tool result> [{message.tool_call_id}]\n{content}\n</tool result>\n"
|
302
|
+
)
|
303
|
+
|
304
|
+
return colored(text, color)
|
305
|
+
|
306
|
+
def _make_packet_text(
|
307
|
+
event: ProcPacketOutputEvent | WorkflowResultEvent | RunResultEvent,
|
308
|
+
) -> str:
|
218
309
|
if isinstance(event, WorkflowResultEvent):
|
219
310
|
src = "workflow"
|
220
311
|
elif isinstance(event, RunResultEvent):
|
@@ -222,6 +313,9 @@ async def print_event_stream(
|
|
222
313
|
else:
|
223
314
|
src = "processor"
|
224
315
|
|
316
|
+
color = get_color(
|
317
|
+
agent_name=event.proc_name or "", role=Role.ASSISTANT, color_by=color_by
|
318
|
+
)
|
225
319
|
text = f"\n<{event.proc_name}> [{event.call_id}]\n"
|
226
320
|
|
227
321
|
if event.data.payloads:
|
@@ -237,7 +331,9 @@ async def print_event_stream(
|
|
237
331
|
text += f"{p_str}\n"
|
238
332
|
text += f"</{src} output>\n"
|
239
333
|
|
240
|
-
|
334
|
+
return colored(text, color)
|
335
|
+
|
336
|
+
# ------ Wrap event generator -------
|
241
337
|
|
242
338
|
async for event in event_generator:
|
243
339
|
yield event
|
@@ -245,91 +341,12 @@ async def print_event_stream(
|
|
245
341
|
if isinstance(event, CompletionChunkEvent) and isinstance(
|
246
342
|
event.data, CompletionChunk
|
247
343
|
):
|
248
|
-
|
249
|
-
|
250
|
-
text = ""
|
251
|
-
|
252
|
-
if isinstance(event, CompletionStartEvent):
|
253
|
-
text += f"\n<{event.proc_name}> [{event.call_id}]\n"
|
254
|
-
elif isinstance(event, ThinkingStartEvent):
|
255
|
-
text += "<thinking>\n"
|
256
|
-
elif isinstance(event, ResponseStartEvent):
|
257
|
-
text += "<response>\n"
|
258
|
-
elif isinstance(event, ToolCallStartEvent):
|
259
|
-
tc = event.data.tool_call
|
260
|
-
text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
|
261
|
-
elif isinstance(event, AnnotationsStartEvent):
|
262
|
-
text += "<annotations>\n"
|
263
|
-
|
264
|
-
# if isinstance(event, CompletionEndEvent):
|
265
|
-
# text += f"\n</{event.proc_name}>\n"
|
266
|
-
if isinstance(event, ThinkingEndEvent):
|
267
|
-
text += "\n</thinking>\n"
|
268
|
-
elif isinstance(event, ResponseEndEvent):
|
269
|
-
text += "\n</response>\n"
|
270
|
-
elif isinstance(event, ToolCallEndEvent):
|
271
|
-
text += "\n</tool call>\n"
|
272
|
-
elif isinstance(event, AnnotationsEndEvent):
|
273
|
-
text += "\n</annotations>\n"
|
274
|
-
|
275
|
-
if isinstance(event, ThinkingChunkEvent):
|
276
|
-
thinking = event.data.thinking
|
277
|
-
if isinstance(thinking, str):
|
278
|
-
text += thinking
|
279
|
-
else:
|
280
|
-
text = "\n".join(
|
281
|
-
[block.get("thinking", "[redacted]") for block in thinking]
|
282
|
-
)
|
283
|
-
|
284
|
-
if isinstance(event, ResponseChunkEvent):
|
285
|
-
text += event.data.response
|
286
|
-
|
287
|
-
if isinstance(event, ToolCallChunkEvent):
|
288
|
-
text += event.data.tool_call.tool_arguments or ""
|
289
|
-
|
290
|
-
if isinstance(event, AnnotationsChunkEvent):
|
291
|
-
text += "\n".join(
|
292
|
-
[
|
293
|
-
json.dumps(annotation, indent=2)
|
294
|
-
for annotation in event.data.annotations
|
295
|
-
]
|
296
|
-
)
|
297
|
-
|
298
|
-
stream_text(text, color)
|
344
|
+
stream_colored_text(_make_chunk_text(event))
|
299
345
|
|
300
346
|
if isinstance(event, MessageEvent) and not isinstance(event, GenMessageEvent):
|
301
|
-
|
302
|
-
|
303
|
-
message = event.data
|
304
|
-
role = message.role
|
305
|
-
content = Printer.content_to_str(message.content, role=role)
|
306
|
-
color = _get_color(event, role)
|
307
|
-
|
308
|
-
text = f"\n<{event.proc_name}> [{event.call_id}]\n"
|
309
|
-
|
310
|
-
if isinstance(event, (SystemMessageEvent, UserMessageEvent)):
|
311
|
-
content = Printer.truncate_content_str(content, trunc_len=trunc_len)
|
312
|
-
|
313
|
-
if isinstance(event, SystemMessageEvent):
|
314
|
-
text += f"<system>\n{content}\n</system>\n"
|
315
|
-
|
316
|
-
elif isinstance(event, UserMessageEvent):
|
317
|
-
text += f"<input>\n{content}\n</input>\n"
|
318
|
-
|
319
|
-
elif isinstance(event, ToolMessageEvent):
|
320
|
-
message = event.data
|
321
|
-
try:
|
322
|
-
content = json.dumps(json.loads(content), indent=2)
|
323
|
-
except Exception:
|
324
|
-
pass
|
325
|
-
text += (
|
326
|
-
f"<tool result> [{message.tool_call_id}]\n"
|
327
|
-
f"{content}\n</tool result>\n"
|
328
|
-
)
|
329
|
-
|
330
|
-
stream_text(text, color)
|
347
|
+
stream_colored_text(_make_message_text(event))
|
331
348
|
|
332
349
|
if isinstance(
|
333
350
|
event, (ProcPacketOutputEvent, WorkflowResultEvent, RunResultEvent)
|
334
351
|
):
|
335
|
-
|
352
|
+
stream_colored_text(_make_packet_text(event))
|
@@ -60,11 +60,13 @@ def with_retry(func: F) -> F:
|
|
60
60
|
logger.warning(f"{err_message}:\n{err}")
|
61
61
|
else:
|
62
62
|
logger.warning(f"{err_message} after retrying:\n{err}")
|
63
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
63
|
+
# raise ProcRunError(proc_name=self.name, call_id=call_id) from err
|
64
|
+
return None # type: ignore[return]
|
64
65
|
|
65
66
|
logger.warning(f"{err_message} (retry attempt {n_attempt + 1}):\n{err}")
|
66
67
|
# This part should not be reachable due to the raise in the loop
|
67
|
-
raise ProcRunError(proc_name=self.name, call_id=call_id)
|
68
|
+
# raise ProcRunError(proc_name=self.name, call_id=call_id)
|
69
|
+
return None # type: ignore[return]
|
68
70
|
|
69
71
|
return cast("F", wrapper)
|
70
72
|
|
@@ -106,7 +108,7 @@ _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
|
|
106
108
|
|
107
109
|
class RecipientSelector(Protocol[_OutT_contra, CtxT]):
|
108
110
|
def __call__(
|
109
|
-
self, output: _OutT_contra, ctx: RunContext[CtxT]
|
111
|
+
self, output: _OutT_contra, *, ctx: RunContext[CtxT]
|
110
112
|
) -> Sequence[ProcName] | None: ...
|
111
113
|
|
112
114
|
|
@@ -114,7 +114,7 @@ class ParallelProcessor(
|
|
114
114
|
ctx: RunContext[CtxT] | None = None,
|
115
115
|
) -> Packet[OutT]:
|
116
116
|
call_id = self._generate_call_id(call_id)
|
117
|
-
ctx = RunContext[CtxT](state=None)
|
117
|
+
ctx = ctx or RunContext[CtxT](state=None) # type: ignore
|
118
118
|
|
119
119
|
val_in_args = self._validate_inputs(
|
120
120
|
call_id=call_id,
|
@@ -223,7 +223,7 @@ class ParallelProcessor(
|
|
223
223
|
ctx: RunContext[CtxT] | None = None,
|
224
224
|
) -> AsyncIterator[Event[Any]]:
|
225
225
|
call_id = self._generate_call_id(call_id)
|
226
|
-
ctx = RunContext[CtxT](state=None)
|
226
|
+
ctx = ctx or RunContext[CtxT](state=None) # type: ignore
|
227
227
|
|
228
228
|
val_in_args = self._validate_inputs(
|
229
229
|
call_id=call_id,
|
@@ -105,7 +105,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
|
|
105
105
|
call_id: str | None = None,
|
106
106
|
ctx: RunContext[CtxT] | None = None,
|
107
107
|
) -> Packet[OutT]:
|
108
|
-
ctx = RunContext[CtxT](state=None)
|
108
|
+
ctx = ctx or RunContext[CtxT](state=None) # type: ignore
|
109
109
|
|
110
110
|
val_in_args, memory, call_id = self._preprocess(
|
111
111
|
chat_inputs=chat_inputs,
|
@@ -136,7 +136,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
|
|
136
136
|
call_id: str | None = None,
|
137
137
|
ctx: RunContext[CtxT] | None = None,
|
138
138
|
) -> AsyncIterator[Event[Any]]:
|
139
|
-
ctx = RunContext[CtxT](state=None)
|
139
|
+
ctx = ctx or RunContext[CtxT](state=None) # type: ignore
|
140
140
|
|
141
141
|
val_in_args, memory, call_id = self._preprocess(
|
142
142
|
chat_inputs=chat_inputs,
|
grasp_agents/prompt_builder.py
CHANGED
@@ -15,11 +15,22 @@ _InT_contra = TypeVar("_InT_contra", contravariant=True)
|
|
15
15
|
|
16
16
|
|
17
17
|
class SystemPromptBuilder(Protocol[CtxT]):
|
18
|
-
def __call__(
|
18
|
+
def __call__(
|
19
|
+
self,
|
20
|
+
*,
|
21
|
+
ctx: RunContext[CtxT],
|
22
|
+
call_id: str,
|
23
|
+
) -> str | None: ...
|
19
24
|
|
20
25
|
|
21
26
|
class InputContentBuilder(Protocol[_InT_contra, CtxT]):
|
22
|
-
def __call__(
|
27
|
+
def __call__(
|
28
|
+
self,
|
29
|
+
in_args: _InT_contra,
|
30
|
+
*,
|
31
|
+
ctx: RunContext[CtxT],
|
32
|
+
call_id: str,
|
33
|
+
) -> Content: ...
|
23
34
|
|
24
35
|
|
25
36
|
PromptArgumentType: TypeAlias = str | bool | int | ImageData
|
@@ -43,9 +54,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
43
54
|
self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
44
55
|
|
45
56
|
@final
|
46
|
-
def build_system_prompt(self, ctx: RunContext[CtxT]) -> str | None:
|
57
|
+
def build_system_prompt(self, ctx: RunContext[CtxT], call_id: str) -> str | None:
|
47
58
|
if self.system_prompt_builder:
|
48
|
-
return self.system_prompt_builder(ctx=ctx)
|
59
|
+
return self.system_prompt_builder(ctx=ctx, call_id=call_id)
|
49
60
|
|
50
61
|
return self.sys_prompt
|
51
62
|
|
@@ -71,7 +82,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
71
82
|
|
72
83
|
@final
|
73
84
|
def _build_input_content(
|
74
|
-
self, in_args: InT | None, ctx: RunContext[CtxT]
|
85
|
+
self, in_args: InT | None, ctx: RunContext[CtxT], call_id: str
|
75
86
|
) -> Content:
|
76
87
|
if in_args is None and self._in_type is not type(None):
|
77
88
|
raise InputPromptBuilderError(
|
@@ -83,7 +94,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
83
94
|
|
84
95
|
val_in_args = self._validate_input_args(in_args)
|
85
96
|
if self.input_content_builder:
|
86
|
-
return self.input_content_builder(
|
97
|
+
return self.input_content_builder(
|
98
|
+
in_args=val_in_args, ctx=ctx, call_id=call_id
|
99
|
+
)
|
87
100
|
|
88
101
|
if issubclass(self._in_type, BaseModel) and isinstance(val_in_args, BaseModel):
|
89
102
|
val_in_args_map = self._format_pydantic_prompt_args(val_in_args)
|
@@ -102,6 +115,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
102
115
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
103
116
|
*,
|
104
117
|
in_args: InT | None = None,
|
118
|
+
call_id: str,
|
105
119
|
ctx: RunContext[CtxT],
|
106
120
|
) -> UserMessage | None:
|
107
121
|
if chat_inputs is not None:
|
@@ -116,7 +130,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
116
130
|
return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
|
117
131
|
|
118
132
|
return UserMessage(
|
119
|
-
content=self._build_input_content(
|
133
|
+
content=self._build_input_content(
|
134
|
+
in_args=in_args, ctx=ctx, call_id=call_id
|
135
|
+
),
|
120
136
|
name=self._agent_name,
|
121
137
|
)
|
122
138
|
|
grasp_agents/run_context.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
from collections import defaultdict
|
2
|
-
from typing import
|
2
|
+
from typing import Generic, TypeVar
|
3
3
|
|
4
4
|
from pydantic import BaseModel, ConfigDict, Field
|
5
5
|
|
6
6
|
from grasp_agents.typing.completion import Completion
|
7
7
|
|
8
|
-
from .printer import
|
8
|
+
from .printer import Printer
|
9
9
|
from .typing.io import ProcName
|
10
10
|
from .usage_tracker import UsageTracker
|
11
11
|
|
@@ -19,13 +19,6 @@ class RunContext(BaseModel, Generic[CtxT]):
|
|
19
19
|
default_factory=lambda: defaultdict(list)
|
20
20
|
)
|
21
21
|
usage_tracker: UsageTracker = Field(default_factory=UsageTracker)
|
22
|
-
|
23
22
|
printer: Printer | None = None
|
24
|
-
log_messages: bool = False
|
25
|
-
color_messages_by: ColoringMode = "role"
|
26
|
-
|
27
|
-
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
28
|
-
if self.log_messages:
|
29
|
-
self.printer = Printer(color_by=self.color_messages_by)
|
30
23
|
|
31
24
|
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
grasp_agents/typing/tool.py
CHANGED
@@ -64,20 +64,22 @@ class BaseTool(
|
|
64
64
|
self,
|
65
65
|
inp: _InT,
|
66
66
|
*,
|
67
|
-
call_id: str | None = None,
|
68
67
|
ctx: RunContext[CtxT] | None = None,
|
68
|
+
call_id: str | None = None,
|
69
69
|
) -> _OutT_co:
|
70
70
|
pass
|
71
71
|
|
72
72
|
async def __call__(
|
73
73
|
self,
|
74
74
|
*,
|
75
|
-
call_id: str | None = None,
|
76
75
|
ctx: RunContext[CtxT] | None = None,
|
76
|
+
call_id: str | None = None,
|
77
77
|
**kwargs: Any,
|
78
78
|
) -> _OutT_co:
|
79
|
+
# NOTE: validation is probably redundant here when tool inputs have been
|
80
|
+
# validated by the LLM already
|
79
81
|
input_args = TypeAdapter(self._in_type).validate_python(kwargs)
|
80
|
-
output = await self.run(input_args,
|
82
|
+
output = await self.run(input_args, ctx=ctx, call_id=call_id)
|
81
83
|
|
82
84
|
return TypeAdapter(self._out_type).validate_python(output)
|
83
85
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: grasp_agents
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.12
|
4
4
|
Summary: Grasp Agents Library
|
5
5
|
License-File: LICENSE.md
|
6
6
|
Requires-Python: <4,>=3.11.4
|
@@ -110,24 +110,16 @@ Create a script, e.g., `problem_recommender.py`:
|
|
110
110
|
|
111
111
|
```python
|
112
112
|
import asyncio
|
113
|
-
from pathlib import Path
|
114
113
|
from typing import Any
|
115
114
|
|
116
115
|
from dotenv import load_dotenv
|
117
116
|
from pydantic import BaseModel, Field
|
118
117
|
|
119
|
-
from grasp_agents
|
118
|
+
from grasp_agents import LLMAgent, BaseTool, RunContext, Printer
|
120
119
|
from grasp_agents.litellm import LiteLLM, LiteLLMSettings
|
121
|
-
from grasp_agents import LLMAgent, BaseTool, RunContext
|
122
|
-
|
123
|
-
load_dotenv()
|
124
120
|
|
125
121
|
|
126
|
-
|
127
|
-
setup_logging(
|
128
|
-
logs_file_path="grasp_agents_demo.log",
|
129
|
-
logs_config_path=Path().cwd() / "configs/logging/default.yaml",
|
130
|
-
)
|
122
|
+
load_dotenv()
|
131
123
|
|
132
124
|
sys_prompt_react = """
|
133
125
|
Your task is to suggest an exciting stats problem to the student.
|
@@ -162,13 +154,11 @@ Returns:
|
|
162
154
|
"""
|
163
155
|
|
164
156
|
|
165
|
-
class AskStudentTool(BaseTool[TeacherQuestion, StudentReply,
|
157
|
+
class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, None]):
|
166
158
|
name: str = "ask_student"
|
167
159
|
description: str = ask_student_tool_description
|
168
160
|
|
169
|
-
async def run(
|
170
|
-
self, inp: TeacherQuestion, ctx: RunContext[Any] | None = None
|
171
|
-
) -> StudentReply:
|
161
|
+
async def run(self, inp: TeacherQuestion, **kwargs: Any) -> StudentReply:
|
172
162
|
return input(inp.question)
|
173
163
|
|
174
164
|
|
@@ -178,10 +168,7 @@ class Problem(BaseModel):
|
|
178
168
|
|
179
169
|
teacher = LLMAgent[None, Problem, None](
|
180
170
|
name="teacher",
|
181
|
-
llm=LiteLLM(
|
182
|
-
model_name="gpt-4.1",
|
183
|
-
llm_settings=LiteLLMSettings(temperature=0.5),
|
184
|
-
),
|
171
|
+
llm=LiteLLM(model_name="gpt-4.1"),
|
185
172
|
tools=[AskStudentTool()],
|
186
173
|
react_mode=True,
|
187
174
|
final_answer_as_tool_call=True,
|
@@ -189,7 +176,7 @@ teacher = LLMAgent[None, Problem, None](
|
|
189
176
|
)
|
190
177
|
|
191
178
|
async def main():
|
192
|
-
ctx = RunContext[None](
|
179
|
+
ctx = RunContext[None](printer=Printer())
|
193
180
|
out = await teacher.run("start", ctx=ctx)
|
194
181
|
print(out.payloads[0])
|
195
182
|
print(ctx.usage_tracker.total_usage)
|