vectordb-bench 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +56 -46
  5. vectordb_bench/backend/clients/__init__.py +101 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +26 -0
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +18 -0
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +345 -0
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +47 -0
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
  11. vectordb_bench/backend/clients/alloydb/cli.py +52 -35
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +8 -9
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +38 -36
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
  25. vectordb_bench/backend/clients/milvus/cli.py +62 -80
  26. vectordb_bench/backend/clients/milvus/config.py +31 -7
  27. vectordb_bench/backend/clients/milvus/milvus.py +23 -26
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +94 -58
  47. vectordb_bench/backend/clients/test/cli.py +1 -2
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +4 -5
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +30 -18
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +85 -34
  61. vectordb_bench/backend/runner/rate_runner.py +51 -23
  62. vectordb_bench/backend/runner/read_write_runner.py +140 -46
  63. vectordb_bench/backend/runner/serial_runner.py +99 -50
  64. vectordb_bench/backend/runner/util.py +4 -19
  65. vectordb_bench/backend/task_runner.py +95 -74
  66. vectordb_bench/backend/utils.py +17 -9
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +108 -83
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +56 -26
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +34 -42
  100. vectordb_bench-0.0.20.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.18.dist-info/RECORD +0 -131
  103. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,14 @@
1
1
  """Wrapper around the Pgvecto.rs vector database over VectorDB"""
2
2
 
3
3
  import logging
4
- import pprint
4
+ from collections.abc import Generator
5
5
  from contextlib import contextmanager
6
- from typing import Any, Generator, Optional, Tuple
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import psycopg
10
- from psycopg import Connection, Cursor, sql
11
10
  from pgvecto_rs.psycopg import register_vector
11
+ from psycopg import Connection, Cursor, sql
12
12
 
13
13
  from ..api import VectorDB
14
14
  from .config import PgVectoRSConfig, PgVectoRSIndexConfig
@@ -33,7 +33,6 @@ class PgVectoRS(VectorDB):
33
33
  drop_old: bool = False,
34
34
  **kwargs,
35
35
  ):
36
-
37
36
  self.name = "PgVectorRS"
38
37
  self.db_config = db_config
39
38
  self.case_config = db_case_config
@@ -52,13 +51,14 @@ class PgVectoRS(VectorDB):
52
51
  (
53
52
  self.case_config.create_index_before_load,
54
53
  self.case_config.create_index_after_load,
55
- )
54
+ ),
56
55
  ):
57
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
58
- log.error(err)
59
- raise RuntimeError(
60
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
56
+ msg = (
57
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
58
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
61
59
  )
60
+ log.error(msg)
61
+ raise RuntimeError(msg)
62
62
 
63
63
  if drop_old:
64
64
  log.info(f"Pgvecto.rs client drop table : {self.table_name}")
@@ -74,7 +74,7 @@ class PgVectoRS(VectorDB):
74
74
  self.conn = None
75
75
 
76
76
  @staticmethod
77
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
77
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
78
78
  conn = psycopg.connect(**kwargs)
79
79
 
80
80
  # create vector extension
@@ -116,21 +116,21 @@ class PgVectoRS(VectorDB):
116
116
  self._filtered_search = sql.Composed(
117
117
  [
118
118
  sql.SQL(
119
- "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
119
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
120
120
  ).format(table_name=sql.Identifier(self.table_name)),
121
121
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
122
122
  sql.SQL(" %s::vector LIMIT %s::int"),
123
- ]
123
+ ],
124
124
  )
125
125
 
126
126
  self._unfiltered_search = sql.Composed(
127
127
  [
128
- sql.SQL(
129
- "SELECT id FROM public.{table_name} ORDER BY embedding "
130
- ).format(table_name=sql.Identifier(self.table_name)),
128
+ sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format(
129
+ table_name=sql.Identifier(self.table_name),
130
+ ),
131
131
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
132
132
  sql.SQL(" %s::vector LIMIT %s::int"),
133
- ]
133
+ ],
134
134
  )
135
135
 
136
136
  try:
@@ -148,8 +148,8 @@ class PgVectoRS(VectorDB):
148
148
 
149
149
  self.cursor.execute(
150
150
  sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
151
- table_name=sql.Identifier(self.table_name)
152
- )
151
+ table_name=sql.Identifier(self.table_name),
152
+ ),
153
153
  )
154
154
  self.conn.commit()
155
155
 
@@ -171,7 +171,7 @@ class PgVectoRS(VectorDB):
171
171
  log.info(f"{self.name} client drop index : {self._index_name}")
172
172
 
173
173
  drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
174
- index_name=sql.Identifier(self._index_name)
174
+ index_name=sql.Identifier(self._index_name),
175
175
  )
176
176
  log.debug(drop_index_sql.as_string(self.cursor))
177
177
  self.cursor.execute(drop_index_sql)
@@ -186,9 +186,9 @@ class PgVectoRS(VectorDB):
186
186
 
187
187
  index_create_sql = sql.SQL(
188
188
  """
189
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
189
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
190
190
  USING vectors (embedding {embedding_metric}) WITH (options = {index_options})
191
- """
191
+ """,
192
192
  ).format(
193
193
  index_name=sql.Identifier(self._index_name),
194
194
  table_name=sql.Identifier(self.table_name),
@@ -202,7 +202,7 @@ class PgVectoRS(VectorDB):
202
202
  except Exception as e:
203
203
  log.warning(
204
204
  f"Failed to create pgvecto.rs index {self._index_name} \
205
- at table {self.table_name} error: {e}"
205
+ at table {self.table_name} error: {e}",
206
206
  )
207
207
  raise e from None
208
208
 
@@ -214,7 +214,7 @@ class PgVectoRS(VectorDB):
214
214
  """
215
215
  CREATE TABLE IF NOT EXISTS public.{table_name}
216
216
  (id BIGINT PRIMARY KEY, embedding vector({dim}))
217
- """
217
+ """,
218
218
  ).format(
219
219
  table_name=sql.Identifier(self.table_name),
220
220
  dim=dim,
@@ -224,9 +224,7 @@ class PgVectoRS(VectorDB):
224
224
  self.cursor.execute(table_create_sql)
225
225
  self.conn.commit()
226
226
  except Exception as e:
227
- log.warning(
228
- f"Failed to create pgvecto.rs table: {self.table_name} error: {e}"
229
- )
227
+ log.warning(f"Failed to create pgvecto.rs table: {self.table_name} error: {e}")
230
228
  raise e from None
231
229
 
232
230
  def insert_embeddings(
@@ -234,7 +232,7 @@ class PgVectoRS(VectorDB):
234
232
  embeddings: list[list[float]],
235
233
  metadata: list[int],
236
234
  **kwargs: Any,
237
- ) -> Tuple[int, Optional[Exception]]:
235
+ ) -> tuple[int, Exception | None]:
238
236
  assert self.conn is not None, "Connection is not initialized"
239
237
  assert self.cursor is not None, "Cursor is not initialized"
240
238
 
@@ -247,8 +245,8 @@ class PgVectoRS(VectorDB):
247
245
 
248
246
  with self.cursor.copy(
249
247
  sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
250
- table_name=sql.Identifier(self.table_name)
251
- )
248
+ table_name=sql.Identifier(self.table_name),
249
+ ),
252
250
  ) as copy:
253
251
  copy.set_types(["bigint", "vector"])
254
252
  for i, row in enumerate(metadata_arr):
@@ -261,7 +259,7 @@ class PgVectoRS(VectorDB):
261
259
  return len(metadata), None
262
260
  except Exception as e:
263
261
  log.warning(
264
- f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}"
262
+ f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}",
265
263
  )
266
264
  return 0, e
267
265
 
@@ -281,12 +279,13 @@ class PgVectoRS(VectorDB):
281
279
  log.debug(self._filtered_search.as_string(self.cursor))
282
280
  gt = filters.get("id")
283
281
  result = self.cursor.execute(
284
- self._filtered_search, (gt, q, k), prepare=True, binary=True
282
+ self._filtered_search,
283
+ (gt, q, k),
284
+ prepare=True,
285
+ binary=True,
285
286
  )
286
287
  else:
287
288
  log.debug(self._unfiltered_search.as_string(self.cursor))
288
- result = self.cursor.execute(
289
- self._unfiltered_search, (q, k), prepare=True, binary=True
290
- )
289
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
291
290
 
292
291
  return [int(i[0]) for i in result.fetchall()]
@@ -1,9 +1,10 @@
1
- from typing import Annotated, Optional, TypedDict, Unpack
1
+ import os
2
+ from typing import Annotated, Unpack
2
3
 
3
4
  import click
4
- import os
5
5
  from pydantic import SecretStr
6
6
 
7
+ from vectordb_bench.backend.clients import DB
7
8
  from vectordb_bench.backend.clients.api import MetricType
8
9
 
9
10
  from ....cli.cli import (
@@ -15,39 +16,48 @@ from ....cli.cli import (
15
16
  get_custom_case_config,
16
17
  run,
17
18
  )
18
- from vectordb_bench.backend.clients import DB
19
19
 
20
20
 
21
- def set_default_quantized_fetch_limit(ctx, param, value):
21
+ # ruff: noqa
22
+ def set_default_quantized_fetch_limit(ctx: any, param: any, value: any):
22
23
  if ctx.params.get("reranking") and value is None:
23
24
  # ef_search is the default value for quantized_fetch_limit as it's bound by ef_search.
24
25
  # 100 is default value for quantized_fetch_limit for IVFFlat.
25
- default_value = ctx.params["ef_search"] if ctx.command.name == "pgvectorhnsw" else 100
26
- return default_value
26
+ return ctx.params["ef_search"] if ctx.command.name == "pgvectorhnsw" else 100
27
27
  return value
28
28
 
29
+
29
30
  class PgVectorTypedDict(CommonTypedDict):
30
31
  user_name: Annotated[
31
- str, click.option("--user-name", type=str, help="Db username", required=True)
32
+ str,
33
+ click.option("--user-name", type=str, help="Db username", required=True),
32
34
  ]
33
35
  password: Annotated[
34
36
  str,
35
- click.option("--password",
36
- type=str,
37
- help="Postgres database password",
38
- default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
39
- show_default="$POSTGRES_PASSWORD",
40
- ),
37
+ click.option(
38
+ "--password",
39
+ type=str,
40
+ help="Postgres database password",
41
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
42
+ show_default="$POSTGRES_PASSWORD",
43
+ ),
41
44
  ]
42
45
 
43
- host: Annotated[
44
- str, click.option("--host", type=str, help="Db host", required=True)
45
- ]
46
- db_name: Annotated[
47
- str, click.option("--db-name", type=str, help="Db name", required=True)
46
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
47
+ port: Annotated[
48
+ int,
49
+ click.option(
50
+ "--port",
51
+ type=int,
52
+ help="Postgres database port",
53
+ default=5432,
54
+ show_default=True,
55
+ required=False,
56
+ ),
48
57
  ]
58
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
49
59
  maintenance_work_mem: Annotated[
50
- Optional[str],
60
+ str | None,
51
61
  click.option(
52
62
  "--maintenance-work-mem",
53
63
  type=str,
@@ -59,7 +69,7 @@ class PgVectorTypedDict(CommonTypedDict):
59
69
  ),
60
70
  ]
61
71
  max_parallel_workers: Annotated[
62
- Optional[int],
72
+ int | None,
63
73
  click.option(
64
74
  "--max-parallel-workers",
65
75
  type=int,
@@ -68,7 +78,7 @@ class PgVectorTypedDict(CommonTypedDict):
68
78
  ),
69
79
  ]
70
80
  quantization_type: Annotated[
71
- Optional[str],
81
+ str | None,
72
82
  click.option(
73
83
  "--quantization-type",
74
84
  type=click.Choice(["none", "bit", "halfvec"]),
@@ -77,7 +87,7 @@ class PgVectorTypedDict(CommonTypedDict):
77
87
  ),
78
88
  ]
79
89
  reranking: Annotated[
80
- Optional[bool],
90
+ bool | None,
81
91
  click.option(
82
92
  "--reranking/--skip-reranking",
83
93
  type=bool,
@@ -86,11 +96,11 @@ class PgVectorTypedDict(CommonTypedDict):
86
96
  ),
87
97
  ]
88
98
  reranking_metric: Annotated[
89
- Optional[str],
99
+ str | None,
90
100
  click.option(
91
101
  "--reranking-metric",
92
102
  type=click.Choice(
93
- [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]]
103
+ [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]],
94
104
  ),
95
105
  help="Distance metric for reranking",
96
106
  default="COSINE",
@@ -98,7 +108,7 @@ class PgVectorTypedDict(CommonTypedDict):
98
108
  ),
99
109
  ]
100
110
  quantized_fetch_limit: Annotated[
101
- Optional[int],
111
+ int | None,
102
112
  click.option(
103
113
  "--quantized-fetch-limit",
104
114
  type=int,
@@ -106,13 +116,11 @@ class PgVectorTypedDict(CommonTypedDict):
106
116
  -- bound by ef_search",
107
117
  required=False,
108
118
  callback=set_default_quantized_fetch_limit,
109
- )
119
+ ),
110
120
  ]
111
121
 
112
-
113
122
 
114
- class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
115
- ...
123
+ class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict): ...
116
124
 
117
125
 
118
126
  @cli.command()
@@ -130,6 +138,7 @@ def PgVectorIVFFlat(
130
138
  user_name=SecretStr(parameters["user_name"]),
131
139
  password=SecretStr(parameters["password"]),
132
140
  host=parameters["host"],
141
+ port=parameters["port"],
133
142
  db_name=parameters["db_name"],
134
143
  ),
135
144
  db_case_config=PgVectorIVFFlatConfig(
@@ -145,8 +154,7 @@ def PgVectorIVFFlat(
145
154
  )
146
155
 
147
156
 
148
- class PgVectorHNSWTypedDict(PgVectorTypedDict, HNSWFlavor1):
149
- ...
157
+ class PgVectorHNSWTypedDict(PgVectorTypedDict, HNSWFlavor1): ...
150
158
 
151
159
 
152
160
  @cli.command()
@@ -164,6 +172,7 @@ def PgVectorHNSW(
164
172
  user_name=SecretStr(parameters["user_name"]),
165
173
  password=SecretStr(parameters["password"]),
166
174
  host=parameters["host"],
175
+ port=parameters["port"],
167
176
  db_name=parameters["db_name"],
168
177
  ),
169
178
  db_case_config=PgVectorHNSWConfig(
@@ -1,7 +1,9 @@
1
1
  from abc import abstractmethod
2
- from typing import Any, Mapping, Optional, Sequence, TypedDict
2
+ from collections.abc import Mapping, Sequence
3
+ from typing import Any, LiteralString, TypedDict
4
+
3
5
  from pydantic import BaseModel, SecretStr
4
- from typing_extensions import LiteralString
6
+
5
7
  from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
8
 
7
9
  POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
@@ -9,7 +11,7 @@ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
9
11
 
10
12
  class PgVectorConfigDict(TypedDict):
11
13
  """These keys will be directly used as kwargs in psycopg connection string,
12
- so the names must match exactly psycopg API"""
14
+ so the names must match exactly psycopg API"""
13
15
 
14
16
  user: str
15
17
  password: str
@@ -41,8 +43,8 @@ class PgVectorIndexParam(TypedDict):
41
43
  metric: str
42
44
  index_type: str
43
45
  index_creation_with_options: Sequence[dict[str, Any]]
44
- maintenance_work_mem: Optional[str]
45
- max_parallel_workers: Optional[int]
46
+ maintenance_work_mem: str | None
47
+ max_parallel_workers: int | None
46
48
 
47
49
 
48
50
  class PgVectorSearchParam(TypedDict):
@@ -59,61 +61,60 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
59
61
  create_index_after_load: bool = True
60
62
 
61
63
  def parse_metric(self) -> str:
62
- if self.quantization_type == "halfvec":
63
- if self.metric_type == MetricType.L2:
64
- return "halfvec_l2_ops"
65
- elif self.metric_type == MetricType.IP:
66
- return "halfvec_ip_ops"
67
- return "halfvec_cosine_ops"
68
- elif self.quantization_type == "bit":
69
- if self.metric_type == MetricType.JACCARD:
70
- return "bit_jaccard_ops"
71
- return "bit_hamming_ops"
72
- else:
73
- if self.metric_type == MetricType.L2:
74
- return "vector_l2_ops"
75
- elif self.metric_type == MetricType.IP:
76
- return "vector_ip_ops"
77
- return "vector_cosine_ops"
64
+ d = {
65
+ "halfvec": {
66
+ MetricType.L2: "halfvec_l2_ops",
67
+ MetricType.IP: "halfvec_ip_ops",
68
+ MetricType.COSINE: "halfvec_cosine_ops",
69
+ },
70
+ "bit": {
71
+ MetricType.JACCARD: "bit_jaccard_ops",
72
+ MetricType.HAMMING: "bit_hamming_ops",
73
+ },
74
+ "_fallback": {
75
+ MetricType.L2: "vector_l2_ops",
76
+ MetricType.IP: "vector_ip_ops",
77
+ MetricType.COSINE: "vector_cosine_ops",
78
+ },
79
+ }
80
+
81
+ if d.get(self.quantization_type) is None:
82
+ return d.get("_fallback").get(self.metric_type)
83
+ return d.get(self.quantization_type).get(self.metric_type)
78
84
 
79
85
  def parse_metric_fun_op(self) -> LiteralString:
80
86
  if self.quantization_type == "bit":
81
87
  if self.metric_type == MetricType.JACCARD:
82
88
  return "<%>"
83
89
  return "<~>"
84
- else:
85
- if self.metric_type == MetricType.L2:
86
- return "<->"
87
- elif self.metric_type == MetricType.IP:
88
- return "<#>"
89
- return "<=>"
90
+ if self.metric_type == MetricType.L2:
91
+ return "<->"
92
+ if self.metric_type == MetricType.IP:
93
+ return "<#>"
94
+ return "<=>"
90
95
 
91
96
  def parse_metric_fun_str(self) -> str:
92
97
  if self.metric_type == MetricType.L2:
93
98
  return "l2_distance"
94
- elif self.metric_type == MetricType.IP:
99
+ if self.metric_type == MetricType.IP:
95
100
  return "max_inner_product"
96
101
  return "cosine_distance"
97
-
102
+
98
103
  def parse_reranking_metric_fun_op(self) -> LiteralString:
99
104
  if self.reranking_metric == MetricType.L2:
100
105
  return "<->"
101
- elif self.reranking_metric == MetricType.IP:
106
+ if self.reranking_metric == MetricType.IP:
102
107
  return "<#>"
103
108
  return "<=>"
104
109
 
105
-
106
110
  @abstractmethod
107
- def index_param(self) -> PgVectorIndexParam:
108
- ...
111
+ def index_param(self) -> PgVectorIndexParam: ...
109
112
 
110
113
  @abstractmethod
111
- def search_param(self) -> PgVectorSearchParam:
112
- ...
114
+ def search_param(self) -> PgVectorSearchParam: ...
113
115
 
114
116
  @abstractmethod
115
- def session_param(self) -> PgVectorSessionCommands:
116
- ...
117
+ def session_param(self) -> PgVectorSessionCommands: ...
117
118
 
118
119
  @staticmethod
119
120
  def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
@@ -125,24 +126,23 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
125
126
  {
126
127
  "option_name": option_name,
127
128
  "val": str(value),
128
- }
129
+ },
129
130
  )
130
131
  return options
131
132
 
132
133
  @staticmethod
133
- def _optionally_build_set_options(
134
- set_mapping: Mapping[str, Any]
135
- ) -> Sequence[dict[str, Any]]:
134
+ def _optionally_build_set_options(set_mapping: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
136
135
  """Walk through options, creating 'SET 'key1 = "value1";' list"""
137
136
  session_options = []
138
137
  for setting_name, value in set_mapping.items():
139
138
  if value:
140
139
  session_options.append(
141
- {"parameter": {
140
+ {
141
+ "parameter": {
142
142
  "setting_name": setting_name,
143
143
  "val": str(value),
144
144
  },
145
- }
145
+ },
146
146
  )
147
147
  return session_options
148
148
 
@@ -165,12 +165,12 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
165
165
  lists: int | None
166
166
  probes: int | None
167
167
  index: IndexType = IndexType.ES_IVFFlat
168
- maintenance_work_mem: Optional[str] = None
169
- max_parallel_workers: Optional[int] = None
170
- quantization_type: Optional[str] = None
171
- reranking: Optional[bool] = None
172
- quantized_fetch_limit: Optional[int] = None
173
- reranking_metric: Optional[str] = None
168
+ maintenance_work_mem: str | None = None
169
+ max_parallel_workers: int | None = None
170
+ quantization_type: str | None = None
171
+ reranking: bool | None = None
172
+ quantized_fetch_limit: int | None = None
173
+ reranking_metric: str | None = None
174
174
 
175
175
  def index_param(self) -> PgVectorIndexParam:
176
176
  index_parameters = {"lists": self.lists}
@@ -179,9 +179,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
179
179
  return {
180
180
  "metric": self.parse_metric(),
181
181
  "index_type": self.index.value,
182
- "index_creation_with_options": self._optionally_build_with_options(
183
- index_parameters
184
- ),
182
+ "index_creation_with_options": self._optionally_build_with_options(index_parameters),
185
183
  "maintenance_work_mem": self.maintenance_work_mem,
186
184
  "max_parallel_workers": self.max_parallel_workers,
187
185
  "quantization_type": self.quantization_type,
@@ -197,9 +195,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
197
195
 
198
196
  def session_param(self) -> PgVectorSessionCommands:
199
197
  session_parameters = {"ivfflat.probes": self.probes}
200
- return {
201
- "session_options": self._optionally_build_set_options(session_parameters)
202
- }
198
+ return {"session_options": self._optionally_build_set_options(session_parameters)}
203
199
 
204
200
 
205
201
  class PgVectorHNSWConfig(PgVectorIndexConfig):
@@ -210,17 +206,15 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
210
206
  """
211
207
 
212
208
  m: int | None # DETAIL: Valid values are between "2" and "100".
213
- ef_construction: (
214
- int | None
215
- ) # ef_construction must be greater than or equal to 2 * m
209
+ ef_construction: int | None # ef_construction must be greater than or equal to 2 * m
216
210
  ef_search: int | None
217
211
  index: IndexType = IndexType.ES_HNSW
218
- maintenance_work_mem: Optional[str] = None
219
- max_parallel_workers: Optional[int] = None
220
- quantization_type: Optional[str] = None
221
- reranking: Optional[bool] = None
222
- quantized_fetch_limit: Optional[int] = None
223
- reranking_metric: Optional[str] = None
212
+ maintenance_work_mem: str | None = None
213
+ max_parallel_workers: int | None = None
214
+ quantization_type: str | None = None
215
+ reranking: bool | None = None
216
+ quantized_fetch_limit: int | None = None
217
+ reranking_metric: str | None = None
224
218
 
225
219
  def index_param(self) -> PgVectorIndexParam:
226
220
  index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
@@ -229,9 +223,7 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
229
223
  return {
230
224
  "metric": self.parse_metric(),
231
225
  "index_type": self.index.value,
232
- "index_creation_with_options": self._optionally_build_with_options(
233
- index_parameters
234
- ),
226
+ "index_creation_with_options": self._optionally_build_with_options(index_parameters),
235
227
  "maintenance_work_mem": self.maintenance_work_mem,
236
228
  "max_parallel_workers": self.max_parallel_workers,
237
229
  "quantization_type": self.quantization_type,
@@ -247,13 +239,11 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
247
239
 
248
240
  def session_param(self) -> PgVectorSessionCommands:
249
241
  session_parameters = {"hnsw.ef_search": self.ef_search}
250
- return {
251
- "session_options": self._optionally_build_set_options(session_parameters)
252
- }
242
+ return {"session_options": self._optionally_build_set_options(session_parameters)}
253
243
 
254
244
 
255
245
  _pgvector_case_config = {
256
- IndexType.HNSW: PgVectorHNSWConfig,
257
- IndexType.ES_HNSW: PgVectorHNSWConfig,
258
- IndexType.IVFFlat: PgVectorIVFFlatConfig,
246
+ IndexType.HNSW: PgVectorHNSWConfig,
247
+ IndexType.ES_HNSW: PgVectorHNSWConfig,
248
+ IndexType.IVFFlat: PgVectorIVFFlatConfig,
259
249
  }