vectordb-bench 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +22 -0
- vectordb_bench/backend/clients/api.py +21 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
- vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
- vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
- vectordb_bench/backend/clients/memorydb/cli.py +88 -0
- vectordb_bench/backend/clients/memorydb/config.py +54 -0
- vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
- vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
- vectordb_bench/backend/clients/pgvector/cli.py +17 -2
- vectordb_bench/backend/clients/pgvector/config.py +20 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +95 -25
- vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
- vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +290 -0
- vectordb_bench/backend/clients/pinecone/config.py +0 -2
- vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
- vectordb_bench/backend/clients/redis/cli.py +8 -0
- vectordb_bench/backend/clients/redis/config.py +37 -6
- vectordb_bench/backend/runner/mp_runner.py +2 -1
- vectordb_bench/cli/cli.py +137 -0
- vectordb_bench/cli/vectordbbench.py +7 -1
- vectordb_bench/frontend/components/check_results/charts.py +9 -6
- vectordb_bench/frontend/components/check_results/data.py +13 -6
- vectordb_bench/frontend/components/concurrent/charts.py +3 -6
- vectordb_bench/frontend/components/run_test/caseSelector.py +10 -0
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -15
- vectordb_bench/frontend/components/run_test/initStyle.py +3 -1
- vectordb_bench/frontend/config/dbCaseConfigs.py +230 -9
- vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
- vectordb_bench/frontend/vdb_benchmark.py +11 -3
- vectordb_bench/models.py +25 -9
- vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
- vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
- vectordb_bench/results/getLeaderboardData.py +17 -7
- vectordb_bench/results/leaderboard.json +1 -1
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/METADATA +64 -31
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/RECORD +47 -40
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/top_level.txt +0 -0
@@ -1,73 +1,138 @@
|
|
1
1
|
"""Wrapper around the Pgvecto.rs vector database over VectorDB"""
|
2
2
|
|
3
|
-
import io
|
4
3
|
import logging
|
4
|
+
import pprint
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from typing import Any
|
7
|
-
import pandas as pd
|
8
|
-
import psycopg2
|
9
|
-
import psycopg2.extras
|
6
|
+
from typing import Any, Generator, Optional, Tuple
|
10
7
|
|
11
|
-
|
8
|
+
import numpy as np
|
9
|
+
import psycopg
|
10
|
+
from psycopg import Connection, Cursor, sql
|
11
|
+
from pgvecto_rs.psycopg import register_vector
|
12
|
+
|
13
|
+
from ..api import VectorDB
|
14
|
+
from .config import PgVectoRSConfig, PgVectoRSIndexConfig
|
12
15
|
|
13
16
|
log = logging.getLogger(__name__)
|
14
17
|
|
18
|
+
|
15
19
|
class PgVectoRS(VectorDB):
|
16
|
-
"""Use
|
20
|
+
"""Use psycopg instructions"""
|
21
|
+
|
22
|
+
conn: psycopg.Connection[Any] | None = None
|
23
|
+
cursor: psycopg.Cursor[Any] | None = None
|
24
|
+
_unfiltered_search: sql.Composed
|
25
|
+
_filtered_search: sql.Composed
|
17
26
|
|
18
27
|
def __init__(
|
19
28
|
self,
|
20
29
|
dim: int,
|
21
|
-
db_config:
|
22
|
-
db_case_config:
|
23
|
-
collection_name: str = "
|
30
|
+
db_config: PgVectoRSConfig,
|
31
|
+
db_case_config: PgVectoRSIndexConfig,
|
32
|
+
collection_name: str = "PgVectoRSCollection",
|
24
33
|
drop_old: bool = False,
|
25
34
|
**kwargs,
|
26
35
|
):
|
36
|
+
|
37
|
+
self.name = "PgVectorRS"
|
27
38
|
self.db_config = db_config
|
28
39
|
self.case_config = db_case_config
|
29
40
|
self.table_name = collection_name
|
30
41
|
self.dim = dim
|
31
42
|
|
32
|
-
self._index_name = "
|
43
|
+
self._index_name = "pgvectors_index"
|
33
44
|
self._primary_field = "id"
|
34
45
|
self._vector_field = "embedding"
|
35
46
|
|
36
47
|
# construct basic units
|
37
|
-
self.conn =
|
38
|
-
self.conn.autocommit = False
|
39
|
-
self.cursor = self.conn.cursor()
|
48
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
40
49
|
|
41
|
-
|
42
|
-
|
43
|
-
|
50
|
+
log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
|
51
|
+
if not any(
|
52
|
+
(
|
53
|
+
self.case_config.create_index_before_load,
|
54
|
+
self.case_config.create_index_after_load,
|
55
|
+
)
|
56
|
+
):
|
57
|
+
err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
|
58
|
+
log.error(err)
|
59
|
+
raise RuntimeError(
|
60
|
+
f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
|
61
|
+
)
|
44
62
|
|
45
63
|
if drop_old:
|
46
64
|
log.info(f"Pgvecto.rs client drop table : {self.table_name}")
|
47
65
|
self._drop_index()
|
48
66
|
self._drop_table()
|
49
67
|
self._create_table(dim)
|
50
|
-
self.
|
68
|
+
if self.case_config.create_index_before_load:
|
69
|
+
self._create_index()
|
51
70
|
|
52
71
|
self.cursor.close()
|
53
72
|
self.conn.close()
|
54
73
|
self.cursor = None
|
55
74
|
self.conn = None
|
56
75
|
|
76
|
+
@staticmethod
|
77
|
+
def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
|
78
|
+
conn = psycopg.connect(**kwargs)
|
79
|
+
|
80
|
+
# create vector extension
|
81
|
+
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors")
|
82
|
+
conn.commit()
|
83
|
+
register_vector(conn)
|
84
|
+
|
85
|
+
conn.autocommit = False
|
86
|
+
cursor = conn.cursor()
|
87
|
+
|
88
|
+
assert conn is not None, "Connection is not initialized"
|
89
|
+
assert cursor is not None, "Cursor is not initialized"
|
90
|
+
|
91
|
+
return conn, cursor
|
92
|
+
|
57
93
|
@contextmanager
|
58
|
-
def init(self) -> None:
|
94
|
+
def init(self) -> Generator[None, None, None]:
|
59
95
|
"""
|
60
96
|
Examples:
|
61
97
|
>>> with self.init():
|
62
98
|
>>> self.insert_embeddings()
|
63
99
|
>>> self.search_embedding()
|
64
100
|
"""
|
65
|
-
|
66
|
-
self.conn.
|
67
|
-
|
68
|
-
|
101
|
+
|
102
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
103
|
+
|
104
|
+
# index configuration may have commands defined that we should set during each client session
|
105
|
+
session_options = self.case_config.session_param()
|
106
|
+
|
107
|
+
for key, val in session_options.items():
|
108
|
+
command = sql.SQL("SET {setting_name} " + "= {val};").format(
|
109
|
+
setting_name=sql.Identifier(key),
|
110
|
+
val=val,
|
111
|
+
)
|
112
|
+
log.debug(command.as_string(self.cursor))
|
113
|
+
self.cursor.execute(command)
|
69
114
|
self.conn.commit()
|
70
115
|
|
116
|
+
self._filtered_search = sql.Composed(
|
117
|
+
[
|
118
|
+
sql.SQL(
|
119
|
+
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
|
120
|
+
).format(table_name=sql.Identifier(self.table_name)),
|
121
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
122
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
123
|
+
]
|
124
|
+
)
|
125
|
+
|
126
|
+
self._unfiltered_search = sql.Composed(
|
127
|
+
[
|
128
|
+
sql.SQL(
|
129
|
+
"SELECT id FROM public.{table_name} ORDER BY embedding "
|
130
|
+
).format(table_name=sql.Identifier(self.table_name)),
|
131
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
132
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
133
|
+
]
|
134
|
+
)
|
135
|
+
|
71
136
|
try:
|
72
137
|
yield
|
73
138
|
finally:
|
@@ -79,42 +144,65 @@ class PgVectoRS(VectorDB):
|
|
79
144
|
def _drop_table(self):
|
80
145
|
assert self.conn is not None, "Connection is not initialized"
|
81
146
|
assert self.cursor is not None, "Cursor is not initialized"
|
147
|
+
log.info(f"{self.name} client drop table : {self.table_name}")
|
82
148
|
|
83
|
-
self.cursor.execute(
|
149
|
+
self.cursor.execute(
|
150
|
+
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
|
151
|
+
table_name=sql.Identifier(self.table_name)
|
152
|
+
)
|
153
|
+
)
|
84
154
|
self.conn.commit()
|
85
155
|
|
86
156
|
def ready_to_load(self):
|
87
157
|
pass
|
88
158
|
|
89
159
|
def optimize(self):
|
90
|
-
|
160
|
+
self._post_insert()
|
91
161
|
|
92
|
-
def
|
93
|
-
|
162
|
+
def _post_insert(self):
|
163
|
+
log.info(f"{self.name} post insert before optimize")
|
164
|
+
if self.case_config.create_index_after_load:
|
165
|
+
self._drop_index()
|
166
|
+
self._create_index()
|
94
167
|
|
95
168
|
def _drop_index(self):
|
96
169
|
assert self.conn is not None, "Connection is not initialized"
|
97
170
|
assert self.cursor is not None, "Cursor is not initialized"
|
171
|
+
log.info(f"{self.name} client drop index : {self._index_name}")
|
98
172
|
|
99
|
-
|
173
|
+
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
|
174
|
+
index_name=sql.Identifier(self._index_name)
|
175
|
+
)
|
176
|
+
log.debug(drop_index_sql.as_string(self.cursor))
|
177
|
+
self.cursor.execute(drop_index_sql)
|
100
178
|
self.conn.commit()
|
101
179
|
|
102
180
|
def _create_index(self):
|
103
181
|
assert self.conn is not None, "Connection is not initialized"
|
104
182
|
assert self.cursor is not None, "Cursor is not initialized"
|
183
|
+
log.info(f"{self.name} client create index : {self._index_name}")
|
105
184
|
|
106
185
|
index_param = self.case_config.index_param()
|
107
186
|
|
187
|
+
index_create_sql = sql.SQL(
|
188
|
+
"""
|
189
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
190
|
+
USING vectors (embedding {embedding_metric}) WITH (options = {index_options})
|
191
|
+
"""
|
192
|
+
).format(
|
193
|
+
index_name=sql.Identifier(self._index_name),
|
194
|
+
table_name=sql.Identifier(self.table_name),
|
195
|
+
embedding_metric=sql.Identifier(index_param["metric"]),
|
196
|
+
index_options=index_param["options"],
|
197
|
+
)
|
108
198
|
try:
|
109
|
-
|
110
|
-
self.cursor.execute(
|
111
|
-
f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" \
|
112
|
-
USING vectors (embedding {index_param["metric"]}) WITH (options = $${index_param["options"]}$$);'
|
113
|
-
)
|
199
|
+
log.debug(index_create_sql.as_string(self.cursor))
|
200
|
+
self.cursor.execute(index_create_sql)
|
114
201
|
self.conn.commit()
|
115
202
|
except Exception as e:
|
116
203
|
log.warning(
|
117
|
-
f"Failed to create pgvecto.rs
|
204
|
+
f"Failed to create pgvecto.rs index {self._index_name} \
|
205
|
+
at table {self.table_name} error: {e}"
|
118
206
|
)
|
119
207
|
raise e from None
|
120
208
|
|
@@ -122,12 +210,18 @@ class PgVectoRS(VectorDB):
|
|
122
210
|
assert self.conn is not None, "Connection is not initialized"
|
123
211
|
assert self.cursor is not None, "Cursor is not initialized"
|
124
212
|
|
213
|
+
table_create_sql = sql.SQL(
|
214
|
+
"""
|
215
|
+
CREATE TABLE IF NOT EXISTS public.{table_name}
|
216
|
+
(id BIGINT PRIMARY KEY, embedding vector({dim}))
|
217
|
+
"""
|
218
|
+
).format(
|
219
|
+
table_name=sql.Identifier(self.table_name),
|
220
|
+
dim=dim,
|
221
|
+
)
|
125
222
|
try:
|
126
223
|
# create table
|
127
|
-
self.cursor.execute(
|
128
|
-
f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" \
|
129
|
-
(id Integer PRIMARY KEY, embedding vector({dim}));'
|
130
|
-
)
|
224
|
+
self.cursor.execute(table_create_sql)
|
131
225
|
self.conn.commit()
|
132
226
|
except Exception as e:
|
133
227
|
log.warning(
|
@@ -140,7 +234,7 @@ class PgVectoRS(VectorDB):
|
|
140
234
|
embeddings: list[list[float]],
|
141
235
|
metadata: list[int],
|
142
236
|
**kwargs: Any,
|
143
|
-
) ->
|
237
|
+
) -> Tuple[int, Optional[Exception]]:
|
144
238
|
assert self.conn is not None, "Connection is not initialized"
|
145
239
|
assert self.cursor is not None, "Cursor is not initialized"
|
146
240
|
|
@@ -148,19 +242,27 @@ class PgVectoRS(VectorDB):
|
|
148
242
|
assert self.cursor is not None, "Cursor is not initialized"
|
149
243
|
|
150
244
|
try:
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
245
|
+
metadata_arr = np.array(metadata)
|
246
|
+
embeddings_arr = np.array(embeddings)
|
247
|
+
|
248
|
+
with self.cursor.copy(
|
249
|
+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
|
250
|
+
table_name=sql.Identifier(self.table_name)
|
251
|
+
)
|
252
|
+
) as copy:
|
253
|
+
copy.set_types(["bigint", "vector"])
|
254
|
+
for i, row in enumerate(metadata_arr):
|
255
|
+
copy.write_row((row, embeddings_arr[i]))
|
160
256
|
self.conn.commit()
|
257
|
+
|
258
|
+
if kwargs.get("last_batch"):
|
259
|
+
self._post_insert()
|
260
|
+
|
161
261
|
return len(metadata), None
|
162
262
|
except Exception as e:
|
163
|
-
log.warning(
|
263
|
+
log.warning(
|
264
|
+
f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}"
|
265
|
+
)
|
164
266
|
return 0, e
|
165
267
|
|
166
268
|
def search_embedding(
|
@@ -173,20 +275,18 @@ class PgVectoRS(VectorDB):
|
|
173
275
|
assert self.conn is not None, "Connection is not initialized"
|
174
276
|
assert self.cursor is not None, "Cursor is not initialized"
|
175
277
|
|
176
|
-
|
278
|
+
q = np.asarray(query)
|
177
279
|
|
178
280
|
if filters:
|
281
|
+
log.debug(self._filtered_search.as_string(self.cursor))
|
179
282
|
gt = filters.get("id")
|
180
|
-
self.cursor.execute(
|
181
|
-
|
182
|
-
{search_param['metrics_op']} '{query}' LIMIT {k}) AS X WHERE id > {gt} ;"
|
283
|
+
result = self.cursor.execute(
|
284
|
+
self._filtered_search, (gt, q, k), prepare=True, binary=True
|
183
285
|
)
|
184
286
|
else:
|
185
|
-
self.
|
186
|
-
|
187
|
-
|
287
|
+
log.debug(self._unfiltered_search.as_string(self.cursor))
|
288
|
+
result = self.cursor.execute(
|
289
|
+
self._unfiltered_search, (q, k), prepare=True, binary=True
|
188
290
|
)
|
189
|
-
self.conn.commit()
|
190
|
-
result = self.cursor.fetchall()
|
191
291
|
|
192
|
-
return [int(i[0]) for i in result]
|
292
|
+
return [int(i[0]) for i in result.fetchall()]
|
@@ -10,6 +10,7 @@ from ....cli.cli import (
|
|
10
10
|
IVFFlatTypedDict,
|
11
11
|
cli,
|
12
12
|
click_parameter_decorators_from_typed_dict,
|
13
|
+
get_custom_case_config,
|
13
14
|
run,
|
14
15
|
)
|
15
16
|
from vectordb_bench.backend.clients import DB
|
@@ -56,7 +57,15 @@ class PgVectorTypedDict(CommonTypedDict):
|
|
56
57
|
required=False,
|
57
58
|
),
|
58
59
|
]
|
59
|
-
|
60
|
+
quantization_type: Annotated[
|
61
|
+
Optional[str],
|
62
|
+
click.option(
|
63
|
+
"--quantization-type",
|
64
|
+
type=click.Choice(["none", "halfvec"]),
|
65
|
+
help="quantization type for vectors",
|
66
|
+
required=False,
|
67
|
+
),
|
68
|
+
]
|
60
69
|
|
61
70
|
class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
|
62
71
|
...
|
@@ -69,6 +78,7 @@ def PgVectorIVFFlat(
|
|
69
78
|
):
|
70
79
|
from .config import PgVectorConfig, PgVectorIVFFlatConfig
|
71
80
|
|
81
|
+
parameters["custom_case"] = get_custom_case_config(parameters)
|
72
82
|
run(
|
73
83
|
db=DB.PgVector,
|
74
84
|
db_config=PgVectorConfig(
|
@@ -79,7 +89,10 @@ def PgVectorIVFFlat(
|
|
79
89
|
db_name=parameters["db_name"],
|
80
90
|
),
|
81
91
|
db_case_config=PgVectorIVFFlatConfig(
|
82
|
-
metric_type=None,
|
92
|
+
metric_type=None,
|
93
|
+
lists=parameters["lists"],
|
94
|
+
probes=parameters["probes"],
|
95
|
+
quantization_type=parameters["quantization_type"],
|
83
96
|
),
|
84
97
|
**parameters,
|
85
98
|
)
|
@@ -96,6 +109,7 @@ def PgVectorHNSW(
|
|
96
109
|
):
|
97
110
|
from .config import PgVectorConfig, PgVectorHNSWConfig
|
98
111
|
|
112
|
+
parameters["custom_case"] = get_custom_case_config(parameters)
|
99
113
|
run(
|
100
114
|
db=DB.PgVector,
|
101
115
|
db_config=PgVectorConfig(
|
@@ -111,6 +125,7 @@ def PgVectorHNSW(
|
|
111
125
|
ef_search=parameters["ef_search"],
|
112
126
|
maintenance_work_mem=parameters["maintenance_work_mem"],
|
113
127
|
max_parallel_workers=parameters["max_parallel_workers"],
|
128
|
+
quantization_type=parameters["quantization_type"],
|
114
129
|
),
|
115
130
|
**parameters,
|
116
131
|
)
|
@@ -59,11 +59,18 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
|
|
59
59
|
create_index_after_load: bool = True
|
60
60
|
|
61
61
|
def parse_metric(self) -> str:
|
62
|
-
if self.
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
62
|
+
if self.quantization_type == "halfvec":
|
63
|
+
if self.metric_type == MetricType.L2:
|
64
|
+
return "halfvec_l2_ops"
|
65
|
+
elif self.metric_type == MetricType.IP:
|
66
|
+
return "halfvec_ip_ops"
|
67
|
+
return "halfvec_cosine_ops"
|
68
|
+
else:
|
69
|
+
if self.metric_type == MetricType.L2:
|
70
|
+
return "vector_l2_ops"
|
71
|
+
elif self.metric_type == MetricType.IP:
|
72
|
+
return "vector_ip_ops"
|
73
|
+
return "vector_cosine_ops"
|
67
74
|
|
68
75
|
def parse_metric_fun_op(self) -> LiteralString:
|
69
76
|
if self.metric_type == MetricType.L2:
|
@@ -143,9 +150,12 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
|
|
143
150
|
index: IndexType = IndexType.ES_IVFFlat
|
144
151
|
maintenance_work_mem: Optional[str] = None
|
145
152
|
max_parallel_workers: Optional[int] = None
|
153
|
+
quantization_type: Optional[str] = None
|
146
154
|
|
147
155
|
def index_param(self) -> PgVectorIndexParam:
|
148
156
|
index_parameters = {"lists": self.lists}
|
157
|
+
if self.quantization_type == "none":
|
158
|
+
self.quantization_type = None
|
149
159
|
return {
|
150
160
|
"metric": self.parse_metric(),
|
151
161
|
"index_type": self.index.value,
|
@@ -154,6 +164,7 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
|
|
154
164
|
),
|
155
165
|
"maintenance_work_mem": self.maintenance_work_mem,
|
156
166
|
"max_parallel_workers": self.max_parallel_workers,
|
167
|
+
"quantization_type": self.quantization_type,
|
157
168
|
}
|
158
169
|
|
159
170
|
def search_param(self) -> PgVectorSearchParam:
|
@@ -183,9 +194,12 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
|
|
183
194
|
index: IndexType = IndexType.ES_HNSW
|
184
195
|
maintenance_work_mem: Optional[str] = None
|
185
196
|
max_parallel_workers: Optional[int] = None
|
197
|
+
quantization_type: Optional[str] = None
|
186
198
|
|
187
199
|
def index_param(self) -> PgVectorIndexParam:
|
188
200
|
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
|
201
|
+
if self.quantization_type == "none":
|
202
|
+
self.quantization_type = None
|
189
203
|
return {
|
190
204
|
"metric": self.parse_metric(),
|
191
205
|
"index_type": self.index.value,
|
@@ -194,6 +208,7 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
|
|
194
208
|
),
|
195
209
|
"maintenance_work_mem": self.maintenance_work_mem,
|
196
210
|
"max_parallel_workers": self.max_parallel_workers,
|
211
|
+
"quantization_type": self.quantization_type,
|
197
212
|
}
|
198
213
|
|
199
214
|
def search_param(self) -> PgVectorSearchParam:
|
@@ -22,7 +22,7 @@ class PgVector(VectorDB):
|
|
22
22
|
conn: psycopg.Connection[Any] | None = None
|
23
23
|
cursor: psycopg.Cursor[Any] | None = None
|
24
24
|
|
25
|
-
|
25
|
+
_filtered_search: sql.Composed
|
26
26
|
_unfiltered_search: sql.Composed
|
27
27
|
|
28
28
|
def __init__(
|
@@ -112,15 +112,63 @@ class PgVector(VectorDB):
|
|
112
112
|
self.cursor.execute(command)
|
113
113
|
self.conn.commit()
|
114
114
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
115
|
+
index_param = self.case_config.index_param()
|
116
|
+
# The following sections assume that the quantization_type value matches the quantization function name
|
117
|
+
if index_param["quantization_type"] != None:
|
118
|
+
self._filtered_search = sql.Composed(
|
119
|
+
[
|
120
|
+
sql.SQL(
|
121
|
+
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding::{quantization_type}({dim}) "
|
122
|
+
).format(
|
123
|
+
table_name=sql.Identifier(self.table_name),
|
124
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
125
|
+
dim=sql.Literal(self.dim),
|
126
|
+
),
|
127
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
128
|
+
sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
|
129
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
130
|
+
dim=sql.Literal(self.dim),
|
131
|
+
),
|
132
|
+
]
|
133
|
+
)
|
134
|
+
else:
|
135
|
+
self._filtered_search = sql.Composed(
|
136
|
+
[
|
137
|
+
sql.SQL(
|
138
|
+
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
|
139
|
+
).format(table_name=sql.Identifier(self.table_name)),
|
140
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
141
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
142
|
+
]
|
143
|
+
)
|
144
|
+
|
145
|
+
if index_param["quantization_type"] != None:
|
146
|
+
self._unfiltered_search = sql.Composed(
|
147
|
+
[
|
148
|
+
sql.SQL(
|
149
|
+
"SELECT id FROM public.{table_name} ORDER BY embedding::{quantization_type}({dim}) "
|
150
|
+
).format(
|
151
|
+
table_name=sql.Identifier(self.table_name),
|
152
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
153
|
+
dim=sql.Literal(self.dim),
|
154
|
+
),
|
155
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
156
|
+
sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
|
157
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
158
|
+
dim=sql.Literal(self.dim),
|
159
|
+
),
|
160
|
+
]
|
161
|
+
)
|
162
|
+
else:
|
163
|
+
self._unfiltered_search = sql.Composed(
|
164
|
+
[
|
165
|
+
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
|
166
|
+
sql.Identifier(self.table_name)
|
167
|
+
),
|
168
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
169
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
170
|
+
]
|
171
|
+
)
|
124
172
|
|
125
173
|
try:
|
126
174
|
yield
|
@@ -255,17 +303,34 @@ class PgVector(VectorDB):
|
|
255
303
|
else:
|
256
304
|
with_clause = sql.Composed(())
|
257
305
|
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
306
|
+
if index_param["quantization_type"] != None:
|
307
|
+
index_create_sql = sql.SQL(
|
308
|
+
"""
|
309
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
310
|
+
USING {index_type} ((embedding::{quantization_type}({dim})) {embedding_metric})
|
311
|
+
"""
|
312
|
+
).format(
|
313
|
+
index_name=sql.Identifier(self._index_name),
|
314
|
+
table_name=sql.Identifier(self.table_name),
|
315
|
+
index_type=sql.Identifier(index_param["index_type"]),
|
316
|
+
# This assumes that the quantization_type value matches the quantization function name
|
317
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
318
|
+
dim=self.dim,
|
319
|
+
embedding_metric=sql.Identifier(index_param["metric"]),
|
320
|
+
)
|
321
|
+
else:
|
322
|
+
index_create_sql = sql.SQL(
|
323
|
+
"""
|
324
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
325
|
+
USING {index_type} (embedding {embedding_metric})
|
326
|
+
"""
|
327
|
+
).format(
|
328
|
+
index_name=sql.Identifier(self._index_name),
|
329
|
+
table_name=sql.Identifier(self.table_name),
|
330
|
+
index_type=sql.Identifier(index_param["index_type"]),
|
331
|
+
embedding_metric=sql.Identifier(index_param["metric"]),
|
332
|
+
)
|
333
|
+
|
269
334
|
index_create_sql_with_with_clause = (
|
270
335
|
index_create_sql + with_clause
|
271
336
|
).join(" ")
|
@@ -342,9 +407,14 @@ class PgVector(VectorDB):
|
|
342
407
|
assert self.cursor is not None, "Cursor is not initialized"
|
343
408
|
|
344
409
|
q = np.asarray(query)
|
345
|
-
|
346
|
-
|
347
|
-
self.
|
348
|
-
|
410
|
+
if filters:
|
411
|
+
gt = filters.get("id")
|
412
|
+
result = self.cursor.execute(
|
413
|
+
self._filtered_search, (gt, q, k), prepare=True, binary=True
|
414
|
+
)
|
415
|
+
else:
|
416
|
+
result = self.cursor.execute(
|
417
|
+
self._unfiltered_search, (q, k), prepare=True, binary=True
|
418
|
+
)
|
349
419
|
|
350
420
|
return [int(i[0]) for i in result.fetchall()]
|