dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +23 -14
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +51 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.1.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -6,11 +6,12 @@ from dataclasses import dataclass
6
6
  from typing import Any, Callable, Iterable
7
7
 
8
8
  import numpy as np
9
- from numpy.typing import ArrayLike, NDArray
9
+ from numpy.typing import NDArray
10
10
 
11
- from dataeval.metrics.stats.base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
12
- from dataeval.output import set_metadata
13
- from dataeval.utils.image import edge_filter
11
+ from dataeval._output import set_metadata
12
+ from dataeval.metrics.stats._base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
13
+ from dataeval.typing import ArrayLike
14
+ from dataeval.utils._image import edge_filter
14
15
 
15
16
  QUARTILES = (0, 25, 50, 75, 100)
16
17
 
@@ -18,7 +19,7 @@ QUARTILES = (0, 25, 50, 75, 100)
18
19
  @dataclass(frozen=True)
19
20
  class VisualStatsOutput(BaseStatsOutput, HistogramPlotMixin):
20
21
  """
21
- Output class for :func:`visualstats` stats metric.
22
+ Output class for :func:`.visualstats` stats metric.
22
23
 
23
24
  Attributes
24
25
  ----------
@@ -53,9 +54,9 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
53
54
  output_class: type = VisualStatsOutput
54
55
  image_function_map: dict[str, Callable[[StatsProcessor[VisualStatsOutput]], Any]] = {
55
56
  "brightness": lambda x: x.get("percentiles")[1],
56
- "contrast": lambda x: np.nan_to_num(
57
- (np.max(x.get("percentiles")) - np.min(x.get("percentiles"))) / np.mean(x.get("percentiles"))
58
- ),
57
+ "contrast": lambda x: 0
58
+ if np.mean(x.get("percentiles")) == 0
59
+ else (np.max(x.get("percentiles")) - np.min(x.get("percentiles"))) / np.mean(x.get("percentiles")),
59
60
  "darkness": lambda x: x.get("percentiles")[-2],
60
61
  "missing": lambda x: np.count_nonzero(np.isnan(np.sum(x.image, axis=0))) / np.prod(x.shape[-2:]),
61
62
  "sharpness": lambda x: np.std(edge_filter(np.mean(x.image, axis=0))),
dataeval/typing.py ADDED
@@ -0,0 +1,54 @@
1
+ """
2
+ Common type hints used for interoperability with DataEval.
3
+ """
4
+
5
+ __all__ = ["Array", "ArrayLike"]
6
+
7
+ from typing import Any, Iterator, Protocol, Sequence, TypeVar, Union, runtime_checkable
8
+
9
+
10
+ @runtime_checkable
11
+ class Array(Protocol):
12
+ """
13
+ Protocol for array objects providing interoperability with DataEval.
14
+
15
+ Supports common array representations with popular libraries like
16
+ PyTorch, Tensorflow and JAX, as well as NumPy arrays.
17
+
18
+ Example
19
+ -------
20
+ >>> import numpy as np
21
+ >>> import torch
22
+ >>> from dataeval.typing import Array
23
+
24
+ Create array objects
25
+
26
+ >>> ndarray = np.random.random((10, 10))
27
+ >>> tensor = torch.tensor([1, 2, 3])
28
+
29
+ Check type at runtime
30
+
31
+ >>> isinstance(ndarray, Array)
32
+ True
33
+
34
+ >>> isinstance(tensor, Array)
35
+ True
36
+ """
37
+
38
+ @property
39
+ def shape(self) -> tuple[int, ...]: ...
40
+ def __array__(self) -> Any: ...
41
+ def __getitem__(self, key: Any, /) -> Any: ...
42
+ def __iter__(self) -> Iterator[Any]: ...
43
+ def __len__(self) -> int: ...
44
+
45
+
46
+ TArray = TypeVar("TArray", bound=Array)
47
+
48
+ ArrayLike = Union[Sequence[Any], Array]
49
+ """
50
+ Type alias for array-like objects used for interoperability with DataEval.
51
+
52
+ This includes native Python sequences, as well as objects that conform to
53
+ the `Array` protocol.
54
+ """
@@ -4,6 +4,6 @@ in setting up data and architectures that are guaranteed to work with applicable
4
4
  DataEval metrics.
5
5
  """
6
6
 
7
- __all__ = ["dataset", "metadata", "torch"]
7
+ __all__ = ["data", "metadata", "torch"]
8
8
 
9
- from dataeval.utils import dataset, metadata, torch
9
+ from . import data, metadata, torch
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import logging
6
+ import warnings
7
+ from importlib import import_module
8
+ from types import ModuleType
9
+ from typing import Any, Iterable, Iterator, Literal, TypeVar, overload
10
+
11
+ import numpy as np
12
+ import torch
13
+ from numpy.typing import NDArray
14
+
15
+ from dataeval._log import LogMessage
16
+ from dataeval.typing import ArrayLike
17
+
18
+ _logger = logging.getLogger(__name__)
19
+
20
+ _MODULE_CACHE = {}
21
+
22
+ T = TypeVar("T", ArrayLike, np.ndarray, torch.Tensor)
23
+ _np_dtype = TypeVar("_np_dtype", bound=np.generic)
24
+
25
+
26
+ def _try_import(module_name) -> ModuleType | None:
27
+ if module_name in _MODULE_CACHE:
28
+ return _MODULE_CACHE[module_name]
29
+
30
+ try:
31
+ module = import_module(module_name)
32
+ except ImportError: # pragma: no cover
33
+ _logger.log(logging.INFO, f"Unable to import {module_name}.")
34
+ module = None
35
+
36
+ _MODULE_CACHE[module_name] = module
37
+ return module
38
+
39
+
40
+ def as_numpy(array: ArrayLike | None) -> NDArray[Any]:
41
+ """Converts an ArrayLike to Numpy array without copying (if possible)"""
42
+ return to_numpy(array, copy=False)
43
+
44
+
45
+ def to_numpy(array: ArrayLike | None, copy: bool = True) -> NDArray[Any]:
46
+ """Converts an ArrayLike to new Numpy array"""
47
+ if array is None:
48
+ return np.ndarray([])
49
+
50
+ if isinstance(array, np.ndarray):
51
+ return array.copy() if copy else array
52
+
53
+ if array.__class__.__module__.startswith("tensorflow"): # pragma: no cover - removed tf from deps
54
+ tf = _try_import("tensorflow")
55
+ if tf and tf.is_tensor(array):
56
+ _logger.log(logging.INFO, "Converting Tensorflow array to NumPy array.")
57
+ return array.numpy().copy() if copy else array.numpy() # type: ignore
58
+
59
+ if array.__class__.__module__.startswith("torch"):
60
+ torch = _try_import("torch")
61
+ if torch and isinstance(array, torch.Tensor):
62
+ _logger.log(logging.INFO, "Converting PyTorch array to NumPy array.")
63
+ numpy = array.detach().cpu().numpy().copy() if copy else array.detach().cpu().numpy() # type: ignore
64
+ _logger.log(logging.DEBUG, LogMessage(lambda: f"{str(array)} -> {str(numpy)}"))
65
+ return numpy
66
+
67
+ return np.array(array) if copy else np.asarray(array)
68
+
69
+
70
+ def to_numpy_iter(iterable: Iterable[ArrayLike]) -> Iterator[NDArray[Any]]:
71
+ """Yields an iterator of numpy arrays from an ArrayLike"""
72
+ for array in iterable:
73
+ yield to_numpy(array)
74
+
75
+
76
+ @overload
77
+ def ensure_embeddings(
78
+ embeddings: T,
79
+ dtype: torch.dtype,
80
+ unit_interval: Literal[True, False, "force"] = False,
81
+ ) -> torch.Tensor: ...
82
+
83
+
84
+ @overload
85
+ def ensure_embeddings(
86
+ embeddings: T,
87
+ dtype: type[_np_dtype],
88
+ unit_interval: Literal[True, False, "force"] = False,
89
+ ) -> NDArray[_np_dtype]: ...
90
+
91
+
92
+ @overload
93
+ def ensure_embeddings(
94
+ embeddings: T,
95
+ dtype: None,
96
+ unit_interval: Literal[True, False, "force"] = False,
97
+ ) -> T: ...
98
+
99
+
100
+ def ensure_embeddings(
101
+ embeddings: T,
102
+ dtype: type[_np_dtype] | torch.dtype | None = None,
103
+ unit_interval: Literal[True, False, "force"] = False,
104
+ ) -> torch.Tensor | NDArray[_np_dtype] | T:
105
+ """
106
+ Validates the embeddings array and converts it to the specified type
107
+
108
+ Parameters
109
+ ----------
110
+ embeddings : ArrayLike
111
+ Embeddings array
112
+ dtype : numpy dtype or torch dtype or None, default None
113
+ The desired dtype of the output array, None to skip conversion
114
+ unit_interval : bool or "force", default False
115
+ Whether to validate or force the embeddings to unit interval
116
+
117
+ Returns
118
+ -------
119
+ Converted embeddings array
120
+
121
+ Raises
122
+ ------
123
+ ValueError
124
+ If the embeddings array is not 2D
125
+ ValueError
126
+ If the embeddings array is not unit interval [0, 1]
127
+ """
128
+ if isinstance(dtype, torch.dtype):
129
+ arr = torch.as_tensor(embeddings, dtype=dtype)
130
+ else:
131
+ arr = (
132
+ embeddings.detach().cpu().numpy().astype(dtype)
133
+ if isinstance(embeddings, torch.Tensor)
134
+ else np.asarray(embeddings, dtype=dtype)
135
+ )
136
+
137
+ if arr.ndim != 2:
138
+ raise ValueError(f"Expected a 2D array, but got a {arr.ndim}D array.")
139
+
140
+ if unit_interval:
141
+ arr_min, arr_max = arr.min(), arr.max()
142
+ if arr_min < 0 or arr_max > 1:
143
+ if unit_interval == "force":
144
+ warnings.warn("Embeddings are not unit interval [0, 1]. Forcing to unit interval.")
145
+ arr = (arr - arr_min) / (arr_max - arr_min)
146
+ else:
147
+ raise ValueError("Embeddings must be unit interval [0, 1].")
148
+
149
+ if dtype is None:
150
+ return embeddings
151
+ else:
152
+ return arr
153
+
154
+
155
+ def flatten(array: ArrayLike) -> NDArray[Any]:
156
+ """
157
+ Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
158
+
159
+ Parameters
160
+ ----------
161
+ X : NDArray, shape - (N, ... )
162
+ Input array
163
+
164
+ Returns
165
+ -------
166
+ NDArray, shape - (N, -1)
167
+ """
168
+ nparr = as_numpy(array)
169
+ return nparr.reshape((nparr.shape[0], -1))
dataeval/utils/_bin.py ADDED
@@ -0,0 +1,199 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import warnings
6
+ from typing import Any, Iterable
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+ from scipy.stats import wasserstein_distance as wd
11
+
12
+ DISCRETE_MIN_WD = 0.054
13
+ CONTINUOUS_MIN_SAMPLE_SIZE = 20
14
+
15
+
16
+ def get_counts(data: NDArray[np.int_], min_num_bins: int | None = None) -> NDArray[np.int_]:
17
+ """
18
+ Returns columnwise unique counts for discrete data.
19
+
20
+ Parameters
21
+ ----------
22
+ data : NDArray
23
+ Array containing integer values for metadata factors
24
+ min_num_bins : int | None, default None
25
+ Minimum number of bins for bincount, helps force consistency across runs
26
+
27
+ Returns
28
+ -------
29
+ NDArray[np.int]
30
+ Bin counts per column of data.
31
+ """
32
+ max_value = data.max() + 1 if min_num_bins is None else min_num_bins
33
+ cnt_array = np.zeros((max_value, data.shape[1]), dtype=np.int_)
34
+ for idx in range(data.shape[1]):
35
+ cnt_array[:, idx] = np.bincount(data[:, idx], minlength=max_value)
36
+
37
+ return cnt_array
38
+
39
+
40
+ def digitize_data(data: list[Any] | NDArray[Any], bins: int | Iterable[float]) -> NDArray[np.intp]:
41
+ """
42
+ Digitizes a list of values into a given number of bins.
43
+
44
+ Parameters
45
+ ----------
46
+ data : list | NDArray
47
+ The values to be digitized.
48
+ bins : int | Iterable[float]
49
+ The number of bins or list of bin edges for the discrete values that data will be digitized into.
50
+
51
+ Returns
52
+ -------
53
+ NDArray[np.intp]
54
+ The digitized values
55
+ """
56
+
57
+ if not np.all([np.issubdtype(type(n), np.number) for n in data]):
58
+ raise TypeError(
59
+ "Encountered a data value with non-numeric type when digitizing a factor. "
60
+ "Ensure all occurrences of continuous factors are numeric types."
61
+ )
62
+ if isinstance(bins, int):
63
+ _, bin_edges = np.histogram(data, bins=bins)
64
+ bin_edges[-1] = np.inf
65
+ bin_edges[0] = -np.inf
66
+ else:
67
+ bin_edges = list(bins)
68
+ return np.digitize(data, bin_edges)
69
+
70
+
71
+ def bin_data(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
72
+ """
73
+ Bins continuous data through either equal width bins, equal amounts in each bin, or by clusters.
74
+ """
75
+ if bin_method == "clusters":
76
+ bin_edges = bin_by_clusters(data)
77
+
78
+ else:
79
+ counts, bin_edges = np.histogram(data, bins="auto")
80
+ n_bins = counts.size
81
+ if counts[counts > 0].min() < 10:
82
+ counter = 20
83
+ while counts[counts > 0].min() < 10 and n_bins >= 2 and counter > 0:
84
+ counter -= 1
85
+ n_bins -= 1
86
+ counts, bin_edges = np.histogram(data, bins=n_bins)
87
+
88
+ if bin_method == "uniform_count":
89
+ quantiles = np.linspace(0, 100, n_bins + 1)
90
+ bin_edges = np.asarray(np.percentile(data, quantiles))
91
+
92
+ bin_edges[0] = -np.inf
93
+ bin_edges[-1] = np.inf
94
+ return np.digitize(data, bin_edges)
95
+
96
+
97
+ def is_continuous(data: NDArray[np.number[Any]], image_indices: NDArray[np.number[Any]]) -> bool:
98
+ """
99
+ Determines whether the data is continuous or discrete using the Wasserstein distance.
100
+
101
+ Given a 1D sample, we consider the intervals between adjacent points. For a continuous distribution,
102
+ a point is equally likely to lie anywhere in the interval bounded by its two neighbors. Furthermore,
103
+ we can put all "between neighbor" locations on the same scale of 0 to 1 by subtracting the smaller
104
+ neighbor and dividing out the length of the interval. (Duplicates are either assigned to zero or
105
+ ignored, depending on context). These normalized locations will be much more uniformly distributed
106
+ for continuous data than for discrete, and this gives us a way to distinguish them. Call this the
107
+ Normalized Near Neighbor distribution (NNN), defined on the interval [0,1].
108
+
109
+ The Wasserstein distance is available in scipy.stats.wasserstein_distance. We can use it to measure
110
+ how close the NNN is to a uniform distribution over [0,1]. We found that as long as a sample has at
111
+ least 20 points, and furthermore at least half as many points as there are discrete values, we can
112
+ reliably distinguish discrete from continuous samples by testing that the Wasserstein distance
113
+ measured from a uniform distribution is greater or less than 0.054, respectively.
114
+ """
115
+ # Check if the metadata is image specific
116
+ _, data_indices_unsorted = np.unique(data, return_index=True)
117
+ if data_indices_unsorted.size == image_indices.size:
118
+ data_indices = np.sort(data_indices_unsorted)
119
+ if (data_indices == image_indices).all():
120
+ data = data[data_indices]
121
+
122
+ n_examples = len(data)
123
+
124
+ if n_examples < CONTINUOUS_MIN_SAMPLE_SIZE:
125
+ warnings.warn(
126
+ f"All samples look discrete with so few data points (< {CONTINUOUS_MIN_SAMPLE_SIZE})", UserWarning
127
+ )
128
+ return False
129
+
130
+ # Require at least 3 unique values before bothering with NNN
131
+ xu = np.unique(data, axis=None)
132
+ if xu.size < 3:
133
+ return False
134
+
135
+ Xs = np.sort(data)
136
+
137
+ X0, X1 = Xs[0:-2], Xs[2:] # left and right neighbors
138
+
139
+ dx = np.zeros(n_examples - 2) # no dx at end points
140
+ gtz = (X1 - X0) > 0 # check for dups; dx will be zero for them
141
+ dx[np.logical_not(gtz)] = 0.0
142
+
143
+ dx[gtz] = (Xs[1:-1] - X0)[gtz] / (X1 - X0)[gtz] # the core idea: dx is NNN samples.
144
+
145
+ shift = wd(dx, np.linspace(0, 1, dx.size)) # how far is dx from uniform, for this feature?
146
+
147
+ return shift < DISCRETE_MIN_WD # if NNN is close enough to uniform, consider the sample continuous.
148
+
149
+
150
+ def bin_by_clusters(data: NDArray[np.number[Any]]) -> NDArray[np.float64]:
151
+ """
152
+ Bins continuous data by using the Clusterer to identify clusters
153
+ and incorporates outliers by adding them to the nearest bin.
154
+ """
155
+ # Delay load numba compiled functions
156
+ from dataeval.utils._clusterer import cluster
157
+
158
+ # Create initial clusters
159
+ c = cluster(data)
160
+
161
+ # Create bins from clusters
162
+ bin_edges = np.zeros(c.clusters.max() + 2)
163
+ for group in range(c.clusters.max() + 1):
164
+ points = np.nonzero(c.clusters == group)[0]
165
+ bin_edges[group] = data[points].min()
166
+
167
+ # Get the outliers
168
+ outliers = np.nonzero(c.clusters == -1)[0]
169
+
170
+ # Identify non-outlier neighbors
171
+ nbrs = c.k_neighbors[outliers]
172
+ nbrs = np.where(np.isin(nbrs, outliers), -1, nbrs)
173
+
174
+ # Find the nearest non-outlier neighbor for each outlier
175
+ nn = np.full(outliers.size, -1, dtype=np.int32)
176
+ for row in range(outliers.size):
177
+ non_outliers = nbrs[row, nbrs[row] != -1]
178
+ if non_outliers.size > 0:
179
+ nn[row] = non_outliers[0]
180
+
181
+ # Group outliers by their neighbors
182
+ unique_nnbrs, same_nbr, counts = np.unique(nn, return_inverse=True, return_counts=True)
183
+
184
+ # Adjust bin_edges based on each unique neighbor group
185
+ extend_bins = []
186
+ for i, nnbr in enumerate(unique_nnbrs):
187
+ outlier_indices = np.nonzero(same_nbr == i)[0]
188
+ min2add = data[outliers[outlier_indices]].min()
189
+ if counts[i] >= 4:
190
+ extend_bins.append(min2add)
191
+ else:
192
+ if min2add < data[nnbr]:
193
+ clusters = c.clusters[nnbr]
194
+ bin_edges[clusters] = min2add
195
+ if extend_bins:
196
+ bin_edges = np.concatenate([bin_edges, extend_bins])
197
+
198
+ bin_edges = np.sort(bin_edges)
199
+ return bin_edges
@@ -0,0 +1,144 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import warnings
6
+ from dataclasses import dataclass
7
+
8
+ import numba
9
+ import numpy as np
10
+ from numpy.typing import NDArray
11
+
12
+ with warnings.catch_warnings():
13
+ warnings.simplefilter("ignore", category=FutureWarning)
14
+ from fast_hdbscan.cluster_trees import (
15
+ CondensedTree,
16
+ cluster_tree_from_condensed_tree,
17
+ condense_tree,
18
+ ds_find,
19
+ ds_rank_create,
20
+ ds_union_by_rank,
21
+ extract_eom_clusters,
22
+ get_cluster_label_vector,
23
+ get_point_membership_strength_vector,
24
+ mst_to_linkage_tree,
25
+ )
26
+
27
+ from dataeval.typing import ArrayLike
28
+ from dataeval.utils._array import flatten, to_numpy
29
+ from dataeval.utils._fast_mst import calculate_neighbor_distances, minimum_spanning_tree
30
+
31
+
32
+ @numba.njit(parallel=True, locals={"i": numba.types.int32})
33
+ def compare_links_to_cluster_std(mst, clusters):
34
+ cluster_ids = np.unique(clusters)
35
+ cluster_grouping = np.full(mst.shape[0], -1, dtype=np.int16)
36
+
37
+ for i in numba.prange(mst.shape[0]):
38
+ cluster_id = clusters[np.int32(mst[i, 0])]
39
+ if cluster_id == clusters[np.int32(mst[i, 1])]:
40
+ cluster_grouping[i] = np.int16(cluster_id)
41
+
42
+ overall_mean = mst.T[2].mean()
43
+ order_mag = np.floor(np.log10(overall_mean)) if overall_mean > 0 else 0
44
+ compare_mag = -3 if order_mag >= 0 else order_mag - 3
45
+
46
+ exact_dup = np.full((mst.shape[0], 2), -1, dtype=np.int32)
47
+ exact_dups_index = np.nonzero(mst[:, 2] < 10**compare_mag)[0]
48
+ exact_dup[exact_dups_index] = mst[exact_dups_index, :2]
49
+
50
+ near_dup = np.full((mst.shape[0], 2), -1, dtype=np.int32)
51
+ for i in range(cluster_ids.size):
52
+ cluster_links = np.nonzero(cluster_grouping == cluster_ids[i])[0]
53
+ cluster_std = mst[cluster_links, 2].std()
54
+
55
+ near_dups = np.nonzero(mst[cluster_links, 2] < cluster_std)[0]
56
+ near_dups_index = cluster_links[near_dups]
57
+ near_dup[near_dups_index] = mst[near_dups_index, :2]
58
+
59
+ exact_idx = np.nonzero(exact_dup.T[0] != -1)[0]
60
+ near_dup[exact_idx] = np.full((exact_idx.size, 2), -1, dtype=np.int32)
61
+ near_idx = np.nonzero(near_dup.T[0] != -1)[0]
62
+
63
+ return exact_dup[exact_idx], near_dup[near_idx]
64
+
65
+
66
+ @dataclass
67
+ class ClusterData:
68
+ clusters: NDArray[np.intp]
69
+ mst: NDArray[np.double]
70
+ linkage_tree: NDArray[np.double]
71
+ condensed_tree: CondensedTree
72
+ membership_strengths: NDArray[np.double]
73
+ k_neighbors: NDArray[np.int32]
74
+ k_distances: NDArray[np.double]
75
+
76
+
77
+ def cluster(data: ArrayLike) -> ClusterData:
78
+ single_cluster = False
79
+ cluster_selection_epsilon = 0.0
80
+ # cluster_selection_method = "eom"
81
+
82
+ x = flatten(to_numpy(data))
83
+ samples, features = x.shape # Due to flatten(), we know shape has a length of 2
84
+ if samples < 2:
85
+ raise ValueError(f"Data should have at least 2 samples; got {samples}")
86
+ if features < 1:
87
+ raise ValueError(f"Samples should have at least 1 feature; got {features}")
88
+
89
+ num_samples = len(x)
90
+ min_num = int(num_samples * 0.05)
91
+ min_cluster_size: int = min(max(5, min_num), 100)
92
+
93
+ max_neighbors = min(25, num_samples - 1)
94
+ kneighbors, kdistances = calculate_neighbor_distances(x, max_neighbors)
95
+ unsorted_mst: NDArray[np.double] = minimum_spanning_tree(x, kneighbors, kdistances)
96
+ mst: NDArray[np.double] = unsorted_mst[np.argsort(unsorted_mst.T[2])]
97
+ linkage_tree: NDArray[np.double] = mst_to_linkage_tree(mst)
98
+ condensed_tree: CondensedTree = condense_tree(linkage_tree, min_cluster_size, None)
99
+
100
+ cluster_tree = cluster_tree_from_condensed_tree(condensed_tree)
101
+
102
+ selected_clusters = extract_eom_clusters(condensed_tree, cluster_tree, allow_single_cluster=single_cluster)
103
+
104
+ # Uncomment if cluster_selection_method is made a parameter
105
+ # if cluster_selection_method != "eom":
106
+ # selected_clusters = extract_leaves(condensed_tree, allow_single_cluster=single_cluster)
107
+
108
+ # Uncomment if cluster_selection_epsilon is made a parameter
109
+ # if len(selected_clusters) > 1 and cluster_selection_epsilon > 0.0:
110
+ # selected_clusters = cluster_epsilon_search(
111
+ # selected_clusters,
112
+ # cluster_tree,
113
+ # min_persistence=cluster_selection_epsilon,
114
+ # )
115
+
116
+ clusters = get_cluster_label_vector(
117
+ condensed_tree,
118
+ selected_clusters,
119
+ cluster_selection_epsilon,
120
+ n_samples=x.shape[0],
121
+ )
122
+
123
+ membership_strengths = get_point_membership_strength_vector(condensed_tree, selected_clusters, clusters)
124
+
125
+ return ClusterData(clusters, mst, linkage_tree, condensed_tree, membership_strengths, kneighbors, kdistances)
126
+
127
+
128
+ def sorted_union_find(index_groups):
129
+ """Merges and sorts groups of indices that share any common index"""
130
+ groups = [[np.int32(x) for x in range(0)] for y in range(0)]
131
+ uniques, inverse = np.unique(index_groups, return_inverse=True)
132
+ inverse = inverse.flatten()
133
+ disjoint_set = ds_rank_create(uniques.size)
134
+ cluster_points = np.empty(uniques.size, dtype=np.uint32)
135
+ for i in range(index_groups.shape[0]):
136
+ point, nbr = np.int32(inverse[i * 2]), np.int32(inverse[i * 2 + 1])
137
+ ds_union_by_rank(disjoint_set, point, nbr)
138
+ for i in range(uniques.size):
139
+ cluster_points[i] = ds_find(disjoint_set, i)
140
+ for i in range(uniques.size):
141
+ dups = np.nonzero(cluster_points == i)[0]
142
+ if dups.size > 0:
143
+ groups.append(uniques[dups].tolist())
144
+ return sorted(groups)