vectordb-bench 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +85 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +112 -77
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +59 -84
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +13 -24
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +50 -54
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +39 -40
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +24 -26
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +67 -58
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +19 -39
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +56 -77
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +34 -43
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +98 -104
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +39 -49
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +15 -25
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +41 -35
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +95 -62
  47. vectordb_bench/backend/clients/test/cli.py +2 -3
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +5 -9
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +37 -26
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +18 -14
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +56 -23
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +46 -22
  63. vectordb_bench/backend/runner/serial_runner.py +81 -46
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -92
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +45 -24
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.21.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  """Wrapper around the Pgvector vector database over VectorDB"""
2
2
 
3
3
  import logging
4
- import pprint
4
+ from collections.abc import Generator, Sequence
5
5
  from contextlib import contextmanager
6
- from typing import Any, Generator, Optional, Tuple, Sequence
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import psycopg
@@ -11,7 +11,7 @@ from pgvector.psycopg import register_vector
11
11
  from psycopg import Connection, Cursor, sql
12
12
 
13
13
  from ..api import VectorDB
14
- from .config import PgVectorConfigDict, PgVectorIndexConfig, PgVectorHNSWConfig
14
+ from .config import PgVectorConfigDict, PgVectorIndexConfig
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
@@ -56,13 +56,14 @@ class PgVector(VectorDB):
56
56
  (
57
57
  self.case_config.create_index_before_load,
58
58
  self.case_config.create_index_after_load,
59
- )
59
+ ),
60
60
  ):
61
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
62
- log.error(err)
63
- raise RuntimeError(
64
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
61
+ msg = (
62
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
63
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
65
64
  )
65
+ log.error(msg)
66
+ raise RuntimeError(msg)
66
67
 
67
68
  if drop_old:
68
69
  self._drop_index()
@@ -77,7 +78,7 @@ class PgVector(VectorDB):
77
78
  self.conn = None
78
79
 
79
80
  @staticmethod
80
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
81
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
81
82
  conn = psycopg.connect(**kwargs)
82
83
  register_vector(conn)
83
84
  conn.autocommit = False
@@ -87,8 +88,8 @@ class PgVector(VectorDB):
87
88
  assert cursor is not None, "Cursor is not initialized"
88
89
 
89
90
  return conn, cursor
90
-
91
- def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
91
+
92
+ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
92
93
  index_param = self.case_config.index_param()
93
94
  reranking = self.case_config.search_param()["reranking"]
94
95
  column_name = (
@@ -103,23 +104,25 @@ class PgVector(VectorDB):
103
104
  )
104
105
 
105
106
  # The following sections assume that the quantization_type value matches the quantization function name
106
- if index_param["quantization_type"] != None:
107
+ if index_param["quantization_type"] is not None:
107
108
  if index_param["quantization_type"] == "bit" and reranking:
108
109
  # Embeddings needs to be passed to binary_quantize function if quantization_type is bit
109
110
  search_query = sql.Composed(
110
111
  [
111
112
  sql.SQL(
112
113
  """
113
- SELECT i.id
114
+ SELECT i.id
114
115
  FROM (
115
- SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
116
+ SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
116
117
  FROM public.{table_name} {where_clause}
117
118
  ORDER BY {column_name}::{quantization_type}({dim})
118
- """
119
+ """,
119
120
  ).format(
120
121
  table_name=sql.Identifier(self.table_name),
121
122
  column_name=column_name,
122
- reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]),
123
+ reranking_metric_fun_op=sql.SQL(
124
+ self.case_config.search_param()["reranking_metric_fun_op"],
125
+ ),
123
126
  quantization_type=sql.SQL(index_param["quantization_type"]),
124
127
  dim=sql.Literal(self.dim),
125
128
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
@@ -127,25 +130,28 @@ class PgVector(VectorDB):
127
130
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
128
131
  sql.SQL(
129
132
  """
130
- {search_vector}
133
+ {search_vector}
131
134
  LIMIT {quantized_fetch_limit}
132
135
  ) i
133
- ORDER BY i.distance
136
+ ORDER BY i.distance
134
137
  LIMIT %s::int
135
- """
138
+ """,
136
139
  ).format(
137
140
  search_vector=search_vector,
138
141
  quantized_fetch_limit=sql.Literal(
139
- self.case_config.search_param()["quantized_fetch_limit"]
142
+ self.case_config.search_param()["quantized_fetch_limit"],
140
143
  ),
141
144
  ),
142
- ]
145
+ ],
143
146
  )
144
147
  else:
145
148
  search_query = sql.Composed(
146
149
  [
147
150
  sql.SQL(
148
- "SELECT id FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) "
151
+ """
152
+ SELECT id FROM public.{table_name}
153
+ {where_clause} ORDER BY {column_name}::{quantization_type}({dim})
154
+ """,
149
155
  ).format(
150
156
  table_name=sql.Identifier(self.table_name),
151
157
  column_name=column_name,
@@ -154,25 +160,26 @@ class PgVector(VectorDB):
154
160
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
155
161
  ),
156
162
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
157
- sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector),
158
- ]
163
+ sql.SQL(" {search_vector} LIMIT %s::int").format(
164
+ search_vector=search_vector,
165
+ ),
166
+ ],
159
167
  )
160
168
  else:
161
169
  search_query = sql.Composed(
162
170
  [
163
171
  sql.SQL(
164
- "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
172
+ "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ",
165
173
  ).format(
166
174
  table_name=sql.Identifier(self.table_name),
167
175
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
168
176
  ),
169
177
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
170
178
  sql.SQL(" %s::vector LIMIT %s::int"),
171
- ]
179
+ ],
172
180
  )
173
-
181
+
174
182
  return search_query
175
-
176
183
 
177
184
  @contextmanager
178
185
  def init(self) -> Generator[None, None, None]:
@@ -191,8 +198,8 @@ class PgVector(VectorDB):
191
198
  if len(session_options) > 0:
192
199
  for setting in session_options:
193
200
  command = sql.SQL("SET {setting_name} " + "= {val};").format(
194
- setting_name=sql.Identifier(setting['parameter']['setting_name']),
195
- val=sql.Identifier(str(setting['parameter']['val'])),
201
+ setting_name=sql.Identifier(setting["parameter"]["setting_name"]),
202
+ val=sql.Identifier(str(setting["parameter"]["val"])),
196
203
  )
197
204
  log.debug(command.as_string(self.cursor))
198
205
  self.cursor.execute(command)
@@ -216,15 +223,12 @@ class PgVector(VectorDB):
216
223
 
217
224
  self.cursor.execute(
218
225
  sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
219
- table_name=sql.Identifier(self.table_name)
220
- )
226
+ table_name=sql.Identifier(self.table_name),
227
+ ),
221
228
  )
222
229
  self.conn.commit()
223
230
 
224
- def ready_to_load(self):
225
- pass
226
-
227
- def optimize(self):
231
+ def optimize(self, data_size: int | None = None):
228
232
  self._post_insert()
229
233
 
230
234
  def _post_insert(self):
@@ -239,7 +243,7 @@ class PgVector(VectorDB):
239
243
  log.info(f"{self.name} client drop index : {self._index_name}")
240
244
 
241
245
  drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
242
- index_name=sql.Identifier(self._index_name)
246
+ index_name=sql.Identifier(self._index_name),
243
247
  )
244
248
  log.debug(drop_index_sql.as_string(self.cursor))
245
249
  self.cursor.execute(drop_index_sql)
@@ -254,63 +258,51 @@ class PgVector(VectorDB):
254
258
  if index_param["maintenance_work_mem"] is not None:
255
259
  self.cursor.execute(
256
260
  sql.SQL("SET maintenance_work_mem TO {};").format(
257
- index_param["maintenance_work_mem"]
258
- )
261
+ index_param["maintenance_work_mem"],
262
+ ),
259
263
  )
260
264
  self.cursor.execute(
261
265
  sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
262
266
  sql.Identifier(self.db_config["user"]),
263
267
  index_param["maintenance_work_mem"],
264
- )
268
+ ),
265
269
  )
266
270
  self.conn.commit()
267
271
 
268
272
  if index_param["max_parallel_workers"] is not None:
269
273
  self.cursor.execute(
270
274
  sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
271
- index_param["max_parallel_workers"]
272
- )
275
+ index_param["max_parallel_workers"],
276
+ ),
273
277
  )
274
278
  self.cursor.execute(
275
- sql.SQL(
276
- "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
277
- ).format(
279
+ sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format(
278
280
  sql.Identifier(self.db_config["user"]),
279
281
  index_param["max_parallel_workers"],
280
- )
282
+ ),
281
283
  )
282
284
  self.cursor.execute(
283
285
  sql.SQL("SET max_parallel_workers TO '{}';").format(
284
- index_param["max_parallel_workers"]
285
- )
286
+ index_param["max_parallel_workers"],
287
+ ),
286
288
  )
287
289
  self.cursor.execute(
288
- sql.SQL(
289
- "ALTER USER {} SET max_parallel_workers TO '{}';"
290
- ).format(
290
+ sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format(
291
291
  sql.Identifier(self.db_config["user"]),
292
292
  index_param["max_parallel_workers"],
293
- )
293
+ ),
294
294
  )
295
295
  self.cursor.execute(
296
- sql.SQL(
297
- "ALTER TABLE {} SET (parallel_workers = {});"
298
- ).format(
296
+ sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format(
299
297
  sql.Identifier(self.table_name),
300
298
  index_param["max_parallel_workers"],
301
- )
299
+ ),
302
300
  )
303
301
  self.conn.commit()
304
302
 
305
- results = self.cursor.execute(
306
- sql.SQL("SHOW max_parallel_maintenance_workers;")
307
- ).fetchall()
308
- results.extend(
309
- self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
310
- )
311
- results.extend(
312
- self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
313
- )
303
+ results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall()
304
+ results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall())
305
+ results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall())
314
306
  log.info(f"{self.name} parallel index creation parameters: {results}")
315
307
 
316
308
  def _create_index(self):
@@ -322,24 +314,21 @@ class PgVector(VectorDB):
322
314
  self._set_parallel_index_build_param()
323
315
  options = []
324
316
  for option in index_param["index_creation_with_options"]:
325
- if option['val'] is not None:
317
+ if option["val"] is not None:
326
318
  options.append(
327
319
  sql.SQL("{option_name} = {val}").format(
328
- option_name=sql.Identifier(option['option_name']),
329
- val=sql.Identifier(str(option['val'])),
330
- )
320
+ option_name=sql.Identifier(option["option_name"]),
321
+ val=sql.Identifier(str(option["val"])),
322
+ ),
331
323
  )
332
- if any(options):
333
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
334
- else:
335
- with_clause = sql.Composed(())
324
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
336
325
 
337
- if index_param["quantization_type"] != None:
326
+ if index_param["quantization_type"] is not None:
338
327
  index_create_sql = sql.SQL(
339
328
  """
340
329
  CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
341
330
  USING {index_type} (({column_name}::{quantization_type}({dim})) {embedding_metric})
342
- """
331
+ """,
343
332
  ).format(
344
333
  index_name=sql.Identifier(self._index_name),
345
334
  table_name=sql.Identifier(self.table_name),
@@ -357,9 +346,9 @@ class PgVector(VectorDB):
357
346
  else:
358
347
  index_create_sql = sql.SQL(
359
348
  """
360
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
349
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
361
350
  USING {index_type} (embedding {embedding_metric})
362
- """
351
+ """,
363
352
  ).format(
364
353
  index_name=sql.Identifier(self._index_name),
365
354
  table_name=sql.Identifier(self.table_name),
@@ -367,9 +356,7 @@ class PgVector(VectorDB):
367
356
  embedding_metric=sql.Identifier(index_param["metric"]),
368
357
  )
369
358
 
370
- index_create_sql_with_with_clause = (
371
- index_create_sql + with_clause
372
- ).join(" ")
359
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
373
360
  log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
374
361
  self.cursor.execute(index_create_sql_with_with_clause)
375
362
  self.conn.commit()
@@ -384,19 +371,17 @@ class PgVector(VectorDB):
384
371
  # create table
385
372
  self.cursor.execute(
386
373
  sql.SQL(
387
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
388
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
374
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
375
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
389
376
  )
390
377
  self.cursor.execute(
391
378
  sql.SQL(
392
- "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;"
393
- ).format(table_name=sql.Identifier(self.table_name))
379
+ "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;",
380
+ ).format(table_name=sql.Identifier(self.table_name)),
394
381
  )
395
382
  self.conn.commit()
396
383
  except Exception as e:
397
- log.warning(
398
- f"Failed to create pgvector table: {self.table_name} error: {e}"
399
- )
384
+ log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}")
400
385
  raise e from None
401
386
 
402
387
  def insert_embeddings(
@@ -404,7 +389,7 @@ class PgVector(VectorDB):
404
389
  embeddings: list[list[float]],
405
390
  metadata: list[int],
406
391
  **kwargs: Any,
407
- ) -> Tuple[int, Optional[Exception]]:
392
+ ) -> tuple[int, Exception | None]:
408
393
  assert self.conn is not None, "Connection is not initialized"
409
394
  assert self.cursor is not None, "Cursor is not initialized"
410
395
 
@@ -414,8 +399,8 @@ class PgVector(VectorDB):
414
399
 
415
400
  with self.cursor.copy(
416
401
  sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
417
- table_name=sql.Identifier(self.table_name)
418
- )
402
+ table_name=sql.Identifier(self.table_name),
403
+ ),
419
404
  ) as copy:
420
405
  copy.set_types(["bigint", "vector"])
421
406
  for i, row in enumerate(metadata_arr):
@@ -427,9 +412,7 @@ class PgVector(VectorDB):
427
412
 
428
413
  return len(metadata), None
429
414
  except Exception as e:
430
- log.warning(
431
- f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
432
- )
415
+ log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
433
416
  return 0, e
434
417
 
435
418
  def search_embedding(
@@ -449,21 +432,32 @@ class PgVector(VectorDB):
449
432
  gt = filters.get("id")
450
433
  if index_param["quantization_type"] == "bit" and search_param["reranking"]:
451
434
  result = self.cursor.execute(
452
- self._filtered_search, (q, gt, q, k), prepare=True, binary=True
435
+ self._filtered_search,
436
+ (q, gt, q, k),
437
+ prepare=True,
438
+ binary=True,
453
439
  )
454
440
  else:
455
441
  result = self.cursor.execute(
456
- self._filtered_search, (gt, q, k), prepare=True, binary=True
442
+ self._filtered_search,
443
+ (gt, q, k),
444
+ prepare=True,
445
+ binary=True,
457
446
  )
458
-
447
+
448
+ elif index_param["quantization_type"] == "bit" and search_param["reranking"]:
449
+ result = self.cursor.execute(
450
+ self._unfiltered_search,
451
+ (q, q, k),
452
+ prepare=True,
453
+ binary=True,
454
+ )
459
455
  else:
460
- if index_param["quantization_type"] == "bit" and search_param["reranking"]:
461
- result = self.cursor.execute(
462
- self._unfiltered_search, (q, q, k), prepare=True, binary=True
463
- )
464
- else:
465
- result = self.cursor.execute(
466
- self._unfiltered_search, (q, k), prepare=True, binary=True
467
- )
456
+ result = self.cursor.execute(
457
+ self._unfiltered_search,
458
+ (q, k),
459
+ prepare=True,
460
+ binary=True,
461
+ )
468
462
 
469
463
  return [int(i[0]) for i in result.fetchall()]
@@ -1,80 +1,94 @@
1
- import click
2
1
  import os
2
+ from typing import Annotated, Unpack
3
+
4
+ import click
3
5
  from pydantic import SecretStr
4
6
 
7
+ from vectordb_bench.backend.clients import DB
8
+
5
9
  from ....cli.cli import (
6
10
  CommonTypedDict,
7
11
  cli,
8
12
  click_parameter_decorators_from_typed_dict,
9
13
  run,
10
14
  )
11
- from typing import Annotated, Unpack
12
- from vectordb_bench.backend.clients import DB
13
15
 
14
16
 
15
17
  class PgVectorScaleTypedDict(CommonTypedDict):
16
18
  user_name: Annotated[
17
- str, click.option("--user-name", type=str, help="Db username", required=True)
19
+ str,
20
+ click.option("--user-name", type=str, help="Db username", required=True),
18
21
  ]
19
22
  password: Annotated[
20
23
  str,
21
- click.option("--password",
22
- type=str,
23
- help="Postgres database password",
24
- default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
25
- show_default="$POSTGRES_PASSWORD",
26
- ),
24
+ click.option(
25
+ "--password",
26
+ type=str,
27
+ help="Postgres database password",
28
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
29
+ show_default="$POSTGRES_PASSWORD",
30
+ ),
27
31
  ]
28
32
 
29
- host: Annotated[
30
- str, click.option("--host", type=str, help="Db host", required=True)
31
- ]
32
- db_name: Annotated[
33
- str, click.option("--db-name", type=str, help="Db name", required=True)
34
- ]
33
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
34
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
35
35
 
36
36
 
37
37
  class PgVectorScaleDiskAnnTypedDict(PgVectorScaleTypedDict):
38
38
  storage_layout: Annotated[
39
39
  str,
40
40
  click.option(
41
- "--storage-layout", type=str, help="Streaming DiskANN storage layout",
41
+ "--storage-layout",
42
+ type=str,
43
+ help="Streaming DiskANN storage layout",
42
44
  ),
43
45
  ]
44
46
  num_neighbors: Annotated[
45
47
  int,
46
48
  click.option(
47
- "--num-neighbors", type=int, help="Streaming DiskANN num neighbors",
49
+ "--num-neighbors",
50
+ type=int,
51
+ help="Streaming DiskANN num neighbors",
48
52
  ),
49
53
  ]
50
54
  search_list_size: Annotated[
51
55
  int,
52
56
  click.option(
53
- "--search-list-size", type=int, help="Streaming DiskANN search list size",
57
+ "--search-list-size",
58
+ type=int,
59
+ help="Streaming DiskANN search list size",
54
60
  ),
55
61
  ]
56
62
  max_alpha: Annotated[
57
63
  float,
58
64
  click.option(
59
- "--max-alpha", type=float, help="Streaming DiskANN max alpha",
65
+ "--max-alpha",
66
+ type=float,
67
+ help="Streaming DiskANN max alpha",
60
68
  ),
61
69
  ]
62
70
  num_dimensions: Annotated[
63
71
  int,
64
72
  click.option(
65
- "--num-dimensions", type=int, help="Streaming DiskANN num dimensions",
73
+ "--num-dimensions",
74
+ type=int,
75
+ help="Streaming DiskANN num dimensions",
66
76
  ),
67
77
  ]
68
78
  query_search_list_size: Annotated[
69
79
  int,
70
80
  click.option(
71
- "--query-search-list-size", type=int, help="Streaming DiskANN query search list size",
81
+ "--query-search-list-size",
82
+ type=int,
83
+ help="Streaming DiskANN query search list size",
72
84
  ),
73
85
  ]
74
86
  query_rescore: Annotated[
75
87
  int,
76
88
  click.option(
77
- "--query-rescore", type=int, help="Streaming DiskANN query rescore",
89
+ "--query-rescore",
90
+ type=int,
91
+ help="Streaming DiskANN query rescore",
78
92
  ),
79
93
  ]
80
94
 
@@ -105,4 +119,4 @@ def PgVectorScaleDiskAnn(
105
119
  query_rescore=parameters["query_rescore"],
106
120
  ),
107
121
  **parameters,
108
- )
122
+ )
@@ -1,7 +1,8 @@
1
1
  from abc import abstractmethod
2
- from typing import TypedDict
2
+ from typing import LiteralString, TypedDict
3
+
3
4
  from pydantic import BaseModel, SecretStr
4
- from typing_extensions import LiteralString
5
+
5
6
  from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
7
 
7
8
  POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
@@ -9,7 +10,7 @@ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
9
10
 
10
11
  class PgVectorScaleConfigDict(TypedDict):
11
12
  """These keys will be directly used as kwargs in psycopg connection string,
12
- so the names must match exactly psycopg API"""
13
+ so the names must match exactly psycopg API"""
13
14
 
14
15
  user: str
15
16
  password: str
@@ -46,7 +47,7 @@ class PgVectorScaleIndexConfig(BaseModel, DBCaseConfig):
46
47
  if self.metric_type == MetricType.COSINE:
47
48
  return "vector_cosine_ops"
48
49
  return ""
49
-
50
+
50
51
  def parse_metric_fun_op(self) -> LiteralString:
51
52
  if self.metric_type == MetricType.COSINE:
52
53
  return "<=>"
@@ -56,19 +57,16 @@ class PgVectorScaleIndexConfig(BaseModel, DBCaseConfig):
56
57
  if self.metric_type == MetricType.COSINE:
57
58
  return "cosine_distance"
58
59
  return ""
59
-
60
+
60
61
  @abstractmethod
61
- def index_param(self) -> dict:
62
- ...
62
+ def index_param(self) -> dict: ...
63
63
 
64
64
  @abstractmethod
65
- def search_param(self) -> dict:
66
- ...
65
+ def search_param(self) -> dict: ...
67
66
 
68
67
  @abstractmethod
69
- def session_param(self) -> dict:
70
- ...
71
-
68
+ def session_param(self) -> dict: ...
69
+
72
70
 
73
71
  class PgVectorScaleStreamingDiskANNConfig(PgVectorScaleIndexConfig):
74
72
  index: IndexType = IndexType.STREAMING_DISKANN
@@ -93,19 +91,20 @@ class PgVectorScaleStreamingDiskANNConfig(PgVectorScaleIndexConfig):
93
91
  "num_dimensions": self.num_dimensions,
94
92
  },
95
93
  }
96
-
94
+
97
95
  def search_param(self) -> dict:
98
96
  return {
99
97
  "metric": self.parse_metric(),
100
98
  "metric_fun_op": self.parse_metric_fun_op(),
101
99
  }
102
-
100
+
103
101
  def session_param(self) -> dict:
104
102
  return {
105
103
  "diskann.query_search_list_size": self.query_search_list_size,
106
104
  "diskann.query_rescore": self.query_rescore,
107
105
  }
108
-
106
+
107
+
109
108
  _pgvectorscale_case_config = {
110
109
  IndexType.STREAMING_DISKANN: PgVectorScaleStreamingDiskANNConfig,
111
110
  }