dataeval 0.74.2__py3-none-any.whl → 0.76.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +27 -23
- dataeval/detectors/__init__.py +2 -2
- dataeval/detectors/drift/__init__.py +14 -12
- dataeval/detectors/drift/base.py +3 -3
- dataeval/detectors/drift/cvm.py +1 -1
- dataeval/detectors/drift/ks.py +3 -2
- dataeval/detectors/drift/mmd.py +9 -7
- dataeval/detectors/drift/torch.py +12 -12
- dataeval/detectors/drift/uncertainty.py +5 -4
- dataeval/detectors/drift/updates.py +1 -1
- dataeval/detectors/linters/__init__.py +4 -4
- dataeval/detectors/linters/clusterer.py +5 -9
- dataeval/detectors/linters/duplicates.py +10 -14
- dataeval/detectors/linters/outliers.py +100 -5
- dataeval/detectors/ood/__init__.py +4 -11
- dataeval/detectors/ood/{ae_torch.py → ae.py} +6 -4
- dataeval/detectors/ood/base.py +47 -160
- dataeval/detectors/ood/metadata_ks_compare.py +34 -42
- dataeval/detectors/ood/metadata_least_likely.py +3 -3
- dataeval/detectors/ood/metadata_ood_mi.py +6 -5
- dataeval/detectors/ood/mixin.py +146 -0
- dataeval/detectors/ood/output.py +63 -0
- dataeval/interop.py +7 -6
- dataeval/{logging.py → log.py} +2 -0
- dataeval/metrics/__init__.py +3 -3
- dataeval/metrics/bias/__init__.py +10 -13
- dataeval/metrics/bias/balance.py +13 -11
- dataeval/metrics/bias/coverage.py +53 -5
- dataeval/metrics/bias/diversity.py +56 -24
- dataeval/metrics/bias/parity.py +20 -17
- dataeval/metrics/estimators/__init__.py +2 -2
- dataeval/metrics/estimators/ber.py +7 -4
- dataeval/metrics/estimators/divergence.py +4 -4
- dataeval/metrics/estimators/uap.py +4 -4
- dataeval/metrics/stats/__init__.py +19 -19
- dataeval/metrics/stats/base.py +28 -12
- dataeval/metrics/stats/boxratiostats.py +13 -14
- dataeval/metrics/stats/datasetstats.py +49 -20
- dataeval/metrics/stats/dimensionstats.py +8 -8
- dataeval/metrics/stats/hashstats.py +14 -10
- dataeval/metrics/stats/labelstats.py +94 -11
- dataeval/metrics/stats/pixelstats.py +11 -14
- dataeval/metrics/stats/visualstats.py +10 -13
- dataeval/output.py +23 -14
- dataeval/utils/__init__.py +5 -14
- dataeval/utils/dataset/__init__.py +7 -0
- dataeval/utils/{torch → dataset}/datasets.py +2 -0
- dataeval/utils/dataset/read.py +63 -0
- dataeval/utils/{split_dataset.py → dataset/split.py} +38 -30
- dataeval/utils/image.py +2 -2
- dataeval/utils/metadata.py +317 -14
- dataeval/{metrics/bias/metadata_utils.py → utils/plot.py} +91 -71
- dataeval/utils/torch/__init__.py +2 -17
- dataeval/utils/torch/gmm.py +29 -6
- dataeval/utils/torch/{utils.py → internal.py} +82 -58
- dataeval/utils/torch/models.py +10 -8
- dataeval/utils/torch/trainer.py +6 -85
- dataeval/workflows/__init__.py +2 -5
- dataeval/workflows/sufficiency.py +18 -8
- {dataeval-0.74.2.dist-info → dataeval-0.76.0.dist-info}/LICENSE.txt +2 -2
- dataeval-0.76.0.dist-info/METADATA +137 -0
- dataeval-0.76.0.dist-info/RECORD +67 -0
- dataeval/detectors/ood/base_torch.py +0 -109
- dataeval/metrics/bias/metadata_preprocessing.py +0 -285
- dataeval/utils/gmm.py +0 -26
- dataeval-0.74.2.dist-info/METADATA +0 -120
- dataeval-0.74.2.dist-info/RECORD +0 -66
- {dataeval-0.74.2.dist-info → dataeval-0.76.0.dist-info}/WHEEL +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
__all__ = [
|
3
|
+
__all__ = []
|
4
4
|
|
5
5
|
from dataclasses import dataclass
|
6
6
|
from typing import Any, Callable, Iterable
|
@@ -8,7 +8,7 @@ from typing import Any, Callable, Iterable
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import ArrayLike, NDArray
|
10
10
|
|
11
|
-
from dataeval.metrics.stats.base import BaseStatsOutput, StatsProcessor, run_stats
|
11
|
+
from dataeval.metrics.stats.base import BaseStatsOutput, HistogramPlotMixin, StatsProcessor, run_stats
|
12
12
|
from dataeval.output import set_metadata
|
13
13
|
from dataeval.utils.image import edge_filter
|
14
14
|
|
@@ -16,9 +16,9 @@ QUARTILES = (0, 25, 50, 75, 100)
|
|
16
16
|
|
17
17
|
|
18
18
|
@dataclass(frozen=True)
|
19
|
-
class VisualStatsOutput(BaseStatsOutput):
|
19
|
+
class VisualStatsOutput(BaseStatsOutput, HistogramPlotMixin):
|
20
20
|
"""
|
21
|
-
Output class for :func:`visualstats` stats metric
|
21
|
+
Output class for :func:`visualstats` stats metric.
|
22
22
|
|
23
23
|
Attributes
|
24
24
|
----------
|
@@ -46,6 +46,8 @@ class VisualStatsOutput(BaseStatsOutput):
|
|
46
46
|
zeros: NDArray[np.float16]
|
47
47
|
percentiles: NDArray[np.float16]
|
48
48
|
|
49
|
+
_excluded_keys = ["percentiles"]
|
50
|
+
|
49
51
|
|
50
52
|
class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
|
51
53
|
output_class: type = VisualStatsOutput
|
@@ -81,7 +83,7 @@ def visualstats(
|
|
81
83
|
per_channel: bool = False,
|
82
84
|
) -> VisualStatsOutput:
|
83
85
|
"""
|
84
|
-
Calculates visual statistics for each image
|
86
|
+
Calculates visual :term:`statistics` for each image.
|
85
87
|
|
86
88
|
This function computes various visual metrics (e.g., :term:`brightness<Brightness>`, darkness, contrast, blurriness)
|
87
89
|
on the images as a whole.
|
@@ -112,15 +114,10 @@ def visualstats(
|
|
112
114
|
--------
|
113
115
|
Calculating the :term:`statistics<Statistics>` on the images, whose shape is (C, H, W)
|
114
116
|
|
115
|
-
>>> results = visualstats(
|
117
|
+
>>> results = visualstats(stats_images)
|
116
118
|
>>> print(results.brightness)
|
117
|
-
[0.
|
118
|
-
0.3015 0.3347 0.3682 0.4014 0.4348 0.468 0.5015 0.5347 0.568
|
119
|
-
0.6016 0.635 0.668 0.701 0.735 0.768 0.8013 0.835 0.868
|
120
|
-
0.9014 0.9346 0.9683 ]
|
119
|
+
[0.1353 0.2085 0.4143 0.6084 0.8135]
|
121
120
|
>>> print(results.contrast)
|
122
|
-
[2.
|
123
|
-
1.258 1.257 1.257 1.256 1.256 1.255 1.255 1.255 1.255 1.254 1.254 1.254
|
124
|
-
1.254 1.254 1.254 1.253 1.253 1.253]
|
121
|
+
[2.04 1.331 1.261 1.279 1.253]
|
125
122
|
"""
|
126
123
|
return run_stats(images, bboxes, per_channel, [VisualStatsProcessor])[0]
|
dataeval/output.py
CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import inspect
|
6
|
+
import logging
|
6
7
|
import sys
|
7
8
|
from collections.abc import Mapping
|
8
9
|
from datetime import datetime, timezone
|
@@ -81,29 +82,37 @@ def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None =
|
|
81
82
|
return f"{v.__class__.__name__}: len={len(v)}"
|
82
83
|
return f"{v.__class__.__name__}"
|
83
84
|
|
84
|
-
|
85
|
-
result = fn(*args, **kwargs)
|
86
|
-
duration = (datetime.now(timezone.utc) - time).total_seconds()
|
87
|
-
fn_params = inspect.signature(fn).parameters
|
88
|
-
|
85
|
+
# Collect function metadata
|
89
86
|
# set all params with defaults then update params with mapped arguments and explicit keyword args
|
87
|
+
fn_params = inspect.signature(fn).parameters
|
90
88
|
arguments = {k: None if v.default is inspect.Parameter.empty else v.default for k, v in fn_params.items()}
|
91
89
|
arguments.update(zip(fn_params, args))
|
92
90
|
arguments.update(kwargs)
|
93
91
|
arguments = {k: fmt(v) for k, v in arguments.items()}
|
94
|
-
|
95
|
-
|
96
|
-
)
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
)
|
92
|
+
is_method = "self" in arguments
|
93
|
+
state_attrs = {k: fmt(getattr(args[0], k)) for k in state or []} if is_method else {}
|
94
|
+
module = args[0].__class__.__module__ if is_method else fn.__module__.removeprefix("src.")
|
95
|
+
class_prefix = f".{args[0].__class__.__name__}." if is_method else "."
|
96
|
+
name = f"{module}{class_prefix}{fn.__name__}"
|
97
|
+
arguments = {k: v for k, v in arguments.items() if k != "self"}
|
98
|
+
|
99
|
+
_logger = logging.getLogger(module)
|
100
|
+
time = datetime.now(timezone.utc)
|
101
|
+
_logger.log(logging.INFO, f">>> Executing '{name}': args={arguments} state={state} <<<")
|
102
|
+
|
103
|
+
##### EXECUTE FUNCTION #####
|
104
|
+
result = fn(*args, **kwargs)
|
105
|
+
############################
|
106
|
+
|
107
|
+
duration = (datetime.now(timezone.utc) - time).total_seconds()
|
108
|
+
_logger.log(logging.INFO, f">>> Completed '{name}': args={arguments} state={state} duration={duration} <<<")
|
109
|
+
|
110
|
+
# Update output with recorded metadata
|
102
111
|
metadata = {
|
103
112
|
"_name": name,
|
104
113
|
"_execution_time": time,
|
105
114
|
"_execution_duration": duration,
|
106
|
-
"_arguments":
|
115
|
+
"_arguments": arguments,
|
107
116
|
"_state": state_attrs,
|
108
117
|
"_version": __version__,
|
109
118
|
}
|
dataeval/utils/__init__.py
CHANGED
@@ -1,18 +1,9 @@
|
|
1
1
|
"""
|
2
|
-
The utility classes and functions are provided by DataEval to assist users
|
3
|
-
in setting up architectures that are guaranteed to work with applicable
|
4
|
-
|
2
|
+
The utility classes and functions are provided by DataEval to assist users \
|
3
|
+
in setting up data and architectures that are guaranteed to work with applicable \
|
4
|
+
DataEval metrics.
|
5
5
|
"""
|
6
6
|
|
7
|
-
|
8
|
-
from dataeval.utils.metadata import merge_metadata
|
9
|
-
from dataeval.utils.split_dataset import split_dataset
|
7
|
+
__all__ = ["dataset", "metadata", "torch"]
|
10
8
|
|
11
|
-
|
12
|
-
|
13
|
-
if _IS_TORCH_AVAILABLE:
|
14
|
-
from dataeval.utils import torch
|
15
|
-
|
16
|
-
__all__ += ["torch"]
|
17
|
-
|
18
|
-
del _IS_TORCH_AVAILABLE
|
9
|
+
from dataeval.utils import dataset, metadata, torch
|
@@ -0,0 +1,7 @@
|
|
1
|
+
"""Provides utility functions for interacting with Computer Vision datasets."""
|
2
|
+
|
3
|
+
__all__ = ["datasets", "read_dataset", "SplitDatasetOutput", "split_dataset"]
|
4
|
+
|
5
|
+
from dataeval.utils.dataset import datasets
|
6
|
+
from dataeval.utils.dataset.read import read_dataset
|
7
|
+
from dataeval.utils.dataset.split import SplitDatasetOutput, split_dataset
|
@@ -0,0 +1,63 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from collections import defaultdict
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
from torch.utils.data import Dataset
|
9
|
+
|
10
|
+
|
11
|
+
def read_dataset(dataset: Dataset[Any]) -> list[list[Any]]:
|
12
|
+
"""
|
13
|
+
Extract information from a dataset at each index into individual lists of each information position.
|
14
|
+
|
15
|
+
Parameters
|
16
|
+
----------
|
17
|
+
dataset : torch.utils.data.Dataset
|
18
|
+
Input dataset
|
19
|
+
|
20
|
+
Returns
|
21
|
+
-------
|
22
|
+
List[List[Any]]
|
23
|
+
All objects in individual lists based on return position from dataset
|
24
|
+
|
25
|
+
Warning
|
26
|
+
-------
|
27
|
+
No type checking is done between lists or data inside lists
|
28
|
+
|
29
|
+
See Also
|
30
|
+
--------
|
31
|
+
torch.utils.data.Dataset
|
32
|
+
|
33
|
+
Examples
|
34
|
+
--------
|
35
|
+
>>> import numpy as np
|
36
|
+
>>> data = np.ones((10, 1, 3, 3))
|
37
|
+
>>> labels = np.ones((10,))
|
38
|
+
>>> class ICDataset:
|
39
|
+
... def __init__(self, data, labels):
|
40
|
+
... self.data = data
|
41
|
+
... self.labels = labels
|
42
|
+
...
|
43
|
+
... def __getitem__(self, idx):
|
44
|
+
... return self.data[idx], self.labels[idx]
|
45
|
+
|
46
|
+
>>> ds = ICDataset(data, labels)
|
47
|
+
|
48
|
+
>>> result = read_dataset(ds)
|
49
|
+
>>> len(result) # images and labels
|
50
|
+
2
|
51
|
+
>>> np.asarray(result[0]).shape # images
|
52
|
+
(10, 1, 3, 3)
|
53
|
+
>>> np.asarray(result[1]).shape # labels
|
54
|
+
(10,)
|
55
|
+
"""
|
56
|
+
|
57
|
+
ddict: dict[int, list[Any]] = defaultdict(list[Any])
|
58
|
+
|
59
|
+
for data in dataset:
|
60
|
+
for i, d in enumerate(data if isinstance(data, tuple) else (data,)):
|
61
|
+
ddict[i].append(d)
|
62
|
+
|
63
|
+
return list(ddict.values())
|
@@ -1,12 +1,9 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
|
4
|
-
|
5
|
-
from dataeval.output import Output, set_metadata
|
6
|
-
|
7
|
-
__all__ = ["split_dataset", "SplitDatasetOutput"]
|
3
|
+
__all__ = []
|
8
4
|
|
9
5
|
import warnings
|
6
|
+
from dataclasses import dataclass
|
10
7
|
from typing import Any, Iterator, NamedTuple, Protocol
|
11
8
|
|
12
9
|
import numpy as np
|
@@ -16,19 +13,30 @@ from sklearn.metrics import silhouette_score
|
|
16
13
|
from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
|
17
14
|
from sklearn.utils.multiclass import type_of_target
|
18
15
|
|
16
|
+
from dataeval.output import Output, set_metadata
|
17
|
+
|
19
18
|
|
20
19
|
class TrainValSplit(NamedTuple):
|
21
20
|
"""Tuple containing train and validation indices"""
|
22
21
|
|
23
|
-
train: NDArray[np.
|
24
|
-
val: NDArray[np.
|
22
|
+
train: NDArray[np.intp]
|
23
|
+
val: NDArray[np.intp]
|
25
24
|
|
26
25
|
|
27
26
|
@dataclass(frozen=True)
|
28
27
|
class SplitDatasetOutput(Output):
|
29
|
-
"""
|
28
|
+
"""
|
29
|
+
Output class containing test indices and a list of TrainValSplits.
|
30
|
+
|
31
|
+
Attributes
|
32
|
+
----------
|
33
|
+
test: NDArray[np.intp]
|
34
|
+
Indices for the test set
|
35
|
+
folds: list[TrainValSplit]
|
36
|
+
List where each index contains the indices for the train and validation splits
|
37
|
+
"""
|
30
38
|
|
31
|
-
test: NDArray[np.
|
39
|
+
test: NDArray[np.intp]
|
32
40
|
folds: list[TrainValSplit]
|
33
41
|
|
34
42
|
|
@@ -100,7 +108,7 @@ def calculate_validation_fraction(num_folds: int, test_frac: float, val_frac: fl
|
|
100
108
|
return val_base * (1.0 / num_folds) * (1.0 - test_frac)
|
101
109
|
|
102
110
|
|
103
|
-
def _validate_labels(labels: NDArray[np.
|
111
|
+
def _validate_labels(labels: NDArray[np.intp], total_partitions: int) -> None:
|
104
112
|
"""
|
105
113
|
Check to make sure there is more input data than the total number of partitions requested
|
106
114
|
|
@@ -131,7 +139,7 @@ def _validate_labels(labels: NDArray[np.int_], total_partitions: int) -> None:
|
|
131
139
|
raise ValueError("Detected continuous labels. Labels must be discrete for proper stratification")
|
132
140
|
|
133
141
|
|
134
|
-
def is_stratifiable(labels: NDArray[np.
|
142
|
+
def is_stratifiable(labels: NDArray[np.intp], num_partitions: int) -> bool:
|
135
143
|
"""
|
136
144
|
Check if the dataset can be stratified by class label over the given number of partitions
|
137
145
|
|
@@ -166,7 +174,7 @@ def is_stratifiable(labels: NDArray[np.int_], num_partitions: int) -> bool:
|
|
166
174
|
return True
|
167
175
|
|
168
176
|
|
169
|
-
def is_groupable(group_ids: NDArray[np.
|
177
|
+
def is_groupable(group_ids: NDArray[np.intp], num_partitions: int) -> bool:
|
170
178
|
"""
|
171
179
|
Warns user if the number of unique group_ids is incompatible with a grouped partition containing
|
172
180
|
num_folds folds. If this is the case, returns groups=None, which tells the partitioner not to
|
@@ -205,7 +213,7 @@ def is_groupable(group_ids: NDArray[np.int_], num_partitions: int) -> bool:
|
|
205
213
|
return True
|
206
214
|
|
207
215
|
|
208
|
-
def bin_kmeans(array: NDArray[Any]) -> NDArray[np.
|
216
|
+
def bin_kmeans(array: NDArray[Any]) -> NDArray[np.intp]:
|
209
217
|
"""
|
210
218
|
Find bins of continuous data by iteratively applying k-means clustering, and keeping the
|
211
219
|
clustering with the highest silhouette score.
|
@@ -226,18 +234,18 @@ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.int_]:
|
|
226
234
|
best_score = 0.60
|
227
235
|
else:
|
228
236
|
best_score = 0.50
|
229
|
-
bin_index = np.zeros(len(array), dtype=np.
|
237
|
+
bin_index = np.zeros(len(array), dtype=np.intp)
|
230
238
|
for k in range(2, 20):
|
231
239
|
clusterer = KMeans(n_clusters=k)
|
232
240
|
cluster_labels = clusterer.fit_predict(array)
|
233
241
|
score = silhouette_score(array, cluster_labels, sample_size=25_000)
|
234
242
|
if score > best_score:
|
235
243
|
best_score = score
|
236
|
-
bin_index = cluster_labels.astype(np.
|
244
|
+
bin_index = cluster_labels.astype(np.intp)
|
237
245
|
return bin_index
|
238
246
|
|
239
247
|
|
240
|
-
def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples: int) -> NDArray[np.
|
248
|
+
def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples: int) -> NDArray[np.intp]:
|
241
249
|
"""
|
242
250
|
Returns individual group numbers based on a subset of metadata defined by groupnames
|
243
251
|
|
@@ -262,7 +270,7 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
|
|
262
270
|
"""
|
263
271
|
features2group = {k: np.array(v) for k, v in metadata.items() if k in group_names}
|
264
272
|
if not features2group:
|
265
|
-
return np.zeros(num_samples, dtype=np.
|
273
|
+
return np.zeros(num_samples, dtype=np.intp)
|
266
274
|
for name, feature in features2group.items():
|
267
275
|
if len(feature) != num_samples:
|
268
276
|
raise ValueError(
|
@@ -278,10 +286,10 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
|
|
278
286
|
|
279
287
|
|
280
288
|
def make_splits(
|
281
|
-
index: NDArray[np.
|
282
|
-
labels: NDArray[np.
|
289
|
+
index: NDArray[np.intp],
|
290
|
+
labels: NDArray[np.intp],
|
283
291
|
n_folds: int,
|
284
|
-
groups: NDArray[np.
|
292
|
+
groups: NDArray[np.intp] | None,
|
285
293
|
stratified: bool,
|
286
294
|
) -> list[TrainValSplit]:
|
287
295
|
"""
|
@@ -318,8 +326,8 @@ def make_splits(
|
|
318
326
|
split_defs.clear()
|
319
327
|
for train_idx, eval_idx in splits:
|
320
328
|
# test_ratio = len(eval_idx) / len(index)
|
321
|
-
t = np.atleast_1d(train_idx).astype(np.
|
322
|
-
v = np.atleast_1d(eval_idx).astype(np.
|
329
|
+
t = np.atleast_1d(train_idx).astype(np.intp)
|
330
|
+
v = np.atleast_1d(eval_idx).astype(np.intp)
|
323
331
|
good = good or (len(np.unique(labels[t])) == n_labels and len(np.unique(labels[v])) == n_labels)
|
324
332
|
split_defs.append(TrainValSplit(t, v))
|
325
333
|
if not good and attempts == 3:
|
@@ -328,7 +336,7 @@ def make_splits(
|
|
328
336
|
|
329
337
|
|
330
338
|
def find_best_split(
|
331
|
-
labels: NDArray[np.
|
339
|
+
labels: NDArray[np.intp], split_defs: list[TrainValSplit], stratified: bool, split_frac: float
|
332
340
|
) -> TrainValSplit:
|
333
341
|
"""
|
334
342
|
Finds the split that most closely satisfies a criterion determined by the arguments passed.
|
@@ -385,10 +393,10 @@ def find_best_split(
|
|
385
393
|
|
386
394
|
|
387
395
|
def single_split(
|
388
|
-
index: NDArray[np.
|
389
|
-
labels: NDArray[np.
|
396
|
+
index: NDArray[np.intp],
|
397
|
+
labels: NDArray[np.intp],
|
390
398
|
split_frac: float,
|
391
|
-
groups: NDArray[np.
|
399
|
+
groups: NDArray[np.intp] | None = None,
|
392
400
|
stratified: bool = False,
|
393
401
|
) -> TrainValSplit:
|
394
402
|
"""
|
@@ -427,7 +435,7 @@ def single_split(
|
|
427
435
|
|
428
436
|
@set_metadata
|
429
437
|
def split_dataset(
|
430
|
-
labels: list[int] | NDArray[np.
|
438
|
+
labels: list[int] | NDArray[np.intp],
|
431
439
|
num_folds: int = 1,
|
432
440
|
stratify: bool = False,
|
433
441
|
split_on: list[str] | None = None,
|
@@ -481,7 +489,7 @@ def split_dataset(
|
|
481
489
|
total_partitions = num_folds + 1 if test_frac else num_folds
|
482
490
|
|
483
491
|
if isinstance(labels, list):
|
484
|
-
labels = np.array(labels, dtype=np.
|
492
|
+
labels = np.array(labels, dtype=np.intp)
|
485
493
|
|
486
494
|
label_length: int = len(labels)
|
487
495
|
|
@@ -497,13 +505,13 @@ def split_dataset(
|
|
497
505
|
if is_groupable(possible_groups, group_partitions):
|
498
506
|
groups = possible_groups
|
499
507
|
|
500
|
-
test_indices: NDArray[np.
|
508
|
+
test_indices: NDArray[np.intp]
|
501
509
|
index = np.arange(label_length)
|
502
510
|
|
503
511
|
tv_indices, test_indices = (
|
504
512
|
single_split(index=index, labels=labels, split_frac=test_frac, groups=groups, stratified=stratify)
|
505
513
|
if test_frac
|
506
|
-
else (index, np.array([], dtype=np.
|
514
|
+
else (index, np.array([], dtype=np.intp))
|
507
515
|
)
|
508
516
|
|
509
517
|
tv_labels = labels[tv_indices]
|
dataeval/utils/image.py
CHANGED
@@ -63,8 +63,8 @@ def edge_filter(image: ArrayLike, offset: float = 0.5) -> NDArray[np.uint8]:
|
|
63
63
|
"""
|
64
64
|
Returns the image filtered using a 3x3 edge detection kernel:
|
65
65
|
[[ -1, -1, -1 ],
|
66
|
-
|
67
|
-
|
66
|
+
[ -1, 8, -1 ],
|
67
|
+
[ -1, -1, -1 ]]
|
68
68
|
"""
|
69
69
|
edges = convolve2d(image, EDGE_KERNEL, mode="same", boundary="symm") + offset
|
70
70
|
np.clip(edges, 0, 255, edges)
|