grasp_agents 0.3.11__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  import json
3
3
  from collections.abc import AsyncIterator, Coroutine, Sequence
4
- from itertools import chain, starmap
4
+ from itertools import starmap
5
5
  from logging import getLogger
6
6
  from typing import Any, ClassVar, Generic, Protocol, TypeVar
7
7
 
@@ -132,62 +132,58 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
132
132
  if self.manage_memory_impl:
133
133
  self.manage_memory_impl(memory=memory, ctx=ctx, **kwargs)
134
134
 
135
- async def generate_message_batch(
135
+ async def generate_messages(
136
136
  self,
137
137
  memory: LLMAgentMemory,
138
+ run_id: str,
138
139
  tool_choice: ToolChoice | None = None,
139
140
  ctx: RunContext[CtxT] | None = None,
140
141
  ) -> Sequence[AssistantMessage]:
141
- completion_batch = await self.llm.generate_completion_batch(
142
+ completion = await self.llm.generate_completion(
142
143
  memory.message_history, tool_choice=tool_choice
143
144
  )
144
- message_batch = list(
145
- chain.from_iterable([c.messages for c in completion_batch])
146
- )
147
- memory.update(message_batch=message_batch)
145
+ memory.update(completion.messages)
148
146
 
149
147
  if ctx is not None:
150
- ctx.completions[self.agent_name].extend(completion_batch)
151
- self._track_usage(self.agent_name, completion_batch, ctx=ctx)
152
- self._print_completions(completion_batch, ctx=ctx)
148
+ ctx.completions[self.agent_name].append(completion)
149
+ self._track_usage(self.agent_name, completion, ctx=ctx)
150
+ self._print_completion(completion, run_id=run_id, ctx=ctx)
153
151
 
154
- return message_batch
152
+ return completion.messages
155
153
 
156
- async def generate_message_stream(
154
+ async def generate_messages_stream(
157
155
  self,
158
156
  memory: LLMAgentMemory,
157
+ run_id: str,
159
158
  tool_choice: ToolChoice | None = None,
160
159
  ctx: RunContext[CtxT] | None = None,
161
160
  ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | GenMessageEvent]:
162
161
  message_hist = memory.message_history
163
- if memory.message_history.batch_size > 1:
164
- raise ValueError("Batch size must be 1 when streaming completions.")
165
- conversation = message_hist.conversations[0]
166
162
 
167
163
  completion: Completion | None = None
168
164
  async for event in await self.llm.generate_completion_stream(
169
- conversation, tool_choice=tool_choice
165
+ message_hist, tool_choice=tool_choice
170
166
  ):
171
167
  yield event
172
168
  if isinstance(event, CompletionEvent):
173
169
  completion = event.data
174
-
175
170
  if completion is None:
176
171
  raise RuntimeError("No completion generated during stream.")
177
172
 
178
- memory.update(message_batch=completion.messages)
173
+ memory.update(completion.messages)
179
174
 
180
175
  for message in completion.messages:
181
176
  yield GenMessageEvent(name=self.agent_name, data=message)
182
177
 
183
178
  if ctx is not None:
184
- self._track_usage(self.agent_name, [completion], ctx=ctx)
179
+ self._track_usage(self.agent_name, completion, ctx=ctx)
185
180
  ctx.completions[self.agent_name].append(completion)
186
181
 
187
182
  async def call_tools(
188
183
  self,
189
184
  calls: Sequence[ToolCall],
190
185
  memory: LLMAgentMemory,
186
+ run_id: str,
191
187
  ctx: RunContext[CtxT] | None = None,
192
188
  ) -> Sequence[ToolMessage]:
193
189
  corouts: list[Coroutine[Any, Any, BaseModel]] = []
@@ -198,12 +194,14 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
198
194
 
199
195
  outs = await asyncio.gather(*corouts)
200
196
  tool_messages = list(
201
- starmap(ToolMessage.from_tool_output, zip(outs, calls, strict=False))
197
+ starmap(ToolMessage.from_tool_output, zip(outs, calls, strict=True))
202
198
  )
203
199
  memory.update(tool_messages)
204
200
 
205
201
  if ctx is not None:
206
- ctx.printer.print_llm_messages(tool_messages, agent_name=self.agent_name)
202
+ ctx.printer.print_llm_messages(
203
+ tool_messages, agent_name=self.agent_name, run_id=run_id
204
+ )
207
205
 
208
206
  return tool_messages
209
207
 
@@ -211,10 +209,13 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
211
209
  self,
212
210
  calls: Sequence[ToolCall],
213
211
  memory: LLMAgentMemory,
212
+ run_id: str,
214
213
  ctx: RunContext[CtxT] | None = None,
215
214
  ) -> AsyncIterator[ToolMessageEvent]:
216
- tool_messages = await self.call_tools(calls, memory=memory, ctx=ctx)
217
- for tool_message, call in zip(tool_messages, calls, strict=False):
215
+ tool_messages = await self.call_tools(
216
+ calls, memory=memory, run_id=run_id, ctx=ctx
217
+ )
218
+ for tool_message, call in zip(tool_messages, calls, strict=True):
218
219
  yield ToolMessageEvent(name=call.tool_name, data=tool_message)
219
220
 
220
221
  def _extract_final_answer_from_tool_calls(
@@ -233,7 +234,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
233
234
  return final_answer_message
234
235
 
235
236
  async def _generate_final_answer(
236
- self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
237
+ self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
237
238
  ) -> AssistantMessage:
238
239
  assert self._final_answer_tool_name is not None
239
240
 
@@ -242,11 +243,15 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
242
243
  )
243
244
  memory.update([user_message])
244
245
  if ctx is not None:
245
- ctx.printer.print_llm_messages([user_message], agent_name=self.agent_name)
246
+ ctx.printer.print_llm_messages(
247
+ [user_message], agent_name=self.agent_name, run_id=run_id
248
+ )
246
249
 
247
250
  tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
248
251
  gen_message = (
249
- await self.generate_message_batch(memory, tool_choice=tool_choice, ctx=ctx)
252
+ await self.generate_messages(
253
+ memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
254
+ )
250
255
  )[0]
251
256
 
252
257
  final_answer_message = self._extract_final_answer_from_tool_calls(
@@ -260,7 +265,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
260
265
  return final_answer_message
261
266
 
262
267
  async def _generate_final_answer_stream(
263
- self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
268
+ self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
264
269
  ) -> AsyncIterator[Event[Any]]:
265
270
  assert self._final_answer_tool_name is not None
266
271
 
@@ -272,8 +277,8 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
272
277
 
273
278
  tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
274
279
  event: Event[Any] | None = None
275
- async for event in self.generate_message_stream(
276
- memory, tool_choice=tool_choice, ctx=ctx
280
+ async for event in self.generate_messages_stream(
281
+ memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
277
282
  ):
278
283
  yield event
279
284
 
@@ -289,7 +294,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
289
294
  yield GenMessageEvent(name=self.agent_name, data=final_answer_message)
290
295
 
291
296
  async def execute(
292
- self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
297
+ self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
293
298
  ) -> AssistantMessage | Sequence[AssistantMessage]:
294
299
  # 1. Generate the first message:
295
300
  # In ReAct mode, we generate the first message without tool calls
@@ -297,26 +302,24 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
297
302
  tool_choice: ToolChoice | None = None
298
303
  if self.tools:
299
304
  tool_choice = "none" if self._react_mode else "auto"
300
- gen_message_batch = await self.generate_message_batch(
301
- memory, tool_choice=tool_choice, ctx=ctx
305
+ gen_messages = await self.generate_messages(
306
+ memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
302
307
  )
303
308
  if not self.tools:
304
- return gen_message_batch
309
+ return gen_messages
305
310
 
306
- if memory.message_history.batch_size > 1:
307
- raise ValueError("Batch size must be 1 for tool call loop.")
308
- gen_message = gen_message_batch[0]
311
+ if len(gen_messages) > 1:
312
+ raise ValueError("n_choices must be 1 when executing the tool call loop.")
313
+ gen_message = gen_messages[0]
309
314
  turns = 0
310
315
 
311
316
  while True:
312
- conversation = memory.message_history.conversations[0]
313
-
314
317
  # 2. Check if we should exit the tool call loop
315
318
 
316
319
  # When final_answer_tool_name is None, we use exit_tool_call_loop_impl
317
320
  # to determine whether to exit the loop.
318
321
  if self._final_answer_tool_name is None and self._exit_tool_call_loop(
319
- conversation, ctx=ctx, num_turns=turns
322
+ memory.message_history, ctx=ctx, num_turns=turns
320
323
  ):
321
324
  return gen_message
322
325
 
@@ -336,7 +339,9 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
336
339
  # tool call.
337
340
  # Otherwise, we simply return the last generated message.
338
341
  if self._final_answer_tool_name is not None:
339
- final_answer = await self._generate_final_answer(memory, ctx=ctx)
342
+ final_answer = await self._generate_final_answer(
343
+ memory, run_id=run_id, ctx=ctx
344
+ )
340
345
  else:
341
346
  final_answer = gen_message
342
347
  logger.info(
@@ -347,7 +352,9 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
347
352
  # 3. Call tools if there are any tool calls in the generated message.
348
353
 
349
354
  if gen_message.tool_calls:
350
- await self.call_tools(gen_message.tool_calls, memory=memory, ctx=ctx)
355
+ await self.call_tools(
356
+ gen_message.tool_calls, memory=memory, run_id=run_id, ctx=ctx
357
+ )
351
358
 
352
359
  # Apply the memory management function if provided.
353
360
  self._manage_memory(memory, ctx=ctx, num_turns=turns)
@@ -363,23 +370,20 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
363
370
  "none" if (self._react_mode and gen_message.tool_calls) else "required"
364
371
  )
365
372
  gen_message = (
366
- await self.generate_message_batch(
367
- memory, tool_choice=tool_choice, ctx=ctx
373
+ await self.generate_messages(
374
+ memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
368
375
  )
369
376
  )[0]
370
377
 
371
378
  turns += 1
372
379
 
373
380
  async def execute_stream(
374
- self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
381
+ self, memory: LLMAgentMemory, run_id: str, ctx: RunContext[CtxT] | None = None
375
382
  ) -> AsyncIterator[Event[Any]]:
376
- if memory.message_history.batch_size > 1:
377
- raise ValueError("Batch size must be 1 when streaming.")
378
-
379
383
  tool_choice: ToolChoice = "none" if self._react_mode else "auto"
380
384
  gen_message: AssistantMessage | None = None
381
- async for event in self.generate_message_stream(
382
- memory, tool_choice=tool_choice, ctx=ctx
385
+ async for event in self.generate_messages_stream(
386
+ memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
383
387
  ):
384
388
  yield event
385
389
  if isinstance(event, GenMessageEvent):
@@ -389,10 +393,8 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
389
393
  turns = 0
390
394
 
391
395
  while True:
392
- conversation = memory.message_history.conversations[0]
393
-
394
396
  if self._final_answer_tool_name is None and self._exit_tool_call_loop(
395
- conversation, ctx=ctx, num_turns=turns
397
+ memory.message_history, ctx=ctx, num_turns=turns
396
398
  ):
397
399
  return
398
400
 
@@ -409,7 +411,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
409
411
  if turns >= self.max_turns:
410
412
  if self._final_answer_tool_name is not None:
411
413
  async for event in self._generate_final_answer_stream(
412
- memory, ctx=ctx
414
+ memory, run_id=run_id, ctx=ctx
413
415
  ):
414
416
  yield event
415
417
  logger.info(
@@ -422,7 +424,7 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
422
424
  yield ToolCallEvent(name=self.agent_name, data=tool_call)
423
425
 
424
426
  async for tool_message_event in self.call_tools_stream(
425
- gen_message.tool_calls, memory=memory, ctx=ctx
427
+ gen_message.tool_calls, memory=memory, run_id=run_id, ctx=ctx
426
428
  ):
427
429
  yield tool_message_event
428
430
 
@@ -431,8 +433,8 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
431
433
  tool_choice = (
432
434
  "none" if (self._react_mode and gen_message.tool_calls) else "required"
433
435
  )
434
- async for event in self.generate_message_stream(
435
- memory, tool_choice=tool_choice, ctx=ctx
436
+ async for event in self.generate_messages_stream(
437
+ memory, tool_choice=tool_choice, run_id=run_id, ctx=ctx
436
438
  ):
437
439
  yield event
438
440
  if isinstance(event, GenMessageEvent):
@@ -443,12 +445,12 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
443
445
  def _track_usage(
444
446
  self,
445
447
  agent_name: str,
446
- completion_batch: Sequence[Completion],
448
+ completion: Completion,
447
449
  ctx: RunContext[CtxT],
448
450
  ) -> None:
449
451
  ctx.usage_tracker.update(
450
452
  agent_name=agent_name,
451
- completions=completion_batch,
453
+ completions=[completion],
452
454
  model_name=self.llm.model_name,
453
455
  )
454
456
 
@@ -473,11 +475,12 @@ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT
473
475
 
474
476
  return FinalAnswerTool()
475
477
 
476
- def _print_completions(
477
- self, completion_batch: Sequence[Completion], ctx: RunContext[CtxT]
478
+ def _print_completion(
479
+ self, completion: Completion, run_id: str, ctx: RunContext[CtxT]
478
480
  ) -> None:
479
- messages = [c.messages[0] for c in completion_batch]
480
- usages = [c.usage for c in completion_batch]
481
481
  ctx.printer.print_llm_messages(
482
- messages, usages=usages, agent_name=self.agent_name
482
+ completion.messages,
483
+ usages=[completion.usage],
484
+ agent_name=self.agent_name,
485
+ run_id=run_id,
483
486
  )
grasp_agents/memory.py CHANGED
@@ -1,10 +1,12 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Any
2
+ from typing import Any, TypeVar
3
3
 
4
4
  from pydantic import BaseModel, ConfigDict
5
5
 
6
6
  from .run_context import RunContext
7
7
 
8
+ MemT = TypeVar("MemT", bound="Memory")
9
+
8
10
 
9
11
  class Memory(BaseModel, ABC):
10
12
  @abstractmethod
@@ -1,3 +1,4 @@
1
+ from ..errors import CompletionError
1
2
  from ..typing.completion_chunk import (
2
3
  CompletionChunk,
3
4
  CompletionChunkChoice,
@@ -12,7 +13,7 @@ def from_api_completion_chunk(
12
13
  api_completion_chunk: OpenAICompletionChunk, name: str | None = None
13
14
  ) -> CompletionChunk:
14
15
  if api_completion_chunk.choices is None: # type: ignore
15
- raise RuntimeError(
16
+ raise CompletionError(
16
17
  f"Completion chunk API error: "
17
18
  f"{getattr(api_completion_chunk, 'error', None)}"
18
19
  )
@@ -24,12 +25,12 @@ def from_api_completion_chunk(
24
25
  finish_reason = api_choice.finish_reason
25
26
 
26
27
  if api_choice.delta is None: # type: ignore
27
- raise RuntimeError(
28
+ raise CompletionError(
28
29
  "API returned None for delta content in completion chunk "
29
30
  f"with finish_reason: {finish_reason}."
30
31
  )
31
32
  # if api_choice.delta.content is None:
32
- # raise RuntimeError(
33
+ # raise CompletionError(
33
34
  # "API returned None for delta content in completion chunk "
34
35
  # f"with finish_reason: {finish_reason}."
35
36
  # )
@@ -3,6 +3,7 @@ from collections.abc import AsyncIterator, Iterable, Mapping
3
3
  from copy import deepcopy
4
4
  from typing import Any, Literal, NamedTuple
5
5
 
6
+ import httpx
6
7
  from openai import AsyncOpenAI
7
8
  from openai._types import NOT_GIVEN # type: ignore[import]
8
9
  from openai.lib.streaming.chat import (
@@ -11,7 +12,7 @@ from openai.lib.streaming.chat import (
11
12
  from openai.lib.streaming.chat import ChunkEvent as OpenAIChunkEvent
12
13
  from pydantic import BaseModel
13
14
 
14
- from ..cloud_llm import CloudLLM, CloudLLMSettings
15
+ from ..cloud_llm import APIProvider, CloudLLM, CloudLLMSettings
15
16
  from ..http_client import AsyncHTTPClientParams
16
17
  from ..rate_limiting.rate_limiter_chunked import RateLimiterC
17
18
  from ..typing.message import AssistantMessage, Messages
@@ -57,8 +58,6 @@ class OpenAILLMSettings(CloudLLMSettings, total=False):
57
58
  store: bool | None
58
59
  user: str
59
60
 
60
- strict_tool_args: bool
61
-
62
61
  # response_format: (
63
62
  # OpenAIResponseFormatText
64
63
  # | OpenAIResponseFormatJSONSchema
@@ -71,16 +70,18 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
71
70
  self,
72
71
  # Base LLM args
73
72
  model_name: str,
74
- model_id: str | None = None,
75
73
  llm_settings: OpenAILLMSettings | None = None,
76
74
  tools: list[BaseTool[BaseModel, Any, Any]] | None = None,
77
75
  response_format: type | Mapping[str, type] | None = None,
76
+ model_id: str | None = None,
77
+ # Custom LLM provider
78
+ api_provider: APIProvider | None = None,
78
79
  # Connection settings
80
+ async_http_client: httpx.AsyncClient | None = None,
79
81
  async_http_client_params: (
80
82
  dict[str, Any] | AsyncHTTPClientParams | None
81
83
  ) = None,
82
84
  async_openai_client_params: dict[str, Any] | None = None,
83
- client: AsyncOpenAI | None = None,
84
85
  # Rate limiting
85
86
  rate_limiter: (RateLimiterC[Messages, AssistantMessage] | None) = None,
86
87
  rate_limiter_rpm: float | None = None,
@@ -88,9 +89,6 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
88
89
  rate_limiter_max_concurrency: int = 300,
89
90
  # Retries
90
91
  num_generation_retries: int = 0,
91
- # Disable tqdm for batch processing
92
- no_tqdm: bool = True,
93
- **kwargs: Any,
94
92
  ) -> None:
95
93
  super().__init__(
96
94
  model_name=model_name,
@@ -99,33 +97,29 @@ class OpenAILLM(CloudLLM[OpenAILLMSettings, OpenAIConverters]):
99
97
  converters=OpenAIConverters(),
100
98
  tools=tools,
101
99
  response_format=response_format,
100
+ api_provider=api_provider,
101
+ async_http_client=async_http_client,
102
102
  async_http_client_params=async_http_client_params,
103
103
  rate_limiter=rate_limiter,
104
104
  rate_limiter_rpm=rate_limiter_rpm,
105
105
  rate_limiter_chunk_size=rate_limiter_chunk_size,
106
106
  rate_limiter_max_concurrency=rate_limiter_max_concurrency,
107
107
  num_generation_retries=num_generation_retries,
108
- no_tqdm=no_tqdm,
109
- **kwargs,
110
108
  )
111
109
 
112
110
  self._tool_call_settings = {
113
- "strict": self._llm_settings.pop("strict_tool_args", None)
111
+ "strict": self._llm_settings.get("use_struct_outputs", False),
114
112
  }
115
113
 
116
114
  _async_openai_client_params = deepcopy(async_openai_client_params or {})
117
115
  if self._async_http_client is not None:
118
116
  _async_openai_client_params["http_client"] = self._async_http_client
119
117
 
120
- # TODO: context manager for async client
121
- if client:
122
- self._client = client
123
- else:
124
- self._client: AsyncOpenAI = AsyncOpenAI(
125
- base_url=self._base_url,
126
- api_key=self._api_key,
127
- **_async_openai_client_params,
128
- )
118
+ self._client: AsyncOpenAI = AsyncOpenAI(
119
+ base_url=self._base_url,
120
+ api_key=self._api_key,
121
+ **_async_openai_client_params,
122
+ )
129
123
 
130
124
  async def _get_completion(
131
125
  self,
@@ -17,7 +17,6 @@ def to_api_tool(
17
17
  tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None
18
18
  ) -> OpenAIToolParam:
19
19
  if strict:
20
- # Enforce strict mode for pydantic models
21
20
  return pydantic_function_tool(
22
21
  model=tool.in_type, name=tool.name, description=tool.description
23
22
  )
@@ -80,7 +80,7 @@ class PacketPool(Generic[CtxT]):
80
80
  try:
81
81
  await task
82
82
  except asyncio.CancelledError:
83
- logger.debug(f"{processor_name} exited")
83
+ logger.info(f"{processor_name} exited")
84
84
 
85
85
  self._tasks.pop(processor_name, None)
86
86
  self._queues.pop(processor_name, None)
grasp_agents/printer.py CHANGED
@@ -36,12 +36,10 @@ AVAILABLE_COLORS: list[Color] = [
36
36
  class Printer:
37
37
  def __init__(
38
38
  self,
39
- source_id: str,
40
39
  color_by: ColoringMode = "role",
41
40
  msg_trunc_len: int = 20000,
42
41
  print_messages: bool = False,
43
42
  ) -> None:
44
- self.source_id = source_id
45
43
  self.color_by = color_by
46
44
  self.msg_trunc_len = msg_trunc_len
47
45
  self.print_messages = print_messages
@@ -84,7 +82,7 @@ class Printer:
84
82
  return content_str
85
83
 
86
84
  def print_llm_message(
87
- self, message: Message, agent_name: str, usage: Usage | None = None
85
+ self, message: Message, agent_name: str, run_id: str, usage: Usage | None = None
88
86
  ) -> None:
89
87
  if not self.print_messages:
90
88
  return
@@ -106,8 +104,7 @@ class Printer:
106
104
 
107
105
  # Print message title
108
106
 
109
- out = f"\n<{agent_name}>"
110
- out += "[" + role.value.upper() + "]"
107
+ out = f"\n[agent: {agent_name} | role: {role.value} | run: {run_id}]"
111
108
 
112
109
  if isinstance(message, ToolMessage):
113
110
  out += f"\n{message.name} | {message.tool_call_id}"
@@ -159,6 +156,7 @@ class Printer:
159
156
  self,
160
157
  messages: Sequence[Message],
161
158
  agent_name: str,
159
+ run_id: str,
162
160
  usages: Sequence[Usage | None] | None = None,
163
161
  ) -> None:
164
162
  if not self.print_messages:
@@ -167,4 +165,6 @@ class Printer:
167
165
  _usages: Sequence[Usage | None] = usages or [None] * len(messages)
168
166
 
169
167
  for _message, _usage in zip(messages, _usages, strict=False):
170
- self.print_llm_message(_message, usage=_usage, agent_name=agent_name)
168
+ self.print_llm_message(
169
+ _message, usage=_usage, agent_name=agent_name, run_id=run_id
170
+ )