dataeval 0.72.1__py3-none-any.whl → 0.72.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +9 -10
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
  16. dataeval/detectors/ood/__init__.py +6 -6
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +7 -7
  18. dataeval/{_internal/detectors → detectors}/ood/aegmm.py +9 -29
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +24 -18
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +24 -20
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +10 -12
  25. dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
  26. dataeval/{_internal/interop.py → interop.py} +12 -7
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -9
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +6 -4
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +48 -14
  32. dataeval/metrics/bias/metadata.py +275 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +12 -10
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +7 -3
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/shared.py +151 -0
  51. dataeval/{_internal → utils}/split_dataset.py +98 -33
  52. dataeval/utils/tensorflow/__init__.py +7 -6
  53. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +60 -64
  54. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +9 -8
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +16 -20
  56. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
  57. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +17 -17
  58. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  59. dataeval/utils/torch/__init__.py +7 -3
  60. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  61. dataeval/{_internal → utils/torch}/datasets.py +48 -42
  62. dataeval/utils/torch/models.py +138 -0
  63. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
  64. dataeval/{_internal → utils/torch}/utils.py +3 -1
  65. dataeval/workflows/__init__.py +1 -1
  66. dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
  67. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/METADATA +2 -1
  68. dataeval-0.72.2.dist-info/RECORD +72 -0
  69. dataeval/_internal/detectors/__init__.py +0 -0
  70. dataeval/_internal/detectors/drift/__init__.py +0 -0
  71. dataeval/_internal/detectors/ood/__init__.py +0 -0
  72. dataeval/_internal/metrics/__init__.py +0 -0
  73. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  74. dataeval/_internal/metrics/utils.py +0 -447
  75. dataeval/_internal/models/__init__.py +0 -0
  76. dataeval/_internal/models/pytorch/__init__.py +0 -0
  77. dataeval/_internal/models/pytorch/utils.py +0 -67
  78. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  79. dataeval/_internal/workflows/__init__.py +0 -0
  80. dataeval/detectors/drift/kernels/__init__.py +0 -10
  81. dataeval/detectors/drift/updates/__init__.py +0 -8
  82. dataeval/utils/tensorflow/models/__init__.py +0 -9
  83. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  84. dataeval/utils/torch/datasets/__init__.py +0 -12
  85. dataeval/utils/torch/models/__init__.py +0 -11
  86. dataeval/utils/torch/trainer/__init__.py +0 -7
  87. dataeval-0.72.1.dist-info/RECORD +0 -81
  88. /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
  89. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
  90. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,156 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = ["HashStatsOutput", "hashstats"]
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Callable, Iterable
7
+
8
+ import numpy as np
9
+ import xxhash as xxh
10
+ from numpy.typing import ArrayLike
11
+ from PIL import Image
12
+ from scipy.fftpack import dct
13
+
14
+ from dataeval.interop import as_numpy
15
+ from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
16
+ from dataeval.output import set_metadata
17
+ from dataeval.utils.image import normalize_image_shape, rescale
18
+
19
+ HASH_SIZE = 8
20
+ MAX_FACTOR = 4
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class HashStatsOutput(BaseStatsOutput):
25
+ """
26
+ Output class for :func:`hashstats` stats metric
27
+
28
+ Attributes
29
+ ----------
30
+ xxhash : List[str]
31
+ xxHash hash of the images as a hex string
32
+ pchash : List[str]
33
+ :term:`Perception-based Hash` of the images as a hex string
34
+ """
35
+
36
+ xxhash: list[str]
37
+ pchash: list[str]
38
+
39
+
40
+ def pchash(image: ArrayLike) -> str:
41
+ """
42
+ Performs a perceptual hash on an image by resizing to a square NxN image
43
+ using the Lanczos algorithm where N is 32x32 or the largest multiple of
44
+ 8 that is smaller than the input image dimensions. The resampled image
45
+ is compressed using a discrete cosine transform and the lowest frequency
46
+ component is encoded as a bit array of greater or less than median value
47
+ and returned as a hex string.
48
+
49
+ Parameters
50
+ ----------
51
+ image : ArrayLike
52
+ An image as a numpy array in CxHxW format
53
+
54
+ Returns
55
+ -------
56
+ str
57
+ The hex string hash of the image using perceptual hashing
58
+ """
59
+ # Verify that the image is at least larger than an 8x8 image
60
+ arr = as_numpy(image)
61
+ min_dim = min(arr.shape[-2:])
62
+ if min_dim < HASH_SIZE + 1:
63
+ raise ValueError(f"Image must be larger than {HASH_SIZE}x{HASH_SIZE} for fuzzy hashing.")
64
+
65
+ # Calculates the dimensions of the resized square image
66
+ resize_dim = HASH_SIZE * min((min_dim - 1) // HASH_SIZE, MAX_FACTOR)
67
+
68
+ # Normalizes the image to CxHxW and takes the mean over all the channels
69
+ normalized = np.mean(normalize_image_shape(arr), axis=0).squeeze()
70
+
71
+ # Rescales the pixel values to an 8-bit 0-255 image
72
+ rescaled = rescale(normalized, 8).astype(np.uint8)
73
+
74
+ # Resizes the image using the Lanczos algorithm to a square image
75
+ im = np.array(Image.fromarray(rescaled).resize((resize_dim, resize_dim), Image.Resampling.LANCZOS))
76
+
77
+ # Performs discrete cosine transforms to compress the image information and takes the lowest frequency component
78
+ transform = dct(dct(im.T).T)[:HASH_SIZE, :HASH_SIZE]
79
+
80
+ # Encodes the transform as a bit array over the median value
81
+ diff = transform > np.median(transform)
82
+
83
+ # Pads the front of the bit array to a multiple of 8 with False
84
+ padded = np.full(int(np.ceil(diff.size / 8) * 8), False)
85
+ padded[-diff.size :] = diff.ravel()
86
+
87
+ # Converts the bit array to a hex string and strips leading 0s
88
+ hash_hex = np.packbits(padded).tobytes().hex().lstrip("0")
89
+ return hash_hex if hash_hex else "0"
90
+
91
+
92
+ def xxhash(image: ArrayLike) -> str:
93
+ """
94
+ Performs a fast non-cryptographic hash using the xxhash algorithm
95
+ (xxhash.com) against the image as a flattened bytearray. The hash
96
+ is returned as a hex string.
97
+
98
+ Parameters
99
+ ----------
100
+ image : ArrayLike
101
+ An image as a numpy array
102
+
103
+ Returns
104
+ -------
105
+ str
106
+ The hex string hash of the image using the xxHash algorithm
107
+ """
108
+ return xxh.xxh3_64_hexdigest(as_numpy(image).ravel().tobytes())
109
+
110
+
111
+ class HashStatsProcessor(StatsProcessor[HashStatsOutput]):
112
+ output_class: type = HashStatsOutput
113
+ image_function_map: dict[str, Callable[[StatsProcessor[HashStatsOutput]], str]] = {
114
+ "xxhash": lambda x: xxhash(x.image),
115
+ "pchash": lambda x: pchash(x.image),
116
+ }
117
+
118
+
119
+ @set_metadata()
120
+ def hashstats(
121
+ images: Iterable[ArrayLike],
122
+ bboxes: Iterable[ArrayLike] | None = None,
123
+ ) -> HashStatsOutput:
124
+ """
125
+ Calculates hashes for each image
126
+
127
+ This function computes hashes from the images including exact hashes and perception-based
128
+ hashes. These hash values can be used to determine if images are exact or near matches.
129
+
130
+ Parameters
131
+ ----------
132
+ images : ArrayLike
133
+ Images to hashing
134
+ bboxes : Iterable[ArrayLike] or None
135
+ Bounding boxes in `xyxy` format for each image
136
+
137
+ Returns
138
+ -------
139
+ HashStatsOutput
140
+ A dictionary-like object containing the computed hashes for each image.
141
+
142
+ See Also
143
+ --------
144
+ Duplicates
145
+
146
+ Examples
147
+ --------
148
+ Calculating the statistics on the images, whose shape is (C, H, W)
149
+
150
+ >>> results = hashstats(images)
151
+ >>> print(results.xxhash)
152
+ ['a72434443d6e7336', 'efc12c2f14581d79', '4a1e03483a27d674', '3a3ecedbcf814226']
153
+ >>> print(results.pchash)
154
+ ['8f25506af46a7c6a', '8000808000008080', '8e71f18e0ef18e0e', 'a956d6a956d6a928']
155
+ """
156
+ return run_stats(images, bboxes, False, [HashStatsProcessor])[0]
@@ -1,13 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["LabelStatsOutput", "labelstats"]
4
+
3
5
  from collections import Counter, defaultdict
4
6
  from dataclasses import dataclass
5
7
  from typing import Any, Iterable, Mapping, TypeVar
6
8
 
7
9
  from numpy.typing import ArrayLike
8
10
 
9
- from dataeval._internal.interop import to_numpy
10
- from dataeval._internal.output import OutputMetadata, set_metadata
11
+ from dataeval.interop import to_numpy
12
+ from dataeval.output import OutputMetadata, set_metadata
11
13
 
12
14
 
13
15
  @dataclass(frozen=True)
@@ -55,7 +57,7 @@ def sort(d: Mapping[TKey, Any]) -> dict[TKey, Any]:
55
57
  return dict(sorted(d.items(), key=lambda x: x[0]))
56
58
 
57
59
 
58
- @set_metadata("dataeval.metrics")
60
+ @set_metadata()
59
61
  def labelstats(
60
62
  labels: Iterable[ArrayLike],
61
63
  ) -> LabelStatsOutput:
@@ -1,14 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["PixelStatsOutput", "pixelstats"]
4
+
3
5
  from dataclasses import dataclass
4
- from typing import Iterable
6
+ from typing import Any, Callable, Iterable
5
7
 
6
8
  import numpy as np
7
9
  from numpy.typing import ArrayLike, NDArray
8
10
  from scipy.stats import entropy, kurtosis, skew
9
11
 
10
- from dataeval._internal.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
11
- from dataeval._internal.output import set_metadata
12
+ from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
13
+ from dataeval.output import set_metadata
12
14
 
13
15
 
14
16
  @dataclass(frozen=True)
@@ -44,9 +46,8 @@ class PixelStatsOutput(BaseStatsOutput):
44
46
 
45
47
 
46
48
  class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
47
- output_class = PixelStatsOutput
48
- cache_keys = ["histogram"]
49
- image_function_map = {
49
+ output_class: type = PixelStatsOutput
50
+ image_function_map: dict[str, Callable[[StatsProcessor[PixelStatsOutput]], Any]] = {
50
51
  "mean": lambda self: np.mean(self.scaled),
51
52
  "std": lambda x: np.std(x.scaled),
52
53
  "var": lambda x: np.var(x.scaled),
@@ -55,7 +56,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
55
56
  "histogram": lambda x: np.histogram(x.scaled, 256, (0, 1))[0],
56
57
  "entropy": lambda x: entropy(x.get("histogram")),
57
58
  }
58
- channel_function_map = {
59
+ channel_function_map: dict[str, Callable[[StatsProcessor[PixelStatsOutput]], Any]] = {
59
60
  "mean": lambda x: np.mean(x.scaled, axis=1),
60
61
  "std": lambda x: np.std(x.scaled, axis=1),
61
62
  "var": lambda x: np.var(x.scaled, axis=1),
@@ -66,7 +67,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
66
67
  }
67
68
 
68
69
 
69
- @set_metadata("dataeval.metrics")
70
+ @set_metadata()
70
71
  def pixelstats(
71
72
  images: Iterable[ArrayLike],
72
73
  bboxes: Iterable[ArrayLike] | None = None,
@@ -1,14 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["VisualStatsOutput", "visualstats"]
4
+
3
5
  from dataclasses import dataclass
4
- from typing import Iterable
6
+ from typing import Any, Callable, Iterable
5
7
 
6
8
  import numpy as np
7
9
  from numpy.typing import ArrayLike, NDArray
8
10
 
9
- from dataeval._internal.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
10
- from dataeval._internal.metrics.utils import edge_filter
11
- from dataeval._internal.output import set_metadata
11
+ from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
12
+ from dataeval.output import set_metadata
13
+ from dataeval.utils.image import edge_filter
12
14
 
13
15
  QUARTILES = (0, 25, 50, 75, 100)
14
16
 
@@ -46,9 +48,8 @@ class VisualStatsOutput(BaseStatsOutput):
46
48
 
47
49
 
48
50
  class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
49
- output_class = VisualStatsOutput
50
- cache_keys = ["percentiles"]
51
- image_function_map = {
51
+ output_class: type = VisualStatsOutput
52
+ image_function_map: dict[str, Callable[[StatsProcessor[VisualStatsOutput]], Any]] = {
52
53
  "brightness": lambda x: x.get("percentiles")[1],
53
54
  "contrast": lambda x: np.nan_to_num(
54
55
  (np.max(x.get("percentiles")) - np.min(x.get("percentiles"))) / np.mean(x.get("percentiles"))
@@ -59,7 +60,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
59
60
  "zeros": lambda x: np.count_nonzero(np.sum(x.image, axis=0) == 0) / np.prod(x.shape[-2:]),
60
61
  "percentiles": lambda x: np.nanpercentile(x.scaled, q=QUARTILES),
61
62
  }
62
- channel_function_map = {
63
+ channel_function_map: dict[str, Callable[[StatsProcessor[VisualStatsOutput]], Any]] = {
63
64
  "brightness": lambda x: x.get("percentiles")[:, 1],
64
65
  "contrast": lambda x: np.nan_to_num(
65
66
  (np.max(x.get("percentiles"), axis=1) - np.min(x.get("percentiles"), axis=1))
@@ -73,7 +74,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
73
74
  }
74
75
 
75
76
 
76
- @set_metadata("dataeval.metrics")
77
+ @set_metadata()
77
78
  def visualstats(
78
79
  images: Iterable[ArrayLike],
79
80
  bboxes: Iterable[ArrayLike] | None = None,
@@ -1,12 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = []
4
+
3
5
  import inspect
6
+ import sys
4
7
  from datetime import datetime, timezone
5
8
  from functools import wraps
6
- from typing import Any
9
+ from typing import Any, Callable, Iterable, TypeVar
7
10
 
8
11
  import numpy as np
9
12
 
13
+ if sys.version_info >= (3, 10):
14
+ from typing import ParamSpec
15
+ else:
16
+ from typing_extensions import ParamSpec
17
+
10
18
  from dataeval import __version__
11
19
 
12
20
 
@@ -25,10 +33,18 @@ class OutputMetadata:
25
33
  return {k.removeprefix("_"): v for k, v in self.__dict__.items() if k.startswith("_")}
26
34
 
27
35
 
28
- def set_metadata(module_name: str = "", state_attr: list[str] | None = None):
29
- def decorator(fn):
36
+ P = ParamSpec("P")
37
+ R = TypeVar("R", bound=OutputMetadata)
38
+
39
+
40
+ def set_metadata(
41
+ state_attr: Iterable[str] | None = None,
42
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
43
+ """Decorator to stamp OutputMetadata classes with runtime metadata"""
44
+
45
+ def decorator(fn: Callable[P, R]) -> Callable[P, R]:
30
46
  @wraps(fn)
31
- def wrapper(*args, **kwargs):
47
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
32
48
  def fmt(v):
33
49
  if np.isscalar(v):
34
50
  return v
@@ -52,9 +68,13 @@ def set_metadata(module_name: str = "", state_attr: list[str] | None = None):
52
68
  if "self" in arguments and state_attr
53
69
  else {}
54
70
  )
55
- name = args[0].__class__.__name__ if "self" in arguments else fn.__name__
71
+ name = (
72
+ f"{args[0].__class__.__module__}.{args[0].__class__.__name__}.{fn.__name__}"
73
+ if "self" in arguments
74
+ else f"{fn.__module__}.{fn.__qualname__}"
75
+ )
56
76
  metadata = {
57
- "_name": f"{module_name}.{name}",
77
+ "_name": name,
58
78
  "_execution_time": time,
59
79
  "_execution_duration": duration,
60
80
  "_arguments": {k: v for k, v in arguments.items() if k != "self"},
@@ -5,15 +5,19 @@ metrics. Currently DataEval supports both :term:`TensorFlow` and PyTorch backend
5
5
  """
6
6
 
7
7
  from dataeval import _IS_TENSORFLOW_AVAILABLE, _IS_TORCH_AVAILABLE
8
+ from dataeval.utils.split_dataset import split_dataset
8
9
 
9
- __all__ = []
10
+ __all__ = ["split_dataset"]
10
11
 
11
12
  if _IS_TORCH_AVAILABLE: # pragma: no cover
12
- from . import torch
13
+ from dataeval.utils import torch
13
14
 
14
15
  __all__ += ["torch"]
15
16
 
16
17
  if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
17
- from . import tensorflow
18
+ from dataeval.utils import tensorflow
18
19
 
19
20
  __all__ += ["tensorflow"]
21
+
22
+ del _IS_TENSORFLOW_AVAILABLE
23
+ del _IS_TORCH_AVAILABLE
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any, NamedTuple
6
+
7
+ import numpy as np
8
+ from numpy.typing import ArrayLike, NDArray
9
+ from scipy.signal import convolve2d
10
+
11
+ EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
12
+ BIT_DEPTH = (1, 8, 12, 16, 32)
13
+
14
+
15
+ class BitDepth(NamedTuple):
16
+ depth: int
17
+ pmin: float | int
18
+ pmax: float | int
19
+
20
+
21
+ def get_bitdepth(image: NDArray[Any]) -> BitDepth:
22
+ """
23
+ Approximates the bit depth of the image using the
24
+ min and max pixel values.
25
+ """
26
+ pmin, pmax = np.min(image), np.max(image)
27
+ if pmin < 0:
28
+ return BitDepth(0, pmin, pmax)
29
+ else:
30
+ depth = ([x for x in BIT_DEPTH if 2**x > pmax] or [max(BIT_DEPTH)])[0]
31
+ return BitDepth(depth, 0, 2**depth - 1)
32
+
33
+
34
+ def rescale(image: NDArray[Any], depth: int = 1) -> NDArray[Any]:
35
+ """
36
+ Rescales the image using the bit depth provided.
37
+ """
38
+ bitdepth = get_bitdepth(image)
39
+ if bitdepth.depth == depth:
40
+ return image
41
+ else:
42
+ normalized = (image + bitdepth.pmin) / (bitdepth.pmax - bitdepth.pmin)
43
+ return normalized * (2**depth - 1)
44
+
45
+
46
+ def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
47
+ """
48
+ Normalizes the image shape into (C,H,W).
49
+ """
50
+ ndim = image.ndim
51
+ if ndim == 2:
52
+ return np.expand_dims(image, axis=0)
53
+ elif ndim == 3:
54
+ return image
55
+ elif ndim > 3:
56
+ # Slice all but the last 3 dimensions
57
+ return image[(0,) * (ndim - 3)]
58
+ else:
59
+ raise ValueError("Images must have 2 or more dimensions.")
60
+
61
+
62
+ def edge_filter(image: ArrayLike, offset: float = 0.5) -> NDArray[np.uint8]:
63
+ """
64
+ Returns the image filtered using a 3x3 edge detection kernel:
65
+ [[ -1, -1, -1 ],
66
+ [ -1, 8, -1 ],
67
+ [ -1, -1, -1 ]]
68
+ """
69
+ edges = convolve2d(image, EDGE_KERNEL, mode="same", boundary="symm") + offset
70
+ np.clip(edges, 0, 255, edges)
71
+ return edges
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import sys
6
+ from typing import Any, Callable, Literal, TypeVar
7
+
8
+ import numpy as np
9
+ from numpy.typing import ArrayLike, NDArray
10
+ from scipy.sparse import csr_matrix
11
+ from scipy.sparse.csgraph import minimum_spanning_tree as mst
12
+ from scipy.spatial.distance import pdist, squareform
13
+ from sklearn.neighbors import NearestNeighbors
14
+
15
+ if sys.version_info >= (3, 10):
16
+ from typing import ParamSpec
17
+ else:
18
+ from typing_extensions import ParamSpec
19
+
20
+ from dataeval.interop import as_numpy
21
+
22
+ EPSILON = 1e-5
23
+ HASH_SIZE = 8
24
+ MAX_FACTOR = 4
25
+
26
+
27
+ P = ParamSpec("P")
28
+ R = TypeVar("R")
29
+
30
+
31
+ def get_method(method_map: dict[str, Callable[P, R]], method: str) -> Callable[P, R]:
32
+ if method not in method_map:
33
+ raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
34
+ return method_map[method]
35
+
36
+
37
+ def flatten(array: ArrayLike) -> NDArray[Any]:
38
+ """
39
+ Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
40
+
41
+ Parameters
42
+ ----------
43
+ X : NDArray, shape - (N, ... )
44
+ Input array
45
+
46
+ Returns
47
+ -------
48
+ NDArray, shape - (N, -1)
49
+ """
50
+ nparr = as_numpy(array)
51
+ return nparr.reshape((nparr.shape[0], -1))
52
+
53
+
54
+ def minimum_spanning_tree(X: NDArray[Any]) -> Any:
55
+ """
56
+ Returns the minimum spanning tree from a :term:`NumPy` image array.
57
+
58
+ Parameters
59
+ ----------
60
+ X : NDArray
61
+ Numpy image array
62
+
63
+ Returns
64
+ -------
65
+ Data representing the minimum spanning tree
66
+ """
67
+ # All features belong on second dimension
68
+ X = flatten(X)
69
+ # We add a small constant to the distance matrix to ensure scipy interprets
70
+ # the input graph as fully-connected.
71
+ dense_eudist = squareform(pdist(X)) + EPSILON
72
+ eudist_csr = csr_matrix(dense_eudist)
73
+ return mst(eudist_csr)
74
+
75
+
76
+ def get_classes_counts(labels: NDArray[np.int_]) -> tuple[int, int]:
77
+ """
78
+ Returns the classes and counts of from an array of labels
79
+
80
+ Parameters
81
+ ----------
82
+ label : NDArray
83
+ Numpy labels array
84
+
85
+ Returns
86
+ -------
87
+ Classes and counts
88
+
89
+ Raises
90
+ ------
91
+ ValueError
92
+ If the number of unique classes is less than 2
93
+ """
94
+ classes, counts = np.unique(labels, return_counts=True)
95
+ M = len(classes)
96
+ if M < 2:
97
+ raise ValueError("Label vector contains less than 2 classes!")
98
+ N = np.sum(counts).astype(int)
99
+ return M, N
100
+
101
+
102
+ def compute_neighbors(
103
+ A: NDArray[Any],
104
+ B: NDArray[Any],
105
+ k: int = 1,
106
+ algorithm: Literal["auto", "ball_tree", "kd_tree"] = "auto",
107
+ ) -> NDArray[Any]:
108
+ """
109
+ For each sample in A, compute the nearest neighbor in B
110
+
111
+ Parameters
112
+ ----------
113
+ A, B : NDArray
114
+ The n_samples and n_features respectively
115
+ k : int
116
+ The number of neighbors to find
117
+ algorithm : Literal
118
+ Tree method for nearest neighbor (auto, ball_tree or kd_tree)
119
+
120
+ Note
121
+ ----
122
+ Do not use kd_tree if n_features > 20
123
+
124
+ Returns
125
+ -------
126
+ List:
127
+ Closest points to each point in A and B
128
+
129
+ Raises
130
+ ------
131
+ ValueError
132
+ If algorithm is not "auto", "ball_tree", or "kd_tree"
133
+
134
+ See Also
135
+ --------
136
+ sklearn.neighbors.NearestNeighbors
137
+ """
138
+
139
+ if k < 1:
140
+ raise ValueError("k must be >= 1")
141
+ if algorithm not in ["auto", "ball_tree", "kd_tree"]:
142
+ raise ValueError("Algorithm must be 'auto', 'ball_tree', or 'kd_tree'")
143
+
144
+ A = flatten(A)
145
+ B = flatten(B)
146
+
147
+ nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm=algorithm).fit(B)
148
+ nns = nbrs.kneighbors(A)[1]
149
+ nns = nns[:, 1:].squeeze()
150
+
151
+ return nns