vectordb-bench 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +22 -0
- vectordb_bench/backend/clients/api.py +21 -1
- vectordb_bench/backend/clients/memorydb/cli.py +88 -0
- vectordb_bench/backend/clients/memorydb/config.py +54 -0
- vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
- vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
- vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +272 -0
- vectordb_bench/cli/vectordbbench.py +5 -0
- vectordb_bench/frontend/components/check_results/data.py +13 -6
- vectordb_bench/frontend/components/run_test/caseSelector.py +10 -0
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -15
- vectordb_bench/frontend/components/run_test/initStyle.py +3 -1
- vectordb_bench/frontend/config/dbCaseConfigs.py +173 -9
- vectordb_bench/models.py +18 -6
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/METADATA +11 -3
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/RECORD +23 -17
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.13.dist-info}/top_level.txt +0 -0
@@ -30,7 +30,9 @@ class DB(Enum):
|
|
30
30
|
WeaviateCloud = "WeaviateCloud"
|
31
31
|
PgVector = "PgVector"
|
32
32
|
PgVectoRS = "PgVectoRS"
|
33
|
+
PgVectorScale = "PgVectorScale"
|
33
34
|
Redis = "Redis"
|
35
|
+
MemoryDB = "MemoryDB"
|
34
36
|
Chroma = "Chroma"
|
35
37
|
AWSOpenSearch = "OpenSearch"
|
36
38
|
Test = "test"
|
@@ -70,10 +72,18 @@ class DB(Enum):
|
|
70
72
|
if self == DB.PgVectoRS:
|
71
73
|
from .pgvecto_rs.pgvecto_rs import PgVectoRS
|
72
74
|
return PgVectoRS
|
75
|
+
|
76
|
+
if self == DB.PgVectorScale:
|
77
|
+
from .pgvectorscale.pgvectorscale import PgVectorScale
|
78
|
+
return PgVectorScale
|
73
79
|
|
74
80
|
if self == DB.Redis:
|
75
81
|
from .redis.redis import Redis
|
76
82
|
return Redis
|
83
|
+
|
84
|
+
if self == DB.MemoryDB:
|
85
|
+
from .memorydb.memorydb import MemoryDB
|
86
|
+
return MemoryDB
|
77
87
|
|
78
88
|
if self == DB.Chroma:
|
79
89
|
from .chroma.chroma import ChromaClient
|
@@ -118,9 +128,17 @@ class DB(Enum):
|
|
118
128
|
from .pgvecto_rs.config import PgVectoRSConfig
|
119
129
|
return PgVectoRSConfig
|
120
130
|
|
131
|
+
if self == DB.PgVectorScale:
|
132
|
+
from .pgvectorscale.config import PgVectorScaleConfig
|
133
|
+
return PgVectorScaleConfig
|
134
|
+
|
121
135
|
if self == DB.Redis:
|
122
136
|
from .redis.config import RedisConfig
|
123
137
|
return RedisConfig
|
138
|
+
|
139
|
+
if self == DB.MemoryDB:
|
140
|
+
from .memorydb.config import MemoryDBConfig
|
141
|
+
return MemoryDBConfig
|
124
142
|
|
125
143
|
if self == DB.Chroma:
|
126
144
|
from .chroma.config import ChromaConfig
|
@@ -163,6 +181,10 @@ class DB(Enum):
|
|
163
181
|
from .aws_opensearch.config import AWSOpenSearchIndexConfig
|
164
182
|
return AWSOpenSearchIndexConfig
|
165
183
|
|
184
|
+
if self == DB.PgVectorScale:
|
185
|
+
from .pgvectorscale.config import _pgvectorscale_case_config
|
186
|
+
return _pgvectorscale_case_config.get(index_type)
|
187
|
+
|
166
188
|
# DB.Pinecone, DB.Chroma, DB.Redis
|
167
189
|
return EmptyDBCaseConfig
|
168
190
|
|
@@ -15,6 +15,7 @@ class MetricType(str, Enum):
|
|
15
15
|
class IndexType(str, Enum):
|
16
16
|
HNSW = "HNSW"
|
17
17
|
DISKANN = "DISKANN"
|
18
|
+
STREAMING_DISKANN = "DISKANN"
|
18
19
|
IVFFlat = "IVF_FLAT"
|
19
20
|
IVFSQ8 = "IVF_SQ8"
|
20
21
|
Flat = "FLAT"
|
@@ -38,6 +39,22 @@ class DBConfig(ABC, BaseModel):
|
|
38
39
|
"""
|
39
40
|
|
40
41
|
db_label: str = ""
|
42
|
+
version: str = ""
|
43
|
+
note: str = ""
|
44
|
+
|
45
|
+
@staticmethod
|
46
|
+
def common_short_configs() -> list[str]:
|
47
|
+
"""
|
48
|
+
short input, such as `db_label`, `version`
|
49
|
+
"""
|
50
|
+
return ["version", "db_label"]
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
def common_long_configs() -> list[str]:
|
54
|
+
"""
|
55
|
+
long input, such as `note`
|
56
|
+
"""
|
57
|
+
return ["note"]
|
41
58
|
|
42
59
|
@abstractmethod
|
43
60
|
def to_dict(self) -> dict:
|
@@ -45,7 +62,10 @@ class DBConfig(ABC, BaseModel):
|
|
45
62
|
|
46
63
|
@validator("*")
|
47
64
|
def not_empty_field(cls, v, field):
|
48
|
-
if
|
65
|
+
if (
|
66
|
+
field.name in cls.common_short_configs()
|
67
|
+
or field.name in cls.common_long_configs()
|
68
|
+
):
|
49
69
|
return v
|
50
70
|
if not v and isinstance(v, (str, SecretStr)):
|
51
71
|
raise ValueError("Empty string!")
|
@@ -0,0 +1,88 @@
|
|
1
|
+
from typing import Annotated, TypedDict, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from ....cli.cli import (
|
7
|
+
CommonTypedDict,
|
8
|
+
HNSWFlavor2,
|
9
|
+
cli,
|
10
|
+
click_parameter_decorators_from_typed_dict,
|
11
|
+
run,
|
12
|
+
)
|
13
|
+
from .. import DB
|
14
|
+
|
15
|
+
|
16
|
+
class MemoryDBTypedDict(TypedDict):
|
17
|
+
host: Annotated[
|
18
|
+
str, click.option("--host", type=str, help="Db host", required=True)
|
19
|
+
]
|
20
|
+
password: Annotated[str, click.option("--password", type=str, help="Db password")]
|
21
|
+
port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
|
22
|
+
ssl: Annotated[
|
23
|
+
bool,
|
24
|
+
click.option(
|
25
|
+
"--ssl/--no-ssl",
|
26
|
+
is_flag=True,
|
27
|
+
show_default=True,
|
28
|
+
default=True,
|
29
|
+
help="Enable or disable SSL for MemoryDB",
|
30
|
+
),
|
31
|
+
]
|
32
|
+
ssl_ca_certs: Annotated[
|
33
|
+
str,
|
34
|
+
click.option(
|
35
|
+
"--ssl-ca-certs",
|
36
|
+
show_default=True,
|
37
|
+
help="Path to certificate authority file to use for SSL",
|
38
|
+
),
|
39
|
+
]
|
40
|
+
cmd: Annotated[
|
41
|
+
bool,
|
42
|
+
click.option(
|
43
|
+
"--cmd",
|
44
|
+
is_flag=True,
|
45
|
+
show_default=True,
|
46
|
+
default=False,
|
47
|
+
help="Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance. In production, MemoryDB only supports cluster mode (CME)",
|
48
|
+
),
|
49
|
+
]
|
50
|
+
insert_batch_size: Annotated[
|
51
|
+
int,
|
52
|
+
click.option(
|
53
|
+
"--insert-batch-size",
|
54
|
+
type=int,
|
55
|
+
default=10,
|
56
|
+
help="Batch size for inserting data. Adjust this as needed, but don't make it too big",
|
57
|
+
),
|
58
|
+
]
|
59
|
+
|
60
|
+
|
61
|
+
class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2):
|
62
|
+
...
|
63
|
+
|
64
|
+
|
65
|
+
@cli.command()
|
66
|
+
@click_parameter_decorators_from_typed_dict(MemoryDBHNSWTypedDict)
|
67
|
+
def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]):
|
68
|
+
from .config import MemoryDBConfig, MemoryDBHNSWConfig
|
69
|
+
|
70
|
+
run(
|
71
|
+
db=DB.MemoryDB,
|
72
|
+
db_config=MemoryDBConfig(
|
73
|
+
db_label=parameters["db_label"],
|
74
|
+
password=SecretStr(parameters["password"]) if parameters["password"] else None,
|
75
|
+
host=SecretStr(parameters["host"]),
|
76
|
+
port=parameters["port"],
|
77
|
+
ssl=parameters["ssl"],
|
78
|
+
ssl_ca_certs=parameters["ssl_ca_certs"],
|
79
|
+
cmd=parameters["cmd"],
|
80
|
+
),
|
81
|
+
db_case_config=MemoryDBHNSWConfig(
|
82
|
+
M=parameters["m"],
|
83
|
+
ef_construction=parameters["ef_construction"],
|
84
|
+
ef_runtime=parameters["ef_runtime"],
|
85
|
+
insert_batch_size=parameters["insert_batch_size"]
|
86
|
+
),
|
87
|
+
**parameters,
|
88
|
+
)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
|
3
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
4
|
+
|
5
|
+
|
6
|
+
class MemoryDBConfig(DBConfig):
|
7
|
+
host: SecretStr
|
8
|
+
password: SecretStr | None = None
|
9
|
+
port: int | None = None
|
10
|
+
ssl: bool | None = None
|
11
|
+
cmd: bool | None = None
|
12
|
+
ssl_ca_certs: str | None = None
|
13
|
+
|
14
|
+
def to_dict(self) -> dict:
|
15
|
+
return {
|
16
|
+
"host": self.host.get_secret_value(),
|
17
|
+
"port": self.port,
|
18
|
+
"password": self.password.get_secret_value() if self.password else None,
|
19
|
+
"ssl": self.ssl,
|
20
|
+
"cmd": self.cmd,
|
21
|
+
"ssl_ca_certs": self.ssl_ca_certs,
|
22
|
+
}
|
23
|
+
|
24
|
+
|
25
|
+
class MemoryDBIndexConfig(BaseModel, DBCaseConfig):
|
26
|
+
metric_type: MetricType | None = None
|
27
|
+
insert_batch_size: int | None = None
|
28
|
+
|
29
|
+
def parse_metric(self) -> str:
|
30
|
+
if self.metric_type == MetricType.L2:
|
31
|
+
return "l2"
|
32
|
+
elif self.metric_type == MetricType.IP:
|
33
|
+
return "ip"
|
34
|
+
return "cosine"
|
35
|
+
|
36
|
+
|
37
|
+
class MemoryDBHNSWConfig(MemoryDBIndexConfig):
|
38
|
+
M: int | None = 16
|
39
|
+
ef_construction: int | None = 64
|
40
|
+
ef_runtime: int | None = 10
|
41
|
+
index: IndexType = IndexType.HNSW
|
42
|
+
|
43
|
+
def index_param(self) -> dict:
|
44
|
+
return {
|
45
|
+
"metric": self.parse_metric(),
|
46
|
+
"index_type": self.index.value,
|
47
|
+
"m": self.M,
|
48
|
+
"ef_construction": self.ef_construction,
|
49
|
+
}
|
50
|
+
|
51
|
+
def search_param(self) -> dict:
|
52
|
+
return {
|
53
|
+
"ef_runtime": self.ef_runtime,
|
54
|
+
}
|
@@ -0,0 +1,254 @@
|
|
1
|
+
import logging, time
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Any, Generator, Optional, Tuple, Type
|
4
|
+
from ..api import VectorDB, DBCaseConfig, IndexType
|
5
|
+
from .config import MemoryDBIndexConfig
|
6
|
+
import redis
|
7
|
+
from redis import Redis
|
8
|
+
from redis.cluster import RedisCluster
|
9
|
+
from redis.commands.search.field import TagField, VectorField, NumericField
|
10
|
+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
11
|
+
from redis.commands.search.query import Query
|
12
|
+
import numpy as np
|
13
|
+
|
14
|
+
|
15
|
+
log = logging.getLogger(__name__)
|
16
|
+
INDEX_NAME = "index" # Vector Index Name
|
17
|
+
|
18
|
+
class MemoryDB(VectorDB):
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
dim: int,
|
22
|
+
db_config: dict,
|
23
|
+
db_case_config: MemoryDBIndexConfig,
|
24
|
+
drop_old: bool = False,
|
25
|
+
**kwargs
|
26
|
+
):
|
27
|
+
|
28
|
+
self.db_config = db_config
|
29
|
+
self.case_config = db_case_config
|
30
|
+
self.collection_name = INDEX_NAME
|
31
|
+
self.target_nodes = RedisCluster.RANDOM if not self.db_config["cmd"] else None
|
32
|
+
self.insert_batch_size = db_case_config.insert_batch_size
|
33
|
+
self.dbsize = kwargs.get("num_rows")
|
34
|
+
|
35
|
+
# Create a MemoryDB connection, if db has password configured, add it to the connection here and in init():
|
36
|
+
log.info(f"Establishing connection to: {self.db_config}")
|
37
|
+
conn = self.get_client(primary=True)
|
38
|
+
log.info(f"Connection established: {conn}")
|
39
|
+
log.info(conn.execute_command("INFO server"))
|
40
|
+
|
41
|
+
if drop_old:
|
42
|
+
try:
|
43
|
+
log.info(f"MemoryDB client getting info for: {INDEX_NAME}")
|
44
|
+
info = conn.ft(INDEX_NAME).info()
|
45
|
+
log.info(f"Index info: {info}")
|
46
|
+
except redis.exceptions.ResponseError as e:
|
47
|
+
log.error(e)
|
48
|
+
drop_old = False
|
49
|
+
log.info(f"MemoryDB client drop_old collection: {self.collection_name}")
|
50
|
+
|
51
|
+
log.info("Executing FLUSHALL")
|
52
|
+
conn.flushall()
|
53
|
+
|
54
|
+
# Since the default behaviour of FLUSHALL is asynchronous, wait for db to be empty
|
55
|
+
self.wait_until(self.wait_for_empty_db, 3, "", conn)
|
56
|
+
if not self.db_config["cmd"]:
|
57
|
+
replica_clients = self.get_client(replicas=True)
|
58
|
+
for rc, host in replica_clients:
|
59
|
+
self.wait_until(self.wait_for_empty_db, 3, "", rc)
|
60
|
+
log.debug(f"Flushall done in the host: {host}")
|
61
|
+
rc.close()
|
62
|
+
|
63
|
+
self.make_index(dim, conn)
|
64
|
+
conn.close()
|
65
|
+
conn = None
|
66
|
+
|
67
|
+
def make_index(self, vector_dimensions: int, conn: redis.Redis):
|
68
|
+
try:
|
69
|
+
# check to see if index exists
|
70
|
+
conn.ft(INDEX_NAME).info()
|
71
|
+
except Exception as e:
|
72
|
+
log.warn(f"Error getting info for index '{INDEX_NAME}': {e}")
|
73
|
+
index_param = self.case_config.index_param()
|
74
|
+
search_param = self.case_config.search_param()
|
75
|
+
vector_parameters = { # Vector Index Type: FLAT or HNSW
|
76
|
+
"TYPE": "FLOAT32",
|
77
|
+
"DIM": vector_dimensions, # Number of Vector Dimensions
|
78
|
+
"DISTANCE_METRIC": index_param["metric"], # Vector Search Distance Metric
|
79
|
+
}
|
80
|
+
if index_param["m"]:
|
81
|
+
vector_parameters["M"] = index_param["m"]
|
82
|
+
if index_param["ef_construction"]:
|
83
|
+
vector_parameters["EF_CONSTRUCTION"] = index_param["ef_construction"]
|
84
|
+
if search_param["ef_runtime"]:
|
85
|
+
vector_parameters["EF_RUNTIME"] = search_param["ef_runtime"]
|
86
|
+
|
87
|
+
schema = (
|
88
|
+
TagField("id"),
|
89
|
+
NumericField("metadata"),
|
90
|
+
VectorField("vector", # Vector Field Name
|
91
|
+
"HNSW", vector_parameters
|
92
|
+
),
|
93
|
+
)
|
94
|
+
|
95
|
+
definition = IndexDefinition(index_type=IndexType.HASH)
|
96
|
+
rs = conn.ft(INDEX_NAME)
|
97
|
+
rs.create_index(schema, definition=definition)
|
98
|
+
|
99
|
+
def get_client(self, **kwargs):
|
100
|
+
"""
|
101
|
+
Gets either cluster connection or normal connection based on `cmd` flag.
|
102
|
+
CMD stands for Cluster Mode Disabled and is a "mode".
|
103
|
+
"""
|
104
|
+
if not self.db_config["cmd"]:
|
105
|
+
# Cluster mode enabled
|
106
|
+
|
107
|
+
client = RedisCluster(
|
108
|
+
host=self.db_config["host"],
|
109
|
+
port=self.db_config["port"],
|
110
|
+
ssl=self.db_config["ssl"],
|
111
|
+
password=self.db_config["password"],
|
112
|
+
ssl_ca_certs=self.db_config["ssl_ca_certs"],
|
113
|
+
ssl_cert_reqs=None,
|
114
|
+
)
|
115
|
+
|
116
|
+
# Ping all nodes to create a connection
|
117
|
+
client.execute_command("PING", target_nodes=RedisCluster.ALL_NODES)
|
118
|
+
replicas = client.get_replicas()
|
119
|
+
|
120
|
+
if len(replicas) > 0:
|
121
|
+
# FT.SEARCH is a keyless command, use READONLY for replica connections
|
122
|
+
client.execute_command("READONLY", target_nodes=RedisCluster.REPLICAS)
|
123
|
+
|
124
|
+
if kwargs.get("primary", False):
|
125
|
+
client = client.get_primaries()[0].redis_connection
|
126
|
+
|
127
|
+
if kwargs.get("replicas", False):
|
128
|
+
# Return client and host name for each replica
|
129
|
+
return [(c.redis_connection, c.host) for c in replicas]
|
130
|
+
|
131
|
+
else:
|
132
|
+
client = Redis(
|
133
|
+
host=self.db_config["host"],
|
134
|
+
port=self.db_config["port"],
|
135
|
+
db=0,
|
136
|
+
ssl=self.db_config["ssl"],
|
137
|
+
password=self.db_config["password"],
|
138
|
+
ssl_ca_certs=self.db_config["ssl_ca_certs"],
|
139
|
+
ssl_cert_reqs=None,
|
140
|
+
)
|
141
|
+
client.execute_command("PING")
|
142
|
+
return client
|
143
|
+
|
144
|
+
@contextmanager
|
145
|
+
def init(self) -> Generator[None, None, None]:
|
146
|
+
""" create and destory connections to database.
|
147
|
+
|
148
|
+
Examples:
|
149
|
+
>>> with self.init():
|
150
|
+
>>> self.insert_embeddings()
|
151
|
+
"""
|
152
|
+
self.conn = self.get_client()
|
153
|
+
search_param = self.case_config.search_param()
|
154
|
+
if search_param["ef_runtime"]:
|
155
|
+
self.ef_runtime_str = f'EF_RUNTIME {search_param["ef_runtime"]}'
|
156
|
+
else:
|
157
|
+
self.ef_runtime_str = ""
|
158
|
+
yield
|
159
|
+
self.conn.close()
|
160
|
+
self.conn = None
|
161
|
+
|
162
|
+
def ready_to_load(self) -> bool:
|
163
|
+
pass
|
164
|
+
|
165
|
+
def optimize(self) -> None:
|
166
|
+
self._post_insert()
|
167
|
+
|
168
|
+
def insert_embeddings(
|
169
|
+
self,
|
170
|
+
embeddings: list[list[float]],
|
171
|
+
metadata: list[int],
|
172
|
+
**kwargs: Any,
|
173
|
+
) -> Tuple[int, Optional[Exception]]:
|
174
|
+
"""Insert embeddings into the database.
|
175
|
+
Should call self.init() first.
|
176
|
+
"""
|
177
|
+
|
178
|
+
try:
|
179
|
+
with self.conn.pipeline(transaction=False) as pipe:
|
180
|
+
for i, embedding in enumerate(embeddings):
|
181
|
+
embedding = np.array(embedding).astype(np.float32)
|
182
|
+
pipe.hset(metadata[i], mapping = {
|
183
|
+
"id": str(metadata[i]),
|
184
|
+
"metadata": metadata[i],
|
185
|
+
"vector": embedding.tobytes(),
|
186
|
+
})
|
187
|
+
# Execute the pipe so we don't keep too much in memory at once
|
188
|
+
if (i + 1) % self.insert_batch_size == 0:
|
189
|
+
pipe.execute()
|
190
|
+
|
191
|
+
pipe.execute()
|
192
|
+
result_len = i + 1
|
193
|
+
except Exception as e:
|
194
|
+
return 0, e
|
195
|
+
|
196
|
+
return result_len, None
|
197
|
+
|
198
|
+
def _post_insert(self):
|
199
|
+
"""Wait for indexing to finish"""
|
200
|
+
client = self.get_client(primary=True)
|
201
|
+
log.info("Waiting for background indexing to finish")
|
202
|
+
args = (self.wait_for_no_activity, 5, "", client)
|
203
|
+
self.wait_until(*args)
|
204
|
+
if not self.db_config["cmd"]:
|
205
|
+
replica_clients = self.get_client(replicas=True)
|
206
|
+
for rc, host_name in replica_clients:
|
207
|
+
args = (self.wait_for_no_activity, 5, "", rc)
|
208
|
+
self.wait_until(*args)
|
209
|
+
log.debug(f"Background indexing completed in the host: {host_name}")
|
210
|
+
rc.close()
|
211
|
+
|
212
|
+
def wait_until(
|
213
|
+
self, condition, interval=5, message="Operation took too long", *args
|
214
|
+
):
|
215
|
+
while not condition(*args):
|
216
|
+
time.sleep(interval)
|
217
|
+
|
218
|
+
def wait_for_no_activity(self, client: redis.RedisCluster | redis.Redis):
|
219
|
+
return (
|
220
|
+
client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY"
|
221
|
+
)
|
222
|
+
|
223
|
+
def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis):
|
224
|
+
return client.execute_command("DBSIZE") == 0
|
225
|
+
|
226
|
+
def search_embedding(
|
227
|
+
self,
|
228
|
+
query: list[float],
|
229
|
+
k: int = 10,
|
230
|
+
filters: dict | None = None,
|
231
|
+
timeout: int | None = None,
|
232
|
+
**kwargs: Any,
|
233
|
+
) -> (list[int]):
|
234
|
+
assert self.conn is not None
|
235
|
+
|
236
|
+
query_vector = np.array(query).astype(np.float32).tobytes()
|
237
|
+
query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
238
|
+
query_params = {"vec": query_vector}
|
239
|
+
|
240
|
+
if filters:
|
241
|
+
# benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
|
242
|
+
# gets exact match for id, and range for metadata if they exist in filters
|
243
|
+
id_value = filters.get("id")
|
244
|
+
# Removing '>=' from the id_value: '>=10000'
|
245
|
+
metadata_value = filters.get("metadata")[2:]
|
246
|
+
if id_value and metadata_value:
|
247
|
+
query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
248
|
+
elif id_value:
|
249
|
+
#gets exact match for id
|
250
|
+
query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
251
|
+
else: #metadata only case, greater than or equal to metadata value
|
252
|
+
query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
|
253
|
+
res = self.conn.ft(INDEX_NAME).search(query_obj, query_params)
|
254
|
+
return [int(doc["id"]) for doc in res.docs]
|
@@ -0,0 +1,154 @@
|
|
1
|
+
from typing import Annotated, Optional, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
import os
|
5
|
+
from pydantic import SecretStr
|
6
|
+
|
7
|
+
from ....cli.cli import (
|
8
|
+
CommonTypedDict,
|
9
|
+
HNSWFlavor1,
|
10
|
+
IVFFlatTypedDict,
|
11
|
+
cli,
|
12
|
+
click_parameter_decorators_from_typed_dict,
|
13
|
+
run,
|
14
|
+
)
|
15
|
+
from vectordb_bench.backend.clients import DB
|
16
|
+
|
17
|
+
|
18
|
+
class PgVectoRSTypedDict(CommonTypedDict):
|
19
|
+
user_name: Annotated[
|
20
|
+
str, click.option("--user-name", type=str, help="Db username", required=True)
|
21
|
+
]
|
22
|
+
password: Annotated[
|
23
|
+
str,
|
24
|
+
click.option(
|
25
|
+
"--password",
|
26
|
+
type=str,
|
27
|
+
help="Postgres database password",
|
28
|
+
default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
|
29
|
+
show_default="$POSTGRES_PASSWORD",
|
30
|
+
),
|
31
|
+
]
|
32
|
+
|
33
|
+
host: Annotated[
|
34
|
+
str, click.option("--host", type=str, help="Db host", required=True)
|
35
|
+
]
|
36
|
+
db_name: Annotated[
|
37
|
+
str, click.option("--db-name", type=str, help="Db name", required=True)
|
38
|
+
]
|
39
|
+
max_parallel_workers: Annotated[
|
40
|
+
Optional[int],
|
41
|
+
click.option(
|
42
|
+
"--max-parallel-workers",
|
43
|
+
type=int,
|
44
|
+
help="Sets the maximum number of parallel processes per maintenance operation (index creation)",
|
45
|
+
required=False,
|
46
|
+
),
|
47
|
+
]
|
48
|
+
quantization_type: Annotated[
|
49
|
+
str,
|
50
|
+
click.option(
|
51
|
+
"--quantization-type",
|
52
|
+
type=click.Choice(["trivial", "scalar", "product"]),
|
53
|
+
help="quantization type for vectors",
|
54
|
+
required=False,
|
55
|
+
),
|
56
|
+
]
|
57
|
+
quantization_ratio: Annotated[
|
58
|
+
str,
|
59
|
+
click.option(
|
60
|
+
"--quantization-ratio",
|
61
|
+
type=click.Choice(["x4", "x8", "x16", "x32", "x64"]),
|
62
|
+
help="quantization ratio(for product quantization)",
|
63
|
+
required=False,
|
64
|
+
),
|
65
|
+
]
|
66
|
+
|
67
|
+
|
68
|
+
class PgVectoRSFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ...
|
69
|
+
|
70
|
+
|
71
|
+
@cli.command()
|
72
|
+
@click_parameter_decorators_from_typed_dict(PgVectoRSFlatTypedDict)
|
73
|
+
def PgVectoRSFlat(
|
74
|
+
**parameters: Unpack[PgVectoRSFlatTypedDict],
|
75
|
+
):
|
76
|
+
from .config import PgVectoRSConfig, PgVectoRSFLATConfig
|
77
|
+
|
78
|
+
run(
|
79
|
+
db=DB.PgVectoRS,
|
80
|
+
db_config=PgVectoRSConfig(
|
81
|
+
db_label=parameters["db_label"],
|
82
|
+
user_name=SecretStr(parameters["user_name"]),
|
83
|
+
password=SecretStr(parameters["password"]),
|
84
|
+
host=parameters["host"],
|
85
|
+
db_name=parameters["db_name"],
|
86
|
+
),
|
87
|
+
db_case_config=PgVectoRSFLATConfig(
|
88
|
+
max_parallel_workers=parameters["max_parallel_workers"],
|
89
|
+
quantization_type=parameters["quantization_type"],
|
90
|
+
quantization_ratio=parameters["quantization_ratio"],
|
91
|
+
),
|
92
|
+
**parameters,
|
93
|
+
)
|
94
|
+
|
95
|
+
|
96
|
+
class PgVectoRSIVFFlatTypedDict(PgVectoRSTypedDict, IVFFlatTypedDict): ...
|
97
|
+
|
98
|
+
|
99
|
+
@cli.command()
|
100
|
+
@click_parameter_decorators_from_typed_dict(PgVectoRSIVFFlatTypedDict)
|
101
|
+
def PgVectoRSIVFFlat(
|
102
|
+
**parameters: Unpack[PgVectoRSIVFFlatTypedDict],
|
103
|
+
):
|
104
|
+
from .config import PgVectoRSConfig, PgVectoRSIVFFlatConfig
|
105
|
+
|
106
|
+
run(
|
107
|
+
db=DB.PgVectoRS,
|
108
|
+
db_config=PgVectoRSConfig(
|
109
|
+
db_label=parameters["db_label"],
|
110
|
+
user_name=SecretStr(parameters["user_name"]),
|
111
|
+
password=SecretStr(parameters["password"]),
|
112
|
+
host=parameters["host"],
|
113
|
+
db_name=parameters["db_name"],
|
114
|
+
),
|
115
|
+
db_case_config=PgVectoRSIVFFlatConfig(
|
116
|
+
max_parallel_workers=parameters["max_parallel_workers"],
|
117
|
+
quantization_type=parameters["quantization_type"],
|
118
|
+
quantization_ratio=parameters["quantization_ratio"],
|
119
|
+
probes=parameters["probes"],
|
120
|
+
lists=parameters["lists"],
|
121
|
+
),
|
122
|
+
**parameters,
|
123
|
+
)
|
124
|
+
|
125
|
+
|
126
|
+
class PgVectoRSHNSWTypedDict(PgVectoRSTypedDict, HNSWFlavor1): ...
|
127
|
+
|
128
|
+
|
129
|
+
@cli.command()
|
130
|
+
@click_parameter_decorators_from_typed_dict(PgVectoRSHNSWTypedDict)
|
131
|
+
def PgVectoRSHNSW(
|
132
|
+
**parameters: Unpack[PgVectoRSHNSWTypedDict],
|
133
|
+
):
|
134
|
+
from .config import PgVectoRSConfig, PgVectoRSHNSWConfig
|
135
|
+
|
136
|
+
run(
|
137
|
+
db=DB.PgVectoRS,
|
138
|
+
db_config=PgVectoRSConfig(
|
139
|
+
db_label=parameters["db_label"],
|
140
|
+
user_name=SecretStr(parameters["user_name"]),
|
141
|
+
password=SecretStr(parameters["password"]),
|
142
|
+
host=parameters["host"],
|
143
|
+
db_name=parameters["db_name"],
|
144
|
+
),
|
145
|
+
db_case_config=PgVectoRSHNSWConfig(
|
146
|
+
max_parallel_workers=parameters["max_parallel_workers"],
|
147
|
+
quantization_type=parameters["quantization_type"],
|
148
|
+
quantization_ratio=parameters["quantization_ratio"],
|
149
|
+
m=parameters["m"],
|
150
|
+
ef_construction=parameters["ef_construction"],
|
151
|
+
ef_search=parameters["ef_search"],
|
152
|
+
),
|
153
|
+
**parameters,
|
154
|
+
)
|