vectordb-bench 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +14 -3
- vectordb_bench/backend/cases.py +34 -13
- vectordb_bench/backend/clients/__init__.py +6 -1
- vectordb_bench/backend/clients/api.py +12 -8
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +4 -2
- vectordb_bench/backend/clients/milvus/milvus.py +17 -10
- vectordb_bench/backend/clients/pgvector/config.py +49 -0
- vectordb_bench/backend/clients/pgvector/pgvector.py +171 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +4 -3
- vectordb_bench/backend/clients/qdrant_cloud/config.py +20 -2
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +11 -11
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +5 -5
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +3 -1
- vectordb_bench/backend/dataset.py +99 -149
- vectordb_bench/backend/result_collector.py +2 -2
- vectordb_bench/backend/runner/mp_runner.py +29 -13
- vectordb_bench/backend/runner/serial_runner.py +69 -51
- vectordb_bench/backend/task_runner.py +43 -48
- vectordb_bench/frontend/components/get_results/saveAsImage.py +4 -2
- vectordb_bench/frontend/const/dbCaseConfigs.py +35 -4
- vectordb_bench/frontend/const/dbPrices.py +5 -33
- vectordb_bench/frontend/const/styles.py +9 -3
- vectordb_bench/metric.py +0 -1
- vectordb_bench/models.py +12 -8
- vectordb_bench/results/dbPrices.json +32 -0
- vectordb_bench/results/getLeaderboardData.py +52 -0
- vectordb_bench/results/leaderboard.json +1 -0
- vectordb_bench/results/{result_20230609_standard.json → result_20230705_standard.json} +670 -214
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/METADATA +98 -13
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/RECORD +34 -29
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/WHEEL +0 -0
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.2.dist-info → vectordb_bench-0.0.3.dist-info}/top_level.txt +0 -0
@@ -3,13 +3,12 @@
|
|
3
3
|
import logging
|
4
4
|
import time
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from typing import
|
6
|
+
from typing import Type
|
7
7
|
|
8
|
-
from ..api import VectorDB, DBConfig, DBCaseConfig,
|
9
|
-
from .config import QdrantConfig
|
8
|
+
from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType
|
9
|
+
from .config import QdrantConfig, QdrantIndexConfig
|
10
10
|
from qdrant_client.http.models import (
|
11
11
|
CollectionStatus,
|
12
|
-
Distance,
|
13
12
|
VectorParams,
|
14
13
|
PayloadSchemaType,
|
15
14
|
Batch,
|
@@ -32,6 +31,7 @@ class QdrantCloud(VectorDB):
|
|
32
31
|
db_case_config: DBCaseConfig,
|
33
32
|
collection_name: str = "QdrantCloudCollection",
|
34
33
|
drop_old: bool = False,
|
34
|
+
**kwargs,
|
35
35
|
):
|
36
36
|
"""Initialize wrapper around the QdrantCloud vector database."""
|
37
37
|
self.db_config = db_config
|
@@ -55,7 +55,7 @@ class QdrantCloud(VectorDB):
|
|
55
55
|
|
56
56
|
@classmethod
|
57
57
|
def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
58
|
-
return
|
58
|
+
return QdrantIndexConfig
|
59
59
|
|
60
60
|
@contextmanager
|
61
61
|
def init(self) -> None:
|
@@ -74,7 +74,7 @@ class QdrantCloud(VectorDB):
|
|
74
74
|
pass
|
75
75
|
|
76
76
|
|
77
|
-
def
|
77
|
+
def optimize(self):
|
78
78
|
assert self.qdrant_client, "Please call self.init() before"
|
79
79
|
# wait for vectors to be fully indexed
|
80
80
|
SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
|
@@ -97,7 +97,7 @@ class QdrantCloud(VectorDB):
|
|
97
97
|
try:
|
98
98
|
qdrant_client.create_collection(
|
99
99
|
collection_name=self.collection_name,
|
100
|
-
vectors_config=VectorParams(size=dim, distance=
|
100
|
+
vectors_config=VectorParams(size=dim, distance=self.case_config.index_param()["distance"])
|
101
101
|
)
|
102
102
|
|
103
103
|
qdrant_client.create_payload_index(
|
@@ -116,7 +116,7 @@ class QdrantCloud(VectorDB):
|
|
116
116
|
self,
|
117
117
|
embeddings: list[list[float]],
|
118
118
|
metadata: list[int],
|
119
|
-
**kwargs
|
119
|
+
**kwargs,
|
120
120
|
) -> (int, Exception):
|
121
121
|
"""Insert embeddings into Milvus. should call self.init() first"""
|
122
122
|
assert self.qdrant_client is not None
|
@@ -127,10 +127,11 @@ class QdrantCloud(VectorDB):
|
|
127
127
|
wait=True,
|
128
128
|
points=Batch(ids=metadata, payloads=[{self._primary_field: v} for v in metadata], vectors=embeddings)
|
129
129
|
)
|
130
|
-
return (len(metadata), None)
|
131
130
|
except Exception as e:
|
132
131
|
log.info(f"Failed to insert data, {e}")
|
133
|
-
return
|
132
|
+
return 0, e
|
133
|
+
else:
|
134
|
+
return len(metadata), None
|
134
135
|
|
135
136
|
def search_embedding(
|
136
137
|
self,
|
@@ -138,7 +139,6 @@ class QdrantCloud(VectorDB):
|
|
138
139
|
k: int = 100,
|
139
140
|
filters: dict | None = None,
|
140
141
|
timeout: int | None = None,
|
141
|
-
**kwargs: Any,
|
142
142
|
) -> list[int]:
|
143
143
|
"""Perform a search on a query embedding and return results with score.
|
144
144
|
Should call self.init() first.
|
@@ -1,7 +1,7 @@
|
|
1
1
|
"""Wrapper around the Weaviate vector database over VectorDB"""
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import
|
4
|
+
from typing import Iterable, Type
|
5
5
|
from contextlib import contextmanager
|
6
6
|
|
7
7
|
from weaviate.exceptions import WeaviateBaseError
|
@@ -21,6 +21,7 @@ class WeaviateCloud(VectorDB):
|
|
21
21
|
db_case_config: DBCaseConfig,
|
22
22
|
collection_name: str = "VectorDBBenchCollection",
|
23
23
|
drop_old: bool = False,
|
24
|
+
**kwargs,
|
24
25
|
):
|
25
26
|
"""Initialize wrapper around the weaviate vector database."""
|
26
27
|
self.db_config = db_config
|
@@ -70,7 +71,7 @@ class WeaviateCloud(VectorDB):
|
|
70
71
|
"""Should call insert first, do nothing"""
|
71
72
|
pass
|
72
73
|
|
73
|
-
def
|
74
|
+
def optimize(self):
|
74
75
|
assert self.client.schema.exists(self.collection_name)
|
75
76
|
self.client.schema.update_config(self.collection_name, {"vectorIndexConfig": self.case_config.search_param() } )
|
76
77
|
|
@@ -98,13 +99,13 @@ class WeaviateCloud(VectorDB):
|
|
98
99
|
self,
|
99
100
|
embeddings: Iterable[list[float]],
|
100
101
|
metadata: list[int],
|
101
|
-
**kwargs
|
102
|
+
**kwargs,
|
102
103
|
) -> (int, Exception):
|
103
104
|
"""Insert embeddings into Weaviate"""
|
104
105
|
assert self.client.schema.exists(self.collection_name)
|
105
106
|
insert_count = 0
|
106
107
|
try:
|
107
|
-
with self.client.batch as batch:
|
108
|
+
with self.client.batch as batch:
|
108
109
|
batch.batch_size = len(metadata)
|
109
110
|
batch.dynamic = True
|
110
111
|
res = []
|
@@ -126,7 +127,6 @@ class WeaviateCloud(VectorDB):
|
|
126
127
|
k: int = 100,
|
127
128
|
filters: dict | None = None,
|
128
129
|
timeout: int | None = None,
|
129
|
-
**kwargs: Any,
|
130
130
|
) -> list[int]:
|
131
131
|
"""Perform a search on a query embedding and return results with distance.
|
132
132
|
Should call self.init() first.
|
@@ -14,7 +14,8 @@ class ZillizCloud(Milvus):
|
|
14
14
|
db_case_config: DBCaseConfig,
|
15
15
|
collection_name: str = "ZillizCloudVectorDBBench",
|
16
16
|
drop_old: bool = False,
|
17
|
-
name: str = "ZillizCloud"
|
17
|
+
name: str = "ZillizCloud",
|
18
|
+
**kwargs,
|
18
19
|
):
|
19
20
|
super().__init__(
|
20
21
|
dim=dim,
|
@@ -23,6 +24,7 @@ class ZillizCloud(Milvus):
|
|
23
24
|
collection_name=collection_name,
|
24
25
|
drop_old=drop_old,
|
25
26
|
name=name,
|
27
|
+
**kwargs,
|
26
28
|
)
|
27
29
|
|
28
30
|
@classmethod
|
@@ -1,23 +1,20 @@
|
|
1
1
|
"""
|
2
2
|
Usage:
|
3
|
-
>>> from xxx import
|
4
|
-
>>>
|
5
|
-
>>> gist_s.dict()
|
6
|
-
dataset: {'data': {'name': 'GIST', 'dim': 128, 'metric_type': 'L2', 'label': 'SMALL', 'size': 50000000}, 'data_dir': 'xxx'}
|
3
|
+
>>> from xxx.dataset import Dataset
|
4
|
+
>>> Dataset.Cohere.get(100_000)
|
7
5
|
"""
|
8
6
|
|
9
7
|
import os
|
10
8
|
import logging
|
11
9
|
import pathlib
|
12
|
-
import math
|
13
10
|
from hashlib import md5
|
14
|
-
from enum import Enum
|
15
|
-
from typing import Any
|
16
|
-
|
11
|
+
from enum import Enum
|
17
12
|
import s3fs
|
18
13
|
import pandas as pd
|
19
14
|
from tqdm import tqdm
|
20
|
-
from pydantic
|
15
|
+
from pydantic import validator, PrivateAttr
|
16
|
+
import polars as pl
|
17
|
+
from pyarrow.parquet import ParquetFile
|
21
18
|
|
22
19
|
from ..base import BaseModel
|
23
20
|
from .. import config
|
@@ -26,108 +23,83 @@ from . import utils
|
|
26
23
|
|
27
24
|
log = logging.getLogger(__name__)
|
28
25
|
|
29
|
-
|
30
|
-
class
|
26
|
+
|
27
|
+
class BaseDataset(BaseModel):
|
28
|
+
name: str
|
29
|
+
size: int
|
30
|
+
dim: int
|
31
|
+
metric_type: MetricType
|
32
|
+
use_shuffled: bool
|
33
|
+
_size_label: dict = PrivateAttr()
|
34
|
+
|
35
|
+
@validator("size")
|
36
|
+
def verify_size(cls, v):
|
37
|
+
if v not in cls._size_label:
|
38
|
+
raise ValueError(f"Size {v} not supported for the dataset, expected: {cls._size_label.keys()}")
|
39
|
+
return v
|
40
|
+
|
41
|
+
@property
|
42
|
+
def label(self) -> str:
|
43
|
+
return self._size_label.get(self.size)
|
44
|
+
|
45
|
+
@property
|
46
|
+
def dir_name(self) -> str:
|
47
|
+
return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
|
48
|
+
|
49
|
+
|
50
|
+
class LAION(BaseDataset):
|
31
51
|
name: str = "LAION"
|
32
52
|
dim: int = 768
|
33
53
|
metric_type: MetricType = MetricType.L2
|
34
54
|
use_shuffled: bool = False
|
55
|
+
_size_label: dict = {100_000_000: "LARGE"}
|
35
56
|
|
36
|
-
@property
|
37
|
-
def dir_name(self) -> str:
|
38
|
-
return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
|
39
57
|
|
40
|
-
|
41
|
-
class GIST:
|
58
|
+
class GIST(BaseDataset):
|
42
59
|
name: str = "GIST"
|
43
60
|
dim: int = 960
|
44
61
|
metric_type: MetricType = MetricType.L2
|
45
62
|
use_shuffled: bool = False
|
63
|
+
_size_label: dict = {
|
64
|
+
100_000: "SMALL",
|
65
|
+
1_000_000: "MEDIUM",
|
66
|
+
}
|
46
67
|
|
47
|
-
@property
|
48
|
-
def dir_name(self) -> str:
|
49
|
-
return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
|
50
68
|
|
51
|
-
|
52
|
-
class Cohere:
|
69
|
+
class Cohere(BaseDataset):
|
53
70
|
name: str = "Cohere"
|
54
71
|
dim: int = 768
|
55
72
|
metric_type: MetricType = MetricType.COSINE
|
56
73
|
use_shuffled: bool = config.USE_SHUFFLED_DATA
|
74
|
+
_size_label: dict = {
|
75
|
+
100_000: "SMALL",
|
76
|
+
1_000_000: "MEDIUM",
|
77
|
+
10_000_000: "LARGE",
|
78
|
+
}
|
57
79
|
|
58
|
-
@property
|
59
|
-
def dir_name(self) -> str:
|
60
|
-
return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
|
61
80
|
|
62
|
-
|
63
|
-
class Glove:
|
81
|
+
class Glove(BaseDataset):
|
64
82
|
name: str = "Glove"
|
65
83
|
dim: int = 200
|
66
84
|
metric_type: MetricType = MetricType.COSINE
|
67
85
|
use_shuffled: bool = False
|
86
|
+
_size_label: dict = {1_000_000: "MEDIUM"}
|
68
87
|
|
69
|
-
@property
|
70
|
-
def dir_name(self) -> str:
|
71
|
-
return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
|
72
88
|
|
73
|
-
|
74
|
-
class SIFT:
|
89
|
+
class SIFT(BaseDataset):
|
75
90
|
name: str = "SIFT"
|
76
91
|
dim: int = 128
|
77
|
-
metric_type: MetricType = MetricType.
|
92
|
+
metric_type: MetricType = MetricType.L2
|
78
93
|
use_shuffled: bool = False
|
94
|
+
_size_label: dict = {
|
79
95
|
|
80
|
-
|
81
|
-
|
82
|
-
|
96
|
+
500_000: "SMALL",
|
97
|
+
5_000_000: "MEDIUM",
|
98
|
+
50_000_000: "LARGE",
|
99
|
+
}
|
83
100
|
|
84
|
-
|
85
|
-
class
|
86
|
-
label: str = "LARGE"
|
87
|
-
size: int = 100_000_000
|
88
|
-
|
89
|
-
@dataclass
|
90
|
-
class GIST_S(GIST):
|
91
|
-
label: str = "SMALL"
|
92
|
-
size: int = 100_000
|
93
|
-
|
94
|
-
@dataclass
|
95
|
-
class GIST_M(GIST):
|
96
|
-
label: str = "MEDIUM"
|
97
|
-
size: int = 1_000_000
|
98
|
-
|
99
|
-
@dataclass
|
100
|
-
class Cohere_M(Cohere):
|
101
|
-
label: str = "MEDIUM"
|
102
|
-
size: int = 1_000_000
|
103
|
-
|
104
|
-
@dataclass
|
105
|
-
class Cohere_L(Cohere):
|
106
|
-
label : str = "LARGE"
|
107
|
-
size : int = 10_000_000
|
108
|
-
|
109
|
-
@dataclass
|
110
|
-
class Glove_M(Glove):
|
111
|
-
label: str = "MEDIUM"
|
112
|
-
size : int = 1_000_000
|
113
|
-
|
114
|
-
@dataclass
|
115
|
-
class SIFT_S(SIFT):
|
116
|
-
label: str = "SMALL"
|
117
|
-
size : int = 500_000
|
118
|
-
|
119
|
-
@dataclass
|
120
|
-
class SIFT_M(SIFT):
|
121
|
-
label: str = "MEDIUM"
|
122
|
-
size : int = 5_000_000
|
123
|
-
|
124
|
-
@dataclass
|
125
|
-
class SIFT_L(SIFT):
|
126
|
-
label: str = "LARGE"
|
127
|
-
size : int = 50_000_000
|
128
|
-
|
129
|
-
|
130
|
-
class DataSet(BaseModel):
|
101
|
+
|
102
|
+
class DatasetManager(BaseModel):
|
131
103
|
"""Download dataset if not int the local directory. Provide data for cases.
|
132
104
|
|
133
105
|
DataSet is iterable, each iteration will return the next batch of data in pandas.DataFrame
|
@@ -137,12 +109,12 @@ class DataSet(BaseModel):
|
|
137
109
|
>>> for data in cohere_s:
|
138
110
|
>>> print(data.columns)
|
139
111
|
"""
|
140
|
-
data:
|
112
|
+
data: BaseDataset
|
141
113
|
test_data: pd.DataFrame | None = None
|
142
114
|
train_files : list[str] = []
|
143
115
|
|
144
116
|
def __eq__(self, obj):
|
145
|
-
if isinstance(obj,
|
117
|
+
if isinstance(obj, DatasetManager):
|
146
118
|
return self.data.name == obj.data.name and \
|
147
119
|
self.data.label == obj.data.label
|
148
120
|
return False
|
@@ -294,88 +266,66 @@ class DataSet(BaseModel):
|
|
294
266
|
|
295
267
|
def _read_file(self, file_name: str) -> pd.DataFrame:
|
296
268
|
"""read one file from disk into memory"""
|
297
|
-
|
298
|
-
|
269
|
+
log.info(f"Read the entire file into memory: {file_name}")
|
299
270
|
p = pathlib.Path(self.data_dir, file_name)
|
300
|
-
log.info(f"reading file into memory: {p}")
|
301
271
|
if not p.exists():
|
302
272
|
log.warning(f"No such file: {p}")
|
303
273
|
return pd.DataFrame()
|
304
|
-
|
305
|
-
|
306
|
-
return df
|
274
|
+
|
275
|
+
return pl.read_parquet(p)
|
307
276
|
|
308
277
|
|
309
278
|
class DataSetIterator:
|
310
|
-
def __init__(self, dataset:
|
279
|
+
def __init__(self, dataset: DatasetManager):
|
311
280
|
self._ds = dataset
|
312
281
|
self._idx = 0 # file number
|
313
|
-
self.
|
282
|
+
self._cur = None
|
314
283
|
self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file
|
315
284
|
|
285
|
+
def _get_iter(self, file_name: str):
|
286
|
+
p = pathlib.Path(self._ds.data_dir, file_name)
|
287
|
+
log.info(f"Get iterator for {p.name}")
|
288
|
+
if not p.exists():
|
289
|
+
raise IndexError(f"No such file {p}")
|
290
|
+
log.warning(f"No such file: {p}")
|
291
|
+
return ParquetFile(p).iter_batches(config.NUM_PER_BATCH)
|
292
|
+
|
316
293
|
def __next__(self) -> pd.DataFrame:
|
317
294
|
"""return the data in the next file of the training list"""
|
318
295
|
if self._idx < len(self._ds.train_files):
|
319
|
-
|
320
|
-
if _sub == 0 and self._idx == 0: # init
|
296
|
+
if self._cur is None:
|
321
297
|
file_name = self._ds.train_files[self._idx]
|
322
|
-
self.
|
323
|
-
self._iter_num = math.ceil(self._curr.shape[0]/100_000)
|
298
|
+
self._cur = self._get_iter(file_name)
|
324
299
|
|
325
|
-
|
300
|
+
try:
|
301
|
+
return next(self._cur).to_pandas()
|
302
|
+
except StopIteration:
|
326
303
|
if self._idx == len(self._ds.train_files) - 1:
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
self._curr = None
|
334
|
-
file_name = self._ds.train_files[self._idx]
|
335
|
-
self._curr = self._ds._read_file(file_name)
|
336
|
-
|
337
|
-
sub_df = self._curr[_sub*100_000: (_sub+1)*100_000]
|
338
|
-
self._sub_idx[self._idx] += 1
|
339
|
-
log.info(f"Get the [{_sub+1}/{self._iter_num}] batch of {self._idx+1}/{len(self._ds.train_files)} train file")
|
340
|
-
return sub_df
|
341
|
-
self._curr = None
|
304
|
+
raise StopIteration from None
|
305
|
+
|
306
|
+
self._idx += 1
|
307
|
+
file_name = self._ds.train_files[self._idx]
|
308
|
+
self._cur = self._get_iter(file_name)
|
309
|
+
return next(self._cur).to_pandas()
|
342
310
|
raise StopIteration
|
343
311
|
|
344
312
|
|
345
|
-
class
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
Label.MEDIUM: DataSet(data=Cohere_M()),
|
365
|
-
Label.LARGE: DataSet(data=Cohere_L()),
|
366
|
-
},
|
367
|
-
Name.Glove:{
|
368
|
-
Label.MEDIUM: DataSet(data=Glove_M()),
|
369
|
-
},
|
370
|
-
Name.SIFT: {
|
371
|
-
Label.SMALL: DataSet(data=SIFT_S()),
|
372
|
-
Label.MEDIUM: DataSet(data=SIFT_M()),
|
373
|
-
Label.LARGE: DataSet(data=SIFT_L()),
|
374
|
-
},
|
375
|
-
Name.LAION: {
|
376
|
-
Label.LARGE: DataSet(data=LAION_L()),
|
377
|
-
},
|
378
|
-
}
|
379
|
-
|
380
|
-
def get(ds: Name, label: Label):
|
381
|
-
return _global_ds_mapping.get(ds, {}).get(label)
|
313
|
+
class Dataset(Enum):
|
314
|
+
"""
|
315
|
+
Value is Dataset classes, DO NOT use it
|
316
|
+
Example:
|
317
|
+
>>> all_dataset = [ds.name for ds in Dataset]
|
318
|
+
>>> Dataset.COHERE.manager(100_000)
|
319
|
+
>>> Dataset.COHERE.get(100_000)
|
320
|
+
"""
|
321
|
+
LAION = LAION
|
322
|
+
GIST = GIST
|
323
|
+
COHERE = Cohere
|
324
|
+
GLOVE = Glove
|
325
|
+
SIFT = SIFT
|
326
|
+
|
327
|
+
def get(self, size: int) -> BaseDataset:
|
328
|
+
return self.value(size=size)
|
329
|
+
|
330
|
+
def manager(self, size: int) -> DatasetManager:
|
331
|
+
return DatasetManager(data=self.get(size))
|
@@ -6,10 +6,10 @@ class ResultCollector:
|
|
6
6
|
@classmethod
|
7
7
|
def collect(cls, result_dir: pathlib.Path) -> list[TestResult]:
|
8
8
|
results = []
|
9
|
-
if not result_dir.exists() or len(list(result_dir.glob("*.json"))) == 0:
|
9
|
+
if not result_dir.exists() or len(list(result_dir.glob("result_*.json"))) == 0:
|
10
10
|
return []
|
11
11
|
|
12
|
-
for json_file in result_dir.glob("*.json"):
|
12
|
+
for json_file in result_dir.glob("result_*.json"):
|
13
13
|
results.append(TestResult.read_file(json_file, trans_unit=True))
|
14
14
|
|
15
15
|
return results
|
@@ -40,7 +40,12 @@ class MultiProcessingSearchRunner:
|
|
40
40
|
self.test_data = utils.SharedNumpyArray(test_data)
|
41
41
|
log.debug(f"test dataset columns: {len(test_data)}")
|
42
42
|
|
43
|
-
def search(self, test_np: utils.SharedNumpyArray) -> tuple[int, float]:
|
43
|
+
def search(self, test_np: utils.SharedNumpyArray, q: mp.Queue, cond: mp.Condition) -> tuple[int, float]:
|
44
|
+
# sync all process
|
45
|
+
q.put(1)
|
46
|
+
with cond:
|
47
|
+
cond.wait()
|
48
|
+
|
44
49
|
with self.db.init():
|
45
50
|
test_data = test_np.read().tolist()
|
46
51
|
num, idx = len(test_data), 0
|
@@ -77,7 +82,7 @@ class MultiProcessingSearchRunner:
|
|
77
82
|
|
78
83
|
@staticmethod
|
79
84
|
def get_mp_context():
|
80
|
-
mp_start_method = "
|
85
|
+
mp_start_method = "spawn"
|
81
86
|
log.debug(f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}")
|
82
87
|
return mp.get_context(mp_start_method)
|
83
88
|
|
@@ -85,21 +90,32 @@ class MultiProcessingSearchRunner:
|
|
85
90
|
max_qps = 0
|
86
91
|
try:
|
87
92
|
for conc in self.concurrencies:
|
88
|
-
with
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
93
|
+
with mp.Manager() as m:
|
94
|
+
q, cond = m.Queue(), m.Condition()
|
95
|
+
with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor:
|
96
|
+
log.info(f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}")
|
97
|
+
future_iter = [executor.submit(self.search, self.test_data, q, cond) for i in range(conc)]
|
98
|
+
# Sync all processes
|
99
|
+
while q.qsize() < conc:
|
100
|
+
sleep_t = conc if conc < 10 else 10
|
101
|
+
time.sleep(sleep_t)
|
102
|
+
|
103
|
+
with cond:
|
104
|
+
cond.notify_all()
|
105
|
+
log.info(f"Syncing all process and start concurrency search, concurrency={conc}")
|
106
|
+
|
107
|
+
start = time.perf_counter()
|
108
|
+
all_count = sum([r.result()[0] for r in future_iter])
|
109
|
+
cost = time.perf_counter() - start
|
110
|
+
|
111
|
+
qps = round(all_count / cost, 4)
|
112
|
+
log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}")
|
97
113
|
|
98
114
|
if qps > max_qps:
|
99
115
|
max_qps = qps
|
100
|
-
log.info(f"
|
116
|
+
log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}")
|
101
117
|
except Exception as e:
|
102
|
-
log.warning(f"
|
118
|
+
log.warning(f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}")
|
103
119
|
traceback.print_exc()
|
104
120
|
|
105
121
|
# No results available, raise exception
|