vectordb-bench 0.0.22__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +32 -0
- vectordb_bench/backend/clients/api.py +1 -0
- vectordb_bench/backend/clients/mariadb/cli.py +107 -0
- vectordb_bench/backend/clients/mariadb/config.py +71 -0
- vectordb_bench/backend/clients/mariadb/mariadb.py +214 -0
- vectordb_bench/backend/clients/milvus/cli.py +50 -0
- vectordb_bench/backend/clients/milvus/config.py +33 -0
- vectordb_bench/backend/clients/pgvector/cli.py +13 -1
- vectordb_bench/backend/clients/pgvector/config.py +22 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +62 -19
- vectordb_bench/backend/clients/tidb/cli.py +98 -0
- vectordb_bench/backend/clients/tidb/config.py +49 -0
- vectordb_bench/backend/clients/tidb/tidb.py +234 -0
- vectordb_bench/cli/vectordbbench.py +4 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +96 -0
- vectordb_bench/frontend/config/styles.py +2 -0
- vectordb_bench/models.py +3 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/METADATA +13 -2
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/RECORD +23 -17
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.22.dist-info → vectordb_bench-0.0.23.dist-info}/top_level.txt +0 -0
@@ -38,9 +38,11 @@ class DB(Enum):
|
|
38
38
|
Chroma = "Chroma"
|
39
39
|
AWSOpenSearch = "OpenSearch"
|
40
40
|
AliyunElasticsearch = "AliyunElasticsearch"
|
41
|
+
MariaDB = "MariaDB"
|
41
42
|
Test = "test"
|
42
43
|
AliyunOpenSearch = "AliyunOpenSearch"
|
43
44
|
MongoDB = "MongoDB"
|
45
|
+
TiDB = "TiDB"
|
44
46
|
|
45
47
|
@property
|
46
48
|
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901
|
@@ -135,6 +137,16 @@ class DB(Enum):
|
|
135
137
|
|
136
138
|
return MongoDB
|
137
139
|
|
140
|
+
if self == DB.MariaDB:
|
141
|
+
from .mariadb.mariadb import MariaDB
|
142
|
+
|
143
|
+
return MariaDB
|
144
|
+
|
145
|
+
if self == DB.TiDB:
|
146
|
+
from .tidb.tidb import TiDB
|
147
|
+
|
148
|
+
return TiDB
|
149
|
+
|
138
150
|
if self == DB.Test:
|
139
151
|
from .test.test import Test
|
140
152
|
|
@@ -236,6 +248,16 @@ class DB(Enum):
|
|
236
248
|
|
237
249
|
return MongoDBConfig
|
238
250
|
|
251
|
+
if self == DB.MariaDB:
|
252
|
+
from .mariadb.config import MariaDBConfig
|
253
|
+
|
254
|
+
return MariaDBConfig
|
255
|
+
|
256
|
+
if self == DB.TiDB:
|
257
|
+
from .tidb.config import TiDBConfig
|
258
|
+
|
259
|
+
return TiDBConfig
|
260
|
+
|
239
261
|
if self == DB.Test:
|
240
262
|
from .test.config import TestConfig
|
241
263
|
|
@@ -318,6 +340,16 @@ class DB(Enum):
|
|
318
340
|
|
319
341
|
return MongoDBIndexConfig
|
320
342
|
|
343
|
+
if self == DB.MariaDB:
|
344
|
+
from .mariadb.config import _mariadb_case_config
|
345
|
+
|
346
|
+
return _mariadb_case_config.get(index_type)
|
347
|
+
|
348
|
+
if self == DB.TiDB:
|
349
|
+
from .tidb.config import TiDBIndexConfig
|
350
|
+
|
351
|
+
return TiDBIndexConfig
|
352
|
+
|
321
353
|
# DB.Pinecone, DB.Chroma, DB.Redis
|
322
354
|
return EmptyDBCaseConfig
|
323
355
|
|
@@ -0,0 +1,107 @@
|
|
1
|
+
from typing import Annotated, Optional, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
import os
|
5
|
+
from pydantic import SecretStr
|
6
|
+
|
7
|
+
from ....cli.cli import (
|
8
|
+
CommonTypedDict,
|
9
|
+
HNSWFlavor1,
|
10
|
+
cli,
|
11
|
+
click_parameter_decorators_from_typed_dict,
|
12
|
+
run,
|
13
|
+
)
|
14
|
+
from vectordb_bench.backend.clients import DB
|
15
|
+
|
16
|
+
|
17
|
+
class MariaDBTypedDict(CommonTypedDict):
|
18
|
+
user_name: Annotated[
|
19
|
+
str, click.option("--username",
|
20
|
+
type=str,
|
21
|
+
help="Username",
|
22
|
+
required=True,
|
23
|
+
),
|
24
|
+
]
|
25
|
+
password: Annotated[
|
26
|
+
str, click.option("--password",
|
27
|
+
type=str,
|
28
|
+
help="Password",
|
29
|
+
required=True,
|
30
|
+
),
|
31
|
+
]
|
32
|
+
|
33
|
+
host: Annotated[
|
34
|
+
str, click.option("--host",
|
35
|
+
type=str,
|
36
|
+
help="Db host",
|
37
|
+
default="127.0.0.1",
|
38
|
+
),
|
39
|
+
]
|
40
|
+
|
41
|
+
port: Annotated[
|
42
|
+
int, click.option("--port",
|
43
|
+
type=int,
|
44
|
+
default=3306,
|
45
|
+
help="Db Port",
|
46
|
+
),
|
47
|
+
]
|
48
|
+
|
49
|
+
storage_engine: Annotated[
|
50
|
+
int, click.option("--storage-engine",
|
51
|
+
type=click.Choice(["InnoDB", "MyISAM"]),
|
52
|
+
help="DB storage engine",
|
53
|
+
required=True,
|
54
|
+
),
|
55
|
+
]
|
56
|
+
|
57
|
+
class MariaDBHNSWTypedDict(MariaDBTypedDict):
|
58
|
+
...
|
59
|
+
m: Annotated[
|
60
|
+
Optional[int], click.option("--m",
|
61
|
+
type=int,
|
62
|
+
help="M parameter in MHNSW vector indexing",
|
63
|
+
required=False,
|
64
|
+
),
|
65
|
+
]
|
66
|
+
|
67
|
+
ef_search: Annotated[
|
68
|
+
Optional[int], click.option("--ef-search",
|
69
|
+
type=int,
|
70
|
+
help="MariaDB system variable mhnsw_min_limit",
|
71
|
+
required=False,
|
72
|
+
),
|
73
|
+
]
|
74
|
+
|
75
|
+
max_cache_size: Annotated[
|
76
|
+
Optional[int], click.option("--max-cache-size",
|
77
|
+
type=int,
|
78
|
+
help="MariaDB system variable mhnsw_max_cache_size",
|
79
|
+
required=False,
|
80
|
+
),
|
81
|
+
]
|
82
|
+
|
83
|
+
|
84
|
+
@cli.command()
|
85
|
+
@click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict)
|
86
|
+
def MariaDBHNSW(
|
87
|
+
**parameters: Unpack[MariaDBHNSWTypedDict],
|
88
|
+
):
|
89
|
+
from .config import MariaDBConfig, MariaDBHNSWConfig
|
90
|
+
|
91
|
+
run(
|
92
|
+
db=DB.MariaDB,
|
93
|
+
db_config=MariaDBConfig(
|
94
|
+
db_label=parameters["db_label"],
|
95
|
+
user_name=parameters["username"],
|
96
|
+
password=SecretStr(parameters["password"]),
|
97
|
+
host=parameters["host"],
|
98
|
+
port=parameters["port"],
|
99
|
+
),
|
100
|
+
db_case_config=MariaDBHNSWConfig(
|
101
|
+
M=parameters["m"],
|
102
|
+
ef_search=parameters["ef_search"],
|
103
|
+
storage_engine=parameters["storage_engine"],
|
104
|
+
max_cache_size=parameters["max_cache_size"],
|
105
|
+
),
|
106
|
+
**parameters,
|
107
|
+
)
|
@@ -0,0 +1,71 @@
|
|
1
|
+
from pydantic import SecretStr, BaseModel
|
2
|
+
from typing import TypedDict
|
3
|
+
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
|
4
|
+
|
5
|
+
class MariaDBConfigDict(TypedDict):
|
6
|
+
"""These keys will be directly used as kwargs in mariadb connection string,
|
7
|
+
so the names must match exactly mariadb API"""
|
8
|
+
|
9
|
+
user: str
|
10
|
+
password: str
|
11
|
+
host: str
|
12
|
+
port: int
|
13
|
+
|
14
|
+
|
15
|
+
class MariaDBConfig(DBConfig):
|
16
|
+
user_name: str = "root"
|
17
|
+
password: SecretStr
|
18
|
+
host: str = "127.0.0.1"
|
19
|
+
port: int = 3306
|
20
|
+
|
21
|
+
def to_dict(self) -> MariaDBConfigDict:
|
22
|
+
pwd_str = self.password.get_secret_value()
|
23
|
+
return {
|
24
|
+
"host": self.host,
|
25
|
+
"port": self.port,
|
26
|
+
"user": self.user_name,
|
27
|
+
"password": pwd_str,
|
28
|
+
}
|
29
|
+
|
30
|
+
|
31
|
+
class MariaDBIndexConfig(BaseModel):
|
32
|
+
"""Base config for MariaDB"""
|
33
|
+
|
34
|
+
metric_type: MetricType | None = None
|
35
|
+
|
36
|
+
def parse_metric(self) -> str:
|
37
|
+
if self.metric_type == MetricType.L2:
|
38
|
+
return "euclidean"
|
39
|
+
elif self.metric_type == MetricType.COSINE:
|
40
|
+
return "cosine"
|
41
|
+
else:
|
42
|
+
raise ValueError(f"Metric type {self.metric_type} is not supported!")
|
43
|
+
|
44
|
+
class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig):
|
45
|
+
M: int | None
|
46
|
+
ef_search: int | None
|
47
|
+
index: IndexType = IndexType.HNSW
|
48
|
+
storage_engine: str = "InnoDB"
|
49
|
+
max_cache_size: int | None
|
50
|
+
|
51
|
+
def index_param(self) -> dict:
|
52
|
+
return {
|
53
|
+
"storage_engine": self.storage_engine,
|
54
|
+
"metric_type": self.parse_metric(),
|
55
|
+
"index_type": self.index.value,
|
56
|
+
"M": self.M,
|
57
|
+
"max_cache_size": self.max_cache_size,
|
58
|
+
}
|
59
|
+
|
60
|
+
def search_param(self) -> dict:
|
61
|
+
return {
|
62
|
+
"metric_type": self.parse_metric(),
|
63
|
+
"ef_search": self.ef_search,
|
64
|
+
}
|
65
|
+
|
66
|
+
|
67
|
+
_mariadb_case_config = {
|
68
|
+
IndexType.HNSW: MariaDBHNSWConfig,
|
69
|
+
}
|
70
|
+
|
71
|
+
|
@@ -0,0 +1,214 @@
|
|
1
|
+
from ..api import VectorDB
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from contextlib import contextmanager
|
5
|
+
from typing import Any, Optional, Tuple
|
6
|
+
from ..api import VectorDB
|
7
|
+
from .config import MariaDBConfigDict, MariaDBIndexConfig
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
import mariadb
|
11
|
+
|
12
|
+
log = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
class MariaDB(VectorDB):
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
dim: int,
|
18
|
+
db_config: MariaDBConfigDict,
|
19
|
+
db_case_config: MariaDBIndexConfig,
|
20
|
+
collection_name: str = "vec_collection",
|
21
|
+
drop_old: bool = False,
|
22
|
+
**kwargs,
|
23
|
+
):
|
24
|
+
|
25
|
+
self.name = "MariaDB"
|
26
|
+
self.db_config = db_config
|
27
|
+
self.case_config = db_case_config
|
28
|
+
self.db_name = "vectordbbench"
|
29
|
+
self.table_name = collection_name
|
30
|
+
self.dim = dim
|
31
|
+
|
32
|
+
# construct basic units
|
33
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
34
|
+
|
35
|
+
if drop_old:
|
36
|
+
self._drop_db()
|
37
|
+
self._create_db_table(dim)
|
38
|
+
|
39
|
+
self.cursor.close()
|
40
|
+
self.conn.close()
|
41
|
+
self.cursor = None
|
42
|
+
self.conn = None
|
43
|
+
|
44
|
+
|
45
|
+
@staticmethod
|
46
|
+
def _create_connection(**kwargs) -> Tuple[mariadb.Connection, mariadb.Cursor]:
|
47
|
+
conn = mariadb.connect(**kwargs)
|
48
|
+
cursor = conn.cursor()
|
49
|
+
|
50
|
+
assert conn is not None, "Connection is not initialized"
|
51
|
+
assert cursor is not None, "Cursor is not initialized"
|
52
|
+
|
53
|
+
return conn, cursor
|
54
|
+
|
55
|
+
|
56
|
+
def _drop_db(self):
|
57
|
+
assert self.conn is not None, "Connection is not initialized"
|
58
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
59
|
+
log.info(f"{self.name} client drop db : {self.db_name}")
|
60
|
+
|
61
|
+
# flush tables before dropping database to avoid some locking issue
|
62
|
+
self.cursor.execute("FLUSH TABLES")
|
63
|
+
self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name}")
|
64
|
+
self.cursor.execute("COMMIT")
|
65
|
+
self.cursor.execute("FLUSH TABLES")
|
66
|
+
|
67
|
+
def _create_db_table(self, dim: int):
|
68
|
+
assert self.conn is not None, "Connection is not initialized"
|
69
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
70
|
+
|
71
|
+
index_param = self.case_config.index_param()
|
72
|
+
|
73
|
+
try:
|
74
|
+
log.info(f"{self.name} client create database : {self.db_name}")
|
75
|
+
self.cursor.execute(f"CREATE DATABASE {self.db_name}")
|
76
|
+
|
77
|
+
log.info(f"{self.name} client create table : {self.table_name}")
|
78
|
+
self.cursor.execute(f"USE {self.db_name}")
|
79
|
+
|
80
|
+
self.cursor.execute(f"""
|
81
|
+
CREATE TABLE {self.table_name} (
|
82
|
+
id INT PRIMARY KEY,
|
83
|
+
v VECTOR({self.dim}) NOT NULL
|
84
|
+
) ENGINE={index_param["storage_engine"]}
|
85
|
+
""")
|
86
|
+
self.cursor.execute("COMMIT")
|
87
|
+
|
88
|
+
except Exception as e:
|
89
|
+
log.warning(
|
90
|
+
f"Failed to create table: {self.table_name} error: {e}"
|
91
|
+
)
|
92
|
+
raise e from None
|
93
|
+
|
94
|
+
|
95
|
+
@contextmanager
|
96
|
+
def init(self) -> None:
|
97
|
+
""" create and destory connections to database.
|
98
|
+
|
99
|
+
Examples:
|
100
|
+
>>> with self.init():
|
101
|
+
>>> self.insert_embeddings()
|
102
|
+
"""
|
103
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
104
|
+
|
105
|
+
index_param = self.case_config.index_param()
|
106
|
+
search_param = self.case_config.search_param()
|
107
|
+
|
108
|
+
# maximize allowed package size
|
109
|
+
self.cursor.execute("SET GLOBAL max_allowed_packet = 1073741824")
|
110
|
+
|
111
|
+
if index_param["index_type"] == "HNSW":
|
112
|
+
if index_param["max_cache_size"] != None:
|
113
|
+
self.cursor.execute(f"SET GLOBAL mhnsw_max_cache_size = {index_param["max_cache_size"]}")
|
114
|
+
if search_param["ef_search"] != None:
|
115
|
+
self.cursor.execute(f"SET mhnsw_ef_search = {search_param["ef_search"]}")
|
116
|
+
self.cursor.execute("COMMIT")
|
117
|
+
|
118
|
+
self.insert_sql = f"INSERT INTO {self.db_name}.{self.table_name} (id, v) VALUES (%s, %s)"
|
119
|
+
self.select_sql = f"SELECT id FROM {self.db_name}.{self.table_name} ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d"
|
120
|
+
self.select_sql_with_filter = f"SELECT id FROM {self.db_name}.{self.table_name} WHERE id >= %d ORDER by vec_distance_{search_param["metric_type"]}(v, %s) LIMIT %d"
|
121
|
+
|
122
|
+
try:
|
123
|
+
yield
|
124
|
+
finally:
|
125
|
+
self.cursor.close()
|
126
|
+
self.conn.close()
|
127
|
+
self.cursor = None
|
128
|
+
self.conn = None
|
129
|
+
|
130
|
+
|
131
|
+
def ready_to_load(self) -> bool:
|
132
|
+
pass
|
133
|
+
|
134
|
+
def optimize(self) -> None:
|
135
|
+
assert self.conn is not None, "Connection is not initialized"
|
136
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
137
|
+
|
138
|
+
index_param = self.case_config.index_param()
|
139
|
+
|
140
|
+
try:
|
141
|
+
index_options = f"DISTANCE={index_param['metric_type']}"
|
142
|
+
if index_param["index_type"] == "HNSW" and index_param["M"] != None:
|
143
|
+
index_options += f" M={index_param['M']}"
|
144
|
+
|
145
|
+
self.cursor.execute(f"""
|
146
|
+
ALTER TABLE {self.db_name}.{self.table_name}
|
147
|
+
ADD VECTOR KEY v(v) {index_options}
|
148
|
+
""")
|
149
|
+
self.cursor.execute("COMMIT")
|
150
|
+
|
151
|
+
except Exception as e:
|
152
|
+
log.warning(
|
153
|
+
f"Failed to create index: {self.table_name} error: {e}"
|
154
|
+
)
|
155
|
+
raise e from None
|
156
|
+
|
157
|
+
pass
|
158
|
+
|
159
|
+
@staticmethod
|
160
|
+
def vector_to_hex(v):
|
161
|
+
return np.array(v, 'float32').tobytes()
|
162
|
+
|
163
|
+
def insert_embeddings(
|
164
|
+
self,
|
165
|
+
embeddings: list[list[float]],
|
166
|
+
metadata: list[int],
|
167
|
+
**kwargs: Any,
|
168
|
+
) -> Tuple[int, Optional[Exception]]:
|
169
|
+
"""Insert embeddings into the database.
|
170
|
+
Should call self.init() first.
|
171
|
+
"""
|
172
|
+
assert self.conn is not None, "Connection is not initialized"
|
173
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
174
|
+
|
175
|
+
try:
|
176
|
+
metadata_arr = np.array(metadata)
|
177
|
+
embeddings_arr = np.array(embeddings)
|
178
|
+
|
179
|
+
batch_data = []
|
180
|
+
for i, row in enumerate(metadata_arr):
|
181
|
+
batch_data.append((int(row), self.vector_to_hex(embeddings_arr[i])));
|
182
|
+
|
183
|
+
self.cursor.executemany(self.insert_sql, batch_data)
|
184
|
+
self.cursor.execute("COMMIT")
|
185
|
+
self.cursor.execute("FLUSH TABLES")
|
186
|
+
|
187
|
+
return len(metadata), None
|
188
|
+
except Exception as e:
|
189
|
+
log.warning(
|
190
|
+
f"Failed to insert data into Vector table ({self.table_name}), error: {e}"
|
191
|
+
)
|
192
|
+
return 0, e
|
193
|
+
|
194
|
+
|
195
|
+
def search_embedding(
|
196
|
+
self,
|
197
|
+
query: list[float],
|
198
|
+
k: int = 100,
|
199
|
+
filters: dict | None = None,
|
200
|
+
timeout: int | None = None,
|
201
|
+
**kwargs: Any,
|
202
|
+
) -> (list[int]):
|
203
|
+
assert self.conn is not None, "Connection is not initialized"
|
204
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
205
|
+
|
206
|
+
search_param = self.case_config.search_param()
|
207
|
+
|
208
|
+
if filters:
|
209
|
+
self.cursor.execute(self.select_sql_with_filter, (filters.get('id'), self.vector_to_hex(query), k))
|
210
|
+
else:
|
211
|
+
self.cursor.execute(self.select_sql, (self.vector_to_hex(query), k))
|
212
|
+
|
213
|
+
return [id for id, in self.cursor.fetchall()]
|
214
|
+
|
@@ -194,6 +194,56 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
|
|
194
194
|
**parameters,
|
195
195
|
)
|
196
196
|
|
197
|
+
@cli.command()
|
198
|
+
@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
|
199
|
+
def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
|
200
|
+
from .config import GPUBruteForceConfig, MilvusConfig
|
201
|
+
|
202
|
+
run(
|
203
|
+
db=DBTYPE,
|
204
|
+
db_config=MilvusConfig(
|
205
|
+
db_label=parameters["db_label"],
|
206
|
+
uri=SecretStr(parameters["uri"]),
|
207
|
+
user=parameters["user_name"],
|
208
|
+
password=SecretStr(parameters["password"]),
|
209
|
+
),
|
210
|
+
db_case_config=GPUBruteForceConfig(
|
211
|
+
metric_type=parameters["metric_type"],
|
212
|
+
limit=parameters["limit"], # top-k for search
|
213
|
+
),
|
214
|
+
**parameters,
|
215
|
+
)
|
216
|
+
|
217
|
+
class MilvusGPUBruteForceTypedDict(CommonTypedDict, MilvusTypedDict):
|
218
|
+
metric_type: Annotated[
|
219
|
+
str,
|
220
|
+
click.option("--metric-type", type=str, required=True, help="Metric type for brute force search"),
|
221
|
+
]
|
222
|
+
limit: Annotated[
|
223
|
+
int,
|
224
|
+
click.option("--limit", type=int, required=True, help="Top-k limit for search"),
|
225
|
+
]
|
226
|
+
|
227
|
+
@cli.command()
|
228
|
+
@click_parameter_decorators_from_typed_dict(MilvusGPUBruteForceTypedDict)
|
229
|
+
def MilvusGPUBruteForce(**parameters: Unpack[MilvusGPUBruteForceTypedDict]):
|
230
|
+
from .config import GPUBruteForceConfig, MilvusConfig
|
231
|
+
|
232
|
+
run(
|
233
|
+
db=DBTYPE,
|
234
|
+
db_config=MilvusConfig(
|
235
|
+
db_label=parameters["db_label"],
|
236
|
+
uri=SecretStr(parameters["uri"]),
|
237
|
+
user=parameters["user_name"],
|
238
|
+
password=SecretStr(parameters["password"]),
|
239
|
+
),
|
240
|
+
db_case_config=GPUBruteForceConfig(
|
241
|
+
metric_type=parameters["metric_type"],
|
242
|
+
limit=parameters["limit"], # top-k for search
|
243
|
+
),
|
244
|
+
**parameters,
|
245
|
+
)
|
246
|
+
|
197
247
|
|
198
248
|
class MilvusGPUIVFPQTypedDict(
|
199
249
|
CommonTypedDict,
|
@@ -40,6 +40,7 @@ class MilvusIndexConfig(BaseModel):
|
|
40
40
|
IndexType.GPU_CAGRA,
|
41
41
|
IndexType.GPU_IVF_FLAT,
|
42
42
|
IndexType.GPU_IVF_PQ,
|
43
|
+
IndexType.GPU_BRUTE_FORCE,
|
43
44
|
]
|
44
45
|
|
45
46
|
def parse_metric(self) -> str:
|
@@ -184,6 +185,37 @@ class GPUIVFFlatConfig(MilvusIndexConfig, DBCaseConfig):
|
|
184
185
|
}
|
185
186
|
|
186
187
|
|
188
|
+
class GPUBruteForceConfig(MilvusIndexConfig, DBCaseConfig):
|
189
|
+
limit: int = 10 # Default top-k for search
|
190
|
+
metric_type: str # Metric type (e.g., 'L2', 'IP', etc.)
|
191
|
+
index: IndexType = IndexType.GPU_BRUTE_FORCE # Index type set to GPU_BRUTE_FORCE
|
192
|
+
|
193
|
+
def index_param(self) -> dict:
|
194
|
+
"""
|
195
|
+
Returns the parameters for creating the GPU_BRUTE_FORCE index.
|
196
|
+
No additional parameters required for index building.
|
197
|
+
"""
|
198
|
+
return {
|
199
|
+
"metric_type": self.parse_metric(), # Metric type for distance calculation (L2, IP, etc.)
|
200
|
+
"index_type": self.index.value, # GPU_BRUTE_FORCE index type
|
201
|
+
"params": {}, # No additional parameters for GPU_BRUTE_FORCE
|
202
|
+
}
|
203
|
+
|
204
|
+
def search_param(self) -> dict:
|
205
|
+
"""
|
206
|
+
Returns the parameters for performing a search on the GPU_BRUTE_FORCE index.
|
207
|
+
Only metric_type and top-k (limit) are needed for search.
|
208
|
+
"""
|
209
|
+
return {
|
210
|
+
"metric_type": self.parse_metric(), # Metric type for search
|
211
|
+
"params": {
|
212
|
+
"nprobe": 1, # For GPU_BRUTE_FORCE, set nprobe to 1 (brute force search)
|
213
|
+
"limit": self.limit, # Top-k for search
|
214
|
+
},
|
215
|
+
}
|
216
|
+
|
217
|
+
|
218
|
+
|
187
219
|
class GPUIVFPQConfig(MilvusIndexConfig, DBCaseConfig):
|
188
220
|
nlist: int = 1024
|
189
221
|
m: int = 0
|
@@ -261,4 +293,5 @@ _milvus_case_config = {
|
|
261
293
|
IndexType.GPU_IVF_FLAT: GPUIVFFlatConfig,
|
262
294
|
IndexType.GPU_IVF_PQ: GPUIVFPQConfig,
|
263
295
|
IndexType.GPU_CAGRA: GPUCAGRAConfig,
|
296
|
+
IndexType.GPU_BRUTE_FORCE: GPUBruteForceConfig,
|
264
297
|
}
|
@@ -82,7 +82,17 @@ class PgVectorTypedDict(CommonTypedDict):
|
|
82
82
|
click.option(
|
83
83
|
"--quantization-type",
|
84
84
|
type=click.Choice(["none", "bit", "halfvec"]),
|
85
|
-
help="quantization type for vectors",
|
85
|
+
help="quantization type for vectors (in index)",
|
86
|
+
required=False,
|
87
|
+
),
|
88
|
+
]
|
89
|
+
table_quantization_type: Annotated[
|
90
|
+
str | None,
|
91
|
+
click.option(
|
92
|
+
"--table-quantization-type",
|
93
|
+
type=click.Choice(["none", "bit", "halfvec"]),
|
94
|
+
help="quantization type for vectors (in table). "
|
95
|
+
"If equal to bit, the parameter quantization_type will be set to bit too.",
|
86
96
|
required=False,
|
87
97
|
),
|
88
98
|
]
|
@@ -146,6 +156,7 @@ def PgVectorIVFFlat(
|
|
146
156
|
lists=parameters["lists"],
|
147
157
|
probes=parameters["probes"],
|
148
158
|
quantization_type=parameters["quantization_type"],
|
159
|
+
table_quantization_type=parameters["table_quantization_type"],
|
149
160
|
reranking=parameters["reranking"],
|
150
161
|
reranking_metric=parameters["reranking_metric"],
|
151
162
|
quantized_fetch_limit=parameters["quantized_fetch_limit"],
|
@@ -182,6 +193,7 @@ def PgVectorHNSW(
|
|
182
193
|
maintenance_work_mem=parameters["maintenance_work_mem"],
|
183
194
|
max_parallel_workers=parameters["max_parallel_workers"],
|
184
195
|
quantization_type=parameters["quantization_type"],
|
196
|
+
table_quantization_type=parameters["table_quantization_type"],
|
185
197
|
reranking=parameters["reranking"],
|
186
198
|
reranking_metric=parameters["reranking_metric"],
|
187
199
|
quantized_fetch_limit=parameters["quantized_fetch_limit"],
|
@@ -80,7 +80,12 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
|
|
80
80
|
|
81
81
|
if d.get(self.quantization_type) is None:
|
82
82
|
return d.get("_fallback").get(self.metric_type)
|
83
|
-
|
83
|
+
metric = d.get(self.quantization_type).get(self.metric_type)
|
84
|
+
# If using binary quantization for the index, use a bit metric
|
85
|
+
# no matter what metric was selected for vector or halfvec data
|
86
|
+
if self.quantization_type == "bit" and metric is None:
|
87
|
+
return "bit_hamming_ops"
|
88
|
+
return metric
|
84
89
|
|
85
90
|
def parse_metric_fun_op(self) -> LiteralString:
|
86
91
|
if self.quantization_type == "bit":
|
@@ -168,14 +173,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
|
|
168
173
|
maintenance_work_mem: str | None = None
|
169
174
|
max_parallel_workers: int | None = None
|
170
175
|
quantization_type: str | None = None
|
176
|
+
table_quantization_type: str | None
|
171
177
|
reranking: bool | None = None
|
172
178
|
quantized_fetch_limit: int | None = None
|
173
179
|
reranking_metric: str | None = None
|
174
180
|
|
175
181
|
def index_param(self) -> PgVectorIndexParam:
|
176
182
|
index_parameters = {"lists": self.lists}
|
177
|
-
if self.quantization_type == "none":
|
178
|
-
self.quantization_type =
|
183
|
+
if self.quantization_type == "none" or self.quantization_type is None:
|
184
|
+
self.quantization_type = "vector"
|
185
|
+
if self.table_quantization_type == "none" or self.table_quantization_type is None:
|
186
|
+
self.table_quantization_type = "vector"
|
187
|
+
if self.table_quantization_type == "bit":
|
188
|
+
self.quantization_type = "bit"
|
179
189
|
return {
|
180
190
|
"metric": self.parse_metric(),
|
181
191
|
"index_type": self.index.value,
|
@@ -183,6 +193,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
|
|
183
193
|
"maintenance_work_mem": self.maintenance_work_mem,
|
184
194
|
"max_parallel_workers": self.max_parallel_workers,
|
185
195
|
"quantization_type": self.quantization_type,
|
196
|
+
"table_quantization_type": self.table_quantization_type,
|
186
197
|
}
|
187
198
|
|
188
199
|
def search_param(self) -> PgVectorSearchParam:
|
@@ -212,14 +223,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
|
|
212
223
|
maintenance_work_mem: str | None = None
|
213
224
|
max_parallel_workers: int | None = None
|
214
225
|
quantization_type: str | None = None
|
226
|
+
table_quantization_type: str | None
|
215
227
|
reranking: bool | None = None
|
216
228
|
quantized_fetch_limit: int | None = None
|
217
229
|
reranking_metric: str | None = None
|
218
230
|
|
219
231
|
def index_param(self) -> PgVectorIndexParam:
|
220
232
|
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
|
221
|
-
if self.quantization_type == "none":
|
222
|
-
self.quantization_type =
|
233
|
+
if self.quantization_type == "none" or self.quantization_type is None:
|
234
|
+
self.quantization_type = "vector"
|
235
|
+
if self.table_quantization_type == "none" or self.table_quantization_type is None:
|
236
|
+
self.table_quantization_type = "vector"
|
237
|
+
if self.table_quantization_type == "bit":
|
238
|
+
self.quantization_type = "bit"
|
223
239
|
return {
|
224
240
|
"metric": self.parse_metric(),
|
225
241
|
"index_type": self.index.value,
|
@@ -227,6 +243,7 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
|
|
227
243
|
"maintenance_work_mem": self.maintenance_work_mem,
|
228
244
|
"max_parallel_workers": self.max_parallel_workers,
|
229
245
|
"quantization_type": self.quantization_type,
|
246
|
+
"table_quantization_type": self.table_quantization_type,
|
230
247
|
}
|
231
248
|
|
232
249
|
def search_param(self) -> PgVectorSearchParam:
|