fabricatio 0.2.10.dev1__cp312-cp312-win_amd64.whl → 0.2.11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +55 -10
- fabricatio/actions/article_rag.py +268 -14
- fabricatio/actions/fs.py +25 -0
- fabricatio/actions/output.py +17 -3
- fabricatio/actions/rag.py +3 -3
- fabricatio/actions/rules.py +14 -3
- fabricatio/capabilities/extract.py +70 -0
- fabricatio/capabilities/rating.py +5 -2
- fabricatio/capabilities/task.py +16 -16
- fabricatio/config.py +9 -2
- fabricatio/decorators.py +43 -26
- fabricatio/fs/__init__.py +9 -2
- fabricatio/fs/readers.py +6 -10
- fabricatio/models/action.py +16 -11
- fabricatio/models/extra/aricle_rag.py +143 -9
- fabricatio/models/extra/article_base.py +56 -7
- fabricatio/models/extra/article_main.py +102 -6
- fabricatio/models/extra/problem.py +5 -1
- fabricatio/models/generic.py +31 -13
- fabricatio/models/kwargs_types.py +4 -2
- fabricatio/models/task.py +13 -1
- fabricatio/models/usages.py +10 -27
- fabricatio/parser.py +16 -12
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +167 -62
- fabricatio/utils.py +38 -11
- fabricatio-0.2.11.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dist-info}/METADATA +20 -9
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dist-info}/RECORD +31 -29
- fabricatio-0.2.10.dev1.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dist-info}/licenses/LICENSE +0 -0
@@ -18,6 +18,7 @@ from fabricatio.models.generic import (
|
|
18
18
|
Titled,
|
19
19
|
WordCount,
|
20
20
|
)
|
21
|
+
from fabricatio.rust import comment
|
21
22
|
from pydantic import Field
|
22
23
|
|
23
24
|
|
@@ -29,11 +30,9 @@ class ReferringType(StrEnum):
|
|
29
30
|
SUBSECTION = "subsection"
|
30
31
|
|
31
32
|
|
32
|
-
|
33
33
|
type RefKey = Tuple[str, Optional[str], Optional[str]]
|
34
34
|
|
35
35
|
|
36
|
-
|
37
36
|
class ArticleMetaData(SketchedAble, Described, WordCount, Titled, Language):
|
38
37
|
"""Metadata for an article component."""
|
39
38
|
|
@@ -47,7 +46,16 @@ class ArticleMetaData(SketchedAble, Described, WordCount, Titled, Language):
|
|
47
46
|
aims: List[str]
|
48
47
|
"""List of writing aims of the research component in academic style."""
|
49
48
|
|
50
|
-
|
49
|
+
@property
|
50
|
+
def typst_metadata_comment(self) -> str:
|
51
|
+
"""Generates a comment for the metadata of the article component."""
|
52
|
+
return (
|
53
|
+
(comment(f"Desc:\n{self.description}\n") if self.description else "")
|
54
|
+
+ (comment(f"Aims:\n{'\n '.join(self.aims)}\n") if self.aims else "")
|
55
|
+
+ (comment(f"Expected Word Count:{self.expected_word_count}") if self.expected_word_count else "")
|
56
|
+
if self.expected_word_count
|
57
|
+
else ""
|
58
|
+
)
|
51
59
|
|
52
60
|
|
53
61
|
class ArticleOutlineBase(
|
@@ -92,7 +100,7 @@ class SubSectionBase(ArticleOutlineBase):
|
|
92
100
|
|
93
101
|
def to_typst_code(self) -> str:
|
94
102
|
"""Converts the component into a Typst code snippet for rendering."""
|
95
|
-
return f"=== {self.title}\n"
|
103
|
+
return f"=== {self.title}\n{self.typst_metadata_comment}\n"
|
96
104
|
|
97
105
|
def introspect(self) -> str:
|
98
106
|
"""Introspects the article subsection outline."""
|
@@ -117,7 +125,9 @@ class SectionBase[T: SubSectionBase](ArticleOutlineBase):
|
|
117
125
|
Returns:
|
118
126
|
str: The formatted Typst code snippet.
|
119
127
|
"""
|
120
|
-
return f"== {self.title}\n" + "\n\n".join(
|
128
|
+
return f"== {self.title}\n{self.typst_metadata_comment}\n" + "\n\n".join(
|
129
|
+
subsec.to_typst_code() for subsec in self.subsections
|
130
|
+
)
|
121
131
|
|
122
132
|
def resolve_update_conflict(self, other: Self) -> str:
|
123
133
|
"""Resolve update errors in the article outline."""
|
@@ -160,7 +170,9 @@ class ChapterBase[T: SectionBase](ArticleOutlineBase):
|
|
160
170
|
|
161
171
|
def to_typst_code(self) -> str:
|
162
172
|
"""Converts the chapter into a Typst formatted code snippet for rendering."""
|
163
|
-
return f"= {self.title}\n" + "\n\n".join(
|
173
|
+
return f"= {self.title}\n{self.typst_metadata_comment}\n" + "\n\n".join(
|
174
|
+
sec.to_typst_code() for sec in self.sections
|
175
|
+
)
|
164
176
|
|
165
177
|
def resolve_update_conflict(self, other: Self) -> str:
|
166
178
|
"""Resolve update errors in the article outline."""
|
@@ -302,4 +314,41 @@ class ArticleBase[T: ChapterBase](FinalizedDumpAble, AsPrompt, WordCount, Descri
|
|
302
314
|
=== Implementation Details
|
303
315
|
== Evaluation Protocol
|
304
316
|
"""
|
305
|
-
return
|
317
|
+
return (
|
318
|
+
comment(
|
319
|
+
f"Title:{self.title}\n"
|
320
|
+
+ (f"Desc:\n{self.description}\n" if self.description else "")
|
321
|
+
+ f"Word Count:{self.expected_word_count}"
|
322
|
+
if self.expected_word_count
|
323
|
+
else ""
|
324
|
+
)
|
325
|
+
+ "\n\n"
|
326
|
+
+ "\n\n".join(a.to_typst_code() for a in self.chapters)
|
327
|
+
)
|
328
|
+
|
329
|
+
def avg_chap_wordcount[S](self:S) -> S:
|
330
|
+
"""Set all chap have same word count sum up to be `self.expected_word_count`."""
|
331
|
+
avg = int(self.expected_word_count / len(self.chapters))
|
332
|
+
for c in self.chapters:
|
333
|
+
c.expected_word_count = avg
|
334
|
+
return self
|
335
|
+
|
336
|
+
def avg_sec_wordcount[S](self:S) -> S:
|
337
|
+
"""Set all sec have same word count sum up to be `self.expected_word_count`."""
|
338
|
+
for c in self.chapters:
|
339
|
+
avg = int(c.expected_word_count / len(c.sections))
|
340
|
+
for s in c.sections:
|
341
|
+
s.expected_word_count = avg
|
342
|
+
return self
|
343
|
+
|
344
|
+
def avg_subsec_wordcount[S](self:S) -> S:
|
345
|
+
"""Set all subsec have same word count sum up to be `self.expected_word_count`."""
|
346
|
+
for _, s in self.iter_sections():
|
347
|
+
avg = int(s.expected_word_count / len(s.subsections))
|
348
|
+
for ss in s.subsections:
|
349
|
+
ss.expected_word_count = avg
|
350
|
+
return self
|
351
|
+
|
352
|
+
def avg_wordcount_recursive(self) -> Self:
|
353
|
+
"""Set all chap, sec, subsec have same word count sum up to be `self.expected_word_count`."""
|
354
|
+
return self.avg_chap_wordcount().avg_sec_wordcount().avg_sec_wordcount()
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from typing import Dict, Generator, List, Self, Tuple, override
|
4
4
|
|
5
|
+
from fabricatio.decorators import precheck_package
|
5
6
|
from fabricatio.fs.readers import extract_sections
|
6
7
|
from fabricatio.journal import logger
|
7
8
|
from fabricatio.models.extra.article_base import (
|
@@ -11,11 +12,15 @@ from fabricatio.models.extra.article_base import (
|
|
11
12
|
SubSectionBase,
|
12
13
|
)
|
13
14
|
from fabricatio.models.extra.article_outline import (
|
15
|
+
ArticleChapterOutline,
|
14
16
|
ArticleOutline,
|
17
|
+
ArticleSectionOutline,
|
18
|
+
ArticleSubsectionOutline,
|
15
19
|
)
|
16
20
|
from fabricatio.models.generic import Described, PersistentAble, SequencePatch, SketchedAble, WithRef, WordCount
|
17
|
-
from fabricatio.rust import word_count
|
18
|
-
from
|
21
|
+
from fabricatio.rust import convert_all_block_tex, convert_all_inline_tex, word_count
|
22
|
+
from fabricatio.utils import fallback_kwargs
|
23
|
+
from pydantic import Field, NonNegativeInt
|
19
24
|
|
20
25
|
PARAGRAPH_SEP = "// - - -"
|
21
26
|
|
@@ -23,6 +28,9 @@ PARAGRAPH_SEP = "// - - -"
|
|
23
28
|
class Paragraph(SketchedAble, WordCount, Described):
|
24
29
|
"""Structured academic paragraph blueprint for controlled content generation."""
|
25
30
|
|
31
|
+
expected_word_count: NonNegativeInt = 0
|
32
|
+
"""The expected word count of this paragraph, 0 means not specified"""
|
33
|
+
|
26
34
|
description: str = Field(
|
27
35
|
alias="elaboration",
|
28
36
|
description=Described.model_fields["description"].description,
|
@@ -85,7 +93,7 @@ class ArticleSubsection(SubSectionBase):
|
|
85
93
|
Returns:
|
86
94
|
str: Typst code snippet for rendering.
|
87
95
|
"""
|
88
|
-
return
|
96
|
+
return super().to_typst_code() + f"\n\n{PARAGRAPH_SEP}\n\n".join(p.content for p in self.paragraphs)
|
89
97
|
|
90
98
|
@classmethod
|
91
99
|
def from_typst_code(cls, title: str, body: str) -> Self:
|
@@ -153,10 +161,74 @@ class Article(
|
|
153
161
|
"Original Article": self.display(),
|
154
162
|
}
|
155
163
|
|
164
|
+
def convert_tex(self) -> Self:
|
165
|
+
"""Convert tex to typst code."""
|
166
|
+
for _, _, subsec in self.iter_subsections():
|
167
|
+
for p in subsec.paragraphs:
|
168
|
+
p.content = convert_all_inline_tex(p.content)
|
169
|
+
p.content = convert_all_block_tex(p.content)
|
170
|
+
return self
|
171
|
+
|
172
|
+
def fix_wrapper(self) -> Self:
|
173
|
+
"""Fix wrapper."""
|
174
|
+
for _, _, subsec in self.iter_subsections():
|
175
|
+
for p in subsec.paragraphs:
|
176
|
+
p.content = (
|
177
|
+
p.content.replace(r" \( ", "$")
|
178
|
+
.replace(r" \) ", "$")
|
179
|
+
.replace("\\[\n", "$$\n")
|
180
|
+
.replace("\n\\]", "\n$$")
|
181
|
+
)
|
182
|
+
return self
|
183
|
+
|
156
184
|
@override
|
157
185
|
def iter_subsections(self) -> Generator[Tuple[ArticleChapter, ArticleSection, ArticleSubsection], None, None]:
|
158
186
|
return super().iter_subsections() # pyright: ignore [reportReturnType]
|
159
187
|
|
188
|
+
def extrac_outline(self) -> ArticleOutline:
|
189
|
+
"""Extract outline from article."""
|
190
|
+
# Create an empty list to hold chapter outlines
|
191
|
+
chapters = []
|
192
|
+
|
193
|
+
# Iterate through each chapter in the article
|
194
|
+
for chapter in self.chapters:
|
195
|
+
# Create an empty list to hold section outlines
|
196
|
+
sections = []
|
197
|
+
|
198
|
+
# Iterate through each section in the chapter
|
199
|
+
for section in chapter.sections:
|
200
|
+
# Create an empty list to hold subsection outlines
|
201
|
+
subsections = []
|
202
|
+
|
203
|
+
# Iterate through each subsection in the section
|
204
|
+
for subsection in section.subsections:
|
205
|
+
# Create a subsection outline and add it to the list
|
206
|
+
subsections.append(
|
207
|
+
ArticleSubsectionOutline(**subsection.model_dump(exclude={"paragraphs"}, by_alias=True))
|
208
|
+
)
|
209
|
+
|
210
|
+
# Create a section outline and add it to the list
|
211
|
+
sections.append(
|
212
|
+
ArticleSectionOutline(
|
213
|
+
**section.model_dump(exclude={"subsections"}, by_alias=True),
|
214
|
+
subsections=subsections,
|
215
|
+
)
|
216
|
+
)
|
217
|
+
|
218
|
+
# Create a chapter outline and add it to the list
|
219
|
+
chapters.append(
|
220
|
+
ArticleChapterOutline(
|
221
|
+
**chapter.model_dump(exclude={"sections"}, by_alias=True),
|
222
|
+
sections=sections,
|
223
|
+
)
|
224
|
+
)
|
225
|
+
|
226
|
+
# Create and return the article outline
|
227
|
+
return ArticleOutline(
|
228
|
+
**self.model_dump(exclude={"chapters"}, by_alias=True),
|
229
|
+
chapters=chapters,
|
230
|
+
)
|
231
|
+
|
160
232
|
@classmethod
|
161
233
|
def from_outline(cls, outline: ArticleOutline) -> "Article":
|
162
234
|
"""Generates an article from the given outline.
|
@@ -194,13 +266,37 @@ class Article(
|
|
194
266
|
return article
|
195
267
|
|
196
268
|
@classmethod
|
197
|
-
def from_typst_code(cls, title: str, body: str) -> Self:
|
269
|
+
def from_typst_code(cls, title: str, body: str, **kwargs) -> Self:
|
198
270
|
"""Generates an article from the given Typst code."""
|
199
271
|
return cls(
|
200
272
|
chapters=[
|
201
273
|
ArticleChapter.from_typst_code(*pack) for pack in extract_sections(body, level=1, section_char="=")
|
202
274
|
],
|
203
275
|
heading=title,
|
204
|
-
|
205
|
-
|
276
|
+
**fallback_kwargs(
|
277
|
+
kwargs,
|
278
|
+
expected_word_count=word_count(body),
|
279
|
+
abstract="",
|
280
|
+
),
|
206
281
|
)
|
282
|
+
|
283
|
+
@classmethod
|
284
|
+
def from_mixed_source(cls, article_outline: ArticleOutline, typst_code: str) -> Self:
|
285
|
+
"""Generates an article from the given outline and Typst code."""
|
286
|
+
self = cls.from_typst_code(article_outline.title, typst_code)
|
287
|
+
self.expected_word_count = article_outline.expected_word_count
|
288
|
+
self.description = article_outline.description
|
289
|
+
for a, o in zip(self.iter_dfs(), article_outline.iter_dfs(), strict=True):
|
290
|
+
a.update_metadata(o)
|
291
|
+
return self.update_ref(article_outline)
|
292
|
+
|
293
|
+
@precheck_package(
|
294
|
+
"questionary", "'questionary' is required to run this function. Have you installed `fabricatio[qa]`?."
|
295
|
+
)
|
296
|
+
async def edit_titles(self) -> Self:
|
297
|
+
"""Edits the titles of the article."""
|
298
|
+
from questionary import text
|
299
|
+
|
300
|
+
for a in self.iter_dfs():
|
301
|
+
a.title = await text(f"Edit `{a.title}`.", default=a.title).ask_async() or a.title
|
302
|
+
return self
|
@@ -7,7 +7,6 @@ from fabricatio.journal import logger
|
|
7
7
|
from fabricatio.models.generic import SketchedAble, WithBriefing
|
8
8
|
from fabricatio.utils import ask_edit
|
9
9
|
from pydantic import Field
|
10
|
-
from questionary import Choice, checkbox, text
|
11
10
|
from rich import print as r_print
|
12
11
|
|
13
12
|
|
@@ -74,6 +73,9 @@ class ProblemSolutions(SketchedAble):
|
|
74
73
|
return len(self.solutions) > 0
|
75
74
|
|
76
75
|
async def edit_problem(self) -> Self:
|
76
|
+
"""Interactively edit the problem description."""
|
77
|
+
from questionary import text
|
78
|
+
|
77
79
|
"""Interactively edit the problem description."""
|
78
80
|
self.problem = Problem.model_validate_strings(
|
79
81
|
await text("Please edit the problem below:", default=self.problem.display()).ask_async()
|
@@ -127,6 +129,8 @@ class Improvement(SketchedAble):
|
|
127
129
|
Returns:
|
128
130
|
Self: The current instance with filtered problems and solutions.
|
129
131
|
"""
|
132
|
+
from questionary import Choice, checkbox
|
133
|
+
|
130
134
|
# Choose the problems to retain
|
131
135
|
chosen_ones: List[ProblemSolutions] = await checkbox(
|
132
136
|
"Please choose the problems you want to retain.(Default: retain all)",
|
fabricatio/models/generic.py
CHANGED
@@ -3,11 +3,11 @@
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from datetime import datetime
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Type, Union, final, overload
|
6
|
+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Self, Type, Union, final, overload
|
7
7
|
|
8
|
-
import
|
8
|
+
import ujson
|
9
9
|
from fabricatio.config import configs
|
10
|
-
from fabricatio.fs.readers import
|
10
|
+
from fabricatio.fs.readers import safe_text_read
|
11
11
|
from fabricatio.journal import logger
|
12
12
|
from fabricatio.parser import JsonCapture
|
13
13
|
from fabricatio.rust import blake3_hash, detect_language
|
@@ -117,6 +117,15 @@ class WordCount(Base):
|
|
117
117
|
"""Expected word count of this research component."""
|
118
118
|
|
119
119
|
|
120
|
+
class FromMapping(Base):
|
121
|
+
"""Class that provides a method to generate a list of objects from a mapping."""
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
@abstractmethod
|
125
|
+
def from_mapping(cls, mapping: Mapping[str, Any], **kwargs: Any) -> List[Self]:
|
126
|
+
"""Generate a list of objects from a mapping."""
|
127
|
+
|
128
|
+
|
120
129
|
class AsPrompt(Base):
|
121
130
|
"""Class that provides a method to generate a prompt from the model.
|
122
131
|
|
@@ -170,11 +179,14 @@ class WithRef[T](Base):
|
|
170
179
|
|
171
180
|
@overload
|
172
181
|
def update_ref[S: WithRef](self: S, reference: T) -> S: ...
|
182
|
+
|
173
183
|
@overload
|
174
184
|
def update_ref[S: WithRef](self: S, reference: "WithRef[T]") -> S: ...
|
185
|
+
|
175
186
|
@overload
|
176
187
|
def update_ref[S: WithRef](self: S, reference: None = None) -> S: ...
|
177
|
-
|
188
|
+
|
189
|
+
def update_ref[S: WithRef](self: S, reference: Union[T, "WithRef[T]", None] = None) -> S:
|
178
190
|
"""Update the reference of the object.
|
179
191
|
|
180
192
|
Args:
|
@@ -189,7 +201,7 @@ class WithRef[T](Base):
|
|
189
201
|
self._reference = reference # pyright: ignore [reportAttributeAccessIssue]
|
190
202
|
return self
|
191
203
|
|
192
|
-
def derive[S: WithRef](self: S, reference: Any) -> S:
|
204
|
+
def derive[S: WithRef](self: S, reference: Any) -> S:
|
193
205
|
"""Derive a new object from the current object.
|
194
206
|
|
195
207
|
Args:
|
@@ -455,10 +467,9 @@ class WithFormatedJsonSchema(Base):
|
|
455
467
|
Returns:
|
456
468
|
str: The JSON schema of the model in a formatted string.
|
457
469
|
"""
|
458
|
-
return
|
459
|
-
cls.model_json_schema(schema_generator=UnsortGenerate),
|
460
|
-
|
461
|
-
).decode()
|
470
|
+
return ujson.dumps(
|
471
|
+
cls.model_json_schema(schema_generator=UnsortGenerate), indent=2, ensure_ascii=False, sort_keys=False
|
472
|
+
)
|
462
473
|
|
463
474
|
|
464
475
|
class CreateJsonObjPrompt(WithFormatedJsonSchema):
|
@@ -470,9 +481,11 @@ class CreateJsonObjPrompt(WithFormatedJsonSchema):
|
|
470
481
|
@classmethod
|
471
482
|
@overload
|
472
483
|
def create_json_prompt(cls, requirement: List[str]) -> List[str]: ...
|
484
|
+
|
473
485
|
@classmethod
|
474
486
|
@overload
|
475
487
|
def create_json_prompt(cls, requirement: str) -> str: ...
|
488
|
+
|
476
489
|
@classmethod
|
477
490
|
def create_json_prompt(cls, requirement: str | List[str]) -> str | List[str]:
|
478
491
|
"""Create the prompt for creating a JSON object with given requirement.
|
@@ -639,6 +652,8 @@ class WithDependency(Base):
|
|
639
652
|
Returns:
|
640
653
|
str: The generated prompt for the task.
|
641
654
|
"""
|
655
|
+
from fabricatio.fs import MAGIKA
|
656
|
+
|
642
657
|
return TEMPLATE_MANAGER.render_template(
|
643
658
|
configs.templates.dependencies_template,
|
644
659
|
{
|
@@ -734,6 +749,12 @@ class ScopedConfig(Base):
|
|
734
749
|
llm_rpm: Optional[PositiveInt] = None
|
735
750
|
"""The requests per minute of the LLM model."""
|
736
751
|
|
752
|
+
llm_presence_penalty: Optional[PositiveFloat] = None
|
753
|
+
"""The presence penalty of the LLM model."""
|
754
|
+
|
755
|
+
llm_frequency_penalty: Optional[PositiveFloat] = None
|
756
|
+
"""The frequency penalty of the LLM model."""
|
757
|
+
|
737
758
|
embedding_api_endpoint: Optional[HttpUrl] = None
|
738
759
|
"""The OpenAI API endpoint."""
|
739
760
|
|
@@ -862,10 +883,7 @@ class Patch[T](ProposedAble):
|
|
862
883
|
)
|
863
884
|
my_schema["description"] = ref_cls.__doc__
|
864
885
|
|
865
|
-
return
|
866
|
-
my_schema,
|
867
|
-
option=orjson.OPT_INDENT_2,
|
868
|
-
).decode()
|
886
|
+
return ujson.dumps(my_schema, indent=2, ensure_ascii=False, sort_keys=False)
|
869
887
|
|
870
888
|
|
871
889
|
class SequencePatch[T](ProposedUpdateAble):
|
@@ -33,7 +33,7 @@ class LLMKwargs(TypedDict, total=False):
|
|
33
33
|
including generation parameters and caching options.
|
34
34
|
"""
|
35
35
|
|
36
|
-
model: str
|
36
|
+
model: Optional[str]
|
37
37
|
temperature: float
|
38
38
|
stop: str | list[str]
|
39
39
|
top_p: float
|
@@ -45,6 +45,8 @@ class LLMKwargs(TypedDict, total=False):
|
|
45
45
|
no_store: bool # If store the response of this call to cache
|
46
46
|
cache_ttl: int # how long the stored cache is alive, in seconds
|
47
47
|
s_maxage: int # max accepted age of cached response, in seconds
|
48
|
+
presence_penalty: float
|
49
|
+
frequency_penalty: float
|
48
50
|
|
49
51
|
|
50
52
|
class GenerateKwargs(LLMKwargs, total=False):
|
@@ -66,7 +68,7 @@ class ValidateKwargs[T](GenerateKwargs, total=False):
|
|
66
68
|
|
67
69
|
default: Optional[T]
|
68
70
|
max_validations: int
|
69
|
-
|
71
|
+
|
70
72
|
|
71
73
|
|
72
74
|
class CompositeScoreKwargs(ValidateKwargs[List[Dict[str, float]]], total=False):
|
fabricatio/models/task.py
CHANGED
@@ -4,7 +4,7 @@ It includes methods to manage the task's lifecycle, such as starting, finishing,
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
from asyncio import Queue
|
7
|
-
from typing import Any, List, Optional, Self
|
7
|
+
from typing import Any, Dict, List, Optional, Self
|
8
8
|
|
9
9
|
from fabricatio.config import configs
|
10
10
|
from fabricatio.constants import TaskStatus
|
@@ -50,6 +50,18 @@ class Task[T](WithBriefing, ProposedAble, WithDependency):
|
|
50
50
|
|
51
51
|
_namespace: Event = PrivateAttr(default_factory=Event)
|
52
52
|
"""The namespace of the task as an event, which is generated from the namespace list."""
|
53
|
+
_extra_init_context: Dict = PrivateAttr(default_factory=dict)
|
54
|
+
"""Extra initialization context for the task, which is designed to override the one of the Workflow."""
|
55
|
+
|
56
|
+
@property
|
57
|
+
def extra_init_context(self) -> Dict:
|
58
|
+
"""Extra initialization context for the task, which is designed to override the one of the Workflow."""
|
59
|
+
return self._extra_init_context
|
60
|
+
|
61
|
+
def update_init_context(self, /, **kwargs) -> Self:
|
62
|
+
"""Update the extra initialization context for the task."""
|
63
|
+
self.extra_init_context.update(kwargs)
|
64
|
+
return self
|
53
65
|
|
54
66
|
def model_post_init(self, __context: Any) -> None:
|
55
67
|
"""Initialize the task with a namespace event."""
|
fabricatio/models/usages.py
CHANGED
@@ -31,7 +31,7 @@ from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
|
|
31
31
|
|
32
32
|
if configs.cache.enabled and configs.cache.type:
|
33
33
|
litellm.enable_cache(type=configs.cache.type, **configs.cache.params)
|
34
|
-
logger.
|
34
|
+
logger.debug(f"{configs.cache.type.name} Cache enabled")
|
35
35
|
|
36
36
|
ROUTER = Router(
|
37
37
|
routing_strategy="usage-based-routing-v2",
|
@@ -63,7 +63,7 @@ class LLMUsage(ScopedConfig):
|
|
63
63
|
self._added_deployment = ROUTER.upsert_deployment(deployment)
|
64
64
|
return ROUTER
|
65
65
|
|
66
|
-
# noinspection PyTypeChecker,PydanticTypeChecker
|
66
|
+
# noinspection PyTypeChecker,PydanticTypeChecker,t
|
67
67
|
async def aquery(
|
68
68
|
self,
|
69
69
|
messages: List[Dict[str, str]],
|
@@ -122,6 +122,12 @@ class LLMUsage(ScopedConfig):
|
|
122
122
|
"cache-ttl": kwargs.get("cache_ttl"),
|
123
123
|
"s-maxage": kwargs.get("s_maxage"),
|
124
124
|
},
|
125
|
+
presence_penalty=kwargs.get("presence_penalty")
|
126
|
+
or self.llm_presence_penalty
|
127
|
+
or configs.llm.presence_penalty,
|
128
|
+
frequency_penalty=kwargs.get("frequency_penalty")
|
129
|
+
or self.llm_frequency_penalty
|
130
|
+
or configs.llm.frequency_penalty,
|
125
131
|
)
|
126
132
|
|
127
133
|
async def ainvoke(
|
@@ -236,7 +242,6 @@ class LLMUsage(ScopedConfig):
|
|
236
242
|
validator: Callable[[str], T | None],
|
237
243
|
default: T = ...,
|
238
244
|
max_validations: PositiveInt = 2,
|
239
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
240
245
|
**kwargs: Unpack[GenerateKwargs],
|
241
246
|
) -> T: ...
|
242
247
|
@overload
|
@@ -246,7 +251,6 @@ class LLMUsage(ScopedConfig):
|
|
246
251
|
validator: Callable[[str], T | None],
|
247
252
|
default: T = ...,
|
248
253
|
max_validations: PositiveInt = 2,
|
249
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
250
254
|
**kwargs: Unpack[GenerateKwargs],
|
251
255
|
) -> List[T]: ...
|
252
256
|
@overload
|
@@ -256,7 +260,6 @@ class LLMUsage(ScopedConfig):
|
|
256
260
|
validator: Callable[[str], T | None],
|
257
261
|
default: None = None,
|
258
262
|
max_validations: PositiveInt = 2,
|
259
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
260
263
|
**kwargs: Unpack[GenerateKwargs],
|
261
264
|
) -> Optional[T]: ...
|
262
265
|
|
@@ -267,7 +270,6 @@ class LLMUsage(ScopedConfig):
|
|
267
270
|
validator: Callable[[str], T | None],
|
268
271
|
default: None = None,
|
269
272
|
max_validations: PositiveInt = 2,
|
270
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
271
273
|
**kwargs: Unpack[GenerateKwargs],
|
272
274
|
) -> List[Optional[T]]: ...
|
273
275
|
|
@@ -277,7 +279,6 @@ class LLMUsage(ScopedConfig):
|
|
277
279
|
validator: Callable[[str], T | None],
|
278
280
|
default: Optional[T] = None,
|
279
281
|
max_validations: PositiveInt = 3,
|
280
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
281
282
|
**kwargs: Unpack[GenerateKwargs],
|
282
283
|
) -> Optional[T] | List[Optional[T]] | List[T] | T:
|
283
284
|
"""Asynchronously asks a question and validates the response using a given validator.
|
@@ -287,34 +288,16 @@ class LLMUsage(ScopedConfig):
|
|
287
288
|
validator (Callable[[str], T | None]): A function to validate the response.
|
288
289
|
default (T | None): Default value to return if validation fails. Defaults to None.
|
289
290
|
max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 3.
|
290
|
-
co_extractor (Optional[GenerateKwargs]): Keyword arguments for the co-extractor, if provided will enable co-extraction.
|
291
291
|
**kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
|
292
292
|
|
293
293
|
Returns:
|
294
|
-
Optional[T] | List[
|
294
|
+
Optional[T] | List[T | None] | List[T] | T: The validated response.
|
295
295
|
"""
|
296
296
|
|
297
297
|
async def _inner(q: str) -> Optional[T]:
|
298
298
|
for lap in range(max_validations):
|
299
299
|
try:
|
300
|
-
if (
|
301
|
-
co_extractor is not None
|
302
|
-
and logger.debug("Co-extraction is enabled.") is None
|
303
|
-
and (
|
304
|
-
validated := validator(
|
305
|
-
response := await self.aask(
|
306
|
-
question=(
|
307
|
-
TEMPLATE_MANAGER.render_template(
|
308
|
-
configs.templates.co_validation_template,
|
309
|
-
{"original_q": q, "original_a": response},
|
310
|
-
)
|
311
|
-
),
|
312
|
-
**co_extractor,
|
313
|
-
)
|
314
|
-
)
|
315
|
-
)
|
316
|
-
is not None
|
317
|
-
):
|
300
|
+
if (validated := validator(response := await self.aask(question=q, **kwargs))) is not None:
|
318
301
|
logger.debug(f"Successfully validated the response at {lap}th attempt.")
|
319
302
|
return validated
|
320
303
|
|
fabricatio/parser.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
1
|
"""A module to parse text using regular expressions."""
|
2
2
|
|
3
|
+
import re
|
4
|
+
from functools import lru_cache
|
5
|
+
from re import Pattern, compile
|
3
6
|
from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type
|
4
7
|
|
5
|
-
import
|
6
|
-
import regex
|
8
|
+
import ujson
|
7
9
|
from json_repair import repair_json
|
8
10
|
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
|
9
|
-
from regex import Pattern, compile
|
10
11
|
|
11
12
|
from fabricatio.config import configs
|
12
13
|
from fabricatio.journal import logger
|
@@ -25,7 +26,7 @@ class Capture(BaseModel):
|
|
25
26
|
"""The target groups to capture from the pattern."""
|
26
27
|
pattern: str = Field(frozen=True)
|
27
28
|
"""The regular expression pattern to search for."""
|
28
|
-
flags: PositiveInt = Field(default=
|
29
|
+
flags: PositiveInt = Field(default=re.DOTALL | re.MULTILINE | re.IGNORECASE, frozen=True)
|
29
30
|
"""The flags to use when compiling the regular expression pattern."""
|
30
31
|
capture_type: Optional[str] = None
|
31
32
|
"""The type of capture to perform, e.g., 'json', which is used to dispatch the fixer accordingly."""
|
@@ -49,7 +50,8 @@ class Capture(BaseModel):
|
|
49
50
|
logger.debug("Applying json repair to text.")
|
50
51
|
if isinstance(text, str):
|
51
52
|
return repair_json(text, ensure_ascii=False) # pyright: ignore [reportReturnType]
|
52
|
-
return [repair_json(item, ensure_ascii=False) for item in
|
53
|
+
return [repair_json(item, ensure_ascii=False) for item in
|
54
|
+
text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
|
53
55
|
case _:
|
54
56
|
return text # pyright: ignore [reportReturnType]
|
55
57
|
|
@@ -63,7 +65,7 @@ class Capture(BaseModel):
|
|
63
65
|
str | None: The captured text if the pattern is found, otherwise None.
|
64
66
|
|
65
67
|
"""
|
66
|
-
if (match :=self._compiled.match(text) or self._compiled.search(text)
|
68
|
+
if (match := self._compiled.match(text) or self._compiled.search(text)) is None:
|
67
69
|
logger.debug(f"Capture Failed {type(text)}: \n{text}")
|
68
70
|
return None
|
69
71
|
groups = self.fix(match.groups())
|
@@ -94,12 +96,12 @@ class Capture(BaseModel):
|
|
94
96
|
return None
|
95
97
|
|
96
98
|
def validate_with[K, T, E](
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
99
|
+
self,
|
100
|
+
text: str,
|
101
|
+
target_type: Type[T],
|
102
|
+
elements_type: Optional[Type[E]] = None,
|
103
|
+
length: Optional[int] = None,
|
104
|
+
deserializer: Callable[[Tuple[str, ...]], K] | Callable[[str], K] = ujson.loads,
|
103
105
|
) -> T | None:
|
104
106
|
"""Validate the given text using the pattern.
|
105
107
|
|
@@ -124,6 +126,7 @@ class Capture(BaseModel):
|
|
124
126
|
return None
|
125
127
|
|
126
128
|
@classmethod
|
129
|
+
@lru_cache(32)
|
127
130
|
def capture_code_block(cls, language: str) -> Self:
|
128
131
|
"""Capture the first occurrence of a code block in the given text.
|
129
132
|
|
@@ -136,6 +139,7 @@ class Capture(BaseModel):
|
|
136
139
|
return cls(pattern=f"```{language}(.*?)```", capture_type=language)
|
137
140
|
|
138
141
|
@classmethod
|
142
|
+
@lru_cache(32)
|
139
143
|
def capture_generic_block(cls, language: str) -> Self:
|
140
144
|
"""Capture the first occurrence of a generic code block in the given text.
|
141
145
|
|
Binary file
|