grasp_agents 0.5.5__tar.gz → 0.5.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/PKG-INFO +1 -1
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/pyproject.toml +1 -1
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/__init__.py +5 -1
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/cloud_llm.py +11 -5
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/llm.py +151 -2
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/llm_agent.py +18 -7
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/llm_policy_executor.py +9 -3
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/packet_pool.py +22 -37
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/printer.py +75 -74
- grasp_agents-0.5.5/src/grasp_agents/processor.py → grasp_agents-0.5.8/src/grasp_agents/processors/base_processor.py +89 -287
- grasp_agents-0.5.8/src/grasp_agents/processors/parallel_processor.py +246 -0
- grasp_agents-0.5.8/src/grasp_agents/processors/processor.py +161 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/runner.py +46 -24
- grasp_agents-0.5.8/src/grasp_agents/typing/completion_chunk.py +506 -0
- grasp_agents-0.5.8/src/grasp_agents/typing/events.py +376 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/workflow/looped_workflow.py +35 -27
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/workflow/sequential_workflow.py +14 -3
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/workflow/workflow_processor.py +21 -15
- grasp_agents-0.5.5/src/grasp_agents/typing/completion_chunk.py +0 -207
- grasp_agents-0.5.5/src/grasp_agents/typing/events.py +0 -166
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/.gitignore +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/LICENSE.md +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/README.md +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/costs_dict.yaml +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/errors.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/generics_utils.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/grasp_logging.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/http_client.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/litellm/__init__.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/litellm/completion_chunk_converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/litellm/completion_converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/litellm/converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/litellm/lite_llm.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/litellm/message_converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/llm_agent_memory.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/memory.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/openai/__init__.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/openai/completion_chunk_converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/openai/completion_converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/openai/content_converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/openai/converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/openai/message_converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/openai/openai_llm.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/openai/tool_converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/packet.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/prompt_builder.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/rate_limiting/__init__.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/rate_limiting/types.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/rate_limiting/utils.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/run_context.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/typing/__init__.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/typing/completion.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/typing/content.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/typing/converters.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/typing/io.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/typing/message.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/typing/tool.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/usage_tracker.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/utils.py +0 -0
- {grasp_agents-0.5.5 → grasp_agents-0.5.8}/src/grasp_agents/workflow/__init__.py +0 -0
@@ -6,7 +6,9 @@ from .llm_agent import LLMAgent
|
|
6
6
|
from .llm_agent_memory import LLMAgentMemory
|
7
7
|
from .memory import Memory
|
8
8
|
from .packet import Packet
|
9
|
-
from .
|
9
|
+
from .processors.base_processor import BaseProcessor
|
10
|
+
from .processors.parallel_processor import ParallelProcessor
|
11
|
+
from .processors.processor import Processor
|
10
12
|
from .run_context import RunContext
|
11
13
|
from .typing.completion import Completion
|
12
14
|
from .typing.content import Content, ImageData
|
@@ -17,6 +19,7 @@ from .typing.tool import BaseTool
|
|
17
19
|
__all__ = [
|
18
20
|
"LLM",
|
19
21
|
"AssistantMessage",
|
22
|
+
"BaseProcessor",
|
20
23
|
"BaseTool",
|
21
24
|
"Completion",
|
22
25
|
"Content",
|
@@ -29,6 +32,7 @@ __all__ = [
|
|
29
32
|
"Messages",
|
30
33
|
"Packet",
|
31
34
|
"Packet",
|
35
|
+
"ParallelProcessor",
|
32
36
|
"ProcName",
|
33
37
|
"Processor",
|
34
38
|
"RunContext",
|
@@ -13,7 +13,7 @@ from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
|
|
13
13
|
from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
|
14
14
|
from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
|
15
15
|
from .typing.completion import Completion
|
16
|
-
from .typing.completion_chunk import CompletionChoice
|
16
|
+
from .typing.completion_chunk import CompletionChoice, CompletionChunk
|
17
17
|
from .typing.events import (
|
18
18
|
CompletionChunkEvent,
|
19
19
|
CompletionEvent,
|
@@ -52,7 +52,9 @@ class CloudLLMSettings(LLMSettings, total=False):
|
|
52
52
|
LLMRateLimiter = RateLimiterC[
|
53
53
|
Messages,
|
54
54
|
AssistantMessage
|
55
|
-
| AsyncIterator[
|
55
|
+
| AsyncIterator[
|
56
|
+
CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
|
57
|
+
],
|
56
58
|
]
|
57
59
|
|
58
60
|
|
@@ -274,7 +276,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
274
276
|
n_choices: int | None = None,
|
275
277
|
proc_name: str | None = None,
|
276
278
|
call_id: str | None = None,
|
277
|
-
) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
|
279
|
+
) -> AsyncIterator[CompletionChunkEvent[CompletionChunk] | CompletionEvent]:
|
278
280
|
completion_kwargs = self._make_completion_kwargs(
|
279
281
|
conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
|
280
282
|
)
|
@@ -284,7 +286,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
284
286
|
api_stream = self._get_completion_stream(**completion_kwargs)
|
285
287
|
api_stream = cast("AsyncIterator[Any]", api_stream)
|
286
288
|
|
287
|
-
async def iterator() -> AsyncIterator[
|
289
|
+
async def iterator() -> AsyncIterator[
|
290
|
+
CompletionChunkEvent[CompletionChunk] | CompletionEvent
|
291
|
+
]:
|
288
292
|
api_completion_chunks: list[Any] = []
|
289
293
|
|
290
294
|
async for api_completion_chunk in api_stream:
|
@@ -318,7 +322,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
|
|
318
322
|
n_choices: int | None = None,
|
319
323
|
proc_name: str | None = None,
|
320
324
|
call_id: str | None = None,
|
321
|
-
) -> AsyncIterator[
|
325
|
+
) -> AsyncIterator[
|
326
|
+
CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
|
327
|
+
]:
|
322
328
|
n_attempt = 0
|
323
329
|
while n_attempt <= self.max_response_retries:
|
324
330
|
try:
|
@@ -7,6 +7,7 @@ from uuid import uuid4
|
|
7
7
|
from pydantic import BaseModel
|
8
8
|
from typing_extensions import TypedDict
|
9
9
|
|
10
|
+
from grasp_agents.typing.completion_chunk import CompletionChunk
|
10
11
|
from grasp_agents.utils import (
|
11
12
|
validate_obj_from_json_or_py_string,
|
12
13
|
validate_tagged_objs_from_json_or_py_string,
|
@@ -19,13 +20,41 @@ from .errors import (
|
|
19
20
|
)
|
20
21
|
from .typing.completion import Completion
|
21
22
|
from .typing.converters import Converters
|
22
|
-
from .typing.events import
|
23
|
+
from .typing.events import (
|
24
|
+
AnnotationsChunkEvent,
|
25
|
+
AnnotationsEndEvent,
|
26
|
+
AnnotationsStartEvent,
|
27
|
+
CompletionChunkEvent,
|
28
|
+
CompletionEndEvent,
|
29
|
+
CompletionEvent,
|
30
|
+
CompletionStartEvent,
|
31
|
+
LLMStateChangeEvent,
|
32
|
+
LLMStreamingErrorEvent,
|
33
|
+
# RefusalChunkEvent,
|
34
|
+
ResponseChunkEvent,
|
35
|
+
ResponseEndEvent,
|
36
|
+
ResponseStartEvent,
|
37
|
+
ThinkingChunkEvent,
|
38
|
+
ThinkingEndEvent,
|
39
|
+
ThinkingStartEvent,
|
40
|
+
ToolCallChunkEvent,
|
41
|
+
ToolCallEndEvent,
|
42
|
+
ToolCallStartEvent,
|
43
|
+
)
|
23
44
|
from .typing.message import Messages
|
24
45
|
from .typing.tool import BaseTool, ToolChoice
|
25
46
|
|
26
47
|
logger = logging.getLogger(__name__)
|
27
48
|
|
28
49
|
|
50
|
+
LLMStreamGenerator = AsyncIterator[
|
51
|
+
CompletionChunkEvent[CompletionChunk]
|
52
|
+
| CompletionEvent
|
53
|
+
| LLMStateChangeEvent[Any]
|
54
|
+
| LLMStreamingErrorEvent
|
55
|
+
]
|
56
|
+
|
57
|
+
|
29
58
|
class LLMSettings(TypedDict, total=False):
|
30
59
|
max_completion_tokens: int | None
|
31
60
|
temperature: float | None
|
@@ -156,6 +185,124 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
156
185
|
tool_name, tool_arguments
|
157
186
|
) from exc
|
158
187
|
|
188
|
+
@staticmethod
|
189
|
+
async def postprocess_event_stream(
|
190
|
+
stream: LLMStreamGenerator,
|
191
|
+
) -> LLMStreamGenerator:
|
192
|
+
prev_completion_id: str | None = None
|
193
|
+
chunk_op_evt: CompletionChunkEvent[CompletionChunk] | None = None
|
194
|
+
response_op_evt: ResponseChunkEvent | None = None
|
195
|
+
thinking_op_evt: ThinkingChunkEvent | None = None
|
196
|
+
annotations_op_evt: AnnotationsChunkEvent | None = None
|
197
|
+
tool_calls_op_evt: ToolCallChunkEvent | None = None
|
198
|
+
|
199
|
+
def _close_open_events() -> list[LLMStateChangeEvent[Any]]:
|
200
|
+
nonlocal \
|
201
|
+
chunk_op_evt, \
|
202
|
+
thinking_op_evt, \
|
203
|
+
tool_calls_op_evt, \
|
204
|
+
response_op_evt, \
|
205
|
+
annotations_op_evt
|
206
|
+
|
207
|
+
events: list[LLMStateChangeEvent[Any]] = []
|
208
|
+
|
209
|
+
if tool_calls_op_evt:
|
210
|
+
events.append(ToolCallEndEvent.from_chunk_event(tool_calls_op_evt))
|
211
|
+
|
212
|
+
if response_op_evt:
|
213
|
+
events.append(ResponseEndEvent.from_chunk_event(response_op_evt))
|
214
|
+
|
215
|
+
if thinking_op_evt:
|
216
|
+
events.append(ThinkingEndEvent.from_chunk_event(thinking_op_evt))
|
217
|
+
|
218
|
+
if annotations_op_evt:
|
219
|
+
events.append(AnnotationsEndEvent.from_chunk_event(annotations_op_evt))
|
220
|
+
|
221
|
+
if chunk_op_evt:
|
222
|
+
events.append(CompletionEndEvent.from_chunk_event(chunk_op_evt))
|
223
|
+
|
224
|
+
chunk_op_evt = None
|
225
|
+
thinking_op_evt = None
|
226
|
+
tool_calls_op_evt = None
|
227
|
+
response_op_evt = None
|
228
|
+
annotations_op_evt = None
|
229
|
+
|
230
|
+
return events
|
231
|
+
|
232
|
+
async for event in stream:
|
233
|
+
if isinstance(event, CompletionChunkEvent) and not isinstance(
|
234
|
+
event, LLMStateChangeEvent
|
235
|
+
):
|
236
|
+
chunk = event.data
|
237
|
+
if len(chunk.choices) != 1:
|
238
|
+
raise ValueError(
|
239
|
+
"Expected exactly one choice in completion chunk, "
|
240
|
+
f"got {len(chunk.choices)}"
|
241
|
+
)
|
242
|
+
|
243
|
+
new_completion = chunk.id != prev_completion_id
|
244
|
+
|
245
|
+
if new_completion:
|
246
|
+
for close_event in _close_open_events():
|
247
|
+
yield close_event
|
248
|
+
|
249
|
+
chunk_op_evt = event
|
250
|
+
yield CompletionStartEvent.from_chunk_event(event)
|
251
|
+
|
252
|
+
sub_events = event.split_into_specialized()
|
253
|
+
|
254
|
+
for sub_event in sub_events:
|
255
|
+
if isinstance(sub_event, ThinkingChunkEvent):
|
256
|
+
if not thinking_op_evt:
|
257
|
+
thinking_op_evt = sub_event
|
258
|
+
yield ThinkingStartEvent.from_chunk_event(sub_event)
|
259
|
+
yield sub_event
|
260
|
+
elif thinking_op_evt:
|
261
|
+
yield ThinkingEndEvent.from_chunk_event(thinking_op_evt)
|
262
|
+
thinking_op_evt = None
|
263
|
+
|
264
|
+
if isinstance(sub_event, ToolCallChunkEvent):
|
265
|
+
tc = sub_event.data.tool_call
|
266
|
+
if tc.id:
|
267
|
+
# Tool call ID is not None only for the first chunk of a tool call
|
268
|
+
if tool_calls_op_evt:
|
269
|
+
yield ToolCallEndEvent.from_chunk_event(
|
270
|
+
tool_calls_op_evt
|
271
|
+
)
|
272
|
+
tool_calls_op_evt = None
|
273
|
+
tool_calls_op_evt = sub_event
|
274
|
+
yield ToolCallStartEvent.from_chunk_event(sub_event)
|
275
|
+
yield sub_event
|
276
|
+
elif tool_calls_op_evt:
|
277
|
+
yield ToolCallEndEvent.from_chunk_event(tool_calls_op_evt)
|
278
|
+
tool_calls_op_evt = None
|
279
|
+
|
280
|
+
if isinstance(sub_event, ResponseChunkEvent):
|
281
|
+
if not response_op_evt:
|
282
|
+
response_op_evt = sub_event
|
283
|
+
yield ResponseStartEvent.from_chunk_event(sub_event)
|
284
|
+
yield sub_event
|
285
|
+
elif response_op_evt:
|
286
|
+
yield ResponseEndEvent.from_chunk_event(response_op_evt)
|
287
|
+
response_op_evt = None
|
288
|
+
|
289
|
+
if isinstance(sub_event, AnnotationsChunkEvent):
|
290
|
+
if not annotations_op_evt:
|
291
|
+
annotations_op_evt = sub_event
|
292
|
+
yield AnnotationsStartEvent.from_chunk_event(sub_event)
|
293
|
+
yield sub_event
|
294
|
+
elif annotations_op_evt:
|
295
|
+
yield AnnotationsEndEvent.from_chunk_event(annotations_op_evt)
|
296
|
+
annotations_op_evt = None
|
297
|
+
|
298
|
+
prev_completion_id = chunk.id
|
299
|
+
|
300
|
+
else:
|
301
|
+
for close_event in _close_open_events():
|
302
|
+
yield close_event
|
303
|
+
|
304
|
+
yield event
|
305
|
+
|
159
306
|
@abstractmethod
|
160
307
|
async def generate_completion(
|
161
308
|
self,
|
@@ -177,7 +324,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
|
|
177
324
|
n_choices: int | None = None,
|
178
325
|
proc_name: str | None = None,
|
179
326
|
call_id: str | None = None,
|
180
|
-
) -> AsyncIterator[
|
327
|
+
) -> AsyncIterator[
|
328
|
+
CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
|
329
|
+
]:
|
181
330
|
pass
|
182
331
|
|
183
332
|
@abstractmethod
|
@@ -11,7 +11,7 @@ from .llm_policy_executor import (
|
|
11
11
|
MemoryManager,
|
12
12
|
ToolCallLoopTerminator,
|
13
13
|
)
|
14
|
-
from .
|
14
|
+
from .processors.parallel_processor import ParallelProcessor
|
15
15
|
from .prompt_builder import (
|
16
16
|
InputContentBuilder,
|
17
17
|
PromptBuilder,
|
@@ -46,7 +46,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
|
|
46
46
|
|
47
47
|
|
48
48
|
class LLMAgent(
|
49
|
-
|
49
|
+
ParallelProcessor[InT, OutT, LLMAgentMemory, CtxT],
|
50
50
|
Generic[InT, OutT, CtxT],
|
51
51
|
):
|
52
52
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
@@ -196,6 +196,20 @@ class LLMAgent(
|
|
196
196
|
|
197
197
|
return system_message, input_message
|
198
198
|
|
199
|
+
def _parse_output_default(
|
200
|
+
self,
|
201
|
+
conversation: Messages,
|
202
|
+
*,
|
203
|
+
in_args: InT | None = None,
|
204
|
+
ctx: RunContext[CtxT] | None = None,
|
205
|
+
) -> OutT:
|
206
|
+
return validate_obj_from_json_or_py_string(
|
207
|
+
str(conversation[-1].content or ""),
|
208
|
+
schema=self._out_type,
|
209
|
+
from_substring=False,
|
210
|
+
strip_language_markdown=True,
|
211
|
+
)
|
212
|
+
|
199
213
|
def _parse_output(
|
200
214
|
self,
|
201
215
|
conversation: Messages,
|
@@ -208,11 +222,8 @@ class LLMAgent(
|
|
208
222
|
conversation=conversation, in_args=in_args, ctx=ctx
|
209
223
|
)
|
210
224
|
|
211
|
-
return
|
212
|
-
|
213
|
-
schema=self._out_type,
|
214
|
-
from_substring=False,
|
215
|
-
strip_language_markdown=True,
|
225
|
+
return self._parse_output_default(
|
226
|
+
conversation=conversation, in_args=in_args, ctx=ctx
|
216
227
|
)
|
217
228
|
|
218
229
|
async def _process(
|
@@ -7,6 +7,8 @@ from typing import Any, Generic, Protocol, final
|
|
7
7
|
|
8
8
|
from pydantic import BaseModel
|
9
9
|
|
10
|
+
from grasp_agents.typing.completion_chunk import CompletionChunk
|
11
|
+
|
10
12
|
from .errors import AgentFinalAnswerError
|
11
13
|
from .llm import LLM, LLMSettings
|
12
14
|
from .llm_agent_memory import LLMAgentMemory
|
@@ -149,19 +151,23 @@ class LLMPolicyExecutor(Generic[CtxT]):
|
|
149
151
|
tool_choice: ToolChoice | None = None,
|
150
152
|
ctx: RunContext[CtxT] | None = None,
|
151
153
|
) -> AsyncIterator[
|
152
|
-
CompletionChunkEvent
|
154
|
+
CompletionChunkEvent[CompletionChunk]
|
153
155
|
| CompletionEvent
|
154
156
|
| GenMessageEvent
|
155
157
|
| LLMStreamingErrorEvent
|
156
158
|
]:
|
157
159
|
completion: Completion | None = None
|
158
|
-
|
160
|
+
|
161
|
+
llm_event_stream = self.llm.generate_completion_stream(
|
159
162
|
memory.message_history,
|
160
163
|
tool_choice=tool_choice,
|
161
164
|
n_choices=1,
|
162
165
|
proc_name=self.agent_name,
|
163
166
|
call_id=call_id,
|
164
|
-
)
|
167
|
+
)
|
168
|
+
llm_event_stream_post = self.llm.postprocess_event_stream(llm_event_stream) # type: ignore[assignment]
|
169
|
+
|
170
|
+
async for event in llm_event_stream_post:
|
165
171
|
if isinstance(event, CompletionEvent):
|
166
172
|
completion = event.data
|
167
173
|
yield event
|
@@ -2,10 +2,9 @@ import asyncio
|
|
2
2
|
import logging
|
3
3
|
from collections.abc import AsyncIterator
|
4
4
|
from types import TracebackType
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Literal, Protocol, TypeVar
|
6
6
|
|
7
7
|
from .packet import Packet
|
8
|
-
from .run_context import CtxT, RunContext
|
9
8
|
from .typing.events import Event
|
10
9
|
from .typing.io import ProcName
|
11
10
|
|
@@ -18,24 +17,21 @@ END_PROC_NAME: Literal["*END*"] = "*END*"
|
|
18
17
|
_PayloadT_contra = TypeVar("_PayloadT_contra", contravariant=True)
|
19
18
|
|
20
19
|
|
21
|
-
class PacketHandler(Protocol[_PayloadT_contra
|
20
|
+
class PacketHandler(Protocol[_PayloadT_contra]):
|
22
21
|
async def __call__(
|
23
|
-
self,
|
24
|
-
packet: Packet[_PayloadT_contra],
|
25
|
-
ctx: RunContext[CtxT],
|
26
|
-
**kwargs: Any,
|
22
|
+
self, packet: Packet[_PayloadT_contra], **kwargs: Any
|
27
23
|
) -> None: ...
|
28
24
|
|
29
25
|
|
30
|
-
class PacketPool
|
26
|
+
class PacketPool:
|
31
27
|
def __init__(self) -> None:
|
32
28
|
self._packet_queues: dict[ProcName, asyncio.Queue[Packet[Any] | None]] = {}
|
33
|
-
self._packet_handlers: dict[ProcName, PacketHandler[Any
|
29
|
+
self._packet_handlers: dict[ProcName, PacketHandler[Any]] = {}
|
34
30
|
self._task_group: asyncio.TaskGroup | None = None
|
35
31
|
|
36
32
|
self._event_queue: asyncio.Queue[Event[Any] | None] = asyncio.Queue()
|
37
33
|
|
38
|
-
self._final_result_fut: asyncio.Future[Packet[Any]]
|
34
|
+
self._final_result_fut: asyncio.Future[Packet[Any]]
|
39
35
|
|
40
36
|
self._stopping = False
|
41
37
|
self._stopped_evt = asyncio.Event()
|
@@ -44,9 +40,8 @@ class PacketPool(Generic[CtxT]):
|
|
44
40
|
|
45
41
|
async def post(self, packet: Packet[Any]) -> None:
|
46
42
|
if packet.recipients == [END_PROC_NAME]:
|
47
|
-
|
48
|
-
|
49
|
-
fut.set_result(packet)
|
43
|
+
if not self._final_result_fut.done():
|
44
|
+
self._final_result_fut.set_result(packet)
|
50
45
|
await self.shutdown()
|
51
46
|
return
|
52
47
|
|
@@ -54,26 +49,14 @@ class PacketPool(Generic[CtxT]):
|
|
54
49
|
queue = self._packet_queues.setdefault(recipient_id, asyncio.Queue())
|
55
50
|
await queue.put(packet)
|
56
51
|
|
57
|
-
def _ensure_final_future(self) -> asyncio.Future[Packet[Any]]:
|
58
|
-
fut = self._final_result_fut
|
59
|
-
if fut is None:
|
60
|
-
fut = asyncio.get_running_loop().create_future()
|
61
|
-
self._final_result_fut = fut
|
62
|
-
return fut
|
63
|
-
|
64
52
|
async def final_result(self) -> Packet[Any]:
|
65
|
-
fut = self._ensure_final_future()
|
66
53
|
try:
|
67
|
-
return await
|
54
|
+
return await self._final_result_fut
|
68
55
|
finally:
|
69
56
|
await self.shutdown()
|
70
57
|
|
71
58
|
def register_packet_handler(
|
72
|
-
self,
|
73
|
-
proc_name: ProcName,
|
74
|
-
handler: PacketHandler[Any, CtxT],
|
75
|
-
ctx: RunContext[CtxT],
|
76
|
-
**run_kwargs: Any,
|
59
|
+
self, proc_name: ProcName, handler: PacketHandler[Any]
|
77
60
|
) -> None:
|
78
61
|
if self._stopping:
|
79
62
|
raise RuntimeError("PacketPool is stopping/stopped")
|
@@ -83,17 +66,19 @@ class PacketPool(Generic[CtxT]):
|
|
83
66
|
|
84
67
|
if self._task_group is not None:
|
85
68
|
self._task_group.create_task(
|
86
|
-
self._handle_packets(proc_name
|
69
|
+
self._handle_packets(proc_name),
|
87
70
|
name=f"packet-handler:{proc_name}",
|
88
71
|
)
|
89
72
|
|
90
73
|
async def push_event(self, event: Event[Any]) -> None:
|
91
74
|
await self._event_queue.put(event)
|
92
75
|
|
93
|
-
async def __aenter__(self) -> "PacketPool
|
76
|
+
async def __aenter__(self) -> "PacketPool":
|
94
77
|
self._task_group = asyncio.TaskGroup()
|
95
78
|
await self._task_group.__aenter__()
|
96
79
|
|
80
|
+
self._final_result_fut = asyncio.get_running_loop().create_future()
|
81
|
+
|
97
82
|
return self
|
98
83
|
|
99
84
|
async def __aexit__(
|
@@ -115,9 +100,7 @@ class PacketPool(Generic[CtxT]):
|
|
115
100
|
|
116
101
|
return False
|
117
102
|
|
118
|
-
async def _handle_packets(
|
119
|
-
self, proc_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
|
120
|
-
) -> None:
|
103
|
+
async def _handle_packets(self, proc_name: ProcName) -> None:
|
121
104
|
queue = self._packet_queues[proc_name]
|
122
105
|
handler = self._packet_handlers[proc_name]
|
123
106
|
|
@@ -125,16 +108,19 @@ class PacketPool(Generic[CtxT]):
|
|
125
108
|
packet = await queue.get()
|
126
109
|
if packet is None:
|
127
110
|
break
|
111
|
+
|
112
|
+
if self._final_result_fut.done():
|
113
|
+
continue
|
114
|
+
|
128
115
|
try:
|
129
|
-
await handler(packet
|
116
|
+
await handler(packet)
|
130
117
|
except asyncio.CancelledError:
|
131
118
|
raise
|
132
119
|
except Exception as err:
|
133
120
|
logger.exception("Error handling packet for %s", proc_name)
|
134
121
|
self._errors.append(err)
|
135
|
-
|
136
|
-
|
137
|
-
fut.set_exception(err)
|
122
|
+
if not self._final_result_fut.done():
|
123
|
+
self._final_result_fut.set_exception(err)
|
138
124
|
await self.shutdown()
|
139
125
|
raise
|
140
126
|
|
@@ -154,6 +140,5 @@ class PacketPool(Generic[CtxT]):
|
|
154
140
|
await self._event_queue.put(None)
|
155
141
|
for queue in self._packet_queues.values():
|
156
142
|
await queue.put(None)
|
157
|
-
|
158
143
|
finally:
|
159
144
|
self._stopped_evt.set()
|