grasp_agents 0.5.10__py3-none-any.whl → 0.5.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/printer.py CHANGED
@@ -72,50 +72,61 @@ ColoringMode: TypeAlias = Literal["agent", "role"]
72
72
  CompletionBlockType: TypeAlias = Literal["response", "thinking", "tool_call"]
73
73
 
74
74
 
75
- class Printer:
76
- def __init__(
77
- self, color_by: ColoringMode = "role", msg_trunc_len: int = 20000
78
- ) -> None:
79
- self.color_by = color_by
80
- self.msg_trunc_len = msg_trunc_len
81
- self._current_message: str = ""
75
+ def stream_colored_text(new_colored_text: str) -> None:
76
+ sys.stdout.write(new_colored_text)
77
+ sys.stdout.flush()
82
78
 
83
- @staticmethod
84
- def get_role_color(role: Role) -> Color:
85
- return ROLE_TO_COLOR[role]
86
79
 
87
- @staticmethod
88
- def get_agent_color(agent_name: str) -> Color:
80
+ def get_color(
81
+ agent_name: str = "", role: Role = Role.ASSISTANT, color_by: ColoringMode = "role"
82
+ ) -> Color:
83
+ if color_by == "agent":
89
84
  idx = int(
90
85
  hashlib.md5(agent_name.encode()).hexdigest(), # noqa :S324
91
86
  16,
92
87
  ) % len(AVAILABLE_COLORS)
93
88
 
94
89
  return AVAILABLE_COLORS[idx]
90
+ return ROLE_TO_COLOR[role]
91
+
95
92
 
96
- @staticmethod
97
- def content_to_str(content: Content | str | None, role: Role) -> str:
98
- if role == Role.USER and isinstance(content, Content):
99
- content_str_parts: list[str] = []
100
- for content_part in content.parts:
101
- if isinstance(content_part, ContentPartText):
102
- content_str_parts.append(content_part.data.strip(" \n"))
103
- elif content_part.data.type == "url":
104
- content_str_parts.append(str(content_part.data.url))
105
- elif content_part.data.type == "base64":
106
- content_str_parts.append("<ENCODED_IMAGE>")
107
- return "\n".join(content_str_parts)
93
+ def content_to_str(content: Content | str | None, role: Role) -> str:
94
+ if role == Role.USER and isinstance(content, Content):
95
+ content_str_parts: list[str] = []
96
+ for content_part in content.parts:
97
+ if isinstance(content_part, ContentPartText):
98
+ content_str_parts.append(content_part.data.strip(" \n"))
99
+ elif content_part.data.type == "url":
100
+ content_str_parts.append(str(content_part.data.url))
101
+ elif content_part.data.type == "base64":
102
+ content_str_parts.append("<ENCODED_IMAGE>")
103
+ return "\n".join(content_str_parts)
108
104
 
109
- assert isinstance(content, str | None)
105
+ assert isinstance(content, str | None)
110
106
 
111
- return (content or "").strip(" \n")
107
+ return (content or "").strip(" \n")
112
108
 
113
- @staticmethod
114
- def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
115
- if len(content_str) > trunc_len:
116
- return content_str[:trunc_len] + "[...]"
117
109
 
118
- return content_str
110
+ def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
111
+ if len(content_str) > trunc_len:
112
+ return content_str[:trunc_len] + "[...]"
113
+
114
+ return content_str
115
+
116
+
117
+ class Printer:
118
+ def __init__(
119
+ self,
120
+ color_by: ColoringMode = "role",
121
+ msg_trunc_len: int = 20000,
122
+ output_to: Literal["print", "log"] = "print",
123
+ logging_level: Literal["info", "debug", "warning", "error"] = "info",
124
+ ) -> None:
125
+ self.color_by: ColoringMode = color_by
126
+ self.msg_trunc_len = msg_trunc_len
127
+ self._current_message: str = ""
128
+ self._logging_level = logging_level
129
+ self._output_to = output_to
119
130
 
120
131
  def print_message(
121
132
  self,
@@ -128,11 +139,8 @@ class Printer:
128
139
  raise ValueError(
129
140
  "Usage information can only be printed for AssistantMessage"
130
141
  )
131
-
132
- color = (
133
- self.get_agent_color(agent_name)
134
- if self.color_by == "agent"
135
- else self.get_role_color(message.role)
142
+ color = get_color(
143
+ agent_name=agent_name, role=message.role, color_by=self.color_by
136
144
  )
137
145
  log_kwargs = {"extra": {"color": color}}
138
146
 
@@ -144,13 +152,13 @@ class Printer:
144
152
  out += f"<thinking>\n{thinking}\n</thinking>\n"
145
153
 
146
154
  # Content
147
- content = self.content_to_str(message.content or "", message.role)
155
+ content = content_to_str(message.content or "", message.role)
148
156
  if content:
149
157
  try:
150
158
  content = json.dumps(json.loads(content), indent=2)
151
159
  except Exception:
152
160
  pass
153
- content = self.truncate_content_str(content, trunc_len=self.msg_trunc_len)
161
+ content = truncate_content_str(content, trunc_len=self.msg_trunc_len)
154
162
  if isinstance(message, SystemMessage):
155
163
  out += f"<system>\n{content}\n</system>\n"
156
164
  elif isinstance(message, UserMessage):
@@ -176,7 +184,17 @@ class Printer:
176
184
 
177
185
  out += f"\n------------------------------------\n{usage_str}\n"
178
186
 
179
- logger.debug(out, **log_kwargs) # type: ignore
187
+ if self._output_to == "log":
188
+ if self._logging_level == "debug":
189
+ logger.debug(out, **log_kwargs) # type: ignore
190
+ elif self._logging_level == "info":
191
+ logger.info(out, **log_kwargs) # type: ignore
192
+ elif self._logging_level == "warning":
193
+ logger.warning(out, **log_kwargs) # type: ignore
194
+ else:
195
+ logger.error(out, **log_kwargs) # type: ignore
196
+ else:
197
+ stream_colored_text(colored(out + "\n", color))
180
198
 
181
199
  def print_messages(
182
200
  self,
@@ -193,28 +211,101 @@ class Printer:
193
211
  )
194
212
 
195
213
 
196
- def stream_text(new_text: str, color: Color) -> None:
197
- sys.stdout.write(colored(new_text, color))
198
- sys.stdout.flush()
199
-
200
-
201
214
  async def print_event_stream(
202
215
  event_generator: AsyncIterator[Event[Any]],
203
216
  color_by: ColoringMode = "role",
204
217
  trunc_len: int = 10000,
205
218
  ) -> AsyncIterator[Event[Any]]:
206
- color = Printer.get_role_color(Role.ASSISTANT)
219
+ def _make_chunk_text(event: CompletionChunkEvent[CompletionChunk]) -> str:
220
+ color = get_color(
221
+ agent_name=event.proc_name or "", role=Role.ASSISTANT, color_by=color_by
222
+ )
223
+ text = ""
224
+
225
+ if isinstance(event, CompletionStartEvent):
226
+ text += f"\n<{event.proc_name}> [{event.call_id}]\n"
227
+ elif isinstance(event, ThinkingStartEvent):
228
+ text += "<thinking>\n"
229
+ elif isinstance(event, ResponseStartEvent):
230
+ text += "<response>\n"
231
+ elif isinstance(event, ToolCallStartEvent):
232
+ tc = event.data.tool_call
233
+ text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
234
+ elif isinstance(event, AnnotationsStartEvent):
235
+ text += "<annotations>\n"
236
+
237
+ # if isinstance(event, CompletionEndEvent):
238
+ # text += f"\n</{event.proc_name}>\n"
239
+ if isinstance(event, ThinkingEndEvent):
240
+ text += "\n</thinking>\n"
241
+ elif isinstance(event, ResponseEndEvent):
242
+ text += "\n</response>\n"
243
+ elif isinstance(event, ToolCallEndEvent):
244
+ text += "\n</tool call>\n"
245
+ elif isinstance(event, AnnotationsEndEvent):
246
+ text += "\n</annotations>\n"
247
+
248
+ if isinstance(event, ThinkingChunkEvent):
249
+ thinking = event.data.thinking
250
+ if isinstance(thinking, str):
251
+ text += thinking
252
+ else:
253
+ text = "\n".join(
254
+ [block.get("thinking", "[redacted]") for block in thinking]
255
+ )
207
256
 
208
- def _get_color(event: Event[Any], role: Role = Role.ASSISTANT) -> Color:
209
- if color_by == "agent":
210
- return Printer.get_agent_color(event.proc_name or "")
211
- return Printer.get_role_color(role)
257
+ if isinstance(event, ResponseChunkEvent):
258
+ text += event.data.response
212
259
 
213
- def _print_packet(
214
- event: ProcPacketOutputEvent | WorkflowResultEvent | RunResultEvent,
215
- ) -> None:
216
- color = _get_color(event, Role.ASSISTANT)
260
+ if isinstance(event, ToolCallChunkEvent):
261
+ text += event.data.tool_call.tool_arguments or ""
217
262
 
263
+ if isinstance(event, AnnotationsChunkEvent):
264
+ text += "\n".join(
265
+ [
266
+ json.dumps(annotation, indent=2)
267
+ for annotation in event.data.annotations
268
+ ]
269
+ )
270
+
271
+ return colored(text, color)
272
+
273
+ def _make_message_text(
274
+ event: MessageEvent[SystemMessage | UserMessage | ToolMessage],
275
+ ) -> str:
276
+ message = event.data
277
+ role = message.role
278
+ content = content_to_str(message.content, role=role)
279
+
280
+ color = get_color(
281
+ agent_name=event.proc_name or "", role=role, color_by=color_by
282
+ )
283
+ text = f"\n<{event.proc_name}> [{event.call_id}]\n"
284
+
285
+ if isinstance(event, (SystemMessageEvent, UserMessageEvent)):
286
+ content = truncate_content_str(content, trunc_len=trunc_len)
287
+
288
+ if isinstance(event, SystemMessageEvent):
289
+ text += f"<system>\n{content}\n</system>\n"
290
+
291
+ elif isinstance(event, UserMessageEvent):
292
+ text += f"<input>\n{content}\n</input>\n"
293
+
294
+ elif isinstance(event, ToolMessageEvent):
295
+ message = event.data
296
+ try:
297
+ content = json.dumps(json.loads(content), indent=2)
298
+ except Exception:
299
+ pass
300
+ text += (
301
+ f"<tool result> [{message.tool_call_id}]\n{content}\n</tool result>\n"
302
+ )
303
+
304
+ return colored(text, color)
305
+
306
+ def _make_packet_text(
307
+ event: ProcPacketOutputEvent | WorkflowResultEvent | RunResultEvent,
308
+ ) -> str:
218
309
  if isinstance(event, WorkflowResultEvent):
219
310
  src = "workflow"
220
311
  elif isinstance(event, RunResultEvent):
@@ -222,6 +313,9 @@ async def print_event_stream(
222
313
  else:
223
314
  src = "processor"
224
315
 
316
+ color = get_color(
317
+ agent_name=event.proc_name or "", role=Role.ASSISTANT, color_by=color_by
318
+ )
225
319
  text = f"\n<{event.proc_name}> [{event.call_id}]\n"
226
320
 
227
321
  if event.data.payloads:
@@ -237,7 +331,9 @@ async def print_event_stream(
237
331
  text += f"{p_str}\n"
238
332
  text += f"</{src} output>\n"
239
333
 
240
- stream_text(text, color)
334
+ return colored(text, color)
335
+
336
+ # ------ Wrap event generator -------
241
337
 
242
338
  async for event in event_generator:
243
339
  yield event
@@ -245,91 +341,12 @@ async def print_event_stream(
245
341
  if isinstance(event, CompletionChunkEvent) and isinstance(
246
342
  event.data, CompletionChunk
247
343
  ):
248
- color = _get_color(event, Role.ASSISTANT)
249
-
250
- text = ""
251
-
252
- if isinstance(event, CompletionStartEvent):
253
- text += f"\n<{event.proc_name}> [{event.call_id}]\n"
254
- elif isinstance(event, ThinkingStartEvent):
255
- text += "<thinking>\n"
256
- elif isinstance(event, ResponseStartEvent):
257
- text += "<response>\n"
258
- elif isinstance(event, ToolCallStartEvent):
259
- tc = event.data.tool_call
260
- text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
261
- elif isinstance(event, AnnotationsStartEvent):
262
- text += "<annotations>\n"
263
-
264
- # if isinstance(event, CompletionEndEvent):
265
- # text += f"\n</{event.proc_name}>\n"
266
- if isinstance(event, ThinkingEndEvent):
267
- text += "\n</thinking>\n"
268
- elif isinstance(event, ResponseEndEvent):
269
- text += "\n</response>\n"
270
- elif isinstance(event, ToolCallEndEvent):
271
- text += "\n</tool call>\n"
272
- elif isinstance(event, AnnotationsEndEvent):
273
- text += "\n</annotations>\n"
274
-
275
- if isinstance(event, ThinkingChunkEvent):
276
- thinking = event.data.thinking
277
- if isinstance(thinking, str):
278
- text += thinking
279
- else:
280
- text = "\n".join(
281
- [block.get("thinking", "[redacted]") for block in thinking]
282
- )
283
-
284
- if isinstance(event, ResponseChunkEvent):
285
- text += event.data.response
286
-
287
- if isinstance(event, ToolCallChunkEvent):
288
- text += event.data.tool_call.tool_arguments or ""
289
-
290
- if isinstance(event, AnnotationsChunkEvent):
291
- text += "\n".join(
292
- [
293
- json.dumps(annotation, indent=2)
294
- for annotation in event.data.annotations
295
- ]
296
- )
297
-
298
- stream_text(text, color)
344
+ stream_colored_text(_make_chunk_text(event))
299
345
 
300
346
  if isinstance(event, MessageEvent) and not isinstance(event, GenMessageEvent):
301
- assert isinstance(event.data, (SystemMessage | UserMessage | ToolMessage))
302
-
303
- message = event.data
304
- role = message.role
305
- content = Printer.content_to_str(message.content, role=role)
306
- color = _get_color(event, role)
307
-
308
- text = f"\n<{event.proc_name}> [{event.call_id}]\n"
309
-
310
- if isinstance(event, (SystemMessageEvent, UserMessageEvent)):
311
- content = Printer.truncate_content_str(content, trunc_len=trunc_len)
312
-
313
- if isinstance(event, SystemMessageEvent):
314
- text += f"<system>\n{content}\n</system>\n"
315
-
316
- elif isinstance(event, UserMessageEvent):
317
- text += f"<input>\n{content}\n</input>\n"
318
-
319
- elif isinstance(event, ToolMessageEvent):
320
- message = event.data
321
- try:
322
- content = json.dumps(json.loads(content), indent=2)
323
- except Exception:
324
- pass
325
- text += (
326
- f"<tool result> [{message.tool_call_id}]\n"
327
- f"{content}\n</tool result>\n"
328
- )
329
-
330
- stream_text(text, color)
347
+ stream_colored_text(_make_message_text(event))
331
348
 
332
349
  if isinstance(
333
350
  event, (ProcPacketOutputEvent, WorkflowResultEvent, RunResultEvent)
334
351
  ):
335
- _print_packet(event)
352
+ stream_colored_text(_make_packet_text(event))
@@ -60,11 +60,13 @@ def with_retry(func: F) -> F:
60
60
  logger.warning(f"{err_message}:\n{err}")
61
61
  else:
62
62
  logger.warning(f"{err_message} after retrying:\n{err}")
63
- raise ProcRunError(proc_name=self.name, call_id=call_id) from err
63
+ # raise ProcRunError(proc_name=self.name, call_id=call_id) from err
64
+ return None # type: ignore[return]
64
65
 
65
66
  logger.warning(f"{err_message} (retry attempt {n_attempt + 1}):\n{err}")
66
67
  # This part should not be reachable due to the raise in the loop
67
- raise ProcRunError(proc_name=self.name, call_id=call_id)
68
+ # raise ProcRunError(proc_name=self.name, call_id=call_id)
69
+ return None # type: ignore[return]
68
70
 
69
71
  return cast("F", wrapper)
70
72
 
@@ -106,7 +108,7 @@ _OutT_contra = TypeVar("_OutT_contra", contravariant=True)
106
108
 
107
109
  class RecipientSelector(Protocol[_OutT_contra, CtxT]):
108
110
  def __call__(
109
- self, output: _OutT_contra, ctx: RunContext[CtxT]
111
+ self, output: _OutT_contra, *, ctx: RunContext[CtxT]
110
112
  ) -> Sequence[ProcName] | None: ...
111
113
 
112
114
 
@@ -114,7 +114,7 @@ class ParallelProcessor(
114
114
  ctx: RunContext[CtxT] | None = None,
115
115
  ) -> Packet[OutT]:
116
116
  call_id = self._generate_call_id(call_id)
117
- ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
117
+ ctx = ctx or RunContext[CtxT](state=None) # type: ignore
118
118
 
119
119
  val_in_args = self._validate_inputs(
120
120
  call_id=call_id,
@@ -223,7 +223,7 @@ class ParallelProcessor(
223
223
  ctx: RunContext[CtxT] | None = None,
224
224
  ) -> AsyncIterator[Event[Any]]:
225
225
  call_id = self._generate_call_id(call_id)
226
- ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
226
+ ctx = ctx or RunContext[CtxT](state=None) # type: ignore
227
227
 
228
228
  val_in_args = self._validate_inputs(
229
229
  call_id=call_id,
@@ -105,7 +105,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
105
105
  call_id: str | None = None,
106
106
  ctx: RunContext[CtxT] | None = None,
107
107
  ) -> Packet[OutT]:
108
- ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
108
+ ctx = ctx or RunContext[CtxT](state=None) # type: ignore
109
109
 
110
110
  val_in_args, memory, call_id = self._preprocess(
111
111
  chat_inputs=chat_inputs,
@@ -136,7 +136,7 @@ class Processor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, C
136
136
  call_id: str | None = None,
137
137
  ctx: RunContext[CtxT] | None = None,
138
138
  ) -> AsyncIterator[Event[Any]]:
139
- ctx = RunContext[CtxT](state=None) if ctx is None else ctx # type: ignore
139
+ ctx = ctx or RunContext[CtxT](state=None) # type: ignore
140
140
 
141
141
  val_in_args, memory, call_id = self._preprocess(
142
142
  chat_inputs=chat_inputs,
@@ -15,11 +15,22 @@ _InT_contra = TypeVar("_InT_contra", contravariant=True)
15
15
 
16
16
 
17
17
  class SystemPromptBuilder(Protocol[CtxT]):
18
- def __call__(self, ctx: RunContext[CtxT]) -> str | None: ...
18
+ def __call__(
19
+ self,
20
+ *,
21
+ ctx: RunContext[CtxT],
22
+ call_id: str,
23
+ ) -> str | None: ...
19
24
 
20
25
 
21
26
  class InputContentBuilder(Protocol[_InT_contra, CtxT]):
22
- def __call__(self, in_args: _InT_contra, ctx: RunContext[CtxT]) -> Content: ...
27
+ def __call__(
28
+ self,
29
+ in_args: _InT_contra,
30
+ *,
31
+ ctx: RunContext[CtxT],
32
+ call_id: str,
33
+ ) -> Content: ...
23
34
 
24
35
 
25
36
  PromptArgumentType: TypeAlias = str | bool | int | ImageData
@@ -43,9 +54,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
43
54
  self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
44
55
 
45
56
  @final
46
- def build_system_prompt(self, ctx: RunContext[CtxT]) -> str | None:
57
+ def build_system_prompt(self, ctx: RunContext[CtxT], call_id: str) -> str | None:
47
58
  if self.system_prompt_builder:
48
- return self.system_prompt_builder(ctx=ctx)
59
+ return self.system_prompt_builder(ctx=ctx, call_id=call_id)
49
60
 
50
61
  return self.sys_prompt
51
62
 
@@ -71,7 +82,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
71
82
 
72
83
  @final
73
84
  def _build_input_content(
74
- self, in_args: InT | None, ctx: RunContext[CtxT]
85
+ self, in_args: InT | None, ctx: RunContext[CtxT], call_id: str
75
86
  ) -> Content:
76
87
  if in_args is None and self._in_type is not type(None):
77
88
  raise InputPromptBuilderError(
@@ -83,7 +94,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
83
94
 
84
95
  val_in_args = self._validate_input_args(in_args)
85
96
  if self.input_content_builder:
86
- return self.input_content_builder(in_args=val_in_args, ctx=ctx)
97
+ return self.input_content_builder(
98
+ in_args=val_in_args, ctx=ctx, call_id=call_id
99
+ )
87
100
 
88
101
  if issubclass(self._in_type, BaseModel) and isinstance(val_in_args, BaseModel):
89
102
  val_in_args_map = self._format_pydantic_prompt_args(val_in_args)
@@ -102,6 +115,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
102
115
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
103
116
  *,
104
117
  in_args: InT | None = None,
118
+ call_id: str,
105
119
  ctx: RunContext[CtxT],
106
120
  ) -> UserMessage | None:
107
121
  if chat_inputs is not None:
@@ -116,7 +130,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
116
130
  return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
117
131
 
118
132
  return UserMessage(
119
- content=self._build_input_content(in_args=in_args, ctx=ctx),
133
+ content=self._build_input_content(
134
+ in_args=in_args, ctx=ctx, call_id=call_id
135
+ ),
120
136
  name=self._agent_name,
121
137
  )
122
138
 
@@ -1,11 +1,11 @@
1
1
  from collections import defaultdict
2
- from typing import Any, Generic, TypeVar
2
+ from typing import Generic, TypeVar
3
3
 
4
4
  from pydantic import BaseModel, ConfigDict, Field
5
5
 
6
6
  from grasp_agents.typing.completion import Completion
7
7
 
8
- from .printer import ColoringMode, Printer
8
+ from .printer import Printer
9
9
  from .typing.io import ProcName
10
10
  from .usage_tracker import UsageTracker
11
11
 
@@ -19,13 +19,6 @@ class RunContext(BaseModel, Generic[CtxT]):
19
19
  default_factory=lambda: defaultdict(list)
20
20
  )
21
21
  usage_tracker: UsageTracker = Field(default_factory=UsageTracker)
22
-
23
22
  printer: Printer | None = None
24
- log_messages: bool = False
25
- color_messages_by: ColoringMode = "role"
26
-
27
- def model_post_init(self, context: Any) -> None: # noqa: ARG002
28
- if self.log_messages:
29
- self.printer = Printer(color_by=self.color_messages_by)
30
23
 
31
24
  model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
@@ -64,20 +64,22 @@ class BaseTool(
64
64
  self,
65
65
  inp: _InT,
66
66
  *,
67
- call_id: str | None = None,
68
67
  ctx: RunContext[CtxT] | None = None,
68
+ call_id: str | None = None,
69
69
  ) -> _OutT_co:
70
70
  pass
71
71
 
72
72
  async def __call__(
73
73
  self,
74
74
  *,
75
- call_id: str | None = None,
76
75
  ctx: RunContext[CtxT] | None = None,
76
+ call_id: str | None = None,
77
77
  **kwargs: Any,
78
78
  ) -> _OutT_co:
79
+ # NOTE: validation is probably redundant here when tool inputs have been
80
+ # validated by the LLM already
79
81
  input_args = TypeAdapter(self._in_type).validate_python(kwargs)
80
- output = await self.run(input_args, call_id=call_id, ctx=ctx)
82
+ output = await self.run(input_args, ctx=ctx, call_id=call_id)
81
83
 
82
84
  return TypeAdapter(self._out_type).validate_python(output)
83
85
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grasp_agents
3
- Version: 0.5.10
3
+ Version: 0.5.12
4
4
  Summary: Grasp Agents Library
5
5
  License-File: LICENSE.md
6
6
  Requires-Python: <4,>=3.11.4
@@ -110,24 +110,16 @@ Create a script, e.g., `problem_recommender.py`:
110
110
 
111
111
  ```python
112
112
  import asyncio
113
- from pathlib import Path
114
113
  from typing import Any
115
114
 
116
115
  from dotenv import load_dotenv
117
116
  from pydantic import BaseModel, Field
118
117
 
119
- from grasp_agents.grasp_logging import setup_logging
118
+ from grasp_agents import LLMAgent, BaseTool, RunContext, Printer
120
119
  from grasp_agents.litellm import LiteLLM, LiteLLMSettings
121
- from grasp_agents import LLMAgent, BaseTool, RunContext
122
-
123
- load_dotenv()
124
120
 
125
121
 
126
- # Configure the logger to output to the console and/or a file
127
- setup_logging(
128
- logs_file_path="grasp_agents_demo.log",
129
- logs_config_path=Path().cwd() / "configs/logging/default.yaml",
130
- )
122
+ load_dotenv()
131
123
 
132
124
  sys_prompt_react = """
133
125
  Your task is to suggest an exciting stats problem to the student.
@@ -162,13 +154,11 @@ Returns:
162
154
  """
163
155
 
164
156
 
165
- class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, Any]):
157
+ class AskStudentTool(BaseTool[TeacherQuestion, StudentReply, None]):
166
158
  name: str = "ask_student"
167
159
  description: str = ask_student_tool_description
168
160
 
169
- async def run(
170
- self, inp: TeacherQuestion, ctx: RunContext[Any] | None = None
171
- ) -> StudentReply:
161
+ async def run(self, inp: TeacherQuestion, **kwargs: Any) -> StudentReply:
172
162
  return input(inp.question)
173
163
 
174
164
 
@@ -178,10 +168,7 @@ class Problem(BaseModel):
178
168
 
179
169
  teacher = LLMAgent[None, Problem, None](
180
170
  name="teacher",
181
- llm=LiteLLM(
182
- model_name="gpt-4.1",
183
- llm_settings=LiteLLMSettings(temperature=0.5),
184
- ),
171
+ llm=LiteLLM(model_name="gpt-4.1"),
185
172
  tools=[AskStudentTool()],
186
173
  react_mode=True,
187
174
  final_answer_as_tool_call=True,
@@ -189,7 +176,7 @@ teacher = LLMAgent[None, Problem, None](
189
176
  )
190
177
 
191
178
  async def main():
192
- ctx = RunContext[None](log_messages=True)
179
+ ctx = RunContext[None](printer=Printer())
193
180
  out = await teacher.run("start", ctx=ctx)
194
181
  print(out.payloads[0])
195
182
  print(ctx.usage_tracker.total_usage)