fabricatio 0.2.10.dev0__cp312-cp312-win_amd64.whl → 0.2.11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +55 -10
- fabricatio/actions/article_rag.py +297 -12
- fabricatio/actions/fs.py +25 -0
- fabricatio/actions/output.py +17 -3
- fabricatio/actions/rag.py +42 -20
- fabricatio/actions/rules.py +14 -3
- fabricatio/capabilities/extract.py +70 -0
- fabricatio/capabilities/rag.py +5 -2
- fabricatio/capabilities/rating.py +5 -2
- fabricatio/capabilities/task.py +16 -16
- fabricatio/config.py +9 -2
- fabricatio/decorators.py +43 -26
- fabricatio/fs/__init__.py +9 -2
- fabricatio/fs/readers.py +6 -10
- fabricatio/models/action.py +16 -11
- fabricatio/models/adv_kwargs_types.py +5 -12
- fabricatio/models/extra/aricle_rag.py +254 -0
- fabricatio/models/extra/article_base.py +56 -7
- fabricatio/models/extra/article_essence.py +8 -7
- fabricatio/models/extra/article_main.py +102 -6
- fabricatio/models/extra/problem.py +5 -1
- fabricatio/models/extra/rag.py +49 -23
- fabricatio/models/generic.py +43 -24
- fabricatio/models/kwargs_types.py +12 -3
- fabricatio/models/task.py +13 -1
- fabricatio/models/usages.py +10 -27
- fabricatio/parser.py +16 -12
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +177 -63
- fabricatio/utils.py +50 -10
- fabricatio-0.2.11.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/METADATA +20 -12
- fabricatio-0.2.11.dist-info/RECORD +65 -0
- fabricatio-0.2.10.dev0.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.10.dev0.dist-info/RECORD +0 -62
- {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/licenses/LICENSE +0 -0
fabricatio/models/extra/rag.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1
1
|
"""A module containing the RAG (Retrieval-Augmented Generation) models."""
|
2
2
|
|
3
|
-
from abc import
|
4
|
-
from
|
3
|
+
from abc import ABC
|
4
|
+
from functools import partial
|
5
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Self, Sequence, Set
|
5
6
|
|
6
7
|
from fabricatio.decorators import precheck_package
|
7
|
-
from
|
8
|
+
from fabricatio.models.generic import Vectorizable
|
9
|
+
from fabricatio.utils import ok
|
10
|
+
from pydantic import JsonValue
|
8
11
|
|
9
12
|
if TYPE_CHECKING:
|
10
13
|
from importlib.util import find_spec
|
@@ -15,14 +18,18 @@ if TYPE_CHECKING:
|
|
15
18
|
from pymilvus import CollectionSchema
|
16
19
|
|
17
20
|
|
18
|
-
class MilvusDataBase(
|
21
|
+
class MilvusDataBase(Vectorizable, ABC):
|
19
22
|
"""A base class for Milvus data."""
|
20
23
|
|
21
|
-
model_config = ConfigDict(use_attribute_docstrings=True)
|
22
|
-
|
23
24
|
primary_field_name: ClassVar[str] = "id"
|
24
|
-
|
25
|
+
"""The name of the primary field in Milvus."""
|
25
26
|
vector_field_name: ClassVar[str] = "vector"
|
27
|
+
"""The name of the vector field in Milvus."""
|
28
|
+
|
29
|
+
index_type: ClassVar[str] = "FLAT"
|
30
|
+
"""The type of index to be used in Milvus."""
|
31
|
+
metric_type: ClassVar[str] = "COSINE"
|
32
|
+
"""The type of metric to be used in Milvus."""
|
26
33
|
|
27
34
|
def prepare_insertion(self, vector: List[float]) -> Dict[str, Any]:
|
28
35
|
"""Prepares the data for insertion into Milvus.
|
@@ -32,11 +39,6 @@ class MilvusDataBase(BaseModel, metaclass=ABCMeta):
|
|
32
39
|
"""
|
33
40
|
return {**self.model_dump(exclude_none=True, by_alias=True), self.vector_field_name: vector}
|
34
41
|
|
35
|
-
@property
|
36
|
-
@abstractmethod
|
37
|
-
def to_vectorize(self) -> str:
|
38
|
-
"""The text representation of the data."""
|
39
|
-
|
40
42
|
@classmethod
|
41
43
|
@precheck_package(
|
42
44
|
"pymilvus", "pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`?"
|
@@ -50,23 +52,47 @@ class MilvusDataBase(BaseModel, metaclass=ABCMeta):
|
|
50
52
|
FieldSchema(cls.vector_field_name, dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
51
53
|
]
|
52
54
|
|
53
|
-
type_mapping = {
|
54
|
-
str: DataType.STRING,
|
55
|
-
int: DataType.INT64,
|
56
|
-
float: DataType.DOUBLE,
|
57
|
-
JsonValue: DataType.JSON,
|
58
|
-
# TODO add more mapping
|
59
|
-
}
|
60
|
-
|
61
55
|
for k, v in cls.model_fields.items():
|
62
56
|
k: str
|
63
57
|
v: FieldInfo
|
64
|
-
|
65
|
-
|
66
|
-
|
58
|
+
schema = partial(FieldSchema, k, description=v.description or "")
|
59
|
+
anno = ok(v.annotation)
|
60
|
+
|
61
|
+
if anno == int:
|
62
|
+
fields.append(schema(dtype=DataType.INT64))
|
63
|
+
elif anno == str:
|
64
|
+
fields.append(schema(dtype=DataType.VARCHAR, max_length=65535))
|
65
|
+
elif anno == float:
|
66
|
+
fields.append(schema(dtype=DataType.DOUBLE))
|
67
|
+
elif anno == list[str] or anno == List[str] or anno == set[str] or anno == Set[str]:
|
68
|
+
fields.append(
|
69
|
+
schema(dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_length=65535, max_capacity=4096)
|
70
|
+
)
|
71
|
+
elif anno == list[int] or anno == List[int] or anno == set[int] or anno == Set[int]:
|
72
|
+
fields.append(schema(dtype=DataType.ARRAY, element_type=DataType.INT64, max_capacity=4096))
|
73
|
+
elif anno == list[float] or anno == List[float] or anno == set[float] or anno == Set[float]:
|
74
|
+
fields.append(schema(dtype=DataType.ARRAY, element_type=DataType.DOUBLE, max_capacity=4096))
|
75
|
+
elif anno == JsonValue:
|
76
|
+
fields.append(schema(dtype=DataType.JSON))
|
77
|
+
|
78
|
+
else:
|
79
|
+
raise NotImplementedError(f"{k}:{anno} is not supported")
|
80
|
+
|
67
81
|
return CollectionSchema(fields)
|
68
82
|
|
69
83
|
@classmethod
|
70
84
|
def from_sequence(cls, data: Sequence[Dict[str, Any]]) -> List[Self]:
|
71
85
|
"""Constructs a list of instances from a sequence of dictionaries."""
|
72
86
|
return [cls(**d) for d in data]
|
87
|
+
|
88
|
+
|
89
|
+
class MilvusClassicModel(MilvusDataBase):
|
90
|
+
"""A class representing a classic model stored in Milvus."""
|
91
|
+
|
92
|
+
text: str
|
93
|
+
"""The text to be stored in Milvus."""
|
94
|
+
subject: str = ""
|
95
|
+
"""The subject of the text."""
|
96
|
+
|
97
|
+
def _prepare_vectorization_inner(self) -> str:
|
98
|
+
return self.text
|
fabricatio/models/generic.py
CHANGED
@@ -3,12 +3,11 @@
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from datetime import datetime
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Type, Union, final, overload
|
6
|
+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Self, Type, Union, final, overload
|
7
7
|
|
8
|
-
import
|
9
|
-
import rtoml
|
8
|
+
import ujson
|
10
9
|
from fabricatio.config import configs
|
11
|
-
from fabricatio.fs.readers import
|
10
|
+
from fabricatio.fs.readers import safe_text_read
|
12
11
|
from fabricatio.journal import logger
|
13
12
|
from fabricatio.parser import JsonCapture
|
14
13
|
from fabricatio.rust import blake3_hash, detect_language
|
@@ -53,7 +52,7 @@ class Display(Base):
|
|
53
52
|
Returns:
|
54
53
|
str: JSON string with 1-level indentation for readability
|
55
54
|
"""
|
56
|
-
return self.model_dump_json(indent=1,by_alias=True)
|
55
|
+
return self.model_dump_json(indent=1, by_alias=True)
|
57
56
|
|
58
57
|
def compact(self) -> str:
|
59
58
|
"""Generate compact JSON representation.
|
@@ -118,6 +117,15 @@ class WordCount(Base):
|
|
118
117
|
"""Expected word count of this research component."""
|
119
118
|
|
120
119
|
|
120
|
+
class FromMapping(Base):
|
121
|
+
"""Class that provides a method to generate a list of objects from a mapping."""
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
@abstractmethod
|
125
|
+
def from_mapping(cls, mapping: Mapping[str, Any], **kwargs: Any) -> List[Self]:
|
126
|
+
"""Generate a list of objects from a mapping."""
|
127
|
+
|
128
|
+
|
121
129
|
class AsPrompt(Base):
|
122
130
|
"""Class that provides a method to generate a prompt from the model.
|
123
131
|
|
@@ -171,11 +179,14 @@ class WithRef[T](Base):
|
|
171
179
|
|
172
180
|
@overload
|
173
181
|
def update_ref[S: WithRef](self: S, reference: T) -> S: ...
|
182
|
+
|
174
183
|
@overload
|
175
184
|
def update_ref[S: WithRef](self: S, reference: "WithRef[T]") -> S: ...
|
185
|
+
|
176
186
|
@overload
|
177
187
|
def update_ref[S: WithRef](self: S, reference: None = None) -> S: ...
|
178
|
-
|
188
|
+
|
189
|
+
def update_ref[S: WithRef](self: S, reference: Union[T, "WithRef[T]", None] = None) -> S:
|
179
190
|
"""Update the reference of the object.
|
180
191
|
|
181
192
|
Args:
|
@@ -190,7 +201,7 @@ class WithRef[T](Base):
|
|
190
201
|
self._reference = reference # pyright: ignore [reportAttributeAccessIssue]
|
191
202
|
return self
|
192
203
|
|
193
|
-
def derive[S: WithRef](self: S, reference: Any) -> S:
|
204
|
+
def derive[S: WithRef](self: S, reference: Any) -> S:
|
194
205
|
"""Derive a new object from the current object.
|
195
206
|
|
196
207
|
Args:
|
@@ -225,7 +236,7 @@ class PersistentAble(Base):
|
|
225
236
|
- Hash generated from JSON content ensures uniqueness
|
226
237
|
"""
|
227
238
|
p = Path(path)
|
228
|
-
out = self.model_dump_json(indent=1,by_alias=True)
|
239
|
+
out = self.model_dump_json(indent=1, by_alias=True)
|
229
240
|
|
230
241
|
# Generate a timestamp in the format YYYYMMDD_HHMMSS
|
231
242
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
@@ -299,16 +310,18 @@ class Language(Base):
|
|
299
310
|
"""Class that provides a language attribute."""
|
300
311
|
|
301
312
|
@property
|
302
|
-
def language(self)->str:
|
313
|
+
def language(self) -> str:
|
303
314
|
"""Get the language of the object."""
|
304
|
-
if isinstance(self,Described):
|
315
|
+
if isinstance(self, Described):
|
305
316
|
return detect_language(self.description)
|
306
|
-
if isinstance(self,Titled):
|
317
|
+
if isinstance(self, Titled):
|
307
318
|
return detect_language(self.title)
|
308
|
-
if isinstance(self,Named):
|
319
|
+
if isinstance(self, Named):
|
309
320
|
return detect_language(self.name)
|
310
321
|
|
311
322
|
return detect_language(self.model_dump_json(by_alias=True))
|
323
|
+
|
324
|
+
|
312
325
|
class ModelHash(Base):
|
313
326
|
"""Class that provides a hash value for the object.
|
314
327
|
|
@@ -454,10 +467,9 @@ class WithFormatedJsonSchema(Base):
|
|
454
467
|
Returns:
|
455
468
|
str: The JSON schema of the model in a formatted string.
|
456
469
|
"""
|
457
|
-
return
|
458
|
-
cls.model_json_schema(schema_generator=UnsortGenerate),
|
459
|
-
|
460
|
-
).decode()
|
470
|
+
return ujson.dumps(
|
471
|
+
cls.model_json_schema(schema_generator=UnsortGenerate), indent=2, ensure_ascii=False, sort_keys=False
|
472
|
+
)
|
461
473
|
|
462
474
|
|
463
475
|
class CreateJsonObjPrompt(WithFormatedJsonSchema):
|
@@ -469,9 +481,11 @@ class CreateJsonObjPrompt(WithFormatedJsonSchema):
|
|
469
481
|
@classmethod
|
470
482
|
@overload
|
471
483
|
def create_json_prompt(cls, requirement: List[str]) -> List[str]: ...
|
484
|
+
|
472
485
|
@classmethod
|
473
486
|
@overload
|
474
487
|
def create_json_prompt(cls, requirement: str) -> str: ...
|
488
|
+
|
475
489
|
@classmethod
|
476
490
|
def create_json_prompt(cls, requirement: str | List[str]) -> str | List[str]:
|
477
491
|
"""Create the prompt for creating a JSON object with given requirement.
|
@@ -550,7 +564,7 @@ class FinalizedDumpAble(Base):
|
|
550
564
|
Returns:
|
551
565
|
str: The finalized dump of the object.
|
552
566
|
"""
|
553
|
-
return self.model_dump_json(indent=1,by_alias=True)
|
567
|
+
return self.model_dump_json(indent=1, by_alias=True)
|
554
568
|
|
555
569
|
def finalized_dump_to(self, path: str | Path) -> Self:
|
556
570
|
"""Finalize the dump of the object to a file.
|
@@ -638,6 +652,8 @@ class WithDependency(Base):
|
|
638
652
|
Returns:
|
639
653
|
str: The generated prompt for the task.
|
640
654
|
"""
|
655
|
+
from fabricatio.fs import MAGIKA
|
656
|
+
|
641
657
|
return TEMPLATE_MANAGER.render_template(
|
642
658
|
configs.templates.dependencies_template,
|
643
659
|
{
|
@@ -662,8 +678,9 @@ class Vectorizable(Base):
|
|
662
678
|
This class includes methods to prepare the model for vectorization, ensuring it fits within a specified token length.
|
663
679
|
"""
|
664
680
|
|
681
|
+
@abstractmethod
|
665
682
|
def _prepare_vectorization_inner(self) -> str:
|
666
|
-
|
683
|
+
"""Prepare the model for vectorization."""
|
667
684
|
|
668
685
|
@final
|
669
686
|
def prepare_vectorization(self, max_length: Optional[int] = None) -> str:
|
@@ -681,8 +698,7 @@ class Vectorizable(Base):
|
|
681
698
|
max_length = max_length or configs.embedding.max_sequence_length
|
682
699
|
chunk = self._prepare_vectorization_inner()
|
683
700
|
if max_length and (length := token_counter(text=chunk)) > max_length:
|
684
|
-
|
685
|
-
raise ValueError(err)
|
701
|
+
raise ValueError(f"Chunk exceeds maximum sequence length {max_length}, got {length}, see \n{chunk}")
|
686
702
|
|
687
703
|
return chunk
|
688
704
|
|
@@ -733,6 +749,12 @@ class ScopedConfig(Base):
|
|
733
749
|
llm_rpm: Optional[PositiveInt] = None
|
734
750
|
"""The requests per minute of the LLM model."""
|
735
751
|
|
752
|
+
llm_presence_penalty: Optional[PositiveFloat] = None
|
753
|
+
"""The presence penalty of the LLM model."""
|
754
|
+
|
755
|
+
llm_frequency_penalty: Optional[PositiveFloat] = None
|
756
|
+
"""The frequency penalty of the LLM model."""
|
757
|
+
|
736
758
|
embedding_api_endpoint: Optional[HttpUrl] = None
|
737
759
|
"""The OpenAI API endpoint."""
|
738
760
|
|
@@ -861,10 +883,7 @@ class Patch[T](ProposedAble):
|
|
861
883
|
)
|
862
884
|
my_schema["description"] = ref_cls.__doc__
|
863
885
|
|
864
|
-
return
|
865
|
-
my_schema,
|
866
|
-
option=orjson.OPT_INDENT_2,
|
867
|
-
).decode()
|
886
|
+
return ujson.dumps(my_schema, indent=2, ensure_ascii=False, sort_keys=False)
|
868
887
|
|
869
888
|
|
870
889
|
class SequencePatch[T](ProposedUpdateAble):
|
@@ -1,11 +1,18 @@
|
|
1
1
|
"""This module contains the types for the keyword arguments of the methods in the models module."""
|
2
2
|
|
3
|
-
from typing import Any, Dict, List, Optional, Required, TypedDict
|
3
|
+
from typing import Any, Dict, List, NotRequired, Optional, Required, TypedDict
|
4
4
|
|
5
5
|
from litellm.caching.caching import CacheMode
|
6
6
|
from litellm.types.caching import CachingSupportedCallTypes
|
7
7
|
|
8
8
|
|
9
|
+
class ChunkKwargs(TypedDict):
|
10
|
+
"""Configuration parameters for chunking operations."""
|
11
|
+
|
12
|
+
max_chunk_size: int
|
13
|
+
max_overlapping_rate: NotRequired[float]
|
14
|
+
|
15
|
+
|
9
16
|
class EmbeddingKwargs(TypedDict, total=False):
|
10
17
|
"""Configuration parameters for text embedding operations.
|
11
18
|
|
@@ -26,7 +33,7 @@ class LLMKwargs(TypedDict, total=False):
|
|
26
33
|
including generation parameters and caching options.
|
27
34
|
"""
|
28
35
|
|
29
|
-
model: str
|
36
|
+
model: Optional[str]
|
30
37
|
temperature: float
|
31
38
|
stop: str | list[str]
|
32
39
|
top_p: float
|
@@ -38,6 +45,8 @@ class LLMKwargs(TypedDict, total=False):
|
|
38
45
|
no_store: bool # If store the response of this call to cache
|
39
46
|
cache_ttl: int # how long the stored cache is alive, in seconds
|
40
47
|
s_maxage: int # max accepted age of cached response, in seconds
|
48
|
+
presence_penalty: float
|
49
|
+
frequency_penalty: float
|
41
50
|
|
42
51
|
|
43
52
|
class GenerateKwargs(LLMKwargs, total=False):
|
@@ -59,7 +68,7 @@ class ValidateKwargs[T](GenerateKwargs, total=False):
|
|
59
68
|
|
60
69
|
default: Optional[T]
|
61
70
|
max_validations: int
|
62
|
-
|
71
|
+
|
63
72
|
|
64
73
|
|
65
74
|
class CompositeScoreKwargs(ValidateKwargs[List[Dict[str, float]]], total=False):
|
fabricatio/models/task.py
CHANGED
@@ -4,7 +4,7 @@ It includes methods to manage the task's lifecycle, such as starting, finishing,
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
from asyncio import Queue
|
7
|
-
from typing import Any, List, Optional, Self
|
7
|
+
from typing import Any, Dict, List, Optional, Self
|
8
8
|
|
9
9
|
from fabricatio.config import configs
|
10
10
|
from fabricatio.constants import TaskStatus
|
@@ -50,6 +50,18 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
|
|
50
50
|
|
51
51
|
_namespace: Event = PrivateAttr(default_factory=Event)
|
52
52
|
"""The namespace of the task as an event, which is generated from the namespace list."""
|
53
|
+
_extra_init_context: Dict = PrivateAttr(default_factory=dict)
|
54
|
+
"""Extra initialization context for the task, which is designed to override the one of the Workflow."""
|
55
|
+
|
56
|
+
@property
|
57
|
+
def extra_init_context(self) -> Dict:
|
58
|
+
"""Extra initialization context for the task, which is designed to override the one of the Workflow."""
|
59
|
+
return self._extra_init_context
|
60
|
+
|
61
|
+
def update_init_context(self, /, **kwargs) -> Self:
|
62
|
+
"""Update the extra initialization context for the task."""
|
63
|
+
self.extra_init_context.update(kwargs)
|
64
|
+
return self
|
53
65
|
|
54
66
|
def model_post_init(self, __context: Any) -> None:
|
55
67
|
"""Initialize the task with a namespace event."""
|
fabricatio/models/usages.py
CHANGED
@@ -31,7 +31,7 @@ from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
|
|
31
31
|
|
32
32
|
if configs.cache.enabled and configs.cache.type:
|
33
33
|
litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
|
34
|
-
logger.
|
34
|
+
logger.debug(f"{configs.cache.type.name} Cache enabled")
|
35
35
|
|
36
36
|
ROUTER = Router(
|
37
37
|
routing_strategy="usage-based-routing-v2",
|
@@ -63,7 +63,7 @@ class LLMUsage(ScopedConfig):
|
|
63
63
|
self._added_deployment = ROUTER.upsert_deployment(deployment)
|
64
64
|
return ROUTER
|
65
65
|
|
66
|
-
# noinspection PyTypeChecker,PydanticTypeChecker
|
66
|
+
# noinspection PyTypeChecker,PydanticTypeChecker,t
|
67
67
|
async def aquery(
|
68
68
|
self,
|
69
69
|
messages: List[Dict[str, str]],
|
@@ -122,6 +122,12 @@ class LLMUsage(ScopedConfig):
|
|
122
122
|
"cache-ttl": kwargs.get("cache_ttl"),
|
123
123
|
"s-maxage": kwargs.get("s_maxage"),
|
124
124
|
},
|
125
|
+
presence_penalty=kwargs.get("presence_penalty")
|
126
|
+
or self.llm_presence_penalty
|
127
|
+
or configs.llm.presence_penalty,
|
128
|
+
frequency_penalty=kwargs.get("frequency_penalty")
|
129
|
+
or self.llm_frequency_penalty
|
130
|
+
or configs.llm.frequency_penalty,
|
125
131
|
)
|
126
132
|
|
127
133
|
async def ainvoke(
|
@@ -236,7 +242,6 @@ class LLMUsage(ScopedConfig):
|
|
236
242
|
validator: Callable[[str], T | None],
|
237
243
|
default: T = ...,
|
238
244
|
max_validations: PositiveInt = 2,
|
239
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
240
245
|
**kwargs: Unpack[GenerateKwargs],
|
241
246
|
) -> T: ...
|
242
247
|
@overload
|
@@ -246,7 +251,6 @@ class LLMUsage(ScopedConfig):
|
|
246
251
|
validator: Callable[[str], T | None],
|
247
252
|
default: T = ...,
|
248
253
|
max_validations: PositiveInt = 2,
|
249
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
250
254
|
**kwargs: Unpack[GenerateKwargs],
|
251
255
|
) -> List[T]: ...
|
252
256
|
@overload
|
@@ -256,7 +260,6 @@ class LLMUsage(ScopedConfig):
|
|
256
260
|
validator: Callable[[str], T | None],
|
257
261
|
default: None = None,
|
258
262
|
max_validations: PositiveInt = 2,
|
259
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
260
263
|
**kwargs: Unpack[GenerateKwargs],
|
261
264
|
) -> Optional[T]: ...
|
262
265
|
|
@@ -267,7 +270,6 @@ class LLMUsage(ScopedConfig):
|
|
267
270
|
validator: Callable[[str], T | None],
|
268
271
|
default: None = None,
|
269
272
|
max_validations: PositiveInt = 2,
|
270
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
271
273
|
**kwargs: Unpack[GenerateKwargs],
|
272
274
|
) -> List[Optional[T]]: ...
|
273
275
|
|
@@ -277,7 +279,6 @@ class LLMUsage(ScopedConfig):
|
|
277
279
|
validator: Callable[[str], T | None],
|
278
280
|
default: Optional[T] = None,
|
279
281
|
max_validations: PositiveInt = 3,
|
280
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
281
282
|
**kwargs: Unpack[GenerateKwargs],
|
282
283
|
) -> Optional[T] | List[Optional[T]] | List[T] | T:
|
283
284
|
"""Asynchronously asks a question and validates the response using a given validator.
|
@@ -287,34 +288,16 @@ class LLMUsage(ScopedConfig):
|
|
287
288
|
validator (Callable[[str], T | None]): A function to validate the response.
|
288
289
|
default (T | None): Default value to return if validation fails. Defaults to None.
|
289
290
|
max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 3.
|
290
|
-
co_extractor (Optional[GenerateKwargs]): Keyword arguments for the co-extractor, if provided will enable co-extraction.
|
291
291
|
**kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
|
292
292
|
|
293
293
|
Returns:
|
294
|
-
Optional[T] | List[
|
294
|
+
Optional[T] | List[T | None] | List[T] | T: The validated response.
|
295
295
|
"""
|
296
296
|
|
297
297
|
async def _inner(q: str) -> Optional[T]:
|
298
298
|
for lap in range(max_validations):
|
299
299
|
try:
|
300
|
-
if (
|
301
|
-
co_extractor is not None
|
302
|
-
and logger.debug("Co-extraction is enabled.") is None
|
303
|
-
and (
|
304
|
-
validated := validator(
|
305
|
-
response := await self.aask(
|
306
|
-
question=(
|
307
|
-
TEMPLATE_MANAGER.render_template(
|
308
|
-
configs.templates.co_validation_template,
|
309
|
-
{"original_q": q, "original_a": response},
|
310
|
-
)
|
311
|
-
),
|
312
|
-
**co_extractor,
|
313
|
-
)
|
314
|
-
)
|
315
|
-
)
|
316
|
-
is not None
|
317
|
-
):
|
300
|
+
if (validated := validator(response := await self.aask(question=q, **kwargs))) is not None:
|
318
301
|
logger.debug(f"Successfully validated the response at {lap}th attempt.")
|
319
302
|
return validated
|
320
303
|
|
fabricatio/parser.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
1
|
"""A module to parse text using regular expressions."""
|
2
2
|
|
3
|
+
import re
|
4
|
+
from functools import lru_cache
|
5
|
+
from re import Pattern, compile
|
3
6
|
from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type
|
4
7
|
|
5
|
-
import
|
6
|
-
import regex
|
8
|
+
import ujson
|
7
9
|
from json_repair import repair_json
|
8
10
|
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
|
9
|
-
from regex import Pattern, compile
|
10
11
|
|
11
12
|
from fabricatio.config import configs
|
12
13
|
from fabricatio.journal import logger
|
@@ -25,7 +26,7 @@ class Capture(BaseModel):
|
|
25
26
|
"""The target groups to capture from the pattern."""
|
26
27
|
pattern: str = Field(frozen=True)
|
27
28
|
"""The regular expression pattern to search for."""
|
28
|
-
flags: PositiveInt = Field(default=
|
29
|
+
flags: PositiveInt = Field(default=re.DOTALL | re.MULTILINE | re.IGNORECASE, frozen=True)
|
29
30
|
"""The flags to use when compiling the regular expression pattern."""
|
30
31
|
capture_type: Optional[str] = None
|
31
32
|
"""The type of capture to perform, e.g., 'json', which is used to dispatch the fixer accordingly."""
|
@@ -49,7 +50,8 @@ class Capture(BaseModel):
|
|
49
50
|
logger.debug("Applying json repair to text.")
|
50
51
|
if isinstance(text, str):
|
51
52
|
return repair_json(text, ensure_ascii=False) # pyright: ignore [reportReturnType]
|
52
|
-
return [repair_json(item, ensure_ascii=False) for item in
|
53
|
+
return [repair_json(item, ensure_ascii=False) for item in
|
54
|
+
text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
|
53
55
|
case _:
|
54
56
|
return text # pyright: ignore [reportReturnType]
|
55
57
|
|
@@ -63,7 +65,7 @@ class Capture(BaseModel):
|
|
63
65
|
str | None: The captured text if the pattern is found, otherwise None.
|
64
66
|
|
65
67
|
"""
|
66
|
-
if (match :=self._compiled.match(text) or self._compiled.search(text)
|
68
|
+
if (match := self._compiled.match(text) or self._compiled.search(text)) is None:
|
67
69
|
logger.debug(f"Capture Failed {type(text)}: \n{text}")
|
68
70
|
return None
|
69
71
|
groups = self.fix(match.groups())
|
@@ -94,12 +96,12 @@ class Capture(BaseModel):
|
|
94
96
|
return None
|
95
97
|
|
96
98
|
def validate_with[K, T, E](
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
99
|
+
self,
|
100
|
+
text: str,
|
101
|
+
target_type: Type[T],
|
102
|
+
elements_type: Optional[Type[E]] = None,
|
103
|
+
length: Optional[int] = None,
|
104
|
+
deserializer: Callable[[Tuple[str, ...]], K] | Callable[[str], K] = ujson.loads,
|
103
105
|
) -> T | None:
|
104
106
|
"""Validate the given text using the pattern.
|
105
107
|
|
@@ -124,6 +126,7 @@ class Capture(BaseModel):
|
|
124
126
|
return None
|
125
127
|
|
126
128
|
@classmethod
|
129
|
+
@lru_cache(32)
|
127
130
|
def capture_code_block(cls, language: str) -> Self:
|
128
131
|
"""Capture the first occurrence of a code block in the given text.
|
129
132
|
|
@@ -136,6 +139,7 @@ class Capture(BaseModel):
|
|
136
139
|
return cls(pattern=f"```{language}(.*?)```", capture_type=language)
|
137
140
|
|
138
141
|
@classmethod
|
142
|
+
@lru_cache(32)
|
139
143
|
def capture_generic_block(cls, language: str) -> Self:
|
140
144
|
"""Capture the first occurrence of a generic code block in the given text.
|
141
145
|
|
Binary file
|