vectordb-bench 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -13
- vectordb_bench/backend/clients/__init__.py +13 -0
- vectordb_bench/backend/clients/api.py +2 -0
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
- vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
- vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
- vectordb_bench/backend/clients/pgdiskann/cli.py +99 -0
- vectordb_bench/backend/clients/pgdiskann/config.py +145 -0
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +350 -0
- vectordb_bench/backend/clients/pgvector/cli.py +62 -1
- vectordb_bench/backend/clients/pgvector/config.py +48 -10
- vectordb_bench/backend/clients/pgvector/pgvector.py +145 -26
- vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +22 -4
- vectordb_bench/backend/clients/pinecone/config.py +0 -2
- vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
- vectordb_bench/backend/clients/redis/cli.py +8 -0
- vectordb_bench/backend/clients/redis/config.py +37 -6
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
- vectordb_bench/backend/runner/mp_runner.py +2 -1
- vectordb_bench/cli/cli.py +137 -0
- vectordb_bench/cli/vectordbbench.py +4 -1
- vectordb_bench/frontend/components/check_results/charts.py +9 -6
- vectordb_bench/frontend/components/concurrent/charts.py +3 -6
- vectordb_bench/frontend/components/run_test/caseSelector.py +6 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +165 -1
- vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
- vectordb_bench/frontend/vdb_benchmark.py +11 -3
- vectordb_bench/models.py +13 -3
- vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
- vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
- vectordb_bench/results/getLeaderboardData.py +17 -7
- vectordb_bench/results/leaderboard.json +1 -1
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/METADATA +65 -35
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/RECORD +42 -38
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,350 @@
|
|
1
|
+
"""Wrapper around the pg_diskann vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import pprint
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any, Generator, Optional, Tuple
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import psycopg
|
10
|
+
from pgvector.psycopg import register_vector
|
11
|
+
from psycopg import Connection, Cursor, sql
|
12
|
+
|
13
|
+
from ..api import VectorDB
|
14
|
+
from .config import PgDiskANNConfigDict, PgDiskANNIndexConfig
|
15
|
+
|
16
|
+
log = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class PgDiskANN(VectorDB):
|
20
|
+
"""Use psycopg instructions"""
|
21
|
+
|
22
|
+
conn: psycopg.Connection[Any] | None = None
|
23
|
+
coursor: psycopg.Cursor[Any] | None = None
|
24
|
+
|
25
|
+
_filtered_search: sql.Composed
|
26
|
+
_unfiltered_search: sql.Composed
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
dim: int,
|
31
|
+
db_config: PgDiskANNConfigDict,
|
32
|
+
db_case_config: PgDiskANNIndexConfig,
|
33
|
+
collection_name: str = "pg_diskann_collection",
|
34
|
+
drop_old: bool = False,
|
35
|
+
**kwargs,
|
36
|
+
):
|
37
|
+
self.name = "PgDiskANN"
|
38
|
+
self.db_config = db_config
|
39
|
+
self.case_config = db_case_config
|
40
|
+
self.table_name = collection_name
|
41
|
+
self.dim = dim
|
42
|
+
|
43
|
+
self._index_name = "pgdiskann_index"
|
44
|
+
self._primary_field = "id"
|
45
|
+
self._vector_field = "embedding"
|
46
|
+
|
47
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
48
|
+
|
49
|
+
log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
|
50
|
+
if not any(
|
51
|
+
(
|
52
|
+
self.case_config.create_index_before_load,
|
53
|
+
self.case_config.create_index_after_load,
|
54
|
+
)
|
55
|
+
):
|
56
|
+
err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
|
57
|
+
log.error(err)
|
58
|
+
raise RuntimeError(
|
59
|
+
f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
|
60
|
+
)
|
61
|
+
|
62
|
+
if drop_old:
|
63
|
+
self._drop_index()
|
64
|
+
self._drop_table()
|
65
|
+
self._create_table(dim)
|
66
|
+
if self.case_config.create_index_before_load:
|
67
|
+
self._create_index()
|
68
|
+
|
69
|
+
self.cursor.close()
|
70
|
+
self.conn.close()
|
71
|
+
self.cursor = None
|
72
|
+
self.conn = None
|
73
|
+
|
74
|
+
@staticmethod
|
75
|
+
def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
|
76
|
+
conn = psycopg.connect(**kwargs)
|
77
|
+
conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS pg_diskann CASCADE")
|
78
|
+
conn.commit()
|
79
|
+
register_vector(conn)
|
80
|
+
conn.autocommit = False
|
81
|
+
cursor = conn.cursor()
|
82
|
+
|
83
|
+
assert conn is not None, "Connection is not initialized"
|
84
|
+
assert cursor is not None, "Cursor is not initialized"
|
85
|
+
|
86
|
+
return conn, cursor
|
87
|
+
|
88
|
+
@contextmanager
|
89
|
+
def init(self) -> Generator[None, None, None]:
|
90
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
91
|
+
|
92
|
+
# index configuration may have commands defined that we should set during each client session
|
93
|
+
session_options: dict[str, Any] = self.case_config.session_param()
|
94
|
+
|
95
|
+
if len(session_options) > 0:
|
96
|
+
for setting_name, setting_val in session_options.items():
|
97
|
+
command = sql.SQL("SET {setting_name} " + "= {setting_val};").format(
|
98
|
+
setting_name=sql.Identifier(setting_name),
|
99
|
+
setting_val=sql.Identifier(str(setting_val)),
|
100
|
+
)
|
101
|
+
log.debug(command.as_string(self.cursor))
|
102
|
+
self.cursor.execute(command)
|
103
|
+
self.conn.commit()
|
104
|
+
|
105
|
+
self._filtered_search = sql.Composed(
|
106
|
+
[
|
107
|
+
sql.SQL(
|
108
|
+
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
|
109
|
+
).format(table_name=sql.Identifier(self.table_name)),
|
110
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
111
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
112
|
+
]
|
113
|
+
)
|
114
|
+
|
115
|
+
self._unfiltered_search = sql.Composed(
|
116
|
+
[
|
117
|
+
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
|
118
|
+
sql.Identifier(self.table_name)
|
119
|
+
),
|
120
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
121
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
122
|
+
]
|
123
|
+
)
|
124
|
+
|
125
|
+
try:
|
126
|
+
yield
|
127
|
+
finally:
|
128
|
+
self.cursor.close()
|
129
|
+
self.conn.close()
|
130
|
+
self.cursor = None
|
131
|
+
self.conn = None
|
132
|
+
|
133
|
+
def _drop_table(self):
|
134
|
+
assert self.conn is not None, "Connection is not initialized"
|
135
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
136
|
+
log.info(f"{self.name} client drop table : {self.table_name}")
|
137
|
+
|
138
|
+
self.cursor.execute(
|
139
|
+
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
|
140
|
+
table_name=sql.Identifier(self.table_name)
|
141
|
+
)
|
142
|
+
)
|
143
|
+
self.conn.commit()
|
144
|
+
|
145
|
+
def ready_to_load(self):
|
146
|
+
pass
|
147
|
+
|
148
|
+
def optimize(self):
|
149
|
+
self._post_insert()
|
150
|
+
|
151
|
+
def _post_insert(self):
|
152
|
+
log.info(f"{self.name} post insert before optimize")
|
153
|
+
if self.case_config.create_index_after_load:
|
154
|
+
self._drop_index()
|
155
|
+
self._create_index()
|
156
|
+
|
157
|
+
def _drop_index(self):
|
158
|
+
assert self.conn is not None, "Connection is not initialized"
|
159
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
160
|
+
log.info(f"{self.name} client drop index : {self._index_name}")
|
161
|
+
|
162
|
+
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
|
163
|
+
index_name=sql.Identifier(self._index_name)
|
164
|
+
)
|
165
|
+
log.debug(drop_index_sql.as_string(self.cursor))
|
166
|
+
self.cursor.execute(drop_index_sql)
|
167
|
+
self.conn.commit()
|
168
|
+
|
169
|
+
def _set_parallel_index_build_param(self):
|
170
|
+
assert self.conn is not None, "Connection is not initialized"
|
171
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
172
|
+
|
173
|
+
index_param = self.case_config.index_param()
|
174
|
+
|
175
|
+
if index_param["maintenance_work_mem"] is not None:
|
176
|
+
self.cursor.execute(
|
177
|
+
sql.SQL("SET maintenance_work_mem TO {};").format(
|
178
|
+
index_param["maintenance_work_mem"]
|
179
|
+
)
|
180
|
+
)
|
181
|
+
self.cursor.execute(
|
182
|
+
sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
|
183
|
+
sql.Identifier(self.db_config["user"]),
|
184
|
+
index_param["maintenance_work_mem"],
|
185
|
+
)
|
186
|
+
)
|
187
|
+
self.conn.commit()
|
188
|
+
|
189
|
+
if index_param["max_parallel_workers"] is not None:
|
190
|
+
self.cursor.execute(
|
191
|
+
sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
|
192
|
+
index_param["max_parallel_workers"]
|
193
|
+
)
|
194
|
+
)
|
195
|
+
self.cursor.execute(
|
196
|
+
sql.SQL(
|
197
|
+
"ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
|
198
|
+
).format(
|
199
|
+
sql.Identifier(self.db_config["user"]),
|
200
|
+
index_param["max_parallel_workers"],
|
201
|
+
)
|
202
|
+
)
|
203
|
+
self.cursor.execute(
|
204
|
+
sql.SQL("SET max_parallel_workers TO '{}';").format(
|
205
|
+
index_param["max_parallel_workers"]
|
206
|
+
)
|
207
|
+
)
|
208
|
+
self.cursor.execute(
|
209
|
+
sql.SQL(
|
210
|
+
"ALTER USER {} SET max_parallel_workers TO '{}';"
|
211
|
+
).format(
|
212
|
+
sql.Identifier(self.db_config["user"]),
|
213
|
+
index_param["max_parallel_workers"],
|
214
|
+
)
|
215
|
+
)
|
216
|
+
self.cursor.execute(
|
217
|
+
sql.SQL(
|
218
|
+
"ALTER TABLE {} SET (parallel_workers = {});"
|
219
|
+
).format(
|
220
|
+
sql.Identifier(self.table_name),
|
221
|
+
index_param["max_parallel_workers"],
|
222
|
+
)
|
223
|
+
)
|
224
|
+
self.conn.commit()
|
225
|
+
|
226
|
+
results = self.cursor.execute(
|
227
|
+
sql.SQL("SHOW max_parallel_maintenance_workers;")
|
228
|
+
).fetchall()
|
229
|
+
results.extend(
|
230
|
+
self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
|
231
|
+
)
|
232
|
+
results.extend(
|
233
|
+
self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
|
234
|
+
)
|
235
|
+
log.info(f"{self.name} parallel index creation parameters: {results}")
|
236
|
+
def _create_index(self):
|
237
|
+
assert self.conn is not None, "Connection is not initialized"
|
238
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
239
|
+
log.info(f"{self.name} client create index : {self._index_name}")
|
240
|
+
|
241
|
+
index_param: dict[str, Any] = self.case_config.index_param()
|
242
|
+
self._set_parallel_index_build_param()
|
243
|
+
|
244
|
+
options = []
|
245
|
+
for option_name, option_val in index_param["options"].items():
|
246
|
+
if option_val is not None:
|
247
|
+
options.append(
|
248
|
+
sql.SQL("{option_name} = {val}").format(
|
249
|
+
option_name=sql.Identifier(option_name),
|
250
|
+
val=sql.Identifier(str(option_val)),
|
251
|
+
)
|
252
|
+
)
|
253
|
+
|
254
|
+
if any(options):
|
255
|
+
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
|
256
|
+
else:
|
257
|
+
with_clause = sql.Composed(())
|
258
|
+
|
259
|
+
index_create_sql = sql.SQL(
|
260
|
+
"""
|
261
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
262
|
+
USING {index_type} (embedding {embedding_metric})
|
263
|
+
"""
|
264
|
+
).format(
|
265
|
+
index_name=sql.Identifier(self._index_name),
|
266
|
+
table_name=sql.Identifier(self.table_name),
|
267
|
+
index_type=sql.Identifier(index_param["index_type"].lower()),
|
268
|
+
embedding_metric=sql.Identifier(index_param["metric"]),
|
269
|
+
)
|
270
|
+
index_create_sql_with_with_clause = (
|
271
|
+
index_create_sql + with_clause
|
272
|
+
).join(" ")
|
273
|
+
log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
|
274
|
+
self.cursor.execute(index_create_sql_with_with_clause)
|
275
|
+
self.conn.commit()
|
276
|
+
|
277
|
+
def _create_table(self, dim: int):
|
278
|
+
assert self.conn is not None, "Connection is not initialized"
|
279
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
280
|
+
|
281
|
+
try:
|
282
|
+
log.info(f"{self.name} client create table : {self.table_name}")
|
283
|
+
|
284
|
+
self.cursor.execute(
|
285
|
+
sql.SQL(
|
286
|
+
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
|
287
|
+
).format(table_name=sql.Identifier(self.table_name), dim=dim)
|
288
|
+
)
|
289
|
+
self.conn.commit()
|
290
|
+
except Exception as e:
|
291
|
+
log.warning(
|
292
|
+
f"Failed to create pgdiskann table: {self.table_name} error: {e}"
|
293
|
+
)
|
294
|
+
raise e from None
|
295
|
+
|
296
|
+
def insert_embeddings(
|
297
|
+
self,
|
298
|
+
embeddings: list[list[float]],
|
299
|
+
metadata: list[int],
|
300
|
+
**kwargs: Any,
|
301
|
+
) -> Tuple[int, Optional[Exception]]:
|
302
|
+
assert self.conn is not None, "Connection is not initialized"
|
303
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
304
|
+
|
305
|
+
try:
|
306
|
+
metadata_arr = np.array(metadata)
|
307
|
+
embeddings_arr = np.array(embeddings)
|
308
|
+
|
309
|
+
with self.cursor.copy(
|
310
|
+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
|
311
|
+
table_name=sql.Identifier(self.table_name)
|
312
|
+
)
|
313
|
+
) as copy:
|
314
|
+
copy.set_types(["bigint", "vector"])
|
315
|
+
for i, row in enumerate(metadata_arr):
|
316
|
+
copy.write_row((row, embeddings_arr[i]))
|
317
|
+
self.conn.commit()
|
318
|
+
|
319
|
+
if kwargs.get("last_batch"):
|
320
|
+
self._post_insert()
|
321
|
+
|
322
|
+
return len(metadata), None
|
323
|
+
except Exception as e:
|
324
|
+
log.warning(
|
325
|
+
f"Failed to insert data into table ({self.table_name}), error: {e}"
|
326
|
+
)
|
327
|
+
return 0, e
|
328
|
+
|
329
|
+
def search_embedding(
|
330
|
+
self,
|
331
|
+
query: list[float],
|
332
|
+
k: int = 100,
|
333
|
+
filters: dict | None = None,
|
334
|
+
timeout: int | None = None,
|
335
|
+
) -> list[int]:
|
336
|
+
assert self.conn is not None, "Connection is not initialized"
|
337
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
338
|
+
|
339
|
+
q = np.asarray(query)
|
340
|
+
if filters:
|
341
|
+
gt = filters.get("id")
|
342
|
+
result = self.cursor.execute(
|
343
|
+
self._filtered_search, (gt, q, k), prepare=True, binary=True
|
344
|
+
)
|
345
|
+
else:
|
346
|
+
result = self.cursor.execute(
|
347
|
+
self._unfiltered_search, (q, k), prepare=True, binary=True
|
348
|
+
)
|
349
|
+
|
350
|
+
return [int(i[0]) for i in result.fetchall()]
|
@@ -4,17 +4,27 @@ import click
|
|
4
4
|
import os
|
5
5
|
from pydantic import SecretStr
|
6
6
|
|
7
|
+
from vectordb_bench.backend.clients.api import MetricType
|
8
|
+
|
7
9
|
from ....cli.cli import (
|
8
10
|
CommonTypedDict,
|
9
11
|
HNSWFlavor1,
|
10
12
|
IVFFlatTypedDict,
|
11
13
|
cli,
|
12
14
|
click_parameter_decorators_from_typed_dict,
|
15
|
+
get_custom_case_config,
|
13
16
|
run,
|
14
17
|
)
|
15
18
|
from vectordb_bench.backend.clients import DB
|
16
19
|
|
17
20
|
|
21
|
+
|
22
|
+
def set_default_quantized_fetch_limit(ctx, param, value):
|
23
|
+
if ctx.params.get("reranking") and value is None:
|
24
|
+
# ef_search is the default value for quantized_fetch_limit as it's bound by ef_search.
|
25
|
+
return ctx.params["ef_search"]
|
26
|
+
return value
|
27
|
+
|
18
28
|
class PgVectorTypedDict(CommonTypedDict):
|
19
29
|
user_name: Annotated[
|
20
30
|
str, click.option("--user-name", type=str, help="Db username", required=True)
|
@@ -56,7 +66,49 @@ class PgVectorTypedDict(CommonTypedDict):
|
|
56
66
|
required=False,
|
57
67
|
),
|
58
68
|
]
|
69
|
+
quantization_type: Annotated[
|
70
|
+
Optional[str],
|
71
|
+
click.option(
|
72
|
+
"--quantization-type",
|
73
|
+
type=click.Choice(["none", "bit", "halfvec"]),
|
74
|
+
help="quantization type for vectors",
|
75
|
+
required=False,
|
76
|
+
),
|
77
|
+
]
|
78
|
+
reranking: Annotated[
|
79
|
+
Optional[bool],
|
80
|
+
click.option(
|
81
|
+
"--reranking/--skip-reranking",
|
82
|
+
type=bool,
|
83
|
+
help="Enable reranking for HNSW search for binary quantization",
|
84
|
+
default=False,
|
85
|
+
),
|
86
|
+
]
|
87
|
+
reranking_metric: Annotated[
|
88
|
+
Optional[str],
|
89
|
+
click.option(
|
90
|
+
"--reranking-metric",
|
91
|
+
type=click.Choice(
|
92
|
+
[metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]]
|
93
|
+
),
|
94
|
+
help="Distance metric for reranking",
|
95
|
+
default="COSINE",
|
96
|
+
show_default=True,
|
97
|
+
),
|
98
|
+
]
|
99
|
+
quantized_fetch_limit: Annotated[
|
100
|
+
Optional[int],
|
101
|
+
click.option(
|
102
|
+
"--quantized-fetch-limit",
|
103
|
+
type=int,
|
104
|
+
help="Limit of fetching quantized vector ranked by distance for reranking \
|
105
|
+
-- bound by ef_search",
|
106
|
+
required=False,
|
107
|
+
callback=set_default_quantized_fetch_limit,
|
108
|
+
)
|
109
|
+
]
|
59
110
|
|
111
|
+
|
60
112
|
|
61
113
|
class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
|
62
114
|
...
|
@@ -69,6 +121,7 @@ def PgVectorIVFFlat(
|
|
69
121
|
):
|
70
122
|
from .config import PgVectorConfig, PgVectorIVFFlatConfig
|
71
123
|
|
124
|
+
parameters["custom_case"] = get_custom_case_config(parameters)
|
72
125
|
run(
|
73
126
|
db=DB.PgVector,
|
74
127
|
db_config=PgVectorConfig(
|
@@ -79,7 +132,10 @@ def PgVectorIVFFlat(
|
|
79
132
|
db_name=parameters["db_name"],
|
80
133
|
),
|
81
134
|
db_case_config=PgVectorIVFFlatConfig(
|
82
|
-
metric_type=None,
|
135
|
+
metric_type=None,
|
136
|
+
lists=parameters["lists"],
|
137
|
+
probes=parameters["probes"],
|
138
|
+
quantization_type=parameters["quantization_type"],
|
83
139
|
),
|
84
140
|
**parameters,
|
85
141
|
)
|
@@ -96,6 +152,7 @@ def PgVectorHNSW(
|
|
96
152
|
):
|
97
153
|
from .config import PgVectorConfig, PgVectorHNSWConfig
|
98
154
|
|
155
|
+
parameters["custom_case"] = get_custom_case_config(parameters)
|
99
156
|
run(
|
100
157
|
db=DB.PgVector,
|
101
158
|
db_config=PgVectorConfig(
|
@@ -111,6 +168,10 @@ def PgVectorHNSW(
|
|
111
168
|
ef_search=parameters["ef_search"],
|
112
169
|
maintenance_work_mem=parameters["maintenance_work_mem"],
|
113
170
|
max_parallel_workers=parameters["max_parallel_workers"],
|
171
|
+
quantization_type=parameters["quantization_type"],
|
172
|
+
reranking=parameters["reranking"],
|
173
|
+
reranking_metric=parameters["reranking_metric"],
|
174
|
+
quantized_fetch_limit=parameters["quantized_fetch_limit"],
|
114
175
|
),
|
115
176
|
**parameters,
|
116
177
|
)
|
@@ -59,18 +59,34 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
|
|
59
59
|
create_index_after_load: bool = True
|
60
60
|
|
61
61
|
def parse_metric(self) -> str:
|
62
|
-
if self.
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
62
|
+
if self.quantization_type == "halfvec":
|
63
|
+
if self.metric_type == MetricType.L2:
|
64
|
+
return "halfvec_l2_ops"
|
65
|
+
elif self.metric_type == MetricType.IP:
|
66
|
+
return "halfvec_ip_ops"
|
67
|
+
return "halfvec_cosine_ops"
|
68
|
+
elif self.quantization_type == "bit":
|
69
|
+
if self.metric_type == MetricType.JACCARD:
|
70
|
+
return "bit_jaccard_ops"
|
71
|
+
return "bit_hamming_ops"
|
72
|
+
else:
|
73
|
+
if self.metric_type == MetricType.L2:
|
74
|
+
return "vector_l2_ops"
|
75
|
+
elif self.metric_type == MetricType.IP:
|
76
|
+
return "vector_ip_ops"
|
77
|
+
return "vector_cosine_ops"
|
67
78
|
|
68
79
|
def parse_metric_fun_op(self) -> LiteralString:
|
69
|
-
if self.
|
70
|
-
|
71
|
-
|
72
|
-
return "
|
73
|
-
|
80
|
+
if self.quantization_type == "bit":
|
81
|
+
if self.metric_type == MetricType.JACCARD:
|
82
|
+
return "<%>"
|
83
|
+
return "<~>"
|
84
|
+
else:
|
85
|
+
if self.metric_type == MetricType.L2:
|
86
|
+
return "<->"
|
87
|
+
elif self.metric_type == MetricType.IP:
|
88
|
+
return "<#>"
|
89
|
+
return "<=>"
|
74
90
|
|
75
91
|
def parse_metric_fun_str(self) -> str:
|
76
92
|
if self.metric_type == MetricType.L2:
|
@@ -78,6 +94,14 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
|
|
78
94
|
elif self.metric_type == MetricType.IP:
|
79
95
|
return "max_inner_product"
|
80
96
|
return "cosine_distance"
|
97
|
+
|
98
|
+
def parse_reranking_metric_fun_op(self) -> LiteralString:
|
99
|
+
if self.reranking_metric == MetricType.L2:
|
100
|
+
return "<->"
|
101
|
+
elif self.reranking_metric == MetricType.IP:
|
102
|
+
return "<#>"
|
103
|
+
return "<=>"
|
104
|
+
|
81
105
|
|
82
106
|
@abstractmethod
|
83
107
|
def index_param(self) -> PgVectorIndexParam:
|
@@ -143,9 +167,12 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
|
|
143
167
|
index: IndexType = IndexType.ES_IVFFlat
|
144
168
|
maintenance_work_mem: Optional[str] = None
|
145
169
|
max_parallel_workers: Optional[int] = None
|
170
|
+
quantization_type: Optional[str] = None
|
146
171
|
|
147
172
|
def index_param(self) -> PgVectorIndexParam:
|
148
173
|
index_parameters = {"lists": self.lists}
|
174
|
+
if self.quantization_type == "none":
|
175
|
+
self.quantization_type = None
|
149
176
|
return {
|
150
177
|
"metric": self.parse_metric(),
|
151
178
|
"index_type": self.index.value,
|
@@ -154,6 +181,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
|
|
154
181
|
),
|
155
182
|
"maintenance_work_mem": self.maintenance_work_mem,
|
156
183
|
"max_parallel_workers": self.max_parallel_workers,
|
184
|
+
"quantization_type": self.quantization_type,
|
157
185
|
}
|
158
186
|
|
159
187
|
def search_param(self) -> PgVectorSearchParam:
|
@@ -183,9 +211,15 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
|
|
183
211
|
index: IndexType = IndexType.ES_HNSW
|
184
212
|
maintenance_work_mem: Optional[str] = None
|
185
213
|
max_parallel_workers: Optional[int] = None
|
214
|
+
quantization_type: Optional[str] = None
|
215
|
+
reranking: Optional[bool] = None
|
216
|
+
quantized_fetch_limit: Optional[int] = None
|
217
|
+
reranking_metric: Optional[str] = None
|
186
218
|
|
187
219
|
def index_param(self) -> PgVectorIndexParam:
|
188
220
|
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
|
221
|
+
if self.quantization_type == "none":
|
222
|
+
self.quantization_type = None
|
189
223
|
return {
|
190
224
|
"metric": self.parse_metric(),
|
191
225
|
"index_type": self.index.value,
|
@@ -194,11 +228,15 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
|
|
194
228
|
),
|
195
229
|
"maintenance_work_mem": self.maintenance_work_mem,
|
196
230
|
"max_parallel_workers": self.max_parallel_workers,
|
231
|
+
"quantization_type": self.quantization_type,
|
197
232
|
}
|
198
233
|
|
199
234
|
def search_param(self) -> PgVectorSearchParam:
|
200
235
|
return {
|
201
236
|
"metric_fun_op": self.parse_metric_fun_op(),
|
237
|
+
"reranking": self.reranking,
|
238
|
+
"reranking_metric_fun_op": self.parse_reranking_metric_fun_op(),
|
239
|
+
"quantized_fetch_limit": self.quantized_fetch_limit,
|
202
240
|
}
|
203
241
|
|
204
242
|
def session_param(self) -> PgVectorSessionCommands:
|