fabricatio 0.2.10.dev1__cp312-cp312-win_amd64.whl → 0.2.11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,6 +18,7 @@ from fabricatio.models.generic import (
18
18
  Titled,
19
19
  WordCount,
20
20
  )
21
+ from fabricatio.rust import comment
21
22
  from pydantic import Field
22
23
 
23
24
 
@@ -29,11 +30,9 @@ class ReferringType(StrEnum):
29
30
  SUBSECTION = "subsection"
30
31
 
31
32
 
32
-
33
33
  type RefKey = Tuple[str, Optional[str], Optional[str]]
34
34
 
35
35
 
36
-
37
36
  class ArticleMetaData(SketchedAble, Described, WordCount, Titled, Language):
38
37
  """Metadata for an article component."""
39
38
 
@@ -47,7 +46,16 @@ class ArticleMetaData(SketchedAble, Described, WordCount, Titled, Language):
47
46
  aims: List[str]
48
47
  """List of writing aims of the research component in academic style."""
49
48
 
50
-
49
+ @property
50
+ def typst_metadata_comment(self) -> str:
51
+ """Generates a comment for the metadata of the article component."""
52
+ return (
53
+ (comment(f"Desc:\n{self.description}\n") if self.description else "")
54
+ + (comment(f"Aims:\n{'\n '.join(self.aims)}\n") if self.aims else "")
55
+ + (comment(f"Expected Word Count:{self.expected_word_count}") if self.expected_word_count else "")
56
+ if self.expected_word_count
57
+ else ""
58
+ )
51
59
 
52
60
 
53
61
  class ArticleOutlineBase(
@@ -92,7 +100,7 @@ class SubSectionBase(ArticleOutlineBase):
92
100
 
93
101
  def to_typst_code(self) -> str:
94
102
  """Converts the component into a Typst code snippet for rendering."""
95
- return f"=== {self.title}\n"
103
+ return f"=== {self.title}\n{self.typst_metadata_comment}\n"
96
104
 
97
105
  def introspect(self) -> str:
98
106
  """Introspects the article subsection outline."""
@@ -117,7 +125,9 @@ class SectionBase[T: SubSectionBase](ArticleOutlineBase):
117
125
  Returns:
118
126
  str: The formatted Typst code snippet.
119
127
  """
120
- return f"== {self.title}\n" + "\n\n".join(subsec.to_typst_code() for subsec in self.subsections)
128
+ return f"== {self.title}\n{self.typst_metadata_comment}\n" + "\n\n".join(
129
+ subsec.to_typst_code() for subsec in self.subsections
130
+ )
121
131
 
122
132
  def resolve_update_conflict(self, other: Self) -> str:
123
133
  """Resolve update errors in the article outline."""
@@ -160,7 +170,9 @@ class ChapterBase[T: SectionBase](ArticleOutlineBase):
160
170
 
161
171
  def to_typst_code(self) -> str:
162
172
  """Converts the chapter into a Typst formatted code snippet for rendering."""
163
- return f"= {self.title}\n" + "\n\n".join(sec.to_typst_code() for sec in self.sections)
173
+ return f"= {self.title}\n{self.typst_metadata_comment}\n" + "\n\n".join(
174
+ sec.to_typst_code() for sec in self.sections
175
+ )
164
176
 
165
177
  def resolve_update_conflict(self, other: Self) -> str:
166
178
  """Resolve update errors in the article outline."""
@@ -302,4 +314,41 @@ class ArticleBase[T: ChapterBase](FinalizedDumpAble, AsPrompt, WordCount, Descri
302
314
  === Implementation Details
303
315
  == Evaluation Protocol
304
316
  """
305
- return "\n\n".join(a.to_typst_code() for a in self.chapters)
317
+ return (
318
+ comment(
319
+ f"Title:{self.title}\n"
320
+ + (f"Desc:\n{self.description}\n" if self.description else "")
321
+ + f"Word Count:{self.expected_word_count}"
322
+ if self.expected_word_count
323
+ else ""
324
+ )
325
+ + "\n\n"
326
+ + "\n\n".join(a.to_typst_code() for a in self.chapters)
327
+ )
328
+
329
+ def avg_chap_wordcount[S](self:S) -> S:
330
+ """Set all chap have same word count sum up to be `self.expected_word_count`."""
331
+ avg = int(self.expected_word_count / len(self.chapters))
332
+ for c in self.chapters:
333
+ c.expected_word_count = avg
334
+ return self
335
+
336
+ def avg_sec_wordcount[S](self:S) -> S:
337
+ """Set all sec have same word count sum up to be `self.expected_word_count`."""
338
+ for c in self.chapters:
339
+ avg = int(c.expected_word_count / len(c.sections))
340
+ for s in c.sections:
341
+ s.expected_word_count = avg
342
+ return self
343
+
344
+ def avg_subsec_wordcount[S](self:S) -> S:
345
+ """Set all subsec have same word count sum up to be `self.expected_word_count`."""
346
+ for _, s in self.iter_sections():
347
+ avg = int(s.expected_word_count / len(s.subsections))
348
+ for ss in s.subsections:
349
+ ss.expected_word_count = avg
350
+ return self
351
+
352
+ def avg_wordcount_recursive(self) -> Self:
353
+ """Set all chap, sec, subsec have same word count sum up to be `self.expected_word_count`."""
354
+ return self.avg_chap_wordcount().avg_sec_wordcount().avg_sec_wordcount()
@@ -2,6 +2,7 @@
2
2
 
3
3
  from typing import Dict, Generator, List, Self, Tuple, override
4
4
 
5
+ from fabricatio.decorators import precheck_package
5
6
  from fabricatio.fs.readers import extract_sections
6
7
  from fabricatio.journal import logger
7
8
  from fabricatio.models.extra.article_base import (
@@ -11,11 +12,15 @@ from fabricatio.models.extra.article_base import (
11
12
  SubSectionBase,
12
13
  )
13
14
  from fabricatio.models.extra.article_outline import (
15
+ ArticleChapterOutline,
14
16
  ArticleOutline,
17
+ ArticleSectionOutline,
18
+ ArticleSubsectionOutline,
15
19
  )
16
20
  from fabricatio.models.generic import Described, PersistentAble, SequencePatch, SketchedAble, WithRef, WordCount
17
- from fabricatio.rust import word_count
18
- from pydantic import Field
21
+ from fabricatio.rust import convert_all_block_tex, convert_all_inline_tex, word_count
22
+ from fabricatio.utils import fallback_kwargs
23
+ from pydantic import Field, NonNegativeInt
19
24
 
20
25
  PARAGRAPH_SEP = "// - - -"
21
26
 
@@ -23,6 +28,9 @@ PARAGRAPH_SEP = "// - - -"
23
28
  class Paragraph(SketchedAble, WordCount, Described):
24
29
  """Structured academic paragraph blueprint for controlled content generation."""
25
30
 
31
+ expected_word_count: NonNegativeInt = 0
32
+ """The expected word count of this paragraph, 0 means not specified"""
33
+
26
34
  description: str = Field(
27
35
  alias="elaboration",
28
36
  description=Described.model_fields["description"].description,
@@ -85,7 +93,7 @@ class ArticleSubsection(SubSectionBase):
85
93
  Returns:
86
94
  str: Typst code snippet for rendering.
87
95
  """
88
- return f"=== {self.title}\n" + f"\n{PARAGRAPH_SEP}\n".join(p.content for p in self.paragraphs)
96
+ return super().to_typst_code() + f"\n\n{PARAGRAPH_SEP}\n\n".join(p.content for p in self.paragraphs)
89
97
 
90
98
  @classmethod
91
99
  def from_typst_code(cls, title: str, body: str) -> Self:
@@ -153,10 +161,74 @@ class Article(
153
161
  "Original Article": self.display(),
154
162
  }
155
163
 
164
+ def convert_tex(self) -> Self:
165
+ """Convert tex to typst code."""
166
+ for _, _, subsec in self.iter_subsections():
167
+ for p in subsec.paragraphs:
168
+ p.content = convert_all_inline_tex(p.content)
169
+ p.content = convert_all_block_tex(p.content)
170
+ return self
171
+
172
+ def fix_wrapper(self) -> Self:
173
+ """Fix wrapper."""
174
+ for _, _, subsec in self.iter_subsections():
175
+ for p in subsec.paragraphs:
176
+ p.content = (
177
+ p.content.replace(r" \( ", "$")
178
+ .replace(r" \) ", "$")
179
+ .replace("\\[\n", "$$\n")
180
+ .replace("\n\\]", "\n$$")
181
+ )
182
+ return self
183
+
156
184
  @override
157
185
  def iter_subsections(self) -> Generator[Tuple[ArticleChapter, ArticleSection, ArticleSubsection], None, None]:
158
186
  return super().iter_subsections() # pyright: ignore [reportReturnType]
159
187
 
188
+ def extrac_outline(self) -> ArticleOutline:
189
+ """Extract outline from article."""
190
+ # Create an empty list to hold chapter outlines
191
+ chapters = []
192
+
193
+ # Iterate through each chapter in the article
194
+ for chapter in self.chapters:
195
+ # Create an empty list to hold section outlines
196
+ sections = []
197
+
198
+ # Iterate through each section in the chapter
199
+ for section in chapter.sections:
200
+ # Create an empty list to hold subsection outlines
201
+ subsections = []
202
+
203
+ # Iterate through each subsection in the section
204
+ for subsection in section.subsections:
205
+ # Create a subsection outline and add it to the list
206
+ subsections.append(
207
+ ArticleSubsectionOutline(**subsection.model_dump(exclude={"paragraphs"}, by_alias=True))
208
+ )
209
+
210
+ # Create a section outline and add it to the list
211
+ sections.append(
212
+ ArticleSectionOutline(
213
+ **section.model_dump(exclude={"subsections"}, by_alias=True),
214
+ subsections=subsections,
215
+ )
216
+ )
217
+
218
+ # Create a chapter outline and add it to the list
219
+ chapters.append(
220
+ ArticleChapterOutline(
221
+ **chapter.model_dump(exclude={"sections"}, by_alias=True),
222
+ sections=sections,
223
+ )
224
+ )
225
+
226
+ # Create and return the article outline
227
+ return ArticleOutline(
228
+ **self.model_dump(exclude={"chapters"}, by_alias=True),
229
+ chapters=chapters,
230
+ )
231
+
160
232
  @classmethod
161
233
  def from_outline(cls, outline: ArticleOutline) -> "Article":
162
234
  """Generates an article from the given outline.
@@ -194,13 +266,37 @@ class Article(
194
266
  return article
195
267
 
196
268
  @classmethod
197
- def from_typst_code(cls, title: str, body: str) -> Self:
269
+ def from_typst_code(cls, title: str, body: str, **kwargs) -> Self:
198
270
  """Generates an article from the given Typst code."""
199
271
  return cls(
200
272
  chapters=[
201
273
  ArticleChapter.from_typst_code(*pack) for pack in extract_sections(body, level=1, section_char="=")
202
274
  ],
203
275
  heading=title,
204
- expected_word_count=word_count(body),
205
- abstract="",
276
+ **fallback_kwargs(
277
+ kwargs,
278
+ expected_word_count=word_count(body),
279
+ abstract="",
280
+ ),
206
281
  )
282
+
283
+ @classmethod
284
+ def from_mixed_source(cls, article_outline: ArticleOutline, typst_code: str) -> Self:
285
+ """Generates an article from the given outline and Typst code."""
286
+ self = cls.from_typst_code(article_outline.title, typst_code)
287
+ self.expected_word_count = article_outline.expected_word_count
288
+ self.description = article_outline.description
289
+ for a, o in zip(self.iter_dfs(), article_outline.iter_dfs(), strict=True):
290
+ a.update_metadata(o)
291
+ return self.update_ref(article_outline)
292
+
293
+ @precheck_package(
294
+ "questionary", "'questionary' is required to run this function. Have you installed `fabricatio[qa]`?."
295
+ )
296
+ async def edit_titles(self) -> Self:
297
+ """Edits the titles of the article."""
298
+ from questionary import text
299
+
300
+ for a in self.iter_dfs():
301
+ a.title = await text(f"Edit `{a.title}`.", default=a.title).ask_async() or a.title
302
+ return self
@@ -7,7 +7,6 @@ from fabricatio.journal import logger
7
7
  from fabricatio.models.generic import SketchedAble, WithBriefing
8
8
  from fabricatio.utils import ask_edit
9
9
  from pydantic import Field
10
- from questionary import Choice, checkbox, text
11
10
  from rich import print as r_print
12
11
 
13
12
 
@@ -74,6 +73,9 @@ class ProblemSolutions(SketchedAble):
74
73
  return len(self.solutions) > 0
75
74
 
76
75
  async def edit_problem(self) -> Self:
76
+ """Interactively edit the problem description."""
77
+ from questionary import text
78
+
77
79
  """Interactively edit the problem description."""
78
80
  self.problem = Problem.model_validate_strings(
79
81
  await text("Please edit the problem below:", default=self.problem.display()).ask_async()
@@ -127,6 +129,8 @@ class Improvement(SketchedAble):
127
129
  Returns:
128
130
  Self: The current instance with filtered problems and solutions.
129
131
  """
132
+ from questionary import Choice, checkbox
133
+
130
134
  # Choose the problems to retain
131
135
  chosen_ones: List[ProblemSolutions] = await checkbox(
132
136
  "Please choose the problems you want to retain.(Default: retain all)",
@@ -3,11 +3,11 @@
3
3
  from abc import ABC, abstractmethod
4
4
  from datetime import datetime
5
5
  from pathlib import Path
6
- from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Type, Union, final, overload
6
+ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Self, Type, Union, final, overload
7
7
 
8
- import orjson
8
+ import ujson
9
9
  from fabricatio.config import configs
10
- from fabricatio.fs.readers import MAGIKA, safe_text_read
10
+ from fabricatio.fs.readers import safe_text_read
11
11
  from fabricatio.journal import logger
12
12
  from fabricatio.parser import JsonCapture
13
13
  from fabricatio.rust import blake3_hash, detect_language
@@ -117,6 +117,15 @@ class WordCount(Base):
117
117
  """Expected word count of this research component."""
118
118
 
119
119
 
120
+ class FromMapping(Base):
121
+ """Class that provides a method to generate a list of objects from a mapping."""
122
+
123
+ @classmethod
124
+ @abstractmethod
125
+ def from_mapping(cls, mapping: Mapping[str, Any], **kwargs: Any) -> List[Self]:
126
+ """Generate a list of objects from a mapping."""
127
+
128
+
120
129
  class AsPrompt(Base):
121
130
  """Class that provides a method to generate a prompt from the model.
122
131
 
@@ -170,11 +179,14 @@ class WithRef[T](Base):
170
179
 
171
180
  @overload
172
181
  def update_ref[S: WithRef](self: S, reference: T) -> S: ...
182
+
173
183
  @overload
174
184
  def update_ref[S: WithRef](self: S, reference: "WithRef[T]") -> S: ...
185
+
175
186
  @overload
176
187
  def update_ref[S: WithRef](self: S, reference: None = None) -> S: ...
177
- def update_ref[S: WithRef](self: S, reference: Union[T, "WithRef[T]", None] = None) -> S: # noqa: PYI019
188
+
189
+ def update_ref[S: WithRef](self: S, reference: Union[T, "WithRef[T]", None] = None) -> S:
178
190
  """Update the reference of the object.
179
191
 
180
192
  Args:
@@ -189,7 +201,7 @@ class WithRef[T](Base):
189
201
  self._reference = reference # pyright: ignore [reportAttributeAccessIssue]
190
202
  return self
191
203
 
192
- def derive[S: WithRef](self: S, reference: Any) -> S: # noqa: PYI019
204
+ def derive[S: WithRef](self: S, reference: Any) -> S:
193
205
  """Derive a new object from the current object.
194
206
 
195
207
  Args:
@@ -455,10 +467,9 @@ class WithFormatedJsonSchema(Base):
455
467
  Returns:
456
468
  str: The JSON schema of the model in a formatted string.
457
469
  """
458
- return orjson.dumps(
459
- cls.model_json_schema(schema_generator=UnsortGenerate),
460
- option=orjson.OPT_INDENT_2,
461
- ).decode()
470
+ return ujson.dumps(
471
+ cls.model_json_schema(schema_generator=UnsortGenerate), indent=2, ensure_ascii=False, sort_keys=False
472
+ )
462
473
 
463
474
 
464
475
  class CreateJsonObjPrompt(WithFormatedJsonSchema):
@@ -470,9 +481,11 @@ class CreateJsonObjPrompt(WithFormatedJsonSchema):
470
481
  @classmethod
471
482
  @overload
472
483
  def create_json_prompt(cls, requirement: List[str]) -> List[str]: ...
484
+
473
485
  @classmethod
474
486
  @overload
475
487
  def create_json_prompt(cls, requirement: str) -> str: ...
488
+
476
489
  @classmethod
477
490
  def create_json_prompt(cls, requirement: str | List[str]) -> str | List[str]:
478
491
  """Create the prompt for creating a JSON object with given requirement.
@@ -639,6 +652,8 @@ class WithDependency(Base):
639
652
  Returns:
640
653
  str: The generated prompt for the task.
641
654
  """
655
+ from fabricatio.fs import MAGIKA
656
+
642
657
  return TEMPLATE_MANAGER.render_template(
643
658
  configs.templates.dependencies_template,
644
659
  {
@@ -734,6 +749,12 @@ class ScopedConfig(Base):
734
749
  llm_rpm: Optional[PositiveInt] = None
735
750
  """The requests per minute of the LLM model."""
736
751
 
752
+ llm_presence_penalty: Optional[PositiveFloat] = None
753
+ """The presence penalty of the LLM model."""
754
+
755
+ llm_frequency_penalty: Optional[PositiveFloat] = None
756
+ """The frequency penalty of the LLM model."""
757
+
737
758
  embedding_api_endpoint: Optional[HttpUrl] = None
738
759
  """The OpenAI API endpoint."""
739
760
 
@@ -862,10 +883,7 @@ class Patch[T](ProposedAble):
862
883
  )
863
884
  my_schema["description"] = ref_cls.__doc__
864
885
 
865
- return orjson.dumps(
866
- my_schema,
867
- option=orjson.OPT_INDENT_2,
868
- ).decode()
886
+ return ujson.dumps(my_schema, indent=2, ensure_ascii=False, sort_keys=False)
869
887
 
870
888
 
871
889
  class SequencePatch[T](ProposedUpdateAble):
@@ -33,7 +33,7 @@ class LLMKwargs(TypedDict, total=False):
33
33
  including generation parameters and caching options.
34
34
  """
35
35
 
36
- model: str
36
+ model: Optional[str]
37
37
  temperature: float
38
38
  stop: str | list[str]
39
39
  top_p: float
@@ -45,6 +45,8 @@ class LLMKwargs(TypedDict, total=False):
45
45
  no_store: bool # If store the response of this call to cache
46
46
  cache_ttl: int # how long the stored cache is alive, in seconds
47
47
  s_maxage: int # max accepted age of cached response, in seconds
48
+ presence_penalty: float
49
+ frequency_penalty: float
48
50
 
49
51
 
50
52
  class GenerateKwargs(LLMKwargs, total=False):
@@ -66,7 +68,7 @@ class ValidateKwargs[T](GenerateKwargs, total=False):
66
68
 
67
69
  default: Optional[T]
68
70
  max_validations: int
69
- co_extractor: GenerateKwargs
71
+
70
72
 
71
73
 
72
74
  class CompositeScoreKwargs(ValidateKwargs[List[Dict[str, float]]], total=False):
fabricatio/models/task.py CHANGED
@@ -4,7 +4,7 @@ It includes methods to manage the task's lifecycle, such as starting, finishing,
4
4
  """
5
5
 
6
6
  from asyncio import Queue
7
- from typing import Any, List, Optional, Self
7
+ from typing import Any, Dict, List, Optional, Self
8
8
 
9
9
  from fabricatio.config import configs
10
10
  from fabricatio.constants import TaskStatus
@@ -50,6 +50,18 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
50
50
 
51
51
  _namespace: Event = PrivateAttr(default_factory=Event)
52
52
  """The namespace of the task as an event, which is generated from the namespace list."""
53
+ _extra_init_context: Dict = PrivateAttr(default_factory=dict)
54
+ """Extra initialization context for the task, which is designed to override the one of the Workflow."""
55
+
56
+ @property
57
+ def extra_init_context(self) -> Dict:
58
+ """Extra initialization context for the task, which is designed to override the one of the Workflow."""
59
+ return self._extra_init_context
60
+
61
+ def update_init_context(self, /, **kwargs) -> Self:
62
+ """Update the extra initialization context for the task."""
63
+ self.extra_init_context.update(kwargs)
64
+ return self
53
65
 
54
66
  def model_post_init(self, __context: Any) -> None:
55
67
  """Initialize the task with a namespace event."""
@@ -31,7 +31,7 @@ from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
31
31
 
32
32
  if configs.cache.enabled and configs.cache.type:
33
33
  litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
34
- logger.success(f"{configs.cache.type.name} Cache enabled")
34
+ logger.debug(f"{configs.cache.type.name} Cache enabled")
35
35
 
36
36
  ROUTER = Router(
37
37
  routing_strategy="usage-based-routing-v2",
@@ -63,7 +63,7 @@ class LLMUsage(ScopedConfig):
63
63
  self._added_deployment = ROUTER.upsert_deployment(deployment)
64
64
  return ROUTER
65
65
 
66
- # noinspection PyTypeChecker,PydanticTypeChecker
66
+ # noinspection PyTypeChecker,PydanticTypeChecker,t
67
67
  async def aquery(
68
68
  self,
69
69
  messages: List[Dict[str, str]],
@@ -122,6 +122,12 @@ class LLMUsage(ScopedConfig):
122
122
  "cache-ttl": kwargs.get("cache_ttl"),
123
123
  "s-maxage": kwargs.get("s_maxage"),
124
124
  },
125
+ presence_penalty=kwargs.get("presence_penalty")
126
+ or self.llm_presence_penalty
127
+ or configs.llm.presence_penalty,
128
+ frequency_penalty=kwargs.get("frequency_penalty")
129
+ or self.llm_frequency_penalty
130
+ or configs.llm.frequency_penalty,
125
131
  )
126
132
 
127
133
  async def ainvoke(
@@ -236,7 +242,6 @@ class LLMUsage(ScopedConfig):
236
242
  validator: Callable[[str], T | None],
237
243
  default: T = ...,
238
244
  max_validations: PositiveInt = 2,
239
- co_extractor: Optional[GenerateKwargs] = None,
240
245
  **kwargs: Unpack[GenerateKwargs],
241
246
  ) -> T: ...
242
247
  @overload
@@ -246,7 +251,6 @@ class LLMUsage(ScopedConfig):
246
251
  validator: Callable[[str], T | None],
247
252
  default: T = ...,
248
253
  max_validations: PositiveInt = 2,
249
- co_extractor: Optional[GenerateKwargs] = None,
250
254
  **kwargs: Unpack[GenerateKwargs],
251
255
  ) -> List[T]: ...
252
256
  @overload
@@ -256,7 +260,6 @@ class LLMUsage(ScopedConfig):
256
260
  validator: Callable[[str], T | None],
257
261
  default: None = None,
258
262
  max_validations: PositiveInt = 2,
259
- co_extractor: Optional[GenerateKwargs] = None,
260
263
  **kwargs: Unpack[GenerateKwargs],
261
264
  ) -> Optional[T]: ...
262
265
 
@@ -267,7 +270,6 @@ class LLMUsage(ScopedConfig):
267
270
  validator: Callable[[str], T | None],
268
271
  default: None = None,
269
272
  max_validations: PositiveInt = 2,
270
- co_extractor: Optional[GenerateKwargs] = None,
271
273
  **kwargs: Unpack[GenerateKwargs],
272
274
  ) -> List[Optional[T]]: ...
273
275
 
@@ -277,7 +279,6 @@ class LLMUsage(ScopedConfig):
277
279
  validator: Callable[[str], T | None],
278
280
  default: Optional[T] = None,
279
281
  max_validations: PositiveInt = 3,
280
- co_extractor: Optional[GenerateKwargs] = None,
281
282
  **kwargs: Unpack[GenerateKwargs],
282
283
  ) -> Optional[T] | List[Optional[T]] | List[T] | T:
283
284
  """Asynchronously asks a question and validates the response using a given validator.
@@ -287,34 +288,16 @@ class LLMUsage(ScopedConfig):
287
288
  validator (Callable[[str], T | None]): A function to validate the response.
288
289
  default (T | None): Default value to return if validation fails. Defaults to None.
289
290
  max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 3.
290
- co_extractor (Optional[GenerateKwargs]): Keyword arguments for the co-extractor, if provided will enable co-extraction.
291
291
  **kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
292
292
 
293
293
  Returns:
294
- Optional[T] | List[Optional[T]] | List[T] | T: The validated response.
294
+ Optional[T] | List[T | None] | List[T] | T: The validated response.
295
295
  """
296
296
 
297
297
  async def _inner(q: str) -> Optional[T]:
298
298
  for lap in range(max_validations):
299
299
  try:
300
- if ((validated := validator(response := await self.aask(question=q, **kwargs))) is not None) or (
301
- co_extractor is not None
302
- and logger.debug("Co-extraction is enabled.") is None
303
- and (
304
- validated := validator(
305
- response := await self.aask(
306
- question=(
307
- TEMPLATE_MANAGER.render_template(
308
- configs.templates.co_validation_template,
309
- {"original_q": q, "original_a": response},
310
- )
311
- ),
312
- **co_extractor,
313
- )
314
- )
315
- )
316
- is not None
317
- ):
300
+ if (validated := validator(response := await self.aask(question=q, **kwargs))) is not None:
318
301
  logger.debug(f"Successfully validated the response at {lap}th attempt.")
319
302
  return validated
320
303
 
fabricatio/parser.py CHANGED
@@ -1,12 +1,13 @@
1
1
  """A module to parse text using regular expressions."""
2
2
 
3
+ import re
4
+ from functools import lru_cache
5
+ from re import Pattern, compile
3
6
  from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type
4
7
 
5
- import orjson
6
- import regex
8
+ import ujson
7
9
  from json_repair import repair_json
8
10
  from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
9
- from regex import Pattern, compile
10
11
 
11
12
  from fabricatio.config import configs
12
13
  from fabricatio.journal import logger
@@ -25,7 +26,7 @@ class Capture(BaseModel):
25
26
  """The target groups to capture from the pattern."""
26
27
  pattern: str = Field(frozen=True)
27
28
  """The regular expression pattern to search for."""
28
- flags: PositiveInt = Field(default=regex.DOTALL | regex.MULTILINE | regex.IGNORECASE, frozen=True)
29
+ flags: PositiveInt = Field(default=re.DOTALL | re.MULTILINE | re.IGNORECASE, frozen=True)
29
30
  """The flags to use when compiling the regular expression pattern."""
30
31
  capture_type: Optional[str] = None
31
32
  """The type of capture to perform, e.g., 'json', which is used to dispatch the fixer accordingly."""
@@ -49,7 +50,8 @@ class Capture(BaseModel):
49
50
  logger.debug("Applying json repair to text.")
50
51
  if isinstance(text, str):
51
52
  return repair_json(text, ensure_ascii=False) # pyright: ignore [reportReturnType]
52
- return [repair_json(item, ensure_ascii=False) for item in text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
53
+ return [repair_json(item, ensure_ascii=False) for item in
54
+ text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
53
55
  case _:
54
56
  return text # pyright: ignore [reportReturnType]
55
57
 
@@ -63,7 +65,7 @@ class Capture(BaseModel):
63
65
  str | None: The captured text if the pattern is found, otherwise None.
64
66
 
65
67
  """
66
- if (match :=self._compiled.match(text) or self._compiled.search(text) ) is None:
68
+ if (match := self._compiled.match(text) or self._compiled.search(text)) is None:
67
69
  logger.debug(f"Capture Failed {type(text)}: \n{text}")
68
70
  return None
69
71
  groups = self.fix(match.groups())
@@ -94,12 +96,12 @@ class Capture(BaseModel):
94
96
  return None
95
97
 
96
98
  def validate_with[K, T, E](
97
- self,
98
- text: str,
99
- target_type: Type[T],
100
- elements_type: Optional[Type[E]] = None,
101
- length: Optional[int] = None,
102
- deserializer: Callable[[Tuple[str, ...]], K] | Callable[[str], K] = orjson.loads,
99
+ self,
100
+ text: str,
101
+ target_type: Type[T],
102
+ elements_type: Optional[Type[E]] = None,
103
+ length: Optional[int] = None,
104
+ deserializer: Callable[[Tuple[str, ...]], K] | Callable[[str], K] = ujson.loads,
103
105
  ) -> T | None:
104
106
  """Validate the given text using the pattern.
105
107
 
@@ -124,6 +126,7 @@ class Capture(BaseModel):
124
126
  return None
125
127
 
126
128
  @classmethod
129
+ @lru_cache(32)
127
130
  def capture_code_block(cls, language: str) -> Self:
128
131
  """Capture the first occurrence of a code block in the given text.
129
132
 
@@ -136,6 +139,7 @@ class Capture(BaseModel):
136
139
  return cls(pattern=f"```{language}(.*?)```", capture_type=language)
137
140
 
138
141
  @classmethod
142
+ @lru_cache(32)
139
143
  def capture_generic_block(cls, language: str) -> Self:
140
144
  """Capture the first occurrence of a generic code block in the given text.
141
145
 
Binary file