fabricatio 0.2.5.dev4__cp312-cp312-win_amd64.whl → 0.2.5.dev5__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/actions/rag.py +1 -1
- fabricatio/capabilities/propose.py +14 -20
- fabricatio/capabilities/rating.py +41 -36
- fabricatio/capabilities/review.py +8 -9
- fabricatio/capabilities/task.py +7 -8
- fabricatio/config.py +8 -4
- fabricatio/fs/readers.py +1 -1
- fabricatio/journal.py +1 -0
- fabricatio/models/action.py +1 -1
- fabricatio/models/events.py +6 -4
- fabricatio/models/extra.py +19 -16
- fabricatio/models/generic.py +14 -1
- fabricatio/models/kwargs_types.py +70 -72
- fabricatio/models/tool.py +4 -4
- fabricatio/models/usages.py +67 -68
- fabricatio/parser.py +26 -5
- {fabricatio-0.2.5.dev4.data → fabricatio-0.2.5.dev5.data}/scripts/tdown.exe +0 -0
- {fabricatio-0.2.5.dev4.dist-info → fabricatio-0.2.5.dev5.dist-info}/METADATA +2 -1
- fabricatio-0.2.5.dev5.dist-info/RECORD +41 -0
- fabricatio-0.2.5.dev4.dist-info/RECORD +0 -41
- {fabricatio-0.2.5.dev4.dist-info → fabricatio-0.2.5.dev5.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.5.dev4.dist-info → fabricatio-0.2.5.dev5.dist-info}/licenses/LICENSE +0 -0
@@ -1,92 +1,90 @@
|
|
1
1
|
"""This module contains the types for the keyword arguments of the methods in the models module."""
|
2
2
|
|
3
|
-
from typing import Any,
|
3
|
+
from typing import Any, TypedDict
|
4
4
|
|
5
5
|
from litellm.caching.caching import CacheMode
|
6
6
|
from litellm.types.caching import CachingSupportedCallTypes
|
7
|
-
from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
|
8
7
|
|
9
8
|
|
10
|
-
class CollectionSimpleConfigKwargs(TypedDict):
|
9
|
+
class CollectionSimpleConfigKwargs(TypedDict, total=False):
|
11
10
|
"""Configuration parameters for a vector collection.
|
12
11
|
|
13
12
|
These arguments are typically used when configuring connections to vector databases.
|
14
13
|
"""
|
15
14
|
|
16
|
-
dimension:
|
17
|
-
timeout:
|
15
|
+
dimension: int
|
16
|
+
timeout: float
|
18
17
|
|
19
18
|
|
20
|
-
class FetchKwargs(TypedDict):
|
19
|
+
class FetchKwargs(TypedDict, total=False):
|
21
20
|
"""Arguments for fetching data from vector collections.
|
22
21
|
|
23
22
|
Controls how data is retrieved from vector databases, including filtering
|
24
23
|
and result limiting parameters.
|
25
24
|
"""
|
26
25
|
|
27
|
-
collection_name:
|
28
|
-
similarity_threshold:
|
29
|
-
result_per_query:
|
26
|
+
collection_name: str
|
27
|
+
similarity_threshold: float
|
28
|
+
result_per_query: int
|
30
29
|
|
31
30
|
|
32
|
-
class EmbeddingKwargs(TypedDict):
|
31
|
+
class EmbeddingKwargs(TypedDict, total=False):
|
33
32
|
"""Configuration parameters for text embedding operations.
|
34
33
|
|
35
34
|
These settings control the behavior of embedding models that convert text
|
36
35
|
to vector representations.
|
37
36
|
"""
|
38
37
|
|
39
|
-
model:
|
40
|
-
dimensions:
|
41
|
-
timeout:
|
42
|
-
caching:
|
38
|
+
model: str
|
39
|
+
dimensions: int
|
40
|
+
timeout: int
|
41
|
+
caching: bool
|
43
42
|
|
44
43
|
|
45
|
-
class LLMKwargs(TypedDict):
|
44
|
+
class LLMKwargs(TypedDict, total=False):
|
46
45
|
"""Configuration parameters for language model inference.
|
47
46
|
|
48
47
|
These arguments control the behavior of large language model calls,
|
49
48
|
including generation parameters and caching options.
|
50
49
|
"""
|
51
50
|
|
52
|
-
model:
|
53
|
-
temperature:
|
54
|
-
stop:
|
55
|
-
top_p:
|
56
|
-
max_tokens:
|
57
|
-
stream:
|
58
|
-
timeout:
|
59
|
-
max_retries:
|
60
|
-
no_cache:
|
61
|
-
no_store:
|
62
|
-
cache_ttl:
|
63
|
-
s_maxage:
|
64
|
-
|
65
|
-
|
66
|
-
class
|
67
|
-
"""Arguments for content
|
51
|
+
model: str
|
52
|
+
temperature: float
|
53
|
+
stop: str | list[str]
|
54
|
+
top_p: float
|
55
|
+
max_tokens: int
|
56
|
+
stream: bool
|
57
|
+
timeout: int
|
58
|
+
max_retries: int
|
59
|
+
no_cache: bool # if the req uses cache in this call
|
60
|
+
no_store: bool # If store the response of this call to cache
|
61
|
+
cache_ttl: int # how long the stored cache is alive, in seconds
|
62
|
+
s_maxage: int # max accepted age of cached response, in seconds
|
63
|
+
|
64
|
+
|
65
|
+
class GenerateKwargs(LLMKwargs, total=False):
|
66
|
+
"""Arguments for content generation operations.
|
68
67
|
|
69
|
-
Extends LLMKwargs with additional parameters specific to
|
70
|
-
such as
|
68
|
+
Extends LLMKwargs with additional parameters specific to generation tasks,
|
69
|
+
such as the number of generated items and the system message.
|
71
70
|
"""
|
72
71
|
|
73
|
-
|
74
|
-
max_validations: NotRequired[PositiveInt]
|
72
|
+
system_message: str
|
75
73
|
|
76
74
|
|
77
|
-
|
78
|
-
|
79
|
-
"""Arguments for content generation operations.
|
75
|
+
class ValidateKwargs[T](GenerateKwargs, total=False):
|
76
|
+
"""Arguments for content validation operations.
|
80
77
|
|
81
|
-
Extends
|
82
|
-
|
78
|
+
Extends LLMKwargs with additional parameters specific to validation tasks,
|
79
|
+
such as limiting the number of validation attempts.
|
83
80
|
"""
|
84
81
|
|
85
|
-
|
82
|
+
default: T
|
83
|
+
max_validations: int
|
86
84
|
|
87
85
|
|
88
86
|
# noinspection PyTypedDict
|
89
|
-
class ReviewKwargs[T](
|
87
|
+
class ReviewKwargs[T](ValidateKwargs[T], total=False):
|
90
88
|
"""Arguments for content review operations.
|
91
89
|
|
92
90
|
Extends GenerateKwargs with parameters for evaluating content against
|
@@ -94,18 +92,18 @@ class ReviewKwargs[T](GenerateKwargs[T]):
|
|
94
92
|
"""
|
95
93
|
|
96
94
|
topic: str
|
97
|
-
criteria:
|
95
|
+
criteria: set[str]
|
98
96
|
|
99
97
|
|
100
98
|
# noinspection PyTypedDict
|
101
|
-
class ChooseKwargs[T](
|
99
|
+
class ChooseKwargs[T](ValidateKwargs[T], total=False):
|
102
100
|
"""Arguments for selection operations.
|
103
101
|
|
104
102
|
Extends GenerateKwargs with parameters for selecting among options,
|
105
103
|
such as the number of items to choose.
|
106
104
|
"""
|
107
105
|
|
108
|
-
k:
|
106
|
+
k: int
|
109
107
|
|
110
108
|
|
111
109
|
class CacheKwargs(TypedDict, total=False):
|
@@ -115,35 +113,35 @@ class CacheKwargs(TypedDict, total=False):
|
|
115
113
|
including in-memory, Redis, S3, and vector database caching options.
|
116
114
|
"""
|
117
115
|
|
118
|
-
mode:
|
119
|
-
host:
|
120
|
-
port:
|
121
|
-
password:
|
122
|
-
namespace:
|
123
|
-
ttl:
|
124
|
-
default_in_memory_ttl:
|
125
|
-
default_in_redis_ttl:
|
126
|
-
similarity_threshold:
|
127
|
-
supported_call_types:
|
116
|
+
mode: CacheMode # when default_on cache is always on, when default_off cache is opt in
|
117
|
+
host: str
|
118
|
+
port: str
|
119
|
+
password: str
|
120
|
+
namespace: str
|
121
|
+
ttl: float
|
122
|
+
default_in_memory_ttl: float
|
123
|
+
default_in_redis_ttl: float
|
124
|
+
similarity_threshold: float
|
125
|
+
supported_call_types: list[CachingSupportedCallTypes]
|
128
126
|
# s3 Bucket, boto3 configuration
|
129
|
-
s3_bucket_name:
|
130
|
-
s3_region_name:
|
131
|
-
s3_api_version:
|
132
|
-
s3_use_ssl:
|
133
|
-
s3_verify:
|
134
|
-
s3_endpoint_url:
|
135
|
-
s3_aws_access_key_id:
|
136
|
-
s3_aws_secret_access_key:
|
137
|
-
s3_aws_session_token:
|
138
|
-
s3_config:
|
139
|
-
s3_path:
|
127
|
+
s3_bucket_name: str
|
128
|
+
s3_region_name: str
|
129
|
+
s3_api_version: str
|
130
|
+
s3_use_ssl: bool
|
131
|
+
s3_verify: bool | str
|
132
|
+
s3_endpoint_url: str
|
133
|
+
s3_aws_access_key_id: str
|
134
|
+
s3_aws_secret_access_key: str
|
135
|
+
s3_aws_session_token: str
|
136
|
+
s3_config: Any
|
137
|
+
s3_path: str
|
140
138
|
redis_semantic_cache_use_async: bool
|
141
139
|
redis_semantic_cache_embedding_model: str
|
142
|
-
redis_flush_size:
|
143
|
-
redis_startup_nodes:
|
140
|
+
redis_flush_size: int
|
141
|
+
redis_startup_nodes: list
|
144
142
|
disk_cache_dir: Any
|
145
|
-
qdrant_api_base:
|
146
|
-
qdrant_api_key:
|
147
|
-
qdrant_collection_name:
|
148
|
-
qdrant_quantization_config:
|
143
|
+
qdrant_api_base: str
|
144
|
+
qdrant_api_key: str
|
145
|
+
qdrant_collection_name: str
|
146
|
+
qdrant_quantization_config: str
|
149
147
|
qdrant_semantic_cache_embedding_model: str
|
fabricatio/models/tool.py
CHANGED
@@ -4,7 +4,7 @@ from importlib.machinery import ModuleSpec
|
|
4
4
|
from importlib.util import module_from_spec
|
5
5
|
from inspect import iscoroutinefunction, signature
|
6
6
|
from types import CodeType, ModuleType
|
7
|
-
from typing import Any, Callable, Dict, List, Optional, Self, overload
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, Self, cast, overload
|
8
8
|
|
9
9
|
from fabricatio.config import configs
|
10
10
|
from fabricatio.decorators import logging_execution_info, use_temp_module
|
@@ -136,7 +136,7 @@ class ToolExecutor(BaseModel):
|
|
136
136
|
|
137
137
|
def inject_tools[M: ModuleType](self, module: Optional[M] = None) -> M:
|
138
138
|
"""Inject the tools into the provided module or default."""
|
139
|
-
module = module or module_from_spec(spec=ModuleSpec(name=configs.toolbox.tool_module_name, loader=None))
|
139
|
+
module = module or cast(M, module_from_spec(spec=ModuleSpec(name=configs.toolbox.tool_module_name, loader=None)))
|
140
140
|
for tool in self.candidates:
|
141
141
|
logger.debug(f"Injecting tool: {tool.name}")
|
142
142
|
setattr(module, tool.name, tool.invoke)
|
@@ -144,7 +144,7 @@ class ToolExecutor(BaseModel):
|
|
144
144
|
|
145
145
|
def inject_data[M: ModuleType](self, module: Optional[M] = None) -> M:
|
146
146
|
"""Inject the data into the provided module or default."""
|
147
|
-
module = module or module_from_spec(spec=ModuleSpec(name=configs.toolbox.data_module_name, loader=None))
|
147
|
+
module = module or cast(M,module_from_spec(spec=ModuleSpec(name=configs.toolbox.data_module_name, loader=None)))
|
148
148
|
for key, value in self.data.items():
|
149
149
|
logger.debug(f"Injecting data: {key}")
|
150
150
|
setattr(module, key, value)
|
@@ -184,6 +184,6 @@ class ToolExecutor(BaseModel):
|
|
184
184
|
tools = []
|
185
185
|
while tool_name := recipe.pop(0):
|
186
186
|
for toolbox in toolboxes:
|
187
|
-
tools.append(toolbox
|
187
|
+
tools.append(toolbox.get(tool_name))
|
188
188
|
|
189
189
|
return cls(candidates=tools)
|
fabricatio/models/usages.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
"""This module contains classes that manage the usage of language models and tools in tasks."""
|
2
2
|
|
3
3
|
from asyncio import gather
|
4
|
-
from typing import Callable, Dict, Iterable, List, Optional, Self, Set, Type, Union, Unpack, overload
|
4
|
+
from typing import Callable, Dict, Iterable, List, Optional, Self, Sequence, Set, Type, Union, Unpack, overload
|
5
5
|
|
6
6
|
import asyncstdlib
|
7
7
|
import litellm
|
@@ -9,7 +9,7 @@ from fabricatio._rust_instances import template_manager
|
|
9
9
|
from fabricatio.config import configs
|
10
10
|
from fabricatio.journal import logger
|
11
11
|
from fabricatio.models.generic import ScopedConfig, WithBriefing
|
12
|
-
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs
|
12
|
+
from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs, ValidateKwargs
|
13
13
|
from fabricatio.models.task import Task
|
14
14
|
from fabricatio.models.tool import Tool, ToolBox
|
15
15
|
from fabricatio.models.utils import Messages
|
@@ -20,12 +20,13 @@ from litellm.types.utils import (
|
|
20
20
|
EmbeddingResponse,
|
21
21
|
ModelResponse,
|
22
22
|
StreamingChoices,
|
23
|
+
TextChoices,
|
23
24
|
)
|
24
25
|
from litellm.utils import CustomStreamWrapper
|
25
26
|
from more_itertools import duplicates_everseen
|
26
27
|
from pydantic import Field, NonNegativeInt, PositiveInt
|
27
28
|
|
28
|
-
if configs.cache.enabled:
|
29
|
+
if configs.cache.enabled and configs.cache.type:
|
29
30
|
litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
|
30
31
|
logger.success(f"{configs.cache.type.name} Cache enabled")
|
31
32
|
|
@@ -42,7 +43,7 @@ class LLMUsage(ScopedConfig):
|
|
42
43
|
messages: List[Dict[str, str]],
|
43
44
|
n: PositiveInt | None = None,
|
44
45
|
**kwargs: Unpack[LLMKwargs],
|
45
|
-
) -> ModelResponse
|
46
|
+
) -> ModelResponse:
|
46
47
|
"""Asynchronously queries the language model to generate a response based on the provided messages and parameters.
|
47
48
|
|
48
49
|
Args:
|
@@ -81,7 +82,7 @@ class LLMUsage(ScopedConfig):
|
|
81
82
|
system_message: str = "",
|
82
83
|
n: PositiveInt | None = None,
|
83
84
|
**kwargs: Unpack[LLMKwargs],
|
84
|
-
) ->
|
85
|
+
) -> Sequence[TextChoices | Choices | StreamingChoices]:
|
85
86
|
"""Asynchronously invokes the language model with a question and optional system message.
|
86
87
|
|
87
88
|
Args:
|
@@ -101,13 +102,14 @@ class LLMUsage(ScopedConfig):
|
|
101
102
|
if isinstance(resp, ModelResponse):
|
102
103
|
return resp.choices
|
103
104
|
if isinstance(resp, CustomStreamWrapper):
|
104
|
-
if not configs.debug.streaming_visible:
|
105
|
-
return
|
105
|
+
if not configs.debug.streaming_visible and (pack := stream_chunk_builder(await asyncstdlib.list())):
|
106
|
+
return pack.choices
|
106
107
|
chunks = []
|
107
108
|
async for chunk in resp:
|
108
109
|
chunks.append(chunk)
|
109
110
|
print(chunk.choices[0].delta.content or "", end="") # noqa: T201
|
110
|
-
|
111
|
+
if pack := stream_chunk_builder(chunks):
|
112
|
+
return pack.choices
|
111
113
|
logger.critical(err := f"Unexpected response type: {type(resp)}")
|
112
114
|
raise ValueError(err)
|
113
115
|
|
@@ -166,15 +168,15 @@ class LLMUsage(ScopedConfig):
|
|
166
168
|
for q, sm in zip(q_seq, sm_seq, strict=True)
|
167
169
|
]
|
168
170
|
)
|
169
|
-
return [r.
|
171
|
+
return [r[0].message.content for r in res]
|
170
172
|
case (list(q_seq), str(sm)):
|
171
173
|
res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for q in q_seq])
|
172
|
-
return [r.
|
174
|
+
return [r[0].message.content for r in res]
|
173
175
|
case (str(q), list(sm_seq)):
|
174
176
|
res = await gather(*[self.ainvoke(n=1, question=q, system_message=sm, **kwargs) for sm in sm_seq])
|
175
|
-
return [r.
|
177
|
+
return [r[0].message.content for r in res]
|
176
178
|
case (str(q), str(sm)):
|
177
|
-
return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs))
|
179
|
+
return ((await self.ainvoke(n=1, question=q, system_message=sm, **kwargs))[0]).message.content
|
178
180
|
case _:
|
179
181
|
raise RuntimeError("Should not reach here.")
|
180
182
|
|
@@ -185,8 +187,7 @@ class LLMUsage(ScopedConfig):
|
|
185
187
|
validator: Callable[[str], T | None],
|
186
188
|
default: T,
|
187
189
|
max_validations: PositiveInt = 2,
|
188
|
-
|
189
|
-
**kwargs: Unpack[LLMKwargs],
|
190
|
+
**kwargs: Unpack[GenerateKwargs],
|
190
191
|
) -> T: ...
|
191
192
|
@overload
|
192
193
|
async def aask_validate[T](
|
@@ -195,19 +196,36 @@ class LLMUsage(ScopedConfig):
|
|
195
196
|
validator: Callable[[str], T | None],
|
196
197
|
default: None = None,
|
197
198
|
max_validations: PositiveInt = 2,
|
198
|
-
|
199
|
-
**kwargs: Unpack[LLMKwargs],
|
199
|
+
**kwargs: Unpack[GenerateKwargs],
|
200
200
|
) -> Optional[T]: ...
|
201
201
|
|
202
|
+
@overload
|
202
203
|
async def aask_validate[T](
|
203
204
|
self,
|
204
|
-
question: str,
|
205
|
+
question: List[str],
|
206
|
+
validator: Callable[[str], T | None],
|
207
|
+
default: None = None,
|
208
|
+
max_validations: PositiveInt = 2,
|
209
|
+
**kwargs: Unpack[GenerateKwargs],
|
210
|
+
) -> List[Optional[T]]: ...
|
211
|
+
@overload
|
212
|
+
async def aask_validate[T](
|
213
|
+
self,
|
214
|
+
question: List[str],
|
215
|
+
validator: Callable[[str], T | None],
|
216
|
+
default: T,
|
217
|
+
max_validations: PositiveInt = 2,
|
218
|
+
**kwargs: Unpack[GenerateKwargs],
|
219
|
+
) -> List[T]: ...
|
220
|
+
|
221
|
+
async def aask_validate[T](
|
222
|
+
self,
|
223
|
+
question: str | List[str],
|
205
224
|
validator: Callable[[str], T | None],
|
206
225
|
default: Optional[T] = None,
|
207
226
|
max_validations: PositiveInt = 2,
|
208
|
-
|
209
|
-
|
210
|
-
) -> Optional[T]:
|
227
|
+
**kwargs: Unpack[GenerateKwargs],
|
228
|
+
) -> Optional[T] | List[Optional[T]] | List[T] | T:
|
211
229
|
"""Asynchronously asks a question and validates the response using a given validator.
|
212
230
|
|
213
231
|
Args:
|
@@ -215,59 +233,42 @@ class LLMUsage(ScopedConfig):
|
|
215
233
|
validator (Callable[[str], T | None]): A function to validate the response.
|
216
234
|
default (T | None): Default value to return if validation fails. Defaults to None.
|
217
235
|
max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 2.
|
218
|
-
system_message (str): System message to include in the request. Defaults to an empty string.
|
219
236
|
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments for the LLM usage.
|
220
237
|
|
221
238
|
Returns:
|
222
239
|
T: The validated response.
|
223
240
|
|
224
241
|
"""
|
225
|
-
for i in range(max_validations):
|
226
|
-
if (
|
227
|
-
response := await self.aask(
|
228
|
-
question=question,
|
229
|
-
system_message=system_message,
|
230
|
-
**kwargs,
|
231
|
-
)
|
232
|
-
) and (validated := validator(response)):
|
233
|
-
logger.debug(f"Successfully validated the response at {i}th attempt.")
|
234
|
-
return validated
|
235
|
-
kwargs["no_cache"] = True
|
236
|
-
logger.debug("Closed the cache for the next attempt")
|
237
|
-
if default is None:
|
238
|
-
logger.error(f"Failed to validate the response after {max_validations} attempts.")
|
239
|
-
return default
|
240
|
-
|
241
|
-
async def aask_validate_batch[T](
|
242
|
-
self,
|
243
|
-
questions: List[str],
|
244
|
-
validator: Callable[[str], T | None],
|
245
|
-
**kwargs: Unpack[GenerateKwargs[T]],
|
246
|
-
) -> List[T]:
|
247
|
-
"""Asynchronously asks a batch of questions and validates the responses using a given validator.
|
248
242
|
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
243
|
+
async def _inner(q: str) -> Optional[T]:
|
244
|
+
for lap in range(max_validations):
|
245
|
+
try:
|
246
|
+
if (response := await self.aask(question=q, **kwargs)) and (validated := validator(response)):
|
247
|
+
logger.debug(f"Successfully validated the response at {lap}th attempt.")
|
248
|
+
return validated
|
249
|
+
except Exception as e: # noqa: BLE001
|
250
|
+
logger.error(f"Error during validation: \n{e}")
|
251
|
+
break
|
252
|
+
kwargs["no_cache"] = True
|
253
|
+
logger.debug("Closed the cache for the next attempt")
|
254
|
+
if default is None:
|
255
|
+
logger.error(f"Failed to validate the response after {max_validations} attempts.")
|
256
|
+
return default
|
257
|
+
|
258
|
+
if isinstance(question, str):
|
259
|
+
return await _inner(question)
|
260
|
+
|
261
|
+
return await gather(*[_inner(q) for q in question])
|
261
262
|
|
262
263
|
async def aliststr(
|
263
|
-
self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[
|
264
|
+
self, requirement: str, k: NonNegativeInt = 0, **kwargs: Unpack[ValidateKwargs[List[str]]]
|
264
265
|
) -> List[str]:
|
265
266
|
"""Asynchronously generates a list of strings based on a given requirement.
|
266
267
|
|
267
268
|
Args:
|
268
269
|
requirement (str): The requirement for the list of strings.
|
269
270
|
k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
|
270
|
-
**kwargs (Unpack[
|
271
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
271
272
|
|
272
273
|
Returns:
|
273
274
|
List[str]: The validated response as a list of strings.
|
@@ -299,12 +300,12 @@ class LLMUsage(ScopedConfig):
|
|
299
300
|
**kwargs,
|
300
301
|
)
|
301
302
|
|
302
|
-
async def awhich_pathstr(self, requirement: str, **kwargs: Unpack[
|
303
|
+
async def awhich_pathstr(self, requirement: str, **kwargs: Unpack[ValidateKwargs[List[str]]]) -> str:
|
303
304
|
"""Asynchronously generates a single path string based on a given requirement.
|
304
305
|
|
305
306
|
Args:
|
306
307
|
requirement (str): The requirement for the list of strings.
|
307
|
-
**kwargs (Unpack[
|
308
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
308
309
|
|
309
310
|
Returns:
|
310
311
|
str: The validated response as a single string.
|
@@ -322,7 +323,7 @@ class LLMUsage(ScopedConfig):
|
|
322
323
|
instruction: str,
|
323
324
|
choices: List[T],
|
324
325
|
k: NonNegativeInt = 0,
|
325
|
-
**kwargs: Unpack[
|
326
|
+
**kwargs: Unpack[ValidateKwargs[List[T]]],
|
326
327
|
) -> List[T]:
|
327
328
|
"""Asynchronously executes a multi-choice decision-making process, generating a prompt based on the instruction and options, and validates the returned selection results.
|
328
329
|
|
@@ -330,7 +331,7 @@ class LLMUsage(ScopedConfig):
|
|
330
331
|
instruction (str): The user-provided instruction/question description.
|
331
332
|
choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
|
332
333
|
k (NonNegativeInt): The number of choices to select, 0 means infinite. Defaults to 0.
|
333
|
-
**kwargs (Unpack[
|
334
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
334
335
|
|
335
336
|
Returns:
|
336
337
|
List[T]: The final validated selection result list, with element types matching the input `choices`.
|
@@ -373,14 +374,14 @@ class LLMUsage(ScopedConfig):
|
|
373
374
|
self,
|
374
375
|
instruction: str,
|
375
376
|
choices: List[T],
|
376
|
-
**kwargs: Unpack[
|
377
|
+
**kwargs: Unpack[ValidateKwargs[List[T]]],
|
377
378
|
) -> T:
|
378
379
|
"""Asynchronously picks a single choice from a list of options using AI validation.
|
379
380
|
|
380
381
|
Args:
|
381
382
|
instruction (str): The user-provided instruction/question description.
|
382
383
|
choices (List[T]): A list of candidate options, requiring elements to have `name` and `briefing` fields.
|
383
|
-
**kwargs (Unpack[
|
384
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
384
385
|
|
385
386
|
Returns:
|
386
387
|
T: The single selected item from the choices list.
|
@@ -402,7 +403,7 @@ class LLMUsage(ScopedConfig):
|
|
402
403
|
prompt: str,
|
403
404
|
affirm_case: str = "",
|
404
405
|
deny_case: str = "",
|
405
|
-
**kwargs: Unpack[
|
406
|
+
**kwargs: Unpack[ValidateKwargs[bool]],
|
406
407
|
) -> bool:
|
407
408
|
"""Asynchronously judges a prompt using AI validation.
|
408
409
|
|
@@ -410,7 +411,7 @@ class LLMUsage(ScopedConfig):
|
|
410
411
|
prompt (str): The input prompt to be judged.
|
411
412
|
affirm_case (str): The affirmative case for the AI model. Defaults to an empty string.
|
412
413
|
deny_case (str): The negative case for the AI model. Defaults to an empty string.
|
413
|
-
**kwargs (Unpack[
|
414
|
+
**kwargs (Unpack[ValidateKwargs]): Additional keyword arguments for the LLM usage.
|
414
415
|
|
415
416
|
Returns:
|
416
417
|
bool: The judgment result (True or False) based on the AI's response.
|
@@ -516,7 +517,6 @@ class ToolBoxUsage(LLMUsage):
|
|
516
517
|
async def choose_toolboxes(
|
517
518
|
self,
|
518
519
|
task: Task,
|
519
|
-
system_message: str = "",
|
520
520
|
**kwargs: Unpack[ChooseKwargs[List[ToolBox]]],
|
521
521
|
) -> List[ToolBox]:
|
522
522
|
"""Asynchronously executes a multi-choice decision-making process to choose toolboxes.
|
@@ -535,7 +535,6 @@ class ToolBoxUsage(LLMUsage):
|
|
535
535
|
return await self.achoose(
|
536
536
|
instruction=task.briefing,
|
537
537
|
choices=list(self.toolboxes),
|
538
|
-
system_message=system_message,
|
539
538
|
**kwargs,
|
540
539
|
)
|
541
540
|
|
fabricatio/parser.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1
1
|
"""A module to parse text using regular expressions."""
|
2
2
|
|
3
|
-
from typing import Any, Callable, Optional, Self, Tuple, Type
|
3
|
+
from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type
|
4
4
|
|
5
5
|
import orjson
|
6
6
|
import regex
|
7
|
+
from json_repair import repair_json
|
7
8
|
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
|
8
9
|
from regex import Pattern, compile
|
9
10
|
|
11
|
+
from fabricatio.config import configs
|
10
12
|
from fabricatio.journal import logger
|
11
13
|
|
12
14
|
|
@@ -25,12 +27,31 @@ class Capture(BaseModel):
|
|
25
27
|
"""The regular expression pattern to search for."""
|
26
28
|
flags: PositiveInt = Field(default=regex.DOTALL | regex.MULTILINE | regex.IGNORECASE, frozen=True)
|
27
29
|
"""The flags to use when compiling the regular expression pattern."""
|
30
|
+
capture_type: Optional[str] = None
|
31
|
+
"""The type of capture to perform, e.g., 'json', which is used to dispatch the fixer accordingly."""
|
28
32
|
_compiled: Pattern = PrivateAttr()
|
29
33
|
|
30
34
|
def model_post_init(self, __context: Any) -> None:
|
31
35
|
"""Initialize the compiled pattern."""
|
32
36
|
self._compiled = compile(self.pattern, self.flags)
|
33
37
|
|
38
|
+
def fix[T](self, text: str | Iterable[str]|T) -> str | List[str]|T:
|
39
|
+
"""Fix the text using the pattern.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
text (str | List[str]): The text to fix.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
str | List[str]: The fixed text with the same type as input.
|
46
|
+
"""
|
47
|
+
match self.capture_type:
|
48
|
+
case "json":
|
49
|
+
if isinstance(text, str):
|
50
|
+
return repair_json(text,ensure_ascii=False)
|
51
|
+
return [repair_json(item) for item in text]
|
52
|
+
case _:
|
53
|
+
return text
|
54
|
+
|
34
55
|
def capture(self, text: str) -> Tuple[str, ...] | str | None:
|
35
56
|
"""Capture the first occurrence of the pattern in the given text.
|
36
57
|
|
@@ -44,12 +65,12 @@ class Capture(BaseModel):
|
|
44
65
|
match = self._compiled.search(text)
|
45
66
|
if match is None:
|
46
67
|
return None
|
47
|
-
|
68
|
+
groups = self.fix(match.groups()) if configs.general.use_json_repair else match.groups()
|
48
69
|
if self.target_groups:
|
49
|
-
cap = tuple(
|
70
|
+
cap = tuple(groups[g - 1] for g in self.target_groups)
|
50
71
|
logger.debug(f"Captured text: {'\n\n'.join(cap)}")
|
51
72
|
return cap
|
52
|
-
cap =
|
73
|
+
cap = groups[0]
|
53
74
|
logger.debug(f"Captured text: \n{cap}")
|
54
75
|
return cap
|
55
76
|
|
@@ -111,7 +132,7 @@ class Capture(BaseModel):
|
|
111
132
|
Returns:
|
112
133
|
Self: The instance of the class with the captured code block.
|
113
134
|
"""
|
114
|
-
return cls(pattern=f"```{language}\n(.*?)\n```")
|
135
|
+
return cls(pattern=f"```{language}\n(.*?)\n```", capture_type=language)
|
115
136
|
|
116
137
|
|
117
138
|
JsonCapture = Capture.capture_code_block("json")
|
Binary file
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: fabricatio
|
3
|
-
Version: 0.2.5.
|
3
|
+
Version: 0.2.5.dev5
|
4
4
|
Classifier: License :: OSI Approved :: MIT License
|
5
5
|
Classifier: Programming Language :: Rust
|
6
6
|
Classifier: Programming Language :: Python :: 3.12
|
@@ -11,6 +11,7 @@ Classifier: Typing :: Typed
|
|
11
11
|
Requires-Dist: appdirs>=1.4.4
|
12
12
|
Requires-Dist: asyncio>=3.4.3
|
13
13
|
Requires-Dist: asyncstdlib>=3.13.0
|
14
|
+
Requires-Dist: json-repair>=0.39.1
|
14
15
|
Requires-Dist: litellm>=1.60.0
|
15
16
|
Requires-Dist: loguru>=0.7.3
|
16
17
|
Requires-Dist: magika>=0.5.1
|