grasp_agents 0.5.6__py3-none-any.whl → 0.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +11 -5
- grasp_agents/llm.py +146 -1
- grasp_agents/llm_policy_executor.py +9 -3
- grasp_agents/packet_pool.py +23 -43
- grasp_agents/printer.py +75 -77
- grasp_agents/processors/parallel_processor.py +15 -13
- grasp_agents/runner.py +27 -24
- grasp_agents/typing/completion_chunk.py +302 -3
- grasp_agents/typing/events.py +256 -50
- {grasp_agents-0.5.6.dist-info → grasp_agents-0.5.8.dist-info}/METADATA +1 -1
- {grasp_agents-0.5.6.dist-info → grasp_agents-0.5.8.dist-info}/RECORD +13 -13
- {grasp_agents-0.5.6.dist-info → grasp_agents-0.5.8.dist-info}/WHEEL +0 -0
- {grasp_agents-0.5.6.dist-info → grasp_agents-0.5.8.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/cloud_llm.py
CHANGED
@@ -13,7 +13,7 @@ from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
|
|
13
13
|
from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
|
14
14
|
from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
|
15
15
|
from .typing.completion import Completion
|
16
|
-
from .typing.completion_chunk import CompletionChoice
|
16
|
+
from .typing.completion_chunk import CompletionChoice, CompletionChunk
|
17
17
|
from .typing.events import (
|
18
18
|
CompletionChunkEvent,
|
19
19
|
CompletionEvent,
|
@@ -52,7 +52,9 @@ class CloudLLMSettings(LLMSettings, total=False):
|
|
52
52
|
LLMRateLimiter = RateLimiterC[
|
53
53
|
Messages,
|
54
54
|
AssistantMessage
|
55
|
-
| AsyncIterator[
|
55
|
+
| AsyncIterator[
|
56
|
+
CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
|
57
|
+
],
|
56
58
|
]
|
57
59
|
|
58
60
|
|
@@ -274,7 +276,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
274
276
|
n_choices: int | None = None,
|
275
277
|
proc_name: str | None = None,
|
276
278
|
call_id: str | None = None,
|
277
|
-
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
279
|
+
) -> AsyncIterator[CompletionChunkEvent[CompletionChunk] | CompletionEvent]:
|
278
280
|
completion_kwargs = self._make_completion_kwargs(
|
279
281
|
conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
|
280
282
|
)
|
@@ -284,7 +286,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
284
286
|
api_stream = self._get_completion_stream(**completion_kwargs)
|
285
287
|
api_stream = cast("AsyncIterator[Any]", api_stream)
|
286
288
|
|
287
|
-
async def iterator() -> AsyncIterator[
|
289
|
+
async def iterator() -> AsyncIterator[
|
290
|
+
CompletionChunkEvent[CompletionChunk] | CompletionEvent
|
291
|
+
]:
|
288
292
|
api_completion_chunks: list[Any] = []
|
289
293
|
|
290
294
|
async for api_completion_chunk in api_stream:
|
@@ -318,7 +322,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
318
322
|
n_choices: int | None = None,
|
319
323
|
proc_name: str | None = None,
|
320
324
|
call_id: str | None = None,
|
321
|
-
) -> AsyncIterator[
|
325
|
+
) -> AsyncIterator[
|
326
|
+
CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
|
327
|
+
]:
|
322
328
|
n_attempt = 0
|
323
329
|
while n_attempt <= self.max_response_retries:
|
324
330
|
try:
|
grasp_agents/llm.py
CHANGED
@@ -7,6 +7,7 @@ from uuid import uuid4
|
|
7
7
|
from pydantic import BaseModel
|
8
8
|
from typing_extensions import TypedDict
|
9
9
|
|
10
|
+
from grasp_agents.typing.completion_chunk import CompletionChunk
|
10
11
|
from grasp_agents.utils import (
|
11
12
|
validate_obj_from_json_or_py_string,
|
12
13
|
validate_tagged_objs_from_json_or_py_string,
|
@@ -20,9 +21,25 @@ from .errors import (
|
|
20
21
|
from .typing.completion import Completion
|
21
22
|
from .typing.converters import Converters
|
22
23
|
from .typing.events import (
|
24
|
+
AnnotationsChunkEvent,
|
25
|
+
AnnotationsEndEvent,
|
26
|
+
AnnotationsStartEvent,
|
23
27
|
CompletionChunkEvent,
|
28
|
+
CompletionEndEvent,
|
24
29
|
CompletionEvent,
|
30
|
+
CompletionStartEvent,
|
31
|
+
LLMStateChangeEvent,
|
25
32
|
LLMStreamingErrorEvent,
|
33
|
+
# RefusalChunkEvent,
|
34
|
+
ResponseChunkEvent,
|
35
|
+
ResponseEndEvent,
|
36
|
+
ResponseStartEvent,
|
37
|
+
ThinkingChunkEvent,
|
38
|
+
ThinkingEndEvent,
|
39
|
+
ThinkingStartEvent,
|
40
|
+
ToolCallChunkEvent,
|
41
|
+
ToolCallEndEvent,
|
42
|
+
ToolCallStartEvent,
|
26
43
|
)
|
27
44
|
from .typing.message import Messages
|
28
45
|
from .typing.tool import BaseTool, ToolChoice
|
@@ -30,6 +47,14 @@ from .typing.tool import BaseTool, ToolChoice
|
|
30
47
|
logger = logging.getLogger(__name__)
|
31
48
|
|
32
49
|
|
50
|
+
LLMStreamGenerator = AsyncIterator[
|
51
|
+
CompletionChunkEvent[CompletionChunk]
|
52
|
+
| CompletionEvent
|
53
|
+
| LLMStateChangeEvent[Any]
|
54
|
+
| LLMStreamingErrorEvent
|
55
|
+
]
|
56
|
+
|
57
|
+
|
33
58
|
class LLMSettings(TypedDict, total=False):
|
34
59
|
max_completion_tokens: int | None
|
35
60
|
temperature: float | None
|
@@ -160,6 +185,124 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
160
185
|
tool_name, tool_arguments
|
161
186
|
) from exc
|
162
187
|
|
188
|
+
@staticmethod
|
189
|
+
async def postprocess_event_stream(
|
190
|
+
stream: LLMStreamGenerator,
|
191
|
+
) -> LLMStreamGenerator:
|
192
|
+
prev_completion_id: str | None = None
|
193
|
+
chunk_op_evt: CompletionChunkEvent[CompletionChunk] | None = None
|
194
|
+
response_op_evt: ResponseChunkEvent | None = None
|
195
|
+
thinking_op_evt: ThinkingChunkEvent | None = None
|
196
|
+
annotations_op_evt: AnnotationsChunkEvent | None = None
|
197
|
+
tool_calls_op_evt: ToolCallChunkEvent | None = None
|
198
|
+
|
199
|
+
def _close_open_events() -> list[LLMStateChangeEvent[Any]]:
|
200
|
+
nonlocal \
|
201
|
+
chunk_op_evt, \
|
202
|
+
thinking_op_evt, \
|
203
|
+
tool_calls_op_evt, \
|
204
|
+
response_op_evt, \
|
205
|
+
annotations_op_evt
|
206
|
+
|
207
|
+
events: list[LLMStateChangeEvent[Any]] = []
|
208
|
+
|
209
|
+
if tool_calls_op_evt:
|
210
|
+
events.append(ToolCallEndEvent.from_chunk_event(tool_calls_op_evt))
|
211
|
+
|
212
|
+
if response_op_evt:
|
213
|
+
events.append(ResponseEndEvent.from_chunk_event(response_op_evt))
|
214
|
+
|
215
|
+
if thinking_op_evt:
|
216
|
+
events.append(ThinkingEndEvent.from_chunk_event(thinking_op_evt))
|
217
|
+
|
218
|
+
if annotations_op_evt:
|
219
|
+
events.append(AnnotationsEndEvent.from_chunk_event(annotations_op_evt))
|
220
|
+
|
221
|
+
if chunk_op_evt:
|
222
|
+
events.append(CompletionEndEvent.from_chunk_event(chunk_op_evt))
|
223
|
+
|
224
|
+
chunk_op_evt = None
|
225
|
+
thinking_op_evt = None
|
226
|
+
tool_calls_op_evt = None
|
227
|
+
response_op_evt = None
|
228
|
+
annotations_op_evt = None
|
229
|
+
|
230
|
+
return events
|
231
|
+
|
232
|
+
async for event in stream:
|
233
|
+
if isinstance(event, CompletionChunkEvent) and not isinstance(
|
234
|
+
event, LLMStateChangeEvent
|
235
|
+
):
|
236
|
+
chunk = event.data
|
237
|
+
if len(chunk.choices) != 1:
|
238
|
+
raise ValueError(
|
239
|
+
"Expected exactly one choice in completion chunk, "
|
240
|
+
f"got {len(chunk.choices)}"
|
241
|
+
)
|
242
|
+
|
243
|
+
new_completion = chunk.id != prev_completion_id
|
244
|
+
|
245
|
+
if new_completion:
|
246
|
+
for close_event in _close_open_events():
|
247
|
+
yield close_event
|
248
|
+
|
249
|
+
chunk_op_evt = event
|
250
|
+
yield CompletionStartEvent.from_chunk_event(event)
|
251
|
+
|
252
|
+
sub_events = event.split_into_specialized()
|
253
|
+
|
254
|
+
for sub_event in sub_events:
|
255
|
+
if isinstance(sub_event, ThinkingChunkEvent):
|
256
|
+
if not thinking_op_evt:
|
257
|
+
thinking_op_evt = sub_event
|
258
|
+
yield ThinkingStartEvent.from_chunk_event(sub_event)
|
259
|
+
yield sub_event
|
260
|
+
elif thinking_op_evt:
|
261
|
+
yield ThinkingEndEvent.from_chunk_event(thinking_op_evt)
|
262
|
+
thinking_op_evt = None
|
263
|
+
|
264
|
+
if isinstance(sub_event, ToolCallChunkEvent):
|
265
|
+
tc = sub_event.data.tool_call
|
266
|
+
if tc.id:
|
267
|
+
# Tool call ID is not None only for the first chunk of a tool call
|
268
|
+
if tool_calls_op_evt:
|
269
|
+
yield ToolCallEndEvent.from_chunk_event(
|
270
|
+
tool_calls_op_evt
|
271
|
+
)
|
272
|
+
tool_calls_op_evt = None
|
273
|
+
tool_calls_op_evt = sub_event
|
274
|
+
yield ToolCallStartEvent.from_chunk_event(sub_event)
|
275
|
+
yield sub_event
|
276
|
+
elif tool_calls_op_evt:
|
277
|
+
yield ToolCallEndEvent.from_chunk_event(tool_calls_op_evt)
|
278
|
+
tool_calls_op_evt = None
|
279
|
+
|
280
|
+
if isinstance(sub_event, ResponseChunkEvent):
|
281
|
+
if not response_op_evt:
|
282
|
+
response_op_evt = sub_event
|
283
|
+
yield ResponseStartEvent.from_chunk_event(sub_event)
|
284
|
+
yield sub_event
|
285
|
+
elif response_op_evt:
|
286
|
+
yield ResponseEndEvent.from_chunk_event(response_op_evt)
|
287
|
+
response_op_evt = None
|
288
|
+
|
289
|
+
if isinstance(sub_event, AnnotationsChunkEvent):
|
290
|
+
if not annotations_op_evt:
|
291
|
+
annotations_op_evt = sub_event
|
292
|
+
yield AnnotationsStartEvent.from_chunk_event(sub_event)
|
293
|
+
yield sub_event
|
294
|
+
elif annotations_op_evt:
|
295
|
+
yield AnnotationsEndEvent.from_chunk_event(annotations_op_evt)
|
296
|
+
annotations_op_evt = None
|
297
|
+
|
298
|
+
prev_completion_id = chunk.id
|
299
|
+
|
300
|
+
else:
|
301
|
+
for close_event in _close_open_events():
|
302
|
+
yield close_event
|
303
|
+
|
304
|
+
yield event
|
305
|
+
|
163
306
|
@abstractmethod
|
164
307
|
async def generate_completion(
|
165
308
|
self,
|
@@ -181,7 +324,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
181
324
|
n_choices: int | None = None,
|
182
325
|
proc_name: str | None = None,
|
183
326
|
call_id: str | None = None,
|
184
|
-
) -> AsyncIterator[
|
327
|
+
) -> AsyncIterator[
|
328
|
+
CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
|
329
|
+
]:
|
185
330
|
pass
|
186
331
|
|
187
332
|
@abstractmethod
|
@@ -7,6 +7,8 @@ from typing import Any, Generic, Protocol, final
|
|
7
7
|
|
8
8
|
from pydantic import BaseModel
|
9
9
|
|
10
|
+
from grasp_agents.typing.completion_chunk import CompletionChunk
|
11
|
+
|
10
12
|
from .errors import AgentFinalAnswerError
|
11
13
|
from .llm import LLM, LLMSettings
|
12
14
|
from .llm_agent_memory import LLMAgentMemory
|
@@ -149,19 +151,23 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
149
151
|
tool_choice: ToolChoice | None = None,
|
150
152
|
ctx: RunContext[CtxT] | None = None,
|
151
153
|
) -> AsyncIterator[
|
152
|
-
CompletionChunkEvent
|
154
|
+
CompletionChunkEvent[CompletionChunk]
|
153
155
|
| CompletionEvent
|
154
156
|
| GenMessageEvent
|
155
157
|
| LLMStreamingErrorEvent
|
156
158
|
]:
|
157
159
|
completion: Completion | None = None
|
158
|
-
|
160
|
+
|
161
|
+
llm_event_stream = self.llm.generate_completion_stream(
|
159
162
|
memory.message_history,
|
160
163
|
tool_choice=tool_choice,
|
161
164
|
n_choices=1,
|
162
165
|
proc_name=self.agent_name,
|
163
166
|
call_id=call_id,
|
164
|
-
)
|
167
|
+
)
|
168
|
+
llm_event_stream_post = self.llm.postprocess_event_stream(llm_event_stream) # type: ignore[assignment]
|
169
|
+
|
170
|
+
async for event in llm_event_stream_post:
|
165
171
|
if isinstance(event, CompletionEvent):
|
166
172
|
completion = event.data
|
167
173
|
yield event
|
grasp_agents/packet_pool.py
CHANGED
@@ -2,10 +2,9 @@ import asyncio
|
|
2
2
|
import logging
|
3
3
|
from collections.abc import AsyncIterator
|
4
4
|
from types import TracebackType
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Literal, Protocol, TypeVar
|
6
6
|
|
7
7
|
from .packet import Packet
|
8
|
-
from .run_context import CtxT, RunContext
|
9
8
|
from .typing.events import Event
|
10
9
|
from .typing.io import ProcName
|
11
10
|
|
@@ -18,24 +17,21 @@ END_PROC_NAME: Literal["*END*"] = "*END*"
|
|
18
17
|
_PayloadT_contra = TypeVar("_PayloadT_contra", contravariant=True)
|
19
18
|
|
20
19
|
|
21
|
-
class PacketHandler(Protocol[_PayloadT_contra
|
20
|
+
class PacketHandler(Protocol[_PayloadT_contra]):
|
22
21
|
async def __call__(
|
23
|
-
self,
|
24
|
-
packet: Packet[_PayloadT_contra],
|
25
|
-
ctx: RunContext[CtxT],
|
26
|
-
**kwargs: Any,
|
22
|
+
self, packet: Packet[_PayloadT_contra], **kwargs: Any
|
27
23
|
) -> None: ...
|
28
24
|
|
29
25
|
|
30
|
-
class PacketPool
|
26
|
+
class PacketPool:
|
31
27
|
def __init__(self) -> None:
|
32
28
|
self._packet_queues: dict[ProcName, asyncio.Queue[Packet[Any] | None]] = {}
|
33
|
-
self._packet_handlers: dict[ProcName, PacketHandler[Any
|
29
|
+
self._packet_handlers: dict[ProcName, PacketHandler[Any]] = {}
|
34
30
|
self._task_group: asyncio.TaskGroup | None = None
|
35
31
|
|
36
32
|
self._event_queue: asyncio.Queue[Event[Any] | None] = asyncio.Queue()
|
37
33
|
|
38
|
-
self._final_result_fut: asyncio.Future[Packet[Any]]
|
34
|
+
self._final_result_fut: asyncio.Future[Packet[Any]]
|
39
35
|
|
40
36
|
self._stopping = False
|
41
37
|
self._stopped_evt = asyncio.Event()
|
@@ -44,9 +40,8 @@ class PacketPool(Generic[CtxT]):
|
|
44
40
|
|
45
41
|
async def post(self, packet: Packet[Any]) -> None:
|
46
42
|
if packet.recipients == [END_PROC_NAME]:
|
47
|
-
|
48
|
-
|
49
|
-
fut.set_result(packet)
|
43
|
+
if not self._final_result_fut.done():
|
44
|
+
self._final_result_fut.set_result(packet)
|
50
45
|
await self.shutdown()
|
51
46
|
return
|
52
47
|
|
@@ -54,31 +49,14 @@ class PacketPool(Generic[CtxT]):
|
|
54
49
|
queue = self._packet_queues.setdefault(recipient_id, asyncio.Queue())
|
55
50
|
await queue.put(packet)
|
56
51
|
|
57
|
-
def _ensure_final_future(self) -> asyncio.Future[Packet[Any]]:
|
58
|
-
fut = self._final_result_fut
|
59
|
-
if fut is None:
|
60
|
-
fut = asyncio.get_running_loop().create_future()
|
61
|
-
self._final_result_fut = fut
|
62
|
-
return fut
|
63
|
-
|
64
52
|
async def final_result(self) -> Packet[Any]:
|
65
|
-
fut = self._ensure_final_future()
|
66
53
|
try:
|
67
|
-
return await
|
54
|
+
return await self._final_result_fut
|
68
55
|
finally:
|
69
56
|
await self.shutdown()
|
70
57
|
|
71
|
-
@property
|
72
|
-
def final_result_ready(self) -> bool:
|
73
|
-
fut = self._final_result_fut
|
74
|
-
return fut is not None and fut.done()
|
75
|
-
|
76
58
|
def register_packet_handler(
|
77
|
-
self,
|
78
|
-
proc_name: ProcName,
|
79
|
-
handler: PacketHandler[Any, CtxT],
|
80
|
-
ctx: RunContext[CtxT],
|
81
|
-
**run_kwargs: Any,
|
59
|
+
self, proc_name: ProcName, handler: PacketHandler[Any]
|
82
60
|
) -> None:
|
83
61
|
if self._stopping:
|
84
62
|
raise RuntimeError("PacketPool is stopping/stopped")
|
@@ -88,17 +66,19 @@ class PacketPool(Generic[CtxT]):
|
|
88
66
|
|
89
67
|
if self._task_group is not None:
|
90
68
|
self._task_group.create_task(
|
91
|
-
self._handle_packets(proc_name
|
69
|
+
self._handle_packets(proc_name),
|
92
70
|
name=f"packet-handler:{proc_name}",
|
93
71
|
)
|
94
72
|
|
95
73
|
async def push_event(self, event: Event[Any]) -> None:
|
96
74
|
await self._event_queue.put(event)
|
97
75
|
|
98
|
-
async def __aenter__(self) -> "PacketPool
|
76
|
+
async def __aenter__(self) -> "PacketPool":
|
99
77
|
self._task_group = asyncio.TaskGroup()
|
100
78
|
await self._task_group.__aenter__()
|
101
79
|
|
80
|
+
self._final_result_fut = asyncio.get_running_loop().create_future()
|
81
|
+
|
102
82
|
return self
|
103
83
|
|
104
84
|
async def __aexit__(
|
@@ -120,26 +100,27 @@ class PacketPool(Generic[CtxT]):
|
|
120
100
|
|
121
101
|
return False
|
122
102
|
|
123
|
-
async def _handle_packets(
|
124
|
-
self, proc_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
|
125
|
-
) -> None:
|
103
|
+
async def _handle_packets(self, proc_name: ProcName) -> None:
|
126
104
|
queue = self._packet_queues[proc_name]
|
127
105
|
handler = self._packet_handlers[proc_name]
|
128
106
|
|
129
|
-
while
|
107
|
+
while True:
|
130
108
|
packet = await queue.get()
|
131
109
|
if packet is None:
|
132
110
|
break
|
111
|
+
|
112
|
+
if self._final_result_fut.done():
|
113
|
+
continue
|
114
|
+
|
133
115
|
try:
|
134
|
-
await handler(packet
|
116
|
+
await handler(packet)
|
135
117
|
except asyncio.CancelledError:
|
136
118
|
raise
|
137
119
|
except Exception as err:
|
138
120
|
logger.exception("Error handling packet for %s", proc_name)
|
139
121
|
self._errors.append(err)
|
140
|
-
|
141
|
-
|
142
|
-
fut.set_exception(err)
|
122
|
+
if not self._final_result_fut.done():
|
123
|
+
self._final_result_fut.set_exception(err)
|
143
124
|
await self.shutdown()
|
144
125
|
raise
|
145
126
|
|
@@ -159,6 +140,5 @@ class PacketPool(Generic[CtxT]):
|
|
159
140
|
await self._event_queue.put(None)
|
160
141
|
for queue in self._packet_queues.values():
|
161
142
|
await queue.put(None)
|
162
|
-
|
163
143
|
finally:
|
164
144
|
self._stopped_evt.set()
|
grasp_agents/printer.py
CHANGED
@@ -9,14 +9,29 @@ from pydantic import BaseModel
|
|
9
9
|
from termcolor import colored
|
10
10
|
from termcolor._types import Color
|
11
11
|
|
12
|
+
from grasp_agents.typing.completion_chunk import CompletionChunk
|
12
13
|
from grasp_agents.typing.events import (
|
14
|
+
AnnotationsChunkEvent,
|
15
|
+
AnnotationsEndEvent,
|
16
|
+
AnnotationsStartEvent,
|
13
17
|
CompletionChunkEvent,
|
18
|
+
# CompletionEndEvent,
|
19
|
+
CompletionStartEvent,
|
14
20
|
Event,
|
15
21
|
GenMessageEvent,
|
16
22
|
MessageEvent,
|
17
23
|
ProcPacketOutputEvent,
|
24
|
+
ResponseChunkEvent,
|
25
|
+
ResponseEndEvent,
|
26
|
+
ResponseStartEvent,
|
18
27
|
RunResultEvent,
|
19
28
|
SystemMessageEvent,
|
29
|
+
ThinkingChunkEvent,
|
30
|
+
ThinkingEndEvent,
|
31
|
+
ThinkingStartEvent,
|
32
|
+
ToolCallChunkEvent,
|
33
|
+
ToolCallEndEvent,
|
34
|
+
ToolCallStartEvent,
|
20
35
|
ToolMessageEvent,
|
21
36
|
UserMessageEvent,
|
22
37
|
WorkflowResultEvent,
|
@@ -24,7 +39,14 @@ from grasp_agents.typing.events import (
|
|
24
39
|
|
25
40
|
from .typing.completion import Usage
|
26
41
|
from .typing.content import Content, ContentPartText
|
27
|
-
from .typing.message import
|
42
|
+
from .typing.message import (
|
43
|
+
AssistantMessage,
|
44
|
+
Message,
|
45
|
+
Role,
|
46
|
+
SystemMessage,
|
47
|
+
ToolMessage,
|
48
|
+
UserMessage,
|
49
|
+
)
|
28
50
|
|
29
51
|
logger = logging.getLogger(__name__)
|
30
52
|
|
@@ -72,7 +94,7 @@ class Printer:
|
|
72
94
|
return AVAILABLE_COLORS[idx]
|
73
95
|
|
74
96
|
@staticmethod
|
75
|
-
def content_to_str(content: Content | str, role: Role) -> str:
|
97
|
+
def content_to_str(content: Content | str | None, role: Role) -> str:
|
76
98
|
if role == Role.USER and isinstance(content, Content):
|
77
99
|
content_str_parts: list[str] = []
|
78
100
|
for content_part in content.parts:
|
@@ -84,9 +106,9 @@ class Printer:
|
|
84
106
|
content_str_parts.append("<ENCODED_IMAGE>")
|
85
107
|
return "\n".join(content_str_parts)
|
86
108
|
|
87
|
-
assert isinstance(content, str)
|
109
|
+
assert isinstance(content, str | None)
|
88
110
|
|
89
|
-
return content.strip(" \n")
|
111
|
+
return (content or "").strip(" \n")
|
90
112
|
|
91
113
|
@staticmethod
|
92
114
|
def truncate_content_str(content_str: str, trunc_len: int = 2000) -> str:
|
@@ -179,36 +201,10 @@ def stream_text(new_text: str, color: Color) -> None:
|
|
179
201
|
async def print_event_stream(
|
180
202
|
event_generator: AsyncIterator[Event[Any]],
|
181
203
|
color_by: ColoringMode = "role",
|
182
|
-
trunc_len: int =
|
204
|
+
trunc_len: int = 10000,
|
183
205
|
) -> AsyncIterator[Event[Any]]:
|
184
|
-
prev_chunk_id: str | None = None
|
185
|
-
thinking_open = False
|
186
|
-
response_open = False
|
187
|
-
open_tool_calls: set[str] = set()
|
188
|
-
|
189
206
|
color = Printer.get_role_color(Role.ASSISTANT)
|
190
207
|
|
191
|
-
def _close_blocks(
|
192
|
-
_thinking_open: bool, _response_open: bool, color: Color
|
193
|
-
) -> tuple[bool, bool]:
|
194
|
-
closing_text = ""
|
195
|
-
while open_tool_calls:
|
196
|
-
open_tool_calls.pop()
|
197
|
-
closing_text += "\n</tool call>\n"
|
198
|
-
|
199
|
-
if _thinking_open:
|
200
|
-
closing_text += "\n</thinking>\n"
|
201
|
-
_thinking_open = False
|
202
|
-
|
203
|
-
if _response_open:
|
204
|
-
closing_text += "\n</response>\n"
|
205
|
-
_response_open = False
|
206
|
-
|
207
|
-
if closing_text:
|
208
|
-
stream_text(closing_text, color)
|
209
|
-
|
210
|
-
return _thinking_open, _response_open
|
211
|
-
|
212
208
|
def _get_color(event: Event[Any], role: Role = Role.ASSISTANT) -> Color:
|
213
209
|
if color_by == "agent":
|
214
210
|
return Printer.get_agent_color(event.proc_name or "")
|
@@ -232,14 +228,7 @@ async def print_event_stream(
|
|
232
228
|
text += f"<{src} output>\n"
|
233
229
|
for p in event.data.payloads:
|
234
230
|
if isinstance(p, BaseModel):
|
235
|
-
# for field_info in type(p).model_fields.values():
|
236
|
-
# if field_info.exclude:
|
237
|
-
# field_info.exclude = False
|
238
|
-
# break
|
239
|
-
# type(p).model_rebuild(force=True)
|
240
231
|
p_str = p.model_dump_json(indent=2)
|
241
|
-
# field_info.exclude = True # type: ignore
|
242
|
-
# type(p).model_rebuild(force=True)
|
243
232
|
else:
|
244
233
|
try:
|
245
234
|
p_str = json.dumps(p, indent=2)
|
@@ -253,56 +242,64 @@ async def print_event_stream(
|
|
253
242
|
async for event in event_generator:
|
254
243
|
yield event
|
255
244
|
|
256
|
-
if isinstance(event, CompletionChunkEvent)
|
257
|
-
|
258
|
-
|
259
|
-
new_completion = chunk_id != prev_chunk_id
|
245
|
+
if isinstance(event, CompletionChunkEvent) and isinstance(
|
246
|
+
event.data, CompletionChunk
|
247
|
+
):
|
260
248
|
color = _get_color(event, Role.ASSISTANT)
|
261
249
|
|
262
250
|
text = ""
|
263
251
|
|
264
|
-
if
|
265
|
-
thinking_open, response_open = _close_blocks(
|
266
|
-
thinking_open, response_open, color
|
267
|
-
)
|
252
|
+
if isinstance(event, CompletionStartEvent):
|
268
253
|
text += f"\n<{event.proc_name}> [{event.call_id}]\n"
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
254
|
+
elif isinstance(event, ThinkingStartEvent):
|
255
|
+
text += "<thinking>\n"
|
256
|
+
elif isinstance(event, ResponseStartEvent):
|
257
|
+
text += "<response>\n"
|
258
|
+
elif isinstance(event, ToolCallStartEvent):
|
259
|
+
tc = event.data.tool_call
|
260
|
+
text += f"<tool call> {tc.tool_name} [{tc.id}]\n"
|
261
|
+
elif isinstance(event, AnnotationsStartEvent):
|
262
|
+
text += "<annotations>\n"
|
263
|
+
|
264
|
+
# if isinstance(event, CompletionEndEvent):
|
265
|
+
# text += f"\n</{event.proc_name}>\n"
|
266
|
+
if isinstance(event, ThinkingEndEvent):
|
276
267
|
text += "\n</thinking>\n"
|
277
|
-
|
278
|
-
|
279
|
-
if delta.content:
|
280
|
-
if not response_open:
|
281
|
-
text += "<response>\n"
|
282
|
-
response_open = True
|
283
|
-
text += delta.content
|
284
|
-
elif response_open:
|
268
|
+
elif isinstance(event, ResponseEndEvent):
|
285
269
|
text += "\n</response>\n"
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
270
|
+
elif isinstance(event, ToolCallEndEvent):
|
271
|
+
text += "\n</tool call>\n"
|
272
|
+
elif isinstance(event, AnnotationsEndEvent):
|
273
|
+
text += "\n</annotations>\n"
|
274
|
+
|
275
|
+
if isinstance(event, ThinkingChunkEvent):
|
276
|
+
thinking = event.data.thinking
|
277
|
+
if isinstance(thinking, str):
|
278
|
+
text += thinking
|
279
|
+
else:
|
280
|
+
text = "\n".join(
|
281
|
+
[block.get("thinking", "[redacted]") for block in thinking]
|
282
|
+
)
|
283
|
+
|
284
|
+
if isinstance(event, ResponseChunkEvent):
|
285
|
+
text += event.data.response
|
286
|
+
|
287
|
+
if isinstance(event, ToolCallChunkEvent):
|
288
|
+
text += event.data.tool_call.tool_arguments or ""
|
289
|
+
|
290
|
+
if isinstance(event, AnnotationsChunkEvent):
|
291
|
+
text += "\n".join(
|
292
|
+
[
|
293
|
+
json.dumps(annotation, indent=2)
|
294
|
+
for annotation in event.data.annotations
|
295
|
+
]
|
296
|
+
)
|
296
297
|
|
297
298
|
stream_text(text, color)
|
298
|
-
prev_chunk_id = chunk_id
|
299
|
-
|
300
|
-
else:
|
301
|
-
thinking_open, response_open = _close_blocks(
|
302
|
-
thinking_open, response_open, color
|
303
|
-
)
|
304
299
|
|
305
300
|
if isinstance(event, MessageEvent) and not isinstance(event, GenMessageEvent):
|
301
|
+
assert isinstance(event.data, (SystemMessage | UserMessage | ToolMessage))
|
302
|
+
|
306
303
|
message = event.data
|
307
304
|
role = message.role
|
308
305
|
content = Printer.content_to_str(message.content, role=role)
|
@@ -320,6 +317,7 @@ async def print_event_stream(
|
|
320
317
|
text += f"<input>\n{content}\n</input>\n"
|
321
318
|
|
322
319
|
elif isinstance(event, ToolMessageEvent):
|
320
|
+
message = event.data
|
323
321
|
try:
|
324
322
|
content = json.dumps(json.loads(content), indent=2)
|
325
323
|
except Exception:
|