keras-nightly 3.14.0.dev2025122704__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/ops/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +3 -0
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/ops/__init__.py +3 -0
- keras/ops/numpy/__init__.py +3 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/backend/jax/nn.py +26 -9
- keras/src/backend/jax/numpy.py +16 -0
- keras/src/backend/numpy/numpy.py +23 -0
- keras/src/backend/openvino/numpy.py +369 -16
- keras/src/backend/tensorflow/numpy.py +34 -1
- keras/src/backend/tensorflow/rnn.py +17 -7
- keras/src/backend/torch/numpy.py +36 -0
- keras/src/backend/torch/rnn.py +28 -11
- keras/src/callbacks/orbax_checkpoint.py +75 -42
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/layers/core/dense.py +122 -6
- keras/src/layers/core/einsum_dense.py +151 -7
- keras/src/layers/core/embedding.py +1 -1
- keras/src/layers/core/reversible_embedding.py +10 -1
- keras/src/layers/layer.py +5 -0
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/losses/losses.py +24 -0
- keras/src/models/model.py +18 -9
- keras/src/ops/image.py +109 -96
- keras/src/ops/numpy.py +181 -0
- keras/src/quantizers/__init__.py +2 -0
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +1 -2
- keras/src/quantizers/gptq_core.py +1 -1
- keras/src/quantizers/quantization_config.py +14 -0
- keras/src/quantizers/quantizers.py +61 -52
- keras/src/random/seed_generator.py +2 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +50 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/utils/jax_layer.py +69 -31
- keras/src/utils/module_utils.py +11 -0
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +53 -49
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
keras/src/saving/saving_api.py
CHANGED
|
@@ -6,13 +6,11 @@ from absl import logging
|
|
|
6
6
|
from keras.src.api_export import keras_export
|
|
7
7
|
from keras.src.legacy.saving import legacy_h5_format
|
|
8
8
|
from keras.src.saving import saving_lib
|
|
9
|
+
from keras.src.saving.orbax_util import find_latest_orbax_checkpoint
|
|
10
|
+
from keras.src.saving.orbax_util import is_orbax_checkpoint
|
|
9
11
|
from keras.src.utils import file_utils
|
|
10
12
|
from keras.src.utils import io_utils
|
|
11
|
-
|
|
12
|
-
try:
|
|
13
|
-
import h5py
|
|
14
|
-
except ImportError:
|
|
15
|
-
h5py = None
|
|
13
|
+
from keras.src.utils.module_utils import h5py
|
|
16
14
|
|
|
17
15
|
|
|
18
16
|
@keras_export(["keras.saving.save_model", "keras.models.save_model"])
|
|
@@ -149,8 +147,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
149
147
|
keras.layers.Softmax()])
|
|
150
148
|
model.save("model.keras")
|
|
151
149
|
loaded_model = keras.saving.load_model("model.keras")
|
|
152
|
-
x = np.random.random((10, 3))
|
|
153
|
-
assert np.allclose(model.predict(x), loaded_model.predict(x))
|
|
154
150
|
```
|
|
155
151
|
|
|
156
152
|
Note that the model variables may have different name values
|
|
@@ -208,7 +204,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
208
204
|
else:
|
|
209
205
|
raise ValueError(
|
|
210
206
|
f"File format not supported: filepath={filepath}. "
|
|
211
|
-
"Keras 3 only supports V3 `.keras` files
|
|
207
|
+
"Keras 3 only supports V3 `.keras` files, "
|
|
212
208
|
"legacy H5 format files (`.h5` extension). "
|
|
213
209
|
"Note that the legacy SavedModel format is not "
|
|
214
210
|
"supported by `load_model()` in Keras 3. In "
|
|
@@ -288,15 +284,16 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
288
284
|
objects_to_skip=objects_to_skip,
|
|
289
285
|
)
|
|
290
286
|
elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
|
|
291
|
-
if not h5py:
|
|
292
|
-
raise ImportError(
|
|
293
|
-
"Loading a H5 file requires `h5py` to be installed."
|
|
294
|
-
)
|
|
295
287
|
if objects_to_skip is not None:
|
|
296
288
|
raise ValueError(
|
|
297
289
|
"`objects_to_skip` only supports loading '.weights.h5' files."
|
|
298
290
|
f"Received: {filepath}"
|
|
299
291
|
)
|
|
292
|
+
if not h5py.available:
|
|
293
|
+
raise ImportError(
|
|
294
|
+
"Loading HDF5 files requires the h5py package. "
|
|
295
|
+
"You can install it via `pip install h5py`"
|
|
296
|
+
)
|
|
300
297
|
with h5py.File(filepath, "r") as f:
|
|
301
298
|
if "layer_names" not in f.attrs and "model_weights" in f:
|
|
302
299
|
f = f["model_weights"]
|
|
@@ -308,9 +305,35 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
308
305
|
legacy_h5_format.load_weights_from_hdf5_group(
|
|
309
306
|
f, model, skip_mismatch
|
|
310
307
|
)
|
|
308
|
+
elif is_orbax_checkpoint(filepath):
|
|
309
|
+
# Load weights from Orbax checkpoint
|
|
310
|
+
from keras.src.utils.module_utils import ocp
|
|
311
|
+
|
|
312
|
+
filepath = str(filepath)
|
|
313
|
+
|
|
314
|
+
# Determine if this is a root directory or a step directory
|
|
315
|
+
items = os.listdir(filepath)
|
|
316
|
+
has_step_subdirs = any(
|
|
317
|
+
os.path.isdir(os.path.join(filepath, item)) and item.isdigit()
|
|
318
|
+
for item in items
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if has_step_subdirs:
|
|
322
|
+
# It's a root directory, find the latest checkpoint
|
|
323
|
+
checkpoint_path = find_latest_orbax_checkpoint(filepath)
|
|
324
|
+
else:
|
|
325
|
+
# It's a step directory, use it directly
|
|
326
|
+
checkpoint_path = filepath
|
|
327
|
+
|
|
328
|
+
# Load checkpoint
|
|
329
|
+
loaded_state = ocp.load_pytree(checkpoint_path)
|
|
330
|
+
|
|
331
|
+
# Set the model state directly from the loaded state
|
|
332
|
+
model.set_state_tree(loaded_state)
|
|
311
333
|
else:
|
|
312
334
|
raise ValueError(
|
|
313
335
|
f"File format not supported: filepath={filepath}. "
|
|
314
|
-
"Keras 3 only supports V3 `.keras`
|
|
315
|
-
"files,
|
|
336
|
+
"Keras 3 only supports V3 `.keras` files, "
|
|
337
|
+
"`.weights.h5` files, legacy H5 format files "
|
|
338
|
+
"(`.h5` extension), or Orbax checkpoints."
|
|
316
339
|
)
|
keras/src/utils/jax_layer.py
CHANGED
|
@@ -11,6 +11,7 @@ from keras.src.api_export import keras_export
|
|
|
11
11
|
from keras.src.backend.common.variables import is_float_dtype
|
|
12
12
|
from keras.src.backend.common.variables import standardize_dtype
|
|
13
13
|
from keras.src.layers.layer import Layer
|
|
14
|
+
from keras.src.random.seed_generator import draw_seed
|
|
14
15
|
from keras.src.saving import serialization_lib
|
|
15
16
|
from keras.src.utils import jax_utils
|
|
16
17
|
from keras.src.utils import tracking
|
|
@@ -244,15 +245,9 @@ class JaxLayer(Layer):
|
|
|
244
245
|
f" Tensorflow backend. Current backend: {backend.backend()}"
|
|
245
246
|
)
|
|
246
247
|
|
|
247
|
-
if init_fn is None and params is None and state is None:
|
|
248
|
-
raise ValueError(
|
|
249
|
-
"`init_fn`, `params` and `state` cannot all be `None`."
|
|
250
|
-
)
|
|
251
|
-
|
|
252
248
|
super().__init__(**kwargs)
|
|
253
249
|
self.call_fn = call_fn
|
|
254
250
|
self.init_fn = init_fn
|
|
255
|
-
self.seed_generator = backend.random.SeedGenerator(seed)
|
|
256
251
|
self.tracked_params = self._create_variables(params, trainable=True)
|
|
257
252
|
self.tracked_state = self._create_variables(state, trainable=False)
|
|
258
253
|
if self.params is not None or self.state is not None:
|
|
@@ -264,7 +259,25 @@ class JaxLayer(Layer):
|
|
|
264
259
|
{"params", "state", "rng", "inputs", "training"},
|
|
265
260
|
{"inputs"},
|
|
266
261
|
)
|
|
267
|
-
self.
|
|
262
|
+
self.call_fn_has_params = "params" in self.call_fn_arguments
|
|
263
|
+
self.call_fn_has_state = "state" in self.call_fn_arguments
|
|
264
|
+
call_fn_has_rng = "rng" in self.call_fn_arguments
|
|
265
|
+
|
|
266
|
+
if call_fn_has_rng:
|
|
267
|
+
self.seed_generator = backend.random.SeedGenerator(seed)
|
|
268
|
+
else:
|
|
269
|
+
self.seed_generator = None
|
|
270
|
+
|
|
271
|
+
if (
|
|
272
|
+
init_fn is None
|
|
273
|
+
and params is None
|
|
274
|
+
and state is None
|
|
275
|
+
and (self.call_fn_has_params or self.call_fn_has_state)
|
|
276
|
+
):
|
|
277
|
+
raise ValueError(
|
|
278
|
+
"`init_fn`, `params` and `state` cannot all be `None` when "
|
|
279
|
+
"`call_fn` takes a `params` or a `state` argument."
|
|
280
|
+
)
|
|
268
281
|
|
|
269
282
|
if init_fn:
|
|
270
283
|
self.init_fn_arguments = self._validate_signature(
|
|
@@ -428,37 +441,58 @@ class JaxLayer(Layer):
|
|
|
428
441
|
flat_variables, _ = jax.tree_util.tree_flatten(variables)
|
|
429
442
|
return flat_variables
|
|
430
443
|
|
|
444
|
+
def _get_init_seed(self):
|
|
445
|
+
"""
|
|
446
|
+
Returns a single seed as a tensor of shape [2].
|
|
447
|
+
|
|
448
|
+
Call this within `_get_init_rng()` to obtain a new seed.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
A native tensor of shape [2] and the backend dtype for seeds.
|
|
452
|
+
"""
|
|
453
|
+
# Use the global SeedGenerator.
|
|
454
|
+
return draw_seed(None)
|
|
455
|
+
|
|
431
456
|
def _get_init_rng(self):
|
|
432
457
|
"""
|
|
433
|
-
Returns a
|
|
434
|
-
|
|
458
|
+
Returns a seed or seeds to pass as the `rng` argument of `init_fn`.
|
|
459
|
+
|
|
460
|
+
By default, this returns a single seed. Override this to return a
|
|
461
|
+
different structure. Overrides should use `self._get_init_seed()` to
|
|
462
|
+
obtain new seeds.
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
RNG key or structure of keys as tensors of shape [2] and the backend
|
|
466
|
+
dtype for seeds.
|
|
467
|
+
"""
|
|
468
|
+
return self._get_init_seed()
|
|
469
|
+
|
|
470
|
+
def _get_call_seed(self):
|
|
471
|
+
"""
|
|
472
|
+
Returns a single seed as a tensor of shape [2].
|
|
435
473
|
|
|
436
|
-
|
|
437
|
-
`self.seed_generator.next()`. Override this to return a different
|
|
438
|
-
structure.
|
|
474
|
+
Call this within `_get_call_rng()` to obtain a new seed.
|
|
439
475
|
|
|
440
476
|
Returns:
|
|
441
|
-
|
|
442
|
-
as the `rng` argument of `init_fn`.
|
|
477
|
+
A native tensor of shape [2] and the backend dtype for seeds.
|
|
443
478
|
"""
|
|
444
479
|
return self.seed_generator.next()
|
|
445
480
|
|
|
446
481
|
def _get_call_rng(self, training):
|
|
447
482
|
"""
|
|
448
|
-
Returns a
|
|
449
|
-
to pass to `call_fn`.
|
|
483
|
+
Returns a seed or seeds to pass as the `rng` argument of `call_fn`.
|
|
450
484
|
|
|
451
|
-
By default, this returns a
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
to
|
|
485
|
+
By default, this returns a seed when `training` is `True`, and `None`
|
|
486
|
+
when `training` is `False`. Override this to return a different
|
|
487
|
+
structure or to pass seeds in inference mode too. Overrides should use
|
|
488
|
+
`self._get_call_seed()` to obtain seeds.
|
|
455
489
|
|
|
456
490
|
Returns:
|
|
457
|
-
|
|
458
|
-
|
|
491
|
+
RNG key or structure of keys as tensors of shape [2] and the backend
|
|
492
|
+
dtype for seeds.
|
|
459
493
|
"""
|
|
460
494
|
if training:
|
|
461
|
-
return self.
|
|
495
|
+
return self._get_call_seed()
|
|
462
496
|
else:
|
|
463
497
|
return None
|
|
464
498
|
|
|
@@ -492,7 +526,7 @@ class JaxLayer(Layer):
|
|
|
492
526
|
init_args.append(True)
|
|
493
527
|
|
|
494
528
|
init_result = self.init_fn(*init_args)
|
|
495
|
-
if self.
|
|
529
|
+
if self.call_fn_has_state:
|
|
496
530
|
init_params, init_state = init_result
|
|
497
531
|
else:
|
|
498
532
|
init_params, init_state = init_result, None
|
|
@@ -503,7 +537,11 @@ class JaxLayer(Layer):
|
|
|
503
537
|
self.tracked_state = self._create_variables(init_state, trainable=False)
|
|
504
538
|
|
|
505
539
|
def build(self, input_shape):
|
|
506
|
-
if
|
|
540
|
+
if (
|
|
541
|
+
self.params is None
|
|
542
|
+
and self.state is None
|
|
543
|
+
and (self.call_fn_has_params or self.call_fn_has_state)
|
|
544
|
+
):
|
|
507
545
|
self._initialize_weights(input_shape)
|
|
508
546
|
|
|
509
547
|
if backend.backend() == "tensorflow":
|
|
@@ -578,7 +616,7 @@ class JaxLayer(Layer):
|
|
|
578
616
|
variable.assign(value)
|
|
579
617
|
|
|
580
618
|
def call_with_fn(fn):
|
|
581
|
-
if self.
|
|
619
|
+
if self.call_fn_has_state:
|
|
582
620
|
predictions, new_state = fn(*call_args)
|
|
583
621
|
jax.tree_util.tree_map(
|
|
584
622
|
assign_state_to_variable, new_state, self.state
|
|
@@ -711,12 +749,12 @@ class FlaxLayer(JaxLayer):
|
|
|
711
749
|
**kwargs,
|
|
712
750
|
):
|
|
713
751
|
# Late import to only require Flax when this is used.
|
|
714
|
-
from flax.
|
|
752
|
+
from flax.linen import DenyList
|
|
715
753
|
|
|
716
754
|
self.module = module
|
|
717
755
|
self.method = method
|
|
718
756
|
|
|
719
|
-
apply_mutable =
|
|
757
|
+
apply_mutable = DenyList(["params"])
|
|
720
758
|
|
|
721
759
|
def apply_with_training(params, state, rng, inputs, training):
|
|
722
760
|
return self.module.apply(
|
|
@@ -801,13 +839,13 @@ class FlaxLayer(JaxLayer):
|
|
|
801
839
|
|
|
802
840
|
def _get_init_rng(self):
|
|
803
841
|
return {
|
|
804
|
-
"params": self.
|
|
805
|
-
"dropout": self.
|
|
842
|
+
"params": self._get_init_seed(),
|
|
843
|
+
"dropout": self._get_init_seed(),
|
|
806
844
|
}
|
|
807
845
|
|
|
808
846
|
def _get_call_rng(self, training):
|
|
809
847
|
if training:
|
|
810
|
-
return {"dropout": self.
|
|
848
|
+
return {"dropout": self._get_call_seed()}
|
|
811
849
|
else:
|
|
812
850
|
return {}
|
|
813
851
|
|
keras/src/utils/module_utils.py
CHANGED
|
@@ -44,15 +44,26 @@ class OrbaxLazyModule(LazyModule):
|
|
|
44
44
|
try:
|
|
45
45
|
parent_module = importlib.import_module("orbax.checkpoint")
|
|
46
46
|
self.module = parent_module.v1
|
|
47
|
+
self.parent_module = parent_module
|
|
47
48
|
except ImportError:
|
|
48
49
|
raise ImportError(self.import_error_msg)
|
|
49
50
|
|
|
51
|
+
def __getattr__(self, name):
|
|
52
|
+
if name == "_api_export_path":
|
|
53
|
+
raise AttributeError
|
|
54
|
+
if self.module is None:
|
|
55
|
+
self.initialize()
|
|
56
|
+
if name == "multihost":
|
|
57
|
+
return self.parent_module.multihost
|
|
58
|
+
return getattr(self.module, name)
|
|
59
|
+
|
|
50
60
|
|
|
51
61
|
tensorflow = LazyModule("tensorflow")
|
|
52
62
|
gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow")
|
|
53
63
|
tensorflow_io = LazyModule("tensorflow_io")
|
|
54
64
|
scipy = LazyModule("scipy")
|
|
55
65
|
jax = LazyModule("jax")
|
|
66
|
+
h5py = LazyModule("h5py")
|
|
56
67
|
torch_xla = LazyModule(
|
|
57
68
|
"torch_xla",
|
|
58
69
|
import_error_msg=(
|
keras/src/utils/tracking.py
CHANGED
|
@@ -31,13 +31,13 @@ def no_automatic_dependency_tracking(fn):
|
|
|
31
31
|
class Tracker:
|
|
32
32
|
"""Attribute tracker, used for e.g. Variable tracking.
|
|
33
33
|
|
|
34
|
-
Monitors certain attribute types
|
|
35
|
-
|
|
34
|
+
Monitors certain attribute types and places matching
|
|
35
|
+
objects into user provided tracking collections.
|
|
36
36
|
|
|
37
37
|
Also passively tracks certain mutable collections
|
|
38
|
-
(dict
|
|
39
|
-
still
|
|
40
|
-
collections
|
|
38
|
+
(e.g. dict and list) ensuring that items added after
|
|
39
|
+
initialization are still tracked. This is done by wrapping
|
|
40
|
+
these collections in tracking-aware proxy objects.
|
|
41
41
|
|
|
42
42
|
Example:
|
|
43
43
|
|
keras/src/version.py
CHANGED