vectordb-bench 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +19 -5
- vectordb_bench/backend/assembler.py +1 -1
- vectordb_bench/backend/cases.py +93 -27
- vectordb_bench/backend/clients/__init__.py +14 -0
- vectordb_bench/backend/clients/api.py +1 -1
- vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +159 -0
- vectordb_bench/backend/clients/aws_opensearch/cli.py +44 -0
- vectordb_bench/backend/clients/aws_opensearch/config.py +58 -0
- vectordb_bench/backend/clients/aws_opensearch/run.py +125 -0
- vectordb_bench/backend/clients/milvus/cli.py +291 -0
- vectordb_bench/backend/clients/milvus/milvus.py +13 -6
- vectordb_bench/backend/clients/pgvector/cli.py +116 -0
- vectordb_bench/backend/clients/pgvector/config.py +1 -1
- vectordb_bench/backend/clients/pgvector/pgvector.py +7 -4
- vectordb_bench/backend/clients/redis/cli.py +74 -0
- vectordb_bench/backend/clients/test/cli.py +25 -0
- vectordb_bench/backend/clients/test/config.py +18 -0
- vectordb_bench/backend/clients/test/test.py +62 -0
- vectordb_bench/backend/clients/weaviate_cloud/cli.py +41 -0
- vectordb_bench/backend/clients/zilliz_cloud/cli.py +55 -0
- vectordb_bench/backend/dataset.py +27 -5
- vectordb_bench/backend/runner/mp_runner.py +14 -3
- vectordb_bench/backend/runner/serial_runner.py +7 -3
- vectordb_bench/backend/task_runner.py +76 -26
- vectordb_bench/cli/__init__.py +0 -0
- vectordb_bench/cli/cli.py +362 -0
- vectordb_bench/cli/vectordbbench.py +22 -0
- vectordb_bench/config-files/sample_config.yml +17 -0
- vectordb_bench/custom/custom_case.json +18 -0
- vectordb_bench/frontend/components/check_results/charts.py +6 -6
- vectordb_bench/frontend/components/check_results/data.py +23 -20
- vectordb_bench/frontend/components/check_results/expanderStyle.py +1 -1
- vectordb_bench/frontend/components/check_results/filters.py +20 -13
- vectordb_bench/frontend/components/check_results/headerIcon.py +1 -1
- vectordb_bench/frontend/components/check_results/priceTable.py +1 -1
- vectordb_bench/frontend/components/check_results/stPageConfig.py +1 -1
- vectordb_bench/frontend/components/concurrent/charts.py +79 -0
- vectordb_bench/frontend/components/custom/displayCustomCase.py +31 -0
- vectordb_bench/frontend/components/custom/displaypPrams.py +11 -0
- vectordb_bench/frontend/components/custom/getCustomConfig.py +40 -0
- vectordb_bench/frontend/components/custom/initStyle.py +15 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +1 -1
- vectordb_bench/frontend/components/run_test/caseSelector.py +40 -28
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -5
- vectordb_bench/frontend/components/run_test/dbSelector.py +8 -14
- vectordb_bench/frontend/components/run_test/generateTasks.py +3 -5
- vectordb_bench/frontend/components/run_test/initStyle.py +14 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +13 -5
- vectordb_bench/frontend/components/tables/data.py +44 -0
- vectordb_bench/frontend/{const → config}/dbCaseConfigs.py +140 -32
- vectordb_bench/frontend/{const → config}/styles.py +2 -0
- vectordb_bench/frontend/pages/concurrent.py +65 -0
- vectordb_bench/frontend/pages/custom.py +64 -0
- vectordb_bench/frontend/pages/quries_per_dollar.py +5 -5
- vectordb_bench/frontend/pages/run_test.py +4 -0
- vectordb_bench/frontend/pages/tables.py +24 -0
- vectordb_bench/frontend/utils.py +17 -1
- vectordb_bench/frontend/vdb_benchmark.py +3 -3
- vectordb_bench/interface.py +21 -25
- vectordb_bench/metric.py +23 -1
- vectordb_bench/models.py +45 -1
- vectordb_bench/results/getLeaderboardData.py +1 -1
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/METADATA +228 -14
- vectordb_bench-0.0.12.dist-info/RECORD +115 -0
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/entry_points.txt +1 -0
- vectordb_bench-0.0.10.dist-info/RECORD +0 -88
- /vectordb_bench/frontend/{const → config}/dbPrices.py +0 -0
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.10.dist-info → vectordb_bench-0.0.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,25 @@
|
|
1
|
+
from typing import Unpack
|
2
|
+
|
3
|
+
from ....cli.cli import (
|
4
|
+
CommonTypedDict,
|
5
|
+
cli,
|
6
|
+
click_parameter_decorators_from_typed_dict,
|
7
|
+
run,
|
8
|
+
)
|
9
|
+
from .. import DB
|
10
|
+
from ..test.config import TestConfig, TestIndexConfig
|
11
|
+
|
12
|
+
|
13
|
+
class TestTypedDict(CommonTypedDict):
|
14
|
+
...
|
15
|
+
|
16
|
+
|
17
|
+
@cli.command()
|
18
|
+
@click_parameter_decorators_from_typed_dict(TestTypedDict)
|
19
|
+
def Test(**parameters: Unpack[TestTypedDict]):
|
20
|
+
run(
|
21
|
+
db=DB.NewClient,
|
22
|
+
db_config=TestConfig(db_label=parameters["db_label"]),
|
23
|
+
db_case_config=TestIndexConfig(),
|
24
|
+
**parameters,
|
25
|
+
)
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
|
3
|
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
|
4
|
+
|
5
|
+
|
6
|
+
class TestConfig(DBConfig):
|
7
|
+
def to_dict(self) -> dict:
|
8
|
+
return {"db_label": self.db_label}
|
9
|
+
|
10
|
+
|
11
|
+
class TestIndexConfig(BaseModel, DBCaseConfig):
|
12
|
+
metric_type: MetricType | None = None
|
13
|
+
|
14
|
+
def index_param(self) -> dict:
|
15
|
+
return {}
|
16
|
+
|
17
|
+
def search_param(self) -> dict:
|
18
|
+
return {}
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Any, Generator, Optional, Tuple
|
4
|
+
|
5
|
+
from ..api import DBCaseConfig, VectorDB
|
6
|
+
|
7
|
+
log = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
class Test(VectorDB):
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
dim: int,
|
14
|
+
db_config: dict,
|
15
|
+
db_case_config: DBCaseConfig,
|
16
|
+
drop_old: bool = False,
|
17
|
+
**kwargs,
|
18
|
+
):
|
19
|
+
self.db_config = db_config
|
20
|
+
self.case_config = db_case_config
|
21
|
+
|
22
|
+
log.info("Starting Test DB")
|
23
|
+
|
24
|
+
@contextmanager
|
25
|
+
def init(self) -> Generator[None, None, None]:
|
26
|
+
"""create and destroy connections to database.
|
27
|
+
|
28
|
+
Examples:
|
29
|
+
>>> with self.init():
|
30
|
+
>>> self.insert_embeddings()
|
31
|
+
"""
|
32
|
+
|
33
|
+
yield
|
34
|
+
|
35
|
+
def ready_to_load(self) -> bool:
|
36
|
+
return True
|
37
|
+
|
38
|
+
def optimize(self) -> None:
|
39
|
+
pass
|
40
|
+
|
41
|
+
def insert_embeddings(
|
42
|
+
self,
|
43
|
+
embeddings: list[list[float]],
|
44
|
+
metadata: list[int],
|
45
|
+
**kwargs: Any,
|
46
|
+
) -> Tuple[int, Optional[Exception]]:
|
47
|
+
"""Insert embeddings into the database.
|
48
|
+
Should call self.init() first.
|
49
|
+
"""
|
50
|
+
raise RuntimeError("Not implemented")
|
51
|
+
return len(metadata), None
|
52
|
+
|
53
|
+
def search_embedding(
|
54
|
+
self,
|
55
|
+
query: list[float],
|
56
|
+
k: int = 100,
|
57
|
+
filters: dict | None = None,
|
58
|
+
timeout: int | None = None,
|
59
|
+
**kwargs: Any,
|
60
|
+
) -> list[int]:
|
61
|
+
raise NotImplementedError
|
62
|
+
return [i for i in range(k)]
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Annotated, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
from pydantic import SecretStr
|
5
|
+
|
6
|
+
from ....cli.cli import (
|
7
|
+
CommonTypedDict,
|
8
|
+
cli,
|
9
|
+
click_parameter_decorators_from_typed_dict,
|
10
|
+
run,
|
11
|
+
)
|
12
|
+
from .. import DB
|
13
|
+
|
14
|
+
|
15
|
+
class WeaviateTypedDict(CommonTypedDict):
|
16
|
+
api_key: Annotated[
|
17
|
+
str, click.option("--api-key", type=str, help="Weaviate api key", required=True)
|
18
|
+
]
|
19
|
+
url: Annotated[
|
20
|
+
str,
|
21
|
+
click.option("--url", type=str, help="Weaviate url", required=True),
|
22
|
+
]
|
23
|
+
|
24
|
+
|
25
|
+
@cli.command()
|
26
|
+
@click_parameter_decorators_from_typed_dict(WeaviateTypedDict)
|
27
|
+
def Weaviate(**parameters: Unpack[WeaviateTypedDict]):
|
28
|
+
from .config import WeaviateConfig, WeaviateIndexConfig
|
29
|
+
|
30
|
+
run(
|
31
|
+
db=DB.WeaviateCloud,
|
32
|
+
db_config=WeaviateConfig(
|
33
|
+
db_label=parameters["db_label"],
|
34
|
+
api_key=SecretStr(parameters["api_key"]),
|
35
|
+
url=SecretStr(parameters["url"]),
|
36
|
+
),
|
37
|
+
db_case_config=WeaviateIndexConfig(
|
38
|
+
ef=256, efConstruction=256, maxConnections=16
|
39
|
+
),
|
40
|
+
**parameters,
|
41
|
+
)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
from typing import Annotated, Unpack
|
2
|
+
|
3
|
+
import click
|
4
|
+
import os
|
5
|
+
from pydantic import SecretStr
|
6
|
+
|
7
|
+
from vectordb_bench.cli.cli import (
|
8
|
+
CommonTypedDict,
|
9
|
+
cli,
|
10
|
+
click_parameter_decorators_from_typed_dict,
|
11
|
+
run,
|
12
|
+
)
|
13
|
+
from vectordb_bench.backend.clients import DB
|
14
|
+
|
15
|
+
|
16
|
+
class ZillizTypedDict(CommonTypedDict):
|
17
|
+
uri: Annotated[
|
18
|
+
str, click.option("--uri", type=str, help="uri connection string", required=True)
|
19
|
+
]
|
20
|
+
user_name: Annotated[
|
21
|
+
str, click.option("--user-name", type=str, help="Db username", required=True)
|
22
|
+
]
|
23
|
+
password: Annotated[
|
24
|
+
str,
|
25
|
+
click.option("--password",
|
26
|
+
type=str,
|
27
|
+
help="Zilliz password",
|
28
|
+
default=lambda: os.environ.get("ZILLIZ_PASSWORD", ""),
|
29
|
+
show_default="$ZILLIZ_PASSWORD",
|
30
|
+
),
|
31
|
+
]
|
32
|
+
level: Annotated[
|
33
|
+
str,
|
34
|
+
click.option("--level", type=str, help="Zilliz index level", required=False),
|
35
|
+
]
|
36
|
+
|
37
|
+
|
38
|
+
@cli.command()
|
39
|
+
@click_parameter_decorators_from_typed_dict(ZillizTypedDict)
|
40
|
+
def ZillizAutoIndex(**parameters: Unpack[ZillizTypedDict]):
|
41
|
+
from .config import ZillizCloudConfig, AutoIndexConfig
|
42
|
+
|
43
|
+
run(
|
44
|
+
db=DB.ZillizCloud,
|
45
|
+
db_config=ZillizCloudConfig(
|
46
|
+
db_label=parameters["db_label"],
|
47
|
+
uri=SecretStr(parameters["uri"]),
|
48
|
+
user=parameters["user_name"],
|
49
|
+
password=SecretStr(parameters["password"]),
|
50
|
+
),
|
51
|
+
db_case_config=AutoIndexConfig(
|
52
|
+
params={parameters["level"]},
|
53
|
+
),
|
54
|
+
**parameters,
|
55
|
+
)
|
@@ -33,6 +33,7 @@ class BaseDataset(BaseModel):
|
|
33
33
|
use_shuffled: bool
|
34
34
|
with_gt: bool = False
|
35
35
|
_size_label: dict[int, SizeLabel] = PrivateAttr()
|
36
|
+
isCustom: bool = False
|
36
37
|
|
37
38
|
@validator("size")
|
38
39
|
def verify_size(cls, v):
|
@@ -52,7 +53,27 @@ class BaseDataset(BaseModel):
|
|
52
53
|
def file_count(self) -> int:
|
53
54
|
return self._size_label.get(self.size).file_count
|
54
55
|
|
56
|
+
class CustomDataset(BaseDataset):
|
57
|
+
dir: str
|
58
|
+
file_num: int
|
59
|
+
isCustom: bool = True
|
60
|
+
|
61
|
+
@validator("size")
|
62
|
+
def verify_size(cls, v):
|
63
|
+
return v
|
64
|
+
|
65
|
+
@property
|
66
|
+
def label(self) -> str:
|
67
|
+
return "Custom"
|
55
68
|
|
69
|
+
@property
|
70
|
+
def dir_name(self) -> str:
|
71
|
+
return self.dir
|
72
|
+
|
73
|
+
@property
|
74
|
+
def file_count(self) -> int:
|
75
|
+
return self.file_num
|
76
|
+
|
56
77
|
class LAION(BaseDataset):
|
57
78
|
name: str = "LAION"
|
58
79
|
dim: int = 768
|
@@ -186,11 +207,12 @@ class DatasetManager(BaseModel):
|
|
186
207
|
gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
|
187
208
|
all_files.extend([gt_file, test_file])
|
188
209
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
210
|
+
if not self.data.isCustom:
|
211
|
+
source.reader().read(
|
212
|
+
dataset=self.data.dir_name.lower(),
|
213
|
+
files=all_files,
|
214
|
+
local_ds_root=self.data_dir,
|
215
|
+
)
|
194
216
|
|
195
217
|
if gt_file is not None and test_file is not None:
|
196
218
|
self.test_data = self._read_file(test_file)
|
@@ -4,6 +4,7 @@ import concurrent
|
|
4
4
|
import multiprocessing as mp
|
5
5
|
import logging
|
6
6
|
from typing import Iterable
|
7
|
+
import numpy as np
|
7
8
|
from ..clients import api
|
8
9
|
from ... import config
|
9
10
|
|
@@ -49,6 +50,7 @@ class MultiProcessingSearchRunner:
|
|
49
50
|
|
50
51
|
start_time = time.perf_counter()
|
51
52
|
count = 0
|
53
|
+
latencies = []
|
52
54
|
while time.perf_counter() < start_time + self.duration:
|
53
55
|
s = time.perf_counter()
|
54
56
|
try:
|
@@ -61,7 +63,8 @@ class MultiProcessingSearchRunner:
|
|
61
63
|
log.warning(f"VectorDB search_embedding error: {e}")
|
62
64
|
traceback.print_exc(chain=True)
|
63
65
|
raise e from None
|
64
|
-
|
66
|
+
|
67
|
+
latencies.append(time.perf_counter() - s)
|
65
68
|
count += 1
|
66
69
|
# loop through the test data
|
67
70
|
idx = idx + 1 if idx < num - 1 else 0
|
@@ -75,7 +78,7 @@ class MultiProcessingSearchRunner:
|
|
75
78
|
f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}"
|
76
79
|
)
|
77
80
|
|
78
|
-
return (count, total_dur)
|
81
|
+
return (count, total_dur, latencies)
|
79
82
|
|
80
83
|
@staticmethod
|
81
84
|
def get_mp_context():
|
@@ -85,6 +88,9 @@ class MultiProcessingSearchRunner:
|
|
85
88
|
|
86
89
|
def _run_all_concurrencies_mem_efficient(self) -> float:
|
87
90
|
max_qps = 0
|
91
|
+
conc_num_list = []
|
92
|
+
conc_qps_list = []
|
93
|
+
conc_latency_p99_list = []
|
88
94
|
try:
|
89
95
|
for conc in self.concurrencies:
|
90
96
|
with mp.Manager() as m:
|
@@ -103,9 +109,14 @@ class MultiProcessingSearchRunner:
|
|
103
109
|
|
104
110
|
start = time.perf_counter()
|
105
111
|
all_count = sum([r.result()[0] for r in future_iter])
|
112
|
+
latencies = sum([r.result()[2] for r in future_iter], start=[])
|
113
|
+
latency_p99 = np.percentile(latencies, 0.99)
|
106
114
|
cost = time.perf_counter() - start
|
107
115
|
|
108
116
|
qps = round(all_count / cost, 4)
|
117
|
+
conc_num_list.append(conc)
|
118
|
+
conc_qps_list.append(qps)
|
119
|
+
conc_latency_p99_list.append(latency_p99)
|
109
120
|
log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}")
|
110
121
|
|
111
122
|
if qps > max_qps:
|
@@ -122,7 +133,7 @@ class MultiProcessingSearchRunner:
|
|
122
133
|
finally:
|
123
134
|
self.stop()
|
124
135
|
|
125
|
-
return max_qps
|
136
|
+
return max_qps, conc_num_list, conc_qps_list, conc_latency_p99_list
|
126
137
|
|
127
138
|
def run(self) -> float:
|
128
139
|
"""
|
@@ -10,7 +10,7 @@ import numpy as np
|
|
10
10
|
import pandas as pd
|
11
11
|
|
12
12
|
from ..clients import api
|
13
|
-
from ...metric import calc_recall
|
13
|
+
from ...metric import calc_ndcg, calc_recall, get_ideal_dcg
|
14
14
|
from ...models import LoadTimeoutError, PerformanceTimeoutError
|
15
15
|
from .. import utils
|
16
16
|
from ... import config
|
@@ -171,11 +171,12 @@ class SerialSearchRunner:
|
|
171
171
|
log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency")
|
172
172
|
with self.db.init():
|
173
173
|
test_data, ground_truth = args
|
174
|
+
ideal_dcg = get_ideal_dcg(self.k)
|
174
175
|
|
175
176
|
log.debug(f"test dataset size: {len(test_data)}")
|
176
177
|
log.debug(f"ground truth size: {ground_truth.columns}, shape: {ground_truth.shape}")
|
177
178
|
|
178
|
-
latencies, recalls = [], []
|
179
|
+
latencies, recalls, ndcgs = [], [], []
|
179
180
|
for idx, emb in enumerate(test_data):
|
180
181
|
s = time.perf_counter()
|
181
182
|
try:
|
@@ -194,6 +195,7 @@ class SerialSearchRunner:
|
|
194
195
|
|
195
196
|
gt = ground_truth['neighbors_id'][idx]
|
196
197
|
recalls.append(calc_recall(self.k, gt[:self.k], results))
|
198
|
+
ndcgs.append(calc_ndcg(gt[:self.k], results, ideal_dcg))
|
197
199
|
|
198
200
|
|
199
201
|
if len(latencies) % 100 == 0:
|
@@ -201,6 +203,7 @@ class SerialSearchRunner:
|
|
201
203
|
|
202
204
|
avg_latency = round(np.mean(latencies), 4)
|
203
205
|
avg_recall = round(np.mean(recalls), 4)
|
206
|
+
avg_ndcg = round(np.mean(ndcgs), 4)
|
204
207
|
cost = round(np.sum(latencies), 4)
|
205
208
|
p99 = round(np.percentile(latencies, 99), 4)
|
206
209
|
log.info(
|
@@ -208,10 +211,11 @@ class SerialSearchRunner:
|
|
208
211
|
f"cost={cost}s, "
|
209
212
|
f"queries={len(latencies)}, "
|
210
213
|
f"avg_recall={avg_recall}, "
|
214
|
+
f"avg_ndcg={avg_ndcg},"
|
211
215
|
f"avg_latency={avg_latency}, "
|
212
216
|
f"p99={p99}"
|
213
217
|
)
|
214
|
-
return (avg_recall, p99)
|
218
|
+
return (avg_recall, avg_ndcg, p99)
|
215
219
|
|
216
220
|
|
217
221
|
def _run_in_subprocess(self) -> tuple[float, float]:
|
@@ -8,7 +8,7 @@ from enum import Enum, auto
|
|
8
8
|
from . import utils
|
9
9
|
from .cases import Case, CaseLabel
|
10
10
|
from ..base import BaseModel
|
11
|
-
from ..models import TaskConfig, PerformanceTimeoutError
|
11
|
+
from ..models import TaskConfig, PerformanceTimeoutError, TaskStage
|
12
12
|
|
13
13
|
from .clients import (
|
14
14
|
api,
|
@@ -29,7 +29,7 @@ class RunningStatus(Enum):
|
|
29
29
|
|
30
30
|
|
31
31
|
class CaseRunner(BaseModel):
|
32
|
-
"""
|
32
|
+
"""DataSet, filter_rate, db_class with db config
|
33
33
|
|
34
34
|
Fields:
|
35
35
|
run_id(str): run_id of this case runner,
|
@@ -49,8 +49,9 @@ class CaseRunner(BaseModel):
|
|
49
49
|
|
50
50
|
db: api.VectorDB | None = None
|
51
51
|
test_emb: list[list[float]] | None = None
|
52
|
-
search_runner: MultiProcessingSearchRunner | None = None
|
53
52
|
serial_search_runner: SerialSearchRunner | None = None
|
53
|
+
search_runner: MultiProcessingSearchRunner | None = None
|
54
|
+
final_search_runner: MultiProcessingSearchRunner | None = None
|
54
55
|
|
55
56
|
def __eq__(self, obj):
|
56
57
|
if isinstance(obj, CaseRunner):
|
@@ -58,7 +59,7 @@ class CaseRunner(BaseModel):
|
|
58
59
|
self.config.db == obj.config.db and \
|
59
60
|
self.config.db_case_config == obj.config.db_case_config and \
|
60
61
|
self.ca.dataset == obj.ca.dataset
|
61
|
-
|
62
|
+
return False
|
62
63
|
|
63
64
|
def display(self) -> dict:
|
64
65
|
c_dict = self.ca.dict(include={'label':True, 'filters': True,'dataset':{'data': {'name': True, 'size': True, 'dim': True, 'metric_type': True, 'label': True}} })
|
@@ -79,20 +80,25 @@ class CaseRunner(BaseModel):
|
|
79
80
|
db_config=self.config.db_config.to_dict(),
|
80
81
|
db_case_config=self.config.db_case_config,
|
81
82
|
drop_old=drop_old,
|
82
|
-
)
|
83
|
+
) # type:ignore
|
84
|
+
|
83
85
|
|
84
86
|
def _pre_run(self, drop_old: bool = True):
|
85
87
|
try:
|
86
88
|
self.init_db(drop_old)
|
87
89
|
self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filter_rate)
|
88
90
|
except ModuleNotFoundError as e:
|
89
|
-
log.warning(
|
91
|
+
log.warning(
|
92
|
+
f"pre run case error: please install client for db: {self.config.db}, error={e}"
|
93
|
+
)
|
90
94
|
raise e from None
|
91
95
|
except Exception as e:
|
92
96
|
log.warning(f"pre run case error: {e}")
|
93
97
|
raise e from None
|
94
98
|
|
95
99
|
def run(self, drop_old: bool = True) -> Metric:
|
100
|
+
log.info("Starting run")
|
101
|
+
|
96
102
|
self._pre_run(drop_old)
|
97
103
|
|
98
104
|
if self.ca.label == CaseLabel.Load:
|
@@ -105,31 +111,35 @@ class CaseRunner(BaseModel):
|
|
105
111
|
raise ValueError(msg)
|
106
112
|
|
107
113
|
def _run_capacity_case(self) -> Metric:
|
108
|
-
"""
|
114
|
+
"""run capacity cases
|
109
115
|
|
110
116
|
Returns:
|
111
117
|
Metric: the max load count
|
112
118
|
"""
|
119
|
+
assert self.db is not None
|
113
120
|
log.info("Start capacity case")
|
114
121
|
try:
|
115
|
-
runner = SerialInsertRunner(
|
122
|
+
runner = SerialInsertRunner(
|
123
|
+
self.db, self.ca.dataset, self.normalize, self.ca.load_timeout
|
124
|
+
)
|
116
125
|
count = runner.run_endlessness()
|
117
126
|
except Exception as e:
|
118
127
|
log.warning(f"Failed to run capacity case, reason = {e}")
|
119
128
|
raise e from None
|
120
129
|
else:
|
121
|
-
log.info(
|
130
|
+
log.info(
|
131
|
+
f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}"
|
132
|
+
)
|
122
133
|
return Metric(max_load_count=count)
|
123
134
|
|
124
135
|
def _run_perf_case(self, drop_old: bool = True) -> Metric:
|
125
|
-
"""
|
136
|
+
"""run performance cases
|
126
137
|
|
127
138
|
Returns:
|
128
139
|
Metric: load_duration, recall, serial_latency_p99, and, qps
|
129
140
|
"""
|
130
|
-
|
131
|
-
|
132
|
-
if drop_old:
|
141
|
+
'''
|
142
|
+
if drop_old:
|
133
143
|
_, load_dur = self._load_train_data()
|
134
144
|
build_dur = self._optimize()
|
135
145
|
m.load_duration = round(load_dur+build_dur, 4)
|
@@ -140,8 +150,43 @@ class CaseRunner(BaseModel):
|
|
140
150
|
)
|
141
151
|
|
142
152
|
self._init_search_runner()
|
143
|
-
|
153
|
+
|
154
|
+
m.qps, m.conc_num_list, m.conc_qps_list, m.conc_latency_p99_list = self._conc_search()
|
144
155
|
m.recall, m.serial_latency_p99 = self._serial_search()
|
156
|
+
'''
|
157
|
+
|
158
|
+
log.info("Start performance case")
|
159
|
+
try:
|
160
|
+
m = Metric()
|
161
|
+
if drop_old:
|
162
|
+
if TaskStage.LOAD in self.config.stages:
|
163
|
+
# self._load_train_data()
|
164
|
+
_, load_dur = self._load_train_data()
|
165
|
+
build_dur = self._optimize()
|
166
|
+
m.load_duration = round(load_dur + build_dur, 4)
|
167
|
+
log.info(
|
168
|
+
f"Finish loading the entire dataset into VectorDB,"
|
169
|
+
f" insert_duration={load_dur}, optimize_duration={build_dur}"
|
170
|
+
f" load_duration(insert + optimize) = {m.load_duration}"
|
171
|
+
)
|
172
|
+
else:
|
173
|
+
log.info("Data loading skipped")
|
174
|
+
if (
|
175
|
+
TaskStage.SEARCH_SERIAL in self.config.stages
|
176
|
+
or TaskStage.SEARCH_CONCURRENT in self.config.stages
|
177
|
+
):
|
178
|
+
self._init_search_runner()
|
179
|
+
if TaskStage.SEARCH_SERIAL in self.config.stages:
|
180
|
+
search_results = self._serial_search()
|
181
|
+
'''
|
182
|
+
m.recall = search_results.recall
|
183
|
+
m.serial_latencies = search_results.serial_latencies
|
184
|
+
'''
|
185
|
+
m.recall, m.ndcg, m.serial_latency_p99 = search_results
|
186
|
+
if TaskStage.SEARCH_CONCURRENT in self.config.stages:
|
187
|
+
search_results = self._conc_search()
|
188
|
+
m.qps, m.conc_num_list, m.conc_qps_list, m.conc_latency_p99_list = search_results
|
189
|
+
|
145
190
|
except Exception as e:
|
146
191
|
log.warning(f"Failed to run performance case, reason = {e}")
|
147
192
|
traceback.print_exc()
|
@@ -217,18 +262,23 @@ class CaseRunner(BaseModel):
|
|
217
262
|
|
218
263
|
gt_df = self.ca.dataset.gt_data
|
219
264
|
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
265
|
+
if TaskStage.SEARCH_SERIAL in self.config.stages:
|
266
|
+
self.serial_search_runner = SerialSearchRunner(
|
267
|
+
db=self.db,
|
268
|
+
test_data=self.test_emb,
|
269
|
+
ground_truth=gt_df,
|
270
|
+
filters=self.ca.filters,
|
271
|
+
k=self.config.case_config.k,
|
272
|
+
)
|
273
|
+
if TaskStage.SEARCH_CONCURRENT in self.config.stages:
|
274
|
+
self.search_runner = MultiProcessingSearchRunner(
|
275
|
+
db=self.db,
|
276
|
+
test_data=self.test_emb,
|
277
|
+
filters=self.ca.filters,
|
278
|
+
concurrencies=self.config.case_config.concurrency_search_config.num_concurrency,
|
279
|
+
duration=self.config.case_config.concurrency_search_config.concurrency_duration,
|
280
|
+
k=self.config.case_config.k,
|
281
|
+
)
|
232
282
|
|
233
283
|
def stop(self):
|
234
284
|
if self.search_runner:
|
File without changes
|