vectordb-bench 0.0.1__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. vectordb_bench/__init__.py +30 -0
  2. vectordb_bench/__main__.py +39 -0
  3. vectordb_bench/backend/__init__.py +0 -0
  4. vectordb_bench/backend/assembler.py +57 -0
  5. vectordb_bench/backend/cases.py +124 -0
  6. vectordb_bench/backend/clients/__init__.py +57 -0
  7. vectordb_bench/backend/clients/api.py +179 -0
  8. vectordb_bench/backend/clients/elastic_cloud/config.py +56 -0
  9. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +152 -0
  10. vectordb_bench/backend/clients/milvus/config.py +123 -0
  11. vectordb_bench/backend/clients/milvus/milvus.py +182 -0
  12. vectordb_bench/backend/clients/pinecone/config.py +15 -0
  13. vectordb_bench/backend/clients/pinecone/pinecone.py +113 -0
  14. vectordb_bench/backend/clients/qdrant_cloud/config.py +16 -0
  15. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +169 -0
  16. vectordb_bench/backend/clients/weaviate_cloud/config.py +45 -0
  17. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +151 -0
  18. vectordb_bench/backend/clients/zilliz_cloud/config.py +34 -0
  19. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +35 -0
  20. vectordb_bench/backend/dataset.py +393 -0
  21. vectordb_bench/backend/result_collector.py +15 -0
  22. vectordb_bench/backend/runner/__init__.py +12 -0
  23. vectordb_bench/backend/runner/mp_runner.py +124 -0
  24. vectordb_bench/backend/runner/serial_runner.py +164 -0
  25. vectordb_bench/backend/task_runner.py +290 -0
  26. vectordb_bench/backend/utils.py +85 -0
  27. vectordb_bench/base.py +6 -0
  28. vectordb_bench/frontend/components/check_results/charts.py +175 -0
  29. vectordb_bench/frontend/components/check_results/data.py +86 -0
  30. vectordb_bench/frontend/components/check_results/filters.py +97 -0
  31. vectordb_bench/frontend/components/check_results/headerIcon.py +18 -0
  32. vectordb_bench/frontend/components/check_results/nav.py +21 -0
  33. vectordb_bench/frontend/components/check_results/priceTable.py +48 -0
  34. vectordb_bench/frontend/components/run_test/autoRefresh.py +10 -0
  35. vectordb_bench/frontend/components/run_test/caseSelector.py +87 -0
  36. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +47 -0
  37. vectordb_bench/frontend/components/run_test/dbSelector.py +36 -0
  38. vectordb_bench/frontend/components/run_test/generateTasks.py +21 -0
  39. vectordb_bench/frontend/components/run_test/hideSidebar.py +10 -0
  40. vectordb_bench/frontend/components/run_test/submitTask.py +69 -0
  41. vectordb_bench/frontend/const.py +391 -0
  42. vectordb_bench/frontend/pages/qps_with_price.py +60 -0
  43. vectordb_bench/frontend/pages/run_test.py +59 -0
  44. vectordb_bench/frontend/utils.py +6 -0
  45. vectordb_bench/frontend/vdb_benchmark.py +42 -0
  46. vectordb_bench/interface.py +239 -0
  47. vectordb_bench/log_util.py +103 -0
  48. vectordb_bench/metric.py +53 -0
  49. vectordb_bench/models.py +234 -0
  50. vectordb_bench/results/result_20230609_standard.json +3228 -0
  51. vectordb_bench-0.0.1.dist-info/LICENSE +21 -0
  52. vectordb_bench-0.0.1.dist-info/METADATA +226 -0
  53. vectordb_bench-0.0.1.dist-info/RECORD +56 -0
  54. vectordb_bench-0.0.1.dist-info/WHEEL +5 -0
  55. vectordb_bench-0.0.1.dist-info/entry_points.txt +2 -0
  56. vectordb_bench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,175 @@
1
+ from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
2
+ from vectordb_bench.frontend.const import *
3
+ from vectordb_bench.models import ResultLabel
4
+ import plotly.express as px
5
+
6
+
7
+ def drawCharts(st, allData, failedTasks, cases):
8
+ st.markdown(
9
+ "<style> .main .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;} </style>",
10
+ unsafe_allow_html=True,
11
+ )
12
+ st.markdown(
13
+ """<style>
14
+ .main div[data-testid='stExpander'] {
15
+ background-color: #F6F8FA;
16
+ border: 1px solid #A9BDD140;
17
+ border-radius: 8px;
18
+ }
19
+ </style>""",
20
+ unsafe_allow_html=True,
21
+ )
22
+ for case in cases:
23
+ chartContainer = st.expander(case, True)
24
+ data = [data for data in allData if data["case"] == case]
25
+ drawChart(data, chartContainer)
26
+
27
+ errorDBs = failedTasks[case]
28
+ showFailedDBs(chartContainer, errorDBs)
29
+
30
+
31
+ def showFailedDBs(st, errorDBs):
32
+ failedDBs = [db for db, label in errorDBs.items() if label == ResultLabel.FAILED]
33
+ timeoutDBs = [
34
+ db for db, label in errorDBs.items() if label == ResultLabel.OUTOFRANGE
35
+ ]
36
+
37
+ showFailedText(st, "Failed", failedDBs)
38
+ showFailedText(st, "Timeout", timeoutDBs)
39
+
40
+
41
+ def showFailedText(st, text, dbs):
42
+ if len(dbs) > 0:
43
+ st.markdown(
44
+ f"<div style='margin: -16px 0 12px 8px; font-size: 16px; font-weight: 600;'>{text}: &nbsp;&nbsp;{', '.join(dbs)}</div>",
45
+ unsafe_allow_html=True,
46
+ )
47
+
48
+
49
+ def drawChart(data, st):
50
+ metricsSet = set()
51
+ for d in data:
52
+ metricsSet = metricsSet.union(d["metricsSet"])
53
+ showMetrics = [metric for metric in metricOrder if metric in metricsSet]
54
+
55
+ for i, metric in enumerate(showMetrics):
56
+ container = st.container()
57
+ drawMetricChart(data, metric, container)
58
+
59
+
60
+ def getLabelToShapeMap(data):
61
+ labelIndexMap = {}
62
+
63
+ dbSet = {d["db"] for d in data}
64
+ for db in dbSet:
65
+ labelSet = {d["db_label"] for d in data if d["db"] == db}
66
+ labelList = list(labelSet)
67
+
68
+ usedShapes = set()
69
+ i = 0
70
+ for label in labelList:
71
+ if label not in labelIndexMap:
72
+ loopCount = 0
73
+ while i % len(PATTERN_SHAPES) in usedShapes:
74
+ i += 1
75
+ loopCount += 1
76
+ if loopCount > len(PATTERN_SHAPES):
77
+ break
78
+ labelIndexMap[label] = i
79
+ i += 1
80
+ else:
81
+ usedShapes.add(labelIndexMap[label] % len(PATTERN_SHAPES))
82
+
83
+ labelToShapeMap = {
84
+ label: getPatternShape(index) for label, index in labelIndexMap.items()
85
+ }
86
+ return labelToShapeMap
87
+
88
+
89
+ def drawMetricChart(data, metric, st):
90
+ dataWithMetric = [d for d in data if d.get(metric, 0) > 1e-7]
91
+ # dataWithMetric = data
92
+ if len(dataWithMetric) == 0:
93
+ return
94
+
95
+ # title = st.container()
96
+ # title.markdown(
97
+ # f"**{metric}** ({'less' if isLowerIsBetterMetric(metric) else 'more'} is better)"
98
+ # )
99
+ chart = st.container()
100
+
101
+ height = len(dataWithMetric) * 24 + 48
102
+ xmin = 0
103
+ xmax = max([d.get(metric, 0) for d in dataWithMetric])
104
+ xpadding = (xmax - xmin) / 16
105
+ xpadding_multiplier = 1.6
106
+ xrange = [xmin, xmax + xpadding * xpadding_multiplier]
107
+ unit = metricUnitMap.get(metric, "")
108
+ labelToShapeMap = getLabelToShapeMap(dataWithMetric)
109
+ categoryorder = (
110
+ "total descending" if isLowerIsBetterMetric(metric) else "total ascending"
111
+ )
112
+ fig = px.bar(
113
+ dataWithMetric,
114
+ x=metric,
115
+ y="db_name",
116
+ color="db",
117
+ height=height,
118
+ # pattern_shape="db_label",
119
+ # pattern_shape_sequence=SHAPES,
120
+ pattern_shape_map=labelToShapeMap,
121
+ orientation="h",
122
+ hover_data={
123
+ "db": False,
124
+ "db_label": False,
125
+ "db_name": True,
126
+ },
127
+ color_discrete_map=COLOR_MAP,
128
+ text_auto=True,
129
+ title=f"{metric.capitalize()} ({'less' if isLowerIsBetterMetric(metric) else 'more'} is better)",
130
+ )
131
+ fig.update_xaxes(showticklabels=False, visible=False, range=xrange)
132
+ fig.update_yaxes(
133
+ # showticklabels=False,
134
+ # visible=False,
135
+ title=dict(
136
+ font=dict(
137
+ size=1,
138
+ ),
139
+ # text="",
140
+ )
141
+ )
142
+ fig.update_traces(
143
+ textposition="outside",
144
+ textfont=dict(
145
+ color="#333",
146
+ size=12,
147
+ ),
148
+ marker=dict(
149
+ pattern=dict(fillmode="overlay", fgcolor="#fff", fgopacity=1, size=7)
150
+ ),
151
+ texttemplate="%{x:,.4~r}" + unit,
152
+ )
153
+ fig.update_layout(
154
+ margin=dict(l=0, r=0, t=48, b=12, pad=8),
155
+ bargap=0.25,
156
+ showlegend=False,
157
+ legend=dict(
158
+ orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""
159
+ ),
160
+ # legend=dict(orientation="v", title=""),
161
+ yaxis={"categoryorder": categoryorder},
162
+ title=dict(
163
+ font=dict(
164
+ size=16,
165
+ color="#666",
166
+ # family="Arial, sans-serif",
167
+ ),
168
+ pad=dict(l=16),
169
+ # y=0.95,
170
+ # yanchor="top",
171
+ # yref="container",
172
+ ),
173
+ )
174
+
175
+ chart.plotly_chart(fig, use_container_width=True)
@@ -0,0 +1,86 @@
1
+ from collections import defaultdict
2
+ from dataclasses import asdict
3
+ from vectordb_bench.metric import isLowerIsBetterMetric
4
+ from vectordb_bench.models import ResultLabel
5
+
6
+
7
+ def getChartData(tasks, dbNames, cases):
8
+ filterTasks = getFilterTasks(tasks, dbNames, cases)
9
+ mergedTasks, failedTasks = mergeTasks(filterTasks)
10
+ return mergedTasks, failedTasks
11
+
12
+
13
+ def getFilterTasks(tasks, dbNames, cases):
14
+ filterTasks = [
15
+ task
16
+ for task in tasks
17
+ if task.task_config.db_name in dbNames
18
+ and task.task_config.case_config.case_id.value in cases
19
+ ]
20
+ return filterTasks
21
+
22
+
23
+ def mergeTasks(tasks):
24
+ dbCaseMetricsMap = defaultdict(lambda: defaultdict(dict))
25
+ for task in tasks:
26
+ db_name = task.task_config.db_name
27
+ db = task.task_config.db.value
28
+ db_label = task.task_config.db_config.db_label or ""
29
+ case = task.task_config.case_config.case_id.value
30
+ dbCaseMetricsMap[db_name][case] = {
31
+ "db": db,
32
+ "db_label": db_label,
33
+ "metrics": mergeMetrics(
34
+ dbCaseMetricsMap[db_name][case].get("metrics", {}), asdict(task.metrics)
35
+ ),
36
+ "label": getBetterLabel(dbCaseMetricsMap[db_name][case].get("label", ResultLabel.FAILED), task.label)
37
+ }
38
+
39
+ mergedTasks = []
40
+ failedTasks = defaultdict(lambda: defaultdict(str))
41
+ for db_name, caseMetricsMap in dbCaseMetricsMap.items():
42
+ for case, metricInfo in caseMetricsMap.items():
43
+ metrics = metricInfo["metrics"]
44
+ db = metricInfo["db"]
45
+ db_label = metricInfo["db_label"]
46
+ label = metricInfo["label"]
47
+ if label == ResultLabel.NORMAL:
48
+ mergedTasks.append(
49
+ {
50
+ "db_name": db_name,
51
+ "db": db,
52
+ "db_label": db_label,
53
+ "case": case,
54
+ "metricsSet": set(metrics.keys()),
55
+ **metrics,
56
+ }
57
+ )
58
+ else:
59
+ failedTasks[case][db_name] = label
60
+
61
+ return mergedTasks, failedTasks
62
+
63
+
64
+ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
65
+ metrics = {**metrics_1}
66
+ for key, value in metrics_2.items():
67
+ metrics[key] = (
68
+ getBetterMetric(key, value, metrics[key]) if key in metrics else value
69
+ )
70
+
71
+ return metrics
72
+
73
+
74
+ def getBetterMetric(metric, value_1, value_2):
75
+ if value_1 < 1e-7:
76
+ return value_2
77
+ if value_2 < 1e-7:
78
+ return value_1
79
+ return (
80
+ min(value_1, value_2)
81
+ if isLowerIsBetterMetric(metric)
82
+ else max(value_1, value_2)
83
+ )
84
+
85
+ def getBetterLabel(label_1: ResultLabel, label_2: ResultLabel):
86
+ return label_2 if label_1 != ResultLabel.NORMAL else label_1
@@ -0,0 +1,97 @@
1
+ from vectordb_bench.frontend.components.check_results.data import getChartData
2
+ from vectordb_bench.frontend.const import *
3
+
4
+
5
+ def getshownData(results, st):
6
+ # hide the nav
7
+ st.markdown(
8
+ "<style> div[data-testid='stSidebarNav'] {display: none;} </style>",
9
+ unsafe_allow_html=True,
10
+ )
11
+
12
+ st.header("Filters")
13
+
14
+ shownResults = getshownResults(results, st)
15
+ showDBNames, showCases = getShowDbsAndCases(shownResults, st)
16
+
17
+ shownData, failedTasks = getChartData(shownResults, showDBNames, showCases)
18
+
19
+ return shownData, failedTasks, showCases
20
+
21
+
22
+ def getshownResults(results, st):
23
+ resultSelectOptions = [
24
+ result.task_label
25
+ if result.task_label != result.run_id
26
+ else f"res-{result.run_id[:4]}"
27
+ for result in results
28
+ ]
29
+ if len(resultSelectOptions) == 0:
30
+ st.write(
31
+ "There are no results to display. Please wait for the task to complete or run a new task."
32
+ )
33
+ return []
34
+
35
+ selectedResultSelectedOptions = st.multiselect(
36
+ "Select the task results you need to analyze.",
37
+ resultSelectOptions,
38
+ # label_visibility="hidden",
39
+ default=resultSelectOptions,
40
+ )
41
+ selectedResult = []
42
+ for option in selectedResultSelectedOptions:
43
+ result = results[resultSelectOptions.index(option)].results
44
+ selectedResult += result
45
+
46
+ return selectedResult
47
+
48
+
49
+ def getShowDbsAndCases(result, st):
50
+ # expanderStyles
51
+ st.markdown("<style> section[data-testid='stSidebar'] div[data-testid='stExpander'] div[data-testid='stVerticalBlock'] { gap: 0.2rem; } </style>", unsafe_allow_html=True,)
52
+ st.markdown(
53
+ "<style> div[data-testid='stExpander'] {background-color: #ffffff;} </style>",
54
+ unsafe_allow_html=True,
55
+ )
56
+ st.markdown(
57
+ "<style> section[data-testid='stSidebar'] .streamlit-expanderHeader p {font-size: 16px; font-weight: 600;} </style>",
58
+ unsafe_allow_html=True,
59
+ )
60
+
61
+ allDbNames = list(set({res.task_config.db_name for res in result}))
62
+ allDbNames.sort()
63
+ allCasesSet = set({res.task_config.case_config.case_id for res in result})
64
+ allCases = [case["name"].value for case in CASE_LIST if case["name"] in allCasesSet]
65
+
66
+ # dbFilterContainer = st.container()
67
+ # dbFilterContainer.subheader("DB Filter")
68
+ dbFilterContainer = st.expander("DB Filter", True)
69
+ showDBNames = filterView(allDbNames, dbFilterContainer, col=1)
70
+
71
+ # caseFilterContainer = st.container()
72
+ # caseFilterContainer.subheader("Case Filter")
73
+ caseFilterContainer = st.expander("Case Filter", True)
74
+ showCases = filterView(
75
+ allCases,
76
+ caseFilterContainer,
77
+ col=1,
78
+ optionLables=[case for case in allCases],
79
+ )
80
+
81
+ return showDBNames, showCases
82
+
83
+
84
+ def filterView(options, st, col, optionLables=None):
85
+ columns = st.columns(
86
+ col,
87
+ gap="small",
88
+ )
89
+ isActive = {option: True for option in options}
90
+ if optionLables is None:
91
+ optionLables = options
92
+ for i, option in enumerate(options):
93
+ isActive[option] = columns[i % col].checkbox(
94
+ optionLables[i], value=isActive[option]
95
+ )
96
+
97
+ return [option for option in options if isActive[option]]
@@ -0,0 +1,18 @@
1
+ def drawHeaderIcon(st):
2
+ st.markdown("""
3
+ <div class="headerIconContainer"></div>
4
+
5
+ <style>
6
+ .headerIconContainer {
7
+ position: absolute;
8
+ top: -50px;
9
+ height: 50px;
10
+ width: 100%;
11
+ border-bottom: 2px solid #E8EAEE;
12
+ background-image: url(https://assets.zilliz.com/vdb_benchmark_db790b5387.png);
13
+ background-repeat: no-repeat;
14
+ }
15
+ </style
16
+ """,
17
+ unsafe_allow_html=True,
18
+ )
@@ -0,0 +1,21 @@
1
+ from streamlit_extras.switch_page_button import switch_page
2
+
3
+
4
+ def NavToRunTest(st):
5
+ st.header("Run your test")
6
+ st.write("You can set the configs and run your own test.")
7
+ navClick = st.button("Run Your Test &nbsp;&nbsp;>")
8
+ if navClick:
9
+ switch_page("run test")
10
+
11
+
12
+ def NavToQPSWithPrice(st):
13
+ navClick = st.button("QPS with Price &nbsp;&nbsp;>")
14
+ if navClick:
15
+ switch_page("qps with price")
16
+
17
+
18
+ def NavToResults(st):
19
+ navClick = st.button("< &nbsp;&nbsp;Back to Results")
20
+ if navClick:
21
+ switch_page("vdb benchmark")
@@ -0,0 +1,48 @@
1
+ from vectordb_bench.backend.clients import DB
2
+ from vectordb_bench.frontend.const import DB_DBLABEL_TO_PRICE
3
+ import pandas as pd
4
+ from collections import defaultdict
5
+ import streamlit as st
6
+
7
+
8
+ def priceTable(container, data):
9
+ dbAndLabelSet = {
10
+ (d["db"], d["db_label"]) for d in data if d["db"] != DB.Milvus.value
11
+ }
12
+
13
+ dbAndLabelList = list(dbAndLabelSet)
14
+ dbAndLabelList.sort()
15
+
16
+ table = pd.DataFrame(
17
+ [
18
+ {
19
+ "DB": db,
20
+ "Label": db_label,
21
+ "Price per hour": DB_DBLABEL_TO_PRICE.get(db, {}).get(db_label, 0),
22
+ }
23
+ for db, db_label in dbAndLabelList
24
+ ]
25
+ )
26
+ height = len(table) * 35 + 38
27
+
28
+ expander = container.expander("You can edit the price.")
29
+ editTable = expander.data_editor(
30
+ table,
31
+ use_container_width=True,
32
+ hide_index=True,
33
+ height=height,
34
+ disabled=("DB", "Label"),
35
+ column_config={
36
+ "Price per hour": st.column_config.NumberColumn(
37
+ min_value=0,
38
+ format="$ %f",
39
+ )
40
+ },
41
+ )
42
+
43
+ priceMap = defaultdict(lambda: defaultdict(float))
44
+ for _, row in editTable.iterrows():
45
+ db, db_label, price = row
46
+ priceMap[db][db_label] = price
47
+
48
+ return priceMap
@@ -0,0 +1,10 @@
1
+ from streamlit_autorefresh import st_autorefresh
2
+ from vectordb_bench.frontend.const import *
3
+
4
+
5
+ def autoRefresh():
6
+ auto_refresh_count = st_autorefresh(
7
+ interval=MAX_AUTO_REFRESH_INTERVAL,
8
+ limit=MAX_AUTO_REFRESH_COUNT,
9
+ key="streamlit-auto-refresh",
10
+ )
@@ -0,0 +1,87 @@
1
+ from vectordb_bench.frontend.const import *
2
+
3
+
4
+ def caseSelector(st, activedDbList):
5
+ st.markdown(
6
+ "<div style='height: 24px;'></div>",
7
+ unsafe_allow_html=True,
8
+ )
9
+ st.subheader("STEP 2: Choose the case(s)")
10
+ st.markdown(
11
+ "<div style='color: #647489; margin-bottom: 24px; margin-top: -12px;'>Choose at least one case you want to run the test for. </div>",
12
+ unsafe_allow_html=True,
13
+ )
14
+
15
+ caseIsActived = {case["name"]: False for case in CASE_LIST}
16
+ allCaseConfigs = {db: {case["name"]: {} for case in CASE_LIST} for db in DB_LIST}
17
+ for case in CASE_LIST:
18
+ caseItemContainer = st.container()
19
+ caseIsActived[case["name"]] = caseItem(
20
+ caseItemContainer, allCaseConfigs, case, activedDbList
21
+ )
22
+ if case.get("divider"):
23
+ caseItemContainer.markdown(
24
+ "<div style='border: 1px solid #cccccc60; margin-bottom: 24px;'></div>",
25
+ unsafe_allow_html=True,
26
+ )
27
+ activedCaseList = [
28
+ case["name"] for case in CASE_LIST if caseIsActived[case["name"]]
29
+ ]
30
+ return activedCaseList, allCaseConfigs
31
+
32
+ def caseItem(st, allCaseConfigs, case, activedDbList):
33
+ selected = st.checkbox(case["name"].value)
34
+ st.markdown(
35
+ f"<div style='color: #1D2939; margin: -8px 0 20px {CHECKBOX_INDENT}px; font-size: 14px;'>{case['intro']}</div>",
36
+ unsafe_allow_html=True,
37
+ )
38
+
39
+ if selected:
40
+ caseConfigSettingContainer = st.container()
41
+ caseConfigSetting(
42
+ caseConfigSettingContainer, allCaseConfigs, case["name"], activedDbList
43
+ )
44
+
45
+ return selected
46
+
47
+
48
+ def caseConfigSetting(st, allCaseConfigs, case, activedDbList):
49
+ for db in activedDbList:
50
+ columns = st.columns(1 + CASE_CONFIG_SETTING_COLUMNS)
51
+ # column 0 - title
52
+ dbColumn = columns[0]
53
+ dbColumn.markdown(
54
+ f"<div style='margin: 0 0 24px {CHECKBOX_INDENT}px; font-size: 18px; font-weight: 600;'>{db.name}</div>",
55
+ unsafe_allow_html=True,
56
+ )
57
+ caseConfig = allCaseConfigs[db][case]
58
+ k = 0
59
+ for config in CASE_CONFIG_MAP.get(db, {}).get(case, []):
60
+ if config.isDisplayed(caseConfig):
61
+ column = columns[1 + k % CASE_CONFIG_SETTING_COLUMNS]
62
+ key = "%s-%s-%s" % (db, case, config.label.value)
63
+ if config.inputType == InputType.Text:
64
+ caseConfig[config.label] = column.text_input(
65
+ config.label.value,
66
+ key=key,
67
+ value=config.inputConfig["value"],
68
+ )
69
+ elif config.inputType == InputType.Option:
70
+ caseConfig[config.label] = column.selectbox(
71
+ config.label.value,
72
+ config.inputConfig["options"],
73
+ key=key,
74
+ )
75
+ elif config.inputType == InputType.Number:
76
+ caseConfig[config.label] = column.number_input(
77
+ config.label.value,
78
+ format="%d",
79
+ step=1,
80
+ min_value=config.inputConfig["min"],
81
+ max_value=config.inputConfig["max"],
82
+ key=key,
83
+ value=config.inputConfig["value"],
84
+ )
85
+ k += 1
86
+ if k == 0:
87
+ columns[1].write("Auto")
@@ -0,0 +1,47 @@
1
+ from vectordb_bench.frontend.const import *
2
+ from vectordb_bench.frontend.utils import inputIsPassword
3
+
4
+
5
+ def dbConfigSettings(st, activedDbList):
6
+ st.markdown(
7
+ "<style> .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;}</style>",
8
+ unsafe_allow_html=True,
9
+ )
10
+ expander = st.expander("Configurations for the selected databases", True)
11
+
12
+ dbConfigs = {}
13
+ for activeDb in activedDbList:
14
+ dbConfigSettingItemContainer = expander.container()
15
+ dbConfig = dbConfigSettingItem(dbConfigSettingItemContainer, activeDb)
16
+ dbConfigs[activeDb] = dbConfig
17
+
18
+ return dbConfigs
19
+
20
+ def dbConfigSettingItem(st, activeDb):
21
+ st.markdown(
22
+ f"<div style='font-weight: 600; font-size: 20px; margin-top: 16px;'>{activeDb.value}</div>",
23
+ unsafe_allow_html=True,
24
+ )
25
+ columns = st.columns(DB_CONFIG_SETTING_COLUMNS)
26
+
27
+ activeDbCls = activeDb.init_cls
28
+ dbConfigClass = activeDbCls.config_cls()
29
+ properties = dbConfigClass.schema().get("properties")
30
+ propertiesItems = list(properties.items())
31
+ moveDBLabelToLast(propertiesItems)
32
+ dbConfig = {}
33
+ for j, property in enumerate(propertiesItems):
34
+ column = columns[j % DB_CONFIG_SETTING_COLUMNS]
35
+ key, value = property
36
+ dbConfig[key] = column.text_input(
37
+ key,
38
+ key="%s-%s" % (activeDb, key),
39
+ value=value.get("default", ""),
40
+ type="password" if inputIsPassword(key) else "default",
41
+ )
42
+ return dbConfigClass(**dbConfig)
43
+
44
+
45
+ def moveDBLabelToLast(propertiesItems):
46
+ propertiesItems.sort(key=lambda x: 1 if x[0] == 'db_label' else 0)
47
+
@@ -0,0 +1,36 @@
1
+
2
+ from vectordb_bench.frontend.const import *
3
+
4
+
5
+ def dbSelector(st):
6
+ st.markdown(
7
+ "<div style='height: 12px;'></div>",
8
+ unsafe_allow_html=True,
9
+ )
10
+ st.subheader("STEP 1: Select the database(s)")
11
+ st.markdown(
12
+ "<div style='color: #647489; margin-bottom: 24px; margin-top: -12px;'>Choose at least one case you want to run the test for. </div>",
13
+ unsafe_allow_html=True,
14
+ )
15
+
16
+ dbContainerColumns = st.columns(DB_SELECTOR_COLUMNS, gap="small")
17
+ dbIsActived = {db: False for db in DB_LIST}
18
+
19
+ # style - image; column gap; checkbox font;
20
+ st.markdown(
21
+ """
22
+ <style>
23
+ div[data-testid='stImage'] {margin: auto;}
24
+ div[data-testid='stHorizontalBlock'] {gap: 8px;}
25
+ .stCheckbox p { color: #000; font-size: 18px; font-weight: 600; }
26
+ </style>
27
+ """,
28
+ unsafe_allow_html=True,
29
+ )
30
+ for i, db in enumerate(DB_LIST):
31
+ column = dbContainerColumns[i % DB_SELECTOR_COLUMNS]
32
+ dbIsActived[db] = column.checkbox(db.name)
33
+ column.image(DB_TO_ICON.get(db, ""))
34
+ activedDbList = [db for db in DB_LIST if dbIsActived[db]]
35
+
36
+ return activedDbList
@@ -0,0 +1,21 @@
1
+ from vectordb_bench.models import CaseConfig, CaseConfigParamType, TaskConfig
2
+
3
+
4
+ def generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs):
5
+ tasks = [
6
+ TaskConfig(
7
+ db=db.value,
8
+ db_config=dbConfigs[db],
9
+ case_config=CaseConfig(
10
+ case_id=case.value,
11
+ custom_case={},
12
+ ),
13
+ db_case_config=db.init_cls.case_config_cls(
14
+ allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None)
15
+ )(**{key.value: value for key, value in allCaseConfigs[db][case].items()}),
16
+ )
17
+ for case in activedCaseList
18
+ for db in activedDbList
19
+ ]
20
+
21
+ return tasks
@@ -0,0 +1,10 @@
1
+ def hideSidebar(st):
2
+ st.markdown(
3
+ "<style> div[data-testid='collapsedControl'] {display: none;} </style>",
4
+ unsafe_allow_html=True,
5
+ )
6
+
7
+ st.markdown(
8
+ "<style> .block-container { max-width: 1000px; } </style>",
9
+ unsafe_allow_html=True,
10
+ )