ag2 0.4b1__py3-none-any.whl → 0.4.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (118) hide show
  1. ag2-0.4.2b1.dist-info/METADATA +19 -0
  2. ag2-0.4.2b1.dist-info/RECORD +6 -0
  3. ag2-0.4.2b1.dist-info/top_level.txt +1 -0
  4. ag2-0.4b1.dist-info/METADATA +0 -496
  5. ag2-0.4b1.dist-info/RECORD +0 -115
  6. ag2-0.4b1.dist-info/top_level.txt +0 -1
  7. autogen/__init__.py +0 -17
  8. autogen/_pydantic.py +0 -116
  9. autogen/agentchat/__init__.py +0 -42
  10. autogen/agentchat/agent.py +0 -142
  11. autogen/agentchat/assistant_agent.py +0 -85
  12. autogen/agentchat/chat.py +0 -306
  13. autogen/agentchat/contrib/__init__.py +0 -0
  14. autogen/agentchat/contrib/agent_builder.py +0 -787
  15. autogen/agentchat/contrib/agent_optimizer.py +0 -450
  16. autogen/agentchat/contrib/capabilities/__init__.py +0 -0
  17. autogen/agentchat/contrib/capabilities/agent_capability.py +0 -21
  18. autogen/agentchat/contrib/capabilities/generate_images.py +0 -297
  19. autogen/agentchat/contrib/capabilities/teachability.py +0 -406
  20. autogen/agentchat/contrib/capabilities/text_compressors.py +0 -72
  21. autogen/agentchat/contrib/capabilities/transform_messages.py +0 -92
  22. autogen/agentchat/contrib/capabilities/transforms.py +0 -565
  23. autogen/agentchat/contrib/capabilities/transforms_util.py +0 -120
  24. autogen/agentchat/contrib/capabilities/vision_capability.py +0 -217
  25. autogen/agentchat/contrib/captainagent.py +0 -487
  26. autogen/agentchat/contrib/gpt_assistant_agent.py +0 -545
  27. autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
  28. autogen/agentchat/contrib/graph_rag/document.py +0 -24
  29. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +0 -76
  30. autogen/agentchat/contrib/graph_rag/graph_query_engine.py +0 -50
  31. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +0 -56
  32. autogen/agentchat/contrib/img_utils.py +0 -390
  33. autogen/agentchat/contrib/llamaindex_conversable_agent.py +0 -123
  34. autogen/agentchat/contrib/llava_agent.py +0 -176
  35. autogen/agentchat/contrib/math_user_proxy_agent.py +0 -471
  36. autogen/agentchat/contrib/multimodal_conversable_agent.py +0 -128
  37. autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +0 -325
  38. autogen/agentchat/contrib/retrieve_assistant_agent.py +0 -56
  39. autogen/agentchat/contrib/retrieve_user_proxy_agent.py +0 -701
  40. autogen/agentchat/contrib/society_of_mind_agent.py +0 -203
  41. autogen/agentchat/contrib/swarm_agent.py +0 -414
  42. autogen/agentchat/contrib/text_analyzer_agent.py +0 -76
  43. autogen/agentchat/contrib/tool_retriever.py +0 -114
  44. autogen/agentchat/contrib/vectordb/__init__.py +0 -0
  45. autogen/agentchat/contrib/vectordb/base.py +0 -243
  46. autogen/agentchat/contrib/vectordb/chromadb.py +0 -326
  47. autogen/agentchat/contrib/vectordb/mongodb.py +0 -559
  48. autogen/agentchat/contrib/vectordb/pgvectordb.py +0 -958
  49. autogen/agentchat/contrib/vectordb/qdrant.py +0 -334
  50. autogen/agentchat/contrib/vectordb/utils.py +0 -126
  51. autogen/agentchat/contrib/web_surfer.py +0 -305
  52. autogen/agentchat/conversable_agent.py +0 -2908
  53. autogen/agentchat/groupchat.py +0 -1668
  54. autogen/agentchat/user_proxy_agent.py +0 -109
  55. autogen/agentchat/utils.py +0 -207
  56. autogen/browser_utils.py +0 -291
  57. autogen/cache/__init__.py +0 -10
  58. autogen/cache/abstract_cache_base.py +0 -78
  59. autogen/cache/cache.py +0 -182
  60. autogen/cache/cache_factory.py +0 -85
  61. autogen/cache/cosmos_db_cache.py +0 -150
  62. autogen/cache/disk_cache.py +0 -109
  63. autogen/cache/in_memory_cache.py +0 -61
  64. autogen/cache/redis_cache.py +0 -128
  65. autogen/code_utils.py +0 -745
  66. autogen/coding/__init__.py +0 -22
  67. autogen/coding/base.py +0 -113
  68. autogen/coding/docker_commandline_code_executor.py +0 -262
  69. autogen/coding/factory.py +0 -45
  70. autogen/coding/func_with_reqs.py +0 -203
  71. autogen/coding/jupyter/__init__.py +0 -22
  72. autogen/coding/jupyter/base.py +0 -32
  73. autogen/coding/jupyter/docker_jupyter_server.py +0 -164
  74. autogen/coding/jupyter/embedded_ipython_code_executor.py +0 -182
  75. autogen/coding/jupyter/jupyter_client.py +0 -224
  76. autogen/coding/jupyter/jupyter_code_executor.py +0 -161
  77. autogen/coding/jupyter/local_jupyter_server.py +0 -168
  78. autogen/coding/local_commandline_code_executor.py +0 -410
  79. autogen/coding/markdown_code_extractor.py +0 -44
  80. autogen/coding/utils.py +0 -57
  81. autogen/exception_utils.py +0 -46
  82. autogen/extensions/__init__.py +0 -0
  83. autogen/formatting_utils.py +0 -76
  84. autogen/function_utils.py +0 -362
  85. autogen/graph_utils.py +0 -148
  86. autogen/io/__init__.py +0 -15
  87. autogen/io/base.py +0 -105
  88. autogen/io/console.py +0 -43
  89. autogen/io/websockets.py +0 -213
  90. autogen/logger/__init__.py +0 -11
  91. autogen/logger/base_logger.py +0 -140
  92. autogen/logger/file_logger.py +0 -287
  93. autogen/logger/logger_factory.py +0 -29
  94. autogen/logger/logger_utils.py +0 -42
  95. autogen/logger/sqlite_logger.py +0 -459
  96. autogen/math_utils.py +0 -356
  97. autogen/oai/__init__.py +0 -33
  98. autogen/oai/anthropic.py +0 -428
  99. autogen/oai/bedrock.py +0 -600
  100. autogen/oai/cerebras.py +0 -264
  101. autogen/oai/client.py +0 -1148
  102. autogen/oai/client_utils.py +0 -167
  103. autogen/oai/cohere.py +0 -453
  104. autogen/oai/completion.py +0 -1216
  105. autogen/oai/gemini.py +0 -469
  106. autogen/oai/groq.py +0 -281
  107. autogen/oai/mistral.py +0 -279
  108. autogen/oai/ollama.py +0 -576
  109. autogen/oai/openai_utils.py +0 -810
  110. autogen/oai/together.py +0 -343
  111. autogen/retrieve_utils.py +0 -487
  112. autogen/runtime_logging.py +0 -163
  113. autogen/token_count_utils.py +0 -257
  114. autogen/types.py +0 -20
  115. autogen/version.py +0 -7
  116. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/LICENSE +0 -0
  117. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/NOTICE.md +0 -0
  118. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/WHEEL +0 -0
@@ -1,559 +0,0 @@
1
- # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
- #
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
- # SPDX-License-Identifier: MIT
7
- from copy import deepcopy
8
- from time import monotonic, sleep
9
- from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Set, Tuple, Union
10
-
11
- import numpy as np
12
- from pymongo import MongoClient, UpdateOne, errors
13
- from pymongo.collection import Collection
14
- from pymongo.driver_info import DriverInfo
15
- from pymongo.operations import SearchIndexModel
16
- from sentence_transformers import SentenceTransformer
17
-
18
- from .base import Document, ItemID, QueryResults, VectorDB
19
- from .utils import get_logger
20
-
21
- logger = get_logger(__name__)
22
-
23
- DEFAULT_INSERT_BATCH_SIZE = 100_000
24
- _SAMPLE_SENTENCE = ["The weather is lovely today in paradise."]
25
- _DELAY = 0.5
26
-
27
-
28
- def with_id_rename(docs: Iterable) -> List[Dict[str, Any]]:
29
- """Utility changes _id field from Collection into id for Document."""
30
- return [{**{k: v for k, v in d.items() if k != "_id"}, "id": d["_id"]} for d in docs]
31
-
32
-
33
- class MongoDBAtlasVectorDB(VectorDB):
34
- """
35
- A Collection object for MongoDB.
36
- """
37
-
38
- def __init__(
39
- self,
40
- connection_string: str = "",
41
- database_name: str = "vector_db",
42
- embedding_function: Callable = SentenceTransformer("all-MiniLM-L6-v2").encode,
43
- collection_name: str = None,
44
- index_name: str = "vector_index",
45
- overwrite: bool = False,
46
- wait_until_index_ready: float = None,
47
- wait_until_document_ready: float = None,
48
- ):
49
- """
50
- Initialize the vector database.
51
-
52
- Args:
53
- connection_string: str | The MongoDB connection string to connect to. Default is ''.
54
- database_name: str | The name of the database. Default is 'vector_db'.
55
- embedding_function: Callable | The embedding function used to generate the vector representation.
56
- collection_name: str | The name of the collection to create for this vector database
57
- Defaults to None
58
- index_name: str | Index name for the vector database, defaults to 'vector_index'
59
- overwrite: bool = False
60
- wait_until_index_ready: float | None | Blocking call to wait until the
61
- database indexes are ready. None, the default, means no wait.
62
- wait_until_document_ready: float | None | Blocking call to wait until the
63
- database indexes are ready. None, the default, means no wait.
64
- """
65
- self.embedding_function = embedding_function
66
- self.index_name = index_name
67
- self._wait_until_index_ready = wait_until_index_ready
68
- self._wait_until_document_ready = wait_until_document_ready
69
-
70
- # This will get the model dimension size by computing the embeddings dimensions
71
- self.dimensions = self._get_embedding_size()
72
-
73
- try:
74
- self.client = MongoClient(connection_string, driver=DriverInfo(name="autogen"))
75
- self.client.admin.command("ping")
76
- logger.debug("Successfully created MongoClient")
77
- except errors.ServerSelectionTimeoutError as err:
78
- raise ConnectionError("Could not connect to MongoDB server") from err
79
-
80
- self.db = self.client[database_name]
81
- logger.debug(f"Atlas Database name: {self.db.name}")
82
- if collection_name:
83
- self.active_collection = self.create_collection(collection_name, overwrite)
84
- else:
85
- self.active_collection = None
86
-
87
- def _is_index_ready(self, collection: Collection, index_name: str):
88
- """Check for the index name in the list of available search indexes to see if the
89
- specified index is of status READY
90
-
91
- Args:
92
- collection (Collection): MongoDB Collection to for the search indexes
93
- index_name (str): Vector Search Index name
94
-
95
- Returns:
96
- bool : True if the index is present and READY false otherwise
97
- """
98
- for index in collection.list_search_indexes(index_name):
99
- if index["type"] == "vectorSearch" and index["status"] == "READY":
100
- return True
101
- return False
102
-
103
- def _wait_for_index(self, collection: Collection, index_name: str, action: str = "create"):
104
- """Waits for the index action to be completed. Otherwise throws a TimeoutError.
105
-
106
- Timeout set on instantiation.
107
- action: "create" or "delete"
108
- """
109
- assert action in ["create", "delete"], f"{action=} must be create or delete."
110
- start = monotonic()
111
- while monotonic() - start < self._wait_until_index_ready:
112
- if action == "create" and self._is_index_ready(collection, index_name):
113
- return
114
- elif action == "delete" and len(list(collection.list_search_indexes())) == 0:
115
- return
116
- sleep(_DELAY)
117
-
118
- raise TimeoutError(f"Index {self.index_name} is not ready!")
119
-
120
- def _wait_for_document(self, collection: Collection, index_name: str, doc: Document):
121
- start = monotonic()
122
- while monotonic() - start < self._wait_until_document_ready:
123
- query_result = _vector_search(
124
- embedding_vector=np.array(self.embedding_function(doc["content"])).tolist(),
125
- n_results=1,
126
- collection=collection,
127
- index_name=index_name,
128
- )
129
- if query_result and query_result[0][0]["_id"] == doc["id"]:
130
- return
131
- sleep(_DELAY)
132
-
133
- raise TimeoutError(f"Document {self.index_name} is not ready!")
134
-
135
- def _get_embedding_size(self):
136
- return len(self.embedding_function(_SAMPLE_SENTENCE)[0])
137
-
138
- def list_collections(self):
139
- """
140
- List the collections in the vector database.
141
-
142
- Returns:
143
- List[str] | The list of collections.
144
- """
145
- return self.db.list_collection_names()
146
-
147
- def create_collection(
148
- self,
149
- collection_name: str,
150
- overwrite: bool = False,
151
- get_or_create: bool = True,
152
- ) -> Collection:
153
- """
154
- Create a collection in the vector database and create a vector search index in the collection.
155
-
156
- Args:
157
- collection_name: str | The name of the collection.
158
- overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
159
- get_or_create: bool | Whether to get or create the collection. Default is True
160
- """
161
- if overwrite:
162
- self.delete_collection(collection_name)
163
-
164
- if collection_name not in self.db.list_collection_names():
165
- # Create a new collection
166
- coll = self.db.create_collection(collection_name)
167
- self.create_index_if_not_exists(index_name=self.index_name, collection=coll)
168
- return coll
169
-
170
- if get_or_create:
171
- # The collection already exists, return it.
172
- coll = self.db[collection_name]
173
- self.create_index_if_not_exists(index_name=self.index_name, collection=coll)
174
- return coll
175
- else:
176
- # get_or_create is False and the collection already exists, raise an error.
177
- raise ValueError(f"Collection {collection_name} already exists.")
178
-
179
- def create_index_if_not_exists(self, index_name: str = "vector_index", collection: Collection = None) -> None:
180
- """
181
- Creates a vector search index on the specified collection in MongoDB.
182
-
183
- Args:
184
- MONGODB_INDEX (str, optional): The name of the vector search index to create. Defaults to "vector_search_index".
185
- collection (Collection, optional): The MongoDB collection to create the index on. Defaults to None.
186
- """
187
- if not self._is_index_ready(collection, index_name):
188
- self.create_vector_search_index(collection, index_name)
189
-
190
- def get_collection(self, collection_name: str = None) -> Collection:
191
- """
192
- Get the collection from the vector database.
193
-
194
- Args:
195
- collection_name: str | The name of the collection. Default is None. If None, return the
196
- current active collection.
197
-
198
- Returns:
199
- Collection | The collection object.
200
- """
201
- if collection_name is None:
202
- if self.active_collection is None:
203
- raise ValueError("No collection is specified.")
204
- else:
205
- logger.debug(
206
- f"No collection is specified. Using current active collection {self.active_collection.name}."
207
- )
208
- else:
209
- self.active_collection = self.db[collection_name]
210
-
211
- return self.active_collection
212
-
213
- def delete_collection(self, collection_name: str) -> None:
214
- """
215
- Delete the collection from the vector database.
216
-
217
- Args:
218
- collection_name: str | The name of the collection.
219
- """
220
- for index in self.db[collection_name].list_search_indexes():
221
- self.db[collection_name].drop_search_index(index["name"])
222
- if self._wait_until_index_ready:
223
- self._wait_for_index(self.db[collection_name], index["name"], "delete")
224
- return self.db[collection_name].drop()
225
-
226
- def create_vector_search_index(
227
- self,
228
- collection: Collection,
229
- index_name: Union[str, None] = "vector_index",
230
- similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine",
231
- ) -> None:
232
- """Create a vector search index in the collection.
233
-
234
- Args:
235
- collection: An existing Collection in the Atlas Database.
236
- index_name: Vector Search Index name.
237
- similarity: Algorithm used for measuring vector similarity.
238
- kwargs: Additional keyword arguments.
239
-
240
- Returns:
241
- None
242
- """
243
- search_index_model = SearchIndexModel(
244
- definition={
245
- "fields": [
246
- {
247
- "type": "vector",
248
- "numDimensions": self.dimensions,
249
- "path": "embedding",
250
- "similarity": similarity,
251
- },
252
- ]
253
- },
254
- name=index_name,
255
- type="vectorSearch",
256
- )
257
- # Create the search index
258
- try:
259
- collection.create_search_index(model=search_index_model)
260
- if self._wait_until_index_ready:
261
- self._wait_for_index(collection, index_name, "create")
262
- logger.debug(f"Search index {index_name} created successfully.")
263
- except Exception as e:
264
- logger.error(
265
- f"Error creating search index: {e}. \n"
266
- f"Your client must be connected to an Atlas cluster. "
267
- f"You may have to manually create a Collection and Search Index "
268
- f"if you are on a free/shared cluster."
269
- )
270
- raise e
271
-
272
- def insert_docs(
273
- self,
274
- docs: List[Document],
275
- collection_name: str = None,
276
- upsert: bool = False,
277
- batch_size=DEFAULT_INSERT_BATCH_SIZE,
278
- **kwargs,
279
- ) -> None:
280
- """Insert Documents and Vector Embeddings into the collection of the vector database.
281
-
282
- For large numbers of Documents, insertion is performed in batches.
283
-
284
- Args:
285
- docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
286
- collection_name: str | The name of the collection. Default is None.
287
- upsert: bool | Whether to update the document if it exists. Default is False.
288
- batch_size: Number of documents to be inserted in each batch
289
- """
290
- if not docs:
291
- logger.info("No documents to insert.")
292
- return
293
-
294
- collection = self.get_collection(collection_name)
295
- if upsert:
296
- self.update_docs(docs, collection.name, upsert=True)
297
- else:
298
- # Sanity checking the first document
299
- if docs[0].get("content") is None:
300
- raise ValueError("The document content is required.")
301
- if docs[0].get("id") is None:
302
- raise ValueError("The document id is required.")
303
-
304
- input_ids = set()
305
- result_ids = set()
306
- id_batch = []
307
- text_batch = []
308
- metadata_batch = []
309
- size = 0
310
- i = 0
311
- for doc in docs:
312
- id = doc["id"]
313
- text = doc["content"]
314
- metadata = doc.get("metadata", {})
315
- id_batch.append(id)
316
- text_batch.append(text)
317
- metadata_batch.append(metadata)
318
- id_size = 1 if isinstance(id, int) else len(id)
319
- size += len(text) + len(metadata) + id_size
320
- if (i + 1) % batch_size == 0 or size >= 47_000_000:
321
- result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch))
322
- input_ids.update(id_batch)
323
- id_batch = []
324
- text_batch = []
325
- metadata_batch = []
326
- size = 0
327
- i += 1
328
- if text_batch:
329
- result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch)) # type: ignore
330
- input_ids.update(id_batch)
331
-
332
- if result_ids != input_ids:
333
- logger.warning(
334
- "Possible data corruption. "
335
- "input_ids not in result_ids: {in_diff}.\n"
336
- "result_ids not in input_ids: {out_diff}".format(
337
- in_diff=input_ids.difference(result_ids), out_diff=result_ids.difference(input_ids)
338
- )
339
- )
340
- if self._wait_until_document_ready and docs:
341
- self._wait_for_document(collection, self.index_name, docs[-1])
342
-
343
- def _insert_batch(
344
- self, collection: Collection, texts: List[str], metadatas: List[Mapping[str, Any]], ids: List[ItemID]
345
- ) -> Set[ItemID]:
346
- """Compute embeddings for and insert a batch of Documents into the Collection.
347
-
348
- For performance reasons, we chose to call self.embedding_function just once,
349
- with the hopefully small tradeoff of having recreating Document dicts.
350
-
351
- Args:
352
- collection: MongoDB Collection
353
- texts: List of the main contents of each document
354
- metadatas: List of metadata mappings
355
- ids: List of ids. Note that these are stored as _id in Collection.
356
-
357
- Returns:
358
- List of ids inserted.
359
- """
360
- n_texts = len(texts)
361
- if n_texts == 0:
362
- return []
363
- # Embed and create the documents
364
- embeddings = self.embedding_function(texts).tolist()
365
- assert (
366
- len(embeddings) == n_texts
367
- ), f"The number of embeddings produced by self.embedding_function ({len(embeddings)} does not match the number of texts provided to it ({n_texts})."
368
- to_insert = [
369
- {"_id": i, "content": t, "metadata": m, "embedding": e}
370
- for i, t, m, e in zip(ids, texts, metadatas, embeddings)
371
- ]
372
- # insert the documents in MongoDB Atlas
373
- insert_result = collection.insert_many(to_insert) # type: ignore
374
- return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs
375
-
376
- def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None:
377
- """Update documents, including their embeddings, in the Collection.
378
-
379
- Optionally allow upsert as kwarg.
380
-
381
- Uses deepcopy to avoid changing docs.
382
-
383
- Args:
384
- docs: List[Document] | A list of documents.
385
- collection_name: str | The name of the collection. Default is None.
386
- kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection.
387
- """
388
-
389
- n_docs = len(docs)
390
- logger.info(f"Preparing to embed and update {n_docs=}")
391
- # Compute the embeddings
392
- embeddings: list[list[float]] = self.embedding_function([doc["content"] for doc in docs]).tolist()
393
- # Prepare the updates
394
- all_updates = []
395
- for i in range(n_docs):
396
- doc = deepcopy(docs[i])
397
- doc["embedding"] = embeddings[i]
398
- doc["_id"] = doc.pop("id")
399
-
400
- all_updates.append(UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=kwargs.get("upsert", False)))
401
- # Perform update in bulk
402
- collection = self.get_collection(collection_name)
403
- result = collection.bulk_write(all_updates)
404
-
405
- if self._wait_until_document_ready and docs:
406
- self._wait_for_document(collection, self.index_name, docs[-1])
407
-
408
- # Log a result summary
409
- logger.info(
410
- "Matched: %s, Modified: %s, Upserted: %s",
411
- result.matched_count,
412
- result.modified_count,
413
- result.upserted_count,
414
- )
415
-
416
- def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs):
417
- """
418
- Delete documents from the collection of the vector database.
419
-
420
- Args:
421
- ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
422
- collection_name: str | The name of the collection. Default is None.
423
- """
424
- collection = self.get_collection(collection_name)
425
- return collection.delete_many({"_id": {"$in": ids}})
426
-
427
- def get_docs_by_ids(
428
- self, ids: List[ItemID] = None, collection_name: str = None, include: List[str] = None, **kwargs
429
- ) -> List[Document]:
430
- """
431
- Retrieve documents from the collection of the vector database based on the ids.
432
-
433
- Args:
434
- ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
435
- collection_name: str | The name of the collection. Default is None.
436
- include: List[str] | The fields to include.
437
- If None, will include ["metadata", "content"], ids will always be included.
438
- Basically, use include to choose whether to include embedding and metadata
439
- kwargs: dict | Additional keyword arguments.
440
-
441
- Returns:
442
- List[Document] | The results.
443
- """
444
- if include is None:
445
- include_fields = {"_id": 1, "content": 1, "metadata": 1}
446
- else:
447
- include_fields = {k: 1 for k in set(include).union({"_id"})}
448
- collection = self.get_collection(collection_name)
449
- if ids is not None:
450
- docs = collection.find({"_id": {"$in": ids}}, include_fields)
451
- # Return with _id field from Collection into id for Document
452
- return with_id_rename(docs)
453
- else:
454
- docs = collection.find({}, include_fields)
455
- # Return with _id field from Collection into id for Document
456
- return with_id_rename(docs)
457
-
458
- def retrieve_docs(
459
- self,
460
- queries: List[str],
461
- collection_name: str = None,
462
- n_results: int = 10,
463
- distance_threshold: float = -1,
464
- **kwargs,
465
- ) -> QueryResults:
466
- """
467
- Retrieve documents from the collection of the vector database based on the queries.
468
-
469
- Args:
470
- queries: List[str] | A list of queries. Each query is a string.
471
- collection_name: str | The name of the collection. Default is None.
472
- n_results: int | The number of relevant documents to return. Default is 10.
473
- distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
474
- returned. Don't filter with it if < 0. Default is -1.
475
- kwargs: Dict | Additional keyword arguments. Ones of importance follow:
476
- oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm.
477
- It determines the number of nearest neighbor candidates to consider during the search phase.
478
- A higher value leads to more accuracy, but is slower. Default is 10
479
-
480
- Returns:
481
- QueryResults | For each query string, a list of nearest documents and their scores.
482
- """
483
- collection = self.get_collection(collection_name)
484
- # Trivial case of an empty collection
485
- if collection.count_documents({}) == 0:
486
- return []
487
-
488
- logger.debug(f"Using index: {self.index_name}")
489
- results = []
490
- for query_text in queries:
491
- # Compute embedding vector from semantic query
492
- logger.debug(f"Query: {query_text}")
493
- query_vector = np.array(self.embedding_function([query_text])).tolist()[0]
494
- # Find documents with similar vectors using the specified index
495
- query_result = _vector_search(
496
- query_vector,
497
- n_results,
498
- collection,
499
- self.index_name,
500
- distance_threshold,
501
- **kwargs,
502
- oversampling_factor=kwargs.get("oversampling_factor", 10),
503
- )
504
- # Change each _id key to id. with_id_rename, but with (doc, score) tuples
505
- results.append(
506
- [({**{k: v for k, v in d[0].items() if k != "_id"}, "id": d[0]["_id"]}, d[1]) for d in query_result]
507
- )
508
- return results
509
-
510
-
511
- def _vector_search(
512
- embedding_vector: List[float],
513
- n_results: int,
514
- collection: Collection,
515
- index_name: str,
516
- distance_threshold: float = -1.0,
517
- oversampling_factor=10,
518
- include_embedding=False,
519
- ) -> List[Tuple[Dict, float]]:
520
- """Core $vectorSearch Aggregation pipeline.
521
-
522
- Args:
523
- embedding_vector: Embedding vector of semantic query
524
- n_results: Number of documents to return. Defaults to 4.
525
- collection: MongoDB Collection with vector index
526
- index_name: Name of the vector index
527
- distance_threshold: Only distance measures smaller than this will be returned.
528
- Don't filter with it if 1 < x < 0. Default is -1.
529
- oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm.
530
- It determines the number of nearest neighbor candidates to consider during the search phase.
531
- A higher value leads to more accuracy, but is slower. Default = 10
532
-
533
- Returns:
534
- List of tuples of length n_results from Collection.
535
- Each tuple contains a document dict and a score.
536
- """
537
-
538
- pipeline = [
539
- {
540
- "$vectorSearch": {
541
- "index": index_name,
542
- "limit": n_results,
543
- "numCandidates": n_results * oversampling_factor,
544
- "queryVector": embedding_vector,
545
- "path": "embedding",
546
- }
547
- },
548
- {"$set": {"score": {"$meta": "vectorSearchScore"}}},
549
- ]
550
- if distance_threshold >= 0.0:
551
- similarity_threshold = 1.0 - distance_threshold
552
- pipeline.append({"$match": {"score": {"$gte": similarity_threshold}}})
553
-
554
- if not include_embedding:
555
- pipeline.append({"$project": {"embedding": 0}})
556
-
557
- logger.debug("pipeline: %s", pipeline)
558
- agg = collection.aggregate(pipeline)
559
- return [(doc, doc.pop("score")) for doc in agg]