fabricatio 0.2.10.dev1__cp312-cp312-win_amd64.whl → 0.2.11.dev0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +12 -2
- fabricatio/actions/article_rag.py +103 -13
- fabricatio/actions/fs.py +25 -0
- fabricatio/actions/output.py +17 -3
- fabricatio/actions/rag.py +3 -3
- fabricatio/actions/rules.py +14 -3
- fabricatio/capabilities/extract.py +65 -0
- fabricatio/capabilities/rating.py +5 -2
- fabricatio/capabilities/task.py +16 -16
- fabricatio/config.py +9 -2
- fabricatio/decorators.py +30 -30
- fabricatio/fs/__init__.py +9 -2
- fabricatio/fs/readers.py +6 -10
- fabricatio/models/extra/aricle_rag.py +124 -9
- fabricatio/models/extra/article_main.py +39 -1
- fabricatio/models/extra/problem.py +7 -3
- fabricatio/models/generic.py +46 -19
- fabricatio/models/kwargs_types.py +3 -1
- fabricatio/models/usages.py +9 -26
- fabricatio/parser.py +16 -12
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +130 -11
- fabricatio/utils.py +11 -3
- fabricatio-0.2.11.dev0.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dev0.dist-info}/METADATA +18 -9
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dev0.dist-info}/RECORD +28 -26
- fabricatio-0.2.10.dev1.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dev0.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dev0.dist-info}/licenses/LICENSE +0 -0
fabricatio/fs/__init__.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""FileSystem manipulation module for Fabricatio."""
|
2
|
+
from importlib.util import find_spec
|
2
3
|
|
4
|
+
from fabricatio.config import configs
|
3
5
|
from fabricatio.fs.curd import (
|
4
6
|
absolute_path,
|
5
7
|
copy_file,
|
@@ -11,10 +13,9 @@ from fabricatio.fs.curd import (
|
|
11
13
|
move_file,
|
12
14
|
tree,
|
13
15
|
)
|
14
|
-
from fabricatio.fs.readers import
|
16
|
+
from fabricatio.fs.readers import safe_json_read, safe_text_read
|
15
17
|
|
16
18
|
__all__ = [
|
17
|
-
"MAGIKA",
|
18
19
|
"absolute_path",
|
19
20
|
"copy_file",
|
20
21
|
"create_directory",
|
@@ -27,3 +28,9 @@ __all__ = [
|
|
27
28
|
"safe_text_read",
|
28
29
|
"tree",
|
29
30
|
]
|
31
|
+
|
32
|
+
if find_spec("magika"):
|
33
|
+
from magika import Magika
|
34
|
+
|
35
|
+
MAGIKA = Magika(model_dir=configs.magika.model_dir)
|
36
|
+
__all__ += ["MAGIKA"]
|
fabricatio/fs/readers.py
CHANGED
@@ -1,17 +1,13 @@
|
|
1
1
|
"""Filesystem readers for Fabricatio."""
|
2
2
|
|
3
|
+
import re
|
3
4
|
from pathlib import Path
|
4
5
|
from typing import Dict, List, Tuple
|
5
6
|
|
6
|
-
import
|
7
|
-
import regex
|
8
|
-
from magika import Magika
|
7
|
+
import ujson
|
9
8
|
|
10
|
-
from fabricatio.config import configs
|
11
9
|
from fabricatio.journal import logger
|
12
10
|
|
13
|
-
MAGIKA = Magika(model_dir=configs.magika.model_dir)
|
14
|
-
|
15
11
|
|
16
12
|
def safe_text_read(path: Path | str) -> str:
|
17
13
|
"""Safely read the text from a file.
|
@@ -41,8 +37,8 @@ def safe_json_read(path: Path | str) -> Dict:
|
|
41
37
|
"""
|
42
38
|
path = Path(path)
|
43
39
|
try:
|
44
|
-
return
|
45
|
-
except (
|
40
|
+
return ujson.loads(path.read_text(encoding="utf-8"))
|
41
|
+
except (ujson.JSONDecodeError, IsADirectoryError, FileNotFoundError) as e:
|
46
42
|
logger.error(f"Failed to read file {path}: {e!s}")
|
47
43
|
return {}
|
48
44
|
|
@@ -58,8 +54,8 @@ def extract_sections(string: str, level: int, section_char: str = "#") -> List[T
|
|
58
54
|
Returns:
|
59
55
|
List[Tuple[str, str]]: List of (header_text, section_content) tuples
|
60
56
|
"""
|
61
|
-
return
|
57
|
+
return re.findall(
|
62
58
|
r"^%s{%d}\s+(.+?)\n((?:(?!^%s{%d}\s).|\n)*)" % (section_char, level, section_char, level),
|
63
59
|
string,
|
64
|
-
|
60
|
+
re.MULTILINE,
|
65
61
|
)
|
@@ -1,22 +1,27 @@
|
|
1
1
|
"""A Module containing the article rag models."""
|
2
2
|
|
3
|
+
import re
|
3
4
|
from pathlib import Path
|
4
|
-
from typing import ClassVar, Dict, List, Self, Unpack
|
5
|
+
from typing import ClassVar, Dict, List, Optional, Self, Unpack
|
5
6
|
|
6
7
|
from fabricatio.fs import safe_text_read
|
7
8
|
from fabricatio.journal import logger
|
8
9
|
from fabricatio.models.extra.rag import MilvusDataBase
|
9
10
|
from fabricatio.models.generic import AsPrompt
|
10
11
|
from fabricatio.models.kwargs_types import ChunkKwargs
|
11
|
-
from fabricatio.rust import BibManager, split_into_chunks
|
12
|
-
from fabricatio.utils import ok
|
13
|
-
from more_itertools.recipes import flatten
|
12
|
+
from fabricatio.rust import BibManager, is_chinese, split_into_chunks
|
13
|
+
from fabricatio.utils import ok
|
14
|
+
from more_itertools.recipes import flatten, unique
|
14
15
|
from pydantic import Field
|
15
16
|
|
16
17
|
|
17
18
|
class ArticleChunk(MilvusDataBase, AsPrompt):
|
18
19
|
"""The chunk of an article."""
|
19
20
|
|
21
|
+
etc_word: ClassVar[str] = "等"
|
22
|
+
and_word: ClassVar[str] = "与"
|
23
|
+
_cite_number: Optional[int] = None
|
24
|
+
|
20
25
|
head_split: ClassVar[List[str]] = [
|
21
26
|
"引 言",
|
22
27
|
"引言",
|
@@ -48,12 +53,14 @@ class ArticleChunk(MilvusDataBase, AsPrompt):
|
|
48
53
|
|
49
54
|
def _as_prompt_inner(self) -> Dict[str, str]:
|
50
55
|
return {
|
51
|
-
|
52
|
-
f"Authors: {';'.join(self.authors)}\n"
|
53
|
-
f"Published Year: {self.year}\n"
|
54
|
-
f"Bibtex Key: {self.bibtex_cite_key}\n",
|
56
|
+
f"[[{ok(self._cite_number, 'You need to update cite number first.')}]] reference `{self.article_title}`": self.chunk
|
55
57
|
}
|
56
58
|
|
59
|
+
@property
|
60
|
+
def cite_number(self) -> int:
|
61
|
+
"""Get the cite number."""
|
62
|
+
return ok(self._cite_number, "cite number not set")
|
63
|
+
|
57
64
|
def _prepare_vectorization_inner(self) -> str:
|
58
65
|
return self.chunk
|
59
66
|
|
@@ -89,8 +96,9 @@ class ArticleChunk(MilvusDataBase, AsPrompt):
|
|
89
96
|
|
90
97
|
result = [
|
91
98
|
cls(chunk=c, year=year, authors=authors, article_title=article_title, bibtex_cite_key=key)
|
92
|
-
for c in split_into_chunks(cls.strip(safe_text_read(path)), **kwargs)
|
99
|
+
for c in split_into_chunks(cls.purge_numeric_citation(cls.strip(safe_text_read(path))), **kwargs)
|
93
100
|
]
|
101
|
+
|
94
102
|
logger.debug(f"Number of chunks created from file {path.as_posix()}: {len(result)}")
|
95
103
|
return result
|
96
104
|
|
@@ -118,3 +126,110 @@ class ArticleChunk(MilvusDataBase, AsPrompt):
|
|
118
126
|
logger.warning("No decrease at tail strip, which is might be abnormal.")
|
119
127
|
|
120
128
|
return string
|
129
|
+
|
130
|
+
def as_typst_cite(self) -> str:
|
131
|
+
"""As typst cite."""
|
132
|
+
return f"#cite(<{self.bibtex_cite_key}>)"
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def purge_numeric_citation(string: str) -> str:
|
136
|
+
"""Purge numeric citation."""
|
137
|
+
import re
|
138
|
+
|
139
|
+
return re.sub(r"\[[\d\s,\\~–-]+]", "", string) # noqa: RUF001
|
140
|
+
|
141
|
+
@property
|
142
|
+
def auther_firstnames(self) -> List[str]:
|
143
|
+
"""Get the first name of the authors."""
|
144
|
+
ret = []
|
145
|
+
for n in self.authors:
|
146
|
+
if is_chinese(n):
|
147
|
+
ret.append(n[0])
|
148
|
+
else:
|
149
|
+
ret.append(n.split()[-1])
|
150
|
+
return ret
|
151
|
+
|
152
|
+
def as_auther_seq(self) -> str:
|
153
|
+
"""Get the auther sequence."""
|
154
|
+
match len(self.authors):
|
155
|
+
case 0:
|
156
|
+
raise ValueError("No authors found")
|
157
|
+
case 1:
|
158
|
+
return f"({self.auther_firstnames[0]},{self.year}){self.as_typst_cite()}"
|
159
|
+
case 2:
|
160
|
+
return f"({self.auther_firstnames[0]}{self.and_word}{self.auther_firstnames[1]},{self.year}){self.as_typst_cite()}"
|
161
|
+
case 3:
|
162
|
+
return f"({self.auther_firstnames[0]},{self.auther_firstnames[1]}{self.and_word}{self.auther_firstnames[2]},{self.year}){self.as_typst_cite()}"
|
163
|
+
case _:
|
164
|
+
return f"({self.auther_firstnames[0]},{self.auther_firstnames[1]}{self.and_word}{self.auther_firstnames[2]}{self.etc_word},{self.year}){self.as_typst_cite()}"
|
165
|
+
|
166
|
+
def update_cite_number(self, cite_number: int) -> Self:
|
167
|
+
"""Update the cite number."""
|
168
|
+
self._cite_number = cite_number
|
169
|
+
return self
|
170
|
+
|
171
|
+
|
172
|
+
class CitationManager(AsPrompt):
|
173
|
+
"""Citation manager."""
|
174
|
+
|
175
|
+
article_chunks: List[ArticleChunk] = Field(default_factory=list)
|
176
|
+
"""Article chunks."""
|
177
|
+
|
178
|
+
pat: str = r"\[\[([\d\s,-]*)]]"
|
179
|
+
"""Regex pattern to match citations."""
|
180
|
+
sep: str = ","
|
181
|
+
"""Separator for citation numbers."""
|
182
|
+
abbr_sep: str = "-"
|
183
|
+
"""Separator for abbreviated citation numbers."""
|
184
|
+
|
185
|
+
def update_chunks(self, article_chunks: List[ArticleChunk], set_cite_number: bool = True) -> Self:
|
186
|
+
"""Update article chunks."""
|
187
|
+
self.article_chunks.clear()
|
188
|
+
self.article_chunks.extend(article_chunks)
|
189
|
+
if set_cite_number:
|
190
|
+
self.set_cite_number_all()
|
191
|
+
return self
|
192
|
+
|
193
|
+
def set_cite_number_all(self) -> Self:
|
194
|
+
"""Set citation numbers for all article chunks."""
|
195
|
+
for i, a in enumerate(self.article_chunks, 1):
|
196
|
+
a.update_cite_number(i)
|
197
|
+
return self
|
198
|
+
|
199
|
+
def _as_prompt_inner(self) -> Dict[str, str]:
|
200
|
+
"""Generate prompt inner representation."""
|
201
|
+
return {"References": "\n".join(r.as_prompt() for r in self.article_chunks)}
|
202
|
+
|
203
|
+
def apply(self, string: str) -> str:
|
204
|
+
"""Apply citation replacements to the input string."""
|
205
|
+
matches = re.findall(self.pat, string)
|
206
|
+
|
207
|
+
for m in matches:
|
208
|
+
notations = self.convert_to_numeric_notations(m)
|
209
|
+
|
210
|
+
citation_number_seq = list(flatten(self.decode_expr(n) for n in notations))
|
211
|
+
dedup = self.deduplicate_citation(citation_number_seq)
|
212
|
+
string.replace(m, self.unpack_cite_seq(dedup))
|
213
|
+
return string
|
214
|
+
|
215
|
+
def decode_expr(self, string: str) -> List[int]:
|
216
|
+
"""Decode citation expression into a list of integers."""
|
217
|
+
if self.abbr_sep in string:
|
218
|
+
start, end = string.split(self.abbr_sep)
|
219
|
+
return list(range(int(start), int(end) + 1))
|
220
|
+
return [int(string)]
|
221
|
+
|
222
|
+
def convert_to_numeric_notations(self, string: str) -> List[str]:
|
223
|
+
"""Convert citation string into numeric notations."""
|
224
|
+
return [s.strip() for s in string.split(self.sep)]
|
225
|
+
|
226
|
+
def deduplicate_citation(self, citation_seq: List[int]) -> List[int]:
|
227
|
+
"""Deduplicate citation sequence."""
|
228
|
+
chunk_seq = [a for a in self.article_chunks if a.cite_number in citation_seq]
|
229
|
+
deduped = unique(chunk_seq, lambda a: a.cite_number)
|
230
|
+
return [a.cite_number for a in deduped]
|
231
|
+
|
232
|
+
def unpack_cite_seq(self, citation_seq: List[int]) -> str:
|
233
|
+
"""Unpack citation sequence into a string."""
|
234
|
+
chunk_seq = [a for a in self.article_chunks if a.cite_number in citation_seq]
|
235
|
+
return "".join(a.as_typst_cite() for a in chunk_seq)
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from typing import Dict, Generator, List, Self, Tuple, override
|
4
4
|
|
5
|
+
from fabricatio.decorators import precheck_package
|
5
6
|
from fabricatio.fs.readers import extract_sections
|
6
7
|
from fabricatio.journal import logger
|
7
8
|
from fabricatio.models.extra.article_base import (
|
@@ -14,7 +15,7 @@ from fabricatio.models.extra.article_outline import (
|
|
14
15
|
ArticleOutline,
|
15
16
|
)
|
16
17
|
from fabricatio.models.generic import Described, PersistentAble, SequencePatch, SketchedAble, WithRef, WordCount
|
17
|
-
from fabricatio.rust import word_count
|
18
|
+
from fabricatio.rust import convert_all_block_tex, convert_all_inline_tex, word_count
|
18
19
|
from pydantic import Field
|
19
20
|
|
20
21
|
PARAGRAPH_SEP = "// - - -"
|
@@ -153,6 +154,26 @@ class Article(
|
|
153
154
|
"Original Article": self.display(),
|
154
155
|
}
|
155
156
|
|
157
|
+
def convert_tex(self) -> Self:
|
158
|
+
"""Convert tex to typst code."""
|
159
|
+
for _, _, subsec in self.iter_subsections():
|
160
|
+
for p in subsec.paragraphs:
|
161
|
+
p.content = convert_all_inline_tex(p.content)
|
162
|
+
p.content = convert_all_block_tex(p.content)
|
163
|
+
return self
|
164
|
+
|
165
|
+
def fix_wrapper(self) -> Self:
|
166
|
+
"""Fix wrapper."""
|
167
|
+
for _, _, subsec in self.iter_subsections():
|
168
|
+
for p in subsec.paragraphs:
|
169
|
+
p.content = (
|
170
|
+
p.content.replace(r" \( ", "$")
|
171
|
+
.replace(r" \) ", "$")
|
172
|
+
.replace("\\[\n", "$$\n")
|
173
|
+
.replace("\n\\]", "\n$$")
|
174
|
+
)
|
175
|
+
return self
|
176
|
+
|
156
177
|
@override
|
157
178
|
def iter_subsections(self) -> Generator[Tuple[ArticleChapter, ArticleSection, ArticleSubsection], None, None]:
|
158
179
|
return super().iter_subsections() # pyright: ignore [reportReturnType]
|
@@ -204,3 +225,20 @@ class Article(
|
|
204
225
|
expected_word_count=word_count(body),
|
205
226
|
abstract="",
|
206
227
|
)
|
228
|
+
|
229
|
+
@classmethod
|
230
|
+
def from_mixed_source(cls, article_outline: ArticleOutline, typst_code: str) -> Self:
|
231
|
+
"""Generates an article from the given outline and Typst code."""
|
232
|
+
self = cls.from_typst_code(article_outline.title, typst_code)
|
233
|
+
self.expected_word_count = article_outline.expected_word_count
|
234
|
+
self.description = article_outline.description
|
235
|
+
for a, o in zip(self.iter_dfs(), article_outline.iter_dfs(), strict=True):
|
236
|
+
a.update_metadata(o)
|
237
|
+
return self.update_ref(article_outline)
|
238
|
+
|
239
|
+
@precheck_package(
|
240
|
+
"questionary", "'questionary' is required to run this function. Have you installed `fabricatio[qa]`?."
|
241
|
+
)
|
242
|
+
def edit_titles(self) -> Self:
|
243
|
+
for a in self.iter_dfs():
|
244
|
+
pass
|
@@ -3,12 +3,12 @@
|
|
3
3
|
from itertools import chain
|
4
4
|
from typing import Any, List, Optional, Self, Tuple, Unpack
|
5
5
|
|
6
|
+
from pydantic import Field
|
7
|
+
from rich import print as r_print
|
8
|
+
|
6
9
|
from fabricatio.journal import logger
|
7
10
|
from fabricatio.models.generic import SketchedAble, WithBriefing
|
8
11
|
from fabricatio.utils import ask_edit
|
9
|
-
from pydantic import Field
|
10
|
-
from questionary import Choice, checkbox, text
|
11
|
-
from rich import print as r_print
|
12
12
|
|
13
13
|
|
14
14
|
class Problem(SketchedAble, WithBriefing):
|
@@ -74,6 +74,8 @@ class ProblemSolutions(SketchedAble):
|
|
74
74
|
return len(self.solutions) > 0
|
75
75
|
|
76
76
|
async def edit_problem(self) -> Self:
|
77
|
+
from questionary import text
|
78
|
+
|
77
79
|
"""Interactively edit the problem description."""
|
78
80
|
self.problem = Problem.model_validate_strings(
|
79
81
|
await text("Please edit the problem below:", default=self.problem.display()).ask_async()
|
@@ -127,6 +129,8 @@ class Improvement(SketchedAble):
|
|
127
129
|
Returns:
|
128
130
|
Self: The current instance with filtered problems and solutions.
|
129
131
|
"""
|
132
|
+
from questionary import Choice, checkbox
|
133
|
+
|
130
134
|
# Choose the problems to retain
|
131
135
|
chosen_ones: List[ProblemSolutions] = await checkbox(
|
132
136
|
"Please choose the problems you want to retain.(Default: retain all)",
|
fabricatio/models/generic.py
CHANGED
@@ -3,16 +3,10 @@
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from datetime import datetime
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Type, Union, final, overload
|
6
|
+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Self, Type, Union, final, overload
|
7
7
|
|
8
|
-
import
|
9
|
-
from fabricatio.config import configs
|
10
|
-
from fabricatio.fs.readers import MAGIKA, safe_text_read
|
11
|
-
from fabricatio.journal import logger
|
12
|
-
from fabricatio.parser import JsonCapture
|
8
|
+
import ujson
|
13
9
|
from fabricatio.rust import blake3_hash, detect_language
|
14
|
-
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
15
|
-
from fabricatio.utils import ok
|
16
10
|
from litellm.utils import token_counter
|
17
11
|
from pydantic import (
|
18
12
|
BaseModel,
|
@@ -27,6 +21,13 @@ from pydantic import (
|
|
27
21
|
)
|
28
22
|
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
|
29
23
|
|
24
|
+
from fabricatio.config import configs
|
25
|
+
from fabricatio.fs.readers import safe_text_read
|
26
|
+
from fabricatio.journal import logger
|
27
|
+
from fabricatio.parser import JsonCapture
|
28
|
+
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
29
|
+
from fabricatio.utils import ok
|
30
|
+
|
30
31
|
|
31
32
|
class Base(BaseModel):
|
32
33
|
"""Base class for all models with Pydantic configuration.
|
@@ -74,9 +75,9 @@ class Display(Base):
|
|
74
75
|
str: Combined display output with boundary markers
|
75
76
|
"""
|
76
77
|
return (
|
77
|
-
|
78
|
-
|
79
|
-
|
78
|
+
"--- Start of Extra Info Sequence ---"
|
79
|
+
+ "\n".join(d.compact() if compact else d.display() for d in seq)
|
80
|
+
+ "--- End of Extra Info Sequence ---"
|
80
81
|
)
|
81
82
|
|
82
83
|
|
@@ -117,6 +118,15 @@ class WordCount(Base):
|
|
117
118
|
"""Expected word count of this research component."""
|
118
119
|
|
119
120
|
|
121
|
+
class FromMapping(Base):
|
122
|
+
"""Class that provides a method to generate a list of objects from a mapping."""
|
123
|
+
|
124
|
+
@classmethod
|
125
|
+
@abstractmethod
|
126
|
+
def from_mapping(cls, mapping: Mapping[str, Any], **kwargs: Any) -> List[Self]:
|
127
|
+
"""Generate a list of objects from a mapping."""
|
128
|
+
|
129
|
+
|
120
130
|
class AsPrompt(Base):
|
121
131
|
"""Class that provides a method to generate a prompt from the model.
|
122
132
|
|
@@ -169,11 +179,17 @@ class WithRef[T](Base):
|
|
169
179
|
)
|
170
180
|
|
171
181
|
@overload
|
172
|
-
def update_ref[S: WithRef](self: S, reference: T) -> S:
|
182
|
+
def update_ref[S: WithRef](self: S, reference: T) -> S:
|
183
|
+
...
|
184
|
+
|
173
185
|
@overload
|
174
|
-
def update_ref[S: WithRef](self: S, reference: "WithRef[T]") -> S:
|
186
|
+
def update_ref[S: WithRef](self: S, reference: "WithRef[T]") -> S:
|
187
|
+
...
|
188
|
+
|
175
189
|
@overload
|
176
|
-
def update_ref[S: WithRef](self: S, reference: None = None) -> S:
|
190
|
+
def update_ref[S: WithRef](self: S, reference: None = None) -> S:
|
191
|
+
...
|
192
|
+
|
177
193
|
def update_ref[S: WithRef](self: S, reference: Union[T, "WithRef[T]", None] = None) -> S: # noqa: PYI019
|
178
194
|
"""Update the reference of the object.
|
179
195
|
|
@@ -455,9 +471,9 @@ class WithFormatedJsonSchema(Base):
|
|
455
471
|
Returns:
|
456
472
|
str: The JSON schema of the model in a formatted string.
|
457
473
|
"""
|
458
|
-
return
|
474
|
+
return ujson.dumps(
|
459
475
|
cls.model_json_schema(schema_generator=UnsortGenerate),
|
460
|
-
option=
|
476
|
+
option=ujson.OPT_INDENT_2,
|
461
477
|
).decode()
|
462
478
|
|
463
479
|
|
@@ -470,9 +486,11 @@ class CreateJsonObjPrompt(WithFormatedJsonSchema):
|
|
470
486
|
@classmethod
|
471
487
|
@overload
|
472
488
|
def create_json_prompt(cls, requirement: List[str]) -> List[str]: ...
|
489
|
+
|
473
490
|
@classmethod
|
474
491
|
@overload
|
475
492
|
def create_json_prompt(cls, requirement: str) -> str: ...
|
493
|
+
|
476
494
|
@classmethod
|
477
495
|
def create_json_prompt(cls, requirement: str | List[str]) -> str | List[str]:
|
478
496
|
"""Create the prompt for creating a JSON object with given requirement.
|
@@ -639,6 +657,8 @@ class WithDependency(Base):
|
|
639
657
|
Returns:
|
640
658
|
str: The generated prompt for the task.
|
641
659
|
"""
|
660
|
+
from fabricatio.fs import MAGIKA
|
661
|
+
|
642
662
|
return TEMPLATE_MANAGER.render_template(
|
643
663
|
configs.templates.dependencies_template,
|
644
664
|
{
|
@@ -734,6 +754,12 @@ class ScopedConfig(Base):
|
|
734
754
|
llm_rpm: Optional[PositiveInt] = None
|
735
755
|
"""The requests per minute of the LLM model."""
|
736
756
|
|
757
|
+
llm_presence_penalty: Optional[PositiveFloat] = None
|
758
|
+
"""The presence penalty of the LLM model."""
|
759
|
+
|
760
|
+
llm_frequency_penalty: Optional[PositiveFloat] = None
|
761
|
+
"""The frequency penalty of the LLM model."""
|
762
|
+
|
737
763
|
embedding_api_endpoint: Optional[HttpUrl] = None
|
738
764
|
"""The OpenAI API endpoint."""
|
739
765
|
|
@@ -858,13 +884,14 @@ class Patch[T](ProposedAble):
|
|
858
884
|
# copy the desc info of each corresponding fields from `ref_cls`
|
859
885
|
for field_name in [f for f in cls.model_fields if f in ref_cls.model_fields]:
|
860
886
|
my_schema["properties"][field_name]["description"] = (
|
861
|
-
|
887
|
+
ref_cls.model_fields[field_name].description or my_schema["properties"][field_name][
|
888
|
+
"description"]
|
862
889
|
)
|
863
890
|
my_schema["description"] = ref_cls.__doc__
|
864
891
|
|
865
|
-
return
|
892
|
+
return ujson.dumps(
|
866
893
|
my_schema,
|
867
|
-
option=
|
894
|
+
option=ujson.OPT_INDENT_2,
|
868
895
|
).decode()
|
869
896
|
|
870
897
|
|
@@ -45,6 +45,8 @@ class LLMKwargs(TypedDict, total=False):
|
|
45
45
|
no_store: bool # If store the response of this call to cache
|
46
46
|
cache_ttl: int # how long the stored cache is alive, in seconds
|
47
47
|
s_maxage: int # max accepted age of cached response, in seconds
|
48
|
+
presence_penalty: float
|
49
|
+
frequency_penalty: float
|
48
50
|
|
49
51
|
|
50
52
|
class GenerateKwargs(LLMKwargs, total=False):
|
@@ -66,7 +68,7 @@ class ValidateKwargs[T](GenerateKwargs, total=False):
|
|
66
68
|
|
67
69
|
default: Optional[T]
|
68
70
|
max_validations: int
|
69
|
-
|
71
|
+
|
70
72
|
|
71
73
|
|
72
74
|
class CompositeScoreKwargs(ValidateKwargs[List[Dict[str, float]]], total=False):
|
fabricatio/models/usages.py
CHANGED
@@ -63,7 +63,7 @@ class LLMUsage(ScopedConfig):
|
|
63
63
|
self._added_deployment = ROUTER.upsert_deployment(deployment)
|
64
64
|
return ROUTER
|
65
65
|
|
66
|
-
# noinspection PyTypeChecker,PydanticTypeChecker
|
66
|
+
# noinspection PyTypeChecker,PydanticTypeChecker,t
|
67
67
|
async def aquery(
|
68
68
|
self,
|
69
69
|
messages: List[Dict[str, str]],
|
@@ -122,6 +122,12 @@ class LLMUsage(ScopedConfig):
|
|
122
122
|
"cache-ttl": kwargs.get("cache_ttl"),
|
123
123
|
"s-maxage": kwargs.get("s_maxage"),
|
124
124
|
},
|
125
|
+
presence_penalty=kwargs.get("presence_penalty")
|
126
|
+
or self.llm_presence_penalty
|
127
|
+
or configs.llm.presence_penalty,
|
128
|
+
frequency_penalty=kwargs.get("frequency_penalty")
|
129
|
+
or self.llm_frequency_penalty
|
130
|
+
or configs.llm.frequency_penalty,
|
125
131
|
)
|
126
132
|
|
127
133
|
async def ainvoke(
|
@@ -236,7 +242,6 @@ class LLMUsage(ScopedConfig):
|
|
236
242
|
validator: Callable[[str], T | None],
|
237
243
|
default: T = ...,
|
238
244
|
max_validations: PositiveInt = 2,
|
239
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
240
245
|
**kwargs: Unpack[GenerateKwargs],
|
241
246
|
) -> T: ...
|
242
247
|
@overload
|
@@ -246,7 +251,6 @@ class LLMUsage(ScopedConfig):
|
|
246
251
|
validator: Callable[[str], T | None],
|
247
252
|
default: T = ...,
|
248
253
|
max_validations: PositiveInt = 2,
|
249
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
250
254
|
**kwargs: Unpack[GenerateKwargs],
|
251
255
|
) -> List[T]: ...
|
252
256
|
@overload
|
@@ -256,7 +260,6 @@ class LLMUsage(ScopedConfig):
|
|
256
260
|
validator: Callable[[str], T | None],
|
257
261
|
default: None = None,
|
258
262
|
max_validations: PositiveInt = 2,
|
259
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
260
263
|
**kwargs: Unpack[GenerateKwargs],
|
261
264
|
) -> Optional[T]: ...
|
262
265
|
|
@@ -267,7 +270,6 @@ class LLMUsage(ScopedConfig):
|
|
267
270
|
validator: Callable[[str], T | None],
|
268
271
|
default: None = None,
|
269
272
|
max_validations: PositiveInt = 2,
|
270
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
271
273
|
**kwargs: Unpack[GenerateKwargs],
|
272
274
|
) -> List[Optional[T]]: ...
|
273
275
|
|
@@ -277,7 +279,6 @@ class LLMUsage(ScopedConfig):
|
|
277
279
|
validator: Callable[[str], T | None],
|
278
280
|
default: Optional[T] = None,
|
279
281
|
max_validations: PositiveInt = 3,
|
280
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
281
282
|
**kwargs: Unpack[GenerateKwargs],
|
282
283
|
) -> Optional[T] | List[Optional[T]] | List[T] | T:
|
283
284
|
"""Asynchronously asks a question and validates the response using a given validator.
|
@@ -287,34 +288,16 @@ class LLMUsage(ScopedConfig):
|
|
287
288
|
validator (Callable[[str], T | None]): A function to validate the response.
|
288
289
|
default (T | None): Default value to return if validation fails. Defaults to None.
|
289
290
|
max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 3.
|
290
|
-
co_extractor (Optional[GenerateKwargs]): Keyword arguments for the co-extractor, if provided will enable co-extraction.
|
291
291
|
**kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
|
292
292
|
|
293
293
|
Returns:
|
294
|
-
Optional[T] | List[
|
294
|
+
Optional[T] | List[T | None] | List[T] | T: The validated response.
|
295
295
|
"""
|
296
296
|
|
297
297
|
async def _inner(q: str) -> Optional[T]:
|
298
298
|
for lap in range(max_validations):
|
299
299
|
try:
|
300
|
-
if (
|
301
|
-
co_extractor is not None
|
302
|
-
and logger.debug("Co-extraction is enabled.") is None
|
303
|
-
and (
|
304
|
-
validated := validator(
|
305
|
-
response := await self.aask(
|
306
|
-
question=(
|
307
|
-
TEMPLATE_MANAGER.render_template(
|
308
|
-
configs.templates.co_validation_template,
|
309
|
-
{"original_q": q, "original_a": response},
|
310
|
-
)
|
311
|
-
),
|
312
|
-
**co_extractor,
|
313
|
-
)
|
314
|
-
)
|
315
|
-
)
|
316
|
-
is not None
|
317
|
-
):
|
300
|
+
if (validated := validator(response := await self.aask(question=q, **kwargs))) is not None:
|
318
301
|
logger.debug(f"Successfully validated the response at {lap}th attempt.")
|
319
302
|
return validated
|
320
303
|
|
fabricatio/parser.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
1
|
"""A module to parse text using regular expressions."""
|
2
2
|
|
3
|
+
import re
|
4
|
+
from functools import lru_cache
|
5
|
+
from re import Pattern, compile
|
3
6
|
from typing import Any, Callable, Iterable, List, Optional, Self, Tuple, Type
|
4
7
|
|
5
|
-
import
|
6
|
-
import regex
|
8
|
+
import ujson
|
7
9
|
from json_repair import repair_json
|
8
10
|
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr, ValidationError
|
9
|
-
from regex import Pattern, compile
|
10
11
|
|
11
12
|
from fabricatio.config import configs
|
12
13
|
from fabricatio.journal import logger
|
@@ -25,7 +26,7 @@ class Capture(BaseModel):
|
|
25
26
|
"""The target groups to capture from the pattern."""
|
26
27
|
pattern: str = Field(frozen=True)
|
27
28
|
"""The regular expression pattern to search for."""
|
28
|
-
flags: PositiveInt = Field(default=
|
29
|
+
flags: PositiveInt = Field(default=re.DOTALL | re.MULTILINE | re.IGNORECASE, frozen=True)
|
29
30
|
"""The flags to use when compiling the regular expression pattern."""
|
30
31
|
capture_type: Optional[str] = None
|
31
32
|
"""The type of capture to perform, e.g., 'json', which is used to dispatch the fixer accordingly."""
|
@@ -49,7 +50,8 @@ class Capture(BaseModel):
|
|
49
50
|
logger.debug("Applying json repair to text.")
|
50
51
|
if isinstance(text, str):
|
51
52
|
return repair_json(text, ensure_ascii=False) # pyright: ignore [reportReturnType]
|
52
|
-
return [repair_json(item, ensure_ascii=False) for item in
|
53
|
+
return [repair_json(item, ensure_ascii=False) for item in
|
54
|
+
text] # pyright: ignore [reportReturnType, reportGeneralTypeIssues]
|
53
55
|
case _:
|
54
56
|
return text # pyright: ignore [reportReturnType]
|
55
57
|
|
@@ -63,7 +65,7 @@ class Capture(BaseModel):
|
|
63
65
|
str | None: The captured text if the pattern is found, otherwise None.
|
64
66
|
|
65
67
|
"""
|
66
|
-
if (match :=self._compiled.match(text) or self._compiled.search(text)
|
68
|
+
if (match := self._compiled.match(text) or self._compiled.search(text)) is None:
|
67
69
|
logger.debug(f"Capture Failed {type(text)}: \n{text}")
|
68
70
|
return None
|
69
71
|
groups = self.fix(match.groups())
|
@@ -94,12 +96,12 @@ class Capture(BaseModel):
|
|
94
96
|
return None
|
95
97
|
|
96
98
|
def validate_with[K, T, E](
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
99
|
+
self,
|
100
|
+
text: str,
|
101
|
+
target_type: Type[T],
|
102
|
+
elements_type: Optional[Type[E]] = None,
|
103
|
+
length: Optional[int] = None,
|
104
|
+
deserializer: Callable[[Tuple[str, ...]], K] | Callable[[str], K] = ujson.loads,
|
103
105
|
) -> T | None:
|
104
106
|
"""Validate the given text using the pattern.
|
105
107
|
|
@@ -124,6 +126,7 @@ class Capture(BaseModel):
|
|
124
126
|
return None
|
125
127
|
|
126
128
|
@classmethod
|
129
|
+
@lru_cache(32)
|
127
130
|
def capture_code_block(cls, language: str) -> Self:
|
128
131
|
"""Capture the first occurrence of a code block in the given text.
|
129
132
|
|
@@ -136,6 +139,7 @@ class Capture(BaseModel):
|
|
136
139
|
return cls(pattern=f"```{language}(.*?)```", capture_type=language)
|
137
140
|
|
138
141
|
@classmethod
|
142
|
+
@lru_cache(32)
|
139
143
|
def capture_generic_block(cls, language: str) -> Self:
|
140
144
|
"""Capture the first occurrence of a generic code block in the given text.
|
141
145
|
|
Binary file
|