vectordb-bench 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +33 -1
- vectordb_bench/backend/clients/api.py +1 -1
- vectordb_bench/backend/clients/chroma/chroma.py +2 -2
- vectordb_bench/backend/clients/clickhouse/cli.py +66 -0
- vectordb_bench/backend/clients/clickhouse/clickhouse.py +156 -0
- vectordb_bench/backend/clients/clickhouse/config.py +60 -0
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +1 -1
- vectordb_bench/backend/clients/mariadb/cli.py +60 -45
- vectordb_bench/backend/clients/mariadb/config.py +11 -9
- vectordb_bench/backend/clients/mariadb/mariadb.py +52 -58
- vectordb_bench/backend/clients/milvus/cli.py +1 -19
- vectordb_bench/backend/clients/milvus/config.py +0 -1
- vectordb_bench/backend/clients/milvus/milvus.py +1 -1
- vectordb_bench/backend/clients/pgvector/cli.py +1 -2
- vectordb_bench/backend/clients/pinecone/pinecone.py +1 -1
- vectordb_bench/backend/clients/qdrant_cloud/config.py +1 -9
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +1 -1
- vectordb_bench/backend/clients/tidb/config.py +6 -9
- vectordb_bench/backend/clients/tidb/tidb.py +17 -18
- vectordb_bench/backend/clients/vespa/cli.py +47 -0
- vectordb_bench/backend/clients/vespa/config.py +51 -0
- vectordb_bench/backend/clients/vespa/util.py +15 -0
- vectordb_bench/backend/clients/vespa/vespa.py +249 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
- vectordb_bench/cli/cli.py +21 -17
- vectordb_bench/cli/vectordbbench.py +5 -1
- vectordb_bench/frontend/config/dbCaseConfigs.py +58 -7
- vectordb_bench/frontend/config/styles.py +2 -0
- vectordb_bench/models.py +5 -6
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info}/METADATA +11 -3
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info}/RECORD +35 -28
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info/licenses}/LICENSE +0 -0
- {vectordb_bench-0.0.23.dist-info → vectordb_bench-0.0.25.dist-info}/top_level.txt +0 -0
@@ -43,6 +43,8 @@ class DB(Enum):
|
|
43
43
|
AliyunOpenSearch = "AliyunOpenSearch"
|
44
44
|
MongoDB = "MongoDB"
|
45
45
|
TiDB = "TiDB"
|
46
|
+
Clickhouse = "Clickhouse"
|
47
|
+
Vespa = "Vespa"
|
46
48
|
|
47
49
|
@property
|
48
50
|
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
|
@@ -117,6 +119,11 @@ class DB(Enum):
|
|
117
119
|
|
118
120
|
return AWSOpenSearch
|
119
121
|
|
122
|
+
if self == DB.Clickhouse:
|
123
|
+
from .clickhouse.clickhouse import Clickhouse
|
124
|
+
|
125
|
+
return Clickhouse
|
126
|
+
|
120
127
|
if self == DB.AlloyDB:
|
121
128
|
from .alloydb.alloydb import AlloyDB
|
122
129
|
|
@@ -152,6 +159,11 @@ class DB(Enum):
|
|
152
159
|
|
153
160
|
return Test
|
154
161
|
|
162
|
+
if self == DB.Vespa:
|
163
|
+
from .vespa.vespa import Vespa
|
164
|
+
|
165
|
+
return Vespa
|
166
|
+
|
155
167
|
msg = f"Unknown DB: {self.name}"
|
156
168
|
raise ValueError(msg)
|
157
169
|
|
@@ -228,6 +240,11 @@ class DB(Enum):
|
|
228
240
|
|
229
241
|
return AWSOpenSearchConfig
|
230
242
|
|
243
|
+
if self == DB.Clickhouse:
|
244
|
+
from .clickhouse.config import ClickhouseConfig
|
245
|
+
|
246
|
+
return ClickhouseConfig
|
247
|
+
|
231
248
|
if self == DB.AlloyDB:
|
232
249
|
from .alloydb.config import AlloyDBConfig
|
233
250
|
|
@@ -263,10 +280,15 @@ class DB(Enum):
|
|
263
280
|
|
264
281
|
return TestConfig
|
265
282
|
|
283
|
+
if self == DB.Vespa:
|
284
|
+
from .vespa.config import VespaConfig
|
285
|
+
|
286
|
+
return VespaConfig
|
287
|
+
|
266
288
|
msg = f"Unknown DB: {self.name}"
|
267
289
|
raise ValueError(msg)
|
268
290
|
|
269
|
-
def case_config_cls( # noqa: PLR0911
|
291
|
+
def case_config_cls( # noqa: C901, PLR0911, PLR0912
|
270
292
|
self,
|
271
293
|
index_type: IndexType | None = None,
|
272
294
|
) -> type[DBCaseConfig]:
|
@@ -310,6 +332,11 @@ class DB(Enum):
|
|
310
332
|
|
311
333
|
return AWSOpenSearchIndexConfig
|
312
334
|
|
335
|
+
if self == DB.Clickhouse:
|
336
|
+
from .clickhouse.config import ClickhouseHNSWConfig
|
337
|
+
|
338
|
+
return ClickhouseHNSWConfig
|
339
|
+
|
313
340
|
if self == DB.PgVectorScale:
|
314
341
|
from .pgvectorscale.config import _pgvectorscale_case_config
|
315
342
|
|
@@ -350,6 +377,11 @@ class DB(Enum):
|
|
350
377
|
|
351
378
|
return TiDBIndexConfig
|
352
379
|
|
380
|
+
if self == DB.Vespa:
|
381
|
+
from .vespa.config import VespaHNSWConfig
|
382
|
+
|
383
|
+
return VespaHNSWConfig
|
384
|
+
|
353
385
|
# DB.Pinecone, DB.Chroma, DB.Redis
|
354
386
|
return EmptyDBCaseConfig
|
355
387
|
|
@@ -162,7 +162,7 @@ class VectorDB(ABC):
|
|
162
162
|
embeddings: list[list[float]],
|
163
163
|
metadata: list[int],
|
164
164
|
**kwargs,
|
165
|
-
) ->
|
165
|
+
) -> tuple[int, Exception]:
|
166
166
|
"""Insert the embeddings to the vector database. The default number of embeddings for
|
167
167
|
each insert_embeddings is 5000.
|
168
168
|
|
@@ -65,7 +65,7 @@ class ChromaClient(VectorDB):
|
|
65
65
|
embeddings: list[list[float]],
|
66
66
|
metadata: list[int],
|
67
67
|
**kwargs: Any,
|
68
|
-
) ->
|
68
|
+
) -> tuple[int, Exception]:
|
69
69
|
"""Insert embeddings into the database.
|
70
70
|
|
71
71
|
Args:
|
@@ -74,7 +74,7 @@ class ChromaClient(VectorDB):
|
|
74
74
|
kwargs: other arguments
|
75
75
|
|
76
76
|
Returns:
|
77
|
-
|
77
|
+
tuple[int, Exception]: number of embeddings inserted and exception if any
|
78
78
|
"""
|
79
79
|
ids = [str(i) for i in metadata]
|
80
80
|
metadata = [{"id": int(i)} for i in metadata]
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from typing import Annotated, TypedDict, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from ....cli.cli import (
|
7
|
+
CommonTypedDict,
|
8
|
+
HNSWFlavor2,
|
9
|
+
cli,
|
10
|
+
click_parameter_decorators_from_typed_dict,
|
11
|
+
run,
|
12
|
+
)
|
13
|
+
from .. import DB
|
14
|
+
from .config import ClickhouseHNSWConfig
|
15
|
+
|
16
|
+
|
17
|
+
class ClickhouseTypedDict(TypedDict):
|
18
|
+
password: Annotated[str, click.option("--password", type=str, help="DB password")]
|
19
|
+
host: Annotated[str, click.option("--host", type=str, help="DB host", required=True)]
|
20
|
+
port: Annotated[int, click.option("--port", type=int, default=8123, help="DB Port")]
|
21
|
+
user: Annotated[int, click.option("--user", type=str, default="clickhouse", help="DB user")]
|
22
|
+
ssl: Annotated[
|
23
|
+
bool,
|
24
|
+
click.option(
|
25
|
+
"--ssl/--no-ssl",
|
26
|
+
is_flag=True,
|
27
|
+
show_default=True,
|
28
|
+
default=True,
|
29
|
+
help="Enable or disable SSL for Clickhouse",
|
30
|
+
),
|
31
|
+
]
|
32
|
+
ssl_ca_certs: Annotated[
|
33
|
+
str,
|
34
|
+
click.option(
|
35
|
+
"--ssl-ca-certs",
|
36
|
+
show_default=True,
|
37
|
+
help="Path to certificate authority file to use for SSL",
|
38
|
+
),
|
39
|
+
]
|
40
|
+
|
41
|
+
|
42
|
+
class ClickhouseHNSWTypedDict(CommonTypedDict, ClickhouseTypedDict, HNSWFlavor2): ...
|
43
|
+
|
44
|
+
|
45
|
+
@cli.command()
|
46
|
+
@click_parameter_decorators_from_typed_dict(ClickhouseHNSWTypedDict)
|
47
|
+
def Clickhouse(**parameters: Unpack[ClickhouseHNSWTypedDict]):
|
48
|
+
from .config import ClickhouseConfig
|
49
|
+
|
50
|
+
run(
|
51
|
+
db=DB.Clickhouse,
|
52
|
+
db_config=ClickhouseConfig(
|
53
|
+
db_label=parameters["db_label"],
|
54
|
+
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
55
|
+
host=parameters["host"],
|
56
|
+
port=parameters["port"],
|
57
|
+
ssl=parameters["ssl"],
|
58
|
+
ssl_ca_certs=parameters["ssl_ca_certs"],
|
59
|
+
),
|
60
|
+
db_case_config=ClickhouseHNSWConfig(
|
61
|
+
M=parameters["m"],
|
62
|
+
efConstruction=parameters["ef_construction"],
|
63
|
+
ef=parameters["ef_runtime"],
|
64
|
+
),
|
65
|
+
**parameters,
|
66
|
+
)
|
@@ -0,0 +1,156 @@
|
|
1
|
+
"""Wrapper around the Clickhouse vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from contextlib import contextmanager
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import clickhouse_connect
|
8
|
+
|
9
|
+
from ..api import DBCaseConfig, VectorDB
|
10
|
+
|
11
|
+
log = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
class Clickhouse(VectorDB):
|
15
|
+
"""Use SQLAlchemy instructions"""
|
16
|
+
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
dim: int,
|
20
|
+
db_config: dict,
|
21
|
+
db_case_config: DBCaseConfig,
|
22
|
+
collection_name: str = "CHVectorCollection",
|
23
|
+
drop_old: bool = False,
|
24
|
+
**kwargs,
|
25
|
+
):
|
26
|
+
self.db_config = db_config
|
27
|
+
self.case_config = db_case_config
|
28
|
+
self.table_name = collection_name
|
29
|
+
self.dim = dim
|
30
|
+
|
31
|
+
self._index_name = "clickhouse_index"
|
32
|
+
self._primary_field = "id"
|
33
|
+
self._vector_field = "embedding"
|
34
|
+
|
35
|
+
# construct basic units
|
36
|
+
self.conn = clickhouse_connect.get_client(
|
37
|
+
host=self.db_config["host"],
|
38
|
+
port=self.db_config["port"],
|
39
|
+
username=self.db_config["user"],
|
40
|
+
password=self.db_config["password"],
|
41
|
+
database=self.db_config["dbname"],
|
42
|
+
)
|
43
|
+
|
44
|
+
if drop_old:
|
45
|
+
log.info(f"Clickhouse client drop table : {self.table_name}")
|
46
|
+
self._drop_table()
|
47
|
+
self._create_table(dim)
|
48
|
+
|
49
|
+
self.conn.close()
|
50
|
+
self.conn = None
|
51
|
+
|
52
|
+
@contextmanager
|
53
|
+
def init(self):
|
54
|
+
"""
|
55
|
+
Examples:
|
56
|
+
>>> with self.init():
|
57
|
+
>>> self.insert_embeddings()
|
58
|
+
>>> self.search_embedding()
|
59
|
+
"""
|
60
|
+
|
61
|
+
self.conn = clickhouse_connect.get_client(
|
62
|
+
host=self.db_config["host"],
|
63
|
+
port=self.db_config["port"],
|
64
|
+
username=self.db_config["user"],
|
65
|
+
password=self.db_config["password"],
|
66
|
+
database=self.db_config["dbname"],
|
67
|
+
)
|
68
|
+
|
69
|
+
try:
|
70
|
+
yield
|
71
|
+
finally:
|
72
|
+
self.conn.close()
|
73
|
+
self.conn = None
|
74
|
+
|
75
|
+
def _drop_table(self):
|
76
|
+
assert self.conn is not None, "Connection is not initialized"
|
77
|
+
|
78
|
+
self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["dbname"]}.{self.table_name}')
|
79
|
+
|
80
|
+
def _create_table(self, dim: int):
|
81
|
+
assert self.conn is not None, "Connection is not initialized"
|
82
|
+
|
83
|
+
try:
|
84
|
+
# create table
|
85
|
+
self.conn.command(
|
86
|
+
f'CREATE TABLE IF NOT EXISTS {self.db_config["dbname"]}.{self.table_name} \
|
87
|
+
(id UInt32, embedding Array(Float64)) ENGINE = MergeTree() ORDER BY id;'
|
88
|
+
)
|
89
|
+
|
90
|
+
except Exception as e:
|
91
|
+
log.warning(f"Failed to create Clickhouse table: {self.table_name} error: {e}")
|
92
|
+
raise e from None
|
93
|
+
|
94
|
+
def ready_to_load(self):
|
95
|
+
pass
|
96
|
+
|
97
|
+
def optimize(self, data_size: int | None = None):
|
98
|
+
pass
|
99
|
+
|
100
|
+
def ready_to_search(self):
|
101
|
+
pass
|
102
|
+
|
103
|
+
def insert_embeddings(
|
104
|
+
self,
|
105
|
+
embeddings: list[list[float]],
|
106
|
+
metadata: list[int],
|
107
|
+
**kwargs: Any,
|
108
|
+
) -> tuple[int, Exception]:
|
109
|
+
assert self.conn is not None, "Connection is not initialized"
|
110
|
+
|
111
|
+
try:
|
112
|
+
# do not iterate for bulk insert
|
113
|
+
items = [metadata, embeddings]
|
114
|
+
|
115
|
+
self.conn.insert(
|
116
|
+
table=self.table_name,
|
117
|
+
data=items,
|
118
|
+
column_names=["id", "embedding"],
|
119
|
+
column_type_names=["UInt32", "Array(Float64)"],
|
120
|
+
column_oriented=True,
|
121
|
+
)
|
122
|
+
return len(metadata), None
|
123
|
+
except Exception as e:
|
124
|
+
log.warning(f"Failed to insert data into Clickhouse table ({self.table_name}), error: {e}")
|
125
|
+
return 0, e
|
126
|
+
|
127
|
+
def search_embedding(
|
128
|
+
self,
|
129
|
+
query: list[float],
|
130
|
+
k: int = 100,
|
131
|
+
filters: dict | None = None,
|
132
|
+
timeout: int | None = None,
|
133
|
+
) -> list[int]:
|
134
|
+
assert self.conn is not None, "Connection is not initialized"
|
135
|
+
|
136
|
+
index_param = self.case_config.index_param() # noqa: F841
|
137
|
+
search_param = self.case_config.search_param()
|
138
|
+
|
139
|
+
if filters:
|
140
|
+
gt = filters.get("id")
|
141
|
+
filter_sql = (
|
142
|
+
f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
|
143
|
+
f'FROM {self.db_config["dbname"]}.{self.table_name} '
|
144
|
+
f"WHERE id > {gt} "
|
145
|
+
f"ORDER BY score LIMIT {k};"
|
146
|
+
)
|
147
|
+
result = self.conn.query(filter_sql).result_rows
|
148
|
+
return [int(row[0]) for row in result]
|
149
|
+
else: # noqa: RET505
|
150
|
+
select_sql = (
|
151
|
+
f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
|
152
|
+
f'FROM {self.db_config["dbname"]}.{self.table_name} '
|
153
|
+
f"ORDER BY score LIMIT {k};"
|
154
|
+
)
|
155
|
+
result = self.conn.query(select_sql).result_rows
|
156
|
+
return [int(row[0]) for row in result]
|
@@ -0,0 +1,60 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
|
3
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
4
|
+
|
5
|
+
|
6
|
+
class ClickhouseConfig(DBConfig):
|
7
|
+
user_name: str = "clickhouse"
|
8
|
+
password: SecretStr
|
9
|
+
host: str = "localhost"
|
10
|
+
port: int = 8123
|
11
|
+
db_name: str = "default"
|
12
|
+
|
13
|
+
def to_dict(self) -> dict:
|
14
|
+
pwd_str = self.password.get_secret_value()
|
15
|
+
return {
|
16
|
+
"host": self.host,
|
17
|
+
"port": self.port,
|
18
|
+
"dbname": self.db_name,
|
19
|
+
"user": self.user_name,
|
20
|
+
"password": pwd_str,
|
21
|
+
}
|
22
|
+
|
23
|
+
|
24
|
+
class ClickhouseIndexConfig(BaseModel):
|
25
|
+
|
26
|
+
metric_type: MetricType | None = None
|
27
|
+
|
28
|
+
def parse_metric(self) -> str:
|
29
|
+
if not self.metric_type:
|
30
|
+
return ""
|
31
|
+
return self.metric_type.value
|
32
|
+
|
33
|
+
def parse_metric_str(self) -> str:
|
34
|
+
if self.metric_type == MetricType.L2:
|
35
|
+
return "L2Distance"
|
36
|
+
if self.metric_type == MetricType.COSINE:
|
37
|
+
return "cosineDistance"
|
38
|
+
msg = f"Not Support for {self.metric_type}"
|
39
|
+
raise RuntimeError(msg)
|
40
|
+
return None
|
41
|
+
|
42
|
+
|
43
|
+
class ClickhouseHNSWConfig(ClickhouseIndexConfig, DBCaseConfig):
|
44
|
+
M: int | None
|
45
|
+
efConstruction: int | None
|
46
|
+
ef: int | None = None
|
47
|
+
index: IndexType = IndexType.HNSW
|
48
|
+
|
49
|
+
def index_param(self) -> dict:
|
50
|
+
return {
|
51
|
+
"metric_type": self.parse_metric_str(),
|
52
|
+
"index_type": self.index.value,
|
53
|
+
"params": {"M": self.M, "efConstruction": self.efConstruction},
|
54
|
+
}
|
55
|
+
|
56
|
+
def search_param(self) -> dict:
|
57
|
+
return {
|
58
|
+
"met˝ric_type": self.parse_metric_str(),
|
59
|
+
"params": {"ef": self.ef},
|
60
|
+
}
|
@@ -81,7 +81,7 @@ class ElasticCloud(VectorDB):
|
|
81
81
|
embeddings: Iterable[list[float]],
|
82
82
|
metadata: list[int],
|
83
83
|
**kwargs,
|
84
|
-
) ->
|
84
|
+
) -> tuple[int, Exception]:
|
85
85
|
"""Insert the embeddings to the elasticsearch."""
|
86
86
|
assert self.client is not None, "should self.init() first"
|
87
87
|
|
@@ -1,83 +1,98 @@
|
|
1
|
-
from typing import Annotated,
|
1
|
+
from typing import Annotated, Unpack
|
2
2
|
|
3
3
|
import click
|
4
|
-
import os
|
5
4
|
from pydantic import SecretStr
|
6
5
|
|
6
|
+
from vectordb_bench.backend.clients import DB
|
7
|
+
|
7
8
|
from ....cli.cli import (
|
8
9
|
CommonTypedDict,
|
9
|
-
HNSWFlavor1,
|
10
10
|
cli,
|
11
11
|
click_parameter_decorators_from_typed_dict,
|
12
12
|
run,
|
13
13
|
)
|
14
|
-
from vectordb_bench.backend.clients import DB
|
15
14
|
|
16
15
|
|
17
16
|
class MariaDBTypedDict(CommonTypedDict):
|
18
17
|
user_name: Annotated[
|
19
|
-
str,
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
18
|
+
str,
|
19
|
+
click.option(
|
20
|
+
"--username",
|
21
|
+
type=str,
|
22
|
+
help="Username",
|
23
|
+
required=True,
|
24
|
+
),
|
24
25
|
]
|
25
26
|
password: Annotated[
|
26
|
-
str,
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
27
|
+
str,
|
28
|
+
click.option(
|
29
|
+
"--password",
|
30
|
+
type=str,
|
31
|
+
help="Password",
|
32
|
+
required=True,
|
33
|
+
),
|
31
34
|
]
|
32
35
|
|
33
36
|
host: Annotated[
|
34
|
-
str,
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
37
|
+
str,
|
38
|
+
click.option(
|
39
|
+
"--host",
|
40
|
+
type=str,
|
41
|
+
help="Db host",
|
42
|
+
default="127.0.0.1",
|
43
|
+
),
|
39
44
|
]
|
40
45
|
|
41
46
|
port: Annotated[
|
42
|
-
int,
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
+
int,
|
48
|
+
click.option(
|
49
|
+
"--port",
|
50
|
+
type=int,
|
51
|
+
default=3306,
|
52
|
+
help="Db Port",
|
53
|
+
),
|
47
54
|
]
|
48
55
|
|
49
56
|
storage_engine: Annotated[
|
50
|
-
int,
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
57
|
+
int,
|
58
|
+
click.option(
|
59
|
+
"--storage-engine",
|
60
|
+
type=click.Choice(["InnoDB", "MyISAM"]),
|
61
|
+
help="DB storage engine",
|
62
|
+
required=True,
|
63
|
+
),
|
55
64
|
]
|
56
65
|
|
66
|
+
|
57
67
|
class MariaDBHNSWTypedDict(MariaDBTypedDict):
|
58
|
-
...
|
59
68
|
m: Annotated[
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
69
|
+
int | None,
|
70
|
+
click.option(
|
71
|
+
"--m",
|
72
|
+
type=int,
|
73
|
+
help="M parameter in MHNSW vector indexing",
|
74
|
+
required=False,
|
75
|
+
),
|
65
76
|
]
|
66
77
|
|
67
78
|
ef_search: Annotated[
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
79
|
+
int | None,
|
80
|
+
click.option(
|
81
|
+
"--ef-search",
|
82
|
+
type=int,
|
83
|
+
help="MariaDB system variable mhnsw_min_limit",
|
84
|
+
required=False,
|
85
|
+
),
|
73
86
|
]
|
74
87
|
|
75
88
|
max_cache_size: Annotated[
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
89
|
+
int | None,
|
90
|
+
click.option(
|
91
|
+
"--max-cache-size",
|
92
|
+
type=int,
|
93
|
+
help="MariaDB system variable mhnsw_max_cache_size",
|
94
|
+
required=False,
|
95
|
+
),
|
81
96
|
]
|
82
97
|
|
83
98
|
|
@@ -1,10 +1,13 @@
|
|
1
|
-
from pydantic import SecretStr, BaseModel
|
2
1
|
from typing import TypedDict
|
3
|
-
|
2
|
+
|
3
|
+
from pydantic import BaseModel, SecretStr
|
4
|
+
|
5
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
6
|
+
|
4
7
|
|
5
8
|
class MariaDBConfigDict(TypedDict):
|
6
9
|
"""These keys will be directly used as kwargs in mariadb connection string,
|
7
|
-
|
10
|
+
so the names must match exactly mariadb API"""
|
8
11
|
|
9
12
|
user: str
|
10
13
|
password: str
|
@@ -36,10 +39,11 @@ class MariaDBIndexConfig(BaseModel):
|
|
36
39
|
def parse_metric(self) -> str:
|
37
40
|
if self.metric_type == MetricType.L2:
|
38
41
|
return "euclidean"
|
39
|
-
|
42
|
+
if self.metric_type == MetricType.COSINE:
|
40
43
|
return "cosine"
|
41
|
-
|
42
|
-
|
44
|
+
msg = f"Metric type {self.metric_type} is not supported!"
|
45
|
+
raise ValueError(msg)
|
46
|
+
|
43
47
|
|
44
48
|
class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
|
45
49
|
M: int | None
|
@@ -65,7 +69,5 @@ class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
|
|
65
69
|
|
66
70
|
|
67
71
|
_mariadb_case_config = {
|
68
|
-
|
72
|
+
IndexType.HNSW: MariaDBHNSWConfig,
|
69
73
|
}
|
70
|
-
|
71
|
-
|