vectordb-bench 0.0.1__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. vectordb_bench/__init__.py +30 -0
  2. vectordb_bench/__main__.py +39 -0
  3. vectordb_bench/backend/__init__.py +0 -0
  4. vectordb_bench/backend/assembler.py +57 -0
  5. vectordb_bench/backend/cases.py +124 -0
  6. vectordb_bench/backend/clients/__init__.py +57 -0
  7. vectordb_bench/backend/clients/api.py +179 -0
  8. vectordb_bench/backend/clients/elastic_cloud/config.py +56 -0
  9. vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +152 -0
  10. vectordb_bench/backend/clients/milvus/config.py +123 -0
  11. vectordb_bench/backend/clients/milvus/milvus.py +182 -0
  12. vectordb_bench/backend/clients/pinecone/config.py +15 -0
  13. vectordb_bench/backend/clients/pinecone/pinecone.py +113 -0
  14. vectordb_bench/backend/clients/qdrant_cloud/config.py +16 -0
  15. vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +169 -0
  16. vectordb_bench/backend/clients/weaviate_cloud/config.py +45 -0
  17. vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +151 -0
  18. vectordb_bench/backend/clients/zilliz_cloud/config.py +34 -0
  19. vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +35 -0
  20. vectordb_bench/backend/dataset.py +393 -0
  21. vectordb_bench/backend/result_collector.py +15 -0
  22. vectordb_bench/backend/runner/__init__.py +12 -0
  23. vectordb_bench/backend/runner/mp_runner.py +124 -0
  24. vectordb_bench/backend/runner/serial_runner.py +164 -0
  25. vectordb_bench/backend/task_runner.py +290 -0
  26. vectordb_bench/backend/utils.py +85 -0
  27. vectordb_bench/base.py +6 -0
  28. vectordb_bench/frontend/components/check_results/charts.py +175 -0
  29. vectordb_bench/frontend/components/check_results/data.py +86 -0
  30. vectordb_bench/frontend/components/check_results/filters.py +97 -0
  31. vectordb_bench/frontend/components/check_results/headerIcon.py +18 -0
  32. vectordb_bench/frontend/components/check_results/nav.py +21 -0
  33. vectordb_bench/frontend/components/check_results/priceTable.py +48 -0
  34. vectordb_bench/frontend/components/run_test/autoRefresh.py +10 -0
  35. vectordb_bench/frontend/components/run_test/caseSelector.py +87 -0
  36. vectordb_bench/frontend/components/run_test/dbConfigSetting.py +47 -0
  37. vectordb_bench/frontend/components/run_test/dbSelector.py +36 -0
  38. vectordb_bench/frontend/components/run_test/generateTasks.py +21 -0
  39. vectordb_bench/frontend/components/run_test/hideSidebar.py +10 -0
  40. vectordb_bench/frontend/components/run_test/submitTask.py +69 -0
  41. vectordb_bench/frontend/const.py +391 -0
  42. vectordb_bench/frontend/pages/qps_with_price.py +60 -0
  43. vectordb_bench/frontend/pages/run_test.py +59 -0
  44. vectordb_bench/frontend/utils.py +6 -0
  45. vectordb_bench/frontend/vdb_benchmark.py +42 -0
  46. vectordb_bench/interface.py +239 -0
  47. vectordb_bench/log_util.py +103 -0
  48. vectordb_bench/metric.py +53 -0
  49. vectordb_bench/models.py +234 -0
  50. vectordb_bench/results/result_20230609_standard.json +3228 -0
  51. vectordb_bench-0.0.1.dist-info/LICENSE +21 -0
  52. vectordb_bench-0.0.1.dist-info/METADATA +226 -0
  53. vectordb_bench-0.0.1.dist-info/RECORD +56 -0
  54. vectordb_bench-0.0.1.dist-info/WHEEL +5 -0
  55. vectordb_bench-0.0.1.dist-info/entry_points.txt +2 -0
  56. vectordb_bench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,151 @@
1
+ """Wrapper around the Weaviate vector database over VectorDB"""
2
+
3
+ import logging
4
+ from typing import Any, Iterable, Type
5
+ from contextlib import contextmanager
6
+
7
+ from weaviate.exceptions import WeaviateBaseError
8
+
9
+ from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType
10
+ from .config import WeaviateConfig, WeaviateIndexConfig
11
+
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ class WeaviateCloud(VectorDB):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ db_config: dict,
21
+ db_case_config: DBCaseConfig,
22
+ collection_name: str = "VectorDBBenchCollection",
23
+ drop_old: bool = False,
24
+ ):
25
+ """Initialize wrapper around the weaviate vector database."""
26
+ self.db_config = db_config
27
+ self.case_config = db_case_config
28
+ self.collection_name = collection_name
29
+
30
+ self._scalar_field = "key"
31
+ self._vector_field = "vector"
32
+ self._index_name = "vector_idx"
33
+
34
+ from weaviate import Client
35
+ client = Client(**db_config)
36
+ if drop_old:
37
+ try:
38
+ if client.schema.exists(self.collection_name):
39
+ log.info(f"weaviate client drop_old collection: {self.collection_name}")
40
+ client.schema.delete_class(self.collection_name)
41
+ except WeaviateBaseError as e:
42
+ log.warning(f"Failed to drop collection: {self.collection_name} error: {str(e)}")
43
+ raise e from None
44
+ self._create_collection(client)
45
+ client = None
46
+
47
+ @classmethod
48
+ def config_cls(cls) -> Type[DBConfig]:
49
+ return WeaviateConfig
50
+
51
+ @classmethod
52
+ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
53
+ return WeaviateIndexConfig
54
+
55
+ @contextmanager
56
+ def init(self) -> None:
57
+ """
58
+ Examples:
59
+ >>> with self.init():
60
+ >>> self.insert_embeddings()
61
+ >>> self.search_embedding()
62
+ """
63
+ from weaviate import Client
64
+ self.client = Client(**self.db_config)
65
+ yield
66
+ self.client = None
67
+ del(self.client)
68
+
69
+ def ready_to_load(self):
70
+ """Should call insert first, do nothing"""
71
+ pass
72
+
73
+ def ready_to_search(self):
74
+ assert self.client.schema.exists(self.collection_name)
75
+ self.client.schema.update_config(self.collection_name, {"vectorIndexConfig": self.case_config.search_param() } )
76
+
77
+ def _create_collection(self, client):
78
+ if not client.schema.exists(self.collection_name):
79
+ log.info(f"Create collection: {self.collection_name}")
80
+ class_obj = {
81
+ "class": self.collection_name,
82
+ "vectorizer": "none",
83
+ "properties": [
84
+ {
85
+ "dataType": ["int"],
86
+ "name": self._scalar_field,
87
+ },
88
+ ]
89
+ }
90
+ class_obj["vectorIndexConfig"] = self.case_config.index_param()
91
+ try:
92
+ client.schema.create_class(class_obj)
93
+ except WeaviateBaseError as e:
94
+ log.warning(f"Failed to create collection: {self.collection_name} error: {str(e)}")
95
+ raise e from None
96
+
97
+ def insert_embeddings(
98
+ self,
99
+ embeddings: Iterable[list[float]],
100
+ metadata: list[int],
101
+ **kwargs: Any,
102
+ ) -> int:
103
+ """Insert embeddings into Weaviate"""
104
+ assert self.client.schema.exists(self.collection_name)
105
+
106
+ try:
107
+ with self.client.batch as batch:
108
+ batch.batch_size = len(metadata)
109
+ batch.dynamic = True
110
+ res = []
111
+ for i in range(len(metadata)):
112
+ res.append(batch.add_data_object(
113
+ {self._scalar_field: metadata[i]},
114
+ class_name=self.collection_name,
115
+ vector=embeddings[i]
116
+ ))
117
+ return len(res)
118
+ except WeaviateBaseError as e:
119
+ log.warning(f"Failed to insert data, error: {str(e)}")
120
+ raise e from None
121
+
122
+ def search_embedding(
123
+ self,
124
+ query: list[float],
125
+ k: int = 100,
126
+ filters: dict | None = None,
127
+ timeout: int | None = None,
128
+ **kwargs: Any,
129
+ ) -> list[int]:
130
+ """Perform a search on a query embedding and return results with distance.
131
+ Should call self.init() first.
132
+ """
133
+ assert self.client.schema.exists(self.collection_name)
134
+
135
+ query_obj = self.client.query.get(self.collection_name, [self._scalar_field]).with_additional("distance").with_near_vector({"vector": query}).with_limit(k)
136
+ if filters:
137
+ where_filter = {
138
+ "path": "key",
139
+ "operator": "GreaterThanEqual",
140
+ "valueInt": filters.get('id')
141
+ }
142
+ query_obj = query_obj.with_where(where_filter)
143
+
144
+ # Perform the search.
145
+ res = query_obj.do()
146
+
147
+ # Organize results.
148
+ ret = [result[self._scalar_field] for result in res["data"]["Get"][self.collection_name]]
149
+
150
+ return ret
151
+
@@ -0,0 +1,34 @@
1
+ from pydantic import BaseModel, SecretStr
2
+ from ..api import DBCaseConfig, DBConfig
3
+ from ..milvus.config import MilvusIndexConfig, IndexType
4
+
5
+
6
+ class ZillizCloudConfig(DBConfig, BaseModel):
7
+ uri: SecretStr | None = None
8
+ user: str
9
+ password: SecretStr | None = None
10
+
11
+ def to_dict(self) -> dict:
12
+ return {
13
+ "uri": self.uri.get_secret_value(),
14
+ "user": self.user,
15
+ "password": self.password.get_secret_value(),
16
+ }
17
+
18
+
19
+ class AutoIndexConfig(MilvusIndexConfig, DBCaseConfig):
20
+ index: IndexType = IndexType.AUTOINDEX
21
+
22
+ def index_param(self) -> dict:
23
+ return {
24
+ "metric_type": self.parse_metric(),
25
+ "index_type": self.index.value,
26
+ "params": {},
27
+ }
28
+
29
+ def search_param(self) -> dict:
30
+ return {
31
+ "metric_type": self.parse_metric(),
32
+ }
33
+
34
+
@@ -0,0 +1,35 @@
1
+ """Wrapper around the ZillizCloud vector database over VectorDB"""
2
+
3
+ from typing import Type
4
+ from ..milvus.milvus import Milvus
5
+ from ..api import DBConfig, DBCaseConfig, IndexType
6
+ from .config import ZillizCloudConfig, AutoIndexConfig
7
+
8
+
9
+ class ZillizCloud(Milvus):
10
+ def __init__(
11
+ self,
12
+ dim: int,
13
+ db_config: dict,
14
+ db_case_config: DBCaseConfig,
15
+ collection_name: str = "ZillizCloudVectorDBBench",
16
+ drop_old: bool = False,
17
+ name: str = "ZillizCloud"
18
+ ):
19
+ super().__init__(
20
+ dim=dim,
21
+ db_config=db_config,
22
+ db_case_config=db_case_config,
23
+ collection_name=collection_name,
24
+ drop_old=drop_old,
25
+ name=name,
26
+ )
27
+
28
+ @classmethod
29
+ def config_cls(cls) -> Type[DBConfig]:
30
+ return ZillizCloudConfig
31
+
32
+
33
+ @classmethod
34
+ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
35
+ return AutoIndexConfig
@@ -0,0 +1,393 @@
1
+ """
2
+ Usage:
3
+ >>> from xxx import dataset as ds
4
+ >>> gist_s = ds.get(ds.Name.GIST, ds.Label.SMALL)
5
+ >>> gist_s.dict()
6
+ dataset: {'data': {'name': 'GIST', 'dim': 128, 'metric_type': 'L2', 'label': 'SMALL', 'size': 50000000}, 'data_dir': 'xxx'}
7
+ """
8
+
9
+ import os
10
+ import logging
11
+ import pathlib
12
+ import math
13
+ from hashlib import md5
14
+ from enum import Enum, auto
15
+ from typing import Any
16
+
17
+ import s3fs
18
+ import pandas as pd
19
+ from tqdm import tqdm
20
+ from pydantic.dataclasses import dataclass
21
+
22
+ from ..base import BaseModel
23
+ from .. import config
24
+ from ..backend.clients import MetricType
25
+ from . import utils
26
+
27
+ log = logging.getLogger(__name__)
28
+
29
+ @dataclass
30
+ class LAION:
31
+ name: str = "LAION"
32
+ dim: int = 768
33
+ metric_type: MetricType = MetricType.COSINE
34
+ use_shuffled: bool = False
35
+
36
+ @property
37
+ def dir_name(self) -> str:
38
+ return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
39
+
40
+ @dataclass
41
+ class GIST:
42
+ name: str = "GIST"
43
+ dim: int = 960
44
+ metric_type: MetricType = MetricType.L2
45
+ use_shuffled: bool = False
46
+
47
+ @property
48
+ def dir_name(self) -> str:
49
+ return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
50
+
51
+ @dataclass
52
+ class Cohere:
53
+ name: str = "Cohere"
54
+ dim: int = 768
55
+ metric_type: MetricType = MetricType.COSINE
56
+ use_shuffled: bool = config.USE_SHUFFLED_DATA
57
+
58
+ @property
59
+ def dir_name(self) -> str:
60
+ return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
61
+
62
+ @dataclass
63
+ class Glove:
64
+ name: str = "Glove"
65
+ dim: int = 200
66
+ metric_type: MetricType = MetricType.COSINE
67
+ use_shuffled: bool = False
68
+
69
+ @property
70
+ def dir_name(self) -> str:
71
+ return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
72
+
73
+ @dataclass
74
+ class SIFT:
75
+ name: str = "SIFT"
76
+ dim: int = 128
77
+ metric_type: MetricType = MetricType.COSINE
78
+ use_shuffled: bool = False
79
+
80
+ @property
81
+ def dir_name(self) -> str:
82
+ return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
83
+
84
+ @dataclass
85
+ class LAION_L(LAION):
86
+ label: str = "LARGE"
87
+ size: int = 100_000_000
88
+
89
+ @dataclass
90
+ class GIST_S(GIST):
91
+ label: str = "SMALL"
92
+ size: int = 100_000
93
+
94
+ @dataclass
95
+ class GIST_M(GIST):
96
+ label: str = "MEDIUM"
97
+ size: int = 1_000_000
98
+
99
+ @dataclass
100
+ class Cohere_S(Cohere):
101
+ label: str = "SMALL"
102
+ size: int = 100_000
103
+
104
+ @dataclass
105
+ class Cohere_M(Cohere):
106
+ label: str = "MEDIUM"
107
+ size: int = 1_000_000
108
+
109
+ @dataclass
110
+ class Cohere_L(Cohere):
111
+ label : str = "LARGE"
112
+ size : int = 10_000_000
113
+
114
+ @dataclass
115
+ class Glove_S(Glove):
116
+ label: str = "SMALL"
117
+ size : int = 100_000
118
+
119
+ @dataclass
120
+ class Glove_M(Glove):
121
+ label: str = "MEDIUM"
122
+ size : int = 1_000_000
123
+
124
+ @dataclass
125
+ class SIFT_S(SIFT):
126
+ label: str = "SMALL"
127
+ size : int = 500_000
128
+
129
+ @dataclass
130
+ class SIFT_M(SIFT):
131
+ label: str = "MEDIUM"
132
+ size : int = 5_000_000
133
+
134
+ @dataclass
135
+ class SIFT_L(SIFT):
136
+ label: str = "LARGE"
137
+ size : int = 50_000_000
138
+
139
+
140
+ class DataSet(BaseModel):
141
+ """Download dataset if not int the local directory. Provide data for cases.
142
+
143
+ DataSet is iterable, each iteration will return the next batch of data in pandas.DataFrame
144
+
145
+ Examples:
146
+ >>> cohere_s = DataSet(data=Cohere_S)
147
+ >>> for data in cohere_s:
148
+ >>> print(data.columns)
149
+ """
150
+ data: GIST | Cohere | Glove | SIFT | Any
151
+ test_data: pd.DataFrame | None = None
152
+ train_files : list[str] = []
153
+
154
+ def __eq__(self, obj):
155
+ if isinstance(obj, DataSet):
156
+ return self.data.name == obj.data.name and \
157
+ self.data.label == obj.data.label
158
+ return False
159
+
160
+ @property
161
+ def data_dir(self) -> pathlib.Path:
162
+ """ data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname}
163
+
164
+ Examples:
165
+ >>> sift_s = DataSet(data=SIFT_L())
166
+ >>> sift_s.relative_path
167
+ '/tmp/vectordb_bench/dataset/sift/sift_small_500k/'
168
+ """
169
+ return pathlib.Path(config.DATASET_LOCAL_DIR, self.data.name.lower(), self.data.dir_name.lower())
170
+
171
+ @property
172
+ def download_dir(self) -> str:
173
+ """ data s3 directory: config.DEFAULT_DATASET_URL/{dataset_dirname}
174
+
175
+ Examples:
176
+ >>> sift_s = DataSet(data=SIFT_L())
177
+ >>> sift_s.download_dir
178
+ 'assets.zilliz.com/benchmark/sift_small_500k'
179
+ """
180
+ return f"{config.DEFAULT_DATASET_URL}{self.data.dir_name}"
181
+
182
+ def __iter__(self):
183
+ return DataSetIterator(self)
184
+
185
+
186
+ def _validate_local_file(self):
187
+ if not self.data_dir.exists():
188
+ log.info(f"local file path not exist, creating it: {self.data_dir}")
189
+ self.data_dir.mkdir(parents=True)
190
+
191
+ fs = s3fs.S3FileSystem(
192
+ anon=True,
193
+ client_kwargs={'region_name': 'us-west-2'}
194
+ )
195
+ dataset_info = fs.ls(self.download_dir, detail=True)
196
+ if len(dataset_info) == 0:
197
+ raise ValueError(f"No data in s3 for dataset: {self.download_dir}")
198
+ path2etag = {info['Key']: info['ETag'].split('"')[1] for info in dataset_info}
199
+
200
+ perfix_to_filter = "train" if self.data.use_shuffled else "shuffle_train"
201
+ filtered_keys = [key for key in path2etag.keys() if key.split("/")[-1].startswith(perfix_to_filter)]
202
+ for k in filtered_keys:
203
+ path2etag.pop(k)
204
+
205
+ # get local files ended with '.parquet'
206
+ file_names = [p.name for p in self.data_dir.glob("*.parquet")]
207
+ log.info(f"local files: {file_names}")
208
+ log.info(f"s3 files: {path2etag.keys()}")
209
+ downloads = []
210
+ if len(file_names) == 0:
211
+ log.info("no local files, set all to downloading lists")
212
+ downloads = path2etag.keys()
213
+ else:
214
+ # if local file exists, check the etag of local file with s3,
215
+ # make sure data files aren't corrupted.
216
+ for name in tqdm([key.split("/")[-1] for key in path2etag.keys()]):
217
+ s3_path = f"{self.download_dir}/{name}"
218
+ local_path = self.data_dir.joinpath(name)
219
+ log.debug(f"s3 path: {s3_path}, local_path: {local_path}")
220
+ if not local_path.exists():
221
+ log.info(f"local file not exists: {local_path}, add to downloading lists")
222
+ downloads.append(s3_path)
223
+
224
+ elif not self.match_etag(path2etag.get(s3_path), local_path):
225
+ log.info(f"local file etag not match with s3 file: {local_path}, add to downloading lists")
226
+ downloads.append(s3_path)
227
+
228
+ for s3_file in tqdm(downloads):
229
+ log.debug(f"downloading file {s3_file} to {self.data_dir}")
230
+ fs.download(s3_file, self.data_dir.as_posix())
231
+
232
+ def match_etag(self, expected_etag: str, local_file) -> bool:
233
+ """Check if local files' etag match with S3"""
234
+ def factor_of_1MB(filesize, num_parts):
235
+ x = filesize / int(num_parts)
236
+ y = x % 1048576
237
+ return int(x + 1048576 - y)
238
+
239
+ def calc_etag(inputfile, partsize):
240
+ md5_digests = []
241
+ with open(inputfile, 'rb') as f:
242
+ for chunk in iter(lambda: f.read(partsize), b''):
243
+ md5_digests.append(md5(chunk).digest())
244
+ return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests))
245
+
246
+ def possible_partsizes(filesize, num_parts):
247
+ return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts
248
+
249
+ filesize = os.path.getsize(local_file)
250
+ le = ""
251
+ if '-' not in expected_etag: # no spliting uploading
252
+ with open(local_file, 'rb') as f:
253
+ le = md5(f.read()).hexdigest()
254
+ log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
255
+ return expected_etag == le
256
+ else:
257
+ num_parts = int(expected_etag.split('-')[-1])
258
+ partsizes = [ ## Default Partsizes Map
259
+ 8388608, # aws_cli/boto3
260
+ 15728640, # s3cmd
261
+ factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files
262
+ ]
263
+
264
+ for partsize in filter(possible_partsizes(filesize, num_parts), partsizes):
265
+ le = calc_etag(local_file, partsize)
266
+ log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
267
+ if expected_etag == le:
268
+ return True
269
+ return False
270
+
271
+ def prepare(self, check=True) -> bool:
272
+ """Download the dataset from S3
273
+ url = f"{config.DEFAULT_DATASET_URL}/{self.data.dir_name}"
274
+
275
+ download files from url to self.data_dir, there'll be 4 types of files in the data_dir
276
+ - train*.parquet: for training
277
+ - test.parquet: for testing
278
+ - neighbors.parquet: ground_truth of the test.parquet
279
+ - neighbors_90p.parquet: ground_truth of the test.parquet after filtering 90% data
280
+ - neighbors_head_1p.parquet: ground_truth of the test.parquet after filtering 1% data
281
+ - neighbors_99p.parquet: ground_truth of the test.parquet after filtering 99% data
282
+ """
283
+ if check:
284
+ self._validate_local_file()
285
+
286
+ prefix = "shuffle_train" if self.data.use_shuffled else "train"
287
+ self.train_files = sorted([f.name for f in self.data_dir.glob(f'{prefix}*.parquet')])
288
+ log.debug(f"{self.data.name}: available train files {self.train_files}")
289
+ self.test_data = self._read_file("test.parquet")
290
+ return True
291
+
292
+ def get_ground_truth(self, filters: int | float | None = None) -> pd.DataFrame:
293
+
294
+ file_name = ""
295
+ if filters is None:
296
+ file_name = "neighbors.parquet"
297
+ elif filters == 0.01:
298
+ file_name = "neighbors_head_1p.parquet"
299
+ elif filters == 0.99:
300
+ file_name = "neighbors_tail_1p.parquet"
301
+ else:
302
+ raise ValueError(f"Filters not supported: {filters}")
303
+ return self._read_file(file_name)
304
+
305
+ def _read_file(self, file_name: str) -> pd.DataFrame:
306
+ """read one file from disk into memory"""
307
+ import pyarrow.parquet as pq
308
+
309
+ p = pathlib.Path(self.data_dir, file_name)
310
+ log.info(f"reading file into memory: {p}")
311
+ if not p.exists():
312
+ log.warning(f"No such file: {p}")
313
+ return pd.DataFrame()
314
+ data = pq.read_table(p)
315
+ df = data.to_pandas()
316
+ return df
317
+
318
+
319
+ class DataSetIterator:
320
+ def __init__(self, dataset: DataSet):
321
+ self._ds = dataset
322
+ self._idx = 0 # file number
323
+ self._curr: pd.DataFrame | None = None
324
+ self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file
325
+
326
+ def __next__(self) -> pd.DataFrame:
327
+ """return the data in the next file of the training list"""
328
+ if self._idx < len(self._ds.train_files):
329
+ _sub = self._sub_idx[self._idx]
330
+ if _sub == 0 and self._idx == 0: # init
331
+ file_name = self._ds.train_files[self._idx]
332
+ self._curr = self._ds._read_file(file_name)
333
+ self._iter_num = math.ceil(self._curr.shape[0]/100_000)
334
+
335
+ if _sub == self._iter_num:
336
+ if self._idx == len(self._ds.train_files) - 1:
337
+ self._curr = None
338
+ raise StopIteration
339
+ else:
340
+ self._idx += 1
341
+ _sub = self._sub_idx[self._idx]
342
+
343
+ self._curr = None
344
+ file_name = self._ds.train_files[self._idx]
345
+ self._curr = self._ds._read_file(file_name)
346
+
347
+ sub_df = self._curr[_sub*100_000: (_sub+1)*100_000]
348
+ self._sub_idx[self._idx] += 1
349
+ log.info(f"Get the [{_sub+1}/{self._iter_num}] batch of {self._idx+1}/{len(self._ds.train_files)} train file")
350
+ return sub_df
351
+ self._curr = None
352
+ raise StopIteration
353
+
354
+
355
+ class Name(Enum):
356
+ GIST = auto()
357
+ Cohere = auto()
358
+ Glove = auto()
359
+ SIFT = auto()
360
+ LAION = auto()
361
+
362
+
363
+ class Label(Enum):
364
+ SMALL = auto()
365
+ MEDIUM = auto()
366
+ LARGE = auto()
367
+
368
+ _global_ds_mapping = {
369
+ Name.GIST: {
370
+ Label.SMALL: DataSet(data=GIST_S()),
371
+ Label.MEDIUM: DataSet(data=GIST_M()),
372
+ },
373
+ Name.Cohere: {
374
+ Label.SMALL: DataSet(data=Cohere_S()),
375
+ Label.MEDIUM: DataSet(data=Cohere_M()),
376
+ Label.LARGE: DataSet(data=Cohere_L()),
377
+ },
378
+ Name.Glove:{
379
+ Label.SMALL: DataSet(data=Glove_S()),
380
+ Label.MEDIUM: DataSet(data=Glove_M()),
381
+ },
382
+ Name.SIFT: {
383
+ Label.SMALL: DataSet(data=SIFT_S()),
384
+ Label.MEDIUM: DataSet(data=SIFT_M()),
385
+ Label.LARGE: DataSet(data=SIFT_L()),
386
+ },
387
+ Name.LAION: {
388
+ Label.LARGE: DataSet(data=LAION_L()),
389
+ },
390
+ }
391
+
392
+ def get(ds: Name, label: Label):
393
+ return _global_ds_mapping.get(ds, {}).get(label)
@@ -0,0 +1,15 @@
1
+ import pathlib
2
+ from ..models import TestResult
3
+
4
+
5
+ class ResultCollector:
6
+ @classmethod
7
+ def collect(cls, result_dir: pathlib.Path) -> list[TestResult]:
8
+ results = []
9
+ if not result_dir.exists() or len(list(result_dir.glob("*.json"))) == 0:
10
+ return []
11
+
12
+ for json_file in result_dir.glob("*.json"):
13
+ results.append(TestResult.read_file(json_file, trans_unit=True))
14
+
15
+ return results
@@ -0,0 +1,12 @@
1
+ from .mp_runner import (
2
+ MultiProcessingSearchRunner,
3
+ )
4
+
5
+ from .serial_runner import SerialSearchRunner, SerialInsertRunner
6
+
7
+
8
+ __all__ = [
9
+ 'MultiProcessingSearchRunner',
10
+ 'SerialSearchRunner',
11
+ 'SerialInsertRunner',
12
+ ]