fabricatio 0.2.0.dev14__cp312-cp312-win_amd64.whl → 0.2.0.dev18__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/_rust.pyi +11 -11
- fabricatio/_rust_instances.py +1 -1
- fabricatio/actions/communication.py +2 -0
- fabricatio/actions/transmission.py +2 -0
- fabricatio/config.py +40 -14
- fabricatio/decorators.py +44 -7
- fabricatio/fs/curd.py +110 -0
- fabricatio/fs/readers.py +2 -0
- fabricatio/models/action.py +5 -4
- fabricatio/models/advanced.py +119 -0
- fabricatio/models/events.py +4 -2
- fabricatio/models/generic.py +1 -421
- fabricatio/models/kwargs_types.py +26 -0
- fabricatio/models/role.py +3 -2
- fabricatio/models/task.py +2 -29
- fabricatio/models/tool.py +65 -49
- fabricatio/models/usages.py +456 -0
- fabricatio/parser.py +1 -1
- fabricatio/toolboxes/fs.py +14 -0
- fabricatio/toolboxes/task.py +2 -0
- {fabricatio-0.2.0.dev14.data → fabricatio-0.2.0.dev18.data}/scripts/tdown.exe +0 -0
- {fabricatio-0.2.0.dev14.dist-info → fabricatio-0.2.0.dev18.dist-info}/METADATA +6 -1
- fabricatio-0.2.0.dev18.dist-info/RECORD +35 -0
- fabricatio-0.2.0.dev14.dist-info/RECORD +0 -30
- {fabricatio-0.2.0.dev14.dist-info → fabricatio-0.2.0.dev18.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.0.dev14.dist-info → fabricatio-0.2.0.dev18.dist-info}/licenses/LICENSE +0 -0
fabricatio/models/generic.py
CHANGED
@@ -1,25 +1,14 @@
|
|
1
1
|
"""This module defines generic classes for models in the Fabricatio library."""
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import
|
4
|
+
from typing import List, Self
|
5
5
|
|
6
|
-
import litellm
|
7
6
|
import orjson
|
8
|
-
from fabricatio._rust_instances import template_manager
|
9
|
-
from fabricatio.config import configs
|
10
7
|
from fabricatio.fs.readers import magika
|
11
|
-
from fabricatio.models.utils import Messages
|
12
|
-
from fabricatio.parser import JsonCapture
|
13
|
-
from litellm.types.utils import Choices, ModelResponse, StreamingChoices
|
14
8
|
from pydantic import (
|
15
9
|
BaseModel,
|
16
10
|
ConfigDict,
|
17
11
|
Field,
|
18
|
-
HttpUrl,
|
19
|
-
NonNegativeFloat,
|
20
|
-
NonNegativeInt,
|
21
|
-
PositiveInt,
|
22
|
-
SecretStr,
|
23
12
|
)
|
24
13
|
|
25
14
|
|
@@ -56,415 +45,6 @@ class WithBriefing(Named, Described):
|
|
56
45
|
return f"{self.name}: {self.description}" if self.description else self.name
|
57
46
|
|
58
47
|
|
59
|
-
class LLMUsage(Base):
|
60
|
-
"""Class that manages LLM (Large Language Model) usage parameters and methods."""
|
61
|
-
|
62
|
-
llm_api_endpoint: Optional[HttpUrl] = None
|
63
|
-
"""The OpenAI API endpoint."""
|
64
|
-
|
65
|
-
llm_api_key: Optional[SecretStr] = None
|
66
|
-
"""The OpenAI API key."""
|
67
|
-
|
68
|
-
llm_timeout: Optional[PositiveInt] = None
|
69
|
-
"""The timeout of the LLM model."""
|
70
|
-
|
71
|
-
llm_max_retries: Optional[PositiveInt] = None
|
72
|
-
"""The maximum number of retries."""
|
73
|
-
|
74
|
-
llm_model: Optional[str] = None
|
75
|
-
"""The LLM model name."""
|
76
|
-
|
77
|
-
llm_temperature: Optional[NonNegativeFloat] = None
|
78
|
-
"""The temperature of the LLM model."""
|
79
|
-
|
80
|
-
llm_stop_sign: Optional[str | List[str]] = None
|
81
|
-
"""The stop sign of the LLM model."""
|
82
|
-
|
83
|
-
llm_top_p: Optional[NonNegativeFloat] = None
|
84
|
-
"""The top p of the LLM model."""
|
85
|
-
|
86
|
-
llm_generation_count: Optional[PositiveInt] = None
|
87
|
-
"""The number of generations to generate."""
|
88
|
-
|
89
|
-
llm_stream: Optional[bool] = None
|
90
|
-
"""Whether to stream the LLM model's response."""
|
91
|
-
|
92
|
-
llm_max_tokens: Optional[PositiveInt] = None
|
93
|
-
"""The maximum number of tokens to generate."""
|
94
|
-
|
95
|
-
async def aquery(
|
96
|
-
self,
|
97
|
-
messages: List[Dict[str, str]],
|
98
|
-
model: str | None = None,
|
99
|
-
temperature: NonNegativeFloat | None = None,
|
100
|
-
stop: str | List[str] | None = None,
|
101
|
-
top_p: NonNegativeFloat | None = None,
|
102
|
-
max_tokens: PositiveInt | None = None,
|
103
|
-
n: PositiveInt | None = None,
|
104
|
-
stream: bool | None = None,
|
105
|
-
timeout: PositiveInt | None = None,
|
106
|
-
max_retries: PositiveInt | None = None,
|
107
|
-
) -> ModelResponse:
|
108
|
-
"""Asynchronously queries the language model to generate a response based on the provided messages and parameters.
|
109
|
-
|
110
|
-
Args:
|
111
|
-
messages (List[Dict[str, str]]): A list of messages, where each message is a dictionary containing the role and content of the message.
|
112
|
-
model (str | None): The name of the model to use. If not provided, the default model will be used.
|
113
|
-
temperature (NonNegativeFloat | None): Controls the randomness of the output. Lower values make the output more deterministic.
|
114
|
-
stop (str | None): A sequence at which to stop the generation of the response.
|
115
|
-
top_p (NonNegativeFloat | None): Controls the diversity of the output through nucleus sampling.
|
116
|
-
max_tokens (PositiveInt | None): The maximum number of tokens to generate in the response.
|
117
|
-
n (PositiveInt | None): The number of responses to generate.
|
118
|
-
stream (bool | None): Whether to receive the response in a streaming fashion.
|
119
|
-
timeout (PositiveInt | None): The timeout duration for the request.
|
120
|
-
max_retries (PositiveInt | None): The maximum number of retries in case of failure.
|
121
|
-
|
122
|
-
Returns:
|
123
|
-
ModelResponse: An object containing the generated response and other metadata from the model.
|
124
|
-
"""
|
125
|
-
# Call the underlying asynchronous completion function with the provided and default parameters
|
126
|
-
return await litellm.acompletion(
|
127
|
-
messages=messages,
|
128
|
-
model=model or self.llm_model or configs.llm.model,
|
129
|
-
temperature=temperature or self.llm_temperature or configs.llm.temperature,
|
130
|
-
stop=stop or self.llm_stop_sign or configs.llm.stop_sign,
|
131
|
-
top_p=top_p or self.llm_top_p or configs.llm.top_p,
|
132
|
-
max_tokens=max_tokens or self.llm_max_tokens or configs.llm.max_tokens,
|
133
|
-
n=n or self.llm_generation_count or configs.llm.generation_count,
|
134
|
-
stream=stream or self.llm_stream or configs.llm.stream,
|
135
|
-
timeout=timeout or self.llm_timeout or configs.llm.timeout,
|
136
|
-
max_retries=max_retries or self.llm_max_retries or configs.llm.max_retries,
|
137
|
-
api_key=self.llm_api_key.get_secret_value() if self.llm_api_key else configs.llm.api_key.get_secret_value(),
|
138
|
-
base_url=self.llm_api_endpoint.unicode_string()
|
139
|
-
if self.llm_api_endpoint
|
140
|
-
else configs.llm.api_endpoint.unicode_string(),
|
141
|
-
)
|
142
|
-
|
143
|
-
async def ainvoke(
|
144
|
-
self,
|
145
|
-
question: str,
|
146
|
-
system_message: str = "",
|
147
|
-
model: str | None = None,
|
148
|
-
temperature: NonNegativeFloat | None = None,
|
149
|
-
stop: str | List[str] | None = None,
|
150
|
-
top_p: NonNegativeFloat | None = None,
|
151
|
-
max_tokens: PositiveInt | None = None,
|
152
|
-
n: PositiveInt | None = None,
|
153
|
-
stream: bool | None = None,
|
154
|
-
timeout: PositiveInt | None = None,
|
155
|
-
max_retries: PositiveInt | None = None,
|
156
|
-
) -> List[Choices | StreamingChoices]:
|
157
|
-
"""Asynchronously invokes the language model with a question and optional system message.
|
158
|
-
|
159
|
-
Args:
|
160
|
-
question (str): The question to ask the model.
|
161
|
-
system_message (str): The system message to provide context to the model.
|
162
|
-
model (str | None): The name of the model to use. If not provided, the default model will be used.
|
163
|
-
temperature (NonNegativeFloat | None): Controls the randomness of the output. Lower values make the output more deterministic.
|
164
|
-
stop (str | None): A sequence at which to stop the generation of the response.
|
165
|
-
top_p (NonNegativeFloat | None): Controls the diversity of the output through nucleus sampling.
|
166
|
-
max_tokens (PositiveInt | None): The maximum number of tokens to generate in the response.
|
167
|
-
n (PositiveInt | None): The number of responses to generate.
|
168
|
-
stream (bool | None): Whether to receive the response in a streaming fashion.
|
169
|
-
timeout (PositiveInt | None): The timeout duration for the request.
|
170
|
-
max_retries (PositiveInt | None): The maximum number of retries in case of failure.
|
171
|
-
|
172
|
-
Returns:
|
173
|
-
List[Choices | StreamingChoices]: A list of choices or streaming choices from the model response.
|
174
|
-
"""
|
175
|
-
return (
|
176
|
-
await self.aquery(
|
177
|
-
messages=Messages().add_system_message(system_message).add_user_message(question),
|
178
|
-
model=model,
|
179
|
-
temperature=temperature,
|
180
|
-
stop=stop,
|
181
|
-
top_p=top_p,
|
182
|
-
max_tokens=max_tokens,
|
183
|
-
n=n,
|
184
|
-
stream=stream,
|
185
|
-
timeout=timeout,
|
186
|
-
max_retries=max_retries,
|
187
|
-
)
|
188
|
-
).choices
|
189
|
-
|
190
|
-
async def aask(
|
191
|
-
self,
|
192
|
-
question: str,
|
193
|
-
system_message: str = "",
|
194
|
-
model: str | None = None,
|
195
|
-
temperature: NonNegativeFloat | None = None,
|
196
|
-
stop: str | List[str] | None = None,
|
197
|
-
top_p: NonNegativeFloat | None = None,
|
198
|
-
max_tokens: PositiveInt | None = None,
|
199
|
-
stream: bool | None = None,
|
200
|
-
timeout: PositiveInt | None = None,
|
201
|
-
max_retries: PositiveInt | None = None,
|
202
|
-
) -> str:
|
203
|
-
"""Asynchronously asks the language model a question and returns the response content.
|
204
|
-
|
205
|
-
Args:
|
206
|
-
question (str): The question to ask the model.
|
207
|
-
system_message (str): The system message to provide context to the model.
|
208
|
-
model (str | None): The name of the model to use. If not provided, the default model will be used.
|
209
|
-
temperature (NonNegativeFloat | None): Controls the randomness of the output. Lower values make the output more deterministic.
|
210
|
-
stop (str | None): A sequence at which to stop the generation of the response.
|
211
|
-
top_p (NonNegativeFloat | None): Controls the diversity of the output through nucleus sampling.
|
212
|
-
max_tokens (PositiveInt | None): The maximum number of tokens to generate in the response.
|
213
|
-
stream (bool | None): Whether to receive the response in a streaming fashion.
|
214
|
-
timeout (PositiveInt | None): The timeout duration for the request.
|
215
|
-
max_retries (PositiveInt | None): The maximum number of retries in case of failure.
|
216
|
-
|
217
|
-
Returns:
|
218
|
-
str: The content of the model's response message.
|
219
|
-
"""
|
220
|
-
return (
|
221
|
-
(
|
222
|
-
await self.ainvoke(
|
223
|
-
n=1,
|
224
|
-
question=question,
|
225
|
-
system_message=system_message,
|
226
|
-
model=model,
|
227
|
-
temperature=temperature,
|
228
|
-
stop=stop,
|
229
|
-
top_p=top_p,
|
230
|
-
max_tokens=max_tokens,
|
231
|
-
stream=stream,
|
232
|
-
timeout=timeout,
|
233
|
-
max_retries=max_retries,
|
234
|
-
)
|
235
|
-
)
|
236
|
-
.pop()
|
237
|
-
.message.content
|
238
|
-
)
|
239
|
-
|
240
|
-
async def aask_validate[T](
|
241
|
-
self,
|
242
|
-
question: str,
|
243
|
-
validator: Callable[[str], T | None],
|
244
|
-
max_validations: PositiveInt = 2,
|
245
|
-
system_message: str = "",
|
246
|
-
model: str | None = None,
|
247
|
-
temperature: NonNegativeFloat | None = None,
|
248
|
-
stop: str | List[str] | None = None,
|
249
|
-
top_p: NonNegativeFloat | None = None,
|
250
|
-
max_tokens: PositiveInt | None = None,
|
251
|
-
stream: bool | None = None,
|
252
|
-
timeout: PositiveInt | None = None,
|
253
|
-
max_retries: PositiveInt | None = None,
|
254
|
-
) -> T:
|
255
|
-
"""Asynchronously ask a question and validate the response using a given validator.
|
256
|
-
|
257
|
-
Args:
|
258
|
-
question (str): The question to ask.
|
259
|
-
validator (Callable[[str], T | None]): A function to validate the response.
|
260
|
-
max_validations (PositiveInt): Maximum number of validation attempts.
|
261
|
-
system_message (str): System message to include in the request.
|
262
|
-
model (str | None): The model to use for the request.
|
263
|
-
temperature (NonNegativeFloat | None): Temperature setting for the request.
|
264
|
-
stop (str | None): Stop sequence for the request.
|
265
|
-
top_p (NonNegativeFloat | None): Top-p sampling parameter.
|
266
|
-
max_tokens (PositiveInt | None): Maximum number of tokens in the response.
|
267
|
-
stream (bool | None): Whether to stream the response.
|
268
|
-
timeout (PositiveInt | None): Timeout for the request.
|
269
|
-
max_retries (PositiveInt | None): Maximum number of retries for the request.
|
270
|
-
|
271
|
-
Returns:
|
272
|
-
T: The validated response.
|
273
|
-
|
274
|
-
Raises:
|
275
|
-
ValueError: If the response fails to validate after the maximum number of attempts.
|
276
|
-
"""
|
277
|
-
for _ in range(max_validations):
|
278
|
-
if (
|
279
|
-
response := await self.aask(
|
280
|
-
question=question,
|
281
|
-
system_message=system_message,
|
282
|
-
model=model,
|
283
|
-
temperature=temperature,
|
284
|
-
stop=stop,
|
285
|
-
top_p=top_p,
|
286
|
-
max_tokens=max_tokens,
|
287
|
-
stream=stream,
|
288
|
-
timeout=timeout,
|
289
|
-
max_retries=max_retries,
|
290
|
-
)
|
291
|
-
) and (validated := validator(response)):
|
292
|
-
return validated
|
293
|
-
raise ValueError("Failed to validate the response.")
|
294
|
-
|
295
|
-
async def achoose[T: WithBriefing](
|
296
|
-
self,
|
297
|
-
instruction: str,
|
298
|
-
choices: List[T],
|
299
|
-
k: NonNegativeInt = 0,
|
300
|
-
max_validations: PositiveInt = 2,
|
301
|
-
system_message: str = "",
|
302
|
-
model: str | None = None,
|
303
|
-
temperature: NonNegativeFloat | None = None,
|
304
|
-
stop: str | List[str] | None = None,
|
305
|
-
top_p: NonNegativeFloat | None = None,
|
306
|
-
max_tokens: PositiveInt | None = None,
|
307
|
-
stream: bool | None = None,
|
308
|
-
timeout: PositiveInt | None = None,
|
309
|
-
max_retries: PositiveInt | None = None,
|
310
|
-
) -> List[T]:
|
311
|
-
"""Asynchronously executes a multi-choice decision-making process, generating a prompt based on the instruction and options, and validates the returned selection results.
|
312
|
-
|
313
|
-
Args:
|
314
|
-
instruction: The user-provided instruction/question description.
|
315
|
-
choices: A list of candidate options, requiring elements to have `name` and `briefing` fields.
|
316
|
-
k: The number of choices to select, 0 means infinite.
|
317
|
-
max_validations: Maximum number of validation failures, default is 2.
|
318
|
-
system_message: Custom system-level prompt, defaults to an empty string.
|
319
|
-
model: The name of the LLM model to use.
|
320
|
-
temperature: Sampling temperature to control randomness in generation.
|
321
|
-
stop: Stop condition string or list for generation.
|
322
|
-
top_p: Core sampling probability threshold.
|
323
|
-
max_tokens: Maximum token limit for the generated result.
|
324
|
-
stream: Whether to enable streaming response mode.
|
325
|
-
timeout: Request timeout in seconds.
|
326
|
-
max_retries: Maximum number of retries.
|
327
|
-
|
328
|
-
Returns:
|
329
|
-
List[T]: The final validated selection result list, with element types matching the input `choices`.
|
330
|
-
|
331
|
-
Important:
|
332
|
-
- Uses a template engine to generate structured prompts.
|
333
|
-
- Ensures response compliance through JSON parsing and format validation.
|
334
|
-
- Relies on `aask_validate` to implement retry mechanisms with validation.
|
335
|
-
"""
|
336
|
-
prompt = template_manager.render_template(
|
337
|
-
"make_choice",
|
338
|
-
{
|
339
|
-
"instruction": instruction,
|
340
|
-
"options": [m.model_dump(include={"name", "briefing"}) for m in choices],
|
341
|
-
"k": k,
|
342
|
-
},
|
343
|
-
)
|
344
|
-
names = [c.name for c in choices]
|
345
|
-
|
346
|
-
def _validate(response: str) -> List[T] | None:
|
347
|
-
ret = JsonCapture.convert_with(response, orjson.loads)
|
348
|
-
if not isinstance(ret, List) or len(ret) != k:
|
349
|
-
return None
|
350
|
-
if any(n not in names for n in ret):
|
351
|
-
return None
|
352
|
-
return ret
|
353
|
-
|
354
|
-
return await self.aask_validate(
|
355
|
-
question=prompt,
|
356
|
-
validator=_validate,
|
357
|
-
max_validations=max_validations,
|
358
|
-
system_message=system_message,
|
359
|
-
model=model,
|
360
|
-
temperature=temperature,
|
361
|
-
stop=stop,
|
362
|
-
top_p=top_p,
|
363
|
-
max_tokens=max_tokens,
|
364
|
-
stream=stream,
|
365
|
-
timeout=timeout,
|
366
|
-
max_retries=max_retries,
|
367
|
-
)
|
368
|
-
|
369
|
-
async def ajudge(
|
370
|
-
self,
|
371
|
-
prompt: str,
|
372
|
-
affirm_case: str = "",
|
373
|
-
deny_case: str = "",
|
374
|
-
max_validations: PositiveInt = 2,
|
375
|
-
system_message: str = "",
|
376
|
-
model: str | None = None,
|
377
|
-
temperature: NonNegativeFloat | None = None,
|
378
|
-
stop: str | List[str] | None = None,
|
379
|
-
top_p: NonNegativeFloat | None = None,
|
380
|
-
max_tokens: PositiveInt | None = None,
|
381
|
-
stream: bool | None = None,
|
382
|
-
timeout: PositiveInt | None = None,
|
383
|
-
max_retries: PositiveInt | None = None,
|
384
|
-
) -> bool:
|
385
|
-
"""Asynchronously judges a prompt using AI validation.
|
386
|
-
|
387
|
-
Args:
|
388
|
-
prompt (str): The input prompt to be judged.
|
389
|
-
affirm_case (str, optional): The affirmative case for the AI model. Defaults to "".
|
390
|
-
deny_case (str, optional): The negative case for the AI model. Defaults to "".
|
391
|
-
max_validations (PositiveInt, optional): Maximum number of validation attempts. Defaults to 2.
|
392
|
-
system_message (str, optional): System message for the AI model. Defaults to "".
|
393
|
-
model (str | None, optional): AI model to use. Defaults to None.
|
394
|
-
temperature (NonNegativeFloat | None, optional): Sampling temperature. Defaults to None.
|
395
|
-
stop (str | List[str] | None, optional): Stop sequences. Defaults to None.
|
396
|
-
top_p (NonNegativeFloat | None, optional): Nucleus sampling parameter. Defaults to None.
|
397
|
-
max_tokens (PositiveInt | None, optional): Maximum number of tokens to generate. Defaults to None.
|
398
|
-
stream (bool | None, optional): Whether to stream the response. Defaults to None.
|
399
|
-
timeout (PositiveInt | None, optional): Timeout in seconds. Defaults to None.
|
400
|
-
max_retries (PositiveInt | None, optional): Maximum number of retries. Defaults to None.
|
401
|
-
|
402
|
-
Returns:
|
403
|
-
bool: The judgment result (True or False) based on the AI's response.
|
404
|
-
|
405
|
-
Notes:
|
406
|
-
The method uses an internal validator to ensure the response is a boolean value.
|
407
|
-
If the response cannot be converted to a boolean, it will return None.
|
408
|
-
"""
|
409
|
-
|
410
|
-
def _validate(response: str) -> bool | None:
|
411
|
-
ret = JsonCapture.convert_with(response, orjson.loads)
|
412
|
-
if not isinstance(ret, bool):
|
413
|
-
return None
|
414
|
-
return ret
|
415
|
-
|
416
|
-
return await self.aask_validate(
|
417
|
-
question=template_manager.render_template(
|
418
|
-
"make_judgment", {"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case}
|
419
|
-
),
|
420
|
-
validator=_validate,
|
421
|
-
max_validations=max_validations,
|
422
|
-
system_message=system_message,
|
423
|
-
model=model,
|
424
|
-
temperature=temperature,
|
425
|
-
stop=stop,
|
426
|
-
top_p=top_p,
|
427
|
-
max_tokens=max_tokens,
|
428
|
-
stream=stream,
|
429
|
-
timeout=timeout,
|
430
|
-
max_retries=max_retries,
|
431
|
-
)
|
432
|
-
|
433
|
-
def fallback_to(self, other: "LLMUsage") -> Self:
|
434
|
-
"""Fallback to another instance's attribute values if the current instance's attributes are None.
|
435
|
-
|
436
|
-
Args:
|
437
|
-
other (LLMUsage): Another instance from which to copy attribute values.
|
438
|
-
|
439
|
-
Returns:
|
440
|
-
Self: The current instance, allowing for method chaining.
|
441
|
-
"""
|
442
|
-
# Iterate over the attribute names and copy values from 'other' to 'self' where applicable
|
443
|
-
# noinspection PydanticTypeChecker,PyTypeChecker
|
444
|
-
for attr_name in LLMUsage.model_fields:
|
445
|
-
# Copy the attribute value from 'other' to 'self' only if 'self' has None and 'other' has a non-None value
|
446
|
-
if getattr(self, attr_name) is None and (attr := getattr(other, attr_name)) is not None:
|
447
|
-
setattr(self, attr_name, attr)
|
448
|
-
|
449
|
-
# Return the current instance to allow for method chaining
|
450
|
-
return self
|
451
|
-
|
452
|
-
def hold_to(self, others: Union["LLMUsage", Iterable["LLMUsage"]]) -> Self:
|
453
|
-
"""Hold to another instance's attribute values if the current instance's attributes are None.
|
454
|
-
|
455
|
-
Args:
|
456
|
-
others (LLMUsage | Iterable[LLMUsage]): Another instance or iterable of instances from which to copy attribute values.
|
457
|
-
|
458
|
-
Returns:
|
459
|
-
Self: The current instance, allowing for method chaining.
|
460
|
-
"""
|
461
|
-
for other in others:
|
462
|
-
# noinspection PyTypeChecker,PydanticTypeChecker
|
463
|
-
for attr_name in LLMUsage.model_fields:
|
464
|
-
if (attr := getattr(self, attr_name)) is not None and getattr(other, attr_name) is None:
|
465
|
-
setattr(other, attr_name, attr)
|
466
|
-
|
467
|
-
|
468
48
|
class WithJsonExample(Base):
|
469
49
|
"""Class that provides a JSON schema for the model."""
|
470
50
|
|
@@ -0,0 +1,26 @@
|
|
1
|
+
"""This module contains the types for the keyword arguments of the methods in the models module."""
|
2
|
+
|
3
|
+
from typing import List, NotRequired, TypedDict
|
4
|
+
|
5
|
+
from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
|
6
|
+
|
7
|
+
|
8
|
+
class LLMKwargs(TypedDict):
|
9
|
+
"""A type representing the keyword arguments for the LLM (Large Language Model) usage."""
|
10
|
+
|
11
|
+
model: NotRequired[str]
|
12
|
+
temperature: NotRequired[NonNegativeFloat]
|
13
|
+
stop: NotRequired[str | List[str]]
|
14
|
+
top_p: NotRequired[NonNegativeFloat]
|
15
|
+
max_tokens: NotRequired[PositiveInt]
|
16
|
+
stream: NotRequired[bool]
|
17
|
+
timeout: NotRequired[PositiveInt]
|
18
|
+
max_retries: NotRequired[PositiveInt]
|
19
|
+
|
20
|
+
|
21
|
+
class ChooseKwargs(LLMKwargs):
|
22
|
+
"""A type representing the keyword arguments for the choose method."""
|
23
|
+
|
24
|
+
max_validations: NotRequired[PositiveInt]
|
25
|
+
system_message: NotRequired[str]
|
26
|
+
k: NotRequired[NonNegativeInt]
|
fabricatio/models/role.py
CHANGED
@@ -5,9 +5,10 @@ from typing import Any, Set
|
|
5
5
|
from fabricatio.core import env
|
6
6
|
from fabricatio.journal import logger
|
7
7
|
from fabricatio.models.action import WorkFlow
|
8
|
+
from fabricatio.models.advanced import ProposeTask
|
8
9
|
from fabricatio.models.events import Event
|
9
|
-
from fabricatio.models.
|
10
|
-
from fabricatio.models.
|
10
|
+
from fabricatio.models.tool import ToolBox
|
11
|
+
from fabricatio.models.usages import ToolBoxUsage
|
11
12
|
from fabricatio.toolboxes import basic_toolboxes
|
12
13
|
from pydantic import Field
|
13
14
|
|
fabricatio/models/task.py
CHANGED
@@ -7,13 +7,11 @@ from asyncio import Queue
|
|
7
7
|
from enum import Enum
|
8
8
|
from typing import Any, List, Optional, Self
|
9
9
|
|
10
|
-
from fabricatio._rust_instances import template_manager
|
11
10
|
from fabricatio.core import env
|
12
11
|
from fabricatio.journal import logger
|
13
12
|
from fabricatio.models.events import Event, EventLike
|
14
|
-
from fabricatio.models.generic import
|
15
|
-
from
|
16
|
-
from pydantic import Field, PrivateAttr, ValidationError
|
13
|
+
from fabricatio.models.generic import WithBriefing, WithDependency, WithJsonExample
|
14
|
+
from pydantic import Field, PrivateAttr
|
17
15
|
|
18
16
|
|
19
17
|
class TaskStatus(Enum):
|
@@ -255,28 +253,3 @@ class Task[T](WithBriefing, WithJsonExample, WithDependency):
|
|
255
253
|
str: The briefing of the task.
|
256
254
|
"""
|
257
255
|
return f"{super().briefing}\n{self.goal}"
|
258
|
-
|
259
|
-
|
260
|
-
class ProposeTask(LLMUsage, WithBriefing):
|
261
|
-
"""A class that proposes a task based on a prompt."""
|
262
|
-
|
263
|
-
async def propose(self, prompt: str) -> Task:
|
264
|
-
"""Propose a task based on the provided prompt."""
|
265
|
-
assert prompt, "Prompt must be provided."
|
266
|
-
|
267
|
-
def _validate_json(response: str) -> None | Task:
|
268
|
-
try:
|
269
|
-
cap = JsonCapture.capture(response)
|
270
|
-
logger.debug(f"Response: \n{response}")
|
271
|
-
logger.info(f"Captured JSON: \n{cap}")
|
272
|
-
return Task.model_validate_json(cap)
|
273
|
-
except ValidationError as e:
|
274
|
-
logger.error(f"Failed to parse task from JSON: {e}")
|
275
|
-
return None
|
276
|
-
|
277
|
-
template_data = {"prompt": prompt, "json_example": Task.json_example()}
|
278
|
-
return await self.aask_validate(
|
279
|
-
template_manager.render_template("propose_task", template_data),
|
280
|
-
_validate_json,
|
281
|
-
system_message=f"# your personal briefing: \n{self.briefing}",
|
282
|
-
)
|
fabricatio/models/tool.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1
1
|
"""A module for defining tools and toolboxes."""
|
2
2
|
|
3
|
+
from importlib.machinery import ModuleSpec
|
4
|
+
from importlib.util import module_from_spec
|
3
5
|
from inspect import iscoroutinefunction, signature
|
4
|
-
from
|
6
|
+
from sys import modules
|
7
|
+
from types import CodeType, ModuleType
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Self, overload
|
5
9
|
|
10
|
+
from fabricatio.config import configs
|
6
11
|
from fabricatio.journal import logger
|
7
|
-
from fabricatio.models.generic import
|
8
|
-
from pydantic import Field
|
12
|
+
from fabricatio.models.generic import WithBriefing
|
13
|
+
from pydantic import BaseModel, ConfigDict, Field
|
9
14
|
|
10
15
|
|
11
16
|
class Tool[**P, R](WithBriefing):
|
@@ -23,7 +28,9 @@ class Tool[**P, R](WithBriefing):
|
|
23
28
|
def model_post_init(self, __context: Any) -> None:
|
24
29
|
"""Initialize the tool with a name and a source function."""
|
25
30
|
self.name = self.name or self.source.__name__
|
26
|
-
|
31
|
+
|
32
|
+
if not self.name:
|
33
|
+
raise RuntimeError("The tool must have a source function.")
|
27
34
|
self.description = self.description or self.source.__doc__ or ""
|
28
35
|
self.description = self.description.strip()
|
29
36
|
|
@@ -53,7 +60,7 @@ def _desc_wrapper(desc: str) -> str:
|
|
53
60
|
class ToolBox(WithBriefing):
|
54
61
|
"""A class representing a collection of tools."""
|
55
62
|
|
56
|
-
tools: List[Tool] = Field(default_factory=list)
|
63
|
+
tools: List[Tool] = Field(default_factory=list, frozen=True)
|
57
64
|
"""A list of tools in the toolbox."""
|
58
65
|
|
59
66
|
def collect_tool[**P, R](self, func: Callable[P, R]) -> Callable[P, R]:
|
@@ -101,10 +108,14 @@ class ToolBox(WithBriefing):
|
|
101
108
|
Tool: The tool instance with the specified name.
|
102
109
|
|
103
110
|
Raises:
|
104
|
-
|
111
|
+
ValueError: If no tool with the specified name is found.
|
105
112
|
"""
|
106
113
|
tool = next((tool for tool in self.tools if tool.name == name), None)
|
107
|
-
|
114
|
+
if tool is None:
|
115
|
+
err = f"No tool with the name {name} found in the toolbox."
|
116
|
+
logger.error(err)
|
117
|
+
raise ValueError(err)
|
118
|
+
|
108
119
|
return tool
|
109
120
|
|
110
121
|
def __hash__(self) -> int:
|
@@ -112,45 +123,50 @@ class ToolBox(WithBriefing):
|
|
112
123
|
return hash(self.briefing)
|
113
124
|
|
114
125
|
|
115
|
-
class
|
116
|
-
"""A class representing
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
def
|
123
|
-
"""
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
def
|
143
|
-
"""
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
126
|
+
class ToolExecutor(BaseModel):
|
127
|
+
"""A class representing a tool executor with a sequence of tools to execute."""
|
128
|
+
|
129
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
130
|
+
execute_sequence: List[Tool] = Field(default_factory=list, frozen=True)
|
131
|
+
"""The sequence of tools to execute."""
|
132
|
+
|
133
|
+
def inject_tools[M: ModuleType](self, module: Optional[M] = None) -> M:
|
134
|
+
"""Inject the tools into the provided module."""
|
135
|
+
module = module or module_from_spec(spec=ModuleSpec(name=configs.toolbox.tool_module_name, loader=None))
|
136
|
+
for tool in self.execute_sequence:
|
137
|
+
setattr(module, tool.name, tool.invoke)
|
138
|
+
return module
|
139
|
+
|
140
|
+
def execute[C: Dict[str, Any]](self, source: CodeType, cxt: Optional[C] = None) -> C:
|
141
|
+
"""Execute the sequence of tools with the provided context."""
|
142
|
+
modules[configs.toolbox.tool_module_name] = self.inject_tools()
|
143
|
+
exec(source, cxt) # noqa: S102
|
144
|
+
modules.pop(configs.toolbox.tool_module_name)
|
145
|
+
return cxt
|
146
|
+
|
147
|
+
@overload
|
148
|
+
def take[C: Dict[str, Any]](self, keys: List[str], source: CodeType, cxt: Optional[C] = None) -> C:
|
149
|
+
"""Check the output of the tools with the provided context."""
|
150
|
+
...
|
151
|
+
|
152
|
+
@overload
|
153
|
+
def take[C: Dict[str, Any]](self, keys: str, source: CodeType, cxt: Optional[C] = None) -> Any:
|
154
|
+
"""Check the output of the tools with the provided context."""
|
155
|
+
...
|
156
|
+
|
157
|
+
def take[C: Dict[str, Any]](self, keys: List[str] | str, source: CodeType, cxt: Optional[C] = None) -> C | Any:
|
158
|
+
"""Check the output of the tools with the provided context."""
|
159
|
+
cxt = self.execute(source, cxt)
|
160
|
+
if isinstance(keys, str):
|
161
|
+
return cxt[keys]
|
162
|
+
return {key: cxt[key] for key in keys}
|
163
|
+
|
164
|
+
@classmethod
|
165
|
+
def from_recipe(cls, recipe: List[str], toolboxes: List[ToolBox]) -> Self:
|
166
|
+
"""Create a tool executor from a recipe and a list of toolboxes."""
|
167
|
+
tools = []
|
168
|
+
while tool_name := recipe.pop(0):
|
169
|
+
for toolbox in toolboxes:
|
170
|
+
tools.append(toolbox[tool_name])
|
171
|
+
|
172
|
+
return cls(execute_sequence=tools)
|