vectordb-bench 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +85 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +112 -77
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +59 -84
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +13 -24
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +50 -54
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +39 -40
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +24 -26
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +67 -58
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +19 -39
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +56 -77
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +34 -43
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +98 -104
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +39 -49
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +15 -25
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +41 -35
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +95 -62
  47. vectordb_bench/backend/clients/test/cli.py +2 -3
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +5 -9
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +37 -26
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +18 -14
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +56 -23
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +46 -22
  63. vectordb_bench/backend/runner/serial_runner.py +81 -46
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -92
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +45 -24
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.21.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/top_level.txt +0 -0
vectordb_bench/cli/cli.py CHANGED
@@ -1,27 +1,27 @@
1
1
  import logging
2
+ import os
2
3
  import time
4
+ from collections.abc import Callable
3
5
  from concurrent.futures import wait
4
6
  from datetime import datetime
5
7
  from pprint import pformat
6
8
  from typing import (
7
9
  Annotated,
8
- Callable,
9
- List,
10
- Optional,
11
- Type,
10
+ Any,
12
11
  TypedDict,
13
12
  Unpack,
14
13
  get_origin,
15
14
  get_type_hints,
16
- Dict,
17
- Any,
18
15
  )
16
+
19
17
  import click
18
+ from yaml import load
20
19
 
21
20
  from vectordb_bench.backend.clients.api import MetricType
21
+
22
22
  from .. import config
23
23
  from ..backend.clients import DB
24
- from ..interface import benchMarkRunner, global_result_future
24
+ from ..interface import benchmark_runner, global_result_future
25
25
  from ..models import (
26
26
  CaseConfig,
27
27
  CaseType,
@@ -31,8 +31,7 @@ from ..models import (
31
31
  TaskConfig,
32
32
  TaskStage,
33
33
  )
34
- import os
35
- from yaml import load
34
+
36
35
  try:
37
36
  from yaml import CLoader as Loader
38
37
  except ImportError:
@@ -46,8 +45,8 @@ def click_get_defaults_from_file(ctx, param, value):
46
45
  else:
47
46
  input_file = os.path.join(config.CONFIG_LOCAL_DIR, value)
48
47
  try:
49
- with open(input_file, 'r') as f:
50
- _config: Dict[str, Dict[str, Any]] = load(f.read(), Loader=Loader)
48
+ with open(input_file) as f:
49
+ _config: dict[str, dict[str, Any]] = load(f.read(), Loader=Loader)
51
50
  ctx.default_map = _config.get(ctx.command.name, {})
52
51
  except Exception as e:
53
52
  raise click.BadParameter(f"Failed to load config file: {e}")
@@ -55,7 +54,7 @@ def click_get_defaults_from_file(ctx, param, value):
55
54
 
56
55
 
57
56
  def click_parameter_decorators_from_typed_dict(
58
- typed_dict: Type,
57
+ typed_dict: type,
59
58
  ) -> Callable[[click.decorators.FC], click.decorators.FC]:
60
59
  """A convenience method decorator that will read in a TypedDict with parameters defined by Annotated types.
61
60
  from .models import CaseConfig, CaseType, DBCaseConfig, DBConfig, TaskConfig, TaskStage
@@ -91,15 +90,12 @@ def click_parameter_decorators_from_typed_dict(
91
90
  decorators = []
92
91
  for _, t in get_type_hints(typed_dict, include_extras=True).items():
93
92
  assert get_origin(t) is Annotated
94
- if (
95
- len(t.__metadata__) == 1
96
- and t.__metadata__[0].__module__ == "click.decorators"
97
- ):
93
+ if len(t.__metadata__) == 1 and t.__metadata__[0].__module__ == "click.decorators":
98
94
  # happy path -- only accept Annotated[..., Union[click.option,click.argument,...]] with no additional metadata defined (len=1)
99
95
  decorators.append(t.__metadata__[0])
100
96
  else:
101
97
  raise RuntimeError(
102
- "Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring"
98
+ "Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring",
103
99
  )
104
100
 
105
101
  def deco(f):
@@ -132,11 +128,11 @@ def parse_task_stages(
132
128
  load: bool,
133
129
  search_serial: bool,
134
130
  search_concurrent: bool,
135
- ) -> List[TaskStage]:
131
+ ) -> list[TaskStage]:
136
132
  stages = []
137
133
  if load and not drop_old:
138
134
  raise RuntimeError("Dropping old data cannot be skipped if loading data")
139
- elif drop_old and not load:
135
+ if drop_old and not load:
140
136
  raise RuntimeError("Load cannot be skipped if dropping old data")
141
137
  if drop_old:
142
138
  stages.append(TaskStage.DROP_OLD)
@@ -149,12 +145,19 @@ def parse_task_stages(
149
145
  return stages
150
146
 
151
147
 
152
- def check_custom_case_parameters(ctx, param, value):
153
- if ctx.params.get("case_type") == "PerformanceCustomDataset":
154
- if value is None:
155
- raise click.BadParameter("Custom case parameters\
156
- \n--custom-case-name\n--custom-dataset-name\n--custom-dataset-dir\n--custom-dataset-size \
157
- \n--custom-dataset-dim\n--custom-dataset-file-count\n are required")
148
+ # ruff: noqa
149
+ def check_custom_case_parameters(ctx: any, param: any, value: any):
150
+ if ctx.params.get("case_type") == "PerformanceCustomDataset" and value is None:
151
+ raise click.BadParameter(
152
+ """ Custom case parameters
153
+ --custom-case-name
154
+ --custom-dataset-name
155
+ --custom-dataset-dir
156
+ --custom-dataset-sizes
157
+ --custom-dataset-dim
158
+ --custom-dataset-file-count
159
+ are required """,
160
+ )
158
161
  return value
159
162
 
160
163
 
@@ -175,7 +178,7 @@ def get_custom_case_config(parameters: dict) -> dict:
175
178
  "file_count": parameters["custom_dataset_file_count"],
176
179
  "use_shuffled": parameters["custom_dataset_use_shuffled"],
177
180
  "with_gt": parameters["custom_dataset_with_gt"],
178
- }
181
+ },
179
182
  }
180
183
  return custom_case_config
181
184
 
@@ -186,12 +189,14 @@ log = logging.getLogger(__name__)
186
189
  class CommonTypedDict(TypedDict):
187
190
  config_file: Annotated[
188
191
  bool,
189
- click.option('--config-file',
190
- type=click.Path(),
191
- callback=click_get_defaults_from_file,
192
- is_eager=True,
193
- expose_value=False,
194
- help='Read configuration from yaml file'),
192
+ click.option(
193
+ "--config-file",
194
+ type=click.Path(),
195
+ callback=click_get_defaults_from_file,
196
+ is_eager=True,
197
+ expose_value=False,
198
+ help="Read configuration from yaml file",
199
+ ),
195
200
  ]
196
201
  drop_old: Annotated[
197
202
  bool,
@@ -246,9 +251,11 @@ class CommonTypedDict(TypedDict):
246
251
  db_label: Annotated[
247
252
  str,
248
253
  click.option(
249
- "--db-label", type=str, help="Db label, default: date in ISO format",
254
+ "--db-label",
255
+ type=str,
256
+ help="Db label, default: date in ISO format",
250
257
  show_default=True,
251
- default=datetime.now().isoformat()
258
+ default=datetime.now().isoformat(),
252
259
  ),
253
260
  ]
254
261
  dry_run: Annotated[
@@ -282,7 +289,7 @@ class CommonTypedDict(TypedDict):
282
289
  ),
283
290
  ]
284
291
  num_concurrency: Annotated[
285
- List[str],
292
+ list[str],
286
293
  click.option(
287
294
  "--num-concurrency",
288
295
  type=str,
@@ -298,7 +305,7 @@ class CommonTypedDict(TypedDict):
298
305
  "--custom-case-name",
299
306
  help="Custom dataset case name",
300
307
  callback=check_custom_case_parameters,
301
- )
308
+ ),
302
309
  ]
303
310
  custom_case_description: Annotated[
304
311
  str,
@@ -307,7 +314,7 @@ class CommonTypedDict(TypedDict):
307
314
  help="Custom dataset case description",
308
315
  default="This is a customized dataset.",
309
316
  show_default=True,
310
- )
317
+ ),
311
318
  ]
312
319
  custom_case_load_timeout: Annotated[
313
320
  int,
@@ -316,7 +323,7 @@ class CommonTypedDict(TypedDict):
316
323
  help="Custom dataset case load timeout",
317
324
  default=36000,
318
325
  show_default=True,
319
- )
326
+ ),
320
327
  ]
321
328
  custom_case_optimize_timeout: Annotated[
322
329
  int,
@@ -325,7 +332,7 @@ class CommonTypedDict(TypedDict):
325
332
  help="Custom dataset case optimize timeout",
326
333
  default=36000,
327
334
  show_default=True,
328
- )
335
+ ),
329
336
  ]
330
337
  custom_dataset_name: Annotated[
331
338
  str,
@@ -397,60 +404,60 @@ class CommonTypedDict(TypedDict):
397
404
 
398
405
 
399
406
  class HNSWBaseTypedDict(TypedDict):
400
- m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m")]
407
+ m: Annotated[int | None, click.option("--m", type=int, help="hnsw m")]
401
408
  ef_construction: Annotated[
402
- Optional[int],
409
+ int | None,
403
410
  click.option("--ef-construction", type=int, help="hnsw ef-construction"),
404
411
  ]
405
412
 
406
413
 
407
414
  class HNSWBaseRequiredTypedDict(TypedDict):
408
- m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m", required=True)]
415
+ m: Annotated[int | None, click.option("--m", type=int, help="hnsw m", required=True)]
409
416
  ef_construction: Annotated[
410
- Optional[int],
417
+ int | None,
411
418
  click.option("--ef-construction", type=int, help="hnsw ef-construction", required=True),
412
419
  ]
413
420
 
414
421
 
415
422
  class HNSWFlavor1(HNSWBaseTypedDict):
416
423
  ef_search: Annotated[
417
- Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search", is_eager=True)
424
+ int | None,
425
+ click.option("--ef-search", type=int, help="hnsw ef-search", is_eager=True),
418
426
  ]
419
427
 
420
428
 
421
429
  class HNSWFlavor2(HNSWBaseTypedDict):
422
430
  ef_runtime: Annotated[
423
- Optional[int], click.option("--ef-runtime", type=int, help="hnsw ef-runtime")
431
+ int | None,
432
+ click.option("--ef-runtime", type=int, help="hnsw ef-runtime"),
424
433
  ]
425
434
 
426
435
 
427
436
  class HNSWFlavor3(HNSWBaseRequiredTypedDict):
428
437
  ef_search: Annotated[
429
- Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search", required=True)
438
+ int | None,
439
+ click.option("--ef-search", type=int, help="hnsw ef-search", required=True),
430
440
  ]
431
441
 
432
442
 
433
443
  class IVFFlatTypedDict(TypedDict):
434
- lists: Annotated[
435
- Optional[int], click.option("--lists", type=int, help="ivfflat lists")
436
- ]
437
- probes: Annotated[
438
- Optional[int], click.option("--probes", type=int, help="ivfflat probes")
439
- ]
444
+ lists: Annotated[int | None, click.option("--lists", type=int, help="ivfflat lists")]
445
+ probes: Annotated[int | None, click.option("--probes", type=int, help="ivfflat probes")]
440
446
 
441
447
 
442
448
  class IVFFlatTypedDictN(TypedDict):
443
449
  nlist: Annotated[
444
- Optional[int], click.option("--lists", "nlist", type=int, help="ivfflat lists", required=True)
450
+ int | None,
451
+ click.option("--lists", "nlist", type=int, help="ivfflat lists", required=True),
445
452
  ]
446
453
  nprobe: Annotated[
447
- Optional[int], click.option("--probes", "nprobe", type=int, help="ivfflat probes", required=True)
454
+ int | None,
455
+ click.option("--probes", "nprobe", type=int, help="ivfflat probes", required=True),
448
456
  ]
449
457
 
450
458
 
451
459
  @click.group()
452
- def cli():
453
- ...
460
+ def cli(): ...
454
461
 
455
462
 
456
463
  def run(
@@ -482,9 +489,7 @@ def run(
482
489
  custom_case=get_custom_case_config(parameters),
483
490
  ),
484
491
  stages=parse_task_stages(
485
- (
486
- False if not parameters["load"] else parameters["drop_old"]
487
- ), # only drop old data if loading new data
492
+ (False if not parameters["load"] else parameters["drop_old"]), # only drop old data if loading new data
488
493
  parameters["load"],
489
494
  parameters["search_serial"],
490
495
  parameters["search_concurrent"],
@@ -493,7 +498,7 @@ def run(
493
498
 
494
499
  log.info(f"Task:\n{pformat(task)}\n")
495
500
  if not parameters["dry_run"]:
496
- benchMarkRunner.run([task])
501
+ benchmark_runner.run([task])
497
502
  time.sleep(5)
498
503
  if global_result_future:
499
504
  wait([global_result_future])
@@ -1,16 +1,15 @@
1
- from ..backend.clients.pgvector.cli import PgVectorHNSW
1
+ from ..backend.clients.alloydb.cli import AlloyDBScaNN
2
+ from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
3
+ from ..backend.clients.memorydb.cli import MemoryDB
4
+ from ..backend.clients.milvus.cli import MilvusAutoIndex
5
+ from ..backend.clients.pgdiskann.cli import PgDiskAnn
2
6
  from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat
7
+ from ..backend.clients.pgvector.cli import PgVectorHNSW
3
8
  from ..backend.clients.pgvectorscale.cli import PgVectorScaleDiskAnn
4
- from ..backend.clients.pgdiskann.cli import PgDiskAnn
5
9
  from ..backend.clients.redis.cli import Redis
6
- from ..backend.clients.memorydb.cli import MemoryDB
7
10
  from ..backend.clients.test.cli import Test
8
11
  from ..backend.clients.weaviate_cloud.cli import Weaviate
9
12
  from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
10
- from ..backend.clients.milvus.cli import MilvusAutoIndex
11
- from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
12
- from ..backend.clients.alloydb.cli import AlloyDBScaNN
13
-
14
13
  from .cli import cli
15
14
 
16
15
  cli.add_command(PgVectorHNSW)
@@ -1,8 +1,7 @@
1
- from vectordb_bench.backend.cases import Case
2
1
  from vectordb_bench.frontend.components.check_results.expanderStyle import (
3
2
  initMainExpanderStyle,
4
3
  )
5
- from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
4
+ from vectordb_bench.metric import metric_order, isLowerIsBetterMetric, metric_unit_map
6
5
  from vectordb_bench.frontend.config.styles import *
7
6
  from vectordb_bench.models import ResultLabel
8
7
  import plotly.express as px
@@ -21,9 +20,7 @@ def drawCharts(st, allData, failedTasks, caseNames: list[str]):
21
20
 
22
21
  def showFailedDBs(st, errorDBs):
23
22
  failedDBs = [db for db, label in errorDBs.items() if label == ResultLabel.FAILED]
24
- timeoutDBs = [
25
- db for db, label in errorDBs.items() if label == ResultLabel.OUTOFRANGE
26
- ]
23
+ timeoutDBs = [db for db, label in errorDBs.items() if label == ResultLabel.OUTOFRANGE]
27
24
 
28
25
  showFailedText(st, "Failed", failedDBs)
29
26
  showFailedText(st, "Timeout", timeoutDBs)
@@ -41,7 +38,7 @@ def drawChart(data, st, key_prefix: str):
41
38
  metricsSet = set()
42
39
  for d in data:
43
40
  metricsSet = metricsSet.union(d["metricsSet"])
44
- showMetrics = [metric for metric in metricOrder if metric in metricsSet]
41
+ showMetrics = [metric for metric in metric_order if metric in metricsSet]
45
42
 
46
43
  for i, metric in enumerate(showMetrics):
47
44
  container = st.container()
@@ -72,9 +69,7 @@ def getLabelToShapeMap(data):
72
69
  else:
73
70
  usedShapes.add(labelIndexMap[label] % len(PATTERN_SHAPES))
74
71
 
75
- labelToShapeMap = {
76
- label: getPatternShape(index) for label, index in labelIndexMap.items()
77
- }
72
+ labelToShapeMap = {label: getPatternShape(index) for label, index in labelIndexMap.items()}
78
73
  return labelToShapeMap
79
74
 
80
75
 
@@ -96,11 +91,9 @@ def drawMetricChart(data, metric, st, key: str):
96
91
  xpadding = (xmax - xmin) / 16
97
92
  xpadding_multiplier = 1.8
98
93
  xrange = [xmin, xmax + xpadding * xpadding_multiplier]
99
- unit = metricUnitMap.get(metric, "")
94
+ unit = metric_unit_map.get(metric, "")
100
95
  labelToShapeMap = getLabelToShapeMap(dataWithMetric)
101
- categoryorder = (
102
- "total descending" if isLowerIsBetterMetric(metric) else "total ascending"
103
- )
96
+ categoryorder = "total descending" if isLowerIsBetterMetric(metric) else "total ascending"
104
97
  fig = px.bar(
105
98
  dataWithMetric,
106
99
  x=metric,
@@ -137,18 +130,14 @@ def drawMetricChart(data, metric, st, key: str):
137
130
  color="#333",
138
131
  size=12,
139
132
  ),
140
- marker=dict(
141
- pattern=dict(fillmode="overlay", fgcolor="#fff", fgopacity=1, size=7)
142
- ),
133
+ marker=dict(pattern=dict(fillmode="overlay", fgcolor="#fff", fgopacity=1, size=7)),
143
134
  texttemplate="%{x:,.4~r}" + unit,
144
135
  )
145
136
  fig.update_layout(
146
137
  margin=dict(l=0, r=0, t=48, b=12, pad=8),
147
138
  bargap=0.25,
148
139
  showlegend=False,
149
- legend=dict(
150
- orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""
151
- ),
140
+ legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""),
152
141
  # legend=dict(orientation="v", title=""),
153
142
  yaxis={"categoryorder": categoryorder},
154
143
  title=dict(
@@ -1,6 +1,5 @@
1
1
  from collections import defaultdict
2
2
  from dataclasses import asdict
3
- from vectordb_bench.backend.cases import Case
4
3
  from vectordb_bench.metric import isLowerIsBetterMetric
5
4
  from vectordb_bench.models import CaseResult, ResultLabel
6
5
 
@@ -24,10 +23,7 @@ def getFilterTasks(
24
23
  task
25
24
  for task in tasks
26
25
  if task.task_config.db_name in dbNames
27
- and task.task_config.case_config.case_id.case_cls(
28
- task.task_config.case_config.custom_case
29
- ).name
30
- in caseNames
26
+ and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
31
27
  ]
32
28
  return filterTasks
33
29
 
@@ -39,9 +35,7 @@ def mergeTasks(tasks: list[CaseResult]):
39
35
  db = task.task_config.db.value
40
36
  db_label = task.task_config.db_config.db_label or ""
41
37
  version = task.task_config.db_config.version or ""
42
- case = task.task_config.case_config.case_id.case_cls(
43
- task.task_config.case_config.custom_case
44
- )
38
+ case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
45
39
  dbCaseMetricsMap[db_name][case.name] = {
46
40
  "db": db,
47
41
  "db_label": db_label,
@@ -86,9 +80,7 @@ def mergeTasks(tasks: list[CaseResult]):
86
80
  def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
87
81
  metrics = {**metrics_1}
88
82
  for key, value in metrics_2.items():
89
- metrics[key] = (
90
- getBetterMetric(key, value, metrics[key]) if key in metrics else value
91
- )
83
+ metrics[key] = getBetterMetric(key, value, metrics[key]) if key in metrics else value
92
84
 
93
85
  return metrics
94
86
 
@@ -99,11 +91,7 @@ def getBetterMetric(metric, value_1, value_2):
99
91
  return value_2
100
92
  if value_2 < 1e-7:
101
93
  return value_1
102
- return (
103
- min(value_1, value_2)
104
- if isLowerIsBetterMetric(metric)
105
- else max(value_1, value_2)
106
- )
94
+ return min(value_1, value_2) if isLowerIsBetterMetric(metric) else max(value_1, value_2)
107
95
  except Exception:
108
96
  return value_1
109
97
 
@@ -20,23 +20,17 @@ def getshownData(results: list[TestResult], st):
20
20
  shownResults = getshownResults(results, st)
21
21
  showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st)
22
22
 
23
- shownData, failedTasks = getChartData(
24
- shownResults, showDBNames, showCaseNames)
23
+ shownData, failedTasks = getChartData(shownResults, showDBNames, showCaseNames)
25
24
 
26
25
  return shownData, failedTasks, showCaseNames
27
26
 
28
27
 
29
28
  def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
30
29
  resultSelectOptions = [
31
- result.task_label
32
- if result.task_label != result.run_id
33
- else f"res-{result.run_id[:4]}"
34
- for result in results
30
+ result.task_label if result.task_label != result.run_id else f"res-{result.run_id[:4]}" for result in results
35
31
  ]
36
32
  if len(resultSelectOptions) == 0:
37
- st.write(
38
- "There are no results to display. Please wait for the task to complete or run a new task."
39
- )
33
+ st.write("There are no results to display. Please wait for the task to complete or run a new task.")
40
34
  return []
41
35
 
42
36
  selectedResultSelectedOptions = st.multiselect(
@@ -58,13 +52,12 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[st
58
52
  allDbNames = list(set({res.task_config.db_name for res in result}))
59
53
  allDbNames.sort()
60
54
  allCases: list[Case] = [
61
- res.task_config.case_config.case_id.case_cls(
62
- res.task_config.case_config.custom_case)
63
- for res in result
55
+ res.task_config.case_config.case_id.case_cls(res.task_config.case_config.custom_case) for res in result
64
56
  ]
65
57
  allCaseNameSet = set({case.name for case in allCases})
66
- allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + \
67
- [case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER]
58
+ allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + [
59
+ case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER
60
+ ]
68
61
 
69
62
  # DB Filter
70
63
  dbFilterContainer = st.container()
@@ -120,8 +113,7 @@ def filterView(container, header, options, col, optionLables=None):
120
113
  )
121
114
  if optionLables is None:
122
115
  optionLables = options
123
- isActive = {option: st.session_state[selectAllState]
124
- for option in optionLables}
116
+ isActive = {option: st.session_state[selectAllState] for option in optionLables}
125
117
  for i, option in enumerate(optionLables):
126
118
  isActive[option] = columns[i % col].checkbox(
127
119
  optionLables[i],
@@ -7,15 +7,15 @@ def NavToRunTest(st):
7
7
  navClick = st.button("Run Your Test &nbsp;&nbsp;>")
8
8
  if navClick:
9
9
  switch_page("run test")
10
-
11
-
10
+
11
+
12
12
  def NavToQuriesPerDollar(st):
13
13
  st.subheader("Compare qps with price.")
14
14
  navClick = st.button("QP$ (Quries per Dollar) &nbsp;&nbsp;>")
15
15
  if navClick:
16
16
  switch_page("quries_per_dollar")
17
-
18
-
17
+
18
+
19
19
  def NavToResults(st, key="nav-to-results"):
20
20
  navClick = st.button("< &nbsp;&nbsp;Back to Results", key=key)
21
21
  if navClick:
@@ -7,9 +7,7 @@ from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE
7
7
 
8
8
 
9
9
  def priceTable(container, data):
10
- dbAndLabelSet = {
11
- (d["db"], d["db_label"]) for d in data if d["db"] != DB.Milvus.value
12
- }
10
+ dbAndLabelSet = {(d["db"], d["db_label"]) for d in data if d["db"] != DB.Milvus.value}
13
11
 
14
12
  dbAndLabelList = list(dbAndLabelSet)
15
13
  dbAndLabelList.sort()
@@ -9,10 +9,11 @@ def initResultsPageConfig(st):
9
9
  # initial_sidebar_state="collapsed",
10
10
  )
11
11
 
12
+
12
13
  def initRunTestPageConfig(st):
13
14
  st.set_page_config(
14
15
  page_title=PAGE_TITLE,
15
16
  page_icon=FAVICON,
16
17
  # layout="wide",
17
18
  initial_sidebar_state="collapsed",
18
- )
19
+ )
@@ -14,24 +14,24 @@ def drawChartsByCase(allData, showCaseNames: list[str], st, latency_type: str):
14
14
  data = [
15
15
  {
16
16
  "conc_num": caseData["conc_num_list"][i],
17
- "qps": caseData["conc_qps_list"][i]
18
- if 0 <= i < len(caseData["conc_qps_list"])
19
- else 0,
20
- "latency_p99": caseData["conc_latency_p99_list"][i] * 1000
21
- if 0 <= i < len(caseData["conc_latency_p99_list"])
22
- else 0,
23
- "latency_avg": caseData["conc_latency_avg_list"][i] * 1000
24
- if 0 <= i < len(caseData["conc_latency_avg_list"])
25
- else 0,
17
+ "qps": (caseData["conc_qps_list"][i] if 0 <= i < len(caseData["conc_qps_list"]) else 0),
18
+ "latency_p99": (
19
+ caseData["conc_latency_p99_list"][i] * 1000
20
+ if 0 <= i < len(caseData["conc_latency_p99_list"])
21
+ else 0
22
+ ),
23
+ "latency_avg": (
24
+ caseData["conc_latency_avg_list"][i] * 1000
25
+ if 0 <= i < len(caseData["conc_latency_avg_list"])
26
+ else 0
27
+ ),
26
28
  "db_name": caseData["db_name"],
27
29
  "db": caseData["db"],
28
30
  }
29
31
  for caseData in caseDataList
30
32
  for i in range(len(caseData["conc_num_list"]))
31
33
  ]
32
- drawChart(
33
- data, chartContainer, key=f"{caseName}-qps-p99", x_metric=latency_type
34
- )
34
+ drawChart(data, chartContainer, key=f"{caseName}-qps-p99", x_metric=latency_type)
35
35
 
36
36
 
37
37
  def getRange(metric, data, padding_multipliers):
@@ -1,4 +1,3 @@
1
-
2
1
  from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig
3
2
 
4
3
 
@@ -6,26 +5,33 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
6
5
 
7
6
  columns = st.columns([1, 2])
8
7
  customCase.dataset_config.name = columns[0].text_input(
9
- "Name", key=f"{key}_name", value=customCase.dataset_config.name)
8
+ "Name", key=f"{key}_name", value=customCase.dataset_config.name
9
+ )
10
10
  customCase.name = f"{customCase.dataset_config.name} (Performace Case)"
11
11
  customCase.dataset_config.dir = columns[1].text_input(
12
- "Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir)
12
+ "Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir
13
+ )
13
14
 
14
15
  columns = st.columns(4)
15
16
  customCase.dataset_config.dim = columns[0].number_input(
16
- "dim", key=f"{key}_dim", value=customCase.dataset_config.dim)
17
+ "dim", key=f"{key}_dim", value=customCase.dataset_config.dim
18
+ )
17
19
  customCase.dataset_config.size = columns[1].number_input(
18
- "size", key=f"{key}_size", value=customCase.dataset_config.size)
20
+ "size", key=f"{key}_size", value=customCase.dataset_config.size
21
+ )
19
22
  customCase.dataset_config.metric_type = columns[2].selectbox(
20
- "metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"])
23
+ "metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"]
24
+ )
21
25
  customCase.dataset_config.file_count = columns[3].number_input(
22
- "train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count)
26
+ "train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count
27
+ )
23
28
 
24
29
  columns = st.columns(4)
25
30
  customCase.dataset_config.use_shuffled = columns[0].checkbox(
26
- "use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled)
31
+ "use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled
32
+ )
27
33
  customCase.dataset_config.with_gt = columns[1].checkbox(
28
- "with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt)
34
+ "with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt
35
+ )
29
36
 
30
- customCase.description = st.text_area(
31
- "description", key=f"{key}_description", value=customCase.description)
37
+ customCase.description = st.text_area("description", key=f"{key}_description", value=customCase.description)
@@ -1,5 +1,6 @@
1
1
  def displayParams(st):
2
- st.markdown("""
2
+ st.markdown(
3
+ """
3
4
  - `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
4
5
  - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
5
6
  - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
@@ -8,4 +9,5 @@ def displayParams(st):
8
9
  - `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files.
9
10
 
10
11
  - `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.
11
- """)
12
+ """
13
+ )