fabricatio 0.2.8.dev1__cp312-cp312-manylinux_2_34_x86_64.whl → 0.2.8.dev3__cp312-cp312-manylinux_2_34_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/_rust.cpython-312-x86_64-linux-gnu.so +0 -0
- fabricatio/_rust.pyi +50 -0
- fabricatio/actions/article.py +103 -65
- fabricatio/actions/article_rag.py +73 -19
- fabricatio/actions/output.py +39 -6
- fabricatio/actions/rag.py +3 -3
- fabricatio/capabilities/check.py +97 -0
- fabricatio/capabilities/correct.py +7 -6
- fabricatio/capabilities/propose.py +20 -4
- fabricatio/capabilities/rag.py +3 -2
- fabricatio/capabilities/rating.py +7 -10
- fabricatio/capabilities/review.py +18 -187
- fabricatio/capabilities/task.py +8 -9
- fabricatio/config.py +2 -0
- fabricatio/fs/curd.py +4 -0
- fabricatio/models/action.py +10 -5
- fabricatio/models/extra/advanced_judge.py +16 -9
- fabricatio/models/extra/article_base.py +53 -10
- fabricatio/models/extra/article_essence.py +47 -171
- fabricatio/models/extra/article_main.py +6 -1
- fabricatio/models/extra/article_proposal.py +19 -1
- fabricatio/models/extra/problem.py +120 -0
- fabricatio/models/extra/rule.py +23 -0
- fabricatio/models/generic.py +50 -42
- fabricatio/models/role.py +4 -1
- fabricatio/models/usages.py +8 -6
- fabricatio/models/utils.py +0 -46
- fabricatio/utils.py +54 -0
- fabricatio-0.2.8.dev3.data/scripts/tdown +0 -0
- {fabricatio-0.2.8.dev1.dist-info → fabricatio-0.2.8.dev3.dist-info}/METADATA +2 -1
- fabricatio-0.2.8.dev3.dist-info/RECORD +53 -0
- fabricatio-0.2.8.dev1.data/scripts/tdown +0 -0
- fabricatio-0.2.8.dev1.dist-info/RECORD +0 -49
- {fabricatio-0.2.8.dev1.dist-info → fabricatio-0.2.8.dev3.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.8.dev1.dist-info → fabricatio-0.2.8.dev3.dist-info}/licenses/LICENSE +0 -0
Binary file
|
fabricatio/_rust.pyi
CHANGED
@@ -122,3 +122,53 @@ class BibManager:
|
|
122
122
|
Returns:
|
123
123
|
List of all titles in the bibliography
|
124
124
|
"""
|
125
|
+
|
126
|
+
def get_author_by_key(self, key: str) -> Optional[List[str]]:
|
127
|
+
"""Retrieve authors by citation key.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
key: Citation key
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
List of authors if found, None otherwise
|
134
|
+
"""
|
135
|
+
|
136
|
+
def get_year_by_key(self, key: str) -> Optional[int]:
|
137
|
+
"""Retrieve the publication year by citation key.
|
138
|
+
|
139
|
+
Args:
|
140
|
+
key: Citation key
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
Publication year if found, None otherwise
|
144
|
+
"""
|
145
|
+
|
146
|
+
def get_abstract_by_key(self, key: str) -> Optional[str]:
|
147
|
+
"""Retrieve the abstract by citation key.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
key: Citation key
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
Abstract if found, None otherwise
|
154
|
+
"""
|
155
|
+
def get_title_by_key(self, key: str) -> Optional[str]:
|
156
|
+
"""Retrieve the title by citation key.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
key: Citation key
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
Title if found, None otherwise
|
163
|
+
"""
|
164
|
+
|
165
|
+
def get_field_by_key(self, key: str, field: str)-> Optional[str]:
|
166
|
+
"""Retrieve a specific field by citation key.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
key: Citation key
|
170
|
+
field: Field name
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
Field value if found, None otherwise
|
174
|
+
"""
|
fabricatio/actions/article.py
CHANGED
@@ -4,6 +4,7 @@ from asyncio import gather
|
|
4
4
|
from pathlib import Path
|
5
5
|
from typing import Any, Callable, List, Optional
|
6
6
|
|
7
|
+
from fabricatio._rust import BibManager
|
7
8
|
from fabricatio.capabilities.advanced_judge import AdvancedJudge
|
8
9
|
from fabricatio.fs import safe_text_read
|
9
10
|
from fabricatio.journal import logger
|
@@ -14,7 +15,8 @@ from fabricatio.models.extra.article_main import Article
|
|
14
15
|
from fabricatio.models.extra.article_outline import ArticleOutline
|
15
16
|
from fabricatio.models.extra.article_proposal import ArticleProposal
|
16
17
|
from fabricatio.models.task import Task
|
17
|
-
from fabricatio.
|
18
|
+
from fabricatio.utils import ok
|
19
|
+
from more_itertools import filter_map
|
18
20
|
|
19
21
|
|
20
22
|
class ExtractArticleEssence(Action):
|
@@ -31,20 +33,62 @@ class ExtractArticleEssence(Action):
|
|
31
33
|
async def _execute(
|
32
34
|
self,
|
33
35
|
task_input: Task,
|
34
|
-
reader: Callable[[str], str] = lambda p: Path(p).read_text(encoding="utf-8"),
|
36
|
+
reader: Callable[[str], Optional[str]] = lambda p: Path(p).read_text(encoding="utf-8"),
|
35
37
|
**_,
|
36
|
-
) ->
|
38
|
+
) -> List[ArticleEssence]:
|
37
39
|
if not task_input.dependencies:
|
38
40
|
logger.info(err := "Task not approved, since no dependencies are provided.")
|
39
41
|
raise RuntimeError(err)
|
40
|
-
|
42
|
+
logger.info(f"Extracting article essence from {len(task_input.dependencies)} files.")
|
41
43
|
# trim the references
|
42
|
-
contents =
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
44
|
+
contents = list(filter_map(reader, task_input.dependencies))
|
45
|
+
logger.info(f"Read {len(task_input.dependencies)} to get {len(contents)} contents.")
|
46
|
+
|
47
|
+
out = []
|
48
|
+
|
49
|
+
for ess in await self.propose(
|
50
|
+
ArticleEssence,
|
51
|
+
[
|
52
|
+
f"{c}\n\n\nBased the provided academic article above, you need to extract the essence from it."
|
53
|
+
for c in contents
|
54
|
+
],
|
55
|
+
):
|
56
|
+
if ess is None:
|
57
|
+
logger.warning("Could not extract article essence")
|
58
|
+
else:
|
59
|
+
out.append(ess)
|
60
|
+
logger.info(f"Extracted {len(out)} article essence from {len(task_input.dependencies)} files.")
|
61
|
+
return out
|
62
|
+
|
63
|
+
|
64
|
+
class FixArticleEssence(Action):
|
65
|
+
"""Fix the article essence based on the bibtex key."""
|
66
|
+
|
67
|
+
output_key: str = "fixed_article_essence"
|
68
|
+
"""The key of the output data."""
|
69
|
+
|
70
|
+
async def _execute(
|
71
|
+
self,
|
72
|
+
bib_mgr: BibManager,
|
73
|
+
article_essence: List[ArticleEssence],
|
74
|
+
**_,
|
75
|
+
) -> List[ArticleEssence]:
|
76
|
+
out = []
|
77
|
+
count = 0
|
78
|
+
for a in article_essence:
|
79
|
+
if key := (bib_mgr.get_cite_key(a.title) or bib_mgr.get_cite_key_fuzzy(a.title)):
|
80
|
+
a.title = bib_mgr.get_title_by_key(key) or a.title
|
81
|
+
a.authors = bib_mgr.get_author_by_key(key) or a.authors
|
82
|
+
a.publication_year = bib_mgr.get_year_by_key(key) or a.publication_year
|
83
|
+
a.bibtex_cite_key = key
|
84
|
+
logger.info(f"Updated {a.title} with {key}")
|
85
|
+
out.append(a)
|
86
|
+
else:
|
87
|
+
logger.warning(f"No key found for {a.title}")
|
88
|
+
count += 1
|
89
|
+
if count:
|
90
|
+
logger.warning(f"{count} articles have no key")
|
91
|
+
return out
|
48
92
|
|
49
93
|
|
50
94
|
class GenerateArticleProposal(Action):
|
@@ -80,7 +124,6 @@ class GenerateArticleProposal(Action):
|
|
80
124
|
)
|
81
125
|
)
|
82
126
|
),
|
83
|
-
**self.prepend_sys_msg(),
|
84
127
|
),
|
85
128
|
"Could not generate the proposal.",
|
86
129
|
).update_ref(briefing)
|
@@ -105,10 +148,9 @@ class GenerateInitialOutline(Action):
|
|
105
148
|
await self.propose(
|
106
149
|
ArticleOutline,
|
107
150
|
article_proposal.as_prompt(),
|
108
|
-
**self.prepend_sys_msg(),
|
109
151
|
),
|
110
152
|
"Could not generate the initial outline.",
|
111
|
-
)
|
153
|
+
).update_ref(article_proposal)
|
112
154
|
|
113
155
|
|
114
156
|
class FixIntrospectedErrors(Action):
|
@@ -120,6 +162,7 @@ class FixIntrospectedErrors(Action):
|
|
120
162
|
async def _execute(
|
121
163
|
self,
|
122
164
|
article_outline: ArticleOutline,
|
165
|
+
supervisor_check: bool = False,
|
123
166
|
**_,
|
124
167
|
) -> Optional[ArticleOutline]:
|
125
168
|
introspect_manual = ok(
|
@@ -141,7 +184,7 @@ class FixIntrospectedErrors(Action):
|
|
141
184
|
reference=f"# Original Article Outline\n{article_outline.display()}\n# Error Need to be fixed\n{err}",
|
142
185
|
topic=intro_topic,
|
143
186
|
rating_manual=introspect_manual,
|
144
|
-
supervisor_check=
|
187
|
+
supervisor_check=supervisor_check,
|
145
188
|
),
|
146
189
|
"Could not correct the component.",
|
147
190
|
)
|
@@ -159,6 +202,7 @@ class FixIllegalReferences(Action):
|
|
159
202
|
async def _execute(
|
160
203
|
self,
|
161
204
|
article_outline: ArticleOutline,
|
205
|
+
supervisor_check: bool = False,
|
162
206
|
**_,
|
163
207
|
) -> Optional[ArticleOutline]:
|
164
208
|
ref_manual = ok(
|
@@ -171,88 +215,81 @@ class FixIllegalReferences(Action):
|
|
171
215
|
"Could not generate the rating manual.",
|
172
216
|
)
|
173
217
|
|
174
|
-
while pack := article_outline.find_illegal_ref():
|
175
|
-
|
218
|
+
while pack := article_outline.find_illegal_ref(gather_identical=True):
|
219
|
+
refs, err = ok(pack)
|
176
220
|
logger.warning(f"Found illegal referring error: {err}")
|
177
|
-
ok(
|
178
|
-
await self.
|
179
|
-
|
180
|
-
reference=f"# Original Article Outline\n{article_outline.display()}\n# Error Need to be fixed\n{err}
|
221
|
+
corrected_ref = ok(
|
222
|
+
await self.correct_obj(
|
223
|
+
refs[0], # pyright: ignore [reportIndexIssue]
|
224
|
+
reference=f"# Original Article Outline\n{article_outline.display()}\n# Error Need to be fixed\n{err}",
|
181
225
|
topic=ref_topic,
|
182
226
|
rating_manual=ref_manual,
|
183
|
-
supervisor_check=
|
227
|
+
supervisor_check=supervisor_check,
|
184
228
|
)
|
185
229
|
)
|
230
|
+
for ref in refs:
|
231
|
+
ref.update_from(corrected_ref) # pyright: ignore [reportAttributeAccessIssue]
|
232
|
+
|
186
233
|
return article_outline.update_ref(article_outline)
|
187
234
|
|
188
235
|
|
189
|
-
class
|
190
|
-
"""Tweak the
|
236
|
+
class TweakOutlineForwardRef(Action, AdvancedJudge):
|
237
|
+
"""Tweak the forward references in the article outline.
|
191
238
|
|
192
|
-
Ensures that the
|
239
|
+
Ensures that the conclusions of the current chapter effectively support the analysis of subsequent chapters.
|
193
240
|
"""
|
194
241
|
|
195
|
-
output_key: str = "
|
242
|
+
output_key: str = "article_outline_fw_ref_checked"
|
196
243
|
|
197
|
-
async def _execute(self, article_outline: ArticleOutline, **cxt) -> ArticleOutline:
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
"
|
244
|
+
async def _execute(self, article_outline: ArticleOutline, supervisor_check: bool = False, **cxt) -> ArticleOutline:
|
245
|
+
return await self._inner(
|
246
|
+
article_outline,
|
247
|
+
supervisor_check,
|
248
|
+
topic="Ensure conclusions support the analysis of subsequent chapters, sections or subsections.",
|
249
|
+
field_name="support_to",
|
203
250
|
)
|
204
251
|
|
252
|
+
async def _inner(
|
253
|
+
self, article_outline: ArticleOutline, supervisor_check: bool, topic: str, field_name: str
|
254
|
+
) -> ArticleOutline:
|
255
|
+
tweak_support_to_manual = ok(
|
256
|
+
await self.draft_rating_manual(topic),
|
257
|
+
"Could not generate the rating manual.",
|
258
|
+
)
|
205
259
|
for a in article_outline.iter_dfs():
|
206
260
|
if await self.evidently_judge(
|
207
261
|
f"{article_outline.as_prompt()}\n\n{a.display()}\n"
|
208
|
-
f"Does the `{a.__class__.__name__}`'s `
|
262
|
+
f"Does the `{a.__class__.__name__}`'s `{field_name}` field need to be extended or tweaked?"
|
209
263
|
):
|
210
|
-
patch=ArticleRefPatch.default()
|
211
|
-
patch.tweaked=a
|
264
|
+
patch = ArticleRefPatch.default()
|
265
|
+
patch.tweaked = getattr(a, field_name)
|
212
266
|
|
213
267
|
await self.correct_obj_inplace(
|
214
268
|
patch,
|
215
269
|
topic=topic,
|
216
|
-
reference=f"{article_outline.as_prompt()}\nThe Article component whose `
|
217
|
-
rating_manual=
|
270
|
+
reference=f"{article_outline.as_prompt()}\nThe Article component whose `{field_name}` field needs to be extended or tweaked",
|
271
|
+
rating_manual=tweak_support_to_manual,
|
272
|
+
supervisor_check=supervisor_check,
|
218
273
|
)
|
219
|
-
|
220
274
|
return article_outline
|
221
275
|
|
222
276
|
|
223
|
-
class TweakOutlineForwardRef
|
224
|
-
"""Tweak the
|
277
|
+
class TweakOutlineBackwardRef(TweakOutlineForwardRef):
|
278
|
+
"""Tweak the backward references in the article outline.
|
225
279
|
|
226
|
-
Ensures that the
|
280
|
+
Ensures that the prerequisites of the current chapter are correctly referenced in the `depend_on` field.
|
227
281
|
"""
|
228
282
|
|
229
|
-
output_key: str = "
|
283
|
+
output_key: str = "article_outline_bw_ref_checked"
|
230
284
|
|
231
|
-
async def _execute(self, article_outline: ArticleOutline, **cxt) -> ArticleOutline:
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
"
|
285
|
+
async def _execute(self, article_outline: ArticleOutline, supervisor_check: bool = False, **cxt) -> ArticleOutline:
|
286
|
+
return await self._inner(
|
287
|
+
article_outline,
|
288
|
+
supervisor_check,
|
289
|
+
topic="Ensure the dependencies of the current chapter are neither abused nor missing.",
|
290
|
+
field_name="depend_on",
|
237
291
|
)
|
238
292
|
|
239
|
-
for a in article_outline.iter_dfs():
|
240
|
-
if await self.evidently_judge(
|
241
|
-
f"{article_outline.as_prompt()}\n\n{a.display()}\n"
|
242
|
-
f"Does the `{a.__class__.__name__}`'s `support_to` field need to be extended or tweaked?"
|
243
|
-
):
|
244
|
-
patch=ArticleRefPatch.default()
|
245
|
-
patch.tweaked=a.support_to
|
246
|
-
|
247
|
-
await self.correct_obj_inplace(
|
248
|
-
patch,
|
249
|
-
topic=topic,
|
250
|
-
reference=f"{article_outline.as_prompt()}\nThe Article component whose `support_to` field needs to be extended or tweaked",
|
251
|
-
rating_manual=tweak_support_to_manual,
|
252
|
-
)
|
253
|
-
|
254
|
-
return article_outline
|
255
|
-
|
256
293
|
|
257
294
|
class GenerateArticle(Action):
|
258
295
|
"""Generate the article based on the outline."""
|
@@ -263,6 +300,7 @@ class GenerateArticle(Action):
|
|
263
300
|
async def _execute(
|
264
301
|
self,
|
265
302
|
article_outline: ArticleOutline,
|
303
|
+
supervisor_check: bool = False,
|
266
304
|
**_,
|
267
305
|
) -> Optional[Article]:
|
268
306
|
article: Article = Article.from_outline(ok(article_outline, "Article outline not specified.")).update_ref(
|
@@ -280,7 +318,7 @@ class GenerateArticle(Action):
|
|
280
318
|
reference=f"# Original Article Outline\n{article_outline.display()}\n# Error Need to be fixed\n{err}",
|
281
319
|
topic=w_topic,
|
282
320
|
rating_manual=write_para_manual,
|
283
|
-
supervisor_check=
|
321
|
+
supervisor_check=supervisor_check,
|
284
322
|
)
|
285
323
|
for _, __, subsec in article.iter_subsections()
|
286
324
|
if (err := subsec.introspect())
|
@@ -1,35 +1,89 @@
|
|
1
1
|
"""A module for writing articles using RAG (Retrieval-Augmented Generation) capabilities."""
|
2
2
|
|
3
|
-
from
|
3
|
+
from asyncio import gather
|
4
|
+
from typing import Dict, Optional
|
4
5
|
|
5
6
|
from fabricatio.capabilities.rag import RAG
|
6
|
-
from fabricatio.journal import logger
|
7
7
|
from fabricatio.models.action import Action
|
8
|
-
from fabricatio.models.extra.article_main import Article
|
9
|
-
from fabricatio.
|
8
|
+
from fabricatio.models.extra.article_main import Article, ArticleParagraphPatch, ArticleSubsection
|
9
|
+
from fabricatio.utils import ok
|
10
10
|
|
11
11
|
|
12
|
-
class
|
12
|
+
class TweakArticleRAG(Action, RAG):
|
13
13
|
"""Write an article based on the provided outline."""
|
14
14
|
|
15
|
-
output_key: str = "
|
15
|
+
output_key: str = "rag_tweaked_article"
|
16
16
|
|
17
|
-
async def _execute(
|
17
|
+
async def _execute(
|
18
|
+
self,
|
19
|
+
article: Article,
|
20
|
+
collection_name: str = "article_essence",
|
21
|
+
citation_requirement: str = "# Citation Format\n"
|
22
|
+
"Use correct citation format based on author count. Cite using author surnames and year:"
|
23
|
+
"For 3+ authors: 'Author1, Author2 et al. (YYYY)'"
|
24
|
+
"For 2 authors: 'Author1 & Author2 (YYYY)'"
|
25
|
+
"Single author: 'Author1 (YYYY)'"
|
26
|
+
"Multiple citations: 'Author1 (YYYY), Author2 (YYYY)'"
|
27
|
+
"Prioritize formulas from reference highlights."
|
28
|
+
"Specify authors/years only."
|
29
|
+
"You can create numeric citation numbers for article whose `bibtex_cite_key` is 'wangWind2024' by using notation like `#cite(<wangWind2024>)`."
|
30
|
+
"Paragraphs must exceed 2-3 sentences",
|
31
|
+
supervisor_check: bool = False,
|
32
|
+
parallel: bool = False,
|
33
|
+
**cxt,
|
34
|
+
) -> Optional[Article]:
|
18
35
|
"""Write an article based on the provided outline."""
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
36
|
+
criteria = await self.draft_rating_criteria(
|
37
|
+
topic := "choose appropriate reference to insert into the article, "
|
38
|
+
"making conclusions or reasoning based on concrete evidence instead of unreliable guesses."
|
39
|
+
"Extensively use formulas highlighted in reference文献 should be translated to 'references'."
|
40
|
+
"Only specify authors and years without numeric citation numbers (like [1],[2])."
|
41
|
+
"Each paragraph should not end with just 2-3 sentences for better readability.",
|
42
|
+
criteria_count=13,
|
25
43
|
)
|
26
44
|
|
45
|
+
tweak_manual = ok(await self.draft_rating_manual(topic, criteria=criteria))
|
46
|
+
self.view(collection_name)
|
27
47
|
|
28
|
-
|
29
|
-
|
48
|
+
if parallel:
|
49
|
+
await gather(
|
50
|
+
*[
|
51
|
+
self._inner(article, subsec, supervisor_check, citation_requirement, topic, tweak_manual)
|
52
|
+
for _, __, subsec in article.iter_subsections()
|
53
|
+
],
|
54
|
+
return_exceptions=True,
|
55
|
+
)
|
56
|
+
else:
|
57
|
+
for _, __, subsec in article.iter_subsections():
|
58
|
+
await self._inner(article, subsec, supervisor_check, citation_requirement, topic, tweak_manual)
|
30
59
|
|
31
|
-
|
60
|
+
return article
|
32
61
|
|
33
|
-
async def
|
34
|
-
|
35
|
-
|
62
|
+
async def _inner(
|
63
|
+
self,
|
64
|
+
article: Article,
|
65
|
+
subsec: ArticleSubsection,
|
66
|
+
supervisor_check: bool,
|
67
|
+
citation_requirement: str,
|
68
|
+
topic: str,
|
69
|
+
tweak_manual: Dict[str, str],
|
70
|
+
) -> None:
|
71
|
+
refind_q = ok(
|
72
|
+
await self.arefined_query(
|
73
|
+
f"{article.referenced.as_prompt()}\n"
|
74
|
+
f"# Subsection requiring reference enhancement\n"
|
75
|
+
f"{subsec.display()}\n"
|
76
|
+
f"# Requirement\n"
|
77
|
+
f"Search related articles in the base to find reference candidates, "
|
78
|
+
f"prioritizing both original article language and English usage",
|
79
|
+
)
|
80
|
+
)
|
81
|
+
patch = ArticleParagraphPatch.default()
|
82
|
+
patch.tweaked = subsec.paragraphs
|
83
|
+
await self.correct_obj_inplace(
|
84
|
+
patch,
|
85
|
+
reference=f"{await self.aretrieve_compact(refind_q, final_limit=50)}\n{citation_requirement}",
|
86
|
+
topic=topic,
|
87
|
+
rating_manual=tweak_manual,
|
88
|
+
supervisor_check=supervisor_check,
|
89
|
+
)
|
fabricatio/actions/output.py
CHANGED
@@ -1,19 +1,20 @@
|
|
1
1
|
"""Dump the finalized output to a file."""
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Optional
|
4
|
+
from typing import Iterable, List, Optional, Type
|
5
5
|
|
6
6
|
from fabricatio.journal import logger
|
7
7
|
from fabricatio.models.action import Action
|
8
8
|
from fabricatio.models.generic import FinalizedDumpAble, PersistentAble
|
9
9
|
from fabricatio.models.task import Task
|
10
|
-
from fabricatio.
|
10
|
+
from fabricatio.utils import ok
|
11
11
|
|
12
12
|
|
13
13
|
class DumpFinalizedOutput(Action):
|
14
14
|
"""Dump the finalized output to a file."""
|
15
15
|
|
16
16
|
output_key: str = "dump_path"
|
17
|
+
dump_path: Optional[str] = None
|
17
18
|
|
18
19
|
async def _execute(
|
19
20
|
self,
|
@@ -24,6 +25,7 @@ class DumpFinalizedOutput(Action):
|
|
24
25
|
) -> str:
|
25
26
|
dump_path = Path(
|
26
27
|
dump_path
|
28
|
+
or self.dump_path
|
27
29
|
or ok(
|
28
30
|
await self.awhich_pathstr(
|
29
31
|
f"{ok(task_input, 'Neither `task_input` and `dump_path` is provided.').briefing}\n\nExtract a single path of the file, to which I will dump the data."
|
@@ -39,6 +41,7 @@ class PersistentAll(Action):
|
|
39
41
|
"""Persist all the data to a file."""
|
40
42
|
|
41
43
|
output_key: str = "persistent_count"
|
44
|
+
persist_dir: Optional[str] = None
|
42
45
|
|
43
46
|
async def _execute(
|
44
47
|
self,
|
@@ -48,6 +51,7 @@ class PersistentAll(Action):
|
|
48
51
|
) -> int:
|
49
52
|
persist_dir = Path(
|
50
53
|
persist_dir
|
54
|
+
or self.persist_dir
|
51
55
|
or ok(
|
52
56
|
await self.awhich_pathstr(
|
53
57
|
f"{ok(task_input, 'Neither `task_input` and `dump_path` is provided.').briefing}\n\nExtract a single path of the file, to which I will persist the data."
|
@@ -60,10 +64,39 @@ class PersistentAll(Action):
|
|
60
64
|
if persist_dir.is_file():
|
61
65
|
logger.warning("Dump should be a directory, but it is a file. Skip dumping.")
|
62
66
|
return count
|
63
|
-
|
64
|
-
for v in cxt.
|
67
|
+
|
68
|
+
for k, v in cxt.items():
|
69
|
+
final_dir = persist_dir.joinpath(k)
|
65
70
|
if isinstance(v, PersistentAble):
|
66
|
-
|
71
|
+
final_dir.mkdir(parents=True, exist_ok=True)
|
72
|
+
v.persist(final_dir)
|
67
73
|
count += 1
|
68
|
-
|
74
|
+
if isinstance(v, Iterable) and any(
|
75
|
+
persistent_ables := (pers for pers in v if isinstance(pers, PersistentAble))
|
76
|
+
):
|
77
|
+
final_dir.mkdir(parents=True, exist_ok=True)
|
78
|
+
for per in persistent_ables:
|
79
|
+
per.persist(final_dir)
|
80
|
+
count += 1
|
81
|
+
logger.info(f"Persisted {count} objects to {persist_dir}")
|
69
82
|
return count
|
83
|
+
|
84
|
+
|
85
|
+
class RetrieveFromPersistent[T: PersistentAble](Action):
|
86
|
+
"""Retrieve the object from the persistent file."""
|
87
|
+
|
88
|
+
output_key: str = "retrieved_obj"
|
89
|
+
"""Retrieve the object from the persistent file."""
|
90
|
+
load_path: str
|
91
|
+
"""The path of the persistent file or directory contains multiple file."""
|
92
|
+
retrieve_cls: Type[T]
|
93
|
+
"""The class of the object to retrieve."""
|
94
|
+
|
95
|
+
async def _execute(self, /, **__) -> Optional[T | List[T]]:
|
96
|
+
logger.info(f"Retrieve `{self.retrieve_cls.__name__}` from persistent file: {self.load_path}")
|
97
|
+
if not (p := Path(self.load_path)).exists():
|
98
|
+
return None
|
99
|
+
|
100
|
+
if p.is_dir():
|
101
|
+
return [self.retrieve_cls.from_persistent(per) for per in p.glob("*")]
|
102
|
+
return self.retrieve_cls.from_persistent(self.load_path)
|
fabricatio/actions/rag.py
CHANGED
@@ -5,7 +5,7 @@ from typing import List, Optional
|
|
5
5
|
from fabricatio.capabilities.rag import RAG
|
6
6
|
from fabricatio.journal import logger
|
7
7
|
from fabricatio.models.action import Action
|
8
|
-
from fabricatio.models.generic import
|
8
|
+
from fabricatio.models.generic import Vectorizable
|
9
9
|
from fabricatio.models.task import Task
|
10
10
|
from questionary import text
|
11
11
|
|
@@ -15,7 +15,7 @@ class InjectToDB(Action, RAG):
|
|
15
15
|
|
16
16
|
output_key: str = "collection_name"
|
17
17
|
|
18
|
-
async def _execute[T:
|
18
|
+
async def _execute[T: Vectorizable](
|
19
19
|
self, to_inject: Optional[T] | List[Optional[T]], collection_name: str = "my_collection",override_inject:bool=False, **_
|
20
20
|
) -> Optional[str]:
|
21
21
|
if not isinstance(to_inject, list):
|
@@ -27,7 +27,7 @@ class InjectToDB(Action, RAG):
|
|
27
27
|
[
|
28
28
|
t.prepare_vectorization(self.embedding_max_sequence_length)
|
29
29
|
for t in to_inject
|
30
|
-
if isinstance(t,
|
30
|
+
if isinstance(t, Vectorizable)
|
31
31
|
],
|
32
32
|
)
|
33
33
|
|
@@ -0,0 +1,97 @@
|
|
1
|
+
"""A class that provides the capability to check strings and objects against rules and guidelines."""
|
2
|
+
from typing import Optional, Unpack
|
3
|
+
|
4
|
+
from fabricatio import TEMPLATE_MANAGER
|
5
|
+
from fabricatio.capabilities.advanced_judge import AdvancedJudge
|
6
|
+
from fabricatio.capabilities.propose import Propose
|
7
|
+
from fabricatio.config import configs
|
8
|
+
from fabricatio.models.extra.problem import Improvement
|
9
|
+
from fabricatio.models.extra.rule import Rule, RuleSet
|
10
|
+
from fabricatio.models.generic import Display, WithBriefing
|
11
|
+
from fabricatio.models.kwargs_types import ValidateKwargs
|
12
|
+
from fabricatio.utils import override_kwargs
|
13
|
+
|
14
|
+
|
15
|
+
class Check(AdvancedJudge, Propose):
|
16
|
+
"""Class that provides the capability to validate strings/objects against predefined rules and guidelines."""
|
17
|
+
|
18
|
+
async def draft_ruleset(
|
19
|
+
self, ruleset_requirement: str, **kwargs: Unpack[ValidateKwargs[RuleSet]]
|
20
|
+
) -> Optional[RuleSet]:
|
21
|
+
"""Generate a rule set based on specified requirements.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
ruleset_requirement (str): Description of desired rule set characteristics
|
25
|
+
**kwargs: Validation configuration parameters
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
Optional[RuleSet]: Generated rule set if successful
|
29
|
+
"""
|
30
|
+
return await self.propose(RuleSet, ruleset_requirement, **kwargs)
|
31
|
+
|
32
|
+
async def draft_rule(self, rule_requirement: str, **kwargs: Unpack[ValidateKwargs[Rule]]) -> Optional[Rule]:
|
33
|
+
"""Create a specific rule based on given specifications.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
rule_requirement (str): Detailed rule description requirements
|
37
|
+
**kwargs: Validation configuration parameters
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Optional[Rule]: Generated rule instance if successful
|
41
|
+
"""
|
42
|
+
return await self.propose(Rule, rule_requirement, **kwargs)
|
43
|
+
|
44
|
+
async def check_string(
|
45
|
+
self,
|
46
|
+
input_text: str,
|
47
|
+
rule: Rule,
|
48
|
+
**kwargs: Unpack[ValidateKwargs[Improvement]],
|
49
|
+
) -> Optional[Improvement]:
|
50
|
+
"""Evaluate text against a specific rule.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
input_text (str): Text content to be evaluated
|
54
|
+
rule (Rule): Rule instance used for validation
|
55
|
+
**kwargs: Validation configuration parameters
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Optional[Improvement]: Suggested improvement if violations found, else None
|
59
|
+
"""
|
60
|
+
if judge := await self.evidently_judge(
|
61
|
+
f"# Content to exam\n{input_text}\n\n# Rule Must to follow\n{rule.display()}\nDoes `Content to exam` provided above violate the `Rule Must to follow` provided above?",
|
62
|
+
**override_kwargs(kwargs, default=None),
|
63
|
+
):
|
64
|
+
return await self.propose(
|
65
|
+
Improvement,
|
66
|
+
TEMPLATE_MANAGER.render_template(
|
67
|
+
configs.templates.check_string_template,
|
68
|
+
{"to_check": input_text, "rule": rule, "judge": judge.display()},
|
69
|
+
),
|
70
|
+
**kwargs,
|
71
|
+
)
|
72
|
+
return None
|
73
|
+
|
74
|
+
async def check_obj[M: (Display, WithBriefing)](
|
75
|
+
self,
|
76
|
+
obj: M,
|
77
|
+
rule: Rule,
|
78
|
+
**kwargs: Unpack[ValidateKwargs[Improvement]],
|
79
|
+
) -> Optional[Improvement]:
|
80
|
+
"""Validate an object against specified rule.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
obj (M): Object implementing Display or WithBriefing interface
|
84
|
+
rule (Rule): Validation rule to apply
|
85
|
+
**kwargs: Validation configuration parameters
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
Optional[Improvement]: Improvement suggestion if issues detected
|
89
|
+
"""
|
90
|
+
if isinstance(obj, Display):
|
91
|
+
input_text = obj.display()
|
92
|
+
elif isinstance(obj, WithBriefing):
|
93
|
+
input_text = obj.briefing
|
94
|
+
else:
|
95
|
+
raise TypeError("obj must be either Display or WithBriefing")
|
96
|
+
|
97
|
+
return await self.check_string(input_text, rule, **kwargs)
|