vectordb-bench 0.0.30__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -27
- vectordb_bench/__main__.py +1 -1
- vectordb_bench/backend/assembler.py +19 -6
- vectordb_bench/backend/cases.py +186 -23
- vectordb_bench/backend/clients/__init__.py +16 -0
- vectordb_bench/backend/clients/api.py +22 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +82 -41
- vectordb_bench/backend/clients/aws_opensearch/config.py +37 -4
- vectordb_bench/backend/clients/chroma/chroma.py +6 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +31 -1
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
- vectordb_bench/backend/clients/milvus/config.py +1 -0
- vectordb_bench/backend/clients/milvus/milvus.py +75 -23
- vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
- vectordb_bench/backend/clients/oceanbase/config.py +125 -0
- vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
- vectordb_bench/backend/clients/qdrant_cloud/config.py +73 -3
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
- vectordb_bench/backend/dataset.py +146 -27
- vectordb_bench/backend/filter.py +76 -0
- vectordb_bench/backend/runner/__init__.py +3 -3
- vectordb_bench/backend/runner/mp_runner.py +52 -39
- vectordb_bench/backend/runner/rate_runner.py +68 -52
- vectordb_bench/backend/runner/read_write_runner.py +125 -68
- vectordb_bench/backend/runner/serial_runner.py +56 -23
- vectordb_bench/backend/task_runner.py +59 -20
- vectordb_bench/cli/cli.py +59 -1
- vectordb_bench/cli/vectordbbench.py +3 -0
- vectordb_bench/frontend/components/check_results/data.py +16 -11
- vectordb_bench/frontend/components/check_results/filters.py +53 -25
- vectordb_bench/frontend/components/check_results/headerIcon.py +18 -13
- vectordb_bench/frontend/components/check_results/nav.py +20 -0
- vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
- vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
- vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
- vectordb_bench/frontend/components/label_filter/charts.py +60 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
- vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
- vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
- vectordb_bench/frontend/components/streaming/charts.py +253 -0
- vectordb_bench/frontend/components/streaming/data.py +62 -0
- vectordb_bench/frontend/components/tables/data.py +1 -1
- vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
- vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
- vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +309 -42
- vectordb_bench/frontend/config/styles.py +34 -4
- vectordb_bench/frontend/pages/concurrent.py +5 -1
- vectordb_bench/frontend/pages/custom.py +4 -0
- vectordb_bench/frontend/pages/label_filter.py +56 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
- vectordb_bench/frontend/{vdb_benchmark.py → pages/results.py} +10 -4
- vectordb_bench/frontend/pages/run_test.py +3 -3
- vectordb_bench/frontend/pages/streaming.py +135 -0
- vectordb_bench/frontend/pages/tables.py +4 -0
- vectordb_bench/frontend/vdbbench.py +31 -0
- vectordb_bench/interface.py +8 -3
- vectordb_bench/metric.py +15 -1
- vectordb_bench/models.py +31 -11
- vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
- vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
- vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
- vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
- vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
- vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
- vectordb_bench/results/dbPrices.json +12 -4
- vectordb_bench/results/getLeaderboardDataV2.py +59 -0
- vectordb_bench/results/leaderboard_v2.json +2662 -0
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.1.dist-info}/METADATA +93 -40
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.1.dist-info}/RECORD +77 -58
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.1.dist-info}/WHEEL +0 -0
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.1.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.1.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,8 @@ from contextlib import contextmanager
|
|
7
7
|
|
8
8
|
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException, utility
|
9
9
|
|
10
|
+
from vectordb_bench.backend.filter import Filter, FilterOp
|
11
|
+
|
10
12
|
from ..api import VectorDB
|
11
13
|
from .config import MilvusIndexConfig
|
12
14
|
|
@@ -16,14 +18,21 @@ MILVUS_LOAD_REQS_SIZE = 1.5 * 1024 * 1024
|
|
16
18
|
|
17
19
|
|
18
20
|
class Milvus(VectorDB):
|
21
|
+
supported_filter_types: list[FilterOp] = [
|
22
|
+
FilterOp.NonFilter,
|
23
|
+
FilterOp.NumGE,
|
24
|
+
FilterOp.StrEqual,
|
25
|
+
]
|
26
|
+
|
19
27
|
def __init__(
|
20
28
|
self,
|
21
29
|
dim: int,
|
22
30
|
db_config: dict,
|
23
31
|
db_case_config: MilvusIndexConfig,
|
24
|
-
collection_name: str = "
|
32
|
+
collection_name: str = "VDBBench",
|
25
33
|
drop_old: bool = False,
|
26
34
|
name: str = "Milvus",
|
35
|
+
with_scalar_labels: bool = False,
|
27
36
|
**kwargs,
|
28
37
|
):
|
29
38
|
"""Initialize wrapper around the milvus vector database."""
|
@@ -32,11 +41,15 @@ class Milvus(VectorDB):
|
|
32
41
|
self.case_config = db_case_config
|
33
42
|
self.collection_name = collection_name
|
34
43
|
self.batch_size = int(MILVUS_LOAD_REQS_SIZE / (dim * 4))
|
44
|
+
self.with_scalar_labels = with_scalar_labels
|
35
45
|
|
36
46
|
self._primary_field = "pk"
|
37
|
-
self.
|
47
|
+
self._scalar_id_field = "id"
|
48
|
+
self._scalar_label_field = "label"
|
38
49
|
self._vector_field = "vector"
|
39
|
-
self.
|
50
|
+
self._vector_index_name = "vector_idx"
|
51
|
+
self._scalar_id_index_name = "id_sort_idx"
|
52
|
+
self._scalar_labels_index_name = "labels_idx"
|
40
53
|
|
41
54
|
from pymilvus import connections
|
42
55
|
|
@@ -53,9 +66,20 @@ class Milvus(VectorDB):
|
|
53
66
|
if not utility.has_collection(self.collection_name):
|
54
67
|
fields = [
|
55
68
|
FieldSchema(self._primary_field, DataType.INT64, is_primary=True),
|
56
|
-
FieldSchema(self.
|
69
|
+
FieldSchema(self._scalar_id_field, DataType.INT64),
|
57
70
|
FieldSchema(self._vector_field, DataType.FLOAT_VECTOR, dim=dim),
|
58
71
|
]
|
72
|
+
if self.with_scalar_labels:
|
73
|
+
is_partition_key = db_case_config.use_partition_key
|
74
|
+
log.info(f"with_scalar_labels, add a new varchar field, as partition_key: {is_partition_key}")
|
75
|
+
fields.append(
|
76
|
+
FieldSchema(
|
77
|
+
self._scalar_label_field,
|
78
|
+
DataType.VARCHAR,
|
79
|
+
max_length=256,
|
80
|
+
is_partition_key=is_partition_key,
|
81
|
+
)
|
82
|
+
)
|
59
83
|
|
60
84
|
log.info(f"{self.name} create collection: {self.collection_name}")
|
61
85
|
|
@@ -67,16 +91,37 @@ class Milvus(VectorDB):
|
|
67
91
|
num_shards=self.db_config.get("num_shards"),
|
68
92
|
)
|
69
93
|
|
70
|
-
|
71
|
-
col.create_index(
|
72
|
-
self._vector_field,
|
73
|
-
self.case_config.index_param(),
|
74
|
-
index_name=self._index_name,
|
75
|
-
)
|
94
|
+
self.create_index()
|
76
95
|
col.load()
|
77
96
|
|
78
97
|
connections.disconnect("default")
|
79
98
|
|
99
|
+
def create_index(self):
|
100
|
+
col = Collection(self.collection_name)
|
101
|
+
# vector index
|
102
|
+
col.create_index(
|
103
|
+
self._vector_field,
|
104
|
+
self.case_config.index_param(),
|
105
|
+
index_name=self._vector_index_name,
|
106
|
+
)
|
107
|
+
# scalar index for range-expr (int-filter)
|
108
|
+
col.create_index(
|
109
|
+
self._scalar_id_field,
|
110
|
+
index_params={
|
111
|
+
"index_type": "STL_SORT",
|
112
|
+
},
|
113
|
+
index_name=self._scalar_id_index_name,
|
114
|
+
)
|
115
|
+
# scalar index for varchar (label-filter)
|
116
|
+
if self.with_scalar_labels:
|
117
|
+
col.create_index(
|
118
|
+
self._scalar_label_field,
|
119
|
+
index_params={
|
120
|
+
"index_type": "BITMAP",
|
121
|
+
},
|
122
|
+
index_name=self._scalar_labels_index_name,
|
123
|
+
)
|
124
|
+
|
80
125
|
@contextmanager
|
81
126
|
def init(self):
|
82
127
|
"""
|
@@ -109,17 +154,13 @@ class Milvus(VectorDB):
|
|
109
154
|
try:
|
110
155
|
self.col.flush()
|
111
156
|
# wait for index done and load refresh
|
112
|
-
self.
|
113
|
-
self._vector_field,
|
114
|
-
self.case_config.index_param(),
|
115
|
-
index_name=self._index_name,
|
116
|
-
)
|
157
|
+
self.create_index()
|
117
158
|
|
118
|
-
utility.wait_for_index_building_complete(self.collection_name)
|
159
|
+
utility.wait_for_index_building_complete(self.collection_name, index_name=self._vector_index_name)
|
119
160
|
|
120
161
|
def wait_index():
|
121
162
|
while True:
|
122
|
-
progress = utility.index_building_progress(self.collection_name)
|
163
|
+
progress = utility.index_building_progress(self.collection_name, index_name=self._vector_index_name)
|
123
164
|
if progress.get("pending_index_rows", -1) == 0:
|
124
165
|
break
|
125
166
|
time.sleep(5)
|
@@ -162,6 +203,7 @@ class Milvus(VectorDB):
|
|
162
203
|
self,
|
163
204
|
embeddings: Iterable[list[float]],
|
164
205
|
metadata: list[int],
|
206
|
+
labels_data: list[str] | None = None,
|
165
207
|
**kwargs,
|
166
208
|
) -> tuple[int, Exception]:
|
167
209
|
"""Insert embeddings into Milvus. should call self.init() first"""
|
@@ -177,32 +219,42 @@ class Milvus(VectorDB):
|
|
177
219
|
metadata[batch_start_offset:batch_end_offset],
|
178
220
|
embeddings[batch_start_offset:batch_end_offset],
|
179
221
|
]
|
222
|
+
if self.with_scalar_labels:
|
223
|
+
insert_data.append(labels_data[batch_start_offset:batch_end_offset])
|
180
224
|
res = self.col.insert(insert_data)
|
181
225
|
insert_count += len(res.primary_keys)
|
182
226
|
except MilvusException as e:
|
183
227
|
log.info(f"Failed to insert data: {e}")
|
184
|
-
return
|
185
|
-
return
|
228
|
+
return insert_count, e
|
229
|
+
return insert_count, None
|
230
|
+
|
231
|
+
def prepare_filter(self, filters: Filter):
|
232
|
+
if filters.type == FilterOp.NonFilter:
|
233
|
+
self.expr = ""
|
234
|
+
elif filters.type == FilterOp.NumGE:
|
235
|
+
self.expr = f"{self._scalar_id_field} >= {filters.int_value}"
|
236
|
+
elif filters.type == FilterOp.StrEqual:
|
237
|
+
self.expr = f"{self._scalar_label_field} == '{filters.label_value}'"
|
238
|
+
else:
|
239
|
+
msg = f"Not support Filter for Milvus - {filters}"
|
240
|
+
raise ValueError(msg)
|
186
241
|
|
187
242
|
def search_embedding(
|
188
243
|
self,
|
189
244
|
query: list[float],
|
190
245
|
k: int = 100,
|
191
|
-
filters: dict | None = None,
|
192
246
|
timeout: int | None = None,
|
193
247
|
) -> list[int]:
|
194
248
|
"""Perform a search on a query embedding and return results."""
|
195
249
|
assert self.col is not None
|
196
250
|
|
197
|
-
expr = f"{self._scalar_field} {filters.get('metadata')}" if filters else ""
|
198
|
-
|
199
251
|
# Perform the search.
|
200
252
|
res = self.col.search(
|
201
253
|
data=[query],
|
202
254
|
anns_field=self._vector_field,
|
203
255
|
param=self.case_config.search_param(),
|
204
256
|
limit=k,
|
205
|
-
expr=expr,
|
257
|
+
expr=self.expr,
|
206
258
|
)
|
207
259
|
|
208
260
|
# Organize results.
|
@@ -0,0 +1,100 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Annotated, Unpack
|
3
|
+
|
4
|
+
import click
|
5
|
+
from pydantic import SecretStr
|
6
|
+
|
7
|
+
from vectordb_bench.backend.clients import DB
|
8
|
+
from vectordb_bench.cli.cli import (
|
9
|
+
CommonTypedDict,
|
10
|
+
HNSWFlavor4,
|
11
|
+
OceanBaseIVFTypedDict,
|
12
|
+
cli,
|
13
|
+
click_parameter_decorators_from_typed_dict,
|
14
|
+
run,
|
15
|
+
)
|
16
|
+
|
17
|
+
from ..api import IndexType
|
18
|
+
|
19
|
+
|
20
|
+
class OceanBaseTypedDict(CommonTypedDict):
|
21
|
+
host: Annotated[str, click.option("--host", type=str, help="OceanBase host", default="")]
|
22
|
+
user: Annotated[str, click.option("--user", type=str, help="OceanBase username", required=True)]
|
23
|
+
password: Annotated[
|
24
|
+
str,
|
25
|
+
click.option(
|
26
|
+
"--password",
|
27
|
+
type=str,
|
28
|
+
help="OceanBase database password",
|
29
|
+
default=lambda: os.environ.get("OB_PASSWORD", ""),
|
30
|
+
),
|
31
|
+
]
|
32
|
+
database: Annotated[str, click.option("--database", type=str, help="DataBase name", required=True)]
|
33
|
+
port: Annotated[int, click.option("--port", type=int, help="OceanBase port", required=True)]
|
34
|
+
|
35
|
+
|
36
|
+
class OceanBaseHNSWTypedDict(CommonTypedDict, OceanBaseTypedDict, HNSWFlavor4): ...
|
37
|
+
|
38
|
+
|
39
|
+
@cli.command()
|
40
|
+
@click_parameter_decorators_from_typed_dict(OceanBaseHNSWTypedDict)
|
41
|
+
def OceanBaseHNSW(**parameters: Unpack[OceanBaseHNSWTypedDict]):
|
42
|
+
from .config import OceanBaseConfig, OceanBaseHNSWConfig
|
43
|
+
|
44
|
+
run(
|
45
|
+
db=DB.OceanBase,
|
46
|
+
db_config=OceanBaseConfig(
|
47
|
+
db_label=parameters["db_label"],
|
48
|
+
user=SecretStr(parameters["user"]),
|
49
|
+
password=SecretStr(parameters["password"]),
|
50
|
+
host=parameters["host"],
|
51
|
+
port=parameters["port"],
|
52
|
+
database=parameters["database"],
|
53
|
+
),
|
54
|
+
db_case_config=OceanBaseHNSWConfig(
|
55
|
+
m=parameters["m"],
|
56
|
+
efConstruction=parameters["ef_construction"],
|
57
|
+
ef_search=parameters["ef_search"],
|
58
|
+
index=parameters["index_type"],
|
59
|
+
),
|
60
|
+
**parameters,
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
class OceanBaseIVFTypedDict(CommonTypedDict, OceanBaseTypedDict, OceanBaseIVFTypedDict): ...
|
65
|
+
|
66
|
+
|
67
|
+
@cli.command()
|
68
|
+
@click_parameter_decorators_from_typed_dict(OceanBaseIVFTypedDict)
|
69
|
+
def OceanBaseIVF(**parameters: Unpack[OceanBaseIVFTypedDict]):
|
70
|
+
from .config import OceanBaseConfig, OceanBaseIVFConfig
|
71
|
+
|
72
|
+
type_str = parameters["index_type"]
|
73
|
+
if type_str == "IVF_FLAT":
|
74
|
+
input_index_type = IndexType.IVFFlat
|
75
|
+
elif type_str == "IVF_PQ":
|
76
|
+
input_index_type = IndexType.IVFPQ
|
77
|
+
elif type_str == "IVF_SQ8":
|
78
|
+
input_index_type = IndexType.IVFSQ8
|
79
|
+
|
80
|
+
input_m = 0 if parameters["m"] is None else parameters["m"]
|
81
|
+
|
82
|
+
run(
|
83
|
+
db=DB.OceanBase,
|
84
|
+
db_config=OceanBaseConfig(
|
85
|
+
db_label=parameters["db_label"],
|
86
|
+
user=SecretStr(parameters["user"]),
|
87
|
+
password=SecretStr(parameters["password"]),
|
88
|
+
host=parameters["host"],
|
89
|
+
port=parameters["port"],
|
90
|
+
database=parameters["database"],
|
91
|
+
),
|
92
|
+
db_case_config=OceanBaseIVFConfig(
|
93
|
+
m=input_m,
|
94
|
+
nlist=parameters["nlist"],
|
95
|
+
sample_per_nlist=parameters["sample_per_nlist"],
|
96
|
+
index=input_index_type,
|
97
|
+
ivf_nprobes=parameters["ivf_nprobes"],
|
98
|
+
),
|
99
|
+
**parameters,
|
100
|
+
)
|
@@ -0,0 +1,125 @@
|
|
1
|
+
from typing import TypedDict
|
2
|
+
|
3
|
+
from pydantic import BaseModel, SecretStr, validator
|
4
|
+
|
5
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
6
|
+
|
7
|
+
|
8
|
+
class OceanBaseConfigDict(TypedDict):
|
9
|
+
user: str
|
10
|
+
host: str
|
11
|
+
port: str
|
12
|
+
password: str
|
13
|
+
database: str
|
14
|
+
|
15
|
+
|
16
|
+
class OceanBaseConfig(DBConfig):
|
17
|
+
user: SecretStr = SecretStr("root@perf")
|
18
|
+
password: SecretStr
|
19
|
+
host: str
|
20
|
+
port: int
|
21
|
+
database: str
|
22
|
+
|
23
|
+
def to_dict(self) -> OceanBaseConfigDict:
|
24
|
+
user_str = self.user.get_secret_value()
|
25
|
+
pwd_str = self.password.get_secret_value()
|
26
|
+
return {
|
27
|
+
"user": user_str,
|
28
|
+
"host": self.host,
|
29
|
+
"port": self.port,
|
30
|
+
"password": pwd_str,
|
31
|
+
"database": self.database,
|
32
|
+
}
|
33
|
+
|
34
|
+
@validator("*")
|
35
|
+
def not_empty_field(cls, v: any, field: any):
|
36
|
+
if field.name in ["password", "host", "db_label"]:
|
37
|
+
return v
|
38
|
+
if isinstance(v, str | SecretStr) and len(v) == 0:
|
39
|
+
raise ValueError("Empty string!")
|
40
|
+
return v
|
41
|
+
|
42
|
+
|
43
|
+
class OceanBaseIndexConfig(BaseModel):
|
44
|
+
index: IndexType
|
45
|
+
metric_type: MetricType | None = None
|
46
|
+
lib: str = "vsag"
|
47
|
+
|
48
|
+
def parse_metric(self) -> str:
|
49
|
+
if self.metric_type == MetricType.L2 or (
|
50
|
+
self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE
|
51
|
+
):
|
52
|
+
return "l2"
|
53
|
+
if self.metric_type == MetricType.IP:
|
54
|
+
return "inner_product"
|
55
|
+
return "cosine"
|
56
|
+
|
57
|
+
def parse_metric_func_str(self) -> str:
|
58
|
+
if self.metric_type == MetricType.L2 or (
|
59
|
+
self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE
|
60
|
+
):
|
61
|
+
return "l2_distance"
|
62
|
+
if self.metric_type == MetricType.IP:
|
63
|
+
return "negative_inner_product"
|
64
|
+
return "cosine_distance"
|
65
|
+
|
66
|
+
|
67
|
+
class OceanBaseHNSWConfig(OceanBaseIndexConfig, DBCaseConfig):
|
68
|
+
m: int
|
69
|
+
efConstruction: int
|
70
|
+
ef_search: int | None = None
|
71
|
+
index: IndexType
|
72
|
+
|
73
|
+
def index_param(self) -> dict:
|
74
|
+
return {
|
75
|
+
"lib": self.lib,
|
76
|
+
"metric_type": self.parse_metric(),
|
77
|
+
"index_type": self.index.value,
|
78
|
+
"params": {"m": self.m, "ef_construction": self.efConstruction},
|
79
|
+
}
|
80
|
+
|
81
|
+
def search_param(self) -> dict:
|
82
|
+
return {"metric_type": self.parse_metric_func_str(), "params": {"ef_search": self.ef_search}}
|
83
|
+
|
84
|
+
|
85
|
+
class OceanBaseIVFConfig(OceanBaseIndexConfig, DBCaseConfig):
|
86
|
+
m: int
|
87
|
+
sample_per_nlist: int
|
88
|
+
nlist: int
|
89
|
+
index: IndexType
|
90
|
+
ivf_nprobes: int | None = None
|
91
|
+
|
92
|
+
def index_param(self) -> dict:
|
93
|
+
if self.index == IndexType.IVFPQ:
|
94
|
+
return {
|
95
|
+
"lib": "OB",
|
96
|
+
"metric_type": self.parse_metric(),
|
97
|
+
"index_type": self.index.value,
|
98
|
+
"params": {
|
99
|
+
"m": self.M,
|
100
|
+
"sample_per_nlist": self.sample_per_nlist,
|
101
|
+
"nlist": self.nlist,
|
102
|
+
},
|
103
|
+
}
|
104
|
+
return {
|
105
|
+
"lib": "OB",
|
106
|
+
"metric_type": self.parse_metric(),
|
107
|
+
"index_type": self.index.value,
|
108
|
+
"params": {
|
109
|
+
"sample_per_nlist": self.sample_per_nlist,
|
110
|
+
"nlist": self.nlist,
|
111
|
+
},
|
112
|
+
}
|
113
|
+
|
114
|
+
def search_param(self) -> dict:
|
115
|
+
return {"metric_type": self.metric_type, "params": {"ivf_nprobes": self.ivf_nprobes}}
|
116
|
+
|
117
|
+
|
118
|
+
_oceanbase_case_config = {
|
119
|
+
IndexType.HNSW_SQ: OceanBaseHNSWConfig,
|
120
|
+
IndexType.HNSW: OceanBaseHNSWConfig,
|
121
|
+
IndexType.HNSW_BQ: OceanBaseHNSWConfig,
|
122
|
+
IndexType.IVFFlat: OceanBaseIVFConfig,
|
123
|
+
IndexType.IVFPQ: OceanBaseIVFConfig,
|
124
|
+
IndexType.IVFSQ8: OceanBaseIVFConfig,
|
125
|
+
}
|
@@ -0,0 +1,215 @@
|
|
1
|
+
import logging
|
2
|
+
import struct
|
3
|
+
import time
|
4
|
+
from collections.abc import Generator
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import mysql.connector as mysql
|
9
|
+
|
10
|
+
from ..api import IndexType, VectorDB
|
11
|
+
from .config import OceanBaseConfigDict, OceanBaseHNSWConfig
|
12
|
+
|
13
|
+
log = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
OCEANBASE_DEFAULT_LOAD_BATCH_SIZE = 256
|
16
|
+
|
17
|
+
|
18
|
+
class OceanBase(VectorDB):
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
dim: int,
|
22
|
+
db_config: OceanBaseConfigDict,
|
23
|
+
db_case_config: OceanBaseHNSWConfig,
|
24
|
+
collection_name: str = "items",
|
25
|
+
drop_old: bool = False,
|
26
|
+
**kwargs,
|
27
|
+
):
|
28
|
+
self.name = "OceanBase"
|
29
|
+
self.dim = dim
|
30
|
+
self.db_config = db_config
|
31
|
+
self.db_case_config = db_case_config
|
32
|
+
self.table_name = collection_name
|
33
|
+
self.load_batch_size = OCEANBASE_DEFAULT_LOAD_BATCH_SIZE
|
34
|
+
self._index_name = "vidx"
|
35
|
+
self._primary_field = "id"
|
36
|
+
self._vector_field = "embedding"
|
37
|
+
|
38
|
+
log.info(
|
39
|
+
f"{self.name} initialized with config:\nDatabase: {self.db_config}\nCase Config: {self.db_case_config}"
|
40
|
+
)
|
41
|
+
|
42
|
+
self._conn = None
|
43
|
+
self._cursor = None
|
44
|
+
|
45
|
+
try:
|
46
|
+
self._connect()
|
47
|
+
if drop_old:
|
48
|
+
self._drop_table()
|
49
|
+
self._create_table()
|
50
|
+
finally:
|
51
|
+
self._disconnect()
|
52
|
+
|
53
|
+
def _connect(self):
|
54
|
+
try:
|
55
|
+
self._conn = mysql.connect(
|
56
|
+
host=self.db_config["host"],
|
57
|
+
user=self.db_config["user"],
|
58
|
+
port=self.db_config["port"],
|
59
|
+
password=self.db_config["password"],
|
60
|
+
database=self.db_config["database"],
|
61
|
+
)
|
62
|
+
self._cursor = self._conn.cursor()
|
63
|
+
except mysql.Error:
|
64
|
+
log.exception("Failed to connect to the database")
|
65
|
+
raise
|
66
|
+
|
67
|
+
def _disconnect(self):
|
68
|
+
if self._cursor:
|
69
|
+
self._cursor.close()
|
70
|
+
self._cursor = None
|
71
|
+
if self._conn:
|
72
|
+
self._conn.close()
|
73
|
+
self._conn = None
|
74
|
+
|
75
|
+
@contextmanager
|
76
|
+
def init(self) -> Generator[None, None, None]:
|
77
|
+
try:
|
78
|
+
self._connect()
|
79
|
+
self._cursor.execute("SET autocommit=1")
|
80
|
+
|
81
|
+
if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}:
|
82
|
+
self._cursor.execute(
|
83
|
+
f"SET ob_hnsw_ef_search={(self.db_case_config.search_param())['params']['ef_search']}"
|
84
|
+
)
|
85
|
+
else:
|
86
|
+
self._cursor.execute(
|
87
|
+
f"SET ob_ivf_nprobes={(self.db_case_config.search_param())['params']['ivf_nprobes']}"
|
88
|
+
)
|
89
|
+
yield
|
90
|
+
finally:
|
91
|
+
self._disconnect()
|
92
|
+
|
93
|
+
def _drop_table(self):
|
94
|
+
if not self._cursor:
|
95
|
+
raise ValueError("Cursor is not initialized")
|
96
|
+
|
97
|
+
log.info(f"Dropping table {self.table_name}")
|
98
|
+
self._cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
99
|
+
|
100
|
+
def _create_table(self):
|
101
|
+
if not self._cursor:
|
102
|
+
raise ValueError("Cursor is not initialized")
|
103
|
+
|
104
|
+
log.info(f"Creating table {self.table_name}")
|
105
|
+
create_table_query = f"""
|
106
|
+
CREATE TABLE {self.table_name} (
|
107
|
+
id INT PRIMARY KEY,
|
108
|
+
embedding VECTOR({self.dim})
|
109
|
+
);
|
110
|
+
"""
|
111
|
+
self._cursor.execute(create_table_query)
|
112
|
+
|
113
|
+
def optimize(self, data_size: int):
|
114
|
+
index_params = self.db_case_config.index_param()
|
115
|
+
index_args = ", ".join(f"{k}={v}" for k, v in index_params["params"].items())
|
116
|
+
index_query = (
|
117
|
+
f"CREATE /*+ PARALLEL(18) */ VECTOR INDEX idx1 "
|
118
|
+
f"ON {self.table_name}(embedding) "
|
119
|
+
f"WITH (distance={self.db_case_config.parse_metric()}, "
|
120
|
+
f"type={index_params['index_type']}, lib={index_params['lib']}, {index_args}"
|
121
|
+
)
|
122
|
+
|
123
|
+
if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}:
|
124
|
+
index_query += ", extra_info_max_size=32"
|
125
|
+
|
126
|
+
index_query += ")"
|
127
|
+
|
128
|
+
log.info("Create index query: %s", index_query)
|
129
|
+
|
130
|
+
try:
|
131
|
+
log.info("Creating index...")
|
132
|
+
start_time = time.time()
|
133
|
+
self._cursor.execute(index_query)
|
134
|
+
log.info(f"Index created in {time.time() - start_time:.2f} seconds")
|
135
|
+
|
136
|
+
log.info("Performing major freeze...")
|
137
|
+
self._cursor.execute("ALTER SYSTEM MAJOR FREEZE;")
|
138
|
+
time.sleep(10)
|
139
|
+
self._wait_for_major_compaction()
|
140
|
+
|
141
|
+
log.info("Gathering schema statistics...")
|
142
|
+
self._cursor.execute("CALL dbms_stats.gather_schema_stats('test', degree => 96);")
|
143
|
+
except mysql.Error:
|
144
|
+
log.exception("Failed to optimize index")
|
145
|
+
raise
|
146
|
+
|
147
|
+
def need_normalize_cosine(self) -> bool:
|
148
|
+
if self.db_case_config.index == IndexType.HNSW_BQ:
|
149
|
+
log.info("current HNSW_BQ only supports L2, cosine dataset need normalize.")
|
150
|
+
return True
|
151
|
+
|
152
|
+
return False
|
153
|
+
|
154
|
+
def _wait_for_major_compaction(self):
|
155
|
+
while True:
|
156
|
+
self._cursor.execute(
|
157
|
+
"SELECT IF(COUNT(*) = COUNT(STATUS = 'IDLE' OR NULL), 'TRUE', 'FALSE') "
|
158
|
+
"AS all_status_idle FROM oceanbase.DBA_OB_ZONE_MAJOR_COMPACTION;"
|
159
|
+
)
|
160
|
+
all_status_idle = self._cursor.fetchone()[0]
|
161
|
+
if all_status_idle == "TRUE":
|
162
|
+
break
|
163
|
+
time.sleep(10)
|
164
|
+
|
165
|
+
def insert_embeddings(
|
166
|
+
self,
|
167
|
+
embeddings: list[list[float]],
|
168
|
+
metadata: list[int],
|
169
|
+
**kwargs: Any,
|
170
|
+
) -> tuple[int, Exception | None]:
|
171
|
+
if not self._cursor:
|
172
|
+
raise ValueError("Cursor is not initialized")
|
173
|
+
|
174
|
+
insert_count = 0
|
175
|
+
try:
|
176
|
+
for batch_start in range(0, len(embeddings), self.load_batch_size):
|
177
|
+
batch_end = min(batch_start + self.load_batch_size, len(embeddings))
|
178
|
+
batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)]
|
179
|
+
values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch)
|
180
|
+
self._cursor.execute(
|
181
|
+
f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}" # noqa: S608
|
182
|
+
)
|
183
|
+
insert_count += len(batch)
|
184
|
+
except mysql.Error:
|
185
|
+
log.exception("Failed to insert embeddings")
|
186
|
+
raise
|
187
|
+
|
188
|
+
return insert_count, None
|
189
|
+
|
190
|
+
def search_embedding(
|
191
|
+
self,
|
192
|
+
query: list[float],
|
193
|
+
k: int = 100,
|
194
|
+
filters: dict[str, Any] | None = None,
|
195
|
+
timeout: int | None = None,
|
196
|
+
) -> list[int]:
|
197
|
+
if not self._cursor:
|
198
|
+
raise ValueError("Cursor is not initialized")
|
199
|
+
|
200
|
+
packed = struct.pack(f"<{len(query)}f", *query)
|
201
|
+
hex_vec = packed.hex()
|
202
|
+
filter_clause = f"WHERE id >= {filters['id']}" if filters else ""
|
203
|
+
query_str = (
|
204
|
+
f"SELECT id FROM {self.table_name} " # noqa: S608
|
205
|
+
f"{filter_clause} ORDER BY "
|
206
|
+
f"{self.db_case_config.parse_metric_func_str()}(embedding, X'{hex_vec}') "
|
207
|
+
f"APPROXIMATE LIMIT {k}"
|
208
|
+
)
|
209
|
+
|
210
|
+
try:
|
211
|
+
self._cursor.execute(query_str)
|
212
|
+
return [row[0] for row in self._cursor.fetchall()]
|
213
|
+
except mysql.Error:
|
214
|
+
log.exception("Failed to execute search query")
|
215
|
+
raise
|