fabricatio 0.2.9.dev2__cp312-cp312-win_amd64.whl → 0.2.9.dev4__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +84 -107
- fabricatio/actions/article_rag.py +15 -10
- fabricatio/actions/output.py +20 -4
- fabricatio/actions/rules.py +37 -4
- fabricatio/capabilities/censor.py +21 -5
- fabricatio/capabilities/check.py +40 -22
- fabricatio/capabilities/correct.py +30 -11
- fabricatio/capabilities/rating.py +53 -47
- fabricatio/config.py +2 -2
- fabricatio/fs/readers.py +20 -1
- fabricatio/models/action.py +6 -6
- fabricatio/models/extra/advanced_judge.py +3 -3
- fabricatio/models/extra/article_base.py +117 -57
- fabricatio/models/extra/article_main.py +102 -14
- fabricatio/models/extra/article_proposal.py +15 -14
- fabricatio/models/extra/patches.py +6 -6
- fabricatio/models/extra/problem.py +20 -7
- fabricatio/models/extra/rule.py +16 -4
- fabricatio/models/generic.py +23 -6
- fabricatio/models/usages.py +7 -16
- fabricatio/parser.py +5 -5
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +33 -0
- fabricatio/utils.py +5 -5
- fabricatio/workflows/articles.py +3 -5
- fabricatio-0.2.9.dev4.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.9.dev2.dist-info → fabricatio-0.2.9.dev4.dist-info}/METADATA +1 -1
- {fabricatio-0.2.9.dev2.dist-info → fabricatio-0.2.9.dev4.dist-info}/RECORD +30 -30
- fabricatio-0.2.9.dev2.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.9.dev2.dist-info → fabricatio-0.2.9.dev4.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.9.dev2.dist-info → fabricatio-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
fabricatio/actions/article.py
CHANGED
@@ -2,24 +2,23 @@
|
|
2
2
|
|
3
3
|
from asyncio import gather
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import
|
5
|
+
from typing import Callable, List, Optional
|
6
6
|
|
7
7
|
from more_itertools import filter_map
|
8
8
|
|
9
9
|
from fabricatio.capabilities.censor import Censor
|
10
|
-
from fabricatio.capabilities.correct import Correct
|
11
10
|
from fabricatio.capabilities.propose import Propose
|
12
11
|
from fabricatio.fs import safe_text_read
|
13
12
|
from fabricatio.journal import logger
|
14
13
|
from fabricatio.models.action import Action
|
15
|
-
from fabricatio.models.extra.article_base import
|
14
|
+
from fabricatio.models.extra.article_base import SubSectionBase
|
16
15
|
from fabricatio.models.extra.article_essence import ArticleEssence
|
17
16
|
from fabricatio.models.extra.article_main import Article
|
18
17
|
from fabricatio.models.extra.article_outline import ArticleOutline
|
19
18
|
from fabricatio.models.extra.article_proposal import ArticleProposal
|
20
19
|
from fabricatio.models.extra.rule import RuleSet
|
21
20
|
from fabricatio.models.task import Task
|
22
|
-
from fabricatio.rust import BibManager
|
21
|
+
from fabricatio.rust import BibManager, detect_language
|
23
22
|
from fabricatio.utils import ok
|
24
23
|
|
25
24
|
|
@@ -53,7 +52,7 @@ class ExtractArticleEssence(Action, Propose):
|
|
53
52
|
for ess in await self.propose(
|
54
53
|
ArticleEssence,
|
55
54
|
[
|
56
|
-
f"{c}\n\n\nBased the provided academic article above, you need to extract the essence from it
|
55
|
+
f"{c}\n\n\nBased the provided academic article above, you need to extract the essence from it.\n\nWrite the value string using `{detect_language(c)}`"
|
57
56
|
for c in contents
|
58
57
|
],
|
59
58
|
):
|
@@ -106,35 +105,30 @@ class GenerateArticleProposal(Action, Propose):
|
|
106
105
|
task_input: Optional[Task] = None,
|
107
106
|
article_briefing: Optional[str] = None,
|
108
107
|
article_briefing_path: Optional[str] = None,
|
109
|
-
langauge: Optional[str] = None,
|
110
108
|
**_,
|
111
109
|
) -> Optional[ArticleProposal]:
|
112
110
|
if article_briefing is None and article_briefing_path is None and task_input is None:
|
113
111
|
logger.error("Task not approved, since all inputs are None.")
|
114
112
|
return None
|
115
113
|
|
116
|
-
|
114
|
+
briefing = article_briefing or safe_text_read(
|
115
|
+
ok(
|
116
|
+
article_briefing_path
|
117
|
+
or await self.awhich_pathstr(
|
118
|
+
f"{ok(task_input).briefing}\nExtract the path of file which contains the article briefing."
|
119
|
+
),
|
120
|
+
"Could not find the path of file to read.",
|
121
|
+
)
|
122
|
+
)
|
123
|
+
|
124
|
+
logger.info("Start generating the proposal.")
|
125
|
+
return ok(
|
117
126
|
await self.propose(
|
118
127
|
ArticleProposal,
|
119
|
-
briefing
|
120
|
-
article_briefing
|
121
|
-
or safe_text_read(
|
122
|
-
ok(
|
123
|
-
article_briefing_path
|
124
|
-
or await self.awhich_pathstr(
|
125
|
-
f"{ok(task_input).briefing}\nExtract the path of file which contains the article briefing."
|
126
|
-
),
|
127
|
-
"Could not find the path of file to read.",
|
128
|
-
)
|
129
|
-
)
|
130
|
-
),
|
128
|
+
f"{briefing}\n\nWrite the value string using `{detect_language(briefing)}` as written language.",
|
131
129
|
),
|
132
130
|
"Could not generate the proposal.",
|
133
131
|
).update_ref(briefing)
|
134
|
-
if langauge:
|
135
|
-
proposal.language = langauge
|
136
|
-
|
137
|
-
return proposal
|
138
132
|
|
139
133
|
|
140
134
|
class GenerateInitialOutline(Action, Propose):
|
@@ -151,7 +145,8 @@ class GenerateInitialOutline(Action, Propose):
|
|
151
145
|
return ok(
|
152
146
|
await self.propose(
|
153
147
|
ArticleOutline,
|
154
|
-
article_proposal.as_prompt()
|
148
|
+
f"{(article_proposal.as_prompt())}\n\nNote that you should use `{article_proposal.language}` to write the `ArticleOutline`\n"
|
149
|
+
f"You Must make sure every chapter have sections, and every section have subsections.",
|
155
150
|
),
|
156
151
|
"Could not generate the initial outline.",
|
157
152
|
).update_ref(article_proposal)
|
@@ -165,25 +160,33 @@ class FixIntrospectedErrors(Action, Censor):
|
|
165
160
|
|
166
161
|
ruleset: Optional[RuleSet] = None
|
167
162
|
"""The ruleset to use to fix the introspected errors."""
|
163
|
+
max_error_count: Optional[int] = None
|
164
|
+
"""The maximum number of errors to fix."""
|
168
165
|
|
169
166
|
async def _execute(
|
170
167
|
self,
|
171
168
|
article_outline: ArticleOutline,
|
172
|
-
|
169
|
+
intro_fix_ruleset: Optional[RuleSet] = None,
|
173
170
|
**_,
|
174
171
|
) -> Optional[ArticleOutline]:
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
172
|
+
counter = 0
|
173
|
+
origin = article_outline
|
174
|
+
while pack := article_outline.gather_introspected():
|
175
|
+
logger.info(f"Found {counter}th introspected errors")
|
176
|
+
logger.warning(f"Found introspected error: {pack}")
|
177
|
+
article_outline = ok(
|
179
178
|
await self.censor_obj(
|
180
|
-
|
181
|
-
ruleset=ok(
|
182
|
-
reference=f"
|
179
|
+
article_outline,
|
180
|
+
ruleset=ok(intro_fix_ruleset or self.ruleset, "No ruleset provided"),
|
181
|
+
reference=f"{article_outline.as_prompt()}\n # Fatal Error of the Original Article Outline\n{pack}",
|
183
182
|
),
|
184
183
|
"Could not correct the component.",
|
185
|
-
)
|
186
|
-
|
184
|
+
).update_ref(origin)
|
185
|
+
|
186
|
+
if self.max_error_count and counter > self.max_error_count:
|
187
|
+
logger.warning("Max error count reached, stopping.")
|
188
|
+
break
|
189
|
+
counter += 1
|
187
190
|
|
188
191
|
return article_outline
|
189
192
|
|
@@ -196,27 +199,36 @@ class FixIllegalReferences(Action, Censor):
|
|
196
199
|
|
197
200
|
ruleset: Optional[RuleSet] = None
|
198
201
|
"""Ruleset to use to fix the illegal references."""
|
202
|
+
max_error_count: Optional[int] = None
|
203
|
+
"""The maximum number of errors to fix."""
|
199
204
|
|
200
205
|
async def _execute(
|
201
206
|
self,
|
202
207
|
article_outline: ArticleOutline,
|
203
|
-
|
208
|
+
ref_fix_ruleset: Optional[RuleSet] = None,
|
204
209
|
**_,
|
205
210
|
) -> Optional[ArticleOutline]:
|
211
|
+
counter = 0
|
206
212
|
while pack := article_outline.find_illegal_ref(gather_identical=True):
|
207
|
-
|
213
|
+
logger.info(f"Found {counter}th illegal references")
|
214
|
+
ref_seq, err = ok(pack)
|
208
215
|
logger.warning(f"Found illegal referring error: {err}")
|
209
|
-
|
216
|
+
new = ok(
|
210
217
|
await self.censor_obj(
|
211
|
-
|
212
|
-
ruleset=ok(
|
213
|
-
reference=f"
|
214
|
-
)
|
218
|
+
ref_seq[0],
|
219
|
+
ruleset=ok(ref_fix_ruleset or self.ruleset, "No ruleset provided"),
|
220
|
+
reference=f"{article_outline.as_prompt()}\n# Some Basic errors found that need to be fixed\n{err}",
|
221
|
+
),
|
222
|
+
"Could not correct the component",
|
215
223
|
)
|
216
|
-
for
|
217
|
-
|
224
|
+
for r in ref_seq:
|
225
|
+
r.update_from(new)
|
226
|
+
if self.max_error_count and counter > self.max_error_count:
|
227
|
+
logger.warning("Max error count reached, stopping.")
|
228
|
+
break
|
229
|
+
counter += 1
|
218
230
|
|
219
|
-
return article_outline
|
231
|
+
return article_outline
|
220
232
|
|
221
233
|
|
222
234
|
class TweakOutlineForwardRef(Action, Censor):
|
@@ -230,32 +242,36 @@ class TweakOutlineForwardRef(Action, Censor):
|
|
230
242
|
"""Ruleset to use to fix the illegal references."""
|
231
243
|
|
232
244
|
async def _execute(
|
233
|
-
self, article_outline: ArticleOutline,
|
245
|
+
self, article_outline: ArticleOutline, ref_twk_ruleset: Optional[RuleSet] = None, **cxt
|
234
246
|
) -> ArticleOutline:
|
235
247
|
return await self._inner(
|
236
248
|
article_outline,
|
237
|
-
ruleset=ok(
|
249
|
+
ruleset=ok(ref_twk_ruleset or self.ruleset, "No ruleset provided"),
|
238
250
|
field_name="support_to",
|
239
251
|
)
|
240
252
|
|
241
253
|
async def _inner(self, article_outline: ArticleOutline, ruleset: RuleSet, field_name: str) -> ArticleOutline:
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
):
|
247
|
-
patch = ArticleRefSequencePatch.default()
|
248
|
-
patch.tweaked = getattr(a, field_name)
|
249
|
-
|
250
|
-
await self.censor_obj_inplace(
|
251
|
-
patch,
|
252
|
-
ruleset=ruleset,
|
253
|
-
reference=f"{article_outline.as_prompt()}\n"
|
254
|
-
f"The Article component titled `{a.title}` whose `{field_name}` field needs to be extended or tweaked.\n"
|
255
|
-
f"# Judgement\n{judge.display()}",
|
256
|
-
)
|
254
|
+
await gather(
|
255
|
+
*[self._loop(a[-1], article_outline, field_name, ruleset) for a in article_outline.iter_subsections()],
|
256
|
+
)
|
257
|
+
|
257
258
|
return article_outline
|
258
259
|
|
260
|
+
async def _loop(
|
261
|
+
self, a: SubSectionBase, article_outline: ArticleOutline, field_name: str, ruleset: RuleSet
|
262
|
+
) -> None:
|
263
|
+
if judge := await self.evidently_judge(
|
264
|
+
f"{article_outline.as_prompt()}\n\n{a.display()}\n"
|
265
|
+
f"Does the `{a.__class__.__name__}`'s `{field_name}` field need to be extended or tweaked?"
|
266
|
+
):
|
267
|
+
await self.censor_obj_inplace(
|
268
|
+
a,
|
269
|
+
ruleset=ruleset,
|
270
|
+
reference=f"{article_outline.as_prompt()}\n"
|
271
|
+
f"The Article component titled `{a.title}` whose `{field_name}` field needs to be extended or tweaked.\n"
|
272
|
+
f"# Judgement\n{judge.display()}",
|
273
|
+
)
|
274
|
+
|
259
275
|
|
260
276
|
class TweakOutlineBackwardRef(TweakOutlineForwardRef):
|
261
277
|
"""Tweak the backward references in the article outline.
|
@@ -267,11 +283,11 @@ class TweakOutlineBackwardRef(TweakOutlineForwardRef):
|
|
267
283
|
ruleset: Optional[RuleSet] = None
|
268
284
|
|
269
285
|
async def _execute(
|
270
|
-
self, article_outline: ArticleOutline,
|
286
|
+
self, article_outline: ArticleOutline, ref_twk_ruleset: Optional[RuleSet] = None, **cxt
|
271
287
|
) -> ArticleOutline:
|
272
288
|
return await self._inner(
|
273
289
|
article_outline,
|
274
|
-
ruleset=ok(
|
290
|
+
ruleset=ok(ref_twk_ruleset or self.ruleset, "No ruleset provided"),
|
275
291
|
field_name="depend_on",
|
276
292
|
)
|
277
293
|
|
@@ -286,7 +302,7 @@ class GenerateArticle(Action, Censor):
|
|
286
302
|
async def _execute(
|
287
303
|
self,
|
288
304
|
article_outline: ArticleOutline,
|
289
|
-
|
305
|
+
article_gen_ruleset: Optional[RuleSet] = None,
|
290
306
|
**_,
|
291
307
|
) -> Optional[Article]:
|
292
308
|
article: Article = Article.from_outline(ok(article_outline, "Article outline not specified.")).update_ref(
|
@@ -297,51 +313,12 @@ class GenerateArticle(Action, Censor):
|
|
297
313
|
*[
|
298
314
|
self.censor_obj_inplace(
|
299
315
|
subsec,
|
300
|
-
ruleset=ok(
|
301
|
-
reference=f"
|
316
|
+
ruleset=ok(article_gen_ruleset or self.ruleset, "No ruleset provided"),
|
317
|
+
reference=f"{article_outline.as_prompt()}\n# Error Need to be fixed\n{err}\nYou should use `{subsec.language}` to write the new `Subsection`.",
|
302
318
|
)
|
303
|
-
for _,
|
304
|
-
if (err := subsec.introspect())
|
319
|
+
for _, _, subsec in article.iter_subsections()
|
320
|
+
if (err := subsec.introspect()) and logger.warning(f"Found Introspection Error:\n{err}") is None
|
305
321
|
],
|
306
|
-
return_exceptions=True,
|
307
322
|
)
|
308
323
|
|
309
324
|
return article
|
310
|
-
|
311
|
-
|
312
|
-
class CorrectProposal(Action, Censor):
|
313
|
-
"""Correct the proposal of the article."""
|
314
|
-
|
315
|
-
output_key: str = "corrected_proposal"
|
316
|
-
|
317
|
-
async def _execute(self, article_proposal: ArticleProposal, **_) -> Any:
|
318
|
-
raise NotImplementedError("Not implemented.")
|
319
|
-
|
320
|
-
|
321
|
-
class CorrectOutline(Action, Correct):
|
322
|
-
"""Correct the outline of the article."""
|
323
|
-
|
324
|
-
output_key: str = "corrected_outline"
|
325
|
-
"""The key of the output data."""
|
326
|
-
|
327
|
-
async def _execute(
|
328
|
-
self,
|
329
|
-
article_outline: ArticleOutline,
|
330
|
-
**_,
|
331
|
-
) -> ArticleOutline:
|
332
|
-
raise NotImplementedError("Not implemented.")
|
333
|
-
|
334
|
-
|
335
|
-
class CorrectArticle(Action, Correct):
|
336
|
-
"""Correct the article based on the outline."""
|
337
|
-
|
338
|
-
output_key: str = "corrected_article"
|
339
|
-
"""The key of the output data."""
|
340
|
-
|
341
|
-
async def _execute(
|
342
|
-
self,
|
343
|
-
article: Article,
|
344
|
-
article_outline: ArticleOutline,
|
345
|
-
**_,
|
346
|
-
) -> Article:
|
347
|
-
raise NotImplementedError("Not implemented.")
|
@@ -6,7 +6,7 @@ from typing import Optional
|
|
6
6
|
from fabricatio.capabilities.censor import Censor
|
7
7
|
from fabricatio.capabilities.rag import RAG
|
8
8
|
from fabricatio.models.action import Action
|
9
|
-
from fabricatio.models.extra.article_main import Article,
|
9
|
+
from fabricatio.models.extra.article_main import Article, ArticleSubsection
|
10
10
|
from fabricatio.models.extra.rule import RuleSet
|
11
11
|
from fabricatio.utils import ok
|
12
12
|
|
@@ -29,11 +29,14 @@ class TweakArticleRAG(Action, RAG, Censor):
|
|
29
29
|
ruleset: Optional[RuleSet] = None
|
30
30
|
"""The ruleset to be used for censoring the article."""
|
31
31
|
|
32
|
+
ref_limit: int = 30
|
33
|
+
"""The limit of references to be retrieved"""
|
34
|
+
|
32
35
|
async def _execute(
|
33
36
|
self,
|
34
37
|
article: Article,
|
35
38
|
collection_name: str = "article_essence",
|
36
|
-
|
39
|
+
twk_rag_ruleset: Optional[RuleSet] = None,
|
37
40
|
parallel: bool = False,
|
38
41
|
**cxt,
|
39
42
|
) -> Optional[Article]:
|
@@ -45,7 +48,7 @@ class TweakArticleRAG(Action, RAG, Censor):
|
|
45
48
|
Args:
|
46
49
|
article (Article): The article to be processed.
|
47
50
|
collection_name (str): The name of the collection to view for processing.
|
48
|
-
|
51
|
+
twk_rag_ruleset (Optional[RuleSet]): The ruleset to apply for censoring. If not provided, the class's ruleset is used.
|
49
52
|
parallel (bool): If True, process subsections in parallel. Otherwise, process them sequentially.
|
50
53
|
**cxt: Additional context parameters.
|
51
54
|
|
@@ -57,14 +60,14 @@ class TweakArticleRAG(Action, RAG, Censor):
|
|
57
60
|
if parallel:
|
58
61
|
await gather(
|
59
62
|
*[
|
60
|
-
self._inner(article, subsec, ok(
|
63
|
+
self._inner(article, subsec, ok(twk_rag_ruleset or self.ruleset, "No ruleset provided!"))
|
61
64
|
for _, __, subsec in article.iter_subsections()
|
62
65
|
],
|
63
66
|
return_exceptions=True,
|
64
67
|
)
|
65
68
|
else:
|
66
69
|
for _, __, subsec in article.iter_subsections():
|
67
|
-
await self._inner(article, subsec, ok(
|
70
|
+
await self._inner(article, subsec, ok(twk_rag_ruleset or self.ruleset, "No ruleset provided!"))
|
68
71
|
return article
|
69
72
|
|
70
73
|
async def _inner(self, article: Article, subsec: ArticleSubsection, ruleset: RuleSet) -> None:
|
@@ -88,13 +91,15 @@ class TweakArticleRAG(Action, RAG, Censor):
|
|
88
91
|
f"{subsec.display()}\n"
|
89
92
|
f"# Requirement\n"
|
90
93
|
f"Search related articles in the base to find reference candidates, "
|
91
|
-
f"
|
94
|
+
f"provide queries in both `English` and `{subsec.language}` can get more accurate results.",
|
92
95
|
)
|
93
96
|
)
|
94
|
-
patch = ArticleParagraphSequencePatch.default()
|
95
|
-
patch.tweaked = subsec.paragraphs
|
96
97
|
await self.censor_obj_inplace(
|
97
|
-
|
98
|
+
subsec,
|
98
99
|
ruleset=ruleset,
|
99
|
-
reference=await self.aretrieve_compact(refind_q, final_limit=
|
100
|
+
reference=f"{await self.aretrieve_compact(refind_q, final_limit=self.ref_limit)}\n\n"
|
101
|
+
f"You can use Reference above to rewrite the `{subsec.__class__.__name__}`.\n"
|
102
|
+
f"You should Always use `{subsec.language}` as written language, "
|
103
|
+
f"which is the original language of the `{subsec.title}`. "
|
104
|
+
f"since rewrite a `{subsec.__class__.__name__}` in a different language is usually a bad choice",
|
100
105
|
)
|
fabricatio/actions/output.py
CHANGED
@@ -103,7 +103,7 @@ class RetrieveFromPersistent[T: PersistentAble](Action):
|
|
103
103
|
retrieve_cls: Type[T]
|
104
104
|
"""The class of the object to retrieve."""
|
105
105
|
|
106
|
-
async def _execute(self, /, **
|
106
|
+
async def _execute(self, /, **_) -> Optional[T | List[T]]:
|
107
107
|
logger.info(f"Retrieve `{self.retrieve_cls.__name__}` from {self.load_path}")
|
108
108
|
if not (p := Path(self.load_path)).exists():
|
109
109
|
logger.warning(f"Path {self.load_path} does not exist")
|
@@ -115,12 +115,29 @@ class RetrieveFromPersistent[T: PersistentAble](Action):
|
|
115
115
|
return self.retrieve_cls.from_persistent(self.load_path)
|
116
116
|
|
117
117
|
|
118
|
+
class RetrieveFromLatest[T: PersistentAble](RetrieveFromPersistent[T]):
|
119
|
+
"""Retrieve the object from the latest persistent file in the dir at `load_path`."""
|
120
|
+
|
121
|
+
async def _execute(self, /, **_) -> Optional[T]:
|
122
|
+
logger.info(f"Retrieve latest `{self.retrieve_cls.__name__}` from {self.load_path}")
|
123
|
+
if not (p := Path(self.load_path)).exists():
|
124
|
+
logger.warning(f"Path {self.load_path} does not exist")
|
125
|
+
return None
|
126
|
+
|
127
|
+
if p.is_dir():
|
128
|
+
logger.info(f"Found directory with {len(list(p.glob('*')))} items")
|
129
|
+
return self.retrieve_cls.from_latest_persistent(self.load_path)
|
130
|
+
logger.error(f"Path {self.load_path} is not a directory")
|
131
|
+
return None
|
132
|
+
|
133
|
+
|
118
134
|
class GatherAsList(Action):
|
119
135
|
"""Gather the objects from the context as a list.
|
120
136
|
|
121
137
|
Notes:
|
122
138
|
If both `gather_suffix` and `gather_prefix` are specified, only the objects with the suffix will be gathered.
|
123
139
|
"""
|
140
|
+
|
124
141
|
output_key: str = "gathered"
|
125
142
|
"""Gather the objects from the context as a list."""
|
126
143
|
gather_suffix: Optional[str] = None
|
@@ -129,13 +146,12 @@ class GatherAsList(Action):
|
|
129
146
|
"""Gather the objects from the context as a list."""
|
130
147
|
|
131
148
|
async def _execute(self, **cxt) -> List[Any]:
|
132
|
-
|
133
149
|
if self.gather_suffix is not None:
|
134
150
|
result = [cxt[k] for k in cxt if k.endswith(self.gather_suffix)]
|
135
151
|
logger.debug(f"Gathered {len(result)} items with suffix {self.gather_suffix}")
|
136
152
|
return result
|
137
|
-
if
|
138
|
-
logger.error(err:="Either `gather_suffix` or `gather_prefix` must be specified.")
|
153
|
+
if self.gather_prefix is None:
|
154
|
+
logger.error(err := "Either `gather_suffix` or `gather_prefix` must be specified.")
|
139
155
|
raise ValueError(err)
|
140
156
|
result = [cxt[k] for k in cxt if k.startswith(self.gather_prefix)]
|
141
157
|
logger.debug(f"Gathered {len(result)} items with prefix {self.gather_prefix}")
|
fabricatio/actions/rules.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
"""A module containing the DraftRuleSet action."""
|
2
2
|
|
3
|
-
from typing import Optional
|
3
|
+
from typing import List, Optional
|
4
4
|
|
5
5
|
from fabricatio.capabilities.check import Check
|
6
|
+
from fabricatio.journal import logger
|
6
7
|
from fabricatio.models.action import Action
|
7
8
|
from fabricatio.models.extra.rule import RuleSet
|
8
9
|
from fabricatio.utils import ok
|
@@ -18,9 +19,10 @@ class DraftRuleSet(Action, Check):
|
|
18
19
|
"""The natural language description of the desired ruleset characteristics."""
|
19
20
|
rule_count: int = 0
|
20
21
|
"""The number of rules to generate in the ruleset (0 for no restriction)."""
|
22
|
+
|
21
23
|
async def _execute(
|
22
24
|
self,
|
23
|
-
ruleset_requirement: Optional[str]=None,
|
25
|
+
ruleset_requirement: Optional[str] = None,
|
24
26
|
**_,
|
25
27
|
) -> Optional[RuleSet]:
|
26
28
|
"""Draft a ruleset based on the requirement description.
|
@@ -33,7 +35,38 @@ class DraftRuleSet(Action, Check):
|
|
33
35
|
Returns:
|
34
36
|
Optional[RuleSet]: Drafted ruleset object or None if generation fails
|
35
37
|
"""
|
36
|
-
|
37
|
-
ruleset_requirement=ok(ruleset_requirement or self.ruleset_requirement,"No ruleset requirement provided"),
|
38
|
+
ruleset = await self.draft_ruleset(
|
39
|
+
ruleset_requirement=ok(ruleset_requirement or self.ruleset_requirement, "No ruleset requirement provided"),
|
38
40
|
rule_count=self.rule_count,
|
39
41
|
)
|
42
|
+
if ruleset:
|
43
|
+
logger.info(f"Drafted Ruleset length: {len(ruleset.rules)}\n{ruleset.display()}")
|
44
|
+
else:
|
45
|
+
logger.warning(f"Drafting Rule Failed for:\n{ruleset_requirement}")
|
46
|
+
return ruleset
|
47
|
+
|
48
|
+
|
49
|
+
class GatherRuleset(Action):
|
50
|
+
"""Action to gather a ruleset from a given requirement description."""
|
51
|
+
|
52
|
+
output_key: str = "gathered_ruleset"
|
53
|
+
"""The key used to store the drafted ruleset in the context dictionary."""
|
54
|
+
|
55
|
+
to_gather: List[str]
|
56
|
+
"""the cxt name of RuleSet to gather"""
|
57
|
+
|
58
|
+
async def _execute(self, **cxt) -> RuleSet:
|
59
|
+
logger.info(f"Gathering Ruleset from {self.to_gather}")
|
60
|
+
# Fix for not_found
|
61
|
+
not_found = next((t for t in self.to_gather if t not in cxt), None)
|
62
|
+
if not_found:
|
63
|
+
raise ValueError(
|
64
|
+
f"Not all required keys found in context: {self.to_gather}|`{not_found}` not found in context."
|
65
|
+
)
|
66
|
+
|
67
|
+
# Fix for invalid RuleSet check
|
68
|
+
invalid = next((t for t in self.to_gather if not isinstance(cxt[t], RuleSet)), None)
|
69
|
+
if invalid is not None:
|
70
|
+
raise TypeError(f"Invalid RuleSet instance for key `{invalid}`")
|
71
|
+
|
72
|
+
return RuleSet.gather(*[cxt[t] for t in self.to_gather])
|
@@ -8,6 +8,8 @@ from typing import Optional, Unpack
|
|
8
8
|
|
9
9
|
from fabricatio.capabilities.check import Check
|
10
10
|
from fabricatio.capabilities.correct import Correct
|
11
|
+
from fabricatio.journal import logger
|
12
|
+
from fabricatio.models.extra.problem import Improvement
|
11
13
|
from fabricatio.models.extra.rule import RuleSet
|
12
14
|
from fabricatio.models.generic import ProposedUpdateAble, SketchedAble
|
13
15
|
from fabricatio.models.kwargs_types import ReferencedKwargs
|
@@ -41,7 +43,11 @@ class Censor(Correct, Check):
|
|
41
43
|
imp = await self.check_obj(obj, ruleset, **override_kwargs(kwargs, default=None))
|
42
44
|
if imp is None:
|
43
45
|
return None
|
44
|
-
|
46
|
+
if not imp:
|
47
|
+
logger.info(f"No improvement found for `{obj.__class__.__name__}`.")
|
48
|
+
return obj
|
49
|
+
logger.info(f'Generated {len(imp)} improvement(s) for `{obj.__class__.__name__}')
|
50
|
+
return await self.correct_obj(obj, Improvement.gather(*imp), **kwargs)
|
45
51
|
|
46
52
|
async def censor_string(
|
47
53
|
self, input_text: str, ruleset: RuleSet, **kwargs: Unpack[ReferencedKwargs[str]]
|
@@ -61,8 +67,13 @@ class Censor(Correct, Check):
|
|
61
67
|
"""
|
62
68
|
imp = await self.check_string(input_text, ruleset, **override_kwargs(kwargs, default=None))
|
63
69
|
if imp is None:
|
64
|
-
|
65
|
-
|
70
|
+
logger.warning(f"Censor failed for string:\n{input_text}")
|
71
|
+
return None
|
72
|
+
if not imp:
|
73
|
+
logger.info("No improvement found for string.")
|
74
|
+
return input_text
|
75
|
+
logger.info(f'Generated {len(imp)} improvement(s) for string.')
|
76
|
+
return await self.correct_string(input_text, Improvement.gather(*imp), **kwargs)
|
66
77
|
|
67
78
|
async def censor_obj_inplace[M: ProposedUpdateAble](
|
68
79
|
self, obj: M, ruleset: RuleSet, **kwargs: Unpack[ReferencedKwargs[M]]
|
@@ -84,5 +95,10 @@ class Censor(Correct, Check):
|
|
84
95
|
"""
|
85
96
|
imp = await self.check_obj(obj, ruleset, **override_kwargs(kwargs, default=None))
|
86
97
|
if imp is None:
|
87
|
-
|
88
|
-
|
98
|
+
logger.warning(f"Censor failed for `{obj.__class__.__name__}`")
|
99
|
+
return None
|
100
|
+
if not imp:
|
101
|
+
logger.info(f"No improvement found for `{obj.__class__.__name__}`.")
|
102
|
+
return obj
|
103
|
+
logger.info(f'Generated {len(imp)} improvement(s) for `{obj.__class__.__name__}')
|
104
|
+
return await self.correct_obj_inplace(obj, improvement=Improvement.gather(*imp), **kwargs)
|