grasp_agents 0.4.7__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. grasp_agents/cloud_llm.py +191 -224
  2. grasp_agents/comm_processor.py +101 -100
  3. grasp_agents/errors.py +69 -9
  4. grasp_agents/litellm/__init__.py +106 -0
  5. grasp_agents/litellm/completion_chunk_converters.py +68 -0
  6. grasp_agents/litellm/completion_converters.py +72 -0
  7. grasp_agents/litellm/converters.py +138 -0
  8. grasp_agents/litellm/lite_llm.py +210 -0
  9. grasp_agents/litellm/message_converters.py +66 -0
  10. grasp_agents/llm.py +84 -49
  11. grasp_agents/llm_agent.py +136 -120
  12. grasp_agents/llm_agent_memory.py +3 -3
  13. grasp_agents/llm_policy_executor.py +167 -174
  14. grasp_agents/memory.py +23 -0
  15. grasp_agents/openai/__init__.py +24 -9
  16. grasp_agents/openai/completion_chunk_converters.py +6 -6
  17. grasp_agents/openai/completion_converters.py +12 -14
  18. grasp_agents/openai/content_converters.py +1 -3
  19. grasp_agents/openai/converters.py +6 -8
  20. grasp_agents/openai/message_converters.py +21 -3
  21. grasp_agents/openai/openai_llm.py +155 -103
  22. grasp_agents/openai/tool_converters.py +4 -6
  23. grasp_agents/packet.py +5 -2
  24. grasp_agents/packet_pool.py +14 -13
  25. grasp_agents/printer.py +233 -73
  26. grasp_agents/processor.py +229 -91
  27. grasp_agents/prompt_builder.py +2 -2
  28. grasp_agents/run_context.py +11 -20
  29. grasp_agents/runner.py +42 -0
  30. grasp_agents/typing/completion.py +16 -9
  31. grasp_agents/typing/completion_chunk.py +51 -22
  32. grasp_agents/typing/events.py +95 -19
  33. grasp_agents/typing/message.py +25 -1
  34. grasp_agents/typing/tool.py +2 -0
  35. grasp_agents/usage_tracker.py +31 -37
  36. grasp_agents/utils.py +95 -84
  37. grasp_agents/workflow/looped_workflow.py +60 -11
  38. grasp_agents/workflow/sequential_workflow.py +43 -11
  39. grasp_agents/workflow/workflow_processor.py +25 -24
  40. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/METADATA +7 -6
  41. grasp_agents-0.5.1.dist-info/RECORD +57 -0
  42. grasp_agents-0.4.7.dist-info/RECORD +0 -50
  43. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/WHEEL +0 -0
  44. {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/processor.py CHANGED
@@ -6,56 +6,42 @@ from typing import Any, ClassVar, Generic, cast, final
6
6
  from uuid import uuid4
7
7
 
8
8
  from pydantic import BaseModel, TypeAdapter
9
- from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
9
+ from pydantic import ValidationError as PydanticValidationError
10
10
 
11
- from .errors import InputValidationError
11
+ from .errors import ProcInputValidationError, ProcOutputValidationError
12
12
  from .generics_utils import AutoInstanceAttributesMixin
13
- from .memory import MemT
13
+ from .memory import DummyMemory, MemT
14
14
  from .packet import Packet
15
15
  from .run_context import CtxT, RunContext
16
- from .typing.events import Event, PacketEvent, ProcOutputEvent
16
+ from .typing.events import (
17
+ Event,
18
+ ProcPacketOutputEvent,
19
+ ProcPayloadOutputEvent,
20
+ ProcStreamingErrorData,
21
+ ProcStreamingErrorEvent,
22
+ )
17
23
  from .typing.io import InT, OutT_co, ProcName
18
24
  from .typing.tool import BaseTool
25
+ from .utils import stream_concurrent
19
26
 
20
27
  logger = logging.getLogger(__name__)
21
28
 
22
29
 
23
- def retry_error_callback(retry_state: RetryCallState) -> None:
24
- exception = retry_state.outcome.exception() if retry_state.outcome else None
25
- if exception:
26
- if retry_state.attempt_number == 1:
27
- logger.warning(f"\nParallel run failed:\n{exception}")
28
- if retry_state.attempt_number > 1:
29
- logger.warning(f"\nParallel run failed after retrying:\n{exception}")
30
-
31
-
32
- def retry_before_sleep_callback(retry_state: RetryCallState) -> None:
33
- exception = retry_state.outcome.exception() if retry_state.outcome else None
34
- logger.info(
35
- f"\nRetrying parallel run (attempt {retry_state.attempt_number}):\n{exception}"
36
- )
37
-
38
-
39
30
  class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, CtxT]):
40
31
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
41
32
  0: "_in_type",
42
33
  1: "_out_type",
43
34
  }
44
35
 
45
- def __init__(
46
- self, name: ProcName, num_par_run_retries: int = 0, **kwargs: Any
47
- ) -> None:
36
+ def __init__(self, name: ProcName, max_retries: int = 0, **kwargs: Any) -> None:
48
37
  self._in_type: type[InT]
49
38
  self._out_type: type[OutT_co]
50
39
 
51
40
  super().__init__()
52
41
 
53
- self._in_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
54
- self._out_type_adapter: TypeAdapter[OutT_co] = TypeAdapter(self._out_type)
55
-
56
42
  self._name: ProcName = name
57
- self._memory: MemT
58
- self._num_par_run_retries: int = num_par_run_retries
43
+ self._memory: MemT = cast("MemT", DummyMemory())
44
+ self._max_retries: int = max_retries
59
45
 
60
46
  @property
61
47
  def in_type(self) -> type[InT]:
@@ -74,8 +60,13 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
74
60
  return self._memory
75
61
 
76
62
  @property
77
- def num_par_run_retries(self) -> int:
78
- return self._num_par_run_retries
63
+ def max_retries(self) -> int:
64
+ return self._max_retries
65
+
66
+ def _generate_call_id(self, call_id: str | None) -> str:
67
+ if call_id is None:
68
+ return str(uuid4())[:6] + "_" + self.name
69
+ return call_id
79
70
 
80
71
  def _validate_and_resolve_single_input(
81
72
  self,
@@ -87,18 +78,19 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
87
78
  "Only one of chat_inputs, in_args, or in_message must be provided."
88
79
  )
89
80
  if chat_inputs is not None and in_args is not None:
90
- raise InputValidationError(multiple_inputs_err_message)
81
+ raise ProcInputValidationError(multiple_inputs_err_message)
91
82
  if chat_inputs is not None and in_packet is not None:
92
- raise InputValidationError(multiple_inputs_err_message)
83
+ raise ProcInputValidationError(multiple_inputs_err_message)
93
84
  if in_args is not None and in_packet is not None:
94
- raise InputValidationError(multiple_inputs_err_message)
85
+ raise ProcInputValidationError(multiple_inputs_err_message)
95
86
 
96
87
  if in_packet is not None:
97
88
  if len(in_packet.payloads) != 1:
98
- raise InputValidationError(
89
+ raise ProcInputValidationError(
99
90
  "Single input runs require exactly one payload in in_packet."
100
91
  )
101
92
  return in_packet.payloads[0]
93
+
102
94
  return in_args
103
95
 
104
96
  def _validate_and_resolve_parallel_inputs(
@@ -108,33 +100,44 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
108
100
  in_args: Sequence[InT] | None,
109
101
  ) -> Sequence[InT]:
110
102
  if chat_inputs is not None:
111
- raise InputValidationError(
103
+ raise ProcInputValidationError(
112
104
  "chat_inputs are not supported in parallel runs. "
113
105
  "Use in_packet or in_args."
114
106
  )
115
107
  if in_packet is not None:
116
108
  if not in_packet.payloads:
117
- raise InputValidationError(
109
+ raise ProcInputValidationError(
118
110
  "Parallel runs require at least one input payload in in_packet."
119
111
  )
120
112
  return in_packet.payloads
121
113
  if in_args is not None:
122
114
  return in_args
123
- raise InputValidationError(
115
+ raise ProcInputValidationError(
124
116
  "Parallel runs require either in_packet or in_args to be provided."
125
117
  )
126
118
 
119
+ def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
120
+ try:
121
+ return [
122
+ TypeAdapter(self._out_type).validate_python(payload)
123
+ for payload in out_payloads
124
+ ]
125
+ except PydanticValidationError as err:
126
+ raise ProcOutputValidationError(
127
+ f"Output validation failed for processor {self.name}:\n{err}"
128
+ ) from err
129
+
127
130
  async def _process(
128
131
  self,
129
132
  chat_inputs: Any | None = None,
130
133
  *,
131
134
  in_args: InT | None = None,
132
135
  memory: MemT,
133
- run_id: str,
136
+ call_id: str,
134
137
  ctx: RunContext[CtxT] | None = None,
135
138
  ) -> Sequence[OutT_co]:
136
139
  if in_args is None:
137
- raise InputValidationError(
140
+ raise ProcInputValidationError(
138
141
  "Default implementation of _process requires in_args"
139
142
  )
140
143
 
@@ -146,54 +149,75 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
146
149
  *,
147
150
  in_args: InT | None = None,
148
151
  memory: MemT,
149
- run_id: str,
152
+ call_id: str,
150
153
  ctx: RunContext[CtxT] | None = None,
151
154
  ) -> AsyncIterator[Event[Any]]:
152
155
  if in_args is None:
153
- raise InputValidationError(
154
- "Default implementation of _process requires in_args"
156
+ raise ProcInputValidationError(
157
+ "Default implementation of _process_stream requires in_args"
155
158
  )
156
159
  outputs = cast("Sequence[OutT_co]", in_args)
157
160
  for out in outputs:
158
- yield ProcOutputEvent(data=out, name=self.name)
159
-
160
- def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
161
- return [
162
- self._out_type_adapter.validate_python(payload) for payload in out_payloads
163
- ]
164
-
165
- def _generate_run_id(self, run_id: str | None) -> str:
166
- if run_id is None:
167
- return str(uuid4())[:6] + "_" + self.name
168
- return run_id
161
+ yield ProcPayloadOutputEvent(data=out, proc_name=self.name, call_id=call_id)
169
162
 
170
- async def _run_single(
163
+ async def _run_single_once(
171
164
  self,
172
165
  chat_inputs: Any | None = None,
173
166
  *,
174
167
  in_packet: Packet[InT] | None = None,
175
168
  in_args: InT | None = None,
176
169
  forgetful: bool = False,
177
- run_id: str | None = None,
170
+ call_id: str,
178
171
  ctx: RunContext[CtxT] | None = None,
179
172
  ) -> Packet[OutT_co]:
180
173
  resolved_in_args = self._validate_and_resolve_single_input(
181
174
  chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
182
175
  )
183
176
  _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
177
+
184
178
  outputs = await self._process(
185
179
  chat_inputs=chat_inputs,
186
180
  in_args=resolved_in_args,
187
181
  memory=_memory,
188
- run_id=self._generate_run_id(run_id),
182
+ call_id=call_id,
189
183
  ctx=ctx,
190
184
  )
191
185
  val_outputs = self._validate_outputs(outputs)
192
186
 
193
187
  return Packet(payloads=val_outputs, sender=self.name)
194
188
 
195
- def _generate_par_run_id(self, run_id: str | None, idx: int) -> str:
196
- return f"{self._generate_run_id(run_id)}/{idx}"
189
+ async def _run_single(
190
+ self,
191
+ chat_inputs: Any | None = None,
192
+ *,
193
+ in_packet: Packet[InT] | None = None,
194
+ in_args: InT | None = None,
195
+ forgetful: bool = False,
196
+ call_id: str,
197
+ ctx: RunContext[CtxT] | None = None,
198
+ ) -> Packet[OutT_co] | None:
199
+ n_attempt = 0
200
+ while n_attempt <= self.max_retries:
201
+ try:
202
+ return await self._run_single_once(
203
+ chat_inputs=chat_inputs,
204
+ in_packet=in_packet,
205
+ in_args=in_args,
206
+ forgetful=forgetful,
207
+ call_id=call_id,
208
+ ctx=ctx,
209
+ )
210
+ except Exception as err:
211
+ n_attempt += 1
212
+ if n_attempt > self.max_retries:
213
+ if n_attempt == 1:
214
+ logger.warning(f"\nProcessor run failed:\n{err}")
215
+ if n_attempt > 1:
216
+ logger.warning(f"\nProcessor run failed after retrying:\n{err}")
217
+ return None
218
+ logger.warning(
219
+ f"\nProcessor run failed (retry attempt {n_attempt}):\n{err}"
220
+ )
197
221
 
198
222
  async def _run_par(
199
223
  self,
@@ -201,27 +225,15 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
201
225
  *,
202
226
  in_packet: Packet[InT] | None = None,
203
227
  in_args: Sequence[InT] | None = None,
204
- run_id: str | None = None,
205
- forgetful: bool = False,
228
+ call_id: str,
206
229
  ctx: RunContext[CtxT] | None = None,
207
230
  ) -> Packet[OutT_co]:
208
231
  par_inputs = self._validate_and_resolve_parallel_inputs(
209
232
  chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
210
233
  )
211
-
212
- wrapped_func = retry(
213
- wait=wait_random_exponential(min=1, max=8),
214
- stop=stop_after_attempt(self._num_par_run_retries + 1),
215
- before_sleep=retry_before_sleep_callback,
216
- retry_error_callback=retry_error_callback,
217
- )(self._run_single)
218
-
219
234
  tasks = [
220
- wrapped_func(
221
- in_args=inp,
222
- forgetful=True,
223
- run_id=self._generate_par_run_id(run_id, idx),
224
- ctx=ctx,
235
+ self._run_single(
236
+ in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
225
237
  )
226
238
  for idx, inp in enumerate(par_inputs)
227
239
  ]
@@ -242,62 +254,188 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
242
254
  in_packet: Packet[InT] | None = None,
243
255
  in_args: InT | Sequence[InT] | None = None,
244
256
  forgetful: bool = False,
245
- run_id: str | None = None,
257
+ call_id: str | None = None,
246
258
  ctx: RunContext[CtxT] | None = None,
247
259
  ) -> Packet[OutT_co]:
260
+ call_id = self._generate_call_id(call_id)
261
+
248
262
  if (in_args is not None and isinstance(in_args, Sequence)) or (
249
263
  in_packet is not None and len(in_packet.payloads) > 1
250
264
  ):
251
265
  return await self._run_par(
252
266
  chat_inputs=chat_inputs,
253
267
  in_packet=in_packet,
254
- in_args=cast("Sequence[InT]", in_args),
255
- run_id=run_id,
256
- forgetful=forgetful,
268
+ in_args=cast("Sequence[InT] | None", in_args),
269
+ call_id=call_id,
257
270
  ctx=ctx,
258
271
  )
259
- return await self._run_single(
272
+ return await self._run_single( # type: ignore[return]
260
273
  chat_inputs=chat_inputs,
261
274
  in_packet=in_packet,
262
- in_args=in_args,
275
+ in_args=cast("InT | None", in_args),
263
276
  forgetful=forgetful,
264
- run_id=run_id,
277
+ call_id=call_id,
265
278
  ctx=ctx,
266
279
  )
267
280
 
268
- async def run_stream(
281
+ async def _run_single_stream_once(
269
282
  self,
270
283
  chat_inputs: Any | None = None,
271
284
  *,
272
285
  in_packet: Packet[InT] | None = None,
273
286
  in_args: InT | None = None,
274
287
  forgetful: bool = False,
275
- run_id: str | None = None,
288
+ call_id: str,
276
289
  ctx: RunContext[CtxT] | None = None,
277
290
  ) -> AsyncIterator[Event[Any]]:
278
291
  resolved_in_args = self._validate_and_resolve_single_input(
279
292
  chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
280
293
  )
281
-
282
294
  _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
283
295
 
284
- outputs: Sequence[OutT_co] = []
285
- async for output_event in self._process_stream(
296
+ outputs: list[OutT_co] = []
297
+ async for event in self._process_stream(
286
298
  chat_inputs=chat_inputs,
287
299
  in_args=resolved_in_args,
288
300
  memory=_memory,
289
- run_id=self._generate_run_id(run_id),
301
+ call_id=call_id,
290
302
  ctx=ctx,
291
303
  ):
292
- if isinstance(output_event, ProcOutputEvent):
293
- outputs.append(output_event.data)
294
- else:
295
- yield output_event
304
+ if isinstance(event, ProcPayloadOutputEvent):
305
+ outputs.append(event.data)
306
+ yield event
296
307
 
297
308
  val_outputs = self._validate_outputs(outputs)
298
309
  out_packet = Packet[OutT_co](payloads=val_outputs, sender=self.name)
299
310
 
300
- yield PacketEvent(data=out_packet, name=self.name)
311
+ yield ProcPacketOutputEvent(
312
+ data=out_packet, proc_name=self.name, call_id=call_id
313
+ )
314
+
315
+ async def _run_single_stream(
316
+ self,
317
+ chat_inputs: Any | None = None,
318
+ *,
319
+ in_packet: Packet[InT] | None = None,
320
+ in_args: InT | None = None,
321
+ forgetful: bool = False,
322
+ call_id: str,
323
+ ctx: RunContext[CtxT] | None = None,
324
+ ) -> AsyncIterator[Event[Any]]:
325
+ n_attempt = 0
326
+ while n_attempt <= self.max_retries:
327
+ try:
328
+ async for event in self._run_single_stream_once(
329
+ chat_inputs=chat_inputs,
330
+ in_packet=in_packet,
331
+ in_args=in_args,
332
+ forgetful=forgetful,
333
+ call_id=call_id,
334
+ ctx=ctx,
335
+ ):
336
+ yield event
337
+
338
+ return
339
+
340
+ except Exception as err:
341
+ err_data = ProcStreamingErrorData(error=err, call_id=call_id)
342
+ yield ProcStreamingErrorEvent(
343
+ data=err_data, proc_name=self.name, call_id=call_id
344
+ )
345
+
346
+ n_attempt += 1
347
+ if n_attempt > self.max_retries:
348
+ if n_attempt == 1:
349
+ logger.warning(f"\nStreaming processor run failed:\n{err}")
350
+ if n_attempt > 1:
351
+ logger.warning(
352
+ f"\nStreaming processor run failed after retrying:\n{err}"
353
+ )
354
+ return
355
+
356
+ logger.warning(
357
+ "\nStreaming processor run failed "
358
+ f"(retry attempt {n_attempt}):\n{err}"
359
+ )
360
+
361
+ async def _run_par_stream(
362
+ self,
363
+ chat_inputs: Any | None = None,
364
+ *,
365
+ in_packet: Packet[InT] | None = None,
366
+ in_args: Sequence[InT] | None = None,
367
+ call_id: str,
368
+ ctx: RunContext[CtxT] | None = None,
369
+ ) -> AsyncIterator[Event[Any]]:
370
+ par_inputs = self._validate_and_resolve_parallel_inputs(
371
+ chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
372
+ )
373
+ streams = [
374
+ self._run_single_stream(
375
+ in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
376
+ )
377
+ for idx, inp in enumerate(par_inputs)
378
+ ]
379
+
380
+ out_packets_map: dict[int, Packet[OutT_co] | None] = dict.fromkeys(
381
+ range(len(streams)), None
382
+ )
383
+
384
+ async for idx, event in stream_concurrent(streams):
385
+ if isinstance(event, ProcPacketOutputEvent):
386
+ out_packets_map[idx] = event.data
387
+ else:
388
+ yield event
389
+
390
+ out_packet = Packet( # type: ignore[return]
391
+ payloads=[
392
+ (out_packet.payloads[0] if out_packet else None)
393
+ for out_packet in out_packets_map.values()
394
+ ],
395
+ sender=self.name,
396
+ )
397
+
398
+ yield ProcPacketOutputEvent(
399
+ data=out_packet, proc_name=self.name, call_id=call_id
400
+ )
401
+
402
+ async def run_stream(
403
+ self,
404
+ chat_inputs: Any | None = None,
405
+ *,
406
+ in_packet: Packet[InT] | None = None,
407
+ in_args: InT | Sequence[InT] | None = None,
408
+ forgetful: bool = False,
409
+ call_id: str | None = None,
410
+ ctx: RunContext[CtxT] | None = None,
411
+ ) -> AsyncIterator[Event[Any]]:
412
+ call_id = self._generate_call_id(call_id)
413
+
414
+ # yield ProcStartEvent(proc_name=self.name, call_id=call_id, data=None)
415
+
416
+ if (in_args is not None and isinstance(in_args, Sequence)) or (
417
+ in_packet is not None and len(in_packet.payloads) > 1
418
+ ):
419
+ stream = self._run_par_stream(
420
+ chat_inputs=chat_inputs,
421
+ in_packet=in_packet,
422
+ in_args=cast("Sequence[InT] | None", in_args),
423
+ call_id=call_id,
424
+ ctx=ctx,
425
+ )
426
+ else:
427
+ stream = self._run_single_stream(
428
+ chat_inputs=chat_inputs,
429
+ in_packet=in_packet,
430
+ in_args=cast("InT | None", in_args),
431
+ forgetful=forgetful,
432
+ call_id=call_id,
433
+ ctx=ctx,
434
+ )
435
+ async for event in stream:
436
+ yield event
437
+
438
+ # yield ProcFinishEvent(proc_name=self.name, call_id=call_id, data=None)
301
439
 
302
440
  @final
303
441
  def as_tool(
@@ -20,7 +20,7 @@ class MakeSystemPromptHandler(Protocol[CtxT]):
20
20
  sys_args: LLMPromptArgs | None,
21
21
  *,
22
22
  ctx: RunContext[CtxT] | None,
23
- ) -> str: ...
23
+ ) -> str | None: ...
24
24
 
25
25
 
26
26
  class MakeInputContentHandler(Protocol[_InT_contra, CtxT]):
@@ -110,7 +110,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
110
110
 
111
111
  return Content.from_text(json.dumps(combined_args, indent=2))
112
112
 
113
- def make_user_message(
113
+ def make_input_message(
114
114
  self,
115
115
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
116
116
  in_args: InT | None = None,
@@ -1,8 +1,7 @@
1
1
  from collections import defaultdict
2
- from collections.abc import Mapping
3
2
  from typing import Any, Generic, TypeVar
4
3
 
5
- from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
4
+ from pydantic import BaseModel, ConfigDict, Field
6
5
 
7
6
  from grasp_agents.typing.completion import Completion
8
7
 
@@ -25,29 +24,21 @@ class RunContext(BaseModel, Generic[CtxT]):
25
24
  state: CtxT | None = None
26
25
 
27
26
  run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
28
- completions: Mapping[ProcName, list[Completion]] = Field(
27
+
28
+ is_streaming: bool = False
29
+ result: Any | None = None
30
+
31
+ completions: dict[ProcName, list[Completion]] = Field(
29
32
  default_factory=lambda: defaultdict(list)
30
33
  )
34
+ usage_tracker: UsageTracker = Field(default_factory=UsageTracker)
31
35
 
36
+ printer: Printer | None = None
32
37
  print_messages: bool = False
33
38
  color_messages_by: ColoringMode = "role"
34
39
 
35
- _usage_tracker: UsageTracker = PrivateAttr()
36
- _printer: Printer = PrivateAttr()
37
-
38
40
  def model_post_init(self, context: Any) -> None: # noqa: ARG002
39
- self._usage_tracker = UsageTracker()
40
- self._printer = Printer(
41
- print_messages=self.print_messages,
42
- color_by=self.color_messages_by,
43
- )
41
+ if self.print_messages:
42
+ self.printer = Printer(color_by=self.color_messages_by)
44
43
 
45
- @property
46
- def usage_tracker(self) -> UsageTracker:
47
- return self._usage_tracker
48
-
49
- @property
50
- def printer(self) -> Printer:
51
- return self._printer
52
-
53
- model_config = ConfigDict(extra="forbid")
44
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
grasp_agents/runner.py ADDED
@@ -0,0 +1,42 @@
1
+ from collections.abc import AsyncIterator, Sequence
2
+ from typing import Any, Generic
3
+
4
+ from .comm_processor import CommProcessor
5
+ from .run_context import CtxT, RunContext
6
+ from .typing.events import Event
7
+
8
+
9
+ class Runner(Generic[CtxT]):
10
+ def __init__(
11
+ self,
12
+ start_proc: CommProcessor[Any, Any, Any, CtxT],
13
+ procs: Sequence[CommProcessor[Any, Any, Any, CtxT]],
14
+ ctx: RunContext[CtxT] | None = None,
15
+ ) -> None:
16
+ if start_proc not in procs:
17
+ raise ValueError(
18
+ f"Start processor {start_proc.name} must be in the list of processors: "
19
+ f"{', '.join(proc.name for proc in procs)}"
20
+ )
21
+ self._start_proc = start_proc
22
+ self._procs = procs
23
+ self._ctx = ctx or RunContext[CtxT]()
24
+
25
+ @property
26
+ def ctx(self) -> RunContext[CtxT]:
27
+ return self._ctx
28
+
29
+ async def run(self, **run_args: Any) -> Any:
30
+ self._ctx.is_streaming = False
31
+ for proc in self._procs:
32
+ proc.start_listening(ctx=self._ctx, **run_args)
33
+ await self._start_proc.run(**run_args, ctx=self._ctx)
34
+
35
+ return self._ctx.result
36
+
37
+ async def run_stream(self, **run_args: Any) -> AsyncIterator[Event[Any]]:
38
+ self._ctx.is_streaming = True
39
+ for proc in self._procs:
40
+ proc.start_listening(ctx=self._ctx, **run_args)
41
+ async for event in self._start_proc.run_stream(**run_args, ctx=self._ctx):
42
+ yield event
@@ -1,8 +1,9 @@
1
1
  import time
2
- from typing import Literal, TypeAlias
2
+ from typing import Any, Literal, TypeAlias
3
3
  from uuid import uuid4
4
4
 
5
- from openai.types.chat.chat_completion import ChoiceLogprobs as CompletionChoiceLogprobs
5
+ from litellm.types.utils import ChoiceLogprobs as LiteLLMChoiceLogprobs
6
+ from openai.types.chat.chat_completion import ChoiceLogprobs as OpenAIChoiceLogprobs
6
7
  from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
7
8
 
8
9
  from .message import AssistantMessage
@@ -22,6 +23,7 @@ class Usage(BaseModel):
22
23
  def __add__(self, add_usage: "Usage") -> "Usage":
23
24
  input_tokens = self.input_tokens + add_usage.input_tokens
24
25
  output_tokens = self.output_tokens + add_usage.output_tokens
26
+
25
27
  if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
26
28
  reasoning_tokens = (self.reasoning_tokens or 0) + (
27
29
  add_usage.reasoning_tokens or 0
@@ -34,11 +36,11 @@ class Usage(BaseModel):
34
36
  else:
35
37
  cached_tokens = None
36
38
 
37
- cost = (
38
- (self.cost or 0.0) + add_usage.cost
39
- if (add_usage.cost is not None)
40
- else None
41
- )
39
+ if self.cost is not None or add_usage.cost is not None:
40
+ cost = (self.cost or 0.0) + (add_usage.cost or 0.0)
41
+ else:
42
+ cost = None
43
+
42
44
  return Usage(
43
45
  input_tokens=input_tokens,
44
46
  output_tokens=output_tokens,
@@ -52,7 +54,9 @@ class CompletionChoice(BaseModel):
52
54
  message: AssistantMessage
53
55
  finish_reason: FinishReason | None
54
56
  index: int
55
- logprobs: CompletionChoiceLogprobs | None = None
57
+ logprobs: OpenAIChoiceLogprobs | LiteLLMChoiceLogprobs | Any | None = None
58
+ # LiteLLM-specific fields
59
+ provider_specific_fields: dict[str, Any] | None = None
56
60
 
57
61
 
58
62
  class CompletionError(BaseModel):
@@ -64,12 +68,15 @@ class CompletionError(BaseModel):
64
68
  class Completion(BaseModel):
65
69
  id: str = Field(default_factory=lambda: str(uuid4())[:8])
66
70
  created: int = Field(default_factory=lambda: int(time.time()))
67
- model: str
71
+ model: str | None
68
72
  name: str | None = None
69
73
  system_fingerprint: str | None = None
70
74
  choices: list[CompletionChoice]
71
75
  usage: Usage | None = None
72
76
  error: CompletionError | None = None
77
+ # LiteLLM-specific fields
78
+ response_ms: float | None = None
79
+ hidden_params: dict[str, Any] | None = None
73
80
 
74
81
  @property
75
82
  def messages(self) -> list[AssistantMessage]: