vectordb-bench 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. vectordb_bench/backend/clients/__init__.py +65 -1
  2. vectordb_bench/backend/clients/api.py +2 -1
  3. vectordb_bench/backend/clients/chroma/chroma.py +2 -2
  4. vectordb_bench/backend/clients/clickhouse/cli.py +66 -0
  5. vectordb_bench/backend/clients/clickhouse/clickhouse.py +156 -0
  6. vectordb_bench/backend/clients/clickhouse/config.py +60 -0
  7. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +1 -1
  8. vectordb_bench/backend/clients/mariadb/cli.py +122 -0
  9. vectordb_bench/backend/clients/mariadb/config.py +73 -0
  10. vectordb_bench/backend/clients/mariadb/mariadb.py +208 -0
  11. vectordb_bench/backend/clients/milvus/cli.py +32 -0
  12. vectordb_bench/backend/clients/milvus/config.py +32 -0
  13. vectordb_bench/backend/clients/milvus/milvus.py +1 -1
  14. vectordb_bench/backend/clients/pgvector/cli.py +14 -3
  15. vectordb_bench/backend/clients/pgvector/config.py +22 -5
  16. vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
  17. vectordb_bench/backend/clients/pinecone/pinecone.py +1 -1
  18. vectordb_bench/backend/clients/qdrant_cloud/config.py +1 -9
  19. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +1 -1
  20. vectordb_bench/backend/clients/tidb/cli.py +98 -0
  21. vectordb_bench/backend/clients/tidb/config.py +46 -0
  22. vectordb_bench/backend/clients/tidb/tidb.py +233 -0
  23. vectordb_bench/backend/clients/vespa/cli.py +47 -0
  24. vectordb_bench/backend/clients/vespa/config.py +51 -0
  25. vectordb_bench/backend/clients/vespa/util.py +15 -0
  26. vectordb_bench/backend/clients/vespa/vespa.py +249 -0
  27. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
  28. vectordb_bench/cli/cli.py +20 -17
  29. vectordb_bench/cli/vectordbbench.py +8 -0
  30. vectordb_bench/frontend/config/dbCaseConfigs.py +147 -0
  31. vectordb_bench/frontend/config/styles.py +4 -0
  32. vectordb_bench/models.py +8 -6
  33. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/METADATA +22 -3
  34. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/RECORD +38 -25
  35. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/WHEEL +1 -1
  36. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/entry_points.txt +0 -0
  37. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info/licenses}/LICENSE +0 -0
  38. {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+
4
+ import mariadb
5
+ import numpy as np
6
+
7
+ from ..api import VectorDB
8
+ from .config import MariaDBConfigDict, MariaDBIndexConfig
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ class MariaDB(VectorDB):
14
+ def __init__(
15
+ self,
16
+ dim: int,
17
+ db_config: MariaDBConfigDict,
18
+ db_case_config: MariaDBIndexConfig,
19
+ collection_name: str = "vec_collection",
20
+ drop_old: bool = False,
21
+ **kwargs,
22
+ ):
23
+ self.name = "MariaDB"
24
+ self.db_config = db_config
25
+ self.case_config = db_case_config
26
+ self.db_name = "vectordbbench"
27
+ self.table_name = collection_name
28
+ self.dim = dim
29
+
30
+ # construct basic units
31
+ self.conn, self.cursor = self._create_connection(**self.db_config)
32
+
33
+ if drop_old:
34
+ self._drop_db()
35
+ self._create_db_table(dim)
36
+
37
+ self.cursor.close()
38
+ self.conn.close()
39
+ self.cursor = None
40
+ self.conn = None
41
+
42
+ @staticmethod
43
+ def _create_connection(**kwargs) -> tuple[mariadb.Connection, mariadb.Cursor]:
44
+ conn = mariadb.connect(**kwargs)
45
+ cursor = conn.cursor()
46
+
47
+ assert conn is not None, "Connection is not initialized"
48
+ assert cursor is not None, "Cursor is not initialized"
49
+
50
+ return conn, cursor
51
+
52
+ def _drop_db(self):
53
+ assert self.conn is not None, "Connection is not initialized"
54
+ assert self.cursor is not None, "Cursor is not initialized"
55
+ log.info(f"{self.name} client drop db : {self.db_name}")
56
+
57
+ # flush tables before dropping database to avoid some locking issue
58
+ self.cursor.execute("FLUSH TABLES")
59
+ self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}")
60
+ self.cursor.execute("COMMIT")
61
+ self.cursor.execute("FLUSH TABLES")
62
+
63
+ def _create_db_table(self, dim: int):
64
+ assert self.conn is not None, "Connection is not initialized"
65
+ assert self.cursor is not None, "Cursor is not initialized"
66
+
67
+ index_param = self.case_config.index_param()
68
+
69
+ try:
70
+ log.info(f"{self.name} client create database : {self.db_name}")
71
+ self.cursor.execute(f"CREATE DATABASE {self.db_name}")
72
+
73
+ log.info(f"{self.name} client create table : {self.table_name}")
74
+ self.cursor.execute(f"USE {self.db_name}")
75
+
76
+ self.cursor.execute(
77
+ f"""
78
+ CREATE TABLE {self.table_name} (
79
+ id INT PRIMARY KEY,
80
+ v VECTOR({self.dim}) NOT NULL
81
+ ) ENGINE={index_param["storage_engine"]}
82
+ """
83
+ )
84
+ self.cursor.execute("COMMIT")
85
+
86
+ except Exception as e:
87
+ log.warning(f"Failed to create table: {self.table_name} error: {e}")
88
+ raise e from None
89
+
90
+ @contextmanager
91
+ def init(self):
92
+ """create and destory connections to database.
93
+
94
+ Examples:
95
+ >>> with self.init():
96
+ >>> self.insert_embeddings()
97
+ """
98
+ self.conn, self.cursor = self._create_connection(**self.db_config)
99
+
100
+ index_param = self.case_config.index_param()
101
+ search_param = self.case_config.search_param()
102
+
103
+ # maximize allowed package size
104
+ self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824")
105
+
106
+ if index_param["index_type"] == "HNSW":
107
+ if index_param["max_cache_size"] is not None:
108
+ self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param['max_cache_size']}")
109
+ if search_param["ef_search"] is not None:
110
+ self.cursor.execute(f"SET mhnsw_ef_search = {search_param['ef_search']}")
111
+ self.cursor.execute("COMMIT")
112
+
113
+ self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)" # noqa: S608
114
+ self.select_sql = (
115
+ f"SELECT id FROM {self.db_name}.{self.table_name}" # noqa: S608
116
+ f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
117
+ )
118
+ self.select_sql_with_filter = (
119
+ f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d " # noqa: S608
120
+ f"ORDER by vec_distance_{search_param['metric_type']}(v, %s) LIMIT %d"
121
+ )
122
+
123
+ try:
124
+ yield
125
+ finally:
126
+ self.cursor.close()
127
+ self.conn.close()
128
+ self.cursor = None
129
+ self.conn = None
130
+
131
+ def ready_to_load(self) -> bool:
132
+ pass
133
+
134
+ def optimize(self) -> None:
135
+ assert self.conn is not None, "Connection is not initialized"
136
+ assert self.cursor is not None, "Cursor is not initialized"
137
+
138
+ index_param = self.case_config.index_param()
139
+
140
+ try:
141
+ index_options = f"DISTANCE={index_param['metric_type']}"
142
+ if index_param["index_type"] == "HNSW" and index_param["M"] is not None:
143
+ index_options += f" M={index_param['M']}"
144
+
145
+ self.cursor.execute(
146
+ f"""
147
+ ALTER TABLE {self.db_name}.{self.table_name}
148
+ ADD VECTOR KEY v(v) {index_options}
149
+ """
150
+ )
151
+ self.cursor.execute("COMMIT")
152
+
153
+ except Exception as e:
154
+ log.warning(f"Failed to create index: {self.table_name} error: {e}")
155
+ raise e from None
156
+
157
+ @staticmethod
158
+ def vector_to_hex(v): # noqa: ANN001
159
+ return np.array(v, "float32").tobytes()
160
+
161
+ def insert_embeddings(
162
+ self,
163
+ embeddings: list[list[float]],
164
+ metadata: list[int],
165
+ **kwargs,
166
+ ) -> tuple[int, Exception]:
167
+ """Insert embeddings into the database.
168
+ Should call self.init() first.
169
+ """
170
+ assert self.conn is not None, "Connection is not initialized"
171
+ assert self.cursor is not None, "Cursor is not initialized"
172
+
173
+ try:
174
+ metadata_arr = np.array(metadata)
175
+ embeddings_arr = np.array(embeddings)
176
+
177
+ batch_data = []
178
+ for i, row in enumerate(metadata_arr):
179
+ batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])))
180
+
181
+ self.cursor.executemany(self.insert_sql, batch_data)
182
+ self.cursor.execute("COMMIT")
183
+ self.cursor.execute("FLUSH TABLES")
184
+
185
+ return len(metadata), None
186
+ except Exception as e:
187
+ log.warning(f"Failed to insert data into Vector table ({self.table_name}), error: {e}")
188
+ return 0, e
189
+
190
+ def search_embedding(
191
+ self,
192
+ query: list[float],
193
+ k: int = 100,
194
+ filters: dict | None = None,
195
+ timeout: int | None = None,
196
+ **kwargs,
197
+ ) -> list[int]:
198
+ assert self.conn is not None, "Connection is not initialized"
199
+ assert self.cursor is not None, "Cursor is not initialized"
200
+
201
+ search_param = self.case_config.search_param() # noqa: F841
202
+
203
+ if filters:
204
+ self.cursor.execute(self.select_sql_with_filter, (filters.get("id"), self.vector_to_hex(query), k))
205
+ else:
206
+ self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k))
207
+
208
+ return [id for (id,) in self.cursor.fetchall()] # noqa: A001
@@ -195,6 +195,38 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
195
195
  )
196
196
 
197
197
 
198
+ class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict):
199
+ metric_type: Annotated[
200
+ str,
201
+ click.option("--metric-type", type=str, required=True, help="Metric type for brute force search"),
202
+ ]
203
+ limit: Annotated[
204
+ int,
205
+ click.option("--limit", type=int, required=True, help="Top-k limit for search"),
206
+ ]
207
+
208
+
209
+ @cli.command()
210
+ @click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
211
+ def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
212
+ from .config import GPUBruteForceConfig, MilvusConfig
213
+
214
+ run(
215
+ db=DBTYPE,
216
+ db_config=MilvusConfig(
217
+ db_label=parameters["db_label"],
218
+ uri=SecretStr(parameters["uri"]),
219
+ user=parameters["user_name"],
220
+ password=SecretStr(parameters["password"]),
221
+ ),
222
+ db_case_config=GPUBruteForceConfig(
223
+ metric_type=parameters["metric_type"],
224
+ limit=parameters["limit"], # top-k for search
225
+ ),
226
+ **parameters,
227
+ )
228
+
229
+
198
230
  class MilvusGPUIVFPQTypedDict(
199
231
  CommonTypedDict,
200
232
  MilvusTypedDict,
@@ -40,6 +40,7 @@ class MilvusIndexConfig(BaseModel):
40
40
  IndexType.GPU_CAGRA,
41
41
  IndexType.GPU_IVF_FLAT,
42
42
  IndexType.GPU_IVF_PQ,
43
+ IndexType.GPU_BRUTE_FORCE,
43
44
  ]
44
45
 
45
46
  def parse_metric(self) -> str:
@@ -184,6 +185,36 @@ class GPUIVFFlatConfig(MilvusIndexConfig, DBCaseConfig):
184
185
  }
185
186
 
186
187
 
188
+ class GPUBruteForceConfig(MilvusIndexConfig, DBCaseConfig):
189
+ limit: int = 10 # Default top-k for search
190
+ metric_type: str # Metric type (e.g., 'L2', 'IP', etc.)
191
+ index: IndexType = IndexType.GPU_BRUTE_FORCE # Index type set to GPU_BRUTE_FORCE
192
+
193
+ def index_param(self) -> dict:
194
+ """
195
+ Returns the parameters for creating the GPU_BRUTE_FORCE index.
196
+ No additional parameters required for index building.
197
+ """
198
+ return {
199
+ "metric_type": self.parse_metric(), # Metric type for distance calculation (L2, IP, etc.)
200
+ "index_type": self.index.value, # GPU_BRUTE_FORCE index type
201
+ "params": {}, # No additional parameters for GPU_BRUTE_FORCE
202
+ }
203
+
204
+ def search_param(self) -> dict:
205
+ """
206
+ Returns the parameters for performing a search on the GPU_BRUTE_FORCE index.
207
+ Only metric_type and top-k (limit) are needed for search.
208
+ """
209
+ return {
210
+ "metric_type": self.parse_metric(), # Metric type for search
211
+ "params": {
212
+ "nprobe": 1, # For GPU_BRUTE_FORCE, set nprobe to 1 (brute force search)
213
+ "limit": self.limit, # Top-k for search
214
+ },
215
+ }
216
+
217
+
187
218
  class GPUIVFPQConfig(MilvusIndexConfig, DBCaseConfig):
188
219
  nlist: int = 1024
189
220
  m: int = 0
@@ -261,4 +292,5 @@ _milvus_case_config = {
261
292
  IndexType.GPU_IVF_FLAT: GPUIVFFlatConfig,
262
293
  IndexType.GPU_IVF_PQ: GPUIVFPQConfig,
263
294
  IndexType.GPU_CAGRA: GPUCAGRAConfig,
295
+ IndexType.GPU_BRUTE_FORCE: GPUBruteForceConfig,
264
296
  }
@@ -155,7 +155,7 @@ class Milvus(VectorDB):
155
155
  embeddings: Iterable[list[float]],
156
156
  metadata: list[int],
157
157
  **kwargs,
158
- ) -> (int, Exception):
158
+ ) -> tuple[int, Exception]:
159
159
  """Insert embeddings into Milvus. should call self.init() first"""
160
160
  # use the first insert_embeddings to init collection
161
161
  assert self.col is not None
@@ -18,8 +18,7 @@ from ....cli.cli import (
18
18
  )
19
19
 
20
20
 
21
- # ruff: noqa
22
- def set_default_quantized_fetch_limit(ctx: any, param: any, value: any):
21
+ def set_default_quantized_fetch_limit(ctx: any, param: any, value: any): # noqa: ARG001
23
22
  if ctx.params.get("reranking") and value is None:
24
23
  # ef_search is the default value for quantized_fetch_limit as it's bound by ef_search.
25
24
  # 100 is default value for quantized_fetch_limit for IVFFlat.
@@ -82,7 +81,17 @@ class PgVectorTypedDict(CommonTypedDict):
82
81
  click.option(
83
82
  "--quantization-type",
84
83
  type=click.Choice(["none", "bit", "halfvec"]),
85
- help="quantization type for vectors",
84
+ help="quantization type for vectors (in index)",
85
+ required=False,
86
+ ),
87
+ ]
88
+ table_quantization_type: Annotated[
89
+ str | None,
90
+ click.option(
91
+ "--table-quantization-type",
92
+ type=click.Choice(["none", "bit", "halfvec"]),
93
+ help="quantization type for vectors (in table). "
94
+ "If equal to bit, the parameter quantization_type will be set to bit too.",
86
95
  required=False,
87
96
  ),
88
97
  ]
@@ -146,6 +155,7 @@ def PgVectorIVFFlat(
146
155
  lists=parameters["lists"],
147
156
  probes=parameters["probes"],
148
157
  quantization_type=parameters["quantization_type"],
158
+ table_quantization_type=parameters["table_quantization_type"],
149
159
  reranking=parameters["reranking"],
150
160
  reranking_metric=parameters["reranking_metric"],
151
161
  quantized_fetch_limit=parameters["quantized_fetch_limit"],
@@ -182,6 +192,7 @@ def PgVectorHNSW(
182
192
  maintenance_work_mem=parameters["maintenance_work_mem"],
183
193
  max_parallel_workers=parameters["max_parallel_workers"],
184
194
  quantization_type=parameters["quantization_type"],
195
+ table_quantization_type=parameters["table_quantization_type"],
185
196
  reranking=parameters["reranking"],
186
197
  reranking_metric=parameters["reranking_metric"],
187
198
  quantized_fetch_limit=parameters["quantized_fetch_limit"],
@@ -80,7 +80,12 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
80
80
 
81
81
  if d.get(self.quantization_type) is None:
82
82
  return d.get("_fallback").get(self.metric_type)
83
- return d.get(self.quantization_type).get(self.metric_type)
83
+ metric = d.get(self.quantization_type).get(self.metric_type)
84
+ # If using binary quantization for the index, use a bit metric
85
+ # no matter what metric was selected for vector or halfvec data
86
+ if self.quantization_type == "bit" and metric is None:
87
+ return "bit_hamming_ops"
88
+ return metric
84
89
 
85
90
  def parse_metric_fun_op(self) -> LiteralString:
86
91
  if self.quantization_type == "bit":
@@ -168,14 +173,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
168
173
  maintenance_work_mem: str | None = None
169
174
  max_parallel_workers: int | None = None
170
175
  quantization_type: str | None = None
176
+ table_quantization_type: str | None
171
177
  reranking: bool | None = None
172
178
  quantized_fetch_limit: int | None = None
173
179
  reranking_metric: str | None = None
174
180
 
175
181
  def index_param(self) -> PgVectorIndexParam:
176
182
  index_parameters = {"lists": self.lists}
177
- if self.quantization_type == "none":
178
- self.quantization_type = None
183
+ if self.quantization_type == "none" or self.quantization_type is None:
184
+ self.quantization_type = "vector"
185
+ if self.table_quantization_type == "none" or self.table_quantization_type is None:
186
+ self.table_quantization_type = "vector"
187
+ if self.table_quantization_type == "bit":
188
+ self.quantization_type = "bit"
179
189
  return {
180
190
  "metric": self.parse_metric(),
181
191
  "index_type": self.index.value,
@@ -183,6 +193,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
183
193
  "maintenance_work_mem": self.maintenance_work_mem,
184
194
  "max_parallel_workers": self.max_parallel_workers,
185
195
  "quantization_type": self.quantization_type,
196
+ "table_quantization_type": self.table_quantization_type,
186
197
  }
187
198
 
188
199
  def search_param(self) -> PgVectorSearchParam:
@@ -212,14 +223,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
212
223
  maintenance_work_mem: str | None = None
213
224
  max_parallel_workers: int | None = None
214
225
  quantization_type: str | None = None
226
+ table_quantization_type: str | None
215
227
  reranking: bool | None = None
216
228
  quantized_fetch_limit: int | None = None
217
229
  reranking_metric: str | None = None
218
230
 
219
231
  def index_param(self) -> PgVectorIndexParam:
220
232
  index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
221
- if self.quantization_type == "none":
222
- self.quantization_type = None
233
+ if self.quantization_type == "none" or self.quantization_type is None:
234
+ self.quantization_type = "vector"
235
+ if self.table_quantization_type == "none" or self.table_quantization_type is None:
236
+ self.table_quantization_type = "vector"
237
+ if self.table_quantization_type == "bit":
238
+ self.quantization_type = "bit"
223
239
  return {
224
240
  "metric": self.parse_metric(),
225
241
  "index_type": self.index.value,
@@ -227,6 +243,7 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
227
243
  "maintenance_work_mem": self.maintenance_work_mem,
228
244
  "max_parallel_workers": self.max_parallel_workers,
229
245
  "quantization_type": self.quantization_type,
246
+ "table_quantization_type": self.table_quantization_type,
230
247
  }
231
248
 
232
249
  def search_param(self) -> PgVectorSearchParam:
@@ -94,7 +94,7 @@ class PgVector(VectorDB):
94
94
  reranking = self.case_config.search_param()["reranking"]
95
95
  column_name = (
96
96
  sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
97
- if index_param["quantization_type"] == "bit"
97
+ if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit"
98
98
  else sql.SQL("embedding")
99
99
  )
100
100
  search_vector = (
@@ -104,7 +104,8 @@ class PgVector(VectorDB):
104
104
  )
105
105
 
106
106
  # The following sections assume that the quantization_type value matches the quantization function name
107
- if index_param["quantization_type"] is not None:
107
+ if index_param["quantization_type"] != index_param["table_quantization_type"]:
108
+ # Reranking makes sense only if table quantization is not "bit"
108
109
  if index_param["quantization_type"] == "bit" and reranking:
109
110
  # Embeddings needs to be passed to binary_quantize function if quantization_type is bit
110
111
  search_query = sql.Composed(
@@ -113,7 +114,7 @@ class PgVector(VectorDB):
113
114
  """
114
115
  SELECT i.id
115
116
  FROM (
116
- SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
117
+ SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance
117
118
  FROM public.{table_name} {where_clause}
118
119
  ORDER BY {column_name}::{quantization_type}({dim})
119
120
  """,
@@ -123,6 +124,8 @@ class PgVector(VectorDB):
123
124
  reranking_metric_fun_op=sql.SQL(
124
125
  self.case_config.search_param()["reranking_metric_fun_op"],
125
126
  ),
127
+ search_vector=search_vector,
128
+ table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
126
129
  quantization_type=sql.SQL(index_param["quantization_type"]),
127
130
  dim=sql.Literal(self.dim),
128
131
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
@@ -130,7 +133,7 @@ class PgVector(VectorDB):
130
133
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
131
134
  sql.SQL(
132
135
  """
133
- {search_vector}
136
+ {search_vector}::{quantization_type}({dim})
134
137
  LIMIT {quantized_fetch_limit}
135
138
  ) i
136
139
  ORDER BY i.distance
@@ -138,6 +141,8 @@ class PgVector(VectorDB):
138
141
  """,
139
142
  ).format(
140
143
  search_vector=search_vector,
144
+ quantization_type=sql.SQL(index_param["quantization_type"]),
145
+ dim=sql.Literal(self.dim),
141
146
  quantized_fetch_limit=sql.Literal(
142
147
  self.case_config.search_param()["quantized_fetch_limit"],
143
148
  ),
@@ -160,10 +165,12 @@ class PgVector(VectorDB):
160
165
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
161
166
  ),
162
167
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
163
- sql.SQL(" {search_vector} LIMIT %s::int").format(
168
+ sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
164
169
  search_vector=search_vector,
170
+ quantization_type=sql.SQL(index_param["quantization_type"]),
171
+ dim=sql.Literal(self.dim),
165
172
  ),
166
- ],
173
+ ]
167
174
  )
168
175
  else:
169
176
  search_query = sql.Composed(
@@ -175,8 +182,12 @@ class PgVector(VectorDB):
175
182
  where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
176
183
  ),
177
184
  sql.SQL(self.case_config.search_param()["metric_fun_op"]),
178
- sql.SQL(" %s::vector LIMIT %s::int"),
179
- ],
185
+ sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
186
+ search_vector=search_vector,
187
+ quantization_type=sql.SQL(index_param["quantization_type"]),
188
+ dim=sql.Literal(self.dim),
189
+ ),
190
+ ]
180
191
  )
181
192
 
182
193
  return search_query
@@ -323,7 +334,7 @@ class PgVector(VectorDB):
323
334
  )
324
335
  with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
325
336
 
326
- if index_param["quantization_type"] is not None:
337
+ if index_param["quantization_type"] != index_param["table_quantization_type"]:
327
338
  index_create_sql = sql.SQL(
328
339
  """
329
340
  CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
@@ -365,14 +376,23 @@ class PgVector(VectorDB):
365
376
  assert self.conn is not None, "Connection is not initialized"
366
377
  assert self.cursor is not None, "Cursor is not initialized"
367
378
 
379
+ index_param = self.case_config.index_param()
380
+
368
381
  try:
369
382
  log.info(f"{self.name} client create table : {self.table_name}")
370
383
 
371
384
  # create table
372
385
  self.cursor.execute(
373
386
  sql.SQL(
374
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
375
- ).format(table_name=sql.Identifier(self.table_name), dim=dim),
387
+ """
388
+ CREATE TABLE IF NOT EXISTS public.{table_name}
389
+ (id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));
390
+ """
391
+ ).format(
392
+ table_name=sql.Identifier(self.table_name),
393
+ table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
394
+ dim=dim,
395
+ )
376
396
  )
377
397
  self.cursor.execute(
378
398
  sql.SQL(
@@ -393,18 +413,41 @@ class PgVector(VectorDB):
393
413
  assert self.conn is not None, "Connection is not initialized"
394
414
  assert self.cursor is not None, "Cursor is not initialized"
395
415
 
416
+ index_param = self.case_config.index_param()
417
+
396
418
  try:
397
419
  metadata_arr = np.array(metadata)
398
420
  embeddings_arr = np.array(embeddings)
399
421
 
400
- with self.cursor.copy(
401
- sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
402
- table_name=sql.Identifier(self.table_name),
403
- ),
404
- ) as copy:
405
- copy.set_types(["bigint", "vector"])
406
- for i, row in enumerate(metadata_arr):
407
- copy.write_row((row, embeddings_arr[i]))
422
+ if index_param["table_quantization_type"] == "bit":
423
+ with self.cursor.copy(
424
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format(
425
+ table_name=sql.Identifier(self.table_name)
426
+ )
427
+ ) as copy:
428
+ # Same logic as pgvector binary_quantize
429
+ for i, row in enumerate(metadata_arr):
430
+ embeddings_bit = ""
431
+ for embedding in embeddings_arr[i]:
432
+ if embedding > 0:
433
+ embeddings_bit += "1"
434
+ else:
435
+ embeddings_bit += "0"
436
+ copy.write_row((str(row), embeddings_bit))
437
+ else:
438
+ with self.cursor.copy(
439
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
440
+ table_name=sql.Identifier(self.table_name)
441
+ )
442
+ ) as copy:
443
+ if index_param["table_quantization_type"] == "halfvec":
444
+ copy.set_types(["bigint", "halfvec"])
445
+ for i, row in enumerate(metadata_arr):
446
+ copy.write_row((row, np.float16(embeddings_arr[i])))
447
+ else:
448
+ copy.set_types(["bigint", "vector"])
449
+ for i, row in enumerate(metadata_arr):
450
+ copy.write_row((row, embeddings_arr[i]))
408
451
  self.conn.commit()
409
452
 
410
453
  if kwargs.get("last_batch"):
@@ -67,7 +67,7 @@ class Pinecone(VectorDB):
67
67
  embeddings: list[list[float]],
68
68
  metadata: list[int],
69
69
  **kwargs,
70
- ) -> (int, Exception):
70
+ ) -> tuple[int, Exception]:
71
71
  assert len(embeddings) == len(metadata)
72
72
  insert_count = 0
73
73
  try:
@@ -1,4 +1,4 @@
1
- from pydantic import BaseModel, SecretStr, validator
1
+ from pydantic import BaseModel, SecretStr
2
2
 
3
3
  from ..api import DBCaseConfig, DBConfig, MetricType
4
4
 
@@ -20,14 +20,6 @@ class QdrantConfig(DBConfig):
20
20
  "url": self.url.get_secret_value(),
21
21
  }
22
22
 
23
- @validator("*")
24
- def not_empty_field(cls, v: any, field: any):
25
- if field.name in ["api_key", "db_label"]:
26
- return v
27
- if isinstance(v, str | SecretStr) and len(v) == 0:
28
- raise ValueError("Empty string!")
29
- return v
30
-
31
23
 
32
24
  class QdrantIndexConfig(BaseModel, DBCaseConfig):
33
25
  metric_type: MetricType | None = None
@@ -111,7 +111,7 @@ class QdrantCloud(VectorDB):
111
111
  embeddings: list[list[float]],
112
112
  metadata: list[int],
113
113
  **kwargs,
114
- ) -> (int, Exception):
114
+ ) -> tuple[int, Exception]:
115
115
  """Insert embeddings into Milvus. should call self.init() first"""
116
116
  assert self.qdrant_client is not None
117
117
  try: