dataeval 0.88.0__py3-none-any.whl → 0.88.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataeval/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.88.0'
21
- __version_tuple__ = version_tuple = (0, 88, 0)
20
+ __version__ = version = '0.88.1'
21
+ __version_tuple__ = version_tuple = (0, 88, 1)
@@ -62,9 +62,12 @@ def project_steps(params: NDArray[Any], projection: NDArray[Any]) -> NDArray[Any
62
62
  def plot_measure(
63
63
  name: str,
64
64
  steps: NDArray[Any],
65
- measure: NDArray[Any],
65
+ averaged_measure: NDArray[Any],
66
+ measures: NDArray[Any] | None,
66
67
  params: NDArray[Any],
67
68
  projection: NDArray[Any],
69
+ error_bars: bool,
70
+ asymptote: bool,
68
71
  ) -> Figure:
69
72
  import matplotlib.pyplot
70
73
 
@@ -73,21 +76,51 @@ def plot_measure(
73
76
  fig.tight_layout()
74
77
 
75
78
  ax = fig.add_subplot(111)
76
-
77
79
  ax.set_title(f"{name} Sufficiency")
78
80
  ax.set_ylabel(f"{name}")
79
81
  ax.set_xlabel("Steps")
80
- # Plot measure over each step
81
- ax.scatter(steps, measure, label=f"Model Results ({name})", s=15, c="black")
82
+ # Plot asymptote
83
+ if asymptote:
84
+ bound = 1 - params[2]
85
+ ax.axhline(y=bound, color="r", label=f"Asymptote: {bound:.4g}", zorder=1)
86
+ # Calculate error bars
87
+ # Plot measure over each step with associated error
88
+ if error_bars:
89
+ if measures is None:
90
+ warnings.warn(
91
+ "Error bars cannot be plotted without full, unaveraged data",
92
+ UserWarning,
93
+ )
94
+ else:
95
+ error = np.std(measures, axis=0)
96
+ ax.errorbar(
97
+ steps,
98
+ averaged_measure,
99
+ yerr=error,
100
+ capsize=7,
101
+ capthick=1.5,
102
+ elinewidth=1.5,
103
+ fmt="o",
104
+ label=f"Model Results ({name})",
105
+ markersize=5,
106
+ color="black",
107
+ ecolor="orange",
108
+ zorder=3,
109
+ )
110
+ else:
111
+ ax.scatter(steps, averaged_measure, label=f"Model Results ({name})", zorder=3, c="black")
82
112
  # Plot extrapolation
83
113
  ax.plot(
84
114
  projection,
85
115
  project_steps(params, projection),
86
116
  linestyle="dashed",
87
117
  label=f"Potential Model Results ({name})",
118
+ linewidth=2,
119
+ zorder=2,
88
120
  )
121
+ ax.set_xscale("log")
89
122
 
90
- ax.legend()
123
+ ax.legend(loc="best")
91
124
  return fig
92
125
 
93
126
 
@@ -286,11 +319,13 @@ class SufficiencyOutput(Output):
286
319
  output[name] = np.array(result)
287
320
  else:
288
321
  output[name] = project_steps(self.params[name], projection)
289
- proj = SufficiencyOutput(projection, measures=self.measures, averaged_measures=output, n_iter=self.n_iter)
322
+ proj = SufficiencyOutput(projection, {}, output, self.n_iter)
290
323
  proj._params = self._params
291
324
  return proj
292
325
 
293
- def plot(self, class_names: Sequence[str] | None = None) -> Sequence[Figure]:
326
+ def plot(
327
+ self, class_names: Sequence[str] | None = None, error_bars: bool = False, asymptote: bool = False
328
+ ) -> Sequence[Figure]:
294
329
  """
295
330
  Plotting function for data :term:`sufficience<Sufficiency>` tasks.
296
331
 
@@ -298,6 +333,10 @@ class SufficiencyOutput(Output):
298
333
  ----------
299
334
  class_names : Sequence[str] | None, default None
300
335
  List of class names
336
+ error_bars : bool, default False
337
+ True if error bars should be plotted, False if not
338
+ asymptote : bool, default False
339
+ True if asymptote should be plotted, False if not
301
340
 
302
341
  Returns
303
342
  -------
@@ -320,25 +359,36 @@ class SufficiencyOutput(Output):
320
359
 
321
360
  # Stores all plots
322
361
  plots = []
323
-
324
362
  # Create a plot for each measure on one figure
325
- for name, averaged_measures in self.averaged_measures.items():
326
- if averaged_measures.ndim > 1:
327
- if class_names is not None and len(averaged_measures) != len(class_names):
363
+ for name, measures in self.averaged_measures.items():
364
+ if measures.ndim > 1:
365
+ if class_names is not None and len(measures) != len(class_names):
328
366
  raise IndexError("Class name count does not align with measures")
329
- for i, measure in enumerate(averaged_measures):
367
+ for i, values in enumerate(measures):
330
368
  class_name = str(i) if class_names is None else class_names[i]
331
369
  fig = plot_measure(
332
370
  f"{name}_{class_name}",
333
371
  self.steps,
334
- measure,
372
+ values,
373
+ self.measures[name][:, :, i] if len(self.measures) else None,
335
374
  self.params[name][i],
336
375
  extrapolated,
376
+ error_bars,
377
+ asymptote,
337
378
  )
338
379
  plots.append(fig)
339
380
 
340
381
  else:
341
- fig = plot_measure(name, self.steps, averaged_measures, self.params[name], extrapolated)
382
+ fig = plot_measure(
383
+ name,
384
+ self.steps,
385
+ measures,
386
+ self.measures.get(name),
387
+ self.params[name],
388
+ extrapolated,
389
+ error_bars,
390
+ asymptote,
391
+ )
342
392
  plots.append(fig)
343
393
 
344
394
  return plots
@@ -280,5 +280,4 @@ class Sufficiency(Generic[T]):
280
280
  )
281
281
 
282
282
  measures[name][run, iteration] = value
283
- # The mean for each measure must be calculated before being returned
284
- return SufficiencyOutput(ranges, measures=measures)
283
+ return SufficiencyOutput(ranges, measures)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dataeval
3
- Version: 0.88.0
3
+ Version: 0.88.1
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Project-URL: Homepage, https://dataeval.ai/
6
6
  Project-URL: Repository, https://github.com/aria-ml/dataeval/
@@ -1,6 +1,6 @@
1
1
  dataeval/__init__.py,sha256=aFzX3SLx8wgc763RY772P41ZLqeHcUHRKW9XAN0KfHQ,1793
2
2
  dataeval/_log.py,sha256=Q2d6oqYKXyn1wkgMdNX9iswod4Jq0jPADShrCFVgJI0,374
3
- dataeval/_version.py,sha256=p36W3DcVLrkAWnGoljUjU-PF8_IvHjfGbC98bXZ2g_c,513
3
+ dataeval/_version.py,sha256=CKtd7X5fA88g3vtlmrUWb2oMZ7hUnqfrivEo9r-T_BU,513
4
4
  dataeval/config.py,sha256=lL73s_xa9pBxHHCnBKi59D_tl4vS7ig1rfWbIYkM_ac,3839
5
5
  dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  dataeval/typing.py,sha256=cKpK8rY7iVf-KL9kuye6qi_6LS6hKbMxHpurdWlYY44,7445
@@ -76,7 +76,7 @@ dataeval/outputs/_metadata.py,sha256=ffZgpX8KWURPHXpOWjbvJ2KRqWQkS2nWuIjKUzoHhMI
76
76
  dataeval/outputs/_ood.py,sha256=suLKVXULGtXH0rq9eXHI1d3d2jhGmItJtz4QiQd47A4,1718
77
77
  dataeval/outputs/_stats.py,sha256=PsDV0uw41aTy-X9tjz-PqOj78TTnH4JQVpOrU3OThAE,17423
78
78
  dataeval/outputs/_utils.py,sha256=KJ1P8tcMFIkGi2A6VfqbZwLcT1cD0c2YssTbWbHALjE,938
79
- dataeval/outputs/_workflows.py,sha256=sw13FNx1vANX7DBsKeOLfP2bkp5r6SexBorfb9dxYxU,12160
79
+ dataeval/outputs/_workflows.py,sha256=Q6lvEjrqdazs0WZTp5hP9wLsrR7-Cofmb3b12OYZZUA,13771
80
80
  dataeval/utils/__init__.py,sha256=sjelzMPaTImF6isiRcp8UGDE3tppEpWS5GoR8WKPZ1k,242
81
81
  dataeval/utils/_array.py,sha256=P4_gyH3kkksUJm9Vqx-oPtLWxFmqMacUJzhj0vmrUd8,6361
82
82
  dataeval/utils/_bin.py,sha256=QjlRCB5mOauETdxSbvRxRG17riO6gScsMd_lNnnvqxs,7391
@@ -98,8 +98,8 @@ dataeval/utils/torch/_internal.py,sha256=LiuqZGIzKewp_29_Lskj0mnNqdMffMheMdgGeXL
98
98
  dataeval/utils/torch/models.py,sha256=1idpXyjrYcCBSsbxxRUOto8xr4MJNjDEqQHiIXVU5Zc,9700
99
99
  dataeval/utils/torch/trainer.py,sha256=kBdgxd9TL1Pvz-dyZbS__POAKeFrDiQ4vKFh8ltJApc,5543
100
100
  dataeval/workflows/__init__.py,sha256=ou8y0KO-d6W5lgmcyLjKlf-J_ckP3vilW7wHkgiDlZ4,255
101
- dataeval/workflows/sufficiency.py,sha256=4DTDaYyEuAfO0LTFpQGXXXayV5aCIbziSL2Rddd1vQ0,10360
102
- dataeval-0.88.0.dist-info/METADATA,sha256=Y5NRZgrhfpyGQKHUnqnO6rAItVR3oWUqIp646_0xluQ,5601
103
- dataeval-0.88.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
104
- dataeval-0.88.0.dist-info/licenses/LICENSE,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
105
- dataeval-0.88.0.dist-info/RECORD,,
101
+ dataeval/workflows/sufficiency.py,sha256=m3Z8VquGxefai6nOqoMveYA1XAA_mUf_IL21W-enyxQ,10274
102
+ dataeval-0.88.1.dist-info/METADATA,sha256=9YDLUVCwj9Owh25uBSvsNkudiCot1jYH_4nTQCSkAEM,5601
103
+ dataeval-0.88.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
104
+ dataeval-0.88.1.dist-info/licenses/LICENSE,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
105
+ dataeval-0.88.1.dist-info/RECORD,,