vectordb-bench 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/assembler.py +2 -2
- vectordb_bench/backend/clients/__init__.py +28 -2
- vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +1 -7
- vectordb_bench/backend/clients/alloydb/alloydb.py +1 -4
- vectordb_bench/backend/clients/api.py +8 -15
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +54 -8
- vectordb_bench/backend/clients/aws_opensearch/cli.py +85 -1
- vectordb_bench/backend/clients/aws_opensearch/config.py +10 -0
- vectordb_bench/backend/clients/chroma/chroma.py +1 -4
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +1 -4
- vectordb_bench/backend/clients/memorydb/cli.py +2 -2
- vectordb_bench/backend/clients/memorydb/memorydb.py +2 -5
- vectordb_bench/backend/clients/milvus/milvus.py +1 -20
- vectordb_bench/backend/clients/mongodb/config.py +53 -0
- vectordb_bench/backend/clients/mongodb/mongodb.py +200 -0
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +1 -4
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +3 -11
- vectordb_bench/backend/clients/pgvector/pgvector.py +2 -7
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +2 -7
- vectordb_bench/backend/clients/pinecone/pinecone.py +1 -4
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +3 -6
- vectordb_bench/backend/clients/redis/redis.py +1 -4
- vectordb_bench/backend/clients/test/cli.py +1 -1
- vectordb_bench/backend/clients/test/test.py +1 -4
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -4
- vectordb_bench/backend/data_source.py +4 -12
- vectordb_bench/backend/runner/mp_runner.py +16 -34
- vectordb_bench/backend/runner/rate_runner.py +4 -4
- vectordb_bench/backend/runner/read_write_runner.py +11 -15
- vectordb_bench/backend/runner/serial_runner.py +20 -28
- vectordb_bench/backend/task_runner.py +6 -26
- vectordb_bench/frontend/components/custom/displaypPrams.py +12 -1
- vectordb_bench/frontend/components/run_test/submitTask.py +20 -3
- vectordb_bench/frontend/config/dbCaseConfigs.py +32 -0
- vectordb_bench/interface.py +10 -19
- vectordb_bench/log_util.py +15 -2
- vectordb_bench/models.py +4 -0
- {vectordb_bench-0.0.20.dist-info → vectordb_bench-0.0.22.dist-info}/METADATA +55 -2
- {vectordb_bench-0.0.20.dist-info → vectordb_bench-0.0.22.dist-info}/RECORD +43 -41
- {vectordb_bench-0.0.20.dist-info → vectordb_bench-0.0.22.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.20.dist-info → vectordb_bench-0.0.22.dist-info}/WHEEL +0 -0
- {vectordb_bench-0.0.20.dist-info → vectordb_bench-0.0.22.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.20.dist-info → vectordb_bench-0.0.22.dist-info}/top_level.txt +0 -0
@@ -53,8 +53,8 @@ class Assembler:
|
|
53
53
|
_ = k.init_cls
|
54
54
|
|
55
55
|
# sort by dataset size
|
56
|
-
for
|
57
|
-
|
56
|
+
for _, runner in db2runner.items():
|
57
|
+
runner.sort(key=lambda x: x.ca.dataset.data.size)
|
58
58
|
|
59
59
|
all_runners = []
|
60
60
|
all_runners.extend(load_runners)
|
@@ -40,9 +40,10 @@ class DB(Enum):
|
|
40
40
|
AliyunElasticsearch = "AliyunElasticsearch"
|
41
41
|
Test = "test"
|
42
42
|
AliyunOpenSearch = "AliyunOpenSearch"
|
43
|
+
MongoDB = "MongoDB"
|
43
44
|
|
44
45
|
@property
|
45
|
-
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912
|
46
|
+
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
|
46
47
|
"""Import while in use"""
|
47
48
|
if self == DB.Milvus:
|
48
49
|
from .milvus.milvus import Milvus
|
@@ -129,11 +130,21 @@ class DB(Enum):
|
|
129
130
|
|
130
131
|
return AliyunOpenSearch
|
131
132
|
|
133
|
+
if self == DB.MongoDB:
|
134
|
+
from .mongodb.mongodb import MongoDB
|
135
|
+
|
136
|
+
return MongoDB
|
137
|
+
|
138
|
+
if self == DB.Test:
|
139
|
+
from .test.test import Test
|
140
|
+
|
141
|
+
return Test
|
142
|
+
|
132
143
|
msg = f"Unknown DB: {self.name}"
|
133
144
|
raise ValueError(msg)
|
134
145
|
|
135
146
|
@property
|
136
|
-
def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912
|
147
|
+
def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901
|
137
148
|
"""Import while in use"""
|
138
149
|
if self == DB.Milvus:
|
139
150
|
from .milvus.config import MilvusConfig
|
@@ -220,6 +231,16 @@ class DB(Enum):
|
|
220
231
|
|
221
232
|
return AliyunOpenSearchConfig
|
222
233
|
|
234
|
+
if self == DB.MongoDB:
|
235
|
+
from .mongodb.config import MongoDBConfig
|
236
|
+
|
237
|
+
return MongoDBConfig
|
238
|
+
|
239
|
+
if self == DB.Test:
|
240
|
+
from .test.config import TestConfig
|
241
|
+
|
242
|
+
return TestConfig
|
243
|
+
|
223
244
|
msg = f"Unknown DB: {self.name}"
|
224
245
|
raise ValueError(msg)
|
225
246
|
|
@@ -292,6 +313,11 @@ class DB(Enum):
|
|
292
313
|
|
293
314
|
return AliyunOpenSearchIndexConfig
|
294
315
|
|
316
|
+
if self == DB.MongoDB:
|
317
|
+
from .mongodb.config import MongoDBIndexConfig
|
318
|
+
|
319
|
+
return MongoDBIndexConfig
|
320
|
+
|
295
321
|
# DB.Pinecone, DB.Chroma, DB.Redis
|
296
322
|
return EmptyDBCaseConfig
|
297
323
|
|
@@ -325,10 +325,7 @@ class AliyunOpenSearch(VectorDB):
|
|
325
325
|
|
326
326
|
return False
|
327
327
|
|
328
|
-
def optimize(self):
|
329
|
-
pass
|
330
|
-
|
331
|
-
def optimize_with_size(self, data_size: int):
|
328
|
+
def optimize(self, data_size: int):
|
332
329
|
log.info(f"optimize count: {data_size}")
|
333
330
|
retry_times = 0
|
334
331
|
while True:
|
@@ -340,6 +337,3 @@ class AliyunOpenSearch(VectorDB):
|
|
340
337
|
if total_count == data_size:
|
341
338
|
log.info("optimize table finish.")
|
342
339
|
return
|
343
|
-
|
344
|
-
def ready_to_load(self):
|
345
|
-
"""ready_to_load will be called before load in load cases."""
|
@@ -137,6 +137,13 @@ class VectorDB(ABC):
|
|
137
137
|
@contextmanager
|
138
138
|
def init(self) -> None:
|
139
139
|
"""create and destory connections to database.
|
140
|
+
Why contextmanager:
|
141
|
+
|
142
|
+
In multiprocessing search tasks, vectordbbench might init
|
143
|
+
totally hundreds of thousands of connections with DB server.
|
144
|
+
|
145
|
+
Too many connections may drain local FDs or server connection resources.
|
146
|
+
If the DB client doesn't have `close()` method, just set the object to None.
|
140
147
|
|
141
148
|
Examples:
|
142
149
|
>>> with self.init():
|
@@ -187,9 +194,8 @@ class VectorDB(ABC):
|
|
187
194
|
"""
|
188
195
|
raise NotImplementedError
|
189
196
|
|
190
|
-
# TODO: remove
|
191
197
|
@abstractmethod
|
192
|
-
def optimize(self):
|
198
|
+
def optimize(self, data_size: int | None = None):
|
193
199
|
"""optimize will be called between insertion and search in performance cases.
|
194
200
|
|
195
201
|
Should be blocked until the vectorDB is ready to be tested on
|
@@ -199,16 +205,3 @@ class VectorDB(ABC):
|
|
199
205
|
Optimize's execution time is limited, the limited time is based on cases.
|
200
206
|
"""
|
201
207
|
raise NotImplementedError
|
202
|
-
|
203
|
-
def optimize_with_size(self, data_size: int):
|
204
|
-
self.optimize()
|
205
|
-
|
206
|
-
# TODO: remove
|
207
|
-
@abstractmethod
|
208
|
-
def ready_to_load(self):
|
209
|
-
"""ready_to_load will be called before load in load cases.
|
210
|
-
|
211
|
-
Should be blocked until the vectorDB is ready to be tested on
|
212
|
-
heavy load cases.
|
213
|
-
"""
|
214
|
-
raise NotImplementedError
|
@@ -12,6 +12,7 @@ log = logging.getLogger(__name__)
|
|
12
12
|
|
13
13
|
WAITING_FOR_REFRESH_SEC = 30
|
14
14
|
WAITING_FOR_FORCE_MERGE_SEC = 30
|
15
|
+
SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC = 30
|
15
16
|
|
16
17
|
|
17
18
|
class AWSOpenSearch(VectorDB):
|
@@ -52,10 +53,27 @@ class AWSOpenSearch(VectorDB):
|
|
52
53
|
return AWSOpenSearchIndexConfig
|
53
54
|
|
54
55
|
def _create_index(self, client: OpenSearch):
|
56
|
+
cluster_settings_body = {
|
57
|
+
"persistent": {
|
58
|
+
"knn.algo_param.index_thread_qty": self.case_config.index_thread_qty,
|
59
|
+
"knn.memory.circuit_breaker.limit": self.case_config.cb_threshold,
|
60
|
+
}
|
61
|
+
}
|
62
|
+
client.cluster.put_settings(cluster_settings_body)
|
55
63
|
settings = {
|
56
64
|
"index": {
|
57
65
|
"knn": True,
|
66
|
+
"number_of_shards": self.case_config.number_of_shards,
|
67
|
+
"number_of_replicas": 0,
|
68
|
+
"translog.flush_threshold_size": self.case_config.flush_threshold_size,
|
69
|
+
# Setting trans log threshold to 5GB
|
70
|
+
**(
|
71
|
+
{"knn.algo_param.ef_search": self.case_config.ef_search}
|
72
|
+
if self.case_config.engine == AWSOS_Engine.nmslib
|
73
|
+
else {}
|
74
|
+
),
|
58
75
|
},
|
76
|
+
"refresh_interval": self.case_config.refresh_interval,
|
59
77
|
}
|
60
78
|
mappings = {
|
61
79
|
"properties": {
|
@@ -145,24 +163,49 @@ class AWSOpenSearch(VectorDB):
|
|
145
163
|
docvalue_fields=[self.id_col_name],
|
146
164
|
stored_fields="_none_",
|
147
165
|
)
|
148
|
-
log.
|
149
|
-
log.
|
150
|
-
log.
|
166
|
+
log.debug(f"Search took: {resp['took']}")
|
167
|
+
log.debug(f"Search shards: {resp['_shards']}")
|
168
|
+
log.debug(f"Search hits total: {resp['hits']['total']}")
|
151
169
|
return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]]
|
152
170
|
except Exception as e:
|
153
171
|
log.warning(f"Failed to search: {self.index_name} error: {e!s}")
|
154
172
|
raise e from None
|
155
173
|
|
156
|
-
def optimize(self):
|
174
|
+
def optimize(self, data_size: int | None = None):
|
157
175
|
"""optimize will be called between insertion and search in performance cases."""
|
158
176
|
# Call refresh first to ensure that all segments are created
|
159
177
|
self._refresh_index()
|
160
|
-
self.
|
178
|
+
if self.case_config.force_merge_enabled:
|
179
|
+
self._do_force_merge()
|
180
|
+
self._refresh_index()
|
181
|
+
self._update_replicas()
|
161
182
|
# Call refresh again to ensure that the index is ready after force merge.
|
162
183
|
self._refresh_index()
|
163
184
|
# ensure that all graphs are loaded in memory and ready for search
|
164
185
|
self._load_graphs_to_memory()
|
165
186
|
|
187
|
+
def _update_replicas(self):
|
188
|
+
index_settings = self.client.indices.get_settings(index=self.index_name)
|
189
|
+
current_number_of_replicas = int(index_settings[self.index_name]["settings"]["index"]["number_of_replicas"])
|
190
|
+
log.info(
|
191
|
+
f"Current Number of replicas are {current_number_of_replicas}"
|
192
|
+
f" and changing the replicas to {self.case_config.number_of_replicas}"
|
193
|
+
)
|
194
|
+
settings_body = {"index": {"number_of_replicas": self.case_config.number_of_replicas}}
|
195
|
+
self.client.indices.put_settings(index=self.index_name, body=settings_body)
|
196
|
+
self._wait_till_green()
|
197
|
+
|
198
|
+
def _wait_till_green(self):
|
199
|
+
log.info("Wait for index to become green..")
|
200
|
+
while True:
|
201
|
+
res = self.client.cat.indices(index=self.index_name, h="health", format="json")
|
202
|
+
health = res[0]["health"]
|
203
|
+
if health != "green":
|
204
|
+
break
|
205
|
+
log.info(f"The index {self.index_name} has health : {health} and is not green. Retrying")
|
206
|
+
time.sleep(SECONDS_WAITING_FOR_REPLICAS_TO_BE_ENABLED_SEC)
|
207
|
+
log.info(f"Index {self.index_name} is green..")
|
208
|
+
|
166
209
|
def _refresh_index(self):
|
167
210
|
log.debug(f"Starting refresh for index {self.index_name}")
|
168
211
|
while True:
|
@@ -179,6 +222,12 @@ class AWSOpenSearch(VectorDB):
|
|
179
222
|
log.debug(f"Completed refresh for index {self.index_name}")
|
180
223
|
|
181
224
|
def _do_force_merge(self):
|
225
|
+
log.info(f"Updating the Index thread qty to {self.case_config.index_thread_qty_during_force_merge}.")
|
226
|
+
|
227
|
+
cluster_settings_body = {
|
228
|
+
"persistent": {"knn.algo_param.index_thread_qty": self.case_config.index_thread_qty_during_force_merge}
|
229
|
+
}
|
230
|
+
self.client.cluster.put_settings(cluster_settings_body)
|
182
231
|
log.debug(f"Starting force merge for index {self.index_name}")
|
183
232
|
force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
|
184
233
|
force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"]
|
@@ -194,6 +243,3 @@ class AWSOpenSearch(VectorDB):
|
|
194
243
|
log.info("Calling warmup API to load graphs into memory")
|
195
244
|
warmup_endpoint = f"/_plugins/_knn/warmup/{self.index_name}"
|
196
245
|
self.client.transport.perform_request("GET", warmup_endpoint)
|
197
|
-
|
198
|
-
def ready_to_load(self):
|
199
|
-
"""ready_to_load will be called before load in load cases."""
|
@@ -18,6 +18,79 @@ class AWSOpenSearchTypedDict(TypedDict):
|
|
18
18
|
port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
|
19
19
|
user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
|
20
20
|
password: Annotated[str, click.option("--password", type=str, help="Db password")]
|
21
|
+
number_of_shards: Annotated[
|
22
|
+
int,
|
23
|
+
click.option("--number-of-shards", type=int, help="Number of primary shards for the index", default=1),
|
24
|
+
]
|
25
|
+
number_of_replicas: Annotated[
|
26
|
+
int,
|
27
|
+
click.option(
|
28
|
+
"--number-of-replicas", type=int, help="Number of replica copies for each primary shard", default=1
|
29
|
+
),
|
30
|
+
]
|
31
|
+
index_thread_qty: Annotated[
|
32
|
+
int,
|
33
|
+
click.option(
|
34
|
+
"--index-thread-qty",
|
35
|
+
type=int,
|
36
|
+
help="Thread count for native engine indexing",
|
37
|
+
default=4,
|
38
|
+
),
|
39
|
+
]
|
40
|
+
|
41
|
+
index_thread_qty_during_force_merge: Annotated[
|
42
|
+
int,
|
43
|
+
click.option(
|
44
|
+
"--index-thread-qty-during-force-merge",
|
45
|
+
type=int,
|
46
|
+
help="Thread count during force merge operations",
|
47
|
+
default=4,
|
48
|
+
),
|
49
|
+
]
|
50
|
+
|
51
|
+
number_of_indexing_clients: Annotated[
|
52
|
+
int,
|
53
|
+
click.option(
|
54
|
+
"--number-of-indexing-clients",
|
55
|
+
type=int,
|
56
|
+
help="Number of concurrent indexing clients",
|
57
|
+
default=1,
|
58
|
+
),
|
59
|
+
]
|
60
|
+
|
61
|
+
number_of_segments: Annotated[
|
62
|
+
int,
|
63
|
+
click.option("--number-of-segments", type=int, help="Target number of segments after merging", default=1),
|
64
|
+
]
|
65
|
+
|
66
|
+
refresh_interval: Annotated[
|
67
|
+
int,
|
68
|
+
click.option(
|
69
|
+
"--refresh-interval", type=str, help="How often to make new data available for search", default="60s"
|
70
|
+
),
|
71
|
+
]
|
72
|
+
|
73
|
+
force_merge_enabled: Annotated[
|
74
|
+
int,
|
75
|
+
click.option("--force-merge-enabled", type=bool, help="Whether to perform force merge operation", default=True),
|
76
|
+
]
|
77
|
+
|
78
|
+
flush_threshold_size: Annotated[
|
79
|
+
int,
|
80
|
+
click.option(
|
81
|
+
"--flush-threshold-size", type=str, help="Size threshold for flushing the transaction log", default="5120mb"
|
82
|
+
),
|
83
|
+
]
|
84
|
+
|
85
|
+
cb_threshold: Annotated[
|
86
|
+
int,
|
87
|
+
click.option(
|
88
|
+
"--cb-threshold",
|
89
|
+
type=str,
|
90
|
+
help="k-NN Memory circuit breaker threshold",
|
91
|
+
default="50%",
|
92
|
+
),
|
93
|
+
]
|
21
94
|
|
22
95
|
|
23
96
|
class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ...
|
@@ -36,6 +109,17 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
|
|
36
109
|
user=parameters["user"],
|
37
110
|
password=SecretStr(parameters["password"]),
|
38
111
|
),
|
39
|
-
db_case_config=AWSOpenSearchIndexConfig(
|
112
|
+
db_case_config=AWSOpenSearchIndexConfig(
|
113
|
+
number_of_shards=parameters["number_of_shards"],
|
114
|
+
number_of_replicas=parameters["number_of_replicas"],
|
115
|
+
index_thread_qty=parameters["index_thread_qty"],
|
116
|
+
number_of_segments=parameters["number_of_segments"],
|
117
|
+
refresh_interval=parameters["refresh_interval"],
|
118
|
+
force_merge_enabled=parameters["force_merge_enabled"],
|
119
|
+
flush_threshold_size=parameters["flush_threshold_size"],
|
120
|
+
number_of_indexing_clients=parameters["number_of_indexing_clients"],
|
121
|
+
index_thread_qty_during_force_merge=parameters["index_thread_qty_during_force_merge"],
|
122
|
+
cb_threshold=parameters["cb_threshold"],
|
123
|
+
),
|
40
124
|
**parameters,
|
41
125
|
)
|
@@ -39,6 +39,16 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
|
|
39
39
|
efConstruction: int = 256
|
40
40
|
efSearch: int = 256
|
41
41
|
M: int = 16
|
42
|
+
index_thread_qty: int | None = 4
|
43
|
+
number_of_shards: int | None = 1
|
44
|
+
number_of_replicas: int | None = 0
|
45
|
+
number_of_segments: int | None = 1
|
46
|
+
refresh_interval: str | None = "60s"
|
47
|
+
force_merge_enabled: bool | None = True
|
48
|
+
flush_threshold_size: str | None = "5120mb"
|
49
|
+
number_of_indexing_clients: int | None = 1
|
50
|
+
index_thread_qty_during_force_merge: int
|
51
|
+
cb_threshold: str | None = "50%"
|
42
52
|
|
43
53
|
def parse_metric(self) -> str:
|
44
54
|
if self.metric_type == MetricType.IP:
|
@@ -143,7 +143,7 @@ class ElasticCloud(VectorDB):
|
|
143
143
|
log.warning(f"Failed to search: {self.indice} error: {e!s}")
|
144
144
|
raise e from None
|
145
145
|
|
146
|
-
def optimize(self):
|
146
|
+
def optimize(self, data_size: int | None = None):
|
147
147
|
"""optimize will be called between insertion and search in performance cases."""
|
148
148
|
assert self.client is not None, "should self.init() first"
|
149
149
|
self.client.indices.refresh(index=self.indice)
|
@@ -158,6 +158,3 @@ class ElasticCloud(VectorDB):
|
|
158
158
|
task_status = self.client.tasks.get(task_id=force_merge_task_id)
|
159
159
|
if task_status["completed"]:
|
160
160
|
return
|
161
|
-
|
162
|
-
def ready_to_load(self):
|
163
|
-
"""ready_to_load will be called before load in load cases."""
|
@@ -43,8 +43,8 @@ class MemoryDBTypedDict(TypedDict):
|
|
43
43
|
show_default=True,
|
44
44
|
default=False,
|
45
45
|
help=(
|
46
|
-
"Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance."
|
47
|
-
" In production, MemoryDB only supports cluster mode (CME)"
|
46
|
+
"Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance."
|
47
|
+
" In production, MemoryDB only supports cluster mode (CME)"
|
48
48
|
),
|
49
49
|
),
|
50
50
|
]
|
@@ -157,17 +157,14 @@ class MemoryDB(VectorDB):
|
|
157
157
|
self.conn = self.get_client()
|
158
158
|
search_param = self.case_config.search_param()
|
159
159
|
if search_param["ef_runtime"]:
|
160
|
-
self.ef_runtime_str = f
|
160
|
+
self.ef_runtime_str = f"EF_RUNTIME {search_param['ef_runtime']}"
|
161
161
|
else:
|
162
162
|
self.ef_runtime_str = ""
|
163
163
|
yield
|
164
164
|
self.conn.close()
|
165
165
|
self.conn = None
|
166
166
|
|
167
|
-
def
|
168
|
-
pass
|
169
|
-
|
170
|
-
def optimize(self) -> None:
|
167
|
+
def optimize(self, data_size: int | None = None):
|
171
168
|
self._post_insert()
|
172
169
|
|
173
170
|
def insert_embeddings(
|
@@ -138,26 +138,7 @@ class Milvus(VectorDB):
|
|
138
138
|
log.warning(f"{self.name} optimize error: {e}")
|
139
139
|
raise e from None
|
140
140
|
|
141
|
-
def
|
142
|
-
assert self.col, "Please call self.init() before"
|
143
|
-
self._pre_load(self.col)
|
144
|
-
|
145
|
-
def _pre_load(self, coll: Collection):
|
146
|
-
try:
|
147
|
-
if not coll.has_index(index_name=self._index_name):
|
148
|
-
log.info(f"{self.name} create index")
|
149
|
-
coll.create_index(
|
150
|
-
self._vector_field,
|
151
|
-
self.case_config.index_param(),
|
152
|
-
index_name=self._index_name,
|
153
|
-
)
|
154
|
-
coll.load()
|
155
|
-
log.info(f"{self.name} load")
|
156
|
-
except Exception as e:
|
157
|
-
log.warning(f"{self.name} pre load error: {e}")
|
158
|
-
raise e from None
|
159
|
-
|
160
|
-
def optimize(self):
|
141
|
+
def optimize(self, data_size: int | None = None):
|
161
142
|
assert self.col, "Please call self.init() before"
|
162
143
|
self._optimize()
|
163
144
|
|
@@ -0,0 +1,53 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
from pydantic import BaseModel, SecretStr
|
4
|
+
|
5
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
6
|
+
|
7
|
+
|
8
|
+
class QuantizationType(Enum):
|
9
|
+
NONE = "none"
|
10
|
+
BINARY = "binary"
|
11
|
+
SCALAR = "scalar"
|
12
|
+
|
13
|
+
|
14
|
+
class MongoDBConfig(DBConfig, BaseModel):
|
15
|
+
connection_string: SecretStr = "mongodb+srv://<user>:<password>@<cluster_name>.heatl.mongodb.net"
|
16
|
+
database: str = "vdb_bench"
|
17
|
+
|
18
|
+
def to_dict(self) -> dict:
|
19
|
+
return {
|
20
|
+
"connection_string": self.connection_string.get_secret_value(),
|
21
|
+
"database": self.database,
|
22
|
+
}
|
23
|
+
|
24
|
+
|
25
|
+
class MongoDBIndexConfig(BaseModel, DBCaseConfig):
|
26
|
+
index: IndexType = IndexType.HNSW # MongoDB uses HNSW for vector search
|
27
|
+
metric_type: MetricType = MetricType.COSINE
|
28
|
+
num_candidates_ratio: int = 10 # Default numCandidates ratio for vector search
|
29
|
+
quantization: QuantizationType = QuantizationType.NONE # Quantization type if applicable
|
30
|
+
|
31
|
+
def parse_metric(self) -> str:
|
32
|
+
if self.metric_type == MetricType.L2:
|
33
|
+
return "euclidean"
|
34
|
+
if self.metric_type == MetricType.IP:
|
35
|
+
return "dotProduct"
|
36
|
+
return "cosine" # Default to cosine similarity
|
37
|
+
|
38
|
+
def index_param(self) -> dict:
|
39
|
+
return {
|
40
|
+
"type": "vectorSearch",
|
41
|
+
"fields": [
|
42
|
+
{
|
43
|
+
"type": "vector",
|
44
|
+
"similarity": self.parse_metric(),
|
45
|
+
"numDimensions": None, # Will be set in MongoDB class
|
46
|
+
"path": "vector", # Vector field name
|
47
|
+
"quantization": self.quantization.value,
|
48
|
+
}
|
49
|
+
],
|
50
|
+
}
|
51
|
+
|
52
|
+
def search_param(self) -> dict:
|
53
|
+
return {"num_candidates_ratio": self.num_candidates_ratio}
|