fabricatio 0.2.3.dev3__cp312-cp312-manylinux_2_34_x86_64.whl → 0.2.4__cp312-cp312-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. fabricatio/__init__.py +18 -5
  2. fabricatio/_rust.cpython-312-x86_64-linux-gnu.so +0 -0
  3. fabricatio/actions/article.py +81 -0
  4. fabricatio/actions/output.py +21 -0
  5. fabricatio/actions/rag.py +25 -0
  6. fabricatio/capabilities/propose.py +55 -0
  7. fabricatio/capabilities/rag.py +193 -52
  8. fabricatio/capabilities/rating.py +12 -36
  9. fabricatio/capabilities/task.py +6 -23
  10. fabricatio/config.py +43 -2
  11. fabricatio/fs/__init__.py +24 -2
  12. fabricatio/fs/curd.py +14 -8
  13. fabricatio/fs/readers.py +5 -2
  14. fabricatio/models/action.py +19 -4
  15. fabricatio/models/events.py +36 -0
  16. fabricatio/models/extra.py +168 -0
  17. fabricatio/models/generic.py +218 -7
  18. fabricatio/models/kwargs_types.py +15 -0
  19. fabricatio/models/task.py +5 -37
  20. fabricatio/models/tool.py +3 -2
  21. fabricatio/models/usages.py +153 -184
  22. fabricatio/models/utils.py +19 -0
  23. fabricatio/parser.py +35 -8
  24. fabricatio/toolboxes/__init__.py +1 -3
  25. fabricatio/toolboxes/fs.py +15 -1
  26. fabricatio/workflows/articles.py +15 -0
  27. fabricatio/workflows/rag.py +11 -0
  28. fabricatio-0.2.4.data/scripts/tdown +0 -0
  29. {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dist-info}/METADATA +65 -177
  30. fabricatio-0.2.4.dist-info/RECORD +40 -0
  31. fabricatio/actions/__init__.py +0 -5
  32. fabricatio/actions/communication.py +0 -15
  33. fabricatio/actions/transmission.py +0 -23
  34. fabricatio/toolboxes/task.py +0 -6
  35. fabricatio-0.2.3.dev3.data/scripts/tdown +0 -0
  36. fabricatio-0.2.3.dev3.dist-info/RECORD +0 -37
  37. {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dist-info}/WHEEL +0 -0
  38. {fabricatio-0.2.3.dev3.dist-info → fabricatio-0.2.4.dist-info}/licenses/LICENSE +0 -0
fabricatio/__init__.py CHANGED
@@ -3,23 +3,32 @@
3
3
  from importlib.util import find_spec
4
4
 
5
5
  from fabricatio._rust_instances import template_manager
6
+ from fabricatio.actions.article import ExtractArticleEssence, GenerateArticleProposal, GenerateOutline
7
+ from fabricatio.actions.output import DumpFinalizedOutput
6
8
  from fabricatio.core import env
7
- from fabricatio.fs import magika
9
+ from fabricatio.fs import magika, safe_json_read, safe_text_read
8
10
  from fabricatio.journal import logger
9
11
  from fabricatio.models.action import Action, WorkFlow
10
12
  from fabricatio.models.events import Event
13
+ from fabricatio.models.extra import ArticleEssence
11
14
  from fabricatio.models.role import Role
12
15
  from fabricatio.models.task import Task
13
16
  from fabricatio.models.tool import ToolBox
14
17
  from fabricatio.models.utils import Message, Messages
15
18
  from fabricatio.parser import Capture, CodeBlockCapture, JsonCapture, PythonCapture
16
- from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox, task_toolbox
19
+ from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox
20
+ from fabricatio.workflows.articles import WriteOutlineWorkFlow
17
21
 
18
22
  __all__ = [
19
23
  "Action",
24
+ "ArticleEssence",
20
25
  "Capture",
21
26
  "CodeBlockCapture",
27
+ "DumpFinalizedOutput",
22
28
  "Event",
29
+ "ExtractArticleEssence",
30
+ "GenerateArticleProposal",
31
+ "GenerateOutline",
23
32
  "JsonCapture",
24
33
  "Message",
25
34
  "Messages",
@@ -28,18 +37,22 @@ __all__ = [
28
37
  "Task",
29
38
  "ToolBox",
30
39
  "WorkFlow",
40
+ "WriteOutlineWorkFlow",
31
41
  "arithmetic_toolbox",
32
42
  "basic_toolboxes",
33
43
  "env",
34
44
  "fs_toolbox",
35
45
  "logger",
36
46
  "magika",
37
- "task_toolbox",
47
+ "safe_json_read",
48
+ "safe_text_read",
38
49
  "template_manager",
39
50
  ]
40
51
 
41
52
 
42
53
  if find_spec("pymilvus"):
43
- from fabricatio.capabilities.rag import Rag
54
+ from fabricatio.actions.rag import InjectToDB
55
+ from fabricatio.capabilities.rag import RAG
56
+ from fabricatio.workflows.rag import StoreArticle
44
57
 
45
- __all__ += ["Rag"]
58
+ __all__ += ["RAG", "InjectToDB", "StoreArticle"]
@@ -0,0 +1,81 @@
1
+ """Actions for transmitting tasks to targets."""
2
+
3
+ from os import PathLike
4
+ from pathlib import Path
5
+ from typing import Callable, List
6
+
7
+ from fabricatio.fs import safe_text_read
8
+ from fabricatio.journal import logger
9
+ from fabricatio.models.action import Action
10
+ from fabricatio.models.extra import ArticleEssence, ArticleOutline, ArticleProposal
11
+ from fabricatio.models.task import Task
12
+
13
+
14
+ class ExtractArticleEssence(Action):
15
+ """Extract the essence of article(s) in text format from the paths specified in the task dependencies.
16
+
17
+ Notes:
18
+ This action is designed to extract vital information from articles with Markdown format, which is pure text, and
19
+ which is converted from pdf files using `magic-pdf` from the `MinerU` project, see https://github.com/opendatalab/MinerU
20
+ """
21
+
22
+ output_key: str = "article_essence"
23
+ """The key of the output data."""
24
+
25
+ async def _execute[P: PathLike | str](
26
+ self,
27
+ task_input: Task,
28
+ reader: Callable[[P], str] = lambda p: Path(p).read_text(encoding="utf-8"),
29
+ **_,
30
+ ) -> List[ArticleEssence]:
31
+ if not task_input.dependencies:
32
+ logger.info(err := "Task not approved, since no dependencies are provided.")
33
+ raise RuntimeError(err)
34
+
35
+ # trim the references
36
+ contents = ["References".join(c.split("References")[:-1]) for c in map(reader, task_input.dependencies)]
37
+ return await self.propose(
38
+ ArticleEssence,
39
+ contents,
40
+ system_message=f"# your personal briefing: \n{self.briefing}",
41
+ )
42
+
43
+
44
+ class GenerateArticleProposal(Action):
45
+ """Generate an outline for the article based on the extracted essence."""
46
+
47
+ output_key: str = "article_proposal"
48
+ """The key of the output data."""
49
+
50
+ async def _execute(
51
+ self,
52
+ task_input: Task,
53
+ **_,
54
+ ) -> ArticleProposal:
55
+ input_path = await self.awhich_pathstr(
56
+ f"{task_input.briefing}\nExtract the path of file, which contains the article briefing that I need to read."
57
+ )
58
+
59
+ return await self.propose(
60
+ ArticleProposal,
61
+ safe_text_read(input_path),
62
+ system_message=f"# your personal briefing: \n{self.briefing}",
63
+ )
64
+
65
+
66
+ class GenerateOutline(Action):
67
+ """Generate the article based on the outline."""
68
+
69
+ output_key: str = "article"
70
+ """The key of the output data."""
71
+
72
+ async def _execute(
73
+ self,
74
+ article_proposal: ArticleProposal,
75
+ **_,
76
+ ) -> ArticleOutline:
77
+ return await self.propose(
78
+ ArticleOutline,
79
+ article_proposal.display(),
80
+ system_message=f"# your personal briefing: \n{self.briefing}",
81
+ )
@@ -0,0 +1,21 @@
1
+ """Dump the finalized output to a file."""
2
+
3
+ from typing import Unpack
4
+
5
+ from fabricatio.models.action import Action
6
+ from fabricatio.models.generic import FinalizedDumpAble
7
+ from fabricatio.models.task import Task
8
+
9
+
10
+ class DumpFinalizedOutput(Action):
11
+ """Dump the finalized output to a file."""
12
+
13
+ output_key: str = "dump_path"
14
+
15
+ async def _execute(self, task_input: Task, to_dump: FinalizedDumpAble, **cxt: Unpack) -> str:
16
+ dump_path = await self.awhich_pathstr(
17
+ f"{task_input.briefing}\n\nExtract a single path of the file, to which I will dump the data."
18
+ )
19
+
20
+ to_dump.finalized_dump_to(dump_path)
21
+ return dump_path
@@ -0,0 +1,25 @@
1
+ """Inject data into the database."""
2
+
3
+ from typing import List, Optional, Unpack
4
+
5
+ from fabricatio.capabilities.rag import RAG
6
+ from fabricatio.models.action import Action
7
+ from fabricatio.models.generic import PrepareVectorization
8
+
9
+
10
+ class InjectToDB(Action, RAG):
11
+ """Inject data into the database."""
12
+
13
+ output_key: str = "collection_name"
14
+
15
+ async def _execute[T: PrepareVectorization](
16
+ self, to_inject: T | List[T], collection_name: Optional[str] = "my_collection", **cxt: Unpack
17
+ ) -> str:
18
+ if not isinstance(to_inject, list):
19
+ to_inject = [to_inject]
20
+
21
+ await self.view(collection_name, create=True).consume_string(
22
+ [t.prepare_vectorization(self.embedding_max_sequence_length) for t in to_inject],
23
+ )
24
+
25
+ return collection_name
@@ -0,0 +1,55 @@
1
+ """A module for the task capabilities of the Fabricatio library."""
2
+
3
+ from typing import List, Type, Unpack, overload
4
+
5
+ from fabricatio.models.generic import ProposedAble
6
+ from fabricatio.models.kwargs_types import GenerateKwargs
7
+ from fabricatio.models.usages import LLMUsage
8
+
9
+
10
+ class Propose[M: ProposedAble](LLMUsage):
11
+ """A class that proposes an Obj based on a prompt."""
12
+
13
+ @overload
14
+ async def propose(
15
+ self,
16
+ cls: Type[M],
17
+ prompt: List[str],
18
+ **kwargs: Unpack[GenerateKwargs],
19
+ ) -> List[M]: ...
20
+
21
+ @overload
22
+ async def propose(
23
+ self,
24
+ cls: Type[M],
25
+ prompt: str,
26
+ **kwargs: Unpack[GenerateKwargs],
27
+ ) -> M: ...
28
+
29
+ async def propose(
30
+ self,
31
+ cls: Type[M],
32
+ prompt: List[str] | str,
33
+ **kwargs: Unpack[GenerateKwargs],
34
+ ) -> List[M] | M:
35
+ """Asynchronously proposes a task based on a given prompt and parameters.
36
+
37
+ Parameters:
38
+ cls: The class type of the task to be proposed.
39
+ prompt: The prompt text for proposing a task, which is a string that must be provided.
40
+ **kwargs: The keyword arguments for the LLM (Large Language Model) usage.
41
+
42
+ Returns:
43
+ A Task object based on the proposal result.
44
+ """
45
+ if isinstance(prompt, str):
46
+ return await self.aask_validate(
47
+ question=cls.create_json_prompt(prompt),
48
+ validator=cls.instantiate_from_string,
49
+ **kwargs,
50
+ )
51
+ return await self.aask_validate_batch(
52
+ questions=[cls.create_json_prompt(p) for p in prompt],
53
+ validator=cls.instantiate_from_string,
54
+ **kwargs,
55
+ )
@@ -1,71 +1,120 @@
1
1
  """A module for the RAG (Retrieval Augmented Generation) model."""
2
2
 
3
+ try:
4
+ from pymilvus import MilvusClient
5
+ except ImportError as e:
6
+ raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
3
7
  from functools import lru_cache
4
8
  from operator import itemgetter
5
9
  from os import PathLike
6
10
  from pathlib import Path
7
- from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack
11
+ from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, overload
8
12
 
9
- from fabricatio import template_manager
13
+ from fabricatio._rust_instances import template_manager
10
14
  from fabricatio.config import configs
11
- from fabricatio.models.kwargs_types import LLMKwargs
12
- from fabricatio.models.usages import LLMUsage
15
+ from fabricatio.journal import logger
16
+ from fabricatio.models.kwargs_types import (
17
+ ChooseKwargs,
18
+ CollectionSimpleConfigKwargs,
19
+ EmbeddingKwargs,
20
+ FetchKwargs,
21
+ LLMKwargs,
22
+ )
23
+ from fabricatio.models.usages import EmbeddingUsage
13
24
  from fabricatio.models.utils import MilvusData
14
- from more_itertools.recipes import flatten
15
-
16
- try:
17
- from pymilvus import MilvusClient
18
- except ImportError as e:
19
- raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
25
+ from more_itertools.recipes import flatten, unique
20
26
  from pydantic import Field, PrivateAttr
21
27
 
22
28
 
23
29
  @lru_cache(maxsize=None)
24
- def create_client(
25
- uri: Optional[str] = None, token: Optional[str] = None, timeout: Optional[float] = None
26
- ) -> MilvusClient:
30
+ def create_client(uri: str, token: str = "", timeout: Optional[float] = None) -> MilvusClient:
27
31
  """Create a Milvus client."""
28
32
  return MilvusClient(
29
- uri=uri or configs.rag.milvus_uri.unicode_string(),
30
- token=token or configs.rag.milvus_token.get_secret_value() if configs.rag.milvus_token else "",
31
- timeout=timeout or configs.rag.milvus_timeout,
33
+ uri=uri,
34
+ token=token,
35
+ timeout=timeout,
32
36
  )
33
37
 
34
38
 
35
- class Rag(LLMUsage):
39
+ class RAG(EmbeddingUsage):
36
40
  """A class representing the RAG (Retrieval Augmented Generation) model."""
37
41
 
38
- milvus_uri: Optional[str] = Field(default=None, frozen=True)
39
- """The URI of the Milvus server."""
40
- milvus_token: Optional[str] = Field(default=None, frozen=True)
41
- """The token for the Milvus server."""
42
- milvus_timeout: Optional[float] = Field(default=None, frozen=True)
43
- """The timeout for the Milvus server."""
44
42
  target_collection: Optional[str] = Field(default=None)
45
43
  """The name of the collection being viewed."""
46
44
 
47
- _client: MilvusClient = PrivateAttr(None)
45
+ _client: Optional[MilvusClient] = PrivateAttr(None)
48
46
  """The Milvus client used for the RAG model."""
49
47
 
50
48
  @property
51
49
  def client(self) -> MilvusClient:
52
50
  """Return the Milvus client."""
51
+ if self._client is None:
52
+ raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
53
53
  return self._client
54
54
 
55
- def model_post_init(self, __context: Any) -> None:
56
- """Initialize the RAG model by creating the collection if it does not exist."""
57
- self._client = create_client(self.milvus_uri, self.milvus_token, self.milvus_timeout)
58
- self.view(self.target_collection, create=True)
55
+ def init_client(
56
+ self,
57
+ milvus_uri: Optional[str] = None,
58
+ milvus_token: Optional[str] = None,
59
+ milvus_timeout: Optional[float] = None,
60
+ ) -> Self:
61
+ """Initialize the Milvus client."""
62
+ self._client = create_client(
63
+ uri=milvus_uri or (self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
64
+ token=milvus_token
65
+ or (token.get_secret_value() if (token := (self.milvus_token or configs.rag.milvus_token)) else ""),
66
+ timeout=milvus_timeout or self.milvus_timeout,
67
+ )
68
+ return self
69
+
70
+ @overload
71
+ async def pack(
72
+ self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
73
+ ) -> List[MilvusData]: ...
74
+ @overload
75
+ async def pack(
76
+ self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
77
+ ) -> MilvusData: ...
59
78
 
60
- def view(self, collection_name: Optional[str], create: bool = False) -> Self:
79
+ async def pack(
80
+ self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
81
+ ) -> List[MilvusData] | MilvusData:
82
+ """Asynchronously generates MilvusData objects for the given input text.
83
+
84
+ Args:
85
+ input_text (List[str] | str): A string or list of strings to generate embeddings for.
86
+ subject (Optional[str]): The subject of the input text. Defaults to None.
87
+ **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
88
+
89
+ Returns:
90
+ List[MilvusData] | MilvusData: The generated MilvusData objects.
91
+ """
92
+ if isinstance(input_text, str):
93
+ return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
94
+ vecs = await self.vectorize(input_text, **kwargs)
95
+ return [
96
+ MilvusData(
97
+ vector=vec,
98
+ text=text,
99
+ subject=subject,
100
+ )
101
+ for text, vec in zip(input_text, vecs, strict=True)
102
+ ]
103
+
104
+ def view(
105
+ self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionSimpleConfigKwargs]
106
+ ) -> Self:
61
107
  """View the specified collection.
62
108
 
63
109
  Args:
64
110
  collection_name (str): The name of the collection.
65
111
  create (bool): Whether to create the collection if it does not exist.
112
+ **kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
66
113
  """
67
114
  if create and collection_name and not self._client.has_collection(collection_name):
68
- self._client.create_collection(collection_name)
115
+ kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
116
+ self._client.create_collection(collection_name, auto_id=True, **kwargs)
117
+ logger.info(f"Creating collection {collection_name}")
69
118
 
70
119
  self.target_collection = collection_name
71
120
  return self
@@ -90,13 +139,14 @@ class Rag(LLMUsage):
90
139
  return self.target_collection
91
140
 
92
141
  def add_document[D: Union[Dict[str, Any], MilvusData]](
93
- self, data: D | List[D], collection_name: Optional[str] = None
142
+ self, data: D | List[D], collection_name: Optional[str] = None, flush: bool = False
94
143
  ) -> Self:
95
144
  """Adds a document to the specified collection.
96
145
 
97
146
  Args:
98
147
  data (Union[Dict[str, Any], MilvusData] | List[Union[Dict[str, Any], MilvusData]]): The data to be added to the collection.
99
148
  collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
149
+ flush (bool): Whether to flush the collection after insertion.
100
150
 
101
151
  Returns:
102
152
  Self: The current instance, allowing for method chaining.
@@ -105,11 +155,19 @@ class Rag(LLMUsage):
105
155
  data = data.prepare_insertion()
106
156
  if isinstance(data, list):
107
157
  data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
108
- self._client.insert(collection_name or self.safe_target_collection, data)
158
+ c_name = collection_name or self.safe_target_collection
159
+ self._client.insert(c_name, data)
160
+
161
+ if flush:
162
+ logger.debug(f"Flushing collection {c_name}")
163
+ self._client.flush(c_name)
109
164
  return self
110
165
 
111
- def consume(
112
- self, source: PathLike, reader: Callable[[PathLike], MilvusData], collection_name: Optional[str] = None
166
+ async def consume_file(
167
+ self,
168
+ source: List[PathLike] | PathLike,
169
+ reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
170
+ collection_name: Optional[str] = None,
113
171
  ) -> Self:
114
172
  """Consume a file and add its content to the collection.
115
173
 
@@ -121,8 +179,21 @@ class Rag(LLMUsage):
121
179
  Returns:
122
180
  Self: The current instance, allowing for method chaining.
123
181
  """
124
- data = reader(Path(source))
125
- self.add_document(data, collection_name or self.safe_target_collection)
182
+ if not isinstance(source, list):
183
+ source = [source]
184
+ return await self.consume_string([reader(s) for s in source], collection_name)
185
+
186
+ async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
187
+ """Consume a string and add it to the collection.
188
+
189
+ Args:
190
+ text (List[str] | str): The text to be added to the collection.
191
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
192
+
193
+ Returns:
194
+ Self: The current instance, allowing for method chaining.
195
+ """
196
+ self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
126
197
  return self
127
198
 
128
199
  async def afetch_document(
@@ -130,6 +201,7 @@ class Rag(LLMUsage):
130
201
  vecs: List[List[float]],
131
202
  desired_fields: List[str] | str,
132
203
  collection_name: Optional[str] = None,
204
+ similarity_threshold: float = 0.37,
133
205
  result_per_query: int = 10,
134
206
  ) -> List[Dict[str, Any]] | List[Any]:
135
207
  """Fetch data from the collection.
@@ -138,6 +210,7 @@ class Rag(LLMUsage):
138
210
  vecs (List[List[float]]): The vectors to search for.
139
211
  desired_fields (List[str] | str): The fields to retrieve.
140
212
  collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
213
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
141
214
  result_per_query (int): The number of results to return per query.
142
215
 
143
216
  Returns:
@@ -147,16 +220,18 @@ class Rag(LLMUsage):
147
220
  search_results = self._client.search(
148
221
  collection_name or self.safe_target_collection,
149
222
  vecs,
223
+ search_params={"radius": similarity_threshold},
150
224
  output_fields=desired_fields if isinstance(desired_fields, list) else [desired_fields],
151
225
  limit=result_per_query,
152
226
  )
153
227
 
154
228
  # Step 2: Flatten the search results
155
229
  flattened_results = flatten(search_results)
156
-
230
+ unique_results = unique(flattened_results, key=itemgetter("id"))
157
231
  # Step 3: Sort by distance (descending)
158
- sorted_results = sorted(flattened_results, key=itemgetter("distance"), reverse=True)
232
+ sorted_results = sorted(unique_results, key=itemgetter("distance"), reverse=True)
159
233
 
234
+ logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
160
235
  # Step 4: Extract the entities
161
236
  resp = [result["entity"] for result in sorted_results]
162
237
 
@@ -167,37 +242,38 @@ class Rag(LLMUsage):
167
242
  async def aretrieve(
168
243
  self,
169
244
  query: List[str] | str,
170
- collection_name: Optional[str] = None,
171
- result_per_query: int = 10,
172
245
  final_limit: int = 20,
246
+ **kwargs: Unpack[FetchKwargs],
173
247
  ) -> List[str]:
174
248
  """Retrieve data from the collection.
175
249
 
176
250
  Args:
177
251
  query (List[str] | str): The query to be used for retrieval.
178
- collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
179
- result_per_query (int): The number of results to be returned per query.
180
252
  final_limit (int): The final limit on the number of results to return.
253
+ **kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
181
254
 
182
255
  Returns:
183
256
  List[str]: A list of strings containing the retrieved data.
184
257
  """
185
258
  if isinstance(query, str):
186
259
  query = [query]
187
- return await self.afetch_document(
188
- vecs=(await self.vectorize(query)),
189
- desired_fields="text",
190
- collection_name=collection_name,
191
- result_per_query=result_per_query,
260
+ return (
261
+ await self.afetch_document(
262
+ vecs=(await self.vectorize(query)),
263
+ desired_fields="text",
264
+ **kwargs,
265
+ )
192
266
  )[:final_limit]
193
267
 
194
268
  async def aask_retrieved(
195
269
  self,
196
- question: str | List[str],
197
- query: List[str] | str,
270
+ question: str,
271
+ query: Optional[List[str] | str] = None,
198
272
  collection_name: Optional[str] = None,
273
+ extra_system_message: str = "",
199
274
  result_per_query: int = 10,
200
275
  final_limit: int = 20,
276
+ similarity_threshold: float = 0.37,
201
277
  **kwargs: Unpack[LLMKwargs],
202
278
  ) -> str:
203
279
  """Asks a question by retrieving relevant documents based on the provided query.
@@ -206,20 +282,85 @@ class Rag(LLMUsage):
206
282
  specified question using the retrieved documents as context.
207
283
 
208
284
  Args:
209
- question (str | List[str]): The question or list of questions to be asked.
285
+ question (str): The question to be asked.
210
286
  query (List[str] | str): The query or list of queries used for document retrieval.
211
287
  collection_name (Optional[str]): The name of the collection to retrieve documents from.
212
288
  If not provided, the currently viewed collection is used.
289
+ extra_system_message (str): An additional system message to be included in the prompt.
213
290
  result_per_query (int): The number of results to return per query. Default is 10.
214
291
  final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
292
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
215
293
  **kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
216
294
 
217
295
  Returns:
218
296
  str: A string response generated after asking with the context of retrieved documents.
219
297
  """
220
- docs = await self.aretrieve(query, collection_name, result_per_query, final_limit)
298
+ docs = await self.aretrieve(
299
+ query or question,
300
+ final_limit,
301
+ collection_name=collection_name,
302
+ result_per_query=result_per_query,
303
+ similarity_threshold=similarity_threshold,
304
+ )
305
+
306
+ rendered = template_manager.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
307
+
308
+ logger.debug(f"Retrieved Documents: \n{rendered}")
221
309
  return await self.aask(
222
310
  question,
223
- template_manager.render_template(configs.templates.retrieved_display_template, {"docs": docs}),
311
+ f"{rendered}\n\n{extra_system_message}",
312
+ **kwargs,
313
+ )
314
+
315
+ async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> List[str]:
316
+ """Refines the given question using a template.
317
+
318
+ Args:
319
+ question (List[str] | str): The question to be refined.
320
+ **kwargs (Unpack[ChooseKwargs]): Additional keyword arguments for the refinement process.
321
+
322
+ Returns:
323
+ List[str]: A list of refined questions.
324
+ """
325
+ return await self.aliststr(
326
+ template_manager.render_template(
327
+ configs.templates.refined_query_template,
328
+ {"question": [question] if isinstance(question, str) else question},
329
+ ),
330
+ **kwargs,
331
+ )
332
+
333
+ async def aask_refined(
334
+ self,
335
+ question: str,
336
+ collection_name: Optional[str] = None,
337
+ extra_system_message: str = "",
338
+ result_per_query: int = 10,
339
+ final_limit: int = 20,
340
+ similarity_threshold: float = 0.37,
341
+ **kwargs: Unpack[LLMKwargs],
342
+ ) -> str:
343
+ """Asks a question using a refined query based on the provided question.
344
+
345
+ Args:
346
+ question (str): The question to be asked.
347
+ collection_name (Optional[str]): The name of the collection to retrieve documents from.
348
+ extra_system_message (str): An additional system message to be included in the prompt.
349
+ result_per_query (int): The number of results to return per query. Default is 10.
350
+ final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
351
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
352
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
353
+
354
+ Returns:
355
+ str: A string response generated after asking with the refined question.
356
+ """
357
+ return await self.aask_retrieved(
358
+ question,
359
+ await self.arefined_query(question, **kwargs),
360
+ collection_name=collection_name,
361
+ extra_system_message=extra_system_message,
362
+ result_per_query=result_per_query,
363
+ final_limit=final_limit,
364
+ similarity_threshold=similarity_threshold,
224
365
  **kwargs,
225
366
  )