keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from keras.src import backend
|
|
6
|
+
from keras.src import tree
|
|
7
|
+
from keras.src.api_export import keras_export
|
|
8
|
+
from keras.src.callbacks.monitor_callback import (
|
|
9
|
+
MonitorCallback, # For metric monitoring logic
|
|
10
|
+
)
|
|
11
|
+
from keras.src.utils.module_utils import ocp
|
|
12
|
+
|
|
13
|
+
# Context and AsyncOptions are accessed through the lazy-loaded ocp module
|
|
14
|
+
|
|
15
|
+
# JAX monitoring compatibility: ensure record_scalar exists
|
|
16
|
+
# to prevent AttributeError in older JAX versions
|
|
17
|
+
try:
|
|
18
|
+
import jax
|
|
19
|
+
|
|
20
|
+
if not hasattr(jax.monitoring, "record_scalar"):
|
|
21
|
+
jax.monitoring.record_scalar = lambda *args, **kwargs: None
|
|
22
|
+
except ImportError:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _get_state_tree(model):
|
|
27
|
+
"""Get the complete model state as a nested tree structure."""
|
|
28
|
+
# For JAX backend, preserve native arrays for performance
|
|
29
|
+
# For other backends, convert to numpy arrays
|
|
30
|
+
if backend.backend() == "jax":
|
|
31
|
+
state_tree = model.get_state_tree()
|
|
32
|
+
did_numpy_conversion = False
|
|
33
|
+
else:
|
|
34
|
+
state_tree = model.get_state_tree(value_format="numpy_array")
|
|
35
|
+
did_numpy_conversion = True
|
|
36
|
+
|
|
37
|
+
# Convert numpy scalar types to Python types for Orbax compatibility
|
|
38
|
+
# Only needed when we did numpy conversion
|
|
39
|
+
if did_numpy_conversion:
|
|
40
|
+
|
|
41
|
+
def convert_scalars(obj):
|
|
42
|
+
if isinstance(obj, np.ndarray) and obj.ndim == 0:
|
|
43
|
+
# Convert 0-dimensional numpy arrays (scalars) to Python types
|
|
44
|
+
return obj.item()
|
|
45
|
+
elif isinstance(obj, np.generic):
|
|
46
|
+
# Convert numpy scalar types (like np.float32) to Python types
|
|
47
|
+
return obj.item()
|
|
48
|
+
else:
|
|
49
|
+
return obj
|
|
50
|
+
|
|
51
|
+
return tree.map_structure(convert_scalars, state_tree)
|
|
52
|
+
else:
|
|
53
|
+
return state_tree
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@keras_export("keras.callbacks.OrbaxCheckpoint")
|
|
57
|
+
class OrbaxCheckpoint(MonitorCallback):
|
|
58
|
+
"""Callback to save and load model state using Orbax with a similar API to
|
|
59
|
+
ModelCheckpoint.
|
|
60
|
+
|
|
61
|
+
This callback saves the model's weights and optimizer state asynchronously
|
|
62
|
+
using Orbax, allowing training to continue without blocking for I/O.
|
|
63
|
+
|
|
64
|
+
**Multi-host Support**: When running in a multi-host distributed training
|
|
65
|
+
environment with JAX backend, this callback automatically coordinates
|
|
66
|
+
checkpointing across all hosts to ensure consistency and proper
|
|
67
|
+
synchronization. Multi-host checkpointing is only supported on JAX.
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
model.compile(loss=..., optimizer=..., metrics=['accuracy'])
|
|
73
|
+
|
|
74
|
+
EPOCHS = 10
|
|
75
|
+
checkpoint_dir = '/tmp/ckpt'
|
|
76
|
+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
|
|
77
|
+
directory=checkpoint_dir,
|
|
78
|
+
monitor='val_accuracy',
|
|
79
|
+
mode='max',
|
|
80
|
+
save_best_only=True)
|
|
81
|
+
|
|
82
|
+
# Model is saved at the end of every epoch, if it's the best seen so far.
|
|
83
|
+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
|
|
84
|
+
|
|
85
|
+
# Alternatively, save checkpoints every N batches -
|
|
86
|
+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
|
|
87
|
+
directory=checkpoint_dir,
|
|
88
|
+
save_freq=100) # Save every 100 batches
|
|
89
|
+
|
|
90
|
+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
directory: path to the directory where to save the checkpoints.
|
|
95
|
+
monitor: The metric name to monitor (e.g., 'val_loss').
|
|
96
|
+
verbose: Verbosity mode, 0 or 1.
|
|
97
|
+
save_best_only: if `save_best_only=True`, it only saves when the model
|
|
98
|
+
is considered the "best" based on the monitored quantity.
|
|
99
|
+
mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
|
|
100
|
+
save_freq: `'epoch'` or integer. Frequency to save checkpoints.
|
|
101
|
+
max_to_keep: Integer, maximum number of recent checkpoints to keep.
|
|
102
|
+
If None, keeps all. Defaults to 1.
|
|
103
|
+
save_on_background: Boolean, whether to save asynchronously in the
|
|
104
|
+
background. Defaults to True.
|
|
105
|
+
initial_value_threshold: Floating point initial "best" value for the
|
|
106
|
+
monitor, used with `save_best_only`.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
directory,
|
|
112
|
+
monitor="val_loss",
|
|
113
|
+
verbose=0,
|
|
114
|
+
save_best_only=False,
|
|
115
|
+
mode="auto",
|
|
116
|
+
save_freq="epoch",
|
|
117
|
+
initial_value_threshold=None,
|
|
118
|
+
max_to_keep=1,
|
|
119
|
+
save_on_background=True,
|
|
120
|
+
):
|
|
121
|
+
# Ensure orbax is available
|
|
122
|
+
ocp.initialize()
|
|
123
|
+
|
|
124
|
+
# Initialize MonitorCallback for handling 'monitor', 'mode', 'best'
|
|
125
|
+
# logic
|
|
126
|
+
super().__init__(monitor, mode, initial_value_threshold)
|
|
127
|
+
|
|
128
|
+
self.directory = directory
|
|
129
|
+
self.verbose = verbose
|
|
130
|
+
self.save_best_only = save_best_only
|
|
131
|
+
self.save_freq = save_freq
|
|
132
|
+
self.max_to_keep = max_to_keep
|
|
133
|
+
self.save_on_background = save_on_background
|
|
134
|
+
self._batches_seen_since_last_saving = 0
|
|
135
|
+
self._last_batch_seen = 0
|
|
136
|
+
self._current_epoch = 0 # Keep track of epoch
|
|
137
|
+
self._total_batches_seen = 0 # Global batch counter for step tracking
|
|
138
|
+
|
|
139
|
+
# Multi-host support
|
|
140
|
+
self._multihost_initialized = self._is_multihost_initialized()
|
|
141
|
+
|
|
142
|
+
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"Unrecognized save_freq: {self.save_freq}. "
|
|
145
|
+
"Expected save_freq are 'epoch' or integer values"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# --- Orbax Checkpointer Setup (V1 API) ---
|
|
149
|
+
policies = []
|
|
150
|
+
if max_to_keep is not None:
|
|
151
|
+
policies.append(
|
|
152
|
+
ocp.training.preservation_policies.LatestN(max_to_keep)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Use AnyPreservationPolicy to combine them, or use directly
|
|
156
|
+
# if single policy
|
|
157
|
+
preservation_policy = None
|
|
158
|
+
if policies:
|
|
159
|
+
if len(policies) == 1:
|
|
160
|
+
preservation_policy = policies[0]
|
|
161
|
+
else:
|
|
162
|
+
preservation_policy = (
|
|
163
|
+
ocp.training.preservation_policies.AnyPreservationPolicy(
|
|
164
|
+
policies
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Create the V1 Checkpointer with direct parameter passing
|
|
169
|
+
# Orbax will handle directory creation on all processes as needed
|
|
170
|
+
self.checkpointer = ocp.training.Checkpointer(
|
|
171
|
+
directory=directory,
|
|
172
|
+
preservation_policy=preservation_policy,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _is_multihost_initialized(self):
|
|
176
|
+
"""Check if multi-host environment is initialized."""
|
|
177
|
+
# Multi-host checkpointing is only supported on JAX backend
|
|
178
|
+
if backend.backend() != "jax":
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
multihost = ocp.multihost
|
|
182
|
+
# Check if JAX distributed client is initialized
|
|
183
|
+
# (indicates multihost setup)
|
|
184
|
+
return multihost.is_jax_distributed_client_initialized()
|
|
185
|
+
|
|
186
|
+
def _sync_processes(self, key=None):
|
|
187
|
+
"""Synchronize all processes across hosts."""
|
|
188
|
+
if not self._multihost_initialized:
|
|
189
|
+
return # No-op for single host
|
|
190
|
+
|
|
191
|
+
multihost = ocp.multihost
|
|
192
|
+
sync_key = key or "orbax_checkpoint_sync"
|
|
193
|
+
multihost.sync_global_processes(sync_key)
|
|
194
|
+
|
|
195
|
+
def is_multihost_enabled(self):
|
|
196
|
+
"""Return True if multi-host checkpointing is enabled and initialized.
|
|
197
|
+
|
|
198
|
+
This method can be used to check if the callback is operating in
|
|
199
|
+
a multi-host distributed training environment. Multi-host checkpointing
|
|
200
|
+
is only supported on JAX backend.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
bool: True if multi-host support is active, False otherwise.
|
|
204
|
+
"""
|
|
205
|
+
return self._multihost_initialized
|
|
206
|
+
|
|
207
|
+
def is_primary_host(self):
|
|
208
|
+
"""Return True if this process is the primary host in multi-host setup.
|
|
209
|
+
|
|
210
|
+
In multi-host environments, only the primary host typically handles
|
|
211
|
+
logging and coordination tasks. Multi-host checkpointing is only
|
|
212
|
+
supported on JAX backend.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
bool: True if this is the primary host, False otherwise.
|
|
216
|
+
Always returns True in single-host environments.
|
|
217
|
+
"""
|
|
218
|
+
if not self._multihost_initialized:
|
|
219
|
+
return True # Single host is always primary
|
|
220
|
+
multihost = ocp.multihost
|
|
221
|
+
return multihost.is_primary_host()
|
|
222
|
+
|
|
223
|
+
def _should_save_on_batch(self, batch):
|
|
224
|
+
"""Check if we should save on this batch."""
|
|
225
|
+
if self.save_freq == "epoch":
|
|
226
|
+
return False
|
|
227
|
+
|
|
228
|
+
if batch <= self._last_batch_seen: # New epoch.
|
|
229
|
+
add_batches = batch + 1
|
|
230
|
+
else:
|
|
231
|
+
add_batches = batch - self._last_batch_seen
|
|
232
|
+
self._batches_seen_since_last_saving += add_batches
|
|
233
|
+
self._last_batch_seen = batch
|
|
234
|
+
self._total_batches_seen += add_batches
|
|
235
|
+
|
|
236
|
+
if self._batches_seen_since_last_saving >= self.save_freq:
|
|
237
|
+
self._batches_seen_since_last_saving = 0
|
|
238
|
+
return True
|
|
239
|
+
return False
|
|
240
|
+
|
|
241
|
+
def _save_checkpoint(self, step, logs=None):
|
|
242
|
+
"""Save a checkpoint at the given step with multi-host coordination."""
|
|
243
|
+
|
|
244
|
+
# --- Prepare Composite State (Backend-Agnostic) ---
|
|
245
|
+
state_tree = _get_state_tree(self.model)
|
|
246
|
+
|
|
247
|
+
# Save the nested state structures directly (preserving layer
|
|
248
|
+
# names and structure)
|
|
249
|
+
composite_state = state_tree
|
|
250
|
+
|
|
251
|
+
# Use a single with statement. If context_options is empty,
|
|
252
|
+
# Context() uses defaults.
|
|
253
|
+
with ocp.Context():
|
|
254
|
+
if self.save_on_background:
|
|
255
|
+
self.checkpointer.save_pytree_async(step, composite_state)
|
|
256
|
+
else:
|
|
257
|
+
self.checkpointer.save_pytree(step, composite_state)
|
|
258
|
+
|
|
259
|
+
def on_train_batch_end(self, batch, logs=None):
|
|
260
|
+
if self._should_save_on_batch(batch):
|
|
261
|
+
# Handle save_best_only logic for batch-level saving
|
|
262
|
+
should_save = True
|
|
263
|
+
if self.save_best_only:
|
|
264
|
+
current = logs.get(self.monitor) if logs else None
|
|
265
|
+
if current is None:
|
|
266
|
+
warnings.warn(
|
|
267
|
+
f"Can save best model only with {self.monitor} "
|
|
268
|
+
f"available, skipping save at batch {batch}.",
|
|
269
|
+
stacklevel=2,
|
|
270
|
+
)
|
|
271
|
+
should_save = False
|
|
272
|
+
elif not self._is_improvement(current, self.best):
|
|
273
|
+
should_save = False
|
|
274
|
+
else:
|
|
275
|
+
# Update best value when there's improvement
|
|
276
|
+
self.best = current
|
|
277
|
+
|
|
278
|
+
if should_save:
|
|
279
|
+
# Use global batch count for Orbax save step
|
|
280
|
+
step = self._total_batches_seen
|
|
281
|
+
self._save_checkpoint(step=step, logs=logs)
|
|
282
|
+
|
|
283
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
284
|
+
self._current_epoch = epoch
|
|
285
|
+
if self.monitor_op is None:
|
|
286
|
+
self._set_monitor_op() # From MonitorCallback
|
|
287
|
+
|
|
288
|
+
# For save_freq="epoch", save at every epoch
|
|
289
|
+
should_save = self.save_freq == "epoch"
|
|
290
|
+
|
|
291
|
+
# Handle save_best_only logic
|
|
292
|
+
if should_save and self.save_best_only:
|
|
293
|
+
current = logs.get(self.monitor) if logs else None
|
|
294
|
+
if current is None:
|
|
295
|
+
warnings.warn(
|
|
296
|
+
f"Can save best model only with {self.monitor} available, "
|
|
297
|
+
f"skipping save at epoch {epoch}.",
|
|
298
|
+
stacklevel=2,
|
|
299
|
+
)
|
|
300
|
+
should_save = False
|
|
301
|
+
elif not self._is_improvement(current, self.best):
|
|
302
|
+
should_save = False
|
|
303
|
+
else:
|
|
304
|
+
# Update best value when there's improvement
|
|
305
|
+
self.best = current
|
|
306
|
+
|
|
307
|
+
if should_save:
|
|
308
|
+
# Use epoch number as the step for Orbax save
|
|
309
|
+
# Keras has already made the save decision - Checkpointer will
|
|
310
|
+
# save unconditionally
|
|
311
|
+
self._save_checkpoint(step=epoch, logs=logs)
|
|
312
|
+
|
|
313
|
+
def on_train_end(self, logs=None):
|
|
314
|
+
# Close the Checkpointer to ensure all pending saves complete
|
|
315
|
+
try:
|
|
316
|
+
self.checkpointer.close()
|
|
317
|
+
except Exception:
|
|
318
|
+
pass # Ignore errors during cleanup
|
|
319
|
+
|
|
320
|
+
# Multi-host synchronization: ensure all hosts complete cleanup
|
|
321
|
+
self._sync_processes("checkpoint_cleanup")
|
|
322
|
+
|
|
323
|
+
def wait_until_finished(self):
|
|
324
|
+
"""Wait for any in-progress checkpoint operations to complete.
|
|
325
|
+
This method blocks until all asynchronous checkpoint save operations
|
|
326
|
+
have completed across all hosts in a multi-host setup.
|
|
327
|
+
"""
|
|
328
|
+
# Wait for any async operations to complete on this host
|
|
329
|
+
self.checkpointer.wait()
|
|
330
|
+
|
|
331
|
+
# Multi-host synchronization: ensure all hosts complete
|
|
332
|
+
self._sync_processes("checkpoint_wait_complete")
|
|
@@ -7,14 +7,63 @@ from keras.src.utils import io_utils
|
|
|
7
7
|
|
|
8
8
|
@keras_export("keras.callbacks.TerminateOnNaN")
|
|
9
9
|
class TerminateOnNaN(Callback):
|
|
10
|
-
"""Callback that terminates training when a NaN loss is encountered.
|
|
10
|
+
"""Callback that terminates training when a NaN loss is encountered.
|
|
11
|
+
|
|
12
|
+
This callback monitors the loss value during training
|
|
13
|
+
and terminates training when a NaN or Inf loss is detected.
|
|
14
|
+
By default, training is stopped gracefully
|
|
15
|
+
by setting `model.stop_training = True`, which triggers all callback cleanup
|
|
16
|
+
methods including `on_train_end()`.
|
|
17
|
+
|
|
18
|
+
Alternatively, you can use `raise_error=True` to immediately raise a
|
|
19
|
+
RuntimeError when NaN/Inf is detected. This raise_error termination
|
|
20
|
+
prevents `on_train_end()` from being called on other callbacks, which
|
|
21
|
+
is useful for preserving backup states or preventing unintended cleanup
|
|
22
|
+
when training fails.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
raise_error: Boolean, default False. If False, uses graceful stop via
|
|
26
|
+
`model.stop_training = True`. If True, immediately raises
|
|
27
|
+
RuntimeError on NaN/Inf loss, bypassing callback cleanup methods.
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
|
|
31
|
+
```
|
|
32
|
+
# Graceful termination (default)
|
|
33
|
+
callback = keras.callbacks.TerminateOnNaN()
|
|
34
|
+
model.fit(x, y, callbacks=[callback])
|
|
35
|
+
|
|
36
|
+
# raise_error termination (strict failure)
|
|
37
|
+
callback = keras.callbacks.TerminateOnNaN(raise_error=True)
|
|
38
|
+
model.fit(x, y, callbacks=[callback])
|
|
39
|
+
```
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, raise_error: bool = False):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.raise_error = raise_error
|
|
11
45
|
|
|
12
46
|
def on_batch_end(self, batch, logs=None):
|
|
47
|
+
"""Check for NaN/Inf loss at the end of each batch.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
batch: Integer, index of batch within the current epoch.
|
|
51
|
+
logs: Dict, contains the return value of `model.train_step()`.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
RuntimeError: If loss is NaN/Inf and raise_error=True.
|
|
55
|
+
"""
|
|
13
56
|
logs = logs or {}
|
|
14
57
|
loss = logs.get("loss")
|
|
15
58
|
if loss is not None:
|
|
16
59
|
if np.isnan(loss) or np.isinf(loss):
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
60
|
+
if self.raise_error:
|
|
61
|
+
raise RuntimeError(
|
|
62
|
+
f"NaN or Inf loss encountered at batch {batch}. "
|
|
63
|
+
f"Loss value: {loss}. Terminating training immediately."
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
io_utils.print_msg(
|
|
67
|
+
f"Batch {batch}: Invalid loss, terminating training"
|
|
68
|
+
)
|
|
69
|
+
self.model.stop_training = True
|
keras/src/datasets/cifar10.py
CHANGED
|
@@ -59,6 +59,11 @@ def load_data():
|
|
|
59
59
|
assert y_train.shape == (50000, 1)
|
|
60
60
|
assert y_test.shape == (10000, 1)
|
|
61
61
|
```
|
|
62
|
+
|
|
63
|
+
**Note**: The CIFAR-10 dataset is known to have a small percentage of
|
|
64
|
+
mislabeled samples, which is inherent to the original dataset. This label
|
|
65
|
+
noise may impact training and evaluation. For more details, refer to
|
|
66
|
+
discussions in the research literature on CIFAR-10 label quality.
|
|
62
67
|
"""
|
|
63
68
|
dirname = "cifar-10-batches-py-target"
|
|
64
69
|
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Distillation module for knowledge distillation in Keras."""
|