vectordb-bench 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +22 -0
- vectordb_bench/backend/clients/api.py +21 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +47 -6
- vectordb_bench/backend/clients/aws_opensearch/config.py +12 -6
- vectordb_bench/backend/clients/aws_opensearch/run.py +34 -3
- vectordb_bench/backend/clients/memorydb/cli.py +88 -0
- vectordb_bench/backend/clients/memorydb/config.py +54 -0
- vectordb_bench/backend/clients/memorydb/memorydb.py +254 -0
- vectordb_bench/backend/clients/pgvecto_rs/cli.py +154 -0
- vectordb_bench/backend/clients/pgvecto_rs/config.py +108 -73
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +159 -59
- vectordb_bench/backend/clients/pgvector/cli.py +17 -2
- vectordb_bench/backend/clients/pgvector/config.py +20 -5
- vectordb_bench/backend/clients/pgvector/pgvector.py +95 -25
- vectordb_bench/backend/clients/pgvectorscale/cli.py +108 -0
- vectordb_bench/backend/clients/pgvectorscale/config.py +111 -0
- vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +290 -0
- vectordb_bench/backend/clients/pinecone/config.py +0 -2
- vectordb_bench/backend/clients/pinecone/pinecone.py +34 -36
- vectordb_bench/backend/clients/redis/cli.py +8 -0
- vectordb_bench/backend/clients/redis/config.py +37 -6
- vectordb_bench/backend/runner/mp_runner.py +2 -1
- vectordb_bench/cli/cli.py +137 -0
- vectordb_bench/cli/vectordbbench.py +7 -1
- vectordb_bench/frontend/components/check_results/charts.py +9 -6
- vectordb_bench/frontend/components/check_results/data.py +13 -6
- vectordb_bench/frontend/components/concurrent/charts.py +3 -6
- vectordb_bench/frontend/components/run_test/caseSelector.py +10 -0
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +37 -15
- vectordb_bench/frontend/components/run_test/initStyle.py +3 -1
- vectordb_bench/frontend/config/dbCaseConfigs.py +230 -9
- vectordb_bench/frontend/pages/quries_per_dollar.py +13 -5
- vectordb_bench/frontend/vdb_benchmark.py +11 -3
- vectordb_bench/models.py +25 -9
- vectordb_bench/results/Milvus/result_20230727_standard_milvus.json +53 -1
- vectordb_bench/results/Milvus/result_20230808_standard_milvus.json +48 -0
- vectordb_bench/results/ZillizCloud/result_20230727_standard_zillizcloud.json +29 -1
- vectordb_bench/results/ZillizCloud/result_20230808_standard_zillizcloud.json +24 -0
- vectordb_bench/results/ZillizCloud/result_20240105_standard_202401_zillizcloud.json +98 -49
- vectordb_bench/results/getLeaderboardData.py +17 -7
- vectordb_bench/results/leaderboard.json +1 -1
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/METADATA +64 -31
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/RECORD +47 -40
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.12.dist-info → vectordb_bench-0.0.14.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,45 @@
|
|
1
|
-
from pydantic import SecretStr
|
2
|
-
from ..api import DBConfig
|
1
|
+
from pydantic import SecretStr, BaseModel
|
2
|
+
from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
|
3
3
|
|
4
4
|
class RedisConfig(DBConfig):
|
5
|
-
password: SecretStr
|
5
|
+
password: SecretStr | None = None
|
6
6
|
host: SecretStr
|
7
|
-
port: int = None
|
7
|
+
port: int | None = None
|
8
8
|
|
9
9
|
def to_dict(self) -> dict:
|
10
10
|
return {
|
11
11
|
"host": self.host.get_secret_value(),
|
12
12
|
"port": self.port,
|
13
|
-
"password": self.password.get_secret_value(),
|
14
|
-
}
|
13
|
+
"password": self.password.get_secret_value() if self.password is not None else None,
|
14
|
+
}
|
15
|
+
|
16
|
+
|
17
|
+
|
18
|
+
class RedisIndexConfig(BaseModel):
|
19
|
+
"""Base config for milvus"""
|
20
|
+
|
21
|
+
metric_type: MetricType | None = None
|
22
|
+
|
23
|
+
def parse_metric(self) -> str:
|
24
|
+
if not self.metric_type:
|
25
|
+
return ""
|
26
|
+
return self.metric_type.value
|
27
|
+
|
28
|
+
class RedisHNSWConfig(RedisIndexConfig, DBCaseConfig):
|
29
|
+
M: int
|
30
|
+
efConstruction: int
|
31
|
+
ef: int | None = None
|
32
|
+
index: IndexType = IndexType.HNSW
|
33
|
+
|
34
|
+
def index_param(self) -> dict:
|
35
|
+
return {
|
36
|
+
"metric_type": self.parse_metric(),
|
37
|
+
"index_type": self.index.value,
|
38
|
+
"params": {"M": self.M, "efConstruction": self.efConstruction},
|
39
|
+
}
|
40
|
+
|
41
|
+
def search_param(self) -> dict:
|
42
|
+
return {
|
43
|
+
"metric_type": self.parse_metric(),
|
44
|
+
"params": {"ef": self.ef},
|
45
|
+
}
|
@@ -2,6 +2,7 @@ import time
|
|
2
2
|
import traceback
|
3
3
|
import concurrent
|
4
4
|
import multiprocessing as mp
|
5
|
+
import random
|
5
6
|
import logging
|
6
7
|
from typing import Iterable
|
7
8
|
import numpy as np
|
@@ -46,7 +47,7 @@ class MultiProcessingSearchRunner:
|
|
46
47
|
cond.wait()
|
47
48
|
|
48
49
|
with self.db.init():
|
49
|
-
num, idx = len(test_data), 0
|
50
|
+
num, idx = len(test_data), random.randint(0, len(test_data) - 1)
|
50
51
|
|
51
52
|
start_time = time.perf_counter()
|
52
53
|
count = 0
|
vectordb_bench/cli/cli.py
CHANGED
@@ -17,6 +17,8 @@ from typing import (
|
|
17
17
|
Any,
|
18
18
|
)
|
19
19
|
import click
|
20
|
+
|
21
|
+
from vectordb_bench.backend.clients.api import MetricType
|
20
22
|
from .. import config
|
21
23
|
from ..backend.clients import DB
|
22
24
|
from ..interface import benchMarkRunner, global_result_future
|
@@ -147,6 +149,37 @@ def parse_task_stages(
|
|
147
149
|
return stages
|
148
150
|
|
149
151
|
|
152
|
+
def check_custom_case_parameters(ctx, param, value):
|
153
|
+
if ctx.params.get("case_type") == "PerformanceCustomDataset":
|
154
|
+
if value is None:
|
155
|
+
raise click.BadParameter("Custom case parameters\
|
156
|
+
\n--custom-case-name\n--custom-dataset-name\n--custom-dataset-dir\n--custom-dataset-size \
|
157
|
+
\n--custom-dataset-dim\n--custom-dataset-file-count\n are required")
|
158
|
+
return value
|
159
|
+
|
160
|
+
|
161
|
+
def get_custom_case_config(parameters: dict) -> dict:
|
162
|
+
custom_case_config = {}
|
163
|
+
if parameters["case_type"] == "PerformanceCustomDataset":
|
164
|
+
custom_case_config = {
|
165
|
+
"name": parameters["custom_case_name"],
|
166
|
+
"description": parameters["custom_case_description"],
|
167
|
+
"load_timeout": parameters["custom_case_load_timeout"],
|
168
|
+
"optimize_timeout": parameters["custom_case_optimize_timeout"],
|
169
|
+
"dataset_config": {
|
170
|
+
"name": parameters["custom_dataset_name"],
|
171
|
+
"dir": parameters["custom_dataset_dir"],
|
172
|
+
"size": parameters["custom_dataset_size"],
|
173
|
+
"dim": parameters["custom_dataset_dim"],
|
174
|
+
"metric_type": parameters["custom_dataset_metric_type"],
|
175
|
+
"file_count": parameters["custom_dataset_file_count"],
|
176
|
+
"use_shuffled": parameters["custom_dataset_use_shuffled"],
|
177
|
+
"with_gt": parameters["custom_dataset_with_gt"],
|
178
|
+
}
|
179
|
+
}
|
180
|
+
return custom_case_config
|
181
|
+
|
182
|
+
|
150
183
|
log = logging.getLogger(__name__)
|
151
184
|
|
152
185
|
|
@@ -205,6 +238,7 @@ class CommonTypedDict(TypedDict):
|
|
205
238
|
click.option(
|
206
239
|
"--case-type",
|
207
240
|
type=click.Choice([ct.name for ct in CaseType if ct.name != "Custom"]),
|
241
|
+
is_eager=True,
|
208
242
|
default="Performance1536D50K",
|
209
243
|
help="Case type",
|
210
244
|
),
|
@@ -258,6 +292,108 @@ class CommonTypedDict(TypedDict):
|
|
258
292
|
callback=lambda *args: list(map(int, click_arg_split(*args))),
|
259
293
|
),
|
260
294
|
]
|
295
|
+
custom_case_name: Annotated[
|
296
|
+
str,
|
297
|
+
click.option(
|
298
|
+
"--custom-case-name",
|
299
|
+
help="Custom dataset case name",
|
300
|
+
callback=check_custom_case_parameters,
|
301
|
+
)
|
302
|
+
]
|
303
|
+
custom_case_description: Annotated[
|
304
|
+
str,
|
305
|
+
click.option(
|
306
|
+
"--custom-case-description",
|
307
|
+
help="Custom dataset case description",
|
308
|
+
default="This is a customized dataset.",
|
309
|
+
show_default=True,
|
310
|
+
)
|
311
|
+
]
|
312
|
+
custom_case_load_timeout: Annotated[
|
313
|
+
int,
|
314
|
+
click.option(
|
315
|
+
"--custom-case-load-timeout",
|
316
|
+
help="Custom dataset case load timeout",
|
317
|
+
default=36000,
|
318
|
+
show_default=True,
|
319
|
+
)
|
320
|
+
]
|
321
|
+
custom_case_optimize_timeout: Annotated[
|
322
|
+
int,
|
323
|
+
click.option(
|
324
|
+
"--custom-case-optimize-timeout",
|
325
|
+
help="Custom dataset case optimize timeout",
|
326
|
+
default=36000,
|
327
|
+
show_default=True,
|
328
|
+
)
|
329
|
+
]
|
330
|
+
custom_dataset_name: Annotated[
|
331
|
+
str,
|
332
|
+
click.option(
|
333
|
+
"--custom-dataset-name",
|
334
|
+
help="Custom dataset name",
|
335
|
+
callback=check_custom_case_parameters,
|
336
|
+
),
|
337
|
+
]
|
338
|
+
custom_dataset_dir: Annotated[
|
339
|
+
str,
|
340
|
+
click.option(
|
341
|
+
"--custom-dataset-dir",
|
342
|
+
help="Custom dataset directory",
|
343
|
+
callback=check_custom_case_parameters,
|
344
|
+
),
|
345
|
+
]
|
346
|
+
custom_dataset_size: Annotated[
|
347
|
+
int,
|
348
|
+
click.option(
|
349
|
+
"--custom-dataset-size",
|
350
|
+
help="Custom dataset size",
|
351
|
+
callback=check_custom_case_parameters,
|
352
|
+
),
|
353
|
+
]
|
354
|
+
custom_dataset_dim: Annotated[
|
355
|
+
int,
|
356
|
+
click.option(
|
357
|
+
"--custom-dataset-dim",
|
358
|
+
help="Custom dataset dimension",
|
359
|
+
callback=check_custom_case_parameters,
|
360
|
+
),
|
361
|
+
]
|
362
|
+
custom_dataset_metric_type: Annotated[
|
363
|
+
str,
|
364
|
+
click.option(
|
365
|
+
"--custom-dataset-metric-type",
|
366
|
+
help="Custom dataset metric type",
|
367
|
+
default=MetricType.COSINE.name,
|
368
|
+
show_default=True,
|
369
|
+
),
|
370
|
+
]
|
371
|
+
custom_dataset_file_count: Annotated[
|
372
|
+
int,
|
373
|
+
click.option(
|
374
|
+
"--custom-dataset-file-count",
|
375
|
+
help="Custom dataset file count",
|
376
|
+
callback=check_custom_case_parameters,
|
377
|
+
),
|
378
|
+
]
|
379
|
+
custom_dataset_use_shuffled: Annotated[
|
380
|
+
bool,
|
381
|
+
click.option(
|
382
|
+
"--custom-dataset-use-shuffled/--skip-custom-dataset-use-shuffled",
|
383
|
+
help="Custom dataset use shuffled",
|
384
|
+
default=False,
|
385
|
+
show_default=True,
|
386
|
+
),
|
387
|
+
]
|
388
|
+
custom_dataset_with_gt: Annotated[
|
389
|
+
bool,
|
390
|
+
click.option(
|
391
|
+
"--custom-dataset-with-gt/--skip-custom-dataset-with-gt",
|
392
|
+
help="Custom dataset with ground truth",
|
393
|
+
default=True,
|
394
|
+
show_default=True,
|
395
|
+
),
|
396
|
+
]
|
261
397
|
|
262
398
|
|
263
399
|
class HNSWBaseTypedDict(TypedDict):
|
@@ -343,6 +479,7 @@ def run(
|
|
343
479
|
concurrency_duration=parameters["concurrency_duration"],
|
344
480
|
num_concurrency=[int(s) for s in parameters["num_concurrency"]],
|
345
481
|
),
|
482
|
+
custom_case=parameters.get("custom_case", {}),
|
346
483
|
),
|
347
484
|
stages=parse_task_stages(
|
348
485
|
(
|
@@ -1,21 +1,27 @@
|
|
1
1
|
from ..backend.clients.pgvector.cli import PgVectorHNSW
|
2
|
+
from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat
|
3
|
+
from ..backend.clients.pgvectorscale.cli import PgVectorScaleDiskAnn
|
2
4
|
from ..backend.clients.redis.cli import Redis
|
5
|
+
from ..backend.clients.memorydb.cli import MemoryDB
|
3
6
|
from ..backend.clients.test.cli import Test
|
4
7
|
from ..backend.clients.weaviate_cloud.cli import Weaviate
|
5
8
|
from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
|
6
9
|
from ..backend.clients.milvus.cli import MilvusAutoIndex
|
7
10
|
from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
|
8
11
|
|
9
|
-
|
10
12
|
from .cli import cli
|
11
13
|
|
12
14
|
cli.add_command(PgVectorHNSW)
|
15
|
+
cli.add_command(PgVectoRSHNSW)
|
16
|
+
cli.add_command(PgVectoRSIVFFlat)
|
13
17
|
cli.add_command(Redis)
|
18
|
+
cli.add_command(MemoryDB)
|
14
19
|
cli.add_command(Weaviate)
|
15
20
|
cli.add_command(Test)
|
16
21
|
cli.add_command(ZillizAutoIndex)
|
17
22
|
cli.add_command(MilvusAutoIndex)
|
18
23
|
cli.add_command(AWSOpenSearch)
|
24
|
+
cli.add_command(PgVectorScaleDiskAnn)
|
19
25
|
|
20
26
|
|
21
27
|
if __name__ == "__main__":
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from vectordb_bench.backend.cases import Case
|
2
|
-
from vectordb_bench.frontend.components.check_results.expanderStyle import
|
2
|
+
from vectordb_bench.frontend.components.check_results.expanderStyle import (
|
3
|
+
initMainExpanderStyle,
|
4
|
+
)
|
3
5
|
from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
|
4
6
|
from vectordb_bench.frontend.config.styles import *
|
5
7
|
from vectordb_bench.models import ResultLabel
|
@@ -11,7 +13,7 @@ def drawCharts(st, allData, failedTasks, caseNames: list[str]):
|
|
11
13
|
for caseName in caseNames:
|
12
14
|
chartContainer = st.expander(caseName, True)
|
13
15
|
data = [data for data in allData if data["case_name"] == caseName]
|
14
|
-
drawChart(data, chartContainer)
|
16
|
+
drawChart(data, chartContainer, key_prefix=caseName)
|
15
17
|
|
16
18
|
errorDBs = failedTasks[caseName]
|
17
19
|
showFailedDBs(chartContainer, errorDBs)
|
@@ -35,7 +37,7 @@ def showFailedText(st, text, dbs):
|
|
35
37
|
)
|
36
38
|
|
37
39
|
|
38
|
-
def drawChart(data, st):
|
40
|
+
def drawChart(data, st, key_prefix: str):
|
39
41
|
metricsSet = set()
|
40
42
|
for d in data:
|
41
43
|
metricsSet = metricsSet.union(d["metricsSet"])
|
@@ -43,7 +45,8 @@ def drawChart(data, st):
|
|
43
45
|
|
44
46
|
for i, metric in enumerate(showMetrics):
|
45
47
|
container = st.container()
|
46
|
-
|
48
|
+
key = f"{key_prefix}-{metric}"
|
49
|
+
drawMetricChart(data, metric, container, key=key)
|
47
50
|
|
48
51
|
|
49
52
|
def getLabelToShapeMap(data):
|
@@ -75,7 +78,7 @@ def getLabelToShapeMap(data):
|
|
75
78
|
return labelToShapeMap
|
76
79
|
|
77
80
|
|
78
|
-
def drawMetricChart(data, metric, st):
|
81
|
+
def drawMetricChart(data, metric, st, key: str):
|
79
82
|
dataWithMetric = [d for d in data if d.get(metric, 0) > 1e-7]
|
80
83
|
# dataWithMetric = data
|
81
84
|
if len(dataWithMetric) == 0:
|
@@ -161,4 +164,4 @@ def drawMetricChart(data, metric, st):
|
|
161
164
|
),
|
162
165
|
)
|
163
166
|
|
164
|
-
chart.plotly_chart(fig, use_container_width=True)
|
167
|
+
chart.plotly_chart(fig, use_container_width=True, key=key)
|
@@ -24,7 +24,10 @@ def getFilterTasks(
|
|
24
24
|
task
|
25
25
|
for task in tasks
|
26
26
|
if task.task_config.db_name in dbNames
|
27
|
-
and task.task_config.case_config.case_id.case_cls(
|
27
|
+
and task.task_config.case_config.case_id.case_cls(
|
28
|
+
task.task_config.case_config.custom_case
|
29
|
+
).name
|
30
|
+
in caseNames
|
28
31
|
]
|
29
32
|
return filterTasks
|
30
33
|
|
@@ -35,17 +38,20 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
35
38
|
db_name = task.task_config.db_name
|
36
39
|
db = task.task_config.db.value
|
37
40
|
db_label = task.task_config.db_config.db_label or ""
|
38
|
-
|
41
|
+
version = task.task_config.db_config.version or ""
|
42
|
+
case = task.task_config.case_config.case_id.case_cls(
|
43
|
+
task.task_config.case_config.custom_case
|
44
|
+
)
|
39
45
|
dbCaseMetricsMap[db_name][case.name] = {
|
40
46
|
"db": db,
|
41
47
|
"db_label": db_label,
|
48
|
+
"version": version,
|
42
49
|
"metrics": mergeMetrics(
|
43
50
|
dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
|
44
51
|
asdict(task.metrics),
|
45
52
|
),
|
46
53
|
"label": getBetterLabel(
|
47
|
-
dbCaseMetricsMap[db_name][case.name].get(
|
48
|
-
"label", ResultLabel.FAILED),
|
54
|
+
dbCaseMetricsMap[db_name][case.name].get("label", ResultLabel.FAILED),
|
49
55
|
task.label,
|
50
56
|
),
|
51
57
|
}
|
@@ -57,6 +63,7 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
57
63
|
metrics = metricInfo["metrics"]
|
58
64
|
db = metricInfo["db"]
|
59
65
|
db_label = metricInfo["db_label"]
|
66
|
+
version = metricInfo["version"]
|
60
67
|
label = metricInfo["label"]
|
61
68
|
if label == ResultLabel.NORMAL:
|
62
69
|
mergedTasks.append(
|
@@ -64,6 +71,7 @@ def mergeTasks(tasks: list[CaseResult]):
|
|
64
71
|
"db_name": db_name,
|
65
72
|
"db": db,
|
66
73
|
"db_label": db_label,
|
74
|
+
"version": version,
|
67
75
|
"case_name": case_name,
|
68
76
|
"metricsSet": set(metrics.keys()),
|
69
77
|
**metrics,
|
@@ -79,8 +87,7 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
|
|
79
87
|
metrics = {**metrics_1}
|
80
88
|
for key, value in metrics_2.items():
|
81
89
|
metrics[key] = (
|
82
|
-
getBetterMetric(
|
83
|
-
key, value, metrics[key]) if key in metrics else value
|
90
|
+
getBetterMetric(key, value, metrics[key]) if key in metrics else value
|
84
91
|
)
|
85
92
|
|
86
93
|
return metrics
|
@@ -22,7 +22,7 @@ def drawChartsByCase(allData, showCaseNames: list[str], st):
|
|
22
22
|
for caseData in caseDataList
|
23
23
|
for i in range(len(caseData["conc_num_list"]))
|
24
24
|
]
|
25
|
-
drawChart(data, chartContainer)
|
25
|
+
drawChart(data, chartContainer, key=f"{caseName}-qps-p99")
|
26
26
|
|
27
27
|
|
28
28
|
def getRange(metric, data, padding_multipliers):
|
@@ -36,7 +36,7 @@ def getRange(metric, data, padding_multipliers):
|
|
36
36
|
return rangeV
|
37
37
|
|
38
38
|
|
39
|
-
def drawChart(data, st):
|
39
|
+
def drawChart(data, st, key: str):
|
40
40
|
if len(data) == 0:
|
41
41
|
return
|
42
42
|
|
@@ -73,7 +73,4 @@ def drawChart(data, st):
|
|
73
73
|
fig.update_yaxes(range=yrange, title_text="QPS")
|
74
74
|
fig.update_traces(textposition="bottom right", texttemplate="conc-%{text:,.4~r}")
|
75
75
|
|
76
|
-
st.plotly_chart(
|
77
|
-
fig,
|
78
|
-
use_container_width=True,
|
79
|
-
)
|
76
|
+
st.plotly_chart(fig, use_container_width=True, key=key)
|
@@ -100,6 +100,16 @@ def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, active
|
|
100
100
|
value=config.inputConfig["value"],
|
101
101
|
help=config.inputHelp,
|
102
102
|
)
|
103
|
+
elif config.inputType == InputType.Float:
|
104
|
+
caseConfig[config.label] = column.number_input(
|
105
|
+
config.displayLabel if config.displayLabel else config.label.value,
|
106
|
+
step=config.inputConfig.get("step", 0.1),
|
107
|
+
min_value=config.inputConfig["min"],
|
108
|
+
max_value=config.inputConfig["max"],
|
109
|
+
key=key,
|
110
|
+
value=config.inputConfig["value"],
|
111
|
+
help=config.inputHelp,
|
112
|
+
)
|
103
113
|
k += 1
|
104
114
|
if k == 0:
|
105
115
|
columns[1].write("Auto")
|
@@ -1,9 +1,10 @@
|
|
1
1
|
from pydantic import ValidationError
|
2
|
-
from vectordb_bench.
|
2
|
+
from vectordb_bench.backend.clients import DB
|
3
|
+
from vectordb_bench.frontend.config.styles import DB_CONFIG_SETTING_COLUMNS
|
3
4
|
from vectordb_bench.frontend.utils import inputIsPassword
|
4
5
|
|
5
6
|
|
6
|
-
def dbConfigSettings(st, activedDbList):
|
7
|
+
def dbConfigSettings(st, activedDbList: list[DB]):
|
7
8
|
expander = st.expander("Configurations for the selected databases", True)
|
8
9
|
|
9
10
|
dbConfigs = {}
|
@@ -27,7 +28,7 @@ def dbConfigSettings(st, activedDbList):
|
|
27
28
|
return dbConfigs, isAllValid
|
28
29
|
|
29
30
|
|
30
|
-
def dbConfigSettingItem(st, activeDb):
|
31
|
+
def dbConfigSettingItem(st, activeDb: DB):
|
31
32
|
st.markdown(
|
32
33
|
f"<div style='font-weight: 600; font-size: 20px; margin-top: 16px;'>{activeDb.value}</div>",
|
33
34
|
unsafe_allow_html=True,
|
@@ -36,20 +37,41 @@ def dbConfigSettingItem(st, activeDb):
|
|
36
37
|
|
37
38
|
dbConfigClass = activeDb.config_cls
|
38
39
|
properties = dbConfigClass.schema().get("properties")
|
39
|
-
propertiesItems = list(properties.items())
|
40
|
-
moveDBLabelToLast(propertiesItems)
|
41
40
|
dbConfig = {}
|
42
|
-
|
43
|
-
|
44
|
-
|
41
|
+
idx = 0
|
42
|
+
|
43
|
+
# db config (unique)
|
44
|
+
for key, property in properties.items():
|
45
|
+
if (
|
46
|
+
key not in dbConfigClass.common_short_configs()
|
47
|
+
and key not in dbConfigClass.common_long_configs()
|
48
|
+
):
|
49
|
+
column = columns[idx % DB_CONFIG_SETTING_COLUMNS]
|
50
|
+
idx += 1
|
51
|
+
dbConfig[key] = column.text_input(
|
52
|
+
key,
|
53
|
+
key="%s-%s" % (activeDb.name, key),
|
54
|
+
value=property.get("default", ""),
|
55
|
+
type="password" if inputIsPassword(key) else "default",
|
56
|
+
)
|
57
|
+
# db config (common short labels)
|
58
|
+
for key in dbConfigClass.common_short_configs():
|
59
|
+
column = columns[idx % DB_CONFIG_SETTING_COLUMNS]
|
60
|
+
idx += 1
|
45
61
|
dbConfig[key] = column.text_input(
|
46
62
|
key,
|
47
|
-
key="%s-%s" % (activeDb, key),
|
48
|
-
value=
|
49
|
-
type="
|
63
|
+
key="%s-%s" % (activeDb.name, key),
|
64
|
+
value="",
|
65
|
+
type="default",
|
66
|
+
placeholder="optional, for labeling results",
|
50
67
|
)
|
51
|
-
return dbConfig
|
52
|
-
|
53
68
|
|
54
|
-
|
55
|
-
|
69
|
+
# db config (common long text_input)
|
70
|
+
for key in dbConfigClass.common_long_configs():
|
71
|
+
dbConfig[key] = st.text_area(
|
72
|
+
key,
|
73
|
+
key="%s-%s" % (activeDb.name, key),
|
74
|
+
value="",
|
75
|
+
placeholder="optional",
|
76
|
+
)
|
77
|
+
return dbConfig
|
@@ -9,6 +9,8 @@ def initStyle(st):
|
|
9
9
|
div[data-testid='stHorizontalBlock'] {gap: 8px;}
|
10
10
|
/* check box */
|
11
11
|
.stCheckbox p { color: #000; font-size: 18px; font-weight: 600; }
|
12
|
+
/* db selector - db_name should not wrap */
|
13
|
+
div[data-testid="stVerticalBlockBorderWrapper"] div[data-testid="stCheckbox"] div[data-testid="stWidgetLabel"] p { white-space: nowrap; }
|
12
14
|
</style>""",
|
13
15
|
unsafe_allow_html=True,
|
14
|
-
)
|
16
|
+
)
|