vectordb-bench 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +85 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +112 -77
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +59 -84
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +13 -24
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +50 -54
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +39 -40
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +24 -26
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +67 -58
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +19 -39
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +56 -77
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +34 -43
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +98 -104
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +39 -49
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +15 -25
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +41 -35
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +95 -62
  47. vectordb_bench/backend/clients/test/cli.py +2 -3
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +5 -9
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +37 -26
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +18 -14
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +56 -23
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +46 -22
  63. vectordb_bench/backend/runner/serial_runner.py +81 -46
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -92
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +45 -24
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.21.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  import logging
2
2
  import pathlib
3
3
  import typing
4
+ from abc import ABC, abstractmethod
4
5
  from enum import Enum
6
+
5
7
  from tqdm import tqdm
6
- import os
7
- from abc import ABC, abstractmethod
8
8
 
9
- from .. import config
9
+ from vectordb_bench import config
10
10
 
11
11
  logging.getLogger("s3fs").setLevel(logging.CRITICAL)
12
12
 
@@ -14,6 +14,7 @@ log = logging.getLogger(__name__)
14
14
 
15
15
  DatasetReader = typing.TypeVar("DatasetReader")
16
16
 
17
+
17
18
  class DatasetSource(Enum):
18
19
  S3 = "S3"
19
20
  AliyunOSS = "AliyunOSS"
@@ -25,6 +26,8 @@ class DatasetSource(Enum):
25
26
  if self == DatasetSource.AliyunOSS:
26
27
  return AliyunOSSReader()
27
28
 
29
+ return None
30
+
28
31
 
29
32
  class DatasetReader(ABC):
30
33
  source: DatasetSource
@@ -39,7 +42,6 @@ class DatasetReader(ABC):
39
42
  files(list[str]): all filenames of the dataset
40
43
  local_ds_root(pathlib.Path): whether to write the remote data.
41
44
  """
42
- pass
43
45
 
44
46
  @abstractmethod
45
47
  def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
@@ -52,13 +54,14 @@ class AliyunOSSReader(DatasetReader):
52
54
 
53
55
  def __init__(self):
54
56
  import oss2
57
+
55
58
  self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True)
56
59
 
57
60
  def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
58
61
  info = self.bucket.get_object_meta(remote.as_posix())
59
62
 
60
63
  # check size equal
61
- remote_size, local_size = info.content_length, os.path.getsize(local)
64
+ remote_size, local_size = info.content_length, local.stat().st_size
62
65
  if remote_size != local_size:
63
66
  log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
64
67
  return False
@@ -70,7 +73,13 @@ class AliyunOSSReader(DatasetReader):
70
73
  if not local_ds_root.exists():
71
74
  log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
72
75
  local_ds_root.mkdir(parents=True)
73
- downloads = [(pathlib.PurePosixPath("benchmark", dataset, f), local_ds_root.joinpath(f)) for f in files]
76
+ downloads = [
77
+ (
78
+ pathlib.PurePosixPath("benchmark", dataset, f),
79
+ local_ds_root.joinpath(f),
80
+ )
81
+ for f in files
82
+ ]
74
83
 
75
84
  else:
76
85
  for file in files:
@@ -92,17 +101,14 @@ class AliyunOSSReader(DatasetReader):
92
101
  log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")
93
102
 
94
103
 
95
-
96
104
  class AwsS3Reader(DatasetReader):
97
105
  source: DatasetSource = DatasetSource.S3
98
106
  remote_root: str = config.AWS_S3_URL
99
107
 
100
108
  def __init__(self):
101
109
  import s3fs
102
- self.fs = s3fs.S3FileSystem(
103
- anon=True,
104
- client_kwargs={'region_name': 'us-west-2'}
105
- )
110
+
111
+ self.fs = s3fs.S3FileSystem(anon=True, client_kwargs={"region_name": "us-west-2"})
106
112
 
107
113
  def ls_all(self, dataset: str):
108
114
  dataset_root_dir = pathlib.Path(self.remote_root, dataset)
@@ -112,7 +118,6 @@ class AwsS3Reader(DatasetReader):
112
118
  log.info(n)
113
119
  return names
114
120
 
115
-
116
121
  def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
117
122
  downloads = []
118
123
  if not local_ds_root.exists():
@@ -139,13 +144,12 @@ class AwsS3Reader(DatasetReader):
139
144
 
140
145
  log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")
141
146
 
142
-
143
147
  def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
144
148
  # info() uses ls() inside, maybe we only need to ls once
145
149
  info = self.fs.info(remote)
146
150
 
147
151
  # check size equal
148
- remote_size, local_size = info.get("size"), os.path.getsize(local)
152
+ remote_size, local_size = info.get("size"), local.stat().st_size
149
153
  if remote_size != local_size:
150
154
  log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
151
155
  return False
@@ -4,25 +4,30 @@ Usage:
4
4
  >>> Dataset.Cohere.get(100_000)
5
5
  """
6
6
 
7
- from collections import namedtuple
8
7
  import logging
9
8
  import pathlib
9
+ import typing
10
10
  from enum import Enum
11
+
11
12
  import pandas as pd
12
- from pydantic import validator, PrivateAttr
13
13
  import polars as pl
14
14
  from pyarrow.parquet import ParquetFile
15
+ from pydantic import PrivateAttr, validator
16
+
17
+ from vectordb_bench import config
18
+ from vectordb_bench.base import BaseModel
15
19
 
16
- from ..base import BaseModel
17
- from .. import config
18
- from ..backend.clients import MetricType
19
20
  from . import utils
20
- from .data_source import DatasetSource, DatasetReader
21
+ from .clients import MetricType
22
+ from .data_source import DatasetReader, DatasetSource
21
23
 
22
24
  log = logging.getLogger(__name__)
23
25
 
24
26
 
25
- SizeLabel = namedtuple('SizeLabel', ['size', 'label', 'file_count'])
27
+ class SizeLabel(typing.NamedTuple):
28
+ size: int
29
+ label: str
30
+ file_count: int
26
31
 
27
32
 
28
33
  class BaseDataset(BaseModel):
@@ -33,12 +38,13 @@ class BaseDataset(BaseModel):
33
38
  use_shuffled: bool
34
39
  with_gt: bool = False
35
40
  _size_label: dict[int, SizeLabel] = PrivateAttr()
36
- isCustom: bool = False
41
+ is_custom: bool = False
37
42
 
38
43
  @validator("size")
39
- def verify_size(cls, v):
44
+ def verify_size(cls, v: int):
40
45
  if v not in cls._size_label:
41
- raise ValueError(f"Size {v} not supported for the dataset, expected: {cls._size_label.keys()}")
46
+ msg = f"Size {v} not supported for the dataset, expected: {cls._size_label.keys()}"
47
+ raise ValueError(msg)
42
48
  return v
43
49
 
44
50
  @property
@@ -53,13 +59,14 @@ class BaseDataset(BaseModel):
53
59
  def file_count(self) -> int:
54
60
  return self._size_label.get(self.size).file_count
55
61
 
62
+
56
63
  class CustomDataset(BaseDataset):
57
64
  dir: str
58
65
  file_num: int
59
- isCustom: bool = True
66
+ is_custom: bool = True
60
67
 
61
68
  @validator("size")
62
- def verify_size(cls, v):
69
+ def verify_size(cls, v: int):
63
70
  return v
64
71
 
65
72
  @property
@@ -102,7 +109,7 @@ class Cohere(BaseDataset):
102
109
  dim: int = 768
103
110
  metric_type: MetricType = MetricType.COSINE
104
111
  use_shuffled: bool = config.USE_SHUFFLED_DATA
105
- with_gt: bool = True,
112
+ with_gt: bool = (True,)
106
113
  _size_label: dict = {
107
114
  100_000: SizeLabel(100_000, "SMALL", 1),
108
115
  1_000_000: SizeLabel(1_000_000, "MEDIUM", 1),
@@ -124,7 +131,11 @@ class SIFT(BaseDataset):
124
131
  metric_type: MetricType = MetricType.L2
125
132
  use_shuffled: bool = False
126
133
  _size_label: dict = {
127
- 500_000: SizeLabel(500_000, "SMALL", 1,),
134
+ 500_000: SizeLabel(
135
+ 500_000,
136
+ "SMALL",
137
+ 1,
138
+ ),
128
139
  5_000_000: SizeLabel(5_000_000, "MEDIUM", 1),
129
140
  # 50_000_000: SizeLabel(50_000_000, "LARGE", 50),
130
141
  }
@@ -135,7 +146,7 @@ class OpenAI(BaseDataset):
135
146
  dim: int = 1536
136
147
  metric_type: MetricType = MetricType.COSINE
137
148
  use_shuffled: bool = config.USE_SHUFFLED_DATA
138
- with_gt: bool = True,
149
+ with_gt: bool = (True,)
139
150
  _size_label: dict = {
140
151
  50_000: SizeLabel(50_000, "SMALL", 1),
141
152
  500_000: SizeLabel(500_000, "MEDIUM", 1),
@@ -153,13 +164,14 @@ class DatasetManager(BaseModel):
153
164
  >>> for data in cohere:
154
165
  >>> print(data.columns)
155
166
  """
156
- data: BaseDataset
167
+
168
+ data: BaseDataset
157
169
  test_data: pd.DataFrame | None = None
158
170
  gt_data: pd.DataFrame | None = None
159
- train_files : list[str] = []
171
+ train_files: list[str] = []
160
172
  reader: DatasetReader | None = None
161
173
 
162
- def __eq__(self, obj):
174
+ def __eq__(self, obj: any):
163
175
  if isinstance(obj, DatasetManager):
164
176
  return self.data.name == obj.data.name and self.data.label == obj.data.label
165
177
  return False
@@ -169,22 +181,27 @@ class DatasetManager(BaseModel):
169
181
 
170
182
  @property
171
183
  def data_dir(self) -> pathlib.Path:
172
- """ data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname}
184
+ """data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname}
173
185
 
174
186
  Examples:
175
187
  >>> sift_s = Dataset.SIFT.manager(500_000)
176
188
  >>> sift_s.relative_path
177
189
  '/tmp/vectordb_bench/dataset/sift/sift_small_500k/'
178
190
  """
179
- return pathlib.Path(config.DATASET_LOCAL_DIR, self.data.name.lower(), self.data.dir_name.lower())
191
+ return pathlib.Path(
192
+ config.DATASET_LOCAL_DIR,
193
+ self.data.name.lower(),
194
+ self.data.dir_name.lower(),
195
+ )
180
196
 
181
197
  def __iter__(self):
182
198
  return DataSetIterator(self)
183
199
 
184
200
  # TODO passing use_shuffle from outside
185
- def prepare(self,
186
- source: DatasetSource=DatasetSource.S3,
187
- filters: int | float | str | None = None,
201
+ def prepare(
202
+ self,
203
+ source: DatasetSource = DatasetSource.S3,
204
+ filters: float | str | None = None,
188
205
  ) -> bool:
189
206
  """Download the dataset from DatasetSource
190
207
  url = f"{source}/{self.data.dir_name}"
@@ -208,7 +225,7 @@ class DatasetManager(BaseModel):
208
225
  gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
209
226
  all_files.extend([gt_file, test_file])
210
227
 
211
- if not self.data.isCustom:
228
+ if not self.data.is_custom:
212
229
  source.reader().read(
213
230
  dataset=self.data.dir_name.lower(),
214
231
  files=all_files,
@@ -220,7 +237,7 @@ class DatasetManager(BaseModel):
220
237
  self.gt_data = self._read_file(gt_file)
221
238
 
222
239
  prefix = "shuffle_train" if use_shuffled else "train"
223
- self.train_files = sorted([f.name for f in self.data_dir.glob(f'{prefix}*.parquet')])
240
+ self.train_files = sorted([f.name for f in self.data_dir.glob(f"{prefix}*.parquet")])
224
241
  log.debug(f"{self.data.name}: available train files {self.train_files}")
225
242
 
226
243
  return True
@@ -241,7 +258,7 @@ class DataSetIterator:
241
258
  self._ds = dataset
242
259
  self._idx = 0 # file number
243
260
  self._cur = None
244
- self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file
261
+ self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file
245
262
 
246
263
  def __iter__(self):
247
264
  return self
@@ -250,7 +267,9 @@ class DataSetIterator:
250
267
  p = pathlib.Path(self._ds.data_dir, file_name)
251
268
  log.info(f"Get iterator for {p.name}")
252
269
  if not p.exists():
253
- raise IndexError(f"No such file {p}")
270
+ msg = f"No such file: {p}"
271
+ log.warning(msg)
272
+ raise IndexError(msg)
254
273
  return ParquetFile(p, memory_map=True, pre_buffer=True).iter_batches(config.NUM_PER_BATCH)
255
274
 
256
275
  def __next__(self) -> pd.DataFrame:
@@ -281,6 +300,7 @@ class Dataset(Enum):
281
300
  >>> Dataset.COHERE.manager(100_000)
282
301
  >>> Dataset.COHERE.get(100_000)
283
302
  """
303
+
284
304
  LAION = LAION
285
305
  GIST = GIST
286
306
  COHERE = Cohere
@@ -1,7 +1,7 @@
1
+ import logging
1
2
  import pathlib
2
- from ..models import TestResult
3
3
 
4
- import logging
4
+ from vectordb_bench.models import TestResult
5
5
 
6
6
  log = logging.getLogger(__name__)
7
7
 
@@ -14,7 +14,6 @@ class ResultCollector:
14
14
  if not result_dir.exists() or len(list(result_dir.rglob(reg))) == 0:
15
15
  return []
16
16
 
17
-
18
17
  for json_file in result_dir.rglob(reg):
19
18
  file_result = TestResult.read_file(json_file, trans_unit=True)
20
19
 
@@ -1,12 +1,10 @@
1
1
  from .mp_runner import (
2
2
  MultiProcessingSearchRunner,
3
3
  )
4
-
5
- from .serial_runner import SerialSearchRunner, SerialInsertRunner
6
-
4
+ from .serial_runner import SerialInsertRunner, SerialSearchRunner
7
5
 
8
6
  __all__ = [
9
- 'MultiProcessingSearchRunner',
10
- 'SerialSearchRunner',
11
- 'SerialInsertRunner',
7
+ "MultiProcessingSearchRunner",
8
+ "SerialInsertRunner",
9
+ "SerialSearchRunner",
12
10
  ]
@@ -1,27 +1,29 @@
1
- import time
2
- import traceback
3
1
  import concurrent
2
+ import logging
4
3
  import multiprocessing as mp
5
4
  import random
6
- import logging
7
- from typing import Iterable
5
+ import time
6
+ import traceback
7
+ from collections.abc import Iterable
8
+
8
9
  import numpy as np
9
- from ..clients import api
10
- from ... import config
11
10
 
11
+ from ... import config
12
+ from ..clients import api
12
13
 
13
14
  NUM_PER_BATCH = config.NUM_PER_BATCH
14
15
  log = logging.getLogger(__name__)
15
16
 
16
17
 
17
18
  class MultiProcessingSearchRunner:
18
- """ multiprocessing search runner
19
+ """multiprocessing search runner
19
20
 
20
21
  Args:
21
22
  k(int): search topk, default to 100
22
23
  concurrency(Iterable): concurrencies, default [1, 5, 10, 15, 20, 25, 30, 35]
23
24
  duration(int): duration for each concurency, default to 30s
24
25
  """
26
+
25
27
  def __init__(
26
28
  self,
27
29
  db: api.VectorDB,
@@ -40,7 +42,12 @@ class MultiProcessingSearchRunner:
40
42
  self.test_data = test_data
41
43
  log.debug(f"test dataset columns: {len(test_data)}")
42
44
 
43
- def search(self, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition) -> tuple[int, float]:
45
+ def search(
46
+ self,
47
+ test_data: list[list[float]],
48
+ q: mp.Queue,
49
+ cond: mp.Condition,
50
+ ) -> tuple[int, float]:
44
51
  # sync all process
45
52
  q.put(1)
46
53
  with cond:
@@ -71,13 +78,16 @@ class MultiProcessingSearchRunner:
71
78
  idx = idx + 1 if idx < num - 1 else 0
72
79
 
73
80
  if count % 500 == 0:
74
- log.debug(f"({mp.current_process().name:16}) search_count: {count}, latest_latency={time.perf_counter()-s}")
81
+ log.debug(
82
+ f"({mp.current_process().name:16}) "
83
+ f"search_count: {count}, latest_latency={time.perf_counter()-s}"
84
+ )
75
85
 
76
86
  total_dur = round(time.perf_counter() - start_time, 4)
77
87
  log.info(
78
88
  f"{mp.current_process().name:16} search {self.duration}s: "
79
89
  f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}"
80
- )
90
+ )
81
91
 
82
92
  return (count, total_dur, latencies)
83
93
 
@@ -87,8 +97,6 @@ class MultiProcessingSearchRunner:
87
97
  log.debug(f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}")
88
98
  return mp.get_context(mp_start_method)
89
99
 
90
-
91
-
92
100
  def _run_all_concurrencies_mem_efficient(self):
93
101
  max_qps = 0
94
102
  conc_num_list = []
@@ -99,7 +107,10 @@ class MultiProcessingSearchRunner:
99
107
  for conc in self.concurrencies:
100
108
  with mp.Manager() as m:
101
109
  q, cond = m.Queue(), m.Condition()
102
- with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor:
110
+ with concurrent.futures.ProcessPoolExecutor(
111
+ mp_context=self.get_mp_context(),
112
+ max_workers=conc,
113
+ ) as executor:
103
114
  log.info(f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}")
104
115
  future_iter = [executor.submit(self.search, self.test_data, q, cond) for i in range(conc)]
105
116
  # Sync all processes
@@ -129,7 +140,9 @@ class MultiProcessingSearchRunner:
129
140
  max_qps = qps
130
141
  log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}")
131
142
  except Exception as e:
132
- log.warning(f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}")
143
+ log.warning(
144
+ f"Fail to search, concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}"
145
+ )
133
146
  traceback.print_exc()
134
147
 
135
148
  # No results available, raise exception
@@ -139,7 +152,13 @@ class MultiProcessingSearchRunner:
139
152
  finally:
140
153
  self.stop()
141
154
 
142
- return max_qps, conc_num_list, conc_qps_list, conc_latency_p99_list, conc_latency_avg_list
155
+ return (
156
+ max_qps,
157
+ conc_num_list,
158
+ conc_qps_list,
159
+ conc_latency_p99_list,
160
+ conc_latency_avg_list,
161
+ )
143
162
 
144
163
  def run(self) -> float:
145
164
  """
@@ -160,9 +179,14 @@ class MultiProcessingSearchRunner:
160
179
  for conc in self.concurrencies:
161
180
  with mp.Manager() as m:
162
181
  q, cond = m.Queue(), m.Condition()
163
- with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor:
182
+ with concurrent.futures.ProcessPoolExecutor(
183
+ mp_context=self.get_mp_context(),
184
+ max_workers=conc,
185
+ ) as executor:
164
186
  log.info(f"Start search_by_dur {duration}s in concurrency {conc}, filters: {self.filters}")
165
- future_iter = [executor.submit(self.search_by_dur, duration, self.test_data, q, cond) for i in range(conc)]
187
+ future_iter = [
188
+ executor.submit(self.search_by_dur, duration, self.test_data, q, cond) for i in range(conc)
189
+ ]
166
190
  # Sync all processes
167
191
  while q.qsize() < conc:
168
192
  sleep_t = conc if conc < 10 else 10
@@ -183,7 +207,9 @@ class MultiProcessingSearchRunner:
183
207
  max_qps = qps
184
208
  log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}")
185
209
  except Exception as e:
186
- log.warning(f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}")
210
+ log.warning(
211
+ f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}",
212
+ )
187
213
  traceback.print_exc()
188
214
 
189
215
  # No results available, raise exception
@@ -195,8 +221,13 @@ class MultiProcessingSearchRunner:
195
221
 
196
222
  return max_qps
197
223
 
198
-
199
- def search_by_dur(self, dur: int, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition) -> int:
224
+ def search_by_dur(
225
+ self,
226
+ dur: int,
227
+ test_data: list[list[float]],
228
+ q: mp.Queue,
229
+ cond: mp.Condition,
230
+ ) -> int:
200
231
  # sync all process
201
232
  q.put(1)
202
233
  with cond:
@@ -225,13 +256,15 @@ class MultiProcessingSearchRunner:
225
256
  idx = idx + 1 if idx < num - 1 else 0
226
257
 
227
258
  if count % 500 == 0:
228
- log.debug(f"({mp.current_process().name:16}) search_count: {count}, latest_latency={time.perf_counter()-s}")
259
+ log.debug(
260
+ f"({mp.current_process().name:16}) search_count: {count}, "
261
+ f"latest_latency={time.perf_counter()-s}"
262
+ )
229
263
 
230
264
  total_dur = round(time.perf_counter() - start_time, 4)
231
265
  log.debug(
232
266
  f"{mp.current_process().name:16} search {self.duration}s: "
233
267
  f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}"
234
- )
268
+ )
235
269
 
236
270
  return count
237
-
@@ -1,36 +1,36 @@
1
+ import concurrent
1
2
  import logging
3
+ import multiprocessing as mp
2
4
  import time
3
- import concurrent
4
5
  from concurrent.futures import ThreadPoolExecutor
5
- import multiprocessing as mp
6
-
7
6
 
7
+ from vectordb_bench import config
8
8
  from vectordb_bench.backend.clients import api
9
9
  from vectordb_bench.backend.dataset import DataSetIterator
10
10
  from vectordb_bench.backend.utils import time_it
11
- from vectordb_bench import config
12
11
 
13
12
  from .util import get_data
13
+
14
14
  log = logging.getLogger(__name__)
15
15
 
16
16
 
17
17
  class RatedMultiThreadingInsertRunner:
18
18
  def __init__(
19
19
  self,
20
- rate: int, # numRows per second
20
+ rate: int, # numRows per second
21
21
  db: api.VectorDB,
22
22
  dataset_iter: DataSetIterator,
23
23
  normalize: bool = False,
24
24
  timeout: float | None = None,
25
25
  ):
26
- self.timeout = timeout if isinstance(timeout, (int, float)) else None
26
+ self.timeout = timeout if isinstance(timeout, int | float) else None
27
27
  self.dataset = dataset_iter
28
28
  self.db = db
29
29
  self.normalize = normalize
30
30
  self.insert_rate = rate
31
31
  self.batch_rate = rate // config.NUM_PER_BATCH
32
32
 
33
- def send_insert_task(self, db, emb: list[list[float]], metadata: list[str]):
33
+ def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str]):
34
34
  db.insert_embeddings(emb, metadata)
35
35
 
36
36
  @time_it
@@ -43,7 +43,9 @@ class RatedMultiThreadingInsertRunner:
43
43
  rate = self.batch_rate
44
44
  for data in self.dataset:
45
45
  emb, metadata = get_data(data, self.normalize)
46
- executing_futures.append(executor.submit(self.send_insert_task, self.db, emb, metadata))
46
+ executing_futures.append(
47
+ executor.submit(self.send_insert_task, self.db, emb, metadata),
48
+ )
47
49
  rate -= 1
48
50
 
49
51
  if rate == 0:
@@ -66,19 +68,26 @@ class RatedMultiThreadingInsertRunner:
66
68
  done, not_done = concurrent.futures.wait(
67
69
  executing_futures,
68
70
  timeout=wait_interval,
69
- return_when=concurrent.futures.FIRST_EXCEPTION)
71
+ return_when=concurrent.futures.FIRST_EXCEPTION,
72
+ )
70
73
 
71
74
  if len(not_done) > 0:
72
- log.warning(f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round")
75
+ log.warning(
76
+ f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] "
77
+ f"tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round"
78
+ )
73
79
  executing_futures = list(not_done)
74
80
  else:
75
- log.debug(f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} task in 1s, wait_interval={wait_interval:.2f}")
81
+ log.debug(
82
+ f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} "
83
+ f"task in 1s, wait_interval={wait_interval:.2f}"
84
+ )
76
85
  executing_futures = []
77
86
  except Exception as e:
78
- log.warn(f"task error, terminating, err={e}")
79
- q.put(None, block=True)
80
- executor.shutdown(wait=True, cancel_futures=True)
81
- raise e
87
+ log.warning(f"task error, terminating, err={e}")
88
+ q.put(None, block=True)
89
+ executor.shutdown(wait=True, cancel_futures=True)
90
+ raise e from e
82
91
 
83
92
  dur = time.perf_counter() - start_time
84
93
  if dur < 1:
@@ -87,10 +96,12 @@ class RatedMultiThreadingInsertRunner:
87
96
  # wait for all tasks in executing_futures to complete
88
97
  if len(executing_futures) > 0:
89
98
  try:
90
- done, _ = concurrent.futures.wait(executing_futures,
91
- return_when=concurrent.futures.FIRST_EXCEPTION)
99
+ done, _ = concurrent.futures.wait(
100
+ executing_futures,
101
+ return_when=concurrent.futures.FIRST_EXCEPTION,
102
+ )
92
103
  except Exception as e:
93
- log.warn(f"task error, terminating, err={e}")
104
+ log.warning(f"task error, terminating, err={e}")
94
105
  q.put(None, block=True)
95
106
  executor.shutdown(wait=True, cancel_futures=True)
96
- raise e
107
+ raise e from e