vectordb-bench 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +75 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +111 -70
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +5 -9
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +38 -36
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +18 -19
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +94 -58
  47. vectordb_bench/backend/clients/test/cli.py +1 -2
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +4 -5
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +30 -18
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +85 -34
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +51 -23
  63. vectordb_bench/backend/runner/serial_runner.py +91 -48
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -72
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +56 -26
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.20.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
1
  import logging
2
- from enum import Enum
3
- from pydantic import SecretStr, BaseModel
4
2
 
5
- from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
3
+ from pydantic import BaseModel, SecretStr
4
+
5
+ from ..api import DBCaseConfig, DBConfig, MetricType
6
6
 
7
7
  log = logging.getLogger(__name__)
8
8
 
@@ -26,18 +26,17 @@ class AliyunOpenSearchConfig(DBConfig, BaseModel):
26
26
  "control_host": self.control_host,
27
27
  }
28
28
 
29
+
29
30
  class AliyunOpenSearchIndexConfig(BaseModel, DBCaseConfig):
30
31
  metric_type: MetricType = MetricType.L2
31
- efConstruction: int = 500
32
+ ef_construction: int = 500
32
33
  M: int = 100
33
34
  ef_search: int = 40
34
35
 
35
36
  def distance_type(self) -> str:
36
37
  if self.metric_type == MetricType.L2:
37
38
  return "SquaredEuclidean"
38
- elif self.metric_type == MetricType.IP:
39
- return "InnerProduct"
40
- elif self.metric_type == MetricType.COSINE:
39
+ if self.metric_type in (MetricType.IP, MetricType.COSINE):
41
40
  return "InnerProduct"
42
41
  return "SquaredEuclidean"
43
42
 
@@ -1,9 +1,9 @@
1
1
  """Wrapper around the alloydb vector database over VectorDB"""
2
2
 
3
3
  import logging
4
- import pprint
4
+ from collections.abc import Generator, Sequence
5
5
  from contextlib import contextmanager
6
- from typing import Any, Generator, Optional, Tuple, Sequence
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  import psycopg
@@ -11,7 +11,7 @@ from pgvector.psycopg import register_vector
11
11
  from psycopg import Connection, Cursor, sql
12
12
 
13
13
  from ..api import VectorDB
14
- from .config import AlloyDBConfigDict, AlloyDBIndexConfig, AlloyDBScaNNConfig
14
+ from .config import AlloyDBConfigDict, AlloyDBIndexConfig
15
15
 
16
16
  log = logging.getLogger(__name__)
17
17
 
@@ -56,13 +56,14 @@ class AlloyDB(VectorDB):
56
56
  (
57
57
  self.case_config.create_index_before_load,
58
58
  self.case_config.create_index_after_load,
59
- )
59
+ ),
60
60
  ):
61
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
62
- log.error(err)
63
- raise RuntimeError(
64
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
61
+ msg = (
62
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
63
+ "\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
65
64
  )
65
+ log.warning(msg)
66
+ raise RuntimeError(msg)
66
67
 
67
68
  if drop_old:
68
69
  self._drop_index()
@@ -77,7 +78,7 @@ class AlloyDB(VectorDB):
77
78
  self.conn = None
78
79
 
79
80
  @staticmethod
80
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
81
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
81
82
  conn = psycopg.connect(**kwargs)
82
83
  register_vector(conn)
83
84
  conn.autocommit = False
@@ -86,21 +87,20 @@ class AlloyDB(VectorDB):
86
87
  assert conn is not None, "Connection is not initialized"
87
88
  assert cursor is not None, "Cursor is not initialized"
88
89
  return conn, cursor
89
-
90
- def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
91
- search_query = sql.Composed(
90
+
91
+ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
92
+ return sql.Composed(
92
93
  [
93
94
  sql.SQL(
94
- "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
95
+ "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ",
95
96
  ).format(
96
97
  table_name=sql.Identifier(self.table_name),
97
98
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
98
99
  ),
99
100
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
100
101
  sql.SQL(" %s::vector LIMIT %s::int"),
101
- ]
102
+ ],
102
103
  )
103
- return search_query
104
104
 
105
105
  @contextmanager
106
106
  def init(self) -> Generator[None, None, None]:
@@ -119,8 +119,8 @@ class AlloyDB(VectorDB):
119
119
  if len(session_options) > 0:
120
120
  for setting in session_options:
121
121
  command = sql.SQL("SET {setting_name} " + "= {val};").format(
122
- setting_name=sql.Identifier(setting['parameter']['setting_name']),
123
- val=sql.Identifier(str(setting['parameter']['val'])),
122
+ setting_name=sql.Identifier(setting["parameter"]["setting_name"]),
123
+ val=sql.Identifier(str(setting["parameter"]["val"])),
124
124
  )
125
125
  log.debug(command.as_string(self.cursor))
126
126
  self.cursor.execute(command)
@@ -144,8 +144,8 @@ class AlloyDB(VectorDB):
144
144
 
145
145
  self.cursor.execute(
146
146
  sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
147
- table_name=sql.Identifier(self.table_name)
148
- )
147
+ table_name=sql.Identifier(self.table_name),
148
+ ),
149
149
  )
150
150
  self.conn.commit()
151
151
 
@@ -167,7 +167,7 @@ class AlloyDB(VectorDB):
167
167
  log.info(f"{self.name} client drop index : {self._index_name}")
168
168
 
169
169
  drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
170
- index_name=sql.Identifier(self._index_name)
170
+ index_name=sql.Identifier(self._index_name),
171
171
  )
172
172
  log.debug(drop_index_sql.as_string(self.cursor))
173
173
  self.cursor.execute(drop_index_sql)
@@ -181,78 +181,64 @@ class AlloyDB(VectorDB):
181
181
 
182
182
  if index_param["enable_pca"] is not None:
183
183
  self.cursor.execute(
184
- sql.SQL("SET scann.enable_pca TO {};").format(
185
- index_param["enable_pca"]
186
- )
184
+ sql.SQL("SET scann.enable_pca TO {};").format(index_param["enable_pca"]),
187
185
  )
188
186
  self.cursor.execute(
189
187
  sql.SQL("ALTER USER {} SET scann.enable_pca TO {};").format(
190
188
  sql.Identifier(self.db_config["user"]),
191
189
  index_param["enable_pca"],
192
- )
190
+ ),
193
191
  )
194
192
  self.conn.commit()
195
193
 
196
194
  if index_param["maintenance_work_mem"] is not None:
197
195
  self.cursor.execute(
198
196
  sql.SQL("SET maintenance_work_mem TO {};").format(
199
- index_param["maintenance_work_mem"]
200
- )
197
+ index_param["maintenance_work_mem"],
198
+ ),
201
199
  )
202
200
  self.cursor.execute(
203
201
  sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
204
202
  sql.Identifier(self.db_config["user"]),
205
203
  index_param["maintenance_work_mem"],
206
- )
204
+ ),
207
205
  )
208
206
  self.conn.commit()
209
207
 
210
208
  if index_param["max_parallel_workers"] is not None:
211
209
  self.cursor.execute(
212
210
  sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
213
- index_param["max_parallel_workers"]
214
- )
211
+ index_param["max_parallel_workers"],
212
+ ),
215
213
  )
216
214
  self.cursor.execute(
217
- sql.SQL(
218
- "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
219
- ).format(
215
+ sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format(
220
216
  sql.Identifier(self.db_config["user"]),
221
217
  index_param["max_parallel_workers"],
222
- )
218
+ ),
223
219
  )
224
220
  self.cursor.execute(
225
221
  sql.SQL("SET max_parallel_workers TO '{}';").format(
226
- index_param["max_parallel_workers"]
227
- )
222
+ index_param["max_parallel_workers"],
223
+ ),
228
224
  )
229
225
  self.cursor.execute(
230
- sql.SQL(
231
- "ALTER USER {} SET max_parallel_workers TO '{}';"
232
- ).format(
226
+ sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format(
233
227
  sql.Identifier(self.db_config["user"]),
234
228
  index_param["max_parallel_workers"],
235
- )
229
+ ),
236
230
  )
237
231
  self.cursor.execute(
238
- sql.SQL(
239
- "ALTER TABLE {} SET (parallel_workers = {});"
240
- ).format(
232
+ sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format(
241
233
  sql.Identifier(self.table_name),
242
234
  index_param["max_parallel_workers"],
243
- )
235
+ ),
244
236
  )
245
237
  self.conn.commit()
246
238
 
247
- results = self.cursor.execute(
248
- sql.SQL("SHOW max_parallel_maintenance_workers;")
249
- ).fetchall()
250
- results.extend(
251
- self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
252
- )
253
- results.extend(
254
- self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
255
- )
239
+ results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall()
240
+ results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall())
241
+ results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall())
256
242
  log.info(f"{self.name} parallel index creation parameters: {results}")
257
243
 
258
244
  def _create_index(self):
@@ -264,23 +250,20 @@ class AlloyDB(VectorDB):
264
250
  self._set_parallel_index_build_param()
265
251
  options = []
266
252
  for option in index_param["index_creation_with_options"]:
267
- if option['val'] is not None:
253
+ if option["val"] is not None:
268
254
  options.append(
269
255
  sql.SQL("{option_name} = {val}").format(
270
- option_name=sql.Identifier(option['option_name']),
271
- val=sql.Identifier(str(option['val'])),
272
- )
256
+ option_name=sql.Identifier(option["option_name"]),
257
+ val=sql.Identifier(str(option["val"])),
258
+ ),
273
259
  )
274
- if any(options):
275
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
276
- else:
277
- with_clause = sql.Composed(())
260
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
278
261
 
279
262
  index_create_sql = sql.SQL(
280
263
  """
281
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
264
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
282
265
  USING {index_type} (embedding {embedding_metric})
283
- """
266
+ """,
284
267
  ).format(
285
268
  index_name=sql.Identifier(self._index_name),
286
269
  table_name=sql.Identifier(self.table_name),
@@ -288,9 +271,7 @@ class AlloyDB(VectorDB):
288
271
  embedding_metric=sql.Identifier(index_param["metric"]),
289
272
  )
290
273
 
291
- index_create_sql_with_with_clause = (
292
- index_create_sql + with_clause
293
- ).join(" ")
274
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
294
275
  log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
295
276
  self.cursor.execute(index_create_sql_with_with_clause)
296
277
  self.conn.commit()
@@ -305,14 +286,12 @@ class AlloyDB(VectorDB):
305
286
  # create table
306
287
  self.cursor.execute(
307
288
  sql.SQL(
308
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
309
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
289
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
290
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
310
291
  )
311
292
  self.conn.commit()
312
293
  except Exception as e:
313
- log.warning(
314
- f"Failed to create alloydb table: {self.table_name} error: {e}"
315
- )
294
+ log.warning(f"Failed to create alloydb table: {self.table_name} error: {e}")
316
295
  raise e from None
317
296
 
318
297
  def insert_embeddings(
@@ -320,7 +299,7 @@ class AlloyDB(VectorDB):
320
299
  embeddings: list[list[float]],
321
300
  metadata: list[int],
322
301
  **kwargs: Any,
323
- ) -> Tuple[int, Optional[Exception]]:
302
+ ) -> tuple[int, Exception | None]:
324
303
  assert self.conn is not None, "Connection is not initialized"
325
304
  assert self.cursor is not None, "Cursor is not initialized"
326
305
 
@@ -330,8 +309,8 @@ class AlloyDB(VectorDB):
330
309
 
331
310
  with self.cursor.copy(
332
311
  sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
333
- table_name=sql.Identifier(self.table_name)
334
- )
312
+ table_name=sql.Identifier(self.table_name),
313
+ ),
335
314
  ) as copy:
336
315
  copy.set_types(["bigint", "vector"])
337
316
  for i, row in enumerate(metadata_arr):
@@ -343,9 +322,7 @@ class AlloyDB(VectorDB):
343
322
 
344
323
  return len(metadata), None
345
324
  except Exception as e:
346
- log.warning(
347
- f"Failed to insert data into alloydb table ({self.table_name}), error: {e}"
348
- )
325
+ log.warning(f"Failed to insert data into alloydb table ({self.table_name}), error: {e}")
349
326
  return 0, e
350
327
 
351
328
  def search_embedding(
@@ -362,11 +339,12 @@ class AlloyDB(VectorDB):
362
339
  if filters:
363
340
  gt = filters.get("id")
364
341
  result = self.cursor.execute(
365
- self._filtered_search, (gt, q, k), prepare=True, binary=True
342
+ self._filtered_search,
343
+ (gt, q, k),
344
+ prepare=True,
345
+ binary=True,
366
346
  )
367
347
  else:
368
- result = self.cursor.execute(
369
- self._unfiltered_search, (q, k), prepare=True, binary=True
370
- )
348
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
371
349
 
372
350
  return [int(i[0]) for i in result.fetchall()]
@@ -1,10 +1,10 @@
1
- from typing import Annotated, Optional, TypedDict, Unpack
1
+ import os
2
+ from typing import Annotated, Unpack
2
3
 
3
4
  import click
4
- import os
5
5
  from pydantic import SecretStr
6
6
 
7
- from vectordb_bench.backend.clients.api import MetricType
7
+ from vectordb_bench.backend.clients import DB
8
8
 
9
9
  from ....cli.cli import (
10
10
  CommonTypedDict,
@@ -13,31 +13,28 @@ from ....cli.cli import (
13
13
  get_custom_case_config,
14
14
  run,
15
15
  )
16
- from vectordb_bench.backend.clients import DB
17
16
 
18
17
 
19
18
  class AlloyDBTypedDict(CommonTypedDict):
20
19
  user_name: Annotated[
21
- str, click.option("--user-name", type=str, help="Db username", required=True)
20
+ str,
21
+ click.option("--user-name", type=str, help="Db username", required=True),
22
22
  ]
23
23
  password: Annotated[
24
24
  str,
25
- click.option("--password",
26
- type=str,
27
- help="Postgres database password",
28
- default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
29
- show_default="$POSTGRES_PASSWORD",
30
- ),
25
+ click.option(
26
+ "--password",
27
+ type=str,
28
+ help="Postgres database password",
29
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
30
+ show_default="$POSTGRES_PASSWORD",
31
+ ),
31
32
  ]
32
33
 
33
- host: Annotated[
34
- str, click.option("--host", type=str, help="Db host", required=True)
35
- ]
36
- db_name: Annotated[
37
- str, click.option("--db-name", type=str, help="Db name", required=True)
38
- ]
34
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
35
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
39
36
  maintenance_work_mem: Annotated[
40
- Optional[str],
37
+ str | None,
41
38
  click.option(
42
39
  "--maintenance-work-mem",
43
40
  type=str,
@@ -49,7 +46,7 @@ class AlloyDBTypedDict(CommonTypedDict):
49
46
  ),
50
47
  ]
51
48
  max_parallel_workers: Annotated[
52
- Optional[int],
49
+ int | None,
53
50
  click.option(
54
51
  "--max-parallel-workers",
55
52
  type=int,
@@ -58,32 +55,51 @@ class AlloyDBTypedDict(CommonTypedDict):
58
55
  ),
59
56
  ]
60
57
 
61
-
62
58
 
63
59
  class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
64
60
  num_leaves: Annotated[
65
61
  int,
66
- click.option("--num-leaves", type=int, help="Number of leaves", required=True)
62
+ click.option("--num-leaves", type=int, help="Number of leaves", required=True),
67
63
  ]
68
64
  num_leaves_to_search: Annotated[
69
65
  int,
70
- click.option("--num-leaves-to-search", type=int, help="Number of leaves to search", required=True)
66
+ click.option(
67
+ "--num-leaves-to-search",
68
+ type=int,
69
+ help="Number of leaves to search",
70
+ required=True,
71
+ ),
71
72
  ]
72
73
  pre_reordering_num_neighbors: Annotated[
73
74
  int,
74
- click.option("--pre-reordering-num-neighbors", type=int, help="Pre-reordering number of neighbors", default=200)
75
+ click.option(
76
+ "--pre-reordering-num-neighbors",
77
+ type=int,
78
+ help="Pre-reordering number of neighbors",
79
+ default=200,
80
+ ),
75
81
  ]
76
82
  max_top_neighbors_buffer_size: Annotated[
77
83
  int,
78
- click.option("--max-top-neighbors-buffer-size", type=int, help="Maximum top neighbors buffer size", default=20_000)
84
+ click.option(
85
+ "--max-top-neighbors-buffer-size",
86
+ type=int,
87
+ help="Maximum top neighbors buffer size",
88
+ default=20_000,
89
+ ),
79
90
  ]
80
91
  num_search_threads: Annotated[
81
92
  int,
82
- click.option("--num-search-threads", type=int, help="Number of search threads", default=2)
93
+ click.option("--num-search-threads", type=int, help="Number of search threads", default=2),
83
94
  ]
84
95
  max_num_prefetch_datasets: Annotated[
85
96
  int,
86
- click.option("--max-num-prefetch-datasets", type=int, help="Maximum number of prefetch datasets", default=100)
97
+ click.option(
98
+ "--max-num-prefetch-datasets",
99
+ type=int,
100
+ help="Maximum number of prefetch datasets",
101
+ default=100,
102
+ ),
87
103
  ]
88
104
  quantizer: Annotated[
89
105
  str,
@@ -91,16 +107,17 @@ class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
91
107
  "--quantizer",
92
108
  type=click.Choice(["SQ8", "FLAT"]),
93
109
  help="Quantizer type",
94
- default="SQ8"
95
- )
110
+ default="SQ8",
111
+ ),
96
112
  ]
97
113
  enable_pca: Annotated[
98
- bool, click.option(
114
+ bool,
115
+ click.option(
99
116
  "--enable-pca",
100
117
  type=click.Choice(["on", "off"]),
101
118
  help="Enable PCA",
102
- default="on"
103
- )
119
+ default="on",
120
+ ),
104
121
  ]
105
122
  max_num_levels: Annotated[
106
123
  int,
@@ -108,8 +125,8 @@ class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
108
125
  "--max-num-levels",
109
126
  type=click.Choice(["1", "2"]),
110
127
  help="Maximum number of levels",
111
- default=1
112
- )
128
+ default=1,
129
+ ),
113
130
  ]
114
131
 
115
132
 
@@ -144,4 +161,4 @@ def AlloyDBScaNN(
144
161
  maintenance_work_mem=parameters["maintenance_work_mem"],
145
162
  ),
146
163
  **parameters,
147
- )
164
+ )
@@ -1,7 +1,9 @@
1
1
  from abc import abstractmethod
2
- from typing import Any, Mapping, Optional, Sequence, TypedDict
2
+ from collections.abc import Mapping, Sequence
3
+ from typing import Any, LiteralString, TypedDict
4
+
3
5
  from pydantic import BaseModel, SecretStr
4
- from typing_extensions import LiteralString
6
+
5
7
  from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
8
 
7
9
  POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
@@ -9,7 +11,7 @@ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
9
11
 
10
12
  class AlloyDBConfigDict(TypedDict):
11
13
  """These keys will be directly used as kwargs in psycopg connection string,
12
- so the names must match exactly psycopg API"""
14
+ so the names must match exactly psycopg API"""
13
15
 
14
16
  user: str
15
17
  password: str
@@ -41,8 +43,8 @@ class AlloyDBIndexParam(TypedDict):
41
43
  metric: str
42
44
  index_type: str
43
45
  index_creation_with_options: Sequence[dict[str, Any]]
44
- maintenance_work_mem: Optional[str]
45
- max_parallel_workers: Optional[int]
46
+ maintenance_work_mem: str | None
47
+ max_parallel_workers: int | None
46
48
 
47
49
 
48
50
  class AlloyDBSearchParam(TypedDict):
@@ -61,31 +63,30 @@ class AlloyDBIndexConfig(BaseModel, DBCaseConfig):
61
63
  def parse_metric(self) -> str:
62
64
  if self.metric_type == MetricType.L2:
63
65
  return "l2"
64
- elif self.metric_type == MetricType.DP:
66
+ if self.metric_type == MetricType.DP:
65
67
  return "dot_product"
66
68
  return "cosine"
67
69
 
68
70
  def parse_metric_fun_op(self) -> LiteralString:
69
71
  if self.metric_type == MetricType.L2:
70
72
  return "<->"
71
- elif self.metric_type == MetricType.IP:
73
+ if self.metric_type == MetricType.IP:
72
74
  return "<#>"
73
75
  return "<=>"
74
76
 
75
77
  @abstractmethod
76
- def index_param(self) -> AlloyDBIndexParam:
77
- ...
78
+ def index_param(self) -> AlloyDBIndexParam: ...
78
79
 
79
80
  @abstractmethod
80
- def search_param(self) -> AlloyDBSearchParam:
81
- ...
81
+ def search_param(self) -> AlloyDBSearchParam: ...
82
82
 
83
83
  @abstractmethod
84
- def session_param(self) -> AlloyDBSessionCommands:
85
- ...
84
+ def session_param(self) -> AlloyDBSessionCommands: ...
86
85
 
87
86
  @staticmethod
88
- def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
87
+ def _optionally_build_with_options(
88
+ with_options: Mapping[str, Any],
89
+ ) -> Sequence[dict[str, Any]]:
89
90
  """Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
90
91
  options = []
91
92
  for option_name, value in with_options.items():
@@ -94,24 +95,25 @@ class AlloyDBIndexConfig(BaseModel, DBCaseConfig):
94
95
  {
95
96
  "option_name": option_name,
96
97
  "val": str(value),
97
- }
98
+ },
98
99
  )
99
100
  return options
100
101
 
101
102
  @staticmethod
102
103
  def _optionally_build_set_options(
103
- set_mapping: Mapping[str, Any]
104
+ set_mapping: Mapping[str, Any],
104
105
  ) -> Sequence[dict[str, Any]]:
105
106
  """Walk through options, creating 'SET 'key1 = "value1";' list"""
106
107
  session_options = []
107
108
  for setting_name, value in set_mapping.items():
108
109
  if value:
109
110
  session_options.append(
110
- {"parameter": {
111
+ {
112
+ "parameter": {
111
113
  "setting_name": setting_name,
112
114
  "val": str(value),
113
115
  },
114
- }
116
+ },
115
117
  )
116
118
  return session_options
117
119
 
@@ -124,22 +126,22 @@ class AlloyDBScaNNConfig(AlloyDBIndexConfig):
124
126
  max_num_levels: int | None
125
127
  num_leaves_to_search: int | None
126
128
  max_top_neighbors_buffer_size: int | None
127
- pre_reordering_num_neighbors: int | None
128
- num_search_threads: int | None
129
+ pre_reordering_num_neighbors: int | None
130
+ num_search_threads: int | None
129
131
  max_num_prefetch_datasets: int | None
130
- maintenance_work_mem: Optional[str] = None
131
- max_parallel_workers: Optional[int] = None
132
+ maintenance_work_mem: str | None = None
133
+ max_parallel_workers: int | None = None
132
134
 
133
135
  def index_param(self) -> AlloyDBIndexParam:
134
136
  index_parameters = {
135
- "num_leaves": self.num_leaves, "max_num_levels": self.max_num_levels, "quantizer": self.quantizer,
137
+ "num_leaves": self.num_leaves,
138
+ "max_num_levels": self.max_num_levels,
139
+ "quantizer": self.quantizer,
136
140
  }
137
141
  return {
138
142
  "metric": self.parse_metric(),
139
143
  "index_type": self.index.value,
140
- "index_creation_with_options": self._optionally_build_with_options(
141
- index_parameters
142
- ),
144
+ "index_creation_with_options": self._optionally_build_with_options(index_parameters),
143
145
  "maintenance_work_mem": self.maintenance_work_mem,
144
146
  "max_parallel_workers": self.max_parallel_workers,
145
147
  "enable_pca": self.enable_pca,
@@ -158,11 +160,9 @@ class AlloyDBScaNNConfig(AlloyDBIndexConfig):
158
160
  "scann.num_search_threads": self.num_search_threads,
159
161
  "scann.max_num_prefetch_datasets": self.max_num_prefetch_datasets,
160
162
  }
161
- return {
162
- "session_options": self._optionally_build_set_options(session_parameters)
163
- }
163
+ return {"session_options": self._optionally_build_set_options(session_parameters)}
164
164
 
165
165
 
166
166
  _alloydb_case_config = {
167
- IndexType.SCANN: AlloyDBScaNNConfig,
167
+ IndexType.SCANN: AlloyDBScaNNConfig,
168
168
  }