fabricatio 0.2.6.dev2__cp312-cp312-win_amd64.whl → 0.2.7.dev2__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +7 -24
- fabricatio/_rust.cp312-win_amd64.pyd +0 -0
- fabricatio/_rust.pyi +22 -0
- fabricatio/actions/article.py +147 -19
- fabricatio/actions/output.py +21 -6
- fabricatio/actions/rag.py +51 -3
- fabricatio/capabilities/correct.py +34 -4
- fabricatio/capabilities/rag.py +67 -16
- fabricatio/capabilities/rating.py +15 -6
- fabricatio/capabilities/review.py +7 -4
- fabricatio/capabilities/task.py +5 -5
- fabricatio/config.py +29 -21
- fabricatio/decorators.py +32 -0
- fabricatio/models/action.py +117 -43
- fabricatio/models/extra.py +724 -84
- fabricatio/models/generic.py +60 -9
- fabricatio/models/kwargs_types.py +40 -10
- fabricatio/models/role.py +30 -6
- fabricatio/models/tool.py +6 -2
- fabricatio/models/usages.py +94 -47
- fabricatio/models/utils.py +25 -0
- fabricatio/parser.py +2 -0
- fabricatio/workflows/articles.py +12 -1
- fabricatio-0.2.7.dev2.data/scripts/tdown.exe +0 -0
- {fabricatio-0.2.6.dev2.dist-info → fabricatio-0.2.7.dev2.dist-info}/METADATA +6 -2
- fabricatio-0.2.7.dev2.dist-info/RECORD +42 -0
- {fabricatio-0.2.6.dev2.dist-info → fabricatio-0.2.7.dev2.dist-info}/WHEEL +1 -1
- fabricatio-0.2.6.dev2.data/scripts/tdown.exe +0 -0
- fabricatio-0.2.6.dev2.dist-info/RECORD +0 -42
- {fabricatio-0.2.6.dev2.dist-info → fabricatio-0.2.7.dev2.dist-info}/licenses/LICENSE +0 -0
fabricatio/__init__.py
CHANGED
@@ -2,59 +2,42 @@
|
|
2
2
|
|
3
3
|
from importlib.util import find_spec
|
4
4
|
|
5
|
+
from fabricatio import actions, toolboxes, workflows
|
5
6
|
from fabricatio._rust import BibManager
|
6
7
|
from fabricatio._rust_instances import TEMPLATE_MANAGER
|
7
|
-
from fabricatio.actions.article import ExtractArticleEssence, GenerateArticleProposal, GenerateOutline
|
8
|
-
from fabricatio.actions.output import DumpFinalizedOutput
|
9
8
|
from fabricatio.core import env
|
10
|
-
from fabricatio.fs import MAGIKA, safe_json_read, safe_text_read
|
11
9
|
from fabricatio.journal import logger
|
10
|
+
from fabricatio.models import extra
|
12
11
|
from fabricatio.models.action import Action, WorkFlow
|
13
12
|
from fabricatio.models.events import Event
|
14
|
-
from fabricatio.models.extra import ArticleEssence
|
15
13
|
from fabricatio.models.role import Role
|
16
14
|
from fabricatio.models.task import Task
|
17
15
|
from fabricatio.models.tool import ToolBox
|
18
|
-
from fabricatio.models.utils import Message, Messages
|
19
16
|
from fabricatio.parser import Capture, GenericCapture, JsonCapture, PythonCapture
|
20
|
-
from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox
|
21
|
-
from fabricatio.workflows.articles import WriteOutlineWorkFlow
|
22
17
|
|
23
18
|
__all__ = [
|
24
|
-
"MAGIKA",
|
25
19
|
"TEMPLATE_MANAGER",
|
26
20
|
"Action",
|
27
|
-
"ArticleEssence",
|
28
21
|
"BibManager",
|
29
22
|
"Capture",
|
30
|
-
"DumpFinalizedOutput",
|
31
23
|
"Event",
|
32
|
-
"ExtractArticleEssence",
|
33
|
-
"GenerateArticleProposal",
|
34
|
-
"GenerateOutline",
|
35
24
|
"GenericCapture",
|
36
25
|
"JsonCapture",
|
37
|
-
"Message",
|
38
|
-
"Messages",
|
39
26
|
"PythonCapture",
|
40
27
|
"Role",
|
41
28
|
"Task",
|
42
29
|
"ToolBox",
|
43
30
|
"WorkFlow",
|
44
|
-
"
|
45
|
-
"arithmetic_toolbox",
|
46
|
-
"basic_toolboxes",
|
31
|
+
"actions",
|
47
32
|
"env",
|
48
|
-
"
|
33
|
+
"extra",
|
49
34
|
"logger",
|
50
|
-
"
|
51
|
-
"
|
35
|
+
"toolboxes",
|
36
|
+
"workflows",
|
52
37
|
]
|
53
38
|
|
54
39
|
|
55
40
|
if find_spec("pymilvus"):
|
56
|
-
from fabricatio.actions.rag import InjectToDB
|
57
41
|
from fabricatio.capabilities.rag import RAG
|
58
|
-
from fabricatio.workflows.rag import StoreArticle
|
59
42
|
|
60
|
-
__all__ += ["RAG"
|
43
|
+
__all__ += ["RAG"]
|
Binary file
|
fabricatio/_rust.pyi
CHANGED
@@ -9,6 +9,7 @@ class TemplateManager:
|
|
9
9
|
|
10
10
|
See: https://crates.io/crates/handlebars
|
11
11
|
"""
|
12
|
+
|
12
13
|
def __init__(
|
13
14
|
self, template_dirs: List[Path], suffix: Optional[str] = None, active_loading: Optional[bool] = None
|
14
15
|
) -> None:
|
@@ -54,6 +55,17 @@ class TemplateManager:
|
|
54
55
|
RuntimeError: If template rendering fails
|
55
56
|
"""
|
56
57
|
|
58
|
+
def render_template_raw(self, template: str, data: Dict[str, Any]) -> str:
|
59
|
+
"""Render a template with context data.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
template: The template string
|
63
|
+
data: Context dictionary to provide variables to the template
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
Rendered template content as string
|
67
|
+
"""
|
68
|
+
|
57
69
|
def blake3_hash(content: bytes) -> str:
|
58
70
|
"""Calculate the BLAKE3 cryptographic hash of data.
|
59
71
|
|
@@ -100,3 +112,13 @@ class BibManager:
|
|
100
112
|
Uses nucleo_matcher for high-quality fuzzy text searching
|
101
113
|
See: https://crates.io/crates/nucleo-matcher
|
102
114
|
"""
|
115
|
+
|
116
|
+
def list_titles(self, is_verbatim: Optional[bool] = False) -> List[str]:
|
117
|
+
"""List all titles in the bibliography.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
is_verbatim: Whether to return verbatim titles (without formatting)
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
List of all titles in the bibliography
|
124
|
+
"""
|
fabricatio/actions/article.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
"""Actions for transmitting tasks to targets."""
|
2
2
|
|
3
|
-
from os import PathLike
|
4
3
|
from pathlib import Path
|
5
|
-
from typing import Callable, List, Optional
|
4
|
+
from typing import Any, Callable, List, Optional
|
6
5
|
|
7
6
|
from fabricatio.fs import safe_text_read
|
8
7
|
from fabricatio.journal import logger
|
9
8
|
from fabricatio.models.action import Action
|
10
|
-
from fabricatio.models.extra import ArticleEssence, ArticleOutline, ArticleProposal
|
9
|
+
from fabricatio.models.extra import Article, ArticleEssence, ArticleOutline, ArticleProposal
|
11
10
|
from fabricatio.models.task import Task
|
11
|
+
from fabricatio.models.utils import ok
|
12
12
|
|
13
13
|
|
14
14
|
class ExtractArticleEssence(Action):
|
@@ -22,10 +22,10 @@ class ExtractArticleEssence(Action):
|
|
22
22
|
output_key: str = "article_essence"
|
23
23
|
"""The key of the output data."""
|
24
24
|
|
25
|
-
async def _execute
|
25
|
+
async def _execute(
|
26
26
|
self,
|
27
27
|
task_input: Task,
|
28
|
-
reader: Callable[[
|
28
|
+
reader: Callable[[str], str] = lambda p: Path(p).read_text(encoding="utf-8"),
|
29
29
|
**_,
|
30
30
|
) -> Optional[List[ArticleEssence]]:
|
31
31
|
if not task_input.dependencies:
|
@@ -49,24 +49,39 @@ class GenerateArticleProposal(Action):
|
|
49
49
|
|
50
50
|
async def _execute(
|
51
51
|
self,
|
52
|
-
task_input: Task,
|
52
|
+
task_input: Optional[Task] = None,
|
53
|
+
article_briefing: Optional[str] = None,
|
54
|
+
article_briefing_path: Optional[str] = None,
|
53
55
|
**_,
|
54
56
|
) -> Optional[ArticleProposal]:
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
return
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
57
|
+
if article_briefing is None and article_briefing_path is None and task_input is None:
|
58
|
+
logger.error("Task not approved, since all inputs are None.")
|
59
|
+
return None
|
60
|
+
|
61
|
+
return (
|
62
|
+
await self.propose(
|
63
|
+
ArticleProposal,
|
64
|
+
briefing := (
|
65
|
+
article_briefing
|
66
|
+
or safe_text_read(
|
67
|
+
ok(
|
68
|
+
article_briefing_path
|
69
|
+
or await self.awhich_pathstr(
|
70
|
+
f"{task_input.briefing}\nExtract the path of file which contains the article briefing."
|
71
|
+
),
|
72
|
+
"Could not find the path of file to read.",
|
73
|
+
)
|
74
|
+
)
|
75
|
+
),
|
76
|
+
**self.prepend_sys_msg(),
|
77
|
+
)
|
78
|
+
).update_ref(briefing)
|
64
79
|
|
65
80
|
|
66
81
|
class GenerateOutline(Action):
|
67
82
|
"""Generate the article based on the outline."""
|
68
83
|
|
69
|
-
output_key: str = "
|
84
|
+
output_key: str = "article_outline"
|
70
85
|
"""The key of the output data."""
|
71
86
|
|
72
87
|
async def _execute(
|
@@ -74,8 +89,121 @@ class GenerateOutline(Action):
|
|
74
89
|
article_proposal: ArticleProposal,
|
75
90
|
**_,
|
76
91
|
) -> Optional[ArticleOutline]:
|
77
|
-
|
92
|
+
out = await self.propose(
|
78
93
|
ArticleOutline,
|
79
|
-
article_proposal.
|
80
|
-
|
94
|
+
article_proposal.as_prompt(),
|
95
|
+
**self.prepend_sys_msg(),
|
81
96
|
)
|
97
|
+
|
98
|
+
manual = await self.draft_rating_manual(
|
99
|
+
topic=(
|
100
|
+
topic
|
101
|
+
:= "Fix the internal referring error, make sure there is no more `ArticleRef` pointing to a non-existing article component."
|
102
|
+
),
|
103
|
+
)
|
104
|
+
while err := out.resolve_ref_error():
|
105
|
+
logger.warning(f"Found error in the outline: \n{err}")
|
106
|
+
out = await self.correct_obj(
|
107
|
+
out,
|
108
|
+
reference=f"# Referring Error\n{err}",
|
109
|
+
topic=topic,
|
110
|
+
rating_manual=manual,
|
111
|
+
supervisor_check=False,
|
112
|
+
)
|
113
|
+
return out.update_ref(article_proposal)
|
114
|
+
|
115
|
+
|
116
|
+
class CorrectProposal(Action):
|
117
|
+
"""Correct the proposal of the article."""
|
118
|
+
|
119
|
+
output_key: str = "corrected_proposal"
|
120
|
+
|
121
|
+
async def _execute(self, article_proposal: ArticleProposal, **_) -> Any:
|
122
|
+
return (await self.censor_obj(article_proposal, reference=article_proposal.referenced)).update_ref(
|
123
|
+
article_proposal
|
124
|
+
)
|
125
|
+
|
126
|
+
|
127
|
+
class CorrectOutline(Action):
|
128
|
+
"""Correct the outline of the article."""
|
129
|
+
|
130
|
+
output_key: str = "corrected_outline"
|
131
|
+
"""The key of the output data."""
|
132
|
+
|
133
|
+
async def _execute(
|
134
|
+
self,
|
135
|
+
article_outline: ArticleOutline,
|
136
|
+
**_,
|
137
|
+
) -> ArticleOutline:
|
138
|
+
return (await self.censor_obj(article_outline, reference=article_outline.referenced.as_prompt())).update_ref(
|
139
|
+
article_outline
|
140
|
+
)
|
141
|
+
|
142
|
+
|
143
|
+
class GenerateArticle(Action):
|
144
|
+
"""Generate the article based on the outline."""
|
145
|
+
|
146
|
+
output_key: str = "article"
|
147
|
+
"""The key of the output data."""
|
148
|
+
|
149
|
+
async def _execute(
|
150
|
+
self,
|
151
|
+
article_outline: ArticleOutline,
|
152
|
+
**_,
|
153
|
+
) -> Optional[Article]:
|
154
|
+
article: Article = Article.from_outline(article_outline).update_ref(article_outline)
|
155
|
+
|
156
|
+
writing_manual = await self.draft_rating_manual(
|
157
|
+
topic=(
|
158
|
+
topic_1
|
159
|
+
:= "improve the content of the subsection to fit the outline. SHALL never add or remove any section or subsection, you can only add or delete paragraphs within the subsection."
|
160
|
+
),
|
161
|
+
)
|
162
|
+
err_resolve_manual = await self.draft_rating_manual(
|
163
|
+
topic=(topic_2 := "this article component has violated the constrain, please correct it.")
|
164
|
+
)
|
165
|
+
for c, deps in article.iter_dfs_with_deps(chapter=False):
|
166
|
+
logger.info(f"Updating the article component: \n{c.display()}")
|
167
|
+
|
168
|
+
out = ok(
|
169
|
+
await self.correct_obj(
|
170
|
+
c,
|
171
|
+
reference=(
|
172
|
+
ref := f"{article_outline.referenced.as_prompt()}\n" + "\n".join(d.display() for d in deps)
|
173
|
+
),
|
174
|
+
topic=topic_1,
|
175
|
+
rating_manual=writing_manual,
|
176
|
+
supervisor_check=False,
|
177
|
+
),
|
178
|
+
"Could not correct the article component.",
|
179
|
+
)
|
180
|
+
while err := c.resolve_update_error(out):
|
181
|
+
logger.warning(f"Found error in the article component: \n{err}")
|
182
|
+
out = ok(
|
183
|
+
await self.correct_obj(
|
184
|
+
out,
|
185
|
+
reference=f"{ref}\n\n# Violated Error\n{err}",
|
186
|
+
topic=topic_2,
|
187
|
+
rating_manual=err_resolve_manual,
|
188
|
+
supervisor_check=False,
|
189
|
+
),
|
190
|
+
"Could not correct the article component.",
|
191
|
+
)
|
192
|
+
|
193
|
+
c.update_from(out)
|
194
|
+
return article
|
195
|
+
|
196
|
+
|
197
|
+
class CorrectArticle(Action):
|
198
|
+
"""Correct the article based on the outline."""
|
199
|
+
|
200
|
+
output_key: str = "corrected_article"
|
201
|
+
"""The key of the output data."""
|
202
|
+
|
203
|
+
async def _execute(
|
204
|
+
self,
|
205
|
+
article: Article,
|
206
|
+
article_outline: ArticleOutline,
|
207
|
+
**_,
|
208
|
+
) -> Article:
|
209
|
+
return await self.censor_obj(article, reference=article_outline.referenced.as_prompt())
|
fabricatio/actions/output.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1
1
|
"""Dump the finalized output to a file."""
|
2
2
|
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Optional
|
5
|
+
|
3
6
|
from fabricatio.models.action import Action
|
4
7
|
from fabricatio.models.generic import FinalizedDumpAble
|
5
8
|
from fabricatio.models.task import Task
|
9
|
+
from fabricatio.models.utils import ok
|
6
10
|
|
7
11
|
|
8
12
|
class DumpFinalizedOutput(Action):
|
@@ -10,10 +14,21 @@ class DumpFinalizedOutput(Action):
|
|
10
14
|
|
11
15
|
output_key: str = "dump_path"
|
12
16
|
|
13
|
-
async def _execute(
|
14
|
-
|
15
|
-
|
17
|
+
async def _execute(
|
18
|
+
self,
|
19
|
+
to_dump: FinalizedDumpAble,
|
20
|
+
task_input: Optional[Task] = None,
|
21
|
+
dump_path: Optional[str | Path] = None,
|
22
|
+
**_,
|
23
|
+
) -> str:
|
24
|
+
dump_path = Path(
|
25
|
+
dump_path
|
26
|
+
or ok(
|
27
|
+
await self.awhich_pathstr(
|
28
|
+
f"{ok(task_input, 'Neither `task_input` and `dump_path` is provided.').briefing}\n\nExtract a single path of the file, to which I will dump the data."
|
29
|
+
),
|
30
|
+
"Could not find the path of file to dump the data.",
|
31
|
+
)
|
16
32
|
)
|
17
|
-
|
18
|
-
|
19
|
-
return dump_path
|
33
|
+
ok(to_dump, "Could not dump the data since the path is not specified.").finalized_dump_to(dump_path)
|
34
|
+
return dump_path.as_posix()
|
fabricatio/actions/rag.py
CHANGED
@@ -3,8 +3,11 @@
|
|
3
3
|
from typing import List, Optional
|
4
4
|
|
5
5
|
from fabricatio.capabilities.rag import RAG
|
6
|
+
from fabricatio.journal import logger
|
6
7
|
from fabricatio.models.action import Action
|
7
8
|
from fabricatio.models.generic import PrepareVectorization
|
9
|
+
from fabricatio.models.task import Task
|
10
|
+
from questionary import text
|
8
11
|
|
9
12
|
|
10
13
|
class InjectToDB(Action, RAG):
|
@@ -13,13 +16,58 @@ class InjectToDB(Action, RAG):
|
|
13
16
|
output_key: str = "collection_name"
|
14
17
|
|
15
18
|
async def _execute[T: PrepareVectorization](
|
16
|
-
self, to_inject: T | List[T], collection_name:
|
19
|
+
self, to_inject: Optional[T] | List[Optional[T]], collection_name: str = "my_collection",override_inject:bool=False, **_
|
17
20
|
) -> Optional[str]:
|
18
21
|
if not isinstance(to_inject, list):
|
19
22
|
to_inject = [to_inject]
|
20
|
-
|
23
|
+
logger.info(f"Injecting {len(to_inject)} items into the collection '{collection_name}'")
|
24
|
+
if override_inject:
|
25
|
+
self.check_client().client.drop_collection(collection_name)
|
21
26
|
await self.view(collection_name, create=True).consume_string(
|
22
|
-
[
|
27
|
+
[
|
28
|
+
t.prepare_vectorization(self.embedding_max_sequence_length)
|
29
|
+
for t in to_inject
|
30
|
+
if isinstance(t, PrepareVectorization)
|
31
|
+
],
|
23
32
|
)
|
24
33
|
|
25
34
|
return collection_name
|
35
|
+
|
36
|
+
|
37
|
+
class RAGTalk(Action, RAG):
|
38
|
+
"""RAG-enabled conversational action that processes user questions based on a given task.
|
39
|
+
|
40
|
+
This action establishes an interactive conversation loop where it retrieves context-relevant
|
41
|
+
information to answer user queries according to the assigned task briefing.
|
42
|
+
|
43
|
+
Notes:
|
44
|
+
task_input: Task briefing that guides how to respond to user questions
|
45
|
+
collection_name: Name of the vector collection to use for retrieval (default: "my_collection")
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
Number of conversation turns completed before termination
|
49
|
+
"""
|
50
|
+
|
51
|
+
output_key: str = "task_output"
|
52
|
+
|
53
|
+
async def _execute(self, task_input: Task[str], **kwargs) -> int:
|
54
|
+
collection_name = kwargs.get("collection_name", "my_collection")
|
55
|
+
counter = 0
|
56
|
+
|
57
|
+
self.view(collection_name, create=True)
|
58
|
+
|
59
|
+
try:
|
60
|
+
while True:
|
61
|
+
user_say = await text("User: ").ask_async()
|
62
|
+
if user_say is None:
|
63
|
+
break
|
64
|
+
gpt_say = await self.aask_retrieved(
|
65
|
+
user_say,
|
66
|
+
user_say,
|
67
|
+
extra_system_message=f"You have to answer to user obeying task assigned to you:\n{task_input.briefing}",
|
68
|
+
)
|
69
|
+
print(f"GPT: {gpt_say}") # noqa: T201
|
70
|
+
counter += 1
|
71
|
+
except KeyboardInterrupt:
|
72
|
+
logger.info(f"executed talk action {counter} times")
|
73
|
+
return counter
|
@@ -10,9 +10,11 @@ from typing import Optional, Unpack, cast
|
|
10
10
|
from fabricatio._rust_instances import TEMPLATE_MANAGER
|
11
11
|
from fabricatio.capabilities.review import Review, ReviewResult
|
12
12
|
from fabricatio.config import configs
|
13
|
-
from fabricatio.models.generic import Display, ProposedAble, WithBriefing
|
14
|
-
from fabricatio.models.kwargs_types import CorrectKwargs, ReviewKwargs
|
13
|
+
from fabricatio.models.generic import CensoredAble, Display, ProposedAble, WithBriefing
|
14
|
+
from fabricatio.models.kwargs_types import CensoredCorrectKwargs, CorrectKwargs, ReviewKwargs
|
15
15
|
from fabricatio.models.task import Task
|
16
|
+
from questionary import confirm, text
|
17
|
+
from rich import print as rprint
|
16
18
|
|
17
19
|
|
18
20
|
class Correct(Review):
|
@@ -55,7 +57,7 @@ class Correct(Review):
|
|
55
57
|
if supervisor_check:
|
56
58
|
await review_res.supervisor_check()
|
57
59
|
if "default" in kwargs:
|
58
|
-
cast(ReviewKwargs[None], kwargs)["default"] = None
|
60
|
+
cast("ReviewKwargs[None]", kwargs)["default"] = None
|
59
61
|
return await self.propose(
|
60
62
|
obj.__class__,
|
61
63
|
TEMPLATE_MANAGER.render_template(
|
@@ -89,7 +91,7 @@ class Correct(Review):
|
|
89
91
|
await review_res.supervisor_check()
|
90
92
|
|
91
93
|
if "default" in kwargs:
|
92
|
-
cast(ReviewKwargs[None], kwargs)["default"] = None
|
94
|
+
cast("ReviewKwargs[None]", kwargs)["default"] = None
|
93
95
|
return await self.ageneric_string(
|
94
96
|
TEMPLATE_MANAGER.render_template(
|
95
97
|
configs.templates.correct_template, {"content": input_text, "review": review_res.display()}
|
@@ -113,3 +115,31 @@ class Correct(Review):
|
|
113
115
|
Optional[Task[T]]: The corrected task, or None if correction fails.
|
114
116
|
"""
|
115
117
|
return await self.correct_obj(task, **kwargs)
|
118
|
+
|
119
|
+
async def censor_obj[M: CensoredAble](
|
120
|
+
self, obj: M, **kwargs: Unpack[CensoredCorrectKwargs[ReviewResult[str]]]
|
121
|
+
) -> M:
|
122
|
+
"""Censor and correct an object based on defined criteria and templates.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
obj (M): The object to be reviewed and corrected.
|
126
|
+
**kwargs (Unpack[CensoredCorrectKwargs]): Additional keyword
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
M: The censored and corrected object.
|
130
|
+
"""
|
131
|
+
last_modified_obj = obj
|
132
|
+
modified_obj = None
|
133
|
+
rprint(obj.finalized_dump())
|
134
|
+
while await confirm("Begin to correct obj above with human censorship?").ask_async():
|
135
|
+
while (topic := await text("What is the topic of the obj reviewing?").ask_async()) is not None and topic:
|
136
|
+
...
|
137
|
+
if (modified_obj := await self.correct_obj(
|
138
|
+
last_modified_obj,
|
139
|
+
topic=topic,
|
140
|
+
**kwargs,
|
141
|
+
)) is None:
|
142
|
+
break
|
143
|
+
last_modified_obj = modified_obj
|
144
|
+
rprint(last_modified_obj.finalized_dump())
|
145
|
+
return modified_obj or last_modified_obj
|
fabricatio/capabilities/rag.py
CHANGED
@@ -15,13 +15,14 @@ from fabricatio.config import configs
|
|
15
15
|
from fabricatio.journal import logger
|
16
16
|
from fabricatio.models.kwargs_types import (
|
17
17
|
ChooseKwargs,
|
18
|
-
|
18
|
+
CollectionConfigKwargs,
|
19
19
|
EmbeddingKwargs,
|
20
20
|
FetchKwargs,
|
21
21
|
LLMKwargs,
|
22
|
+
RetrievalKwargs,
|
22
23
|
)
|
23
24
|
from fabricatio.models.usages import EmbeddingUsage
|
24
|
-
from fabricatio.models.utils import MilvusData
|
25
|
+
from fabricatio.models.utils import MilvusData, ok
|
25
26
|
from more_itertools.recipes import flatten, unique
|
26
27
|
from pydantic import Field, PrivateAttr
|
27
28
|
|
@@ -60,13 +61,21 @@ class RAG(EmbeddingUsage):
|
|
60
61
|
) -> Self:
|
61
62
|
"""Initialize the Milvus client."""
|
62
63
|
self._client = create_client(
|
63
|
-
uri=milvus_uri or (self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
|
64
|
+
uri=milvus_uri or ok(self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
|
64
65
|
token=milvus_token
|
65
66
|
or (token.get_secret_value() if (token := (self.milvus_token or configs.rag.milvus_token)) else ""),
|
66
|
-
timeout=milvus_timeout or self.milvus_timeout,
|
67
|
+
timeout=milvus_timeout or self.milvus_timeout or configs.rag.milvus_timeout,
|
67
68
|
)
|
68
69
|
return self
|
69
70
|
|
71
|
+
def check_client(self, init: bool = True) -> Self:
|
72
|
+
"""Check if the client is initialized, and if not, initialize it."""
|
73
|
+
if self._client is None and init:
|
74
|
+
return self.init_client()
|
75
|
+
if self._client is None and not init:
|
76
|
+
raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
|
77
|
+
return self
|
78
|
+
|
70
79
|
@overload
|
71
80
|
async def pack(
|
72
81
|
self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
@@ -102,17 +111,24 @@ class RAG(EmbeddingUsage):
|
|
102
111
|
]
|
103
112
|
|
104
113
|
def view(
|
105
|
-
self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[
|
114
|
+
self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionConfigKwargs]
|
106
115
|
) -> Self:
|
107
116
|
"""View the specified collection.
|
108
117
|
|
109
118
|
Args:
|
110
119
|
collection_name (str): The name of the collection.
|
111
120
|
create (bool): Whether to create the collection if it does not exist.
|
112
|
-
**kwargs (Unpack[
|
121
|
+
**kwargs (Unpack[CollectionConfigKwargs]): Additional keyword arguments for collection configuration.
|
113
122
|
"""
|
114
|
-
if create and collection_name and self.client.has_collection(collection_name):
|
115
|
-
kwargs["dimension"] =
|
123
|
+
if create and collection_name and not self.check_client().client.has_collection(collection_name):
|
124
|
+
kwargs["dimension"] = ok(
|
125
|
+
kwargs.get("dimension")
|
126
|
+
or self.milvus_dimensions
|
127
|
+
or configs.rag.milvus_dimensions
|
128
|
+
or self.embedding_dimensions
|
129
|
+
or configs.embedding.dimensions,
|
130
|
+
"`dimension` is not set at any level.",
|
131
|
+
)
|
116
132
|
self.client.create_collection(collection_name, auto_id=True, **kwargs)
|
117
133
|
logger.info(f"Creating collection {collection_name}")
|
118
134
|
|
@@ -158,7 +174,7 @@ class RAG(EmbeddingUsage):
|
|
158
174
|
else:
|
159
175
|
raise TypeError(f"Expected MilvusData or list of MilvusData, got {type(data)}")
|
160
176
|
c_name = collection_name or self.safe_target_collection
|
161
|
-
self.client.insert(c_name, prepared_data)
|
177
|
+
self.check_client().client.insert(c_name, prepared_data)
|
162
178
|
|
163
179
|
if flush:
|
164
180
|
logger.debug(f"Flushing collection {c_name}")
|
@@ -198,6 +214,25 @@ class RAG(EmbeddingUsage):
|
|
198
214
|
self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
|
199
215
|
return self
|
200
216
|
|
217
|
+
@overload
|
218
|
+
async def afetch_document[V: (int, str, float, bytes)](
|
219
|
+
self,
|
220
|
+
vecs: List[List[float]],
|
221
|
+
desired_fields: List[str],
|
222
|
+
collection_name: Optional[str] = None,
|
223
|
+
similarity_threshold: float = 0.37,
|
224
|
+
result_per_query: int = 10,
|
225
|
+
) -> List[Dict[str, V]]: ...
|
226
|
+
|
227
|
+
@overload
|
228
|
+
async def afetch_document[V: (int, str, float, bytes)](
|
229
|
+
self,
|
230
|
+
vecs: List[List[float]],
|
231
|
+
desired_fields: str,
|
232
|
+
collection_name: Optional[str] = None,
|
233
|
+
similarity_threshold: float = 0.37,
|
234
|
+
result_per_query: int = 10,
|
235
|
+
) -> List[V]: ...
|
201
236
|
async def afetch_document[V: (int, str, float, bytes)](
|
202
237
|
self,
|
203
238
|
vecs: List[List[float]],
|
@@ -219,7 +254,7 @@ class RAG(EmbeddingUsage):
|
|
219
254
|
List[Dict[str, Any]] | List[Any]: The retrieved data.
|
220
255
|
"""
|
221
256
|
# Step 1: Search for vectors
|
222
|
-
search_results = self.client.search(
|
257
|
+
search_results = self.check_client().client.search(
|
223
258
|
collection_name or self.safe_target_collection,
|
224
259
|
vecs,
|
225
260
|
search_params={"radius": similarity_threshold},
|
@@ -260,7 +295,7 @@ class RAG(EmbeddingUsage):
|
|
260
295
|
if isinstance(query, str):
|
261
296
|
query = [query]
|
262
297
|
return cast(
|
263
|
-
List[str],
|
298
|
+
"List[str]",
|
264
299
|
await self.afetch_document(
|
265
300
|
vecs=(await self.vectorize(query)),
|
266
301
|
desired_fields="text",
|
@@ -268,6 +303,24 @@ class RAG(EmbeddingUsage):
|
|
268
303
|
),
|
269
304
|
)[:final_limit]
|
270
305
|
|
306
|
+
async def aretrieve_compact(
|
307
|
+
self,
|
308
|
+
query: List[str] | str,
|
309
|
+
**kwargs: Unpack[RetrievalKwargs],
|
310
|
+
) -> str:
|
311
|
+
"""Retrieve data from the collection and format it for display.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
query (List[str] | str): The query to be used for retrieval.
|
315
|
+
**kwargs (Unpack[RetrievalKwargs]): Additional keyword arguments for retrieval.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
str: A formatted string containing the retrieved data.
|
319
|
+
"""
|
320
|
+
return TEMPLATE_MANAGER.render_template(
|
321
|
+
configs.templates.retrieved_display_template, {"docs": (await self.aretrieve(query, **kwargs))}
|
322
|
+
)
|
323
|
+
|
271
324
|
async def aask_retrieved(
|
272
325
|
self,
|
273
326
|
question: str,
|
@@ -298,16 +351,14 @@ class RAG(EmbeddingUsage):
|
|
298
351
|
Returns:
|
299
352
|
str: A string response generated after asking with the context of retrieved documents.
|
300
353
|
"""
|
301
|
-
|
354
|
+
rendered = await self.aretrieve_compact(
|
302
355
|
query or question,
|
303
|
-
final_limit,
|
356
|
+
final_limit=final_limit,
|
304
357
|
collection_name=collection_name,
|
305
358
|
result_per_query=result_per_query,
|
306
359
|
similarity_threshold=similarity_threshold,
|
307
360
|
)
|
308
361
|
|
309
|
-
rendered = TEMPLATE_MANAGER.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
|
310
|
-
|
311
362
|
logger.debug(f"Retrieved Documents: \n{rendered}")
|
312
363
|
return await self.aask(
|
313
364
|
question,
|
@@ -315,7 +366,7 @@ class RAG(EmbeddingUsage):
|
|
315
366
|
**kwargs,
|
316
367
|
)
|
317
368
|
|
318
|
-
async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> List[str]:
|
369
|
+
async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> Optional[List[str]]:
|
319
370
|
"""Refines the given question using a template.
|
320
371
|
|
321
372
|
Args:
|