fabricatio 0.2.3.dev2__cp312-cp312-manylinux_2_34_x86_64.whl → 0.2.4__cp312-cp312-manylinux_2_34_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fabricatio/__init__.py +24 -3
- fabricatio/_rust.cpython-312-x86_64-linux-gnu.so +0 -0
- fabricatio/actions/article.py +81 -0
- fabricatio/actions/output.py +21 -0
- fabricatio/actions/rag.py +25 -0
- fabricatio/capabilities/propose.py +55 -0
- fabricatio/capabilities/rag.py +241 -54
- fabricatio/capabilities/rating.py +12 -36
- fabricatio/capabilities/task.py +6 -23
- fabricatio/config.py +46 -2
- fabricatio/fs/__init__.py +24 -2
- fabricatio/fs/curd.py +14 -8
- fabricatio/fs/readers.py +5 -2
- fabricatio/models/action.py +19 -4
- fabricatio/models/events.py +36 -0
- fabricatio/models/extra.py +168 -0
- fabricatio/models/generic.py +218 -7
- fabricatio/models/kwargs_types.py +15 -0
- fabricatio/models/task.py +11 -43
- fabricatio/models/tool.py +3 -2
- fabricatio/models/usages.py +153 -184
- fabricatio/models/utils.py +19 -0
- fabricatio/parser.py +35 -8
- fabricatio/toolboxes/__init__.py +1 -3
- fabricatio/toolboxes/fs.py +15 -1
- fabricatio/workflows/articles.py +15 -0
- fabricatio/workflows/rag.py +11 -0
- fabricatio-0.2.4.data/scripts/tdown +0 -0
- {fabricatio-0.2.3.dev2.dist-info → fabricatio-0.2.4.dist-info}/METADATA +39 -147
- fabricatio-0.2.4.dist-info/RECORD +40 -0
- fabricatio/actions/__init__.py +0 -5
- fabricatio/actions/communication.py +0 -15
- fabricatio/actions/transmission.py +0 -23
- fabricatio/toolboxes/task.py +0 -6
- fabricatio-0.2.3.dev2.data/scripts/tdown +0 -0
- fabricatio-0.2.3.dev2.dist-info/RECORD +0 -37
- {fabricatio-0.2.3.dev2.dist-info → fabricatio-0.2.4.dist-info}/WHEEL +0 -0
- {fabricatio-0.2.3.dev2.dist-info → fabricatio-0.2.4.dist-info}/licenses/LICENSE +0 -0
fabricatio/__init__.py
CHANGED
@@ -1,23 +1,34 @@
|
|
1
1
|
"""Fabricatio is a Python library for building llm app using event-based agent structure."""
|
2
2
|
|
3
|
+
from importlib.util import find_spec
|
4
|
+
|
3
5
|
from fabricatio._rust_instances import template_manager
|
6
|
+
from fabricatio.actions.article import ExtractArticleEssence, GenerateArticleProposal, GenerateOutline
|
7
|
+
from fabricatio.actions.output import DumpFinalizedOutput
|
4
8
|
from fabricatio.core import env
|
5
|
-
from fabricatio.fs import magika
|
9
|
+
from fabricatio.fs import magika, safe_json_read, safe_text_read
|
6
10
|
from fabricatio.journal import logger
|
7
11
|
from fabricatio.models.action import Action, WorkFlow
|
8
12
|
from fabricatio.models.events import Event
|
13
|
+
from fabricatio.models.extra import ArticleEssence
|
9
14
|
from fabricatio.models.role import Role
|
10
15
|
from fabricatio.models.task import Task
|
11
16
|
from fabricatio.models.tool import ToolBox
|
12
17
|
from fabricatio.models.utils import Message, Messages
|
13
18
|
from fabricatio.parser import Capture, CodeBlockCapture, JsonCapture, PythonCapture
|
14
|
-
from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox
|
19
|
+
from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox
|
20
|
+
from fabricatio.workflows.articles import WriteOutlineWorkFlow
|
15
21
|
|
16
22
|
__all__ = [
|
17
23
|
"Action",
|
24
|
+
"ArticleEssence",
|
18
25
|
"Capture",
|
19
26
|
"CodeBlockCapture",
|
27
|
+
"DumpFinalizedOutput",
|
20
28
|
"Event",
|
29
|
+
"ExtractArticleEssence",
|
30
|
+
"GenerateArticleProposal",
|
31
|
+
"GenerateOutline",
|
21
32
|
"JsonCapture",
|
22
33
|
"Message",
|
23
34
|
"Messages",
|
@@ -26,12 +37,22 @@ __all__ = [
|
|
26
37
|
"Task",
|
27
38
|
"ToolBox",
|
28
39
|
"WorkFlow",
|
40
|
+
"WriteOutlineWorkFlow",
|
29
41
|
"arithmetic_toolbox",
|
30
42
|
"basic_toolboxes",
|
31
43
|
"env",
|
32
44
|
"fs_toolbox",
|
33
45
|
"logger",
|
34
46
|
"magika",
|
35
|
-
"
|
47
|
+
"safe_json_read",
|
48
|
+
"safe_text_read",
|
36
49
|
"template_manager",
|
37
50
|
]
|
51
|
+
|
52
|
+
|
53
|
+
if find_spec("pymilvus"):
|
54
|
+
from fabricatio.actions.rag import InjectToDB
|
55
|
+
from fabricatio.capabilities.rag import RAG
|
56
|
+
from fabricatio.workflows.rag import StoreArticle
|
57
|
+
|
58
|
+
__all__ += ["RAG", "InjectToDB", "StoreArticle"]
|
Binary file
|
@@ -0,0 +1,81 @@
|
|
1
|
+
"""Actions for transmitting tasks to targets."""
|
2
|
+
|
3
|
+
from os import PathLike
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Callable, List
|
6
|
+
|
7
|
+
from fabricatio.fs import safe_text_read
|
8
|
+
from fabricatio.journal import logger
|
9
|
+
from fabricatio.models.action import Action
|
10
|
+
from fabricatio.models.extra import ArticleEssence, ArticleOutline, ArticleProposal
|
11
|
+
from fabricatio.models.task import Task
|
12
|
+
|
13
|
+
|
14
|
+
class ExtractArticleEssence(Action):
|
15
|
+
"""Extract the essence of article(s) in text format from the paths specified in the task dependencies.
|
16
|
+
|
17
|
+
Notes:
|
18
|
+
This action is designed to extract vital information from articles with Markdown format, which is pure text, and
|
19
|
+
which is converted from pdf files using `magic-pdf` from the `MinerU` project, see https://github.com/opendatalab/MinerU
|
20
|
+
"""
|
21
|
+
|
22
|
+
output_key: str = "article_essence"
|
23
|
+
"""The key of the output data."""
|
24
|
+
|
25
|
+
async def _execute[P: PathLike | str](
|
26
|
+
self,
|
27
|
+
task_input: Task,
|
28
|
+
reader: Callable[[P], str] = lambda p: Path(p).read_text(encoding="utf-8"),
|
29
|
+
**_,
|
30
|
+
) -> List[ArticleEssence]:
|
31
|
+
if not task_input.dependencies:
|
32
|
+
logger.info(err := "Task not approved, since no dependencies are provided.")
|
33
|
+
raise RuntimeError(err)
|
34
|
+
|
35
|
+
# trim the references
|
36
|
+
contents = ["References".join(c.split("References")[:-1]) for c in map(reader, task_input.dependencies)]
|
37
|
+
return await self.propose(
|
38
|
+
ArticleEssence,
|
39
|
+
contents,
|
40
|
+
system_message=f"# your personal briefing: \n{self.briefing}",
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class GenerateArticleProposal(Action):
|
45
|
+
"""Generate an outline for the article based on the extracted essence."""
|
46
|
+
|
47
|
+
output_key: str = "article_proposal"
|
48
|
+
"""The key of the output data."""
|
49
|
+
|
50
|
+
async def _execute(
|
51
|
+
self,
|
52
|
+
task_input: Task,
|
53
|
+
**_,
|
54
|
+
) -> ArticleProposal:
|
55
|
+
input_path = await self.awhich_pathstr(
|
56
|
+
f"{task_input.briefing}\nExtract the path of file, which contains the article briefing that I need to read."
|
57
|
+
)
|
58
|
+
|
59
|
+
return await self.propose(
|
60
|
+
ArticleProposal,
|
61
|
+
safe_text_read(input_path),
|
62
|
+
system_message=f"# your personal briefing: \n{self.briefing}",
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
class GenerateOutline(Action):
|
67
|
+
"""Generate the article based on the outline."""
|
68
|
+
|
69
|
+
output_key: str = "article"
|
70
|
+
"""The key of the output data."""
|
71
|
+
|
72
|
+
async def _execute(
|
73
|
+
self,
|
74
|
+
article_proposal: ArticleProposal,
|
75
|
+
**_,
|
76
|
+
) -> ArticleOutline:
|
77
|
+
return await self.propose(
|
78
|
+
ArticleOutline,
|
79
|
+
article_proposal.display(),
|
80
|
+
system_message=f"# your personal briefing: \n{self.briefing}",
|
81
|
+
)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
"""Dump the finalized output to a file."""
|
2
|
+
|
3
|
+
from typing import Unpack
|
4
|
+
|
5
|
+
from fabricatio.models.action import Action
|
6
|
+
from fabricatio.models.generic import FinalizedDumpAble
|
7
|
+
from fabricatio.models.task import Task
|
8
|
+
|
9
|
+
|
10
|
+
class DumpFinalizedOutput(Action):
|
11
|
+
"""Dump the finalized output to a file."""
|
12
|
+
|
13
|
+
output_key: str = "dump_path"
|
14
|
+
|
15
|
+
async def _execute(self, task_input: Task, to_dump: FinalizedDumpAble, **cxt: Unpack) -> str:
|
16
|
+
dump_path = await self.awhich_pathstr(
|
17
|
+
f"{task_input.briefing}\n\nExtract a single path of the file, to which I will dump the data."
|
18
|
+
)
|
19
|
+
|
20
|
+
to_dump.finalized_dump_to(dump_path)
|
21
|
+
return dump_path
|
@@ -0,0 +1,25 @@
|
|
1
|
+
"""Inject data into the database."""
|
2
|
+
|
3
|
+
from typing import List, Optional, Unpack
|
4
|
+
|
5
|
+
from fabricatio.capabilities.rag import RAG
|
6
|
+
from fabricatio.models.action import Action
|
7
|
+
from fabricatio.models.generic import PrepareVectorization
|
8
|
+
|
9
|
+
|
10
|
+
class InjectToDB(Action, RAG):
|
11
|
+
"""Inject data into the database."""
|
12
|
+
|
13
|
+
output_key: str = "collection_name"
|
14
|
+
|
15
|
+
async def _execute[T: PrepareVectorization](
|
16
|
+
self, to_inject: T | List[T], collection_name: Optional[str] = "my_collection", **cxt: Unpack
|
17
|
+
) -> str:
|
18
|
+
if not isinstance(to_inject, list):
|
19
|
+
to_inject = [to_inject]
|
20
|
+
|
21
|
+
await self.view(collection_name, create=True).consume_string(
|
22
|
+
[t.prepare_vectorization(self.embedding_max_sequence_length) for t in to_inject],
|
23
|
+
)
|
24
|
+
|
25
|
+
return collection_name
|
@@ -0,0 +1,55 @@
|
|
1
|
+
"""A module for the task capabilities of the Fabricatio library."""
|
2
|
+
|
3
|
+
from typing import List, Type, Unpack, overload
|
4
|
+
|
5
|
+
from fabricatio.models.generic import ProposedAble
|
6
|
+
from fabricatio.models.kwargs_types import GenerateKwargs
|
7
|
+
from fabricatio.models.usages import LLMUsage
|
8
|
+
|
9
|
+
|
10
|
+
class Propose[M: ProposedAble](LLMUsage):
|
11
|
+
"""A class that proposes an Obj based on a prompt."""
|
12
|
+
|
13
|
+
@overload
|
14
|
+
async def propose(
|
15
|
+
self,
|
16
|
+
cls: Type[M],
|
17
|
+
prompt: List[str],
|
18
|
+
**kwargs: Unpack[GenerateKwargs],
|
19
|
+
) -> List[M]: ...
|
20
|
+
|
21
|
+
@overload
|
22
|
+
async def propose(
|
23
|
+
self,
|
24
|
+
cls: Type[M],
|
25
|
+
prompt: str,
|
26
|
+
**kwargs: Unpack[GenerateKwargs],
|
27
|
+
) -> M: ...
|
28
|
+
|
29
|
+
async def propose(
|
30
|
+
self,
|
31
|
+
cls: Type[M],
|
32
|
+
prompt: List[str] | str,
|
33
|
+
**kwargs: Unpack[GenerateKwargs],
|
34
|
+
) -> List[M] | M:
|
35
|
+
"""Asynchronously proposes a task based on a given prompt and parameters.
|
36
|
+
|
37
|
+
Parameters:
|
38
|
+
cls: The class type of the task to be proposed.
|
39
|
+
prompt: The prompt text for proposing a task, which is a string that must be provided.
|
40
|
+
**kwargs: The keyword arguments for the LLM (Large Language Model) usage.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
A Task object based on the proposal result.
|
44
|
+
"""
|
45
|
+
if isinstance(prompt, str):
|
46
|
+
return await self.aask_validate(
|
47
|
+
question=cls.create_json_prompt(prompt),
|
48
|
+
validator=cls.instantiate_from_string,
|
49
|
+
**kwargs,
|
50
|
+
)
|
51
|
+
return await self.aask_validate_batch(
|
52
|
+
questions=[cls.create_json_prompt(p) for p in prompt],
|
53
|
+
validator=cls.instantiate_from_string,
|
54
|
+
**kwargs,
|
55
|
+
)
|
fabricatio/capabilities/rag.py
CHANGED
@@ -1,89 +1,152 @@
|
|
1
1
|
"""A module for the RAG (Retrieval Augmented Generation) model."""
|
2
2
|
|
3
|
+
try:
|
4
|
+
from pymilvus import MilvusClient
|
5
|
+
except ImportError as e:
|
6
|
+
raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
|
7
|
+
from functools import lru_cache
|
3
8
|
from operator import itemgetter
|
4
9
|
from os import PathLike
|
5
10
|
from pathlib import Path
|
6
|
-
from typing import Any, Callable, Dict, List, Optional, Self, Union
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, overload
|
7
12
|
|
13
|
+
from fabricatio._rust_instances import template_manager
|
8
14
|
from fabricatio.config import configs
|
9
|
-
from fabricatio.
|
15
|
+
from fabricatio.journal import logger
|
16
|
+
from fabricatio.models.kwargs_types import (
|
17
|
+
ChooseKwargs,
|
18
|
+
CollectionSimpleConfigKwargs,
|
19
|
+
EmbeddingKwargs,
|
20
|
+
FetchKwargs,
|
21
|
+
LLMKwargs,
|
22
|
+
)
|
23
|
+
from fabricatio.models.usages import EmbeddingUsage
|
10
24
|
from fabricatio.models.utils import MilvusData
|
11
|
-
from more_itertools.recipes import flatten
|
25
|
+
from more_itertools.recipes import flatten, unique
|
26
|
+
from pydantic import Field, PrivateAttr
|
12
27
|
|
13
|
-
try:
|
14
|
-
from pymilvus import MilvusClient
|
15
|
-
except ImportError as e:
|
16
|
-
raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
|
17
|
-
from pydantic import PrivateAttr
|
18
28
|
|
29
|
+
@lru_cache(maxsize=None)
|
30
|
+
def create_client(uri: str, token: str = "", timeout: Optional[float] = None) -> MilvusClient:
|
31
|
+
"""Create a Milvus client."""
|
32
|
+
return MilvusClient(
|
33
|
+
uri=uri,
|
34
|
+
token=token,
|
35
|
+
timeout=timeout,
|
36
|
+
)
|
19
37
|
|
20
|
-
|
38
|
+
|
39
|
+
class RAG(EmbeddingUsage):
|
21
40
|
"""A class representing the RAG (Retrieval Augmented Generation) model."""
|
22
41
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
),
|
29
|
-
)
|
30
|
-
_target_collection: Optional[str] = PrivateAttr(default=None)
|
42
|
+
target_collection: Optional[str] = Field(default=None)
|
43
|
+
"""The name of the collection being viewed."""
|
44
|
+
|
45
|
+
_client: Optional[MilvusClient] = PrivateAttr(None)
|
46
|
+
"""The Milvus client used for the RAG model."""
|
31
47
|
|
32
48
|
@property
|
33
49
|
def client(self) -> MilvusClient:
|
34
|
-
"""
|
50
|
+
"""Return the Milvus client."""
|
51
|
+
if self._client is None:
|
52
|
+
raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
|
35
53
|
return self._client
|
36
54
|
|
37
|
-
def
|
55
|
+
def init_client(
|
56
|
+
self,
|
57
|
+
milvus_uri: Optional[str] = None,
|
58
|
+
milvus_token: Optional[str] = None,
|
59
|
+
milvus_timeout: Optional[float] = None,
|
60
|
+
) -> Self:
|
61
|
+
"""Initialize the Milvus client."""
|
62
|
+
self._client = create_client(
|
63
|
+
uri=milvus_uri or (self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
|
64
|
+
token=milvus_token
|
65
|
+
or (token.get_secret_value() if (token := (self.milvus_token or configs.rag.milvus_token)) else ""),
|
66
|
+
timeout=milvus_timeout or self.milvus_timeout,
|
67
|
+
)
|
68
|
+
return self
|
69
|
+
|
70
|
+
@overload
|
71
|
+
async def pack(
|
72
|
+
self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
73
|
+
) -> List[MilvusData]: ...
|
74
|
+
@overload
|
75
|
+
async def pack(
|
76
|
+
self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
77
|
+
) -> MilvusData: ...
|
78
|
+
|
79
|
+
async def pack(
|
80
|
+
self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
|
81
|
+
) -> List[MilvusData] | MilvusData:
|
82
|
+
"""Asynchronously generates MilvusData objects for the given input text.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
input_text (List[str] | str): A string or list of strings to generate embeddings for.
|
86
|
+
subject (Optional[str]): The subject of the input text. Defaults to None.
|
87
|
+
**kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
List[MilvusData] | MilvusData: The generated MilvusData objects.
|
91
|
+
"""
|
92
|
+
if isinstance(input_text, str):
|
93
|
+
return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
|
94
|
+
vecs = await self.vectorize(input_text, **kwargs)
|
95
|
+
return [
|
96
|
+
MilvusData(
|
97
|
+
vector=vec,
|
98
|
+
text=text,
|
99
|
+
subject=subject,
|
100
|
+
)
|
101
|
+
for text, vec in zip(input_text, vecs, strict=True)
|
102
|
+
]
|
103
|
+
|
104
|
+
def view(
|
105
|
+
self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionSimpleConfigKwargs]
|
106
|
+
) -> Self:
|
38
107
|
"""View the specified collection.
|
39
108
|
|
40
109
|
Args:
|
41
110
|
collection_name (str): The name of the collection.
|
42
111
|
create (bool): Whether to create the collection if it does not exist.
|
112
|
+
**kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
|
43
113
|
"""
|
44
|
-
if create and self._client.has_collection(collection_name):
|
45
|
-
self.
|
114
|
+
if create and collection_name and not self._client.has_collection(collection_name):
|
115
|
+
kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
|
116
|
+
self._client.create_collection(collection_name, auto_id=True, **kwargs)
|
117
|
+
logger.info(f"Creating collection {collection_name}")
|
46
118
|
|
47
|
-
self.
|
119
|
+
self.target_collection = collection_name
|
48
120
|
return self
|
49
121
|
|
50
|
-
def
|
122
|
+
def quit_viewing(self) -> Self:
|
51
123
|
"""Quit the current view.
|
52
124
|
|
53
125
|
Returns:
|
54
126
|
Self: The current instance, allowing for method chaining.
|
55
127
|
"""
|
56
|
-
self.
|
57
|
-
return self
|
58
|
-
|
59
|
-
@property
|
60
|
-
def viewing_collection(self) -> Optional[str]:
|
61
|
-
"""Get the name of the collection being viewed.
|
62
|
-
|
63
|
-
Returns:
|
64
|
-
Optional[str]: The name of the collection being viewed.
|
65
|
-
"""
|
66
|
-
return self._target_collection
|
128
|
+
return self.view(None)
|
67
129
|
|
68
130
|
@property
|
69
|
-
def
|
131
|
+
def safe_target_collection(self) -> str:
|
70
132
|
"""Get the name of the collection being viewed, raise an error if not viewing any collection.
|
71
133
|
|
72
134
|
Returns:
|
73
135
|
str: The name of the collection being viewed.
|
74
136
|
"""
|
75
|
-
if self.
|
137
|
+
if self.target_collection is None:
|
76
138
|
raise RuntimeError("No collection is being viewed. Have you called `self.view()`?")
|
77
|
-
return self.
|
139
|
+
return self.target_collection
|
78
140
|
|
79
141
|
def add_document[D: Union[Dict[str, Any], MilvusData]](
|
80
|
-
self, data: D | List[D], collection_name: Optional[str] = None
|
142
|
+
self, data: D | List[D], collection_name: Optional[str] = None, flush: bool = False
|
81
143
|
) -> Self:
|
82
144
|
"""Adds a document to the specified collection.
|
83
145
|
|
84
146
|
Args:
|
85
147
|
data (Union[Dict[str, Any], MilvusData] | List[Union[Dict[str, Any], MilvusData]]): The data to be added to the collection.
|
86
148
|
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
149
|
+
flush (bool): Whether to flush the collection after insertion.
|
87
150
|
|
88
151
|
Returns:
|
89
152
|
Self: The current instance, allowing for method chaining.
|
@@ -92,11 +155,19 @@ class Rag(LLMUsage):
|
|
92
155
|
data = data.prepare_insertion()
|
93
156
|
if isinstance(data, list):
|
94
157
|
data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
|
95
|
-
|
158
|
+
c_name = collection_name or self.safe_target_collection
|
159
|
+
self._client.insert(c_name, data)
|
160
|
+
|
161
|
+
if flush:
|
162
|
+
logger.debug(f"Flushing collection {c_name}")
|
163
|
+
self._client.flush(c_name)
|
96
164
|
return self
|
97
165
|
|
98
|
-
def
|
99
|
-
self,
|
166
|
+
async def consume_file(
|
167
|
+
self,
|
168
|
+
source: List[PathLike] | PathLike,
|
169
|
+
reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
|
170
|
+
collection_name: Optional[str] = None,
|
100
171
|
) -> Self:
|
101
172
|
"""Consume a file and add its content to the collection.
|
102
173
|
|
@@ -108,8 +179,21 @@ class Rag(LLMUsage):
|
|
108
179
|
Returns:
|
109
180
|
Self: The current instance, allowing for method chaining.
|
110
181
|
"""
|
111
|
-
|
112
|
-
|
182
|
+
if not isinstance(source, list):
|
183
|
+
source = [source]
|
184
|
+
return await self.consume_string([reader(s) for s in source], collection_name)
|
185
|
+
|
186
|
+
async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
|
187
|
+
"""Consume a string and add it to the collection.
|
188
|
+
|
189
|
+
Args:
|
190
|
+
text (List[str] | str): The text to be added to the collection.
|
191
|
+
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
Self: The current instance, allowing for method chaining.
|
195
|
+
"""
|
196
|
+
self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
|
113
197
|
return self
|
114
198
|
|
115
199
|
async def afetch_document(
|
@@ -117,6 +201,7 @@ class Rag(LLMUsage):
|
|
117
201
|
vecs: List[List[float]],
|
118
202
|
desired_fields: List[str] | str,
|
119
203
|
collection_name: Optional[str] = None,
|
204
|
+
similarity_threshold: float = 0.37,
|
120
205
|
result_per_query: int = 10,
|
121
206
|
) -> List[Dict[str, Any]] | List[Any]:
|
122
207
|
"""Fetch data from the collection.
|
@@ -125,6 +210,7 @@ class Rag(LLMUsage):
|
|
125
210
|
vecs (List[List[float]]): The vectors to search for.
|
126
211
|
desired_fields (List[str] | str): The fields to retrieve.
|
127
212
|
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
213
|
+
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
128
214
|
result_per_query (int): The number of results to return per query.
|
129
215
|
|
130
216
|
Returns:
|
@@ -132,18 +218,20 @@ class Rag(LLMUsage):
|
|
132
218
|
"""
|
133
219
|
# Step 1: Search for vectors
|
134
220
|
search_results = self._client.search(
|
135
|
-
collection_name or self.
|
221
|
+
collection_name or self.safe_target_collection,
|
136
222
|
vecs,
|
223
|
+
search_params={"radius": similarity_threshold},
|
137
224
|
output_fields=desired_fields if isinstance(desired_fields, list) else [desired_fields],
|
138
225
|
limit=result_per_query,
|
139
226
|
)
|
140
227
|
|
141
228
|
# Step 2: Flatten the search results
|
142
229
|
flattened_results = flatten(search_results)
|
143
|
-
|
230
|
+
unique_results = unique(flattened_results, key=itemgetter("id"))
|
144
231
|
# Step 3: Sort by distance (descending)
|
145
|
-
sorted_results = sorted(
|
232
|
+
sorted_results = sorted(unique_results, key=itemgetter("distance"), reverse=True)
|
146
233
|
|
234
|
+
logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
|
147
235
|
# Step 4: Extract the entities
|
148
236
|
resp = [result["entity"] for result in sorted_results]
|
149
237
|
|
@@ -154,26 +242,125 @@ class Rag(LLMUsage):
|
|
154
242
|
async def aretrieve(
|
155
243
|
self,
|
156
244
|
query: List[str] | str,
|
157
|
-
collection_name: Optional[str] = None,
|
158
|
-
result_per_query: int = 10,
|
159
245
|
final_limit: int = 20,
|
246
|
+
**kwargs: Unpack[FetchKwargs],
|
160
247
|
) -> List[str]:
|
161
248
|
"""Retrieve data from the collection.
|
162
249
|
|
163
250
|
Args:
|
164
251
|
query (List[str] | str): The query to be used for retrieval.
|
165
|
-
collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
|
166
|
-
result_per_query (int): The number of results to be returned per query.
|
167
252
|
final_limit (int): The final limit on the number of results to return.
|
253
|
+
**kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
|
168
254
|
|
169
255
|
Returns:
|
170
256
|
List[str]: A list of strings containing the retrieved data.
|
171
257
|
"""
|
172
258
|
if isinstance(query, str):
|
173
259
|
query = [query]
|
174
|
-
return
|
175
|
-
|
176
|
-
|
260
|
+
return (
|
261
|
+
await self.afetch_document(
|
262
|
+
vecs=(await self.vectorize(query)),
|
263
|
+
desired_fields="text",
|
264
|
+
**kwargs,
|
265
|
+
)
|
266
|
+
)[:final_limit]
|
267
|
+
|
268
|
+
async def aask_retrieved(
|
269
|
+
self,
|
270
|
+
question: str,
|
271
|
+
query: Optional[List[str] | str] = None,
|
272
|
+
collection_name: Optional[str] = None,
|
273
|
+
extra_system_message: str = "",
|
274
|
+
result_per_query: int = 10,
|
275
|
+
final_limit: int = 20,
|
276
|
+
similarity_threshold: float = 0.37,
|
277
|
+
**kwargs: Unpack[LLMKwargs],
|
278
|
+
) -> str:
|
279
|
+
"""Asks a question by retrieving relevant documents based on the provided query.
|
280
|
+
|
281
|
+
This method performs document retrieval using the given query, then asks the
|
282
|
+
specified question using the retrieved documents as context.
|
283
|
+
|
284
|
+
Args:
|
285
|
+
question (str): The question to be asked.
|
286
|
+
query (List[str] | str): The query or list of queries used for document retrieval.
|
287
|
+
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
288
|
+
If not provided, the currently viewed collection is used.
|
289
|
+
extra_system_message (str): An additional system message to be included in the prompt.
|
290
|
+
result_per_query (int): The number of results to return per query. Default is 10.
|
291
|
+
final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
|
292
|
+
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
293
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
|
294
|
+
|
295
|
+
Returns:
|
296
|
+
str: A string response generated after asking with the context of retrieved documents.
|
297
|
+
"""
|
298
|
+
docs = await self.aretrieve(
|
299
|
+
query or question,
|
300
|
+
final_limit,
|
177
301
|
collection_name=collection_name,
|
178
302
|
result_per_query=result_per_query,
|
179
|
-
|
303
|
+
similarity_threshold=similarity_threshold,
|
304
|
+
)
|
305
|
+
|
306
|
+
rendered = template_manager.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
|
307
|
+
|
308
|
+
logger.debug(f"Retrieved Documents: \n{rendered}")
|
309
|
+
return await self.aask(
|
310
|
+
question,
|
311
|
+
f"{rendered}\n\n{extra_system_message}",
|
312
|
+
**kwargs,
|
313
|
+
)
|
314
|
+
|
315
|
+
async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> List[str]:
|
316
|
+
"""Refines the given question using a template.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
question (List[str] | str): The question to be refined.
|
320
|
+
**kwargs (Unpack[ChooseKwargs]): Additional keyword arguments for the refinement process.
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
List[str]: A list of refined questions.
|
324
|
+
"""
|
325
|
+
return await self.aliststr(
|
326
|
+
template_manager.render_template(
|
327
|
+
configs.templates.refined_query_template,
|
328
|
+
{"question": [question] if isinstance(question, str) else question},
|
329
|
+
),
|
330
|
+
**kwargs,
|
331
|
+
)
|
332
|
+
|
333
|
+
async def aask_refined(
|
334
|
+
self,
|
335
|
+
question: str,
|
336
|
+
collection_name: Optional[str] = None,
|
337
|
+
extra_system_message: str = "",
|
338
|
+
result_per_query: int = 10,
|
339
|
+
final_limit: int = 20,
|
340
|
+
similarity_threshold: float = 0.37,
|
341
|
+
**kwargs: Unpack[LLMKwargs],
|
342
|
+
) -> str:
|
343
|
+
"""Asks a question using a refined query based on the provided question.
|
344
|
+
|
345
|
+
Args:
|
346
|
+
question (str): The question to be asked.
|
347
|
+
collection_name (Optional[str]): The name of the collection to retrieve documents from.
|
348
|
+
extra_system_message (str): An additional system message to be included in the prompt.
|
349
|
+
result_per_query (int): The number of results to return per query. Default is 10.
|
350
|
+
final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
|
351
|
+
similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
|
352
|
+
**kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
str: A string response generated after asking with the refined question.
|
356
|
+
"""
|
357
|
+
return await self.aask_retrieved(
|
358
|
+
question,
|
359
|
+
await self.arefined_query(question, **kwargs),
|
360
|
+
collection_name=collection_name,
|
361
|
+
extra_system_message=extra_system_message,
|
362
|
+
result_per_query=result_per_query,
|
363
|
+
final_limit=final_limit,
|
364
|
+
similarity_threshold=similarity_threshold,
|
365
|
+
**kwargs,
|
366
|
+
)
|