vectordb-bench 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +49 -24
- vectordb_bench/__main__.py +4 -3
- vectordb_bench/backend/assembler.py +12 -13
- vectordb_bench/backend/cases.py +55 -45
- vectordb_bench/backend/clients/__init__.py +75 -14
- vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
- vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
- vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +111 -70
- vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
- vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
- vectordb_bench/backend/clients/alloydb/cli.py +51 -34
- vectordb_bench/backend/clients/alloydb/config.py +30 -30
- vectordb_bench/backend/clients/api.py +5 -9
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
- vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
- vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
- vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
- vectordb_bench/backend/clients/chroma/chroma.py +38 -36
- vectordb_bench/backend/clients/chroma/config.py +4 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
- vectordb_bench/backend/clients/memorydb/cli.py +8 -8
- vectordb_bench/backend/clients/memorydb/config.py +2 -2
- vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
- vectordb_bench/backend/clients/milvus/cli.py +41 -83
- vectordb_bench/backend/clients/milvus/config.py +18 -8
- vectordb_bench/backend/clients/milvus/milvus.py +18 -19
- vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
- vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
- vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
- vectordb_bench/backend/clients/pgvector/cli.py +40 -31
- vectordb_bench/backend/clients/pgvector/config.py +63 -73
- vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
- vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
- vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
- vectordb_bench/backend/clients/pinecone/config.py +1 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
- vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
- vectordb_bench/backend/clients/redis/cli.py +6 -12
- vectordb_bench/backend/clients/redis/config.py +7 -5
- vectordb_bench/backend/clients/redis/redis.py +94 -58
- vectordb_bench/backend/clients/test/cli.py +1 -2
- vectordb_bench/backend/clients/test/config.py +2 -2
- vectordb_bench/backend/clients/test/test.py +4 -5
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
- vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
- vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
- vectordb_bench/backend/data_source.py +30 -18
- vectordb_bench/backend/dataset.py +47 -27
- vectordb_bench/backend/result_collector.py +2 -3
- vectordb_bench/backend/runner/__init__.py +4 -6
- vectordb_bench/backend/runner/mp_runner.py +85 -34
- vectordb_bench/backend/runner/rate_runner.py +30 -19
- vectordb_bench/backend/runner/read_write_runner.py +51 -23
- vectordb_bench/backend/runner/serial_runner.py +91 -48
- vectordb_bench/backend/runner/util.py +4 -3
- vectordb_bench/backend/task_runner.py +92 -72
- vectordb_bench/backend/utils.py +17 -10
- vectordb_bench/base.py +0 -1
- vectordb_bench/cli/cli.py +65 -60
- vectordb_bench/cli/vectordbbench.py +6 -7
- vectordb_bench/frontend/components/check_results/charts.py +8 -19
- vectordb_bench/frontend/components/check_results/data.py +4 -16
- vectordb_bench/frontend/components/check_results/filters.py +8 -16
- vectordb_bench/frontend/components/check_results/nav.py +4 -4
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
- vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
- vectordb_bench/frontend/components/concurrent/charts.py +12 -12
- vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
- vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
- vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
- vectordb_bench/frontend/components/custom/initStyle.py +1 -1
- vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
- vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
- vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
- vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
- vectordb_bench/frontend/components/tables/data.py +3 -6
- vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
- vectordb_bench/frontend/pages/concurrent.py +3 -5
- vectordb_bench/frontend/pages/custom.py +30 -9
- vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
- vectordb_bench/frontend/pages/run_test.py +3 -7
- vectordb_bench/frontend/utils.py +1 -1
- vectordb_bench/frontend/vdb_benchmark.py +4 -6
- vectordb_bench/interface.py +56 -26
- vectordb_bench/log_util.py +59 -64
- vectordb_bench/metric.py +10 -11
- vectordb_bench/models.py +26 -43
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +22 -15
- vectordb_bench-0.0.20.dist-info/RECORD +135 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
- vectordb_bench-0.0.19.dist-info/RECORD +0 -135
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
|
|
1
1
|
"""Wrapper around the pg_diskann vector database over VectorDB"""
|
2
2
|
|
3
3
|
import logging
|
4
|
-
import
|
4
|
+
from collections.abc import Generator
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from typing import Any
|
6
|
+
from typing import Any
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import psycopg
|
@@ -44,20 +44,21 @@ class PgDiskANN(VectorDB):
|
|
44
44
|
self._primary_field = "id"
|
45
45
|
self._vector_field = "embedding"
|
46
46
|
|
47
|
-
self.conn, self.cursor = self._create_connection(**self.db_config)
|
47
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
48
48
|
|
49
49
|
log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
|
50
50
|
if not any(
|
51
51
|
(
|
52
52
|
self.case_config.create_index_before_load,
|
53
53
|
self.case_config.create_index_after_load,
|
54
|
-
)
|
54
|
+
),
|
55
55
|
):
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
|
56
|
+
msg = (
|
57
|
+
f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
|
58
|
+
f"{self.name} config values: {self.db_config}\n{self.case_config}"
|
60
59
|
)
|
60
|
+
log.error(msg)
|
61
|
+
raise RuntimeError(msg)
|
61
62
|
|
62
63
|
if drop_old:
|
63
64
|
self._drop_index()
|
@@ -72,7 +73,7 @@ class PgDiskANN(VectorDB):
|
|
72
73
|
self.conn = None
|
73
74
|
|
74
75
|
@staticmethod
|
75
|
-
def _create_connection(**kwargs) ->
|
76
|
+
def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
|
76
77
|
conn = psycopg.connect(**kwargs)
|
77
78
|
conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS pg_diskann CASCADE")
|
78
79
|
conn.commit()
|
@@ -101,25 +102,25 @@ class PgDiskANN(VectorDB):
|
|
101
102
|
log.debug(command.as_string(self.cursor))
|
102
103
|
self.cursor.execute(command)
|
103
104
|
self.conn.commit()
|
104
|
-
|
105
|
+
|
105
106
|
self._filtered_search = sql.Composed(
|
106
107
|
[
|
107
108
|
sql.SQL(
|
108
|
-
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
|
109
|
-
|
109
|
+
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
|
110
|
+
).format(table_name=sql.Identifier(self.table_name)),
|
110
111
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
111
112
|
sql.SQL(" %s::vector LIMIT %s::int"),
|
112
|
-
]
|
113
|
+
],
|
113
114
|
)
|
114
115
|
|
115
116
|
self._unfiltered_search = sql.Composed(
|
116
117
|
[
|
117
118
|
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
|
118
|
-
sql.Identifier(self.table_name)
|
119
|
+
sql.Identifier(self.table_name),
|
119
120
|
),
|
120
121
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
121
122
|
sql.SQL(" %s::vector LIMIT %s::int"),
|
122
|
-
]
|
123
|
+
],
|
123
124
|
)
|
124
125
|
|
125
126
|
try:
|
@@ -137,8 +138,8 @@ class PgDiskANN(VectorDB):
|
|
137
138
|
|
138
139
|
self.cursor.execute(
|
139
140
|
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
|
140
|
-
table_name=sql.Identifier(self.table_name)
|
141
|
-
)
|
141
|
+
table_name=sql.Identifier(self.table_name),
|
142
|
+
),
|
142
143
|
)
|
143
144
|
self.conn.commit()
|
144
145
|
|
@@ -160,7 +161,7 @@ class PgDiskANN(VectorDB):
|
|
160
161
|
log.info(f"{self.name} client drop index : {self._index_name}")
|
161
162
|
|
162
163
|
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
|
163
|
-
index_name=sql.Identifier(self._index_name)
|
164
|
+
index_name=sql.Identifier(self._index_name),
|
164
165
|
)
|
165
166
|
log.debug(drop_index_sql.as_string(self.cursor))
|
166
167
|
self.cursor.execute(drop_index_sql)
|
@@ -175,64 +176,53 @@ class PgDiskANN(VectorDB):
|
|
175
176
|
if index_param["maintenance_work_mem"] is not None:
|
176
177
|
self.cursor.execute(
|
177
178
|
sql.SQL("SET maintenance_work_mem TO {};").format(
|
178
|
-
index_param["maintenance_work_mem"]
|
179
|
-
)
|
179
|
+
index_param["maintenance_work_mem"],
|
180
|
+
),
|
180
181
|
)
|
181
182
|
self.cursor.execute(
|
182
183
|
sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
|
183
184
|
sql.Identifier(self.db_config["user"]),
|
184
185
|
index_param["maintenance_work_mem"],
|
185
|
-
)
|
186
|
+
),
|
186
187
|
)
|
187
188
|
self.conn.commit()
|
188
189
|
|
189
190
|
if index_param["max_parallel_workers"] is not None:
|
190
191
|
self.cursor.execute(
|
191
192
|
sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
|
192
|
-
index_param["max_parallel_workers"]
|
193
|
-
)
|
193
|
+
index_param["max_parallel_workers"],
|
194
|
+
),
|
194
195
|
)
|
195
196
|
self.cursor.execute(
|
196
|
-
sql.SQL(
|
197
|
-
"ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
|
198
|
-
).format(
|
197
|
+
sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format(
|
199
198
|
sql.Identifier(self.db_config["user"]),
|
200
199
|
index_param["max_parallel_workers"],
|
201
|
-
)
|
200
|
+
),
|
202
201
|
)
|
203
202
|
self.cursor.execute(
|
204
203
|
sql.SQL("SET max_parallel_workers TO '{}';").format(
|
205
|
-
index_param["max_parallel_workers"]
|
206
|
-
)
|
204
|
+
index_param["max_parallel_workers"],
|
205
|
+
),
|
207
206
|
)
|
208
207
|
self.cursor.execute(
|
209
|
-
sql.SQL(
|
210
|
-
"ALTER USER {} SET max_parallel_workers TO '{}';"
|
211
|
-
).format(
|
208
|
+
sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format(
|
212
209
|
sql.Identifier(self.db_config["user"]),
|
213
210
|
index_param["max_parallel_workers"],
|
214
|
-
)
|
211
|
+
),
|
215
212
|
)
|
216
213
|
self.cursor.execute(
|
217
|
-
sql.SQL(
|
218
|
-
"ALTER TABLE {} SET (parallel_workers = {});"
|
219
|
-
).format(
|
214
|
+
sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format(
|
220
215
|
sql.Identifier(self.table_name),
|
221
216
|
index_param["max_parallel_workers"],
|
222
|
-
)
|
217
|
+
),
|
223
218
|
)
|
224
219
|
self.conn.commit()
|
225
220
|
|
226
|
-
results = self.cursor.execute(
|
227
|
-
|
228
|
-
).fetchall()
|
229
|
-
results.extend(
|
230
|
-
self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
|
231
|
-
)
|
232
|
-
results.extend(
|
233
|
-
self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
|
234
|
-
)
|
221
|
+
results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall()
|
222
|
+
results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall())
|
223
|
+
results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall())
|
235
224
|
log.info(f"{self.name} parallel index creation parameters: {results}")
|
225
|
+
|
236
226
|
def _create_index(self):
|
237
227
|
assert self.conn is not None, "Connection is not initialized"
|
238
228
|
assert self.cursor is not None, "Cursor is not initialized"
|
@@ -248,28 +238,23 @@ class PgDiskANN(VectorDB):
|
|
248
238
|
sql.SQL("{option_name} = {val}").format(
|
249
239
|
option_name=sql.Identifier(option_name),
|
250
240
|
val=sql.Identifier(str(option_val)),
|
251
|
-
)
|
241
|
+
),
|
252
242
|
)
|
253
|
-
|
254
|
-
if any(options)
|
255
|
-
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
|
256
|
-
else:
|
257
|
-
with_clause = sql.Composed(())
|
243
|
+
|
244
|
+
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
|
258
245
|
|
259
246
|
index_create_sql = sql.SQL(
|
260
247
|
"""
|
261
|
-
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
248
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
262
249
|
USING {index_type} (embedding {embedding_metric})
|
263
|
-
"""
|
250
|
+
""",
|
264
251
|
).format(
|
265
252
|
index_name=sql.Identifier(self._index_name),
|
266
253
|
table_name=sql.Identifier(self.table_name),
|
267
254
|
index_type=sql.Identifier(index_param["index_type"].lower()),
|
268
255
|
embedding_metric=sql.Identifier(index_param["metric"]),
|
269
256
|
)
|
270
|
-
index_create_sql_with_with_clause = (
|
271
|
-
index_create_sql + with_clause
|
272
|
-
).join(" ")
|
257
|
+
index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
|
273
258
|
log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
|
274
259
|
self.cursor.execute(index_create_sql_with_with_clause)
|
275
260
|
self.conn.commit()
|
@@ -283,14 +268,12 @@ class PgDiskANN(VectorDB):
|
|
283
268
|
|
284
269
|
self.cursor.execute(
|
285
270
|
sql.SQL(
|
286
|
-
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
|
287
|
-
).format(table_name=sql.Identifier(self.table_name), dim=dim)
|
271
|
+
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
|
272
|
+
).format(table_name=sql.Identifier(self.table_name), dim=dim),
|
288
273
|
)
|
289
274
|
self.conn.commit()
|
290
275
|
except Exception as e:
|
291
|
-
log.warning(
|
292
|
-
f"Failed to create pgdiskann table: {self.table_name} error: {e}"
|
293
|
-
)
|
276
|
+
log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}")
|
294
277
|
raise e from None
|
295
278
|
|
296
279
|
def insert_embeddings(
|
@@ -298,7 +281,7 @@ class PgDiskANN(VectorDB):
|
|
298
281
|
embeddings: list[list[float]],
|
299
282
|
metadata: list[int],
|
300
283
|
**kwargs: Any,
|
301
|
-
) ->
|
284
|
+
) -> tuple[int, Exception | None]:
|
302
285
|
assert self.conn is not None, "Connection is not initialized"
|
303
286
|
assert self.cursor is not None, "Cursor is not initialized"
|
304
287
|
|
@@ -308,8 +291,8 @@ class PgDiskANN(VectorDB):
|
|
308
291
|
|
309
292
|
with self.cursor.copy(
|
310
293
|
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
|
311
|
-
table_name=sql.Identifier(self.table_name)
|
312
|
-
)
|
294
|
+
table_name=sql.Identifier(self.table_name),
|
295
|
+
),
|
313
296
|
) as copy:
|
314
297
|
copy.set_types(["bigint", "vector"])
|
315
298
|
for i, row in enumerate(metadata_arr):
|
@@ -321,9 +304,7 @@ class PgDiskANN(VectorDB):
|
|
321
304
|
|
322
305
|
return len(metadata), None
|
323
306
|
except Exception as e:
|
324
|
-
log.warning(
|
325
|
-
f"Failed to insert data into table ({self.table_name}), error: {e}"
|
326
|
-
)
|
307
|
+
log.warning(f"Failed to insert data into table ({self.table_name}), error: {e}")
|
327
308
|
return 0, e
|
328
309
|
|
329
310
|
def search_embedding(
|
@@ -340,11 +321,12 @@ class PgDiskANN(VectorDB):
|
|
340
321
|
if filters:
|
341
322
|
gt = filters.get("id")
|
342
323
|
result = self.cursor.execute(
|
343
|
-
|
344
|
-
|
324
|
+
self._filtered_search,
|
325
|
+
(gt, q, k),
|
326
|
+
prepare=True,
|
327
|
+
binary=True,
|
328
|
+
)
|
345
329
|
else:
|
346
|
-
result = self.cursor.execute(
|
347
|
-
self._unfiltered_search, (q, k), prepare=True, binary=True
|
348
|
-
)
|
330
|
+
result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
|
349
331
|
|
350
332
|
return [int(i[0]) for i in result.fetchall()]
|
@@ -1,9 +1,11 @@
|
|
1
|
-
|
1
|
+
import os
|
2
|
+
from typing import Annotated, Unpack
|
2
3
|
|
3
4
|
import click
|
4
|
-
import os
|
5
5
|
from pydantic import SecretStr
|
6
6
|
|
7
|
+
from vectordb_bench.backend.clients import DB
|
8
|
+
|
7
9
|
from ....cli.cli import (
|
8
10
|
CommonTypedDict,
|
9
11
|
HNSWFlavor1,
|
@@ -12,12 +14,12 @@ from ....cli.cli import (
|
|
12
14
|
click_parameter_decorators_from_typed_dict,
|
13
15
|
run,
|
14
16
|
)
|
15
|
-
from vectordb_bench.backend.clients import DB
|
16
17
|
|
17
18
|
|
18
19
|
class PgVectoRSTypedDict(CommonTypedDict):
|
19
20
|
user_name: Annotated[
|
20
|
-
str,
|
21
|
+
str,
|
22
|
+
click.option("--user-name", type=str, help="Db username", required=True),
|
21
23
|
]
|
22
24
|
password: Annotated[
|
23
25
|
str,
|
@@ -30,14 +32,10 @@ class PgVectoRSTypedDict(CommonTypedDict):
|
|
30
32
|
),
|
31
33
|
]
|
32
34
|
|
33
|
-
host: Annotated[
|
34
|
-
|
35
|
-
]
|
36
|
-
db_name: Annotated[
|
37
|
-
str, click.option("--db-name", type=str, help="Db name", required=True)
|
38
|
-
]
|
35
|
+
host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
|
36
|
+
db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
|
39
37
|
max_parallel_workers: Annotated[
|
40
|
-
|
38
|
+
int | None,
|
41
39
|
click.option(
|
42
40
|
"--max-parallel-workers",
|
43
41
|
type=int,
|
@@ -1,11 +1,11 @@
|
|
1
1
|
from abc import abstractmethod
|
2
2
|
from typing import TypedDict
|
3
3
|
|
4
|
+
from pgvecto_rs.types import Flat, Hnsw, IndexOption, Ivf, Quantization
|
5
|
+
from pgvecto_rs.types.index import QuantizationRatio, QuantizationType
|
4
6
|
from pydantic import BaseModel, SecretStr
|
5
|
-
from pgvecto_rs.types import IndexOption, Ivf, Hnsw, Flat, Quantization
|
6
|
-
from pgvecto_rs.types.index import QuantizationType, QuantizationRatio
|
7
7
|
|
8
|
-
from ..api import
|
8
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
9
9
|
|
10
10
|
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
|
11
11
|
|
@@ -52,14 +52,14 @@ class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
|
|
52
52
|
def parse_metric(self) -> str:
|
53
53
|
if self.metric_type == MetricType.L2:
|
54
54
|
return "vector_l2_ops"
|
55
|
-
|
55
|
+
if self.metric_type == MetricType.IP:
|
56
56
|
return "vector_dot_ops"
|
57
57
|
return "vector_cos_ops"
|
58
58
|
|
59
59
|
def parse_metric_fun_op(self) -> str:
|
60
60
|
if self.metric_type == MetricType.L2:
|
61
61
|
return "<->"
|
62
|
-
|
62
|
+
if self.metric_type == MetricType.IP:
|
63
63
|
return "<#>"
|
64
64
|
return "<=>"
|
65
65
|
|
@@ -85,9 +85,7 @@ class PgVectoRSHNSWConfig(PgVectoRSIndexConfig):
|
|
85
85
|
if self.quantization_type is None:
|
86
86
|
quantization = None
|
87
87
|
else:
|
88
|
-
quantization = Quantization(
|
89
|
-
typ=self.quantization_type, ratio=self.quantization_ratio
|
90
|
-
)
|
88
|
+
quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio)
|
91
89
|
|
92
90
|
option = IndexOption(
|
93
91
|
index=Hnsw(
|
@@ -115,9 +113,7 @@ class PgVectoRSIVFFlatConfig(PgVectoRSIndexConfig):
|
|
115
113
|
if self.quantization_type is None:
|
116
114
|
quantization = None
|
117
115
|
else:
|
118
|
-
quantization = Quantization(
|
119
|
-
typ=self.quantization_type, ratio=self.quantization_ratio
|
120
|
-
)
|
116
|
+
quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio)
|
121
117
|
|
122
118
|
option = IndexOption(
|
123
119
|
index=Ivf(nlist=self.lists, quantization=quantization),
|
@@ -139,9 +135,7 @@ class PgVectoRSFLATConfig(PgVectoRSIndexConfig):
|
|
139
135
|
if self.quantization_type is None:
|
140
136
|
quantization = None
|
141
137
|
else:
|
142
|
-
quantization = Quantization(
|
143
|
-
typ=self.quantization_type, ratio=self.quantization_ratio
|
144
|
-
)
|
138
|
+
quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio)
|
145
139
|
|
146
140
|
option = IndexOption(
|
147
141
|
index=Flat(
|
@@ -1,14 +1,14 @@
|
|
1
1
|
"""Wrapper around the Pgvecto.rs vector database over VectorDB"""
|
2
2
|
|
3
3
|
import logging
|
4
|
-
import
|
4
|
+
from collections.abc import Generator
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from typing import Any
|
6
|
+
from typing import Any
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import psycopg
|
10
|
-
from psycopg import Connection, Cursor, sql
|
11
10
|
from pgvecto_rs.psycopg import register_vector
|
11
|
+
from psycopg import Connection, Cursor, sql
|
12
12
|
|
13
13
|
from ..api import VectorDB
|
14
14
|
from .config import PgVectoRSConfig, PgVectoRSIndexConfig
|
@@ -33,7 +33,6 @@ class PgVectoRS(VectorDB):
|
|
33
33
|
drop_old: bool = False,
|
34
34
|
**kwargs,
|
35
35
|
):
|
36
|
-
|
37
36
|
self.name = "PgVectorRS"
|
38
37
|
self.db_config = db_config
|
39
38
|
self.case_config = db_case_config
|
@@ -52,13 +51,14 @@ class PgVectoRS(VectorDB):
|
|
52
51
|
(
|
53
52
|
self.case_config.create_index_before_load,
|
54
53
|
self.case_config.create_index_after_load,
|
55
|
-
)
|
54
|
+
),
|
56
55
|
):
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
|
56
|
+
msg = (
|
57
|
+
f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
|
58
|
+
f"{self.name} config values: {self.db_config}\n{self.case_config}"
|
61
59
|
)
|
60
|
+
log.error(msg)
|
61
|
+
raise RuntimeError(msg)
|
62
62
|
|
63
63
|
if drop_old:
|
64
64
|
log.info(f"Pgvecto.rs client drop table : {self.table_name}")
|
@@ -74,7 +74,7 @@ class PgVectoRS(VectorDB):
|
|
74
74
|
self.conn = None
|
75
75
|
|
76
76
|
@staticmethod
|
77
|
-
def _create_connection(**kwargs) ->
|
77
|
+
def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
|
78
78
|
conn = psycopg.connect(**kwargs)
|
79
79
|
|
80
80
|
# create vector extension
|
@@ -116,21 +116,21 @@ class PgVectoRS(VectorDB):
|
|
116
116
|
self._filtered_search = sql.Composed(
|
117
117
|
[
|
118
118
|
sql.SQL(
|
119
|
-
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
|
119
|
+
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
|
120
120
|
).format(table_name=sql.Identifier(self.table_name)),
|
121
121
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
122
122
|
sql.SQL(" %s::vector LIMIT %s::int"),
|
123
|
-
]
|
123
|
+
],
|
124
124
|
)
|
125
125
|
|
126
126
|
self._unfiltered_search = sql.Composed(
|
127
127
|
[
|
128
|
-
sql.SQL(
|
129
|
-
|
130
|
-
)
|
128
|
+
sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format(
|
129
|
+
table_name=sql.Identifier(self.table_name),
|
130
|
+
),
|
131
131
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
132
132
|
sql.SQL(" %s::vector LIMIT %s::int"),
|
133
|
-
]
|
133
|
+
],
|
134
134
|
)
|
135
135
|
|
136
136
|
try:
|
@@ -148,8 +148,8 @@ class PgVectoRS(VectorDB):
|
|
148
148
|
|
149
149
|
self.cursor.execute(
|
150
150
|
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
|
151
|
-
table_name=sql.Identifier(self.table_name)
|
152
|
-
)
|
151
|
+
table_name=sql.Identifier(self.table_name),
|
152
|
+
),
|
153
153
|
)
|
154
154
|
self.conn.commit()
|
155
155
|
|
@@ -171,7 +171,7 @@ class PgVectoRS(VectorDB):
|
|
171
171
|
log.info(f"{self.name} client drop index : {self._index_name}")
|
172
172
|
|
173
173
|
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
|
174
|
-
index_name=sql.Identifier(self._index_name)
|
174
|
+
index_name=sql.Identifier(self._index_name),
|
175
175
|
)
|
176
176
|
log.debug(drop_index_sql.as_string(self.cursor))
|
177
177
|
self.cursor.execute(drop_index_sql)
|
@@ -186,9 +186,9 @@ class PgVectoRS(VectorDB):
|
|
186
186
|
|
187
187
|
index_create_sql = sql.SQL(
|
188
188
|
"""
|
189
|
-
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
189
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
190
190
|
USING vectors (embedding {embedding_metric}) WITH (options = {index_options})
|
191
|
-
"""
|
191
|
+
""",
|
192
192
|
).format(
|
193
193
|
index_name=sql.Identifier(self._index_name),
|
194
194
|
table_name=sql.Identifier(self.table_name),
|
@@ -202,7 +202,7 @@ class PgVectoRS(VectorDB):
|
|
202
202
|
except Exception as e:
|
203
203
|
log.warning(
|
204
204
|
f"Failed to create pgvecto.rs index {self._index_name} \
|
205
|
-
at table {self.table_name} error: {e}"
|
205
|
+
at table {self.table_name} error: {e}",
|
206
206
|
)
|
207
207
|
raise e from None
|
208
208
|
|
@@ -214,7 +214,7 @@ class PgVectoRS(VectorDB):
|
|
214
214
|
"""
|
215
215
|
CREATE TABLE IF NOT EXISTS public.{table_name}
|
216
216
|
(id BIGINT PRIMARY KEY, embedding vector({dim}))
|
217
|
-
"""
|
217
|
+
""",
|
218
218
|
).format(
|
219
219
|
table_name=sql.Identifier(self.table_name),
|
220
220
|
dim=dim,
|
@@ -224,9 +224,7 @@ class PgVectoRS(VectorDB):
|
|
224
224
|
self.cursor.execute(table_create_sql)
|
225
225
|
self.conn.commit()
|
226
226
|
except Exception as e:
|
227
|
-
log.warning(
|
228
|
-
f"Failed to create pgvecto.rs table: {self.table_name} error: {e}"
|
229
|
-
)
|
227
|
+
log.warning(f"Failed to create pgvecto.rs table: {self.table_name} error: {e}")
|
230
228
|
raise e from None
|
231
229
|
|
232
230
|
def insert_embeddings(
|
@@ -234,7 +232,7 @@ class PgVectoRS(VectorDB):
|
|
234
232
|
embeddings: list[list[float]],
|
235
233
|
metadata: list[int],
|
236
234
|
**kwargs: Any,
|
237
|
-
) ->
|
235
|
+
) -> tuple[int, Exception | None]:
|
238
236
|
assert self.conn is not None, "Connection is not initialized"
|
239
237
|
assert self.cursor is not None, "Cursor is not initialized"
|
240
238
|
|
@@ -247,8 +245,8 @@ class PgVectoRS(VectorDB):
|
|
247
245
|
|
248
246
|
with self.cursor.copy(
|
249
247
|
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
|
250
|
-
table_name=sql.Identifier(self.table_name)
|
251
|
-
)
|
248
|
+
table_name=sql.Identifier(self.table_name),
|
249
|
+
),
|
252
250
|
) as copy:
|
253
251
|
copy.set_types(["bigint", "vector"])
|
254
252
|
for i, row in enumerate(metadata_arr):
|
@@ -261,7 +259,7 @@ class PgVectoRS(VectorDB):
|
|
261
259
|
return len(metadata), None
|
262
260
|
except Exception as e:
|
263
261
|
log.warning(
|
264
|
-
f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}"
|
262
|
+
f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}",
|
265
263
|
)
|
266
264
|
return 0, e
|
267
265
|
|
@@ -281,12 +279,13 @@ class PgVectoRS(VectorDB):
|
|
281
279
|
log.debug(self._filtered_search.as_string(self.cursor))
|
282
280
|
gt = filters.get("id")
|
283
281
|
result = self.cursor.execute(
|
284
|
-
self._filtered_search,
|
282
|
+
self._filtered_search,
|
283
|
+
(gt, q, k),
|
284
|
+
prepare=True,
|
285
|
+
binary=True,
|
285
286
|
)
|
286
287
|
else:
|
287
288
|
log.debug(self._unfiltered_search.as_string(self.cursor))
|
288
|
-
result = self.cursor.execute(
|
289
|
-
self._unfiltered_search, (q, k), prepare=True, binary=True
|
290
|
-
)
|
289
|
+
result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
|
291
290
|
|
292
291
|
return [int(i[0]) for i in result.fetchall()]
|