vectordb-bench 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectordb_bench/backend/clients/__init__.py +4 -4
- vectordb_bench/backend/clients/api.py +1 -0
- vectordb_bench/backend/clients/chroma/chroma.py +2 -14
- vectordb_bench/backend/clients/milvus/config.py +19 -0
- vectordb_bench/backend/clients/pgvecto_rs/config.py +44 -32
- vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +16 -16
- vectordb_bench/backend/clients/pgvector/config.py +63 -12
- vectordb_bench/backend/clients/pgvector/pgvector.py +105 -77
- vectordb_bench/backend/clients/qdrant_cloud/config.py +19 -6
- vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +11 -7
- vectordb_bench/backend/clients/zilliz_cloud/config.py +4 -0
- vectordb_bench/backend/data_source.py +13 -64
- vectordb_bench/backend/dataset.py +45 -67
- vectordb_bench/backend/runner/serial_runner.py +1 -1
- vectordb_bench/backend/task_runner.py +2 -2
- vectordb_bench/backend/utils.py +30 -0
- vectordb_bench/frontend/components/run_test/caseSelector.py +1 -1
- vectordb_bench/frontend/const/dbCaseConfigs.py +41 -77
- vectordb_bench/models.py +1 -0
- vectordb_bench/results/PgVector/result_20230727_standard_pgvector.json +8 -0
- vectordb_bench/results/PgVector/result_20230808_standard_pgvector.json +9 -3
- vectordb_bench/results/ZillizCloud/{result_20240105_beta_202401_zillizcloud.json → result_20240105_standard_202401_zillizcloud.json} +365 -41
- vectordb_bench/results/getLeaderboardData.py +1 -1
- vectordb_bench/results/leaderboard.json +1 -1
- {vectordb_bench-0.0.6.dist-info → vectordb_bench-0.0.8.dist-info}/METADATA +15 -2
- {vectordb_bench-0.0.6.dist-info → vectordb_bench-0.0.8.dist-info}/RECORD +30 -30
- {vectordb_bench-0.0.6.dist-info → vectordb_bench-0.0.8.dist-info}/WHEEL +1 -1
- {vectordb_bench-0.0.6.dist-info → vectordb_bench-0.0.8.dist-info}/LICENSE +0 -0
- {vectordb_bench-0.0.6.dist-info → vectordb_bench-0.0.8.dist-info}/entry_points.txt +0 -0
- {vectordb_bench-0.0.6.dist-info → vectordb_bench-0.0.8.dist-info}/top_level.txt +0 -0
@@ -1,28 +1,14 @@
|
|
1
1
|
"""Wrapper around the Pgvector vector database over VectorDB"""
|
2
2
|
|
3
|
+
import io
|
3
4
|
import logging
|
4
5
|
from contextlib import contextmanager
|
5
6
|
from typing import Any
|
7
|
+
import pandas as pd
|
8
|
+
import psycopg2
|
9
|
+
import psycopg2.extras
|
6
10
|
|
7
|
-
from ..api import VectorDB, DBCaseConfig
|
8
|
-
from pgvector.sqlalchemy import Vector
|
9
|
-
from sqlalchemy import (
|
10
|
-
MetaData,
|
11
|
-
create_engine,
|
12
|
-
insert,
|
13
|
-
select,
|
14
|
-
Index,
|
15
|
-
Table,
|
16
|
-
text,
|
17
|
-
Column,
|
18
|
-
Float,
|
19
|
-
Integer
|
20
|
-
)
|
21
|
-
from sqlalchemy.orm import (
|
22
|
-
declarative_base,
|
23
|
-
mapped_column,
|
24
|
-
Session
|
25
|
-
)
|
11
|
+
from ..api import IndexType, VectorDB, DBCaseConfig
|
26
12
|
|
27
13
|
log = logging.getLogger(__name__)
|
28
14
|
|
@@ -37,6 +23,7 @@ class PgVector(VectorDB):
|
|
37
23
|
drop_old: bool = False,
|
38
24
|
**kwargs,
|
39
25
|
):
|
26
|
+
self.name = "PgVector"
|
40
27
|
self.db_config = db_config
|
41
28
|
self.case_config = db_case_config
|
42
29
|
self.table_name = collection_name
|
@@ -47,22 +34,26 @@ class PgVector(VectorDB):
|
|
47
34
|
self._vector_field = "embedding"
|
48
35
|
|
49
36
|
# construct basic units
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
pq_metadata.reflect(pg_engine)
|
37
|
+
self.conn = psycopg2.connect(**self.db_config)
|
38
|
+
self.conn.autocommit = False
|
39
|
+
self.cursor = self.conn.cursor()
|
54
40
|
|
55
41
|
# create vector extension
|
56
|
-
|
57
|
-
|
58
|
-
conn.commit()
|
42
|
+
self.cursor.execute('CREATE EXTENSION IF NOT EXISTS vector')
|
43
|
+
self.conn.commit()
|
59
44
|
|
60
|
-
|
61
|
-
if drop_old and self.table_name in pq_metadata.tables:
|
45
|
+
if drop_old :
|
62
46
|
log.info(f"Pgvector client drop table : {self.table_name}")
|
63
47
|
# self.pg_table.drop(pg_engine, checkfirst=True)
|
64
|
-
|
65
|
-
self.
|
48
|
+
self._drop_index()
|
49
|
+
self._drop_table()
|
50
|
+
self._create_table(dim)
|
51
|
+
self._create_index()
|
52
|
+
|
53
|
+
self.cursor.close()
|
54
|
+
self.conn.close()
|
55
|
+
self.cursor = None
|
56
|
+
self.conn = None
|
66
57
|
|
67
58
|
@contextmanager
|
68
59
|
def init(self) -> None:
|
@@ -72,53 +63,70 @@ class PgVector(VectorDB):
|
|
72
63
|
>>> self.insert_embeddings()
|
73
64
|
>>> self.search_embedding()
|
74
65
|
"""
|
75
|
-
self.
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
66
|
+
self.conn = psycopg2.connect(**self.db_config)
|
67
|
+
self.conn.autocommit = False
|
68
|
+
self.cursor = self.conn.cursor()
|
69
|
+
|
70
|
+
try:
|
71
|
+
yield
|
72
|
+
finally:
|
73
|
+
self.cursor.close()
|
74
|
+
self.conn.close()
|
75
|
+
self.cursor = None
|
76
|
+
self.conn = None
|
77
|
+
|
78
|
+
def _drop_table(self):
|
79
|
+
assert self.conn is not None, "Connection is not initialized"
|
80
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
81
|
+
|
82
|
+
self.cursor.execute(f'DROP TABLE IF EXISTS public."{self.table_name}"')
|
83
|
+
self.conn.commit()
|
87
84
|
|
88
85
|
def ready_to_load(self):
|
89
86
|
pass
|
90
87
|
|
91
88
|
def optimize(self):
|
92
89
|
pass
|
90
|
+
|
91
|
+
def _post_insert(self):
|
92
|
+
log.info(f"{self.name} post insert before optimize")
|
93
|
+
self._drop_index()
|
94
|
+
self._create_index()
|
93
95
|
|
94
96
|
def ready_to_search(self):
|
95
97
|
pass
|
96
|
-
|
97
|
-
def
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
extend_existing=True
|
104
|
-
)
|
98
|
+
|
99
|
+
def _drop_index(self):
|
100
|
+
assert self.conn is not None, "Connection is not initialized"
|
101
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
102
|
+
|
103
|
+
self.cursor.execute(f'DROP INDEX IF EXISTS "{self._index_name}"')
|
104
|
+
self.conn.commit()
|
105
105
|
|
106
|
-
def _create_index(self
|
106
|
+
def _create_index(self):
|
107
|
+
assert self.conn is not None, "Connection is not initialized"
|
108
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
109
|
+
|
107
110
|
index_param = self.case_config.index_param()
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
111
|
+
if self.case_config.index == IndexType.HNSW:
|
112
|
+
log.debug(f'Creating HNSW index. m={index_param["m"]}, ef_construction={index_param["ef_construction"]}')
|
113
|
+
self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING hnsw (embedding {index_param["metric"]}) WITH (m={index_param["m"]}, ef_construction={index_param["ef_construction"]});')
|
114
|
+
elif self.case_config.index == IndexType.IVFFlat:
|
115
|
+
log.debug(f'Creating IVFFLAT index. list={index_param["lists"]}')
|
116
|
+
self.cursor.execute(f'CREATE INDEX IF NOT EXISTS {self._index_name} ON public."{self.table_name}" USING ivfflat (embedding {index_param["metric"]}) WITH (lists={index_param["lists"]});')
|
117
|
+
else:
|
118
|
+
assert "Invalid index type {self.case_config.index}"
|
119
|
+
self.conn.commit()
|
120
|
+
|
121
|
+
def _create_table(self, dim : int):
|
122
|
+
assert self.conn is not None, "Connection is not initialized"
|
123
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
124
|
+
|
117
125
|
try:
|
118
126
|
# create table
|
119
|
-
self.
|
120
|
-
|
121
|
-
self.
|
127
|
+
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS public."{self.table_name}" (id BIGINT PRIMARY KEY, embedding vector({dim}));')
|
128
|
+
self.cursor.execute(f'ALTER TABLE public."{self.table_name}" ALTER COLUMN embedding SET STORAGE PLAIN;')
|
129
|
+
self.conn.commit()
|
122
130
|
except Exception as e:
|
123
131
|
log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}")
|
124
132
|
raise e from None
|
@@ -129,10 +137,24 @@ class PgVector(VectorDB):
|
|
129
137
|
metadata: list[int],
|
130
138
|
**kwargs: Any,
|
131
139
|
) -> (int, Exception):
|
140
|
+
assert self.conn is not None, "Connection is not initialized"
|
141
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
142
|
+
|
132
143
|
try:
|
133
|
-
items =
|
134
|
-
|
135
|
-
|
144
|
+
items = {
|
145
|
+
"id": metadata,
|
146
|
+
"embedding": embeddings
|
147
|
+
}
|
148
|
+
df = pd.DataFrame(items)
|
149
|
+
csv_buffer = io.StringIO()
|
150
|
+
df.to_csv(csv_buffer, index=False, header=False)
|
151
|
+
csv_buffer.seek(0)
|
152
|
+
self.cursor.copy_expert(f"COPY public.\"{self.table_name}\" FROM STDIN WITH (FORMAT CSV)", csv_buffer)
|
153
|
+
self.conn.commit()
|
154
|
+
|
155
|
+
if kwargs.get("last_batch"):
|
156
|
+
self._post_insert()
|
157
|
+
|
136
158
|
return len(metadata), None
|
137
159
|
except Exception as e:
|
138
160
|
log.warning(f"Failed to insert data into pgvector table ({self.table_name}), error: {e}")
|
@@ -145,15 +167,21 @@ class PgVector(VectorDB):
|
|
145
167
|
filters: dict | None = None,
|
146
168
|
timeout: int | None = None,
|
147
169
|
) -> list[int]:
|
148
|
-
assert self.
|
170
|
+
assert self.conn is not None, "Connection is not initialized"
|
171
|
+
assert self.cursor is not None, "Cursor is not initialized"
|
172
|
+
|
149
173
|
search_param =self.case_config.search_param()
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
174
|
+
|
175
|
+
if self.case_config.index == IndexType.HNSW:
|
176
|
+
self.cursor.execute(f'SET hnsw.ef_search = {search_param["ef"]}')
|
177
|
+
self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};")
|
178
|
+
elif self.case_config.index == IndexType.IVFFlat:
|
179
|
+
self.cursor.execute(f'SET ivfflat.probes = {search_param["probes"]}')
|
180
|
+
self.cursor.execute(f"SELECT id FROM public.\"{self.table_name}\" ORDER BY embedding {search_param['metric_fun_op']} '{query}' LIMIT {k};")
|
181
|
+
else:
|
182
|
+
assert "Invalid index type {self.case_config.index}"
|
183
|
+
self.conn.commit()
|
184
|
+
result = self.cursor.fetchall()
|
185
|
+
|
186
|
+
return [int(i[0]) for i in result]
|
159
187
|
|
@@ -1,18 +1,31 @@
|
|
1
1
|
from pydantic import BaseModel, SecretStr
|
2
2
|
|
3
3
|
from ..api import DBConfig, DBCaseConfig, MetricType
|
4
|
+
from pydantic import validator
|
4
5
|
|
5
|
-
|
6
|
+
# Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant.
|
6
7
|
class QdrantConfig(DBConfig):
|
7
8
|
url: SecretStr
|
8
9
|
api_key: SecretStr
|
9
10
|
|
10
11
|
def to_dict(self) -> dict:
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
12
|
+
api_key = self.api_key.get_secret_value()
|
13
|
+
if len(api_key) > 0:
|
14
|
+
return {
|
15
|
+
"url": self.url.get_secret_value(),
|
16
|
+
"api_key": self.api_key.get_secret_value(),
|
17
|
+
"prefer_grpc": True,
|
18
|
+
}
|
19
|
+
else:
|
20
|
+
return {"url": self.url.get_secret_value(),}
|
21
|
+
|
22
|
+
@validator("*")
|
23
|
+
def not_empty_field(cls, v, field):
|
24
|
+
if field.name in ["api_key", "db_label"]:
|
25
|
+
return v
|
26
|
+
if isinstance(v, (str, SecretStr)) and len(v) == 0:
|
27
|
+
raise ValueError("Empty string!")
|
28
|
+
return v
|
16
29
|
|
17
30
|
class QdrantIndexConfig(BaseModel, DBCaseConfig):
|
18
31
|
metric_type: MetricType | None = None
|
@@ -43,8 +43,7 @@ class QdrantCloud(VectorDB):
|
|
43
43
|
if drop_old:
|
44
44
|
log.info(f"QdrantCloud client drop_old collection: {self.collection_name}")
|
45
45
|
tmp_client.delete_collection(self.collection_name)
|
46
|
-
|
47
|
-
self._create_collection(dim, tmp_client)
|
46
|
+
self._create_collection(dim, tmp_client)
|
48
47
|
tmp_client = None
|
49
48
|
|
50
49
|
@contextmanager
|
@@ -110,13 +109,18 @@ class QdrantCloud(VectorDB):
|
|
110
109
|
) -> (int, Exception):
|
111
110
|
"""Insert embeddings into Milvus. should call self.init() first"""
|
112
111
|
assert self.qdrant_client is not None
|
112
|
+
QDRANT_BATCH_SIZE = 500
|
113
113
|
try:
|
114
114
|
# TODO: counts
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
115
|
+
for offset in range(0, len(embeddings), QDRANT_BATCH_SIZE):
|
116
|
+
vectors = embeddings[offset: offset + QDRANT_BATCH_SIZE]
|
117
|
+
ids = metadata[offset: offset + QDRANT_BATCH_SIZE]
|
118
|
+
payloads=[{self._primary_field: v} for v in ids]
|
119
|
+
_ = self.qdrant_client.upsert(
|
120
|
+
collection_name=self.collection_name,
|
121
|
+
wait=True,
|
122
|
+
points=Batch(ids=ids, payloads=payloads, vectors=vectors),
|
123
|
+
)
|
120
124
|
except Exception as e:
|
121
125
|
log.info(f"Failed to insert data, {e}")
|
122
126
|
return 0, e
|
@@ -19,6 +19,7 @@ class ZillizCloudConfig(DBConfig):
|
|
19
19
|
|
20
20
|
class AutoIndexConfig(MilvusIndexConfig, DBCaseConfig):
|
21
21
|
index: IndexType = IndexType.AUTOINDEX
|
22
|
+
level: int = 1
|
22
23
|
|
23
24
|
def index_param(self) -> dict:
|
24
25
|
return {
|
@@ -30,6 +31,9 @@ class AutoIndexConfig(MilvusIndexConfig, DBCaseConfig):
|
|
30
31
|
def search_param(self) -> dict:
|
31
32
|
return {
|
32
33
|
"metric_type": self.parse_metric(),
|
34
|
+
"params": {
|
35
|
+
"level": self.level,
|
36
|
+
}
|
33
37
|
}
|
34
38
|
|
35
39
|
|
@@ -3,7 +3,6 @@ import pathlib
|
|
3
3
|
import typing
|
4
4
|
from enum import Enum
|
5
5
|
from tqdm import tqdm
|
6
|
-
from hashlib import md5
|
7
6
|
import os
|
8
7
|
from abc import ABC, abstractmethod
|
9
8
|
|
@@ -32,14 +31,13 @@ class DatasetReader(ABC):
|
|
32
31
|
remote_root: str
|
33
32
|
|
34
33
|
@abstractmethod
|
35
|
-
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path
|
34
|
+
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
|
36
35
|
"""read dataset files from remote_root to local_ds_root,
|
37
36
|
|
38
37
|
Args:
|
39
38
|
dataset(str): for instance "sift_small_500k"
|
40
39
|
files(list[str]): all filenames of the dataset
|
41
40
|
local_ds_root(pathlib.Path): whether to write the remote data.
|
42
|
-
check_etag(bool): whether to check the etag
|
43
41
|
"""
|
44
42
|
pass
|
45
43
|
|
@@ -56,7 +54,7 @@ class AliyunOSSReader(DatasetReader):
|
|
56
54
|
import oss2
|
57
55
|
self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True)
|
58
56
|
|
59
|
-
def validate_file(self, remote: pathlib.Path, local: pathlib.Path
|
57
|
+
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
|
60
58
|
info = self.bucket.get_object_meta(remote.as_posix())
|
61
59
|
|
62
60
|
# check size equal
|
@@ -65,26 +63,21 @@ class AliyunOSSReader(DatasetReader):
|
|
65
63
|
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
|
66
64
|
return False
|
67
65
|
|
68
|
-
# check etag equal
|
69
|
-
if check_etag:
|
70
|
-
return match_etag(info.etag.strip('"').lower(), local)
|
71
|
-
|
72
|
-
|
73
66
|
return True
|
74
67
|
|
75
|
-
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path
|
68
|
+
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
|
76
69
|
downloads = []
|
77
70
|
if not local_ds_root.exists():
|
78
71
|
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
|
79
72
|
local_ds_root.mkdir(parents=True)
|
80
|
-
downloads = [(pathlib.
|
73
|
+
downloads = [(pathlib.PurePosixPath("benchmark", dataset, f), local_ds_root.joinpath(f)) for f in files]
|
81
74
|
|
82
75
|
else:
|
83
76
|
for file in files:
|
84
|
-
remote_file = pathlib.
|
77
|
+
remote_file = pathlib.PurePosixPath("benchmark", dataset, file)
|
85
78
|
local_file = local_ds_root.joinpath(file)
|
86
79
|
|
87
|
-
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file
|
80
|
+
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)):
|
88
81
|
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
|
89
82
|
downloads.append((remote_file, local_file))
|
90
83
|
|
@@ -93,8 +86,8 @@ class AliyunOSSReader(DatasetReader):
|
|
93
86
|
|
94
87
|
log.info(f"Start to downloading files, total count: {len(downloads)}")
|
95
88
|
for remote_file, local_file in tqdm(downloads):
|
96
|
-
log.debug(f"downloading file {remote_file} to {
|
97
|
-
self.bucket.get_object_to_file(remote_file.as_posix(), local_file.
|
89
|
+
log.debug(f"downloading file {remote_file} to {local_file}")
|
90
|
+
self.bucket.get_object_to_file(remote_file.as_posix(), local_file.absolute())
|
98
91
|
|
99
92
|
log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")
|
100
93
|
|
@@ -120,19 +113,19 @@ class AwsS3Reader(DatasetReader):
|
|
120
113
|
return names
|
121
114
|
|
122
115
|
|
123
|
-
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path
|
116
|
+
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
|
124
117
|
downloads = []
|
125
118
|
if not local_ds_root.exists():
|
126
119
|
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
|
127
120
|
local_ds_root.mkdir(parents=True)
|
128
|
-
downloads = [pathlib.
|
121
|
+
downloads = [pathlib.PurePosixPath(self.remote_root, dataset, f) for f in files]
|
129
122
|
|
130
123
|
else:
|
131
124
|
for file in files:
|
132
|
-
remote_file = pathlib.
|
125
|
+
remote_file = pathlib.PurePosixPath(self.remote_root, dataset, file)
|
133
126
|
local_file = local_ds_root.joinpath(file)
|
134
127
|
|
135
|
-
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file
|
128
|
+
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)):
|
136
129
|
log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
|
137
130
|
downloads.append(remote_file)
|
138
131
|
|
@@ -147,7 +140,7 @@ class AwsS3Reader(DatasetReader):
|
|
147
140
|
log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")
|
148
141
|
|
149
142
|
|
150
|
-
def validate_file(self, remote: pathlib.Path, local: pathlib.Path
|
143
|
+
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
|
151
144
|
# info() uses ls() inside, maybe we only need to ls once
|
152
145
|
info = self.fs.info(remote)
|
153
146
|
|
@@ -157,48 +150,4 @@ class AwsS3Reader(DatasetReader):
|
|
157
150
|
log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
|
158
151
|
return False
|
159
152
|
|
160
|
-
# check etag equal
|
161
|
-
if check_etag:
|
162
|
-
return match_etag(info.get('ETag', "").strip('"'), local)
|
163
|
-
|
164
153
|
return True
|
165
|
-
|
166
|
-
|
167
|
-
def match_etag(expected_etag: str, local_file) -> bool:
|
168
|
-
"""Check if local files' etag match with S3"""
|
169
|
-
def factor_of_1MB(filesize, num_parts):
|
170
|
-
x = filesize / int(num_parts)
|
171
|
-
y = x % 1048576
|
172
|
-
return int(x + 1048576 - y)
|
173
|
-
|
174
|
-
def calc_etag(inputfile, partsize):
|
175
|
-
md5_digests = []
|
176
|
-
with open(inputfile, 'rb') as f:
|
177
|
-
for chunk in iter(lambda: f.read(partsize), b''):
|
178
|
-
md5_digests.append(md5(chunk).digest())
|
179
|
-
return md5(b''.join(md5_digests)).hexdigest() + '-' + str(len(md5_digests))
|
180
|
-
|
181
|
-
def possible_partsizes(filesize, num_parts):
|
182
|
-
return lambda partsize: partsize < filesize and (float(filesize) / float(partsize)) <= num_parts
|
183
|
-
|
184
|
-
filesize = os.path.getsize(local_file)
|
185
|
-
le = ""
|
186
|
-
if '-' not in expected_etag: # no spliting uploading
|
187
|
-
with open(local_file, 'rb') as f:
|
188
|
-
le = md5(f.read()).hexdigest()
|
189
|
-
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
|
190
|
-
return expected_etag == le
|
191
|
-
else:
|
192
|
-
num_parts = int(expected_etag.split('-')[-1])
|
193
|
-
partsizes = [ ## Default Partsizes Map
|
194
|
-
8388608, # aws_cli/boto3
|
195
|
-
15728640, # s3cmd
|
196
|
-
factor_of_1MB(filesize, num_parts) # Used by many clients to upload large files
|
197
|
-
]
|
198
|
-
|
199
|
-
for partsize in filter(possible_partsizes(filesize, num_parts), partsizes):
|
200
|
-
le = calc_etag(local_file, partsize)
|
201
|
-
log.debug(f"calculated local etag {le}, expected etag: {expected_etag}")
|
202
|
-
if expected_etag == le:
|
203
|
-
return True
|
204
|
-
return False
|