vectordb-bench 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. vectordb_bench/__init__.py +19 -5
  2. vectordb_bench/backend/assembler.py +1 -1
  3. vectordb_bench/backend/cases.py +93 -27
  4. vectordb_bench/backend/clients/__init__.py +14 -0
  5. vectordb_bench/backend/clients/api.py +1 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
  7. vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
  8. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
  9. vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
  10. vectordb_bench/backend/clients/milvus/cli.py +291 -0
  11. vectordb_bench/backend/clients/milvus/milvus.py +13 -6
  12. vectordb_bench/backend/clients/pgvector/cli.py +116 -0
  13. vectordb_bench/backend/clients/pgvector/config.py +1 -1
  14. vectordb_bench/backend/clients/pgvector/pgvector.py +7 -4
  15. vectordb_bench/backend/clients/redis/cli.py +74 -0
  16. vectordb_bench/backend/clients/test/cli.py +25 -0
  17. vectordb_bench/backend/clients/test/config.py +18 -0
  18. vectordb_bench/backend/clients/test/test.py +62 -0
  19. vectordb_bench/backend/clients/weaviate_cloud/cli.py +41 -0
  20. vectordb_bench/backend/clients/zilliz_cloud/cli.py +55 -0
  21. vectordb_bench/backend/dataset.py +27 -5
  22. vectordb_bench/backend/runner/mp_runner.py +14 -3
  23. vectordb_bench/backend/runner/serial_runner.py +7 -3
  24. vectordb_bench/backend/task_runner.py +76 -26
  25. vectordb_bench/cli/__init__.py +0 -0
  26. vectordb_bench/cli/cli.py +362 -0
  27. vectordb_bench/cli/vectordbbench.py +22 -0
  28. vectordb_bench/config-files/sample_config.yml +17 -0
  29. vectordb_bench/custom/custom_case.json +18 -0
  30. vectordb_bench/frontend/components/check_results/charts.py +6 -6
  31. vectordb_bench/frontend/components/check_results/data.py +23 -20
  32. vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
  33. vectordb_bench/frontend/components/check_results/filters.py +20 -13
  34. vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
  35. vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
  36. vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
  37. vectordb_bench/frontend/components/concurrent/charts.py +79 -0
  38. vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
  39. vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
  40. vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
  41. vectordb_bench/frontend/components/custom/initStyle.py +15 -0
  42. vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
  43. vectordb_bench/frontend/components/run_test/caseSelector.py +40 -28
  44. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -5
  45. vectordb_bench/frontend/components/run_test/dbSelector.py +8 -14
  46. vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
  47. vectordb_bench/frontend/components/run_test/initStyle.py +14 -0
  48. vectordb_bench/frontend/components/run_test/submitTask.py +13 -5
  49. vectordb_bench/frontend/components/tables/data.py +44 -0
  50. vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +140 -32
  51. vectordb_bench/frontend/{const → config}/styles.py +2 -0
  52. vectordb_bench/frontend/pages/concurrent.py +65 -0
  53. vectordb_bench/frontend/pages/custom.py +64 -0
  54. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
  55. vectordb_bench/frontend/pages/run_test.py +4 -0
  56. vectordb_bench/frontend/pages/tables.py +24 -0
  57. vectordb_bench/frontend/utils.py +17 -1
  58. vectordb_bench/frontend/vdb_benchmark.py +3 -3
  59. vectordb_bench/interface.py +21 -25
  60. vectordb_bench/metric.py +23 -1
  61. vectordb_bench/models.py +45 -1
  62. vectordb_bench/results/getLeaderboardData.py +1 -1
  63. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/METADATA +228 -14
  64. vectordb_bench-0.0.12.dist-info/RECORD +115 -0
  65. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/WHEEL +1 -1
  66. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/entry_points.txt +1 -0
  67. vectordb_bench-0.0.10.dist-info/RECORD +0 -88
  68. /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
  69. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/LICENSE +0 -0
  70. {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,362 @@
1
+ import logging
2
+ import time
3
+ from concurrent.futures import wait
4
+ from datetime import datetime
5
+ from pprint import pformat
6
+ from typing import (
7
+ Annotated,
8
+ Callable,
9
+ List,
10
+ Optional,
11
+ Type,
12
+ TypedDict,
13
+ Unpack,
14
+ get_origin,
15
+ get_type_hints,
16
+ Dict,
17
+ Any,
18
+ )
19
+ import click
20
+ from .. import config
21
+ from ..backend.clients import DB
22
+ from ..interface import benchMarkRunner, global_result_future
23
+ from ..models import (
24
+ CaseConfig,
25
+ CaseType,
26
+ ConcurrencySearchConfig,
27
+ DBCaseConfig,
28
+ DBConfig,
29
+ TaskConfig,
30
+ TaskStage,
31
+ )
32
+ import os
33
+ from yaml import load
34
+ try:
35
+ from yaml import CLoader as Loader
36
+ except ImportError:
37
+ from yaml import Loader
38
+
39
+
40
+ def click_get_defaults_from_file(ctx, param, value):
41
+ if value:
42
+ if os.path.exists(value):
43
+ input_file = value
44
+ else:
45
+ input_file = os.path.join(config.CONFIG_LOCAL_DIR, value)
46
+ try:
47
+ with open(input_file, 'r') as f:
48
+ _config: Dict[str, Dict[str, Any]] = load(f.read(), Loader=Loader)
49
+ ctx.default_map = _config.get(ctx.command.name, {})
50
+ except Exception as e:
51
+ raise click.BadParameter(f"Failed to load config file: {e}")
52
+ return value
53
+
54
+
55
+ def click_parameter_decorators_from_typed_dict(
56
+ typed_dict: Type,
57
+ ) -> Callable[[click.decorators.FC], click.decorators.FC]:
58
+ """A convenience method decorator that will read in a TypedDict with parameters defined by Annotated types.
59
+ from .models import CaseConfig, CaseType, DBCaseConfig, DBConfig, TaskConfig, TaskStage
60
+ The click.options will be collected and re-composed as a single decorator to apply to the click.command.
61
+
62
+ Args:
63
+ typed_dict (TypedDict) with Annotated[..., click.option()] keys
64
+
65
+ Returns:
66
+ a fully decorated method
67
+
68
+
69
+ For clarity, the key names of the TypedDict will be used to determine the type hints for the input parameters.
70
+ The actual function parameters are controlled by the click.option definitions. You must manually ensure these are aligned in a sensible way!
71
+
72
+ Example:
73
+ ```
74
+ class CommonTypedDict(TypedDict):
75
+ z: Annotated[int, click.option("--z/--no-z", is_flag=True, type=bool, help="help z", default=True, show_default=True)]
76
+ name: Annotated[str, click.argument("name", required=False, default="Jeff")]
77
+
78
+ class FooTypedDict(CommonTypedDict):
79
+ x: Annotated[int, click.option("--x", type=int, help="help x", default=1, show_default=True)]
80
+ y: Annotated[str, click.option("--y", type=str, help="help y", default="foo", show_default=True)]
81
+
82
+ @cli.command()
83
+ @click_parameter_decorators_from_typed_dict(FooTypedDict)
84
+ def foo(**parameters: Unpack[FooTypedDict]):
85
+ "Foo docstring"
86
+ print(f"input parameters: {parameters["x"]}")
87
+ ```
88
+ """
89
+ decorators = []
90
+ for _, t in get_type_hints(typed_dict, include_extras=True).items():
91
+ assert get_origin(t) is Annotated
92
+ if (
93
+ len(t.__metadata__) == 1
94
+ and t.__metadata__[0].__module__ == "click.decorators"
95
+ ):
96
+ # happy path -- only accept Annotated[..., Union[click.option,click.argument,...]] with no additional metadata defined (len=1)
97
+ decorators.append(t.__metadata__[0])
98
+ else:
99
+ raise RuntimeError(
100
+ "Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring"
101
+ )
102
+
103
+ def deco(f):
104
+ for dec in reversed(decorators):
105
+ f = dec(f)
106
+ return f
107
+
108
+ return deco
109
+
110
+
111
+ def click_arg_split(ctx: click.Context, param: click.core.Option, value):
112
+ """Will split a comma-separated list input into an actual list.
113
+
114
+ Args:
115
+ ctx (...): unused click arg
116
+ param (...): unused click arg
117
+ value (str): input comma-separated list
118
+
119
+ Returns:
120
+ value (List[str]): list of original
121
+ """
122
+ # split columns by ',' and remove whitespace
123
+ if value is None:
124
+ return []
125
+ return [c.strip() for c in value.split(",") if c.strip()]
126
+
127
+
128
+ def parse_task_stages(
129
+ drop_old: bool,
130
+ load: bool,
131
+ search_serial: bool,
132
+ search_concurrent: bool,
133
+ ) -> List[TaskStage]:
134
+ stages = []
135
+ if load and not drop_old:
136
+ raise RuntimeError("Dropping old data cannot be skipped if loading data")
137
+ elif drop_old and not load:
138
+ raise RuntimeError("Load cannot be skipped if dropping old data")
139
+ if drop_old:
140
+ stages.append(TaskStage.DROP_OLD)
141
+ if load:
142
+ stages.append(TaskStage.LOAD)
143
+ if search_serial:
144
+ stages.append(TaskStage.SEARCH_SERIAL)
145
+ if search_concurrent:
146
+ stages.append(TaskStage.SEARCH_CONCURRENT)
147
+ return stages
148
+
149
+
150
+ log = logging.getLogger(__name__)
151
+
152
+
153
+ class CommonTypedDict(TypedDict):
154
+ config_file: Annotated[
155
+ bool,
156
+ click.option('--config-file',
157
+ type=click.Path(),
158
+ callback=click_get_defaults_from_file,
159
+ is_eager=True,
160
+ expose_value=False,
161
+ help='Read configuration from yaml file'),
162
+ ]
163
+ drop_old: Annotated[
164
+ bool,
165
+ click.option(
166
+ "--drop-old/--skip-drop-old",
167
+ type=bool,
168
+ default=True,
169
+ help="Drop old or skip",
170
+ show_default=True,
171
+ ),
172
+ ]
173
+ load: Annotated[
174
+ bool,
175
+ click.option(
176
+ "--load/--skip-load",
177
+ type=bool,
178
+ default=True,
179
+ help="Load or skip",
180
+ show_default=True,
181
+ ),
182
+ ]
183
+ search_serial: Annotated[
184
+ bool,
185
+ click.option(
186
+ "--search-serial/--skip-search-serial",
187
+ type=bool,
188
+ default=True,
189
+ help="Search serial or skip",
190
+ show_default=True,
191
+ ),
192
+ ]
193
+ search_concurrent: Annotated[
194
+ bool,
195
+ click.option(
196
+ "--search-concurrent/--skip-search-concurrent",
197
+ type=bool,
198
+ default=True,
199
+ help="Search concurrent or skip",
200
+ show_default=True,
201
+ ),
202
+ ]
203
+ case_type: Annotated[
204
+ str,
205
+ click.option(
206
+ "--case-type",
207
+ type=click.Choice([ct.name for ct in CaseType if ct.name != "Custom"]),
208
+ default="Performance1536D50K",
209
+ help="Case type",
210
+ ),
211
+ ]
212
+ db_label: Annotated[
213
+ str,
214
+ click.option(
215
+ "--db-label", type=str, help="Db label, default: date in ISO format",
216
+ show_default=True,
217
+ default=datetime.now().isoformat()
218
+ ),
219
+ ]
220
+ dry_run: Annotated[
221
+ bool,
222
+ click.option(
223
+ "--dry-run",
224
+ type=bool,
225
+ default=False,
226
+ is_flag=True,
227
+ help="Print just the configuration and exit without running the tasks",
228
+ ),
229
+ ]
230
+ k: Annotated[
231
+ int,
232
+ click.option(
233
+ "--k",
234
+ type=int,
235
+ default=config.K_DEFAULT,
236
+ show_default=True,
237
+ help="K value for number of nearest neighbors to search",
238
+ ),
239
+ ]
240
+ concurrency_duration: Annotated[
241
+ int,
242
+ click.option(
243
+ "--concurrency-duration",
244
+ type=int,
245
+ default=config.CONCURRENCY_DURATION,
246
+ show_default=True,
247
+ help="Adjusts the duration in seconds of each concurrency search",
248
+ ),
249
+ ]
250
+ num_concurrency: Annotated[
251
+ List[str],
252
+ click.option(
253
+ "--num-concurrency",
254
+ type=str,
255
+ help="Comma-separated list of concurrency values to test during concurrent search",
256
+ show_default=True,
257
+ default=",".join(map(str, config.NUM_CONCURRENCY)),
258
+ callback=lambda *args: list(map(int, click_arg_split(*args))),
259
+ ),
260
+ ]
261
+
262
+
263
+ class HNSWBaseTypedDict(TypedDict):
264
+ m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m")]
265
+ ef_construction: Annotated[
266
+ Optional[int],
267
+ click.option("--ef-construction", type=int, help="hnsw ef-construction"),
268
+ ]
269
+
270
+
271
+ class HNSWBaseRequiredTypedDict(TypedDict):
272
+ m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m", required=True)]
273
+ ef_construction: Annotated[
274
+ Optional[int],
275
+ click.option("--ef-construction", type=int, help="hnsw ef-construction", required=True),
276
+ ]
277
+
278
+
279
+ class HNSWFlavor1(HNSWBaseTypedDict):
280
+ ef_search: Annotated[
281
+ Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search")
282
+ ]
283
+
284
+
285
+ class HNSWFlavor2(HNSWBaseTypedDict):
286
+ ef_runtime: Annotated[
287
+ Optional[int], click.option("--ef-runtime", type=int, help="hnsw ef-runtime")
288
+ ]
289
+
290
+
291
+ class HNSWFlavor3(HNSWBaseRequiredTypedDict):
292
+ ef_search: Annotated[
293
+ Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search", required=True)
294
+ ]
295
+
296
+
297
+ class IVFFlatTypedDict(TypedDict):
298
+ lists: Annotated[
299
+ Optional[int], click.option("--lists", type=int, help="ivfflat lists")
300
+ ]
301
+ probes: Annotated[
302
+ Optional[int], click.option("--probes", type=int, help="ivfflat probes")
303
+ ]
304
+
305
+
306
+ class IVFFlatTypedDictN(TypedDict):
307
+ nlist: Annotated[
308
+ Optional[int], click.option("--lists", "nlist", type=int, help="ivfflat lists", required=True)
309
+ ]
310
+ nprobe: Annotated[
311
+ Optional[int], click.option("--probes", "nprobe", type=int, help="ivfflat probes", required=True)
312
+ ]
313
+
314
+
315
+ @click.group()
316
+ def cli():
317
+ ...
318
+
319
+
320
+ def run(
321
+ db: DB,
322
+ db_config: DBConfig,
323
+ db_case_config: DBCaseConfig,
324
+ **parameters: Unpack[CommonTypedDict],
325
+ ):
326
+ """Builds a single VectorDBBench Task and runs it, awaiting the task until finished.
327
+
328
+ Args:
329
+ db (DB)
330
+ db_config (DBConfig)
331
+ db_case_config (DBCaseConfig)
332
+ **parameters: expects keys from CommonTypedDict
333
+ """
334
+
335
+ task = TaskConfig(
336
+ db=db,
337
+ db_config=db_config,
338
+ db_case_config=db_case_config,
339
+ case_config=CaseConfig(
340
+ case_id=CaseType[parameters["case_type"]],
341
+ k=parameters["k"],
342
+ concurrency_search_config=ConcurrencySearchConfig(
343
+ concurrency_duration=parameters["concurrency_duration"],
344
+ num_concurrency=[int(s) for s in parameters["num_concurrency"]],
345
+ ),
346
+ ),
347
+ stages=parse_task_stages(
348
+ (
349
+ False if not parameters["load"] else parameters["drop_old"]
350
+ ), # only drop old data if loading new data
351
+ parameters["load"],
352
+ parameters["search_serial"],
353
+ parameters["search_concurrent"],
354
+ ),
355
+ )
356
+
357
+ log.info(f"Task:\n{pformat(task)}\n")
358
+ if not parameters["dry_run"]:
359
+ benchMarkRunner.run([task])
360
+ time.sleep(5)
361
+ if global_result_future:
362
+ wait([global_result_future])
@@ -0,0 +1,22 @@
1
+ from ..backend.clients.pgvector.cli import PgVectorHNSW
2
+ from ..backend.clients.redis.cli import Redis
3
+ from ..backend.clients.test.cli import Test
4
+ from ..backend.clients.weaviate_cloud.cli import Weaviate
5
+ from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
6
+ from ..backend.clients.milvus.cli import MilvusAutoIndex
7
+ from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
8
+
9
+
10
+ from .cli import cli
11
+
12
+ cli.add_command(PgVectorHNSW)
13
+ cli.add_command(Redis)
14
+ cli.add_command(Weaviate)
15
+ cli.add_command(Test)
16
+ cli.add_command(ZillizAutoIndex)
17
+ cli.add_command(MilvusAutoIndex)
18
+ cli.add_command(AWSOpenSearch)
19
+
20
+
21
+ if __name__ == "__main__":
22
+ cli()
@@ -0,0 +1,17 @@
1
+ pgvectorhnsw:
2
+ db_label: pgConfigTest
3
+ user_name: vectordbbench
4
+ db_name: vectordbbench
5
+ host: localhost
6
+ m: 16
7
+ ef_construction: 128
8
+ ef_search: 128
9
+ milvushnsw:
10
+ skip_search_serial: True
11
+ case_type: Performance1536D50K
12
+ uri: http://localhost:19530
13
+ m: 16
14
+ ef_construction: 128
15
+ ef_search: 128
16
+ drop_old: False
17
+ load: False
@@ -0,0 +1,18 @@
1
+ [
2
+ {
3
+ "name": "My Dataset (Performace Case)",
4
+ "description": "this is a customized dataset.",
5
+ "load_timeout": 36000,
6
+ "optimize_timeout": 36000,
7
+ "dataset_config": {
8
+ "name": "My Dataset",
9
+ "dir": "/my_dataset_path",
10
+ "size": 1000000,
11
+ "dim": 1024,
12
+ "metric_type": "L2",
13
+ "file_count": 1,
14
+ "use_shuffled": false,
15
+ "with_gt": true
16
+ }
17
+ }
18
+ ]
@@ -1,19 +1,19 @@
1
1
  from vectordb_bench.backend.cases import Case
2
2
  from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
3
3
  from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
4
- from vectordb_bench.frontend.const.styles import *
4
+ from vectordb_bench.frontend.config.styles import *
5
5
  from vectordb_bench.models import ResultLabel
6
6
  import plotly.express as px
7
7
 
8
8
 
9
- def drawCharts(st, allData, failedTasks, cases: list[Case]):
9
+ def drawCharts(st, allData, failedTasks, caseNames: list[str]):
10
10
  initMainExpanderStyle(st)
11
- for case in cases:
12
- chartContainer = st.expander(case.name, True)
13
- data = [data for data in allData if data["case_name"] == case.name]
11
+ for caseName in caseNames:
12
+ chartContainer = st.expander(caseName, True)
13
+ data = [data for data in allData if data["case_name"] == caseName]
14
14
  drawChart(data, chartContainer)
15
15
 
16
- errorDBs = failedTasks[case.name]
16
+ errorDBs = failedTasks[caseName]
17
17
  showFailedDBs(chartContainer, errorDBs)
18
18
 
19
19
 
@@ -8,9 +8,9 @@ from vectordb_bench.models import CaseResult, ResultLabel
8
8
  def getChartData(
9
9
  tasks: list[CaseResult],
10
10
  dbNames: list[str],
11
- cases: list[Case],
11
+ caseNames: list[str],
12
12
  ):
13
- filterTasks = getFilterTasks(tasks, dbNames, cases)
13
+ filterTasks = getFilterTasks(tasks, dbNames, caseNames)
14
14
  mergedTasks, failedTasks = mergeTasks(filterTasks)
15
15
  return mergedTasks, failedTasks
16
16
 
@@ -18,14 +18,13 @@ def getChartData(
18
18
  def getFilterTasks(
19
19
  tasks: list[CaseResult],
20
20
  dbNames: list[str],
21
- cases: list[Case],
21
+ caseNames: list[str],
22
22
  ) -> list[CaseResult]:
23
- case_ids = [case.case_id for case in cases]
24
23
  filterTasks = [
25
24
  task
26
25
  for task in tasks
27
26
  if task.task_config.db_name in dbNames
28
- and task.task_config.case_config.case_id in case_ids
27
+ and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
29
28
  ]
30
29
  return filterTasks
31
30
 
@@ -36,16 +35,17 @@ def mergeTasks(tasks: list[CaseResult]):
36
35
  db_name = task.task_config.db_name
37
36
  db = task.task_config.db.value
38
37
  db_label = task.task_config.db_config.db_label or ""
39
- case_id = task.task_config.case_config.case_id
40
- dbCaseMetricsMap[db_name][case_id] = {
38
+ case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
39
+ dbCaseMetricsMap[db_name][case.name] = {
41
40
  "db": db,
42
41
  "db_label": db_label,
43
42
  "metrics": mergeMetrics(
44
- dbCaseMetricsMap[db_name][case_id].get("metrics", {}),
43
+ dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
45
44
  asdict(task.metrics),
46
45
  ),
47
46
  "label": getBetterLabel(
48
- dbCaseMetricsMap[db_name][case_id].get("label", ResultLabel.FAILED),
47
+ dbCaseMetricsMap[db_name][case.name].get(
48
+ "label", ResultLabel.FAILED),
49
49
  task.label,
50
50
  ),
51
51
  }
@@ -53,12 +53,11 @@ def mergeTasks(tasks: list[CaseResult]):
53
53
  mergedTasks = []
54
54
  failedTasks = defaultdict(lambda: defaultdict(str))
55
55
  for db_name, caseMetricsMap in dbCaseMetricsMap.items():
56
- for case_id, metricInfo in caseMetricsMap.items():
56
+ for case_name, metricInfo in caseMetricsMap.items():
57
57
  metrics = metricInfo["metrics"]
58
58
  db = metricInfo["db"]
59
59
  db_label = metricInfo["db_label"]
60
60
  label = metricInfo["label"]
61
- case_name = case_id.case_name
62
61
  if label == ResultLabel.NORMAL:
63
62
  mergedTasks.append(
64
63
  {
@@ -80,22 +79,26 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
80
79
  metrics = {**metrics_1}
81
80
  for key, value in metrics_2.items():
82
81
  metrics[key] = (
83
- getBetterMetric(key, value, metrics[key]) if key in metrics else value
82
+ getBetterMetric(
83
+ key, value, metrics[key]) if key in metrics else value
84
84
  )
85
85
 
86
86
  return metrics
87
87
 
88
88
 
89
89
  def getBetterMetric(metric, value_1, value_2):
90
- if value_1 < 1e-7:
91
- return value_2
92
- if value_2 < 1e-7:
90
+ try:
91
+ if value_1 < 1e-7:
92
+ return value_2
93
+ if value_2 < 1e-7:
94
+ return value_1
95
+ return (
96
+ min(value_1, value_2)
97
+ if isLowerIsBetterMetric(metric)
98
+ else max(value_1, value_2)
99
+ )
100
+ except Exception:
93
101
  return value_1
94
- return (
95
- min(value_1, value_2)
96
- if isLowerIsBetterMetric(metric)
97
- else max(value_1, value_2)
98
- )
99
102
 
100
103
 
101
104
  def getBetterLabel(label_1: ResultLabel, label_2: ResultLabel):
@@ -1,7 +1,7 @@
1
1
  def initMainExpanderStyle(st):
2
2
  st.markdown(
3
3
  """<style>
4
- .main .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;}
4
+ .main div[data-testid='stExpander'] p {font-size: 18px; font-weight: 600;}
5
5
  .main div[data-testid='stExpander'] {
6
6
  background-color: #F6F8FA;
7
7
  border: 1px solid #A9BDD140;
@@ -1,8 +1,8 @@
1
1
  from vectordb_bench.backend.cases import Case
2
2
  from vectordb_bench.frontend.components.check_results.data import getChartData
3
3
  from vectordb_bench.frontend.components.check_results.expanderStyle import initSidebarExanderStyle
4
- from vectordb_bench.frontend.const.dbCaseConfigs import CASE_LIST
5
- from vectordb_bench.frontend.const.styles import *
4
+ from vectordb_bench.frontend.config.dbCaseConfigs import CASE_NAME_ORDER
5
+ from vectordb_bench.frontend.config.styles import *
6
6
  import streamlit as st
7
7
 
8
8
  from vectordb_bench.models import CaseResult, TestResult
@@ -18,11 +18,12 @@ def getshownData(results: list[TestResult], st):
18
18
  st.header("Filters")
19
19
 
20
20
  shownResults = getshownResults(results, st)
21
- showDBNames, showCases = getShowDbsAndCases(shownResults, st)
21
+ showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st)
22
22
 
23
- shownData, failedTasks = getChartData(shownResults, showDBNames, showCases)
23
+ shownData, failedTasks = getChartData(
24
+ shownResults, showDBNames, showCaseNames)
24
25
 
25
- return shownData, failedTasks, showCases
26
+ return shownData, failedTasks, showCaseNames
26
27
 
27
28
 
28
29
  def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
@@ -52,12 +53,18 @@ def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
52
53
  return selectedResult
53
54
 
54
55
 
55
- def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[Case]]:
56
+ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[str]]:
56
57
  initSidebarExanderStyle(st)
57
58
  allDbNames = list(set({res.task_config.db_name for res in result}))
58
59
  allDbNames.sort()
59
- allCasesSet = set({res.task_config.case_config.case_id for res in result})
60
- allCases: list[Case] = [case.case_cls() for case in CASE_LIST if case in allCasesSet]
60
+ allCases: list[Case] = [
61
+ res.task_config.case_config.case_id.case_cls(
62
+ res.task_config.case_config.custom_case)
63
+ for res in result
64
+ ]
65
+ allCaseNameSet = set({case.name for case in allCases})
66
+ allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + \
67
+ [case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER]
61
68
 
62
69
  # DB Filter
63
70
  dbFilterContainer = st.container()
@@ -70,15 +77,14 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[Ca
70
77
 
71
78
  # Case Filter
72
79
  caseFilterContainer = st.container()
73
- showCases = filterView(
80
+ showCaseNames = filterView(
74
81
  caseFilterContainer,
75
82
  "Case Filter",
76
- [case for case in allCases],
83
+ [caseName for caseName in allCaseNames],
77
84
  col=1,
78
- optionLables=[case.name for case in allCases],
79
85
  )
80
86
 
81
- return showDBNames, showCases
87
+ return showDBNames, showCaseNames
82
88
 
83
89
 
84
90
  def filterView(container, header, options, col, optionLables=None):
@@ -114,7 +120,8 @@ def filterView(container, header, options, col, optionLables=None):
114
120
  )
115
121
  if optionLables is None:
116
122
  optionLables = options
117
- isActive = {option: st.session_state[selectAllState] for option in optionLables}
123
+ isActive = {option: st.session_state[selectAllState]
124
+ for option in optionLables}
118
125
  for i, option in enumerate(optionLables):
119
126
  isActive[option] = columns[i % col].checkbox(
120
127
  optionLables[i],
@@ -1,4 +1,4 @@
1
- from vectordb_bench.frontend.const.styles import HEADER_ICON
1
+ from vectordb_bench.frontend.config.styles import HEADER_ICON
2
2
 
3
3
 
4
4
  def drawHeaderIcon(st):
@@ -3,7 +3,7 @@ import pandas as pd
3
3
  from collections import defaultdict
4
4
  import streamlit as st
5
5
 
6
- from vectordb_bench.frontend.const.dbPrices import DB_DBLABEL_TO_PRICE
6
+ from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE
7
7
 
8
8
 
9
9
  def priceTable(container, data):
@@ -1,4 +1,4 @@
1
- from vectordb_bench.frontend.const.styles import *
1
+ from vectordb_bench.frontend.config.styles import *
2
2
 
3
3
 
4
4
  def initResultsPageConfig(st):