vectordb-bench 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. vectordb_bench/__init__.py +49 -24
  2. vectordb_bench/__main__.py +4 -3
  3. vectordb_bench/backend/assembler.py +12 -13
  4. vectordb_bench/backend/cases.py +55 -45
  5. vectordb_bench/backend/clients/__init__.py +75 -14
  6. vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +1 -2
  7. vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +3 -4
  8. vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +111 -70
  9. vectordb_bench/backend/clients/aliyun_opensearch/config.py +6 -7
  10. vectordb_bench/backend/clients/alloydb/alloydb.py +58 -80
  11. vectordb_bench/backend/clients/alloydb/cli.py +51 -34
  12. vectordb_bench/backend/clients/alloydb/config.py +30 -30
  13. vectordb_bench/backend/clients/api.py +5 -9
  14. vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +46 -47
  15. vectordb_bench/backend/clients/aws_opensearch/cli.py +4 -7
  16. vectordb_bench/backend/clients/aws_opensearch/config.py +13 -9
  17. vectordb_bench/backend/clients/aws_opensearch/run.py +69 -59
  18. vectordb_bench/backend/clients/chroma/chroma.py +38 -36
  19. vectordb_bench/backend/clients/chroma/config.py +4 -2
  20. vectordb_bench/backend/clients/elastic_cloud/config.py +5 -5
  21. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +23 -22
  22. vectordb_bench/backend/clients/memorydb/cli.py +8 -8
  23. vectordb_bench/backend/clients/memorydb/config.py +2 -2
  24. vectordb_bench/backend/clients/memorydb/memorydb.py +65 -53
  25. vectordb_bench/backend/clients/milvus/cli.py +41 -83
  26. vectordb_bench/backend/clients/milvus/config.py +18 -8
  27. vectordb_bench/backend/clients/milvus/milvus.py +18 -19
  28. vectordb_bench/backend/clients/pgdiskann/cli.py +29 -22
  29. vectordb_bench/backend/clients/pgdiskann/config.py +29 -26
  30. vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +55 -73
  31. vectordb_bench/backend/clients/pgvecto_rs/cli.py +9 -11
  32. vectordb_bench/backend/clients/pgvecto_rs/config.py +8 -14
  33. vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +33 -34
  34. vectordb_bench/backend/clients/pgvector/cli.py +40 -31
  35. vectordb_bench/backend/clients/pgvector/config.py +63 -73
  36. vectordb_bench/backend/clients/pgvector/pgvector.py +97 -98
  37. vectordb_bench/backend/clients/pgvectorscale/cli.py +38 -24
  38. vectordb_bench/backend/clients/pgvectorscale/config.py +14 -15
  39. vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +38 -43
  40. vectordb_bench/backend/clients/pinecone/config.py +1 -0
  41. vectordb_bench/backend/clients/pinecone/pinecone.py +14 -21
  42. vectordb_bench/backend/clients/qdrant_cloud/config.py +11 -10
  43. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +40 -31
  44. vectordb_bench/backend/clients/redis/cli.py +6 -12
  45. vectordb_bench/backend/clients/redis/config.py +7 -5
  46. vectordb_bench/backend/clients/redis/redis.py +94 -58
  47. vectordb_bench/backend/clients/test/cli.py +1 -2
  48. vectordb_bench/backend/clients/test/config.py +2 -2
  49. vectordb_bench/backend/clients/test/test.py +4 -5
  50. vectordb_bench/backend/clients/weaviate_cloud/cli.py +3 -4
  51. vectordb_bench/backend/clients/weaviate_cloud/config.py +2 -2
  52. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +36 -22
  53. vectordb_bench/backend/clients/zilliz_cloud/cli.py +14 -11
  54. vectordb_bench/backend/clients/zilliz_cloud/config.py +2 -4
  55. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +1 -1
  56. vectordb_bench/backend/data_source.py +30 -18
  57. vectordb_bench/backend/dataset.py +47 -27
  58. vectordb_bench/backend/result_collector.py +2 -3
  59. vectordb_bench/backend/runner/__init__.py +4 -6
  60. vectordb_bench/backend/runner/mp_runner.py +85 -34
  61. vectordb_bench/backend/runner/rate_runner.py +30 -19
  62. vectordb_bench/backend/runner/read_write_runner.py +51 -23
  63. vectordb_bench/backend/runner/serial_runner.py +91 -48
  64. vectordb_bench/backend/runner/util.py +4 -3
  65. vectordb_bench/backend/task_runner.py +92 -72
  66. vectordb_bench/backend/utils.py +17 -10
  67. vectordb_bench/base.py +0 -1
  68. vectordb_bench/cli/cli.py +65 -60
  69. vectordb_bench/cli/vectordbbench.py +6 -7
  70. vectordb_bench/frontend/components/check_results/charts.py +8 -19
  71. vectordb_bench/frontend/components/check_results/data.py +4 -16
  72. vectordb_bench/frontend/components/check_results/filters.py +8 -16
  73. vectordb_bench/frontend/components/check_results/nav.py +4 -4
  74. vectordb_bench/frontend/components/check_results/priceTable.py +1 -3
  75. vectordb_bench/frontend/components/check_results/stPageConfig.py +2 -1
  76. vectordb_bench/frontend/components/concurrent/charts.py +12 -12
  77. vectordb_bench/frontend/components/custom/displayCustomCase.py +17 -11
  78. vectordb_bench/frontend/components/custom/displaypPrams.py +4 -2
  79. vectordb_bench/frontend/components/custom/getCustomConfig.py +1 -2
  80. vectordb_bench/frontend/components/custom/initStyle.py +1 -1
  81. vectordb_bench/frontend/components/get_results/saveAsImage.py +2 -0
  82. vectordb_bench/frontend/components/run_test/caseSelector.py +3 -9
  83. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +1 -4
  84. vectordb_bench/frontend/components/run_test/dbSelector.py +1 -1
  85. vectordb_bench/frontend/components/run_test/generateTasks.py +8 -8
  86. vectordb_bench/frontend/components/run_test/submitTask.py +14 -18
  87. vectordb_bench/frontend/components/tables/data.py +3 -6
  88. vectordb_bench/frontend/config/dbCaseConfigs.py +51 -84
  89. vectordb_bench/frontend/pages/concurrent.py +3 -5
  90. vectordb_bench/frontend/pages/custom.py +30 -9
  91. vectordb_bench/frontend/pages/quries_per_dollar.py +3 -3
  92. vectordb_bench/frontend/pages/run_test.py +3 -7
  93. vectordb_bench/frontend/utils.py +1 -1
  94. vectordb_bench/frontend/vdb_benchmark.py +4 -6
  95. vectordb_bench/interface.py +56 -26
  96. vectordb_bench/log_util.py +59 -64
  97. vectordb_bench/metric.py +10 -11
  98. vectordb_bench/models.py +26 -43
  99. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/METADATA +22 -15
  100. vectordb_bench-0.0.20.dist-info/RECORD +135 -0
  101. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/WHEEL +1 -1
  102. vectordb_bench-0.0.19.dist-info/RECORD +0 -135
  103. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/LICENSE +0 -0
  104. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/entry_points.txt +0 -0
  105. {vectordb_bench-0.0.19.dist-info → vectordb_bench-0.0.20.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  import logging
2
2
  import pathlib
3
3
  import typing
4
+ from abc import ABC, abstractmethod
4
5
  from enum import Enum
6
+
5
7
  from tqdm import tqdm
6
- import os
7
- from abc import ABC, abstractmethod
8
8
 
9
- from .. import config
9
+ from vectordb_bench import config
10
10
 
11
11
  logging.getLogger("s3fs").setLevel(logging.CRITICAL)
12
12
 
@@ -14,6 +14,7 @@ log = logging.getLogger(__name__)
14
14
 
15
15
  DatasetReader = typing.TypeVar("DatasetReader")
16
16
 
17
+
17
18
  class DatasetSource(Enum):
18
19
  S3 = "S3"
19
20
  AliyunOSS = "AliyunOSS"
@@ -25,6 +26,8 @@ class DatasetSource(Enum):
25
26
  if self == DatasetSource.AliyunOSS:
26
27
  return AliyunOSSReader()
27
28
 
29
+ return None
30
+
28
31
 
29
32
  class DatasetReader(ABC):
30
33
  source: DatasetSource
@@ -39,7 +42,6 @@ class DatasetReader(ABC):
39
42
  files(list[str]): all filenames of the dataset
40
43
  local_ds_root(pathlib.Path): whether to write the remote data.
41
44
  """
42
- pass
43
45
 
44
46
  @abstractmethod
45
47
  def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
@@ -52,15 +54,18 @@ class AliyunOSSReader(DatasetReader):
52
54
 
53
55
  def __init__(self):
54
56
  import oss2
57
+
55
58
  self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True)
56
59
 
57
60
  def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
58
61
  info = self.bucket.get_object_meta(remote.as_posix())
59
62
 
60
63
  # check size equal
61
- remote_size, local_size = info.content_length, os.path.getsize(local)
64
+ remote_size, local_size = info.content_length, local.stat().st_size
62
65
  if remote_size != local_size:
63
- log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
66
+ log.info(
67
+ f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]",
68
+ )
64
69
  return False
65
70
 
66
71
  return True
@@ -70,7 +75,13 @@ class AliyunOSSReader(DatasetReader):
70
75
  if not local_ds_root.exists():
71
76
  log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
72
77
  local_ds_root.mkdir(parents=True)
73
- downloads = [(pathlib.PurePosixPath("benchmark", dataset, f), local_ds_root.joinpath(f)) for f in files]
78
+ downloads = [
79
+ (
80
+ pathlib.PurePosixPath("benchmark", dataset, f),
81
+ local_ds_root.joinpath(f),
82
+ )
83
+ for f in files
84
+ ]
74
85
 
75
86
  else:
76
87
  for file in files:
@@ -78,7 +89,9 @@ class AliyunOSSReader(DatasetReader):
78
89
  local_file = local_ds_root.joinpath(file)
79
90
 
80
91
  if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)):
81
- log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
92
+ log.info(
93
+ f"local file: {local_file} not match with remote: {remote_file}; add to downloading list",
94
+ )
82
95
  downloads.append((remote_file, local_file))
83
96
 
84
97
  if len(downloads) == 0:
@@ -92,17 +105,14 @@ class AliyunOSSReader(DatasetReader):
92
105
  log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")
93
106
 
94
107
 
95
-
96
108
  class AwsS3Reader(DatasetReader):
97
109
  source: DatasetSource = DatasetSource.S3
98
110
  remote_root: str = config.AWS_S3_URL
99
111
 
100
112
  def __init__(self):
101
113
  import s3fs
102
- self.fs = s3fs.S3FileSystem(
103
- anon=True,
104
- client_kwargs={'region_name': 'us-west-2'}
105
- )
114
+
115
+ self.fs = s3fs.S3FileSystem(anon=True, client_kwargs={"region_name": "us-west-2"})
106
116
 
107
117
  def ls_all(self, dataset: str):
108
118
  dataset_root_dir = pathlib.Path(self.remote_root, dataset)
@@ -112,7 +122,6 @@ class AwsS3Reader(DatasetReader):
112
122
  log.info(n)
113
123
  return names
114
124
 
115
-
116
125
  def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
117
126
  downloads = []
118
127
  if not local_ds_root.exists():
@@ -126,7 +135,9 @@ class AwsS3Reader(DatasetReader):
126
135
  local_file = local_ds_root.joinpath(file)
127
136
 
128
137
  if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)):
129
- log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
138
+ log.info(
139
+ f"local file: {local_file} not match with remote: {remote_file}; add to downloading list",
140
+ )
130
141
  downloads.append(remote_file)
131
142
 
132
143
  if len(downloads) == 0:
@@ -139,15 +150,16 @@ class AwsS3Reader(DatasetReader):
139
150
 
140
151
  log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")
141
152
 
142
-
143
153
  def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
144
154
  # info() uses ls() inside, maybe we only need to ls once
145
155
  info = self.fs.info(remote)
146
156
 
147
157
  # check size equal
148
- remote_size, local_size = info.get("size"), os.path.getsize(local)
158
+ remote_size, local_size = info.get("size"), local.stat().st_size
149
159
  if remote_size != local_size:
150
- log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
160
+ log.info(
161
+ f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]",
162
+ )
151
163
  return False
152
164
 
153
165
  return True
@@ -4,25 +4,30 @@ Usage:
4
4
  >>> Dataset.Cohere.get(100_000)
5
5
  """
6
6
 
7
- from collections import namedtuple
8
7
  import logging
9
8
  import pathlib
9
+ import typing
10
10
  from enum import Enum
11
+
11
12
  import pandas as pd
12
- from pydantic import validator, PrivateAttr
13
13
  import polars as pl
14
14
  from pyarrow.parquet import ParquetFile
15
+ from pydantic import PrivateAttr, validator
16
+
17
+ from vectordb_bench import config
18
+ from vectordb_bench.base import BaseModel
15
19
 
16
- from ..base import BaseModel
17
- from .. import config
18
- from ..backend.clients import MetricType
19
20
  from . import utils
20
- from .data_source import DatasetSource, DatasetReader
21
+ from .clients import MetricType
22
+ from .data_source import DatasetReader, DatasetSource
21
23
 
22
24
  log = logging.getLogger(__name__)
23
25
 
24
26
 
25
- SizeLabel = namedtuple('SizeLabel', ['size', 'label', 'file_count'])
27
+ class SizeLabel(typing.NamedTuple):
28
+ size: int
29
+ label: str
30
+ file_count: int
26
31
 
27
32
 
28
33
  class BaseDataset(BaseModel):
@@ -33,12 +38,13 @@ class BaseDataset(BaseModel):
33
38
  use_shuffled: bool
34
39
  with_gt: bool = False
35
40
  _size_label: dict[int, SizeLabel] = PrivateAttr()
36
- isCustom: bool = False
41
+ is_custom: bool = False
37
42
 
38
43
  @validator("size")
39
- def verify_size(cls, v):
44
+ def verify_size(cls, v: int):
40
45
  if v not in cls._size_label:
41
- raise ValueError(f"Size {v} not supported for the dataset, expected: {cls._size_label.keys()}")
46
+ msg = f"Size {v} not supported for the dataset, expected: {cls._size_label.keys()}"
47
+ raise ValueError(msg)
42
48
  return v
43
49
 
44
50
  @property
@@ -53,13 +59,14 @@ class BaseDataset(BaseModel):
53
59
  def file_count(self) -> int:
54
60
  return self._size_label.get(self.size).file_count
55
61
 
62
+
56
63
  class CustomDataset(BaseDataset):
57
64
  dir: str
58
65
  file_num: int
59
- isCustom: bool = True
66
+ is_custom: bool = True
60
67
 
61
68
  @validator("size")
62
- def verify_size(cls, v):
69
+ def verify_size(cls, v: int):
63
70
  return v
64
71
 
65
72
  @property
@@ -102,7 +109,7 @@ class Cohere(BaseDataset):
102
109
  dim: int = 768
103
110
  metric_type: MetricType = MetricType.COSINE
104
111
  use_shuffled: bool = config.USE_SHUFFLED_DATA
105
- with_gt: bool = True,
112
+ with_gt: bool = (True,)
106
113
  _size_label: dict = {
107
114
  100_000: SizeLabel(100_000, "SMALL", 1),
108
115
  1_000_000: SizeLabel(1_000_000, "MEDIUM", 1),
@@ -124,7 +131,11 @@ class SIFT(BaseDataset):
124
131
  metric_type: MetricType = MetricType.L2
125
132
  use_shuffled: bool = False
126
133
  _size_label: dict = {
127
- 500_000: SizeLabel(500_000, "SMALL", 1,),
134
+ 500_000: SizeLabel(
135
+ 500_000,
136
+ "SMALL",
137
+ 1,
138
+ ),
128
139
  5_000_000: SizeLabel(5_000_000, "MEDIUM", 1),
129
140
  # 50_000_000: SizeLabel(50_000_000, "LARGE", 50),
130
141
  }
@@ -135,7 +146,7 @@ class OpenAI(BaseDataset):
135
146
  dim: int = 1536
136
147
  metric_type: MetricType = MetricType.COSINE
137
148
  use_shuffled: bool = config.USE_SHUFFLED_DATA
138
- with_gt: bool = True,
149
+ with_gt: bool = (True,)
139
150
  _size_label: dict = {
140
151
  50_000: SizeLabel(50_000, "SMALL", 1),
141
152
  500_000: SizeLabel(500_000, "MEDIUM", 1),
@@ -153,13 +164,14 @@ class DatasetManager(BaseModel):
153
164
  >>> for data in cohere:
154
165
  >>> print(data.columns)
155
166
  """
156
- data: BaseDataset
167
+
168
+ data: BaseDataset
157
169
  test_data: pd.DataFrame | None = None
158
170
  gt_data: pd.DataFrame | None = None
159
- train_files : list[str] = []
171
+ train_files: list[str] = []
160
172
  reader: DatasetReader | None = None
161
173
 
162
- def __eq__(self, obj):
174
+ def __eq__(self, obj: any):
163
175
  if isinstance(obj, DatasetManager):
164
176
  return self.data.name == obj.data.name and self.data.label == obj.data.label
165
177
  return False
@@ -169,22 +181,27 @@ class DatasetManager(BaseModel):
169
181
 
170
182
  @property
171
183
  def data_dir(self) -> pathlib.Path:
172
- """ data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname}
184
+ """data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname}
173
185
 
174
186
  Examples:
175
187
  >>> sift_s = Dataset.SIFT.manager(500_000)
176
188
  >>> sift_s.relative_path
177
189
  '/tmp/vectordb_bench/dataset/sift/sift_small_500k/'
178
190
  """
179
- return pathlib.Path(config.DATASET_LOCAL_DIR, self.data.name.lower(), self.data.dir_name.lower())
191
+ return pathlib.Path(
192
+ config.DATASET_LOCAL_DIR,
193
+ self.data.name.lower(),
194
+ self.data.dir_name.lower(),
195
+ )
180
196
 
181
197
  def __iter__(self):
182
198
  return DataSetIterator(self)
183
199
 
184
200
  # TODO passing use_shuffle from outside
185
- def prepare(self,
186
- source: DatasetSource=DatasetSource.S3,
187
- filters: int | float | str | None = None,
201
+ def prepare(
202
+ self,
203
+ source: DatasetSource = DatasetSource.S3,
204
+ filters: float | str | None = None,
188
205
  ) -> bool:
189
206
  """Download the dataset from DatasetSource
190
207
  url = f"{source}/{self.data.dir_name}"
@@ -208,7 +225,7 @@ class DatasetManager(BaseModel):
208
225
  gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
209
226
  all_files.extend([gt_file, test_file])
210
227
 
211
- if not self.data.isCustom:
228
+ if not self.data.is_custom:
212
229
  source.reader().read(
213
230
  dataset=self.data.dir_name.lower(),
214
231
  files=all_files,
@@ -220,7 +237,7 @@ class DatasetManager(BaseModel):
220
237
  self.gt_data = self._read_file(gt_file)
221
238
 
222
239
  prefix = "shuffle_train" if use_shuffled else "train"
223
- self.train_files = sorted([f.name for f in self.data_dir.glob(f'{prefix}*.parquet')])
240
+ self.train_files = sorted([f.name for f in self.data_dir.glob(f"{prefix}*.parquet")])
224
241
  log.debug(f"{self.data.name}: available train files {self.train_files}")
225
242
 
226
243
  return True
@@ -241,7 +258,7 @@ class DataSetIterator:
241
258
  self._ds = dataset
242
259
  self._idx = 0 # file number
243
260
  self._cur = None
244
- self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file
261
+ self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file
245
262
 
246
263
  def __iter__(self):
247
264
  return self
@@ -250,7 +267,9 @@ class DataSetIterator:
250
267
  p = pathlib.Path(self._ds.data_dir, file_name)
251
268
  log.info(f"Get iterator for {p.name}")
252
269
  if not p.exists():
253
- raise IndexError(f"No such file {p}")
270
+ msg = f"No such file: {p}"
271
+ log.warning(msg)
272
+ raise IndexError(msg)
254
273
  return ParquetFile(p, memory_map=True, pre_buffer=True).iter_batches(config.NUM_PER_BATCH)
255
274
 
256
275
  def __next__(self) -> pd.DataFrame:
@@ -281,6 +300,7 @@ class Dataset(Enum):
281
300
  >>> Dataset.COHERE.manager(100_000)
282
301
  >>> Dataset.COHERE.get(100_000)
283
302
  """
303
+
284
304
  LAION = LAION
285
305
  GIST = GIST
286
306
  COHERE = Cohere
@@ -1,7 +1,7 @@
1
+ import logging
1
2
  import pathlib
2
- from ..models import TestResult
3
3
 
4
- import logging
4
+ from vectordb_bench.models import TestResult
5
5
 
6
6
  log = logging.getLogger(__name__)
7
7
 
@@ -14,7 +14,6 @@ class ResultCollector:
14
14
  if not result_dir.exists() or len(list(result_dir.rglob(reg))) == 0:
15
15
  return []
16
16
 
17
-
18
17
  for json_file in result_dir.rglob(reg):
19
18
  file_result = TestResult.read_file(json_file, trans_unit=True)
20
19
 
@@ -1,12 +1,10 @@
1
1
  from .mp_runner import (
2
2
  MultiProcessingSearchRunner,
3
3
  )
4
-
5
- from .serial_runner import SerialSearchRunner, SerialInsertRunner
6
-
4
+ from .serial_runner import SerialInsertRunner, SerialSearchRunner
7
5
 
8
6
  __all__ = [
9
- 'MultiProcessingSearchRunner',
10
- 'SerialSearchRunner',
11
- 'SerialInsertRunner',
7
+ "MultiProcessingSearchRunner",
8
+ "SerialInsertRunner",
9
+ "SerialSearchRunner",
12
10
  ]
@@ -1,27 +1,29 @@
1
- import time
2
- import traceback
3
1
  import concurrent
2
+ import logging
4
3
  import multiprocessing as mp
5
4
  import random
6
- import logging
7
- from typing import Iterable
5
+ import time
6
+ import traceback
7
+ from collections.abc import Iterable
8
+
8
9
  import numpy as np
9
- from ..clients import api
10
- from ... import config
11
10
 
11
+ from ... import config
12
+ from ..clients import api
12
13
 
13
14
  NUM_PER_BATCH = config.NUM_PER_BATCH
14
15
  log = logging.getLogger(__name__)
15
16
 
16
17
 
17
18
  class MultiProcessingSearchRunner:
18
- """ multiprocessing search runner
19
+ """multiprocessing search runner
19
20
 
20
21
  Args:
21
22
  k(int): search topk, default to 100
22
23
  concurrency(Iterable): concurrencies, default [1, 5, 10, 15, 20, 25, 30, 35]
23
24
  duration(int): duration for each concurency, default to 30s
24
25
  """
26
+
25
27
  def __init__(
26
28
  self,
27
29
  db: api.VectorDB,
@@ -40,7 +42,12 @@ class MultiProcessingSearchRunner:
40
42
  self.test_data = test_data
41
43
  log.debug(f"test dataset columns: {len(test_data)}")
42
44
 
43
- def search(self, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition) -> tuple[int, float]:
45
+ def search(
46
+ self,
47
+ test_data: list[list[float]],
48
+ q: mp.Queue,
49
+ cond: mp.Condition,
50
+ ) -> tuple[int, float]:
44
51
  # sync all process
45
52
  q.put(1)
46
53
  with cond:
@@ -71,24 +78,27 @@ class MultiProcessingSearchRunner:
71
78
  idx = idx + 1 if idx < num - 1 else 0
72
79
 
73
80
  if count % 500 == 0:
74
- log.debug(f"({mp.current_process().name:16}) search_count: {count}, latest_latency={time.perf_counter()-s}")
81
+ log.debug(
82
+ f"({mp.current_process().name:16}) ",
83
+ f"search_count: {count}, latest_latency={time.perf_counter()-s}",
84
+ )
75
85
 
76
86
  total_dur = round(time.perf_counter() - start_time, 4)
77
87
  log.info(
78
88
  f"{mp.current_process().name:16} search {self.duration}s: "
79
- f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}"
80
- )
89
+ f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}",
90
+ )
81
91
 
82
92
  return (count, total_dur, latencies)
83
93
 
84
94
  @staticmethod
85
95
  def get_mp_context():
86
96
  mp_start_method = "spawn"
87
- log.debug(f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}")
97
+ log.debug(
98
+ f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}",
99
+ )
88
100
  return mp.get_context(mp_start_method)
89
101
 
90
-
91
-
92
102
  def _run_all_concurrencies_mem_efficient(self):
93
103
  max_qps = 0
94
104
  conc_num_list = []
@@ -99,8 +109,13 @@ class MultiProcessingSearchRunner:
99
109
  for conc in self.concurrencies:
100
110
  with mp.Manager() as m:
101
111
  q, cond = m.Queue(), m.Condition()
102
- with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor:
103
- log.info(f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}")
112
+ with concurrent.futures.ProcessPoolExecutor(
113
+ mp_context=self.get_mp_context(),
114
+ max_workers=conc,
115
+ ) as executor:
116
+ log.info(
117
+ f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}",
118
+ )
104
119
  future_iter = [executor.submit(self.search, self.test_data, q, cond) for i in range(conc)]
105
120
  # Sync all processes
106
121
  while q.qsize() < conc:
@@ -109,7 +124,9 @@ class MultiProcessingSearchRunner:
109
124
 
110
125
  with cond:
111
126
  cond.notify_all()
112
- log.info(f"Syncing all process and start concurrency search, concurrency={conc}")
127
+ log.info(
128
+ f"Syncing all process and start concurrency search, concurrency={conc}",
129
+ )
113
130
 
114
131
  start = time.perf_counter()
115
132
  all_count = sum([r.result()[0] for r in future_iter])
@@ -123,13 +140,19 @@ class MultiProcessingSearchRunner:
123
140
  conc_qps_list.append(qps)
124
141
  conc_latency_p99_list.append(latency_p99)
125
142
  conc_latency_avg_list.append(latency_avg)
126
- log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}")
143
+ log.info(
144
+ f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}",
145
+ )
127
146
 
128
147
  if qps > max_qps:
129
148
  max_qps = qps
130
- log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}")
149
+ log.info(
150
+ f"Update largest qps with concurrency {conc}: current max_qps={max_qps}",
151
+ )
131
152
  except Exception as e:
132
- log.warning(f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}")
153
+ log.warning(
154
+ f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}",
155
+ )
133
156
  traceback.print_exc()
134
157
 
135
158
  # No results available, raise exception
@@ -139,7 +162,13 @@ class MultiProcessingSearchRunner:
139
162
  finally:
140
163
  self.stop()
141
164
 
142
- return max_qps, conc_num_list, conc_qps_list, conc_latency_p99_list, conc_latency_avg_list
165
+ return (
166
+ max_qps,
167
+ conc_num_list,
168
+ conc_qps_list,
169
+ conc_latency_p99_list,
170
+ conc_latency_avg_list,
171
+ )
143
172
 
144
173
  def run(self) -> float:
145
174
  """
@@ -160,9 +189,16 @@ class MultiProcessingSearchRunner:
160
189
  for conc in self.concurrencies:
161
190
  with mp.Manager() as m:
162
191
  q, cond = m.Queue(), m.Condition()
163
- with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor:
164
- log.info(f"Start search_by_dur {duration}s in concurrency {conc}, filters: {self.filters}")
165
- future_iter = [executor.submit(self.search_by_dur, duration, self.test_data, q, cond) for i in range(conc)]
192
+ with concurrent.futures.ProcessPoolExecutor(
193
+ mp_context=self.get_mp_context(),
194
+ max_workers=conc,
195
+ ) as executor:
196
+ log.info(
197
+ f"Start search_by_dur {duration}s in concurrency {conc}, filters: {self.filters}",
198
+ )
199
+ future_iter = [
200
+ executor.submit(self.search_by_dur, duration, self.test_data, q, cond) for i in range(conc)
201
+ ]
166
202
  # Sync all processes
167
203
  while q.qsize() < conc:
168
204
  sleep_t = conc if conc < 10 else 10
@@ -170,20 +206,28 @@ class MultiProcessingSearchRunner:
170
206
 
171
207
  with cond:
172
208
  cond.notify_all()
173
- log.info(f"Syncing all process and start concurrency search, concurrency={conc}")
209
+ log.info(
210
+ f"Syncing all process and start concurrency search, concurrency={conc}",
211
+ )
174
212
 
175
213
  start = time.perf_counter()
176
214
  all_count = sum([r.result() for r in future_iter])
177
215
  cost = time.perf_counter() - start
178
216
 
179
217
  qps = round(all_count / cost, 4)
180
- log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}")
218
+ log.info(
219
+ f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}",
220
+ )
181
221
 
182
222
  if qps > max_qps:
183
223
  max_qps = qps
184
- log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}")
224
+ log.info(
225
+ f"Update largest qps with concurrency {conc}: current max_qps={max_qps}",
226
+ )
185
227
  except Exception as e:
186
- log.warning(f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}")
228
+ log.warning(
229
+ f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}",
230
+ )
187
231
  traceback.print_exc()
188
232
 
189
233
  # No results available, raise exception
@@ -195,8 +239,13 @@ class MultiProcessingSearchRunner:
195
239
 
196
240
  return max_qps
197
241
 
198
-
199
- def search_by_dur(self, dur: int, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition) -> int:
242
+ def search_by_dur(
243
+ self,
244
+ dur: int,
245
+ test_data: list[list[float]],
246
+ q: mp.Queue,
247
+ cond: mp.Condition,
248
+ ) -> int:
200
249
  # sync all process
201
250
  q.put(1)
202
251
  with cond:
@@ -225,13 +274,15 @@ class MultiProcessingSearchRunner:
225
274
  idx = idx + 1 if idx < num - 1 else 0
226
275
 
227
276
  if count % 500 == 0:
228
- log.debug(f"({mp.current_process().name:16}) search_count: {count}, latest_latency={time.perf_counter()-s}")
277
+ log.debug(
278
+ f"({mp.current_process().name:16}) search_count: {count}, ",
279
+ f"latest_latency={time.perf_counter()-s}",
280
+ )
229
281
 
230
282
  total_dur = round(time.perf_counter() - start_time, 4)
231
283
  log.debug(
232
284
  f"{mp.current_process().name:16} search {self.duration}s: "
233
- f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}"
234
- )
285
+ f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}",
286
+ )
235
287
 
236
288
  return count
237
-