vectordb-bench 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +1 -0
- vectordb_bench/backend/assembler.py +1 -1
- vectordb_bench/backend/cases.py +64 -18
- vectordb_bench/backend/clients/__init__.py +13 -0
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
- vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
- vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
- vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
- vectordb_bench/backend/dataset.py +27 -5
- vectordb_bench/cli/vectordbbench.py +2 -0
- vectordb_bench/custom/custom_case.json +18 -0
- vectordb_bench/frontend/components/check_results/charts.py +6 -6
- vectordb_bench/frontend/components/check_results/data.py +12 -12
- vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
- vectordb_bench/frontend/components/check_results/filters.py +20 -13
- vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
- vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
- vectordb_bench/frontend/components/concurrent/charts.py +26 -29
- vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
- vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
- vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
- vectordb_bench/frontend/components/custom/initStyle.py +15 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
- vectordb_bench/frontend/components/run_test/caseSelector.py +40 -28
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -5
- vectordb_bench/frontend/components/run_test/dbSelector.py +2 -14
- vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
- vectordb_bench/frontend/components/run_test/initStyle.py +14 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +1 -1
- vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +138 -31
- vectordb_bench/frontend/{const → config}/styles.py +2 -0
- vectordb_bench/frontend/pages/concurrent.py +11 -18
- vectordb_bench/frontend/pages/custom.py +64 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
- vectordb_bench/frontend/pages/run_test.py +4 -0
- vectordb_bench/frontend/pages/tables.py +2 -2
- vectordb_bench/frontend/utils.py +17 -1
- vectordb_bench/frontend/vdb_benchmark.py +3 -3
- vectordb_bench/models.py +8 -4
- vectordb_bench/results/getLeaderboardData.py +1 -1
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/METADATA +36 -13
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/RECORD +48 -37
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/WHEEL +1 -1
- /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,9 @@ from vectordb_bench.models import CaseResult, ResultLabel
|
|
8
8
|
def getChartData(
|
9
9
|
tasks: list[CaseResult],
|
10
10
|
dbNames: list[str],
|
11
|
-
|
11
|
+
caseNames: list[str],
|
12
12
|
):
|
13
|
-
filterTasks = getFilterTasks(tasks, dbNames,
|
13
|
+
filterTasks = getFilterTasks(tasks, dbNames, caseNames)
|
14
14
|
mergedTasks, failedTasks = mergeTasks(filterTasks)
|
15
15
|
return mergedTasks, failedTasks
|
16
16
|
|
@@ -18,14 +18,13 @@ def getChartData(
|
|
18
18
|
def getFilterTasks(
|
19
19
|
tasks: list[CaseResult],
|
20
20
|
dbNames: list[str],
|
21
|
-
|
21
|
+
caseNames: list[str],
|
22
22
|
) -> list[CaseResult]:
|
23
|
-
case_ids = [case.case_id for case in cases]
|
24
23
|
filterTasks = [
|
25
24
|
task
|
26
25
|
for task in tasks
|
27
26
|
if task.task_config.db_name in dbNames
|
28
|
-
and task.task_config.case_config.case_id in
|
27
|
+
and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
|
29
28
|
]
|
30
29
|
return filterTasks
|
31
30
|
|
@@ -36,16 +35,17 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
36
35
|
db_name = task.task_config.db_name
|
37
36
|
db = task.task_config.db.value
|
38
37
|
db_label = task.task_config.db_config.db_label or ""
|
39
|
-
|
40
|
-
dbCaseMetricsMap[db_name][
|
38
|
+
case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
|
39
|
+
dbCaseMetricsMap[db_name][case.name] = {
|
41
40
|
"db": db,
|
42
41
|
"db_label": db_label,
|
43
42
|
"metrics": mergeMetrics(
|
44
|
-
dbCaseMetricsMap[db_name][
|
43
|
+
dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
|
45
44
|
asdict(task.metrics),
|
46
45
|
),
|
47
46
|
"label": getBetterLabel(
|
48
|
-
dbCaseMetricsMap[db_name][
|
47
|
+
dbCaseMetricsMap[db_name][case.name].get(
|
48
|
+
"label", ResultLabel.FAILED),
|
49
49
|
task.label,
|
50
50
|
),
|
51
51
|
}
|
@@ -53,12 +53,11 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
53
53
|
mergedTasks = []
|
54
54
|
failedTasks = defaultdict(lambda: defaultdict(str))
|
55
55
|
for db_name, caseMetricsMap in dbCaseMetricsMap.items():
|
56
|
-
for
|
56
|
+
for case_name, metricInfo in caseMetricsMap.items():
|
57
57
|
metrics = metricInfo["metrics"]
|
58
58
|
db = metricInfo["db"]
|
59
59
|
db_label = metricInfo["db_label"]
|
60
60
|
label = metricInfo["label"]
|
61
|
-
case_name = case_id.case_name
|
62
61
|
if label == ResultLabel.NORMAL:
|
63
62
|
mergedTasks.append(
|
64
63
|
{
|
@@ -80,7 +79,8 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
|
|
80
79
|
metrics = {**metrics_1}
|
81
80
|
for key, value in metrics_2.items():
|
82
81
|
metrics[key] = (
|
83
|
-
getBetterMetric(
|
82
|
+
getBetterMetric(
|
83
|
+
key, value, metrics[key]) if key in metrics else value
|
84
84
|
)
|
85
85
|
|
86
86
|
return metrics
|
@@ -1,7 +1,7 @@
|
|
1
1
|
def initMainExpanderStyle(st):
|
2
2
|
st.markdown(
|
3
3
|
"""<style>
|
4
|
-
.main
|
4
|
+
.main div[data-testid='stExpander'] p {font-size: 18px; font-weight: 600;}
|
5
5
|
.main div[data-testid='stExpander'] {
|
6
6
|
background-color: #F6F8FA;
|
7
7
|
border: 1px solid #A9BDD140;
|
@@ -1,8 +1,8 @@
|
|
1
1
|
from vectordb_bench.backend.cases import Case
|
2
2
|
from vectordb_bench.frontend.components.check_results.data import getChartData
|
3
3
|
from vectordb_bench.frontend.components.check_results.expanderStyle import initSidebarExanderStyle
|
4
|
-
from vectordb_bench.frontend.
|
5
|
-
from vectordb_bench.frontend.
|
4
|
+
from vectordb_bench.frontend.config.dbCaseConfigs import CASE_NAME_ORDER
|
5
|
+
from vectordb_bench.frontend.config.styles import *
|
6
6
|
import streamlit as st
|
7
7
|
|
8
8
|
from vectordb_bench.models import CaseResult, TestResult
|
@@ -18,11 +18,12 @@ def getshownData(results: list[TestResult], st):
|
|
18
18
|
st.header("Filters")
|
19
19
|
|
20
20
|
shownResults = getshownResults(results, st)
|
21
|
-
showDBNames,
|
21
|
+
showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st)
|
22
22
|
|
23
|
-
shownData, failedTasks = getChartData(
|
23
|
+
shownData, failedTasks = getChartData(
|
24
|
+
shownResults, showDBNames, showCaseNames)
|
24
25
|
|
25
|
-
return shownData, failedTasks,
|
26
|
+
return shownData, failedTasks, showCaseNames
|
26
27
|
|
27
28
|
|
28
29
|
def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
|
@@ -52,12 +53,18 @@ def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
|
|
52
53
|
return selectedResult
|
53
54
|
|
54
55
|
|
55
|
-
def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[
|
56
|
+
def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[str]]:
|
56
57
|
initSidebarExanderStyle(st)
|
57
58
|
allDbNames = list(set({res.task_config.db_name for res in result}))
|
58
59
|
allDbNames.sort()
|
59
|
-
|
60
|
-
|
60
|
+
allCases: list[Case] = [
|
61
|
+
res.task_config.case_config.case_id.case_cls(
|
62
|
+
res.task_config.case_config.custom_case)
|
63
|
+
for res in result
|
64
|
+
]
|
65
|
+
allCaseNameSet = set({case.name for case in allCases})
|
66
|
+
allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + \
|
67
|
+
[case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER]
|
61
68
|
|
62
69
|
# DB Filter
|
63
70
|
dbFilterContainer = st.container()
|
@@ -70,15 +77,14 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[Ca
|
|
70
77
|
|
71
78
|
# Case Filter
|
72
79
|
caseFilterContainer = st.container()
|
73
|
-
|
80
|
+
showCaseNames = filterView(
|
74
81
|
caseFilterContainer,
|
75
82
|
"Case Filter",
|
76
|
-
[
|
83
|
+
[caseName for caseName in allCaseNames],
|
77
84
|
col=1,
|
78
|
-
optionLables=[case.name for case in allCases],
|
79
85
|
)
|
80
86
|
|
81
|
-
return showDBNames,
|
87
|
+
return showDBNames, showCaseNames
|
82
88
|
|
83
89
|
|
84
90
|
def filterView(container, header, options, col, optionLables=None):
|
@@ -114,7 +120,8 @@ def filterView(container, header, options, col, optionLables=None):
|
|
114
120
|
)
|
115
121
|
if optionLables is None:
|
116
122
|
optionLables = options
|
117
|
-
isActive = {option: st.session_state[selectAllState]
|
123
|
+
isActive = {option: st.session_state[selectAllState]
|
124
|
+
for option in optionLables}
|
118
125
|
for i, option in enumerate(optionLables):
|
119
126
|
isActive[option] = columns[i % col].checkbox(
|
120
127
|
optionLables[i],
|
@@ -3,7 +3,7 @@ import pandas as pd
|
|
3
3
|
from collections import defaultdict
|
4
4
|
import streamlit as st
|
5
5
|
|
6
|
-
from vectordb_bench.frontend.
|
6
|
+
from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE
|
7
7
|
|
8
8
|
|
9
9
|
def priceTable(container, data):
|
@@ -1,26 +1,27 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
|
1
|
+
from vectordb_bench.frontend.components.check_results.expanderStyle import (
|
2
|
+
initMainExpanderStyle,
|
3
|
+
)
|
5
4
|
import plotly.express as px
|
6
5
|
|
7
|
-
from vectordb_bench.frontend.
|
6
|
+
from vectordb_bench.frontend.config.styles import COLOR_MAP
|
8
7
|
|
9
8
|
|
10
|
-
def drawChartsByCase(allData,
|
9
|
+
def drawChartsByCase(allData, showCaseNames: list[str], st):
|
11
10
|
initMainExpanderStyle(st)
|
12
|
-
for
|
13
|
-
chartContainer = st.expander(
|
14
|
-
caseDataList = [
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
11
|
+
for caseName in showCaseNames:
|
12
|
+
chartContainer = st.expander(caseName, True)
|
13
|
+
caseDataList = [data for data in allData if data["case_name"] == caseName]
|
14
|
+
data = [
|
15
|
+
{
|
16
|
+
"conc_num": caseData["conc_num_list"][i],
|
17
|
+
"qps": caseData["conc_qps_list"][i],
|
18
|
+
"latency_p99": caseData["conc_latency_p99_list"][i] * 1000,
|
19
|
+
"db_name": caseData["db_name"],
|
20
|
+
"db": caseData["db"],
|
21
|
+
}
|
22
|
+
for caseData in caseDataList
|
23
|
+
for i in range(len(caseData["conc_num_list"]))
|
24
|
+
]
|
24
25
|
drawChart(data, chartContainer)
|
25
26
|
|
26
27
|
|
@@ -38,7 +39,7 @@ def getRange(metric, data, padding_multipliers):
|
|
38
39
|
def drawChart(data, st):
|
39
40
|
if len(data) == 0:
|
40
41
|
return
|
41
|
-
|
42
|
+
|
42
43
|
x = "latency_p99"
|
43
44
|
xrange = getRange(x, data, [0.05, 0.1])
|
44
45
|
|
@@ -63,7 +64,6 @@ def drawChart(data, st):
|
|
63
64
|
line_group=line_group,
|
64
65
|
text=text,
|
65
66
|
markers=True,
|
66
|
-
# color_discrete_map=color_discrete_map,
|
67
67
|
hover_data={
|
68
68
|
"conc_num": True,
|
69
69
|
},
|
@@ -71,12 +71,9 @@ def drawChart(data, st):
|
|
71
71
|
)
|
72
72
|
fig.update_xaxes(range=xrange, title_text="Latency P99 (ms)")
|
73
73
|
fig.update_yaxes(range=yrange, title_text="QPS")
|
74
|
-
fig.update_traces(textposition="bottom right",
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
# ),
|
81
|
-
# )
|
82
|
-
st.plotly_chart(fig, use_container_width=True,)
|
74
|
+
fig.update_traces(textposition="bottom right", texttemplate="conc-%{text:,.4~r}")
|
75
|
+
|
76
|
+
st.plotly_chart(
|
77
|
+
fig,
|
78
|
+
use_container_width=True,
|
79
|
+
)
|
@@ -0,0 +1,31 @@
|
|
1
|
+
|
2
|
+
from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig
|
3
|
+
|
4
|
+
|
5
|
+
def displayCustomCase(customCase: CustomCaseConfig, st, key):
|
6
|
+
|
7
|
+
columns = st.columns([1, 2])
|
8
|
+
customCase.dataset_config.name = columns[0].text_input(
|
9
|
+
"Name", key=f"{key}_name", value=customCase.dataset_config.name)
|
10
|
+
customCase.name = f"{customCase.dataset_config.name} (Performace Case)"
|
11
|
+
customCase.dataset_config.dir = columns[1].text_input(
|
12
|
+
"Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir)
|
13
|
+
|
14
|
+
columns = st.columns(4)
|
15
|
+
customCase.dataset_config.dim = columns[0].number_input(
|
16
|
+
"dim", key=f"{key}_dim", value=customCase.dataset_config.dim)
|
17
|
+
customCase.dataset_config.size = columns[1].number_input(
|
18
|
+
"size", key=f"{key}_size", value=customCase.dataset_config.size)
|
19
|
+
customCase.dataset_config.metric_type = columns[2].selectbox(
|
20
|
+
"metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"])
|
21
|
+
customCase.dataset_config.file_count = columns[3].number_input(
|
22
|
+
"train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count)
|
23
|
+
|
24
|
+
columns = st.columns(4)
|
25
|
+
customCase.dataset_config.use_shuffled = columns[0].checkbox(
|
26
|
+
"use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled)
|
27
|
+
customCase.dataset_config.with_gt = columns[1].checkbox(
|
28
|
+
"with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt)
|
29
|
+
|
30
|
+
customCase.description = st.text_area(
|
31
|
+
"description", key=f"{key}_description", value=customCase.description)
|
@@ -0,0 +1,11 @@
|
|
1
|
+
def displayParams(st):
|
2
|
+
st.markdown("""
|
3
|
+
- `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
|
4
|
+
- Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
|
5
|
+
- Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
|
6
|
+
- Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`.
|
7
|
+
|
8
|
+
- `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files.
|
9
|
+
|
10
|
+
- `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.
|
11
|
+
""")
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import json
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
from vectordb_bench import config
|
6
|
+
|
7
|
+
|
8
|
+
class CustomDatasetConfig(BaseModel):
|
9
|
+
name: str = "custom_dataset"
|
10
|
+
dir: str = ""
|
11
|
+
size: int = 0
|
12
|
+
dim: int = 0
|
13
|
+
metric_type: str = "L2"
|
14
|
+
file_count: int = 1
|
15
|
+
use_shuffled: bool = False
|
16
|
+
with_gt: bool = True
|
17
|
+
|
18
|
+
|
19
|
+
class CustomCaseConfig(BaseModel):
|
20
|
+
name: str = "custom_dataset (Performace Case)"
|
21
|
+
description: str = ""
|
22
|
+
load_timeout: int = 36000
|
23
|
+
optimize_timeout: int = 36000
|
24
|
+
dataset_config: CustomDatasetConfig = CustomDatasetConfig()
|
25
|
+
|
26
|
+
|
27
|
+
def get_custom_configs():
|
28
|
+
with open(config.CUSTOM_CONFIG_DIR, "r") as f:
|
29
|
+
custom_configs = json.load(f)
|
30
|
+
return [CustomCaseConfig(**custom_config) for custom_config in custom_configs]
|
31
|
+
|
32
|
+
|
33
|
+
def save_custom_configs(custom_configs: list[CustomDatasetConfig]):
|
34
|
+
with open(config.CUSTOM_CONFIG_DIR, "w") as f:
|
35
|
+
json.dump([custom_config.dict()
|
36
|
+
for custom_config in custom_configs], f, indent=4)
|
37
|
+
|
38
|
+
|
39
|
+
def generate_custom_case():
|
40
|
+
return CustomCaseConfig()
|
@@ -0,0 +1,15 @@
|
|
1
|
+
def initStyle(st):
|
2
|
+
st.markdown(
|
3
|
+
"""<style>
|
4
|
+
/* expander - header */
|
5
|
+
.main div[data-testid='stExpander'] summary p {font-size: 20px; font-weight: 600;}
|
6
|
+
/*
|
7
|
+
button {
|
8
|
+
height: auto;
|
9
|
+
padding-left: 8px !important;
|
10
|
+
padding-right: 6px !important;
|
11
|
+
}
|
12
|
+
*/
|
13
|
+
</style>""",
|
14
|
+
unsafe_allow_html=True,
|
15
|
+
)
|
@@ -1,9 +1,13 @@
|
|
1
|
-
|
1
|
+
|
2
|
+
from vectordb_bench.frontend.config.styles import *
|
2
3
|
from vectordb_bench.backend.cases import CaseType
|
3
|
-
from vectordb_bench.frontend.
|
4
|
+
from vectordb_bench.frontend.config.dbCaseConfigs import *
|
5
|
+
from collections import defaultdict
|
6
|
+
|
7
|
+
from vectordb_bench.frontend.utils import addHorizontalLine
|
4
8
|
|
5
9
|
|
6
|
-
def caseSelector(st, activedDbList):
|
10
|
+
def caseSelector(st, activedDbList: list[DB]):
|
7
11
|
st.markdown(
|
8
12
|
"<div style='height: 24px;'></div>",
|
9
13
|
unsafe_allow_html=True,
|
@@ -14,41 +18,49 @@ def caseSelector(st, activedDbList):
|
|
14
18
|
unsafe_allow_html=True,
|
15
19
|
)
|
16
20
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
21
|
+
activedCaseList: list[CaseConfig] = []
|
22
|
+
dbToCaseClusterConfigs = defaultdict(lambda: defaultdict(dict))
|
23
|
+
dbToCaseConfigs = defaultdict(lambda: defaultdict(dict))
|
24
|
+
caseClusters = UI_CASE_CLUSTERS + [get_custom_case_cluter()]
|
25
|
+
for caseCluster in caseClusters:
|
26
|
+
activedCaseList += caseClusterExpander(
|
27
|
+
st, caseCluster, dbToCaseClusterConfigs, activedDbList)
|
28
|
+
for db in dbToCaseClusterConfigs:
|
29
|
+
for uiCaseItem in dbToCaseClusterConfigs[db]:
|
30
|
+
for case in uiCaseItem.cases:
|
31
|
+
dbToCaseConfigs[db][case] = dbToCaseClusterConfigs[db][uiCaseItem]
|
32
|
+
|
33
|
+
return activedCaseList, dbToCaseConfigs
|
34
|
+
|
35
|
+
|
36
|
+
def caseClusterExpander(st, caseCluster: UICaseItemCluster, dbToCaseClusterConfigs, activedDbList: list[DB]):
|
37
|
+
expander = st.expander(caseCluster.label, False)
|
38
|
+
activedCases: list[CaseConfig] = []
|
39
|
+
for uiCaseItem in caseCluster.uiCaseItems:
|
40
|
+
if uiCaseItem.isLine:
|
41
|
+
addHorizontalLine(expander)
|
25
42
|
else:
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
caseItemContainer, allCaseConfigs, case, activedDbList
|
30
|
-
)
|
31
|
-
activedCaseList = [case for case in CASE_LIST if caseIsActived[case]]
|
32
|
-
return activedCaseList, allCaseConfigs
|
43
|
+
activedCases += caseItemCheckbox(expander,
|
44
|
+
dbToCaseClusterConfigs, uiCaseItem, activedDbList)
|
45
|
+
return activedCases
|
33
46
|
|
34
47
|
|
35
|
-
def
|
36
|
-
selected = st.checkbox(
|
48
|
+
def caseItemCheckbox(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
|
49
|
+
selected = st.checkbox(uiCaseItem.label)
|
37
50
|
st.markdown(
|
38
|
-
f"<div style='color: #1D2939; margin: -8px 0 20px {CHECKBOX_INDENT}px; font-size: 14px;'>{
|
51
|
+
f"<div style='color: #1D2939; margin: -8px 0 20px {CHECKBOX_INDENT}px; font-size: 14px;'>{uiCaseItem.description}</div>",
|
39
52
|
unsafe_allow_html=True,
|
40
53
|
)
|
41
54
|
|
42
55
|
if selected:
|
43
|
-
caseConfigSettingContainer = st.container()
|
44
56
|
caseConfigSetting(
|
45
|
-
|
57
|
+
st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList
|
46
58
|
)
|
47
59
|
|
48
|
-
return selected
|
60
|
+
return uiCaseItem.cases if selected else []
|
49
61
|
|
50
62
|
|
51
|
-
def caseConfigSetting(st,
|
63
|
+
def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
|
52
64
|
for db in activedDbList:
|
53
65
|
columns = st.columns(1 + CASE_CONFIG_SETTING_COLUMNS)
|
54
66
|
# column 0 - title
|
@@ -57,12 +69,12 @@ def caseConfigSetting(st, allCaseConfigs, case, activedDbList):
|
|
57
69
|
f"<div style='margin: 0 0 24px {CHECKBOX_INDENT}px; font-size: 18px; font-weight: 600;'>{db.name}</div>",
|
58
70
|
unsafe_allow_html=True,
|
59
71
|
)
|
60
|
-
caseConfig = allCaseConfigs[db][case]
|
61
72
|
k = 0
|
62
|
-
|
73
|
+
caseConfig = dbToCaseClusterConfigs[db][uiCaseItem]
|
74
|
+
for config in CASE_CONFIG_MAP.get(db, {}).get(uiCaseItem.caseLabel, []):
|
63
75
|
if config.isDisplayed(caseConfig):
|
64
76
|
column = columns[1 + k % CASE_CONFIG_SETTING_COLUMNS]
|
65
|
-
key = "%s-%s-%s" % (db,
|
77
|
+
key = "%s-%s-%s" % (db, uiCaseItem.label, config.label.value)
|
66
78
|
if config.inputType == InputType.Text:
|
67
79
|
caseConfig[config.label] = column.text_input(
|
68
80
|
config.displayLabel if config.displayLabel else config.label.value,
|
@@ -1,13 +1,9 @@
|
|
1
1
|
from pydantic import ValidationError
|
2
|
-
from vectordb_bench.frontend.
|
2
|
+
from vectordb_bench.frontend.config.styles import *
|
3
3
|
from vectordb_bench.frontend.utils import inputIsPassword
|
4
4
|
|
5
5
|
|
6
6
|
def dbConfigSettings(st, activedDbList):
|
7
|
-
st.markdown(
|
8
|
-
"<style> .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;}</style>",
|
9
|
-
unsafe_allow_html=True,
|
10
|
-
)
|
11
7
|
expander = st.expander("Configurations for the selected databases", True)
|
12
8
|
|
13
9
|
dbConfigs = {}
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from streamlit.runtime.media_file_storage import MediaFileStorageError
|
2
|
-
|
3
|
-
from vectordb_bench.frontend.
|
4
|
-
from vectordb_bench.frontend.const.dbCaseConfigs import DB_LIST
|
2
|
+
from vectordb_bench.frontend.config.styles import DB_SELECTOR_COLUMNS, DB_TO_ICON
|
3
|
+
from vectordb_bench.frontend.config.dbCaseConfigs import DB_LIST
|
5
4
|
|
6
5
|
|
7
6
|
def dbSelector(st):
|
@@ -18,17 +17,6 @@ def dbSelector(st):
|
|
18
17
|
dbContainerColumns = st.columns(DB_SELECTOR_COLUMNS, gap="small")
|
19
18
|
dbIsActived = {db: False for db in DB_LIST}
|
20
19
|
|
21
|
-
# style - image; column gap; checkbox font;
|
22
|
-
st.markdown(
|
23
|
-
"""
|
24
|
-
<style>
|
25
|
-
div[data-testid='stImage'] {margin: auto;}
|
26
|
-
div[data-testid='stHorizontalBlock'] {gap: 8px;}
|
27
|
-
.stCheckbox p { color: #000; font-size: 18px; font-weight: 600; }
|
28
|
-
</style>
|
29
|
-
""",
|
30
|
-
unsafe_allow_html=True,
|
31
|
-
)
|
32
20
|
for i, db in enumerate(DB_LIST):
|
33
21
|
column = dbContainerColumns[i % DB_SELECTOR_COLUMNS]
|
34
22
|
dbIsActived[db] = column.checkbox(db.name)
|
@@ -1,17 +1,15 @@
|
|
1
|
+
from vectordb_bench.backend.clients import DB
|
1
2
|
from vectordb_bench.models import CaseConfig, CaseConfigParamType, TaskConfig
|
2
3
|
|
3
4
|
|
4
|
-
def generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs):
|
5
|
+
def generate_tasks(activedDbList: list[DB], dbConfigs, activedCaseList: list[CaseConfig], allCaseConfigs):
|
5
6
|
tasks = []
|
6
7
|
for db in activedDbList:
|
7
8
|
for case in activedCaseList:
|
8
9
|
task = TaskConfig(
|
9
10
|
db=db.value,
|
10
11
|
db_config=dbConfigs[db],
|
11
|
-
case_config=
|
12
|
-
case_id=case.value,
|
13
|
-
custom_case={},
|
14
|
-
),
|
12
|
+
case_config=case,
|
15
13
|
db_case_config=db.case_config_cls(
|
16
14
|
allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None)
|
17
15
|
)(**{key.value: value for key, value in allCaseConfigs[db][case].items()}),
|
@@ -0,0 +1,14 @@
|
|
1
|
+
def initStyle(st):
|
2
|
+
st.markdown(
|
3
|
+
"""<style>
|
4
|
+
/* expander - header */
|
5
|
+
.main div[data-testid='stExpander'] p {font-size: 18px; font-weight: 600;}
|
6
|
+
/* db icon */
|
7
|
+
div[data-testid='stImage'] {margin: auto;}
|
8
|
+
/* db column gap */
|
9
|
+
div[data-testid='stHorizontalBlock'] {gap: 8px;}
|
10
|
+
/* check box */
|
11
|
+
.stCheckbox p { color: #000; font-size: 18px; font-weight: 600; }
|
12
|
+
</style>""",
|
13
|
+
unsafe_allow_html=True,
|
14
|
+
)
|