vectordb-bench 1.0.5__py3-none-any.whl → 1.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +1 -0
- vectordb_bench/backend/clients/__init__.py +15 -0
- vectordb_bench/backend/clients/api.py +2 -0
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +104 -40
- vectordb_bench/backend/clients/aws_opensearch/cli.py +52 -15
- vectordb_bench/backend/clients/aws_opensearch/config.py +27 -7
- vectordb_bench/backend/clients/hologres/cli.py +50 -0
- vectordb_bench/backend/clients/hologres/config.py +120 -0
- vectordb_bench/backend/clients/hologres/hologres.py +385 -0
- vectordb_bench/backend/clients/lancedb/lancedb.py +1 -0
- vectordb_bench/backend/clients/milvus/cli.py +25 -0
- vectordb_bench/backend/clients/milvus/config.py +2 -1
- vectordb_bench/backend/clients/milvus/milvus.py +1 -1
- vectordb_bench/backend/clients/oceanbase/cli.py +1 -0
- vectordb_bench/backend/clients/oceanbase/config.py +3 -1
- vectordb_bench/backend/clients/oceanbase/oceanbase.py +20 -4
- vectordb_bench/backend/clients/pgdiskann/cli.py +45 -0
- vectordb_bench/backend/clients/pgdiskann/config.py +16 -0
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +94 -26
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -1
- vectordb_bench/backend/clients/zilliz_cloud/config.py +4 -1
- vectordb_bench/backend/runner/rate_runner.py +23 -11
- vectordb_bench/cli/cli.py +59 -1
- vectordb_bench/cli/vectordbbench.py +2 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +82 -3
- vectordb_bench/frontend/config/styles.py +1 -0
- vectordb_bench/interface.py +5 -1
- vectordb_bench/models.py +4 -0
- vectordb_bench/results/getLeaderboardDataV2.py +23 -2
- vectordb_bench/results/leaderboard_v2.json +200 -0
- vectordb_bench/results/leaderboard_v2_streaming.json +128 -0
- {vectordb_bench-1.0.5.dist-info → vectordb_bench-1.0.8.dist-info}/METADATA +40 -8
- {vectordb_bench-1.0.5.dist-info → vectordb_bench-1.0.8.dist-info}/RECORD +37 -33
- {vectordb_bench-1.0.5.dist-info → vectordb_bench-1.0.8.dist-info}/WHEEL +0 -0
- {vectordb_bench-1.0.5.dist-info → vectordb_bench-1.0.8.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-1.0.5.dist-info → vectordb_bench-1.0.8.dist-info}/licenses/LICENSE +0 -0
- {vectordb_bench-1.0.5.dist-info → vectordb_bench-1.0.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,385 @@
|
|
1
|
+
"""Wrapper around the Hologres vector database over VectorDB"""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from collections.abc import Generator
|
6
|
+
from contextlib import contextmanager
|
7
|
+
from io import StringIO
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
import psycopg
|
11
|
+
from psycopg import Connection, Cursor, sql
|
12
|
+
|
13
|
+
from ..api import VectorDB
|
14
|
+
from .config import HologresConfig, HologresIndexConfig
|
15
|
+
|
16
|
+
log = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class Hologres(VectorDB):
|
20
|
+
"""Use psycopg instructions"""
|
21
|
+
|
22
|
+
conn: psycopg.Connection[Any] | None = None
|
23
|
+
cursor: psycopg.Cursor[Any] | None = None
|
24
|
+
|
25
|
+
_tg_name: str = "vdb_bench_tg_1"
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
dim: int,
|
30
|
+
db_config: HologresConfig,
|
31
|
+
db_case_config: HologresIndexConfig,
|
32
|
+
collection_name: str = "vector_collection",
|
33
|
+
drop_old: bool = False,
|
34
|
+
**kwargs,
|
35
|
+
):
|
36
|
+
self.name = "Alibaba Cloud Hologres"
|
37
|
+
self.db_config = db_config
|
38
|
+
self.case_config = db_case_config
|
39
|
+
self.table_name = collection_name
|
40
|
+
self.dim = dim
|
41
|
+
|
42
|
+
self._primary_field = "id"
|
43
|
+
self._vector_field = "embedding"
|
44
|
+
|
45
|
+
# construct basic units
|
46
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
47
|
+
|
48
|
+
# create vector extension
|
49
|
+
if self.case_config.is_proxima():
|
50
|
+
self.cursor.execute("CREATE EXTENSION proxima;")
|
51
|
+
self.conn.commit()
|
52
|
+
|
53
|
+
log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
|
54
|
+
if not any(
|
55
|
+
(
|
56
|
+
self.case_config.create_index_before_load,
|
57
|
+
self.case_config.create_index_after_load,
|
58
|
+
),
|
59
|
+
):
|
60
|
+
msg = (
|
61
|
+
f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
|
62
|
+
f"{self.name} config values: {self.db_config}\n{self.case_config}"
|
63
|
+
)
|
64
|
+
log.error(msg)
|
65
|
+
raise RuntimeError(msg)
|
66
|
+
|
67
|
+
if drop_old:
|
68
|
+
self._drop_table()
|
69
|
+
self._create_table(dim)
|
70
|
+
if self.case_config.create_index_before_load:
|
71
|
+
self._create_index()
|
72
|
+
|
73
|
+
self.cursor.close()
|
74
|
+
self.conn.close()
|
75
|
+
self.cursor = None
|
76
|
+
self.conn = None
|
77
|
+
|
78
|
+
@staticmethod
|
79
|
+
def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
|
80
|
+
conn = psycopg.connect(**kwargs)
|
81
|
+
conn.autocommit = True
|
82
|
+
cursor = conn.cursor()
|
83
|
+
|
84
|
+
assert conn is not None, "Connection is not initialized"
|
85
|
+
assert cursor is not None, "Cursor is not initialized"
|
86
|
+
|
87
|
+
return conn, cursor
|
88
|
+
|
89
|
+
@contextmanager
|
90
|
+
def init(self) -> Generator[None, None, None]:
|
91
|
+
"""
|
92
|
+
Examples:
|
93
|
+
>>> with self.init():
|
94
|
+
>>> self.insert_embeddings()
|
95
|
+
>>> self.search_embedding()
|
96
|
+
"""
|
97
|
+
|
98
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
99
|
+
|
100
|
+
self._set_search_guc()
|
101
|
+
|
102
|
+
try:
|
103
|
+
yield
|
104
|
+
finally:
|
105
|
+
self.cursor.close()
|
106
|
+
self.conn.close()
|
107
|
+
self.cursor = None
|
108
|
+
self.conn = None
|
109
|
+
|
110
|
+
def _set_search_guc(self):
|
111
|
+
assert self.conn is not None, "Connection is not initialized"
|
112
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
113
|
+
|
114
|
+
sql_guc = sql.SQL(f"SET hg_vector_ef_search = {self.case_config.ef_search};")
|
115
|
+
log.info(f"{self.name} client set search guc: {sql_guc.as_string()}")
|
116
|
+
self.cursor.execute(sql_guc)
|
117
|
+
self.conn.commit()
|
118
|
+
|
119
|
+
def _drop_table(self):
|
120
|
+
assert self.conn is not None, "Connection is not initialized"
|
121
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
122
|
+
|
123
|
+
log.info(f"{self.name} client drop table : {self.table_name}")
|
124
|
+
self.cursor.execute(
|
125
|
+
sql.SQL("DROP TABLE IF EXISTS {table_name};").format(
|
126
|
+
table_name=sql.Identifier(self.table_name),
|
127
|
+
),
|
128
|
+
)
|
129
|
+
self.conn.commit()
|
130
|
+
|
131
|
+
try:
|
132
|
+
log.info(f"{self.name} client purge table recycle bin: {self.table_name}")
|
133
|
+
self.cursor.execute(
|
134
|
+
sql.SQL("purge TABLE {table_name};").format(
|
135
|
+
table_name=sql.Identifier(self.table_name),
|
136
|
+
),
|
137
|
+
)
|
138
|
+
except Exception as e:
|
139
|
+
log.info(f"{self.name} client purge table {self.table_name} recycle bin failed, error: {e}, ignore.")
|
140
|
+
finally:
|
141
|
+
self.conn.commit()
|
142
|
+
|
143
|
+
try:
|
144
|
+
log.info(f"{self.name} client drop table group : {self._tg_name}")
|
145
|
+
self.cursor.execute(sql.SQL(f"CALL HG_DROP_TABLE_GROUP('{self._tg_name}');"))
|
146
|
+
except Exception as e:
|
147
|
+
log.info(f"{self.name} client drop table group : {self._tg_name} failed, error: {e}, ignore.")
|
148
|
+
finally:
|
149
|
+
self.conn.commit()
|
150
|
+
|
151
|
+
try:
|
152
|
+
log.info(f"{self.name} client free cache")
|
153
|
+
self.cursor.execute("select hg_admin_command('freecache');")
|
154
|
+
except Exception as e:
|
155
|
+
log.info(f"{self.name} client free cache failed, error: {e}, ignore.")
|
156
|
+
finally:
|
157
|
+
self.conn.commit()
|
158
|
+
|
159
|
+
def optimize(self, data_size: int | None = None):
|
160
|
+
if self.case_config.create_index_after_load:
|
161
|
+
self._create_index()
|
162
|
+
self._full_compact()
|
163
|
+
self._analyze()
|
164
|
+
|
165
|
+
def _vacuum(self):
|
166
|
+
log.info(f"{self.name} client vacuum table : {self.table_name}")
|
167
|
+
try:
|
168
|
+
# VACUUM cannot run inside a transaction block
|
169
|
+
# it's better to new a connection
|
170
|
+
self.conn.autocommit = True
|
171
|
+
with self.conn.cursor() as cursor:
|
172
|
+
cursor.execute(
|
173
|
+
sql.SQL(
|
174
|
+
"""
|
175
|
+
VACUUM {table_name};
|
176
|
+
"""
|
177
|
+
).format(
|
178
|
+
table_name=sql.Identifier(self.table_name),
|
179
|
+
)
|
180
|
+
)
|
181
|
+
log.info(f"{self.name} client vacuum table : {self.table_name} done")
|
182
|
+
except Exception as e:
|
183
|
+
log.warning(f"Failed to vacuum table: {self.table_name} error: {e}")
|
184
|
+
raise e from None
|
185
|
+
finally:
|
186
|
+
self.conn.autocommit = True
|
187
|
+
|
188
|
+
def _analyze(self):
|
189
|
+
log.info(f"{self.name} client analyze table : {self.table_name}")
|
190
|
+
self.cursor.execute(sql.SQL(f"ANALYZE {self.table_name};"))
|
191
|
+
log.info(f"{self.name} client analyze table : {self.table_name} done")
|
192
|
+
|
193
|
+
def _full_compact(self):
|
194
|
+
log.info(f"{self.name} client full compact table : {self.table_name}")
|
195
|
+
self.cursor.execute(
|
196
|
+
sql.SQL(
|
197
|
+
"""
|
198
|
+
SELECT hologres.hg_full_compact_table(
|
199
|
+
'{table_name}',
|
200
|
+
'max_file_size_mb={full_compact_max_file_size_mb}'
|
201
|
+
);
|
202
|
+
"""
|
203
|
+
).format(
|
204
|
+
table_name=sql.SQL(self.table_name),
|
205
|
+
full_compact_max_file_size_mb=sql.SQL(str(self.case_config.full_compact_max_file_size_mb)),
|
206
|
+
)
|
207
|
+
)
|
208
|
+
log.info(f"{self.name} client full compact table : {self.table_name} done")
|
209
|
+
|
210
|
+
def _create_index(self):
|
211
|
+
assert self.conn is not None, "Connection is not initialized"
|
212
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
213
|
+
|
214
|
+
sql_index = sql.SQL(
|
215
|
+
"""
|
216
|
+
CALL set_table_property ('{table_name}', 'vectors', '{{
|
217
|
+
"embedding": {{
|
218
|
+
"algorithm": "{algorithm}",
|
219
|
+
"distance_method": "{distance_method}",
|
220
|
+
"builder_params": {builder_params}
|
221
|
+
}}
|
222
|
+
}}');
|
223
|
+
"""
|
224
|
+
).format(
|
225
|
+
table_name=sql.Identifier(self.table_name),
|
226
|
+
algorithm=sql.SQL(self.case_config.algorithm()),
|
227
|
+
distance_method=sql.SQL(self.case_config.distance_method()),
|
228
|
+
builder_params=sql.SQL(json.dumps(self.case_config.builder_params())),
|
229
|
+
)
|
230
|
+
|
231
|
+
log.info(f"{self.name} client create index on table : {self.table_name}, with sql: {sql_index.as_string()}")
|
232
|
+
try:
|
233
|
+
self.cursor.execute(sql_index)
|
234
|
+
self.conn.commit()
|
235
|
+
except Exception as e:
|
236
|
+
log.warning(f"Failed to create index on table: {self.table_name} error: {e}")
|
237
|
+
raise e from None
|
238
|
+
|
239
|
+
def _set_replica_count(self, replica_count: int = 2):
|
240
|
+
assert self.conn is not None, "Connection is not initialized"
|
241
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
242
|
+
|
243
|
+
try:
|
244
|
+
# non-warehouse mode by default
|
245
|
+
sql_tg_replica = sql.SQL(
|
246
|
+
f"CALL hg_set_table_group_property('{self._tg_name}', 'replica_count', '{replica_count}');"
|
247
|
+
)
|
248
|
+
|
249
|
+
# check warehouse mode
|
250
|
+
sql_check = sql.SQL("select count(*) from hologres.hg_warehouses;")
|
251
|
+
log.info(f"check warehouse mode with sql: {sql_check}")
|
252
|
+
self.cursor.execute(sql_check)
|
253
|
+
result_check = self.cursor.fetchone()[0]
|
254
|
+
if result_check > 0:
|
255
|
+
# get warehouse name
|
256
|
+
sql_get_warehouse_name = sql.SQL("select current_warehouse();")
|
257
|
+
log.info(f"get warehouse name with sql: {sql_get_warehouse_name}")
|
258
|
+
self.cursor.execute(sql_get_warehouse_name)
|
259
|
+
sql_tg_replica = sql.SQL(
|
260
|
+
"""
|
261
|
+
CALL hg_table_group_set_warehouse_replica_count (
|
262
|
+
'{dbname}.{tg_name}',
|
263
|
+
{replica_count},
|
264
|
+
'{warehouse_name}'
|
265
|
+
);
|
266
|
+
"""
|
267
|
+
).format(
|
268
|
+
tg_name=sql.SQL(self._tg_name),
|
269
|
+
warehouse_name=sql.SQL(self.cursor.fetchone()[0]),
|
270
|
+
dbname=sql.SQL(self.db_config["dbname"]),
|
271
|
+
replica_count=replica_count,
|
272
|
+
)
|
273
|
+
log.info(f"{self.name} client set table group replica: {self._tg_name}, with sql: {sql_tg_replica}")
|
274
|
+
self.cursor.execute(sql_tg_replica)
|
275
|
+
except Exception as e:
|
276
|
+
log.warning(f"Failed to set replica count, error: {e}, ignore")
|
277
|
+
finally:
|
278
|
+
self.conn.commit()
|
279
|
+
|
280
|
+
def _create_table(self, dim: int):
|
281
|
+
assert self.conn is not None, "Connection is not initialized"
|
282
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
283
|
+
|
284
|
+
sql_tg = sql.SQL(f"CALL HG_CREATE_TABLE_GROUP ('{self._tg_name}', 1);")
|
285
|
+
log.info(f"{self.name} client create table group : {self._tg_name}, with sql: {sql_tg}")
|
286
|
+
try:
|
287
|
+
self.cursor.execute(sql_tg)
|
288
|
+
except Exception as e:
|
289
|
+
log.warning(f"Failed to create table group : {self._tg_name} error: {e}, ignore")
|
290
|
+
finally:
|
291
|
+
self.conn.commit()
|
292
|
+
|
293
|
+
self._set_replica_count(replica_count=2)
|
294
|
+
|
295
|
+
sql_table = sql.SQL(
|
296
|
+
"""
|
297
|
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
298
|
+
id BIGINT PRIMARY KEY,
|
299
|
+
embedding FLOAT4[] CHECK (array_ndims(embedding) = 1 AND array_length(embedding, 1) = {dim})
|
300
|
+
)
|
301
|
+
WITH (table_group = {tg_name});
|
302
|
+
"""
|
303
|
+
).format(
|
304
|
+
table_name=sql.Identifier(self.table_name),
|
305
|
+
dim=dim,
|
306
|
+
tg_name=sql.SQL(self._tg_name),
|
307
|
+
)
|
308
|
+
log.info(f"{self.name} client create table : {self.table_name}, with sql: {sql_table.as_string()}")
|
309
|
+
try:
|
310
|
+
self.cursor.execute(sql_table)
|
311
|
+
self.conn.commit()
|
312
|
+
except Exception as e:
|
313
|
+
log.warning(f"Failed to create table : {self.table_name} error: {e}")
|
314
|
+
raise e from None
|
315
|
+
|
316
|
+
def insert_embeddings(
|
317
|
+
self,
|
318
|
+
embeddings: list[list[float]],
|
319
|
+
metadata: list[int],
|
320
|
+
**kwargs: Any,
|
321
|
+
) -> tuple[int, Exception | None]:
|
322
|
+
assert self.conn is not None, "Connection is not initialized"
|
323
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
324
|
+
|
325
|
+
try:
|
326
|
+
buffer = StringIO()
|
327
|
+
for i in range(len(metadata)):
|
328
|
+
buffer.write("%d\t%s\n" % (metadata[i], "{" + ",".join("%f" % x for x in embeddings[i]) + "}"))
|
329
|
+
buffer.seek(0)
|
330
|
+
|
331
|
+
with self.cursor.copy(
|
332
|
+
sql.SQL("COPY {table_name} FROM STDIN").format(table_name=sql.Identifier(self.table_name))
|
333
|
+
) as copy:
|
334
|
+
copy.write(buffer.getvalue())
|
335
|
+
self.conn.commit()
|
336
|
+
|
337
|
+
return len(metadata), None
|
338
|
+
except Exception as e:
|
339
|
+
log.warning(f"Failed to insert data into table ({self.table_name}), error: {e}")
|
340
|
+
return 0, e
|
341
|
+
|
342
|
+
def _compose_query_and_params(self, vec: list[float], topk: int, ge_id: int | None = None):
|
343
|
+
params = []
|
344
|
+
|
345
|
+
where_clause = sql.SQL("")
|
346
|
+
if ge_id is not None:
|
347
|
+
where_clause = sql.SQL(" WHERE id >= %s ")
|
348
|
+
params.append(ge_id)
|
349
|
+
|
350
|
+
vec_float4 = [psycopg._wrappers.Float4(i) for i in vec]
|
351
|
+
params.append(vec_float4)
|
352
|
+
params.append(topk)
|
353
|
+
|
354
|
+
query = sql.SQL(
|
355
|
+
"""
|
356
|
+
SELECT id
|
357
|
+
FROM {table_name}
|
358
|
+
{where_clause}
|
359
|
+
ORDER BY {distance_function}(embedding, %b)
|
360
|
+
{order_direction}
|
361
|
+
LIMIT %s;
|
362
|
+
"""
|
363
|
+
).format(
|
364
|
+
table_name=sql.Identifier(self.table_name),
|
365
|
+
distance_function=sql.SQL(self.case_config.distance_function()),
|
366
|
+
where_clause=where_clause,
|
367
|
+
order_direction=sql.SQL(self.case_config.order_direction()),
|
368
|
+
)
|
369
|
+
|
370
|
+
return query, params
|
371
|
+
|
372
|
+
def search_embedding(
|
373
|
+
self,
|
374
|
+
query: list[float],
|
375
|
+
k: int = 100,
|
376
|
+
filters: dict | None = None,
|
377
|
+
timeout: int | None = None,
|
378
|
+
) -> list[int]:
|
379
|
+
assert self.conn is not None, "Connection is not initialized"
|
380
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
381
|
+
|
382
|
+
ge = filters.get("id") if filters else None
|
383
|
+
q, params = self._compose_query_and_params(query, k, ge)
|
384
|
+
result = self.cursor.execute(q, params, prepare=True, binary=True)
|
385
|
+
return [int(i[0]) for i in result.fetchall()]
|
@@ -40,6 +40,17 @@ class MilvusTypedDict(TypedDict):
|
|
40
40
|
show_default=True,
|
41
41
|
),
|
42
42
|
]
|
43
|
+
replica_number: Annotated[
|
44
|
+
int,
|
45
|
+
click.option(
|
46
|
+
"--replica-number",
|
47
|
+
type=int,
|
48
|
+
help="Number of replicas",
|
49
|
+
required=False,
|
50
|
+
default=1,
|
51
|
+
show_default=True,
|
52
|
+
),
|
53
|
+
]
|
43
54
|
|
44
55
|
|
45
56
|
class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict): ...
|
@@ -58,6 +69,7 @@ def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]):
|
|
58
69
|
user=parameters["user_name"],
|
59
70
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
60
71
|
num_shards=int(parameters["num_shards"]),
|
72
|
+
replica_number=int(parameters["replica_number"]),
|
61
73
|
),
|
62
74
|
db_case_config=AutoIndexConfig(),
|
63
75
|
**parameters,
|
@@ -77,6 +89,7 @@ def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]):
|
|
77
89
|
user=parameters["user_name"],
|
78
90
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
79
91
|
num_shards=int(parameters["num_shards"]),
|
92
|
+
replica_number=int(parameters["replica_number"]),
|
80
93
|
),
|
81
94
|
db_case_config=FLATConfig(),
|
82
95
|
**parameters,
|
@@ -99,6 +112,7 @@ def MilvusHNSW(**parameters: Unpack[MilvusHNSWTypedDict]):
|
|
99
112
|
user=parameters["user_name"],
|
100
113
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
101
114
|
num_shards=int(parameters["num_shards"]),
|
115
|
+
replica_number=int(parameters["replica_number"]),
|
102
116
|
),
|
103
117
|
db_case_config=HNSWConfig(
|
104
118
|
M=parameters["m"],
|
@@ -163,6 +177,7 @@ def MilvusHNSWPQ(**parameters: Unpack[MilvusHNSWPQTypedDict]):
|
|
163
177
|
user=parameters["user_name"],
|
164
178
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
165
179
|
num_shards=int(parameters["num_shards"]),
|
180
|
+
replica_number=int(parameters["replica_number"]),
|
166
181
|
),
|
167
182
|
db_case_config=HNSWPQConfig(
|
168
183
|
M=parameters["m"],
|
@@ -206,6 +221,7 @@ def MilvusHNSWPRQ(**parameters: Unpack[MilvusHNSWPRQTypedDict]):
|
|
206
221
|
user=parameters["user_name"],
|
207
222
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
208
223
|
num_shards=int(parameters["num_shards"]),
|
224
|
+
replica_number=int(parameters["replica_number"]),
|
209
225
|
),
|
210
226
|
db_case_config=HNSWPRQConfig(
|
211
227
|
M=parameters["m"],
|
@@ -246,6 +262,7 @@ def MilvusHNSWSQ(**parameters: Unpack[MilvusHNSWSQTypedDict]):
|
|
246
262
|
user=parameters["user_name"],
|
247
263
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
248
264
|
num_shards=int(parameters["num_shards"]),
|
265
|
+
replica_number=int(parameters["replica_number"]),
|
249
266
|
),
|
250
267
|
db_case_config=HNSWSQConfig(
|
251
268
|
M=parameters["m"],
|
@@ -276,6 +293,7 @@ def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]):
|
|
276
293
|
user=parameters["user_name"],
|
277
294
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
278
295
|
num_shards=int(parameters["num_shards"]),
|
296
|
+
replica_number=int(parameters["replica_number"]),
|
279
297
|
),
|
280
298
|
db_case_config=IVFFlatConfig(
|
281
299
|
nlist=parameters["nlist"],
|
@@ -298,6 +316,7 @@ def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]):
|
|
298
316
|
user=parameters["user_name"],
|
299
317
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
300
318
|
num_shards=int(parameters["num_shards"]),
|
319
|
+
replica_number=int(parameters["replica_number"]),
|
301
320
|
),
|
302
321
|
db_case_config=IVFSQ8Config(
|
303
322
|
nlist=parameters["nlist"],
|
@@ -359,6 +378,7 @@ def MilvusIVFRabitQ(**parameters: Unpack[MilvusIVFRABITQTypedDict]):
|
|
359
378
|
user=parameters["user_name"],
|
360
379
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
361
380
|
num_shards=int(parameters["num_shards"]),
|
381
|
+
replica_number=int(parameters["replica_number"]),
|
362
382
|
),
|
363
383
|
db_case_config=IVFRABITQConfig(
|
364
384
|
nlist=parameters["nlist"],
|
@@ -389,6 +409,7 @@ def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]):
|
|
389
409
|
user=parameters["user_name"],
|
390
410
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
391
411
|
num_shards=int(parameters["num_shards"]),
|
412
|
+
replica_number=int(parameters["replica_number"]),
|
392
413
|
),
|
393
414
|
db_case_config=DISKANNConfig(
|
394
415
|
search_list=parameters["search_list"],
|
@@ -418,6 +439,7 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
|
|
418
439
|
user=parameters["user_name"],
|
419
440
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
420
441
|
num_shards=int(parameters["num_shards"]),
|
442
|
+
replica_number=int(parameters["replica_number"]),
|
421
443
|
),
|
422
444
|
db_case_config=GPUIVFFlatConfig(
|
423
445
|
nlist=parameters["nlist"],
|
@@ -453,6 +475,7 @@ def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
|
|
453
475
|
user=parameters["user_name"],
|
454
476
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
455
477
|
num_shards=int(parameters["num_shards"]),
|
478
|
+
replica_number=int(parameters["replica_number"]),
|
456
479
|
),
|
457
480
|
db_case_config=GPUBruteForceConfig(
|
458
481
|
metric_type=parameters["metric_type"],
|
@@ -485,6 +508,7 @@ def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]):
|
|
485
508
|
user=parameters["user_name"],
|
486
509
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
487
510
|
num_shards=int(parameters["num_shards"]),
|
511
|
+
replica_number=int(parameters["replica_number"]),
|
488
512
|
),
|
489
513
|
db_case_config=GPUIVFPQConfig(
|
490
514
|
nlist=parameters["nlist"],
|
@@ -525,6 +549,7 @@ def MilvusGPUCAGRA(**parameters: Unpack[MilvusGPUCAGRATypedDict]):
|
|
525
549
|
user=parameters["user_name"],
|
526
550
|
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
527
551
|
num_shards=int(parameters["num_shards"]),
|
552
|
+
replica_number=int(parameters["replica_number"]),
|
528
553
|
),
|
529
554
|
db_case_config=GPUCAGRAConfig(
|
530
555
|
intermediate_graph_degree=parameters["intermediate_graph_degree"],
|
@@ -8,6 +8,7 @@ class MilvusConfig(DBConfig):
|
|
8
8
|
user: str | None = None
|
9
9
|
password: SecretStr | None = None
|
10
10
|
num_shards: int = 1
|
11
|
+
replica_number: int = 1
|
11
12
|
|
12
13
|
def to_dict(self) -> dict:
|
13
14
|
return {
|
@@ -15,6 +16,7 @@ class MilvusConfig(DBConfig):
|
|
15
16
|
"user": self.user if self.user else None,
|
16
17
|
"password": self.password.get_secret_value() if self.password else None,
|
17
18
|
"num_shards": self.num_shards,
|
19
|
+
"replica_number": self.replica_number,
|
18
20
|
}
|
19
21
|
|
20
22
|
@validator("*")
|
@@ -318,7 +320,6 @@ class GPUIVFFlatConfig(MilvusIndexConfig, DBCaseConfig):
|
|
318
320
|
|
319
321
|
class GPUBruteForceConfig(MilvusIndexConfig, DBCaseConfig):
|
320
322
|
limit: int = 10 # Default top-k for search
|
321
|
-
metric_type: str # Metric type (e.g., 'L2', 'IP', etc.)
|
322
323
|
index: IndexType = IndexType.GPU_BRUTE_FORCE # Index type set to GPU_BRUTE_FORCE
|
323
324
|
|
324
325
|
def index_param(self) -> dict:
|
@@ -93,6 +93,7 @@ def OceanBaseIVF(**parameters: Unpack[OceanBaseIVFTypedDict]):
|
|
93
93
|
m=input_m,
|
94
94
|
nlist=parameters["nlist"],
|
95
95
|
sample_per_nlist=parameters["sample_per_nlist"],
|
96
|
+
nbits=parameters["nbits"],
|
96
97
|
index=input_index_type,
|
97
98
|
ivf_nprobes=parameters["ivf_nprobes"],
|
98
99
|
),
|
@@ -85,6 +85,7 @@ class OceanBaseHNSWConfig(OceanBaseIndexConfig, DBCaseConfig):
|
|
85
85
|
class OceanBaseIVFConfig(OceanBaseIndexConfig, DBCaseConfig):
|
86
86
|
m: int
|
87
87
|
sample_per_nlist: int
|
88
|
+
nbits: int | None = None
|
88
89
|
nlist: int
|
89
90
|
index: IndexType
|
90
91
|
ivf_nprobes: int | None = None
|
@@ -96,8 +97,9 @@ class OceanBaseIVFConfig(OceanBaseIndexConfig, DBCaseConfig):
|
|
96
97
|
"metric_type": self.parse_metric(),
|
97
98
|
"index_type": self.index.value,
|
98
99
|
"params": {
|
99
|
-
"m": self.
|
100
|
+
"m": self.m,
|
100
101
|
"sample_per_nlist": self.sample_per_nlist,
|
102
|
+
"nbits": self.nbits,
|
101
103
|
"nlist": self.nlist,
|
102
104
|
},
|
103
105
|
}
|
@@ -7,6 +7,8 @@ from typing import Any
|
|
7
7
|
|
8
8
|
import mysql.connector as mysql
|
9
9
|
|
10
|
+
from vectordb_bench.backend.filter import Filter, FilterOp
|
11
|
+
|
10
12
|
from ..api import IndexType, VectorDB
|
11
13
|
from .config import OceanBaseConfigDict, OceanBaseHNSWConfig
|
12
14
|
|
@@ -16,6 +18,12 @@ OCEANBASE_DEFAULT_LOAD_BATCH_SIZE = 256
|
|
16
18
|
|
17
19
|
|
18
20
|
class OceanBase(VectorDB):
|
21
|
+
supported_filter_types: list[FilterOp] = [
|
22
|
+
FilterOp.NonFilter,
|
23
|
+
FilterOp.NumGE,
|
24
|
+
FilterOp.StrEqual,
|
25
|
+
]
|
26
|
+
|
19
27
|
def __init__(
|
20
28
|
self,
|
21
29
|
dim: int,
|
@@ -187,22 +195,30 @@ class OceanBase(VectorDB):
|
|
187
195
|
|
188
196
|
return insert_count, None
|
189
197
|
|
198
|
+
def prepare_filter(self, filters: Filter):
|
199
|
+
if filters.type == FilterOp.NonFilter:
|
200
|
+
self.expr = ""
|
201
|
+
elif filters.type == FilterOp.NumGE:
|
202
|
+
self.expr = f"WHERE id >= {filters.int_value}"
|
203
|
+
elif filters.type == FilterOp.StrEqual:
|
204
|
+
self.expr = f"WHERE id == '{filters.label_value}'"
|
205
|
+
else:
|
206
|
+
msg = f"Not support Filter for Oceanbase - {filters}"
|
207
|
+
raise ValueError(msg)
|
208
|
+
|
190
209
|
def search_embedding(
|
191
210
|
self,
|
192
211
|
query: list[float],
|
193
212
|
k: int = 100,
|
194
|
-
filters: dict[str, Any] | None = None,
|
195
|
-
timeout: int | None = None,
|
196
213
|
) -> list[int]:
|
197
214
|
if not self._cursor:
|
198
215
|
raise ValueError("Cursor is not initialized")
|
199
216
|
|
200
217
|
packed = struct.pack(f"<{len(query)}f", *query)
|
201
218
|
hex_vec = packed.hex()
|
202
|
-
filter_clause = f"WHERE id >= {filters['id']}" if filters else ""
|
203
219
|
query_str = (
|
204
220
|
f"SELECT id FROM {self.table_name} " # noqa: S608
|
205
|
-
f"{
|
221
|
+
f"{self.expr} ORDER BY "
|
206
222
|
f"{self.db_case_config.parse_metric_func_str()}(embedding, X'{hex_vec}') "
|
207
223
|
f"APPROXIMATE LIMIT {k}"
|
208
224
|
)
|