vectordb-bench 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +75 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +111 -70
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +5 -9
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +38 -36
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +18 -19
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +94 -58
  47. vectordb_bench/backend/clients/test/cli.py +1 -2
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +4 -5
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +30 -18
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +85 -34
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +51 -23
  63. vectordb_bench/backend/runner/serial_runner.py +91 -48
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -72
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +56 -26
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.20.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  """Wrapper around the pg_diskann vector database over VectorDB"""
2
2
 
3
3
  import logging
4
- import pprint
4
+ from collections.abc import Generator
5
5
  from contextlib import contextmanager
6
- from typing import Any, Generator, Optional, Tuple
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import psycopg
@@ -44,20 +44,21 @@ class PgDiskANN(VectorDB):
44
44
  self._primary_field = "id"
45
45
  self._vector_field = "embedding"
46
46
 
47
- self.conn, self.cursor = self._create_connection(**self.db_config)
47
+ self.conn, self.cursor = self._create_connection(**self.db_config)
48
48
 
49
49
  log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
50
50
  if not any(
51
51
  (
52
52
  self.case_config.create_index_before_load,
53
53
  self.case_config.create_index_after_load,
54
- )
54
+ ),
55
55
  ):
56
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
57
- log.error(err)
58
- raise RuntimeError(
59
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
56
+ msg = (
57
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
58
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
60
59
  )
60
+ log.error(msg)
61
+ raise RuntimeError(msg)
61
62
 
62
63
  if drop_old:
63
64
  self._drop_index()
@@ -72,7 +73,7 @@ class PgDiskANN(VectorDB):
72
73
  self.conn = None
73
74
 
74
75
  @staticmethod
75
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
76
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
76
77
  conn = psycopg.connect(**kwargs)
77
78
  conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS pg_diskann CASCADE")
78
79
  conn.commit()
@@ -101,25 +102,25 @@ class PgDiskANN(VectorDB):
101
102
  log.debug(command.as_string(self.cursor))
102
103
  self.cursor.execute(command)
103
104
  self.conn.commit()
104
-
105
+
105
106
  self._filtered_search = sql.Composed(
106
107
  [
107
108
  sql.SQL(
108
- "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
109
- ).format(table_name=sql.Identifier(self.table_name)),
109
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
110
+ ).format(table_name=sql.Identifier(self.table_name)),
110
111
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
111
112
  sql.SQL(" %s::vector LIMIT %s::int"),
112
- ]
113
+ ],
113
114
  )
114
115
 
115
116
  self._unfiltered_search = sql.Composed(
116
117
  [
117
118
  sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
118
- sql.Identifier(self.table_name)
119
+ sql.Identifier(self.table_name),
119
120
  ),
120
121
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
121
122
  sql.SQL(" %s::vector LIMIT %s::int"),
122
- ]
123
+ ],
123
124
  )
124
125
 
125
126
  try:
@@ -137,8 +138,8 @@ class PgDiskANN(VectorDB):
137
138
 
138
139
  self.cursor.execute(
139
140
  sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
140
- table_name=sql.Identifier(self.table_name)
141
- )
141
+ table_name=sql.Identifier(self.table_name),
142
+ ),
142
143
  )
143
144
  self.conn.commit()
144
145
 
@@ -160,7 +161,7 @@ class PgDiskANN(VectorDB):
160
161
  log.info(f"{self.name} client drop index : {self._index_name}")
161
162
 
162
163
  drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
163
- index_name=sql.Identifier(self._index_name)
164
+ index_name=sql.Identifier(self._index_name),
164
165
  )
165
166
  log.debug(drop_index_sql.as_string(self.cursor))
166
167
  self.cursor.execute(drop_index_sql)
@@ -175,64 +176,53 @@ class PgDiskANN(VectorDB):
175
176
  if index_param["maintenance_work_mem"] is not None:
176
177
  self.cursor.execute(
177
178
  sql.SQL("SET maintenance_work_mem TO {};").format(
178
- index_param["maintenance_work_mem"]
179
- )
179
+ index_param["maintenance_work_mem"],
180
+ ),
180
181
  )
181
182
  self.cursor.execute(
182
183
  sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
183
184
  sql.Identifier(self.db_config["user"]),
184
185
  index_param["maintenance_work_mem"],
185
- )
186
+ ),
186
187
  )
187
188
  self.conn.commit()
188
189
 
189
190
  if index_param["max_parallel_workers"] is not None:
190
191
  self.cursor.execute(
191
192
  sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
192
- index_param["max_parallel_workers"]
193
- )
193
+ index_param["max_parallel_workers"],
194
+ ),
194
195
  )
195
196
  self.cursor.execute(
196
- sql.SQL(
197
- "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
198
- ).format(
197
+ sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format(
199
198
  sql.Identifier(self.db_config["user"]),
200
199
  index_param["max_parallel_workers"],
201
- )
200
+ ),
202
201
  )
203
202
  self.cursor.execute(
204
203
  sql.SQL("SET max_parallel_workers TO '{}';").format(
205
- index_param["max_parallel_workers"]
206
- )
204
+ index_param["max_parallel_workers"],
205
+ ),
207
206
  )
208
207
  self.cursor.execute(
209
- sql.SQL(
210
- "ALTER USER {} SET max_parallel_workers TO '{}';"
211
- ).format(
208
+ sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format(
212
209
  sql.Identifier(self.db_config["user"]),
213
210
  index_param["max_parallel_workers"],
214
- )
211
+ ),
215
212
  )
216
213
  self.cursor.execute(
217
- sql.SQL(
218
- "ALTER TABLE {} SET (parallel_workers = {});"
219
- ).format(
214
+ sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format(
220
215
  sql.Identifier(self.table_name),
221
216
  index_param["max_parallel_workers"],
222
- )
217
+ ),
223
218
  )
224
219
  self.conn.commit()
225
220
 
226
- results = self.cursor.execute(
227
- sql.SQL("SHOW max_parallel_maintenance_workers;")
228
- ).fetchall()
229
- results.extend(
230
- self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
231
- )
232
- results.extend(
233
- self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
234
- )
221
+ results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall()
222
+ results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall())
223
+ results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall())
235
224
  log.info(f"{self.name} parallel index creation parameters: {results}")
225
+
236
226
  def _create_index(self):
237
227
  assert self.conn is not None, "Connection is not initialized"
238
228
  assert self.cursor is not None, "Cursor is not initialized"
@@ -248,28 +238,23 @@ class PgDiskANN(VectorDB):
248
238
  sql.SQL("{option_name} = {val}").format(
249
239
  option_name=sql.Identifier(option_name),
250
240
  val=sql.Identifier(str(option_val)),
251
- )
241
+ ),
252
242
  )
253
-
254
- if any(options):
255
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
256
- else:
257
- with_clause = sql.Composed(())
243
+
244
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
258
245
 
259
246
  index_create_sql = sql.SQL(
260
247
  """
261
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
248
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
262
249
  USING {index_type} (embedding {embedding_metric})
263
- """
250
+ """,
264
251
  ).format(
265
252
  index_name=sql.Identifier(self._index_name),
266
253
  table_name=sql.Identifier(self.table_name),
267
254
  index_type=sql.Identifier(index_param["index_type"].lower()),
268
255
  embedding_metric=sql.Identifier(index_param["metric"]),
269
256
  )
270
- index_create_sql_with_with_clause = (
271
- index_create_sql + with_clause
272
- ).join(" ")
257
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
273
258
  log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
274
259
  self.cursor.execute(index_create_sql_with_with_clause)
275
260
  self.conn.commit()
@@ -283,14 +268,12 @@ class PgDiskANN(VectorDB):
283
268
 
284
269
  self.cursor.execute(
285
270
  sql.SQL(
286
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
287
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
271
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
272
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
288
273
  )
289
274
  self.conn.commit()
290
275
  except Exception as e:
291
- log.warning(
292
- f"Failed to create pgdiskann table: {self.table_name} error: {e}"
293
- )
276
+ log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}")
294
277
  raise e from None
295
278
 
296
279
  def insert_embeddings(
@@ -298,7 +281,7 @@ class PgDiskANN(VectorDB):
298
281
  embeddings: list[list[float]],
299
282
  metadata: list[int],
300
283
  **kwargs: Any,
301
- ) -> Tuple[int, Optional[Exception]]:
284
+ ) -> tuple[int, Exception | None]:
302
285
  assert self.conn is not None, "Connection is not initialized"
303
286
  assert self.cursor is not None, "Cursor is not initialized"
304
287
 
@@ -308,8 +291,8 @@ class PgDiskANN(VectorDB):
308
291
 
309
292
  with self.cursor.copy(
310
293
  sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
311
- table_name=sql.Identifier(self.table_name)
312
- )
294
+ table_name=sql.Identifier(self.table_name),
295
+ ),
313
296
  ) as copy:
314
297
  copy.set_types(["bigint", "vector"])
315
298
  for i, row in enumerate(metadata_arr):
@@ -321,9 +304,7 @@ class PgDiskANN(VectorDB):
321
304
 
322
305
  return len(metadata), None
323
306
  except Exception as e:
324
- log.warning(
325
- f"Failed to insert data into table ({self.table_name}), error: {e}"
326
- )
307
+ log.warning(f"Failed to insert data into table ({self.table_name}), error: {e}")
327
308
  return 0, e
328
309
 
329
310
  def search_embedding(
@@ -340,11 +321,12 @@ class PgDiskANN(VectorDB):
340
321
  if filters:
341
322
  gt = filters.get("id")
342
323
  result = self.cursor.execute(
343
- self._filtered_search, (gt, q, k), prepare=True, binary=True
344
- )
324
+ self._filtered_search,
325
+ (gt, q, k),
326
+ prepare=True,
327
+ binary=True,
328
+ )
345
329
  else:
346
- result = self.cursor.execute(
347
- self._unfiltered_search, (q, k), prepare=True, binary=True
348
- )
330
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
349
331
 
350
332
  return [int(i[0]) for i in result.fetchall()]
@@ -1,9 +1,11 @@
1
- from typing import Annotated, Optional, Unpack
1
+ import os
2
+ from typing import Annotated, Unpack
2
3
 
3
4
  import click
4
- import os
5
5
  from pydantic import SecretStr
6
6
 
7
+ from vectordb_bench.backend.clients import DB
8
+
7
9
  from ....cli.cli import (
8
10
  CommonTypedDict,
9
11
  HNSWFlavor1,
@@ -12,12 +14,12 @@ from ....cli.cli import (
12
14
  click_parameter_decorators_from_typed_dict,
13
15
  run,
14
16
  )
15
- from vectordb_bench.backend.clients import DB
16
17
 
17
18
 
18
19
  class PgVectoRSTypedDict(CommonTypedDict):
19
20
  user_name: Annotated[
20
- str, click.option("--user-name", type=str, help="Db username", required=True)
21
+ str,
22
+ click.option("--user-name", type=str, help="Db username", required=True),
21
23
  ]
22
24
  password: Annotated[
23
25
  str,
@@ -30,14 +32,10 @@ class PgVectoRSTypedDict(CommonTypedDict):
30
32
  ),
31
33
  ]
32
34
 
33
- host: Annotated[
34
- str, click.option("--host", type=str, help="Db host", required=True)
35
- ]
36
- db_name: Annotated[
37
- str, click.option("--db-name", type=str, help="Db name", required=True)
38
- ]
35
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
36
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
39
37
  max_parallel_workers: Annotated[
40
- Optional[int],
38
+ int | None,
41
39
  click.option(
42
40
  "--max-parallel-workers",
43
41
  type=int,
@@ -1,11 +1,11 @@
1
1
  from abc import abstractmethod
2
2
  from typing import TypedDict
3
3
 
4
+ from pgvecto_rs.types import Flat, Hnsw, IndexOption, Ivf, Quantization
5
+ from pgvecto_rs.types.index import QuantizationRatio, QuantizationType
4
6
  from pydantic import BaseModel, SecretStr
5
- from pgvecto_rs.types import IndexOption, Ivf, Hnsw, Flat, Quantization
6
- from pgvecto_rs.types.index import QuantizationType, QuantizationRatio
7
7
 
8
- from ..api import DBConfig, DBCaseConfig, IndexType, MetricType
8
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
9
9
 
10
10
  POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
11
11
 
@@ -52,14 +52,14 @@ class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
52
52
  def parse_metric(self) -> str:
53
53
  if self.metric_type == MetricType.L2:
54
54
  return "vector_l2_ops"
55
- elif self.metric_type == MetricType.IP:
55
+ if self.metric_type == MetricType.IP:
56
56
  return "vector_dot_ops"
57
57
  return "vector_cos_ops"
58
58
 
59
59
  def parse_metric_fun_op(self) -> str:
60
60
  if self.metric_type == MetricType.L2:
61
61
  return "<->"
62
- elif self.metric_type == MetricType.IP:
62
+ if self.metric_type == MetricType.IP:
63
63
  return "<#>"
64
64
  return "<=>"
65
65
 
@@ -85,9 +85,7 @@ class PgVectoRSHNSWConfig(PgVectoRSIndexConfig):
85
85
  if self.quantization_type is None:
86
86
  quantization = None
87
87
  else:
88
- quantization = Quantization(
89
- typ=self.quantization_type, ratio=self.quantization_ratio
90
- )
88
+ quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio)
91
89
 
92
90
  option = IndexOption(
93
91
  index=Hnsw(
@@ -115,9 +113,7 @@ class PgVectoRSIVFFlatConfig(PgVectoRSIndexConfig):
115
113
  if self.quantization_type is None:
116
114
  quantization = None
117
115
  else:
118
- quantization = Quantization(
119
- typ=self.quantization_type, ratio=self.quantization_ratio
120
- )
116
+ quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio)
121
117
 
122
118
  option = IndexOption(
123
119
  index=Ivf(nlist=self.lists, quantization=quantization),
@@ -139,9 +135,7 @@ class PgVectoRSFLATConfig(PgVectoRSIndexConfig):
139
135
  if self.quantization_type is None:
140
136
  quantization = None
141
137
  else:
142
- quantization = Quantization(
143
- typ=self.quantization_type, ratio=self.quantization_ratio
144
- )
138
+ quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio)
145
139
 
146
140
  option = IndexOption(
147
141
  index=Flat(
@@ -1,14 +1,14 @@
1
1
  """Wrapper around the Pgvecto.rs vector database over VectorDB"""
2
2
 
3
3
  import logging
4
- import pprint
4
+ from collections.abc import Generator
5
5
  from contextlib import contextmanager
6
- from typing import Any, Generator, Optional, Tuple
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import psycopg
10
- from psycopg import Connection, Cursor, sql
11
10
  from pgvecto_rs.psycopg import register_vector
11
+ from psycopg import Connection, Cursor, sql
12
12
 
13
13
  from ..api import VectorDB
14
14
  from .config import PgVectoRSConfig, PgVectoRSIndexConfig
@@ -33,7 +33,6 @@ class PgVectoRS(VectorDB):
33
33
  drop_old: bool = False,
34
34
  **kwargs,
35
35
  ):
36
-
37
36
  self.name = "PgVectorRS"
38
37
  self.db_config = db_config
39
38
  self.case_config = db_case_config
@@ -52,13 +51,14 @@ class PgVectoRS(VectorDB):
52
51
  (
53
52
  self.case_config.create_index_before_load,
54
53
  self.case_config.create_index_after_load,
55
- )
54
+ ),
56
55
  ):
57
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
58
- log.error(err)
59
- raise RuntimeError(
60
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
56
+ msg = (
57
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
58
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
61
59
  )
60
+ log.error(msg)
61
+ raise RuntimeError(msg)
62
62
 
63
63
  if drop_old:
64
64
  log.info(f"Pgvecto.rs client drop table : {self.table_name}")
@@ -74,7 +74,7 @@ class PgVectoRS(VectorDB):
74
74
  self.conn = None
75
75
 
76
76
  @staticmethod
77
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
77
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
78
78
  conn = psycopg.connect(**kwargs)
79
79
 
80
80
  # create vector extension
@@ -116,21 +116,21 @@ class PgVectoRS(VectorDB):
116
116
  self._filtered_search = sql.Composed(
117
117
  [
118
118
  sql.SQL(
119
- "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
119
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
120
120
  ).format(table_name=sql.Identifier(self.table_name)),
121
121
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
122
122
  sql.SQL(" %s::vector LIMIT %s::int"),
123
- ]
123
+ ],
124
124
  )
125
125
 
126
126
  self._unfiltered_search = sql.Composed(
127
127
  [
128
- sql.SQL(
129
- "SELECT id FROM public.{table_name} ORDER BY embedding "
130
- ).format(table_name=sql.Identifier(self.table_name)),
128
+ sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format(
129
+ table_name=sql.Identifier(self.table_name),
130
+ ),
131
131
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
132
132
  sql.SQL(" %s::vector LIMIT %s::int"),
133
- ]
133
+ ],
134
134
  )
135
135
 
136
136
  try:
@@ -148,8 +148,8 @@ class PgVectoRS(VectorDB):
148
148
 
149
149
  self.cursor.execute(
150
150
  sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
151
- table_name=sql.Identifier(self.table_name)
152
- )
151
+ table_name=sql.Identifier(self.table_name),
152
+ ),
153
153
  )
154
154
  self.conn.commit()
155
155
 
@@ -171,7 +171,7 @@ class PgVectoRS(VectorDB):
171
171
  log.info(f"{self.name} client drop index : {self._index_name}")
172
172
 
173
173
  drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
174
- index_name=sql.Identifier(self._index_name)
174
+ index_name=sql.Identifier(self._index_name),
175
175
  )
176
176
  log.debug(drop_index_sql.as_string(self.cursor))
177
177
  self.cursor.execute(drop_index_sql)
@@ -186,9 +186,9 @@ class PgVectoRS(VectorDB):
186
186
 
187
187
  index_create_sql = sql.SQL(
188
188
  """
189
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
189
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
190
190
  USING vectors (embedding {embedding_metric}) WITH (options = {index_options})
191
- """
191
+ """,
192
192
  ).format(
193
193
  index_name=sql.Identifier(self._index_name),
194
194
  table_name=sql.Identifier(self.table_name),
@@ -202,7 +202,7 @@ class PgVectoRS(VectorDB):
202
202
  except Exception as e:
203
203
  log.warning(
204
204
  f"Failed to create pgvecto.rs index {self._index_name} \
205
- at table {self.table_name} error: {e}"
205
+ at table {self.table_name} error: {e}",
206
206
  )
207
207
  raise e from None
208
208
 
@@ -214,7 +214,7 @@ class PgVectoRS(VectorDB):
214
214
  """
215
215
  CREATE TABLE IF NOT EXISTS public.{table_name}
216
216
  (id BIGINT PRIMARY KEY, embedding vector({dim}))
217
- """
217
+ """,
218
218
  ).format(
219
219
  table_name=sql.Identifier(self.table_name),
220
220
  dim=dim,
@@ -224,9 +224,7 @@ class PgVectoRS(VectorDB):
224
224
  self.cursor.execute(table_create_sql)
225
225
  self.conn.commit()
226
226
  except Exception as e:
227
- log.warning(
228
- f"Failed to create pgvecto.rs table: {self.table_name} error: {e}"
229
- )
227
+ log.warning(f"Failed to create pgvecto.rs table: {self.table_name} error: {e}")
230
228
  raise e from None
231
229
 
232
230
  def insert_embeddings(
@@ -234,7 +232,7 @@ class PgVectoRS(VectorDB):
234
232
  embeddings: list[list[float]],
235
233
  metadata: list[int],
236
234
  **kwargs: Any,
237
- ) -> Tuple[int, Optional[Exception]]:
235
+ ) -> tuple[int, Exception | None]:
238
236
  assert self.conn is not None, "Connection is not initialized"
239
237
  assert self.cursor is not None, "Cursor is not initialized"
240
238
 
@@ -247,8 +245,8 @@ class PgVectoRS(VectorDB):
247
245
 
248
246
  with self.cursor.copy(
249
247
  sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
250
- table_name=sql.Identifier(self.table_name)
251
- )
248
+ table_name=sql.Identifier(self.table_name),
249
+ ),
252
250
  ) as copy:
253
251
  copy.set_types(["bigint", "vector"])
254
252
  for i, row in enumerate(metadata_arr):
@@ -261,7 +259,7 @@ class PgVectoRS(VectorDB):
261
259
  return len(metadata), None
262
260
  except Exception as e:
263
261
  log.warning(
264
- f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}"
262
+ f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}",
265
263
  )
266
264
  return 0, e
267
265
 
@@ -281,12 +279,13 @@ class PgVectoRS(VectorDB):
281
279
  log.debug(self._filtered_search.as_string(self.cursor))
282
280
  gt = filters.get("id")
283
281
  result = self.cursor.execute(
284
- self._filtered_search, (gt, q, k), prepare=True, binary=True
282
+ self._filtered_search,
283
+ (gt, q, k),
284
+ prepare=True,
285
+ binary=True,
285
286
  )
286
287
  else:
287
288
  log.debug(self._unfiltered_search.as_string(self.cursor))
288
- result = self.cursor.execute(
289
- self._unfiltered_search, (q, k), prepare=True, binary=True
290
- )
289
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
291
290
 
292
291
  return [int(i[0]) for i in result.fetchall()]