vectordb-bench 0.0.22__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -38,9 +38,11 @@ class DB(Enum):
38
38
  Chroma = "Chroma"
39
39
  AWSOpenSearch = "OpenSearch"
40
40
  AliyunElasticsearch = "AliyunElasticsearch"
41
+ MariaDB = "MariaDB"
41
42
  Test = "test"
42
43
  AliyunOpenSearch = "AliyunOpenSearch"
43
44
  MongoDB = "MongoDB"
45
+ TiDB = "TiDB"
44
46
 
45
47
  @property
46
48
  def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
@@ -135,6 +137,16 @@ class DB(Enum):
135
137
 
136
138
  return MongoDB
137
139
 
140
+ if self == DB.MariaDB:
141
+ from .mariadb.mariadb import MariaDB
142
+
143
+ return MariaDB
144
+
145
+ if self == DB.TiDB:
146
+ from .tidb.tidb import TiDB
147
+
148
+ return TiDB
149
+
138
150
  if self == DB.Test:
139
151
  from .test.test import Test
140
152
 
@@ -236,6 +248,16 @@ class DB(Enum):
236
248
 
237
249
  return MongoDBConfig
238
250
 
251
+ if self == DB.MariaDB:
252
+ from .mariadb.config import MariaDBConfig
253
+
254
+ return MariaDBConfig
255
+
256
+ if self == DB.TiDB:
257
+ from .tidb.config import TiDBConfig
258
+
259
+ return TiDBConfig
260
+
239
261
  if self == DB.Test:
240
262
  from .test.config import TestConfig
241
263
 
@@ -318,6 +340,16 @@ class DB(Enum):
318
340
 
319
341
  return MongoDBIndexConfig
320
342
 
343
+ if self == DB.MariaDB:
344
+ from .mariadb.config import _mariadb_case_config
345
+
346
+ return _mariadb_case_config.get(index_type)
347
+
348
+ if self == DB.TiDB:
349
+ from .tidb.config import TiDBIndexConfig
350
+
351
+ return TiDBIndexConfig
352
+
321
353
  # DB.Pinecone, DB.Chroma, DB.Redis
322
354
  return EmptyDBCaseConfig
323
355
 
@@ -25,6 +25,7 @@ class IndexType(str, Enum):
25
25
  ES_HNSW = "hnsw"
26
26
  ES_IVFFlat = "ivfflat"
27
27
  GPU_IVF_FLAT = "GPU_IVF_FLAT"
28
+ GPU_BRUTE_FORCE = "GPU_BRUTE_FORCE"
28
29
  GPU_IVF_PQ = "GPU_IVF_PQ"
29
30
  GPU_CAGRA = "GPU_CAGRA"
30
31
  SCANN = "scann"
@@ -0,0 +1,107 @@
1
+ from typing import Annotated, Optional, Unpack
2
+
3
+ import click
4
+ import os
5
+ from pydantic import SecretStr
6
+
7
+ from ....cli.cli import (
8
+ CommonTypedDict,
9
+ HNSWFlavor1,
10
+ cli,
11
+ click_parameter_decorators_from_typed_dict,
12
+ run,
13
+ )
14
+ from vectordb_bench.backend.clients import DB
15
+
16
+
17
+ class MariaDBTypedDict(CommonTypedDict):
18
+ user_name: Annotated[
19
+ str, click.option("--username",
20
+ type=str,
21
+ help="Username",
22
+ required=True,
23
+ ),
24
+ ]
25
+ password: Annotated[
26
+ str, click.option("--password",
27
+ type=str,
28
+ help="Password",
29
+ required=True,
30
+ ),
31
+ ]
32
+
33
+ host: Annotated[
34
+ str, click.option("--host",
35
+ type=str,
36
+ help="Db host",
37
+ default="127.0.0.1",
38
+ ),
39
+ ]
40
+
41
+ port: Annotated[
42
+ int, click.option("--port",
43
+ type=int,
44
+ default=3306,
45
+ help="Db Port",
46
+ ),
47
+ ]
48
+
49
+ storage_engine: Annotated[
50
+ int, click.option("--storage-engine",
51
+ type=click.Choice(["InnoDB", "MyISAM"]),
52
+ help="DB storage engine",
53
+ required=True,
54
+ ),
55
+ ]
56
+
57
+ class MariaDBHNSWTypedDict(MariaDBTypedDict):
58
+ ...
59
+ m: Annotated[
60
+ Optional[int], click.option("--m",
61
+ type=int,
62
+ help="M parameter in MHNSW vector indexing",
63
+ required=False,
64
+ ),
65
+ ]
66
+
67
+ ef_search: Annotated[
68
+ Optional[int], click.option("--ef-search",
69
+ type=int,
70
+ help="MariaDB system variable mhnsw_min_limit",
71
+ required=False,
72
+ ),
73
+ ]
74
+
75
+ max_cache_size: Annotated[
76
+ Optional[int], click.option("--max-cache-size",
77
+ type=int,
78
+ help="MariaDB system variable mhnsw_max_cache_size",
79
+ required=False,
80
+ ),
81
+ ]
82
+
83
+
84
+ @cli.command()
85
+ @click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict)
86
+ def MariaDBHNSW(
87
+ **parameters: Unpack[MariaDBHNSWTypedDict],
88
+ ):
89
+ from .config import MariaDBConfig, MariaDBHNSWConfig
90
+
91
+ run(
92
+ db=DB.MariaDB,
93
+ db_config=MariaDBConfig(
94
+ db_label=parameters["db_label"],
95
+ user_name=parameters["username"],
96
+ password=SecretStr(parameters["password"]),
97
+ host=parameters["host"],
98
+ port=parameters["port"],
99
+ ),
100
+ db_case_config=MariaDBHNSWConfig(
101
+ M=parameters["m"],
102
+ ef_search=parameters["ef_search"],
103
+ storage_engine=parameters["storage_engine"],
104
+ max_cache_size=parameters["max_cache_size"],
105
+ ),
106
+ **parameters,
107
+ )
@@ -0,0 +1,71 @@
1
+ from pydantic import SecretStr, BaseModel
2
+ from typing import TypedDict
3
+ from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
4
+
5
+ class MariaDBConfigDict(TypedDict):
6
+ """These keys will be directly used as kwargs in mariadb connection string,
7
+ so the names must match exactly mariadb API"""
8
+
9
+ user: str
10
+ password: str
11
+ host: str
12
+ port: int
13
+
14
+
15
+ class MariaDBConfig(DBConfig):
16
+ user_name: str = "root"
17
+ password: SecretStr
18
+ host: str = "127.0.0.1"
19
+ port: int = 3306
20
+
21
+ def to_dict(self) -> MariaDBConfigDict:
22
+ pwd_str = self.password.get_secret_value()
23
+ return {
24
+ "host": self.host,
25
+ "port": self.port,
26
+ "user": self.user_name,
27
+ "password": pwd_str,
28
+ }
29
+
30
+
31
+ class MariaDBIndexConfig(BaseModel):
32
+ """Base config for MariaDB"""
33
+
34
+ metric_type: MetricType | None = None
35
+
36
+ def parse_metric(self) -> str:
37
+ if self.metric_type == MetricType.L2:
38
+ return "euclidean"
39
+ elif self.metric_type == MetricType.COSINE:
40
+ return "cosine"
41
+ else:
42
+ raise ValueError(f"Metric type {self.metric_type} is not supported!")
43
+
44
+ class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
45
+ M: int | None
46
+ ef_search: int | None
47
+ index: IndexType = IndexType.HNSW
48
+ storage_engine: str = "InnoDB"
49
+ max_cache_size: int | None
50
+
51
+ def index_param(self) -> dict:
52
+ return {
53
+ "storage_engine": self.storage_engine,
54
+ "metric_type": self.parse_metric(),
55
+ "index_type": self.index.value,
56
+ "M": self.M,
57
+ "max_cache_size": self.max_cache_size,
58
+ }
59
+
60
+ def search_param(self) -> dict:
61
+ return {
62
+ "metric_type": self.parse_metric(),
63
+ "ef_search": self.ef_search,
64
+ }
65
+
66
+
67
+ _mariadb_case_config = {
68
+ IndexType.HNSW: MariaDBHNSWConfig,
69
+ }
70
+
71
+
@@ -0,0 +1,214 @@
1
+ from ..api import VectorDB
2
+
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from typing import Any, Optional, Tuple
6
+ from ..api import VectorDB
7
+ from .config import MariaDBConfigDict, MariaDBIndexConfig
8
+ import numpy as np
9
+
10
+ import mariadb
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ class MariaDB(VectorDB):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ db_config: MariaDBConfigDict,
19
+ db_case_config: MariaDBIndexConfig,
20
+ collection_name: str = "vec_collection",
21
+ drop_old: bool = False,
22
+ **kwargs,
23
+ ):
24
+
25
+ self.name = "MariaDB"
26
+ self.db_config = db_config
27
+ self.case_config = db_case_config
28
+ self.db_name = "vectordbbench"
29
+ self.table_name = collection_name
30
+ self.dim = dim
31
+
32
+ # construct basic units
33
+ self.conn, self.cursor = self._create_connection(**self.db_config)
34
+
35
+ if drop_old:
36
+ self._drop_db()
37
+ self._create_db_table(dim)
38
+
39
+ self.cursor.close()
40
+ self.conn.close()
41
+ self.cursor = None
42
+ self.conn = None
43
+
44
+
45
+ @staticmethod
46
+ def _create_connection(**kwargs) -> Tuple[mariadb.Connection, mariadb.Cursor]:
47
+ conn = mariadb.connect(**kwargs)
48
+ cursor = conn.cursor()
49
+
50
+ assert conn is not None, "Connection is not initialized"
51
+ assert cursor is not None, "Cursor is not initialized"
52
+
53
+ return conn, cursor
54
+
55
+
56
+ def _drop_db(self):
57
+ assert self.conn is not None, "Connection is not initialized"
58
+ assert self.cursor is not None, "Cursor is not initialized"
59
+ log.info(f"{self.name} client drop db : {self.db_name}")
60
+
61
+ # flush tables before dropping database to avoid some locking issue
62
+ self.cursor.execute("FLUSH TABLES")
63
+ self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}")
64
+ self.cursor.execute("COMMIT")
65
+ self.cursor.execute("FLUSH TABLES")
66
+
67
+ def _create_db_table(self, dim: int):
68
+ assert self.conn is not None, "Connection is not initialized"
69
+ assert self.cursor is not None, "Cursor is not initialized"
70
+
71
+ index_param = self.case_config.index_param()
72
+
73
+ try:
74
+ log.info(f"{self.name} client create database : {self.db_name}")
75
+ self.cursor.execute(f"CREATE DATABASE {self.db_name}")
76
+
77
+ log.info(f"{self.name} client create table : {self.table_name}")
78
+ self.cursor.execute(f"USE {self.db_name}")
79
+
80
+ self.cursor.execute(f"""
81
+ CREATE TABLE {self.table_name} (
82
+ id INT PRIMARY KEY,
83
+ v VECTOR({self.dim}) NOT NULL
84
+ ) ENGINE={index_param["storage_engine"]}
85
+ """)
86
+ self.cursor.execute("COMMIT")
87
+
88
+ except Exception as e:
89
+ log.warning(
90
+ f"Failed to create table: {self.table_name} error: {e}"
91
+ )
92
+ raise e from None
93
+
94
+
95
+ @contextmanager
96
+ def init(self) -> None:
97
+ """ create and destory connections to database.
98
+
99
+ Examples:
100
+ >>> with self.init():
101
+ >>> self.insert_embeddings()
102
+ """
103
+ self.conn, self.cursor = self._create_connection(**self.db_config)
104
+
105
+ index_param = self.case_config.index_param()
106
+ search_param = self.case_config.search_param()
107
+
108
+ # maximize allowed package size
109
+ self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824")
110
+
111
+ if index_param["index_type"] == "HNSW":
112
+ if index_param["max_cache_size"] != None:
113
+ self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param["max_cache_size"]}")
114
+ if search_param["ef_search"] != None:
115
+ self.cursor.execute(f"SET mhnsw_ef_search = {search_param["ef_search"]}")
116
+ self.cursor.execute("COMMIT")
117
+
118
+ self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)"
119
+ self.select_sql = f"SELECT id FROM {self.db_name}.{self.table_name} ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d"
120
+ self.select_sql_with_filter = f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d"
121
+
122
+ try:
123
+ yield
124
+ finally:
125
+ self.cursor.close()
126
+ self.conn.close()
127
+ self.cursor = None
128
+ self.conn = None
129
+
130
+
131
+ def ready_to_load(self) -> bool:
132
+ pass
133
+
134
+ def optimize(self) -> None:
135
+ assert self.conn is not None, "Connection is not initialized"
136
+ assert self.cursor is not None, "Cursor is not initialized"
137
+
138
+ index_param = self.case_config.index_param()
139
+
140
+ try:
141
+ index_options = f"DISTANCE={index_param['metric_type']}"
142
+ if index_param["index_type"] == "HNSW" and index_param["M"] != None:
143
+ index_options += f" M={index_param['M']}"
144
+
145
+ self.cursor.execute(f"""
146
+ ALTER TABLE {self.db_name}.{self.table_name}
147
+ ADD VECTOR KEY v(v) {index_options}
148
+ """)
149
+ self.cursor.execute("COMMIT")
150
+
151
+ except Exception as e:
152
+ log.warning(
153
+ f"Failed to create index: {self.table_name} error: {e}"
154
+ )
155
+ raise e from None
156
+
157
+ pass
158
+
159
+ @staticmethod
160
+ def vector_to_hex(v):
161
+ return np.array(v, 'float32').tobytes()
162
+
163
+ def insert_embeddings(
164
+ self,
165
+ embeddings: list[list[float]],
166
+ metadata: list[int],
167
+ **kwargs: Any,
168
+ ) -> Tuple[int, Optional[Exception]]:
169
+ """Insert embeddings into the database.
170
+ Should call self.init() first.
171
+ """
172
+ assert self.conn is not None, "Connection is not initialized"
173
+ assert self.cursor is not None, "Cursor is not initialized"
174
+
175
+ try:
176
+ metadata_arr = np.array(metadata)
177
+ embeddings_arr = np.array(embeddings)
178
+
179
+ batch_data = []
180
+ for i, row in enumerate(metadata_arr):
181
+ batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])));
182
+
183
+ self.cursor.executemany(self.insert_sql, batch_data)
184
+ self.cursor.execute("COMMIT")
185
+ self.cursor.execute("FLUSH TABLES")
186
+
187
+ return len(metadata), None
188
+ except Exception as e:
189
+ log.warning(
190
+ f"Failed to insert data into Vector table ({self.table_name}), error: {e}"
191
+ )
192
+ return 0, e
193
+
194
+
195
+ def search_embedding(
196
+ self,
197
+ query: list[float],
198
+ k: int = 100,
199
+ filters: dict | None = None,
200
+ timeout: int | None = None,
201
+ **kwargs: Any,
202
+ ) -> (list[int]):
203
+ assert self.conn is not None, "Connection is not initialized"
204
+ assert self.cursor is not None, "Cursor is not initialized"
205
+
206
+ search_param = self.case_config.search_param()
207
+
208
+ if filters:
209
+ self.cursor.execute(self.select_sql_with_filter, (filters.get('id'), self.vector_to_hex(query), k))
210
+ else:
211
+ self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k))
212
+
213
+ return [id for id, in self.cursor.fetchall()]
214
+
@@ -194,6 +194,56 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
194
194
  **parameters,
195
195
  )
196
196
 
197
+ @cli.command()
198
+ @click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
199
+ def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
200
+ from .config import GPUBruteForceConfig, MilvusConfig
201
+
202
+ run(
203
+ db=DBTYPE,
204
+ db_config=MilvusConfig(
205
+ db_label=parameters["db_label"],
206
+ uri=SecretStr(parameters["uri"]),
207
+ user=parameters["user_name"],
208
+ password=SecretStr(parameters["password"]),
209
+ ),
210
+ db_case_config=GPUBruteForceConfig(
211
+ metric_type=parameters["metric_type"],
212
+ limit=parameters["limit"], # top-k for search
213
+ ),
214
+ **parameters,
215
+ )
216
+
217
+ class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict):
218
+ metric_type: Annotated[
219
+ str,
220
+ click.option("--metric-type", type=str, required=True, help="Metric type for brute force search"),
221
+ ]
222
+ limit: Annotated[
223
+ int,
224
+ click.option("--limit", type=int, required=True, help="Top-k limit for search"),
225
+ ]
226
+
227
+ @cli.command()
228
+ @click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
229
+ def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
230
+ from .config import GPUBruteForceConfig, MilvusConfig
231
+
232
+ run(
233
+ db=DBTYPE,
234
+ db_config=MilvusConfig(
235
+ db_label=parameters["db_label"],
236
+ uri=SecretStr(parameters["uri"]),
237
+ user=parameters["user_name"],
238
+ password=SecretStr(parameters["password"]),
239
+ ),
240
+ db_case_config=GPUBruteForceConfig(
241
+ metric_type=parameters["metric_type"],
242
+ limit=parameters["limit"], # top-k for search
243
+ ),
244
+ **parameters,
245
+ )
246
+
197
247
 
198
248
  class MilvusGPUIVFPQTypedDict(
199
249
  CommonTypedDict,
@@ -40,6 +40,7 @@ class MilvusIndexConfig(BaseModel):
40
40
  IndexType.GPU_CAGRA,
41
41
  IndexType.GPU_IVF_FLAT,
42
42
  IndexType.GPU_IVF_PQ,
43
+ IndexType.GPU_BRUTE_FORCE,
43
44
  ]
44
45
 
45
46
  def parse_metric(self) -> str:
@@ -184,6 +185,37 @@ class GPUIVFFlatConfig(MilvusIndexConfig, DBCaseConfig):
184
185
  }
185
186
 
186
187
 
188
+ class GPUBruteForceConfig(MilvusIndexConfig, DBCaseConfig):
189
+ limit: int = 10 # Default top-k for search
190
+ metric_type: str # Metric type (e.g., 'L2', 'IP', etc.)
191
+ index: IndexType = IndexType.GPU_BRUTE_FORCE # Index type set to GPU_BRUTE_FORCE
192
+
193
+ def index_param(self) -> dict:
194
+ """
195
+ Returns the parameters for creating the GPU_BRUTE_FORCE index.
196
+ No additional parameters required for index building.
197
+ """
198
+ return {
199
+ "metric_type": self.parse_metric(), # Metric type for distance calculation (L2, IP, etc.)
200
+ "index_type": self.index.value, # GPU_BRUTE_FORCE index type
201
+ "params": {}, # No additional parameters for GPU_BRUTE_FORCE
202
+ }
203
+
204
+ def search_param(self) -> dict:
205
+ """
206
+ Returns the parameters for performing a search on the GPU_BRUTE_FORCE index.
207
+ Only metric_type and top-k (limit) are needed for search.
208
+ """
209
+ return {
210
+ "metric_type": self.parse_metric(), # Metric type for search
211
+ "params": {
212
+ "nprobe": 1, # For GPU_BRUTE_FORCE, set nprobe to 1 (brute force search)
213
+ "limit": self.limit, # Top-k for search
214
+ },
215
+ }
216
+
217
+
218
+
187
219
  class GPUIVFPQConfig(MilvusIndexConfig, DBCaseConfig):
188
220
  nlist: int = 1024
189
221
  m: int = 0
@@ -261,4 +293,5 @@ _milvus_case_config = {
261
293
  IndexType.GPU_IVF_FLAT: GPUIVFFlatConfig,
262
294
  IndexType.GPU_IVF_PQ: GPUIVFPQConfig,
263
295
  IndexType.GPU_CAGRA: GPUCAGRAConfig,
296
+ IndexType.GPU_BRUTE_FORCE: GPUBruteForceConfig,
264
297
  }
@@ -82,7 +82,17 @@ class PgVectorTypedDict(CommonTypedDict):
82
82
  click.option(
83
83
  "--quantization-type",
84
84
  type=click.Choice(["none", "bit", "halfvec"]),
85
- help="quantization type for vectors",
85
+ help="quantization type for vectors (in index)",
86
+ required=False,
87
+ ),
88
+ ]
89
+ table_quantization_type: Annotated[
90
+ str | None,
91
+ click.option(
92
+ "--table-quantization-type",
93
+ type=click.Choice(["none", "bit", "halfvec"]),
94
+ help="quantization type for vectors (in table). "
95
+ "If equal to bit, the parameter quantization_type will be set to bit too.",
86
96
  required=False,
87
97
  ),
88
98
  ]
@@ -146,6 +156,7 @@ def PgVectorIVFFlat(
146
156
  lists=parameters["lists"],
147
157
  probes=parameters["probes"],
148
158
  quantization_type=parameters["quantization_type"],
159
+ table_quantization_type=parameters["table_quantization_type"],
149
160
  reranking=parameters["reranking"],
150
161
  reranking_metric=parameters["reranking_metric"],
151
162
  quantized_fetch_limit=parameters["quantized_fetch_limit"],
@@ -182,6 +193,7 @@ def PgVectorHNSW(
182
193
  maintenance_work_mem=parameters["maintenance_work_mem"],
183
194
  max_parallel_workers=parameters["max_parallel_workers"],
184
195
  quantization_type=parameters["quantization_type"],
196
+ table_quantization_type=parameters["table_quantization_type"],
185
197
  reranking=parameters["reranking"],
186
198
  reranking_metric=parameters["reranking_metric"],
187
199
  quantized_fetch_limit=parameters["quantized_fetch_limit"],
@@ -80,7 +80,12 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
80
80
 
81
81
  if d.get(self.quantization_type) is None:
82
82
  return d.get("_fallback").get(self.metric_type)
83
- return d.get(self.quantization_type).get(self.metric_type)
83
+ metric = d.get(self.quantization_type).get(self.metric_type)
84
+ # If using binary quantization for the index, use a bit metric
85
+ # no matter what metric was selected for vector or halfvec data
86
+ if self.quantization_type == "bit" and metric is None:
87
+ return "bit_hamming_ops"
88
+ return metric
84
89
 
85
90
  def parse_metric_fun_op(self) -> LiteralString:
86
91
  if self.quantization_type == "bit":
@@ -168,14 +173,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
168
173
  maintenance_work_mem: str | None = None
169
174
  max_parallel_workers: int | None = None
170
175
  quantization_type: str | None = None
176
+ table_quantization_type: str | None
171
177
  reranking: bool | None = None
172
178
  quantized_fetch_limit: int | None = None
173
179
  reranking_metric: str | None = None
174
180
 
175
181
  def index_param(self) -> PgVectorIndexParam:
176
182
  index_parameters = {"lists": self.lists}
177
- if self.quantization_type == "none":
178
- self.quantization_type = None
183
+ if self.quantization_type == "none" or self.quantization_type is None:
184
+ self.quantization_type = "vector"
185
+ if self.table_quantization_type == "none" or self.table_quantization_type is None:
186
+ self.table_quantization_type = "vector"
187
+ if self.table_quantization_type == "bit":
188
+ self.quantization_type = "bit"
179
189
  return {
180
190
  "metric": self.parse_metric(),
181
191
  "index_type": self.index.value,
@@ -183,6 +193,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
183
193
  "maintenance_work_mem": self.maintenance_work_mem,
184
194
  "max_parallel_workers": self.max_parallel_workers,
185
195
  "quantization_type": self.quantization_type,
196
+ "table_quantization_type": self.table_quantization_type,
186
197
  }
187
198
 
188
199
  def search_param(self) -> PgVectorSearchParam:
@@ -212,14 +223,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
212
223
  maintenance_work_mem: str | None = None
213
224
  max_parallel_workers: int | None = None
214
225
  quantization_type: str | None = None
226
+ table_quantization_type: str | None
215
227
  reranking: bool | None = None
216
228
  quantized_fetch_limit: int | None = None
217
229
  reranking_metric: str | None = None
218
230
 
219
231
  def index_param(self) -> PgVectorIndexParam:
220
232
  index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
221
- if self.quantization_type == "none":
222
- self.quantization_type = None
233
+ if self.quantization_type == "none" or self.quantization_type is None:
234
+ self.quantization_type = "vector"
235
+ if self.table_quantization_type == "none" or self.table_quantization_type is None:
236
+ self.table_quantization_type = "vector"
237
+ if self.table_quantization_type == "bit":
238
+ self.quantization_type = "bit"
223
239
  return {
224
240
  "metric": self.parse_metric(),
225
241
  "index_type": self.index.value,
@@ -227,6 +243,7 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
227
243
  "maintenance_work_mem": self.maintenance_work_mem,
228
244
  "max_parallel_workers": self.max_parallel_workers,
229
245
  "quantization_type": self.quantization_type,
246
+ "table_quantization_type": self.table_quantization_type,
230
247
  }
231
248
 
232
249
  def search_param(self) -> PgVectorSearchParam: