vectordb-bench 0.0.29__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -27
- vectordb_bench/backend/assembler.py +19 -6
- vectordb_bench/backend/cases.py +186 -23
- vectordb_bench/backend/clients/__init__.py +32 -0
- vectordb_bench/backend/clients/api.py +22 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +249 -43
- vectordb_bench/backend/clients/aws_opensearch/cli.py +51 -21
- vectordb_bench/backend/clients/aws_opensearch/config.py +58 -16
- vectordb_bench/backend/clients/chroma/chroma.py +6 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
- vectordb_bench/backend/clients/lancedb/cli.py +62 -8
- vectordb_bench/backend/clients/lancedb/config.py +14 -1
- vectordb_bench/backend/clients/lancedb/lancedb.py +21 -9
- vectordb_bench/backend/clients/memorydb/memorydb.py +2 -2
- vectordb_bench/backend/clients/milvus/cli.py +30 -9
- vectordb_bench/backend/clients/milvus/config.py +3 -0
- vectordb_bench/backend/clients/milvus/milvus.py +81 -23
- vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
- vectordb_bench/backend/clients/oceanbase/config.py +125 -0
- vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
- vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
- vectordb_bench/backend/clients/qdrant_local/cli.py +60 -0
- vectordb_bench/backend/clients/qdrant_local/config.py +47 -0
- vectordb_bench/backend/clients/qdrant_local/qdrant_local.py +232 -0
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +29 -3
- vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -0
- vectordb_bench/backend/dataset.py +143 -27
- vectordb_bench/backend/filter.py +76 -0
- vectordb_bench/backend/runner/__init__.py +3 -3
- vectordb_bench/backend/runner/mp_runner.py +52 -39
- vectordb_bench/backend/runner/rate_runner.py +68 -52
- vectordb_bench/backend/runner/read_write_runner.py +125 -68
- vectordb_bench/backend/runner/serial_runner.py +56 -23
- vectordb_bench/backend/task_runner.py +48 -20
- vectordb_bench/cli/batch_cli.py +121 -0
- vectordb_bench/cli/cli.py +59 -1
- vectordb_bench/cli/vectordbbench.py +7 -0
- vectordb_bench/config-files/batch_sample_config.yml +17 -0
- vectordb_bench/frontend/components/check_results/data.py +16 -11
- vectordb_bench/frontend/components/check_results/filters.py +53 -25
- vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
- vectordb_bench/frontend/components/check_results/nav.py +20 -0
- vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
- vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
- vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
- vectordb_bench/frontend/components/label_filter/charts.py +60 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
- vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
- vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
- vectordb_bench/frontend/components/streaming/charts.py +253 -0
- vectordb_bench/frontend/components/streaming/data.py +62 -0
- vectordb_bench/frontend/components/tables/data.py +1 -1
- vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
- vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
- vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +420 -41
- vectordb_bench/frontend/config/styles.py +32 -2
- vectordb_bench/frontend/pages/concurrent.py +5 -1
- vectordb_bench/frontend/pages/custom.py +4 -0
- vectordb_bench/frontend/pages/label_filter.py +56 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
- vectordb_bench/frontend/pages/results.py +60 -0
- vectordb_bench/frontend/pages/run_test.py +3 -3
- vectordb_bench/frontend/pages/streaming.py +135 -0
- vectordb_bench/frontend/pages/tables.py +4 -0
- vectordb_bench/frontend/vdb_benchmark.py +16 -41
- vectordb_bench/interface.py +6 -2
- vectordb_bench/metric.py +15 -1
- vectordb_bench/models.py +38 -11
- vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
- vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
- vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
- vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
- vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
- vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
- vectordb_bench/results/dbPrices.json +12 -4
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +131 -32
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +87 -65
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +1 -1
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {vectordb_bench-0.0.29.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -20,3 +20,23 @@ def NavToResults(st, key="nav-to-results"):
|
|
20
20
|
navClick = st.button("< Back to Results", key=key)
|
21
21
|
if navClick:
|
22
22
|
switch_page("vdb benchmark")
|
23
|
+
|
24
|
+
|
25
|
+
def NavToPages(st):
|
26
|
+
options = [
|
27
|
+
{"name": "Run Test", "link": "run_test"},
|
28
|
+
{"name": "Results", "link": "results"},
|
29
|
+
{"name": "Quries Per Dollar", "link": "quries_per_dollar"},
|
30
|
+
{"name": "Concurrent", "link": "concurrent"},
|
31
|
+
{"name": "Label Filter", "link": "label_filter"},
|
32
|
+
{"name": "Streaming", "link": "streaming"},
|
33
|
+
{"name": "Tables", "link": "tables"},
|
34
|
+
{"name": "Custom Dataset", "link": "custom"},
|
35
|
+
]
|
36
|
+
|
37
|
+
html = ""
|
38
|
+
for i, option in enumerate(options):
|
39
|
+
html += f'<a href="/{option["link"]}" target="_self" style="text-decoration: none; padding: 0.1px 0.2px;">{option["name"]}</a>'
|
40
|
+
if i < len(options) - 1:
|
41
|
+
html += '<span style="color: #888; margin: 0 5px;">|</span>'
|
42
|
+
st.markdown(html, unsafe_allow_html=True)
|
@@ -12,7 +12,7 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
|
|
12
12
|
"Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir
|
13
13
|
)
|
14
14
|
|
15
|
-
columns = st.columns(
|
15
|
+
columns = st.columns(3)
|
16
16
|
customCase.dataset_config.dim = columns[0].number_input(
|
17
17
|
"dim", key=f"{key}_dim", value=customCase.dataset_config.dim
|
18
18
|
)
|
@@ -22,16 +22,51 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
|
|
22
22
|
customCase.dataset_config.metric_type = columns[2].selectbox(
|
23
23
|
"metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"]
|
24
24
|
)
|
25
|
-
|
26
|
-
|
25
|
+
|
26
|
+
columns = st.columns(3)
|
27
|
+
customCase.dataset_config.train_name = columns[0].text_input(
|
28
|
+
"train file name",
|
29
|
+
key=f"{key}_train_name",
|
30
|
+
value=customCase.dataset_config.train_name,
|
31
|
+
)
|
32
|
+
customCase.dataset_config.test_name = columns[1].text_input(
|
33
|
+
"test file name", key=f"{key}_test_name", value=customCase.dataset_config.test_name
|
34
|
+
)
|
35
|
+
customCase.dataset_config.gt_name = columns[2].text_input(
|
36
|
+
"ground truth file name", key=f"{key}_gt_name", value=customCase.dataset_config.gt_name
|
37
|
+
)
|
38
|
+
|
39
|
+
columns = st.columns([1, 1, 2, 2])
|
40
|
+
customCase.dataset_config.train_id_name = columns[0].text_input(
|
41
|
+
"train id name", key=f"{key}_train_id_name", value=customCase.dataset_config.train_id_name
|
42
|
+
)
|
43
|
+
customCase.dataset_config.train_col_name = columns[1].text_input(
|
44
|
+
"train emb name", key=f"{key}_train_col_name", value=customCase.dataset_config.train_col_name
|
45
|
+
)
|
46
|
+
customCase.dataset_config.test_col_name = columns[2].text_input(
|
47
|
+
"test emb name", key=f"{key}_test_col_name", value=customCase.dataset_config.test_col_name
|
48
|
+
)
|
49
|
+
customCase.dataset_config.gt_col_name = columns[3].text_input(
|
50
|
+
"ground truth emb name", key=f"{key}_gt_col_name", value=customCase.dataset_config.gt_col_name
|
27
51
|
)
|
28
52
|
|
29
|
-
columns = st.columns(
|
30
|
-
customCase.dataset_config.
|
31
|
-
"
|
53
|
+
columns = st.columns(2)
|
54
|
+
customCase.dataset_config.scalar_labels_name = columns[0].text_input(
|
55
|
+
"scalar labels file name",
|
56
|
+
key=f"{key}_scalar_labels_file_name",
|
57
|
+
value=customCase.dataset_config.scalar_labels_name,
|
32
58
|
)
|
33
|
-
|
34
|
-
|
59
|
+
default_label_percentages = ",".join(map(str, customCase.dataset_config.with_label_percentages))
|
60
|
+
label_percentage_input = columns[1].text_input(
|
61
|
+
"label percentages",
|
62
|
+
key=f"{key}_label_percantages",
|
63
|
+
value=default_label_percentages,
|
35
64
|
)
|
65
|
+
try:
|
66
|
+
customCase.dataset_config.label_percentages = [
|
67
|
+
float(item.strip()) for item in label_percentage_input.split(",") if item.strip()
|
68
|
+
]
|
69
|
+
except ValueError as e:
|
70
|
+
st.write(f"<span style='color:red'>{e},please input correct number</span>", unsafe_allow_html=True)
|
36
71
|
|
37
72
|
customCase.description = st.text_area("description", key=f"{key}_description", value=customCase.description)
|
@@ -2,13 +2,18 @@ def displayParams(st):
|
|
2
2
|
st.markdown(
|
3
3
|
"""
|
4
4
|
- `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
|
5
|
-
- Vectors data files: The file
|
6
|
-
- Query test vectors: The file
|
7
|
-
- Ground truth file: The file
|
5
|
+
- Vectors data files: The file should have two kinds of columns: `id` as an incrementing `int` and `emb` as an array of `float32`. The name of two columns could be defined on your own.
|
6
|
+
- Query test vectors: The file could be named on your own and should have two kinds of columns: `id` as an incrementing `int` and `emb` as an array of `float32`. The `id` column must be named as `id`, and `emb` column could be defined on your own.
|
7
|
+
- Ground truth file: The file could be named on your own and should have two kinds of columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`. The `id` column must be named as `id`, and `neighbors_id` column could be defined on your own.
|
8
8
|
|
9
|
-
- `Train File
|
9
|
+
- `Train File Name` - If the number of train file is `more than one`, please input all your train file name and `split with ','` without the `.parquet` file extensionthe. For example, if there are two train file and the name of them are `train1.parquet` and `train2.parquet`, then input `train1,train2`.
|
10
|
+
|
11
|
+
- `Ground Truth Emb Name` - No matter whether filter file is applied or not, the `neighbors_id` column in ground truth file must have the same name.
|
12
|
+
|
13
|
+
- `Scalar Labels File Name ` - If there is a scalar labels file, please input the filename without the .parquet extension. The file should have two columns: `id` as an incrementing `int` and `labels` as an array of `string`. The `id` column must correspond one-to-one with the `id` column in train file..
|
14
|
+
|
15
|
+
- `Label percentages` - If you have filter file, please input label percentage you want to real run and `split with ','` when it's `more than one`. If you `don't have` filter file, than `keep the text vacant.`
|
10
16
|
|
11
|
-
- `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.
|
12
17
|
"""
|
13
18
|
)
|
14
19
|
st.caption(
|
@@ -14,6 +14,16 @@ class CustomDatasetConfig(BaseModel):
|
|
14
14
|
file_count: int = 1
|
15
15
|
use_shuffled: bool = False
|
16
16
|
with_gt: bool = True
|
17
|
+
train_name: str = "train"
|
18
|
+
test_name: str = "test"
|
19
|
+
gt_name: str = "neighbors"
|
20
|
+
train_id_name: str = "id"
|
21
|
+
train_col_name: str = "emb"
|
22
|
+
test_col_name: str = "emb"
|
23
|
+
gt_col_name: str = "neighbors_id"
|
24
|
+
scalar_labels_name: str = "scalar_labels"
|
25
|
+
label_percentages: list[str] = []
|
26
|
+
with_label_percentages: list[float] = [0.001, 0.02, 0.5]
|
17
27
|
|
18
28
|
|
19
29
|
class CustomCaseConfig(BaseModel):
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import plotly.express as px
|
2
|
+
from vectordb_bench.metric import metric_unit_map
|
3
|
+
|
4
|
+
|
5
|
+
def drawCharts(st, allData, **kwargs):
|
6
|
+
dataset_names = list(set([data["dataset_name"] for data in allData]))
|
7
|
+
dataset_names.sort()
|
8
|
+
for dataset_name in dataset_names:
|
9
|
+
container = st.container()
|
10
|
+
container.subheader(dataset_name)
|
11
|
+
data = [d for d in allData if d["dataset_name"] == dataset_name]
|
12
|
+
drawChartByMetric(container, data, **kwargs)
|
13
|
+
|
14
|
+
|
15
|
+
def drawChartByMetric(st, data, metrics=("qps", "recall"), **kwargs):
|
16
|
+
columns = st.columns(len(metrics))
|
17
|
+
for i, metric in enumerate(metrics):
|
18
|
+
container = columns[i]
|
19
|
+
container.markdown(f"#### {metric}")
|
20
|
+
drawChart(container, data, metric)
|
21
|
+
|
22
|
+
|
23
|
+
def getRange(metric, data, padding_multipliers):
|
24
|
+
minV = min([d.get(metric, 0) for d in data])
|
25
|
+
maxV = max([d.get(metric, 0) for d in data])
|
26
|
+
padding = maxV - minV
|
27
|
+
rangeV = [
|
28
|
+
minV - padding * padding_multipliers[0],
|
29
|
+
maxV + padding * padding_multipliers[1],
|
30
|
+
]
|
31
|
+
return rangeV
|
32
|
+
|
33
|
+
|
34
|
+
def drawChart(st, data: list[object], metric):
|
35
|
+
unit = metric_unit_map.get(metric, "")
|
36
|
+
x = "filter_rate"
|
37
|
+
xrange = getRange(x, data, [0.05, 0.1])
|
38
|
+
|
39
|
+
y = metric
|
40
|
+
yrange = getRange(y, data, [0.2, 0.1])
|
41
|
+
|
42
|
+
data.sort(key=lambda a: a[x])
|
43
|
+
|
44
|
+
fig = px.line(
|
45
|
+
data,
|
46
|
+
x=x,
|
47
|
+
y=y,
|
48
|
+
color="db_name",
|
49
|
+
line_group="db_name",
|
50
|
+
text=metric,
|
51
|
+
markers=True,
|
52
|
+
)
|
53
|
+
fig.update_xaxes(range=xrange)
|
54
|
+
fig.update_yaxes(range=yrange)
|
55
|
+
fig.update_traces(textposition="bottom right", texttemplate="%{y:,.4~r}" + unit)
|
56
|
+
fig.update_layout(
|
57
|
+
margin=dict(l=0, r=0, t=40, b=0, pad=8),
|
58
|
+
legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""),
|
59
|
+
)
|
60
|
+
st.plotly_chart(fig, use_container_width=True)
|
@@ -1,8 +1,21 @@
|
|
1
|
-
from vectordb_bench.
|
2
|
-
from vectordb_bench.frontend.
|
1
|
+
from vectordb_bench.backend.clients import DB
|
2
|
+
from vectordb_bench.frontend.components.run_test.inputWidget import inputWidget
|
3
3
|
from collections import defaultdict
|
4
|
+
from vectordb_bench.frontend.config.dbCaseConfigs import (
|
5
|
+
UI_CASE_CLUSTERS,
|
6
|
+
UICaseItem,
|
7
|
+
UICaseItemCluster,
|
8
|
+
get_case_config_inputs,
|
9
|
+
get_custom_case_cluter,
|
10
|
+
)
|
11
|
+
from vectordb_bench.frontend.config.styles import (
|
12
|
+
CASE_CONFIG_SETTING_COLUMNS,
|
13
|
+
CHECKBOX_INDENT,
|
14
|
+
DB_CASE_CONFIG_SETTING_COLUMNS,
|
15
|
+
)
|
4
16
|
|
5
17
|
from vectordb_bench.frontend.utils import addHorizontalLine
|
18
|
+
from vectordb_bench.models import CaseConfig
|
6
19
|
|
7
20
|
|
8
21
|
def caseSelector(st, activedDbList: list[DB]):
|
@@ -24,7 +37,7 @@ def caseSelector(st, activedDbList: list[DB]):
|
|
24
37
|
activedCaseList += caseClusterExpander(st, caseCluster, dbToCaseClusterConfigs, activedDbList)
|
25
38
|
for db in dbToCaseClusterConfigs:
|
26
39
|
for uiCaseItem in dbToCaseClusterConfigs[db]:
|
27
|
-
for case in uiCaseItem.
|
40
|
+
for case in uiCaseItem.get_cases():
|
28
41
|
dbToCaseConfigs[db][case] = dbToCaseClusterConfigs[db][uiCaseItem]
|
29
42
|
|
30
43
|
return activedCaseList, dbToCaseConfigs
|
@@ -48,15 +61,38 @@ def caseItemCheckbox(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, actived
|
|
48
61
|
unsafe_allow_html=True,
|
49
62
|
)
|
50
63
|
|
64
|
+
caseConfigSetting(st.container(), uiCaseItem)
|
65
|
+
|
51
66
|
if selected:
|
52
|
-
|
67
|
+
dbCaseConfigSetting(st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList)
|
68
|
+
|
69
|
+
return uiCaseItem.get_cases() if selected else []
|
70
|
+
|
71
|
+
|
72
|
+
def caseConfigSetting(st, uiCaseItem: UICaseItem):
|
73
|
+
config_inputs = uiCaseItem.extra_custom_case_config_inputs
|
74
|
+
if len(config_inputs) == 0:
|
75
|
+
return
|
53
76
|
|
54
|
-
|
77
|
+
columns = st.columns(
|
78
|
+
[
|
79
|
+
1,
|
80
|
+
*[DB_CASE_CONFIG_SETTING_COLUMNS / CASE_CONFIG_SETTING_COLUMNS] * CASE_CONFIG_SETTING_COLUMNS,
|
81
|
+
]
|
82
|
+
)
|
83
|
+
columns[0].markdown(
|
84
|
+
f"<div style='margin: 0 0 24px {CHECKBOX_INDENT}px; font-size: 18px; font-weight: 600;'>Custom Config</div>",
|
85
|
+
unsafe_allow_html=True,
|
86
|
+
)
|
87
|
+
for i, config_input in enumerate(config_inputs):
|
88
|
+
column = columns[1 + i % CASE_CONFIG_SETTING_COLUMNS]
|
89
|
+
key = f"custom-config-{uiCaseItem.label}-{config_input.label.value}"
|
90
|
+
uiCaseItem.tmp_custom_config[config_input.label.value] = inputWidget(column, config=config_input, key=key)
|
55
91
|
|
56
92
|
|
57
|
-
def
|
93
|
+
def dbCaseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
|
58
94
|
for db in activedDbList:
|
59
|
-
columns = st.columns(1 +
|
95
|
+
columns = st.columns(1 + DB_CASE_CONFIG_SETTING_COLUMNS)
|
60
96
|
# column 0 - title
|
61
97
|
dbColumn = columns[0]
|
62
98
|
dbColumn.markdown(
|
@@ -64,52 +100,12 @@ def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, active
|
|
64
100
|
unsafe_allow_html=True,
|
65
101
|
)
|
66
102
|
k = 0
|
67
|
-
|
68
|
-
for config in
|
69
|
-
if config.isDisplayed(
|
70
|
-
column = columns[1 + k %
|
103
|
+
dbCaseConfig = dbToCaseClusterConfigs[db][uiCaseItem]
|
104
|
+
for config in get_case_config_inputs(db, uiCaseItem.caseLabel):
|
105
|
+
if config.isDisplayed(dbCaseConfig):
|
106
|
+
column = columns[1 + k % DB_CASE_CONFIG_SETTING_COLUMNS]
|
71
107
|
key = "%s-%s-%s" % (db, uiCaseItem.label, config.label.value)
|
72
|
-
|
73
|
-
caseConfig[config.label] = column.text_input(
|
74
|
-
config.displayLabel if config.displayLabel else config.label.value,
|
75
|
-
key=key,
|
76
|
-
help=config.inputHelp,
|
77
|
-
value=config.inputConfig["value"],
|
78
|
-
)
|
79
|
-
elif config.inputType == InputType.Option:
|
80
|
-
caseConfig[config.label] = column.selectbox(
|
81
|
-
config.displayLabel if config.displayLabel else config.label.value,
|
82
|
-
config.inputConfig["options"],
|
83
|
-
key=key,
|
84
|
-
help=config.inputHelp,
|
85
|
-
)
|
86
|
-
elif config.inputType == InputType.Number:
|
87
|
-
caseConfig[config.label] = column.number_input(
|
88
|
-
config.displayLabel if config.displayLabel else config.label.value,
|
89
|
-
# format="%d",
|
90
|
-
step=config.inputConfig.get("step", 1),
|
91
|
-
min_value=config.inputConfig["min"],
|
92
|
-
max_value=config.inputConfig["max"],
|
93
|
-
key=key,
|
94
|
-
value=config.inputConfig["value"],
|
95
|
-
help=config.inputHelp,
|
96
|
-
)
|
97
|
-
elif config.inputType == InputType.Float:
|
98
|
-
caseConfig[config.label] = column.number_input(
|
99
|
-
config.displayLabel if config.displayLabel else config.label.value,
|
100
|
-
step=config.inputConfig.get("step", 0.1),
|
101
|
-
min_value=config.inputConfig["min"],
|
102
|
-
max_value=config.inputConfig["max"],
|
103
|
-
key=key,
|
104
|
-
value=config.inputConfig["value"],
|
105
|
-
help=config.inputHelp,
|
106
|
-
)
|
107
|
-
elif config.inputType == InputType.Bool:
|
108
|
-
caseConfig[config.label] = column.checkbox(
|
109
|
-
config.displayLabel if config.displayLabel else config.label.value,
|
110
|
-
value=config.inputConfig["value"],
|
111
|
-
help=config.inputHelp,
|
112
|
-
)
|
108
|
+
dbCaseConfig[config.label] = inputWidget(column, config, key)
|
113
109
|
k += 1
|
114
110
|
if k == 0:
|
115
111
|
columns[1].write("Auto")
|
@@ -1,9 +1,10 @@
|
|
1
1
|
from streamlit.runtime.media_file_storage import MediaFileStorageError
|
2
2
|
from vectordb_bench.frontend.config.styles import DB_SELECTOR_COLUMNS, DB_TO_ICON
|
3
3
|
from vectordb_bench.frontend.config.dbCaseConfigs import DB_LIST
|
4
|
+
import streamlit as st
|
4
5
|
|
5
6
|
|
6
|
-
def dbSelector(st):
|
7
|
+
def dbSelector(st: st):
|
7
8
|
st.markdown(
|
8
9
|
"<div style='height: 12px;'></div>",
|
9
10
|
unsafe_allow_html=True,
|
@@ -20,11 +21,14 @@ def dbSelector(st):
|
|
20
21
|
for i, db in enumerate(DB_LIST):
|
21
22
|
column = dbContainerColumns[i % DB_SELECTOR_COLUMNS]
|
22
23
|
dbIsActived[db] = column.checkbox(db.name)
|
23
|
-
|
24
|
-
|
25
|
-
|
24
|
+
image_src = DB_TO_ICON.get(db, None)
|
25
|
+
if image_src:
|
26
|
+
column.markdown(
|
27
|
+
f'<img src="{image_src}" style="width:100px;height:100px;object-fit:contain;object-position:center;margin-bottom:10px;">',
|
28
|
+
unsafe_allow_html=True,
|
29
|
+
)
|
30
|
+
else:
|
26
31
|
column.warning(f"{db.name} image not available")
|
27
|
-
pass
|
28
32
|
activedDbList = [db for db in DB_LIST if dbIsActived[db]]
|
29
33
|
|
30
34
|
return activedDbList
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from vectordb_bench.frontend.config.dbCaseConfigs import CaseConfigInput, InputType
|
2
|
+
|
3
|
+
|
4
|
+
def inputWidget(st, config: CaseConfigInput, key: str):
|
5
|
+
if config.inputType == InputType.Text:
|
6
|
+
return st.text_input(
|
7
|
+
config.displayLabel if config.displayLabel else config.label.value,
|
8
|
+
key=key,
|
9
|
+
help=config.inputHelp,
|
10
|
+
value=config.inputConfig["value"],
|
11
|
+
)
|
12
|
+
if config.inputType == InputType.Option:
|
13
|
+
return st.selectbox(
|
14
|
+
config.displayLabel if config.displayLabel else config.label.value,
|
15
|
+
config.inputConfig["options"],
|
16
|
+
key=key,
|
17
|
+
help=config.inputHelp,
|
18
|
+
)
|
19
|
+
if config.inputType == InputType.Number:
|
20
|
+
return st.number_input(
|
21
|
+
config.displayLabel if config.displayLabel else config.label.value,
|
22
|
+
# format="%d",
|
23
|
+
step=config.inputConfig.get("step", 1),
|
24
|
+
min_value=config.inputConfig["min"],
|
25
|
+
max_value=config.inputConfig["max"],
|
26
|
+
key=key,
|
27
|
+
value=config.inputConfig["value"],
|
28
|
+
help=config.inputHelp,
|
29
|
+
)
|
30
|
+
if config.inputType == InputType.Float:
|
31
|
+
return st.number_input(
|
32
|
+
config.displayLabel if config.displayLabel else config.label.value,
|
33
|
+
step=config.inputConfig.get("step", 0.1),
|
34
|
+
min_value=config.inputConfig["min"],
|
35
|
+
max_value=config.inputConfig["max"],
|
36
|
+
key=key,
|
37
|
+
value=config.inputConfig["value"],
|
38
|
+
help=config.inputHelp,
|
39
|
+
)
|
40
|
+
if config.inputType == InputType.Bool:
|
41
|
+
return st.selectbox(
|
42
|
+
config.displayLabel if config.displayLabel else config.label.value,
|
43
|
+
options=[True, False],
|
44
|
+
index=0 if config.inputConfig["value"] else 1,
|
45
|
+
key=key,
|
46
|
+
help=config.inputHelp,
|
47
|
+
)
|
48
|
+
raise Exception(f"Invalid InputType: {config.inputType}")
|
@@ -86,7 +86,9 @@ def controlPanel(st, tasks: list[TaskConfig], taskLabel, isAllValid):
|
|
86
86
|
currentTaskId = benchmark_runner.get_current_task_id()
|
87
87
|
tasksCount = benchmark_runner.get_tasks_count()
|
88
88
|
text = f":running: Running Task {currentTaskId} / {tasksCount}"
|
89
|
-
|
89
|
+
|
90
|
+
if tasksCount > 0:
|
91
|
+
st.progress(currentTaskId / tasksCount, text=text)
|
90
92
|
|
91
93
|
columns = st.columns(6)
|
92
94
|
columns[0].button(
|
@@ -0,0 +1,253 @@
|
|
1
|
+
import plotly.graph_objects as go
|
2
|
+
|
3
|
+
from vectordb_bench.frontend.components.streaming.data import (
|
4
|
+
DisplayedMetric,
|
5
|
+
StreamingData,
|
6
|
+
get_streaming_data,
|
7
|
+
)
|
8
|
+
from vectordb_bench.frontend.config.styles import (
|
9
|
+
COLORS_10,
|
10
|
+
COLORS_2,
|
11
|
+
SCATTER_LINE_WIDTH,
|
12
|
+
SCATTER_MAKER_SIZE,
|
13
|
+
STREAMING_CHART_COLUMNS,
|
14
|
+
)
|
15
|
+
|
16
|
+
|
17
|
+
def drawChartsByCase(
|
18
|
+
st,
|
19
|
+
allData,
|
20
|
+
showCaseNames: list[str],
|
21
|
+
**kwargs,
|
22
|
+
):
|
23
|
+
allData = [d for d in allData if len(d["st_search_stage_list"]) > 0]
|
24
|
+
for case_name in showCaseNames:
|
25
|
+
data = [d for d in allData if d["case_name"] == case_name]
|
26
|
+
if len(data) == 0:
|
27
|
+
continue
|
28
|
+
container = st.container()
|
29
|
+
container.write("") # blank line
|
30
|
+
container.subheader(case_name)
|
31
|
+
drawChartByMetric(container, data, case_name=case_name, **kwargs)
|
32
|
+
container.write("") # blank line
|
33
|
+
|
34
|
+
|
35
|
+
def drawChartByMetric(
|
36
|
+
st,
|
37
|
+
case_data,
|
38
|
+
case_name: str,
|
39
|
+
line_chart_displayed_y_metrics: list[tuple[DisplayedMetric, str]],
|
40
|
+
**kwargs,
|
41
|
+
):
|
42
|
+
columns = st.columns(STREAMING_CHART_COLUMNS)
|
43
|
+
streaming_data = get_streaming_data(case_data)
|
44
|
+
|
45
|
+
# line chart
|
46
|
+
for i, metric_info in enumerate(line_chart_displayed_y_metrics):
|
47
|
+
metric, note = metric_info
|
48
|
+
container = columns[i % STREAMING_CHART_COLUMNS]
|
49
|
+
container.markdown(f"#### {metric.value.capitalize()}")
|
50
|
+
container.markdown(f"{note}")
|
51
|
+
key = f"{case_name}-{metric.value}"
|
52
|
+
drawLineChart(container, streaming_data, metric=metric, key=key, **kwargs)
|
53
|
+
|
54
|
+
# bar chart
|
55
|
+
container = columns[len(line_chart_displayed_y_metrics) % STREAMING_CHART_COLUMNS]
|
56
|
+
container.markdown("#### Duration")
|
57
|
+
container.markdown(
|
58
|
+
"insert more than ideal-insert-duration (dash-line) means exceeding the maximum processing capacity.",
|
59
|
+
help="vectordb need more time to process accumulated insert requests.",
|
60
|
+
)
|
61
|
+
key = f"{case_name}-duration"
|
62
|
+
drawBarChart(container, case_data, key=key, **kwargs)
|
63
|
+
# drawLineChart(container, data, line_x_displayed_label, label)
|
64
|
+
# drawTestChart(container)
|
65
|
+
|
66
|
+
|
67
|
+
def drawLineChart(
|
68
|
+
st,
|
69
|
+
streaming_data: list[StreamingData],
|
70
|
+
metric: DisplayedMetric,
|
71
|
+
key: str,
|
72
|
+
with_last_optimized_data=True,
|
73
|
+
**kwargs,
|
74
|
+
):
|
75
|
+
db_names = list({d.db_name for d in streaming_data})
|
76
|
+
db_names.sort()
|
77
|
+
x_metric = kwargs.get("line_chart_displayed_x_metric", DisplayedMetric.search_stage)
|
78
|
+
fig = go.Figure()
|
79
|
+
if x_metric == DisplayedMetric.search_time:
|
80
|
+
ideal_insert_duration = streaming_data[0].ideal_insert_duration
|
81
|
+
fig.add_shape(
|
82
|
+
type="line",
|
83
|
+
y0=min([getattr(d, metric.value) for d in streaming_data]),
|
84
|
+
y1=max([getattr(d, metric.value) for d in streaming_data]),
|
85
|
+
x0=ideal_insert_duration,
|
86
|
+
x1=ideal_insert_duration,
|
87
|
+
line=dict(color="#999", width=SCATTER_LINE_WIDTH, dash="dot"),
|
88
|
+
showlegend=True,
|
89
|
+
name="insert 100% standard time",
|
90
|
+
)
|
91
|
+
for i, db_name in enumerate(db_names):
|
92
|
+
data = [d for d in streaming_data if d.db_name == db_name]
|
93
|
+
color = COLORS_10[i]
|
94
|
+
if with_last_optimized_data:
|
95
|
+
fig.add_trace(
|
96
|
+
get_optimized_scatter(
|
97
|
+
data,
|
98
|
+
db_name=db_name,
|
99
|
+
metric=metric,
|
100
|
+
color=color,
|
101
|
+
**kwargs,
|
102
|
+
)
|
103
|
+
)
|
104
|
+
fig.add_trace(
|
105
|
+
get_normal_scatter(
|
106
|
+
data,
|
107
|
+
db_name=db_name,
|
108
|
+
metric=metric,
|
109
|
+
color=color,
|
110
|
+
**kwargs,
|
111
|
+
)
|
112
|
+
)
|
113
|
+
fig.update_layout(
|
114
|
+
margin=dict(l=0, r=0, t=40, b=0, pad=8),
|
115
|
+
legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="left", x=0, title=""),
|
116
|
+
)
|
117
|
+
|
118
|
+
x_title = "Search Stages (%)"
|
119
|
+
if x_metric == DisplayedMetric.search_time:
|
120
|
+
x_title = "Actual Time (s)"
|
121
|
+
fig.update_layout(xaxis_title=x_title)
|
122
|
+
st.plotly_chart(fig, use_container_width=True, key=key)
|
123
|
+
|
124
|
+
|
125
|
+
def get_normal_scatter(
|
126
|
+
data: list[StreamingData],
|
127
|
+
db_name: str,
|
128
|
+
metric: DisplayedMetric,
|
129
|
+
color: str,
|
130
|
+
line_chart_displayed_x_metric: DisplayedMetric,
|
131
|
+
**kwargs,
|
132
|
+
):
|
133
|
+
unit = ""
|
134
|
+
if "latency" in metric.value:
|
135
|
+
unit = "ms"
|
136
|
+
data.sort(key=lambda x: getattr(x, line_chart_displayed_x_metric.value))
|
137
|
+
data = [d for d in data if not d.optimized]
|
138
|
+
hovertemplate = f"%{{text}}% data inserted.<br>{metric.value}=%{{y:.4g}}{unit}"
|
139
|
+
if line_chart_displayed_x_metric == DisplayedMetric.search_time:
|
140
|
+
hovertemplate = f"%{{text}}% data inserted.<br>actual_time=%{{x:.4g}}s<br>{metric.value}=%{{y:.4g}}{unit}"
|
141
|
+
return go.Scatter(
|
142
|
+
x=[getattr(d, line_chart_displayed_x_metric.value) for d in data],
|
143
|
+
y=[getattr(d, metric.value) for d in data],
|
144
|
+
text=[d.search_stage for d in data],
|
145
|
+
mode="markers+lines",
|
146
|
+
name=db_name,
|
147
|
+
marker=dict(color=color, size=SCATTER_MAKER_SIZE),
|
148
|
+
line=dict(dash="solid", width=SCATTER_LINE_WIDTH, color=color),
|
149
|
+
legendgroup=db_name,
|
150
|
+
hovertemplate=hovertemplate,
|
151
|
+
)
|
152
|
+
|
153
|
+
|
154
|
+
def get_optimized_scatter(
|
155
|
+
data: list[StreamingData],
|
156
|
+
db_name: str,
|
157
|
+
metric: DisplayedMetric,
|
158
|
+
color: str,
|
159
|
+
line_chart_displayed_x_metric: DisplayedMetric,
|
160
|
+
**kwargs,
|
161
|
+
):
|
162
|
+
unit = ""
|
163
|
+
if "latency" in metric.value:
|
164
|
+
unit = "ms"
|
165
|
+
data.sort(key=lambda x: x.search_stage)
|
166
|
+
if not data[-1].optimized or len(data) < 2:
|
167
|
+
return go.Scatter()
|
168
|
+
data = data[-2:]
|
169
|
+
hovertemplate = f"all data inserted and <b style='color: #333;'>optimized</b>.<br>{metric.value}=%{{y:.4g}}{unit}"
|
170
|
+
if line_chart_displayed_x_metric == DisplayedMetric.search_time:
|
171
|
+
hovertemplate = f"all data inserted and <b style='color: #333;'>optimized</b>.<br>actual_time=%{{x:.4g}}s<br>{metric.value}=%{{y:.4g}}{unit}"
|
172
|
+
return go.Scatter(
|
173
|
+
x=[getattr(d, line_chart_displayed_x_metric.value) for d in data],
|
174
|
+
y=[getattr(d, metric.value) for d in data],
|
175
|
+
text=[d.search_stage for d in data],
|
176
|
+
mode="markers+lines",
|
177
|
+
name=db_name,
|
178
|
+
legendgroup=db_name,
|
179
|
+
marker=dict(color=color, size=[0, SCATTER_MAKER_SIZE]),
|
180
|
+
line=dict(dash="dash", width=SCATTER_LINE_WIDTH, color=color),
|
181
|
+
hovertemplate=hovertemplate,
|
182
|
+
showlegend=False,
|
183
|
+
)
|
184
|
+
|
185
|
+
|
186
|
+
def drawBarChart(
|
187
|
+
st,
|
188
|
+
data,
|
189
|
+
key: str,
|
190
|
+
with_last_optimized_data=True,
|
191
|
+
**kwargs,
|
192
|
+
):
|
193
|
+
if len(data) < 1:
|
194
|
+
return
|
195
|
+
fig = go.Figure()
|
196
|
+
|
197
|
+
# ideal insert duration
|
198
|
+
ideal_insert_duration = data[0]["st_ideal_insert_duration"]
|
199
|
+
fig.add_shape(
|
200
|
+
type="line",
|
201
|
+
y0=-0.5,
|
202
|
+
y1=len(data) - 0.5,
|
203
|
+
x0=ideal_insert_duration,
|
204
|
+
x1=ideal_insert_duration,
|
205
|
+
line=dict(color="#999", width=SCATTER_LINE_WIDTH, dash="dot"),
|
206
|
+
showlegend=True,
|
207
|
+
name="insert 100% standard time",
|
208
|
+
)
|
209
|
+
|
210
|
+
# insert duration
|
211
|
+
fig.add_trace(
|
212
|
+
get_bar(
|
213
|
+
data,
|
214
|
+
metric=DisplayedMetric.insert_duration,
|
215
|
+
color=COLORS_2[0],
|
216
|
+
**kwargs,
|
217
|
+
)
|
218
|
+
)
|
219
|
+
|
220
|
+
# optimized duration
|
221
|
+
if with_last_optimized_data:
|
222
|
+
fig.add_trace(
|
223
|
+
get_bar(
|
224
|
+
data,
|
225
|
+
metric=DisplayedMetric.optimize_duration,
|
226
|
+
color=COLORS_2[1],
|
227
|
+
**kwargs,
|
228
|
+
)
|
229
|
+
)
|
230
|
+
fig.update_layout(
|
231
|
+
margin=dict(l=0, r=0, t=40, b=0, pad=8),
|
232
|
+
legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="left", x=0, title=""),
|
233
|
+
)
|
234
|
+
fig.update_layout(xaxis_title="time (s)")
|
235
|
+
fig.update_layout(barmode="stack")
|
236
|
+
fig.update_traces(width=0.15)
|
237
|
+
st.plotly_chart(fig, use_container_width=True, key=key)
|
238
|
+
|
239
|
+
|
240
|
+
def get_bar(
|
241
|
+
data: list[StreamingData],
|
242
|
+
metric: DisplayedMetric,
|
243
|
+
color: str,
|
244
|
+
**kwargs,
|
245
|
+
):
|
246
|
+
return go.Bar(
|
247
|
+
x=[d[metric.value] for d in data],
|
248
|
+
y=[d["db_name"] for d in data],
|
249
|
+
name=metric,
|
250
|
+
marker_color=color,
|
251
|
+
orientation="h",
|
252
|
+
hovertemplate="%{y} %{x:.2f}s",
|
253
|
+
)
|