grasp_agents 0.5.6__tar.gz → 0.5.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/PKG-INFO +1 -1
  2. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/pyproject.toml +1 -1
  3. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/cloud_llm.py +11 -5
  4. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/llm.py +146 -1
  5. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/llm_policy_executor.py +9 -3
  6. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/packet_pool.py +23 -43
  7. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/printer.py +75 -77
  8. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/processors/parallel_processor.py +15 -13
  9. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/runner.py +27 -24
  10. grasp_agents-0.5.8/src/grasp_agents/typing/completion_chunk.py +506 -0
  11. grasp_agents-0.5.8/src/grasp_agents/typing/events.py +376 -0
  12. grasp_agents-0.5.6/src/grasp_agents/typing/completion_chunk.py +0 -207
  13. grasp_agents-0.5.6/src/grasp_agents/typing/events.py +0 -170
  14. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/.gitignore +0 -0
  15. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/LICENSE.md +0 -0
  16. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/README.md +0 -0
  17. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/__init__.py +0 -0
  18. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/costs_dict.yaml +0 -0
  19. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/errors.py +0 -0
  20. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/generics_utils.py +0 -0
  21. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/grasp_logging.py +0 -0
  22. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/http_client.py +0 -0
  23. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/litellm/__init__.py +0 -0
  24. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/litellm/completion_chunk_converters.py +0 -0
  25. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/litellm/completion_converters.py +0 -0
  26. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/litellm/converters.py +0 -0
  27. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/litellm/lite_llm.py +0 -0
  28. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/litellm/message_converters.py +0 -0
  29. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/llm_agent.py +0 -0
  30. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/llm_agent_memory.py +0 -0
  31. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/memory.py +0 -0
  32. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/openai/__init__.py +0 -0
  33. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/openai/completion_chunk_converters.py +0 -0
  34. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/openai/completion_converters.py +0 -0
  35. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/openai/content_converters.py +0 -0
  36. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/openai/converters.py +0 -0
  37. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/openai/message_converters.py +0 -0
  38. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/openai/openai_llm.py +0 -0
  39. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/openai/tool_converters.py +0 -0
  40. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/packet.py +0 -0
  41. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/processors/base_processor.py +0 -0
  42. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/processors/processor.py +0 -0
  43. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/prompt_builder.py +0 -0
  44. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/rate_limiting/__init__.py +0 -0
  45. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
  46. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/rate_limiting/types.py +0 -0
  47. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/rate_limiting/utils.py +0 -0
  48. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/run_context.py +0 -0
  49. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/typing/__init__.py +0 -0
  50. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/typing/completion.py +0 -0
  51. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/typing/content.py +0 -0
  52. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/typing/converters.py +0 -0
  53. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/typing/io.py +0 -0
  54. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/typing/message.py +0 -0
  55. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/typing/tool.py +0 -0
  56. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/usage_tracker.py +0 -0
  57. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/utils.py +0 -0
  58. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/workflow/__init__.py +0 -0
  59. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/workflow/looped_workflow.py +0 -0
  60. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/workflow/sequential_workflow.py +0 -0
  61. {grasp_agents-0.5.6 → grasp_agents-0.5.8}/src/grasp_agents/workflow/workflow_processor.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grasp_agents
3
- Version: 0.5.6
3
+ Version: 0.5.8
4
4
  Summary: Grasp Agents Library
5
5
  License-File: LICENSE.md
6
6
  Requires-Python: <4,>=3.11.4
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "grasp_agents"
3
- version = "0.5.6"
3
+ version = "0.5.8"
4
4
  description = "Grasp Agents Library"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11.4,<4"
@@ -13,7 +13,7 @@ from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
13
13
  from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
14
14
  from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
15
15
  from .typing.completion import Completion
16
- from .typing.completion_chunk import CompletionChoice
16
+ from .typing.completion_chunk import CompletionChoice, CompletionChunk
17
17
  from .typing.events import (
18
18
  CompletionChunkEvent,
19
19
  CompletionEvent,
@@ -52,7 +52,9 @@ class CloudLLMSettings(LLMSettings, total=False):
52
52
  LLMRateLimiter = RateLimiterC[
53
53
  Messages,
54
54
  AssistantMessage
55
- | AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent],
55
+ | AsyncIterator[
56
+ CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
57
+ ],
56
58
  ]
57
59
 
58
60
 
@@ -274,7 +276,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
274
276
  n_choices: int | None = None,
275
277
  proc_name: str | None = None,
276
278
  call_id: str | None = None,
277
- ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
279
+ ) -> AsyncIterator[CompletionChunkEvent[CompletionChunk] | CompletionEvent]:
278
280
  completion_kwargs = self._make_completion_kwargs(
279
281
  conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
280
282
  )
@@ -284,7 +286,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
284
286
  api_stream = self._get_completion_stream(**completion_kwargs)
285
287
  api_stream = cast("AsyncIterator[Any]", api_stream)
286
288
 
287
- async def iterator() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
289
+ async def iterator() -> AsyncIterator[
290
+ CompletionChunkEvent[CompletionChunk] | CompletionEvent
291
+ ]:
288
292
  api_completion_chunks: list[Any] = []
289
293
 
290
294
  async for api_completion_chunk in api_stream:
@@ -318,7 +322,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
318
322
  n_choices: int | None = None,
319
323
  proc_name: str | None = None,
320
324
  call_id: str | None = None,
321
- ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
325
+ ) -> AsyncIterator[
326
+ CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
327
+ ]:
322
328
  n_attempt = 0
323
329
  while n_attempt <= self.max_response_retries:
324
330
  try:
@@ -7,6 +7,7 @@ from uuid import uuid4
7
7
  from pydantic import BaseModel
8
8
  from typing_extensions import TypedDict
9
9
 
10
+ from grasp_agents.typing.completion_chunk import CompletionChunk
10
11
  from grasp_agents.utils import (
11
12
  validate_obj_from_json_or_py_string,
12
13
  validate_tagged_objs_from_json_or_py_string,
@@ -20,9 +21,25 @@ from .errors import (
20
21
  from .typing.completion import Completion
21
22
  from .typing.converters import Converters
22
23
  from .typing.events import (
24
+ AnnotationsChunkEvent,
25
+ AnnotationsEndEvent,
26
+ AnnotationsStartEvent,
23
27
  CompletionChunkEvent,
28
+ CompletionEndEvent,
24
29
  CompletionEvent,
30
+ CompletionStartEvent,
31
+ LLMStateChangeEvent,
25
32
  LLMStreamingErrorEvent,
33
+ # RefusalChunkEvent,
34
+ ResponseChunkEvent,
35
+ ResponseEndEvent,
36
+ ResponseStartEvent,
37
+ ThinkingChunkEvent,
38
+ ThinkingEndEvent,
39
+ ThinkingStartEvent,
40
+ ToolCallChunkEvent,
41
+ ToolCallEndEvent,
42
+ ToolCallStartEvent,
26
43
  )
27
44
  from .typing.message import Messages
28
45
  from .typing.tool import BaseTool, ToolChoice
@@ -30,6 +47,14 @@ from .typing.tool import BaseTool, ToolChoice
30
47
  logger = logging.getLogger(__name__)
31
48
 
32
49
 
50
+ LLMStreamGenerator = AsyncIterator[
51
+ CompletionChunkEvent[CompletionChunk]
52
+ | CompletionEvent
53
+ | LLMStateChangeEvent[Any]
54
+ | LLMStreamingErrorEvent
55
+ ]
56
+
57
+
33
58
  class LLMSettings(TypedDict, total=False):
34
59
  max_completion_tokens: int | None
35
60
  temperature: float | None
@@ -160,6 +185,124 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
160
185
  tool_name, tool_arguments
161
186
  ) from exc
162
187
 
188
+ @staticmethod
189
+ async def postprocess_event_stream(
190
+ stream: LLMStreamGenerator,
191
+ ) -> LLMStreamGenerator:
192
+ prev_completion_id: str | None = None
193
+ chunk_op_evt: CompletionChunkEvent[CompletionChunk] | None = None
194
+ response_op_evt: ResponseChunkEvent | None = None
195
+ thinking_op_evt: ThinkingChunkEvent | None = None
196
+ annotations_op_evt: AnnotationsChunkEvent | None = None
197
+ tool_calls_op_evt: ToolCallChunkEvent | None = None
198
+
199
+ def _close_open_events() -> list[LLMStateChangeEvent[Any]]:
200
+ nonlocal \
201
+ chunk_op_evt, \
202
+ thinking_op_evt, \
203
+ tool_calls_op_evt, \
204
+ response_op_evt, \
205
+ annotations_op_evt
206
+
207
+ events: list[LLMStateChangeEvent[Any]] = []
208
+
209
+ if tool_calls_op_evt:
210
+ events.append(ToolCallEndEvent.from_chunk_event(tool_calls_op_evt))
211
+
212
+ if response_op_evt:
213
+ events.append(ResponseEndEvent.from_chunk_event(response_op_evt))
214
+
215
+ if thinking_op_evt:
216
+ events.append(ThinkingEndEvent.from_chunk_event(thinking_op_evt))
217
+
218
+ if annotations_op_evt:
219
+ events.append(AnnotationsEndEvent.from_chunk_event(annotations_op_evt))
220
+
221
+ if chunk_op_evt:
222
+ events.append(CompletionEndEvent.from_chunk_event(chunk_op_evt))
223
+
224
+ chunk_op_evt = None
225
+ thinking_op_evt = None
226
+ tool_calls_op_evt = None
227
+ response_op_evt = None
228
+ annotations_op_evt = None
229
+
230
+ return events
231
+
232
+ async for event in stream:
233
+ if isinstance(event, CompletionChunkEvent) and not isinstance(
234
+ event, LLMStateChangeEvent
235
+ ):
236
+ chunk = event.data
237
+ if len(chunk.choices) != 1:
238
+ raise ValueError(
239
+ "Expected exactly one choice in completion chunk, "
240
+ f"got {len(chunk.choices)}"
241
+ )
242
+
243
+ new_completion = chunk.id != prev_completion_id
244
+
245
+ if new_completion:
246
+ for close_event in _close_open_events():
247
+ yield close_event
248
+
249
+ chunk_op_evt = event
250
+ yield CompletionStartEvent.from_chunk_event(event)
251
+
252
+ sub_events = event.split_into_specialized()
253
+
254
+ for sub_event in sub_events:
255
+ if isinstance(sub_event, ThinkingChunkEvent):
256
+ if not thinking_op_evt:
257
+ thinking_op_evt = sub_event
258
+ yield ThinkingStartEvent.from_chunk_event(sub_event)
259
+ yield sub_event
260
+ elif thinking_op_evt:
261
+ yield ThinkingEndEvent.from_chunk_event(thinking_op_evt)
262
+ thinking_op_evt = None
263
+
264
+ if isinstance(sub_event, ToolCallChunkEvent):
265
+ tc = sub_event.data.tool_call
266
+ if tc.id:
267
+ # Tool call ID is not None only for the first chunk of a tool call
268
+ if tool_calls_op_evt:
269
+ yield ToolCallEndEvent.from_chunk_event(
270
+ tool_calls_op_evt
271
+ )
272
+ tool_calls_op_evt = None
273
+ tool_calls_op_evt = sub_event
274
+ yield ToolCallStartEvent.from_chunk_event(sub_event)
275
+ yield sub_event
276
+ elif tool_calls_op_evt:
277
+ yield ToolCallEndEvent.from_chunk_event(tool_calls_op_evt)
278
+ tool_calls_op_evt = None
279
+
280
+ if isinstance(sub_event, ResponseChunkEvent):
281
+ if not response_op_evt:
282
+ response_op_evt = sub_event
283
+ yield ResponseStartEvent.from_chunk_event(sub_event)
284
+ yield sub_event
285
+ elif response_op_evt:
286
+ yield ResponseEndEvent.from_chunk_event(response_op_evt)
287
+ response_op_evt = None
288
+
289
+ if isinstance(sub_event, AnnotationsChunkEvent):
290
+ if not annotations_op_evt:
291
+ annotations_op_evt = sub_event
292
+ yield AnnotationsStartEvent.from_chunk_event(sub_event)
293
+ yield sub_event
294
+ elif annotations_op_evt:
295
+ yield AnnotationsEndEvent.from_chunk_event(annotations_op_evt)
296
+ annotations_op_evt = None
297
+
298
+ prev_completion_id = chunk.id
299
+
300
+ else:
301
+ for close_event in _close_open_events():
302
+ yield close_event
303
+
304
+ yield event
305
+
163
306
  @abstractmethod
164
307
  async def generate_completion(
165
308
  self,
@@ -181,7 +324,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
181
324
  n_choices: int | None = None,
182
325
  proc_name: str | None = None,
183
326
  call_id: str | None = None,
184
- ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
327
+ ) -> AsyncIterator[
328
+ CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
329
+ ]:
185
330
  pass
186
331
 
187
332
  @abstractmethod
@@ -7,6 +7,8 @@ from typing import Any, Generic, Protocol, final
7
7
 
8
8
  from pydantic import BaseModel
9
9
 
10
+ from grasp_agents.typing.completion_chunk import CompletionChunk
11
+
10
12
  from .errors import AgentFinalAnswerError
11
13
  from .llm import LLM, LLMSettings
12
14
  from .llm_agent_memory import LLMAgentMemory
@@ -149,19 +151,23 @@ class LLMPolicyExecutor(Generic[CtxT]):
149
151
  tool_choice: ToolChoice | None = None,
150
152
  ctx: RunContext[CtxT] | None = None,
151
153
  ) -> AsyncIterator[
152
- CompletionChunkEvent
154
+ CompletionChunkEvent[CompletionChunk]
153
155
  | CompletionEvent
154
156
  | GenMessageEvent
155
157
  | LLMStreamingErrorEvent
156
158
  ]:
157
159
  completion: Completion | None = None
158
- async for event in self.llm.generate_completion_stream( # type: ignore[no-untyped-call]
160
+
161
+ llm_event_stream = self.llm.generate_completion_stream(
159
162
  memory.message_history,
160
163
  tool_choice=tool_choice,
161
164
  n_choices=1,
162
165
  proc_name=self.agent_name,
163
166
  call_id=call_id,
164
- ):
167
+ )
168
+ llm_event_stream_post = self.llm.postprocess_event_stream(llm_event_stream) # type: ignore[assignment]
169
+
170
+ async for event in llm_event_stream_post:
165
171
  if isinstance(event, CompletionEvent):
166
172
  completion = event.data
167
173
  yield event
@@ -2,10 +2,9 @@ import asyncio
2
2
  import logging
3
3
  from collections.abc import AsyncIterator
4
4
  from types import TracebackType
5
- from typing import Any, Generic, Literal, Protocol, TypeVar
5
+ from typing import Any, Literal, Protocol, TypeVar
6
6
 
7
7
  from .packet import Packet
8
- from .run_context import CtxT, RunContext
9
8
  from .typing.events import Event
10
9
  from .typing.io import ProcName
11
10
 
@@ -18,24 +17,21 @@ END_PROC_NAME: Literal["*END*"] = "*END*"
18
17
  _PayloadT_contra = TypeVar("_PayloadT_contra", contravariant=True)
19
18
 
20
19
 
21
- class PacketHandler(Protocol[_PayloadT_contra, CtxT]):
20
+ class PacketHandler(Protocol[_PayloadT_contra]):
22
21
  async def __call__(
23
- self,
24
- packet: Packet[_PayloadT_contra],
25
- ctx: RunContext[CtxT],
26
- **kwargs: Any,
22
+ self, packet: Packet[_PayloadT_contra], **kwargs: Any
27
23
  ) -> None: ...
28
24
 
29
25
 
30
- class PacketPool(Generic[CtxT]):
26
+ class PacketPool:
31
27
  def __init__(self) -> None:
32
28
  self._packet_queues: dict[ProcName, asyncio.Queue[Packet[Any] | None]] = {}
33
- self._packet_handlers: dict[ProcName, PacketHandler[Any, CtxT]] = {}
29
+ self._packet_handlers: dict[ProcName, PacketHandler[Any]] = {}
34
30
  self._task_group: asyncio.TaskGroup | None = None
35
31
 
36
32
  self._event_queue: asyncio.Queue[Event[Any] | None] = asyncio.Queue()
37
33
 
38
- self._final_result_fut: asyncio.Future[Packet[Any]] | None = None
34
+ self._final_result_fut: asyncio.Future[Packet[Any]]
39
35
 
40
36
  self._stopping = False
41
37
  self._stopped_evt = asyncio.Event()
@@ -44,9 +40,8 @@ class PacketPool(Generic[CtxT]):
44
40
 
45
41
  async def post(self, packet: Packet[Any]) -> None:
46
42
  if packet.recipients == [END_PROC_NAME]:
47
- fut = self._ensure_final_future()
48
- if not fut.done():
49
- fut.set_result(packet)
43
+ if not self._final_result_fut.done():
44
+ self._final_result_fut.set_result(packet)
50
45
  await self.shutdown()
51
46
  return
52
47
 
@@ -54,31 +49,14 @@ class PacketPool(Generic[CtxT]):
54
49
  queue = self._packet_queues.setdefault(recipient_id, asyncio.Queue())
55
50
  await queue.put(packet)
56
51
 
57
- def _ensure_final_future(self) -> asyncio.Future[Packet[Any]]:
58
- fut = self._final_result_fut
59
- if fut is None:
60
- fut = asyncio.get_running_loop().create_future()
61
- self._final_result_fut = fut
62
- return fut
63
-
64
52
  async def final_result(self) -> Packet[Any]:
65
- fut = self._ensure_final_future()
66
53
  try:
67
- return await fut
54
+ return await self._final_result_fut
68
55
  finally:
69
56
  await self.shutdown()
70
57
 
71
- @property
72
- def final_result_ready(self) -> bool:
73
- fut = self._final_result_fut
74
- return fut is not None and fut.done()
75
-
76
58
  def register_packet_handler(
77
- self,
78
- proc_name: ProcName,
79
- handler: PacketHandler[Any, CtxT],
80
- ctx: RunContext[CtxT],
81
- **run_kwargs: Any,
59
+ self, proc_name: ProcName, handler: PacketHandler[Any]
82
60
  ) -> None:
83
61
  if self._stopping:
84
62
  raise RuntimeError("PacketPool is stopping/stopped")
@@ -88,17 +66,19 @@ class PacketPool(Generic[CtxT]):
88
66
 
89
67
  if self._task_group is not None:
90
68
  self._task_group.create_task(
91
- self._handle_packets(proc_name, ctx=ctx, **run_kwargs),
69
+ self._handle_packets(proc_name),
92
70
  name=f"packet-handler:{proc_name}",
93
71
  )
94
72
 
95
73
  async def push_event(self, event: Event[Any]) -> None:
96
74
  await self._event_queue.put(event)
97
75
 
98
- async def __aenter__(self) -> "PacketPool[CtxT]":
76
+ async def __aenter__(self) -> "PacketPool":
99
77
  self._task_group = asyncio.TaskGroup()
100
78
  await self._task_group.__aenter__()
101
79
 
80
+ self._final_result_fut = asyncio.get_running_loop().create_future()
81
+
102
82
  return self
103
83
 
104
84
  async def __aexit__(
@@ -120,26 +100,27 @@ class PacketPool(Generic[CtxT]):
120
100
 
121
101
  return False
122
102
 
123
- async def _handle_packets(
124
- self, proc_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
125
- ) -> None:
103
+ async def _handle_packets(self, proc_name: ProcName) -> None:
126
104
  queue = self._packet_queues[proc_name]
127
105
  handler = self._packet_handlers[proc_name]
128
106
 
129
- while not self.final_result_ready:
107
+ while True:
130
108
  packet = await queue.get()
131
109
  if packet is None:
132
110
  break
111
+
112
+ if self._final_result_fut.done():
113
+ continue
114
+
133
115
  try:
134
- await handler(packet, ctx=ctx, **run_kwargs)
116
+ await handler(packet)
135
117
  except asyncio.CancelledError:
136
118
  raise
137
119
  except Exception as err:
138
120
  logger.exception("Error handling packet for %s", proc_name)
139
121
  self._errors.append(err)
140
- fut = self._final_result_fut
141
- if fut and not fut.done():
142
- fut.set_exception(err)
122
+ if not self._final_result_fut.done():
123
+ self._final_result_fut.set_exception(err)
143
124
  await self.shutdown()
144
125
  raise
145
126
 
@@ -159,6 +140,5 @@ class PacketPool(Generic[CtxT]):
159
140
  await self._event_queue.put(None)
160
141
  for queue in self._packet_queues.values():
161
142
  await queue.put(None)
162
-
163
143
  finally:
164
144
  self._stopped_evt.set()
@@ -9,14 +9,29 @@ from pydantic import BaseModel
9
9
  from termcolor import colored
10
10
  from termcolor._types import Color
11
11
 
12
+ from grasp_agents.typing.completion_chunk import CompletionChunk
12
13
  from grasp_agents.typing.events import (
14
+ AnnotationsChunkEvent,
15
+ AnnotationsEndEvent,
16
+ AnnotationsStartEvent,
13
17
  CompletionChunkEvent,
18
+ # CompletionEndEvent,
19
+ CompletionStartEvent,
14
20
  Event,
15
21
  GenMessageEvent,
16
22
  MessageEvent,
17
23
  ProcPacketOutputEvent,
24
+ ResponseChunkEvent,
25
+ ResponseEndEvent,
26
+ ResponseStartEvent,
18
27
  RunResultEvent,
19
28
  SystemMessageEvent,
29
+ ThinkingChunkEvent,
30
+ ThinkingEndEvent,
31
+ ThinkingStartEvent,
32
+ ToolCallChunkEvent,
33
+ ToolCallEndEvent,
34
+ ToolCallStartEvent,
20
35
  ToolMessageEvent,
21
36
  UserMessageEvent,
22
37
  WorkflowResultEvent,
@@ -24,7 +39,14 @@ from grasp_agents.typing.events import (
24
39
 
25
40
  from .typing.completion import Usage
26
41
  from .typing.content import Content, ContentPartText
27
- from .typing.message import AssistantMessage, Message, Role, SystemMessage, UserMessage
42
+ from .typing.message import (
43
+ AssistantMessage,
44
+ Message,
45
+ Role,
46
+ SystemMessage,
47
+ ToolMessage,
48
+ UserMessage,
49
+ )
28
50
 
29
51
  logger = logging.getLogger(__name__)
30
52
 
@@ -72,7 +94,7 @@ class Printer:
72
94
  return AVAILABLE_COLORS[idx]
73
95
 
74
96
  @staticmethod
75
- def content_to_str(content: Content | str, role: Role) -> str:
97
+ def content_to_str(content: Content | str | None, role: Role) -> str:
76
98
  if role == Role.USER and isinstance(content, Content):
77
99
  content_str_parts: list[str] = []
78
100
  for content_part in content.parts:
@@ -84,9 +106,9 @@ class Printer:
84
106
  content_str_parts.append("<ENCODED_IMAGE>")
85
107
  return "\n".join(content_str_parts)
86
108
 
87
- assert isinstance(content, str)
109
+ assert isinstance(content, str | None)
88
110
 
89
- return content.strip(" \n")
111
+ return (content or "").strip(" \n")
90
112
 
91
113
  @staticmethod
92
114
  def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
@@ -179,36 +201,10 @@ def stream_text(new_text: str, color: Color) -> None:
179
201
  async def print_event_stream(
180
202
  event_generator: AsyncIterator[Event[Any]],
181
203
  color_by: ColoringMode = "role",
182
- trunc_len: int = 1000,
204
+ trunc_len: int = 10000,
183
205
  ) -> AsyncIterator[Event[Any]]:
184
- prev_chunk_id: str | None = None
185
- thinking_open = False
186
- response_open = False
187
- open_tool_calls: set[str] = set()
188
-
189
206
  color = Printer.get_role_color(Role.ASSISTANT)
190
207
 
191
- def _close_blocks(
192
- _thinking_open: bool, _response_open: bool, color: Color
193
- ) -> tuple[bool, bool]:
194
- closing_text = ""
195
- while open_tool_calls:
196
- open_tool_calls.pop()
197
- closing_text += "\n</tool call>\n"
198
-
199
- if _thinking_open:
200
- closing_text += "\n</thinking>\n"
201
- _thinking_open = False
202
-
203
- if _response_open:
204
- closing_text += "\n</response>\n"
205
- _response_open = False
206
-
207
- if closing_text:
208
- stream_text(closing_text, color)
209
-
210
- return _thinking_open, _response_open
211
-
212
208
  def _get_color(event: Event[Any], role: Role = Role.ASSISTANT) -> Color:
213
209
  if color_by == "agent":
214
210
  return Printer.get_agent_color(event.proc_name or "")
@@ -232,14 +228,7 @@ async def print_event_stream(
232
228
  text += f"<{src} output>\n"
233
229
  for p in event.data.payloads:
234
230
  if isinstance(p, BaseModel):
235
- # for field_info in type(p).model_fields.values():
236
- # if field_info.exclude:
237
- # field_info.exclude = False
238
- # break
239
- # type(p).model_rebuild(force=True)
240
231
  p_str = p.model_dump_json(indent=2)
241
- # field_info.exclude = True # type: ignore
242
- # type(p).model_rebuild(force=True)
243
232
  else:
244
233
  try:
245
234
  p_str = json.dumps(p, indent=2)
@@ -253,56 +242,64 @@ async def print_event_stream(
253
242
  async for event in event_generator:
254
243
  yield event
255
244
 
256
- if isinstance(event, CompletionChunkEvent):
257
- delta = event.data.choices[0].delta
258
- chunk_id = event.data.id
259
- new_completion = chunk_id != prev_chunk_id
245
+ if isinstance(event, CompletionChunkEvent) and isinstance(
246
+ event.data, CompletionChunk
247
+ ):
260
248
  color = _get_color(event, Role.ASSISTANT)
261
249
 
262
250
  text = ""
263
251
 
264
- if new_completion:
265
- thinking_open, response_open = _close_blocks(
266
- thinking_open, response_open, color
267
- )
252
+ if isinstance(event, CompletionStartEvent):
268
253
  text += f"\n<{event.proc_name}> [{event.call_id}]\n"
269
-
270
- if delta.reasoning_content:
271
- if not thinking_open:
272
- text += "<thinking>\n"
273
- thinking_open = True
274
- text += delta.reasoning_content
275
- elif thinking_open:
254
+ elif isinstance(event, ThinkingStartEvent):
255
+ text += "<thinking>\n"
256
+ elif isinstance(event, ResponseStartEvent):
257
+ text += "<response>\n"
258
+ elif isinstance(event, ToolCallStartEvent):
259
+ tc = event.data.tool_call
260
+ text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
261
+ elif isinstance(event, AnnotationsStartEvent):
262
+ text += "<annotations>\n"
263
+
264
+ # if isinstance(event, CompletionEndEvent):
265
+ # text += f"\n</{event.proc_name}>\n"
266
+ if isinstance(event, ThinkingEndEvent):
276
267
  text += "\n</thinking>\n"
277
- thinking_open = False
278
-
279
- if delta.content:
280
- if not response_open:
281
- text += "<response>\n"
282
- response_open = True
283
- text += delta.content
284
- elif response_open:
268
+ elif isinstance(event, ResponseEndEvent):
285
269
  text += "\n</response>\n"
286
- response_open = False
287
-
288
- if delta.tool_calls:
289
- for tc in delta.tool_calls:
290
- if tc.id and tc.id not in open_tool_calls:
291
- open_tool_calls.add(tc.id) # type: ignore
292
- text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
293
-
294
- if tc.tool_arguments:
295
- text += tc.tool_arguments
270
+ elif isinstance(event, ToolCallEndEvent):
271
+ text += "\n</tool call>\n"
272
+ elif isinstance(event, AnnotationsEndEvent):
273
+ text += "\n</annotations>\n"
274
+
275
+ if isinstance(event, ThinkingChunkEvent):
276
+ thinking = event.data.thinking
277
+ if isinstance(thinking, str):
278
+ text += thinking
279
+ else:
280
+ text = "\n".join(
281
+ [block.get("thinking", "[redacted]") for block in thinking]
282
+ )
283
+
284
+ if isinstance(event, ResponseChunkEvent):
285
+ text += event.data.response
286
+
287
+ if isinstance(event, ToolCallChunkEvent):
288
+ text += event.data.tool_call.tool_arguments or ""
289
+
290
+ if isinstance(event, AnnotationsChunkEvent):
291
+ text += "\n".join(
292
+ [
293
+ json.dumps(annotation, indent=2)
294
+ for annotation in event.data.annotations
295
+ ]
296
+ )
296
297
 
297
298
  stream_text(text, color)
298
- prev_chunk_id = chunk_id
299
-
300
- else:
301
- thinking_open, response_open = _close_blocks(
302
- thinking_open, response_open, color
303
- )
304
299
 
305
300
  if isinstance(event, MessageEvent) and not isinstance(event, GenMessageEvent):
301
+ assert isinstance(event.data, (SystemMessage | UserMessage | ToolMessage))
302
+
306
303
  message = event.data
307
304
  role = message.role
308
305
  content = Printer.content_to_str(message.content, role=role)
@@ -320,6 +317,7 @@ async def print_event_stream(
320
317
  text += f"<input>\n{content}\n</input>\n"
321
318
 
322
319
  elif isinstance(event, ToolMessageEvent):
320
+ message = event.data
323
321
  try:
324
322
  content = json.dumps(json.loads(content), indent=2)
325
323
  except Exception: