keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/backend/torch/numpy.py
CHANGED
|
@@ -411,6 +411,12 @@ def array(x, dtype=None):
|
|
|
411
411
|
return convert_to_tensor(x, dtype=dtype)
|
|
412
412
|
|
|
413
413
|
|
|
414
|
+
def view(x, dtype=None):
|
|
415
|
+
dtype = to_torch_dtype(dtype)
|
|
416
|
+
x = convert_to_tensor(x)
|
|
417
|
+
return x.view(dtype=dtype)
|
|
418
|
+
|
|
419
|
+
|
|
414
420
|
def average(x, axis=None, weights=None):
|
|
415
421
|
x = convert_to_tensor(x)
|
|
416
422
|
dtypes_to_resolve = [x.dtype, float]
|
|
@@ -764,6 +770,12 @@ def empty(shape, dtype=None):
|
|
|
764
770
|
return torch.empty(size=shape, dtype=dtype, device=get_device())
|
|
765
771
|
|
|
766
772
|
|
|
773
|
+
def empty_like(x, dtype=None):
|
|
774
|
+
x = convert_to_tensor(x)
|
|
775
|
+
dtype = to_torch_dtype(dtype or x.dtype)
|
|
776
|
+
return torch.empty_like(x, dtype=dtype, device=get_device())
|
|
777
|
+
|
|
778
|
+
|
|
767
779
|
def equal(x1, x2):
|
|
768
780
|
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
|
769
781
|
return torch.eq(x1, x2)
|
|
@@ -946,6 +958,11 @@ def isposinf(x):
|
|
|
946
958
|
return torch.isposinf(x)
|
|
947
959
|
|
|
948
960
|
|
|
961
|
+
def isreal(x):
|
|
962
|
+
x = convert_to_tensor(x)
|
|
963
|
+
return torch.isreal(x)
|
|
964
|
+
|
|
965
|
+
|
|
949
966
|
def kron(x1, x2):
|
|
950
967
|
x1 = convert_to_tensor(x1)
|
|
951
968
|
x2 = convert_to_tensor(x2)
|
|
@@ -958,6 +975,20 @@ def lcm(x1, x2):
|
|
|
958
975
|
return torch.lcm(x1, x2)
|
|
959
976
|
|
|
960
977
|
|
|
978
|
+
def ldexp(x1, x2):
|
|
979
|
+
x1 = convert_to_tensor(x1)
|
|
980
|
+
x2 = convert_to_tensor(x2)
|
|
981
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
982
|
+
|
|
983
|
+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
|
|
984
|
+
raise TypeError(
|
|
985
|
+
f"ldexp exponent must be an integer type. "
|
|
986
|
+
f"Received: x2 dtype={x2.dtype}"
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
return cast(torch.ldexp(x1, x2), dtype)
|
|
990
|
+
|
|
991
|
+
|
|
961
992
|
def less(x1, x2):
|
|
962
993
|
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
|
963
994
|
return torch.less(x1, x2)
|
|
@@ -1351,6 +1382,18 @@ def prod(x, axis=None, keepdims=False, dtype=None):
|
|
|
1351
1382
|
return x
|
|
1352
1383
|
|
|
1353
1384
|
|
|
1385
|
+
def ptp(x, axis=None, keepdims=False):
|
|
1386
|
+
x = convert_to_tensor(x)
|
|
1387
|
+
if axis is None:
|
|
1388
|
+
return x.max() - x.min()
|
|
1389
|
+
elif axis == ():
|
|
1390
|
+
return torch.zeros_like(x)
|
|
1391
|
+
else:
|
|
1392
|
+
return torch.amax(x, dim=axis, keepdim=keepdims) - torch.amin(
|
|
1393
|
+
x, dim=axis, keepdim=keepdims
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
|
|
1354
1397
|
def quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
1355
1398
|
x = convert_to_tensor(x)
|
|
1356
1399
|
q = convert_to_tensor(q)
|
|
@@ -1528,6 +1571,12 @@ def split(x, indices_or_sections, axis=0):
|
|
|
1528
1571
|
return list(out)
|
|
1529
1572
|
|
|
1530
1573
|
|
|
1574
|
+
def array_split(x, indices_or_sections, axis=0):
|
|
1575
|
+
x = convert_to_tensor(x)
|
|
1576
|
+
out = torch.tensor_split(x, indices_or_sections, dim=axis)
|
|
1577
|
+
return list(out)
|
|
1578
|
+
|
|
1579
|
+
|
|
1531
1580
|
def stack(x, axis=0):
|
|
1532
1581
|
x = [convert_to_tensor(elem) for elem in x]
|
|
1533
1582
|
return torch.stack(x, dim=axis)
|
|
@@ -1641,8 +1690,9 @@ def tile(x, repeats):
|
|
|
1641
1690
|
def trace(x, offset=0, axis1=0, axis2=1):
|
|
1642
1691
|
x = convert_to_tensor(x)
|
|
1643
1692
|
dtype = standardize_dtype(x.dtype)
|
|
1644
|
-
if dtype
|
|
1645
|
-
|
|
1693
|
+
if dtype in ("bool", "int8", "int16", "uint8"):
|
|
1694
|
+
# Torch backend doesn't support uint32 dtype.
|
|
1695
|
+
dtype = "int32"
|
|
1646
1696
|
return torch.sum(
|
|
1647
1697
|
torch.diagonal(x, offset, axis1, axis2),
|
|
1648
1698
|
dim=-1,
|
|
@@ -1755,6 +1805,16 @@ def negative(x):
|
|
|
1755
1805
|
return torch.negative(x)
|
|
1756
1806
|
|
|
1757
1807
|
|
|
1808
|
+
def nextafter(x1, x2):
|
|
1809
|
+
x1 = convert_to_tensor(x1)
|
|
1810
|
+
x2 = convert_to_tensor(x2)
|
|
1811
|
+
|
|
1812
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1813
|
+
x1 = cast(x1, torch.float64)
|
|
1814
|
+
x2 = cast(x2, torch.float64)
|
|
1815
|
+
return cast(torch.nextafter(x1, x2), dtype)
|
|
1816
|
+
|
|
1817
|
+
|
|
1758
1818
|
def square(x):
|
|
1759
1819
|
x = convert_to_tensor(x)
|
|
1760
1820
|
if standardize_dtype(x.dtype) == "bool":
|
|
@@ -1783,6 +1843,24 @@ def transpose(x, axes=None):
|
|
|
1783
1843
|
return x.T
|
|
1784
1844
|
|
|
1785
1845
|
|
|
1846
|
+
def trapezoid(y, x=None, dx=1.0, axis=-1):
|
|
1847
|
+
y = convert_to_tensor(y)
|
|
1848
|
+
if standardize_dtype(y.dtype) == "bool":
|
|
1849
|
+
y = cast(y, config.floatx())
|
|
1850
|
+
if x is not None:
|
|
1851
|
+
x = convert_to_tensor(x)
|
|
1852
|
+
return torch.trapz(y, x=x, dim=axis)
|
|
1853
|
+
else:
|
|
1854
|
+
dx = convert_to_tensor(dx)
|
|
1855
|
+
return torch.trapz(y, dx=dx, dim=axis)
|
|
1856
|
+
|
|
1857
|
+
|
|
1858
|
+
def vander(x, N=None, increasing=False):
|
|
1859
|
+
x = convert_to_tensor(x)
|
|
1860
|
+
result_dtype = dtypes.result_type(x.dtype)
|
|
1861
|
+
return cast(torch.vander(x, N=N, increasing=increasing), result_dtype)
|
|
1862
|
+
|
|
1863
|
+
|
|
1786
1864
|
def var(x, axis=None, keepdims=False):
|
|
1787
1865
|
x = convert_to_tensor(x)
|
|
1788
1866
|
compute_dtype = dtypes.result_type(x.dtype, "float32")
|
keras/src/callbacks/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ from keras.src.callbacks.lambda_callback import LambdaCallback
|
|
|
8
8
|
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
|
|
9
9
|
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
|
|
10
10
|
from keras.src.callbacks.monitor_callback import MonitorCallback
|
|
11
|
+
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
|
|
11
12
|
from keras.src.callbacks.progbar_logger import ProgbarLogger
|
|
12
13
|
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
|
|
13
14
|
from keras.src.callbacks.remote_monitor import RemoteMonitor
|
|
@@ -283,6 +283,11 @@ class ModelCheckpoint(MonitorCallback):
|
|
|
283
283
|
self.model.save_weights(filepath, overwrite=True)
|
|
284
284
|
else:
|
|
285
285
|
self.model.save(filepath, overwrite=True)
|
|
286
|
+
if self.verbose > 0:
|
|
287
|
+
io_utils.print_msg(
|
|
288
|
+
f"\nEpoch {epoch + 1}: "
|
|
289
|
+
f"finished saving model to {filepath}"
|
|
290
|
+
)
|
|
286
291
|
except IsADirectoryError: # h5py 3.x
|
|
287
292
|
raise IOError(
|
|
288
293
|
"Please specify a non-directory filepath for "
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from keras.src import backend
|
|
6
|
+
from keras.src import tree
|
|
7
|
+
from keras.src.api_export import keras_export
|
|
8
|
+
from keras.src.callbacks.monitor_callback import (
|
|
9
|
+
MonitorCallback, # For metric monitoring logic
|
|
10
|
+
)
|
|
11
|
+
from keras.src.utils.module_utils import ocp
|
|
12
|
+
|
|
13
|
+
# Context and AsyncOptions are accessed through the lazy-loaded ocp module
|
|
14
|
+
|
|
15
|
+
# JAX monitoring compatibility: ensure record_scalar exists
|
|
16
|
+
# to prevent AttributeError in older JAX versions
|
|
17
|
+
try:
|
|
18
|
+
import jax
|
|
19
|
+
|
|
20
|
+
if not hasattr(jax.monitoring, "record_scalar"):
|
|
21
|
+
jax.monitoring.record_scalar = lambda *args, **kwargs: None
|
|
22
|
+
except ImportError:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _get_state_tree(model):
|
|
27
|
+
"""Get the complete model state as a nested tree structure."""
|
|
28
|
+
# For JAX backend, preserve native arrays for performance
|
|
29
|
+
# For other backends, convert to numpy arrays
|
|
30
|
+
if backend.backend() == "jax":
|
|
31
|
+
state_tree = model.get_state_tree()
|
|
32
|
+
did_numpy_conversion = False
|
|
33
|
+
else:
|
|
34
|
+
state_tree = model.get_state_tree(value_format="numpy_array")
|
|
35
|
+
did_numpy_conversion = True
|
|
36
|
+
|
|
37
|
+
# Convert numpy scalar types to Python types for Orbax compatibility
|
|
38
|
+
# Only needed when we did numpy conversion
|
|
39
|
+
if did_numpy_conversion:
|
|
40
|
+
|
|
41
|
+
def convert_scalars(obj):
|
|
42
|
+
if isinstance(obj, np.ndarray) and obj.ndim == 0:
|
|
43
|
+
# Convert 0-dimensional numpy arrays (scalars) to Python types
|
|
44
|
+
return obj.item()
|
|
45
|
+
elif isinstance(obj, np.generic):
|
|
46
|
+
# Convert numpy scalar types (like np.float32) to Python types
|
|
47
|
+
return obj.item()
|
|
48
|
+
else:
|
|
49
|
+
return obj
|
|
50
|
+
|
|
51
|
+
return tree.map_structure(convert_scalars, state_tree)
|
|
52
|
+
else:
|
|
53
|
+
return state_tree
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@keras_export("keras.callbacks.OrbaxCheckpoint")
|
|
57
|
+
class OrbaxCheckpoint(MonitorCallback):
|
|
58
|
+
"""Callback to save and load model state using Orbax with a similar API to
|
|
59
|
+
ModelCheckpoint.
|
|
60
|
+
|
|
61
|
+
This callback saves the model's weights and optimizer state asynchronously
|
|
62
|
+
using Orbax, allowing training to continue without blocking for I/O.
|
|
63
|
+
|
|
64
|
+
**Multi-host Support**: When running in a multi-host distributed training
|
|
65
|
+
environment with JAX backend, this callback automatically coordinates
|
|
66
|
+
checkpointing across all hosts to ensure consistency and proper
|
|
67
|
+
synchronization. Multi-host checkpointing is only supported on JAX.
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
model.compile(loss=..., optimizer=..., metrics=['accuracy'])
|
|
73
|
+
|
|
74
|
+
EPOCHS = 10
|
|
75
|
+
checkpoint_dir = '/tmp/ckpt'
|
|
76
|
+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
|
|
77
|
+
directory=checkpoint_dir,
|
|
78
|
+
monitor='val_accuracy',
|
|
79
|
+
mode='max',
|
|
80
|
+
save_best_only=True)
|
|
81
|
+
|
|
82
|
+
# Model is saved at the end of every epoch, if it's the best seen so far.
|
|
83
|
+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
|
|
84
|
+
|
|
85
|
+
# Alternatively, save checkpoints every N batches -
|
|
86
|
+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
|
|
87
|
+
directory=checkpoint_dir,
|
|
88
|
+
save_freq=100) # Save every 100 batches
|
|
89
|
+
|
|
90
|
+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
directory: path to the directory where to save the checkpoints.
|
|
95
|
+
monitor: The metric name to monitor (e.g., 'val_loss').
|
|
96
|
+
verbose: Verbosity mode, 0 or 1.
|
|
97
|
+
save_best_only: if `save_best_only=True`, it only saves when the model
|
|
98
|
+
is considered the "best" based on the monitored quantity.
|
|
99
|
+
mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
|
|
100
|
+
save_freq: `'epoch'` or integer. Frequency to save checkpoints.
|
|
101
|
+
max_to_keep: Integer, maximum number of recent checkpoints to keep.
|
|
102
|
+
If None, keeps all. Defaults to 1.
|
|
103
|
+
save_on_background: Boolean, whether to save asynchronously in the
|
|
104
|
+
background. Defaults to True.
|
|
105
|
+
initial_value_threshold: Floating point initial "best" value for the
|
|
106
|
+
monitor, used with `save_best_only`.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
directory,
|
|
112
|
+
monitor="val_loss",
|
|
113
|
+
verbose=0,
|
|
114
|
+
save_best_only=False,
|
|
115
|
+
mode="auto",
|
|
116
|
+
save_freq="epoch",
|
|
117
|
+
initial_value_threshold=None,
|
|
118
|
+
max_to_keep=1,
|
|
119
|
+
save_on_background=True,
|
|
120
|
+
):
|
|
121
|
+
# Ensure orbax is available
|
|
122
|
+
ocp.initialize()
|
|
123
|
+
|
|
124
|
+
# Initialize MonitorCallback for handling 'monitor', 'mode', 'best'
|
|
125
|
+
# logic
|
|
126
|
+
super().__init__(monitor, mode, initial_value_threshold)
|
|
127
|
+
|
|
128
|
+
self.directory = directory
|
|
129
|
+
self.verbose = verbose
|
|
130
|
+
self.save_best_only = save_best_only
|
|
131
|
+
self.save_freq = save_freq
|
|
132
|
+
self.max_to_keep = max_to_keep
|
|
133
|
+
self.save_on_background = save_on_background
|
|
134
|
+
self._batches_seen_since_last_saving = 0
|
|
135
|
+
self._last_batch_seen = 0
|
|
136
|
+
self._current_epoch = 0 # Keep track of epoch
|
|
137
|
+
self._total_batches_seen = 0 # Global batch counter for step tracking
|
|
138
|
+
|
|
139
|
+
# Multi-host support
|
|
140
|
+
self._multihost_initialized = self._is_multihost_initialized()
|
|
141
|
+
|
|
142
|
+
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"Unrecognized save_freq: {self.save_freq}. "
|
|
145
|
+
"Expected save_freq are 'epoch' or integer values"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# --- Orbax Checkpointer Setup (V1 API) ---
|
|
149
|
+
policies = []
|
|
150
|
+
if max_to_keep is not None:
|
|
151
|
+
policies.append(
|
|
152
|
+
ocp.training.preservation_policies.LatestN(max_to_keep)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Use AnyPreservationPolicy to combine them, or use directly
|
|
156
|
+
# if single policy
|
|
157
|
+
preservation_policy = None
|
|
158
|
+
if policies:
|
|
159
|
+
if len(policies) == 1:
|
|
160
|
+
preservation_policy = policies[0]
|
|
161
|
+
else:
|
|
162
|
+
preservation_policy = (
|
|
163
|
+
ocp.training.preservation_policies.AnyPreservationPolicy(
|
|
164
|
+
policies
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Create the V1 Checkpointer with direct parameter passing
|
|
169
|
+
# Orbax will handle directory creation on all processes as needed
|
|
170
|
+
self.checkpointer = ocp.training.Checkpointer(
|
|
171
|
+
directory=directory,
|
|
172
|
+
preservation_policy=preservation_policy,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _is_multihost_initialized(self):
|
|
176
|
+
"""Check if multi-host environment is initialized."""
|
|
177
|
+
# Multi-host checkpointing is only supported on JAX backend
|
|
178
|
+
if backend.backend() != "jax":
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
multihost = ocp.multihost
|
|
182
|
+
# Check if JAX distributed client is initialized
|
|
183
|
+
# (indicates multihost setup)
|
|
184
|
+
return multihost.is_jax_distributed_client_initialized()
|
|
185
|
+
|
|
186
|
+
def _sync_processes(self, key=None):
|
|
187
|
+
"""Synchronize all processes across hosts."""
|
|
188
|
+
if not self._multihost_initialized:
|
|
189
|
+
return # No-op for single host
|
|
190
|
+
|
|
191
|
+
multihost = ocp.multihost
|
|
192
|
+
sync_key = key or "orbax_checkpoint_sync"
|
|
193
|
+
multihost.sync_global_processes(sync_key)
|
|
194
|
+
|
|
195
|
+
def is_multihost_enabled(self):
|
|
196
|
+
"""Return True if multi-host checkpointing is enabled and initialized.
|
|
197
|
+
|
|
198
|
+
This method can be used to check if the callback is operating in
|
|
199
|
+
a multi-host distributed training environment. Multi-host checkpointing
|
|
200
|
+
is only supported on JAX backend.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
bool: True if multi-host support is active, False otherwise.
|
|
204
|
+
"""
|
|
205
|
+
return self._multihost_initialized
|
|
206
|
+
|
|
207
|
+
def is_primary_host(self):
|
|
208
|
+
"""Return True if this process is the primary host in multi-host setup.
|
|
209
|
+
|
|
210
|
+
In multi-host environments, only the primary host typically handles
|
|
211
|
+
logging and coordination tasks. Multi-host checkpointing is only
|
|
212
|
+
supported on JAX backend.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
bool: True if this is the primary host, False otherwise.
|
|
216
|
+
Always returns True in single-host environments.
|
|
217
|
+
"""
|
|
218
|
+
if not self._multihost_initialized:
|
|
219
|
+
return True # Single host is always primary
|
|
220
|
+
multihost = ocp.multihost
|
|
221
|
+
return multihost.is_primary_host()
|
|
222
|
+
|
|
223
|
+
def _should_save_on_batch(self, batch):
|
|
224
|
+
"""Check if we should save on this batch."""
|
|
225
|
+
if self.save_freq == "epoch":
|
|
226
|
+
return False
|
|
227
|
+
|
|
228
|
+
if batch <= self._last_batch_seen: # New epoch.
|
|
229
|
+
add_batches = batch + 1
|
|
230
|
+
else:
|
|
231
|
+
add_batches = batch - self._last_batch_seen
|
|
232
|
+
self._batches_seen_since_last_saving += add_batches
|
|
233
|
+
self._last_batch_seen = batch
|
|
234
|
+
self._total_batches_seen += add_batches
|
|
235
|
+
|
|
236
|
+
if self._batches_seen_since_last_saving >= self.save_freq:
|
|
237
|
+
self._batches_seen_since_last_saving = 0
|
|
238
|
+
return True
|
|
239
|
+
return False
|
|
240
|
+
|
|
241
|
+
def _save_checkpoint(self, step, logs=None):
|
|
242
|
+
"""Save a checkpoint at the given step with multi-host coordination."""
|
|
243
|
+
|
|
244
|
+
# --- Prepare Composite State (Backend-Agnostic) ---
|
|
245
|
+
state_tree = _get_state_tree(self.model)
|
|
246
|
+
|
|
247
|
+
# Save the nested state structures directly (preserving layer
|
|
248
|
+
# names and structure)
|
|
249
|
+
composite_state = state_tree
|
|
250
|
+
|
|
251
|
+
# Use a single with statement. If context_options is empty,
|
|
252
|
+
# Context() uses defaults.
|
|
253
|
+
with ocp.Context():
|
|
254
|
+
if self.save_on_background:
|
|
255
|
+
self.checkpointer.save_pytree_async(step, composite_state)
|
|
256
|
+
else:
|
|
257
|
+
self.checkpointer.save_pytree(step, composite_state)
|
|
258
|
+
|
|
259
|
+
def on_train_batch_end(self, batch, logs=None):
|
|
260
|
+
if self._should_save_on_batch(batch):
|
|
261
|
+
# Handle save_best_only logic for batch-level saving
|
|
262
|
+
should_save = True
|
|
263
|
+
if self.save_best_only:
|
|
264
|
+
current = logs.get(self.monitor) if logs else None
|
|
265
|
+
if current is None:
|
|
266
|
+
warnings.warn(
|
|
267
|
+
f"Can save best model only with {self.monitor} "
|
|
268
|
+
f"available, skipping save at batch {batch}.",
|
|
269
|
+
stacklevel=2,
|
|
270
|
+
)
|
|
271
|
+
should_save = False
|
|
272
|
+
elif not self._is_improvement(current, self.best):
|
|
273
|
+
should_save = False
|
|
274
|
+
else:
|
|
275
|
+
# Update best value when there's improvement
|
|
276
|
+
self.best = current
|
|
277
|
+
|
|
278
|
+
if should_save:
|
|
279
|
+
# Use global batch count for Orbax save step
|
|
280
|
+
step = self._total_batches_seen
|
|
281
|
+
self._save_checkpoint(step=step, logs=logs)
|
|
282
|
+
|
|
283
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
284
|
+
self._current_epoch = epoch
|
|
285
|
+
if self.monitor_op is None:
|
|
286
|
+
self._set_monitor_op() # From MonitorCallback
|
|
287
|
+
|
|
288
|
+
# For save_freq="epoch", save at every epoch
|
|
289
|
+
should_save = self.save_freq == "epoch"
|
|
290
|
+
|
|
291
|
+
# Handle save_best_only logic
|
|
292
|
+
if should_save and self.save_best_only:
|
|
293
|
+
current = logs.get(self.monitor) if logs else None
|
|
294
|
+
if current is None:
|
|
295
|
+
warnings.warn(
|
|
296
|
+
f"Can save best model only with {self.monitor} available, "
|
|
297
|
+
f"skipping save at epoch {epoch}.",
|
|
298
|
+
stacklevel=2,
|
|
299
|
+
)
|
|
300
|
+
should_save = False
|
|
301
|
+
elif not self._is_improvement(current, self.best):
|
|
302
|
+
should_save = False
|
|
303
|
+
else:
|
|
304
|
+
# Update best value when there's improvement
|
|
305
|
+
self.best = current
|
|
306
|
+
|
|
307
|
+
if should_save:
|
|
308
|
+
# Use epoch number as the step for Orbax save
|
|
309
|
+
# Keras has already made the save decision - Checkpointer will
|
|
310
|
+
# save unconditionally
|
|
311
|
+
self._save_checkpoint(step=epoch, logs=logs)
|
|
312
|
+
|
|
313
|
+
def on_train_end(self, logs=None):
|
|
314
|
+
# Close the Checkpointer to ensure all pending saves complete
|
|
315
|
+
try:
|
|
316
|
+
self.checkpointer.close()
|
|
317
|
+
except Exception:
|
|
318
|
+
pass # Ignore errors during cleanup
|
|
319
|
+
|
|
320
|
+
# Multi-host synchronization: ensure all hosts complete cleanup
|
|
321
|
+
self._sync_processes("checkpoint_cleanup")
|
|
322
|
+
|
|
323
|
+
def wait_until_finished(self):
|
|
324
|
+
"""Wait for any in-progress checkpoint operations to complete.
|
|
325
|
+
This method blocks until all asynchronous checkpoint save operations
|
|
326
|
+
have completed across all hosts in a multi-host setup.
|
|
327
|
+
"""
|
|
328
|
+
# Wait for any async operations to complete on this host
|
|
329
|
+
self.checkpointer.wait()
|
|
330
|
+
|
|
331
|
+
# Multi-host synchronization: ensure all hosts complete
|
|
332
|
+
self._sync_processes("checkpoint_wait_complete")
|
|
@@ -7,14 +7,63 @@ from keras.src.utils import io_utils
|
|
|
7
7
|
|
|
8
8
|
@keras_export("keras.callbacks.TerminateOnNaN")
|
|
9
9
|
class TerminateOnNaN(Callback):
|
|
10
|
-
"""Callback that terminates training when a NaN loss is encountered.
|
|
10
|
+
"""Callback that terminates training when a NaN loss is encountered.
|
|
11
|
+
|
|
12
|
+
This callback monitors the loss value during training
|
|
13
|
+
and terminates training when a NaN or Inf loss is detected.
|
|
14
|
+
By default, training is stopped gracefully
|
|
15
|
+
by setting `model.stop_training = True`, which triggers all callback cleanup
|
|
16
|
+
methods including `on_train_end()`.
|
|
17
|
+
|
|
18
|
+
Alternatively, you can use `raise_error=True` to immediately raise a
|
|
19
|
+
RuntimeError when NaN/Inf is detected. This raise_error termination
|
|
20
|
+
prevents `on_train_end()` from being called on other callbacks, which
|
|
21
|
+
is useful for preserving backup states or preventing unintended cleanup
|
|
22
|
+
when training fails.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
raise_error: Boolean, default False. If False, uses graceful stop via
|
|
26
|
+
`model.stop_training = True`. If True, immediately raises
|
|
27
|
+
RuntimeError on NaN/Inf loss, bypassing callback cleanup methods.
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
|
|
31
|
+
```
|
|
32
|
+
# Graceful termination (default)
|
|
33
|
+
callback = keras.callbacks.TerminateOnNaN()
|
|
34
|
+
model.fit(x, y, callbacks=[callback])
|
|
35
|
+
|
|
36
|
+
# raise_error termination (strict failure)
|
|
37
|
+
callback = keras.callbacks.TerminateOnNaN(raise_error=True)
|
|
38
|
+
model.fit(x, y, callbacks=[callback])
|
|
39
|
+
```
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, raise_error: bool = False):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.raise_error = raise_error
|
|
11
45
|
|
|
12
46
|
def on_batch_end(self, batch, logs=None):
|
|
47
|
+
"""Check for NaN/Inf loss at the end of each batch.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
batch: Integer, index of batch within the current epoch.
|
|
51
|
+
logs: Dict, contains the return value of `model.train_step()`.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
RuntimeError: If loss is NaN/Inf and raise_error=True.
|
|
55
|
+
"""
|
|
13
56
|
logs = logs or {}
|
|
14
57
|
loss = logs.get("loss")
|
|
15
58
|
if loss is not None:
|
|
16
59
|
if np.isnan(loss) or np.isinf(loss):
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
60
|
+
if self.raise_error:
|
|
61
|
+
raise RuntimeError(
|
|
62
|
+
f"NaN or Inf loss encountered at batch {batch}. "
|
|
63
|
+
f"Loss value: {loss}. Terminating training immediately."
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
io_utils.print_msg(
|
|
67
|
+
f"Batch {batch}: Invalid loss, terminating training"
|
|
68
|
+
)
|
|
69
|
+
self.model.stop_training = True
|
keras/src/datasets/cifar10.py
CHANGED
|
@@ -59,6 +59,11 @@ def load_data():
|
|
|
59
59
|
assert y_train.shape == (50000, 1)
|
|
60
60
|
assert y_test.shape == (10000, 1)
|
|
61
61
|
```
|
|
62
|
+
|
|
63
|
+
**Note**: The CIFAR-10 dataset is known to have a small percentage of
|
|
64
|
+
mislabeled samples, which is inherent to the original dataset. This label
|
|
65
|
+
noise may impact training and evaluation. For more details, refer to
|
|
66
|
+
discussions in the research literature on CIFAR-10 label quality.
|
|
62
67
|
"""
|
|
63
68
|
dirname = "cifar-10-batches-py-target"
|
|
64
69
|
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Distillation module for knowledge distillation in Keras."""
|