vectordb-bench 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. vectordb_bench/backend/clients/__init__.py +48 -0
  2. vectordb_bench/backend/clients/api.py +1 -0
  3. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +53 -4
  4. vectordb_bench/backend/clients/aws_opensearch/cli.py +85 -1
  5. vectordb_bench/backend/clients/aws_opensearch/config.py +10 -0
  6. vectordb_bench/backend/clients/mariadb/cli.py +107 -0
  7. vectordb_bench/backend/clients/mariadb/config.py +71 -0
  8. vectordb_bench/backend/clients/mariadb/mariadb.py +214 -0
  9. vectordb_bench/backend/clients/milvus/cli.py +50 -0
  10. vectordb_bench/backend/clients/milvus/config.py +33 -0
  11. vectordb_bench/backend/clients/mongodb/config.py +53 -0
  12. vectordb_bench/backend/clients/mongodb/mongodb.py +200 -0
  13. vectordb_bench/backend/clients/pgvector/cli.py +13 -1
  14. vectordb_bench/backend/clients/pgvector/config.py +22 -5
  15. vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
  16. vectordb_bench/backend/clients/tidb/cli.py +98 -0
  17. vectordb_bench/backend/clients/tidb/config.py +49 -0
  18. vectordb_bench/backend/clients/tidb/tidb.py +234 -0
  19. vectordb_bench/cli/vectordbbench.py +4 -0
  20. vectordb_bench/frontend/components/custom/displaypPrams.py +12 -1
  21. vectordb_bench/frontend/components/run_test/submitTask.py +20 -3
  22. vectordb_bench/frontend/config/dbCaseConfigs.py +128 -0
  23. vectordb_bench/frontend/config/styles.py +2 -0
  24. vectordb_bench/log_util.py +15 -2
  25. vectordb_bench/models.py +7 -0
  26. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/METADATA +67 -3
  27. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/RECORD +31 -23
  28. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/WHEEL +1 -1
  29. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/LICENSE +0 -0
  30. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/entry_points.txt +0 -0
  31. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/top_level.txt +0 -0
@@ -38,8 +38,11 @@ class DB(Enum):
38
38
  Chroma = "Chroma"
39
39
  AWSOpenSearch = "OpenSearch"
40
40
  AliyunElasticsearch = "AliyunElasticsearch"
41
+ MariaDB = "MariaDB"
41
42
  Test = "test"
42
43
  AliyunOpenSearch = "AliyunOpenSearch"
44
+ MongoDB = "MongoDB"
45
+ TiDB = "TiDB"
43
46
 
44
47
  @property
45
48
  def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
@@ -129,6 +132,21 @@ class DB(Enum):
129
132
 
130
133
  return AliyunOpenSearch
131
134
 
135
+ if self == DB.MongoDB:
136
+ from .mongodb.mongodb import MongoDB
137
+
138
+ return MongoDB
139
+
140
+ if self == DB.MariaDB:
141
+ from .mariadb.mariadb import MariaDB
142
+
143
+ return MariaDB
144
+
145
+ if self == DB.TiDB:
146
+ from .tidb.tidb import TiDB
147
+
148
+ return TiDB
149
+
132
150
  if self == DB.Test:
133
151
  from .test.test import Test
134
152
 
@@ -225,6 +243,21 @@ class DB(Enum):
225
243
 
226
244
  return AliyunOpenSearchConfig
227
245
 
246
+ if self == DB.MongoDB:
247
+ from .mongodb.config import MongoDBConfig
248
+
249
+ return MongoDBConfig
250
+
251
+ if self == DB.MariaDB:
252
+ from .mariadb.config import MariaDBConfig
253
+
254
+ return MariaDBConfig
255
+
256
+ if self == DB.TiDB:
257
+ from .tidb.config import TiDBConfig
258
+
259
+ return TiDBConfig
260
+
228
261
  if self == DB.Test:
229
262
  from .test.config import TestConfig
230
263
 
@@ -302,6 +335,21 @@ class DB(Enum):
302
335
 
303
336
  return AliyunOpenSearchIndexConfig
304
337
 
338
+ if self == DB.MongoDB:
339
+ from .mongodb.config import MongoDBIndexConfig
340
+
341
+ return MongoDBIndexConfig
342
+
343
+ if self == DB.MariaDB:
344
+ from .mariadb.config import _mariadb_case_config
345
+
346
+ return _mariadb_case_config.get(index_type)
347
+
348
+ if self == DB.TiDB:
349
+ from .tidb.config import TiDBIndexConfig
350
+
351
+ return TiDBIndexConfig
352
+
305
353
  # DB.Pinecone, DB.Chroma, DB.Redis
306
354
  return EmptyDBCaseConfig
307
355
 
@@ -25,6 +25,7 @@ class IndexType(str, Enum):
25
25
  ES_HNSW = "hnsw"
26
26
  ES_IVFFlat = "ivfflat"
27
27
  GPU_IVF_FLAT = "GPU_IVF_FLAT"
28
+ GPU_BRUTE_FORCE = "GPU_BRUTE_FORCE"
28
29
  GPU_IVF_PQ = "GPU_IVF_PQ"
29
30
  GPU_CAGRA = "GPU_CAGRA"
30
31
  SCANN = "scann"
@@ -12,6 +12,7 @@ log = logging.getLogger(__name__)
12
12
 
13
13
  WAITING_FOR_REFRESH_SEC = 30
14
14
  WAITING_FOR_FORCE_MERGE_SEC = 30
15
+ SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC = 30
15
16
 
16
17
 
17
18
  class AWSOpenSearch(VectorDB):
@@ -52,10 +53,27 @@ class AWSOpenSearch(VectorDB):
52
53
  return AWSOpenSearchIndexConfig
53
54
 
54
55
  def _create_index(self, client: OpenSearch):
56
+ cluster_settings_body = {
57
+ "persistent": {
58
+ "knn.algo_param.index_thread_qty": self.case_config.index_thread_qty,
59
+ "knn.memory.circuit_breaker.limit": self.case_config.cb_threshold,
60
+ }
61
+ }
62
+ client.cluster.put_settings(cluster_settings_body)
55
63
  settings = {
56
64
  "index": {
57
65
  "knn": True,
66
+ "number_of_shards": self.case_config.number_of_shards,
67
+ "number_of_replicas": 0,
68
+ "translog.flush_threshold_size": self.case_config.flush_threshold_size,
69
+ # Setting trans log threshold to 5GB
70
+ **(
71
+ {"knn.algo_param.ef_search": self.case_config.ef_search}
72
+ if self.case_config.engine == AWSOS_Engine.nmslib
73
+ else {}
74
+ ),
58
75
  },
76
+ "refresh_interval": self.case_config.refresh_interval,
59
77
  }
60
78
  mappings = {
61
79
  "properties": {
@@ -145,9 +163,9 @@ class AWSOpenSearch(VectorDB):
145
163
  docvalue_fields=[self.id_col_name],
146
164
  stored_fields="_none_",
147
165
  )
148
- log.info(f"Search took: {resp['took']}")
149
- log.info(f"Search shards: {resp['_shards']}")
150
- log.info(f"Search hits total: {resp['hits']['total']}")
166
+ log.debug(f"Search took: {resp['took']}")
167
+ log.debug(f"Search shards: {resp['_shards']}")
168
+ log.debug(f"Search hits total: {resp['hits']['total']}")
151
169
  return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]]
152
170
  except Exception as e:
153
171
  log.warning(f"Failed to search: {self.index_name} error: {e!s}")
@@ -157,12 +175,37 @@ class AWSOpenSearch(VectorDB):
157
175
  """optimize will be called between insertion and search in performance cases."""
158
176
  # Call refresh first to ensure that all segments are created
159
177
  self._refresh_index()
160
- self._do_force_merge()
178
+ if self.case_config.force_merge_enabled:
179
+ self._do_force_merge()
180
+ self._refresh_index()
181
+ self._update_replicas()
161
182
  # Call refresh again to ensure that the index is ready after force merge.
162
183
  self._refresh_index()
163
184
  # ensure that all graphs are loaded in memory and ready for search
164
185
  self._load_graphs_to_memory()
165
186
 
187
+ def _update_replicas(self):
188
+ index_settings = self.client.indices.get_settings(index=self.index_name)
189
+ current_number_of_replicas = int(index_settings[self.index_name]["settings"]["index"]["number_of_replicas"])
190
+ log.info(
191
+ f"Current Number of replicas are {current_number_of_replicas}"
192
+ f" and changing the replicas to {self.case_config.number_of_replicas}"
193
+ )
194
+ settings_body = {"index": {"number_of_replicas": self.case_config.number_of_replicas}}
195
+ self.client.indices.put_settings(index=self.index_name, body=settings_body)
196
+ self._wait_till_green()
197
+
198
+ def _wait_till_green(self):
199
+ log.info("Wait for index to become green..")
200
+ while True:
201
+ res = self.client.cat.indices(index=self.index_name, h="health", format="json")
202
+ health = res[0]["health"]
203
+ if health != "green":
204
+ break
205
+ log.info(f"The index {self.index_name} has health : {health} and is not green. Retrying")
206
+ time.sleep(SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC)
207
+ log.info(f"Index {self.index_name} is green..")
208
+
166
209
  def _refresh_index(self):
167
210
  log.debug(f"Starting refresh for index {self.index_name}")
168
211
  while True:
@@ -179,6 +222,12 @@ class AWSOpenSearch(VectorDB):
179
222
  log.debug(f"Completed refresh for index {self.index_name}")
180
223
 
181
224
  def _do_force_merge(self):
225
+ log.info(f"Updating the Index thread qty to {self.case_config.index_thread_qty_during_force_merge}.")
226
+
227
+ cluster_settings_body = {
228
+ "persistent": {"knn.algo_param.index_thread_qty": self.case_config.index_thread_qty_during_force_merge}
229
+ }
230
+ self.client.cluster.put_settings(cluster_settings_body)
182
231
  log.debug(f"Starting force merge for index {self.index_name}")
183
232
  force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
184
233
  force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"]
@@ -18,6 +18,79 @@ class AWSOpenSearchTypedDict(TypedDict):
18
18
  port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
19
19
  user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
20
20
  password: Annotated[str, click.option("--password", type=str, help="Db password")]
21
+ number_of_shards: Annotated[
22
+ int,
23
+ click.option("--number-of-shards", type=int, help="Number of primary shards for the index", default=1),
24
+ ]
25
+ number_of_replicas: Annotated[
26
+ int,
27
+ click.option(
28
+ "--number-of-replicas", type=int, help="Number of replica copies for each primary shard", default=1
29
+ ),
30
+ ]
31
+ index_thread_qty: Annotated[
32
+ int,
33
+ click.option(
34
+ "--index-thread-qty",
35
+ type=int,
36
+ help="Thread count for native engine indexing",
37
+ default=4,
38
+ ),
39
+ ]
40
+
41
+ index_thread_qty_during_force_merge: Annotated[
42
+ int,
43
+ click.option(
44
+ "--index-thread-qty-during-force-merge",
45
+ type=int,
46
+ help="Thread count during force merge operations",
47
+ default=4,
48
+ ),
49
+ ]
50
+
51
+ number_of_indexing_clients: Annotated[
52
+ int,
53
+ click.option(
54
+ "--number-of-indexing-clients",
55
+ type=int,
56
+ help="Number of concurrent indexing clients",
57
+ default=1,
58
+ ),
59
+ ]
60
+
61
+ number_of_segments: Annotated[
62
+ int,
63
+ click.option("--number-of-segments", type=int, help="Target number of segments after merging", default=1),
64
+ ]
65
+
66
+ refresh_interval: Annotated[
67
+ int,
68
+ click.option(
69
+ "--refresh-interval", type=str, help="How often to make new data available for search", default="60s"
70
+ ),
71
+ ]
72
+
73
+ force_merge_enabled: Annotated[
74
+ int,
75
+ click.option("--force-merge-enabled", type=bool, help="Whether to perform force merge operation", default=True),
76
+ ]
77
+
78
+ flush_threshold_size: Annotated[
79
+ int,
80
+ click.option(
81
+ "--flush-threshold-size", type=str, help="Size threshold for flushing the transaction log", default="5120mb"
82
+ ),
83
+ ]
84
+
85
+ cb_threshold: Annotated[
86
+ int,
87
+ click.option(
88
+ "--cb-threshold",
89
+ type=str,
90
+ help="k-NN Memory circuit breaker threshold",
91
+ default="50%",
92
+ ),
93
+ ]
21
94
 
22
95
 
23
96
  class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ...
@@ -36,6 +109,17 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
36
109
  user=parameters["user"],
37
110
  password=SecretStr(parameters["password"]),
38
111
  ),
39
- db_case_config=AWSOpenSearchIndexConfig(),
112
+ db_case_config=AWSOpenSearchIndexConfig(
113
+ number_of_shards=parameters["number_of_shards"],
114
+ number_of_replicas=parameters["number_of_replicas"],
115
+ index_thread_qty=parameters["index_thread_qty"],
116
+ number_of_segments=parameters["number_of_segments"],
117
+ refresh_interval=parameters["refresh_interval"],
118
+ force_merge_enabled=parameters["force_merge_enabled"],
119
+ flush_threshold_size=parameters["flush_threshold_size"],
120
+ number_of_indexing_clients=parameters["number_of_indexing_clients"],
121
+ index_thread_qty_during_force_merge=parameters["index_thread_qty_during_force_merge"],
122
+ cb_threshold=parameters["cb_threshold"],
123
+ ),
40
124
  **parameters,
41
125
  )
@@ -39,6 +39,16 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
39
39
  efConstruction: int = 256
40
40
  efSearch: int = 256
41
41
  M: int = 16
42
+ index_thread_qty: int | None = 4
43
+ number_of_shards: int | None = 1
44
+ number_of_replicas: int | None = 0
45
+ number_of_segments: int | None = 1
46
+ refresh_interval: str | None = "60s"
47
+ force_merge_enabled: bool | None = True
48
+ flush_threshold_size: str | None = "5120mb"
49
+ number_of_indexing_clients: int | None = 1
50
+ index_thread_qty_during_force_merge: int
51
+ cb_threshold: str | None = "50%"
42
52
 
43
53
  def parse_metric(self) -> str:
44
54
  if self.metric_type == MetricType.IP:
@@ -0,0 +1,107 @@
1
+ from typing import Annotated, Optional, Unpack
2
+
3
+ import click
4
+ import os
5
+ from pydantic import SecretStr
6
+
7
+ from ....cli.cli import (
8
+ CommonTypedDict,
9
+ HNSWFlavor1,
10
+ cli,
11
+ click_parameter_decorators_from_typed_dict,
12
+ run,
13
+ )
14
+ from vectordb_bench.backend.clients import DB
15
+
16
+
17
+ class MariaDBTypedDict(CommonTypedDict):
18
+ user_name: Annotated[
19
+ str, click.option("--username",
20
+ type=str,
21
+ help="Username",
22
+ required=True,
23
+ ),
24
+ ]
25
+ password: Annotated[
26
+ str, click.option("--password",
27
+ type=str,
28
+ help="Password",
29
+ required=True,
30
+ ),
31
+ ]
32
+
33
+ host: Annotated[
34
+ str, click.option("--host",
35
+ type=str,
36
+ help="Db host",
37
+ default="127.0.0.1",
38
+ ),
39
+ ]
40
+
41
+ port: Annotated[
42
+ int, click.option("--port",
43
+ type=int,
44
+ default=3306,
45
+ help="Db Port",
46
+ ),
47
+ ]
48
+
49
+ storage_engine: Annotated[
50
+ int, click.option("--storage-engine",
51
+ type=click.Choice(["InnoDB", "MyISAM"]),
52
+ help="DB storage engine",
53
+ required=True,
54
+ ),
55
+ ]
56
+
57
+ class MariaDBHNSWTypedDict(MariaDBTypedDict):
58
+ ...
59
+ m: Annotated[
60
+ Optional[int], click.option("--m",
61
+ type=int,
62
+ help="M parameter in MHNSW vector indexing",
63
+ required=False,
64
+ ),
65
+ ]
66
+
67
+ ef_search: Annotated[
68
+ Optional[int], click.option("--ef-search",
69
+ type=int,
70
+ help="MariaDB system variable mhnsw_min_limit",
71
+ required=False,
72
+ ),
73
+ ]
74
+
75
+ max_cache_size: Annotated[
76
+ Optional[int], click.option("--max-cache-size",
77
+ type=int,
78
+ help="MariaDB system variable mhnsw_max_cache_size",
79
+ required=False,
80
+ ),
81
+ ]
82
+
83
+
84
+ @cli.command()
85
+ @click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict)
86
+ def MariaDBHNSW(
87
+ **parameters: Unpack[MariaDBHNSWTypedDict],
88
+ ):
89
+ from .config import MariaDBConfig, MariaDBHNSWConfig
90
+
91
+ run(
92
+ db=DB.MariaDB,
93
+ db_config=MariaDBConfig(
94
+ db_label=parameters["db_label"],
95
+ user_name=parameters["username"],
96
+ password=SecretStr(parameters["password"]),
97
+ host=parameters["host"],
98
+ port=parameters["port"],
99
+ ),
100
+ db_case_config=MariaDBHNSWConfig(
101
+ M=parameters["m"],
102
+ ef_search=parameters["ef_search"],
103
+ storage_engine=parameters["storage_engine"],
104
+ max_cache_size=parameters["max_cache_size"],
105
+ ),
106
+ **parameters,
107
+ )
@@ -0,0 +1,71 @@
1
+ from pydantic import SecretStr, BaseModel
2
+ from typing import TypedDict
3
+ from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
4
+
5
+ class MariaDBConfigDict(TypedDict):
6
+ """These keys will be directly used as kwargs in mariadb connection string,
7
+ so the names must match exactly mariadb API"""
8
+
9
+ user: str
10
+ password: str
11
+ host: str
12
+ port: int
13
+
14
+
15
+ class MariaDBConfig(DBConfig):
16
+ user_name: str = "root"
17
+ password: SecretStr
18
+ host: str = "127.0.0.1"
19
+ port: int = 3306
20
+
21
+ def to_dict(self) -> MariaDBConfigDict:
22
+ pwd_str = self.password.get_secret_value()
23
+ return {
24
+ "host": self.host,
25
+ "port": self.port,
26
+ "user": self.user_name,
27
+ "password": pwd_str,
28
+ }
29
+
30
+
31
+ class MariaDBIndexConfig(BaseModel):
32
+ """Base config for MariaDB"""
33
+
34
+ metric_type: MetricType | None = None
35
+
36
+ def parse_metric(self) -> str:
37
+ if self.metric_type == MetricType.L2:
38
+ return "euclidean"
39
+ elif self.metric_type == MetricType.COSINE:
40
+ return "cosine"
41
+ else:
42
+ raise ValueError(f"Metric type {self.metric_type} is not supported!")
43
+
44
+ class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
45
+ M: int | None
46
+ ef_search: int | None
47
+ index: IndexType = IndexType.HNSW
48
+ storage_engine: str = "InnoDB"
49
+ max_cache_size: int | None
50
+
51
+ def index_param(self) -> dict:
52
+ return {
53
+ "storage_engine": self.storage_engine,
54
+ "metric_type": self.parse_metric(),
55
+ "index_type": self.index.value,
56
+ "M": self.M,
57
+ "max_cache_size": self.max_cache_size,
58
+ }
59
+
60
+ def search_param(self) -> dict:
61
+ return {
62
+ "metric_type": self.parse_metric(),
63
+ "ef_search": self.ef_search,
64
+ }
65
+
66
+
67
+ _mariadb_case_config = {
68
+ IndexType.HNSW: MariaDBHNSWConfig,
69
+ }
70
+
71
+
@@ -0,0 +1,214 @@
1
+ from ..api import VectorDB
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from typing import Any, Optional, Tuple
6
+ from ..api import VectorDB
7
+ from .config import MariaDBConfigDict, MariaDBIndexConfig
8
+ import numpy as np
9
+
10
+ import mariadb
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ class MariaDB(VectorDB):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ db_config: MariaDBConfigDict,
19
+ db_case_config: MariaDBIndexConfig,
20
+ collection_name: str = "vec_collection",
21
+ drop_old: bool = False,
22
+ **kwargs,
23
+ ):
24
+
25
+ self.name = "MariaDB"
26
+ self.db_config = db_config
27
+ self.case_config = db_case_config
28
+ self.db_name = "vectordbbench"
29
+ self.table_name = collection_name
30
+ self.dim = dim
31
+
32
+ # construct basic units
33
+ self.conn, self.cursor = self._create_connection(**self.db_config)
34
+
35
+ if drop_old:
36
+ self._drop_db()
37
+ self._create_db_table(dim)
38
+
39
+ self.cursor.close()
40
+ self.conn.close()
41
+ self.cursor = None
42
+ self.conn = None
43
+
44
+
45
+ @staticmethod
46
+ def _create_connection(**kwargs) -> Tuple[mariadb.Connection, mariadb.Cursor]:
47
+ conn = mariadb.connect(**kwargs)
48
+ cursor = conn.cursor()
49
+
50
+ assert conn is not None, "Connection is not initialized"
51
+ assert cursor is not None, "Cursor is not initialized"
52
+
53
+ return conn, cursor
54
+
55
+
56
+ def _drop_db(self):
57
+ assert self.conn is not None, "Connection is not initialized"
58
+ assert self.cursor is not None, "Cursor is not initialized"
59
+ log.info(f"{self.name} client drop db : {self.db_name}")
60
+
61
+ # flush tables before dropping database to avoid some locking issue
62
+ self.cursor.execute("FLUSH TABLES")
63
+ self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}")
64
+ self.cursor.execute("COMMIT")
65
+ self.cursor.execute("FLUSH TABLES")
66
+
67
+ def _create_db_table(self, dim: int):
68
+ assert self.conn is not None, "Connection is not initialized"
69
+ assert self.cursor is not None, "Cursor is not initialized"
70
+
71
+ index_param = self.case_config.index_param()
72
+
73
+ try:
74
+ log.info(f"{self.name} client create database : {self.db_name}")
75
+ self.cursor.execute(f"CREATE DATABASE {self.db_name}")
76
+
77
+ log.info(f"{self.name} client create table : {self.table_name}")
78
+ self.cursor.execute(f"USE {self.db_name}")
79
+
80
+ self.cursor.execute(f"""
81
+ CREATE TABLE {self.table_name} (
82
+ id INT PRIMARY KEY,
83
+ v VECTOR({self.dim}) NOT NULL
84
+ ) ENGINE={index_param["storage_engine"]}
85
+ """)
86
+ self.cursor.execute("COMMIT")
87
+
88
+ except Exception as e:
89
+ log.warning(
90
+ f"Failed to create table: {self.table_name} error: {e}"
91
+ )
92
+ raise e from None
93
+
94
+
95
+ @contextmanager
96
+ def init(self) -> None:
97
+ """ create and destory connections to database.
98
+
99
+ Examples:
100
+ >>> with self.init():
101
+ >>> self.insert_embeddings()
102
+ """
103
+ self.conn, self.cursor = self._create_connection(**self.db_config)
104
+
105
+ index_param = self.case_config.index_param()
106
+ search_param = self.case_config.search_param()
107
+
108
+ # maximize allowed package size
109
+ self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824")
110
+
111
+ if index_param["index_type"] == "HNSW":
112
+ if index_param["max_cache_size"] != None:
113
+ self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param["max_cache_size"]}")
114
+ if search_param["ef_search"] != None:
115
+ self.cursor.execute(f"SET mhnsw_ef_search = {search_param["ef_search"]}")
116
+ self.cursor.execute("COMMIT")
117
+
118
+ self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)"
119
+ self.select_sql = f"SELECT id FROM {self.db_name}.{self.table_name} ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d"
120
+ self.select_sql_with_filter = f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d"
121
+
122
+ try:
123
+ yield
124
+ finally:
125
+ self.cursor.close()
126
+ self.conn.close()
127
+ self.cursor = None
128
+ self.conn = None
129
+
130
+
131
+ def ready_to_load(self) -> bool:
132
+ pass
133
+
134
+ def optimize(self) -> None:
135
+ assert self.conn is not None, "Connection is not initialized"
136
+ assert self.cursor is not None, "Cursor is not initialized"
137
+
138
+ index_param = self.case_config.index_param()
139
+
140
+ try:
141
+ index_options = f"DISTANCE={index_param['metric_type']}"
142
+ if index_param["index_type"] == "HNSW" and index_param["M"] != None:
143
+ index_options += f" M={index_param['M']}"
144
+
145
+ self.cursor.execute(f"""
146
+ ALTER TABLE {self.db_name}.{self.table_name}
147
+ ADD VECTOR KEY v(v) {index_options}
148
+ """)
149
+ self.cursor.execute("COMMIT")
150
+
151
+ except Exception as e:
152
+ log.warning(
153
+ f"Failed to create index: {self.table_name} error: {e}"
154
+ )
155
+ raise e from None
156
+
157
+ pass
158
+
159
+ @staticmethod
160
+ def vector_to_hex(v):
161
+ return np.array(v, 'float32').tobytes()
162
+
163
+ def insert_embeddings(
164
+ self,
165
+ embeddings: list[list[float]],
166
+ metadata: list[int],
167
+ **kwargs: Any,
168
+ ) -> Tuple[int, Optional[Exception]]:
169
+ """Insert embeddings into the database.
170
+ Should call self.init() first.
171
+ """
172
+ assert self.conn is not None, "Connection is not initialized"
173
+ assert self.cursor is not None, "Cursor is not initialized"
174
+
175
+ try:
176
+ metadata_arr = np.array(metadata)
177
+ embeddings_arr = np.array(embeddings)
178
+
179
+ batch_data = []
180
+ for i, row in enumerate(metadata_arr):
181
+ batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])));
182
+
183
+ self.cursor.executemany(self.insert_sql, batch_data)
184
+ self.cursor.execute("COMMIT")
185
+ self.cursor.execute("FLUSH TABLES")
186
+
187
+ return len(metadata), None
188
+ except Exception as e:
189
+ log.warning(
190
+ f"Failed to insert data into Vector table ({self.table_name}), error: {e}"
191
+ )
192
+ return 0, e
193
+
194
+
195
+ def search_embedding(
196
+ self,
197
+ query: list[float],
198
+ k: int = 100,
199
+ filters: dict | None = None,
200
+ timeout: int | None = None,
201
+ **kwargs: Any,
202
+ ) -> (list[int]):
203
+ assert self.conn is not None, "Connection is not initialized"
204
+ assert self.cursor is not None, "Cursor is not initialized"
205
+
206
+ search_param = self.case_config.search_param()
207
+
208
+ if filters:
209
+ self.cursor.execute(self.select_sql_with_filter, (filters.get('id'), self.vector_to_hex(query), k))
210
+ else:
211
+ self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k))
212
+
213
+ return [id for id, in self.cursor.fetchall()]
214
+