keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,332 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+
5
+ from keras.src import backend
6
+ from keras.src import tree
7
+ from keras.src.api_export import keras_export
8
+ from keras.src.callbacks.monitor_callback import (
9
+ MonitorCallback, # For metric monitoring logic
10
+ )
11
+ from keras.src.utils.module_utils import ocp
12
+
13
+ # Context and AsyncOptions are accessed through the lazy-loaded ocp module
14
+
15
+ # JAX monitoring compatibility: ensure record_scalar exists
16
+ # to prevent AttributeError in older JAX versions
17
+ try:
18
+ import jax
19
+
20
+ if not hasattr(jax.monitoring, "record_scalar"):
21
+ jax.monitoring.record_scalar = lambda *args, **kwargs: None
22
+ except ImportError:
23
+ pass
24
+
25
+
26
+ def _get_state_tree(model):
27
+ """Get the complete model state as a nested tree structure."""
28
+ # For JAX backend, preserve native arrays for performance
29
+ # For other backends, convert to numpy arrays
30
+ if backend.backend() == "jax":
31
+ state_tree = model.get_state_tree()
32
+ did_numpy_conversion = False
33
+ else:
34
+ state_tree = model.get_state_tree(value_format="numpy_array")
35
+ did_numpy_conversion = True
36
+
37
+ # Convert numpy scalar types to Python types for Orbax compatibility
38
+ # Only needed when we did numpy conversion
39
+ if did_numpy_conversion:
40
+
41
+ def convert_scalars(obj):
42
+ if isinstance(obj, np.ndarray) and obj.ndim == 0:
43
+ # Convert 0-dimensional numpy arrays (scalars) to Python types
44
+ return obj.item()
45
+ elif isinstance(obj, np.generic):
46
+ # Convert numpy scalar types (like np.float32) to Python types
47
+ return obj.item()
48
+ else:
49
+ return obj
50
+
51
+ return tree.map_structure(convert_scalars, state_tree)
52
+ else:
53
+ return state_tree
54
+
55
+
56
+ @keras_export("keras.callbacks.OrbaxCheckpoint")
57
+ class OrbaxCheckpoint(MonitorCallback):
58
+ """Callback to save and load model state using Orbax with a similar API to
59
+ ModelCheckpoint.
60
+
61
+ This callback saves the model's weights and optimizer state asynchronously
62
+ using Orbax, allowing training to continue without blocking for I/O.
63
+
64
+ **Multi-host Support**: When running in a multi-host distributed training
65
+ environment with JAX backend, this callback automatically coordinates
66
+ checkpointing across all hosts to ensure consistency and proper
67
+ synchronization. Multi-host checkpointing is only supported on JAX.
68
+
69
+ Example:
70
+
71
+ ```python
72
+ model.compile(loss=..., optimizer=..., metrics=['accuracy'])
73
+
74
+ EPOCHS = 10
75
+ checkpoint_dir = '/tmp/ckpt'
76
+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
77
+ directory=checkpoint_dir,
78
+ monitor='val_accuracy',
79
+ mode='max',
80
+ save_best_only=True)
81
+
82
+ # Model is saved at the end of every epoch, if it's the best seen so far.
83
+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
84
+
85
+ # Alternatively, save checkpoints every N batches -
86
+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
87
+ directory=checkpoint_dir,
88
+ save_freq=100) # Save every 100 batches
89
+
90
+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
91
+ ```
92
+
93
+ Args:
94
+ directory: path to the directory where to save the checkpoints.
95
+ monitor: The metric name to monitor (e.g., 'val_loss').
96
+ verbose: Verbosity mode, 0 or 1.
97
+ save_best_only: if `save_best_only=True`, it only saves when the model
98
+ is considered the "best" based on the monitored quantity.
99
+ mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
100
+ save_freq: `'epoch'` or integer. Frequency to save checkpoints.
101
+ max_to_keep: Integer, maximum number of recent checkpoints to keep.
102
+ If None, keeps all. Defaults to 1.
103
+ save_on_background: Boolean, whether to save asynchronously in the
104
+ background. Defaults to True.
105
+ initial_value_threshold: Floating point initial "best" value for the
106
+ monitor, used with `save_best_only`.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ directory,
112
+ monitor="val_loss",
113
+ verbose=0,
114
+ save_best_only=False,
115
+ mode="auto",
116
+ save_freq="epoch",
117
+ initial_value_threshold=None,
118
+ max_to_keep=1,
119
+ save_on_background=True,
120
+ ):
121
+ # Ensure orbax is available
122
+ ocp.initialize()
123
+
124
+ # Initialize MonitorCallback for handling 'monitor', 'mode', 'best'
125
+ # logic
126
+ super().__init__(monitor, mode, initial_value_threshold)
127
+
128
+ self.directory = directory
129
+ self.verbose = verbose
130
+ self.save_best_only = save_best_only
131
+ self.save_freq = save_freq
132
+ self.max_to_keep = max_to_keep
133
+ self.save_on_background = save_on_background
134
+ self._batches_seen_since_last_saving = 0
135
+ self._last_batch_seen = 0
136
+ self._current_epoch = 0 # Keep track of epoch
137
+ self._total_batches_seen = 0 # Global batch counter for step tracking
138
+
139
+ # Multi-host support
140
+ self._multihost_initialized = self._is_multihost_initialized()
141
+
142
+ if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
143
+ raise ValueError(
144
+ f"Unrecognized save_freq: {self.save_freq}. "
145
+ "Expected save_freq are 'epoch' or integer values"
146
+ )
147
+
148
+ # --- Orbax Checkpointer Setup (V1 API) ---
149
+ policies = []
150
+ if max_to_keep is not None:
151
+ policies.append(
152
+ ocp.training.preservation_policies.LatestN(max_to_keep)
153
+ )
154
+
155
+ # Use AnyPreservationPolicy to combine them, or use directly
156
+ # if single policy
157
+ preservation_policy = None
158
+ if policies:
159
+ if len(policies) == 1:
160
+ preservation_policy = policies[0]
161
+ else:
162
+ preservation_policy = (
163
+ ocp.training.preservation_policies.AnyPreservationPolicy(
164
+ policies
165
+ )
166
+ )
167
+
168
+ # Create the V1 Checkpointer with direct parameter passing
169
+ # Orbax will handle directory creation on all processes as needed
170
+ self.checkpointer = ocp.training.Checkpointer(
171
+ directory=directory,
172
+ preservation_policy=preservation_policy,
173
+ )
174
+
175
+ def _is_multihost_initialized(self):
176
+ """Check if multi-host environment is initialized."""
177
+ # Multi-host checkpointing is only supported on JAX backend
178
+ if backend.backend() != "jax":
179
+ return False
180
+
181
+ multihost = ocp.multihost
182
+ # Check if JAX distributed client is initialized
183
+ # (indicates multihost setup)
184
+ return multihost.is_jax_distributed_client_initialized()
185
+
186
+ def _sync_processes(self, key=None):
187
+ """Synchronize all processes across hosts."""
188
+ if not self._multihost_initialized:
189
+ return # No-op for single host
190
+
191
+ multihost = ocp.multihost
192
+ sync_key = key or "orbax_checkpoint_sync"
193
+ multihost.sync_global_processes(sync_key)
194
+
195
+ def is_multihost_enabled(self):
196
+ """Return True if multi-host checkpointing is enabled and initialized.
197
+
198
+ This method can be used to check if the callback is operating in
199
+ a multi-host distributed training environment. Multi-host checkpointing
200
+ is only supported on JAX backend.
201
+
202
+ Returns:
203
+ bool: True if multi-host support is active, False otherwise.
204
+ """
205
+ return self._multihost_initialized
206
+
207
+ def is_primary_host(self):
208
+ """Return True if this process is the primary host in multi-host setup.
209
+
210
+ In multi-host environments, only the primary host typically handles
211
+ logging and coordination tasks. Multi-host checkpointing is only
212
+ supported on JAX backend.
213
+
214
+ Returns:
215
+ bool: True if this is the primary host, False otherwise.
216
+ Always returns True in single-host environments.
217
+ """
218
+ if not self._multihost_initialized:
219
+ return True # Single host is always primary
220
+ multihost = ocp.multihost
221
+ return multihost.is_primary_host()
222
+
223
+ def _should_save_on_batch(self, batch):
224
+ """Check if we should save on this batch."""
225
+ if self.save_freq == "epoch":
226
+ return False
227
+
228
+ if batch <= self._last_batch_seen: # New epoch.
229
+ add_batches = batch + 1
230
+ else:
231
+ add_batches = batch - self._last_batch_seen
232
+ self._batches_seen_since_last_saving += add_batches
233
+ self._last_batch_seen = batch
234
+ self._total_batches_seen += add_batches
235
+
236
+ if self._batches_seen_since_last_saving >= self.save_freq:
237
+ self._batches_seen_since_last_saving = 0
238
+ return True
239
+ return False
240
+
241
+ def _save_checkpoint(self, step, logs=None):
242
+ """Save a checkpoint at the given step with multi-host coordination."""
243
+
244
+ # --- Prepare Composite State (Backend-Agnostic) ---
245
+ state_tree = _get_state_tree(self.model)
246
+
247
+ # Save the nested state structures directly (preserving layer
248
+ # names and structure)
249
+ composite_state = state_tree
250
+
251
+ # Use a single with statement. If context_options is empty,
252
+ # Context() uses defaults.
253
+ with ocp.Context():
254
+ if self.save_on_background:
255
+ self.checkpointer.save_pytree_async(step, composite_state)
256
+ else:
257
+ self.checkpointer.save_pytree(step, composite_state)
258
+
259
+ def on_train_batch_end(self, batch, logs=None):
260
+ if self._should_save_on_batch(batch):
261
+ # Handle save_best_only logic for batch-level saving
262
+ should_save = True
263
+ if self.save_best_only:
264
+ current = logs.get(self.monitor) if logs else None
265
+ if current is None:
266
+ warnings.warn(
267
+ f"Can save best model only with {self.monitor} "
268
+ f"available, skipping save at batch {batch}.",
269
+ stacklevel=2,
270
+ )
271
+ should_save = False
272
+ elif not self._is_improvement(current, self.best):
273
+ should_save = False
274
+ else:
275
+ # Update best value when there's improvement
276
+ self.best = current
277
+
278
+ if should_save:
279
+ # Use global batch count for Orbax save step
280
+ step = self._total_batches_seen
281
+ self._save_checkpoint(step=step, logs=logs)
282
+
283
+ def on_epoch_end(self, epoch, logs=None):
284
+ self._current_epoch = epoch
285
+ if self.monitor_op is None:
286
+ self._set_monitor_op() # From MonitorCallback
287
+
288
+ # For save_freq="epoch", save at every epoch
289
+ should_save = self.save_freq == "epoch"
290
+
291
+ # Handle save_best_only logic
292
+ if should_save and self.save_best_only:
293
+ current = logs.get(self.monitor) if logs else None
294
+ if current is None:
295
+ warnings.warn(
296
+ f"Can save best model only with {self.monitor} available, "
297
+ f"skipping save at epoch {epoch}.",
298
+ stacklevel=2,
299
+ )
300
+ should_save = False
301
+ elif not self._is_improvement(current, self.best):
302
+ should_save = False
303
+ else:
304
+ # Update best value when there's improvement
305
+ self.best = current
306
+
307
+ if should_save:
308
+ # Use epoch number as the step for Orbax save
309
+ # Keras has already made the save decision - Checkpointer will
310
+ # save unconditionally
311
+ self._save_checkpoint(step=epoch, logs=logs)
312
+
313
+ def on_train_end(self, logs=None):
314
+ # Close the Checkpointer to ensure all pending saves complete
315
+ try:
316
+ self.checkpointer.close()
317
+ except Exception:
318
+ pass # Ignore errors during cleanup
319
+
320
+ # Multi-host synchronization: ensure all hosts complete cleanup
321
+ self._sync_processes("checkpoint_cleanup")
322
+
323
+ def wait_until_finished(self):
324
+ """Wait for any in-progress checkpoint operations to complete.
325
+ This method blocks until all asynchronous checkpoint save operations
326
+ have completed across all hosts in a multi-host setup.
327
+ """
328
+ # Wait for any async operations to complete on this host
329
+ self.checkpointer.wait()
330
+
331
+ # Multi-host synchronization: ensure all hosts complete
332
+ self._sync_processes("checkpoint_wait_complete")
@@ -7,14 +7,63 @@ from keras.src.utils import io_utils
7
7
 
8
8
  @keras_export("keras.callbacks.TerminateOnNaN")
9
9
  class TerminateOnNaN(Callback):
10
- """Callback that terminates training when a NaN loss is encountered."""
10
+ """Callback that terminates training when a NaN loss is encountered.
11
+
12
+ This callback monitors the loss value during training
13
+ and terminates training when a NaN or Inf loss is detected.
14
+ By default, training is stopped gracefully
15
+ by setting `model.stop_training = True`, which triggers all callback cleanup
16
+ methods including `on_train_end()`.
17
+
18
+ Alternatively, you can use `raise_error=True` to immediately raise a
19
+ RuntimeError when NaN/Inf is detected. This raise_error termination
20
+ prevents `on_train_end()` from being called on other callbacks, which
21
+ is useful for preserving backup states or preventing unintended cleanup
22
+ when training fails.
23
+
24
+ Args:
25
+ raise_error: Boolean, default False. If False, uses graceful stop via
26
+ `model.stop_training = True`. If True, immediately raises
27
+ RuntimeError on NaN/Inf loss, bypassing callback cleanup methods.
28
+
29
+ Example:
30
+
31
+ ```
32
+ # Graceful termination (default)
33
+ callback = keras.callbacks.TerminateOnNaN()
34
+ model.fit(x, y, callbacks=[callback])
35
+
36
+ # raise_error termination (strict failure)
37
+ callback = keras.callbacks.TerminateOnNaN(raise_error=True)
38
+ model.fit(x, y, callbacks=[callback])
39
+ ```
40
+ """
41
+
42
+ def __init__(self, raise_error: bool = False):
43
+ super().__init__()
44
+ self.raise_error = raise_error
11
45
 
12
46
  def on_batch_end(self, batch, logs=None):
47
+ """Check for NaN/Inf loss at the end of each batch.
48
+
49
+ Args:
50
+ batch: Integer, index of batch within the current epoch.
51
+ logs: Dict, contains the return value of `model.train_step()`.
52
+
53
+ Raises:
54
+ RuntimeError: If loss is NaN/Inf and raise_error=True.
55
+ """
13
56
  logs = logs or {}
14
57
  loss = logs.get("loss")
15
58
  if loss is not None:
16
59
  if np.isnan(loss) or np.isinf(loss):
17
- io_utils.print_msg(
18
- f"Batch {batch}: Invalid loss, terminating training"
19
- )
20
- self.model.stop_training = True
60
+ if self.raise_error:
61
+ raise RuntimeError(
62
+ f"NaN or Inf loss encountered at batch {batch}. "
63
+ f"Loss value: {loss}. Terminating training immediately."
64
+ )
65
+ else:
66
+ io_utils.print_msg(
67
+ f"Batch {batch}: Invalid loss, terminating training"
68
+ )
69
+ self.model.stop_training = True
@@ -59,6 +59,11 @@ def load_data():
59
59
  assert y_train.shape == (50000, 1)
60
60
  assert y_test.shape == (10000, 1)
61
61
  ```
62
+
63
+ **Note**: The CIFAR-10 dataset is known to have a small percentage of
64
+ mislabeled samples, which is inherent to the original dataset. This label
65
+ noise may impact training and evaluation. For more details, refer to
66
+ discussions in the research literature on CIFAR-10 label quality.
62
67
  """
63
68
  dirname = "cifar-10-batches-py-target"
64
69
  origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
@@ -0,0 +1 @@
1
+ """Distillation module for knowledge distillation in Keras."""