vectordb-bench 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. vectordb_bench/backend/cases.py +1 -1
  2. vectordb_bench/backend/clients/__init__.py +39 -0
  3. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +27 -0
  4. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +19 -0
  5. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +304 -0
  6. vectordb_bench/backend/clients/aliyun_opensearch/config.py +48 -0
  7. vectordb_bench/backend/clients/alloydb/alloydb.py +372 -0
  8. vectordb_bench/backend/clients/alloydb/cli.py +147 -0
  9. vectordb_bench/backend/clients/alloydb/config.py +168 -0
  10. vectordb_bench/backend/clients/api.py +5 -0
  11. vectordb_bench/backend/clients/milvus/cli.py +25 -1
  12. vectordb_bench/backend/clients/milvus/config.py +16 -2
  13. vectordb_bench/backend/clients/milvus/milvus.py +4 -6
  14. vectordb_bench/backend/runner/rate_runner.py +32 -15
  15. vectordb_bench/backend/runner/read_write_runner.py +102 -36
  16. vectordb_bench/backend/runner/serial_runner.py +8 -2
  17. vectordb_bench/backend/runner/util.py +0 -16
  18. vectordb_bench/backend/task_runner.py +4 -3
  19. vectordb_bench/backend/utils.py +1 -0
  20. vectordb_bench/cli/vectordbbench.py +2 -0
  21. vectordb_bench/frontend/config/dbCaseConfigs.py +224 -0
  22. vectordb_bench/models.py +9 -0
  23. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/METADATA +13 -23
  24. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/RECORD +28 -21
  25. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/LICENSE +0 -0
  26. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/WHEEL +0 -0
  27. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/entry_points.txt +0 -0
  28. {vectordb_bench-0.0.17.dist-info → vectordb_bench-0.0.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,372 @@
1
+ """Wrapper around the alloydb vector database over VectorDB"""
2
+
3
+ import logging
4
+ import pprint
5
+ from contextlib import contextmanager
6
+ from typing import Any, Generator, Optional, Tuple, Sequence
7
+
8
+ import numpy as np
9
+ import psycopg
10
+ from pgvector.psycopg import register_vector
11
+ from psycopg import Connection, Cursor, sql
12
+
13
+ from ..api import VectorDB
14
+ from .config import AlloyDBConfigDict, AlloyDBIndexConfig, AlloyDBScaNNConfig
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class AlloyDB(VectorDB):
20
+ """Use psycopg instructions"""
21
+
22
+ conn: psycopg.Connection[Any] | None = None
23
+ cursor: psycopg.Cursor[Any] | None = None
24
+
25
+ _filtered_search: sql.Composed
26
+ _unfiltered_search: sql.Composed
27
+
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ db_config: AlloyDBConfigDict,
32
+ db_case_config: AlloyDBIndexConfig,
33
+ collection_name: str = "alloydb_collection",
34
+ drop_old: bool = False,
35
+ **kwargs,
36
+ ):
37
+ self.name = "AlloyDB"
38
+ self.db_config = db_config
39
+ self.case_config = db_case_config
40
+ self.table_name = collection_name
41
+ self.dim = dim
42
+
43
+ self._index_name = "alloydb_index"
44
+ self._primary_field = "id"
45
+ self._vector_field = "embedding"
46
+
47
+ # construct basic units
48
+ self.conn, self.cursor = self._create_connection(**self.db_config)
49
+
50
+ # create vector extension
51
+ self.cursor.execute("CREATE EXTENSION IF NOT EXISTS alloydb_scann CASCADE")
52
+ self.conn.commit()
53
+
54
+ log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
55
+ if not any(
56
+ (
57
+ self.case_config.create_index_before_load,
58
+ self.case_config.create_index_after_load,
59
+ )
60
+ ):
61
+ err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
62
+ log.error(err)
63
+ raise RuntimeError(
64
+ f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
65
+ )
66
+
67
+ if drop_old:
68
+ self._drop_index()
69
+ self._drop_table()
70
+ self._create_table(dim)
71
+ if self.case_config.create_index_before_load:
72
+ self._create_index()
73
+
74
+ self.cursor.close()
75
+ self.conn.close()
76
+ self.cursor = None
77
+ self.conn = None
78
+
79
+ @staticmethod
80
+ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
81
+ conn = psycopg.connect(**kwargs)
82
+ register_vector(conn)
83
+ conn.autocommit = False
84
+ cursor = conn.cursor()
85
+
86
+ assert conn is not None, "Connection is not initialized"
87
+ assert cursor is not None, "Cursor is not initialized"
88
+ return conn, cursor
89
+
90
+ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
91
+ search_query = sql.Composed(
92
+ [
93
+ sql.SQL(
94
+ "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
95
+ ).format(
96
+ table_name=sql.Identifier(self.table_name),
97
+ where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
98
+ ),
99
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
100
+ sql.SQL(" %s::vector LIMIT %s::int"),
101
+ ]
102
+ )
103
+ return search_query
104
+
105
+ @contextmanager
106
+ def init(self) -> Generator[None, None, None]:
107
+ """
108
+ Examples:
109
+ >>> with self.init():
110
+ >>> self.insert_embeddings()
111
+ >>> self.search_embedding()
112
+ """
113
+
114
+ self.conn, self.cursor = self._create_connection(**self.db_config)
115
+
116
+ # index configuration may have commands defined that we should set during each client session
117
+ session_options: Sequence[dict[str, Any]] = self.case_config.session_param()["session_options"]
118
+
119
+ if len(session_options) > 0:
120
+ for setting in session_options:
121
+ command = sql.SQL("SET {setting_name} " + "= {val};").format(
122
+ setting_name=sql.Identifier(setting['parameter']['setting_name']),
123
+ val=sql.Identifier(str(setting['parameter']['val'])),
124
+ )
125
+ log.debug(command.as_string(self.cursor))
126
+ self.cursor.execute(command)
127
+ self.conn.commit()
128
+
129
+ self._filtered_search = self._generate_search_query(filtered=True)
130
+ self._unfiltered_search = self._generate_search_query()
131
+
132
+ try:
133
+ yield
134
+ finally:
135
+ self.cursor.close()
136
+ self.conn.close()
137
+ self.cursor = None
138
+ self.conn = None
139
+
140
+ def _drop_table(self):
141
+ assert self.conn is not None, "Connection is not initialized"
142
+ assert self.cursor is not None, "Cursor is not initialized"
143
+ log.info(f"{self.name} client drop table : {self.table_name}")
144
+
145
+ self.cursor.execute(
146
+ sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
147
+ table_name=sql.Identifier(self.table_name)
148
+ )
149
+ )
150
+ self.conn.commit()
151
+
152
+ def ready_to_load(self):
153
+ pass
154
+
155
+ def optimize(self):
156
+ self._post_insert()
157
+
158
+ def _post_insert(self):
159
+ log.info(f"{self.name} post insert before optimize")
160
+ if self.case_config.create_index_after_load:
161
+ self._drop_index()
162
+ self._create_index()
163
+
164
+ def _drop_index(self):
165
+ assert self.conn is not None, "Connection is not initialized"
166
+ assert self.cursor is not None, "Cursor is not initialized"
167
+ log.info(f"{self.name} client drop index : {self._index_name}")
168
+
169
+ drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
170
+ index_name=sql.Identifier(self._index_name)
171
+ )
172
+ log.debug(drop_index_sql.as_string(self.cursor))
173
+ self.cursor.execute(drop_index_sql)
174
+ self.conn.commit()
175
+
176
+ def _set_parallel_index_build_param(self):
177
+ assert self.conn is not None, "Connection is not initialized"
178
+ assert self.cursor is not None, "Cursor is not initialized"
179
+
180
+ index_param = self.case_config.index_param()
181
+
182
+ if index_param["enable_pca"] is not None:
183
+ self.cursor.execute(
184
+ sql.SQL("SET scann.enable_pca TO {};").format(
185
+ index_param["enable_pca"]
186
+ )
187
+ )
188
+ self.cursor.execute(
189
+ sql.SQL("ALTER USER {} SET scann.enable_pca TO {};").format(
190
+ sql.Identifier(self.db_config["user"]),
191
+ index_param["enable_pca"],
192
+ )
193
+ )
194
+ self.conn.commit()
195
+
196
+ if index_param["maintenance_work_mem"] is not None:
197
+ self.cursor.execute(
198
+ sql.SQL("SET maintenance_work_mem TO {};").format(
199
+ index_param["maintenance_work_mem"]
200
+ )
201
+ )
202
+ self.cursor.execute(
203
+ sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
204
+ sql.Identifier(self.db_config["user"]),
205
+ index_param["maintenance_work_mem"],
206
+ )
207
+ )
208
+ self.conn.commit()
209
+
210
+ if index_param["max_parallel_workers"] is not None:
211
+ self.cursor.execute(
212
+ sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
213
+ index_param["max_parallel_workers"]
214
+ )
215
+ )
216
+ self.cursor.execute(
217
+ sql.SQL(
218
+ "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
219
+ ).format(
220
+ sql.Identifier(self.db_config["user"]),
221
+ index_param["max_parallel_workers"],
222
+ )
223
+ )
224
+ self.cursor.execute(
225
+ sql.SQL("SET max_parallel_workers TO '{}';").format(
226
+ index_param["max_parallel_workers"]
227
+ )
228
+ )
229
+ self.cursor.execute(
230
+ sql.SQL(
231
+ "ALTER USER {} SET max_parallel_workers TO '{}';"
232
+ ).format(
233
+ sql.Identifier(self.db_config["user"]),
234
+ index_param["max_parallel_workers"],
235
+ )
236
+ )
237
+ self.cursor.execute(
238
+ sql.SQL(
239
+ "ALTER TABLE {} SET (parallel_workers = {});"
240
+ ).format(
241
+ sql.Identifier(self.table_name),
242
+ index_param["max_parallel_workers"],
243
+ )
244
+ )
245
+ self.conn.commit()
246
+
247
+ results = self.cursor.execute(
248
+ sql.SQL("SHOW max_parallel_maintenance_workers;")
249
+ ).fetchall()
250
+ results.extend(
251
+ self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
252
+ )
253
+ results.extend(
254
+ self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
255
+ )
256
+ log.info(f"{self.name} parallel index creation parameters: {results}")
257
+
258
+ def _create_index(self):
259
+ assert self.conn is not None, "Connection is not initialized"
260
+ assert self.cursor is not None, "Cursor is not initialized"
261
+ log.info(f"{self.name} client create index : {self._index_name}")
262
+
263
+ index_param = self.case_config.index_param()
264
+ self._set_parallel_index_build_param()
265
+ options = []
266
+ for option in index_param["index_creation_with_options"]:
267
+ if option['val'] is not None:
268
+ options.append(
269
+ sql.SQL("{option_name} = {val}").format(
270
+ option_name=sql.Identifier(option['option_name']),
271
+ val=sql.Identifier(str(option['val'])),
272
+ )
273
+ )
274
+ if any(options):
275
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
276
+ else:
277
+ with_clause = sql.Composed(())
278
+
279
+ index_create_sql = sql.SQL(
280
+ """
281
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
282
+ USING {index_type} (embedding {embedding_metric})
283
+ """
284
+ ).format(
285
+ index_name=sql.Identifier(self._index_name),
286
+ table_name=sql.Identifier(self.table_name),
287
+ index_type=sql.Identifier(index_param["index_type"]),
288
+ embedding_metric=sql.Identifier(index_param["metric"]),
289
+ )
290
+
291
+ index_create_sql_with_with_clause = (
292
+ index_create_sql + with_clause
293
+ ).join(" ")
294
+ log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
295
+ self.cursor.execute(index_create_sql_with_with_clause)
296
+ self.conn.commit()
297
+
298
+ def _create_table(self, dim: int):
299
+ assert self.conn is not None, "Connection is not initialized"
300
+ assert self.cursor is not None, "Cursor is not initialized"
301
+
302
+ try:
303
+ log.info(f"{self.name} client create table : {self.table_name}")
304
+
305
+ # create table
306
+ self.cursor.execute(
307
+ sql.SQL(
308
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
309
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim)
310
+ )
311
+ self.conn.commit()
312
+ except Exception as e:
313
+ log.warning(
314
+ f"Failed to create alloydb table: {self.table_name} error: {e}"
315
+ )
316
+ raise e from None
317
+
318
+ def insert_embeddings(
319
+ self,
320
+ embeddings: list[list[float]],
321
+ metadata: list[int],
322
+ **kwargs: Any,
323
+ ) -> Tuple[int, Optional[Exception]]:
324
+ assert self.conn is not None, "Connection is not initialized"
325
+ assert self.cursor is not None, "Cursor is not initialized"
326
+
327
+ try:
328
+ metadata_arr = np.array(metadata)
329
+ embeddings_arr = np.array(embeddings)
330
+
331
+ with self.cursor.copy(
332
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
333
+ table_name=sql.Identifier(self.table_name)
334
+ )
335
+ ) as copy:
336
+ copy.set_types(["bigint", "vector"])
337
+ for i, row in enumerate(metadata_arr):
338
+ copy.write_row((row, embeddings_arr[i]))
339
+ self.conn.commit()
340
+
341
+ if kwargs.get("last_batch"):
342
+ self._post_insert()
343
+
344
+ return len(metadata), None
345
+ except Exception as e:
346
+ log.warning(
347
+ f"Failed to insert data into alloydb table ({self.table_name}), error: {e}"
348
+ )
349
+ return 0, e
350
+
351
+ def search_embedding(
352
+ self,
353
+ query: list[float],
354
+ k: int = 100,
355
+ filters: dict | None = None,
356
+ timeout: int | None = None,
357
+ ) -> list[int]:
358
+ assert self.conn is not None, "Connection is not initialized"
359
+ assert self.cursor is not None, "Cursor is not initialized"
360
+
361
+ q = np.asarray(query)
362
+ if filters:
363
+ gt = filters.get("id")
364
+ result = self.cursor.execute(
365
+ self._filtered_search, (gt, q, k), prepare=True, binary=True
366
+ )
367
+ else:
368
+ result = self.cursor.execute(
369
+ self._unfiltered_search, (q, k), prepare=True, binary=True
370
+ )
371
+
372
+ return [int(i[0]) for i in result.fetchall()]
@@ -0,0 +1,147 @@
1
+ from typing import Annotated, Optional, TypedDict, Unpack
2
+
3
+ import click
4
+ import os
5
+ from pydantic import SecretStr
6
+
7
+ from vectordb_bench.backend.clients.api import MetricType
8
+
9
+ from ....cli.cli import (
10
+ CommonTypedDict,
11
+ cli,
12
+ click_parameter_decorators_from_typed_dict,
13
+ get_custom_case_config,
14
+ run,
15
+ )
16
+ from vectordb_bench.backend.clients import DB
17
+
18
+
19
+ class AlloyDBTypedDict(CommonTypedDict):
20
+ user_name: Annotated[
21
+ str, click.option("--user-name", type=str, help="Db username", required=True)
22
+ ]
23
+ password: Annotated[
24
+ str,
25
+ click.option("--password",
26
+ type=str,
27
+ help="Postgres database password",
28
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
29
+ show_default="$POSTGRES_PASSWORD",
30
+ ),
31
+ ]
32
+
33
+ host: Annotated[
34
+ str, click.option("--host", type=str, help="Db host", required=True)
35
+ ]
36
+ db_name: Annotated[
37
+ str, click.option("--db-name", type=str, help="Db name", required=True)
38
+ ]
39
+ maintenance_work_mem: Annotated[
40
+ Optional[str],
41
+ click.option(
42
+ "--maintenance-work-mem",
43
+ type=str,
44
+ help="Sets the maximum memory to be used for maintenance operations (index creation). "
45
+ "Can be entered as string with unit like '64GB' or as an integer number of KB."
46
+ "This will set the parameters: max_parallel_maintenance_workers,"
47
+ " max_parallel_workers & table(parallel_workers)",
48
+ required=False,
49
+ ),
50
+ ]
51
+ max_parallel_workers: Annotated[
52
+ Optional[int],
53
+ click.option(
54
+ "--max-parallel-workers",
55
+ type=int,
56
+ help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
57
+ required=False,
58
+ ),
59
+ ]
60
+
61
+
62
+
63
+ class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
64
+ num_leaves: Annotated[
65
+ int,
66
+ click.option("--num-leaves", type=int, help="Number of leaves", required=True)
67
+ ]
68
+ num_leaves_to_search: Annotated[
69
+ int,
70
+ click.option("--num-leaves-to-search", type=int, help="Number of leaves to search", required=True)
71
+ ]
72
+ pre_reordering_num_neighbors: Annotated[
73
+ int,
74
+ click.option("--pre-reordering-num-neighbors", type=int, help="Pre-reordering number of neighbors", default=200)
75
+ ]
76
+ max_top_neighbors_buffer_size: Annotated[
77
+ int,
78
+ click.option("--max-top-neighbors-buffer-size", type=int, help="Maximum top neighbors buffer size", default=20_000)
79
+ ]
80
+ num_search_threads: Annotated[
81
+ int,
82
+ click.option("--num-search-threads", type=int, help="Number of search threads", default=2)
83
+ ]
84
+ max_num_prefetch_datasets: Annotated[
85
+ int,
86
+ click.option("--max-num-prefetch-datasets", type=int, help="Maximum number of prefetch datasets", default=100)
87
+ ]
88
+ quantizer: Annotated[
89
+ str,
90
+ click.option(
91
+ "--quantizer",
92
+ type=click.Choice(["SQ8", "FLAT"]),
93
+ help="Quantizer type",
94
+ default="SQ8"
95
+ )
96
+ ]
97
+ enable_pca: Annotated[
98
+ bool, click.option(
99
+ "--enable-pca",
100
+ type=click.Choice(["on", "off"]),
101
+ help="Enable PCA",
102
+ default="on"
103
+ )
104
+ ]
105
+ max_num_levels: Annotated[
106
+ int,
107
+ click.option(
108
+ "--max-num-levels",
109
+ type=click.Choice(["1", "2"]),
110
+ help="Maximum number of levels",
111
+ default=1
112
+ )
113
+ ]
114
+
115
+
116
+ @cli.command()
117
+ @click_parameter_decorators_from_typed_dict(AlloyDBScaNNTypedDict)
118
+ def AlloyDBScaNN(
119
+ **parameters: Unpack[AlloyDBScaNNTypedDict],
120
+ ):
121
+ from .config import AlloyDBConfig, AlloyDBScaNNConfig
122
+
123
+ parameters["custom_case"] = get_custom_case_config(parameters)
124
+ run(
125
+ db=DB.AlloyDB,
126
+ db_config=AlloyDBConfig(
127
+ db_label=parameters["db_label"],
128
+ user_name=SecretStr(parameters["user_name"]),
129
+ password=SecretStr(parameters["password"]),
130
+ host=parameters["host"],
131
+ db_name=parameters["db_name"],
132
+ ),
133
+ db_case_config=AlloyDBScaNNConfig(
134
+ num_leaves=parameters["num_leaves"],
135
+ quantizer=parameters["quantizer"],
136
+ enable_pca=parameters["enable_pca"],
137
+ max_num_levels=parameters["max_num_levels"],
138
+ num_leaves_to_search=parameters["num_leaves_to_search"],
139
+ max_top_neighbors_buffer_size=parameters["max_top_neighbors_buffer_size"],
140
+ pre_reordering_num_neighbors=parameters["pre_reordering_num_neighbors"],
141
+ num_search_threads=parameters["num_search_threads"],
142
+ max_num_prefetch_datasets=parameters["max_num_prefetch_datasets"],
143
+ max_parallel_workers=parameters["max_parallel_workers"],
144
+ maintenance_work_mem=parameters["maintenance_work_mem"],
145
+ ),
146
+ **parameters,
147
+ )
@@ -0,0 +1,168 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Mapping, Optional, Sequence, TypedDict
3
+ from pydantic import BaseModel, SecretStr
4
+ from typing_extensions import LiteralString
5
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
+
7
+ POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
8
+
9
+
10
+ class AlloyDBConfigDict(TypedDict):
11
+ """These keys will be directly used as kwargs in psycopg connection string,
12
+ so the names must match exactly psycopg API"""
13
+
14
+ user: str
15
+ password: str
16
+ host: str
17
+ port: int
18
+ dbname: str
19
+
20
+
21
+ class AlloyDBConfig(DBConfig):
22
+ user_name: SecretStr = SecretStr("postgres")
23
+ password: SecretStr
24
+ host: str = "localhost"
25
+ port: int = 5432
26
+ db_name: str
27
+
28
+ def to_dict(self) -> AlloyDBConfigDict:
29
+ user_str = self.user_name.get_secret_value()
30
+ pwd_str = self.password.get_secret_value()
31
+ return {
32
+ "host": self.host,
33
+ "port": self.port,
34
+ "dbname": self.db_name,
35
+ "user": user_str,
36
+ "password": pwd_str,
37
+ }
38
+
39
+
40
+ class AlloyDBIndexParam(TypedDict):
41
+ metric: str
42
+ index_type: str
43
+ index_creation_with_options: Sequence[dict[str, Any]]
44
+ maintenance_work_mem: Optional[str]
45
+ max_parallel_workers: Optional[int]
46
+
47
+
48
+ class AlloyDBSearchParam(TypedDict):
49
+ metric_fun_op: LiteralString
50
+
51
+
52
+ class AlloyDBSessionCommands(TypedDict):
53
+ session_options: Sequence[dict[str, Any]]
54
+
55
+
56
+ class AlloyDBIndexConfig(BaseModel, DBCaseConfig):
57
+ metric_type: MetricType | None = None
58
+ create_index_before_load: bool = False
59
+ create_index_after_load: bool = True
60
+
61
+ def parse_metric(self) -> str:
62
+ if self.metric_type == MetricType.L2:
63
+ return "l2"
64
+ elif self.metric_type == MetricType.DP:
65
+ return "dot_product"
66
+ return "cosine"
67
+
68
+ def parse_metric_fun_op(self) -> LiteralString:
69
+ if self.metric_type == MetricType.L2:
70
+ return "<->"
71
+ elif self.metric_type == MetricType.IP:
72
+ return "<#>"
73
+ return "<=>"
74
+
75
+ @abstractmethod
76
+ def index_param(self) -> AlloyDBIndexParam:
77
+ ...
78
+
79
+ @abstractmethod
80
+ def search_param(self) -> AlloyDBSearchParam:
81
+ ...
82
+
83
+ @abstractmethod
84
+ def session_param(self) -> AlloyDBSessionCommands:
85
+ ...
86
+
87
+ @staticmethod
88
+ def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
89
+ """Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
90
+ options = []
91
+ for option_name, value in with_options.items():
92
+ if value is not None:
93
+ options.append(
94
+ {
95
+ "option_name": option_name,
96
+ "val": str(value),
97
+ }
98
+ )
99
+ return options
100
+
101
+ @staticmethod
102
+ def _optionally_build_set_options(
103
+ set_mapping: Mapping[str, Any]
104
+ ) -> Sequence[dict[str, Any]]:
105
+ """Walk through options, creating 'SET 'key1 = "value1";' list"""
106
+ session_options = []
107
+ for setting_name, value in set_mapping.items():
108
+ if value:
109
+ session_options.append(
110
+ {"parameter": {
111
+ "setting_name": setting_name,
112
+ "val": str(value),
113
+ },
114
+ }
115
+ )
116
+ return session_options
117
+
118
+
119
+ class AlloyDBScaNNConfig(AlloyDBIndexConfig):
120
+ index: IndexType = IndexType.SCANN
121
+ num_leaves: int | None
122
+ quantizer: str | None
123
+ enable_pca: str | None
124
+ max_num_levels: int | None
125
+ num_leaves_to_search: int | None
126
+ max_top_neighbors_buffer_size: int | None
127
+ pre_reordering_num_neighbors: int | None
128
+ num_search_threads: int | None
129
+ max_num_prefetch_datasets: int | None
130
+ maintenance_work_mem: Optional[str] = None
131
+ max_parallel_workers: Optional[int] = None
132
+
133
+ def index_param(self) -> AlloyDBIndexParam:
134
+ index_parameters = {
135
+ "num_leaves": self.num_leaves, "max_num_levels": self.max_num_levels, "quantizer": self.quantizer,
136
+ }
137
+ return {
138
+ "metric": self.parse_metric(),
139
+ "index_type": self.index.value,
140
+ "index_creation_with_options": self._optionally_build_with_options(
141
+ index_parameters
142
+ ),
143
+ "maintenance_work_mem": self.maintenance_work_mem,
144
+ "max_parallel_workers": self.max_parallel_workers,
145
+ "enable_pca": self.enable_pca,
146
+ }
147
+
148
+ def search_param(self) -> AlloyDBSearchParam:
149
+ return {
150
+ "metric_fun_op": self.parse_metric_fun_op(),
151
+ }
152
+
153
+ def session_param(self) -> AlloyDBSessionCommands:
154
+ session_parameters = {
155
+ "scann.num_leaves_to_search": self.num_leaves_to_search,
156
+ "scann.max_top_neighbors_buffer_size": self.max_top_neighbors_buffer_size,
157
+ "scann.pre_reordering_num_neighbors": self.pre_reordering_num_neighbors,
158
+ "scann.num_search_threads": self.num_search_threads,
159
+ "scann.max_num_prefetch_datasets": self.max_num_prefetch_datasets,
160
+ }
161
+ return {
162
+ "session_options": self._optionally_build_set_options(session_parameters)
163
+ }
164
+
165
+
166
+ _alloydb_case_config = {
167
+ IndexType.SCANN: AlloyDBScaNNConfig,
168
+ }