ai-parrot 0.1.0__cp311-cp311-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-parrot might be problematic. Click here for more details.
- ai_parrot-0.1.0.dist-info/LICENSE +21 -0
- ai_parrot-0.1.0.dist-info/METADATA +299 -0
- ai_parrot-0.1.0.dist-info/RECORD +108 -0
- ai_parrot-0.1.0.dist-info/WHEEL +5 -0
- ai_parrot-0.1.0.dist-info/top_level.txt +3 -0
- parrot/__init__.py +18 -0
- parrot/chatbots/__init__.py +7 -0
- parrot/chatbots/abstract.py +965 -0
- parrot/chatbots/asktroc.py +16 -0
- parrot/chatbots/base.py +257 -0
- parrot/chatbots/basic.py +9 -0
- parrot/chatbots/bose.py +17 -0
- parrot/chatbots/cody.py +17 -0
- parrot/chatbots/copilot.py +100 -0
- parrot/chatbots/dataframe.py +103 -0
- parrot/chatbots/hragents.py +15 -0
- parrot/chatbots/oddie.py +17 -0
- parrot/chatbots/retrievals/__init__.py +515 -0
- parrot/chatbots/retrievals/constitutional.py +19 -0
- parrot/conf.py +108 -0
- parrot/crew/__init__.py +3 -0
- parrot/crew/tools/__init__.py +22 -0
- parrot/crew/tools/bing.py +13 -0
- parrot/crew/tools/config.py +43 -0
- parrot/crew/tools/duckgo.py +62 -0
- parrot/crew/tools/file.py +24 -0
- parrot/crew/tools/google.py +168 -0
- parrot/crew/tools/gtrends.py +16 -0
- parrot/crew/tools/md2pdf.py +25 -0
- parrot/crew/tools/rag.py +42 -0
- parrot/crew/tools/search.py +32 -0
- parrot/crew/tools/url.py +21 -0
- parrot/exceptions.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/bots.py +196 -0
- parrot/handlers/chat.py +169 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/database.py +29 -0
- parrot/llms/__init__.py +0 -0
- parrot/llms/abstract.py +41 -0
- parrot/llms/anthropic.py +36 -0
- parrot/llms/google.py +37 -0
- parrot/llms/groq.py +33 -0
- parrot/llms/hf.py +39 -0
- parrot/llms/openai.py +49 -0
- parrot/llms/pipes.py +103 -0
- parrot/llms/vertex.py +68 -0
- parrot/loaders/__init__.py +20 -0
- parrot/loaders/abstract.py +456 -0
- parrot/loaders/basepdf.py +102 -0
- parrot/loaders/basevideo.py +280 -0
- parrot/loaders/csv.py +42 -0
- parrot/loaders/dir.py +37 -0
- parrot/loaders/excel.py +349 -0
- parrot/loaders/github.py +65 -0
- parrot/loaders/handlers/__init__.py +5 -0
- parrot/loaders/handlers/data.py +213 -0
- parrot/loaders/image.py +119 -0
- parrot/loaders/json.py +52 -0
- parrot/loaders/pdf.py +187 -0
- parrot/loaders/pdfchapters.py +142 -0
- parrot/loaders/pdffn.py +112 -0
- parrot/loaders/pdfimages.py +207 -0
- parrot/loaders/pdfmark.py +88 -0
- parrot/loaders/pdftables.py +145 -0
- parrot/loaders/ppt.py +30 -0
- parrot/loaders/qa.py +81 -0
- parrot/loaders/repo.py +103 -0
- parrot/loaders/rtd.py +65 -0
- parrot/loaders/txt.py +92 -0
- parrot/loaders/utils/__init__.py +1 -0
- parrot/loaders/utils/models.py +25 -0
- parrot/loaders/video.py +96 -0
- parrot/loaders/videolocal.py +107 -0
- parrot/loaders/vimeo.py +106 -0
- parrot/loaders/web.py +216 -0
- parrot/loaders/web_base.py +112 -0
- parrot/loaders/word.py +125 -0
- parrot/loaders/youtube.py +192 -0
- parrot/manager.py +152 -0
- parrot/models.py +347 -0
- parrot/py.typed +0 -0
- parrot/stores/__init__.py +0 -0
- parrot/stores/abstract.py +170 -0
- parrot/stores/milvus.py +540 -0
- parrot/stores/qdrant.py +153 -0
- parrot/tools/__init__.py +16 -0
- parrot/tools/abstract.py +53 -0
- parrot/tools/asknews.py +32 -0
- parrot/tools/bing.py +13 -0
- parrot/tools/duck.py +62 -0
- parrot/tools/google.py +170 -0
- parrot/tools/stack.py +26 -0
- parrot/tools/weather.py +70 -0
- parrot/tools/wikipedia.py +59 -0
- parrot/tools/zipcode.py +179 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpython-311-x86_64-linux-gnu.so +0 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- resources/users/__init__.py +5 -0
- resources/users/handlers.py +13 -0
- resources/users/models.py +205 -0
- settings/__init__.py +0 -0
- settings/settings.py +51 -0
parrot/stores/milvus.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
1
|
+
from typing import Optional, Union, Any
|
|
2
|
+
import asyncio
|
|
3
|
+
import uuid
|
|
4
|
+
import torch
|
|
5
|
+
from pymilvus import (
|
|
6
|
+
MilvusClient,
|
|
7
|
+
Collection,
|
|
8
|
+
FieldSchema,
|
|
9
|
+
CollectionSchema,
|
|
10
|
+
DataType,
|
|
11
|
+
connections,
|
|
12
|
+
db
|
|
13
|
+
)
|
|
14
|
+
from pymilvus.exceptions import MilvusException
|
|
15
|
+
from langchain_milvus import Milvus # pylint: disable=import-error, E0611
|
|
16
|
+
from langchain.memory import VectorStoreRetrieverMemory
|
|
17
|
+
from .abstract import AbstractStore
|
|
18
|
+
from ..conf import (
|
|
19
|
+
MILVUS_HOST,
|
|
20
|
+
MILVUS_PROTOCOL,
|
|
21
|
+
MILVUS_PORT,
|
|
22
|
+
MILVUS_URL,
|
|
23
|
+
MILVUS_TOKEN,
|
|
24
|
+
MILVUS_USER,
|
|
25
|
+
MILVUS_PASSWORD,
|
|
26
|
+
MILVUS_SECURE,
|
|
27
|
+
MILVUS_SERVER_NAME,
|
|
28
|
+
MILVUS_CA_CERT,
|
|
29
|
+
MILVUS_SERVER_CERT,
|
|
30
|
+
MILVUS_SERVER_KEY,
|
|
31
|
+
MILVUS_USE_TLSv2
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MilvusStore(AbstractStore):
|
|
36
|
+
"""MilvusStore class.
|
|
37
|
+
|
|
38
|
+
Milvus is a Vector Database multi-layered.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
host (str): Milvus host.
|
|
42
|
+
port (int): Milvus port.
|
|
43
|
+
url (str): Milvus URL.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, embeddings = None, **kwargs):
|
|
47
|
+
super().__init__(embeddings, **kwargs)
|
|
48
|
+
self.use_bge: bool = kwargs.pop("use_bge", False)
|
|
49
|
+
self.fastembed: bool = kwargs.pop("use_fastembed", False)
|
|
50
|
+
self.database: str = kwargs.pop('database', '')
|
|
51
|
+
self.collection = kwargs.pop('collection_name', '')
|
|
52
|
+
self.dimension: int = kwargs.pop("dimension", 768)
|
|
53
|
+
self._metric_type: str = kwargs.pop("metric_type", 'COSINE')
|
|
54
|
+
self._index_type: str = kwargs.pop("index_type", 'IVF_FLAT')
|
|
55
|
+
self.host = kwargs.pop("host", MILVUS_HOST)
|
|
56
|
+
self.port = kwargs.pop("port", MILVUS_PORT)
|
|
57
|
+
self.protocol = kwargs.pop("protocol", MILVUS_PROTOCOL)
|
|
58
|
+
self.create_database: bool = kwargs.pop('create_database', True)
|
|
59
|
+
self.url = kwargs.pop("url", MILVUS_URL)
|
|
60
|
+
if not self.url:
|
|
61
|
+
self.url = f"{self.protocol}://{self.host}:{self.port}"
|
|
62
|
+
else:
|
|
63
|
+
# Extract host and port from URL
|
|
64
|
+
if not self.host:
|
|
65
|
+
self.host = self.url.split("://")[-1].split(":")[0]
|
|
66
|
+
if not self.port:
|
|
67
|
+
self.port = int(self.url.split(":")[-1])
|
|
68
|
+
self.token = kwargs.pop("token", MILVUS_TOKEN)
|
|
69
|
+
# user and password (if required)
|
|
70
|
+
self.user = kwargs.pop("user", MILVUS_USER)
|
|
71
|
+
self.password = kwargs.pop("password", MILVUS_PASSWORD)
|
|
72
|
+
# SSL/TLS
|
|
73
|
+
self._secure: bool = kwargs.pop('secure', MILVUS_SECURE)
|
|
74
|
+
self._server_name: str = kwargs.pop('server_name', MILVUS_SERVER_NAME)
|
|
75
|
+
self._cert: str = kwargs.pop('server_pem_path', MILVUS_SERVER_CERT)
|
|
76
|
+
self._ca_cert: str = kwargs.pop('ca_pem_path', MILVUS_CA_CERT)
|
|
77
|
+
self._cert_key: str = kwargs.pop('client_key_path', MILVUS_SERVER_KEY)
|
|
78
|
+
# Any other argument will be passed to the Milvus client
|
|
79
|
+
self.kwargs = {
|
|
80
|
+
"uri": self.url,
|
|
81
|
+
"host": self.host,
|
|
82
|
+
"port": self.port,
|
|
83
|
+
**kwargs
|
|
84
|
+
}
|
|
85
|
+
if self.token:
|
|
86
|
+
self.kwargs['token'] = self.token
|
|
87
|
+
if self.user:
|
|
88
|
+
self.kwargs['token'] = f"{self.user}:{self.password}"
|
|
89
|
+
# SSL Security:
|
|
90
|
+
if self._secure is True:
|
|
91
|
+
args = {
|
|
92
|
+
"secure": self._secure,
|
|
93
|
+
"server_name": self._server_name
|
|
94
|
+
}
|
|
95
|
+
if self._cert:
|
|
96
|
+
if MILVUS_USE_TLSv2 is True:
|
|
97
|
+
args['client_pem_path'] = self._cert
|
|
98
|
+
args['client_key_path'] = self._cert_key
|
|
99
|
+
else:
|
|
100
|
+
args["server_pem_path"] = self._cert
|
|
101
|
+
if self._ca_cert:
|
|
102
|
+
args['ca_pem_path'] = self._ca_cert
|
|
103
|
+
self.kwargs = {**self.kwargs, **args}
|
|
104
|
+
# 1. Check if database exists:
|
|
105
|
+
if self.database:
|
|
106
|
+
self.kwargs['db_name'] = self.database
|
|
107
|
+
self.use_database(self.database, create=True)
|
|
108
|
+
|
|
109
|
+
def connect(self, client_id: str = None):
|
|
110
|
+
# 1. Set up a pyMilvus default connection
|
|
111
|
+
# Unique connection:
|
|
112
|
+
if not client_id:
|
|
113
|
+
client_id = "uri-connection"
|
|
114
|
+
self.conn = connections.connect(
|
|
115
|
+
alias=client_id,
|
|
116
|
+
**self.kwargs
|
|
117
|
+
)
|
|
118
|
+
self._client = MilvusClient(
|
|
119
|
+
**self.kwargs
|
|
120
|
+
)
|
|
121
|
+
self._connected = True
|
|
122
|
+
return self._client, client_id
|
|
123
|
+
|
|
124
|
+
def close(self, client_id: str = "uri-connection"):
|
|
125
|
+
connections.disconnect(alias=client_id)
|
|
126
|
+
self._connected = False
|
|
127
|
+
self._client.close()
|
|
128
|
+
|
|
129
|
+
def create_db(self, db_name: str, **kwargs) -> bool:
|
|
130
|
+
args = {
|
|
131
|
+
"uri": self.url,
|
|
132
|
+
"host": self.host,
|
|
133
|
+
"port": self.port,
|
|
134
|
+
**kwargs
|
|
135
|
+
}
|
|
136
|
+
print('ARGS >', args)
|
|
137
|
+
try:
|
|
138
|
+
conn = connections.connect(**args)
|
|
139
|
+
db.create_database(db_name)
|
|
140
|
+
self.logger.notice(
|
|
141
|
+
f"Database {db_name} created successfully."
|
|
142
|
+
)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"Error creating database: {e}"
|
|
146
|
+
)
|
|
147
|
+
finally:
|
|
148
|
+
connections.disconnect(alias="uri-connection")
|
|
149
|
+
|
|
150
|
+
def use_database(self, db_name: str, create: bool = False) -> None:
|
|
151
|
+
try:
|
|
152
|
+
conn = connections.connect(**self.kwargs)
|
|
153
|
+
except MilvusException as exc:
|
|
154
|
+
if "database not found" in exc.message:
|
|
155
|
+
args = self.kwargs.copy()
|
|
156
|
+
del args['db_name']
|
|
157
|
+
self.create_db(db_name, **args)
|
|
158
|
+
# re-connect:
|
|
159
|
+
conn = connections.connect(**self.kwargs)
|
|
160
|
+
if db_name not in db.list_database():
|
|
161
|
+
if self.create_database is True or create is True:
|
|
162
|
+
try:
|
|
163
|
+
db.create_database(db_name)
|
|
164
|
+
self.logger.notice(
|
|
165
|
+
f"Database {db_name} created successfully."
|
|
166
|
+
)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
f"Error creating database: {e}"
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Database {db_name} does not exist."
|
|
174
|
+
)
|
|
175
|
+
connections.disconnect(alias='default')
|
|
176
|
+
|
|
177
|
+
def setup_vector(self):
|
|
178
|
+
self.vector = Milvus(
|
|
179
|
+
self._embed_,
|
|
180
|
+
consistency_level='Bounded',
|
|
181
|
+
connection_args={**self.kwargs},
|
|
182
|
+
collection_name=self.collection,
|
|
183
|
+
)
|
|
184
|
+
return self.vector
|
|
185
|
+
|
|
186
|
+
def get_vectorstore(self):
|
|
187
|
+
return self.get_vector()
|
|
188
|
+
|
|
189
|
+
def collection_exists(self, collection_name: str) -> bool:
|
|
190
|
+
if collection_name in self._client.list_collections():
|
|
191
|
+
return True
|
|
192
|
+
return False
|
|
193
|
+
|
|
194
|
+
def check_state(self, collection_name: str) -> dict:
|
|
195
|
+
return self._client.get_load_state(collection_name=collection_name)
|
|
196
|
+
|
|
197
|
+
async def delete_collection(self, collection: str = None) -> dict:
|
|
198
|
+
self._client.drop_collection(
|
|
199
|
+
collection_name=collection
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
async def create_collection(
|
|
203
|
+
self,
|
|
204
|
+
collection_name: str,
|
|
205
|
+
document: Any = None,
|
|
206
|
+
dimension: int = 768,
|
|
207
|
+
index_type: str = None,
|
|
208
|
+
metric_type: str = None,
|
|
209
|
+
schema_type: str = 'default',
|
|
210
|
+
metadata_field: str = None,
|
|
211
|
+
**kwargs
|
|
212
|
+
) -> dict:
|
|
213
|
+
"""create_collection.
|
|
214
|
+
|
|
215
|
+
Create a Schema (Milvus Collection) on the Current Database.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
collection_name (str): Collection Name.
|
|
219
|
+
document (Any): List of Documents.
|
|
220
|
+
dimension (int, optional): Vector Dimension. Defaults to 768.
|
|
221
|
+
index_type (str, optional): Default index type of Vector Field. Defaults to "HNSW".
|
|
222
|
+
metric_type (str, optional): Default Metric for Vector Index. Defaults to "L2".
|
|
223
|
+
schema_type (str, optional): Description of Model. Defaults to 'default'.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
dict: _description_
|
|
227
|
+
"""
|
|
228
|
+
# Check if collection exists:
|
|
229
|
+
if self.collection_exists(collection_name):
|
|
230
|
+
self.logger.warning(
|
|
231
|
+
f"Collection {collection_name} already exists."
|
|
232
|
+
)
|
|
233
|
+
return None
|
|
234
|
+
idx_params = {}
|
|
235
|
+
if not index_type:
|
|
236
|
+
index_type = self._index_type
|
|
237
|
+
if index_type == 'HNSW':
|
|
238
|
+
idx_params = {
|
|
239
|
+
"M": 36,
|
|
240
|
+
"efConstruction": 1024
|
|
241
|
+
}
|
|
242
|
+
elif index_type in ('IVF_FLAT', 'SCANN', 'IVF_SQ8'):
|
|
243
|
+
idx_params = {
|
|
244
|
+
"nlist": 1024
|
|
245
|
+
}
|
|
246
|
+
elif index_type in ('IVF_PQ'):
|
|
247
|
+
idx_params = {
|
|
248
|
+
"nlist": 1024,
|
|
249
|
+
"m": 16
|
|
250
|
+
}
|
|
251
|
+
if not metric_type:
|
|
252
|
+
metric_type = self._metric_type # default metric type
|
|
253
|
+
if schema_type == 'default':
|
|
254
|
+
# Default Collection for all loaders:
|
|
255
|
+
schema = MilvusClient.create_schema(
|
|
256
|
+
auto_id=False,
|
|
257
|
+
enable_dynamic_field=True,
|
|
258
|
+
description=collection_name
|
|
259
|
+
)
|
|
260
|
+
schema.add_field(
|
|
261
|
+
field_name="pk",
|
|
262
|
+
datatype=DataType.INT64,
|
|
263
|
+
is_primary=True,
|
|
264
|
+
auto_id=True,
|
|
265
|
+
max_length=100
|
|
266
|
+
)
|
|
267
|
+
schema.add_field(
|
|
268
|
+
field_name="index",
|
|
269
|
+
datatype=DataType.VARCHAR,
|
|
270
|
+
max_length=65535
|
|
271
|
+
)
|
|
272
|
+
schema.add_field(
|
|
273
|
+
field_name="url",
|
|
274
|
+
datatype=DataType.VARCHAR,
|
|
275
|
+
max_length=65535
|
|
276
|
+
)
|
|
277
|
+
schema.add_field(
|
|
278
|
+
field_name="source",
|
|
279
|
+
datatype=DataType.VARCHAR,
|
|
280
|
+
max_length=65535
|
|
281
|
+
)
|
|
282
|
+
schema.add_field(
|
|
283
|
+
field_name="filename",
|
|
284
|
+
datatype=DataType.VARCHAR,
|
|
285
|
+
max_length=65535
|
|
286
|
+
)
|
|
287
|
+
schema.add_field(
|
|
288
|
+
field_name="question",
|
|
289
|
+
datatype=DataType.VARCHAR,
|
|
290
|
+
max_length=65535
|
|
291
|
+
)
|
|
292
|
+
schema.add_field(
|
|
293
|
+
field_name="answer",
|
|
294
|
+
datatype=DataType.VARCHAR,
|
|
295
|
+
max_length=65535
|
|
296
|
+
)
|
|
297
|
+
schema.add_field(
|
|
298
|
+
field_name="source_type",
|
|
299
|
+
datatype=DataType.VARCHAR,
|
|
300
|
+
max_length=128
|
|
301
|
+
)
|
|
302
|
+
schema.add_field(
|
|
303
|
+
field_name="type",
|
|
304
|
+
datatype=DataType.VARCHAR,
|
|
305
|
+
max_length=65535
|
|
306
|
+
)
|
|
307
|
+
schema.add_field(
|
|
308
|
+
field_name="text",
|
|
309
|
+
datatype=DataType.VARCHAR,
|
|
310
|
+
description="Text",
|
|
311
|
+
max_length=65535
|
|
312
|
+
)
|
|
313
|
+
schema.add_field(
|
|
314
|
+
field_name="summary",
|
|
315
|
+
datatype=DataType.VARCHAR,
|
|
316
|
+
description="Summary (refine resume)",
|
|
317
|
+
max_length=65535
|
|
318
|
+
)
|
|
319
|
+
schema.add_field(
|
|
320
|
+
field_name="vector",
|
|
321
|
+
datatype=DataType.FLOAT_VECTOR,
|
|
322
|
+
dim=dimension,
|
|
323
|
+
description="vector"
|
|
324
|
+
)
|
|
325
|
+
# schema.add_field(
|
|
326
|
+
# field_name="embedding",
|
|
327
|
+
# datatype=DataType.FLOAT_VECTOR,
|
|
328
|
+
# dim=dimension,
|
|
329
|
+
# description="Binary Embeddings"
|
|
330
|
+
# )
|
|
331
|
+
schema.add_field(
|
|
332
|
+
field_name="document_meta",
|
|
333
|
+
datatype=DataType.JSON,
|
|
334
|
+
description="Custom Metadata information"
|
|
335
|
+
)
|
|
336
|
+
index_params = self._client.prepare_index_params()
|
|
337
|
+
index_params.add_index(
|
|
338
|
+
field_name="pk",
|
|
339
|
+
index_type="STL_SORT"
|
|
340
|
+
)
|
|
341
|
+
index_params.add_index(
|
|
342
|
+
field_name="text",
|
|
343
|
+
index_type="marisa-trie"
|
|
344
|
+
)
|
|
345
|
+
index_params.add_index(
|
|
346
|
+
field_name="summary",
|
|
347
|
+
index_type="marisa-trie"
|
|
348
|
+
)
|
|
349
|
+
index_params.add_index(
|
|
350
|
+
field_name="vector",
|
|
351
|
+
index_type=index_type,
|
|
352
|
+
metric_type=metric_type,
|
|
353
|
+
params=idx_params
|
|
354
|
+
)
|
|
355
|
+
self._client.create_collection(
|
|
356
|
+
collection_name=collection_name,
|
|
357
|
+
schema=schema,
|
|
358
|
+
index_params=index_params,
|
|
359
|
+
num_shards=2
|
|
360
|
+
)
|
|
361
|
+
await asyncio.sleep(2)
|
|
362
|
+
res = self._client.get_load_state(
|
|
363
|
+
collection_name=collection_name
|
|
364
|
+
)
|
|
365
|
+
return None
|
|
366
|
+
else:
|
|
367
|
+
self._client.create_collection(
|
|
368
|
+
collection_name=collection_name,
|
|
369
|
+
dimension=dimension
|
|
370
|
+
)
|
|
371
|
+
if metadata_field:
|
|
372
|
+
kwargs['metadata_field'] = metadata_field
|
|
373
|
+
# Here using drop_old=True to force recreate based on the first document
|
|
374
|
+
docstore = Milvus.from_documents(
|
|
375
|
+
[document], # Only the first document
|
|
376
|
+
self._embed_,
|
|
377
|
+
connection_args={**self.kwargs},
|
|
378
|
+
collection_name=collection_name,
|
|
379
|
+
drop_old=True,
|
|
380
|
+
# consistency_level='Session',
|
|
381
|
+
primary_field='pk',
|
|
382
|
+
text_field='text',
|
|
383
|
+
vector_field='vector',
|
|
384
|
+
**kwargs
|
|
385
|
+
)
|
|
386
|
+
return docstore
|
|
387
|
+
|
|
388
|
+
async def load_documents(
|
|
389
|
+
self,
|
|
390
|
+
documents: list,
|
|
391
|
+
collection: str = None,
|
|
392
|
+
upsert: bool = False,
|
|
393
|
+
attribute: str = 'source_type',
|
|
394
|
+
metadata_field: str = None,
|
|
395
|
+
**kwargs
|
|
396
|
+
):
|
|
397
|
+
if not collection:
|
|
398
|
+
collection = self.collection
|
|
399
|
+
try:
|
|
400
|
+
tensor = torch.randn(1000, 1000).cuda()
|
|
401
|
+
except Exception:
|
|
402
|
+
tensor = None
|
|
403
|
+
if upsert is True:
|
|
404
|
+
# get first document
|
|
405
|
+
doc = documents[0]
|
|
406
|
+
# getting source type:
|
|
407
|
+
doc_type = doc.metadata.get('attribute', None)
|
|
408
|
+
if attribute:
|
|
409
|
+
deleted = self._client.delete(
|
|
410
|
+
collection_name=collection,
|
|
411
|
+
filter=f'{attribute} == "{doc_type}"'
|
|
412
|
+
)
|
|
413
|
+
self.logger.notice(
|
|
414
|
+
f"Deleted documents with {attribute} {attribute}: {deleted}"
|
|
415
|
+
)
|
|
416
|
+
if metadata_field:
|
|
417
|
+
# document_meta
|
|
418
|
+
kwargs['metadata_field'] = metadata_field
|
|
419
|
+
docstore = Milvus.from_documents(
|
|
420
|
+
documents,
|
|
421
|
+
self._embed_,
|
|
422
|
+
connection_args={**self.kwargs},
|
|
423
|
+
collection_name=collection,
|
|
424
|
+
consistency_level='Bounded',
|
|
425
|
+
drop_old=False,
|
|
426
|
+
primary_field='pk',
|
|
427
|
+
text_field='text',
|
|
428
|
+
vector_field='vector',
|
|
429
|
+
**kwargs
|
|
430
|
+
)
|
|
431
|
+
del tensor
|
|
432
|
+
return docstore
|
|
433
|
+
|
|
434
|
+
def upsert(self, payload: dict, collection: str = None) -> None:
|
|
435
|
+
pass
|
|
436
|
+
|
|
437
|
+
def insert(
|
|
438
|
+
self,
|
|
439
|
+
payload: Union[dict, list],
|
|
440
|
+
collection: str = None
|
|
441
|
+
) -> dict:
|
|
442
|
+
if collection is None:
|
|
443
|
+
collection = self.collection
|
|
444
|
+
result = self._client.insert(
|
|
445
|
+
collection_name=collection,
|
|
446
|
+
data=payload
|
|
447
|
+
)
|
|
448
|
+
collection.flush()
|
|
449
|
+
return result
|
|
450
|
+
|
|
451
|
+
def get_vector(
|
|
452
|
+
self,
|
|
453
|
+
collection: str = None,
|
|
454
|
+
metric_type: str = None,
|
|
455
|
+
nprobe: int = 200,
|
|
456
|
+
metadata_field: str = None
|
|
457
|
+
) -> Milvus:
|
|
458
|
+
if not metric_type:
|
|
459
|
+
metric_type = self._metric_type
|
|
460
|
+
if not collection:
|
|
461
|
+
collection = self.collection
|
|
462
|
+
_search = {
|
|
463
|
+
"search_params": {
|
|
464
|
+
"metric_type": metric_type,
|
|
465
|
+
"params": {"nprobe": nprobe},
|
|
466
|
+
}
|
|
467
|
+
}
|
|
468
|
+
if metadata_field:
|
|
469
|
+
# document_meta
|
|
470
|
+
_search['metadata_field'] = metadata_field
|
|
471
|
+
_embed_ = self.create_embedding(
|
|
472
|
+
model_name=self.embedding_name
|
|
473
|
+
)
|
|
474
|
+
return Milvus(
|
|
475
|
+
embedding_function=_embed_,
|
|
476
|
+
collection_name=collection,
|
|
477
|
+
consistency_level='Bounded',
|
|
478
|
+
connection_args={
|
|
479
|
+
**self.kwargs
|
|
480
|
+
},
|
|
481
|
+
primary_field='pk',
|
|
482
|
+
text_field='text',
|
|
483
|
+
vector_field='vector',
|
|
484
|
+
**_search
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def similarity_search(self, query: str, collection: str = None, limit: int = 2) -> list:
|
|
488
|
+
if collection is None:
|
|
489
|
+
collection = self.collection
|
|
490
|
+
if self._embed_ is None:
|
|
491
|
+
_embed_ = self.create_embedding(
|
|
492
|
+
model_name=self.embedding_name
|
|
493
|
+
)
|
|
494
|
+
else:
|
|
495
|
+
_embed_ = self._embed_
|
|
496
|
+
vector_db = Milvus(
|
|
497
|
+
embedding_function=_embed_,
|
|
498
|
+
collection_name=collection,
|
|
499
|
+
consistency_level='Bounded',
|
|
500
|
+
connection_args={
|
|
501
|
+
**self.kwargs
|
|
502
|
+
},
|
|
503
|
+
primary_field='pk',
|
|
504
|
+
text_field='text',
|
|
505
|
+
vector_field='vector'
|
|
506
|
+
)
|
|
507
|
+
return vector_db.similarity_search(query, k=limit)
|
|
508
|
+
|
|
509
|
+
def search(
|
|
510
|
+
self,
|
|
511
|
+
payload: Union[dict, list],
|
|
512
|
+
collection: str = None,
|
|
513
|
+
limit: Optional[int] = None
|
|
514
|
+
) -> list:
|
|
515
|
+
args = {}
|
|
516
|
+
if collection is None:
|
|
517
|
+
collection = self.collection
|
|
518
|
+
if limit is not None:
|
|
519
|
+
args = {"limit": limit}
|
|
520
|
+
if isinstance(payload, dict):
|
|
521
|
+
payload = [payload]
|
|
522
|
+
result = self._client.search(
|
|
523
|
+
collection_name=collection,
|
|
524
|
+
data=payload,
|
|
525
|
+
**args
|
|
526
|
+
)
|
|
527
|
+
return result
|
|
528
|
+
|
|
529
|
+
def memory_retriever(self, num_results: int = 5) -> VectorStoreRetrieverMemory:
|
|
530
|
+
vectordb = Milvus.from_documents(
|
|
531
|
+
{},
|
|
532
|
+
self._embed_,
|
|
533
|
+
connection_args={**self.kwargs}
|
|
534
|
+
)
|
|
535
|
+
retriever = Milvus.as_retriever(vectordb, search_kwargs=dict(k=num_results))
|
|
536
|
+
return VectorStoreRetrieverMemory(retriever=retriever)
|
|
537
|
+
|
|
538
|
+
def save_context(self, memory: VectorStoreRetrieverMemory, context: list) -> None:
|
|
539
|
+
for val in context:
|
|
540
|
+
memory.save_context(val)
|
parrot/stores/qdrant.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from qdrant_client import QdrantClient # pylint: disable=import-error
|
|
3
|
+
from langchain_community.vectorstores import ( # pylint: disable=import-error, E0611
|
|
4
|
+
Qdrant
|
|
5
|
+
)
|
|
6
|
+
from .abstract import AbstractStore
|
|
7
|
+
from ..conf import (
|
|
8
|
+
QDRANT_PROTOCOL,
|
|
9
|
+
QDRANT_HOST,
|
|
10
|
+
QDRANT_PORT,
|
|
11
|
+
QDRANT_USE_HTTPS,
|
|
12
|
+
QDRANT_CONN_TYPE,
|
|
13
|
+
QDRANT_URL
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class QdrantStore(AbstractStore):
|
|
18
|
+
"""QdrantStore class.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
host (str): Qdrant host.
|
|
23
|
+
port (int): Qdrant port.
|
|
24
|
+
index_name (str): Qdrant index name.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def _create_qdrant_client(self, host, port, url, https, verify, qdrant_args):
|
|
28
|
+
"""
|
|
29
|
+
Creates a Qdrant client based on the provided configuration.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
host: Host of the Qdrant server (if using "server" connection).
|
|
33
|
+
port: Port of the Qdrant server (if using "server" connection).
|
|
34
|
+
url: URL of the Qdrant cloud service (if using "cloud" connection).
|
|
35
|
+
https: Whether to use HTTPS for the connection.
|
|
36
|
+
verify: Whether to verify the SSL certificate.
|
|
37
|
+
qdrant_args: Additional arguments for the Qdrant client.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A QdrantClient object.
|
|
41
|
+
"""
|
|
42
|
+
if url is not None:
|
|
43
|
+
return QdrantClient(
|
|
44
|
+
url=url,
|
|
45
|
+
port=None,
|
|
46
|
+
verify=verify,
|
|
47
|
+
**qdrant_args
|
|
48
|
+
)
|
|
49
|
+
else:
|
|
50
|
+
return QdrantClient(
|
|
51
|
+
host,
|
|
52
|
+
port=port,
|
|
53
|
+
https=https,
|
|
54
|
+
verify=verify,
|
|
55
|
+
**qdrant_args
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def __init__(self, embeddings = None, **kwargs):
|
|
59
|
+
super().__init__(embeddings, **kwargs)
|
|
60
|
+
self.host = kwargs.get("host", QDRANT_HOST)
|
|
61
|
+
self.port = kwargs.get("port", QDRANT_PORT)
|
|
62
|
+
qdrant_args = kwargs.get("qdrant_args", {})
|
|
63
|
+
connection_type = kwargs.get("connection_type", QDRANT_CONN_TYPE)
|
|
64
|
+
url = kwargs.get("url", QDRANT_URL)
|
|
65
|
+
if connection_type == "server":
|
|
66
|
+
self.client = self._create_qdrant_client(
|
|
67
|
+
self.host, self.port, url, QDRANT_USE_HTTPS, False, qdrant_args
|
|
68
|
+
)
|
|
69
|
+
elif connection_type == "cloud":
|
|
70
|
+
if url is None:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
"A URL is required for 'cloud' connection"
|
|
73
|
+
)
|
|
74
|
+
self.client = self._create_qdrant_client(
|
|
75
|
+
None, None, url, False, False, qdrant_args
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Invalid connection type: {connection_type}"
|
|
80
|
+
)
|
|
81
|
+
if url is not None:
|
|
82
|
+
self.url = url
|
|
83
|
+
else:
|
|
84
|
+
self.url = f"{QDRANT_PROTOCOL}://{self.host}"
|
|
85
|
+
if self.port:
|
|
86
|
+
self.url += f":{self.port}"
|
|
87
|
+
|
|
88
|
+
def get_vectorstore(self):
|
|
89
|
+
if self._embed_ is None:
|
|
90
|
+
_embed_ = self.create_embedding(
|
|
91
|
+
model_name=self.embedding_name
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
_embed_ = self._embed_
|
|
95
|
+
self.vector = Qdrant(
|
|
96
|
+
client=self.client,
|
|
97
|
+
collection_name=self.collection,
|
|
98
|
+
embeddings=_embed_,
|
|
99
|
+
)
|
|
100
|
+
return self.vector
|
|
101
|
+
|
|
102
|
+
async def load_documents(
|
|
103
|
+
self,
|
|
104
|
+
documents: list,
|
|
105
|
+
collection: str = None
|
|
106
|
+
):
|
|
107
|
+
if collection is None:
|
|
108
|
+
collection = self.collection
|
|
109
|
+
|
|
110
|
+
docstore = Qdrant.from_documents(
|
|
111
|
+
documents,
|
|
112
|
+
self._embed_,
|
|
113
|
+
url=self.url,
|
|
114
|
+
# location=":memory:", # Local mode with in-memory storage only
|
|
115
|
+
collection_name=collection,
|
|
116
|
+
force_recreate=False,
|
|
117
|
+
)
|
|
118
|
+
return docstore
|
|
119
|
+
|
|
120
|
+
def upsert(self, payload: dict, collection: str = None) -> None:
|
|
121
|
+
if collection is None:
|
|
122
|
+
collection = self.collection
|
|
123
|
+
self.client.upsert(
|
|
124
|
+
collection_name=collection,
|
|
125
|
+
points=self._embed_,
|
|
126
|
+
payload=payload
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def search(self, payload: dict, collection: str = None) -> dict:
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
async def delete_collection(self, collection: str = None) -> dict:
|
|
133
|
+
self.client.delete_collection(
|
|
134
|
+
collection_name=collection
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
async def create_collection(
|
|
138
|
+
self,
|
|
139
|
+
collection_name: str,
|
|
140
|
+
document: Any,
|
|
141
|
+
dimension: int = 768,
|
|
142
|
+
**kwargs
|
|
143
|
+
) -> dict:
|
|
144
|
+
# Here using drop_old=True to force recreate based on the first document
|
|
145
|
+
docstore = Qdrant.from_documents(
|
|
146
|
+
[document],
|
|
147
|
+
self._embed_,
|
|
148
|
+
url=self.url,
|
|
149
|
+
# location=":memory:", # Local mode with in-memory storage only
|
|
150
|
+
collection_name=collection_name,
|
|
151
|
+
force_recreate=True,
|
|
152
|
+
)
|
|
153
|
+
return docstore
|
parrot/tools/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from langchain_community.tools.yahoo_finance_news import YahooFinanceNewsTool
|
|
3
|
+
from langchain_community.tools import YouTubeSearchTool
|
|
4
|
+
from langchain_community.agent_toolkits import O365Toolkit
|
|
5
|
+
from navconfig import config
|
|
6
|
+
from .wikipedia import WikipediaTool, WikidataTool
|
|
7
|
+
from .asknews import AskNewsTool
|
|
8
|
+
from .duck import DuckDuckGoSearchTool, DuckDuckGoRelevantSearch
|
|
9
|
+
from .weather import OpenWeather, OpenWeatherMapTool
|
|
10
|
+
from .google import GoogleLocationFinder, GoogleSiteSearchTool, GoogleSearchTool
|
|
11
|
+
from .zipcode import ZipcodeAPIToolkit
|
|
12
|
+
from .bing import BingSearchTool
|
|
13
|
+
from .stack import StackExchangeTool
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
os.environ["USER_AGENT"] = "Parrot.AI/1.0"
|