vectordb-bench 0.0.1__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +30 -0
- vectordb_bench/__main__.py +39 -0
- vectordb_bench/backend/__init__.py +0 -0
- vectordb_bench/backend/assembler.py +57 -0
- vectordb_bench/backend/cases.py +124 -0
- vectordb_bench/backend/clients/__init__.py +57 -0
- vectordb_bench/backend/clients/api.py +179 -0
- vectordb_bench/backend/clients/elastic_cloud/config.py +56 -0
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +152 -0
- vectordb_bench/backend/clients/milvus/config.py +123 -0
- vectordb_bench/backend/clients/milvus/milvus.py +182 -0
- vectordb_bench/backend/clients/pinecone/config.py +15 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +113 -0
- vectordb_bench/backend/clients/qdrant_cloud/config.py +16 -0
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +169 -0
- vectordb_bench/backend/clients/weaviate_cloud/config.py +45 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +151 -0
- vectordb_bench/backend/clients/zilliz_cloud/config.py +34 -0
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +35 -0
- vectordb_bench/backend/dataset.py +393 -0
- vectordb_bench/backend/result_collector.py +15 -0
- vectordb_bench/backend/runner/__init__.py +12 -0
- vectordb_bench/backend/runner/mp_runner.py +124 -0
- vectordb_bench/backend/runner/serial_runner.py +164 -0
- vectordb_bench/backend/task_runner.py +290 -0
- vectordb_bench/backend/utils.py +85 -0
- vectordb_bench/base.py +6 -0
- vectordb_bench/frontend/components/check_results/charts.py +175 -0
- vectordb_bench/frontend/components/check_results/data.py +86 -0
- vectordb_bench/frontend/components/check_results/filters.py +97 -0
- vectordb_bench/frontend/components/check_results/headerIcon.py +18 -0
- vectordb_bench/frontend/components/check_results/nav.py +21 -0
- vectordb_bench/frontend/components/check_results/priceTable.py +48 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +10 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +87 -0
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +47 -0
- vectordb_bench/frontend/components/run_test/dbSelector.py +36 -0
- vectordb_bench/frontend/components/run_test/generateTasks.py +21 -0
- vectordb_bench/frontend/components/run_test/hideSidebar.py +10 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +69 -0
- vectordb_bench/frontend/const.py +391 -0
- vectordb_bench/frontend/pages/qps_with_price.py +60 -0
- vectordb_bench/frontend/pages/run_test.py +59 -0
- vectordb_bench/frontend/utils.py +6 -0
- vectordb_bench/frontend/vdb_benchmark.py +42 -0
- vectordb_bench/interface.py +239 -0
- vectordb_bench/log_util.py +103 -0
- vectordb_bench/metric.py +53 -0
- vectordb_bench/models.py +234 -0
- vectordb_bench/results/result_20230609_standard.json +3228 -0
- vectordb_bench-0.0.1.dist-info/LICENSE +21 -0
- vectordb_bench-0.0.1.dist-info/METADATA +226 -0
- vectordb_bench-0.0.1.dist-info/RECORD +56 -0
- vectordb_bench-0.0.1.dist-info/WHEEL +5 -0
- vectordb_bench-0.0.1.dist-info/entry_points.txt +2 -0
- vectordb_bench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,239 @@
|
|
1
|
+
import traceback
|
2
|
+
import pathlib
|
3
|
+
import signal
|
4
|
+
import logging
|
5
|
+
import uuid
|
6
|
+
import concurrent
|
7
|
+
import multiprocessing as mp
|
8
|
+
from multiprocessing.connection import Connection
|
9
|
+
|
10
|
+
import psutil
|
11
|
+
from enum import Enum
|
12
|
+
|
13
|
+
from . import config
|
14
|
+
from .metric import Metric
|
15
|
+
from .models import (
|
16
|
+
TaskConfig,
|
17
|
+
TestResult,
|
18
|
+
CaseResult,
|
19
|
+
LoadTimeoutError,
|
20
|
+
PerformanceTimeoutError,
|
21
|
+
ResultLabel,
|
22
|
+
)
|
23
|
+
from .backend.result_collector import ResultCollector
|
24
|
+
from .backend.assembler import Assembler
|
25
|
+
from .backend.task_runner import TaskRunner
|
26
|
+
|
27
|
+
log = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
global_result_future: concurrent.futures.Future | None = None
|
30
|
+
|
31
|
+
class SIGNAL(Enum):
|
32
|
+
SUCCESS=0
|
33
|
+
ERROR=1
|
34
|
+
WIP=2
|
35
|
+
|
36
|
+
|
37
|
+
class BenchMarkRunner:
|
38
|
+
def __init__(self):
|
39
|
+
self.running_task: TaskRunner | None = None
|
40
|
+
self.latest_error: str | None = None
|
41
|
+
|
42
|
+
def run(self, tasks: list[TaskConfig], task_label: str | None = None) -> bool:
|
43
|
+
"""run all the tasks in the configs, write one result into the path"""
|
44
|
+
if self.running_task is not None:
|
45
|
+
log.warning("There're still tasks running in the background")
|
46
|
+
return False
|
47
|
+
|
48
|
+
if len(tasks) == 0:
|
49
|
+
log.warning("Empty tasks submitted")
|
50
|
+
return False
|
51
|
+
|
52
|
+
log.debug(f"tasks: {tasks}")
|
53
|
+
|
54
|
+
# Generate run_id
|
55
|
+
run_id = uuid.uuid4().hex
|
56
|
+
log.info(f"generated uuid for the tasks: {run_id}")
|
57
|
+
task_label = task_label if task_label else run_id
|
58
|
+
|
59
|
+
self.receive_conn, send_conn = mp.Pipe()
|
60
|
+
self.latest_error = ""
|
61
|
+
self.running_task = Assembler.assemble_all(run_id, task_label, tasks)
|
62
|
+
self.running_task.display()
|
63
|
+
|
64
|
+
return self._run_async(send_conn)
|
65
|
+
|
66
|
+
def get_results(self, result_dir: pathlib.Path | None = None) -> list[TestResult]:
|
67
|
+
"""results of all runs, each TestResult represents one run."""
|
68
|
+
target_dir = result_dir if result_dir else config.RESULTS_LOCAL_DIR
|
69
|
+
return ResultCollector.collect(target_dir)
|
70
|
+
|
71
|
+
def _try_get_signal(self):
|
72
|
+
if self.receive_conn and self.receive_conn.poll():
|
73
|
+
sig, received = self.receive_conn.recv()
|
74
|
+
log.debug(f"Sigal received to process: {sig}, {received}")
|
75
|
+
if sig == SIGNAL.ERROR:
|
76
|
+
self.latest_error = received
|
77
|
+
self._clear_running_task()
|
78
|
+
elif sig == SIGNAL.SUCCESS:
|
79
|
+
global global_result_future
|
80
|
+
global_result_future = None
|
81
|
+
self.running_task = None
|
82
|
+
self.receive_conn = None
|
83
|
+
elif sig == SIGNAL.WIP:
|
84
|
+
self.running_task.set_finished(received)
|
85
|
+
else:
|
86
|
+
self._clear_running_task()
|
87
|
+
|
88
|
+
def has_running(self) -> bool:
|
89
|
+
"""check if there're running benchmarks"""
|
90
|
+
if self.running_task:
|
91
|
+
self._try_get_signal()
|
92
|
+
return self.running_task is not None
|
93
|
+
|
94
|
+
def stop_running(self):
|
95
|
+
"""force stop if ther're running benchmarks"""
|
96
|
+
self._clear_running_task()
|
97
|
+
|
98
|
+
def get_tasks_count(self) -> int:
|
99
|
+
"""the count of all tasks"""
|
100
|
+
if self.running_task:
|
101
|
+
return self.running_task.num_cases()
|
102
|
+
return 0
|
103
|
+
|
104
|
+
|
105
|
+
def get_current_task_id(self) -> int:
|
106
|
+
""" the index of current running task
|
107
|
+
return -1 if not running
|
108
|
+
"""
|
109
|
+
if not self.running_task:
|
110
|
+
return -1
|
111
|
+
return self.running_task.num_finished()
|
112
|
+
|
113
|
+
def _sync_running_task(self):
|
114
|
+
if not self.running_task:
|
115
|
+
return
|
116
|
+
|
117
|
+
global global_result_future
|
118
|
+
try:
|
119
|
+
if global_result_future:
|
120
|
+
global_result_future.result()
|
121
|
+
except Exception as e:
|
122
|
+
log.warning(f"task running failed: {e}", exc_info=True)
|
123
|
+
finally:
|
124
|
+
global_result_future = None
|
125
|
+
self.running_task = None
|
126
|
+
|
127
|
+
def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> None:
|
128
|
+
try:
|
129
|
+
if not running_task:
|
130
|
+
return
|
131
|
+
|
132
|
+
c_results = []
|
133
|
+
latest_runner, cached_load_duration = None, None
|
134
|
+
for idx, runner in enumerate(running_task.case_runners):
|
135
|
+
case_res = CaseResult(
|
136
|
+
result_id=idx,
|
137
|
+
metrics=Metric(),
|
138
|
+
task_config=runner.config,
|
139
|
+
)
|
140
|
+
|
141
|
+
drop_old = False if latest_runner and runner == latest_runner else config.DROP_OLD
|
142
|
+
try:
|
143
|
+
log.info(f"[{idx+1}/{running_task.num_cases()}] start case: {runner.display()}, drop_old={drop_old}")
|
144
|
+
case_res.metrics = runner.run(drop_old)
|
145
|
+
log.info(f"[{idx+1}/{running_task.num_cases()}] finish case: {runner.display()}, "
|
146
|
+
f"result={case_res.metrics}, label={case_res.label}")
|
147
|
+
|
148
|
+
# cache the latest succeeded runner
|
149
|
+
latest_runner = runner
|
150
|
+
|
151
|
+
# cache the latest drop_old=True load_duration of the latest succeeded runner
|
152
|
+
cached_load_duration = case_res.metrics.load_duration if drop_old else cached_load_duration
|
153
|
+
|
154
|
+
# use the cached load duration if this case didn't drop the existing collection
|
155
|
+
if not drop_old:
|
156
|
+
case_res.metrics.load_duration = cached_load_duration if cached_load_duration else 0.0
|
157
|
+
except (LoadTimeoutError, PerformanceTimeoutError) as e:
|
158
|
+
log.warning(f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}")
|
159
|
+
case_res.label = ResultLabel.OUTOFRANGE
|
160
|
+
continue
|
161
|
+
|
162
|
+
except Exception as e:
|
163
|
+
log.warning(f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}")
|
164
|
+
traceback.print_exc()
|
165
|
+
case_res.label = ResultLabel.FAILED
|
166
|
+
continue
|
167
|
+
|
168
|
+
finally:
|
169
|
+
c_results.append(case_res)
|
170
|
+
send_conn.send((SIGNAL.WIP, idx))
|
171
|
+
|
172
|
+
|
173
|
+
test_result = TestResult(
|
174
|
+
run_id=running_task.run_id,
|
175
|
+
task_label=running_task.task_label,
|
176
|
+
results=c_results,
|
177
|
+
)
|
178
|
+
test_result.display()
|
179
|
+
test_result.write_file()
|
180
|
+
|
181
|
+
send_conn.send((SIGNAL.SUCCESS, None))
|
182
|
+
send_conn.close()
|
183
|
+
log.info(f"Succes to finish task: label={running_task.task_label}, run_id={running_task.run_id}")
|
184
|
+
|
185
|
+
except Exception as e:
|
186
|
+
err_msg = f"An error occurs when running task={running_task.task_label}, run_id={running_task.run_id}, err={e}"
|
187
|
+
traceback.print_exc()
|
188
|
+
log.warning(err_msg)
|
189
|
+
send_conn.send((SIGNAL.ERROR, err_msg))
|
190
|
+
send_conn.close()
|
191
|
+
return
|
192
|
+
|
193
|
+
def _clear_running_task(self):
|
194
|
+
global global_result_future
|
195
|
+
global_result_future = None
|
196
|
+
|
197
|
+
if self.running_task:
|
198
|
+
log.info(f"will force stop running task: {self.running_task.run_id}")
|
199
|
+
for r in self.running_task.case_runners:
|
200
|
+
r.stop()
|
201
|
+
|
202
|
+
self.kill_proc_tree(timeout=5)
|
203
|
+
self.running_task = None
|
204
|
+
|
205
|
+
if self.receive_conn:
|
206
|
+
self.receive_conn.close()
|
207
|
+
self.receive_conn = None
|
208
|
+
|
209
|
+
|
210
|
+
def _run_async(self, conn: Connection) -> bool:
|
211
|
+
log.info(f"task submitted: id={self.running_task.run_id}, {self.running_task.task_label}, case number: {len(self.running_task.case_runners)}")
|
212
|
+
global global_result_future
|
213
|
+
executor = concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=mp.get_context("spawn"))
|
214
|
+
global_result_future = executor.submit(self._async_task_v2, self.running_task, conn)
|
215
|
+
|
216
|
+
return True
|
217
|
+
|
218
|
+
def kill_proc_tree(self, sig=signal.SIGTERM, timeout=None, on_terminate=None):
|
219
|
+
"""Kill a process tree (including grandchildren) with signal
|
220
|
+
"sig" and return a (gone, still_alive) tuple.
|
221
|
+
"on_terminate", if specified, is a callback function which is
|
222
|
+
called as soon as a child terminates.
|
223
|
+
"""
|
224
|
+
children = psutil.Process().children(recursive=True)
|
225
|
+
for p in children:
|
226
|
+
try:
|
227
|
+
log.warning(f"sending SIGTERM to child process: {p}")
|
228
|
+
p.send_signal(sig)
|
229
|
+
except psutil.NoSuchProcess:
|
230
|
+
pass
|
231
|
+
gone, alive = psutil.wait_procs(children, timeout=timeout,
|
232
|
+
callback=on_terminate)
|
233
|
+
|
234
|
+
for p in alive:
|
235
|
+
log.warning(f"force killing child process: {p}")
|
236
|
+
p.kill()
|
237
|
+
|
238
|
+
|
239
|
+
benchMarkRunner = BenchMarkRunner()
|
@@ -0,0 +1,103 @@
|
|
1
|
+
import logging
|
2
|
+
from logging import config
|
3
|
+
|
4
|
+
|
5
|
+
def init(log_level):
|
6
|
+
LOGGING = {
|
7
|
+
'version': 1,
|
8
|
+
'disable_existing_loggers': False,
|
9
|
+
'formatters': {
|
10
|
+
'default': {
|
11
|
+
'format': '%(asctime)s | %(levelname)s |%(message)s (%(filename)s:%(lineno)s)',
|
12
|
+
},
|
13
|
+
'colorful_console': {
|
14
|
+
'format': '%(asctime)s | %(levelname)s: %(message)s (%(filename)s:%(lineno)s) (%(process)s)',
|
15
|
+
'()': ColorfulFormatter,
|
16
|
+
},
|
17
|
+
},
|
18
|
+
'handlers': {
|
19
|
+
'console': {
|
20
|
+
'class': 'logging.StreamHandler',
|
21
|
+
'formatter': 'colorful_console',
|
22
|
+
},
|
23
|
+
'no_color_console': {
|
24
|
+
'class': 'logging.StreamHandler',
|
25
|
+
'formatter': 'default',
|
26
|
+
},
|
27
|
+
},
|
28
|
+
'loggers': {
|
29
|
+
'vectordb_bench': {
|
30
|
+
'handlers': ['console'],
|
31
|
+
'level': log_level,
|
32
|
+
'propagate': False
|
33
|
+
},
|
34
|
+
'no_color': {
|
35
|
+
'handlers': ['no_color_console'],
|
36
|
+
'level': log_level,
|
37
|
+
'propagate': False
|
38
|
+
},
|
39
|
+
},
|
40
|
+
'propagate': False,
|
41
|
+
}
|
42
|
+
|
43
|
+
config.dictConfig(LOGGING)
|
44
|
+
|
45
|
+
class colors:
|
46
|
+
HEADER= '\033[95m'
|
47
|
+
INFO= '\033[92m'
|
48
|
+
DEBUG= '\033[94m'
|
49
|
+
WARNING= '\033[93m'
|
50
|
+
ERROR= '\033[95m'
|
51
|
+
CRITICAL= '\033[91m'
|
52
|
+
ENDC= '\033[0m'
|
53
|
+
|
54
|
+
|
55
|
+
|
56
|
+
COLORS = {
|
57
|
+
'INFO': colors.INFO,
|
58
|
+
'INFOM': colors.INFO,
|
59
|
+
'DEBUG': colors.DEBUG,
|
60
|
+
'DEBUGM': colors.DEBUG,
|
61
|
+
'WARNING': colors.WARNING,
|
62
|
+
'WARNINGM': colors.WARNING,
|
63
|
+
'CRITICAL': colors.CRITICAL,
|
64
|
+
'CRITICALM': colors.CRITICAL,
|
65
|
+
'ERROR': colors.ERROR,
|
66
|
+
'ERRORM': colors.ERROR,
|
67
|
+
'ENDC': colors.ENDC,
|
68
|
+
}
|
69
|
+
|
70
|
+
|
71
|
+
class ColorFulFormatColMixin:
|
72
|
+
def format_col(self, message_str, level_name):
|
73
|
+
if level_name in COLORS.keys():
|
74
|
+
message_str = COLORS[level_name] + message_str + COLORS['ENDC']
|
75
|
+
return message_str
|
76
|
+
|
77
|
+
def formatTime(self, record, datefmt=None):
|
78
|
+
ret = super().formatTime(record, datefmt)
|
79
|
+
return ret
|
80
|
+
|
81
|
+
|
82
|
+
class ColorfulLogRecordProxy(logging.LogRecord):
|
83
|
+
def __init__(self, record):
|
84
|
+
self._record = record
|
85
|
+
msg_level = record.levelname + 'M'
|
86
|
+
self.msg = f"{COLORS[msg_level]}{record.msg}{COLORS['ENDC']}"
|
87
|
+
self.filename = record.filename
|
88
|
+
self.lineno = f'{record.lineno}'
|
89
|
+
self.process = f'{record.process}'
|
90
|
+
self.levelname = f"{COLORS[record.levelname]}{record.levelname}{COLORS['ENDC']}"
|
91
|
+
|
92
|
+
def __getattr__(self, attr):
|
93
|
+
if attr not in self.__dict__:
|
94
|
+
return getattr(self._record, attr)
|
95
|
+
return getattr(self, attr)
|
96
|
+
|
97
|
+
|
98
|
+
class ColorfulFormatter(ColorFulFormatColMixin, logging.Formatter):
|
99
|
+
def format(self, record):
|
100
|
+
proxy = ColorfulLogRecordProxy(record)
|
101
|
+
message_str = super().format(proxy)
|
102
|
+
|
103
|
+
return message_str
|
vectordb_bench/metric.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
import logging
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
from dataclasses import dataclass
|
5
|
+
|
6
|
+
|
7
|
+
log = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class Metric:
|
12
|
+
"""result metrics"""
|
13
|
+
|
14
|
+
# for load cases
|
15
|
+
max_load_count: int = 0
|
16
|
+
|
17
|
+
# for performance cases
|
18
|
+
load_duration: float = 0.0 # duration to load all dataset into DB
|
19
|
+
qps: float = 0.0
|
20
|
+
serial_latency_p99: float = 0.0
|
21
|
+
recall: float = 0.0
|
22
|
+
|
23
|
+
metricUnitMap = {
|
24
|
+
'load_duration': 's',
|
25
|
+
'serial_latency_p99': 'ms',
|
26
|
+
'max_load_count': 'K'
|
27
|
+
}
|
28
|
+
|
29
|
+
lowerIsBetterMetricList = [
|
30
|
+
"load_duration",
|
31
|
+
"serial_latency_p99",
|
32
|
+
]
|
33
|
+
|
34
|
+
metricOrder = [
|
35
|
+
"qps",
|
36
|
+
"recall",
|
37
|
+
"load_duration",
|
38
|
+
"serial_latency_p99",
|
39
|
+
"max_load_count",
|
40
|
+
]
|
41
|
+
|
42
|
+
|
43
|
+
def isLowerIsBetterMetric(metric: str) -> bool:
|
44
|
+
return metric in lowerIsBetterMetricList
|
45
|
+
|
46
|
+
|
47
|
+
def calc_recall(count: int, ground_truth: list[int], got: list[int]) -> float:
|
48
|
+
recalls = np.zeros(count)
|
49
|
+
for i, result in enumerate(got):
|
50
|
+
if result in ground_truth:
|
51
|
+
recalls[i] = 1
|
52
|
+
|
53
|
+
return np.mean(recalls)
|
vectordb_bench/models.py
ADDED
@@ -0,0 +1,234 @@
|
|
1
|
+
import logging
|
2
|
+
import pathlib
|
3
|
+
from datetime import date
|
4
|
+
from typing import Self
|
5
|
+
from enum import Enum
|
6
|
+
|
7
|
+
import ujson
|
8
|
+
|
9
|
+
from .backend.clients import (
|
10
|
+
DB,
|
11
|
+
DBConfig,
|
12
|
+
DBCaseConfig,
|
13
|
+
IndexType,
|
14
|
+
)
|
15
|
+
from .base import BaseModel
|
16
|
+
from . import config
|
17
|
+
from .metric import Metric
|
18
|
+
|
19
|
+
|
20
|
+
log = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class LoadTimeoutError(TimeoutError):
|
24
|
+
pass
|
25
|
+
|
26
|
+
class PerformanceTimeoutError(TimeoutError):
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
class CaseType(Enum):
|
31
|
+
"""
|
32
|
+
Value will be displayed in UI
|
33
|
+
"""
|
34
|
+
|
35
|
+
CapacitySDim = "Capacity Test (Large-dim)"
|
36
|
+
CapacityLDim = "Capacity Test (Small-dim)"
|
37
|
+
|
38
|
+
Performance100M = "Search Performance Test (XLarge Dataset)"
|
39
|
+
PerformanceLZero = "Search Performance Test (Large Dataset)"
|
40
|
+
PerformanceMZero = "Search Performance Test (Medium Dataset)"
|
41
|
+
PerformanceSZero = "Search Performance Test (Small Dataset)"
|
42
|
+
|
43
|
+
PerformanceLLow = (
|
44
|
+
"Filtering Search Performance Test (Large Dataset, Low Filtering Rate)"
|
45
|
+
)
|
46
|
+
PerformanceMLow = (
|
47
|
+
"Filtering Search Performance Test (Medium Dataset, Low Filtering Rate)"
|
48
|
+
)
|
49
|
+
PerformanceSLow = (
|
50
|
+
"Filtering Search Performance Test (Small Dataset, Low Filtering Rate)"
|
51
|
+
)
|
52
|
+
PerformanceLHigh = (
|
53
|
+
"Filtering Search Performance Test (Large Dataset, High Filtering Rate)"
|
54
|
+
)
|
55
|
+
PerformanceMHigh = (
|
56
|
+
"Filtering Search Performance Test (Medium Dataset, High Filtering Rate)"
|
57
|
+
)
|
58
|
+
PerformanceSHigh = (
|
59
|
+
"Filtering Search Performance Test (Small Dataset, High Filtering Rate)"
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
class CaseConfigParamType(Enum):
|
64
|
+
"""
|
65
|
+
Value will be the key of CaseConfig.params and displayed in UI
|
66
|
+
"""
|
67
|
+
|
68
|
+
IndexType = "IndexType"
|
69
|
+
M = "M"
|
70
|
+
EFConstruction = "efConstruction"
|
71
|
+
EF = "ef"
|
72
|
+
SearchList = "search_list"
|
73
|
+
Nlist = "nlist"
|
74
|
+
Nprobe = "nprobe"
|
75
|
+
MaxConnections = "maxConnections"
|
76
|
+
numCandidates = "num_candidates"
|
77
|
+
|
78
|
+
|
79
|
+
class CustomizedCase(BaseModel):
|
80
|
+
pass
|
81
|
+
|
82
|
+
|
83
|
+
class CaseConfig(BaseModel):
|
84
|
+
"""cases, dataset, test cases, filter rate, params"""
|
85
|
+
|
86
|
+
case_id: CaseType
|
87
|
+
custom_case: dict | None = None
|
88
|
+
|
89
|
+
|
90
|
+
class TaskConfig(BaseModel):
|
91
|
+
db: DB
|
92
|
+
db_config: DBConfig
|
93
|
+
db_case_config: DBCaseConfig
|
94
|
+
case_config: CaseConfig
|
95
|
+
|
96
|
+
@property
|
97
|
+
def db_name(self):
|
98
|
+
db = self.db.value
|
99
|
+
db_label = self.db_config.db_label
|
100
|
+
return f"{db}-{db_label}" if db_label else db
|
101
|
+
|
102
|
+
|
103
|
+
class ResultLabel(Enum):
|
104
|
+
NORMAL = ":)"
|
105
|
+
FAILED = "x"
|
106
|
+
OUTOFRANGE = "?"
|
107
|
+
|
108
|
+
|
109
|
+
class CaseResult(BaseModel):
|
110
|
+
metrics: Metric
|
111
|
+
task_config: TaskConfig
|
112
|
+
label: ResultLabel = ResultLabel.NORMAL
|
113
|
+
|
114
|
+
|
115
|
+
class TestResult(BaseModel):
|
116
|
+
"""ROOT/result_{date.today()}_{task_label}.json"""
|
117
|
+
|
118
|
+
run_id: str
|
119
|
+
task_label: str
|
120
|
+
results: list[CaseResult]
|
121
|
+
|
122
|
+
def write_file(self):
|
123
|
+
result_dir = config.RESULTS_LOCAL_DIR
|
124
|
+
if not result_dir.exists():
|
125
|
+
log.info(f"local result directory not exist, creating it: {result_dir}")
|
126
|
+
result_dir.mkdir(parents=True)
|
127
|
+
|
128
|
+
file_name = f'result_{date.today().strftime("%Y%m%d")}_{self.task_label}.json'
|
129
|
+
result_file = result_dir.joinpath(file_name)
|
130
|
+
if result_file.exists():
|
131
|
+
log.warning(
|
132
|
+
f"Replacing existing result with the same file_name: {result_file}"
|
133
|
+
)
|
134
|
+
|
135
|
+
log.info(f"write results to disk {result_file}")
|
136
|
+
with open(result_file, "w") as f:
|
137
|
+
b = self.json(exclude={"db_config": {"password", "api_key"}})
|
138
|
+
f.write(b)
|
139
|
+
|
140
|
+
@classmethod
|
141
|
+
def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self:
|
142
|
+
if not full_path.exists():
|
143
|
+
raise ValueError(f"No such file: {full_path}")
|
144
|
+
|
145
|
+
with open(full_path) as f:
|
146
|
+
test_result = ujson.loads(f.read())
|
147
|
+
if "task_label" not in test_result:
|
148
|
+
test_result["task_label"] = test_result["run_id"]
|
149
|
+
|
150
|
+
for case_result in test_result["results"]:
|
151
|
+
task_config = case_result.get("task_config")
|
152
|
+
db = DB(task_config.get("db"))
|
153
|
+
dbcls = db.init_cls
|
154
|
+
task_config["db_config"] = dbcls.config_cls()(
|
155
|
+
**task_config["db_config"]
|
156
|
+
)
|
157
|
+
task_config["db_case_config"] = dbcls.case_config_cls(
|
158
|
+
index_type=task_config["db_case_config"].get("index", None),
|
159
|
+
)(**task_config["db_case_config"])
|
160
|
+
|
161
|
+
case_result["task_config"] = task_config
|
162
|
+
|
163
|
+
if trans_unit:
|
164
|
+
cur_max_count = case_result["metrics"]["max_load_count"]
|
165
|
+
case_result["metrics"]["max_load_count"] = (
|
166
|
+
cur_max_count / 1000
|
167
|
+
if int(cur_max_count) > 0
|
168
|
+
else cur_max_count
|
169
|
+
)
|
170
|
+
|
171
|
+
cur_latency = case_result["metrics"]["serial_latency_p99"]
|
172
|
+
case_result["metrics"]["serial_latency_p99"] = (
|
173
|
+
cur_latency * 1000 if cur_latency > 0 else cur_latency
|
174
|
+
)
|
175
|
+
c = TestResult.validate(test_result)
|
176
|
+
|
177
|
+
return c
|
178
|
+
|
179
|
+
def display(self, dbs: list[DB] | None = None):
|
180
|
+
filter_list = dbs if dbs and isinstance(dbs, list) else None
|
181
|
+
sorted_results = sorted(self.results, key=lambda x: (
|
182
|
+
x.task_config.db.name,
|
183
|
+
x.task_config.db_config.db_label,
|
184
|
+
x.task_config.case_config.case_id.name,
|
185
|
+
), reverse=True)
|
186
|
+
|
187
|
+
filtered_results = [r for r in sorted_results if not filter_list or r.task_config.db not in filter_list]
|
188
|
+
|
189
|
+
def append_return(x, y):
|
190
|
+
x.append(y)
|
191
|
+
return x
|
192
|
+
|
193
|
+
max_db = max(map(len, [f.task_config.db.name for f in filtered_results]))
|
194
|
+
max_db_labels = max(map(len, [f.task_config.db_config.db_label for f in filtered_results])) + 3
|
195
|
+
max_case = max(map(len, [f.task_config.case_config.case_id.name for f in filtered_results]))
|
196
|
+
max_load_dur = max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3
|
197
|
+
max_qps = max(map(len, [str(f.metrics.qps) for f in filtered_results])) + 3
|
198
|
+
max_recall = max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3
|
199
|
+
|
200
|
+
max_db_labels = 8 if max_db_labels == 0 else max_db_labels
|
201
|
+
max_load_dur = 11 if max_load_dur == 0 else max_load_dur + 3
|
202
|
+
max_qps = 10 if max_qps == 0 else max_load_dur + 3
|
203
|
+
max_recall = 13 if max_recall == 0 else max_recall + 3
|
204
|
+
|
205
|
+
LENGTH = (max_db, max_db_labels, max_case, len(self.task_label), max_load_dur, max_qps, 15, max_recall, 14)
|
206
|
+
|
207
|
+
DATA_FORMAT = (
|
208
|
+
f"%-{max_db}s | %-{max_db_labels}s %-{max_case}s %-{len(self.task_label)}s "
|
209
|
+
f"| %-{max_load_dur}s %-{max_qps}s %-15s %-{max_recall}s %-14s"
|
210
|
+
)
|
211
|
+
|
212
|
+
TITLE = DATA_FORMAT % (
|
213
|
+
"DB", "db_label", "case", "label", "load_dur", "qps", "latency(p99)", "recall", "max_load_count")
|
214
|
+
SPLIT = DATA_FORMAT%tuple(map(lambda x:"-"*x, LENGTH))
|
215
|
+
SUMMERY_FORMAT = ("Task summery: run_id=%s, task_label=%s") % (self.run_id[:5], self.task_label)
|
216
|
+
fmt = [SUMMERY_FORMAT, TITLE, SPLIT]
|
217
|
+
|
218
|
+
|
219
|
+
for f in filtered_results:
|
220
|
+
fmt.append(DATA_FORMAT%(
|
221
|
+
f.task_config.db.name,
|
222
|
+
f.task_config.db_config.db_label,
|
223
|
+
f.task_config.case_config.case_id.name,
|
224
|
+
self.task_label,
|
225
|
+
f.metrics.load_duration,
|
226
|
+
f.metrics.qps,
|
227
|
+
f.metrics.serial_latency_p99,
|
228
|
+
f.metrics.recall,
|
229
|
+
f.metrics.max_load_count,
|
230
|
+
))
|
231
|
+
|
232
|
+
tmp_logger = logging.getLogger("no_color")
|
233
|
+
for f in fmt:
|
234
|
+
tmp_logger.info(f)
|