vectordb-bench 0.0.1__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +30 -0
- vectordb_bench/__main__.py +39 -0
- vectordb_bench/backend/__init__.py +0 -0
- vectordb_bench/backend/assembler.py +57 -0
- vectordb_bench/backend/cases.py +124 -0
- vectordb_bench/backend/clients/__init__.py +57 -0
- vectordb_bench/backend/clients/api.py +179 -0
- vectordb_bench/backend/clients/elastic_cloud/config.py +56 -0
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +152 -0
- vectordb_bench/backend/clients/milvus/config.py +123 -0
- vectordb_bench/backend/clients/milvus/milvus.py +182 -0
- vectordb_bench/backend/clients/pinecone/config.py +15 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +113 -0
- vectordb_bench/backend/clients/qdrant_cloud/config.py +16 -0
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +169 -0
- vectordb_bench/backend/clients/weaviate_cloud/config.py +45 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +151 -0
- vectordb_bench/backend/clients/zilliz_cloud/config.py +34 -0
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +35 -0
- vectordb_bench/backend/dataset.py +393 -0
- vectordb_bench/backend/result_collector.py +15 -0
- vectordb_bench/backend/runner/__init__.py +12 -0
- vectordb_bench/backend/runner/mp_runner.py +124 -0
- vectordb_bench/backend/runner/serial_runner.py +164 -0
- vectordb_bench/backend/task_runner.py +290 -0
- vectordb_bench/backend/utils.py +85 -0
- vectordb_bench/base.py +6 -0
- vectordb_bench/frontend/components/check_results/charts.py +175 -0
- vectordb_bench/frontend/components/check_results/data.py +86 -0
- vectordb_bench/frontend/components/check_results/filters.py +97 -0
- vectordb_bench/frontend/components/check_results/headerIcon.py +18 -0
- vectordb_bench/frontend/components/check_results/nav.py +21 -0
- vectordb_bench/frontend/components/check_results/priceTable.py +48 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +10 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +87 -0
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +47 -0
- vectordb_bench/frontend/components/run_test/dbSelector.py +36 -0
- vectordb_bench/frontend/components/run_test/generateTasks.py +21 -0
- vectordb_bench/frontend/components/run_test/hideSidebar.py +10 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +69 -0
- vectordb_bench/frontend/const.py +391 -0
- vectordb_bench/frontend/pages/qps_with_price.py +60 -0
- vectordb_bench/frontend/pages/run_test.py +59 -0
- vectordb_bench/frontend/utils.py +6 -0
- vectordb_bench/frontend/vdb_benchmark.py +42 -0
- vectordb_bench/interface.py +239 -0
- vectordb_bench/log_util.py +103 -0
- vectordb_bench/metric.py +53 -0
- vectordb_bench/models.py +234 -0
- vectordb_bench/results/result_20230609_standard.json +3228 -0
- vectordb_bench-0.0.1.dist-info/LICENSE +21 -0
- vectordb_bench-0.0.1.dist-info/METADATA +226 -0
- vectordb_bench-0.0.1.dist-info/RECORD +56 -0
- vectordb_bench-0.0.1.dist-info/WHEEL +5 -0
- vectordb_bench-0.0.1.dist-info/entry_points.txt +2 -0
- vectordb_bench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,175 @@
|
|
1
|
+
from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
|
2
|
+
from vectordb_bench.frontend.const import *
|
3
|
+
from vectordb_bench.models import ResultLabel
|
4
|
+
import plotly.express as px
|
5
|
+
|
6
|
+
|
7
|
+
def drawCharts(st, allData, failedTasks, cases):
|
8
|
+
st.markdown(
|
9
|
+
"<style> .main .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;} </style>",
|
10
|
+
unsafe_allow_html=True,
|
11
|
+
)
|
12
|
+
st.markdown(
|
13
|
+
"""<style>
|
14
|
+
.main div[data-testid='stExpander'] {
|
15
|
+
background-color: #F6F8FA;
|
16
|
+
border: 1px solid #A9BDD140;
|
17
|
+
border-radius: 8px;
|
18
|
+
}
|
19
|
+
</style>""",
|
20
|
+
unsafe_allow_html=True,
|
21
|
+
)
|
22
|
+
for case in cases:
|
23
|
+
chartContainer = st.expander(case, True)
|
24
|
+
data = [data for data in allData if data["case"] == case]
|
25
|
+
drawChart(data, chartContainer)
|
26
|
+
|
27
|
+
errorDBs = failedTasks[case]
|
28
|
+
showFailedDBs(chartContainer, errorDBs)
|
29
|
+
|
30
|
+
|
31
|
+
def showFailedDBs(st, errorDBs):
|
32
|
+
failedDBs = [db for db, label in errorDBs.items() if label == ResultLabel.FAILED]
|
33
|
+
timeoutDBs = [
|
34
|
+
db for db, label in errorDBs.items() if label == ResultLabel.OUTOFRANGE
|
35
|
+
]
|
36
|
+
|
37
|
+
showFailedText(st, "Failed", failedDBs)
|
38
|
+
showFailedText(st, "Timeout", timeoutDBs)
|
39
|
+
|
40
|
+
|
41
|
+
def showFailedText(st, text, dbs):
|
42
|
+
if len(dbs) > 0:
|
43
|
+
st.markdown(
|
44
|
+
f"<div style='margin: -16px 0 12px 8px; font-size: 16px; font-weight: 600;'>{text}: {', '.join(dbs)}</div>",
|
45
|
+
unsafe_allow_html=True,
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
def drawChart(data, st):
|
50
|
+
metricsSet = set()
|
51
|
+
for d in data:
|
52
|
+
metricsSet = metricsSet.union(d["metricsSet"])
|
53
|
+
showMetrics = [metric for metric in metricOrder if metric in metricsSet]
|
54
|
+
|
55
|
+
for i, metric in enumerate(showMetrics):
|
56
|
+
container = st.container()
|
57
|
+
drawMetricChart(data, metric, container)
|
58
|
+
|
59
|
+
|
60
|
+
def getLabelToShapeMap(data):
|
61
|
+
labelIndexMap = {}
|
62
|
+
|
63
|
+
dbSet = {d["db"] for d in data}
|
64
|
+
for db in dbSet:
|
65
|
+
labelSet = {d["db_label"] for d in data if d["db"] == db}
|
66
|
+
labelList = list(labelSet)
|
67
|
+
|
68
|
+
usedShapes = set()
|
69
|
+
i = 0
|
70
|
+
for label in labelList:
|
71
|
+
if label not in labelIndexMap:
|
72
|
+
loopCount = 0
|
73
|
+
while i % len(PATTERN_SHAPES) in usedShapes:
|
74
|
+
i += 1
|
75
|
+
loopCount += 1
|
76
|
+
if loopCount > len(PATTERN_SHAPES):
|
77
|
+
break
|
78
|
+
labelIndexMap[label] = i
|
79
|
+
i += 1
|
80
|
+
else:
|
81
|
+
usedShapes.add(labelIndexMap[label] % len(PATTERN_SHAPES))
|
82
|
+
|
83
|
+
labelToShapeMap = {
|
84
|
+
label: getPatternShape(index) for label, index in labelIndexMap.items()
|
85
|
+
}
|
86
|
+
return labelToShapeMap
|
87
|
+
|
88
|
+
|
89
|
+
def drawMetricChart(data, metric, st):
|
90
|
+
dataWithMetric = [d for d in data if d.get(metric, 0) > 1e-7]
|
91
|
+
# dataWithMetric = data
|
92
|
+
if len(dataWithMetric) == 0:
|
93
|
+
return
|
94
|
+
|
95
|
+
# title = st.container()
|
96
|
+
# title.markdown(
|
97
|
+
# f"**{metric}** ({'less' if isLowerIsBetterMetric(metric) else 'more'} is better)"
|
98
|
+
# )
|
99
|
+
chart = st.container()
|
100
|
+
|
101
|
+
height = len(dataWithMetric) * 24 + 48
|
102
|
+
xmin = 0
|
103
|
+
xmax = max([d.get(metric, 0) for d in dataWithMetric])
|
104
|
+
xpadding = (xmax - xmin) / 16
|
105
|
+
xpadding_multiplier = 1.6
|
106
|
+
xrange = [xmin, xmax + xpadding * xpadding_multiplier]
|
107
|
+
unit = metricUnitMap.get(metric, "")
|
108
|
+
labelToShapeMap = getLabelToShapeMap(dataWithMetric)
|
109
|
+
categoryorder = (
|
110
|
+
"total descending" if isLowerIsBetterMetric(metric) else "total ascending"
|
111
|
+
)
|
112
|
+
fig = px.bar(
|
113
|
+
dataWithMetric,
|
114
|
+
x=metric,
|
115
|
+
y="db_name",
|
116
|
+
color="db",
|
117
|
+
height=height,
|
118
|
+
# pattern_shape="db_label",
|
119
|
+
# pattern_shape_sequence=SHAPES,
|
120
|
+
pattern_shape_map=labelToShapeMap,
|
121
|
+
orientation="h",
|
122
|
+
hover_data={
|
123
|
+
"db": False,
|
124
|
+
"db_label": False,
|
125
|
+
"db_name": True,
|
126
|
+
},
|
127
|
+
color_discrete_map=COLOR_MAP,
|
128
|
+
text_auto=True,
|
129
|
+
title=f"{metric.capitalize()} ({'less' if isLowerIsBetterMetric(metric) else 'more'} is better)",
|
130
|
+
)
|
131
|
+
fig.update_xaxes(showticklabels=False, visible=False, range=xrange)
|
132
|
+
fig.update_yaxes(
|
133
|
+
# showticklabels=False,
|
134
|
+
# visible=False,
|
135
|
+
title=dict(
|
136
|
+
font=dict(
|
137
|
+
size=1,
|
138
|
+
),
|
139
|
+
# text="",
|
140
|
+
)
|
141
|
+
)
|
142
|
+
fig.update_traces(
|
143
|
+
textposition="outside",
|
144
|
+
textfont=dict(
|
145
|
+
color="#333",
|
146
|
+
size=12,
|
147
|
+
),
|
148
|
+
marker=dict(
|
149
|
+
pattern=dict(fillmode="overlay", fgcolor="#fff", fgopacity=1, size=7)
|
150
|
+
),
|
151
|
+
texttemplate="%{x:,.4~r}" + unit,
|
152
|
+
)
|
153
|
+
fig.update_layout(
|
154
|
+
margin=dict(l=0, r=0, t=48, b=12, pad=8),
|
155
|
+
bargap=0.25,
|
156
|
+
showlegend=False,
|
157
|
+
legend=dict(
|
158
|
+
orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""
|
159
|
+
),
|
160
|
+
# legend=dict(orientation="v", title=""),
|
161
|
+
yaxis={"categoryorder": categoryorder},
|
162
|
+
title=dict(
|
163
|
+
font=dict(
|
164
|
+
size=16,
|
165
|
+
color="#666",
|
166
|
+
# family="Arial, sans-serif",
|
167
|
+
),
|
168
|
+
pad=dict(l=16),
|
169
|
+
# y=0.95,
|
170
|
+
# yanchor="top",
|
171
|
+
# yref="container",
|
172
|
+
),
|
173
|
+
)
|
174
|
+
|
175
|
+
chart.plotly_chart(fig, use_container_width=True)
|
@@ -0,0 +1,86 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from dataclasses import asdict
|
3
|
+
from vectordb_bench.metric import isLowerIsBetterMetric
|
4
|
+
from vectordb_bench.models import ResultLabel
|
5
|
+
|
6
|
+
|
7
|
+
def getChartData(tasks, dbNames, cases):
|
8
|
+
filterTasks = getFilterTasks(tasks, dbNames, cases)
|
9
|
+
mergedTasks, failedTasks = mergeTasks(filterTasks)
|
10
|
+
return mergedTasks, failedTasks
|
11
|
+
|
12
|
+
|
13
|
+
def getFilterTasks(tasks, dbNames, cases):
|
14
|
+
filterTasks = [
|
15
|
+
task
|
16
|
+
for task in tasks
|
17
|
+
if task.task_config.db_name in dbNames
|
18
|
+
and task.task_config.case_config.case_id.value in cases
|
19
|
+
]
|
20
|
+
return filterTasks
|
21
|
+
|
22
|
+
|
23
|
+
def mergeTasks(tasks):
|
24
|
+
dbCaseMetricsMap = defaultdict(lambda: defaultdict(dict))
|
25
|
+
for task in tasks:
|
26
|
+
db_name = task.task_config.db_name
|
27
|
+
db = task.task_config.db.value
|
28
|
+
db_label = task.task_config.db_config.db_label or ""
|
29
|
+
case = task.task_config.case_config.case_id.value
|
30
|
+
dbCaseMetricsMap[db_name][case] = {
|
31
|
+
"db": db,
|
32
|
+
"db_label": db_label,
|
33
|
+
"metrics": mergeMetrics(
|
34
|
+
dbCaseMetricsMap[db_name][case].get("metrics", {}), asdict(task.metrics)
|
35
|
+
),
|
36
|
+
"label": getBetterLabel(dbCaseMetricsMap[db_name][case].get("label", ResultLabel.FAILED), task.label)
|
37
|
+
}
|
38
|
+
|
39
|
+
mergedTasks = []
|
40
|
+
failedTasks = defaultdict(lambda: defaultdict(str))
|
41
|
+
for db_name, caseMetricsMap in dbCaseMetricsMap.items():
|
42
|
+
for case, metricInfo in caseMetricsMap.items():
|
43
|
+
metrics = metricInfo["metrics"]
|
44
|
+
db = metricInfo["db"]
|
45
|
+
db_label = metricInfo["db_label"]
|
46
|
+
label = metricInfo["label"]
|
47
|
+
if label == ResultLabel.NORMAL:
|
48
|
+
mergedTasks.append(
|
49
|
+
{
|
50
|
+
"db_name": db_name,
|
51
|
+
"db": db,
|
52
|
+
"db_label": db_label,
|
53
|
+
"case": case,
|
54
|
+
"metricsSet": set(metrics.keys()),
|
55
|
+
**metrics,
|
56
|
+
}
|
57
|
+
)
|
58
|
+
else:
|
59
|
+
failedTasks[case][db_name] = label
|
60
|
+
|
61
|
+
return mergedTasks, failedTasks
|
62
|
+
|
63
|
+
|
64
|
+
def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
|
65
|
+
metrics = {**metrics_1}
|
66
|
+
for key, value in metrics_2.items():
|
67
|
+
metrics[key] = (
|
68
|
+
getBetterMetric(key, value, metrics[key]) if key in metrics else value
|
69
|
+
)
|
70
|
+
|
71
|
+
return metrics
|
72
|
+
|
73
|
+
|
74
|
+
def getBetterMetric(metric, value_1, value_2):
|
75
|
+
if value_1 < 1e-7:
|
76
|
+
return value_2
|
77
|
+
if value_2 < 1e-7:
|
78
|
+
return value_1
|
79
|
+
return (
|
80
|
+
min(value_1, value_2)
|
81
|
+
if isLowerIsBetterMetric(metric)
|
82
|
+
else max(value_1, value_2)
|
83
|
+
)
|
84
|
+
|
85
|
+
def getBetterLabel(label_1: ResultLabel, label_2: ResultLabel):
|
86
|
+
return label_2 if label_1 != ResultLabel.NORMAL else label_1
|
@@ -0,0 +1,97 @@
|
|
1
|
+
from vectordb_bench.frontend.components.check_results.data import getChartData
|
2
|
+
from vectordb_bench.frontend.const import *
|
3
|
+
|
4
|
+
|
5
|
+
def getshownData(results, st):
|
6
|
+
# hide the nav
|
7
|
+
st.markdown(
|
8
|
+
"<style> div[data-testid='stSidebarNav'] {display: none;} </style>",
|
9
|
+
unsafe_allow_html=True,
|
10
|
+
)
|
11
|
+
|
12
|
+
st.header("Filters")
|
13
|
+
|
14
|
+
shownResults = getshownResults(results, st)
|
15
|
+
showDBNames, showCases = getShowDbsAndCases(shownResults, st)
|
16
|
+
|
17
|
+
shownData, failedTasks = getChartData(shownResults, showDBNames, showCases)
|
18
|
+
|
19
|
+
return shownData, failedTasks, showCases
|
20
|
+
|
21
|
+
|
22
|
+
def getshownResults(results, st):
|
23
|
+
resultSelectOptions = [
|
24
|
+
result.task_label
|
25
|
+
if result.task_label != result.run_id
|
26
|
+
else f"res-{result.run_id[:4]}"
|
27
|
+
for result in results
|
28
|
+
]
|
29
|
+
if len(resultSelectOptions) == 0:
|
30
|
+
st.write(
|
31
|
+
"There are no results to display. Please wait for the task to complete or run a new task."
|
32
|
+
)
|
33
|
+
return []
|
34
|
+
|
35
|
+
selectedResultSelectedOptions = st.multiselect(
|
36
|
+
"Select the task results you need to analyze.",
|
37
|
+
resultSelectOptions,
|
38
|
+
# label_visibility="hidden",
|
39
|
+
default=resultSelectOptions,
|
40
|
+
)
|
41
|
+
selectedResult = []
|
42
|
+
for option in selectedResultSelectedOptions:
|
43
|
+
result = results[resultSelectOptions.index(option)].results
|
44
|
+
selectedResult += result
|
45
|
+
|
46
|
+
return selectedResult
|
47
|
+
|
48
|
+
|
49
|
+
def getShowDbsAndCases(result, st):
|
50
|
+
# expanderStyles
|
51
|
+
st.markdown("<style> section[data-testid='stSidebar'] div[data-testid='stExpander'] div[data-testid='stVerticalBlock'] { gap: 0.2rem; } </style>", unsafe_allow_html=True,)
|
52
|
+
st.markdown(
|
53
|
+
"<style> div[data-testid='stExpander'] {background-color: #ffffff;} </style>",
|
54
|
+
unsafe_allow_html=True,
|
55
|
+
)
|
56
|
+
st.markdown(
|
57
|
+
"<style> section[data-testid='stSidebar'] .streamlit-expanderHeader p {font-size: 16px; font-weight: 600;} </style>",
|
58
|
+
unsafe_allow_html=True,
|
59
|
+
)
|
60
|
+
|
61
|
+
allDbNames = list(set({res.task_config.db_name for res in result}))
|
62
|
+
allDbNames.sort()
|
63
|
+
allCasesSet = set({res.task_config.case_config.case_id for res in result})
|
64
|
+
allCases = [case["name"].value for case in CASE_LIST if case["name"] in allCasesSet]
|
65
|
+
|
66
|
+
# dbFilterContainer = st.container()
|
67
|
+
# dbFilterContainer.subheader("DB Filter")
|
68
|
+
dbFilterContainer = st.expander("DB Filter", True)
|
69
|
+
showDBNames = filterView(allDbNames, dbFilterContainer, col=1)
|
70
|
+
|
71
|
+
# caseFilterContainer = st.container()
|
72
|
+
# caseFilterContainer.subheader("Case Filter")
|
73
|
+
caseFilterContainer = st.expander("Case Filter", True)
|
74
|
+
showCases = filterView(
|
75
|
+
allCases,
|
76
|
+
caseFilterContainer,
|
77
|
+
col=1,
|
78
|
+
optionLables=[case for case in allCases],
|
79
|
+
)
|
80
|
+
|
81
|
+
return showDBNames, showCases
|
82
|
+
|
83
|
+
|
84
|
+
def filterView(options, st, col, optionLables=None):
|
85
|
+
columns = st.columns(
|
86
|
+
col,
|
87
|
+
gap="small",
|
88
|
+
)
|
89
|
+
isActive = {option: True for option in options}
|
90
|
+
if optionLables is None:
|
91
|
+
optionLables = options
|
92
|
+
for i, option in enumerate(options):
|
93
|
+
isActive[option] = columns[i % col].checkbox(
|
94
|
+
optionLables[i], value=isActive[option]
|
95
|
+
)
|
96
|
+
|
97
|
+
return [option for option in options if isActive[option]]
|
@@ -0,0 +1,18 @@
|
|
1
|
+
def drawHeaderIcon(st):
|
2
|
+
st.markdown("""
|
3
|
+
<div class="headerIconContainer"></div>
|
4
|
+
|
5
|
+
<style>
|
6
|
+
.headerIconContainer {
|
7
|
+
position: absolute;
|
8
|
+
top: -50px;
|
9
|
+
height: 50px;
|
10
|
+
width: 100%;
|
11
|
+
border-bottom: 2px solid #E8EAEE;
|
12
|
+
background-image: url(https://assets.zilliz.com/vdb_benchmark_db790b5387.png);
|
13
|
+
background-repeat: no-repeat;
|
14
|
+
}
|
15
|
+
</style
|
16
|
+
""",
|
17
|
+
unsafe_allow_html=True,
|
18
|
+
)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from streamlit_extras.switch_page_button import switch_page
|
2
|
+
|
3
|
+
|
4
|
+
def NavToRunTest(st):
|
5
|
+
st.header("Run your test")
|
6
|
+
st.write("You can set the configs and run your own test.")
|
7
|
+
navClick = st.button("Run Your Test >")
|
8
|
+
if navClick:
|
9
|
+
switch_page("run test")
|
10
|
+
|
11
|
+
|
12
|
+
def NavToQPSWithPrice(st):
|
13
|
+
navClick = st.button("QPS with Price >")
|
14
|
+
if navClick:
|
15
|
+
switch_page("qps with price")
|
16
|
+
|
17
|
+
|
18
|
+
def NavToResults(st):
|
19
|
+
navClick = st.button("< Back to Results")
|
20
|
+
if navClick:
|
21
|
+
switch_page("vdb benchmark")
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from vectordb_bench.backend.clients import DB
|
2
|
+
from vectordb_bench.frontend.const import DB_DBLABEL_TO_PRICE
|
3
|
+
import pandas as pd
|
4
|
+
from collections import defaultdict
|
5
|
+
import streamlit as st
|
6
|
+
|
7
|
+
|
8
|
+
def priceTable(container, data):
|
9
|
+
dbAndLabelSet = {
|
10
|
+
(d["db"], d["db_label"]) for d in data if d["db"] != DB.Milvus.value
|
11
|
+
}
|
12
|
+
|
13
|
+
dbAndLabelList = list(dbAndLabelSet)
|
14
|
+
dbAndLabelList.sort()
|
15
|
+
|
16
|
+
table = pd.DataFrame(
|
17
|
+
[
|
18
|
+
{
|
19
|
+
"DB": db,
|
20
|
+
"Label": db_label,
|
21
|
+
"Price per hour": DB_DBLABEL_TO_PRICE.get(db, {}).get(db_label, 0),
|
22
|
+
}
|
23
|
+
for db, db_label in dbAndLabelList
|
24
|
+
]
|
25
|
+
)
|
26
|
+
height = len(table) * 35 + 38
|
27
|
+
|
28
|
+
expander = container.expander("You can edit the price.")
|
29
|
+
editTable = expander.data_editor(
|
30
|
+
table,
|
31
|
+
use_container_width=True,
|
32
|
+
hide_index=True,
|
33
|
+
height=height,
|
34
|
+
disabled=("DB", "Label"),
|
35
|
+
column_config={
|
36
|
+
"Price per hour": st.column_config.NumberColumn(
|
37
|
+
min_value=0,
|
38
|
+
format="$ %f",
|
39
|
+
)
|
40
|
+
},
|
41
|
+
)
|
42
|
+
|
43
|
+
priceMap = defaultdict(lambda: defaultdict(float))
|
44
|
+
for _, row in editTable.iterrows():
|
45
|
+
db, db_label, price = row
|
46
|
+
priceMap[db][db_label] = price
|
47
|
+
|
48
|
+
return priceMap
|
@@ -0,0 +1,10 @@
|
|
1
|
+
from streamlit_autorefresh import st_autorefresh
|
2
|
+
from vectordb_bench.frontend.const import *
|
3
|
+
|
4
|
+
|
5
|
+
def autoRefresh():
|
6
|
+
auto_refresh_count = st_autorefresh(
|
7
|
+
interval=MAX_AUTO_REFRESH_INTERVAL,
|
8
|
+
limit=MAX_AUTO_REFRESH_COUNT,
|
9
|
+
key="streamlit-auto-refresh",
|
10
|
+
)
|
@@ -0,0 +1,87 @@
|
|
1
|
+
from vectordb_bench.frontend.const import *
|
2
|
+
|
3
|
+
|
4
|
+
def caseSelector(st, activedDbList):
|
5
|
+
st.markdown(
|
6
|
+
"<div style='height: 24px;'></div>",
|
7
|
+
unsafe_allow_html=True,
|
8
|
+
)
|
9
|
+
st.subheader("STEP 2: Choose the case(s)")
|
10
|
+
st.markdown(
|
11
|
+
"<div style='color: #647489; margin-bottom: 24px; margin-top: -12px;'>Choose at least one case you want to run the test for. </div>",
|
12
|
+
unsafe_allow_html=True,
|
13
|
+
)
|
14
|
+
|
15
|
+
caseIsActived = {case["name"]: False for case in CASE_LIST}
|
16
|
+
allCaseConfigs = {db: {case["name"]: {} for case in CASE_LIST} for db in DB_LIST}
|
17
|
+
for case in CASE_LIST:
|
18
|
+
caseItemContainer = st.container()
|
19
|
+
caseIsActived[case["name"]] = caseItem(
|
20
|
+
caseItemContainer, allCaseConfigs, case, activedDbList
|
21
|
+
)
|
22
|
+
if case.get("divider"):
|
23
|
+
caseItemContainer.markdown(
|
24
|
+
"<div style='border: 1px solid #cccccc60; margin-bottom: 24px;'></div>",
|
25
|
+
unsafe_allow_html=True,
|
26
|
+
)
|
27
|
+
activedCaseList = [
|
28
|
+
case["name"] for case in CASE_LIST if caseIsActived[case["name"]]
|
29
|
+
]
|
30
|
+
return activedCaseList, allCaseConfigs
|
31
|
+
|
32
|
+
def caseItem(st, allCaseConfigs, case, activedDbList):
|
33
|
+
selected = st.checkbox(case["name"].value)
|
34
|
+
st.markdown(
|
35
|
+
f"<div style='color: #1D2939; margin: -8px 0 20px {CHECKBOX_INDENT}px; font-size: 14px;'>{case['intro']}</div>",
|
36
|
+
unsafe_allow_html=True,
|
37
|
+
)
|
38
|
+
|
39
|
+
if selected:
|
40
|
+
caseConfigSettingContainer = st.container()
|
41
|
+
caseConfigSetting(
|
42
|
+
caseConfigSettingContainer, allCaseConfigs, case["name"], activedDbList
|
43
|
+
)
|
44
|
+
|
45
|
+
return selected
|
46
|
+
|
47
|
+
|
48
|
+
def caseConfigSetting(st, allCaseConfigs, case, activedDbList):
|
49
|
+
for db in activedDbList:
|
50
|
+
columns = st.columns(1 + CASE_CONFIG_SETTING_COLUMNS)
|
51
|
+
# column 0 - title
|
52
|
+
dbColumn = columns[0]
|
53
|
+
dbColumn.markdown(
|
54
|
+
f"<div style='margin: 0 0 24px {CHECKBOX_INDENT}px; font-size: 18px; font-weight: 600;'>{db.name}</div>",
|
55
|
+
unsafe_allow_html=True,
|
56
|
+
)
|
57
|
+
caseConfig = allCaseConfigs[db][case]
|
58
|
+
k = 0
|
59
|
+
for config in CASE_CONFIG_MAP.get(db, {}).get(case, []):
|
60
|
+
if config.isDisplayed(caseConfig):
|
61
|
+
column = columns[1 + k % CASE_CONFIG_SETTING_COLUMNS]
|
62
|
+
key = "%s-%s-%s" % (db, case, config.label.value)
|
63
|
+
if config.inputType == InputType.Text:
|
64
|
+
caseConfig[config.label] = column.text_input(
|
65
|
+
config.label.value,
|
66
|
+
key=key,
|
67
|
+
value=config.inputConfig["value"],
|
68
|
+
)
|
69
|
+
elif config.inputType == InputType.Option:
|
70
|
+
caseConfig[config.label] = column.selectbox(
|
71
|
+
config.label.value,
|
72
|
+
config.inputConfig["options"],
|
73
|
+
key=key,
|
74
|
+
)
|
75
|
+
elif config.inputType == InputType.Number:
|
76
|
+
caseConfig[config.label] = column.number_input(
|
77
|
+
config.label.value,
|
78
|
+
format="%d",
|
79
|
+
step=1,
|
80
|
+
min_value=config.inputConfig["min"],
|
81
|
+
max_value=config.inputConfig["max"],
|
82
|
+
key=key,
|
83
|
+
value=config.inputConfig["value"],
|
84
|
+
)
|
85
|
+
k += 1
|
86
|
+
if k == 0:
|
87
|
+
columns[1].write("Auto")
|
@@ -0,0 +1,47 @@
|
|
1
|
+
from vectordb_bench.frontend.const import *
|
2
|
+
from vectordb_bench.frontend.utils import inputIsPassword
|
3
|
+
|
4
|
+
|
5
|
+
def dbConfigSettings(st, activedDbList):
|
6
|
+
st.markdown(
|
7
|
+
"<style> .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;}</style>",
|
8
|
+
unsafe_allow_html=True,
|
9
|
+
)
|
10
|
+
expander = st.expander("Configurations for the selected databases", True)
|
11
|
+
|
12
|
+
dbConfigs = {}
|
13
|
+
for activeDb in activedDbList:
|
14
|
+
dbConfigSettingItemContainer = expander.container()
|
15
|
+
dbConfig = dbConfigSettingItem(dbConfigSettingItemContainer, activeDb)
|
16
|
+
dbConfigs[activeDb] = dbConfig
|
17
|
+
|
18
|
+
return dbConfigs
|
19
|
+
|
20
|
+
def dbConfigSettingItem(st, activeDb):
|
21
|
+
st.markdown(
|
22
|
+
f"<div style='font-weight: 600; font-size: 20px; margin-top: 16px;'>{activeDb.value}</div>",
|
23
|
+
unsafe_allow_html=True,
|
24
|
+
)
|
25
|
+
columns = st.columns(DB_CONFIG_SETTING_COLUMNS)
|
26
|
+
|
27
|
+
activeDbCls = activeDb.init_cls
|
28
|
+
dbConfigClass = activeDbCls.config_cls()
|
29
|
+
properties = dbConfigClass.schema().get("properties")
|
30
|
+
propertiesItems = list(properties.items())
|
31
|
+
moveDBLabelToLast(propertiesItems)
|
32
|
+
dbConfig = {}
|
33
|
+
for j, property in enumerate(propertiesItems):
|
34
|
+
column = columns[j % DB_CONFIG_SETTING_COLUMNS]
|
35
|
+
key, value = property
|
36
|
+
dbConfig[key] = column.text_input(
|
37
|
+
key,
|
38
|
+
key="%s-%s" % (activeDb, key),
|
39
|
+
value=value.get("default", ""),
|
40
|
+
type="password" if inputIsPassword(key) else "default",
|
41
|
+
)
|
42
|
+
return dbConfigClass(**dbConfig)
|
43
|
+
|
44
|
+
|
45
|
+
def moveDBLabelToLast(propertiesItems):
|
46
|
+
propertiesItems.sort(key=lambda x: 1 if x[0] == 'db_label' else 0)
|
47
|
+
|
@@ -0,0 +1,36 @@
|
|
1
|
+
|
2
|
+
from vectordb_bench.frontend.const import *
|
3
|
+
|
4
|
+
|
5
|
+
def dbSelector(st):
|
6
|
+
st.markdown(
|
7
|
+
"<div style='height: 12px;'></div>",
|
8
|
+
unsafe_allow_html=True,
|
9
|
+
)
|
10
|
+
st.subheader("STEP 1: Select the database(s)")
|
11
|
+
st.markdown(
|
12
|
+
"<div style='color: #647489; margin-bottom: 24px; margin-top: -12px;'>Choose at least one case you want to run the test for. </div>",
|
13
|
+
unsafe_allow_html=True,
|
14
|
+
)
|
15
|
+
|
16
|
+
dbContainerColumns = st.columns(DB_SELECTOR_COLUMNS, gap="small")
|
17
|
+
dbIsActived = {db: False for db in DB_LIST}
|
18
|
+
|
19
|
+
# style - image; column gap; checkbox font;
|
20
|
+
st.markdown(
|
21
|
+
"""
|
22
|
+
<style>
|
23
|
+
div[data-testid='stImage'] {margin: auto;}
|
24
|
+
div[data-testid='stHorizontalBlock'] {gap: 8px;}
|
25
|
+
.stCheckbox p { color: #000; font-size: 18px; font-weight: 600; }
|
26
|
+
</style>
|
27
|
+
""",
|
28
|
+
unsafe_allow_html=True,
|
29
|
+
)
|
30
|
+
for i, db in enumerate(DB_LIST):
|
31
|
+
column = dbContainerColumns[i % DB_SELECTOR_COLUMNS]
|
32
|
+
dbIsActived[db] = column.checkbox(db.name)
|
33
|
+
column.image(DB_TO_ICON.get(db, ""))
|
34
|
+
activedDbList = [db for db in DB_LIST if dbIsActived[db]]
|
35
|
+
|
36
|
+
return activedDbList
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from vectordb_bench.models import CaseConfig, CaseConfigParamType, TaskConfig
|
2
|
+
|
3
|
+
|
4
|
+
def generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs):
|
5
|
+
tasks = [
|
6
|
+
TaskConfig(
|
7
|
+
db=db.value,
|
8
|
+
db_config=dbConfigs[db],
|
9
|
+
case_config=CaseConfig(
|
10
|
+
case_id=case.value,
|
11
|
+
custom_case={},
|
12
|
+
),
|
13
|
+
db_case_config=db.init_cls.case_config_cls(
|
14
|
+
allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None)
|
15
|
+
)(**{key.value: value for key, value in allCaseConfigs[db][case].items()}),
|
16
|
+
)
|
17
|
+
for case in activedCaseList
|
18
|
+
for db in activedDbList
|
19
|
+
]
|
20
|
+
|
21
|
+
return tasks
|