grasp_agents 0.4.7__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/cloud_llm.py +191 -224
  2. grasp_agents/comm_processor.py +101 -100
  3. grasp_agents/errors.py +69 -9
  4. grasp_agents/litellm/__init__.py +106 -0
  5. grasp_agents/litellm/completion_chunk_converters.py +68 -0
  6. grasp_agents/litellm/completion_converters.py +72 -0
  7. grasp_agents/litellm/converters.py +138 -0
  8. grasp_agents/litellm/lite_llm.py +210 -0
  9. grasp_agents/litellm/message_converters.py +66 -0
  10. grasp_agents/llm.py +84 -49
  11. grasp_agents/llm_agent.py +136 -120
  12. grasp_agents/llm_agent_memory.py +3 -3
  13. grasp_agents/llm_policy_executor.py +167 -174
  14. grasp_agents/memory.py +4 -0
  15. grasp_agents/openai/__init__.py +24 -9
  16. grasp_agents/openai/completion_chunk_converters.py +6 -6
  17. grasp_agents/openai/completion_converters.py +12 -14
  18. grasp_agents/openai/content_converters.py +1 -3
  19. grasp_agents/openai/converters.py +6 -8
  20. grasp_agents/openai/message_converters.py +21 -3
  21. grasp_agents/openai/openai_llm.py +155 -103
  22. grasp_agents/openai/tool_converters.py +4 -6
  23. grasp_agents/packet.py +5 -2
  24. grasp_agents/packet_pool.py +14 -13
  25. grasp_agents/printer.py +234 -72
  26. grasp_agents/processor.py +228 -88
  27. grasp_agents/prompt_builder.py +2 -2
  28. grasp_agents/run_context.py +11 -20
  29. grasp_agents/runner.py +42 -0
  30. grasp_agents/typing/completion.py +16 -9
  31. grasp_agents/typing/completion_chunk.py +51 -22
  32. grasp_agents/typing/events.py +95 -19
  33. grasp_agents/typing/message.py +25 -1
  34. grasp_agents/typing/tool.py +2 -0
  35. grasp_agents/usage_tracker.py +31 -37
  36. grasp_agents/utils.py +95 -84
  37. grasp_agents/workflow/looped_workflow.py +60 -11
  38. grasp_agents/workflow/sequential_workflow.py +43 -11
  39. grasp_agents/workflow/workflow_processor.py +25 -24
  40. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
  41. grasp_agents-0.5.0.dist-info/RECORD +57 -0
  42. grasp_agents-0.4.7.dist-info/RECORD +0 -50
  43. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
  44. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/processor.py CHANGED
@@ -6,56 +6,44 @@ from typing import Any, ClassVar, Generic, cast, final
6
6
  from uuid import uuid4
7
7
 
8
8
  from pydantic import BaseModel, TypeAdapter
9
- from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
9
+ from pydantic import ValidationError as PydanticValidationError
10
10
 
11
- from .errors import InputValidationError
11
+ from .errors import ProcInputValidationError, ProcOutputValidationError
12
12
  from .generics_utils import AutoInstanceAttributesMixin
13
13
  from .memory import MemT
14
14
  from .packet import Packet
15
15
  from .run_context import CtxT, RunContext
16
- from .typing.events import Event, PacketEvent, ProcOutputEvent
16
+ from .typing.events import (
17
+ Event,
18
+ # ProcFinishEvent,
19
+ ProcPacketOutputEvent,
20
+ ProcPayloadOutputEvent,
21
+ # ProcStartEvent,
22
+ ProcStreamingErrorData,
23
+ ProcStreamingErrorEvent,
24
+ )
17
25
  from .typing.io import InT, OutT_co, ProcName
18
26
  from .typing.tool import BaseTool
27
+ from .utils import stream_concurrent
19
28
 
20
29
  logger = logging.getLogger(__name__)
21
30
 
22
31
 
23
- def retry_error_callback(retry_state: RetryCallState) -> None:
24
- exception = retry_state.outcome.exception() if retry_state.outcome else None
25
- if exception:
26
- if retry_state.attempt_number == 1:
27
- logger.warning(f"\nParallel run failed:\n{exception}")
28
- if retry_state.attempt_number > 1:
29
- logger.warning(f"\nParallel run failed after retrying:\n{exception}")
30
-
31
-
32
- def retry_before_sleep_callback(retry_state: RetryCallState) -> None:
33
- exception = retry_state.outcome.exception() if retry_state.outcome else None
34
- logger.info(
35
- f"\nRetrying parallel run (attempt {retry_state.attempt_number}):\n{exception}"
36
- )
37
-
38
-
39
32
  class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, CtxT]):
40
33
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
41
34
  0: "_in_type",
42
35
  1: "_out_type",
43
36
  }
44
37
 
45
- def __init__(
46
- self, name: ProcName, num_par_run_retries: int = 0, **kwargs: Any
47
- ) -> None:
38
+ def __init__(self, name: ProcName, max_retries: int = 0, **kwargs: Any) -> None:
48
39
  self._in_type: type[InT]
49
40
  self._out_type: type[OutT_co]
50
41
 
51
42
  super().__init__()
52
43
 
53
- self._in_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
54
- self._out_type_adapter: TypeAdapter[OutT_co] = TypeAdapter(self._out_type)
55
-
56
44
  self._name: ProcName = name
57
45
  self._memory: MemT
58
- self._num_par_run_retries: int = num_par_run_retries
46
+ self._max_retries: int = max_retries
59
47
 
60
48
  @property
61
49
  def in_type(self) -> type[InT]:
@@ -74,8 +62,13 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
74
62
  return self._memory
75
63
 
76
64
  @property
77
- def num_par_run_retries(self) -> int:
78
- return self._num_par_run_retries
65
+ def max_retries(self) -> int:
66
+ return self._max_retries
67
+
68
+ def _generate_call_id(self, call_id: str | None) -> str:
69
+ if call_id is None:
70
+ return str(uuid4())[:6] + "_" + self.name
71
+ return call_id
79
72
 
80
73
  def _validate_and_resolve_single_input(
81
74
  self,
@@ -87,18 +80,19 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
87
80
  "Only one of chat_inputs, in_args, or in_message must be provided."
88
81
  )
89
82
  if chat_inputs is not None and in_args is not None:
90
- raise InputValidationError(multiple_inputs_err_message)
83
+ raise ProcInputValidationError(multiple_inputs_err_message)
91
84
  if chat_inputs is not None and in_packet is not None:
92
- raise InputValidationError(multiple_inputs_err_message)
85
+ raise ProcInputValidationError(multiple_inputs_err_message)
93
86
  if in_args is not None and in_packet is not None:
94
- raise InputValidationError(multiple_inputs_err_message)
87
+ raise ProcInputValidationError(multiple_inputs_err_message)
95
88
 
96
89
  if in_packet is not None:
97
90
  if len(in_packet.payloads) != 1:
98
- raise InputValidationError(
91
+ raise ProcInputValidationError(
99
92
  "Single input runs require exactly one payload in in_packet."
100
93
  )
101
94
  return in_packet.payloads[0]
95
+
102
96
  return in_args
103
97
 
104
98
  def _validate_and_resolve_parallel_inputs(
@@ -108,33 +102,44 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
108
102
  in_args: Sequence[InT] | None,
109
103
  ) -> Sequence[InT]:
110
104
  if chat_inputs is not None:
111
- raise InputValidationError(
105
+ raise ProcInputValidationError(
112
106
  "chat_inputs are not supported in parallel runs. "
113
107
  "Use in_packet or in_args."
114
108
  )
115
109
  if in_packet is not None:
116
110
  if not in_packet.payloads:
117
- raise InputValidationError(
111
+ raise ProcInputValidationError(
118
112
  "Parallel runs require at least one input payload in in_packet."
119
113
  )
120
114
  return in_packet.payloads
121
115
  if in_args is not None:
122
116
  return in_args
123
- raise InputValidationError(
117
+ raise ProcInputValidationError(
124
118
  "Parallel runs require either in_packet or in_args to be provided."
125
119
  )
126
120
 
121
+ def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
122
+ try:
123
+ return [
124
+ TypeAdapter(self._out_type).validate_python(payload)
125
+ for payload in out_payloads
126
+ ]
127
+ except PydanticValidationError as err:
128
+ raise ProcOutputValidationError(
129
+ f"Output validation failed for processor {self.name}:\n{err}"
130
+ ) from err
131
+
127
132
  async def _process(
128
133
  self,
129
134
  chat_inputs: Any | None = None,
130
135
  *,
131
136
  in_args: InT | None = None,
132
137
  memory: MemT,
133
- run_id: str,
138
+ call_id: str,
134
139
  ctx: RunContext[CtxT] | None = None,
135
140
  ) -> Sequence[OutT_co]:
136
141
  if in_args is None:
137
- raise InputValidationError(
142
+ raise ProcInputValidationError(
138
143
  "Default implementation of _process requires in_args"
139
144
  )
140
145
 
@@ -146,35 +151,25 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
146
151
  *,
147
152
  in_args: InT | None = None,
148
153
  memory: MemT,
149
- run_id: str,
154
+ call_id: str,
150
155
  ctx: RunContext[CtxT] | None = None,
151
156
  ) -> AsyncIterator[Event[Any]]:
152
157
  if in_args is None:
153
- raise InputValidationError(
154
- "Default implementation of _process requires in_args"
158
+ raise ProcInputValidationError(
159
+ "Default implementation of _process_stream requires in_args"
155
160
  )
156
161
  outputs = cast("Sequence[OutT_co]", in_args)
157
162
  for out in outputs:
158
- yield ProcOutputEvent(data=out, name=self.name)
163
+ yield ProcPayloadOutputEvent(data=out, proc_name=self.name, call_id=call_id)
159
164
 
160
- def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
161
- return [
162
- self._out_type_adapter.validate_python(payload) for payload in out_payloads
163
- ]
164
-
165
- def _generate_run_id(self, run_id: str | None) -> str:
166
- if run_id is None:
167
- return str(uuid4())[:6] + "_" + self.name
168
- return run_id
169
-
170
- async def _run_single(
165
+ async def _run_single_once(
171
166
  self,
172
167
  chat_inputs: Any | None = None,
173
168
  *,
174
169
  in_packet: Packet[InT] | None = None,
175
170
  in_args: InT | None = None,
176
171
  forgetful: bool = False,
177
- run_id: str | None = None,
172
+ call_id: str,
178
173
  ctx: RunContext[CtxT] | None = None,
179
174
  ) -> Packet[OutT_co]:
180
175
  resolved_in_args = self._validate_and_resolve_single_input(
@@ -185,15 +180,45 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
185
180
  chat_inputs=chat_inputs,
186
181
  in_args=resolved_in_args,
187
182
  memory=_memory,
188
- run_id=self._generate_run_id(run_id),
183
+ call_id=call_id,
189
184
  ctx=ctx,
190
185
  )
191
186
  val_outputs = self._validate_outputs(outputs)
192
187
 
193
188
  return Packet(payloads=val_outputs, sender=self.name)
194
189
 
195
- def _generate_par_run_id(self, run_id: str | None, idx: int) -> str:
196
- return f"{self._generate_run_id(run_id)}/{idx}"
190
+ async def _run_single(
191
+ self,
192
+ chat_inputs: Any | None = None,
193
+ *,
194
+ in_packet: Packet[InT] | None = None,
195
+ in_args: InT | None = None,
196
+ forgetful: bool = False,
197
+ call_id: str,
198
+ ctx: RunContext[CtxT] | None = None,
199
+ ) -> Packet[OutT_co] | None:
200
+ n_attempt = 0
201
+ while n_attempt <= self.max_retries:
202
+ try:
203
+ return await self._run_single_once(
204
+ chat_inputs=chat_inputs,
205
+ in_packet=in_packet,
206
+ in_args=in_args,
207
+ forgetful=forgetful,
208
+ call_id=call_id,
209
+ ctx=ctx,
210
+ )
211
+ except Exception as err:
212
+ n_attempt += 1
213
+ if n_attempt > self.max_retries:
214
+ if n_attempt == 1:
215
+ logger.warning(f"\nProcessor run failed:\n{err}")
216
+ if n_attempt > 1:
217
+ logger.warning(f"\nProcessor run failed after retrying:\n{err}")
218
+ return None
219
+ logger.warning(
220
+ f"\nProcessor run failed (retry attempt {n_attempt}):\n{err}"
221
+ )
197
222
 
198
223
  async def _run_par(
199
224
  self,
@@ -201,27 +226,15 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
201
226
  *,
202
227
  in_packet: Packet[InT] | None = None,
203
228
  in_args: Sequence[InT] | None = None,
204
- run_id: str | None = None,
205
- forgetful: bool = False,
229
+ call_id: str,
206
230
  ctx: RunContext[CtxT] | None = None,
207
231
  ) -> Packet[OutT_co]:
208
232
  par_inputs = self._validate_and_resolve_parallel_inputs(
209
233
  chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
210
234
  )
211
-
212
- wrapped_func = retry(
213
- wait=wait_random_exponential(min=1, max=8),
214
- stop=stop_after_attempt(self._num_par_run_retries + 1),
215
- before_sleep=retry_before_sleep_callback,
216
- retry_error_callback=retry_error_callback,
217
- )(self._run_single)
218
-
219
235
  tasks = [
220
- wrapped_func(
221
- in_args=inp,
222
- forgetful=True,
223
- run_id=self._generate_par_run_id(run_id, idx),
224
- ctx=ctx,
236
+ self._run_single(
237
+ in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
225
238
  )
226
239
  for idx, inp in enumerate(par_inputs)
227
240
  ]
@@ -242,37 +255,38 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
242
255
  in_packet: Packet[InT] | None = None,
243
256
  in_args: InT | Sequence[InT] | None = None,
244
257
  forgetful: bool = False,
245
- run_id: str | None = None,
258
+ call_id: str | None = None,
246
259
  ctx: RunContext[CtxT] | None = None,
247
260
  ) -> Packet[OutT_co]:
261
+ call_id = self._generate_call_id(call_id)
262
+
248
263
  if (in_args is not None and isinstance(in_args, Sequence)) or (
249
264
  in_packet is not None and len(in_packet.payloads) > 1
250
265
  ):
251
266
  return await self._run_par(
252
267
  chat_inputs=chat_inputs,
253
268
  in_packet=in_packet,
254
- in_args=cast("Sequence[InT]", in_args),
255
- run_id=run_id,
256
- forgetful=forgetful,
269
+ in_args=cast("Sequence[InT] | None", in_args),
270
+ call_id=call_id,
257
271
  ctx=ctx,
258
272
  )
259
- return await self._run_single(
273
+ return await self._run_single( # type: ignore[return]
260
274
  chat_inputs=chat_inputs,
261
275
  in_packet=in_packet,
262
- in_args=in_args,
276
+ in_args=cast("InT | None", in_args),
263
277
  forgetful=forgetful,
264
- run_id=run_id,
278
+ call_id=call_id,
265
279
  ctx=ctx,
266
280
  )
267
281
 
268
- async def run_stream(
282
+ async def _run_single_stream_once(
269
283
  self,
270
284
  chat_inputs: Any | None = None,
271
285
  *,
272
286
  in_packet: Packet[InT] | None = None,
273
287
  in_args: InT | None = None,
274
288
  forgetful: bool = False,
275
- run_id: str | None = None,
289
+ call_id: str,
276
290
  ctx: RunContext[CtxT] | None = None,
277
291
  ) -> AsyncIterator[Event[Any]]:
278
292
  resolved_in_args = self._validate_and_resolve_single_input(
@@ -281,23 +295,149 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
281
295
 
282
296
  _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
283
297
 
284
- outputs: Sequence[OutT_co] = []
285
- async for output_event in self._process_stream(
298
+ outputs: list[OutT_co] = []
299
+ async for event in self._process_stream(
286
300
  chat_inputs=chat_inputs,
287
301
  in_args=resolved_in_args,
288
302
  memory=_memory,
289
- run_id=self._generate_run_id(run_id),
303
+ call_id=call_id,
290
304
  ctx=ctx,
291
305
  ):
292
- if isinstance(output_event, ProcOutputEvent):
293
- outputs.append(output_event.data)
294
- else:
295
- yield output_event
306
+ if isinstance(event, ProcPayloadOutputEvent):
307
+ outputs.append(event.data)
308
+ yield event
296
309
 
297
310
  val_outputs = self._validate_outputs(outputs)
298
311
  out_packet = Packet[OutT_co](payloads=val_outputs, sender=self.name)
299
312
 
300
- yield PacketEvent(data=out_packet, name=self.name)
313
+ yield ProcPacketOutputEvent(
314
+ data=out_packet, proc_name=self.name, call_id=call_id
315
+ )
316
+
317
+ async def _run_single_stream(
318
+ self,
319
+ chat_inputs: Any | None = None,
320
+ *,
321
+ in_packet: Packet[InT] | None = None,
322
+ in_args: InT | None = None,
323
+ forgetful: bool = False,
324
+ call_id: str,
325
+ ctx: RunContext[CtxT] | None = None,
326
+ ) -> AsyncIterator[Event[Any]]:
327
+ n_attempt = 0
328
+ while n_attempt <= self.max_retries:
329
+ try:
330
+ async for event in self._run_single_stream_once(
331
+ chat_inputs=chat_inputs,
332
+ in_packet=in_packet,
333
+ in_args=in_args,
334
+ forgetful=forgetful,
335
+ call_id=call_id,
336
+ ctx=ctx,
337
+ ):
338
+ yield event
339
+
340
+ return
341
+
342
+ except Exception as err:
343
+ err_data = ProcStreamingErrorData(error=err, call_id=call_id)
344
+ yield ProcStreamingErrorEvent(
345
+ data=err_data, proc_name=self.name, call_id=call_id
346
+ )
347
+
348
+ n_attempt += 1
349
+ if n_attempt > self.max_retries:
350
+ if n_attempt == 1:
351
+ logger.warning(f"\nStreaming processor run failed:\n{err}")
352
+ if n_attempt > 1:
353
+ logger.warning(
354
+ f"\nStreaming processor run failed after retrying:\n{err}"
355
+ )
356
+ return
357
+
358
+ logger.warning(
359
+ "\nStreaming processor run failed "
360
+ f"(retry attempt {n_attempt}):\n{err}"
361
+ )
362
+
363
+ async def _run_par_stream(
364
+ self,
365
+ chat_inputs: Any | None = None,
366
+ *,
367
+ in_packet: Packet[InT] | None = None,
368
+ in_args: Sequence[InT] | None = None,
369
+ call_id: str,
370
+ ctx: RunContext[CtxT] | None = None,
371
+ ) -> AsyncIterator[Event[Any]]:
372
+ par_inputs = self._validate_and_resolve_parallel_inputs(
373
+ chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
374
+ )
375
+ streams = [
376
+ self._run_single_stream(
377
+ in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
378
+ )
379
+ for idx, inp in enumerate(par_inputs)
380
+ ]
381
+
382
+ out_packets_map: dict[int, Packet[OutT_co] | None] = dict.fromkeys(
383
+ range(len(streams)), None
384
+ )
385
+
386
+ async for idx, event in stream_concurrent(streams):
387
+ if isinstance(event, ProcPacketOutputEvent):
388
+ out_packets_map[idx] = event.data
389
+ else:
390
+ yield event
391
+
392
+ out_packet = Packet( # type: ignore[return]
393
+ payloads=[
394
+ (out_packet.payloads[0] if out_packet else None)
395
+ for out_packet in out_packets_map.values()
396
+ ],
397
+ sender=self.name,
398
+ )
399
+
400
+ yield ProcPacketOutputEvent(
401
+ data=out_packet, proc_name=self.name, call_id=call_id
402
+ )
403
+
404
+ async def run_stream(
405
+ self,
406
+ chat_inputs: Any | None = None,
407
+ *,
408
+ in_packet: Packet[InT] | None = None,
409
+ in_args: InT | Sequence[InT] | None = None,
410
+ forgetful: bool = False,
411
+ call_id: str | None = None,
412
+ ctx: RunContext[CtxT] | None = None,
413
+ ) -> AsyncIterator[Event[Any]]:
414
+ call_id = self._generate_call_id(call_id)
415
+
416
+ # yield ProcStartEvent(proc_name=self.name, call_id=call_id, data=None)
417
+
418
+ if (in_args is not None and isinstance(in_args, Sequence)) or (
419
+ in_packet is not None and len(in_packet.payloads) > 1
420
+ ):
421
+ stream = self._run_par_stream(
422
+ chat_inputs=chat_inputs,
423
+ in_packet=in_packet,
424
+ in_args=cast("Sequence[InT] | None", in_args),
425
+ call_id=call_id,
426
+ ctx=ctx,
427
+ )
428
+ else:
429
+ stream = self._run_single_stream(
430
+ chat_inputs=chat_inputs,
431
+ in_packet=in_packet,
432
+ in_args=cast("InT | None", in_args),
433
+ forgetful=forgetful,
434
+ call_id=call_id,
435
+ ctx=ctx,
436
+ )
437
+ async for event in stream:
438
+ yield event
439
+
440
+ # yield ProcFinishEvent(proc_name=self.name, call_id=call_id, data=None)
301
441
 
302
442
  @final
303
443
  def as_tool(
@@ -20,7 +20,7 @@ class MakeSystemPromptHandler(Protocol[CtxT]):
20
20
  sys_args: LLMPromptArgs | None,
21
21
  *,
22
22
  ctx: RunContext[CtxT] | None,
23
- ) -> str: ...
23
+ ) -> str | None: ...
24
24
 
25
25
 
26
26
  class MakeInputContentHandler(Protocol[_InT_contra, CtxT]):
@@ -110,7 +110,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
110
110
 
111
111
  return Content.from_text(json.dumps(combined_args, indent=2))
112
112
 
113
- def make_user_message(
113
+ def make_input_message(
114
114
  self,
115
115
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
116
116
  in_args: InT | None = None,
@@ -1,8 +1,7 @@
1
1
  from collections import defaultdict
2
- from collections.abc import Mapping
3
2
  from typing import Any, Generic, TypeVar
4
3
 
5
- from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
4
+ from pydantic import BaseModel, ConfigDict, Field
6
5
 
7
6
  from grasp_agents.typing.completion import Completion
8
7
 
@@ -25,29 +24,21 @@ class RunContext(BaseModel, Generic[CtxT]):
25
24
  state: CtxT | None = None
26
25
 
27
26
  run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
28
- completions: Mapping[ProcName, list[Completion]] = Field(
27
+
28
+ is_streaming: bool = False
29
+ result: Any | None = None
30
+
31
+ completions: dict[ProcName, list[Completion]] = Field(
29
32
  default_factory=lambda: defaultdict(list)
30
33
  )
34
+ usage_tracker: UsageTracker = Field(default_factory=UsageTracker)
31
35
 
36
+ printer: Printer | None = None
32
37
  print_messages: bool = False
33
38
  color_messages_by: ColoringMode = "role"
34
39
 
35
- _usage_tracker: UsageTracker = PrivateAttr()
36
- _printer: Printer = PrivateAttr()
37
-
38
40
  def model_post_init(self, context: Any) -> None: # noqa: ARG002
39
- self._usage_tracker = UsageTracker()
40
- self._printer = Printer(
41
- print_messages=self.print_messages,
42
- color_by=self.color_messages_by,
43
- )
41
+ if self.print_messages:
42
+ self.printer = Printer(color_by=self.color_messages_by)
44
43
 
45
- @property
46
- def usage_tracker(self) -> UsageTracker:
47
- return self._usage_tracker
48
-
49
- @property
50
- def printer(self) -> Printer:
51
- return self._printer
52
-
53
- model_config = ConfigDict(extra="forbid")
44
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
grasp_agents/runner.py ADDED
@@ -0,0 +1,42 @@
1
+ from collections.abc import AsyncIterator, Sequence
2
+ from typing import Any, Generic
3
+
4
+ from .comm_processor import CommProcessor
5
+ from .run_context import CtxT, RunContext
6
+ from .typing.events import Event
7
+
8
+
9
+ class Runner(Generic[CtxT]):
10
+ def __init__(
11
+ self,
12
+ start_proc: CommProcessor[Any, Any, Any, CtxT],
13
+ procs: Sequence[CommProcessor[Any, Any, Any, CtxT]],
14
+ ctx: RunContext[CtxT] | None = None,
15
+ ) -> None:
16
+ if start_proc not in procs:
17
+ raise ValueError(
18
+ f"Start processor {start_proc.name} must be in the list of processors: "
19
+ f"{', '.join(proc.name for proc in procs)}"
20
+ )
21
+ self._start_proc = start_proc
22
+ self._procs = procs
23
+ self._ctx = ctx or RunContext[CtxT]()
24
+
25
+ @property
26
+ def ctx(self) -> RunContext[CtxT]:
27
+ return self._ctx
28
+
29
+ async def run(self, **run_args: Any) -> Any:
30
+ self._ctx.is_streaming = False
31
+ for proc in self._procs:
32
+ proc.start_listening(ctx=self._ctx, **run_args)
33
+ await self._start_proc.run(**run_args, ctx=self._ctx)
34
+
35
+ return self._ctx.result
36
+
37
+ async def run_stream(self, **run_args: Any) -> AsyncIterator[Event[Any]]:
38
+ self._ctx.is_streaming = True
39
+ for proc in self._procs:
40
+ proc.start_listening(ctx=self._ctx, **run_args)
41
+ async for event in self._start_proc.run_stream(**run_args, ctx=self._ctx):
42
+ yield event
@@ -1,8 +1,9 @@
1
1
  import time
2
- from typing import Literal, TypeAlias
2
+ from typing import Any, Literal, TypeAlias
3
3
  from uuid import uuid4
4
4
 
5
- from openai.types.chat.chat_completion import ChoiceLogprobs as CompletionChoiceLogprobs
5
+ from litellm.types.utils import ChoiceLogprobs as LiteLLMChoiceLogprobs
6
+ from openai.types.chat.chat_completion import ChoiceLogprobs as OpenAIChoiceLogprobs
6
7
  from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
7
8
 
8
9
  from .message import AssistantMessage
@@ -22,6 +23,7 @@ class Usage(BaseModel):
22
23
  def __add__(self, add_usage: "Usage") -> "Usage":
23
24
  input_tokens = self.input_tokens + add_usage.input_tokens
24
25
  output_tokens = self.output_tokens + add_usage.output_tokens
26
+
25
27
  if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
26
28
  reasoning_tokens = (self.reasoning_tokens or 0) + (
27
29
  add_usage.reasoning_tokens or 0
@@ -34,11 +36,11 @@ class Usage(BaseModel):
34
36
  else:
35
37
  cached_tokens = None
36
38
 
37
- cost = (
38
- (self.cost or 0.0) + add_usage.cost
39
- if (add_usage.cost is not None)
40
- else None
41
- )
39
+ if self.cost is not None or add_usage.cost is not None:
40
+ cost = (self.cost or 0.0) + (add_usage.cost or 0.0)
41
+ else:
42
+ cost = None
43
+
42
44
  return Usage(
43
45
  input_tokens=input_tokens,
44
46
  output_tokens=output_tokens,
@@ -52,7 +54,9 @@ class CompletionChoice(BaseModel):
52
54
  message: AssistantMessage
53
55
  finish_reason: FinishReason | None
54
56
  index: int
55
- logprobs: CompletionChoiceLogprobs | None = None
57
+ logprobs: OpenAIChoiceLogprobs | LiteLLMChoiceLogprobs | Any | None = None
58
+ # LiteLLM-specific fields
59
+ provider_specific_fields: dict[str, Any] | None = None
56
60
 
57
61
 
58
62
  class CompletionError(BaseModel):
@@ -64,12 +68,15 @@ class CompletionError(BaseModel):
64
68
  class Completion(BaseModel):
65
69
  id: str = Field(default_factory=lambda: str(uuid4())[:8])
66
70
  created: int = Field(default_factory=lambda: int(time.time()))
67
- model: str
71
+ model: str | None
68
72
  name: str | None = None
69
73
  system_fingerprint: str | None = None
70
74
  choices: list[CompletionChoice]
71
75
  usage: Usage | None = None
72
76
  error: CompletionError | None = None
77
+ # LiteLLM-specific fields
78
+ response_ms: float | None = None
79
+ hidden_params: dict[str, Any] | None = None
73
80
 
74
81
  @property
75
82
  def messages(self) -> list[AssistantMessage]: