grasp_agents 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/LICENSE.md +1 -1
  2. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/PKG-INFO +1 -1
  3. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/pyproject.toml +1 -1
  4. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/cloud_llm.py +1 -0
  5. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/llm_agent.py +70 -21
  6. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/llm_agent_state.py +9 -7
  7. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/prompt_builder.py +11 -4
  8. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/tool_orchestrator.py +8 -8
  9. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/content.py +2 -2
  10. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/io.py +3 -2
  11. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/message.py +2 -2
  12. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/tool.py +0 -7
  13. grasp_agents-0.2.1/src/grasp_agents/utils.py +194 -0
  14. grasp_agents-0.2.0/src/grasp_agents/utils.py +0 -187
  15. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/.gitignore +0 -0
  16. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/README.md +0 -0
  17. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/agent_message.py +0 -0
  18. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/agent_message_pool.py +0 -0
  19. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/base_agent.py +0 -0
  20. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/comm_agent.py +0 -0
  21. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/costs_dict.yaml +0 -0
  22. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/generics_utils.py +0 -0
  23. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/grasp_logging.py +0 -0
  24. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/http_client.py +0 -0
  25. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/llm.py +0 -0
  26. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/memory.py +0 -0
  27. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/__init__.py +0 -0
  28. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/completion_converters.py +0 -0
  29. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/content_converters.py +0 -0
  30. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/converters.py +0 -0
  31. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/message_converters.py +0 -0
  32. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/openai_llm.py +0 -0
  33. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/openai/tool_converters.py +0 -0
  34. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/printer.py +0 -0
  35. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/rate_limiting/__init__.py +0 -0
  36. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/rate_limiting/rate_limiter_chunked.py +0 -0
  37. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/rate_limiting/types.py +0 -0
  38. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/rate_limiting/utils.py +0 -0
  39. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/run_context.py +0 -0
  40. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/__init__.py +0 -0
  41. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/completion.py +0 -0
  42. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/typing/converters.py +0 -0
  43. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/usage_tracker.py +0 -0
  44. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/workflow/__init__.py +0 -0
  45. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/workflow/looped_agent.py +0 -0
  46. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/workflow/sequential_agent.py +0 -0
  47. {grasp_agents-0.2.0 → grasp_agents-0.2.1}/src/grasp_agents/workflow/workflow_agent.py +0 -0
@@ -8,6 +8,6 @@ Package production dependencies are licensed under the following terms:
8
8
  | dotenv | 0.9.9 | BSD-3-Clause license | https://github.com/pedroburon/dotenv |
9
9
  | httpx | 0.28.1 | BSD License | https://github.com/encode/httpx |
10
10
  | openai | 1.77.0 | Apache Software License | https://github.com/openai/openai-python |
11
- | tenacity | 9.1.2 | Apache Software License | https://github.com/jd/tenacity |
11
+ | tenacity | 8.5.0 | Apache Software License | https://github.com/jd/tenacity |
12
12
  | termcolor | 2.5.0 | MIT License | https://github.com/termcolor/termcolor |
13
13
  | tqdm | 4.67.1 | MIT License; Mozilla Public License 2.0 (MPL 2.0) | https://tqdm.github.io |
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: grasp_agents
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Grasp Agents Library
5
5
  License-File: LICENSE.md
6
6
  Requires-Python: <4,>=3.11.4
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "grasp_agents"
3
- version = "0.2.0"
3
+ version = "0.2.1"
4
4
  description = "Grasp Agents Library"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11.4,<4"
@@ -286,6 +286,7 @@ class CloudLLM(LLM[SettingsT, ConvertT], Generic[SettingsT, ConvertT]):
286
286
  validate_obj_from_json_or_py_string(
287
287
  message.content,
288
288
  adapter=self._response_format_pyd,
289
+ from_substring=True,
289
290
  )
290
291
 
291
292
  async def generate_completion_stream(
@@ -10,7 +10,7 @@ from .comm_agent import CommunicatingAgent
10
10
  from .llm import LLM, LLMSettings
11
11
  from .llm_agent_state import (
12
12
  LLMAgentState,
13
- MakeCustomAgentState,
13
+ SetAgentState,
14
14
  SetAgentStateStrategy,
15
15
  )
16
16
  from .prompt_builder import (
@@ -25,7 +25,11 @@ from .run_context import (
25
25
  SystemRunArgs,
26
26
  UserRunArgs,
27
27
  )
28
- from .tool_orchestrator import ExitToolCallLoopHandler, ToolOrchestrator
28
+ from .tool_orchestrator import (
29
+ ExitToolCallLoopHandler,
30
+ ManageAgentStateHandler,
31
+ ToolOrchestrator,
32
+ )
29
33
  from .typing.content import ImageData
30
34
  from .typing.converters import Converters
31
35
  from .typing.io import (
@@ -43,9 +47,14 @@ from .typing.tool import BaseTool
43
47
  from .utils import get_prompt, validate_obj_from_json_or_py_string
44
48
 
45
49
 
46
- class ParseOutputHandler(Protocol[OutT, CtxT]):
50
+ class ParseOutputHandler(Protocol[InT, OutT, CtxT]):
47
51
  def __call__(
48
- self, *args: Any, ctx: RunContextWrapper[CtxT] | None, **kwargs: Any
52
+ self,
53
+ conversation: Conversation,
54
+ *,
55
+ rcv_args: InT | None,
56
+ batch_idx: int,
57
+ ctx: RunContextWrapper[CtxT] | None,
49
58
  ) -> OutT: ...
50
59
 
51
60
 
@@ -91,9 +100,15 @@ class LLMAgent(
91
100
  # Agent state
92
101
  self._state: LLMAgentState = LLMAgentState()
93
102
  self.set_state_strategy: SetAgentStateStrategy = set_state_strategy
94
- self._make_custom_agent_state_impl: MakeCustomAgentState | None = None
103
+ self._set_agent_state_impl: SetAgentState | None = None
95
104
 
96
105
  # Tool orchestrator
106
+
107
+ self._using_default_llm_response_format: bool = False
108
+ if llm.response_format is None and tools is None:
109
+ llm.response_format = self.out_type
110
+ self._using_default_llm_response_format = True
111
+
97
112
  self._tool_orchestrator: ToolOrchestrator[CtxT] = ToolOrchestrator[CtxT](
98
113
  agent_id=self.agent_id,
99
114
  llm=llm,
@@ -152,16 +167,28 @@ class LLMAgent(
152
167
  conversation: Conversation,
153
168
  *,
154
169
  rcv_args: InT | None = None,
170
+ batch_idx: int = 0,
155
171
  ctx: RunContextWrapper[CtxT] | None = None,
156
- **kwargs: Any,
157
172
  ) -> OutT:
158
173
  if self._parse_output_impl:
174
+ if self._using_default_llm_response_format:
175
+ # When using custom output parsing, the required LLM response format
176
+ # can differ from the final agent output type ->
177
+ # set it back to None unless it was specified explicitly at init.
178
+ self._tool_orchestrator.llm.response_format = None
179
+ self._using_default_llm_response_format = False
180
+
159
181
  return self._parse_output_impl(
160
- conversation=conversation, rcv_args=rcv_args, ctx=ctx, **kwargs
182
+ conversation=conversation,
183
+ rcv_args=rcv_args,
184
+ batch_idx=batch_idx,
185
+ ctx=ctx,
161
186
  )
162
187
 
163
188
  return validate_obj_from_json_or_py_string(
164
- str(conversation[-1].content), self._out_type_adapter
189
+ str(conversation[-1].content),
190
+ adapter=self._out_type_adapter,
191
+ from_substring=True,
165
192
  )
166
193
 
167
194
  @final
@@ -210,7 +237,7 @@ class LLMAgent(
210
237
  rcv_state=rcv_state,
211
238
  sys_prompt=formatted_sys_prompt,
212
239
  strategy=self.set_state_strategy,
213
- make_custom_state_impl=self._make_custom_agent_state_impl,
240
+ set_agent_state_impl=self._set_agent_state_impl,
214
241
  ctx=ctx,
215
242
  )
216
243
 
@@ -234,7 +261,7 @@ class LLMAgent(
234
261
  else:
235
262
  # 4. Run tool call loop (new messages are added to the message
236
263
  # history inside the loop)
237
- await self._tool_orchestrator.run_loop(agent_state=state, ctx=ctx)
264
+ await self._tool_orchestrator.run_loop(state=state, ctx=ctx)
238
265
 
239
266
  # 5. Parse outputs
240
267
  batch_size = state.message_history.batch_size
@@ -324,15 +351,13 @@ class LLMAgent(
324
351
  return func
325
352
 
326
353
  def parse_output_handler(
327
- self, func: ParseOutputHandler[OutT, CtxT]
328
- ) -> ParseOutputHandler[OutT, CtxT]:
354
+ self, func: ParseOutputHandler[InT, OutT, CtxT]
355
+ ) -> ParseOutputHandler[InT, OutT, CtxT]:
329
356
  self._parse_output_impl = func
330
357
 
331
358
  return func
332
359
 
333
- def make_custom_agent_state_handler(
334
- self, func: MakeCustomAgentState
335
- ) -> MakeCustomAgentState:
360
+ def set_agent_state_handler(self, func: SetAgentState) -> SetAgentState:
336
361
  self._make_custom_agent_state_impl = func
337
362
 
338
363
  return func
@@ -344,6 +369,13 @@ class LLMAgent(
344
369
 
345
370
  return func
346
371
 
372
+ def manage_agent_state_handler(
373
+ self, func: ManageAgentStateHandler[CtxT]
374
+ ) -> ManageAgentStateHandler[CtxT]:
375
+ self._tool_orchestrator.manage_agent_state_impl = func
376
+
377
+ return func
378
+
347
379
  # -- Override these methods in subclasses if needed --
348
380
 
349
381
  def _register_overridden_handlers(self) -> None:
@@ -356,15 +388,18 @@ class LLMAgent(
356
388
  if cur_cls._format_inp_args is not base_cls._format_inp_args: # noqa: SLF001
357
389
  self._prompt_builder.format_inp_args_impl = self._format_inp_args
358
390
 
359
- if cur_cls._make_custom_agent_state is not base_cls._make_custom_agent_state: # noqa: SLF001
360
- self._make_custom_agent_state_impl = self._make_custom_agent_state
391
+ if cur_cls._set_agent_state is not base_cls._set_agent_state: # noqa: SLF001
392
+ self._set_agent_state_impl = self._set_agent_state
393
+
394
+ if cur_cls._manage_agent_state is not base_cls._manage_agent_state: # noqa: SLF001
395
+ self._tool_orchestrator.manage_agent_state_impl = self._manage_agent_state
361
396
 
362
397
  if (
363
398
  cur_cls._tool_call_loop_exit is not base_cls._tool_call_loop_exit # noqa: SLF001
364
399
  ):
365
400
  self._tool_orchestrator.exit_tool_call_loop_impl = self._tool_call_loop_exit
366
401
 
367
- self._parse_output_impl: ParseOutputHandler[OutT, CtxT] | None = None
402
+ self._parse_output_impl: ParseOutputHandler[InT, OutT, CtxT] | None = None
368
403
 
369
404
  def _format_sys_args(
370
405
  self,
@@ -381,21 +416,24 @@ class LLMAgent(
381
416
  self,
382
417
  usr_args: LLMPromptArgs,
383
418
  rcv_args: InT,
419
+ *,
420
+ batch_idx: int = 0,
384
421
  ctx: RunContextWrapper[CtxT] | None = None,
385
422
  ) -> LLMFormattedArgs:
386
423
  raise NotImplementedError(
387
424
  "LLMAgent._format_inp_args must be overridden by a subclass"
388
425
  )
389
426
 
390
- def _make_custom_agent_state(
427
+ def _set_agent_state(
391
428
  self,
392
- cur_state: LLMAgentState | None,
429
+ cur_state: LLMAgentState,
430
+ *,
393
431
  rcv_state: AgentState | None,
394
432
  sys_prompt: LLMPrompt | None,
395
433
  ctx: RunContextWrapper[Any] | None,
396
434
  ) -> LLMAgentState:
397
435
  raise NotImplementedError(
398
- "LLMAgent._make_custom_agent_state_handler must be overridden by a subclass"
436
+ "LLMAgent._set_agent_state_handler must be overridden by a subclass"
399
437
  )
400
438
 
401
439
  def _tool_call_loop_exit(
@@ -408,3 +446,14 @@ class LLMAgent(
408
446
  raise NotImplementedError(
409
447
  "LLMAgent._tool_call_loop_exit must be overridden by a subclass"
410
448
  )
449
+
450
+ def _manage_agent_state(
451
+ self,
452
+ state: LLMAgentState,
453
+ *,
454
+ ctx: RunContextWrapper[CtxT] | None = None,
455
+ **kwargs: Any,
456
+ ) -> None:
457
+ raise NotImplementedError(
458
+ "LLMAgent._manage_agent_state must be overridden by a subclass"
459
+ )
@@ -10,10 +10,11 @@ from .typing.io import AgentState, LLMPrompt
10
10
  SetAgentStateStrategy = Literal["keep", "reset", "from_sender", "custom"]
11
11
 
12
12
 
13
- class MakeCustomAgentState(Protocol):
13
+ class SetAgentState(Protocol):
14
14
  def __call__(
15
15
  self,
16
- cur_state: Optional["LLMAgentState"],
16
+ cur_state: "LLMAgentState",
17
+ *,
17
18
  rcv_state: AgentState | None,
18
19
  sys_prompt: LLMPrompt | None,
19
20
  ctx: RunContextWrapper[Any] | None,
@@ -30,11 +31,12 @@ class LLMAgentState(AgentState):
30
31
  @classmethod
31
32
  def from_cur_and_rcv_states(
32
33
  cls,
33
- cur_state: Optional["LLMAgentState"] = None,
34
+ cur_state: "LLMAgentState",
35
+ *,
34
36
  rcv_state: Optional["AgentState"] = None,
35
37
  sys_prompt: LLMPrompt | None = None,
36
38
  strategy: SetAgentStateStrategy = "from_sender",
37
- make_custom_state_impl: MakeCustomAgentState | None = None,
39
+ set_agent_state_impl: SetAgentState | None = None,
38
40
  ctx: RunContextWrapper[Any] | None = None,
39
41
  ) -> "LLMAgentState":
40
42
  upd_mh = cur_state.message_history if cur_state else None
@@ -59,10 +61,10 @@ class LLMAgentState(AgentState):
59
61
  upd_mh.reset(sys_prompt)
60
62
 
61
63
  elif strategy == "custom":
62
- assert make_custom_state_impl is not None, (
63
- "Custom message history handler implementation is not provided."
64
+ assert set_agent_state_impl is not None, (
65
+ "Agent state setter implementation is not provided."
64
66
  )
65
- return make_custom_state_impl(
67
+ return set_agent_state_impl(
66
68
  cur_state=cur_state,
67
69
  rcv_state=rcv_state,
68
70
  sys_prompt=sys_prompt,
@@ -35,6 +35,8 @@ class FormatInputArgsHandler(Protocol[InT, CtxT]):
35
35
  self,
36
36
  usr_args: LLMPromptArgs,
37
37
  rcv_args: InT,
38
+ *,
39
+ batch_idx: int,
38
40
  ctx: RunContextWrapper[CtxT] | None,
39
41
  ) -> LLMFormattedArgs: ...
40
42
 
@@ -77,11 +79,13 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
77
79
  self,
78
80
  usr_args: LLMPromptArgs,
79
81
  rcv_args: InT,
82
+ *,
83
+ batch_idx: int = 0,
80
84
  ctx: RunContextWrapper[CtxT] | None = None,
81
85
  ) -> LLMFormattedArgs:
82
86
  if self.format_inp_args_impl:
83
87
  return self.format_inp_args_impl(
84
- usr_args=usr_args, rcv_args=rcv_args, ctx=ctx
88
+ usr_args=usr_args, rcv_args=rcv_args, batch_idx=batch_idx, ctx=ctx
85
89
  )
86
90
 
87
91
  if not isinstance(rcv_args, BaseModel) and rcv_args is not None:
@@ -100,6 +104,7 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
100
104
  def make_sys_prompt(
101
105
  self,
102
106
  sys_args: LLMPromptArgs,
107
+ *,
103
108
  ctx: RunContextWrapper[CtxT] | None,
104
109
  ) -> LLMPrompt | None:
105
110
  if self.sys_prompt is None:
@@ -151,9 +156,11 @@ class PromptBuilder(AutoInstanceAttributesMixin, Generic[InT, CtxT]):
151
156
  ]
152
157
 
153
158
  formatted_inp_args_batch = [
154
- self._format_inp_args(usr_args=val_usr_args, rcv_args=val_rcv_args, ctx=ctx)
155
- for val_usr_args, val_rcv_args in zip(
156
- val_usr_args_batch_, val_rcv_args_batch_, strict=False
159
+ self._format_inp_args(
160
+ usr_args=val_usr_args, rcv_args=val_rcv_args, batch_idx=i, ctx=ctx
161
+ )
162
+ for i, (val_usr_args, val_rcv_args) in enumerate(
163
+ zip(val_usr_args_batch_, val_rcv_args_batch_, strict=False)
157
164
  )
158
165
  ]
159
166
 
@@ -29,7 +29,7 @@ class ExitToolCallLoopHandler(Protocol[CtxT]):
29
29
  class ManageAgentStateHandler(Protocol[CtxT]):
30
30
  def __call__(
31
31
  self,
32
- agent_state: LLMAgentState,
32
+ state: LLMAgentState,
33
33
  *,
34
34
  ctx: RunContextWrapper[CtxT] | None,
35
35
  **kwargs: Any,
@@ -93,13 +93,13 @@ class ToolOrchestrator(Generic[CtxT]):
93
93
 
94
94
  def _manage_agent_state(
95
95
  self,
96
- agent_state: LLMAgentState,
96
+ state: LLMAgentState,
97
97
  *,
98
98
  ctx: RunContextWrapper[CtxT] | None = None,
99
99
  **kwargs: Any,
100
100
  ) -> None:
101
101
  if self.manage_agent_state_impl:
102
- self.manage_agent_state_impl(agent_state=agent_state, ctx=ctx, **kwargs)
102
+ self.manage_agent_state_impl(state=state, ctx=ctx, **kwargs)
103
103
 
104
104
  async def generate_once(
105
105
  self,
@@ -119,10 +119,10 @@ class ToolOrchestrator(Generic[CtxT]):
119
119
 
120
120
  async def run_loop(
121
121
  self,
122
- agent_state: LLMAgentState,
122
+ state: LLMAgentState,
123
123
  ctx: RunContextWrapper[CtxT] | None = None,
124
124
  ) -> None:
125
- message_history = agent_state.message_history
125
+ message_history = state.message_history
126
126
  assert message_history.batch_size == 1, (
127
127
  "Batch size must be 1 for tool call loop"
128
128
  )
@@ -131,13 +131,13 @@ class ToolOrchestrator(Generic[CtxT]):
131
131
 
132
132
  tool_choice = "none" if self._react_mode else "auto"
133
133
  gen_message_batch = await self.generate_once(
134
- agent_state, tool_choice=tool_choice, ctx=ctx
134
+ state, tool_choice=tool_choice, ctx=ctx
135
135
  )
136
136
 
137
137
  turns = 0
138
138
 
139
139
  while True:
140
- self._manage_agent_state(agent_state=agent_state, ctx=ctx, num_turns=turns)
140
+ self._manage_agent_state(state=state, ctx=ctx, num_turns=turns)
141
141
 
142
142
  if self._exit_tool_call_loop(
143
143
  message_history.batched_conversations[0], ctx=ctx, num_turns=turns
@@ -156,7 +156,7 @@ class ToolOrchestrator(Generic[CtxT]):
156
156
 
157
157
  tool_choice = "none" if (self._react_mode and msg.tool_calls) else "auto"
158
158
  gen_message_batch = await self.generate_once(
159
- agent_state, tool_choice=tool_choice, ctx=ctx
159
+ state, tool_choice=tool_choice, ctx=ctx
160
160
  )
161
161
 
162
162
  turns += 1
@@ -1,6 +1,6 @@
1
1
  import base64
2
2
  import re
3
- from collections.abc import Iterable
3
+ from collections.abc import Iterable, Mapping
4
4
  from enum import StrEnum
5
5
  from pathlib import Path
6
6
  from typing import Annotated, Any, Literal, TypeAlias
@@ -66,7 +66,7 @@ class Content(BaseModel):
66
66
  def from_formatted_prompt(
67
67
  cls,
68
68
  prompt_template: str,
69
- prompt_args: dict[str, str | ImageData] | None = None,
69
+ prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
70
70
  ) -> "Content":
71
71
  prompt_args = prompt_args or {}
72
72
  image_args = {
@@ -1,3 +1,4 @@
1
+ from collections.abc import Mapping
1
2
  from typing import TypeAlias, TypeVar
2
3
 
3
4
  from pydantic import BaseModel
@@ -20,5 +21,5 @@ OutT = TypeVar("OutT", covariant=True) # noqa: PLC0105
20
21
  StateT = TypeVar("StateT", bound=AgentState, covariant=True) # noqa: PLC0105
21
22
 
22
23
  LLMPrompt: TypeAlias = str
23
- LLMFormattedSystemArgs: TypeAlias = dict[str, str]
24
- LLMFormattedArgs: TypeAlias = dict[str, str | ImageData]
24
+ LLMFormattedSystemArgs: TypeAlias = Mapping[str, str | int | bool]
25
+ LLMFormattedArgs: TypeAlias = Mapping[str, str | int | bool | ImageData]
@@ -1,5 +1,5 @@
1
1
  import json
2
- from collections.abc import Hashable, Sequence
2
+ from collections.abc import Hashable, Mapping, Sequence
3
3
  from enum import StrEnum
4
4
  from typing import Annotated, Any, Literal, TypeAlias
5
5
  from uuid import uuid4
@@ -79,7 +79,7 @@ class UserMessage(MessageBase):
79
79
  def from_formatted_prompt(
80
80
  cls,
81
81
  prompt_template: str,
82
- prompt_args: dict[str, str | ImageData] | None = None,
82
+ prompt_args: Mapping[str, str | int | bool | ImageData] | None = None,
83
83
  model_id: str | None = None,
84
84
  ) -> "UserMessage":
85
85
  content = Content.from_formatted_prompt(
@@ -1,8 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  from abc import ABC, abstractmethod
5
- from collections.abc import Sequence
6
4
  from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar
7
5
 
8
6
  from pydantic import BaseModel, PrivateAttr, TypeAdapter
@@ -60,11 +58,6 @@ class BaseTool(
60
58
  ) -> _ToolOutT:
61
59
  pass
62
60
 
63
- async def run_batch(
64
- self, inp_batch: Sequence[_ToolInT], ctx: RunContextWrapper[CtxT] | None = None
65
- ) -> Sequence[_ToolOutT]:
66
- return await asyncio.gather(*[self.run(inp, ctx=ctx) for inp in inp_batch])
67
-
68
61
  async def __call__(
69
62
  self, ctx: RunContextWrapper[CtxT] | None = None, **kwargs: Any
70
63
  ) -> _ToolOutT:
@@ -0,0 +1,194 @@
1
+ import ast
2
+ import asyncio
3
+ import json
4
+ import re
5
+ from collections.abc import Coroutine, Mapping
6
+ from datetime import datetime
7
+ from logging import getLogger
8
+ from pathlib import Path
9
+ from typing import Any, TypeVar
10
+
11
+ from pydantic import (
12
+ GetCoreSchemaHandler,
13
+ TypeAdapter,
14
+ ValidationError,
15
+ )
16
+ from pydantic_core import core_schema
17
+ from tqdm.autonotebook import tqdm
18
+
19
+ logger = getLogger(__name__)
20
+
21
+ _JSON_START_RE = re.compile(r"[{\[]")
22
+
23
+ T = TypeVar("T")
24
+
25
+
26
+ def extract_json_substring(text: str) -> str | None:
27
+ decoder = json.JSONDecoder()
28
+ for match in _JSON_START_RE.finditer(text):
29
+ start = match.start()
30
+ try:
31
+ _, end = decoder.raw_decode(text, idx=start)
32
+ return text[start:end]
33
+ except ValueError:
34
+ continue
35
+
36
+ return None
37
+
38
+
39
+ def parse_json_or_py_string(
40
+ s: str, return_none_on_failure: bool = False
41
+ ) -> dict[str, Any] | list[Any] | None:
42
+ s_fmt = re.sub(r"```[a-zA-Z0-9]*\n|```", "", s).strip()
43
+ try:
44
+ return ast.literal_eval(s_fmt)
45
+ except (ValueError, SyntaxError):
46
+ try:
47
+ return json.loads(s_fmt)
48
+ except json.JSONDecodeError as exc:
49
+ if return_none_on_failure:
50
+ return None
51
+ raise ValueError(
52
+ "Invalid JSON/Python string - Both ast.literal_eval and json.loads "
53
+ f"failed to parse the following response:\n{s}"
54
+ ) from exc
55
+
56
+
57
+ def parse_json_or_py_substring(
58
+ json_str: str, return_none_on_failure: bool = False
59
+ ) -> dict[str, Any] | list[Any] | None:
60
+ return parse_json_or_py_string(
61
+ extract_json_substring(json_str) or "", return_none_on_failure
62
+ )
63
+
64
+
65
+ def validate_obj_from_json_or_py_string(
66
+ s: str, adapter: TypeAdapter[T], from_substring: bool = False
67
+ ) -> T:
68
+ try:
69
+ if from_substring:
70
+ parsed = parse_json_or_py_substring(s, return_none_on_failure=True)
71
+ else:
72
+ parsed = parse_json_or_py_string(s, return_none_on_failure=True)
73
+ if parsed is None:
74
+ parsed = s
75
+ return adapter.validate_python(parsed)
76
+ except (json.JSONDecodeError, ValidationError) as exc:
77
+ raise ValueError(
78
+ f"Invalid JSON or Python string:\n{s}\nExpected type: {adapter._type}", # type: ignore[arg-type]
79
+ ) from exc
80
+
81
+
82
+ def extract_xml_list(text: str) -> list[str]:
83
+ pattern = re.compile(r"<(chunk_\d+)>(.*?)</\1>", re.DOTALL)
84
+
85
+ chunks: list[str] = []
86
+ for match in pattern.finditer(text):
87
+ content = match.group(2).strip()
88
+ chunks.append(content)
89
+ return chunks
90
+
91
+
92
+ def build_marker_json_parser_type(
93
+ marker_to_model: Mapping[str, type],
94
+ ) -> type:
95
+ """
96
+ Return a Pydantic-compatible *type* that, when given a **str**, searches for
97
+ the first marker substring and validates the JSON that follows with the
98
+ corresponding Pydantic model.
99
+
100
+ If no marker is found, the raw string is returned unchanged.
101
+
102
+ Example:
103
+ -------
104
+ >>> Todo = build_marker_json_parser_type({'```json': MyModel})
105
+ >>> Todo.validate('```json {"a": 1}')
106
+ MyModel(a=1)
107
+
108
+ """
109
+
110
+ class MarkerParsedOutput:
111
+ """String → (Model | str) parser generated by build_marker_json_parser_type."""
112
+
113
+ @classmethod
114
+ def __get_pydantic_core_schema__(
115
+ cls,
116
+ _source_type: Any,
117
+ _handler: GetCoreSchemaHandler,
118
+ ) -> core_schema.CoreSchema:
119
+ def _validate(value: Any) -> Any:
120
+ if not isinstance(value, str):
121
+ raise TypeError("MarkerParsedOutput expects a string")
122
+
123
+ for marker, model in marker_to_model.items():
124
+ if marker in value:
125
+ adapter = TypeAdapter[Any](model)
126
+ return validate_obj_from_json_or_py_string(
127
+ value, adapter=adapter, from_substring=True
128
+ )
129
+
130
+ return value
131
+
132
+ return core_schema.no_info_after_validator_function(
133
+ _validate, core_schema.any_schema()
134
+ )
135
+
136
+ @classmethod
137
+ def __get_pydantic_json_schema__(
138
+ cls,
139
+ schema: core_schema.CoreSchema,
140
+ handler: GetCoreSchemaHandler,
141
+ ):
142
+ return handler(schema)
143
+
144
+ unique_suffix = "_".join(sorted(marker_to_model))[:40]
145
+ MarkerParsedOutput.__name__ = f"MarkerParsedOutput_{unique_suffix}"
146
+
147
+ return MarkerParsedOutput
148
+
149
+
150
+ def read_txt(file_path: str | Path, encoding: str = "utf-8") -> str:
151
+ return Path(file_path).read_text(encoding=encoding)
152
+
153
+
154
+ def read_contents_from_file(
155
+ file_path: str | Path,
156
+ binary_mode: bool = False,
157
+ ) -> str | bytes:
158
+ try:
159
+ if binary_mode:
160
+ return Path(file_path).read_bytes()
161
+ return Path(file_path).read_text()
162
+ except FileNotFoundError:
163
+ logger.error(f"File {file_path} not found.")
164
+ return ""
165
+
166
+
167
+ def get_prompt(prompt_text: str | None, prompt_path: str | Path | None) -> str | None:
168
+ if prompt_text is None:
169
+ return read_contents_from_file(prompt_path) if prompt_path is not None else None # type: ignore[arg-type]
170
+
171
+ return prompt_text
172
+
173
+
174
+ async def asyncio_gather_with_pbar(
175
+ *corouts: Coroutine[Any, Any, Any],
176
+ no_tqdm: bool = False,
177
+ desc: str | None = None,
178
+ ) -> list[Any]:
179
+ pbar = tqdm(total=len(corouts), desc=desc, disable=no_tqdm)
180
+
181
+ async def run_and_update(coro: Coroutine[Any, Any, Any]) -> Any:
182
+ result = await coro
183
+ pbar.update(1)
184
+ return result
185
+
186
+ wrapped_tasks = [run_and_update(c) for c in corouts]
187
+ results = await asyncio.gather(*wrapped_tasks)
188
+ pbar.close()
189
+
190
+ return results
191
+
192
+
193
+ def get_timestamp() -> str:
194
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -1,187 +0,0 @@
1
- import ast
2
- import asyncio
3
- import json
4
- import re
5
- from collections.abc import Coroutine
6
- from datetime import datetime
7
- from logging import getLogger
8
- from pathlib import Path
9
- from typing import Any, TypeVar
10
-
11
- from pydantic import (
12
- BaseModel,
13
- GetCoreSchemaHandler,
14
- TypeAdapter,
15
- ValidationError,
16
- )
17
- from pydantic_core import core_schema
18
- from tqdm.autonotebook import tqdm
19
-
20
- logger = getLogger(__name__)
21
-
22
- T = TypeVar("T")
23
-
24
-
25
- def filter_fields(data: dict[str, Any], model: type[BaseModel]) -> dict[str, Any]:
26
- return {key: data[key] for key in model.model_fields if key in data}
27
-
28
-
29
- def read_txt(file_path: str) -> str:
30
- return Path(file_path).read_text()
31
-
32
-
33
- def format_json_string(text: str) -> str:
34
- decoder = json.JSONDecoder()
35
- text = text.replace("\n", "")
36
- length = len(text)
37
- i = 0
38
- while i < length:
39
- ch = text[i]
40
- if ch in "{[":
41
- try:
42
- _, end = decoder.raw_decode(text[i:])
43
- return text[i : i + end]
44
- except ValueError:
45
- pass
46
- i += 1
47
-
48
- return text
49
-
50
-
51
- def parse_json_or_py_string(
52
- json_str: str, return_none_on_failure: bool = False
53
- ) -> dict[str, Any] | list[Any] | None:
54
- try:
55
- json_response = ast.literal_eval(json_str)
56
- except (ValueError, SyntaxError):
57
- try:
58
- json_response = json.loads(json_str)
59
- except json.JSONDecodeError as exc:
60
- if return_none_on_failure:
61
- return None
62
- raise ValueError(
63
- "Invalid JSON - Both ast.literal_eval and json.loads "
64
- f"failed to parse the following response:\n{json_str}"
65
- ) from exc
66
-
67
- return json_response
68
-
69
-
70
- def extract_json(
71
- json_str: str, return_none_on_failure: bool = False
72
- ) -> dict[str, Any] | list[Any] | None:
73
- return parse_json_or_py_string(format_json_string(json_str), return_none_on_failure)
74
-
75
-
76
- def validate_obj_from_json_or_py_string(s: str, adapter: TypeAdapter[T]) -> T:
77
- s_fmt = re.sub(r"```[a-zA-Z0-9]*\n|```", "", s).strip()
78
- try:
79
- parsed = json.loads(s_fmt)
80
- return adapter.validate_python(parsed)
81
- except (json.JSONDecodeError, ValidationError):
82
- try:
83
- return adapter.validate_python(s_fmt)
84
- except ValidationError as exc:
85
- raise ValueError(
86
- f"Invalid JSON or Python string:\n{s}\nExpected type: {adapter._type}", # type: ignore[arg-type]
87
- ) from exc
88
-
89
-
90
- def extract_xml_list(text: str) -> list[str]:
91
- pattern = re.compile(r"<(chunk_\d+)>(.*?)</\1>", re.DOTALL)
92
-
93
- chunks: list[str] = []
94
- for match in pattern.finditer(text):
95
- content = match.group(2).strip()
96
- chunks.append(content)
97
- return chunks
98
-
99
-
100
- def make_conditional_parsed_output_type(
101
- response_format: type, marker: str = "<DONE>"
102
- ) -> type:
103
- class ParsedOutput:
104
- """
105
- * Accepts any **str**.
106
- * If the string contains `marker`, it must contain a valid JSON for
107
- `response_format` → we return that a response_format instance.
108
- * Otherwise we leave the string untouched.
109
- """
110
-
111
- @classmethod
112
- def __get_pydantic_core_schema__(
113
- cls,
114
- _source_type: Any,
115
- _handler: GetCoreSchemaHandler,
116
- ) -> core_schema.CoreSchema:
117
- def validator(v: Any) -> Any:
118
- if isinstance(v, str) and marker in v:
119
- v_json_str = format_json_string(v)
120
- response_format_adapter = TypeAdapter[Any](response_format)
121
-
122
- return response_format_adapter.validate_json(v_json_str)
123
-
124
- return v
125
-
126
- return core_schema.no_info_after_validator_function(
127
- validator, core_schema.any_schema()
128
- )
129
-
130
- @classmethod
131
- def __get_pydantic_json_schema__(
132
- cls, core_schema: core_schema.CoreSchema, handler: GetCoreSchemaHandler
133
- ):
134
- return handler(core_schema)
135
-
136
- return ParsedOutput
137
-
138
-
139
- def read_contents_from_file(
140
- file_path: str | Path,
141
- binary_mode: bool = False,
142
- ) -> str | bytes:
143
- """Reads and returns contents of file"""
144
- try:
145
- if binary_mode:
146
- with open(file_path, "rb") as file:
147
- return file.read()
148
- else:
149
- with open(file_path) as file:
150
- return file.read()
151
- except FileNotFoundError:
152
- logger.error(f"File {file_path} not found.")
153
- return ""
154
-
155
-
156
- def get_prompt(prompt_text: str | None, prompt_path: str | Path | None) -> str | None:
157
- if prompt_text is None:
158
- prompt = (
159
- read_contents_from_file(prompt_path) if prompt_path is not None else None
160
- )
161
- else:
162
- prompt = prompt_text
163
-
164
- return prompt # type: ignore[assignment]
165
-
166
-
167
- async def asyncio_gather_with_pbar(
168
- *corouts: Coroutine[Any, Any, Any],
169
- no_tqdm: bool = False,
170
- desc: str | None = None,
171
- ) -> list[Any]:
172
- pbar = tqdm(total=len(corouts), desc=desc, disable=no_tqdm)
173
-
174
- async def run_and_update(coro: Coroutine[Any, Any, Any]) -> Any:
175
- result = await coro
176
- pbar.update(1)
177
- return result
178
-
179
- wrapped_tasks = [run_and_update(c) for c in corouts]
180
- results = await asyncio.gather(*wrapped_tasks)
181
- pbar.close()
182
-
183
- return results
184
-
185
-
186
- def get_timestamp() -> str:
187
- return datetime.now().strftime("%Y%m%d_%H%M%S")
File without changes
File without changes