grasp_agents 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +191 -218
- grasp_agents/comm_processor.py +101 -100
- grasp_agents/errors.py +69 -9
- grasp_agents/litellm/__init__.py +106 -0
- grasp_agents/litellm/completion_chunk_converters.py +68 -0
- grasp_agents/litellm/completion_converters.py +72 -0
- grasp_agents/litellm/converters.py +138 -0
- grasp_agents/litellm/lite_llm.py +210 -0
- grasp_agents/litellm/message_converters.py +66 -0
- grasp_agents/llm.py +84 -49
- grasp_agents/llm_agent.py +136 -120
- grasp_agents/llm_agent_memory.py +3 -3
- grasp_agents/llm_policy_executor.py +167 -174
- grasp_agents/memory.py +4 -0
- grasp_agents/openai/__init__.py +24 -9
- grasp_agents/openai/completion_chunk_converters.py +6 -6
- grasp_agents/openai/completion_converters.py +12 -14
- grasp_agents/openai/content_converters.py +1 -3
- grasp_agents/openai/converters.py +6 -8
- grasp_agents/openai/message_converters.py +21 -3
- grasp_agents/openai/openai_llm.py +155 -103
- grasp_agents/openai/tool_converters.py +4 -6
- grasp_agents/packet.py +5 -2
- grasp_agents/packet_pool.py +14 -13
- grasp_agents/printer.py +234 -72
- grasp_agents/processor.py +228 -88
- grasp_agents/prompt_builder.py +2 -2
- grasp_agents/run_context.py +11 -20
- grasp_agents/runner.py +42 -0
- grasp_agents/typing/completion.py +16 -9
- grasp_agents/typing/completion_chunk.py +51 -22
- grasp_agents/typing/events.py +95 -19
- grasp_agents/typing/message.py +25 -1
- grasp_agents/typing/tool.py +2 -0
- grasp_agents/usage_tracker.py +31 -37
- grasp_agents/utils.py +95 -84
- grasp_agents/workflow/looped_workflow.py +60 -11
- grasp_agents/workflow/sequential_workflow.py +43 -11
- grasp_agents/workflow/workflow_processor.py +25 -24
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/METADATA +7 -6
- grasp_agents-0.5.0.dist-info/RECORD +57 -0
- grasp_agents-0.4.6.dist-info/RECORD +0 -50
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/WHEEL +0 -0
- {grasp_agents-0.4.6.dist-info → grasp_agents-0.5.0.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/processor.py
CHANGED
@@ -6,56 +6,44 @@ from typing import Any, ClassVar, Generic, cast, final
|
|
6
6
|
from uuid import uuid4
|
7
7
|
|
8
8
|
from pydantic import BaseModel, TypeAdapter
|
9
|
-
from
|
9
|
+
from pydantic import ValidationError as PydanticValidationError
|
10
10
|
|
11
|
-
from .errors import
|
11
|
+
from .errors import ProcInputValidationError, ProcOutputValidationError
|
12
12
|
from .generics_utils import AutoInstanceAttributesMixin
|
13
13
|
from .memory import MemT
|
14
14
|
from .packet import Packet
|
15
15
|
from .run_context import CtxT, RunContext
|
16
|
-
from .typing.events import
|
16
|
+
from .typing.events import (
|
17
|
+
Event,
|
18
|
+
# ProcFinishEvent,
|
19
|
+
ProcPacketOutputEvent,
|
20
|
+
ProcPayloadOutputEvent,
|
21
|
+
# ProcStartEvent,
|
22
|
+
ProcStreamingErrorData,
|
23
|
+
ProcStreamingErrorEvent,
|
24
|
+
)
|
17
25
|
from .typing.io import InT, OutT_co, ProcName
|
18
26
|
from .typing.tool import BaseTool
|
27
|
+
from .utils import stream_concurrent
|
19
28
|
|
20
29
|
logger = logging.getLogger(__name__)
|
21
30
|
|
22
31
|
|
23
|
-
def retry_error_callback(retry_state: RetryCallState) -> None:
|
24
|
-
exception = retry_state.outcome.exception() if retry_state.outcome else None
|
25
|
-
if exception:
|
26
|
-
if retry_state.attempt_number == 1:
|
27
|
-
logger.warning(f"\nParallel run failed:\n{exception}")
|
28
|
-
if retry_state.attempt_number > 1:
|
29
|
-
logger.warning(f"\nParallel run failed after retrying:\n{exception}")
|
30
|
-
|
31
|
-
|
32
|
-
def retry_before_sleep_callback(retry_state: RetryCallState) -> None:
|
33
|
-
exception = retry_state.outcome.exception() if retry_state.outcome else None
|
34
|
-
logger.info(
|
35
|
-
f"\nRetrying parallel run (attempt {retry_state.attempt_number}):\n{exception}"
|
36
|
-
)
|
37
|
-
|
38
|
-
|
39
32
|
class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, CtxT]):
|
40
33
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
41
34
|
0: "_in_type",
|
42
35
|
1: "_out_type",
|
43
36
|
}
|
44
37
|
|
45
|
-
def __init__(
|
46
|
-
self, name: ProcName, num_par_run_retries: int = 0, **kwargs: Any
|
47
|
-
) -> None:
|
38
|
+
def __init__(self, name: ProcName, max_retries: int = 0, **kwargs: Any) -> None:
|
48
39
|
self._in_type: type[InT]
|
49
40
|
self._out_type: type[OutT_co]
|
50
41
|
|
51
42
|
super().__init__()
|
52
43
|
|
53
|
-
self._in_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
54
|
-
self._out_type_adapter: TypeAdapter[OutT_co] = TypeAdapter(self._out_type)
|
55
|
-
|
56
44
|
self._name: ProcName = name
|
57
45
|
self._memory: MemT
|
58
|
-
self.
|
46
|
+
self._max_retries: int = max_retries
|
59
47
|
|
60
48
|
@property
|
61
49
|
def in_type(self) -> type[InT]:
|
@@ -74,8 +62,13 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
74
62
|
return self._memory
|
75
63
|
|
76
64
|
@property
|
77
|
-
def
|
78
|
-
return self.
|
65
|
+
def max_retries(self) -> int:
|
66
|
+
return self._max_retries
|
67
|
+
|
68
|
+
def _generate_call_id(self, call_id: str | None) -> str:
|
69
|
+
if call_id is None:
|
70
|
+
return str(uuid4())[:6] + "_" + self.name
|
71
|
+
return call_id
|
79
72
|
|
80
73
|
def _validate_and_resolve_single_input(
|
81
74
|
self,
|
@@ -87,18 +80,19 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
87
80
|
"Only one of chat_inputs, in_args, or in_message must be provided."
|
88
81
|
)
|
89
82
|
if chat_inputs is not None and in_args is not None:
|
90
|
-
raise
|
83
|
+
raise ProcInputValidationError(multiple_inputs_err_message)
|
91
84
|
if chat_inputs is not None and in_packet is not None:
|
92
|
-
raise
|
85
|
+
raise ProcInputValidationError(multiple_inputs_err_message)
|
93
86
|
if in_args is not None and in_packet is not None:
|
94
|
-
raise
|
87
|
+
raise ProcInputValidationError(multiple_inputs_err_message)
|
95
88
|
|
96
89
|
if in_packet is not None:
|
97
90
|
if len(in_packet.payloads) != 1:
|
98
|
-
raise
|
91
|
+
raise ProcInputValidationError(
|
99
92
|
"Single input runs require exactly one payload in in_packet."
|
100
93
|
)
|
101
94
|
return in_packet.payloads[0]
|
95
|
+
|
102
96
|
return in_args
|
103
97
|
|
104
98
|
def _validate_and_resolve_parallel_inputs(
|
@@ -108,33 +102,44 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
108
102
|
in_args: Sequence[InT] | None,
|
109
103
|
) -> Sequence[InT]:
|
110
104
|
if chat_inputs is not None:
|
111
|
-
raise
|
105
|
+
raise ProcInputValidationError(
|
112
106
|
"chat_inputs are not supported in parallel runs. "
|
113
107
|
"Use in_packet or in_args."
|
114
108
|
)
|
115
109
|
if in_packet is not None:
|
116
110
|
if not in_packet.payloads:
|
117
|
-
raise
|
111
|
+
raise ProcInputValidationError(
|
118
112
|
"Parallel runs require at least one input payload in in_packet."
|
119
113
|
)
|
120
114
|
return in_packet.payloads
|
121
115
|
if in_args is not None:
|
122
116
|
return in_args
|
123
|
-
raise
|
117
|
+
raise ProcInputValidationError(
|
124
118
|
"Parallel runs require either in_packet or in_args to be provided."
|
125
119
|
)
|
126
120
|
|
121
|
+
def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
|
122
|
+
try:
|
123
|
+
return [
|
124
|
+
TypeAdapter(self._out_type).validate_python(payload)
|
125
|
+
for payload in out_payloads
|
126
|
+
]
|
127
|
+
except PydanticValidationError as err:
|
128
|
+
raise ProcOutputValidationError(
|
129
|
+
f"Output validation failed for processor {self.name}:\n{err}"
|
130
|
+
) from err
|
131
|
+
|
127
132
|
async def _process(
|
128
133
|
self,
|
129
134
|
chat_inputs: Any | None = None,
|
130
135
|
*,
|
131
136
|
in_args: InT | None = None,
|
132
137
|
memory: MemT,
|
133
|
-
|
138
|
+
call_id: str,
|
134
139
|
ctx: RunContext[CtxT] | None = None,
|
135
140
|
) -> Sequence[OutT_co]:
|
136
141
|
if in_args is None:
|
137
|
-
raise
|
142
|
+
raise ProcInputValidationError(
|
138
143
|
"Default implementation of _process requires in_args"
|
139
144
|
)
|
140
145
|
|
@@ -146,35 +151,25 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
146
151
|
*,
|
147
152
|
in_args: InT | None = None,
|
148
153
|
memory: MemT,
|
149
|
-
|
154
|
+
call_id: str,
|
150
155
|
ctx: RunContext[CtxT] | None = None,
|
151
156
|
) -> AsyncIterator[Event[Any]]:
|
152
157
|
if in_args is None:
|
153
|
-
raise
|
154
|
-
"Default implementation of
|
158
|
+
raise ProcInputValidationError(
|
159
|
+
"Default implementation of _process_stream requires in_args"
|
155
160
|
)
|
156
161
|
outputs = cast("Sequence[OutT_co]", in_args)
|
157
162
|
for out in outputs:
|
158
|
-
yield
|
163
|
+
yield ProcPayloadOutputEvent(data=out, proc_name=self.name, call_id=call_id)
|
159
164
|
|
160
|
-
def
|
161
|
-
return [
|
162
|
-
self._out_type_adapter.validate_python(payload) for payload in out_payloads
|
163
|
-
]
|
164
|
-
|
165
|
-
def _generate_run_id(self, run_id: str | None) -> str:
|
166
|
-
if run_id is None:
|
167
|
-
return str(uuid4())[:6] + "_" + self.name
|
168
|
-
return run_id
|
169
|
-
|
170
|
-
async def _run_single(
|
165
|
+
async def _run_single_once(
|
171
166
|
self,
|
172
167
|
chat_inputs: Any | None = None,
|
173
168
|
*,
|
174
169
|
in_packet: Packet[InT] | None = None,
|
175
170
|
in_args: InT | None = None,
|
176
171
|
forgetful: bool = False,
|
177
|
-
|
172
|
+
call_id: str,
|
178
173
|
ctx: RunContext[CtxT] | None = None,
|
179
174
|
) -> Packet[OutT_co]:
|
180
175
|
resolved_in_args = self._validate_and_resolve_single_input(
|
@@ -185,15 +180,45 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
185
180
|
chat_inputs=chat_inputs,
|
186
181
|
in_args=resolved_in_args,
|
187
182
|
memory=_memory,
|
188
|
-
|
183
|
+
call_id=call_id,
|
189
184
|
ctx=ctx,
|
190
185
|
)
|
191
186
|
val_outputs = self._validate_outputs(outputs)
|
192
187
|
|
193
188
|
return Packet(payloads=val_outputs, sender=self.name)
|
194
189
|
|
195
|
-
def
|
196
|
-
|
190
|
+
async def _run_single(
|
191
|
+
self,
|
192
|
+
chat_inputs: Any | None = None,
|
193
|
+
*,
|
194
|
+
in_packet: Packet[InT] | None = None,
|
195
|
+
in_args: InT | None = None,
|
196
|
+
forgetful: bool = False,
|
197
|
+
call_id: str,
|
198
|
+
ctx: RunContext[CtxT] | None = None,
|
199
|
+
) -> Packet[OutT_co] | None:
|
200
|
+
n_attempt = 0
|
201
|
+
while n_attempt <= self.max_retries:
|
202
|
+
try:
|
203
|
+
return await self._run_single_once(
|
204
|
+
chat_inputs=chat_inputs,
|
205
|
+
in_packet=in_packet,
|
206
|
+
in_args=in_args,
|
207
|
+
forgetful=forgetful,
|
208
|
+
call_id=call_id,
|
209
|
+
ctx=ctx,
|
210
|
+
)
|
211
|
+
except Exception as err:
|
212
|
+
n_attempt += 1
|
213
|
+
if n_attempt > self.max_retries:
|
214
|
+
if n_attempt == 1:
|
215
|
+
logger.warning(f"\nProcessor run failed:\n{err}")
|
216
|
+
if n_attempt > 1:
|
217
|
+
logger.warning(f"\nProcessor run failed after retrying:\n{err}")
|
218
|
+
return None
|
219
|
+
logger.warning(
|
220
|
+
f"\nProcessor run failed (retry attempt {n_attempt}):\n{err}"
|
221
|
+
)
|
197
222
|
|
198
223
|
async def _run_par(
|
199
224
|
self,
|
@@ -201,27 +226,15 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
201
226
|
*,
|
202
227
|
in_packet: Packet[InT] | None = None,
|
203
228
|
in_args: Sequence[InT] | None = None,
|
204
|
-
|
205
|
-
forgetful: bool = False,
|
229
|
+
call_id: str,
|
206
230
|
ctx: RunContext[CtxT] | None = None,
|
207
231
|
) -> Packet[OutT_co]:
|
208
232
|
par_inputs = self._validate_and_resolve_parallel_inputs(
|
209
233
|
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
210
234
|
)
|
211
|
-
|
212
|
-
wrapped_func = retry(
|
213
|
-
wait=wait_random_exponential(min=1, max=8),
|
214
|
-
stop=stop_after_attempt(self._num_par_run_retries + 1),
|
215
|
-
before_sleep=retry_before_sleep_callback,
|
216
|
-
retry_error_callback=retry_error_callback,
|
217
|
-
)(self._run_single)
|
218
|
-
|
219
235
|
tasks = [
|
220
|
-
|
221
|
-
in_args=inp,
|
222
|
-
forgetful=True,
|
223
|
-
run_id=self._generate_par_run_id(run_id, idx),
|
224
|
-
ctx=ctx,
|
236
|
+
self._run_single(
|
237
|
+
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
225
238
|
)
|
226
239
|
for idx, inp in enumerate(par_inputs)
|
227
240
|
]
|
@@ -242,37 +255,38 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
242
255
|
in_packet: Packet[InT] | None = None,
|
243
256
|
in_args: InT | Sequence[InT] | None = None,
|
244
257
|
forgetful: bool = False,
|
245
|
-
|
258
|
+
call_id: str | None = None,
|
246
259
|
ctx: RunContext[CtxT] | None = None,
|
247
260
|
) -> Packet[OutT_co]:
|
261
|
+
call_id = self._generate_call_id(call_id)
|
262
|
+
|
248
263
|
if (in_args is not None and isinstance(in_args, Sequence)) or (
|
249
264
|
in_packet is not None and len(in_packet.payloads) > 1
|
250
265
|
):
|
251
266
|
return await self._run_par(
|
252
267
|
chat_inputs=chat_inputs,
|
253
268
|
in_packet=in_packet,
|
254
|
-
in_args=cast("Sequence[InT]", in_args),
|
255
|
-
|
256
|
-
forgetful=forgetful,
|
269
|
+
in_args=cast("Sequence[InT] | None", in_args),
|
270
|
+
call_id=call_id,
|
257
271
|
ctx=ctx,
|
258
272
|
)
|
259
|
-
return await self._run_single(
|
273
|
+
return await self._run_single( # type: ignore[return]
|
260
274
|
chat_inputs=chat_inputs,
|
261
275
|
in_packet=in_packet,
|
262
|
-
in_args=in_args,
|
276
|
+
in_args=cast("InT | None", in_args),
|
263
277
|
forgetful=forgetful,
|
264
|
-
|
278
|
+
call_id=call_id,
|
265
279
|
ctx=ctx,
|
266
280
|
)
|
267
281
|
|
268
|
-
async def
|
282
|
+
async def _run_single_stream_once(
|
269
283
|
self,
|
270
284
|
chat_inputs: Any | None = None,
|
271
285
|
*,
|
272
286
|
in_packet: Packet[InT] | None = None,
|
273
287
|
in_args: InT | None = None,
|
274
288
|
forgetful: bool = False,
|
275
|
-
|
289
|
+
call_id: str,
|
276
290
|
ctx: RunContext[CtxT] | None = None,
|
277
291
|
) -> AsyncIterator[Event[Any]]:
|
278
292
|
resolved_in_args = self._validate_and_resolve_single_input(
|
@@ -281,23 +295,149 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
281
295
|
|
282
296
|
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
283
297
|
|
284
|
-
outputs:
|
285
|
-
async for
|
298
|
+
outputs: list[OutT_co] = []
|
299
|
+
async for event in self._process_stream(
|
286
300
|
chat_inputs=chat_inputs,
|
287
301
|
in_args=resolved_in_args,
|
288
302
|
memory=_memory,
|
289
|
-
|
303
|
+
call_id=call_id,
|
290
304
|
ctx=ctx,
|
291
305
|
):
|
292
|
-
if isinstance(
|
293
|
-
outputs.append(
|
294
|
-
|
295
|
-
yield output_event
|
306
|
+
if isinstance(event, ProcPayloadOutputEvent):
|
307
|
+
outputs.append(event.data)
|
308
|
+
yield event
|
296
309
|
|
297
310
|
val_outputs = self._validate_outputs(outputs)
|
298
311
|
out_packet = Packet[OutT_co](payloads=val_outputs, sender=self.name)
|
299
312
|
|
300
|
-
yield
|
313
|
+
yield ProcPacketOutputEvent(
|
314
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
315
|
+
)
|
316
|
+
|
317
|
+
async def _run_single_stream(
|
318
|
+
self,
|
319
|
+
chat_inputs: Any | None = None,
|
320
|
+
*,
|
321
|
+
in_packet: Packet[InT] | None = None,
|
322
|
+
in_args: InT | None = None,
|
323
|
+
forgetful: bool = False,
|
324
|
+
call_id: str,
|
325
|
+
ctx: RunContext[CtxT] | None = None,
|
326
|
+
) -> AsyncIterator[Event[Any]]:
|
327
|
+
n_attempt = 0
|
328
|
+
while n_attempt <= self.max_retries:
|
329
|
+
try:
|
330
|
+
async for event in self._run_single_stream_once(
|
331
|
+
chat_inputs=chat_inputs,
|
332
|
+
in_packet=in_packet,
|
333
|
+
in_args=in_args,
|
334
|
+
forgetful=forgetful,
|
335
|
+
call_id=call_id,
|
336
|
+
ctx=ctx,
|
337
|
+
):
|
338
|
+
yield event
|
339
|
+
|
340
|
+
return
|
341
|
+
|
342
|
+
except Exception as err:
|
343
|
+
err_data = ProcStreamingErrorData(error=err, call_id=call_id)
|
344
|
+
yield ProcStreamingErrorEvent(
|
345
|
+
data=err_data, proc_name=self.name, call_id=call_id
|
346
|
+
)
|
347
|
+
|
348
|
+
n_attempt += 1
|
349
|
+
if n_attempt > self.max_retries:
|
350
|
+
if n_attempt == 1:
|
351
|
+
logger.warning(f"\nStreaming processor run failed:\n{err}")
|
352
|
+
if n_attempt > 1:
|
353
|
+
logger.warning(
|
354
|
+
f"\nStreaming processor run failed after retrying:\n{err}"
|
355
|
+
)
|
356
|
+
return
|
357
|
+
|
358
|
+
logger.warning(
|
359
|
+
"\nStreaming processor run failed "
|
360
|
+
f"(retry attempt {n_attempt}):\n{err}"
|
361
|
+
)
|
362
|
+
|
363
|
+
async def _run_par_stream(
|
364
|
+
self,
|
365
|
+
chat_inputs: Any | None = None,
|
366
|
+
*,
|
367
|
+
in_packet: Packet[InT] | None = None,
|
368
|
+
in_args: Sequence[InT] | None = None,
|
369
|
+
call_id: str,
|
370
|
+
ctx: RunContext[CtxT] | None = None,
|
371
|
+
) -> AsyncIterator[Event[Any]]:
|
372
|
+
par_inputs = self._validate_and_resolve_parallel_inputs(
|
373
|
+
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
374
|
+
)
|
375
|
+
streams = [
|
376
|
+
self._run_single_stream(
|
377
|
+
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
378
|
+
)
|
379
|
+
for idx, inp in enumerate(par_inputs)
|
380
|
+
]
|
381
|
+
|
382
|
+
out_packets_map: dict[int, Packet[OutT_co] | None] = dict.fromkeys(
|
383
|
+
range(len(streams)), None
|
384
|
+
)
|
385
|
+
|
386
|
+
async for idx, event in stream_concurrent(streams):
|
387
|
+
if isinstance(event, ProcPacketOutputEvent):
|
388
|
+
out_packets_map[idx] = event.data
|
389
|
+
else:
|
390
|
+
yield event
|
391
|
+
|
392
|
+
out_packet = Packet( # type: ignore[return]
|
393
|
+
payloads=[
|
394
|
+
(out_packet.payloads[0] if out_packet else None)
|
395
|
+
for out_packet in out_packets_map.values()
|
396
|
+
],
|
397
|
+
sender=self.name,
|
398
|
+
)
|
399
|
+
|
400
|
+
yield ProcPacketOutputEvent(
|
401
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
402
|
+
)
|
403
|
+
|
404
|
+
async def run_stream(
|
405
|
+
self,
|
406
|
+
chat_inputs: Any | None = None,
|
407
|
+
*,
|
408
|
+
in_packet: Packet[InT] | None = None,
|
409
|
+
in_args: InT | Sequence[InT] | None = None,
|
410
|
+
forgetful: bool = False,
|
411
|
+
call_id: str | None = None,
|
412
|
+
ctx: RunContext[CtxT] | None = None,
|
413
|
+
) -> AsyncIterator[Event[Any]]:
|
414
|
+
call_id = self._generate_call_id(call_id)
|
415
|
+
|
416
|
+
# yield ProcStartEvent(proc_name=self.name, call_id=call_id, data=None)
|
417
|
+
|
418
|
+
if (in_args is not None and isinstance(in_args, Sequence)) or (
|
419
|
+
in_packet is not None and len(in_packet.payloads) > 1
|
420
|
+
):
|
421
|
+
stream = self._run_par_stream(
|
422
|
+
chat_inputs=chat_inputs,
|
423
|
+
in_packet=in_packet,
|
424
|
+
in_args=cast("Sequence[InT] | None", in_args),
|
425
|
+
call_id=call_id,
|
426
|
+
ctx=ctx,
|
427
|
+
)
|
428
|
+
else:
|
429
|
+
stream = self._run_single_stream(
|
430
|
+
chat_inputs=chat_inputs,
|
431
|
+
in_packet=in_packet,
|
432
|
+
in_args=cast("InT | None", in_args),
|
433
|
+
forgetful=forgetful,
|
434
|
+
call_id=call_id,
|
435
|
+
ctx=ctx,
|
436
|
+
)
|
437
|
+
async for event in stream:
|
438
|
+
yield event
|
439
|
+
|
440
|
+
# yield ProcFinishEvent(proc_name=self.name, call_id=call_id, data=None)
|
301
441
|
|
302
442
|
@final
|
303
443
|
def as_tool(
|
grasp_agents/prompt_builder.py
CHANGED
@@ -20,7 +20,7 @@ class MakeSystemPromptHandler(Protocol[CtxT]):
|
|
20
20
|
sys_args: LLMPromptArgs | None,
|
21
21
|
*,
|
22
22
|
ctx: RunContext[CtxT] | None,
|
23
|
-
) -> str: ...
|
23
|
+
) -> str | None: ...
|
24
24
|
|
25
25
|
|
26
26
|
class MakeInputContentHandler(Protocol[_InT_contra, CtxT]):
|
@@ -110,7 +110,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
110
110
|
|
111
111
|
return Content.from_text(json.dumps(combined_args, indent=2))
|
112
112
|
|
113
|
-
def
|
113
|
+
def make_input_message(
|
114
114
|
self,
|
115
115
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
116
116
|
in_args: InT | None = None,
|
grasp_agents/run_context.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
1
|
from collections import defaultdict
|
2
|
-
from collections.abc import Mapping
|
3
2
|
from typing import Any, Generic, TypeVar
|
4
3
|
|
5
|
-
from pydantic import BaseModel, ConfigDict, Field
|
4
|
+
from pydantic import BaseModel, ConfigDict, Field
|
6
5
|
|
7
6
|
from grasp_agents.typing.completion import Completion
|
8
7
|
|
@@ -25,29 +24,21 @@ class RunContext(BaseModel, Generic[CtxT]):
|
|
25
24
|
state: CtxT | None = None
|
26
25
|
|
27
26
|
run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
|
28
|
-
|
27
|
+
|
28
|
+
is_streaming: bool = False
|
29
|
+
result: Any | None = None
|
30
|
+
|
31
|
+
completions: dict[ProcName, list[Completion]] = Field(
|
29
32
|
default_factory=lambda: defaultdict(list)
|
30
33
|
)
|
34
|
+
usage_tracker: UsageTracker = Field(default_factory=UsageTracker)
|
31
35
|
|
36
|
+
printer: Printer | None = None
|
32
37
|
print_messages: bool = False
|
33
38
|
color_messages_by: ColoringMode = "role"
|
34
39
|
|
35
|
-
_usage_tracker: UsageTracker = PrivateAttr()
|
36
|
-
_printer: Printer = PrivateAttr()
|
37
|
-
|
38
40
|
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
39
|
-
self.
|
40
|
-
|
41
|
-
print_messages=self.print_messages,
|
42
|
-
color_by=self.color_messages_by,
|
43
|
-
)
|
41
|
+
if self.print_messages:
|
42
|
+
self.printer = Printer(color_by=self.color_messages_by)
|
44
43
|
|
45
|
-
|
46
|
-
def usage_tracker(self) -> UsageTracker:
|
47
|
-
return self._usage_tracker
|
48
|
-
|
49
|
-
@property
|
50
|
-
def printer(self) -> Printer:
|
51
|
-
return self._printer
|
52
|
-
|
53
|
-
model_config = ConfigDict(extra="forbid")
|
44
|
+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
grasp_agents/runner.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
from collections.abc import AsyncIterator, Sequence
|
2
|
+
from typing import Any, Generic
|
3
|
+
|
4
|
+
from .comm_processor import CommProcessor
|
5
|
+
from .run_context import CtxT, RunContext
|
6
|
+
from .typing.events import Event
|
7
|
+
|
8
|
+
|
9
|
+
class Runner(Generic[CtxT]):
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
start_proc: CommProcessor[Any, Any, Any, CtxT],
|
13
|
+
procs: Sequence[CommProcessor[Any, Any, Any, CtxT]],
|
14
|
+
ctx: RunContext[CtxT] | None = None,
|
15
|
+
) -> None:
|
16
|
+
if start_proc not in procs:
|
17
|
+
raise ValueError(
|
18
|
+
f"Start processor {start_proc.name} must be in the list of processors: "
|
19
|
+
f"{', '.join(proc.name for proc in procs)}"
|
20
|
+
)
|
21
|
+
self._start_proc = start_proc
|
22
|
+
self._procs = procs
|
23
|
+
self._ctx = ctx or RunContext[CtxT]()
|
24
|
+
|
25
|
+
@property
|
26
|
+
def ctx(self) -> RunContext[CtxT]:
|
27
|
+
return self._ctx
|
28
|
+
|
29
|
+
async def run(self, **run_args: Any) -> Any:
|
30
|
+
self._ctx.is_streaming = False
|
31
|
+
for proc in self._procs:
|
32
|
+
proc.start_listening(ctx=self._ctx, **run_args)
|
33
|
+
await self._start_proc.run(**run_args, ctx=self._ctx)
|
34
|
+
|
35
|
+
return self._ctx.result
|
36
|
+
|
37
|
+
async def run_stream(self, **run_args: Any) -> AsyncIterator[Event[Any]]:
|
38
|
+
self._ctx.is_streaming = True
|
39
|
+
for proc in self._procs:
|
40
|
+
proc.start_listening(ctx=self._ctx, **run_args)
|
41
|
+
async for event in self._start_proc.run_stream(**run_args, ctx=self._ctx):
|
42
|
+
yield event
|
@@ -1,8 +1,9 @@
|
|
1
1
|
import time
|
2
|
-
from typing import Literal, TypeAlias
|
2
|
+
from typing import Any, Literal, TypeAlias
|
3
3
|
from uuid import uuid4
|
4
4
|
|
5
|
-
from
|
5
|
+
from litellm.types.utils import ChoiceLogprobs as LiteLLMChoiceLogprobs
|
6
|
+
from openai.types.chat.chat_completion import ChoiceLogprobs as OpenAIChoiceLogprobs
|
6
7
|
from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
|
7
8
|
|
8
9
|
from .message import AssistantMessage
|
@@ -22,6 +23,7 @@ class Usage(BaseModel):
|
|
22
23
|
def __add__(self, add_usage: "Usage") -> "Usage":
|
23
24
|
input_tokens = self.input_tokens + add_usage.input_tokens
|
24
25
|
output_tokens = self.output_tokens + add_usage.output_tokens
|
26
|
+
|
25
27
|
if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
|
26
28
|
reasoning_tokens = (self.reasoning_tokens or 0) + (
|
27
29
|
add_usage.reasoning_tokens or 0
|
@@ -34,11 +36,11 @@ class Usage(BaseModel):
|
|
34
36
|
else:
|
35
37
|
cached_tokens = None
|
36
38
|
|
37
|
-
cost
|
38
|
-
(self.cost or 0.0) + add_usage.cost
|
39
|
-
|
40
|
-
|
41
|
-
|
39
|
+
if self.cost is not None or add_usage.cost is not None:
|
40
|
+
cost = (self.cost or 0.0) + (add_usage.cost or 0.0)
|
41
|
+
else:
|
42
|
+
cost = None
|
43
|
+
|
42
44
|
return Usage(
|
43
45
|
input_tokens=input_tokens,
|
44
46
|
output_tokens=output_tokens,
|
@@ -52,7 +54,9 @@ class CompletionChoice(BaseModel):
|
|
52
54
|
message: AssistantMessage
|
53
55
|
finish_reason: FinishReason | None
|
54
56
|
index: int
|
55
|
-
logprobs:
|
57
|
+
logprobs: OpenAIChoiceLogprobs | LiteLLMChoiceLogprobs | Any | None = None
|
58
|
+
# LiteLLM-specific fields
|
59
|
+
provider_specific_fields: dict[str, Any] | None = None
|
56
60
|
|
57
61
|
|
58
62
|
class CompletionError(BaseModel):
|
@@ -64,12 +68,15 @@ class CompletionError(BaseModel):
|
|
64
68
|
class Completion(BaseModel):
|
65
69
|
id: str = Field(default_factory=lambda: str(uuid4())[:8])
|
66
70
|
created: int = Field(default_factory=lambda: int(time.time()))
|
67
|
-
model: str
|
71
|
+
model: str | None
|
68
72
|
name: str | None = None
|
69
73
|
system_fingerprint: str | None = None
|
70
74
|
choices: list[CompletionChoice]
|
71
75
|
usage: Usage | None = None
|
72
76
|
error: CompletionError | None = None
|
77
|
+
# LiteLLM-specific fields
|
78
|
+
response_ms: float | None = None
|
79
|
+
hidden_params: dict[str, Any] | None = None
|
73
80
|
|
74
81
|
@property
|
75
82
|
def messages(self) -> list[AssistantMessage]:
|