keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -313,18 +313,19 @@ def append(x1, x2, axis=None):
313
313
  return torch.cat((x1, x2), dim=axis)
314
314
 
315
315
 
316
- def arange(start, stop=None, step=1, dtype=None):
316
+ def arange(start, stop=None, step=None, dtype=None):
317
317
  if dtype is None:
318
- dtypes_to_resolve = [
319
- getattr(start, "dtype", type(start)),
320
- getattr(step, "dtype", type(step)),
321
- ]
318
+ dtypes_to_resolve = [getattr(start, "dtype", type(start))]
322
319
  if stop is not None:
323
320
  dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
321
+ if step is not None:
322
+ dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
324
323
  dtype = dtypes.result_type(*dtypes_to_resolve)
325
324
  dtype = to_torch_dtype(dtype)
326
325
  if stop is None:
327
- return torch.arange(end=start, dtype=dtype, device=get_device())
326
+ start, stop = 0, start
327
+ if step is None:
328
+ step = 1
328
329
  return torch.arange(
329
330
  start, stop, step=step, dtype=dtype, device=get_device()
330
331
  )
@@ -410,6 +411,12 @@ def array(x, dtype=None):
410
411
  return convert_to_tensor(x, dtype=dtype)
411
412
 
412
413
 
414
+ def view(x, dtype=None):
415
+ dtype = to_torch_dtype(dtype)
416
+ x = convert_to_tensor(x)
417
+ return x.view(dtype=dtype)
418
+
419
+
413
420
  def average(x, axis=None, weights=None):
414
421
  x = convert_to_tensor(x)
415
422
  dtypes_to_resolve = [x.dtype, float]
@@ -763,6 +770,12 @@ def empty(shape, dtype=None):
763
770
  return torch.empty(size=shape, dtype=dtype, device=get_device())
764
771
 
765
772
 
773
+ def empty_like(x, dtype=None):
774
+ x = convert_to_tensor(x)
775
+ dtype = to_torch_dtype(dtype or x.dtype)
776
+ return torch.empty_like(x, dtype=dtype, device=get_device())
777
+
778
+
766
779
  def equal(x1, x2):
767
780
  x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
768
781
  return torch.eq(x1, x2)
@@ -945,6 +958,11 @@ def isposinf(x):
945
958
  return torch.isposinf(x)
946
959
 
947
960
 
961
+ def isreal(x):
962
+ x = convert_to_tensor(x)
963
+ return torch.isreal(x)
964
+
965
+
948
966
  def kron(x1, x2):
949
967
  x1 = convert_to_tensor(x1)
950
968
  x2 = convert_to_tensor(x2)
@@ -957,6 +975,20 @@ def lcm(x1, x2):
957
975
  return torch.lcm(x1, x2)
958
976
 
959
977
 
978
+ def ldexp(x1, x2):
979
+ x1 = convert_to_tensor(x1)
980
+ x2 = convert_to_tensor(x2)
981
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
982
+
983
+ if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
984
+ raise TypeError(
985
+ f"ldexp exponent must be an integer type. "
986
+ f"Received: x2 dtype={x2.dtype}"
987
+ )
988
+
989
+ return cast(torch.ldexp(x1, x2), dtype)
990
+
991
+
960
992
  def less(x1, x2):
961
993
  x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
962
994
  return torch.less(x1, x2)
@@ -1053,6 +1085,15 @@ def logaddexp(x1, x2):
1053
1085
  return torch.logaddexp(x1, x2)
1054
1086
 
1055
1087
 
1088
+ def logaddexp2(x1, x2):
1089
+ x1 = convert_to_tensor(x1)
1090
+ x2 = convert_to_tensor(x2)
1091
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1092
+ x1 = cast(x1, dtype)
1093
+ x2 = cast(x2, dtype)
1094
+ return torch.logaddexp2(x1, x2)
1095
+
1096
+
1056
1097
  def logical_and(x1, x2):
1057
1098
  x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
1058
1099
  return torch.logical_and(x1, x2)
@@ -1518,6 +1559,12 @@ def split(x, indices_or_sections, axis=0):
1518
1559
  return list(out)
1519
1560
 
1520
1561
 
1562
+ def array_split(x, indices_or_sections, axis=0):
1563
+ x = convert_to_tensor(x)
1564
+ out = torch.tensor_split(x, indices_or_sections, dim=axis)
1565
+ return list(out)
1566
+
1567
+
1521
1568
  def stack(x, axis=0):
1522
1569
  x = [convert_to_tensor(elem) for elem in x]
1523
1570
  return torch.stack(x, dim=axis)
@@ -1631,8 +1678,9 @@ def tile(x, repeats):
1631
1678
  def trace(x, offset=0, axis1=0, axis2=1):
1632
1679
  x = convert_to_tensor(x)
1633
1680
  dtype = standardize_dtype(x.dtype)
1634
- if dtype != "int64":
1635
- dtype = dtypes.result_type(dtype, "int32")
1681
+ if dtype in ("bool", "int8", "int16", "uint8"):
1682
+ # Torch backend doesn't support uint32 dtype.
1683
+ dtype = "int32"
1636
1684
  return torch.sum(
1637
1685
  torch.diagonal(x, offset, axis1, axis2),
1638
1686
  dim=-1,
@@ -1745,6 +1793,16 @@ def negative(x):
1745
1793
  return torch.negative(x)
1746
1794
 
1747
1795
 
1796
+ def nextafter(x1, x2):
1797
+ x1 = convert_to_tensor(x1)
1798
+ x2 = convert_to_tensor(x2)
1799
+
1800
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1801
+ x1 = cast(x1, torch.float64)
1802
+ x2 = cast(x2, torch.float64)
1803
+ return cast(torch.nextafter(x1, x2), dtype)
1804
+
1805
+
1748
1806
  def square(x):
1749
1807
  x = convert_to_tensor(x)
1750
1808
  if standardize_dtype(x.dtype) == "bool":
@@ -1773,6 +1831,24 @@ def transpose(x, axes=None):
1773
1831
  return x.T
1774
1832
 
1775
1833
 
1834
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
1835
+ y = convert_to_tensor(y)
1836
+ if standardize_dtype(y.dtype) == "bool":
1837
+ y = cast(y, config.floatx())
1838
+ if x is not None:
1839
+ x = convert_to_tensor(x)
1840
+ return torch.trapz(y, x=x, dim=axis)
1841
+ else:
1842
+ dx = convert_to_tensor(dx)
1843
+ return torch.trapz(y, dx=dx, dim=axis)
1844
+
1845
+
1846
+ def vander(x, N=None, increasing=False):
1847
+ x = convert_to_tensor(x)
1848
+ result_dtype = dtypes.result_type(x.dtype)
1849
+ return cast(torch.vander(x, N=N, increasing=increasing), result_dtype)
1850
+
1851
+
1776
1852
  def var(x, axis=None, keepdims=False):
1777
1853
  x = convert_to_tensor(x)
1778
1854
  compute_dtype = dtypes.result_type(x.dtype, "float32")
@@ -8,6 +8,7 @@ from keras.src.callbacks.lambda_callback import LambdaCallback
8
8
  from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
9
9
  from keras.src.callbacks.model_checkpoint import ModelCheckpoint
10
10
  from keras.src.callbacks.monitor_callback import MonitorCallback
11
+ from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
11
12
  from keras.src.callbacks.progbar_logger import ProgbarLogger
12
13
  from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
13
14
  from keras.src.callbacks.remote_monitor import RemoteMonitor
@@ -39,6 +39,7 @@ class CallbackList(Callback):
39
39
  via `Callback.set_params`.
40
40
  """
41
41
  self.callbacks = tree.flatten(callbacks) if callbacks else []
42
+ self._in_begin_end_block_count = 0
42
43
  self._executor = None
43
44
  self._async_train = False
44
45
  self._async_test = False
@@ -78,9 +79,6 @@ class CallbackList(Callback):
78
79
  if not utils.is_default(cbk.on_predict_batch_end):
79
80
  async_predict = False
80
81
 
81
- if async_train or async_test or async_predict:
82
- self._executor = concurrent.futures.ThreadPoolExecutor()
83
-
84
82
  self._async_train = async_train
85
83
  self._async_test = async_test
86
84
  self._async_predict = async_predict
@@ -113,6 +111,33 @@ class CallbackList(Callback):
113
111
  for callback in self.callbacks:
114
112
  callback.set_model(model)
115
113
 
114
+ def _on_begin(self):
115
+ """Called by `on_train/test/predict_begin`.
116
+
117
+ Start the executor for async calls if needed.
118
+ """
119
+ self._in_begin_end_block_count += 1
120
+ if (
121
+ self._in_begin_end_block_count == 1
122
+ and (self._async_train or self._async_test or self._async_predict)
123
+ and self._executor is None
124
+ ):
125
+ self._executor = concurrent.futures.ThreadPoolExecutor()
126
+
127
+ def _on_end(self):
128
+ """Called by `on_train/test/predict_end`.
129
+
130
+ Shutdown the executor for async calls if all begin/end blocks completed.
131
+ """
132
+ self._in_begin_end_block_count -= 1
133
+ if self._in_begin_end_block_count < 0:
134
+ raise ValueError(
135
+ "`on_xxx_end` called without corresponding `on_xxx_begin`"
136
+ )
137
+ if self._in_begin_end_block_count == 0 and self._executor is not None:
138
+ self._executor.shutdown()
139
+ self._executor = None
140
+
116
141
  def _async_dispatch(self, fn, *args):
117
142
  for future in self._futures:
118
143
  if future.done():
@@ -121,7 +146,8 @@ class CallbackList(Callback):
121
146
  future = self._executor.submit(fn, *args)
122
147
  self._futures.append(future)
123
148
 
124
- def _clear_futures(self):
149
+ def _flush_futures(self):
150
+ """Waits for all futures to complete and clears the list."""
125
151
  for future in self._futures:
126
152
  future.result()
127
153
  self._futures = []
@@ -138,7 +164,7 @@ class CallbackList(Callback):
138
164
 
139
165
  def on_epoch_end(self, epoch, logs=None):
140
166
  if self._async_train:
141
- self._clear_futures()
167
+ self._flush_futures()
142
168
 
143
169
  logs = python_utils.pythonify_logs(logs)
144
170
  for callback in self.callbacks:
@@ -204,44 +230,52 @@ class CallbackList(Callback):
204
230
  callback.on_predict_batch_end(batch, logs=logs)
205
231
 
206
232
  def on_train_begin(self, logs=None):
233
+ self._on_begin()
234
+
207
235
  logs = python_utils.pythonify_logs(logs)
208
236
  for callback in self.callbacks:
209
237
  callback.on_train_begin(logs)
210
238
 
211
239
  def on_train_end(self, logs=None):
212
240
  if self._async_train:
213
- self._clear_futures()
241
+ self._flush_futures()
214
242
 
215
243
  logs = python_utils.pythonify_logs(logs)
216
244
  for callback in self.callbacks:
217
245
  callback.on_train_end(logs)
218
246
 
247
+ self._on_end()
248
+
219
249
  def on_test_begin(self, logs=None):
250
+ self._on_begin()
251
+
220
252
  logs = python_utils.pythonify_logs(logs)
221
253
  for callback in self.callbacks:
222
254
  callback.on_test_begin(logs)
223
255
 
224
256
  def on_test_end(self, logs=None):
225
257
  if self._async_test:
226
- self._clear_futures()
258
+ self._flush_futures()
227
259
 
228
260
  logs = python_utils.pythonify_logs(logs)
229
261
  for callback in self.callbacks:
230
262
  callback.on_test_end(logs)
231
263
 
264
+ self._on_end()
265
+
232
266
  def on_predict_begin(self, logs=None):
267
+ self._on_begin()
268
+
233
269
  logs = python_utils.pythonify_logs(logs)
234
270
  for callback in self.callbacks:
235
271
  callback.on_predict_begin(logs)
236
272
 
237
273
  def on_predict_end(self, logs=None):
238
274
  if self._async_predict:
239
- self._clear_futures()
275
+ self._flush_futures()
240
276
 
241
277
  logs = python_utils.pythonify_logs(logs)
242
278
  for callback in self.callbacks:
243
279
  callback.on_predict_end(logs)
244
280
 
245
- def __del__(self):
246
- if self._executor is not None:
247
- self._executor.shutdown(cancel_futures=True)
281
+ self._on_end()
@@ -283,6 +283,11 @@ class ModelCheckpoint(MonitorCallback):
283
283
  self.model.save_weights(filepath, overwrite=True)
284
284
  else:
285
285
  self.model.save(filepath, overwrite=True)
286
+ if self.verbose > 0:
287
+ io_utils.print_msg(
288
+ f"\nEpoch {epoch + 1}: "
289
+ f"finished saving model to {filepath}"
290
+ )
286
291
  except IsADirectoryError: # h5py 3.x
287
292
  raise IOError(
288
293
  "Please specify a non-directory filepath for "
@@ -0,0 +1,299 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+
5
+ from keras.src import backend
6
+ from keras.src import tree
7
+ from keras.src.api_export import keras_export
8
+ from keras.src.callbacks.monitor_callback import (
9
+ MonitorCallback, # For metric monitoring logic
10
+ )
11
+ from keras.src.utils.io_utils import print_msg
12
+ from keras.src.utils.module_utils import ocp
13
+
14
+ # Context and AsyncOptions are accessed through the lazy-loaded ocp module
15
+
16
+ # JAX monitoring compatibility: ensure record_scalar exists
17
+ # to prevent AttributeError in older JAX versions
18
+ try:
19
+ import jax
20
+
21
+ if not hasattr(jax.monitoring, "record_scalar"):
22
+ jax.monitoring.record_scalar = lambda *args, **kwargs: None
23
+ except ImportError:
24
+ pass
25
+
26
+
27
+ def _get_state_tree(model):
28
+ """Get the complete model state as a nested tree structure."""
29
+ # For JAX backend, preserve native arrays for performance
30
+ # For other backends, convert to numpy arrays
31
+ if backend.backend() == "jax":
32
+ state_tree = model.get_state_tree()
33
+ did_numpy_conversion = False
34
+ else:
35
+ state_tree = model.get_state_tree(value_format="numpy_array")
36
+ did_numpy_conversion = True
37
+
38
+ # Convert numpy scalar types to Python types for Orbax compatibility
39
+ # Only needed when we did numpy conversion
40
+ if did_numpy_conversion:
41
+
42
+ def convert_scalars(obj):
43
+ if isinstance(obj, np.ndarray) and obj.ndim == 0:
44
+ # Convert 0-dimensional numpy arrays (scalars) to Python types
45
+ return obj.item()
46
+ elif isinstance(obj, np.generic):
47
+ # Convert numpy scalar types (like np.float32) to Python types
48
+ return obj.item()
49
+ else:
50
+ return obj
51
+
52
+ return tree.map_structure(convert_scalars, state_tree)
53
+ else:
54
+ return state_tree
55
+
56
+
57
+ @keras_export("keras.callbacks.OrbaxCheckpoint")
58
+ class OrbaxCheckpoint(MonitorCallback):
59
+ """Callback to save and load model state using Orbax with a similar API to
60
+ ModelCheckpoint.
61
+
62
+ This callback saves the model's weights and optimizer state asynchronously
63
+ using Orbax, allowing training to continue without blocking for I/O.
64
+
65
+ Example:
66
+
67
+ ```python
68
+ model.compile(loss=..., optimizer=..., metrics=['accuracy'])
69
+
70
+ EPOCHS = 10
71
+ checkpoint_dir = '/tmp/ckpt'
72
+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
73
+ directory=checkpoint_dir,
74
+ monitor='val_accuracy',
75
+ mode='max',
76
+ save_best_only=True)
77
+
78
+ # Model is saved at the end of every epoch, if it's the best seen so far.
79
+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
80
+
81
+ # Alternatively, save checkpoints every N batches -
82
+ orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
83
+ directory=checkpoint_dir,
84
+ save_freq=100) # Save every 100 batches
85
+
86
+ model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
87
+ ```
88
+
89
+ Args:
90
+ directory: path to the directory where to save the checkpoints.
91
+ monitor: The metric name to monitor (e.g., 'val_loss').
92
+ verbose: Verbosity mode, 0 or 1.
93
+ save_best_only: if `save_best_only=True`, it only saves when the model
94
+ is considered the "best" based on the monitored quantity.
95
+ save_weights_only: if `save_weights_only=True`, only the model's
96
+ weights will be saved. Otherwise, the full model state
97
+ (weights, non-trainable variables, optimizer state, and
98
+ metrics state) will be saved. Defaults to False.
99
+ mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
100
+ save_freq: `'epoch'` or integer. Frequency to save checkpoints.
101
+ max_to_keep: Integer, maximum number of recent checkpoints to keep.
102
+ If None, keeps all. Defaults to 1.
103
+ save_on_background: Boolean, whether to save asynchronously in the
104
+ background. Defaults to True.
105
+ initial_value_threshold: Floating point initial "best" value for the
106
+ monitor, used with `save_best_only`.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ directory,
112
+ monitor="val_loss",
113
+ verbose=0,
114
+ save_best_only=False,
115
+ save_weights_only=False,
116
+ mode="auto",
117
+ save_freq="epoch",
118
+ initial_value_threshold=None,
119
+ max_to_keep=1,
120
+ save_on_background=True,
121
+ ):
122
+ # Ensure orbax is available
123
+ ocp.initialize()
124
+
125
+ # Initialize MonitorCallback for handling 'monitor', 'mode', 'best'
126
+ # logic
127
+ super().__init__(monitor, mode, initial_value_threshold)
128
+
129
+ self.directory = directory
130
+ self.verbose = verbose
131
+ self.save_best_only = save_best_only
132
+ self.save_weights_only = save_weights_only
133
+ self.save_freq = save_freq
134
+ self.max_to_keep = max_to_keep
135
+ self.save_on_background = save_on_background
136
+ self._batches_seen_since_last_saving = 0
137
+ self._last_batch_seen = 0
138
+ self._current_epoch = 0 # Keep track of epoch
139
+ self._total_batches_seen = 0 # Global batch counter for step tracking
140
+
141
+ if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
142
+ raise ValueError(
143
+ f"Unrecognized save_freq: {self.save_freq}. "
144
+ "Expected save_freq are 'epoch' or integer values"
145
+ )
146
+
147
+ # --- Orbax Checkpointer Setup (V1 API) ---
148
+ policies = []
149
+ if max_to_keep is not None:
150
+ policies.append(
151
+ ocp.training.preservation_policies.LatestN(max_to_keep)
152
+ )
153
+
154
+ # Use AnyPreservationPolicy to combine them.
155
+ preservation_policy = None
156
+ if policies:
157
+ preservation_policy = (
158
+ ocp.training.preservation_policies.AnyPreservationPolicy(
159
+ policies
160
+ )
161
+ )
162
+
163
+ # Create the V1 Checkpointer with direct parameter passing
164
+ # Orbax will handle directory creation on all processes as needed
165
+ self.checkpointer = ocp.training.Checkpointer(
166
+ directory=directory,
167
+ preservation_policy=preservation_policy,
168
+ )
169
+
170
+ def _should_save_on_batch(self, batch):
171
+ """Check if we should save on this batch."""
172
+ if self.save_freq == "epoch":
173
+ return False
174
+
175
+ if batch <= self._last_batch_seen: # New epoch.
176
+ add_batches = batch + 1
177
+ else:
178
+ add_batches = batch - self._last_batch_seen
179
+ self._batches_seen_since_last_saving += add_batches
180
+ self._last_batch_seen = batch
181
+ self._total_batches_seen += add_batches
182
+
183
+ if self._batches_seen_since_last_saving >= self.save_freq:
184
+ self._batches_seen_since_last_saving = 0
185
+ return True
186
+ return False
187
+
188
+ def _save_checkpoint(self, step, logs=None):
189
+ """Save a checkpoint at the given step."""
190
+
191
+ # --- Prepare Composite State (Backend-Agnostic) ---
192
+ state_tree = _get_state_tree(self.model)
193
+
194
+ # Save the nested state structures directly (preserving layer
195
+ # names and structure)
196
+ if self.save_weights_only:
197
+ composite_state = {
198
+ "trainable_variables": state_tree["trainable_variables"],
199
+ }
200
+ if "non_trainable_variables" in state_tree:
201
+ composite_state["non_trainable_variables"] = state_tree[
202
+ "non_trainable_variables"
203
+ ]
204
+ else:
205
+ composite_state = state_tree
206
+
207
+ # --- Save Logic (V1 API) ---
208
+ # All processes participate in distributed checkpointing
209
+ # Checkpointer is configured to save unconditionally when
210
+ # save_pytree is called
211
+ if self.verbose > 0:
212
+ print_msg(
213
+ f"OrbaxCheckpoint: Triggering async save for step {step}..."
214
+ )
215
+
216
+ # Use a single with statement. If context_options is empty,
217
+ # Context() uses defaults.
218
+ with ocp.Context():
219
+ if self.save_on_background:
220
+ self.checkpointer.save_pytree_async(step, composite_state)
221
+ else:
222
+ self.checkpointer.save_pytree(step, composite_state)
223
+
224
+ def on_train_batch_end(self, batch, logs=None):
225
+ if self._should_save_on_batch(batch):
226
+ # Handle save_best_only logic for batch-level saving
227
+ should_save = True
228
+ if self.save_best_only:
229
+ current = logs.get(self.monitor) if logs else None
230
+ if current is None:
231
+ warnings.warn(
232
+ f"Can save best model only with {self.monitor} "
233
+ f"available, skipping save at batch {batch}.",
234
+ stacklevel=2,
235
+ )
236
+ should_save = False
237
+ elif not self._is_improvement(current, self.best):
238
+ should_save = False
239
+ else:
240
+ # Update best value when there's improvement
241
+ self.best = current
242
+
243
+ if should_save:
244
+ # Use global batch count for Orbax save step
245
+ step = self._total_batches_seen
246
+ self._save_checkpoint(step=step, logs=logs)
247
+
248
+ def on_epoch_end(self, epoch, logs=None):
249
+ self._current_epoch = epoch
250
+ if self.monitor_op is None:
251
+ self._set_monitor_op() # From MonitorCallback
252
+
253
+ # For save_freq="epoch", save at every epoch
254
+ should_save = self.save_freq == "epoch"
255
+
256
+ # Handle save_best_only logic
257
+ if should_save and self.save_best_only:
258
+ current = logs.get(self.monitor) if logs else None
259
+ if current is None:
260
+ warnings.warn(
261
+ f"Can save best model only with {self.monitor} available, "
262
+ f"skipping save at epoch {epoch}.",
263
+ stacklevel=2,
264
+ )
265
+ should_save = False
266
+ elif not self._is_improvement(current, self.best):
267
+ should_save = False
268
+ else:
269
+ # Update best value when there's improvement
270
+ self.best = current
271
+
272
+ if should_save:
273
+ # Use epoch number as the step for Orbax save
274
+ # Keras has already made the save decision - Checkpointer will
275
+ # save unconditionally
276
+ self._save_checkpoint(step=epoch, logs=logs)
277
+
278
+ def on_train_end(self, logs=None):
279
+ # Close the Checkpointer to ensure all pending saves complete
280
+ try:
281
+ self.checkpointer.close()
282
+ except Exception:
283
+ pass # Ignore errors during cleanup
284
+
285
+ def wait_until_finished(self):
286
+ """Wait for any in-progress checkpoint operations to complete.
287
+ This method blocks until all asynchronous checkpoint save operations
288
+ have completed. It should be called before attempting to load
289
+ checkpoints if there might be pending save operations.
290
+ """
291
+ # Wait for any async operations to complete
292
+ if hasattr(self.checkpointer, "wait"):
293
+ self.checkpointer.wait()
294
+ else:
295
+ # Fallback for older Orbax versions that don't have wait() method
296
+ while self.checkpointer.is_saving_in_progress():
297
+ import time
298
+
299
+ time.sleep(0.1)
@@ -7,14 +7,63 @@ from keras.src.utils import io_utils
7
7
 
8
8
  @keras_export("keras.callbacks.TerminateOnNaN")
9
9
  class TerminateOnNaN(Callback):
10
- """Callback that terminates training when a NaN loss is encountered."""
10
+ """Callback that terminates training when a NaN loss is encountered.
11
+
12
+ This callback monitors the loss value during training
13
+ and terminates training when a NaN or Inf loss is detected.
14
+ By default, training is stopped gracefully
15
+ by setting `model.stop_training = True`, which triggers all callback cleanup
16
+ methods including `on_train_end()`.
17
+
18
+ Alternatively, you can use `raise_error=True` to immediately raise a
19
+ RuntimeError when NaN/Inf is detected. This raise_error termination
20
+ prevents `on_train_end()` from being called on other callbacks, which
21
+ is useful for preserving backup states or preventing unintended cleanup
22
+ when training fails.
23
+
24
+ Args:
25
+ raise_error: Boolean, default False. If False, uses graceful stop via
26
+ `model.stop_training = True`. If True, immediately raises
27
+ RuntimeError on NaN/Inf loss, bypassing callback cleanup methods.
28
+
29
+ Example:
30
+
31
+ ```
32
+ # Graceful termination (default)
33
+ callback = keras.callbacks.TerminateOnNaN()
34
+ model.fit(x, y, callbacks=[callback])
35
+
36
+ # raise_error termination (strict failure)
37
+ callback = keras.callbacks.TerminateOnNaN(raise_error=True)
38
+ model.fit(x, y, callbacks=[callback])
39
+ ```
40
+ """
41
+
42
+ def __init__(self, raise_error: bool = False):
43
+ super().__init__()
44
+ self.raise_error = raise_error
11
45
 
12
46
  def on_batch_end(self, batch, logs=None):
47
+ """Check for NaN/Inf loss at the end of each batch.
48
+
49
+ Args:
50
+ batch: Integer, index of batch within the current epoch.
51
+ logs: Dict, contains the return value of `model.train_step()`.
52
+
53
+ Raises:
54
+ RuntimeError: If loss is NaN/Inf and raise_error=True.
55
+ """
13
56
  logs = logs or {}
14
57
  loss = logs.get("loss")
15
58
  if loss is not None:
16
59
  if np.isnan(loss) or np.isinf(loss):
17
- io_utils.print_msg(
18
- f"Batch {batch}: Invalid loss, terminating training"
19
- )
20
- self.model.stop_training = True
60
+ if self.raise_error:
61
+ raise RuntimeError(
62
+ f"NaN or Inf loss encountered at batch {batch}. "
63
+ f"Loss value: {loss}. Terminating training immediately."
64
+ )
65
+ else:
66
+ io_utils.print_msg(
67
+ f"Batch {batch}: Invalid loss, terminating training"
68
+ )
69
+ self.model.stop_training = True
@@ -59,6 +59,11 @@ def load_data():
59
59
  assert y_train.shape == (50000, 1)
60
60
  assert y_test.shape == (10000, 1)
61
61
  ```
62
+
63
+ **Note**: The CIFAR-10 dataset is known to have a small percentage of
64
+ mislabeled samples, which is inherent to the original dataset. This label
65
+ noise may impact training and evaluation. For more details, refer to
66
+ discussions in the research literature on CIFAR-10 label quality.
62
67
  """
63
68
  dirname = "cifar-10-batches-py-target"
64
69
  origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"