vectordb-bench 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. vectordb_bench/backend/clients/__init__.py +48 -0
  2. vectordb_bench/backend/clients/api.py +1 -0
  3. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +53 -4
  4. vectordb_bench/backend/clients/aws_opensearch/cli.py +85 -1
  5. vectordb_bench/backend/clients/aws_opensearch/config.py +10 -0
  6. vectordb_bench/backend/clients/mariadb/cli.py +107 -0
  7. vectordb_bench/backend/clients/mariadb/config.py +71 -0
  8. vectordb_bench/backend/clients/mariadb/mariadb.py +214 -0
  9. vectordb_bench/backend/clients/milvus/cli.py +50 -0
  10. vectordb_bench/backend/clients/milvus/config.py +33 -0
  11. vectordb_bench/backend/clients/mongodb/config.py +53 -0
  12. vectordb_bench/backend/clients/mongodb/mongodb.py +200 -0
  13. vectordb_bench/backend/clients/pgvector/cli.py +13 -1
  14. vectordb_bench/backend/clients/pgvector/config.py +22 -5
  15. vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
  16. vectordb_bench/backend/clients/tidb/cli.py +98 -0
  17. vectordb_bench/backend/clients/tidb/config.py +49 -0
  18. vectordb_bench/backend/clients/tidb/tidb.py +234 -0
  19. vectordb_bench/cli/vectordbbench.py +4 -0
  20. vectordb_bench/frontend/components/custom/displaypPrams.py +12 -1
  21. vectordb_bench/frontend/components/run_test/submitTask.py +20 -3
  22. vectordb_bench/frontend/config/dbCaseConfigs.py +128 -0
  23. vectordb_bench/frontend/config/styles.py +2 -0
  24. vectordb_bench/log_util.py +15 -2
  25. vectordb_bench/models.py +7 -0
  26. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/METADATA +67 -3
  27. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/RECORD +31 -23
  28. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/WHEEL +1 -1
  29. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/LICENSE +0 -0
  30. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/entry_points.txt +0 -0
  31. {vectordb_bench-0.0.21.dist-info → vectordb_bench-0.0.23.dist-info}/top_level.txt +0 -0
@@ -194,6 +194,56 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
194
194
  **parameters,
195
195
  )
196
196
 
197
+ @cli.command()
198
+ @click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
199
+ def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
200
+ from .config import GPUBruteForceConfig, MilvusConfig
201
+
202
+ run(
203
+ db=DBTYPE,
204
+ db_config=MilvusConfig(
205
+ db_label=parameters["db_label"],
206
+ uri=SecretStr(parameters["uri"]),
207
+ user=parameters["user_name"],
208
+ password=SecretStr(parameters["password"]),
209
+ ),
210
+ db_case_config=GPUBruteForceConfig(
211
+ metric_type=parameters["metric_type"],
212
+ limit=parameters["limit"], # top-k for search
213
+ ),
214
+ **parameters,
215
+ )
216
+
217
+ class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict):
218
+ metric_type: Annotated[
219
+ str,
220
+ click.option("--metric-type", type=str, required=True, help="Metric type for brute force search"),
221
+ ]
222
+ limit: Annotated[
223
+ int,
224
+ click.option("--limit", type=int, required=True, help="Top-k limit for search"),
225
+ ]
226
+
227
+ @cli.command()
228
+ @click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
229
+ def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
230
+ from .config import GPUBruteForceConfig, MilvusConfig
231
+
232
+ run(
233
+ db=DBTYPE,
234
+ db_config=MilvusConfig(
235
+ db_label=parameters["db_label"],
236
+ uri=SecretStr(parameters["uri"]),
237
+ user=parameters["user_name"],
238
+ password=SecretStr(parameters["password"]),
239
+ ),
240
+ db_case_config=GPUBruteForceConfig(
241
+ metric_type=parameters["metric_type"],
242
+ limit=parameters["limit"], # top-k for search
243
+ ),
244
+ **parameters,
245
+ )
246
+
197
247
 
198
248
  class MilvusGPUIVFPQTypedDict(
199
249
  CommonTypedDict,
@@ -40,6 +40,7 @@ class MilvusIndexConfig(BaseModel):
40
40
  IndexType.GPU_CAGRA,
41
41
  IndexType.GPU_IVF_FLAT,
42
42
  IndexType.GPU_IVF_PQ,
43
+ IndexType.GPU_BRUTE_FORCE,
43
44
  ]
44
45
 
45
46
  def parse_metric(self) -> str:
@@ -184,6 +185,37 @@ class GPUIVFFlatConfig(MilvusIndexConfig, DBCaseConfig):
184
185
  }
185
186
 
186
187
 
188
+ class GPUBruteForceConfig(MilvusIndexConfig, DBCaseConfig):
189
+ limit: int = 10 # Default top-k for search
190
+ metric_type: str # Metric type (e.g., 'L2', 'IP', etc.)
191
+ index: IndexType = IndexType.GPU_BRUTE_FORCE # Index type set to GPU_BRUTE_FORCE
192
+
193
+ def index_param(self) -> dict:
194
+ """
195
+ Returns the parameters for creating the GPU_BRUTE_FORCE index.
196
+ No additional parameters required for index building.
197
+ """
198
+ return {
199
+ "metric_type": self.parse_metric(), # Metric type for distance calculation (L2, IP, etc.)
200
+ "index_type": self.index.value, # GPU_BRUTE_FORCE index type
201
+ "params": {}, # No additional parameters for GPU_BRUTE_FORCE
202
+ }
203
+
204
+ def search_param(self) -> dict:
205
+ """
206
+ Returns the parameters for performing a search on the GPU_BRUTE_FORCE index.
207
+ Only metric_type and top-k (limit) are needed for search.
208
+ """
209
+ return {
210
+ "metric_type": self.parse_metric(), # Metric type for search
211
+ "params": {
212
+ "nprobe": 1, # For GPU_BRUTE_FORCE, set nprobe to 1 (brute force search)
213
+ "limit": self.limit, # Top-k for search
214
+ },
215
+ }
216
+
217
+
218
+
187
219
  class GPUIVFPQConfig(MilvusIndexConfig, DBCaseConfig):
188
220
  nlist: int = 1024
189
221
  m: int = 0
@@ -261,4 +293,5 @@ _milvus_case_config = {
261
293
  IndexType.GPU_IVF_FLAT: GPUIVFFlatConfig,
262
294
  IndexType.GPU_IVF_PQ: GPUIVFPQConfig,
263
295
  IndexType.GPU_CAGRA: GPUCAGRAConfig,
296
+ IndexType.GPU_BRUTE_FORCE: GPUBruteForceConfig,
264
297
  }
@@ -0,0 +1,53 @@
1
+ from enum import Enum
2
+
3
+ from pydantic import BaseModel, SecretStr
4
+
5
+ from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
6
+
7
+
8
+ class QuantizationType(Enum):
9
+ NONE = "none"
10
+ BINARY = "binary"
11
+ SCALAR = "scalar"
12
+
13
+
14
+ class MongoDBConfig(DBConfig, BaseModel):
15
+ connection_string: SecretStr = "mongodb+srv://<user>:<password>@<cluster_name>.heatl.mongodb.net"
16
+ database: str = "vdb_bench"
17
+
18
+ def to_dict(self) -> dict:
19
+ return {
20
+ "connection_string": self.connection_string.get_secret_value(),
21
+ "database": self.database,
22
+ }
23
+
24
+
25
+ class MongoDBIndexConfig(BaseModel, DBCaseConfig):
26
+ index: IndexType = IndexType.HNSW # MongoDB uses HNSW for vector search
27
+ metric_type: MetricType = MetricType.COSINE
28
+ num_candidates_ratio: int = 10 # Default numCandidates ratio for vector search
29
+ quantization: QuantizationType = QuantizationType.NONE # Quantization type if applicable
30
+
31
+ def parse_metric(self) -> str:
32
+ if self.metric_type == MetricType.L2:
33
+ return "euclidean"
34
+ if self.metric_type == MetricType.IP:
35
+ return "dotProduct"
36
+ return "cosine" # Default to cosine similarity
37
+
38
+ def index_param(self) -> dict:
39
+ return {
40
+ "type": "vectorSearch",
41
+ "fields": [
42
+ {
43
+ "type": "vector",
44
+ "similarity": self.parse_metric(),
45
+ "numDimensions": None, # Will be set in MongoDB class
46
+ "path": "vector", # Vector field name
47
+ "quantization": self.quantization.value,
48
+ }
49
+ ],
50
+ }
51
+
52
+ def search_param(self) -> dict:
53
+ return {"num_candidates_ratio": self.num_candidates_ratio}
@@ -0,0 +1,200 @@
1
+ import logging
2
+ import time
3
+ from contextlib import contextmanager
4
+
5
+ from pymongo import MongoClient
6
+ from pymongo.operations import SearchIndexModel
7
+
8
+ from ..api import VectorDB
9
+ from .config import MongoDBIndexConfig
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ class MongoDBError(Exception):
15
+ """Custom exception class for MongoDB client errors."""
16
+
17
+
18
+ class MongoDB(VectorDB):
19
+ def __init__(
20
+ self,
21
+ dim: int,
22
+ db_config: dict,
23
+ db_case_config: MongoDBIndexConfig,
24
+ collection_name: str = "vdb_bench_collection",
25
+ id_field: str = "id",
26
+ vector_field: str = "vector",
27
+ drop_old: bool = False,
28
+ **kwargs,
29
+ ):
30
+ self.dim = dim
31
+ self.db_config = db_config
32
+ self.case_config = db_case_config
33
+ self.collection_name = collection_name
34
+ self.id_field = id_field
35
+ self.vector_field = vector_field
36
+ self.drop_old = drop_old
37
+
38
+ # Update index dimensions
39
+ index_params = self.case_config.index_param()
40
+ log.info(f"index params: {index_params}")
41
+ index_params["fields"][0]["numDimensions"] = dim
42
+ self.index_params = index_params
43
+
44
+ # Initialize - they'll also be set in init()
45
+ uri = self.db_config["connection_string"]
46
+ self.client = MongoClient(uri)
47
+ self.db = self.client[self.db_config["database"]]
48
+ self.collection = self.db[self.collection_name]
49
+ if self.drop_old and self.collection_name in self.db.list_collection_names():
50
+ log.info(f"MongoDB client dropping old collection: {self.collection_name}")
51
+ self.db.drop_collection(self.collection_name)
52
+ self.client = None
53
+ self.db = None
54
+ self.collection = None
55
+
56
+ @contextmanager
57
+ def init(self):
58
+ """Initialize MongoDB client and cleanup when done"""
59
+ try:
60
+ uri = self.db_config["connection_string"]
61
+ self.client = MongoClient(uri)
62
+ self.db = self.client[self.db_config["database"]]
63
+ self.collection = self.db[self.collection_name]
64
+
65
+ yield
66
+ finally:
67
+ if self.client is not None:
68
+ self.client.close()
69
+ self.client = None
70
+ self.db = None
71
+ self.collection = None
72
+
73
+ def _create_index(self) -> None:
74
+ """Create vector search index"""
75
+ index_name = "vector_index"
76
+ index_params = self.index_params
77
+ log.info(f"index params {index_params}")
78
+ # drop index if already exists
79
+ if self.collection.list_indexes():
80
+ all_indexes = self.collection.list_search_indexes()
81
+ if any(idx.get("name") == index_name for idx in all_indexes):
82
+ log.info(f"Drop index: {index_name}")
83
+ try:
84
+ self.collection.drop_search_index(index_name)
85
+ while True:
86
+ indices = list(self.collection.list_search_indexes())
87
+ indices = [idx for idx in indices if idx["name"] == index_name]
88
+ log.debug(f"index status {indices}")
89
+ if len(indices) == 0:
90
+ break
91
+ log.info(f"index deleting {indices}")
92
+ except Exception:
93
+ log.exception(f"Error dropping index {index_name}")
94
+ try:
95
+ # Create vector search index
96
+ search_index = SearchIndexModel(definition=index_params, name=index_name, type="vectorSearch")
97
+
98
+ self.collection.create_search_index(search_index)
99
+ log.info(f"Created vector search index: {index_name}")
100
+ self._wait_for_index_ready(index_name)
101
+
102
+ # Create regular index on id field for faster lookups
103
+ self.collection.create_index(self.id_field)
104
+ log.info(f"Created index on {self.id_field} field")
105
+
106
+ except Exception:
107
+ log.exception(f"Error creating index {index_name}")
108
+ raise
109
+
110
+ def _wait_for_index_ready(self, index_name: str, check_interval: int = 5) -> None:
111
+ """Wait for index to be ready"""
112
+ while True:
113
+ indices = list(self.collection.list_search_indexes())
114
+ log.debug(f"index status {indices}")
115
+ if indices and any(idx.get("name") == index_name and idx.get("queryable") for idx in indices):
116
+ break
117
+ for idx in indices:
118
+ if idx.get("name") == index_name and idx.get("status") == "FAILED":
119
+ error_msg = f"Index {index_name} failed to build"
120
+ raise MongoDBError(error_msg)
121
+
122
+ time.sleep(check_interval)
123
+ log.info(f"Index {index_name} is ready")
124
+
125
+ def need_normalize_cosine(self) -> bool:
126
+ return False
127
+
128
+ def insert_embeddings(
129
+ self,
130
+ embeddings: list[list[float]],
131
+ metadata: list[int],
132
+ **kwargs,
133
+ ) -> (int, Exception | None):
134
+ """Insert embeddings into MongoDB"""
135
+
136
+ # Prepare documents in bulk
137
+ documents = [
138
+ {
139
+ self.id_field: id_,
140
+ self.vector_field: embedding,
141
+ }
142
+ for id_, embedding in zip(metadata, embeddings, strict=False)
143
+ ]
144
+
145
+ # Use ordered=False for better insert performance
146
+ try:
147
+ self.collection.insert_many(documents, ordered=False)
148
+ except Exception as e:
149
+ return 0, e
150
+ return len(documents), None
151
+
152
+ def search_embedding(
153
+ self,
154
+ query: list[float],
155
+ k: int = 100,
156
+ filters: dict | None = None,
157
+ **kwargs,
158
+ ) -> list[int]:
159
+ """Search for similar vectors"""
160
+ search_params = self.case_config.search_param()
161
+
162
+ vector_search = {"queryVector": query, "index": "vector_index", "path": self.vector_field, "limit": k}
163
+
164
+ # Add exact search parameter if specified
165
+ if search_params["exact"]:
166
+ vector_search["exact"] = True
167
+ else:
168
+ # Set numCandidates based on k value and data size
169
+ # For 50K dataset, use higher multiplier for better recall
170
+ num_candidates = min(10000, k * search_params["num_candidates_ratio"])
171
+ vector_search["numCandidates"] = num_candidates
172
+
173
+ # Add filter if specified
174
+ if filters:
175
+ log.info(f"Applying filter: {filters}")
176
+ vector_search["filter"] = {
177
+ "id": {"gte": filters["id"]},
178
+ }
179
+ pipeline = [
180
+ {"$vectorSearch": vector_search},
181
+ {
182
+ "$project": {
183
+ "_id": 0,
184
+ self.id_field: 1,
185
+ "score": {"$meta": "vectorSearchScore"}, # Include similarity score
186
+ }
187
+ },
188
+ ]
189
+
190
+ results = list(self.collection.aggregate(pipeline))
191
+ return [doc[self.id_field] for doc in results]
192
+
193
+ def optimize(self, data_size: int | None = None) -> None:
194
+ """MongoDB vector search indexes are self-optimizing"""
195
+ log.info("optimize for search")
196
+ self._create_index()
197
+ self._wait_for_index_ready("vector_index")
198
+
199
+ def ready_to_load(self) -> None:
200
+ """MongoDB is always ready to load"""
@@ -82,7 +82,17 @@ class PgVectorTypedDict(CommonTypedDict):
82
82
  click.option(
83
83
  "--quantization-type",
84
84
  type=click.Choice(["none", "bit", "halfvec"]),
85
- help="quantization type for vectors",
85
+ help="quantization type for vectors (in index)",
86
+ required=False,
87
+ ),
88
+ ]
89
+ table_quantization_type: Annotated[
90
+ str | None,
91
+ click.option(
92
+ "--table-quantization-type",
93
+ type=click.Choice(["none", "bit", "halfvec"]),
94
+ help="quantization type for vectors (in table). "
95
+ "If equal to bit, the parameter quantization_type will be set to bit too.",
86
96
  required=False,
87
97
  ),
88
98
  ]
@@ -146,6 +156,7 @@ def PgVectorIVFFlat(
146
156
  lists=parameters["lists"],
147
157
  probes=parameters["probes"],
148
158
  quantization_type=parameters["quantization_type"],
159
+ table_quantization_type=parameters["table_quantization_type"],
149
160
  reranking=parameters["reranking"],
150
161
  reranking_metric=parameters["reranking_metric"],
151
162
  quantized_fetch_limit=parameters["quantized_fetch_limit"],
@@ -182,6 +193,7 @@ def PgVectorHNSW(
182
193
  maintenance_work_mem=parameters["maintenance_work_mem"],
183
194
  max_parallel_workers=parameters["max_parallel_workers"],
184
195
  quantization_type=parameters["quantization_type"],
196
+ table_quantization_type=parameters["table_quantization_type"],
185
197
  reranking=parameters["reranking"],
186
198
  reranking_metric=parameters["reranking_metric"],
187
199
  quantized_fetch_limit=parameters["quantized_fetch_limit"],
@@ -80,7 +80,12 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
80
80
 
81
81
  if d.get(self.quantization_type) is None:
82
82
  return d.get("_fallback").get(self.metric_type)
83
- return d.get(self.quantization_type).get(self.metric_type)
83
+ metric = d.get(self.quantization_type).get(self.metric_type)
84
+ # If using binary quantization for the index, use a bit metric
85
+ # no matter what metric was selected for vector or halfvec data
86
+ if self.quantization_type == "bit" and metric is None:
87
+ return "bit_hamming_ops"
88
+ return metric
84
89
 
85
90
  def parse_metric_fun_op(self) -> LiteralString:
86
91
  if self.quantization_type == "bit":
@@ -168,14 +173,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
168
173
  maintenance_work_mem: str | None = None
169
174
  max_parallel_workers: int | None = None
170
175
  quantization_type: str | None = None
176
+ table_quantization_type: str | None
171
177
  reranking: bool | None = None
172
178
  quantized_fetch_limit: int | None = None
173
179
  reranking_metric: str | None = None
174
180
 
175
181
  def index_param(self) -> PgVectorIndexParam:
176
182
  index_parameters = {"lists": self.lists}
177
- if self.quantization_type == "none":
178
- self.quantization_type = None
183
+ if self.quantization_type == "none" or self.quantization_type is None:
184
+ self.quantization_type = "vector"
185
+ if self.table_quantization_type == "none" or self.table_quantization_type is None:
186
+ self.table_quantization_type = "vector"
187
+ if self.table_quantization_type == "bit":
188
+ self.quantization_type = "bit"
179
189
  return {
180
190
  "metric": self.parse_metric(),
181
191
  "index_type": self.index.value,
@@ -183,6 +193,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
183
193
  "maintenance_work_mem": self.maintenance_work_mem,
184
194
  "max_parallel_workers": self.max_parallel_workers,
185
195
  "quantization_type": self.quantization_type,
196
+ "table_quantization_type": self.table_quantization_type,
186
197
  }
187
198
 
188
199
  def search_param(self) -> PgVectorSearchParam:
@@ -212,14 +223,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
212
223
  maintenance_work_mem: str | None = None
213
224
  max_parallel_workers: int | None = None
214
225
  quantization_type: str | None = None
226
+ table_quantization_type: str | None
215
227
  reranking: bool | None = None
216
228
  quantized_fetch_limit: int | None = None
217
229
  reranking_metric: str | None = None
218
230
 
219
231
  def index_param(self) -> PgVectorIndexParam:
220
232
  index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
221
- if self.quantization_type == "none":
222
- self.quantization_type = None
233
+ if self.quantization_type == "none" or self.quantization_type is None:
234
+ self.quantization_type = "vector"
235
+ if self.table_quantization_type == "none" or self.table_quantization_type is None:
236
+ self.table_quantization_type = "vector"
237
+ if self.table_quantization_type == "bit":
238
+ self.quantization_type = "bit"
223
239
  return {
224
240
  "metric": self.parse_metric(),
225
241
  "index_type": self.index.value,
@@ -227,6 +243,7 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
227
243
  "maintenance_work_mem": self.maintenance_work_mem,
228
244
  "max_parallel_workers": self.max_parallel_workers,
229
245
  "quantization_type": self.quantization_type,
246
+ "table_quantization_type": self.table_quantization_type,
230
247
  }
231
248
 
232
249
  def search_param(self) -> PgVectorSearchParam:
@@ -94,7 +94,7 @@ class PgVector(VectorDB):
94
94
  reranking = self.case_config.search_param()["reranking"]
95
95
  column_name = (
96
96
  sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
97
- if index_param["quantization_type"] == "bit"
97
+ if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit"
98
98
  else sql.SQL("embedding")
99
99
  )
100
100
  search_vector = (
@@ -104,7 +104,8 @@ class PgVector(VectorDB):
104
104
  )
105
105
 
106
106
  # The following sections assume that the quantization_type value matches the quantization function name
107
- if index_param["quantization_type"] is not None:
107
+ if index_param["quantization_type"] != index_param["table_quantization_type"]:
108
+ # Reranking makes sense only if table quantization is not "bit"
108
109
  if index_param["quantization_type"] == "bit" and reranking:
109
110
  # Embeddings needs to be passed to binary_quantize function if quantization_type is bit
110
111
  search_query = sql.Composed(
@@ -113,7 +114,7 @@ class PgVector(VectorDB):
113
114
  """
114
115
  SELECT i.id
115
116
  FROM (
116
- SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
117
+ SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance
117
118
  FROM public.{table_name} {where_clause}
118
119
  ORDER BY {column_name}::{quantization_type}({dim})
119
120
  """,
@@ -123,6 +124,8 @@ class PgVector(VectorDB):
123
124
  reranking_metric_fun_op=sql.SQL(
124
125
  self.case_config.search_param()["reranking_metric_fun_op"],
125
126
  ),
127
+ search_vector=search_vector,
128
+ table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
126
129
  quantization_type=sql.SQL(index_param["quantization_type"]),
127
130
  dim=sql.Literal(self.dim),
128
131
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
@@ -130,7 +133,7 @@ class PgVector(VectorDB):
130
133
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
131
134
  sql.SQL(
132
135
  """
133
- {search_vector}
136
+ {search_vector}::{quantization_type}({dim})
134
137
  LIMIT {quantized_fetch_limit}
135
138
  ) i
136
139
  ORDER BY i.distance
@@ -138,6 +141,8 @@ class PgVector(VectorDB):
138
141
  """,
139
142
  ).format(
140
143
  search_vector=search_vector,
144
+ quantization_type=sql.SQL(index_param["quantization_type"]),
145
+ dim=sql.Literal(self.dim),
141
146
  quantized_fetch_limit=sql.Literal(
142
147
  self.case_config.search_param()["quantized_fetch_limit"],
143
148
  ),
@@ -160,10 +165,12 @@ class PgVector(VectorDB):
160
165
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
161
166
  ),
162
167
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
163
- sql.SQL(" {search_vector} LIMIT %s::int").format(
168
+ sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
164
169
  search_vector=search_vector,
170
+ quantization_type=sql.SQL(index_param["quantization_type"]),
171
+ dim=sql.Literal(self.dim),
165
172
  ),
166
- ],
173
+ ]
167
174
  )
168
175
  else:
169
176
  search_query = sql.Composed(
@@ -175,8 +182,12 @@ class PgVector(VectorDB):
175
182
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
176
183
  ),
177
184
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
178
- sql.SQL(" %s::vector LIMIT %s::int"),
179
- ],
185
+ sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
186
+ search_vector=search_vector,
187
+ quantization_type=sql.SQL(index_param["quantization_type"]),
188
+ dim=sql.Literal(self.dim),
189
+ ),
190
+ ]
180
191
  )
181
192
 
182
193
  return search_query
@@ -323,7 +334,7 @@ class PgVector(VectorDB):
323
334
  )
324
335
  with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
325
336
 
326
- if index_param["quantization_type"] is not None:
337
+ if index_param["quantization_type"] != index_param["table_quantization_type"]:
327
338
  index_create_sql = sql.SQL(
328
339
  """
329
340
  CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
@@ -365,14 +376,23 @@ class PgVector(VectorDB):
365
376
  assert self.conn is not None, "Connection is not initialized"
366
377
  assert self.cursor is not None, "Cursor is not initialized"
367
378
 
379
+ index_param = self.case_config.index_param()
380
+
368
381
  try:
369
382
  log.info(f"{self.name} client create table : {self.table_name}")
370
383
 
371
384
  # create table
372
385
  self.cursor.execute(
373
386
  sql.SQL(
374
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
375
- ).format(table_name=sql.Identifier(self.table_name), dim=dim),
387
+ """
388
+ CREATE TABLE IF NOT EXISTS public.{table_name}
389
+ (id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));
390
+ """
391
+ ).format(
392
+ table_name=sql.Identifier(self.table_name),
393
+ table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
394
+ dim=dim,
395
+ )
376
396
  )
377
397
  self.cursor.execute(
378
398
  sql.SQL(
@@ -393,18 +413,41 @@ class PgVector(VectorDB):
393
413
  assert self.conn is not None, "Connection is not initialized"
394
414
  assert self.cursor is not None, "Cursor is not initialized"
395
415
 
416
+ index_param = self.case_config.index_param()
417
+
396
418
  try:
397
419
  metadata_arr = np.array(metadata)
398
420
  embeddings_arr = np.array(embeddings)
399
421
 
400
- with self.cursor.copy(
401
- sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
402
- table_name=sql.Identifier(self.table_name),
403
- ),
404
- ) as copy:
405
- copy.set_types(["bigint", "vector"])
406
- for i, row in enumerate(metadata_arr):
407
- copy.write_row((row, embeddings_arr[i]))
422
+ if index_param["table_quantization_type"] == "bit":
423
+ with self.cursor.copy(
424
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format(
425
+ table_name=sql.Identifier(self.table_name)
426
+ )
427
+ ) as copy:
428
+ # Same logic as pgvector binary_quantize
429
+ for i, row in enumerate(metadata_arr):
430
+ embeddings_bit = ""
431
+ for embedding in embeddings_arr[i]:
432
+ if embedding > 0:
433
+ embeddings_bit += "1"
434
+ else:
435
+ embeddings_bit += "0"
436
+ copy.write_row((str(row), embeddings_bit))
437
+ else:
438
+ with self.cursor.copy(
439
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
440
+ table_name=sql.Identifier(self.table_name)
441
+ )
442
+ ) as copy:
443
+ if index_param["table_quantization_type"] == "halfvec":
444
+ copy.set_types(["bigint", "halfvec"])
445
+ for i, row in enumerate(metadata_arr):
446
+ copy.write_row((row, np.float16(embeddings_arr[i])))
447
+ else:
448
+ copy.set_types(["bigint", "vector"])
449
+ for i, row in enumerate(metadata_arr):
450
+ copy.write_row((row, embeddings_arr[i]))
408
451
  self.conn.commit()
409
452
 
410
453
  if kwargs.get("last_batch"):