Perception 0.7.4__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. perception/__init__.py +13 -0
  2. perception/benchmarking/__init__.py +23 -0
  3. perception/benchmarking/common.py +649 -0
  4. perception/benchmarking/extensions.c +31307 -0
  5. perception/benchmarking/extensions.cp312-win_amd64.pyd +0 -0
  6. perception/benchmarking/extensions.pyx +112 -0
  7. perception/benchmarking/image.py +202 -0
  8. perception/benchmarking/image_transforms.py +42 -0
  9. perception/benchmarking/video.py +224 -0
  10. perception/benchmarking/video_transforms.py +200 -0
  11. perception/experimental/__init__.py +0 -0
  12. perception/experimental/ann/__init__.py +0 -0
  13. perception/experimental/ann/index.py +430 -0
  14. perception/experimental/ann/serve.py +152 -0
  15. perception/experimental/approximate_deduplication.py +301 -0
  16. perception/experimental/debug.py +240 -0
  17. perception/experimental/local_descriptor_deduplication.py +710 -0
  18. perception/extensions.cp312-win_amd64.pyd +0 -0
  19. perception/extensions.cpp +33751 -0
  20. perception/extensions.pyx +305 -0
  21. perception/hashers/__init__.py +27 -0
  22. perception/hashers/hasher.py +406 -0
  23. perception/hashers/image/__init__.py +17 -0
  24. perception/hashers/image/average.py +35 -0
  25. perception/hashers/image/dhash.py +30 -0
  26. perception/hashers/image/opencv.py +63 -0
  27. perception/hashers/image/pdq.py +34 -0
  28. perception/hashers/image/phash.py +109 -0
  29. perception/hashers/image/wavelet.py +59 -0
  30. perception/hashers/tools.py +1075 -0
  31. perception/hashers/video/__init__.py +5 -0
  32. perception/hashers/video/framewise.py +106 -0
  33. perception/hashers/video/scenes.py +241 -0
  34. perception/hashers/video/tmk.py +215 -0
  35. perception/py.typed +0 -0
  36. perception/testing/__init__.py +243 -0
  37. perception/testing/images/README.md +13 -0
  38. perception/testing/images/image1.jpg +0 -0
  39. perception/testing/images/image10.jpg +0 -0
  40. perception/testing/images/image2.jpg +0 -0
  41. perception/testing/images/image3.jpg +0 -0
  42. perception/testing/images/image4.jpg +0 -0
  43. perception/testing/images/image5.jpg +0 -0
  44. perception/testing/images/image6.jpg +0 -0
  45. perception/testing/images/image7.jpg +0 -0
  46. perception/testing/images/image8.jpg +0 -0
  47. perception/testing/images/image9.jpg +0 -0
  48. perception/testing/logos/README.md +4 -0
  49. perception/testing/logos/logoipsum.png +0 -0
  50. perception/testing/videos/README.md +6 -0
  51. perception/testing/videos/expected_tmk.json.gz +0 -0
  52. perception/testing/videos/rgb.m4v +0 -0
  53. perception/testing/videos/v1.m4v +0 -0
  54. perception/testing/videos/v2.m4v +0 -0
  55. perception/testing/videos/v2s.mov +0 -0
  56. perception/tools.py +387 -0
  57. perception/utils.py +2 -0
  58. perception-0.7.4.dist-info/DELVEWHEEL +1 -0
  59. perception-0.7.4.dist-info/LICENSE +191 -0
  60. perception-0.7.4.dist-info/METADATA +112 -0
  61. perception-0.7.4.dist-info/RECORD +63 -0
  62. perception-0.7.4.dist-info/WHEEL +4 -0
  63. perception.libs/msvcp140-370c82302f0983347afe7f970ea2ece2.dll +0 -0
perception/__init__.py ADDED
@@ -0,0 +1,13 @@
1
+ """""" # start delvewheel patch
2
+ def _delvewheel_patch_1_8_1():
3
+ import os
4
+ libs_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'perception.libs'))
5
+ if os.path.isdir(libs_dir):
6
+ os.add_dll_directory(libs_dir)
7
+
8
+
9
+ _delvewheel_patch_1_8_1()
10
+ del _delvewheel_patch_1_8_1
11
+ # end delvewheel patch
12
+
13
+ __version__ = "0.7.4"
@@ -0,0 +1,23 @@
1
+ from perception.benchmarking import video_transforms
2
+ from perception.benchmarking import video
3
+ from perception.benchmarking import image
4
+ from perception.benchmarking.image import (
5
+ BenchmarkImageDataset,
6
+ BenchmarkImageTransforms,
7
+ )
8
+ from perception.benchmarking.video import (
9
+ BenchmarkVideoDataset,
10
+ BenchmarkVideoTransforms,
11
+ )
12
+ from perception.benchmarking.common import BenchmarkHashes
13
+
14
+ __all__ = [
15
+ "BenchmarkImageDataset",
16
+ "BenchmarkImageTransforms",
17
+ "BenchmarkVideoDataset",
18
+ "BenchmarkVideoTransforms",
19
+ "BenchmarkHashes",
20
+ "video_transforms",
21
+ "video",
22
+ "image",
23
+ ]
@@ -0,0 +1,649 @@
1
+ import itertools
2
+ import logging
3
+ import os
4
+ import shutil
5
+ import tempfile
6
+ import typing
7
+ import uuid
8
+ import warnings
9
+ import zipfile
10
+ from abc import ABC
11
+ from typing import Optional
12
+
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import pandas as pd
16
+ import tqdm
17
+ from scipy import spatial, stats
18
+
19
+ from ..hashers.tools import compute_md5, string_to_vector
20
+
21
+ try:
22
+ from . import extensions # type: ignore
23
+ except ImportError:
24
+ warnings.warn(
25
+ "C extensions were not built. Some metrics will be computed more slowly. "
26
+ "Please install from wheels or set up a compiler prior to installation "
27
+ "from source to use extensions."
28
+ )
29
+ extensions = None
30
+
31
+ log = logging.getLogger(__name__)
32
+
33
+
34
+ def create_mask(transformed_guids, noop_guids):
35
+ """Given a list of transformed guids and noop guids,
36
+ computes an MxN array indicating whether noop n has the same guid
37
+ as transform m. Used for applying a mask to a distance matrix
38
+ for efficient computation of recall at different thresholds.
39
+
40
+ Args:
41
+ transformed_guids: An iterable of transformed guids
42
+ noop: An iterable of noop guids
43
+
44
+ Returns:
45
+ An boolean array of shape
46
+ `(len(transformed_guids), len(transformed_noops))`
47
+ """
48
+ n_noops = len(noop_guids)
49
+ previous_guid = None
50
+ start = None
51
+ end = 0
52
+ mask = np.zeros((len(transformed_guids), len(noop_guids)), dtype="bool")
53
+ for current_guid, row in zip(transformed_guids, mask):
54
+ if previous_guid is None or current_guid != previous_guid:
55
+ start = end
56
+ end = start + next(
57
+ (
58
+ other_index
59
+ for other_index, guid in enumerate(noop_guids[start:])
60
+ if guid != current_guid
61
+ ),
62
+ n_noops,
63
+ )
64
+ previous_guid = current_guid
65
+ row[start:end] = True
66
+ return mask
67
+
68
+
69
+ def compute_threshold_precision_recall(pos, neg, precision_threshold=99.9):
70
+ # Sort both arrays according to the positive distance
71
+ neg = neg[pos.argsort()]
72
+ pos = pos[pos.argsort()]
73
+
74
+ # Compute false-positive rate for every value in pos
75
+ tp = np.arange(1, len(pos) + 1)
76
+ fp = np.array([(neg <= t).sum() for t in pos])
77
+ precision = 100 * tp / (tp + fp)
78
+
79
+ # Choose the optimal threshold
80
+ bad_threshold_idxs = np.where(precision < precision_threshold)[0]
81
+
82
+ if len(bad_threshold_idxs) > 0 and bad_threshold_idxs[0] > 0:
83
+ optimal_threshold = pos[bad_threshold_idxs[0] - 1]
84
+ recovered = (pos <= optimal_threshold).sum()
85
+ if recovered == 0:
86
+ optimal_precision = np.nan
87
+ else:
88
+ optimal_precision = precision[pos <= optimal_threshold].min()
89
+ optimal_recall = round(100 * recovered / len(pos), 3)
90
+ elif len(bad_threshold_idxs) > 0:
91
+ # The closest hash was a false positive.
92
+ optimal_threshold = pos[0]
93
+ optimal_recall = 0
94
+ optimal_precision = np.nan
95
+ else:
96
+ optimal_precision = 100
97
+ optimal_threshold = pos.max()
98
+ optimal_recall = 100
99
+ return optimal_threshold, optimal_precision, optimal_recall
100
+
101
+
102
+ class Filterable(ABC):
103
+ _df: pd.DataFrame
104
+ expected_columns: typing.List
105
+
106
+ def __init__(self, df):
107
+ assert sorted(df.columns) == sorted(
108
+ self.expected_columns
109
+ ), f"Column mismatch: Expected {sorted(self.expected_columns)}, found {sorted(df.columns)}."
110
+ self._df = df
111
+
112
+ @property
113
+ def categories(self):
114
+ """The categories included in the dataset"""
115
+ return self._df["category"].unique()
116
+
117
+ def filter(self, **kwargs):
118
+ """Obtain a new dataset filtered with the given
119
+ keyword arguments."""
120
+ df = self._df.copy()
121
+ for field, included in kwargs.items():
122
+ existing = self._df[field].unique()
123
+ if not all(inc in existing for inc in included):
124
+ missing = ", ".join(
125
+ [str(inc) for inc in included if inc not in existing]
126
+ )
127
+ message = f"Did not find {missing} in column {field} dataset."
128
+ warnings.warn(message, UserWarning)
129
+ df = df[df[field].isin(included)]
130
+ return self.__class__(df.copy())
131
+
132
+
133
+ class Saveable(Filterable):
134
+ @classmethod
135
+ def load(
136
+ cls,
137
+ path_to_zip_or_directory: str,
138
+ storage_dir: Optional[str] = None,
139
+ verify_md5=True,
140
+ ):
141
+ """Load a dataset from a ZIP file or directory.
142
+
143
+ Args:
144
+ path_to_zip_or_directory: Pretty self-explanatory
145
+ storage_dir: If providing a ZIP file, where to extract
146
+ the contents. If None, contents will be extracted to
147
+ a folder with the same name as the ZIP file in the
148
+ same directory as the ZIP file.
149
+ verify_md5: Verify md5s when loading
150
+ """
151
+
152
+ # Load index whether from inside ZIP file or from directory.
153
+ if os.path.splitext(path_to_zip_or_directory)[1] == ".zip":
154
+ if storage_dir is None:
155
+ storage_dir = os.path.join(
156
+ os.path.dirname(os.path.abspath(path_to_zip_or_directory)),
157
+ os.path.splitext(os.path.basename(path_to_zip_or_directory))[0],
158
+ )
159
+ os.makedirs(storage_dir, exist_ok=True)
160
+ with zipfile.ZipFile(path_to_zip_or_directory, "r") as z:
161
+ # Try extracting only the index at first so we can
162
+ # compare md5.
163
+ z.extract("index.csv", os.path.join(storage_dir))
164
+ index: pd.DataFrame = pd.read_csv(
165
+ os.path.join(storage_dir, "index.csv")
166
+ )
167
+ index["filepath"] = index["filename"].apply(
168
+ lambda fn: (
169
+ os.path.join(storage_dir, fn) if not pd.isnull(fn) else None
170
+ )
171
+ )
172
+ do_zip_extraction = True
173
+ if index["filepath"].apply(os.path.isfile).all():
174
+ if verify_md5:
175
+ do_zip_extraction = not all(
176
+ row["md5"] == compute_md5(row["filepath"])
177
+ for _, row in tqdm.tqdm(
178
+ index.iterrows(), desc="Checking cache"
179
+ )
180
+ )
181
+ else:
182
+ do_zip_extraction = False
183
+ if do_zip_extraction:
184
+ z.extractall(storage_dir)
185
+ else:
186
+ log.info("Found all files already extracted. Skipping extraction.")
187
+ verify_md5 = False
188
+ else:
189
+ assert (
190
+ storage_dir is None
191
+ ), "Storage directory only valid if path is to ZIP file."
192
+ index = pd.read_csv(os.path.join(path_to_zip_or_directory, "index.csv"))
193
+ index["filepath"] = index["filename"].apply(
194
+ lambda fn: (
195
+ os.path.join(path_to_zip_or_directory, fn)
196
+ if not pd.isnull(fn)
197
+ else None
198
+ )
199
+ )
200
+
201
+ if verify_md5:
202
+ assert all(
203
+ row["md5"] == compute_md5(row["filepath"])
204
+ for _, row in tqdm.tqdm(
205
+ index.iterrows(),
206
+ desc="Performing final md5 integrity check.",
207
+ total=len(index.index),
208
+ )
209
+ ), "An md5 mismatch has occurred."
210
+ return cls(index.drop(["filename", "md5"], axis=1))
211
+
212
+ def save(self, path_to_zip_or_directory):
213
+ """Save a dataset to a directory or ZIP file.
214
+
215
+ Args:
216
+ path_to_zip_or_directory: Pretty self-explanatory
217
+ """
218
+ df = self._df
219
+ assert "filepath" in df.columns, "Index dataframe must contain filepath."
220
+
221
+ # Build index using filename instead of filepath.
222
+ index = df.copy()
223
+ index["filename"] = df["filepath"].apply(
224
+ lambda filepath: (
225
+ os.path.basename(filepath) if not pd.isnull(filepath) else None
226
+ )
227
+ )
228
+ if index["filename"].dropna().duplicated().sum() > 0:
229
+ warnings.warn("Changing filenames to UUID due to duplicates.", UserWarning)
230
+
231
+ index["filename"] = [
232
+ (
233
+ str(uuid.uuid4()) + os.path.splitext(row["filename"])[1]
234
+ if not pd.isnull(row["filename"])
235
+ else None
236
+ )
237
+ for _, row in index.iterrows()
238
+ ]
239
+ index["md5"] = [
240
+ compute_md5(filepath) if not pd.isnull(filepath) else None
241
+ for filepath in tqdm.tqdm(index["filepath"], desc="Computing md5s.")
242
+ ]
243
+
244
+ # Add all files as well as the dataframe index to
245
+ # a ZIP file if path is to ZIP file or to the directory if it is
246
+ # not a ZIP file.
247
+ if os.path.splitext(path_to_zip_or_directory)[1] == ".zip":
248
+ with zipfile.ZipFile(path_to_zip_or_directory, "w") as f:
249
+ with tempfile.TemporaryFile(mode="w+") as index_file:
250
+ index.drop("filepath", axis=1).to_csv(index_file, index=False)
251
+ index_file.seek(0)
252
+ f.writestr("index.csv", index_file.read())
253
+ for _, row in tqdm.tqdm(
254
+ index.iterrows(), desc="Saving files", total=len(df)
255
+ ):
256
+ if pd.isnull(row["filepath"]):
257
+ # There was an error associated with this file.
258
+ continue
259
+ f.write(row["filepath"], row["filename"])
260
+ else:
261
+ os.makedirs(path_to_zip_or_directory, exist_ok=True)
262
+ index.drop("filepath", axis=1).to_csv(
263
+ os.path.join(path_to_zip_or_directory, "index.csv"), index=False
264
+ )
265
+ for _, row in tqdm.tqdm(
266
+ index.iterrows(), desc="Saving files", total=len(df)
267
+ ):
268
+ if pd.isnull(row["filepath"]):
269
+ # There was an error associated with this file.
270
+ continue
271
+ if row["filepath"] == os.path.join(
272
+ path_to_zip_or_directory, row["filename"]
273
+ ):
274
+ # The source file is the same as the target file.
275
+ continue
276
+ shutil.copy(
277
+ row["filepath"],
278
+ os.path.join(path_to_zip_or_directory, row["filename"]),
279
+ )
280
+
281
+
282
+ class BenchmarkHashes(Filterable):
283
+ """A dataset of hashes for transformed images. It is essentially
284
+ a wrapper around a `pandas.DataFrame` with the following columns:
285
+
286
+ - guid
287
+ - error
288
+ - filepath
289
+ - category
290
+ - transform_name
291
+ - hasher_name
292
+ - hasher_dtype
293
+ - hasher_distance_metric
294
+ - hasher_hash_length
295
+ - hash
296
+ """
297
+
298
+ expected_columns = [
299
+ "error",
300
+ "filepath",
301
+ "hash",
302
+ "hasher_name",
303
+ "hasher_dtype",
304
+ "hasher_distance_metric",
305
+ "category",
306
+ "guid",
307
+ "input_filepath",
308
+ "transform_name",
309
+ "hasher_hash_length",
310
+ ]
311
+
312
+ def __init__(self, df: pd.DataFrame):
313
+ super().__init__(df)
314
+ self._metrics: Optional[pd.DataFrame] = None
315
+
316
+ def __add__(self, other):
317
+ return BenchmarkHashes(df=pd.concat([self._df, other._df]).drop_duplicates())
318
+
319
+ def __radd__(self, other):
320
+ return self.__add__(other)
321
+
322
+ @classmethod
323
+ def load(cls, filepath: str):
324
+ return cls(pd.read_csv(filepath))
325
+
326
+ def save(self, filepath):
327
+ self._df.to_csv(filepath, index=False)
328
+
329
+ def compute_metrics(
330
+ self, custom_distance_metrics: Optional[dict] = None
331
+ ) -> pd.DataFrame:
332
+ if self._metrics is not None:
333
+ return self._metrics
334
+ metrics = []
335
+ hashsets = self._df.sort_values("guid")
336
+ n_dropped = hashsets["hash"].isnull().sum()
337
+ if n_dropped > 0:
338
+ hashsets = hashsets.dropna(subset=["hash"])
339
+ warnings.warn(f"Dropping {n_dropped} invalid / empty hashes.", UserWarning)
340
+ for (hasher_name, transform_name, category), hashset in tqdm.tqdm(
341
+ hashsets.groupby(["hasher_name", "transform_name", "category"]),
342
+ desc="Computing metrics.",
343
+ ):
344
+ # Note the guid filtering below. We need to include only guids
345
+ # for which we have the transform *and* the guid. One of them
346
+ # may have been dropped due to being invalid.
347
+ noops = hashsets[
348
+ (hashsets["transform_name"] == "noop")
349
+ & (hashsets["hasher_name"] == hasher_name)
350
+ & (hashsets["guid"].isin(hashset["guid"]))
351
+ ]
352
+ valid_hashset = hashset[hashset["guid"].isin(noops["guid"])]
353
+ dtype, distance_metric, hash_length = valid_hashset.iloc[0][
354
+ ["hasher_dtype", "hasher_distance_metric", "hasher_hash_length"]
355
+ ]
356
+ n_noops = len(noops.guid)
357
+ n_hashset = len(valid_hashset.guid)
358
+ noop_guids = noops.guid.values
359
+ mask = create_mask(valid_hashset.guid.values, noops.guid.values)
360
+ if distance_metric != "custom":
361
+ X_trans = np.array(
362
+ valid_hashset.hash.apply(
363
+ string_to_vector, # type: ignore[arg-type]
364
+ hash_length=int(hash_length),
365
+ dtype=dtype,
366
+ hash_format="base64",
367
+ ).tolist()
368
+ )
369
+ X_noop = np.array(
370
+ noops.hash.apply(
371
+ string_to_vector,
372
+ dtype=dtype,
373
+ hash_format="base64",
374
+ hash_length=int(hash_length),
375
+ ).tolist()
376
+ )
377
+ if (
378
+ distance_metric != "euclidean"
379
+ or "int" not in dtype
380
+ or extensions is None
381
+ ):
382
+ distance_matrix = spatial.distance.cdist(
383
+ XA=X_trans, XB=X_noop, metric=distance_metric
384
+ )
385
+ distance_to_closest_image = distance_matrix.min(axis=1)
386
+ distance_to_correct_image = np.ma.masked_array(
387
+ distance_matrix, np.logical_not(mask)
388
+ ).min(axis=1)
389
+ distance_matrix_incorrect_image: np.ndarray = np.ma.masked_array(
390
+ distance_matrix, mask
391
+ )
392
+ distance_to_incorrect_image = distance_matrix_incorrect_image.min(
393
+ axis=1
394
+ )
395
+ closest_incorrect_guid = noop_guids[
396
+ distance_matrix_incorrect_image.argmin(axis=1)
397
+ ]
398
+ else:
399
+ distances, indexes = extensions.compute_euclidean_metrics(
400
+ X_noop.astype("int32"), X_trans.astype("int32"), mask
401
+ )
402
+ distance_to_correct_image = distances[:, 1]
403
+ distance_to_incorrect_image = distances[:, 0]
404
+ distance_to_closest_image = distances.min(axis=1)
405
+ closest_incorrect_guid = [noop_guids[idx] for idx in indexes[:, 0]]
406
+ else:
407
+ assert (
408
+ custom_distance_metrics is not None
409
+ and hasher_name in custom_distance_metrics
410
+ ), f"You must provide a custom distance metric for {hasher_name}."
411
+ noops_hash_values = noops.hash.values
412
+ hashset_hash_values = valid_hashset.hash.values
413
+ distance_matrix = np.zeros((n_hashset, n_noops))
414
+ distance_function = custom_distance_metrics[hasher_name]
415
+ for i1, i2 in itertools.product(range(n_hashset), range(n_noops)):
416
+ distance_matrix[i1, i2] = distance_function(
417
+ hashset_hash_values[i1], noops_hash_values[i2]
418
+ )
419
+ distance_to_closest_image = distance_matrix.min(axis=1)
420
+ distance_to_correct_image = np.ma.masked_array(
421
+ distance_matrix, np.logical_not(mask)
422
+ ).min(axis=1)
423
+ distance_matrix_incorrect_image = np.ma.masked_array(
424
+ distance_matrix, mask
425
+ )
426
+ distance_to_incorrect_image = distance_matrix_incorrect_image.min(
427
+ axis=1
428
+ )
429
+ closest_incorrect_guid = noop_guids[
430
+ distance_matrix_incorrect_image.argmin(axis=1)
431
+ ]
432
+
433
+ metrics.append(
434
+ pd.DataFrame(
435
+ {
436
+ "guid": valid_hashset["guid"].values,
437
+ "transform_name": transform_name,
438
+ "hasher_name": hasher_name,
439
+ "category": category,
440
+ "distance_to_closest_correct_image": distance_to_correct_image,
441
+ "distance_to_closest_incorrect_image": distance_to_incorrect_image,
442
+ "distance_to_closest_image": distance_to_closest_image,
443
+ "closest_incorrect_guid": closest_incorrect_guid,
444
+ }
445
+ )
446
+ )
447
+ metrics_df = pd.concat(metrics)
448
+ self._metrics = metrics_df
449
+ return metrics_df
450
+
451
+ def show_histograms(self, grouping=None, precision_threshold=99.9, **kwargs):
452
+ """Plot histograms for true and false positives, similar
453
+ to https://tech.okcupid.com/evaluating-perceptual-image-hashes-okcupid/
454
+ Additional arguments passed to compute_metrics.
455
+
456
+ Args:
457
+ grouping: List of fields to group by. By default, all fields are used
458
+ (category, and transform_name).
459
+ """
460
+ if grouping is None:
461
+ grouping = ["category", "transform_name"]
462
+
463
+ metrics = self.compute_metrics(**kwargs)
464
+
465
+ hasher_names = metrics["hasher_name"].unique().tolist()
466
+ bounds = (
467
+ metrics.groupby("hasher_name")[
468
+ ["distance_to_closest_image", "distance_to_closest_incorrect_image"]
469
+ ]
470
+ .max()
471
+ .max(axis=1)
472
+ )
473
+ if grouping:
474
+ group_names = [
475
+ ":".join(map(str, row.values))
476
+ for idx, row in metrics[grouping].drop_duplicates().iterrows()
477
+ ]
478
+ else:
479
+ group_names = [""]
480
+ ncols = len(hasher_names)
481
+ nrows = len(group_names)
482
+
483
+ fig, axs = plt.subplots(
484
+ ncols=ncols, nrows=nrows, figsize=(ncols * 4, nrows * 3), sharey=True
485
+ )
486
+
487
+ for group_name, subset in metrics.groupby(["hasher_name"] + grouping):
488
+ # Get names of group and hasher
489
+ if grouping:
490
+ hasher_name = group_name[0]
491
+ group_name = ":".join(map(str, group_name[1:]))
492
+ else:
493
+ hasher_name = group_name
494
+ group_name = ""
495
+
496
+ # Get the correct axis.
497
+ colIdx = hasher_names.index(hasher_name)
498
+ rowIdx = group_names.index(group_name)
499
+ if ncols > 1 and nrows > 1:
500
+ ax = axs[rowIdx, colIdx]
501
+ elif ncols == 1 and nrows == 1:
502
+ ax = axs
503
+ else:
504
+ ax = axs[rowIdx if nrows > 1 else colIdx]
505
+
506
+ # Plot the charts
507
+ pos, neg = (
508
+ subset.groupby(["guid", "transform_name"])[
509
+ [
510
+ "distance_to_closest_correct_image",
511
+ "distance_to_closest_incorrect_image",
512
+ ]
513
+ ]
514
+ .min()
515
+ .values.T
516
+ )
517
+ optimal_threshold, _, optimal_recall = compute_threshold_precision_recall(
518
+ pos=pos, neg=neg, precision_threshold=precision_threshold
519
+ )
520
+ optimal_threshold = optimal_threshold.round(3)
521
+ emd = stats.wasserstein_distance(pos, neg).round(2)
522
+ ax.hist(neg, label="neg", bins=10)
523
+ ax.hist(pos, label="pos", bins=10)
524
+ ax.text(
525
+ 0.5,
526
+ 0.5,
527
+ f"Recall: {optimal_recall:.0f}% @ {optimal_threshold}\nemd: {emd:.2f}",
528
+ horizontalalignment="center",
529
+ color="black",
530
+ verticalalignment="center",
531
+ transform=ax.transAxes,
532
+ fontsize=12,
533
+ fontweight=1000,
534
+ )
535
+ ax.set_xlim(-0.05 * bounds[hasher_name], bounds[hasher_name])
536
+ if rowIdx == 0:
537
+ ax.set_title(hasher_name)
538
+ ax.legend()
539
+ if colIdx == 0:
540
+ ax.set_ylabel(group_name)
541
+ fig.tight_layout()
542
+
543
+ def compute_threshold_recall(
544
+ self, precision_threshold=99.9, grouping=None, **kwargs
545
+ ) -> pd.DataFrame:
546
+ """Compute a table for threshold and recall for each category, hasher,
547
+ and transformation combinations. Additional arguments passed to compute_metrics.
548
+
549
+ Args:
550
+ precision_threshold: The precision threshold to use
551
+ for choosing a distance threshold for each hasher.
552
+ grouping: List of fields to group by. By default, all fields are used
553
+ (category, and transform_name).
554
+
555
+ Returns:
556
+ A pandas DataFrame with 7 columns. The key columns are threshold
557
+ (The optimal distance threshold for detecting a match for this
558
+ combination), recall (the number of correct matches divided by
559
+ the number of possible matches), and precision (the number correct
560
+ matches divided by the total number of matches whether correct
561
+ or incorrect).
562
+ """
563
+ if grouping is None:
564
+ grouping = ["category", "transform_name"]
565
+
566
+ def group_func(subset):
567
+ pos, neg = (
568
+ subset.groupby(["guid", "transform_name"])[
569
+ [
570
+ "distance_to_closest_correct_image",
571
+ "distance_to_closest_incorrect_image",
572
+ ]
573
+ ]
574
+ .min()
575
+ .values.T
576
+ )
577
+
578
+ (
579
+ optimal_threshold,
580
+ optimal_precision,
581
+ optimal_recall,
582
+ ) = compute_threshold_precision_recall(
583
+ pos=pos, neg=neg, precision_threshold=precision_threshold
584
+ )
585
+ return pd.Series(
586
+ {
587
+ "threshold": optimal_threshold,
588
+ "recall": optimal_recall,
589
+ "precision": optimal_precision,
590
+ "n_exemplars": len(subset),
591
+ }
592
+ )
593
+
594
+ return (
595
+ self.compute_metrics(**kwargs)
596
+ .groupby(grouping + ["hasher_name"])
597
+ .apply(group_func)
598
+ )
599
+
600
+
601
+ class BenchmarkDataset(Saveable):
602
+ """A dataset of images separated into
603
+ categories. It is essentially a wrapper around a pandas
604
+ dataframe with the following columns:
605
+
606
+ - filepath
607
+ - category
608
+ """
609
+
610
+ expected_columns = ["filepath", "category"]
611
+
612
+ @classmethod
613
+ def from_tuples(cls, files: typing.List[typing.Tuple[str, str]]):
614
+ """Build dataset from a set of files.
615
+
616
+ Args:
617
+ files: A list of tuples where each entry is a pair
618
+ filepath and category.
619
+ """
620
+ df = pd.DataFrame.from_records(
621
+ [{"filepath": f, "category": c} for f, c in files]
622
+ )
623
+ return cls(df)
624
+
625
+ def transform(self, transforms, storage_dir, errors):
626
+ raise NotImplementedError()
627
+
628
+
629
+ class BenchmarkTransforms(Saveable):
630
+ """A dataset of transformed images. Essentially wraps a DataFrame with the
631
+ following columns:
632
+
633
+ - guid
634
+ - filepath
635
+ - category
636
+ - transform_name
637
+ - input_filepath (for memo purposes only)
638
+ """
639
+
640
+ expected_columns = [
641
+ "filepath",
642
+ "category",
643
+ "transform_name",
644
+ "input_filepath",
645
+ "guid",
646
+ ]
647
+
648
+ def compute_hashes(self, hashers, max_workers):
649
+ raise NotImplementedError()