vectordb-bench 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +85 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +112 -77
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +59 -84
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +13 -24
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +50 -54
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +39 -40
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +24 -26
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +67 -58
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +19 -39
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +56 -77
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +34 -43
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +98 -104
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +39 -49
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +15 -25
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +41 -35
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +95 -62
  47. vectordb_bench/backend/clients/test/cli.py +2 -3
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +5 -9
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +37 -26
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +18 -14
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +56 -23
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +46 -22
  63. vectordb_bench/backend/runner/serial_runner.py +81 -46
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -92
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +45 -24
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.21.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  """Wrapper around the alloydb vector database over VectorDB"""
2
2
 
3
3
  import logging
4
- import pprint
4
+ from collections.abc import Generator, Sequence
5
5
  from contextlib import contextmanager
6
- from typing import Any, Generator, Optional, Tuple, Sequence
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import psycopg
@@ -11,7 +11,7 @@ from pgvector.psycopg import register_vector
11
11
  from psycopg import Connection, Cursor, sql
12
12
 
13
13
  from ..api import VectorDB
14
- from .config import AlloyDBConfigDict, AlloyDBIndexConfig, AlloyDBScaNNConfig
14
+ from .config import AlloyDBConfigDict, AlloyDBIndexConfig
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
@@ -56,13 +56,14 @@ class AlloyDB(VectorDB):
56
56
  (
57
57
  self.case_config.create_index_before_load,
58
58
  self.case_config.create_index_after_load,
59
- )
59
+ ),
60
60
  ):
61
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
62
- log.error(err)
63
- raise RuntimeError(
64
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
61
+ msg = (
62
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
63
+ "\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
65
64
  )
65
+ log.warning(msg)
66
+ raise RuntimeError(msg)
66
67
 
67
68
  if drop_old:
68
69
  self._drop_index()
@@ -77,7 +78,7 @@ class AlloyDB(VectorDB):
77
78
  self.conn = None
78
79
 
79
80
  @staticmethod
80
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
81
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
81
82
  conn = psycopg.connect(**kwargs)
82
83
  register_vector(conn)
83
84
  conn.autocommit = False
@@ -86,21 +87,20 @@ class AlloyDB(VectorDB):
86
87
  assert conn is not None, "Connection is not initialized"
87
88
  assert cursor is not None, "Cursor is not initialized"
88
89
  return conn, cursor
89
-
90
- def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
91
- search_query = sql.Composed(
90
+
91
+ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
92
+ return sql.Composed(
92
93
  [
93
94
  sql.SQL(
94
- "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
95
+ "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ",
95
96
  ).format(
96
97
  table_name=sql.Identifier(self.table_name),
97
98
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
98
99
  ),
99
100
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
100
101
  sql.SQL(" %s::vector LIMIT %s::int"),
101
- ]
102
+ ],
102
103
  )
103
- return search_query
104
104
 
105
105
  @contextmanager
106
106
  def init(self) -> Generator[None, None, None]:
@@ -119,8 +119,8 @@ class AlloyDB(VectorDB):
119
119
  if len(session_options) > 0:
120
120
  for setting in session_options:
121
121
  command = sql.SQL("SET {setting_name} " + "= {val};").format(
122
- setting_name=sql.Identifier(setting['parameter']['setting_name']),
123
- val=sql.Identifier(str(setting['parameter']['val'])),
122
+ setting_name=sql.Identifier(setting["parameter"]["setting_name"]),
123
+ val=sql.Identifier(str(setting["parameter"]["val"])),
124
124
  )
125
125
  log.debug(command.as_string(self.cursor))
126
126
  self.cursor.execute(command)
@@ -144,15 +144,12 @@ class AlloyDB(VectorDB):
144
144
 
145
145
  self.cursor.execute(
146
146
  sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
147
- table_name=sql.Identifier(self.table_name)
148
- )
147
+ table_name=sql.Identifier(self.table_name),
148
+ ),
149
149
  )
150
150
  self.conn.commit()
151
151
 
152
- def ready_to_load(self):
153
- pass
154
-
155
- def optimize(self):
152
+ def optimize(self, data_size: int | None = None):
156
153
  self._post_insert()
157
154
 
158
155
  def _post_insert(self):
@@ -167,7 +164,7 @@ class AlloyDB(VectorDB):
167
164
  log.info(f"{self.name} client drop index : {self._index_name}")
168
165
 
169
166
  drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
170
- index_name=sql.Identifier(self._index_name)
167
+ index_name=sql.Identifier(self._index_name),
171
168
  )
172
169
  log.debug(drop_index_sql.as_string(self.cursor))
173
170
  self.cursor.execute(drop_index_sql)
@@ -181,78 +178,64 @@ class AlloyDB(VectorDB):
181
178
 
182
179
  if index_param["enable_pca"] is not None:
183
180
  self.cursor.execute(
184
- sql.SQL("SET scann.enable_pca TO {};").format(
185
- index_param["enable_pca"]
186
- )
181
+ sql.SQL("SET scann.enable_pca TO {};").format(index_param["enable_pca"]),
187
182
  )
188
183
  self.cursor.execute(
189
184
  sql.SQL("ALTER USER {} SET scann.enable_pca TO {};").format(
190
185
  sql.Identifier(self.db_config["user"]),
191
186
  index_param["enable_pca"],
192
- )
187
+ ),
193
188
  )
194
189
  self.conn.commit()
195
190
 
196
191
  if index_param["maintenance_work_mem"] is not None:
197
192
  self.cursor.execute(
198
193
  sql.SQL("SET maintenance_work_mem TO {};").format(
199
- index_param["maintenance_work_mem"]
200
- )
194
+ index_param["maintenance_work_mem"],
195
+ ),
201
196
  )
202
197
  self.cursor.execute(
203
198
  sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
204
199
  sql.Identifier(self.db_config["user"]),
205
200
  index_param["maintenance_work_mem"],
206
- )
201
+ ),
207
202
  )
208
203
  self.conn.commit()
209
204
 
210
205
  if index_param["max_parallel_workers"] is not None:
211
206
  self.cursor.execute(
212
207
  sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
213
- index_param["max_parallel_workers"]
214
- )
208
+ index_param["max_parallel_workers"],
209
+ ),
215
210
  )
216
211
  self.cursor.execute(
217
- sql.SQL(
218
- "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
219
- ).format(
212
+ sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format(
220
213
  sql.Identifier(self.db_config["user"]),
221
214
  index_param["max_parallel_workers"],
222
- )
215
+ ),
223
216
  )
224
217
  self.cursor.execute(
225
218
  sql.SQL("SET max_parallel_workers TO '{}';").format(
226
- index_param["max_parallel_workers"]
227
- )
219
+ index_param["max_parallel_workers"],
220
+ ),
228
221
  )
229
222
  self.cursor.execute(
230
- sql.SQL(
231
- "ALTER USER {} SET max_parallel_workers TO '{}';"
232
- ).format(
223
+ sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format(
233
224
  sql.Identifier(self.db_config["user"]),
234
225
  index_param["max_parallel_workers"],
235
- )
226
+ ),
236
227
  )
237
228
  self.cursor.execute(
238
- sql.SQL(
239
- "ALTER TABLE {} SET (parallel_workers = {});"
240
- ).format(
229
+ sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format(
241
230
  sql.Identifier(self.table_name),
242
231
  index_param["max_parallel_workers"],
243
- )
232
+ ),
244
233
  )
245
234
  self.conn.commit()
246
235
 
247
- results = self.cursor.execute(
248
- sql.SQL("SHOW max_parallel_maintenance_workers;")
249
- ).fetchall()
250
- results.extend(
251
- self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
252
- )
253
- results.extend(
254
- self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
255
- )
236
+ results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall()
237
+ results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall())
238
+ results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall())
256
239
  log.info(f"{self.name} parallel index creation parameters: {results}")
257
240
 
258
241
  def _create_index(self):
@@ -264,23 +247,20 @@ class AlloyDB(VectorDB):
264
247
  self._set_parallel_index_build_param()
265
248
  options = []
266
249
  for option in index_param["index_creation_with_options"]:
267
- if option['val'] is not None:
250
+ if option["val"] is not None:
268
251
  options.append(
269
252
  sql.SQL("{option_name} = {val}").format(
270
- option_name=sql.Identifier(option['option_name']),
271
- val=sql.Identifier(str(option['val'])),
272
- )
253
+ option_name=sql.Identifier(option["option_name"]),
254
+ val=sql.Identifier(str(option["val"])),
255
+ ),
273
256
  )
274
- if any(options):
275
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
276
- else:
277
- with_clause = sql.Composed(())
257
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
278
258
 
279
259
  index_create_sql = sql.SQL(
280
260
  """
281
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
261
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
282
262
  USING {index_type} (embedding {embedding_metric})
283
- """
263
+ """,
284
264
  ).format(
285
265
  index_name=sql.Identifier(self._index_name),
286
266
  table_name=sql.Identifier(self.table_name),
@@ -288,9 +268,7 @@ class AlloyDB(VectorDB):
288
268
  embedding_metric=sql.Identifier(index_param["metric"]),
289
269
  )
290
270
 
291
- index_create_sql_with_with_clause = (
292
- index_create_sql + with_clause
293
- ).join(" ")
271
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
294
272
  log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
295
273
  self.cursor.execute(index_create_sql_with_with_clause)
296
274
  self.conn.commit()
@@ -305,14 +283,12 @@ class AlloyDB(VectorDB):
305
283
  # create table
306
284
  self.cursor.execute(
307
285
  sql.SQL(
308
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
309
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
286
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
287
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
310
288
  )
311
289
  self.conn.commit()
312
290
  except Exception as e:
313
- log.warning(
314
- f"Failed to create alloydb table: {self.table_name} error: {e}"
315
- )
291
+ log.warning(f"Failed to create alloydb table: {self.table_name} error: {e}")
316
292
  raise e from None
317
293
 
318
294
  def insert_embeddings(
@@ -320,7 +296,7 @@ class AlloyDB(VectorDB):
320
296
  embeddings: list[list[float]],
321
297
  metadata: list[int],
322
298
  **kwargs: Any,
323
- ) -> Tuple[int, Optional[Exception]]:
299
+ ) -> tuple[int, Exception | None]:
324
300
  assert self.conn is not None, "Connection is not initialized"
325
301
  assert self.cursor is not None, "Cursor is not initialized"
326
302
 
@@ -330,8 +306,8 @@ class AlloyDB(VectorDB):
330
306
 
331
307
  with self.cursor.copy(
332
308
  sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
333
- table_name=sql.Identifier(self.table_name)
334
- )
309
+ table_name=sql.Identifier(self.table_name),
310
+ ),
335
311
  ) as copy:
336
312
  copy.set_types(["bigint", "vector"])
337
313
  for i, row in enumerate(metadata_arr):
@@ -343,9 +319,7 @@ class AlloyDB(VectorDB):
343
319
 
344
320
  return len(metadata), None
345
321
  except Exception as e:
346
- log.warning(
347
- f"Failed to insert data into alloydb table ({self.table_name}), error: {e}"
348
- )
322
+ log.warning(f"Failed to insert data into alloydb table ({self.table_name}), error: {e}")
349
323
  return 0, e
350
324
 
351
325
  def search_embedding(
@@ -362,11 +336,12 @@ class AlloyDB(VectorDB):
362
336
  if filters:
363
337
  gt = filters.get("id")
364
338
  result = self.cursor.execute(
365
- self._filtered_search, (gt, q, k), prepare=True, binary=True
339
+ self._filtered_search,
340
+ (gt, q, k),
341
+ prepare=True,
342
+ binary=True,
366
343
  )
367
344
  else:
368
- result = self.cursor.execute(
369
- self._unfiltered_search, (q, k), prepare=True, binary=True
370
- )
345
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
371
346
 
372
347
  return [int(i[0]) for i in result.fetchall()]
@@ -1,10 +1,10 @@
1
- from typing import Annotated, Optional, TypedDict, Unpack
1
+ import os
2
+ from typing import Annotated, Unpack
2
3
 
3
4
  import click
4
- import os
5
5
  from pydantic import SecretStr
6
6
 
7
- from vectordb_bench.backend.clients.api import MetricType
7
+ from vectordb_bench.backend.clients import DB
8
8
 
9
9
  from ....cli.cli import (
10
10
  CommonTypedDict,
@@ -13,31 +13,28 @@ from ....cli.cli import (
13
13
  get_custom_case_config,
14
14
  run,
15
15
  )
16
- from vectordb_bench.backend.clients import DB
17
16
 
18
17
 
19
18
  class AlloyDBTypedDict(CommonTypedDict):
20
19
  user_name: Annotated[
21
- str, click.option("--user-name", type=str, help="Db username", required=True)
20
+ str,
21
+ click.option("--user-name", type=str, help="Db username", required=True),
22
22
  ]
23
23
  password: Annotated[
24
24
  str,
25
- click.option("--password",
26
- type=str,
27
- help="Postgres database password",
28
- default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
29
- show_default="$POSTGRES_PASSWORD",
30
- ),
25
+ click.option(
26
+ "--password",
27
+ type=str,
28
+ help="Postgres database password",
29
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
30
+ show_default="$POSTGRES_PASSWORD",
31
+ ),
31
32
  ]
32
33
 
33
- host: Annotated[
34
- str, click.option("--host", type=str, help="Db host", required=True)
35
- ]
36
- db_name: Annotated[
37
- str, click.option("--db-name", type=str, help="Db name", required=True)
38
- ]
34
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
35
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
39
36
  maintenance_work_mem: Annotated[
40
- Optional[str],
37
+ str | None,
41
38
  click.option(
42
39
  "--maintenance-work-mem",
43
40
  type=str,
@@ -49,7 +46,7 @@ class AlloyDBTypedDict(CommonTypedDict):
49
46
  ),
50
47
  ]
51
48
  max_parallel_workers: Annotated[
52
- Optional[int],
49
+ int | None,
53
50
  click.option(
54
51
  "--max-parallel-workers",
55
52
  type=int,
@@ -58,32 +55,51 @@ class AlloyDBTypedDict(CommonTypedDict):
58
55
  ),
59
56
  ]
60
57
 
61
-
62
58
 
63
59
  class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
64
60
  num_leaves: Annotated[
65
61
  int,
66
- click.option("--num-leaves", type=int, help="Number of leaves", required=True)
62
+ click.option("--num-leaves", type=int, help="Number of leaves", required=True),
67
63
  ]
68
64
  num_leaves_to_search: Annotated[
69
65
  int,
70
- click.option("--num-leaves-to-search", type=int, help="Number of leaves to search", required=True)
66
+ click.option(
67
+ "--num-leaves-to-search",
68
+ type=int,
69
+ help="Number of leaves to search",
70
+ required=True,
71
+ ),
71
72
  ]
72
73
  pre_reordering_num_neighbors: Annotated[
73
74
  int,
74
- click.option("--pre-reordering-num-neighbors", type=int, help="Pre-reordering number of neighbors", default=200)
75
+ click.option(
76
+ "--pre-reordering-num-neighbors",
77
+ type=int,
78
+ help="Pre-reordering number of neighbors",
79
+ default=200,
80
+ ),
75
81
  ]
76
82
  max_top_neighbors_buffer_size: Annotated[
77
83
  int,
78
- click.option("--max-top-neighbors-buffer-size", type=int, help="Maximum top neighbors buffer size", default=20_000)
84
+ click.option(
85
+ "--max-top-neighbors-buffer-size",
86
+ type=int,
87
+ help="Maximum top neighbors buffer size",
88
+ default=20_000,
89
+ ),
79
90
  ]
80
91
  num_search_threads: Annotated[
81
92
  int,
82
- click.option("--num-search-threads", type=int, help="Number of search threads", default=2)
93
+ click.option("--num-search-threads", type=int, help="Number of search threads", default=2),
83
94
  ]
84
95
  max_num_prefetch_datasets: Annotated[
85
96
  int,
86
- click.option("--max-num-prefetch-datasets", type=int, help="Maximum number of prefetch datasets", default=100)
97
+ click.option(
98
+ "--max-num-prefetch-datasets",
99
+ type=int,
100
+ help="Maximum number of prefetch datasets",
101
+ default=100,
102
+ ),
87
103
  ]
88
104
  quantizer: Annotated[
89
105
  str,
@@ -91,16 +107,17 @@ class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
91
107
  "--quantizer",
92
108
  type=click.Choice(["SQ8", "FLAT"]),
93
109
  help="Quantizer type",
94
- default="SQ8"
95
- )
110
+ default="SQ8",
111
+ ),
96
112
  ]
97
113
  enable_pca: Annotated[
98
- bool, click.option(
114
+ bool,
115
+ click.option(
99
116
  "--enable-pca",
100
117
  type=click.Choice(["on", "off"]),
101
118
  help="Enable PCA",
102
- default="on"
103
- )
119
+ default="on",
120
+ ),
104
121
  ]
105
122
  max_num_levels: Annotated[
106
123
  int,
@@ -108,8 +125,8 @@ class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
108
125
  "--max-num-levels",
109
126
  type=click.Choice(["1", "2"]),
110
127
  help="Maximum number of levels",
111
- default=1
112
- )
128
+ default=1,
129
+ ),
113
130
  ]
114
131
 
115
132
 
@@ -144,4 +161,4 @@ def AlloyDBScaNN(
144
161
  maintenance_work_mem=parameters["maintenance_work_mem"],
145
162
  ),
146
163
  **parameters,
147
- )
164
+ )
@@ -1,7 +1,9 @@
1
1
  from abc import abstractmethod
2
- from typing import Any, Mapping, Optional, Sequence, TypedDict
2
+ from collections.abc import Mapping, Sequence
3
+ from typing import Any, LiteralString, TypedDict
4
+
3
5
  from pydantic import BaseModel, SecretStr
4
- from typing_extensions import LiteralString
6
+
5
7
  from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
8
 
7
9
  POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
@@ -9,7 +11,7 @@ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
9
11
 
10
12
  class AlloyDBConfigDict(TypedDict):
11
13
  """These keys will be directly used as kwargs in psycopg connection string,
12
- so the names must match exactly psycopg API"""
14
+ so the names must match exactly psycopg API"""
13
15
 
14
16
  user: str
15
17
  password: str
@@ -41,8 +43,8 @@ class AlloyDBIndexParam(TypedDict):
41
43
  metric: str
42
44
  index_type: str
43
45
  index_creation_with_options: Sequence[dict[str, Any]]
44
- maintenance_work_mem: Optional[str]
45
- max_parallel_workers: Optional[int]
46
+ maintenance_work_mem: str | None
47
+ max_parallel_workers: int | None
46
48
 
47
49
 
48
50
  class AlloyDBSearchParam(TypedDict):
@@ -61,31 +63,30 @@ class AlloyDBIndexConfig(BaseModel, DBCaseConfig):
61
63
  def parse_metric(self) -> str:
62
64
  if self.metric_type == MetricType.L2:
63
65
  return "l2"
64
- elif self.metric_type == MetricType.DP:
66
+ if self.metric_type == MetricType.DP:
65
67
  return "dot_product"
66
68
  return "cosine"
67
69
 
68
70
  def parse_metric_fun_op(self) -> LiteralString:
69
71
  if self.metric_type == MetricType.L2:
70
72
  return "<->"
71
- elif self.metric_type == MetricType.IP:
73
+ if self.metric_type == MetricType.IP:
72
74
  return "<#>"
73
75
  return "<=>"
74
76
 
75
77
  @abstractmethod
76
- def index_param(self) -> AlloyDBIndexParam:
77
- ...
78
+ def index_param(self) -> AlloyDBIndexParam: ...
78
79
 
79
80
  @abstractmethod
80
- def search_param(self) -> AlloyDBSearchParam:
81
- ...
81
+ def search_param(self) -> AlloyDBSearchParam: ...
82
82
 
83
83
  @abstractmethod
84
- def session_param(self) -> AlloyDBSessionCommands:
85
- ...
84
+ def session_param(self) -> AlloyDBSessionCommands: ...
86
85
 
87
86
  @staticmethod
88
- def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
87
+ def _optionally_build_with_options(
88
+ with_options: Mapping[str, Any],
89
+ ) -> Sequence[dict[str, Any]]:
89
90
  """Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
90
91
  options = []
91
92
  for option_name, value in with_options.items():
@@ -94,24 +95,25 @@ class AlloyDBIndexConfig(BaseModel, DBCaseConfig):
94
95
  {
95
96
  "option_name": option_name,
96
97
  "val": str(value),
97
- }
98
+ },
98
99
  )
99
100
  return options
100
101
 
101
102
  @staticmethod
102
103
  def _optionally_build_set_options(
103
- set_mapping: Mapping[str, Any]
104
+ set_mapping: Mapping[str, Any],
104
105
  ) -> Sequence[dict[str, Any]]:
105
106
  """Walk through options, creating 'SET 'key1 = "value1";' list"""
106
107
  session_options = []
107
108
  for setting_name, value in set_mapping.items():
108
109
  if value:
109
110
  session_options.append(
110
- {"parameter": {
111
+ {
112
+ "parameter": {
111
113
  "setting_name": setting_name,
112
114
  "val": str(value),
113
115
  },
114
- }
116
+ },
115
117
  )
116
118
  return session_options
117
119
 
@@ -124,22 +126,22 @@ class AlloyDBScaNNConfig(AlloyDBIndexConfig):
124
126
  max_num_levels: int | None
125
127
  num_leaves_to_search: int | None
126
128
  max_top_neighbors_buffer_size: int | None
127
- pre_reordering_num_neighbors: int | None
128
- num_search_threads: int | None
129
+ pre_reordering_num_neighbors: int | None
130
+ num_search_threads: int | None
129
131
  max_num_prefetch_datasets: int | None
130
- maintenance_work_mem: Optional[str] = None
131
- max_parallel_workers: Optional[int] = None
132
+ maintenance_work_mem: str | None = None
133
+ max_parallel_workers: int | None = None
132
134
 
133
135
  def index_param(self) -> AlloyDBIndexParam:
134
136
  index_parameters = {
135
- "num_leaves": self.num_leaves, "max_num_levels": self.max_num_levels, "quantizer": self.quantizer,
137
+ "num_leaves": self.num_leaves,
138
+ "max_num_levels": self.max_num_levels,
139
+ "quantizer": self.quantizer,
136
140
  }
137
141
  return {
138
142
  "metric": self.parse_metric(),
139
143
  "index_type": self.index.value,
140
- "index_creation_with_options": self._optionally_build_with_options(
141
- index_parameters
142
- ),
144
+ "index_creation_with_options": self._optionally_build_with_options(index_parameters),
143
145
  "maintenance_work_mem": self.maintenance_work_mem,
144
146
  "max_parallel_workers": self.max_parallel_workers,
145
147
  "enable_pca": self.enable_pca,
@@ -158,11 +160,9 @@ class AlloyDBScaNNConfig(AlloyDBIndexConfig):
158
160
  "scann.num_search_threads": self.num_search_threads,
159
161
  "scann.max_num_prefetch_datasets": self.max_num_prefetch_datasets,
160
162
  }
161
- return {
162
- "session_options": self._optionally_build_set_options(session_parameters)
163
- }
163
+ return {"session_options": self._optionally_build_set_options(session_parameters)}
164
164
 
165
165
 
166
166
  _alloydb_case_config = {
167
- IndexType.SCANN: AlloyDBScaNNConfig,
167
+ IndexType.SCANN: AlloyDBScaNNConfig,
168
168
  }