fabricatio 0.2.5.dev5__cp312-cp312-win_amd64.whl → 0.2.6.dev1__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +6 -6
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/_rust_instances.py +1 -1
- fabricatio/actions/article.py +4 -4
- fabricatio/actions/output.py +1 -3
- fabricatio/actions/rag.py +2 -2
- fabricatio/capabilities/correct.py +103 -0
- fabricatio/capabilities/rag.py +20 -17
- fabricatio/capabilities/rating.py +8 -8
- fabricatio/capabilities/review.py +53 -16
- fabricatio/capabilities/task.py +3 -3
- fabricatio/config.py +37 -6
- fabricatio/fs/__init__.py +2 -2
- fabricatio/fs/readers.py +1 -1
- fabricatio/journal.py +0 -7
- fabricatio/models/action.py +4 -4
- fabricatio/models/generic.py +40 -10
- fabricatio/models/kwargs_types.py +2 -2
- fabricatio/models/role.py +2 -2
- fabricatio/models/task.py +2 -2
- fabricatio/models/usages.py +70 -23
- fabricatio/models/utils.py +21 -0
- fabricatio/parser.py +13 -5
- {fabricatio-0.2.5.dev5.data → fabricatio-0.2.6.dev1.data}/scripts/tdown.exe +0 -0
- {fabricatio-0.2.5.dev5.dist-info → fabricatio-0.2.6.dev1.dist-info}/METADATA +3 -3
- fabricatio-0.2.6.dev1.dist-info/RECORD +42 -0
- fabricatio-0.2.5.dev5.dist-info/RECORD +0 -41
- {fabricatio-0.2.5.dev5.dist-info → fabricatio-0.2.6.dev1.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.5.dev5.dist-info → fabricatio-0.2.6.dev1.dist-info}/licenses/LICENSE +0 -0
fabricatio/config.py
CHANGED
@@ -19,8 +19,6 @@ from pydantic import (
|
|
19
19
|
)
|
20
20
|
from pydantic_settings import (
|
21
21
|
BaseSettings,
|
22
|
-
DotEnvSettingsSource,
|
23
|
-
EnvSettingsSource,
|
24
22
|
PydanticBaseSettingsSource,
|
25
23
|
PyprojectTomlConfigSettingsSource,
|
26
24
|
SettingsConfigDict,
|
@@ -68,7 +66,7 @@ class LLMConfig(BaseModel):
|
|
68
66
|
temperature: NonNegativeFloat = Field(default=1.0)
|
69
67
|
"""The temperature of the LLM model. Controls randomness in generation. Set to 1.0 as per request."""
|
70
68
|
|
71
|
-
stop_sign: str | List[str] = Field(default_factory=lambda
|
69
|
+
stop_sign: str | List[str] = Field(default_factory=lambda: ["\n\n\n", "User:"])
|
72
70
|
"""The stop sign of the LLM model. No default stop sign specified."""
|
73
71
|
|
74
72
|
top_p: NonNegativeFloat = Field(default=0.35)
|
@@ -83,6 +81,12 @@ class LLMConfig(BaseModel):
|
|
83
81
|
max_tokens: PositiveInt = Field(default=8192)
|
84
82
|
"""The maximum number of tokens to generate. Set to 8192 as per request."""
|
85
83
|
|
84
|
+
rpm: Optional[PositiveInt] = Field(default=100)
|
85
|
+
"""The rate limit of the LLM model in requests per minute. None means not checked."""
|
86
|
+
|
87
|
+
tpm: Optional[PositiveInt] = Field(default=1000000)
|
88
|
+
"""The rate limit of the LLM model in tokens per minute. None means not checked."""
|
89
|
+
|
86
90
|
|
87
91
|
class EmbeddingConfig(BaseModel):
|
88
92
|
"""Embedding configuration class."""
|
@@ -222,6 +226,12 @@ class TemplateConfig(BaseModel):
|
|
222
226
|
review_string_template: str = Field(default="review_string")
|
223
227
|
"""The name of the review string template which will be used to review a string."""
|
224
228
|
|
229
|
+
generic_string_template: str = Field(default="generic_string")
|
230
|
+
"""The name of the generic string template which will be used to review a string."""
|
231
|
+
|
232
|
+
correct_template: str = Field(default="correct")
|
233
|
+
"""The name of the correct template which will be used to correct a string."""
|
234
|
+
|
225
235
|
|
226
236
|
class MagikaConfig(BaseModel):
|
227
237
|
"""Magika configuration class."""
|
@@ -285,6 +295,19 @@ class CacheConfig(BaseModel):
|
|
285
295
|
"""Whether to enable cache."""
|
286
296
|
|
287
297
|
|
298
|
+
class RoutingConfig(BaseModel):
|
299
|
+
"""Routing configuration class."""
|
300
|
+
|
301
|
+
model_config = ConfigDict(use_attribute_docstrings=True)
|
302
|
+
|
303
|
+
allowed_fails: Optional[int] = 1
|
304
|
+
"""The number of allowed fails before the routing is considered failed."""
|
305
|
+
retry_after: int = 15
|
306
|
+
"""The time in seconds to wait before retrying the routing after a fail."""
|
307
|
+
cooldown_time: Optional[int] = 120
|
308
|
+
"""The time in seconds to wait before retrying the routing after a cooldown."""
|
309
|
+
|
310
|
+
|
288
311
|
class Settings(BaseSettings):
|
289
312
|
"""Application settings class.
|
290
313
|
|
@@ -310,6 +333,9 @@ class Settings(BaseSettings):
|
|
310
333
|
llm: LLMConfig = Field(default_factory=LLMConfig)
|
311
334
|
"""LLM Configuration"""
|
312
335
|
|
336
|
+
routing: RoutingConfig = Field(default_factory=RoutingConfig)
|
337
|
+
"""Routing Configuration"""
|
338
|
+
|
313
339
|
embedding: EmbeddingConfig = Field(default_factory=EmbeddingConfig)
|
314
340
|
"""Embedding Configuration"""
|
315
341
|
|
@@ -348,6 +374,9 @@ class Settings(BaseSettings):
|
|
348
374
|
) -> tuple[PydanticBaseSettingsSource, ...]:
|
349
375
|
"""Customize settings sources.
|
350
376
|
|
377
|
+
This method customizes the settings sources used by the application. It returns a tuple of settings sources, including
|
378
|
+
the dotenv settings source, environment settings source, a custom TomlConfigSettingsSource, and a custom
|
379
|
+
|
351
380
|
Args:
|
352
381
|
settings_cls (type[BaseSettings]): The settings class.
|
353
382
|
init_settings (PydanticBaseSettingsSource): Initial settings source.
|
@@ -359,10 +388,12 @@ class Settings(BaseSettings):
|
|
359
388
|
tuple[PydanticBaseSettingsSource, ...]: A tuple of settings sources.
|
360
389
|
"""
|
361
390
|
return (
|
362
|
-
|
363
|
-
|
364
|
-
|
391
|
+
init_settings,
|
392
|
+
dotenv_settings,
|
393
|
+
env_settings,
|
394
|
+
file_secret_settings,
|
365
395
|
PyprojectTomlConfigSettingsSource(settings_cls),
|
396
|
+
TomlConfigSettingsSource(settings_cls),
|
366
397
|
)
|
367
398
|
|
368
399
|
|
fabricatio/fs/__init__.py
CHANGED
@@ -11,9 +11,10 @@ from fabricatio.fs.curd import (
|
|
11
11
|
move_file,
|
12
12
|
tree,
|
13
13
|
)
|
14
|
-
from fabricatio.fs.readers import
|
14
|
+
from fabricatio.fs.readers import MAGIKA, safe_json_read, safe_text_read
|
15
15
|
|
16
16
|
__all__ = [
|
17
|
+
"MAGIKA",
|
17
18
|
"absolute_path",
|
18
19
|
"copy_file",
|
19
20
|
"create_directory",
|
@@ -21,7 +22,6 @@ __all__ = [
|
|
21
22
|
"delete_file",
|
22
23
|
"dump_text",
|
23
24
|
"gather_files",
|
24
|
-
"magika",
|
25
25
|
"move_file",
|
26
26
|
"safe_json_read",
|
27
27
|
"safe_text_read",
|
fabricatio/fs/readers.py
CHANGED
@@ -9,7 +9,7 @@ from magika import Magika
|
|
9
9
|
from fabricatio.config import configs
|
10
10
|
from fabricatio.journal import logger
|
11
11
|
|
12
|
-
|
12
|
+
MAGIKA = Magika(model_dir=configs.magika.model_dir)
|
13
13
|
|
14
14
|
|
15
15
|
def safe_text_read(path: Path | str) -> str:
|
fabricatio/journal.py
CHANGED
@@ -19,10 +19,3 @@ logger.add(
|
|
19
19
|
logger.add(sys.stderr, level=configs.debug.log_level)
|
20
20
|
|
21
21
|
__all__ = ["logger"]
|
22
|
-
if __name__ == "__main__":
|
23
|
-
logger.debug("This is a trace message.")
|
24
|
-
logger.info("This is an information message.")
|
25
|
-
logger.success("This is a success message.")
|
26
|
-
logger.warning("This is a warning message.")
|
27
|
-
logger.error("This is an error message.")
|
28
|
-
logger.critical("This is a critical message.")
|
fabricatio/models/action.py
CHANGED
@@ -3,9 +3,9 @@
|
|
3
3
|
import traceback
|
4
4
|
from abc import abstractmethod
|
5
5
|
from asyncio import Queue, create_task
|
6
|
-
from typing import Any, Dict, Self, Tuple, Type, Union,
|
6
|
+
from typing import Any, Dict, Self, Tuple, Type, Union, final
|
7
7
|
|
8
|
-
from fabricatio.capabilities.
|
8
|
+
from fabricatio.capabilities.correct import Correct
|
9
9
|
from fabricatio.capabilities.task import HandleTask, ProposeTask
|
10
10
|
from fabricatio.journal import logger
|
11
11
|
from fabricatio.models.generic import WithBriefing
|
@@ -14,7 +14,7 @@ from fabricatio.models.usages import ToolBoxUsage
|
|
14
14
|
from pydantic import Field, PrivateAttr
|
15
15
|
|
16
16
|
|
17
|
-
class Action(HandleTask, ProposeTask,
|
17
|
+
class Action(HandleTask, ProposeTask, Correct):
|
18
18
|
"""Class that represents an action to be executed in a workflow."""
|
19
19
|
|
20
20
|
name: str = Field(default="")
|
@@ -37,7 +37,7 @@ class Action(HandleTask, ProposeTask, Review):
|
|
37
37
|
self.description = self.description or self.__class__.__doc__ or ""
|
38
38
|
|
39
39
|
@abstractmethod
|
40
|
-
async def _execute(self, **cxt
|
40
|
+
async def _execute(self, **cxt) -> Any:
|
41
41
|
"""Execute the action with the provided arguments.
|
42
42
|
|
43
43
|
Args:
|
fabricatio/models/generic.py
CHANGED
@@ -2,13 +2,13 @@
|
|
2
2
|
|
3
3
|
from abc import abstractmethod
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Union, final
|
5
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Union, final, overload
|
6
6
|
|
7
7
|
import orjson
|
8
8
|
from fabricatio._rust import blake3_hash
|
9
|
-
from fabricatio._rust_instances import
|
9
|
+
from fabricatio._rust_instances import TEMPLATE_MANAGER
|
10
10
|
from fabricatio.config import configs
|
11
|
-
from fabricatio.fs.readers import
|
11
|
+
from fabricatio.fs.readers import MAGIKA, safe_text_read
|
12
12
|
from fabricatio.journal import logger
|
13
13
|
from fabricatio.parser import JsonCapture
|
14
14
|
from pydantic import (
|
@@ -40,6 +40,14 @@ class Display(Base):
|
|
40
40
|
"""
|
41
41
|
return self.model_dump_json(indent=1)
|
42
42
|
|
43
|
+
def compact(self) -> str:
|
44
|
+
"""Display the model in a compact JSON string.
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
str: The compact JSON string of the model.
|
48
|
+
"""
|
49
|
+
return self.model_dump_json()
|
50
|
+
|
43
51
|
|
44
52
|
class Named(Base):
|
45
53
|
"""Class that includes a name attribute."""
|
@@ -100,7 +108,15 @@ class CreateJsonObjPrompt(WithFormatedJsonSchema):
|
|
100
108
|
"""Class that provides a prompt for creating a JSON object."""
|
101
109
|
|
102
110
|
@classmethod
|
103
|
-
|
111
|
+
@overload
|
112
|
+
def create_json_prompt(cls, requirement: List[str]) -> List[str]: ...
|
113
|
+
|
114
|
+
@classmethod
|
115
|
+
@overload
|
116
|
+
def create_json_prompt(cls, requirement: str) -> str: ...
|
117
|
+
|
118
|
+
@classmethod
|
119
|
+
def create_json_prompt(cls, requirement: str | List[str]) -> str | List[str]:
|
104
120
|
"""Create the prompt for creating a JSON object with given requirement.
|
105
121
|
|
106
122
|
Args:
|
@@ -109,10 +125,18 @@ class CreateJsonObjPrompt(WithFormatedJsonSchema):
|
|
109
125
|
Returns:
|
110
126
|
str: The prompt for creating a JSON object with given requirement.
|
111
127
|
"""
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
128
|
+
if isinstance(requirement, str):
|
129
|
+
return TEMPLATE_MANAGER.render_template(
|
130
|
+
configs.templates.create_json_obj_template,
|
131
|
+
{"requirement": requirement, "json_schema": cls.formated_json_schema()},
|
132
|
+
)
|
133
|
+
return [
|
134
|
+
TEMPLATE_MANAGER.render_template(
|
135
|
+
configs.templates.create_json_obj_template,
|
136
|
+
{"requirement": r, "json_schema": cls.formated_json_schema()},
|
137
|
+
)
|
138
|
+
for r in requirement
|
139
|
+
]
|
116
140
|
|
117
141
|
|
118
142
|
class InstantiateFromString(Base):
|
@@ -231,13 +255,13 @@ class WithDependency(Base):
|
|
231
255
|
Returns:
|
232
256
|
str: The generated prompt for the task.
|
233
257
|
"""
|
234
|
-
return
|
258
|
+
return TEMPLATE_MANAGER.render_template(
|
235
259
|
configs.templates.dependencies_template,
|
236
260
|
{
|
237
261
|
(pth := Path(p)).name: {
|
238
262
|
"path": pth.as_posix(),
|
239
263
|
"exists": pth.exists(),
|
240
|
-
"description": (identity :=
|
264
|
+
"description": (identity := MAGIKA.identify_path(pth)).output.description,
|
241
265
|
"size": f"{pth.stat().st_size / (1024 * 1024) if pth.exists() and pth.is_file() else 0:.3f} MB",
|
242
266
|
"content": (text := safe_text_read(pth)),
|
243
267
|
"lines": len(text.splitlines()),
|
@@ -307,6 +331,12 @@ class ScopedConfig(Base):
|
|
307
331
|
llm_max_tokens: Optional[PositiveInt] = None
|
308
332
|
"""The maximum number of tokens to generate."""
|
309
333
|
|
334
|
+
llm_tpm: Optional[PositiveInt] = None
|
335
|
+
"""The tokens per minute of the LLM model."""
|
336
|
+
|
337
|
+
llm_rpm: Optional[PositiveInt] = None
|
338
|
+
"""The requests per minute of the LLM model."""
|
339
|
+
|
310
340
|
embedding_api_endpoint: Optional[HttpUrl] = None
|
311
341
|
"""The OpenAI API endpoint."""
|
312
342
|
|
@@ -12,7 +12,7 @@ class CollectionSimpleConfigKwargs(TypedDict, total=False):
|
|
12
12
|
These arguments are typically used when configuring connections to vector databases.
|
13
13
|
"""
|
14
14
|
|
15
|
-
dimension: int
|
15
|
+
dimension: int | None
|
16
16
|
timeout: float
|
17
17
|
|
18
18
|
|
@@ -23,7 +23,7 @@ class FetchKwargs(TypedDict, total=False):
|
|
23
23
|
and result limiting parameters.
|
24
24
|
"""
|
25
25
|
|
26
|
-
collection_name: str
|
26
|
+
collection_name: str | None
|
27
27
|
similarity_threshold: float
|
28
28
|
result_per_query: int
|
29
29
|
|
fabricatio/models/role.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
from typing import Any, Self, Set
|
4
4
|
|
5
|
-
from fabricatio.capabilities.
|
5
|
+
from fabricatio.capabilities.correct import Correct
|
6
6
|
from fabricatio.capabilities.task import HandleTask, ProposeTask
|
7
7
|
from fabricatio.core import env
|
8
8
|
from fabricatio.journal import logger
|
@@ -12,7 +12,7 @@ from fabricatio.models.tool import ToolBox
|
|
12
12
|
from pydantic import Field
|
13
13
|
|
14
14
|
|
15
|
-
class Role(ProposeTask, HandleTask,
|
15
|
+
class Role(ProposeTask, HandleTask, Correct):
|
16
16
|
"""Class that represents a role with a registry of events and workflows."""
|
17
17
|
|
18
18
|
registry: dict[Event | str, WorkFlow] = Field(default_factory=dict)
|
fabricatio/models/task.py
CHANGED
@@ -6,7 +6,7 @@ It includes methods to manage the task's lifecycle, such as starting, finishing,
|
|
6
6
|
from asyncio import Queue
|
7
7
|
from typing import Any, List, Optional, Self
|
8
8
|
|
9
|
-
from fabricatio._rust_instances import
|
9
|
+
from fabricatio._rust_instances import TEMPLATE_MANAGER
|
10
10
|
from fabricatio.config import configs
|
11
11
|
from fabricatio.core import env
|
12
12
|
from fabricatio.journal import logger
|
@@ -253,7 +253,7 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
|
|
253
253
|
Returns:
|
254
254
|
str: The briefing of the task.
|
255
255
|
"""
|
256
|
-
return
|
256
|
+
return TEMPLATE_MANAGER.render_template(
|
257
257
|
configs.templates.task_briefing_template,
|
258
258
|
self.model_dump(),
|
259
259
|
)
|
fabricatio/models/usages.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Self, Sequence, Set
|
|
5
5
|
|
6
6
|
import asyncstdlib
|
7
7
|
import litellm
|
8
|
-
from fabricatio._rust_instances import
|
8
|
+
from fabricatio._rust_instances import TEMPLATE_MANAGER
|
9
9
|
from fabricatio.config import configs
|
10
10
|
from fabricatio.journal import logger
|
11
11
|
from fabricatio.models.generic import ScopedConfig, WithBriefing
|
@@ -13,8 +13,9 @@ from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, Genera
|
|
13
13
|
from fabricatio.models.task import Task
|
14
14
|
from fabricatio.models.tool import Tool, ToolBox
|
15
15
|
from fabricatio.models.utils import Messages
|
16
|
-
from fabricatio.parser import JsonCapture
|
17
|
-
from litellm import stream_chunk_builder
|
16
|
+
from fabricatio.parser import GenericCapture, JsonCapture
|
17
|
+
from litellm import Router, stream_chunk_builder
|
18
|
+
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
|
18
19
|
from litellm.types.utils import (
|
19
20
|
Choices,
|
20
21
|
EmbeddingResponse,
|
@@ -22,7 +23,7 @@ from litellm.types.utils import (
|
|
22
23
|
StreamingChoices,
|
23
24
|
TextChoices,
|
24
25
|
)
|
25
|
-
from litellm.utils import CustomStreamWrapper
|
26
|
+
from litellm.utils import CustomStreamWrapper # pyright: ignore [reportPrivateImportUsage]
|
26
27
|
from more_itertools import duplicates_everseen
|
27
28
|
from pydantic import Field, NonNegativeInt, PositiveInt
|
28
29
|
|
@@ -30,20 +31,33 @@ if configs.cache.enabled and configs.cache.type:
|
|
30
31
|
litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
|
31
32
|
logger.success(f"{configs.cache.type.name} Cache enabled")
|
32
33
|
|
34
|
+
ROUTER = Router(
|
35
|
+
routing_strategy="usage-based-routing-v2",
|
36
|
+
allowed_fails=configs.routing.allowed_fails,
|
37
|
+
retry_after=configs.routing.retry_after,
|
38
|
+
cooldown_time=configs.routing.cooldown_time,
|
39
|
+
)
|
40
|
+
|
33
41
|
|
34
42
|
class LLMUsage(ScopedConfig):
|
35
43
|
"""Class that manages LLM (Large Language Model) usage parameters and methods."""
|
36
44
|
|
45
|
+
def _deploy(self, deployment: Deployment) -> Router:
|
46
|
+
"""Add a deployment to the router."""
|
47
|
+
self._added_deployment = ROUTER.upsert_deployment(deployment)
|
48
|
+
return ROUTER
|
49
|
+
|
37
50
|
@classmethod
|
38
51
|
def _scoped_model(cls) -> Type["LLMUsage"]:
|
39
52
|
return LLMUsage
|
40
53
|
|
54
|
+
# noinspection PyTypeChecker,PydanticTypeChecker
|
41
55
|
async def aquery(
|
42
56
|
self,
|
43
57
|
messages: List[Dict[str, str]],
|
44
58
|
n: PositiveInt | None = None,
|
45
59
|
**kwargs: Unpack[LLMKwargs],
|
46
|
-
) -> ModelResponse:
|
60
|
+
) -> ModelResponse | CustomStreamWrapper:
|
47
61
|
"""Asynchronously queries the language model to generate a response based on the provided messages and parameters.
|
48
62
|
|
49
63
|
Args:
|
@@ -55,19 +69,33 @@ class LLMUsage(ScopedConfig):
|
|
55
69
|
ModelResponse | CustomStreamWrapper: An object containing the generated response and other metadata from the model.
|
56
70
|
"""
|
57
71
|
# Call the underlying asynchronous completion function with the provided and default parameters
|
58
|
-
|
72
|
+
# noinspection PyTypeChecker,PydanticTypeChecker
|
73
|
+
|
74
|
+
return await self._deploy(
|
75
|
+
Deployment(
|
76
|
+
model_name=(m_name := kwargs.get("model") or self.llm_model or configs.llm.model),
|
77
|
+
litellm_params=(
|
78
|
+
p := LiteLLM_Params(
|
79
|
+
api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
|
80
|
+
api_base=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
|
81
|
+
model=m_name,
|
82
|
+
tpm=self.llm_tpm or configs.llm.tpm,
|
83
|
+
rpm=self.llm_rpm or configs.llm.rpm,
|
84
|
+
max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
|
85
|
+
timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
|
86
|
+
)
|
87
|
+
),
|
88
|
+
model_info=ModelInfo(id=hash(m_name + p.model_dump_json(exclude_none=True))),
|
89
|
+
)
|
90
|
+
).acompletion(
|
59
91
|
messages=messages,
|
60
92
|
n=n or self.llm_generation_count or configs.llm.generation_count,
|
61
|
-
model=
|
93
|
+
model=m_name,
|
62
94
|
temperature=kwargs.get("temperature") or self.llm_temperature or configs.llm.temperature,
|
63
95
|
stop=kwargs.get("stop") or self.llm_stop_sign or configs.llm.stop_sign,
|
64
96
|
top_p=kwargs.get("top_p") or self.llm_top_p or configs.llm.top_p,
|
65
97
|
max_tokens=kwargs.get("max_tokens") or self.llm_max_tokens or configs.llm.max_tokens,
|
66
98
|
stream=kwargs.get("stream") or self.llm_stream or configs.llm.stream,
|
67
|
-
timeout=kwargs.get("timeout") or self.llm_timeout or configs.llm.timeout,
|
68
|
-
max_retries=kwargs.get("max_retries") or self.llm_max_retries or configs.llm.max_retries,
|
69
|
-
api_key=(self.llm_api_key or configs.llm.api_key).get_secret_value(),
|
70
|
-
base_url=(self.llm_api_endpoint or configs.llm.api_endpoint).unicode_string(),
|
71
99
|
cache={
|
72
100
|
"no-cache": kwargs.get("no_cache"),
|
73
101
|
"no-store": kwargs.get("no_store"),
|
@@ -192,31 +220,31 @@ class LLMUsage(ScopedConfig):
|
|
192
220
|
@overload
|
193
221
|
async def aask_validate[T](
|
194
222
|
self,
|
195
|
-
question: str,
|
223
|
+
question: List[str],
|
196
224
|
validator: Callable[[str], T | None],
|
197
|
-
default:
|
225
|
+
default: T,
|
198
226
|
max_validations: PositiveInt = 2,
|
199
227
|
**kwargs: Unpack[GenerateKwargs],
|
200
|
-
) ->
|
201
|
-
|
228
|
+
) -> List[T]: ...
|
202
229
|
@overload
|
203
230
|
async def aask_validate[T](
|
204
231
|
self,
|
205
|
-
question:
|
232
|
+
question: str,
|
206
233
|
validator: Callable[[str], T | None],
|
207
234
|
default: None = None,
|
208
235
|
max_validations: PositiveInt = 2,
|
209
236
|
**kwargs: Unpack[GenerateKwargs],
|
210
|
-
) ->
|
237
|
+
) -> Optional[T]: ...
|
238
|
+
|
211
239
|
@overload
|
212
240
|
async def aask_validate[T](
|
213
241
|
self,
|
214
242
|
question: List[str],
|
215
243
|
validator: Callable[[str], T | None],
|
216
|
-
default:
|
244
|
+
default: None = None,
|
217
245
|
max_validations: PositiveInt = 2,
|
218
246
|
**kwargs: Unpack[GenerateKwargs],
|
219
|
-
) -> List[T]: ...
|
247
|
+
) -> List[Optional[T]]: ...
|
220
248
|
|
221
249
|
async def aask_validate[T](
|
222
250
|
self,
|
@@ -274,7 +302,7 @@ class LLMUsage(ScopedConfig):
|
|
274
302
|
List[str]: The validated response as a list of strings.
|
275
303
|
"""
|
276
304
|
return await self.aask_validate(
|
277
|
-
|
305
|
+
TEMPLATE_MANAGER.render_template(
|
278
306
|
configs.templates.liststr_template,
|
279
307
|
{"requirement": requirement, "k": k},
|
280
308
|
),
|
@@ -293,7 +321,7 @@ class LLMUsage(ScopedConfig):
|
|
293
321
|
List[str]: The validated response as a list of strings.
|
294
322
|
"""
|
295
323
|
return await self.aliststr(
|
296
|
-
|
324
|
+
TEMPLATE_MANAGER.render_template(
|
297
325
|
configs.templates.pathstr_template,
|
298
326
|
{"requirement": requirement},
|
299
327
|
),
|
@@ -318,6 +346,25 @@ class LLMUsage(ScopedConfig):
|
|
318
346
|
)
|
319
347
|
).pop()
|
320
348
|
|
349
|
+
async def ageneric_string(self, requirement: str, **kwargs: Unpack[ValidateKwargs[str]]) -> str:
|
350
|
+
"""Asynchronously generates a generic string based on a given requirement.
|
351
|
+
|
352
|
+
Args:
|
353
|
+
requirement (str): The requirement for the string.
|
354
|
+
**kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
|
355
|
+
|
356
|
+
Returns:
|
357
|
+
str: The generated string.
|
358
|
+
"""
|
359
|
+
return await self.aask_validate(
|
360
|
+
TEMPLATE_MANAGER.render_template(
|
361
|
+
configs.templates.generic_string_template,
|
362
|
+
{"requirement": requirement, "language": GenericCapture.capture_type},
|
363
|
+
),
|
364
|
+
validator=lambda resp: GenericCapture.capture(resp),
|
365
|
+
**kwargs,
|
366
|
+
)
|
367
|
+
|
321
368
|
async def achoose[T: WithBriefing](
|
322
369
|
self,
|
323
370
|
instruction: str,
|
@@ -344,7 +391,7 @@ class LLMUsage(ScopedConfig):
|
|
344
391
|
if dup := duplicates_everseen(choices, key=lambda x: x.name):
|
345
392
|
logger.error(err := f"Redundant choices: {dup}")
|
346
393
|
raise ValueError(err)
|
347
|
-
prompt =
|
394
|
+
prompt = TEMPLATE_MANAGER.render_template(
|
348
395
|
configs.templates.make_choice_template,
|
349
396
|
{
|
350
397
|
"instruction": instruction,
|
@@ -417,7 +464,7 @@ class LLMUsage(ScopedConfig):
|
|
417
464
|
bool: The judgment result (True or False) based on the AI's response.
|
418
465
|
"""
|
419
466
|
return await self.aask_validate(
|
420
|
-
question=
|
467
|
+
question=TEMPLATE_MANAGER.render_template(
|
421
468
|
configs.templates.make_judgment_template,
|
422
469
|
{"prompt": prompt, "affirm_case": affirm_case, "deny_case": deny_case},
|
423
470
|
),
|
fabricatio/models/utils.py
CHANGED
@@ -4,6 +4,7 @@ from enum import Enum
|
|
4
4
|
from typing import Any, Dict, List, Literal, Optional, Self
|
5
5
|
|
6
6
|
from pydantic import BaseModel, ConfigDict, Field
|
7
|
+
from questionary import text
|
7
8
|
|
8
9
|
|
9
10
|
class Message(BaseModel):
|
@@ -144,3 +145,23 @@ class TaskStatus(Enum):
|
|
144
145
|
Finished = "finished"
|
145
146
|
Failed = "failed"
|
146
147
|
Cancelled = "cancelled"
|
148
|
+
|
149
|
+
|
150
|
+
async def ask_edit(
|
151
|
+
text_seq: List[str],
|
152
|
+
) -> List[str]:
|
153
|
+
"""Asks the user to edit a list of texts.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
text_seq (List[str]): A list of texts to be edited.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
List[str]: A list of edited texts.
|
160
|
+
If the user does not edit a text, it will not be included in the returned list.
|
161
|
+
"""
|
162
|
+
res = []
|
163
|
+
for i, t in enumerate(text_seq):
|
164
|
+
edited = await text(f"[{i}] ", default=t).ask_async()
|
165
|
+
if edited:
|
166
|
+
res.append(edited)
|
167
|
+
return res
|
fabricatio/parser.py
CHANGED
@@ -35,7 +35,7 @@ class Capture(BaseModel):
|
|
35
35
|
"""Initialize the compiled pattern."""
|
36
36
|
self._compiled = compile(self.pattern, self.flags)
|
37
37
|
|
38
|
-
def fix[T](self, text: str | Iterable[str]|T) -> str | List[str]|T:
|
38
|
+
def fix[T](self, text: str | Iterable[str] | T) -> str | List[str] | T:
|
39
39
|
"""Fix the text using the pattern.
|
40
40
|
|
41
41
|
Args:
|
@@ -47,8 +47,8 @@ class Capture(BaseModel):
|
|
47
47
|
match self.capture_type:
|
48
48
|
case "json":
|
49
49
|
if isinstance(text, str):
|
50
|
-
return repair_json(text,ensure_ascii=False)
|
51
|
-
return [repair_json(item) for item in text]
|
50
|
+
return repair_json(text, ensure_ascii=False)
|
51
|
+
return [repair_json(item, ensure_ascii=False) for item in text]
|
52
52
|
case _:
|
53
53
|
return text
|
54
54
|
|
@@ -134,8 +134,16 @@ class Capture(BaseModel):
|
|
134
134
|
"""
|
135
135
|
return cls(pattern=f"```{language}\n(.*?)\n```", capture_type=language)
|
136
136
|
|
137
|
+
@classmethod
|
138
|
+
def capture_generic_block(cls, language: str) -> Self:
|
139
|
+
"""Capture the first occurrence of a generic code block in the given text.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
Self: The instance of the class with the captured code block.
|
143
|
+
"""
|
144
|
+
return cls(pattern=f"--- Start of {language} ---\n(.*?)\n--- end of {language} ---", capture_type=language)
|
145
|
+
|
137
146
|
|
138
147
|
JsonCapture = Capture.capture_code_block("json")
|
139
148
|
PythonCapture = Capture.capture_code_block("python")
|
140
|
-
|
141
|
-
CodeBlockCapture = Capture(pattern="```.*?\n(.*?)\n```")
|
149
|
+
GenericCapture = Capture.capture_generic_block("String")
|
Binary file
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: fabricatio
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.6.dev1
|
4
4
|
Classifier: License :: OSI Approved :: MIT License
|
5
5
|
Classifier: Programming Language :: Rust
|
6
6
|
Classifier: Programming Language :: Python :: 3.12
|
@@ -176,7 +176,7 @@ if __name__ == "__main__":
|
|
176
176
|
### Template Management and Rendering
|
177
177
|
|
178
178
|
```python
|
179
|
-
from fabricatio._rust_instances import
|
179
|
+
from fabricatio._rust_instances import TEMPLATE_MANAGER
|
180
180
|
|
181
181
|
template_name = "claude-xml.hbs"
|
182
182
|
data = {
|
@@ -185,7 +185,7 @@ data = {
|
|
185
185
|
"files": [{"path": "file1.py", "code": "print('Hello')"}],
|
186
186
|
}
|
187
187
|
|
188
|
-
rendered_template =
|
188
|
+
rendered_template = TEMPLATE_MANAGER.render_template(template_name, data)
|
189
189
|
print(rendered_template)
|
190
190
|
```
|
191
191
|
|