vectordb-bench 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +65 -1
- vectordb_bench/backend/clients/api.py +2 -1
- vectordb_bench/backend/clients/chroma/chroma.py +2 -2
- vectordb_bench/backend/clients/clickhouse/cli.py +66 -0
- vectordb_bench/backend/clients/clickhouse/clickhouse.py +156 -0
- vectordb_bench/backend/clients/clickhouse/config.py +60 -0
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +1 -1
- vectordb_bench/backend/clients/mariadb/cli.py +122 -0
- vectordb_bench/backend/clients/mariadb/config.py +73 -0
- vectordb_bench/backend/clients/mariadb/mariadb.py +208 -0
- vectordb_bench/backend/clients/milvus/cli.py +32 -0
- vectordb_bench/backend/clients/milvus/config.py +32 -0
- vectordb_bench/backend/clients/milvus/milvus.py +1 -1
- vectordb_bench/backend/clients/pgvector/cli.py +14 -3
- vectordb_bench/backend/clients/pgvector/config.py +22 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
- vectordb_bench/backend/clients/pinecone/pinecone.py +1 -1
- vectordb_bench/backend/clients/qdrant_cloud/config.py +1 -9
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +1 -1
- vectordb_bench/backend/clients/tidb/cli.py +98 -0
- vectordb_bench/backend/clients/tidb/config.py +46 -0
- vectordb_bench/backend/clients/tidb/tidb.py +233 -0
- vectordb_bench/backend/clients/vespa/cli.py +47 -0
- vectordb_bench/backend/clients/vespa/config.py +51 -0
- vectordb_bench/backend/clients/vespa/util.py +15 -0
- vectordb_bench/backend/clients/vespa/vespa.py +249 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
- vectordb_bench/cli/cli.py +20 -17
- vectordb_bench/cli/vectordbbench.py +8 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +147 -0
- vectordb_bench/frontend/config/styles.py +4 -0
- vectordb_bench/models.py +8 -6
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/METADATA +22 -3
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/RECORD +38 -25
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info/licenses}/LICENSE +0 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/top_level.txt +0 -0
@@ -38,9 +38,13 @@ class DB(Enum):
|
|
38
38
|
Chroma = "Chroma"
|
39
39
|
AWSOpenSearch = "OpenSearch"
|
40
40
|
AliyunElasticsearch = "AliyunElasticsearch"
|
41
|
+
MariaDB = "MariaDB"
|
41
42
|
Test = "test"
|
42
43
|
AliyunOpenSearch = "AliyunOpenSearch"
|
43
44
|
MongoDB = "MongoDB"
|
45
|
+
TiDB = "TiDB"
|
46
|
+
Clickhouse = "Clickhouse"
|
47
|
+
Vespa = "Vespa"
|
44
48
|
|
45
49
|
@property
|
46
50
|
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
|
@@ -115,6 +119,11 @@ class DB(Enum):
|
|
115
119
|
|
116
120
|
return AWSOpenSearch
|
117
121
|
|
122
|
+
if self == DB.Clickhouse:
|
123
|
+
from .clickhouse.clickhouse import Clickhouse
|
124
|
+
|
125
|
+
return Clickhouse
|
126
|
+
|
118
127
|
if self == DB.AlloyDB:
|
119
128
|
from .alloydb.alloydb import AlloyDB
|
120
129
|
|
@@ -135,11 +144,26 @@ class DB(Enum):
|
|
135
144
|
|
136
145
|
return MongoDB
|
137
146
|
|
147
|
+
if self == DB.MariaDB:
|
148
|
+
from .mariadb.mariadb import MariaDB
|
149
|
+
|
150
|
+
return MariaDB
|
151
|
+
|
152
|
+
if self == DB.TiDB:
|
153
|
+
from .tidb.tidb import TiDB
|
154
|
+
|
155
|
+
return TiDB
|
156
|
+
|
138
157
|
if self == DB.Test:
|
139
158
|
from .test.test import Test
|
140
159
|
|
141
160
|
return Test
|
142
161
|
|
162
|
+
if self == DB.Vespa:
|
163
|
+
from .vespa.vespa import Vespa
|
164
|
+
|
165
|
+
return Vespa
|
166
|
+
|
143
167
|
msg = f"Unknown DB: {self.name}"
|
144
168
|
raise ValueError(msg)
|
145
169
|
|
@@ -216,6 +240,11 @@ class DB(Enum):
|
|
216
240
|
|
217
241
|
return AWSOpenSearchConfig
|
218
242
|
|
243
|
+
if self == DB.Clickhouse:
|
244
|
+
from .clickhouse.config import ClickhouseConfig
|
245
|
+
|
246
|
+
return ClickhouseConfig
|
247
|
+
|
219
248
|
if self == DB.AlloyDB:
|
220
249
|
from .alloydb.config import AlloyDBConfig
|
221
250
|
|
@@ -236,15 +265,30 @@ class DB(Enum):
|
|
236
265
|
|
237
266
|
return MongoDBConfig
|
238
267
|
|
268
|
+
if self == DB.MariaDB:
|
269
|
+
from .mariadb.config import MariaDBConfig
|
270
|
+
|
271
|
+
return MariaDBConfig
|
272
|
+
|
273
|
+
if self == DB.TiDB:
|
274
|
+
from .tidb.config import TiDBConfig
|
275
|
+
|
276
|
+
return TiDBConfig
|
277
|
+
|
239
278
|
if self == DB.Test:
|
240
279
|
from .test.config import TestConfig
|
241
280
|
|
242
281
|
return TestConfig
|
243
282
|
|
283
|
+
if self == DB.Vespa:
|
284
|
+
from .vespa.config import VespaConfig
|
285
|
+
|
286
|
+
return VespaConfig
|
287
|
+
|
244
288
|
msg = f"Unknown DB: {self.name}"
|
245
289
|
raise ValueError(msg)
|
246
290
|
|
247
|
-
def case_config_cls( # noqa: PLR0911
|
291
|
+
def case_config_cls( # noqa: C901, PLR0911, PLR0912
|
248
292
|
self,
|
249
293
|
index_type: IndexType | None = None,
|
250
294
|
) -> type[DBCaseConfig]:
|
@@ -288,6 +332,11 @@ class DB(Enum):
|
|
288
332
|
|
289
333
|
return AWSOpenSearchIndexConfig
|
290
334
|
|
335
|
+
if self == DB.Clickhouse:
|
336
|
+
from .clickhouse.config import ClickhouseHNSWConfig
|
337
|
+
|
338
|
+
return ClickhouseHNSWConfig
|
339
|
+
|
291
340
|
if self == DB.PgVectorScale:
|
292
341
|
from .pgvectorscale.config import _pgvectorscale_case_config
|
293
342
|
|
@@ -318,6 +367,21 @@ class DB(Enum):
|
|
318
367
|
|
319
368
|
return MongoDBIndexConfig
|
320
369
|
|
370
|
+
if self == DB.MariaDB:
|
371
|
+
from .mariadb.config import _mariadb_case_config
|
372
|
+
|
373
|
+
return _mariadb_case_config.get(index_type)
|
374
|
+
|
375
|
+
if self == DB.TiDB:
|
376
|
+
from .tidb.config import TiDBIndexConfig
|
377
|
+
|
378
|
+
return TiDBIndexConfig
|
379
|
+
|
380
|
+
if self == DB.Vespa:
|
381
|
+
from .vespa.config import VespaHNSWConfig
|
382
|
+
|
383
|
+
return VespaHNSWConfig
|
384
|
+
|
321
385
|
# DB.Pinecone, DB.Chroma, DB.Redis
|
322
386
|
return EmptyDBCaseConfig
|
323
387
|
|
@@ -25,6 +25,7 @@ class IndexType(str, Enum):
|
|
25
25
|
ES_HNSW = "hnsw"
|
26
26
|
ES_IVFFlat = "ivfflat"
|
27
27
|
GPU_IVF_FLAT = "GPU_IVF_FLAT"
|
28
|
+
GPU_BRUTE_FORCE = "GPU_BRUTE_FORCE"
|
28
29
|
GPU_IVF_PQ = "GPU_IVF_PQ"
|
29
30
|
GPU_CAGRA = "GPU_CAGRA"
|
30
31
|
SCANN = "scann"
|
@@ -161,7 +162,7 @@ class VectorDB(ABC):
|
|
161
162
|
embeddings: list[list[float]],
|
162
163
|
metadata: list[int],
|
163
164
|
**kwargs,
|
164
|
-
) ->
|
165
|
+
) -> tuple[int, Exception]:
|
165
166
|
"""Insert the embeddings to the vector database. The default number of embeddings for
|
166
167
|
each insert_embeddings is 5000.
|
167
168
|
|
@@ -65,7 +65,7 @@ class ChromaClient(VectorDB):
|
|
65
65
|
embeddings: list[list[float]],
|
66
66
|
metadata: list[int],
|
67
67
|
**kwargs: Any,
|
68
|
-
) ->
|
68
|
+
) -> tuple[int, Exception]:
|
69
69
|
"""Insert embeddings into the database.
|
70
70
|
|
71
71
|
Args:
|
@@ -74,7 +74,7 @@ class ChromaClient(VectorDB):
|
|
74
74
|
kwargs: other arguments
|
75
75
|
|
76
76
|
Returns:
|
77
|
-
|
77
|
+
tuple[int, Exception]: number of embeddings inserted and exception if any
|
78
78
|
"""
|
79
79
|
ids = [str(i) for i in metadata]
|
80
80
|
metadata = [{"id": int(i)} for i in metadata]
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from typing import Annotated, TypedDict, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from ....cli.cli import (
|
7
|
+
CommonTypedDict,
|
8
|
+
HNSWFlavor2,
|
9
|
+
cli,
|
10
|
+
click_parameter_decorators_from_typed_dict,
|
11
|
+
run,
|
12
|
+
)
|
13
|
+
from .. import DB
|
14
|
+
from .config import ClickhouseHNSWConfig
|
15
|
+
|
16
|
+
|
17
|
+
class ClickhouseTypedDict(TypedDict):
|
18
|
+
password: Annotated[str, click.option("--password", type=str, help="DB password")]
|
19
|
+
host: Annotated[str, click.option("--host", type=str, help="DB host", required=True)]
|
20
|
+
port: Annotated[int, click.option("--port", type=int, default=8123, help="DB Port")]
|
21
|
+
user: Annotated[int, click.option("--user", type=str, default="clickhouse", help="DB user")]
|
22
|
+
ssl: Annotated[
|
23
|
+
bool,
|
24
|
+
click.option(
|
25
|
+
"--ssl/--no-ssl",
|
26
|
+
is_flag=True,
|
27
|
+
show_default=True,
|
28
|
+
default=True,
|
29
|
+
help="Enable or disable SSL for Clickhouse",
|
30
|
+
),
|
31
|
+
]
|
32
|
+
ssl_ca_certs: Annotated[
|
33
|
+
str,
|
34
|
+
click.option(
|
35
|
+
"--ssl-ca-certs",
|
36
|
+
show_default=True,
|
37
|
+
help="Path to certificate authority file to use for SSL",
|
38
|
+
),
|
39
|
+
]
|
40
|
+
|
41
|
+
|
42
|
+
class ClickhouseHNSWTypedDict(CommonTypedDict, ClickhouseTypedDict, HNSWFlavor2): ...
|
43
|
+
|
44
|
+
|
45
|
+
@cli.command()
|
46
|
+
@click_parameter_decorators_from_typed_dict(ClickhouseHNSWTypedDict)
|
47
|
+
def Clickhouse(**parameters: Unpack[ClickhouseHNSWTypedDict]):
|
48
|
+
from .config import ClickhouseConfig
|
49
|
+
|
50
|
+
run(
|
51
|
+
db=DB.Clickhouse,
|
52
|
+
db_config=ClickhouseConfig(
|
53
|
+
db_label=parameters["db_label"],
|
54
|
+
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
55
|
+
host=parameters["host"],
|
56
|
+
port=parameters["port"],
|
57
|
+
ssl=parameters["ssl"],
|
58
|
+
ssl_ca_certs=parameters["ssl_ca_certs"],
|
59
|
+
),
|
60
|
+
db_case_config=ClickhouseHNSWConfig(
|
61
|
+
M=parameters["m"],
|
62
|
+
efConstruction=parameters["ef_construction"],
|
63
|
+
ef=parameters["ef_runtime"],
|
64
|
+
),
|
65
|
+
**parameters,
|
66
|
+
)
|
@@ -0,0 +1,156 @@
|
|
1
|
+
"""Wrapper around the Clickhouse vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from contextlib import contextmanager
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import clickhouse_connect
|
8
|
+
|
9
|
+
from ..api import DBCaseConfig, VectorDB
|
10
|
+
|
11
|
+
log = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
class Clickhouse(VectorDB):
|
15
|
+
"""Use SQLAlchemy instructions"""
|
16
|
+
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
dim: int,
|
20
|
+
db_config: dict,
|
21
|
+
db_case_config: DBCaseConfig,
|
22
|
+
collection_name: str = "CHVectorCollection",
|
23
|
+
drop_old: bool = False,
|
24
|
+
**kwargs,
|
25
|
+
):
|
26
|
+
self.db_config = db_config
|
27
|
+
self.case_config = db_case_config
|
28
|
+
self.table_name = collection_name
|
29
|
+
self.dim = dim
|
30
|
+
|
31
|
+
self._index_name = "clickhouse_index"
|
32
|
+
self._primary_field = "id"
|
33
|
+
self._vector_field = "embedding"
|
34
|
+
|
35
|
+
# construct basic units
|
36
|
+
self.conn = clickhouse_connect.get_client(
|
37
|
+
host=self.db_config["host"],
|
38
|
+
port=self.db_config["port"],
|
39
|
+
username=self.db_config["user"],
|
40
|
+
password=self.db_config["password"],
|
41
|
+
database=self.db_config["dbname"],
|
42
|
+
)
|
43
|
+
|
44
|
+
if drop_old:
|
45
|
+
log.info(f"Clickhouse client drop table : {self.table_name}")
|
46
|
+
self._drop_table()
|
47
|
+
self._create_table(dim)
|
48
|
+
|
49
|
+
self.conn.close()
|
50
|
+
self.conn = None
|
51
|
+
|
52
|
+
@contextmanager
|
53
|
+
def init(self):
|
54
|
+
"""
|
55
|
+
Examples:
|
56
|
+
>>> with self.init():
|
57
|
+
>>> self.insert_embeddings()
|
58
|
+
>>> self.search_embedding()
|
59
|
+
"""
|
60
|
+
|
61
|
+
self.conn = clickhouse_connect.get_client(
|
62
|
+
host=self.db_config["host"],
|
63
|
+
port=self.db_config["port"],
|
64
|
+
username=self.db_config["user"],
|
65
|
+
password=self.db_config["password"],
|
66
|
+
database=self.db_config["dbname"],
|
67
|
+
)
|
68
|
+
|
69
|
+
try:
|
70
|
+
yield
|
71
|
+
finally:
|
72
|
+
self.conn.close()
|
73
|
+
self.conn = None
|
74
|
+
|
75
|
+
def _drop_table(self):
|
76
|
+
assert self.conn is not None, "Connection is not initialized"
|
77
|
+
|
78
|
+
self.conn.command(f'DROP TABLE IF EXISTS {self.db_config["dbname"]}.{self.table_name}')
|
79
|
+
|
80
|
+
def _create_table(self, dim: int):
|
81
|
+
assert self.conn is not None, "Connection is not initialized"
|
82
|
+
|
83
|
+
try:
|
84
|
+
# create table
|
85
|
+
self.conn.command(
|
86
|
+
f'CREATE TABLE IF NOT EXISTS {self.db_config["dbname"]}.{self.table_name} \
|
87
|
+
(id UInt32, embedding Array(Float64)) ENGINE = MergeTree() ORDER BY id;'
|
88
|
+
)
|
89
|
+
|
90
|
+
except Exception as e:
|
91
|
+
log.warning(f"Failed to create Clickhouse table: {self.table_name} error: {e}")
|
92
|
+
raise e from None
|
93
|
+
|
94
|
+
def ready_to_load(self):
|
95
|
+
pass
|
96
|
+
|
97
|
+
def optimize(self, data_size: int | None = None):
|
98
|
+
pass
|
99
|
+
|
100
|
+
def ready_to_search(self):
|
101
|
+
pass
|
102
|
+
|
103
|
+
def insert_embeddings(
|
104
|
+
self,
|
105
|
+
embeddings: list[list[float]],
|
106
|
+
metadata: list[int],
|
107
|
+
**kwargs: Any,
|
108
|
+
) -> tuple[int, Exception]:
|
109
|
+
assert self.conn is not None, "Connection is not initialized"
|
110
|
+
|
111
|
+
try:
|
112
|
+
# do not iterate for bulk insert
|
113
|
+
items = [metadata, embeddings]
|
114
|
+
|
115
|
+
self.conn.insert(
|
116
|
+
table=self.table_name,
|
117
|
+
data=items,
|
118
|
+
column_names=["id", "embedding"],
|
119
|
+
column_type_names=["UInt32", "Array(Float64)"],
|
120
|
+
column_oriented=True,
|
121
|
+
)
|
122
|
+
return len(metadata), None
|
123
|
+
except Exception as e:
|
124
|
+
log.warning(f"Failed to insert data into Clickhouse table ({self.table_name}), error: {e}")
|
125
|
+
return 0, e
|
126
|
+
|
127
|
+
def search_embedding(
|
128
|
+
self,
|
129
|
+
query: list[float],
|
130
|
+
k: int = 100,
|
131
|
+
filters: dict | None = None,
|
132
|
+
timeout: int | None = None,
|
133
|
+
) -> list[int]:
|
134
|
+
assert self.conn is not None, "Connection is not initialized"
|
135
|
+
|
136
|
+
index_param = self.case_config.index_param() # noqa: F841
|
137
|
+
search_param = self.case_config.search_param()
|
138
|
+
|
139
|
+
if filters:
|
140
|
+
gt = filters.get("id")
|
141
|
+
filter_sql = (
|
142
|
+
f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
|
143
|
+
f'FROM {self.db_config["dbname"]}.{self.table_name} '
|
144
|
+
f"WHERE id > {gt} "
|
145
|
+
f"ORDER BY score LIMIT {k};"
|
146
|
+
)
|
147
|
+
result = self.conn.query(filter_sql).result_rows
|
148
|
+
return [int(row[0]) for row in result]
|
149
|
+
else: # noqa: RET505
|
150
|
+
select_sql = (
|
151
|
+
f'SELECT id, {search_param["metric_type"]}(embedding,{query}) AS score ' # noqa: S608
|
152
|
+
f'FROM {self.db_config["dbname"]}.{self.table_name} '
|
153
|
+
f"ORDER BY score LIMIT {k};"
|
154
|
+
)
|
155
|
+
result = self.conn.query(select_sql).result_rows
|
156
|
+
return [int(row[0]) for row in result]
|
@@ -0,0 +1,60 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
|
3
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
4
|
+
|
5
|
+
|
6
|
+
class ClickhouseConfig(DBConfig):
|
7
|
+
user_name: str = "clickhouse"
|
8
|
+
password: SecretStr
|
9
|
+
host: str = "localhost"
|
10
|
+
port: int = 8123
|
11
|
+
db_name: str = "default"
|
12
|
+
|
13
|
+
def to_dict(self) -> dict:
|
14
|
+
pwd_str = self.password.get_secret_value()
|
15
|
+
return {
|
16
|
+
"host": self.host,
|
17
|
+
"port": self.port,
|
18
|
+
"dbname": self.db_name,
|
19
|
+
"user": self.user_name,
|
20
|
+
"password": pwd_str,
|
21
|
+
}
|
22
|
+
|
23
|
+
|
24
|
+
class ClickhouseIndexConfig(BaseModel):
|
25
|
+
|
26
|
+
metric_type: MetricType | None = None
|
27
|
+
|
28
|
+
def parse_metric(self) -> str:
|
29
|
+
if not self.metric_type:
|
30
|
+
return ""
|
31
|
+
return self.metric_type.value
|
32
|
+
|
33
|
+
def parse_metric_str(self) -> str:
|
34
|
+
if self.metric_type == MetricType.L2:
|
35
|
+
return "L2Distance"
|
36
|
+
if self.metric_type == MetricType.COSINE:
|
37
|
+
return "cosineDistance"
|
38
|
+
msg = f"Not Support for {self.metric_type}"
|
39
|
+
raise RuntimeError(msg)
|
40
|
+
return None
|
41
|
+
|
42
|
+
|
43
|
+
class ClickhouseHNSWConfig(ClickhouseIndexConfig, DBCaseConfig):
|
44
|
+
M: int | None
|
45
|
+
efConstruction: int | None
|
46
|
+
ef: int | None = None
|
47
|
+
index: IndexType = IndexType.HNSW
|
48
|
+
|
49
|
+
def index_param(self) -> dict:
|
50
|
+
return {
|
51
|
+
"metric_type": self.parse_metric_str(),
|
52
|
+
"index_type": self.index.value,
|
53
|
+
"params": {"M": self.M, "efConstruction": self.efConstruction},
|
54
|
+
}
|
55
|
+
|
56
|
+
def search_param(self) -> dict:
|
57
|
+
return {
|
58
|
+
"met˝ric_type": self.parse_metric_str(),
|
59
|
+
"params": {"ef": self.ef},
|
60
|
+
}
|
@@ -81,7 +81,7 @@ class ElasticCloud(VectorDB):
|
|
81
81
|
embeddings: Iterable[list[float]],
|
82
82
|
metadata: list[int],
|
83
83
|
**kwargs,
|
84
|
-
) ->
|
84
|
+
) -> tuple[int, Exception]:
|
85
85
|
"""Insert the embeddings to the elasticsearch."""
|
86
86
|
assert self.client is not None, "should self.init() first"
|
87
87
|
|
@@ -0,0 +1,122 @@
|
|
1
|
+
from typing import Annotated, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from vectordb_bench.backend.clients import DB
|
7
|
+
|
8
|
+
from ....cli.cli import (
|
9
|
+
CommonTypedDict,
|
10
|
+
cli,
|
11
|
+
click_parameter_decorators_from_typed_dict,
|
12
|
+
run,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
class MariaDBTypedDict(CommonTypedDict):
|
17
|
+
user_name: Annotated[
|
18
|
+
str,
|
19
|
+
click.option(
|
20
|
+
"--username",
|
21
|
+
type=str,
|
22
|
+
help="Username",
|
23
|
+
required=True,
|
24
|
+
),
|
25
|
+
]
|
26
|
+
password: Annotated[
|
27
|
+
str,
|
28
|
+
click.option(
|
29
|
+
"--password",
|
30
|
+
type=str,
|
31
|
+
help="Password",
|
32
|
+
required=True,
|
33
|
+
),
|
34
|
+
]
|
35
|
+
|
36
|
+
host: Annotated[
|
37
|
+
str,
|
38
|
+
click.option(
|
39
|
+
"--host",
|
40
|
+
type=str,
|
41
|
+
help="Db host",
|
42
|
+
default="127.0.0.1",
|
43
|
+
),
|
44
|
+
]
|
45
|
+
|
46
|
+
port: Annotated[
|
47
|
+
int,
|
48
|
+
click.option(
|
49
|
+
"--port",
|
50
|
+
type=int,
|
51
|
+
default=3306,
|
52
|
+
help="Db Port",
|
53
|
+
),
|
54
|
+
]
|
55
|
+
|
56
|
+
storage_engine: Annotated[
|
57
|
+
int,
|
58
|
+
click.option(
|
59
|
+
"--storage-engine",
|
60
|
+
type=click.Choice(["InnoDB", "MyISAM"]),
|
61
|
+
help="DB storage engine",
|
62
|
+
required=True,
|
63
|
+
),
|
64
|
+
]
|
65
|
+
|
66
|
+
|
67
|
+
class MariaDBHNSWTypedDict(MariaDBTypedDict):
|
68
|
+
m: Annotated[
|
69
|
+
int | None,
|
70
|
+
click.option(
|
71
|
+
"--m",
|
72
|
+
type=int,
|
73
|
+
help="M parameter in MHNSW vector indexing",
|
74
|
+
required=False,
|
75
|
+
),
|
76
|
+
]
|
77
|
+
|
78
|
+
ef_search: Annotated[
|
79
|
+
int | None,
|
80
|
+
click.option(
|
81
|
+
"--ef-search",
|
82
|
+
type=int,
|
83
|
+
help="MariaDB system variable mhnsw_min_limit",
|
84
|
+
required=False,
|
85
|
+
),
|
86
|
+
]
|
87
|
+
|
88
|
+
max_cache_size: Annotated[
|
89
|
+
int | None,
|
90
|
+
click.option(
|
91
|
+
"--max-cache-size",
|
92
|
+
type=int,
|
93
|
+
help="MariaDB system variable mhnsw_max_cache_size",
|
94
|
+
required=False,
|
95
|
+
),
|
96
|
+
]
|
97
|
+
|
98
|
+
|
99
|
+
@cli.command()
|
100
|
+
@click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict)
|
101
|
+
def MariaDBHNSW(
|
102
|
+
**parameters: Unpack[MariaDBHNSWTypedDict],
|
103
|
+
):
|
104
|
+
from .config import MariaDBConfig, MariaDBHNSWConfig
|
105
|
+
|
106
|
+
run(
|
107
|
+
db=DB.MariaDB,
|
108
|
+
db_config=MariaDBConfig(
|
109
|
+
db_label=parameters["db_label"],
|
110
|
+
user_name=parameters["username"],
|
111
|
+
password=SecretStr(parameters["password"]),
|
112
|
+
host=parameters["host"],
|
113
|
+
port=parameters["port"],
|
114
|
+
),
|
115
|
+
db_case_config=MariaDBHNSWConfig(
|
116
|
+
M=parameters["m"],
|
117
|
+
ef_search=parameters["ef_search"],
|
118
|
+
storage_engine=parameters["storage_engine"],
|
119
|
+
max_cache_size=parameters["max_cache_size"],
|
120
|
+
),
|
121
|
+
**parameters,
|
122
|
+
)
|
@@ -0,0 +1,73 @@
|
|
1
|
+
from typing import TypedDict
|
2
|
+
|
3
|
+
from pydantic import BaseModel, SecretStr
|
4
|
+
|
5
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
6
|
+
|
7
|
+
|
8
|
+
class MariaDBConfigDict(TypedDict):
|
9
|
+
"""These keys will be directly used as kwargs in mariadb connection string,
|
10
|
+
so the names must match exactly mariadb API"""
|
11
|
+
|
12
|
+
user: str
|
13
|
+
password: str
|
14
|
+
host: str
|
15
|
+
port: int
|
16
|
+
|
17
|
+
|
18
|
+
class MariaDBConfig(DBConfig):
|
19
|
+
user_name: str = "root"
|
20
|
+
password: SecretStr
|
21
|
+
host: str = "127.0.0.1"
|
22
|
+
port: int = 3306
|
23
|
+
|
24
|
+
def to_dict(self) -> MariaDBConfigDict:
|
25
|
+
pwd_str = self.password.get_secret_value()
|
26
|
+
return {
|
27
|
+
"host": self.host,
|
28
|
+
"port": self.port,
|
29
|
+
"user": self.user_name,
|
30
|
+
"password": pwd_str,
|
31
|
+
}
|
32
|
+
|
33
|
+
|
34
|
+
class MariaDBIndexConfig(BaseModel):
|
35
|
+
"""Base config for MariaDB"""
|
36
|
+
|
37
|
+
metric_type: MetricType | None = None
|
38
|
+
|
39
|
+
def parse_metric(self) -> str:
|
40
|
+
if self.metric_type == MetricType.L2:
|
41
|
+
return "euclidean"
|
42
|
+
if self.metric_type == MetricType.COSINE:
|
43
|
+
return "cosine"
|
44
|
+
msg = f"Metric type {self.metric_type} is not supported!"
|
45
|
+
raise ValueError(msg)
|
46
|
+
|
47
|
+
|
48
|
+
class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
|
49
|
+
M: int | None
|
50
|
+
ef_search: int | None
|
51
|
+
index: IndexType = IndexType.HNSW
|
52
|
+
storage_engine: str = "InnoDB"
|
53
|
+
max_cache_size: int | None
|
54
|
+
|
55
|
+
def index_param(self) -> dict:
|
56
|
+
return {
|
57
|
+
"storage_engine": self.storage_engine,
|
58
|
+
"metric_type": self.parse_metric(),
|
59
|
+
"index_type": self.index.value,
|
60
|
+
"M": self.M,
|
61
|
+
"max_cache_size": self.max_cache_size,
|
62
|
+
}
|
63
|
+
|
64
|
+
def search_param(self) -> dict:
|
65
|
+
return {
|
66
|
+
"metric_type": self.parse_metric(),
|
67
|
+
"ef_search": self.ef_search,
|
68
|
+
}
|
69
|
+
|
70
|
+
|
71
|
+
_mariadb_case_config = {
|
72
|
+
IndexType.HNSW: MariaDBHNSWConfig,
|
73
|
+
}
|