vectordb-bench 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. vectordb_bench/backend/clients/__init__.py +65 -1
  2. vectordb_bench/backend/clients/api.py +2 -1
  3. vectordb_bench/backend/clients/chroma/chroma.py +2 -2
  4. vectordb_bench/backend/clients/clickhouse/cli.py +66 -0
  5. vectordb_bench/backend/clients/clickhouse/clickhouse.py +156 -0
  6. vectordb_bench/backend/clients/clickhouse/config.py +60 -0
  7. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +1 -1
  8. vectordb_bench/backend/clients/mariadb/cli.py +122 -0
  9. vectordb_bench/backend/clients/mariadb/config.py +73 -0
  10. vectordb_bench/backend/clients/mariadb/mariadb.py +208 -0
  11. vectordb_bench/backend/clients/milvus/cli.py +32 -0
  12. vectordb_bench/backend/clients/milvus/config.py +32 -0
  13. vectordb_bench/backend/clients/milvus/milvus.py +1 -1
  14. vectordb_bench/backend/clients/pgvector/cli.py +14 -3
  15. vectordb_bench/backend/clients/pgvector/config.py +22 -5
  16. vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
  17. vectordb_bench/backend/clients/pinecone/pinecone.py +1 -1
  18. vectordb_bench/backend/clients/qdrant_cloud/config.py +1 -9
  19. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +1 -1
  20. vectordb_bench/backend/clients/tidb/cli.py +98 -0
  21. vectordb_bench/backend/clients/tidb/config.py +46 -0
  22. vectordb_bench/backend/clients/tidb/tidb.py +233 -0
  23. vectordb_bench/backend/clients/vespa/cli.py +47 -0
  24. vectordb_bench/backend/clients/vespa/config.py +51 -0
  25. vectordb_bench/backend/clients/vespa/util.py +15 -0
  26. vectordb_bench/backend/clients/vespa/vespa.py +249 -0
  27. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
  28. vectordb_bench/cli/cli.py +20 -17
  29. vectordb_bench/cli/vectordbbench.py +8 -0
  30. vectordb_bench/frontend/config/dbCaseConfigs.py +147 -0
  31. vectordb_bench/frontend/config/styles.py +4 -0
  32. vectordb_bench/models.py +8 -6
  33. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/METADATA +22 -3
  34. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/RECORD +38 -25
  35. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/WHEEL +1 -1
  36. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/entry_points.txt +0 -0
  37. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info/licenses}/LICENSE +0 -0
  38. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/top_level.txt +0 -0
@@ -38,9 +38,13 @@ class DB(Enum):
38
38
  Chroma = "Chroma"
39
39
  AWSOpenSearch = "OpenSearch"
40
40
  AliyunElasticsearch = "AliyunElasticsearch"
41
+ MariaDB = "MariaDB"
41
42
  Test = "test"
42
43
  AliyunOpenSearch = "AliyunOpenSearch"
43
44
  MongoDB = "MongoDB"
45
+ TiDB = "TiDB"
46
+ Clickhouse = "Clickhouse"
47
+ Vespa = "Vespa"
44
48
 
45
49
  @property
46
50
  def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
@@ -115,6 +119,11 @@ class DB(Enum):
115
119
 
116
120
  return AWSOpenSearch
117
121
 
122
+ if self == DB.Clickhouse:
123
+ from .clickhouse.clickhouse import Clickhouse
124
+
125
+ return Clickhouse
126
+
118
127
  if self == DB.AlloyDB:
119
128
  from .alloydb.alloydb import AlloyDB
120
129
 
@@ -135,11 +144,26 @@ class DB(Enum):
135
144
 
136
145
  return MongoDB
137
146
 
147
+ if self == DB.MariaDB:
148
+ from .mariadb.mariadb import MariaDB
149
+
150
+ return MariaDB
151
+
152
+ if self == DB.TiDB:
153
+ from .tidb.tidb import TiDB
154
+
155
+ return TiDB
156
+
138
157
  if self == DB.Test:
139
158
  from .test.test import Test
140
159
 
141
160
  return Test
142
161
 
162
+ if self == DB.Vespa:
163
+ from .vespa.vespa import Vespa
164
+
165
+ return Vespa
166
+
143
167
  msg = f"Unknown DB: {self.name}"
144
168
  raise ValueError(msg)
145
169
 
@@ -216,6 +240,11 @@ class DB(Enum):
216
240
 
217
241
  return AWSOpenSearchConfig
218
242
 
243
+ if self == DB.Clickhouse:
244
+ from .clickhouse.config import ClickhouseConfig
245
+
246
+ return ClickhouseConfig
247
+
219
248
  if self == DB.AlloyDB:
220
249
  from .alloydb.config import AlloyDBConfig
221
250
 
@@ -236,15 +265,30 @@ class DB(Enum):
236
265
 
237
266
  return MongoDBConfig
238
267
 
268
+ if self == DB.MariaDB:
269
+ from .mariadb.config import MariaDBConfig
270
+
271
+ return MariaDBConfig
272
+
273
+ if self == DB.TiDB:
274
+ from .tidb.config import TiDBConfig
275
+
276
+ return TiDBConfig
277
+
239
278
  if self == DB.Test:
240
279
  from .test.config import TestConfig
241
280
 
242
281
  return TestConfig
243
282
 
283
+ if self == DB.Vespa:
284
+ from .vespa.config import VespaConfig
285
+
286
+ return VespaConfig
287
+
244
288
  msg = f"Unknown DB: {self.name}"
245
289
  raise ValueError(msg)
246
290
 
247
- def case_config_cls( # noqa: PLR0911
291
+ def case_config_cls( # noqa: C901, PLR0911, PLR0912
248
292
  self,
249
293
  index_type: IndexType | None = None,
250
294
  ) -> type[DBCaseConfig]:
@@ -288,6 +332,11 @@ class DB(Enum):
288
332
 
289
333
  return AWSOpenSearchIndexConfig
290
334
 
335
+ if self == DB.Clickhouse:
336
+ from .clickhouse.config import ClickhouseHNSWConfig
337
+
338
+ return ClickhouseHNSWConfig
339
+
291
340
  if self == DB.PgVectorScale:
292
341
  from .pgvectorscale.config import _pgvectorscale_case_config
293
342
 
@@ -318,6 +367,21 @@ class DB(Enum):
318
367
 
319
368
  return MongoDBIndexConfig
320
369
 
370
+ if self == DB.MariaDB:
371
+ from .mariadb.config import _mariadb_case_config
372
+
373
+ return _mariadb_case_config.get(index_type)
374
+
375
+ if self == DB.TiDB:
376
+ from .tidb.config import TiDBIndexConfig
377
+
378
+ return TiDBIndexConfig
379
+
380
+ if self == DB.Vespa:
381
+ from .vespa.config import VespaHNSWConfig
382
+
383
+ return VespaHNSWConfig
384
+
321
385
  # DB.Pinecone, DB.Chroma, DB.Redis
322
386
  return EmptyDBCaseConfig
323
387
 
@@ -25,6 +25,7 @@ class IndexType(str, Enum):
25
25
  ES_HNSW = "hnsw"
26
26
  ES_IVFFlat = "ivfflat"
27
27
  GPU_IVF_FLAT = "GPU_IVF_FLAT"
28
+ GPU_BRUTE_FORCE = "GPU_BRUTE_FORCE"
28
29
  GPU_IVF_PQ = "GPU_IVF_PQ"
29
30
  GPU_CAGRA = "GPU_CAGRA"
30
31
  SCANN = "scann"
@@ -161,7 +162,7 @@ class VectorDB(ABC):
161
162
  embeddings: list[list[float]],
162
163
  metadata: list[int],
163
164
  **kwargs,
164
- ) -> (int, Exception):
165
+ ) -> tuple[int, Exception]:
165
166
  """Insert the embeddings to the vector database. The default number of embeddings for
166
167
  each insert_embeddings is 5000.
167
168
 
@@ -65,7 +65,7 @@ class ChromaClient(VectorDB):
65
65
  embeddings: list[list[float]],
66
66
  metadata: list[int],
67
67
  **kwargs: Any,
68
- ) -> (int, Exception):
68
+ ) -> tuple[int, Exception]:
69
69
  """Insert embeddings into the database.
70
70
 
71
71
  Args:
@@ -74,7 +74,7 @@ class ChromaClient(VectorDB):
74
74
  kwargs: other arguments
75
75
 
76
76
  Returns:
77
- (int, Exception): number of embeddings inserted and exception if any
77
+ tuple[int, Exception]: number of embeddings inserted and exception if any
78
78
  """
79
79
  ids = [str(i) for i in metadata]
80
80
  metadata = [{"id": int(i)} for i in metadata]
@@ -0,0 +1,66 @@
1
+ from typing import Annotated, TypedDict, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from ....cli.cli import (
7
+ CommonTypedDict,
8
+ HNSWFlavor2,
9
+ cli,
10
+ click_parameter_decorators_from_typed_dict,
11
+ run,
12
+ )
13
+ from .. import DB
14
+ from .config import ClickhouseHNSWConfig
15
+
16
+
17
+ class ClickhouseTypedDict(TypedDict):
18
+ password: Annotated[str, click.option("--password", type=str, help="DB password")]
19
+ host: Annotated[str, click.option("--host", type=str, help="DB host", required=True)]
20
+ port: Annotated[int, click.option("--port", type=int, default=8123, help="DB Port")]
21
+ user: Annotated[int, click.option("--user", type=str, default="clickhouse", help="DB user")]
22
+ ssl: Annotated[
23
+ bool,
24
+ click.option(
25
+ "--ssl/--no-ssl",
26
+ is_flag=True,
27
+ show_default=True,
28
+ default=True,
29
+ help="Enable or disable SSL for Clickhouse",
30
+ ),
31
+ ]
32
+ ssl_ca_certs: Annotated[
33
+ str,
34
+ click.option(
35
+ "--ssl-ca-certs",
36
+ show_default=True,
37
+ help="Path to certificate authority file to use for SSL",
38
+ ),
39
+ ]
40
+
41
+
42
+ class ClickhouseHNSWTypedDict(CommonTypedDict, ClickhouseTypedDict, HNSWFlavor2): ...
43
+
44
+
45
+ @cli.command()
46
+ @click_parameter_decorators_from_typed_dict(ClickhouseHNSWTypedDict)
47
+ def Clickhouse(**parameters: Unpack[ClickhouseHNSWTypedDict]):
48
+ from .config import ClickhouseConfig
49
+
50
+ run(
51
+ db=DB.Clickhouse,
52
+ db_config=ClickhouseConfig(
53
+ db_label=parameters["db_label"],
54
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
55
+ host=parameters["host"],
56
+ port=parameters["port"],
57
+ ssl=parameters["ssl"],
58
+ ssl_ca_certs=parameters["ssl_ca_certs"],
59
+ ),
60
+ db_case_config=ClickhouseHNSWConfig(
61
+ M=parameters["m"],
62
+ efConstruction=parameters["ef_construction"],
63
+ ef=parameters["ef_runtime"],
64
+ ),
65
+ **parameters,
66
+ )
@@ -0,0 +1,156 @@
1
+ """Wrapper around the Clickhouse vector database over VectorDB"""
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from typing import Any
6
+
7
+ import clickhouse_connect
8
+
9
+ from ..api import DBCaseConfig, VectorDB
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ class Clickhouse(VectorDB):
15
+ """Use SQLAlchemy instructions"""
16
+
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ db_config: dict,
21
+ db_case_config: DBCaseConfig,
22
+ collection_name: str = "CHVectorCollection",
23
+ drop_old: bool = False,
24
+ **kwargs,
25
+ ):
26
+ self.db_config = db_config
27
+ self.case_config = db_case_config
28
+ self.table_name = collection_name
29
+ self.dim = dim
30
+
31
+ self._index_name = "clickhouse_index"
32
+ self._primary_field = "id"
33
+ self._vector_field = "embedding"
34
+
35
+ # construct basic units
36
+ self.conn = clickhouse_connect.get_client(
37
+ host=self.db_config["host"],
38
+ port=self.db_config["port"],
39
+ username=self.db_config["user"],
40
+ password=self.db_config["password"],
41
+ database=self.db_config["dbname"],
42
+ )
43
+
44
+ if drop_old:
45
+ log.info(f"Clickhouse client drop table : {self.table_name}")
46
+ self._drop_table()
47
+ self._create_table(dim)
48
+
49
+ self.conn.close()
50
+ self.conn = None
51
+
52
+ @contextmanager
53
+ def init(self):
54
+ """
55
+ Examples:
56
+ >>> with self.init():
57
+ >>> self.insert_embeddings()
58
+ >>> self.search_embedding()
59
+ """
60
+
61
+ self.conn = clickhouse_connect.get_client(
62
+ host=self.db_config["host"],
63
+ port=self.db_config["port"],
64
+ username=self.db_config["user"],
65
+ password=self.db_config["password"],
66
+ database=self.db_config["dbname"],
67
+ )
68
+
69
+ try:
70
+ yield
71
+ finally:
72
+ self.conn.close()
73
+ self.conn = None
74
+
75
+ def _drop_table(self):
76
+ assert self.conn is not None, "Connection is not initialized"
77
+
78
+ self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["dbname"]}.{self.table_name}')
79
+
80
+ def _create_table(self, dim: int):
81
+ assert self.conn is not None, "Connection is not initialized"
82
+
83
+ try:
84
+ # create table
85
+ self.conn.command(
86
+ f'CREATE TABLE IF NOT EXISTS {self.db_config["dbname"]}.{self.table_name} \
87
+ (id UInt32, embedding Array(Float64)) ENGINE = MergeTree() ORDER BY id;'
88
+ )
89
+
90
+ except Exception as e:
91
+ log.warning(f"Failed to create Clickhouse table: {self.table_name} error: {e}")
92
+ raise e from None
93
+
94
+ def ready_to_load(self):
95
+ pass
96
+
97
+ def optimize(self, data_size: int | None = None):
98
+ pass
99
+
100
+ def ready_to_search(self):
101
+ pass
102
+
103
+ def insert_embeddings(
104
+ self,
105
+ embeddings: list[list[float]],
106
+ metadata: list[int],
107
+ **kwargs: Any,
108
+ ) -> tuple[int, Exception]:
109
+ assert self.conn is not None, "Connection is not initialized"
110
+
111
+ try:
112
+ # do not iterate for bulk insert
113
+ items = [metadata, embeddings]
114
+
115
+ self.conn.insert(
116
+ table=self.table_name,
117
+ data=items,
118
+ column_names=["id", "embedding"],
119
+ column_type_names=["UInt32", "Array(Float64)"],
120
+ column_oriented=True,
121
+ )
122
+ return len(metadata), None
123
+ except Exception as e:
124
+ log.warning(f"Failed to insert data into Clickhouse table ({self.table_name}), error: {e}")
125
+ return 0, e
126
+
127
+ def search_embedding(
128
+ self,
129
+ query: list[float],
130
+ k: int = 100,
131
+ filters: dict | None = None,
132
+ timeout: int | None = None,
133
+ ) -> list[int]:
134
+ assert self.conn is not None, "Connection is not initialized"
135
+
136
+ index_param = self.case_config.index_param() # noqa: F841
137
+ search_param = self.case_config.search_param()
138
+
139
+ if filters:
140
+ gt = filters.get("id")
141
+ filter_sql = (
142
+ f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
143
+ f'FROM {self.db_config["dbname"]}.{self.table_name} '
144
+ f"WHERE id > {gt} "
145
+ f"ORDER BY score LIMIT {k};"
146
+ )
147
+ result = self.conn.query(filter_sql).result_rows
148
+ return [int(row[0]) for row in result]
149
+ else: # noqa: RET505
150
+ select_sql = (
151
+ f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
152
+ f'FROM {self.db_config["dbname"]}.{self.table_name} '
153
+ f"ORDER BY score LIMIT {k};"
154
+ )
155
+ result = self.conn.query(select_sql).result_rows
156
+ return [int(row[0]) for row in result]
@@ -0,0 +1,60 @@
1
+ from pydantic import BaseModel, SecretStr
2
+
3
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
4
+
5
+
6
+ class ClickhouseConfig(DBConfig):
7
+ user_name: str = "clickhouse"
8
+ password: SecretStr
9
+ host: str = "localhost"
10
+ port: int = 8123
11
+ db_name: str = "default"
12
+
13
+ def to_dict(self) -> dict:
14
+ pwd_str = self.password.get_secret_value()
15
+ return {
16
+ "host": self.host,
17
+ "port": self.port,
18
+ "dbname": self.db_name,
19
+ "user": self.user_name,
20
+ "password": pwd_str,
21
+ }
22
+
23
+
24
+ class ClickhouseIndexConfig(BaseModel):
25
+
26
+ metric_type: MetricType | None = None
27
+
28
+ def parse_metric(self) -> str:
29
+ if not self.metric_type:
30
+ return ""
31
+ return self.metric_type.value
32
+
33
+ def parse_metric_str(self) -> str:
34
+ if self.metric_type == MetricType.L2:
35
+ return "L2Distance"
36
+ if self.metric_type == MetricType.COSINE:
37
+ return "cosineDistance"
38
+ msg = f"Not Support for {self.metric_type}"
39
+ raise RuntimeError(msg)
40
+ return None
41
+
42
+
43
+ class ClickhouseHNSWConfig(ClickhouseIndexConfig, DBCaseConfig):
44
+ M: int | None
45
+ efConstruction: int | None
46
+ ef: int | None = None
47
+ index: IndexType = IndexType.HNSW
48
+
49
+ def index_param(self) -> dict:
50
+ return {
51
+ "metric_type": self.parse_metric_str(),
52
+ "index_type": self.index.value,
53
+ "params": {"M": self.M, "efConstruction": self.efConstruction},
54
+ }
55
+
56
+ def search_param(self) -> dict:
57
+ return {
58
+ "met˝ric_type": self.parse_metric_str(),
59
+ "params": {"ef": self.ef},
60
+ }
@@ -81,7 +81,7 @@ class ElasticCloud(VectorDB):
81
81
  embeddings: Iterable[list[float]],
82
82
  metadata: list[int],
83
83
  **kwargs,
84
- ) -> (int, Exception):
84
+ ) -> tuple[int, Exception]:
85
85
  """Insert the embeddings to the elasticsearch."""
86
86
  assert self.client is not None, "should self.init() first"
87
87
 
@@ -0,0 +1,122 @@
1
+ from typing import Annotated, Unpack
2
+
3
+ import click
4
+ from pydantic import SecretStr
5
+
6
+ from vectordb_bench.backend.clients import DB
7
+
8
+ from ....cli.cli import (
9
+ CommonTypedDict,
10
+ cli,
11
+ click_parameter_decorators_from_typed_dict,
12
+ run,
13
+ )
14
+
15
+
16
+ class MariaDBTypedDict(CommonTypedDict):
17
+ user_name: Annotated[
18
+ str,
19
+ click.option(
20
+ "--username",
21
+ type=str,
22
+ help="Username",
23
+ required=True,
24
+ ),
25
+ ]
26
+ password: Annotated[
27
+ str,
28
+ click.option(
29
+ "--password",
30
+ type=str,
31
+ help="Password",
32
+ required=True,
33
+ ),
34
+ ]
35
+
36
+ host: Annotated[
37
+ str,
38
+ click.option(
39
+ "--host",
40
+ type=str,
41
+ help="Db host",
42
+ default="127.0.0.1",
43
+ ),
44
+ ]
45
+
46
+ port: Annotated[
47
+ int,
48
+ click.option(
49
+ "--port",
50
+ type=int,
51
+ default=3306,
52
+ help="Db Port",
53
+ ),
54
+ ]
55
+
56
+ storage_engine: Annotated[
57
+ int,
58
+ click.option(
59
+ "--storage-engine",
60
+ type=click.Choice(["InnoDB", "MyISAM"]),
61
+ help="DB storage engine",
62
+ required=True,
63
+ ),
64
+ ]
65
+
66
+
67
+ class MariaDBHNSWTypedDict(MariaDBTypedDict):
68
+ m: Annotated[
69
+ int | None,
70
+ click.option(
71
+ "--m",
72
+ type=int,
73
+ help="M parameter in MHNSW vector indexing",
74
+ required=False,
75
+ ),
76
+ ]
77
+
78
+ ef_search: Annotated[
79
+ int | None,
80
+ click.option(
81
+ "--ef-search",
82
+ type=int,
83
+ help="MariaDB system variable mhnsw_min_limit",
84
+ required=False,
85
+ ),
86
+ ]
87
+
88
+ max_cache_size: Annotated[
89
+ int | None,
90
+ click.option(
91
+ "--max-cache-size",
92
+ type=int,
93
+ help="MariaDB system variable mhnsw_max_cache_size",
94
+ required=False,
95
+ ),
96
+ ]
97
+
98
+
99
+ @cli.command()
100
+ @click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict)
101
+ def MariaDBHNSW(
102
+ **parameters: Unpack[MariaDBHNSWTypedDict],
103
+ ):
104
+ from .config import MariaDBConfig, MariaDBHNSWConfig
105
+
106
+ run(
107
+ db=DB.MariaDB,
108
+ db_config=MariaDBConfig(
109
+ db_label=parameters["db_label"],
110
+ user_name=parameters["username"],
111
+ password=SecretStr(parameters["password"]),
112
+ host=parameters["host"],
113
+ port=parameters["port"],
114
+ ),
115
+ db_case_config=MariaDBHNSWConfig(
116
+ M=parameters["m"],
117
+ ef_search=parameters["ef_search"],
118
+ storage_engine=parameters["storage_engine"],
119
+ max_cache_size=parameters["max_cache_size"],
120
+ ),
121
+ **parameters,
122
+ )
@@ -0,0 +1,73 @@
1
+ from typing import TypedDict
2
+
3
+ from pydantic import BaseModel, SecretStr
4
+
5
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
+
7
+
8
+ class MariaDBConfigDict(TypedDict):
9
+ """These keys will be directly used as kwargs in mariadb connection string,
10
+ so the names must match exactly mariadb API"""
11
+
12
+ user: str
13
+ password: str
14
+ host: str
15
+ port: int
16
+
17
+
18
+ class MariaDBConfig(DBConfig):
19
+ user_name: str = "root"
20
+ password: SecretStr
21
+ host: str = "127.0.0.1"
22
+ port: int = 3306
23
+
24
+ def to_dict(self) -> MariaDBConfigDict:
25
+ pwd_str = self.password.get_secret_value()
26
+ return {
27
+ "host": self.host,
28
+ "port": self.port,
29
+ "user": self.user_name,
30
+ "password": pwd_str,
31
+ }
32
+
33
+
34
+ class MariaDBIndexConfig(BaseModel):
35
+ """Base config for MariaDB"""
36
+
37
+ metric_type: MetricType | None = None
38
+
39
+ def parse_metric(self) -> str:
40
+ if self.metric_type == MetricType.L2:
41
+ return "euclidean"
42
+ if self.metric_type == MetricType.COSINE:
43
+ return "cosine"
44
+ msg = f"Metric type {self.metric_type} is not supported!"
45
+ raise ValueError(msg)
46
+
47
+
48
+ class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
49
+ M: int | None
50
+ ef_search: int | None
51
+ index: IndexType = IndexType.HNSW
52
+ storage_engine: str = "InnoDB"
53
+ max_cache_size: int | None
54
+
55
+ def index_param(self) -> dict:
56
+ return {
57
+ "storage_engine": self.storage_engine,
58
+ "metric_type": self.parse_metric(),
59
+ "index_type": self.index.value,
60
+ "M": self.M,
61
+ "max_cache_size": self.max_cache_size,
62
+ }
63
+
64
+ def search_param(self) -> dict:
65
+ return {
66
+ "metric_type": self.parse_metric(),
67
+ "ef_search": self.ef_search,
68
+ }
69
+
70
+
71
+ _mariadb_case_config = {
72
+ IndexType.HNSW: MariaDBHNSWConfig,
73
+ }