vectordb-bench 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +48 -0
- vectordb_bench/backend/clients/api.py +1 -0
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +53 -4
- vectordb_bench/backend/clients/aws_opensearch/cli.py +85 -1
- vectordb_bench/backend/clients/aws_opensearch/config.py +10 -0
- vectordb_bench/backend/clients/mariadb/cli.py +107 -0
- vectordb_bench/backend/clients/mariadb/config.py +71 -0
- vectordb_bench/backend/clients/mariadb/mariadb.py +214 -0
- vectordb_bench/backend/clients/milvus/cli.py +50 -0
- vectordb_bench/backend/clients/milvus/config.py +33 -0
- vectordb_bench/backend/clients/mongodb/config.py +53 -0
- vectordb_bench/backend/clients/mongodb/mongodb.py +200 -0
- vectordb_bench/backend/clients/pgvector/cli.py +13 -1
- vectordb_bench/backend/clients/pgvector/config.py +22 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
- vectordb_bench/backend/clients/tidb/cli.py +98 -0
- vectordb_bench/backend/clients/tidb/config.py +49 -0
- vectordb_bench/backend/clients/tidb/tidb.py +234 -0
- vectordb_bench/cli/vectordbbench.py +4 -0
- vectordb_bench/frontend/components/custom/displaypPrams.py +12 -1
- vectordb_bench/frontend/components/run_test/submitTask.py +20 -3
- vectordb_bench/frontend/config/dbCaseConfigs.py +128 -0
- vectordb_bench/frontend/config/styles.py +2 -0
- vectordb_bench/log_util.py +15 -2
- vectordb_bench/models.py +7 -0
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/METADATA +67 -3
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/RECORD +31 -23
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/top_level.txt +0 -0
@@ -194,6 +194,56 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
|
|
194
194
|
**parameters,
|
195
195
|
)
|
196
196
|
|
197
|
+
@cli.command()
|
198
|
+
@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
|
199
|
+
def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
|
200
|
+
from .config import GPUBruteForceConfig, MilvusConfig
|
201
|
+
|
202
|
+
run(
|
203
|
+
db=DBTYPE,
|
204
|
+
db_config=MilvusConfig(
|
205
|
+
db_label=parameters["db_label"],
|
206
|
+
uri=SecretStr(parameters["uri"]),
|
207
|
+
user=parameters["user_name"],
|
208
|
+
password=SecretStr(parameters["password"]),
|
209
|
+
),
|
210
|
+
db_case_config=GPUBruteForceConfig(
|
211
|
+
metric_type=parameters["metric_type"],
|
212
|
+
limit=parameters["limit"], # top-k for search
|
213
|
+
),
|
214
|
+
**parameters,
|
215
|
+
)
|
216
|
+
|
217
|
+
class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict):
|
218
|
+
metric_type: Annotated[
|
219
|
+
str,
|
220
|
+
click.option("--metric-type", type=str, required=True, help="Metric type for brute force search"),
|
221
|
+
]
|
222
|
+
limit: Annotated[
|
223
|
+
int,
|
224
|
+
click.option("--limit", type=int, required=True, help="Top-k limit for search"),
|
225
|
+
]
|
226
|
+
|
227
|
+
@cli.command()
|
228
|
+
@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
|
229
|
+
def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
|
230
|
+
from .config import GPUBruteForceConfig, MilvusConfig
|
231
|
+
|
232
|
+
run(
|
233
|
+
db=DBTYPE,
|
234
|
+
db_config=MilvusConfig(
|
235
|
+
db_label=parameters["db_label"],
|
236
|
+
uri=SecretStr(parameters["uri"]),
|
237
|
+
user=parameters["user_name"],
|
238
|
+
password=SecretStr(parameters["password"]),
|
239
|
+
),
|
240
|
+
db_case_config=GPUBruteForceConfig(
|
241
|
+
metric_type=parameters["metric_type"],
|
242
|
+
limit=parameters["limit"], # top-k for search
|
243
|
+
),
|
244
|
+
**parameters,
|
245
|
+
)
|
246
|
+
|
197
247
|
|
198
248
|
class MilvusGPUIVFPQTypedDict(
|
199
249
|
CommonTypedDict,
|
@@ -40,6 +40,7 @@ class MilvusIndexConfig(BaseModel):
|
|
40
40
|
IndexType.GPU_CAGRA,
|
41
41
|
IndexType.GPU_IVF_FLAT,
|
42
42
|
IndexType.GPU_IVF_PQ,
|
43
|
+
IndexType.GPU_BRUTE_FORCE,
|
43
44
|
]
|
44
45
|
|
45
46
|
def parse_metric(self) -> str:
|
@@ -184,6 +185,37 @@ class GPUIVFFlatConfig(MilvusIndexConfig, DBCaseConfig):
|
|
184
185
|
}
|
185
186
|
|
186
187
|
|
188
|
+
class GPUBruteForceConfig(MilvusIndexConfig, DBCaseConfig):
|
189
|
+
limit: int = 10 # Default top-k for search
|
190
|
+
metric_type: str # Metric type (e.g., 'L2', 'IP', etc.)
|
191
|
+
index: IndexType = IndexType.GPU_BRUTE_FORCE # Index type set to GPU_BRUTE_FORCE
|
192
|
+
|
193
|
+
def index_param(self) -> dict:
|
194
|
+
"""
|
195
|
+
Returns the parameters for creating the GPU_BRUTE_FORCE index.
|
196
|
+
No additional parameters required for index building.
|
197
|
+
"""
|
198
|
+
return {
|
199
|
+
"metric_type": self.parse_metric(), # Metric type for distance calculation (L2, IP, etc.)
|
200
|
+
"index_type": self.index.value, # GPU_BRUTE_FORCE index type
|
201
|
+
"params": {}, # No additional parameters for GPU_BRUTE_FORCE
|
202
|
+
}
|
203
|
+
|
204
|
+
def search_param(self) -> dict:
|
205
|
+
"""
|
206
|
+
Returns the parameters for performing a search on the GPU_BRUTE_FORCE index.
|
207
|
+
Only metric_type and top-k (limit) are needed for search.
|
208
|
+
"""
|
209
|
+
return {
|
210
|
+
"metric_type": self.parse_metric(), # Metric type for search
|
211
|
+
"params": {
|
212
|
+
"nprobe": 1, # For GPU_BRUTE_FORCE, set nprobe to 1 (brute force search)
|
213
|
+
"limit": self.limit, # Top-k for search
|
214
|
+
},
|
215
|
+
}
|
216
|
+
|
217
|
+
|
218
|
+
|
187
219
|
class GPUIVFPQConfig(MilvusIndexConfig, DBCaseConfig):
|
188
220
|
nlist: int = 1024
|
189
221
|
m: int = 0
|
@@ -261,4 +293,5 @@ _milvus_case_config = {
|
|
261
293
|
IndexType.GPU_IVF_FLAT: GPUIVFFlatConfig,
|
262
294
|
IndexType.GPU_IVF_PQ: GPUIVFPQConfig,
|
263
295
|
IndexType.GPU_CAGRA: GPUCAGRAConfig,
|
296
|
+
IndexType.GPU_BRUTE_FORCE: GPUBruteForceConfig,
|
264
297
|
}
|
@@ -0,0 +1,53 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
from pydantic import BaseModel, SecretStr
|
4
|
+
|
5
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
6
|
+
|
7
|
+
|
8
|
+
class QuantizationType(Enum):
|
9
|
+
NONE = "none"
|
10
|
+
BINARY = "binary"
|
11
|
+
SCALAR = "scalar"
|
12
|
+
|
13
|
+
|
14
|
+
class MongoDBConfig(DBConfig, BaseModel):
|
15
|
+
connection_string: SecretStr = "mongodb+srv://<user>:<password>@<cluster_name>.heatl.mongodb.net"
|
16
|
+
database: str = "vdb_bench"
|
17
|
+
|
18
|
+
def to_dict(self) -> dict:
|
19
|
+
return {
|
20
|
+
"connection_string": self.connection_string.get_secret_value(),
|
21
|
+
"database": self.database,
|
22
|
+
}
|
23
|
+
|
24
|
+
|
25
|
+
class MongoDBIndexConfig(BaseModel, DBCaseConfig):
|
26
|
+
index: IndexType = IndexType.HNSW # MongoDB uses HNSW for vector search
|
27
|
+
metric_type: MetricType = MetricType.COSINE
|
28
|
+
num_candidates_ratio: int = 10 # Default numCandidates ratio for vector search
|
29
|
+
quantization: QuantizationType = QuantizationType.NONE # Quantization type if applicable
|
30
|
+
|
31
|
+
def parse_metric(self) -> str:
|
32
|
+
if self.metric_type == MetricType.L2:
|
33
|
+
return "euclidean"
|
34
|
+
if self.metric_type == MetricType.IP:
|
35
|
+
return "dotProduct"
|
36
|
+
return "cosine" # Default to cosine similarity
|
37
|
+
|
38
|
+
def index_param(self) -> dict:
|
39
|
+
return {
|
40
|
+
"type": "vectorSearch",
|
41
|
+
"fields": [
|
42
|
+
{
|
43
|
+
"type": "vector",
|
44
|
+
"similarity": self.parse_metric(),
|
45
|
+
"numDimensions": None, # Will be set in MongoDB class
|
46
|
+
"path": "vector", # Vector field name
|
47
|
+
"quantization": self.quantization.value,
|
48
|
+
}
|
49
|
+
],
|
50
|
+
}
|
51
|
+
|
52
|
+
def search_param(self) -> dict:
|
53
|
+
return {"num_candidates_ratio": self.num_candidates_ratio}
|
@@ -0,0 +1,200 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from contextlib import contextmanager
|
4
|
+
|
5
|
+
from pymongo import MongoClient
|
6
|
+
from pymongo.operations import SearchIndexModel
|
7
|
+
|
8
|
+
from ..api import VectorDB
|
9
|
+
from .config import MongoDBIndexConfig
|
10
|
+
|
11
|
+
log = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
class MongoDBError(Exception):
|
15
|
+
"""Custom exception class for MongoDB client errors."""
|
16
|
+
|
17
|
+
|
18
|
+
class MongoDB(VectorDB):
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
dim: int,
|
22
|
+
db_config: dict,
|
23
|
+
db_case_config: MongoDBIndexConfig,
|
24
|
+
collection_name: str = "vdb_bench_collection",
|
25
|
+
id_field: str = "id",
|
26
|
+
vector_field: str = "vector",
|
27
|
+
drop_old: bool = False,
|
28
|
+
**kwargs,
|
29
|
+
):
|
30
|
+
self.dim = dim
|
31
|
+
self.db_config = db_config
|
32
|
+
self.case_config = db_case_config
|
33
|
+
self.collection_name = collection_name
|
34
|
+
self.id_field = id_field
|
35
|
+
self.vector_field = vector_field
|
36
|
+
self.drop_old = drop_old
|
37
|
+
|
38
|
+
# Update index dimensions
|
39
|
+
index_params = self.case_config.index_param()
|
40
|
+
log.info(f"index params: {index_params}")
|
41
|
+
index_params["fields"][0]["numDimensions"] = dim
|
42
|
+
self.index_params = index_params
|
43
|
+
|
44
|
+
# Initialize - they'll also be set in init()
|
45
|
+
uri = self.db_config["connection_string"]
|
46
|
+
self.client = MongoClient(uri)
|
47
|
+
self.db = self.client[self.db_config["database"]]
|
48
|
+
self.collection = self.db[self.collection_name]
|
49
|
+
if self.drop_old and self.collection_name in self.db.list_collection_names():
|
50
|
+
log.info(f"MongoDB client dropping old collection: {self.collection_name}")
|
51
|
+
self.db.drop_collection(self.collection_name)
|
52
|
+
self.client = None
|
53
|
+
self.db = None
|
54
|
+
self.collection = None
|
55
|
+
|
56
|
+
@contextmanager
|
57
|
+
def init(self):
|
58
|
+
"""Initialize MongoDB client and cleanup when done"""
|
59
|
+
try:
|
60
|
+
uri = self.db_config["connection_string"]
|
61
|
+
self.client = MongoClient(uri)
|
62
|
+
self.db = self.client[self.db_config["database"]]
|
63
|
+
self.collection = self.db[self.collection_name]
|
64
|
+
|
65
|
+
yield
|
66
|
+
finally:
|
67
|
+
if self.client is not None:
|
68
|
+
self.client.close()
|
69
|
+
self.client = None
|
70
|
+
self.db = None
|
71
|
+
self.collection = None
|
72
|
+
|
73
|
+
def _create_index(self) -> None:
|
74
|
+
"""Create vector search index"""
|
75
|
+
index_name = "vector_index"
|
76
|
+
index_params = self.index_params
|
77
|
+
log.info(f"index params {index_params}")
|
78
|
+
# drop index if already exists
|
79
|
+
if self.collection.list_indexes():
|
80
|
+
all_indexes = self.collection.list_search_indexes()
|
81
|
+
if any(idx.get("name") == index_name for idx in all_indexes):
|
82
|
+
log.info(f"Drop index: {index_name}")
|
83
|
+
try:
|
84
|
+
self.collection.drop_search_index(index_name)
|
85
|
+
while True:
|
86
|
+
indices = list(self.collection.list_search_indexes())
|
87
|
+
indices = [idx for idx in indices if idx["name"] == index_name]
|
88
|
+
log.debug(f"index status {indices}")
|
89
|
+
if len(indices) == 0:
|
90
|
+
break
|
91
|
+
log.info(f"index deleting {indices}")
|
92
|
+
except Exception:
|
93
|
+
log.exception(f"Error dropping index {index_name}")
|
94
|
+
try:
|
95
|
+
# Create vector search index
|
96
|
+
search_index = SearchIndexModel(definition=index_params, name=index_name, type="vectorSearch")
|
97
|
+
|
98
|
+
self.collection.create_search_index(search_index)
|
99
|
+
log.info(f"Created vector search index: {index_name}")
|
100
|
+
self._wait_for_index_ready(index_name)
|
101
|
+
|
102
|
+
# Create regular index on id field for faster lookups
|
103
|
+
self.collection.create_index(self.id_field)
|
104
|
+
log.info(f"Created index on {self.id_field} field")
|
105
|
+
|
106
|
+
except Exception:
|
107
|
+
log.exception(f"Error creating index {index_name}")
|
108
|
+
raise
|
109
|
+
|
110
|
+
def _wait_for_index_ready(self, index_name: str, check_interval: int = 5) -> None:
|
111
|
+
"""Wait for index to be ready"""
|
112
|
+
while True:
|
113
|
+
indices = list(self.collection.list_search_indexes())
|
114
|
+
log.debug(f"index status {indices}")
|
115
|
+
if indices and any(idx.get("name") == index_name and idx.get("queryable") for idx in indices):
|
116
|
+
break
|
117
|
+
for idx in indices:
|
118
|
+
if idx.get("name") == index_name and idx.get("status") == "FAILED":
|
119
|
+
error_msg = f"Index {index_name} failed to build"
|
120
|
+
raise MongoDBError(error_msg)
|
121
|
+
|
122
|
+
time.sleep(check_interval)
|
123
|
+
log.info(f"Index {index_name} is ready")
|
124
|
+
|
125
|
+
def need_normalize_cosine(self) -> bool:
|
126
|
+
return False
|
127
|
+
|
128
|
+
def insert_embeddings(
|
129
|
+
self,
|
130
|
+
embeddings: list[list[float]],
|
131
|
+
metadata: list[int],
|
132
|
+
**kwargs,
|
133
|
+
) -> (int, Exception | None):
|
134
|
+
"""Insert embeddings into MongoDB"""
|
135
|
+
|
136
|
+
# Prepare documents in bulk
|
137
|
+
documents = [
|
138
|
+
{
|
139
|
+
self.id_field: id_,
|
140
|
+
self.vector_field: embedding,
|
141
|
+
}
|
142
|
+
for id_, embedding in zip(metadata, embeddings, strict=False)
|
143
|
+
]
|
144
|
+
|
145
|
+
# Use ordered=False for better insert performance
|
146
|
+
try:
|
147
|
+
self.collection.insert_many(documents, ordered=False)
|
148
|
+
except Exception as e:
|
149
|
+
return 0, e
|
150
|
+
return len(documents), None
|
151
|
+
|
152
|
+
def search_embedding(
|
153
|
+
self,
|
154
|
+
query: list[float],
|
155
|
+
k: int = 100,
|
156
|
+
filters: dict | None = None,
|
157
|
+
**kwargs,
|
158
|
+
) -> list[int]:
|
159
|
+
"""Search for similar vectors"""
|
160
|
+
search_params = self.case_config.search_param()
|
161
|
+
|
162
|
+
vector_search = {"queryVector": query, "index": "vector_index", "path": self.vector_field, "limit": k}
|
163
|
+
|
164
|
+
# Add exact search parameter if specified
|
165
|
+
if search_params["exact"]:
|
166
|
+
vector_search["exact"] = True
|
167
|
+
else:
|
168
|
+
# Set numCandidates based on k value and data size
|
169
|
+
# For 50K dataset, use higher multiplier for better recall
|
170
|
+
num_candidates = min(10000, k * search_params["num_candidates_ratio"])
|
171
|
+
vector_search["numCandidates"] = num_candidates
|
172
|
+
|
173
|
+
# Add filter if specified
|
174
|
+
if filters:
|
175
|
+
log.info(f"Applying filter: {filters}")
|
176
|
+
vector_search["filter"] = {
|
177
|
+
"id": {"gte": filters["id"]},
|
178
|
+
}
|
179
|
+
pipeline = [
|
180
|
+
{"$vectorSearch": vector_search},
|
181
|
+
{
|
182
|
+
"$project": {
|
183
|
+
"_id": 0,
|
184
|
+
self.id_field: 1,
|
185
|
+
"score": {"$meta": "vectorSearchScore"}, # Include similarity score
|
186
|
+
}
|
187
|
+
},
|
188
|
+
]
|
189
|
+
|
190
|
+
results = list(self.collection.aggregate(pipeline))
|
191
|
+
return [doc[self.id_field] for doc in results]
|
192
|
+
|
193
|
+
def optimize(self, data_size: int | None = None) -> None:
|
194
|
+
"""MongoDB vector search indexes are self-optimizing"""
|
195
|
+
log.info("optimize for search")
|
196
|
+
self._create_index()
|
197
|
+
self._wait_for_index_ready("vector_index")
|
198
|
+
|
199
|
+
def ready_to_load(self) -> None:
|
200
|
+
"""MongoDB is always ready to load"""
|
@@ -82,7 +82,17 @@ class PgVectorTypedDict(CommonTypedDict):
|
|
82
82
|
click.option(
|
83
83
|
"--quantization-type",
|
84
84
|
type=click.Choice(["none", "bit", "halfvec"]),
|
85
|
-
help="quantization type for vectors",
|
85
|
+
help="quantization type for vectors (in index)",
|
86
|
+
required=False,
|
87
|
+
),
|
88
|
+
]
|
89
|
+
table_quantization_type: Annotated[
|
90
|
+
str | None,
|
91
|
+
click.option(
|
92
|
+
"--table-quantization-type",
|
93
|
+
type=click.Choice(["none", "bit", "halfvec"]),
|
94
|
+
help="quantization type for vectors (in table). "
|
95
|
+
"If equal to bit, the parameter quantization_type will be set to bit too.",
|
86
96
|
required=False,
|
87
97
|
),
|
88
98
|
]
|
@@ -146,6 +156,7 @@ def PgVectorIVFFlat(
|
|
146
156
|
lists=parameters["lists"],
|
147
157
|
probes=parameters["probes"],
|
148
158
|
quantization_type=parameters["quantization_type"],
|
159
|
+
table_quantization_type=parameters["table_quantization_type"],
|
149
160
|
reranking=parameters["reranking"],
|
150
161
|
reranking_metric=parameters["reranking_metric"],
|
151
162
|
quantized_fetch_limit=parameters["quantized_fetch_limit"],
|
@@ -182,6 +193,7 @@ def PgVectorHNSW(
|
|
182
193
|
maintenance_work_mem=parameters["maintenance_work_mem"],
|
183
194
|
max_parallel_workers=parameters["max_parallel_workers"],
|
184
195
|
quantization_type=parameters["quantization_type"],
|
196
|
+
table_quantization_type=parameters["table_quantization_type"],
|
185
197
|
reranking=parameters["reranking"],
|
186
198
|
reranking_metric=parameters["reranking_metric"],
|
187
199
|
quantized_fetch_limit=parameters["quantized_fetch_limit"],
|
@@ -80,7 +80,12 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
|
|
80
80
|
|
81
81
|
if d.get(self.quantization_type) is None:
|
82
82
|
return d.get("_fallback").get(self.metric_type)
|
83
|
-
|
83
|
+
metric = d.get(self.quantization_type).get(self.metric_type)
|
84
|
+
# If using binary quantization for the index, use a bit metric
|
85
|
+
# no matter what metric was selected for vector or halfvec data
|
86
|
+
if self.quantization_type == "bit" and metric is None:
|
87
|
+
return "bit_hamming_ops"
|
88
|
+
return metric
|
84
89
|
|
85
90
|
def parse_metric_fun_op(self) -> LiteralString:
|
86
91
|
if self.quantization_type == "bit":
|
@@ -168,14 +173,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
|
|
168
173
|
maintenance_work_mem: str | None = None
|
169
174
|
max_parallel_workers: int | None = None
|
170
175
|
quantization_type: str | None = None
|
176
|
+
table_quantization_type: str | None
|
171
177
|
reranking: bool | None = None
|
172
178
|
quantized_fetch_limit: int | None = None
|
173
179
|
reranking_metric: str | None = None
|
174
180
|
|
175
181
|
def index_param(self) -> PgVectorIndexParam:
|
176
182
|
index_parameters = {"lists": self.lists}
|
177
|
-
if self.quantization_type == "none":
|
178
|
-
self.quantization_type =
|
183
|
+
if self.quantization_type == "none" or self.quantization_type is None:
|
184
|
+
self.quantization_type = "vector"
|
185
|
+
if self.table_quantization_type == "none" or self.table_quantization_type is None:
|
186
|
+
self.table_quantization_type = "vector"
|
187
|
+
if self.table_quantization_type == "bit":
|
188
|
+
self.quantization_type = "bit"
|
179
189
|
return {
|
180
190
|
"metric": self.parse_metric(),
|
181
191
|
"index_type": self.index.value,
|
@@ -183,6 +193,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
|
|
183
193
|
"maintenance_work_mem": self.maintenance_work_mem,
|
184
194
|
"max_parallel_workers": self.max_parallel_workers,
|
185
195
|
"quantization_type": self.quantization_type,
|
196
|
+
"table_quantization_type": self.table_quantization_type,
|
186
197
|
}
|
187
198
|
|
188
199
|
def search_param(self) -> PgVectorSearchParam:
|
@@ -212,14 +223,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
|
|
212
223
|
maintenance_work_mem: str | None = None
|
213
224
|
max_parallel_workers: int | None = None
|
214
225
|
quantization_type: str | None = None
|
226
|
+
table_quantization_type: str | None
|
215
227
|
reranking: bool | None = None
|
216
228
|
quantized_fetch_limit: int | None = None
|
217
229
|
reranking_metric: str | None = None
|
218
230
|
|
219
231
|
def index_param(self) -> PgVectorIndexParam:
|
220
232
|
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
|
221
|
-
if self.quantization_type == "none":
|
222
|
-
self.quantization_type =
|
233
|
+
if self.quantization_type == "none" or self.quantization_type is None:
|
234
|
+
self.quantization_type = "vector"
|
235
|
+
if self.table_quantization_type == "none" or self.table_quantization_type is None:
|
236
|
+
self.table_quantization_type = "vector"
|
237
|
+
if self.table_quantization_type == "bit":
|
238
|
+
self.quantization_type = "bit"
|
223
239
|
return {
|
224
240
|
"metric": self.parse_metric(),
|
225
241
|
"index_type": self.index.value,
|
@@ -227,6 +243,7 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
|
|
227
243
|
"maintenance_work_mem": self.maintenance_work_mem,
|
228
244
|
"max_parallel_workers": self.max_parallel_workers,
|
229
245
|
"quantization_type": self.quantization_type,
|
246
|
+
"table_quantization_type": self.table_quantization_type,
|
230
247
|
}
|
231
248
|
|
232
249
|
def search_param(self) -> PgVectorSearchParam:
|
@@ -94,7 +94,7 @@ class PgVector(VectorDB):
|
|
94
94
|
reranking = self.case_config.search_param()["reranking"]
|
95
95
|
column_name = (
|
96
96
|
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
|
97
|
-
if index_param["quantization_type"] == "bit"
|
97
|
+
if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit"
|
98
98
|
else sql.SQL("embedding")
|
99
99
|
)
|
100
100
|
search_vector = (
|
@@ -104,7 +104,8 @@ class PgVector(VectorDB):
|
|
104
104
|
)
|
105
105
|
|
106
106
|
# The following sections assume that the quantization_type value matches the quantization function name
|
107
|
-
if index_param["quantization_type"]
|
107
|
+
if index_param["quantization_type"] != index_param["table_quantization_type"]:
|
108
|
+
# Reranking makes sense only if table quantization is not "bit"
|
108
109
|
if index_param["quantization_type"] == "bit" and reranking:
|
109
110
|
# Embeddings needs to be passed to binary_quantize function if quantization_type is bit
|
110
111
|
search_query = sql.Composed(
|
@@ -113,7 +114,7 @@ class PgVector(VectorDB):
|
|
113
114
|
"""
|
114
115
|
SELECT i.id
|
115
116
|
FROM (
|
116
|
-
SELECT id, embedding {reranking_metric_fun_op} %s::
|
117
|
+
SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance
|
117
118
|
FROM public.{table_name} {where_clause}
|
118
119
|
ORDER BY {column_name}::{quantization_type}({dim})
|
119
120
|
""",
|
@@ -123,6 +124,8 @@ class PgVector(VectorDB):
|
|
123
124
|
reranking_metric_fun_op=sql.SQL(
|
124
125
|
self.case_config.search_param()["reranking_metric_fun_op"],
|
125
126
|
),
|
127
|
+
search_vector=search_vector,
|
128
|
+
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
|
126
129
|
quantization_type=sql.SQL(index_param["quantization_type"]),
|
127
130
|
dim=sql.Literal(self.dim),
|
128
131
|
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
@@ -130,7 +133,7 @@ class PgVector(VectorDB):
|
|
130
133
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
131
134
|
sql.SQL(
|
132
135
|
"""
|
133
|
-
{search_vector}
|
136
|
+
{search_vector}::{quantization_type}({dim})
|
134
137
|
LIMIT {quantized_fetch_limit}
|
135
138
|
) i
|
136
139
|
ORDER BY i.distance
|
@@ -138,6 +141,8 @@ class PgVector(VectorDB):
|
|
138
141
|
""",
|
139
142
|
).format(
|
140
143
|
search_vector=search_vector,
|
144
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
145
|
+
dim=sql.Literal(self.dim),
|
141
146
|
quantized_fetch_limit=sql.Literal(
|
142
147
|
self.case_config.search_param()["quantized_fetch_limit"],
|
143
148
|
),
|
@@ -160,10 +165,12 @@ class PgVector(VectorDB):
|
|
160
165
|
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
161
166
|
),
|
162
167
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
163
|
-
sql.SQL(" {search_vector} LIMIT %s::int").format(
|
168
|
+
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
|
164
169
|
search_vector=search_vector,
|
170
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
171
|
+
dim=sql.Literal(self.dim),
|
165
172
|
),
|
166
|
-
]
|
173
|
+
]
|
167
174
|
)
|
168
175
|
else:
|
169
176
|
search_query = sql.Composed(
|
@@ -175,8 +182,12 @@ class PgVector(VectorDB):
|
|
175
182
|
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
176
183
|
),
|
177
184
|
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
178
|
-
sql.SQL("
|
179
|
-
|
185
|
+
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
|
186
|
+
search_vector=search_vector,
|
187
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
188
|
+
dim=sql.Literal(self.dim),
|
189
|
+
),
|
190
|
+
]
|
180
191
|
)
|
181
192
|
|
182
193
|
return search_query
|
@@ -323,7 +334,7 @@ class PgVector(VectorDB):
|
|
323
334
|
)
|
324
335
|
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
|
325
336
|
|
326
|
-
if index_param["quantization_type"]
|
337
|
+
if index_param["quantization_type"] != index_param["table_quantization_type"]:
|
327
338
|
index_create_sql = sql.SQL(
|
328
339
|
"""
|
329
340
|
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
@@ -365,14 +376,23 @@ class PgVector(VectorDB):
|
|
365
376
|
assert self.conn is not None, "Connection is not initialized"
|
366
377
|
assert self.cursor is not None, "Cursor is not initialized"
|
367
378
|
|
379
|
+
index_param = self.case_config.index_param()
|
380
|
+
|
368
381
|
try:
|
369
382
|
log.info(f"{self.name} client create table : {self.table_name}")
|
370
383
|
|
371
384
|
# create table
|
372
385
|
self.cursor.execute(
|
373
386
|
sql.SQL(
|
374
|
-
"
|
375
|
-
|
387
|
+
"""
|
388
|
+
CREATE TABLE IF NOT EXISTS public.{table_name}
|
389
|
+
(id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));
|
390
|
+
"""
|
391
|
+
).format(
|
392
|
+
table_name=sql.Identifier(self.table_name),
|
393
|
+
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
|
394
|
+
dim=dim,
|
395
|
+
)
|
376
396
|
)
|
377
397
|
self.cursor.execute(
|
378
398
|
sql.SQL(
|
@@ -393,18 +413,41 @@ class PgVector(VectorDB):
|
|
393
413
|
assert self.conn is not None, "Connection is not initialized"
|
394
414
|
assert self.cursor is not None, "Cursor is not initialized"
|
395
415
|
|
416
|
+
index_param = self.case_config.index_param()
|
417
|
+
|
396
418
|
try:
|
397
419
|
metadata_arr = np.array(metadata)
|
398
420
|
embeddings_arr = np.array(embeddings)
|
399
421
|
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
copy
|
406
|
-
|
407
|
-
|
422
|
+
if index_param["table_quantization_type"] == "bit":
|
423
|
+
with self.cursor.copy(
|
424
|
+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format(
|
425
|
+
table_name=sql.Identifier(self.table_name)
|
426
|
+
)
|
427
|
+
) as copy:
|
428
|
+
# Same logic as pgvector binary_quantize
|
429
|
+
for i, row in enumerate(metadata_arr):
|
430
|
+
embeddings_bit = ""
|
431
|
+
for embedding in embeddings_arr[i]:
|
432
|
+
if embedding > 0:
|
433
|
+
embeddings_bit += "1"
|
434
|
+
else:
|
435
|
+
embeddings_bit += "0"
|
436
|
+
copy.write_row((str(row), embeddings_bit))
|
437
|
+
else:
|
438
|
+
with self.cursor.copy(
|
439
|
+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
|
440
|
+
table_name=sql.Identifier(self.table_name)
|
441
|
+
)
|
442
|
+
) as copy:
|
443
|
+
if index_param["table_quantization_type"] == "halfvec":
|
444
|
+
copy.set_types(["bigint", "halfvec"])
|
445
|
+
for i, row in enumerate(metadata_arr):
|
446
|
+
copy.write_row((row, np.float16(embeddings_arr[i])))
|
447
|
+
else:
|
448
|
+
copy.set_types(["bigint", "vector"])
|
449
|
+
for i, row in enumerate(metadata_arr):
|
450
|
+
copy.write_row((row, embeddings_arr[i]))
|
408
451
|
self.conn.commit()
|
409
452
|
|
410
453
|
if kwargs.get("last_batch"):
|