fabricatio 0.2.13.dev0__cp312-cp312-win_amd64.whl → 0.2.13.dev2__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +49 -1
- fabricatio/actions/article_rag.py +23 -59
- fabricatio/capabilities/advanced_rag.py +5 -1
- fabricatio/capabilities/rag.py +37 -5
- fabricatio/config.py +8 -3
- fabricatio/models/action.py +16 -3
- fabricatio/models/adv_kwargs_types.py +4 -1
- fabricatio/models/extra/aricle_rag.py +26 -11
- fabricatio/models/extra/article_base.py +40 -7
- fabricatio/models/extra/article_main.py +22 -33
- fabricatio/models/extra/article_outline.py +1 -3
- fabricatio/models/generic.py +3 -3
- fabricatio/models/kwargs_types.py +9 -1
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/utils.py +162 -1
- fabricatio-0.2.13.dev2.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.13.dev2.data/scripts/ttm.exe +0 -0
- {fabricatio-0.2.13.dev0.dist-info → fabricatio-0.2.13.dev2.dist-info}/METADATA +5 -1
- {fabricatio-0.2.13.dev0.dist-info → fabricatio-0.2.13.dev2.dist-info}/RECORD +21 -21
- fabricatio-0.2.13.dev0.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.13.dev0.data/scripts/ttm.exe +0 -0
- {fabricatio-0.2.13.dev0.dist-info → fabricatio-0.2.13.dev2.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.13.dev0.dist-info → fabricatio-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
fabricatio/actions/article.py
CHANGED
@@ -8,10 +8,12 @@ from more_itertools import filter_map
|
|
8
8
|
from pydantic import Field
|
9
9
|
from rich import print as r_print
|
10
10
|
|
11
|
+
from fabricatio import TEMPLATE_MANAGER
|
11
12
|
from fabricatio.capabilities.censor import Censor
|
12
13
|
from fabricatio.capabilities.extract import Extract
|
13
14
|
from fabricatio.capabilities.propose import Propose
|
14
|
-
from fabricatio.
|
15
|
+
from fabricatio.config import configs
|
16
|
+
from fabricatio.fs import dump_text, safe_text_read
|
15
17
|
from fabricatio.journal import logger
|
16
18
|
from fabricatio.models.action import Action
|
17
19
|
from fabricatio.models.extra.article_essence import ArticleEssence
|
@@ -21,6 +23,7 @@ from fabricatio.models.extra.article_proposal import ArticleProposal
|
|
21
23
|
from fabricatio.models.extra.rule import RuleSet
|
22
24
|
from fabricatio.models.kwargs_types import ValidateKwargs
|
23
25
|
from fabricatio.models.task import Task
|
26
|
+
from fabricatio.models.usages import LLMUsage
|
24
27
|
from fabricatio.rust import BibManager, detect_language
|
25
28
|
from fabricatio.utils import ok, wrapp_in_block
|
26
29
|
|
@@ -271,3 +274,48 @@ class LoadArticle(Action):
|
|
271
274
|
|
272
275
|
async def _execute(self, article_outline: ArticleOutline, typst_code: str, **cxt) -> Article:
|
273
276
|
return Article.from_mixed_source(article_outline, typst_code)
|
277
|
+
|
278
|
+
|
279
|
+
class WriteChapterSummary(Action, LLMUsage):
|
280
|
+
"""Write the chapter summary."""
|
281
|
+
|
282
|
+
output_key: str = "chapter_summaries"
|
283
|
+
|
284
|
+
paragraph_count: int = 1
|
285
|
+
|
286
|
+
summary_word_count: int = 200
|
287
|
+
|
288
|
+
summary_title: str = "Chapter Summary"
|
289
|
+
write_to: Optional[Path] = None
|
290
|
+
|
291
|
+
async def _execute(self, article: Article, write_to: Optional[Path] = None, **cxt) -> List[str]:
|
292
|
+
logger.info(";".join(a.title for a in article.chapters))
|
293
|
+
|
294
|
+
ret = [
|
295
|
+
f"== {self.summary_title}\n{raw}"
|
296
|
+
for raw in (
|
297
|
+
await self.aask(
|
298
|
+
TEMPLATE_MANAGER.render_template(
|
299
|
+
configs.templates.chap_summary_template,
|
300
|
+
[
|
301
|
+
{
|
302
|
+
"chapter": a.to_typst_code(),
|
303
|
+
"title": a.title,
|
304
|
+
"language": a.language,
|
305
|
+
"summary_word_count": self.summary_word_count,
|
306
|
+
"paragraph_count": self.paragraph_count,
|
307
|
+
}
|
308
|
+
for a in article.chapters
|
309
|
+
],
|
310
|
+
)
|
311
|
+
)
|
312
|
+
)
|
313
|
+
]
|
314
|
+
|
315
|
+
if (to := (self.write_to or write_to)) is not None:
|
316
|
+
dump_text(
|
317
|
+
to,
|
318
|
+
"\n\n\n".join(f"//{a.title}\n\n{s}" for a, s in zip(article.chapters, ret, strict=True)),
|
319
|
+
)
|
320
|
+
|
321
|
+
return ret
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
from asyncio import gather
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import List, Optional
|
5
|
+
from typing import ClassVar, List, Optional
|
6
6
|
|
7
7
|
from pydantic import Field, PositiveInt
|
8
8
|
|
@@ -21,17 +21,17 @@ from fabricatio.models.extra.article_outline import ArticleOutline
|
|
21
21
|
from fabricatio.models.extra.rule import RuleSet
|
22
22
|
from fabricatio.models.kwargs_types import ChooseKwargs, LLMKwargs
|
23
23
|
from fabricatio.rust import convert_to_block_formula, convert_to_inline_formula
|
24
|
-
from fabricatio.utils import
|
24
|
+
from fabricatio.utils import ok
|
25
25
|
|
26
26
|
TYPST_CITE_USAGE = (
|
27
|
-
"citation number is REQUIRED to cite any reference
|
28
|
-
"Everything is build upon the typst language, which is similar to latex, \n"
|
27
|
+
"citation number is REQUIRED to cite any reference!'\n"
|
29
28
|
"Legal citing syntax examples(seperated by |): [[1]]|[[1,2]]|[[1-3]]|[[12,13-15]]|[[1-3,5-7]]\n"
|
30
29
|
"Illegal citing syntax examples(seperated by |): [[1],[2],[3]]|[[1],[1-2]]\n"
|
30
|
+
"You SHALL not cite a single reference more than once!"
|
31
31
|
"Those reference mark shall not be omitted during the extraction\n"
|
32
32
|
"It's recommended to cite multiple references that supports your conclusion at a time.\n"
|
33
33
|
"Wrap inline expression with '\\(' and '\\)',like '\\(>5m\\)' '\\(89%\\)', and wrap block equation with '\\[' and '\\]'.\n"
|
34
|
-
"In addition to that, you can add a label outside the block equation which can be used as a cross reference identifier, the label is a string wrapped in `<` and `>` like `<energy-release-rate-equation>`.Note that the label string should be a summarizing title for the equation being labeled.\n"
|
34
|
+
"In addition to that, you can add a label outside the block equation which can be used as a cross reference identifier, the label is a string wrapped in `<` and `>` like `<energy-release-rate-equation>`.Note that the label string should be a summarizing title for the equation being labeled and should never be written within the formula block.\n"
|
35
35
|
"you can refer to that label by using the syntax with prefix of `@eqt:`, which indicate that this notation is citing a label from the equations. For example ' @eqt:energy-release-rate-equation ' DO remember that the notation shall have both suffixed and prefixed space char which enable the compiler to distinguish the notation from the plaintext."
|
36
36
|
"Below is two usage example:\n"
|
37
37
|
"```typst\n"
|
@@ -47,6 +47,7 @@ TYPST_CITE_USAGE = (
|
|
47
47
|
class WriteArticleContentRAG(Action, RAG, Extract):
|
48
48
|
"""Write an article based on the provided outline."""
|
49
49
|
|
50
|
+
ctx_override: ClassVar[bool] = True
|
50
51
|
search_increment_multiplier: float = 1.6
|
51
52
|
"""The increment multiplier of the search increment."""
|
52
53
|
ref_limit: int = 35
|
@@ -63,6 +64,7 @@ class WriteArticleContentRAG(Action, RAG, Extract):
|
|
63
64
|
"""The number of results to be returned per query."""
|
64
65
|
req: str = TYPST_CITE_USAGE
|
65
66
|
"""The req of the write article content."""
|
67
|
+
tei_endpoint: Optional[str] = None
|
66
68
|
|
67
69
|
async def _execute(
|
68
70
|
self,
|
@@ -108,16 +110,7 @@ class WriteArticleContentRAG(Action, RAG, Extract):
|
|
108
110
|
|
109
111
|
while not await confirm("Accept this version and continue?").ask_async():
|
110
112
|
if inst := await text("Search for more refs for additional spec.").ask_async():
|
111
|
-
await self.search_database(
|
112
|
-
article,
|
113
|
-
article_outline,
|
114
|
-
chap,
|
115
|
-
sec,
|
116
|
-
subsec,
|
117
|
-
cm,
|
118
|
-
supervisor=True,
|
119
|
-
extra_instruction=inst,
|
120
|
-
)
|
113
|
+
await self.search_database(article, article_outline, chap, sec, subsec, cm, extra_instruction=inst)
|
121
114
|
|
122
115
|
if instruction := await text("Enter the instructions to improve").ask_async():
|
123
116
|
raw = await self.write_raw(article, article_outline, chap, sec, subsec, cm, instruction)
|
@@ -200,7 +193,6 @@ class WriteArticleContentRAG(Action, RAG, Extract):
|
|
200
193
|
subsec: ArticleSubsection,
|
201
194
|
cm: CitationManager,
|
202
195
|
extra_instruction: str = "",
|
203
|
-
supervisor: bool = False,
|
204
196
|
) -> None:
|
205
197
|
"""Search database for related references."""
|
206
198
|
search_req = (
|
@@ -208,61 +200,31 @@ class WriteArticleContentRAG(Action, RAG, Extract):
|
|
208
200
|
f"More specifically, i m witting the Chapter `{chap.title}` >> Section `{sec.title}` >> Subsection `{subsec.title}`.\n"
|
209
201
|
f"I need to search related references to build up the content of the subsec mentioned above, which is `{subsec.title}`.\n"
|
210
202
|
f"provide 10~16 queries as possible, to get best result!\n"
|
211
|
-
f"You should provide both English version and chinese version of the refined queries!\n{extra_instruction}
|
203
|
+
f"You should provide both English version and chinese version of the refined queries!\n{extra_instruction}"
|
212
204
|
)
|
213
205
|
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
ref_q = await ask_retain(ref_q)
|
224
|
-
ret = await self.aretrieve(
|
225
|
-
ref_q,
|
226
|
-
ArticleChunk,
|
227
|
-
max_accepted=self.ref_limit,
|
228
|
-
result_per_query=self.result_per_query,
|
229
|
-
similarity_threshold=self.threshold,
|
230
|
-
)
|
231
|
-
|
232
|
-
cm.add_chunks(ok(ret))
|
233
|
-
ref_q = await self.arefined_query(
|
234
|
-
f"{cm.as_prompt()}\n\nAbove is the retrieved references in the first RAG, now we need to perform the second RAG.\n\n{search_req}",
|
235
|
-
**self.query_model,
|
236
|
-
)
|
237
|
-
|
238
|
-
if ref_q is None:
|
239
|
-
logger.warning("Second refine query is None, skipping.")
|
240
|
-
return
|
241
|
-
if supervisor:
|
242
|
-
ref_q = await ask_retain(ref_q)
|
243
|
-
|
244
|
-
ret = await self.aretrieve(
|
245
|
-
ref_q,
|
246
|
-
ArticleChunk,
|
247
|
-
max_accepted=int(self.ref_limit * self.search_increment_multiplier),
|
248
|
-
result_per_query=int(self.result_per_query * self.search_increment_multiplier),
|
249
|
-
similarity_threshold=self.threshold,
|
206
|
+
await self.clued_search(
|
207
|
+
search_req,
|
208
|
+
cm,
|
209
|
+
refinery_kwargs=self.ref_q_model,
|
210
|
+
expand_multiplier=self.search_increment_multiplier,
|
211
|
+
base_accepted=self.ref_limit,
|
212
|
+
result_per_query=self.ref_per_q,
|
213
|
+
similarity_threshold=self.similarity_threshold,
|
214
|
+
tei_endpoint=self.tei_endpoint,
|
250
215
|
)
|
251
|
-
if ret is None:
|
252
|
-
logger.warning("Second retrieve is None, skipping.")
|
253
|
-
return
|
254
|
-
cm.add_chunks(ret)
|
255
216
|
|
256
217
|
|
257
218
|
class ArticleConsultRAG(Action, AdvancedRAG):
|
258
219
|
"""Write an article based on the provided outline."""
|
259
220
|
|
221
|
+
ctx_override: ClassVar[bool] = True
|
260
222
|
output_key: str = "consult_count"
|
261
223
|
search_increment_multiplier: float = 1.6
|
262
224
|
"""The multiplier to increase the limit of references to retrieve per query."""
|
263
|
-
ref_limit: int =
|
225
|
+
ref_limit: int = 26
|
264
226
|
"""The final limit of references."""
|
265
|
-
ref_per_q: int =
|
227
|
+
ref_per_q: int = 13
|
266
228
|
"""The limit of references to retrieve per query."""
|
267
229
|
similarity_threshold: float = 0.62
|
268
230
|
"""The similarity threshold of references to retrieve."""
|
@@ -270,6 +232,7 @@ class ArticleConsultRAG(Action, AdvancedRAG):
|
|
270
232
|
"""The model to use for refining query."""
|
271
233
|
req: str = TYPST_CITE_USAGE
|
272
234
|
"""The request for the rag model."""
|
235
|
+
tei_endpoint: Optional[str] = None
|
273
236
|
|
274
237
|
@precheck_package(
|
275
238
|
"questionary", "`questionary` is required for supervisor mode, please install it by `fabricatio[qa]`"
|
@@ -300,6 +263,7 @@ class ArticleConsultRAG(Action, AdvancedRAG):
|
|
300
263
|
base_accepted=self.ref_limit,
|
301
264
|
result_per_query=self.ref_per_q,
|
302
265
|
similarity_threshold=self.similarity_threshold,
|
266
|
+
tei_endpoint=self.tei_endpoint,
|
303
267
|
)
|
304
268
|
|
305
269
|
ret = await self.aask(f"{cm.as_prompt()}\n{self.req}\n{req}")
|
@@ -7,6 +7,7 @@ from fabricatio.journal import logger
|
|
7
7
|
from fabricatio.models.adv_kwargs_types import FetchKwargs
|
8
8
|
from fabricatio.models.extra.aricle_rag import ArticleChunk, CitationManager
|
9
9
|
from fabricatio.models.kwargs_types import ChooseKwargs
|
10
|
+
from fabricatio.utils import fallback_kwargs
|
10
11
|
|
11
12
|
|
12
13
|
class AdvancedRAG(RAG):
|
@@ -40,10 +41,13 @@ class AdvancedRAG(RAG):
|
|
40
41
|
f"\n\n{requirement}",
|
41
42
|
**refinery_kwargs,
|
42
43
|
)
|
44
|
+
|
43
45
|
if ref_q is None:
|
44
46
|
logger.error(f"At round [{i}/{max_round}] search, failed to refine the query, exit.")
|
45
47
|
return cm
|
46
|
-
refs = await self.aretrieve(
|
48
|
+
refs = await self.aretrieve(
|
49
|
+
ref_q, ArticleChunk, base_accepted, **fallback_kwargs(kwargs, filter_expr=cm.as_milvus_filter_expr())
|
50
|
+
)
|
47
51
|
|
48
52
|
if (max_capacity := max_capacity - len(refs)) < 0:
|
49
53
|
cm.add_chunks(refs[0:max_capacity])
|
fabricatio/capabilities/rag.py
CHANGED
@@ -143,21 +143,27 @@ class RAG(EmbeddingUsage):
|
|
143
143
|
|
144
144
|
async def afetch_document[D: MilvusDataBase](
|
145
145
|
self,
|
146
|
-
|
146
|
+
query: List[str],
|
147
147
|
document_model: Type[D],
|
148
148
|
collection_name: Optional[str] = None,
|
149
149
|
similarity_threshold: float = 0.37,
|
150
150
|
result_per_query: int = 10,
|
151
|
+
tei_endpoint: Optional[str] = None,
|
152
|
+
reranker_threshold: float = 0.7,
|
153
|
+
filter_expr: str = "",
|
151
154
|
) -> List[D]:
|
152
155
|
"""Asynchronously fetches documents from a Milvus database based on input vectors.
|
153
156
|
|
154
157
|
Args:
|
155
|
-
|
158
|
+
query (List[str]): A list of vectors to search for in the database.
|
156
159
|
document_model (Type[D]): The model class used to convert fetched data into document objects.
|
157
160
|
collection_name (Optional[str]): The name of the collection to search within.
|
158
161
|
If None, the currently viewed collection is used.
|
159
162
|
similarity_threshold (float): The similarity threshold for vector search. Defaults to 0.37.
|
160
163
|
result_per_query (int): The maximum number of results to return per query. Defaults to 10.
|
164
|
+
tei_endpoint (str): the endpoint of the TEI api.
|
165
|
+
reranker_threshold (float): The threshold used to filtered low relativity document.
|
166
|
+
filter_expr (str): filter_expression parsed into pymilvus search.
|
161
167
|
|
162
168
|
Returns:
|
163
169
|
List[D]: A list of document objects created from the fetched data.
|
@@ -165,15 +171,38 @@ class RAG(EmbeddingUsage):
|
|
165
171
|
# Step 1: Search for vectors
|
166
172
|
search_results = self.check_client().client.search(
|
167
173
|
collection_name or self.safe_target_collection,
|
168
|
-
|
174
|
+
await self.vectorize(query),
|
169
175
|
search_params={"radius": similarity_threshold},
|
170
176
|
output_fields=list(document_model.model_fields),
|
177
|
+
filter=filter_expr,
|
171
178
|
limit=result_per_query,
|
172
179
|
)
|
180
|
+
if tei_endpoint is not None:
|
181
|
+
from fabricatio.utils import RerankerAPI
|
182
|
+
|
183
|
+
reranker = RerankerAPI(base_url=tei_endpoint)
|
184
|
+
|
185
|
+
retrieved_id = set()
|
186
|
+
raw_result = []
|
187
|
+
|
188
|
+
for q, g in zip(query, search_results, strict=True):
|
189
|
+
models = document_model.from_sequence([res["entity"] for res in g if res["id"] not in retrieved_id])
|
190
|
+
logger.debug(f"Retrived {len(g)} raw document, filtered out {len(models)}.")
|
191
|
+
retrieved_id.update(res["id"] for res in g)
|
192
|
+
if not models:
|
193
|
+
continue
|
194
|
+
rank_scores = await reranker.arerank(q, [m.prepare_vectorization() for m in models], truncate=True)
|
195
|
+
raw_result.extend(
|
196
|
+
(models[s["index"]], s["score"]) for s in rank_scores if s["score"] > reranker_threshold
|
197
|
+
)
|
198
|
+
|
199
|
+
raw_result_sorted = sorted(raw_result, key=lambda x: x[1], reverse=True)
|
200
|
+
return [r[0] for r in raw_result_sorted]
|
173
201
|
|
174
202
|
# Step 2: Flatten the search results
|
175
203
|
flattened_results = flatten(search_results)
|
176
204
|
unique_results = unique(flattened_results, key=itemgetter("id"))
|
205
|
+
|
177
206
|
# Step 3: Sort by distance (descending)
|
178
207
|
sorted_results = sorted(unique_results, key=itemgetter("distance"), reverse=True)
|
179
208
|
|
@@ -205,15 +234,18 @@ class RAG(EmbeddingUsage):
|
|
205
234
|
"""
|
206
235
|
if isinstance(query, str):
|
207
236
|
query = [query]
|
237
|
+
|
208
238
|
return (
|
209
239
|
await self.afetch_document(
|
210
|
-
|
240
|
+
query=query,
|
211
241
|
document_model=document_model,
|
212
242
|
**kwargs,
|
213
243
|
)
|
214
244
|
)[:max_accepted]
|
215
245
|
|
216
|
-
async def arefined_query(
|
246
|
+
async def arefined_query(
|
247
|
+
self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs[Optional[List[str]]]]
|
248
|
+
) -> Optional[List[str]]:
|
217
249
|
"""Refines the given question using a template.
|
218
250
|
|
219
251
|
Args:
|
fabricatio/config.py
CHANGED
@@ -86,10 +86,12 @@ class LLMConfig(BaseModel):
|
|
86
86
|
|
87
87
|
tpm: Optional[PositiveInt] = Field(default=1000000)
|
88
88
|
"""The rate limit of the LLM model in tokens per minute. None means not checked."""
|
89
|
-
presence_penalty:Optional[PositiveFloat]=None
|
89
|
+
presence_penalty: Optional[PositiveFloat] = None
|
90
90
|
"""The presence penalty of the LLM model."""
|
91
|
-
frequency_penalty:Optional[PositiveFloat]=None
|
91
|
+
frequency_penalty: Optional[PositiveFloat] = None
|
92
92
|
"""The frequency penalty of the LLM model."""
|
93
|
+
|
94
|
+
|
93
95
|
class EmbeddingConfig(BaseModel):
|
94
96
|
"""Embedding configuration class."""
|
95
97
|
|
@@ -252,10 +254,13 @@ class TemplateConfig(BaseModel):
|
|
252
254
|
rule_requirement_template: str = Field(default="rule_requirement")
|
253
255
|
"""The name of the rule requirement template which will be used to generate a rule requirement."""
|
254
256
|
|
255
|
-
|
256
257
|
extract_template: str = Field(default="extract")
|
257
258
|
"""The name of the extract template which will be used to extract model from string."""
|
258
259
|
|
260
|
+
chap_summary_template: str = Field(default="chap_summary")
|
261
|
+
"""The name of the chap summary template which will be used to generate a chapter summary."""
|
262
|
+
|
263
|
+
|
259
264
|
class MagikaConfig(BaseModel):
|
260
265
|
"""Magika configuration class."""
|
261
266
|
|
fabricatio/models/action.py
CHANGED
@@ -12,7 +12,7 @@ Classes:
|
|
12
12
|
import traceback
|
13
13
|
from abc import abstractmethod
|
14
14
|
from asyncio import Queue, create_task
|
15
|
-
from typing import Any, Dict, Self, Sequence, Tuple, Type, Union, final
|
15
|
+
from typing import Any, ClassVar, Dict, Self, Sequence, Tuple, Type, Union, final
|
16
16
|
|
17
17
|
from fabricatio.journal import logger
|
18
18
|
from fabricatio.models.generic import WithBriefing
|
@@ -33,6 +33,9 @@ class Action(WithBriefing):
|
|
33
33
|
a specific operation and can modify the shared context data.
|
34
34
|
"""
|
35
35
|
|
36
|
+
ctx_override: ClassVar[bool] = False
|
37
|
+
"""Whether to override the instance attr by the context variable."""
|
38
|
+
|
36
39
|
name: str = Field(default="")
|
37
40
|
"""The name of the action."""
|
38
41
|
|
@@ -157,6 +160,15 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
|
|
157
160
|
action.personality = personality
|
158
161
|
return self
|
159
162
|
|
163
|
+
def override_action_variable(self, action: Action, ctx: Dict[str, Any]) -> Self:
|
164
|
+
"""Override action variable with context values."""
|
165
|
+
if action.ctx_override:
|
166
|
+
for k, v in ctx.items():
|
167
|
+
if hasattr(action, k):
|
168
|
+
setattr(action, k, v)
|
169
|
+
|
170
|
+
return self
|
171
|
+
|
160
172
|
async def serve(self, task: Task) -> None:
|
161
173
|
"""Execute workflow to complete given task.
|
162
174
|
|
@@ -178,11 +190,12 @@ class WorkFlow(WithBriefing, ToolBoxUsage):
|
|
178
190
|
try:
|
179
191
|
# Process each action in sequence
|
180
192
|
for i, step in enumerate(self._instances):
|
181
|
-
current_action
|
182
|
-
logger.info(f"Executing step [{i}] >> {current_action}")
|
193
|
+
logger.info(f"Executing step [{i}] >> {(current_action := step.name)}")
|
183
194
|
|
184
195
|
# Get current context and execute action
|
185
196
|
context = await self._context.get()
|
197
|
+
|
198
|
+
self.override_action_variable(step, context)
|
186
199
|
act_task = create_task(step.act(context))
|
187
200
|
# Handle task cancellation
|
188
201
|
if task.is_cancelled():
|
@@ -1,7 +1,7 @@
|
|
1
1
|
"""A module containing kwargs types for content correction and checking operations."""
|
2
2
|
|
3
3
|
from importlib.util import find_spec
|
4
|
-
from typing import NotRequired, TypedDict
|
4
|
+
from typing import NotRequired, Optional, TypedDict
|
5
5
|
|
6
6
|
from fabricatio.models.extra.problem import Improvement
|
7
7
|
from fabricatio.models.extra.rule import RuleSet
|
@@ -58,3 +58,6 @@ if find_spec("pymilvus"):
|
|
58
58
|
collection_name: NotRequired[str | None]
|
59
59
|
similarity_threshold: NotRequired[float]
|
60
60
|
result_per_query: NotRequired[int]
|
61
|
+
tei_endpoint: NotRequired[Optional[str]]
|
62
|
+
reranker_threshold: NotRequired[float]
|
63
|
+
filter_expr: NotRequired[str]
|
@@ -1,6 +1,7 @@
|
|
1
1
|
"""A Module containing the article rag models."""
|
2
2
|
|
3
3
|
import re
|
4
|
+
from itertools import groupby
|
4
5
|
from pathlib import Path
|
5
6
|
from typing import ClassVar, Dict, List, Optional, Self, Unpack
|
6
7
|
|
@@ -10,12 +11,13 @@ from fabricatio.models.extra.rag import MilvusDataBase
|
|
10
11
|
from fabricatio.models.generic import AsPrompt
|
11
12
|
from fabricatio.models.kwargs_types import ChunkKwargs
|
12
13
|
from fabricatio.rust import BibManager, blake3_hash, split_into_chunks
|
13
|
-
from fabricatio.utils import ok
|
14
|
+
from fabricatio.utils import ok, wrapp_in_block
|
15
|
+
from more_itertools.more import first
|
14
16
|
from more_itertools.recipes import flatten, unique
|
15
17
|
from pydantic import Field
|
16
18
|
|
17
19
|
|
18
|
-
class ArticleChunk(MilvusDataBase
|
20
|
+
class ArticleChunk(MilvusDataBase):
|
19
21
|
"""The chunk of an article."""
|
20
22
|
|
21
23
|
etc_word: ClassVar[str] = "等"
|
@@ -51,10 +53,9 @@ class ArticleChunk(MilvusDataBase, AsPrompt):
|
|
51
53
|
bibtex_cite_key: str
|
52
54
|
"""The bibtex cite key of the article"""
|
53
55
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
}
|
56
|
+
@property
|
57
|
+
def reference_header(self) -> str:
|
58
|
+
return f"[[{ok(self._cite_number, 'You need to update cite number first.')}]] reference `{self.article_title}` from {self.as_auther_seq()}"
|
58
59
|
|
59
60
|
@property
|
60
61
|
def cite_number(self) -> int:
|
@@ -204,13 +205,23 @@ class CitationManager(AsPrompt):
|
|
204
205
|
|
205
206
|
def set_cite_number_all(self) -> Self:
|
206
207
|
"""Set citation numbers for all article chunks."""
|
207
|
-
|
208
|
-
|
208
|
+
number_mapping = {a.bibtex_cite_key: 0 for a in self.article_chunks}
|
209
|
+
|
210
|
+
for i, k in enumerate(number_mapping.keys()):
|
211
|
+
number_mapping[k] = i
|
212
|
+
|
213
|
+
for a in self.article_chunks:
|
214
|
+
a.update_cite_number(number_mapping[a.bibtex_cite_key])
|
209
215
|
return self
|
210
216
|
|
211
217
|
def _as_prompt_inner(self) -> Dict[str, str]:
|
212
218
|
"""Generate prompt inner representation."""
|
213
|
-
|
219
|
+
seg = []
|
220
|
+
for k, g in groupby(self.article_chunks, key=lambda a: a.bibtex_cite_key):
|
221
|
+
g = list(g)
|
222
|
+
logger.debug(f"Group [{k}]: {len(g)}")
|
223
|
+
seg.append(wrapp_in_block("\n\n".join(a.chunk for a in g), first(g).reference_header))
|
224
|
+
return {"References": "\n".join(seg)}
|
214
225
|
|
215
226
|
def apply(self, string: str) -> str:
|
216
227
|
"""Apply citation replacements to the input string."""
|
@@ -261,5 +272,9 @@ class CitationManager(AsPrompt):
|
|
261
272
|
|
262
273
|
def unpack_cite_seq(self, citation_seq: List[int]) -> str:
|
263
274
|
"""Unpack citation sequence into a string."""
|
264
|
-
chunk_seq =
|
265
|
-
return "".join(a.as_typst_cite() for a in chunk_seq)
|
275
|
+
chunk_seq = {a.bibtex_cite_key: a for a in self.article_chunks if a.cite_number in citation_seq}
|
276
|
+
return "".join(a.as_typst_cite() for a in chunk_seq.values())
|
277
|
+
|
278
|
+
def as_milvus_filter_expr(self, blacklist: bool = True) -> str:
|
279
|
+
if blacklist:
|
280
|
+
return " and ".join(f'bibtex_cite_key != "{a.bibtex_cite_key}"' for a in self.article_chunks)
|
@@ -2,8 +2,9 @@
|
|
2
2
|
|
3
3
|
from abc import ABC
|
4
4
|
from enum import StrEnum
|
5
|
-
from typing import Generator, List, Optional, Self, Tuple
|
5
|
+
from typing import ClassVar, Generator, List, Optional, Self, Tuple, Type
|
6
6
|
|
7
|
+
from fabricatio.fs.readers import extract_sections
|
7
8
|
from fabricatio.models.generic import (
|
8
9
|
AsPrompt,
|
9
10
|
Described,
|
@@ -105,12 +106,6 @@ class ArticleOutlineBase(
|
|
105
106
|
self.description = other.description
|
106
107
|
return self
|
107
108
|
|
108
|
-
def display_metadata(self) -> str:
|
109
|
-
"""Displays the metadata of the current instance."""
|
110
|
-
return self.model_dump_json(
|
111
|
-
indent=1, include={"title", "writing_aim", "description", "support_to", "depend_on"}
|
112
|
-
)
|
113
|
-
|
114
109
|
def update_from_inner(self, other: Self) -> Self:
|
115
110
|
"""Updates the current instance with the attributes of another instance."""
|
116
111
|
return self.update_metadata(other)
|
@@ -140,6 +135,8 @@ class SectionBase[T: SubSectionBase](ArticleOutlineBase):
|
|
140
135
|
subsections: List[T]
|
141
136
|
"""Subsections of the section. Contains at least one subsection. You can also add more as needed."""
|
142
137
|
|
138
|
+
child_type: ClassVar[Type[SubSectionBase]]
|
139
|
+
|
143
140
|
def to_typst_code(self) -> str:
|
144
141
|
"""Converts the section into a Typst formatted code snippet.
|
145
142
|
|
@@ -148,6 +145,17 @@ class SectionBase[T: SubSectionBase](ArticleOutlineBase):
|
|
148
145
|
"""
|
149
146
|
return f"== {super().to_typst_code()}" + "\n\n".join(subsec.to_typst_code() for subsec in self.subsections)
|
150
147
|
|
148
|
+
@classmethod
|
149
|
+
def from_typst_code(cls, title: str, body: str, **kwargs) -> Self:
|
150
|
+
"""Creates an Article object from the given Typst code."""
|
151
|
+
return super().from_typst_code(
|
152
|
+
title,
|
153
|
+
body,
|
154
|
+
subsections=[
|
155
|
+
cls.child_type.from_typst_code(*pack) for pack in extract_sections(body, level=3, section_char="=")
|
156
|
+
],
|
157
|
+
)
|
158
|
+
|
151
159
|
def resolve_update_conflict(self, other: Self) -> str:
|
152
160
|
"""Resolve update errors in the article outline."""
|
153
161
|
out = ""
|
@@ -186,11 +194,23 @@ class ChapterBase[T: SectionBase](ArticleOutlineBase):
|
|
186
194
|
|
187
195
|
sections: List[T]
|
188
196
|
"""Sections of the chapter. Contains at least one section. You can also add more as needed."""
|
197
|
+
child_type: ClassVar[Type[SectionBase]]
|
189
198
|
|
190
199
|
def to_typst_code(self) -> str:
|
191
200
|
"""Converts the chapter into a Typst formatted code snippet for rendering."""
|
192
201
|
return f"= {super().to_typst_code()}" + "\n\n".join(sec.to_typst_code() for sec in self.sections)
|
193
202
|
|
203
|
+
@classmethod
|
204
|
+
def from_typst_code(cls, title: str, body: str, **kwargs) -> Self:
|
205
|
+
"""Creates an Article object from the given Typst code."""
|
206
|
+
return super().from_typst_code(
|
207
|
+
title,
|
208
|
+
body,
|
209
|
+
sections=[
|
210
|
+
cls.child_type.from_typst_code(*pack) for pack in extract_sections(body, level=2, section_char="=")
|
211
|
+
],
|
212
|
+
)
|
213
|
+
|
194
214
|
def resolve_update_conflict(self, other: Self) -> str:
|
195
215
|
"""Resolve update errors in the article outline."""
|
196
216
|
out = ""
|
@@ -238,6 +258,19 @@ class ArticleBase[T: ChapterBase](FinalizedDumpAble, AsPrompt, FromTypstCode, To
|
|
238
258
|
chapters: List[T]
|
239
259
|
"""Chapters of the article. Contains at least one chapter. You can also add more as needed."""
|
240
260
|
|
261
|
+
child_type: ClassVar[Type[ChapterBase]]
|
262
|
+
|
263
|
+
@classmethod
|
264
|
+
def from_typst_code(cls, title: str, body: str, **kwargs) -> Self:
|
265
|
+
"""Generates an article from the given Typst code."""
|
266
|
+
return super().from_typst_code(
|
267
|
+
title,
|
268
|
+
body,
|
269
|
+
chapters=[
|
270
|
+
cls.child_type.from_typst_code(*pack) for pack in extract_sections(body, level=1, section_char="=")
|
271
|
+
],
|
272
|
+
)
|
273
|
+
|
241
274
|
def iter_dfs_rev(
|
242
275
|
self,
|
243
276
|
) -> Generator[ArticleOutlineBase, None, None]:
|
@@ -1,9 +1,8 @@
|
|
1
1
|
"""ArticleBase and ArticleSubsection classes for managing hierarchical document components."""
|
2
2
|
|
3
|
-
from typing import Dict, Generator, List, Self, Tuple, override
|
3
|
+
from typing import ClassVar, Dict, Generator, List, Self, Tuple, Type, override
|
4
4
|
|
5
5
|
from fabricatio.decorators import precheck_package
|
6
|
-
from fabricatio.fs.readers import extract_sections
|
7
6
|
from fabricatio.journal import logger
|
8
7
|
from fabricatio.models.extra.article_base import (
|
9
8
|
ArticleBase,
|
@@ -52,6 +51,11 @@ class Paragraph(SketchedAble, WordCount, Described):
|
|
52
51
|
"""Create a Paragraph object from the given content."""
|
53
52
|
return cls(elaboration="", aims=[], expected_word_count=word_count(content), content=content)
|
54
53
|
|
54
|
+
@property
|
55
|
+
def exact_wordcount(self) -> int:
|
56
|
+
"""Calculates the exact word count of the content."""
|
57
|
+
return word_count(self.content)
|
58
|
+
|
55
59
|
|
56
60
|
class ArticleParagraphSequencePatch(SequencePatch[Paragraph]):
|
57
61
|
"""Patch for `Paragraph` list of `ArticleSubsection`."""
|
@@ -115,31 +119,13 @@ class ArticleSubsection(SubSectionBase):
|
|
115
119
|
class ArticleSection(SectionBase[ArticleSubsection]):
|
116
120
|
"""Atomic argumentative unit with high-level specificity."""
|
117
121
|
|
118
|
-
|
119
|
-
def from_typst_code(cls, title: str, body: str, **kwargs) -> Self:
|
120
|
-
"""Creates an Article object from the given Typst code."""
|
121
|
-
return super().from_typst_code(
|
122
|
-
title,
|
123
|
-
body,
|
124
|
-
subsections=[
|
125
|
-
ArticleSubsection.from_typst_code(*pack) for pack in extract_sections(body, level=3, section_char="=")
|
126
|
-
],
|
127
|
-
)
|
122
|
+
child_type: ClassVar[Type[SubSectionBase]] = ArticleSubsection
|
128
123
|
|
129
124
|
|
130
125
|
class ArticleChapter(ChapterBase[ArticleSection]):
|
131
126
|
"""Thematic progression implementing research function."""
|
132
127
|
|
133
|
-
|
134
|
-
def from_typst_code(cls, title: str, body: str, **kwargs) -> Self:
|
135
|
-
"""Creates an Article object from the given Typst code."""
|
136
|
-
return super().from_typst_code(
|
137
|
-
title,
|
138
|
-
body,
|
139
|
-
sections=[
|
140
|
-
ArticleSection.from_typst_code(*pack) for pack in extract_sections(body, level=2, section_char="=")
|
141
|
-
],
|
142
|
-
)
|
128
|
+
child_type: ClassVar[Type[SectionBase]] = ArticleSection
|
143
129
|
|
144
130
|
|
145
131
|
class Article(
|
@@ -153,6 +139,8 @@ class Article(
|
|
153
139
|
aiming to provide a comprehensive model for academic papers.
|
154
140
|
"""
|
155
141
|
|
142
|
+
child_type: ClassVar[Type[ChapterBase]] = ArticleChapter
|
143
|
+
|
156
144
|
def _as_prompt_inner(self) -> Dict[str, str]:
|
157
145
|
return {
|
158
146
|
"Original Article Briefing": self.referenced.referenced.referenced,
|
@@ -261,17 +249,6 @@ class Article(
|
|
261
249
|
article.chapters.append(article_chapter)
|
262
250
|
return article
|
263
251
|
|
264
|
-
@classmethod
|
265
|
-
def from_typst_code(cls, title: str, body: str, **kwargs) -> Self:
|
266
|
-
"""Generates an article from the given Typst code."""
|
267
|
-
return super().from_typst_code(
|
268
|
-
title,
|
269
|
-
body,
|
270
|
-
chapters=[
|
271
|
-
ArticleChapter.from_typst_code(*pack) for pack in extract_sections(body, level=1, section_char="=")
|
272
|
-
],
|
273
|
-
)
|
274
|
-
|
275
252
|
@classmethod
|
276
253
|
def from_mixed_source(cls, article_outline: ArticleOutline, typst_code: str) -> Self:
|
277
254
|
"""Generates an article from the given outline and Typst code."""
|
@@ -292,3 +269,15 @@ class Article(
|
|
292
269
|
for a in self.iter_dfs():
|
293
270
|
a.title = await text(f"Edit `{a.title}`.", default=a.title).ask_async() or a.title
|
294
271
|
return self
|
272
|
+
|
273
|
+
def check_short_paragraphs(self, threshold: int = 60) -> str:
|
274
|
+
"""Checks for short paragraphs in the article."""
|
275
|
+
err = []
|
276
|
+
for chap, sec, subsec in self.iter_subsections():
|
277
|
+
for i, p in enumerate(subsec.paragraphs):
|
278
|
+
if p.exact_wordcount <= threshold:
|
279
|
+
err.append(
|
280
|
+
f"{chap.title}->{sec.title}->{subsec.title}-> Paragraph [{i}] is too short, {p.exact_wordcount} words."
|
281
|
+
)
|
282
|
+
|
283
|
+
return "\n".join(err)
|
@@ -19,6 +19,7 @@ class ArticleSubsectionOutline(SubSectionBase):
|
|
19
19
|
|
20
20
|
class ArticleSectionOutline(SectionBase[ArticleSubsectionOutline]):
|
21
21
|
"""A slightly more detailed research component specification for academic paper generation, Must contain subsections."""
|
22
|
+
|
22
23
|
@classmethod
|
23
24
|
def from_typst_code(cls, title: str, body: str, **kwargs) -> Self:
|
24
25
|
"""Parse the given Typst code into an ArticleSectionOutline instance."""
|
@@ -32,7 +33,6 @@ class ArticleSectionOutline(SectionBase[ArticleSubsectionOutline]):
|
|
32
33
|
)
|
33
34
|
|
34
35
|
|
35
|
-
|
36
36
|
class ArticleChapterOutline(ChapterBase[ArticleSectionOutline]):
|
37
37
|
"""Macro-structural unit implementing standard academic paper organization. Must contain sections."""
|
38
38
|
|
@@ -46,11 +46,9 @@ class ArticleChapterOutline(ChapterBase[ArticleSectionOutline]):
|
|
46
46
|
ArticleSectionOutline.from_typst_code(*pack)
|
47
47
|
for pack in extract_sections(body, level=2, section_char="=")
|
48
48
|
],
|
49
|
-
|
50
49
|
)
|
51
50
|
|
52
51
|
|
53
|
-
|
54
52
|
class ArticleOutline(
|
55
53
|
WithRef[ArticleProposal],
|
56
54
|
PersistentAble,
|
fabricatio/models/generic.py
CHANGED
@@ -312,11 +312,11 @@ class Language(Base):
|
|
312
312
|
@property
|
313
313
|
def language(self) -> str:
|
314
314
|
"""Get the language of the object."""
|
315
|
-
if isinstance(self, Described):
|
315
|
+
if isinstance(self, Described) and self.description:
|
316
316
|
return detect_language(self.description)
|
317
|
-
if isinstance(self, Titled):
|
317
|
+
if isinstance(self, Titled) and self.title:
|
318
318
|
return detect_language(self.title)
|
319
|
-
if isinstance(self, Named):
|
319
|
+
if isinstance(self, Named) and self.name:
|
320
320
|
return detect_language(self.name)
|
321
321
|
|
322
322
|
return detect_language(self.model_dump_json(by_alias=True))
|
@@ -1,6 +1,6 @@
|
|
1
1
|
"""This module contains the types for the keyword arguments of the methods in the models module."""
|
2
2
|
|
3
|
-
from typing import Any, Dict, List, NotRequired, Optional, Required, TypedDict
|
3
|
+
from typing import Any, Dict, List, Literal, NotRequired, Optional, Required, TypedDict
|
4
4
|
|
5
5
|
from litellm.caching.caching import CacheMode
|
6
6
|
from litellm.types.caching import CachingSupportedCallTypes
|
@@ -164,3 +164,11 @@ class CacheKwargs(TypedDict, total=False):
|
|
164
164
|
qdrant_collection_name: str
|
165
165
|
qdrant_quantization_config: str
|
166
166
|
qdrant_semantic_cache_embedding_model: str
|
167
|
+
|
168
|
+
|
169
|
+
class RerankOptions(TypedDict, total=False):
|
170
|
+
"""Optional keyword arguments for the rerank method."""
|
171
|
+
|
172
|
+
raw_scores: bool
|
173
|
+
truncate: bool
|
174
|
+
truncation_direction: Literal["Left", "Right"]
|
Binary file
|
fabricatio/utils.py
CHANGED
@@ -1,8 +1,13 @@
|
|
1
1
|
"""A collection of utility functions for the fabricatio package."""
|
2
2
|
|
3
|
-
from typing import Any, Dict, List, Mapping, Optional, overload
|
3
|
+
from typing import Any, Dict, List, Mapping, Optional, TypedDict, Unpack, overload
|
4
|
+
|
5
|
+
import aiohttp
|
6
|
+
import requests
|
4
7
|
|
5
8
|
from fabricatio.decorators import precheck_package
|
9
|
+
from fabricatio.journal import logger
|
10
|
+
from fabricatio.models.kwargs_types import RerankOptions
|
6
11
|
|
7
12
|
|
8
13
|
@precheck_package(
|
@@ -92,3 +97,159 @@ def wrapp_in_block(string: str, title: str, style: str = "-") -> str:
|
|
92
97
|
str: The wrapped string.
|
93
98
|
"""
|
94
99
|
return f"--- Start of {title} ---\n{string}\n--- End of {title} ---".replace("-", style)
|
100
|
+
|
101
|
+
|
102
|
+
class RerankResult(TypedDict):
|
103
|
+
"""The rerank result."""
|
104
|
+
|
105
|
+
index: int
|
106
|
+
score: float
|
107
|
+
|
108
|
+
|
109
|
+
class RerankerAPI:
|
110
|
+
"""A class to interact with the /rerank API for text reranking."""
|
111
|
+
|
112
|
+
def __init__(self, base_url: str) -> None:
|
113
|
+
"""Initialize the RerankerAPI instance.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
base_url (str): The base URL of the TEI-deployed reranker model API.
|
117
|
+
Example: "http://localhost:8000".
|
118
|
+
"""
|
119
|
+
self.base_url = base_url.rstrip("/") # Ensure no trailing slashes
|
120
|
+
|
121
|
+
@staticmethod
|
122
|
+
def _map_error_code(status_code: int, error_data: Dict[str, str]) -> Exception:
|
123
|
+
"""Map HTTP status codes and error data to specific exceptions.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
status_code (int): The HTTP status code returned by the API.
|
127
|
+
error_data (Dict[str, str]): The error details returned by the API.
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
Exception: A specific exception based on the error code and message.
|
131
|
+
"""
|
132
|
+
error_message = error_data.get("error", "Unknown error")
|
133
|
+
|
134
|
+
if status_code == 400:
|
135
|
+
return ValueError(f"Bad request: {error_message}")
|
136
|
+
if status_code == 413:
|
137
|
+
return ValueError(f"Batch size error: {error_message}")
|
138
|
+
if status_code == 422:
|
139
|
+
return RuntimeError(f"Tokenization error: {error_message}")
|
140
|
+
if status_code == 424:
|
141
|
+
return RuntimeError(f"Rerank error: {error_message}")
|
142
|
+
if status_code == 429:
|
143
|
+
return RuntimeError(f"Model overloaded: {error_message}")
|
144
|
+
return RuntimeError(f"Unexpected error ({status_code}): {error_message}")
|
145
|
+
|
146
|
+
def rerank(self, query: str, texts: List[str], **kwargs: Unpack[RerankOptions]) -> List[RerankResult]:
|
147
|
+
"""Call the /rerank API to rerank a list of texts based on a query (synchronous).
|
148
|
+
|
149
|
+
Args:
|
150
|
+
query (str): The query string used for matching with the texts.
|
151
|
+
texts (List[str]): A list of texts to be reranked.
|
152
|
+
**kwargs (Unpack[RerankOptions]): Optional keyword arguments:
|
153
|
+
- raw_scores (bool, optional): Whether to return raw scores. Defaults to False.
|
154
|
+
- truncate (bool, optional): Whether to truncate the texts. Defaults to False.
|
155
|
+
- truncation_direction (Literal["left", "right"], optional): Direction of truncation. Defaults to "right".
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
List[RerankResult]: A list of dictionaries containing the reranked results.
|
159
|
+
Each dictionary includes:
|
160
|
+
- "index" (int): The original index of the text.
|
161
|
+
- "score" (float): The relevance score.
|
162
|
+
|
163
|
+
Raises:
|
164
|
+
ValueError: If input parameters are invalid or the API returns a client-side error.
|
165
|
+
RuntimeError: If the API call fails or returns a server-side error.
|
166
|
+
"""
|
167
|
+
# Validate inputs
|
168
|
+
if not isinstance(query, str) or not query.strip():
|
169
|
+
raise ValueError("Query must be a non-empty string.")
|
170
|
+
if not isinstance(texts, list) or not all(isinstance(text, str) for text in texts):
|
171
|
+
raise ValueError("Texts must be a list of strings.")
|
172
|
+
|
173
|
+
# Construct the request payload
|
174
|
+
payload = {
|
175
|
+
"query": query,
|
176
|
+
"texts": texts,
|
177
|
+
**kwargs,
|
178
|
+
}
|
179
|
+
|
180
|
+
try:
|
181
|
+
# Send POST request to the API
|
182
|
+
response = requests.post(f"{self.base_url}/rerank", json=payload)
|
183
|
+
|
184
|
+
# Handle non-200 status codes
|
185
|
+
if not response.ok:
|
186
|
+
error_data = None
|
187
|
+
if "application/json" in response.headers.get("Content-Type", ""):
|
188
|
+
error_data = response.json()
|
189
|
+
else:
|
190
|
+
error_data = {"error": response.text, "error_type": "unexpected_mimetype"}
|
191
|
+
raise self._map_error_code(response.status_code, error_data)
|
192
|
+
|
193
|
+
# Parse the JSON response
|
194
|
+
data: List[RerankResult] = response.json()
|
195
|
+
logger.debug(f"Rerank for `{query}` get {[s['score'] for s in data]}")
|
196
|
+
return data
|
197
|
+
|
198
|
+
except requests.exceptions.RequestException as e:
|
199
|
+
raise RuntimeError(f"Failed to connect to the API: {e}") from e
|
200
|
+
|
201
|
+
async def arerank(self, query: str, texts: List[str], **kwargs: Unpack[RerankOptions]) -> List[RerankResult]:
|
202
|
+
"""Call the /rerank API to rerank a list of texts based on a query (asynchronous).
|
203
|
+
|
204
|
+
Args:
|
205
|
+
query (str): The query string used for matching with the texts.
|
206
|
+
texts (List[str]): A list of texts to be reranked.
|
207
|
+
**kwargs (Unpack[RerankOptions]): Optional keyword arguments:
|
208
|
+
- raw_scores (bool, optional): Whether to return raw scores. Defaults to False.
|
209
|
+
- truncate (bool, optional): Whether to truncate the texts. Defaults to False.
|
210
|
+
- truncation_direction (Literal["left", "right"], optional): Direction of truncation. Defaults to "right".
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
List[RerankResult]: A list of dictionaries containing the reranked results.
|
214
|
+
Each dictionary includes:
|
215
|
+
- "index" (int): The original index of the text.
|
216
|
+
- "score" (float): The relevance score.
|
217
|
+
|
218
|
+
Raises:
|
219
|
+
ValueError: If input parameters are invalid or the API returns a client-side error.
|
220
|
+
RuntimeError: If the API call fails or returns a server-side error.
|
221
|
+
"""
|
222
|
+
# Validate inputs
|
223
|
+
if not isinstance(query, str) or not query.strip():
|
224
|
+
raise ValueError("Query must be a non-empty string.")
|
225
|
+
if not isinstance(texts, list) or not all(isinstance(text, str) for text in texts):
|
226
|
+
raise ValueError("Texts must be a list of strings.")
|
227
|
+
|
228
|
+
# Construct the request payload
|
229
|
+
payload = {
|
230
|
+
"query": query,
|
231
|
+
"texts": texts,
|
232
|
+
**kwargs,
|
233
|
+
}
|
234
|
+
|
235
|
+
try:
|
236
|
+
# Send POST request to the API using aiohttp
|
237
|
+
async with (
|
238
|
+
aiohttp.ClientSession() as session,
|
239
|
+
session.post(f"{self.base_url}/rerank", json=payload) as response,
|
240
|
+
):
|
241
|
+
# Handle non-200 status codes
|
242
|
+
if not response.ok:
|
243
|
+
if "application/json" in response.headers.get("Content-Type", ""):
|
244
|
+
error_data = await response.json()
|
245
|
+
else:
|
246
|
+
error_data = {"error": await response.text(), "error_type": "unexpected_mimetype"}
|
247
|
+
raise self._map_error_code(response.status, error_data)
|
248
|
+
|
249
|
+
# Parse the JSON response
|
250
|
+
data: List[RerankResult] = await response.json()
|
251
|
+
logger.debug(f"Rerank for `{query}` get {[s['score'] for s in data]}")
|
252
|
+
return data
|
253
|
+
|
254
|
+
except aiohttp.ClientError as e:
|
255
|
+
raise RuntimeError(f"Failed to connect to the API: {e}") from e
|
Binary file
|
Binary file
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: fabricatio
|
3
|
-
Version: 0.2.13.
|
3
|
+
Version: 0.2.13.dev2
|
4
4
|
Classifier: License :: OSI Approved :: MIT License
|
5
5
|
Classifier: Programming Language :: Rust
|
6
6
|
Classifier: Programming Language :: Python :: 3.12
|
@@ -165,6 +165,10 @@ max_tokens = 8192
|
|
165
165
|
```bash
|
166
166
|
make test
|
167
167
|
```
|
168
|
+
## TODO
|
169
|
+
|
170
|
+
- Add an element based format strategy
|
171
|
+
|
168
172
|
|
169
173
|
## Contributing
|
170
174
|
|
@@ -1,26 +1,26 @@
|
|
1
|
-
fabricatio-0.2.13.
|
2
|
-
fabricatio-0.2.13.
|
3
|
-
fabricatio-0.2.13.
|
4
|
-
fabricatio/actions/article.py,sha256=
|
5
|
-
fabricatio/actions/article_rag.py,sha256=
|
1
|
+
fabricatio-0.2.13.dev2.dist-info/METADATA,sha256=-Wa7FaGztQxqO0dcMsIB6FNH6w5r4jQi5URaPkevk60,5316
|
2
|
+
fabricatio-0.2.13.dev2.dist-info/WHEEL,sha256=jABKVkLC9kJr8mi_er5jOqpiQUjARSLXDUIIxDqsS50,96
|
3
|
+
fabricatio-0.2.13.dev2.dist-info/licenses/LICENSE,sha256=do7J7EiCGbq0QPbMAL_FqLYufXpHnCnXBOuqVPwSV8Y,1088
|
4
|
+
fabricatio/actions/article.py,sha256=c9HfLM1p0kkdEcQrl8w6Qa5p8riPnkEO--QbNZj88HQ,12562
|
5
|
+
fabricatio/actions/article_rag.py,sha256=0J5aeTyHS30bDaQBnxxc8p16oB87hRK_fOgijShWJPA,17797
|
6
6
|
fabricatio/actions/fs.py,sha256=gJR14U4ln35nt8Z7OWLVAZpqGaLnED-r1Yi-lX22tkI,959
|
7
7
|
fabricatio/actions/output.py,sha256=lX0HkDse3ypCzZgeF-8Dr-EnNFdBiE-WQ1iLPFlGM1g,8302
|
8
8
|
fabricatio/actions/rag.py,sha256=KN-OWgcQjGmNgSZ-s5B8m4LpYKSGFJR8eq72mo2CP9k,3592
|
9
9
|
fabricatio/actions/rules.py,sha256=dkvCgNDjt2KSO1VgPRsxT4YBmIIMeetZb5tiz-slYkU,3640
|
10
10
|
fabricatio/actions/__init__.py,sha256=wVENCFtpVb1rLFxoOFJt9-8smLWXuJV7IwA8P3EfFz4,48
|
11
11
|
fabricatio/capabilities/advanced_judge.py,sha256=selB0Gwf1F4gGJlwBiRo6gI4KOUROgh3WnzO3mZFEls,706
|
12
|
-
fabricatio/capabilities/advanced_rag.py,sha256=
|
12
|
+
fabricatio/capabilities/advanced_rag.py,sha256=FaXHGOqS4VleGSLsnC5qm4S4EBcHZLZbj8TjXRieKBs,2513
|
13
13
|
fabricatio/capabilities/censor.py,sha256=bBT5qy-kp7fh8g4Lz3labSwxwJ60gGd_vrkc6k1cZ1U,4719
|
14
14
|
fabricatio/capabilities/check.py,sha256=kYqzohhv2bZfl1aKSUt7a8snT8YEl2zgha_ZdAdMMfQ,8622
|
15
15
|
fabricatio/capabilities/correct.py,sha256=W_cInqlciNEhyMK0YI53jk4EvW9uAdge90IO9OElUmA,10420
|
16
16
|
fabricatio/capabilities/extract.py,sha256=PMjkWvbsv57IYT7zzd_xbIu4eQqQjpcmBtJzqlWZhHY,2495
|
17
17
|
fabricatio/capabilities/propose.py,sha256=hkBeSlmcTdfYWT-ph6nlbtHXBozi_JXqXlWcnBy3W78,2007
|
18
|
-
fabricatio/capabilities/rag.py,sha256=
|
18
|
+
fabricatio/capabilities/rag.py,sha256=icJgFwEX4eOlZdtchDaITjFRBrAON6xezyFfPmzehI8,11058
|
19
19
|
fabricatio/capabilities/rating.py,sha256=iMtQs3H6vCjuEjiuuz4SRKMVaX7yff7MHWz-slYvi5g,17835
|
20
20
|
fabricatio/capabilities/review.py,sha256=-EMZe0ADFPT6fPGmra16UPjJC1M3rAs6dPFdTZ88Fgg,5060
|
21
21
|
fabricatio/capabilities/task.py,sha256=uks1U-4LNCUdwdRxAbJJjMc31hOw6jlrcYriuQQfb04,4475
|
22
22
|
fabricatio/capabilities/__init__.py,sha256=v1cHRHIJ2gxyqMLNCs6ERVcCakSasZNYzmMI4lqAcls,57
|
23
|
-
fabricatio/config.py,sha256=
|
23
|
+
fabricatio/config.py,sha256=XoGU5PMqVNW8p_5VD2qE3_g3UpjpSBHUzvjXlgNS0lM,18126
|
24
24
|
fabricatio/constants.py,sha256=thfDuF6JEtJ5CHOnAJLfqvn5834n8ep6DH2jc6XGzQM,577
|
25
25
|
fabricatio/core.py,sha256=VQ_JKgUGIy2gZ8xsTBZCdr_IP7wC5aPg0_bsOmjQ588,6458
|
26
26
|
fabricatio/decorators.py,sha256=RFMYUlQPf561-BIHetpMd7fPig5bZ2brzWiQTgoLOlY,8966
|
@@ -28,23 +28,23 @@ fabricatio/fs/curd.py,sha256=652nHulbJ3gwt0Z3nywtPMmjhEyglDvEfc3p7ieJNNA,4777
|
|
28
28
|
fabricatio/fs/readers.py,sha256=UXvcJO3UCsxHu9PPkg34Yh55Zi-miv61jD_wZQJgKRs,1751
|
29
29
|
fabricatio/fs/__init__.py,sha256=FydmlEY_3QY74r1BpGDc5lFLhE6g6gkwOAtE30Fo-aI,786
|
30
30
|
fabricatio/journal.py,sha256=stnEP88aUBA_GmU9gfTF2EZI8FS2OyMLGaMSTgK4QgA,476
|
31
|
-
fabricatio/models/action.py,sha256=
|
32
|
-
fabricatio/models/adv_kwargs_types.py,sha256=
|
31
|
+
fabricatio/models/action.py,sha256=qxPeOD_nYNN94MzOhCzRDhySZFvM8uoZb_hhA7d_yn4,10609
|
32
|
+
fabricatio/models/adv_kwargs_types.py,sha256=IBV3ZcsNLvvEjO_2hBpYg_wLSpNKaMx6Ndam3qXJCw8,2097
|
33
33
|
fabricatio/models/events.py,sha256=wiirk_ASg3iXDOZU_gIimci1VZVzWE1nDmxy-hQVJ9M,4150
|
34
34
|
fabricatio/models/extra/advanced_judge.py,sha256=INUl_41C8jkausDekkjnEmTwNfLCJ23TwFjq2cM23Cw,1092
|
35
|
-
fabricatio/models/extra/aricle_rag.py,sha256=
|
36
|
-
fabricatio/models/extra/article_base.py,sha256=
|
35
|
+
fabricatio/models/extra/aricle_rag.py,sha256=6mkuxNZD_8cdWINmxP8ajmTwdwSH45jcdUSBmY6ZZfQ,11685
|
36
|
+
fabricatio/models/extra/article_base.py,sha256=qBkRYAOdrtTnO02G0W1zDDtQWYrQIKot_XyyDaaCLp8,15697
|
37
37
|
fabricatio/models/extra/article_essence.py,sha256=mlIkkRMR3I1RtqiiOnmIE3Vy623L4eECumkRzryE1pw,2749
|
38
|
-
fabricatio/models/extra/article_main.py,sha256=
|
39
|
-
fabricatio/models/extra/article_outline.py,sha256=
|
38
|
+
fabricatio/models/extra/article_main.py,sha256=zUkFQRLv6cLoPBbo7H21rrRScjgGv_SzhFd0Y514FsA,11211
|
39
|
+
fabricatio/models/extra/article_outline.py,sha256=M4TSrhQ7zGaOcGN91Z-zrhm_IKr8GrPM6uOpK_0JfFI,2789
|
40
40
|
fabricatio/models/extra/article_proposal.py,sha256=NbyjW-7UiFPtnVD9nte75re4xL2pD4qL29PpNV4Cg_M,1870
|
41
41
|
fabricatio/models/extra/patches.py,sha256=_WNCxtYzzsVfUxI16vu4IqsLahLYRHdbQN9er9tqhC0,997
|
42
42
|
fabricatio/models/extra/problem.py,sha256=8tTU-3giFHOi5j7NJsvH__JJyYcaGrcfsRnkzQNm0Ew,7216
|
43
43
|
fabricatio/models/extra/rag.py,sha256=RMi8vhEPB0I5mVmjRLRLxYHUnm9pFhvVwysaIwmW2s0,3955
|
44
44
|
fabricatio/models/extra/rule.py,sha256=KQQELVhCLUXhEZ35jU3WGYqKHuCYEAkn0p6pxAE-hOU,2625
|
45
45
|
fabricatio/models/extra/__init__.py,sha256=XlYnS_2B9nhLhtQkjE7rvvfPmAAtXVdNi9bSDAR-Ge8,54
|
46
|
-
fabricatio/models/generic.py,sha256=
|
47
|
-
fabricatio/models/kwargs_types.py,sha256=
|
46
|
+
fabricatio/models/generic.py,sha256=xk2Q_dADxUIGUuaqahCPTnZ4HwBRD67HCs13ZA5_LnE,31364
|
47
|
+
fabricatio/models/kwargs_types.py,sha256=gIvLSof3XE-B0cGE5d1BrOQB1HO8Pd666_scd-9JaF4,5000
|
48
48
|
fabricatio/models/role.py,sha256=b8FDRF4VjMMt93Uh5yiAufFbsoH7RcUaaFJAjVmq2l0,2931
|
49
49
|
fabricatio/models/task.py,sha256=bLYSKjlRAlb4jMYyF12RTnm_8pVXysSmX8CYLrEmbQ8,11096
|
50
50
|
fabricatio/models/tool.py,sha256=jQ51g4lwTPfsMF1nbreDJtBczbxIHoXcPuLSOqHliq8,12506
|
@@ -56,12 +56,12 @@ fabricatio/rust_instances.py,sha256=Byeo8KHW_dJiXujJq7YPGDLBX5bHNDYbBc4sY3uubVY,
|
|
56
56
|
fabricatio/toolboxes/arithmetic.py,sha256=WLqhY-Pikv11Y_0SGajwZx3WhsLNpHKf9drzAqOf_nY,1369
|
57
57
|
fabricatio/toolboxes/fs.py,sha256=l4L1CVxJmjw9Ld2XUpIlWfV0_Fu_2Og6d3E13I-S4aE,736
|
58
58
|
fabricatio/toolboxes/__init__.py,sha256=KBJi5OG_pExscdlM7Bnt_UF43j4I3Lv6G71kPVu4KQU,395
|
59
|
-
fabricatio/utils.py,sha256=
|
59
|
+
fabricatio/utils.py,sha256=IBKfs2Rg3bJnazzvj1-Fz1rMWNKhiuQG5_rZ1nxQeMI,10299
|
60
60
|
fabricatio/workflows/articles.py,sha256=ObYTFUqLUk_CzdmmnX6S7APfxcGmPFqnFr9pdjU7Z4Y,969
|
61
61
|
fabricatio/workflows/rag.py,sha256=-YYp2tlE9Vtfgpg6ROpu6QVO8j8yVSPa6yDzlN3qVxs,520
|
62
62
|
fabricatio/workflows/__init__.py,sha256=5ScFSTA-bvhCesj3U9Mnmi6Law6N1fmh5UKyh58L3u8,51
|
63
63
|
fabricatio/__init__.py,sha256=Rmvq2VgdS2u68vnOi2i5RbeWbAwrJDbk8D8D883PJWE,1022
|
64
|
-
fabricatio/rust.cp312-win_amd64.pyd,sha256=
|
65
|
-
fabricatio-0.2.13.
|
66
|
-
fabricatio-0.2.13.
|
67
|
-
fabricatio-0.2.13.
|
64
|
+
fabricatio/rust.cp312-win_amd64.pyd,sha256=wOu3NK1TiWnQPmJd_p7QJgVnP91_xrD4AkIzBFbUBhY,4449280
|
65
|
+
fabricatio-0.2.13.dev2.data/scripts/tdown.exe,sha256=e7KGiGIaLTHej3keQvWbFbjQGLGstQWBuQkQp3w0uhM,3359232
|
66
|
+
fabricatio-0.2.13.dev2.data/scripts/ttm.exe,sha256=OMXhJIJdSO8HuoWtClyLtllDQ2M_06FzeC2nXi7ryMM,2554880
|
67
|
+
fabricatio-0.2.13.dev2.dist-info/RECORD,,
|
Binary file
|
Binary file
|
File without changes
|
File without changes
|