grasp_agents 0.5.5__py3-none-any.whl → 0.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/__init__.py CHANGED
@@ -6,7 +6,9 @@ from .llm_agent import LLMAgent
6
6
  from .llm_agent_memory import LLMAgentMemory
7
7
  from .memory import Memory
8
8
  from .packet import Packet
9
- from .processor import Processor
9
+ from .processors.base_processor import BaseProcessor
10
+ from .processors.parallel_processor import ParallelProcessor
11
+ from .processors.processor import Processor
10
12
  from .run_context import RunContext
11
13
  from .typing.completion import Completion
12
14
  from .typing.content import Content, ImageData
@@ -17,6 +19,7 @@ from .typing.tool import BaseTool
17
19
  __all__ = [
18
20
  "LLM",
19
21
  "AssistantMessage",
22
+ "BaseProcessor",
20
23
  "BaseTool",
21
24
  "Completion",
22
25
  "Content",
@@ -29,6 +32,7 @@ __all__ = [
29
32
  "Messages",
30
33
  "Packet",
31
34
  "Packet",
35
+ "ParallelProcessor",
32
36
  "ProcName",
33
37
  "Processor",
34
38
  "RunContext",
grasp_agents/cloud_llm.py CHANGED
@@ -13,7 +13,7 @@ from .http_client import AsyncHTTPClientParams, create_simple_async_httpx_client
13
13
  from .llm import LLM, ConvertT_co, LLMSettings, SettingsT_co
14
14
  from .rate_limiting.rate_limiter_chunked import RateLimiterC, limit_rate
15
15
  from .typing.completion import Completion
16
- from .typing.completion_chunk import CompletionChoice
16
+ from .typing.completion_chunk import CompletionChoice, CompletionChunk
17
17
  from .typing.events import (
18
18
  CompletionChunkEvent,
19
19
  CompletionEvent,
@@ -52,7 +52,9 @@ class CloudLLMSettings(LLMSettings, total=False):
52
52
  LLMRateLimiter = RateLimiterC[
53
53
  Messages,
54
54
  AssistantMessage
55
- | AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent],
55
+ | AsyncIterator[
56
+ CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
57
+ ],
56
58
  ]
57
59
 
58
60
 
@@ -274,7 +276,7 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
274
276
  n_choices: int | None = None,
275
277
  proc_name: str | None = None,
276
278
  call_id: str | None = None,
277
- ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
279
+ ) -> AsyncIterator[CompletionChunkEvent[CompletionChunk] | CompletionEvent]:
278
280
  completion_kwargs = self._make_completion_kwargs(
279
281
  conversation=conversation, tool_choice=tool_choice, n_choices=n_choices
280
282
  )
@@ -284,7 +286,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
284
286
  api_stream = self._get_completion_stream(**completion_kwargs)
285
287
  api_stream = cast("AsyncIterator[Any]", api_stream)
286
288
 
287
- async def iterator() -> AsyncIterator[CompletionChunkEvent | CompletionEvent]:
289
+ async def iterator() -> AsyncIterator[
290
+ CompletionChunkEvent[CompletionChunk] | CompletionEvent
291
+ ]:
288
292
  api_completion_chunks: list[Any] = []
289
293
 
290
294
  async for api_completion_chunk in api_stream:
@@ -318,7 +322,9 @@ class CloudLLM(LLM[SettingsT_co, ConvertT_co], Generic[SettingsT_co, ConvertT_co
318
322
  n_choices: int | None = None,
319
323
  proc_name: str | None = None,
320
324
  call_id: str | None = None,
321
- ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
325
+ ) -> AsyncIterator[
326
+ CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
327
+ ]:
322
328
  n_attempt = 0
323
329
  while n_attempt <= self.max_response_retries:
324
330
  try:
grasp_agents/llm.py CHANGED
@@ -7,6 +7,7 @@ from uuid import uuid4
7
7
  from pydantic import BaseModel
8
8
  from typing_extensions import TypedDict
9
9
 
10
+ from grasp_agents.typing.completion_chunk import CompletionChunk
10
11
  from grasp_agents.utils import (
11
12
  validate_obj_from_json_or_py_string,
12
13
  validate_tagged_objs_from_json_or_py_string,
@@ -19,13 +20,41 @@ from .errors import (
19
20
  )
20
21
  from .typing.completion import Completion
21
22
  from .typing.converters import Converters
22
- from .typing.events import CompletionChunkEvent, CompletionEvent, LLMStreamingErrorEvent
23
+ from .typing.events import (
24
+ AnnotationsChunkEvent,
25
+ AnnotationsEndEvent,
26
+ AnnotationsStartEvent,
27
+ CompletionChunkEvent,
28
+ CompletionEndEvent,
29
+ CompletionEvent,
30
+ CompletionStartEvent,
31
+ LLMStateChangeEvent,
32
+ LLMStreamingErrorEvent,
33
+ # RefusalChunkEvent,
34
+ ResponseChunkEvent,
35
+ ResponseEndEvent,
36
+ ResponseStartEvent,
37
+ ThinkingChunkEvent,
38
+ ThinkingEndEvent,
39
+ ThinkingStartEvent,
40
+ ToolCallChunkEvent,
41
+ ToolCallEndEvent,
42
+ ToolCallStartEvent,
43
+ )
23
44
  from .typing.message import Messages
24
45
  from .typing.tool import BaseTool, ToolChoice
25
46
 
26
47
  logger = logging.getLogger(__name__)
27
48
 
28
49
 
50
+ LLMStreamGenerator = AsyncIterator[
51
+ CompletionChunkEvent[CompletionChunk]
52
+ | CompletionEvent
53
+ | LLMStateChangeEvent[Any]
54
+ | LLMStreamingErrorEvent
55
+ ]
56
+
57
+
29
58
  class LLMSettings(TypedDict, total=False):
30
59
  max_completion_tokens: int | None
31
60
  temperature: float | None
@@ -156,6 +185,124 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
156
185
  tool_name, tool_arguments
157
186
  ) from exc
158
187
 
188
+ @staticmethod
189
+ async def postprocess_event_stream(
190
+ stream: LLMStreamGenerator,
191
+ ) -> LLMStreamGenerator:
192
+ prev_completion_id: str | None = None
193
+ chunk_op_evt: CompletionChunkEvent[CompletionChunk] | None = None
194
+ response_op_evt: ResponseChunkEvent | None = None
195
+ thinking_op_evt: ThinkingChunkEvent | None = None
196
+ annotations_op_evt: AnnotationsChunkEvent | None = None
197
+ tool_calls_op_evt: ToolCallChunkEvent | None = None
198
+
199
+ def _close_open_events() -> list[LLMStateChangeEvent[Any]]:
200
+ nonlocal \
201
+ chunk_op_evt, \
202
+ thinking_op_evt, \
203
+ tool_calls_op_evt, \
204
+ response_op_evt, \
205
+ annotations_op_evt
206
+
207
+ events: list[LLMStateChangeEvent[Any]] = []
208
+
209
+ if tool_calls_op_evt:
210
+ events.append(ToolCallEndEvent.from_chunk_event(tool_calls_op_evt))
211
+
212
+ if response_op_evt:
213
+ events.append(ResponseEndEvent.from_chunk_event(response_op_evt))
214
+
215
+ if thinking_op_evt:
216
+ events.append(ThinkingEndEvent.from_chunk_event(thinking_op_evt))
217
+
218
+ if annotations_op_evt:
219
+ events.append(AnnotationsEndEvent.from_chunk_event(annotations_op_evt))
220
+
221
+ if chunk_op_evt:
222
+ events.append(CompletionEndEvent.from_chunk_event(chunk_op_evt))
223
+
224
+ chunk_op_evt = None
225
+ thinking_op_evt = None
226
+ tool_calls_op_evt = None
227
+ response_op_evt = None
228
+ annotations_op_evt = None
229
+
230
+ return events
231
+
232
+ async for event in stream:
233
+ if isinstance(event, CompletionChunkEvent) and not isinstance(
234
+ event, LLMStateChangeEvent
235
+ ):
236
+ chunk = event.data
237
+ if len(chunk.choices) != 1:
238
+ raise ValueError(
239
+ "Expected exactly one choice in completion chunk, "
240
+ f"got {len(chunk.choices)}"
241
+ )
242
+
243
+ new_completion = chunk.id != prev_completion_id
244
+
245
+ if new_completion:
246
+ for close_event in _close_open_events():
247
+ yield close_event
248
+
249
+ chunk_op_evt = event
250
+ yield CompletionStartEvent.from_chunk_event(event)
251
+
252
+ sub_events = event.split_into_specialized()
253
+
254
+ for sub_event in sub_events:
255
+ if isinstance(sub_event, ThinkingChunkEvent):
256
+ if not thinking_op_evt:
257
+ thinking_op_evt = sub_event
258
+ yield ThinkingStartEvent.from_chunk_event(sub_event)
259
+ yield sub_event
260
+ elif thinking_op_evt:
261
+ yield ThinkingEndEvent.from_chunk_event(thinking_op_evt)
262
+ thinking_op_evt = None
263
+
264
+ if isinstance(sub_event, ToolCallChunkEvent):
265
+ tc = sub_event.data.tool_call
266
+ if tc.id:
267
+ # Tool call ID is not None only for the first chunk of a tool call
268
+ if tool_calls_op_evt:
269
+ yield ToolCallEndEvent.from_chunk_event(
270
+ tool_calls_op_evt
271
+ )
272
+ tool_calls_op_evt = None
273
+ tool_calls_op_evt = sub_event
274
+ yield ToolCallStartEvent.from_chunk_event(sub_event)
275
+ yield sub_event
276
+ elif tool_calls_op_evt:
277
+ yield ToolCallEndEvent.from_chunk_event(tool_calls_op_evt)
278
+ tool_calls_op_evt = None
279
+
280
+ if isinstance(sub_event, ResponseChunkEvent):
281
+ if not response_op_evt:
282
+ response_op_evt = sub_event
283
+ yield ResponseStartEvent.from_chunk_event(sub_event)
284
+ yield sub_event
285
+ elif response_op_evt:
286
+ yield ResponseEndEvent.from_chunk_event(response_op_evt)
287
+ response_op_evt = None
288
+
289
+ if isinstance(sub_event, AnnotationsChunkEvent):
290
+ if not annotations_op_evt:
291
+ annotations_op_evt = sub_event
292
+ yield AnnotationsStartEvent.from_chunk_event(sub_event)
293
+ yield sub_event
294
+ elif annotations_op_evt:
295
+ yield AnnotationsEndEvent.from_chunk_event(annotations_op_evt)
296
+ annotations_op_evt = None
297
+
298
+ prev_completion_id = chunk.id
299
+
300
+ else:
301
+ for close_event in _close_open_events():
302
+ yield close_event
303
+
304
+ yield event
305
+
159
306
  @abstractmethod
160
307
  async def generate_completion(
161
308
  self,
@@ -177,7 +324,9 @@ class LLM(ABC, Generic[SettingsT_co, ConvertT_co]):
177
324
  n_choices: int | None = None,
178
325
  proc_name: str | None = None,
179
326
  call_id: str | None = None,
180
- ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | LLMStreamingErrorEvent]:
327
+ ) -> AsyncIterator[
328
+ CompletionChunkEvent[CompletionChunk] | CompletionEvent | LLMStreamingErrorEvent
329
+ ]:
181
330
  pass
182
331
 
183
332
  @abstractmethod
grasp_agents/llm_agent.py CHANGED
@@ -11,7 +11,7 @@ from .llm_policy_executor import (
11
11
  MemoryManager,
12
12
  ToolCallLoopTerminator,
13
13
  )
14
- from .processor import Processor
14
+ from .processors.parallel_processor import ParallelProcessor
15
15
  from .prompt_builder import (
16
16
  InputContentBuilder,
17
17
  PromptBuilder,
@@ -46,7 +46,7 @@ class OutputParser(Protocol[_InT_contra, _OutT_co, CtxT]):
46
46
 
47
47
 
48
48
  class LLMAgent(
49
- Processor[InT, OutT, LLMAgentMemory, CtxT],
49
+ ParallelProcessor[InT, OutT, LLMAgentMemory, CtxT],
50
50
  Generic[InT, OutT, CtxT],
51
51
  ):
52
52
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
@@ -196,6 +196,20 @@ class LLMAgent(
196
196
 
197
197
  return system_message, input_message
198
198
 
199
+ def _parse_output_default(
200
+ self,
201
+ conversation: Messages,
202
+ *,
203
+ in_args: InT | None = None,
204
+ ctx: RunContext[CtxT] | None = None,
205
+ ) -> OutT:
206
+ return validate_obj_from_json_or_py_string(
207
+ str(conversation[-1].content or ""),
208
+ schema=self._out_type,
209
+ from_substring=False,
210
+ strip_language_markdown=True,
211
+ )
212
+
199
213
  def _parse_output(
200
214
  self,
201
215
  conversation: Messages,
@@ -208,11 +222,8 @@ class LLMAgent(
208
222
  conversation=conversation, in_args=in_args, ctx=ctx
209
223
  )
210
224
 
211
- return validate_obj_from_json_or_py_string(
212
- str(conversation[-1].content or ""),
213
- schema=self._out_type,
214
- from_substring=False,
215
- strip_language_markdown=True,
225
+ return self._parse_output_default(
226
+ conversation=conversation, in_args=in_args, ctx=ctx
216
227
  )
217
228
 
218
229
  async def _process(
@@ -7,6 +7,8 @@ from typing import Any, Generic, Protocol, final
7
7
 
8
8
  from pydantic import BaseModel
9
9
 
10
+ from grasp_agents.typing.completion_chunk import CompletionChunk
11
+
10
12
  from .errors import AgentFinalAnswerError
11
13
  from .llm import LLM, LLMSettings
12
14
  from .llm_agent_memory import LLMAgentMemory
@@ -149,19 +151,23 @@ class LLMPolicyExecutor(Generic[CtxT]):
149
151
  tool_choice: ToolChoice | None = None,
150
152
  ctx: RunContext[CtxT] | None = None,
151
153
  ) -> AsyncIterator[
152
- CompletionChunkEvent
154
+ CompletionChunkEvent[CompletionChunk]
153
155
  | CompletionEvent
154
156
  | GenMessageEvent
155
157
  | LLMStreamingErrorEvent
156
158
  ]:
157
159
  completion: Completion | None = None
158
- async for event in self.llm.generate_completion_stream( # type: ignore[no-untyped-call]
160
+
161
+ llm_event_stream = self.llm.generate_completion_stream(
159
162
  memory.message_history,
160
163
  tool_choice=tool_choice,
161
164
  n_choices=1,
162
165
  proc_name=self.agent_name,
163
166
  call_id=call_id,
164
- ):
167
+ )
168
+ llm_event_stream_post = self.llm.postprocess_event_stream(llm_event_stream) # type: ignore[assignment]
169
+
170
+ async for event in llm_event_stream_post:
165
171
  if isinstance(event, CompletionEvent):
166
172
  completion = event.data
167
173
  yield event
@@ -2,10 +2,9 @@ import asyncio
2
2
  import logging
3
3
  from collections.abc import AsyncIterator
4
4
  from types import TracebackType
5
- from typing import Any, Generic, Literal, Protocol, TypeVar
5
+ from typing import Any, Literal, Protocol, TypeVar
6
6
 
7
7
  from .packet import Packet
8
- from .run_context import CtxT, RunContext
9
8
  from .typing.events import Event
10
9
  from .typing.io import ProcName
11
10
 
@@ -18,24 +17,21 @@ END_PROC_NAME: Literal["*END*"] = "*END*"
18
17
  _PayloadT_contra = TypeVar("_PayloadT_contra", contravariant=True)
19
18
 
20
19
 
21
- class PacketHandler(Protocol[_PayloadT_contra, CtxT]):
20
+ class PacketHandler(Protocol[_PayloadT_contra]):
22
21
  async def __call__(
23
- self,
24
- packet: Packet[_PayloadT_contra],
25
- ctx: RunContext[CtxT],
26
- **kwargs: Any,
22
+ self, packet: Packet[_PayloadT_contra], **kwargs: Any
27
23
  ) -> None: ...
28
24
 
29
25
 
30
- class PacketPool(Generic[CtxT]):
26
+ class PacketPool:
31
27
  def __init__(self) -> None:
32
28
  self._packet_queues: dict[ProcName, asyncio.Queue[Packet[Any] | None]] = {}
33
- self._packet_handlers: dict[ProcName, PacketHandler[Any, CtxT]] = {}
29
+ self._packet_handlers: dict[ProcName, PacketHandler[Any]] = {}
34
30
  self._task_group: asyncio.TaskGroup | None = None
35
31
 
36
32
  self._event_queue: asyncio.Queue[Event[Any] | None] = asyncio.Queue()
37
33
 
38
- self._final_result_fut: asyncio.Future[Packet[Any]] | None = None
34
+ self._final_result_fut: asyncio.Future[Packet[Any]]
39
35
 
40
36
  self._stopping = False
41
37
  self._stopped_evt = asyncio.Event()
@@ -44,9 +40,8 @@ class PacketPool(Generic[CtxT]):
44
40
 
45
41
  async def post(self, packet: Packet[Any]) -> None:
46
42
  if packet.recipients == [END_PROC_NAME]:
47
- fut = self._ensure_final_future()
48
- if not fut.done():
49
- fut.set_result(packet)
43
+ if not self._final_result_fut.done():
44
+ self._final_result_fut.set_result(packet)
50
45
  await self.shutdown()
51
46
  return
52
47
 
@@ -54,26 +49,14 @@ class PacketPool(Generic[CtxT]):
54
49
  queue = self._packet_queues.setdefault(recipient_id, asyncio.Queue())
55
50
  await queue.put(packet)
56
51
 
57
- def _ensure_final_future(self) -> asyncio.Future[Packet[Any]]:
58
- fut = self._final_result_fut
59
- if fut is None:
60
- fut = asyncio.get_running_loop().create_future()
61
- self._final_result_fut = fut
62
- return fut
63
-
64
52
  async def final_result(self) -> Packet[Any]:
65
- fut = self._ensure_final_future()
66
53
  try:
67
- return await fut
54
+ return await self._final_result_fut
68
55
  finally:
69
56
  await self.shutdown()
70
57
 
71
58
  def register_packet_handler(
72
- self,
73
- proc_name: ProcName,
74
- handler: PacketHandler[Any, CtxT],
75
- ctx: RunContext[CtxT],
76
- **run_kwargs: Any,
59
+ self, proc_name: ProcName, handler: PacketHandler[Any]
77
60
  ) -> None:
78
61
  if self._stopping:
79
62
  raise RuntimeError("PacketPool is stopping/stopped")
@@ -83,17 +66,19 @@ class PacketPool(Generic[CtxT]):
83
66
 
84
67
  if self._task_group is not None:
85
68
  self._task_group.create_task(
86
- self._handle_packets(proc_name, ctx=ctx, **run_kwargs),
69
+ self._handle_packets(proc_name),
87
70
  name=f"packet-handler:{proc_name}",
88
71
  )
89
72
 
90
73
  async def push_event(self, event: Event[Any]) -> None:
91
74
  await self._event_queue.put(event)
92
75
 
93
- async def __aenter__(self) -> "PacketPool[CtxT]":
76
+ async def __aenter__(self) -> "PacketPool":
94
77
  self._task_group = asyncio.TaskGroup()
95
78
  await self._task_group.__aenter__()
96
79
 
80
+ self._final_result_fut = asyncio.get_running_loop().create_future()
81
+
97
82
  return self
98
83
 
99
84
  async def __aexit__(
@@ -115,9 +100,7 @@ class PacketPool(Generic[CtxT]):
115
100
 
116
101
  return False
117
102
 
118
- async def _handle_packets(
119
- self, proc_name: ProcName, ctx: RunContext[CtxT], **run_kwargs: Any
120
- ) -> None:
103
+ async def _handle_packets(self, proc_name: ProcName) -> None:
121
104
  queue = self._packet_queues[proc_name]
122
105
  handler = self._packet_handlers[proc_name]
123
106
 
@@ -125,16 +108,19 @@ class PacketPool(Generic[CtxT]):
125
108
  packet = await queue.get()
126
109
  if packet is None:
127
110
  break
111
+
112
+ if self._final_result_fut.done():
113
+ continue
114
+
128
115
  try:
129
- await handler(packet, ctx=ctx, **run_kwargs)
116
+ await handler(packet)
130
117
  except asyncio.CancelledError:
131
118
  raise
132
119
  except Exception as err:
133
120
  logger.exception("Error handling packet for %s", proc_name)
134
121
  self._errors.append(err)
135
- fut = self._final_result_fut
136
- if fut and not fut.done():
137
- fut.set_exception(err)
122
+ if not self._final_result_fut.done():
123
+ self._final_result_fut.set_exception(err)
138
124
  await self.shutdown()
139
125
  raise
140
126
 
@@ -154,6 +140,5 @@ class PacketPool(Generic[CtxT]):
154
140
  await self._event_queue.put(None)
155
141
  for queue in self._packet_queues.values():
156
142
  await queue.put(None)
157
-
158
143
  finally:
159
144
  self._stopped_evt.set()