vectordb-bench 0.0.29__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -27
- vectordb_bench/backend/assembler.py +19 -6
- vectordb_bench/backend/cases.py +186 -23
- vectordb_bench/backend/clients/__init__.py +32 -0
- vectordb_bench/backend/clients/api.py +22 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +249 -43
- vectordb_bench/backend/clients/aws_opensearch/cli.py +51 -21
- vectordb_bench/backend/clients/aws_opensearch/config.py +58 -16
- vectordb_bench/backend/clients/chroma/chroma.py +6 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
- vectordb_bench/backend/clients/lancedb/cli.py +62 -8
- vectordb_bench/backend/clients/lancedb/config.py +14 -1
- vectordb_bench/backend/clients/lancedb/lancedb.py +21 -9
- vectordb_bench/backend/clients/memorydb/memorydb.py +2 -2
- vectordb_bench/backend/clients/milvus/cli.py +30 -9
- vectordb_bench/backend/clients/milvus/config.py +3 -0
- vectordb_bench/backend/clients/milvus/milvus.py +81 -23
- vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
- vectordb_bench/backend/clients/oceanbase/config.py +125 -0
- vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
- vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
- vectordb_bench/backend/clients/qdrant_local/cli.py +60 -0
- vectordb_bench/backend/clients/qdrant_local/config.py +47 -0
- vectordb_bench/backend/clients/qdrant_local/qdrant_local.py +232 -0
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +29 -3
- vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -0
- vectordb_bench/backend/dataset.py +143 -27
- vectordb_bench/backend/filter.py +76 -0
- vectordb_bench/backend/runner/__init__.py +3 -3
- vectordb_bench/backend/runner/mp_runner.py +52 -39
- vectordb_bench/backend/runner/rate_runner.py +68 -52
- vectordb_bench/backend/runner/read_write_runner.py +125 -68
- vectordb_bench/backend/runner/serial_runner.py +56 -23
- vectordb_bench/backend/task_runner.py +48 -20
- vectordb_bench/cli/batch_cli.py +121 -0
- vectordb_bench/cli/cli.py +59 -1
- vectordb_bench/cli/vectordbbench.py +7 -0
- vectordb_bench/config-files/batch_sample_config.yml +17 -0
- vectordb_bench/frontend/components/check_results/data.py +16 -11
- vectordb_bench/frontend/components/check_results/filters.py +53 -25
- vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
- vectordb_bench/frontend/components/check_results/nav.py +20 -0
- vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
- vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
- vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
- vectordb_bench/frontend/components/label_filter/charts.py +60 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
- vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
- vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
- vectordb_bench/frontend/components/streaming/charts.py +253 -0
- vectordb_bench/frontend/components/streaming/data.py +62 -0
- vectordb_bench/frontend/components/tables/data.py +1 -1
- vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
- vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
- vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +420 -41
- vectordb_bench/frontend/config/styles.py +32 -2
- vectordb_bench/frontend/pages/concurrent.py +5 -1
- vectordb_bench/frontend/pages/custom.py +4 -0
- vectordb_bench/frontend/pages/label_filter.py +56 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
- vectordb_bench/frontend/pages/results.py +60 -0
- vectordb_bench/frontend/pages/run_test.py +3 -3
- vectordb_bench/frontend/pages/streaming.py +135 -0
- vectordb_bench/frontend/pages/tables.py +4 -0
- vectordb_bench/frontend/vdb_benchmark.py +16 -41
- vectordb_bench/interface.py +6 -2
- vectordb_bench/metric.py +15 -1
- vectordb_bench/models.py +38 -11
- vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
- vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
- vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
- vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
- vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
- vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
- vectordb_bench/results/dbPrices.json +12 -4
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +131 -32
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +87 -65
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +1 -1
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,215 @@
|
|
1
|
+
import logging
|
2
|
+
import struct
|
3
|
+
import time
|
4
|
+
from collections.abc import Generator
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
import mysql.connector as mysql
|
9
|
+
|
10
|
+
from ..api import IndexType, VectorDB
|
11
|
+
from .config import OceanBaseConfigDict, OceanBaseHNSWConfig
|
12
|
+
|
13
|
+
log = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
OCEANBASE_DEFAULT_LOAD_BATCH_SIZE = 256
|
16
|
+
|
17
|
+
|
18
|
+
class OceanBase(VectorDB):
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
dim: int,
|
22
|
+
db_config: OceanBaseConfigDict,
|
23
|
+
db_case_config: OceanBaseHNSWConfig,
|
24
|
+
collection_name: str = "items",
|
25
|
+
drop_old: bool = False,
|
26
|
+
**kwargs,
|
27
|
+
):
|
28
|
+
self.name = "OceanBase"
|
29
|
+
self.dim = dim
|
30
|
+
self.db_config = db_config
|
31
|
+
self.db_case_config = db_case_config
|
32
|
+
self.table_name = collection_name
|
33
|
+
self.load_batch_size = OCEANBASE_DEFAULT_LOAD_BATCH_SIZE
|
34
|
+
self._index_name = "vidx"
|
35
|
+
self._primary_field = "id"
|
36
|
+
self._vector_field = "embedding"
|
37
|
+
|
38
|
+
log.info(
|
39
|
+
f"{self.name} initialized with config:\nDatabase: {self.db_config}\nCase Config: {self.db_case_config}"
|
40
|
+
)
|
41
|
+
|
42
|
+
self._conn = None
|
43
|
+
self._cursor = None
|
44
|
+
|
45
|
+
try:
|
46
|
+
self._connect()
|
47
|
+
if drop_old:
|
48
|
+
self._drop_table()
|
49
|
+
self._create_table()
|
50
|
+
finally:
|
51
|
+
self._disconnect()
|
52
|
+
|
53
|
+
def _connect(self):
|
54
|
+
try:
|
55
|
+
self._conn = mysql.connect(
|
56
|
+
host=self.db_config["host"],
|
57
|
+
user=self.db_config["user"],
|
58
|
+
port=self.db_config["port"],
|
59
|
+
password=self.db_config["password"],
|
60
|
+
database=self.db_config["database"],
|
61
|
+
)
|
62
|
+
self._cursor = self._conn.cursor()
|
63
|
+
except mysql.Error:
|
64
|
+
log.exception("Failed to connect to the database")
|
65
|
+
raise
|
66
|
+
|
67
|
+
def _disconnect(self):
|
68
|
+
if self._cursor:
|
69
|
+
self._cursor.close()
|
70
|
+
self._cursor = None
|
71
|
+
if self._conn:
|
72
|
+
self._conn.close()
|
73
|
+
self._conn = None
|
74
|
+
|
75
|
+
@contextmanager
|
76
|
+
def init(self) -> Generator[None, None, None]:
|
77
|
+
try:
|
78
|
+
self._connect()
|
79
|
+
self._cursor.execute("SET autocommit=1")
|
80
|
+
|
81
|
+
if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}:
|
82
|
+
self._cursor.execute(
|
83
|
+
f"SET ob_hnsw_ef_search={(self.db_case_config.search_param())['params']['ef_search']}"
|
84
|
+
)
|
85
|
+
else:
|
86
|
+
self._cursor.execute(
|
87
|
+
f"SET ob_ivf_nprobes={(self.db_case_config.search_param())['params']['ivf_nprobes']}"
|
88
|
+
)
|
89
|
+
yield
|
90
|
+
finally:
|
91
|
+
self._disconnect()
|
92
|
+
|
93
|
+
def _drop_table(self):
|
94
|
+
if not self._cursor:
|
95
|
+
raise ValueError("Cursor is not initialized")
|
96
|
+
|
97
|
+
log.info(f"Dropping table {self.table_name}")
|
98
|
+
self._cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
99
|
+
|
100
|
+
def _create_table(self):
|
101
|
+
if not self._cursor:
|
102
|
+
raise ValueError("Cursor is not initialized")
|
103
|
+
|
104
|
+
log.info(f"Creating table {self.table_name}")
|
105
|
+
create_table_query = f"""
|
106
|
+
CREATE TABLE {self.table_name} (
|
107
|
+
id INT PRIMARY KEY,
|
108
|
+
embedding VECTOR({self.dim})
|
109
|
+
);
|
110
|
+
"""
|
111
|
+
self._cursor.execute(create_table_query)
|
112
|
+
|
113
|
+
def optimize(self, data_size: int):
|
114
|
+
index_params = self.db_case_config.index_param()
|
115
|
+
index_args = ", ".join(f"{k}={v}" for k, v in index_params["params"].items())
|
116
|
+
index_query = (
|
117
|
+
f"CREATE /*+ PARALLEL(18) */ VECTOR INDEX idx1 "
|
118
|
+
f"ON {self.table_name}(embedding) "
|
119
|
+
f"WITH (distance={self.db_case_config.parse_metric()}, "
|
120
|
+
f"type={index_params['index_type']}, lib={index_params['lib']}, {index_args}"
|
121
|
+
)
|
122
|
+
|
123
|
+
if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}:
|
124
|
+
index_query += ", extra_info_max_size=32"
|
125
|
+
|
126
|
+
index_query += ")"
|
127
|
+
|
128
|
+
log.info("Create index query: %s", index_query)
|
129
|
+
|
130
|
+
try:
|
131
|
+
log.info("Creating index...")
|
132
|
+
start_time = time.time()
|
133
|
+
self._cursor.execute(index_query)
|
134
|
+
log.info(f"Index created in {time.time() - start_time:.2f} seconds")
|
135
|
+
|
136
|
+
log.info("Performing major freeze...")
|
137
|
+
self._cursor.execute("ALTER SYSTEM MAJOR FREEZE;")
|
138
|
+
time.sleep(10)
|
139
|
+
self._wait_for_major_compaction()
|
140
|
+
|
141
|
+
log.info("Gathering schema statistics...")
|
142
|
+
self._cursor.execute("CALL dbms_stats.gather_schema_stats('test', degree => 96);")
|
143
|
+
except mysql.Error:
|
144
|
+
log.exception("Failed to optimize index")
|
145
|
+
raise
|
146
|
+
|
147
|
+
def need_normalize_cosine(self) -> bool:
|
148
|
+
if self.db_case_config.index == IndexType.HNSW_BQ:
|
149
|
+
log.info("current HNSW_BQ only supports L2, cosine dataset need normalize.")
|
150
|
+
return True
|
151
|
+
|
152
|
+
return False
|
153
|
+
|
154
|
+
def _wait_for_major_compaction(self):
|
155
|
+
while True:
|
156
|
+
self._cursor.execute(
|
157
|
+
"SELECT IF(COUNT(*) = COUNT(STATUS = 'IDLE' OR NULL), 'TRUE', 'FALSE') "
|
158
|
+
"AS all_status_idle FROM oceanbase.DBA_OB_ZONE_MAJOR_COMPACTION;"
|
159
|
+
)
|
160
|
+
all_status_idle = self._cursor.fetchone()[0]
|
161
|
+
if all_status_idle == "TRUE":
|
162
|
+
break
|
163
|
+
time.sleep(10)
|
164
|
+
|
165
|
+
def insert_embeddings(
|
166
|
+
self,
|
167
|
+
embeddings: list[list[float]],
|
168
|
+
metadata: list[int],
|
169
|
+
**kwargs: Any,
|
170
|
+
) -> tuple[int, Exception | None]:
|
171
|
+
if not self._cursor:
|
172
|
+
raise ValueError("Cursor is not initialized")
|
173
|
+
|
174
|
+
insert_count = 0
|
175
|
+
try:
|
176
|
+
for batch_start in range(0, len(embeddings), self.load_batch_size):
|
177
|
+
batch_end = min(batch_start + self.load_batch_size, len(embeddings))
|
178
|
+
batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)]
|
179
|
+
values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch)
|
180
|
+
self._cursor.execute(
|
181
|
+
f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}" # noqa: S608
|
182
|
+
)
|
183
|
+
insert_count += len(batch)
|
184
|
+
except mysql.Error:
|
185
|
+
log.exception("Failed to insert embeddings")
|
186
|
+
raise
|
187
|
+
|
188
|
+
return insert_count, None
|
189
|
+
|
190
|
+
def search_embedding(
|
191
|
+
self,
|
192
|
+
query: list[float],
|
193
|
+
k: int = 100,
|
194
|
+
filters: dict[str, Any] | None = None,
|
195
|
+
timeout: int | None = None,
|
196
|
+
) -> list[int]:
|
197
|
+
if not self._cursor:
|
198
|
+
raise ValueError("Cursor is not initialized")
|
199
|
+
|
200
|
+
packed = struct.pack(f"<{len(query)}f", *query)
|
201
|
+
hex_vec = packed.hex()
|
202
|
+
filter_clause = f"WHERE id >= {filters['id']}" if filters else ""
|
203
|
+
query_str = (
|
204
|
+
f"SELECT id FROM {self.table_name} " # noqa: S608
|
205
|
+
f"{filter_clause} ORDER BY "
|
206
|
+
f"{self.db_case_config.parse_metric_func_str()}(embedding, X'{hex_vec}') "
|
207
|
+
f"APPROXIMATE LIMIT {k}"
|
208
|
+
)
|
209
|
+
|
210
|
+
try:
|
211
|
+
self._cursor.execute(query_str)
|
212
|
+
return [row[0] for row in self._cursor.fetchall()]
|
213
|
+
except mysql.Error:
|
214
|
+
log.exception("Failed to execute search query")
|
215
|
+
raise
|
@@ -5,8 +5,9 @@ from contextlib import contextmanager
|
|
5
5
|
|
6
6
|
import pinecone
|
7
7
|
|
8
|
-
from
|
9
|
-
|
8
|
+
from vectordb_bench.backend.filter import Filter, FilterOp
|
9
|
+
|
10
|
+
from ..api import DBCaseConfig, VectorDB
|
10
11
|
|
11
12
|
log = logging.getLogger(__name__)
|
12
13
|
|
@@ -15,12 +16,19 @@ PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
|
|
15
16
|
|
16
17
|
|
17
18
|
class Pinecone(VectorDB):
|
19
|
+
supported_filter_types: list[FilterOp] = [
|
20
|
+
FilterOp.NonFilter,
|
21
|
+
FilterOp.NumGE,
|
22
|
+
FilterOp.StrEqual,
|
23
|
+
]
|
24
|
+
|
18
25
|
def __init__(
|
19
26
|
self,
|
20
27
|
dim: int,
|
21
28
|
db_config: dict,
|
22
29
|
db_case_config: DBCaseConfig,
|
23
30
|
drop_old: bool = False,
|
31
|
+
with_scalar_labels: bool = False,
|
24
32
|
**kwargs,
|
25
33
|
):
|
26
34
|
"""Initialize wrapper around the milvus vector database."""
|
@@ -33,6 +41,7 @@ class Pinecone(VectorDB):
|
|
33
41
|
pc = pinecone.Pinecone(api_key=self.api_key)
|
34
42
|
index = pc.Index(self.index_name)
|
35
43
|
|
44
|
+
self.with_scalar_labels = with_scalar_labels
|
36
45
|
if drop_old:
|
37
46
|
index_stats = index.describe_index_stats()
|
38
47
|
index_dim = index_stats["dimension"]
|
@@ -43,15 +52,8 @@ class Pinecone(VectorDB):
|
|
43
52
|
log.info(f"Pinecone index delete namespace: {namespace}")
|
44
53
|
index.delete(delete_all=True, namespace=namespace)
|
45
54
|
|
46
|
-
self.
|
47
|
-
|
48
|
-
@classmethod
|
49
|
-
def config_cls(cls) -> type[DBConfig]:
|
50
|
-
return PineconeConfig
|
51
|
-
|
52
|
-
@classmethod
|
53
|
-
def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]:
|
54
|
-
return EmptyDBCaseConfig
|
55
|
+
self._scalar_id_field = "meta"
|
56
|
+
self._scalar_label_field = "label"
|
55
57
|
|
56
58
|
@contextmanager
|
57
59
|
def init(self):
|
@@ -66,6 +68,7 @@ class Pinecone(VectorDB):
|
|
66
68
|
self,
|
67
69
|
embeddings: list[list[float]],
|
68
70
|
metadata: list[int],
|
71
|
+
labels_data: list[str] | None = None,
|
69
72
|
**kwargs,
|
70
73
|
) -> tuple[int, Exception]:
|
71
74
|
assert len(embeddings) == len(metadata)
|
@@ -75,33 +78,44 @@ class Pinecone(VectorDB):
|
|
75
78
|
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
|
76
79
|
insert_datas = []
|
77
80
|
for i in range(batch_start_offset, batch_end_offset):
|
81
|
+
metadata_dict = {self._scalar_id_field: metadata[i]}
|
82
|
+
if self.with_scalar_labels:
|
83
|
+
metadata_dict[self._scalar_label_field] = labels_data[i]
|
78
84
|
insert_data = (
|
79
85
|
str(metadata[i]),
|
80
86
|
embeddings[i],
|
81
|
-
|
87
|
+
metadata_dict,
|
82
88
|
)
|
83
89
|
insert_datas.append(insert_data)
|
84
90
|
self.index.upsert(insert_datas)
|
85
91
|
insert_count += batch_end_offset - batch_start_offset
|
86
92
|
except Exception as e:
|
87
|
-
return
|
88
|
-
return
|
93
|
+
return insert_count, e
|
94
|
+
return len(embeddings), None
|
89
95
|
|
90
96
|
def search_embedding(
|
91
97
|
self,
|
92
98
|
query: list[float],
|
93
99
|
k: int = 100,
|
94
|
-
filters: dict | None = None,
|
95
100
|
timeout: int | None = None,
|
96
101
|
) -> list[int]:
|
97
|
-
pinecone_filters =
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
)["matches"]
|
104
|
-
except Exception as e:
|
105
|
-
log.warning(f"Error querying index: {e}")
|
106
|
-
raise e from e
|
102
|
+
pinecone_filters = self.expr
|
103
|
+
res = self.index.query(
|
104
|
+
top_k=k,
|
105
|
+
vector=query,
|
106
|
+
filter=pinecone_filters,
|
107
|
+
)["matches"]
|
107
108
|
return [int(one_res["id"]) for one_res in res]
|
109
|
+
|
110
|
+
def prepare_filter(self, filters: Filter):
|
111
|
+
if filters.type == FilterOp.NonFilter:
|
112
|
+
self.expr = None
|
113
|
+
elif filters.type == FilterOp.NumGE:
|
114
|
+
self.expr = {self._scalar_id_field: {"$gte": filters.int_value}}
|
115
|
+
elif filters.type == FilterOp.StrEqual:
|
116
|
+
# both "in" and "==" are supported
|
117
|
+
# for example, self.expr = {self._scalar_label_field: {"$in": [filters.label_value]}}
|
118
|
+
self.expr = {self._scalar_label_field: {"$eq": filters.label_value}}
|
119
|
+
else:
|
120
|
+
msg = f"Not support Filter for Pinecone - {filters}"
|
121
|
+
raise ValueError(msg)
|
@@ -1,7 +1,12 @@
|
|
1
|
-
from
|
1
|
+
from typing import TypeVar
|
2
|
+
|
3
|
+
from pydantic import BaseModel, SecretStr, validator
|
2
4
|
|
3
5
|
from ..api import DBCaseConfig, DBConfig, MetricType
|
4
6
|
|
7
|
+
# define type "SearchParams"
|
8
|
+
SearchParams = TypeVar("SearchParams")
|
9
|
+
|
5
10
|
|
6
11
|
# Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant.
|
7
12
|
class QdrantConfig(DBConfig):
|
@@ -20,9 +25,43 @@ class QdrantConfig(DBConfig):
|
|
20
25
|
"url": self.url.get_secret_value(),
|
21
26
|
}
|
22
27
|
|
28
|
+
@validator("*")
|
29
|
+
def not_empty_field(cls, v: any, field: any):
|
30
|
+
if field.name in ["api_key"]:
|
31
|
+
return v
|
32
|
+
return super().not_empty_field(v, field)
|
33
|
+
|
23
34
|
|
24
35
|
class QdrantIndexConfig(BaseModel, DBCaseConfig):
|
25
36
|
metric_type: MetricType | None = None
|
37
|
+
m: int = 16
|
38
|
+
payload_m: int = 16 # only for label_filter cases
|
39
|
+
create_payload_int_index: bool = False
|
40
|
+
create_payload_keyword_index: bool = False
|
41
|
+
is_tenant: bool = False
|
42
|
+
use_scalar_quant: bool = False
|
43
|
+
sq_quantile: float = 0.99
|
44
|
+
default_segment_number: int = 0
|
45
|
+
|
46
|
+
use_rescore: bool = False
|
47
|
+
oversampling: float = 1.0
|
48
|
+
indexed_only: bool = False
|
49
|
+
hnsw_ef: int | None = 100
|
50
|
+
exact: bool = False
|
51
|
+
|
52
|
+
with_payload: bool = False
|
53
|
+
|
54
|
+
def __eq__(self, obj: any):
|
55
|
+
return (
|
56
|
+
self.m == obj.m
|
57
|
+
and self.payload_m == obj.payload_m
|
58
|
+
and self.create_payload_int_index == obj.create_payload_int_index
|
59
|
+
and self.create_payload_keyword_index == obj.create_payload_keyword_index
|
60
|
+
and self.is_tenant == obj.is_tenant
|
61
|
+
and self.use_scalar_quant == obj.use_scalar_quant
|
62
|
+
and self.sq_quantile == obj.sq_quantile
|
63
|
+
and self.default_segment_number == obj.default_segment_number
|
64
|
+
)
|
26
65
|
|
27
66
|
def parse_metric(self) -> str:
|
28
67
|
if self.metric_type == MetricType.L2:
|
@@ -36,5 +75,22 @@ class QdrantIndexConfig(BaseModel, DBCaseConfig):
|
|
36
75
|
def index_param(self) -> dict:
|
37
76
|
return {"distance": self.parse_metric()}
|
38
77
|
|
39
|
-
def search_param(self) ->
|
40
|
-
|
78
|
+
def search_param(self) -> SearchParams:
|
79
|
+
# Import while in use
|
80
|
+
from qdrant_client.http.models import QuantizationSearchParams, SearchParams
|
81
|
+
|
82
|
+
quantization = (
|
83
|
+
QuantizationSearchParams(
|
84
|
+
ignore=False,
|
85
|
+
rescore=True,
|
86
|
+
oversampling=self.oversampling,
|
87
|
+
)
|
88
|
+
if self.use_rescore
|
89
|
+
else None
|
90
|
+
)
|
91
|
+
return SearchParams(
|
92
|
+
hnsw_ef=self.hnsw_ef,
|
93
|
+
exact=self.exact,
|
94
|
+
indexed_only=self.indexed_only,
|
95
|
+
quantization=quantization,
|
96
|
+
)
|
@@ -9,13 +9,24 @@ from qdrant_client.http.models import (
|
|
9
9
|
Batch,
|
10
10
|
CollectionStatus,
|
11
11
|
FieldCondition,
|
12
|
-
|
12
|
+
HnswConfigDiff,
|
13
|
+
KeywordIndexParams,
|
14
|
+
OptimizersConfigDiff,
|
13
15
|
PayloadSchemaType,
|
14
16
|
Range,
|
17
|
+
ScalarQuantization,
|
18
|
+
ScalarQuantizationConfig,
|
19
|
+
ScalarType,
|
15
20
|
VectorParams,
|
16
21
|
)
|
22
|
+
from qdrant_client.http.models import (
|
23
|
+
Filter as QdrantFilter,
|
24
|
+
)
|
17
25
|
|
18
|
-
from
|
26
|
+
from vectordb_bench.backend.clients.qdrant_cloud.config import QdrantIndexConfig
|
27
|
+
from vectordb_bench.backend.filter import Filter, FilterOp
|
28
|
+
|
29
|
+
from ..api import VectorDB
|
19
30
|
|
20
31
|
log = logging.getLogger(__name__)
|
21
32
|
|
@@ -25,24 +36,33 @@ QDRANT_BATCH_SIZE = 500
|
|
25
36
|
|
26
37
|
|
27
38
|
class QdrantCloud(VectorDB):
|
39
|
+
supported_filter_types: list[FilterOp] = [
|
40
|
+
FilterOp.NonFilter,
|
41
|
+
FilterOp.NumGE,
|
42
|
+
FilterOp.StrEqual,
|
43
|
+
]
|
44
|
+
|
28
45
|
def __init__(
|
29
46
|
self,
|
30
47
|
dim: int,
|
31
48
|
db_config: dict,
|
32
|
-
db_case_config:
|
49
|
+
db_case_config: QdrantIndexConfig,
|
33
50
|
collection_name: str = "QdrantCloudCollection",
|
34
51
|
drop_old: bool = False,
|
52
|
+
with_scalar_labels: bool = False,
|
35
53
|
**kwargs,
|
36
54
|
):
|
37
55
|
"""Initialize wrapper around the QdrantCloud vector database."""
|
38
56
|
self.db_config = db_config
|
39
|
-
self.
|
57
|
+
self.db_case_config = db_case_config
|
40
58
|
self.collection_name = collection_name
|
41
59
|
|
42
60
|
self._primary_field = "pk"
|
61
|
+
self._scalar_label_field = "label"
|
43
62
|
self._vector_field = "vector"
|
44
63
|
|
45
64
|
tmp_client = QdrantClient(**self.db_config)
|
65
|
+
self.with_scalar_labels = with_scalar_labels
|
46
66
|
if drop_old:
|
47
67
|
log.info(f"QdrantCloud client drop_old collection: {self.collection_name}")
|
48
68
|
tmp_client.delete_collection(self.collection_name)
|
@@ -50,7 +70,7 @@ class QdrantCloud(VectorDB):
|
|
50
70
|
tmp_client = None
|
51
71
|
|
52
72
|
@contextmanager
|
53
|
-
def init(self)
|
73
|
+
def init(self):
|
54
74
|
"""
|
55
75
|
Examples:
|
56
76
|
>>> with self.init():
|
@@ -74,7 +94,7 @@ class QdrantCloud(VectorDB):
|
|
74
94
|
if info.status == CollectionStatus.GREEN:
|
75
95
|
msg = (
|
76
96
|
f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, "
|
77
|
-
f"Collection status: {info.
|
97
|
+
f"Collection status: {info.status}, Segment counts: {info.segments_count}"
|
78
98
|
)
|
79
99
|
log.info(msg)
|
80
100
|
return
|
@@ -86,19 +106,48 @@ class QdrantCloud(VectorDB):
|
|
86
106
|
log.info(f"Create collection: {self.collection_name}")
|
87
107
|
|
88
108
|
try:
|
109
|
+
# whether to use quant (SQ8)
|
110
|
+
quantization_config = None
|
111
|
+
if self.db_case_config.use_scalar_quant:
|
112
|
+
quantization_config = ScalarQuantization(
|
113
|
+
scalar=ScalarQuantizationConfig(
|
114
|
+
type=ScalarType.INT8,
|
115
|
+
quantile=self.db_case_config.sq_quantile,
|
116
|
+
always_ram=True,
|
117
|
+
)
|
118
|
+
)
|
119
|
+
|
120
|
+
# create collection
|
89
121
|
qdrant_client.create_collection(
|
90
122
|
collection_name=self.collection_name,
|
91
123
|
vectors_config=VectorParams(
|
92
124
|
size=dim,
|
93
|
-
distance=self.
|
125
|
+
distance=self.db_case_config.parse_metric(),
|
126
|
+
),
|
127
|
+
hnsw_config=HnswConfigDiff(m=self.db_case_config.m, payload_m=self.db_case_config.payload_m),
|
128
|
+
optimizers_config=OptimizersConfigDiff(
|
129
|
+
default_segment_number=self.db_case_config.default_segment_number
|
94
130
|
),
|
131
|
+
quantization_config=quantization_config,
|
95
132
|
)
|
96
133
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
134
|
+
# create payload_index for int-field
|
135
|
+
if self.db_case_config.create_payload_int_index:
|
136
|
+
qdrant_client.create_payload_index(
|
137
|
+
collection_name=self.collection_name,
|
138
|
+
field_name=self._primary_field,
|
139
|
+
field_schema=PayloadSchemaType.INTEGER,
|
140
|
+
)
|
141
|
+
|
142
|
+
# create payload_index for str-field
|
143
|
+
if self.with_scalar_labels and self.db_case_config.create_payload_keyword_index:
|
144
|
+
qdrant_client.create_payload_index(
|
145
|
+
collection_name=self.collection_name,
|
146
|
+
field_name=self._scalar_label_field,
|
147
|
+
field_schema=KeywordIndexParams(
|
148
|
+
type=PayloadSchemaType.KEYWORD, is_tenant=self.db_case_config.is_tenant
|
149
|
+
),
|
150
|
+
)
|
102
151
|
|
103
152
|
except Exception as e:
|
104
153
|
if "already exists!" in str(e):
|
@@ -110,16 +159,22 @@ class QdrantCloud(VectorDB):
|
|
110
159
|
self,
|
111
160
|
embeddings: list[list[float]],
|
112
161
|
metadata: list[int],
|
162
|
+
labels_data: list[str] | None = None,
|
113
163
|
**kwargs,
|
114
164
|
) -> tuple[int, Exception]:
|
115
165
|
"""Insert embeddings into Milvus. should call self.init() first"""
|
116
166
|
assert self.qdrant_client is not None
|
117
167
|
try:
|
118
|
-
# TODO: counts
|
119
168
|
for offset in range(0, len(embeddings), QDRANT_BATCH_SIZE):
|
120
169
|
vectors = embeddings[offset : offset + QDRANT_BATCH_SIZE]
|
121
170
|
ids = metadata[offset : offset + QDRANT_BATCH_SIZE]
|
122
|
-
|
171
|
+
if self.with_scalar_labels:
|
172
|
+
labels = labels_data[offset : offset + QDRANT_BATCH_SIZE]
|
173
|
+
payloads = [
|
174
|
+
{self._primary_field: pk, self._scalar_label_field: labels[i]} for i, pk in enumerate(ids)
|
175
|
+
]
|
176
|
+
else:
|
177
|
+
payloads = [{self._primary_field: pk} for i, pk in enumerate(ids)]
|
123
178
|
_ = self.qdrant_client.upsert(
|
124
179
|
collection_name=self.collection_name,
|
125
180
|
wait=True,
|
@@ -135,34 +190,46 @@ class QdrantCloud(VectorDB):
|
|
135
190
|
self,
|
136
191
|
query: list[float],
|
137
192
|
k: int = 100,
|
138
|
-
filters: dict | None = None,
|
139
193
|
timeout: int | None = None,
|
194
|
+
**kwargs,
|
140
195
|
) -> list[int]:
|
141
196
|
"""Perform a search on a query embedding and return results with score.
|
142
197
|
Should call self.init() first.
|
143
198
|
"""
|
144
199
|
assert self.qdrant_client is not None
|
145
200
|
|
146
|
-
|
147
|
-
|
148
|
-
|
201
|
+
res = self.qdrant_client.search(
|
202
|
+
collection_name=self.collection_name,
|
203
|
+
query_vector=query,
|
204
|
+
limit=k,
|
205
|
+
query_filter=self.query_filter,
|
206
|
+
search_params=self.db_case_config.search_param(),
|
207
|
+
with_payload=self.db_case_config.with_payload,
|
208
|
+
)
|
209
|
+
|
210
|
+
return [r.id for r in res]
|
211
|
+
|
212
|
+
def prepare_filter(self, filters: Filter):
|
213
|
+
if filters.type == FilterOp.NonFilter:
|
214
|
+
self.query_filter = None
|
215
|
+
elif filters.type == FilterOp.NumGE:
|
216
|
+
self.query_filter = QdrantFilter(
|
149
217
|
must=[
|
150
218
|
FieldCondition(
|
151
219
|
key=self._primary_field,
|
152
|
-
range=Range(
|
153
|
-
gt=filters.get("id"),
|
154
|
-
),
|
220
|
+
range=Range(gte=filters.int_value),
|
155
221
|
),
|
156
|
-
]
|
222
|
+
]
|
157
223
|
)
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
224
|
+
elif filters.type == FilterOp.StrEqual:
|
225
|
+
self.query_filter = QdrantFilter(
|
226
|
+
must=[
|
227
|
+
FieldCondition(
|
228
|
+
key=self._scalar_label_field,
|
229
|
+
match={"value": filters.label_value},
|
230
|
+
),
|
231
|
+
]
|
232
|
+
)
|
233
|
+
else:
|
234
|
+
msg = f"Not support Filter for Qdrant - {filters}"
|
235
|
+
raise ValueError(msg)
|