vectordb-bench 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. vectordb_bench/__init__.py +1 -0
  2. vectordb_bench/backend/assembler.py +1 -1
  3. vectordb_bench/backend/cases.py +64 -18
  4. vectordb_bench/backend/clients/__init__.py +13 -0
  5. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
  6. vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
  7. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
  8. vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
  9. vectordb_bench/backend/dataset.py +27 -5
  10. vectordb_bench/cli/vectordbbench.py +2 -0
  11. vectordb_bench/custom/custom_case.json +18 -0
  12. vectordb_bench/frontend/components/check_results/charts.py +6 -6
  13. vectordb_bench/frontend/components/check_results/data.py +12 -12
  14. vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
  15. vectordb_bench/frontend/components/check_results/filters.py +20 -13
  16. vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
  17. vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
  18. vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
  19. vectordb_bench/frontend/components/concurrent/charts.py +26 -29
  20. vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
  21. vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
  22. vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
  23. vectordb_bench/frontend/components/custom/initStyle.py +15 -0
  24. vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
  25. vectordb_bench/frontend/components/run_test/caseSelector.py +40 -28
  26. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -5
  27. vectordb_bench/frontend/components/run_test/dbSelector.py +2 -14
  28. vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
  29. vectordb_bench/frontend/components/run_test/initStyle.py +14 -0
  30. vectordb_bench/frontend/components/run_test/submitTask.py +1 -1
  31. vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +138 -31
  32. vectordb_bench/frontend/{const → config}/styles.py +2 -0
  33. vectordb_bench/frontend/pages/concurrent.py +11 -18
  34. vectordb_bench/frontend/pages/custom.py +64 -0
  35. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
  36. vectordb_bench/frontend/pages/run_test.py +4 -0
  37. vectordb_bench/frontend/pages/tables.py +2 -2
  38. vectordb_bench/frontend/utils.py +17 -1
  39. vectordb_bench/frontend/vdb_benchmark.py +3 -3
  40. vectordb_bench/models.py +8 -4
  41. vectordb_bench/results/getLeaderboardData.py +1 -1
  42. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/METADATA +36 -13
  43. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/RECORD +48 -37
  44. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/WHEEL +1 -1
  45. /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
  46. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/LICENSE +0 -0
  47. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/entry_points.txt +0 -0
  48. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.12.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,9 @@ from vectordb_bench.models import CaseResult, ResultLabel
8
8
  def getChartData(
9
9
  tasks: list[CaseResult],
10
10
  dbNames: list[str],
11
- cases: list[Case],
11
+ caseNames: list[str],
12
12
  ):
13
- filterTasks = getFilterTasks(tasks, dbNames, cases)
13
+ filterTasks = getFilterTasks(tasks, dbNames, caseNames)
14
14
  mergedTasks, failedTasks = mergeTasks(filterTasks)
15
15
  return mergedTasks, failedTasks
16
16
 
@@ -18,14 +18,13 @@ def getChartData(
18
18
  def getFilterTasks(
19
19
  tasks: list[CaseResult],
20
20
  dbNames: list[str],
21
- cases: list[Case],
21
+ caseNames: list[str],
22
22
  ) -> list[CaseResult]:
23
- case_ids = [case.case_id for case in cases]
24
23
  filterTasks = [
25
24
  task
26
25
  for task in tasks
27
26
  if task.task_config.db_name in dbNames
28
- and task.task_config.case_config.case_id in case_ids
27
+ and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
29
28
  ]
30
29
  return filterTasks
31
30
 
@@ -36,16 +35,17 @@ def mergeTasks(tasks: list[CaseResult]):
36
35
  db_name = task.task_config.db_name
37
36
  db = task.task_config.db.value
38
37
  db_label = task.task_config.db_config.db_label or ""
39
- case_id = task.task_config.case_config.case_id
40
- dbCaseMetricsMap[db_name][case_id] = {
38
+ case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
39
+ dbCaseMetricsMap[db_name][case.name] = {
41
40
  "db": db,
42
41
  "db_label": db_label,
43
42
  "metrics": mergeMetrics(
44
- dbCaseMetricsMap[db_name][case_id].get("metrics", {}),
43
+ dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
45
44
  asdict(task.metrics),
46
45
  ),
47
46
  "label": getBetterLabel(
48
- dbCaseMetricsMap[db_name][case_id].get("label", ResultLabel.FAILED),
47
+ dbCaseMetricsMap[db_name][case.name].get(
48
+ "label", ResultLabel.FAILED),
49
49
  task.label,
50
50
  ),
51
51
  }
@@ -53,12 +53,11 @@ def mergeTasks(tasks: list[CaseResult]):
53
53
  mergedTasks = []
54
54
  failedTasks = defaultdict(lambda: defaultdict(str))
55
55
  for db_name, caseMetricsMap in dbCaseMetricsMap.items():
56
- for case_id, metricInfo in caseMetricsMap.items():
56
+ for case_name, metricInfo in caseMetricsMap.items():
57
57
  metrics = metricInfo["metrics"]
58
58
  db = metricInfo["db"]
59
59
  db_label = metricInfo["db_label"]
60
60
  label = metricInfo["label"]
61
- case_name = case_id.case_name
62
61
  if label == ResultLabel.NORMAL:
63
62
  mergedTasks.append(
64
63
  {
@@ -80,7 +79,8 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
80
79
  metrics = {**metrics_1}
81
80
  for key, value in metrics_2.items():
82
81
  metrics[key] = (
83
- getBetterMetric(key, value, metrics[key]) if key in metrics else value
82
+ getBetterMetric(
83
+ key, value, metrics[key]) if key in metrics else value
84
84
  )
85
85
 
86
86
  return metrics
@@ -1,7 +1,7 @@
1
1
  def initMainExpanderStyle(st):
2
2
  st.markdown(
3
3
  """<style>
4
- .main .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;}
4
+ .main div[data-testid='stExpander'] p {font-size: 18px; font-weight: 600;}
5
5
  .main div[data-testid='stExpander'] {
6
6
  background-color: #F6F8FA;
7
7
  border: 1px solid #A9BDD140;
@@ -1,8 +1,8 @@
1
1
  from vectordb_bench.backend.cases import Case
2
2
  from vectordb_bench.frontend.components.check_results.data import getChartData
3
3
  from vectordb_bench.frontend.components.check_results.expanderStyle import initSidebarExanderStyle
4
- from vectordb_bench.frontend.const.dbCaseConfigs import CASE_LIST
5
- from vectordb_bench.frontend.const.styles import *
4
+ from vectordb_bench.frontend.config.dbCaseConfigs import CASE_NAME_ORDER
5
+ from vectordb_bench.frontend.config.styles import *
6
6
  import streamlit as st
7
7
 
8
8
  from vectordb_bench.models import CaseResult, TestResult
@@ -18,11 +18,12 @@ def getshownData(results: list[TestResult], st):
18
18
  st.header("Filters")
19
19
 
20
20
  shownResults = getshownResults(results, st)
21
- showDBNames, showCases = getShowDbsAndCases(shownResults, st)
21
+ showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st)
22
22
 
23
- shownData, failedTasks = getChartData(shownResults, showDBNames, showCases)
23
+ shownData, failedTasks = getChartData(
24
+ shownResults, showDBNames, showCaseNames)
24
25
 
25
- return shownData, failedTasks, showCases
26
+ return shownData, failedTasks, showCaseNames
26
27
 
27
28
 
28
29
  def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
@@ -52,12 +53,18 @@ def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
52
53
  return selectedResult
53
54
 
54
55
 
55
- def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[Case]]:
56
+ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[str]]:
56
57
  initSidebarExanderStyle(st)
57
58
  allDbNames = list(set({res.task_config.db_name for res in result}))
58
59
  allDbNames.sort()
59
- allCasesSet = set({res.task_config.case_config.case_id for res in result})
60
- allCases: list[Case] = [case.case_cls() for case in CASE_LIST if case in allCasesSet]
60
+ allCases: list[Case] = [
61
+ res.task_config.case_config.case_id.case_cls(
62
+ res.task_config.case_config.custom_case)
63
+ for res in result
64
+ ]
65
+ allCaseNameSet = set({case.name for case in allCases})
66
+ allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + \
67
+ [case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER]
61
68
 
62
69
  # DB Filter
63
70
  dbFilterContainer = st.container()
@@ -70,15 +77,14 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[Ca
70
77
 
71
78
  # Case Filter
72
79
  caseFilterContainer = st.container()
73
- showCases = filterView(
80
+ showCaseNames = filterView(
74
81
  caseFilterContainer,
75
82
  "Case Filter",
76
- [case for case in allCases],
83
+ [caseName for caseName in allCaseNames],
77
84
  col=1,
78
- optionLables=[case.name for case in allCases],
79
85
  )
80
86
 
81
- return showDBNames, showCases
87
+ return showDBNames, showCaseNames
82
88
 
83
89
 
84
90
  def filterView(container, header, options, col, optionLables=None):
@@ -114,7 +120,8 @@ def filterView(container, header, options, col, optionLables=None):
114
120
  )
115
121
  if optionLables is None:
116
122
  optionLables = options
117
- isActive = {option: st.session_state[selectAllState] for option in optionLables}
123
+ isActive = {option: st.session_state[selectAllState]
124
+ for option in optionLables}
118
125
  for i, option in enumerate(optionLables):
119
126
  isActive[option] = columns[i % col].checkbox(
120
127
  optionLables[i],
@@ -1,4 +1,4 @@
1
- from vectordb_bench.frontend.const.styles import HEADER_ICON
1
+ from vectordb_bench.frontend.config.styles import HEADER_ICON
2
2
 
3
3
 
4
4
  def drawHeaderIcon(st):
@@ -3,7 +3,7 @@ import pandas as pd
3
3
  from collections import defaultdict
4
4
  import streamlit as st
5
5
 
6
- from vectordb_bench.frontend.const.dbPrices import DB_DBLABEL_TO_PRICE
6
+ from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE
7
7
 
8
8
 
9
9
  def priceTable(container, data):
@@ -1,4 +1,4 @@
1
- from vectordb_bench.frontend.const.styles import *
1
+ from vectordb_bench.frontend.config.styles import *
2
2
 
3
3
 
4
4
  def initResultsPageConfig(st):
@@ -1,26 +1,27 @@
1
-
2
-
3
- from vectordb_bench.backend.cases import Case
4
- from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
1
+ from vectordb_bench.frontend.components.check_results.expanderStyle import (
2
+ initMainExpanderStyle,
3
+ )
5
4
  import plotly.express as px
6
5
 
7
- from vectordb_bench.frontend.const.styles import COLOR_MAP
6
+ from vectordb_bench.frontend.config.styles import COLOR_MAP
8
7
 
9
8
 
10
- def drawChartsByCase(allData, cases: list[Case], st):
9
+ def drawChartsByCase(allData, showCaseNames: list[str], st):
11
10
  initMainExpanderStyle(st)
12
- for case in cases:
13
- chartContainer = st.expander(case.name, True)
14
- caseDataList = [
15
- data for data in allData if data["case_name"] == case.name]
16
- data = [{
17
- "conc_num": caseData["conc_num_list"][i],
18
- "qps": caseData["conc_qps_list"][i],
19
- "latency_p99": caseData["conc_latency_p99_list"][i] * 1000,
20
- "db_name": caseData["db_name"],
21
- "db": caseData["db"]
22
-
23
- } for caseData in caseDataList for i in range(len(caseData["conc_num_list"]))]
11
+ for caseName in showCaseNames:
12
+ chartContainer = st.expander(caseName, True)
13
+ caseDataList = [data for data in allData if data["case_name"] == caseName]
14
+ data = [
15
+ {
16
+ "conc_num": caseData["conc_num_list"][i],
17
+ "qps": caseData["conc_qps_list"][i],
18
+ "latency_p99": caseData["conc_latency_p99_list"][i] * 1000,
19
+ "db_name": caseData["db_name"],
20
+ "db": caseData["db"],
21
+ }
22
+ for caseData in caseDataList
23
+ for i in range(len(caseData["conc_num_list"]))
24
+ ]
24
25
  drawChart(data, chartContainer)
25
26
 
26
27
 
@@ -38,7 +39,7 @@ def getRange(metric, data, padding_multipliers):
38
39
  def drawChart(data, st):
39
40
  if len(data) == 0:
40
41
  return
41
-
42
+
42
43
  x = "latency_p99"
43
44
  xrange = getRange(x, data, [0.05, 0.1])
44
45
 
@@ -63,7 +64,6 @@ def drawChart(data, st):
63
64
  line_group=line_group,
64
65
  text=text,
65
66
  markers=True,
66
- # color_discrete_map=color_discrete_map,
67
67
  hover_data={
68
68
  "conc_num": True,
69
69
  },
@@ -71,12 +71,9 @@ def drawChart(data, st):
71
71
  )
72
72
  fig.update_xaxes(range=xrange, title_text="Latency P99 (ms)")
73
73
  fig.update_yaxes(range=yrange, title_text="QPS")
74
- fig.update_traces(textposition="bottom right",
75
- texttemplate="conc-%{text:,.4~r}")
76
- # fig.update_layout(
77
- # margin=dict(l=0, r=0, t=40, b=0, pad=8),
78
- # legend=dict(
79
- # orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""
80
- # ),
81
- # )
82
- st.plotly_chart(fig, use_container_width=True,)
74
+ fig.update_traces(textposition="bottom right", texttemplate="conc-%{text:,.4~r}")
75
+
76
+ st.plotly_chart(
77
+ fig,
78
+ use_container_width=True,
79
+ )
@@ -0,0 +1,31 @@
1
+
2
+ from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig
3
+
4
+
5
+ def displayCustomCase(customCase: CustomCaseConfig, st, key):
6
+
7
+ columns = st.columns([1, 2])
8
+ customCase.dataset_config.name = columns[0].text_input(
9
+ "Name", key=f"{key}_name", value=customCase.dataset_config.name)
10
+ customCase.name = f"{customCase.dataset_config.name} (Performace Case)"
11
+ customCase.dataset_config.dir = columns[1].text_input(
12
+ "Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir)
13
+
14
+ columns = st.columns(4)
15
+ customCase.dataset_config.dim = columns[0].number_input(
16
+ "dim", key=f"{key}_dim", value=customCase.dataset_config.dim)
17
+ customCase.dataset_config.size = columns[1].number_input(
18
+ "size", key=f"{key}_size", value=customCase.dataset_config.size)
19
+ customCase.dataset_config.metric_type = columns[2].selectbox(
20
+ "metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"])
21
+ customCase.dataset_config.file_count = columns[3].number_input(
22
+ "train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count)
23
+
24
+ columns = st.columns(4)
25
+ customCase.dataset_config.use_shuffled = columns[0].checkbox(
26
+ "use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled)
27
+ customCase.dataset_config.with_gt = columns[1].checkbox(
28
+ "with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt)
29
+
30
+ customCase.description = st.text_area(
31
+ "description", key=f"{key}_description", value=customCase.description)
@@ -0,0 +1,11 @@
1
+ def displayParams(st):
2
+ st.markdown("""
3
+ - `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
4
+ - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
5
+ - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
6
+ - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`.
7
+
8
+ - `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files.
9
+
10
+ - `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.
11
+ """)
@@ -0,0 +1,40 @@
1
+ import json
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from vectordb_bench import config
6
+
7
+
8
+ class CustomDatasetConfig(BaseModel):
9
+ name: str = "custom_dataset"
10
+ dir: str = ""
11
+ size: int = 0
12
+ dim: int = 0
13
+ metric_type: str = "L2"
14
+ file_count: int = 1
15
+ use_shuffled: bool = False
16
+ with_gt: bool = True
17
+
18
+
19
+ class CustomCaseConfig(BaseModel):
20
+ name: str = "custom_dataset (Performace Case)"
21
+ description: str = ""
22
+ load_timeout: int = 36000
23
+ optimize_timeout: int = 36000
24
+ dataset_config: CustomDatasetConfig = CustomDatasetConfig()
25
+
26
+
27
+ def get_custom_configs():
28
+ with open(config.CUSTOM_CONFIG_DIR, "r") as f:
29
+ custom_configs = json.load(f)
30
+ return [CustomCaseConfig(**custom_config) for custom_config in custom_configs]
31
+
32
+
33
+ def save_custom_configs(custom_configs: list[CustomDatasetConfig]):
34
+ with open(config.CUSTOM_CONFIG_DIR, "w") as f:
35
+ json.dump([custom_config.dict()
36
+ for custom_config in custom_configs], f, indent=4)
37
+
38
+
39
+ def generate_custom_case():
40
+ return CustomCaseConfig()
@@ -0,0 +1,15 @@
1
+ def initStyle(st):
2
+ st.markdown(
3
+ """<style>
4
+ /* expander - header */
5
+ .main div[data-testid='stExpander'] summary p {font-size: 20px; font-weight: 600;}
6
+ /*
7
+ button {
8
+ height: auto;
9
+ padding-left: 8px !important;
10
+ padding-right: 6px !important;
11
+ }
12
+ */
13
+ </style>""",
14
+ unsafe_allow_html=True,
15
+ )
@@ -1,5 +1,5 @@
1
1
  from streamlit_autorefresh import st_autorefresh
2
- from vectordb_bench.frontend.const.styles import *
2
+ from vectordb_bench.frontend.config.styles import *
3
3
 
4
4
 
5
5
  def autoRefresh():
@@ -1,9 +1,13 @@
1
- from vectordb_bench.frontend.const.styles import *
1
+
2
+ from vectordb_bench.frontend.config.styles import *
2
3
  from vectordb_bench.backend.cases import CaseType
3
- from vectordb_bench.frontend.const.dbCaseConfigs import *
4
+ from vectordb_bench.frontend.config.dbCaseConfigs import *
5
+ from collections import defaultdict
6
+
7
+ from vectordb_bench.frontend.utils import addHorizontalLine
4
8
 
5
9
 
6
- def caseSelector(st, activedDbList):
10
+ def caseSelector(st, activedDbList: list[DB]):
7
11
  st.markdown(
8
12
  "<div style='height: 24px;'></div>",
9
13
  unsafe_allow_html=True,
@@ -14,41 +18,49 @@ def caseSelector(st, activedDbList):
14
18
  unsafe_allow_html=True,
15
19
  )
16
20
 
17
- caseIsActived = {case: False for case in CASE_LIST}
18
- allCaseConfigs = {db: {case: {} for case in CASE_LIST} for db in DB_LIST}
19
- for caseOrDivider in CASE_LIST_WITH_DIVIDER:
20
- if caseOrDivider == DIVIDER:
21
- caseItemContainer.markdown(
22
- "<div style='border: 1px solid #cccccc60; margin-bottom: 24px;'></div>",
23
- unsafe_allow_html=True,
24
- )
21
+ activedCaseList: list[CaseConfig] = []
22
+ dbToCaseClusterConfigs = defaultdict(lambda: defaultdict(dict))
23
+ dbToCaseConfigs = defaultdict(lambda: defaultdict(dict))
24
+ caseClusters = UI_CASE_CLUSTERS + [get_custom_case_cluter()]
25
+ for caseCluster in caseClusters:
26
+ activedCaseList += caseClusterExpander(
27
+ st, caseCluster, dbToCaseClusterConfigs, activedDbList)
28
+ for db in dbToCaseClusterConfigs:
29
+ for uiCaseItem in dbToCaseClusterConfigs[db]:
30
+ for case in uiCaseItem.cases:
31
+ dbToCaseConfigs[db][case] = dbToCaseClusterConfigs[db][uiCaseItem]
32
+
33
+ return activedCaseList, dbToCaseConfigs
34
+
35
+
36
+ def caseClusterExpander(st, caseCluster: UICaseItemCluster, dbToCaseClusterConfigs, activedDbList: list[DB]):
37
+ expander = st.expander(caseCluster.label, False)
38
+ activedCases: list[CaseConfig] = []
39
+ for uiCaseItem in caseCluster.uiCaseItems:
40
+ if uiCaseItem.isLine:
41
+ addHorizontalLine(expander)
25
42
  else:
26
- case = caseOrDivider
27
- caseItemContainer = st.container()
28
- caseIsActived[case] = caseItem(
29
- caseItemContainer, allCaseConfigs, case, activedDbList
30
- )
31
- activedCaseList = [case for case in CASE_LIST if caseIsActived[case]]
32
- return activedCaseList, allCaseConfigs
43
+ activedCases += caseItemCheckbox(expander,
44
+ dbToCaseClusterConfigs, uiCaseItem, activedDbList)
45
+ return activedCases
33
46
 
34
47
 
35
- def caseItem(st, allCaseConfigs, case: CaseType, activedDbList):
36
- selected = st.checkbox(case.case_name)
48
+ def caseItemCheckbox(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
49
+ selected = st.checkbox(uiCaseItem.label)
37
50
  st.markdown(
38
- f"<div style='color: #1D2939; margin: -8px 0 20px {CHECKBOX_INDENT}px; font-size: 14px;'>{case.case_description}</div>",
51
+ f"<div style='color: #1D2939; margin: -8px 0 20px {CHECKBOX_INDENT}px; font-size: 14px;'>{uiCaseItem.description}</div>",
39
52
  unsafe_allow_html=True,
40
53
  )
41
54
 
42
55
  if selected:
43
- caseConfigSettingContainer = st.container()
44
56
  caseConfigSetting(
45
- caseConfigSettingContainer, allCaseConfigs, case, activedDbList
57
+ st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList
46
58
  )
47
59
 
48
- return selected
60
+ return uiCaseItem.cases if selected else []
49
61
 
50
62
 
51
- def caseConfigSetting(st, allCaseConfigs, case, activedDbList):
63
+ def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
52
64
  for db in activedDbList:
53
65
  columns = st.columns(1 + CASE_CONFIG_SETTING_COLUMNS)
54
66
  # column 0 - title
@@ -57,12 +69,12 @@ def caseConfigSetting(st, allCaseConfigs, case, activedDbList):
57
69
  f"<div style='margin: 0 0 24px {CHECKBOX_INDENT}px; font-size: 18px; font-weight: 600;'>{db.name}</div>",
58
70
  unsafe_allow_html=True,
59
71
  )
60
- caseConfig = allCaseConfigs[db][case]
61
72
  k = 0
62
- for config in CASE_CONFIG_MAP.get(db, {}).get(case.case_cls().label, []):
73
+ caseConfig = dbToCaseClusterConfigs[db][uiCaseItem]
74
+ for config in CASE_CONFIG_MAP.get(db, {}).get(uiCaseItem.caseLabel, []):
63
75
  if config.isDisplayed(caseConfig):
64
76
  column = columns[1 + k % CASE_CONFIG_SETTING_COLUMNS]
65
- key = "%s-%s-%s" % (db, case, config.label.value)
77
+ key = "%s-%s-%s" % (db, uiCaseItem.label, config.label.value)
66
78
  if config.inputType == InputType.Text:
67
79
  caseConfig[config.label] = column.text_input(
68
80
  config.displayLabel if config.displayLabel else config.label.value,
@@ -1,13 +1,9 @@
1
1
  from pydantic import ValidationError
2
- from vectordb_bench.frontend.const.styles import *
2
+ from vectordb_bench.frontend.config.styles import *
3
3
  from vectordb_bench.frontend.utils import inputIsPassword
4
4
 
5
5
 
6
6
  def dbConfigSettings(st, activedDbList):
7
- st.markdown(
8
- "<style> .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;}</style>",
9
- unsafe_allow_html=True,
10
- )
11
7
  expander = st.expander("Configurations for the selected databases", True)
12
8
 
13
9
  dbConfigs = {}
@@ -1,7 +1,6 @@
1
1
  from streamlit.runtime.media_file_storage import MediaFileStorageError
2
-
3
- from vectordb_bench.frontend.const.styles import *
4
- from vectordb_bench.frontend.const.dbCaseConfigs import DB_LIST
2
+ from vectordb_bench.frontend.config.styles import DB_SELECTOR_COLUMNS, DB_TO_ICON
3
+ from vectordb_bench.frontend.config.dbCaseConfigs import DB_LIST
5
4
 
6
5
 
7
6
  def dbSelector(st):
@@ -18,17 +17,6 @@ def dbSelector(st):
18
17
  dbContainerColumns = st.columns(DB_SELECTOR_COLUMNS, gap="small")
19
18
  dbIsActived = {db: False for db in DB_LIST}
20
19
 
21
- # style - image; column gap; checkbox font;
22
- st.markdown(
23
- """
24
- <style>
25
- div[data-testid='stImage'] {margin: auto;}
26
- div[data-testid='stHorizontalBlock'] {gap: 8px;}
27
- .stCheckbox p { color: #000; font-size: 18px; font-weight: 600; }
28
- </style>
29
- """,
30
- unsafe_allow_html=True,
31
- )
32
20
  for i, db in enumerate(DB_LIST):
33
21
  column = dbContainerColumns[i % DB_SELECTOR_COLUMNS]
34
22
  dbIsActived[db] = column.checkbox(db.name)
@@ -1,17 +1,15 @@
1
+ from vectordb_bench.backend.clients import DB
1
2
  from vectordb_bench.models import CaseConfig, CaseConfigParamType, TaskConfig
2
3
 
3
4
 
4
- def generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs):
5
+ def generate_tasks(activedDbList: list[DB], dbConfigs, activedCaseList: list[CaseConfig], allCaseConfigs):
5
6
  tasks = []
6
7
  for db in activedDbList:
7
8
  for case in activedCaseList:
8
9
  task = TaskConfig(
9
10
  db=db.value,
10
11
  db_config=dbConfigs[db],
11
- case_config=CaseConfig(
12
- case_id=case.value,
13
- custom_case={},
14
- ),
12
+ case_config=case,
15
13
  db_case_config=db.case_config_cls(
16
14
  allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None)
17
15
  )(**{key.value: value for key, value in allCaseConfigs[db][case].items()}),
@@ -0,0 +1,14 @@
1
+ def initStyle(st):
2
+ st.markdown(
3
+ """<style>
4
+ /* expander - header */
5
+ .main div[data-testid='stExpander'] p {font-size: 18px; font-weight: 600;}
6
+ /* db icon */
7
+ div[data-testid='stImage'] {margin: auto;}
8
+ /* db column gap */
9
+ div[data-testid='stHorizontalBlock'] {gap: 8px;}
10
+ /* check box */
11
+ .stCheckbox p { color: #000; font-size: 18px; font-weight: 600; }
12
+ </style>""",
13
+ unsafe_allow_html=True,
14
+ )
@@ -1,5 +1,5 @@
1
1
  from datetime import datetime
2
- from vectordb_bench.frontend.const.styles import *
2
+ from vectordb_bench.frontend.config.styles import *
3
3
  from vectordb_bench.interface import benchMarkRunner
4
4
 
5
5