vectordb-bench 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. vectordb_bench/__init__.py +1 -0
  2. vectordb_bench/backend/assembler.py +1 -1
  3. vectordb_bench/backend/cases.py +64 -18
  4. vectordb_bench/backend/clients/__init__.py +35 -0
  5. vectordb_bench/backend/clients/api.py +21 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
  7. vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
  8. vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
  9. vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
  10. vectordb_bench/backend/clients/memorydb/cli.py +88 -0
  11. vectordb_bench/backend/clients/memorydb/config.py +54 -0
  12. vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
  13. vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
  14. vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
  15. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
  16. vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
  17. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +272 -0
  18. vectordb_bench/backend/dataset.py +27 -5
  19. vectordb_bench/cli/vectordbbench.py +7 -0
  20. vectordb_bench/custom/custom_case.json +18 -0
  21. vectordb_bench/frontend/components/check_results/charts.py +6 -6
  22. vectordb_bench/frontend/components/check_results/data.py +18 -11
  23. vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
  24. vectordb_bench/frontend/components/check_results/filters.py +20 -13
  25. vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
  26. vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
  27. vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
  28. vectordb_bench/frontend/components/concurrent/charts.py +26 -29
  29. vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
  30. vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
  31. vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
  32. vectordb_bench/frontend/components/custom/initStyle.py +15 -0
  33. vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
  34. vectordb_bench/frontend/components/run_test/caseSelector.py +50 -28
  35. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -19
  36. vectordb_bench/frontend/components/run_test/dbSelector.py +2 -14
  37. vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
  38. vectordb_bench/frontend/components/run_test/initStyle.py +16 -0
  39. vectordb_bench/frontend/components/run_test/submitTask.py +1 -1
  40. vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +311 -40
  41. vectordb_bench/frontend/{const → config}/styles.py +2 -0
  42. vectordb_bench/frontend/pages/concurrent.py +11 -18
  43. vectordb_bench/frontend/pages/custom.py +64 -0
  44. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
  45. vectordb_bench/frontend/pages/run_test.py +4 -0
  46. vectordb_bench/frontend/pages/tables.py +2 -2
  47. vectordb_bench/frontend/utils.py +17 -1
  48. vectordb_bench/frontend/vdb_benchmark.py +3 -3
  49. vectordb_bench/models.py +26 -10
  50. vectordb_bench/results/getLeaderboardData.py +1 -1
  51. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/METADATA +46 -15
  52. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/RECORD +57 -40
  53. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/WHEEL +1 -1
  54. /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
  55. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/LICENSE +0 -0
  56. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/entry_points.txt +0 -0
  57. {vectordb_bench-0.0.11.dist-info → vectordb_bench-0.0.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,272 @@
1
+ """Wrapper around the Pgvectorscale vector database over VectorDB"""
2
+
3
+ import logging
4
+ import pprint
5
+ from contextlib import contextmanager
6
+ from typing import Any, Generator, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import psycopg
10
+ from pgvector.psycopg import register_vector
11
+ from psycopg import Connection, Cursor, sql
12
+
13
+ from ..api import VectorDB
14
+ from .config import PgVectorScaleConfigDict, PgVectorScaleIndexConfig
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class PgVectorScale(VectorDB):
20
+ """Use psycopg instructions"""
21
+
22
+ conn: psycopg.Connection[Any] | None = None
23
+ coursor: psycopg.Cursor[Any] | None = None
24
+
25
+ def __init__(
26
+ self,
27
+ dim: int,
28
+ db_config: PgVectorScaleConfigDict,
29
+ db_case_config: PgVectorScaleIndexConfig,
30
+ collection_name: str = "pg_vectorscale_collection",
31
+ drop_old: bool = False,
32
+ **kwargs,
33
+ ):
34
+ self.name = "PgVectorScale"
35
+ self.db_config = db_config
36
+ self.case_config = db_case_config
37
+ self.table_name = collection_name
38
+ self.dim = dim
39
+
40
+ self._index_name = "pgvectorscale_index"
41
+ self._primary_field = "id"
42
+ self._vector_field = "embedding"
43
+
44
+ self.conn, self.cursor = self._create_connection(**self.db_config)
45
+
46
+ log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
47
+ if not any(
48
+ (
49
+ self.case_config.create_index_before_load,
50
+ self.case_config.create_index_after_load,
51
+ )
52
+ ):
53
+ err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
54
+ log.error(err)
55
+ raise RuntimeError(
56
+ f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
57
+ )
58
+
59
+ if drop_old:
60
+ self._drop_index()
61
+ self._drop_table()
62
+ self._create_table(dim)
63
+ if self.case_config.create_index_before_load:
64
+ self._create_index()
65
+
66
+ self.cursor.close()
67
+ self.conn.close()
68
+ self.cursor = None
69
+ self.conn = None
70
+
71
+ @staticmethod
72
+ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
73
+ conn = psycopg.connect(**kwargs)
74
+ conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
75
+ conn.commit()
76
+ register_vector(conn)
77
+ conn.autocommit = False
78
+ cursor = conn.cursor()
79
+
80
+ assert conn is not None, "Connection is not initialized"
81
+ assert cursor is not None, "Cursor is not initialized"
82
+
83
+ return conn, cursor
84
+
85
+ @contextmanager
86
+ def init(self) -> Generator[None, None, None]:
87
+ self.conn, self.cursor = self._create_connection(**self.db_config)
88
+
89
+ # index configuration may have commands defined that we should set during each client session
90
+ session_options: dict[str, Any] = self.case_config.session_param()
91
+
92
+ if len(session_options) > 0:
93
+ for setting_name, setting_val in session_options.items():
94
+ command = sql.SQL("SET {setting_name} " + "= {setting_val};").format(
95
+ setting_name=sql.Identifier(setting_name),
96
+ setting_val=sql.Identifier(str(setting_val)),
97
+ )
98
+ log.debug(command.as_string(self.cursor))
99
+ self.cursor.execute(command)
100
+ self.conn.commit()
101
+
102
+ self._unfiltered_search = sql.Composed(
103
+ [
104
+ sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
105
+ sql.Identifier(self.table_name)
106
+ ),
107
+ sql.SQL(self.case_config.search_param()["metric_fun_op"]),
108
+ sql.SQL(" %s::vector LIMIT %s::int"),
109
+ ]
110
+ )
111
+
112
+ try:
113
+ yield
114
+ finally:
115
+ self.cursor.close()
116
+ self.conn.close()
117
+ self.cursor = None
118
+ self.conn = None
119
+
120
+ def _drop_table(self):
121
+ assert self.conn is not None, "Connection is not initialized"
122
+ assert self.cursor is not None, "Cursor is not initialized"
123
+ log.info(f"{self.name} client drop table : {self.table_name}")
124
+
125
+ self.cursor.execute(
126
+ sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
127
+ table_name=sql.Identifier(self.table_name)
128
+ )
129
+ )
130
+ self.conn.commit()
131
+
132
+ def ready_to_load(self):
133
+ pass
134
+
135
+ def optimize(self):
136
+ self._post_insert()
137
+
138
+ def _post_insert(self):
139
+ log.info(f"{self.name} post insert before optimize")
140
+ if self.case_config.create_index_after_load:
141
+ self._drop_index()
142
+ self._create_index()
143
+
144
+ def _drop_index(self):
145
+ assert self.conn is not None, "Connection is not initialized"
146
+ assert self.cursor is not None, "Cursor is not initialized"
147
+ log.info(f"{self.name} client drop index : {self._index_name}")
148
+
149
+ drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
150
+ index_name=sql.Identifier(self._index_name)
151
+ )
152
+ log.debug(drop_index_sql.as_string(self.cursor))
153
+ self.cursor.execute(drop_index_sql)
154
+ self.conn.commit()
155
+
156
+ def _create_index(self):
157
+ assert self.conn is not None, "Connection is not initialized"
158
+ assert self.cursor is not None, "Cursor is not initialized"
159
+ log.info(f"{self.name} client create index : {self._index_name}")
160
+
161
+ index_param: dict[str, Any] = self.case_config.index_param()
162
+
163
+ options = []
164
+ for option_name, option_val in index_param["options"].items():
165
+ if option_val is not None:
166
+ options.append(
167
+ sql.SQL("{option_name} = {val}").format(
168
+ option_name=sql.Identifier(option_name),
169
+ val=sql.Identifier(str(option_val)),
170
+ )
171
+ )
172
+
173
+ num_bits_per_dimension = "2" if self.dim < 900 else "1"
174
+ options.append(
175
+ sql.SQL("{option_name} = {val}").format(
176
+ option_name=sql.Identifier("num_bits_per_dimension"),
177
+ val=sql.Identifier(num_bits_per_dimension),
178
+ )
179
+ )
180
+
181
+ if any(options):
182
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
183
+ else:
184
+ with_clause = sql.Composed(())
185
+
186
+ index_create_sql = sql.SQL(
187
+ """
188
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
189
+ USING {index_type} (embedding {embedding_metric})
190
+ """
191
+ ).format(
192
+ index_name=sql.Identifier(self._index_name),
193
+ table_name=sql.Identifier(self.table_name),
194
+ index_type=sql.Identifier(index_param["index_type"].lower()),
195
+ embedding_metric=sql.Identifier(index_param["metric"]),
196
+ )
197
+ index_create_sql_with_with_clause = (
198
+ index_create_sql + with_clause
199
+ ).join(" ")
200
+ log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
201
+ self.cursor.execute(index_create_sql_with_with_clause)
202
+ self.conn.commit()
203
+
204
+ def _create_table(self, dim: int):
205
+ assert self.conn is not None, "Connection is not initialized"
206
+ assert self.cursor is not None, "Cursor is not initialized"
207
+
208
+ try:
209
+ log.info(f"{self.name} client create table : {self.table_name}")
210
+
211
+ self.cursor.execute(
212
+ sql.SQL(
213
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
214
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim)
215
+ )
216
+ self.conn.commit()
217
+ except Exception as e:
218
+ log.warning(
219
+ f"Failed to create pgvectorscale table: {self.table_name} error: {e}"
220
+ )
221
+ raise e from None
222
+
223
+ def insert_embeddings(
224
+ self,
225
+ embeddings: list[list[float]],
226
+ metadata: list[int],
227
+ **kwargs: Any,
228
+ ) -> Tuple[int, Optional[Exception]]:
229
+ assert self.conn is not None, "Connection is not initialized"
230
+ assert self.cursor is not None, "Cursor is not initialized"
231
+
232
+ try:
233
+ metadata_arr = np.array(metadata)
234
+ embeddings_arr = np.array(embeddings)
235
+
236
+ with self.cursor.copy(
237
+ sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
238
+ table_name=sql.Identifier(self.table_name)
239
+ )
240
+ ) as copy:
241
+ copy.set_types(["bigint", "vector"])
242
+ for i, row in enumerate(metadata_arr):
243
+ copy.write_row((row, embeddings_arr[i]))
244
+ self.conn.commit()
245
+
246
+ if kwargs.get("last_batch"):
247
+ self._post_insert()
248
+
249
+ return len(metadata), None
250
+ except Exception as e:
251
+ log.warning(
252
+ f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
253
+ )
254
+ return 0, e
255
+
256
+ def search_embedding(
257
+ self,
258
+ query: list[float],
259
+ k: int = 100,
260
+ filters: dict | None = None,
261
+ timeout: int | None = None,
262
+ ) -> list[int]:
263
+ assert self.conn is not None, "Connection is not initialized"
264
+ assert self.cursor is not None, "Cursor is not initialized"
265
+
266
+ q = np.asarray(query)
267
+ # TODO add filters support
268
+ result = self.cursor.execute(
269
+ self._unfiltered_search, (q, k), prepare=True, binary=True
270
+ )
271
+
272
+ return [int(i[0]) for i in result.fetchall()]
@@ -33,6 +33,7 @@ class BaseDataset(BaseModel):
33
33
  use_shuffled: bool
34
34
  with_gt: bool = False
35
35
  _size_label: dict[int, SizeLabel] = PrivateAttr()
36
+ isCustom: bool = False
36
37
 
37
38
  @validator("size")
38
39
  def verify_size(cls, v):
@@ -52,7 +53,27 @@ class BaseDataset(BaseModel):
52
53
  def file_count(self) -> int:
53
54
  return self._size_label.get(self.size).file_count
54
55
 
56
+ class CustomDataset(BaseDataset):
57
+ dir: str
58
+ file_num: int
59
+ isCustom: bool = True
60
+
61
+ @validator("size")
62
+ def verify_size(cls, v):
63
+ return v
64
+
65
+ @property
66
+ def label(self) -> str:
67
+ return "Custom"
55
68
 
69
+ @property
70
+ def dir_name(self) -> str:
71
+ return self.dir
72
+
73
+ @property
74
+ def file_count(self) -> int:
75
+ return self.file_num
76
+
56
77
  class LAION(BaseDataset):
57
78
  name: str = "LAION"
58
79
  dim: int = 768
@@ -186,11 +207,12 @@ class DatasetManager(BaseModel):
186
207
  gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
187
208
  all_files.extend([gt_file, test_file])
188
209
 
189
- source.reader().read(
190
- dataset=self.data.dir_name.lower(),
191
- files=all_files,
192
- local_ds_root=self.data_dir,
193
- )
210
+ if not self.data.isCustom:
211
+ source.reader().read(
212
+ dataset=self.data.dir_name.lower(),
213
+ files=all_files,
214
+ local_ds_root=self.data_dir,
215
+ )
194
216
 
195
217
  if gt_file is not None and test_file is not None:
196
218
  self.test_data = self._read_file(test_file)
@@ -1,19 +1,26 @@
1
1
  from ..backend.clients.pgvector.cli import PgVectorHNSW
2
+ from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat
2
3
  from ..backend.clients.redis.cli import Redis
4
+ from ..backend.clients.memorydb.cli import MemoryDB
3
5
  from ..backend.clients.test.cli import Test
4
6
  from ..backend.clients.weaviate_cloud.cli import Weaviate
5
7
  from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
6
8
  from ..backend.clients.milvus.cli import MilvusAutoIndex
9
+ from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
7
10
 
8
11
 
9
12
  from .cli import cli
10
13
 
11
14
  cli.add_command(PgVectorHNSW)
15
+ cli.add_command(PgVectoRSHNSW)
16
+ cli.add_command(PgVectoRSIVFFlat)
12
17
  cli.add_command(Redis)
18
+ cli.add_command(MemoryDB)
13
19
  cli.add_command(Weaviate)
14
20
  cli.add_command(Test)
15
21
  cli.add_command(ZillizAutoIndex)
16
22
  cli.add_command(MilvusAutoIndex)
23
+ cli.add_command(AWSOpenSearch)
17
24
 
18
25
 
19
26
  if __name__ == "__main__":
@@ -0,0 +1,18 @@
1
+ [
2
+ {
3
+ "name": "My Dataset (Performace Case)",
4
+ "description": "this is a customized dataset.",
5
+ "load_timeout": 36000,
6
+ "optimize_timeout": 36000,
7
+ "dataset_config": {
8
+ "name": "My Dataset",
9
+ "dir": "/my_dataset_path",
10
+ "size": 1000000,
11
+ "dim": 1024,
12
+ "metric_type": "L2",
13
+ "file_count": 1,
14
+ "use_shuffled": false,
15
+ "with_gt": true
16
+ }
17
+ }
18
+ ]
@@ -1,19 +1,19 @@
1
1
  from vectordb_bench.backend.cases import Case
2
2
  from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
3
3
  from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
4
- from vectordb_bench.frontend.const.styles import *
4
+ from vectordb_bench.frontend.config.styles import *
5
5
  from vectordb_bench.models import ResultLabel
6
6
  import plotly.express as px
7
7
 
8
8
 
9
- def drawCharts(st, allData, failedTasks, cases: list[Case]):
9
+ def drawCharts(st, allData, failedTasks, caseNames: list[str]):
10
10
  initMainExpanderStyle(st)
11
- for case in cases:
12
- chartContainer = st.expander(case.name, True)
13
- data = [data for data in allData if data["case_name"] == case.name]
11
+ for caseName in caseNames:
12
+ chartContainer = st.expander(caseName, True)
13
+ data = [data for data in allData if data["case_name"] == caseName]
14
14
  drawChart(data, chartContainer)
15
15
 
16
- errorDBs = failedTasks[case.name]
16
+ errorDBs = failedTasks[caseName]
17
17
  showFailedDBs(chartContainer, errorDBs)
18
18
 
19
19
 
@@ -8,9 +8,9 @@ from vectordb_bench.models import CaseResult, ResultLabel
8
8
  def getChartData(
9
9
  tasks: list[CaseResult],
10
10
  dbNames: list[str],
11
- cases: list[Case],
11
+ caseNames: list[str],
12
12
  ):
13
- filterTasks = getFilterTasks(tasks, dbNames, cases)
13
+ filterTasks = getFilterTasks(tasks, dbNames, caseNames)
14
14
  mergedTasks, failedTasks = mergeTasks(filterTasks)
15
15
  return mergedTasks, failedTasks
16
16
 
@@ -18,14 +18,16 @@ def getChartData(
18
18
  def getFilterTasks(
19
19
  tasks: list[CaseResult],
20
20
  dbNames: list[str],
21
- cases: list[Case],
21
+ caseNames: list[str],
22
22
  ) -> list[CaseResult]:
23
- case_ids = [case.case_id for case in cases]
24
23
  filterTasks = [
25
24
  task
26
25
  for task in tasks
27
26
  if task.task_config.db_name in dbNames
28
- and task.task_config.case_config.case_id in case_ids
27
+ and task.task_config.case_config.case_id.case_cls(
28
+ task.task_config.case_config.custom_case
29
+ ).name
30
+ in caseNames
29
31
  ]
30
32
  return filterTasks
31
33
 
@@ -36,16 +38,20 @@ def mergeTasks(tasks: list[CaseResult]):
36
38
  db_name = task.task_config.db_name
37
39
  db = task.task_config.db.value
38
40
  db_label = task.task_config.db_config.db_label or ""
39
- case_id = task.task_config.case_config.case_id
40
- dbCaseMetricsMap[db_name][case_id] = {
41
+ version = task.task_config.db_config.version or ""
42
+ case = task.task_config.case_config.case_id.case_cls(
43
+ task.task_config.case_config.custom_case
44
+ )
45
+ dbCaseMetricsMap[db_name][case.name] = {
41
46
  "db": db,
42
47
  "db_label": db_label,
48
+ "version": version,
43
49
  "metrics": mergeMetrics(
44
- dbCaseMetricsMap[db_name][case_id].get("metrics", {}),
50
+ dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
45
51
  asdict(task.metrics),
46
52
  ),
47
53
  "label": getBetterLabel(
48
- dbCaseMetricsMap[db_name][case_id].get("label", ResultLabel.FAILED),
54
+ dbCaseMetricsMap[db_name][case.name].get("label", ResultLabel.FAILED),
49
55
  task.label,
50
56
  ),
51
57
  }
@@ -53,18 +59,19 @@ def mergeTasks(tasks: list[CaseResult]):
53
59
  mergedTasks = []
54
60
  failedTasks = defaultdict(lambda: defaultdict(str))
55
61
  for db_name, caseMetricsMap in dbCaseMetricsMap.items():
56
- for case_id, metricInfo in caseMetricsMap.items():
62
+ for case_name, metricInfo in caseMetricsMap.items():
57
63
  metrics = metricInfo["metrics"]
58
64
  db = metricInfo["db"]
59
65
  db_label = metricInfo["db_label"]
66
+ version = metricInfo["version"]
60
67
  label = metricInfo["label"]
61
- case_name = case_id.case_name
62
68
  if label == ResultLabel.NORMAL:
63
69
  mergedTasks.append(
64
70
  {
65
71
  "db_name": db_name,
66
72
  "db": db,
67
73
  "db_label": db_label,
74
+ "version": version,
68
75
  "case_name": case_name,
69
76
  "metricsSet": set(metrics.keys()),
70
77
  **metrics,
@@ -1,7 +1,7 @@
1
1
  def initMainExpanderStyle(st):
2
2
  st.markdown(
3
3
  """<style>
4
- .main .streamlit-expanderHeader p {font-size: 20px; font-weight: 600;}
4
+ .main div[data-testid='stExpander'] p {font-size: 18px; font-weight: 600;}
5
5
  .main div[data-testid='stExpander'] {
6
6
  background-color: #F6F8FA;
7
7
  border: 1px solid #A9BDD140;
@@ -1,8 +1,8 @@
1
1
  from vectordb_bench.backend.cases import Case
2
2
  from vectordb_bench.frontend.components.check_results.data import getChartData
3
3
  from vectordb_bench.frontend.components.check_results.expanderStyle import initSidebarExanderStyle
4
- from vectordb_bench.frontend.const.dbCaseConfigs import CASE_LIST
5
- from vectordb_bench.frontend.const.styles import *
4
+ from vectordb_bench.frontend.config.dbCaseConfigs import CASE_NAME_ORDER
5
+ from vectordb_bench.frontend.config.styles import *
6
6
  import streamlit as st
7
7
 
8
8
  from vectordb_bench.models import CaseResult, TestResult
@@ -18,11 +18,12 @@ def getshownData(results: list[TestResult], st):
18
18
  st.header("Filters")
19
19
 
20
20
  shownResults = getshownResults(results, st)
21
- showDBNames, showCases = getShowDbsAndCases(shownResults, st)
21
+ showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st)
22
22
 
23
- shownData, failedTasks = getChartData(shownResults, showDBNames, showCases)
23
+ shownData, failedTasks = getChartData(
24
+ shownResults, showDBNames, showCaseNames)
24
25
 
25
- return shownData, failedTasks, showCases
26
+ return shownData, failedTasks, showCaseNames
26
27
 
27
28
 
28
29
  def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
@@ -52,12 +53,18 @@ def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
52
53
  return selectedResult
53
54
 
54
55
 
55
- def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[Case]]:
56
+ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[str]]:
56
57
  initSidebarExanderStyle(st)
57
58
  allDbNames = list(set({res.task_config.db_name for res in result}))
58
59
  allDbNames.sort()
59
- allCasesSet = set({res.task_config.case_config.case_id for res in result})
60
- allCases: list[Case] = [case.case_cls() for case in CASE_LIST if case in allCasesSet]
60
+ allCases: list[Case] = [
61
+ res.task_config.case_config.case_id.case_cls(
62
+ res.task_config.case_config.custom_case)
63
+ for res in result
64
+ ]
65
+ allCaseNameSet = set({case.name for case in allCases})
66
+ allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + \
67
+ [case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER]
61
68
 
62
69
  # DB Filter
63
70
  dbFilterContainer = st.container()
@@ -70,15 +77,14 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[Ca
70
77
 
71
78
  # Case Filter
72
79
  caseFilterContainer = st.container()
73
- showCases = filterView(
80
+ showCaseNames = filterView(
74
81
  caseFilterContainer,
75
82
  "Case Filter",
76
- [case for case in allCases],
83
+ [caseName for caseName in allCaseNames],
77
84
  col=1,
78
- optionLables=[case.name for case in allCases],
79
85
  )
80
86
 
81
- return showDBNames, showCases
87
+ return showDBNames, showCaseNames
82
88
 
83
89
 
84
90
  def filterView(container, header, options, col, optionLables=None):
@@ -114,7 +120,8 @@ def filterView(container, header, options, col, optionLables=None):
114
120
  )
115
121
  if optionLables is None:
116
122
  optionLables = options
117
- isActive = {option: st.session_state[selectAllState] for option in optionLables}
123
+ isActive = {option: st.session_state[selectAllState]
124
+ for option in optionLables}
118
125
  for i, option in enumerate(optionLables):
119
126
  isActive[option] = columns[i % col].checkbox(
120
127
  optionLables[i],
@@ -1,4 +1,4 @@
1
- from vectordb_bench.frontend.const.styles import HEADER_ICON
1
+ from vectordb_bench.frontend.config.styles import HEADER_ICON
2
2
 
3
3
 
4
4
  def drawHeaderIcon(st):
@@ -3,7 +3,7 @@ import pandas as pd
3
3
  from collections import defaultdict
4
4
  import streamlit as st
5
5
 
6
- from vectordb_bench.frontend.const.dbPrices import DB_DBLABEL_TO_PRICE
6
+ from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE
7
7
 
8
8
 
9
9
  def priceTable(container, data):
@@ -1,4 +1,4 @@
1
- from vectordb_bench.frontend.const.styles import *
1
+ from vectordb_bench.frontend.config.styles import *
2
2
 
3
3
 
4
4
  def initResultsPageConfig(st):
@@ -1,26 +1,27 @@
1
-
2
-
3
- from vectordb_bench.backend.cases import Case
4
- from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
1
+ from vectordb_bench.frontend.components.check_results.expanderStyle import (
2
+ initMainExpanderStyle,
3
+ )
5
4
  import plotly.express as px
6
5
 
7
- from vectordb_bench.frontend.const.styles import COLOR_MAP
6
+ from vectordb_bench.frontend.config.styles import COLOR_MAP
8
7
 
9
8
 
10
- def drawChartsByCase(allData, cases: list[Case], st):
9
+ def drawChartsByCase(allData, showCaseNames: list[str], st):
11
10
  initMainExpanderStyle(st)
12
- for case in cases:
13
- chartContainer = st.expander(case.name, True)
14
- caseDataList = [
15
- data for data in allData if data["case_name"] == case.name]
16
- data = [{
17
- "conc_num": caseData["conc_num_list"][i],
18
- "qps": caseData["conc_qps_list"][i],
19
- "latency_p99": caseData["conc_latency_p99_list"][i] * 1000,
20
- "db_name": caseData["db_name"],
21
- "db": caseData["db"]
22
-
23
- } for caseData in caseDataList for i in range(len(caseData["conc_num_list"]))]
11
+ for caseName in showCaseNames:
12
+ chartContainer = st.expander(caseName, True)
13
+ caseDataList = [data for data in allData if data["case_name"] == caseName]
14
+ data = [
15
+ {
16
+ "conc_num": caseData["conc_num_list"][i],
17
+ "qps": caseData["conc_qps_list"][i],
18
+ "latency_p99": caseData["conc_latency_p99_list"][i] * 1000,
19
+ "db_name": caseData["db_name"],
20
+ "db": caseData["db"],
21
+ }
22
+ for caseData in caseDataList
23
+ for i in range(len(caseData["conc_num_list"]))
24
+ ]
24
25
  drawChart(data, chartContainer)
25
26
 
26
27
 
@@ -38,7 +39,7 @@ def getRange(metric, data, padding_multipliers):
38
39
  def drawChart(data, st):
39
40
  if len(data) == 0:
40
41
  return
41
-
42
+
42
43
  x = "latency_p99"
43
44
  xrange = getRange(x, data, [0.05, 0.1])
44
45
 
@@ -63,7 +64,6 @@ def drawChart(data, st):
63
64
  line_group=line_group,
64
65
  text=text,
65
66
  markers=True,
66
- # color_discrete_map=color_discrete_map,
67
67
  hover_data={
68
68
  "conc_num": True,
69
69
  },
@@ -71,12 +71,9 @@ def drawChart(data, st):
71
71
  )
72
72
  fig.update_xaxes(range=xrange, title_text="Latency P99 (ms)")
73
73
  fig.update_yaxes(range=yrange, title_text="QPS")
74
- fig.update_traces(textposition="bottom right",
75
- texttemplate="conc-%{text:,.4~r}")
76
- # fig.update_layout(
77
- # margin=dict(l=0, r=0, t=40, b=0, pad=8),
78
- # legend=dict(
79
- # orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""
80
- # ),
81
- # )
82
- st.plotly_chart(fig, use_container_width=True,)
74
+ fig.update_traces(textposition="bottom right", texttemplate="conc-%{text:,.4~r}")
75
+
76
+ st.plotly_chart(
77
+ fig,
78
+ use_container_width=True,
79
+ )