fabricatio 0.2.10.dev0__cp312-cp312-win_amd64.whl → 0.2.11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. fabricatio/actions/article.py +55 -10
  2. fabricatio/actions/article_rag.py +297 -12
  3. fabricatio/actions/fs.py +25 -0
  4. fabricatio/actions/output.py +17 -3
  5. fabricatio/actions/rag.py +42 -20
  6. fabricatio/actions/rules.py +14 -3
  7. fabricatio/capabilities/extract.py +70 -0
  8. fabricatio/capabilities/rag.py +5 -2
  9. fabricatio/capabilities/rating.py +5 -2
  10. fabricatio/capabilities/task.py +16 -16
  11. fabricatio/config.py +9 -2
  12. fabricatio/decorators.py +43 -26
  13. fabricatio/fs/__init__.py +9 -2
  14. fabricatio/fs/readers.py +6 -10
  15. fabricatio/models/action.py +16 -11
  16. fabricatio/models/adv_kwargs_types.py +5 -12
  17. fabricatio/models/extra/aricle_rag.py +254 -0
  18. fabricatio/models/extra/article_base.py +56 -7
  19. fabricatio/models/extra/article_essence.py +8 -7
  20. fabricatio/models/extra/article_main.py +102 -6
  21. fabricatio/models/extra/problem.py +5 -1
  22. fabricatio/models/extra/rag.py +49 -23
  23. fabricatio/models/generic.py +43 -24
  24. fabricatio/models/kwargs_types.py +12 -3
  25. fabricatio/models/task.py +13 -1
  26. fabricatio/models/usages.py +10 -27
  27. fabricatio/parser.py +16 -12
  28. fabricatio/rust.cp312-win_amd64.pyd +0 -0
  29. fabricatio/rust.pyi +177 -63
  30. fabricatio/utils.py +50 -10
  31. fabricatio-0.2.11.data/scripts/tdown.exe +0 -0
  32. {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/METADATA +20 -12
  33. fabricatio-0.2.11.dist-info/RECORD +65 -0
  34. fabricatio-0.2.10.dev0.data/scripts/tdown.exe +0 -0
  35. fabricatio-0.2.10.dev0.dist-info/RECORD +0 -62
  36. {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/WHEEL +0 -0
  37. {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,13 @@
1
1
  """A module containing the RAG (Retrieval-Augmented Generation) models."""
2
2
 
3
- from abc import ABCMeta, abstractmethod
4
- from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Self, Sequence
3
+ from abc import ABC
4
+ from functools import partial
5
+ from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Self, Sequence, Set
5
6
 
6
7
  from fabricatio.decorators import precheck_package
7
- from pydantic import BaseModel, ConfigDict, JsonValue
8
+ from fabricatio.models.generic import Vectorizable
9
+ from fabricatio.utils import ok
10
+ from pydantic import JsonValue
8
11
 
9
12
  if TYPE_CHECKING:
10
13
  from importlib.util import find_spec
@@ -15,14 +18,18 @@ if TYPE_CHECKING:
15
18
  from pymilvus import CollectionSchema
16
19
 
17
20
 
18
- class MilvusDataBase(BaseModel, metaclass=ABCMeta):
21
+ class MilvusDataBase(Vectorizable, ABC):
19
22
  """A base class for Milvus data."""
20
23
 
21
- model_config = ConfigDict(use_attribute_docstrings=True)
22
-
23
24
  primary_field_name: ClassVar[str] = "id"
24
-
25
+ """The name of the primary field in Milvus."""
25
26
  vector_field_name: ClassVar[str] = "vector"
27
+ """The name of the vector field in Milvus."""
28
+
29
+ index_type: ClassVar[str] = "FLAT"
30
+ """The type of index to be used in Milvus."""
31
+ metric_type: ClassVar[str] = "COSINE"
32
+ """The type of metric to be used in Milvus."""
26
33
 
27
34
  def prepare_insertion(self, vector: List[float]) -> Dict[str, Any]:
28
35
  """Prepares the data for insertion into Milvus.
@@ -32,11 +39,6 @@ class MilvusDataBase(BaseModel, metaclass=ABCMeta):
32
39
  """
33
40
  return {**self.model_dump(exclude_none=True, by_alias=True), self.vector_field_name: vector}
34
41
 
35
- @property
36
- @abstractmethod
37
- def to_vectorize(self) -> str:
38
- """The text representation of the data."""
39
-
40
42
  @classmethod
41
43
  @precheck_package(
42
44
  "pymilvus", "pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`?"
@@ -50,23 +52,47 @@ class MilvusDataBase(BaseModel, metaclass=ABCMeta):
50
52
  FieldSchema(cls.vector_field_name, dtype=DataType.FLOAT_VECTOR, dim=dimension),
51
53
  ]
52
54
 
53
- type_mapping = {
54
- str: DataType.STRING,
55
- int: DataType.INT64,
56
- float: DataType.DOUBLE,
57
- JsonValue: DataType.JSON,
58
- # TODO add more mapping
59
- }
60
-
61
55
  for k, v in cls.model_fields.items():
62
56
  k: str
63
57
  v: FieldInfo
64
- fields.append(
65
- FieldSchema(k, dtype=type_mapping.get(v.annotation, DataType.UNKNOWN), description=v.description or "")
66
- )
58
+ schema = partial(FieldSchema, k, description=v.description or "")
59
+ anno = ok(v.annotation)
60
+
61
+ if anno == int:
62
+ fields.append(schema(dtype=DataType.INT64))
63
+ elif anno == str:
64
+ fields.append(schema(dtype=DataType.VARCHAR, max_length=65535))
65
+ elif anno == float:
66
+ fields.append(schema(dtype=DataType.DOUBLE))
67
+ elif anno == list[str] or anno == List[str] or anno == set[str] or anno == Set[str]:
68
+ fields.append(
69
+ schema(dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_length=65535, max_capacity=4096)
70
+ )
71
+ elif anno == list[int] or anno == List[int] or anno == set[int] or anno == Set[int]:
72
+ fields.append(schema(dtype=DataType.ARRAY, element_type=DataType.INT64, max_capacity=4096))
73
+ elif anno == list[float] or anno == List[float] or anno == set[float] or anno == Set[float]:
74
+ fields.append(schema(dtype=DataType.ARRAY, element_type=DataType.DOUBLE, max_capacity=4096))
75
+ elif anno == JsonValue:
76
+ fields.append(schema(dtype=DataType.JSON))
77
+
78
+ else:
79
+ raise NotImplementedError(f"{k}:{anno} is not supported")
80
+
67
81
  return CollectionSchema(fields)
68
82
 
69
83
  @classmethod
70
84
  def from_sequence(cls, data: Sequence[Dict[str, Any]]) -> List[Self]:
71
85
  """Constructs a list of instances from a sequence of dictionaries."""
72
86
  return [cls(**d) for d in data]
87
+
88
+
89
+ class MilvusClassicModel(MilvusDataBase):
90
+ """A class representing a classic model stored in Milvus."""
91
+
92
+ text: str
93
+ """The text to be stored in Milvus."""
94
+ subject: str = ""
95
+ """The subject of the text."""
96
+
97
+ def _prepare_vectorization_inner(self) -> str:
98
+ return self.text
@@ -3,12 +3,11 @@
3
3
  from abc import ABC, abstractmethod
4
4
  from datetime import datetime
5
5
  from pathlib import Path
6
- from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Type, Union, final, overload
6
+ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Self, Type, Union, final, overload
7
7
 
8
- import orjson
9
- import rtoml
8
+ import ujson
10
9
  from fabricatio.config import configs
11
- from fabricatio.fs.readers import MAGIKA, safe_text_read
10
+ from fabricatio.fs.readers import safe_text_read
12
11
  from fabricatio.journal import logger
13
12
  from fabricatio.parser import JsonCapture
14
13
  from fabricatio.rust import blake3_hash, detect_language
@@ -53,7 +52,7 @@ class Display(Base):
53
52
  Returns:
54
53
  str: JSON string with 1-level indentation for readability
55
54
  """
56
- return self.model_dump_json(indent=1,by_alias=True)
55
+ return self.model_dump_json(indent=1, by_alias=True)
57
56
 
58
57
  def compact(self) -> str:
59
58
  """Generate compact JSON representation.
@@ -118,6 +117,15 @@ class WordCount(Base):
118
117
  """Expected word count of this research component."""
119
118
 
120
119
 
120
+ class FromMapping(Base):
121
+ """Class that provides a method to generate a list of objects from a mapping."""
122
+
123
+ @classmethod
124
+ @abstractmethod
125
+ def from_mapping(cls, mapping: Mapping[str, Any], **kwargs: Any) -> List[Self]:
126
+ """Generate a list of objects from a mapping."""
127
+
128
+
121
129
  class AsPrompt(Base):
122
130
  """Class that provides a method to generate a prompt from the model.
123
131
 
@@ -171,11 +179,14 @@ class WithRef[T](Base):
171
179
 
172
180
  @overload
173
181
  def update_ref[S: WithRef](self: S, reference: T) -> S: ...
182
+
174
183
  @overload
175
184
  def update_ref[S: WithRef](self: S, reference: "WithRef[T]") -> S: ...
185
+
176
186
  @overload
177
187
  def update_ref[S: WithRef](self: S, reference: None = None) -> S: ...
178
- def update_ref[S: WithRef](self: S, reference: Union[T, "WithRef[T]", None] = None) -> S: # noqa: PYI019
188
+
189
+ def update_ref[S: WithRef](self: S, reference: Union[T, "WithRef[T]", None] = None) -> S:
179
190
  """Update the reference of the object.
180
191
 
181
192
  Args:
@@ -190,7 +201,7 @@ class WithRef[T](Base):
190
201
  self._reference = reference # pyright: ignore [reportAttributeAccessIssue]
191
202
  return self
192
203
 
193
- def derive[S: WithRef](self: S, reference: Any) -> S: # noqa: PYI019
204
+ def derive[S: WithRef](self: S, reference: Any) -> S:
194
205
  """Derive a new object from the current object.
195
206
 
196
207
  Args:
@@ -225,7 +236,7 @@ class PersistentAble(Base):
225
236
  - Hash generated from JSON content ensures uniqueness
226
237
  """
227
238
  p = Path(path)
228
- out = self.model_dump_json(indent=1,by_alias=True)
239
+ out = self.model_dump_json(indent=1, by_alias=True)
229
240
 
230
241
  # Generate a timestamp in the format YYYYMMDD_HHMMSS
231
242
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -299,16 +310,18 @@ class Language(Base):
299
310
  """Class that provides a language attribute."""
300
311
 
301
312
  @property
302
- def language(self)->str:
313
+ def language(self) -> str:
303
314
  """Get the language of the object."""
304
- if isinstance(self,Described):
315
+ if isinstance(self, Described):
305
316
  return detect_language(self.description)
306
- if isinstance(self,Titled):
317
+ if isinstance(self, Titled):
307
318
  return detect_language(self.title)
308
- if isinstance(self,Named):
319
+ if isinstance(self, Named):
309
320
  return detect_language(self.name)
310
321
 
311
322
  return detect_language(self.model_dump_json(by_alias=True))
323
+
324
+
312
325
  class ModelHash(Base):
313
326
  """Class that provides a hash value for the object.
314
327
 
@@ -454,10 +467,9 @@ class WithFormatedJsonSchema(Base):
454
467
  Returns:
455
468
  str: The JSON schema of the model in a formatted string.
456
469
  """
457
- return orjson.dumps(
458
- cls.model_json_schema(schema_generator=UnsortGenerate),
459
- option=orjson.OPT_INDENT_2,
460
- ).decode()
470
+ return ujson.dumps(
471
+ cls.model_json_schema(schema_generator=UnsortGenerate), indent=2, ensure_ascii=False, sort_keys=False
472
+ )
461
473
 
462
474
 
463
475
  class CreateJsonObjPrompt(WithFormatedJsonSchema):
@@ -469,9 +481,11 @@ class CreateJsonObjPrompt(WithFormatedJsonSchema):
469
481
  @classmethod
470
482
  @overload
471
483
  def create_json_prompt(cls, requirement: List[str]) -> List[str]: ...
484
+
472
485
  @classmethod
473
486
  @overload
474
487
  def create_json_prompt(cls, requirement: str) -> str: ...
488
+
475
489
  @classmethod
476
490
  def create_json_prompt(cls, requirement: str | List[str]) -> str | List[str]:
477
491
  """Create the prompt for creating a JSON object with given requirement.
@@ -550,7 +564,7 @@ class FinalizedDumpAble(Base):
550
564
  Returns:
551
565
  str: The finalized dump of the object.
552
566
  """
553
- return self.model_dump_json(indent=1,by_alias=True)
567
+ return self.model_dump_json(indent=1, by_alias=True)
554
568
 
555
569
  def finalized_dump_to(self, path: str | Path) -> Self:
556
570
  """Finalize the dump of the object to a file.
@@ -638,6 +652,8 @@ class WithDependency(Base):
638
652
  Returns:
639
653
  str: The generated prompt for the task.
640
654
  """
655
+ from fabricatio.fs import MAGIKA
656
+
641
657
  return TEMPLATE_MANAGER.render_template(
642
658
  configs.templates.dependencies_template,
643
659
  {
@@ -662,8 +678,9 @@ class Vectorizable(Base):
662
678
  This class includes methods to prepare the model for vectorization, ensuring it fits within a specified token length.
663
679
  """
664
680
 
681
+ @abstractmethod
665
682
  def _prepare_vectorization_inner(self) -> str:
666
- return rtoml.dumps(self.model_dump())
683
+ """Prepare the model for vectorization."""
667
684
 
668
685
  @final
669
686
  def prepare_vectorization(self, max_length: Optional[int] = None) -> str:
@@ -681,8 +698,7 @@ class Vectorizable(Base):
681
698
  max_length = max_length or configs.embedding.max_sequence_length
682
699
  chunk = self._prepare_vectorization_inner()
683
700
  if max_length and (length := token_counter(text=chunk)) > max_length:
684
- logger.error(err := f"Chunk exceeds maximum sequence length {max_length}, got {length}, see {chunk}")
685
- raise ValueError(err)
701
+ raise ValueError(f"Chunk exceeds maximum sequence length {max_length}, got {length}, see \n{chunk}")
686
702
 
687
703
  return chunk
688
704
 
@@ -733,6 +749,12 @@ class ScopedConfig(Base):
733
749
  llm_rpm: Optional[PositiveInt] = None
734
750
  """The requests per minute of the LLM model."""
735
751
 
752
+ llm_presence_penalty: Optional[PositiveFloat] = None
753
+ """The presence penalty of the LLM model."""
754
+
755
+ llm_frequency_penalty: Optional[PositiveFloat] = None
756
+ """The frequency penalty of the LLM model."""
757
+
736
758
  embedding_api_endpoint: Optional[HttpUrl] = None
737
759
  """The OpenAI API endpoint."""
738
760
 
@@ -861,10 +883,7 @@ class Patch[T](ProposedAble):
861
883
  )
862
884
  my_schema["description"] = ref_cls.__doc__
863
885
 
864
- return orjson.dumps(
865
- my_schema,
866
- option=orjson.OPT_INDENT_2,
867
- ).decode()
886
+ return ujson.dumps(my_schema, indent=2, ensure_ascii=False, sort_keys=False)
868
887
 
869
888
 
870
889
  class SequencePatch[T](ProposedUpdateAble):
@@ -1,11 +1,18 @@
1
1
  """This module contains the types for the keyword arguments of the methods in the models module."""
2
2
 
3
- from typing import Any, Dict, List, Optional, Required, TypedDict
3
+ from typing import Any, Dict, List, NotRequired, Optional, Required, TypedDict
4
4
 
5
5
  from litellm.caching.caching import CacheMode
6
6
  from litellm.types.caching import CachingSupportedCallTypes
7
7
 
8
8
 
9
+ class ChunkKwargs(TypedDict):
10
+ """Configuration parameters for chunking operations."""
11
+
12
+ max_chunk_size: int
13
+ max_overlapping_rate: NotRequired[float]
14
+
15
+
9
16
  class EmbeddingKwargs(TypedDict, total=False):
10
17
  """Configuration parameters for text embedding operations.
11
18
 
@@ -26,7 +33,7 @@ class LLMKwargs(TypedDict, total=False):
26
33
  including generation parameters and caching options.
27
34
  """
28
35
 
29
- model: str
36
+ model: Optional[str]
30
37
  temperature: float
31
38
  stop: str | list[str]
32
39
  top_p: float
@@ -38,6 +45,8 @@ class LLMKwargs(TypedDict, total=False):
38
45
  no_store: bool # If store the response of this call to cache
39
46
  cache_ttl: int # how long the stored cache is alive, in seconds
40
47
  s_maxage: int # max accepted age of cached response, in seconds
48
+ presence_penalty: float
49
+ frequency_penalty: float
41
50
 
42
51
 
43
52
  class GenerateKwargs(LLMKwargs, total=False):
@@ -59,7 +68,7 @@ class ValidateKwargs[T](GenerateKwargs, total=False):
59
68
 
60
69
  default: Optional[T]
61
70
  max_validations: int
62
- co_extractor: GenerateKwargs
71
+
63
72
 
64
73
 
65
74
  class CompositeScoreKwargs(ValidateKwargs[List[Dict[str, float]]], total=False):
fabricatio/models/task.py CHANGED
@@ -4,7 +4,7 @@ It includes methods to manage the task's lifecycle, such as starting, finishing,
4
4
  """
5
5
 
6
6
  from asyncio import Queue
7
- from typing import Any, List, Optional, Self
7
+ from typing import Any, Dict, List, Optional, Self
8
8
 
9
9
  from fabricatio.config import configs
10
10
  from fabricatio.constants import TaskStatus
@@ -50,6 +50,18 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
50
50
 
51
51
  _namespace: Event = PrivateAttr(default_factory=Event)
52
52
  """The namespace of the task as an event, which is generated from the namespace list."""
53
+ _extra_init_context: Dict = PrivateAttr(default_factory=dict)
54
+ """Extra initialization context for the task, which is designed to override the one of the Workflow."""
55
+
56
+ @property
57
+ def extra_init_context(self) -> Dict:
58
+ """Extra initialization context for the task, which is designed to override the one of the Workflow."""
59
+ return self._extra_init_context
60
+
61
+ def update_init_context(self, /, **kwargs) -> Self:
62
+ """Update the extra initialization context for the task."""
63
+ self.extra_init_context.update(kwargs)
64
+ return self
53
65
 
54
66
  def model_post_init(self, __context: Any) -> None:
55
67
  """Initialize the task with a namespace event."""
@@ -31,7 +31,7 @@ from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
31
31
 
32
32
  if configs.cache.enabled and configs.cache.type:
33
33
  litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
34
- logger.success(f"{configs.cache.type.name} Cache enabled")
34
+ logger.debug(f"{configs.cache.type.name} Cache enabled")
35
35
 
36
36
  ROUTER = Router(
37
37
  routing_strategy="usage-based-routing-v2",
@@ -63,7 +63,7 @@ class LLMUsage(ScopedConfig):
63
63
  self._added_deployment = ROUTER.upsert_deployment(deployment)
64
64
  return ROUTER
65
65
 
66
- # noinspection PyTypeChecker,PydanticTypeChecker
66
+ # noinspection PyTypeChecker,PydanticTypeChecker,t
67
67
  async def aquery(
68
68
  self,
69
69
  messages: List[Dict[str, str]],
@@ -122,6 +122,12 @@ class LLMUsage(ScopedConfig):
122
122
  "cache-ttl": kwargs.get("cache_ttl"),
123
123
  "s-maxage": kwargs.get("s_maxage"),
124
124
  },
125
+ presence_penalty=kwargs.get("presence_penalty")
126
+ or self.llm_presence_penalty
127
+ or configs.llm.presence_penalty,
128
+ frequency_penalty=kwargs.get("frequency_penalty")
129
+ or self.llm_frequency_penalty
130
+ or configs.llm.frequency_penalty,
125
131
  )
126
132
 
127
133
  async def ainvoke(
@@ -236,7 +242,6 @@ class LLMUsage(ScopedConfig):
236
242
  validator: Callable[[str], T | None],
237
243
  default: T = ...,
238
244
  max_validations: PositiveInt = 2,
239
- co_extractor: Optional[GenerateKwargs] = None,
240
245
  **kwargs: Unpack[GenerateKwargs],
241
246
  ) -> T: ...
242
247
  @overload
@@ -246,7 +251,6 @@ class LLMUsage(ScopedConfig):
246
251
  validator: Callable[[str], T | None],
247
252
  default: T = ...,
248
253
  max_validations: PositiveInt = 2,
249
- co_extractor: Optional[GenerateKwargs] = None,
250
254
  **kwargs: Unpack[GenerateKwargs],
251
255
  ) -> List[T]: ...
252
256
  @overload
@@ -256,7 +260,6 @@ class LLMUsage(ScopedConfig):
256
260
  validator: Callable[[str], T | None],
257
261
  default: None = None,
258
262
  max_validations: PositiveInt = 2,
259
- co_extractor: Optional[GenerateKwargs] = None,
260
263
  **kwargs: Unpack[GenerateKwargs],
261
264
  ) -> Optional[T]: ...
262
265
 
@@ -267,7 +270,6 @@ class LLMUsage(ScopedConfig):
267
270
  validator: Callable[[str], T | None],
268
271
  default: None = None,
269
272
  max_validations: PositiveInt = 2,
270
- co_extractor: Optional[GenerateKwargs] = None,
271
273
  **kwargs: Unpack[GenerateKwargs],
272
274
  ) -> List[Optional[T]]: ...
273
275
 
@@ -277,7 +279,6 @@ class LLMUsage(ScopedConfig):
277
279
  validator: Callable[[str], T | None],
278
280
  default: Optional[T] = None,
279
281
  max_validations: PositiveInt = 3,
280
- co_extractor: Optional[GenerateKwargs] = None,
281
282
  **kwargs: Unpack[GenerateKwargs],
282
283
  ) -> Optional[T] | List[Optional[T]] | List[T] | T:
283
284
  """Asynchronously asks a question and validates the response using a given validator.
@@ -287,34 +288,16 @@ class LLMUsage(ScopedConfig):
287
288
  validator (Callable[[str], T | None]): A function to validate the response.
288
289
  default (T | None): Default value to return if validation fails. Defaults to None.
289
290
  max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 3.
290
- co_extractor (Optional[GenerateKwargs]): Keyword arguments for the co-extractor, if provided will enable co-extraction.
291
291
  **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
292
292
 
293
293
  Returns:
294
- Optional[T] | List[Optional[T]] | List[T] | T: The validated response.
294
+ Optional[T] | List[T | None] | List[T] | T: The validated response.
295
295
  """
296
296
 
297
297
  async def _inner(q: str) -> Optional[T]:
298
298
  for lap in range(max_validations):
299
299
  try:
300
- if ((validated := validator(response := await self.aask(question=q, **kwargs))) is not None) or (
301
- co_extractor is not None
302
- and logger.debug("Co-extraction is enabled.") is None
303
- and (
304
- validated := validator(
305
- response := await self.aask(
306
- question=(
307
- TEMPLATE_MANAGER.render_template(
308
- configs.templates.co_validation_template,
309
- {"original_q": q, "original_a": response},
310
- )
311
- ),
312
- **co_extractor,
313
- )
314
- )
315
- )
316
- is not None
317
- ):
300
+ if (validated := validator(response := await self.aask(question=q, **kwargs))) is not None:
318
301
  logger.debug(f"Successfully validated the response at {lap}th attempt.")
319
302
  return validated
320
303
 
fabricatio/parser.py CHANGED
@@ -1,12 +1,13 @@
1
1
  """A module to parse text using regular expressions."""
2
2
 
3
+ import re
4
+ from functools import lru_cache
5
+ from re import Pattern, compile
3
6
  from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type
4
7
 
5
- import orjson
6
- import regex
8
+ import ujson
7
9
  from json_repair import repair_json
8
10
  from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
9
- from regex import Pattern, compile
10
11
 
11
12
  from fabricatio.config import configs
12
13
  from fabricatio.journal import logger
@@ -25,7 +26,7 @@ class Capture(BaseModel):
25
26
  """The target groups to capture from the pattern."""
26
27
  pattern: str = Field(frozen=True)
27
28
  """The regular expression pattern to search for."""
28
- flags: PositiveInt = Field(default=regex.DOTALL | regex.MULTILINE | regex.IGNORECASE, frozen=True)
29
+ flags: PositiveInt = Field(default=re.DOTALL | re.MULTILINE | re.IGNORECASE, frozen=True)
29
30
  """The flags to use when compiling the regular expression pattern."""
30
31
  capture_type: Optional[str] = None
31
32
  """The type of capture to perform, e.g., 'json', which is used to dispatch the fixer accordingly."""
@@ -49,7 +50,8 @@ class Capture(BaseModel):
49
50
  logger.debug("Applying json repair to text.")
50
51
  if isinstance(text, str):
51
52
  return repair_json(text, ensure_ascii=False) # pyright: ignore [reportReturnType]
52
- return [repair_json(item, ensure_ascii=False) for item in text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
53
+ return [repair_json(item, ensure_ascii=False) for item in
54
+ text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
53
55
  case _:
54
56
  return text # pyright: ignore [reportReturnType]
55
57
 
@@ -63,7 +65,7 @@ class Capture(BaseModel):
63
65
  str | None: The captured text if the pattern is found, otherwise None.
64
66
 
65
67
  """
66
- if (match :=self._compiled.match(text) or self._compiled.search(text) ) is None:
68
+ if (match := self._compiled.match(text) or self._compiled.search(text)) is None:
67
69
  logger.debug(f"Capture Failed {type(text)}: \n{text}")
68
70
  return None
69
71
  groups = self.fix(match.groups())
@@ -94,12 +96,12 @@ class Capture(BaseModel):
94
96
  return None
95
97
 
96
98
  def validate_with[K, T, E](
97
- self,
98
- text: str,
99
- target_type: Type[T],
100
- elements_type: Optional[Type[E]] = None,
101
- length: Optional[int] = None,
102
- deserializer: Callable[[Tuple[str, ...]], K] | Callable[[str], K] = orjson.loads,
99
+ self,
100
+ text: str,
101
+ target_type: Type[T],
102
+ elements_type: Optional[Type[E]] = None,
103
+ length: Optional[int] = None,
104
+ deserializer: Callable[[Tuple[str, ...]], K] | Callable[[str], K] = ujson.loads,
103
105
  ) -> T | None:
104
106
  """Validate the given text using the pattern.
105
107
 
@@ -124,6 +126,7 @@ class Capture(BaseModel):
124
126
  return None
125
127
 
126
128
  @classmethod
129
+ @lru_cache(32)
127
130
  def capture_code_block(cls, language: str) -> Self:
128
131
  """Capture the first occurrence of a code block in the given text.
129
132
 
@@ -136,6 +139,7 @@ class Capture(BaseModel):
136
139
  return cls(pattern=f"```{language}(.*?)```", capture_type=language)
137
140
 
138
141
  @classmethod
142
+ @lru_cache(32)
139
143
  def capture_generic_block(cls, language: str) -> Self:
140
144
  """Capture the first occurrence of a generic code block in the given text.
141
145
 
Binary file