vectordb-bench 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +56 -46
  5. vectordb_bench/backend/clients/__init__.py +101 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +26 -0
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +18 -0
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +345 -0
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +47 -0
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
  11. vectordb_bench/backend/clients/alloydb/cli.py +52 -35
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +8 -9
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +38 -36
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
  25. vectordb_bench/backend/clients/milvus/cli.py +62 -80
  26. vectordb_bench/backend/clients/milvus/config.py +31 -7
  27. vectordb_bench/backend/clients/milvus/milvus.py +23 -26
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +94 -58
  47. vectordb_bench/backend/clients/test/cli.py +1 -2
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +4 -5
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +30 -18
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +85 -34
  61. vectordb_bench/backend/runner/rate_runner.py +51 -23
  62. vectordb_bench/backend/runner/read_write_runner.py +140 -46
  63. vectordb_bench/backend/runner/serial_runner.py +99 -50
  64. vectordb_bench/backend/runner/util.py +4 -19
  65. vectordb_bench/backend/task_runner.py +95 -74
  66. vectordb_bench/backend/utils.py +17 -9
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +108 -83
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +56 -26
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +34 -42
  100. vectordb_bench-0.0.20.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.18.dist-info/RECORD +0 -131
  103. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  """Wrapper around the Pgvector vector database over VectorDB"""
2
2
 
3
3
  import logging
4
- import pprint
4
+ from collections.abc import Generator, Sequence
5
5
  from contextlib import contextmanager
6
- from typing import Any, Generator, Optional, Tuple, Sequence
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import psycopg
@@ -11,7 +11,7 @@ from pgvector.psycopg import register_vector
11
11
  from psycopg import Connection, Cursor, sql
12
12
 
13
13
  from ..api import VectorDB
14
- from .config import PgVectorConfigDict, PgVectorIndexConfig, PgVectorHNSWConfig
14
+ from .config import PgVectorConfigDict, PgVectorIndexConfig
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
@@ -56,13 +56,14 @@ class PgVector(VectorDB):
56
56
  (
57
57
  self.case_config.create_index_before_load,
58
58
  self.case_config.create_index_after_load,
59
- )
59
+ ),
60
60
  ):
61
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
62
- log.error(err)
63
- raise RuntimeError(
64
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
61
+ msg = (
62
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
63
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
65
64
  )
65
+ log.error(msg)
66
+ raise RuntimeError(msg)
66
67
 
67
68
  if drop_old:
68
69
  self._drop_index()
@@ -77,7 +78,7 @@ class PgVector(VectorDB):
77
78
  self.conn = None
78
79
 
79
80
  @staticmethod
80
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
81
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
81
82
  conn = psycopg.connect(**kwargs)
82
83
  register_vector(conn)
83
84
  conn.autocommit = False
@@ -87,8 +88,8 @@ class PgVector(VectorDB):
87
88
  assert cursor is not None, "Cursor is not initialized"
88
89
 
89
90
  return conn, cursor
90
-
91
- def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
91
+
92
+ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
92
93
  index_param = self.case_config.index_param()
93
94
  reranking = self.case_config.search_param()["reranking"]
94
95
  column_name = (
@@ -103,23 +104,25 @@ class PgVector(VectorDB):
103
104
  )
104
105
 
105
106
  # The following sections assume that the quantization_type value matches the quantization function name
106
- if index_param["quantization_type"] != None:
107
+ if index_param["quantization_type"] is not None:
107
108
  if index_param["quantization_type"] == "bit" and reranking:
108
109
  # Embeddings needs to be passed to binary_quantize function if quantization_type is bit
109
110
  search_query = sql.Composed(
110
111
  [
111
112
  sql.SQL(
112
113
  """
113
- SELECT i.id
114
+ SELECT i.id
114
115
  FROM (
115
- SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
116
+ SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
116
117
  FROM public.{table_name} {where_clause}
117
118
  ORDER BY {column_name}::{quantization_type}({dim})
118
- """
119
+ """,
119
120
  ).format(
120
121
  table_name=sql.Identifier(self.table_name),
121
122
  column_name=column_name,
122
- reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]),
123
+ reranking_metric_fun_op=sql.SQL(
124
+ self.case_config.search_param()["reranking_metric_fun_op"],
125
+ ),
123
126
  quantization_type=sql.SQL(index_param["quantization_type"]),
124
127
  dim=sql.Literal(self.dim),
125
128
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
@@ -127,25 +130,28 @@ class PgVector(VectorDB):
127
130
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
128
131
  sql.SQL(
129
132
  """
130
- {search_vector}
133
+ {search_vector}
131
134
  LIMIT {quantized_fetch_limit}
132
135
  ) i
133
- ORDER BY i.distance
136
+ ORDER BY i.distance
134
137
  LIMIT %s::int
135
- """
138
+ """,
136
139
  ).format(
137
140
  search_vector=search_vector,
138
141
  quantized_fetch_limit=sql.Literal(
139
- self.case_config.search_param()["quantized_fetch_limit"]
142
+ self.case_config.search_param()["quantized_fetch_limit"],
140
143
  ),
141
144
  ),
142
- ]
145
+ ],
143
146
  )
144
147
  else:
145
148
  search_query = sql.Composed(
146
149
  [
147
150
  sql.SQL(
148
- "SELECT id FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) "
151
+ """
152
+ SELECT id FROM public.{table_name}
153
+ {where_clause} ORDER BY {column_name}::{quantization_type}({dim})
154
+ """,
149
155
  ).format(
150
156
  table_name=sql.Identifier(self.table_name),
151
157
  column_name=column_name,
@@ -154,25 +160,26 @@ class PgVector(VectorDB):
154
160
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
155
161
  ),
156
162
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
157
- sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector),
158
- ]
163
+ sql.SQL(" {search_vector} LIMIT %s::int").format(
164
+ search_vector=search_vector,
165
+ ),
166
+ ],
159
167
  )
160
168
  else:
161
169
  search_query = sql.Composed(
162
170
  [
163
171
  sql.SQL(
164
- "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
172
+ "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ",
165
173
  ).format(
166
174
  table_name=sql.Identifier(self.table_name),
167
175
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
168
176
  ),
169
177
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
170
178
  sql.SQL(" %s::vector LIMIT %s::int"),
171
- ]
179
+ ],
172
180
  )
173
-
181
+
174
182
  return search_query
175
-
176
183
 
177
184
  @contextmanager
178
185
  def init(self) -> Generator[None, None, None]:
@@ -191,8 +198,8 @@ class PgVector(VectorDB):
191
198
  if len(session_options) > 0:
192
199
  for setting in session_options:
193
200
  command = sql.SQL("SET {setting_name} " + "= {val};").format(
194
- setting_name=sql.Identifier(setting['parameter']['setting_name']),
195
- val=sql.Identifier(str(setting['parameter']['val'])),
201
+ setting_name=sql.Identifier(setting["parameter"]["setting_name"]),
202
+ val=sql.Identifier(str(setting["parameter"]["val"])),
196
203
  )
197
204
  log.debug(command.as_string(self.cursor))
198
205
  self.cursor.execute(command)
@@ -216,8 +223,8 @@ class PgVector(VectorDB):
216
223
 
217
224
  self.cursor.execute(
218
225
  sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
219
- table_name=sql.Identifier(self.table_name)
220
- )
226
+ table_name=sql.Identifier(self.table_name),
227
+ ),
221
228
  )
222
229
  self.conn.commit()
223
230
 
@@ -239,7 +246,7 @@ class PgVector(VectorDB):
239
246
  log.info(f"{self.name} client drop index : {self._index_name}")
240
247
 
241
248
  drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
242
- index_name=sql.Identifier(self._index_name)
249
+ index_name=sql.Identifier(self._index_name),
243
250
  )
244
251
  log.debug(drop_index_sql.as_string(self.cursor))
245
252
  self.cursor.execute(drop_index_sql)
@@ -254,63 +261,51 @@ class PgVector(VectorDB):
254
261
  if index_param["maintenance_work_mem"] is not None:
255
262
  self.cursor.execute(
256
263
  sql.SQL("SET maintenance_work_mem TO {};").format(
257
- index_param["maintenance_work_mem"]
258
- )
264
+ index_param["maintenance_work_mem"],
265
+ ),
259
266
  )
260
267
  self.cursor.execute(
261
268
  sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
262
269
  sql.Identifier(self.db_config["user"]),
263
270
  index_param["maintenance_work_mem"],
264
- )
271
+ ),
265
272
  )
266
273
  self.conn.commit()
267
274
 
268
275
  if index_param["max_parallel_workers"] is not None:
269
276
  self.cursor.execute(
270
277
  sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
271
- index_param["max_parallel_workers"]
272
- )
278
+ index_param["max_parallel_workers"],
279
+ ),
273
280
  )
274
281
  self.cursor.execute(
275
- sql.SQL(
276
- "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
277
- ).format(
282
+ sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format(
278
283
  sql.Identifier(self.db_config["user"]),
279
284
  index_param["max_parallel_workers"],
280
- )
285
+ ),
281
286
  )
282
287
  self.cursor.execute(
283
288
  sql.SQL("SET max_parallel_workers TO '{}';").format(
284
- index_param["max_parallel_workers"]
285
- )
289
+ index_param["max_parallel_workers"],
290
+ ),
286
291
  )
287
292
  self.cursor.execute(
288
- sql.SQL(
289
- "ALTER USER {} SET max_parallel_workers TO '{}';"
290
- ).format(
293
+ sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format(
291
294
  sql.Identifier(self.db_config["user"]),
292
295
  index_param["max_parallel_workers"],
293
- )
296
+ ),
294
297
  )
295
298
  self.cursor.execute(
296
- sql.SQL(
297
- "ALTER TABLE {} SET (parallel_workers = {});"
298
- ).format(
299
+ sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format(
299
300
  sql.Identifier(self.table_name),
300
301
  index_param["max_parallel_workers"],
301
- )
302
+ ),
302
303
  )
303
304
  self.conn.commit()
304
305
 
305
- results = self.cursor.execute(
306
- sql.SQL("SHOW max_parallel_maintenance_workers;")
307
- ).fetchall()
308
- results.extend(
309
- self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
310
- )
311
- results.extend(
312
- self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
313
- )
306
+ results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall()
307
+ results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall())
308
+ results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall())
314
309
  log.info(f"{self.name} parallel index creation parameters: {results}")
315
310
 
316
311
  def _create_index(self):
@@ -322,24 +317,21 @@ class PgVector(VectorDB):
322
317
  self._set_parallel_index_build_param()
323
318
  options = []
324
319
  for option in index_param["index_creation_with_options"]:
325
- if option['val'] is not None:
320
+ if option["val"] is not None:
326
321
  options.append(
327
322
  sql.SQL("{option_name} = {val}").format(
328
- option_name=sql.Identifier(option['option_name']),
329
- val=sql.Identifier(str(option['val'])),
330
- )
323
+ option_name=sql.Identifier(option["option_name"]),
324
+ val=sql.Identifier(str(option["val"])),
325
+ ),
331
326
  )
332
- if any(options):
333
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
334
- else:
335
- with_clause = sql.Composed(())
327
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
336
328
 
337
- if index_param["quantization_type"] != None:
329
+ if index_param["quantization_type"] is not None:
338
330
  index_create_sql = sql.SQL(
339
331
  """
340
332
  CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
341
333
  USING {index_type} (({column_name}::{quantization_type}({dim})) {embedding_metric})
342
- """
334
+ """,
343
335
  ).format(
344
336
  index_name=sql.Identifier(self._index_name),
345
337
  table_name=sql.Identifier(self.table_name),
@@ -357,9 +349,9 @@ class PgVector(VectorDB):
357
349
  else:
358
350
  index_create_sql = sql.SQL(
359
351
  """
360
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
352
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
361
353
  USING {index_type} (embedding {embedding_metric})
362
- """
354
+ """,
363
355
  ).format(
364
356
  index_name=sql.Identifier(self._index_name),
365
357
  table_name=sql.Identifier(self.table_name),
@@ -367,9 +359,7 @@ class PgVector(VectorDB):
367
359
  embedding_metric=sql.Identifier(index_param["metric"]),
368
360
  )
369
361
 
370
- index_create_sql_with_with_clause = (
371
- index_create_sql + with_clause
372
- ).join(" ")
362
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
373
363
  log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
374
364
  self.cursor.execute(index_create_sql_with_with_clause)
375
365
  self.conn.commit()
@@ -384,19 +374,17 @@ class PgVector(VectorDB):
384
374
  # create table
385
375
  self.cursor.execute(
386
376
  sql.SQL(
387
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
388
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
377
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
378
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
389
379
  )
390
380
  self.cursor.execute(
391
381
  sql.SQL(
392
- "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;"
393
- ).format(table_name=sql.Identifier(self.table_name))
382
+ "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;",
383
+ ).format(table_name=sql.Identifier(self.table_name)),
394
384
  )
395
385
  self.conn.commit()
396
386
  except Exception as e:
397
- log.warning(
398
- f"Failed to create pgvector table: {self.table_name} error: {e}"
399
- )
387
+ log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}")
400
388
  raise e from None
401
389
 
402
390
  def insert_embeddings(
@@ -404,7 +392,7 @@ class PgVector(VectorDB):
404
392
  embeddings: list[list[float]],
405
393
  metadata: list[int],
406
394
  **kwargs: Any,
407
- ) -> Tuple[int, Optional[Exception]]:
395
+ ) -> tuple[int, Exception | None]:
408
396
  assert self.conn is not None, "Connection is not initialized"
409
397
  assert self.cursor is not None, "Cursor is not initialized"
410
398
 
@@ -414,8 +402,8 @@ class PgVector(VectorDB):
414
402
 
415
403
  with self.cursor.copy(
416
404
  sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
417
- table_name=sql.Identifier(self.table_name)
418
- )
405
+ table_name=sql.Identifier(self.table_name),
406
+ ),
419
407
  ) as copy:
420
408
  copy.set_types(["bigint", "vector"])
421
409
  for i, row in enumerate(metadata_arr):
@@ -428,7 +416,7 @@ class PgVector(VectorDB):
428
416
  return len(metadata), None
429
417
  except Exception as e:
430
418
  log.warning(
431
- f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
419
+ f"Failed to insert data into pgvector table ({self.table_name}), error: {e}",
432
420
  )
433
421
  return 0, e
434
422
 
@@ -449,21 +437,32 @@ class PgVector(VectorDB):
449
437
  gt = filters.get("id")
450
438
  if index_param["quantization_type"] == "bit" and search_param["reranking"]:
451
439
  result = self.cursor.execute(
452
- self._filtered_search, (q, gt, q, k), prepare=True, binary=True
440
+ self._filtered_search,
441
+ (q, gt, q, k),
442
+ prepare=True,
443
+ binary=True,
453
444
  )
454
445
  else:
455
446
  result = self.cursor.execute(
456
- self._filtered_search, (gt, q, k), prepare=True, binary=True
447
+ self._filtered_search,
448
+ (gt, q, k),
449
+ prepare=True,
450
+ binary=True,
457
451
  )
458
-
452
+
453
+ elif index_param["quantization_type"] == "bit" and search_param["reranking"]:
454
+ result = self.cursor.execute(
455
+ self._unfiltered_search,
456
+ (q, q, k),
457
+ prepare=True,
458
+ binary=True,
459
+ )
459
460
  else:
460
- if index_param["quantization_type"] == "bit" and search_param["reranking"]:
461
- result = self.cursor.execute(
462
- self._unfiltered_search, (q, q, k), prepare=True, binary=True
463
- )
464
- else:
465
- result = self.cursor.execute(
466
- self._unfiltered_search, (q, k), prepare=True, binary=True
467
- )
461
+ result = self.cursor.execute(
462
+ self._unfiltered_search,
463
+ (q, k),
464
+ prepare=True,
465
+ binary=True,
466
+ )
468
467
 
469
468
  return [int(i[0]) for i in result.fetchall()]
@@ -1,80 +1,94 @@
1
- import click
2
1
  import os
2
+ from typing import Annotated, Unpack
3
+
4
+ import click
3
5
  from pydantic import SecretStr
4
6
 
7
+ from vectordb_bench.backend.clients import DB
8
+
5
9
  from ....cli.cli import (
6
10
  CommonTypedDict,
7
11
  cli,
8
12
  click_parameter_decorators_from_typed_dict,
9
13
  run,
10
14
  )
11
- from typing import Annotated, Unpack
12
- from vectordb_bench.backend.clients import DB
13
15
 
14
16
 
15
17
  class PgVectorScaleTypedDict(CommonTypedDict):
16
18
  user_name: Annotated[
17
- str, click.option("--user-name", type=str, help="Db username", required=True)
19
+ str,
20
+ click.option("--user-name", type=str, help="Db username", required=True),
18
21
  ]
19
22
  password: Annotated[
20
23
  str,
21
- click.option("--password",
22
- type=str,
23
- help="Postgres database password",
24
- default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
25
- show_default="$POSTGRES_PASSWORD",
26
- ),
24
+ click.option(
25
+ "--password",
26
+ type=str,
27
+ help="Postgres database password",
28
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
29
+ show_default="$POSTGRES_PASSWORD",
30
+ ),
27
31
  ]
28
32
 
29
- host: Annotated[
30
- str, click.option("--host", type=str, help="Db host", required=True)
31
- ]
32
- db_name: Annotated[
33
- str, click.option("--db-name", type=str, help="Db name", required=True)
34
- ]
33
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
34
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
35
35
 
36
36
 
37
37
  class PgVectorScaleDiskAnnTypedDict(PgVectorScaleTypedDict):
38
38
  storage_layout: Annotated[
39
39
  str,
40
40
  click.option(
41
- "--storage-layout", type=str, help="Streaming DiskANN storage layout",
41
+ "--storage-layout",
42
+ type=str,
43
+ help="Streaming DiskANN storage layout",
42
44
  ),
43
45
  ]
44
46
  num_neighbors: Annotated[
45
47
  int,
46
48
  click.option(
47
- "--num-neighbors", type=int, help="Streaming DiskANN num neighbors",
49
+ "--num-neighbors",
50
+ type=int,
51
+ help="Streaming DiskANN num neighbors",
48
52
  ),
49
53
  ]
50
54
  search_list_size: Annotated[
51
55
  int,
52
56
  click.option(
53
- "--search-list-size", type=int, help="Streaming DiskANN search list size",
57
+ "--search-list-size",
58
+ type=int,
59
+ help="Streaming DiskANN search list size",
54
60
  ),
55
61
  ]
56
62
  max_alpha: Annotated[
57
63
  float,
58
64
  click.option(
59
- "--max-alpha", type=float, help="Streaming DiskANN max alpha",
65
+ "--max-alpha",
66
+ type=float,
67
+ help="Streaming DiskANN max alpha",
60
68
  ),
61
69
  ]
62
70
  num_dimensions: Annotated[
63
71
  int,
64
72
  click.option(
65
- "--num-dimensions", type=int, help="Streaming DiskANN num dimensions",
73
+ "--num-dimensions",
74
+ type=int,
75
+ help="Streaming DiskANN num dimensions",
66
76
  ),
67
77
  ]
68
78
  query_search_list_size: Annotated[
69
79
  int,
70
80
  click.option(
71
- "--query-search-list-size", type=int, help="Streaming DiskANN query search list size",
81
+ "--query-search-list-size",
82
+ type=int,
83
+ help="Streaming DiskANN query search list size",
72
84
  ),
73
85
  ]
74
86
  query_rescore: Annotated[
75
87
  int,
76
88
  click.option(
77
- "--query-rescore", type=int, help="Streaming DiskANN query rescore",
89
+ "--query-rescore",
90
+ type=int,
91
+ help="Streaming DiskANN query rescore",
78
92
  ),
79
93
  ]
80
94
 
@@ -105,4 +119,4 @@ def PgVectorScaleDiskAnn(
105
119
  query_rescore=parameters["query_rescore"],
106
120
  ),
107
121
  **parameters,
108
- )
122
+ )
@@ -1,7 +1,8 @@
1
1
  from abc import abstractmethod
2
- from typing import TypedDict
2
+ from typing import LiteralString, TypedDict
3
+
3
4
  from pydantic import BaseModel, SecretStr
4
- from typing_extensions import LiteralString
5
+
5
6
  from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
7
 
7
8
  POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
@@ -9,7 +10,7 @@ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
9
10
 
10
11
  class PgVectorScaleConfigDict(TypedDict):
11
12
  """These keys will be directly used as kwargs in psycopg connection string,
12
- so the names must match exactly psycopg API"""
13
+ so the names must match exactly psycopg API"""
13
14
 
14
15
  user: str
15
16
  password: str
@@ -46,7 +47,7 @@ class PgVectorScaleIndexConfig(BaseModel, DBCaseConfig):
46
47
  if self.metric_type == MetricType.COSINE:
47
48
  return "vector_cosine_ops"
48
49
  return ""
49
-
50
+
50
51
  def parse_metric_fun_op(self) -> LiteralString:
51
52
  if self.metric_type == MetricType.COSINE:
52
53
  return "<=>"
@@ -56,19 +57,16 @@ class PgVectorScaleIndexConfig(BaseModel, DBCaseConfig):
56
57
  if self.metric_type == MetricType.COSINE:
57
58
  return "cosine_distance"
58
59
  return ""
59
-
60
+
60
61
  @abstractmethod
61
- def index_param(self) -> dict:
62
- ...
62
+ def index_param(self) -> dict: ...
63
63
 
64
64
  @abstractmethod
65
- def search_param(self) -> dict:
66
- ...
65
+ def search_param(self) -> dict: ...
67
66
 
68
67
  @abstractmethod
69
- def session_param(self) -> dict:
70
- ...
71
-
68
+ def session_param(self) -> dict: ...
69
+
72
70
 
73
71
  class PgVectorScaleStreamingDiskANNConfig(PgVectorScaleIndexConfig):
74
72
  index: IndexType = IndexType.STREAMING_DISKANN
@@ -93,19 +91,20 @@ class PgVectorScaleStreamingDiskANNConfig(PgVectorScaleIndexConfig):
93
91
  "num_dimensions": self.num_dimensions,
94
92
  },
95
93
  }
96
-
94
+
97
95
  def search_param(self) -> dict:
98
96
  return {
99
97
  "metric": self.parse_metric(),
100
98
  "metric_fun_op": self.parse_metric_fun_op(),
101
99
  }
102
-
100
+
103
101
  def session_param(self) -> dict:
104
102
  return {
105
103
  "diskann.query_search_list_size": self.query_search_list_size,
106
104
  "diskann.query_rescore": self.query_rescore,
107
105
  }
108
-
106
+
107
+
109
108
  _pgvectorscale_case_config = {
110
109
  IndexType.STREAMING_DISKANN: PgVectorScaleStreamingDiskANNConfig,
111
110
  }