vectordb-bench 0.0.1__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/__init__.py +30 -0
- vectordb_bench/__main__.py +39 -0
- vectordb_bench/backend/__init__.py +0 -0
- vectordb_bench/backend/assembler.py +57 -0
- vectordb_bench/backend/cases.py +124 -0
- vectordb_bench/backend/clients/__init__.py +57 -0
- vectordb_bench/backend/clients/api.py +179 -0
- vectordb_bench/backend/clients/elastic_cloud/config.py +56 -0
- vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +152 -0
- vectordb_bench/backend/clients/milvus/config.py +123 -0
- vectordb_bench/backend/clients/milvus/milvus.py +182 -0
- vectordb_bench/backend/clients/pinecone/config.py +15 -0
- vectordb_bench/backend/clients/pinecone/pinecone.py +113 -0
- vectordb_bench/backend/clients/qdrant_cloud/config.py +16 -0
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +169 -0
- vectordb_bench/backend/clients/weaviate_cloud/config.py +45 -0
- vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +151 -0
- vectordb_bench/backend/clients/zilliz_cloud/config.py +34 -0
- vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +35 -0
- vectordb_bench/backend/dataset.py +393 -0
- vectordb_bench/backend/result_collector.py +15 -0
- vectordb_bench/backend/runner/__init__.py +12 -0
- vectordb_bench/backend/runner/mp_runner.py +124 -0
- vectordb_bench/backend/runner/serial_runner.py +164 -0
- vectordb_bench/backend/task_runner.py +290 -0
- vectordb_bench/backend/utils.py +85 -0
- vectordb_bench/base.py +6 -0
- vectordb_bench/frontend/components/check_results/charts.py +175 -0
- vectordb_bench/frontend/components/check_results/data.py +86 -0
- vectordb_bench/frontend/components/check_results/filters.py +97 -0
- vectordb_bench/frontend/components/check_results/headerIcon.py +18 -0
- vectordb_bench/frontend/components/check_results/nav.py +21 -0
- vectordb_bench/frontend/components/check_results/priceTable.py +48 -0
- vectordb_bench/frontend/components/run_test/autoRefresh.py +10 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +87 -0
- vectordb_bench/frontend/components/run_test/dbConfigSetting.py +47 -0
- vectordb_bench/frontend/components/run_test/dbSelector.py +36 -0
- vectordb_bench/frontend/components/run_test/generateTasks.py +21 -0
- vectordb_bench/frontend/components/run_test/hideSidebar.py +10 -0
- vectordb_bench/frontend/components/run_test/submitTask.py +69 -0
- vectordb_bench/frontend/const.py +391 -0
- vectordb_bench/frontend/pages/qps_with_price.py +60 -0
- vectordb_bench/frontend/pages/run_test.py +59 -0
- vectordb_bench/frontend/utils.py +6 -0
- vectordb_bench/frontend/vdb_benchmark.py +42 -0
- vectordb_bench/interface.py +239 -0
- vectordb_bench/log_util.py +103 -0
- vectordb_bench/metric.py +53 -0
- vectordb_bench/models.py +234 -0
- vectordb_bench/results/result_20230609_standard.json +3228 -0
- vectordb_bench-0.0.1.dist-info/LICENSE +21 -0
- vectordb_bench-0.0.1.dist-info/METADATA +226 -0
- vectordb_bench-0.0.1.dist-info/RECORD +56 -0
- vectordb_bench-0.0.1.dist-info/WHEEL +5 -0
- vectordb_bench-0.0.1.dist-info/entry_points.txt +2 -0
- vectordb_bench-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,151 @@
|
|
1
|
+
"""Wrapper around the Weaviate vector database over VectorDB"""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Any, Iterable, Type
|
5
|
+
from contextlib import contextmanager
|
6
|
+
|
7
|
+
from weaviate.exceptions import WeaviateBaseError
|
8
|
+
|
9
|
+
from ..api import VectorDB, DBConfig, DBCaseConfig, IndexType
|
10
|
+
from .config import WeaviateConfig, WeaviateIndexConfig
|
11
|
+
|
12
|
+
|
13
|
+
log = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class WeaviateCloud(VectorDB):
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
dim: int,
|
20
|
+
db_config: dict,
|
21
|
+
db_case_config: DBCaseConfig,
|
22
|
+
collection_name: str = "VectorDBBenchCollection",
|
23
|
+
drop_old: bool = False,
|
24
|
+
):
|
25
|
+
"""Initialize wrapper around the weaviate vector database."""
|
26
|
+
self.db_config = db_config
|
27
|
+
self.case_config = db_case_config
|
28
|
+
self.collection_name = collection_name
|
29
|
+
|
30
|
+
self._scalar_field = "key"
|
31
|
+
self._vector_field = "vector"
|
32
|
+
self._index_name = "vector_idx"
|
33
|
+
|
34
|
+
from weaviate import Client
|
35
|
+
client = Client(**db_config)
|
36
|
+
if drop_old:
|
37
|
+
try:
|
38
|
+
if client.schema.exists(self.collection_name):
|
39
|
+
log.info(f"weaviate client drop_old collection: {self.collection_name}")
|
40
|
+
client.schema.delete_class(self.collection_name)
|
41
|
+
except WeaviateBaseError as e:
|
42
|
+
log.warning(f"Failed to drop collection: {self.collection_name} error: {str(e)}")
|
43
|
+
raise e from None
|
44
|
+
self._create_collection(client)
|
45
|
+
client = None
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def config_cls(cls) -> Type[DBConfig]:
|
49
|
+
return WeaviateConfig
|
50
|
+
|
51
|
+
@classmethod
|
52
|
+
def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
53
|
+
return WeaviateIndexConfig
|
54
|
+
|
55
|
+
@contextmanager
|
56
|
+
def init(self) -> None:
|
57
|
+
"""
|
58
|
+
Examples:
|
59
|
+
>>> with self.init():
|
60
|
+
>>> self.insert_embeddings()
|
61
|
+
>>> self.search_embedding()
|
62
|
+
"""
|
63
|
+
from weaviate import Client
|
64
|
+
self.client = Client(**self.db_config)
|
65
|
+
yield
|
66
|
+
self.client = None
|
67
|
+
del(self.client)
|
68
|
+
|
69
|
+
def ready_to_load(self):
|
70
|
+
"""Should call insert first, do nothing"""
|
71
|
+
pass
|
72
|
+
|
73
|
+
def ready_to_search(self):
|
74
|
+
assert self.client.schema.exists(self.collection_name)
|
75
|
+
self.client.schema.update_config(self.collection_name, {"vectorIndexConfig": self.case_config.search_param() } )
|
76
|
+
|
77
|
+
def _create_collection(self, client):
|
78
|
+
if not client.schema.exists(self.collection_name):
|
79
|
+
log.info(f"Create collection: {self.collection_name}")
|
80
|
+
class_obj = {
|
81
|
+
"class": self.collection_name,
|
82
|
+
"vectorizer": "none",
|
83
|
+
"properties": [
|
84
|
+
{
|
85
|
+
"dataType": ["int"],
|
86
|
+
"name": self._scalar_field,
|
87
|
+
},
|
88
|
+
]
|
89
|
+
}
|
90
|
+
class_obj["vectorIndexConfig"] = self.case_config.index_param()
|
91
|
+
try:
|
92
|
+
client.schema.create_class(class_obj)
|
93
|
+
except WeaviateBaseError as e:
|
94
|
+
log.warning(f"Failed to create collection: {self.collection_name} error: {str(e)}")
|
95
|
+
raise e from None
|
96
|
+
|
97
|
+
def insert_embeddings(
|
98
|
+
self,
|
99
|
+
embeddings: Iterable[list[float]],
|
100
|
+
metadata: list[int],
|
101
|
+
**kwargs: Any,
|
102
|
+
) -> int:
|
103
|
+
"""Insert embeddings into Weaviate"""
|
104
|
+
assert self.client.schema.exists(self.collection_name)
|
105
|
+
|
106
|
+
try:
|
107
|
+
with self.client.batch as batch:
|
108
|
+
batch.batch_size = len(metadata)
|
109
|
+
batch.dynamic = True
|
110
|
+
res = []
|
111
|
+
for i in range(len(metadata)):
|
112
|
+
res.append(batch.add_data_object(
|
113
|
+
{self._scalar_field: metadata[i]},
|
114
|
+
class_name=self.collection_name,
|
115
|
+
vector=embeddings[i]
|
116
|
+
))
|
117
|
+
return len(res)
|
118
|
+
except WeaviateBaseError as e:
|
119
|
+
log.warning(f"Failed to insert data, error: {str(e)}")
|
120
|
+
raise e from None
|
121
|
+
|
122
|
+
def search_embedding(
|
123
|
+
self,
|
124
|
+
query: list[float],
|
125
|
+
k: int = 100,
|
126
|
+
filters: dict | None = None,
|
127
|
+
timeout: int | None = None,
|
128
|
+
**kwargs: Any,
|
129
|
+
) -> list[int]:
|
130
|
+
"""Perform a search on a query embedding and return results with distance.
|
131
|
+
Should call self.init() first.
|
132
|
+
"""
|
133
|
+
assert self.client.schema.exists(self.collection_name)
|
134
|
+
|
135
|
+
query_obj = self.client.query.get(self.collection_name, [self._scalar_field]).with_additional("distance").with_near_vector({"vector": query}).with_limit(k)
|
136
|
+
if filters:
|
137
|
+
where_filter = {
|
138
|
+
"path": "key",
|
139
|
+
"operator": "GreaterThanEqual",
|
140
|
+
"valueInt": filters.get('id')
|
141
|
+
}
|
142
|
+
query_obj = query_obj.with_where(where_filter)
|
143
|
+
|
144
|
+
# Perform the search.
|
145
|
+
res = query_obj.do()
|
146
|
+
|
147
|
+
# Organize results.
|
148
|
+
ret = [result[self._scalar_field] for result in res["data"]["Get"][self.collection_name]]
|
149
|
+
|
150
|
+
return ret
|
151
|
+
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from pydantic import BaseModel, SecretStr
|
2
|
+
from ..api import DBCaseConfig, DBConfig
|
3
|
+
from ..milvus.config import MilvusIndexConfig, IndexType
|
4
|
+
|
5
|
+
|
6
|
+
class ZillizCloudConfig(DBConfig, BaseModel):
|
7
|
+
uri: SecretStr | None = None
|
8
|
+
user: str
|
9
|
+
password: SecretStr | None = None
|
10
|
+
|
11
|
+
def to_dict(self) -> dict:
|
12
|
+
return {
|
13
|
+
"uri": self.uri.get_secret_value(),
|
14
|
+
"user": self.user,
|
15
|
+
"password": self.password.get_secret_value(),
|
16
|
+
}
|
17
|
+
|
18
|
+
|
19
|
+
class AutoIndexConfig(MilvusIndexConfig, DBCaseConfig):
|
20
|
+
index: IndexType = IndexType.AUTOINDEX
|
21
|
+
|
22
|
+
def index_param(self) -> dict:
|
23
|
+
return {
|
24
|
+
"metric_type": self.parse_metric(),
|
25
|
+
"index_type": self.index.value,
|
26
|
+
"params": {},
|
27
|
+
}
|
28
|
+
|
29
|
+
def search_param(self) -> dict:
|
30
|
+
return {
|
31
|
+
"metric_type": self.parse_metric(),
|
32
|
+
}
|
33
|
+
|
34
|
+
|
@@ -0,0 +1,35 @@
|
|
1
|
+
"""Wrapper around the ZillizCloud vector database over VectorDB"""
|
2
|
+
|
3
|
+
from typing import Type
|
4
|
+
from ..milvus.milvus import Milvus
|
5
|
+
from ..api import DBConfig, DBCaseConfig, IndexType
|
6
|
+
from .config import ZillizCloudConfig, AutoIndexConfig
|
7
|
+
|
8
|
+
|
9
|
+
class ZillizCloud(Milvus):
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
dim: int,
|
13
|
+
db_config: dict,
|
14
|
+
db_case_config: DBCaseConfig,
|
15
|
+
collection_name: str = "ZillizCloudVectorDBBench",
|
16
|
+
drop_old: bool = False,
|
17
|
+
name: str = "ZillizCloud"
|
18
|
+
):
|
19
|
+
super().__init__(
|
20
|
+
dim=dim,
|
21
|
+
db_config=db_config,
|
22
|
+
db_case_config=db_case_config,
|
23
|
+
collection_name=collection_name,
|
24
|
+
drop_old=drop_old,
|
25
|
+
name=name,
|
26
|
+
)
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def config_cls(cls) -> Type[DBConfig]:
|
30
|
+
return ZillizCloudConfig
|
31
|
+
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
|
35
|
+
return AutoIndexConfig
|
@@ -0,0 +1,393 @@
|
|
1
|
+
"""
|
2
|
+
Usage:
|
3
|
+
>>> from xxx import dataset as ds
|
4
|
+
>>> gist_s = ds.get(ds.Name.GIST, ds.Label.SMALL)
|
5
|
+
>>> gist_s.dict()
|
6
|
+
dataset: {'data': {'name': 'GIST', 'dim': 128, 'metric_type': 'L2', 'label': 'SMALL', 'size': 50000000}, 'data_dir': 'xxx'}
|
7
|
+
"""
|
8
|
+
|
9
|
+
import os
|
10
|
+
import logging
|
11
|
+
import pathlib
|
12
|
+
import math
|
13
|
+
from hashlib import md5
|
14
|
+
from enum import Enum, auto
|
15
|
+
from typing import Any
|
16
|
+
|
17
|
+
import s3fs
|
18
|
+
import pandas as pd
|
19
|
+
from tqdm import tqdm
|
20
|
+
from pydantic.dataclasses import dataclass
|
21
|
+
|
22
|
+
from ..base import BaseModel
|
23
|
+
from .. import config
|
24
|
+
from ..backend.clients import MetricType
|
25
|
+
from . import utils
|
26
|
+
|
27
|
+
log = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
@dataclass
|
30
|
+
class LAION:
|
31
|
+
name: str = "LAION"
|
32
|
+
dim: int = 768
|
33
|
+
metric_type: MetricType = MetricType.COSINE
|
34
|
+
use_shuffled: bool = False
|
35
|
+
|
36
|
+
@property
|
37
|
+
def dir_name(self) -> str:
|
38
|
+
return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class GIST:
|
42
|
+
name: str = "GIST"
|
43
|
+
dim: int = 960
|
44
|
+
metric_type: MetricType = MetricType.L2
|
45
|
+
use_shuffled: bool = False
|
46
|
+
|
47
|
+
@property
|
48
|
+
def dir_name(self) -> str:
|
49
|
+
return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
|
50
|
+
|
51
|
+
@dataclass
|
52
|
+
class Cohere:
|
53
|
+
name: str = "Cohere"
|
54
|
+
dim: int = 768
|
55
|
+
metric_type: MetricType = MetricType.COSINE
|
56
|
+
use_shuffled: bool = config.USE_SHUFFLED_DATA
|
57
|
+
|
58
|
+
@property
|
59
|
+
def dir_name(self) -> str:
|
60
|
+
return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
|
61
|
+
|
62
|
+
@dataclass
|
63
|
+
class Glove:
|
64
|
+
name: str = "Glove"
|
65
|
+
dim: int = 200
|
66
|
+
metric_type: MetricType = MetricType.COSINE
|
67
|
+
use_shuffled: bool = False
|
68
|
+
|
69
|
+
@property
|
70
|
+
def dir_name(self) -> str:
|
71
|
+
return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
|
72
|
+
|
73
|
+
@dataclass
|
74
|
+
class SIFT:
|
75
|
+
name: str = "SIFT"
|
76
|
+
dim: int = 128
|
77
|
+
metric_type: MetricType = MetricType.COSINE
|
78
|
+
use_shuffled: bool = False
|
79
|
+
|
80
|
+
@property
|
81
|
+
def dir_name(self) -> str:
|
82
|
+
return f"{self.name}_{self.label}_{utils.numerize(self.size)}".lower()
|
83
|
+
|
84
|
+
@dataclass
|
85
|
+
class LAION_L(LAION):
|
86
|
+
label: str = "LARGE"
|
87
|
+
size: int = 100_000_000
|
88
|
+
|
89
|
+
@dataclass
|
90
|
+
class GIST_S(GIST):
|
91
|
+
label: str = "SMALL"
|
92
|
+
size: int = 100_000
|
93
|
+
|
94
|
+
@dataclass
|
95
|
+
class GIST_M(GIST):
|
96
|
+
label: str = "MEDIUM"
|
97
|
+
size: int = 1_000_000
|
98
|
+
|
99
|
+
@dataclass
|
100
|
+
class Cohere_S(Cohere):
|
101
|
+
label: str = "SMALL"
|
102
|
+
size: int = 100_000
|
103
|
+
|
104
|
+
@dataclass
|
105
|
+
class Cohere_M(Cohere):
|
106
|
+
label: str = "MEDIUM"
|
107
|
+
size: int = 1_000_000
|
108
|
+
|
109
|
+
@dataclass
|
110
|
+
class Cohere_L(Cohere):
|
111
|
+
label : str = "LARGE"
|
112
|
+
size : int = 10_000_000
|
113
|
+
|
114
|
+
@dataclass
|
115
|
+
class Glove_S(Glove):
|
116
|
+
label: str = "SMALL"
|
117
|
+
size : int = 100_000
|
118
|
+
|
119
|
+
@dataclass
|
120
|
+
class Glove_M(Glove):
|
121
|
+
label: str = "MEDIUM"
|
122
|
+
size : int = 1_000_000
|
123
|
+
|
124
|
+
@dataclass
|
125
|
+
class SIFT_S(SIFT):
|
126
|
+
label: str = "SMALL"
|
127
|
+
size : int = 500_000
|
128
|
+
|
129
|
+
@dataclass
|
130
|
+
class SIFT_M(SIFT):
|
131
|
+
label: str = "MEDIUM"
|
132
|
+
size : int = 5_000_000
|
133
|
+
|
134
|
+
@dataclass
|
135
|
+
class SIFT_L(SIFT):
|
136
|
+
label: str = "LARGE"
|
137
|
+
size : int = 50_000_000
|
138
|
+
|
139
|
+
|
140
|
+
class DataSet(BaseModel):
|
141
|
+
"""Download dataset if not int the local directory. Provide data for cases.
|
142
|
+
|
143
|
+
DataSet is iterable, each iteration will return the next batch of data in pandas.DataFrame
|
144
|
+
|
145
|
+
Examples:
|
146
|
+
>>> cohere_s = DataSet(data=Cohere_S)
|
147
|
+
>>> for data in cohere_s:
|
148
|
+
>>> print(data.columns)
|
149
|
+
"""
|
150
|
+
data: GIST | Cohere | Glove | SIFT | Any
|
151
|
+
test_data: pd.DataFrame | None = None
|
152
|
+
train_files : list[str] = []
|
153
|
+
|
154
|
+
def __eq__(self, obj):
|
155
|
+
if isinstance(obj, DataSet):
|
156
|
+
return self.data.name == obj.data.name and \
|
157
|
+
self.data.label == obj.data.label
|
158
|
+
return False
|
159
|
+
|
160
|
+
@property
|
161
|
+
def data_dir(self) -> pathlib.Path:
|
162
|
+
""" data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname}
|
163
|
+
|
164
|
+
Examples:
|
165
|
+
>>> sift_s = DataSet(data=SIFT_L())
|
166
|
+
>>> sift_s.relative_path
|
167
|
+
'/tmp/vectordb_bench/dataset/sift/sift_small_500k/'
|
168
|
+
"""
|
169
|
+
return pathlib.Path(config.DATASET_LOCAL_DIR, self.data.name.lower(), self.data.dir_name.lower())
|
170
|
+
|
171
|
+
@property
|
172
|
+
def download_dir(self) -> str:
|
173
|
+
""" data s3 directory: config.DEFAULT_DATASET_URL/{dataset_dirname}
|
174
|
+
|
175
|
+
Examples:
|
176
|
+
>>> sift_s = DataSet(data=SIFT_L())
|
177
|
+
>>> sift_s.download_dir
|
178
|
+
'assets.zilliz.com/benchmark/sift_small_500k'
|
179
|
+
"""
|
180
|
+
return f"{config.DEFAULT_DATASET_URL}{self.data.dir_name}"
|
181
|
+
|
182
|
+
def __iter__(self):
|
183
|
+
return DataSetIterator(self)
|
184
|
+
|
185
|
+
|
186
|
+
def _validate_local_file(self):
|
187
|
+
if not self.data_dir.exists():
|
188
|
+
log.info(f"local file path not exist, creating it: {self.data_dir}")
|
189
|
+
self.data_dir.mkdir(parents=True)
|
190
|
+
|
191
|
+
fs = s3fs.S3FileSystem(
|
192
|
+
anon=True,
|
193
|
+
client_kwargs={'region_name': 'us-west-2'}
|
194
|
+
)
|
195
|
+
dataset_info = fs.ls(self.download_dir, detail=True)
|
196
|
+
if len(dataset_info) == 0:
|
197
|
+
raise ValueError(f"No data in s3 for dataset: {self.download_dir}")
|
198
|
+
path2etag = {info['Key']: info['ETag'].split('"')[1] for info in dataset_info}
|
199
|
+
|
200
|
+
perfix_to_filter = "train" if self.data.use_shuffled else "shuffle_train"
|
201
|
+
filtered_keys = [key for key in path2etag.keys() if key.split("/")[-1].startswith(perfix_to_filter)]
|
202
|
+
for k in filtered_keys:
|
203
|
+
path2etag.pop(k)
|
204
|
+
|
205
|
+
# get local files ended with '.parquet'
|
206
|
+
file_names = [p.name for p in self.data_dir.glob("*.parquet")]
|
207
|
+
log.info(f"local files: {file_names}")
|
208
|
+
log.info(f"s3 files: {path2etag.keys()}")
|
209
|
+
downloads = []
|
210
|
+
if len(file_names) == 0:
|
211
|
+
log.info("no local files, set all to downloading lists")
|
212
|
+
downloads = path2etag.keys()
|
213
|
+
else:
|
214
|
+
# if local file exists, check the etag of local file with s3,
|
215
|
+
# make sure data files aren't corrupted.
|
216
|
+
for name in tqdm([key.split("/")[-1] for key in path2etag.keys()]):
|
217
|
+
s3_path = f"{self.download_dir}/{name}"
|
218
|
+
local_path = self.data_dir.joinpath(name)
|
219
|
+
log.debug(f"s3 path: {s3_path}, local_path: {local_path}")
|
220
|
+
if not local_path.exists():
|
221
|
+
log.info(f"local file not exists: {local_path}, add to downloading lists")
|
222
|
+
downloads.append(s3_path)
|
223
|
+
|
224
|
+
elif not self.match_etag(path2etag.get(s3_path), local_path):
|
225
|
+
log.info(f"local file etag not match with s3 file: {local_path}, add to downloading lists")
|
226
|
+
downloads.append(s3_path)
|
227
|
+
|
228
|
+
for s3_file in tqdm(downloads):
|
229
|
+
log.debug(f"downloading file {s3_file} to {self.data_dir}")
|
230
|
+
fs.download(s3_file, self.data_dir.as_posix())
|
231
|
+
|
232
|
+
def match_etag(self, expected_etag: str, local_file) -> bool:
|
233
|
+
"""Check if local files' etag match with S3"""
|
234
|
+
def factor_of_1MB(filesize, num_parts):
|
235
|
+
x = filesize / int(num_parts)
|
236
|
+
y = x % 1048576
|
237
|
+
return int(x + 1048576 - y)
|
238
|
+
|
239
|
+
def calc_etag(inputfile, partsize):
|
240
|
+
md5_digests = []
|
241
|
+
with open(inputfile, 'rb') as f:
|
242
|
+
for chunk in iter(lambda: f.read(partsize), b''):
|
243
|
+
md5_digests.append(md5(chunk).digest())
|
244
|
+
return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests))
|
245
|
+
|
246
|
+
def possible_partsizes(filesize, num_parts):
|
247
|
+
return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts
|
248
|
+
|
249
|
+
filesize = os.path.getsize(local_file)
|
250
|
+
le = ""
|
251
|
+
if '-' not in expected_etag: # no spliting uploading
|
252
|
+
with open(local_file, 'rb') as f:
|
253
|
+
le = md5(f.read()).hexdigest()
|
254
|
+
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
|
255
|
+
return expected_etag == le
|
256
|
+
else:
|
257
|
+
num_parts = int(expected_etag.split('-')[-1])
|
258
|
+
partsizes = [ ## Default Partsizes Map
|
259
|
+
8388608, # aws_cli/boto3
|
260
|
+
15728640, # s3cmd
|
261
|
+
factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files
|
262
|
+
]
|
263
|
+
|
264
|
+
for partsize in filter(possible_partsizes(filesize, num_parts), partsizes):
|
265
|
+
le = calc_etag(local_file, partsize)
|
266
|
+
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
|
267
|
+
if expected_etag == le:
|
268
|
+
return True
|
269
|
+
return False
|
270
|
+
|
271
|
+
def prepare(self, check=True) -> bool:
|
272
|
+
"""Download the dataset from S3
|
273
|
+
url = f"{config.DEFAULT_DATASET_URL}/{self.data.dir_name}"
|
274
|
+
|
275
|
+
download files from url to self.data_dir, there'll be 4 types of files in the data_dir
|
276
|
+
- train*.parquet: for training
|
277
|
+
- test.parquet: for testing
|
278
|
+
- neighbors.parquet: ground_truth of the test.parquet
|
279
|
+
- neighbors_90p.parquet: ground_truth of the test.parquet after filtering 90% data
|
280
|
+
- neighbors_head_1p.parquet: ground_truth of the test.parquet after filtering 1% data
|
281
|
+
- neighbors_99p.parquet: ground_truth of the test.parquet after filtering 99% data
|
282
|
+
"""
|
283
|
+
if check:
|
284
|
+
self._validate_local_file()
|
285
|
+
|
286
|
+
prefix = "shuffle_train" if self.data.use_shuffled else "train"
|
287
|
+
self.train_files = sorted([f.name for f in self.data_dir.glob(f'{prefix}*.parquet')])
|
288
|
+
log.debug(f"{self.data.name}: available train files {self.train_files}")
|
289
|
+
self.test_data = self._read_file("test.parquet")
|
290
|
+
return True
|
291
|
+
|
292
|
+
def get_ground_truth(self, filters: int | float | None = None) -> pd.DataFrame:
|
293
|
+
|
294
|
+
file_name = ""
|
295
|
+
if filters is None:
|
296
|
+
file_name = "neighbors.parquet"
|
297
|
+
elif filters == 0.01:
|
298
|
+
file_name = "neighbors_head_1p.parquet"
|
299
|
+
elif filters == 0.99:
|
300
|
+
file_name = "neighbors_tail_1p.parquet"
|
301
|
+
else:
|
302
|
+
raise ValueError(f"Filters not supported: {filters}")
|
303
|
+
return self._read_file(file_name)
|
304
|
+
|
305
|
+
def _read_file(self, file_name: str) -> pd.DataFrame:
|
306
|
+
"""read one file from disk into memory"""
|
307
|
+
import pyarrow.parquet as pq
|
308
|
+
|
309
|
+
p = pathlib.Path(self.data_dir, file_name)
|
310
|
+
log.info(f"reading file into memory: {p}")
|
311
|
+
if not p.exists():
|
312
|
+
log.warning(f"No such file: {p}")
|
313
|
+
return pd.DataFrame()
|
314
|
+
data = pq.read_table(p)
|
315
|
+
df = data.to_pandas()
|
316
|
+
return df
|
317
|
+
|
318
|
+
|
319
|
+
class DataSetIterator:
|
320
|
+
def __init__(self, dataset: DataSet):
|
321
|
+
self._ds = dataset
|
322
|
+
self._idx = 0 # file number
|
323
|
+
self._curr: pd.DataFrame | None = None
|
324
|
+
self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file
|
325
|
+
|
326
|
+
def __next__(self) -> pd.DataFrame:
|
327
|
+
"""return the data in the next file of the training list"""
|
328
|
+
if self._idx < len(self._ds.train_files):
|
329
|
+
_sub = self._sub_idx[self._idx]
|
330
|
+
if _sub == 0 and self._idx == 0: # init
|
331
|
+
file_name = self._ds.train_files[self._idx]
|
332
|
+
self._curr = self._ds._read_file(file_name)
|
333
|
+
self._iter_num = math.ceil(self._curr.shape[0]/100_000)
|
334
|
+
|
335
|
+
if _sub == self._iter_num:
|
336
|
+
if self._idx == len(self._ds.train_files) - 1:
|
337
|
+
self._curr = None
|
338
|
+
raise StopIteration
|
339
|
+
else:
|
340
|
+
self._idx += 1
|
341
|
+
_sub = self._sub_idx[self._idx]
|
342
|
+
|
343
|
+
self._curr = None
|
344
|
+
file_name = self._ds.train_files[self._idx]
|
345
|
+
self._curr = self._ds._read_file(file_name)
|
346
|
+
|
347
|
+
sub_df = self._curr[_sub*100_000: (_sub+1)*100_000]
|
348
|
+
self._sub_idx[self._idx] += 1
|
349
|
+
log.info(f"Get the [{_sub+1}/{self._iter_num}] batch of {self._idx+1}/{len(self._ds.train_files)} train file")
|
350
|
+
return sub_df
|
351
|
+
self._curr = None
|
352
|
+
raise StopIteration
|
353
|
+
|
354
|
+
|
355
|
+
class Name(Enum):
|
356
|
+
GIST = auto()
|
357
|
+
Cohere = auto()
|
358
|
+
Glove = auto()
|
359
|
+
SIFT = auto()
|
360
|
+
LAION = auto()
|
361
|
+
|
362
|
+
|
363
|
+
class Label(Enum):
|
364
|
+
SMALL = auto()
|
365
|
+
MEDIUM = auto()
|
366
|
+
LARGE = auto()
|
367
|
+
|
368
|
+
_global_ds_mapping = {
|
369
|
+
Name.GIST: {
|
370
|
+
Label.SMALL: DataSet(data=GIST_S()),
|
371
|
+
Label.MEDIUM: DataSet(data=GIST_M()),
|
372
|
+
},
|
373
|
+
Name.Cohere: {
|
374
|
+
Label.SMALL: DataSet(data=Cohere_S()),
|
375
|
+
Label.MEDIUM: DataSet(data=Cohere_M()),
|
376
|
+
Label.LARGE: DataSet(data=Cohere_L()),
|
377
|
+
},
|
378
|
+
Name.Glove:{
|
379
|
+
Label.SMALL: DataSet(data=Glove_S()),
|
380
|
+
Label.MEDIUM: DataSet(data=Glove_M()),
|
381
|
+
},
|
382
|
+
Name.SIFT: {
|
383
|
+
Label.SMALL: DataSet(data=SIFT_S()),
|
384
|
+
Label.MEDIUM: DataSet(data=SIFT_M()),
|
385
|
+
Label.LARGE: DataSet(data=SIFT_L()),
|
386
|
+
},
|
387
|
+
Name.LAION: {
|
388
|
+
Label.LARGE: DataSet(data=LAION_L()),
|
389
|
+
},
|
390
|
+
}
|
391
|
+
|
392
|
+
def get(ds: Name, label: Label):
|
393
|
+
return _global_ds_mapping.get(ds, {}).get(label)
|
@@ -0,0 +1,15 @@
|
|
1
|
+
import pathlib
|
2
|
+
from ..models import TestResult
|
3
|
+
|
4
|
+
|
5
|
+
class ResultCollector:
|
6
|
+
@classmethod
|
7
|
+
def collect(cls, result_dir: pathlib.Path) -> list[TestResult]:
|
8
|
+
results = []
|
9
|
+
if not result_dir.exists() or len(list(result_dir.glob("*.json"))) == 0:
|
10
|
+
return []
|
11
|
+
|
12
|
+
for json_file in result_dir.glob("*.json"):
|
13
|
+
results.append(TestResult.read_file(json_file, trans_unit=True))
|
14
|
+
|
15
|
+
return results
|