ai-parrot 0.3.4__cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-parrot might be problematic. Click here for more details.

Files changed (109) hide show
  1. ai_parrot-0.3.4.dist-info/LICENSE +21 -0
  2. ai_parrot-0.3.4.dist-info/METADATA +319 -0
  3. ai_parrot-0.3.4.dist-info/RECORD +109 -0
  4. ai_parrot-0.3.4.dist-info/WHEEL +6 -0
  5. ai_parrot-0.3.4.dist-info/top_level.txt +3 -0
  6. parrot/__init__.py +21 -0
  7. parrot/chatbots/__init__.py +7 -0
  8. parrot/chatbots/abstract.py +728 -0
  9. parrot/chatbots/asktroc.py +16 -0
  10. parrot/chatbots/base.py +366 -0
  11. parrot/chatbots/basic.py +9 -0
  12. parrot/chatbots/bose.py +17 -0
  13. parrot/chatbots/cody.py +17 -0
  14. parrot/chatbots/copilot.py +83 -0
  15. parrot/chatbots/dataframe.py +103 -0
  16. parrot/chatbots/hragents.py +15 -0
  17. parrot/chatbots/odoo.py +17 -0
  18. parrot/chatbots/retrievals/__init__.py +578 -0
  19. parrot/chatbots/retrievals/constitutional.py +19 -0
  20. parrot/conf.py +110 -0
  21. parrot/crew/__init__.py +3 -0
  22. parrot/crew/tools/__init__.py +22 -0
  23. parrot/crew/tools/bing.py +13 -0
  24. parrot/crew/tools/config.py +43 -0
  25. parrot/crew/tools/duckgo.py +62 -0
  26. parrot/crew/tools/file.py +24 -0
  27. parrot/crew/tools/google.py +168 -0
  28. parrot/crew/tools/gtrends.py +16 -0
  29. parrot/crew/tools/md2pdf.py +25 -0
  30. parrot/crew/tools/rag.py +42 -0
  31. parrot/crew/tools/search.py +32 -0
  32. parrot/crew/tools/url.py +21 -0
  33. parrot/exceptions.cpython-311-x86_64-linux-gnu.so +0 -0
  34. parrot/handlers/__init__.py +4 -0
  35. parrot/handlers/bots.py +196 -0
  36. parrot/handlers/chat.py +162 -0
  37. parrot/interfaces/__init__.py +6 -0
  38. parrot/interfaces/database.py +29 -0
  39. parrot/llms/__init__.py +137 -0
  40. parrot/llms/abstract.py +47 -0
  41. parrot/llms/anthropic.py +42 -0
  42. parrot/llms/google.py +42 -0
  43. parrot/llms/groq.py +45 -0
  44. parrot/llms/hf.py +45 -0
  45. parrot/llms/openai.py +59 -0
  46. parrot/llms/pipes.py +114 -0
  47. parrot/llms/vertex.py +78 -0
  48. parrot/loaders/__init__.py +20 -0
  49. parrot/loaders/abstract.py +456 -0
  50. parrot/loaders/audio.py +106 -0
  51. parrot/loaders/basepdf.py +102 -0
  52. parrot/loaders/basevideo.py +280 -0
  53. parrot/loaders/csv.py +42 -0
  54. parrot/loaders/dir.py +37 -0
  55. parrot/loaders/excel.py +349 -0
  56. parrot/loaders/github.py +65 -0
  57. parrot/loaders/handlers/__init__.py +5 -0
  58. parrot/loaders/handlers/data.py +213 -0
  59. parrot/loaders/image.py +119 -0
  60. parrot/loaders/json.py +52 -0
  61. parrot/loaders/pdf.py +437 -0
  62. parrot/loaders/pdfchapters.py +142 -0
  63. parrot/loaders/pdffn.py +112 -0
  64. parrot/loaders/pdfimages.py +207 -0
  65. parrot/loaders/pdfmark.py +88 -0
  66. parrot/loaders/pdftables.py +145 -0
  67. parrot/loaders/ppt.py +30 -0
  68. parrot/loaders/qa.py +81 -0
  69. parrot/loaders/repo.py +103 -0
  70. parrot/loaders/rtd.py +65 -0
  71. parrot/loaders/txt.py +92 -0
  72. parrot/loaders/utils/__init__.py +1 -0
  73. parrot/loaders/utils/models.py +25 -0
  74. parrot/loaders/video.py +96 -0
  75. parrot/loaders/videolocal.py +120 -0
  76. parrot/loaders/vimeo.py +106 -0
  77. parrot/loaders/web.py +216 -0
  78. parrot/loaders/web_base.py +112 -0
  79. parrot/loaders/word.py +125 -0
  80. parrot/loaders/youtube.py +192 -0
  81. parrot/manager.py +166 -0
  82. parrot/models.py +372 -0
  83. parrot/py.typed +0 -0
  84. parrot/stores/__init__.py +48 -0
  85. parrot/stores/abstract.py +171 -0
  86. parrot/stores/milvus.py +632 -0
  87. parrot/stores/qdrant.py +153 -0
  88. parrot/tools/__init__.py +12 -0
  89. parrot/tools/abstract.py +53 -0
  90. parrot/tools/asknews.py +32 -0
  91. parrot/tools/bing.py +13 -0
  92. parrot/tools/duck.py +62 -0
  93. parrot/tools/google.py +170 -0
  94. parrot/tools/stack.py +26 -0
  95. parrot/tools/weather.py +70 -0
  96. parrot/tools/wikipedia.py +59 -0
  97. parrot/tools/zipcode.py +179 -0
  98. parrot/utils/__init__.py +2 -0
  99. parrot/utils/parsers/__init__.py +5 -0
  100. parrot/utils/parsers/toml.cpython-311-x86_64-linux-gnu.so +0 -0
  101. parrot/utils/toml.py +11 -0
  102. parrot/utils/types.cpython-311-x86_64-linux-gnu.so +0 -0
  103. parrot/utils/uv.py +11 -0
  104. parrot/version.py +10 -0
  105. resources/users/__init__.py +5 -0
  106. resources/users/handlers.py +13 -0
  107. resources/users/models.py +205 -0
  108. settings/__init__.py +0 -0
  109. settings/settings.py +51 -0
@@ -0,0 +1,632 @@
1
+ from typing import Optional, Union, Any
2
+ import asyncio
3
+ import uuid
4
+ import torch
5
+ from pymilvus import (
6
+ MilvusClient,
7
+ Collection,
8
+ FieldSchema,
9
+ CollectionSchema,
10
+ DataType,
11
+ connections,
12
+ db
13
+ )
14
+ from pymilvus.exceptions import MilvusException
15
+ from langchain_milvus import Milvus # pylint: disable=import-error, E0611
16
+ from langchain.memory import VectorStoreRetrieverMemory
17
+ from .abstract import AbstractStore
18
+ from ..conf import (
19
+ MILVUS_HOST,
20
+ MILVUS_PROTOCOL,
21
+ MILVUS_PORT,
22
+ MILVUS_URL,
23
+ MILVUS_TOKEN,
24
+ MILVUS_USER,
25
+ MILVUS_PASSWORD,
26
+ MILVUS_SECURE,
27
+ MILVUS_SERVER_NAME,
28
+ MILVUS_CA_CERT,
29
+ MILVUS_SERVER_CERT,
30
+ MILVUS_SERVER_KEY,
31
+ MILVUS_USE_TLSv2
32
+ )
33
+
34
+
35
+ class MilvusConnection:
36
+ """
37
+ Context Manager for Milvus Connections.
38
+ """
39
+ def __init__(self, alias: str = 'default', **kwargs):
40
+ self._connected: bool = False
41
+ self.kwargs = kwargs
42
+ self.alias: str = alias
43
+
44
+ def connect(self, alias: str = None, **kwargs):
45
+ if not alias:
46
+ alias = self.alias
47
+ conn = connections.connect(
48
+ alias=alias,
49
+ **kwargs
50
+ )
51
+ self._connected = True
52
+ return alias
53
+
54
+ def is_connected(self):
55
+ return self._connected
56
+
57
+ def close(self, alias: str = None):
58
+ try:
59
+ connections.disconnect(alias=alias)
60
+ finally:
61
+ self._connected = False
62
+
63
+ def __enter__(self):
64
+ self.connect(alias=self.alias, **self.kwargs)
65
+ return self
66
+
67
+ def __exit__(self, exc_type, exc_value, traceback):
68
+ self.close(alias=self.alias)
69
+ return self
70
+
71
+
72
+ class MilvusStore(AbstractStore):
73
+ """MilvusStore class.
74
+
75
+ Milvus is a Vector Database multi-layered.
76
+
77
+ Args:
78
+ host (str): Milvus host.
79
+ port (int): Milvus port.
80
+ url (str): Milvus URL.
81
+ """
82
+
83
+ def __init__(self, embeddings = None, **kwargs):
84
+ super().__init__(embeddings, **kwargs)
85
+ self.use_bge: bool = kwargs.pop("use_bge", False)
86
+ self.fastembed: bool = kwargs.pop("use_fastembed", False)
87
+ self.database: str = kwargs.pop('database', '')
88
+ self.collection = kwargs.pop('collection_name', '')
89
+ self.dimension: int = kwargs.pop("dimension", 768)
90
+ self._metric_type: str = kwargs.pop("metric_type", 'COSINE')
91
+ self._index_type: str = kwargs.pop("index_type", 'IVF_FLAT')
92
+ self.host = kwargs.pop("host", MILVUS_HOST)
93
+ self.port = kwargs.pop("port", MILVUS_PORT)
94
+ self.protocol = kwargs.pop("protocol", MILVUS_PROTOCOL)
95
+ self.create_database: bool = kwargs.pop('create_database', True)
96
+ self.url = kwargs.pop("url", MILVUS_URL)
97
+ self._client_id = kwargs.pop('client_id', 'default')
98
+ if not self.url:
99
+ self.url = f"{self.protocol}://{self.host}:{self.port}"
100
+ else:
101
+ # Extract host and port from URL
102
+ if not self.host:
103
+ self.host = self.url.split("://")[-1].split(":")[0]
104
+ if not self.port:
105
+ self.port = int(self.url.split(":")[-1])
106
+ self.token = kwargs.pop("token", MILVUS_TOKEN)
107
+ # user and password (if required)
108
+ self.user = kwargs.pop("user", MILVUS_USER)
109
+ self.password = kwargs.pop("password", MILVUS_PASSWORD)
110
+ # SSL/TLS
111
+ self._secure: bool = kwargs.pop('secure', MILVUS_SECURE)
112
+ self._server_name: str = kwargs.pop('server_name', MILVUS_SERVER_NAME)
113
+ self._cert: str = kwargs.pop('server_pem_path', MILVUS_SERVER_CERT)
114
+ self._ca_cert: str = kwargs.pop('ca_pem_path', MILVUS_CA_CERT)
115
+ self._cert_key: str = kwargs.pop('client_key_path', MILVUS_SERVER_KEY)
116
+ # Any other argument will be passed to the Milvus client
117
+ self.kwargs = {
118
+ "uri": self.url,
119
+ "host": self.host,
120
+ "port": self.port,
121
+ **kwargs
122
+ }
123
+ if self.token:
124
+ self.kwargs['token'] = self.token
125
+ if self.user:
126
+ self.kwargs['token'] = f"{self.user}:{self.password}"
127
+ # SSL Security:
128
+ if self._secure is True:
129
+ args = {
130
+ "secure": self._secure,
131
+ "server_name": self._server_name
132
+ }
133
+ if self._cert:
134
+ if MILVUS_USE_TLSv2 is True:
135
+ args['client_pem_path'] = self._cert
136
+ args['client_key_path'] = self._cert_key
137
+ else:
138
+ args["server_pem_path"] = self._cert
139
+ if self._ca_cert:
140
+ args['ca_pem_path'] = self._ca_cert
141
+ self.kwargs = {**self.kwargs, **args}
142
+ # 1. Check if database exists:
143
+ if self.database:
144
+ self.kwargs['db_name'] = self.database
145
+ self.use_database(
146
+ self.database,
147
+ create=self.create_database
148
+ )
149
+
150
+ async def __aenter__(self):
151
+ try:
152
+ self.tensor = torch.randn(1000, 1000).cuda()
153
+ except RuntimeError:
154
+ self.tensor = None
155
+ if self._embed_ is None:
156
+ self._embed_ = self.create_embedding(
157
+ model_name=self.embedding_name
158
+ )
159
+ return self
160
+
161
+ async def __aexit__(self, exc_type, exc_value, traceback):
162
+ # closing Embedding
163
+ self._embed_ = None
164
+ del self.tensor
165
+ try:
166
+ self.close(alias=self._client_id)
167
+ torch.cuda.empty_cache()
168
+ except RuntimeError:
169
+ pass
170
+
171
+ def connection(self, alias: str = None):
172
+ if not alias:
173
+ # self._client_id = str(uuid.uuid4())
174
+ self._client_id = 'uri-connection'
175
+ else:
176
+ self._client_id = alias
177
+ # making the connection:
178
+ self._client, self._client_id = self.connect(
179
+ alias=self._client_id
180
+ )
181
+ return self
182
+
183
+ def connect(self, alias: str = None) -> tuple:
184
+ # 1. Set up a pyMilvus default connection
185
+ # Unique connection:
186
+ if not alias:
187
+ alias = "default"
188
+ _ = connections.connect(
189
+ alias=alias,
190
+ **self.kwargs
191
+ )
192
+ client = MilvusClient(
193
+ **self.kwargs
194
+ )
195
+ self._connected = True
196
+ return client, alias
197
+
198
+ def close(self, alias: str = "default"):
199
+ connections.disconnect(alias=alias)
200
+ try:
201
+ self._client.close()
202
+ finally:
203
+ self._connected = False
204
+
205
+ def create_db(self, db_name: str, alias: str = 'default', **kwargs) -> bool:
206
+ args = {
207
+ "uri": self.url,
208
+ "host": self.host,
209
+ "port": self.port,
210
+ **kwargs
211
+ }
212
+ try:
213
+ conn = connections.connect(alias, **args)
214
+ db.create_database(db_name)
215
+ self.logger.notice(
216
+ f"Database {db_name} created successfully."
217
+ )
218
+ except Exception as e:
219
+ raise ValueError(
220
+ f"Error creating database: {e}"
221
+ )
222
+ finally:
223
+ connections.disconnect(alias="uri-connection")
224
+
225
+ def use_database(
226
+ self,
227
+ db_name: str,
228
+ alias:str = 'default',
229
+ create: bool = False
230
+ ) -> None:
231
+ try:
232
+ conn = connections.connect(alias, **self.kwargs)
233
+ except MilvusException as exc:
234
+ if "database not found" in exc.message:
235
+ args = self.kwargs.copy()
236
+ del args['db_name']
237
+ self.create_db(db_name, alias=alias, **args)
238
+ # re-connect:
239
+ try:
240
+ _ = connections.connect(alias, **self.kwargs)
241
+ if db_name not in db.list_database(using=alias):
242
+ if self.create_database is True or create is True:
243
+ try:
244
+ db.create_database(db_name, using=alias, timeout=10)
245
+ self.logger.notice(
246
+ f"Database {db_name} created successfully."
247
+ )
248
+ except Exception as e:
249
+ raise ValueError(
250
+ f"Error creating database: {e}"
251
+ )
252
+ else:
253
+ raise ValueError(
254
+ f"Database {db_name} does not exist."
255
+ )
256
+ finally:
257
+ connections.disconnect(alias=alias)
258
+
259
+ def setup_vector(self):
260
+ self.vector = Milvus(
261
+ self._embed_,
262
+ consistency_level='Bounded',
263
+ connection_args={**self.kwargs},
264
+ collection_name=self.collection,
265
+ )
266
+ return self.vector
267
+
268
+ def get_vectorstore(self):
269
+ return self.get_vector()
270
+
271
+ def collection_exists(self, collection_name: str) -> bool:
272
+ if collection_name in self._client.list_collections():
273
+ return True
274
+ return False
275
+
276
+ def check_state(self, collection_name: str) -> dict:
277
+ return self._client.get_load_state(collection_name=collection_name)
278
+
279
+ async def delete_collection(self, collection: str = None) -> dict:
280
+ self._client.drop_collection(
281
+ collection_name=collection
282
+ )
283
+
284
+ async def create_collection(
285
+ self,
286
+ collection_name: str,
287
+ document: Any = None,
288
+ dimension: int = 768,
289
+ index_type: str = None,
290
+ metric_type: str = None,
291
+ schema_type: str = 'default',
292
+ metadata_field: str = None,
293
+ **kwargs
294
+ ) -> dict:
295
+ """create_collection.
296
+
297
+ Create a Schema (Milvus Collection) on the Current Database.
298
+
299
+ Args:
300
+ collection_name (str): Collection Name.
301
+ document (Any): List of Documents.
302
+ dimension (int, optional): Vector Dimension. Defaults to 768.
303
+ index_type (str, optional): Default index type of Vector Field. Defaults to "HNSW".
304
+ metric_type (str, optional): Default Metric for Vector Index. Defaults to "L2".
305
+ schema_type (str, optional): Description of Model. Defaults to 'default'.
306
+
307
+ Returns:
308
+ dict: _description_
309
+ """
310
+ # Check if collection exists:
311
+ if self.collection_exists(collection_name):
312
+ self.logger.warning(
313
+ f"Collection {collection_name} already exists."
314
+ )
315
+ return None
316
+ idx_params = {}
317
+ if not index_type:
318
+ index_type = self._index_type
319
+ if index_type == 'HNSW':
320
+ idx_params = {
321
+ "M": 36,
322
+ "efConstruction": 1024
323
+ }
324
+ elif index_type in ('IVF_FLAT', 'SCANN', 'IVF_SQ8'):
325
+ idx_params = {
326
+ "nlist": 1024
327
+ }
328
+ elif index_type in ('IVF_PQ'):
329
+ idx_params = {
330
+ "nlist": 1024,
331
+ "m": 16
332
+ }
333
+ if not metric_type:
334
+ metric_type = self._metric_type # default metric type
335
+ if schema_type == 'default':
336
+ # Default Collection for all loaders:
337
+ schema = MilvusClient.create_schema(
338
+ auto_id=False,
339
+ enable_dynamic_field=True,
340
+ description=collection_name
341
+ )
342
+ schema.add_field(
343
+ field_name="pk",
344
+ datatype=DataType.INT64,
345
+ is_primary=True,
346
+ auto_id=True,
347
+ max_length=100
348
+ )
349
+ schema.add_field(
350
+ field_name="index",
351
+ datatype=DataType.VARCHAR,
352
+ max_length=65535
353
+ )
354
+ schema.add_field(
355
+ field_name="url",
356
+ datatype=DataType.VARCHAR,
357
+ max_length=65535
358
+ )
359
+ schema.add_field(
360
+ field_name="source",
361
+ datatype=DataType.VARCHAR,
362
+ max_length=65535
363
+ )
364
+ schema.add_field(
365
+ field_name="filename",
366
+ datatype=DataType.VARCHAR,
367
+ max_length=65535
368
+ )
369
+ schema.add_field(
370
+ field_name="question",
371
+ datatype=DataType.VARCHAR,
372
+ max_length=65535
373
+ )
374
+ schema.add_field(
375
+ field_name="answer",
376
+ datatype=DataType.VARCHAR,
377
+ max_length=65535
378
+ )
379
+ schema.add_field(
380
+ field_name="source_type",
381
+ datatype=DataType.VARCHAR,
382
+ max_length=128
383
+ )
384
+ schema.add_field(
385
+ field_name="type",
386
+ datatype=DataType.VARCHAR,
387
+ max_length=65535
388
+ )
389
+ schema.add_field(
390
+ field_name="text",
391
+ datatype=DataType.VARCHAR,
392
+ description="Text",
393
+ max_length=65535
394
+ )
395
+ schema.add_field(
396
+ field_name="summary",
397
+ datatype=DataType.VARCHAR,
398
+ description="Summary (refine resume)",
399
+ max_length=65535
400
+ )
401
+ schema.add_field(
402
+ field_name="vector",
403
+ datatype=DataType.FLOAT_VECTOR,
404
+ dim=dimension,
405
+ description="vector"
406
+ )
407
+ # schema.add_field(
408
+ # field_name="embedding",
409
+ # datatype=DataType.FLOAT_VECTOR,
410
+ # dim=dimension,
411
+ # description="Binary Embeddings"
412
+ # )
413
+ schema.add_field(
414
+ field_name="document_meta",
415
+ datatype=DataType.JSON,
416
+ description="Custom Metadata information"
417
+ )
418
+ index_params = self._client.prepare_index_params()
419
+ index_params.add_index(
420
+ field_name="pk",
421
+ index_type="STL_SORT"
422
+ )
423
+ index_params.add_index(
424
+ field_name="text",
425
+ index_type="marisa-trie"
426
+ )
427
+ index_params.add_index(
428
+ field_name="summary",
429
+ index_type="marisa-trie"
430
+ )
431
+ index_params.add_index(
432
+ field_name="vector",
433
+ index_type=index_type,
434
+ metric_type=metric_type,
435
+ params=idx_params
436
+ )
437
+ self._client.create_collection(
438
+ collection_name=collection_name,
439
+ schema=schema,
440
+ index_params=index_params,
441
+ num_shards=2
442
+ )
443
+ await asyncio.sleep(2)
444
+ res = self._client.get_load_state(
445
+ collection_name=collection_name
446
+ )
447
+ return None
448
+ else:
449
+ self._client.create_collection(
450
+ collection_name=collection_name,
451
+ dimension=dimension
452
+ )
453
+ if metadata_field:
454
+ kwargs['metadata_field'] = metadata_field
455
+ # Here using drop_old=True to force recreate based on the first document
456
+ docstore = Milvus.from_documents(
457
+ [document], # Only the first document
458
+ self._embed_,
459
+ connection_args={**self.kwargs},
460
+ collection_name=collection_name,
461
+ drop_old=True,
462
+ # consistency_level='Session',
463
+ primary_field='pk',
464
+ text_field='text',
465
+ vector_field='vector',
466
+ **kwargs
467
+ )
468
+ return docstore
469
+
470
+ async def load_documents(
471
+ self,
472
+ documents: list,
473
+ collection: str = None,
474
+ upsert: bool = False,
475
+ attribute: str = 'source_type',
476
+ metadata_field: str = None,
477
+ **kwargs
478
+ ):
479
+ if not collection:
480
+ collection = self.collection
481
+ try:
482
+ tensor = torch.randn(1000, 1000).cuda()
483
+ except Exception:
484
+ tensor = None
485
+ if upsert is True:
486
+ # get first document
487
+ doc = documents[0]
488
+ # getting source type:
489
+ doc_type = doc.metadata.get('attribute', None)
490
+ if attribute:
491
+ deleted = self._client.delete(
492
+ collection_name=collection,
493
+ filter=f'{attribute} == "{doc_type}"'
494
+ )
495
+ self.logger.notice(
496
+ f"Deleted documents with {attribute} {attribute}: {deleted}"
497
+ )
498
+ if metadata_field:
499
+ # document_meta
500
+ kwargs['metadata_field'] = metadata_field
501
+ docstore = Milvus.from_documents(
502
+ documents,
503
+ self._embed_,
504
+ connection_args={**self.kwargs},
505
+ collection_name=collection,
506
+ consistency_level='Bounded',
507
+ drop_old=False,
508
+ primary_field='pk',
509
+ text_field='text',
510
+ vector_field='vector',
511
+ **kwargs
512
+ )
513
+ del tensor
514
+ return docstore
515
+
516
+ def upsert(self, payload: dict, collection: str = None) -> None:
517
+ pass
518
+
519
+ def insert(
520
+ self,
521
+ payload: Union[dict, list],
522
+ collection: Union[str, None] = None
523
+ ) -> dict:
524
+ if collection is None:
525
+ collection = self.collection
526
+ result = self._client.insert(
527
+ collection_name=collection,
528
+ data=payload
529
+ )
530
+ collection.flush()
531
+ return result
532
+
533
+ def get_vector(
534
+ self,
535
+ collection: Union[str, None] = None,
536
+ metric_type: str = None,
537
+ nprobe: int = 32,
538
+ metadata_field: str = None,
539
+ consistency_level: str = 'session'
540
+ ) -> Milvus:
541
+ if not metric_type:
542
+ metric_type = self._metric_type
543
+ if not collection:
544
+ collection = self.collection
545
+ _search = {
546
+ "search_params": {
547
+ "metric_type": metric_type,
548
+ "params": {"nprobe": nprobe, "nlist": 1024},
549
+ }
550
+ }
551
+ if metadata_field:
552
+ # document_meta
553
+ _search['metadata_field'] = metadata_field
554
+ _embed_ = self.create_embedding(
555
+ model_name=self.embedding_name
556
+ )
557
+ return Milvus(
558
+ embedding_function=_embed_,
559
+ collection_name=collection,
560
+ consistency_level=consistency_level,
561
+ connection_args={
562
+ **self.kwargs
563
+ },
564
+ primary_field='pk',
565
+ text_field='text',
566
+ vector_field='vector',
567
+ **_search
568
+ )
569
+
570
+ def similarity_search(
571
+ self,
572
+ query: str,
573
+ collection: Union[str, None] = None,
574
+ limit: int = 2,
575
+ consistency_level: str = 'Bounded'
576
+ ) -> list:
577
+ if collection is None:
578
+ collection = self.collection
579
+ if self._embed_ is None:
580
+ _embed_ = self.create_embedding(
581
+ model_name=self.embedding_name
582
+ )
583
+ else:
584
+ _embed_ = self._embed_
585
+ vector_db = Milvus(
586
+ embedding_function=_embed_,
587
+ collection_name=collection,
588
+ consistency_level=consistency_level,
589
+ connection_args={
590
+ **self.kwargs
591
+ },
592
+ primary_field='pk',
593
+ text_field='text',
594
+ vector_field='vector'
595
+ )
596
+ return vector_db.similarity_search(query, k=limit)
597
+
598
+ def search(
599
+ self,
600
+ payload: Union[dict, list],
601
+ collection: Union[str, None] = None,
602
+ limit: Optional[int] = None
603
+ ) -> list:
604
+ args = {}
605
+ if collection is None:
606
+ collection = self.collection
607
+ if limit is not None:
608
+ args = {"limit": limit}
609
+ if isinstance(payload, dict):
610
+ payload = [payload]
611
+ result = self._client.search(
612
+ collection_name=collection,
613
+ data=payload,
614
+ **args
615
+ )
616
+ return result
617
+
618
+ def memory_retriever(self, num_results: int = 5) -> VectorStoreRetrieverMemory:
619
+ vectordb = Milvus.from_documents(
620
+ {},
621
+ self._embed_,
622
+ connection_args={**self.kwargs}
623
+ )
624
+ retriever = Milvus.as_retriever(
625
+ vectordb,
626
+ search_kwargs=dict(k=num_results)
627
+ )
628
+ return VectorStoreRetrieverMemory(retriever=retriever)
629
+
630
+ def save_context(self, memory: VectorStoreRetrieverMemory, context: list) -> None:
631
+ for val in context:
632
+ memory.save_context(val)