fabricatio 0.2.0.dev14__cp312-cp312-win_amd64.whl → 0.2.0.dev18__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,25 +1,14 @@
1
1
  """This module defines generic classes for models in the Fabricatio library."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Callable, Dict, Iterable, List, Optional, Self, Union
4
+ from typing import List, Self
5
5
 
6
- import litellm
7
6
  import orjson
8
- from fabricatio._rust_instances import template_manager
9
- from fabricatio.config import configs
10
7
  from fabricatio.fs.readers import magika
11
- from fabricatio.models.utils import Messages
12
- from fabricatio.parser import JsonCapture
13
- from litellm.types.utils import Choices, ModelResponse, StreamingChoices
14
8
  from pydantic import (
15
9
  BaseModel,
16
10
  ConfigDict,
17
11
  Field,
18
- HttpUrl,
19
- NonNegativeFloat,
20
- NonNegativeInt,
21
- PositiveInt,
22
- SecretStr,
23
12
  )
24
13
 
25
14
 
@@ -56,415 +45,6 @@ class WithBriefing(Named, Described):
56
45
  return f"{self.name}: {self.description}" if self.description else self.name
57
46
 
58
47
 
59
- class LLMUsage(Base):
60
- """Class that manages LLM (Large Language Model) usage parameters and methods."""
61
-
62
- llm_api_endpoint: Optional[HttpUrl] = None
63
- """The OpenAI API endpoint."""
64
-
65
- llm_api_key: Optional[SecretStr] = None
66
- """The OpenAI API key."""
67
-
68
- llm_timeout: Optional[PositiveInt] = None
69
- """The timeout of the LLM model."""
70
-
71
- llm_max_retries: Optional[PositiveInt] = None
72
- """The maximum number of retries."""
73
-
74
- llm_model: Optional[str] = None
75
- """The LLM model name."""
76
-
77
- llm_temperature: Optional[NonNegativeFloat] = None
78
- """The temperature of the LLM model."""
79
-
80
- llm_stop_sign: Optional[str | List[str]] = None
81
- """The stop sign of the LLM model."""
82
-
83
- llm_top_p: Optional[NonNegativeFloat] = None
84
- """The top p of the LLM model."""
85
-
86
- llm_generation_count: Optional[PositiveInt] = None
87
- """The number of generations to generate."""
88
-
89
- llm_stream: Optional[bool] = None
90
- """Whether to stream the LLM model's response."""
91
-
92
- llm_max_tokens: Optional[PositiveInt] = None
93
- """The maximum number of tokens to generate."""
94
-
95
- async def aquery(
96
- self,
97
- messages: List[Dict[str, str]],
98
- model: str | None = None,
99
- temperature: NonNegativeFloat | None = None,
100
- stop: str | List[str] | None = None,
101
- top_p: NonNegativeFloat | None = None,
102
- max_tokens: PositiveInt | None = None,
103
- n: PositiveInt | None = None,
104
- stream: bool | None = None,
105
- timeout: PositiveInt | None = None,
106
- max_retries: PositiveInt | None = None,
107
- ) -> ModelResponse:
108
- """Asynchronously queries the language model to generate a response based on the provided messages and parameters.
109
-
110
- Args:
111
- messages (List[Dict[str, str]]): A list of messages, where each message is a dictionary containing the role and content of the message.
112
- model (str | None): The name of the model to use. If not provided, the default model will be used.
113
- temperature (NonNegativeFloat | None): Controls the randomness of the output. Lower values make the output more deterministic.
114
- stop (str | None): A sequence at which to stop the generation of the response.
115
- top_p (NonNegativeFloat | None): Controls the diversity of the output through nucleus sampling.
116
- max_tokens (PositiveInt | None): The maximum number of tokens to generate in the response.
117
- n (PositiveInt | None): The number of responses to generate.
118
- stream (bool | None): Whether to receive the response in a streaming fashion.
119
- timeout (PositiveInt | None): The timeout duration for the request.
120
- max_retries (PositiveInt | None): The maximum number of retries in case of failure.
121
-
122
- Returns:
123
- ModelResponse: An object containing the generated response and other metadata from the model.
124
- """
125
- # Call the underlying asynchronous completion function with the provided and default parameters
126
- return await litellm.acompletion(
127
- messages=messages,
128
- model=model or self.llm_model or configs.llm.model,
129
- temperature=temperature or self.llm_temperature or configs.llm.temperature,
130
- stop=stop or self.llm_stop_sign or configs.llm.stop_sign,
131
- top_p=top_p or self.llm_top_p or configs.llm.top_p,
132
- max_tokens=max_tokens or self.llm_max_tokens or configs.llm.max_tokens,
133
- n=n or self.llm_generation_count or configs.llm.generation_count,
134
- stream=stream or self.llm_stream or configs.llm.stream,
135
- timeout=timeout or self.llm_timeout or configs.llm.timeout,
136
- max_retries=max_retries or self.llm_max_retries or configs.llm.max_retries,
137
- api_key=self.llm_api_key.get_secret_value() if self.llm_api_key else configs.llm.api_key.get_secret_value(),
138
- base_url=self.llm_api_endpoint.unicode_string()
139
- if self.llm_api_endpoint
140
- else configs.llm.api_endpoint.unicode_string(),
141
- )
142
-
143
- async def ainvoke(
144
- self,
145
- question: str,
146
- system_message: str = "",
147
- model: str | None = None,
148
- temperature: NonNegativeFloat | None = None,
149
- stop: str | List[str] | None = None,
150
- top_p: NonNegativeFloat | None = None,
151
- max_tokens: PositiveInt | None = None,
152
- n: PositiveInt | None = None,
153
- stream: bool | None = None,
154
- timeout: PositiveInt | None = None,
155
- max_retries: PositiveInt | None = None,
156
- ) -> List[Choices | StreamingChoices]:
157
- """Asynchronously invokes the language model with a question and optional system message.
158
-
159
- Args:
160
- question (str): The question to ask the model.
161
- system_message (str): The system message to provide context to the model.
162
- model (str | None): The name of the model to use. If not provided, the default model will be used.
163
- temperature (NonNegativeFloat | None): Controls the randomness of the output. Lower values make the output more deterministic.
164
- stop (str | None): A sequence at which to stop the generation of the response.
165
- top_p (NonNegativeFloat | None): Controls the diversity of the output through nucleus sampling.
166
- max_tokens (PositiveInt | None): The maximum number of tokens to generate in the response.
167
- n (PositiveInt | None): The number of responses to generate.
168
- stream (bool | None): Whether to receive the response in a streaming fashion.
169
- timeout (PositiveInt | None): The timeout duration for the request.
170
- max_retries (PositiveInt | None): The maximum number of retries in case of failure.
171
-
172
- Returns:
173
- List[Choices | StreamingChoices]: A list of choices or streaming choices from the model response.
174
- """
175
- return (
176
- await self.aquery(
177
- messages=Messages().add_system_message(system_message).add_user_message(question),
178
- model=model,
179
- temperature=temperature,
180
- stop=stop,
181
- top_p=top_p,
182
- max_tokens=max_tokens,
183
- n=n,
184
- stream=stream,
185
- timeout=timeout,
186
- max_retries=max_retries,
187
- )
188
- ).choices
189
-
190
- async def aask(
191
- self,
192
- question: str,
193
- system_message: str = "",
194
- model: str | None = None,
195
- temperature: NonNegativeFloat | None = None,
196
- stop: str | List[str] | None = None,
197
- top_p: NonNegativeFloat | None = None,
198
- max_tokens: PositiveInt | None = None,
199
- stream: bool | None = None,
200
- timeout: PositiveInt | None = None,
201
- max_retries: PositiveInt | None = None,
202
- ) -> str:
203
- """Asynchronously asks the language model a question and returns the response content.
204
-
205
- Args:
206
- question (str): The question to ask the model.
207
- system_message (str): The system message to provide context to the model.
208
- model (str | None): The name of the model to use. If not provided, the default model will be used.
209
- temperature (NonNegativeFloat | None): Controls the randomness of the output. Lower values make the output more deterministic.
210
- stop (str | None): A sequence at which to stop the generation of the response.
211
- top_p (NonNegativeFloat | None): Controls the diversity of the output through nucleus sampling.
212
- max_tokens (PositiveInt | None): The maximum number of tokens to generate in the response.
213
- stream (bool | None): Whether to receive the response in a streaming fashion.
214
- timeout (PositiveInt | None): The timeout duration for the request.
215
- max_retries (PositiveInt | None): The maximum number of retries in case of failure.
216
-
217
- Returns:
218
- str: The content of the model's response message.
219
- """
220
- return (
221
- (
222
- await self.ainvoke(
223
- n=1,
224
- question=question,
225
- system_message=system_message,
226
- model=model,
227
- temperature=temperature,
228
- stop=stop,
229
- top_p=top_p,
230
- max_tokens=max_tokens,
231
- stream=stream,
232
- timeout=timeout,
233
- max_retries=max_retries,
234
- )
235
- )
236
- .pop()
237
- .message.content
238
- )
239
-
240
- async def aask_validate[T](
241
- self,
242
- question: str,
243
- validator: Callable[[str], T | None],
244
- max_validations: PositiveInt = 2,
245
- system_message: str = "",
246
- model: str | None = None,
247
- temperature: NonNegativeFloat | None = None,
248
- stop: str | List[str] | None = None,
249
- top_p: NonNegativeFloat | None = None,
250
- max_tokens: PositiveInt | None = None,
251
- stream: bool | None = None,
252
- timeout: PositiveInt | None = None,
253
- max_retries: PositiveInt | None = None,
254
- ) -> T:
255
- """Asynchronously ask a question and validate the response using a given validator.
256
-
257
- Args:
258
- question (str): The question to ask.
259
- validator (Callable[[str], T | None]): A function to validate the response.
260
- max_validations (PositiveInt): Maximum number of validation attempts.
261
- system_message (str): System message to include in the request.
262
- model (str | None): The model to use for the request.
263
- temperature (NonNegativeFloat | None): Temperature setting for the request.
264
- stop (str | None): Stop sequence for the request.
265
- top_p (NonNegativeFloat | None): Top-p sampling parameter.
266
- max_tokens (PositiveInt | None): Maximum number of tokens in the response.
267
- stream (bool | None): Whether to stream the response.
268
- timeout (PositiveInt | None): Timeout for the request.
269
- max_retries (PositiveInt | None): Maximum number of retries for the request.
270
-
271
- Returns:
272
- T: The validated response.
273
-
274
- Raises:
275
- ValueError: If the response fails to validate after the maximum number of attempts.
276
- """
277
- for _ in range(max_validations):
278
- if (
279
- response := await self.aask(
280
- question=question,
281
- system_message=system_message,
282
- model=model,
283
- temperature=temperature,
284
- stop=stop,
285
- top_p=top_p,
286
- max_tokens=max_tokens,
287
- stream=stream,
288
- timeout=timeout,
289
- max_retries=max_retries,
290
- )
291
- ) and (validated := validator(response)):
292
- return validated
293
- raise ValueError("Failed to validate the response.")
294
-
295
- async def achoose[T: WithBriefing](
296
- self,
297
- instruction: str,
298
- choices: List[T],
299
- k: NonNegativeInt = 0,
300
- max_validations: PositiveInt = 2,
301
- system_message: str = "",
302
- model: str | None = None,
303
- temperature: NonNegativeFloat | None = None,
304
- stop: str | List[str] | None = None,
305
- top_p: NonNegativeFloat | None = None,
306
- max_tokens: PositiveInt | None = None,
307
- stream: bool | None = None,
308
- timeout: PositiveInt | None = None,
309
- max_retries: PositiveInt | None = None,
310
- ) -> List[T]:
311
- """Asynchronously executes a multi-choice decision-making process, generating a prompt based on the instruction and options, and validates the returned selection results.
312
-
313
- Args:
314
- instruction: The user-provided instruction/question description.
315
- choices: A list of candidate options, requiring elements to have `name` and `briefing` fields.
316
- k: The number of choices to select, 0 means infinite.
317
- max_validations: Maximum number of validation failures, default is 2.
318
- system_message: Custom system-level prompt, defaults to an empty string.
319
- model: The name of the LLM model to use.
320
- temperature: Sampling temperature to control randomness in generation.
321
- stop: Stop condition string or list for generation.
322
- top_p: Core sampling probability threshold.
323
- max_tokens: Maximum token limit for the generated result.
324
- stream: Whether to enable streaming response mode.
325
- timeout: Request timeout in seconds.
326
- max_retries: Maximum number of retries.
327
-
328
- Returns:
329
- List[T]: The final validated selection result list, with element types matching the input `choices`.
330
-
331
- Important:
332
- - Uses a template engine to generate structured prompts.
333
- - Ensures response compliance through JSON parsing and format validation.
334
- - Relies on `aask_validate` to implement retry mechanisms with validation.
335
- """
336
- prompt = template_manager.render_template(
337
- "make_choice",
338
- {
339
- "instruction": instruction,
340
- "options": [m.model_dump(include={"name", "briefing"}) for m in choices],
341
- "k": k,
342
- },
343
- )
344
- names = [c.name for c in choices]
345
-
346
- def _validate(response: str) -> List[T] | None:
347
- ret = JsonCapture.convert_with(response, orjson.loads)
348
- if not isinstance(ret, List) or len(ret) != k:
349
- return None
350
- if any(n not in names for n in ret):
351
- return None
352
- return ret
353
-
354
- return await self.aask_validate(
355
- question=prompt,
356
- validator=_validate,
357
- max_validations=max_validations,
358
- system_message=system_message,
359
- model=model,
360
- temperature=temperature,
361
- stop=stop,
362
- top_p=top_p,
363
- max_tokens=max_tokens,
364
- stream=stream,
365
- timeout=timeout,
366
- max_retries=max_retries,
367
- )
368
-
369
- async def ajudge(
370
- self,
371
- prompt: str,
372
- affirm_case: str = "",
373
- deny_case: str = "",
374
- max_validations: PositiveInt = 2,
375
- system_message: str = "",
376
- model: str | None = None,
377
- temperature: NonNegativeFloat | None = None,
378
- stop: str | List[str] | None = None,
379
- top_p: NonNegativeFloat | None = None,
380
- max_tokens: PositiveInt | None = None,
381
- stream: bool | None = None,
382
- timeout: PositiveInt | None = None,
383
- max_retries: PositiveInt | None = None,
384
- ) -> bool:
385
- """Asynchronously judges a prompt using AI validation.
386
-
387
- Args:
388
- prompt (str): The input prompt to be judged.
389
- affirm_case (str, optional): The affirmative case for the AI model. Defaults to "".
390
- deny_case (str, optional): The negative case for the AI model. Defaults to "".
391
- max_validations (PositiveInt, optional): Maximum number of validation attempts. Defaults to 2.
392
- system_message (str, optional): System message for the AI model. Defaults to "".
393
- model (str | None, optional): AI model to use. Defaults to None.
394
- temperature (NonNegativeFloat | None, optional): Sampling temperature. Defaults to None.
395
- stop (str | List[str] | None, optional): Stop sequences. Defaults to None.
396
- top_p (NonNegativeFloat | None, optional): Nucleus sampling parameter. Defaults to None.
397
- max_tokens (PositiveInt | None, optional): Maximum number of tokens to generate. Defaults to None.
398
- stream (bool | None, optional): Whether to stream the response. Defaults to None.
399
- timeout (PositiveInt | None, optional): Timeout in seconds. Defaults to None.
400
- max_retries (PositiveInt | None, optional): Maximum number of retries. Defaults to None.
401
-
402
- Returns:
403
- bool: The judgment result (True or False) based on the AI's response.
404
-
405
- Notes:
406
- The method uses an internal validator to ensure the response is a boolean value.
407
- If the response cannot be converted to a boolean, it will return None.
408
- """
409
-
410
- def _validate(response: str) -> bool | None:
411
- ret = JsonCapture.convert_with(response, orjson.loads)
412
- if not isinstance(ret, bool):
413
- return None
414
- return ret
415
-
416
- return await self.aask_validate(
417
- question=template_manager.render_template(
418
- "make_judgment", {"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case}
419
- ),
420
- validator=_validate,
421
- max_validations=max_validations,
422
- system_message=system_message,
423
- model=model,
424
- temperature=temperature,
425
- stop=stop,
426
- top_p=top_p,
427
- max_tokens=max_tokens,
428
- stream=stream,
429
- timeout=timeout,
430
- max_retries=max_retries,
431
- )
432
-
433
- def fallback_to(self, other: "LLMUsage") -> Self:
434
- """Fallback to another instance's attribute values if the current instance's attributes are None.
435
-
436
- Args:
437
- other (LLMUsage): Another instance from which to copy attribute values.
438
-
439
- Returns:
440
- Self: The current instance, allowing for method chaining.
441
- """
442
- # Iterate over the attribute names and copy values from 'other' to 'self' where applicable
443
- # noinspection PydanticTypeChecker,PyTypeChecker
444
- for attr_name in LLMUsage.model_fields:
445
- # Copy the attribute value from 'other' to 'self' only if 'self' has None and 'other' has a non-None value
446
- if getattr(self, attr_name) is None and (attr := getattr(other, attr_name)) is not None:
447
- setattr(self, attr_name, attr)
448
-
449
- # Return the current instance to allow for method chaining
450
- return self
451
-
452
- def hold_to(self, others: Union["LLMUsage", Iterable["LLMUsage"]]) -> Self:
453
- """Hold to another instance's attribute values if the current instance's attributes are None.
454
-
455
- Args:
456
- others (LLMUsage | Iterable[LLMUsage]): Another instance or iterable of instances from which to copy attribute values.
457
-
458
- Returns:
459
- Self: The current instance, allowing for method chaining.
460
- """
461
- for other in others:
462
- # noinspection PyTypeChecker,PydanticTypeChecker
463
- for attr_name in LLMUsage.model_fields:
464
- if (attr := getattr(self, attr_name)) is not None and getattr(other, attr_name) is None:
465
- setattr(other, attr_name, attr)
466
-
467
-
468
48
  class WithJsonExample(Base):
469
49
  """Class that provides a JSON schema for the model."""
470
50
 
@@ -0,0 +1,26 @@
1
+ """This module contains the types for the keyword arguments of the methods in the models module."""
2
+
3
+ from typing import List, NotRequired, TypedDict
4
+
5
+ from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
6
+
7
+
8
+ class LLMKwargs(TypedDict):
9
+ """A type representing the keyword arguments for the LLM (Large Language Model) usage."""
10
+
11
+ model: NotRequired[str]
12
+ temperature: NotRequired[NonNegativeFloat]
13
+ stop: NotRequired[str | List[str]]
14
+ top_p: NotRequired[NonNegativeFloat]
15
+ max_tokens: NotRequired[PositiveInt]
16
+ stream: NotRequired[bool]
17
+ timeout: NotRequired[PositiveInt]
18
+ max_retries: NotRequired[PositiveInt]
19
+
20
+
21
+ class ChooseKwargs(LLMKwargs):
22
+ """A type representing the keyword arguments for the choose method."""
23
+
24
+ max_validations: NotRequired[PositiveInt]
25
+ system_message: NotRequired[str]
26
+ k: NotRequired[NonNegativeInt]
fabricatio/models/role.py CHANGED
@@ -5,9 +5,10 @@ from typing import Any, Set
5
5
  from fabricatio.core import env
6
6
  from fabricatio.journal import logger
7
7
  from fabricatio.models.action import WorkFlow
8
+ from fabricatio.models.advanced import ProposeTask
8
9
  from fabricatio.models.events import Event
9
- from fabricatio.models.task import ProposeTask
10
- from fabricatio.models.tool import ToolBox, ToolBoxUsage
10
+ from fabricatio.models.tool import ToolBox
11
+ from fabricatio.models.usages import ToolBoxUsage
11
12
  from fabricatio.toolboxes import basic_toolboxes
12
13
  from pydantic import Field
13
14
 
fabricatio/models/task.py CHANGED
@@ -7,13 +7,11 @@ from asyncio import Queue
7
7
  from enum import Enum
8
8
  from typing import Any, List, Optional, Self
9
9
 
10
- from fabricatio._rust_instances import template_manager
11
10
  from fabricatio.core import env
12
11
  from fabricatio.journal import logger
13
12
  from fabricatio.models.events import Event, EventLike
14
- from fabricatio.models.generic import LLMUsage, WithBriefing, WithDependency, WithJsonExample
15
- from fabricatio.parser import JsonCapture
16
- from pydantic import Field, PrivateAttr, ValidationError
13
+ from fabricatio.models.generic import WithBriefing, WithDependency, WithJsonExample
14
+ from pydantic import Field, PrivateAttr
17
15
 
18
16
 
19
17
  class TaskStatus(Enum):
@@ -255,28 +253,3 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
255
253
  str: The briefing of the task.
256
254
  """
257
255
  return f"{super().briefing}\n{self.goal}"
258
-
259
-
260
- class ProposeTask(LLMUsage, WithBriefing):
261
- """A class that proposes a task based on a prompt."""
262
-
263
- async def propose(self, prompt: str) -> Task:
264
- """Propose a task based on the provided prompt."""
265
- assert prompt, "Prompt must be provided."
266
-
267
- def _validate_json(response: str) -> None | Task:
268
- try:
269
- cap = JsonCapture.capture(response)
270
- logger.debug(f"Response: \n{response}")
271
- logger.info(f"Captured JSON: \n{cap}")
272
- return Task.model_validate_json(cap)
273
- except ValidationError as e:
274
- logger.error(f"Failed to parse task from JSON: {e}")
275
- return None
276
-
277
- template_data = {"prompt": prompt, "json_example": Task.json_example()}
278
- return await self.aask_validate(
279
- template_manager.render_template("propose_task", template_data),
280
- _validate_json,
281
- system_message=f"# your personal briefing: \n{self.briefing}",
282
- )
fabricatio/models/tool.py CHANGED
@@ -1,11 +1,16 @@
1
1
  """A module for defining tools and toolboxes."""
2
2
 
3
+ from importlib.machinery import ModuleSpec
4
+ from importlib.util import module_from_spec
3
5
  from inspect import iscoroutinefunction, signature
4
- from typing import Any, Callable, Iterable, List, Self, Set, Union
6
+ from sys import modules
7
+ from types import CodeType, ModuleType
8
+ from typing import Any, Callable, Dict, List, Optional, Self, overload
5
9
 
10
+ from fabricatio.config import configs
6
11
  from fabricatio.journal import logger
7
- from fabricatio.models.generic import Base, WithBriefing
8
- from pydantic import Field
12
+ from fabricatio.models.generic import WithBriefing
13
+ from pydantic import BaseModel, ConfigDict, Field
9
14
 
10
15
 
11
16
  class Tool[**P, R](WithBriefing):
@@ -23,7 +28,9 @@ class Tool[**P, R](WithBriefing):
23
28
  def model_post_init(self, __context: Any) -> None:
24
29
  """Initialize the tool with a name and a source function."""
25
30
  self.name = self.name or self.source.__name__
26
- assert self.name, "The tool must have a name."
31
+
32
+ if not self.name:
33
+ raise RuntimeError("The tool must have a source function.")
27
34
  self.description = self.description or self.source.__doc__ or ""
28
35
  self.description = self.description.strip()
29
36
 
@@ -53,7 +60,7 @@ def _desc_wrapper(desc: str) -> str:
53
60
  class ToolBox(WithBriefing):
54
61
  """A class representing a collection of tools."""
55
62
 
56
- tools: List[Tool] = Field(default_factory=list)
63
+ tools: List[Tool] = Field(default_factory=list, frozen=True)
57
64
  """A list of tools in the toolbox."""
58
65
 
59
66
  def collect_tool[**P, R](self, func: Callable[P, R]) -> Callable[P, R]:
@@ -101,10 +108,14 @@ class ToolBox(WithBriefing):
101
108
  Tool: The tool instance with the specified name.
102
109
 
103
110
  Raises:
104
- AssertionError: If no tool with the specified name is found.
111
+ ValueError: If no tool with the specified name is found.
105
112
  """
106
113
  tool = next((tool for tool in self.tools if tool.name == name), None)
107
- assert tool, f"No tool named {name} found."
114
+ if tool is None:
115
+ err = f"No tool with the name {name} found in the toolbox."
116
+ logger.error(err)
117
+ raise ValueError(err)
118
+
108
119
  return tool
109
120
 
110
121
  def __hash__(self) -> int:
@@ -112,45 +123,50 @@ class ToolBox(WithBriefing):
112
123
  return hash(self.briefing)
113
124
 
114
125
 
115
- class ToolBoxUsage(Base):
116
- """A class representing the usage of tools in a task."""
117
-
118
- toolboxes: Set[ToolBox] = Field(default_factory=set)
119
- """A set of toolboxes used by the instance."""
120
-
121
- @property
122
- def available_toolbox_names(self) -> List[str]:
123
- """Return a list of available toolbox names."""
124
- return [toolbox.name for toolbox in self.toolboxes]
125
-
126
- def supply_tools_from(self, others: Union["ToolBoxUsage", Iterable["ToolBoxUsage"]]) -> Self:
127
- """Supplies tools from other ToolUsage instances to this instance.
128
-
129
- Args:
130
- others ("ToolUsage" | Iterable["ToolUsage"]): A single ToolUsage instance or an iterable of ToolUsage instances
131
- from which to take tools.
132
-
133
- Returns:
134
- Self: The current ToolUsage instance with updated tools.
135
- """
136
- if isinstance(others, ToolBoxUsage):
137
- others = [others]
138
- for other in others:
139
- self.toolboxes.update(other.toolboxes)
140
- return self
141
-
142
- def provide_tools_to(self, others: Union["ToolBoxUsage", Iterable["ToolBoxUsage"]]) -> Self:
143
- """Provides tools from this instance to other ToolUsage instances.
144
-
145
- Args:
146
- others ("ToolUsage" | Iterable["ToolUsage"]): A single ToolUsage instance or an iterable of ToolUsage instances
147
- to which to provide tools.
148
-
149
- Returns:
150
- Self: The current ToolUsage instance.
151
- """
152
- if isinstance(others, ToolBoxUsage):
153
- others = [others]
154
- for other in others:
155
- other.toolboxes.update(self.toolboxes)
156
- return self
126
+ class ToolExecutor(BaseModel):
127
+ """A class representing a tool executor with a sequence of tools to execute."""
128
+
129
+ model_config = ConfigDict(use_attribute_docstrings=True)
130
+ execute_sequence: List[Tool] = Field(default_factory=list, frozen=True)
131
+ """The sequence of tools to execute."""
132
+
133
+ def inject_tools[M: ModuleType](self, module: Optional[M] = None) -> M:
134
+ """Inject the tools into the provided module."""
135
+ module = module or module_from_spec(spec=ModuleSpec(name=configs.toolbox.tool_module_name, loader=None))
136
+ for tool in self.execute_sequence:
137
+ setattr(module, tool.name, tool.invoke)
138
+ return module
139
+
140
+ def execute[C: Dict[str, Any]](self, source: CodeType, cxt: Optional[C] = None) -> C:
141
+ """Execute the sequence of tools with the provided context."""
142
+ modules[configs.toolbox.tool_module_name] = self.inject_tools()
143
+ exec(source, cxt) # noqa: S102
144
+ modules.pop(configs.toolbox.tool_module_name)
145
+ return cxt
146
+
147
+ @overload
148
+ def take[C: Dict[str, Any]](self, keys: List[str], source: CodeType, cxt: Optional[C] = None) -> C:
149
+ """Check the output of the tools with the provided context."""
150
+ ...
151
+
152
+ @overload
153
+ def take[C: Dict[str, Any]](self, keys: str, source: CodeType, cxt: Optional[C] = None) -> Any:
154
+ """Check the output of the tools with the provided context."""
155
+ ...
156
+
157
+ def take[C: Dict[str, Any]](self, keys: List[str] | str, source: CodeType, cxt: Optional[C] = None) -> C | Any:
158
+ """Check the output of the tools with the provided context."""
159
+ cxt = self.execute(source, cxt)
160
+ if isinstance(keys, str):
161
+ return cxt[keys]
162
+ return {key: cxt[key] for key in keys}
163
+
164
+ @classmethod
165
+ def from_recipe(cls, recipe: List[str], toolboxes: List[ToolBox]) -> Self:
166
+ """Create a tool executor from a recipe and a list of toolboxes."""
167
+ tools = []
168
+ while tool_name := recipe.pop(0):
169
+ for toolbox in toolboxes:
170
+ tools.append(toolbox[tool_name])
171
+
172
+ return cls(execute_sequence=tools)