grasp_agents 0.3.11__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
grasp_agents/processor.py CHANGED
@@ -1,40 +1,64 @@
1
+ import asyncio
2
+ import logging
1
3
  from abc import ABC
2
4
  from collections.abc import AsyncIterator, Sequence
3
5
  from typing import Any, ClassVar, Generic, cast, final
6
+ from uuid import uuid4
4
7
 
5
8
  from pydantic import BaseModel, TypeAdapter
9
+ from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
6
10
 
11
+ from .errors import InputValidationError
7
12
  from .generics_utils import AutoInstanceAttributesMixin
13
+ from .memory import MemT
8
14
  from .packet import Packet
9
15
  from .run_context import CtxT, RunContext
10
16
  from .typing.events import Event, PacketEvent, ProcOutputEvent
11
- from .typing.io import InT_contra, MemT_co, OutT_co, ProcName
17
+ from .typing.io import InT, OutT_co, ProcName
12
18
  from .typing.tool import BaseTool
13
19
 
20
+ logger = logging.getLogger(__name__)
14
21
 
15
- class Processor(
16
- AutoInstanceAttributesMixin, ABC, Generic[InT_contra, OutT_co, MemT_co, CtxT]
17
- ):
22
+
23
+ def retry_error_callback(retry_state: RetryCallState) -> None:
24
+ exception = retry_state.outcome.exception() if retry_state.outcome else None
25
+ if exception:
26
+ if retry_state.attempt_number == 1:
27
+ logger.warning(f"\nParallel run failed:\n{exception}")
28
+ if retry_state.attempt_number > 1:
29
+ logger.warning(f"\nParallel run failed after retrying:\n{exception}")
30
+
31
+
32
+ def retry_before_sleep_callback(retry_state: RetryCallState) -> None:
33
+ exception = retry_state.outcome.exception() if retry_state.outcome else None
34
+ logger.info(
35
+ f"\nRetrying parallel run (attempt {retry_state.attempt_number}):\n{exception}"
36
+ )
37
+
38
+
39
+ class Processor(AutoInstanceAttributesMixin, ABC, Generic[InT, OutT_co, MemT, CtxT]):
18
40
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
19
41
  0: "_in_type",
20
42
  1: "_out_type",
21
43
  }
22
44
 
23
- def __init__(self, name: ProcName, **kwargs: Any) -> None:
24
- self._in_type: type[InT_contra]
45
+ def __init__(
46
+ self, name: ProcName, num_par_run_retries: int = 0, **kwargs: Any
47
+ ) -> None:
48
+ self._in_type: type[InT]
25
49
  self._out_type: type[OutT_co]
26
50
 
27
51
  super().__init__()
28
52
 
29
- self._in_type_adapter: TypeAdapter[InT_contra] = TypeAdapter(self._in_type)
53
+ self._in_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
30
54
  self._out_type_adapter: TypeAdapter[OutT_co] = TypeAdapter(self._out_type)
31
55
 
32
56
  self._name: ProcName = name
33
- self._memory: MemT_co
57
+ self._memory: MemT
58
+ self._num_par_run_retries: int = num_par_run_retries
34
59
 
35
60
  @property
36
- def in_type(self) -> type[InT_contra]: # type: ignore[reportInvalidTypeVarUse]
37
- # Exposing the type of a contravariant variable only, should be type safe
61
+ def in_type(self) -> type[InT]:
38
62
  return self._in_type
39
63
 
40
64
  @property
@@ -46,48 +70,73 @@ class Processor(
46
70
  return self._name
47
71
 
48
72
  @property
49
- def memory(self) -> MemT_co:
73
+ def memory(self) -> MemT:
50
74
  return self._memory
51
75
 
52
- def _validate_and_resolve_inputs(
76
+ @property
77
+ def num_par_run_retries(self) -> int:
78
+ return self._num_par_run_retries
79
+
80
+ def _validate_and_resolve_single_input(
53
81
  self,
54
82
  chat_inputs: Any | None = None,
55
- in_packet: Packet[InT_contra] | None = None,
56
- in_args: InT_contra | Sequence[InT_contra] | None = None,
57
- ) -> Sequence[InT_contra] | None:
83
+ in_packet: Packet[InT] | None = None,
84
+ in_args: InT | None = None,
85
+ ) -> InT | None:
58
86
  multiple_inputs_err_message = (
59
87
  "Only one of chat_inputs, in_args, or in_message must be provided."
60
88
  )
61
89
  if chat_inputs is not None and in_args is not None:
62
- raise ValueError(multiple_inputs_err_message)
90
+ raise InputValidationError(multiple_inputs_err_message)
63
91
  if chat_inputs is not None and in_packet is not None:
64
- raise ValueError(multiple_inputs_err_message)
92
+ raise InputValidationError(multiple_inputs_err_message)
65
93
  if in_args is not None and in_packet is not None:
66
- raise ValueError(multiple_inputs_err_message)
94
+ raise InputValidationError(multiple_inputs_err_message)
67
95
 
68
- resolved_in_args: Sequence[InT_contra] | None = None
69
96
  if in_packet is not None:
70
- resolved_in_args = in_packet.payloads
71
- elif isinstance(in_args, self._in_type):
72
- resolved_in_args = cast("Sequence[InT_contra]", [in_args])
73
- elif in_args is None:
74
- resolved_in_args = in_args
75
- else:
76
- resolved_in_args = cast("Sequence[InT_contra]", in_args)
97
+ if len(in_packet.payloads) != 1:
98
+ raise InputValidationError(
99
+ "Single input runs require exactly one payload in in_packet."
100
+ )
101
+ return in_packet.payloads[0]
102
+ return in_args
77
103
 
78
- return resolved_in_args
104
+ def _validate_and_resolve_parallel_inputs(
105
+ self,
106
+ chat_inputs: Any | None,
107
+ in_packet: Packet[InT] | None,
108
+ in_args: Sequence[InT] | None,
109
+ ) -> Sequence[InT]:
110
+ if chat_inputs is not None:
111
+ raise InputValidationError(
112
+ "chat_inputs are not supported in parallel runs. "
113
+ "Use in_packet or in_args."
114
+ )
115
+ if in_packet is not None:
116
+ if not in_packet.payloads:
117
+ raise InputValidationError(
118
+ "Parallel runs require at least one input payload in in_packet."
119
+ )
120
+ return in_packet.payloads
121
+ if in_args is not None:
122
+ return in_args
123
+ raise InputValidationError(
124
+ "Parallel runs require either in_packet or in_args to be provided."
125
+ )
79
126
 
80
127
  async def _process(
81
128
  self,
82
129
  chat_inputs: Any | None = None,
83
130
  *,
84
- in_args: Sequence[InT_contra] | None = None,
85
- forgetful: bool = False,
131
+ in_args: InT | None = None,
132
+ memory: MemT,
133
+ run_id: str,
86
134
  ctx: RunContext[CtxT] | None = None,
87
135
  ) -> Sequence[OutT_co]:
88
- assert in_args is not None, (
89
- "Default implementation of _process requires in_args"
90
- )
136
+ if in_args is None:
137
+ raise InputValidationError(
138
+ "Default implementation of _process requires in_args"
139
+ )
91
140
 
92
141
  return cast("Sequence[OutT_co]", in_args)
93
142
 
@@ -95,13 +144,15 @@ class Processor(
95
144
  self,
96
145
  chat_inputs: Any | None = None,
97
146
  *,
98
- in_args: Sequence[InT_contra] | None = None,
99
- forgetful: bool = False,
147
+ in_args: InT | None = None,
148
+ memory: MemT,
149
+ run_id: str,
100
150
  ctx: RunContext[CtxT] | None = None,
101
151
  ) -> AsyncIterator[Event[Any]]:
102
- assert in_args is not None, (
103
- "Default implementation of _process requires in_args"
104
- )
152
+ if in_args is None:
153
+ raise InputValidationError(
154
+ "Default implementation of _process requires in_args"
155
+ )
105
156
  outputs = cast("Sequence[OutT_co]", in_args)
106
157
  for out in outputs:
107
158
  yield ProcOutputEvent(data=out, name=self.name)
@@ -111,46 +162,129 @@ class Processor(
111
162
  self._out_type_adapter.validate_python(payload) for payload in out_payloads
112
163
  ]
113
164
 
114
- async def run(
165
+ def _generate_run_id(self, run_id: str | None) -> str:
166
+ if run_id is None:
167
+ return str(uuid4())[:6] + "_" + self.name
168
+ return run_id
169
+
170
+ async def _run_single(
115
171
  self,
116
172
  chat_inputs: Any | None = None,
117
173
  *,
118
- in_packet: Packet[InT_contra] | None = None,
119
- in_args: InT_contra | Sequence[InT_contra] | None = None,
174
+ in_packet: Packet[InT] | None = None,
175
+ in_args: InT | None = None,
120
176
  forgetful: bool = False,
177
+ run_id: str | None = None,
121
178
  ctx: RunContext[CtxT] | None = None,
122
179
  ) -> Packet[OutT_co]:
123
- resolved_in_args = self._validate_and_resolve_inputs(
180
+ resolved_in_args = self._validate_and_resolve_single_input(
124
181
  chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
125
182
  )
183
+ _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
126
184
  outputs = await self._process(
127
185
  chat_inputs=chat_inputs,
128
186
  in_args=resolved_in_args,
129
- forgetful=forgetful,
187
+ memory=_memory,
188
+ run_id=self._generate_run_id(run_id),
130
189
  ctx=ctx,
131
190
  )
132
191
  val_outputs = self._validate_outputs(outputs)
133
192
 
134
193
  return Packet(payloads=val_outputs, sender=self.name)
135
194
 
195
+ def _generate_par_run_id(self, run_id: str | None, idx: int) -> str:
196
+ return f"{self._generate_run_id(run_id)}/{idx}"
197
+
198
+ async def _run_par(
199
+ self,
200
+ chat_inputs: Any | None = None,
201
+ *,
202
+ in_packet: Packet[InT] | None = None,
203
+ in_args: Sequence[InT] | None = None,
204
+ run_id: str | None = None,
205
+ forgetful: bool = False,
206
+ ctx: RunContext[CtxT] | None = None,
207
+ ) -> Packet[OutT_co]:
208
+ par_inputs = self._validate_and_resolve_parallel_inputs(
209
+ chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
210
+ )
211
+
212
+ wrapped_func = retry(
213
+ wait=wait_random_exponential(min=1, max=8),
214
+ stop=stop_after_attempt(self._num_par_run_retries + 1),
215
+ before_sleep=retry_before_sleep_callback,
216
+ retry_error_callback=retry_error_callback,
217
+ )(self._run_single)
218
+
219
+ tasks = [
220
+ wrapped_func(
221
+ in_args=inp,
222
+ forgetful=True,
223
+ run_id=self._generate_par_run_id(run_id, idx),
224
+ ctx=ctx,
225
+ )
226
+ for idx, inp in enumerate(par_inputs)
227
+ ]
228
+ out_packets = await asyncio.gather(*tasks)
229
+
230
+ return Packet( # type: ignore[return]
231
+ payloads=[
232
+ (out_packet.payloads[0] if out_packet else None)
233
+ for out_packet in out_packets
234
+ ],
235
+ sender=self.name,
236
+ )
237
+
238
+ async def run(
239
+ self,
240
+ chat_inputs: Any | None = None,
241
+ *,
242
+ in_packet: Packet[InT] | None = None,
243
+ in_args: InT | Sequence[InT] | None = None,
244
+ forgetful: bool = False,
245
+ run_id: str | None = None,
246
+ ctx: RunContext[CtxT] | None = None,
247
+ ) -> Packet[OutT_co]:
248
+ if isinstance(in_args, Sequence):
249
+ return await self._run_par(
250
+ chat_inputs=chat_inputs,
251
+ in_packet=in_packet,
252
+ in_args=cast("Sequence[InT]", in_args),
253
+ run_id=run_id,
254
+ forgetful=forgetful,
255
+ ctx=ctx,
256
+ )
257
+ return await self._run_single(
258
+ chat_inputs=chat_inputs,
259
+ in_packet=in_packet,
260
+ in_args=in_args,
261
+ forgetful=forgetful,
262
+ run_id=run_id,
263
+ ctx=ctx,
264
+ )
265
+
136
266
  async def run_stream(
137
267
  self,
138
268
  chat_inputs: Any | None = None,
139
269
  *,
140
- in_packet: Packet[InT_contra] | None = None,
141
- in_args: InT_contra | Sequence[InT_contra] | None = None,
270
+ in_packet: Packet[InT] | None = None,
271
+ in_args: InT | None = None,
142
272
  forgetful: bool = False,
273
+ run_id: str | None = None,
143
274
  ctx: RunContext[CtxT] | None = None,
144
275
  ) -> AsyncIterator[Event[Any]]:
145
- resolved_in_args = self._validate_and_resolve_inputs(
276
+ resolved_in_args = self._validate_and_resolve_single_input(
146
277
  chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
147
278
  )
148
279
 
280
+ _memory = self.memory.model_copy(deep=True) if forgetful else self.memory
281
+
149
282
  outputs: Sequence[OutT_co] = []
150
283
  async for output_event in self._process_stream(
151
284
  chat_inputs=chat_inputs,
152
285
  in_args=resolved_in_args,
153
- forgetful=forgetful,
286
+ memory=_memory,
287
+ run_id=self._generate_run_id(run_id),
154
288
  ctx=ctx,
155
289
  ):
156
290
  if isinstance(output_event, ProcOutputEvent):
@@ -166,7 +300,7 @@ class Processor(
166
300
  @final
167
301
  def as_tool(
168
302
  self, tool_name: str, tool_description: str
169
- ) -> BaseTool[InT_contra, OutT_co, Any]: # type: ignore[override]
303
+ ) -> BaseTool[InT, OutT_co, Any]: # type: ignore[override]
170
304
  # TODO: stream tools
171
305
  processor_instance = self
172
306
  in_type = processor_instance.in_type
@@ -182,7 +316,7 @@ class Processor(
182
316
  description: str = tool_description
183
317
 
184
318
  async def run(
185
- self, inp: InT_contra, ctx: RunContext[CtxT] | None = None
319
+ self, inp: InT, ctx: RunContext[CtxT] | None = None
186
320
  ) -> OutT_co:
187
321
  result = await processor_instance.run(
188
322
  in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
@@ -1,15 +1,18 @@
1
1
  import json
2
2
  from collections.abc import Mapping, Sequence
3
- from typing import ClassVar, Generic, Protocol, TypeAlias
3
+ from typing import ClassVar, Generic, Protocol, TypeAlias, TypeVar
4
4
 
5
5
  from pydantic import BaseModel, TypeAdapter
6
6
 
7
+ from .errors import InputPromptBuilderError, SystemPromptBuilderError
7
8
  from .generics_utils import AutoInstanceAttributesMixin
8
9
  from .run_context import CtxT, RunContext
9
10
  from .typing.content import Content, ImageData
10
- from .typing.io import InT_contra, LLMPrompt, LLMPromptArgs
11
+ from .typing.io import InT, LLMPrompt, LLMPromptArgs
11
12
  from .typing.message import UserMessage
12
13
 
14
+ _InT_contra = TypeVar("_InT_contra", contravariant=True)
15
+
13
16
 
14
17
  class MakeSystemPromptHandler(Protocol[CtxT]):
15
18
  def __call__(
@@ -20,13 +23,12 @@ class MakeSystemPromptHandler(Protocol[CtxT]):
20
23
  ) -> str: ...
21
24
 
22
25
 
23
- class MakeInputContentHandler(Protocol[InT_contra, CtxT]):
26
+ class MakeInputContentHandler(Protocol[_InT_contra, CtxT]):
24
27
  def __call__(
25
28
  self,
26
29
  *,
27
- in_args: InT_contra | None,
30
+ in_args: _InT_contra | None,
28
31
  usr_args: LLMPromptArgs | None,
29
- batch_idx: int,
30
32
  ctx: RunContext[CtxT] | None,
31
33
  ) -> Content: ...
32
34
 
@@ -34,7 +36,7 @@ class MakeInputContentHandler(Protocol[InT_contra, CtxT]):
34
36
  PromptArgumentType: TypeAlias = str | bool | int | ImageData
35
37
 
36
38
 
37
- class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
39
+ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
38
40
  _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {0: "_in_type"}
39
41
 
40
42
  def __init__(
@@ -45,7 +47,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
45
47
  sys_args_schema: type[LLMPromptArgs] | None = None,
46
48
  usr_args_schema: type[LLMPromptArgs] | None = None,
47
49
  ):
48
- self._in_type: type[InT_contra]
50
+ self._in_type: type[InT]
49
51
  super().__init__()
50
52
 
51
53
  self._agent_name = agent_name
@@ -54,11 +56,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
54
56
  self.sys_args_schema = sys_args_schema
55
57
  self.usr_args_schema = usr_args_schema
56
58
  self.make_system_prompt_impl: MakeSystemPromptHandler[CtxT] | None = None
57
- self.make_input_content_impl: (
58
- MakeInputContentHandler[InT_contra, CtxT] | None
59
- ) = None
59
+ self.make_input_content_impl: MakeInputContentHandler[InT, CtxT] | None = None
60
60
 
61
- self._in_args_type_adapter: TypeAdapter[InT_contra] = TypeAdapter(self._in_type)
61
+ self._in_args_type_adapter: TypeAdapter[InT] = TypeAdapter(self._in_type)
62
62
 
63
63
  def make_system_prompt(
64
64
  self, sys_args: LLMPromptArgs | None = None, ctx: RunContext[CtxT] | None = None
@@ -71,9 +71,9 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
71
71
  if self.sys_args_schema is not None:
72
72
  val_sys_args = self.sys_args_schema.model_validate(sys_args)
73
73
  else:
74
- raise TypeError(
75
- "System prompt template is set, but system arguments schema is not "
76
- "provided."
74
+ raise SystemPromptBuilderError(
75
+ "System prompt template and arguments is set, but system arguments "
76
+ "schema is not provided."
77
77
  )
78
78
 
79
79
  if self.make_system_prompt_impl:
@@ -86,9 +86,8 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
86
86
  def make_input_content(
87
87
  self,
88
88
  *,
89
- in_args: InT_contra | None,
89
+ in_args: InT | None,
90
90
  usr_args: LLMPromptArgs | None,
91
- batch_idx: int = 0,
92
91
  ctx: RunContext[CtxT] | None = None,
93
92
  ) -> Content:
94
93
  val_in_args, val_usr_args = self._validate_prompt_args(
@@ -97,7 +96,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
97
96
 
98
97
  if self.make_input_content_impl:
99
98
  return self.make_input_content_impl(
100
- in_args=val_in_args, usr_args=val_usr_args, batch_idx=batch_idx, ctx=ctx
99
+ in_args=val_in_args, usr_args=val_usr_args, ctx=ctx
101
100
  )
102
101
 
103
102
  combined_args = self._combine_args(in_args=val_in_args, usr_args=val_usr_args)
@@ -106,56 +105,46 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
106
105
 
107
106
  if self.in_prompt_template is not None:
108
107
  return Content.from_formatted_prompt(
109
- self.in_prompt_template, prompt_args=combined_args
108
+ self.in_prompt_template, **combined_args
110
109
  )
111
110
 
112
111
  return Content.from_text(json.dumps(combined_args, indent=2))
113
112
 
114
- def make_user_messages(
113
+ def make_user_message(
115
114
  self,
116
115
  chat_inputs: LLMPrompt | Sequence[str | ImageData] | None = None,
117
- in_args_batch: Sequence[InT_contra] | None = None,
116
+ in_args: InT | None = None,
118
117
  usr_args: LLMPromptArgs | None = None,
119
118
  ctx: RunContext[CtxT] | None = None,
120
- ) -> Sequence[UserMessage]:
119
+ ) -> UserMessage | None:
120
+ if chat_inputs is None and in_args is None and usr_args is None:
121
+ return None
122
+
121
123
  if chat_inputs:
122
124
  if isinstance(chat_inputs, LLMPrompt):
123
- return self._usr_messages_from_text(chat_inputs)
124
- return self._usr_messages_from_content_parts(chat_inputs)
125
+ return UserMessage.from_text(chat_inputs, name=self._agent_name)
126
+ return UserMessage.from_content_parts(chat_inputs, name=self._agent_name)
125
127
 
126
- in_content_batch = [
127
- self.make_input_content(
128
- in_args=in_args, usr_args=usr_args, batch_idx=i, ctx=ctx
129
- )
130
- for i, in_args in enumerate(in_args_batch or [None])
131
- ]
132
- return [
133
- UserMessage(content=in_content, name=self._agent_name)
134
- for in_content in in_content_batch
135
- ]
136
-
137
- def _usr_messages_from_text(self, text: str) -> list[UserMessage]:
138
- return [UserMessage.from_text(text, name=self._agent_name)]
128
+ in_content = self.make_input_content(
129
+ in_args=in_args, usr_args=usr_args, ctx=ctx
130
+ )
139
131
 
140
- def _usr_messages_from_content_parts(
141
- self, content_parts: Sequence[str | ImageData]
142
- ) -> list[UserMessage]:
143
- return [UserMessage.from_content_parts(content_parts, name=self._agent_name)]
132
+ return UserMessage(content=in_content, name=self._agent_name)
144
133
 
145
134
  def _validate_prompt_args(
146
135
  self,
147
136
  *,
148
- in_args: InT_contra | None,
137
+ in_args: InT | None,
149
138
  usr_args: LLMPromptArgs | None,
150
- ) -> tuple[InT_contra | None, LLMPromptArgs | None]:
139
+ ) -> tuple[InT | None, LLMPromptArgs | None]:
151
140
  val_usr_args = usr_args
152
141
  if usr_args is not None:
153
142
  if self.in_prompt_template is None:
154
- raise TypeError(
143
+ raise InputPromptBuilderError(
155
144
  "Input prompt template is not set, but user arguments are provided."
156
145
  )
157
146
  if self.usr_args_schema is None:
158
- raise TypeError(
147
+ raise InputPromptBuilderError(
159
148
  "User arguments schema is not provided, but user arguments are "
160
149
  "given."
161
150
  )
@@ -167,12 +156,12 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
167
156
  if isinstance(val_in_args, BaseModel):
168
157
  has_image = self._has_image_data(val_in_args)
169
158
  if has_image and self.in_prompt_template is None:
170
- raise TypeError(
159
+ raise InputPromptBuilderError(
171
160
  "BaseModel input arguments contain ImageData, but input prompt "
172
161
  "template is not set. Cannot format input arguments."
173
162
  )
174
163
  elif self.in_prompt_template is not None:
175
- raise TypeError(
164
+ raise InputPromptBuilderError(
176
165
  "Cannot use the input prompt template with "
177
166
  "non-BaseModel input arguments."
178
167
  )
@@ -189,9 +178,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
189
178
  return contains_image_data
190
179
 
191
180
  @staticmethod
192
- def _format_pydantic_prompt_args(
193
- inp: BaseModel,
194
- ) -> dict[str, PromptArgumentType]:
181
+ def _format_pydantic_prompt_args(inp: BaseModel) -> dict[str, PromptArgumentType]:
195
182
  formatted_args: dict[str, PromptArgumentType] = {}
196
183
  for field in type(inp).model_fields:
197
184
  if field == "selected_recipients":
@@ -200,18 +187,17 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT_contra, CtxT]):
200
187
  val = getattr(inp, field)
201
188
  if isinstance(val, (int, str, bool, ImageData)):
202
189
  formatted_args[field] = val
203
- elif isinstance(val, BaseModel):
204
- formatted_args[field] = val.model_dump_json(indent=2, warnings="error")
205
190
  else:
206
- raise TypeError(
207
- f"Field '{field}' in prompt arguments must be of type "
208
- "int, str, bool, BaseModel, or ImageData."
191
+ formatted_args[field] = (
192
+ TypeAdapter(type(val)) # type: ignore[return-value]
193
+ .dump_json(val, indent=2, warnings="error")
194
+ .decode("utf-8")
209
195
  )
210
196
 
211
197
  return formatted_args
212
198
 
213
199
  def _combine_args(
214
- self, *, in_args: InT_contra | None, usr_args: LLMPromptArgs | None
200
+ self, *, in_args: InT | None, usr_args: LLMPromptArgs | None
215
201
  ) -> Mapping[str, PromptArgumentType] | str:
216
202
  fmt_usr_args = self._format_pydantic_prompt_args(usr_args) if usr_args else {}
217
203
 
@@ -1,7 +1,6 @@
1
1
  from collections import defaultdict
2
2
  from collections.abc import Mapping
3
3
  from typing import Any, Generic, TypeVar
4
- from uuid import uuid4
5
4
 
6
5
  from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
7
6
 
@@ -23,8 +22,6 @@ CtxT = TypeVar("CtxT")
23
22
 
24
23
 
25
24
  class RunContext(BaseModel, Generic[CtxT]):
26
- run_id: str = Field(default_factory=lambda: str(uuid4())[:8], frozen=True)
27
-
28
25
  state: CtxT | None = None
29
26
 
30
27
  run_args: dict[ProcName, RunArgs] = Field(default_factory=dict)
@@ -39,9 +36,8 @@ class RunContext(BaseModel, Generic[CtxT]):
39
36
  _printer: Printer = PrivateAttr()
40
37
 
41
38
  def model_post_init(self, context: Any) -> None: # noqa: ARG002
42
- self._usage_tracker = UsageTracker(source_id=self.run_id)
39
+ self._usage_tracker = UsageTracker()
43
40
  self._printer = Printer(
44
- source_id=self.run_id,
45
41
  print_messages=self.print_messages,
46
42
  color_by=self.color_messages_by,
47
43
  )
@@ -11,6 +11,7 @@ from openai.types.chat.chat_completion_token_logprob import (
11
11
  )
12
12
  from pydantic import BaseModel, Field
13
13
 
14
+ from ..errors import CombineCompletionChunksError
14
15
  from .completion import (
15
16
  Completion,
16
17
  CompletionChoice,
@@ -54,21 +55,25 @@ class CompletionChunk(BaseModel):
54
55
 
55
56
  def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
56
57
  if not chunks:
57
- raise ValueError("Cannot combine an empty list of completion chunks.")
58
+ raise CombineCompletionChunksError(
59
+ "Cannot combine an empty list of completion chunks."
60
+ )
58
61
 
59
62
  model_list = {chunk.model for chunk in chunks}
60
63
  if len(model_list) > 1:
61
- raise ValueError("All chunks must have the same model.")
64
+ raise CombineCompletionChunksError("All chunks must have the same model.")
62
65
  model = model_list.pop()
63
66
 
64
67
  name_list = {chunk.name for chunk in chunks}
65
68
  if len(name_list) > 1:
66
- raise ValueError("All chunks must have the same name.")
69
+ raise CombineCompletionChunksError("All chunks must have the same name.")
67
70
  name = name_list.pop()
68
71
 
69
72
  system_fingerprints_list = {chunk.system_fingerprint for chunk in chunks}
70
73
  if len(system_fingerprints_list) > 1:
71
- raise ValueError("All chunks must have the same system fingerprint.")
74
+ raise CombineCompletionChunksError(
75
+ "All chunks must have the same system fingerprint."
76
+ )
72
77
  system_fingerprint = system_fingerprints_list.pop()
73
78
 
74
79
  created_list = [chunk.created for chunk in chunks]
@@ -128,7 +133,7 @@ def combine_completion_chunks(chunks: list[CompletionChunk]) -> Completion:
128
133
  or _tool_call.tool_name is None
129
134
  or _tool_call.tool_arguments is None
130
135
  ):
131
- raise ValueError(
136
+ raise CombineCompletionChunksError(
132
137
  "Completion chunk tool calls must have id, tool_name, "
133
138
  "and tool_arguments set."
134
139
  )
@@ -1,6 +1,6 @@
1
1
  import base64
2
2
  import re
3
- from collections.abc import Iterable, Mapping
3
+ from collections.abc import Iterable
4
4
  from enum import StrEnum
5
5
  from pathlib import Path
6
6
  from typing import Annotated, Any, Literal, TypeAlias
@@ -66,7 +66,8 @@ class Content(BaseModel):
66
66
  def from_formatted_prompt(
67
67
  cls,
68
68
  prompt_template: str,
69
- prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
69
+ /,
70
+ **prompt_args: str | int | bool | ImageData | None,
70
71
  ) -> "Content":
71
72
  prompt_args = prompt_args or {}
72
73
  image_args = {
grasp_agents/typing/io.py CHANGED
@@ -5,12 +5,12 @@ from pydantic import BaseModel
5
5
  ProcName: TypeAlias = str
6
6
 
7
7
 
8
+ InT = TypeVar("InT")
9
+ OutT_co = TypeVar("OutT_co", covariant=True)
10
+
11
+
8
12
  class LLMPromptArgs(BaseModel):
9
13
  pass
10
14
 
11
15
 
12
- InT_contra = TypeVar("InT_contra", contravariant=True)
13
- OutT_co = TypeVar("OutT_co", covariant=True)
14
- MemT_co = TypeVar("MemT_co", covariant=True)
15
-
16
16
  LLMPrompt: TypeAlias = str