vectordb-bench 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +19 -5
- vectordb_bench/backend/assembler.py +1 -1
- vectordb_bench/backend/cases.py +93 -27
- vectordb_bench/backend/clients/__init__.py +14 -0
- vectordb_bench/backend/clients/api.py +1 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
- vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
- vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
- vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
- vectordb_bench/backend/clients/milvus/cli.py +291 -0
- vectordb_bench/backend/clients/milvus/milvus.py +13 -6
- vectordb_bench/backend/clients/pgvector/cli.py +116 -0
- vectordb_bench/backend/clients/pgvector/config.py +1 -1
- vectordb_bench/backend/clients/pgvector/pgvector.py +7 -4
- vectordb_bench/backend/clients/redis/cli.py +74 -0
- vectordb_bench/backend/clients/test/cli.py +25 -0
- vectordb_bench/backend/clients/test/config.py +18 -0
- vectordb_bench/backend/clients/test/test.py +62 -0
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +41 -0
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +55 -0
- vectordb_bench/backend/dataset.py +27 -5
- vectordb_bench/backend/runner/mp_runner.py +14 -3
- vectordb_bench/backend/runner/serial_runner.py +7 -3
- vectordb_bench/backend/task_runner.py +76 -26
- vectordb_bench/cli/__init__.py +0 -0
- vectordb_bench/cli/cli.py +362 -0
- vectordb_bench/cli/vectordbbench.py +22 -0
- vectordb_bench/config-files/sample_config.yml +17 -0
- vectordb_bench/custom/custom_case.json +18 -0
- vectordb_bench/frontend/components/check_results/charts.py +6 -6
- vectordb_bench/frontend/components/check_results/data.py +23 -20
- vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
- vectordb_bench/frontend/components/check_results/filters.py +20 -13
- vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
- vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
- vectordb_bench/frontend/components/concurrent/charts.py +79 -0
- vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
- vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
- vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
- vectordb_bench/frontend/components/custom/initStyle.py +15 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
- vectordb_bench/frontend/components/run_test/caseSelector.py +40 -28
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -5
- vectordb_bench/frontend/components/run_test/dbSelector.py +8 -14
- vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
- vectordb_bench/frontend/components/run_test/initStyle.py +14 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +13 -5
- vectordb_bench/frontend/components/tables/data.py +44 -0
- vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +140 -32
- vectordb_bench/frontend/{const → config}/styles.py +2 -0
- vectordb_bench/frontend/pages/concurrent.py +65 -0
- vectordb_bench/frontend/pages/custom.py +64 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
- vectordb_bench/frontend/pages/run_test.py +4 -0
- vectordb_bench/frontend/pages/tables.py +24 -0
- vectordb_bench/frontend/utils.py +17 -1
- vectordb_bench/frontend/vdb_benchmark.py +3 -3
- vectordb_bench/interface.py +21 -25
- vectordb_bench/metric.py +23 -1
- vectordb_bench/models.py +45 -1
- vectordb_bench/results/getLeaderboardData.py +1 -1
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/METADATA +228 -14
- vectordb_bench-0.0.12.dist-info/RECORD +115 -0
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/entry_points.txt +1 -0
- vectordb_bench-0.0.10.dist-info/RECORD +0 -88
- /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,362 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from concurrent.futures import wait
|
4
|
+
from datetime import datetime
|
5
|
+
from pprint import pformat
|
6
|
+
from typing import (
|
7
|
+
Annotated,
|
8
|
+
Callable,
|
9
|
+
List,
|
10
|
+
Optional,
|
11
|
+
Type,
|
12
|
+
TypedDict,
|
13
|
+
Unpack,
|
14
|
+
get_origin,
|
15
|
+
get_type_hints,
|
16
|
+
Dict,
|
17
|
+
Any,
|
18
|
+
)
|
19
|
+
import click
|
20
|
+
from .. import config
|
21
|
+
from ..backend.clients import DB
|
22
|
+
from ..interface import benchMarkRunner, global_result_future
|
23
|
+
from ..models import (
|
24
|
+
CaseConfig,
|
25
|
+
CaseType,
|
26
|
+
ConcurrencySearchConfig,
|
27
|
+
DBCaseConfig,
|
28
|
+
DBConfig,
|
29
|
+
TaskConfig,
|
30
|
+
TaskStage,
|
31
|
+
)
|
32
|
+
import os
|
33
|
+
from yaml import load
|
34
|
+
try:
|
35
|
+
from yaml import CLoader as Loader
|
36
|
+
except ImportError:
|
37
|
+
from yaml import Loader
|
38
|
+
|
39
|
+
|
40
|
+
def click_get_defaults_from_file(ctx, param, value):
|
41
|
+
if value:
|
42
|
+
if os.path.exists(value):
|
43
|
+
input_file = value
|
44
|
+
else:
|
45
|
+
input_file = os.path.join(config.CONFIG_LOCAL_DIR, value)
|
46
|
+
try:
|
47
|
+
with open(input_file, 'r') as f:
|
48
|
+
_config: Dict[str, Dict[str, Any]] = load(f.read(), Loader=Loader)
|
49
|
+
ctx.default_map = _config.get(ctx.command.name, {})
|
50
|
+
except Exception as e:
|
51
|
+
raise click.BadParameter(f"Failed to load config file: {e}")
|
52
|
+
return value
|
53
|
+
|
54
|
+
|
55
|
+
def click_parameter_decorators_from_typed_dict(
|
56
|
+
typed_dict: Type,
|
57
|
+
) -> Callable[[click.decorators.FC], click.decorators.FC]:
|
58
|
+
"""A convenience method decorator that will read in a TypedDict with parameters defined by Annotated types.
|
59
|
+
from .models import CaseConfig, CaseType, DBCaseConfig, DBConfig, TaskConfig, TaskStage
|
60
|
+
The click.options will be collected and re-composed as a single decorator to apply to the click.command.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
typed_dict (TypedDict) with Annotated[..., click.option()] keys
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
a fully decorated method
|
67
|
+
|
68
|
+
|
69
|
+
For clarity, the key names of the TypedDict will be used to determine the type hints for the input parameters.
|
70
|
+
The actual function parameters are controlled by the click.option definitions. You must manually ensure these are aligned in a sensible way!
|
71
|
+
|
72
|
+
Example:
|
73
|
+
```
|
74
|
+
class CommonTypedDict(TypedDict):
|
75
|
+
z: Annotated[int, click.option("--z/--no-z", is_flag=True, type=bool, help="help z", default=True, show_default=True)]
|
76
|
+
name: Annotated[str, click.argument("name", required=False, default="Jeff")]
|
77
|
+
|
78
|
+
class FooTypedDict(CommonTypedDict):
|
79
|
+
x: Annotated[int, click.option("--x", type=int, help="help x", default=1, show_default=True)]
|
80
|
+
y: Annotated[str, click.option("--y", type=str, help="help y", default="foo", show_default=True)]
|
81
|
+
|
82
|
+
@cli.command()
|
83
|
+
@click_parameter_decorators_from_typed_dict(FooTypedDict)
|
84
|
+
def foo(**parameters: Unpack[FooTypedDict]):
|
85
|
+
"Foo docstring"
|
86
|
+
print(f"input parameters: {parameters["x"]}")
|
87
|
+
```
|
88
|
+
"""
|
89
|
+
decorators = []
|
90
|
+
for _, t in get_type_hints(typed_dict, include_extras=True).items():
|
91
|
+
assert get_origin(t) is Annotated
|
92
|
+
if (
|
93
|
+
len(t.__metadata__) == 1
|
94
|
+
and t.__metadata__[0].__module__ == "click.decorators"
|
95
|
+
):
|
96
|
+
# happy path -- only accept Annotated[..., Union[click.option,click.argument,...]] with no additional metadata defined (len=1)
|
97
|
+
decorators.append(t.__metadata__[0])
|
98
|
+
else:
|
99
|
+
raise RuntimeError(
|
100
|
+
"Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring"
|
101
|
+
)
|
102
|
+
|
103
|
+
def deco(f):
|
104
|
+
for dec in reversed(decorators):
|
105
|
+
f = dec(f)
|
106
|
+
return f
|
107
|
+
|
108
|
+
return deco
|
109
|
+
|
110
|
+
|
111
|
+
def click_arg_split(ctx: click.Context, param: click.core.Option, value):
|
112
|
+
"""Will split a comma-separated list input into an actual list.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
ctx (...): unused click arg
|
116
|
+
param (...): unused click arg
|
117
|
+
value (str): input comma-separated list
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
value (List[str]): list of original
|
121
|
+
"""
|
122
|
+
# split columns by ',' and remove whitespace
|
123
|
+
if value is None:
|
124
|
+
return []
|
125
|
+
return [c.strip() for c in value.split(",") if c.strip()]
|
126
|
+
|
127
|
+
|
128
|
+
def parse_task_stages(
|
129
|
+
drop_old: bool,
|
130
|
+
load: bool,
|
131
|
+
search_serial: bool,
|
132
|
+
search_concurrent: bool,
|
133
|
+
) -> List[TaskStage]:
|
134
|
+
stages = []
|
135
|
+
if load and not drop_old:
|
136
|
+
raise RuntimeError("Dropping old data cannot be skipped if loading data")
|
137
|
+
elif drop_old and not load:
|
138
|
+
raise RuntimeError("Load cannot be skipped if dropping old data")
|
139
|
+
if drop_old:
|
140
|
+
stages.append(TaskStage.DROP_OLD)
|
141
|
+
if load:
|
142
|
+
stages.append(TaskStage.LOAD)
|
143
|
+
if search_serial:
|
144
|
+
stages.append(TaskStage.SEARCH_SERIAL)
|
145
|
+
if search_concurrent:
|
146
|
+
stages.append(TaskStage.SEARCH_CONCURRENT)
|
147
|
+
return stages
|
148
|
+
|
149
|
+
|
150
|
+
log = logging.getLogger(__name__)
|
151
|
+
|
152
|
+
|
153
|
+
class CommonTypedDict(TypedDict):
|
154
|
+
config_file: Annotated[
|
155
|
+
bool,
|
156
|
+
click.option('--config-file',
|
157
|
+
type=click.Path(),
|
158
|
+
callback=click_get_defaults_from_file,
|
159
|
+
is_eager=True,
|
160
|
+
expose_value=False,
|
161
|
+
help='Read configuration from yaml file'),
|
162
|
+
]
|
163
|
+
drop_old: Annotated[
|
164
|
+
bool,
|
165
|
+
click.option(
|
166
|
+
"--drop-old/--skip-drop-old",
|
167
|
+
type=bool,
|
168
|
+
default=True,
|
169
|
+
help="Drop old or skip",
|
170
|
+
show_default=True,
|
171
|
+
),
|
172
|
+
]
|
173
|
+
load: Annotated[
|
174
|
+
bool,
|
175
|
+
click.option(
|
176
|
+
"--load/--skip-load",
|
177
|
+
type=bool,
|
178
|
+
default=True,
|
179
|
+
help="Load or skip",
|
180
|
+
show_default=True,
|
181
|
+
),
|
182
|
+
]
|
183
|
+
search_serial: Annotated[
|
184
|
+
bool,
|
185
|
+
click.option(
|
186
|
+
"--search-serial/--skip-search-serial",
|
187
|
+
type=bool,
|
188
|
+
default=True,
|
189
|
+
help="Search serial or skip",
|
190
|
+
show_default=True,
|
191
|
+
),
|
192
|
+
]
|
193
|
+
search_concurrent: Annotated[
|
194
|
+
bool,
|
195
|
+
click.option(
|
196
|
+
"--search-concurrent/--skip-search-concurrent",
|
197
|
+
type=bool,
|
198
|
+
default=True,
|
199
|
+
help="Search concurrent or skip",
|
200
|
+
show_default=True,
|
201
|
+
),
|
202
|
+
]
|
203
|
+
case_type: Annotated[
|
204
|
+
str,
|
205
|
+
click.option(
|
206
|
+
"--case-type",
|
207
|
+
type=click.Choice([ct.name for ct in CaseType if ct.name != "Custom"]),
|
208
|
+
default="Performance1536D50K",
|
209
|
+
help="Case type",
|
210
|
+
),
|
211
|
+
]
|
212
|
+
db_label: Annotated[
|
213
|
+
str,
|
214
|
+
click.option(
|
215
|
+
"--db-label", type=str, help="Db label, default: date in ISO format",
|
216
|
+
show_default=True,
|
217
|
+
default=datetime.now().isoformat()
|
218
|
+
),
|
219
|
+
]
|
220
|
+
dry_run: Annotated[
|
221
|
+
bool,
|
222
|
+
click.option(
|
223
|
+
"--dry-run",
|
224
|
+
type=bool,
|
225
|
+
default=False,
|
226
|
+
is_flag=True,
|
227
|
+
help="Print just the configuration and exit without running the tasks",
|
228
|
+
),
|
229
|
+
]
|
230
|
+
k: Annotated[
|
231
|
+
int,
|
232
|
+
click.option(
|
233
|
+
"--k",
|
234
|
+
type=int,
|
235
|
+
default=config.K_DEFAULT,
|
236
|
+
show_default=True,
|
237
|
+
help="K value for number of nearest neighbors to search",
|
238
|
+
),
|
239
|
+
]
|
240
|
+
concurrency_duration: Annotated[
|
241
|
+
int,
|
242
|
+
click.option(
|
243
|
+
"--concurrency-duration",
|
244
|
+
type=int,
|
245
|
+
default=config.CONCURRENCY_DURATION,
|
246
|
+
show_default=True,
|
247
|
+
help="Adjusts the duration in seconds of each concurrency search",
|
248
|
+
),
|
249
|
+
]
|
250
|
+
num_concurrency: Annotated[
|
251
|
+
List[str],
|
252
|
+
click.option(
|
253
|
+
"--num-concurrency",
|
254
|
+
type=str,
|
255
|
+
help="Comma-separated list of concurrency values to test during concurrent search",
|
256
|
+
show_default=True,
|
257
|
+
default=",".join(map(str, config.NUM_CONCURRENCY)),
|
258
|
+
callback=lambda *args: list(map(int, click_arg_split(*args))),
|
259
|
+
),
|
260
|
+
]
|
261
|
+
|
262
|
+
|
263
|
+
class HNSWBaseTypedDict(TypedDict):
|
264
|
+
m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m")]
|
265
|
+
ef_construction: Annotated[
|
266
|
+
Optional[int],
|
267
|
+
click.option("--ef-construction", type=int, help="hnsw ef-construction"),
|
268
|
+
]
|
269
|
+
|
270
|
+
|
271
|
+
class HNSWBaseRequiredTypedDict(TypedDict):
|
272
|
+
m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m", required=True)]
|
273
|
+
ef_construction: Annotated[
|
274
|
+
Optional[int],
|
275
|
+
click.option("--ef-construction", type=int, help="hnsw ef-construction", required=True),
|
276
|
+
]
|
277
|
+
|
278
|
+
|
279
|
+
class HNSWFlavor1(HNSWBaseTypedDict):
|
280
|
+
ef_search: Annotated[
|
281
|
+
Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search")
|
282
|
+
]
|
283
|
+
|
284
|
+
|
285
|
+
class HNSWFlavor2(HNSWBaseTypedDict):
|
286
|
+
ef_runtime: Annotated[
|
287
|
+
Optional[int], click.option("--ef-runtime", type=int, help="hnsw ef-runtime")
|
288
|
+
]
|
289
|
+
|
290
|
+
|
291
|
+
class HNSWFlavor3(HNSWBaseRequiredTypedDict):
|
292
|
+
ef_search: Annotated[
|
293
|
+
Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search", required=True)
|
294
|
+
]
|
295
|
+
|
296
|
+
|
297
|
+
class IVFFlatTypedDict(TypedDict):
|
298
|
+
lists: Annotated[
|
299
|
+
Optional[int], click.option("--lists", type=int, help="ivfflat lists")
|
300
|
+
]
|
301
|
+
probes: Annotated[
|
302
|
+
Optional[int], click.option("--probes", type=int, help="ivfflat probes")
|
303
|
+
]
|
304
|
+
|
305
|
+
|
306
|
+
class IVFFlatTypedDictN(TypedDict):
|
307
|
+
nlist: Annotated[
|
308
|
+
Optional[int], click.option("--lists", "nlist", type=int, help="ivfflat lists", required=True)
|
309
|
+
]
|
310
|
+
nprobe: Annotated[
|
311
|
+
Optional[int], click.option("--probes", "nprobe", type=int, help="ivfflat probes", required=True)
|
312
|
+
]
|
313
|
+
|
314
|
+
|
315
|
+
@click.group()
|
316
|
+
def cli():
|
317
|
+
...
|
318
|
+
|
319
|
+
|
320
|
+
def run(
|
321
|
+
db: DB,
|
322
|
+
db_config: DBConfig,
|
323
|
+
db_case_config: DBCaseConfig,
|
324
|
+
**parameters: Unpack[CommonTypedDict],
|
325
|
+
):
|
326
|
+
"""Builds a single VectorDBBench Task and runs it, awaiting the task until finished.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
db (DB)
|
330
|
+
db_config (DBConfig)
|
331
|
+
db_case_config (DBCaseConfig)
|
332
|
+
**parameters: expects keys from CommonTypedDict
|
333
|
+
"""
|
334
|
+
|
335
|
+
task = TaskConfig(
|
336
|
+
db=db,
|
337
|
+
db_config=db_config,
|
338
|
+
db_case_config=db_case_config,
|
339
|
+
case_config=CaseConfig(
|
340
|
+
case_id=CaseType[parameters["case_type"]],
|
341
|
+
k=parameters["k"],
|
342
|
+
concurrency_search_config=ConcurrencySearchConfig(
|
343
|
+
concurrency_duration=parameters["concurrency_duration"],
|
344
|
+
num_concurrency=[int(s) for s in parameters["num_concurrency"]],
|
345
|
+
),
|
346
|
+
),
|
347
|
+
stages=parse_task_stages(
|
348
|
+
(
|
349
|
+
False if not parameters["load"] else parameters["drop_old"]
|
350
|
+
), # only drop old data if loading new data
|
351
|
+
parameters["load"],
|
352
|
+
parameters["search_serial"],
|
353
|
+
parameters["search_concurrent"],
|
354
|
+
),
|
355
|
+
)
|
356
|
+
|
357
|
+
log.info(f"Task:\n{pformat(task)}\n")
|
358
|
+
if not parameters["dry_run"]:
|
359
|
+
benchMarkRunner.run([task])
|
360
|
+
time.sleep(5)
|
361
|
+
if global_result_future:
|
362
|
+
wait([global_result_future])
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from ..backend.clients.pgvector.cli import PgVectorHNSW
|
2
|
+
from ..backend.clients.redis.cli import Redis
|
3
|
+
from ..backend.clients.test.cli import Test
|
4
|
+
from ..backend.clients.weaviate_cloud.cli import Weaviate
|
5
|
+
from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
|
6
|
+
from ..backend.clients.milvus.cli import MilvusAutoIndex
|
7
|
+
from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
|
8
|
+
|
9
|
+
|
10
|
+
from .cli import cli
|
11
|
+
|
12
|
+
cli.add_command(PgVectorHNSW)
|
13
|
+
cli.add_command(Redis)
|
14
|
+
cli.add_command(Weaviate)
|
15
|
+
cli.add_command(Test)
|
16
|
+
cli.add_command(ZillizAutoIndex)
|
17
|
+
cli.add_command(MilvusAutoIndex)
|
18
|
+
cli.add_command(AWSOpenSearch)
|
19
|
+
|
20
|
+
|
21
|
+
if __name__ == "__main__":
|
22
|
+
cli()
|
@@ -0,0 +1,17 @@
|
|
1
|
+
pgvectorhnsw:
|
2
|
+
db_label: pgConfigTest
|
3
|
+
user_name: vectordbbench
|
4
|
+
db_name: vectordbbench
|
5
|
+
host: localhost
|
6
|
+
m: 16
|
7
|
+
ef_construction: 128
|
8
|
+
ef_search: 128
|
9
|
+
milvushnsw:
|
10
|
+
skip_search_serial: True
|
11
|
+
case_type: Performance1536D50K
|
12
|
+
uri: http://localhost:19530
|
13
|
+
m: 16
|
14
|
+
ef_construction: 128
|
15
|
+
ef_search: 128
|
16
|
+
drop_old: False
|
17
|
+
load: False
|
@@ -0,0 +1,18 @@
|
|
1
|
+
[
|
2
|
+
{
|
3
|
+
"name": "My Dataset (Performace Case)",
|
4
|
+
"description": "this is a customized dataset.",
|
5
|
+
"load_timeout": 36000,
|
6
|
+
"optimize_timeout": 36000,
|
7
|
+
"dataset_config": {
|
8
|
+
"name": "My Dataset",
|
9
|
+
"dir": "/my_dataset_path",
|
10
|
+
"size": 1000000,
|
11
|
+
"dim": 1024,
|
12
|
+
"metric_type": "L2",
|
13
|
+
"file_count": 1,
|
14
|
+
"use_shuffled": false,
|
15
|
+
"with_gt": true
|
16
|
+
}
|
17
|
+
}
|
18
|
+
]
|
@@ -1,19 +1,19 @@
|
|
1
1
|
from vectordb_bench.backend.cases import Case
|
2
2
|
from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
|
3
3
|
from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
|
4
|
-
from vectordb_bench.frontend.
|
4
|
+
from vectordb_bench.frontend.config.styles import *
|
5
5
|
from vectordb_bench.models import ResultLabel
|
6
6
|
import plotly.express as px
|
7
7
|
|
8
8
|
|
9
|
-
def drawCharts(st, allData, failedTasks,
|
9
|
+
def drawCharts(st, allData, failedTasks, caseNames: list[str]):
|
10
10
|
initMainExpanderStyle(st)
|
11
|
-
for
|
12
|
-
chartContainer = st.expander(
|
13
|
-
data = [data for data in allData if data["case_name"] ==
|
11
|
+
for caseName in caseNames:
|
12
|
+
chartContainer = st.expander(caseName, True)
|
13
|
+
data = [data for data in allData if data["case_name"] == caseName]
|
14
14
|
drawChart(data, chartContainer)
|
15
15
|
|
16
|
-
errorDBs = failedTasks[
|
16
|
+
errorDBs = failedTasks[caseName]
|
17
17
|
showFailedDBs(chartContainer, errorDBs)
|
18
18
|
|
19
19
|
|
@@ -8,9 +8,9 @@ from vectordb_bench.models import CaseResult, ResultLabel
|
|
8
8
|
def getChartData(
|
9
9
|
tasks: list[CaseResult],
|
10
10
|
dbNames: list[str],
|
11
|
-
|
11
|
+
caseNames: list[str],
|
12
12
|
):
|
13
|
-
filterTasks = getFilterTasks(tasks, dbNames,
|
13
|
+
filterTasks = getFilterTasks(tasks, dbNames, caseNames)
|
14
14
|
mergedTasks, failedTasks = mergeTasks(filterTasks)
|
15
15
|
return mergedTasks, failedTasks
|
16
16
|
|
@@ -18,14 +18,13 @@ def getChartData(
|
|
18
18
|
def getFilterTasks(
|
19
19
|
tasks: list[CaseResult],
|
20
20
|
dbNames: list[str],
|
21
|
-
|
21
|
+
caseNames: list[str],
|
22
22
|
) -> list[CaseResult]:
|
23
|
-
case_ids = [case.case_id for case in cases]
|
24
23
|
filterTasks = [
|
25
24
|
task
|
26
25
|
for task in tasks
|
27
26
|
if task.task_config.db_name in dbNames
|
28
|
-
and task.task_config.case_config.case_id in
|
27
|
+
and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
|
29
28
|
]
|
30
29
|
return filterTasks
|
31
30
|
|
@@ -36,16 +35,17 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
36
35
|
db_name = task.task_config.db_name
|
37
36
|
db = task.task_config.db.value
|
38
37
|
db_label = task.task_config.db_config.db_label or ""
|
39
|
-
|
40
|
-
dbCaseMetricsMap[db_name][
|
38
|
+
case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
|
39
|
+
dbCaseMetricsMap[db_name][case.name] = {
|
41
40
|
"db": db,
|
42
41
|
"db_label": db_label,
|
43
42
|
"metrics": mergeMetrics(
|
44
|
-
dbCaseMetricsMap[db_name][
|
43
|
+
dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
|
45
44
|
asdict(task.metrics),
|
46
45
|
),
|
47
46
|
"label": getBetterLabel(
|
48
|
-
dbCaseMetricsMap[db_name][
|
47
|
+
dbCaseMetricsMap[db_name][case.name].get(
|
48
|
+
"label", ResultLabel.FAILED),
|
49
49
|
task.label,
|
50
50
|
),
|
51
51
|
}
|
@@ -53,12 +53,11 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
53
53
|
mergedTasks = []
|
54
54
|
failedTasks = defaultdict(lambda: defaultdict(str))
|
55
55
|
for db_name, caseMetricsMap in dbCaseMetricsMap.items():
|
56
|
-
for
|
56
|
+
for case_name, metricInfo in caseMetricsMap.items():
|
57
57
|
metrics = metricInfo["metrics"]
|
58
58
|
db = metricInfo["db"]
|
59
59
|
db_label = metricInfo["db_label"]
|
60
60
|
label = metricInfo["label"]
|
61
|
-
case_name = case_id.case_name
|
62
61
|
if label == ResultLabel.NORMAL:
|
63
62
|
mergedTasks.append(
|
64
63
|
{
|
@@ -80,22 +79,26 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
|
|
80
79
|
metrics = {**metrics_1}
|
81
80
|
for key, value in metrics_2.items():
|
82
81
|
metrics[key] = (
|
83
|
-
getBetterMetric(
|
82
|
+
getBetterMetric(
|
83
|
+
key, value, metrics[key]) if key in metrics else value
|
84
84
|
)
|
85
85
|
|
86
86
|
return metrics
|
87
87
|
|
88
88
|
|
89
89
|
def getBetterMetric(metric, value_1, value_2):
|
90
|
-
|
91
|
-
|
92
|
-
|
90
|
+
try:
|
91
|
+
if value_1 < 1e-7:
|
92
|
+
return value_2
|
93
|
+
if value_2 < 1e-7:
|
94
|
+
return value_1
|
95
|
+
return (
|
96
|
+
min(value_1, value_2)
|
97
|
+
if isLowerIsBetterMetric(metric)
|
98
|
+
else max(value_1, value_2)
|
99
|
+
)
|
100
|
+
except Exception:
|
93
101
|
return value_1
|
94
|
-
return (
|
95
|
-
min(value_1, value_2)
|
96
|
-
if isLowerIsBetterMetric(metric)
|
97
|
-
else max(value_1, value_2)
|
98
|
-
)
|
99
102
|
|
100
103
|
|
101
104
|
def getBetterLabel(label_1: ResultLabel, label_2: ResultLabel):
|
@@ -1,7 +1,7 @@
|
|
1
1
|
def initMainExpanderStyle(st):
|
2
2
|
st.markdown(
|
3
3
|
"""<style>
|
4
|
-
.main
|
4
|
+
.main div[data-testid='stExpander'] p {font-size: 18px; font-weight: 600;}
|
5
5
|
.main div[data-testid='stExpander'] {
|
6
6
|
background-color: #F6F8FA;
|
7
7
|
border: 1px solid #A9BDD140;
|
@@ -1,8 +1,8 @@
|
|
1
1
|
from vectordb_bench.backend.cases import Case
|
2
2
|
from vectordb_bench.frontend.components.check_results.data import getChartData
|
3
3
|
from vectordb_bench.frontend.components.check_results.expanderStyle import initSidebarExanderStyle
|
4
|
-
from vectordb_bench.frontend.
|
5
|
-
from vectordb_bench.frontend.
|
4
|
+
from vectordb_bench.frontend.config.dbCaseConfigs import CASE_NAME_ORDER
|
5
|
+
from vectordb_bench.frontend.config.styles import *
|
6
6
|
import streamlit as st
|
7
7
|
|
8
8
|
from vectordb_bench.models import CaseResult, TestResult
|
@@ -18,11 +18,12 @@ def getshownData(results: list[TestResult], st):
|
|
18
18
|
st.header("Filters")
|
19
19
|
|
20
20
|
shownResults = getshownResults(results, st)
|
21
|
-
showDBNames,
|
21
|
+
showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st)
|
22
22
|
|
23
|
-
shownData, failedTasks = getChartData(
|
23
|
+
shownData, failedTasks = getChartData(
|
24
|
+
shownResults, showDBNames, showCaseNames)
|
24
25
|
|
25
|
-
return shownData, failedTasks,
|
26
|
+
return shownData, failedTasks, showCaseNames
|
26
27
|
|
27
28
|
|
28
29
|
def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
|
@@ -52,12 +53,18 @@ def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
|
|
52
53
|
return selectedResult
|
53
54
|
|
54
55
|
|
55
|
-
def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[
|
56
|
+
def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[str]]:
|
56
57
|
initSidebarExanderStyle(st)
|
57
58
|
allDbNames = list(set({res.task_config.db_name for res in result}))
|
58
59
|
allDbNames.sort()
|
59
|
-
|
60
|
-
|
60
|
+
allCases: list[Case] = [
|
61
|
+
res.task_config.case_config.case_id.case_cls(
|
62
|
+
res.task_config.case_config.custom_case)
|
63
|
+
for res in result
|
64
|
+
]
|
65
|
+
allCaseNameSet = set({case.name for case in allCases})
|
66
|
+
allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + \
|
67
|
+
[case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER]
|
61
68
|
|
62
69
|
# DB Filter
|
63
70
|
dbFilterContainer = st.container()
|
@@ -70,15 +77,14 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[Ca
|
|
70
77
|
|
71
78
|
# Case Filter
|
72
79
|
caseFilterContainer = st.container()
|
73
|
-
|
80
|
+
showCaseNames = filterView(
|
74
81
|
caseFilterContainer,
|
75
82
|
"Case Filter",
|
76
|
-
[
|
83
|
+
[caseName for caseName in allCaseNames],
|
77
84
|
col=1,
|
78
|
-
optionLables=[case.name for case in allCases],
|
79
85
|
)
|
80
86
|
|
81
|
-
return showDBNames,
|
87
|
+
return showDBNames, showCaseNames
|
82
88
|
|
83
89
|
|
84
90
|
def filterView(container, header, options, col, optionLables=None):
|
@@ -114,7 +120,8 @@ def filterView(container, header, options, col, optionLables=None):
|
|
114
120
|
)
|
115
121
|
if optionLables is None:
|
116
122
|
optionLables = options
|
117
|
-
isActive = {option: st.session_state[selectAllState]
|
123
|
+
isActive = {option: st.session_state[selectAllState]
|
124
|
+
for option in optionLables}
|
118
125
|
for i, option in enumerate(optionLables):
|
119
126
|
isActive[option] = columns[i % col].checkbox(
|
120
127
|
optionLables[i],
|
@@ -3,7 +3,7 @@ import pandas as pd
|
|
3
3
|
from collections import defaultdict
|
4
4
|
import streamlit as st
|
5
5
|
|
6
|
-
from vectordb_bench.frontend.
|
6
|
+
from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE
|
7
7
|
|
8
8
|
|
9
9
|
def priceTable(container, data):
|