fabricatio 0.2.6.dev3__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. fabricatio/__init__.py +60 -0
  2. fabricatio/_rust.cp39-win_amd64.pyd +0 -0
  3. fabricatio/_rust.pyi +116 -0
  4. fabricatio/_rust_instances.py +10 -0
  5. fabricatio/actions/article.py +81 -0
  6. fabricatio/actions/output.py +19 -0
  7. fabricatio/actions/rag.py +25 -0
  8. fabricatio/capabilities/correct.py +115 -0
  9. fabricatio/capabilities/propose.py +49 -0
  10. fabricatio/capabilities/rag.py +369 -0
  11. fabricatio/capabilities/rating.py +339 -0
  12. fabricatio/capabilities/review.py +278 -0
  13. fabricatio/capabilities/task.py +113 -0
  14. fabricatio/config.py +400 -0
  15. fabricatio/core.py +181 -0
  16. fabricatio/decorators.py +179 -0
  17. fabricatio/fs/__init__.py +29 -0
  18. fabricatio/fs/curd.py +149 -0
  19. fabricatio/fs/readers.py +46 -0
  20. fabricatio/journal.py +21 -0
  21. fabricatio/models/action.py +158 -0
  22. fabricatio/models/events.py +120 -0
  23. fabricatio/models/extra.py +171 -0
  24. fabricatio/models/generic.py +406 -0
  25. fabricatio/models/kwargs_types.py +158 -0
  26. fabricatio/models/role.py +48 -0
  27. fabricatio/models/task.py +299 -0
  28. fabricatio/models/tool.py +189 -0
  29. fabricatio/models/usages.py +682 -0
  30. fabricatio/models/utils.py +167 -0
  31. fabricatio/parser.py +149 -0
  32. fabricatio/py.typed +0 -0
  33. fabricatio/toolboxes/__init__.py +15 -0
  34. fabricatio/toolboxes/arithmetic.py +62 -0
  35. fabricatio/toolboxes/fs.py +31 -0
  36. fabricatio/workflows/articles.py +15 -0
  37. fabricatio/workflows/rag.py +11 -0
  38. fabricatio-0.2.6.dev3.data/scripts/tdown.exe +0 -0
  39. fabricatio-0.2.6.dev3.dist-info/METADATA +432 -0
  40. fabricatio-0.2.6.dev3.dist-info/RECORD +42 -0
  41. fabricatio-0.2.6.dev3.dist-info/WHEEL +4 -0
  42. fabricatio-0.2.6.dev3.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,369 @@
1
+ """A module for the RAG (Retrieval Augmented Generation) model."""
2
+
3
+ try:
4
+ from pymilvus import MilvusClient
5
+ except ImportError as e:
6
+ raise RuntimeError("pymilvus is not installed. Have you installed `fabricatio[rag]` instead of `fabricatio`") from e
7
+ from functools import lru_cache
8
+ from operator import itemgetter
9
+ from os import PathLike
10
+ from pathlib import Path
11
+ from typing import Any, Callable, Dict, List, Optional, Self, Union, Unpack, cast, overload
12
+
13
+ from fabricatio._rust_instances import TEMPLATE_MANAGER
14
+ from fabricatio.config import configs
15
+ from fabricatio.journal import logger
16
+ from fabricatio.models.kwargs_types import (
17
+ ChooseKwargs,
18
+ CollectionSimpleConfigKwargs,
19
+ EmbeddingKwargs,
20
+ FetchKwargs,
21
+ LLMKwargs,
22
+ )
23
+ from fabricatio.models.usages import EmbeddingUsage
24
+ from fabricatio.models.utils import MilvusData
25
+ from more_itertools.recipes import flatten, unique
26
+ from pydantic import Field, PrivateAttr
27
+
28
+
29
+ @lru_cache(maxsize=None)
30
+ def create_client(uri: str, token: str = "", timeout: Optional[float] = None) -> MilvusClient:
31
+ """Create a Milvus client."""
32
+ return MilvusClient(
33
+ uri=uri,
34
+ token=token,
35
+ timeout=timeout,
36
+ )
37
+
38
+
39
+ class RAG(EmbeddingUsage):
40
+ """A class representing the RAG (Retrieval Augmented Generation) model."""
41
+
42
+ target_collection: Optional[str] = Field(default=None)
43
+ """The name of the collection being viewed."""
44
+
45
+ _client: Optional[MilvusClient] = PrivateAttr(None)
46
+ """The Milvus client used for the RAG model."""
47
+
48
+ @property
49
+ def client(self) -> MilvusClient:
50
+ """Return the Milvus client."""
51
+ if self._client is None:
52
+ raise RuntimeError("Client is not initialized. Have you called `self.init_client()`?")
53
+ return self._client
54
+
55
+ def init_client(
56
+ self,
57
+ milvus_uri: Optional[str] = None,
58
+ milvus_token: Optional[str] = None,
59
+ milvus_timeout: Optional[float] = None,
60
+ ) -> Self:
61
+ """Initialize the Milvus client."""
62
+ self._client = create_client(
63
+ uri=milvus_uri or (self.milvus_uri or configs.rag.milvus_uri).unicode_string(),
64
+ token=milvus_token
65
+ or (token.get_secret_value() if (token := (self.milvus_token or configs.rag.milvus_token)) else ""),
66
+ timeout=milvus_timeout or self.milvus_timeout,
67
+ )
68
+ return self
69
+
70
+ @overload
71
+ async def pack(
72
+ self, input_text: List[str], subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
73
+ ) -> List[MilvusData]: ...
74
+ @overload
75
+ async def pack(
76
+ self, input_text: str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
77
+ ) -> MilvusData: ...
78
+
79
+ async def pack(
80
+ self, input_text: List[str] | str, subject: Optional[str] = None, **kwargs: Unpack[EmbeddingKwargs]
81
+ ) -> List[MilvusData] | MilvusData:
82
+ """Asynchronously generates MilvusData objects for the given input text.
83
+
84
+ Args:
85
+ input_text (List[str] | str): A string or list of strings to generate embeddings for.
86
+ subject (Optional[str]): The subject of the input text. Defaults to None.
87
+ **kwargs (Unpack[EmbeddingKwargs]): Additional keyword arguments for embedding.
88
+
89
+ Returns:
90
+ List[MilvusData] | MilvusData: The generated MilvusData objects.
91
+ """
92
+ if isinstance(input_text, str):
93
+ return MilvusData(vector=await self.vectorize(input_text, **kwargs), text=input_text, subject=subject)
94
+ vecs = await self.vectorize(input_text, **kwargs)
95
+ return [
96
+ MilvusData(
97
+ vector=vec,
98
+ text=text,
99
+ subject=subject,
100
+ )
101
+ for text, vec in zip(input_text, vecs, strict=True)
102
+ ]
103
+
104
+ def view(
105
+ self, collection_name: Optional[str], create: bool = False, **kwargs: Unpack[CollectionSimpleConfigKwargs]
106
+ ) -> Self:
107
+ """View the specified collection.
108
+
109
+ Args:
110
+ collection_name (str): The name of the collection.
111
+ create (bool): Whether to create the collection if it does not exist.
112
+ **kwargs (Unpack[CollectionSimpleConfigKwargs]): Additional keyword arguments for collection configuration.
113
+ """
114
+ if create and collection_name and self.client.has_collection(collection_name):
115
+ kwargs["dimension"] = kwargs.get("dimension") or self.milvus_dimensions or configs.rag.milvus_dimensions
116
+ self.client.create_collection(collection_name, auto_id=True, **kwargs)
117
+ logger.info(f"Creating collection {collection_name}")
118
+
119
+ self.target_collection = collection_name
120
+ return self
121
+
122
+ def quit_viewing(self) -> Self:
123
+ """Quit the current view.
124
+
125
+ Returns:
126
+ Self: The current instance, allowing for method chaining.
127
+ """
128
+ return self.view(None)
129
+
130
+ @property
131
+ def safe_target_collection(self) -> str:
132
+ """Get the name of the collection being viewed, raise an error if not viewing any collection.
133
+
134
+ Returns:
135
+ str: The name of the collection being viewed.
136
+ """
137
+ if self.target_collection is None:
138
+ raise RuntimeError("No collection is being viewed. Have you called `self.view()`?")
139
+ return self.target_collection
140
+
141
+ def add_document[D: Union[Dict[str, Any], MilvusData]](
142
+ self, data: D | List[D], collection_name: Optional[str] = None, flush: bool = False
143
+ ) -> Self:
144
+ """Adds a document to the specified collection.
145
+
146
+ Args:
147
+ data (Union[Dict[str, Any], MilvusData] | List[Union[Dict[str, Any], MilvusData]]): The data to be added to the collection.
148
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
149
+ flush (bool): Whether to flush the collection after insertion.
150
+
151
+ Returns:
152
+ Self: The current instance, allowing for method chaining.
153
+ """
154
+ if isinstance(data, MilvusData):
155
+ prepared_data = data.prepare_insertion()
156
+ elif isinstance(data, list):
157
+ prepared_data = [d.prepare_insertion() if isinstance(d, MilvusData) else d for d in data]
158
+ else:
159
+ raise TypeError(f"Expected MilvusData or list of MilvusData, got {type(data)}")
160
+ c_name = collection_name or self.safe_target_collection
161
+ self.client.insert(c_name, prepared_data)
162
+
163
+ if flush:
164
+ logger.debug(f"Flushing collection {c_name}")
165
+ self.client.flush(c_name)
166
+ return self
167
+
168
+ async def consume_file(
169
+ self,
170
+ source: List[PathLike] | PathLike,
171
+ reader: Callable[[PathLike], str] = lambda path: Path(path).read_text(encoding="utf-8"),
172
+ collection_name: Optional[str] = None,
173
+ ) -> Self:
174
+ """Consume a file and add its content to the collection.
175
+
176
+ Args:
177
+ source (PathLike): The path to the file to be consumed.
178
+ reader (Callable[[PathLike], MilvusData]): The reader function to read the file.
179
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
180
+
181
+ Returns:
182
+ Self: The current instance, allowing for method chaining.
183
+ """
184
+ if not isinstance(source, list):
185
+ source = [source]
186
+ return await self.consume_string([reader(s) for s in source], collection_name)
187
+
188
+ async def consume_string(self, text: List[str] | str, collection_name: Optional[str] = None) -> Self:
189
+ """Consume a string and add it to the collection.
190
+
191
+ Args:
192
+ text (List[str] | str): The text to be added to the collection.
193
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
194
+
195
+ Returns:
196
+ Self: The current instance, allowing for method chaining.
197
+ """
198
+ self.add_document(await self.pack(text), collection_name or self.safe_target_collection, flush=True)
199
+ return self
200
+
201
+ async def afetch_document[V: (int, str, float, bytes)](
202
+ self,
203
+ vecs: List[List[float]],
204
+ desired_fields: List[str] | str,
205
+ collection_name: Optional[str] = None,
206
+ similarity_threshold: float = 0.37,
207
+ result_per_query: int = 10,
208
+ ) -> List[Dict[str, Any]] | List[V]:
209
+ """Fetch data from the collection.
210
+
211
+ Args:
212
+ vecs (List[List[float]]): The vectors to search for.
213
+ desired_fields (List[str] | str): The fields to retrieve.
214
+ collection_name (Optional[str]): The name of the collection. If not provided, the currently viewed collection is used.
215
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
216
+ result_per_query (int): The number of results to return per query.
217
+
218
+ Returns:
219
+ List[Dict[str, Any]] | List[Any]: The retrieved data.
220
+ """
221
+ # Step 1: Search for vectors
222
+ search_results = self.client.search(
223
+ collection_name or self.safe_target_collection,
224
+ vecs,
225
+ search_params={"radius": similarity_threshold},
226
+ output_fields=desired_fields if isinstance(desired_fields, list) else [desired_fields],
227
+ limit=result_per_query,
228
+ )
229
+
230
+ # Step 2: Flatten the search results
231
+ flattened_results = flatten(search_results)
232
+ unique_results = unique(flattened_results, key=itemgetter("id"))
233
+ # Step 3: Sort by distance (descending)
234
+ sorted_results = sorted(unique_results, key=itemgetter("distance"), reverse=True)
235
+
236
+ logger.debug(f"Searched similarities: {[t['distance'] for t in sorted_results]}")
237
+ # Step 4: Extract the entities
238
+ resp = [result["entity"] for result in sorted_results]
239
+
240
+ if isinstance(desired_fields, list):
241
+ return resp
242
+ return [r.get(desired_fields) for r in resp] # extract the single field as list
243
+
244
+ async def aretrieve(
245
+ self,
246
+ query: List[str] | str,
247
+ final_limit: int = 20,
248
+ **kwargs: Unpack[FetchKwargs],
249
+ ) -> List[str]:
250
+ """Retrieve data from the collection.
251
+
252
+ Args:
253
+ query (List[str] | str): The query to be used for retrieval.
254
+ final_limit (int): The final limit on the number of results to return.
255
+ **kwargs (Unpack[FetchKwargs]): Additional keyword arguments for retrieval.
256
+
257
+ Returns:
258
+ List[str]: A list of strings containing the retrieved data.
259
+ """
260
+ if isinstance(query, str):
261
+ query = [query]
262
+ return cast(
263
+ List[str],
264
+ await self.afetch_document(
265
+ vecs=(await self.vectorize(query)),
266
+ desired_fields="text",
267
+ **kwargs,
268
+ ),
269
+ )[:final_limit]
270
+
271
+ async def aask_retrieved(
272
+ self,
273
+ question: str,
274
+ query: Optional[List[str] | str] = None,
275
+ collection_name: Optional[str] = None,
276
+ extra_system_message: str = "",
277
+ result_per_query: int = 10,
278
+ final_limit: int = 20,
279
+ similarity_threshold: float = 0.37,
280
+ **kwargs: Unpack[LLMKwargs],
281
+ ) -> str:
282
+ """Asks a question by retrieving relevant documents based on the provided query.
283
+
284
+ This method performs document retrieval using the given query, then asks the
285
+ specified question using the retrieved documents as context.
286
+
287
+ Args:
288
+ question (str): The question to be asked.
289
+ query (List[str] | str): The query or list of queries used for document retrieval.
290
+ collection_name (Optional[str]): The name of the collection to retrieve documents from.
291
+ If not provided, the currently viewed collection is used.
292
+ extra_system_message (str): An additional system message to be included in the prompt.
293
+ result_per_query (int): The number of results to return per query. Default is 10.
294
+ final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
295
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
296
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
297
+
298
+ Returns:
299
+ str: A string response generated after asking with the context of retrieved documents.
300
+ """
301
+ docs = await self.aretrieve(
302
+ query or question,
303
+ final_limit,
304
+ collection_name=collection_name,
305
+ result_per_query=result_per_query,
306
+ similarity_threshold=similarity_threshold,
307
+ )
308
+
309
+ rendered = TEMPLATE_MANAGER.render_template(configs.templates.retrieved_display_template, {"docs": docs[::-1]})
310
+
311
+ logger.debug(f"Retrieved Documents: \n{rendered}")
312
+ return await self.aask(
313
+ question,
314
+ f"{rendered}\n\n{extra_system_message}",
315
+ **kwargs,
316
+ )
317
+
318
+ async def arefined_query(self, question: List[str] | str, **kwargs: Unpack[ChooseKwargs]) -> List[str]:
319
+ """Refines the given question using a template.
320
+
321
+ Args:
322
+ question (List[str] | str): The question to be refined.
323
+ **kwargs (Unpack[ChooseKwargs]): Additional keyword arguments for the refinement process.
324
+
325
+ Returns:
326
+ List[str]: A list of refined questions.
327
+ """
328
+ return await self.aliststr(
329
+ TEMPLATE_MANAGER.render_template(
330
+ configs.templates.refined_query_template,
331
+ {"question": [question] if isinstance(question, str) else question},
332
+ ),
333
+ **kwargs,
334
+ )
335
+
336
+ async def aask_refined(
337
+ self,
338
+ question: str,
339
+ collection_name: Optional[str] = None,
340
+ extra_system_message: str = "",
341
+ result_per_query: int = 10,
342
+ final_limit: int = 20,
343
+ similarity_threshold: float = 0.37,
344
+ **kwargs: Unpack[LLMKwargs],
345
+ ) -> str:
346
+ """Asks a question using a refined query based on the provided question.
347
+
348
+ Args:
349
+ question (str): The question to be asked.
350
+ collection_name (Optional[str]): The name of the collection to retrieve documents from.
351
+ extra_system_message (str): An additional system message to be included in the prompt.
352
+ result_per_query (int): The number of results to return per query. Default is 10.
353
+ final_limit (int): The maximum number of retrieved documents to consider. Default is 20.
354
+ similarity_threshold (float): The threshold for similarity, only results above this threshold will be returned.
355
+ **kwargs (Unpack[LLMKwargs]): Additional keyword arguments passed to the underlying `aask` method.
356
+
357
+ Returns:
358
+ str: A string response generated after asking with the refined question.
359
+ """
360
+ return await self.aask_retrieved(
361
+ question,
362
+ await self.arefined_query(question, **kwargs),
363
+ collection_name=collection_name,
364
+ extra_system_message=extra_system_message,
365
+ result_per_query=result_per_query,
366
+ final_limit=final_limit,
367
+ similarity_threshold=similarity_threshold,
368
+ **kwargs,
369
+ )