vectordb-bench 0.0.30__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -27
- vectordb_bench/backend/assembler.py +19 -6
- vectordb_bench/backend/cases.py +186 -23
- vectordb_bench/backend/clients/__init__.py +16 -0
- vectordb_bench/backend/clients/api.py +22 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +82 -41
- vectordb_bench/backend/clients/aws_opensearch/config.py +23 -4
- vectordb_bench/backend/clients/chroma/chroma.py +6 -2
- vectordb_bench/backend/clients/elastic_cloud/config.py +19 -1
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +133 -45
- vectordb_bench/backend/clients/milvus/config.py +1 -0
- vectordb_bench/backend/clients/milvus/milvus.py +74 -22
- vectordb_bench/backend/clients/oceanbase/cli.py +100 -0
- vectordb_bench/backend/clients/oceanbase/config.py +125 -0
- vectordb_bench/backend/clients/oceanbase/oceanbase.py +215 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +39 -25
- vectordb_bench/backend/clients/qdrant_cloud/config.py +59 -3
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +100 -33
- vectordb_bench/backend/dataset.py +143 -27
- vectordb_bench/backend/filter.py +76 -0
- vectordb_bench/backend/runner/__init__.py +3 -3
- vectordb_bench/backend/runner/mp_runner.py +52 -39
- vectordb_bench/backend/runner/rate_runner.py +68 -52
- vectordb_bench/backend/runner/read_write_runner.py +125 -68
- vectordb_bench/backend/runner/serial_runner.py +56 -23
- vectordb_bench/backend/task_runner.py +48 -20
- vectordb_bench/cli/cli.py +59 -1
- vectordb_bench/cli/vectordbbench.py +3 -0
- vectordb_bench/frontend/components/check_results/data.py +16 -11
- vectordb_bench/frontend/components/check_results/filters.py +53 -25
- vectordb_bench/frontend/components/check_results/headerIcon.py +16 -13
- vectordb_bench/frontend/components/check_results/nav.py +20 -0
- vectordb_bench/frontend/components/custom/displayCustomCase.py +43 -8
- vectordb_bench/frontend/components/custom/displaypPrams.py +10 -5
- vectordb_bench/frontend/components/custom/getCustomConfig.py +10 -0
- vectordb_bench/frontend/components/label_filter/charts.py +60 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +48 -52
- vectordb_bench/frontend/components/run_test/dbSelector.py +9 -5
- vectordb_bench/frontend/components/run_test/inputWidget.py +48 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +3 -1
- vectordb_bench/frontend/components/streaming/charts.py +253 -0
- vectordb_bench/frontend/components/streaming/data.py +62 -0
- vectordb_bench/frontend/components/tables/data.py +1 -1
- vectordb_bench/frontend/components/welcome/explainPrams.py +66 -0
- vectordb_bench/frontend/components/welcome/pagestyle.py +106 -0
- vectordb_bench/frontend/components/welcome/welcomePrams.py +147 -0
- vectordb_bench/frontend/config/dbCaseConfigs.py +307 -40
- vectordb_bench/frontend/config/styles.py +32 -2
- vectordb_bench/frontend/pages/concurrent.py +5 -1
- vectordb_bench/frontend/pages/custom.py +4 -0
- vectordb_bench/frontend/pages/label_filter.py +56 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -1
- vectordb_bench/frontend/pages/results.py +60 -0
- vectordb_bench/frontend/pages/run_test.py +3 -3
- vectordb_bench/frontend/pages/streaming.py +135 -0
- vectordb_bench/frontend/pages/tables.py +4 -0
- vectordb_bench/frontend/vdb_benchmark.py +16 -41
- vectordb_bench/interface.py +6 -2
- vectordb_bench/metric.py +15 -1
- vectordb_bench/models.py +31 -11
- vectordb_bench/results/ElasticCloud/result_20250318_standard_elasticcloud.json +5890 -0
- vectordb_bench/results/Milvus/result_20250509_standard_milvus.json +6138 -0
- vectordb_bench/results/OpenSearch/result_20250224_standard_opensearch.json +7319 -0
- vectordb_bench/results/Pinecone/result_20250124_standard_pinecone.json +2365 -0
- vectordb_bench/results/QdrantCloud/result_20250602_standard_qdrantcloud.json +3556 -0
- vectordb_bench/results/ZillizCloud/result_20250613_standard_zillizcloud.json +6290 -0
- vectordb_bench/results/dbPrices.json +12 -4
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/METADATA +85 -32
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/RECORD +73 -56
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +0 -791
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +0 -679
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +0 -1352
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/WHEEL +0 -0
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {vectordb_bench-0.0.30.dist-info → vectordb_bench-1.0.0.dist-info}/top_level.txt +0 -0
@@ -6,15 +6,14 @@ from enum import Enum, auto
|
|
6
6
|
import numpy as np
|
7
7
|
import psutil
|
8
8
|
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
|
9
|
+
from ..base import BaseModel
|
10
|
+
from ..metric import Metric
|
11
|
+
from ..models import PerformanceTimeoutError, TaskConfig, TaskStage
|
13
12
|
from . import utils
|
14
|
-
from .cases import Case, CaseLabel
|
13
|
+
from .cases import Case, CaseLabel, StreamingPerformanceCase
|
15
14
|
from .clients import MetricType, api
|
16
15
|
from .data_source import DatasetSource
|
17
|
-
from .runner import MultiProcessingSearchRunner, SerialInsertRunner, SerialSearchRunner
|
16
|
+
from .runner import MultiProcessingSearchRunner, ReadWriteRunner, SerialInsertRunner, SerialSearchRunner
|
18
17
|
|
19
18
|
log = logging.getLogger(__name__)
|
20
19
|
|
@@ -48,6 +47,7 @@ class CaseRunner(BaseModel):
|
|
48
47
|
serial_search_runner: SerialSearchRunner | None = None
|
49
48
|
search_runner: MultiProcessingSearchRunner | None = None
|
50
49
|
final_search_runner: MultiProcessingSearchRunner | None = None
|
50
|
+
read_write_runner: ReadWriteRunner | None = None
|
51
51
|
|
52
52
|
def __eq__(self, obj: any):
|
53
53
|
if isinstance(obj, CaseRunner):
|
@@ -63,6 +63,7 @@ class CaseRunner(BaseModel):
|
|
63
63
|
c_dict = self.ca.dict(
|
64
64
|
include={
|
65
65
|
"label": True,
|
66
|
+
"name": True,
|
66
67
|
"filters": True,
|
67
68
|
"dataset": {
|
68
69
|
"data": {
|
@@ -91,12 +92,13 @@ class CaseRunner(BaseModel):
|
|
91
92
|
db_config=self.config.db_config.to_dict(),
|
92
93
|
db_case_config=self.config.db_case_config,
|
93
94
|
drop_old=drop_old,
|
95
|
+
with_scalar_labels=self.ca.with_scalar_labels,
|
94
96
|
)
|
95
97
|
|
96
98
|
def _pre_run(self, drop_old: bool = True):
|
97
99
|
try:
|
98
100
|
self.init_db(drop_old)
|
99
|
-
self.ca.dataset.prepare(self.dataset_source, filters=self.ca.
|
101
|
+
self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filters)
|
100
102
|
except ModuleNotFoundError as e:
|
101
103
|
log.warning(f"pre run case error: please install client for db: {self.config.db}, error={e}")
|
102
104
|
raise e from None
|
@@ -110,6 +112,8 @@ class CaseRunner(BaseModel):
|
|
110
112
|
return self._run_capacity_case()
|
111
113
|
if self.ca.label == CaseLabel.Performance:
|
112
114
|
return self._run_perf_case(drop_old)
|
115
|
+
if self.ca.label == CaseLabel.Streaming:
|
116
|
+
return self._run_streaming_case()
|
113
117
|
msg = f"unknown case type: {self.ca.label}"
|
114
118
|
log.warning(msg)
|
115
119
|
raise ValueError(msg)
|
@@ -127,6 +131,7 @@ class CaseRunner(BaseModel):
|
|
127
131
|
self.db,
|
128
132
|
self.ca.dataset,
|
129
133
|
self.normalize,
|
134
|
+
self.ca.filters,
|
130
135
|
self.ca.load_timeout,
|
131
136
|
)
|
132
137
|
count = runner.run_endlessness()
|
@@ -151,6 +156,8 @@ class CaseRunner(BaseModel):
|
|
151
156
|
if TaskStage.LOAD in self.config.stages:
|
152
157
|
_, load_dur = self._load_train_data()
|
153
158
|
build_dur = self._optimize()
|
159
|
+
m.insert_duration = round(load_dur, 4)
|
160
|
+
m.optimize_duration = round(build_dur, 4)
|
154
161
|
m.load_duration = round(load_dur + build_dur, 4)
|
155
162
|
log.info(
|
156
163
|
f"Finish loading the entire dataset into VectorDB,"
|
@@ -172,10 +179,6 @@ class CaseRunner(BaseModel):
|
|
172
179
|
) = search_results
|
173
180
|
if TaskStage.SEARCH_SERIAL in self.config.stages:
|
174
181
|
search_results = self._serial_search()
|
175
|
-
"""
|
176
|
-
m.recall = search_results.recall
|
177
|
-
m.serial_latencies = search_results.serial_latencies
|
178
|
-
"""
|
179
182
|
m.recall, m.ndcg, m.serial_latency_p99 = search_results
|
180
183
|
|
181
184
|
except Exception as e:
|
@@ -186,6 +189,19 @@ class CaseRunner(BaseModel):
|
|
186
189
|
log.info(f"Performance case got result: {m}")
|
187
190
|
return m
|
188
191
|
|
192
|
+
def _run_streaming_case(self) -> Metric:
|
193
|
+
log.info("Start streaming case")
|
194
|
+
try:
|
195
|
+
self._init_read_write_runner()
|
196
|
+
m = self.read_write_runner.run_read_write()
|
197
|
+
except Exception as e:
|
198
|
+
log.warning(f"Failed to run streaming case, reason = {e}")
|
199
|
+
traceback.print_exc()
|
200
|
+
raise e from None
|
201
|
+
else:
|
202
|
+
log.info(f"Streaming case got result: {m}")
|
203
|
+
return m
|
204
|
+
|
189
205
|
@utils.time_it
|
190
206
|
def _load_train_data(self):
|
191
207
|
"""Insert train data and get the insert_duration"""
|
@@ -194,6 +210,7 @@ class CaseRunner(BaseModel):
|
|
194
210
|
self.db,
|
195
211
|
self.ca.dataset,
|
196
212
|
self.normalize,
|
213
|
+
self.ca.filters,
|
197
214
|
self.ca.load_timeout,
|
198
215
|
)
|
199
216
|
runner.run()
|
@@ -207,7 +224,7 @@ class CaseRunner(BaseModel):
|
|
207
224
|
calculate the recall, serial_latency_p99
|
208
225
|
|
209
226
|
Returns:
|
210
|
-
tuple[float, float]: recall, serial_latency_p99
|
227
|
+
tuple[float, float, float]: recall, ndcg, serial_latency_p99
|
211
228
|
"""
|
212
229
|
try:
|
213
230
|
results, _ = self.serial_search_runner.run()
|
@@ -253,10 +270,12 @@ class CaseRunner(BaseModel):
|
|
253
270
|
raise e from None
|
254
271
|
|
255
272
|
def _init_search_runner(self):
|
256
|
-
test_emb = np.stack(self.ca.dataset.test_data["emb"])
|
257
273
|
if self.normalize:
|
274
|
+
test_emb = np.stack(self.ca.dataset.test_data)
|
258
275
|
test_emb = test_emb / np.linalg.norm(test_emb, axis=1)[:, np.newaxis]
|
259
|
-
|
276
|
+
self.test_emb = test_emb.tolist()
|
277
|
+
else:
|
278
|
+
self.test_emb = self.ca.dataset.test_data
|
260
279
|
|
261
280
|
gt_df = self.ca.dataset.gt_data
|
262
281
|
|
@@ -279,6 +298,20 @@ class CaseRunner(BaseModel):
|
|
279
298
|
k=self.config.case_config.k,
|
280
299
|
)
|
281
300
|
|
301
|
+
def _init_read_write_runner(self):
|
302
|
+
ca: StreamingPerformanceCase = self.ca
|
303
|
+
self.read_write_runner = ReadWriteRunner(
|
304
|
+
db=self.db,
|
305
|
+
dataset=ca.dataset,
|
306
|
+
insert_rate=ca.insert_rate,
|
307
|
+
search_stages=ca.search_stages,
|
308
|
+
optimize_after_write=ca.optimize_after_write,
|
309
|
+
read_dur_after_write=ca.read_dur_after_write,
|
310
|
+
concurrencies=ca.concurrencies,
|
311
|
+
k=self.config.case_config.k,
|
312
|
+
normalize=self.normalize,
|
313
|
+
)
|
314
|
+
|
282
315
|
def stop(self):
|
283
316
|
if self.search_runner:
|
284
317
|
self.search_runner.stop()
|
@@ -316,12 +349,7 @@ class TaskRunner(BaseModel):
|
|
316
349
|
fmt.append(DATA_FORMAT % ("-" * 11, "-" * 12, "-" * 20, "-" * 7, "-" * 7))
|
317
350
|
|
318
351
|
for f in self.case_runners:
|
319
|
-
|
320
|
-
filters = f.ca.filter_rate
|
321
|
-
elif f.ca.filter_size != 0:
|
322
|
-
filters = f.ca.filter_size
|
323
|
-
else:
|
324
|
-
filters = "None"
|
352
|
+
filters = f.ca.filters.filter_rate
|
325
353
|
|
326
354
|
ds_str = f"{f.ca.dataset.data.name}-{f.ca.dataset.data.label}-{utils.numerize(f.ca.dataset.data.size)}"
|
327
355
|
fmt.append(
|
vectordb_bench/cli/cli.py
CHANGED
@@ -110,7 +110,7 @@ def click_parameter_decorators_from_typed_dict(
|
|
110
110
|
return deco
|
111
111
|
|
112
112
|
|
113
|
-
def click_arg_split(ctx: click.Context, param: click.core.Option, value): # noqa:
|
113
|
+
def click_arg_split(ctx: click.Context, param: click.core.Option, value: any): # noqa: ARG001
|
114
114
|
"""Will split a comma-separated list input into an actual list.
|
115
115
|
|
116
116
|
Args:
|
@@ -455,6 +455,22 @@ class HNSWFlavor3(HNSWBaseRequiredTypedDict):
|
|
455
455
|
]
|
456
456
|
|
457
457
|
|
458
|
+
class HNSWFlavor4(HNSWBaseRequiredTypedDict):
|
459
|
+
ef_search: Annotated[
|
460
|
+
int | None,
|
461
|
+
click.option("--ef-search", type=int, help="hnsw ef-search", required=True),
|
462
|
+
]
|
463
|
+
index_type: Annotated[
|
464
|
+
str | None,
|
465
|
+
click.option(
|
466
|
+
"--index-type",
|
467
|
+
type=click.Choice(["HNSW", "HNSW_SQ", "HNSW_BQ"], case_sensitive=False),
|
468
|
+
help="Type of index to use. Supported values: HNSW, HNSW_SQ, HNSW_BQ",
|
469
|
+
required=True,
|
470
|
+
),
|
471
|
+
]
|
472
|
+
|
473
|
+
|
458
474
|
class IVFFlatTypedDict(TypedDict):
|
459
475
|
lists: Annotated[int | None, click.option("--lists", type=int, help="ivfflat lists")]
|
460
476
|
probes: Annotated[int | None, click.option("--probes", type=int, help="ivfflat probes")]
|
@@ -471,6 +487,48 @@ class IVFFlatTypedDictN(TypedDict):
|
|
471
487
|
]
|
472
488
|
|
473
489
|
|
490
|
+
class OceanBaseIVFTypedDict(TypedDict):
|
491
|
+
index_type: Annotated[
|
492
|
+
str | None,
|
493
|
+
click.option(
|
494
|
+
"--index-type",
|
495
|
+
type=click.Choice(["IVF_FLAT", "IVF_SQ8", "IVF_PQ"], case_sensitive=False),
|
496
|
+
help="Type of index to use. Supported values: IVF_FLAT, IVF_SQ8, IVF_PQ",
|
497
|
+
required=True,
|
498
|
+
),
|
499
|
+
]
|
500
|
+
nlist: Annotated[
|
501
|
+
int | None,
|
502
|
+
click.option("--nlist", "nlist", type=int, help="Number of cluster centers", required=True),
|
503
|
+
]
|
504
|
+
sample_per_nlist: Annotated[
|
505
|
+
int | None,
|
506
|
+
click.option(
|
507
|
+
"--sample_per_nlist",
|
508
|
+
"sample_per_nlist",
|
509
|
+
type=int,
|
510
|
+
help="The cluster centers are calculated by total sampling sample_per_nlist * nlist vectors",
|
511
|
+
required=True,
|
512
|
+
),
|
513
|
+
]
|
514
|
+
ivf_nprobes: Annotated[
|
515
|
+
int | None,
|
516
|
+
click.option(
|
517
|
+
"--ivf_nprobes",
|
518
|
+
"ivf_nprobes",
|
519
|
+
type=str,
|
520
|
+
help="How many clustering centers to search during the query",
|
521
|
+
required=True,
|
522
|
+
),
|
523
|
+
]
|
524
|
+
m: Annotated[
|
525
|
+
int | None,
|
526
|
+
click.option(
|
527
|
+
"--m", "m", type=int, help="The number of sub-vectors that each data vector is divided into during IVF-PQ"
|
528
|
+
),
|
529
|
+
]
|
530
|
+
|
531
|
+
|
474
532
|
@click.group()
|
475
533
|
def cli(): ...
|
476
534
|
|
@@ -5,6 +5,7 @@ from ..backend.clients.lancedb.cli import LanceDB
|
|
5
5
|
from ..backend.clients.mariadb.cli import MariaDBHNSW
|
6
6
|
from ..backend.clients.memorydb.cli import MemoryDB
|
7
7
|
from ..backend.clients.milvus.cli import MilvusAutoIndex
|
8
|
+
from ..backend.clients.oceanbase.cli import OceanBaseHNSW, OceanBaseIVF
|
8
9
|
from ..backend.clients.pgdiskann.cli import PgDiskAnn
|
9
10
|
from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat
|
10
11
|
from ..backend.clients.pgvector.cli import PgVectorHNSW
|
@@ -33,6 +34,8 @@ cli.add_command(AWSOpenSearch)
|
|
33
34
|
cli.add_command(PgVectorScaleDiskAnn)
|
34
35
|
cli.add_command(PgDiskAnn)
|
35
36
|
cli.add_command(AlloyDBScaNN)
|
37
|
+
cli.add_command(OceanBaseHNSW)
|
38
|
+
cli.add_command(OceanBaseIVF)
|
36
39
|
cli.add_command(MariaDBHNSW)
|
37
40
|
cli.add_command(TiDB)
|
38
41
|
cli.add_command(Clickhouse)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from collections import defaultdict
|
2
2
|
from dataclasses import asdict
|
3
|
-
from vectordb_bench.metric import isLowerIsBetterMetric
|
3
|
+
from vectordb_bench.metric import QPS_METRIC, isLowerIsBetterMetric
|
4
4
|
from vectordb_bench.models import CaseResult, ResultLabel
|
5
5
|
|
6
6
|
|
@@ -22,8 +22,7 @@ def getFilterTasks(
|
|
22
22
|
filterTasks = [
|
23
23
|
task
|
24
24
|
for task in tasks
|
25
|
-
if task.task_config.db_name in dbNames
|
26
|
-
and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
|
25
|
+
if task.task_config.db_name in dbNames and task.task_config.case_config.case_name in caseNames
|
27
26
|
]
|
28
27
|
return filterTasks
|
29
28
|
|
@@ -35,17 +34,22 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
35
34
|
db = task.task_config.db.value
|
36
35
|
db_label = task.task_config.db_config.db_label or ""
|
37
36
|
version = task.task_config.db_config.version or ""
|
38
|
-
case = task.task_config.case_config.
|
37
|
+
case = task.task_config.case_config.case
|
38
|
+
case_name = case.name
|
39
|
+
dataset_name = case.dataset.data.full_name
|
40
|
+
filter_rate = case.filter_rate
|
39
41
|
dbCaseMetricsMap[db_name][case.name] = {
|
40
42
|
"db": db,
|
41
43
|
"db_label": db_label,
|
42
44
|
"version": version,
|
45
|
+
"dataset_name": dataset_name,
|
46
|
+
"filter_rate": filter_rate,
|
43
47
|
"metrics": mergeMetrics(
|
44
|
-
dbCaseMetricsMap[db_name][
|
48
|
+
dbCaseMetricsMap[db_name][case_name].get("metrics", {}),
|
45
49
|
asdict(task.metrics),
|
46
50
|
),
|
47
51
|
"label": getBetterLabel(
|
48
|
-
dbCaseMetricsMap[db_name][
|
52
|
+
dbCaseMetricsMap[db_name][case_name].get("label", ResultLabel.FAILED),
|
49
53
|
task.label,
|
50
54
|
),
|
51
55
|
}
|
@@ -59,12 +63,16 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
59
63
|
db_label = metricInfo["db_label"]
|
60
64
|
version = metricInfo["version"]
|
61
65
|
label = metricInfo["label"]
|
66
|
+
dataset_name = metricInfo["dataset_name"]
|
67
|
+
filter_rate = metricInfo["filter_rate"]
|
62
68
|
if label == ResultLabel.NORMAL:
|
63
69
|
mergedTasks.append(
|
64
70
|
{
|
65
71
|
"db_name": db_name,
|
66
72
|
"db": db,
|
67
73
|
"db_label": db_label,
|
74
|
+
"dataset_name": dataset_name,
|
75
|
+
"filter_rate": filter_rate,
|
68
76
|
"version": version,
|
69
77
|
"case_name": case_name,
|
70
78
|
"metricsSet": set(metrics.keys()),
|
@@ -77,12 +85,9 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
77
85
|
return mergedTasks, failedTasks
|
78
86
|
|
79
87
|
|
88
|
+
# for same db-label, we use the results with the highest qps
|
80
89
|
def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
|
81
|
-
|
82
|
-
for key, value in metrics_2.items():
|
83
|
-
metrics[key] = getBetterMetric(key, value, metrics[key]) if key in metrics else value
|
84
|
-
|
85
|
-
return metrics
|
90
|
+
return metrics_1 if metrics_1.get(QPS_METRIC, 0) > metrics_2.get(QPS_METRIC, 0) else metrics_2
|
86
91
|
|
87
92
|
|
88
93
|
def getBetterMetric(metric, value_1, value_2):
|
@@ -1,14 +1,19 @@
|
|
1
1
|
from vectordb_bench.backend.cases import Case
|
2
|
+
from vectordb_bench.backend.dataset import DatasetWithSizeType
|
3
|
+
from vectordb_bench.backend.filter import FilterOp
|
2
4
|
from vectordb_bench.frontend.components.check_results.data import getChartData
|
3
|
-
from vectordb_bench.frontend.components.check_results.expanderStyle import
|
5
|
+
from vectordb_bench.frontend.components.check_results.expanderStyle import (
|
6
|
+
initSidebarExanderStyle,
|
7
|
+
)
|
4
8
|
from vectordb_bench.frontend.config.dbCaseConfigs import CASE_NAME_ORDER
|
5
|
-
from vectordb_bench.frontend.config.styles import
|
9
|
+
from vectordb_bench.frontend.config.styles import SIDEBAR_CONTROL_COLUMNS
|
6
10
|
import streamlit as st
|
11
|
+
from typing import Callable
|
7
12
|
|
8
13
|
from vectordb_bench.models import CaseResult, TestResult
|
9
14
|
|
10
15
|
|
11
|
-
def getshownData(results: list[TestResult],
|
16
|
+
def getshownData(st, results: list[TestResult], filter_type: FilterOp = FilterOp.NonFilter, **kwargs):
|
12
17
|
# hide the nav
|
13
18
|
st.markdown(
|
14
19
|
"<style> div[data-testid='stSidebarNav'] {display: none;} </style>",
|
@@ -17,15 +22,20 @@ def getshownData(results: list[TestResult], st):
|
|
17
22
|
|
18
23
|
st.header("Filters")
|
19
24
|
|
20
|
-
shownResults = getshownResults(results,
|
21
|
-
showDBNames, showCaseNames = getShowDbsAndCases(shownResults,
|
25
|
+
shownResults = getshownResults(st, results, **kwargs)
|
26
|
+
showDBNames, showCaseNames = getShowDbsAndCases(st, shownResults, filter_type)
|
22
27
|
|
23
28
|
shownData, failedTasks = getChartData(shownResults, showDBNames, showCaseNames)
|
24
29
|
|
25
30
|
return shownData, failedTasks, showCaseNames
|
26
31
|
|
27
32
|
|
28
|
-
def getshownResults(
|
33
|
+
def getshownResults(
|
34
|
+
st,
|
35
|
+
results: list[TestResult],
|
36
|
+
case_results_filter: Callable[[CaseResult], bool] = lambda x: True,
|
37
|
+
**kwargs,
|
38
|
+
) -> list[CaseResult]:
|
29
39
|
resultSelectOptions = [
|
30
40
|
result.task_label if result.task_label != result.run_id else f"res-{result.run_id[:4]}" for result in results
|
31
41
|
]
|
@@ -41,23 +51,18 @@ def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
|
|
41
51
|
)
|
42
52
|
selectedResult: list[CaseResult] = []
|
43
53
|
for option in selectedResultSelectedOptions:
|
44
|
-
|
45
|
-
selectedResult +=
|
54
|
+
case_results = results[resultSelectOptions.index(option)].results
|
55
|
+
selectedResult += [r for r in case_results if case_results_filter(r)]
|
46
56
|
|
47
57
|
return selectedResult
|
48
58
|
|
49
59
|
|
50
|
-
def getShowDbsAndCases(result: list[CaseResult],
|
60
|
+
def getShowDbsAndCases(st, result: list[CaseResult], filter_type: FilterOp) -> tuple[list[str], list[str]]:
|
51
61
|
initSidebarExanderStyle(st)
|
52
|
-
|
62
|
+
case_results = [res for res in result if res.task_config.case_config.case.filters.type == filter_type]
|
63
|
+
allDbNames = list(set({res.task_config.db_name for res in case_results}))
|
53
64
|
allDbNames.sort()
|
54
|
-
allCases: list[Case] = [
|
55
|
-
res.task_config.case_config.case_id.case_cls(res.task_config.case_config.custom_case) for res in result
|
56
|
-
]
|
57
|
-
allCaseNameSet = set({case.name for case in allCases})
|
58
|
-
allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + [
|
59
|
-
case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER
|
60
|
-
]
|
65
|
+
allCases: list[Case] = [res.task_config.case_config.case for res in case_results]
|
61
66
|
|
62
67
|
# DB Filter
|
63
68
|
dbFilterContainer = st.container()
|
@@ -67,15 +72,38 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[st
|
|
67
72
|
allDbNames,
|
68
73
|
col=1,
|
69
74
|
)
|
75
|
+
showCaseNames = []
|
76
|
+
|
77
|
+
if filter_type == FilterOp.NonFilter:
|
78
|
+
allCaseNameSet = set({case.name for case in allCases})
|
79
|
+
allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + [
|
80
|
+
case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER
|
81
|
+
]
|
82
|
+
|
83
|
+
# Case Filter
|
84
|
+
caseFilterContainer = st.container()
|
85
|
+
showCaseNames = filterView(
|
86
|
+
caseFilterContainer,
|
87
|
+
"Case Filter",
|
88
|
+
[caseName for caseName in allCaseNames],
|
89
|
+
col=1,
|
90
|
+
)
|
70
91
|
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
92
|
+
if filter_type == FilterOp.StrEqual:
|
93
|
+
container = st.container()
|
94
|
+
datasetWithSizeTypes = [dataset_with_size_type for dataset_with_size_type in DatasetWithSizeType]
|
95
|
+
showDatasetWithSizeTypes = filterView(
|
96
|
+
container,
|
97
|
+
"Case Filter",
|
98
|
+
datasetWithSizeTypes,
|
99
|
+
col=1,
|
100
|
+
optionLables=[v.value for v in datasetWithSizeTypes],
|
101
|
+
)
|
102
|
+
datasets = [dataset_with_size_type.get_manager() for dataset_with_size_type in showDatasetWithSizeTypes]
|
103
|
+
showCaseNames = list(set([case.name for case in allCases if case.dataset in datasets]))
|
104
|
+
|
105
|
+
if filter_type == FilterOp.NumGE:
|
106
|
+
raise NotImplementedError
|
79
107
|
|
80
108
|
return showDBNames, showCaseNames
|
81
109
|
|
@@ -4,19 +4,22 @@ from vectordb_bench.frontend.config.styles import HEADER_ICON
|
|
4
4
|
def drawHeaderIcon(st):
|
5
5
|
st.markdown(
|
6
6
|
f"""
|
7
|
-
<
|
7
|
+
<a href="/vdb_benchmark" target="_self">
|
8
|
+
<div class="headerIconContainer"></div>
|
9
|
+
</a>
|
8
10
|
|
9
|
-
<style>
|
10
|
-
.headerIconContainer {{
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
11
|
+
<style>
|
12
|
+
.headerIconContainer {{
|
13
|
+
position: relative;
|
14
|
+
top: 0px;
|
15
|
+
height: 50px;
|
16
|
+
width: 100%;
|
17
|
+
border-bottom: 2px solid #E8EAEE;
|
18
|
+
background-image: url({HEADER_ICON});
|
19
|
+
background-repeat: no-repeat;
|
20
|
+
cursor: pointer;
|
21
|
+
}}
|
22
|
+
</style>
|
23
|
+
""",
|
21
24
|
unsafe_allow_html=True,
|
22
25
|
)
|
@@ -20,3 +20,23 @@ def NavToResults(st, key="nav-to-results"):
|
|
20
20
|
navClick = st.button("< Back to Results", key=key)
|
21
21
|
if navClick:
|
22
22
|
switch_page("vdb benchmark")
|
23
|
+
|
24
|
+
|
25
|
+
def NavToPages(st):
|
26
|
+
options = [
|
27
|
+
{"name": "Run Test", "link": "run_test"},
|
28
|
+
{"name": "Results", "link": "results"},
|
29
|
+
{"name": "Quries Per Dollar", "link": "quries_per_dollar"},
|
30
|
+
{"name": "Concurrent", "link": "concurrent"},
|
31
|
+
{"name": "Label Filter", "link": "label_filter"},
|
32
|
+
{"name": "Streaming", "link": "streaming"},
|
33
|
+
{"name": "Tables", "link": "tables"},
|
34
|
+
{"name": "Custom Dataset", "link": "custom"},
|
35
|
+
]
|
36
|
+
|
37
|
+
html = ""
|
38
|
+
for i, option in enumerate(options):
|
39
|
+
html += f'<a href="/{option["link"]}" target="_self" style="text-decoration: none; padding: 0.1px 0.2px;">{option["name"]}</a>'
|
40
|
+
if i < len(options) - 1:
|
41
|
+
html += '<span style="color: #888; margin: 0 5px;">|</span>'
|
42
|
+
st.markdown(html, unsafe_allow_html=True)
|
@@ -12,7 +12,7 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
|
|
12
12
|
"Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir
|
13
13
|
)
|
14
14
|
|
15
|
-
columns = st.columns(
|
15
|
+
columns = st.columns(3)
|
16
16
|
customCase.dataset_config.dim = columns[0].number_input(
|
17
17
|
"dim", key=f"{key}_dim", value=customCase.dataset_config.dim
|
18
18
|
)
|
@@ -22,16 +22,51 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
|
|
22
22
|
customCase.dataset_config.metric_type = columns[2].selectbox(
|
23
23
|
"metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"]
|
24
24
|
)
|
25
|
-
|
26
|
-
|
25
|
+
|
26
|
+
columns = st.columns(3)
|
27
|
+
customCase.dataset_config.train_name = columns[0].text_input(
|
28
|
+
"train file name",
|
29
|
+
key=f"{key}_train_name",
|
30
|
+
value=customCase.dataset_config.train_name,
|
31
|
+
)
|
32
|
+
customCase.dataset_config.test_name = columns[1].text_input(
|
33
|
+
"test file name", key=f"{key}_test_name", value=customCase.dataset_config.test_name
|
34
|
+
)
|
35
|
+
customCase.dataset_config.gt_name = columns[2].text_input(
|
36
|
+
"ground truth file name", key=f"{key}_gt_name", value=customCase.dataset_config.gt_name
|
37
|
+
)
|
38
|
+
|
39
|
+
columns = st.columns([1, 1, 2, 2])
|
40
|
+
customCase.dataset_config.train_id_name = columns[0].text_input(
|
41
|
+
"train id name", key=f"{key}_train_id_name", value=customCase.dataset_config.train_id_name
|
42
|
+
)
|
43
|
+
customCase.dataset_config.train_col_name = columns[1].text_input(
|
44
|
+
"train emb name", key=f"{key}_train_col_name", value=customCase.dataset_config.train_col_name
|
45
|
+
)
|
46
|
+
customCase.dataset_config.test_col_name = columns[2].text_input(
|
47
|
+
"test emb name", key=f"{key}_test_col_name", value=customCase.dataset_config.test_col_name
|
48
|
+
)
|
49
|
+
customCase.dataset_config.gt_col_name = columns[3].text_input(
|
50
|
+
"ground truth emb name", key=f"{key}_gt_col_name", value=customCase.dataset_config.gt_col_name
|
27
51
|
)
|
28
52
|
|
29
|
-
columns = st.columns(
|
30
|
-
customCase.dataset_config.
|
31
|
-
"
|
53
|
+
columns = st.columns(2)
|
54
|
+
customCase.dataset_config.scalar_labels_name = columns[0].text_input(
|
55
|
+
"scalar labels file name",
|
56
|
+
key=f"{key}_scalar_labels_file_name",
|
57
|
+
value=customCase.dataset_config.scalar_labels_name,
|
32
58
|
)
|
33
|
-
|
34
|
-
|
59
|
+
default_label_percentages = ",".join(map(str, customCase.dataset_config.with_label_percentages))
|
60
|
+
label_percentage_input = columns[1].text_input(
|
61
|
+
"label percentages",
|
62
|
+
key=f"{key}_label_percantages",
|
63
|
+
value=default_label_percentages,
|
35
64
|
)
|
65
|
+
try:
|
66
|
+
customCase.dataset_config.label_percentages = [
|
67
|
+
float(item.strip()) for item in label_percentage_input.split(",") if item.strip()
|
68
|
+
]
|
69
|
+
except ValueError as e:
|
70
|
+
st.write(f"<span style='color:red'>{e},please input correct number</span>", unsafe_allow_html=True)
|
36
71
|
|
37
72
|
customCase.description = st.text_area("description", key=f"{key}_description", value=customCase.description)
|
@@ -2,13 +2,18 @@ def displayParams(st):
|
|
2
2
|
st.markdown(
|
3
3
|
"""
|
4
4
|
- `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
|
5
|
-
- Vectors data files: The file
|
6
|
-
- Query test vectors: The file
|
7
|
-
- Ground truth file: The file
|
5
|
+
- Vectors data files: The file should have two kinds of columns: `id` as an incrementing `int` and `emb` as an array of `float32`. The name of two columns could be defined on your own.
|
6
|
+
- Query test vectors: The file could be named on your own and should have two kinds of columns: `id` as an incrementing `int` and `emb` as an array of `float32`. The `id` column must be named as `id`, and `emb` column could be defined on your own.
|
7
|
+
- Ground truth file: The file could be named on your own and should have two kinds of columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`. The `id` column must be named as `id`, and `neighbors_id` column could be defined on your own.
|
8
8
|
|
9
|
-
- `Train File
|
9
|
+
- `Train File Name` - If the number of train file is `more than one`, please input all your train file name and `split with ','` without the `.parquet` file extensionthe. For example, if there are two train file and the name of them are `train1.parquet` and `train2.parquet`, then input `train1,train2`.
|
10
|
+
|
11
|
+
- `Ground Truth Emb Name` - No matter whether filter file is applied or not, the `neighbors_id` column in ground truth file must have the same name.
|
12
|
+
|
13
|
+
- `Scalar Labels File Name ` - If there is a scalar labels file, please input the filename without the .parquet extension. The file should have two columns: `id` as an incrementing `int` and `labels` as an array of `string`. The `id` column must correspond one-to-one with the `id` column in train file..
|
14
|
+
|
15
|
+
- `Label percentages` - If you have filter file, please input label percentage you want to real run and `split with ','` when it's `more than one`. If you `don't have` filter file, than `keep the text vacant.`
|
10
16
|
|
11
|
-
- `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.
|
12
17
|
"""
|
13
18
|
)
|
14
19
|
st.caption(
|
@@ -14,6 +14,16 @@ class CustomDatasetConfig(BaseModel):
|
|
14
14
|
file_count: int = 1
|
15
15
|
use_shuffled: bool = False
|
16
16
|
with_gt: bool = True
|
17
|
+
train_name: str = "train"
|
18
|
+
test_name: str = "test"
|
19
|
+
gt_name: str = "neighbors"
|
20
|
+
train_id_name: str = "id"
|
21
|
+
train_col_name: str = "emb"
|
22
|
+
test_col_name: str = "emb"
|
23
|
+
gt_col_name: str = "neighbors_id"
|
24
|
+
scalar_labels_name: str = "scalar_labels"
|
25
|
+
label_percentages: list[str] = []
|
26
|
+
with_label_percentages: list[float] = [0.001, 0.02, 0.5]
|
17
27
|
|
18
28
|
|
19
29
|
class CustomCaseConfig(BaseModel):
|