fabricatio 0.2.4.dev2__cp312-cp312-win_amd64.whl → 0.2.5__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. fabricatio/__init__.py +14 -5
  2. fabricatio/_rust.cp312-win_amd64.pyd +0 -0
  3. fabricatio/_rust.pyi +65 -16
  4. fabricatio/_rust_instances.py +2 -0
  5. fabricatio/actions/article.py +46 -14
  6. fabricatio/actions/output.py +21 -0
  7. fabricatio/actions/rag.py +1 -1
  8. fabricatio/capabilities/propose.py +14 -20
  9. fabricatio/capabilities/rag.py +85 -26
  10. fabricatio/capabilities/rating.py +59 -51
  11. fabricatio/capabilities/review.py +241 -0
  12. fabricatio/capabilities/task.py +7 -8
  13. fabricatio/config.py +36 -4
  14. fabricatio/fs/__init__.py +13 -1
  15. fabricatio/fs/curd.py +27 -8
  16. fabricatio/fs/readers.py +6 -3
  17. fabricatio/journal.py +1 -1
  18. fabricatio/models/action.py +6 -8
  19. fabricatio/models/events.py +6 -4
  20. fabricatio/models/extra.py +100 -25
  21. fabricatio/models/generic.py +56 -4
  22. fabricatio/models/kwargs_types.py +123 -35
  23. fabricatio/models/role.py +3 -3
  24. fabricatio/models/task.py +0 -14
  25. fabricatio/models/tool.py +7 -6
  26. fabricatio/models/usages.py +144 -101
  27. fabricatio/parser.py +26 -5
  28. fabricatio/toolboxes/__init__.py +1 -3
  29. fabricatio/toolboxes/fs.py +17 -1
  30. fabricatio/workflows/articles.py +10 -6
  31. fabricatio/workflows/rag.py +11 -0
  32. fabricatio-0.2.5.data/scripts/tdown.exe +0 -0
  33. {fabricatio-0.2.4.dev2.dist-info → fabricatio-0.2.5.dist-info}/METADATA +2 -1
  34. fabricatio-0.2.5.dist-info/RECORD +41 -0
  35. fabricatio/toolboxes/task.py +0 -6
  36. fabricatio-0.2.4.dev2.data/scripts/tdown.exe +0 -0
  37. fabricatio-0.2.4.dev2.dist-info/RECORD +0 -39
  38. {fabricatio-0.2.4.dev2.dist-info → fabricatio-0.2.5.dist-info}/WHEEL +0 -0
  39. {fabricatio-0.2.4.dev2.dist-info → fabricatio-0.2.5.dist-info}/licenses/LICENSE +0 -0
fabricatio/__init__.py CHANGED
@@ -2,10 +2,12 @@
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
5
+ from fabricatio._rust import BibManager
5
6
  from fabricatio._rust_instances import template_manager
6
- from fabricatio.actions.article import ExtractArticleEssence
7
+ from fabricatio.actions.article import ExtractArticleEssence, GenerateArticleProposal, GenerateOutline
8
+ from fabricatio.actions.output import DumpFinalizedOutput
7
9
  from fabricatio.core import env
8
- from fabricatio.fs import magika
10
+ from fabricatio.fs import magika, safe_json_read, safe_text_read
9
11
  from fabricatio.journal import logger
10
12
  from fabricatio.models.action import Action, WorkFlow
11
13
  from fabricatio.models.events import Event
@@ -15,15 +17,20 @@ from fabricatio.models.task import Task
15
17
  from fabricatio.models.tool import ToolBox
16
18
  from fabricatio.models.utils import Message, Messages
17
19
  from fabricatio.parser import Capture, CodeBlockCapture, JsonCapture, PythonCapture
18
- from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox, task_toolbox
20
+ from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox
21
+ from fabricatio.workflows.articles import WriteOutlineWorkFlow
19
22
 
20
23
  __all__ = [
21
24
  "Action",
22
25
  "ArticleEssence",
26
+ "BibManager",
23
27
  "Capture",
24
28
  "CodeBlockCapture",
29
+ "DumpFinalizedOutput",
25
30
  "Event",
26
31
  "ExtractArticleEssence",
32
+ "GenerateArticleProposal",
33
+ "GenerateOutline",
27
34
  "JsonCapture",
28
35
  "Message",
29
36
  "Messages",
@@ -32,13 +39,15 @@ __all__ = [
32
39
  "Task",
33
40
  "ToolBox",
34
41
  "WorkFlow",
42
+ "WriteOutlineWorkFlow",
35
43
  "arithmetic_toolbox",
36
44
  "basic_toolboxes",
37
45
  "env",
38
46
  "fs_toolbox",
39
47
  "logger",
40
48
  "magika",
41
- "task_toolbox",
49
+ "safe_json_read",
50
+ "safe_text_read",
42
51
  "template_manager",
43
52
  ]
44
53
 
@@ -46,6 +55,6 @@ __all__ = [
46
55
  if find_spec("pymilvus"):
47
56
  from fabricatio.actions.rag import InjectToDB
48
57
  from fabricatio.capabilities.rag import RAG
49
- from fabricatio.workflows.articles import StoreArticle
58
+ from fabricatio.workflows.rag import StoreArticle
50
59
 
51
60
  __all__ += ["RAG", "InjectToDB", "StoreArticle"]
Binary file
fabricatio/_rust.pyi CHANGED
@@ -2,52 +2,101 @@ from pathlib import Path
2
2
  from typing import Any, Dict, List, Optional
3
3
 
4
4
  class TemplateManager:
5
- """TemplateManager class for managing handlebars templates."""
5
+ """Template rendering engine using Handlebars templates.
6
+
7
+ This manager handles template discovery, loading, and rendering
8
+ through a wrapper around the handlebars-rust engine.
9
+
10
+ See: https://crates.io/crates/handlebars
11
+ """
6
12
  def __init__(
7
13
  self, template_dirs: List[Path], suffix: Optional[str] = None, active_loading: Optional[bool] = None
8
14
  ) -> None:
9
15
  """Initialize the template manager.
10
16
 
11
17
  Args:
12
- template_dirs (List[Path]): A list of paths to directories containing templates.
13
- suffix (str, optional): The suffix of template files. None means 'hbs' suffix.
14
- active_loading (bool, optional): Whether to enable active loading of templates.
18
+ template_dirs: List of directories containing template files
19
+ suffix: File extension for templates (defaults to 'hbs')
20
+ active_loading: Whether to enable dev mode for reloading templates on change
15
21
  """
16
22
 
17
23
  @property
18
24
  def template_count(self) -> int:
19
- """Get the number of templates discovered."""
25
+ """Returns the number of currently loaded templates."""
20
26
 
21
27
  def get_template_source(self, name: str) -> Optional[str]:
22
- """Get the source path of a template by name.
28
+ """Get the filesystem path for a template.
23
29
 
24
30
  Args:
25
- name (str): The name of the template to retrieve.
31
+ name: Template name (without extension)
26
32
 
27
33
  Returns:
28
- Optional[str]: The source path of the template.
34
+ Path to the template file if found, None otherwise
29
35
  """
30
36
 
31
37
  def discover_templates(self) -> None:
32
- """Discover templates in the specified directories."""
38
+ """Scan template directories and load available templates.
39
+
40
+ This refreshes the template cache, finding any new or modified templates.
41
+ """
33
42
 
34
43
  def render_template(self, name: str, data: Dict[str, Any]) -> str:
35
- """Render a template with the given name and data.
44
+ """Render a template with context data.
36
45
 
37
46
  Args:
38
- name (str): The name of the template to render.
39
- data (Dict[str, Any]): The data to pass to the template.
47
+ name: Template name (without extension)
48
+ data: Context dictionary to provide variables to the template
40
49
 
41
50
  Returns:
42
- str: The rendered template.
51
+ Rendered template content as string
52
+
53
+ Raises:
54
+ RuntimeError: If template rendering fails
43
55
  """
44
56
 
45
57
  def blake3_hash(content: bytes) -> str:
46
- """Calculate the BLAKE3 hash of the given data.
58
+ """Calculate the BLAKE3 cryptographic hash of data.
47
59
 
48
60
  Args:
49
- content (bytes): The data to hash.
61
+ content: Bytes to be hashed
50
62
 
51
63
  Returns:
52
- str: The BLAKE3 hash of the data.
64
+ Hex-encoded BLAKE3 hash string
53
65
  """
66
+
67
+ class BibManager:
68
+ """BibTeX bibliography manager for parsing and querying citation data."""
69
+
70
+ def __init__(self, path: str) -> None:
71
+ """Initialize the bibliography manager.
72
+
73
+ Args:
74
+ path: Path to BibTeX (.bib) file to load
75
+
76
+ Raises:
77
+ RuntimeError: If file cannot be read or parsed
78
+ """
79
+
80
+ def get_cite_key(self, title: str) -> Optional[str]:
81
+ """Find citation key by exact title match.
82
+
83
+ Args:
84
+ title: Full title to search for (case-insensitive)
85
+
86
+ Returns:
87
+ Citation key if exact match found, None otherwise
88
+ """
89
+
90
+ def get_cite_key_fuzzy(self, query: str) -> Optional[str]:
91
+ """Find best matching citation using fuzzy text search.
92
+
93
+ Args:
94
+ query: Search term to find in bibliography entries
95
+
96
+ Returns:
97
+ Citation key of best matching entry, or None if no good match
98
+
99
+ Notes:
100
+ Uses nucleo_matcher for high-quality fuzzy text searching
101
+ See: https://crates.io/crates/nucleo-matcher
102
+ """
@@ -1,3 +1,5 @@
1
+ """Some necessary instances."""
2
+
1
3
  from fabricatio._rust import TemplateManager
2
4
  from fabricatio.config import configs
3
5
 
@@ -2,11 +2,12 @@
2
2
 
3
3
  from os import PathLike
4
4
  from pathlib import Path
5
- from typing import Callable, List
5
+ from typing import Callable, List, Optional
6
6
 
7
+ from fabricatio.fs import safe_text_read
7
8
  from fabricatio.journal import logger
8
9
  from fabricatio.models.action import Action
9
- from fabricatio.models.extra import ArticleEssence
10
+ from fabricatio.models.extra import ArticleEssence, ArticleOutline, ArticleProposal
10
11
  from fabricatio.models.task import Task
11
12
 
12
13
 
@@ -18,11 +19,6 @@ class ExtractArticleEssence(Action):
18
19
  which is converted from pdf files using `magic-pdf` from the `MinerU` project, see https://github.com/opendatalab/MinerU
19
20
  """
20
21
 
21
- name: str = "extract article essence"
22
- """The name of the action."""
23
- description: str = "Extract the essence of article(s) from the paths specified in the task dependencies."
24
- """The description of the action."""
25
-
26
22
  output_key: str = "article_essence"
27
23
  """The key of the output data."""
28
24
 
@@ -31,13 +27,9 @@ class ExtractArticleEssence(Action):
31
27
  task_input: Task,
32
28
  reader: Callable[[P], str] = lambda p: Path(p).read_text(encoding="utf-8"),
33
29
  **_,
34
- ) -> List[ArticleEssence]:
35
- if not await self.ajudge(
36
- f"= Task\n{task_input.briefing}\n\n\n= Role\n{self.briefing}",
37
- affirm_case="The task does not violate the role, and could be approved since the file dependencies are specified.",
38
- deny_case="The task does violate the role, and could not be approved.",
39
- ):
40
- logger.info(err := "Task not approved.")
30
+ ) -> Optional[List[ArticleEssence]]:
31
+ if not task_input.dependencies:
32
+ logger.info(err := "Task not approved, since no dependencies are provided.")
41
33
  raise RuntimeError(err)
42
34
 
43
35
  # trim the references
@@ -47,3 +39,43 @@ class ExtractArticleEssence(Action):
47
39
  contents,
48
40
  system_message=f"# your personal briefing: \n{self.briefing}",
49
41
  )
42
+
43
+
44
+ class GenerateArticleProposal(Action):
45
+ """Generate an outline for the article based on the extracted essence."""
46
+
47
+ output_key: str = "article_proposal"
48
+ """The key of the output data."""
49
+
50
+ async def _execute(
51
+ self,
52
+ task_input: Task,
53
+ **_,
54
+ ) -> Optional[ArticleProposal]:
55
+ input_path = await self.awhich_pathstr(
56
+ f"{task_input.briefing}\nExtract the path of file, which contains the article briefing that I need to read."
57
+ )
58
+
59
+ return await self.propose(
60
+ ArticleProposal,
61
+ safe_text_read(input_path),
62
+ system_message=f"# your personal briefing: \n{self.briefing}",
63
+ )
64
+
65
+
66
+ class GenerateOutline(Action):
67
+ """Generate the article based on the outline."""
68
+
69
+ output_key: str = "article"
70
+ """The key of the output data."""
71
+
72
+ async def _execute(
73
+ self,
74
+ article_proposal: ArticleProposal,
75
+ **_,
76
+ ) -> Optional[ArticleOutline]:
77
+ return await self.propose(
78
+ ArticleOutline,
79
+ article_proposal.display(),
80
+ system_message=f"# your personal briefing: \n{self.briefing}",
81
+ )
@@ -0,0 +1,21 @@
1
+ """Dump the finalized output to a file."""
2
+
3
+ from typing import Unpack
4
+
5
+ from fabricatio.models.action import Action
6
+ from fabricatio.models.generic import FinalizedDumpAble
7
+ from fabricatio.models.task import Task
8
+
9
+
10
+ class DumpFinalizedOutput(Action):
11
+ """Dump the finalized output to a file."""
12
+
13
+ output_key: str = "dump_path"
14
+
15
+ async def _execute(self, task_input: Task, to_dump: FinalizedDumpAble, **cxt: Unpack) -> str:
16
+ dump_path = await self.awhich_pathstr(
17
+ f"{task_input.briefing}\n\nExtract a single path of the file, to which I will dump the data."
18
+ )
19
+
20
+ to_dump.finalized_dump_to(dump_path)
21
+ return dump_path
fabricatio/actions/rag.py CHANGED
@@ -14,7 +14,7 @@ class InjectToDB(Action, RAG):
14
14
 
15
15
  async def _execute[T: PrepareVectorization](
16
16
  self, to_inject: T | List[T], collection_name: Optional[str] = "my_collection", **cxt: Unpack
17
- ) -> str:
17
+ ) -> Optional[str]:
18
18
  if not isinstance(to_inject, list):
19
19
  to_inject = [to_inject]
20
20
 
@@ -1,37 +1,37 @@
1
1
  """A module for the task capabilities of the Fabricatio library."""
2
2
 
3
- from typing import List, Type, Unpack, overload
3
+ from typing import List, Optional, Type, Unpack, overload
4
4
 
5
5
  from fabricatio.models.generic import ProposedAble
6
- from fabricatio.models.kwargs_types import GenerateKwargs
6
+ from fabricatio.models.kwargs_types import ValidateKwargs
7
7
  from fabricatio.models.usages import LLMUsage
8
8
 
9
9
 
10
- class Propose[M: ProposedAble](LLMUsage):
10
+ class Propose(LLMUsage):
11
11
  """A class that proposes an Obj based on a prompt."""
12
12
 
13
13
  @overload
14
- async def propose(
14
+ async def propose[M: ProposedAble](
15
15
  self,
16
16
  cls: Type[M],
17
17
  prompt: List[str],
18
- **kwargs: Unpack[GenerateKwargs],
19
- ) -> List[M]: ...
18
+ **kwargs: Unpack[ValidateKwargs[M]],
19
+ ) -> Optional[List[M]]: ...
20
20
 
21
21
  @overload
22
- async def propose(
22
+ async def propose[M: ProposedAble](
23
23
  self,
24
24
  cls: Type[M],
25
25
  prompt: str,
26
- **kwargs: Unpack[GenerateKwargs],
27
- ) -> M: ...
26
+ **kwargs: Unpack[ValidateKwargs[M]],
27
+ ) -> Optional[M]: ...
28
28
 
29
- async def propose(
29
+ async def propose[M: ProposedAble](
30
30
  self,
31
31
  cls: Type[M],
32
32
  prompt: List[str] | str,
33
- **kwargs: Unpack[GenerateKwargs],
34
- ) -> List[M] | M:
33
+ **kwargs: Unpack[ValidateKwargs[M]],
34
+ ) -> Optional[List[M] | M]:
35
35
  """Asynchronously proposes a task based on a given prompt and parameters.
36
36
 
37
37
  Parameters:
@@ -42,14 +42,8 @@ class Propose[M: ProposedAble](LLMUsage):
42
42
  Returns:
43
43
  A Task object based on the proposal result.
44
44
  """
45
- if isinstance(prompt, str):
46
- return await self.aask_validate(
47
- question=cls.create_json_prompt(prompt),
48
- validator=cls.instantiate_from_string,
49
- **kwargs,
50
- )
51
- return await self.aask_validate_batch(
52
- questions=[cls.create_json_prompt(p) for p in prompt],
45
+ return await self.aask_validate(
46
+ question=cls.create_json_prompt(prompt),
53
47
  validator=cls.instantiate_from_string,
54
48
  **kwargs,
55
49
  )
@@ -8,15 +8,21 @@ from functools import lru_cache
8
8
  from operator import itemgetter
9
9
  from os import PathLike
10
10
  from pathlib import Path
11
- from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, overload
11
+ from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, cast, overload
12
12
 
13
13
  from fabricatio._rust_instances import template_manager
14
14
  from fabricatio.config import configs
15
15
  from fabricatio.journal import logger
16
- from fabricatio.models.kwargs_types import CollectionSimpleConfigKwargs, EmbeddingKwargs, FetchKwargs, LLMKwargs
16
+ from fabricatio.models.kwargs_types import (
17
+ ChooseKwargs,
18
+ CollectionSimpleConfigKwargs,
19
+ EmbeddingKwargs,
20
+ FetchKwargs,
21
+ LLMKwargs,
22
+ )
17
23
  from fabricatio.models.usages import EmbeddingUsage
18
24
  from fabricatio.models.utils import MilvusData
19
- from more_itertools.recipes import flatten
25
+ from more_itertools.recipes import flatten, unique
20
26
  from pydantic import Field, PrivateAttr
21
27
 
22
28
 
@@ -105,9 +111,9 @@ class RAG(EmbeddingUsage):
105
111
  create (bool): Whether to create the collection if it does not exist.
106
112
  **kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
107
113
  """
108
- if create and collection_name and not self._client.has_collection(collection_name):
114
+ if create and collection_name and self.client.has_collection(collection_name):
109
115
  kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
110
- self._client.create_collection(collection_name, auto_id=True, **kwargs)
116
+ self.client.create_collection(collection_name, auto_id=True, **kwargs)
111
117
  logger.info(f"Creating collection {collection_name}")
112
118
 
113
119
  self.target_collection = collection_name
@@ -146,15 +152,17 @@ class RAG(EmbeddingUsage):
146
152
  Self: The current instance, allowing for method chaining.
147
153
  """
148
154
  if isinstance(data, MilvusData):
149
- data = data.prepare_insertion()
150
- if isinstance(data, list):
151
- data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
155
+ prepared_data = data.prepare_insertion()
156
+ elif isinstance(data, list):
157
+ prepared_data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
158
+ else:
159
+ raise TypeError(f"Expected MilvusData or list of MilvusData, got {type(data)}")
152
160
  c_name = collection_name or self.safe_target_collection
153
- self._client.insert(c_name, data)
161
+ self.client.insert(c_name, prepared_data)
154
162
 
155
163
  if flush:
156
164
  logger.debug(f"Flushing collection {c_name}")
157
- self._client.flush(c_name)
165
+ self.client.flush(c_name)
158
166
  return self
159
167
 
160
168
  async def consume_file(
@@ -190,14 +198,14 @@ class RAG(EmbeddingUsage):
190
198
  self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
191
199
  return self
192
200
 
193
- async def afetch_document(
201
+ async def afetch_document[V: (int, str, float, bytes)](
194
202
  self,
195
203
  vecs: List[List[float]],
196
204
  desired_fields: List[str] | str,
197
205
  collection_name: Optional[str] = None,
198
206
  similarity_threshold: float = 0.37,
199
207
  result_per_query: int = 10,
200
- ) -> List[Dict[str, Any]] | List[Any]:
208
+ ) -> List[Dict[str, Any]] | List[V]:
201
209
  """Fetch data from the collection.
202
210
 
203
211
  Args:
@@ -211,7 +219,7 @@ class RAG(EmbeddingUsage):
211
219
  List[Dict[str, Any]] | List[Any]: The retrieved data.
212
220
  """
213
221
  # Step 1: Search for vectors
214
- search_results = self._client.search(
222
+ search_results = self.client.search(
215
223
  collection_name or self.safe_target_collection,
216
224
  vecs,
217
225
  search_params={"radius": similarity_threshold},
@@ -221,9 +229,9 @@ class RAG(EmbeddingUsage):
221
229
 
222
230
  # Step 2: Flatten the search results
223
231
  flattened_results = flatten(search_results)
224
-
232
+ unique_results = unique(flattened_results, key=itemgetter("id"))
225
233
  # Step 3: Sort by distance (descending)
226
- sorted_results = sorted(flattened_results, key=itemgetter("distance"), reverse=True)
234
+ sorted_results = sorted(unique_results, key=itemgetter("distance"), reverse=True)
227
235
 
228
236
  logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
229
237
  # Step 4: Extract the entities
@@ -231,12 +239,11 @@ class RAG(EmbeddingUsage):
231
239
 
232
240
  if isinstance(desired_fields, list):
233
241
  return resp
234
- return [r.get(desired_fields) for r in resp]
242
+ return [r.get(desired_fields) for r in resp] # extract the single field as list
235
243
 
236
244
  async def aretrieve(
237
245
  self,
238
246
  query: List[str] | str,
239
- collection_name: Optional[str] = None,
240
247
  final_limit: int = 20,
241
248
  **kwargs: Unpack[FetchKwargs],
242
249
  ) -> List[str]:
@@ -244,7 +251,6 @@ class RAG(EmbeddingUsage):
244
251
 
245
252
  Args:
246
253
  query (List[str] | str): The query to be used for retrieval.
247
- collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
248
254
  final_limit (int): The final limit on the number of results to return.
249
255
  **kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
250
256
 
@@ -253,19 +259,19 @@ class RAG(EmbeddingUsage):
253
259
  """
254
260
  if isinstance(query, str):
255
261
  query = [query]
256
- return (
262
+ return cast(
263
+ List[str],
257
264
  await self.afetch_document(
258
265
  vecs=(await self.vectorize(query)),
259
266
  desired_fields="text",
260
- collection_name=collection_name,
261
267
  **kwargs,
262
- )
268
+ ),
263
269
  )[:final_limit]
264
270
 
265
271
  async def aask_retrieved(
266
272
  self,
267
- question: str | List[str],
268
- query: List[str] | str,
273
+ question: str,
274
+ query: Optional[List[str] | str] = None,
269
275
  collection_name: Optional[str] = None,
270
276
  extra_system_message: str = "",
271
277
  result_per_query: int = 10,
@@ -279,7 +285,7 @@ class RAG(EmbeddingUsage):
279
285
  specified question using the retrieved documents as context.
280
286
 
281
287
  Args:
282
- question (str | List[str]): The question or list of questions to be asked.
288
+ question (str): The question to be asked.
283
289
  query (List[str] | str): The query or list of queries used for document retrieval.
284
290
  collection_name (Optional[str]): The name of the collection to retrieve documents from.
285
291
  If not provided, the currently viewed collection is used.
@@ -293,9 +299,9 @@ class RAG(EmbeddingUsage):
293
299
  str: A string response generated after asking with the context of retrieved documents.
294
300
  """
295
301
  docs = await self.aretrieve(
296
- query,
297
- collection_name,
302
+ query or question,
298
303
  final_limit,
304
+ collection_name=collection_name,
299
305
  result_per_query=result_per_query,
300
306
  similarity_threshold=similarity_threshold,
301
307
  )
@@ -308,3 +314,56 @@ class RAG(EmbeddingUsage):
308
314
  f"{rendered}\n\n{extra_system_message}",
309
315
  **kwargs,
310
316
  )
317
+
318
+ async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> List[str]:
319
+ """Refines the given question using a template.
320
+
321
+ Args:
322
+ question (List[str] | str): The question to be refined.
323
+ **kwargs (Unpack[ChooseKwargs]): Additional keyword arguments for the refinement process.
324
+
325
+ Returns:
326
+ List[str]: A list of refined questions.
327
+ """
328
+ return await self.aliststr(
329
+ template_manager.render_template(
330
+ configs.templates.refined_query_template,
331
+ {"question": [question] if isinstance(question, str) else question},
332
+ ),
333
+ **kwargs,
334
+ )
335
+
336
+ async def aask_refined(
337
+ self,
338
+ question: str,
339
+ collection_name: Optional[str] = None,
340
+ extra_system_message: str = "",
341
+ result_per_query: int = 10,
342
+ final_limit: int = 20,
343
+ similarity_threshold: float = 0.37,
344
+ **kwargs: Unpack[LLMKwargs],
345
+ ) -> str:
346
+ """Asks a question using a refined query based on the provided question.
347
+
348
+ Args:
349
+ question (str): The question to be asked.
350
+ collection_name (Optional[str]): The name of the collection to retrieve documents from.
351
+ extra_system_message (str): An additional system message to be included in the prompt.
352
+ result_per_query (int): The number of results to return per query. Default is 10.
353
+ final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
354
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
355
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
356
+
357
+ Returns:
358
+ str: A string response generated after asking with the refined question.
359
+ """
360
+ return await self.aask_retrieved(
361
+ question,
362
+ await self.arefined_query(question, **kwargs),
363
+ collection_name=collection_name,
364
+ extra_system_message=extra_system_message,
365
+ result_per_query=result_per_query,
366
+ final_limit=final_limit,
367
+ similarity_threshold=similarity_threshold,
368
+ **kwargs,
369
+ )