fabricatio 0.2.10.dev1__cp312-cp312-win_amd64.whl → 0.2.11.dev1__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +38 -9
- fabricatio/actions/article_rag.py +115 -13
- fabricatio/actions/fs.py +25 -0
- fabricatio/actions/output.py +17 -3
- fabricatio/actions/rag.py +3 -3
- fabricatio/actions/rules.py +14 -3
- fabricatio/capabilities/extract.py +70 -0
- fabricatio/capabilities/rating.py +5 -2
- fabricatio/capabilities/task.py +16 -16
- fabricatio/config.py +9 -2
- fabricatio/decorators.py +30 -30
- fabricatio/fs/__init__.py +9 -2
- fabricatio/fs/readers.py +6 -10
- fabricatio/models/extra/aricle_rag.py +125 -9
- fabricatio/models/extra/article_main.py +46 -2
- fabricatio/models/extra/problem.py +5 -1
- fabricatio/models/generic.py +29 -11
- fabricatio/models/kwargs_types.py +3 -1
- fabricatio/models/usages.py +9 -26
- fabricatio/parser.py +16 -12
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +124 -6
- fabricatio/utils.py +11 -3
- fabricatio-0.2.11.dev1.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dev1.dist-info}/METADATA +18 -9
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dev1.dist-info}/RECORD +28 -26
- fabricatio-0.2.10.dev1.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dev1.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.10.dev1.dist-info → fabricatio-0.2.11.dev1.dist-info}/licenses/LICENSE +0 -0
fabricatio/decorators.py
CHANGED
@@ -8,14 +8,34 @@ from shutil import which
|
|
8
8
|
from types import ModuleType
|
9
9
|
from typing import Callable, List, Optional
|
10
10
|
|
11
|
-
from questionary import confirm
|
12
|
-
|
13
11
|
from fabricatio.config import configs
|
14
12
|
from fabricatio.journal import logger
|
15
13
|
|
16
14
|
|
15
|
+
def precheck_package[**P, R](package_name: str, msg: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
16
|
+
"""Check if a package exists in the current environment.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
package_name (str): The name of the package to check.
|
20
|
+
msg (str): The message to display if the package is not found.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
bool: True if the package exists, False otherwise.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def _wrapper(func: Callable[P, R]) -> Callable[P, R]:
|
27
|
+
def _inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
28
|
+
if find_spec(package_name):
|
29
|
+
return func(*args, **kwargs)
|
30
|
+
raise RuntimeError(msg)
|
31
|
+
|
32
|
+
return _inner
|
33
|
+
|
34
|
+
return _wrapper
|
35
|
+
|
36
|
+
|
17
37
|
def depend_on_external_cmd[**P, R](
|
18
|
-
|
38
|
+
bin_name: str, install_tip: Optional[str], homepage: Optional[str] = None
|
19
39
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
20
40
|
"""Decorator to check for the presence of an external command.
|
21
41
|
|
@@ -68,6 +88,8 @@ def logging_execution_info[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
|
68
88
|
return _wrapper
|
69
89
|
|
70
90
|
|
91
|
+
@precheck_package("questionary",
|
92
|
+
"'questionary' is required to run this function. Have you installed `fabricatio[qa]`?.")
|
71
93
|
def confirm_to_execute[**P, R](func: Callable[P, R]) -> Callable[P, Optional[R]] | Callable[P, R]:
|
72
94
|
"""Decorator to confirm before executing a function.
|
73
95
|
|
@@ -80,14 +102,15 @@ def confirm_to_execute[**P, R](func: Callable[P, R]) -> Callable[P, Optional[R]]
|
|
80
102
|
if not configs.general.confirm_on_ops:
|
81
103
|
# Skip confirmation if the configuration is set to False
|
82
104
|
return func
|
105
|
+
from questionary import confirm
|
83
106
|
|
84
107
|
if iscoroutinefunction(func):
|
85
108
|
|
86
109
|
@wraps(func)
|
87
110
|
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
|
88
111
|
if await confirm(
|
89
|
-
|
90
|
-
|
112
|
+
f"Are you sure to execute function: {func.__name__}{signature(func)} \n📦 Args:{args}\n🔑 Kwargs:{kwargs}\n",
|
113
|
+
instruction="Please input [Yes/No] to proceed (default: Yes):",
|
91
114
|
).ask_async():
|
92
115
|
return await func(*args, **kwargs)
|
93
116
|
logger.warning(f"Function: {func.__name__}{signature(func)} canceled by user.")
|
@@ -98,8 +121,8 @@ def confirm_to_execute[**P, R](func: Callable[P, R]) -> Callable[P, Optional[R]]
|
|
98
121
|
@wraps(func)
|
99
122
|
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
|
100
123
|
if confirm(
|
101
|
-
|
102
|
-
|
124
|
+
f"Are you sure to execute function: {func.__name__}{signature(func)} \n📦 Args:{args}\n��� Kwargs:{kwargs}\n",
|
125
|
+
instruction="Please input [Yes/No] to proceed (default: Yes):",
|
103
126
|
).ask():
|
104
127
|
return func(*args, **kwargs)
|
105
128
|
logger.warning(f"Function: {func.__name__}{signature(func)} canceled by user.")
|
@@ -192,7 +215,6 @@ def logging_exec_time[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
|
192
215
|
from time import time
|
193
216
|
|
194
217
|
if iscoroutinefunction(func):
|
195
|
-
|
196
218
|
@wraps(func)
|
197
219
|
async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
198
220
|
start_time = time()
|
@@ -210,25 +232,3 @@ def logging_exec_time[**P, R](func: Callable[P, R]) -> Callable[P, R]:
|
|
210
232
|
return result
|
211
233
|
|
212
234
|
return _wrapper
|
213
|
-
|
214
|
-
|
215
|
-
def precheck_package[**P, R](package_name: str, msg: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
216
|
-
"""Check if a package exists in the current environment.
|
217
|
-
|
218
|
-
Args:
|
219
|
-
package_name (str): The name of the package to check.
|
220
|
-
msg (str): The message to display if the package is not found.
|
221
|
-
|
222
|
-
Returns:
|
223
|
-
bool: True if the package exists, False otherwise.
|
224
|
-
"""
|
225
|
-
|
226
|
-
def _wrapper(func: Callable[P, R]) -> Callable[P, R]:
|
227
|
-
def _inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
228
|
-
if find_spec(package_name):
|
229
|
-
return func(*args, **kwargs)
|
230
|
-
raise RuntimeError(msg)
|
231
|
-
|
232
|
-
return _inner
|
233
|
-
|
234
|
-
return _wrapper
|
fabricatio/fs/__init__.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""FileSystem manipulation module for Fabricatio."""
|
2
|
+
from importlib.util import find_spec
|
2
3
|
|
4
|
+
from fabricatio.config import configs
|
3
5
|
from fabricatio.fs.curd import (
|
4
6
|
absolute_path,
|
5
7
|
copy_file,
|
@@ -11,10 +13,9 @@ from fabricatio.fs.curd import (
|
|
11
13
|
move_file,
|
12
14
|
tree,
|
13
15
|
)
|
14
|
-
from fabricatio.fs.readers import
|
16
|
+
from fabricatio.fs.readers import safe_json_read, safe_text_read
|
15
17
|
|
16
18
|
__all__ = [
|
17
|
-
"MAGIKA",
|
18
19
|
"absolute_path",
|
19
20
|
"copy_file",
|
20
21
|
"create_directory",
|
@@ -27,3 +28,9 @@ __all__ = [
|
|
27
28
|
"safe_text_read",
|
28
29
|
"tree",
|
29
30
|
]
|
31
|
+
|
32
|
+
if find_spec("magika"):
|
33
|
+
from magika import Magika
|
34
|
+
|
35
|
+
MAGIKA = Magika(model_dir=configs.magika.model_dir)
|
36
|
+
__all__ += ["MAGIKA"]
|
fabricatio/fs/readers.py
CHANGED
@@ -1,17 +1,13 @@
|
|
1
1
|
"""Filesystem readers for Fabricatio."""
|
2
2
|
|
3
|
+
import re
|
3
4
|
from pathlib import Path
|
4
5
|
from typing import Dict, List, Tuple
|
5
6
|
|
6
|
-
import
|
7
|
-
import regex
|
8
|
-
from magika import Magika
|
7
|
+
import ujson
|
9
8
|
|
10
|
-
from fabricatio.config import configs
|
11
9
|
from fabricatio.journal import logger
|
12
10
|
|
13
|
-
MAGIKA = Magika(model_dir=configs.magika.model_dir)
|
14
|
-
|
15
11
|
|
16
12
|
def safe_text_read(path: Path | str) -> str:
|
17
13
|
"""Safely read the text from a file.
|
@@ -41,8 +37,8 @@ def safe_json_read(path: Path | str) -> Dict:
|
|
41
37
|
"""
|
42
38
|
path = Path(path)
|
43
39
|
try:
|
44
|
-
return
|
45
|
-
except (
|
40
|
+
return ujson.loads(path.read_text(encoding="utf-8"))
|
41
|
+
except (ujson.JSONDecodeError, IsADirectoryError, FileNotFoundError) as e:
|
46
42
|
logger.error(f"Failed to read file {path}: {e!s}")
|
47
43
|
return {}
|
48
44
|
|
@@ -58,8 +54,8 @@ def extract_sections(string: str, level: int, section_char: str = "#") -> List[T
|
|
58
54
|
Returns:
|
59
55
|
List[Tuple[str, str]]: List of (header_text, section_content) tuples
|
60
56
|
"""
|
61
|
-
return
|
57
|
+
return re.findall(
|
62
58
|
r"^%s{%d}\s+(.+?)\n((?:(?!^%s{%d}\s).|\n)*)" % (section_char, level, section_char, level),
|
63
59
|
string,
|
64
|
-
|
60
|
+
re.MULTILINE,
|
65
61
|
)
|
@@ -1,22 +1,27 @@
|
|
1
1
|
"""A Module containing the article rag models."""
|
2
2
|
|
3
|
+
import re
|
3
4
|
from pathlib import Path
|
4
|
-
from typing import ClassVar, Dict, List, Self, Unpack
|
5
|
+
from typing import ClassVar, Dict, List, Optional, Self, Unpack
|
5
6
|
|
6
7
|
from fabricatio.fs import safe_text_read
|
7
8
|
from fabricatio.journal import logger
|
8
9
|
from fabricatio.models.extra.rag import MilvusDataBase
|
9
10
|
from fabricatio.models.generic import AsPrompt
|
10
11
|
from fabricatio.models.kwargs_types import ChunkKwargs
|
11
|
-
from fabricatio.rust import BibManager, split_into_chunks
|
12
|
-
from fabricatio.utils import ok
|
13
|
-
from more_itertools.recipes import flatten
|
12
|
+
from fabricatio.rust import BibManager, is_chinese, split_into_chunks
|
13
|
+
from fabricatio.utils import ok
|
14
|
+
from more_itertools.recipes import flatten, unique
|
14
15
|
from pydantic import Field
|
15
16
|
|
16
17
|
|
17
18
|
class ArticleChunk(MilvusDataBase, AsPrompt):
|
18
19
|
"""The chunk of an article."""
|
19
20
|
|
21
|
+
etc_word: ClassVar[str] = "等"
|
22
|
+
and_word: ClassVar[str] = "与"
|
23
|
+
_cite_number: Optional[int] = None
|
24
|
+
|
20
25
|
head_split: ClassVar[List[str]] = [
|
21
26
|
"引 言",
|
22
27
|
"引言",
|
@@ -48,12 +53,14 @@ class ArticleChunk(MilvusDataBase, AsPrompt):
|
|
48
53
|
|
49
54
|
def _as_prompt_inner(self) -> Dict[str, str]:
|
50
55
|
return {
|
51
|
-
|
52
|
-
f"Authors: {';'.join(self.authors)}\n"
|
53
|
-
f"Published Year: {self.year}\n"
|
54
|
-
f"Bibtex Key: {self.bibtex_cite_key}\n",
|
56
|
+
f"[[{ok(self._cite_number, 'You need to update cite number first.')}]] reference `{self.article_title}`": self.chunk
|
55
57
|
}
|
56
58
|
|
59
|
+
@property
|
60
|
+
def cite_number(self) -> int:
|
61
|
+
"""Get the cite number."""
|
62
|
+
return ok(self._cite_number, "cite number not set")
|
63
|
+
|
57
64
|
def _prepare_vectorization_inner(self) -> str:
|
58
65
|
return self.chunk
|
59
66
|
|
@@ -89,8 +96,9 @@ class ArticleChunk(MilvusDataBase, AsPrompt):
|
|
89
96
|
|
90
97
|
result = [
|
91
98
|
cls(chunk=c, year=year, authors=authors, article_title=article_title, bibtex_cite_key=key)
|
92
|
-
for c in split_into_chunks(cls.strip(safe_text_read(path)), **kwargs)
|
99
|
+
for c in split_into_chunks(cls.purge_numeric_citation(cls.strip(safe_text_read(path))), **kwargs)
|
93
100
|
]
|
101
|
+
|
94
102
|
logger.debug(f"Number of chunks created from file {path.as_posix()}: {len(result)}")
|
95
103
|
return result
|
96
104
|
|
@@ -118,3 +126,111 @@ class ArticleChunk(MilvusDataBase, AsPrompt):
|
|
118
126
|
logger.warning("No decrease at tail strip, which is might be abnormal.")
|
119
127
|
|
120
128
|
return string
|
129
|
+
|
130
|
+
def as_typst_cite(self) -> str:
|
131
|
+
"""As typst cite."""
|
132
|
+
return f"#cite(<{self.bibtex_cite_key}>)"
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def purge_numeric_citation(string: str) -> str:
|
136
|
+
"""Purge numeric citation."""
|
137
|
+
import re
|
138
|
+
|
139
|
+
return re.sub(r"\[[\d\s,\\~–-]+]", "", string)
|
140
|
+
|
141
|
+
@property
|
142
|
+
def auther_firstnames(self) -> List[str]:
|
143
|
+
"""Get the first name of the authors."""
|
144
|
+
ret = []
|
145
|
+
for n in self.authors:
|
146
|
+
if is_chinese(n):
|
147
|
+
ret.append(n[0])
|
148
|
+
else:
|
149
|
+
ret.append(n.split()[-1])
|
150
|
+
return ret
|
151
|
+
|
152
|
+
def as_auther_seq(self) -> str:
|
153
|
+
"""Get the auther sequence."""
|
154
|
+
match len(self.authors):
|
155
|
+
case 0:
|
156
|
+
raise ValueError("No authors found")
|
157
|
+
case 1:
|
158
|
+
return f"({self.auther_firstnames[0]},{self.year}){self.as_typst_cite()}"
|
159
|
+
case 2:
|
160
|
+
return f"({self.auther_firstnames[0]}{self.and_word}{self.auther_firstnames[1]},{self.year}){self.as_typst_cite()}"
|
161
|
+
case 3:
|
162
|
+
return f"({self.auther_firstnames[0]},{self.auther_firstnames[1]}{self.and_word}{self.auther_firstnames[2]},{self.year}){self.as_typst_cite()}"
|
163
|
+
case _:
|
164
|
+
return f"({self.auther_firstnames[0]},{self.auther_firstnames[1]}{self.and_word}{self.auther_firstnames[2]}{self.etc_word},{self.year}){self.as_typst_cite()}"
|
165
|
+
|
166
|
+
def update_cite_number(self, cite_number: int) -> Self:
|
167
|
+
"""Update the cite number."""
|
168
|
+
self._cite_number = cite_number
|
169
|
+
return self
|
170
|
+
|
171
|
+
|
172
|
+
class CitationManager(AsPrompt):
|
173
|
+
"""Citation manager."""
|
174
|
+
|
175
|
+
article_chunks: List[ArticleChunk] = Field(default_factory=list)
|
176
|
+
"""Article chunks."""
|
177
|
+
|
178
|
+
pat: str = r"(\[\[([\d\s,-]*)]])"
|
179
|
+
"""Regex pattern to match citations."""
|
180
|
+
sep: str = ","
|
181
|
+
"""Separator for citation numbers."""
|
182
|
+
abbr_sep: str = "-"
|
183
|
+
"""Separator for abbreviated citation numbers."""
|
184
|
+
|
185
|
+
def update_chunks(self, article_chunks: List[ArticleChunk], set_cite_number: bool = True) -> Self:
|
186
|
+
"""Update article chunks."""
|
187
|
+
self.article_chunks.clear()
|
188
|
+
self.article_chunks.extend(article_chunks)
|
189
|
+
if set_cite_number:
|
190
|
+
self.set_cite_number_all()
|
191
|
+
return self
|
192
|
+
|
193
|
+
def set_cite_number_all(self) -> Self:
|
194
|
+
"""Set citation numbers for all article chunks."""
|
195
|
+
for i, a in enumerate(self.article_chunks, 1):
|
196
|
+
a.update_cite_number(i)
|
197
|
+
return self
|
198
|
+
|
199
|
+
def _as_prompt_inner(self) -> Dict[str, str]:
|
200
|
+
"""Generate prompt inner representation."""
|
201
|
+
return {"References": "\n".join(r.as_prompt() for r in self.article_chunks)}
|
202
|
+
|
203
|
+
def apply(self, string: str) -> str:
|
204
|
+
"""Apply citation replacements to the input string."""
|
205
|
+
for origin,m in re.findall(self.pat, string):
|
206
|
+
logger.info(f"Matching citation: {m}")
|
207
|
+
notations = self.convert_to_numeric_notations(m)
|
208
|
+
logger.info(f"Citing Notations: {notations}")
|
209
|
+
citation_number_seq = list(flatten(self.decode_expr(n) for n in notations))
|
210
|
+
logger.info(f"Citation Number Sequence: {citation_number_seq}")
|
211
|
+
dedup = self.deduplicate_citation(citation_number_seq)
|
212
|
+
logger.info(f"Deduplicated Citation Number Sequence: {dedup}")
|
213
|
+
string=string.replace(origin, self.unpack_cite_seq(dedup))
|
214
|
+
return string
|
215
|
+
|
216
|
+
def decode_expr(self, string: str) -> List[int]:
|
217
|
+
"""Decode citation expression into a list of integers."""
|
218
|
+
if self.abbr_sep in string:
|
219
|
+
start, end = string.split(self.abbr_sep)
|
220
|
+
return list(range(int(start), int(end) + 1))
|
221
|
+
return [int(string)]
|
222
|
+
|
223
|
+
def convert_to_numeric_notations(self, string: str) -> List[str]:
|
224
|
+
"""Convert citation string into numeric notations."""
|
225
|
+
return [s.strip() for s in string.split(self.sep)]
|
226
|
+
|
227
|
+
def deduplicate_citation(self, citation_seq: List[int]) -> List[int]:
|
228
|
+
"""Deduplicate citation sequence."""
|
229
|
+
chunk_seq = [a for a in self.article_chunks if a.cite_number in citation_seq]
|
230
|
+
deduped = unique(chunk_seq, lambda a: a.bibtex_cite_key)
|
231
|
+
return [a.cite_number for a in deduped]
|
232
|
+
|
233
|
+
def unpack_cite_seq(self, citation_seq: List[int]) -> str:
|
234
|
+
"""Unpack citation sequence into a string."""
|
235
|
+
chunk_seq = [a for a in self.article_chunks if a.cite_number in citation_seq]
|
236
|
+
return "".join(a.as_typst_cite() for a in chunk_seq)
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from typing import Dict, Generator, List, Self, Tuple, override
|
4
4
|
|
5
|
+
from fabricatio.decorators import precheck_package
|
5
6
|
from fabricatio.fs.readers import extract_sections
|
6
7
|
from fabricatio.journal import logger
|
7
8
|
from fabricatio.models.extra.article_base import (
|
@@ -14,8 +15,8 @@ from fabricatio.models.extra.article_outline import (
|
|
14
15
|
ArticleOutline,
|
15
16
|
)
|
16
17
|
from fabricatio.models.generic import Described, PersistentAble, SequencePatch, SketchedAble, WithRef, WordCount
|
17
|
-
from fabricatio.rust import word_count
|
18
|
-
from pydantic import Field
|
18
|
+
from fabricatio.rust import convert_all_block_tex, convert_all_inline_tex, word_count
|
19
|
+
from pydantic import Field, NonNegativeInt
|
19
20
|
|
20
21
|
PARAGRAPH_SEP = "// - - -"
|
21
22
|
|
@@ -23,6 +24,9 @@ PARAGRAPH_SEP = "// - - -"
|
|
23
24
|
class Paragraph(SketchedAble, WordCount, Described):
|
24
25
|
"""Structured academic paragraph blueprint for controlled content generation."""
|
25
26
|
|
27
|
+
expected_word_count: NonNegativeInt = 0
|
28
|
+
"""The expected word count of this paragraph, 0 means not specified"""
|
29
|
+
|
26
30
|
description: str = Field(
|
27
31
|
alias="elaboration",
|
28
32
|
description=Described.model_fields["description"].description,
|
@@ -153,6 +157,26 @@ class Article(
|
|
153
157
|
"Original Article": self.display(),
|
154
158
|
}
|
155
159
|
|
160
|
+
def convert_tex(self) -> Self:
|
161
|
+
"""Convert tex to typst code."""
|
162
|
+
for _, _, subsec in self.iter_subsections():
|
163
|
+
for p in subsec.paragraphs:
|
164
|
+
p.content = convert_all_inline_tex(p.content)
|
165
|
+
p.content = convert_all_block_tex(p.content)
|
166
|
+
return self
|
167
|
+
|
168
|
+
def fix_wrapper(self) -> Self:
|
169
|
+
"""Fix wrapper."""
|
170
|
+
for _, _, subsec in self.iter_subsections():
|
171
|
+
for p in subsec.paragraphs:
|
172
|
+
p.content = (
|
173
|
+
p.content.replace(r" \( ", "$")
|
174
|
+
.replace(r" \) ", "$")
|
175
|
+
.replace("\\[\n", "$$\n")
|
176
|
+
.replace("\n\\]", "\n$$")
|
177
|
+
)
|
178
|
+
return self
|
179
|
+
|
156
180
|
@override
|
157
181
|
def iter_subsections(self) -> Generator[Tuple[ArticleChapter, ArticleSection, ArticleSubsection], None, None]:
|
158
182
|
return super().iter_subsections() # pyright: ignore [reportReturnType]
|
@@ -204,3 +228,23 @@ class Article(
|
|
204
228
|
expected_word_count=word_count(body),
|
205
229
|
abstract="",
|
206
230
|
)
|
231
|
+
|
232
|
+
@classmethod
|
233
|
+
def from_mixed_source(cls, article_outline: ArticleOutline, typst_code: str) -> Self:
|
234
|
+
"""Generates an article from the given outline and Typst code."""
|
235
|
+
self = cls.from_typst_code(article_outline.title, typst_code)
|
236
|
+
self.expected_word_count = article_outline.expected_word_count
|
237
|
+
self.description = article_outline.description
|
238
|
+
for a, o in zip(self.iter_dfs(), article_outline.iter_dfs(), strict=True):
|
239
|
+
a.update_metadata(o)
|
240
|
+
return self.update_ref(article_outline)
|
241
|
+
|
242
|
+
@precheck_package(
|
243
|
+
"questionary", "'questionary' is required to run this function. Have you installed `fabricatio[qa]`?."
|
244
|
+
)
|
245
|
+
async def edit_titles(self) -> Self:
|
246
|
+
"""Edits the titles of the article."""
|
247
|
+
from questionary import text
|
248
|
+
|
249
|
+
for a in self.iter_dfs():
|
250
|
+
a.title = await text(f"Edit `{a.title}`.", default=a.title).ask_async() or a.title
|
@@ -7,7 +7,6 @@ from fabricatio.journal import logger
|
|
7
7
|
from fabricatio.models.generic import SketchedAble, WithBriefing
|
8
8
|
from fabricatio.utils import ask_edit
|
9
9
|
from pydantic import Field
|
10
|
-
from questionary import Choice, checkbox, text
|
11
10
|
from rich import print as r_print
|
12
11
|
|
13
12
|
|
@@ -74,6 +73,9 @@ class ProblemSolutions(SketchedAble):
|
|
74
73
|
return len(self.solutions) > 0
|
75
74
|
|
76
75
|
async def edit_problem(self) -> Self:
|
76
|
+
"""Interactively edit the problem description."""
|
77
|
+
from questionary import text
|
78
|
+
|
77
79
|
"""Interactively edit the problem description."""
|
78
80
|
self.problem = Problem.model_validate_strings(
|
79
81
|
await text("Please edit the problem below:", default=self.problem.display()).ask_async()
|
@@ -127,6 +129,8 @@ class Improvement(SketchedAble):
|
|
127
129
|
Returns:
|
128
130
|
Self: The current instance with filtered problems and solutions.
|
129
131
|
"""
|
132
|
+
from questionary import Choice, checkbox
|
133
|
+
|
130
134
|
# Choose the problems to retain
|
131
135
|
chosen_ones: List[ProblemSolutions] = await checkbox(
|
132
136
|
"Please choose the problems you want to retain.(Default: retain all)",
|
fabricatio/models/generic.py
CHANGED
@@ -3,11 +3,11 @@
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from datetime import datetime
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Self, Type, Union, final, overload
|
6
|
+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Self, Type, Union, final, overload
|
7
7
|
|
8
|
-
import
|
8
|
+
import ujson
|
9
9
|
from fabricatio.config import configs
|
10
|
-
from fabricatio.fs.readers import
|
10
|
+
from fabricatio.fs.readers import safe_text_read
|
11
11
|
from fabricatio.journal import logger
|
12
12
|
from fabricatio.parser import JsonCapture
|
13
13
|
from fabricatio.rust import blake3_hash, detect_language
|
@@ -117,6 +117,15 @@ class WordCount(Base):
|
|
117
117
|
"""Expected word count of this research component."""
|
118
118
|
|
119
119
|
|
120
|
+
class FromMapping(Base):
|
121
|
+
"""Class that provides a method to generate a list of objects from a mapping."""
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
@abstractmethod
|
125
|
+
def from_mapping(cls, mapping: Mapping[str, Any], **kwargs: Any) -> List[Self]:
|
126
|
+
"""Generate a list of objects from a mapping."""
|
127
|
+
|
128
|
+
|
120
129
|
class AsPrompt(Base):
|
121
130
|
"""Class that provides a method to generate a prompt from the model.
|
122
131
|
|
@@ -170,10 +179,13 @@ class WithRef[T](Base):
|
|
170
179
|
|
171
180
|
@overload
|
172
181
|
def update_ref[S: WithRef](self: S, reference: T) -> S: ...
|
182
|
+
|
173
183
|
@overload
|
174
184
|
def update_ref[S: WithRef](self: S, reference: "WithRef[T]") -> S: ...
|
185
|
+
|
175
186
|
@overload
|
176
187
|
def update_ref[S: WithRef](self: S, reference: None = None) -> S: ...
|
188
|
+
|
177
189
|
def update_ref[S: WithRef](self: S, reference: Union[T, "WithRef[T]", None] = None) -> S: # noqa: PYI019
|
178
190
|
"""Update the reference of the object.
|
179
191
|
|
@@ -455,10 +467,9 @@ class WithFormatedJsonSchema(Base):
|
|
455
467
|
Returns:
|
456
468
|
str: The JSON schema of the model in a formatted string.
|
457
469
|
"""
|
458
|
-
return
|
459
|
-
cls.model_json_schema(schema_generator=UnsortGenerate),
|
460
|
-
|
461
|
-
).decode()
|
470
|
+
return ujson.dumps(
|
471
|
+
cls.model_json_schema(schema_generator=UnsortGenerate), indent=2, ensure_ascii=False, sort_keys=False
|
472
|
+
)
|
462
473
|
|
463
474
|
|
464
475
|
class CreateJsonObjPrompt(WithFormatedJsonSchema):
|
@@ -470,9 +481,11 @@ class CreateJsonObjPrompt(WithFormatedJsonSchema):
|
|
470
481
|
@classmethod
|
471
482
|
@overload
|
472
483
|
def create_json_prompt(cls, requirement: List[str]) -> List[str]: ...
|
484
|
+
|
473
485
|
@classmethod
|
474
486
|
@overload
|
475
487
|
def create_json_prompt(cls, requirement: str) -> str: ...
|
488
|
+
|
476
489
|
@classmethod
|
477
490
|
def create_json_prompt(cls, requirement: str | List[str]) -> str | List[str]:
|
478
491
|
"""Create the prompt for creating a JSON object with given requirement.
|
@@ -639,6 +652,8 @@ class WithDependency(Base):
|
|
639
652
|
Returns:
|
640
653
|
str: The generated prompt for the task.
|
641
654
|
"""
|
655
|
+
from fabricatio.fs import MAGIKA
|
656
|
+
|
642
657
|
return TEMPLATE_MANAGER.render_template(
|
643
658
|
configs.templates.dependencies_template,
|
644
659
|
{
|
@@ -734,6 +749,12 @@ class ScopedConfig(Base):
|
|
734
749
|
llm_rpm: Optional[PositiveInt] = None
|
735
750
|
"""The requests per minute of the LLM model."""
|
736
751
|
|
752
|
+
llm_presence_penalty: Optional[PositiveFloat] = None
|
753
|
+
"""The presence penalty of the LLM model."""
|
754
|
+
|
755
|
+
llm_frequency_penalty: Optional[PositiveFloat] = None
|
756
|
+
"""The frequency penalty of the LLM model."""
|
757
|
+
|
737
758
|
embedding_api_endpoint: Optional[HttpUrl] = None
|
738
759
|
"""The OpenAI API endpoint."""
|
739
760
|
|
@@ -862,10 +883,7 @@ class Patch[T](ProposedAble):
|
|
862
883
|
)
|
863
884
|
my_schema["description"] = ref_cls.__doc__
|
864
885
|
|
865
|
-
return
|
866
|
-
my_schema,
|
867
|
-
option=orjson.OPT_INDENT_2,
|
868
|
-
).decode()
|
886
|
+
return ujson.dumps(my_schema, indent=2, ensure_ascii=False, sort_keys=False)
|
869
887
|
|
870
888
|
|
871
889
|
class SequencePatch[T](ProposedUpdateAble):
|
@@ -45,6 +45,8 @@ class LLMKwargs(TypedDict, total=False):
|
|
45
45
|
no_store: bool # If store the response of this call to cache
|
46
46
|
cache_ttl: int # how long the stored cache is alive, in seconds
|
47
47
|
s_maxage: int # max accepted age of cached response, in seconds
|
48
|
+
presence_penalty: float
|
49
|
+
frequency_penalty: float
|
48
50
|
|
49
51
|
|
50
52
|
class GenerateKwargs(LLMKwargs, total=False):
|
@@ -66,7 +68,7 @@ class ValidateKwargs[T](GenerateKwargs, total=False):
|
|
66
68
|
|
67
69
|
default: Optional[T]
|
68
70
|
max_validations: int
|
69
|
-
|
71
|
+
|
70
72
|
|
71
73
|
|
72
74
|
class CompositeScoreKwargs(ValidateKwargs[List[Dict[str, float]]], total=False):
|
fabricatio/models/usages.py
CHANGED
@@ -63,7 +63,7 @@ class LLMUsage(ScopedConfig):
|
|
63
63
|
self._added_deployment = ROUTER.upsert_deployment(deployment)
|
64
64
|
return ROUTER
|
65
65
|
|
66
|
-
# noinspection PyTypeChecker,PydanticTypeChecker
|
66
|
+
# noinspection PyTypeChecker,PydanticTypeChecker,t
|
67
67
|
async def aquery(
|
68
68
|
self,
|
69
69
|
messages: List[Dict[str, str]],
|
@@ -122,6 +122,12 @@ class LLMUsage(ScopedConfig):
|
|
122
122
|
"cache-ttl": kwargs.get("cache_ttl"),
|
123
123
|
"s-maxage": kwargs.get("s_maxage"),
|
124
124
|
},
|
125
|
+
presence_penalty=kwargs.get("presence_penalty")
|
126
|
+
or self.llm_presence_penalty
|
127
|
+
or configs.llm.presence_penalty,
|
128
|
+
frequency_penalty=kwargs.get("frequency_penalty")
|
129
|
+
or self.llm_frequency_penalty
|
130
|
+
or configs.llm.frequency_penalty,
|
125
131
|
)
|
126
132
|
|
127
133
|
async def ainvoke(
|
@@ -236,7 +242,6 @@ class LLMUsage(ScopedConfig):
|
|
236
242
|
validator: Callable[[str], T | None],
|
237
243
|
default: T = ...,
|
238
244
|
max_validations: PositiveInt = 2,
|
239
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
240
245
|
**kwargs: Unpack[GenerateKwargs],
|
241
246
|
) -> T: ...
|
242
247
|
@overload
|
@@ -246,7 +251,6 @@ class LLMUsage(ScopedConfig):
|
|
246
251
|
validator: Callable[[str], T | None],
|
247
252
|
default: T = ...,
|
248
253
|
max_validations: PositiveInt = 2,
|
249
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
250
254
|
**kwargs: Unpack[GenerateKwargs],
|
251
255
|
) -> List[T]: ...
|
252
256
|
@overload
|
@@ -256,7 +260,6 @@ class LLMUsage(ScopedConfig):
|
|
256
260
|
validator: Callable[[str], T | None],
|
257
261
|
default: None = None,
|
258
262
|
max_validations: PositiveInt = 2,
|
259
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
260
263
|
**kwargs: Unpack[GenerateKwargs],
|
261
264
|
) -> Optional[T]: ...
|
262
265
|
|
@@ -267,7 +270,6 @@ class LLMUsage(ScopedConfig):
|
|
267
270
|
validator: Callable[[str], T | None],
|
268
271
|
default: None = None,
|
269
272
|
max_validations: PositiveInt = 2,
|
270
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
271
273
|
**kwargs: Unpack[GenerateKwargs],
|
272
274
|
) -> List[Optional[T]]: ...
|
273
275
|
|
@@ -277,7 +279,6 @@ class LLMUsage(ScopedConfig):
|
|
277
279
|
validator: Callable[[str], T | None],
|
278
280
|
default: Optional[T] = None,
|
279
281
|
max_validations: PositiveInt = 3,
|
280
|
-
co_extractor: Optional[GenerateKwargs] = None,
|
281
282
|
**kwargs: Unpack[GenerateKwargs],
|
282
283
|
) -> Optional[T] | List[Optional[T]] | List[T] | T:
|
283
284
|
"""Asynchronously asks a question and validates the response using a given validator.
|
@@ -287,34 +288,16 @@ class LLMUsage(ScopedConfig):
|
|
287
288
|
validator (Callable[[str], T | None]): A function to validate the response.
|
288
289
|
default (T | None): Default value to return if validation fails. Defaults to None.
|
289
290
|
max_validations (PositiveInt): Maximum number of validation attempts. Defaults to 3.
|
290
|
-
co_extractor (Optional[GenerateKwargs]): Keyword arguments for the co-extractor, if provided will enable co-extraction.
|
291
291
|
**kwargs (Unpack[GenerateKwargs]): Additional keyword arguments for the LLM usage.
|
292
292
|
|
293
293
|
Returns:
|
294
|
-
Optional[T] | List[
|
294
|
+
Optional[T] | List[T | None] | List[T] | T: The validated response.
|
295
295
|
"""
|
296
296
|
|
297
297
|
async def _inner(q: str) -> Optional[T]:
|
298
298
|
for lap in range(max_validations):
|
299
299
|
try:
|
300
|
-
if (
|
301
|
-
co_extractor is not None
|
302
|
-
and logger.debug("Co-extraction is enabled.") is None
|
303
|
-
and (
|
304
|
-
validated := validator(
|
305
|
-
response := await self.aask(
|
306
|
-
question=(
|
307
|
-
TEMPLATE_MANAGER.render_template(
|
308
|
-
configs.templates.co_validation_template,
|
309
|
-
{"original_q": q, "original_a": response},
|
310
|
-
)
|
311
|
-
),
|
312
|
-
**co_extractor,
|
313
|
-
)
|
314
|
-
)
|
315
|
-
)
|
316
|
-
is not None
|
317
|
-
):
|
300
|
+
if (validated := validator(response := await self.aask(question=q, **kwargs))) is not None:
|
318
301
|
logger.debug(f"Successfully validated the response at {lap}th attempt.")
|
319
302
|
return validated
|
320
303
|
|