vectordb-bench 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +48 -0
- vectordb_bench/backend/clients/api.py +1 -0
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +53 -4
- vectordb_bench/backend/clients/aws_opensearch/cli.py +85 -1
- vectordb_bench/backend/clients/aws_opensearch/config.py +10 -0
- vectordb_bench/backend/clients/mariadb/cli.py +107 -0
- vectordb_bench/backend/clients/mariadb/config.py +71 -0
- vectordb_bench/backend/clients/mariadb/mariadb.py +214 -0
- vectordb_bench/backend/clients/milvus/cli.py +50 -0
- vectordb_bench/backend/clients/milvus/config.py +33 -0
- vectordb_bench/backend/clients/mongodb/config.py +53 -0
- vectordb_bench/backend/clients/mongodb/mongodb.py +200 -0
- vectordb_bench/backend/clients/pgvector/cli.py +13 -1
- vectordb_bench/backend/clients/pgvector/config.py +22 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
- vectordb_bench/backend/clients/tidb/cli.py +98 -0
- vectordb_bench/backend/clients/tidb/config.py +49 -0
- vectordb_bench/backend/clients/tidb/tidb.py +234 -0
- vectordb_bench/cli/vectordbbench.py +4 -0
- vectordb_bench/frontend/components/custom/displaypPrams.py +12 -1
- vectordb_bench/frontend/components/run_test/submitTask.py +20 -3
- vectordb_bench/frontend/config/dbCaseConfigs.py +128 -0
- vectordb_bench/frontend/config/styles.py +2 -0
- vectordb_bench/log_util.py +15 -2
- vectordb_bench/models.py +7 -0
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/METADATA +67 -3
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/RECORD +31 -23
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/top_level.txt +0 -0
@@ -38,8 +38,11 @@ class DB(Enum):
|
|
38
38
|
Chroma = "Chroma"
|
39
39
|
AWSOpenSearch = "OpenSearch"
|
40
40
|
AliyunElasticsearch = "AliyunElasticsearch"
|
41
|
+
MariaDB = "MariaDB"
|
41
42
|
Test = "test"
|
42
43
|
AliyunOpenSearch = "AliyunOpenSearch"
|
44
|
+
MongoDB = "MongoDB"
|
45
|
+
TiDB = "TiDB"
|
43
46
|
|
44
47
|
@property
|
45
48
|
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
|
@@ -129,6 +132,21 @@ class DB(Enum):
|
|
129
132
|
|
130
133
|
return AliyunOpenSearch
|
131
134
|
|
135
|
+
if self == DB.MongoDB:
|
136
|
+
from .mongodb.mongodb import MongoDB
|
137
|
+
|
138
|
+
return MongoDB
|
139
|
+
|
140
|
+
if self == DB.MariaDB:
|
141
|
+
from .mariadb.mariadb import MariaDB
|
142
|
+
|
143
|
+
return MariaDB
|
144
|
+
|
145
|
+
if self == DB.TiDB:
|
146
|
+
from .tidb.tidb import TiDB
|
147
|
+
|
148
|
+
return TiDB
|
149
|
+
|
132
150
|
if self == DB.Test:
|
133
151
|
from .test.test import Test
|
134
152
|
|
@@ -225,6 +243,21 @@ class DB(Enum):
|
|
225
243
|
|
226
244
|
return AliyunOpenSearchConfig
|
227
245
|
|
246
|
+
if self == DB.MongoDB:
|
247
|
+
from .mongodb.config import MongoDBConfig
|
248
|
+
|
249
|
+
return MongoDBConfig
|
250
|
+
|
251
|
+
if self == DB.MariaDB:
|
252
|
+
from .mariadb.config import MariaDBConfig
|
253
|
+
|
254
|
+
return MariaDBConfig
|
255
|
+
|
256
|
+
if self == DB.TiDB:
|
257
|
+
from .tidb.config import TiDBConfig
|
258
|
+
|
259
|
+
return TiDBConfig
|
260
|
+
|
228
261
|
if self == DB.Test:
|
229
262
|
from .test.config import TestConfig
|
230
263
|
|
@@ -302,6 +335,21 @@ class DB(Enum):
|
|
302
335
|
|
303
336
|
return AliyunOpenSearchIndexConfig
|
304
337
|
|
338
|
+
if self == DB.MongoDB:
|
339
|
+
from .mongodb.config import MongoDBIndexConfig
|
340
|
+
|
341
|
+
return MongoDBIndexConfig
|
342
|
+
|
343
|
+
if self == DB.MariaDB:
|
344
|
+
from .mariadb.config import _mariadb_case_config
|
345
|
+
|
346
|
+
return _mariadb_case_config.get(index_type)
|
347
|
+
|
348
|
+
if self == DB.TiDB:
|
349
|
+
from .tidb.config import TiDBIndexConfig
|
350
|
+
|
351
|
+
return TiDBIndexConfig
|
352
|
+
|
305
353
|
# DB.Pinecone, DB.Chroma, DB.Redis
|
306
354
|
return EmptyDBCaseConfig
|
307
355
|
|
@@ -12,6 +12,7 @@ log = logging.getLogger(__name__)
|
|
12
12
|
|
13
13
|
WAITING_FOR_REFRESH_SEC = 30
|
14
14
|
WAITING_FOR_FORCE_MERGE_SEC = 30
|
15
|
+
SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC = 30
|
15
16
|
|
16
17
|
|
17
18
|
class AWSOpenSearch(VectorDB):
|
@@ -52,10 +53,27 @@ class AWSOpenSearch(VectorDB):
|
|
52
53
|
return AWSOpenSearchIndexConfig
|
53
54
|
|
54
55
|
def _create_index(self, client: OpenSearch):
|
56
|
+
cluster_settings_body = {
|
57
|
+
"persistent": {
|
58
|
+
"knn.algo_param.index_thread_qty": self.case_config.index_thread_qty,
|
59
|
+
"knn.memory.circuit_breaker.limit": self.case_config.cb_threshold,
|
60
|
+
}
|
61
|
+
}
|
62
|
+
client.cluster.put_settings(cluster_settings_body)
|
55
63
|
settings = {
|
56
64
|
"index": {
|
57
65
|
"knn": True,
|
66
|
+
"number_of_shards": self.case_config.number_of_shards,
|
67
|
+
"number_of_replicas": 0,
|
68
|
+
"translog.flush_threshold_size": self.case_config.flush_threshold_size,
|
69
|
+
# Setting trans log threshold to 5GB
|
70
|
+
**(
|
71
|
+
{"knn.algo_param.ef_search": self.case_config.ef_search}
|
72
|
+
if self.case_config.engine == AWSOS_Engine.nmslib
|
73
|
+
else {}
|
74
|
+
),
|
58
75
|
},
|
76
|
+
"refresh_interval": self.case_config.refresh_interval,
|
59
77
|
}
|
60
78
|
mappings = {
|
61
79
|
"properties": {
|
@@ -145,9 +163,9 @@ class AWSOpenSearch(VectorDB):
|
|
145
163
|
docvalue_fields=[self.id_col_name],
|
146
164
|
stored_fields="_none_",
|
147
165
|
)
|
148
|
-
log.
|
149
|
-
log.
|
150
|
-
log.
|
166
|
+
log.debug(f"Search took: {resp['took']}")
|
167
|
+
log.debug(f"Search shards: {resp['_shards']}")
|
168
|
+
log.debug(f"Search hits total: {resp['hits']['total']}")
|
151
169
|
return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]]
|
152
170
|
except Exception as e:
|
153
171
|
log.warning(f"Failed to search: {self.index_name} error: {e!s}")
|
@@ -157,12 +175,37 @@ class AWSOpenSearch(VectorDB):
|
|
157
175
|
"""optimize will be called between insertion and search in performance cases."""
|
158
176
|
# Call refresh first to ensure that all segments are created
|
159
177
|
self._refresh_index()
|
160
|
-
self.
|
178
|
+
if self.case_config.force_merge_enabled:
|
179
|
+
self._do_force_merge()
|
180
|
+
self._refresh_index()
|
181
|
+
self._update_replicas()
|
161
182
|
# Call refresh again to ensure that the index is ready after force merge.
|
162
183
|
self._refresh_index()
|
163
184
|
# ensure that all graphs are loaded in memory and ready for search
|
164
185
|
self._load_graphs_to_memory()
|
165
186
|
|
187
|
+
def _update_replicas(self):
|
188
|
+
index_settings = self.client.indices.get_settings(index=self.index_name)
|
189
|
+
current_number_of_replicas = int(index_settings[self.index_name]["settings"]["index"]["number_of_replicas"])
|
190
|
+
log.info(
|
191
|
+
f"Current Number of replicas are {current_number_of_replicas}"
|
192
|
+
f" and changing the replicas to {self.case_config.number_of_replicas}"
|
193
|
+
)
|
194
|
+
settings_body = {"index": {"number_of_replicas": self.case_config.number_of_replicas}}
|
195
|
+
self.client.indices.put_settings(index=self.index_name, body=settings_body)
|
196
|
+
self._wait_till_green()
|
197
|
+
|
198
|
+
def _wait_till_green(self):
|
199
|
+
log.info("Wait for index to become green..")
|
200
|
+
while True:
|
201
|
+
res = self.client.cat.indices(index=self.index_name, h="health", format="json")
|
202
|
+
health = res[0]["health"]
|
203
|
+
if health != "green":
|
204
|
+
break
|
205
|
+
log.info(f"The index {self.index_name} has health : {health} and is not green. Retrying")
|
206
|
+
time.sleep(SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC)
|
207
|
+
log.info(f"Index {self.index_name} is green..")
|
208
|
+
|
166
209
|
def _refresh_index(self):
|
167
210
|
log.debug(f"Starting refresh for index {self.index_name}")
|
168
211
|
while True:
|
@@ -179,6 +222,12 @@ class AWSOpenSearch(VectorDB):
|
|
179
222
|
log.debug(f"Completed refresh for index {self.index_name}")
|
180
223
|
|
181
224
|
def _do_force_merge(self):
|
225
|
+
log.info(f"Updating the Index thread qty to {self.case_config.index_thread_qty_during_force_merge}.")
|
226
|
+
|
227
|
+
cluster_settings_body = {
|
228
|
+
"persistent": {"knn.algo_param.index_thread_qty": self.case_config.index_thread_qty_during_force_merge}
|
229
|
+
}
|
230
|
+
self.client.cluster.put_settings(cluster_settings_body)
|
182
231
|
log.debug(f"Starting force merge for index {self.index_name}")
|
183
232
|
force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
|
184
233
|
force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"]
|
@@ -18,6 +18,79 @@ class AWSOpenSearchTypedDict(TypedDict):
|
|
18
18
|
port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
|
19
19
|
user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
|
20
20
|
password: Annotated[str, click.option("--password", type=str, help="Db password")]
|
21
|
+
number_of_shards: Annotated[
|
22
|
+
int,
|
23
|
+
click.option("--number-of-shards", type=int, help="Number of primary shards for the index", default=1),
|
24
|
+
]
|
25
|
+
number_of_replicas: Annotated[
|
26
|
+
int,
|
27
|
+
click.option(
|
28
|
+
"--number-of-replicas", type=int, help="Number of replica copies for each primary shard", default=1
|
29
|
+
),
|
30
|
+
]
|
31
|
+
index_thread_qty: Annotated[
|
32
|
+
int,
|
33
|
+
click.option(
|
34
|
+
"--index-thread-qty",
|
35
|
+
type=int,
|
36
|
+
help="Thread count for native engine indexing",
|
37
|
+
default=4,
|
38
|
+
),
|
39
|
+
]
|
40
|
+
|
41
|
+
index_thread_qty_during_force_merge: Annotated[
|
42
|
+
int,
|
43
|
+
click.option(
|
44
|
+
"--index-thread-qty-during-force-merge",
|
45
|
+
type=int,
|
46
|
+
help="Thread count during force merge operations",
|
47
|
+
default=4,
|
48
|
+
),
|
49
|
+
]
|
50
|
+
|
51
|
+
number_of_indexing_clients: Annotated[
|
52
|
+
int,
|
53
|
+
click.option(
|
54
|
+
"--number-of-indexing-clients",
|
55
|
+
type=int,
|
56
|
+
help="Number of concurrent indexing clients",
|
57
|
+
default=1,
|
58
|
+
),
|
59
|
+
]
|
60
|
+
|
61
|
+
number_of_segments: Annotated[
|
62
|
+
int,
|
63
|
+
click.option("--number-of-segments", type=int, help="Target number of segments after merging", default=1),
|
64
|
+
]
|
65
|
+
|
66
|
+
refresh_interval: Annotated[
|
67
|
+
int,
|
68
|
+
click.option(
|
69
|
+
"--refresh-interval", type=str, help="How often to make new data available for search", default="60s"
|
70
|
+
),
|
71
|
+
]
|
72
|
+
|
73
|
+
force_merge_enabled: Annotated[
|
74
|
+
int,
|
75
|
+
click.option("--force-merge-enabled", type=bool, help="Whether to perform force merge operation", default=True),
|
76
|
+
]
|
77
|
+
|
78
|
+
flush_threshold_size: Annotated[
|
79
|
+
int,
|
80
|
+
click.option(
|
81
|
+
"--flush-threshold-size", type=str, help="Size threshold for flushing the transaction log", default="5120mb"
|
82
|
+
),
|
83
|
+
]
|
84
|
+
|
85
|
+
cb_threshold: Annotated[
|
86
|
+
int,
|
87
|
+
click.option(
|
88
|
+
"--cb-threshold",
|
89
|
+
type=str,
|
90
|
+
help="k-NN Memory circuit breaker threshold",
|
91
|
+
default="50%",
|
92
|
+
),
|
93
|
+
]
|
21
94
|
|
22
95
|
|
23
96
|
class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ...
|
@@ -36,6 +109,17 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
|
|
36
109
|
user=parameters["user"],
|
37
110
|
password=SecretStr(parameters["password"]),
|
38
111
|
),
|
39
|
-
db_case_config=AWSOpenSearchIndexConfig(
|
112
|
+
db_case_config=AWSOpenSearchIndexConfig(
|
113
|
+
number_of_shards=parameters["number_of_shards"],
|
114
|
+
number_of_replicas=parameters["number_of_replicas"],
|
115
|
+
index_thread_qty=parameters["index_thread_qty"],
|
116
|
+
number_of_segments=parameters["number_of_segments"],
|
117
|
+
refresh_interval=parameters["refresh_interval"],
|
118
|
+
force_merge_enabled=parameters["force_merge_enabled"],
|
119
|
+
flush_threshold_size=parameters["flush_threshold_size"],
|
120
|
+
number_of_indexing_clients=parameters["number_of_indexing_clients"],
|
121
|
+
index_thread_qty_during_force_merge=parameters["index_thread_qty_during_force_merge"],
|
122
|
+
cb_threshold=parameters["cb_threshold"],
|
123
|
+
),
|
40
124
|
**parameters,
|
41
125
|
)
|
@@ -39,6 +39,16 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
|
39
39
|
efConstruction: int = 256
|
40
40
|
efSearch: int = 256
|
41
41
|
M: int = 16
|
42
|
+
index_thread_qty: int | None = 4
|
43
|
+
number_of_shards: int | None = 1
|
44
|
+
number_of_replicas: int | None = 0
|
45
|
+
number_of_segments: int | None = 1
|
46
|
+
refresh_interval: str | None = "60s"
|
47
|
+
force_merge_enabled: bool | None = True
|
48
|
+
flush_threshold_size: str | None = "5120mb"
|
49
|
+
number_of_indexing_clients: int | None = 1
|
50
|
+
index_thread_qty_during_force_merge: int
|
51
|
+
cb_threshold: str | None = "50%"
|
42
52
|
|
43
53
|
def parse_metric(self) -> str:
|
44
54
|
if self.metric_type == MetricType.IP:
|
@@ -0,0 +1,107 @@
|
|
1
|
+
from typing import Annotated, Optional, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
import os
|
5
|
+
from pydantic import SecretStr
|
6
|
+
|
7
|
+
from ....cli.cli import (
|
8
|
+
CommonTypedDict,
|
9
|
+
HNSWFlavor1,
|
10
|
+
cli,
|
11
|
+
click_parameter_decorators_from_typed_dict,
|
12
|
+
run,
|
13
|
+
)
|
14
|
+
from vectordb_bench.backend.clients import DB
|
15
|
+
|
16
|
+
|
17
|
+
class MariaDBTypedDict(CommonTypedDict):
|
18
|
+
user_name: Annotated[
|
19
|
+
str, click.option("--username",
|
20
|
+
type=str,
|
21
|
+
help="Username",
|
22
|
+
required=True,
|
23
|
+
),
|
24
|
+
]
|
25
|
+
password: Annotated[
|
26
|
+
str, click.option("--password",
|
27
|
+
type=str,
|
28
|
+
help="Password",
|
29
|
+
required=True,
|
30
|
+
),
|
31
|
+
]
|
32
|
+
|
33
|
+
host: Annotated[
|
34
|
+
str, click.option("--host",
|
35
|
+
type=str,
|
36
|
+
help="Db host",
|
37
|
+
default="127.0.0.1",
|
38
|
+
),
|
39
|
+
]
|
40
|
+
|
41
|
+
port: Annotated[
|
42
|
+
int, click.option("--port",
|
43
|
+
type=int,
|
44
|
+
default=3306,
|
45
|
+
help="Db Port",
|
46
|
+
),
|
47
|
+
]
|
48
|
+
|
49
|
+
storage_engine: Annotated[
|
50
|
+
int, click.option("--storage-engine",
|
51
|
+
type=click.Choice(["InnoDB", "MyISAM"]),
|
52
|
+
help="DB storage engine",
|
53
|
+
required=True,
|
54
|
+
),
|
55
|
+
]
|
56
|
+
|
57
|
+
class MariaDBHNSWTypedDict(MariaDBTypedDict):
|
58
|
+
...
|
59
|
+
m: Annotated[
|
60
|
+
Optional[int], click.option("--m",
|
61
|
+
type=int,
|
62
|
+
help="M parameter in MHNSW vector indexing",
|
63
|
+
required=False,
|
64
|
+
),
|
65
|
+
]
|
66
|
+
|
67
|
+
ef_search: Annotated[
|
68
|
+
Optional[int], click.option("--ef-search",
|
69
|
+
type=int,
|
70
|
+
help="MariaDB system variable mhnsw_min_limit",
|
71
|
+
required=False,
|
72
|
+
),
|
73
|
+
]
|
74
|
+
|
75
|
+
max_cache_size: Annotated[
|
76
|
+
Optional[int], click.option("--max-cache-size",
|
77
|
+
type=int,
|
78
|
+
help="MariaDB system variable mhnsw_max_cache_size",
|
79
|
+
required=False,
|
80
|
+
),
|
81
|
+
]
|
82
|
+
|
83
|
+
|
84
|
+
@cli.command()
|
85
|
+
@click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict)
|
86
|
+
def MariaDBHNSW(
|
87
|
+
**parameters: Unpack[MariaDBHNSWTypedDict],
|
88
|
+
):
|
89
|
+
from .config import MariaDBConfig, MariaDBHNSWConfig
|
90
|
+
|
91
|
+
run(
|
92
|
+
db=DB.MariaDB,
|
93
|
+
db_config=MariaDBConfig(
|
94
|
+
db_label=parameters["db_label"],
|
95
|
+
user_name=parameters["username"],
|
96
|
+
password=SecretStr(parameters["password"]),
|
97
|
+
host=parameters["host"],
|
98
|
+
port=parameters["port"],
|
99
|
+
),
|
100
|
+
db_case_config=MariaDBHNSWConfig(
|
101
|
+
M=parameters["m"],
|
102
|
+
ef_search=parameters["ef_search"],
|
103
|
+
storage_engine=parameters["storage_engine"],
|
104
|
+
max_cache_size=parameters["max_cache_size"],
|
105
|
+
),
|
106
|
+
**parameters,
|
107
|
+
)
|
@@ -0,0 +1,71 @@
|
|
1
|
+
from pydantic import SecretStr, BaseModel
|
2
|
+
from typing import TypedDict
|
3
|
+
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
|
4
|
+
|
5
|
+
class MariaDBConfigDict(TypedDict):
|
6
|
+
"""These keys will be directly used as kwargs in mariadb connection string,
|
7
|
+
so the names must match exactly mariadb API"""
|
8
|
+
|
9
|
+
user: str
|
10
|
+
password: str
|
11
|
+
host: str
|
12
|
+
port: int
|
13
|
+
|
14
|
+
|
15
|
+
class MariaDBConfig(DBConfig):
|
16
|
+
user_name: str = "root"
|
17
|
+
password: SecretStr
|
18
|
+
host: str = "127.0.0.1"
|
19
|
+
port: int = 3306
|
20
|
+
|
21
|
+
def to_dict(self) -> MariaDBConfigDict:
|
22
|
+
pwd_str = self.password.get_secret_value()
|
23
|
+
return {
|
24
|
+
"host": self.host,
|
25
|
+
"port": self.port,
|
26
|
+
"user": self.user_name,
|
27
|
+
"password": pwd_str,
|
28
|
+
}
|
29
|
+
|
30
|
+
|
31
|
+
class MariaDBIndexConfig(BaseModel):
|
32
|
+
"""Base config for MariaDB"""
|
33
|
+
|
34
|
+
metric_type: MetricType | None = None
|
35
|
+
|
36
|
+
def parse_metric(self) -> str:
|
37
|
+
if self.metric_type == MetricType.L2:
|
38
|
+
return "euclidean"
|
39
|
+
elif self.metric_type == MetricType.COSINE:
|
40
|
+
return "cosine"
|
41
|
+
else:
|
42
|
+
raise ValueError(f"Metric type {self.metric_type} is not supported!")
|
43
|
+
|
44
|
+
class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
|
45
|
+
M: int | None
|
46
|
+
ef_search: int | None
|
47
|
+
index: IndexType = IndexType.HNSW
|
48
|
+
storage_engine: str = "InnoDB"
|
49
|
+
max_cache_size: int | None
|
50
|
+
|
51
|
+
def index_param(self) -> dict:
|
52
|
+
return {
|
53
|
+
"storage_engine": self.storage_engine,
|
54
|
+
"metric_type": self.parse_metric(),
|
55
|
+
"index_type": self.index.value,
|
56
|
+
"M": self.M,
|
57
|
+
"max_cache_size": self.max_cache_size,
|
58
|
+
}
|
59
|
+
|
60
|
+
def search_param(self) -> dict:
|
61
|
+
return {
|
62
|
+
"metric_type": self.parse_metric(),
|
63
|
+
"ef_search": self.ef_search,
|
64
|
+
}
|
65
|
+
|
66
|
+
|
67
|
+
_mariadb_case_config = {
|
68
|
+
IndexType.HNSW: MariaDBHNSWConfig,
|
69
|
+
}
|
70
|
+
|
71
|
+
|
@@ -0,0 +1,214 @@
|
|
1
|
+
from ..api import VectorDB
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from contextlib import contextmanager
|
5
|
+
from typing import Any, Optional, Tuple
|
6
|
+
from ..api import VectorDB
|
7
|
+
from .config import MariaDBConfigDict, MariaDBIndexConfig
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
import mariadb
|
11
|
+
|
12
|
+
log = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
class MariaDB(VectorDB):
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
dim: int,
|
18
|
+
db_config: MariaDBConfigDict,
|
19
|
+
db_case_config: MariaDBIndexConfig,
|
20
|
+
collection_name: str = "vec_collection",
|
21
|
+
drop_old: bool = False,
|
22
|
+
**kwargs,
|
23
|
+
):
|
24
|
+
|
25
|
+
self.name = "MariaDB"
|
26
|
+
self.db_config = db_config
|
27
|
+
self.case_config = db_case_config
|
28
|
+
self.db_name = "vectordbbench"
|
29
|
+
self.table_name = collection_name
|
30
|
+
self.dim = dim
|
31
|
+
|
32
|
+
# construct basic units
|
33
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
34
|
+
|
35
|
+
if drop_old:
|
36
|
+
self._drop_db()
|
37
|
+
self._create_db_table(dim)
|
38
|
+
|
39
|
+
self.cursor.close()
|
40
|
+
self.conn.close()
|
41
|
+
self.cursor = None
|
42
|
+
self.conn = None
|
43
|
+
|
44
|
+
|
45
|
+
@staticmethod
|
46
|
+
def _create_connection(**kwargs) -> Tuple[mariadb.Connection, mariadb.Cursor]:
|
47
|
+
conn = mariadb.connect(**kwargs)
|
48
|
+
cursor = conn.cursor()
|
49
|
+
|
50
|
+
assert conn is not None, "Connection is not initialized"
|
51
|
+
assert cursor is not None, "Cursor is not initialized"
|
52
|
+
|
53
|
+
return conn, cursor
|
54
|
+
|
55
|
+
|
56
|
+
def _drop_db(self):
|
57
|
+
assert self.conn is not None, "Connection is not initialized"
|
58
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
59
|
+
log.info(f"{self.name} client drop db : {self.db_name}")
|
60
|
+
|
61
|
+
# flush tables before dropping database to avoid some locking issue
|
62
|
+
self.cursor.execute("FLUSH TABLES")
|
63
|
+
self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}")
|
64
|
+
self.cursor.execute("COMMIT")
|
65
|
+
self.cursor.execute("FLUSH TABLES")
|
66
|
+
|
67
|
+
def _create_db_table(self, dim: int):
|
68
|
+
assert self.conn is not None, "Connection is not initialized"
|
69
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
70
|
+
|
71
|
+
index_param = self.case_config.index_param()
|
72
|
+
|
73
|
+
try:
|
74
|
+
log.info(f"{self.name} client create database : {self.db_name}")
|
75
|
+
self.cursor.execute(f"CREATE DATABASE {self.db_name}")
|
76
|
+
|
77
|
+
log.info(f"{self.name} client create table : {self.table_name}")
|
78
|
+
self.cursor.execute(f"USE {self.db_name}")
|
79
|
+
|
80
|
+
self.cursor.execute(f"""
|
81
|
+
CREATE TABLE {self.table_name} (
|
82
|
+
id INT PRIMARY KEY,
|
83
|
+
v VECTOR({self.dim}) NOT NULL
|
84
|
+
) ENGINE={index_param["storage_engine"]}
|
85
|
+
""")
|
86
|
+
self.cursor.execute("COMMIT")
|
87
|
+
|
88
|
+
except Exception as e:
|
89
|
+
log.warning(
|
90
|
+
f"Failed to create table: {self.table_name} error: {e}"
|
91
|
+
)
|
92
|
+
raise e from None
|
93
|
+
|
94
|
+
|
95
|
+
@contextmanager
|
96
|
+
def init(self) -> None:
|
97
|
+
""" create and destory connections to database.
|
98
|
+
|
99
|
+
Examples:
|
100
|
+
>>> with self.init():
|
101
|
+
>>> self.insert_embeddings()
|
102
|
+
"""
|
103
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
104
|
+
|
105
|
+
index_param = self.case_config.index_param()
|
106
|
+
search_param = self.case_config.search_param()
|
107
|
+
|
108
|
+
# maximize allowed package size
|
109
|
+
self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824")
|
110
|
+
|
111
|
+
if index_param["index_type"] == "HNSW":
|
112
|
+
if index_param["max_cache_size"] != None:
|
113
|
+
self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param["max_cache_size"]}")
|
114
|
+
if search_param["ef_search"] != None:
|
115
|
+
self.cursor.execute(f"SET mhnsw_ef_search = {search_param["ef_search"]}")
|
116
|
+
self.cursor.execute("COMMIT")
|
117
|
+
|
118
|
+
self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)"
|
119
|
+
self.select_sql = f"SELECT id FROM {self.db_name}.{self.table_name} ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d"
|
120
|
+
self.select_sql_with_filter = f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d"
|
121
|
+
|
122
|
+
try:
|
123
|
+
yield
|
124
|
+
finally:
|
125
|
+
self.cursor.close()
|
126
|
+
self.conn.close()
|
127
|
+
self.cursor = None
|
128
|
+
self.conn = None
|
129
|
+
|
130
|
+
|
131
|
+
def ready_to_load(self) -> bool:
|
132
|
+
pass
|
133
|
+
|
134
|
+
def optimize(self) -> None:
|
135
|
+
assert self.conn is not None, "Connection is not initialized"
|
136
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
137
|
+
|
138
|
+
index_param = self.case_config.index_param()
|
139
|
+
|
140
|
+
try:
|
141
|
+
index_options = f"DISTANCE={index_param['metric_type']}"
|
142
|
+
if index_param["index_type"] == "HNSW" and index_param["M"] != None:
|
143
|
+
index_options += f" M={index_param['M']}"
|
144
|
+
|
145
|
+
self.cursor.execute(f"""
|
146
|
+
ALTER TABLE {self.db_name}.{self.table_name}
|
147
|
+
ADD VECTOR KEY v(v) {index_options}
|
148
|
+
""")
|
149
|
+
self.cursor.execute("COMMIT")
|
150
|
+
|
151
|
+
except Exception as e:
|
152
|
+
log.warning(
|
153
|
+
f"Failed to create index: {self.table_name} error: {e}"
|
154
|
+
)
|
155
|
+
raise e from None
|
156
|
+
|
157
|
+
pass
|
158
|
+
|
159
|
+
@staticmethod
|
160
|
+
def vector_to_hex(v):
|
161
|
+
return np.array(v, 'float32').tobytes()
|
162
|
+
|
163
|
+
def insert_embeddings(
|
164
|
+
self,
|
165
|
+
embeddings: list[list[float]],
|
166
|
+
metadata: list[int],
|
167
|
+
**kwargs: Any,
|
168
|
+
) -> Tuple[int, Optional[Exception]]:
|
169
|
+
"""Insert embeddings into the database.
|
170
|
+
Should call self.init() first.
|
171
|
+
"""
|
172
|
+
assert self.conn is not None, "Connection is not initialized"
|
173
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
174
|
+
|
175
|
+
try:
|
176
|
+
metadata_arr = np.array(metadata)
|
177
|
+
embeddings_arr = np.array(embeddings)
|
178
|
+
|
179
|
+
batch_data = []
|
180
|
+
for i, row in enumerate(metadata_arr):
|
181
|
+
batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])));
|
182
|
+
|
183
|
+
self.cursor.executemany(self.insert_sql, batch_data)
|
184
|
+
self.cursor.execute("COMMIT")
|
185
|
+
self.cursor.execute("FLUSH TABLES")
|
186
|
+
|
187
|
+
return len(metadata), None
|
188
|
+
except Exception as e:
|
189
|
+
log.warning(
|
190
|
+
f"Failed to insert data into Vector table ({self.table_name}), error: {e}"
|
191
|
+
)
|
192
|
+
return 0, e
|
193
|
+
|
194
|
+
|
195
|
+
def search_embedding(
|
196
|
+
self,
|
197
|
+
query: list[float],
|
198
|
+
k: int = 100,
|
199
|
+
filters: dict | None = None,
|
200
|
+
timeout: int | None = None,
|
201
|
+
**kwargs: Any,
|
202
|
+
) -> (list[int]):
|
203
|
+
assert self.conn is not None, "Connection is not initialized"
|
204
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
205
|
+
|
206
|
+
search_param = self.case_config.search_param()
|
207
|
+
|
208
|
+
if filters:
|
209
|
+
self.cursor.execute(self.select_sql_with_filter, (filters.get('id'), self.vector_to_hex(query), k))
|
210
|
+
else:
|
211
|
+
self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k))
|
212
|
+
|
213
|
+
return [id for id, in self.cursor.fetchall()]
|
214
|
+
|