vectordb-bench 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +56 -46
  5. vectordb_bench/backend/clients/__init__.py +101 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +26 -0
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +18 -0
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +345 -0
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +47 -0
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
  11. vectordb_bench/backend/clients/alloydb/cli.py +52 -35
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +8 -9
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +38 -36
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
  25. vectordb_bench/backend/clients/milvus/cli.py +62 -80
  26. vectordb_bench/backend/clients/milvus/config.py +31 -7
  27. vectordb_bench/backend/clients/milvus/milvus.py +23 -26
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +94 -58
  47. vectordb_bench/backend/clients/test/cli.py +1 -2
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +4 -5
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +30 -18
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +85 -34
  61. vectordb_bench/backend/runner/rate_runner.py +51 -23
  62. vectordb_bench/backend/runner/read_write_runner.py +140 -46
  63. vectordb_bench/backend/runner/serial_runner.py +99 -50
  64. vectordb_bench/backend/runner/util.py +4 -19
  65. vectordb_bench/backend/task_runner.py +95 -74
  66. vectordb_bench/backend/utils.py +17 -9
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +108 -83
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +56 -26
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +34 -42
  100. vectordb_bench-0.0.20.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.18.dist-info/RECORD +0 -131
  103. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.18.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ import pathlib
5
5
  import signal
6
6
  import traceback
7
7
  import uuid
8
+ from collections.abc import Callable
8
9
  from enum import Enum
9
10
  from multiprocessing.connection import Connection
10
11
 
@@ -16,8 +17,15 @@ from .backend.data_source import DatasetSource
16
17
  from .backend.result_collector import ResultCollector
17
18
  from .backend.task_runner import TaskRunner
18
19
  from .metric import Metric
19
- from .models import (CaseResult, LoadTimeoutError, PerformanceTimeoutError,
20
- ResultLabel, TaskConfig, TaskStage, TestResult)
20
+ from .models import (
21
+ CaseResult,
22
+ LoadTimeoutError,
23
+ PerformanceTimeoutError,
24
+ ResultLabel,
25
+ TaskConfig,
26
+ TaskStage,
27
+ TestResult,
28
+ )
21
29
 
22
30
  log = logging.getLogger(__name__)
23
31
 
@@ -37,11 +45,9 @@ class BenchMarkRunner:
37
45
  self.drop_old: bool = True
38
46
  self.dataset_source: DatasetSource = DatasetSource.S3
39
47
 
40
-
41
48
  def set_drop_old(self, drop_old: bool):
42
49
  self.drop_old = drop_old
43
50
 
44
-
45
51
  def set_download_address(self, use_aliyun: bool):
46
52
  if use_aliyun:
47
53
  self.dataset_source = DatasetSource.AliyunOSS
@@ -59,7 +65,9 @@ class BenchMarkRunner:
59
65
  log.warning("Empty tasks submitted")
60
66
  return False
61
67
 
62
- log.debug(f"tasks: {tasks}, task_label: {task_label}, dataset source: {self.dataset_source}")
68
+ log.debug(
69
+ f"tasks: {tasks}, task_label: {task_label}, dataset source: {self.dataset_source}",
70
+ )
63
71
 
64
72
  # Generate run_id
65
73
  run_id = uuid.uuid4().hex
@@ -70,7 +78,12 @@ class BenchMarkRunner:
70
78
  self.latest_error = ""
71
79
 
72
80
  try:
73
- self.running_task = Assembler.assemble_all(run_id, task_label, tasks, self.dataset_source)
81
+ self.running_task = Assembler.assemble_all(
82
+ run_id,
83
+ task_label,
84
+ tasks,
85
+ self.dataset_source,
86
+ )
74
87
  self.running_task.display()
75
88
  except ModuleNotFoundError as e:
76
89
  msg = f"Please install client for database, error={e}"
@@ -119,7 +132,7 @@ class BenchMarkRunner:
119
132
  return 0
120
133
 
121
134
  def get_current_task_id(self) -> int:
122
- """ the index of current running task
135
+ """the index of current running task
123
136
  return -1 if not running
124
137
  """
125
138
  if not self.running_task:
@@ -153,18 +166,18 @@ class BenchMarkRunner:
153
166
  task_config=runner.config,
154
167
  )
155
168
 
156
- # drop_old = False if latest_runner and runner == latest_runner else config.DROP_OLD
157
- # drop_old = config.DROP_OLD
158
169
  drop_old = TaskStage.DROP_OLD in runner.config.stages
159
- if latest_runner and runner == latest_runner:
160
- drop_old = False
161
- elif not self.drop_old:
170
+ if (latest_runner and runner == latest_runner) or not self.drop_old:
162
171
  drop_old = False
163
172
  try:
164
- log.info(f"[{idx+1}/{running_task.num_cases()}] start case: {runner.display()}, drop_old={drop_old}")
173
+ log.info(
174
+ f"[{idx+1}/{running_task.num_cases()}] start case: {runner.display()}, drop_old={drop_old}",
175
+ )
165
176
  case_res.metrics = runner.run(drop_old)
166
- log.info(f"[{idx+1}/{running_task.num_cases()}] finish case: {runner.display()}, "
167
- f"result={case_res.metrics}, label={case_res.label}")
177
+ log.info(
178
+ f"[{idx+1}/{running_task.num_cases()}] finish case: {runner.display()}, "
179
+ f"result={case_res.metrics}, label={case_res.label}",
180
+ )
168
181
 
169
182
  # cache the latest succeeded runner
170
183
  latest_runner = runner
@@ -176,12 +189,16 @@ class BenchMarkRunner:
176
189
  if not drop_old:
177
190
  case_res.metrics.load_duration = cached_load_duration if cached_load_duration else 0.0
178
191
  except (LoadTimeoutError, PerformanceTimeoutError) as e:
179
- log.warning(f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}")
192
+ log.warning(
193
+ f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}",
194
+ )
180
195
  case_res.label = ResultLabel.OUTOFRANGE
181
196
  continue
182
197
 
183
198
  except Exception as e:
184
- log.warning(f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}")
199
+ log.warning(
200
+ f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}",
201
+ )
185
202
  traceback.print_exc()
186
203
  case_res.label = ResultLabel.FAILED
187
204
  continue
@@ -200,10 +217,14 @@ class BenchMarkRunner:
200
217
 
201
218
  send_conn.send((SIGNAL.SUCCESS, None))
202
219
  send_conn.close()
203
- log.info(f"Success to finish task: label={running_task.task_label}, run_id={running_task.run_id}")
220
+ log.info(
221
+ f"Success to finish task: label={running_task.task_label}, run_id={running_task.run_id}",
222
+ )
204
223
 
205
224
  except Exception as e:
206
- err_msg = f"An error occurs when running task={running_task.task_label}, run_id={running_task.run_id}, err={e}"
225
+ err_msg = (
226
+ f"An error occurs when running task={running_task.task_label}, run_id={running_task.run_id}, err={e}"
227
+ )
207
228
  traceback.print_exc()
208
229
  log.warning(err_msg)
209
230
  send_conn.send((SIGNAL.ERROR, err_msg))
@@ -226,16 +247,26 @@ class BenchMarkRunner:
226
247
  self.receive_conn.close()
227
248
  self.receive_conn = None
228
249
 
229
-
230
250
  def _run_async(self, conn: Connection) -> bool:
231
- log.info(f"task submitted: id={self.running_task.run_id}, {self.running_task.task_label}, case number: {len(self.running_task.case_runners)}")
251
+ log.info(
252
+ f"task submitted: id={self.running_task.run_id}, {self.running_task.task_label}, ",
253
+ f"case number: {len(self.running_task.case_runners)}",
254
+ )
232
255
  global global_result_future
233
- executor = concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=mp.get_context("spawn"))
256
+ executor = concurrent.futures.ProcessPoolExecutor(
257
+ max_workers=1,
258
+ mp_context=mp.get_context("spawn"),
259
+ )
234
260
  global_result_future = executor.submit(self._async_task_v2, self.running_task, conn)
235
261
 
236
262
  return True
237
263
 
238
- def kill_proc_tree(self, sig=signal.SIGTERM, timeout=None, on_terminate=None):
264
+ def kill_proc_tree(
265
+ self,
266
+ sig: int = signal.SIGTERM,
267
+ timeout: float | None = None,
268
+ on_terminate: Callable | None = None,
269
+ ):
239
270
  """Kill a process tree (including grandchildren) with signal
240
271
  "sig" and return a (gone, still_alive) tuple.
241
272
  "on_terminate", if specified, is a callback function which is
@@ -248,12 +279,11 @@ class BenchMarkRunner:
248
279
  p.send_signal(sig)
249
280
  except psutil.NoSuchProcess:
250
281
  pass
251
- gone, alive = psutil.wait_procs(children, timeout=timeout,
252
- callback=on_terminate)
282
+ gone, alive = psutil.wait_procs(children, timeout=timeout, callback=on_terminate)
253
283
 
254
284
  for p in alive:
255
285
  log.warning(f"force killing child process: {p}")
256
286
  p.kill()
257
287
 
258
288
 
259
- benchMarkRunner = BenchMarkRunner()
289
+ benchmark_runner = BenchMarkRunner()
@@ -1,102 +1,97 @@
1
1
  import logging
2
2
  from logging import config
3
3
 
4
- def init(log_level):
5
- LOGGING = {
6
- 'version': 1,
7
- 'disable_existing_loggers': False,
8
- 'formatters': {
9
- 'default': {
10
- 'format': '%(asctime)s | %(levelname)s |%(message)s (%(filename)s:%(lineno)s)',
4
+
5
+ def init(log_level: str):
6
+ log_config = {
7
+ "version": 1,
8
+ "disable_existing_loggers": False,
9
+ "formatters": {
10
+ "default": {
11
+ "format": "%(asctime)s | %(levelname)s |%(message)s (%(filename)s:%(lineno)s)",
11
12
  },
12
- 'colorful_console': {
13
- 'format': '%(asctime)s | %(levelname)s: %(message)s (%(filename)s:%(lineno)s) (%(process)s)',
14
- '()': ColorfulFormatter,
13
+ "colorful_console": {
14
+ "format": "%(asctime)s | %(levelname)s: %(message)s (%(filename)s:%(lineno)s) (%(process)s)",
15
+ "()": ColorfulFormatter,
15
16
  },
16
17
  },
17
- 'handlers': {
18
- 'console': {
19
- 'class': 'logging.StreamHandler',
20
- 'formatter': 'colorful_console',
18
+ "handlers": {
19
+ "console": {
20
+ "class": "logging.StreamHandler",
21
+ "formatter": "colorful_console",
21
22
  },
22
- 'no_color_console': {
23
- 'class': 'logging.StreamHandler',
24
- 'formatter': 'default',
23
+ "no_color_console": {
24
+ "class": "logging.StreamHandler",
25
+ "formatter": "default",
25
26
  },
26
27
  },
27
- 'loggers': {
28
- 'vectordb_bench': {
29
- 'handlers': ['console'],
30
- 'level': log_level,
31
- 'propagate': False
28
+ "loggers": {
29
+ "vectordb_bench": {
30
+ "handlers": ["console"],
31
+ "level": log_level,
32
+ "propagate": False,
32
33
  },
33
- 'no_color': {
34
- 'handlers': ['no_color_console'],
35
- 'level': log_level,
36
- 'propagate': False
34
+ "no_color": {
35
+ "handlers": ["no_color_console"],
36
+ "level": log_level,
37
+ "propagate": False,
37
38
  },
38
39
  },
39
- 'propagate': False,
40
+ "propagate": False,
40
41
  }
41
42
 
42
- config.dictConfig(LOGGING)
43
+ config.dictConfig(log_config)
43
44
 
44
- class colors:
45
- HEADER= '\033[95m'
46
- INFO= '\033[92m'
47
- DEBUG= '\033[94m'
48
- WARNING= '\033[93m'
49
- ERROR= '\033[95m'
50
- CRITICAL= '\033[91m'
51
- ENDC= '\033[0m'
52
45
 
46
+ class colors:
47
+ HEADER = "\033[95m"
48
+ INFO = "\033[92m"
49
+ DEBUG = "\033[94m"
50
+ WARNING = "\033[93m"
51
+ ERROR = "\033[95m"
52
+ CRITICAL = "\033[91m"
53
+ ENDC = "\033[0m"
53
54
 
54
55
 
55
56
  COLORS = {
56
- 'INFO': colors.INFO,
57
- 'INFOM': colors.INFO,
58
- 'DEBUG': colors.DEBUG,
59
- 'DEBUGM': colors.DEBUG,
60
- 'WARNING': colors.WARNING,
61
- 'WARNINGM': colors.WARNING,
62
- 'CRITICAL': colors.CRITICAL,
63
- 'CRITICALM': colors.CRITICAL,
64
- 'ERROR': colors.ERROR,
65
- 'ERRORM': colors.ERROR,
66
- 'ENDC': colors.ENDC,
57
+ "INFO": colors.INFO,
58
+ "INFOM": colors.INFO,
59
+ "DEBUG": colors.DEBUG,
60
+ "DEBUGM": colors.DEBUG,
61
+ "WARNING": colors.WARNING,
62
+ "WARNINGM": colors.WARNING,
63
+ "CRITICAL": colors.CRITICAL,
64
+ "CRITICALM": colors.CRITICAL,
65
+ "ERROR": colors.ERROR,
66
+ "ERRORM": colors.ERROR,
67
+ "ENDC": colors.ENDC,
67
68
  }
68
69
 
69
70
 
70
71
  class ColorFulFormatColMixin:
71
- def format_col(self, message_str, level_name):
72
- if level_name in COLORS.keys():
73
- message_str = COLORS[level_name] + message_str + COLORS['ENDC']
74
- return message_str
75
-
76
- def formatTime(self, record, datefmt=None):
77
- ret = super().formatTime(record, datefmt)
78
- return ret
72
+ def format_col(self, message: str, level_name: str):
73
+ if level_name in COLORS:
74
+ message = COLORS[level_name] + message + COLORS["ENDC"]
75
+ return message
79
76
 
80
77
 
81
78
  class ColorfulLogRecordProxy(logging.LogRecord):
82
- def __init__(self, record):
79
+ def __init__(self, record: any):
83
80
  self._record = record
84
- msg_level = record.levelname + 'M'
81
+ msg_level = record.levelname + "M"
85
82
  self.msg = f"{COLORS[msg_level]}{record.msg}{COLORS['ENDC']}"
86
83
  self.filename = record.filename
87
- self.lineno = f'{record.lineno}'
88
- self.process = f'{record.process}'
84
+ self.lineno = f"{record.lineno}"
85
+ self.process = f"{record.process}"
89
86
  self.levelname = f"{COLORS[record.levelname]}{record.levelname}{COLORS['ENDC']}"
90
87
 
91
- def __getattr__(self, attr):
88
+ def __getattr__(self, attr: any):
92
89
  if attr not in self.__dict__:
93
90
  return getattr(self._record, attr)
94
91
  return getattr(self, attr)
95
92
 
96
93
 
97
94
  class ColorfulFormatter(ColorFulFormatColMixin, logging.Formatter):
98
- def format(self, record):
95
+ def format(self, record: any):
99
96
  proxy = ColorfulLogRecordProxy(record)
100
- message_str = super().format(proxy)
101
-
102
- return message_str
97
+ return super().format(proxy)
vectordb_bench/metric.py CHANGED
@@ -1,8 +1,7 @@
1
1
  import logging
2
- import numpy as np
3
-
4
2
  from dataclasses import dataclass, field
5
3
 
4
+ import numpy as np
6
5
 
7
6
  log = logging.getLogger(__name__)
8
7
 
@@ -33,19 +32,19 @@ MAX_LOAD_COUNT_METRIC = "max_load_count"
33
32
  QPS_METRIC = "qps"
34
33
  RECALL_METRIC = "recall"
35
34
 
36
- metricUnitMap = {
35
+ metric_unit_map = {
37
36
  LOAD_DURATION_METRIC: "s",
38
37
  SERIAL_LATENCY_P99_METRIC: "ms",
39
38
  MAX_LOAD_COUNT_METRIC: "K",
40
39
  QURIES_PER_DOLLAR_METRIC: "K",
41
40
  }
42
41
 
43
- lowerIsBetterMetricList = [
42
+ lower_is_better_metrics = [
44
43
  LOAD_DURATION_METRIC,
45
44
  SERIAL_LATENCY_P99_METRIC,
46
45
  ]
47
46
 
48
- metricOrder = [
47
+ metric_order = [
49
48
  QPS_METRIC,
50
49
  RECALL_METRIC,
51
50
  LOAD_DURATION_METRIC,
@@ -55,7 +54,7 @@ metricOrder = [
55
54
 
56
55
 
57
56
  def isLowerIsBetterMetric(metric: str) -> bool:
58
- return metric in lowerIsBetterMetricList
57
+ return metric in lower_is_better_metrics
59
58
 
60
59
 
61
60
  def calc_recall(count: int, ground_truth: list[int], got: list[int]) -> float:
@@ -70,7 +69,7 @@ def calc_recall(count: int, ground_truth: list[int], got: list[int]) -> float:
70
69
  def get_ideal_dcg(k: int):
71
70
  ideal_dcg = 0
72
71
  for i in range(k):
73
- ideal_dcg += 1 / np.log2(i+2)
72
+ ideal_dcg += 1 / np.log2(i + 2)
74
73
 
75
74
  return ideal_dcg
76
75
 
@@ -78,8 +77,8 @@ def get_ideal_dcg(k: int):
78
77
  def calc_ndcg(ground_truth: list[int], got: list[int], ideal_dcg: float) -> float:
79
78
  dcg = 0
80
79
  ground_truth = list(ground_truth)
81
- for id in set(got):
82
- if id in ground_truth:
83
- idx = ground_truth.index(id)
84
- dcg += 1 / np.log2(idx+2)
80
+ for got_id in set(got):
81
+ if got_id in ground_truth:
82
+ idx = ground_truth.index(got_id)
83
+ dcg += 1 / np.log2(idx + 2)
85
84
  return dcg / ideal_dcg
vectordb_bench/models.py CHANGED
@@ -2,29 +2,31 @@ import logging
2
2
  import pathlib
3
3
  from datetime import date, datetime
4
4
  from enum import Enum, StrEnum, auto
5
- from typing import List, Self
5
+ from typing import Self
6
6
 
7
7
  import ujson
8
8
 
9
+ from . import config
10
+ from .backend.cases import CaseType
9
11
  from .backend.clients import (
10
12
  DB,
11
- DBConfig,
12
13
  DBCaseConfig,
14
+ DBConfig,
13
15
  )
14
- from .backend.cases import CaseType
15
16
  from .base import BaseModel
16
- from . import config
17
17
  from .metric import Metric
18
18
 
19
19
  log = logging.getLogger(__name__)
20
20
 
21
21
 
22
22
  class LoadTimeoutError(TimeoutError):
23
- pass
23
+ def __init__(self, duration: int):
24
+ super().__init__(f"capacity case load timeout in {duration}s")
24
25
 
25
26
 
26
27
  class PerformanceTimeoutError(TimeoutError):
27
- pass
28
+ def __init__(self):
29
+ super().__init__("Performance case optimize timeout")
28
30
 
29
31
 
30
32
  class CaseConfigParamType(Enum):
@@ -92,7 +94,7 @@ class CustomizedCase(BaseModel):
92
94
 
93
95
 
94
96
  class ConcurrencySearchConfig(BaseModel):
95
- num_concurrency: List[int] = config.NUM_CONCURRENCY
97
+ num_concurrency: list[int] = config.NUM_CONCURRENCY
96
98
  concurrency_duration: int = config.CONCURRENCY_DURATION
97
99
 
98
100
 
@@ -146,7 +148,7 @@ class TaskConfig(BaseModel):
146
148
  db_config: DBConfig
147
149
  db_case_config: DBCaseConfig
148
150
  case_config: CaseConfig
149
- stages: List[TaskStage] = ALL_TASK_STAGES
151
+ stages: list[TaskStage] = ALL_TASK_STAGES
150
152
 
151
153
  @property
152
154
  def db_name(self):
@@ -210,26 +212,23 @@ class TestResult(BaseModel):
210
212
  log.info(f"local result directory not exist, creating it: {result_dir}")
211
213
  result_dir.mkdir(parents=True)
212
214
 
213
- file_name = self.file_fmt.format(
214
- date.today().strftime("%Y%m%d"), partial.task_label, db
215
- )
215
+ file_name = self.file_fmt.format(date.today().strftime("%Y%m%d"), partial.task_label, db)
216
216
  result_file = result_dir.joinpath(file_name)
217
217
  if result_file.exists():
218
- log.warning(
219
- f"Replacing existing result with the same file_name: {result_file}"
220
- )
218
+ log.warning(f"Replacing existing result with the same file_name: {result_file}")
221
219
 
222
220
  log.info(f"write results to disk {result_file}")
223
- with open(result_file, "w") as f:
221
+ with pathlib.Path(result_file).open("w") as f:
224
222
  b = partial.json(exclude={"db_config": {"password", "api_key"}})
225
223
  f.write(b)
226
224
 
227
225
  @classmethod
228
226
  def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self:
229
227
  if not full_path.exists():
230
- raise ValueError(f"No such file: {full_path}")
228
+ msg = f"No such file: {full_path}"
229
+ raise ValueError(msg)
231
230
 
232
- with open(full_path) as f:
231
+ with pathlib.Path(full_path).open("r") as f:
233
232
  test_result = ujson.loads(f.read())
234
233
  if "task_label" not in test_result:
235
234
  test_result["task_label"] = test_result["run_id"]
@@ -248,19 +247,16 @@ class TestResult(BaseModel):
248
247
  if trans_unit:
249
248
  cur_max_count = case_result["metrics"]["max_load_count"]
250
249
  case_result["metrics"]["max_load_count"] = (
251
- cur_max_count / 1000
252
- if int(cur_max_count) > 0
253
- else cur_max_count
250
+ cur_max_count / 1000 if int(cur_max_count) > 0 else cur_max_count
254
251
  )
255
252
 
256
253
  cur_latency = case_result["metrics"]["serial_latency_p99"]
257
254
  case_result["metrics"]["serial_latency_p99"] = (
258
255
  cur_latency * 1000 if cur_latency > 0 else cur_latency
259
256
  )
260
- c = TestResult.validate(test_result)
261
-
262
- return c
257
+ return TestResult.validate(test_result)
263
258
 
259
+ # ruff: noqa
264
260
  def display(self, dbs: list[DB] | None = None):
265
261
  filter_list = dbs if dbs and isinstance(dbs, list) else None
266
262
  sorted_results = sorted(
@@ -273,31 +269,18 @@ class TestResult(BaseModel):
273
269
  reverse=True,
274
270
  )
275
271
 
276
- filtered_results = [
277
- r
278
- for r in sorted_results
279
- if not filter_list or r.task_config.db not in filter_list
280
- ]
272
+ filtered_results = [r for r in sorted_results if not filter_list or r.task_config.db not in filter_list]
281
273
 
282
- def append_return(x, y):
274
+ def append_return(x: any, y: any):
283
275
  x.append(y)
284
276
  return x
285
277
 
286
278
  max_db = max(map(len, [f.task_config.db.name for f in filtered_results]))
287
- max_db_labels = (
288
- max(map(len, [f.task_config.db_config.db_label for f in filtered_results]))
289
- + 3
290
- )
291
- max_case = max(
292
- map(len, [f.task_config.case_config.case_id.name for f in filtered_results])
293
- )
294
- max_load_dur = (
295
- max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3
296
- )
279
+ max_db_labels = max(map(len, [f.task_config.db_config.db_label for f in filtered_results])) + 3
280
+ max_case = max(map(len, [f.task_config.case_config.case_id.name for f in filtered_results]))
281
+ max_load_dur = max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3
297
282
  max_qps = max(map(len, [str(f.metrics.qps) for f in filtered_results])) + 3
298
- max_recall = (
299
- max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3
300
- )
283
+ max_recall = max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3
301
284
 
302
285
  max_db_labels = 8 if max_db_labels < 8 else max_db_labels
303
286
  max_load_dur = 11 if max_load_dur < 11 else max_load_dur
@@ -356,7 +339,7 @@ class TestResult(BaseModel):
356
339
  f.metrics.recall,
357
340
  f.metrics.max_load_count,
358
341
  f.label.value,
359
- )
342
+ ),
360
343
  )
361
344
 
362
345
  tmp_logger = logging.getLogger("no_color")