dataeval 0.65.0__py3-none-any.whl → 0.67.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +24 -22
  3. dataeval/_internal/detectors/drift/base.py +206 -26
  4. dataeval/_internal/detectors/drift/cvm.py +25 -23
  5. dataeval/_internal/detectors/drift/ks.py +28 -25
  6. dataeval/_internal/detectors/drift/mmd.py +30 -29
  7. dataeval/_internal/detectors/drift/torch.py +66 -58
  8. dataeval/_internal/detectors/drift/uncertainty.py +28 -28
  9. dataeval/_internal/detectors/duplicates.py +28 -18
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +61 -43
  13. dataeval/_internal/detectors/ood/llr.py +27 -24
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +33 -27
  17. dataeval/_internal/flags.py +5 -3
  18. dataeval/_internal/interop.py +4 -2
  19. dataeval/_internal/metrics/balance.py +33 -4
  20. dataeval/_internal/metrics/ber.py +6 -4
  21. dataeval/_internal/metrics/diversity.py +70 -27
  22. dataeval/_internal/metrics/parity.py +114 -26
  23. dataeval/_internal/metrics/stats.py +154 -16
  24. dataeval/_internal/metrics/uap.py +28 -2
  25. dataeval/_internal/metrics/utils.py +20 -18
  26. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  27. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  28. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  29. dataeval/_internal/models/tensorflow/losses.py +15 -11
  30. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  31. dataeval/_internal/models/tensorflow/trainer.py +8 -6
  32. dataeval/_internal/models/tensorflow/utils.py +21 -19
  33. dataeval/_internal/output.py +13 -10
  34. dataeval/_internal/utils.py +5 -3
  35. dataeval/_internal/workflows/sufficiency.py +42 -30
  36. dataeval/detectors/__init__.py +6 -25
  37. dataeval/detectors/drift/__init__.py +16 -0
  38. dataeval/detectors/drift/kernels/__init__.py +6 -0
  39. dataeval/detectors/drift/updates/__init__.py +3 -0
  40. dataeval/detectors/linters/__init__.py +5 -0
  41. dataeval/detectors/ood/__init__.py +11 -0
  42. dataeval/metrics/__init__.py +2 -26
  43. dataeval/metrics/bias/__init__.py +14 -0
  44. dataeval/metrics/estimators/__init__.py +9 -0
  45. dataeval/metrics/stats/__init__.py +6 -0
  46. dataeval/tensorflow/__init__.py +3 -0
  47. dataeval/tensorflow/loss/__init__.py +3 -0
  48. dataeval/tensorflow/models/__init__.py +5 -0
  49. dataeval/tensorflow/recon/__init__.py +3 -0
  50. dataeval/torch/__init__.py +3 -0
  51. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  52. dataeval/torch/trainer/__init__.py +3 -0
  53. dataeval/utils/__init__.py +3 -6
  54. dataeval/workflows/__init__.py +2 -4
  55. {dataeval-0.65.0.dist-info → dataeval-0.67.0.dist-info}/METADATA +1 -1
  56. dataeval-0.67.0.dist-info/RECORD +72 -0
  57. dataeval/models/__init__.py +0 -15
  58. dataeval/models/tensorflow/__init__.py +0 -6
  59. dataeval-0.65.0.dist-info/RECORD +0 -60
  60. {dataeval-0.65.0.dist-info → dataeval-0.67.0.dist-info}/LICENSE.txt +0 -0
  61. {dataeval-0.65.0.dist-info → dataeval-0.67.0.dist-info}/WHEEL +0 -0
@@ -1,4 +1,6 @@
1
- from typing import Any, List, Union
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
2
4
 
3
5
  import torch
4
6
  import torch.nn as nn
@@ -14,40 +16,52 @@ def get_images_from_batch(batch: Any) -> Any:
14
16
 
15
17
 
16
18
  class AETrainer:
19
+ """
20
+ A class to train and evaluate an autoencoder model.
21
+
22
+ Parameters
23
+ ----------
24
+ model : nn.Module
25
+ The model to be trained.
26
+ device : str or torch.device, default "auto"
27
+ The hardware device to use for training.
28
+ If "auto", the device will be set to "cuda" if available, otherwise "cpu".
29
+ batch_size : int, default 8
30
+ The number of images to process in a batch.
31
+ """
32
+
17
33
  def __init__(
18
34
  self,
19
35
  model: nn.Module,
20
- device: Union[str, torch.device] = "auto",
36
+ device: str | torch.device = "auto",
21
37
  batch_size: int = 8,
22
38
  ):
23
- """
24
- model : nn.Module
25
- Model to be trained
26
- device : str | torch.device, default "cpu"
27
- Hardware device for model, optimizer, and data to run on
28
- batch_size : int, default 8
29
- Number of images to group together in `torch.utils.data.DataLoader`
30
- """
31
39
  if device == "auto":
32
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
41
  self.device = device
34
42
  self.model = model.to(device)
35
43
  self.batch_size = batch_size
36
44
 
37
- def train(self, dataset: Dataset, epochs: int = 25) -> List[float]:
45
+ def train(self, dataset: Dataset, epochs: int = 25) -> list[float]:
38
46
  """
39
- Basic training function for Autoencoder models for reconstruction tasks
47
+ Basic image reconstruction training function for Autoencoder models
40
48
 
41
49
  Uses `torch.optim.Adam` and `torch.nn.MSELoss` as default hyperparameters
42
50
 
43
51
  Parameters
44
52
  ----------
45
53
  dataset : Dataset
46
- Torch Dataset containing images in the first return position
54
+ The dataset to train on.
55
+ Torch Dataset containing images in the first return position.
47
56
  epochs : int, default 25
48
57
  Number of full training loops
49
58
 
50
- Note
59
+ Returns
60
+ -------
61
+ List[float]
62
+ A list of average loss values for each epoch.
63
+
64
+ Notes
51
65
  ----
52
66
  To replace this function with a custom function, do
53
67
  AETrainer.train = custom_function
@@ -58,7 +72,7 @@ class AETrainer:
58
72
  opt = Adam(self.model.parameters(), lr=0.001)
59
73
  criterion = nn.MSELoss().to(self.device)
60
74
  # Record loss
61
- loss_history: List[float] = []
75
+ loss_history: list[float] = []
62
76
 
63
77
  for _ in range(epochs):
64
78
  epoch_loss: float = 0
@@ -89,19 +103,20 @@ class AETrainer:
89
103
  @torch.no_grad
90
104
  def eval(self, dataset: Dataset) -> float:
91
105
  """
92
- Basic evaluation function for Autoencoder models for reconstruction tasks
106
+ Basic image reconstruction evaluation function for Autoencoder models
93
107
 
94
- Uses `torch.optim.Adam` and `torch.nn.MSELoss` as default hyperparameters
108
+ Uses `torch.nn.MSELoss` as default loss function.
95
109
 
96
110
  Parameters
97
111
  ----------
98
112
  dataset : Dataset
99
- Torch Dataset containing images in the first return position
113
+ The dataset to evaluate on.
114
+ Torch Dataset containing images in the first return position.
100
115
 
101
116
  Returns
102
117
  -------
103
118
  float
104
- Total reconstruction loss over all data
119
+ Total reconstruction loss over the entire dataset
105
120
 
106
121
  Note
107
122
  ----
@@ -124,18 +139,25 @@ class AETrainer:
124
139
  @torch.no_grad
125
140
  def encode(self, dataset: Dataset) -> torch.Tensor:
126
141
  """
127
- Encode data through model if it has an encode attribute,
128
- otherwise passes data through model.forward
142
+ Create image embeddings for the dataset using the model's encoder.
143
+
144
+ If the model has an `encode` method, it will be used; otherwise,
145
+ `model.forward` will be used.
129
146
 
130
147
  Parameters
131
148
  ----------
132
149
  dataset: Dataset
133
- Dataset containing images to be encoded by the model
150
+ The dataset to encode.
151
+ Torch Dataset containing images in the first return position.
134
152
 
135
153
  Returns
136
154
  -------
137
155
  torch.Tensor
138
156
  Data encoded by the model
157
+
158
+ Notes
159
+ -----
160
+ This function should be run after the model has been trained and evaluated.
139
161
  """
140
162
  self.model.eval()
141
163
  dl = DataLoader(dataset, batch_size=self.batch_size)
@@ -155,21 +177,67 @@ class AETrainer:
155
177
 
156
178
 
157
179
  class AriaAutoencoder(nn.Module):
180
+ """
181
+ An autoencoder model with a separate encoder and decoder.
182
+
183
+ Parameters
184
+ ----------
185
+ channels : int, default 3
186
+ Number of input channels
187
+ """
188
+
158
189
  def __init__(self, channels=3):
159
190
  super().__init__()
160
191
  self.encoder = Encoder(channels)
161
192
  self.decoder = Decoder(channels)
162
193
 
163
194
  def forward(self, x):
195
+ """
196
+ Perform a forward pass through the encoder and decoder.
197
+
198
+ Parameters
199
+ ----------
200
+ x : torch.Tensor
201
+ Input tensor
202
+
203
+ Returns
204
+ -------
205
+ torch.Tensor
206
+ The reconstructed output tensor.
207
+ """
164
208
  x = self.encoder(x)
165
209
  x = self.decoder(x)
166
210
  return x
167
211
 
168
212
  def encode(self, x):
213
+ """
214
+ Encode the input tensor using the encoder.
215
+
216
+ Parameters
217
+ ----------
218
+ x : torch.Tensor
219
+ Input tensor
220
+
221
+ Returns
222
+ -------
223
+ torch.Tensor
224
+ The encoded representation of the input tensor.
225
+ """
169
226
  return self.encoder(x)
170
227
 
171
228
 
172
229
  class Encoder(nn.Module):
230
+ """
231
+ A simple encoder to be used in an autoencoder model.
232
+
233
+ This is the encoder used by the AriaAutoencoder model.
234
+
235
+ Parameters
236
+ ----------
237
+ channels : int, default 3
238
+ Number of input channels
239
+ """
240
+
173
241
  def __init__(self, channels=3):
174
242
  super().__init__()
175
243
  self.encoder = nn.Sequential(
@@ -183,10 +251,34 @@ class Encoder(nn.Module):
183
251
  )
184
252
 
185
253
  def forward(self, x):
254
+ """
255
+ Perform a forward pass through the encoder.
256
+
257
+ Parameters
258
+ ----------
259
+ x : torch.Tensor
260
+ Input tensor
261
+
262
+ Returns
263
+ -------
264
+ torch.Tensor
265
+ The encoded representation of the input tensor.
266
+ """
186
267
  return self.encoder(x)
187
268
 
188
269
 
189
270
  class Decoder(nn.Module):
271
+ """
272
+ A simple decoder to be used in an autoencoder model.
273
+
274
+ This is the decoder used by the AriaAutoencoder model.
275
+
276
+ Parameters
277
+ ----------
278
+ channels : int
279
+ Number of output channels
280
+ """
281
+
190
282
  def __init__(self, channels):
191
283
  super().__init__()
192
284
  self.decoder = nn.Sequential(
@@ -199,4 +291,17 @@ class Decoder(nn.Module):
199
291
  )
200
292
 
201
293
  def forward(self, x):
294
+ """
295
+ Perform a forward pass through the decoder.
296
+
297
+ Parameters
298
+ ----------
299
+ x : torch.Tensor
300
+ The encoded tensor.
301
+
302
+ Returns
303
+ -------
304
+ torch.Tensor
305
+ The reconstructed output tensor.
306
+ """
202
307
  return self.decoder(x)
@@ -8,7 +8,9 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  # pyright: reportIncompatibleMethodOverride=false
10
10
 
11
- from typing import Callable, Tuple, cast
11
+ from __future__ import annotations
12
+
13
+ from typing import Callable, cast
12
14
 
13
15
  import keras
14
16
  import tensorflow as tf
@@ -56,16 +58,17 @@ def eucl_cosim_features(x: tf.Tensor, y: tf.Tensor, max_eucl: float = 1e2) -> tf
56
58
 
57
59
  Parameters
58
60
  ----------
59
- x
61
+ x : tf.Tensor
60
62
  Tensor used in feature computation.
61
- y
63
+ y : tf.Tensor
62
64
  Tensor used in feature computation.
63
- max_eucl
65
+ max_eucl : float, default 1e2
64
66
  Maximum value to clip relative Euclidean distance by.
65
67
 
66
68
  Returns
67
69
  -------
68
- Tensor concatenating the relative Euclidean distance and cosine similarity features.
70
+ tf.Tensor
71
+ Tensor concatenating the relative Euclidean distance and cosine similarity features.
69
72
  """
70
73
  if len(x.shape) > 2 or len(y.shape) > 2:
71
74
  x = cast(tf.Tensor, Flatten()(x))
@@ -78,9 +81,9 @@ def eucl_cosim_features(x: tf.Tensor, y: tf.Tensor, max_eucl: float = 1e2) -> tf
78
81
 
79
82
 
80
83
  class Sampling(Layer):
81
- """Reparametrization trick. Uses (z_mean, z_log_var) to sample the latent vector z."""
84
+ """Reparametrization trick - Uses (z_mean, z_log_var) to sample the latent vector z."""
82
85
 
83
- def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
86
+ def call(self, inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
84
87
  """
85
88
  Sample z.
86
89
 
@@ -138,7 +141,7 @@ class EncoderVAE(Layer):
138
141
  self.fc_log_var = Dense(latent_dim, activation=None)
139
142
  self.sampling = Sampling()
140
143
 
141
- def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
144
+ def call(self, x: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
142
145
  x = cast(tf.Tensor, self.encoder_net(x))
143
146
  if len(x.shape) > 2:
144
147
  x = cast(tf.Tensor, Flatten()(x))
@@ -173,9 +176,9 @@ class AE(keras.Model):
173
176
 
174
177
  Parameters
175
178
  ----------
176
- encoder_net
179
+ encoder_net : keras.Model
177
180
  Layers for the encoder wrapped in a keras.Sequential class.
178
- decoder_net
181
+ decoder_net : keras.Model
179
182
  Layers for the decoder wrapped in a keras.Sequential class.
180
183
  """
181
184
 
@@ -196,13 +199,13 @@ class VAE(keras.Model):
196
199
 
197
200
  Parameters
198
201
  ----------
199
- encoder_net
202
+ encoder_net : keras.Model
200
203
  Layers for the encoder wrapped in a keras.Sequential class.
201
- decoder_net
204
+ decoder_net : keras.Model
202
205
  Layers for the decoder wrapped in a keras.Sequential class.
203
- latent_dim
206
+ latent_dim : int
204
207
  Dimensionality of the latent space.
205
- beta
208
+ beta : float, default 1.0
206
209
  Beta parameter for KL-divergence loss term.
207
210
  """
208
211
 
@@ -214,7 +217,7 @@ class VAE(keras.Model):
214
217
  self.latent_dim = latent_dim
215
218
 
216
219
  def call(self, x: tf.Tensor) -> tf.Tensor:
217
- z_mean, z_log_var, z = cast(Tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(x))
220
+ z_mean, z_log_var, z = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(x))
218
221
  x_recon = self.decoder(z)
219
222
  # add KL divergence loss term
220
223
  kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
@@ -228,15 +231,15 @@ class AEGMM(keras.Model):
228
231
 
229
232
  Parameters
230
233
  ----------
231
- encoder_net
234
+ encoder_net : keras.Model
232
235
  Layers for the encoder wrapped in a keras.Sequential class.
233
- decoder_net
236
+ decoder_net : keras.Model
234
237
  Layers for the decoder wrapped in a keras.Sequential class.
235
- gmm_density_net
238
+ gmm_density_net : keras.Model
236
239
  Layers for the GMM network wrapped in a keras.Sequential class.
237
- n_gmm
240
+ n_gmm : int
238
241
  Number of components in GMM.
239
- recon_features
242
+ recon_features : Callable, default eucl_cosim_features
240
243
  Function to extract features from the reconstructed instance by the decoder.
241
244
  """
242
245
 
@@ -255,7 +258,7 @@ class AEGMM(keras.Model):
255
258
  self.n_gmm = n_gmm
256
259
  self.recon_features = recon_features
257
260
 
258
- def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
261
+ def call(self, x: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
259
262
  enc = self.encoder(x)
260
263
  x_recon = cast(tf.Tensor, self.decoder(enc))
261
264
  recon_features = self.recon_features(x, x_recon)
@@ -270,19 +273,19 @@ class VAEGMM(keras.Model):
270
273
 
271
274
  Parameters
272
275
  ----------
273
- encoder_net
276
+ encoder_net : keras.Model
274
277
  Layers for the encoder wrapped in a keras.Sequential class.
275
- decoder_net
278
+ decoder_net : keras.Model
276
279
  Layers for the decoder wrapped in a keras.Sequential class.
277
- gmm_density_net
280
+ gmm_density_net : keras.Model
278
281
  Layers for the GMM network wrapped in a keras.Sequential class.
279
- n_gmm
282
+ n_gmm : int
280
283
  Number of components in GMM.
281
- latent_dim
284
+ latent_dim : int
282
285
  Dimensionality of the latent space.
283
- recon_features
286
+ recon_features : Callable, default eucl_cosim_features
284
287
  Function to extract features from the reconstructed instance by the decoder.
285
- beta
288
+ beta : float, default 1.0
286
289
  Beta parameter for KL-divergence loss term.
287
290
  """
288
291
 
@@ -305,8 +308,8 @@ class VAEGMM(keras.Model):
305
308
  self.recon_features = recon_features
306
309
  self.beta = beta
307
310
 
308
- def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
309
- enc_mean, enc_log_var, enc = cast(Tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(x))
311
+ def call(self, x: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
312
+ enc_mean, enc_log_var, enc = cast(tuple[tf.Tensor, tf.Tensor, tf.Tensor], self.encoder(x))
310
313
  x_recon = cast(tf.Tensor, self.decoder(enc))
311
314
  recon_features = self.recon_features(x, x_recon)
312
315
  z = cast(tf.Tensor, tf.concat([enc, recon_features], -1))
@@ -6,7 +6,9 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
- from typing import NamedTuple, Tuple
9
+ from __future__ import annotations
10
+
11
+ from typing import NamedTuple
10
12
 
11
13
  import numpy as np
12
14
  import tensorflow as tf
@@ -75,7 +77,7 @@ def gmm_energy(
75
77
  z: tf.Tensor,
76
78
  params: GaussianMixtureModelParams,
77
79
  return_mean: bool = True,
78
- ) -> Tuple[tf.Tensor, tf.Tensor]:
80
+ ) -> tuple[tf.Tensor, tf.Tensor]:
79
81
  """
80
82
  Compute sample energy from Gaussian Mixture Model.
81
83
 
@@ -6,7 +6,9 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
- from typing import Literal, Optional, Union, cast
9
+ from __future__ import annotations
10
+
11
+ from typing import Literal, cast
10
12
 
11
13
  import tensorflow as tf
12
14
  from keras.layers import Flatten
@@ -20,22 +22,24 @@ from dataeval._internal.models.tensorflow.gmm import gmm_energy, gmm_params
20
22
 
21
23
  class Elbo:
22
24
  """
23
- Compute ELBO loss. The covariance matrix can be specified by passing the full covariance matrix, the matrix
25
+ Compute ELBO loss.
26
+
27
+ The covariance matrix can be specified by passing the full covariance matrix, the matrix
24
28
  diagonal, or a scale identity multiplier. Only one of these should be specified. If none are specified, the
25
29
  identity matrix is used.
26
30
 
27
31
  Parameters
28
32
  ----------
29
- cov_type
33
+ cov_type : Union[Literal["cov_full", "cov_diag"], float], default 1.0
30
34
  Full covariance matrix, diagonal variance matrix, or scale identity multiplier.
31
- x
35
+ x : ArrayLike, optional - default None
32
36
  Dataset used to calculate the covariance matrix. Required for full and diagonal covariance matrix types.
33
37
  """
34
38
 
35
39
  def __init__(
36
40
  self,
37
- cov_type: Union[Literal["cov_full", "cov_diag"], float] = 1.0,
38
- x: Optional[Union[tf.Tensor, NDArray]] = None,
41
+ cov_type: Literal["cov_full", "cov_diag"] | float = 1.0,
42
+ x: tf.Tensor | NDArray | None = None,
39
43
  ):
40
44
  if isinstance(cov_type, float):
41
45
  self.cov = ("sim", cov_type)
@@ -67,13 +71,13 @@ class LossGMM:
67
71
 
68
72
  Parameters
69
73
  ----------
70
- w_recon
74
+ w_recon : float, default 1e-7
71
75
  Weight on elbo loss term.
72
- w_energy
76
+ w_energy : float, default 0.1
73
77
  Weight on sample energy loss term.
74
- w_cov_diag
78
+ w_cov_diag : float, default 0.005
75
79
  Weight on covariance regularizing loss term.
76
- elbo
80
+ elbo : Elbo, optional - default None
77
81
  ELBO loss function used to calculate w_recon.
78
82
  """
79
83
 
@@ -82,7 +86,7 @@ class LossGMM:
82
86
  w_recon: float = 1e-7,
83
87
  w_energy: float = 0.1,
84
88
  w_cov_diag: float = 0.005,
85
- elbo: Optional[Elbo] = None,
89
+ elbo: Elbo | None = None,
86
90
  ):
87
91
  self.w_recon = w_recon
88
92
  self.w_energy = w_energy
@@ -8,9 +8,10 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
8
8
  Licensed under Apache Software License (Apache 2.0)
9
9
  """
10
10
 
11
+ from __future__ import annotations
12
+
11
13
  import functools
12
14
  import warnings
13
- from typing import Optional
14
15
 
15
16
  import keras
16
17
  import numpy as np
@@ -238,47 +239,47 @@ class PixelCNN(distribution.Distribution):
238
239
 
239
240
  Parameters
240
241
  ----------
241
- image_shape
242
+ image_shape : tuple
242
243
  3D `TensorShape` or tuple for the `[height, width, channels]` dimensions of the image.
243
- conditional_shape
244
+ conditional_shape : tuple, optional - default None
244
245
  `TensorShape` or tuple for the shape of the conditional input, or `None` if there is no conditional input.
245
- num_resnet
246
+ num_resnet : int, default 5
246
247
  The number of layers (shown in Figure 2 of [2]) within each highest-level block of Figure 2 of [1].
247
- num_hierarchies
248
+ num_hierarchies : int, default 3
248
249
  The number of highest-level blocks (separated by expansions/contractions of dimensions in Figure 2 of [1].)
249
- num_filters
250
+ num_filters : int, default 160
250
251
  The number of convolutional filters.
251
- num_logistic_mix
252
+ num_logistic_mix : int, default 10
252
253
  Number of components in the logistic mixture distribution.
253
- receptive_field_dims
254
+ receptive_field_dims tuple, default (3, 3)
254
255
  Height and width in pixels of the receptive field of the convolutional layers above and to the left
255
256
  of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of [2]
256
257
  shows a receptive field of (3, 5) (the row containing the current pixel is included in the height).
257
258
  The default of (3, 3) was used to produce the results in [1].
258
- dropout_p
259
+ dropout_p : float, default 0.0
259
260
  The dropout probability. Should be between 0 and 1.
260
- resnet_activation
261
+ resnet_activation : str, default "concat_elu"
261
262
  The type of activation to use in the resnet blocks. May be 'concat_elu', 'elu', or 'relu'.
262
- l2_weight
263
+ l2_weight : float, default 0.0
263
264
  The L2 regularization weight.
264
- use_weight_norm
265
+ use_weight_norm : bool, default True
265
266
  If `True` then use weight normalization (works only in Eager mode).
266
- use_data_init
267
+ use_data_init : bool, default True
267
268
  If `True` then use data-dependent initialization (has no effect if `use_weight_norm` is `False`).
268
- high
269
+ high : int, default 255
269
270
  The maximum value of the input data (255 for an 8-bit image).
270
- low
271
+ low : int, default 0
271
272
  The minimum value of the input data.
272
- dtype
273
+ dtype : tensorflow dtype, default tf.float32
273
274
  Data type of the `Distribution`.
274
- name
275
+ name : str, default "PixelCNN"
275
276
  The name of the `Distribution`.
276
277
  """
277
278
 
278
279
  def __init__(
279
280
  self,
280
281
  image_shape: tuple,
281
- conditional_shape: Optional[tuple] = None,
282
+ conditional_shape: tuple | None = None,
282
283
  num_resnet: int = 5,
283
284
  num_hierarchies: int = 3,
284
285
  num_filters: int = 160,
@@ -6,7 +6,9 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
- from typing import Callable, Iterable, Optional, Tuple, cast
9
+ from __future__ import annotations
10
+
11
+ from typing import Callable, Iterable, cast
10
12
 
11
13
  import keras
12
14
  import numpy as np
@@ -17,10 +19,10 @@ from numpy.typing import NDArray
17
19
  def trainer(
18
20
  model: keras.Model,
19
21
  x_train: NDArray,
20
- y_train: Optional[NDArray] = None,
21
- loss_fn: Optional[Callable[..., tf.Tensor]] = None,
22
+ y_train: NDArray | None = None,
23
+ loss_fn: Callable[..., tf.Tensor] | None = None,
22
24
  optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
23
- preprocess_fn: Optional[Callable[[tf.Tensor], tf.Tensor]] = None,
25
+ preprocess_fn: Callable[[tf.Tensor], tf.Tensor] | None = None,
24
26
  epochs: int = 20,
25
27
  reg_loss_fn: Callable[[keras.Model], tf.Tensor] = (lambda _: cast(tf.Tensor, tf.Variable(0, dtype=tf.float32))),
26
28
  batch_size: int = 64,
@@ -70,14 +72,14 @@ def trainer(
70
72
  dataset.on_epoch_end() # type: ignore py39
71
73
  loss_val_ma = 0.0
72
74
  for step, data in enumerate(dataset):
73
- x, y = cast(Tuple[tf.Tensor, Optional[tf.Tensor]], data if isinstance(data, tuple) else (data, None))
75
+ x, y = data if isinstance(data, tuple) else (data, None)
74
76
  if isinstance(preprocess_fn, Callable):
75
77
  x = preprocess_fn(x)
76
78
  with tf.GradientTape() as tape:
77
79
  y_hat = model(x)
78
80
  y = x if y is None else y
79
81
  if isinstance(loss_fn, Callable):
80
- args = [y] + list(y_hat) if isinstance(y_hat, Tuple) else [y, y_hat]
82
+ args = [y] + list(y_hat) if isinstance(y_hat, tuple) else [y, y_hat]
81
83
  loss = loss_fn(*args)
82
84
  else:
83
85
  loss = cast(tf.Tensor, tf.constant(0.0, dtype=tf.float32))