vectordb-bench 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +22 -0
- vectordb_bench/backend/clients/api.py +21 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
- vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
- vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
- vectordb_bench/backend/clients/memorydb/cli.py +88 -0
- vectordb_bench/backend/clients/memorydb/config.py +54 -0
- vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
- vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
- vectordb_bench/backend/clients/pgvector/cli.py +17 -2
- vectordb_bench/backend/clients/pgvector/config.py +20 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +95 -25
- vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
- vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +290 -0
- vectordb_bench/backend/clients/pinecone/config.py +0 -2
- vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
- vectordb_bench/backend/clients/redis/cli.py +8 -0
- vectordb_bench/backend/clients/redis/config.py +37 -6
- vectordb_bench/backend/runner/mp_runner.py +2 -1
- vectordb_bench/cli/cli.py +137 -0
- vectordb_bench/cli/vectordbbench.py +7 -1
- vectordb_bench/frontend/components/check_results/charts.py +9 -6
- vectordb_bench/frontend/components/check_results/data.py +13 -6
- vectordb_bench/frontend/components/concurrent/charts.py +3 -6
- vectordb_bench/frontend/components/run_test/caseSelector.py +10 -0
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -15
- vectordb_bench/frontend/components/run_test/initStyle.py +3 -1
- vectordb_bench/frontend/config/dbCaseConfigs.py +230 -9
- vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
- vectordb_bench/frontend/vdb_benchmark.py +11 -3
- vectordb_bench/models.py +25 -9
- vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
- vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
- vectordb_bench/results/getLeaderboardData.py +17 -7
- vectordb_bench/results/leaderboard.json +1 -1
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/METADATA +64 -31
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/RECORD +47 -40
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
|
|
1
|
+
import logging, time
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Any, Generator, Optional, Tuple, Type
|
4
|
+
from ..api import VectorDB, DBCaseConfig, IndexType
|
5
|
+
from .config import MemoryDBIndexConfig
|
6
|
+
import redis
|
7
|
+
from redis import Redis
|
8
|
+
from redis.cluster import RedisCluster
|
9
|
+
from redis.commands.search.field import TagField, VectorField, NumericField
|
10
|
+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
11
|
+
from redis.commands.search.query import Query
|
12
|
+
import numpy as np
|
13
|
+
|
14
|
+
|
15
|
+
log = logging.getLogger(__name__)
|
16
|
+
INDEX_NAME = "index" # Vector Index Name
|
17
|
+
|
18
|
+
class MemoryDB(VectorDB):
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
dim: int,
|
22
|
+
db_config: dict,
|
23
|
+
db_case_config: MemoryDBIndexConfig,
|
24
|
+
drop_old: bool = False,
|
25
|
+
**kwargs
|
26
|
+
):
|
27
|
+
|
28
|
+
self.db_config = db_config
|
29
|
+
self.case_config = db_case_config
|
30
|
+
self.collection_name = INDEX_NAME
|
31
|
+
self.target_nodes = RedisCluster.RANDOM if not self.db_config["cmd"] else None
|
32
|
+
self.insert_batch_size = db_case_config.insert_batch_size
|
33
|
+
self.dbsize = kwargs.get("num_rows")
|
34
|
+
|
35
|
+
# Create a MemoryDB connection, if db has password configured, add it to the connection here and in init():
|
36
|
+
log.info(f"Establishing connection to: {self.db_config}")
|
37
|
+
conn = self.get_client(primary=True)
|
38
|
+
log.info(f"Connection established: {conn}")
|
39
|
+
log.info(conn.execute_command("INFO server"))
|
40
|
+
|
41
|
+
if drop_old:
|
42
|
+
try:
|
43
|
+
log.info(f"MemoryDB client getting info for: {INDEX_NAME}")
|
44
|
+
info = conn.ft(INDEX_NAME).info()
|
45
|
+
log.info(f"Index info: {info}")
|
46
|
+
except redis.exceptions.ResponseError as e:
|
47
|
+
log.error(e)
|
48
|
+
drop_old = False
|
49
|
+
log.info(f"MemoryDB client drop_old collection: {self.collection_name}")
|
50
|
+
|
51
|
+
log.info("Executing FLUSHALL")
|
52
|
+
conn.flushall()
|
53
|
+
|
54
|
+
# Since the default behaviour of FLUSHALL is asynchronous, wait for db to be empty
|
55
|
+
self.wait_until(self.wait_for_empty_db, 3, "", conn)
|
56
|
+
if not self.db_config["cmd"]:
|
57
|
+
replica_clients = self.get_client(replicas=True)
|
58
|
+
for rc, host in replica_clients:
|
59
|
+
self.wait_until(self.wait_for_empty_db, 3, "", rc)
|
60
|
+
log.debug(f"Flushall done in the host: {host}")
|
61
|
+
rc.close()
|
62
|
+
|
63
|
+
self.make_index(dim, conn)
|
64
|
+
conn.close()
|
65
|
+
conn = None
|
66
|
+
|
67
|
+
def make_index(self, vector_dimensions: int, conn: redis.Redis):
|
68
|
+
try:
|
69
|
+
# check to see if index exists
|
70
|
+
conn.ft(INDEX_NAME).info()
|
71
|
+
except Exception as e:
|
72
|
+
log.warn(f"Error getting info for index '{INDEX_NAME}': {e}")
|
73
|
+
index_param = self.case_config.index_param()
|
74
|
+
search_param = self.case_config.search_param()
|
75
|
+
vector_parameters = { # Vector Index Type: FLAT or HNSW
|
76
|
+
"TYPE": "FLOAT32",
|
77
|
+
"DIM": vector_dimensions, # Number of Vector Dimensions
|
78
|
+
"DISTANCE_METRIC": index_param["metric"], # Vector Search Distance Metric
|
79
|
+
}
|
80
|
+
if index_param["m"]:
|
81
|
+
vector_parameters["M"] = index_param["m"]
|
82
|
+
if index_param["ef_construction"]:
|
83
|
+
vector_parameters["EF_CONSTRUCTION"] = index_param["ef_construction"]
|
84
|
+
if search_param["ef_runtime"]:
|
85
|
+
vector_parameters["EF_RUNTIME"] = search_param["ef_runtime"]
|
86
|
+
|
87
|
+
schema = (
|
88
|
+
TagField("id"),
|
89
|
+
NumericField("metadata"),
|
90
|
+
VectorField("vector", # Vector Field Name
|
91
|
+
"HNSW", vector_parameters
|
92
|
+
),
|
93
|
+
)
|
94
|
+
|
95
|
+
definition = IndexDefinition(index_type=IndexType.HASH)
|
96
|
+
rs = conn.ft(INDEX_NAME)
|
97
|
+
rs.create_index(schema, definition=definition)
|
98
|
+
|
99
|
+
def get_client(self, **kwargs):
|
100
|
+
"""
|
101
|
+
Gets either cluster connection or normal connection based on `cmd` flag.
|
102
|
+
CMD stands for Cluster Mode Disabled and is a "mode".
|
103
|
+
"""
|
104
|
+
if not self.db_config["cmd"]:
|
105
|
+
# Cluster mode enabled
|
106
|
+
|
107
|
+
client = RedisCluster(
|
108
|
+
host=self.db_config["host"],
|
109
|
+
port=self.db_config["port"],
|
110
|
+
ssl=self.db_config["ssl"],
|
111
|
+
password=self.db_config["password"],
|
112
|
+
ssl_ca_certs=self.db_config["ssl_ca_certs"],
|
113
|
+
ssl_cert_reqs=None,
|
114
|
+
)
|
115
|
+
|
116
|
+
# Ping all nodes to create a connection
|
117
|
+
client.execute_command("PING", target_nodes=RedisCluster.ALL_NODES)
|
118
|
+
replicas = client.get_replicas()
|
119
|
+
|
120
|
+
if len(replicas) > 0:
|
121
|
+
# FT.SEARCH is a keyless command, use READONLY for replica connections
|
122
|
+
client.execute_command("READONLY", target_nodes=RedisCluster.REPLICAS)
|
123
|
+
|
124
|
+
if kwargs.get("primary", False):
|
125
|
+
client = client.get_primaries()[0].redis_connection
|
126
|
+
|
127
|
+
if kwargs.get("replicas", False):
|
128
|
+
# Return client and host name for each replica
|
129
|
+
return [(c.redis_connection, c.host) for c in replicas]
|
130
|
+
|
131
|
+
else:
|
132
|
+
client = Redis(
|
133
|
+
host=self.db_config["host"],
|
134
|
+
port=self.db_config["port"],
|
135
|
+
db=0,
|
136
|
+
ssl=self.db_config["ssl"],
|
137
|
+
password=self.db_config["password"],
|
138
|
+
ssl_ca_certs=self.db_config["ssl_ca_certs"],
|
139
|
+
ssl_cert_reqs=None,
|
140
|
+
)
|
141
|
+
client.execute_command("PING")
|
142
|
+
return client
|
143
|
+
|
144
|
+
@contextmanager
|
145
|
+
def init(self) -> Generator[None, None, None]:
|
146
|
+
""" create and destory connections to database.
|
147
|
+
|
148
|
+
Examples:
|
149
|
+
>>> with self.init():
|
150
|
+
>>> self.insert_embeddings()
|
151
|
+
"""
|
152
|
+
self.conn = self.get_client()
|
153
|
+
search_param = self.case_config.search_param()
|
154
|
+
if search_param["ef_runtime"]:
|
155
|
+
self.ef_runtime_str = f'EF_RUNTIME {search_param["ef_runtime"]}'
|
156
|
+
else:
|
157
|
+
self.ef_runtime_str = ""
|
158
|
+
yield
|
159
|
+
self.conn.close()
|
160
|
+
self.conn = None
|
161
|
+
|
162
|
+
def ready_to_load(self) -> bool:
|
163
|
+
pass
|
164
|
+
|
165
|
+
def optimize(self) -> None:
|
166
|
+
self._post_insert()
|
167
|
+
|
168
|
+
def insert_embeddings(
|
169
|
+
self,
|
170
|
+
embeddings: list[list[float]],
|
171
|
+
metadata: list[int],
|
172
|
+
**kwargs: Any,
|
173
|
+
) -> Tuple[int, Optional[Exception]]:
|
174
|
+
"""Insert embeddings into the database.
|
175
|
+
Should call self.init() first.
|
176
|
+
"""
|
177
|
+
|
178
|
+
try:
|
179
|
+
with self.conn.pipeline(transaction=False) as pipe:
|
180
|
+
for i, embedding in enumerate(embeddings):
|
181
|
+
embedding = np.array(embedding).astype(np.float32)
|
182
|
+
pipe.hset(metadata[i], mapping = {
|
183
|
+
"id": str(metadata[i]),
|
184
|
+
"metadata": metadata[i],
|
185
|
+
"vector": embedding.tobytes(),
|
186
|
+
})
|
187
|
+
# Execute the pipe so we don't keep too much in memory at once
|
188
|
+
if (i + 1) % self.insert_batch_size == 0:
|
189
|
+
pipe.execute()
|
190
|
+
|
191
|
+
pipe.execute()
|
192
|
+
result_len = i + 1
|
193
|
+
except Exception as e:
|
194
|
+
return 0, e
|
195
|
+
|
196
|
+
return result_len, None
|
197
|
+
|
198
|
+
def _post_insert(self):
|
199
|
+
"""Wait for indexing to finish"""
|
200
|
+
client = self.get_client(primary=True)
|
201
|
+
log.info("Waiting for background indexing to finish")
|
202
|
+
args = (self.wait_for_no_activity, 5, "", client)
|
203
|
+
self.wait_until(*args)
|
204
|
+
if not self.db_config["cmd"]:
|
205
|
+
replica_clients = self.get_client(replicas=True)
|
206
|
+
for rc, host_name in replica_clients:
|
207
|
+
args = (self.wait_for_no_activity, 5, "", rc)
|
208
|
+
self.wait_until(*args)
|
209
|
+
log.debug(f"Background indexing completed in the host: {host_name}")
|
210
|
+
rc.close()
|
211
|
+
|
212
|
+
def wait_until(
|
213
|
+
self, condition, interval=5, message="Operation took too long", *args
|
214
|
+
):
|
215
|
+
while not condition(*args):
|
216
|
+
time.sleep(interval)
|
217
|
+
|
218
|
+
def wait_for_no_activity(self, client: redis.RedisCluster | redis.Redis):
|
219
|
+
return (
|
220
|
+
client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY"
|
221
|
+
)
|
222
|
+
|
223
|
+
def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis):
|
224
|
+
return client.execute_command("DBSIZE") == 0
|
225
|
+
|
226
|
+
def search_embedding(
|
227
|
+
self,
|
228
|
+
query: list[float],
|
229
|
+
k: int = 10,
|
230
|
+
filters: dict | None = None,
|
231
|
+
timeout: int | None = None,
|
232
|
+
**kwargs: Any,
|
233
|
+
) -> (list[int]):
|
234
|
+
assert self.conn is not None
|
235
|
+
|
236
|
+
query_vector = np.array(query).astype(np.float32).tobytes()
|
237
|
+
query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
238
|
+
query_params = {"vec": query_vector}
|
239
|
+
|
240
|
+
if filters:
|
241
|
+
# benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
|
242
|
+
# gets exact match for id, and range for metadata if they exist in filters
|
243
|
+
id_value = filters.get("id")
|
244
|
+
# Removing '>=' from the id_value: '>=10000'
|
245
|
+
metadata_value = filters.get("metadata")[2:]
|
246
|
+
if id_value and metadata_value:
|
247
|
+
query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
248
|
+
elif id_value:
|
249
|
+
#gets exact match for id
|
250
|
+
query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
251
|
+
else: #metadata only case, greater than or equal to metadata value
|
252
|
+
query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
253
|
+
res = self.conn.ft(INDEX_NAME).search(query_obj, query_params)
|
254
|
+
return [int(doc["id"]) for doc in res.docs]
|
@@ -0,0 +1,154 @@
|
|
1
|
+
from typing import Annotated, Optional, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
import os
|
5
|
+
from pydantic import SecretStr
|
6
|
+
|
7
|
+
from ....cli.cli import (
|
8
|
+
CommonTypedDict,
|
9
|
+
HNSWFlavor1,
|
10
|
+
IVFFlatTypedDict,
|
11
|
+
cli,
|
12
|
+
click_parameter_decorators_from_typed_dict,
|
13
|
+
run,
|
14
|
+
)
|
15
|
+
from vectordb_bench.backend.clients import DB
|
16
|
+
|
17
|
+
|
18
|
+
class PgVectoRSTypedDict(CommonTypedDict):
|
19
|
+
user_name: Annotated[
|
20
|
+
str, click.option("--user-name", type=str, help="Db username", required=True)
|
21
|
+
]
|
22
|
+
password: Annotated[
|
23
|
+
str,
|
24
|
+
click.option(
|
25
|
+
"--password",
|
26
|
+
type=str,
|
27
|
+
help="Postgres database password",
|
28
|
+
default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
|
29
|
+
show_default="$POSTGRES_PASSWORD",
|
30
|
+
),
|
31
|
+
]
|
32
|
+
|
33
|
+
host: Annotated[
|
34
|
+
str, click.option("--host", type=str, help="Db host", required=True)
|
35
|
+
]
|
36
|
+
db_name: Annotated[
|
37
|
+
str, click.option("--db-name", type=str, help="Db name", required=True)
|
38
|
+
]
|
39
|
+
max_parallel_workers: Annotated[
|
40
|
+
Optional[int],
|
41
|
+
click.option(
|
42
|
+
"--max-parallel-workers",
|
43
|
+
type=int,
|
44
|
+
help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
|
45
|
+
required=False,
|
46
|
+
),
|
47
|
+
]
|
48
|
+
quantization_type: Annotated[
|
49
|
+
str,
|
50
|
+
click.option(
|
51
|
+
"--quantization-type",
|
52
|
+
type=click.Choice(["trivial", "scalar", "product"]),
|
53
|
+
help="quantization type for vectors",
|
54
|
+
required=False,
|
55
|
+
),
|
56
|
+
]
|
57
|
+
quantization_ratio: Annotated[
|
58
|
+
str,
|
59
|
+
click.option(
|
60
|
+
"--quantization-ratio",
|
61
|
+
type=click.Choice(["x4", "x8", "x16", "x32", "x64"]),
|
62
|
+
help="quantization ratio(for product quantization)",
|
63
|
+
required=False,
|
64
|
+
),
|
65
|
+
]
|
66
|
+
|
67
|
+
|
68
|
+
class PgVectoRSFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ...
|
69
|
+
|
70
|
+
|
71
|
+
@cli.command()
|
72
|
+
@click_parameter_decorators_from_typed_dict(PgVectoRSFlatTypedDict)
|
73
|
+
def PgVectoRSFlat(
|
74
|
+
**parameters: Unpack[PgVectoRSFlatTypedDict],
|
75
|
+
):
|
76
|
+
from .config import PgVectoRSConfig, PgVectoRSFLATConfig
|
77
|
+
|
78
|
+
run(
|
79
|
+
db=DB.PgVectoRS,
|
80
|
+
db_config=PgVectoRSConfig(
|
81
|
+
db_label=parameters["db_label"],
|
82
|
+
user_name=SecretStr(parameters["user_name"]),
|
83
|
+
password=SecretStr(parameters["password"]),
|
84
|
+
host=parameters["host"],
|
85
|
+
db_name=parameters["db_name"],
|
86
|
+
),
|
87
|
+
db_case_config=PgVectoRSFLATConfig(
|
88
|
+
max_parallel_workers=parameters["max_parallel_workers"],
|
89
|
+
quantization_type=parameters["quantization_type"],
|
90
|
+
quantization_ratio=parameters["quantization_ratio"],
|
91
|
+
),
|
92
|
+
**parameters,
|
93
|
+
)
|
94
|
+
|
95
|
+
|
96
|
+
class PgVectoRSIVFFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ...
|
97
|
+
|
98
|
+
|
99
|
+
@cli.command()
|
100
|
+
@click_parameter_decorators_from_typed_dict(PgVectoRSIVFFlatTypedDict)
|
101
|
+
def PgVectoRSIVFFlat(
|
102
|
+
**parameters: Unpack[PgVectoRSIVFFlatTypedDict],
|
103
|
+
):
|
104
|
+
from .config import PgVectoRSConfig, PgVectoRSIVFFlatConfig
|
105
|
+
|
106
|
+
run(
|
107
|
+
db=DB.PgVectoRS,
|
108
|
+
db_config=PgVectoRSConfig(
|
109
|
+
db_label=parameters["db_label"],
|
110
|
+
user_name=SecretStr(parameters["user_name"]),
|
111
|
+
password=SecretStr(parameters["password"]),
|
112
|
+
host=parameters["host"],
|
113
|
+
db_name=parameters["db_name"],
|
114
|
+
),
|
115
|
+
db_case_config=PgVectoRSIVFFlatConfig(
|
116
|
+
max_parallel_workers=parameters["max_parallel_workers"],
|
117
|
+
quantization_type=parameters["quantization_type"],
|
118
|
+
quantization_ratio=parameters["quantization_ratio"],
|
119
|
+
probes=parameters["probes"],
|
120
|
+
lists=parameters["lists"],
|
121
|
+
),
|
122
|
+
**parameters,
|
123
|
+
)
|
124
|
+
|
125
|
+
|
126
|
+
class PgVectoRSHNSWTypedDict(PgVectoRSTypedDict, HNSWFlavor1): ...
|
127
|
+
|
128
|
+
|
129
|
+
@cli.command()
|
130
|
+
@click_parameter_decorators_from_typed_dict(PgVectoRSHNSWTypedDict)
|
131
|
+
def PgVectoRSHNSW(
|
132
|
+
**parameters: Unpack[PgVectoRSHNSWTypedDict],
|
133
|
+
):
|
134
|
+
from .config import PgVectoRSConfig, PgVectoRSHNSWConfig
|
135
|
+
|
136
|
+
run(
|
137
|
+
db=DB.PgVectoRS,
|
138
|
+
db_config=PgVectoRSConfig(
|
139
|
+
db_label=parameters["db_label"],
|
140
|
+
user_name=SecretStr(parameters["user_name"]),
|
141
|
+
password=SecretStr(parameters["password"]),
|
142
|
+
host=parameters["host"],
|
143
|
+
db_name=parameters["db_name"],
|
144
|
+
),
|
145
|
+
db_case_config=PgVectoRSHNSWConfig(
|
146
|
+
max_parallel_workers=parameters["max_parallel_workers"],
|
147
|
+
quantization_type=parameters["quantization_type"],
|
148
|
+
quantization_ratio=parameters["quantization_ratio"],
|
149
|
+
m=parameters["m"],
|
150
|
+
ef_construction=parameters["ef_construction"],
|
151
|
+
ef_search=parameters["ef_search"],
|
152
|
+
),
|
153
|
+
**parameters,
|
154
|
+
)
|
@@ -1,30 +1,53 @@
|
|
1
|
-
from
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import TypedDict
|
3
|
+
|
2
4
|
from pydantic import BaseModel, SecretStr
|
3
|
-
from
|
5
|
+
from pgvecto_rs.types import IndexOption, Ivf, Hnsw, Flat, Quantization
|
6
|
+
from pgvecto_rs.types.index import QuantizationType, QuantizationRatio
|
7
|
+
|
8
|
+
from ..api import DBConfig, DBCaseConfig, IndexType, MetricType
|
4
9
|
|
5
10
|
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
|
6
11
|
|
7
12
|
|
13
|
+
class PgVectorRSConfigDict(TypedDict):
|
14
|
+
"""These keys will be directly used as kwargs in psycopg connection string,
|
15
|
+
so the names must match exactly psycopg API"""
|
16
|
+
|
17
|
+
user: str
|
18
|
+
password: str
|
19
|
+
host: str
|
20
|
+
port: int
|
21
|
+
dbname: str
|
22
|
+
|
23
|
+
|
8
24
|
class PgVectoRSConfig(DBConfig):
|
9
|
-
user_name:
|
25
|
+
user_name: str = "postgres"
|
10
26
|
password: SecretStr
|
11
27
|
host: str = "localhost"
|
12
28
|
port: int = 5432
|
13
29
|
db_name: str
|
14
30
|
|
15
31
|
def to_dict(self) -> dict:
|
16
|
-
user_str = self.user_name
|
32
|
+
user_str = self.user_name
|
17
33
|
pwd_str = self.password.get_secret_value()
|
18
34
|
return {
|
19
35
|
"host": self.host,
|
20
36
|
"port": self.port,
|
21
37
|
"dbname": self.db_name,
|
22
38
|
"user": user_str,
|
23
|
-
"password": pwd_str
|
39
|
+
"password": pwd_str,
|
24
40
|
}
|
25
41
|
|
42
|
+
|
26
43
|
class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
|
27
44
|
metric_type: MetricType | None = None
|
45
|
+
create_index_before_load: bool = False
|
46
|
+
create_index_after_load: bool = True
|
47
|
+
|
48
|
+
max_parallel_workers: int | None = None
|
49
|
+
quantization_type: QuantizationType | None = None
|
50
|
+
quantization_ratio: QuantizationRatio | None = None
|
28
51
|
|
29
52
|
def parse_metric(self) -> str:
|
30
53
|
if self.metric_type == MetricType.L2:
|
@@ -40,88 +63,100 @@ class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
|
|
40
63
|
return "<#>"
|
41
64
|
return "<=>"
|
42
65
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
def parse_quantization(self) -> str:
|
48
|
-
if self.quantizationType == "trivial":
|
49
|
-
return "quantization = { trivial = { } }"
|
50
|
-
elif self.quantizationType == "scalar":
|
51
|
-
return "quantization = { scalar = { } }"
|
52
|
-
else:
|
53
|
-
return f'quantization = {{ product = {{ ratio = "{self.quantizationRatio}" }} }}'
|
54
|
-
|
66
|
+
def search_param(self) -> dict:
|
67
|
+
return {
|
68
|
+
"metric_fun_op": self.parse_metric_fun_op(),
|
69
|
+
}
|
55
70
|
|
56
|
-
|
57
|
-
|
58
|
-
efConstruction: int
|
59
|
-
index: IndexType = IndexType.HNSW
|
71
|
+
@abstractmethod
|
72
|
+
def index_param(self) -> dict[str, str]: ...
|
60
73
|
|
61
|
-
|
62
|
-
|
63
|
-
[indexing.hnsw]
|
64
|
-
m = {self.M}
|
65
|
-
ef_construction = {self.efConstruction}
|
66
|
-
{self.parse_quantization()}
|
67
|
-
"""
|
68
|
-
return {"options": options, "metric": self.parse_metric()}
|
74
|
+
@abstractmethod
|
75
|
+
def session_param(self) -> dict[str, str | int]: ...
|
69
76
|
|
70
|
-
def search_param(self) -> dict:
|
71
|
-
return {"metrics_op": self.parse_metric_fun_op()}
|
72
77
|
|
78
|
+
class PgVectoRSHNSWConfig(PgVectoRSIndexConfig):
|
79
|
+
index: IndexType = IndexType.HNSW
|
80
|
+
m: int | None = None
|
81
|
+
ef_search: int | None
|
82
|
+
ef_construction: int | None = None
|
73
83
|
|
74
|
-
|
75
|
-
|
76
|
-
|
84
|
+
def index_param(self) -> dict[str, str]:
|
85
|
+
if self.quantization_type is None:
|
86
|
+
quantization = None
|
87
|
+
else:
|
88
|
+
quantization = Quantization(
|
89
|
+
typ=self.quantization_type, ratio=self.quantization_ratio
|
90
|
+
)
|
91
|
+
|
92
|
+
option = IndexOption(
|
93
|
+
index=Hnsw(
|
94
|
+
m=self.m,
|
95
|
+
ef_construction=self.ef_construction,
|
96
|
+
quantization=quantization,
|
97
|
+
),
|
98
|
+
threads=self.max_parallel_workers,
|
99
|
+
)
|
100
|
+
return {"options": option.dumps(), "metric": self.parse_metric()}
|
101
|
+
|
102
|
+
def session_param(self) -> dict[str, str | int]:
|
103
|
+
session_parameters = {}
|
104
|
+
if self.ef_search is not None:
|
105
|
+
session_parameters["vectors.hnsw_ef_search"] = str(self.ef_search)
|
106
|
+
return session_parameters
|
107
|
+
|
108
|
+
|
109
|
+
class PgVectoRSIVFFlatConfig(PgVectoRSIndexConfig):
|
77
110
|
index: IndexType = IndexType.IVFFlat
|
111
|
+
probes: int | None
|
112
|
+
lists: int | None
|
113
|
+
|
114
|
+
def index_param(self) -> dict[str, str]:
|
115
|
+
if self.quantization_type is None:
|
116
|
+
quantization = None
|
117
|
+
else:
|
118
|
+
quantization = Quantization(
|
119
|
+
typ=self.quantization_type, ratio=self.quantization_ratio
|
120
|
+
)
|
78
121
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
{self.parse_quantization()}
|
85
|
-
"""
|
86
|
-
return {"options": options, "metric": self.parse_metric()}
|
122
|
+
option = IndexOption(
|
123
|
+
index=Ivf(nlist=self.lists, quantization=quantization),
|
124
|
+
threads=self.max_parallel_workers,
|
125
|
+
)
|
126
|
+
return {"options": option.dumps(), "metric": self.parse_metric()}
|
87
127
|
|
88
|
-
def
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
nprobe: int | None = None
|
94
|
-
index: IndexType = IndexType.IVFSQ8
|
95
|
-
|
96
|
-
def index_param(self) -> dict:
|
97
|
-
options = f"""
|
98
|
-
[indexing.ivf]
|
99
|
-
nlist = {self.nlist}
|
100
|
-
nsample = {self.nprobe if self.nprobe else 10}
|
101
|
-
quantization = {{ scalar = {{ }} }}
|
102
|
-
"""
|
103
|
-
return {"options": options, "metric": self.parse_metric()}
|
128
|
+
def session_param(self) -> dict[str, str | int]:
|
129
|
+
session_parameters = {}
|
130
|
+
if self.probes is not None:
|
131
|
+
session_parameters["vectors.ivf_nprobe"] = str(self.probes)
|
132
|
+
return session_parameters
|
104
133
|
|
105
|
-
def search_param(self) -> dict:
|
106
|
-
return {"metrics_op": self.parse_metric_fun_op()}
|
107
134
|
|
108
|
-
class
|
135
|
+
class PgVectoRSFLATConfig(PgVectoRSIndexConfig):
|
109
136
|
index: IndexType = IndexType.Flat
|
110
137
|
|
111
|
-
def index_param(self) -> dict:
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
138
|
+
def index_param(self) -> dict[str, str]:
|
139
|
+
if self.quantization_type is None:
|
140
|
+
quantization = None
|
141
|
+
else:
|
142
|
+
quantization = Quantization(
|
143
|
+
typ=self.quantization_type, ratio=self.quantization_ratio
|
144
|
+
)
|
117
145
|
|
118
|
-
|
119
|
-
|
146
|
+
option = IndexOption(
|
147
|
+
index=Flat(
|
148
|
+
quantization=quantization,
|
149
|
+
),
|
150
|
+
threads=self.max_parallel_workers,
|
151
|
+
)
|
152
|
+
return {"options": option.dumps(), "metric": self.parse_metric()}
|
153
|
+
|
154
|
+
def session_param(self) -> dict[str, str | int]:
|
155
|
+
return {}
|
120
156
|
|
121
157
|
|
122
158
|
_pgvecto_rs_case_config = {
|
123
|
-
IndexType.HNSW:
|
124
|
-
IndexType.IVFFlat:
|
125
|
-
IndexType.
|
126
|
-
IndexType.Flat: FLATConfig,
|
159
|
+
IndexType.HNSW: PgVectoRSHNSWConfig,
|
160
|
+
IndexType.IVFFlat: PgVectoRSIVFFlatConfig,
|
161
|
+
IndexType.Flat: PgVectoRSFLATConfig,
|
127
162
|
}
|