fabricatio 0.2.4.dev2__cp312-cp312-manylinux_2_34_x86_64.whl → 0.2.5__cp312-cp312-manylinux_2_34_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +14 -5
- fabricatio/_rust.cpython-312-x86_64-linux-gnu.so +0 -0
- fabricatio/_rust.pyi +65 -16
- fabricatio/_rust_instances.py +2 -0
- fabricatio/actions/article.py +46 -14
- fabricatio/actions/output.py +21 -0
- fabricatio/actions/rag.py +1 -1
- fabricatio/capabilities/propose.py +14 -20
- fabricatio/capabilities/rag.py +85 -26
- fabricatio/capabilities/rating.py +59 -51
- fabricatio/capabilities/review.py +241 -0
- fabricatio/capabilities/task.py +7 -8
- fabricatio/config.py +36 -4
- fabricatio/fs/__init__.py +13 -1
- fabricatio/fs/curd.py +27 -8
- fabricatio/fs/readers.py +6 -3
- fabricatio/journal.py +1 -1
- fabricatio/models/action.py +6 -8
- fabricatio/models/events.py +6 -4
- fabricatio/models/extra.py +100 -25
- fabricatio/models/generic.py +56 -4
- fabricatio/models/kwargs_types.py +123 -35
- fabricatio/models/role.py +3 -3
- fabricatio/models/task.py +0 -14
- fabricatio/models/tool.py +7 -6
- fabricatio/models/usages.py +144 -101
- fabricatio/parser.py +26 -5
- fabricatio/toolboxes/__init__.py +1 -3
- fabricatio/toolboxes/fs.py +17 -1
- fabricatio/workflows/articles.py +10 -6
- fabricatio/workflows/rag.py +11 -0
- fabricatio-0.2.5.data/scripts/tdown +0 -0
- {fabricatio-0.2.4.dev2.dist-info → fabricatio-0.2.5.dist-info}/METADATA +2 -1
- fabricatio-0.2.5.dist-info/RECORD +41 -0
- fabricatio/toolboxes/task.py +0 -6
- fabricatio-0.2.4.dev2.data/scripts/tdown +0 -0
- fabricatio-0.2.4.dev2.dist-info/RECORD +0 -39
- {fabricatio-0.2.4.dev2.dist-info → fabricatio-0.2.5.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.4.dev2.dist-info → fabricatio-0.2.5.dist-info}/licenses/LICENSE +0 -0
fabricatio/__init__.py
CHANGED
@@ -2,10 +2,12 @@
|
|
2
2
|
|
3
3
|
from importlib.util import find_spec
|
4
4
|
|
5
|
+
from fabricatio._rust import BibManager
|
5
6
|
from fabricatio._rust_instances import template_manager
|
6
|
-
from fabricatio.actions.article import ExtractArticleEssence
|
7
|
+
from fabricatio.actions.article import ExtractArticleEssence, GenerateArticleProposal, GenerateOutline
|
8
|
+
from fabricatio.actions.output import DumpFinalizedOutput
|
7
9
|
from fabricatio.core import env
|
8
|
-
from fabricatio.fs import magika
|
10
|
+
from fabricatio.fs import magika, safe_json_read, safe_text_read
|
9
11
|
from fabricatio.journal import logger
|
10
12
|
from fabricatio.models.action import Action, WorkFlow
|
11
13
|
from fabricatio.models.events import Event
|
@@ -15,15 +17,20 @@ from fabricatio.models.task import Task
|
|
15
17
|
from fabricatio.models.tool import ToolBox
|
16
18
|
from fabricatio.models.utils import Message, Messages
|
17
19
|
from fabricatio.parser import Capture, CodeBlockCapture, JsonCapture, PythonCapture
|
18
|
-
from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox
|
20
|
+
from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox
|
21
|
+
from fabricatio.workflows.articles import WriteOutlineWorkFlow
|
19
22
|
|
20
23
|
__all__ = [
|
21
24
|
"Action",
|
22
25
|
"ArticleEssence",
|
26
|
+
"BibManager",
|
23
27
|
"Capture",
|
24
28
|
"CodeBlockCapture",
|
29
|
+
"DumpFinalizedOutput",
|
25
30
|
"Event",
|
26
31
|
"ExtractArticleEssence",
|
32
|
+
"GenerateArticleProposal",
|
33
|
+
"GenerateOutline",
|
27
34
|
"JsonCapture",
|
28
35
|
"Message",
|
29
36
|
"Messages",
|
@@ -32,13 +39,15 @@ __all__ = [
|
|
32
39
|
"Task",
|
33
40
|
"ToolBox",
|
34
41
|
"WorkFlow",
|
42
|
+
"WriteOutlineWorkFlow",
|
35
43
|
"arithmetic_toolbox",
|
36
44
|
"basic_toolboxes",
|
37
45
|
"env",
|
38
46
|
"fs_toolbox",
|
39
47
|
"logger",
|
40
48
|
"magika",
|
41
|
-
"
|
49
|
+
"safe_json_read",
|
50
|
+
"safe_text_read",
|
42
51
|
"template_manager",
|
43
52
|
]
|
44
53
|
|
@@ -46,6 +55,6 @@ __all__ = [
|
|
46
55
|
if find_spec("pymilvus"):
|
47
56
|
from fabricatio.actions.rag import InjectToDB
|
48
57
|
from fabricatio.capabilities.rag import RAG
|
49
|
-
from fabricatio.workflows.
|
58
|
+
from fabricatio.workflows.rag import StoreArticle
|
50
59
|
|
51
60
|
__all__ += ["RAG", "InjectToDB", "StoreArticle"]
|
Binary file
|
fabricatio/_rust.pyi
CHANGED
@@ -2,52 +2,101 @@ from pathlib import Path
|
|
2
2
|
from typing import Any, Dict, List, Optional
|
3
3
|
|
4
4
|
class TemplateManager:
|
5
|
-
"""
|
5
|
+
"""Template rendering engine using Handlebars templates.
|
6
|
+
|
7
|
+
This manager handles template discovery, loading, and rendering
|
8
|
+
through a wrapper around the handlebars-rust engine.
|
9
|
+
|
10
|
+
See: https://crates.io/crates/handlebars
|
11
|
+
"""
|
6
12
|
def __init__(
|
7
13
|
self, template_dirs: List[Path], suffix: Optional[str] = None, active_loading: Optional[bool] = None
|
8
14
|
) -> None:
|
9
15
|
"""Initialize the template manager.
|
10
16
|
|
11
17
|
Args:
|
12
|
-
template_dirs
|
13
|
-
suffix
|
14
|
-
active_loading
|
18
|
+
template_dirs: List of directories containing template files
|
19
|
+
suffix: File extension for templates (defaults to 'hbs')
|
20
|
+
active_loading: Whether to enable dev mode for reloading templates on change
|
15
21
|
"""
|
16
22
|
|
17
23
|
@property
|
18
24
|
def template_count(self) -> int:
|
19
|
-
"""
|
25
|
+
"""Returns the number of currently loaded templates."""
|
20
26
|
|
21
27
|
def get_template_source(self, name: str) -> Optional[str]:
|
22
|
-
"""Get the
|
28
|
+
"""Get the filesystem path for a template.
|
23
29
|
|
24
30
|
Args:
|
25
|
-
name
|
31
|
+
name: Template name (without extension)
|
26
32
|
|
27
33
|
Returns:
|
28
|
-
|
34
|
+
Path to the template file if found, None otherwise
|
29
35
|
"""
|
30
36
|
|
31
37
|
def discover_templates(self) -> None:
|
32
|
-
"""
|
38
|
+
"""Scan template directories and load available templates.
|
39
|
+
|
40
|
+
This refreshes the template cache, finding any new or modified templates.
|
41
|
+
"""
|
33
42
|
|
34
43
|
def render_template(self, name: str, data: Dict[str, Any]) -> str:
|
35
|
-
"""Render a template with
|
44
|
+
"""Render a template with context data.
|
36
45
|
|
37
46
|
Args:
|
38
|
-
name
|
39
|
-
data
|
47
|
+
name: Template name (without extension)
|
48
|
+
data: Context dictionary to provide variables to the template
|
40
49
|
|
41
50
|
Returns:
|
42
|
-
|
51
|
+
Rendered template content as string
|
52
|
+
|
53
|
+
Raises:
|
54
|
+
RuntimeError: If template rendering fails
|
43
55
|
"""
|
44
56
|
|
45
57
|
def blake3_hash(content: bytes) -> str:
|
46
|
-
"""Calculate the BLAKE3 hash of
|
58
|
+
"""Calculate the BLAKE3 cryptographic hash of data.
|
47
59
|
|
48
60
|
Args:
|
49
|
-
content
|
61
|
+
content: Bytes to be hashed
|
50
62
|
|
51
63
|
Returns:
|
52
|
-
|
64
|
+
Hex-encoded BLAKE3 hash string
|
53
65
|
"""
|
66
|
+
|
67
|
+
class BibManager:
|
68
|
+
"""BibTeX bibliography manager for parsing and querying citation data."""
|
69
|
+
|
70
|
+
def __init__(self, path: str) -> None:
|
71
|
+
"""Initialize the bibliography manager.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
path: Path to BibTeX (.bib) file to load
|
75
|
+
|
76
|
+
Raises:
|
77
|
+
RuntimeError: If file cannot be read or parsed
|
78
|
+
"""
|
79
|
+
|
80
|
+
def get_cite_key(self, title: str) -> Optional[str]:
|
81
|
+
"""Find citation key by exact title match.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
title: Full title to search for (case-insensitive)
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
Citation key if exact match found, None otherwise
|
88
|
+
"""
|
89
|
+
|
90
|
+
def get_cite_key_fuzzy(self, query: str) -> Optional[str]:
|
91
|
+
"""Find best matching citation using fuzzy text search.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
query: Search term to find in bibliography entries
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
Citation key of best matching entry, or None if no good match
|
98
|
+
|
99
|
+
Notes:
|
100
|
+
Uses nucleo_matcher for high-quality fuzzy text searching
|
101
|
+
See: https://crates.io/crates/nucleo-matcher
|
102
|
+
"""
|
fabricatio/_rust_instances.py
CHANGED
fabricatio/actions/article.py
CHANGED
@@ -2,11 +2,12 @@
|
|
2
2
|
|
3
3
|
from os import PathLike
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import Callable, List
|
5
|
+
from typing import Callable, List, Optional
|
6
6
|
|
7
|
+
from fabricatio.fs import safe_text_read
|
7
8
|
from fabricatio.journal import logger
|
8
9
|
from fabricatio.models.action import Action
|
9
|
-
from fabricatio.models.extra import ArticleEssence
|
10
|
+
from fabricatio.models.extra import ArticleEssence, ArticleOutline, ArticleProposal
|
10
11
|
from fabricatio.models.task import Task
|
11
12
|
|
12
13
|
|
@@ -18,11 +19,6 @@ class ExtractArticleEssence(Action):
|
|
18
19
|
which is converted from pdf files using `magic-pdf` from the `MinerU` project, see https://github.com/opendatalab/MinerU
|
19
20
|
"""
|
20
21
|
|
21
|
-
name: str = "extract article essence"
|
22
|
-
"""The name of the action."""
|
23
|
-
description: str = "Extract the essence of article(s) from the paths specified in the task dependencies."
|
24
|
-
"""The description of the action."""
|
25
|
-
|
26
22
|
output_key: str = "article_essence"
|
27
23
|
"""The key of the output data."""
|
28
24
|
|
@@ -31,13 +27,9 @@ class ExtractArticleEssence(Action):
|
|
31
27
|
task_input: Task,
|
32
28
|
reader: Callable[[P], str] = lambda p: Path(p).read_text(encoding="utf-8"),
|
33
29
|
**_,
|
34
|
-
) -> List[ArticleEssence]:
|
35
|
-
if not
|
36
|
-
|
37
|
-
affirm_case="The task does not violate the role, and could be approved since the file dependencies are specified.",
|
38
|
-
deny_case="The task does violate the role, and could not be approved.",
|
39
|
-
):
|
40
|
-
logger.info(err := "Task not approved.")
|
30
|
+
) -> Optional[List[ArticleEssence]]:
|
31
|
+
if not task_input.dependencies:
|
32
|
+
logger.info(err := "Task not approved, since no dependencies are provided.")
|
41
33
|
raise RuntimeError(err)
|
42
34
|
|
43
35
|
# trim the references
|
@@ -47,3 +39,43 @@ class ExtractArticleEssence(Action):
|
|
47
39
|
contents,
|
48
40
|
system_message=f"# your personal briefing: \n{self.briefing}",
|
49
41
|
)
|
42
|
+
|
43
|
+
|
44
|
+
class GenerateArticleProposal(Action):
|
45
|
+
"""Generate an outline for the article based on the extracted essence."""
|
46
|
+
|
47
|
+
output_key: str = "article_proposal"
|
48
|
+
"""The key of the output data."""
|
49
|
+
|
50
|
+
async def _execute(
|
51
|
+
self,
|
52
|
+
task_input: Task,
|
53
|
+
**_,
|
54
|
+
) -> Optional[ArticleProposal]:
|
55
|
+
input_path = await self.awhich_pathstr(
|
56
|
+
f"{task_input.briefing}\nExtract the path of file, which contains the article briefing that I need to read."
|
57
|
+
)
|
58
|
+
|
59
|
+
return await self.propose(
|
60
|
+
ArticleProposal,
|
61
|
+
safe_text_read(input_path),
|
62
|
+
system_message=f"# your personal briefing: \n{self.briefing}",
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
class GenerateOutline(Action):
|
67
|
+
"""Generate the article based on the outline."""
|
68
|
+
|
69
|
+
output_key: str = "article"
|
70
|
+
"""The key of the output data."""
|
71
|
+
|
72
|
+
async def _execute(
|
73
|
+
self,
|
74
|
+
article_proposal: ArticleProposal,
|
75
|
+
**_,
|
76
|
+
) -> Optional[ArticleOutline]:
|
77
|
+
return await self.propose(
|
78
|
+
ArticleOutline,
|
79
|
+
article_proposal.display(),
|
80
|
+
system_message=f"# your personal briefing: \n{self.briefing}",
|
81
|
+
)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
"""Dump the finalized output to a file."""
|
2
|
+
|
3
|
+
from typing import Unpack
|
4
|
+
|
5
|
+
from fabricatio.models.action import Action
|
6
|
+
from fabricatio.models.generic import FinalizedDumpAble
|
7
|
+
from fabricatio.models.task import Task
|
8
|
+
|
9
|
+
|
10
|
+
class DumpFinalizedOutput(Action):
|
11
|
+
"""Dump the finalized output to a file."""
|
12
|
+
|
13
|
+
output_key: str = "dump_path"
|
14
|
+
|
15
|
+
async def _execute(self, task_input: Task, to_dump: FinalizedDumpAble, **cxt: Unpack) -> str:
|
16
|
+
dump_path = await self.awhich_pathstr(
|
17
|
+
f"{task_input.briefing}\n\nExtract a single path of the file, to which I will dump the data."
|
18
|
+
)
|
19
|
+
|
20
|
+
to_dump.finalized_dump_to(dump_path)
|
21
|
+
return dump_path
|
fabricatio/actions/rag.py
CHANGED
@@ -14,7 +14,7 @@ class InjectToDB(Action, RAG):
|
|
14
14
|
|
15
15
|
async def _execute[T: PrepareVectorization](
|
16
16
|
self, to_inject: T | List[T], collection_name: Optional[str] = "my_collection", **cxt: Unpack
|
17
|
-
) -> str:
|
17
|
+
) -> Optional[str]:
|
18
18
|
if not isinstance(to_inject, list):
|
19
19
|
to_inject = [to_inject]
|
20
20
|
|
@@ -1,37 +1,37 @@
|
|
1
1
|
"""A module for the task capabilities of the Fabricatio library."""
|
2
2
|
|
3
|
-
from typing import List, Type, Unpack, overload
|
3
|
+
from typing import List, Optional, Type, Unpack, overload
|
4
4
|
|
5
5
|
from fabricatio.models.generic import ProposedAble
|
6
|
-
from fabricatio.models.kwargs_types import
|
6
|
+
from fabricatio.models.kwargs_types import ValidateKwargs
|
7
7
|
from fabricatio.models.usages import LLMUsage
|
8
8
|
|
9
9
|
|
10
|
-
class Propose
|
10
|
+
class Propose(LLMUsage):
|
11
11
|
"""A class that proposes an Obj based on a prompt."""
|
12
12
|
|
13
13
|
@overload
|
14
|
-
async def propose(
|
14
|
+
async def propose[M: ProposedAble](
|
15
15
|
self,
|
16
16
|
cls: Type[M],
|
17
17
|
prompt: List[str],
|
18
|
-
**kwargs: Unpack[
|
19
|
-
) -> List[M]: ...
|
18
|
+
**kwargs: Unpack[ValidateKwargs[M]],
|
19
|
+
) -> Optional[List[M]]: ...
|
20
20
|
|
21
21
|
@overload
|
22
|
-
async def propose(
|
22
|
+
async def propose[M: ProposedAble](
|
23
23
|
self,
|
24
24
|
cls: Type[M],
|
25
25
|
prompt: str,
|
26
|
-
**kwargs: Unpack[
|
27
|
-
) -> M: ...
|
26
|
+
**kwargs: Unpack[ValidateKwargs[M]],
|
27
|
+
) -> Optional[M]: ...
|
28
28
|
|
29
|
-
async def propose(
|
29
|
+
async def propose[M: ProposedAble](
|
30
30
|
self,
|
31
31
|
cls: Type[M],
|
32
32
|
prompt: List[str] | str,
|
33
|
-
**kwargs: Unpack[
|
34
|
-
) -> List[M] | M:
|
33
|
+
**kwargs: Unpack[ValidateKwargs[M]],
|
34
|
+
) -> Optional[List[M] | M]:
|
35
35
|
"""Asynchronously proposes a task based on a given prompt and parameters.
|
36
36
|
|
37
37
|
Parameters:
|
@@ -42,14 +42,8 @@ class Propose[M: ProposedAble](LLMUsage):
|
|
42
42
|
Returns:
|
43
43
|
A Task object based on the proposal result.
|
44
44
|
"""
|
45
|
-
|
46
|
-
|
47
|
-
question=cls.create_json_prompt(prompt),
|
48
|
-
validator=cls.instantiate_from_string,
|
49
|
-
**kwargs,
|
50
|
-
)
|
51
|
-
return await self.aask_validate_batch(
|
52
|
-
questions=[cls.create_json_prompt(p) for p in prompt],
|
45
|
+
return await self.aask_validate(
|
46
|
+
question=cls.create_json_prompt(prompt),
|
53
47
|
validator=cls.instantiate_from_string,
|
54
48
|
**kwargs,
|
55
49
|
)
|
fabricatio/capabilities/rag.py
CHANGED
@@ -8,15 +8,21 @@ from functools import lru_cache
|
|
8
8
|
from operator import itemgetter
|
9
9
|
from os import PathLike
|
10
10
|
from pathlib import Path
|
11
|
-
from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, overload
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, cast, overload
|
12
12
|
|
13
13
|
from fabricatio._rust_instances import template_manager
|
14
14
|
from fabricatio.config import configs
|
15
15
|
from fabricatio.journal import logger
|
16
|
-
from fabricatio.models.kwargs_types import
|
16
|
+
from fabricatio.models.kwargs_types import (
|
17
|
+
ChooseKwargs,
|
18
|
+
CollectionSimpleConfigKwargs,
|
19
|
+
EmbeddingKwargs,
|
20
|
+
FetchKwargs,
|
21
|
+
LLMKwargs,
|
22
|
+
)
|
17
23
|
from fabricatio.models.usages import EmbeddingUsage
|
18
24
|
from fabricatio.models.utils import MilvusData
|
19
|
-
from more_itertools.recipes import flatten
|
25
|
+
from more_itertools.recipes import flatten, unique
|
20
26
|
from pydantic import Field, PrivateAttr
|
21
27
|
|
22
28
|
|
@@ -105,9 +111,9 @@ class RAG(EmbeddingUsage):
|
|
105
111
|
create (bool): Whether to create the collection if it does not exist.
|
106
112
|
**kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
|
107
113
|
"""
|
108
|
-
if create and collection_name and
|
114
|
+
if create and collection_name and self.client.has_collection(collection_name):
|
109
115
|
kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
|
110
|
-
self.
|
116
|
+
self.client.create_collection(collection_name, auto_id=True, **kwargs)
|
111
117
|
logger.info(f"Creating collection {collection_name}")
|
112
118
|
|
113
119
|
self.target_collection = collection_name
|
@@ -146,15 +152,17 @@ class RAG(EmbeddingUsage):
|
|
146
152
|
Self: The current instance, allowing for method chaining.
|
147
153
|
"""
|
148
154
|
if isinstance(data, MilvusData):
|
149
|
-
|
150
|
-
|
151
|
-
|
155
|
+
prepared_data = data.prepare_insertion()
|
156
|
+
elif isinstance(data, list):
|
157
|
+
prepared_data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
|
158
|
+
else:
|
159
|
+
raise TypeError(f"Expected MilvusData or list of MilvusData, got {type(data)}")
|
152
160
|
c_name = collection_name or self.safe_target_collection
|
153
|
-
self.
|
161
|
+
self.client.insert(c_name, prepared_data)
|
154
162
|
|
155
163
|
if flush:
|
156
164
|
logger.debug(f"Flushing collection {c_name}")
|
157
|
-
self.
|
165
|
+
self.client.flush(c_name)
|
158
166
|
return self
|
159
167
|
|
160
168
|
async def consume_file(
|
@@ -190,14 +198,14 @@ class RAG(EmbeddingUsage):
|
|
190
198
|
self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
|
191
199
|
return self
|
192
200
|
|
193
|
-
async def afetch_document(
|
201
|
+
async def afetch_document[V: (int, str, float, bytes)](
|
194
202
|
self,
|
195
203
|
vecs: List[List[float]],
|
196
204
|
desired_fields: List[str] | str,
|
197
205
|
collection_name: Optional[str] = None,
|
198
206
|
similarity_threshold: float = 0.37,
|
199
207
|
result_per_query: int = 10,
|
200
|
-
) -> List[Dict[str, Any]] | List[
|
208
|
+
) -> List[Dict[str, Any]] | List[V]:
|
201
209
|
"""Fetch data from the collection.
|
202
210
|
|
203
211
|
Args:
|
@@ -211,7 +219,7 @@ class RAG(EmbeddingUsage):
|
|
211
219
|
List[Dict[str, Any]] | List[Any]: The retrieved data.
|
212
220
|
"""
|
213
221
|
# Step 1: Search for vectors
|
214
|
-
search_results = self.
|
222
|
+
search_results = self.client.search(
|
215
223
|
collection_name or self.safe_target_collection,
|
216
224
|
vecs,
|
217
225
|
search_params={"radius": similarity_threshold},
|
@@ -221,9 +229,9 @@ class RAG(EmbeddingUsage):
|
|
221
229
|
|
222
230
|
# Step 2: Flatten the search results
|
223
231
|
flattened_results = flatten(search_results)
|
224
|
-
|
232
|
+
unique_results = unique(flattened_results, key=itemgetter("id"))
|
225
233
|
# Step 3: Sort by distance (descending)
|
226
|
-
sorted_results = sorted(
|
234
|
+
sorted_results = sorted(unique_results, key=itemgetter("distance"), reverse=True)
|
227
235
|
|
228
236
|
logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
|
229
237
|
# Step 4: Extract the entities
|
@@ -231,12 +239,11 @@ class RAG(EmbeddingUsage):
|
|
231
239
|
|
232
240
|
if isinstance(desired_fields, list):
|
233
241
|
return resp
|
234
|
-
return [r.get(desired_fields) for r in resp]
|
242
|
+
return [r.get(desired_fields) for r in resp] # extract the single field as list
|
235
243
|
|
236
244
|
async def aretrieve(
|
237
245
|
self,
|
238
246
|
query: List[str] | str,
|
239
|
-
collection_name: Optional[str] = None,
|
240
247
|
final_limit: int = 20,
|
241
248
|
**kwargs: Unpack[FetchKwargs],
|
242
249
|
) -> List[str]:
|
@@ -244,7 +251,6 @@ class RAG(EmbeddingUsage):
|
|
244
251
|
|
245
252
|
Args:
|
246
253
|
query (List[str] | str): The query to be used for retrieval.
|
247
|
-
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
248
254
|
final_limit (int): The final limit on the number of results to return.
|
249
255
|
**kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
|
250
256
|
|
@@ -253,19 +259,19 @@ class RAG(EmbeddingUsage):
|
|
253
259
|
"""
|
254
260
|
if isinstance(query, str):
|
255
261
|
query = [query]
|
256
|
-
return (
|
262
|
+
return cast(
|
263
|
+
List[str],
|
257
264
|
await self.afetch_document(
|
258
265
|
vecs=(await self.vectorize(query)),
|
259
266
|
desired_fields="text",
|
260
|
-
collection_name=collection_name,
|
261
267
|
**kwargs,
|
262
|
-
)
|
268
|
+
),
|
263
269
|
)[:final_limit]
|
264
270
|
|
265
271
|
async def aask_retrieved(
|
266
272
|
self,
|
267
|
-
question: str
|
268
|
-
query: List[str] | str,
|
273
|
+
question: str,
|
274
|
+
query: Optional[List[str] | str] = None,
|
269
275
|
collection_name: Optional[str] = None,
|
270
276
|
extra_system_message: str = "",
|
271
277
|
result_per_query: int = 10,
|
@@ -279,7 +285,7 @@ class RAG(EmbeddingUsage):
|
|
279
285
|
specified question using the retrieved documents as context.
|
280
286
|
|
281
287
|
Args:
|
282
|
-
question (str
|
288
|
+
question (str): The question to be asked.
|
283
289
|
query (List[str] | str): The query or list of queries used for document retrieval.
|
284
290
|
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
285
291
|
If not provided, the currently viewed collection is used.
|
@@ -293,9 +299,9 @@ class RAG(EmbeddingUsage):
|
|
293
299
|
str: A string response generated after asking with the context of retrieved documents.
|
294
300
|
"""
|
295
301
|
docs = await self.aretrieve(
|
296
|
-
query,
|
297
|
-
collection_name,
|
302
|
+
query or question,
|
298
303
|
final_limit,
|
304
|
+
collection_name=collection_name,
|
299
305
|
result_per_query=result_per_query,
|
300
306
|
similarity_threshold=similarity_threshold,
|
301
307
|
)
|
@@ -308,3 +314,56 @@ class RAG(EmbeddingUsage):
|
|
308
314
|
f"{rendered}\n\n{extra_system_message}",
|
309
315
|
**kwargs,
|
310
316
|
)
|
317
|
+
|
318
|
+
async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> List[str]:
|
319
|
+
"""Refines the given question using a template.
|
320
|
+
|
321
|
+
Args:
|
322
|
+
question (List[str] | str): The question to be refined.
|
323
|
+
**kwargs (Unpack[ChooseKwargs]): Additional keyword arguments for the refinement process.
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
List[str]: A list of refined questions.
|
327
|
+
"""
|
328
|
+
return await self.aliststr(
|
329
|
+
template_manager.render_template(
|
330
|
+
configs.templates.refined_query_template,
|
331
|
+
{"question": [question] if isinstance(question, str) else question},
|
332
|
+
),
|
333
|
+
**kwargs,
|
334
|
+
)
|
335
|
+
|
336
|
+
async def aask_refined(
|
337
|
+
self,
|
338
|
+
question: str,
|
339
|
+
collection_name: Optional[str] = None,
|
340
|
+
extra_system_message: str = "",
|
341
|
+
result_per_query: int = 10,
|
342
|
+
final_limit: int = 20,
|
343
|
+
similarity_threshold: float = 0.37,
|
344
|
+
**kwargs: Unpack[LLMKwargs],
|
345
|
+
) -> str:
|
346
|
+
"""Asks a question using a refined query based on the provided question.
|
347
|
+
|
348
|
+
Args:
|
349
|
+
question (str): The question to be asked.
|
350
|
+
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
351
|
+
extra_system_message (str): An additional system message to be included in the prompt.
|
352
|
+
result_per_query (int): The number of results to return per query. Default is 10.
|
353
|
+
final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
|
354
|
+
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
355
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
|
356
|
+
|
357
|
+
Returns:
|
358
|
+
str: A string response generated after asking with the refined question.
|
359
|
+
"""
|
360
|
+
return await self.aask_retrieved(
|
361
|
+
question,
|
362
|
+
await self.arefined_query(question, **kwargs),
|
363
|
+
collection_name=collection_name,
|
364
|
+
extra_system_message=extra_system_message,
|
365
|
+
result_per_query=result_per_query,
|
366
|
+
final_limit=final_limit,
|
367
|
+
similarity_threshold=similarity_threshold,
|
368
|
+
**kwargs,
|
369
|
+
)
|