fabricatio 0.2.9.dev3__cp312-cp312-win_amd64.whl → 0.2.10.dev0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/actions/article.py +13 -113
- fabricatio/actions/article_rag.py +9 -2
- fabricatio/capabilities/check.py +15 -9
- fabricatio/capabilities/correct.py +5 -6
- fabricatio/capabilities/rag.py +39 -232
- fabricatio/capabilities/rating.py +46 -40
- fabricatio/config.py +2 -2
- fabricatio/constants.py +20 -0
- fabricatio/decorators.py +23 -0
- fabricatio/fs/readers.py +20 -1
- fabricatio/models/adv_kwargs_types.py +42 -0
- fabricatio/models/events.py +6 -6
- fabricatio/models/extra/advanced_judge.py +4 -4
- fabricatio/models/extra/article_base.py +25 -211
- fabricatio/models/extra/article_main.py +69 -95
- fabricatio/models/extra/article_proposal.py +15 -14
- fabricatio/models/extra/patches.py +6 -6
- fabricatio/models/extra/problem.py +12 -17
- fabricatio/models/extra/rag.py +72 -0
- fabricatio/models/extra/rule.py +1 -2
- fabricatio/models/generic.py +34 -10
- fabricatio/models/kwargs_types.py +1 -38
- fabricatio/models/task.py +3 -3
- fabricatio/models/usages.py +78 -8
- fabricatio/parser.py +5 -5
- fabricatio/rust.cp312-win_amd64.pyd +0 -0
- fabricatio/rust.pyi +27 -12
- fabricatio-0.2.10.dev0.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dev0.dist-info}/METADATA +1 -1
- fabricatio-0.2.10.dev0.dist-info/RECORD +62 -0
- fabricatio/models/utils.py +0 -148
- fabricatio-0.2.9.dev3.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.9.dev3.dist-info/RECORD +0 -61
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dev0.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.9.dev3.dist-info → fabricatio-0.2.10.dev0.dist-info}/licenses/LICENSE +0 -0
fabricatio/actions/article.py
CHANGED
@@ -11,7 +11,6 @@ from fabricatio.capabilities.propose import Propose
|
|
11
11
|
from fabricatio.fs import safe_text_read
|
12
12
|
from fabricatio.journal import logger
|
13
13
|
from fabricatio.models.action import Action
|
14
|
-
from fabricatio.models.extra.article_base import SubSectionBase
|
15
14
|
from fabricatio.models.extra.article_essence import ArticleEssence
|
16
15
|
from fabricatio.models.extra.article_main import Article
|
17
16
|
from fabricatio.models.extra.article_outline import ArticleOutline
|
@@ -105,7 +104,6 @@ class GenerateArticleProposal(Action, Propose):
|
|
105
104
|
task_input: Optional[Task] = None,
|
106
105
|
article_briefing: Optional[str] = None,
|
107
106
|
article_briefing_path: Optional[str] = None,
|
108
|
-
langauge: Optional[str] = None,
|
109
107
|
**_,
|
110
108
|
) -> Optional[ArticleProposal]:
|
111
109
|
if article_briefing is None and article_briefing_path is None and task_input is None:
|
@@ -122,17 +120,14 @@ class GenerateArticleProposal(Action, Propose):
|
|
122
120
|
)
|
123
121
|
)
|
124
122
|
|
125
|
-
|
123
|
+
logger.info("Start generating the proposal.")
|
124
|
+
return ok(
|
126
125
|
await self.propose(
|
127
126
|
ArticleProposal,
|
128
|
-
f"{briefing}\n\nWrite the value string using `{detect_language(briefing)}`",
|
127
|
+
f"{briefing}\n\nWrite the value string using `{detect_language(briefing)}` as written language.",
|
129
128
|
),
|
130
129
|
"Could not generate the proposal.",
|
131
130
|
).update_ref(briefing)
|
132
|
-
if langauge:
|
133
|
-
proposal.language = langauge
|
134
|
-
|
135
|
-
return proposal
|
136
131
|
|
137
132
|
|
138
133
|
class GenerateInitialOutline(Action, Propose):
|
@@ -146,11 +141,17 @@ class GenerateInitialOutline(Action, Propose):
|
|
146
141
|
article_proposal: ArticleProposal,
|
147
142
|
**_,
|
148
143
|
) -> Optional[ArticleOutline]:
|
144
|
+
raw_outline = await self.aask(
|
145
|
+
f"{(article_proposal.as_prompt())}\n\nNote that you should use `{article_proposal.language}` to write the `ArticleOutline`\n"
|
146
|
+
f"Design each chapter of a proper and academic and ready for release manner.\n"
|
147
|
+
f"You Must make sure every chapter have sections, and every section have subsections.\n"
|
148
|
+
f"Make the chapter and sections and subsections bing divided into a specific enough article component.",
|
149
|
+
)
|
150
|
+
|
149
151
|
return ok(
|
150
152
|
await self.propose(
|
151
153
|
ArticleOutline,
|
152
|
-
f"{
|
153
|
-
f"You Must make sure every chapter have sections, and every section have subsections.",
|
154
|
+
f"{raw_outline}\n\n\n\noutline provided above is the outline i need to extract to a JSON,",
|
154
155
|
),
|
155
156
|
"Could not generate the initial outline.",
|
156
157
|
).update_ref(article_proposal)
|
@@ -182,7 +183,7 @@ class FixIntrospectedErrors(Action, Censor):
|
|
182
183
|
await self.censor_obj(
|
183
184
|
article_outline,
|
184
185
|
ruleset=ok(intro_fix_ruleset or self.ruleset, "No ruleset provided"),
|
185
|
-
reference=f"{article_outline.
|
186
|
+
reference=f"{article_outline.display()}\n # Fatal Error of the Original Article Outline\n{pack}",
|
186
187
|
),
|
187
188
|
"Could not correct the component.",
|
188
189
|
).update_ref(origin)
|
@@ -195,107 +196,6 @@ class FixIntrospectedErrors(Action, Censor):
|
|
195
196
|
return article_outline
|
196
197
|
|
197
198
|
|
198
|
-
class FixIllegalReferences(Action, Censor):
|
199
|
-
"""Fix illegal references in the article outline."""
|
200
|
-
|
201
|
-
output_key: str = "illegal_references_fixed_outline"
|
202
|
-
"""The key of the output data."""
|
203
|
-
|
204
|
-
ruleset: Optional[RuleSet] = None
|
205
|
-
"""Ruleset to use to fix the illegal references."""
|
206
|
-
max_error_count: Optional[int] = None
|
207
|
-
"""The maximum number of errors to fix."""
|
208
|
-
|
209
|
-
async def _execute(
|
210
|
-
self,
|
211
|
-
article_outline: ArticleOutline,
|
212
|
-
ref_fix_ruleset: Optional[RuleSet] = None,
|
213
|
-
**_,
|
214
|
-
) -> Optional[ArticleOutline]:
|
215
|
-
counter = 0
|
216
|
-
while pack := article_outline.find_illegal_ref(gather_identical=True):
|
217
|
-
logger.info(f"Found {counter}th illegal references")
|
218
|
-
ref_seq, err = ok(pack)
|
219
|
-
logger.warning(f"Found illegal referring error: {err}")
|
220
|
-
new = ok(
|
221
|
-
await self.censor_obj(
|
222
|
-
ref_seq[0],
|
223
|
-
ruleset=ok(ref_fix_ruleset or self.ruleset, "No ruleset provided"),
|
224
|
-
reference=f"{article_outline.as_prompt()}\n# Some Basic errors found that need to be fixed\n{err}",
|
225
|
-
),
|
226
|
-
"Could not correct the component",
|
227
|
-
)
|
228
|
-
for r in ref_seq:
|
229
|
-
r.update_from(new)
|
230
|
-
if self.max_error_count and counter > self.max_error_count:
|
231
|
-
logger.warning("Max error count reached, stopping.")
|
232
|
-
break
|
233
|
-
counter += 1
|
234
|
-
|
235
|
-
return article_outline
|
236
|
-
|
237
|
-
|
238
|
-
class TweakOutlineForwardRef(Action, Censor):
|
239
|
-
"""Tweak the forward references in the article outline.
|
240
|
-
|
241
|
-
Ensures that the conclusions of the current chapter effectively support the analysis of subsequent chapters.
|
242
|
-
"""
|
243
|
-
|
244
|
-
output_key: str = "article_outline_fw_ref_checked"
|
245
|
-
ruleset: Optional[RuleSet] = None
|
246
|
-
"""Ruleset to use to fix the illegal references."""
|
247
|
-
|
248
|
-
async def _execute(
|
249
|
-
self, article_outline: ArticleOutline, ref_twk_ruleset: Optional[RuleSet] = None, **cxt
|
250
|
-
) -> ArticleOutline:
|
251
|
-
return await self._inner(
|
252
|
-
article_outline,
|
253
|
-
ruleset=ok(ref_twk_ruleset or self.ruleset, "No ruleset provided"),
|
254
|
-
field_name="support_to",
|
255
|
-
)
|
256
|
-
|
257
|
-
async def _inner(self, article_outline: ArticleOutline, ruleset: RuleSet, field_name: str) -> ArticleOutline:
|
258
|
-
await gather(
|
259
|
-
*[self._loop(a[-1], article_outline, field_name, ruleset) for a in article_outline.iter_subsections()],
|
260
|
-
)
|
261
|
-
|
262
|
-
return article_outline
|
263
|
-
|
264
|
-
async def _loop(
|
265
|
-
self, a: SubSectionBase, article_outline: ArticleOutline, field_name: str, ruleset: RuleSet
|
266
|
-
) -> None:
|
267
|
-
if judge := await self.evidently_judge(
|
268
|
-
f"{article_outline.as_prompt()}\n\n{a.display()}\n"
|
269
|
-
f"Does the `{a.__class__.__name__}`'s `{field_name}` field need to be extended or tweaked?"
|
270
|
-
):
|
271
|
-
await self.censor_obj_inplace(
|
272
|
-
a,
|
273
|
-
ruleset=ruleset,
|
274
|
-
reference=f"{article_outline.as_prompt()}\n"
|
275
|
-
f"The Article component titled `{a.title}` whose `{field_name}` field needs to be extended or tweaked.\n"
|
276
|
-
f"# Judgement\n{judge.display()}",
|
277
|
-
)
|
278
|
-
|
279
|
-
|
280
|
-
class TweakOutlineBackwardRef(TweakOutlineForwardRef):
|
281
|
-
"""Tweak the backward references in the article outline.
|
282
|
-
|
283
|
-
Ensures that the prerequisites of the current chapter are correctly referenced in the `depend_on` field.
|
284
|
-
"""
|
285
|
-
|
286
|
-
output_key: str = "article_outline_bw_ref_checked"
|
287
|
-
ruleset: Optional[RuleSet] = None
|
288
|
-
|
289
|
-
async def _execute(
|
290
|
-
self, article_outline: ArticleOutline, ref_twk_ruleset: Optional[RuleSet] = None, **cxt
|
291
|
-
) -> ArticleOutline:
|
292
|
-
return await self._inner(
|
293
|
-
article_outline,
|
294
|
-
ruleset=ok(ref_twk_ruleset or self.ruleset, "No ruleset provided"),
|
295
|
-
field_name="depend_on",
|
296
|
-
)
|
297
|
-
|
298
|
-
|
299
199
|
class GenerateArticle(Action, Censor):
|
300
200
|
"""Generate the article based on the outline."""
|
301
201
|
|
@@ -318,7 +218,7 @@ class GenerateArticle(Action, Censor):
|
|
318
218
|
self.censor_obj_inplace(
|
319
219
|
subsec,
|
320
220
|
ruleset=ok(article_gen_ruleset or self.ruleset, "No ruleset provided"),
|
321
|
-
reference=f"{article_outline.as_prompt()}\n# Error Need to be fixed\n{err}",
|
221
|
+
reference=f"{article_outline.as_prompt()}\n# Error Need to be fixed\n{err}\nYou should use `{subsec.language}` to write the new `Subsection`.",
|
322
222
|
)
|
323
223
|
for _, _, subsec in article.iter_subsections()
|
324
224
|
if (err := subsec.introspect()) and logger.warning(f"Found Introspection Error:\n{err}") is None
|
@@ -29,6 +29,9 @@ class TweakArticleRAG(Action, RAG, Censor):
|
|
29
29
|
ruleset: Optional[RuleSet] = None
|
30
30
|
"""The ruleset to be used for censoring the article."""
|
31
31
|
|
32
|
+
ref_limit: int = 30
|
33
|
+
"""The limit of references to be retrieved"""
|
34
|
+
|
32
35
|
async def _execute(
|
33
36
|
self,
|
34
37
|
article: Article,
|
@@ -88,11 +91,15 @@ class TweakArticleRAG(Action, RAG, Censor):
|
|
88
91
|
f"{subsec.display()}\n"
|
89
92
|
f"# Requirement\n"
|
90
93
|
f"Search related articles in the base to find reference candidates, "
|
91
|
-
f"
|
94
|
+
f"provide queries in both `English` and `{subsec.language}` can get more accurate results.",
|
92
95
|
)
|
93
96
|
)
|
94
97
|
await self.censor_obj_inplace(
|
95
98
|
subsec,
|
96
99
|
ruleset=ruleset,
|
97
|
-
reference=await self.aretrieve_compact(refind_q, final_limit=
|
100
|
+
reference=f"{await self.aretrieve_compact(refind_q, final_limit=self.ref_limit)}\n\n"
|
101
|
+
f"You can use Reference above to rewrite the `{subsec.__class__.__name__}`.\n"
|
102
|
+
f"You should Always use `{subsec.language}` as written language, "
|
103
|
+
f"which is the original language of the `{subsec.title}`. "
|
104
|
+
f"since rewrite a `{subsec.__class__.__name__}` in a different language is usually a bad choice",
|
98
105
|
)
|
fabricatio/capabilities/check.py
CHANGED
@@ -8,7 +8,7 @@ from fabricatio.capabilities.advanced_judge import AdvancedJudge
|
|
8
8
|
from fabricatio.capabilities.propose import Propose
|
9
9
|
from fabricatio.config import configs
|
10
10
|
from fabricatio.journal import logger
|
11
|
-
from fabricatio.models.extra.patches import
|
11
|
+
from fabricatio.models.extra.patches import RuleSetMetadata
|
12
12
|
from fabricatio.models.extra.problem import Improvement
|
13
13
|
from fabricatio.models.extra.rule import Rule, RuleSet
|
14
14
|
from fabricatio.models.generic import Display, WithBriefing
|
@@ -42,12 +42,17 @@ class Check(AdvancedJudge, Propose):
|
|
42
42
|
- Returns None if any step in rule generation fails
|
43
43
|
- Uses `alist_str` for requirement breakdown and iterative rule proposal
|
44
44
|
"""
|
45
|
-
rule_reqs =
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
45
|
+
rule_reqs = (
|
46
|
+
await self.alist_str(
|
47
|
+
TEMPLATE_MANAGER.render_template(
|
48
|
+
configs.templates.ruleset_requirement_breakdown_template,
|
49
|
+
{"ruleset_requirement": ruleset_requirement},
|
50
|
+
),
|
51
|
+
rule_count,
|
52
|
+
**override_kwargs(kwargs, default=None),
|
53
|
+
)
|
54
|
+
if rule_count > 1
|
55
|
+
else [ruleset_requirement]
|
51
56
|
)
|
52
57
|
|
53
58
|
if rule_reqs is None:
|
@@ -65,7 +70,7 @@ class Check(AdvancedJudge, Propose):
|
|
65
70
|
return None
|
66
71
|
|
67
72
|
ruleset_patch = await self.propose(
|
68
|
-
|
73
|
+
RuleSetMetadata,
|
69
74
|
f"{ruleset_requirement}\n\nYou should use `{detect_language(ruleset_requirement)}`!",
|
70
75
|
**override_kwargs(kwargs, default=None),
|
71
76
|
)
|
@@ -99,7 +104,8 @@ class Check(AdvancedJudge, Propose):
|
|
99
104
|
- Proposes Improvement only when violation is confirmed
|
100
105
|
"""
|
101
106
|
if judge := await self.evidently_judge(
|
102
|
-
f"# Content to exam\n{input_text}\n\n# Rule Must to follow\n{rule.display()}\nDoes `Content to exam` provided above violate the `
|
107
|
+
f"# Content to exam\n{input_text}\n\n# Rule Must to follow\n{rule.display()}\nDoes `Content to exam` provided above violate the `{rule.name}` provided above?"
|
108
|
+
f"should I take some measure to fix that violation? true for I do need, false for I don't need.",
|
103
109
|
**override_kwargs(kwargs, default=None),
|
104
110
|
):
|
105
111
|
logger.info(f"Rule `{rule.name}` violated: \n{judge.display()}")
|
@@ -57,7 +57,7 @@ class Correct(Rating, Propose):
|
|
57
57
|
self.decide_solution(
|
58
58
|
ps,
|
59
59
|
**fallback_kwargs(
|
60
|
-
kwargs, topic=f"which solution is better to deal this problem {ps.problem.
|
60
|
+
kwargs, topic=f"which solution is better to deal this problem {ps.problem.description}\n\n"
|
61
61
|
),
|
62
62
|
)
|
63
63
|
for ps in improvement.problem_solutions
|
@@ -167,13 +167,12 @@ class Correct(Rating, Propose):
|
|
167
167
|
logger.info(f"Improvement {improvement.focused_on} not decided, start deciding...")
|
168
168
|
improvement = await self.decide_improvement(improvement, **override_kwargs(kwargs, default=None))
|
169
169
|
|
170
|
-
|
171
|
-
|
170
|
+
total = len(improvement.problem_solutions)
|
171
|
+
for idx, ps in enumerate(improvement.problem_solutions):
|
172
|
+
logger.info(f"[{idx + 1}/{total}] Fixing {obj.__class__.__name__} for problem `{ps.problem.name}`")
|
172
173
|
fixed_obj = await self.fix_troubled_obj(obj, ps, reference, **kwargs)
|
173
174
|
if fixed_obj is None:
|
174
|
-
logger.error(
|
175
|
-
f"Failed to fix troubling obj {obj.__class__.__name__} when deal with problem: {ps.problem.name}",
|
176
|
-
)
|
175
|
+
logger.error(f"[{idx + 1}/{total}] Failed to fix problem `{ps.problem.name}`")
|
177
176
|
return None
|
178
177
|
obj = fixed_obj
|
179
178
|
return obj
|
fabricatio/capabilities/rag.py
CHANGED
@@ -3,28 +3,22 @@
|
|
3
3
|
try:
|
4
4
|
from pymilvus import MilvusClient
|
5
5
|
except ImportError as e:
|
6
|
-
raise RuntimeError(
|
6
|
+
raise RuntimeError(
|
7
|
+
"pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`?"
|
8
|
+
) from e
|
7
9
|
from functools import lru_cache
|
8
10
|
from operator import itemgetter
|
9
|
-
from
|
10
|
-
from pathlib import Path
|
11
|
-
from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, cast, overload
|
11
|
+
from typing import List, Optional, Self, Type, Unpack
|
12
12
|
|
13
13
|
from more_itertools.recipes import flatten, unique
|
14
14
|
from pydantic import Field, PrivateAttr
|
15
15
|
|
16
16
|
from fabricatio.config import configs
|
17
17
|
from fabricatio.journal import logger
|
18
|
-
from fabricatio.models.
|
19
|
-
|
20
|
-
|
21
|
-
EmbeddingKwargs,
|
22
|
-
FetchKwargs,
|
23
|
-
LLMKwargs,
|
24
|
-
RetrievalKwargs,
|
25
|
-
)
|
18
|
+
from fabricatio.models.adv_kwargs_types import CollectionConfigKwargs, FetchKwargs
|
19
|
+
from fabricatio.models.extra.rag import MilvusDataBase
|
20
|
+
from fabricatio.models.kwargs_types import ChooseKwargs
|
26
21
|
from fabricatio.models.usages import EmbeddingUsage
|
27
|
-
from fabricatio.models.utils import MilvusData
|
28
22
|
from fabricatio.rust_instances import TEMPLATE_MANAGER
|
29
23
|
from fabricatio.utils import ok
|
30
24
|
|
@@ -78,40 +72,6 @@ class RAG(EmbeddingUsage):
|
|
78
72
|
raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
|
79
73
|
return self
|
80
74
|
|
81
|
-
@overload
|
82
|
-
async def pack(
|
83
|
-
self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
84
|
-
) -> List[MilvusData]: ...
|
85
|
-
@overload
|
86
|
-
async def pack(
|
87
|
-
self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
88
|
-
) -> MilvusData: ...
|
89
|
-
|
90
|
-
async def pack(
|
91
|
-
self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
92
|
-
) -> List[MilvusData] | MilvusData:
|
93
|
-
"""Asynchronously generates MilvusData objects for the given input text.
|
94
|
-
|
95
|
-
Args:
|
96
|
-
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
97
|
-
subject (Optional[str]): The subject of the input text. Defaults to None.
|
98
|
-
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
99
|
-
|
100
|
-
Returns:
|
101
|
-
List[MilvusData] | MilvusData: The generated MilvusData objects.
|
102
|
-
"""
|
103
|
-
if isinstance(input_text, str):
|
104
|
-
return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
|
105
|
-
vecs = await self.vectorize(input_text, **kwargs)
|
106
|
-
return [
|
107
|
-
MilvusData(
|
108
|
-
vector=vec,
|
109
|
-
text=text,
|
110
|
-
subject=subject,
|
111
|
-
)
|
112
|
-
for text, vec in zip(input_text, vecs, strict=True)
|
113
|
-
]
|
114
|
-
|
115
75
|
def view(
|
116
76
|
self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionConfigKwargs]
|
117
77
|
) -> Self:
|
@@ -152,29 +112,27 @@ class RAG(EmbeddingUsage):
|
|
152
112
|
Returns:
|
153
113
|
str: The name of the collection being viewed.
|
154
114
|
"""
|
155
|
-
|
156
|
-
raise RuntimeError("No collection is being viewed. Have you called `self.view()`?")
|
157
|
-
return self.target_collection
|
115
|
+
return ok(self.target_collection, "No collection is being viewed. Have you called `self.view()`?")
|
158
116
|
|
159
|
-
def add_document[D:
|
160
|
-
self, data: D |
|
117
|
+
async def add_document[D: MilvusDataBase](
|
118
|
+
self, data: List[D] | D, collection_name: Optional[str] = None, flush: bool = False
|
161
119
|
) -> Self:
|
162
120
|
"""Adds a document to the specified collection.
|
163
121
|
|
164
122
|
Args:
|
165
|
-
data (Union[Dict[str, Any],
|
123
|
+
data (Union[Dict[str, Any], MilvusDataBase] | List[Union[Dict[str, Any], MilvusDataBase]]): The data to be added to the collection.
|
166
124
|
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
167
125
|
flush (bool): Whether to flush the collection after insertion.
|
168
126
|
|
169
127
|
Returns:
|
170
128
|
Self: The current instance, allowing for method chaining.
|
171
129
|
"""
|
172
|
-
if isinstance(data,
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
130
|
+
if isinstance(data, MilvusDataBase):
|
131
|
+
data = [data]
|
132
|
+
|
133
|
+
data_vec = await self.vectorize([d.to_vectorize for d in data])
|
134
|
+
prepared_data = [d.prepare_insertion(vec) for d, vec in zip(data, data_vec, strict=True)]
|
135
|
+
|
178
136
|
c_name = collection_name or self.safe_target_collection
|
179
137
|
self.check_client().client.insert(c_name, prepared_data)
|
180
138
|
|
@@ -183,84 +141,33 @@ class RAG(EmbeddingUsage):
|
|
183
141
|
self.client.flush(c_name)
|
184
142
|
return self
|
185
143
|
|
186
|
-
async def
|
187
|
-
self,
|
188
|
-
source: List[PathLike] | PathLike,
|
189
|
-
reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
|
190
|
-
collection_name: Optional[str] = None,
|
191
|
-
) -> Self:
|
192
|
-
"""Consume a file and add its content to the collection.
|
193
|
-
|
194
|
-
Args:
|
195
|
-
source (PathLike): The path to the file to be consumed.
|
196
|
-
reader (Callable[[PathLike], MilvusData]): The reader function to read the file.
|
197
|
-
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
198
|
-
|
199
|
-
Returns:
|
200
|
-
Self: The current instance, allowing for method chaining.
|
201
|
-
"""
|
202
|
-
if not isinstance(source, list):
|
203
|
-
source = [source]
|
204
|
-
return await self.consume_string([reader(s) for s in source], collection_name)
|
205
|
-
|
206
|
-
async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
|
207
|
-
"""Consume a string and add it to the collection.
|
208
|
-
|
209
|
-
Args:
|
210
|
-
text (List[str] | str): The text to be added to the collection.
|
211
|
-
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
212
|
-
|
213
|
-
Returns:
|
214
|
-
Self: The current instance, allowing for method chaining.
|
215
|
-
"""
|
216
|
-
self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
|
217
|
-
return self
|
218
|
-
|
219
|
-
@overload
|
220
|
-
async def afetch_document[V: (int, str, float, bytes)](
|
144
|
+
async def afetch_document[D: MilvusDataBase](
|
221
145
|
self,
|
222
146
|
vecs: List[List[float]],
|
223
|
-
|
147
|
+
document_model: Type[D],
|
224
148
|
collection_name: Optional[str] = None,
|
225
149
|
similarity_threshold: float = 0.37,
|
226
150
|
result_per_query: int = 10,
|
227
|
-
) -> List[
|
228
|
-
|
229
|
-
@overload
|
230
|
-
async def afetch_document[V: (int, str, float, bytes)](
|
231
|
-
self,
|
232
|
-
vecs: List[List[float]],
|
233
|
-
desired_fields: str,
|
234
|
-
collection_name: Optional[str] = None,
|
235
|
-
similarity_threshold: float = 0.37,
|
236
|
-
result_per_query: int = 10,
|
237
|
-
) -> List[V]: ...
|
238
|
-
async def afetch_document[V: (int, str, float, bytes)](
|
239
|
-
self,
|
240
|
-
vecs: List[List[float]],
|
241
|
-
desired_fields: List[str] | str,
|
242
|
-
collection_name: Optional[str] = None,
|
243
|
-
similarity_threshold: float = 0.37,
|
244
|
-
result_per_query: int = 10,
|
245
|
-
) -> List[Dict[str, Any]] | List[V]:
|
246
|
-
"""Fetch data from the collection.
|
151
|
+
) -> List[D]:
|
152
|
+
"""Asynchronously fetches documents from a Milvus database based on input vectors.
|
247
153
|
|
248
154
|
Args:
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
155
|
+
vecs (List[List[float]]): A list of vectors to search for in the database.
|
156
|
+
document_model (Type[D]): The model class used to convert fetched data into document objects.
|
157
|
+
collection_name (Optional[str]): The name of the collection to search within.
|
158
|
+
If None, the currently viewed collection is used.
|
159
|
+
similarity_threshold (float): The similarity threshold for vector search. Defaults to 0.37.
|
160
|
+
result_per_query (int): The maximum number of results to return per query. Defaults to 10.
|
254
161
|
|
255
162
|
Returns:
|
256
|
-
|
163
|
+
List[D]: A list of document objects created from the fetched data.
|
257
164
|
"""
|
258
165
|
# Step 1: Search for vectors
|
259
166
|
search_results = self.check_client().client.search(
|
260
167
|
collection_name or self.safe_target_collection,
|
261
168
|
vecs,
|
262
169
|
search_params={"radius": similarity_threshold},
|
263
|
-
output_fields=
|
170
|
+
output_fields=list(document_model.model_fields),
|
264
171
|
limit=result_per_query,
|
265
172
|
)
|
266
173
|
|
@@ -270,20 +177,20 @@ class RAG(EmbeddingUsage):
|
|
270
177
|
# Step 3: Sort by distance (descending)
|
271
178
|
sorted_results = sorted(unique_results, key=itemgetter("distance"), reverse=True)
|
272
179
|
|
273
|
-
logger.debug(
|
180
|
+
logger.debug(
|
181
|
+
f"Fetched {len(sorted_results)} document,searched similarities: {[t['distance'] for t in sorted_results]}"
|
182
|
+
)
|
274
183
|
# Step 4: Extract the entities
|
275
184
|
resp = [result["entity"] for result in sorted_results]
|
276
185
|
|
277
|
-
|
278
|
-
return resp
|
279
|
-
return [r.get(desired_fields) for r in resp] # extract the single field as list
|
186
|
+
return document_model.from_sequence(resp)
|
280
187
|
|
281
|
-
async def aretrieve(
|
188
|
+
async def aretrieve[D: MilvusDataBase](
|
282
189
|
self,
|
283
190
|
query: List[str] | str,
|
284
191
|
final_limit: int = 20,
|
285
|
-
**kwargs: Unpack[FetchKwargs],
|
286
|
-
) -> List[
|
192
|
+
**kwargs: Unpack[FetchKwargs[D]],
|
193
|
+
) -> List[D]:
|
287
194
|
"""Retrieve data from the collection.
|
288
195
|
|
289
196
|
Args:
|
@@ -292,82 +199,17 @@ class RAG(EmbeddingUsage):
|
|
292
199
|
**kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
|
293
200
|
|
294
201
|
Returns:
|
295
|
-
List[
|
202
|
+
List[D]: A list of document objects created from the retrieved data.
|
296
203
|
"""
|
297
204
|
if isinstance(query, str):
|
298
205
|
query = [query]
|
299
|
-
return
|
300
|
-
"List[str]",
|
206
|
+
return (
|
301
207
|
await self.afetch_document(
|
302
208
|
vecs=(await self.vectorize(query)),
|
303
|
-
desired_fields="text",
|
304
209
|
**kwargs,
|
305
|
-
)
|
210
|
+
)
|
306
211
|
)[:final_limit]
|
307
212
|
|
308
|
-
async def aretrieve_compact(
|
309
|
-
self,
|
310
|
-
query: List[str] | str,
|
311
|
-
**kwargs: Unpack[RetrievalKwargs],
|
312
|
-
) -> str:
|
313
|
-
"""Retrieve data from the collection and format it for display.
|
314
|
-
|
315
|
-
Args:
|
316
|
-
query (List[str] | str): The query to be used for retrieval.
|
317
|
-
**kwargs (Unpack[RetrievalKwargs]): Additional keyword arguments for retrieval.
|
318
|
-
|
319
|
-
Returns:
|
320
|
-
str: A formatted string containing the retrieved data.
|
321
|
-
"""
|
322
|
-
return TEMPLATE_MANAGER.render_template(
|
323
|
-
configs.templates.retrieved_display_template, {"docs": (await self.aretrieve(query, **kwargs))}
|
324
|
-
)
|
325
|
-
|
326
|
-
async def aask_retrieved(
|
327
|
-
self,
|
328
|
-
question: str,
|
329
|
-
query: Optional[List[str] | str] = None,
|
330
|
-
collection_name: Optional[str] = None,
|
331
|
-
extra_system_message: str = "",
|
332
|
-
result_per_query: int = 10,
|
333
|
-
final_limit: int = 20,
|
334
|
-
similarity_threshold: float = 0.37,
|
335
|
-
**kwargs: Unpack[LLMKwargs],
|
336
|
-
) -> str:
|
337
|
-
"""Asks a question by retrieving relevant documents based on the provided query.
|
338
|
-
|
339
|
-
This method performs document retrieval using the given query, then asks the
|
340
|
-
specified question using the retrieved documents as context.
|
341
|
-
|
342
|
-
Args:
|
343
|
-
question (str): The question to be asked.
|
344
|
-
query (List[str] | str): The query or list of queries used for document retrieval.
|
345
|
-
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
346
|
-
If not provided, the currently viewed collection is used.
|
347
|
-
extra_system_message (str): An additional system message to be included in the prompt.
|
348
|
-
result_per_query (int): The number of results to return per query. Default is 10.
|
349
|
-
final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
|
350
|
-
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
351
|
-
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
|
352
|
-
|
353
|
-
Returns:
|
354
|
-
str: A string response generated after asking with the context of retrieved documents.
|
355
|
-
"""
|
356
|
-
rendered = await self.aretrieve_compact(
|
357
|
-
query or question,
|
358
|
-
final_limit=final_limit,
|
359
|
-
collection_name=collection_name,
|
360
|
-
result_per_query=result_per_query,
|
361
|
-
similarity_threshold=similarity_threshold,
|
362
|
-
)
|
363
|
-
|
364
|
-
logger.debug(f"Retrieved Documents: \n{rendered}")
|
365
|
-
return await self.aask(
|
366
|
-
question,
|
367
|
-
f"{rendered}\n\n{extra_system_message}",
|
368
|
-
**kwargs,
|
369
|
-
)
|
370
|
-
|
371
213
|
async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> Optional[List[str]]:
|
372
214
|
"""Refines the given question using a template.
|
373
215
|
|
@@ -385,38 +227,3 @@ class RAG(EmbeddingUsage):
|
|
385
227
|
),
|
386
228
|
**kwargs,
|
387
229
|
)
|
388
|
-
|
389
|
-
async def aask_refined(
|
390
|
-
self,
|
391
|
-
question: str,
|
392
|
-
collection_name: Optional[str] = None,
|
393
|
-
extra_system_message: str = "",
|
394
|
-
result_per_query: int = 10,
|
395
|
-
final_limit: int = 20,
|
396
|
-
similarity_threshold: float = 0.37,
|
397
|
-
**kwargs: Unpack[LLMKwargs],
|
398
|
-
) -> str:
|
399
|
-
"""Asks a question using a refined query based on the provided question.
|
400
|
-
|
401
|
-
Args:
|
402
|
-
question (str): The question to be asked.
|
403
|
-
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
404
|
-
extra_system_message (str): An additional system message to be included in the prompt.
|
405
|
-
result_per_query (int): The number of results to return per query. Default is 10.
|
406
|
-
final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
|
407
|
-
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
408
|
-
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
|
409
|
-
|
410
|
-
Returns:
|
411
|
-
str: A string response generated after asking with the refined question.
|
412
|
-
"""
|
413
|
-
return await self.aask_retrieved(
|
414
|
-
question,
|
415
|
-
await self.arefined_query(question, **kwargs),
|
416
|
-
collection_name=collection_name,
|
417
|
-
extra_system_message=extra_system_message,
|
418
|
-
result_per_query=result_per_query,
|
419
|
-
final_limit=final_limit,
|
420
|
-
similarity_threshold=similarity_threshold,
|
421
|
-
**kwargs,
|
422
|
-
)
|