fabricatio 0.2.5.dev4__cp312-cp312-manylinux_2_34_x86_64.whl → 0.2.6__cp312-cp312-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. fabricatio/__init__.py +11 -28
  2. fabricatio/_rust.cpython-312-x86_64-linux-gnu.so +0 -0
  3. fabricatio/_rust.pyi +14 -1
  4. fabricatio/_rust_instances.py +1 -1
  5. fabricatio/actions/article.py +52 -5
  6. fabricatio/actions/output.py +1 -3
  7. fabricatio/actions/rag.py +51 -5
  8. fabricatio/capabilities/correct.py +115 -0
  9. fabricatio/capabilities/propose.py +14 -20
  10. fabricatio/capabilities/rag.py +43 -25
  11. fabricatio/capabilities/rating.py +48 -43
  12. fabricatio/capabilities/review.py +61 -25
  13. fabricatio/capabilities/task.py +12 -13
  14. fabricatio/config.py +66 -26
  15. fabricatio/fs/__init__.py +2 -2
  16. fabricatio/fs/readers.py +2 -2
  17. fabricatio/journal.py +1 -7
  18. fabricatio/models/action.py +117 -45
  19. fabricatio/models/events.py +6 -4
  20. fabricatio/models/extra.py +575 -88
  21. fabricatio/models/generic.py +53 -10
  22. fabricatio/models/kwargs_types.py +96 -76
  23. fabricatio/models/role.py +32 -8
  24. fabricatio/models/task.py +2 -2
  25. fabricatio/models/tool.py +4 -4
  26. fabricatio/models/usages.py +180 -98
  27. fabricatio/models/utils.py +46 -0
  28. fabricatio/parser.py +38 -7
  29. fabricatio/workflows/articles.py +12 -1
  30. fabricatio-0.2.6.data/scripts/tdown +0 -0
  31. fabricatio-0.2.6.dist-info/METADATA +432 -0
  32. fabricatio-0.2.6.dist-info/RECORD +42 -0
  33. {fabricatio-0.2.5.dev4.dist-info → fabricatio-0.2.6.dist-info}/WHEEL +1 -1
  34. fabricatio-0.2.5.dev4.data/scripts/tdown +0 -0
  35. fabricatio-0.2.5.dev4.dist-info/METADATA +0 -311
  36. fabricatio-0.2.5.dev4.dist-info/RECORD +0 -41
  37. {fabricatio-0.2.5.dev4.dist-info → fabricatio-0.2.6.dist-info}/licenses/LICENSE +0 -0
fabricatio/__init__.py CHANGED
@@ -2,59 +2,42 @@
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
5
+ from fabricatio import actions, toolboxes, workflows
5
6
  from fabricatio._rust import BibManager
6
- from fabricatio._rust_instances import template_manager
7
- from fabricatio.actions.article import ExtractArticleEssence, GenerateArticleProposal, GenerateOutline
8
- from fabricatio.actions.output import DumpFinalizedOutput
7
+ from fabricatio._rust_instances import TEMPLATE_MANAGER
9
8
  from fabricatio.core import env
10
- from fabricatio.fs import magika, safe_json_read, safe_text_read
11
9
  from fabricatio.journal import logger
10
+ from fabricatio.models import extra
12
11
  from fabricatio.models.action import Action, WorkFlow
13
12
  from fabricatio.models.events import Event
14
- from fabricatio.models.extra import ArticleEssence
15
13
  from fabricatio.models.role import Role
16
14
  from fabricatio.models.task import Task
17
15
  from fabricatio.models.tool import ToolBox
18
- from fabricatio.models.utils import Message, Messages
19
- from fabricatio.parser import Capture, CodeBlockCapture, JsonCapture, PythonCapture
20
- from fabricatio.toolboxes import arithmetic_toolbox, basic_toolboxes, fs_toolbox
21
- from fabricatio.workflows.articles import WriteOutlineWorkFlow
16
+ from fabricatio.parser import Capture, GenericCapture, JsonCapture, PythonCapture
22
17
 
23
18
  __all__ = [
19
+ "TEMPLATE_MANAGER",
24
20
  "Action",
25
- "ArticleEssence",
26
21
  "BibManager",
27
22
  "Capture",
28
- "CodeBlockCapture",
29
- "DumpFinalizedOutput",
30
23
  "Event",
31
- "ExtractArticleEssence",
32
- "GenerateArticleProposal",
33
- "GenerateOutline",
24
+ "GenericCapture",
34
25
  "JsonCapture",
35
- "Message",
36
- "Messages",
37
26
  "PythonCapture",
38
27
  "Role",
39
28
  "Task",
40
29
  "ToolBox",
41
30
  "WorkFlow",
42
- "WriteOutlineWorkFlow",
43
- "arithmetic_toolbox",
44
- "basic_toolboxes",
31
+ "actions",
45
32
  "env",
46
- "fs_toolbox",
33
+ "extra",
47
34
  "logger",
48
- "magika",
49
- "safe_json_read",
50
- "safe_text_read",
51
- "template_manager",
35
+ "toolboxes",
36
+ "workflows",
52
37
  ]
53
38
 
54
39
 
55
40
  if find_spec("pymilvus"):
56
- from fabricatio.actions.rag import InjectToDB
57
41
  from fabricatio.capabilities.rag import RAG
58
- from fabricatio.workflows.rag import StoreArticle
59
42
 
60
- __all__ += ["RAG", "InjectToDB", "StoreArticle"]
43
+ __all__ += ["RAG"]
fabricatio/_rust.pyi CHANGED
@@ -9,8 +9,9 @@ class TemplateManager:
9
9
 
10
10
  See: https://crates.io/crates/handlebars
11
11
  """
12
+
12
13
  def __init__(
13
- self, template_dirs: List[Path], suffix: Optional[str] = None, active_loading: Optional[bool] = None
14
+ self, template_dirs: List[Path], suffix: Optional[str] = None, active_loading: Optional[bool] = None
14
15
  ) -> None:
15
16
  """Initialize the template manager.
16
17
 
@@ -54,6 +55,7 @@ class TemplateManager:
54
55
  RuntimeError: If template rendering fails
55
56
  """
56
57
 
58
+
57
59
  def blake3_hash(content: bytes) -> str:
58
60
  """Calculate the BLAKE3 cryptographic hash of data.
59
61
 
@@ -64,6 +66,7 @@ def blake3_hash(content: bytes) -> str:
64
66
  Hex-encoded BLAKE3 hash string
65
67
  """
66
68
 
69
+
67
70
  class BibManager:
68
71
  """BibTeX bibliography manager for parsing and querying citation data."""
69
72
 
@@ -100,3 +103,13 @@ class BibManager:
100
103
  Uses nucleo_matcher for high-quality fuzzy text searching
101
104
  See: https://crates.io/crates/nucleo-matcher
102
105
  """
106
+
107
+ def list_titles(self, is_verbatim: Optional[bool] = False) -> List[str]:
108
+ """List all titles in the bibliography.
109
+
110
+ Args:
111
+ is_verbatim: Whether to return verbatim titles (without formatting)
112
+
113
+ Returns:
114
+ List of all titles in the bibliography
115
+ """
@@ -3,7 +3,7 @@
3
3
  from fabricatio._rust import TemplateManager
4
4
  from fabricatio.config import configs
5
5
 
6
- template_manager = TemplateManager(
6
+ TEMPLATE_MANAGER = TemplateManager(
7
7
  template_dirs=configs.templates.template_dir,
8
8
  suffix=configs.templates.template_suffix,
9
9
  active_loading=configs.templates.active_loading,
@@ -2,13 +2,15 @@
2
2
 
3
3
  from os import PathLike
4
4
  from pathlib import Path
5
- from typing import Callable, List
5
+ from typing import Any, Callable, List, Optional
6
6
 
7
7
  from fabricatio.fs import safe_text_read
8
8
  from fabricatio.journal import logger
9
9
  from fabricatio.models.action import Action
10
10
  from fabricatio.models.extra import ArticleEssence, ArticleOutline, ArticleProposal
11
11
  from fabricatio.models.task import Task
12
+ from questionary import confirm, text
13
+ from rich import print as rprint
12
14
 
13
15
 
14
16
  class ExtractArticleEssence(Action):
@@ -27,7 +29,7 @@ class ExtractArticleEssence(Action):
27
29
  task_input: Task,
28
30
  reader: Callable[[P], str] = lambda p: Path(p).read_text(encoding="utf-8"),
29
31
  **_,
30
- ) -> List[ArticleEssence]:
32
+ ) -> Optional[List[ArticleEssence]]:
31
33
  if not task_input.dependencies:
32
34
  logger.info(err := "Task not approved, since no dependencies are provided.")
33
35
  raise RuntimeError(err)
@@ -51,7 +53,7 @@ class GenerateArticleProposal(Action):
51
53
  self,
52
54
  task_input: Task,
53
55
  **_,
54
- ) -> ArticleProposal:
56
+ ) -> Optional[ArticleProposal]:
55
57
  input_path = await self.awhich_pathstr(
56
58
  f"{task_input.briefing}\nExtract the path of file, which contains the article briefing that I need to read."
57
59
  )
@@ -66,16 +68,61 @@ class GenerateArticleProposal(Action):
66
68
  class GenerateOutline(Action):
67
69
  """Generate the article based on the outline."""
68
70
 
69
- output_key: str = "article"
71
+ output_key: str = "article_outline"
70
72
  """The key of the output data."""
71
73
 
72
74
  async def _execute(
73
75
  self,
74
76
  article_proposal: ArticleProposal,
75
77
  **_,
76
- ) -> ArticleOutline:
78
+ ) -> Optional[ArticleOutline]:
77
79
  return await self.propose(
78
80
  ArticleOutline,
79
81
  article_proposal.display(),
80
82
  system_message=f"# your personal briefing: \n{self.briefing}",
81
83
  )
84
+
85
+
86
+ class CorrectProposal(Action):
87
+ """Correct the proposal of the article."""
88
+
89
+ output_key: str = "corrected_proposal"
90
+
91
+ async def _execute(self, task_input: Task, article_proposal: ArticleProposal, **_) -> Any:
92
+ input_path = await self.awhich_pathstr(
93
+ f"{task_input.briefing}\nExtract the path of file, which contains the article briefing that I need to read."
94
+ )
95
+
96
+ ret = None
97
+ while await confirm("Do you want to correct the Proposal?").ask_async():
98
+ rprint(article_proposal.display())
99
+ while not (topic := await text("What is the topic of the proposal reviewing?").ask_async()):
100
+ ...
101
+ ret = await self.correct_obj(
102
+ article_proposal,
103
+ safe_text_read(input_path),
104
+ topic=topic,
105
+ )
106
+ return ret or article_proposal
107
+
108
+
109
+ class CorrectOutline(Action):
110
+ """Correct the outline of the article."""
111
+
112
+ output_key: str = "corrected_outline"
113
+ """The key of the output data."""
114
+
115
+ async def _execute(
116
+ self,
117
+ article_outline: ArticleOutline,
118
+ article_proposal: ArticleProposal,
119
+
120
+ **_,
121
+ ) -> Optional[str]:
122
+ ret = None
123
+ while await confirm("Do you want to correct the outline?").ask_async():
124
+ rprint(article_outline.finalized_dump())
125
+ while not (topic := await text("What is the topic of the outline reviewing?").ask_async()):
126
+ ...
127
+ ret = await self.correct_obj(article_outline, article_proposal.display(), topic=topic)
128
+ return ret or article_outline
@@ -1,7 +1,5 @@
1
1
  """Dump the finalized output to a file."""
2
2
 
3
- from typing import Unpack
4
-
5
3
  from fabricatio.models.action import Action
6
4
  from fabricatio.models.generic import FinalizedDumpAble
7
5
  from fabricatio.models.task import Task
@@ -12,7 +10,7 @@ class DumpFinalizedOutput(Action):
12
10
 
13
11
  output_key: str = "dump_path"
14
12
 
15
- async def _execute(self, task_input: Task, to_dump: FinalizedDumpAble, **cxt: Unpack) -> str:
13
+ async def _execute(self, task_input: Task, to_dump: FinalizedDumpAble, **_) -> str:
16
14
  dump_path = await self.awhich_pathstr(
17
15
  f"{task_input.briefing}\n\nExtract a single path of the file, to which I will dump the data."
18
16
  )
fabricatio/actions/rag.py CHANGED
@@ -1,10 +1,13 @@
1
1
  """Inject data into the database."""
2
2
 
3
- from typing import List, Optional, Unpack
3
+ from typing import List, Optional
4
4
 
5
5
  from fabricatio.capabilities.rag import RAG
6
+ from fabricatio.journal import logger
6
7
  from fabricatio.models.action import Action
7
8
  from fabricatio.models.generic import PrepareVectorization
9
+ from fabricatio.models.task import Task
10
+ from questionary import text
8
11
 
9
12
 
10
13
  class InjectToDB(Action, RAG):
@@ -13,13 +16,56 @@ class InjectToDB(Action, RAG):
13
16
  output_key: str = "collection_name"
14
17
 
15
18
  async def _execute[T: PrepareVectorization](
16
- self, to_inject: T | List[T], collection_name: Optional[str] = "my_collection", **cxt: Unpack
17
- ) -> str:
19
+ self, to_inject: Optional[T] | List[Optional[T]], collection_name: Optional[str] = "my_collection", **_
20
+ ) -> Optional[str]:
18
21
  if not isinstance(to_inject, list):
19
22
  to_inject = [to_inject]
20
-
23
+ logger.info(f"Injecting {len(to_inject)} items into the collection '{collection_name}'")
21
24
  await self.view(collection_name, create=True).consume_string(
22
- [t.prepare_vectorization(self.embedding_max_sequence_length) for t in to_inject],
25
+ [
26
+ t.prepare_vectorization(self.embedding_max_sequence_length)
27
+ for t in to_inject
28
+ if isinstance(t, PrepareVectorization)
29
+ ],
23
30
  )
24
31
 
25
32
  return collection_name
33
+
34
+
35
+ class RAGTalk(Action, RAG):
36
+ """RAG-enabled conversational action that processes user questions based on a given task.
37
+
38
+ This action establishes an interactive conversation loop where it retrieves context-relevant
39
+ information to answer user queries according to the assigned task briefing.
40
+
41
+ Notes:
42
+ task_input: Task briefing that guides how to respond to user questions
43
+ collection_name: Name of the vector collection to use for retrieval (default: "my_collection")
44
+
45
+ Returns:
46
+ Number of conversation turns completed before termination
47
+ """
48
+
49
+ output_key: str = "task_output"
50
+
51
+ async def _execute(self, task_input: Task[str], **kwargs) -> int:
52
+ collection_name = kwargs.get("collection_name", "my_collection")
53
+ counter = 0
54
+
55
+ self.view(collection_name, create=True)
56
+
57
+ try:
58
+ while True:
59
+ user_say = await text("User: ").ask_async()
60
+ if user_say is None:
61
+ break
62
+ gpt_say = await self.aask_retrieved(
63
+ user_say,
64
+ user_say,
65
+ extra_system_message=f"You have to answer to user obeying task assigned to you:\n{task_input.briefing}",
66
+ )
67
+ print(f"GPT: {gpt_say}") # noqa: T201
68
+ counter += 1
69
+ except KeyboardInterrupt:
70
+ logger.info(f"executed talk action {counter} times")
71
+ return counter
@@ -0,0 +1,115 @@
1
+ """Correct capability module providing advanced review and validation functionality.
2
+
3
+ This module implements the Correct capability, which extends the Review functionality
4
+ to provide mechanisms for reviewing, validating, and correcting various objects and tasks
5
+ based on predefined criteria and templates.
6
+ """
7
+
8
+ from typing import Optional, Unpack, cast
9
+
10
+ from fabricatio._rust_instances import TEMPLATE_MANAGER
11
+ from fabricatio.capabilities.review import Review, ReviewResult
12
+ from fabricatio.config import configs
13
+ from fabricatio.models.generic import Display, ProposedAble, WithBriefing
14
+ from fabricatio.models.kwargs_types import CorrectKwargs, ReviewKwargs
15
+ from fabricatio.models.task import Task
16
+
17
+
18
+ class Correct(Review):
19
+ """Correct capability for reviewing, validating, and improving objects.
20
+
21
+ This class enhances the Review capability with specialized functionality for
22
+ correcting and improving objects based on review feedback. It can process
23
+ various inputs including tasks, strings, and generic objects that implement
24
+ the required interfaces, applying corrections based on templated review processes.
25
+ """
26
+
27
+ async def correct_obj[M: ProposedAble](
28
+ self,
29
+ obj: M,
30
+ reference: str = "",
31
+ supervisor_check: bool = True,
32
+ **kwargs: Unpack[ReviewKwargs[ReviewResult[str]]],
33
+ ) -> Optional[M]:
34
+ """Review and correct an object based on defined criteria and templates.
35
+
36
+ This method first conducts a review of the given object, then uses the review results
37
+ to generate a corrected version of the object using appropriate templates.
38
+
39
+ Args:
40
+ obj (M): The object to be reviewed and corrected. Must implement ProposedAble.
41
+ reference (str): A reference or contextual information for the object.
42
+ supervisor_check (bool, optional): Whether to perform a supervisor check on the review results. Defaults to True.
43
+ **kwargs: Review configuration parameters including criteria and review options.
44
+
45
+ Returns:
46
+ Optional[M]: A corrected version of the input object, or None if correction fails.
47
+
48
+ Raises:
49
+ TypeError: If the provided object doesn't implement Display or WithBriefing interfaces.
50
+ """
51
+ if not isinstance(obj, (Display, WithBriefing)):
52
+ raise TypeError(f"Expected Display or WithBriefing, got {type(obj)}")
53
+
54
+ review_res = await self.review_obj(obj, **kwargs)
55
+ if supervisor_check:
56
+ await review_res.supervisor_check()
57
+ if "default" in kwargs:
58
+ cast(ReviewKwargs[None], kwargs)["default"] = None
59
+ return await self.propose(
60
+ obj.__class__,
61
+ TEMPLATE_MANAGER.render_template(
62
+ configs.templates.correct_template,
63
+ {
64
+ "content": f"{(reference + '\n\nAbove is referencing material') if reference else ''}{obj.display() if isinstance(obj, Display) else obj.briefing}",
65
+ "review": review_res.display(),
66
+ },
67
+ ),
68
+ **kwargs,
69
+ )
70
+
71
+ async def correct_string(
72
+ self, input_text: str, supervisor_check: bool = True, **kwargs: Unpack[ReviewKwargs[ReviewResult[str]]]
73
+ ) -> Optional[str]:
74
+ """Review and correct a string based on defined criteria and templates.
75
+
76
+ This method applies the review process to the input text and generates
77
+ a corrected version based on the review results.
78
+
79
+ Args:
80
+ input_text (str): The text content to be reviewed and corrected.
81
+ supervisor_check (bool, optional): Whether to perform a supervisor check on the review results. Defaults to True.
82
+ **kwargs: Review configuration parameters including criteria and review options.
83
+
84
+ Returns:
85
+ Optional[str]: The corrected text content, or None if correction fails.
86
+ """
87
+ review_res = await self.review_string(input_text, **kwargs)
88
+ if supervisor_check:
89
+ await review_res.supervisor_check()
90
+
91
+ if "default" in kwargs:
92
+ cast(ReviewKwargs[None], kwargs)["default"] = None
93
+ return await self.ageneric_string(
94
+ TEMPLATE_MANAGER.render_template(
95
+ configs.templates.correct_template, {"content": input_text, "review": review_res.display()}
96
+ ),
97
+ **kwargs,
98
+ )
99
+
100
+ async def correct_task[T](
101
+ self, task: Task[T], **kwargs: Unpack[CorrectKwargs[ReviewResult[str]]]
102
+ ) -> Optional[Task[T]]:
103
+ """Review and correct a task object based on defined criteria.
104
+
105
+ This is a specialized version of correct_obj specifically for Task objects,
106
+ applying the same review and correction process to task definitions.
107
+
108
+ Args:
109
+ task (Task[T]): The task to be reviewed and corrected.
110
+ **kwargs: Review configuration parameters including criteria and review options.
111
+
112
+ Returns:
113
+ Optional[Task[T]]: The corrected task, or None if correction fails.
114
+ """
115
+ return await self.correct_obj(task, **kwargs)
@@ -1,37 +1,37 @@
1
1
  """A module for the task capabilities of the Fabricatio library."""
2
2
 
3
- from typing import List, Type, Unpack, overload
3
+ from typing import List, Optional, Type, Unpack, overload
4
4
 
5
5
  from fabricatio.models.generic import ProposedAble
6
- from fabricatio.models.kwargs_types import GenerateKwargs
6
+ from fabricatio.models.kwargs_types import ValidateKwargs
7
7
  from fabricatio.models.usages import LLMUsage
8
8
 
9
9
 
10
- class Propose[M: ProposedAble](LLMUsage):
10
+ class Propose(LLMUsage):
11
11
  """A class that proposes an Obj based on a prompt."""
12
12
 
13
13
  @overload
14
- async def propose(
14
+ async def propose[M: ProposedAble](
15
15
  self,
16
16
  cls: Type[M],
17
17
  prompt: List[str],
18
- **kwargs: Unpack[GenerateKwargs[M]],
19
- ) -> List[M]: ...
18
+ **kwargs: Unpack[ValidateKwargs[M]],
19
+ ) -> Optional[List[M]]: ...
20
20
 
21
21
  @overload
22
- async def propose(
22
+ async def propose[M: ProposedAble](
23
23
  self,
24
24
  cls: Type[M],
25
25
  prompt: str,
26
- **kwargs: Unpack[GenerateKwargs[M]],
27
- ) -> M: ...
26
+ **kwargs: Unpack[ValidateKwargs[M]],
27
+ ) -> Optional[M]: ...
28
28
 
29
- async def propose(
29
+ async def propose[M: ProposedAble](
30
30
  self,
31
31
  cls: Type[M],
32
32
  prompt: List[str] | str,
33
- **kwargs: Unpack[GenerateKwargs[M]],
34
- ) -> List[M] | M:
33
+ **kwargs: Unpack[ValidateKwargs[M]],
34
+ ) -> Optional[List[M] | M]:
35
35
  """Asynchronously proposes a task based on a given prompt and parameters.
36
36
 
37
37
  Parameters:
@@ -42,14 +42,8 @@ class Propose[M: ProposedAble](LLMUsage):
42
42
  Returns:
43
43
  A Task object based on the proposal result.
44
44
  """
45
- if isinstance(prompt, str):
46
- return await self.aask_validate(
47
- question=cls.create_json_prompt(prompt),
48
- validator=cls.instantiate_from_string,
49
- **kwargs,
50
- )
51
- return await self.aask_validate_batch(
52
- questions=[cls.create_json_prompt(p) for p in prompt],
45
+ return await self.aask_validate(
46
+ question=cls.create_json_prompt(prompt),
53
47
  validator=cls.instantiate_from_string,
54
48
  **kwargs,
55
49
  )
@@ -8,20 +8,20 @@ from functools import lru_cache
8
8
  from operator import itemgetter
9
9
  from os import PathLike
10
10
  from pathlib import Path
11
- from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, overload
11
+ from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, cast, overload
12
12
 
13
- from fabricatio._rust_instances import template_manager
13
+ from fabricatio._rust_instances import TEMPLATE_MANAGER
14
14
  from fabricatio.config import configs
15
15
  from fabricatio.journal import logger
16
16
  from fabricatio.models.kwargs_types import (
17
17
  ChooseKwargs,
18
- CollectionSimpleConfigKwargs,
18
+ CollectionConfigKwargs,
19
19
  EmbeddingKwargs,
20
20
  FetchKwargs,
21
21
  LLMKwargs,
22
22
  )
23
23
  from fabricatio.models.usages import EmbeddingUsage
24
- from fabricatio.models.utils import MilvusData
24
+ from fabricatio.models.utils import MilvusData, ok
25
25
  from more_itertools.recipes import flatten, unique
26
26
  from pydantic import Field, PrivateAttr
27
27
 
@@ -60,13 +60,21 @@ class RAG(EmbeddingUsage):
60
60
  ) -> Self:
61
61
  """Initialize the Milvus client."""
62
62
  self._client = create_client(
63
- uri=milvus_uri or (self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
63
+ uri=milvus_uri or ok(self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
64
64
  token=milvus_token
65
65
  or (token.get_secret_value() if (token := (self.milvus_token or configs.rag.milvus_token)) else ""),
66
- timeout=milvus_timeout or self.milvus_timeout,
66
+ timeout=milvus_timeout or self.milvus_timeout or configs.rag.milvus_timeout,
67
67
  )
68
68
  return self
69
69
 
70
+ def check_client(self, init: bool = True) -> Self:
71
+ """Check if the client is initialized, and if not, initialize it."""
72
+ if self._client is None and init:
73
+ return self.init_client()
74
+ if self._client is None and not init:
75
+ raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
76
+ return self
77
+
70
78
  @overload
71
79
  async def pack(
72
80
  self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
@@ -102,18 +110,25 @@ class RAG(EmbeddingUsage):
102
110
  ]
103
111
 
104
112
  def view(
105
- self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionSimpleConfigKwargs]
113
+ self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionConfigKwargs]
106
114
  ) -> Self:
107
115
  """View the specified collection.
108
116
 
109
117
  Args:
110
118
  collection_name (str): The name of the collection.
111
119
  create (bool): Whether to create the collection if it does not exist.
112
- **kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
120
+ **kwargs (Unpack[CollectionConfigKwargs]): Additional keyword arguments for collection configuration.
113
121
  """
114
- if create and collection_name and not self._client.has_collection(collection_name):
115
- kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
116
- self._client.create_collection(collection_name, auto_id=True, **kwargs)
122
+ if create and collection_name and not self.check_client().client.has_collection(collection_name):
123
+ kwargs["dimension"] = ok(
124
+ kwargs.get("dimension")
125
+ or self.milvus_dimensions
126
+ or configs.rag.milvus_dimensions
127
+ or self.embedding_dimensions
128
+ or configs.embedding.dimensions,
129
+ "`dimension` is not set at any level.",
130
+ )
131
+ self.client.create_collection(collection_name, auto_id=True, **kwargs)
117
132
  logger.info(f"Creating collection {collection_name}")
118
133
 
119
134
  self.target_collection = collection_name
@@ -152,15 +167,17 @@ class RAG(EmbeddingUsage):
152
167
  Self: The current instance, allowing for method chaining.
153
168
  """
154
169
  if isinstance(data, MilvusData):
155
- data = data.prepare_insertion()
156
- if isinstance(data, list):
157
- data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
170
+ prepared_data = data.prepare_insertion()
171
+ elif isinstance(data, list):
172
+ prepared_data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
173
+ else:
174
+ raise TypeError(f"Expected MilvusData or list of MilvusData, got {type(data)}")
158
175
  c_name = collection_name or self.safe_target_collection
159
- self._client.insert(c_name, data)
176
+ self.check_client().client.insert(c_name, prepared_data)
160
177
 
161
178
  if flush:
162
179
  logger.debug(f"Flushing collection {c_name}")
163
- self._client.flush(c_name)
180
+ self.client.flush(c_name)
164
181
  return self
165
182
 
166
183
  async def consume_file(
@@ -196,14 +213,14 @@ class RAG(EmbeddingUsage):
196
213
  self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
197
214
  return self
198
215
 
199
- async def afetch_document(
216
+ async def afetch_document[V: (int, str, float, bytes)](
200
217
  self,
201
218
  vecs: List[List[float]],
202
219
  desired_fields: List[str] | str,
203
220
  collection_name: Optional[str] = None,
204
221
  similarity_threshold: float = 0.37,
205
222
  result_per_query: int = 10,
206
- ) -> List[Dict[str, Any]] | List[Any]:
223
+ ) -> List[Dict[str, Any]] | List[V]:
207
224
  """Fetch data from the collection.
208
225
 
209
226
  Args:
@@ -217,7 +234,7 @@ class RAG(EmbeddingUsage):
217
234
  List[Dict[str, Any]] | List[Any]: The retrieved data.
218
235
  """
219
236
  # Step 1: Search for vectors
220
- search_results = self._client.search(
237
+ search_results = self.check_client().client.search(
221
238
  collection_name or self.safe_target_collection,
222
239
  vecs,
223
240
  search_params={"radius": similarity_threshold},
@@ -237,7 +254,7 @@ class RAG(EmbeddingUsage):
237
254
 
238
255
  if isinstance(desired_fields, list):
239
256
  return resp
240
- return [r.get(desired_fields) for r in resp]
257
+ return [r.get(desired_fields) for r in resp] # extract the single field as list
241
258
 
242
259
  async def aretrieve(
243
260
  self,
@@ -257,12 +274,13 @@ class RAG(EmbeddingUsage):
257
274
  """
258
275
  if isinstance(query, str):
259
276
  query = [query]
260
- return (
277
+ return cast(
278
+ List[str],
261
279
  await self.afetch_document(
262
280
  vecs=(await self.vectorize(query)),
263
281
  desired_fields="text",
264
282
  **kwargs,
265
- )
283
+ ),
266
284
  )[:final_limit]
267
285
 
268
286
  async def aask_retrieved(
@@ -303,7 +321,7 @@ class RAG(EmbeddingUsage):
303
321
  similarity_threshold=similarity_threshold,
304
322
  )
305
323
 
306
- rendered = template_manager.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
324
+ rendered = TEMPLATE_MANAGER.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
307
325
 
308
326
  logger.debug(f"Retrieved Documents: \n{rendered}")
309
327
  return await self.aask(
@@ -312,7 +330,7 @@ class RAG(EmbeddingUsage):
312
330
  **kwargs,
313
331
  )
314
332
 
315
- async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> List[str]:
333
+ async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> Optional[List[str]]:
316
334
  """Refines the given question using a template.
317
335
 
318
336
  Args:
@@ -323,7 +341,7 @@ class RAG(EmbeddingUsage):
323
341
  List[str]: A list of refined questions.
324
342
  """
325
343
  return await self.aliststr(
326
- template_manager.render_template(
344
+ TEMPLATE_MANAGER.render_template(
327
345
  configs.templates.refined_query_template,
328
346
  {"question": [question] if isinstance(question, str) else question},
329
347
  ),