vectordb-bench 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -13
- vectordb_bench/backend/clients/__init__.py +13 -0
- vectordb_bench/backend/clients/api.py +2 -0
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
- vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
- vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
- vectordb_bench/backend/clients/pgdiskann/cli.py +99 -0
- vectordb_bench/backend/clients/pgdiskann/config.py +145 -0
- vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +350 -0
- vectordb_bench/backend/clients/pgvector/cli.py +62 -1
- vectordb_bench/backend/clients/pgvector/config.py +48 -10
- vectordb_bench/backend/clients/pgvector/pgvector.py +145 -26
- vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +22 -4
- vectordb_bench/backend/clients/pinecone/config.py +0 -2
- vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
- vectordb_bench/backend/clients/redis/cli.py +8 -0
- vectordb_bench/backend/clients/redis/config.py +37 -6
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +1 -1
- vectordb_bench/backend/runner/mp_runner.py +2 -1
- vectordb_bench/cli/cli.py +137 -0
- vectordb_bench/cli/vectordbbench.py +4 -1
- vectordb_bench/frontend/components/check_results/charts.py +9 -6
- vectordb_bench/frontend/components/concurrent/charts.py +3 -6
- vectordb_bench/frontend/components/run_test/caseSelector.py +6 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +165 -1
- vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
- vectordb_bench/frontend/vdb_benchmark.py +11 -3
- vectordb_bench/models.py +13 -3
- vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
- vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
- vectordb_bench/results/getLeaderboardData.py +17 -7
- vectordb_bench/results/leaderboard.json +1 -1
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/METADATA +65 -35
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/RECORD +42 -38
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.13.dist-info → vectordb_bench-0.0.15.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ from pgvector.psycopg import register_vector
|
|
11
11
|
from psycopg import Connection, Cursor, sql
|
12
12
|
|
13
13
|
from ..api import VectorDB
|
14
|
-
from .config import PgVectorConfigDict, PgVectorIndexConfig
|
14
|
+
from .config import PgVectorConfigDict, PgVectorIndexConfig, PgVectorHNSWConfig
|
15
15
|
|
16
16
|
log = logging.getLogger(__name__)
|
17
17
|
|
@@ -22,7 +22,7 @@ class PgVector(VectorDB):
|
|
22
22
|
conn: psycopg.Connection[Any] | None = None
|
23
23
|
cursor: psycopg.Cursor[Any] | None = None
|
24
24
|
|
25
|
-
|
25
|
+
_filtered_search: sql.Composed
|
26
26
|
_unfiltered_search: sql.Composed
|
27
27
|
|
28
28
|
def __init__(
|
@@ -87,6 +87,92 @@ class PgVector(VectorDB):
|
|
87
87
|
assert cursor is not None, "Cursor is not initialized"
|
88
88
|
|
89
89
|
return conn, cursor
|
90
|
+
|
91
|
+
def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
|
92
|
+
index_param = self.case_config.index_param()
|
93
|
+
reranking = self.case_config.search_param()["reranking"]
|
94
|
+
column_name = (
|
95
|
+
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
|
96
|
+
if index_param["quantization_type"] == "bit"
|
97
|
+
else sql.SQL("embedding")
|
98
|
+
)
|
99
|
+
search_vector = (
|
100
|
+
sql.SQL("binary_quantize({0})").format(sql.Placeholder())
|
101
|
+
if index_param["quantization_type"] == "bit"
|
102
|
+
else sql.Placeholder()
|
103
|
+
)
|
104
|
+
|
105
|
+
# The following sections assume that the quantization_type value matches the quantization function name
|
106
|
+
if index_param["quantization_type"] != None:
|
107
|
+
if index_param["quantization_type"] == "bit" and reranking:
|
108
|
+
# Embeddings needs to be passed to binary_quantize function if quantization_type is bit
|
109
|
+
search_query = sql.Composed(
|
110
|
+
[
|
111
|
+
sql.SQL(
|
112
|
+
"""
|
113
|
+
SELECT i.id
|
114
|
+
FROM (
|
115
|
+
SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
|
116
|
+
FROM public.{table_name} {where_clause}
|
117
|
+
ORDER BY {column_name}::{quantization_type}({dim})
|
118
|
+
"""
|
119
|
+
).format(
|
120
|
+
table_name=sql.Identifier(self.table_name),
|
121
|
+
column_name=column_name,
|
122
|
+
reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]),
|
123
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
124
|
+
dim=sql.Literal(self.dim),
|
125
|
+
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
126
|
+
),
|
127
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
128
|
+
sql.SQL(
|
129
|
+
"""
|
130
|
+
{search_vector}
|
131
|
+
LIMIT {quantized_fetch_limit}
|
132
|
+
) i
|
133
|
+
ORDER BY i.distance
|
134
|
+
LIMIT %s::int
|
135
|
+
"""
|
136
|
+
).format(
|
137
|
+
search_vector=search_vector,
|
138
|
+
quantized_fetch_limit=sql.Literal(
|
139
|
+
self.case_config.search_param()["quantized_fetch_limit"]
|
140
|
+
),
|
141
|
+
),
|
142
|
+
]
|
143
|
+
)
|
144
|
+
else:
|
145
|
+
search_query = sql.Composed(
|
146
|
+
[
|
147
|
+
sql.SQL(
|
148
|
+
"SELECT id FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) "
|
149
|
+
).format(
|
150
|
+
table_name=sql.Identifier(self.table_name),
|
151
|
+
column_name=column_name,
|
152
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
153
|
+
dim=sql.Literal(self.dim),
|
154
|
+
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
155
|
+
),
|
156
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
157
|
+
sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector),
|
158
|
+
]
|
159
|
+
)
|
160
|
+
else:
|
161
|
+
search_query = sql.Composed(
|
162
|
+
[
|
163
|
+
sql.SQL(
|
164
|
+
"SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
|
165
|
+
).format(
|
166
|
+
table_name=sql.Identifier(self.table_name),
|
167
|
+
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
|
168
|
+
),
|
169
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
170
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
171
|
+
]
|
172
|
+
)
|
173
|
+
|
174
|
+
return search_query
|
175
|
+
|
90
176
|
|
91
177
|
@contextmanager
|
92
178
|
def init(self) -> Generator[None, None, None]:
|
@@ -112,15 +198,8 @@ class PgVector(VectorDB):
|
|
112
198
|
self.cursor.execute(command)
|
113
199
|
self.conn.commit()
|
114
200
|
|
115
|
-
self.
|
116
|
-
|
117
|
-
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
|
118
|
-
sql.Identifier(self.table_name)
|
119
|
-
),
|
120
|
-
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
121
|
-
sql.SQL(" %s::vector LIMIT %s::int"),
|
122
|
-
]
|
123
|
-
)
|
201
|
+
self._filtered_search = self._generate_search_query(filtered=True)
|
202
|
+
self._unfiltered_search = self._generate_search_query()
|
124
203
|
|
125
204
|
try:
|
126
205
|
yield
|
@@ -255,17 +334,39 @@ class PgVector(VectorDB):
|
|
255
334
|
else:
|
256
335
|
with_clause = sql.Composed(())
|
257
336
|
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
337
|
+
if index_param["quantization_type"] != None:
|
338
|
+
index_create_sql = sql.SQL(
|
339
|
+
"""
|
340
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
341
|
+
USING {index_type} (({column_name}::{quantization_type}({dim})) {embedding_metric})
|
342
|
+
"""
|
343
|
+
).format(
|
344
|
+
index_name=sql.Identifier(self._index_name),
|
345
|
+
table_name=sql.Identifier(self.table_name),
|
346
|
+
column_name=(
|
347
|
+
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
|
348
|
+
if index_param["quantization_type"] == "bit"
|
349
|
+
else sql.Identifier("embedding")
|
350
|
+
),
|
351
|
+
index_type=sql.Identifier(index_param["index_type"]),
|
352
|
+
# This assumes that the quantization_type value matches the quantization function name
|
353
|
+
quantization_type=sql.SQL(index_param["quantization_type"]),
|
354
|
+
dim=self.dim,
|
355
|
+
embedding_metric=sql.Identifier(index_param["metric"]),
|
356
|
+
)
|
357
|
+
else:
|
358
|
+
index_create_sql = sql.SQL(
|
359
|
+
"""
|
360
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
361
|
+
USING {index_type} (embedding {embedding_metric})
|
362
|
+
"""
|
363
|
+
).format(
|
364
|
+
index_name=sql.Identifier(self._index_name),
|
365
|
+
table_name=sql.Identifier(self.table_name),
|
366
|
+
index_type=sql.Identifier(index_param["index_type"]),
|
367
|
+
embedding_metric=sql.Identifier(index_param["metric"]),
|
368
|
+
)
|
369
|
+
|
269
370
|
index_create_sql_with_with_clause = (
|
270
371
|
index_create_sql + with_clause
|
271
372
|
).join(" ")
|
@@ -341,10 +442,28 @@ class PgVector(VectorDB):
|
|
341
442
|
assert self.conn is not None, "Connection is not initialized"
|
342
443
|
assert self.cursor is not None, "Cursor is not initialized"
|
343
444
|
|
445
|
+
index_param = self.case_config.index_param()
|
446
|
+
search_param = self.case_config.search_param()
|
344
447
|
q = np.asarray(query)
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
448
|
+
if filters:
|
449
|
+
gt = filters.get("id")
|
450
|
+
if index_param["quantization_type"] == "bit" and search_param["reranking"]:
|
451
|
+
result = self.cursor.execute(
|
452
|
+
self._filtered_search, (q, gt, q, k), prepare=True, binary=True
|
453
|
+
)
|
454
|
+
else:
|
455
|
+
result = self.cursor.execute(
|
456
|
+
self._filtered_search, (gt, q, k), prepare=True, binary=True
|
457
|
+
)
|
458
|
+
|
459
|
+
else:
|
460
|
+
if index_param["quantization_type"] == "bit" and search_param["reranking"]:
|
461
|
+
result = self.cursor.execute(
|
462
|
+
self._unfiltered_search, (q, q, k), prepare=True, binary=True
|
463
|
+
)
|
464
|
+
else:
|
465
|
+
result = self.cursor.execute(
|
466
|
+
self._unfiltered_search, (q, k), prepare=True, binary=True
|
467
|
+
)
|
349
468
|
|
350
469
|
return [int(i[0]) for i in result.fetchall()]
|
@@ -0,0 +1,108 @@
|
|
1
|
+
import click
|
2
|
+
import os
|
3
|
+
from pydantic import SecretStr
|
4
|
+
|
5
|
+
from ....cli.cli import (
|
6
|
+
CommonTypedDict,
|
7
|
+
cli,
|
8
|
+
click_parameter_decorators_from_typed_dict,
|
9
|
+
run,
|
10
|
+
)
|
11
|
+
from typing import Annotated, Unpack
|
12
|
+
from vectordb_bench.backend.clients import DB
|
13
|
+
|
14
|
+
|
15
|
+
class PgVectorScaleTypedDict(CommonTypedDict):
|
16
|
+
user_name: Annotated[
|
17
|
+
str, click.option("--user-name", type=str, help="Db username", required=True)
|
18
|
+
]
|
19
|
+
password: Annotated[
|
20
|
+
str,
|
21
|
+
click.option("--password",
|
22
|
+
type=str,
|
23
|
+
help="Postgres database password",
|
24
|
+
default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
|
25
|
+
show_default="$POSTGRES_PASSWORD",
|
26
|
+
),
|
27
|
+
]
|
28
|
+
|
29
|
+
host: Annotated[
|
30
|
+
str, click.option("--host", type=str, help="Db host", required=True)
|
31
|
+
]
|
32
|
+
db_name: Annotated[
|
33
|
+
str, click.option("--db-name", type=str, help="Db name", required=True)
|
34
|
+
]
|
35
|
+
|
36
|
+
|
37
|
+
class PgVectorScaleDiskAnnTypedDict(PgVectorScaleTypedDict):
|
38
|
+
storage_layout: Annotated[
|
39
|
+
str,
|
40
|
+
click.option(
|
41
|
+
"--storage-layout", type=str, help="Streaming DiskANN storage layout",
|
42
|
+
),
|
43
|
+
]
|
44
|
+
num_neighbors: Annotated[
|
45
|
+
int,
|
46
|
+
click.option(
|
47
|
+
"--num-neighbors", type=int, help="Streaming DiskANN num neighbors",
|
48
|
+
),
|
49
|
+
]
|
50
|
+
search_list_size: Annotated[
|
51
|
+
int,
|
52
|
+
click.option(
|
53
|
+
"--search-list-size", type=int, help="Streaming DiskANN search list size",
|
54
|
+
),
|
55
|
+
]
|
56
|
+
max_alpha: Annotated[
|
57
|
+
float,
|
58
|
+
click.option(
|
59
|
+
"--max-alpha", type=float, help="Streaming DiskANN max alpha",
|
60
|
+
),
|
61
|
+
]
|
62
|
+
num_dimensions: Annotated[
|
63
|
+
int,
|
64
|
+
click.option(
|
65
|
+
"--num-dimensions", type=int, help="Streaming DiskANN num dimensions",
|
66
|
+
),
|
67
|
+
]
|
68
|
+
query_search_list_size: Annotated[
|
69
|
+
int,
|
70
|
+
click.option(
|
71
|
+
"--query-search-list-size", type=int, help="Streaming DiskANN query search list size",
|
72
|
+
),
|
73
|
+
]
|
74
|
+
query_rescore: Annotated[
|
75
|
+
int,
|
76
|
+
click.option(
|
77
|
+
"--query-rescore", type=int, help="Streaming DiskANN query rescore",
|
78
|
+
),
|
79
|
+
]
|
80
|
+
|
81
|
+
|
82
|
+
@cli.command()
|
83
|
+
@click_parameter_decorators_from_typed_dict(PgVectorScaleDiskAnnTypedDict)
|
84
|
+
def PgVectorScaleDiskAnn(
|
85
|
+
**parameters: Unpack[PgVectorScaleDiskAnnTypedDict],
|
86
|
+
):
|
87
|
+
from .config import PgVectorScaleConfig, PgVectorScaleStreamingDiskANNConfig
|
88
|
+
|
89
|
+
run(
|
90
|
+
db=DB.PgVectorScale,
|
91
|
+
db_config=PgVectorScaleConfig(
|
92
|
+
db_label=parameters["db_label"],
|
93
|
+
user_name=SecretStr(parameters["user_name"]),
|
94
|
+
password=SecretStr(parameters["password"]),
|
95
|
+
host=parameters["host"],
|
96
|
+
db_name=parameters["db_name"],
|
97
|
+
),
|
98
|
+
db_case_config=PgVectorScaleStreamingDiskANNConfig(
|
99
|
+
storage_layout=parameters["storage_layout"],
|
100
|
+
num_neighbors=parameters["num_neighbors"],
|
101
|
+
search_list_size=parameters["search_list_size"],
|
102
|
+
max_alpha=parameters["max_alpha"],
|
103
|
+
num_dimensions=parameters["num_dimensions"],
|
104
|
+
query_search_list_size=parameters["query_search_list_size"],
|
105
|
+
query_rescore=parameters["query_rescore"],
|
106
|
+
),
|
107
|
+
**parameters,
|
108
|
+
)
|
@@ -22,6 +22,9 @@ class PgVectorScale(VectorDB):
|
|
22
22
|
conn: psycopg.Connection[Any] | None = None
|
23
23
|
coursor: psycopg.Cursor[Any] | None = None
|
24
24
|
|
25
|
+
_unfiltered_search: sql.Composed
|
26
|
+
_filtered_search: sql.Composed
|
27
|
+
|
25
28
|
def __init__(
|
26
29
|
self,
|
27
30
|
dim: int,
|
@@ -99,6 +102,16 @@ class PgVectorScale(VectorDB):
|
|
99
102
|
self.cursor.execute(command)
|
100
103
|
self.conn.commit()
|
101
104
|
|
105
|
+
self._filtered_search = sql.Composed(
|
106
|
+
[
|
107
|
+
sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format(
|
108
|
+
sql.Identifier(self.table_name),
|
109
|
+
),
|
110
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
111
|
+
sql.SQL(" %s::vector LIMIT %s::int")
|
112
|
+
]
|
113
|
+
)
|
114
|
+
|
102
115
|
self._unfiltered_search = sql.Composed(
|
103
116
|
[
|
104
117
|
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
|
@@ -264,9 +277,14 @@ class PgVectorScale(VectorDB):
|
|
264
277
|
assert self.cursor is not None, "Cursor is not initialized"
|
265
278
|
|
266
279
|
q = np.asarray(query)
|
267
|
-
|
268
|
-
|
269
|
-
self.
|
270
|
-
|
280
|
+
if filters:
|
281
|
+
gt = filters.get("id")
|
282
|
+
result = self.cursor.execute(
|
283
|
+
self._filtered_search, (gt, q, k), prepare=True, binary=True
|
284
|
+
)
|
285
|
+
else:
|
286
|
+
result = self.cursor.execute(
|
287
|
+
self._unfiltered_search, (q, k), prepare=True, binary=True
|
288
|
+
)
|
271
289
|
|
272
290
|
return [int(i[0]) for i in result.fetchall()]
|
@@ -4,12 +4,10 @@ from ..api import DBConfig
|
|
4
4
|
|
5
5
|
class PineconeConfig(DBConfig):
|
6
6
|
api_key: SecretStr
|
7
|
-
environment: SecretStr
|
8
7
|
index_name: str
|
9
8
|
|
10
9
|
def to_dict(self) -> dict:
|
11
10
|
return {
|
12
11
|
"api_key": self.api_key.get_secret_value(),
|
13
|
-
"environment": self.environment.get_secret_value(),
|
14
12
|
"index_name": self.index_name,
|
15
13
|
}
|
@@ -3,7 +3,7 @@
|
|
3
3
|
import logging
|
4
4
|
from contextlib import contextmanager
|
5
5
|
from typing import Type
|
6
|
-
|
6
|
+
import pinecone
|
7
7
|
from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
|
8
8
|
from .config import PineconeConfig
|
9
9
|
|
@@ -11,7 +11,8 @@ from .config import PineconeConfig
|
|
11
11
|
log = logging.getLogger(__name__)
|
12
12
|
|
13
13
|
PINECONE_MAX_NUM_PER_BATCH = 1000
|
14
|
-
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024
|
14
|
+
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
|
15
|
+
|
15
16
|
|
16
17
|
class Pinecone(VectorDB):
|
17
18
|
def __init__(
|
@@ -23,30 +24,25 @@ class Pinecone(VectorDB):
|
|
23
24
|
**kwargs,
|
24
25
|
):
|
25
26
|
"""Initialize wrapper around the milvus vector database."""
|
26
|
-
self.index_name = db_config
|
27
|
-
self.api_key = db_config
|
28
|
-
self.
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
27
|
+
self.index_name = db_config.get("index_name", "")
|
28
|
+
self.api_key = db_config.get("api_key", "")
|
29
|
+
self.batch_size = int(
|
30
|
+
min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)
|
31
|
+
)
|
32
|
+
|
33
|
+
pc = pinecone.Pinecone(api_key=self.api_key)
|
34
|
+
index = pc.Index(self.index_name)
|
35
|
+
|
35
36
|
if drop_old:
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
index_dim = index.describe_index_stats()["dimension"]
|
40
|
-
if (index_dim != dim):
|
41
|
-
raise ValueError(
|
42
|
-
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}")
|
43
|
-
log.info(
|
44
|
-
f"Pinecone client delete old index: {self.index_name}")
|
45
|
-
index.delete(delete_all=True)
|
46
|
-
index.close()
|
47
|
-
else:
|
37
|
+
index_stats = index.describe_index_stats()
|
38
|
+
index_dim = index_stats["dimension"]
|
39
|
+
if index_dim != dim:
|
48
40
|
raise ValueError(
|
49
|
-
f"Pinecone index {self.index_name}
|
41
|
+
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
|
42
|
+
)
|
43
|
+
for namespace in index_stats["namespaces"]:
|
44
|
+
log.info(f"Pinecone index delete namespace: {namespace}")
|
45
|
+
index.delete(delete_all=True, namespace=namespace)
|
50
46
|
|
51
47
|
self._metadata_key = "meta"
|
52
48
|
|
@@ -59,13 +55,10 @@ class Pinecone(VectorDB):
|
|
59
55
|
return EmptyDBCaseConfig
|
60
56
|
|
61
57
|
@contextmanager
|
62
|
-
def init(self)
|
63
|
-
|
64
|
-
|
65
|
-
api_key=self.api_key, environment=self.environment)
|
66
|
-
self.index = pinecone.Index(self.index_name)
|
58
|
+
def init(self):
|
59
|
+
pc = pinecone.Pinecone(api_key=self.api_key)
|
60
|
+
self.index = pc.Index(self.index_name)
|
67
61
|
yield
|
68
|
-
self.index.close()
|
69
62
|
|
70
63
|
def ready_to_load(self):
|
71
64
|
pass
|
@@ -83,11 +76,16 @@ class Pinecone(VectorDB):
|
|
83
76
|
insert_count = 0
|
84
77
|
try:
|
85
78
|
for batch_start_offset in range(0, len(embeddings), self.batch_size):
|
86
|
-
batch_end_offset = min(
|
79
|
+
batch_end_offset = min(
|
80
|
+
batch_start_offset + self.batch_size, len(embeddings)
|
81
|
+
)
|
87
82
|
insert_datas = []
|
88
83
|
for i in range(batch_start_offset, batch_end_offset):
|
89
|
-
insert_data = (
|
90
|
-
|
84
|
+
insert_data = (
|
85
|
+
str(metadata[i]),
|
86
|
+
embeddings[i],
|
87
|
+
{self._metadata_key: metadata[i]},
|
88
|
+
)
|
91
89
|
insert_datas.append(insert_data)
|
92
90
|
self.index.upsert(insert_datas)
|
93
91
|
insert_count += batch_end_offset - batch_start_offset
|
@@ -101,7 +99,7 @@ class Pinecone(VectorDB):
|
|
101
99
|
k: int = 100,
|
102
100
|
filters: dict | None = None,
|
103
101
|
timeout: int | None = None,
|
104
|
-
) -> list[
|
102
|
+
) -> list[int]:
|
105
103
|
if filters is None:
|
106
104
|
pinecone_filters = {}
|
107
105
|
else:
|
@@ -111,9 +109,9 @@ class Pinecone(VectorDB):
|
|
111
109
|
top_k=k,
|
112
110
|
vector=query,
|
113
111
|
filter=pinecone_filters,
|
114
|
-
)[
|
112
|
+
)["matches"]
|
115
113
|
except Exception as e:
|
116
114
|
print(f"Error querying index: {e}")
|
117
115
|
raise e
|
118
|
-
id_res = [int(one_res[
|
116
|
+
id_res = [int(one_res["id"]) for one_res in res]
|
119
117
|
return id_res
|
@@ -3,6 +3,9 @@ from typing import Annotated, TypedDict, Unpack
|
|
3
3
|
import click
|
4
4
|
from pydantic import SecretStr
|
5
5
|
|
6
|
+
from .config import RedisHNSWConfig
|
7
|
+
|
8
|
+
|
6
9
|
from ....cli.cli import (
|
7
10
|
CommonTypedDict,
|
8
11
|
HNSWFlavor2,
|
@@ -69,6 +72,11 @@ def Redis(**parameters: Unpack[RedisHNSWTypedDict]):
|
|
69
72
|
ssl=parameters["ssl"],
|
70
73
|
ssl_ca_certs=parameters["ssl_ca_certs"],
|
71
74
|
cmd=parameters["cmd"],
|
75
|
+
),
|
76
|
+
db_case_config=RedisHNSWConfig(
|
77
|
+
M=parameters["m"],
|
78
|
+
efConstruction=parameters["ef_construction"],
|
79
|
+
ef=parameters["ef_runtime"],
|
72
80
|
),
|
73
81
|
**parameters,
|
74
82
|
)
|
@@ -1,14 +1,45 @@
|
|
1
|
-
from pydantic import SecretStr
|
2
|
-
from ..api import DBConfig
|
1
|
+
from pydantic import SecretStr, BaseModel
|
2
|
+
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
|
3
3
|
|
4
4
|
class RedisConfig(DBConfig):
|
5
|
-
password: SecretStr
|
5
|
+
password: SecretStr | None = None
|
6
6
|
host: SecretStr
|
7
|
-
port: int = None
|
7
|
+
port: int | None = None
|
8
8
|
|
9
9
|
def to_dict(self) -> dict:
|
10
10
|
return {
|
11
11
|
"host": self.host.get_secret_value(),
|
12
12
|
"port": self.port,
|
13
|
-
"password": self.password.get_secret_value(),
|
14
|
-
}
|
13
|
+
"password": self.password.get_secret_value() if self.password is not None else None,
|
14
|
+
}
|
15
|
+
|
16
|
+
|
17
|
+
|
18
|
+
class RedisIndexConfig(BaseModel):
|
19
|
+
"""Base config for milvus"""
|
20
|
+
|
21
|
+
metric_type: MetricType | None = None
|
22
|
+
|
23
|
+
def parse_metric(self) -> str:
|
24
|
+
if not self.metric_type:
|
25
|
+
return ""
|
26
|
+
return self.metric_type.value
|
27
|
+
|
28
|
+
class RedisHNSWConfig(RedisIndexConfig, DBCaseConfig):
|
29
|
+
M: int
|
30
|
+
efConstruction: int
|
31
|
+
ef: int | None = None
|
32
|
+
index: IndexType = IndexType.HNSW
|
33
|
+
|
34
|
+
def index_param(self) -> dict:
|
35
|
+
return {
|
36
|
+
"metric_type": self.parse_metric(),
|
37
|
+
"index_type": self.index.value,
|
38
|
+
"params": {"M": self.M, "efConstruction": self.efConstruction},
|
39
|
+
}
|
40
|
+
|
41
|
+
def search_param(self) -> dict:
|
42
|
+
return {
|
43
|
+
"metric_type": self.parse_metric(),
|
44
|
+
"params": {"ef": self.ef},
|
45
|
+
}
|
@@ -23,7 +23,7 @@ class WeaviateCloud(VectorDB):
|
|
23
23
|
**kwargs,
|
24
24
|
):
|
25
25
|
"""Initialize wrapper around the weaviate vector database."""
|
26
|
-
db_config.update("auth_client_secret"
|
26
|
+
db_config.update({"auth_client_secret": weaviate.AuthApiKey(api_key=db_config.get("auth_client_secret"))})
|
27
27
|
self.db_config = db_config
|
28
28
|
self.case_config = db_case_config
|
29
29
|
self.collection_name = collection_name
|
@@ -2,6 +2,7 @@ import time
|
|
2
2
|
import traceback
|
3
3
|
import concurrent
|
4
4
|
import multiprocessing as mp
|
5
|
+
import random
|
5
6
|
import logging
|
6
7
|
from typing import Iterable
|
7
8
|
import numpy as np
|
@@ -46,7 +47,7 @@ class MultiProcessingSearchRunner:
|
|
46
47
|
cond.wait()
|
47
48
|
|
48
49
|
with self.db.init():
|
49
|
-
num, idx = len(test_data), 0
|
50
|
+
num, idx = len(test_data), random.randint(0, len(test_data) - 1)
|
50
51
|
|
51
52
|
start_time = time.perf_counter()
|
52
53
|
count = 0
|