vectordb-bench 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +1 -0
- vectordb_bench/backend/assembler.py +1 -1
- vectordb_bench/backend/cases.py +64 -18
- vectordb_bench/backend/clients/__init__.py +35 -0
- vectordb_bench/backend/clients/api.py +21 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
- vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
- vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
- vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
- vectordb_bench/backend/clients/memorydb/cli.py +88 -0
- vectordb_bench/backend/clients/memorydb/config.py +54 -0
- vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
- vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
- vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +272 -0
- vectordb_bench/backend/dataset.py +27 -5
- vectordb_bench/cli/vectordbbench.py +7 -0
- vectordb_bench/custom/custom_case.json +18 -0
- vectordb_bench/frontend/components/check_results/charts.py +6 -6
- vectordb_bench/frontend/components/check_results/data.py +18 -11
- vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
- vectordb_bench/frontend/components/check_results/filters.py +20 -13
- vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
- vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
- vectordb_bench/frontend/components/concurrent/charts.py +26 -29
- vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
- vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
- vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
- vectordb_bench/frontend/components/custom/initStyle.py +15 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
- vectordb_bench/frontend/components/run_test/caseSelector.py +50 -28
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -19
- vectordb_bench/frontend/components/run_test/dbSelector.py +2 -14
- vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
- vectordb_bench/frontend/components/run_test/initStyle.py +16 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +1 -1
- vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +311 -40
- vectordb_bench/frontend/{const → config}/styles.py +2 -0
- vectordb_bench/frontend/pages/concurrent.py +11 -18
- vectordb_bench/frontend/pages/custom.py +64 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
- vectordb_bench/frontend/pages/run_test.py +4 -0
- vectordb_bench/frontend/pages/tables.py +2 -2
- vectordb_bench/frontend/utils.py +17 -1
- vectordb_bench/frontend/vdb_benchmark.py +3 -3
- vectordb_bench/models.py +26 -10
- vectordb_bench/results/getLeaderboardData.py +1 -1
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/METADATA +46 -15
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/RECORD +57 -40
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/WHEEL +1 -1
- /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,272 @@
|
|
1
|
+
"""Wrapper around the Pgvectorscale vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import pprint
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from typing import Any, Generator, Optional, Tuple
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import psycopg
|
10
|
+
from pgvector.psycopg import register_vector
|
11
|
+
from psycopg import Connection, Cursor, sql
|
12
|
+
|
13
|
+
from ..api import VectorDB
|
14
|
+
from .config import PgVectorScaleConfigDict, PgVectorScaleIndexConfig
|
15
|
+
|
16
|
+
log = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class PgVectorScale(VectorDB):
|
20
|
+
"""Use psycopg instructions"""
|
21
|
+
|
22
|
+
conn: psycopg.Connection[Any] | None = None
|
23
|
+
coursor: psycopg.Cursor[Any] | None = None
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
dim: int,
|
28
|
+
db_config: PgVectorScaleConfigDict,
|
29
|
+
db_case_config: PgVectorScaleIndexConfig,
|
30
|
+
collection_name: str = "pg_vectorscale_collection",
|
31
|
+
drop_old: bool = False,
|
32
|
+
**kwargs,
|
33
|
+
):
|
34
|
+
self.name = "PgVectorScale"
|
35
|
+
self.db_config = db_config
|
36
|
+
self.case_config = db_case_config
|
37
|
+
self.table_name = collection_name
|
38
|
+
self.dim = dim
|
39
|
+
|
40
|
+
self._index_name = "pgvectorscale_index"
|
41
|
+
self._primary_field = "id"
|
42
|
+
self._vector_field = "embedding"
|
43
|
+
|
44
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
45
|
+
|
46
|
+
log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
|
47
|
+
if not any(
|
48
|
+
(
|
49
|
+
self.case_config.create_index_before_load,
|
50
|
+
self.case_config.create_index_after_load,
|
51
|
+
)
|
52
|
+
):
|
53
|
+
err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
|
54
|
+
log.error(err)
|
55
|
+
raise RuntimeError(
|
56
|
+
f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
|
57
|
+
)
|
58
|
+
|
59
|
+
if drop_old:
|
60
|
+
self._drop_index()
|
61
|
+
self._drop_table()
|
62
|
+
self._create_table(dim)
|
63
|
+
if self.case_config.create_index_before_load:
|
64
|
+
self._create_index()
|
65
|
+
|
66
|
+
self.cursor.close()
|
67
|
+
self.conn.close()
|
68
|
+
self.cursor = None
|
69
|
+
self.conn = None
|
70
|
+
|
71
|
+
@staticmethod
|
72
|
+
def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
|
73
|
+
conn = psycopg.connect(**kwargs)
|
74
|
+
conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
|
75
|
+
conn.commit()
|
76
|
+
register_vector(conn)
|
77
|
+
conn.autocommit = False
|
78
|
+
cursor = conn.cursor()
|
79
|
+
|
80
|
+
assert conn is not None, "Connection is not initialized"
|
81
|
+
assert cursor is not None, "Cursor is not initialized"
|
82
|
+
|
83
|
+
return conn, cursor
|
84
|
+
|
85
|
+
@contextmanager
|
86
|
+
def init(self) -> Generator[None, None, None]:
|
87
|
+
self.conn, self.cursor = self._create_connection(**self.db_config)
|
88
|
+
|
89
|
+
# index configuration may have commands defined that we should set during each client session
|
90
|
+
session_options: dict[str, Any] = self.case_config.session_param()
|
91
|
+
|
92
|
+
if len(session_options) > 0:
|
93
|
+
for setting_name, setting_val in session_options.items():
|
94
|
+
command = sql.SQL("SET {setting_name} " + "= {setting_val};").format(
|
95
|
+
setting_name=sql.Identifier(setting_name),
|
96
|
+
setting_val=sql.Identifier(str(setting_val)),
|
97
|
+
)
|
98
|
+
log.debug(command.as_string(self.cursor))
|
99
|
+
self.cursor.execute(command)
|
100
|
+
self.conn.commit()
|
101
|
+
|
102
|
+
self._unfiltered_search = sql.Composed(
|
103
|
+
[
|
104
|
+
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
|
105
|
+
sql.Identifier(self.table_name)
|
106
|
+
),
|
107
|
+
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
|
108
|
+
sql.SQL(" %s::vector LIMIT %s::int"),
|
109
|
+
]
|
110
|
+
)
|
111
|
+
|
112
|
+
try:
|
113
|
+
yield
|
114
|
+
finally:
|
115
|
+
self.cursor.close()
|
116
|
+
self.conn.close()
|
117
|
+
self.cursor = None
|
118
|
+
self.conn = None
|
119
|
+
|
120
|
+
def _drop_table(self):
|
121
|
+
assert self.conn is not None, "Connection is not initialized"
|
122
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
123
|
+
log.info(f"{self.name} client drop table : {self.table_name}")
|
124
|
+
|
125
|
+
self.cursor.execute(
|
126
|
+
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
|
127
|
+
table_name=sql.Identifier(self.table_name)
|
128
|
+
)
|
129
|
+
)
|
130
|
+
self.conn.commit()
|
131
|
+
|
132
|
+
def ready_to_load(self):
|
133
|
+
pass
|
134
|
+
|
135
|
+
def optimize(self):
|
136
|
+
self._post_insert()
|
137
|
+
|
138
|
+
def _post_insert(self):
|
139
|
+
log.info(f"{self.name} post insert before optimize")
|
140
|
+
if self.case_config.create_index_after_load:
|
141
|
+
self._drop_index()
|
142
|
+
self._create_index()
|
143
|
+
|
144
|
+
def _drop_index(self):
|
145
|
+
assert self.conn is not None, "Connection is not initialized"
|
146
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
147
|
+
log.info(f"{self.name} client drop index : {self._index_name}")
|
148
|
+
|
149
|
+
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
|
150
|
+
index_name=sql.Identifier(self._index_name)
|
151
|
+
)
|
152
|
+
log.debug(drop_index_sql.as_string(self.cursor))
|
153
|
+
self.cursor.execute(drop_index_sql)
|
154
|
+
self.conn.commit()
|
155
|
+
|
156
|
+
def _create_index(self):
|
157
|
+
assert self.conn is not None, "Connection is not initialized"
|
158
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
159
|
+
log.info(f"{self.name} client create index : {self._index_name}")
|
160
|
+
|
161
|
+
index_param: dict[str, Any] = self.case_config.index_param()
|
162
|
+
|
163
|
+
options = []
|
164
|
+
for option_name, option_val in index_param["options"].items():
|
165
|
+
if option_val is not None:
|
166
|
+
options.append(
|
167
|
+
sql.SQL("{option_name} = {val}").format(
|
168
|
+
option_name=sql.Identifier(option_name),
|
169
|
+
val=sql.Identifier(str(option_val)),
|
170
|
+
)
|
171
|
+
)
|
172
|
+
|
173
|
+
num_bits_per_dimension = "2" if self.dim < 900 else "1"
|
174
|
+
options.append(
|
175
|
+
sql.SQL("{option_name} = {val}").format(
|
176
|
+
option_name=sql.Identifier("num_bits_per_dimension"),
|
177
|
+
val=sql.Identifier(num_bits_per_dimension),
|
178
|
+
)
|
179
|
+
)
|
180
|
+
|
181
|
+
if any(options):
|
182
|
+
with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
|
183
|
+
else:
|
184
|
+
with_clause = sql.Composed(())
|
185
|
+
|
186
|
+
index_create_sql = sql.SQL(
|
187
|
+
"""
|
188
|
+
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
|
189
|
+
USING {index_type} (embedding {embedding_metric})
|
190
|
+
"""
|
191
|
+
).format(
|
192
|
+
index_name=sql.Identifier(self._index_name),
|
193
|
+
table_name=sql.Identifier(self.table_name),
|
194
|
+
index_type=sql.Identifier(index_param["index_type"].lower()),
|
195
|
+
embedding_metric=sql.Identifier(index_param["metric"]),
|
196
|
+
)
|
197
|
+
index_create_sql_with_with_clause = (
|
198
|
+
index_create_sql + with_clause
|
199
|
+
).join(" ")
|
200
|
+
log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
|
201
|
+
self.cursor.execute(index_create_sql_with_with_clause)
|
202
|
+
self.conn.commit()
|
203
|
+
|
204
|
+
def _create_table(self, dim: int):
|
205
|
+
assert self.conn is not None, "Connection is not initialized"
|
206
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
207
|
+
|
208
|
+
try:
|
209
|
+
log.info(f"{self.name} client create table : {self.table_name}")
|
210
|
+
|
211
|
+
self.cursor.execute(
|
212
|
+
sql.SQL(
|
213
|
+
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
|
214
|
+
).format(table_name=sql.Identifier(self.table_name), dim=dim)
|
215
|
+
)
|
216
|
+
self.conn.commit()
|
217
|
+
except Exception as e:
|
218
|
+
log.warning(
|
219
|
+
f"Failed to create pgvectorscale table: {self.table_name} error: {e}"
|
220
|
+
)
|
221
|
+
raise e from None
|
222
|
+
|
223
|
+
def insert_embeddings(
|
224
|
+
self,
|
225
|
+
embeddings: list[list[float]],
|
226
|
+
metadata: list[int],
|
227
|
+
**kwargs: Any,
|
228
|
+
) -> Tuple[int, Optional[Exception]]:
|
229
|
+
assert self.conn is not None, "Connection is not initialized"
|
230
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
231
|
+
|
232
|
+
try:
|
233
|
+
metadata_arr = np.array(metadata)
|
234
|
+
embeddings_arr = np.array(embeddings)
|
235
|
+
|
236
|
+
with self.cursor.copy(
|
237
|
+
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
|
238
|
+
table_name=sql.Identifier(self.table_name)
|
239
|
+
)
|
240
|
+
) as copy:
|
241
|
+
copy.set_types(["bigint", "vector"])
|
242
|
+
for i, row in enumerate(metadata_arr):
|
243
|
+
copy.write_row((row, embeddings_arr[i]))
|
244
|
+
self.conn.commit()
|
245
|
+
|
246
|
+
if kwargs.get("last_batch"):
|
247
|
+
self._post_insert()
|
248
|
+
|
249
|
+
return len(metadata), None
|
250
|
+
except Exception as e:
|
251
|
+
log.warning(
|
252
|
+
f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
|
253
|
+
)
|
254
|
+
return 0, e
|
255
|
+
|
256
|
+
def search_embedding(
|
257
|
+
self,
|
258
|
+
query: list[float],
|
259
|
+
k: int = 100,
|
260
|
+
filters: dict | None = None,
|
261
|
+
timeout: int | None = None,
|
262
|
+
) -> list[int]:
|
263
|
+
assert self.conn is not None, "Connection is not initialized"
|
264
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
265
|
+
|
266
|
+
q = np.asarray(query)
|
267
|
+
# TODO add filters support
|
268
|
+
result = self.cursor.execute(
|
269
|
+
self._unfiltered_search, (q, k), prepare=True, binary=True
|
270
|
+
)
|
271
|
+
|
272
|
+
return [int(i[0]) for i in result.fetchall()]
|
@@ -33,6 +33,7 @@ class BaseDataset(BaseModel):
|
|
33
33
|
use_shuffled: bool
|
34
34
|
with_gt: bool = False
|
35
35
|
_size_label: dict[int, SizeLabel] = PrivateAttr()
|
36
|
+
isCustom: bool = False
|
36
37
|
|
37
38
|
@validator("size")
|
38
39
|
def verify_size(cls, v):
|
@@ -52,7 +53,27 @@ class BaseDataset(BaseModel):
|
|
52
53
|
def file_count(self) -> int:
|
53
54
|
return self._size_label.get(self.size).file_count
|
54
55
|
|
56
|
+
class CustomDataset(BaseDataset):
|
57
|
+
dir: str
|
58
|
+
file_num: int
|
59
|
+
isCustom: bool = True
|
60
|
+
|
61
|
+
@validator("size")
|
62
|
+
def verify_size(cls, v):
|
63
|
+
return v
|
64
|
+
|
65
|
+
@property
|
66
|
+
def label(self) -> str:
|
67
|
+
return "Custom"
|
55
68
|
|
69
|
+
@property
|
70
|
+
def dir_name(self) -> str:
|
71
|
+
return self.dir
|
72
|
+
|
73
|
+
@property
|
74
|
+
def file_count(self) -> int:
|
75
|
+
return self.file_num
|
76
|
+
|
56
77
|
class LAION(BaseDataset):
|
57
78
|
name: str = "LAION"
|
58
79
|
dim: int = 768
|
@@ -186,11 +207,12 @@ class DatasetManager(BaseModel):
|
|
186
207
|
gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
|
187
208
|
all_files.extend([gt_file, test_file])
|
188
209
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
210
|
+
if not self.data.isCustom:
|
211
|
+
source.reader().read(
|
212
|
+
dataset=self.data.dir_name.lower(),
|
213
|
+
files=all_files,
|
214
|
+
local_ds_root=self.data_dir,
|
215
|
+
)
|
194
216
|
|
195
217
|
if gt_file is not None and test_file is not None:
|
196
218
|
self.test_data = self._read_file(test_file)
|
@@ -1,19 +1,26 @@
|
|
1
1
|
from ..backend.clients.pgvector.cli import PgVectorHNSW
|
2
|
+
from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat
|
2
3
|
from ..backend.clients.redis.cli import Redis
|
4
|
+
from ..backend.clients.memorydb.cli import MemoryDB
|
3
5
|
from ..backend.clients.test.cli import Test
|
4
6
|
from ..backend.clients.weaviate_cloud.cli import Weaviate
|
5
7
|
from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
|
6
8
|
from ..backend.clients.milvus.cli import MilvusAutoIndex
|
9
|
+
from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
|
7
10
|
|
8
11
|
|
9
12
|
from .cli import cli
|
10
13
|
|
11
14
|
cli.add_command(PgVectorHNSW)
|
15
|
+
cli.add_command(PgVectoRSHNSW)
|
16
|
+
cli.add_command(PgVectoRSIVFFlat)
|
12
17
|
cli.add_command(Redis)
|
18
|
+
cli.add_command(MemoryDB)
|
13
19
|
cli.add_command(Weaviate)
|
14
20
|
cli.add_command(Test)
|
15
21
|
cli.add_command(ZillizAutoIndex)
|
16
22
|
cli.add_command(MilvusAutoIndex)
|
23
|
+
cli.add_command(AWSOpenSearch)
|
17
24
|
|
18
25
|
|
19
26
|
if __name__ == "__main__":
|
@@ -0,0 +1,18 @@
|
|
1
|
+
[
|
2
|
+
{
|
3
|
+
"name": "My Dataset (Performace Case)",
|
4
|
+
"description": "this is a customized dataset.",
|
5
|
+
"load_timeout": 36000,
|
6
|
+
"optimize_timeout": 36000,
|
7
|
+
"dataset_config": {
|
8
|
+
"name": "My Dataset",
|
9
|
+
"dir": "/my_dataset_path",
|
10
|
+
"size": 1000000,
|
11
|
+
"dim": 1024,
|
12
|
+
"metric_type": "L2",
|
13
|
+
"file_count": 1,
|
14
|
+
"use_shuffled": false,
|
15
|
+
"with_gt": true
|
16
|
+
}
|
17
|
+
}
|
18
|
+
]
|
@@ -1,19 +1,19 @@
|
|
1
1
|
from vectordb_bench.backend.cases import Case
|
2
2
|
from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
|
3
3
|
from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
|
4
|
-
from vectordb_bench.frontend.
|
4
|
+
from vectordb_bench.frontend.config.styles import *
|
5
5
|
from vectordb_bench.models import ResultLabel
|
6
6
|
import plotly.express as px
|
7
7
|
|
8
8
|
|
9
|
-
def drawCharts(st, allData, failedTasks,
|
9
|
+
def drawCharts(st, allData, failedTasks, caseNames: list[str]):
|
10
10
|
initMainExpanderStyle(st)
|
11
|
-
for
|
12
|
-
chartContainer = st.expander(
|
13
|
-
data = [data for data in allData if data["case_name"] ==
|
11
|
+
for caseName in caseNames:
|
12
|
+
chartContainer = st.expander(caseName, True)
|
13
|
+
data = [data for data in allData if data["case_name"] == caseName]
|
14
14
|
drawChart(data, chartContainer)
|
15
15
|
|
16
|
-
errorDBs = failedTasks[
|
16
|
+
errorDBs = failedTasks[caseName]
|
17
17
|
showFailedDBs(chartContainer, errorDBs)
|
18
18
|
|
19
19
|
|
@@ -8,9 +8,9 @@ from vectordb_bench.models import CaseResult, ResultLabel
|
|
8
8
|
def getChartData(
|
9
9
|
tasks: list[CaseResult],
|
10
10
|
dbNames: list[str],
|
11
|
-
|
11
|
+
caseNames: list[str],
|
12
12
|
):
|
13
|
-
filterTasks = getFilterTasks(tasks, dbNames,
|
13
|
+
filterTasks = getFilterTasks(tasks, dbNames, caseNames)
|
14
14
|
mergedTasks, failedTasks = mergeTasks(filterTasks)
|
15
15
|
return mergedTasks, failedTasks
|
16
16
|
|
@@ -18,14 +18,16 @@ def getChartData(
|
|
18
18
|
def getFilterTasks(
|
19
19
|
tasks: list[CaseResult],
|
20
20
|
dbNames: list[str],
|
21
|
-
|
21
|
+
caseNames: list[str],
|
22
22
|
) -> list[CaseResult]:
|
23
|
-
case_ids = [case.case_id for case in cases]
|
24
23
|
filterTasks = [
|
25
24
|
task
|
26
25
|
for task in tasks
|
27
26
|
if task.task_config.db_name in dbNames
|
28
|
-
and task.task_config.case_config.case_id
|
27
|
+
and task.task_config.case_config.case_id.case_cls(
|
28
|
+
task.task_config.case_config.custom_case
|
29
|
+
).name
|
30
|
+
in caseNames
|
29
31
|
]
|
30
32
|
return filterTasks
|
31
33
|
|
@@ -36,16 +38,20 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
36
38
|
db_name = task.task_config.db_name
|
37
39
|
db = task.task_config.db.value
|
38
40
|
db_label = task.task_config.db_config.db_label or ""
|
39
|
-
|
40
|
-
|
41
|
+
version = task.task_config.db_config.version or ""
|
42
|
+
case = task.task_config.case_config.case_id.case_cls(
|
43
|
+
task.task_config.case_config.custom_case
|
44
|
+
)
|
45
|
+
dbCaseMetricsMap[db_name][case.name] = {
|
41
46
|
"db": db,
|
42
47
|
"db_label": db_label,
|
48
|
+
"version": version,
|
43
49
|
"metrics": mergeMetrics(
|
44
|
-
dbCaseMetricsMap[db_name][
|
50
|
+
dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
|
45
51
|
asdict(task.metrics),
|
46
52
|
),
|
47
53
|
"label": getBetterLabel(
|
48
|
-
dbCaseMetricsMap[db_name][
|
54
|
+
dbCaseMetricsMap[db_name][case.name].get("label", ResultLabel.FAILED),
|
49
55
|
task.label,
|
50
56
|
),
|
51
57
|
}
|
@@ -53,18 +59,19 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
53
59
|
mergedTasks = []
|
54
60
|
failedTasks = defaultdict(lambda: defaultdict(str))
|
55
61
|
for db_name, caseMetricsMap in dbCaseMetricsMap.items():
|
56
|
-
for
|
62
|
+
for case_name, metricInfo in caseMetricsMap.items():
|
57
63
|
metrics = metricInfo["metrics"]
|
58
64
|
db = metricInfo["db"]
|
59
65
|
db_label = metricInfo["db_label"]
|
66
|
+
version = metricInfo["version"]
|
60
67
|
label = metricInfo["label"]
|
61
|
-
case_name = case_id.case_name
|
62
68
|
if label == ResultLabel.NORMAL:
|
63
69
|
mergedTasks.append(
|
64
70
|
{
|
65
71
|
"db_name": db_name,
|
66
72
|
"db": db,
|
67
73
|
"db_label": db_label,
|
74
|
+
"version": version,
|
68
75
|
"case_name": case_name,
|
69
76
|
"metricsSet": set(metrics.keys()),
|
70
77
|
**metrics,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
def initMainExpanderStyle(st):
|
2
2
|
st.markdown(
|
3
3
|
"""<style>
|
4
|
-
.main
|
4
|
+
.main div[data-testid='stExpander'] p {font-size: 18px; font-weight: 600;}
|
5
5
|
.main div[data-testid='stExpander'] {
|
6
6
|
background-color: #F6F8FA;
|
7
7
|
border: 1px solid #A9BDD140;
|
@@ -1,8 +1,8 @@
|
|
1
1
|
from vectordb_bench.backend.cases import Case
|
2
2
|
from vectordb_bench.frontend.components.check_results.data import getChartData
|
3
3
|
from vectordb_bench.frontend.components.check_results.expanderStyle import initSidebarExanderStyle
|
4
|
-
from vectordb_bench.frontend.
|
5
|
-
from vectordb_bench.frontend.
|
4
|
+
from vectordb_bench.frontend.config.dbCaseConfigs import CASE_NAME_ORDER
|
5
|
+
from vectordb_bench.frontend.config.styles import *
|
6
6
|
import streamlit as st
|
7
7
|
|
8
8
|
from vectordb_bench.models import CaseResult, TestResult
|
@@ -18,11 +18,12 @@ def getshownData(results: list[TestResult], st):
|
|
18
18
|
st.header("Filters")
|
19
19
|
|
20
20
|
shownResults = getshownResults(results, st)
|
21
|
-
showDBNames,
|
21
|
+
showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st)
|
22
22
|
|
23
|
-
shownData, failedTasks = getChartData(
|
23
|
+
shownData, failedTasks = getChartData(
|
24
|
+
shownResults, showDBNames, showCaseNames)
|
24
25
|
|
25
|
-
return shownData, failedTasks,
|
26
|
+
return shownData, failedTasks, showCaseNames
|
26
27
|
|
27
28
|
|
28
29
|
def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
|
@@ -52,12 +53,18 @@ def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
|
|
52
53
|
return selectedResult
|
53
54
|
|
54
55
|
|
55
|
-
def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[
|
56
|
+
def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[str]]:
|
56
57
|
initSidebarExanderStyle(st)
|
57
58
|
allDbNames = list(set({res.task_config.db_name for res in result}))
|
58
59
|
allDbNames.sort()
|
59
|
-
|
60
|
-
|
60
|
+
allCases: list[Case] = [
|
61
|
+
res.task_config.case_config.case_id.case_cls(
|
62
|
+
res.task_config.case_config.custom_case)
|
63
|
+
for res in result
|
64
|
+
]
|
65
|
+
allCaseNameSet = set({case.name for case in allCases})
|
66
|
+
allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + \
|
67
|
+
[case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER]
|
61
68
|
|
62
69
|
# DB Filter
|
63
70
|
dbFilterContainer = st.container()
|
@@ -70,15 +77,14 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[Ca
|
|
70
77
|
|
71
78
|
# Case Filter
|
72
79
|
caseFilterContainer = st.container()
|
73
|
-
|
80
|
+
showCaseNames = filterView(
|
74
81
|
caseFilterContainer,
|
75
82
|
"Case Filter",
|
76
|
-
[
|
83
|
+
[caseName for caseName in allCaseNames],
|
77
84
|
col=1,
|
78
|
-
optionLables=[case.name for case in allCases],
|
79
85
|
)
|
80
86
|
|
81
|
-
return showDBNames,
|
87
|
+
return showDBNames, showCaseNames
|
82
88
|
|
83
89
|
|
84
90
|
def filterView(container, header, options, col, optionLables=None):
|
@@ -114,7 +120,8 @@ def filterView(container, header, options, col, optionLables=None):
|
|
114
120
|
)
|
115
121
|
if optionLables is None:
|
116
122
|
optionLables = options
|
117
|
-
isActive = {option: st.session_state[selectAllState]
|
123
|
+
isActive = {option: st.session_state[selectAllState]
|
124
|
+
for option in optionLables}
|
118
125
|
for i, option in enumerate(optionLables):
|
119
126
|
isActive[option] = columns[i % col].checkbox(
|
120
127
|
optionLables[i],
|
@@ -3,7 +3,7 @@ import pandas as pd
|
|
3
3
|
from collections import defaultdict
|
4
4
|
import streamlit as st
|
5
5
|
|
6
|
-
from vectordb_bench.frontend.
|
6
|
+
from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE
|
7
7
|
|
8
8
|
|
9
9
|
def priceTable(container, data):
|
@@ -1,26 +1,27 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
|
1
|
+
from vectordb_bench.frontend.components.check_results.expanderStyle import (
|
2
|
+
initMainExpanderStyle,
|
3
|
+
)
|
5
4
|
import plotly.express as px
|
6
5
|
|
7
|
-
from vectordb_bench.frontend.
|
6
|
+
from vectordb_bench.frontend.config.styles import COLOR_MAP
|
8
7
|
|
9
8
|
|
10
|
-
def drawChartsByCase(allData,
|
9
|
+
def drawChartsByCase(allData, showCaseNames: list[str], st):
|
11
10
|
initMainExpanderStyle(st)
|
12
|
-
for
|
13
|
-
chartContainer = st.expander(
|
14
|
-
caseDataList = [
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
11
|
+
for caseName in showCaseNames:
|
12
|
+
chartContainer = st.expander(caseName, True)
|
13
|
+
caseDataList = [data for data in allData if data["case_name"] == caseName]
|
14
|
+
data = [
|
15
|
+
{
|
16
|
+
"conc_num": caseData["conc_num_list"][i],
|
17
|
+
"qps": caseData["conc_qps_list"][i],
|
18
|
+
"latency_p99": caseData["conc_latency_p99_list"][i] * 1000,
|
19
|
+
"db_name": caseData["db_name"],
|
20
|
+
"db": caseData["db"],
|
21
|
+
}
|
22
|
+
for caseData in caseDataList
|
23
|
+
for i in range(len(caseData["conc_num_list"]))
|
24
|
+
]
|
24
25
|
drawChart(data, chartContainer)
|
25
26
|
|
26
27
|
|
@@ -38,7 +39,7 @@ def getRange(metric, data, padding_multipliers):
|
|
38
39
|
def drawChart(data, st):
|
39
40
|
if len(data) == 0:
|
40
41
|
return
|
41
|
-
|
42
|
+
|
42
43
|
x = "latency_p99"
|
43
44
|
xrange = getRange(x, data, [0.05, 0.1])
|
44
45
|
|
@@ -63,7 +64,6 @@ def drawChart(data, st):
|
|
63
64
|
line_group=line_group,
|
64
65
|
text=text,
|
65
66
|
markers=True,
|
66
|
-
# color_discrete_map=color_discrete_map,
|
67
67
|
hover_data={
|
68
68
|
"conc_num": True,
|
69
69
|
},
|
@@ -71,12 +71,9 @@ def drawChart(data, st):
|
|
71
71
|
)
|
72
72
|
fig.update_xaxes(range=xrange, title_text="Latency P99 (ms)")
|
73
73
|
fig.update_yaxes(range=yrange, title_text="QPS")
|
74
|
-
fig.update_traces(textposition="bottom right",
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
# ),
|
81
|
-
# )
|
82
|
-
st.plotly_chart(fig, use_container_width=True,)
|
74
|
+
fig.update_traces(textposition="bottom right", texttemplate="conc-%{text:,.4~r}")
|
75
|
+
|
76
|
+
st.plotly_chart(
|
77
|
+
fig,
|
78
|
+
use_container_width=True,
|
79
|
+
)
|