keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +12 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +12 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +1 -1
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +33 -16
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +485 -20
- keras/src/backend/jax/numpy.py +92 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +76 -7
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1030 -185
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +264 -54
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +84 -8
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +299 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +191 -172
- keras/src/layers/core/einsum_dense.py +235 -186
- keras/src/layers/core/embedding.py +83 -93
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +390 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +40 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/losses/loss.py +1 -1
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +156 -27
- keras/src/ops/image.py +184 -3
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +541 -43
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +12 -1
- keras/src/quantizers/gptq.py +8 -6
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +150 -78
- keras/src/quantizers/quantization_config.py +232 -0
- keras/src/quantizers/quantizers.py +114 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +4 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +14 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +187 -36
- keras/src/utils/module_utils.py +18 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/rng_utils.py +9 -1
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
keras/src/backend/torch/numpy.py
CHANGED
|
@@ -313,18 +313,19 @@ def append(x1, x2, axis=None):
|
|
|
313
313
|
return torch.cat((x1, x2), dim=axis)
|
|
314
314
|
|
|
315
315
|
|
|
316
|
-
def arange(start, stop=None, step=
|
|
316
|
+
def arange(start, stop=None, step=None, dtype=None):
|
|
317
317
|
if dtype is None:
|
|
318
|
-
dtypes_to_resolve = [
|
|
319
|
-
getattr(start, "dtype", type(start)),
|
|
320
|
-
getattr(step, "dtype", type(step)),
|
|
321
|
-
]
|
|
318
|
+
dtypes_to_resolve = [getattr(start, "dtype", type(start))]
|
|
322
319
|
if stop is not None:
|
|
323
320
|
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
|
|
321
|
+
if step is not None:
|
|
322
|
+
dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
|
|
324
323
|
dtype = dtypes.result_type(*dtypes_to_resolve)
|
|
325
324
|
dtype = to_torch_dtype(dtype)
|
|
326
325
|
if stop is None:
|
|
327
|
-
|
|
326
|
+
start, stop = 0, start
|
|
327
|
+
if step is None:
|
|
328
|
+
step = 1
|
|
328
329
|
return torch.arange(
|
|
329
330
|
start, stop, step=step, dtype=dtype, device=get_device()
|
|
330
331
|
)
|
|
@@ -410,6 +411,12 @@ def array(x, dtype=None):
|
|
|
410
411
|
return convert_to_tensor(x, dtype=dtype)
|
|
411
412
|
|
|
412
413
|
|
|
414
|
+
def view(x, dtype=None):
|
|
415
|
+
dtype = to_torch_dtype(dtype)
|
|
416
|
+
x = convert_to_tensor(x)
|
|
417
|
+
return x.view(dtype=dtype)
|
|
418
|
+
|
|
419
|
+
|
|
413
420
|
def average(x, axis=None, weights=None):
|
|
414
421
|
x = convert_to_tensor(x)
|
|
415
422
|
dtypes_to_resolve = [x.dtype, float]
|
|
@@ -763,6 +770,12 @@ def empty(shape, dtype=None):
|
|
|
763
770
|
return torch.empty(size=shape, dtype=dtype, device=get_device())
|
|
764
771
|
|
|
765
772
|
|
|
773
|
+
def empty_like(x, dtype=None):
|
|
774
|
+
x = convert_to_tensor(x)
|
|
775
|
+
dtype = to_torch_dtype(dtype or x.dtype)
|
|
776
|
+
return torch.empty_like(x, dtype=dtype, device=get_device())
|
|
777
|
+
|
|
778
|
+
|
|
766
779
|
def equal(x1, x2):
|
|
767
780
|
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
|
768
781
|
return torch.eq(x1, x2)
|
|
@@ -945,6 +958,11 @@ def isposinf(x):
|
|
|
945
958
|
return torch.isposinf(x)
|
|
946
959
|
|
|
947
960
|
|
|
961
|
+
def isreal(x):
|
|
962
|
+
x = convert_to_tensor(x)
|
|
963
|
+
return torch.isreal(x)
|
|
964
|
+
|
|
965
|
+
|
|
948
966
|
def kron(x1, x2):
|
|
949
967
|
x1 = convert_to_tensor(x1)
|
|
950
968
|
x2 = convert_to_tensor(x2)
|
|
@@ -957,6 +975,20 @@ def lcm(x1, x2):
|
|
|
957
975
|
return torch.lcm(x1, x2)
|
|
958
976
|
|
|
959
977
|
|
|
978
|
+
def ldexp(x1, x2):
|
|
979
|
+
x1 = convert_to_tensor(x1)
|
|
980
|
+
x2 = convert_to_tensor(x2)
|
|
981
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
982
|
+
|
|
983
|
+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
|
|
984
|
+
raise TypeError(
|
|
985
|
+
f"ldexp exponent must be an integer type. "
|
|
986
|
+
f"Received: x2 dtype={x2.dtype}"
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
return cast(torch.ldexp(x1, x2), dtype)
|
|
990
|
+
|
|
991
|
+
|
|
960
992
|
def less(x1, x2):
|
|
961
993
|
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
|
962
994
|
return torch.less(x1, x2)
|
|
@@ -1053,6 +1085,15 @@ def logaddexp(x1, x2):
|
|
|
1053
1085
|
return torch.logaddexp(x1, x2)
|
|
1054
1086
|
|
|
1055
1087
|
|
|
1088
|
+
def logaddexp2(x1, x2):
|
|
1089
|
+
x1 = convert_to_tensor(x1)
|
|
1090
|
+
x2 = convert_to_tensor(x2)
|
|
1091
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1092
|
+
x1 = cast(x1, dtype)
|
|
1093
|
+
x2 = cast(x2, dtype)
|
|
1094
|
+
return torch.logaddexp2(x1, x2)
|
|
1095
|
+
|
|
1096
|
+
|
|
1056
1097
|
def logical_and(x1, x2):
|
|
1057
1098
|
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
|
1058
1099
|
return torch.logical_and(x1, x2)
|
|
@@ -1518,6 +1559,12 @@ def split(x, indices_or_sections, axis=0):
|
|
|
1518
1559
|
return list(out)
|
|
1519
1560
|
|
|
1520
1561
|
|
|
1562
|
+
def array_split(x, indices_or_sections, axis=0):
|
|
1563
|
+
x = convert_to_tensor(x)
|
|
1564
|
+
out = torch.tensor_split(x, indices_or_sections, dim=axis)
|
|
1565
|
+
return list(out)
|
|
1566
|
+
|
|
1567
|
+
|
|
1521
1568
|
def stack(x, axis=0):
|
|
1522
1569
|
x = [convert_to_tensor(elem) for elem in x]
|
|
1523
1570
|
return torch.stack(x, dim=axis)
|
|
@@ -1631,8 +1678,9 @@ def tile(x, repeats):
|
|
|
1631
1678
|
def trace(x, offset=0, axis1=0, axis2=1):
|
|
1632
1679
|
x = convert_to_tensor(x)
|
|
1633
1680
|
dtype = standardize_dtype(x.dtype)
|
|
1634
|
-
if dtype
|
|
1635
|
-
|
|
1681
|
+
if dtype in ("bool", "int8", "int16", "uint8"):
|
|
1682
|
+
# Torch backend doesn't support uint32 dtype.
|
|
1683
|
+
dtype = "int32"
|
|
1636
1684
|
return torch.sum(
|
|
1637
1685
|
torch.diagonal(x, offset, axis1, axis2),
|
|
1638
1686
|
dim=-1,
|
|
@@ -1745,6 +1793,16 @@ def negative(x):
|
|
|
1745
1793
|
return torch.negative(x)
|
|
1746
1794
|
|
|
1747
1795
|
|
|
1796
|
+
def nextafter(x1, x2):
|
|
1797
|
+
x1 = convert_to_tensor(x1)
|
|
1798
|
+
x2 = convert_to_tensor(x2)
|
|
1799
|
+
|
|
1800
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1801
|
+
x1 = cast(x1, torch.float64)
|
|
1802
|
+
x2 = cast(x2, torch.float64)
|
|
1803
|
+
return cast(torch.nextafter(x1, x2), dtype)
|
|
1804
|
+
|
|
1805
|
+
|
|
1748
1806
|
def square(x):
|
|
1749
1807
|
x = convert_to_tensor(x)
|
|
1750
1808
|
if standardize_dtype(x.dtype) == "bool":
|
|
@@ -1773,6 +1831,24 @@ def transpose(x, axes=None):
|
|
|
1773
1831
|
return x.T
|
|
1774
1832
|
|
|
1775
1833
|
|
|
1834
|
+
def trapezoid(y, x=None, dx=1.0, axis=-1):
|
|
1835
|
+
y = convert_to_tensor(y)
|
|
1836
|
+
if standardize_dtype(y.dtype) == "bool":
|
|
1837
|
+
y = cast(y, config.floatx())
|
|
1838
|
+
if x is not None:
|
|
1839
|
+
x = convert_to_tensor(x)
|
|
1840
|
+
return torch.trapz(y, x=x, dim=axis)
|
|
1841
|
+
else:
|
|
1842
|
+
dx = convert_to_tensor(dx)
|
|
1843
|
+
return torch.trapz(y, dx=dx, dim=axis)
|
|
1844
|
+
|
|
1845
|
+
|
|
1846
|
+
def vander(x, N=None, increasing=False):
|
|
1847
|
+
x = convert_to_tensor(x)
|
|
1848
|
+
result_dtype = dtypes.result_type(x.dtype)
|
|
1849
|
+
return cast(torch.vander(x, N=N, increasing=increasing), result_dtype)
|
|
1850
|
+
|
|
1851
|
+
|
|
1776
1852
|
def var(x, axis=None, keepdims=False):
|
|
1777
1853
|
x = convert_to_tensor(x)
|
|
1778
1854
|
compute_dtype = dtypes.result_type(x.dtype, "float32")
|
keras/src/callbacks/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ from keras.src.callbacks.lambda_callback import LambdaCallback
|
|
|
8
8
|
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
|
|
9
9
|
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
|
|
10
10
|
from keras.src.callbacks.monitor_callback import MonitorCallback
|
|
11
|
+
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
|
|
11
12
|
from keras.src.callbacks.progbar_logger import ProgbarLogger
|
|
12
13
|
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
|
|
13
14
|
from keras.src.callbacks.remote_monitor import RemoteMonitor
|
|
@@ -39,6 +39,7 @@ class CallbackList(Callback):
|
|
|
39
39
|
via `Callback.set_params`.
|
|
40
40
|
"""
|
|
41
41
|
self.callbacks = tree.flatten(callbacks) if callbacks else []
|
|
42
|
+
self._in_begin_end_block_count = 0
|
|
42
43
|
self._executor = None
|
|
43
44
|
self._async_train = False
|
|
44
45
|
self._async_test = False
|
|
@@ -78,9 +79,6 @@ class CallbackList(Callback):
|
|
|
78
79
|
if not utils.is_default(cbk.on_predict_batch_end):
|
|
79
80
|
async_predict = False
|
|
80
81
|
|
|
81
|
-
if async_train or async_test or async_predict:
|
|
82
|
-
self._executor = concurrent.futures.ThreadPoolExecutor()
|
|
83
|
-
|
|
84
82
|
self._async_train = async_train
|
|
85
83
|
self._async_test = async_test
|
|
86
84
|
self._async_predict = async_predict
|
|
@@ -113,6 +111,33 @@ class CallbackList(Callback):
|
|
|
113
111
|
for callback in self.callbacks:
|
|
114
112
|
callback.set_model(model)
|
|
115
113
|
|
|
114
|
+
def _on_begin(self):
|
|
115
|
+
"""Called by `on_train/test/predict_begin`.
|
|
116
|
+
|
|
117
|
+
Start the executor for async calls if needed.
|
|
118
|
+
"""
|
|
119
|
+
self._in_begin_end_block_count += 1
|
|
120
|
+
if (
|
|
121
|
+
self._in_begin_end_block_count == 1
|
|
122
|
+
and (self._async_train or self._async_test or self._async_predict)
|
|
123
|
+
and self._executor is None
|
|
124
|
+
):
|
|
125
|
+
self._executor = concurrent.futures.ThreadPoolExecutor()
|
|
126
|
+
|
|
127
|
+
def _on_end(self):
|
|
128
|
+
"""Called by `on_train/test/predict_end`.
|
|
129
|
+
|
|
130
|
+
Shutdown the executor for async calls if all begin/end blocks completed.
|
|
131
|
+
"""
|
|
132
|
+
self._in_begin_end_block_count -= 1
|
|
133
|
+
if self._in_begin_end_block_count < 0:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
"`on_xxx_end` called without corresponding `on_xxx_begin`"
|
|
136
|
+
)
|
|
137
|
+
if self._in_begin_end_block_count == 0 and self._executor is not None:
|
|
138
|
+
self._executor.shutdown()
|
|
139
|
+
self._executor = None
|
|
140
|
+
|
|
116
141
|
def _async_dispatch(self, fn, *args):
|
|
117
142
|
for future in self._futures:
|
|
118
143
|
if future.done():
|
|
@@ -121,7 +146,8 @@ class CallbackList(Callback):
|
|
|
121
146
|
future = self._executor.submit(fn, *args)
|
|
122
147
|
self._futures.append(future)
|
|
123
148
|
|
|
124
|
-
def
|
|
149
|
+
def _flush_futures(self):
|
|
150
|
+
"""Waits for all futures to complete and clears the list."""
|
|
125
151
|
for future in self._futures:
|
|
126
152
|
future.result()
|
|
127
153
|
self._futures = []
|
|
@@ -138,7 +164,7 @@ class CallbackList(Callback):
|
|
|
138
164
|
|
|
139
165
|
def on_epoch_end(self, epoch, logs=None):
|
|
140
166
|
if self._async_train:
|
|
141
|
-
self.
|
|
167
|
+
self._flush_futures()
|
|
142
168
|
|
|
143
169
|
logs = python_utils.pythonify_logs(logs)
|
|
144
170
|
for callback in self.callbacks:
|
|
@@ -204,44 +230,52 @@ class CallbackList(Callback):
|
|
|
204
230
|
callback.on_predict_batch_end(batch, logs=logs)
|
|
205
231
|
|
|
206
232
|
def on_train_begin(self, logs=None):
|
|
233
|
+
self._on_begin()
|
|
234
|
+
|
|
207
235
|
logs = python_utils.pythonify_logs(logs)
|
|
208
236
|
for callback in self.callbacks:
|
|
209
237
|
callback.on_train_begin(logs)
|
|
210
238
|
|
|
211
239
|
def on_train_end(self, logs=None):
|
|
212
240
|
if self._async_train:
|
|
213
|
-
self.
|
|
241
|
+
self._flush_futures()
|
|
214
242
|
|
|
215
243
|
logs = python_utils.pythonify_logs(logs)
|
|
216
244
|
for callback in self.callbacks:
|
|
217
245
|
callback.on_train_end(logs)
|
|
218
246
|
|
|
247
|
+
self._on_end()
|
|
248
|
+
|
|
219
249
|
def on_test_begin(self, logs=None):
|
|
250
|
+
self._on_begin()
|
|
251
|
+
|
|
220
252
|
logs = python_utils.pythonify_logs(logs)
|
|
221
253
|
for callback in self.callbacks:
|
|
222
254
|
callback.on_test_begin(logs)
|
|
223
255
|
|
|
224
256
|
def on_test_end(self, logs=None):
|
|
225
257
|
if self._async_test:
|
|
226
|
-
self.
|
|
258
|
+
self._flush_futures()
|
|
227
259
|
|
|
228
260
|
logs = python_utils.pythonify_logs(logs)
|
|
229
261
|
for callback in self.callbacks:
|
|
230
262
|
callback.on_test_end(logs)
|
|
231
263
|
|
|
264
|
+
self._on_end()
|
|
265
|
+
|
|
232
266
|
def on_predict_begin(self, logs=None):
|
|
267
|
+
self._on_begin()
|
|
268
|
+
|
|
233
269
|
logs = python_utils.pythonify_logs(logs)
|
|
234
270
|
for callback in self.callbacks:
|
|
235
271
|
callback.on_predict_begin(logs)
|
|
236
272
|
|
|
237
273
|
def on_predict_end(self, logs=None):
|
|
238
274
|
if self._async_predict:
|
|
239
|
-
self.
|
|
275
|
+
self._flush_futures()
|
|
240
276
|
|
|
241
277
|
logs = python_utils.pythonify_logs(logs)
|
|
242
278
|
for callback in self.callbacks:
|
|
243
279
|
callback.on_predict_end(logs)
|
|
244
280
|
|
|
245
|
-
|
|
246
|
-
if self._executor is not None:
|
|
247
|
-
self._executor.shutdown(cancel_futures=True)
|
|
281
|
+
self._on_end()
|
|
@@ -283,6 +283,11 @@ class ModelCheckpoint(MonitorCallback):
|
|
|
283
283
|
self.model.save_weights(filepath, overwrite=True)
|
|
284
284
|
else:
|
|
285
285
|
self.model.save(filepath, overwrite=True)
|
|
286
|
+
if self.verbose > 0:
|
|
287
|
+
io_utils.print_msg(
|
|
288
|
+
f"\nEpoch {epoch + 1}: "
|
|
289
|
+
f"finished saving model to {filepath}"
|
|
290
|
+
)
|
|
286
291
|
except IsADirectoryError: # h5py 3.x
|
|
287
292
|
raise IOError(
|
|
288
293
|
"Please specify a non-directory filepath for "
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from keras.src import backend
|
|
6
|
+
from keras.src import tree
|
|
7
|
+
from keras.src.api_export import keras_export
|
|
8
|
+
from keras.src.callbacks.monitor_callback import (
|
|
9
|
+
MonitorCallback, # For metric monitoring logic
|
|
10
|
+
)
|
|
11
|
+
from keras.src.utils.io_utils import print_msg
|
|
12
|
+
from keras.src.utils.module_utils import ocp
|
|
13
|
+
|
|
14
|
+
# Context and AsyncOptions are accessed through the lazy-loaded ocp module
|
|
15
|
+
|
|
16
|
+
# JAX monitoring compatibility: ensure record_scalar exists
|
|
17
|
+
# to prevent AttributeError in older JAX versions
|
|
18
|
+
try:
|
|
19
|
+
import jax
|
|
20
|
+
|
|
21
|
+
if not hasattr(jax.monitoring, "record_scalar"):
|
|
22
|
+
jax.monitoring.record_scalar = lambda *args, **kwargs: None
|
|
23
|
+
except ImportError:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _get_state_tree(model):
|
|
28
|
+
"""Get the complete model state as a nested tree structure."""
|
|
29
|
+
# For JAX backend, preserve native arrays for performance
|
|
30
|
+
# For other backends, convert to numpy arrays
|
|
31
|
+
if backend.backend() == "jax":
|
|
32
|
+
state_tree = model.get_state_tree()
|
|
33
|
+
did_numpy_conversion = False
|
|
34
|
+
else:
|
|
35
|
+
state_tree = model.get_state_tree(value_format="numpy_array")
|
|
36
|
+
did_numpy_conversion = True
|
|
37
|
+
|
|
38
|
+
# Convert numpy scalar types to Python types for Orbax compatibility
|
|
39
|
+
# Only needed when we did numpy conversion
|
|
40
|
+
if did_numpy_conversion:
|
|
41
|
+
|
|
42
|
+
def convert_scalars(obj):
|
|
43
|
+
if isinstance(obj, np.ndarray) and obj.ndim == 0:
|
|
44
|
+
# Convert 0-dimensional numpy arrays (scalars) to Python types
|
|
45
|
+
return obj.item()
|
|
46
|
+
elif isinstance(obj, np.generic):
|
|
47
|
+
# Convert numpy scalar types (like np.float32) to Python types
|
|
48
|
+
return obj.item()
|
|
49
|
+
else:
|
|
50
|
+
return obj
|
|
51
|
+
|
|
52
|
+
return tree.map_structure(convert_scalars, state_tree)
|
|
53
|
+
else:
|
|
54
|
+
return state_tree
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@keras_export("keras.callbacks.OrbaxCheckpoint")
|
|
58
|
+
class OrbaxCheckpoint(MonitorCallback):
|
|
59
|
+
"""Callback to save and load model state using Orbax with a similar API to
|
|
60
|
+
ModelCheckpoint.
|
|
61
|
+
|
|
62
|
+
This callback saves the model's weights and optimizer state asynchronously
|
|
63
|
+
using Orbax, allowing training to continue without blocking for I/O.
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
model.compile(loss=..., optimizer=..., metrics=['accuracy'])
|
|
69
|
+
|
|
70
|
+
EPOCHS = 10
|
|
71
|
+
checkpoint_dir = '/tmp/ckpt'
|
|
72
|
+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
|
|
73
|
+
directory=checkpoint_dir,
|
|
74
|
+
monitor='val_accuracy',
|
|
75
|
+
mode='max',
|
|
76
|
+
save_best_only=True)
|
|
77
|
+
|
|
78
|
+
# Model is saved at the end of every epoch, if it's the best seen so far.
|
|
79
|
+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
|
|
80
|
+
|
|
81
|
+
# Alternatively, save checkpoints every N batches -
|
|
82
|
+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
|
|
83
|
+
directory=checkpoint_dir,
|
|
84
|
+
save_freq=100) # Save every 100 batches
|
|
85
|
+
|
|
86
|
+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
directory: path to the directory where to save the checkpoints.
|
|
91
|
+
monitor: The metric name to monitor (e.g., 'val_loss').
|
|
92
|
+
verbose: Verbosity mode, 0 or 1.
|
|
93
|
+
save_best_only: if `save_best_only=True`, it only saves when the model
|
|
94
|
+
is considered the "best" based on the monitored quantity.
|
|
95
|
+
save_weights_only: if `save_weights_only=True`, only the model's
|
|
96
|
+
weights will be saved. Otherwise, the full model state
|
|
97
|
+
(weights, non-trainable variables, optimizer state, and
|
|
98
|
+
metrics state) will be saved. Defaults to False.
|
|
99
|
+
mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
|
|
100
|
+
save_freq: `'epoch'` or integer. Frequency to save checkpoints.
|
|
101
|
+
max_to_keep: Integer, maximum number of recent checkpoints to keep.
|
|
102
|
+
If None, keeps all. Defaults to 1.
|
|
103
|
+
save_on_background: Boolean, whether to save asynchronously in the
|
|
104
|
+
background. Defaults to True.
|
|
105
|
+
initial_value_threshold: Floating point initial "best" value for the
|
|
106
|
+
monitor, used with `save_best_only`.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
directory,
|
|
112
|
+
monitor="val_loss",
|
|
113
|
+
verbose=0,
|
|
114
|
+
save_best_only=False,
|
|
115
|
+
save_weights_only=False,
|
|
116
|
+
mode="auto",
|
|
117
|
+
save_freq="epoch",
|
|
118
|
+
initial_value_threshold=None,
|
|
119
|
+
max_to_keep=1,
|
|
120
|
+
save_on_background=True,
|
|
121
|
+
):
|
|
122
|
+
# Ensure orbax is available
|
|
123
|
+
ocp.initialize()
|
|
124
|
+
|
|
125
|
+
# Initialize MonitorCallback for handling 'monitor', 'mode', 'best'
|
|
126
|
+
# logic
|
|
127
|
+
super().__init__(monitor, mode, initial_value_threshold)
|
|
128
|
+
|
|
129
|
+
self.directory = directory
|
|
130
|
+
self.verbose = verbose
|
|
131
|
+
self.save_best_only = save_best_only
|
|
132
|
+
self.save_weights_only = save_weights_only
|
|
133
|
+
self.save_freq = save_freq
|
|
134
|
+
self.max_to_keep = max_to_keep
|
|
135
|
+
self.save_on_background = save_on_background
|
|
136
|
+
self._batches_seen_since_last_saving = 0
|
|
137
|
+
self._last_batch_seen = 0
|
|
138
|
+
self._current_epoch = 0 # Keep track of epoch
|
|
139
|
+
self._total_batches_seen = 0 # Global batch counter for step tracking
|
|
140
|
+
|
|
141
|
+
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Unrecognized save_freq: {self.save_freq}. "
|
|
144
|
+
"Expected save_freq are 'epoch' or integer values"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# --- Orbax Checkpointer Setup (V1 API) ---
|
|
148
|
+
policies = []
|
|
149
|
+
if max_to_keep is not None:
|
|
150
|
+
policies.append(
|
|
151
|
+
ocp.training.preservation_policies.LatestN(max_to_keep)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Use AnyPreservationPolicy to combine them.
|
|
155
|
+
preservation_policy = None
|
|
156
|
+
if policies:
|
|
157
|
+
preservation_policy = (
|
|
158
|
+
ocp.training.preservation_policies.AnyPreservationPolicy(
|
|
159
|
+
policies
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Create the V1 Checkpointer with direct parameter passing
|
|
164
|
+
# Orbax will handle directory creation on all processes as needed
|
|
165
|
+
self.checkpointer = ocp.training.Checkpointer(
|
|
166
|
+
directory=directory,
|
|
167
|
+
preservation_policy=preservation_policy,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def _should_save_on_batch(self, batch):
|
|
171
|
+
"""Check if we should save on this batch."""
|
|
172
|
+
if self.save_freq == "epoch":
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
if batch <= self._last_batch_seen: # New epoch.
|
|
176
|
+
add_batches = batch + 1
|
|
177
|
+
else:
|
|
178
|
+
add_batches = batch - self._last_batch_seen
|
|
179
|
+
self._batches_seen_since_last_saving += add_batches
|
|
180
|
+
self._last_batch_seen = batch
|
|
181
|
+
self._total_batches_seen += add_batches
|
|
182
|
+
|
|
183
|
+
if self._batches_seen_since_last_saving >= self.save_freq:
|
|
184
|
+
self._batches_seen_since_last_saving = 0
|
|
185
|
+
return True
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
def _save_checkpoint(self, step, logs=None):
|
|
189
|
+
"""Save a checkpoint at the given step."""
|
|
190
|
+
|
|
191
|
+
# --- Prepare Composite State (Backend-Agnostic) ---
|
|
192
|
+
state_tree = _get_state_tree(self.model)
|
|
193
|
+
|
|
194
|
+
# Save the nested state structures directly (preserving layer
|
|
195
|
+
# names and structure)
|
|
196
|
+
if self.save_weights_only:
|
|
197
|
+
composite_state = {
|
|
198
|
+
"trainable_variables": state_tree["trainable_variables"],
|
|
199
|
+
}
|
|
200
|
+
if "non_trainable_variables" in state_tree:
|
|
201
|
+
composite_state["non_trainable_variables"] = state_tree[
|
|
202
|
+
"non_trainable_variables"
|
|
203
|
+
]
|
|
204
|
+
else:
|
|
205
|
+
composite_state = state_tree
|
|
206
|
+
|
|
207
|
+
# --- Save Logic (V1 API) ---
|
|
208
|
+
# All processes participate in distributed checkpointing
|
|
209
|
+
# Checkpointer is configured to save unconditionally when
|
|
210
|
+
# save_pytree is called
|
|
211
|
+
if self.verbose > 0:
|
|
212
|
+
print_msg(
|
|
213
|
+
f"OrbaxCheckpoint: Triggering async save for step {step}..."
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Use a single with statement. If context_options is empty,
|
|
217
|
+
# Context() uses defaults.
|
|
218
|
+
with ocp.Context():
|
|
219
|
+
if self.save_on_background:
|
|
220
|
+
self.checkpointer.save_pytree_async(step, composite_state)
|
|
221
|
+
else:
|
|
222
|
+
self.checkpointer.save_pytree(step, composite_state)
|
|
223
|
+
|
|
224
|
+
def on_train_batch_end(self, batch, logs=None):
|
|
225
|
+
if self._should_save_on_batch(batch):
|
|
226
|
+
# Handle save_best_only logic for batch-level saving
|
|
227
|
+
should_save = True
|
|
228
|
+
if self.save_best_only:
|
|
229
|
+
current = logs.get(self.monitor) if logs else None
|
|
230
|
+
if current is None:
|
|
231
|
+
warnings.warn(
|
|
232
|
+
f"Can save best model only with {self.monitor} "
|
|
233
|
+
f"available, skipping save at batch {batch}.",
|
|
234
|
+
stacklevel=2,
|
|
235
|
+
)
|
|
236
|
+
should_save = False
|
|
237
|
+
elif not self._is_improvement(current, self.best):
|
|
238
|
+
should_save = False
|
|
239
|
+
else:
|
|
240
|
+
# Update best value when there's improvement
|
|
241
|
+
self.best = current
|
|
242
|
+
|
|
243
|
+
if should_save:
|
|
244
|
+
# Use global batch count for Orbax save step
|
|
245
|
+
step = self._total_batches_seen
|
|
246
|
+
self._save_checkpoint(step=step, logs=logs)
|
|
247
|
+
|
|
248
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
249
|
+
self._current_epoch = epoch
|
|
250
|
+
if self.monitor_op is None:
|
|
251
|
+
self._set_monitor_op() # From MonitorCallback
|
|
252
|
+
|
|
253
|
+
# For save_freq="epoch", save at every epoch
|
|
254
|
+
should_save = self.save_freq == "epoch"
|
|
255
|
+
|
|
256
|
+
# Handle save_best_only logic
|
|
257
|
+
if should_save and self.save_best_only:
|
|
258
|
+
current = logs.get(self.monitor) if logs else None
|
|
259
|
+
if current is None:
|
|
260
|
+
warnings.warn(
|
|
261
|
+
f"Can save best model only with {self.monitor} available, "
|
|
262
|
+
f"skipping save at epoch {epoch}.",
|
|
263
|
+
stacklevel=2,
|
|
264
|
+
)
|
|
265
|
+
should_save = False
|
|
266
|
+
elif not self._is_improvement(current, self.best):
|
|
267
|
+
should_save = False
|
|
268
|
+
else:
|
|
269
|
+
# Update best value when there's improvement
|
|
270
|
+
self.best = current
|
|
271
|
+
|
|
272
|
+
if should_save:
|
|
273
|
+
# Use epoch number as the step for Orbax save
|
|
274
|
+
# Keras has already made the save decision - Checkpointer will
|
|
275
|
+
# save unconditionally
|
|
276
|
+
self._save_checkpoint(step=epoch, logs=logs)
|
|
277
|
+
|
|
278
|
+
def on_train_end(self, logs=None):
|
|
279
|
+
# Close the Checkpointer to ensure all pending saves complete
|
|
280
|
+
try:
|
|
281
|
+
self.checkpointer.close()
|
|
282
|
+
except Exception:
|
|
283
|
+
pass # Ignore errors during cleanup
|
|
284
|
+
|
|
285
|
+
def wait_until_finished(self):
|
|
286
|
+
"""Wait for any in-progress checkpoint operations to complete.
|
|
287
|
+
This method blocks until all asynchronous checkpoint save operations
|
|
288
|
+
have completed. It should be called before attempting to load
|
|
289
|
+
checkpoints if there might be pending save operations.
|
|
290
|
+
"""
|
|
291
|
+
# Wait for any async operations to complete
|
|
292
|
+
if hasattr(self.checkpointer, "wait"):
|
|
293
|
+
self.checkpointer.wait()
|
|
294
|
+
else:
|
|
295
|
+
# Fallback for older Orbax versions that don't have wait() method
|
|
296
|
+
while self.checkpointer.is_saving_in_progress():
|
|
297
|
+
import time
|
|
298
|
+
|
|
299
|
+
time.sleep(0.1)
|
|
@@ -7,14 +7,63 @@ from keras.src.utils import io_utils
|
|
|
7
7
|
|
|
8
8
|
@keras_export("keras.callbacks.TerminateOnNaN")
|
|
9
9
|
class TerminateOnNaN(Callback):
|
|
10
|
-
"""Callback that terminates training when a NaN loss is encountered.
|
|
10
|
+
"""Callback that terminates training when a NaN loss is encountered.
|
|
11
|
+
|
|
12
|
+
This callback monitors the loss value during training
|
|
13
|
+
and terminates training when a NaN or Inf loss is detected.
|
|
14
|
+
By default, training is stopped gracefully
|
|
15
|
+
by setting `model.stop_training = True`, which triggers all callback cleanup
|
|
16
|
+
methods including `on_train_end()`.
|
|
17
|
+
|
|
18
|
+
Alternatively, you can use `raise_error=True` to immediately raise a
|
|
19
|
+
RuntimeError when NaN/Inf is detected. This raise_error termination
|
|
20
|
+
prevents `on_train_end()` from being called on other callbacks, which
|
|
21
|
+
is useful for preserving backup states or preventing unintended cleanup
|
|
22
|
+
when training fails.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
raise_error: Boolean, default False. If False, uses graceful stop via
|
|
26
|
+
`model.stop_training = True`. If True, immediately raises
|
|
27
|
+
RuntimeError on NaN/Inf loss, bypassing callback cleanup methods.
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
|
|
31
|
+
```
|
|
32
|
+
# Graceful termination (default)
|
|
33
|
+
callback = keras.callbacks.TerminateOnNaN()
|
|
34
|
+
model.fit(x, y, callbacks=[callback])
|
|
35
|
+
|
|
36
|
+
# raise_error termination (strict failure)
|
|
37
|
+
callback = keras.callbacks.TerminateOnNaN(raise_error=True)
|
|
38
|
+
model.fit(x, y, callbacks=[callback])
|
|
39
|
+
```
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, raise_error: bool = False):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.raise_error = raise_error
|
|
11
45
|
|
|
12
46
|
def on_batch_end(self, batch, logs=None):
|
|
47
|
+
"""Check for NaN/Inf loss at the end of each batch.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
batch: Integer, index of batch within the current epoch.
|
|
51
|
+
logs: Dict, contains the return value of `model.train_step()`.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
RuntimeError: If loss is NaN/Inf and raise_error=True.
|
|
55
|
+
"""
|
|
13
56
|
logs = logs or {}
|
|
14
57
|
loss = logs.get("loss")
|
|
15
58
|
if loss is not None:
|
|
16
59
|
if np.isnan(loss) or np.isinf(loss):
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
60
|
+
if self.raise_error:
|
|
61
|
+
raise RuntimeError(
|
|
62
|
+
f"NaN or Inf loss encountered at batch {batch}. "
|
|
63
|
+
f"Loss value: {loss}. Terminating training immediately."
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
io_utils.print_msg(
|
|
67
|
+
f"Batch {batch}: Invalid loss, terminating training"
|
|
68
|
+
)
|
|
69
|
+
self.model.stop_training = True
|
keras/src/datasets/cifar10.py
CHANGED
|
@@ -59,6 +59,11 @@ def load_data():
|
|
|
59
59
|
assert y_train.shape == (50000, 1)
|
|
60
60
|
assert y_test.shape == (10000, 1)
|
|
61
61
|
```
|
|
62
|
+
|
|
63
|
+
**Note**: The CIFAR-10 dataset is known to have a small percentage of
|
|
64
|
+
mislabeled samples, which is inherent to the original dataset. This label
|
|
65
|
+
noise may impact training and evaluation. For more details, refer to
|
|
66
|
+
discussions in the research literature on CIFAR-10 label quality.
|
|
62
67
|
"""
|
|
63
68
|
dirname = "cifar-10-batches-py-target"
|
|
64
69
|
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|