ai-parrot 0.3.4__cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-parrot might be problematic. Click here for more details.
- ai_parrot-0.3.4.dist-info/LICENSE +21 -0
- ai_parrot-0.3.4.dist-info/METADATA +319 -0
- ai_parrot-0.3.4.dist-info/RECORD +109 -0
- ai_parrot-0.3.4.dist-info/WHEEL +6 -0
- ai_parrot-0.3.4.dist-info/top_level.txt +3 -0
- parrot/__init__.py +21 -0
- parrot/chatbots/__init__.py +7 -0
- parrot/chatbots/abstract.py +728 -0
- parrot/chatbots/asktroc.py +16 -0
- parrot/chatbots/base.py +366 -0
- parrot/chatbots/basic.py +9 -0
- parrot/chatbots/bose.py +17 -0
- parrot/chatbots/cody.py +17 -0
- parrot/chatbots/copilot.py +83 -0
- parrot/chatbots/dataframe.py +103 -0
- parrot/chatbots/hragents.py +15 -0
- parrot/chatbots/odoo.py +17 -0
- parrot/chatbots/retrievals/__init__.py +578 -0
- parrot/chatbots/retrievals/constitutional.py +19 -0
- parrot/conf.py +110 -0
- parrot/crew/__init__.py +3 -0
- parrot/crew/tools/__init__.py +22 -0
- parrot/crew/tools/bing.py +13 -0
- parrot/crew/tools/config.py +43 -0
- parrot/crew/tools/duckgo.py +62 -0
- parrot/crew/tools/file.py +24 -0
- parrot/crew/tools/google.py +168 -0
- parrot/crew/tools/gtrends.py +16 -0
- parrot/crew/tools/md2pdf.py +25 -0
- parrot/crew/tools/rag.py +42 -0
- parrot/crew/tools/search.py +32 -0
- parrot/crew/tools/url.py +21 -0
- parrot/exceptions.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/bots.py +196 -0
- parrot/handlers/chat.py +162 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/database.py +29 -0
- parrot/llms/__init__.py +137 -0
- parrot/llms/abstract.py +47 -0
- parrot/llms/anthropic.py +42 -0
- parrot/llms/google.py +42 -0
- parrot/llms/groq.py +45 -0
- parrot/llms/hf.py +45 -0
- parrot/llms/openai.py +59 -0
- parrot/llms/pipes.py +114 -0
- parrot/llms/vertex.py +78 -0
- parrot/loaders/__init__.py +20 -0
- parrot/loaders/abstract.py +456 -0
- parrot/loaders/audio.py +106 -0
- parrot/loaders/basepdf.py +102 -0
- parrot/loaders/basevideo.py +280 -0
- parrot/loaders/csv.py +42 -0
- parrot/loaders/dir.py +37 -0
- parrot/loaders/excel.py +349 -0
- parrot/loaders/github.py +65 -0
- parrot/loaders/handlers/__init__.py +5 -0
- parrot/loaders/handlers/data.py +213 -0
- parrot/loaders/image.py +119 -0
- parrot/loaders/json.py +52 -0
- parrot/loaders/pdf.py +437 -0
- parrot/loaders/pdfchapters.py +142 -0
- parrot/loaders/pdffn.py +112 -0
- parrot/loaders/pdfimages.py +207 -0
- parrot/loaders/pdfmark.py +88 -0
- parrot/loaders/pdftables.py +145 -0
- parrot/loaders/ppt.py +30 -0
- parrot/loaders/qa.py +81 -0
- parrot/loaders/repo.py +103 -0
- parrot/loaders/rtd.py +65 -0
- parrot/loaders/txt.py +92 -0
- parrot/loaders/utils/__init__.py +1 -0
- parrot/loaders/utils/models.py +25 -0
- parrot/loaders/video.py +96 -0
- parrot/loaders/videolocal.py +120 -0
- parrot/loaders/vimeo.py +106 -0
- parrot/loaders/web.py +216 -0
- parrot/loaders/web_base.py +112 -0
- parrot/loaders/word.py +125 -0
- parrot/loaders/youtube.py +192 -0
- parrot/manager.py +166 -0
- parrot/models.py +372 -0
- parrot/py.typed +0 -0
- parrot/stores/__init__.py +48 -0
- parrot/stores/abstract.py +171 -0
- parrot/stores/milvus.py +632 -0
- parrot/stores/qdrant.py +153 -0
- parrot/tools/__init__.py +12 -0
- parrot/tools/abstract.py +53 -0
- parrot/tools/asknews.py +32 -0
- parrot/tools/bing.py +13 -0
- parrot/tools/duck.py +62 -0
- parrot/tools/google.py +170 -0
- parrot/tools/stack.py +26 -0
- parrot/tools/weather.py +70 -0
- parrot/tools/wikipedia.py +59 -0
- parrot/tools/zipcode.py +179 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- resources/users/__init__.py +5 -0
- resources/users/handlers.py +13 -0
- resources/users/models.py +205 -0
- settings/__init__.py +0 -0
- settings/settings.py +51 -0
parrot/stores/milvus.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
1
|
+
from typing import Optional, Union, Any
|
|
2
|
+
import asyncio
|
|
3
|
+
import uuid
|
|
4
|
+
import torch
|
|
5
|
+
from pymilvus import (
|
|
6
|
+
MilvusClient,
|
|
7
|
+
Collection,
|
|
8
|
+
FieldSchema,
|
|
9
|
+
CollectionSchema,
|
|
10
|
+
DataType,
|
|
11
|
+
connections,
|
|
12
|
+
db
|
|
13
|
+
)
|
|
14
|
+
from pymilvus.exceptions import MilvusException
|
|
15
|
+
from langchain_milvus import Milvus # pylint: disable=import-error, E0611
|
|
16
|
+
from langchain.memory import VectorStoreRetrieverMemory
|
|
17
|
+
from .abstract import AbstractStore
|
|
18
|
+
from ..conf import (
|
|
19
|
+
MILVUS_HOST,
|
|
20
|
+
MILVUS_PROTOCOL,
|
|
21
|
+
MILVUS_PORT,
|
|
22
|
+
MILVUS_URL,
|
|
23
|
+
MILVUS_TOKEN,
|
|
24
|
+
MILVUS_USER,
|
|
25
|
+
MILVUS_PASSWORD,
|
|
26
|
+
MILVUS_SECURE,
|
|
27
|
+
MILVUS_SERVER_NAME,
|
|
28
|
+
MILVUS_CA_CERT,
|
|
29
|
+
MILVUS_SERVER_CERT,
|
|
30
|
+
MILVUS_SERVER_KEY,
|
|
31
|
+
MILVUS_USE_TLSv2
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MilvusConnection:
|
|
36
|
+
"""
|
|
37
|
+
Context Manager for Milvus Connections.
|
|
38
|
+
"""
|
|
39
|
+
def __init__(self, alias: str = 'default', **kwargs):
|
|
40
|
+
self._connected: bool = False
|
|
41
|
+
self.kwargs = kwargs
|
|
42
|
+
self.alias: str = alias
|
|
43
|
+
|
|
44
|
+
def connect(self, alias: str = None, **kwargs):
|
|
45
|
+
if not alias:
|
|
46
|
+
alias = self.alias
|
|
47
|
+
conn = connections.connect(
|
|
48
|
+
alias=alias,
|
|
49
|
+
**kwargs
|
|
50
|
+
)
|
|
51
|
+
self._connected = True
|
|
52
|
+
return alias
|
|
53
|
+
|
|
54
|
+
def is_connected(self):
|
|
55
|
+
return self._connected
|
|
56
|
+
|
|
57
|
+
def close(self, alias: str = None):
|
|
58
|
+
try:
|
|
59
|
+
connections.disconnect(alias=alias)
|
|
60
|
+
finally:
|
|
61
|
+
self._connected = False
|
|
62
|
+
|
|
63
|
+
def __enter__(self):
|
|
64
|
+
self.connect(alias=self.alias, **self.kwargs)
|
|
65
|
+
return self
|
|
66
|
+
|
|
67
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
68
|
+
self.close(alias=self.alias)
|
|
69
|
+
return self
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class MilvusStore(AbstractStore):
|
|
73
|
+
"""MilvusStore class.
|
|
74
|
+
|
|
75
|
+
Milvus is a Vector Database multi-layered.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
host (str): Milvus host.
|
|
79
|
+
port (int): Milvus port.
|
|
80
|
+
url (str): Milvus URL.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, embeddings = None, **kwargs):
|
|
84
|
+
super().__init__(embeddings, **kwargs)
|
|
85
|
+
self.use_bge: bool = kwargs.pop("use_bge", False)
|
|
86
|
+
self.fastembed: bool = kwargs.pop("use_fastembed", False)
|
|
87
|
+
self.database: str = kwargs.pop('database', '')
|
|
88
|
+
self.collection = kwargs.pop('collection_name', '')
|
|
89
|
+
self.dimension: int = kwargs.pop("dimension", 768)
|
|
90
|
+
self._metric_type: str = kwargs.pop("metric_type", 'COSINE')
|
|
91
|
+
self._index_type: str = kwargs.pop("index_type", 'IVF_FLAT')
|
|
92
|
+
self.host = kwargs.pop("host", MILVUS_HOST)
|
|
93
|
+
self.port = kwargs.pop("port", MILVUS_PORT)
|
|
94
|
+
self.protocol = kwargs.pop("protocol", MILVUS_PROTOCOL)
|
|
95
|
+
self.create_database: bool = kwargs.pop('create_database', True)
|
|
96
|
+
self.url = kwargs.pop("url", MILVUS_URL)
|
|
97
|
+
self._client_id = kwargs.pop('client_id', 'default')
|
|
98
|
+
if not self.url:
|
|
99
|
+
self.url = f"{self.protocol}://{self.host}:{self.port}"
|
|
100
|
+
else:
|
|
101
|
+
# Extract host and port from URL
|
|
102
|
+
if not self.host:
|
|
103
|
+
self.host = self.url.split("://")[-1].split(":")[0]
|
|
104
|
+
if not self.port:
|
|
105
|
+
self.port = int(self.url.split(":")[-1])
|
|
106
|
+
self.token = kwargs.pop("token", MILVUS_TOKEN)
|
|
107
|
+
# user and password (if required)
|
|
108
|
+
self.user = kwargs.pop("user", MILVUS_USER)
|
|
109
|
+
self.password = kwargs.pop("password", MILVUS_PASSWORD)
|
|
110
|
+
# SSL/TLS
|
|
111
|
+
self._secure: bool = kwargs.pop('secure', MILVUS_SECURE)
|
|
112
|
+
self._server_name: str = kwargs.pop('server_name', MILVUS_SERVER_NAME)
|
|
113
|
+
self._cert: str = kwargs.pop('server_pem_path', MILVUS_SERVER_CERT)
|
|
114
|
+
self._ca_cert: str = kwargs.pop('ca_pem_path', MILVUS_CA_CERT)
|
|
115
|
+
self._cert_key: str = kwargs.pop('client_key_path', MILVUS_SERVER_KEY)
|
|
116
|
+
# Any other argument will be passed to the Milvus client
|
|
117
|
+
self.kwargs = {
|
|
118
|
+
"uri": self.url,
|
|
119
|
+
"host": self.host,
|
|
120
|
+
"port": self.port,
|
|
121
|
+
**kwargs
|
|
122
|
+
}
|
|
123
|
+
if self.token:
|
|
124
|
+
self.kwargs['token'] = self.token
|
|
125
|
+
if self.user:
|
|
126
|
+
self.kwargs['token'] = f"{self.user}:{self.password}"
|
|
127
|
+
# SSL Security:
|
|
128
|
+
if self._secure is True:
|
|
129
|
+
args = {
|
|
130
|
+
"secure": self._secure,
|
|
131
|
+
"server_name": self._server_name
|
|
132
|
+
}
|
|
133
|
+
if self._cert:
|
|
134
|
+
if MILVUS_USE_TLSv2 is True:
|
|
135
|
+
args['client_pem_path'] = self._cert
|
|
136
|
+
args['client_key_path'] = self._cert_key
|
|
137
|
+
else:
|
|
138
|
+
args["server_pem_path"] = self._cert
|
|
139
|
+
if self._ca_cert:
|
|
140
|
+
args['ca_pem_path'] = self._ca_cert
|
|
141
|
+
self.kwargs = {**self.kwargs, **args}
|
|
142
|
+
# 1. Check if database exists:
|
|
143
|
+
if self.database:
|
|
144
|
+
self.kwargs['db_name'] = self.database
|
|
145
|
+
self.use_database(
|
|
146
|
+
self.database,
|
|
147
|
+
create=self.create_database
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
async def __aenter__(self):
|
|
151
|
+
try:
|
|
152
|
+
self.tensor = torch.randn(1000, 1000).cuda()
|
|
153
|
+
except RuntimeError:
|
|
154
|
+
self.tensor = None
|
|
155
|
+
if self._embed_ is None:
|
|
156
|
+
self._embed_ = self.create_embedding(
|
|
157
|
+
model_name=self.embedding_name
|
|
158
|
+
)
|
|
159
|
+
return self
|
|
160
|
+
|
|
161
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
162
|
+
# closing Embedding
|
|
163
|
+
self._embed_ = None
|
|
164
|
+
del self.tensor
|
|
165
|
+
try:
|
|
166
|
+
self.close(alias=self._client_id)
|
|
167
|
+
torch.cuda.empty_cache()
|
|
168
|
+
except RuntimeError:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
def connection(self, alias: str = None):
|
|
172
|
+
if not alias:
|
|
173
|
+
# self._client_id = str(uuid.uuid4())
|
|
174
|
+
self._client_id = 'uri-connection'
|
|
175
|
+
else:
|
|
176
|
+
self._client_id = alias
|
|
177
|
+
# making the connection:
|
|
178
|
+
self._client, self._client_id = self.connect(
|
|
179
|
+
alias=self._client_id
|
|
180
|
+
)
|
|
181
|
+
return self
|
|
182
|
+
|
|
183
|
+
def connect(self, alias: str = None) -> tuple:
|
|
184
|
+
# 1. Set up a pyMilvus default connection
|
|
185
|
+
# Unique connection:
|
|
186
|
+
if not alias:
|
|
187
|
+
alias = "default"
|
|
188
|
+
_ = connections.connect(
|
|
189
|
+
alias=alias,
|
|
190
|
+
**self.kwargs
|
|
191
|
+
)
|
|
192
|
+
client = MilvusClient(
|
|
193
|
+
**self.kwargs
|
|
194
|
+
)
|
|
195
|
+
self._connected = True
|
|
196
|
+
return client, alias
|
|
197
|
+
|
|
198
|
+
def close(self, alias: str = "default"):
|
|
199
|
+
connections.disconnect(alias=alias)
|
|
200
|
+
try:
|
|
201
|
+
self._client.close()
|
|
202
|
+
finally:
|
|
203
|
+
self._connected = False
|
|
204
|
+
|
|
205
|
+
def create_db(self, db_name: str, alias: str = 'default', **kwargs) -> bool:
|
|
206
|
+
args = {
|
|
207
|
+
"uri": self.url,
|
|
208
|
+
"host": self.host,
|
|
209
|
+
"port": self.port,
|
|
210
|
+
**kwargs
|
|
211
|
+
}
|
|
212
|
+
try:
|
|
213
|
+
conn = connections.connect(alias, **args)
|
|
214
|
+
db.create_database(db_name)
|
|
215
|
+
self.logger.notice(
|
|
216
|
+
f"Database {db_name} created successfully."
|
|
217
|
+
)
|
|
218
|
+
except Exception as e:
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"Error creating database: {e}"
|
|
221
|
+
)
|
|
222
|
+
finally:
|
|
223
|
+
connections.disconnect(alias="uri-connection")
|
|
224
|
+
|
|
225
|
+
def use_database(
|
|
226
|
+
self,
|
|
227
|
+
db_name: str,
|
|
228
|
+
alias:str = 'default',
|
|
229
|
+
create: bool = False
|
|
230
|
+
) -> None:
|
|
231
|
+
try:
|
|
232
|
+
conn = connections.connect(alias, **self.kwargs)
|
|
233
|
+
except MilvusException as exc:
|
|
234
|
+
if "database not found" in exc.message:
|
|
235
|
+
args = self.kwargs.copy()
|
|
236
|
+
del args['db_name']
|
|
237
|
+
self.create_db(db_name, alias=alias, **args)
|
|
238
|
+
# re-connect:
|
|
239
|
+
try:
|
|
240
|
+
_ = connections.connect(alias, **self.kwargs)
|
|
241
|
+
if db_name not in db.list_database(using=alias):
|
|
242
|
+
if self.create_database is True or create is True:
|
|
243
|
+
try:
|
|
244
|
+
db.create_database(db_name, using=alias, timeout=10)
|
|
245
|
+
self.logger.notice(
|
|
246
|
+
f"Database {db_name} created successfully."
|
|
247
|
+
)
|
|
248
|
+
except Exception as e:
|
|
249
|
+
raise ValueError(
|
|
250
|
+
f"Error creating database: {e}"
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
f"Database {db_name} does not exist."
|
|
255
|
+
)
|
|
256
|
+
finally:
|
|
257
|
+
connections.disconnect(alias=alias)
|
|
258
|
+
|
|
259
|
+
def setup_vector(self):
|
|
260
|
+
self.vector = Milvus(
|
|
261
|
+
self._embed_,
|
|
262
|
+
consistency_level='Bounded',
|
|
263
|
+
connection_args={**self.kwargs},
|
|
264
|
+
collection_name=self.collection,
|
|
265
|
+
)
|
|
266
|
+
return self.vector
|
|
267
|
+
|
|
268
|
+
def get_vectorstore(self):
|
|
269
|
+
return self.get_vector()
|
|
270
|
+
|
|
271
|
+
def collection_exists(self, collection_name: str) -> bool:
|
|
272
|
+
if collection_name in self._client.list_collections():
|
|
273
|
+
return True
|
|
274
|
+
return False
|
|
275
|
+
|
|
276
|
+
def check_state(self, collection_name: str) -> dict:
|
|
277
|
+
return self._client.get_load_state(collection_name=collection_name)
|
|
278
|
+
|
|
279
|
+
async def delete_collection(self, collection: str = None) -> dict:
|
|
280
|
+
self._client.drop_collection(
|
|
281
|
+
collection_name=collection
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
async def create_collection(
|
|
285
|
+
self,
|
|
286
|
+
collection_name: str,
|
|
287
|
+
document: Any = None,
|
|
288
|
+
dimension: int = 768,
|
|
289
|
+
index_type: str = None,
|
|
290
|
+
metric_type: str = None,
|
|
291
|
+
schema_type: str = 'default',
|
|
292
|
+
metadata_field: str = None,
|
|
293
|
+
**kwargs
|
|
294
|
+
) -> dict:
|
|
295
|
+
"""create_collection.
|
|
296
|
+
|
|
297
|
+
Create a Schema (Milvus Collection) on the Current Database.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
collection_name (str): Collection Name.
|
|
301
|
+
document (Any): List of Documents.
|
|
302
|
+
dimension (int, optional): Vector Dimension. Defaults to 768.
|
|
303
|
+
index_type (str, optional): Default index type of Vector Field. Defaults to "HNSW".
|
|
304
|
+
metric_type (str, optional): Default Metric for Vector Index. Defaults to "L2".
|
|
305
|
+
schema_type (str, optional): Description of Model. Defaults to 'default'.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
dict: _description_
|
|
309
|
+
"""
|
|
310
|
+
# Check if collection exists:
|
|
311
|
+
if self.collection_exists(collection_name):
|
|
312
|
+
self.logger.warning(
|
|
313
|
+
f"Collection {collection_name} already exists."
|
|
314
|
+
)
|
|
315
|
+
return None
|
|
316
|
+
idx_params = {}
|
|
317
|
+
if not index_type:
|
|
318
|
+
index_type = self._index_type
|
|
319
|
+
if index_type == 'HNSW':
|
|
320
|
+
idx_params = {
|
|
321
|
+
"M": 36,
|
|
322
|
+
"efConstruction": 1024
|
|
323
|
+
}
|
|
324
|
+
elif index_type in ('IVF_FLAT', 'SCANN', 'IVF_SQ8'):
|
|
325
|
+
idx_params = {
|
|
326
|
+
"nlist": 1024
|
|
327
|
+
}
|
|
328
|
+
elif index_type in ('IVF_PQ'):
|
|
329
|
+
idx_params = {
|
|
330
|
+
"nlist": 1024,
|
|
331
|
+
"m": 16
|
|
332
|
+
}
|
|
333
|
+
if not metric_type:
|
|
334
|
+
metric_type = self._metric_type # default metric type
|
|
335
|
+
if schema_type == 'default':
|
|
336
|
+
# Default Collection for all loaders:
|
|
337
|
+
schema = MilvusClient.create_schema(
|
|
338
|
+
auto_id=False,
|
|
339
|
+
enable_dynamic_field=True,
|
|
340
|
+
description=collection_name
|
|
341
|
+
)
|
|
342
|
+
schema.add_field(
|
|
343
|
+
field_name="pk",
|
|
344
|
+
datatype=DataType.INT64,
|
|
345
|
+
is_primary=True,
|
|
346
|
+
auto_id=True,
|
|
347
|
+
max_length=100
|
|
348
|
+
)
|
|
349
|
+
schema.add_field(
|
|
350
|
+
field_name="index",
|
|
351
|
+
datatype=DataType.VARCHAR,
|
|
352
|
+
max_length=65535
|
|
353
|
+
)
|
|
354
|
+
schema.add_field(
|
|
355
|
+
field_name="url",
|
|
356
|
+
datatype=DataType.VARCHAR,
|
|
357
|
+
max_length=65535
|
|
358
|
+
)
|
|
359
|
+
schema.add_field(
|
|
360
|
+
field_name="source",
|
|
361
|
+
datatype=DataType.VARCHAR,
|
|
362
|
+
max_length=65535
|
|
363
|
+
)
|
|
364
|
+
schema.add_field(
|
|
365
|
+
field_name="filename",
|
|
366
|
+
datatype=DataType.VARCHAR,
|
|
367
|
+
max_length=65535
|
|
368
|
+
)
|
|
369
|
+
schema.add_field(
|
|
370
|
+
field_name="question",
|
|
371
|
+
datatype=DataType.VARCHAR,
|
|
372
|
+
max_length=65535
|
|
373
|
+
)
|
|
374
|
+
schema.add_field(
|
|
375
|
+
field_name="answer",
|
|
376
|
+
datatype=DataType.VARCHAR,
|
|
377
|
+
max_length=65535
|
|
378
|
+
)
|
|
379
|
+
schema.add_field(
|
|
380
|
+
field_name="source_type",
|
|
381
|
+
datatype=DataType.VARCHAR,
|
|
382
|
+
max_length=128
|
|
383
|
+
)
|
|
384
|
+
schema.add_field(
|
|
385
|
+
field_name="type",
|
|
386
|
+
datatype=DataType.VARCHAR,
|
|
387
|
+
max_length=65535
|
|
388
|
+
)
|
|
389
|
+
schema.add_field(
|
|
390
|
+
field_name="text",
|
|
391
|
+
datatype=DataType.VARCHAR,
|
|
392
|
+
description="Text",
|
|
393
|
+
max_length=65535
|
|
394
|
+
)
|
|
395
|
+
schema.add_field(
|
|
396
|
+
field_name="summary",
|
|
397
|
+
datatype=DataType.VARCHAR,
|
|
398
|
+
description="Summary (refine resume)",
|
|
399
|
+
max_length=65535
|
|
400
|
+
)
|
|
401
|
+
schema.add_field(
|
|
402
|
+
field_name="vector",
|
|
403
|
+
datatype=DataType.FLOAT_VECTOR,
|
|
404
|
+
dim=dimension,
|
|
405
|
+
description="vector"
|
|
406
|
+
)
|
|
407
|
+
# schema.add_field(
|
|
408
|
+
# field_name="embedding",
|
|
409
|
+
# datatype=DataType.FLOAT_VECTOR,
|
|
410
|
+
# dim=dimension,
|
|
411
|
+
# description="Binary Embeddings"
|
|
412
|
+
# )
|
|
413
|
+
schema.add_field(
|
|
414
|
+
field_name="document_meta",
|
|
415
|
+
datatype=DataType.JSON,
|
|
416
|
+
description="Custom Metadata information"
|
|
417
|
+
)
|
|
418
|
+
index_params = self._client.prepare_index_params()
|
|
419
|
+
index_params.add_index(
|
|
420
|
+
field_name="pk",
|
|
421
|
+
index_type="STL_SORT"
|
|
422
|
+
)
|
|
423
|
+
index_params.add_index(
|
|
424
|
+
field_name="text",
|
|
425
|
+
index_type="marisa-trie"
|
|
426
|
+
)
|
|
427
|
+
index_params.add_index(
|
|
428
|
+
field_name="summary",
|
|
429
|
+
index_type="marisa-trie"
|
|
430
|
+
)
|
|
431
|
+
index_params.add_index(
|
|
432
|
+
field_name="vector",
|
|
433
|
+
index_type=index_type,
|
|
434
|
+
metric_type=metric_type,
|
|
435
|
+
params=idx_params
|
|
436
|
+
)
|
|
437
|
+
self._client.create_collection(
|
|
438
|
+
collection_name=collection_name,
|
|
439
|
+
schema=schema,
|
|
440
|
+
index_params=index_params,
|
|
441
|
+
num_shards=2
|
|
442
|
+
)
|
|
443
|
+
await asyncio.sleep(2)
|
|
444
|
+
res = self._client.get_load_state(
|
|
445
|
+
collection_name=collection_name
|
|
446
|
+
)
|
|
447
|
+
return None
|
|
448
|
+
else:
|
|
449
|
+
self._client.create_collection(
|
|
450
|
+
collection_name=collection_name,
|
|
451
|
+
dimension=dimension
|
|
452
|
+
)
|
|
453
|
+
if metadata_field:
|
|
454
|
+
kwargs['metadata_field'] = metadata_field
|
|
455
|
+
# Here using drop_old=True to force recreate based on the first document
|
|
456
|
+
docstore = Milvus.from_documents(
|
|
457
|
+
[document], # Only the first document
|
|
458
|
+
self._embed_,
|
|
459
|
+
connection_args={**self.kwargs},
|
|
460
|
+
collection_name=collection_name,
|
|
461
|
+
drop_old=True,
|
|
462
|
+
# consistency_level='Session',
|
|
463
|
+
primary_field='pk',
|
|
464
|
+
text_field='text',
|
|
465
|
+
vector_field='vector',
|
|
466
|
+
**kwargs
|
|
467
|
+
)
|
|
468
|
+
return docstore
|
|
469
|
+
|
|
470
|
+
async def load_documents(
|
|
471
|
+
self,
|
|
472
|
+
documents: list,
|
|
473
|
+
collection: str = None,
|
|
474
|
+
upsert: bool = False,
|
|
475
|
+
attribute: str = 'source_type',
|
|
476
|
+
metadata_field: str = None,
|
|
477
|
+
**kwargs
|
|
478
|
+
):
|
|
479
|
+
if not collection:
|
|
480
|
+
collection = self.collection
|
|
481
|
+
try:
|
|
482
|
+
tensor = torch.randn(1000, 1000).cuda()
|
|
483
|
+
except Exception:
|
|
484
|
+
tensor = None
|
|
485
|
+
if upsert is True:
|
|
486
|
+
# get first document
|
|
487
|
+
doc = documents[0]
|
|
488
|
+
# getting source type:
|
|
489
|
+
doc_type = doc.metadata.get('attribute', None)
|
|
490
|
+
if attribute:
|
|
491
|
+
deleted = self._client.delete(
|
|
492
|
+
collection_name=collection,
|
|
493
|
+
filter=f'{attribute} == "{doc_type}"'
|
|
494
|
+
)
|
|
495
|
+
self.logger.notice(
|
|
496
|
+
f"Deleted documents with {attribute} {attribute}: {deleted}"
|
|
497
|
+
)
|
|
498
|
+
if metadata_field:
|
|
499
|
+
# document_meta
|
|
500
|
+
kwargs['metadata_field'] = metadata_field
|
|
501
|
+
docstore = Milvus.from_documents(
|
|
502
|
+
documents,
|
|
503
|
+
self._embed_,
|
|
504
|
+
connection_args={**self.kwargs},
|
|
505
|
+
collection_name=collection,
|
|
506
|
+
consistency_level='Bounded',
|
|
507
|
+
drop_old=False,
|
|
508
|
+
primary_field='pk',
|
|
509
|
+
text_field='text',
|
|
510
|
+
vector_field='vector',
|
|
511
|
+
**kwargs
|
|
512
|
+
)
|
|
513
|
+
del tensor
|
|
514
|
+
return docstore
|
|
515
|
+
|
|
516
|
+
def upsert(self, payload: dict, collection: str = None) -> None:
|
|
517
|
+
pass
|
|
518
|
+
|
|
519
|
+
def insert(
|
|
520
|
+
self,
|
|
521
|
+
payload: Union[dict, list],
|
|
522
|
+
collection: Union[str, None] = None
|
|
523
|
+
) -> dict:
|
|
524
|
+
if collection is None:
|
|
525
|
+
collection = self.collection
|
|
526
|
+
result = self._client.insert(
|
|
527
|
+
collection_name=collection,
|
|
528
|
+
data=payload
|
|
529
|
+
)
|
|
530
|
+
collection.flush()
|
|
531
|
+
return result
|
|
532
|
+
|
|
533
|
+
def get_vector(
|
|
534
|
+
self,
|
|
535
|
+
collection: Union[str, None] = None,
|
|
536
|
+
metric_type: str = None,
|
|
537
|
+
nprobe: int = 32,
|
|
538
|
+
metadata_field: str = None,
|
|
539
|
+
consistency_level: str = 'session'
|
|
540
|
+
) -> Milvus:
|
|
541
|
+
if not metric_type:
|
|
542
|
+
metric_type = self._metric_type
|
|
543
|
+
if not collection:
|
|
544
|
+
collection = self.collection
|
|
545
|
+
_search = {
|
|
546
|
+
"search_params": {
|
|
547
|
+
"metric_type": metric_type,
|
|
548
|
+
"params": {"nprobe": nprobe, "nlist": 1024},
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
if metadata_field:
|
|
552
|
+
# document_meta
|
|
553
|
+
_search['metadata_field'] = metadata_field
|
|
554
|
+
_embed_ = self.create_embedding(
|
|
555
|
+
model_name=self.embedding_name
|
|
556
|
+
)
|
|
557
|
+
return Milvus(
|
|
558
|
+
embedding_function=_embed_,
|
|
559
|
+
collection_name=collection,
|
|
560
|
+
consistency_level=consistency_level,
|
|
561
|
+
connection_args={
|
|
562
|
+
**self.kwargs
|
|
563
|
+
},
|
|
564
|
+
primary_field='pk',
|
|
565
|
+
text_field='text',
|
|
566
|
+
vector_field='vector',
|
|
567
|
+
**_search
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
def similarity_search(
|
|
571
|
+
self,
|
|
572
|
+
query: str,
|
|
573
|
+
collection: Union[str, None] = None,
|
|
574
|
+
limit: int = 2,
|
|
575
|
+
consistency_level: str = 'Bounded'
|
|
576
|
+
) -> list:
|
|
577
|
+
if collection is None:
|
|
578
|
+
collection = self.collection
|
|
579
|
+
if self._embed_ is None:
|
|
580
|
+
_embed_ = self.create_embedding(
|
|
581
|
+
model_name=self.embedding_name
|
|
582
|
+
)
|
|
583
|
+
else:
|
|
584
|
+
_embed_ = self._embed_
|
|
585
|
+
vector_db = Milvus(
|
|
586
|
+
embedding_function=_embed_,
|
|
587
|
+
collection_name=collection,
|
|
588
|
+
consistency_level=consistency_level,
|
|
589
|
+
connection_args={
|
|
590
|
+
**self.kwargs
|
|
591
|
+
},
|
|
592
|
+
primary_field='pk',
|
|
593
|
+
text_field='text',
|
|
594
|
+
vector_field='vector'
|
|
595
|
+
)
|
|
596
|
+
return vector_db.similarity_search(query, k=limit)
|
|
597
|
+
|
|
598
|
+
def search(
|
|
599
|
+
self,
|
|
600
|
+
payload: Union[dict, list],
|
|
601
|
+
collection: Union[str, None] = None,
|
|
602
|
+
limit: Optional[int] = None
|
|
603
|
+
) -> list:
|
|
604
|
+
args = {}
|
|
605
|
+
if collection is None:
|
|
606
|
+
collection = self.collection
|
|
607
|
+
if limit is not None:
|
|
608
|
+
args = {"limit": limit}
|
|
609
|
+
if isinstance(payload, dict):
|
|
610
|
+
payload = [payload]
|
|
611
|
+
result = self._client.search(
|
|
612
|
+
collection_name=collection,
|
|
613
|
+
data=payload,
|
|
614
|
+
**args
|
|
615
|
+
)
|
|
616
|
+
return result
|
|
617
|
+
|
|
618
|
+
def memory_retriever(self, num_results: int = 5) -> VectorStoreRetrieverMemory:
|
|
619
|
+
vectordb = Milvus.from_documents(
|
|
620
|
+
{},
|
|
621
|
+
self._embed_,
|
|
622
|
+
connection_args={**self.kwargs}
|
|
623
|
+
)
|
|
624
|
+
retriever = Milvus.as_retriever(
|
|
625
|
+
vectordb,
|
|
626
|
+
search_kwargs=dict(k=num_results)
|
|
627
|
+
)
|
|
628
|
+
return VectorStoreRetrieverMemory(retriever=retriever)
|
|
629
|
+
|
|
630
|
+
def save_context(self, memory: VectorStoreRetrieverMemory, context: list) -> None:
|
|
631
|
+
for val in context:
|
|
632
|
+
memory.save_context(val)
|