fabricatio 0.2.3.dev0__cp312-cp312-win_amd64.whl → 0.2.3.dev2__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
@@ -0,0 +1,179 @@
1
+ """A module for the RAG (Retrieval Augmented Generation) model."""
2
+
3
+ from operator import itemgetter
4
+ from os import PathLike
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Dict, List, Optional, Self, Union
7
+
8
+ from fabricatio.config import configs
9
+ from fabricatio.models.usages import LLMUsage
10
+ from fabricatio.models.utils import MilvusData
11
+ from more_itertools.recipes import flatten
12
+
13
+ try:
14
+ from pymilvus import MilvusClient
15
+ except ImportError as e:
16
+ raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
17
+ from pydantic import PrivateAttr
18
+
19
+
20
+ class Rag(LLMUsage):
21
+ """A class representing the RAG (Retrieval Augmented Generation) model."""
22
+
23
+ _client: MilvusClient = PrivateAttr(
24
+ default=MilvusClient(
25
+ uri=configs.rag.milvus_uri.unicode_string(),
26
+ token=configs.rag.milvus_token.get_secret_value(),
27
+ timeout=configs.rag.milvus_timeout,
28
+ ),
29
+ )
30
+ _target_collection: Optional[str] = PrivateAttr(default=None)
31
+
32
+ @property
33
+ def client(self) -> MilvusClient:
34
+ """The Milvus client."""
35
+ return self._client
36
+
37
+ def view(self, collection_name: str, create: bool = False) -> Self:
38
+ """View the specified collection.
39
+
40
+ Args:
41
+ collection_name (str): The name of the collection.
42
+ create (bool): Whether to create the collection if it does not exist.
43
+ """
44
+ if create and self._client.has_collection(collection_name):
45
+ self._client.create_collection(collection_name)
46
+
47
+ self._target_collection = collection_name
48
+ return self
49
+
50
+ def quit_view(self) -> Self:
51
+ """Quit the current view.
52
+
53
+ Returns:
54
+ Self: The current instance, allowing for method chaining.
55
+ """
56
+ self._target_collection = None
57
+ return self
58
+
59
+ @property
60
+ def viewing_collection(self) -> Optional[str]:
61
+ """Get the name of the collection being viewed.
62
+
63
+ Returns:
64
+ Optional[str]: The name of the collection being viewed.
65
+ """
66
+ return self._target_collection
67
+
68
+ @property
69
+ def safe_viewing_collection(self) -> str:
70
+ """Get the name of the collection being viewed, raise an error if not viewing any collection.
71
+
72
+ Returns:
73
+ str: The name of the collection being viewed.
74
+ """
75
+ if self._target_collection is None:
76
+ raise RuntimeError("No collection is being viewed. Have you called `self.view()`?")
77
+ return self._target_collection
78
+
79
+ def add_document[D: Union[Dict[str, Any], MilvusData]](
80
+ self, data: D | List[D], collection_name: Optional[str] = None
81
+ ) -> Self:
82
+ """Adds a document to the specified collection.
83
+
84
+ Args:
85
+ data (Union[Dict[str, Any], MilvusData] | List[Union[Dict[str, Any], MilvusData]]): The data to be added to the collection.
86
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
87
+
88
+ Returns:
89
+ Self: The current instance, allowing for method chaining.
90
+ """
91
+ if isinstance(data, MilvusData):
92
+ data = data.prepare_insertion()
93
+ if isinstance(data, list):
94
+ data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
95
+ self._client.insert(collection_name or self.safe_viewing_collection, data)
96
+ return self
97
+
98
+ def consume(
99
+ self, source: PathLike, reader: Callable[[PathLike], MilvusData], collection_name: Optional[str] = None
100
+ ) -> Self:
101
+ """Consume a file and add its content to the collection.
102
+
103
+ Args:
104
+ source (PathLike): The path to the file to be consumed.
105
+ reader (Callable[[PathLike], MilvusData]): The reader function to read the file.
106
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
107
+
108
+ Returns:
109
+ Self: The current instance, allowing for method chaining.
110
+ """
111
+ data = reader(Path(source))
112
+ self.add_document(data, collection_name or self.safe_viewing_collection)
113
+ return self
114
+
115
+ async def afetch_document(
116
+ self,
117
+ vecs: List[List[float]],
118
+ desired_fields: List[str] | str,
119
+ collection_name: Optional[str] = None,
120
+ result_per_query: int = 10,
121
+ ) -> List[Dict[str, Any]] | List[Any]:
122
+ """Fetch data from the collection.
123
+
124
+ Args:
125
+ vecs (List[List[float]]): The vectors to search for.
126
+ desired_fields (List[str] | str): The fields to retrieve.
127
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
128
+ result_per_query (int): The number of results to return per query.
129
+
130
+ Returns:
131
+ List[Dict[str, Any]] | List[Any]: The retrieved data.
132
+ """
133
+ # Step 1: Search for vectors
134
+ search_results = self._client.search(
135
+ collection_name or self.safe_viewing_collection,
136
+ vecs,
137
+ output_fields=desired_fields if isinstance(desired_fields, list) else [desired_fields],
138
+ limit=result_per_query,
139
+ )
140
+
141
+ # Step 2: Flatten the search results
142
+ flattened_results = flatten(search_results)
143
+
144
+ # Step 3: Sort by distance (descending)
145
+ sorted_results = sorted(flattened_results, key=itemgetter("distance"), reverse=True)
146
+
147
+ # Step 4: Extract the entities
148
+ resp = [result["entity"] for result in sorted_results]
149
+
150
+ if isinstance(desired_fields, list):
151
+ return resp
152
+ return [r.get(desired_fields) for r in resp]
153
+
154
+ async def aretrieve(
155
+ self,
156
+ query: List[str] | str,
157
+ collection_name: Optional[str] = None,
158
+ result_per_query: int = 10,
159
+ final_limit: int = 20,
160
+ ) -> List[str]:
161
+ """Retrieve data from the collection.
162
+
163
+ Args:
164
+ query (List[str] | str): The query to be used for retrieval.
165
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
166
+ result_per_query (int): The number of results to be returned per query.
167
+ final_limit (int): The final limit on the number of results to return.
168
+
169
+ Returns:
170
+ List[str]: A list of strings containing the retrieved data.
171
+ """
172
+ if isinstance(query, str):
173
+ query = [query]
174
+ return await self.afetch_document(
175
+ vecs=(await self.vectorize(query)),
176
+ desired_fields="text",
177
+ collection_name=collection_name,
178
+ result_per_query=result_per_query,
179
+ )[:final_limit]
fabricatio/config.py CHANGED
@@ -11,6 +11,7 @@ from pydantic import (
11
11
  FilePath,
12
12
  HttpUrl,
13
13
  NonNegativeFloat,
14
+ PositiveFloat,
14
15
  PositiveInt,
15
16
  SecretStr,
16
17
  )
@@ -207,6 +208,19 @@ class ToolBoxConfig(BaseModel):
207
208
  """The name of the module containing the data."""
208
209
 
209
210
 
211
+ class RagConfig(BaseModel):
212
+ """RAG configuration class."""
213
+
214
+ model_config = ConfigDict(use_attribute_docstrings=True)
215
+
216
+ milvus_uri: HttpUrl = Field(default=HttpUrl("http://localhost:19530"))
217
+ """The URI of the Milvus server."""
218
+ milvus_timeout: Optional[PositiveFloat] = Field(default=None)
219
+ """The timeout of the Milvus server."""
220
+ milvus_token: Optional[SecretStr] = Field(default=None)
221
+ """The token of the Milvus server."""
222
+
223
+
210
224
  class Settings(BaseSettings):
211
225
  """Application settings class.
212
226
 
@@ -250,6 +264,9 @@ class Settings(BaseSettings):
250
264
  toolbox: ToolBoxConfig = Field(default_factory=ToolBoxConfig)
251
265
  """Toolbox Configuration"""
252
266
 
267
+ rag: RagConfig = Field(default_factory=RagConfig)
268
+ """RAG Configuration"""
269
+
253
270
  @classmethod
254
271
  def settings_customise_sources(
255
272
  cls,
@@ -5,6 +5,15 @@ from typing import List, NotRequired, TypedDict
5
5
  from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
6
6
 
7
7
 
8
+ class EmbeddingKwargs(TypedDict):
9
+ """A type representing the keyword arguments for the embedding method."""
10
+
11
+ model: NotRequired[str]
12
+ dimensions: NotRequired[int]
13
+ timeout: NotRequired[PositiveInt]
14
+ caching: NotRequired[bool]
15
+
16
+
8
17
  class LLMKwargs(TypedDict):
9
18
  """A type representing the keyword arguments for the LLM (Large Language Model) usage."""
10
19
 
@@ -10,14 +10,15 @@ from fabricatio._rust_instances import template_manager
10
10
  from fabricatio.config import configs
11
11
  from fabricatio.journal import logger
12
12
  from fabricatio.models.generic import Base, WithBriefing
13
- from fabricatio.models.kwargs_types import ChooseKwargs, GenerateKwargs, LLMKwargs
13
+ from fabricatio.models.kwargs_types import ChooseKwargs, EmbeddingKwargs, GenerateKwargs, LLMKwargs
14
14
  from fabricatio.models.task import Task
15
15
  from fabricatio.models.tool import Tool, ToolBox
16
- from fabricatio.models.utils import Messages
16
+ from fabricatio.models.utils import Messages, MilvusData
17
17
  from fabricatio.parser import JsonCapture
18
18
  from litellm import stream_chunk_builder
19
19
  from litellm.types.utils import (
20
20
  Choices,
21
+ EmbeddingResponse,
21
22
  ModelResponse,
22
23
  StreamingChoices,
23
24
  )
@@ -61,6 +62,97 @@ class LLMUsage(Base):
61
62
  llm_max_tokens: Optional[PositiveInt] = None
62
63
  """The maximum number of tokens to generate."""
63
64
 
65
+ async def aembedding(
66
+ self,
67
+ input_text: List[str],
68
+ model: Optional[str] = None,
69
+ dimensions: Optional[int] = None,
70
+ timeout: Optional[PositiveInt] = None,
71
+ caching: Optional[bool] = False,
72
+ ) -> EmbeddingResponse:
73
+ """Asynchronously generates embeddings for the given input text.
74
+
75
+ Args:
76
+ input_text (List[str]): A list of strings to generate embeddings for.
77
+ model (Optional[str]): The model to use for embedding. Defaults to the instance's `llm_model` or the global configuration.
78
+ dimensions (Optional[int]): The dimensions of the embedding. Defaults to None.
79
+ timeout (Optional[PositiveInt]): The timeout for the embedding request. Defaults to the instance's `llm_timeout` or the global configuration.
80
+ caching (Optional[bool]): Whether to cache the embedding result. Defaults to False.
81
+
82
+
83
+ Returns:
84
+ EmbeddingResponse: The response containing the embeddings.
85
+ """
86
+ return await litellm.aembedding(
87
+ input=input_text,
88
+ caching=caching,
89
+ dimensions=dimensions,
90
+ model=model or self.llm_model or configs.llm.model,
91
+ timeout=timeout or self.llm_timeout or configs.llm.timeout,
92
+ api_key=self.llm_api_key.get_secret_value() if self.llm_api_key else configs.llm.api_key.get_secret_value(),
93
+ api_base=self.llm_api_endpoint.unicode_string().rstrip(
94
+ "/"
95
+ ) # seems embedding function takes no base_url end with a slash
96
+ if self.llm_api_endpoint
97
+ else configs.llm.api_endpoint.unicode_string().rstrip("/"),
98
+ )
99
+
100
+ @overload
101
+ async def vectorize(self, input_text: List[str], **kwargs: Unpack[EmbeddingKwargs]) -> List[List[float]]: ...
102
+ @overload
103
+ async def vectorize(self, input_text: str, **kwargs: Unpack[EmbeddingKwargs]) -> List[float]: ...
104
+
105
+ async def vectorize(
106
+ self, input_text: List[str] | str, **kwargs: Unpack[EmbeddingKwargs]
107
+ ) -> List[List[float]] | List[float]:
108
+ """Asynchronously generates vector embeddings for the given input text.
109
+
110
+ Args:
111
+ input_text (List[str] | str): A string or list of strings to generate embeddings for.
112
+ **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
113
+
114
+ Returns:
115
+ List[List[float]] | List[float]: The generated embeddings.
116
+ """
117
+ if isinstance(input_text, str):
118
+ return (await self.aembedding([input_text], **kwargs)).data[0].get("embedding")
119
+
120
+ return [o.get("embedding") for o in (await self.aembedding(input_text, **kwargs)).data]
121
+
122
+ @overload
123
+ async def pack(
124
+ self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
125
+ ) -> List[MilvusData]: ...
126
+ @overload
127
+ async def pack(
128
+ self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
129
+ ) -> MilvusData: ...
130
+
131
+ async def pack(
132
+ self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
133
+ ) -> List[MilvusData] | MilvusData:
134
+ """Asynchronously generates MilvusData objects for the given input text.
135
+
136
+ Args:
137
+ input_text (List[str] | str): A string or list of strings to generate embeddings for.
138
+ subject (Optional[str]): The subject of the input text. Defaults to None.
139
+ **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
140
+
141
+ Returns:
142
+ List[MilvusData] | MilvusData: The generated MilvusData objects.
143
+ """
144
+ if isinstance(input_text, str):
145
+ return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
146
+ vecs = await self.vectorize(input_text, **kwargs)
147
+ return [
148
+ MilvusData(
149
+ vector=vec,
150
+ text=text,
151
+ subject=subject,
152
+ )
153
+ for text, vec in zip(input_text, vecs, strict=True)
154
+ ]
155
+
64
156
  async def aquery(
65
157
  self,
66
158
  messages: List[Dict[str, str]],
@@ -1,6 +1,6 @@
1
1
  """A module containing utility classes for the models."""
2
2
 
3
- from typing import Dict, List, Literal, Self
3
+ from typing import Any, Dict, List, Literal, Optional, Self
4
4
 
5
5
  from pydantic import BaseModel, ConfigDict, Field
6
6
 
@@ -76,3 +76,52 @@ class Messages(list):
76
76
  list[dict]: A list of dictionaries representing the messages.
77
77
  """
78
78
  return [message.model_dump() for message in self]
79
+
80
+
81
+ class MilvusData(BaseModel):
82
+ """A class representing data stored in Milvus."""
83
+
84
+ model_config = ConfigDict(use_attribute_docstrings=True)
85
+ id: Optional[int] = Field(default=None)
86
+ """The identifier of the data."""
87
+
88
+ vector: List[float]
89
+ """The vector representation of the data."""
90
+
91
+ text: str
92
+ """The text representation of the data."""
93
+
94
+ subject: Optional[str] = Field(default=None)
95
+ """A subject label that we use to demo metadata filtering later."""
96
+
97
+ def prepare_insertion(self) -> Dict[str, Any]:
98
+ """Prepares the data for insertion into Milvus.
99
+
100
+ Returns:
101
+ dict: A dictionary containing the data to be inserted into Milvus.
102
+ """
103
+ return self.model_dump(exclude_none=True)
104
+
105
+ def update_subject(self, new_subject: str) -> Self:
106
+ """Updates the subject label of the data.
107
+
108
+ Args:
109
+ new_subject (str): The new subject label.
110
+
111
+ Returns:
112
+ Self: The updated instance of MilvusData.
113
+ """
114
+ self.subject = new_subject
115
+ return self
116
+
117
+ def update_id(self, new_id: int) -> Self:
118
+ """Updates the identifier of the data.
119
+
120
+ Args:
121
+ new_id (int): The new identifier.
122
+
123
+ Returns:
124
+ Self: The updated instance of MilvusData.
125
+ """
126
+ self.id = new_id
127
+ return self
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fabricatio
3
- Version: 0.2.3.dev0
3
+ Version: 0.2.3.dev2
4
4
  Classifier: License :: OSI Approved :: MIT License
5
5
  Classifier: Programming Language :: Rust
6
6
  Classifier: Programming Language :: Python :: 3.12
@@ -1,12 +1,13 @@
1
- fabricatio-0.2.3.dev0.dist-info/METADATA,sha256=AbSeMTWK-eEnD_dWakSjEUsP-XE9mBHA-8_1JK3EYr4,12339
2
- fabricatio-0.2.3.dev0.dist-info/WHEEL,sha256=tpW5AN9B-9qsM9WW2FXG2r193YXiqexDadpKp0A2daI,96
3
- fabricatio-0.2.3.dev0.dist-info/licenses/LICENSE,sha256=do7J7EiCGbq0QPbMAL_FqLYufXpHnCnXBOuqVPwSV8Y,1088
1
+ fabricatio-0.2.3.dev2.dist-info/METADATA,sha256=BlVLqYv59JMHHleZ0q386UhyragcQ3iZbMXonozLX3c,12339
2
+ fabricatio-0.2.3.dev2.dist-info/WHEEL,sha256=tpW5AN9B-9qsM9WW2FXG2r193YXiqexDadpKp0A2daI,96
3
+ fabricatio-0.2.3.dev2.dist-info/licenses/LICENSE,sha256=do7J7EiCGbq0QPbMAL_FqLYufXpHnCnXBOuqVPwSV8Y,1088
4
4
  fabricatio/actions/communication.py,sha256=NZxIIncKgJSDyBrqNebUtH_haqtxHa8ld2TZxT3CMdU,429
5
5
  fabricatio/actions/transmission.py,sha256=xpvKqbXqgpi1BWy-vUUvmd8NZ1GhRNfsYUBp-l2jLyk,862
6
6
  fabricatio/actions/__init__.py,sha256=eFmFVPQvtNgFynIXBVr3eP-vWQDWCPng60YY5LXvZgg,115
7
+ fabricatio/capabilities/rag.py,sha256=WeUn-7wppyISbmNpm4flogzUZrCnrJ5iQYNjepZpdww,6954
7
8
  fabricatio/capabilities/rating.py,sha256=zmTUvsUfxFgovRQzy4djL2zKRYTHmN6JY7A4lyT5uVQ,14907
8
9
  fabricatio/capabilities/task.py,sha256=d2xtrwQxXWI40UskQCR5YhHarY7ST0ppr8TjY12uWQE,5327
9
- fabricatio/config.py,sha256=wzaaUHZZMRCYc37M_M4qKuLOYtwdEjYtyG77-AGkqCg,11467
10
+ fabricatio/config.py,sha256=6_cpyGu87NmuPC2Kc0fq_cwOogH3fjXTILiDY7poWi4,12041
10
11
  fabricatio/core.py,sha256=VQ_JKgUGIy2gZ8xsTBZCdr_IP7wC5aPg0_bsOmjQ588,6458
11
12
  fabricatio/decorators.py,sha256=uzsP4tFKQNjDHBkofsjjoJA0IUAaYOtt6YVedoyOqlo,6551
12
13
  fabricatio/fs/curd.py,sha256=faMstgGUiQ4k2AW3OXfvvWWTldTtKXco7QINYaMjmyA,3981
@@ -16,12 +17,12 @@ fabricatio/journal.py,sha256=siqimKF0M_QaaOCMxtjr_BJVNyUIAQWILzE9Q4T6-7c,781
16
17
  fabricatio/models/action.py,sha256=NpklAVUHYO5JIY9YLwYowZ-U8R9CFf5aC10DhLF7gxQ,5924
17
18
  fabricatio/models/events.py,sha256=mrihNEFgQ5o7qFWja1z_qX8dnaTLwPBoJdVlzxQV5oM,2719
18
19
  fabricatio/models/generic.py,sha256=WEjZ96rTyBjaBjkM6e8E4Pg_Naot4xWRvGJteqBiCCI,5133
19
- fabricatio/models/kwargs_types.py,sha256=nTtD3wzSpCg-NlrJ43yW6lmfeWzD2V_XGMPlL5mXzyc,1147
20
+ fabricatio/models/kwargs_types.py,sha256=a-e7rdMZJi8xTBL_RLmTC9OPzI-Js7rlS689PR03VvA,1401
20
21
  fabricatio/models/role.py,sha256=gYvleTeKUGDUNKPAC5B0EPMLC4jZ4vHsFHmHiVXkU6c,1830
21
22
  fabricatio/models/task.py,sha256=ip6VeOV7vgXqhiQFOCjVl3hzc6lgdhfyvxbBuSz2-C0,11529
22
23
  fabricatio/models/tool.py,sha256=WTFnpF6xZ1nJbmIOonLsGQcM-kkDCeZiAFqyil9xg2U,6988
23
- fabricatio/models/usages.py,sha256=iLxas1gE7MA55ZtQJJ-qu3W6JP5KLjPgmNKqNYIF6yU,23972
24
- fabricatio/models/utils.py,sha256=i_kpcQpct04mQFk1nbcVGV-pl1YThWu4Qk3wbewzKkc,2535
24
+ fabricatio/models/usages.py,sha256=bzsTDrAekiQyIwKeWds5YdgsXk8qiZkD7OZdno1Q_Ck,28213
25
+ fabricatio/models/utils.py,sha256=mXea76bd4r2jy_zx74GM4t5kCvkMu0JTOaw_VGvTCxk,3952
25
26
  fabricatio/parser.py,sha256=uLabsvF07wRKW1PoTGuGEENCx3P4mhmuO8JkmOEkKko,3522
26
27
  fabricatio/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
28
  fabricatio/toolboxes/arithmetic.py,sha256=WLqhY-Pikv11Y_0SGajwZx3WhsLNpHKf9drzAqOf_nY,1369
@@ -31,6 +32,6 @@ fabricatio/toolboxes/__init__.py,sha256=b13KmASO8q5fBLwew964fn9oH86ER5g-S1PgA4fZ
31
32
  fabricatio/_rust.pyi,sha256=0wCqtwWkVxxoqprvk8T27T8QYKIAKHS7xgsmdMNjQKc,1756
32
33
  fabricatio/_rust_instances.py,sha256=dl0-yZ4UvT5g20tQgnPJpmqtkjFGXNG_YK4eLfi_ugQ,279
33
34
  fabricatio/__init__.py,sha256=opIrN8lGyT-h2If4Qez0bRuWBa3uIT9GsM9CZy7_XJ0,1100
34
- fabricatio/_rust.cp312-win_amd64.pyd,sha256=MMgVw-oe5Yw0aBgcHHDX4FExLxDD0Xx1mnDmzDM7Wcc,1266176
35
- fabricatio-0.2.3.dev0.data/scripts/tdown.exe,sha256=AqyvPpwXz-hImUl93bB9nQO67AtojXG53SDN-Ltor8M,3397120
36
- fabricatio-0.2.3.dev0.dist-info/RECORD,,
35
+ fabricatio/_rust.cp312-win_amd64.pyd,sha256=wJkACEnysxtUfMBOwQ3851hSYNy1HOtY9vSIepLzEFg,1272320
36
+ fabricatio-0.2.3.dev2.data/scripts/tdown.exe,sha256=A9eIQbGxwDhpQ0gYqtEHA4Ms-LakyYkSkjiNnccvfYY,3398144
37
+ fabricatio-0.2.3.dev2.dist-info/RECORD,,