grasp_agents 0.5.11__py3-none-any.whl → 0.5.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/__init__.py CHANGED
@@ -6,6 +6,7 @@ from .llm_agent import LLMAgent
6
6
  from .llm_agent_memory import LLMAgentMemory
7
7
  from .memory import Memory
8
8
  from .packet import Packet
9
+ from .printer import Printer, print_event_stream
9
10
  from .processors.base_processor import BaseProcessor
10
11
  from .processors.parallel_processor import ParallelProcessor
11
12
  from .processors.processor import Processor
@@ -33,9 +34,11 @@ __all__ = [
33
34
  "Packet",
34
35
  "Packet",
35
36
  "ParallelProcessor",
37
+ "Printer",
36
38
  "ProcName",
37
39
  "Processor",
38
40
  "RunContext",
39
41
  "SystemMessage",
40
42
  "UserMessage",
43
+ "print_event_stream",
41
44
  ]
grasp_agents/cloud_llm.py CHANGED
@@ -69,6 +69,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
69
69
  0 # LLM response retries: try to regenerate to pass validation
70
70
  )
71
71
  apply_response_schema_via_provider: bool = False
72
+ apply_tool_call_schema_via_provider: bool = False
72
73
  async_http_client: httpx.AsyncClient | None = None
73
74
  async_http_client_params: dict[str, Any] | AsyncHTTPClientParams | None = None
74
75
 
@@ -79,6 +80,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
79
80
  f"{self.rate_limiter.rpm} RPM"
80
81
  )
81
82
 
83
+ if self.apply_response_schema_via_provider:
84
+ object.__setattr__(self, "apply_tool_call_schema_via_provider", True)
85
+
82
86
  if self.async_http_client is None and self.async_http_client_params is not None:
83
87
  object.__setattr__(
84
88
  self,
@@ -99,7 +103,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
99
103
  api_tools = None
100
104
  api_tool_choice = None
101
105
  if tools:
102
- strict = True if self.apply_response_schema_via_provider else None
106
+ strict = True if self.apply_tool_call_schema_via_provider else None
103
107
  api_tools = [
104
108
  self.converters.to_tool(t, strict=strict) for t in tools.values()
105
109
  ]
@@ -174,8 +178,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
174
178
  response_schema=response_schema,
175
179
  response_schema_by_xml_tag=response_schema_by_xml_tag,
176
180
  )
177
- if tools is not None:
178
- self._validate_tool_calls(completion, tools=tools)
181
+ if not self.apply_tool_call_schema_via_provider and tools is not None:
182
+ self._validate_tool_calls(completion, tools=tools)
179
183
 
180
184
  return completion
181
185
 
@@ -207,17 +211,16 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
207
211
 
208
212
  if n_attempt > self.max_response_retries:
209
213
  if n_attempt == 1:
210
- logger.warning(f"\nCloudLLM completion request failed:\n{err}")
214
+ logger.warning(f"\nCloudLLM completion failed:\n{err}")
211
215
  if n_attempt > 1:
212
216
  logger.warning(
213
- f"\nCloudLLM completion request failed after retrying:\n{err}"
217
+ f"\nCloudLLM completion failed after retrying:\n{err}"
214
218
  )
215
219
  raise err
216
220
  # return make_refusal_completion(self._model_name, err)
217
221
 
218
222
  logger.warning(
219
- f"\nCloudLLM completion request failed (retry attempt {n_attempt}):"
220
- f"\n{err}"
223
+ f"\nCloudLLM completion failed (retry attempt {n_attempt}):\n{err}"
221
224
  )
222
225
 
223
226
  return make_refusal_completion(
@@ -281,8 +284,8 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
281
284
  response_schema=response_schema,
282
285
  response_schema_by_xml_tag=response_schema_by_xml_tag,
283
286
  )
284
- if tools is not None:
285
- self._validate_tool_calls(completion, tools=tools)
287
+ if not self.apply_tool_call_schema_via_provider and tools is not None:
288
+ self._validate_tool_calls(completion, tools=tools)
286
289
 
287
290
  return iterator()
288
291
 
@@ -326,11 +329,10 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
326
329
  n_attempt += 1
327
330
  if n_attempt > self.max_response_retries:
328
331
  if n_attempt == 1:
329
- logger.warning(f"\nCloudLLM completion request failed:\n{err}")
332
+ logger.warning(f"\nCloudLLM completion failed:\n{err}")
330
333
  if n_attempt > 1:
331
334
  logger.warning(
332
- "\nCloudLLM completion request failed after "
333
- f"retrying:\n{err}"
335
+ f"\nCloudLLM completion failed after retrying:\n{err}"
334
336
  )
335
337
  refusal_completion = make_refusal_completion(
336
338
  self.model_name, err
@@ -344,6 +346,5 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
344
346
  # return
345
347
 
346
348
  logger.warning(
347
- "\nCloudLLM completion request failed "
348
- f"(retry attempt {n_attempt}):\n{err}"
349
+ f"\nCloudLLM completion failed (retry attempt {n_attempt}):\n{err}"
349
350
  )
@@ -159,7 +159,7 @@ class AutoInstanceAttributesMixin:
159
159
  attr_type = resolved_attr_types[attr_name]
160
160
  # attr_type = None if _attr_type is type(None) else _attr_type
161
161
  else:
162
- attr_type = Any
162
+ attr_type = object
163
163
 
164
164
  if attr_name in pyd_private:
165
165
  pyd_private[attr_name] = attr_type
@@ -74,23 +74,22 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
74
74
  )
75
75
 
76
76
  _api_provider = self.api_provider
77
-
78
- if self.model_name in litellm.get_valid_models(): # type: ignore[no-untyped-call]
77
+ try:
79
78
  _, provider_name, _, _ = litellm.get_llm_provider(self.model_name) # type: ignore[no-untyped-call]
80
79
  _api_provider = APIProvider(name=provider_name)
81
- elif self.api_provider is not None:
82
- self._lite_llm_completion_params["api_key"] = self.api_provider.get(
83
- "api_key"
84
- )
85
- self._lite_llm_completion_params["api_base"] = self.api_provider.get(
86
- "api_base"
87
- )
88
- elif self.api_provider is None:
89
- raise ValueError(
90
- f"Model '{self.model_name}' is not supported by LiteLLM and no API provider "
91
- "was specified. Please provide a valid API provider or use a different "
92
- "model."
93
- )
80
+ except Exception as exc:
81
+ if self.api_provider is not None:
82
+ self._lite_llm_completion_params["api_key"] = self.api_provider.get(
83
+ "api_key"
84
+ )
85
+ self._lite_llm_completion_params["api_base"] = self.api_provider.get(
86
+ "api_base"
87
+ )
88
+ else:
89
+ raise ValueError(
90
+ f"Failed to retrieve a LiteLLM supported API provider for model "
91
+ f"'{self.model_name}' and no custom API provider was specified."
92
+ ) from exc
94
93
 
95
94
  if self.llm_settings is not None:
96
95
  stream_options = self.llm_settings.get("stream_options") or {}
@@ -149,6 +148,9 @@ class LiteLLM(CloudLLM[LiteLLMSettings, LiteLLMConverters]):
149
148
  n_choices: int | None = None,
150
149
  **api_llm_settings: Any,
151
150
  ) -> LiteLLMCompletion:
151
+ if api_llm_settings and api_llm_settings.get("stream_options"):
152
+ api_llm_settings.pop("stream_options")
153
+
152
154
  completion = await litellm.acompletion( # type: ignore[no-untyped-call]
153
155
  model=self.model_name,
154
156
  messages=api_messages,
grasp_agents/llm_agent.py CHANGED
@@ -310,12 +310,9 @@ class LLMAgent(
310
310
  yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
311
311
 
312
312
  def _print_messages(
313
- self,
314
- messages: Sequence[Message],
315
- ctx: RunContext[CtxT],
316
- call_id: str,
313
+ self, messages: Sequence[Message], ctx: RunContext[CtxT], call_id: str
317
314
  ) -> None:
318
- if ctx and ctx.printer:
315
+ if ctx.printer:
319
316
  ctx.printer.print_messages(messages, agent_name=self.name, call_id=call_id)
320
317
 
321
318
  # -- Override these methods in subclasses if needed --
@@ -161,14 +161,11 @@ class LLMPolicyExecutor(Generic[CtxT]):
161
161
  response_schema_by_xml_tag=self.response_schema_by_xml_tag,
162
162
  tools=self.tools,
163
163
  tool_choice=tool_choice,
164
- n_choices=1,
165
164
  proc_name=self.agent_name,
166
165
  call_id=call_id,
167
166
  )
168
167
  memory.update(completion.messages)
169
- self._process_completion(
170
- completion, ctx=ctx, call_id=call_id, print_messages=True
171
- )
168
+ self._process_completion(completion, ctx=ctx, call_id=call_id)
172
169
 
173
170
  return completion.messages[0]
174
171
 
@@ -193,7 +190,6 @@ class LLMPolicyExecutor(Generic[CtxT]):
193
190
  response_schema_by_xml_tag=self.response_schema_by_xml_tag,
194
191
  tools=self.tools,
195
192
  tool_choice=tool_choice,
196
- n_choices=1,
197
193
  proc_name=self.agent_name,
198
194
  call_id=call_id,
199
195
  )
@@ -212,9 +208,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
212
208
 
213
209
  memory.update(completion.messages)
214
210
 
215
- self._process_completion(
216
- completion, print_messages=True, ctx=ctx, call_id=call_id
217
- )
211
+ self._process_completion(completion, ctx=ctx, call_id=call_id)
218
212
 
219
213
  async def call_tools(
220
214
  self,
@@ -237,7 +231,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
237
231
 
238
232
  memory.update(tool_messages)
239
233
 
240
- if ctx and ctx.printer:
234
+ if ctx.printer:
241
235
  ctx.printer.print_messages(
242
236
  tool_messages, agent_name=self.agent_name, call_id=call_id
243
237
  )
@@ -283,7 +277,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
283
277
  "Exceeded the maximum number of turns: provide a final answer now!"
284
278
  )
285
279
  memory.update([user_message])
286
- if ctx and ctx.printer:
280
+ if ctx.printer:
287
281
  ctx.printer.print_messages(
288
282
  [user_message], agent_name=self.agent_name, call_id=call_id
289
283
  )
@@ -309,7 +303,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
309
303
  yield UserMessageEvent(
310
304
  proc_name=self.agent_name, call_id=call_id, data=user_message
311
305
  )
312
- if ctx and ctx.printer:
306
+ if ctx.printer:
313
307
  ctx.printer.print_messages(
314
308
  [user_message], agent_name=self.agent_name, call_id=call_id
315
309
  )
@@ -507,12 +501,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
507
501
  return FinalAnswerTool()
508
502
 
509
503
  def _process_completion(
510
- self,
511
- completion: Completion,
512
- *,
513
- print_messages: bool = False,
514
- ctx: RunContext[CtxT],
515
- call_id: str,
504
+ self, completion: Completion, *, ctx: RunContext[CtxT], call_id: str
516
505
  ) -> None:
517
506
  ctx.completions[self.agent_name].append(completion)
518
507
  ctx.usage_tracker.update(
@@ -520,7 +509,7 @@ class LLMPolicyExecutor(Generic[CtxT]):
520
509
  completions=[completion],
521
510
  model_name=self.llm.model_name,
522
511
  )
523
- if ctx.printer and print_messages:
512
+ if ctx.printer:
524
513
  usages = [None] * (len(completion.messages) - 1) + [completion.usage]
525
514
  ctx.printer.print_messages(
526
515
  completion.messages,
@@ -172,6 +172,9 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
172
172
  response_format = api_response_schema or NOT_GIVEN
173
173
  n = n_choices or NOT_GIVEN
174
174
 
175
+ if api_llm_settings and api_llm_settings.get("stream_options"):
176
+ api_llm_settings.pop("stream_options")
177
+
175
178
  if self.apply_response_schema_via_provider:
176
179
  return await self.client.beta.chat.completions.parse(
177
180
  model=self.model_name,
grasp_agents/printer.py CHANGED
@@ -72,50 +72,61 @@ ColoringMode: TypeAlias = Literal["agent", "role"]
72
72
  CompletionBlockType: TypeAlias = Literal["response", "thinking", "tool_call"]
73
73
 
74
74
 
75
- class Printer:
76
- def __init__(
77
- self, color_by: ColoringMode = "role", msg_trunc_len: int = 20000
78
- ) -> None:
79
- self.color_by = color_by
80
- self.msg_trunc_len = msg_trunc_len
81
- self._current_message: str = ""
75
+ def stream_colored_text(new_colored_text: str) -> None:
76
+ sys.stdout.write(new_colored_text)
77
+ sys.stdout.flush()
82
78
 
83
- @staticmethod
84
- def get_role_color(role: Role) -> Color:
85
- return ROLE_TO_COLOR[role]
86
79
 
87
- @staticmethod
88
- def get_agent_color(agent_name: str) -> Color:
80
+ def get_color(
81
+ agent_name: str = "", role: Role = Role.ASSISTANT, color_by: ColoringMode = "role"
82
+ ) -> Color:
83
+ if color_by == "agent":
89
84
  idx = int(
90
85
  hashlib.md5(agent_name.encode()).hexdigest(), # noqa :S324
91
86
  16,
92
87
  ) % len(AVAILABLE_COLORS)
93
88
 
94
89
  return AVAILABLE_COLORS[idx]
90
+ return ROLE_TO_COLOR[role]
91
+
95
92
 
96
- @staticmethod
97
- def content_to_str(content: Content | str | None, role: Role) -> str:
98
- if role == Role.USER and isinstance(content, Content):
99
- content_str_parts: list[str] = []
100
- for content_part in content.parts:
101
- if isinstance(content_part, ContentPartText):
102
- content_str_parts.append(content_part.data.strip(" \n"))
103
- elif content_part.data.type == "url":
104
- content_str_parts.append(str(content_part.data.url))
105
- elif content_part.data.type == "base64":
106
- content_str_parts.append("<ENCODED_IMAGE>")
107
- return "\n".join(content_str_parts)
93
+ def content_to_str(content: Content | str | None, role: Role) -> str:
94
+ if role == Role.USER and isinstance(content, Content):
95
+ content_str_parts: list[str] = []
96
+ for content_part in content.parts:
97
+ if isinstance(content_part, ContentPartText):
98
+ content_str_parts.append(content_part.data.strip(" \n"))
99
+ elif content_part.data.type == "url":
100
+ content_str_parts.append(str(content_part.data.url))
101
+ elif content_part.data.type == "base64":
102
+ content_str_parts.append("<ENCODED_IMAGE>")
103
+ return "\n".join(content_str_parts)
108
104
 
109
- assert isinstance(content, str | None)
105
+ assert isinstance(content, str | None)
110
106
 
111
- return (content or "").strip(" \n")
107
+ return (content or "").strip(" \n")
112
108
 
113
- @staticmethod
114
- def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
115
- if len(content_str) > trunc_len:
116
- return content_str[:trunc_len] + "[...]"
117
109
 
118
- return content_str
110
+ def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
111
+ if len(content_str) > trunc_len:
112
+ return content_str[:trunc_len] + "[...]"
113
+
114
+ return content_str
115
+
116
+
117
+ class Printer:
118
+ def __init__(
119
+ self,
120
+ color_by: ColoringMode = "role",
121
+ msg_trunc_len: int = 20000,
122
+ output_to: Literal["print", "log"] = "print",
123
+ logging_level: Literal["info", "debug", "warning", "error"] = "info",
124
+ ) -> None:
125
+ self.color_by: ColoringMode = color_by
126
+ self.msg_trunc_len = msg_trunc_len
127
+ self._current_message: str = ""
128
+ self._logging_level = logging_level
129
+ self._output_to = output_to
119
130
 
120
131
  def print_message(
121
132
  self,
@@ -128,11 +139,8 @@ class Printer:
128
139
  raise ValueError(
129
140
  "Usage information can only be printed for AssistantMessage"
130
141
  )
131
-
132
- color = (
133
- self.get_agent_color(agent_name)
134
- if self.color_by == "agent"
135
- else self.get_role_color(message.role)
142
+ color = get_color(
143
+ agent_name=agent_name, role=message.role, color_by=self.color_by
136
144
  )
137
145
  log_kwargs = {"extra": {"color": color}}
138
146
 
@@ -144,13 +152,13 @@ class Printer:
144
152
  out += f"<thinking>\n{thinking}\n</thinking>\n"
145
153
 
146
154
  # Content
147
- content = self.content_to_str(message.content or "", message.role)
155
+ content = content_to_str(message.content or "", message.role)
148
156
  if content:
149
157
  try:
150
158
  content = json.dumps(json.loads(content), indent=2)
151
159
  except Exception:
152
160
  pass
153
- content = self.truncate_content_str(content, trunc_len=self.msg_trunc_len)
161
+ content = truncate_content_str(content, trunc_len=self.msg_trunc_len)
154
162
  if isinstance(message, SystemMessage):
155
163
  out += f"<system>\n{content}\n</system>\n"
156
164
  elif isinstance(message, UserMessage):
@@ -176,7 +184,17 @@ class Printer:
176
184
 
177
185
  out += f"\n------------------------------------\n{usage_str}\n"
178
186
 
179
- logger.debug(out, **log_kwargs) # type: ignore
187
+ if self._output_to == "log":
188
+ if self._logging_level == "debug":
189
+ logger.debug(out, **log_kwargs) # type: ignore
190
+ elif self._logging_level == "info":
191
+ logger.info(out, **log_kwargs) # type: ignore
192
+ elif self._logging_level == "warning":
193
+ logger.warning(out, **log_kwargs) # type: ignore
194
+ else:
195
+ logger.error(out, **log_kwargs) # type: ignore
196
+ else:
197
+ stream_colored_text(colored(out + "\n", color))
180
198
 
181
199
  def print_messages(
182
200
  self,
@@ -193,28 +211,101 @@ class Printer:
193
211
  )
194
212
 
195
213
 
196
- def stream_text(new_text: str, color: Color) -> None:
197
- sys.stdout.write(colored(new_text, color))
198
- sys.stdout.flush()
199
-
200
-
201
214
  async def print_event_stream(
202
215
  event_generator: AsyncIterator[Event[Any]],
203
216
  color_by: ColoringMode = "role",
204
217
  trunc_len: int = 10000,
205
218
  ) -> AsyncIterator[Event[Any]]:
206
- color = Printer.get_role_color(Role.ASSISTANT)
219
+ def _make_chunk_text(event: CompletionChunkEvent[CompletionChunk]) -> str:
220
+ color = get_color(
221
+ agent_name=event.proc_name or "", role=Role.ASSISTANT, color_by=color_by
222
+ )
223
+ text = ""
224
+
225
+ if isinstance(event, CompletionStartEvent):
226
+ text += f"\n<{event.proc_name}> [{event.call_id}]\n"
227
+ elif isinstance(event, ThinkingStartEvent):
228
+ text += "<thinking>\n"
229
+ elif isinstance(event, ResponseStartEvent):
230
+ text += "<response>\n"
231
+ elif isinstance(event, ToolCallStartEvent):
232
+ tc = event.data.tool_call
233
+ text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
234
+ elif isinstance(event, AnnotationsStartEvent):
235
+ text += "<annotations>\n"
236
+
237
+ # if isinstance(event, CompletionEndEvent):
238
+ # text += f"\n</{event.proc_name}>\n"
239
+ if isinstance(event, ThinkingEndEvent):
240
+ text += "\n</thinking>\n"
241
+ elif isinstance(event, ResponseEndEvent):
242
+ text += "\n</response>\n"
243
+ elif isinstance(event, ToolCallEndEvent):
244
+ text += "\n</tool call>\n"
245
+ elif isinstance(event, AnnotationsEndEvent):
246
+ text += "\n</annotations>\n"
247
+
248
+ if isinstance(event, ThinkingChunkEvent):
249
+ thinking = event.data.thinking
250
+ if isinstance(thinking, str):
251
+ text += thinking
252
+ else:
253
+ text = "\n".join(
254
+ [block.get("thinking", "[redacted]") for block in thinking]
255
+ )
207
256
 
208
- def _get_color(event: Event[Any], role: Role = Role.ASSISTANT) -> Color:
209
- if color_by == "agent":
210
- return Printer.get_agent_color(event.proc_name or "")
211
- return Printer.get_role_color(role)
257
+ if isinstance(event, ResponseChunkEvent):
258
+ text += event.data.response
212
259
 
213
- def _print_packet(
214
- event: ProcPacketOutputEvent | WorkflowResultEvent | RunResultEvent,
215
- ) -> None:
216
- color = _get_color(event, Role.ASSISTANT)
260
+ if isinstance(event, ToolCallChunkEvent):
261
+ text += event.data.tool_call.tool_arguments or ""
217
262
 
263
+ if isinstance(event, AnnotationsChunkEvent):
264
+ text += "\n".join(
265
+ [
266
+ json.dumps(annotation, indent=2)
267
+ for annotation in event.data.annotations
268
+ ]
269
+ )
270
+
271
+ return colored(text, color)
272
+
273
+ def _make_message_text(
274
+ event: MessageEvent[SystemMessage | UserMessage | ToolMessage],
275
+ ) -> str:
276
+ message = event.data
277
+ role = message.role
278
+ content = content_to_str(message.content, role=role)
279
+
280
+ color = get_color(
281
+ agent_name=event.proc_name or "", role=role, color_by=color_by
282
+ )
283
+ text = f"\n<{event.proc_name}> [{event.call_id}]\n"
284
+
285
+ if isinstance(event, (SystemMessageEvent, UserMessageEvent)):
286
+ content = truncate_content_str(content, trunc_len=trunc_len)
287
+
288
+ if isinstance(event, SystemMessageEvent):
289
+ text += f"<system>\n{content}\n</system>\n"
290
+
291
+ elif isinstance(event, UserMessageEvent):
292
+ text += f"<input>\n{content}\n</input>\n"
293
+
294
+ elif isinstance(event, ToolMessageEvent):
295
+ message = event.data
296
+ try:
297
+ content = json.dumps(json.loads(content), indent=2)
298
+ except Exception:
299
+ pass
300
+ text += (
301
+ f"<tool result> [{message.tool_call_id}]\n{content}\n</tool result>\n"
302
+ )
303
+
304
+ return colored(text, color)
305
+
306
+ def _make_packet_text(
307
+ event: ProcPacketOutputEvent | WorkflowResultEvent | RunResultEvent,
308
+ ) -> str:
218
309
  if isinstance(event, WorkflowResultEvent):
219
310
  src = "workflow"
220
311
  elif isinstance(event, RunResultEvent):
@@ -222,6 +313,9 @@ async def print_event_stream(
222
313
  else:
223
314
  src = "processor"
224
315
 
316
+ color = get_color(
317
+ agent_name=event.proc_name or "", role=Role.ASSISTANT, color_by=color_by
318
+ )
225
319
  text = f"\n<{event.proc_name}> [{event.call_id}]\n"
226
320
 
227
321
  if event.data.payloads:
@@ -237,7 +331,9 @@ async def print_event_stream(
237
331
  text += f"{p_str}\n"
238
332
  text += f"</{src} output>\n"
239
333
 
240
- stream_text(text, color)
334
+ return colored(text, color)
335
+
336
+ # ------ Wrap event generator -------
241
337
 
242
338
  async for event in event_generator:
243
339
  yield event
@@ -245,91 +341,12 @@ async def print_event_stream(
245
341
  if isinstance(event, CompletionChunkEvent) and isinstance(
246
342
  event.data, CompletionChunk
247
343
  ):
248
- color = _get_color(event, Role.ASSISTANT)
249
-
250
- text = ""
251
-
252
- if isinstance(event, CompletionStartEvent):
253
- text += f"\n<{event.proc_name}> [{event.call_id}]\n"
254
- elif isinstance(event, ThinkingStartEvent):
255
- text += "<thinking>\n"
256
- elif isinstance(event, ResponseStartEvent):
257
- text += "<response>\n"
258
- elif isinstance(event, ToolCallStartEvent):
259
- tc = event.data.tool_call
260
- text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
261
- elif isinstance(event, AnnotationsStartEvent):
262
- text += "<annotations>\n"
263
-
264
- # if isinstance(event, CompletionEndEvent):
265
- # text += f"\n</{event.proc_name}>\n"
266
- if isinstance(event, ThinkingEndEvent):
267
- text += "\n</thinking>\n"
268
- elif isinstance(event, ResponseEndEvent):
269
- text += "\n</response>\n"
270
- elif isinstance(event, ToolCallEndEvent):
271
- text += "\n</tool call>\n"
272
- elif isinstance(event, AnnotationsEndEvent):
273
- text += "\n</annotations>\n"
274
-
275
- if isinstance(event, ThinkingChunkEvent):
276
- thinking = event.data.thinking
277
- if isinstance(thinking, str):
278
- text += thinking
279
- else:
280
- text = "\n".join(
281
- [block.get("thinking", "[redacted]") for block in thinking]
282
- )
283
-
284
- if isinstance(event, ResponseChunkEvent):
285
- text += event.data.response
286
-
287
- if isinstance(event, ToolCallChunkEvent):
288
- text += event.data.tool_call.tool_arguments or ""
289
-
290
- if isinstance(event, AnnotationsChunkEvent):
291
- text += "\n".join(
292
- [
293
- json.dumps(annotation, indent=2)
294
- for annotation in event.data.annotations
295
- ]
296
- )
297
-
298
- stream_text(text, color)
344
+ stream_colored_text(_make_chunk_text(event))
299
345
 
300
346
  if isinstance(event, MessageEvent) and not isinstance(event, GenMessageEvent):
301
- assert isinstance(event.data, (SystemMessage | UserMessage | ToolMessage))
302
-
303
- message = event.data
304
- role = message.role
305
- content = Printer.content_to_str(message.content, role=role)
306
- color = _get_color(event, role)
307
-
308
- text = f"\n<{event.proc_name}> [{event.call_id}]\n"
309
-
310
- if isinstance(event, (SystemMessageEvent, UserMessageEvent)):
311
- content = Printer.truncate_content_str(content, trunc_len=trunc_len)
312
-
313
- if isinstance(event, SystemMessageEvent):
314
- text += f"<system>\n{content}\n</system>\n"
315
-
316
- elif isinstance(event, UserMessageEvent):
317
- text += f"<input>\n{content}\n</input>\n"
318
-
319
- elif isinstance(event, ToolMessageEvent):
320
- message = event.data
321
- try:
322
- content = json.dumps(json.loads(content), indent=2)
323
- except Exception:
324
- pass
325
- text += (
326
- f"<tool result> [{message.tool_call_id}]\n"
327
- f"{content}\n</tool result>\n"
328
- )
329
-
330
- stream_text(text, color)
347
+ stream_colored_text(_make_message_text(event))
331
348
 
332
349
  if isinstance(
333
350
  event, (ProcPacketOutputEvent, WorkflowResultEvent, RunResultEvent)
334
351
  ):
335
- _print_packet(event)
352
+ stream_colored_text(_make_packet_text(event))
@@ -60,11 +60,13 @@ def with_retry(func: F) -> F:
60
60
  logger.warning(f"{err_message}:\n{err}")
61
61
  else:
62
62
  logger.warning(f"{err_message} after retrying:\n{err}")
63
- raise ProcRunError(proc_name=self.name, call_id=call_id) from err
63
+ # raise ProcRunError(proc_name=self.name, call_id=call_id) from err
64
+ return None # type: ignore[return]
64
65
 
65
66
  logger.warning(f"{err_message} (retry attempt {n_attempt + 1}):\n{err}")
66
67
  # This part should not be reachable due to the raise in the loop
67
- raise ProcRunError(proc_name=self.name, call_id=call_id)
68
+ # raise ProcRunError(proc_name=self.name, call_id=call_id)
69
+ return None # type: ignore[return]
68
70
 
69
71
  return cast("F", wrapper)
70
72
 
@@ -1,11 +1,11 @@
1
1
  from collections import defaultdict
2
- from typing import Any, Generic, TypeVar
2
+ from typing import Generic, TypeVar
3
3
 
4
4
  from pydantic import BaseModel, ConfigDict, Field
5
5
 
6
6
  from grasp_agents.typing.completion import Completion
7
7
 
8
- from .printer import ColoringMode, Printer
8
+ from .printer import Printer
9
9
  from .typing.io import ProcName
10
10
  from .usage_tracker import UsageTracker
11
11
 
@@ -19,13 +19,6 @@ class RunContext(BaseModel, Generic[CtxT]):
19
19
  default_factory=lambda: defaultdict(list)
20
20
  )
21
21
  usage_tracker: UsageTracker = Field(default_factory=UsageTracker)
22
-
23
22
  printer: Printer | None = None
24
- log_messages: bool = False
25
- color_messages_by: ColoringMode = "role"
26
-
27
- def model_post_init(self, context: Any) -> None: # noqa: ARG002
28
- if self.log_messages:
29
- self.printer = Printer(color_by=self.color_messages_by)
30
23
 
31
24
  model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grasp_agents
3
- Version: 0.5.11
3
+ Version: 0.5.13
4
4
  Summary: Grasp Agents Library
5
5
  License-File: LICENSE.md
6
6
  Requires-Python: <4,>=3.11.4
@@ -110,24 +110,16 @@ Create a script, e.g., `problem_recommender.py`:
110
110
 
111
111
  ```python
112
112
  import asyncio
113
- from pathlib import Path
114
113
  from typing import Any
115
114
 
116
115
  from dotenv import load_dotenv
117
116
  from pydantic import BaseModel, Field
118
117
 
119
- from grasp_agents.grasp_logging import setup_logging
118
+ from grasp_agents import LLMAgent, BaseTool, RunContext, Printer
120
119
  from grasp_agents.litellm import LiteLLM, LiteLLMSettings
121
- from grasp_agents import LLMAgent, BaseTool, RunContext
122
-
123
- load_dotenv()
124
120
 
125
121
 
126
- # Configure the logger to output to the console and/or a file
127
- setup_logging(
128
- logs_file_path="grasp_agents_demo.log",
129
- logs_config_path=Path().cwd() / "configs/logging/default.yaml",
130
- )
122
+ load_dotenv()
131
123
 
132
124
  sys_prompt_react = """
133
125
  Your task is to suggest an exciting stats problem to the student.
@@ -162,7 +154,7 @@ Returns:
162
154
  """
163
155
 
164
156
 
165
- class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, Any]):
157
+ class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, None]):
166
158
  name: str = "ask_student"
167
159
  description: str = ask_student_tool_description
168
160
 
@@ -176,11 +168,7 @@ class Problem(BaseModel):
176
168
 
177
169
  teacher = LLMAgent[None, Problem, None](
178
170
  name="teacher",
179
- llm=LiteLLM(
180
- model_name="gpt-4.1",
181
- # model_name="claude-sonnet-4-20250514",
182
- # llm_settings=LiteLLMSettings(reasoning_effort="low"),
183
- ),
171
+ llm=LiteLLM(model_name="gpt-4.1"),
184
172
  tools=[AskStudentTool()],
185
173
  react_mode=True,
186
174
  final_answer_as_tool_call=True,
@@ -188,7 +176,7 @@ teacher = LLMAgent[None, Problem, None](
188
176
  )
189
177
 
190
178
  async def main():
191
- ctx = RunContext[None](log_messages=True)
179
+ ctx = RunContext[None](printer=Printer())
192
180
  out = await teacher.run("start", ctx=ctx)
193
181
  print(out.payloads[0])
194
182
  print(ctx.usage_tracker.total_usage)
@@ -1,20 +1,20 @@
1
- grasp_agents/__init__.py,sha256=Z3a_j2Etiap9H6lvE8-PQP_OIGMUcHNPeJAJO12B8kY,1031
2
- grasp_agents/cloud_llm.py,sha256=vwI6gpLOsFqN4KtaTOo75xw8t7uRtdVrYGjopEDmQBw,13091
1
+ grasp_agents/__init__.py,sha256=0pRU10xjcpuPdCitYtPK_bJVSUZ89FD4Jsmv1DJ_0GY,1121
2
+ grasp_agents/cloud_llm.py,sha256=NwKr1XwpJP-px9xZDI1rngHuDmpSbtJ61VynLacbtNo,13237
3
3
  grasp_agents/costs_dict.yaml,sha256=2MFNWtkv5W5WSCcv1Cj13B1iQLVv5Ot9pS_KW2Gu2DA,2510
4
4
  grasp_agents/errors.py,sha256=K-22TCM1Klhsej47Rg5eTqnGiGPaXgKOpdOZZ7cPipw,4633
5
- grasp_agents/generics_utils.py,sha256=5Pw3I9dlnKC2VGqYKC4ZZUO3Z_vTNT-NPFovNfPkl6I,6542
5
+ grasp_agents/generics_utils.py,sha256=HTX7G8eoylR-zMKz7JDKC-QnDjwFLqMNMjiI8XoGEow,6545
6
6
  grasp_agents/grasp_logging.py,sha256=H1GYhXdQvVkmauFDZ-KDwvVmPQHZUUm9sRqX_ObK2xI,1111
7
7
  grasp_agents/http_client.py,sha256=Es8NXGDkp4Nem7g24-jW0KFGA9Hp_o2Cv3cOvjup-iU,859
8
8
  grasp_agents/llm.py,sha256=IeV2QpR4AldVP3THzSETEnsaDx3DYz5HM6dkikSpy4o,10684
9
- grasp_agents/llm_agent.py,sha256=F_ou0pfdztqZzd2yU1jZZZVzcyhsLXfE_i0c4y2fZIQ,14123
9
+ grasp_agents/llm_agent.py,sha256=Ig4YUsxPGvC444ZGl1eDIwfy46gmyygOe7mleEq0ZYM,14090
10
10
  grasp_agents/llm_agent_memory.py,sha256=XmOT2G8RG5AHd0LR3WuK7VbD-KFFfThmJnuZK2iU3Fs,1856
11
- grasp_agents/llm_policy_executor.py,sha256=r0UxwjnVzTBQqLlwvZZ_JL0wl6ZebCgxkcz6I4GdmrM,18136
11
+ grasp_agents/llm_policy_executor.py,sha256=AVUuulV4QYQG86hHrBf7ldfThAGx3MPnZJdotnxwO8U,17886
12
12
  grasp_agents/memory.py,sha256=keHuNEZNSxHT9FKpMohHOCNi7UAz_oRIc91IQEuzaWE,1162
13
13
  grasp_agents/packet.py,sha256=EmE-W4ZSMVZoqClECGFe7OGqrT4FSJ8IVGICrdjtdEY,1462
14
14
  grasp_agents/packet_pool.py,sha256=AF7ZMYY1U6ppNLEn6o0R8QXyWmcLQGcju7_TYQpAudg,4443
15
- grasp_agents/printer.py,sha256=wVNCaR9mbFKyzYdT8YpYD1JQqRqHdLtdfiZrwYxaM6Y,11132
15
+ grasp_agents/printer.py,sha256=9hw4XOoYlj6mZ3hcH04sxbQQ3rCZqixrn0RIwiXwGh8,11637
16
16
  grasp_agents/prompt_builder.py,sha256=wNPphkW8RL8501jV4Z7ncsN_sxBDR9Ax7eILLHr-OYg,6110
17
- grasp_agents/run_context.py,sha256=7qVs0T5rLvINmtlXqOoyy2Hu9xPzuFDbcVR6R93NF-0,951
17
+ grasp_agents/run_context.py,sha256=0kWvOKBzQzx6FbdtDVAoeCOmiRGssN6X4n8YPX_oLBY,687
18
18
  grasp_agents/runner.py,sha256=JL2wSKahbPYVd56NRB09cwco43sjhZPI4XYFCZyOXOA,5173
19
19
  grasp_agents/usage_tracker.py,sha256=ZQfVUUpG0C89hyPWT_JgXnjQOxoYmumcQ9t-aCfcMo8,3561
20
20
  grasp_agents/utils.py,sha256=qKmGBwrQHw1-BgqRLuGTPKGs3J_zbrpk3nxnP1iZBiQ,6152
@@ -22,7 +22,7 @@ grasp_agents/litellm/__init__.py,sha256=wD8RZBYokFDfbS9Cs7nO_zKb3w7RIVwEGj7g2D5C
22
22
  grasp_agents/litellm/completion_chunk_converters.py,sha256=J5PPxzoTBqkvKQnCoBxQxJo7Q8Xfl9cbv2GRZox8Cjo,2689
23
23
  grasp_agents/litellm/completion_converters.py,sha256=JQ7XvQwwc-biFqVMcRO61SL5VGs_SkUvAhUz1QD7EmU,2516
24
24
  grasp_agents/litellm/converters.py,sha256=XjePHii578sXP26Fyhnv0XfwJ3cNTp5PraggTsvcBXo,4778
25
- grasp_agents/litellm/lite_llm.py,sha256=2XsPB-BbM-Y2xNxsKmO0JOJOD_UYj6ndGMjfLkGPAK4,8279
25
+ grasp_agents/litellm/lite_llm.py,sha256=IXSplFiOksN-DjxkhZmB0eEGYJGr06IVR_ktE-2MVr8,8341
26
26
  grasp_agents/litellm/message_converters.py,sha256=PsGLIJEcAeEoluHIh-utEufJ_9WeMYzXkwnR-8jyULQ,2037
27
27
  grasp_agents/openai/__init__.py,sha256=xaRnblUskiLvypIhMe4NRp9dxCG-gNR7dPiugUbPbhE,4717
28
28
  grasp_agents/openai/completion_chunk_converters.py,sha256=3MnMskdlp7ycsggc1ok1XpCHaP4Us2rLYaxImPLw1eI,2573
@@ -30,9 +30,9 @@ grasp_agents/openai/completion_converters.py,sha256=UlDeQSl0AEFUS-QI5e8rrjfmXZoj
30
30
  grasp_agents/openai/content_converters.py,sha256=sMsZhoatuL_8t0IdVaGWIVZLB4nyi1ajD61GewQmeY4,2503
31
31
  grasp_agents/openai/converters.py,sha256=RKOfMbIJmfFQ7ot0RGR6wrdMbR6_L7PB0UZwxwgM88g,4691
32
32
  grasp_agents/openai/message_converters.py,sha256=fhSN81uK51EGbLyM2-f0MvPX_UBrMy7SF3JQPo-dkXg,4686
33
- grasp_agents/openai/openai_llm.py,sha256=QjxrZ4fM_FX3ncBjehUjWPCCiI62u_W2XDi7nth1WrY,9737
33
+ grasp_agents/openai/openai_llm.py,sha256=DCvoF7MM_OID2UhTkNcBLeEdf8Ej2SQ-6FhgRGe8l1E,9861
34
34
  grasp_agents/openai/tool_converters.py,sha256=rNH5t2Wir9nuy8Ei0jaxNuzDaXGqTLmLz3VyrnJhyn0,1196
35
- grasp_agents/processors/base_processor.py,sha256=BQ2k8dJY0jTMmidXZdK7JLO2YIQkmkp5boF1fT1o6uQ,10838
35
+ grasp_agents/processors/base_processor.py,sha256=2ks_jpeye0XW9k3BcAhD9XO8NctspOV2py9ZAD-dO6c,10942
36
36
  grasp_agents/processors/parallel_processor.py,sha256=BOXRlPaZ-hooz0hHctqiW_5ldR-yDPYjFxuP7fAbZCI,7911
37
37
  grasp_agents/processors/processor.py,sha256=35MtYKrKtCZZMhV-U1DXBXtCNbCvZGaiiXo_5a3tI6s,5249
38
38
  grasp_agents/rate_limiting/__init__.py,sha256=KRgtF_E7R3YfA2cpYcFcZ7wycV0pWVJ0xRQC7YhiIEQ,158
@@ -52,7 +52,7 @@ grasp_agents/workflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3h
52
52
  grasp_agents/workflow/looped_workflow.py,sha256=WHp9O3Za2sBVfY_BLOdvPvtY20XsjZQaWSO2-oAFvOY,6806
53
53
  grasp_agents/workflow/sequential_workflow.py,sha256=e3BIWzy_2novmEWNwIteyMbrzvl1-evHrTBE3r3SpU8,3648
54
54
  grasp_agents/workflow/workflow_processor.py,sha256=DwHz70UOTp9dkbtzH9KE5LkGcT1RdHV7Hdiby0Bu9tw,3535
55
- grasp_agents-0.5.11.dist-info/METADATA,sha256=BkVyEN63RzGsCIJCnm5S38EI2ua9NcbPmr3lRCmWPGs,7021
56
- grasp_agents-0.5.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
57
- grasp_agents-0.5.11.dist-info/licenses/LICENSE.md,sha256=-nNNdWqGB8gJ2O-peFQ2Irshv5tW5pHKyTcYkwvH7CE,1201
58
- grasp_agents-0.5.11.dist-info/RECORD,,
55
+ grasp_agents-0.5.13.dist-info/METADATA,sha256=3ElnYJXh0BeigwCOvFoy4TxT-znLGjEP1rn6rDJQik0,6633
56
+ grasp_agents-0.5.13.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
57
+ grasp_agents-0.5.13.dist-info/licenses/LICENSE.md,sha256=-nNNdWqGB8gJ2O-peFQ2Irshv5tW5pHKyTcYkwvH7CE,1201
58
+ grasp_agents-0.5.13.dist-info/RECORD,,