vectordb-bench 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +49 -24
- vectordb_bench/__main__.py +4 -3
- vectordb_bench/backend/assembler.py +12 -13
- vectordb_bench/backend/cases.py +55 -45
- vectordb_bench/backend/clients/__init__.py +85 -14
- vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
- vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
- vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +112 -77
- vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
- vectordb_bench/backend/clients/alloydb/alloydb.py +59 -84
- vectordb_bench/backend/clients/alloydb/cli.py +51 -34
- vectordb_bench/backend/clients/alloydb/config.py +30 -30
- vectordb_bench/backend/clients/api.py +13 -24
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +50 -54
- vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
- vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
- vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
- vectordb_bench/backend/clients/chroma/chroma.py +39 -40
- vectordb_bench/backend/clients/chroma/config.py +4 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +24 -26
- vectordb_bench/backend/clients/memorydb/cli.py +8 -8
- vectordb_bench/backend/clients/memorydb/config.py +2 -2
- vectordb_bench/backend/clients/memorydb/memorydb.py +67 -58
- vectordb_bench/backend/clients/milvus/cli.py +41 -83
- vectordb_bench/backend/clients/milvus/config.py +18 -8
- vectordb_bench/backend/clients/milvus/milvus.py +19 -39
- vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
- vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +56 -77
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
- vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +34 -43
- vectordb_bench/backend/clients/pgvector/cli.py +40 -31
- vectordb_bench/backend/clients/pgvector/config.py +63 -73
- vectordb_bench/backend/clients/pgvector/pgvector.py +98 -104
- vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
- vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +39 -49
- vectordb_bench/backend/clients/pinecone/config.py +1 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +15 -25
- vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +41 -35
- vectordb_bench/backend/clients/redis/cli.py +6 -12
- vectordb_bench/backend/clients/redis/config.py +7 -5
- vectordb_bench/backend/clients/redis/redis.py +95 -62
- vectordb_bench/backend/clients/test/cli.py +2 -3
- vectordb_bench/backend/clients/test/config.py +2 -2
- vectordb_bench/backend/clients/test/test.py +5 -9
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
- vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +37 -26
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
- vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
- vectordb_bench/backend/data_source.py +18 -14
- vectordb_bench/backend/dataset.py +47 -27
- vectordb_bench/backend/result_collector.py +2 -3
- vectordb_bench/backend/runner/__init__.py +4 -6
- vectordb_bench/backend/runner/mp_runner.py +56 -23
- vectordb_bench/backend/runner/rate_runner.py +30 -19
- vectordb_bench/backend/runner/read_write_runner.py +46 -22
- vectordb_bench/backend/runner/serial_runner.py +81 -46
- vectordb_bench/backend/runner/util.py +4 -3
- vectordb_bench/backend/task_runner.py +92 -92
- vectordb_bench/backend/utils.py +17 -10
- vectordb_bench/base.py +0 -1
- vectordb_bench/cli/cli.py +65 -60
- vectordb_bench/cli/vectordbbench.py +6 -7
- vectordb_bench/frontend/components/check_results/charts.py +8 -19
- vectordb_bench/frontend/components/check_results/data.py +4 -16
- vectordb_bench/frontend/components/check_results/filters.py +8 -16
- vectordb_bench/frontend/components/check_results/nav.py +4 -4
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
- vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
- vectordb_bench/frontend/components/concurrent/charts.py +12 -12
- vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
- vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
- vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
- vectordb_bench/frontend/components/custom/initStyle.py +1 -1
- vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
- vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
- vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
- vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
- vectordb_bench/frontend/components/tables/data.py +3 -6
- vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
- vectordb_bench/frontend/pages/concurrent.py +3 -5
- vectordb_bench/frontend/pages/custom.py +30 -9
- vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
- vectordb_bench/frontend/pages/run_test.py +3 -7
- vectordb_bench/frontend/utils.py +1 -1
- vectordb_bench/frontend/vdb_benchmark.py +4 -6
- vectordb_bench/interface.py +45 -24
- vectordb_bench/log_util.py +59 -64
- vectordb_bench/metric.py +10 -11
- vectordb_bench/models.py +26 -43
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/METADATA +22 -15
- vectordb_bench-0.0.21.dist-info/RECORD +135 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/WHEEL +1 -1
- vectordb_bench-0.0.19.dist-info/RECORD +0 -135
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
|
|
1
1
|
"""Wrapper around the alloydb vector database over VectorDB"""
|
2
2
|
|
3
3
|
import logging
|
4
|
-
import
|
4
|
+
from collections.abc import Generator, Sequence
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from typing import Any
|
6
|
+
from typing import Any
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import psycopg
|
@@ -11,7 +11,7 @@ from pgvector.psycopg import register_vector
|
|
11
11
|
from psycopg import Connection, Cursor, sql
|
12
12
|
|
13
13
|
from ..api import VectorDB
|
14
|
-
from .config import AlloyDBConfigDict, AlloyDBIndexConfig
|
14
|
+
from .config import AlloyDBConfigDict, AlloyDBIndexConfig
|
15
15
|
|
16
16
|
log = logging.getLogger(__name__)
|
17
17
|
|
@@ -56,13 +56,14 @@ class AlloyDB(VectorDB):
|
|
56
56
|
(
|
57
57
|
self.case_config.create_index_before_load,
|
58
58
|
self.case_config.create_index_after_load,
|
59
|
-
)
|
59
|
+
),
|
60
60
|
):
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
|
61
|
+
msg = (
|
62
|
+
f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
|
63
|
+
"\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
|
65
64
|
)
|
65
|
+
log.warning(msg)
|
66
|
+
raise RuntimeError(msg)
|
66
67
|
|
67
68
|
if drop_old:
|
68
69
|
self._drop_index()
|
@@ -77,7 +78,7 @@ class AlloyDB(VectorDB):
|
|
77
78
|
self.conn = None
|
78
79
|
|
79
80
|
@staticmethod
|
80
|
-
def _create_connection(**kwargs) ->
|
81
|
+
def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
|
81
82
|
conn = psycopg.connect(**kwargs)
|
82
83
|
register_vector(conn)
|
83
84
|
conn.autocommit = False
|
@@ -86,21 +87,20 @@ class AlloyDB(VectorDB):
|
|
86
87
|
assert conn is not None, "Connection is not initialized"
|
87
88
|
assert cursor is not None, "Cursor is not initialized"
|
88
89
|
return conn, cursor
|
89
|
-
|
90
|
-
def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
|
91
|
-
|
90
|
+
|
91
|
+
def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
|
92
|
+
return sql.Composed(
|
92
93
|
[
|
93
94
|
sql.SQL(
|
94
|
-
"SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
|
95
|
+
"SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ",
|
95
96
|
).format(
|
96
97
|
table_name=sql.Identifier(self.table_name),
|
97
98
|
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
98
99
|
),
|
99
100
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
100
101
|
sql.SQL(" %s::vector LIMIT %s::int"),
|
101
|
-
]
|
102
|
+
],
|
102
103
|
)
|
103
|
-
return search_query
|
104
104
|
|
105
105
|
@contextmanager
|
106
106
|
def init(self) -> Generator[None, None, None]:
|
@@ -119,8 +119,8 @@ class AlloyDB(VectorDB):
|
|
119
119
|
if len(session_options) > 0:
|
120
120
|
for setting in session_options:
|
121
121
|
command = sql.SQL("SET {setting_name} " + "= {val};").format(
|
122
|
-
setting_name=sql.Identifier(setting[
|
123
|
-
val=sql.Identifier(str(setting[
|
122
|
+
setting_name=sql.Identifier(setting["parameter"]["setting_name"]),
|
123
|
+
val=sql.Identifier(str(setting["parameter"]["val"])),
|
124
124
|
)
|
125
125
|
log.debug(command.as_string(self.cursor))
|
126
126
|
self.cursor.execute(command)
|
@@ -144,15 +144,12 @@ class AlloyDB(VectorDB):
|
|
144
144
|
|
145
145
|
self.cursor.execute(
|
146
146
|
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
|
147
|
-
table_name=sql.Identifier(self.table_name)
|
148
|
-
)
|
147
|
+
table_name=sql.Identifier(self.table_name),
|
148
|
+
),
|
149
149
|
)
|
150
150
|
self.conn.commit()
|
151
151
|
|
152
|
-
def
|
153
|
-
pass
|
154
|
-
|
155
|
-
def optimize(self):
|
152
|
+
def optimize(self, data_size: int | None = None):
|
156
153
|
self._post_insert()
|
157
154
|
|
158
155
|
def _post_insert(self):
|
@@ -167,7 +164,7 @@ class AlloyDB(VectorDB):
|
|
167
164
|
log.info(f"{self.name} client drop index : {self._index_name}")
|
168
165
|
|
169
166
|
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
|
170
|
-
index_name=sql.Identifier(self._index_name)
|
167
|
+
index_name=sql.Identifier(self._index_name),
|
171
168
|
)
|
172
169
|
log.debug(drop_index_sql.as_string(self.cursor))
|
173
170
|
self.cursor.execute(drop_index_sql)
|
@@ -181,78 +178,64 @@ class AlloyDB(VectorDB):
|
|
181
178
|
|
182
179
|
if index_param["enable_pca"] is not None:
|
183
180
|
self.cursor.execute(
|
184
|
-
sql.SQL("SET scann.enable_pca TO {};").format(
|
185
|
-
index_param["enable_pca"]
|
186
|
-
)
|
181
|
+
sql.SQL("SET scann.enable_pca TO {};").format(index_param["enable_pca"]),
|
187
182
|
)
|
188
183
|
self.cursor.execute(
|
189
184
|
sql.SQL("ALTER USER {} SET scann.enable_pca TO {};").format(
|
190
185
|
sql.Identifier(self.db_config["user"]),
|
191
186
|
index_param["enable_pca"],
|
192
|
-
)
|
187
|
+
),
|
193
188
|
)
|
194
189
|
self.conn.commit()
|
195
190
|
|
196
191
|
if index_param["maintenance_work_mem"] is not None:
|
197
192
|
self.cursor.execute(
|
198
193
|
sql.SQL("SET maintenance_work_mem TO {};").format(
|
199
|
-
index_param["maintenance_work_mem"]
|
200
|
-
)
|
194
|
+
index_param["maintenance_work_mem"],
|
195
|
+
),
|
201
196
|
)
|
202
197
|
self.cursor.execute(
|
203
198
|
sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
|
204
199
|
sql.Identifier(self.db_config["user"]),
|
205
200
|
index_param["maintenance_work_mem"],
|
206
|
-
)
|
201
|
+
),
|
207
202
|
)
|
208
203
|
self.conn.commit()
|
209
204
|
|
210
205
|
if index_param["max_parallel_workers"] is not None:
|
211
206
|
self.cursor.execute(
|
212
207
|
sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
|
213
|
-
index_param["max_parallel_workers"]
|
214
|
-
)
|
208
|
+
index_param["max_parallel_workers"],
|
209
|
+
),
|
215
210
|
)
|
216
211
|
self.cursor.execute(
|
217
|
-
sql.SQL(
|
218
|
-
"ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
|
219
|
-
).format(
|
212
|
+
sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format(
|
220
213
|
sql.Identifier(self.db_config["user"]),
|
221
214
|
index_param["max_parallel_workers"],
|
222
|
-
)
|
215
|
+
),
|
223
216
|
)
|
224
217
|
self.cursor.execute(
|
225
218
|
sql.SQL("SET max_parallel_workers TO '{}';").format(
|
226
|
-
index_param["max_parallel_workers"]
|
227
|
-
)
|
219
|
+
index_param["max_parallel_workers"],
|
220
|
+
),
|
228
221
|
)
|
229
222
|
self.cursor.execute(
|
230
|
-
sql.SQL(
|
231
|
-
"ALTER USER {} SET max_parallel_workers TO '{}';"
|
232
|
-
).format(
|
223
|
+
sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format(
|
233
224
|
sql.Identifier(self.db_config["user"]),
|
234
225
|
index_param["max_parallel_workers"],
|
235
|
-
)
|
226
|
+
),
|
236
227
|
)
|
237
228
|
self.cursor.execute(
|
238
|
-
sql.SQL(
|
239
|
-
"ALTER TABLE {} SET (parallel_workers = {});"
|
240
|
-
).format(
|
229
|
+
sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format(
|
241
230
|
sql.Identifier(self.table_name),
|
242
231
|
index_param["max_parallel_workers"],
|
243
|
-
)
|
232
|
+
),
|
244
233
|
)
|
245
234
|
self.conn.commit()
|
246
235
|
|
247
|
-
results = self.cursor.execute(
|
248
|
-
|
249
|
-
).fetchall()
|
250
|
-
results.extend(
|
251
|
-
self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
|
252
|
-
)
|
253
|
-
results.extend(
|
254
|
-
self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
|
255
|
-
)
|
236
|
+
results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall()
|
237
|
+
results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall())
|
238
|
+
results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall())
|
256
239
|
log.info(f"{self.name} parallel index creation parameters: {results}")
|
257
240
|
|
258
241
|
def _create_index(self):
|
@@ -264,23 +247,20 @@ class AlloyDB(VectorDB):
|
|
264
247
|
self._set_parallel_index_build_param()
|
265
248
|
options = []
|
266
249
|
for option in index_param["index_creation_with_options"]:
|
267
|
-
if option[
|
250
|
+
if option["val"] is not None:
|
268
251
|
options.append(
|
269
252
|
sql.SQL("{option_name} = {val}").format(
|
270
|
-
option_name=sql.Identifier(option[
|
271
|
-
val=sql.Identifier(str(option[
|
272
|
-
)
|
253
|
+
option_name=sql.Identifier(option["option_name"]),
|
254
|
+
val=sql.Identifier(str(option["val"])),
|
255
|
+
),
|
273
256
|
)
|
274
|
-
if any(options)
|
275
|
-
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
|
276
|
-
else:
|
277
|
-
with_clause = sql.Composed(())
|
257
|
+
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
|
278
258
|
|
279
259
|
index_create_sql = sql.SQL(
|
280
260
|
"""
|
281
|
-
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
261
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
282
262
|
USING {index_type} (embedding {embedding_metric})
|
283
|
-
"""
|
263
|
+
""",
|
284
264
|
).format(
|
285
265
|
index_name=sql.Identifier(self._index_name),
|
286
266
|
table_name=sql.Identifier(self.table_name),
|
@@ -288,9 +268,7 @@ class AlloyDB(VectorDB):
|
|
288
268
|
embedding_metric=sql.Identifier(index_param["metric"]),
|
289
269
|
)
|
290
270
|
|
291
|
-
index_create_sql_with_with_clause = (
|
292
|
-
index_create_sql + with_clause
|
293
|
-
).join(" ")
|
271
|
+
index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
|
294
272
|
log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
|
295
273
|
self.cursor.execute(index_create_sql_with_with_clause)
|
296
274
|
self.conn.commit()
|
@@ -305,14 +283,12 @@ class AlloyDB(VectorDB):
|
|
305
283
|
# create table
|
306
284
|
self.cursor.execute(
|
307
285
|
sql.SQL(
|
308
|
-
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
|
309
|
-
).format(table_name=sql.Identifier(self.table_name), dim=dim)
|
286
|
+
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
|
287
|
+
).format(table_name=sql.Identifier(self.table_name), dim=dim),
|
310
288
|
)
|
311
289
|
self.conn.commit()
|
312
290
|
except Exception as e:
|
313
|
-
log.warning(
|
314
|
-
f"Failed to create alloydb table: {self.table_name} error: {e}"
|
315
|
-
)
|
291
|
+
log.warning(f"Failed to create alloydb table: {self.table_name} error: {e}")
|
316
292
|
raise e from None
|
317
293
|
|
318
294
|
def insert_embeddings(
|
@@ -320,7 +296,7 @@ class AlloyDB(VectorDB):
|
|
320
296
|
embeddings: list[list[float]],
|
321
297
|
metadata: list[int],
|
322
298
|
**kwargs: Any,
|
323
|
-
) ->
|
299
|
+
) -> tuple[int, Exception | None]:
|
324
300
|
assert self.conn is not None, "Connection is not initialized"
|
325
301
|
assert self.cursor is not None, "Cursor is not initialized"
|
326
302
|
|
@@ -330,8 +306,8 @@ class AlloyDB(VectorDB):
|
|
330
306
|
|
331
307
|
with self.cursor.copy(
|
332
308
|
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
|
333
|
-
table_name=sql.Identifier(self.table_name)
|
334
|
-
)
|
309
|
+
table_name=sql.Identifier(self.table_name),
|
310
|
+
),
|
335
311
|
) as copy:
|
336
312
|
copy.set_types(["bigint", "vector"])
|
337
313
|
for i, row in enumerate(metadata_arr):
|
@@ -343,9 +319,7 @@ class AlloyDB(VectorDB):
|
|
343
319
|
|
344
320
|
return len(metadata), None
|
345
321
|
except Exception as e:
|
346
|
-
log.warning(
|
347
|
-
f"Failed to insert data into alloydb table ({self.table_name}), error: {e}"
|
348
|
-
)
|
322
|
+
log.warning(f"Failed to insert data into alloydb table ({self.table_name}), error: {e}")
|
349
323
|
return 0, e
|
350
324
|
|
351
325
|
def search_embedding(
|
@@ -362,11 +336,12 @@ class AlloyDB(VectorDB):
|
|
362
336
|
if filters:
|
363
337
|
gt = filters.get("id")
|
364
338
|
result = self.cursor.execute(
|
365
|
-
self._filtered_search,
|
339
|
+
self._filtered_search,
|
340
|
+
(gt, q, k),
|
341
|
+
prepare=True,
|
342
|
+
binary=True,
|
366
343
|
)
|
367
344
|
else:
|
368
|
-
result = self.cursor.execute(
|
369
|
-
self._unfiltered_search, (q, k), prepare=True, binary=True
|
370
|
-
)
|
345
|
+
result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
|
371
346
|
|
372
347
|
return [int(i[0]) for i in result.fetchall()]
|
@@ -1,10 +1,10 @@
|
|
1
|
-
|
1
|
+
import os
|
2
|
+
from typing import Annotated, Unpack
|
2
3
|
|
3
4
|
import click
|
4
|
-
import os
|
5
5
|
from pydantic import SecretStr
|
6
6
|
|
7
|
-
from vectordb_bench.backend.clients
|
7
|
+
from vectordb_bench.backend.clients import DB
|
8
8
|
|
9
9
|
from ....cli.cli import (
|
10
10
|
CommonTypedDict,
|
@@ -13,31 +13,28 @@ from ....cli.cli import (
|
|
13
13
|
get_custom_case_config,
|
14
14
|
run,
|
15
15
|
)
|
16
|
-
from vectordb_bench.backend.clients import DB
|
17
16
|
|
18
17
|
|
19
18
|
class AlloyDBTypedDict(CommonTypedDict):
|
20
19
|
user_name: Annotated[
|
21
|
-
str,
|
20
|
+
str,
|
21
|
+
click.option("--user-name", type=str, help="Db username", required=True),
|
22
22
|
]
|
23
23
|
password: Annotated[
|
24
24
|
str,
|
25
|
-
click.option(
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
25
|
+
click.option(
|
26
|
+
"--password",
|
27
|
+
type=str,
|
28
|
+
help="Postgres database password",
|
29
|
+
default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
|
30
|
+
show_default="$POSTGRES_PASSWORD",
|
31
|
+
),
|
31
32
|
]
|
32
33
|
|
33
|
-
host: Annotated[
|
34
|
-
|
35
|
-
]
|
36
|
-
db_name: Annotated[
|
37
|
-
str, click.option("--db-name", type=str, help="Db name", required=True)
|
38
|
-
]
|
34
|
+
host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
|
35
|
+
db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
|
39
36
|
maintenance_work_mem: Annotated[
|
40
|
-
|
37
|
+
str | None,
|
41
38
|
click.option(
|
42
39
|
"--maintenance-work-mem",
|
43
40
|
type=str,
|
@@ -49,7 +46,7 @@ class AlloyDBTypedDict(CommonTypedDict):
|
|
49
46
|
),
|
50
47
|
]
|
51
48
|
max_parallel_workers: Annotated[
|
52
|
-
|
49
|
+
int | None,
|
53
50
|
click.option(
|
54
51
|
"--max-parallel-workers",
|
55
52
|
type=int,
|
@@ -58,32 +55,51 @@ class AlloyDBTypedDict(CommonTypedDict):
|
|
58
55
|
),
|
59
56
|
]
|
60
57
|
|
61
|
-
|
62
58
|
|
63
59
|
class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
|
64
60
|
num_leaves: Annotated[
|
65
61
|
int,
|
66
|
-
click.option("--num-leaves", type=int, help="Number of leaves", required=True)
|
62
|
+
click.option("--num-leaves", type=int, help="Number of leaves", required=True),
|
67
63
|
]
|
68
64
|
num_leaves_to_search: Annotated[
|
69
65
|
int,
|
70
|
-
click.option(
|
66
|
+
click.option(
|
67
|
+
"--num-leaves-to-search",
|
68
|
+
type=int,
|
69
|
+
help="Number of leaves to search",
|
70
|
+
required=True,
|
71
|
+
),
|
71
72
|
]
|
72
73
|
pre_reordering_num_neighbors: Annotated[
|
73
74
|
int,
|
74
|
-
click.option(
|
75
|
+
click.option(
|
76
|
+
"--pre-reordering-num-neighbors",
|
77
|
+
type=int,
|
78
|
+
help="Pre-reordering number of neighbors",
|
79
|
+
default=200,
|
80
|
+
),
|
75
81
|
]
|
76
82
|
max_top_neighbors_buffer_size: Annotated[
|
77
83
|
int,
|
78
|
-
click.option(
|
84
|
+
click.option(
|
85
|
+
"--max-top-neighbors-buffer-size",
|
86
|
+
type=int,
|
87
|
+
help="Maximum top neighbors buffer size",
|
88
|
+
default=20_000,
|
89
|
+
),
|
79
90
|
]
|
80
91
|
num_search_threads: Annotated[
|
81
92
|
int,
|
82
|
-
click.option("--num-search-threads", type=int, help="Number of search threads", default=2)
|
93
|
+
click.option("--num-search-threads", type=int, help="Number of search threads", default=2),
|
83
94
|
]
|
84
95
|
max_num_prefetch_datasets: Annotated[
|
85
96
|
int,
|
86
|
-
click.option(
|
97
|
+
click.option(
|
98
|
+
"--max-num-prefetch-datasets",
|
99
|
+
type=int,
|
100
|
+
help="Maximum number of prefetch datasets",
|
101
|
+
default=100,
|
102
|
+
),
|
87
103
|
]
|
88
104
|
quantizer: Annotated[
|
89
105
|
str,
|
@@ -91,16 +107,17 @@ class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
|
|
91
107
|
"--quantizer",
|
92
108
|
type=click.Choice(["SQ8", "FLAT"]),
|
93
109
|
help="Quantizer type",
|
94
|
-
default="SQ8"
|
95
|
-
)
|
110
|
+
default="SQ8",
|
111
|
+
),
|
96
112
|
]
|
97
113
|
enable_pca: Annotated[
|
98
|
-
bool,
|
114
|
+
bool,
|
115
|
+
click.option(
|
99
116
|
"--enable-pca",
|
100
117
|
type=click.Choice(["on", "off"]),
|
101
118
|
help="Enable PCA",
|
102
|
-
default="on"
|
103
|
-
)
|
119
|
+
default="on",
|
120
|
+
),
|
104
121
|
]
|
105
122
|
max_num_levels: Annotated[
|
106
123
|
int,
|
@@ -108,8 +125,8 @@ class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
|
|
108
125
|
"--max-num-levels",
|
109
126
|
type=click.Choice(["1", "2"]),
|
110
127
|
help="Maximum number of levels",
|
111
|
-
default=1
|
112
|
-
)
|
128
|
+
default=1,
|
129
|
+
),
|
113
130
|
]
|
114
131
|
|
115
132
|
|
@@ -144,4 +161,4 @@ def AlloyDBScaNN(
|
|
144
161
|
maintenance_work_mem=parameters["maintenance_work_mem"],
|
145
162
|
),
|
146
163
|
**parameters,
|
147
|
-
)
|
164
|
+
)
|
@@ -1,7 +1,9 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
-
from
|
2
|
+
from collections.abc import Mapping, Sequence
|
3
|
+
from typing import Any, LiteralString, TypedDict
|
4
|
+
|
3
5
|
from pydantic import BaseModel, SecretStr
|
4
|
-
|
6
|
+
|
5
7
|
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
6
8
|
|
7
9
|
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
|
@@ -9,7 +11,7 @@ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
|
|
9
11
|
|
10
12
|
class AlloyDBConfigDict(TypedDict):
|
11
13
|
"""These keys will be directly used as kwargs in psycopg connection string,
|
12
|
-
|
14
|
+
so the names must match exactly psycopg API"""
|
13
15
|
|
14
16
|
user: str
|
15
17
|
password: str
|
@@ -41,8 +43,8 @@ class AlloyDBIndexParam(TypedDict):
|
|
41
43
|
metric: str
|
42
44
|
index_type: str
|
43
45
|
index_creation_with_options: Sequence[dict[str, Any]]
|
44
|
-
maintenance_work_mem:
|
45
|
-
max_parallel_workers:
|
46
|
+
maintenance_work_mem: str | None
|
47
|
+
max_parallel_workers: int | None
|
46
48
|
|
47
49
|
|
48
50
|
class AlloyDBSearchParam(TypedDict):
|
@@ -61,31 +63,30 @@ class AlloyDBIndexConfig(BaseModel, DBCaseConfig):
|
|
61
63
|
def parse_metric(self) -> str:
|
62
64
|
if self.metric_type == MetricType.L2:
|
63
65
|
return "l2"
|
64
|
-
|
66
|
+
if self.metric_type == MetricType.DP:
|
65
67
|
return "dot_product"
|
66
68
|
return "cosine"
|
67
69
|
|
68
70
|
def parse_metric_fun_op(self) -> LiteralString:
|
69
71
|
if self.metric_type == MetricType.L2:
|
70
72
|
return "<->"
|
71
|
-
|
73
|
+
if self.metric_type == MetricType.IP:
|
72
74
|
return "<#>"
|
73
75
|
return "<=>"
|
74
76
|
|
75
77
|
@abstractmethod
|
76
|
-
def index_param(self) -> AlloyDBIndexParam:
|
77
|
-
...
|
78
|
+
def index_param(self) -> AlloyDBIndexParam: ...
|
78
79
|
|
79
80
|
@abstractmethod
|
80
|
-
def search_param(self) -> AlloyDBSearchParam:
|
81
|
-
...
|
81
|
+
def search_param(self) -> AlloyDBSearchParam: ...
|
82
82
|
|
83
83
|
@abstractmethod
|
84
|
-
def session_param(self) -> AlloyDBSessionCommands:
|
85
|
-
...
|
84
|
+
def session_param(self) -> AlloyDBSessionCommands: ...
|
86
85
|
|
87
86
|
@staticmethod
|
88
|
-
def _optionally_build_with_options(
|
87
|
+
def _optionally_build_with_options(
|
88
|
+
with_options: Mapping[str, Any],
|
89
|
+
) -> Sequence[dict[str, Any]]:
|
89
90
|
"""Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
|
90
91
|
options = []
|
91
92
|
for option_name, value in with_options.items():
|
@@ -94,24 +95,25 @@ class AlloyDBIndexConfig(BaseModel, DBCaseConfig):
|
|
94
95
|
{
|
95
96
|
"option_name": option_name,
|
96
97
|
"val": str(value),
|
97
|
-
}
|
98
|
+
},
|
98
99
|
)
|
99
100
|
return options
|
100
101
|
|
101
102
|
@staticmethod
|
102
103
|
def _optionally_build_set_options(
|
103
|
-
set_mapping: Mapping[str, Any]
|
104
|
+
set_mapping: Mapping[str, Any],
|
104
105
|
) -> Sequence[dict[str, Any]]:
|
105
106
|
"""Walk through options, creating 'SET 'key1 = "value1";' list"""
|
106
107
|
session_options = []
|
107
108
|
for setting_name, value in set_mapping.items():
|
108
109
|
if value:
|
109
110
|
session_options.append(
|
110
|
-
{
|
111
|
+
{
|
112
|
+
"parameter": {
|
111
113
|
"setting_name": setting_name,
|
112
114
|
"val": str(value),
|
113
115
|
},
|
114
|
-
}
|
116
|
+
},
|
115
117
|
)
|
116
118
|
return session_options
|
117
119
|
|
@@ -124,22 +126,22 @@ class AlloyDBScaNNConfig(AlloyDBIndexConfig):
|
|
124
126
|
max_num_levels: int | None
|
125
127
|
num_leaves_to_search: int | None
|
126
128
|
max_top_neighbors_buffer_size: int | None
|
127
|
-
pre_reordering_num_neighbors: int |
|
128
|
-
num_search_threads: int
|
129
|
+
pre_reordering_num_neighbors: int | None
|
130
|
+
num_search_threads: int | None
|
129
131
|
max_num_prefetch_datasets: int | None
|
130
|
-
maintenance_work_mem:
|
131
|
-
max_parallel_workers:
|
132
|
+
maintenance_work_mem: str | None = None
|
133
|
+
max_parallel_workers: int | None = None
|
132
134
|
|
133
135
|
def index_param(self) -> AlloyDBIndexParam:
|
134
136
|
index_parameters = {
|
135
|
-
"num_leaves": self.num_leaves,
|
137
|
+
"num_leaves": self.num_leaves,
|
138
|
+
"max_num_levels": self.max_num_levels,
|
139
|
+
"quantizer": self.quantizer,
|
136
140
|
}
|
137
141
|
return {
|
138
142
|
"metric": self.parse_metric(),
|
139
143
|
"index_type": self.index.value,
|
140
|
-
"index_creation_with_options": self._optionally_build_with_options(
|
141
|
-
index_parameters
|
142
|
-
),
|
144
|
+
"index_creation_with_options": self._optionally_build_with_options(index_parameters),
|
143
145
|
"maintenance_work_mem": self.maintenance_work_mem,
|
144
146
|
"max_parallel_workers": self.max_parallel_workers,
|
145
147
|
"enable_pca": self.enable_pca,
|
@@ -158,11 +160,9 @@ class AlloyDBScaNNConfig(AlloyDBIndexConfig):
|
|
158
160
|
"scann.num_search_threads": self.num_search_threads,
|
159
161
|
"scann.max_num_prefetch_datasets": self.max_num_prefetch_datasets,
|
160
162
|
}
|
161
|
-
return {
|
162
|
-
"session_options": self._optionally_build_set_options(session_parameters)
|
163
|
-
}
|
163
|
+
return {"session_options": self._optionally_build_set_options(session_parameters)}
|
164
164
|
|
165
165
|
|
166
166
|
_alloydb_case_config = {
|
167
|
-
|
167
|
+
IndexType.SCANN: AlloyDBScaNNConfig,
|
168
168
|
}
|