grasp_agents 0.5.6__py3-none-any.whl → 0.5.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/printer.py CHANGED
@@ -9,14 +9,29 @@ from pydantic import BaseModel
9
9
  from termcolor import colored
10
10
  from termcolor._types import Color
11
11
 
12
+ from grasp_agents.typing.completion_chunk import CompletionChunk
12
13
  from grasp_agents.typing.events import (
14
+ AnnotationsChunkEvent,
15
+ AnnotationsEndEvent,
16
+ AnnotationsStartEvent,
13
17
  CompletionChunkEvent,
18
+ # CompletionEndEvent,
19
+ CompletionStartEvent,
14
20
  Event,
15
21
  GenMessageEvent,
16
22
  MessageEvent,
17
23
  ProcPacketOutputEvent,
24
+ ResponseChunkEvent,
25
+ ResponseEndEvent,
26
+ ResponseStartEvent,
18
27
  RunResultEvent,
19
28
  SystemMessageEvent,
29
+ ThinkingChunkEvent,
30
+ ThinkingEndEvent,
31
+ ThinkingStartEvent,
32
+ ToolCallChunkEvent,
33
+ ToolCallEndEvent,
34
+ ToolCallStartEvent,
20
35
  ToolMessageEvent,
21
36
  UserMessageEvent,
22
37
  WorkflowResultEvent,
@@ -24,7 +39,14 @@ from grasp_agents.typing.events import (
24
39
 
25
40
  from .typing.completion import Usage
26
41
  from .typing.content import Content, ContentPartText
27
- from .typing.message import AssistantMessage, Message, Role, SystemMessage, UserMessage
42
+ from .typing.message import (
43
+ AssistantMessage,
44
+ Message,
45
+ Role,
46
+ SystemMessage,
47
+ ToolMessage,
48
+ UserMessage,
49
+ )
28
50
 
29
51
  logger = logging.getLogger(__name__)
30
52
 
@@ -72,7 +94,7 @@ class Printer:
72
94
  return AVAILABLE_COLORS[idx]
73
95
 
74
96
  @staticmethod
75
- def content_to_str(content: Content | str, role: Role) -> str:
97
+ def content_to_str(content: Content | str | None, role: Role) -> str:
76
98
  if role == Role.USER and isinstance(content, Content):
77
99
  content_str_parts: list[str] = []
78
100
  for content_part in content.parts:
@@ -84,9 +106,9 @@ class Printer:
84
106
  content_str_parts.append("<ENCODED_IMAGE>")
85
107
  return "\n".join(content_str_parts)
86
108
 
87
- assert isinstance(content, str)
109
+ assert isinstance(content, str | None)
88
110
 
89
- return content.strip(" \n")
111
+ return (content or "").strip(" \n")
90
112
 
91
113
  @staticmethod
92
114
  def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
@@ -179,36 +201,10 @@ def stream_text(new_text: str, color: Color) -> None:
179
201
  async def print_event_stream(
180
202
  event_generator: AsyncIterator[Event[Any]],
181
203
  color_by: ColoringMode = "role",
182
- trunc_len: int = 1000,
204
+ trunc_len: int = 10000,
183
205
  ) -> AsyncIterator[Event[Any]]:
184
- prev_chunk_id: str | None = None
185
- thinking_open = False
186
- response_open = False
187
- open_tool_calls: set[str] = set()
188
-
189
206
  color = Printer.get_role_color(Role.ASSISTANT)
190
207
 
191
- def _close_blocks(
192
- _thinking_open: bool, _response_open: bool, color: Color
193
- ) -> tuple[bool, bool]:
194
- closing_text = ""
195
- while open_tool_calls:
196
- open_tool_calls.pop()
197
- closing_text += "\n</tool call>\n"
198
-
199
- if _thinking_open:
200
- closing_text += "\n</thinking>\n"
201
- _thinking_open = False
202
-
203
- if _response_open:
204
- closing_text += "\n</response>\n"
205
- _response_open = False
206
-
207
- if closing_text:
208
- stream_text(closing_text, color)
209
-
210
- return _thinking_open, _response_open
211
-
212
208
  def _get_color(event: Event[Any], role: Role = Role.ASSISTANT) -> Color:
213
209
  if color_by == "agent":
214
210
  return Printer.get_agent_color(event.proc_name or "")
@@ -232,14 +228,7 @@ async def print_event_stream(
232
228
  text += f"<{src} output>\n"
233
229
  for p in event.data.payloads:
234
230
  if isinstance(p, BaseModel):
235
- # for field_info in type(p).model_fields.values():
236
- # if field_info.exclude:
237
- # field_info.exclude = False
238
- # break
239
- # type(p).model_rebuild(force=True)
240
231
  p_str = p.model_dump_json(indent=2)
241
- # field_info.exclude = True # type: ignore
242
- # type(p).model_rebuild(force=True)
243
232
  else:
244
233
  try:
245
234
  p_str = json.dumps(p, indent=2)
@@ -253,56 +242,64 @@ async def print_event_stream(
253
242
  async for event in event_generator:
254
243
  yield event
255
244
 
256
- if isinstance(event, CompletionChunkEvent):
257
- delta = event.data.choices[0].delta
258
- chunk_id = event.data.id
259
- new_completion = chunk_id != prev_chunk_id
245
+ if isinstance(event, CompletionChunkEvent) and isinstance(
246
+ event.data, CompletionChunk
247
+ ):
260
248
  color = _get_color(event, Role.ASSISTANT)
261
249
 
262
250
  text = ""
263
251
 
264
- if new_completion:
265
- thinking_open, response_open = _close_blocks(
266
- thinking_open, response_open, color
267
- )
252
+ if isinstance(event, CompletionStartEvent):
268
253
  text += f"\n<{event.proc_name}> [{event.call_id}]\n"
269
-
270
- if delta.reasoning_content:
271
- if not thinking_open:
272
- text += "<thinking>\n"
273
- thinking_open = True
274
- text += delta.reasoning_content
275
- elif thinking_open:
254
+ elif isinstance(event, ThinkingStartEvent):
255
+ text += "<thinking>\n"
256
+ elif isinstance(event, ResponseStartEvent):
257
+ text += "<response>\n"
258
+ elif isinstance(event, ToolCallStartEvent):
259
+ tc = event.data.tool_call
260
+ text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
261
+ elif isinstance(event, AnnotationsStartEvent):
262
+ text += "<annotations>\n"
263
+
264
+ # if isinstance(event, CompletionEndEvent):
265
+ # text += f"\n</{event.proc_name}>\n"
266
+ if isinstance(event, ThinkingEndEvent):
276
267
  text += "\n</thinking>\n"
277
- thinking_open = False
278
-
279
- if delta.content:
280
- if not response_open:
281
- text += "<response>\n"
282
- response_open = True
283
- text += delta.content
284
- elif response_open:
268
+ elif isinstance(event, ResponseEndEvent):
285
269
  text += "\n</response>\n"
286
- response_open = False
287
-
288
- if delta.tool_calls:
289
- for tc in delta.tool_calls:
290
- if tc.id and tc.id not in open_tool_calls:
291
- open_tool_calls.add(tc.id) # type: ignore
292
- text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
293
-
294
- if tc.tool_arguments:
295
- text += tc.tool_arguments
270
+ elif isinstance(event, ToolCallEndEvent):
271
+ text += "\n</tool call>\n"
272
+ elif isinstance(event, AnnotationsEndEvent):
273
+ text += "\n</annotations>\n"
274
+
275
+ if isinstance(event, ThinkingChunkEvent):
276
+ thinking = event.data.thinking
277
+ if isinstance(thinking, str):
278
+ text += thinking
279
+ else:
280
+ text = "\n".join(
281
+ [block.get("thinking", "[redacted]") for block in thinking]
282
+ )
283
+
284
+ if isinstance(event, ResponseChunkEvent):
285
+ text += event.data.response
286
+
287
+ if isinstance(event, ToolCallChunkEvent):
288
+ text += event.data.tool_call.tool_arguments or ""
289
+
290
+ if isinstance(event, AnnotationsChunkEvent):
291
+ text += "\n".join(
292
+ [
293
+ json.dumps(annotation, indent=2)
294
+ for annotation in event.data.annotations
295
+ ]
296
+ )
296
297
 
297
298
  stream_text(text, color)
298
- prev_chunk_id = chunk_id
299
-
300
- else:
301
- thinking_open, response_open = _close_blocks(
302
- thinking_open, response_open, color
303
- )
304
299
 
305
300
  if isinstance(event, MessageEvent) and not isinstance(event, GenMessageEvent):
301
+ assert isinstance(event.data, (SystemMessage | UserMessage | ToolMessage))
302
+
306
303
  message = event.data
307
304
  role = message.role
308
305
  content = Printer.content_to_str(message.content, role=role)
@@ -320,6 +317,7 @@ async def print_event_stream(
320
317
  text += f"<input>\n{content}\n</input>\n"
321
318
 
322
319
  elif isinstance(event, ToolMessageEvent):
320
+ message = event.data
323
321
  try:
324
322
  content = json.dumps(json.loads(content), indent=2)
325
323
  except Exception:
@@ -1,16 +1,13 @@
1
1
  import asyncio
2
2
  import logging
3
3
  from collections.abc import AsyncIterator, Sequence
4
- from typing import Any, ClassVar, Generic, cast
5
-
4
+ from typing import Any, ClassVar, Generic, cast
6
5
 
7
6
  from ..errors import PacketRoutingError
8
7
  from ..memory import MemT
9
8
  from ..packet import Packet
10
9
  from ..run_context import CtxT, RunContext
11
- from ..typing.events import (
12
- Event, ProcPacketOutputEvent, ProcPayloadOutputEvent
13
- )
10
+ from ..typing.events import Event, ProcPacketOutputEvent, ProcPayloadOutputEvent
14
11
  from ..typing.io import InT, OutT
15
12
  from ..utils import stream_concurrent
16
13
  from .base_processor import BaseProcessor, with_retry, with_retry_stream
@@ -18,7 +15,9 @@ from .base_processor import BaseProcessor, with_retry, with_retry_stream
18
15
  logger = logging.getLogger(__name__)
19
16
 
20
17
 
21
- class ParallelProcessor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, CtxT]):
18
+ class ParallelProcessor(
19
+ BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, CtxT]
20
+ ):
22
21
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
23
22
  0: "_in_type",
24
23
  1: "_out_type",
@@ -33,7 +32,7 @@ class ParallelProcessor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT,
33
32
  call_id: str,
34
33
  ctx: RunContext[CtxT] | None = None,
35
34
  ) -> OutT:
36
- return cast(OutT, in_args)
35
+ return cast("OutT", in_args)
37
36
 
38
37
  async def _process_stream(
39
38
  self,
@@ -44,7 +43,7 @@ class ParallelProcessor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT,
44
43
  call_id: str,
45
44
  ctx: RunContext[CtxT] | None = None,
46
45
  ) -> AsyncIterator[Event[Any]]:
47
- output = cast(OutT, in_args)
46
+ output = cast("OutT", in_args)
48
47
  yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
49
48
 
50
49
  def _validate_parallel_recipients(
@@ -59,7 +58,7 @@ class ParallelProcessor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT,
59
58
  message="Parallel runs must return the same recipients "
60
59
  f"[proc_name={self.name}; call_id={call_id}]",
61
60
  )
62
-
61
+
63
62
  @with_retry
64
63
  async def _run_single(
65
64
  self,
@@ -86,7 +85,6 @@ class ParallelProcessor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT,
86
85
 
87
86
  return Packet(payloads=[val_output], sender=self.name, recipients=recipients)
88
87
 
89
-
90
88
  async def _run_parallel(
91
89
  self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT] | None = None
92
90
  ) -> Packet[OutT]:
@@ -125,8 +123,10 @@ class ParallelProcessor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT,
125
123
  )
126
124
 
127
125
  if val_in_args and len(val_in_args) > 1:
128
- return await self._run_parallel(in_args=val_in_args, call_id=call_id, ctx=ctx)
129
-
126
+ return await self._run_parallel(
127
+ in_args=val_in_args, call_id=call_id, ctx=ctx
128
+ )
129
+
130
130
  return await self._run_single(
131
131
  chat_inputs=chat_inputs,
132
132
  in_args=val_in_args[0] if val_in_args else None,
@@ -231,7 +231,9 @@ class ParallelProcessor(BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT,
231
231
  )
232
232
 
233
233
  if val_in_args and len(val_in_args) > 1:
234
- stream = self._run_parallel_stream(in_args=val_in_args, call_id=call_id, ctx=ctx)
234
+ stream = self._run_parallel_stream(
235
+ in_args=val_in_args, call_id=call_id, ctx=ctx
236
+ )
235
237
  else:
236
238
  stream = self._run_single_stream(
237
239
  chat_inputs=chat_inputs,
grasp_agents/runner.py CHANGED
@@ -6,7 +6,7 @@ from typing import Any, Generic
6
6
  from .errors import RunnerError
7
7
  from .packet import Packet, StartPacket
8
8
  from .packet_pool import END_PROC_NAME, PacketPool
9
- from .processors.processor import Processor
9
+ from .processors.base_processor import BaseProcessor
10
10
  from .run_context import CtxT, RunContext
11
11
  from .typing.events import Event, ProcPacketOutputEvent, RunResultEvent
12
12
  from .typing.io import OutT
@@ -17,8 +17,8 @@ logger = logging.getLogger(__name__)
17
17
  class Runner(Generic[OutT, CtxT]):
18
18
  def __init__(
19
19
  self,
20
- entry_proc: Processor[Any, Any, Any, CtxT],
21
- procs: Sequence[Processor[Any, Any, Any, CtxT]],
20
+ entry_proc: BaseProcessor[Any, Any, Any, CtxT],
21
+ procs: Sequence[BaseProcessor[Any, Any, Any, CtxT]],
22
22
  ctx: RunContext[CtxT] | None = None,
23
23
  ) -> None:
24
24
  if entry_proc not in procs:
@@ -34,7 +34,6 @@ class Runner(Generic[OutT, CtxT]):
34
34
  self._entry_proc = entry_proc
35
35
  self._procs = procs
36
36
  self._ctx = ctx or RunContext[CtxT]()
37
- self._packet_pool: PacketPool[CtxT] = PacketPool()
38
37
 
39
38
  @property
40
39
  def ctx(self) -> RunContext[CtxT]:
@@ -49,9 +48,10 @@ class Runner(Generic[OutT, CtxT]):
49
48
 
50
49
  async def _packet_handler(
51
50
  self,
52
- proc: Processor[Any, Any, Any, CtxT],
53
- pool: PacketPool[CtxT],
54
51
  packet: Packet[Any],
52
+ *,
53
+ proc: BaseProcessor[Any, Any, Any, CtxT],
54
+ pool: PacketPool,
55
55
  ctx: RunContext[CtxT],
56
56
  **run_kwargs: Any,
57
57
  ) -> None:
@@ -72,9 +72,10 @@ class Runner(Generic[OutT, CtxT]):
72
72
 
73
73
  async def _packet_handler_stream(
74
74
  self,
75
- proc: Processor[Any, Any, Any, CtxT],
76
- pool: PacketPool[CtxT],
77
75
  packet: Packet[Any],
76
+ *,
77
+ proc: BaseProcessor[Any, Any, Any, CtxT],
78
+ pool: PacketPool,
78
79
  ctx: RunContext[CtxT],
79
80
  **run_kwargs: Any,
80
81
  ) -> None:
@@ -99,18 +100,18 @@ class Runner(Generic[OutT, CtxT]):
99
100
 
100
101
  await pool.post(out_packet)
101
102
 
102
- async def run(
103
- self,
104
- chat_input: Any = "start",
105
- **run_args: Any,
106
- ) -> Packet[OutT]:
107
- async with PacketPool[CtxT]() as pool:
103
+ async def run(self, chat_input: Any = "start", **run_args: Any) -> Packet[OutT]:
104
+ async with PacketPool() as pool:
108
105
  for proc in self._procs:
109
106
  pool.register_packet_handler(
110
107
  proc_name=proc.name,
111
- handler=partial(self._packet_handler, proc, pool),
112
- ctx=self._ctx,
113
- **run_args,
108
+ handler=partial(
109
+ self._packet_handler,
110
+ proc=proc,
111
+ pool=pool,
112
+ ctx=self._ctx,
113
+ **run_args,
114
+ ),
114
115
  )
115
116
  await pool.post(
116
117
  StartPacket[Any](
@@ -120,17 +121,19 @@ class Runner(Generic[OutT, CtxT]):
120
121
  return await pool.final_result()
121
122
 
122
123
  async def run_stream(
123
- self,
124
- chat_input: Any = "start",
125
- **run_args: Any,
124
+ self, chat_input: Any = "start", **run_args: Any
126
125
  ) -> AsyncIterator[Event[Any]]:
127
- async with PacketPool[CtxT]() as pool:
126
+ async with PacketPool() as pool:
128
127
  for proc in self._procs:
129
128
  pool.register_packet_handler(
130
129
  proc_name=proc.name,
131
- handler=partial(self._packet_handler_stream, proc, pool),
132
- ctx=self._ctx,
133
- **run_args,
130
+ handler=partial(
131
+ self._packet_handler_stream,
132
+ proc=proc,
133
+ pool=pool,
134
+ ctx=self._ctx,
135
+ **run_args,
136
+ ),
134
137
  )
135
138
  await pool.post(
136
139
  StartPacket[Any](