grasp_agents 0.4.7__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_agents/cloud_llm.py +191 -224
- grasp_agents/comm_processor.py +101 -100
- grasp_agents/errors.py +69 -9
- grasp_agents/litellm/__init__.py +106 -0
- grasp_agents/litellm/completion_chunk_converters.py +68 -0
- grasp_agents/litellm/completion_converters.py +72 -0
- grasp_agents/litellm/converters.py +138 -0
- grasp_agents/litellm/lite_llm.py +210 -0
- grasp_agents/litellm/message_converters.py +66 -0
- grasp_agents/llm.py +84 -49
- grasp_agents/llm_agent.py +136 -120
- grasp_agents/llm_agent_memory.py +3 -3
- grasp_agents/llm_policy_executor.py +167 -174
- grasp_agents/memory.py +23 -0
- grasp_agents/openai/__init__.py +24 -9
- grasp_agents/openai/completion_chunk_converters.py +6 -6
- grasp_agents/openai/completion_converters.py +12 -14
- grasp_agents/openai/content_converters.py +1 -3
- grasp_agents/openai/converters.py +6 -8
- grasp_agents/openai/message_converters.py +21 -3
- grasp_agents/openai/openai_llm.py +155 -103
- grasp_agents/openai/tool_converters.py +4 -6
- grasp_agents/packet.py +5 -2
- grasp_agents/packet_pool.py +14 -13
- grasp_agents/printer.py +233 -73
- grasp_agents/processor.py +229 -91
- grasp_agents/prompt_builder.py +2 -2
- grasp_agents/run_context.py +11 -20
- grasp_agents/runner.py +42 -0
- grasp_agents/typing/completion.py +16 -9
- grasp_agents/typing/completion_chunk.py +51 -22
- grasp_agents/typing/events.py +95 -19
- grasp_agents/typing/message.py +25 -1
- grasp_agents/typing/tool.py +2 -0
- grasp_agents/usage_tracker.py +31 -37
- grasp_agents/utils.py +95 -84
- grasp_agents/workflow/looped_workflow.py +60 -11
- grasp_agents/workflow/sequential_workflow.py +43 -11
- grasp_agents/workflow/workflow_processor.py +25 -24
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/METADATA +7 -6
- grasp_agents-0.5.1.dist-info/RECORD +57 -0
- grasp_agents-0.4.7.dist-info/RECORD +0 -50
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/WHEEL +0 -0
- {grasp_agents-0.4.7.dist-info → grasp_agents-0.5.1.dist-info}/licenses/LICENSE.md +0 -0
grasp_agents/processor.py
CHANGED
@@ -6,56 +6,42 @@ from typing import Any, ClassVar, Generic, cast, final
|
|
6
6
|
from uuid import uuid4
|
7
7
|
|
8
8
|
from pydantic import BaseModel, TypeAdapter
|
9
|
-
from
|
9
|
+
from pydantic import ValidationError as PydanticValidationError
|
10
10
|
|
11
|
-
from .errors import
|
11
|
+
from .errors import ProcInputValidationError, ProcOutputValidationError
|
12
12
|
from .generics_utils import AutoInstanceAttributesMixin
|
13
|
-
from .memory import MemT
|
13
|
+
from .memory import DummyMemory, MemT
|
14
14
|
from .packet import Packet
|
15
15
|
from .run_context import CtxT, RunContext
|
16
|
-
from .typing.events import
|
16
|
+
from .typing.events import (
|
17
|
+
Event,
|
18
|
+
ProcPacketOutputEvent,
|
19
|
+
ProcPayloadOutputEvent,
|
20
|
+
ProcStreamingErrorData,
|
21
|
+
ProcStreamingErrorEvent,
|
22
|
+
)
|
17
23
|
from .typing.io import InT, OutT_co, ProcName
|
18
24
|
from .typing.tool import BaseTool
|
25
|
+
from .utils import stream_concurrent
|
19
26
|
|
20
27
|
logger = logging.getLogger(__name__)
|
21
28
|
|
22
29
|
|
23
|
-
def retry_error_callback(retry_state: RetryCallState) -> None:
|
24
|
-
exception = retry_state.outcome.exception() if retry_state.outcome else None
|
25
|
-
if exception:
|
26
|
-
if retry_state.attempt_number == 1:
|
27
|
-
logger.warning(f"\nParallel run failed:\n{exception}")
|
28
|
-
if retry_state.attempt_number > 1:
|
29
|
-
logger.warning(f"\nParallel run failed after retrying:\n{exception}")
|
30
|
-
|
31
|
-
|
32
|
-
def retry_before_sleep_callback(retry_state: RetryCallState) -> None:
|
33
|
-
exception = retry_state.outcome.exception() if retry_state.outcome else None
|
34
|
-
logger.info(
|
35
|
-
f"\nRetrying parallel run (attempt {retry_state.attempt_number}):\n{exception}"
|
36
|
-
)
|
37
|
-
|
38
|
-
|
39
30
|
class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, CtxT]):
|
40
31
|
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
|
41
32
|
0: "_in_type",
|
42
33
|
1: "_out_type",
|
43
34
|
}
|
44
35
|
|
45
|
-
def __init__(
|
46
|
-
self, name: ProcName, num_par_run_retries: int = 0, **kwargs: Any
|
47
|
-
) -> None:
|
36
|
+
def __init__(self, name: ProcName, max_retries: int = 0, **kwargs: Any) -> None:
|
48
37
|
self._in_type: type[InT]
|
49
38
|
self._out_type: type[OutT_co]
|
50
39
|
|
51
40
|
super().__init__()
|
52
41
|
|
53
|
-
self._in_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
|
54
|
-
self._out_type_adapter: TypeAdapter[OutT_co] = TypeAdapter(self._out_type)
|
55
|
-
|
56
42
|
self._name: ProcName = name
|
57
|
-
self._memory: MemT
|
58
|
-
self.
|
43
|
+
self._memory: MemT = cast("MemT", DummyMemory())
|
44
|
+
self._max_retries: int = max_retries
|
59
45
|
|
60
46
|
@property
|
61
47
|
def in_type(self) -> type[InT]:
|
@@ -74,8 +60,13 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
74
60
|
return self._memory
|
75
61
|
|
76
62
|
@property
|
77
|
-
def
|
78
|
-
return self.
|
63
|
+
def max_retries(self) -> int:
|
64
|
+
return self._max_retries
|
65
|
+
|
66
|
+
def _generate_call_id(self, call_id: str | None) -> str:
|
67
|
+
if call_id is None:
|
68
|
+
return str(uuid4())[:6] + "_" + self.name
|
69
|
+
return call_id
|
79
70
|
|
80
71
|
def _validate_and_resolve_single_input(
|
81
72
|
self,
|
@@ -87,18 +78,19 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
87
78
|
"Only one of chat_inputs, in_args, or in_message must be provided."
|
88
79
|
)
|
89
80
|
if chat_inputs is not None and in_args is not None:
|
90
|
-
raise
|
81
|
+
raise ProcInputValidationError(multiple_inputs_err_message)
|
91
82
|
if chat_inputs is not None and in_packet is not None:
|
92
|
-
raise
|
83
|
+
raise ProcInputValidationError(multiple_inputs_err_message)
|
93
84
|
if in_args is not None and in_packet is not None:
|
94
|
-
raise
|
85
|
+
raise ProcInputValidationError(multiple_inputs_err_message)
|
95
86
|
|
96
87
|
if in_packet is not None:
|
97
88
|
if len(in_packet.payloads) != 1:
|
98
|
-
raise
|
89
|
+
raise ProcInputValidationError(
|
99
90
|
"Single input runs require exactly one payload in in_packet."
|
100
91
|
)
|
101
92
|
return in_packet.payloads[0]
|
93
|
+
|
102
94
|
return in_args
|
103
95
|
|
104
96
|
def _validate_and_resolve_parallel_inputs(
|
@@ -108,33 +100,44 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
108
100
|
in_args: Sequence[InT] | None,
|
109
101
|
) -> Sequence[InT]:
|
110
102
|
if chat_inputs is not None:
|
111
|
-
raise
|
103
|
+
raise ProcInputValidationError(
|
112
104
|
"chat_inputs are not supported in parallel runs. "
|
113
105
|
"Use in_packet or in_args."
|
114
106
|
)
|
115
107
|
if in_packet is not None:
|
116
108
|
if not in_packet.payloads:
|
117
|
-
raise
|
109
|
+
raise ProcInputValidationError(
|
118
110
|
"Parallel runs require at least one input payload in in_packet."
|
119
111
|
)
|
120
112
|
return in_packet.payloads
|
121
113
|
if in_args is not None:
|
122
114
|
return in_args
|
123
|
-
raise
|
115
|
+
raise ProcInputValidationError(
|
124
116
|
"Parallel runs require either in_packet or in_args to be provided."
|
125
117
|
)
|
126
118
|
|
119
|
+
def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
|
120
|
+
try:
|
121
|
+
return [
|
122
|
+
TypeAdapter(self._out_type).validate_python(payload)
|
123
|
+
for payload in out_payloads
|
124
|
+
]
|
125
|
+
except PydanticValidationError as err:
|
126
|
+
raise ProcOutputValidationError(
|
127
|
+
f"Output validation failed for processor {self.name}:\n{err}"
|
128
|
+
) from err
|
129
|
+
|
127
130
|
async def _process(
|
128
131
|
self,
|
129
132
|
chat_inputs: Any | None = None,
|
130
133
|
*,
|
131
134
|
in_args: InT | None = None,
|
132
135
|
memory: MemT,
|
133
|
-
|
136
|
+
call_id: str,
|
134
137
|
ctx: RunContext[CtxT] | None = None,
|
135
138
|
) -> Sequence[OutT_co]:
|
136
139
|
if in_args is None:
|
137
|
-
raise
|
140
|
+
raise ProcInputValidationError(
|
138
141
|
"Default implementation of _process requires in_args"
|
139
142
|
)
|
140
143
|
|
@@ -146,54 +149,75 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
146
149
|
*,
|
147
150
|
in_args: InT | None = None,
|
148
151
|
memory: MemT,
|
149
|
-
|
152
|
+
call_id: str,
|
150
153
|
ctx: RunContext[CtxT] | None = None,
|
151
154
|
) -> AsyncIterator[Event[Any]]:
|
152
155
|
if in_args is None:
|
153
|
-
raise
|
154
|
-
"Default implementation of
|
156
|
+
raise ProcInputValidationError(
|
157
|
+
"Default implementation of _process_stream requires in_args"
|
155
158
|
)
|
156
159
|
outputs = cast("Sequence[OutT_co]", in_args)
|
157
160
|
for out in outputs:
|
158
|
-
yield
|
159
|
-
|
160
|
-
def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
|
161
|
-
return [
|
162
|
-
self._out_type_adapter.validate_python(payload) for payload in out_payloads
|
163
|
-
]
|
164
|
-
|
165
|
-
def _generate_run_id(self, run_id: str | None) -> str:
|
166
|
-
if run_id is None:
|
167
|
-
return str(uuid4())[:6] + "_" + self.name
|
168
|
-
return run_id
|
161
|
+
yield ProcPayloadOutputEvent(data=out, proc_name=self.name, call_id=call_id)
|
169
162
|
|
170
|
-
async def
|
163
|
+
async def _run_single_once(
|
171
164
|
self,
|
172
165
|
chat_inputs: Any | None = None,
|
173
166
|
*,
|
174
167
|
in_packet: Packet[InT] | None = None,
|
175
168
|
in_args: InT | None = None,
|
176
169
|
forgetful: bool = False,
|
177
|
-
|
170
|
+
call_id: str,
|
178
171
|
ctx: RunContext[CtxT] | None = None,
|
179
172
|
) -> Packet[OutT_co]:
|
180
173
|
resolved_in_args = self._validate_and_resolve_single_input(
|
181
174
|
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
182
175
|
)
|
183
176
|
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
177
|
+
|
184
178
|
outputs = await self._process(
|
185
179
|
chat_inputs=chat_inputs,
|
186
180
|
in_args=resolved_in_args,
|
187
181
|
memory=_memory,
|
188
|
-
|
182
|
+
call_id=call_id,
|
189
183
|
ctx=ctx,
|
190
184
|
)
|
191
185
|
val_outputs = self._validate_outputs(outputs)
|
192
186
|
|
193
187
|
return Packet(payloads=val_outputs, sender=self.name)
|
194
188
|
|
195
|
-
def
|
196
|
-
|
189
|
+
async def _run_single(
|
190
|
+
self,
|
191
|
+
chat_inputs: Any | None = None,
|
192
|
+
*,
|
193
|
+
in_packet: Packet[InT] | None = None,
|
194
|
+
in_args: InT | None = None,
|
195
|
+
forgetful: bool = False,
|
196
|
+
call_id: str,
|
197
|
+
ctx: RunContext[CtxT] | None = None,
|
198
|
+
) -> Packet[OutT_co] | None:
|
199
|
+
n_attempt = 0
|
200
|
+
while n_attempt <= self.max_retries:
|
201
|
+
try:
|
202
|
+
return await self._run_single_once(
|
203
|
+
chat_inputs=chat_inputs,
|
204
|
+
in_packet=in_packet,
|
205
|
+
in_args=in_args,
|
206
|
+
forgetful=forgetful,
|
207
|
+
call_id=call_id,
|
208
|
+
ctx=ctx,
|
209
|
+
)
|
210
|
+
except Exception as err:
|
211
|
+
n_attempt += 1
|
212
|
+
if n_attempt > self.max_retries:
|
213
|
+
if n_attempt == 1:
|
214
|
+
logger.warning(f"\nProcessor run failed:\n{err}")
|
215
|
+
if n_attempt > 1:
|
216
|
+
logger.warning(f"\nProcessor run failed after retrying:\n{err}")
|
217
|
+
return None
|
218
|
+
logger.warning(
|
219
|
+
f"\nProcessor run failed (retry attempt {n_attempt}):\n{err}"
|
220
|
+
)
|
197
221
|
|
198
222
|
async def _run_par(
|
199
223
|
self,
|
@@ -201,27 +225,15 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
201
225
|
*,
|
202
226
|
in_packet: Packet[InT] | None = None,
|
203
227
|
in_args: Sequence[InT] | None = None,
|
204
|
-
|
205
|
-
forgetful: bool = False,
|
228
|
+
call_id: str,
|
206
229
|
ctx: RunContext[CtxT] | None = None,
|
207
230
|
) -> Packet[OutT_co]:
|
208
231
|
par_inputs = self._validate_and_resolve_parallel_inputs(
|
209
232
|
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
210
233
|
)
|
211
|
-
|
212
|
-
wrapped_func = retry(
|
213
|
-
wait=wait_random_exponential(min=1, max=8),
|
214
|
-
stop=stop_after_attempt(self._num_par_run_retries + 1),
|
215
|
-
before_sleep=retry_before_sleep_callback,
|
216
|
-
retry_error_callback=retry_error_callback,
|
217
|
-
)(self._run_single)
|
218
|
-
|
219
234
|
tasks = [
|
220
|
-
|
221
|
-
in_args=inp,
|
222
|
-
forgetful=True,
|
223
|
-
run_id=self._generate_par_run_id(run_id, idx),
|
224
|
-
ctx=ctx,
|
235
|
+
self._run_single(
|
236
|
+
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
225
237
|
)
|
226
238
|
for idx, inp in enumerate(par_inputs)
|
227
239
|
]
|
@@ -242,62 +254,188 @@ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, Ct
|
|
242
254
|
in_packet: Packet[InT] | None = None,
|
243
255
|
in_args: InT | Sequence[InT] | None = None,
|
244
256
|
forgetful: bool = False,
|
245
|
-
|
257
|
+
call_id: str | None = None,
|
246
258
|
ctx: RunContext[CtxT] | None = None,
|
247
259
|
) -> Packet[OutT_co]:
|
260
|
+
call_id = self._generate_call_id(call_id)
|
261
|
+
|
248
262
|
if (in_args is not None and isinstance(in_args, Sequence)) or (
|
249
263
|
in_packet is not None and len(in_packet.payloads) > 1
|
250
264
|
):
|
251
265
|
return await self._run_par(
|
252
266
|
chat_inputs=chat_inputs,
|
253
267
|
in_packet=in_packet,
|
254
|
-
in_args=cast("Sequence[InT]", in_args),
|
255
|
-
|
256
|
-
forgetful=forgetful,
|
268
|
+
in_args=cast("Sequence[InT] | None", in_args),
|
269
|
+
call_id=call_id,
|
257
270
|
ctx=ctx,
|
258
271
|
)
|
259
|
-
return await self._run_single(
|
272
|
+
return await self._run_single( # type: ignore[return]
|
260
273
|
chat_inputs=chat_inputs,
|
261
274
|
in_packet=in_packet,
|
262
|
-
in_args=in_args,
|
275
|
+
in_args=cast("InT | None", in_args),
|
263
276
|
forgetful=forgetful,
|
264
|
-
|
277
|
+
call_id=call_id,
|
265
278
|
ctx=ctx,
|
266
279
|
)
|
267
280
|
|
268
|
-
async def
|
281
|
+
async def _run_single_stream_once(
|
269
282
|
self,
|
270
283
|
chat_inputs: Any | None = None,
|
271
284
|
*,
|
272
285
|
in_packet: Packet[InT] | None = None,
|
273
286
|
in_args: InT | None = None,
|
274
287
|
forgetful: bool = False,
|
275
|
-
|
288
|
+
call_id: str,
|
276
289
|
ctx: RunContext[CtxT] | None = None,
|
277
290
|
) -> AsyncIterator[Event[Any]]:
|
278
291
|
resolved_in_args = self._validate_and_resolve_single_input(
|
279
292
|
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
280
293
|
)
|
281
|
-
|
282
294
|
_memory = self.memory.model_copy(deep=True) if forgetful else self.memory
|
283
295
|
|
284
|
-
outputs:
|
285
|
-
async for
|
296
|
+
outputs: list[OutT_co] = []
|
297
|
+
async for event in self._process_stream(
|
286
298
|
chat_inputs=chat_inputs,
|
287
299
|
in_args=resolved_in_args,
|
288
300
|
memory=_memory,
|
289
|
-
|
301
|
+
call_id=call_id,
|
290
302
|
ctx=ctx,
|
291
303
|
):
|
292
|
-
if isinstance(
|
293
|
-
outputs.append(
|
294
|
-
|
295
|
-
yield output_event
|
304
|
+
if isinstance(event, ProcPayloadOutputEvent):
|
305
|
+
outputs.append(event.data)
|
306
|
+
yield event
|
296
307
|
|
297
308
|
val_outputs = self._validate_outputs(outputs)
|
298
309
|
out_packet = Packet[OutT_co](payloads=val_outputs, sender=self.name)
|
299
310
|
|
300
|
-
yield
|
311
|
+
yield ProcPacketOutputEvent(
|
312
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
313
|
+
)
|
314
|
+
|
315
|
+
async def _run_single_stream(
|
316
|
+
self,
|
317
|
+
chat_inputs: Any | None = None,
|
318
|
+
*,
|
319
|
+
in_packet: Packet[InT] | None = None,
|
320
|
+
in_args: InT | None = None,
|
321
|
+
forgetful: bool = False,
|
322
|
+
call_id: str,
|
323
|
+
ctx: RunContext[CtxT] | None = None,
|
324
|
+
) -> AsyncIterator[Event[Any]]:
|
325
|
+
n_attempt = 0
|
326
|
+
while n_attempt <= self.max_retries:
|
327
|
+
try:
|
328
|
+
async for event in self._run_single_stream_once(
|
329
|
+
chat_inputs=chat_inputs,
|
330
|
+
in_packet=in_packet,
|
331
|
+
in_args=in_args,
|
332
|
+
forgetful=forgetful,
|
333
|
+
call_id=call_id,
|
334
|
+
ctx=ctx,
|
335
|
+
):
|
336
|
+
yield event
|
337
|
+
|
338
|
+
return
|
339
|
+
|
340
|
+
except Exception as err:
|
341
|
+
err_data = ProcStreamingErrorData(error=err, call_id=call_id)
|
342
|
+
yield ProcStreamingErrorEvent(
|
343
|
+
data=err_data, proc_name=self.name, call_id=call_id
|
344
|
+
)
|
345
|
+
|
346
|
+
n_attempt += 1
|
347
|
+
if n_attempt > self.max_retries:
|
348
|
+
if n_attempt == 1:
|
349
|
+
logger.warning(f"\nStreaming processor run failed:\n{err}")
|
350
|
+
if n_attempt > 1:
|
351
|
+
logger.warning(
|
352
|
+
f"\nStreaming processor run failed after retrying:\n{err}"
|
353
|
+
)
|
354
|
+
return
|
355
|
+
|
356
|
+
logger.warning(
|
357
|
+
"\nStreaming processor run failed "
|
358
|
+
f"(retry attempt {n_attempt}):\n{err}"
|
359
|
+
)
|
360
|
+
|
361
|
+
async def _run_par_stream(
|
362
|
+
self,
|
363
|
+
chat_inputs: Any | None = None,
|
364
|
+
*,
|
365
|
+
in_packet: Packet[InT] | None = None,
|
366
|
+
in_args: Sequence[InT] | None = None,
|
367
|
+
call_id: str,
|
368
|
+
ctx: RunContext[CtxT] | None = None,
|
369
|
+
) -> AsyncIterator[Event[Any]]:
|
370
|
+
par_inputs = self._validate_and_resolve_parallel_inputs(
|
371
|
+
chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
|
372
|
+
)
|
373
|
+
streams = [
|
374
|
+
self._run_single_stream(
|
375
|
+
in_args=inp, forgetful=True, call_id=f"{call_id}/{idx}", ctx=ctx
|
376
|
+
)
|
377
|
+
for idx, inp in enumerate(par_inputs)
|
378
|
+
]
|
379
|
+
|
380
|
+
out_packets_map: dict[int, Packet[OutT_co] | None] = dict.fromkeys(
|
381
|
+
range(len(streams)), None
|
382
|
+
)
|
383
|
+
|
384
|
+
async for idx, event in stream_concurrent(streams):
|
385
|
+
if isinstance(event, ProcPacketOutputEvent):
|
386
|
+
out_packets_map[idx] = event.data
|
387
|
+
else:
|
388
|
+
yield event
|
389
|
+
|
390
|
+
out_packet = Packet( # type: ignore[return]
|
391
|
+
payloads=[
|
392
|
+
(out_packet.payloads[0] if out_packet else None)
|
393
|
+
for out_packet in out_packets_map.values()
|
394
|
+
],
|
395
|
+
sender=self.name,
|
396
|
+
)
|
397
|
+
|
398
|
+
yield ProcPacketOutputEvent(
|
399
|
+
data=out_packet, proc_name=self.name, call_id=call_id
|
400
|
+
)
|
401
|
+
|
402
|
+
async def run_stream(
|
403
|
+
self,
|
404
|
+
chat_inputs: Any | None = None,
|
405
|
+
*,
|
406
|
+
in_packet: Packet[InT] | None = None,
|
407
|
+
in_args: InT | Sequence[InT] | None = None,
|
408
|
+
forgetful: bool = False,
|
409
|
+
call_id: str | None = None,
|
410
|
+
ctx: RunContext[CtxT] | None = None,
|
411
|
+
) -> AsyncIterator[Event[Any]]:
|
412
|
+
call_id = self._generate_call_id(call_id)
|
413
|
+
|
414
|
+
# yield ProcStartEvent(proc_name=self.name, call_id=call_id, data=None)
|
415
|
+
|
416
|
+
if (in_args is not None and isinstance(in_args, Sequence)) or (
|
417
|
+
in_packet is not None and len(in_packet.payloads) > 1
|
418
|
+
):
|
419
|
+
stream = self._run_par_stream(
|
420
|
+
chat_inputs=chat_inputs,
|
421
|
+
in_packet=in_packet,
|
422
|
+
in_args=cast("Sequence[InT] | None", in_args),
|
423
|
+
call_id=call_id,
|
424
|
+
ctx=ctx,
|
425
|
+
)
|
426
|
+
else:
|
427
|
+
stream = self._run_single_stream(
|
428
|
+
chat_inputs=chat_inputs,
|
429
|
+
in_packet=in_packet,
|
430
|
+
in_args=cast("InT | None", in_args),
|
431
|
+
forgetful=forgetful,
|
432
|
+
call_id=call_id,
|
433
|
+
ctx=ctx,
|
434
|
+
)
|
435
|
+
async for event in stream:
|
436
|
+
yield event
|
437
|
+
|
438
|
+
# yield ProcFinishEvent(proc_name=self.name, call_id=call_id, data=None)
|
301
439
|
|
302
440
|
@final
|
303
441
|
def as_tool(
|
grasp_agents/prompt_builder.py
CHANGED
@@ -20,7 +20,7 @@ class MakeSystemPromptHandler(Protocol[CtxT]):
|
|
20
20
|
sys_args: LLMPromptArgs | None,
|
21
21
|
*,
|
22
22
|
ctx: RunContext[CtxT] | None,
|
23
|
-
) -> str: ...
|
23
|
+
) -> str | None: ...
|
24
24
|
|
25
25
|
|
26
26
|
class MakeInputContentHandler(Protocol[_InT_contra, CtxT]):
|
@@ -110,7 +110,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
|
|
110
110
|
|
111
111
|
return Content.from_text(json.dumps(combined_args, indent=2))
|
112
112
|
|
113
|
-
def
|
113
|
+
def make_input_message(
|
114
114
|
self,
|
115
115
|
chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
|
116
116
|
in_args: InT | None = None,
|
grasp_agents/run_context.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
1
|
from collections import defaultdict
|
2
|
-
from collections.abc import Mapping
|
3
2
|
from typing import Any, Generic, TypeVar
|
4
3
|
|
5
|
-
from pydantic import BaseModel, ConfigDict, Field
|
4
|
+
from pydantic import BaseModel, ConfigDict, Field
|
6
5
|
|
7
6
|
from grasp_agents.typing.completion import Completion
|
8
7
|
|
@@ -25,29 +24,21 @@ class RunContext(BaseModel, Generic[CtxT]):
|
|
25
24
|
state: CtxT | None = None
|
26
25
|
|
27
26
|
run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
|
28
|
-
|
27
|
+
|
28
|
+
is_streaming: bool = False
|
29
|
+
result: Any | None = None
|
30
|
+
|
31
|
+
completions: dict[ProcName, list[Completion]] = Field(
|
29
32
|
default_factory=lambda: defaultdict(list)
|
30
33
|
)
|
34
|
+
usage_tracker: UsageTracker = Field(default_factory=UsageTracker)
|
31
35
|
|
36
|
+
printer: Printer | None = None
|
32
37
|
print_messages: bool = False
|
33
38
|
color_messages_by: ColoringMode = "role"
|
34
39
|
|
35
|
-
_usage_tracker: UsageTracker = PrivateAttr()
|
36
|
-
_printer: Printer = PrivateAttr()
|
37
|
-
|
38
40
|
def model_post_init(self, context: Any) -> None: # noqa: ARG002
|
39
|
-
self.
|
40
|
-
|
41
|
-
print_messages=self.print_messages,
|
42
|
-
color_by=self.color_messages_by,
|
43
|
-
)
|
41
|
+
if self.print_messages:
|
42
|
+
self.printer = Printer(color_by=self.color_messages_by)
|
44
43
|
|
45
|
-
|
46
|
-
def usage_tracker(self) -> UsageTracker:
|
47
|
-
return self._usage_tracker
|
48
|
-
|
49
|
-
@property
|
50
|
-
def printer(self) -> Printer:
|
51
|
-
return self._printer
|
52
|
-
|
53
|
-
model_config = ConfigDict(extra="forbid")
|
44
|
+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
grasp_agents/runner.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
from collections.abc import AsyncIterator, Sequence
|
2
|
+
from typing import Any, Generic
|
3
|
+
|
4
|
+
from .comm_processor import CommProcessor
|
5
|
+
from .run_context import CtxT, RunContext
|
6
|
+
from .typing.events import Event
|
7
|
+
|
8
|
+
|
9
|
+
class Runner(Generic[CtxT]):
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
start_proc: CommProcessor[Any, Any, Any, CtxT],
|
13
|
+
procs: Sequence[CommProcessor[Any, Any, Any, CtxT]],
|
14
|
+
ctx: RunContext[CtxT] | None = None,
|
15
|
+
) -> None:
|
16
|
+
if start_proc not in procs:
|
17
|
+
raise ValueError(
|
18
|
+
f"Start processor {start_proc.name} must be in the list of processors: "
|
19
|
+
f"{', '.join(proc.name for proc in procs)}"
|
20
|
+
)
|
21
|
+
self._start_proc = start_proc
|
22
|
+
self._procs = procs
|
23
|
+
self._ctx = ctx or RunContext[CtxT]()
|
24
|
+
|
25
|
+
@property
|
26
|
+
def ctx(self) -> RunContext[CtxT]:
|
27
|
+
return self._ctx
|
28
|
+
|
29
|
+
async def run(self, **run_args: Any) -> Any:
|
30
|
+
self._ctx.is_streaming = False
|
31
|
+
for proc in self._procs:
|
32
|
+
proc.start_listening(ctx=self._ctx, **run_args)
|
33
|
+
await self._start_proc.run(**run_args, ctx=self._ctx)
|
34
|
+
|
35
|
+
return self._ctx.result
|
36
|
+
|
37
|
+
async def run_stream(self, **run_args: Any) -> AsyncIterator[Event[Any]]:
|
38
|
+
self._ctx.is_streaming = True
|
39
|
+
for proc in self._procs:
|
40
|
+
proc.start_listening(ctx=self._ctx, **run_args)
|
41
|
+
async for event in self._start_proc.run_stream(**run_args, ctx=self._ctx):
|
42
|
+
yield event
|
@@ -1,8 +1,9 @@
|
|
1
1
|
import time
|
2
|
-
from typing import Literal, TypeAlias
|
2
|
+
from typing import Any, Literal, TypeAlias
|
3
3
|
from uuid import uuid4
|
4
4
|
|
5
|
-
from
|
5
|
+
from litellm.types.utils import ChoiceLogprobs as LiteLLMChoiceLogprobs
|
6
|
+
from openai.types.chat.chat_completion import ChoiceLogprobs as OpenAIChoiceLogprobs
|
6
7
|
from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt
|
7
8
|
|
8
9
|
from .message import AssistantMessage
|
@@ -22,6 +23,7 @@ class Usage(BaseModel):
|
|
22
23
|
def __add__(self, add_usage: "Usage") -> "Usage":
|
23
24
|
input_tokens = self.input_tokens + add_usage.input_tokens
|
24
25
|
output_tokens = self.output_tokens + add_usage.output_tokens
|
26
|
+
|
25
27
|
if self.reasoning_tokens is not None or add_usage.reasoning_tokens is not None:
|
26
28
|
reasoning_tokens = (self.reasoning_tokens or 0) + (
|
27
29
|
add_usage.reasoning_tokens or 0
|
@@ -34,11 +36,11 @@ class Usage(BaseModel):
|
|
34
36
|
else:
|
35
37
|
cached_tokens = None
|
36
38
|
|
37
|
-
cost
|
38
|
-
(self.cost or 0.0) + add_usage.cost
|
39
|
-
|
40
|
-
|
41
|
-
|
39
|
+
if self.cost is not None or add_usage.cost is not None:
|
40
|
+
cost = (self.cost or 0.0) + (add_usage.cost or 0.0)
|
41
|
+
else:
|
42
|
+
cost = None
|
43
|
+
|
42
44
|
return Usage(
|
43
45
|
input_tokens=input_tokens,
|
44
46
|
output_tokens=output_tokens,
|
@@ -52,7 +54,9 @@ class CompletionChoice(BaseModel):
|
|
52
54
|
message: AssistantMessage
|
53
55
|
finish_reason: FinishReason | None
|
54
56
|
index: int
|
55
|
-
logprobs:
|
57
|
+
logprobs: OpenAIChoiceLogprobs | LiteLLMChoiceLogprobs | Any | None = None
|
58
|
+
# LiteLLM-specific fields
|
59
|
+
provider_specific_fields: dict[str, Any] | None = None
|
56
60
|
|
57
61
|
|
58
62
|
class CompletionError(BaseModel):
|
@@ -64,12 +68,15 @@ class CompletionError(BaseModel):
|
|
64
68
|
class Completion(BaseModel):
|
65
69
|
id: str = Field(default_factory=lambda: str(uuid4())[:8])
|
66
70
|
created: int = Field(default_factory=lambda: int(time.time()))
|
67
|
-
model: str
|
71
|
+
model: str | None
|
68
72
|
name: str | None = None
|
69
73
|
system_fingerprint: str | None = None
|
70
74
|
choices: list[CompletionChoice]
|
71
75
|
usage: Usage | None = None
|
72
76
|
error: CompletionError | None = None
|
77
|
+
# LiteLLM-specific fields
|
78
|
+
response_ms: float | None = None
|
79
|
+
hidden_params: dict[str, Any] | None = None
|
73
80
|
|
74
81
|
@property
|
75
82
|
def messages(self) -> list[AssistantMessage]:
|