vectordb-bench 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. vectordb_bench/__init__.py +14 -3
  2. vectordb_bench/backend/assembler.py +2 -2
  3. vectordb_bench/backend/cases.py +146 -57
  4. vectordb_bench/backend/clients/__init__.py +6 -1
  5. vectordb_bench/backend/clients/api.py +23 -11
  6. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  7. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +11 -9
  8. vectordb_bench/backend/clients/milvus/config.py +2 -3
  9. vectordb_bench/backend/clients/milvus/milvus.py +32 -19
  10. vectordb_bench/backend/clients/pgvector/config.py +49 -0
  11. vectordb_bench/backend/clients/pgvector/pgvector.py +171 -0
  12. vectordb_bench/backend/clients/pinecone/config.py +3 -3
  13. vectordb_bench/backend/clients/pinecone/pinecone.py +19 -13
  14. vectordb_bench/backend/clients/qdrant_cloud/config.py +23 -6
  15. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +12 -13
  16. vectordb_bench/backend/clients/weaviate_cloud/config.py +3 -3
  17. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +9 -8
  18. vectordb_bench/backend/clients/zilliz_cloud/config.py +5 -4
  19. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +3 -1
  20. vectordb_bench/backend/dataset.py +100 -162
  21. vectordb_bench/backend/result_collector.py +2 -2
  22. vectordb_bench/backend/runner/mp_runner.py +29 -13
  23. vectordb_bench/backend/runner/serial_runner.py +98 -36
  24. vectordb_bench/backend/task_runner.py +43 -48
  25. vectordb_bench/frontend/components/check_results/charts.py +10 -21
  26. vectordb_bench/frontend/components/check_results/data.py +31 -15
  27. vectordb_bench/frontend/components/check_results/expanderStyle.py +37 -0
  28. vectordb_bench/frontend/components/check_results/filters.py +61 -33
  29. vectordb_bench/frontend/components/check_results/footer.py +8 -0
  30. vectordb_bench/frontend/components/check_results/headerIcon.py +8 -4
  31. vectordb_bench/frontend/components/check_results/nav.py +7 -6
  32. vectordb_bench/frontend/components/check_results/priceTable.py +3 -2
  33. vectordb_bench/frontend/components/check_results/stPageConfig.py +18 -0
  34. vectordb_bench/frontend/components/get_results/saveAsImage.py +50 -0
  35. vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
  36. vectordb_bench/frontend/components/run_test/caseSelector.py +19 -16
  37. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +20 -7
  38. vectordb_bench/frontend/components/run_test/dbSelector.py +5 -5
  39. vectordb_bench/frontend/components/run_test/hideSidebar.py +4 -6
  40. vectordb_bench/frontend/components/run_test/submitTask.py +16 -10
  41. vectordb_bench/frontend/const/dbCaseConfigs.py +291 -0
  42. vectordb_bench/frontend/const/dbPrices.py +6 -0
  43. vectordb_bench/frontend/const/styles.py +58 -0
  44. vectordb_bench/frontend/pages/{qps_with_price.py → quries_per_dollar.py} +24 -17
  45. vectordb_bench/frontend/pages/run_test.py +17 -11
  46. vectordb_bench/frontend/vdb_benchmark.py +19 -12
  47. vectordb_bench/metric.py +19 -10
  48. vectordb_bench/models.py +14 -40
  49. vectordb_bench/results/dbPrices.json +32 -0
  50. vectordb_bench/results/getLeaderboardData.py +52 -0
  51. vectordb_bench/results/leaderboard.json +1 -0
  52. vectordb_bench/results/{result_20230609_standard.json → result_20230705_standard.json} +1910 -897
  53. {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/METADATA +107 -27
  54. vectordb_bench-0.0.3.dist-info/RECORD +67 -0
  55. vectordb_bench/frontend/const.py +0 -391
  56. vectordb_bench-0.0.1.dist-info/RECORD +0 -56
  57. {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/LICENSE +0 -0
  58. {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/WHEEL +0 -0
  59. {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/entry_points.txt +0 -0
  60. {vectordb_bench-0.0.1.dist-info → vectordb_bench-0.0.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import psutil
2
3
  import traceback
3
4
  import concurrent
4
5
  import numpy as np
@@ -7,7 +8,7 @@ from enum import Enum, auto
7
8
  from . import utils
8
9
  from .cases import Case, CaseLabel
9
10
  from ..base import BaseModel
10
- from ..models import TaskConfig
11
+ from ..models import TaskConfig, PerformanceTimeoutError
11
12
 
12
13
  from .clients import (
13
14
  api,
@@ -92,80 +93,70 @@ class CaseRunner(BaseModel):
92
93
  self._pre_run(drop_old)
93
94
 
94
95
  if self.ca.label == CaseLabel.Load:
95
- return self._run_load_case()
96
+ return self._run_capacity_case()
96
97
  elif self.ca.label == CaseLabel.Performance:
97
98
  return self._run_perf_case(drop_old)
98
99
  else:
99
- log.warning(f"unknown case type: {self.ca.label}")
100
- raise ValueError(f"Unknown case type: {self.ca.label}")
100
+ msg = f"unknown case type: {self.ca.label}"
101
+ log.warning(msg)
102
+ raise ValueError(msg)
101
103
 
102
-
103
- def _run_load_case(self) -> Metric:
104
- """ run load cases
104
+ def _run_capacity_case(self) -> Metric:
105
+ """ run capacity cases
105
106
 
106
107
  Returns:
107
108
  Metric: the max load count
108
109
  """
109
110
  log.info("Start capacity case")
110
- # datasets for load tests are quite small, can fit into memory
111
- # only 1 file
112
- data_df = [data_df for data_df in self.ca.dataset][0]
113
-
114
- all_embeddings, all_metadata = np.stack(data_df["emb"]).tolist(), data_df['id'].tolist()
115
- runner = SerialInsertRunner(self.db, all_embeddings, all_metadata)
116
111
  try:
112
+ runner = SerialInsertRunner(self.db, self.ca.dataset, self.normalize, self.ca.load_timeout)
117
113
  count = runner.run_endlessness()
118
- log.info(f"load reach limit: insertion counts={count}")
119
- return Metric(max_load_count=count)
120
114
  except Exception as e:
121
- log.warning(f"run capacity case error: {e}")
115
+ log.warning(f"Failed to run capacity case, reason = {e}")
122
116
  raise e from None
123
- log.info("End capacity case")
124
-
117
+ else:
118
+ log.info(f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}")
119
+ return Metric(max_load_count=count)
125
120
 
126
121
  def _run_perf_case(self, drop_old: bool = True) -> Metric:
122
+ """ run performance cases
123
+
124
+ Returns:
125
+ Metric: load_duration, recall, serial_latency_p99, and, qps
126
+ """
127
127
  try:
128
128
  m = Metric()
129
129
  if drop_old:
130
130
  _, load_dur = self._load_train_data()
131
131
  build_dur = self._optimize()
132
132
  m.load_duration = round(load_dur+build_dur, 4)
133
+ log.info(
134
+ f"Finish loading the entire dataset into VectorDB,"
135
+ f" insert_duration={load_dur}, optimize_duration={build_dur}"
136
+ f" load_duration(insert + optimize) = {m.load_duration}"
137
+ )
133
138
 
134
139
  self._init_search_runner()
135
140
  m.recall, m.serial_latency_p99 = self._serial_search()
136
141
  m.qps = self._conc_search()
137
-
138
- log.info(f"got results: {m}")
139
- return m
140
142
  except Exception as e:
141
- log.warning(f"performance case run error: {e}")
143
+ log.warning(f"Failed to run performance case, reason = {e}")
142
144
  traceback.print_exc()
143
- raise e
145
+ raise e from None
146
+ else:
147
+ log.info(f"Performance case got result: {m}")
148
+ return m
144
149
 
145
150
  @utils.time_it
146
151
  def _load_train_data(self):
147
152
  """Insert train data and get the insert_duration"""
148
- for data_df in self.ca.dataset:
149
- try:
150
- all_metadata = data_df['id'].tolist()
151
-
152
- emb_np = np.stack(data_df['emb'])
153
- if self.normalize:
154
- log.debug("normalize the 100k train data")
155
- all_embeddings = emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis].tolist()
156
- else:
157
- all_embeddings = emb_np.tolist()
158
-
159
- del(emb_np)
160
- log.debug(f"normalized size: {len(all_embeddings)}, {len(all_metadata)}")
161
-
162
- runner = SerialInsertRunner(self.db, all_embeddings, all_metadata)
163
- runner.run()
164
- except Exception as e:
165
- raise e from None
166
- finally:
167
- runner = None
168
-
153
+ try:
154
+ runner = SerialInsertRunner(self.db, self.ca.dataset, self.normalize, self.ca.load_timeout)
155
+ runner.run()
156
+ except Exception as e:
157
+ raise e from None
158
+ finally:
159
+ runner = None
169
160
 
170
161
  def _serial_search(self) -> tuple[float, float]:
171
162
  """Performance serial tests, search the entire test data once,
@@ -198,17 +189,21 @@ class CaseRunner(BaseModel):
198
189
 
199
190
  @utils.time_it
200
191
  def _task(self) -> None:
201
- """"""
202
192
  with self.db.init():
203
- self.db.ready_to_search()
193
+ self.db.optimize()
204
194
 
205
195
  def _optimize(self) -> float:
206
196
  with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
207
197
  future = executor.submit(self._task)
208
198
  try:
209
- return future.result()[1]
199
+ return future.result(timeout=self.ca.optimize_timeout)[1]
200
+ except TimeoutError as e:
201
+ log.warning(f"VectorDB optimize timeout in {self.ca.optimize_timeout}")
202
+ for pid, _ in executor._processes.items():
203
+ psutil.Process(pid).kill()
204
+ raise PerformanceTimeoutError("Performance case optimize timeout") from e
210
205
  except Exception as e:
211
- log.warning(f"VectorDB ready_to_search error: {e}")
206
+ log.warning(f"VectorDB optimize error: {e}")
212
207
  raise e from None
213
208
 
214
209
  def _init_search_runner(self):
@@ -1,30 +1,19 @@
1
+ from vectordb_bench.backend.cases import Case
2
+ from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
1
3
  from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
2
- from vectordb_bench.frontend.const import *
4
+ from vectordb_bench.frontend.const.styles import *
3
5
  from vectordb_bench.models import ResultLabel
4
6
  import plotly.express as px
5
7
 
6
8
 
7
- def drawCharts(st, allData, failedTasks, cases):
8
- st.markdown(
9
- "<style> .main .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;} </style>",
10
- unsafe_allow_html=True,
11
- )
12
- st.markdown(
13
- """<style>
14
- .main div[data-testid='stExpander'] {
15
- background-color: #F6F8FA;
16
- border: 1px solid #A9BDD140;
17
- border-radius: 8px;
18
- }
19
- </style>""",
20
- unsafe_allow_html=True,
21
- )
9
+ def drawCharts(st, allData, failedTasks, cases: list[Case]):
10
+ initMainExpanderStyle(st)
22
11
  for case in cases:
23
- chartContainer = st.expander(case, True)
24
- data = [data for data in allData if data["case"] == case]
12
+ chartContainer = st.expander(case.name, True)
13
+ data = [data for data in allData if data["case_name"] == case.name]
25
14
  drawChart(data, chartContainer)
26
15
 
27
- errorDBs = failedTasks[case]
16
+ errorDBs = failedTasks[case.name]
28
17
  showFailedDBs(chartContainer, errorDBs)
29
18
 
30
19
 
@@ -102,7 +91,7 @@ def drawMetricChart(data, metric, st):
102
91
  xmin = 0
103
92
  xmax = max([d.get(metric, 0) for d in dataWithMetric])
104
93
  xpadding = (xmax - xmin) / 16
105
- xpadding_multiplier = 1.6
94
+ xpadding_multiplier = 1.8
106
95
  xrange = [xmin, xmax + xpadding * xpadding_multiplier]
107
96
  unit = metricUnitMap.get(metric, "")
108
97
  labelToShapeMap = getLabelToShapeMap(dataWithMetric)
@@ -136,7 +125,7 @@ def drawMetricChart(data, metric, st):
136
125
  font=dict(
137
126
  size=1,
138
127
  ),
139
- # text="",
128
+ text="",
140
129
  )
141
130
  )
142
131
  fig.update_traces(
@@ -1,63 +1,78 @@
1
1
  from collections import defaultdict
2
2
  from dataclasses import asdict
3
+ from vectordb_bench.backend.cases import Case
3
4
  from vectordb_bench.metric import isLowerIsBetterMetric
4
- from vectordb_bench.models import ResultLabel
5
+ from vectordb_bench.models import CaseResult, ResultLabel
5
6
 
6
7
 
7
- def getChartData(tasks, dbNames, cases):
8
+ def getChartData(
9
+ tasks: list[CaseResult],
10
+ dbNames: list[str],
11
+ cases: list[Case],
12
+ ):
8
13
  filterTasks = getFilterTasks(tasks, dbNames, cases)
9
14
  mergedTasks, failedTasks = mergeTasks(filterTasks)
10
15
  return mergedTasks, failedTasks
11
16
 
12
17
 
13
- def getFilterTasks(tasks, dbNames, cases):
18
+ def getFilterTasks(
19
+ tasks: list[CaseResult],
20
+ dbNames: list[str],
21
+ cases: list[Case],
22
+ ) -> list[CaseResult]:
23
+ case_ids = [case.case_id for case in cases]
14
24
  filterTasks = [
15
25
  task
16
26
  for task in tasks
17
27
  if task.task_config.db_name in dbNames
18
- and task.task_config.case_config.case_id.value in cases
28
+ and task.task_config.case_config.case_id in case_ids
19
29
  ]
20
30
  return filterTasks
21
31
 
22
32
 
23
- def mergeTasks(tasks):
33
+ def mergeTasks(tasks: list[CaseResult]):
24
34
  dbCaseMetricsMap = defaultdict(lambda: defaultdict(dict))
25
35
  for task in tasks:
26
36
  db_name = task.task_config.db_name
27
37
  db = task.task_config.db.value
28
38
  db_label = task.task_config.db_config.db_label or ""
29
- case = task.task_config.case_config.case_id.value
30
- dbCaseMetricsMap[db_name][case] = {
39
+ case_id = task.task_config.case_config.case_id
40
+ dbCaseMetricsMap[db_name][case_id] = {
31
41
  "db": db,
32
42
  "db_label": db_label,
33
43
  "metrics": mergeMetrics(
34
- dbCaseMetricsMap[db_name][case].get("metrics", {}), asdict(task.metrics)
44
+ dbCaseMetricsMap[db_name][case_id].get("metrics", {}),
45
+ asdict(task.metrics),
46
+ ),
47
+ "label": getBetterLabel(
48
+ dbCaseMetricsMap[db_name][case_id].get("label", ResultLabel.FAILED),
49
+ task.label,
35
50
  ),
36
- "label": getBetterLabel(dbCaseMetricsMap[db_name][case].get("label", ResultLabel.FAILED), task.label)
37
51
  }
38
52
 
39
53
  mergedTasks = []
40
54
  failedTasks = defaultdict(lambda: defaultdict(str))
41
55
  for db_name, caseMetricsMap in dbCaseMetricsMap.items():
42
- for case, metricInfo in caseMetricsMap.items():
56
+ for case_id, metricInfo in caseMetricsMap.items():
43
57
  metrics = metricInfo["metrics"]
44
58
  db = metricInfo["db"]
45
59
  db_label = metricInfo["db_label"]
46
60
  label = metricInfo["label"]
61
+ case_name = case_id.case_name
47
62
  if label == ResultLabel.NORMAL:
48
63
  mergedTasks.append(
49
64
  {
50
65
  "db_name": db_name,
51
66
  "db": db,
52
67
  "db_label": db_label,
53
- "case": case,
68
+ "case_name": case_name,
54
69
  "metricsSet": set(metrics.keys()),
55
70
  **metrics,
56
71
  }
57
72
  )
58
- else:
59
- failedTasks[case][db_name] = label
60
-
73
+ else:
74
+ failedTasks[case_name][db_name] = label
75
+
61
76
  return mergedTasks, failedTasks
62
77
 
63
78
 
@@ -81,6 +96,7 @@ def getBetterMetric(metric, value_1, value_2):
81
96
  if isLowerIsBetterMetric(metric)
82
97
  else max(value_1, value_2)
83
98
  )
84
-
99
+
100
+
85
101
  def getBetterLabel(label_1: ResultLabel, label_2: ResultLabel):
86
102
  return label_2 if label_1 != ResultLabel.NORMAL else label_1
@@ -0,0 +1,37 @@
1
+ def initMainExpanderStyle(st):
2
+ st.markdown(
3
+ """<style>
4
+ .main .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;}
5
+ .main div[data-testid='stExpander'] {
6
+ background-color: #F6F8FA;
7
+ border: 1px solid #A9BDD140;
8
+ border-radius: 8px;
9
+ }
10
+ </style>""",
11
+ unsafe_allow_html=True,
12
+ )
13
+
14
+
15
+ def initSidebarExanderStyle(st):
16
+ st.markdown(
17
+ """<style>
18
+ section[data-testid='stSidebar']
19
+ div[data-testid='stExpander']
20
+ div[data-testid='stVerticalBlock']
21
+ { gap: 0.2rem; }
22
+ div[data-testid='stExpander']
23
+ { background-color: #ffffff; }
24
+ section[data-testid='stSidebar']
25
+ .streamlit-expanderHeader
26
+ p { font-size: 16px; font-weight: 600; }
27
+ section[data-testid='stSidebar']
28
+ div[data-testid='stExpander']
29
+ div[data-testid='stVerticalBlock']
30
+ button {
31
+ padding: 0 0.5rem;
32
+ margin-bottom: 8px;
33
+ float: right;
34
+ }
35
+ <style>""",
36
+ unsafe_allow_html=True,
37
+ )
@@ -1,8 +1,14 @@
1
+ from vectordb_bench.backend.cases import Case
1
2
  from vectordb_bench.frontend.components.check_results.data import getChartData
2
- from vectordb_bench.frontend.const import *
3
+ from vectordb_bench.frontend.components.check_results.expanderStyle import initSidebarExanderStyle
4
+ from vectordb_bench.frontend.const.dbCaseConfigs import CASE_LIST
5
+ from vectordb_bench.frontend.const.styles import *
6
+ import streamlit as st
3
7
 
8
+ from vectordb_bench.models import CaseResult, TestResult
4
9
 
5
- def getshownData(results, st):
10
+
11
+ def getshownData(results: list[TestResult], st):
6
12
  # hide the nav
7
13
  st.markdown(
8
14
  "<style> div[data-testid='stSidebarNav'] {display: none;} </style>",
@@ -19,7 +25,7 @@ def getshownData(results, st):
19
25
  return shownData, failedTasks, showCases
20
26
 
21
27
 
22
- def getshownResults(results, st):
28
+ def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
23
29
  resultSelectOptions = [
24
30
  result.task_label
25
31
  if result.task_label != result.run_id
@@ -38,7 +44,7 @@ def getshownResults(results, st):
38
44
  # label_visibility="hidden",
39
45
  default=resultSelectOptions,
40
46
  )
41
- selectedResult = []
47
+ selectedResult: list[CaseResult] = []
42
48
  for option in selectedResultSelectedOptions:
43
49
  result = results[resultSelectOptions.index(option)].results
44
50
  selectedResult += result
@@ -46,52 +52,74 @@ def getshownResults(results, st):
46
52
  return selectedResult
47
53
 
48
54
 
49
- def getShowDbsAndCases(result, st):
50
- # expanderStyles
51
- st.markdown("<style> section[data-testid='stSidebar'] div[data-testid='stExpander'] div[data-testid='stVerticalBlock'] { gap: 0.2rem; } </style>", unsafe_allow_html=True,)
52
- st.markdown(
53
- "<style> div[data-testid='stExpander'] {background-color: #ffffff;} </style>",
54
- unsafe_allow_html=True,
55
- )
56
- st.markdown(
57
- "<style> section[data-testid='stSidebar'] .streamlit-expanderHeader p {font-size: 16px; font-weight: 600;} </style>",
58
- unsafe_allow_html=True,
59
- )
60
-
55
+ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[Case]]:
56
+ initSidebarExanderStyle(st)
61
57
  allDbNames = list(set({res.task_config.db_name for res in result}))
62
58
  allDbNames.sort()
63
59
  allCasesSet = set({res.task_config.case_config.case_id for res in result})
64
- allCases = [case["name"].value for case in CASE_LIST if case["name"] in allCasesSet]
65
-
66
- # dbFilterContainer = st.container()
67
- # dbFilterContainer.subheader("DB Filter")
68
- dbFilterContainer = st.expander("DB Filter", True)
69
- showDBNames = filterView(allDbNames, dbFilterContainer, col=1)
60
+ allCases: list[Case] = [case.case_cls() for case in CASE_LIST if case in allCasesSet]
61
+
62
+ # DB Filter
63
+ dbFilterContainer = st.container()
64
+ showDBNames = filterView(
65
+ dbFilterContainer,
66
+ "DB Filter",
67
+ allDbNames,
68
+ col=1,
69
+ )
70
70
 
71
- # caseFilterContainer = st.container()
72
- # caseFilterContainer.subheader("Case Filter")
73
- caseFilterContainer = st.expander("Case Filter", True)
71
+ # Case Filter
72
+ caseFilterContainer = st.container()
74
73
  showCases = filterView(
75
- allCases,
76
74
  caseFilterContainer,
75
+ "Case Filter",
76
+ [case for case in allCases],
77
77
  col=1,
78
- optionLables=[case for case in allCases],
78
+ optionLables=[case.name for case in allCases],
79
79
  )
80
80
 
81
81
  return showDBNames, showCases
82
82
 
83
83
 
84
- def filterView(options, st, col, optionLables=None):
85
- columns = st.columns(
84
+ def filterView(container, header, options, col, optionLables=None):
85
+ selectAllState = f"{header}-select-all-state"
86
+ if selectAllState not in st.session_state:
87
+ st.session_state[selectAllState] = True
88
+
89
+ countKeyState = f"{header}-select-all-count-key"
90
+ if countKeyState not in st.session_state:
91
+ st.session_state[countKeyState] = 0
92
+
93
+ expander = container.expander(header, True)
94
+ selectAllColumns = expander.columns(SIDEBAR_CONTROL_COLUMNS, gap="small")
95
+ selectAllButton = selectAllColumns[SIDEBAR_CONTROL_COLUMNS - 2].button(
96
+ "select all",
97
+ key=f"{header}-select-all-button",
98
+ # type="primary",
99
+ )
100
+ clearAllButton = selectAllColumns[SIDEBAR_CONTROL_COLUMNS - 1].button(
101
+ "clear all",
102
+ key=f"{header}-clear-all-button",
103
+ # type="primary",
104
+ )
105
+ if selectAllButton:
106
+ st.session_state[selectAllState] = True
107
+ st.session_state[countKeyState] += 1
108
+ if clearAllButton:
109
+ st.session_state[selectAllState] = False
110
+ st.session_state[countKeyState] += 1
111
+ columns = expander.columns(
86
112
  col,
87
113
  gap="small",
88
114
  )
89
- isActive = {option: True for option in options}
90
115
  if optionLables is None:
91
116
  optionLables = options
92
- for i, option in enumerate(options):
117
+ isActive = {option: st.session_state[selectAllState] for option in optionLables}
118
+ for i, option in enumerate(optionLables):
93
119
  isActive[option] = columns[i % col].checkbox(
94
- optionLables[i], value=isActive[option]
120
+ optionLables[i],
121
+ value=isActive[option],
122
+ key=f"{optionLables[i]}-{st.session_state[countKeyState]}",
95
123
  )
96
124
 
97
- return [option for option in options if isActive[option]]
125
+ return [options[i] for i, option in enumerate(optionLables) if isActive[option]]
@@ -0,0 +1,8 @@
1
+ def footer(st):
2
+ text = "* All test results are from community contributors. If there is any ambiguity, feel free to raise an issue or make amendments on our <a href='https://github.com/zilliztech/VectorDBBench'>GitHub page</a>."
3
+ st.markdown(
4
+ f"""
5
+ <div style="margin-top: 16px; color: #aaa; font-size: 14px;">{text}</div
6
+ """,
7
+ unsafe_allow_html=True,
8
+ )
@@ -1,17 +1,21 @@
1
+ from vectordb_bench.frontend.const.styles import HEADER_ICON
2
+
3
+
1
4
  def drawHeaderIcon(st):
2
- st.markdown("""
5
+ st.markdown(
6
+ f"""
3
7
  <div class="headerIconContainer"></div>
4
8
 
5
9
  <style>
6
- .headerIconContainer {
10
+ .headerIconContainer {{
7
11
  position: absolute;
8
12
  top: -50px;
9
13
  height: 50px;
10
14
  width: 100%;
11
15
  border-bottom: 2px solid #E8EAEE;
12
- background-image: url(https://assets.zilliz.com/vdb_benchmark_db790b5387.png);
16
+ background-image: url({HEADER_ICON});
13
17
  background-repeat: no-repeat;
14
- }
18
+ }}
15
19
  </style
16
20
  """,
17
21
  unsafe_allow_html=True,
@@ -2,20 +2,21 @@ from streamlit_extras.switch_page_button import switch_page
2
2
 
3
3
 
4
4
  def NavToRunTest(st):
5
- st.header("Run your test")
5
+ st.subheader("Run your test")
6
6
  st.write("You can set the configs and run your own test.")
7
7
  navClick = st.button("Run Your Test &nbsp;&nbsp;>")
8
8
  if navClick:
9
9
  switch_page("run test")
10
10
 
11
11
 
12
- def NavToQPSWithPrice(st):
13
- navClick = st.button("QPS with Price &nbsp;&nbsp;>")
12
+ def NavToQuriesPerDollar(st):
13
+ st.subheader("Compare qps with price.")
14
+ navClick = st.button("QP$ (Quries per Dollar) &nbsp;&nbsp;>")
14
15
  if navClick:
15
- switch_page("qps with price")
16
+ switch_page("quries_per_dollar")
16
17
 
17
18
 
18
- def NavToResults(st):
19
- navClick = st.button("< &nbsp;&nbsp;Back to Results")
19
+ def NavToResults(st, key="nav-to-results"):
20
+ navClick = st.button("< &nbsp;&nbsp;Back to Results", key=key)
20
21
  if navClick:
21
22
  switch_page("vdb benchmark")
@@ -1,9 +1,10 @@
1
1
  from vectordb_bench.backend.clients import DB
2
- from vectordb_bench.frontend.const import DB_DBLABEL_TO_PRICE
3
2
  import pandas as pd
4
3
  from collections import defaultdict
5
4
  import streamlit as st
6
5
 
6
+ from vectordb_bench.frontend.const.dbPrices import DB_DBLABEL_TO_PRICE
7
+
7
8
 
8
9
  def priceTable(container, data):
9
10
  dbAndLabelSet = {
@@ -25,7 +26,7 @@ def priceTable(container, data):
25
26
  )
26
27
  height = len(table) * 35 + 38
27
28
 
28
- expander = container.expander("You can edit the price.")
29
+ expander = container.expander("Price List (Editable).")
29
30
  editTable = expander.data_editor(
30
31
  table,
31
32
  use_container_width=True,
@@ -0,0 +1,18 @@
1
+ from vectordb_bench.frontend.const.styles import *
2
+
3
+
4
+ def initResultsPageConfig(st):
5
+ st.set_page_config(
6
+ page_title=PAGE_TITLE,
7
+ page_icon=FAVICON,
8
+ # layout="wide",
9
+ # initial_sidebar_state="collapsed",
10
+ )
11
+
12
+ def initRunTestPageConfig(st):
13
+ st.set_page_config(
14
+ page_title=PAGE_TITLE,
15
+ page_icon=FAVICON,
16
+ # layout="wide",
17
+ initial_sidebar_state="collapsed",
18
+ )
@@ -0,0 +1,50 @@
1
+ import requests
2
+ import streamlit as st
3
+ import streamlit.components.v1 as components
4
+
5
+ HTML_2_CANVAS_URL = "https://unpkg.com/html2canvas@1.4.1/dist/html2canvas.js"
6
+
7
+
8
+ @st.cache_data
9
+ def load_unpkg(src: str) -> str:
10
+ return requests.get(src).text
11
+
12
+ def getResults(container, pageName="vectordb_bench"):
13
+ container.subheader("Get results")
14
+ saveAsImage(container, pageName)
15
+
16
+ def saveAsImage(container, pageName):
17
+ html2canvasJS = load_unpkg(HTML_2_CANVAS_URL)
18
+ container.write()
19
+ buttonText = "Save as Image"
20
+ savePDFButton = container.button(buttonText)
21
+ if savePDFButton:
22
+ components.html(
23
+ f"""
24
+ <script>{html2canvasJS}</script>
25
+
26
+ <script>
27
+ const html2canvas = window.html2canvas
28
+
29
+ const streamlitDoc = window.parent.document;
30
+ const stApp = streamlitDoc.querySelector('.main > .block-container');
31
+
32
+ const buttons = Array.from(streamlitDoc.querySelectorAll('.stButton > button'));
33
+ const imgButton = buttons.find(el => el.innerText === '{buttonText}');
34
+
35
+ if (imgButton)
36
+ imgButton.innerText = 'Creating Image...';
37
+
38
+ html2canvas(stApp, {{ allowTaint: false, useCORS: true }}).then(function (canvas) {{
39
+ a = document.createElement('a');
40
+ a.href = canvas.toDataURL("image/jpeg", 1.0).replace("image/jpeg", "image/octet-stream");
41
+ a.download = '{pageName}.png';
42
+ a.click();
43
+
44
+ if (imgButton)
45
+ imgButton.innerText = '{buttonText}';
46
+ }})
47
+ </script>""",
48
+ height=0,
49
+ width=0,
50
+ )
@@ -1,5 +1,5 @@
1
1
  from streamlit_autorefresh import st_autorefresh
2
- from vectordb_bench.frontend.const import *
2
+ from vectordb_bench.frontend.const.styles import *
3
3
 
4
4
 
5
5
  def autoRefresh():