vectordb-bench 0.0.30__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. vectordb_bench/__init__.py +14 -27
  2. vectordb_bench/backend/assembler.py +19 -6
  3. vectordb_bench/backend/cases.py +186 -23
  4. vectordb_bench/backend/clients/__init__.py +16 -0
  5. vectordb_bench/backend/clients/api.py +22 -1
  6. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +82 -41
  7. vectordb_bench/backend/clients/aws_opensearch/config.py +23 -4
  8. vectordb_bench/backend/clients/chroma/chroma.py +6 -2
  9. vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
  10. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
  11. vectordb_bench/backend/clients/milvus/config.py +1 -0
  12. vectordb_bench/backend/clients/milvus/milvus.py +74 -22
  13. vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
  14. vectordb_bench/backend/clients/oceanbase/config.py +125 -0
  15. vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
  16. vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
  17. vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
  18. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
  19. vectordb_bench/backend/dataset.py +143 -27
  20. vectordb_bench/backend/filter.py +76 -0
  21. vectordb_bench/backend/runner/__init__.py +3 -3
  22. vectordb_bench/backend/runner/mp_runner.py +52 -39
  23. vectordb_bench/backend/runner/rate_runner.py +68 -52
  24. vectordb_bench/backend/runner/read_write_runner.py +125 -68
  25. vectordb_bench/backend/runner/serial_runner.py +56 -23
  26. vectordb_bench/backend/task_runner.py +48 -20
  27. vectordb_bench/cli/cli.py +59 -1
  28. vectordb_bench/cli/vectordbbench.py +3 -0
  29. vectordb_bench/frontend/components/check_results/data.py +16 -11
  30. vectordb_bench/frontend/components/check_results/filters.py +53 -25
  31. vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
  32. vectordb_bench/frontend/components/check_results/nav.py +20 -0
  33. vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
  34. vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
  35. vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
  36. vectordb_bench/frontend/components/label_filter/charts.py +60 -0
  37. vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
  38. vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
  39. vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
  40. vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
  41. vectordb_bench/frontend/components/streaming/charts.py +253 -0
  42. vectordb_bench/frontend/components/streaming/data.py +62 -0
  43. vectordb_bench/frontend/components/tables/data.py +1 -1
  44. vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
  45. vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
  46. vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
  47. vectordb_bench/frontend/config/dbCaseConfigs.py +307 -40
  48. vectordb_bench/frontend/config/styles.py +32 -2
  49. vectordb_bench/frontend/pages/concurrent.py +5 -1
  50. vectordb_bench/frontend/pages/custom.py +4 -0
  51. vectordb_bench/frontend/pages/label_filter.py +56 -0
  52. vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
  53. vectordb_bench/frontend/pages/results.py +60 -0
  54. vectordb_bench/frontend/pages/run_test.py +3 -3
  55. vectordb_bench/frontend/pages/streaming.py +135 -0
  56. vectordb_bench/frontend/pages/tables.py +4 -0
  57. vectordb_bench/frontend/vdb_benchmark.py +16 -41
  58. vectordb_bench/interface.py +6 -2
  59. vectordb_bench/metric.py +15 -1
  60. vectordb_bench/models.py +31 -11
  61. vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
  62. vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
  63. vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
  64. vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
  65. vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
  66. vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
  67. vectordb_bench/results/dbPrices.json +12 -4
  68. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +85 -32
  69. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +73 -56
  70. vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
  71. vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
  72. vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
  73. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +0 -0
  74. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
  75. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
  76. {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -6,15 +6,14 @@ from enum import Enum, auto
6
6
  import numpy as np
7
7
  import psutil
8
8
 
9
- from vectordb_bench.base import BaseModel
10
- from vectordb_bench.metric import Metric
11
- from vectordb_bench.models import PerformanceTimeoutError, TaskConfig, TaskStage
12
-
9
+ from ..base import BaseModel
10
+ from ..metric import Metric
11
+ from ..models import PerformanceTimeoutError, TaskConfig, TaskStage
13
12
  from . import utils
14
- from .cases import Case, CaseLabel
13
+ from .cases import Case, CaseLabel, StreamingPerformanceCase
15
14
  from .clients import MetricType, api
16
15
  from .data_source import DatasetSource
17
- from .runner import MultiProcessingSearchRunner, SerialInsertRunner, SerialSearchRunner
16
+ from .runner import MultiProcessingSearchRunner, ReadWriteRunner, SerialInsertRunner, SerialSearchRunner
18
17
 
19
18
  log = logging.getLogger(__name__)
20
19
 
@@ -48,6 +47,7 @@ class CaseRunner(BaseModel):
48
47
  serial_search_runner: SerialSearchRunner | None = None
49
48
  search_runner: MultiProcessingSearchRunner | None = None
50
49
  final_search_runner: MultiProcessingSearchRunner | None = None
50
+ read_write_runner: ReadWriteRunner | None = None
51
51
 
52
52
  def __eq__(self, obj: any):
53
53
  if isinstance(obj, CaseRunner):
@@ -63,6 +63,7 @@ class CaseRunner(BaseModel):
63
63
  c_dict = self.ca.dict(
64
64
  include={
65
65
  "label": True,
66
+ "name": True,
66
67
  "filters": True,
67
68
  "dataset": {
68
69
  "data": {
@@ -91,12 +92,13 @@ class CaseRunner(BaseModel):
91
92
  db_config=self.config.db_config.to_dict(),
92
93
  db_case_config=self.config.db_case_config,
93
94
  drop_old=drop_old,
95
+ with_scalar_labels=self.ca.with_scalar_labels,
94
96
  )
95
97
 
96
98
  def _pre_run(self, drop_old: bool = True):
97
99
  try:
98
100
  self.init_db(drop_old)
99
- self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filter_rate)
101
+ self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filters)
100
102
  except ModuleNotFoundError as e:
101
103
  log.warning(f"pre run case error: please install client for db: {self.config.db}, error={e}")
102
104
  raise e from None
@@ -110,6 +112,8 @@ class CaseRunner(BaseModel):
110
112
  return self._run_capacity_case()
111
113
  if self.ca.label == CaseLabel.Performance:
112
114
  return self._run_perf_case(drop_old)
115
+ if self.ca.label == CaseLabel.Streaming:
116
+ return self._run_streaming_case()
113
117
  msg = f"unknown case type: {self.ca.label}"
114
118
  log.warning(msg)
115
119
  raise ValueError(msg)
@@ -127,6 +131,7 @@ class CaseRunner(BaseModel):
127
131
  self.db,
128
132
  self.ca.dataset,
129
133
  self.normalize,
134
+ self.ca.filters,
130
135
  self.ca.load_timeout,
131
136
  )
132
137
  count = runner.run_endlessness()
@@ -151,6 +156,8 @@ class CaseRunner(BaseModel):
151
156
  if TaskStage.LOAD in self.config.stages:
152
157
  _, load_dur = self._load_train_data()
153
158
  build_dur = self._optimize()
159
+ m.insert_duration = round(load_dur, 4)
160
+ m.optimize_duration = round(build_dur, 4)
154
161
  m.load_duration = round(load_dur + build_dur, 4)
155
162
  log.info(
156
163
  f"Finish loading the entire dataset into VectorDB,"
@@ -172,10 +179,6 @@ class CaseRunner(BaseModel):
172
179
  ) = search_results
173
180
  if TaskStage.SEARCH_SERIAL in self.config.stages:
174
181
  search_results = self._serial_search()
175
- """
176
- m.recall = search_results.recall
177
- m.serial_latencies = search_results.serial_latencies
178
- """
179
182
  m.recall, m.ndcg, m.serial_latency_p99 = search_results
180
183
 
181
184
  except Exception as e:
@@ -186,6 +189,19 @@ class CaseRunner(BaseModel):
186
189
  log.info(f"Performance case got result: {m}")
187
190
  return m
188
191
 
192
+ def _run_streaming_case(self) -> Metric:
193
+ log.info("Start streaming case")
194
+ try:
195
+ self._init_read_write_runner()
196
+ m = self.read_write_runner.run_read_write()
197
+ except Exception as e:
198
+ log.warning(f"Failed to run streaming case, reason = {e}")
199
+ traceback.print_exc()
200
+ raise e from None
201
+ else:
202
+ log.info(f"Streaming case got result: {m}")
203
+ return m
204
+
189
205
  @utils.time_it
190
206
  def _load_train_data(self):
191
207
  """Insert train data and get the insert_duration"""
@@ -194,6 +210,7 @@ class CaseRunner(BaseModel):
194
210
  self.db,
195
211
  self.ca.dataset,
196
212
  self.normalize,
213
+ self.ca.filters,
197
214
  self.ca.load_timeout,
198
215
  )
199
216
  runner.run()
@@ -207,7 +224,7 @@ class CaseRunner(BaseModel):
207
224
  calculate the recall, serial_latency_p99
208
225
 
209
226
  Returns:
210
- tuple[float, float]: recall, serial_latency_p99
227
+ tuple[float, float, float]: recall, ndcg, serial_latency_p99
211
228
  """
212
229
  try:
213
230
  results, _ = self.serial_search_runner.run()
@@ -253,10 +270,12 @@ class CaseRunner(BaseModel):
253
270
  raise e from None
254
271
 
255
272
  def _init_search_runner(self):
256
- test_emb = np.stack(self.ca.dataset.test_data["emb"])
257
273
  if self.normalize:
274
+ test_emb = np.stack(self.ca.dataset.test_data)
258
275
  test_emb = test_emb / np.linalg.norm(test_emb, axis=1)[:, np.newaxis]
259
- self.test_emb = test_emb.tolist()
276
+ self.test_emb = test_emb.tolist()
277
+ else:
278
+ self.test_emb = self.ca.dataset.test_data
260
279
 
261
280
  gt_df = self.ca.dataset.gt_data
262
281
 
@@ -279,6 +298,20 @@ class CaseRunner(BaseModel):
279
298
  k=self.config.case_config.k,
280
299
  )
281
300
 
301
+ def _init_read_write_runner(self):
302
+ ca: StreamingPerformanceCase = self.ca
303
+ self.read_write_runner = ReadWriteRunner(
304
+ db=self.db,
305
+ dataset=ca.dataset,
306
+ insert_rate=ca.insert_rate,
307
+ search_stages=ca.search_stages,
308
+ optimize_after_write=ca.optimize_after_write,
309
+ read_dur_after_write=ca.read_dur_after_write,
310
+ concurrencies=ca.concurrencies,
311
+ k=self.config.case_config.k,
312
+ normalize=self.normalize,
313
+ )
314
+
282
315
  def stop(self):
283
316
  if self.search_runner:
284
317
  self.search_runner.stop()
@@ -316,12 +349,7 @@ class TaskRunner(BaseModel):
316
349
  fmt.append(DATA_FORMAT % ("-" * 11, "-" * 12, "-" * 20, "-" * 7, "-" * 7))
317
350
 
318
351
  for f in self.case_runners:
319
- if f.ca.filter_rate != 0.0:
320
- filters = f.ca.filter_rate
321
- elif f.ca.filter_size != 0:
322
- filters = f.ca.filter_size
323
- else:
324
- filters = "None"
352
+ filters = f.ca.filters.filter_rate
325
353
 
326
354
  ds_str = f"{f.ca.dataset.data.name}-{f.ca.dataset.data.label}-{utils.numerize(f.ca.dataset.data.size)}"
327
355
  fmt.append(
vectordb_bench/cli/cli.py CHANGED
@@ -110,7 +110,7 @@ def click_parameter_decorators_from_typed_dict(
110
110
  return deco
111
111
 
112
112
 
113
- def click_arg_split(ctx: click.Context, param: click.core.Option, value): # noqa: ANN001, ARG001
113
+ def click_arg_split(ctx: click.Context, param: click.core.Option, value: any): # noqa: ARG001
114
114
  """Will split a comma-separated list input into an actual list.
115
115
 
116
116
  Args:
@@ -455,6 +455,22 @@ class HNSWFlavor3(HNSWBaseRequiredTypedDict):
455
455
  ]
456
456
 
457
457
 
458
+ class HNSWFlavor4(HNSWBaseRequiredTypedDict):
459
+ ef_search: Annotated[
460
+ int | None,
461
+ click.option("--ef-search", type=int, help="hnsw ef-search", required=True),
462
+ ]
463
+ index_type: Annotated[
464
+ str | None,
465
+ click.option(
466
+ "--index-type",
467
+ type=click.Choice(["HNSW", "HNSW_SQ", "HNSW_BQ"], case_sensitive=False),
468
+ help="Type of index to use. Supported values: HNSW, HNSW_SQ, HNSW_BQ",
469
+ required=True,
470
+ ),
471
+ ]
472
+
473
+
458
474
  class IVFFlatTypedDict(TypedDict):
459
475
  lists: Annotated[int | None, click.option("--lists", type=int, help="ivfflat lists")]
460
476
  probes: Annotated[int | None, click.option("--probes", type=int, help="ivfflat probes")]
@@ -471,6 +487,48 @@ class IVFFlatTypedDictN(TypedDict):
471
487
  ]
472
488
 
473
489
 
490
+ class OceanBaseIVFTypedDict(TypedDict):
491
+ index_type: Annotated[
492
+ str | None,
493
+ click.option(
494
+ "--index-type",
495
+ type=click.Choice(["IVF_FLAT", "IVF_SQ8", "IVF_PQ"], case_sensitive=False),
496
+ help="Type of index to use. Supported values: IVF_FLAT, IVF_SQ8, IVF_PQ",
497
+ required=True,
498
+ ),
499
+ ]
500
+ nlist: Annotated[
501
+ int | None,
502
+ click.option("--nlist", "nlist", type=int, help="Number of cluster centers", required=True),
503
+ ]
504
+ sample_per_nlist: Annotated[
505
+ int | None,
506
+ click.option(
507
+ "--sample_per_nlist",
508
+ "sample_per_nlist",
509
+ type=int,
510
+ help="The cluster centers are calculated by total sampling sample_per_nlist * nlist vectors",
511
+ required=True,
512
+ ),
513
+ ]
514
+ ivf_nprobes: Annotated[
515
+ int | None,
516
+ click.option(
517
+ "--ivf_nprobes",
518
+ "ivf_nprobes",
519
+ type=str,
520
+ help="How many clustering centers to search during the query",
521
+ required=True,
522
+ ),
523
+ ]
524
+ m: Annotated[
525
+ int | None,
526
+ click.option(
527
+ "--m", "m", type=int, help="The number of sub-vectors that each data vector is divided into during IVF-PQ"
528
+ ),
529
+ ]
530
+
531
+
474
532
  @click.group()
475
533
  def cli(): ...
476
534
 
@@ -5,6 +5,7 @@ from ..backend.clients.lancedb.cli import LanceDB
5
5
  from ..backend.clients.mariadb.cli import MariaDBHNSW
6
6
  from ..backend.clients.memorydb.cli import MemoryDB
7
7
  from ..backend.clients.milvus.cli import MilvusAutoIndex
8
+ from ..backend.clients.oceanbase.cli import OceanBaseHNSW, OceanBaseIVF
8
9
  from ..backend.clients.pgdiskann.cli import PgDiskAnn
9
10
  from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat
10
11
  from ..backend.clients.pgvector.cli import PgVectorHNSW
@@ -33,6 +34,8 @@ cli.add_command(AWSOpenSearch)
33
34
  cli.add_command(PgVectorScaleDiskAnn)
34
35
  cli.add_command(PgDiskAnn)
35
36
  cli.add_command(AlloyDBScaNN)
37
+ cli.add_command(OceanBaseHNSW)
38
+ cli.add_command(OceanBaseIVF)
36
39
  cli.add_command(MariaDBHNSW)
37
40
  cli.add_command(TiDB)
38
41
  cli.add_command(Clickhouse)
@@ -1,6 +1,6 @@
1
1
  from collections import defaultdict
2
2
  from dataclasses import asdict
3
- from vectordb_bench.metric import isLowerIsBetterMetric
3
+ from vectordb_bench.metric import QPS_METRIC, isLowerIsBetterMetric
4
4
  from vectordb_bench.models import CaseResult, ResultLabel
5
5
 
6
6
 
@@ -22,8 +22,7 @@ def getFilterTasks(
22
22
  filterTasks = [
23
23
  task
24
24
  for task in tasks
25
- if task.task_config.db_name in dbNames
26
- and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
25
+ if task.task_config.db_name in dbNames and task.task_config.case_config.case_name in caseNames
27
26
  ]
28
27
  return filterTasks
29
28
 
@@ -35,17 +34,22 @@ def mergeTasks(tasks: list[CaseResult]):
35
34
  db = task.task_config.db.value
36
35
  db_label = task.task_config.db_config.db_label or ""
37
36
  version = task.task_config.db_config.version or ""
38
- case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
37
+ case = task.task_config.case_config.case
38
+ case_name = case.name
39
+ dataset_name = case.dataset.data.full_name
40
+ filter_rate = case.filter_rate
39
41
  dbCaseMetricsMap[db_name][case.name] = {
40
42
  "db": db,
41
43
  "db_label": db_label,
42
44
  "version": version,
45
+ "dataset_name": dataset_name,
46
+ "filter_rate": filter_rate,
43
47
  "metrics": mergeMetrics(
44
- dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
48
+ dbCaseMetricsMap[db_name][case_name].get("metrics", {}),
45
49
  asdict(task.metrics),
46
50
  ),
47
51
  "label": getBetterLabel(
48
- dbCaseMetricsMap[db_name][case.name].get("label", ResultLabel.FAILED),
52
+ dbCaseMetricsMap[db_name][case_name].get("label", ResultLabel.FAILED),
49
53
  task.label,
50
54
  ),
51
55
  }
@@ -59,12 +63,16 @@ def mergeTasks(tasks: list[CaseResult]):
59
63
  db_label = metricInfo["db_label"]
60
64
  version = metricInfo["version"]
61
65
  label = metricInfo["label"]
66
+ dataset_name = metricInfo["dataset_name"]
67
+ filter_rate = metricInfo["filter_rate"]
62
68
  if label == ResultLabel.NORMAL:
63
69
  mergedTasks.append(
64
70
  {
65
71
  "db_name": db_name,
66
72
  "db": db,
67
73
  "db_label": db_label,
74
+ "dataset_name": dataset_name,
75
+ "filter_rate": filter_rate,
68
76
  "version": version,
69
77
  "case_name": case_name,
70
78
  "metricsSet": set(metrics.keys()),
@@ -77,12 +85,9 @@ def mergeTasks(tasks: list[CaseResult]):
77
85
  return mergedTasks, failedTasks
78
86
 
79
87
 
88
+ # for same db-label, we use the results with the highest qps
80
89
  def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
81
- metrics = {**metrics_1}
82
- for key, value in metrics_2.items():
83
- metrics[key] = getBetterMetric(key, value, metrics[key]) if key in metrics else value
84
-
85
- return metrics
90
+ return metrics_1 if metrics_1.get(QPS_METRIC, 0) > metrics_2.get(QPS_METRIC, 0) else metrics_2
86
91
 
87
92
 
88
93
  def getBetterMetric(metric, value_1, value_2):
@@ -1,14 +1,19 @@
1
1
  from vectordb_bench.backend.cases import Case
2
+ from vectordb_bench.backend.dataset import DatasetWithSizeType
3
+ from vectordb_bench.backend.filter import FilterOp
2
4
  from vectordb_bench.frontend.components.check_results.data import getChartData
3
- from vectordb_bench.frontend.components.check_results.expanderStyle import initSidebarExanderStyle
5
+ from vectordb_bench.frontend.components.check_results.expanderStyle import (
6
+ initSidebarExanderStyle,
7
+ )
4
8
  from vectordb_bench.frontend.config.dbCaseConfigs import CASE_NAME_ORDER
5
- from vectordb_bench.frontend.config.styles import *
9
+ from vectordb_bench.frontend.config.styles import SIDEBAR_CONTROL_COLUMNS
6
10
  import streamlit as st
11
+ from typing import Callable
7
12
 
8
13
  from vectordb_bench.models import CaseResult, TestResult
9
14
 
10
15
 
11
- def getshownData(results: list[TestResult], st):
16
+ def getshownData(st, results: list[TestResult], filter_type: FilterOp = FilterOp.NonFilter, **kwargs):
12
17
  # hide the nav
13
18
  st.markdown(
14
19
  "<style> div[data-testid='stSidebarNav'] {display: none;} </style>",
@@ -17,15 +22,20 @@ def getshownData(results: list[TestResult], st):
17
22
 
18
23
  st.header("Filters")
19
24
 
20
- shownResults = getshownResults(results, st)
21
- showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st)
25
+ shownResults = getshownResults(st, results, **kwargs)
26
+ showDBNames, showCaseNames = getShowDbsAndCases(st, shownResults, filter_type)
22
27
 
23
28
  shownData, failedTasks = getChartData(shownResults, showDBNames, showCaseNames)
24
29
 
25
30
  return shownData, failedTasks, showCaseNames
26
31
 
27
32
 
28
- def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
33
+ def getshownResults(
34
+ st,
35
+ results: list[TestResult],
36
+ case_results_filter: Callable[[CaseResult], bool] = lambda x: True,
37
+ **kwargs,
38
+ ) -> list[CaseResult]:
29
39
  resultSelectOptions = [
30
40
  result.task_label if result.task_label != result.run_id else f"res-{result.run_id[:4]}" for result in results
31
41
  ]
@@ -41,23 +51,18 @@ def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
41
51
  )
42
52
  selectedResult: list[CaseResult] = []
43
53
  for option in selectedResultSelectedOptions:
44
- result = results[resultSelectOptions.index(option)].results
45
- selectedResult += result
54
+ case_results = results[resultSelectOptions.index(option)].results
55
+ selectedResult += [r for r in case_results if case_results_filter(r)]
46
56
 
47
57
  return selectedResult
48
58
 
49
59
 
50
- def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[str]]:
60
+ def getShowDbsAndCases(st, result: list[CaseResult], filter_type: FilterOp) -> tuple[list[str], list[str]]:
51
61
  initSidebarExanderStyle(st)
52
- allDbNames = list(set({res.task_config.db_name for res in result}))
62
+ case_results = [res for res in result if res.task_config.case_config.case.filters.type == filter_type]
63
+ allDbNames = list(set({res.task_config.db_name for res in case_results}))
53
64
  allDbNames.sort()
54
- allCases: list[Case] = [
55
- res.task_config.case_config.case_id.case_cls(res.task_config.case_config.custom_case) for res in result
56
- ]
57
- allCaseNameSet = set({case.name for case in allCases})
58
- allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + [
59
- case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER
60
- ]
65
+ allCases: list[Case] = [res.task_config.case_config.case for res in case_results]
61
66
 
62
67
  # DB Filter
63
68
  dbFilterContainer = st.container()
@@ -67,15 +72,38 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[st
67
72
  allDbNames,
68
73
  col=1,
69
74
  )
75
+ showCaseNames = []
76
+
77
+ if filter_type == FilterOp.NonFilter:
78
+ allCaseNameSet = set({case.name for case in allCases})
79
+ allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + [
80
+ case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER
81
+ ]
82
+
83
+ # Case Filter
84
+ caseFilterContainer = st.container()
85
+ showCaseNames = filterView(
86
+ caseFilterContainer,
87
+ "Case Filter",
88
+ [caseName for caseName in allCaseNames],
89
+ col=1,
90
+ )
70
91
 
71
- # Case Filter
72
- caseFilterContainer = st.container()
73
- showCaseNames = filterView(
74
- caseFilterContainer,
75
- "Case Filter",
76
- [caseName for caseName in allCaseNames],
77
- col=1,
78
- )
92
+ if filter_type == FilterOp.StrEqual:
93
+ container = st.container()
94
+ datasetWithSizeTypes = [dataset_with_size_type for dataset_with_size_type in DatasetWithSizeType]
95
+ showDatasetWithSizeTypes = filterView(
96
+ container,
97
+ "Case Filter",
98
+ datasetWithSizeTypes,
99
+ col=1,
100
+ optionLables=[v.value for v in datasetWithSizeTypes],
101
+ )
102
+ datasets = [dataset_with_size_type.get_manager() for dataset_with_size_type in showDatasetWithSizeTypes]
103
+ showCaseNames = list(set([case.name for case in allCases if case.dataset in datasets]))
104
+
105
+ if filter_type == FilterOp.NumGE:
106
+ raise NotImplementedError
79
107
 
80
108
  return showDBNames, showCaseNames
81
109
 
@@ -4,19 +4,22 @@ from vectordb_bench.frontend.config.styles import HEADER_ICON
4
4
  def drawHeaderIcon(st):
5
5
  st.markdown(
6
6
  f"""
7
- <div class="headerIconContainer"></div>
7
+ <a href="/vdb_benchmark" target="_self">
8
+ <div class="headerIconContainer"></div>
9
+ </a>
8
10
 
9
- <style>
10
- .headerIconContainer {{
11
- position: absolute;
12
- top: -50px;
13
- height: 50px;
14
- width: 100%;
15
- border-bottom: 2px solid #E8EAEE;
16
- background-image: url({HEADER_ICON});
17
- background-repeat: no-repeat;
18
- }}
19
- </style
20
- """,
11
+ <style>
12
+ .headerIconContainer {{
13
+ position: relative;
14
+ top: 0px;
15
+ height: 50px;
16
+ width: 100%;
17
+ border-bottom: 2px solid #E8EAEE;
18
+ background-image: url({HEADER_ICON});
19
+ background-repeat: no-repeat;
20
+ cursor: pointer;
21
+ }}
22
+ </style>
23
+ """,
21
24
  unsafe_allow_html=True,
22
25
  )
@@ -20,3 +20,23 @@ def NavToResults(st, key="nav-to-results"):
20
20
  navClick = st.button("< &nbsp;&nbsp;Back to Results", key=key)
21
21
  if navClick:
22
22
  switch_page("vdb benchmark")
23
+
24
+
25
+ def NavToPages(st):
26
+ options = [
27
+ {"name": "Run Test", "link": "run_test"},
28
+ {"name": "Results", "link": "results"},
29
+ {"name": "Quries Per Dollar", "link": "quries_per_dollar"},
30
+ {"name": "Concurrent", "link": "concurrent"},
31
+ {"name": "Label Filter", "link": "label_filter"},
32
+ {"name": "Streaming", "link": "streaming"},
33
+ {"name": "Tables", "link": "tables"},
34
+ {"name": "Custom Dataset", "link": "custom"},
35
+ ]
36
+
37
+ html = ""
38
+ for i, option in enumerate(options):
39
+ html += f'<a href="/{option["link"]}" target="_self" style="text-decoration: none; padding: 0.1px 0.2px;">{option["name"]}</a>'
40
+ if i < len(options) - 1:
41
+ html += '<span style="color: #888; margin: 0 5px;">|</span>'
42
+ st.markdown(html, unsafe_allow_html=True)
@@ -12,7 +12,7 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
12
12
  "Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir
13
13
  )
14
14
 
15
- columns = st.columns(4)
15
+ columns = st.columns(3)
16
16
  customCase.dataset_config.dim = columns[0].number_input(
17
17
  "dim", key=f"{key}_dim", value=customCase.dataset_config.dim
18
18
  )
@@ -22,16 +22,51 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
22
22
  customCase.dataset_config.metric_type = columns[2].selectbox(
23
23
  "metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"]
24
24
  )
25
- customCase.dataset_config.file_count = columns[3].number_input(
26
- "train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count
25
+
26
+ columns = st.columns(3)
27
+ customCase.dataset_config.train_name = columns[0].text_input(
28
+ "train file name",
29
+ key=f"{key}_train_name",
30
+ value=customCase.dataset_config.train_name,
31
+ )
32
+ customCase.dataset_config.test_name = columns[1].text_input(
33
+ "test file name", key=f"{key}_test_name", value=customCase.dataset_config.test_name
34
+ )
35
+ customCase.dataset_config.gt_name = columns[2].text_input(
36
+ "ground truth file name", key=f"{key}_gt_name", value=customCase.dataset_config.gt_name
37
+ )
38
+
39
+ columns = st.columns([1, 1, 2, 2])
40
+ customCase.dataset_config.train_id_name = columns[0].text_input(
41
+ "train id name", key=f"{key}_train_id_name", value=customCase.dataset_config.train_id_name
42
+ )
43
+ customCase.dataset_config.train_col_name = columns[1].text_input(
44
+ "train emb name", key=f"{key}_train_col_name", value=customCase.dataset_config.train_col_name
45
+ )
46
+ customCase.dataset_config.test_col_name = columns[2].text_input(
47
+ "test emb name", key=f"{key}_test_col_name", value=customCase.dataset_config.test_col_name
48
+ )
49
+ customCase.dataset_config.gt_col_name = columns[3].text_input(
50
+ "ground truth emb name", key=f"{key}_gt_col_name", value=customCase.dataset_config.gt_col_name
27
51
  )
28
52
 
29
- columns = st.columns(4)
30
- customCase.dataset_config.use_shuffled = columns[0].checkbox(
31
- "use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled
53
+ columns = st.columns(2)
54
+ customCase.dataset_config.scalar_labels_name = columns[0].text_input(
55
+ "scalar labels file name",
56
+ key=f"{key}_scalar_labels_file_name",
57
+ value=customCase.dataset_config.scalar_labels_name,
32
58
  )
33
- customCase.dataset_config.with_gt = columns[1].checkbox(
34
- "with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt
59
+ default_label_percentages = ",".join(map(str, customCase.dataset_config.with_label_percentages))
60
+ label_percentage_input = columns[1].text_input(
61
+ "label percentages",
62
+ key=f"{key}_label_percantages",
63
+ value=default_label_percentages,
35
64
  )
65
+ try:
66
+ customCase.dataset_config.label_percentages = [
67
+ float(item.strip()) for item in label_percentage_input.split(",") if item.strip()
68
+ ]
69
+ except ValueError as e:
70
+ st.write(f"<span style='color:red'>{e},please input correct number</span>", unsafe_allow_html=True)
36
71
 
37
72
  customCase.description = st.text_area("description", key=f"{key}_description", value=customCase.description)
@@ -2,13 +2,18 @@ def displayParams(st):
2
2
  st.markdown(
3
3
  """
4
4
  - `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
5
- - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
6
- - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
7
- - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`.
5
+ - Vectors data files: The file should have two kinds of columns: `id` as an incrementing `int` and `emb` as an array of `float32`. The name of two columns could be defined on your own.
6
+ - Query test vectors: The file could be named on your own and should have two kinds of columns: `id` as an incrementing `int` and `emb` as an array of `float32`. The `id` column must be named as `id`, and `emb` column could be defined on your own.
7
+ - Ground truth file: The file could be named on your own and should have two kinds of columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`. The `id` column must be named as `id`, and `neighbors_id` column could be defined on your own.
8
8
 
9
- - `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files.
9
+ - `Train File Name` - If the number of train file is `more than one`, please input all your train file name and `split with ','` without the `.parquet` file extensionthe. For example, if there are two train file and the name of them are `train1.parquet` and `train2.parquet`, then input `train1,train2`.
10
+
11
+ - `Ground Truth Emb Name` - No matter whether filter file is applied or not, the `neighbors_id` column in ground truth file must have the same name.
12
+
13
+ - `Scalar Labels File Name ` - If there is a scalar labels file, please input the filename without the .parquet extension. The file should have two columns: `id` as an incrementing `int` and `labels` as an array of `string`. The `id` column must correspond one-to-one with the `id` column in train file..
14
+
15
+ - `Label percentages` - If you have filter file, please input label percentage you want to real run and `split with ','` when it's `more than one`. If you `don't have` filter file, than `keep the text vacant.`
10
16
 
11
- - `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.
12
17
  """
13
18
  )
14
19
  st.caption(
@@ -14,6 +14,16 @@ class CustomDatasetConfig(BaseModel):
14
14
  file_count: int = 1
15
15
  use_shuffled: bool = False
16
16
  with_gt: bool = True
17
+ train_name: str = "train"
18
+ test_name: str = "test"
19
+ gt_name: str = "neighbors"
20
+ train_id_name: str = "id"
21
+ train_col_name: str = "emb"
22
+ test_col_name: str = "emb"
23
+ gt_col_name: str = "neighbors_id"
24
+ scalar_labels_name: str = "scalar_labels"
25
+ label_percentages: list[str] = []
26
+ with_label_percentages: list[float] = [0.001, 0.02, 0.5]
17
27
 
18
28
 
19
29
  class CustomCaseConfig(BaseModel):