vectordb-bench 0.0.29__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. vectordb_bench/__init__.py +14 -27
  2. vectordb_bench/backend/assembler.py +19 -6
  3. vectordb_bench/backend/cases.py +186 -23
  4. vectordb_bench/backend/clients/__init__.py +32 -0
  5. vectordb_bench/backend/clients/api.py +22 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +249 -43
  7. vectordb_bench/backend/clients/aws_opensearch/cli.py +51 -21
  8. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -16
  9. vectordb_bench/backend/clients/chroma/chroma.py +6 -2
  10. vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
  11. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
  12. vectordb_bench/backend/clients/lancedb/cli.py +62 -8
  13. vectordb_bench/backend/clients/lancedb/config.py +14 -1
  14. vectordb_bench/backend/clients/lancedb/lancedb.py +21 -9
  15. vectordb_bench/backend/clients/memorydb/memorydb.py +2 -2
  16. vectordb_bench/backend/clients/milvus/cli.py +30 -9
  17. vectordb_bench/backend/clients/milvus/config.py +3 -0
  18. vectordb_bench/backend/clients/milvus/milvus.py +81 -23
  19. vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
  20. vectordb_bench/backend/clients/oceanbase/config.py +125 -0
  21. vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
  22. vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
  23. vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
  24. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
  25. vectordb_bench/backend/clients/qdrant_local/cli.py +60 -0
  26. vectordb_bench/backend/clients/qdrant_local/config.py +47 -0
  27. vectordb_bench/backend/clients/qdrant_local/qdrant_local.py +232 -0
  28. vectordb_bench/backend/clients/weaviate_cloud/cli.py +29 -3
  29. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -0
  30. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -0
  31. vectordb_bench/backend/dataset.py +143 -27
  32. vectordb_bench/backend/filter.py +76 -0
  33. vectordb_bench/backend/runner/__init__.py +3 -3
  34. vectordb_bench/backend/runner/mp_runner.py +52 -39
  35. vectordb_bench/backend/runner/rate_runner.py +68 -52
  36. vectordb_bench/backend/runner/read_write_runner.py +125 -68
  37. vectordb_bench/backend/runner/serial_runner.py +56 -23
  38. vectordb_bench/backend/task_runner.py +48 -20
  39. vectordb_bench/cli/batch_cli.py +121 -0
  40. vectordb_bench/cli/cli.py +59 -1
  41. vectordb_bench/cli/vectordbbench.py +7 -0
  42. vectordb_bench/config-files/batch_sample_config.yml +17 -0
  43. vectordb_bench/frontend/components/check_results/data.py +16 -11
  44. vectordb_bench/frontend/components/check_results/filters.py +53 -25
  45. vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
  46. vectordb_bench/frontend/components/check_results/nav.py +20 -0
  47. vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
  48. vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
  49. vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
  50. vectordb_bench/frontend/components/label_filter/charts.py +60 -0
  51. vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
  52. vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
  53. vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
  54. vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
  55. vectordb_bench/frontend/components/streaming/charts.py +253 -0
  56. vectordb_bench/frontend/components/streaming/data.py +62 -0
  57. vectordb_bench/frontend/components/tables/data.py +1 -1
  58. vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
  59. vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
  60. vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
  61. vectordb_bench/frontend/config/dbCaseConfigs.py +420 -41
  62. vectordb_bench/frontend/config/styles.py +32 -2
  63. vectordb_bench/frontend/pages/concurrent.py +5 -1
  64. vectordb_bench/frontend/pages/custom.py +4 -0
  65. vectordb_bench/frontend/pages/label_filter.py +56 -0
  66. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
  67. vectordb_bench/frontend/pages/results.py +60 -0
  68. vectordb_bench/frontend/pages/run_test.py +3 -3
  69. vectordb_bench/frontend/pages/streaming.py +135 -0
  70. vectordb_bench/frontend/pages/tables.py +4 -0
  71. vectordb_bench/frontend/vdb_benchmark.py +16 -41
  72. vectordb_bench/interface.py +6 -2
  73. vectordb_bench/metric.py +15 -1
  74. vectordb_bench/models.py +38 -11
  75. vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
  76. vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
  77. vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
  78. vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
  79. vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
  80. vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
  81. vectordb_bench/results/dbPrices.json +12 -4
  82. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +131 -32
  83. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +87 -65
  84. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +1 -1
  85. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
  86. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
  87. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
  88. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
  89. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
  90. {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -20,3 +20,23 @@ def NavToResults(st, key="nav-to-results"):
20
20
  navClick = st.button("< &nbsp;&nbsp;Back to Results", key=key)
21
21
  if navClick:
22
22
  switch_page("vdb benchmark")
23
+
24
+
25
+ def NavToPages(st):
26
+ options = [
27
+ {"name": "Run Test", "link": "run_test"},
28
+ {"name": "Results", "link": "results"},
29
+ {"name": "Quries Per Dollar", "link": "quries_per_dollar"},
30
+ {"name": "Concurrent", "link": "concurrent"},
31
+ {"name": "Label Filter", "link": "label_filter"},
32
+ {"name": "Streaming", "link": "streaming"},
33
+ {"name": "Tables", "link": "tables"},
34
+ {"name": "Custom Dataset", "link": "custom"},
35
+ ]
36
+
37
+ html = ""
38
+ for i, option in enumerate(options):
39
+ html += f'<a href="/{option["link"]}" target="_self" style="text-decoration: none; padding: 0.1px 0.2px;">{option["name"]}</a>'
40
+ if i < len(options) - 1:
41
+ html += '<span style="color: #888; margin: 0 5px;">|</span>'
42
+ st.markdown(html, unsafe_allow_html=True)
@@ -12,7 +12,7 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
12
12
  "Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir
13
13
  )
14
14
 
15
- columns = st.columns(4)
15
+ columns = st.columns(3)
16
16
  customCase.dataset_config.dim = columns[0].number_input(
17
17
  "dim", key=f"{key}_dim", value=customCase.dataset_config.dim
18
18
  )
@@ -22,16 +22,51 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
22
22
  customCase.dataset_config.metric_type = columns[2].selectbox(
23
23
  "metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"]
24
24
  )
25
- customCase.dataset_config.file_count = columns[3].number_input(
26
- "train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count
25
+
26
+ columns = st.columns(3)
27
+ customCase.dataset_config.train_name = columns[0].text_input(
28
+ "train file name",
29
+ key=f"{key}_train_name",
30
+ value=customCase.dataset_config.train_name,
31
+ )
32
+ customCase.dataset_config.test_name = columns[1].text_input(
33
+ "test file name", key=f"{key}_test_name", value=customCase.dataset_config.test_name
34
+ )
35
+ customCase.dataset_config.gt_name = columns[2].text_input(
36
+ "ground truth file name", key=f"{key}_gt_name", value=customCase.dataset_config.gt_name
37
+ )
38
+
39
+ columns = st.columns([1, 1, 2, 2])
40
+ customCase.dataset_config.train_id_name = columns[0].text_input(
41
+ "train id name", key=f"{key}_train_id_name", value=customCase.dataset_config.train_id_name
42
+ )
43
+ customCase.dataset_config.train_col_name = columns[1].text_input(
44
+ "train emb name", key=f"{key}_train_col_name", value=customCase.dataset_config.train_col_name
45
+ )
46
+ customCase.dataset_config.test_col_name = columns[2].text_input(
47
+ "test emb name", key=f"{key}_test_col_name", value=customCase.dataset_config.test_col_name
48
+ )
49
+ customCase.dataset_config.gt_col_name = columns[3].text_input(
50
+ "ground truth emb name", key=f"{key}_gt_col_name", value=customCase.dataset_config.gt_col_name
27
51
  )
28
52
 
29
- columns = st.columns(4)
30
- customCase.dataset_config.use_shuffled = columns[0].checkbox(
31
- "use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled
53
+ columns = st.columns(2)
54
+ customCase.dataset_config.scalar_labels_name = columns[0].text_input(
55
+ "scalar labels file name",
56
+ key=f"{key}_scalar_labels_file_name",
57
+ value=customCase.dataset_config.scalar_labels_name,
32
58
  )
33
- customCase.dataset_config.with_gt = columns[1].checkbox(
34
- "with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt
59
+ default_label_percentages = ",".join(map(str, customCase.dataset_config.with_label_percentages))
60
+ label_percentage_input = columns[1].text_input(
61
+ "label percentages",
62
+ key=f"{key}_label_percantages",
63
+ value=default_label_percentages,
35
64
  )
65
+ try:
66
+ customCase.dataset_config.label_percentages = [
67
+ float(item.strip()) for item in label_percentage_input.split(",") if item.strip()
68
+ ]
69
+ except ValueError as e:
70
+ st.write(f"<span style='color:red'>{e},please input correct number</span>", unsafe_allow_html=True)
36
71
 
37
72
  customCase.description = st.text_area("description", key=f"{key}_description", value=customCase.description)
@@ -2,13 +2,18 @@ def displayParams(st):
2
2
  st.markdown(
3
3
  """
4
4
  - `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
5
- - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
6
- - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
7
- - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`.
5
+ - Vectors data files: The file should have two kinds of columns: `id` as an incrementing `int` and `emb` as an array of `float32`. The name of two columns could be defined on your own.
6
+ - Query test vectors: The file could be named on your own and should have two kinds of columns: `id` as an incrementing `int` and `emb` as an array of `float32`. The `id` column must be named as `id`, and `emb` column could be defined on your own.
7
+ - Ground truth file: The file could be named on your own and should have two kinds of columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`. The `id` column must be named as `id`, and `neighbors_id` column could be defined on your own.
8
8
 
9
- - `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files.
9
+ - `Train File Name` - If the number of train file is `more than one`, please input all your train file name and `split with ','` without the `.parquet` file extensionthe. For example, if there are two train file and the name of them are `train1.parquet` and `train2.parquet`, then input `train1,train2`.
10
+
11
+ - `Ground Truth Emb Name` - No matter whether filter file is applied or not, the `neighbors_id` column in ground truth file must have the same name.
12
+
13
+ - `Scalar Labels File Name ` - If there is a scalar labels file, please input the filename without the .parquet extension. The file should have two columns: `id` as an incrementing `int` and `labels` as an array of `string`. The `id` column must correspond one-to-one with the `id` column in train file..
14
+
15
+ - `Label percentages` - If you have filter file, please input label percentage you want to real run and `split with ','` when it's `more than one`. If you `don't have` filter file, than `keep the text vacant.`
10
16
 
11
- - `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.
12
17
  """
13
18
  )
14
19
  st.caption(
@@ -14,6 +14,16 @@ class CustomDatasetConfig(BaseModel):
14
14
  file_count: int = 1
15
15
  use_shuffled: bool = False
16
16
  with_gt: bool = True
17
+ train_name: str = "train"
18
+ test_name: str = "test"
19
+ gt_name: str = "neighbors"
20
+ train_id_name: str = "id"
21
+ train_col_name: str = "emb"
22
+ test_col_name: str = "emb"
23
+ gt_col_name: str = "neighbors_id"
24
+ scalar_labels_name: str = "scalar_labels"
25
+ label_percentages: list[str] = []
26
+ with_label_percentages: list[float] = [0.001, 0.02, 0.5]
17
27
 
18
28
 
19
29
  class CustomCaseConfig(BaseModel):
@@ -0,0 +1,60 @@
1
+ import plotly.express as px
2
+ from vectordb_bench.metric import metric_unit_map
3
+
4
+
5
+ def drawCharts(st, allData, **kwargs):
6
+ dataset_names = list(set([data["dataset_name"] for data in allData]))
7
+ dataset_names.sort()
8
+ for dataset_name in dataset_names:
9
+ container = st.container()
10
+ container.subheader(dataset_name)
11
+ data = [d for d in allData if d["dataset_name"] == dataset_name]
12
+ drawChartByMetric(container, data, **kwargs)
13
+
14
+
15
+ def drawChartByMetric(st, data, metrics=("qps", "recall"), **kwargs):
16
+ columns = st.columns(len(metrics))
17
+ for i, metric in enumerate(metrics):
18
+ container = columns[i]
19
+ container.markdown(f"#### {metric}")
20
+ drawChart(container, data, metric)
21
+
22
+
23
+ def getRange(metric, data, padding_multipliers):
24
+ minV = min([d.get(metric, 0) for d in data])
25
+ maxV = max([d.get(metric, 0) for d in data])
26
+ padding = maxV - minV
27
+ rangeV = [
28
+ minV - padding * padding_multipliers[0],
29
+ maxV + padding * padding_multipliers[1],
30
+ ]
31
+ return rangeV
32
+
33
+
34
+ def drawChart(st, data: list[object], metric):
35
+ unit = metric_unit_map.get(metric, "")
36
+ x = "filter_rate"
37
+ xrange = getRange(x, data, [0.05, 0.1])
38
+
39
+ y = metric
40
+ yrange = getRange(y, data, [0.2, 0.1])
41
+
42
+ data.sort(key=lambda a: a[x])
43
+
44
+ fig = px.line(
45
+ data,
46
+ x=x,
47
+ y=y,
48
+ color="db_name",
49
+ line_group="db_name",
50
+ text=metric,
51
+ markers=True,
52
+ )
53
+ fig.update_xaxes(range=xrange)
54
+ fig.update_yaxes(range=yrange)
55
+ fig.update_traces(textposition="bottom right", texttemplate="%{y:,.4~r}" + unit)
56
+ fig.update_layout(
57
+ margin=dict(l=0, r=0, t=40, b=0, pad=8),
58
+ legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""),
59
+ )
60
+ st.plotly_chart(fig, use_container_width=True)
@@ -1,8 +1,21 @@
1
- from vectordb_bench.frontend.config.styles import *
2
- from vectordb_bench.frontend.config.dbCaseConfigs import *
1
+ from vectordb_bench.backend.clients import DB
2
+ from vectordb_bench.frontend.components.run_test.inputWidget import inputWidget
3
3
  from collections import defaultdict
4
+ from vectordb_bench.frontend.config.dbCaseConfigs import (
5
+ UI_CASE_CLUSTERS,
6
+ UICaseItem,
7
+ UICaseItemCluster,
8
+ get_case_config_inputs,
9
+ get_custom_case_cluter,
10
+ )
11
+ from vectordb_bench.frontend.config.styles import (
12
+ CASE_CONFIG_SETTING_COLUMNS,
13
+ CHECKBOX_INDENT,
14
+ DB_CASE_CONFIG_SETTING_COLUMNS,
15
+ )
4
16
 
5
17
  from vectordb_bench.frontend.utils import addHorizontalLine
18
+ from vectordb_bench.models import CaseConfig
6
19
 
7
20
 
8
21
  def caseSelector(st, activedDbList: list[DB]):
@@ -24,7 +37,7 @@ def caseSelector(st, activedDbList: list[DB]):
24
37
  activedCaseList += caseClusterExpander(st, caseCluster, dbToCaseClusterConfigs, activedDbList)
25
38
  for db in dbToCaseClusterConfigs:
26
39
  for uiCaseItem in dbToCaseClusterConfigs[db]:
27
- for case in uiCaseItem.cases:
40
+ for case in uiCaseItem.get_cases():
28
41
  dbToCaseConfigs[db][case] = dbToCaseClusterConfigs[db][uiCaseItem]
29
42
 
30
43
  return activedCaseList, dbToCaseConfigs
@@ -48,15 +61,38 @@ def caseItemCheckbox(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, actived
48
61
  unsafe_allow_html=True,
49
62
  )
50
63
 
64
+ caseConfigSetting(st.container(), uiCaseItem)
65
+
51
66
  if selected:
52
- caseConfigSetting(st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList)
67
+ dbCaseConfigSetting(st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList)
68
+
69
+ return uiCaseItem.get_cases() if selected else []
70
+
71
+
72
+ def caseConfigSetting(st, uiCaseItem: UICaseItem):
73
+ config_inputs = uiCaseItem.extra_custom_case_config_inputs
74
+ if len(config_inputs) == 0:
75
+ return
53
76
 
54
- return uiCaseItem.cases if selected else []
77
+ columns = st.columns(
78
+ [
79
+ 1,
80
+ *[DB_CASE_CONFIG_SETTING_COLUMNS / CASE_CONFIG_SETTING_COLUMNS] * CASE_CONFIG_SETTING_COLUMNS,
81
+ ]
82
+ )
83
+ columns[0].markdown(
84
+ f"<div style='margin: 0 0 24px {CHECKBOX_INDENT}px; font-size: 18px; font-weight: 600;'>Custom Config</div>",
85
+ unsafe_allow_html=True,
86
+ )
87
+ for i, config_input in enumerate(config_inputs):
88
+ column = columns[1 + i % CASE_CONFIG_SETTING_COLUMNS]
89
+ key = f"custom-config-{uiCaseItem.label}-{config_input.label.value}"
90
+ uiCaseItem.tmp_custom_config[config_input.label.value] = inputWidget(column, config=config_input, key=key)
55
91
 
56
92
 
57
- def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
93
+ def dbCaseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
58
94
  for db in activedDbList:
59
- columns = st.columns(1 + CASE_CONFIG_SETTING_COLUMNS)
95
+ columns = st.columns(1 + DB_CASE_CONFIG_SETTING_COLUMNS)
60
96
  # column 0 - title
61
97
  dbColumn = columns[0]
62
98
  dbColumn.markdown(
@@ -64,52 +100,12 @@ def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, active
64
100
  unsafe_allow_html=True,
65
101
  )
66
102
  k = 0
67
- caseConfig = dbToCaseClusterConfigs[db][uiCaseItem]
68
- for config in CASE_CONFIG_MAP.get(db, {}).get(uiCaseItem.caseLabel, []):
69
- if config.isDisplayed(caseConfig):
70
- column = columns[1 + k % CASE_CONFIG_SETTING_COLUMNS]
103
+ dbCaseConfig = dbToCaseClusterConfigs[db][uiCaseItem]
104
+ for config in get_case_config_inputs(db, uiCaseItem.caseLabel):
105
+ if config.isDisplayed(dbCaseConfig):
106
+ column = columns[1 + k % DB_CASE_CONFIG_SETTING_COLUMNS]
71
107
  key = "%s-%s-%s" % (db, uiCaseItem.label, config.label.value)
72
- if config.inputType == InputType.Text:
73
- caseConfig[config.label] = column.text_input(
74
- config.displayLabel if config.displayLabel else config.label.value,
75
- key=key,
76
- help=config.inputHelp,
77
- value=config.inputConfig["value"],
78
- )
79
- elif config.inputType == InputType.Option:
80
- caseConfig[config.label] = column.selectbox(
81
- config.displayLabel if config.displayLabel else config.label.value,
82
- config.inputConfig["options"],
83
- key=key,
84
- help=config.inputHelp,
85
- )
86
- elif config.inputType == InputType.Number:
87
- caseConfig[config.label] = column.number_input(
88
- config.displayLabel if config.displayLabel else config.label.value,
89
- # format="%d",
90
- step=config.inputConfig.get("step", 1),
91
- min_value=config.inputConfig["min"],
92
- max_value=config.inputConfig["max"],
93
- key=key,
94
- value=config.inputConfig["value"],
95
- help=config.inputHelp,
96
- )
97
- elif config.inputType == InputType.Float:
98
- caseConfig[config.label] = column.number_input(
99
- config.displayLabel if config.displayLabel else config.label.value,
100
- step=config.inputConfig.get("step", 0.1),
101
- min_value=config.inputConfig["min"],
102
- max_value=config.inputConfig["max"],
103
- key=key,
104
- value=config.inputConfig["value"],
105
- help=config.inputHelp,
106
- )
107
- elif config.inputType == InputType.Bool:
108
- caseConfig[config.label] = column.checkbox(
109
- config.displayLabel if config.displayLabel else config.label.value,
110
- value=config.inputConfig["value"],
111
- help=config.inputHelp,
112
- )
108
+ dbCaseConfig[config.label] = inputWidget(column, config, key)
113
109
  k += 1
114
110
  if k == 0:
115
111
  columns[1].write("Auto")
@@ -1,9 +1,10 @@
1
1
  from streamlit.runtime.media_file_storage import MediaFileStorageError
2
2
  from vectordb_bench.frontend.config.styles import DB_SELECTOR_COLUMNS, DB_TO_ICON
3
3
  from vectordb_bench.frontend.config.dbCaseConfigs import DB_LIST
4
+ import streamlit as st
4
5
 
5
6
 
6
- def dbSelector(st):
7
+ def dbSelector(st: st):
7
8
  st.markdown(
8
9
  "<div style='height: 12px;'></div>",
9
10
  unsafe_allow_html=True,
@@ -20,11 +21,14 @@ def dbSelector(st):
20
21
  for i, db in enumerate(DB_LIST):
21
22
  column = dbContainerColumns[i % DB_SELECTOR_COLUMNS]
22
23
  dbIsActived[db] = column.checkbox(db.name)
23
- try:
24
- column.image(DB_TO_ICON.get(db, ""))
25
- except MediaFileStorageError:
24
+ image_src = DB_TO_ICON.get(db, None)
25
+ if image_src:
26
+ column.markdown(
27
+ f'<img src="{image_src}" style="width:100px;height:100px;object-fit:contain;object-position:center;margin-bottom:10px;">',
28
+ unsafe_allow_html=True,
29
+ )
30
+ else:
26
31
  column.warning(f"{db.name} image not available")
27
- pass
28
32
  activedDbList = [db for db in DB_LIST if dbIsActived[db]]
29
33
 
30
34
  return activedDbList
@@ -0,0 +1,48 @@
1
+ from vectordb_bench.frontend.config.dbCaseConfigs import CaseConfigInput, InputType
2
+
3
+
4
+ def inputWidget(st, config: CaseConfigInput, key: str):
5
+ if config.inputType == InputType.Text:
6
+ return st.text_input(
7
+ config.displayLabel if config.displayLabel else config.label.value,
8
+ key=key,
9
+ help=config.inputHelp,
10
+ value=config.inputConfig["value"],
11
+ )
12
+ if config.inputType == InputType.Option:
13
+ return st.selectbox(
14
+ config.displayLabel if config.displayLabel else config.label.value,
15
+ config.inputConfig["options"],
16
+ key=key,
17
+ help=config.inputHelp,
18
+ )
19
+ if config.inputType == InputType.Number:
20
+ return st.number_input(
21
+ config.displayLabel if config.displayLabel else config.label.value,
22
+ # format="%d",
23
+ step=config.inputConfig.get("step", 1),
24
+ min_value=config.inputConfig["min"],
25
+ max_value=config.inputConfig["max"],
26
+ key=key,
27
+ value=config.inputConfig["value"],
28
+ help=config.inputHelp,
29
+ )
30
+ if config.inputType == InputType.Float:
31
+ return st.number_input(
32
+ config.displayLabel if config.displayLabel else config.label.value,
33
+ step=config.inputConfig.get("step", 0.1),
34
+ min_value=config.inputConfig["min"],
35
+ max_value=config.inputConfig["max"],
36
+ key=key,
37
+ value=config.inputConfig["value"],
38
+ help=config.inputHelp,
39
+ )
40
+ if config.inputType == InputType.Bool:
41
+ return st.selectbox(
42
+ config.displayLabel if config.displayLabel else config.label.value,
43
+ options=[True, False],
44
+ index=0 if config.inputConfig["value"] else 1,
45
+ key=key,
46
+ help=config.inputHelp,
47
+ )
48
+ raise Exception(f"Invalid InputType: {config.inputType}")
@@ -86,7 +86,9 @@ def controlPanel(st, tasks: list[TaskConfig], taskLabel, isAllValid):
86
86
  currentTaskId = benchmark_runner.get_current_task_id()
87
87
  tasksCount = benchmark_runner.get_tasks_count()
88
88
  text = f":running: Running Task {currentTaskId} / {tasksCount}"
89
- st.progress(currentTaskId / tasksCount, text=text)
89
+
90
+ if tasksCount > 0:
91
+ st.progress(currentTaskId / tasksCount, text=text)
90
92
 
91
93
  columns = st.columns(6)
92
94
  columns[0].button(
@@ -0,0 +1,253 @@
1
+ import plotly.graph_objects as go
2
+
3
+ from vectordb_bench.frontend.components.streaming.data import (
4
+ DisplayedMetric,
5
+ StreamingData,
6
+ get_streaming_data,
7
+ )
8
+ from vectordb_bench.frontend.config.styles import (
9
+ COLORS_10,
10
+ COLORS_2,
11
+ SCATTER_LINE_WIDTH,
12
+ SCATTER_MAKER_SIZE,
13
+ STREAMING_CHART_COLUMNS,
14
+ )
15
+
16
+
17
+ def drawChartsByCase(
18
+ st,
19
+ allData,
20
+ showCaseNames: list[str],
21
+ **kwargs,
22
+ ):
23
+ allData = [d for d in allData if len(d["st_search_stage_list"]) > 0]
24
+ for case_name in showCaseNames:
25
+ data = [d for d in allData if d["case_name"] == case_name]
26
+ if len(data) == 0:
27
+ continue
28
+ container = st.container()
29
+ container.write("") # blank line
30
+ container.subheader(case_name)
31
+ drawChartByMetric(container, data, case_name=case_name, **kwargs)
32
+ container.write("") # blank line
33
+
34
+
35
+ def drawChartByMetric(
36
+ st,
37
+ case_data,
38
+ case_name: str,
39
+ line_chart_displayed_y_metrics: list[tuple[DisplayedMetric, str]],
40
+ **kwargs,
41
+ ):
42
+ columns = st.columns(STREAMING_CHART_COLUMNS)
43
+ streaming_data = get_streaming_data(case_data)
44
+
45
+ # line chart
46
+ for i, metric_info in enumerate(line_chart_displayed_y_metrics):
47
+ metric, note = metric_info
48
+ container = columns[i % STREAMING_CHART_COLUMNS]
49
+ container.markdown(f"#### {metric.value.capitalize()}")
50
+ container.markdown(f"{note}")
51
+ key = f"{case_name}-{metric.value}"
52
+ drawLineChart(container, streaming_data, metric=metric, key=key, **kwargs)
53
+
54
+ # bar chart
55
+ container = columns[len(line_chart_displayed_y_metrics) % STREAMING_CHART_COLUMNS]
56
+ container.markdown("#### Duration")
57
+ container.markdown(
58
+ "insert more than ideal-insert-duration (dash-line) means exceeding the maximum processing capacity.",
59
+ help="vectordb need more time to process accumulated insert requests.",
60
+ )
61
+ key = f"{case_name}-duration"
62
+ drawBarChart(container, case_data, key=key, **kwargs)
63
+ # drawLineChart(container, data, line_x_displayed_label, label)
64
+ # drawTestChart(container)
65
+
66
+
67
+ def drawLineChart(
68
+ st,
69
+ streaming_data: list[StreamingData],
70
+ metric: DisplayedMetric,
71
+ key: str,
72
+ with_last_optimized_data=True,
73
+ **kwargs,
74
+ ):
75
+ db_names = list({d.db_name for d in streaming_data})
76
+ db_names.sort()
77
+ x_metric = kwargs.get("line_chart_displayed_x_metric", DisplayedMetric.search_stage)
78
+ fig = go.Figure()
79
+ if x_metric == DisplayedMetric.search_time:
80
+ ideal_insert_duration = streaming_data[0].ideal_insert_duration
81
+ fig.add_shape(
82
+ type="line",
83
+ y0=min([getattr(d, metric.value) for d in streaming_data]),
84
+ y1=max([getattr(d, metric.value) for d in streaming_data]),
85
+ x0=ideal_insert_duration,
86
+ x1=ideal_insert_duration,
87
+ line=dict(color="#999", width=SCATTER_LINE_WIDTH, dash="dot"),
88
+ showlegend=True,
89
+ name="insert 100% standard time",
90
+ )
91
+ for i, db_name in enumerate(db_names):
92
+ data = [d for d in streaming_data if d.db_name == db_name]
93
+ color = COLORS_10[i]
94
+ if with_last_optimized_data:
95
+ fig.add_trace(
96
+ get_optimized_scatter(
97
+ data,
98
+ db_name=db_name,
99
+ metric=metric,
100
+ color=color,
101
+ **kwargs,
102
+ )
103
+ )
104
+ fig.add_trace(
105
+ get_normal_scatter(
106
+ data,
107
+ db_name=db_name,
108
+ metric=metric,
109
+ color=color,
110
+ **kwargs,
111
+ )
112
+ )
113
+ fig.update_layout(
114
+ margin=dict(l=0, r=0, t=40, b=0, pad=8),
115
+ legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="left", x=0, title=""),
116
+ )
117
+
118
+ x_title = "Search Stages (%)"
119
+ if x_metric == DisplayedMetric.search_time:
120
+ x_title = "Actual Time (s)"
121
+ fig.update_layout(xaxis_title=x_title)
122
+ st.plotly_chart(fig, use_container_width=True, key=key)
123
+
124
+
125
+ def get_normal_scatter(
126
+ data: list[StreamingData],
127
+ db_name: str,
128
+ metric: DisplayedMetric,
129
+ color: str,
130
+ line_chart_displayed_x_metric: DisplayedMetric,
131
+ **kwargs,
132
+ ):
133
+ unit = ""
134
+ if "latency" in metric.value:
135
+ unit = "ms"
136
+ data.sort(key=lambda x: getattr(x, line_chart_displayed_x_metric.value))
137
+ data = [d for d in data if not d.optimized]
138
+ hovertemplate = f"%{{text}}% data inserted.<br>{metric.value}=%{{y:.4g}}{unit}"
139
+ if line_chart_displayed_x_metric == DisplayedMetric.search_time:
140
+ hovertemplate = f"%{{text}}% data inserted.<br>actual_time=%{{x:.4g}}s<br>{metric.value}=%{{y:.4g}}{unit}"
141
+ return go.Scatter(
142
+ x=[getattr(d, line_chart_displayed_x_metric.value) for d in data],
143
+ y=[getattr(d, metric.value) for d in data],
144
+ text=[d.search_stage for d in data],
145
+ mode="markers+lines",
146
+ name=db_name,
147
+ marker=dict(color=color, size=SCATTER_MAKER_SIZE),
148
+ line=dict(dash="solid", width=SCATTER_LINE_WIDTH, color=color),
149
+ legendgroup=db_name,
150
+ hovertemplate=hovertemplate,
151
+ )
152
+
153
+
154
+ def get_optimized_scatter(
155
+ data: list[StreamingData],
156
+ db_name: str,
157
+ metric: DisplayedMetric,
158
+ color: str,
159
+ line_chart_displayed_x_metric: DisplayedMetric,
160
+ **kwargs,
161
+ ):
162
+ unit = ""
163
+ if "latency" in metric.value:
164
+ unit = "ms"
165
+ data.sort(key=lambda x: x.search_stage)
166
+ if not data[-1].optimized or len(data) < 2:
167
+ return go.Scatter()
168
+ data = data[-2:]
169
+ hovertemplate = f"all data inserted and <b style='color: #333;'>optimized</b>.<br>{metric.value}=%{{y:.4g}}{unit}"
170
+ if line_chart_displayed_x_metric == DisplayedMetric.search_time:
171
+ hovertemplate = f"all data inserted and <b style='color: #333;'>optimized</b>.<br>actual_time=%{{x:.4g}}s<br>{metric.value}=%{{y:.4g}}{unit}"
172
+ return go.Scatter(
173
+ x=[getattr(d, line_chart_displayed_x_metric.value) for d in data],
174
+ y=[getattr(d, metric.value) for d in data],
175
+ text=[d.search_stage for d in data],
176
+ mode="markers+lines",
177
+ name=db_name,
178
+ legendgroup=db_name,
179
+ marker=dict(color=color, size=[0, SCATTER_MAKER_SIZE]),
180
+ line=dict(dash="dash", width=SCATTER_LINE_WIDTH, color=color),
181
+ hovertemplate=hovertemplate,
182
+ showlegend=False,
183
+ )
184
+
185
+
186
+ def drawBarChart(
187
+ st,
188
+ data,
189
+ key: str,
190
+ with_last_optimized_data=True,
191
+ **kwargs,
192
+ ):
193
+ if len(data) < 1:
194
+ return
195
+ fig = go.Figure()
196
+
197
+ # ideal insert duration
198
+ ideal_insert_duration = data[0]["st_ideal_insert_duration"]
199
+ fig.add_shape(
200
+ type="line",
201
+ y0=-0.5,
202
+ y1=len(data) - 0.5,
203
+ x0=ideal_insert_duration,
204
+ x1=ideal_insert_duration,
205
+ line=dict(color="#999", width=SCATTER_LINE_WIDTH, dash="dot"),
206
+ showlegend=True,
207
+ name="insert 100% standard time",
208
+ )
209
+
210
+ # insert duration
211
+ fig.add_trace(
212
+ get_bar(
213
+ data,
214
+ metric=DisplayedMetric.insert_duration,
215
+ color=COLORS_2[0],
216
+ **kwargs,
217
+ )
218
+ )
219
+
220
+ # optimized duration
221
+ if with_last_optimized_data:
222
+ fig.add_trace(
223
+ get_bar(
224
+ data,
225
+ metric=DisplayedMetric.optimize_duration,
226
+ color=COLORS_2[1],
227
+ **kwargs,
228
+ )
229
+ )
230
+ fig.update_layout(
231
+ margin=dict(l=0, r=0, t=40, b=0, pad=8),
232
+ legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="left", x=0, title=""),
233
+ )
234
+ fig.update_layout(xaxis_title="time (s)")
235
+ fig.update_layout(barmode="stack")
236
+ fig.update_traces(width=0.15)
237
+ st.plotly_chart(fig, use_container_width=True, key=key)
238
+
239
+
240
+ def get_bar(
241
+ data: list[StreamingData],
242
+ metric: DisplayedMetric,
243
+ color: str,
244
+ **kwargs,
245
+ ):
246
+ return go.Bar(
247
+ x=[d[metric.value] for d in data],
248
+ y=[d["db_name"] for d in data],
249
+ name=metric,
250
+ marker_color=color,
251
+ orientation="h",
252
+ hovertemplate="%{y} %{x:.2f}s",
253
+ )