keras-nightly 3.14.0.dev2026011204__py3-none-any.whl → 3.14.0.dev2026011304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/src/backend/tensorflow/numpy.py +1 -1
- keras/src/callbacks/orbax_checkpoint.py +75 -42
- keras/src/losses/losses.py +24 -0
- keras/src/models/model.py +11 -6
- keras/src/ops/image.py +73 -17
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/utils/module_utils.py +11 -0
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026011204.dist-info → keras_nightly-3.14.0.dev2026011304.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026011204.dist-info → keras_nightly-3.14.0.dev2026011304.dist-info}/RECORD +13 -12
- {keras_nightly-3.14.0.dev2026011204.dist-info → keras_nightly-3.14.0.dev2026011304.dist-info}/WHEEL +0 -0
- {keras_nightly-3.14.0.dev2026011204.dist-info → keras_nightly-3.14.0.dev2026011304.dist-info}/top_level.txt +0 -0
|
@@ -8,7 +8,6 @@ from keras.src.api_export import keras_export
|
|
|
8
8
|
from keras.src.callbacks.monitor_callback import (
|
|
9
9
|
MonitorCallback, # For metric monitoring logic
|
|
10
10
|
)
|
|
11
|
-
from keras.src.utils.io_utils import print_msg
|
|
12
11
|
from keras.src.utils.module_utils import ocp
|
|
13
12
|
|
|
14
13
|
# Context and AsyncOptions are accessed through the lazy-loaded ocp module
|
|
@@ -62,6 +61,11 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
62
61
|
This callback saves the model's weights and optimizer state asynchronously
|
|
63
62
|
using Orbax, allowing training to continue without blocking for I/O.
|
|
64
63
|
|
|
64
|
+
**Multi-host Support**: When running in a multi-host distributed training
|
|
65
|
+
environment with JAX backend, this callback automatically coordinates
|
|
66
|
+
checkpointing across all hosts to ensure consistency and proper
|
|
67
|
+
synchronization. Multi-host checkpointing is only supported on JAX.
|
|
68
|
+
|
|
65
69
|
Example:
|
|
66
70
|
|
|
67
71
|
```python
|
|
@@ -92,10 +96,6 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
92
96
|
verbose: Verbosity mode, 0 or 1.
|
|
93
97
|
save_best_only: if `save_best_only=True`, it only saves when the model
|
|
94
98
|
is considered the "best" based on the monitored quantity.
|
|
95
|
-
save_weights_only: if `save_weights_only=True`, only the model's
|
|
96
|
-
weights will be saved. Otherwise, the full model state
|
|
97
|
-
(weights, non-trainable variables, optimizer state, and
|
|
98
|
-
metrics state) will be saved. Defaults to False.
|
|
99
99
|
mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
|
|
100
100
|
save_freq: `'epoch'` or integer. Frequency to save checkpoints.
|
|
101
101
|
max_to_keep: Integer, maximum number of recent checkpoints to keep.
|
|
@@ -112,7 +112,6 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
112
112
|
monitor="val_loss",
|
|
113
113
|
verbose=0,
|
|
114
114
|
save_best_only=False,
|
|
115
|
-
save_weights_only=False,
|
|
116
115
|
mode="auto",
|
|
117
116
|
save_freq="epoch",
|
|
118
117
|
initial_value_threshold=None,
|
|
@@ -129,7 +128,6 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
129
128
|
self.directory = directory
|
|
130
129
|
self.verbose = verbose
|
|
131
130
|
self.save_best_only = save_best_only
|
|
132
|
-
self.save_weights_only = save_weights_only
|
|
133
131
|
self.save_freq = save_freq
|
|
134
132
|
self.max_to_keep = max_to_keep
|
|
135
133
|
self.save_on_background = save_on_background
|
|
@@ -138,6 +136,9 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
138
136
|
self._current_epoch = 0 # Keep track of epoch
|
|
139
137
|
self._total_batches_seen = 0 # Global batch counter for step tracking
|
|
140
138
|
|
|
139
|
+
# Multi-host support
|
|
140
|
+
self._multihost_initialized = self._is_multihost_initialized()
|
|
141
|
+
|
|
141
142
|
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
|
|
142
143
|
raise ValueError(
|
|
143
144
|
f"Unrecognized save_freq: {self.save_freq}. "
|
|
@@ -151,14 +152,18 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
151
152
|
ocp.training.preservation_policies.LatestN(max_to_keep)
|
|
152
153
|
)
|
|
153
154
|
|
|
154
|
-
# Use AnyPreservationPolicy to combine them
|
|
155
|
+
# Use AnyPreservationPolicy to combine them, or use directly
|
|
156
|
+
# if single policy
|
|
155
157
|
preservation_policy = None
|
|
156
158
|
if policies:
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
159
|
+
if len(policies) == 1:
|
|
160
|
+
preservation_policy = policies[0]
|
|
161
|
+
else:
|
|
162
|
+
preservation_policy = (
|
|
163
|
+
ocp.training.preservation_policies.AnyPreservationPolicy(
|
|
164
|
+
policies
|
|
165
|
+
)
|
|
160
166
|
)
|
|
161
|
-
)
|
|
162
167
|
|
|
163
168
|
# Create the V1 Checkpointer with direct parameter passing
|
|
164
169
|
# Orbax will handle directory creation on all processes as needed
|
|
@@ -167,6 +172,54 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
167
172
|
preservation_policy=preservation_policy,
|
|
168
173
|
)
|
|
169
174
|
|
|
175
|
+
def _is_multihost_initialized(self):
|
|
176
|
+
"""Check if multi-host environment is initialized."""
|
|
177
|
+
# Multi-host checkpointing is only supported on JAX backend
|
|
178
|
+
if backend.backend() != "jax":
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
multihost = ocp.multihost
|
|
182
|
+
# Check if JAX distributed client is initialized
|
|
183
|
+
# (indicates multihost setup)
|
|
184
|
+
return multihost.is_jax_distributed_client_initialized()
|
|
185
|
+
|
|
186
|
+
def _sync_processes(self, key=None):
|
|
187
|
+
"""Synchronize all processes across hosts."""
|
|
188
|
+
if not self._multihost_initialized:
|
|
189
|
+
return # No-op for single host
|
|
190
|
+
|
|
191
|
+
multihost = ocp.multihost
|
|
192
|
+
sync_key = key or "orbax_checkpoint_sync"
|
|
193
|
+
multihost.sync_global_processes(sync_key)
|
|
194
|
+
|
|
195
|
+
def is_multihost_enabled(self):
|
|
196
|
+
"""Return True if multi-host checkpointing is enabled and initialized.
|
|
197
|
+
|
|
198
|
+
This method can be used to check if the callback is operating in
|
|
199
|
+
a multi-host distributed training environment. Multi-host checkpointing
|
|
200
|
+
is only supported on JAX backend.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
bool: True if multi-host support is active, False otherwise.
|
|
204
|
+
"""
|
|
205
|
+
return self._multihost_initialized
|
|
206
|
+
|
|
207
|
+
def is_primary_host(self):
|
|
208
|
+
"""Return True if this process is the primary host in multi-host setup.
|
|
209
|
+
|
|
210
|
+
In multi-host environments, only the primary host typically handles
|
|
211
|
+
logging and coordination tasks. Multi-host checkpointing is only
|
|
212
|
+
supported on JAX backend.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
bool: True if this is the primary host, False otherwise.
|
|
216
|
+
Always returns True in single-host environments.
|
|
217
|
+
"""
|
|
218
|
+
if not self._multihost_initialized:
|
|
219
|
+
return True # Single host is always primary
|
|
220
|
+
multihost = ocp.multihost
|
|
221
|
+
return multihost.is_primary_host()
|
|
222
|
+
|
|
170
223
|
def _should_save_on_batch(self, batch):
|
|
171
224
|
"""Check if we should save on this batch."""
|
|
172
225
|
if self.save_freq == "epoch":
|
|
@@ -186,32 +239,14 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
186
239
|
return False
|
|
187
240
|
|
|
188
241
|
def _save_checkpoint(self, step, logs=None):
|
|
189
|
-
"""Save a checkpoint at the given step."""
|
|
242
|
+
"""Save a checkpoint at the given step with multi-host coordination."""
|
|
190
243
|
|
|
191
244
|
# --- Prepare Composite State (Backend-Agnostic) ---
|
|
192
245
|
state_tree = _get_state_tree(self.model)
|
|
193
246
|
|
|
194
247
|
# Save the nested state structures directly (preserving layer
|
|
195
248
|
# names and structure)
|
|
196
|
-
|
|
197
|
-
composite_state = {
|
|
198
|
-
"trainable_variables": state_tree["trainable_variables"],
|
|
199
|
-
}
|
|
200
|
-
if "non_trainable_variables" in state_tree:
|
|
201
|
-
composite_state["non_trainable_variables"] = state_tree[
|
|
202
|
-
"non_trainable_variables"
|
|
203
|
-
]
|
|
204
|
-
else:
|
|
205
|
-
composite_state = state_tree
|
|
206
|
-
|
|
207
|
-
# --- Save Logic (V1 API) ---
|
|
208
|
-
# All processes participate in distributed checkpointing
|
|
209
|
-
# Checkpointer is configured to save unconditionally when
|
|
210
|
-
# save_pytree is called
|
|
211
|
-
if self.verbose > 0:
|
|
212
|
-
print_msg(
|
|
213
|
-
f"OrbaxCheckpoint: Triggering async save for step {step}..."
|
|
214
|
-
)
|
|
249
|
+
composite_state = state_tree
|
|
215
250
|
|
|
216
251
|
# Use a single with statement. If context_options is empty,
|
|
217
252
|
# Context() uses defaults.
|
|
@@ -282,18 +317,16 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
282
317
|
except Exception:
|
|
283
318
|
pass # Ignore errors during cleanup
|
|
284
319
|
|
|
320
|
+
# Multi-host synchronization: ensure all hosts complete cleanup
|
|
321
|
+
self._sync_processes("checkpoint_cleanup")
|
|
322
|
+
|
|
285
323
|
def wait_until_finished(self):
|
|
286
324
|
"""Wait for any in-progress checkpoint operations to complete.
|
|
287
325
|
This method blocks until all asynchronous checkpoint save operations
|
|
288
|
-
have completed
|
|
289
|
-
checkpoints if there might be pending save operations.
|
|
326
|
+
have completed across all hosts in a multi-host setup.
|
|
290
327
|
"""
|
|
291
|
-
# Wait for any async operations to complete
|
|
292
|
-
|
|
293
|
-
self.checkpointer.wait()
|
|
294
|
-
else:
|
|
295
|
-
# Fallback for older Orbax versions that don't have wait() method
|
|
296
|
-
while self.checkpointer.is_saving_in_progress():
|
|
297
|
-
import time
|
|
328
|
+
# Wait for any async operations to complete on this host
|
|
329
|
+
self.checkpointer.wait()
|
|
298
330
|
|
|
299
|
-
|
|
331
|
+
# Multi-host synchronization: ensure all hosts complete
|
|
332
|
+
self._sync_processes("checkpoint_wait_complete")
|
keras/src/losses/losses.py
CHANGED
|
@@ -73,6 +73,14 @@ class MeanSquaredError(LossFunctionWrapper):
|
|
|
73
73
|
`"float32"` unless set to different value
|
|
74
74
|
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
|
75
75
|
provided, then the `compute_dtype` will be utilized.
|
|
76
|
+
|
|
77
|
+
Examples:
|
|
78
|
+
|
|
79
|
+
>>> y_true = keras.ops.array([1.0, 0.0, 1.0])
|
|
80
|
+
>>> y_pred = keras.ops.array([0.9, 0.1, 0.8])
|
|
81
|
+
>>> loss = keras.losses.MeanSquaredError()
|
|
82
|
+
>>> loss(y_true, y_pred)
|
|
83
|
+
0.02
|
|
76
84
|
"""
|
|
77
85
|
|
|
78
86
|
def __init__(
|
|
@@ -114,6 +122,14 @@ class MeanAbsoluteError(LossFunctionWrapper):
|
|
|
114
122
|
`"float32"` unless set to different value
|
|
115
123
|
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
|
116
124
|
provided, then the `compute_dtype` will be utilized.
|
|
125
|
+
|
|
126
|
+
Examples:
|
|
127
|
+
|
|
128
|
+
>>> y_true = keras.ops.array([1.0, 0.3, 1.0])
|
|
129
|
+
>>> y_pred = keras.ops.array([1.9, 0.3, 1.8])
|
|
130
|
+
>>> loss = keras.losses.MeanAbsoluteError()
|
|
131
|
+
>>> loss(y_true, y_pred)
|
|
132
|
+
0.5666667
|
|
117
133
|
"""
|
|
118
134
|
|
|
119
135
|
def __init__(
|
|
@@ -155,6 +171,14 @@ class MeanAbsolutePercentageError(LossFunctionWrapper):
|
|
|
155
171
|
`"float32"` unless set to different value
|
|
156
172
|
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
|
157
173
|
provided, then the `compute_dtype` will be utilized.
|
|
174
|
+
|
|
175
|
+
Examples:
|
|
176
|
+
|
|
177
|
+
>>> y_true = keras.ops.array([100.0, 200.0, 300.0])
|
|
178
|
+
>>> y_pred = keras.ops.array([90.0, 210.0, 310.0])
|
|
179
|
+
>>> loss = keras.losses.MeanAbsolutePercentageError()
|
|
180
|
+
>>> loss(y_true, y_pred)
|
|
181
|
+
6.111111
|
|
158
182
|
"""
|
|
159
183
|
|
|
160
184
|
def __init__(
|
keras/src/models/model.py
CHANGED
|
@@ -992,13 +992,18 @@ class Model(Trainer, base_trainer.Trainer, Layer):
|
|
|
992
992
|
self.non_trainable_variables, path_value_dict
|
|
993
993
|
)
|
|
994
994
|
elif k == "optimizer_variables":
|
|
995
|
-
self.
|
|
996
|
-
self.
|
|
997
|
-
|
|
995
|
+
if hasattr(self, "optimizer") and self.optimizer is not None:
|
|
996
|
+
self._assign_variable_values(
|
|
997
|
+
self.optimizer.variables, path_value_dict
|
|
998
|
+
)
|
|
998
999
|
elif k == "metrics_variables":
|
|
999
|
-
|
|
1000
|
-
self
|
|
1001
|
-
|
|
1000
|
+
if (
|
|
1001
|
+
hasattr(self, "metrics_variables")
|
|
1002
|
+
and self.metrics_variables
|
|
1003
|
+
):
|
|
1004
|
+
self._assign_variable_values(
|
|
1005
|
+
self.metrics_variables, path_value_dict
|
|
1006
|
+
)
|
|
1002
1007
|
else:
|
|
1003
1008
|
raise ValueError(f"Unknown variable name: {k}")
|
|
1004
1009
|
|
keras/src/ops/image.py
CHANGED
|
@@ -616,42 +616,98 @@ def extract_patches(
|
|
|
616
616
|
padding="valid",
|
|
617
617
|
data_format=None,
|
|
618
618
|
):
|
|
619
|
-
"""Extracts patches from the image(s).
|
|
619
|
+
"""Extracts patches from the image(s) or volume(s).
|
|
620
|
+
|
|
621
|
+
This function supports both 2D and 3D patch extraction based on the
|
|
622
|
+
`size` argument length, similar to how `keras.ops.conv` handles
|
|
623
|
+
different dimensions.
|
|
620
624
|
|
|
621
625
|
Args:
|
|
622
|
-
images: Input image or batch of images.
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
626
|
+
images: Input image/volume or batch of images/volumes.
|
|
627
|
+
For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.
|
|
628
|
+
For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.
|
|
629
|
+
size: Patch size as int or tuple.
|
|
630
|
+
Length 2 tuple `(patch_height, patch_width)` or int for 2D patches.
|
|
631
|
+
Length 3 tuple `(patch_depth, patch_height, patch_width)` for
|
|
632
|
+
3D patches.
|
|
633
|
+
strides: Strides for patch extraction. If not specified, defaults
|
|
634
|
+
to `size` (non-overlapping patches).
|
|
635
|
+
dilation_rate: Dilation rate for patch extraction. Note that
|
|
636
|
+
`dilation_rate > 1` is not supported with `strides > 1`.
|
|
630
637
|
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
|
|
631
638
|
data_format: A string specifying the data format of the input tensor.
|
|
632
639
|
It can be either `"channels_last"` or `"channels_first"`.
|
|
633
|
-
|
|
634
|
-
`(batch, height, width, channels)`, while `"channels_first"`
|
|
635
|
-
corresponds to inputs with shape `(batch, channels, height, width)`.
|
|
636
|
-
If not specified, the value will default to
|
|
637
|
-
`keras.config.image_data_format`.
|
|
640
|
+
If not specified, defaults to `keras.config.image_data_format`.
|
|
638
641
|
|
|
639
642
|
Returns:
|
|
640
|
-
Extracted patches
|
|
643
|
+
Extracted patches with shape depending on input and `size`:
|
|
644
|
+
- 2D patches: 3D (unbatched) or 4D (batched)
|
|
645
|
+
- 3D patches: 4D (unbatched) or 5D (batched)
|
|
641
646
|
|
|
642
647
|
Examples:
|
|
643
648
|
|
|
649
|
+
>>> # 2D patches from batch of images
|
|
644
650
|
>>> image = np.random.random(
|
|
645
651
|
... (2, 20, 20, 3)
|
|
646
|
-
... ).astype("float32")
|
|
652
|
+
... ).astype("float32")
|
|
647
653
|
>>> patches = keras.ops.image.extract_patches(image, (5, 5))
|
|
648
654
|
>>> patches.shape
|
|
649
655
|
(2, 4, 4, 75)
|
|
650
|
-
|
|
656
|
+
|
|
657
|
+
>>> # 2D patches from single image
|
|
658
|
+
>>> image = np.random.random((20, 20, 3)).astype("float32")
|
|
651
659
|
>>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1))
|
|
652
660
|
>>> patches.shape
|
|
653
661
|
(18, 18, 27)
|
|
662
|
+
|
|
663
|
+
>>> # 3D patches from batch of volumes
|
|
664
|
+
>>> volumes = np.random.random(
|
|
665
|
+
... (2, 10, 10, 10, 3)
|
|
666
|
+
... ).astype("float32")
|
|
667
|
+
>>> patches = keras.ops.image.extract_patches(volumes, (3, 3, 3))
|
|
668
|
+
>>> patches.shape
|
|
669
|
+
(2, 3, 3, 3, 81)
|
|
670
|
+
|
|
671
|
+
>>> # 3D patches from single volume
|
|
672
|
+
>>> volume = np.random.random((10, 10, 10, 3)).astype("float32")
|
|
673
|
+
>>> patches = keras.ops.image.extract_patches(volume, (3, 3, 3))
|
|
674
|
+
>>> patches.shape
|
|
675
|
+
(3, 3, 3, 81)
|
|
654
676
|
"""
|
|
677
|
+
# Validate size argument
|
|
678
|
+
if not isinstance(size, int):
|
|
679
|
+
if not isinstance(size, (tuple, list)):
|
|
680
|
+
raise TypeError(
|
|
681
|
+
"Invalid `size` argument. Expected an int or a tuple. "
|
|
682
|
+
f"Received: size={size} of type {type(size).__name__}"
|
|
683
|
+
)
|
|
684
|
+
if len(size) not in (2, 3):
|
|
685
|
+
raise ValueError(
|
|
686
|
+
"Invalid `size` argument. Expected a tuple of length 2 or 3. "
|
|
687
|
+
f"Received: size={size} with length {len(size)}"
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
# Determine 2D vs 3D based on size argument
|
|
691
|
+
if not isinstance(size, int) and len(size) == 3:
|
|
692
|
+
# 3D patch extraction
|
|
693
|
+
if any_symbolic_tensors((images,)):
|
|
694
|
+
return ExtractPatches3D(
|
|
695
|
+
size=size,
|
|
696
|
+
strides=strides,
|
|
697
|
+
dilation_rate=dilation_rate,
|
|
698
|
+
padding=padding,
|
|
699
|
+
data_format=data_format,
|
|
700
|
+
).symbolic_call(images)
|
|
701
|
+
return _extract_patches_3d(
|
|
702
|
+
images,
|
|
703
|
+
size,
|
|
704
|
+
strides,
|
|
705
|
+
dilation_rate,
|
|
706
|
+
padding,
|
|
707
|
+
data_format=data_format,
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# 2D patch extraction (default)
|
|
655
711
|
if any_symbolic_tensors((images,)):
|
|
656
712
|
return ExtractPatches(
|
|
657
713
|
size=size,
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Orbax checkpoint loading functionality."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from keras.src.utils.module_utils import ocp
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def is_orbax_checkpoint(filepath):
|
|
9
|
+
"""Check if the given path is an Orbax checkpoint directory."""
|
|
10
|
+
if not os.path.exists(filepath):
|
|
11
|
+
return False
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
return ocp.is_orbax_checkpoint(filepath)
|
|
15
|
+
except (ImportError, AttributeError):
|
|
16
|
+
# Fallback to check for orbax.checkpoint file if Orbax API not available
|
|
17
|
+
return os.path.isfile(os.path.join(filepath, "orbax.checkpoint"))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def find_latest_orbax_checkpoint(checkpoint_dir):
|
|
21
|
+
"""Find the latest checkpoint in an Orbax checkpoint directory."""
|
|
22
|
+
checkpointer = ocp.training.Checkpointer(directory=checkpoint_dir)
|
|
23
|
+
latest = checkpointer.latest
|
|
24
|
+
if latest is None:
|
|
25
|
+
raise ValueError(f"No valid checkpoints found in {checkpoint_dir}")
|
|
26
|
+
return os.path.join(checkpoint_dir, str(latest.step))
|
keras/src/saving/saving_api.py
CHANGED
|
@@ -6,13 +6,11 @@ from absl import logging
|
|
|
6
6
|
from keras.src.api_export import keras_export
|
|
7
7
|
from keras.src.legacy.saving import legacy_h5_format
|
|
8
8
|
from keras.src.saving import saving_lib
|
|
9
|
+
from keras.src.saving.orbax_util import find_latest_orbax_checkpoint
|
|
10
|
+
from keras.src.saving.orbax_util import is_orbax_checkpoint
|
|
9
11
|
from keras.src.utils import file_utils
|
|
10
12
|
from keras.src.utils import io_utils
|
|
11
|
-
|
|
12
|
-
try:
|
|
13
|
-
import h5py
|
|
14
|
-
except ImportError:
|
|
15
|
-
h5py = None
|
|
13
|
+
from keras.src.utils.module_utils import h5py
|
|
16
14
|
|
|
17
15
|
|
|
18
16
|
@keras_export(["keras.saving.save_model", "keras.models.save_model"])
|
|
@@ -149,8 +147,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
149
147
|
keras.layers.Softmax()])
|
|
150
148
|
model.save("model.keras")
|
|
151
149
|
loaded_model = keras.saving.load_model("model.keras")
|
|
152
|
-
x = np.random.random((10, 3))
|
|
153
|
-
assert np.allclose(model.predict(x), loaded_model.predict(x))
|
|
154
150
|
```
|
|
155
151
|
|
|
156
152
|
Note that the model variables may have different name values
|
|
@@ -208,7 +204,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
208
204
|
else:
|
|
209
205
|
raise ValueError(
|
|
210
206
|
f"File format not supported: filepath={filepath}. "
|
|
211
|
-
"Keras 3 only supports V3 `.keras` files
|
|
207
|
+
"Keras 3 only supports V3 `.keras` files, "
|
|
212
208
|
"legacy H5 format files (`.h5` extension). "
|
|
213
209
|
"Note that the legacy SavedModel format is not "
|
|
214
210
|
"supported by `load_model()` in Keras 3. In "
|
|
@@ -288,15 +284,16 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
288
284
|
objects_to_skip=objects_to_skip,
|
|
289
285
|
)
|
|
290
286
|
elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
|
|
291
|
-
if not h5py:
|
|
292
|
-
raise ImportError(
|
|
293
|
-
"Loading a H5 file requires `h5py` to be installed."
|
|
294
|
-
)
|
|
295
287
|
if objects_to_skip is not None:
|
|
296
288
|
raise ValueError(
|
|
297
289
|
"`objects_to_skip` only supports loading '.weights.h5' files."
|
|
298
290
|
f"Received: {filepath}"
|
|
299
291
|
)
|
|
292
|
+
if not h5py.available:
|
|
293
|
+
raise ImportError(
|
|
294
|
+
"Loading HDF5 files requires the h5py package. "
|
|
295
|
+
"You can install it via `pip install h5py`"
|
|
296
|
+
)
|
|
300
297
|
with h5py.File(filepath, "r") as f:
|
|
301
298
|
if "layer_names" not in f.attrs and "model_weights" in f:
|
|
302
299
|
f = f["model_weights"]
|
|
@@ -308,9 +305,35 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
308
305
|
legacy_h5_format.load_weights_from_hdf5_group(
|
|
309
306
|
f, model, skip_mismatch
|
|
310
307
|
)
|
|
308
|
+
elif is_orbax_checkpoint(filepath):
|
|
309
|
+
# Load weights from Orbax checkpoint
|
|
310
|
+
from keras.src.utils.module_utils import ocp
|
|
311
|
+
|
|
312
|
+
filepath = str(filepath)
|
|
313
|
+
|
|
314
|
+
# Determine if this is a root directory or a step directory
|
|
315
|
+
items = os.listdir(filepath)
|
|
316
|
+
has_step_subdirs = any(
|
|
317
|
+
os.path.isdir(os.path.join(filepath, item)) and item.isdigit()
|
|
318
|
+
for item in items
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if has_step_subdirs:
|
|
322
|
+
# It's a root directory, find the latest checkpoint
|
|
323
|
+
checkpoint_path = find_latest_orbax_checkpoint(filepath)
|
|
324
|
+
else:
|
|
325
|
+
# It's a step directory, use it directly
|
|
326
|
+
checkpoint_path = filepath
|
|
327
|
+
|
|
328
|
+
# Load checkpoint
|
|
329
|
+
loaded_state = ocp.load_pytree(checkpoint_path)
|
|
330
|
+
|
|
331
|
+
# Set the model state directly from the loaded state
|
|
332
|
+
model.set_state_tree(loaded_state)
|
|
311
333
|
else:
|
|
312
334
|
raise ValueError(
|
|
313
335
|
f"File format not supported: filepath={filepath}. "
|
|
314
|
-
"Keras 3 only supports V3 `.keras`
|
|
315
|
-
"files,
|
|
336
|
+
"Keras 3 only supports V3 `.keras` files, "
|
|
337
|
+
"`.weights.h5` files, legacy H5 format files "
|
|
338
|
+
"(`.h5` extension), or Orbax checkpoints."
|
|
316
339
|
)
|
keras/src/utils/module_utils.py
CHANGED
|
@@ -44,15 +44,26 @@ class OrbaxLazyModule(LazyModule):
|
|
|
44
44
|
try:
|
|
45
45
|
parent_module = importlib.import_module("orbax.checkpoint")
|
|
46
46
|
self.module = parent_module.v1
|
|
47
|
+
self.parent_module = parent_module
|
|
47
48
|
except ImportError:
|
|
48
49
|
raise ImportError(self.import_error_msg)
|
|
49
50
|
|
|
51
|
+
def __getattr__(self, name):
|
|
52
|
+
if name == "_api_export_path":
|
|
53
|
+
raise AttributeError
|
|
54
|
+
if self.module is None:
|
|
55
|
+
self.initialize()
|
|
56
|
+
if name == "multihost":
|
|
57
|
+
return self.parent_module.multihost
|
|
58
|
+
return getattr(self.module, name)
|
|
59
|
+
|
|
50
60
|
|
|
51
61
|
tensorflow = LazyModule("tensorflow")
|
|
52
62
|
gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow")
|
|
53
63
|
tensorflow_io = LazyModule("tensorflow_io")
|
|
54
64
|
scipy = LazyModule("scipy")
|
|
55
65
|
jax = LazyModule("jax")
|
|
66
|
+
h5py = LazyModule("h5py")
|
|
56
67
|
torch_xla = LazyModule(
|
|
57
68
|
"torch_xla",
|
|
58
69
|
import_error_msg=(
|
keras/src/version.py
CHANGED
{keras_nightly-3.14.0.dev2026011204.dist-info → keras_nightly-3.14.0.dev2026011304.dist-info}/RECORD
RENAMED
|
@@ -128,7 +128,7 @@ keras/regularizers/__init__.py,sha256=542Shphw7W8h4Dyf2rmqMKUECVZ8IVBvN9g1LWhz-b
|
|
|
128
128
|
keras/saving/__init__.py,sha256=KvL2GZxjvgFgEhvEnkvqjIR9JSNHKz-NWZacXajsjLI,1298
|
|
129
129
|
keras/src/__init__.py,sha256=Gi4S7EiCMkE03PbdGNpFdaUYySWDs_FcAJ8Taz9Y1BE,684
|
|
130
130
|
keras/src/api_export.py,sha256=gXOkBOnmscV013WAc75lc4Up01-Kkg9EylIAT_QWctg,1173
|
|
131
|
-
keras/src/version.py,sha256=
|
|
131
|
+
keras/src/version.py,sha256=6Xtelgtl2s2j5h0H2_ulwbg1pYur71GFnblVZ9ACMxk,204
|
|
132
132
|
keras/src/activations/__init__.py,sha256=0nL3IFDB9unlrMz8ninKOWo-uCHasTUpTo1tXZb2u44,4433
|
|
133
133
|
keras/src/activations/activations.py,sha256=mogPggtp4CGldI3VOPNmesRxp6EbiR1_i4KLGaVwzL8,17614
|
|
134
134
|
keras/src/applications/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -211,7 +211,7 @@ keras/src/backend/tensorflow/layer.py,sha256=69d40LwL4HhKRsCjj1VRpjfrQXXF8VV3vh0
|
|
|
211
211
|
keras/src/backend/tensorflow/linalg.py,sha256=_lZVfdY1tFvrN7xwbt3INGoTR0yC5v-kI1Q0XppVibY,8773
|
|
212
212
|
keras/src/backend/tensorflow/math.py,sha256=zTu_7Ff6B2Ro862z_xH0OCmIWbV74DjsO5UnfjYuOUQ,12370
|
|
213
213
|
keras/src/backend/tensorflow/nn.py,sha256=6vtZHzUED6_blUPE1Tnc3GAxPpJ2ebxoaiMn80tTL9k,51328
|
|
214
|
-
keras/src/backend/tensorflow/numpy.py,sha256=
|
|
214
|
+
keras/src/backend/tensorflow/numpy.py,sha256=nIpMvr-g81I9KF74RD4AbU4e4t-0eFa9MND2Fh1u8Tk,104623
|
|
215
215
|
keras/src/backend/tensorflow/optimizer.py,sha256=kFlyEOnGjEYdLpd8mpwhUeku78__xBfZbbrDWpJrq60,9307
|
|
216
216
|
keras/src/backend/tensorflow/random.py,sha256=iO8V_soaDXZm9ewyAVbjudhsMj08C348c9Bz64nxXC4,6475
|
|
217
217
|
keras/src/backend/tensorflow/rnn.py,sha256=99EJqbPdWddmG14zyjjhUZfU5zo9ObmslF_Mak7EmAs,34602
|
|
@@ -254,7 +254,7 @@ keras/src/callbacks/lambda_callback.py,sha256=q-nNr_k7MyYRP3HIetFsutcLkq78cUYxDD
|
|
|
254
254
|
keras/src/callbacks/learning_rate_scheduler.py,sha256=II0SLxltUX3omRbGTYffd9KTWLRKtzW57SDRe70_t7E,2965
|
|
255
255
|
keras/src/callbacks/model_checkpoint.py,sha256=Jt2mMKHKx0isrQnhiHADDOmwu72J594m93PBHy-zpV8,18570
|
|
256
256
|
keras/src/callbacks/monitor_callback.py,sha256=-QBKqkKJ7Rg6L40Q80IScpvybmLoodLWcJoAgnTe_c4,4184
|
|
257
|
-
keras/src/callbacks/orbax_checkpoint.py,sha256=
|
|
257
|
+
keras/src/callbacks/orbax_checkpoint.py,sha256=hG_OClsm4lYQVTyCLWLJqgdYl6OXtFjz0J6z5eUtsmY,12529
|
|
258
258
|
keras/src/callbacks/progbar_logger.py,sha256=BqddKoOyc8vxxtKriq5QD3n5JhVPUxkuWF2u1UlCriQ,3104
|
|
259
259
|
keras/src/callbacks/reduce_lr_on_plateau.py,sha256=isJ9EzVo8jIu-_kWTFHpM_gaI5PbHTcUBM0keR9FRHA,4766
|
|
260
260
|
keras/src/callbacks/remote_monitor.py,sha256=VDbNzCdddCDe_ZoeVvwV50oJkwOehhT_IDDYD8LzFOg,2727
|
|
@@ -476,7 +476,7 @@ keras/src/legacy/saving/saving_utils.py,sha256=8Sa2rmBGnTv86Tix20OgwF5vTLTpUYbfG
|
|
|
476
476
|
keras/src/legacy/saving/serialization.py,sha256=hiwqO3Il861pkfm0Egaeph2XbhOlQQobmZjbZZgK32c,21368
|
|
477
477
|
keras/src/losses/__init__.py,sha256=rt63Ye0f7YdAR0eV0EOj2J61DI6xNdp2ojonx6rB3wE,6595
|
|
478
478
|
keras/src/losses/loss.py,sha256=8dCOv64yj9QC_GbcKT9M8YEC_Jr01wWuo-BBqFbfg0Q,8783
|
|
479
|
-
keras/src/losses/losses.py,sha256=
|
|
479
|
+
keras/src/losses/losses.py,sha256=MeFB4X3YLiTCw8sOEKpFUrSD4yv8E7hte91Gg-v04ok,100169
|
|
480
480
|
keras/src/metrics/__init__.py,sha256=CydJsY38PR2lRN4irhO_wnlvgruTEAgSHp8eUYE0lwY,7410
|
|
481
481
|
keras/src/metrics/accuracy_metrics.py,sha256=i_7ObnlyyE_UKDj8Nk5h5skakqpMlkMiphJ20eqcYho,18274
|
|
482
482
|
keras/src/metrics/confusion_metrics.py,sha256=EKN1JGndT7pVesg_YAh8mGiM2wieAbGzXlw1ftuUGu4,62640
|
|
@@ -492,14 +492,14 @@ keras/src/metrics/regression_metrics.py,sha256=eLacV_8CKtzA26BJDJuncUDATuL1x8O6S
|
|
|
492
492
|
keras/src/models/__init__.py,sha256=DPbBPSfIGgsufTfJH5U5xJOeN_Ef4FMadT7KKYg3Kjg,143
|
|
493
493
|
keras/src/models/cloning.py,sha256=P0gMH3H9nyz6SMsdt4BQO05rXFa4qiqZk44rFpEnHsM,15945
|
|
494
494
|
keras/src/models/functional.py,sha256=uD-qH9WwAUhaBrAEWAKnsVvKo0tvdHxa1M0dbBOE96Y,34086
|
|
495
|
-
keras/src/models/model.py,sha256=
|
|
495
|
+
keras/src/models/model.py,sha256=9kM6rbiAZOx3ycq2qM7QV6h2P1di57rA2HlljstSkh8,42215
|
|
496
496
|
keras/src/models/sequential.py,sha256=CC9Q1BNB9m7TkgMHRyjOzhQvneng576wJpmdgHrACKY,14352
|
|
497
497
|
keras/src/models/variable_mapping.py,sha256=FVtcgjBRqOxtvkzOE6kjG9SpcB9keDg2gS5LOTlXvG0,2181
|
|
498
498
|
keras/src/ops/__init__.py,sha256=aORlvnrqY_eQl0EFLWdpHsXHnQ6JLSw1qhwJMr-VXJ0,644
|
|
499
499
|
keras/src/ops/core.py,sha256=t06-MvptYb6ZVwmNj083JyUtzU4M6UTVXOT2vVHtKyU,42781
|
|
500
500
|
keras/src/ops/einops.py,sha256=-pxW0_AzDQNsR7t2TJrzvYXBJpmLYA3fJoO0U_U96PY,6268
|
|
501
501
|
keras/src/ops/function.py,sha256=QV9n1-xeTPDK_FJ3sjlHDWVH2sqDj96R6YQnJueMOlA,17821
|
|
502
|
-
keras/src/ops/image.py,sha256
|
|
502
|
+
keras/src/ops/image.py,sha256=-UBomIodLByXhxN9aSt0JUGn-N5yulXgRZ06XpXq8mM,68743
|
|
503
503
|
keras/src/ops/linalg.py,sha256=3V8S_cgNxZZCIFcFj-FBHTdRqWNbimDtumMvfoc0f30,26736
|
|
504
504
|
keras/src/ops/math.py,sha256=4qYMJ5qAPmeSyeF63YWoGbUkQt6f4_VX0enOChU4mXU,37233
|
|
505
505
|
keras/src/ops/nn.py,sha256=04gjHB2BWusy4tWm59EO5Ns1paJC5umDNGwNCKzaJWQ,104658
|
|
@@ -543,7 +543,8 @@ keras/src/saving/__init__.py,sha256=vnrtfvnzW7Gwtxe5COhaMoEnVYB5iDe2YlqJ-DvqFIk,
|
|
|
543
543
|
keras/src/saving/file_editor.py,sha256=tsUo9mQbMa8433tHTnOKWFhDeathYwDb0CeWcDTTTBQ,32089
|
|
544
544
|
keras/src/saving/keras_saveable.py,sha256=aGIt1ajtsaamfUq18LM6ql8JEoQzi3HwzJEuwQ9bmKE,1285
|
|
545
545
|
keras/src/saving/object_registration.py,sha256=OOO-7-SNfPoFkFsR_c5jzE6aSIDIlHlnMcm9IlI_Gbs,7357
|
|
546
|
-
keras/src/saving/
|
|
546
|
+
keras/src/saving/orbax_util.py,sha256=0o05YKFjiePHgeW_d5fvuUAGzymqbJTeuquUR-7uVGE,906
|
|
547
|
+
keras/src/saving/saving_api.py,sha256=PMkxXhtNNKX8GlwIsCP8-Plt19M012wNEk7i8BhxWzo,12670
|
|
547
548
|
keras/src/saving/saving_lib.py,sha256=-uSXsojqzSl19FtW5FogCclvnu_nnVU3S-Si293DNq0,58723
|
|
548
549
|
keras/src/saving/serialization_lib.py,sha256=yzCTm8hin__MGA2N5M5F-8Zbts5ZJVmINbrH4wEtIwI,30334
|
|
549
550
|
keras/src/testing/__init__.py,sha256=7vVsV7Rn3rG99DdURgnH8ncpxagRwIE0uhH-R4qDyok,315
|
|
@@ -584,7 +585,7 @@ keras/src/utils/io_utils.py,sha256=Riv9TCCnz6xQLUvR1QC-UOCoGZ_KiNTwQVvLY6dKcX8,4
|
|
|
584
585
|
keras/src/utils/jax_layer.py,sha256=ytws8NcxWzJ4kViBy3bc-Pk3st3_3L8RqXxgq9sYp1k,32912
|
|
585
586
|
keras/src/utils/jax_utils.py,sha256=vY3P4S9mfWEjdirLd81ocKqeCm-UVfgQ1yTi6UHdBiM,322
|
|
586
587
|
keras/src/utils/model_visualization.py,sha256=0ENeiq8q-qbyGjfcRixyyInb3aTxfcKCooKhZ1hSuI0,17794
|
|
587
|
-
keras/src/utils/module_utils.py,sha256=
|
|
588
|
+
keras/src/utils/module_utils.py,sha256=FTZPMRLurURchLPX1tu-h3b-UoPW28faNOlDzpYDW6A,2894
|
|
588
589
|
keras/src/utils/naming.py,sha256=bPowKBlgiVP_6XtVlNVHxrxheKuJy2c0e-oEM8ocZQY,1776
|
|
589
590
|
keras/src/utils/numerical_utils.py,sha256=Uqe5nu1HXmiZuh5-MznomtDSVSO9FgFaltdDtGnN61o,7205
|
|
590
591
|
keras/src/utils/progbar.py,sha256=Yg2Vp1xzqU7HnfDEGSeZsmOKAKYKA4oEHv7yAMaucYw,10358
|
|
@@ -614,7 +615,7 @@ keras/utils/bounding_boxes/__init__.py,sha256=jtvQll4u8ZY0Z96HwNhP1nxWEG9FM3gI-6
|
|
|
614
615
|
keras/utils/legacy/__init__.py,sha256=oSYZz6uS8UxSElRaaJYWJEoweJ4GAasZjnn7fNaOlog,342
|
|
615
616
|
keras/visualization/__init__.py,sha256=UKWmiy6sps4SWlmQi9WX8_Z53cPpLlphz2zIeHdwJpQ,722
|
|
616
617
|
keras/wrappers/__init__.py,sha256=QkS-O5K8qGS7C3sytF8MpmO6PasATpNVGF8qtb7Ojsw,407
|
|
617
|
-
keras_nightly-3.14.0.
|
|
618
|
-
keras_nightly-3.14.0.
|
|
619
|
-
keras_nightly-3.14.0.
|
|
620
|
-
keras_nightly-3.14.0.
|
|
618
|
+
keras_nightly-3.14.0.dev2026011304.dist-info/METADATA,sha256=VvTqhhwTIlxRsKL9og2HY3vWDoIfXS2ITQ2jU8Mt2X4,6339
|
|
619
|
+
keras_nightly-3.14.0.dev2026011304.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
620
|
+
keras_nightly-3.14.0.dev2026011304.dist-info/top_level.txt,sha256=ptcw_-QuGZ4ZDjMdwi_Z0clZm8QAqFdvzzFnDEOTs9o,6
|
|
621
|
+
keras_nightly-3.14.0.dev2026011304.dist-info/RECORD,,
|
{keras_nightly-3.14.0.dev2026011204.dist-info → keras_nightly-3.14.0.dev2026011304.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|