dataeval 0.86.1__py3-none-any.whl → 0.86.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataeval/outputs/_bias.py CHANGED
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  from dataclasses import asdict, dataclass
7
- from typing import Any, Literal, TypeVar, overload
7
+ from typing import Any, Mapping, Sequence, TypeVar
8
8
 
9
9
  import numpy as np
10
10
  import pandas as pd
@@ -39,7 +39,7 @@ class ToDataFrameMixin:
39
39
  This method requires `pandas <https://pandas.pydata.org/>`_ to be installed.
40
40
  """
41
41
  return pd.DataFrame(
42
- index=self.factor_names, # type: ignore - list[str] is documented as acceptable index type
42
+ index=self.factor_names, # type: ignore - Sequence[str] is documented as acceptable index type
43
43
  data={
44
44
  "score": self.score.round(2),
45
45
  "p-value": self.p_value.round(2),
@@ -58,7 +58,7 @@ class ParityOutput(ToDataFrameMixin, Output):
58
58
  chi-squared score(s) of the test
59
59
  p_value : NDArray[np.float64]
60
60
  p-value(s) of the test
61
- factor_names : list[str]
61
+ factor_names : Sequence[str]
62
62
  Names of each metadata factor
63
63
  insufficient_data: dict
64
64
  Dictionary of metadata factors with less than 5 class occurrences per value
@@ -66,8 +66,8 @@ class ParityOutput(ToDataFrameMixin, Output):
66
66
 
67
67
  score: NDArray[np.float64]
68
68
  p_value: NDArray[np.float64]
69
- factor_names: list[str]
70
- insufficient_data: dict[str, dict[int, dict[str, int]]]
69
+ factor_names: Sequence[str]
70
+ insufficient_data: Mapping[str, Mapping[int, Mapping[str, int]]]
71
71
 
72
72
 
73
73
  @dataclass(frozen=True)
@@ -145,12 +145,15 @@ class CoverageOutput(Output):
145
145
  cols = min(3, num_images)
146
146
  fig, axs = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
147
147
 
148
- for image, ax in zip(images[:num_images], axs.flat):
148
+ # Flatten axes using numpy array explicitly for compatibility
149
+ axs_flat = np.asarray(axs).flatten()
150
+
151
+ for image, ax in zip(images[:num_images], axs_flat):
149
152
  image = channels_first_to_last(as_numpy(image))
150
153
  ax.imshow(image)
151
154
  ax.axis("off")
152
155
 
153
- for ax in axs.flat[num_images:]:
156
+ for ax in axs_flat[num_images:]:
154
157
  ax.axis("off")
155
158
 
156
159
  fig.tight_layout()
@@ -187,65 +190,23 @@ class BalanceOutput(Output):
187
190
  Estimate of inter/intra-factor mutual information
188
191
  classwise : NDArray[np.float64]
189
192
  Estimate of mutual information between metadata factors and individual class labels
190
- factor_names : list[str]
193
+ factor_names : Sequence[str]
191
194
  Names of each metadata factor
192
- class_names : list[str]
195
+ class_names : Sequence[str]
193
196
  List of the class labels present in the dataset
194
197
  """
195
198
 
196
199
  balance: NDArray[np.float64]
197
200
  factors: NDArray[np.float64]
198
201
  classwise: NDArray[np.float64]
199
- factor_names: list[str]
200
- class_names: list[str]
201
-
202
- @overload
203
- def _by_factor_type(
204
- self,
205
- attr: Literal["factor_names"],
206
- factor_type: Literal["discrete", "continuous", "both"],
207
- ) -> list[str]: ...
208
-
209
- @overload
210
- def _by_factor_type(
211
- self,
212
- attr: Literal["balance", "factors", "classwise"],
213
- factor_type: Literal["discrete", "continuous", "both"],
214
- ) -> NDArray[np.float64]: ...
215
-
216
- def _by_factor_type(
217
- self,
218
- attr: Literal["balance", "factors", "classwise", "factor_names"],
219
- factor_type: Literal["discrete", "continuous", "both"],
220
- ) -> NDArray[np.float64] | list[str]:
221
- # if not filtering by factor_type then just return the requested attribute without mask
222
- if factor_type == "both":
223
- return getattr(self, attr)
224
-
225
- # create the mask for the selected factor_type
226
- mask_lambda = (
227
- (lambda x: "-continuous" not in x) if factor_type == "discrete" else (lambda x: "-discrete" not in x)
228
- )
229
-
230
- # return the masked attribute
231
- if attr == "factor_names":
232
- return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
233
- factor_type_mask = np.asarray([mask_lambda(x) for x in self.factor_names])
234
- if attr == "factors":
235
- return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
236
- if attr == "balance":
237
- return self.balance[factor_type_mask]
238
- if attr == "classwise":
239
- return self.classwise[:, factor_type_mask]
240
-
241
- raise ValueError(f"Unknown attr {attr} specified.")
202
+ factor_names: Sequence[str]
203
+ class_names: Sequence[str]
242
204
 
243
205
  def plot(
244
206
  self,
245
- row_labels: list[Any] | NDArray[Any] | None = None,
246
- col_labels: list[Any] | NDArray[Any] | None = None,
207
+ row_labels: Sequence[Any] | NDArray[Any] | None = None,
208
+ col_labels: Sequence[Any] | NDArray[Any] | None = None,
247
209
  plot_classwise: bool = False,
248
- factor_type: Literal["discrete", "continuous", "both"] = "discrete",
249
210
  ) -> Figure:
250
211
  """
251
212
  Plot a heatmap of balance information.
@@ -258,8 +219,6 @@ class BalanceOutput(Output):
258
219
  List/Array containing the labels for columns in the histogram
259
220
  plot_classwise : bool, default False
260
221
  Whether to plot per-class balance instead of global balance
261
- factor_type : "discrete", "continuous", or "both", default "discrete"
262
- Whether to plot discretized values, continuous values, or to include both
263
222
 
264
223
  Returns
265
224
  -------
@@ -273,10 +232,10 @@ class BalanceOutput(Output):
273
232
  if row_labels is None:
274
233
  row_labels = self.class_names
275
234
  if col_labels is None:
276
- col_labels = self._by_factor_type("factor_names", factor_type)
235
+ col_labels = self.factor_names
277
236
 
278
237
  fig = heatmap(
279
- self._by_factor_type("classwise", factor_type),
238
+ self.classwise,
280
239
  row_labels,
281
240
  col_labels,
282
241
  xlabel="Factors",
@@ -287,8 +246,8 @@ class BalanceOutput(Output):
287
246
  # Combine balance and factors results
288
247
  data = np.concatenate(
289
248
  [
290
- self._by_factor_type("balance", factor_type)[np.newaxis, 1:],
291
- self._by_factor_type("factors", factor_type),
249
+ self.balance[np.newaxis, 1:],
250
+ self.factors,
292
251
  ],
293
252
  axis=0,
294
253
  )
@@ -297,7 +256,7 @@ class BalanceOutput(Output):
297
256
  # Finalize the data for the plot, last row is last factor x last factor so it gets dropped
298
257
  heat_data = np.where(mask, np.nan, data)[:-1]
299
258
  # Creating label array for heat map axes
300
- heat_labels = self._by_factor_type("factor_names", factor_type)
259
+ heat_labels = self.factor_names
301
260
 
302
261
  if row_labels is None:
303
262
  row_labels = heat_labels[:-1]
@@ -320,16 +279,16 @@ class DiversityOutput(Output):
320
279
  :term:`Diversity` index for classes and factors
321
280
  classwise : NDArray[np.double]
322
281
  Classwise diversity index [n_class x n_factor]
323
- factor_names : list[str]
282
+ factor_names : Sequence[str]
324
283
  Names of each metadata factor
325
- class_names : list[str]
284
+ class_names : Sequence[str]
326
285
  Class labels for each value in the dataset
327
286
  """
328
287
 
329
288
  diversity_index: NDArray[np.double]
330
289
  classwise: NDArray[np.double]
331
- factor_names: list[str]
332
- class_names: list[str]
290
+ factor_names: Sequence[str]
291
+ class_names: Sequence[str]
333
292
 
334
293
  def plot(
335
294
  self,
@@ -377,7 +336,7 @@ class DiversityOutput(Output):
377
336
  import matplotlib.pyplot as plt
378
337
 
379
338
  fig, ax = plt.subplots(figsize=(8, 8))
380
- heat_labels = np.concatenate((["class"], self.factor_names))
339
+ heat_labels = ["class_labels"] + list(self.factor_names)
381
340
  ax.bar(heat_labels, self.diversity_index)
382
341
  ax.set_xlabel("Factors")
383
342
  plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
+ from typing import Sequence
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
@@ -64,7 +65,7 @@ class ClustererOutput(Output):
64
65
  """
65
66
  return np.nonzero(self.clusters == -1)[0]
66
67
 
67
- def find_duplicates(self) -> tuple[list[list[int]], list[list[int]]]:
68
+ def find_duplicates(self) -> tuple[Sequence[Sequence[int]], Sequence[Sequence[int]]]:
68
69
  """
69
70
  Finds duplicate and near duplicate data based on cluster average distance
70
71
 
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
- from typing import Generic, TypeVar, Union
6
+ from typing import Generic, Mapping, Sequence, TypeVar, Union
7
7
 
8
8
  import pandas as pd
9
9
  from typing_extensions import TypeAlias
@@ -11,13 +11,13 @@ from typing_extensions import TypeAlias
11
11
  from dataeval.outputs._base import Output
12
12
  from dataeval.outputs._stats import DimensionStatsOutput, LabelStatsOutput, PixelStatsOutput, VisualStatsOutput
13
13
 
14
- DuplicateGroup: TypeAlias = list[int]
15
- DatasetDuplicateGroupMap: TypeAlias = dict[int, DuplicateGroup]
14
+ DuplicateGroup: TypeAlias = Sequence[int]
15
+ DatasetDuplicateGroupMap: TypeAlias = Mapping[int, DuplicateGroup]
16
16
  TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
17
17
 
18
- IndexIssueMap: TypeAlias = dict[int, dict[str, float]]
18
+ IndexIssueMap: TypeAlias = Mapping[int, Mapping[str, float]]
19
19
  OutlierStatsOutput: TypeAlias = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
20
- TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
20
+ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, Sequence[IndexIssueMap])
21
21
 
22
22
 
23
23
  @dataclass(frozen=True)
@@ -27,9 +27,9 @@ class DuplicatesOutput(Output, Generic[TIndexCollection]):
27
27
 
28
28
  Attributes
29
29
  ----------
30
- exact : list[list[int] | dict[int, list[int]]]
30
+ exact : Sequence[Sequence[int] | Mapping[int, Sequence[int]]]
31
31
  Indices of images that are exact matches
32
- near: list[list[int] | dict[int, list[int]]]
32
+ near: Sequence[Sequence[int] | Mapping[int, Sequence[int]]]
33
33
  Indices of images that are near matches
34
34
 
35
35
  Notes
@@ -39,13 +39,13 @@ class DuplicatesOutput(Output, Generic[TIndexCollection]):
39
39
  index of the dataset, and the value is the list index groups from that dataset.
40
40
  """
41
41
 
42
- exact: list[TIndexCollection]
43
- near: list[TIndexCollection]
42
+ exact: Sequence[TIndexCollection]
43
+ near: Sequence[TIndexCollection]
44
44
 
45
45
 
46
46
  def _reorganize_by_class_and_metric(
47
47
  result: IndexIssueMap, lstats: LabelStatsOutput
48
- ) -> tuple[dict[str, list[int]], dict[str, dict[str, int]]]:
48
+ ) -> tuple[Mapping[str, Sequence[int]], Mapping[str, Mapping[str, int]]]:
49
49
  """Flip result from grouping by image to grouping by class and metric"""
50
50
  metrics: dict[str, list[int]] = {}
51
51
  class_wise: dict[str, dict[str, int]] = {label: {} for label in lstats.class_names}
@@ -61,7 +61,7 @@ def _reorganize_by_class_and_metric(
61
61
  return metrics, class_wise
62
62
 
63
63
 
64
- def _create_table(metrics: dict[str, list[int]], class_wise: dict[str, dict[str, int]]) -> list[str]:
64
+ def _create_table(metrics: Mapping[str, Sequence[int]], class_wise: Mapping[str, Mapping[str, int]]) -> Sequence[str]:
65
65
  """Create table for displaying the results"""
66
66
  max_class_length = max(len(str(label)) for label in class_wise) + 2
67
67
  max_total = max(len(metrics[group]) for group in metrics) + 2
@@ -71,7 +71,7 @@ def _create_table(metrics: dict[str, list[int]], class_wise: dict[str, dict[str,
71
71
  + [f"{group:^{max(5, len(str(group))) + 2}}" for group in sorted(metrics.keys())]
72
72
  + [f"{'Total':<{max_total}}"]
73
73
  )
74
- table_rows: list[str] = []
74
+ table_rows: Sequence[str] = []
75
75
 
76
76
  for class_cat, results in class_wise.items():
77
77
  table_value = [f"{class_cat:>{max_class_length}}"]
@@ -86,7 +86,7 @@ def _create_table(metrics: dict[str, list[int]], class_wise: dict[str, dict[str,
86
86
  return [table_header] + table_rows
87
87
 
88
88
 
89
- def _create_pandas_dataframe(class_wise: dict[str, dict[str, int]]) -> list[dict[str, str | int]]:
89
+ def _create_pandas_dataframe(class_wise: Mapping[str, Mapping[str, int]]) -> Sequence[Mapping[str, str | int]]:
90
90
  """Create data for pandas dataframe"""
91
91
  data = []
92
92
  for label, metrics_dict in class_wise.items():
@@ -105,7 +105,7 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
105
105
 
106
106
  Attributes
107
107
  ----------
108
- issues : dict[int, dict[str, float]] | list[dict[int, dict[str, float]]]
108
+ issues : Mapping[int, Mapping[str, float]] | Sequence[Mapping[int, Mapping[str, float]]]
109
109
  Indices of image Outliers with their associated issue type and calculated values.
110
110
 
111
111
  - For a single dataset, a dictionary containing the indices of outliers and
@@ -117,7 +117,7 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
117
117
  issues: TIndexIssueMap
118
118
 
119
119
  def __len__(self) -> int:
120
- if isinstance(self.issues, dict):
120
+ if isinstance(self.issues, Mapping):
121
121
  return len(self.issues)
122
122
  return sum(len(d) for d in self.issues)
123
123
 
@@ -134,7 +134,7 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
134
134
  -------
135
135
  str
136
136
  """
137
- if isinstance(self.issues, dict):
137
+ if isinstance(self.issues, Mapping):
138
138
  metrics, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
139
139
  listed_table = _create_table(metrics, classwise)
140
140
  table = "\n".join(listed_table)
@@ -165,7 +165,7 @@ class OutliersOutput(Output, Generic[TIndexIssueMap]):
165
165
  -----
166
166
  This method requires `pandas <https://pandas.pydata.org/>`_ to be installed.
167
167
  """
168
- if isinstance(self.issues, dict):
168
+ if isinstance(self.issues, Mapping):
169
169
  _, classwise = _reorganize_by_class_and_metric(self.issues, labelstats)
170
170
  data = _create_pandas_dataframe(classwise)
171
171
  df = pd.DataFrame(data)
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Optional, Sequence, Union
6
+ from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Optional, Sequence, Union
7
7
 
8
8
  import numpy as np
9
9
  import pandas as pd
@@ -61,7 +61,7 @@ class BaseStatsOutput(Output):
61
61
  The number of detected objects in each image
62
62
  """
63
63
 
64
- source_index: list[SourceIndex]
64
+ source_index: Sequence[SourceIndex]
65
65
  object_count: NDArray[np.uint16]
66
66
  image_count: int
67
67
 
@@ -80,7 +80,7 @@ class BaseStatsOutput(Output):
80
80
  self,
81
81
  channel_index: OptionalRange,
82
82
  channel_count: OptionalRange = None,
83
- ) -> list[bool]:
83
+ ) -> Sequence[bool]:
84
84
  """
85
85
  Boolean mask for results filtered to specified channel index and optionally the count
86
86
  of the channels per image.
@@ -92,8 +92,8 @@ class BaseStatsOutput(Output):
92
92
  channel_count : int | Iterable[int] | None
93
93
  Optional count(s) of channels to filter for
94
94
  """
95
- mask: list[bool] = []
96
- cur_mask: list[bool] = []
95
+ mask: Sequence[bool] = []
96
+ cur_mask: Sequence[bool] = []
97
97
  cur_image = 0
98
98
  cur_max_channel = 0
99
99
  for source_index in list(self.source_index) + [None]:
@@ -113,7 +113,7 @@ class BaseStatsOutput(Output):
113
113
 
114
114
  def _get_channels(
115
115
  self, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
116
- ) -> tuple[int, list[bool] | None]:
116
+ ) -> tuple[int, Sequence[bool] | None]:
117
117
  source_index = self.data()[SOURCE_INDEX]
118
118
  raw_channels = int(max([si.channel or 0 for si in source_index])) + 1
119
119
  if isinstance(channel_index, int):
@@ -140,7 +140,7 @@ class BaseStatsOutput(Output):
140
140
  self,
141
141
  filter: str | Sequence[str] | None = None, # noqa: A002
142
142
  exclude_constant: bool = False,
143
- ) -> dict[str, NDArray[Any]]:
143
+ ) -> Mapping[str, NDArray[Any]]:
144
144
  """
145
145
  Returns all 1-dimensional data as a dictionary of numpy arrays.
146
146
 
@@ -153,7 +153,7 @@ class BaseStatsOutput(Output):
153
153
 
154
154
  Returns
155
155
  -------
156
- dict[str, NDArray[Any]]
156
+ Mapping[str, NDArray[Any]]
157
157
  """
158
158
  filter_ = [filter] if isinstance(filter, str) else filter
159
159
  return {
@@ -253,8 +253,8 @@ class HashStatsOutput(BaseStatsOutput):
253
253
  :term:`Perception-based Hash` of the images as a hex string
254
254
  """
255
255
 
256
- xxhash: list[str]
257
- pchash: list[str]
256
+ xxhash: Sequence[str]
257
+ pchash: Sequence[str]
258
258
 
259
259
 
260
260
  @dataclass(frozen=True)
@@ -264,15 +264,15 @@ class LabelStatsOutput(Output):
264
264
 
265
265
  Attributes
266
266
  ----------
267
- label_counts_per_class : dict[int, int]
267
+ label_counts_per_class : Mapping[int, int]
268
268
  Dictionary whose keys are the different label classes and
269
269
  values are total counts of each class
270
- label_counts_per_image : list[int]
270
+ label_counts_per_image : Sequence[int]
271
271
  Number of labels per image
272
- image_counts_per_class : dict[int, int]
272
+ image_counts_per_class : Mapping[int, int]
273
273
  Dictionary whose keys are the different label classes and
274
274
  values are total counts of each image the class is present in
275
- image_indices_per_class : dict[int, list]
275
+ image_indices_per_class : Mapping[int, list]
276
276
  Dictionary whose keys are the different label classes and
277
277
  values are lists containing the images that have that label
278
278
  image_count : int
@@ -281,17 +281,17 @@ class LabelStatsOutput(Output):
281
281
  Total number of classes present
282
282
  label_count : int
283
283
  Total number of labels present
284
- class_names : list[str]
284
+ class_names : Sequence[str]
285
285
  """
286
286
 
287
- label_counts_per_class: list[int]
288
- label_counts_per_image: list[int]
289
- image_counts_per_class: list[int]
290
- image_indices_per_class: list[list[int]]
287
+ label_counts_per_class: Sequence[int]
288
+ label_counts_per_image: Sequence[int]
289
+ image_counts_per_class: Sequence[int]
290
+ image_indices_per_class: Sequence[Sequence[int]]
291
291
  image_count: int
292
292
  class_count: int
293
293
  label_count: int
294
- class_names: list[str]
294
+ class_names: Sequence[str]
295
295
 
296
296
  def to_table(self) -> str:
297
297
  """
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from dataclasses import dataclass
6
+ from typing import Sequence
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
@@ -36,9 +37,9 @@ class SplitDatasetOutput(Output):
36
37
  ----------
37
38
  test: NDArray[np.intp]
38
39
  Indices for the test set
39
- folds: list[TrainValSplit]
40
+ folds: Sequence[TrainValSplit]
40
41
  List of train and validation split indices
41
42
  """
42
43
 
43
44
  test: NDArray[np.intp]
44
- folds: list[TrainValSplit]
45
+ folds: Sequence[TrainValSplit]
@@ -177,7 +177,9 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
177
177
  return res.x
178
178
 
179
179
 
180
- def get_curve_params(measures: dict[str, NDArray[Any]], ranges: NDArray[Any], niter: int) -> dict[str, NDArray[Any]]:
180
+ def get_curve_params(
181
+ measures: Mapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
182
+ ) -> Mapping[str, NDArray[Any]]:
181
183
  """Calculates and aggregates parameters for both single and multi-class metrics"""
182
184
  output = {}
183
185
  for name, measure in measures.items():
@@ -208,7 +210,7 @@ class SufficiencyOutput(Output):
208
210
  """
209
211
 
210
212
  steps: NDArray[np.uint32]
211
- measures: dict[str, NDArray[np.float64]]
213
+ measures: Mapping[str, NDArray[np.float64]]
212
214
  n_iter: int = 1000
213
215
 
214
216
  def __post_init__(self) -> None:
@@ -220,7 +222,7 @@ class SufficiencyOutput(Output):
220
222
  self._params = None
221
223
 
222
224
  @property
223
- def params(self) -> dict[str, NDArray[Any]]:
225
+ def params(self) -> Mapping[str, NDArray[Any]]:
224
226
  if self._params is None:
225
227
  self._params = {}
226
228
  if self.n_iter not in self._params:
@@ -270,7 +272,7 @@ class SufficiencyOutput(Output):
270
272
  proj._params = self._params
271
273
  return proj
272
274
 
273
- def plot(self, class_names: Sequence[str] | None = None) -> list[Figure]:
275
+ def plot(self, class_names: Sequence[str] | None = None) -> Sequence[Figure]:
274
276
  """
275
277
  Plotting function for data :term:`sufficience<Sufficiency>` tasks.
276
278
 
@@ -281,7 +283,7 @@ class SufficiencyOutput(Output):
281
283
 
282
284
  Returns
283
285
  -------
284
- list[Figure]
286
+ Sequence[Figure]
285
287
  List of Figures for each measure
286
288
 
287
289
  Raises
@@ -325,7 +327,7 @@ class SufficiencyOutput(Output):
325
327
 
326
328
  def inv_project(
327
329
  self, targets: Mapping[str, ArrayLike], n_iter: int | None = None
328
- ) -> dict[str, NDArray[np.float64]]:
330
+ ) -> Mapping[str, NDArray[np.float64]]:
329
331
  """
330
332
  Calculate training samples needed to achieve target model metric values.
331
333
 
@@ -339,7 +341,7 @@ class SufficiencyOutput(Output):
339
341
 
340
342
  Returns
341
343
  -------
342
- dict[str, NDArray]
344
+ Mapping[str, NDArray]
343
345
  List of the number of training samples needed to achieve each
344
346
  corresponding entry in targets
345
347
  """
dataeval/typing.py CHANGED
@@ -21,7 +21,7 @@ __all__ = [
21
21
 
22
22
 
23
23
  import sys
24
- from typing import Any, Generic, Iterator, Protocol, TypedDict, TypeVar, runtime_checkable
24
+ from typing import Any, Generic, Iterator, Mapping, Protocol, TypedDict, TypeVar, runtime_checkable
25
25
 
26
26
  import numpy.typing
27
27
  from typing_extensions import NotRequired, ReadOnly, Required
@@ -159,7 +159,7 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
159
159
  # ========== IMAGE CLASSIFICATION DATASETS ==========
160
160
 
161
161
 
162
- ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, dict[str, Any]]
162
+ ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, Mapping[str, Any]]
163
163
  """
164
164
  Type alias for an image classification datum tuple.
165
165
 
@@ -199,7 +199,7 @@ class ObjectDetectionTarget(Protocol):
199
199
  def scores(self) -> ArrayLike: ...
200
200
 
201
201
 
202
- ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, dict[str, Any]]
202
+ ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, Mapping[str, Any]]
203
203
  """
204
204
  Type alias for an object detection datum tuple.
205
205
 
@@ -240,7 +240,7 @@ class SegmentationTarget(Protocol):
240
240
  def scores(self) -> ArrayLike: ...
241
241
 
242
242
 
243
- SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget, dict[str, Any]]
243
+ SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget, Mapping[str, Any]]
244
244
  """
245
245
  Type alias for an image classification datum tuple.
246
246
 
dataeval/utils/_plot.py CHANGED
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  import math
7
- from typing import Any
7
+ from typing import Any, Mapping, Sequence
8
8
 
9
9
  import numpy as np
10
10
 
@@ -134,7 +134,7 @@ def format_text(*args: str) -> str:
134
134
 
135
135
 
136
136
  def histogram_plot(
137
- data_dict: dict[str, Any],
137
+ data_dict: Mapping[str, Any],
138
138
  log: bool = True,
139
139
  xlabel: str = "values",
140
140
  ylabel: str = "counts",
@@ -186,10 +186,10 @@ def histogram_plot(
186
186
 
187
187
 
188
188
  def channel_histogram_plot(
189
- data_dict: dict[str, Any],
189
+ data_dict: Mapping[str, Any],
190
190
  log: bool = True,
191
191
  max_channels: int = 3,
192
- ch_mask: list[bool] | None = None,
192
+ ch_mask: Sequence[bool] | None = None,
193
193
  xlabel: str = "values",
194
194
  ylabel: str = "counts",
195
195
  ) -> Figure: