vectordb-bench 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +65 -1
- vectordb_bench/backend/clients/api.py +2 -1
- vectordb_bench/backend/clients/chroma/chroma.py +2 -2
- vectordb_bench/backend/clients/clickhouse/cli.py +66 -0
- vectordb_bench/backend/clients/clickhouse/clickhouse.py +156 -0
- vectordb_bench/backend/clients/clickhouse/config.py +60 -0
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +1 -1
- vectordb_bench/backend/clients/mariadb/cli.py +122 -0
- vectordb_bench/backend/clients/mariadb/config.py +73 -0
- vectordb_bench/backend/clients/mariadb/mariadb.py +208 -0
- vectordb_bench/backend/clients/milvus/cli.py +32 -0
- vectordb_bench/backend/clients/milvus/config.py +32 -0
- vectordb_bench/backend/clients/milvus/milvus.py +1 -1
- vectordb_bench/backend/clients/pgvector/cli.py +14 -3
- vectordb_bench/backend/clients/pgvector/config.py +22 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
- vectordb_bench/backend/clients/pinecone/pinecone.py +1 -1
- vectordb_bench/backend/clients/qdrant_cloud/config.py +1 -9
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +1 -1
- vectordb_bench/backend/clients/tidb/cli.py +98 -0
- vectordb_bench/backend/clients/tidb/config.py +46 -0
- vectordb_bench/backend/clients/tidb/tidb.py +233 -0
- vectordb_bench/backend/clients/vespa/cli.py +47 -0
- vectordb_bench/backend/clients/vespa/config.py +51 -0
- vectordb_bench/backend/clients/vespa/util.py +15 -0
- vectordb_bench/backend/clients/vespa/vespa.py +249 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
- vectordb_bench/cli/cli.py +20 -17
- vectordb_bench/cli/vectordbbench.py +8 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +147 -0
- vectordb_bench/frontend/config/styles.py +4 -0
- vectordb_bench/models.py +8 -6
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/METADATA +22 -3
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/RECORD +38 -25
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info/licenses}/LICENSE +0 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
|
4
|
+
import mariadb
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ..api import VectorDB
|
8
|
+
from .config import MariaDBConfigDict, MariaDBIndexConfig
|
9
|
+
|
10
|
+
log = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class MariaDB(VectorDB):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
dim: int,
|
17
|
+
db_config: MariaDBConfigDict,
|
18
|
+
db_case_config: MariaDBIndexConfig,
|
19
|
+
collection_name: str = "vec_collection",
|
20
|
+
drop_old: bool = False,
|
21
|
+
**kwargs,
|
22
|
+
):
|
23
|
+
self.name = "MariaDB"
|
24
|
+
self.db_config = db_config
|
25
|
+
self.case_config = db_case_config
|
26
|
+
self.db_name = "vectordbbench"
|
27
|
+
self.table_name = collection_name
|
28
|
+
self.dim = dim
|
29
|
+
|
30
|
+
# construct basic units
|
31
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
32
|
+
|
33
|
+
if drop_old:
|
34
|
+
self._drop_db()
|
35
|
+
self._create_db_table(dim)
|
36
|
+
|
37
|
+
self.cursor.close()
|
38
|
+
self.conn.close()
|
39
|
+
self.cursor = None
|
40
|
+
self.conn = None
|
41
|
+
|
42
|
+
@staticmethod
|
43
|
+
def _create_connection(**kwargs) -> tuple[mariadb.Connection, mariadb.Cursor]:
|
44
|
+
conn = mariadb.connect(**kwargs)
|
45
|
+
cursor = conn.cursor()
|
46
|
+
|
47
|
+
assert conn is not None, "Connection is not initialized"
|
48
|
+
assert cursor is not None, "Cursor is not initialized"
|
49
|
+
|
50
|
+
return conn, cursor
|
51
|
+
|
52
|
+
def _drop_db(self):
|
53
|
+
assert self.conn is not None, "Connection is not initialized"
|
54
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
55
|
+
log.info(f"{self.name} client drop db : {self.db_name}")
|
56
|
+
|
57
|
+
# flush tables before dropping database to avoid some locking issue
|
58
|
+
self.cursor.execute("FLUSH TABLES")
|
59
|
+
self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}")
|
60
|
+
self.cursor.execute("COMMIT")
|
61
|
+
self.cursor.execute("FLUSH TABLES")
|
62
|
+
|
63
|
+
def _create_db_table(self, dim: int):
|
64
|
+
assert self.conn is not None, "Connection is not initialized"
|
65
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
66
|
+
|
67
|
+
index_param = self.case_config.index_param()
|
68
|
+
|
69
|
+
try:
|
70
|
+
log.info(f"{self.name} client create database : {self.db_name}")
|
71
|
+
self.cursor.execute(f"CREATE DATABASE {self.db_name}")
|
72
|
+
|
73
|
+
log.info(f"{self.name} client create table : {self.table_name}")
|
74
|
+
self.cursor.execute(f"USE {self.db_name}")
|
75
|
+
|
76
|
+
self.cursor.execute(
|
77
|
+
f"""
|
78
|
+
CREATE TABLE {self.table_name} (
|
79
|
+
id INT PRIMARY KEY,
|
80
|
+
v VECTOR({self.dim}) NOT NULL
|
81
|
+
) ENGINE={index_param["storage_engine"]}
|
82
|
+
"""
|
83
|
+
)
|
84
|
+
self.cursor.execute("COMMIT")
|
85
|
+
|
86
|
+
except Exception as e:
|
87
|
+
log.warning(f"Failed to create table: {self.table_name} error: {e}")
|
88
|
+
raise e from None
|
89
|
+
|
90
|
+
@contextmanager
|
91
|
+
def init(self):
|
92
|
+
"""create and destory connections to database.
|
93
|
+
|
94
|
+
Examples:
|
95
|
+
>>> with self.init():
|
96
|
+
>>> self.insert_embeddings()
|
97
|
+
"""
|
98
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
99
|
+
|
100
|
+
index_param = self.case_config.index_param()
|
101
|
+
search_param = self.case_config.search_param()
|
102
|
+
|
103
|
+
# maximize allowed package size
|
104
|
+
self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824")
|
105
|
+
|
106
|
+
if index_param["index_type"] == "HNSW":
|
107
|
+
if index_param["max_cache_size"] is not None:
|
108
|
+
self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param['max_cache_size']}")
|
109
|
+
if search_param["ef_search"] is not None:
|
110
|
+
self.cursor.execute(f"SET mhnsw_ef_search = {search_param['ef_search']}")
|
111
|
+
self.cursor.execute("COMMIT")
|
112
|
+
|
113
|
+
self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608
|
114
|
+
self.select_sql = (
|
115
|
+
f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608
|
116
|
+
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
|
117
|
+
)
|
118
|
+
self.select_sql_with_filter = (
|
119
|
+
f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " # noqa: S608
|
120
|
+
f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
|
121
|
+
)
|
122
|
+
|
123
|
+
try:
|
124
|
+
yield
|
125
|
+
finally:
|
126
|
+
self.cursor.close()
|
127
|
+
self.conn.close()
|
128
|
+
self.cursor = None
|
129
|
+
self.conn = None
|
130
|
+
|
131
|
+
def ready_to_load(self) -> bool:
|
132
|
+
pass
|
133
|
+
|
134
|
+
def optimize(self) -> None:
|
135
|
+
assert self.conn is not None, "Connection is not initialized"
|
136
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
137
|
+
|
138
|
+
index_param = self.case_config.index_param()
|
139
|
+
|
140
|
+
try:
|
141
|
+
index_options = f"DISTANCE={index_param['metric_type']}"
|
142
|
+
if index_param["index_type"] == "HNSW" and index_param["M"] is not None:
|
143
|
+
index_options += f" M={index_param['M']}"
|
144
|
+
|
145
|
+
self.cursor.execute(
|
146
|
+
f"""
|
147
|
+
ALTER TABLE {self.db_name}.{self.table_name}
|
148
|
+
ADD VECTOR KEY v(v) {index_options}
|
149
|
+
"""
|
150
|
+
)
|
151
|
+
self.cursor.execute("COMMIT")
|
152
|
+
|
153
|
+
except Exception as e:
|
154
|
+
log.warning(f"Failed to create index: {self.table_name} error: {e}")
|
155
|
+
raise e from None
|
156
|
+
|
157
|
+
@staticmethod
|
158
|
+
def vector_to_hex(v): # noqa: ANN001
|
159
|
+
return np.array(v, "float32").tobytes()
|
160
|
+
|
161
|
+
def insert_embeddings(
|
162
|
+
self,
|
163
|
+
embeddings: list[list[float]],
|
164
|
+
metadata: list[int],
|
165
|
+
**kwargs,
|
166
|
+
) -> tuple[int, Exception]:
|
167
|
+
"""Insert embeddings into the database.
|
168
|
+
Should call self.init() first.
|
169
|
+
"""
|
170
|
+
assert self.conn is not None, "Connection is not initialized"
|
171
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
172
|
+
|
173
|
+
try:
|
174
|
+
metadata_arr = np.array(metadata)
|
175
|
+
embeddings_arr = np.array(embeddings)
|
176
|
+
|
177
|
+
batch_data = []
|
178
|
+
for i, row in enumerate(metadata_arr):
|
179
|
+
batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])))
|
180
|
+
|
181
|
+
self.cursor.executemany(self.insert_sql, batch_data)
|
182
|
+
self.cursor.execute("COMMIT")
|
183
|
+
self.cursor.execute("FLUSH TABLES")
|
184
|
+
|
185
|
+
return len(metadata), None
|
186
|
+
except Exception as e:
|
187
|
+
log.warning(f"Failed to insert data into Vector table ({self.table_name}), error: {e}")
|
188
|
+
return 0, e
|
189
|
+
|
190
|
+
def search_embedding(
|
191
|
+
self,
|
192
|
+
query: list[float],
|
193
|
+
k: int = 100,
|
194
|
+
filters: dict | None = None,
|
195
|
+
timeout: int | None = None,
|
196
|
+
**kwargs,
|
197
|
+
) -> list[int]:
|
198
|
+
assert self.conn is not None, "Connection is not initialized"
|
199
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
200
|
+
|
201
|
+
search_param = self.case_config.search_param() # noqa: F841
|
202
|
+
|
203
|
+
if filters:
|
204
|
+
self.cursor.execute(self.select_sql_with_filter, (filters.get("id"), self.vector_to_hex(query), k))
|
205
|
+
else:
|
206
|
+
self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k))
|
207
|
+
|
208
|
+
return [id for (id,) in self.cursor.fetchall()] # noqa: A001
|
@@ -195,6 +195,38 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
|
|
195
195
|
)
|
196
196
|
|
197
197
|
|
198
|
+
class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict):
|
199
|
+
metric_type: Annotated[
|
200
|
+
str,
|
201
|
+
click.option("--metric-type", type=str, required=True, help="Metric type for brute force search"),
|
202
|
+
]
|
203
|
+
limit: Annotated[
|
204
|
+
int,
|
205
|
+
click.option("--limit", type=int, required=True, help="Top-k limit for search"),
|
206
|
+
]
|
207
|
+
|
208
|
+
|
209
|
+
@cli.command()
|
210
|
+
@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
|
211
|
+
def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
|
212
|
+
from .config import GPUBruteForceConfig, MilvusConfig
|
213
|
+
|
214
|
+
run(
|
215
|
+
db=DBTYPE,
|
216
|
+
db_config=MilvusConfig(
|
217
|
+
db_label=parameters["db_label"],
|
218
|
+
uri=SecretStr(parameters["uri"]),
|
219
|
+
user=parameters["user_name"],
|
220
|
+
password=SecretStr(parameters["password"]),
|
221
|
+
),
|
222
|
+
db_case_config=GPUBruteForceConfig(
|
223
|
+
metric_type=parameters["metric_type"],
|
224
|
+
limit=parameters["limit"], # top-k for search
|
225
|
+
),
|
226
|
+
**parameters,
|
227
|
+
)
|
228
|
+
|
229
|
+
|
198
230
|
class MilvusGPUIVFPQTypedDict(
|
199
231
|
CommonTypedDict,
|
200
232
|
MilvusTypedDict,
|
@@ -40,6 +40,7 @@ class MilvusIndexConfig(BaseModel):
|
|
40
40
|
IndexType.GPU_CAGRA,
|
41
41
|
IndexType.GPU_IVF_FLAT,
|
42
42
|
IndexType.GPU_IVF_PQ,
|
43
|
+
IndexType.GPU_BRUTE_FORCE,
|
43
44
|
]
|
44
45
|
|
45
46
|
def parse_metric(self) -> str:
|
@@ -184,6 +185,36 @@ class GPUIVFFlatConfig(MilvusIndexConfig, DBCaseConfig):
|
|
184
185
|
}
|
185
186
|
|
186
187
|
|
188
|
+
class GPUBruteForceConfig(MilvusIndexConfig, DBCaseConfig):
|
189
|
+
limit: int = 10 # Default top-k for search
|
190
|
+
metric_type: str # Metric type (e.g., 'L2', 'IP', etc.)
|
191
|
+
index: IndexType = IndexType.GPU_BRUTE_FORCE # Index type set to GPU_BRUTE_FORCE
|
192
|
+
|
193
|
+
def index_param(self) -> dict:
|
194
|
+
"""
|
195
|
+
Returns the parameters for creating the GPU_BRUTE_FORCE index.
|
196
|
+
No additional parameters required for index building.
|
197
|
+
"""
|
198
|
+
return {
|
199
|
+
"metric_type": self.parse_metric(), # Metric type for distance calculation (L2, IP, etc.)
|
200
|
+
"index_type": self.index.value, # GPU_BRUTE_FORCE index type
|
201
|
+
"params": {}, # No additional parameters for GPU_BRUTE_FORCE
|
202
|
+
}
|
203
|
+
|
204
|
+
def search_param(self) -> dict:
|
205
|
+
"""
|
206
|
+
Returns the parameters for performing a search on the GPU_BRUTE_FORCE index.
|
207
|
+
Only metric_type and top-k (limit) are needed for search.
|
208
|
+
"""
|
209
|
+
return {
|
210
|
+
"metric_type": self.parse_metric(), # Metric type for search
|
211
|
+
"params": {
|
212
|
+
"nprobe": 1, # For GPU_BRUTE_FORCE, set nprobe to 1 (brute force search)
|
213
|
+
"limit": self.limit, # Top-k for search
|
214
|
+
},
|
215
|
+
}
|
216
|
+
|
217
|
+
|
187
218
|
class GPUIVFPQConfig(MilvusIndexConfig, DBCaseConfig):
|
188
219
|
nlist: int = 1024
|
189
220
|
m: int = 0
|
@@ -261,4 +292,5 @@ _milvus_case_config = {
|
|
261
292
|
IndexType.GPU_IVF_FLAT: GPUIVFFlatConfig,
|
262
293
|
IndexType.GPU_IVF_PQ: GPUIVFPQConfig,
|
263
294
|
IndexType.GPU_CAGRA: GPUCAGRAConfig,
|
295
|
+
IndexType.GPU_BRUTE_FORCE: GPUBruteForceConfig,
|
264
296
|
}
|
@@ -155,7 +155,7 @@ class Milvus(VectorDB):
|
|
155
155
|
embeddings: Iterable[list[float]],
|
156
156
|
metadata: list[int],
|
157
157
|
**kwargs,
|
158
|
-
) ->
|
158
|
+
) -> tuple[int, Exception]:
|
159
159
|
"""Insert embeddings into Milvus. should call self.init() first"""
|
160
160
|
# use the first insert_embeddings to init collection
|
161
161
|
assert self.col is not None
|
@@ -18,8 +18,7 @@ from ....cli.cli import (
|
|
18
18
|
)
|
19
19
|
|
20
20
|
|
21
|
-
#
|
22
|
-
def set_default_quantized_fetch_limit(ctx: any, param: any, value: any):
|
21
|
+
def set_default_quantized_fetch_limit(ctx: any, param: any, value: any): # noqa: ARG001
|
23
22
|
if ctx.params.get("reranking") and value is None:
|
24
23
|
# ef_search is the default value for quantized_fetch_limit as it's bound by ef_search.
|
25
24
|
# 100 is default value for quantized_fetch_limit for IVFFlat.
|
@@ -82,7 +81,17 @@ class PgVectorTypedDict(CommonTypedDict):
|
|
82
81
|
click.option(
|
83
82
|
"--quantization-type",
|
84
83
|
type=click.Choice(["none", "bit", "halfvec"]),
|
85
|
-
help="quantization type for vectors",
|
84
|
+
help="quantization type for vectors (in index)",
|
85
|
+
required=False,
|
86
|
+
),
|
87
|
+
]
|
88
|
+
table_quantization_type: Annotated[
|
89
|
+
str | None,
|
90
|
+
click.option(
|
91
|
+
"--table-quantization-type",
|
92
|
+
type=click.Choice(["none", "bit", "halfvec"]),
|
93
|
+
help="quantization type for vectors (in table). "
|
94
|
+
"If equal to bit, the parameter quantization_type will be set to bit too.",
|
86
95
|
required=False,
|
87
96
|
),
|
88
97
|
]
|
@@ -146,6 +155,7 @@ def PgVectorIVFFlat(
|
|
146
155
|
lists=parameters["lists"],
|
147
156
|
probes=parameters["probes"],
|
148
157
|
quantization_type=parameters["quantization_type"],
|
158
|
+
table_quantization_type=parameters["table_quantization_type"],
|
149
159
|
reranking=parameters["reranking"],
|
150
160
|
reranking_metric=parameters["reranking_metric"],
|
151
161
|
quantized_fetch_limit=parameters["quantized_fetch_limit"],
|
@@ -182,6 +192,7 @@ def PgVectorHNSW(
|
|
182
192
|
maintenance_work_mem=parameters["maintenance_work_mem"],
|
183
193
|
max_parallel_workers=parameters["max_parallel_workers"],
|
184
194
|
quantization_type=parameters["quantization_type"],
|
195
|
+
table_quantization_type=parameters["table_quantization_type"],
|
185
196
|
reranking=parameters["reranking"],
|
186
197
|
reranking_metric=parameters["reranking_metric"],
|
187
198
|
quantized_fetch_limit=parameters["quantized_fetch_limit"],
|
@@ -80,7 +80,12 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
|
|
80
80
|
|
81
81
|
if d.get(self.quantization_type) is None:
|
82
82
|
return d.get("_fallback").get(self.metric_type)
|
83
|
-
|
83
|
+
metric = d.get(self.quantization_type).get(self.metric_type)
|
84
|
+
# If using binary quantization for the index, use a bit metric
|
85
|
+
# no matter what metric was selected for vector or halfvec data
|
86
|
+
if self.quantization_type == "bit" and metric is None:
|
87
|
+
return "bit_hamming_ops"
|
88
|
+
return metric
|
84
89
|
|
85
90
|
def parse_metric_fun_op(self) -> LiteralString:
|
86
91
|
if self.quantization_type == "bit":
|
@@ -168,14 +173,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
|
|
168
173
|
maintenance_work_mem: str | None = None
|
169
174
|
max_parallel_workers: int | None = None
|
170
175
|
quantization_type: str | None = None
|
176
|
+
table_quantization_type: str | None
|
171
177
|
reranking: bool | None = None
|
172
178
|
quantized_fetch_limit: int | None = None
|
173
179
|
reranking_metric: str | None = None
|
174
180
|
|
175
181
|
def index_param(self) -> PgVectorIndexParam:
|
176
182
|
index_parameters = {"lists": self.lists}
|
177
|
-
if self.quantization_type == "none":
|
178
|
-
self.quantization_type =
|
183
|
+
if self.quantization_type == "none" or self.quantization_type is None:
|
184
|
+
self.quantization_type = "vector"
|
185
|
+
if self.table_quantization_type == "none" or self.table_quantization_type is None:
|
186
|
+
self.table_quantization_type = "vector"
|
187
|
+
if self.table_quantization_type == "bit":
|
188
|
+
self.quantization_type = "bit"
|
179
189
|
return {
|
180
190
|
"metric": self.parse_metric(),
|
181
191
|
"index_type": self.index.value,
|
@@ -183,6 +193,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
|
|
183
193
|
"maintenance_work_mem": self.maintenance_work_mem,
|
184
194
|
"max_parallel_workers": self.max_parallel_workers,
|
185
195
|
"quantization_type": self.quantization_type,
|
196
|
+
"table_quantization_type": self.table_quantization_type,
|
186
197
|
}
|
187
198
|
|
188
199
|
def search_param(self) -> PgVectorSearchParam:
|
@@ -212,14 +223,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
|
|
212
223
|
maintenance_work_mem: str | None = None
|
213
224
|
max_parallel_workers: int | None = None
|
214
225
|
quantization_type: str | None = None
|
226
|
+
table_quantization_type: str | None
|
215
227
|
reranking: bool | None = None
|
216
228
|
quantized_fetch_limit: int | None = None
|
217
229
|
reranking_metric: str | None = None
|
218
230
|
|
219
231
|
def index_param(self) -> PgVectorIndexParam:
|
220
232
|
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
|
221
|
-
if self.quantization_type == "none":
|
222
|
-
self.quantization_type =
|
233
|
+
if self.quantization_type == "none" or self.quantization_type is None:
|
234
|
+
self.quantization_type = "vector"
|
235
|
+
if self.table_quantization_type == "none" or self.table_quantization_type is None:
|
236
|
+
self.table_quantization_type = "vector"
|
237
|
+
if self.table_quantization_type == "bit":
|
238
|
+
self.quantization_type = "bit"
|
223
239
|
return {
|
224
240
|
"metric": self.parse_metric(),
|
225
241
|
"index_type": self.index.value,
|
@@ -227,6 +243,7 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
|
|
227
243
|
"maintenance_work_mem": self.maintenance_work_mem,
|
228
244
|
"max_parallel_workers": self.max_parallel_workers,
|
229
245
|
"quantization_type": self.quantization_type,
|
246
|
+
"table_quantization_type": self.table_quantization_type,
|
230
247
|
}
|
231
248
|
|
232
249
|
def search_param(self) -> PgVectorSearchParam:
|
@@ -94,7 +94,7 @@ class PgVector(VectorDB):
|
|
94
94
|
reranking = self.case_config.search_param()["reranking"]
|
95
95
|
column_name = (
|
96
96
|
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
|
97
|
-
if index_param["quantization_type"] == "bit"
|
97
|
+
if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit"
|
98
98
|
else sql.SQL("embedding")
|
99
99
|
)
|
100
100
|
search_vector = (
|
@@ -104,7 +104,8 @@ class PgVector(VectorDB):
|
|
104
104
|
)
|
105
105
|
|
106
106
|
# The following sections assume that the quantization_type value matches the quantization function name
|
107
|
-
if index_param["quantization_type"]
|
107
|
+
if index_param["quantization_type"] != index_param["table_quantization_type"]:
|
108
|
+
# Reranking makes sense only if table quantization is not "bit"
|
108
109
|
if index_param["quantization_type"] == "bit" and reranking:
|
109
110
|
# Embeddings needs to be passed to binary_quantize function if quantization_type is bit
|
110
111
|
search_query = sql.Composed(
|
@@ -113,7 +114,7 @@ class PgVector(VectorDB):
|
|
113
114
|
"""
|
114
115
|
SELECT i.id
|
115
116
|
FROM (
|
116
|
-
SELECT id, embedding {reranking_metric_fun_op} %s::
|
117
|
+
SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance
|
117
118
|
FROM public.{table_name} {where_clause}
|
118
119
|
ORDER BY {column_name}::{quantization_type}({dim})
|
119
120
|
""",
|
@@ -123,6 +124,8 @@ class PgVector(VectorDB):
|
|
123
124
|
reranking_metric_fun_op=sql.SQL(
|
124
125
|
self.case_config.search_param()["reranking_metric_fun_op"],
|
125
126
|
),
|
127
|
+
search_vector=search_vector,
|
128
|
+
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
|
126
129
|
quantization_type=sql.SQL(index_param["quantization_type"]),
|
127
130
|
dim=sql.Literal(self.dim),
|
128
131
|
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
@@ -130,7 +133,7 @@ class PgVector(VectorDB):
|
|
130
133
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
131
134
|
sql.SQL(
|
132
135
|
"""
|
133
|
-
{search_vector}
|
136
|
+
{search_vector}::{quantization_type}({dim})
|
134
137
|
LIMIT {quantized_fetch_limit}
|
135
138
|
) i
|
136
139
|
ORDER BY i.distance
|
@@ -138,6 +141,8 @@ class PgVector(VectorDB):
|
|
138
141
|
""",
|
139
142
|
).format(
|
140
143
|
search_vector=search_vector,
|
144
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
145
|
+
dim=sql.Literal(self.dim),
|
141
146
|
quantized_fetch_limit=sql.Literal(
|
142
147
|
self.case_config.search_param()["quantized_fetch_limit"],
|
143
148
|
),
|
@@ -160,10 +165,12 @@ class PgVector(VectorDB):
|
|
160
165
|
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
161
166
|
),
|
162
167
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
163
|
-
sql.SQL(" {search_vector} LIMIT %s::int").format(
|
168
|
+
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
|
164
169
|
search_vector=search_vector,
|
170
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
171
|
+
dim=sql.Literal(self.dim),
|
165
172
|
),
|
166
|
-
]
|
173
|
+
]
|
167
174
|
)
|
168
175
|
else:
|
169
176
|
search_query = sql.Composed(
|
@@ -175,8 +182,12 @@ class PgVector(VectorDB):
|
|
175
182
|
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
176
183
|
),
|
177
184
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
178
|
-
sql.SQL("
|
179
|
-
|
185
|
+
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
|
186
|
+
search_vector=search_vector,
|
187
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
188
|
+
dim=sql.Literal(self.dim),
|
189
|
+
),
|
190
|
+
]
|
180
191
|
)
|
181
192
|
|
182
193
|
return search_query
|
@@ -323,7 +334,7 @@ class PgVector(VectorDB):
|
|
323
334
|
)
|
324
335
|
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
|
325
336
|
|
326
|
-
if index_param["quantization_type"]
|
337
|
+
if index_param["quantization_type"] != index_param["table_quantization_type"]:
|
327
338
|
index_create_sql = sql.SQL(
|
328
339
|
"""
|
329
340
|
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
@@ -365,14 +376,23 @@ class PgVector(VectorDB):
|
|
365
376
|
assert self.conn is not None, "Connection is not initialized"
|
366
377
|
assert self.cursor is not None, "Cursor is not initialized"
|
367
378
|
|
379
|
+
index_param = self.case_config.index_param()
|
380
|
+
|
368
381
|
try:
|
369
382
|
log.info(f"{self.name} client create table : {self.table_name}")
|
370
383
|
|
371
384
|
# create table
|
372
385
|
self.cursor.execute(
|
373
386
|
sql.SQL(
|
374
|
-
"
|
375
|
-
|
387
|
+
"""
|
388
|
+
CREATE TABLE IF NOT EXISTS public.{table_name}
|
389
|
+
(id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));
|
390
|
+
"""
|
391
|
+
).format(
|
392
|
+
table_name=sql.Identifier(self.table_name),
|
393
|
+
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
|
394
|
+
dim=dim,
|
395
|
+
)
|
376
396
|
)
|
377
397
|
self.cursor.execute(
|
378
398
|
sql.SQL(
|
@@ -393,18 +413,41 @@ class PgVector(VectorDB):
|
|
393
413
|
assert self.conn is not None, "Connection is not initialized"
|
394
414
|
assert self.cursor is not None, "Cursor is not initialized"
|
395
415
|
|
416
|
+
index_param = self.case_config.index_param()
|
417
|
+
|
396
418
|
try:
|
397
419
|
metadata_arr = np.array(metadata)
|
398
420
|
embeddings_arr = np.array(embeddings)
|
399
421
|
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
copy
|
406
|
-
|
407
|
-
|
422
|
+
if index_param["table_quantization_type"] == "bit":
|
423
|
+
with self.cursor.copy(
|
424
|
+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format(
|
425
|
+
table_name=sql.Identifier(self.table_name)
|
426
|
+
)
|
427
|
+
) as copy:
|
428
|
+
# Same logic as pgvector binary_quantize
|
429
|
+
for i, row in enumerate(metadata_arr):
|
430
|
+
embeddings_bit = ""
|
431
|
+
for embedding in embeddings_arr[i]:
|
432
|
+
if embedding > 0:
|
433
|
+
embeddings_bit += "1"
|
434
|
+
else:
|
435
|
+
embeddings_bit += "0"
|
436
|
+
copy.write_row((str(row), embeddings_bit))
|
437
|
+
else:
|
438
|
+
with self.cursor.copy(
|
439
|
+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
|
440
|
+
table_name=sql.Identifier(self.table_name)
|
441
|
+
)
|
442
|
+
) as copy:
|
443
|
+
if index_param["table_quantization_type"] == "halfvec":
|
444
|
+
copy.set_types(["bigint", "halfvec"])
|
445
|
+
for i, row in enumerate(metadata_arr):
|
446
|
+
copy.write_row((row, np.float16(embeddings_arr[i])))
|
447
|
+
else:
|
448
|
+
copy.set_types(["bigint", "vector"])
|
449
|
+
for i, row in enumerate(metadata_arr):
|
450
|
+
copy.write_row((row, embeddings_arr[i]))
|
408
451
|
self.conn.commit()
|
409
452
|
|
410
453
|
if kwargs.get("last_batch"):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from pydantic import BaseModel, SecretStr
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
2
|
|
3
3
|
from ..api import DBCaseConfig, DBConfig, MetricType
|
4
4
|
|
@@ -20,14 +20,6 @@ class QdrantConfig(DBConfig):
|
|
20
20
|
"url": self.url.get_secret_value(),
|
21
21
|
}
|
22
22
|
|
23
|
-
@validator("*")
|
24
|
-
def not_empty_field(cls, v: any, field: any):
|
25
|
-
if field.name in ["api_key", "db_label"]:
|
26
|
-
return v
|
27
|
-
if isinstance(v, str | SecretStr) and len(v) == 0:
|
28
|
-
raise ValueError("Empty string!")
|
29
|
-
return v
|
30
|
-
|
31
23
|
|
32
24
|
class QdrantIndexConfig(BaseModel, DBCaseConfig):
|
33
25
|
metric_type: MetricType | None = None
|
@@ -111,7 +111,7 @@ class QdrantCloud(VectorDB):
|
|
111
111
|
embeddings: list[list[float]],
|
112
112
|
metadata: list[int],
|
113
113
|
**kwargs,
|
114
|
-
) ->
|
114
|
+
) -> tuple[int, Exception]:
|
115
115
|
"""Insert embeddings into Milvus. should call self.init() first"""
|
116
116
|
assert self.qdrant_client is not None
|
117
117
|
try:
|