keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -411,6 +411,12 @@ def array(x, dtype=None):
411
411
  return convert_to_tensor(x, dtype=dtype)
412
412
 
413
413
 
414
+ def view(x, dtype=None):
415
+ dtype = to_torch_dtype(dtype)
416
+ x = convert_to_tensor(x)
417
+ return x.view(dtype=dtype)
418
+
419
+
414
420
  def average(x, axis=None, weights=None):
415
421
  x = convert_to_tensor(x)
416
422
  dtypes_to_resolve = [x.dtype, float]
@@ -764,6 +770,12 @@ def empty(shape, dtype=None):
764
770
  return torch.empty(size=shape, dtype=dtype, device=get_device())
765
771
 
766
772
 
773
+ def empty_like(x, dtype=None):
774
+ x = convert_to_tensor(x)
775
+ dtype = to_torch_dtype(dtype or x.dtype)
776
+ return torch.empty_like(x, dtype=dtype, device=get_device())
777
+
778
+
767
779
  def equal(x1, x2):
768
780
  x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
769
781
  return torch.eq(x1, x2)
@@ -946,6 +958,11 @@ def isposinf(x):
946
958
  return torch.isposinf(x)
947
959
 
948
960
 
961
+ def isreal(x):
962
+ x = convert_to_tensor(x)
963
+ return torch.isreal(x)
964
+
965
+
949
966
  def kron(x1, x2):
950
967
  x1 = convert_to_tensor(x1)
951
968
  x2 = convert_to_tensor(x2)
@@ -958,6 +975,20 @@ def lcm(x1, x2):
958
975
  return torch.lcm(x1, x2)
959
976
 
960
977
 
978
+ def ldexp(x1, x2):
979
+ x1 = convert_to_tensor(x1)
980
+ x2 = convert_to_tensor(x2)
981
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
982
+
983
+ if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
984
+ raise TypeError(
985
+ f"ldexp exponent must be an integer type. "
986
+ f"Received: x2 dtype={x2.dtype}"
987
+ )
988
+
989
+ return cast(torch.ldexp(x1, x2), dtype)
990
+
991
+
961
992
  def less(x1, x2):
962
993
  x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
963
994
  return torch.less(x1, x2)
@@ -1351,6 +1382,18 @@ def prod(x, axis=None, keepdims=False, dtype=None):
1351
1382
  return x
1352
1383
 
1353
1384
 
1385
+ def ptp(x, axis=None, keepdims=False):
1386
+ x = convert_to_tensor(x)
1387
+ if axis is None:
1388
+ return x.max() - x.min()
1389
+ elif axis == ():
1390
+ return torch.zeros_like(x)
1391
+ else:
1392
+ return torch.amax(x, dim=axis, keepdim=keepdims) - torch.amin(
1393
+ x, dim=axis, keepdim=keepdims
1394
+ )
1395
+
1396
+
1354
1397
  def quantile(x, q, axis=None, method="linear", keepdims=False):
1355
1398
  x = convert_to_tensor(x)
1356
1399
  q = convert_to_tensor(q)
@@ -1528,6 +1571,12 @@ def split(x, indices_or_sections, axis=0):
1528
1571
  return list(out)
1529
1572
 
1530
1573
 
1574
+ def array_split(x, indices_or_sections, axis=0):
1575
+ x = convert_to_tensor(x)
1576
+ out = torch.tensor_split(x, indices_or_sections, dim=axis)
1577
+ return list(out)
1578
+
1579
+
1531
1580
  def stack(x, axis=0):
1532
1581
  x = [convert_to_tensor(elem) for elem in x]
1533
1582
  return torch.stack(x, dim=axis)
@@ -1641,8 +1690,9 @@ def tile(x, repeats):
1641
1690
  def trace(x, offset=0, axis1=0, axis2=1):
1642
1691
  x = convert_to_tensor(x)
1643
1692
  dtype = standardize_dtype(x.dtype)
1644
- if dtype != "int64":
1645
- dtype = dtypes.result_type(dtype, "int32")
1693
+ if dtype in ("bool", "int8", "int16", "uint8"):
1694
+ # Torch backend doesn't support uint32 dtype.
1695
+ dtype = "int32"
1646
1696
  return torch.sum(
1647
1697
  torch.diagonal(x, offset, axis1, axis2),
1648
1698
  dim=-1,
@@ -1755,6 +1805,16 @@ def negative(x):
1755
1805
  return torch.negative(x)
1756
1806
 
1757
1807
 
1808
+ def nextafter(x1, x2):
1809
+ x1 = convert_to_tensor(x1)
1810
+ x2 = convert_to_tensor(x2)
1811
+
1812
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1813
+ x1 = cast(x1, torch.float64)
1814
+ x2 = cast(x2, torch.float64)
1815
+ return cast(torch.nextafter(x1, x2), dtype)
1816
+
1817
+
1758
1818
  def square(x):
1759
1819
  x = convert_to_tensor(x)
1760
1820
  if standardize_dtype(x.dtype) == "bool":
@@ -1783,6 +1843,24 @@ def transpose(x, axes=None):
1783
1843
  return x.T
1784
1844
 
1785
1845
 
1846
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
1847
+ y = convert_to_tensor(y)
1848
+ if standardize_dtype(y.dtype) == "bool":
1849
+ y = cast(y, config.floatx())
1850
+ if x is not None:
1851
+ x = convert_to_tensor(x)
1852
+ return torch.trapz(y, x=x, dim=axis)
1853
+ else:
1854
+ dx = convert_to_tensor(dx)
1855
+ return torch.trapz(y, dx=dx, dim=axis)
1856
+
1857
+
1858
+ def vander(x, N=None, increasing=False):
1859
+ x = convert_to_tensor(x)
1860
+ result_dtype = dtypes.result_type(x.dtype)
1861
+ return cast(torch.vander(x, N=N, increasing=increasing), result_dtype)
1862
+
1863
+
1786
1864
  def var(x, axis=None, keepdims=False):
1787
1865
  x = convert_to_tensor(x)
1788
1866
  compute_dtype = dtypes.result_type(x.dtype, "float32")
@@ -8,6 +8,7 @@ from keras.src.callbacks.lambda_callback import LambdaCallback
8
8
  from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
9
9
  from keras.src.callbacks.model_checkpoint import ModelCheckpoint
10
10
  from keras.src.callbacks.monitor_callback import MonitorCallback
11
+ from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
11
12
  from keras.src.callbacks.progbar_logger import ProgbarLogger
12
13
  from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
13
14
  from keras.src.callbacks.remote_monitor import RemoteMonitor
@@ -283,6 +283,11 @@ class ModelCheckpoint(MonitorCallback):
283
283
  self.model.save_weights(filepath, overwrite=True)
284
284
  else:
285
285
  self.model.save(filepath, overwrite=True)
286
+ if self.verbose > 0:
287
+ io_utils.print_msg(
288
+ f"\nEpoch {epoch + 1}: "
289
+ f"finished saving model to {filepath}"
290
+ )
286
291
  except IsADirectoryError: # h5py 3.x
287
292
  raise IOError(
288
293
  "Please specify a non-directory filepath for "
@@ -0,0 +1,332 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+
5
+ from keras.src import backend
6
+ from keras.src import tree
7
+ from keras.src.api_export import keras_export
8
+ from keras.src.callbacks.monitor_callback import (
9
+ MonitorCallback, # For metric monitoring logic
10
+ )
11
+ from keras.src.utils.module_utils import ocp
12
+
13
+ # Context and AsyncOptions are accessed through the lazy-loaded ocp module
14
+
15
+ # JAX monitoring compatibility: ensure record_scalar exists
16
+ # to prevent AttributeError in older JAX versions
17
+ try:
18
+ import jax
19
+
20
+ if not hasattr(jax.monitoring, "record_scalar"):
21
+ jax.monitoring.record_scalar = lambda *args, **kwargs: None
22
+ except ImportError:
23
+ pass
24
+
25
+
26
+ def _get_state_tree(model):
27
+ """Get the complete model state as a nested tree structure."""
28
+ # For JAX backend, preserve native arrays for performance
29
+ # For other backends, convert to numpy arrays
30
+ if backend.backend() == "jax":
31
+ state_tree = model.get_state_tree()
32
+ did_numpy_conversion = False
33
+ else:
34
+ state_tree = model.get_state_tree(value_format="numpy_array")
35
+ did_numpy_conversion = True
36
+
37
+ # Convert numpy scalar types to Python types for Orbax compatibility
38
+ # Only needed when we did numpy conversion
39
+ if did_numpy_conversion:
40
+
41
+ def convert_scalars(obj):
42
+ if isinstance(obj, np.ndarray) and obj.ndim == 0:
43
+ # Convert 0-dimensional numpy arrays (scalars) to Python types
44
+ return obj.item()
45
+ elif isinstance(obj, np.generic):
46
+ # Convert numpy scalar types (like np.float32) to Python types
47
+ return obj.item()
48
+ else:
49
+ return obj
50
+
51
+ return tree.map_structure(convert_scalars, state_tree)
52
+ else:
53
+ return state_tree
54
+
55
+
56
+ @keras_export("keras.callbacks.OrbaxCheckpoint")
57
+ class OrbaxCheckpoint(MonitorCallback):
58
+ """Callback to save and load model state using Orbax with a similar API to
59
+ ModelCheckpoint.
60
+
61
+ This callback saves the model's weights and optimizer state asynchronously
62
+ using Orbax, allowing training to continue without blocking for I/O.
63
+
64
+ **Multi-host Support**: When running in a multi-host distributed training
65
+ environment with JAX backend, this callback automatically coordinates
66
+ checkpointing across all hosts to ensure consistency and proper
67
+ synchronization. Multi-host checkpointing is only supported on JAX.
68
+
69
+ Example:
70
+
71
+ ```python
72
+ model.compile(loss=..., optimizer=..., metrics=['accuracy'])
73
+
74
+ EPOCHS = 10
75
+ checkpoint_dir = '/tmp/ckpt'
76
+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
77
+ directory=checkpoint_dir,
78
+ monitor='val_accuracy',
79
+ mode='max',
80
+ save_best_only=True)
81
+
82
+ # Model is saved at the end of every epoch, if it's the best seen so far.
83
+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
84
+
85
+ # Alternatively, save checkpoints every N batches -
86
+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
87
+ directory=checkpoint_dir,
88
+ save_freq=100) # Save every 100 batches
89
+
90
+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
91
+ ```
92
+
93
+ Args:
94
+ directory: path to the directory where to save the checkpoints.
95
+ monitor: The metric name to monitor (e.g., 'val_loss').
96
+ verbose: Verbosity mode, 0 or 1.
97
+ save_best_only: if `save_best_only=True`, it only saves when the model
98
+ is considered the "best" based on the monitored quantity.
99
+ mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
100
+ save_freq: `'epoch'` or integer. Frequency to save checkpoints.
101
+ max_to_keep: Integer, maximum number of recent checkpoints to keep.
102
+ If None, keeps all. Defaults to 1.
103
+ save_on_background: Boolean, whether to save asynchronously in the
104
+ background. Defaults to True.
105
+ initial_value_threshold: Floating point initial "best" value for the
106
+ monitor, used with `save_best_only`.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ directory,
112
+ monitor="val_loss",
113
+ verbose=0,
114
+ save_best_only=False,
115
+ mode="auto",
116
+ save_freq="epoch",
117
+ initial_value_threshold=None,
118
+ max_to_keep=1,
119
+ save_on_background=True,
120
+ ):
121
+ # Ensure orbax is available
122
+ ocp.initialize()
123
+
124
+ # Initialize MonitorCallback for handling 'monitor', 'mode', 'best'
125
+ # logic
126
+ super().__init__(monitor, mode, initial_value_threshold)
127
+
128
+ self.directory = directory
129
+ self.verbose = verbose
130
+ self.save_best_only = save_best_only
131
+ self.save_freq = save_freq
132
+ self.max_to_keep = max_to_keep
133
+ self.save_on_background = save_on_background
134
+ self._batches_seen_since_last_saving = 0
135
+ self._last_batch_seen = 0
136
+ self._current_epoch = 0 # Keep track of epoch
137
+ self._total_batches_seen = 0 # Global batch counter for step tracking
138
+
139
+ # Multi-host support
140
+ self._multihost_initialized = self._is_multihost_initialized()
141
+
142
+ if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
143
+ raise ValueError(
144
+ f"Unrecognized save_freq: {self.save_freq}. "
145
+ "Expected save_freq are 'epoch' or integer values"
146
+ )
147
+
148
+ # --- Orbax Checkpointer Setup (V1 API) ---
149
+ policies = []
150
+ if max_to_keep is not None:
151
+ policies.append(
152
+ ocp.training.preservation_policies.LatestN(max_to_keep)
153
+ )
154
+
155
+ # Use AnyPreservationPolicy to combine them, or use directly
156
+ # if single policy
157
+ preservation_policy = None
158
+ if policies:
159
+ if len(policies) == 1:
160
+ preservation_policy = policies[0]
161
+ else:
162
+ preservation_policy = (
163
+ ocp.training.preservation_policies.AnyPreservationPolicy(
164
+ policies
165
+ )
166
+ )
167
+
168
+ # Create the V1 Checkpointer with direct parameter passing
169
+ # Orbax will handle directory creation on all processes as needed
170
+ self.checkpointer = ocp.training.Checkpointer(
171
+ directory=directory,
172
+ preservation_policy=preservation_policy,
173
+ )
174
+
175
+ def _is_multihost_initialized(self):
176
+ """Check if multi-host environment is initialized."""
177
+ # Multi-host checkpointing is only supported on JAX backend
178
+ if backend.backend() != "jax":
179
+ return False
180
+
181
+ multihost = ocp.multihost
182
+ # Check if JAX distributed client is initialized
183
+ # (indicates multihost setup)
184
+ return multihost.is_jax_distributed_client_initialized()
185
+
186
+ def _sync_processes(self, key=None):
187
+ """Synchronize all processes across hosts."""
188
+ if not self._multihost_initialized:
189
+ return # No-op for single host
190
+
191
+ multihost = ocp.multihost
192
+ sync_key = key or "orbax_checkpoint_sync"
193
+ multihost.sync_global_processes(sync_key)
194
+
195
+ def is_multihost_enabled(self):
196
+ """Return True if multi-host checkpointing is enabled and initialized.
197
+
198
+ This method can be used to check if the callback is operating in
199
+ a multi-host distributed training environment. Multi-host checkpointing
200
+ is only supported on JAX backend.
201
+
202
+ Returns:
203
+ bool: True if multi-host support is active, False otherwise.
204
+ """
205
+ return self._multihost_initialized
206
+
207
+ def is_primary_host(self):
208
+ """Return True if this process is the primary host in multi-host setup.
209
+
210
+ In multi-host environments, only the primary host typically handles
211
+ logging and coordination tasks. Multi-host checkpointing is only
212
+ supported on JAX backend.
213
+
214
+ Returns:
215
+ bool: True if this is the primary host, False otherwise.
216
+ Always returns True in single-host environments.
217
+ """
218
+ if not self._multihost_initialized:
219
+ return True # Single host is always primary
220
+ multihost = ocp.multihost
221
+ return multihost.is_primary_host()
222
+
223
+ def _should_save_on_batch(self, batch):
224
+ """Check if we should save on this batch."""
225
+ if self.save_freq == "epoch":
226
+ return False
227
+
228
+ if batch <= self._last_batch_seen: # New epoch.
229
+ add_batches = batch + 1
230
+ else:
231
+ add_batches = batch - self._last_batch_seen
232
+ self._batches_seen_since_last_saving += add_batches
233
+ self._last_batch_seen = batch
234
+ self._total_batches_seen += add_batches
235
+
236
+ if self._batches_seen_since_last_saving >= self.save_freq:
237
+ self._batches_seen_since_last_saving = 0
238
+ return True
239
+ return False
240
+
241
+ def _save_checkpoint(self, step, logs=None):
242
+ """Save a checkpoint at the given step with multi-host coordination."""
243
+
244
+ # --- Prepare Composite State (Backend-Agnostic) ---
245
+ state_tree = _get_state_tree(self.model)
246
+
247
+ # Save the nested state structures directly (preserving layer
248
+ # names and structure)
249
+ composite_state = state_tree
250
+
251
+ # Use a single with statement. If context_options is empty,
252
+ # Context() uses defaults.
253
+ with ocp.Context():
254
+ if self.save_on_background:
255
+ self.checkpointer.save_pytree_async(step, composite_state)
256
+ else:
257
+ self.checkpointer.save_pytree(step, composite_state)
258
+
259
+ def on_train_batch_end(self, batch, logs=None):
260
+ if self._should_save_on_batch(batch):
261
+ # Handle save_best_only logic for batch-level saving
262
+ should_save = True
263
+ if self.save_best_only:
264
+ current = logs.get(self.monitor) if logs else None
265
+ if current is None:
266
+ warnings.warn(
267
+ f"Can save best model only with {self.monitor} "
268
+ f"available, skipping save at batch {batch}.",
269
+ stacklevel=2,
270
+ )
271
+ should_save = False
272
+ elif not self._is_improvement(current, self.best):
273
+ should_save = False
274
+ else:
275
+ # Update best value when there's improvement
276
+ self.best = current
277
+
278
+ if should_save:
279
+ # Use global batch count for Orbax save step
280
+ step = self._total_batches_seen
281
+ self._save_checkpoint(step=step, logs=logs)
282
+
283
+ def on_epoch_end(self, epoch, logs=None):
284
+ self._current_epoch = epoch
285
+ if self.monitor_op is None:
286
+ self._set_monitor_op() # From MonitorCallback
287
+
288
+ # For save_freq="epoch", save at every epoch
289
+ should_save = self.save_freq == "epoch"
290
+
291
+ # Handle save_best_only logic
292
+ if should_save and self.save_best_only:
293
+ current = logs.get(self.monitor) if logs else None
294
+ if current is None:
295
+ warnings.warn(
296
+ f"Can save best model only with {self.monitor} available, "
297
+ f"skipping save at epoch {epoch}.",
298
+ stacklevel=2,
299
+ )
300
+ should_save = False
301
+ elif not self._is_improvement(current, self.best):
302
+ should_save = False
303
+ else:
304
+ # Update best value when there's improvement
305
+ self.best = current
306
+
307
+ if should_save:
308
+ # Use epoch number as the step for Orbax save
309
+ # Keras has already made the save decision - Checkpointer will
310
+ # save unconditionally
311
+ self._save_checkpoint(step=epoch, logs=logs)
312
+
313
+ def on_train_end(self, logs=None):
314
+ # Close the Checkpointer to ensure all pending saves complete
315
+ try:
316
+ self.checkpointer.close()
317
+ except Exception:
318
+ pass # Ignore errors during cleanup
319
+
320
+ # Multi-host synchronization: ensure all hosts complete cleanup
321
+ self._sync_processes("checkpoint_cleanup")
322
+
323
+ def wait_until_finished(self):
324
+ """Wait for any in-progress checkpoint operations to complete.
325
+ This method blocks until all asynchronous checkpoint save operations
326
+ have completed across all hosts in a multi-host setup.
327
+ """
328
+ # Wait for any async operations to complete on this host
329
+ self.checkpointer.wait()
330
+
331
+ # Multi-host synchronization: ensure all hosts complete
332
+ self._sync_processes("checkpoint_wait_complete")
@@ -7,14 +7,63 @@ from keras.src.utils import io_utils
7
7
 
8
8
  @keras_export("keras.callbacks.TerminateOnNaN")
9
9
  class TerminateOnNaN(Callback):
10
- """Callback that terminates training when a NaN loss is encountered."""
10
+ """Callback that terminates training when a NaN loss is encountered.
11
+
12
+ This callback monitors the loss value during training
13
+ and terminates training when a NaN or Inf loss is detected.
14
+ By default, training is stopped gracefully
15
+ by setting `model.stop_training = True`, which triggers all callback cleanup
16
+ methods including `on_train_end()`.
17
+
18
+ Alternatively, you can use `raise_error=True` to immediately raise a
19
+ RuntimeError when NaN/Inf is detected. This raise_error termination
20
+ prevents `on_train_end()` from being called on other callbacks, which
21
+ is useful for preserving backup states or preventing unintended cleanup
22
+ when training fails.
23
+
24
+ Args:
25
+ raise_error: Boolean, default False. If False, uses graceful stop via
26
+ `model.stop_training = True`. If True, immediately raises
27
+ RuntimeError on NaN/Inf loss, bypassing callback cleanup methods.
28
+
29
+ Example:
30
+
31
+ ```
32
+ # Graceful termination (default)
33
+ callback = keras.callbacks.TerminateOnNaN()
34
+ model.fit(x, y, callbacks=[callback])
35
+
36
+ # raise_error termination (strict failure)
37
+ callback = keras.callbacks.TerminateOnNaN(raise_error=True)
38
+ model.fit(x, y, callbacks=[callback])
39
+ ```
40
+ """
41
+
42
+ def __init__(self, raise_error: bool = False):
43
+ super().__init__()
44
+ self.raise_error = raise_error
11
45
 
12
46
  def on_batch_end(self, batch, logs=None):
47
+ """Check for NaN/Inf loss at the end of each batch.
48
+
49
+ Args:
50
+ batch: Integer, index of batch within the current epoch.
51
+ logs: Dict, contains the return value of `model.train_step()`.
52
+
53
+ Raises:
54
+ RuntimeError: If loss is NaN/Inf and raise_error=True.
55
+ """
13
56
  logs = logs or {}
14
57
  loss = logs.get("loss")
15
58
  if loss is not None:
16
59
  if np.isnan(loss) or np.isinf(loss):
17
- io_utils.print_msg(
18
- f"Batch {batch}: Invalid loss, terminating training"
19
- )
20
- self.model.stop_training = True
60
+ if self.raise_error:
61
+ raise RuntimeError(
62
+ f"NaN or Inf loss encountered at batch {batch}. "
63
+ f"Loss value: {loss}. Terminating training immediately."
64
+ )
65
+ else:
66
+ io_utils.print_msg(
67
+ f"Batch {batch}: Invalid loss, terminating training"
68
+ )
69
+ self.model.stop_training = True
@@ -59,6 +59,11 @@ def load_data():
59
59
  assert y_train.shape == (50000, 1)
60
60
  assert y_test.shape == (10000, 1)
61
61
  ```
62
+
63
+ **Note**: The CIFAR-10 dataset is known to have a small percentage of
64
+ mislabeled samples, which is inherent to the original dataset. This label
65
+ noise may impact training and evaluation. For more details, refer to
66
+ discussions in the research literature on CIFAR-10 label quality.
62
67
  """
63
68
  dirname = "cifar-10-batches-py-target"
64
69
  origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
@@ -0,0 +1 @@
1
+ """Distillation module for knowledge distillation in Keras."""