vectordb-bench 0.0.30__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. vectordb_bench/__init__.py +14 -27
  2. vectordb_bench/backend/assembler.py +19 -6
  3. vectordb_bench/backend/cases.py +186 -23
  4. vectordb_bench/backend/clients/__init__.py +16 -0
  5. vectordb_bench/backend/clients/api.py +22 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +82 -41
  7. vectordb_bench/backend/clients/aws_opensearch/config.py +23 -4
  8. vectordb_bench/backend/clients/chroma/chroma.py +6 -2
  9. vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
  10. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
  11. vectordb_bench/backend/clients/milvus/config.py +1 -0
  12. vectordb_bench/backend/clients/milvus/milvus.py +74 -22
  13. vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
  14. vectordb_bench/backend/clients/oceanbase/config.py +125 -0
  15. vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
  16. vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
  17. vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
  18. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
  19. vectordb_bench/backend/dataset.py +143 -27
  20. vectordb_bench/backend/filter.py +76 -0
  21. vectordb_bench/backend/runner/__init__.py +3 -3
  22. vectordb_bench/backend/runner/mp_runner.py +52 -39
  23. vectordb_bench/backend/runner/rate_runner.py +68 -52
  24. vectordb_bench/backend/runner/read_write_runner.py +125 -68
  25. vectordb_bench/backend/runner/serial_runner.py +56 -23
  26. vectordb_bench/backend/task_runner.py +48 -20
  27. vectordb_bench/cli/cli.py +59 -1
  28. vectordb_bench/cli/vectordbbench.py +3 -0
  29. vectordb_bench/frontend/components/check_results/data.py +16 -11
  30. vectordb_bench/frontend/components/check_results/filters.py +53 -25
  31. vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
  32. vectordb_bench/frontend/components/check_results/nav.py +20 -0
  33. vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
  34. vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
  35. vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
  36. vectordb_bench/frontend/components/label_filter/charts.py +60 -0
  37. vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
  38. vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
  39. vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
  40. vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
  41. vectordb_bench/frontend/components/streaming/charts.py +253 -0
  42. vectordb_bench/frontend/components/streaming/data.py +62 -0
  43. vectordb_bench/frontend/components/tables/data.py +1 -1
  44. vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
  45. vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
  46. vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
  47. vectordb_bench/frontend/config/dbCaseConfigs.py +307 -40
  48. vectordb_bench/frontend/config/styles.py +32 -2
  49. vectordb_bench/frontend/pages/concurrent.py +5 -1
  50. vectordb_bench/frontend/pages/custom.py +4 -0
  51. vectordb_bench/frontend/pages/label_filter.py +56 -0
  52. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
  53. vectordb_bench/frontend/pages/results.py +60 -0
  54. vectordb_bench/frontend/pages/run_test.py +3 -3
  55. vectordb_bench/frontend/pages/streaming.py +135 -0
  56. vectordb_bench/frontend/pages/tables.py +4 -0
  57. vectordb_bench/frontend/vdb_benchmark.py +16 -41
  58. vectordb_bench/interface.py +6 -2
  59. vectordb_bench/metric.py +15 -1
  60. vectordb_bench/models.py +31 -11
  61. vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
  62. vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
  63. vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
  64. vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
  65. vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
  66. vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
  67. vectordb_bench/results/dbPrices.json +12 -4
  68. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +85 -32
  69. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +73 -56
  70. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
  71. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
  72. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
  73. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +0 -0
  74. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
  76. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,8 @@ from contextlib import contextmanager
7
7
 
8
8
  from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException, utility
9
9
 
10
+ from vectordb_bench.backend.filter import Filter, FilterOp
11
+
10
12
  from ..api import VectorDB
11
13
  from .config import MilvusIndexConfig
12
14
 
@@ -16,6 +18,12 @@ MILVUS_LOAD_REQS_SIZE = 1.5 * 1024 * 1024
16
18
 
17
19
 
18
20
  class Milvus(VectorDB):
21
+ supported_filter_types: list[FilterOp] = [
22
+ FilterOp.NonFilter,
23
+ FilterOp.NumGE,
24
+ FilterOp.StrEqual,
25
+ ]
26
+
19
27
  def __init__(
20
28
  self,
21
29
  dim: int,
@@ -24,6 +32,7 @@ class Milvus(VectorDB):
24
32
  collection_name: str = "VectorDBBenchCollection",
25
33
  drop_old: bool = False,
26
34
  name: str = "Milvus",
35
+ with_scalar_labels: bool = False,
27
36
  **kwargs,
28
37
  ):
29
38
  """Initialize wrapper around the milvus vector database."""
@@ -32,11 +41,15 @@ class Milvus(VectorDB):
32
41
  self.case_config = db_case_config
33
42
  self.collection_name = collection_name
34
43
  self.batch_size = int(MILVUS_LOAD_REQS_SIZE / (dim * 4))
44
+ self.with_scalar_labels = with_scalar_labels
35
45
 
36
46
  self._primary_field = "pk"
37
- self._scalar_field = "id"
47
+ self._scalar_id_field = "id"
48
+ self._scalar_label_field = "label"
38
49
  self._vector_field = "vector"
39
- self._index_name = "vector_idx"
50
+ self._vector_index_name = "vector_idx"
51
+ self._scalar_id_index_name = "id_sort_idx"
52
+ self._scalar_labels_index_name = "labels_idx"
40
53
 
41
54
  from pymilvus import connections
42
55
 
@@ -53,9 +66,20 @@ class Milvus(VectorDB):
53
66
  if not utility.has_collection(self.collection_name):
54
67
  fields = [
55
68
  FieldSchema(self._primary_field, DataType.INT64, is_primary=True),
56
- FieldSchema(self._scalar_field, DataType.INT64),
69
+ FieldSchema(self._scalar_id_field, DataType.INT64),
57
70
  FieldSchema(self._vector_field, DataType.FLOAT_VECTOR, dim=dim),
58
71
  ]
72
+ if self.with_scalar_labels:
73
+ is_partition_key = db_case_config.use_partition_key
74
+ log.info(f"with_scalar_labels, add a new varchar field, as partition_key: {is_partition_key}")
75
+ fields.append(
76
+ FieldSchema(
77
+ self._scalar_label_field,
78
+ DataType.VARCHAR,
79
+ max_length=256,
80
+ is_partition_key=is_partition_key,
81
+ )
82
+ )
59
83
 
60
84
  log.info(f"{self.name} create collection: {self.collection_name}")
61
85
 
@@ -67,16 +91,37 @@ class Milvus(VectorDB):
67
91
  num_shards=self.db_config.get("num_shards"),
68
92
  )
69
93
 
70
- log.info(f"{self.name} create index: index_params: {self.case_config.index_param()}")
71
- col.create_index(
72
- self._vector_field,
73
- self.case_config.index_param(),
74
- index_name=self._index_name,
75
- )
94
+ self.create_index()
76
95
  col.load()
77
96
 
78
97
  connections.disconnect("default")
79
98
 
99
+ def create_index(self):
100
+ col = Collection(self.collection_name)
101
+ # vector index
102
+ col.create_index(
103
+ self._vector_field,
104
+ self.case_config.index_param(),
105
+ index_name=self._vector_index_name,
106
+ )
107
+ # scalar index for range-expr (int-filter)
108
+ col.create_index(
109
+ self._scalar_id_field,
110
+ index_params={
111
+ "index_type": "STL_SORT",
112
+ },
113
+ index_name=self._scalar_id_index_name,
114
+ )
115
+ # scalar index for varchar (label-filter)
116
+ if self.with_scalar_labels:
117
+ col.create_index(
118
+ self._scalar_label_field,
119
+ index_params={
120
+ "index_type": "BITMAP",
121
+ },
122
+ index_name=self._scalar_labels_index_name,
123
+ )
124
+
80
125
  @contextmanager
81
126
  def init(self):
82
127
  """
@@ -109,17 +154,13 @@ class Milvus(VectorDB):
109
154
  try:
110
155
  self.col.flush()
111
156
  # wait for index done and load refresh
112
- self.col.create_index(
113
- self._vector_field,
114
- self.case_config.index_param(),
115
- index_name=self._index_name,
116
- )
157
+ self.create_index()
117
158
 
118
- utility.wait_for_index_building_complete(self.collection_name)
159
+ utility.wait_for_index_building_complete(self.collection_name, index_name=self._vector_index_name)
119
160
 
120
161
  def wait_index():
121
162
  while True:
122
- progress = utility.index_building_progress(self.collection_name)
163
+ progress = utility.index_building_progress(self.collection_name, index_name=self._vector_index_name)
123
164
  if progress.get("pending_index_rows", -1) == 0:
124
165
  break
125
166
  time.sleep(5)
@@ -162,6 +203,7 @@ class Milvus(VectorDB):
162
203
  self,
163
204
  embeddings: Iterable[list[float]],
164
205
  metadata: list[int],
206
+ labels_data: list[str] | None = None,
165
207
  **kwargs,
166
208
  ) -> tuple[int, Exception]:
167
209
  """Insert embeddings into Milvus. should call self.init() first"""
@@ -177,32 +219,42 @@ class Milvus(VectorDB):
177
219
  metadata[batch_start_offset:batch_end_offset],
178
220
  embeddings[batch_start_offset:batch_end_offset],
179
221
  ]
222
+ if self.with_scalar_labels:
223
+ insert_data.append(labels_data[batch_start_offset:batch_end_offset])
180
224
  res = self.col.insert(insert_data)
181
225
  insert_count += len(res.primary_keys)
182
226
  except MilvusException as e:
183
227
  log.info(f"Failed to insert data: {e}")
184
- return (insert_count, e)
185
- return (insert_count, None)
228
+ return insert_count, e
229
+ return insert_count, None
230
+
231
+ def prepare_filter(self, filters: Filter):
232
+ if filters.type == FilterOp.NonFilter:
233
+ self.expr = ""
234
+ elif filters.type == FilterOp.NumGE:
235
+ self.expr = f"{self._scalar_id_field} >= {filters.int_value}"
236
+ elif filters.type == FilterOp.StrEqual:
237
+ self.expr = f"{self._scalar_label_field} == '{filters.label_value}'"
238
+ else:
239
+ msg = f"Not support Filter for Milvus - {filters}"
240
+ raise ValueError(msg)
186
241
 
187
242
  def search_embedding(
188
243
  self,
189
244
  query: list[float],
190
245
  k: int = 100,
191
- filters: dict | None = None,
192
246
  timeout: int | None = None,
193
247
  ) -> list[int]:
194
248
  """Perform a search on a query embedding and return results."""
195
249
  assert self.col is not None
196
250
 
197
- expr = f"{self._scalar_field} {filters.get('metadata')}" if filters else ""
198
-
199
251
  # Perform the search.
200
252
  res = self.col.search(
201
253
  data=[query],
202
254
  anns_field=self._vector_field,
203
255
  param=self.case_config.search_param(),
204
256
  limit=k,
205
- expr=expr,
257
+ expr=self.expr,
206
258
  )
207
259
 
208
260
  # Organize results.
@@ -0,0 +1,100 @@
1
+ import os
2
+ from typing import Annotated, Unpack
3
+
4
+ import click
5
+ from pydantic import SecretStr
6
+
7
+ from vectordb_bench.backend.clients import DB
8
+ from vectordb_bench.cli.cli import (
9
+ CommonTypedDict,
10
+ HNSWFlavor4,
11
+ OceanBaseIVFTypedDict,
12
+ cli,
13
+ click_parameter_decorators_from_typed_dict,
14
+ run,
15
+ )
16
+
17
+ from ..api import IndexType
18
+
19
+
20
+ class OceanBaseTypedDict(CommonTypedDict):
21
+ host: Annotated[str, click.option("--host", type=str, help="OceanBase host", default="")]
22
+ user: Annotated[str, click.option("--user", type=str, help="OceanBase username", required=True)]
23
+ password: Annotated[
24
+ str,
25
+ click.option(
26
+ "--password",
27
+ type=str,
28
+ help="OceanBase database password",
29
+ default=lambda: os.environ.get("OB_PASSWORD", ""),
30
+ ),
31
+ ]
32
+ database: Annotated[str, click.option("--database", type=str, help="DataBase name", required=True)]
33
+ port: Annotated[int, click.option("--port", type=int, help="OceanBase port", required=True)]
34
+
35
+
36
+ class OceanBaseHNSWTypedDict(CommonTypedDict, OceanBaseTypedDict, HNSWFlavor4): ...
37
+
38
+
39
+ @cli.command()
40
+ @click_parameter_decorators_from_typed_dict(OceanBaseHNSWTypedDict)
41
+ def OceanBaseHNSW(**parameters: Unpack[OceanBaseHNSWTypedDict]):
42
+ from .config import OceanBaseConfig, OceanBaseHNSWConfig
43
+
44
+ run(
45
+ db=DB.OceanBase,
46
+ db_config=OceanBaseConfig(
47
+ db_label=parameters["db_label"],
48
+ user=SecretStr(parameters["user"]),
49
+ password=SecretStr(parameters["password"]),
50
+ host=parameters["host"],
51
+ port=parameters["port"],
52
+ database=parameters["database"],
53
+ ),
54
+ db_case_config=OceanBaseHNSWConfig(
55
+ m=parameters["m"],
56
+ efConstruction=parameters["ef_construction"],
57
+ ef_search=parameters["ef_search"],
58
+ index=parameters["index_type"],
59
+ ),
60
+ **parameters,
61
+ )
62
+
63
+
64
+ class OceanBaseIVFTypedDict(CommonTypedDict, OceanBaseTypedDict, OceanBaseIVFTypedDict): ...
65
+
66
+
67
+ @cli.command()
68
+ @click_parameter_decorators_from_typed_dict(OceanBaseIVFTypedDict)
69
+ def OceanBaseIVF(**parameters: Unpack[OceanBaseIVFTypedDict]):
70
+ from .config import OceanBaseConfig, OceanBaseIVFConfig
71
+
72
+ type_str = parameters["index_type"]
73
+ if type_str == "IVF_FLAT":
74
+ input_index_type = IndexType.IVFFlat
75
+ elif type_str == "IVF_PQ":
76
+ input_index_type = IndexType.IVFPQ
77
+ elif type_str == "IVF_SQ8":
78
+ input_index_type = IndexType.IVFSQ8
79
+
80
+ input_m = 0 if parameters["m"] is None else parameters["m"]
81
+
82
+ run(
83
+ db=DB.OceanBase,
84
+ db_config=OceanBaseConfig(
85
+ db_label=parameters["db_label"],
86
+ user=SecretStr(parameters["user"]),
87
+ password=SecretStr(parameters["password"]),
88
+ host=parameters["host"],
89
+ port=parameters["port"],
90
+ database=parameters["database"],
91
+ ),
92
+ db_case_config=OceanBaseIVFConfig(
93
+ m=input_m,
94
+ nlist=parameters["nlist"],
95
+ sample_per_nlist=parameters["sample_per_nlist"],
96
+ index=input_index_type,
97
+ ivf_nprobes=parameters["ivf_nprobes"],
98
+ ),
99
+ **parameters,
100
+ )
@@ -0,0 +1,125 @@
1
+ from typing import TypedDict
2
+
3
+ from pydantic import BaseModel, SecretStr, validator
4
+
5
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
+
7
+
8
+ class OceanBaseConfigDict(TypedDict):
9
+ user: str
10
+ host: str
11
+ port: str
12
+ password: str
13
+ database: str
14
+
15
+
16
+ class OceanBaseConfig(DBConfig):
17
+ user: SecretStr = SecretStr("root@perf")
18
+ password: SecretStr
19
+ host: str
20
+ port: int
21
+ database: str
22
+
23
+ def to_dict(self) -> OceanBaseConfigDict:
24
+ user_str = self.user.get_secret_value()
25
+ pwd_str = self.password.get_secret_value()
26
+ return {
27
+ "user": user_str,
28
+ "host": self.host,
29
+ "port": self.port,
30
+ "password": pwd_str,
31
+ "database": self.database,
32
+ }
33
+
34
+ @validator("*")
35
+ def not_empty_field(cls, v: any, field: any):
36
+ if field.name in ["password", "host", "db_label"]:
37
+ return v
38
+ if isinstance(v, str | SecretStr) and len(v) == 0:
39
+ raise ValueError("Empty string!")
40
+ return v
41
+
42
+
43
+ class OceanBaseIndexConfig(BaseModel):
44
+ index: IndexType
45
+ metric_type: MetricType | None = None
46
+ lib: str = "vsag"
47
+
48
+ def parse_metric(self) -> str:
49
+ if self.metric_type == MetricType.L2 or (
50
+ self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE
51
+ ):
52
+ return "l2"
53
+ if self.metric_type == MetricType.IP:
54
+ return "inner_product"
55
+ return "cosine"
56
+
57
+ def parse_metric_func_str(self) -> str:
58
+ if self.metric_type == MetricType.L2 or (
59
+ self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE
60
+ ):
61
+ return "l2_distance"
62
+ if self.metric_type == MetricType.IP:
63
+ return "negative_inner_product"
64
+ return "cosine_distance"
65
+
66
+
67
+ class OceanBaseHNSWConfig(OceanBaseIndexConfig, DBCaseConfig):
68
+ m: int
69
+ efConstruction: int
70
+ ef_search: int | None = None
71
+ index: IndexType
72
+
73
+ def index_param(self) -> dict:
74
+ return {
75
+ "lib": self.lib,
76
+ "metric_type": self.parse_metric(),
77
+ "index_type": self.index.value,
78
+ "params": {"m": self.m, "ef_construction": self.efConstruction},
79
+ }
80
+
81
+ def search_param(self) -> dict:
82
+ return {"metric_type": self.parse_metric_func_str(), "params": {"ef_search": self.ef_search}}
83
+
84
+
85
+ class OceanBaseIVFConfig(OceanBaseIndexConfig, DBCaseConfig):
86
+ m: int
87
+ sample_per_nlist: int
88
+ nlist: int
89
+ index: IndexType
90
+ ivf_nprobes: int | None = None
91
+
92
+ def index_param(self) -> dict:
93
+ if self.index == IndexType.IVFPQ:
94
+ return {
95
+ "lib": "OB",
96
+ "metric_type": self.parse_metric(),
97
+ "index_type": self.index.value,
98
+ "params": {
99
+ "m": self.M,
100
+ "sample_per_nlist": self.sample_per_nlist,
101
+ "nlist": self.nlist,
102
+ },
103
+ }
104
+ return {
105
+ "lib": "OB",
106
+ "metric_type": self.parse_metric(),
107
+ "index_type": self.index.value,
108
+ "params": {
109
+ "sample_per_nlist": self.sample_per_nlist,
110
+ "nlist": self.nlist,
111
+ },
112
+ }
113
+
114
+ def search_param(self) -> dict:
115
+ return {"metric_type": self.metric_type, "params": {"ivf_nprobes": self.ivf_nprobes}}
116
+
117
+
118
+ _oceanbase_case_config = {
119
+ IndexType.HNSW_SQ: OceanBaseHNSWConfig,
120
+ IndexType.HNSW: OceanBaseHNSWConfig,
121
+ IndexType.HNSW_BQ: OceanBaseHNSWConfig,
122
+ IndexType.IVFFlat: OceanBaseIVFConfig,
123
+ IndexType.IVFPQ: OceanBaseIVFConfig,
124
+ IndexType.IVFSQ8: OceanBaseIVFConfig,
125
+ }
@@ -0,0 +1,215 @@
1
+ import logging
2
+ import struct
3
+ import time
4
+ from collections.abc import Generator
5
+ from contextlib import contextmanager
6
+ from typing import Any
7
+
8
+ import mysql.connector as mysql
9
+
10
+ from ..api import IndexType, VectorDB
11
+ from .config import OceanBaseConfigDict, OceanBaseHNSWConfig
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+ OCEANBASE_DEFAULT_LOAD_BATCH_SIZE = 256
16
+
17
+
18
+ class OceanBase(VectorDB):
19
+ def __init__(
20
+ self,
21
+ dim: int,
22
+ db_config: OceanBaseConfigDict,
23
+ db_case_config: OceanBaseHNSWConfig,
24
+ collection_name: str = "items",
25
+ drop_old: bool = False,
26
+ **kwargs,
27
+ ):
28
+ self.name = "OceanBase"
29
+ self.dim = dim
30
+ self.db_config = db_config
31
+ self.db_case_config = db_case_config
32
+ self.table_name = collection_name
33
+ self.load_batch_size = OCEANBASE_DEFAULT_LOAD_BATCH_SIZE
34
+ self._index_name = "vidx"
35
+ self._primary_field = "id"
36
+ self._vector_field = "embedding"
37
+
38
+ log.info(
39
+ f"{self.name} initialized with config:\nDatabase: {self.db_config}\nCase Config: {self.db_case_config}"
40
+ )
41
+
42
+ self._conn = None
43
+ self._cursor = None
44
+
45
+ try:
46
+ self._connect()
47
+ if drop_old:
48
+ self._drop_table()
49
+ self._create_table()
50
+ finally:
51
+ self._disconnect()
52
+
53
+ def _connect(self):
54
+ try:
55
+ self._conn = mysql.connect(
56
+ host=self.db_config["host"],
57
+ user=self.db_config["user"],
58
+ port=self.db_config["port"],
59
+ password=self.db_config["password"],
60
+ database=self.db_config["database"],
61
+ )
62
+ self._cursor = self._conn.cursor()
63
+ except mysql.Error:
64
+ log.exception("Failed to connect to the database")
65
+ raise
66
+
67
+ def _disconnect(self):
68
+ if self._cursor:
69
+ self._cursor.close()
70
+ self._cursor = None
71
+ if self._conn:
72
+ self._conn.close()
73
+ self._conn = None
74
+
75
+ @contextmanager
76
+ def init(self) -> Generator[None, None, None]:
77
+ try:
78
+ self._connect()
79
+ self._cursor.execute("SET autocommit=1")
80
+
81
+ if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}:
82
+ self._cursor.execute(
83
+ f"SET ob_hnsw_ef_search={(self.db_case_config.search_param())['params']['ef_search']}"
84
+ )
85
+ else:
86
+ self._cursor.execute(
87
+ f"SET ob_ivf_nprobes={(self.db_case_config.search_param())['params']['ivf_nprobes']}"
88
+ )
89
+ yield
90
+ finally:
91
+ self._disconnect()
92
+
93
+ def _drop_table(self):
94
+ if not self._cursor:
95
+ raise ValueError("Cursor is not initialized")
96
+
97
+ log.info(f"Dropping table {self.table_name}")
98
+ self._cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
99
+
100
+ def _create_table(self):
101
+ if not self._cursor:
102
+ raise ValueError("Cursor is not initialized")
103
+
104
+ log.info(f"Creating table {self.table_name}")
105
+ create_table_query = f"""
106
+ CREATE TABLE {self.table_name} (
107
+ id INT PRIMARY KEY,
108
+ embedding VECTOR({self.dim})
109
+ );
110
+ """
111
+ self._cursor.execute(create_table_query)
112
+
113
+ def optimize(self, data_size: int):
114
+ index_params = self.db_case_config.index_param()
115
+ index_args = ", ".join(f"{k}={v}" for k, v in index_params["params"].items())
116
+ index_query = (
117
+ f"CREATE /*+ PARALLEL(18) */ VECTOR INDEX idx1 "
118
+ f"ON {self.table_name}(embedding) "
119
+ f"WITH (distance={self.db_case_config.parse_metric()}, "
120
+ f"type={index_params['index_type']}, lib={index_params['lib']}, {index_args}"
121
+ )
122
+
123
+ if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}:
124
+ index_query += ", extra_info_max_size=32"
125
+
126
+ index_query += ")"
127
+
128
+ log.info("Create index query: %s", index_query)
129
+
130
+ try:
131
+ log.info("Creating index...")
132
+ start_time = time.time()
133
+ self._cursor.execute(index_query)
134
+ log.info(f"Index created in {time.time() - start_time:.2f} seconds")
135
+
136
+ log.info("Performing major freeze...")
137
+ self._cursor.execute("ALTER SYSTEM MAJOR FREEZE;")
138
+ time.sleep(10)
139
+ self._wait_for_major_compaction()
140
+
141
+ log.info("Gathering schema statistics...")
142
+ self._cursor.execute("CALL dbms_stats.gather_schema_stats('test', degree => 96);")
143
+ except mysql.Error:
144
+ log.exception("Failed to optimize index")
145
+ raise
146
+
147
+ def need_normalize_cosine(self) -> bool:
148
+ if self.db_case_config.index == IndexType.HNSW_BQ:
149
+ log.info("current HNSW_BQ only supports L2, cosine dataset need normalize.")
150
+ return True
151
+
152
+ return False
153
+
154
+ def _wait_for_major_compaction(self):
155
+ while True:
156
+ self._cursor.execute(
157
+ "SELECT IF(COUNT(*) = COUNT(STATUS = 'IDLE' OR NULL), 'TRUE', 'FALSE') "
158
+ "AS all_status_idle FROM oceanbase.DBA_OB_ZONE_MAJOR_COMPACTION;"
159
+ )
160
+ all_status_idle = self._cursor.fetchone()[0]
161
+ if all_status_idle == "TRUE":
162
+ break
163
+ time.sleep(10)
164
+
165
+ def insert_embeddings(
166
+ self,
167
+ embeddings: list[list[float]],
168
+ metadata: list[int],
169
+ **kwargs: Any,
170
+ ) -> tuple[int, Exception | None]:
171
+ if not self._cursor:
172
+ raise ValueError("Cursor is not initialized")
173
+
174
+ insert_count = 0
175
+ try:
176
+ for batch_start in range(0, len(embeddings), self.load_batch_size):
177
+ batch_end = min(batch_start + self.load_batch_size, len(embeddings))
178
+ batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)]
179
+ values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch)
180
+ self._cursor.execute(
181
+ f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}" # noqa: S608
182
+ )
183
+ insert_count += len(batch)
184
+ except mysql.Error:
185
+ log.exception("Failed to insert embeddings")
186
+ raise
187
+
188
+ return insert_count, None
189
+
190
+ def search_embedding(
191
+ self,
192
+ query: list[float],
193
+ k: int = 100,
194
+ filters: dict[str, Any] | None = None,
195
+ timeout: int | None = None,
196
+ ) -> list[int]:
197
+ if not self._cursor:
198
+ raise ValueError("Cursor is not initialized")
199
+
200
+ packed = struct.pack(f"<{len(query)}f", *query)
201
+ hex_vec = packed.hex()
202
+ filter_clause = f"WHERE id >= {filters['id']}" if filters else ""
203
+ query_str = (
204
+ f"SELECT id FROM {self.table_name} " # noqa: S608
205
+ f"{filter_clause} ORDER BY "
206
+ f"{self.db_case_config.parse_metric_func_str()}(embedding, X'{hex_vec}') "
207
+ f"APPROXIMATE LIMIT {k}"
208
+ )
209
+
210
+ try:
211
+ self._cursor.execute(query_str)
212
+ return [row[0] for row in self._cursor.fetchall()]
213
+ except mysql.Error:
214
+ log.exception("Failed to execute search query")
215
+ raise