Perception 0.7.4__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- perception/__init__.py +13 -0
- perception/benchmarking/__init__.py +23 -0
- perception/benchmarking/common.py +649 -0
- perception/benchmarking/extensions.c +31307 -0
- perception/benchmarking/extensions.cp312-win_amd64.pyd +0 -0
- perception/benchmarking/extensions.pyx +112 -0
- perception/benchmarking/image.py +202 -0
- perception/benchmarking/image_transforms.py +42 -0
- perception/benchmarking/video.py +224 -0
- perception/benchmarking/video_transforms.py +200 -0
- perception/experimental/__init__.py +0 -0
- perception/experimental/ann/__init__.py +0 -0
- perception/experimental/ann/index.py +430 -0
- perception/experimental/ann/serve.py +152 -0
- perception/experimental/approximate_deduplication.py +301 -0
- perception/experimental/debug.py +240 -0
- perception/experimental/local_descriptor_deduplication.py +710 -0
- perception/extensions.cp312-win_amd64.pyd +0 -0
- perception/extensions.cpp +33751 -0
- perception/extensions.pyx +305 -0
- perception/hashers/__init__.py +27 -0
- perception/hashers/hasher.py +406 -0
- perception/hashers/image/__init__.py +17 -0
- perception/hashers/image/average.py +35 -0
- perception/hashers/image/dhash.py +30 -0
- perception/hashers/image/opencv.py +63 -0
- perception/hashers/image/pdq.py +34 -0
- perception/hashers/image/phash.py +109 -0
- perception/hashers/image/wavelet.py +59 -0
- perception/hashers/tools.py +1075 -0
- perception/hashers/video/__init__.py +5 -0
- perception/hashers/video/framewise.py +106 -0
- perception/hashers/video/scenes.py +241 -0
- perception/hashers/video/tmk.py +215 -0
- perception/py.typed +0 -0
- perception/testing/__init__.py +243 -0
- perception/testing/images/README.md +13 -0
- perception/testing/images/image1.jpg +0 -0
- perception/testing/images/image10.jpg +0 -0
- perception/testing/images/image2.jpg +0 -0
- perception/testing/images/image3.jpg +0 -0
- perception/testing/images/image4.jpg +0 -0
- perception/testing/images/image5.jpg +0 -0
- perception/testing/images/image6.jpg +0 -0
- perception/testing/images/image7.jpg +0 -0
- perception/testing/images/image8.jpg +0 -0
- perception/testing/images/image9.jpg +0 -0
- perception/testing/logos/README.md +4 -0
- perception/testing/logos/logoipsum.png +0 -0
- perception/testing/videos/README.md +6 -0
- perception/testing/videos/expected_tmk.json.gz +0 -0
- perception/testing/videos/rgb.m4v +0 -0
- perception/testing/videos/v1.m4v +0 -0
- perception/testing/videos/v2.m4v +0 -0
- perception/testing/videos/v2s.mov +0 -0
- perception/tools.py +387 -0
- perception/utils.py +2 -0
- perception-0.7.4.dist-info/DELVEWHEEL +1 -0
- perception-0.7.4.dist-info/LICENSE +191 -0
- perception-0.7.4.dist-info/METADATA +112 -0
- perception-0.7.4.dist-info/RECORD +63 -0
- perception-0.7.4.dist-info/WHEEL +4 -0
- perception.libs/msvcp140-370c82302f0983347afe7f970ea2ece2.dll +0 -0
perception/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""""" # start delvewheel patch
|
|
2
|
+
def _delvewheel_patch_1_8_1():
|
|
3
|
+
import os
|
|
4
|
+
libs_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'perception.libs'))
|
|
5
|
+
if os.path.isdir(libs_dir):
|
|
6
|
+
os.add_dll_directory(libs_dir)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
_delvewheel_patch_1_8_1()
|
|
10
|
+
del _delvewheel_patch_1_8_1
|
|
11
|
+
# end delvewheel patch
|
|
12
|
+
|
|
13
|
+
__version__ = "0.7.4"
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from perception.benchmarking import video_transforms
|
|
2
|
+
from perception.benchmarking import video
|
|
3
|
+
from perception.benchmarking import image
|
|
4
|
+
from perception.benchmarking.image import (
|
|
5
|
+
BenchmarkImageDataset,
|
|
6
|
+
BenchmarkImageTransforms,
|
|
7
|
+
)
|
|
8
|
+
from perception.benchmarking.video import (
|
|
9
|
+
BenchmarkVideoDataset,
|
|
10
|
+
BenchmarkVideoTransforms,
|
|
11
|
+
)
|
|
12
|
+
from perception.benchmarking.common import BenchmarkHashes
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"BenchmarkImageDataset",
|
|
16
|
+
"BenchmarkImageTransforms",
|
|
17
|
+
"BenchmarkVideoDataset",
|
|
18
|
+
"BenchmarkVideoTransforms",
|
|
19
|
+
"BenchmarkHashes",
|
|
20
|
+
"video_transforms",
|
|
21
|
+
"video",
|
|
22
|
+
"image",
|
|
23
|
+
]
|
|
@@ -0,0 +1,649 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import shutil
|
|
5
|
+
import tempfile
|
|
6
|
+
import typing
|
|
7
|
+
import uuid
|
|
8
|
+
import warnings
|
|
9
|
+
import zipfile
|
|
10
|
+
from abc import ABC
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
import tqdm
|
|
17
|
+
from scipy import spatial, stats
|
|
18
|
+
|
|
19
|
+
from ..hashers.tools import compute_md5, string_to_vector
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from . import extensions # type: ignore
|
|
23
|
+
except ImportError:
|
|
24
|
+
warnings.warn(
|
|
25
|
+
"C extensions were not built. Some metrics will be computed more slowly. "
|
|
26
|
+
"Please install from wheels or set up a compiler prior to installation "
|
|
27
|
+
"from source to use extensions."
|
|
28
|
+
)
|
|
29
|
+
extensions = None
|
|
30
|
+
|
|
31
|
+
log = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def create_mask(transformed_guids, noop_guids):
|
|
35
|
+
"""Given a list of transformed guids and noop guids,
|
|
36
|
+
computes an MxN array indicating whether noop n has the same guid
|
|
37
|
+
as transform m. Used for applying a mask to a distance matrix
|
|
38
|
+
for efficient computation of recall at different thresholds.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
transformed_guids: An iterable of transformed guids
|
|
42
|
+
noop: An iterable of noop guids
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
An boolean array of shape
|
|
46
|
+
`(len(transformed_guids), len(transformed_noops))`
|
|
47
|
+
"""
|
|
48
|
+
n_noops = len(noop_guids)
|
|
49
|
+
previous_guid = None
|
|
50
|
+
start = None
|
|
51
|
+
end = 0
|
|
52
|
+
mask = np.zeros((len(transformed_guids), len(noop_guids)), dtype="bool")
|
|
53
|
+
for current_guid, row in zip(transformed_guids, mask):
|
|
54
|
+
if previous_guid is None or current_guid != previous_guid:
|
|
55
|
+
start = end
|
|
56
|
+
end = start + next(
|
|
57
|
+
(
|
|
58
|
+
other_index
|
|
59
|
+
for other_index, guid in enumerate(noop_guids[start:])
|
|
60
|
+
if guid != current_guid
|
|
61
|
+
),
|
|
62
|
+
n_noops,
|
|
63
|
+
)
|
|
64
|
+
previous_guid = current_guid
|
|
65
|
+
row[start:end] = True
|
|
66
|
+
return mask
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def compute_threshold_precision_recall(pos, neg, precision_threshold=99.9):
|
|
70
|
+
# Sort both arrays according to the positive distance
|
|
71
|
+
neg = neg[pos.argsort()]
|
|
72
|
+
pos = pos[pos.argsort()]
|
|
73
|
+
|
|
74
|
+
# Compute false-positive rate for every value in pos
|
|
75
|
+
tp = np.arange(1, len(pos) + 1)
|
|
76
|
+
fp = np.array([(neg <= t).sum() for t in pos])
|
|
77
|
+
precision = 100 * tp / (tp + fp)
|
|
78
|
+
|
|
79
|
+
# Choose the optimal threshold
|
|
80
|
+
bad_threshold_idxs = np.where(precision < precision_threshold)[0]
|
|
81
|
+
|
|
82
|
+
if len(bad_threshold_idxs) > 0 and bad_threshold_idxs[0] > 0:
|
|
83
|
+
optimal_threshold = pos[bad_threshold_idxs[0] - 1]
|
|
84
|
+
recovered = (pos <= optimal_threshold).sum()
|
|
85
|
+
if recovered == 0:
|
|
86
|
+
optimal_precision = np.nan
|
|
87
|
+
else:
|
|
88
|
+
optimal_precision = precision[pos <= optimal_threshold].min()
|
|
89
|
+
optimal_recall = round(100 * recovered / len(pos), 3)
|
|
90
|
+
elif len(bad_threshold_idxs) > 0:
|
|
91
|
+
# The closest hash was a false positive.
|
|
92
|
+
optimal_threshold = pos[0]
|
|
93
|
+
optimal_recall = 0
|
|
94
|
+
optimal_precision = np.nan
|
|
95
|
+
else:
|
|
96
|
+
optimal_precision = 100
|
|
97
|
+
optimal_threshold = pos.max()
|
|
98
|
+
optimal_recall = 100
|
|
99
|
+
return optimal_threshold, optimal_precision, optimal_recall
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Filterable(ABC):
|
|
103
|
+
_df: pd.DataFrame
|
|
104
|
+
expected_columns: typing.List
|
|
105
|
+
|
|
106
|
+
def __init__(self, df):
|
|
107
|
+
assert sorted(df.columns) == sorted(
|
|
108
|
+
self.expected_columns
|
|
109
|
+
), f"Column mismatch: Expected {sorted(self.expected_columns)}, found {sorted(df.columns)}."
|
|
110
|
+
self._df = df
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def categories(self):
|
|
114
|
+
"""The categories included in the dataset"""
|
|
115
|
+
return self._df["category"].unique()
|
|
116
|
+
|
|
117
|
+
def filter(self, **kwargs):
|
|
118
|
+
"""Obtain a new dataset filtered with the given
|
|
119
|
+
keyword arguments."""
|
|
120
|
+
df = self._df.copy()
|
|
121
|
+
for field, included in kwargs.items():
|
|
122
|
+
existing = self._df[field].unique()
|
|
123
|
+
if not all(inc in existing for inc in included):
|
|
124
|
+
missing = ", ".join(
|
|
125
|
+
[str(inc) for inc in included if inc not in existing]
|
|
126
|
+
)
|
|
127
|
+
message = f"Did not find {missing} in column {field} dataset."
|
|
128
|
+
warnings.warn(message, UserWarning)
|
|
129
|
+
df = df[df[field].isin(included)]
|
|
130
|
+
return self.__class__(df.copy())
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class Saveable(Filterable):
|
|
134
|
+
@classmethod
|
|
135
|
+
def load(
|
|
136
|
+
cls,
|
|
137
|
+
path_to_zip_or_directory: str,
|
|
138
|
+
storage_dir: Optional[str] = None,
|
|
139
|
+
verify_md5=True,
|
|
140
|
+
):
|
|
141
|
+
"""Load a dataset from a ZIP file or directory.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
path_to_zip_or_directory: Pretty self-explanatory
|
|
145
|
+
storage_dir: If providing a ZIP file, where to extract
|
|
146
|
+
the contents. If None, contents will be extracted to
|
|
147
|
+
a folder with the same name as the ZIP file in the
|
|
148
|
+
same directory as the ZIP file.
|
|
149
|
+
verify_md5: Verify md5s when loading
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
# Load index whether from inside ZIP file or from directory.
|
|
153
|
+
if os.path.splitext(path_to_zip_or_directory)[1] == ".zip":
|
|
154
|
+
if storage_dir is None:
|
|
155
|
+
storage_dir = os.path.join(
|
|
156
|
+
os.path.dirname(os.path.abspath(path_to_zip_or_directory)),
|
|
157
|
+
os.path.splitext(os.path.basename(path_to_zip_or_directory))[0],
|
|
158
|
+
)
|
|
159
|
+
os.makedirs(storage_dir, exist_ok=True)
|
|
160
|
+
with zipfile.ZipFile(path_to_zip_or_directory, "r") as z:
|
|
161
|
+
# Try extracting only the index at first so we can
|
|
162
|
+
# compare md5.
|
|
163
|
+
z.extract("index.csv", os.path.join(storage_dir))
|
|
164
|
+
index: pd.DataFrame = pd.read_csv(
|
|
165
|
+
os.path.join(storage_dir, "index.csv")
|
|
166
|
+
)
|
|
167
|
+
index["filepath"] = index["filename"].apply(
|
|
168
|
+
lambda fn: (
|
|
169
|
+
os.path.join(storage_dir, fn) if not pd.isnull(fn) else None
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
do_zip_extraction = True
|
|
173
|
+
if index["filepath"].apply(os.path.isfile).all():
|
|
174
|
+
if verify_md5:
|
|
175
|
+
do_zip_extraction = not all(
|
|
176
|
+
row["md5"] == compute_md5(row["filepath"])
|
|
177
|
+
for _, row in tqdm.tqdm(
|
|
178
|
+
index.iterrows(), desc="Checking cache"
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
do_zip_extraction = False
|
|
183
|
+
if do_zip_extraction:
|
|
184
|
+
z.extractall(storage_dir)
|
|
185
|
+
else:
|
|
186
|
+
log.info("Found all files already extracted. Skipping extraction.")
|
|
187
|
+
verify_md5 = False
|
|
188
|
+
else:
|
|
189
|
+
assert (
|
|
190
|
+
storage_dir is None
|
|
191
|
+
), "Storage directory only valid if path is to ZIP file."
|
|
192
|
+
index = pd.read_csv(os.path.join(path_to_zip_or_directory, "index.csv"))
|
|
193
|
+
index["filepath"] = index["filename"].apply(
|
|
194
|
+
lambda fn: (
|
|
195
|
+
os.path.join(path_to_zip_or_directory, fn)
|
|
196
|
+
if not pd.isnull(fn)
|
|
197
|
+
else None
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if verify_md5:
|
|
202
|
+
assert all(
|
|
203
|
+
row["md5"] == compute_md5(row["filepath"])
|
|
204
|
+
for _, row in tqdm.tqdm(
|
|
205
|
+
index.iterrows(),
|
|
206
|
+
desc="Performing final md5 integrity check.",
|
|
207
|
+
total=len(index.index),
|
|
208
|
+
)
|
|
209
|
+
), "An md5 mismatch has occurred."
|
|
210
|
+
return cls(index.drop(["filename", "md5"], axis=1))
|
|
211
|
+
|
|
212
|
+
def save(self, path_to_zip_or_directory):
|
|
213
|
+
"""Save a dataset to a directory or ZIP file.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
path_to_zip_or_directory: Pretty self-explanatory
|
|
217
|
+
"""
|
|
218
|
+
df = self._df
|
|
219
|
+
assert "filepath" in df.columns, "Index dataframe must contain filepath."
|
|
220
|
+
|
|
221
|
+
# Build index using filename instead of filepath.
|
|
222
|
+
index = df.copy()
|
|
223
|
+
index["filename"] = df["filepath"].apply(
|
|
224
|
+
lambda filepath: (
|
|
225
|
+
os.path.basename(filepath) if not pd.isnull(filepath) else None
|
|
226
|
+
)
|
|
227
|
+
)
|
|
228
|
+
if index["filename"].dropna().duplicated().sum() > 0:
|
|
229
|
+
warnings.warn("Changing filenames to UUID due to duplicates.", UserWarning)
|
|
230
|
+
|
|
231
|
+
index["filename"] = [
|
|
232
|
+
(
|
|
233
|
+
str(uuid.uuid4()) + os.path.splitext(row["filename"])[1]
|
|
234
|
+
if not pd.isnull(row["filename"])
|
|
235
|
+
else None
|
|
236
|
+
)
|
|
237
|
+
for _, row in index.iterrows()
|
|
238
|
+
]
|
|
239
|
+
index["md5"] = [
|
|
240
|
+
compute_md5(filepath) if not pd.isnull(filepath) else None
|
|
241
|
+
for filepath in tqdm.tqdm(index["filepath"], desc="Computing md5s.")
|
|
242
|
+
]
|
|
243
|
+
|
|
244
|
+
# Add all files as well as the dataframe index to
|
|
245
|
+
# a ZIP file if path is to ZIP file or to the directory if it is
|
|
246
|
+
# not a ZIP file.
|
|
247
|
+
if os.path.splitext(path_to_zip_or_directory)[1] == ".zip":
|
|
248
|
+
with zipfile.ZipFile(path_to_zip_or_directory, "w") as f:
|
|
249
|
+
with tempfile.TemporaryFile(mode="w+") as index_file:
|
|
250
|
+
index.drop("filepath", axis=1).to_csv(index_file, index=False)
|
|
251
|
+
index_file.seek(0)
|
|
252
|
+
f.writestr("index.csv", index_file.read())
|
|
253
|
+
for _, row in tqdm.tqdm(
|
|
254
|
+
index.iterrows(), desc="Saving files", total=len(df)
|
|
255
|
+
):
|
|
256
|
+
if pd.isnull(row["filepath"]):
|
|
257
|
+
# There was an error associated with this file.
|
|
258
|
+
continue
|
|
259
|
+
f.write(row["filepath"], row["filename"])
|
|
260
|
+
else:
|
|
261
|
+
os.makedirs(path_to_zip_or_directory, exist_ok=True)
|
|
262
|
+
index.drop("filepath", axis=1).to_csv(
|
|
263
|
+
os.path.join(path_to_zip_or_directory, "index.csv"), index=False
|
|
264
|
+
)
|
|
265
|
+
for _, row in tqdm.tqdm(
|
|
266
|
+
index.iterrows(), desc="Saving files", total=len(df)
|
|
267
|
+
):
|
|
268
|
+
if pd.isnull(row["filepath"]):
|
|
269
|
+
# There was an error associated with this file.
|
|
270
|
+
continue
|
|
271
|
+
if row["filepath"] == os.path.join(
|
|
272
|
+
path_to_zip_or_directory, row["filename"]
|
|
273
|
+
):
|
|
274
|
+
# The source file is the same as the target file.
|
|
275
|
+
continue
|
|
276
|
+
shutil.copy(
|
|
277
|
+
row["filepath"],
|
|
278
|
+
os.path.join(path_to_zip_or_directory, row["filename"]),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class BenchmarkHashes(Filterable):
|
|
283
|
+
"""A dataset of hashes for transformed images. It is essentially
|
|
284
|
+
a wrapper around a `pandas.DataFrame` with the following columns:
|
|
285
|
+
|
|
286
|
+
- guid
|
|
287
|
+
- error
|
|
288
|
+
- filepath
|
|
289
|
+
- category
|
|
290
|
+
- transform_name
|
|
291
|
+
- hasher_name
|
|
292
|
+
- hasher_dtype
|
|
293
|
+
- hasher_distance_metric
|
|
294
|
+
- hasher_hash_length
|
|
295
|
+
- hash
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
expected_columns = [
|
|
299
|
+
"error",
|
|
300
|
+
"filepath",
|
|
301
|
+
"hash",
|
|
302
|
+
"hasher_name",
|
|
303
|
+
"hasher_dtype",
|
|
304
|
+
"hasher_distance_metric",
|
|
305
|
+
"category",
|
|
306
|
+
"guid",
|
|
307
|
+
"input_filepath",
|
|
308
|
+
"transform_name",
|
|
309
|
+
"hasher_hash_length",
|
|
310
|
+
]
|
|
311
|
+
|
|
312
|
+
def __init__(self, df: pd.DataFrame):
|
|
313
|
+
super().__init__(df)
|
|
314
|
+
self._metrics: Optional[pd.DataFrame] = None
|
|
315
|
+
|
|
316
|
+
def __add__(self, other):
|
|
317
|
+
return BenchmarkHashes(df=pd.concat([self._df, other._df]).drop_duplicates())
|
|
318
|
+
|
|
319
|
+
def __radd__(self, other):
|
|
320
|
+
return self.__add__(other)
|
|
321
|
+
|
|
322
|
+
@classmethod
|
|
323
|
+
def load(cls, filepath: str):
|
|
324
|
+
return cls(pd.read_csv(filepath))
|
|
325
|
+
|
|
326
|
+
def save(self, filepath):
|
|
327
|
+
self._df.to_csv(filepath, index=False)
|
|
328
|
+
|
|
329
|
+
def compute_metrics(
|
|
330
|
+
self, custom_distance_metrics: Optional[dict] = None
|
|
331
|
+
) -> pd.DataFrame:
|
|
332
|
+
if self._metrics is not None:
|
|
333
|
+
return self._metrics
|
|
334
|
+
metrics = []
|
|
335
|
+
hashsets = self._df.sort_values("guid")
|
|
336
|
+
n_dropped = hashsets["hash"].isnull().sum()
|
|
337
|
+
if n_dropped > 0:
|
|
338
|
+
hashsets = hashsets.dropna(subset=["hash"])
|
|
339
|
+
warnings.warn(f"Dropping {n_dropped} invalid / empty hashes.", UserWarning)
|
|
340
|
+
for (hasher_name, transform_name, category), hashset in tqdm.tqdm(
|
|
341
|
+
hashsets.groupby(["hasher_name", "transform_name", "category"]),
|
|
342
|
+
desc="Computing metrics.",
|
|
343
|
+
):
|
|
344
|
+
# Note the guid filtering below. We need to include only guids
|
|
345
|
+
# for which we have the transform *and* the guid. One of them
|
|
346
|
+
# may have been dropped due to being invalid.
|
|
347
|
+
noops = hashsets[
|
|
348
|
+
(hashsets["transform_name"] == "noop")
|
|
349
|
+
& (hashsets["hasher_name"] == hasher_name)
|
|
350
|
+
& (hashsets["guid"].isin(hashset["guid"]))
|
|
351
|
+
]
|
|
352
|
+
valid_hashset = hashset[hashset["guid"].isin(noops["guid"])]
|
|
353
|
+
dtype, distance_metric, hash_length = valid_hashset.iloc[0][
|
|
354
|
+
["hasher_dtype", "hasher_distance_metric", "hasher_hash_length"]
|
|
355
|
+
]
|
|
356
|
+
n_noops = len(noops.guid)
|
|
357
|
+
n_hashset = len(valid_hashset.guid)
|
|
358
|
+
noop_guids = noops.guid.values
|
|
359
|
+
mask = create_mask(valid_hashset.guid.values, noops.guid.values)
|
|
360
|
+
if distance_metric != "custom":
|
|
361
|
+
X_trans = np.array(
|
|
362
|
+
valid_hashset.hash.apply(
|
|
363
|
+
string_to_vector, # type: ignore[arg-type]
|
|
364
|
+
hash_length=int(hash_length),
|
|
365
|
+
dtype=dtype,
|
|
366
|
+
hash_format="base64",
|
|
367
|
+
).tolist()
|
|
368
|
+
)
|
|
369
|
+
X_noop = np.array(
|
|
370
|
+
noops.hash.apply(
|
|
371
|
+
string_to_vector,
|
|
372
|
+
dtype=dtype,
|
|
373
|
+
hash_format="base64",
|
|
374
|
+
hash_length=int(hash_length),
|
|
375
|
+
).tolist()
|
|
376
|
+
)
|
|
377
|
+
if (
|
|
378
|
+
distance_metric != "euclidean"
|
|
379
|
+
or "int" not in dtype
|
|
380
|
+
or extensions is None
|
|
381
|
+
):
|
|
382
|
+
distance_matrix = spatial.distance.cdist(
|
|
383
|
+
XA=X_trans, XB=X_noop, metric=distance_metric
|
|
384
|
+
)
|
|
385
|
+
distance_to_closest_image = distance_matrix.min(axis=1)
|
|
386
|
+
distance_to_correct_image = np.ma.masked_array(
|
|
387
|
+
distance_matrix, np.logical_not(mask)
|
|
388
|
+
).min(axis=1)
|
|
389
|
+
distance_matrix_incorrect_image: np.ndarray = np.ma.masked_array(
|
|
390
|
+
distance_matrix, mask
|
|
391
|
+
)
|
|
392
|
+
distance_to_incorrect_image = distance_matrix_incorrect_image.min(
|
|
393
|
+
axis=1
|
|
394
|
+
)
|
|
395
|
+
closest_incorrect_guid = noop_guids[
|
|
396
|
+
distance_matrix_incorrect_image.argmin(axis=1)
|
|
397
|
+
]
|
|
398
|
+
else:
|
|
399
|
+
distances, indexes = extensions.compute_euclidean_metrics(
|
|
400
|
+
X_noop.astype("int32"), X_trans.astype("int32"), mask
|
|
401
|
+
)
|
|
402
|
+
distance_to_correct_image = distances[:, 1]
|
|
403
|
+
distance_to_incorrect_image = distances[:, 0]
|
|
404
|
+
distance_to_closest_image = distances.min(axis=1)
|
|
405
|
+
closest_incorrect_guid = [noop_guids[idx] for idx in indexes[:, 0]]
|
|
406
|
+
else:
|
|
407
|
+
assert (
|
|
408
|
+
custom_distance_metrics is not None
|
|
409
|
+
and hasher_name in custom_distance_metrics
|
|
410
|
+
), f"You must provide a custom distance metric for {hasher_name}."
|
|
411
|
+
noops_hash_values = noops.hash.values
|
|
412
|
+
hashset_hash_values = valid_hashset.hash.values
|
|
413
|
+
distance_matrix = np.zeros((n_hashset, n_noops))
|
|
414
|
+
distance_function = custom_distance_metrics[hasher_name]
|
|
415
|
+
for i1, i2 in itertools.product(range(n_hashset), range(n_noops)):
|
|
416
|
+
distance_matrix[i1, i2] = distance_function(
|
|
417
|
+
hashset_hash_values[i1], noops_hash_values[i2]
|
|
418
|
+
)
|
|
419
|
+
distance_to_closest_image = distance_matrix.min(axis=1)
|
|
420
|
+
distance_to_correct_image = np.ma.masked_array(
|
|
421
|
+
distance_matrix, np.logical_not(mask)
|
|
422
|
+
).min(axis=1)
|
|
423
|
+
distance_matrix_incorrect_image = np.ma.masked_array(
|
|
424
|
+
distance_matrix, mask
|
|
425
|
+
)
|
|
426
|
+
distance_to_incorrect_image = distance_matrix_incorrect_image.min(
|
|
427
|
+
axis=1
|
|
428
|
+
)
|
|
429
|
+
closest_incorrect_guid = noop_guids[
|
|
430
|
+
distance_matrix_incorrect_image.argmin(axis=1)
|
|
431
|
+
]
|
|
432
|
+
|
|
433
|
+
metrics.append(
|
|
434
|
+
pd.DataFrame(
|
|
435
|
+
{
|
|
436
|
+
"guid": valid_hashset["guid"].values,
|
|
437
|
+
"transform_name": transform_name,
|
|
438
|
+
"hasher_name": hasher_name,
|
|
439
|
+
"category": category,
|
|
440
|
+
"distance_to_closest_correct_image": distance_to_correct_image,
|
|
441
|
+
"distance_to_closest_incorrect_image": distance_to_incorrect_image,
|
|
442
|
+
"distance_to_closest_image": distance_to_closest_image,
|
|
443
|
+
"closest_incorrect_guid": closest_incorrect_guid,
|
|
444
|
+
}
|
|
445
|
+
)
|
|
446
|
+
)
|
|
447
|
+
metrics_df = pd.concat(metrics)
|
|
448
|
+
self._metrics = metrics_df
|
|
449
|
+
return metrics_df
|
|
450
|
+
|
|
451
|
+
def show_histograms(self, grouping=None, precision_threshold=99.9, **kwargs):
|
|
452
|
+
"""Plot histograms for true and false positives, similar
|
|
453
|
+
to https://tech.okcupid.com/evaluating-perceptual-image-hashes-okcupid/
|
|
454
|
+
Additional arguments passed to compute_metrics.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
grouping: List of fields to group by. By default, all fields are used
|
|
458
|
+
(category, and transform_name).
|
|
459
|
+
"""
|
|
460
|
+
if grouping is None:
|
|
461
|
+
grouping = ["category", "transform_name"]
|
|
462
|
+
|
|
463
|
+
metrics = self.compute_metrics(**kwargs)
|
|
464
|
+
|
|
465
|
+
hasher_names = metrics["hasher_name"].unique().tolist()
|
|
466
|
+
bounds = (
|
|
467
|
+
metrics.groupby("hasher_name")[
|
|
468
|
+
["distance_to_closest_image", "distance_to_closest_incorrect_image"]
|
|
469
|
+
]
|
|
470
|
+
.max()
|
|
471
|
+
.max(axis=1)
|
|
472
|
+
)
|
|
473
|
+
if grouping:
|
|
474
|
+
group_names = [
|
|
475
|
+
":".join(map(str, row.values))
|
|
476
|
+
for idx, row in metrics[grouping].drop_duplicates().iterrows()
|
|
477
|
+
]
|
|
478
|
+
else:
|
|
479
|
+
group_names = [""]
|
|
480
|
+
ncols = len(hasher_names)
|
|
481
|
+
nrows = len(group_names)
|
|
482
|
+
|
|
483
|
+
fig, axs = plt.subplots(
|
|
484
|
+
ncols=ncols, nrows=nrows, figsize=(ncols * 4, nrows * 3), sharey=True
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
for group_name, subset in metrics.groupby(["hasher_name"] + grouping):
|
|
488
|
+
# Get names of group and hasher
|
|
489
|
+
if grouping:
|
|
490
|
+
hasher_name = group_name[0]
|
|
491
|
+
group_name = ":".join(map(str, group_name[1:]))
|
|
492
|
+
else:
|
|
493
|
+
hasher_name = group_name
|
|
494
|
+
group_name = ""
|
|
495
|
+
|
|
496
|
+
# Get the correct axis.
|
|
497
|
+
colIdx = hasher_names.index(hasher_name)
|
|
498
|
+
rowIdx = group_names.index(group_name)
|
|
499
|
+
if ncols > 1 and nrows > 1:
|
|
500
|
+
ax = axs[rowIdx, colIdx]
|
|
501
|
+
elif ncols == 1 and nrows == 1:
|
|
502
|
+
ax = axs
|
|
503
|
+
else:
|
|
504
|
+
ax = axs[rowIdx if nrows > 1 else colIdx]
|
|
505
|
+
|
|
506
|
+
# Plot the charts
|
|
507
|
+
pos, neg = (
|
|
508
|
+
subset.groupby(["guid", "transform_name"])[
|
|
509
|
+
[
|
|
510
|
+
"distance_to_closest_correct_image",
|
|
511
|
+
"distance_to_closest_incorrect_image",
|
|
512
|
+
]
|
|
513
|
+
]
|
|
514
|
+
.min()
|
|
515
|
+
.values.T
|
|
516
|
+
)
|
|
517
|
+
optimal_threshold, _, optimal_recall = compute_threshold_precision_recall(
|
|
518
|
+
pos=pos, neg=neg, precision_threshold=precision_threshold
|
|
519
|
+
)
|
|
520
|
+
optimal_threshold = optimal_threshold.round(3)
|
|
521
|
+
emd = stats.wasserstein_distance(pos, neg).round(2)
|
|
522
|
+
ax.hist(neg, label="neg", bins=10)
|
|
523
|
+
ax.hist(pos, label="pos", bins=10)
|
|
524
|
+
ax.text(
|
|
525
|
+
0.5,
|
|
526
|
+
0.5,
|
|
527
|
+
f"Recall: {optimal_recall:.0f}% @ {optimal_threshold}\nemd: {emd:.2f}",
|
|
528
|
+
horizontalalignment="center",
|
|
529
|
+
color="black",
|
|
530
|
+
verticalalignment="center",
|
|
531
|
+
transform=ax.transAxes,
|
|
532
|
+
fontsize=12,
|
|
533
|
+
fontweight=1000,
|
|
534
|
+
)
|
|
535
|
+
ax.set_xlim(-0.05 * bounds[hasher_name], bounds[hasher_name])
|
|
536
|
+
if rowIdx == 0:
|
|
537
|
+
ax.set_title(hasher_name)
|
|
538
|
+
ax.legend()
|
|
539
|
+
if colIdx == 0:
|
|
540
|
+
ax.set_ylabel(group_name)
|
|
541
|
+
fig.tight_layout()
|
|
542
|
+
|
|
543
|
+
def compute_threshold_recall(
|
|
544
|
+
self, precision_threshold=99.9, grouping=None, **kwargs
|
|
545
|
+
) -> pd.DataFrame:
|
|
546
|
+
"""Compute a table for threshold and recall for each category, hasher,
|
|
547
|
+
and transformation combinations. Additional arguments passed to compute_metrics.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
precision_threshold: The precision threshold to use
|
|
551
|
+
for choosing a distance threshold for each hasher.
|
|
552
|
+
grouping: List of fields to group by. By default, all fields are used
|
|
553
|
+
(category, and transform_name).
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
A pandas DataFrame with 7 columns. The key columns are threshold
|
|
557
|
+
(The optimal distance threshold for detecting a match for this
|
|
558
|
+
combination), recall (the number of correct matches divided by
|
|
559
|
+
the number of possible matches), and precision (the number correct
|
|
560
|
+
matches divided by the total number of matches whether correct
|
|
561
|
+
or incorrect).
|
|
562
|
+
"""
|
|
563
|
+
if grouping is None:
|
|
564
|
+
grouping = ["category", "transform_name"]
|
|
565
|
+
|
|
566
|
+
def group_func(subset):
|
|
567
|
+
pos, neg = (
|
|
568
|
+
subset.groupby(["guid", "transform_name"])[
|
|
569
|
+
[
|
|
570
|
+
"distance_to_closest_correct_image",
|
|
571
|
+
"distance_to_closest_incorrect_image",
|
|
572
|
+
]
|
|
573
|
+
]
|
|
574
|
+
.min()
|
|
575
|
+
.values.T
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
(
|
|
579
|
+
optimal_threshold,
|
|
580
|
+
optimal_precision,
|
|
581
|
+
optimal_recall,
|
|
582
|
+
) = compute_threshold_precision_recall(
|
|
583
|
+
pos=pos, neg=neg, precision_threshold=precision_threshold
|
|
584
|
+
)
|
|
585
|
+
return pd.Series(
|
|
586
|
+
{
|
|
587
|
+
"threshold": optimal_threshold,
|
|
588
|
+
"recall": optimal_recall,
|
|
589
|
+
"precision": optimal_precision,
|
|
590
|
+
"n_exemplars": len(subset),
|
|
591
|
+
}
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
return (
|
|
595
|
+
self.compute_metrics(**kwargs)
|
|
596
|
+
.groupby(grouping + ["hasher_name"])
|
|
597
|
+
.apply(group_func)
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
class BenchmarkDataset(Saveable):
|
|
602
|
+
"""A dataset of images separated into
|
|
603
|
+
categories. It is essentially a wrapper around a pandas
|
|
604
|
+
dataframe with the following columns:
|
|
605
|
+
|
|
606
|
+
- filepath
|
|
607
|
+
- category
|
|
608
|
+
"""
|
|
609
|
+
|
|
610
|
+
expected_columns = ["filepath", "category"]
|
|
611
|
+
|
|
612
|
+
@classmethod
|
|
613
|
+
def from_tuples(cls, files: typing.List[typing.Tuple[str, str]]):
|
|
614
|
+
"""Build dataset from a set of files.
|
|
615
|
+
|
|
616
|
+
Args:
|
|
617
|
+
files: A list of tuples where each entry is a pair
|
|
618
|
+
filepath and category.
|
|
619
|
+
"""
|
|
620
|
+
df = pd.DataFrame.from_records(
|
|
621
|
+
[{"filepath": f, "category": c} for f, c in files]
|
|
622
|
+
)
|
|
623
|
+
return cls(df)
|
|
624
|
+
|
|
625
|
+
def transform(self, transforms, storage_dir, errors):
|
|
626
|
+
raise NotImplementedError()
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
class BenchmarkTransforms(Saveable):
|
|
630
|
+
"""A dataset of transformed images. Essentially wraps a DataFrame with the
|
|
631
|
+
following columns:
|
|
632
|
+
|
|
633
|
+
- guid
|
|
634
|
+
- filepath
|
|
635
|
+
- category
|
|
636
|
+
- transform_name
|
|
637
|
+
- input_filepath (for memo purposes only)
|
|
638
|
+
"""
|
|
639
|
+
|
|
640
|
+
expected_columns = [
|
|
641
|
+
"filepath",
|
|
642
|
+
"category",
|
|
643
|
+
"transform_name",
|
|
644
|
+
"input_filepath",
|
|
645
|
+
"guid",
|
|
646
|
+
]
|
|
647
|
+
|
|
648
|
+
def compute_hashes(self, hashers, max_workers):
|
|
649
|
+
raise NotImplementedError()
|