ag2 0.4b1__py3-none-any.whl → 0.4.2b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ag2 might be problematic. Click here for more details.

Files changed (118) hide show
  1. ag2-0.4.2b1.dist-info/METADATA +19 -0
  2. ag2-0.4.2b1.dist-info/RECORD +6 -0
  3. ag2-0.4.2b1.dist-info/top_level.txt +1 -0
  4. ag2-0.4b1.dist-info/METADATA +0 -496
  5. ag2-0.4b1.dist-info/RECORD +0 -115
  6. ag2-0.4b1.dist-info/top_level.txt +0 -1
  7. autogen/__init__.py +0 -17
  8. autogen/_pydantic.py +0 -116
  9. autogen/agentchat/__init__.py +0 -42
  10. autogen/agentchat/agent.py +0 -142
  11. autogen/agentchat/assistant_agent.py +0 -85
  12. autogen/agentchat/chat.py +0 -306
  13. autogen/agentchat/contrib/__init__.py +0 -0
  14. autogen/agentchat/contrib/agent_builder.py +0 -787
  15. autogen/agentchat/contrib/agent_optimizer.py +0 -450
  16. autogen/agentchat/contrib/capabilities/__init__.py +0 -0
  17. autogen/agentchat/contrib/capabilities/agent_capability.py +0 -21
  18. autogen/agentchat/contrib/capabilities/generate_images.py +0 -297
  19. autogen/agentchat/contrib/capabilities/teachability.py +0 -406
  20. autogen/agentchat/contrib/capabilities/text_compressors.py +0 -72
  21. autogen/agentchat/contrib/capabilities/transform_messages.py +0 -92
  22. autogen/agentchat/contrib/capabilities/transforms.py +0 -565
  23. autogen/agentchat/contrib/capabilities/transforms_util.py +0 -120
  24. autogen/agentchat/contrib/capabilities/vision_capability.py +0 -217
  25. autogen/agentchat/contrib/captainagent.py +0 -487
  26. autogen/agentchat/contrib/gpt_assistant_agent.py +0 -545
  27. autogen/agentchat/contrib/graph_rag/__init__.py +0 -0
  28. autogen/agentchat/contrib/graph_rag/document.py +0 -24
  29. autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +0 -76
  30. autogen/agentchat/contrib/graph_rag/graph_query_engine.py +0 -50
  31. autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +0 -56
  32. autogen/agentchat/contrib/img_utils.py +0 -390
  33. autogen/agentchat/contrib/llamaindex_conversable_agent.py +0 -123
  34. autogen/agentchat/contrib/llava_agent.py +0 -176
  35. autogen/agentchat/contrib/math_user_proxy_agent.py +0 -471
  36. autogen/agentchat/contrib/multimodal_conversable_agent.py +0 -128
  37. autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +0 -325
  38. autogen/agentchat/contrib/retrieve_assistant_agent.py +0 -56
  39. autogen/agentchat/contrib/retrieve_user_proxy_agent.py +0 -701
  40. autogen/agentchat/contrib/society_of_mind_agent.py +0 -203
  41. autogen/agentchat/contrib/swarm_agent.py +0 -414
  42. autogen/agentchat/contrib/text_analyzer_agent.py +0 -76
  43. autogen/agentchat/contrib/tool_retriever.py +0 -114
  44. autogen/agentchat/contrib/vectordb/__init__.py +0 -0
  45. autogen/agentchat/contrib/vectordb/base.py +0 -243
  46. autogen/agentchat/contrib/vectordb/chromadb.py +0 -326
  47. autogen/agentchat/contrib/vectordb/mongodb.py +0 -559
  48. autogen/agentchat/contrib/vectordb/pgvectordb.py +0 -958
  49. autogen/agentchat/contrib/vectordb/qdrant.py +0 -334
  50. autogen/agentchat/contrib/vectordb/utils.py +0 -126
  51. autogen/agentchat/contrib/web_surfer.py +0 -305
  52. autogen/agentchat/conversable_agent.py +0 -2908
  53. autogen/agentchat/groupchat.py +0 -1668
  54. autogen/agentchat/user_proxy_agent.py +0 -109
  55. autogen/agentchat/utils.py +0 -207
  56. autogen/browser_utils.py +0 -291
  57. autogen/cache/__init__.py +0 -10
  58. autogen/cache/abstract_cache_base.py +0 -78
  59. autogen/cache/cache.py +0 -182
  60. autogen/cache/cache_factory.py +0 -85
  61. autogen/cache/cosmos_db_cache.py +0 -150
  62. autogen/cache/disk_cache.py +0 -109
  63. autogen/cache/in_memory_cache.py +0 -61
  64. autogen/cache/redis_cache.py +0 -128
  65. autogen/code_utils.py +0 -745
  66. autogen/coding/__init__.py +0 -22
  67. autogen/coding/base.py +0 -113
  68. autogen/coding/docker_commandline_code_executor.py +0 -262
  69. autogen/coding/factory.py +0 -45
  70. autogen/coding/func_with_reqs.py +0 -203
  71. autogen/coding/jupyter/__init__.py +0 -22
  72. autogen/coding/jupyter/base.py +0 -32
  73. autogen/coding/jupyter/docker_jupyter_server.py +0 -164
  74. autogen/coding/jupyter/embedded_ipython_code_executor.py +0 -182
  75. autogen/coding/jupyter/jupyter_client.py +0 -224
  76. autogen/coding/jupyter/jupyter_code_executor.py +0 -161
  77. autogen/coding/jupyter/local_jupyter_server.py +0 -168
  78. autogen/coding/local_commandline_code_executor.py +0 -410
  79. autogen/coding/markdown_code_extractor.py +0 -44
  80. autogen/coding/utils.py +0 -57
  81. autogen/exception_utils.py +0 -46
  82. autogen/extensions/__init__.py +0 -0
  83. autogen/formatting_utils.py +0 -76
  84. autogen/function_utils.py +0 -362
  85. autogen/graph_utils.py +0 -148
  86. autogen/io/__init__.py +0 -15
  87. autogen/io/base.py +0 -105
  88. autogen/io/console.py +0 -43
  89. autogen/io/websockets.py +0 -213
  90. autogen/logger/__init__.py +0 -11
  91. autogen/logger/base_logger.py +0 -140
  92. autogen/logger/file_logger.py +0 -287
  93. autogen/logger/logger_factory.py +0 -29
  94. autogen/logger/logger_utils.py +0 -42
  95. autogen/logger/sqlite_logger.py +0 -459
  96. autogen/math_utils.py +0 -356
  97. autogen/oai/__init__.py +0 -33
  98. autogen/oai/anthropic.py +0 -428
  99. autogen/oai/bedrock.py +0 -600
  100. autogen/oai/cerebras.py +0 -264
  101. autogen/oai/client.py +0 -1148
  102. autogen/oai/client_utils.py +0 -167
  103. autogen/oai/cohere.py +0 -453
  104. autogen/oai/completion.py +0 -1216
  105. autogen/oai/gemini.py +0 -469
  106. autogen/oai/groq.py +0 -281
  107. autogen/oai/mistral.py +0 -279
  108. autogen/oai/ollama.py +0 -576
  109. autogen/oai/openai_utils.py +0 -810
  110. autogen/oai/together.py +0 -343
  111. autogen/retrieve_utils.py +0 -487
  112. autogen/runtime_logging.py +0 -163
  113. autogen/token_count_utils.py +0 -257
  114. autogen/types.py +0 -20
  115. autogen/version.py +0 -7
  116. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/LICENSE +0 -0
  117. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/NOTICE.md +0 -0
  118. {ag2-0.4b1.dist-info → ag2-0.4.2b1.dist-info}/WHEEL +0 -0
@@ -1,958 +0,0 @@
1
- # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
2
- #
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # Portions derived from https://github.com/microsoft/autogen are under the MIT License.
6
- # SPDX-License-Identifier: MIT
7
- import os
8
- import re
9
- import urllib.parse
10
- from typing import Callable, List, Optional, Union
11
-
12
- import numpy as np
13
- from sentence_transformers import SentenceTransformer
14
-
15
- from .base import Document, ItemID, QueryResults, VectorDB
16
- from .utils import get_logger
17
-
18
- try:
19
- import pgvector
20
- from pgvector.psycopg import register_vector
21
- except ImportError:
22
- raise ImportError("Please install pgvector: `pip install pgvector`")
23
-
24
- try:
25
- import psycopg
26
- except ImportError:
27
- raise ImportError("Please install pgvector: `pip install psycopg`")
28
-
29
- PGVECTOR_MAX_BATCH_SIZE = os.environ.get("PGVECTOR_MAX_BATCH_SIZE", 40000)
30
- logger = get_logger(__name__)
31
-
32
-
33
- class Collection:
34
- """
35
- A Collection object for PGVector.
36
-
37
- Attributes:
38
- client: The PGVector client.
39
- collection_name (str): The name of the collection. Default is "documents".
40
- embedding_function (Callable): The embedding function used to generate the vector representation.
41
- Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
42
- Models can be chosen from:
43
- https://huggingface.co/models?library=sentence-transformers
44
- metadata (Optional[dict]): The metadata of the collection.
45
- get_or_create (Optional): The flag indicating whether to get or create the collection.
46
- """
47
-
48
- def __init__(
49
- self,
50
- client=None,
51
- collection_name: str = "autogen-docs",
52
- embedding_function: Callable = None,
53
- metadata=None,
54
- get_or_create=None,
55
- ):
56
- """
57
- Initialize the Collection object.
58
-
59
- Args:
60
- client: The PostgreSQL client.
61
- collection_name: The name of the collection. Default is "documents".
62
- embedding_function: The embedding function used to generate the vector representation.
63
- metadata: The metadata of the collection.
64
- get_or_create: The flag indicating whether to get or create the collection.
65
- Returns:
66
- None
67
- """
68
- self.client = client
69
- self.name = self.set_collection_name(collection_name)
70
- self.require_embeddings_or_documents = False
71
- self.ids = []
72
- if embedding_function:
73
- self.embedding_function = embedding_function
74
- else:
75
- self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
76
- self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
77
- self.documents = ""
78
- self.get_or_create = get_or_create
79
- # This will get the model dimension size by computing the embeddings dimensions
80
- sentences = [
81
- "The weather is lovely today in paradise.",
82
- ]
83
- embeddings = self.embedding_function(sentences)
84
- self.dimension = len(embeddings[0])
85
-
86
- def set_collection_name(self, collection_name) -> str:
87
- name = re.sub("-", "_", collection_name)
88
- self.name = name
89
- return self.name
90
-
91
- def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
92
- """
93
- Add documents to the collection.
94
-
95
- Args:
96
- ids (List[ItemID]): A list of document IDs.
97
- embeddings (List): A list of document embeddings. Optional
98
- metadatas (List): A list of document metadatas. Optional
99
- documents (List): A list of documents.
100
-
101
- Returns:
102
- None
103
- """
104
- cursor = self.client.cursor()
105
- sql_values = []
106
- if embeddings is not None and metadatas is not None:
107
- for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
108
- metadata = re.sub("'", '"', str(metadata))
109
- sql_values.append((doc_id, embedding, metadata, document))
110
- sql_string = (
111
- f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
112
- )
113
- elif embeddings is not None:
114
- for doc_id, embedding, document in zip(ids, embeddings, documents):
115
- sql_values.append((doc_id, embedding, document))
116
- sql_string = f"INSERT INTO {self.name} (id, embedding, documents) " f"VALUES (%s, %s, %s);\n"
117
- elif metadatas is not None:
118
- for doc_id, metadata, document in zip(ids, metadatas, documents):
119
- metadata = re.sub("'", '"', str(metadata))
120
- embedding = self.embedding_function(document)
121
- sql_values.append((doc_id, metadata, embedding, document))
122
- sql_string = (
123
- f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
124
- )
125
- else:
126
- for doc_id, document in zip(ids, documents):
127
- embedding = self.embedding_function(document)
128
- sql_values.append((doc_id, document, embedding))
129
- sql_string = f"INSERT INTO {self.name} (id, documents, embedding)\n" f"VALUES (%s, %s, %s);\n"
130
- logger.debug(f"Add SQL String:\n{sql_string}\n{sql_values}")
131
- cursor.executemany(sql_string, sql_values)
132
- cursor.close()
133
-
134
- def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
135
- """
136
- Upsert documents into the collection.
137
-
138
- Args:
139
- ids (List[ItemID]): A list of document IDs.
140
- documents (List): A list of documents.
141
- embeddings (List): A list of document embeddings.
142
- metadatas (List): A list of document metadatas.
143
-
144
- Returns:
145
- None
146
- """
147
- cursor = self.client.cursor()
148
- sql_values = []
149
- if embeddings is not None and metadatas is not None:
150
- for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
151
- metadata = re.sub("'", '"', str(metadata))
152
- sql_values.append((doc_id, embedding, metadata, document, embedding, metadata, document))
153
- sql_string = (
154
- f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n"
155
- f"VALUES (%s, %s, %s, %s)\n"
156
- f"ON CONFLICT (id)\n"
157
- f"DO UPDATE SET embedding = %s,\n"
158
- f"metadatas = %s, documents = %s;\n"
159
- )
160
- elif embeddings is not None:
161
- for doc_id, embedding, document in zip(ids, embeddings, documents):
162
- sql_values.append((doc_id, embedding, document, embedding, document))
163
- sql_string = (
164
- f"INSERT INTO {self.name} (id, embedding, documents) "
165
- f"VALUES (%s, %s, %s) ON CONFLICT (id)\n"
166
- f"DO UPDATE SET embedding = %s, documents = %s;\n"
167
- )
168
- elif metadatas is not None:
169
- for doc_id, metadata, document in zip(ids, metadatas, documents):
170
- metadata = re.sub("'", '"', str(metadata))
171
- embedding = self.embedding_function(document)
172
- sql_values.append((doc_id, metadata, embedding, document, metadata, document, embedding))
173
- sql_string = (
174
- f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n"
175
- f"VALUES (%s, %s, %s, %s)\n"
176
- f"ON CONFLICT (id)\n"
177
- f"DO UPDATE SET metadatas = %s, documents = %s, embedding = %s;\n"
178
- )
179
- else:
180
- for doc_id, document in zip(ids, documents):
181
- embedding = self.embedding_function(document)
182
- sql_values.append((doc_id, document, embedding, document))
183
- sql_string = (
184
- f"INSERT INTO {self.name} (id, documents, embedding)\n"
185
- f"VALUES (%s, %s, %s)\n"
186
- f"ON CONFLICT (id)\n"
187
- f"DO UPDATE SET documents = %s;\n"
188
- )
189
- logger.debug(f"Upsert SQL String:\n{sql_string}\n{sql_values}")
190
- cursor.executemany(sql_string, sql_values)
191
- cursor.close()
192
-
193
- def count(self) -> int:
194
- """
195
- Get the total number of documents in the collection.
196
-
197
- Returns:
198
- int: The total number of documents.
199
- """
200
- cursor = self.client.cursor()
201
- query = f"SELECT COUNT(*) FROM {self.name}"
202
- cursor.execute(query)
203
- total = cursor.fetchone()[0]
204
- cursor.close()
205
- try:
206
- total = int(total)
207
- except (TypeError, ValueError):
208
- total = None
209
- return total
210
-
211
- def table_exists(self, table_name: str) -> bool:
212
- """
213
- Check if a table exists in the PostgreSQL database.
214
-
215
- Args:
216
- table_name (str): The name of the table to check.
217
-
218
- Returns:
219
- bool: True if the table exists, False otherwise.
220
- """
221
-
222
- cursor = self.client.cursor()
223
- cursor.execute(
224
- """
225
- SELECT EXISTS (
226
- SELECT 1
227
- FROM information_schema.tables
228
- WHERE table_name = %s
229
- )
230
- """,
231
- (table_name,),
232
- )
233
- exists = cursor.fetchone()[0]
234
- return exists
235
-
236
- def get(
237
- self,
238
- ids: Optional[str] = None,
239
- include: Optional[str] = None,
240
- where: Optional[str] = None,
241
- limit: Optional[Union[int, str]] = None,
242
- offset: Optional[Union[int, str]] = None,
243
- ) -> List[Document]:
244
- """
245
- Retrieve documents from the collection.
246
-
247
- Args:
248
- ids (Optional[List]): A list of document IDs.
249
- include (Optional): The fields to include.
250
- where (Optional): Additional filtering criteria.
251
- limit (Optional): The maximum number of documents to retrieve.
252
- offset (Optional): The offset for pagination.
253
-
254
- Returns:
255
- List: The retrieved documents.
256
- """
257
- cursor = self.client.cursor()
258
-
259
- # Initialize variables for query components
260
- select_clause = "SELECT id, metadatas, documents, embedding"
261
- from_clause = f"FROM {self.name}"
262
- where_clause = ""
263
- limit_clause = ""
264
- offset_clause = ""
265
-
266
- # Handle include clause
267
- if include:
268
- select_clause = f"SELECT id, {', '.join(include)}, embedding"
269
-
270
- # Handle where clause
271
- if ids:
272
- where_clause = f"WHERE id IN ({', '.join(['%s' for _ in ids])})"
273
- elif where:
274
- where_clause = f"WHERE {where}"
275
-
276
- # Handle limit and offset clauses
277
- if limit:
278
- limit_clause = "LIMIT %s"
279
- if offset:
280
- offset_clause = "OFFSET %s"
281
-
282
- # Construct the full query
283
- query = f"{select_clause} {from_clause} {where_clause} {limit_clause} {offset_clause}"
284
- retrieved_documents = []
285
- try:
286
- # Execute the query with the appropriate values
287
- if ids is not None:
288
- cursor.execute(query, ids)
289
- else:
290
- query_params = []
291
- if limit:
292
- query_params.append(limit)
293
- if offset:
294
- query_params.append(offset)
295
- cursor.execute(query, query_params)
296
-
297
- retrieval = cursor.fetchall()
298
- for retrieved_document in retrieval:
299
- retrieved_documents.append(
300
- Document(
301
- id=retrieved_document[0].strip(),
302
- metadata=retrieved_document[1],
303
- content=retrieved_document[2],
304
- embedding=retrieved_document[3],
305
- )
306
- )
307
- except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn) as e:
308
- logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead. Error: {e}")
309
- self.create_collection(collection_name=self.name, dimension=self.dimension)
310
- logger.info(f"Created table {self.name}")
311
-
312
- cursor.close()
313
- return retrieved_documents
314
-
315
- def update(self, ids: List, embeddings: List, metadatas: List, documents: List) -> None:
316
- """
317
- Update documents in the collection.
318
-
319
- Args:
320
- ids (List): A list of document IDs.
321
- embeddings (List): A list of document embeddings.
322
- metadatas (List): A list of document metadatas.
323
- documents (List): A list of documents.
324
-
325
- Returns:
326
- None
327
- """
328
- cursor = self.client.cursor()
329
- sql_values = []
330
- for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
331
- sql_values.append((doc_id, embedding, metadata, document, doc_id, embedding, metadata, document))
332
- sql_string = (
333
- f"INSERT INTO {self.name} (id, embedding, metadata, document) "
334
- f"VALUES (%s, %s, %s, %s) "
335
- f"ON CONFLICT (id) "
336
- f"DO UPDATE SET id = %s, embedding = %s, "
337
- f"metadata = %s, document = %s;\n"
338
- )
339
- logger.debug(f"Upsert SQL String:\n{sql_string}\n")
340
- cursor.executemany(sql_string, sql_values)
341
- cursor.close()
342
-
343
- @staticmethod
344
- def euclidean_distance(arr1: List[float], arr2: List[float]) -> float:
345
- """
346
- Calculate the Euclidean distance between two vectors.
347
-
348
- Parameters:
349
- - arr1 (List[float]): The first vector.
350
- - arr2 (List[float]): The second vector.
351
-
352
- Returns:
353
- - float: The Euclidean distance between arr1 and arr2.
354
- """
355
- dist = np.linalg.norm(arr1 - arr2)
356
- return dist
357
-
358
- @staticmethod
359
- def cosine_distance(arr1: List[float], arr2: List[float]) -> float:
360
- """
361
- Calculate the cosine distance between two vectors.
362
-
363
- Parameters:
364
- - arr1 (List[float]): The first vector.
365
- - arr2 (List[float]): The second vector.
366
-
367
- Returns:
368
- - float: The cosine distance between arr1 and arr2.
369
- """
370
- dist = np.dot(arr1, arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))
371
- return dist
372
-
373
- @staticmethod
374
- def inner_product_distance(arr1: List[float], arr2: List[float]) -> float:
375
- """
376
- Calculate the Euclidean distance between two vectors.
377
-
378
- Parameters:
379
- - arr1 (List[float]): The first vector.
380
- - arr2 (List[float]): The second vector.
381
-
382
- Returns:
383
- - float: The Euclidean distance between arr1 and arr2.
384
- """
385
- dist = np.linalg.norm(arr1 - arr2)
386
- return dist
387
-
388
- def query(
389
- self,
390
- query_texts: List[str],
391
- collection_name: Optional[str] = None,
392
- n_results: Optional[int] = 10,
393
- distance_type: Optional[str] = "euclidean",
394
- distance_threshold: Optional[float] = -1,
395
- include_embedding: Optional[bool] = False,
396
- ) -> QueryResults:
397
- """
398
- Query documents in the collection.
399
-
400
- Args:
401
- query_texts (List[str]): A list of query texts.
402
- collection_name (Optional[str]): The name of the collection.
403
- n_results (int): The maximum number of results to return.
404
- distance_type (Optional[str]): Distance search type - euclidean or cosine
405
- distance_threshold (Optional[float]): Distance threshold to limit searches
406
- include_embedding (Optional[bool]): Include embedding values in QueryResults
407
- Returns:
408
- QueryResults: The query results.
409
- """
410
- if collection_name:
411
- self.name = collection_name
412
-
413
- clause = "ORDER BY"
414
- if distance_threshold == -1:
415
- distance_threshold = ""
416
- clause = "ORDER BY"
417
- elif distance_threshold > 0:
418
- distance_threshold = f"< {distance_threshold}"
419
- clause = "WHERE"
420
-
421
- cursor = self.client.cursor()
422
- results = []
423
- for query_text in query_texts:
424
- vector = self.embedding_function(query_text, convert_to_tensor=False).tolist()
425
- if distance_type.lower() == "cosine":
426
- index_function = "<=>"
427
- elif distance_type.lower() == "euclidean":
428
- index_function = "<->"
429
- elif distance_type.lower() == "inner-product":
430
- index_function = "<#>"
431
- else:
432
- index_function = "<->"
433
- query = (
434
- f"SELECT id, documents, embedding, metadatas "
435
- f"FROM {self.name} "
436
- f"{clause} embedding {index_function} '{str(vector)}' {distance_threshold} "
437
- f"LIMIT {n_results}"
438
- )
439
- cursor.execute(query)
440
- result = []
441
- for row in cursor.fetchall():
442
- fetched_document = Document(id=row[0].strip(), content=row[1], embedding=row[2], metadata=row[3])
443
- fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding"))
444
- if distance_type.lower() == "cosine":
445
- distance = self.cosine_distance(fetched_document_array, vector)
446
- elif distance_type.lower() == "euclidean":
447
- distance = self.euclidean_distance(fetched_document_array, vector)
448
- elif distance_type.lower() == "inner-product":
449
- distance = self.inner_product_distance(fetched_document_array, vector)
450
- else:
451
- distance = self.euclidean_distance(fetched_document_array, vector)
452
- if not include_embedding:
453
- fetched_document = Document(id=row[0].strip(), content=row[1], metadata=row[3])
454
- result.append((fetched_document, distance))
455
- results.append(result)
456
- cursor.close()
457
- logger.debug(f"Query Results: {results}")
458
- return results
459
-
460
- @staticmethod
461
- def convert_string_to_array(array_string: str) -> List[float]:
462
- """
463
- Convert a string representation of an array to a list of floats.
464
-
465
- Parameters:
466
- - array_string (str): The string representation of the array.
467
-
468
- Returns:
469
- - list: A list of floats parsed from the input string. If the input is
470
- not a string, it returns the input itself.
471
- """
472
- if not isinstance(array_string, str):
473
- return array_string
474
- array_string = array_string.strip("[]")
475
- array = [float(num) for num in array_string.split()]
476
- return array
477
-
478
- def modify(self, metadata, collection_name: Optional[str] = None) -> None:
479
- """
480
- Modify metadata for the collection.
481
-
482
- Args:
483
- collection_name: The name of the collection.
484
- metadata: The new metadata.
485
-
486
- Returns:
487
- None
488
- """
489
- if collection_name:
490
- self.name = collection_name
491
- cursor = self.client.cursor()
492
- cursor.execute(
493
- "UPDATE collections" "SET metadata = '%s'" "WHERE collection_name = '%s';", (metadata, self.name)
494
- )
495
- cursor.close()
496
-
497
- def delete(self, ids: List[ItemID], collection_name: Optional[str] = None) -> None:
498
- """
499
- Delete documents from the collection.
500
-
501
- Args:
502
- ids (List[ItemID]): A list of document IDs to delete.
503
- collection_name (str): The name of the collection to delete.
504
-
505
- Returns:
506
- None
507
- """
508
- if collection_name:
509
- self.name = collection_name
510
- cursor = self.client.cursor()
511
- id_placeholders = ", ".join(["%s" for _ in ids])
512
- cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({id_placeholders});", ids)
513
- cursor.close()
514
-
515
- def delete_collection(self, collection_name: Optional[str] = None) -> None:
516
- """
517
- Delete the entire collection.
518
-
519
- Args:
520
- collection_name (Optional[str]): The name of the collection to delete.
521
-
522
- Returns:
523
- None
524
- """
525
- if collection_name:
526
- self.name = collection_name
527
- cursor = self.client.cursor()
528
- cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
529
- cursor.close()
530
-
531
- def create_collection(
532
- self, collection_name: Optional[str] = None, dimension: Optional[Union[str, int]] = None
533
- ) -> None:
534
- """
535
- Create a new collection.
536
-
537
- Args:
538
- collection_name (Optional[str]): The name of the new collection.
539
- dimension (Optional[Union[str, int]]): The dimension size of the sentence embedding model
540
-
541
- Returns:
542
- None
543
- """
544
- if collection_name:
545
- self.name = collection_name
546
-
547
- if dimension:
548
- self.dimension = dimension
549
- elif self.dimension is None:
550
- self.dimension = 384
551
-
552
- cursor = self.client.cursor()
553
- cursor.execute(
554
- f"CREATE TABLE {self.name} ("
555
- f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector({self.dimension}));"
556
- f"CREATE INDEX "
557
- f'ON {self.name} USING hnsw (embedding vector_l2_ops) WITH (m = {self.metadata["hnsw:M"]}, '
558
- f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
559
- f"CREATE INDEX "
560
- f'ON {self.name} USING hnsw (embedding vector_cosine_ops) WITH (m = {self.metadata["hnsw:M"]}, '
561
- f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
562
- f"CREATE INDEX "
563
- f'ON {self.name} USING hnsw (embedding vector_ip_ops) WITH (m = {self.metadata["hnsw:M"]}, '
564
- f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
565
- )
566
- cursor.close()
567
-
568
-
569
- class PGVectorDB(VectorDB):
570
- """
571
- A vector database that uses PGVector as the backend.
572
- """
573
-
574
- def __init__(
575
- self,
576
- *,
577
- conn: Optional[psycopg.Connection] = None,
578
- connection_string: Optional[str] = None,
579
- host: Optional[str] = None,
580
- port: Optional[Union[int, str]] = None,
581
- dbname: Optional[str] = None,
582
- username: Optional[str] = None,
583
- password: Optional[str] = None,
584
- connect_timeout: Optional[int] = 10,
585
- embedding_function: Callable = None,
586
- metadata: Optional[dict] = None,
587
- ) -> None:
588
- """
589
- Initialize the vector database.
590
-
591
- Note: connection_string or host + port + dbname must be specified
592
-
593
- Args:
594
- conn: psycopg.Connection | A customer connection object to connect to the database.
595
- A connection object may include additional key/values:
596
- https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
597
- connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None.
598
- host: str | The host to connect to. Default is None.
599
- port: int | The port to connect to. Default is None.
600
- dbname: str | The database name to connect to. Default is None.
601
- username: str | The database username to use. Default is None.
602
- password: str | The database user password to use. Default is None.
603
- connect_timeout: int | The timeout to set for the connection. Default is 10.
604
- embedding_function: Callable | The embedding function used to generate the vector representation.
605
- Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
606
- Models can be chosen from:
607
- https://huggingface.co/models?library=sentence-transformers
608
- metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
609
- setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
610
- using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
611
- For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
612
- Returns:
613
- None
614
- """
615
- self.client = self.establish_connection(
616
- conn=conn,
617
- connection_string=connection_string,
618
- host=host,
619
- port=port,
620
- dbname=dbname,
621
- username=username,
622
- password=password,
623
- connect_timeout=connect_timeout,
624
- )
625
- if embedding_function:
626
- self.embedding_function = embedding_function
627
- else:
628
- self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
629
- self.metadata = metadata
630
- register_vector(self.client)
631
- self.active_collection = None
632
-
633
- def establish_connection(
634
- self,
635
- conn: Optional[psycopg.Connection] = None,
636
- connection_string: Optional[str] = None,
637
- host: Optional[str] = None,
638
- port: Optional[Union[int, str]] = None,
639
- dbname: Optional[str] = None,
640
- username: Optional[str] = None,
641
- password: Optional[str] = None,
642
- connect_timeout: Optional[int] = 10,
643
- ) -> psycopg.Connection:
644
- """
645
- Establishes a connection to a PostgreSQL database using psycopg.
646
-
647
- Args:
648
- conn: An existing psycopg connection object. If provided, this connection will be used.
649
- connection_string: A string containing the connection information. If provided, a new connection will be established using this string.
650
- host: The hostname of the PostgreSQL server. Used if connection_string is not provided.
651
- port: The port number to connect to at the server host. Used if connection_string is not provided.
652
- dbname: The database name. Used if connection_string is not provided.
653
- username: The username to connect as. Used if connection_string is not provided.
654
- password: The user's password. Used if connection_string is not provided.
655
- connect_timeout: Maximum wait for connection, in seconds. The default is 10 seconds.
656
-
657
- Returns:
658
- A psycopg.Connection object representing the established connection.
659
-
660
- Raises:
661
- PermissionError if no credentials are supplied
662
- psycopg.Error: If an error occurs while trying to connect to the database.
663
- """
664
- try:
665
- if conn:
666
- self.client = conn
667
- elif connection_string:
668
- parsed_connection = urllib.parse.urlparse(connection_string)
669
- encoded_username = urllib.parse.quote(parsed_connection.username, safe="")
670
- encoded_password = urllib.parse.quote(parsed_connection.password, safe="")
671
- encoded_password = f":{encoded_password}@"
672
- encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="")
673
- encoded_port = f":{parsed_connection.port}"
674
- encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="")
675
- connection_string_encoded = (
676
- f"{parsed_connection.scheme}://{encoded_username}{encoded_password}"
677
- f"{encoded_host}{encoded_port}/{encoded_database}"
678
- )
679
- self.client = psycopg.connect(conninfo=connection_string_encoded, autocommit=True)
680
- elif host:
681
- connection_string = ""
682
- if host:
683
- encoded_host = urllib.parse.quote(host, safe="")
684
- connection_string += f"host={encoded_host} "
685
- if port:
686
- connection_string += f"port={port} "
687
- if dbname:
688
- encoded_database = urllib.parse.quote(dbname, safe="")
689
- connection_string += f"dbname={encoded_database} "
690
- if username:
691
- encoded_username = urllib.parse.quote(username, safe="")
692
- connection_string += f"user={encoded_username} "
693
- if password:
694
- encoded_password = urllib.parse.quote(password, safe="")
695
- connection_string += f"password={encoded_password} "
696
-
697
- self.client = psycopg.connect(
698
- conninfo=connection_string,
699
- connect_timeout=connect_timeout,
700
- autocommit=True,
701
- )
702
- else:
703
- logger.error("Credentials were not supplied...")
704
- raise PermissionError
705
- self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
706
- except psycopg.Error as e:
707
- logger.error("Error connecting to the database: ", e)
708
- raise e
709
- return self.client
710
-
711
- def create_collection(
712
- self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
713
- ) -> Collection:
714
- """
715
- Create a collection in the vector database.
716
- Case 1. if the collection does not exist, create the collection.
717
- Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
718
- Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
719
- otherwise it raise a ValueError.
720
-
721
- Args:
722
- collection_name: str | The name of the collection.
723
- overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
724
- get_or_create: bool | Whether to get the collection if it exists. Default is True.
725
-
726
- Returns:
727
- Collection | The collection object.
728
- """
729
- try:
730
- if self.active_collection and self.active_collection.name == collection_name:
731
- collection = self.active_collection
732
- else:
733
- collection = self.get_collection(collection_name)
734
- except ValueError:
735
- collection = None
736
- if collection is None:
737
- collection = Collection(
738
- client=self.client,
739
- collection_name=collection_name,
740
- embedding_function=self.embedding_function,
741
- get_or_create=get_or_create,
742
- metadata=self.metadata,
743
- )
744
- collection.set_collection_name(collection_name=collection_name)
745
- collection.create_collection(collection_name=collection_name)
746
- return collection
747
- elif overwrite:
748
- self.delete_collection(collection_name)
749
- collection = Collection(
750
- client=self.client,
751
- collection_name=collection_name,
752
- embedding_function=self.embedding_function,
753
- get_or_create=get_or_create,
754
- metadata=self.metadata,
755
- )
756
- collection.set_collection_name(collection_name=collection_name)
757
- collection.create_collection(collection_name=collection_name)
758
- return collection
759
- elif get_or_create:
760
- return collection
761
- elif not collection.table_exists(table_name=collection_name):
762
- collection = Collection(
763
- client=self.client,
764
- collection_name=collection_name,
765
- embedding_function=self.embedding_function,
766
- get_or_create=get_or_create,
767
- metadata=self.metadata,
768
- )
769
- collection.set_collection_name(collection_name=collection_name)
770
- collection.create_collection(collection_name=collection_name)
771
- return collection
772
- else:
773
- raise ValueError(f"Collection {collection_name} already exists.")
774
-
775
- def get_collection(self, collection_name: str = None) -> Collection:
776
- """
777
- Get the collection from the vector database.
778
-
779
- Args:
780
- collection_name: str | The name of the collection. Default is None. If None, return the
781
- current active collection.
782
-
783
- Returns:
784
- Collection | The collection object.
785
- """
786
- if collection_name is None:
787
- if self.active_collection is None:
788
- raise ValueError("No collection is specified.")
789
- else:
790
- logger.debug(
791
- f"No collection is specified. Using current active collection {self.active_collection.name}."
792
- )
793
- else:
794
- if not (self.active_collection and self.active_collection.name == collection_name):
795
- self.active_collection = Collection(
796
- client=self.client,
797
- collection_name=collection_name,
798
- embedding_function=self.embedding_function,
799
- )
800
- return self.active_collection
801
-
802
- def delete_collection(self, collection_name: str) -> None:
803
- """
804
- Delete the collection from the vector database.
805
-
806
- Args:
807
- collection_name: str | The name of the collection.
808
-
809
- Returns:
810
- None
811
- """
812
- if self.active_collection:
813
- self.active_collection.delete_collection(collection_name)
814
- else:
815
- collection = self.get_collection(collection_name)
816
- collection.delete_collection(collection_name)
817
- if self.active_collection and self.active_collection.name == collection_name:
818
- self.active_collection = None
819
-
820
- def _batch_insert(
821
- self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
822
- ) -> None:
823
- batch_size = int(PGVECTOR_MAX_BATCH_SIZE)
824
- default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
825
- default_metadatas = [default_metadata] * min(batch_size, len(documents))
826
- for i in range(0, len(documents), min(batch_size, len(documents))):
827
- end_idx = i + min(batch_size, len(documents) - i)
828
- collection_kwargs = {
829
- "documents": documents[i:end_idx],
830
- "ids": ids[i:end_idx],
831
- "metadatas": metadatas[i:end_idx] if metadatas else default_metadatas,
832
- "embeddings": embeddings[i:end_idx] if embeddings else None,
833
- }
834
- if upsert:
835
- collection.upsert(**collection_kwargs)
836
- else:
837
- collection.add(**collection_kwargs)
838
-
839
- def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
840
- """
841
- Insert documents into the collection of the vector database.
842
-
843
- Args:
844
- docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
845
- collection_name: str | The name of the collection. Default is None.
846
- upsert: bool | Whether to update the document if it exists. Default is False.
847
- kwargs: Dict | Additional keyword arguments.
848
-
849
- Returns:
850
- None
851
- """
852
- if not docs:
853
- return
854
- if docs[0].get("content") is None:
855
- raise ValueError("The document content is required.")
856
- if docs[0].get("id") is None:
857
- raise ValueError("The document id is required.")
858
- documents = [doc.get("content") for doc in docs]
859
- ids = [doc.get("id") for doc in docs]
860
-
861
- collection = self.get_collection(collection_name)
862
- if docs[0].get("embedding") is None:
863
- logger.debug(
864
- "No content embedding is provided. "
865
- "Will use the VectorDB's embedding function to generate the content embedding."
866
- )
867
- embeddings = None
868
- else:
869
- embeddings = [doc.get("embedding") for doc in docs]
870
- if docs[0].get("metadata") is None:
871
- metadatas = None
872
- else:
873
- metadatas = [doc.get("metadata") for doc in docs]
874
-
875
- self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert)
876
-
877
- def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
878
- """
879
- Update documents in the collection of the vector database.
880
-
881
- Args:
882
- docs: List[Document] | A list of documents.
883
- collection_name: str | The name of the collection. Default is None.
884
-
885
- Returns:
886
- None
887
- """
888
- self.insert_docs(docs, collection_name, upsert=True)
889
-
890
- def delete_docs(self, ids: List[ItemID], collection_name: str = None) -> None:
891
- """
892
- Delete documents from the collection of the vector database.
893
-
894
- Args:
895
- ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
896
- collection_name: str | The name of the collection. Default is None.
897
- kwargs: Dict | Additional keyword arguments.
898
-
899
- Returns:
900
- None
901
- """
902
- collection = self.get_collection(collection_name)
903
- collection.delete(ids=ids, collection_name=collection_name)
904
-
905
- def retrieve_docs(
906
- self,
907
- queries: List[str],
908
- collection_name: str = None,
909
- n_results: int = 10,
910
- distance_threshold: float = -1,
911
- ) -> QueryResults:
912
- """
913
- Retrieve documents from the collection of the vector database based on the queries.
914
-
915
- Args:
916
- queries: List[str] | A list of queries. Each query is a string.
917
- collection_name: str | The name of the collection. Default is None.
918
- n_results: int | The number of relevant documents to return. Default is 10.
919
- distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
920
- returned. Don't filter with it if < 0. Default is -1.
921
- kwargs: Dict | Additional keyword arguments.
922
-
923
- Returns:
924
- QueryResults | The query results. Each query result is a list of list of tuples containing the document and
925
- the distance.
926
- """
927
- collection = self.get_collection(collection_name)
928
- if isinstance(queries, str):
929
- queries = [queries]
930
- results = collection.query(
931
- query_texts=queries,
932
- n_results=n_results,
933
- distance_threshold=distance_threshold,
934
- )
935
- logger.debug(f"Retrieve Docs Results:\n{results}")
936
- return results
937
-
938
- def get_docs_by_ids(
939
- self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
940
- ) -> List[Document]:
941
- """
942
- Retrieve documents from the collection of the vector database based on the ids.
943
-
944
- Args:
945
- ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
946
- collection_name: str | The name of the collection. Default is None.
947
- include: List[str] | The fields to include. Default is None.
948
- If None, will include ["metadatas", "documents"], ids will always be included.
949
- kwargs: dict | Additional keyword arguments.
950
-
951
- Returns:
952
- List[Document] | The results.
953
- """
954
- collection = self.get_collection(collection_name)
955
- include = include if include else ["metadatas", "documents"]
956
- results = collection.get(ids, include=include, **kwargs)
957
- logger.debug(f"Retrieve Documents by ID Results:\n{results}")
958
- return results