fabricatio 0.2.10.dev0__cp312-cp312-win_amd64.whl → 0.2.11__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +55 -10
- fabricatio/actions/article_rag.py +297 -12
- fabricatio/actions/fs.py +25 -0
- fabricatio/actions/output.py +17 -3
- fabricatio/actions/rag.py +42 -20
- fabricatio/actions/rules.py +14 -3
- fabricatio/capabilities/extract.py +70 -0
- fabricatio/capabilities/rag.py +5 -2
- fabricatio/capabilities/rating.py +5 -2
- fabricatio/capabilities/task.py +16 -16
- fabricatio/config.py +9 -2
- fabricatio/decorators.py +43 -26
- fabricatio/fs/__init__.py +9 -2
- fabricatio/fs/readers.py +6 -10
- fabricatio/models/action.py +16 -11
- fabricatio/models/adv_kwargs_types.py +5 -12
- fabricatio/models/extra/aricle_rag.py +254 -0
- fabricatio/models/extra/article_base.py +56 -7
- fabricatio/models/extra/article_essence.py +8 -7
- fabricatio/models/extra/article_main.py +102 -6
- fabricatio/models/extra/problem.py +5 -1
- fabricatio/models/extra/rag.py +49 -23
- fabricatio/models/generic.py +43 -24
- fabricatio/models/kwargs_types.py +12 -3
- fabricatio/models/task.py +13 -1
- fabricatio/models/usages.py +10 -27
- fabricatio/parser.py +16 -12
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +177 -63
- fabricatio/utils.py +50 -10
- fabricatio-0.2.11.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/METADATA +20 -12
- fabricatio-0.2.11.dist-info/RECORD +65 -0
- fabricatio-0.2.10.dev0.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.10.dev0.dist-info/RECORD +0 -62
- {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.10.dev0.dist-info → fabricatio-0.2.11.dist-info}/licenses/LICENSE +0 -0
fabricatio/actions/article.py
CHANGED
@@ -5,8 +5,11 @@ from pathlib import Path
|
|
5
5
|
from typing import Callable, List, Optional
|
6
6
|
|
7
7
|
from more_itertools import filter_map
|
8
|
+
from pydantic import Field
|
9
|
+
from rich import print as r_print
|
8
10
|
|
9
11
|
from fabricatio.capabilities.censor import Censor
|
12
|
+
from fabricatio.capabilities.extract import Extract
|
10
13
|
from fabricatio.capabilities.propose import Propose
|
11
14
|
from fabricatio.fs import safe_text_read
|
12
15
|
from fabricatio.journal import logger
|
@@ -16,9 +19,10 @@ from fabricatio.models.extra.article_main import Article
|
|
16
19
|
from fabricatio.models.extra.article_outline import ArticleOutline
|
17
20
|
from fabricatio.models.extra.article_proposal import ArticleProposal
|
18
21
|
from fabricatio.models.extra.rule import RuleSet
|
22
|
+
from fabricatio.models.kwargs_types import ValidateKwargs
|
19
23
|
from fabricatio.models.task import Task
|
20
24
|
from fabricatio.rust import BibManager, detect_language
|
21
|
-
from fabricatio.utils import ok
|
25
|
+
from fabricatio.utils import ok, wrapp_in_block
|
22
26
|
|
23
27
|
|
24
28
|
class ExtractArticleEssence(Action, Propose):
|
@@ -78,7 +82,7 @@ class FixArticleEssence(Action):
|
|
78
82
|
out = []
|
79
83
|
count = 0
|
80
84
|
for a in article_essence:
|
81
|
-
if key := (bib_mgr.
|
85
|
+
if key := (bib_mgr.get_cite_key_by_title(a.title) or bib_mgr.get_cite_key_fuzzy(a.title)):
|
82
86
|
a.title = bib_mgr.get_title_by_key(key) or a.title
|
83
87
|
a.authors = bib_mgr.get_author_by_key(key) or a.authors
|
84
88
|
a.publication_year = bib_mgr.get_year_by_key(key) or a.publication_year
|
@@ -130,33 +134,65 @@ class GenerateArticleProposal(Action, Propose):
|
|
130
134
|
).update_ref(briefing)
|
131
135
|
|
132
136
|
|
133
|
-
class GenerateInitialOutline(Action,
|
137
|
+
class GenerateInitialOutline(Action, Extract):
|
134
138
|
"""Generate the initial article outline based on the article proposal."""
|
135
139
|
|
136
140
|
output_key: str = "initial_article_outline"
|
137
141
|
"""The key of the output data."""
|
138
142
|
|
143
|
+
supervisor: bool = False
|
144
|
+
"""Whether to use the supervisor to fix the outline."""
|
145
|
+
|
146
|
+
extract_kwargs: ValidateKwargs[Optional[ArticleOutline]] = Field(default_factory=ValidateKwargs)
|
147
|
+
"""The kwargs to extract the outline."""
|
148
|
+
|
139
149
|
async def _execute(
|
140
150
|
self,
|
141
151
|
article_proposal: ArticleProposal,
|
152
|
+
supervisor: Optional[bool] = None,
|
142
153
|
**_,
|
143
154
|
) -> Optional[ArticleOutline]:
|
144
|
-
|
145
|
-
f"{(article_proposal.as_prompt())}\n\nNote that you should use `{article_proposal.language}` to write the `ArticleOutline`\n"
|
155
|
+
req = (
|
146
156
|
f"Design each chapter of a proper and academic and ready for release manner.\n"
|
147
157
|
f"You Must make sure every chapter have sections, and every section have subsections.\n"
|
148
|
-
f"Make the chapter and sections and subsections bing divided into a specific enough article component
|
158
|
+
f"Make the chapter and sections and subsections bing divided into a specific enough article component.\n"
|
159
|
+
f"Every chapter must have sections, every section must have subsections.\n"
|
160
|
+
f"Note that you SHALL use `{article_proposal.language}` as written language",
|
149
161
|
)
|
150
162
|
|
163
|
+
raw_outline = await self.aask(f"{(article_proposal.as_prompt())}\n{req}")
|
164
|
+
|
165
|
+
if supervisor or (supervisor is None and self.supervisor):
|
166
|
+
from questionary import confirm, text
|
167
|
+
|
168
|
+
r_print(raw_outline)
|
169
|
+
while not await confirm("Accept this version and continue?", default=True).ask_async():
|
170
|
+
imp = await text("Enter the improvement:").ask_async()
|
171
|
+
raw_outline = await self.aask(
|
172
|
+
f"{article_proposal.as_prompt()}\n{wrapp_in_block(raw_outline, 'Previous ArticleOutline')}\n{req}\n{wrapp_in_block(imp, title='Improvement')}"
|
173
|
+
)
|
174
|
+
r_print(raw_outline)
|
175
|
+
|
151
176
|
return ok(
|
152
|
-
await self.
|
153
|
-
ArticleOutline,
|
154
|
-
f"{raw_outline}\n\n\n\noutline provided above is the outline i need to extract to a JSON,",
|
155
|
-
),
|
177
|
+
await self.extract(ArticleOutline, raw_outline, **self.extract_kwargs),
|
156
178
|
"Could not generate the initial outline.",
|
157
179
|
).update_ref(article_proposal)
|
158
180
|
|
159
181
|
|
182
|
+
class ExtractOutlineFromRaw(Action, Extract):
|
183
|
+
"""Extract the outline from the raw outline."""
|
184
|
+
|
185
|
+
output_key: str = "article_outline_from_raw"
|
186
|
+
|
187
|
+
async def _execute(self, article_outline_raw_path: str | Path, **cxt) -> ArticleOutline:
|
188
|
+
logger.info(f"Extracting outline from raw: {Path(article_outline_raw_path).as_posix()}")
|
189
|
+
|
190
|
+
return ok(
|
191
|
+
await self.extract(ArticleOutline, safe_text_read(article_outline_raw_path)),
|
192
|
+
"Could not extract the outline from raw.",
|
193
|
+
)
|
194
|
+
|
195
|
+
|
160
196
|
class FixIntrospectedErrors(Action, Censor):
|
161
197
|
"""Fix introspected errors in the article outline."""
|
162
198
|
|
@@ -226,3 +262,12 @@ class GenerateArticle(Action, Censor):
|
|
226
262
|
)
|
227
263
|
|
228
264
|
return article
|
265
|
+
|
266
|
+
|
267
|
+
class LoadArticle(Action):
|
268
|
+
"""Load the article from the outline and typst code."""
|
269
|
+
|
270
|
+
output_key: str = "loaded_article"
|
271
|
+
|
272
|
+
async def _execute(self, article_outline: ArticleOutline, typst_code: str, **cxt) -> Article:
|
273
|
+
return Article.from_mixed_source(article_outline, typst_code)
|
@@ -1,14 +1,276 @@
|
|
1
1
|
"""A module for writing articles using RAG (Retrieval-Augmented Generation) capabilities."""
|
2
2
|
|
3
3
|
from asyncio import gather
|
4
|
-
from
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import List, Optional
|
5
6
|
|
7
|
+
from fabricatio import BibManager
|
6
8
|
from fabricatio.capabilities.censor import Censor
|
9
|
+
from fabricatio.capabilities.extract import Extract
|
7
10
|
from fabricatio.capabilities.rag import RAG
|
11
|
+
from fabricatio.decorators import precheck_package
|
12
|
+
from fabricatio.journal import logger
|
8
13
|
from fabricatio.models.action import Action
|
9
|
-
from fabricatio.models.extra.
|
14
|
+
from fabricatio.models.extra.aricle_rag import ArticleChunk, CitationManager
|
15
|
+
from fabricatio.models.extra.article_essence import ArticleEssence
|
16
|
+
from fabricatio.models.extra.article_main import Article, ArticleChapter, ArticleSection, ArticleSubsection
|
17
|
+
from fabricatio.models.extra.article_outline import ArticleOutline
|
10
18
|
from fabricatio.models.extra.rule import RuleSet
|
11
|
-
from fabricatio.utils import ok
|
19
|
+
from fabricatio.utils import ask_retain, ok
|
20
|
+
|
21
|
+
TYPST_CITE_USAGE = (
|
22
|
+
"citation number is REQUIRED to cite any reference!,for example in Auther Pattern: 'Doe et al.[[1]], Jack et al.[[2]]' or in Sentence Suffix Sattern: 'Global requirement is incresing[[1]].'\n"
|
23
|
+
"Everything is build upon the typst language, which is similar to latex, \n"
|
24
|
+
"Legal citing syntax examples(seperated by |): [[1]]|[[1,2]]|[[1-3]]|[[12,13-15]]|[[1-3,5-7]]\n"
|
25
|
+
"Illegal citing syntax examples(seperated by |): [[1],[2],[3]]|[[1],[1-2]]\n"
|
26
|
+
"Those reference mark shall not be omitted during the extraction\n"
|
27
|
+
"It's recommended to cite multiple references that supports your conclusion at a time.\n"
|
28
|
+
"Wrapp inline expression using $ $,like '$>5m$' '$89%$' , and wrapp block equation using $$ $$. if you are using '$' as the money unit, you should add a '\\' before it to avoid being interpreted as a inline equation. For example 'The pants worths 5\\$.'\n"
|
29
|
+
"In addition to that, you can add a label outside the block equation which can be used as a cross reference identifier, the label is a string wrapped in `<` and `>` like `<energy-release-rate-equation>`.Note that the label string should be a summarizing title for the equation being labeled.\n"
|
30
|
+
"you can refer to that label by using the syntax with prefix of `@eqt:`, which indicate that this notation is citing a label from the equations. For example ' @eqt:energy-release-rate-equation ' DO remember that the notation shall have both suffixed and prefixed space char which enable the compiler to distinguish the notation from the plaintext."
|
31
|
+
"Below is a usage example:\n"
|
32
|
+
"```typst\n"
|
33
|
+
"See @eqt:mass-energy-equation , it's the foundation of physics.\n"
|
34
|
+
"$$\n"
|
35
|
+
"E = m c^2\n"
|
36
|
+
"$$ <mass-energy-equation>\n\n\n"
|
37
|
+
"In @eqt:mass-energy-equation , $m$ stands for mass, $c$ stands for speed of light, and $E$ stands for energy. \n"
|
38
|
+
"```"
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
class WriteArticleContentRAG(Action, RAG, Extract):
|
43
|
+
"""Write an article based on the provided outline."""
|
44
|
+
|
45
|
+
ref_limit: int = 35
|
46
|
+
"""The limit of references to be retrieved"""
|
47
|
+
threshold: float = 0.62
|
48
|
+
"""The threshold of relevance"""
|
49
|
+
extractor_model: str
|
50
|
+
"""The model to use for extracting the content from the retrieved references."""
|
51
|
+
query_model: str
|
52
|
+
"""The model to use for querying the database"""
|
53
|
+
supervisor: bool = False
|
54
|
+
"""Whether to use supervisor mode"""
|
55
|
+
req: str = TYPST_CITE_USAGE
|
56
|
+
|
57
|
+
async def _execute(
|
58
|
+
self,
|
59
|
+
article_outline: ArticleOutline,
|
60
|
+
collection_name: Optional[str] = None,
|
61
|
+
supervisor: Optional[bool] = None,
|
62
|
+
**cxt,
|
63
|
+
) -> Article:
|
64
|
+
article = Article.from_outline(article_outline).update_ref(article_outline)
|
65
|
+
self.target_collection = collection_name or self.safe_target_collection
|
66
|
+
if supervisor or (supervisor is None and self.supervisor):
|
67
|
+
for chap, sec, subsec in article.iter_subsections():
|
68
|
+
await self._supervisor_inner(article, article_outline, chap, sec, subsec)
|
69
|
+
|
70
|
+
else:
|
71
|
+
await gather(
|
72
|
+
*[
|
73
|
+
self._inner(article, article_outline, chap, sec, subsec)
|
74
|
+
for chap, sec, subsec in article.iter_subsections()
|
75
|
+
]
|
76
|
+
)
|
77
|
+
return article.convert_tex()
|
78
|
+
|
79
|
+
@precheck_package(
|
80
|
+
"questionary", "`questionary` is required for supervisor mode, please install it by `fabricatio[qa]`"
|
81
|
+
)
|
82
|
+
async def _supervisor_inner(
|
83
|
+
self,
|
84
|
+
article: Article,
|
85
|
+
article_outline: ArticleOutline,
|
86
|
+
chap: ArticleChapter,
|
87
|
+
sec: ArticleSection,
|
88
|
+
subsec: ArticleSubsection,
|
89
|
+
) -> ArticleSubsection:
|
90
|
+
from questionary import confirm, text
|
91
|
+
from rich import print as r_print
|
92
|
+
|
93
|
+
ret = await self.search_database(article, article_outline, chap, sec, subsec)
|
94
|
+
|
95
|
+
cm = CitationManager(article_chunks=await ask_retain([r.chunk for r in ret], ret)).set_cite_number_all()
|
96
|
+
|
97
|
+
raw = await self.write_raw(article, article_outline, chap, sec, subsec, cm)
|
98
|
+
r_print(raw)
|
99
|
+
|
100
|
+
while not await confirm("Accept this version and continue?").ask_async():
|
101
|
+
if inst := await text("Search for more refs for additional spec.").ask_async():
|
102
|
+
new_refs = await self.search_database(
|
103
|
+
article,
|
104
|
+
article_outline,
|
105
|
+
chap,
|
106
|
+
sec,
|
107
|
+
subsec,
|
108
|
+
supervisor=True,
|
109
|
+
extra_instruction=inst,
|
110
|
+
)
|
111
|
+
cm.add_chunks(await ask_retain([r.chunk for r in new_refs], new_refs))
|
112
|
+
|
113
|
+
if instruction := await text("Enter the instructions to improve").ask_async():
|
114
|
+
raw = await self.write_raw(article, article_outline, chap, sec, subsec, cm, instruction)
|
115
|
+
if edt := await text("Edit", default=raw).ask_async():
|
116
|
+
raw = edt
|
117
|
+
|
118
|
+
r_print(raw)
|
119
|
+
|
120
|
+
return await self.extract_new_subsec(subsec, raw, cm)
|
121
|
+
|
122
|
+
async def _inner(
|
123
|
+
self,
|
124
|
+
article: Article,
|
125
|
+
article_outline: ArticleOutline,
|
126
|
+
chap: ArticleChapter,
|
127
|
+
sec: ArticleSection,
|
128
|
+
subsec: ArticleSubsection,
|
129
|
+
) -> ArticleSubsection:
|
130
|
+
ret = await self.search_database(article, article_outline, chap, sec, subsec)
|
131
|
+
cm = CitationManager(article_chunks=ret).set_cite_number_all()
|
132
|
+
|
133
|
+
raw_paras = await self.write_raw(article, article_outline, chap, sec, subsec, cm)
|
134
|
+
|
135
|
+
return await self.extract_new_subsec(subsec, raw_paras, cm)
|
136
|
+
|
137
|
+
async def extract_new_subsec(
|
138
|
+
self, subsec: ArticleSubsection, raw_paras: str, cm: CitationManager
|
139
|
+
) -> ArticleSubsection:
|
140
|
+
"""Extract the new subsec."""
|
141
|
+
new_subsec = ok(
|
142
|
+
await self.extract(
|
143
|
+
ArticleSubsection,
|
144
|
+
raw_paras,
|
145
|
+
f"Above is the subsection titled `{subsec.title}`.\n"
|
146
|
+
f"I need you to extract the content to update my subsection obj provided below.\n{self.req}"
|
147
|
+
f"{subsec.display()}\n",
|
148
|
+
model=self.extractor_model,
|
149
|
+
),
|
150
|
+
"Failed to propose new subsection.",
|
151
|
+
)
|
152
|
+
for p in new_subsec.paragraphs:
|
153
|
+
p.content = cm.apply(p.content).replace("$$", "\n$$\n")
|
154
|
+
subsec.update_from(new_subsec)
|
155
|
+
logger.debug(f"{subsec.title}:rpl\n{subsec.display()}")
|
156
|
+
return subsec
|
157
|
+
|
158
|
+
async def write_raw(
|
159
|
+
self,
|
160
|
+
article: Article,
|
161
|
+
article_outline: ArticleOutline,
|
162
|
+
chap: ArticleChapter,
|
163
|
+
sec: ArticleSection,
|
164
|
+
subsec: ArticleSubsection,
|
165
|
+
cm: CitationManager,
|
166
|
+
extra_instruction: str = "",
|
167
|
+
) -> str:
|
168
|
+
"""Write the raw paragraphs of the subsec."""
|
169
|
+
return (
|
170
|
+
(
|
171
|
+
await self.aask(
|
172
|
+
f"{cm.as_prompt()}\nAbove is some related reference from other auther retrieved for you."
|
173
|
+
f"{article_outline.finalized_dump()}\n\nAbove is my article outline, I m writing graduate thesis titled `{article.title}`. "
|
174
|
+
f"More specifically, i m witting the Chapter `{chap.title}` >> Section `{sec.title}` >> Subsection `{subsec.title}`.\n"
|
175
|
+
f"Please help me write the paragraphs of the subsec mentioned above, which is `{subsec.title}`.\n"
|
176
|
+
f"{self.req}\n"
|
177
|
+
f"You SHALL use `{article.language}` as writing language.\n{extra_instruction}\n"
|
178
|
+
f"Do not use numbered list to display the outcome, you should regard you are writing the main text of the article\n"
|
179
|
+
f"You should not copy others' works from the references directly on to my thesis, we can only harness the conclusion they have drawn."
|
180
|
+
)
|
181
|
+
)
|
182
|
+
.replace(r" \( ", "$")
|
183
|
+
.replace(r" \) ", "$")
|
184
|
+
.replace(r"\(", "$")
|
185
|
+
.replace(r"\)", "$")
|
186
|
+
.replace("\\[\n", "$$\n")
|
187
|
+
.replace("\\[ ", "$$\n")
|
188
|
+
.replace("\n\\]", "\n$$")
|
189
|
+
.replace(" \\]", "\n$$")
|
190
|
+
)
|
191
|
+
|
192
|
+
async def search_database(
|
193
|
+
self,
|
194
|
+
article: Article,
|
195
|
+
article_outline: ArticleOutline,
|
196
|
+
chap: ArticleChapter,
|
197
|
+
sec: ArticleSection,
|
198
|
+
subsec: ArticleSubsection,
|
199
|
+
extra_instruction: str = "",
|
200
|
+
supervisor: bool = False,
|
201
|
+
) -> List[ArticleChunk]:
|
202
|
+
"""Search database for related references."""
|
203
|
+
ref_q = ok(
|
204
|
+
await self.arefined_query(
|
205
|
+
f"{article_outline.finalized_dump()}\n\nAbove is my article outline, I m writing graduate thesis titled `{article.title}`. "
|
206
|
+
f"More specifically, i m witting the Chapter `{chap.title}` >> Section `{sec.title}` >> Subsection `{subsec.title}`.\n"
|
207
|
+
f"I need to search related references to build up the content of the subsec mentioned above, which is `{subsec.title}`.\n"
|
208
|
+
f"provide 10~16 queries as possible, to get best result!\n"
|
209
|
+
f"You should provide both English version and chinese version of the refined queries!\n{extra_instruction}\n",
|
210
|
+
model=self.query_model,
|
211
|
+
),
|
212
|
+
"Failed to refine query.",
|
213
|
+
)
|
214
|
+
|
215
|
+
if supervisor:
|
216
|
+
ref_q = await ask_retain(ref_q)
|
217
|
+
|
218
|
+
return await self.aretrieve(
|
219
|
+
ref_q, ArticleChunk, final_limit=self.ref_limit, result_per_query=3, similarity_threshold=self.threshold
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
class ArticleConsultRAG(Action, RAG):
|
224
|
+
"""Write an article based on the provided outline."""
|
225
|
+
|
226
|
+
output_key:str ="consult_count"
|
227
|
+
|
228
|
+
ref_limit: int = 20
|
229
|
+
"""The final limit of references."""
|
230
|
+
ref_per_q: int = 3
|
231
|
+
"""The limit of references to retrieve per query."""
|
232
|
+
similarity_threshold: float = 0.62
|
233
|
+
"""The similarity threshold of references to retrieve."""
|
234
|
+
ref_q_model: Optional[str] = None
|
235
|
+
"""The model to use for refining query."""
|
236
|
+
req: str = TYPST_CITE_USAGE
|
237
|
+
"""The request for the rag model."""
|
238
|
+
|
239
|
+
@precheck_package(
|
240
|
+
"questionary", "`questionary` is required for supervisor mode, please install it by `fabricatio[qa]`"
|
241
|
+
)
|
242
|
+
async def _execute(self, collection_name: Optional[str] = None, **cxt) -> int:
|
243
|
+
from questionary import confirm, text
|
244
|
+
from rich import print as r_print
|
245
|
+
|
246
|
+
from fabricatio.rust import convert_all_block_tex, convert_all_inline_tex
|
247
|
+
|
248
|
+
self.target_collection = collection_name or self.safe_target_collection
|
249
|
+
|
250
|
+
cm = CitationManager()
|
251
|
+
|
252
|
+
counter = 0
|
253
|
+
while (req := await text("User: ").ask_async()) is not None:
|
254
|
+
if await confirm("Empty the cm?").ask_async():
|
255
|
+
cm.empty()
|
256
|
+
ref_q = await self.arefined_query(req, model=self.ref_q_model)
|
257
|
+
refs = await self.aretrieve(
|
258
|
+
ok(ref_q, "Failed to refine query."),
|
259
|
+
ArticleChunk,
|
260
|
+
final_limit=self.ref_limit,
|
261
|
+
result_per_query=self.ref_per_q,
|
262
|
+
similarity_threshold=self.similarity_threshold,
|
263
|
+
)
|
264
|
+
|
265
|
+
ret = await self.aask(f"{cm.add_chunks(refs).as_prompt()}\n{self.req}\n{req}")
|
266
|
+
ret = convert_all_inline_tex(ret)
|
267
|
+
ret = convert_all_block_tex(ret)
|
268
|
+
ret = cm.apply(ret)
|
269
|
+
|
270
|
+
r_print(ret)
|
271
|
+
counter += 1
|
272
|
+
logger.info(f"{counter} rounds of conversation.")
|
273
|
+
return counter
|
12
274
|
|
13
275
|
|
14
276
|
class TweakArticleRAG(Action, RAG, Censor):
|
@@ -39,7 +301,7 @@ class TweakArticleRAG(Action, RAG, Censor):
|
|
39
301
|
twk_rag_ruleset: Optional[RuleSet] = None,
|
40
302
|
parallel: bool = False,
|
41
303
|
**cxt,
|
42
|
-
) ->
|
304
|
+
) -> Article:
|
43
305
|
"""Write an article based on the provided outline.
|
44
306
|
|
45
307
|
This method processes the article outline, either in parallel or sequentially, by enhancing each subsection
|
@@ -53,7 +315,7 @@ class TweakArticleRAG(Action, RAG, Censor):
|
|
53
315
|
**cxt: Additional context parameters.
|
54
316
|
|
55
317
|
Returns:
|
56
|
-
|
318
|
+
Article: The processed article with enhanced subsections and applied censoring rules.
|
57
319
|
"""
|
58
320
|
self.view(collection_name)
|
59
321
|
|
@@ -86,20 +348,43 @@ class TweakArticleRAG(Action, RAG, Censor):
|
|
86
348
|
"""
|
87
349
|
refind_q = ok(
|
88
350
|
await self.arefined_query(
|
89
|
-
f"{article.referenced.as_prompt()}\n"
|
90
|
-
f"# Subsection requiring reference enhancement\n"
|
91
|
-
f"{subsec.display()}\n"
|
92
|
-
f"# Requirement\n"
|
93
|
-
f"Search related articles in the base to find reference candidates, "
|
94
|
-
f"provide queries in both `English` and `{subsec.language}` can get more accurate results.",
|
351
|
+
f"{article.referenced.as_prompt()}\n# Subsection requiring reference enhancement\n{subsec.display()}\n"
|
95
352
|
)
|
96
353
|
)
|
97
354
|
await self.censor_obj_inplace(
|
98
355
|
subsec,
|
99
356
|
ruleset=ruleset,
|
100
|
-
reference=f"{await self.
|
357
|
+
reference=f"{'\n\n'.join(d.display() for d in await self.aretrieve(refind_q, document_model=ArticleEssence, final_limit=self.ref_limit))}\n\n"
|
101
358
|
f"You can use Reference above to rewrite the `{subsec.__class__.__name__}`.\n"
|
102
359
|
f"You should Always use `{subsec.language}` as written language, "
|
103
360
|
f"which is the original language of the `{subsec.title}`. "
|
104
361
|
f"since rewrite a `{subsec.__class__.__name__}` in a different language is usually a bad choice",
|
105
362
|
)
|
363
|
+
|
364
|
+
|
365
|
+
class ChunkArticle(Action):
|
366
|
+
"""Chunk an article into smaller chunks."""
|
367
|
+
|
368
|
+
output_key: str = "article_chunks"
|
369
|
+
"""The key used to store the output of the action."""
|
370
|
+
max_chunk_size: Optional[int] = None
|
371
|
+
"""The maximum size of each chunk."""
|
372
|
+
max_overlapping_rate: Optional[float] = None
|
373
|
+
"""The maximum overlapping rate between chunks."""
|
374
|
+
|
375
|
+
async def _execute(
|
376
|
+
self,
|
377
|
+
article_path: str | Path,
|
378
|
+
bib_manager: BibManager,
|
379
|
+
max_chunk_size: Optional[int] = None,
|
380
|
+
max_overlapping_rate: Optional[float] = None,
|
381
|
+
**_,
|
382
|
+
) -> List[ArticleChunk]:
|
383
|
+
return ArticleChunk.from_file(
|
384
|
+
article_path,
|
385
|
+
bib_manager,
|
386
|
+
max_chunk_size=ok(max_chunk_size or self.max_chunk_size, "No max_chunk_size provided!"),
|
387
|
+
max_overlapping_rate=ok(
|
388
|
+
max_overlapping_rate or self.max_overlapping_rate, "No max_overlapping_rate provided!"
|
389
|
+
),
|
390
|
+
)
|
fabricatio/actions/fs.py
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
"""A module for file system utilities."""
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Any, List, Mapping, Self
|
5
|
+
|
6
|
+
from fabricatio.fs import safe_text_read
|
7
|
+
from fabricatio.journal import logger
|
8
|
+
from fabricatio.models.action import Action
|
9
|
+
from fabricatio.models.generic import FromMapping
|
10
|
+
|
11
|
+
|
12
|
+
class ReadText(Action, FromMapping):
|
13
|
+
"""Read text from a file."""
|
14
|
+
output_key: str = "read_text"
|
15
|
+
read_path: str | Path
|
16
|
+
"""Path to the file to read."""
|
17
|
+
|
18
|
+
async def _execute(self, *_: Any, **cxt) -> str:
|
19
|
+
logger.info(f"Read text from {Path(self.read_path).as_posix()} to {self.output_key}")
|
20
|
+
return safe_text_read(self.read_path)
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def from_mapping(cls, mapping: Mapping[str, str | Path], **kwargs: Any) -> List[Self]:
|
24
|
+
"""Create a list of ReadText actions from a mapping of output_key to read_path."""
|
25
|
+
return [cls(read_path=p, output_key=k, **kwargs) for k, p in mapping.items()]
|
fabricatio/actions/output.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
"""Dump the finalized output to a file."""
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Any, Iterable, List, Optional, Type
|
4
|
+
from typing import Any, Iterable, List, Mapping, Optional, Type
|
5
5
|
|
6
6
|
from fabricatio.journal import logger
|
7
7
|
from fabricatio.models.action import Action
|
8
|
-
from fabricatio.models.generic import FinalizedDumpAble, PersistentAble
|
8
|
+
from fabricatio.models.generic import FinalizedDumpAble, FromMapping, PersistentAble
|
9
9
|
from fabricatio.models.task import Task
|
10
10
|
from fabricatio.utils import ok
|
11
11
|
|
@@ -115,7 +115,7 @@ class RetrieveFromPersistent[T: PersistentAble](Action):
|
|
115
115
|
return self.retrieve_cls.from_persistent(self.load_path)
|
116
116
|
|
117
117
|
|
118
|
-
class RetrieveFromLatest[T: PersistentAble](RetrieveFromPersistent[T]):
|
118
|
+
class RetrieveFromLatest[T: PersistentAble](RetrieveFromPersistent[T], FromMapping):
|
119
119
|
"""Retrieve the object from the latest persistent file in the dir at `load_path`."""
|
120
120
|
|
121
121
|
async def _execute(self, /, **_) -> Optional[T]:
|
@@ -130,6 +130,20 @@ class RetrieveFromLatest[T: PersistentAble](RetrieveFromPersistent[T]):
|
|
130
130
|
logger.error(f"Path {self.load_path} is not a directory")
|
131
131
|
return None
|
132
132
|
|
133
|
+
@classmethod
|
134
|
+
def from_mapping(
|
135
|
+
cls,
|
136
|
+
mapping: Mapping[str, str | Path],
|
137
|
+
*,
|
138
|
+
retrieve_cls: Type[T],
|
139
|
+
**kwargs,
|
140
|
+
) -> List["RetrieveFromLatest[T]"]:
|
141
|
+
"""Create a list of `RetrieveFromLatest` from the mapping."""
|
142
|
+
return [
|
143
|
+
cls(retrieve_cls=retrieve_cls, load_path=Path(p).as_posix(), output_key=o, **kwargs)
|
144
|
+
for o, p in mapping.items()
|
145
|
+
]
|
146
|
+
|
133
147
|
|
134
148
|
class GatherAsList(Action):
|
135
149
|
"""Gather the objects from the context as a list.
|
fabricatio/actions/rag.py
CHANGED
@@ -2,37 +2,57 @@
|
|
2
2
|
|
3
3
|
from typing import List, Optional
|
4
4
|
|
5
|
-
from questionary import text
|
6
|
-
|
7
5
|
from fabricatio.capabilities.rag import RAG
|
6
|
+
from fabricatio.config import configs
|
8
7
|
from fabricatio.journal import logger
|
9
8
|
from fabricatio.models.action import Action
|
10
|
-
from fabricatio.models.
|
9
|
+
from fabricatio.models.extra.rag import MilvusClassicModel, MilvusDataBase
|
11
10
|
from fabricatio.models.task import Task
|
11
|
+
from fabricatio.utils import ok
|
12
12
|
|
13
13
|
|
14
14
|
class InjectToDB(Action, RAG):
|
15
15
|
"""Inject data into the database."""
|
16
16
|
|
17
17
|
output_key: str = "collection_name"
|
18
|
+
collection_name: str = "my_collection"
|
19
|
+
"""The name of the collection to inject data into."""
|
18
20
|
|
19
|
-
async def _execute[T:
|
20
|
-
|
21
|
+
async def _execute[T: MilvusDataBase](
|
22
|
+
self, to_inject: Optional[T] | List[Optional[T]], override_inject: bool = False, **_
|
21
23
|
) -> Optional[str]:
|
24
|
+
from pymilvus.milvus_client import IndexParams
|
25
|
+
|
26
|
+
if to_inject is None:
|
27
|
+
return None
|
22
28
|
if not isinstance(to_inject, list):
|
23
29
|
to_inject = [to_inject]
|
24
|
-
|
30
|
+
if not (seq := [t for t in to_inject if t is not None]): # filter out None
|
31
|
+
return None
|
32
|
+
logger.info(f"Injecting {len(seq)} items into the collection '{self.collection_name}'")
|
25
33
|
if override_inject:
|
26
|
-
self.check_client().client.drop_collection(collection_name)
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
34
|
+
self.check_client().client.drop_collection(self.collection_name)
|
35
|
+
|
36
|
+
await self.view(
|
37
|
+
self.collection_name,
|
38
|
+
create=True,
|
39
|
+
schema=seq[0].as_milvus_schema(
|
40
|
+
ok(
|
41
|
+
self.milvus_dimensions
|
42
|
+
or configs.rag.milvus_dimensions
|
43
|
+
or self.embedding_dimensions
|
44
|
+
or configs.embedding.dimensions
|
45
|
+
),
|
46
|
+
),
|
47
|
+
index_params=IndexParams(
|
48
|
+
seq[0].vector_field_name,
|
49
|
+
index_name=seq[0].vector_field_name,
|
50
|
+
index_type=seq[0].index_type,
|
51
|
+
metric_type=seq[0].metric_type,
|
52
|
+
),
|
53
|
+
).add_document(seq, flush=True)
|
54
|
+
|
55
|
+
return self.collection_name
|
36
56
|
|
37
57
|
|
38
58
|
class RAGTalk(Action, RAG):
|
@@ -52,6 +72,8 @@ class RAGTalk(Action, RAG):
|
|
52
72
|
output_key: str = "task_output"
|
53
73
|
|
54
74
|
async def _execute(self, task_input: Task[str], **kwargs) -> int:
|
75
|
+
from questionary import text
|
76
|
+
|
55
77
|
collection_name = kwargs.get("collection_name", "my_collection")
|
56
78
|
counter = 0
|
57
79
|
|
@@ -62,10 +84,10 @@ class RAGTalk(Action, RAG):
|
|
62
84
|
user_say = await text("User: ").ask_async()
|
63
85
|
if user_say is None:
|
64
86
|
break
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
87
|
+
ret: List[MilvusClassicModel] = await self.aretrieve(user_say, document_model=MilvusClassicModel)
|
88
|
+
|
89
|
+
gpt_say = await self.aask(
|
90
|
+
user_say, system_message="\n".join(m.text for m in ret) + "\nYou can refer facts provided above."
|
69
91
|
)
|
70
92
|
print(f"GPT: {gpt_say}") # noqa: T201
|
71
93
|
counter += 1
|