keras-nightly 3.14.0.dev2025122704__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +3 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +3 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +3 -0
  7. keras/ops/numpy/__init__.py +3 -0
  8. keras/quantizers/__init__.py +1 -0
  9. keras/src/backend/jax/nn.py +26 -9
  10. keras/src/backend/jax/numpy.py +16 -0
  11. keras/src/backend/numpy/numpy.py +23 -0
  12. keras/src/backend/openvino/numpy.py +369 -16
  13. keras/src/backend/tensorflow/numpy.py +34 -1
  14. keras/src/backend/tensorflow/rnn.py +17 -7
  15. keras/src/backend/torch/numpy.py +36 -0
  16. keras/src/backend/torch/rnn.py +28 -11
  17. keras/src/callbacks/orbax_checkpoint.py +75 -42
  18. keras/src/dtype_policies/__init__.py +2 -0
  19. keras/src/dtype_policies/dtype_policy.py +90 -1
  20. keras/src/layers/core/dense.py +122 -6
  21. keras/src/layers/core/einsum_dense.py +151 -7
  22. keras/src/layers/core/embedding.py +1 -1
  23. keras/src/layers/core/reversible_embedding.py +10 -1
  24. keras/src/layers/layer.py +5 -0
  25. keras/src/layers/preprocessing/feature_space.py +8 -4
  26. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  27. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
  28. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  29. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  30. keras/src/losses/losses.py +24 -0
  31. keras/src/models/model.py +18 -9
  32. keras/src/ops/image.py +109 -96
  33. keras/src/ops/numpy.py +181 -0
  34. keras/src/quantizers/__init__.py +2 -0
  35. keras/src/quantizers/awq.py +361 -0
  36. keras/src/quantizers/awq_config.py +140 -0
  37. keras/src/quantizers/awq_core.py +217 -0
  38. keras/src/quantizers/gptq.py +1 -2
  39. keras/src/quantizers/gptq_core.py +1 -1
  40. keras/src/quantizers/quantization_config.py +14 -0
  41. keras/src/quantizers/quantizers.py +61 -52
  42. keras/src/random/seed_generator.py +2 -2
  43. keras/src/saving/file_editor.py +81 -6
  44. keras/src/saving/orbax_util.py +50 -0
  45. keras/src/saving/saving_api.py +37 -14
  46. keras/src/utils/jax_layer.py +69 -31
  47. keras/src/utils/module_utils.py +11 -0
  48. keras/src/utils/tracking.py +5 -5
  49. keras/src/version.py +1 -1
  50. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
  51. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +53 -49
  52. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
  53. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
@@ -6,13 +6,11 @@ from absl import logging
6
6
  from keras.src.api_export import keras_export
7
7
  from keras.src.legacy.saving import legacy_h5_format
8
8
  from keras.src.saving import saving_lib
9
+ from keras.src.saving.orbax_util import find_latest_orbax_checkpoint
10
+ from keras.src.saving.orbax_util import is_orbax_checkpoint
9
11
  from keras.src.utils import file_utils
10
12
  from keras.src.utils import io_utils
11
-
12
- try:
13
- import h5py
14
- except ImportError:
15
- h5py = None
13
+ from keras.src.utils.module_utils import h5py
16
14
 
17
15
 
18
16
  @keras_export(["keras.saving.save_model", "keras.models.save_model"])
@@ -149,8 +147,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
149
147
  keras.layers.Softmax()])
150
148
  model.save("model.keras")
151
149
  loaded_model = keras.saving.load_model("model.keras")
152
- x = np.random.random((10, 3))
153
- assert np.allclose(model.predict(x), loaded_model.predict(x))
154
150
  ```
155
151
 
156
152
  Note that the model variables may have different name values
@@ -208,7 +204,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
208
204
  else:
209
205
  raise ValueError(
210
206
  f"File format not supported: filepath={filepath}. "
211
- "Keras 3 only supports V3 `.keras` files and "
207
+ "Keras 3 only supports V3 `.keras` files, "
212
208
  "legacy H5 format files (`.h5` extension). "
213
209
  "Note that the legacy SavedModel format is not "
214
210
  "supported by `load_model()` in Keras 3. In "
@@ -288,15 +284,16 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
288
284
  objects_to_skip=objects_to_skip,
289
285
  )
290
286
  elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
291
- if not h5py:
292
- raise ImportError(
293
- "Loading a H5 file requires `h5py` to be installed."
294
- )
295
287
  if objects_to_skip is not None:
296
288
  raise ValueError(
297
289
  "`objects_to_skip` only supports loading '.weights.h5' files."
298
290
  f"Received: {filepath}"
299
291
  )
292
+ if not h5py.available:
293
+ raise ImportError(
294
+ "Loading HDF5 files requires the h5py package. "
295
+ "You can install it via `pip install h5py`"
296
+ )
300
297
  with h5py.File(filepath, "r") as f:
301
298
  if "layer_names" not in f.attrs and "model_weights" in f:
302
299
  f = f["model_weights"]
@@ -308,9 +305,35 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
308
305
  legacy_h5_format.load_weights_from_hdf5_group(
309
306
  f, model, skip_mismatch
310
307
  )
308
+ elif is_orbax_checkpoint(filepath):
309
+ # Load weights from Orbax checkpoint
310
+ from keras.src.utils.module_utils import ocp
311
+
312
+ filepath = str(filepath)
313
+
314
+ # Determine if this is a root directory or a step directory
315
+ items = os.listdir(filepath)
316
+ has_step_subdirs = any(
317
+ os.path.isdir(os.path.join(filepath, item)) and item.isdigit()
318
+ for item in items
319
+ )
320
+
321
+ if has_step_subdirs:
322
+ # It's a root directory, find the latest checkpoint
323
+ checkpoint_path = find_latest_orbax_checkpoint(filepath)
324
+ else:
325
+ # It's a step directory, use it directly
326
+ checkpoint_path = filepath
327
+
328
+ # Load checkpoint
329
+ loaded_state = ocp.load_pytree(checkpoint_path)
330
+
331
+ # Set the model state directly from the loaded state
332
+ model.set_state_tree(loaded_state)
311
333
  else:
312
334
  raise ValueError(
313
335
  f"File format not supported: filepath={filepath}. "
314
- "Keras 3 only supports V3 `.keras` and `.weights.h5` "
315
- "files, or legacy V1/V2 `.h5` files."
336
+ "Keras 3 only supports V3 `.keras` files, "
337
+ "`.weights.h5` files, legacy H5 format files "
338
+ "(`.h5` extension), or Orbax checkpoints."
316
339
  )
@@ -11,6 +11,7 @@ from keras.src.api_export import keras_export
11
11
  from keras.src.backend.common.variables import is_float_dtype
12
12
  from keras.src.backend.common.variables import standardize_dtype
13
13
  from keras.src.layers.layer import Layer
14
+ from keras.src.random.seed_generator import draw_seed
14
15
  from keras.src.saving import serialization_lib
15
16
  from keras.src.utils import jax_utils
16
17
  from keras.src.utils import tracking
@@ -244,15 +245,9 @@ class JaxLayer(Layer):
244
245
  f" Tensorflow backend. Current backend: {backend.backend()}"
245
246
  )
246
247
 
247
- if init_fn is None and params is None and state is None:
248
- raise ValueError(
249
- "`init_fn`, `params` and `state` cannot all be `None`."
250
- )
251
-
252
248
  super().__init__(**kwargs)
253
249
  self.call_fn = call_fn
254
250
  self.init_fn = init_fn
255
- self.seed_generator = backend.random.SeedGenerator(seed)
256
251
  self.tracked_params = self._create_variables(params, trainable=True)
257
252
  self.tracked_state = self._create_variables(state, trainable=False)
258
253
  if self.params is not None or self.state is not None:
@@ -264,7 +259,25 @@ class JaxLayer(Layer):
264
259
  {"params", "state", "rng", "inputs", "training"},
265
260
  {"inputs"},
266
261
  )
267
- self.has_state = "state" in self.call_fn_arguments
262
+ self.call_fn_has_params = "params" in self.call_fn_arguments
263
+ self.call_fn_has_state = "state" in self.call_fn_arguments
264
+ call_fn_has_rng = "rng" in self.call_fn_arguments
265
+
266
+ if call_fn_has_rng:
267
+ self.seed_generator = backend.random.SeedGenerator(seed)
268
+ else:
269
+ self.seed_generator = None
270
+
271
+ if (
272
+ init_fn is None
273
+ and params is None
274
+ and state is None
275
+ and (self.call_fn_has_params or self.call_fn_has_state)
276
+ ):
277
+ raise ValueError(
278
+ "`init_fn`, `params` and `state` cannot all be `None` when "
279
+ "`call_fn` takes a `params` or a `state` argument."
280
+ )
268
281
 
269
282
  if init_fn:
270
283
  self.init_fn_arguments = self._validate_signature(
@@ -428,37 +441,58 @@ class JaxLayer(Layer):
428
441
  flat_variables, _ = jax.tree_util.tree_flatten(variables)
429
442
  return flat_variables
430
443
 
444
+ def _get_init_seed(self):
445
+ """
446
+ Returns a single seed as a tensor of shape [2].
447
+
448
+ Call this within `_get_init_rng()` to obtain a new seed.
449
+
450
+ Returns:
451
+ A native tensor of shape [2] and the backend dtype for seeds.
452
+ """
453
+ # Use the global SeedGenerator.
454
+ return draw_seed(None)
455
+
431
456
  def _get_init_rng(self):
432
457
  """
433
- Returns a key in form of the backend array of size 2 dtype uint32
434
- to pass to `init_fn`.
458
+ Returns a seed or seeds to pass as the `rng` argument of `init_fn`.
459
+
460
+ By default, this returns a single seed. Override this to return a
461
+ different structure. Overrides should use `self._get_init_seed()` to
462
+ obtain new seeds.
463
+
464
+ Returns:
465
+ RNG key or structure of keys as tensors of shape [2] and the backend
466
+ dtype for seeds.
467
+ """
468
+ return self._get_init_seed()
469
+
470
+ def _get_call_seed(self):
471
+ """
472
+ Returns a single seed as a tensor of shape [2].
435
473
 
436
- By default, this returns a Jax or TF array of size 2 by calling
437
- `self.seed_generator.next()`. Override this to return a different
438
- structure.
474
+ Call this within `_get_call_rng()` to obtain a new seed.
439
475
 
440
476
  Returns:
441
- a key as an Jax or TF array of size 2 dtype uint32 will be passed
442
- as the `rng` argument of `init_fn`.
477
+ A native tensor of shape [2] and the backend dtype for seeds.
443
478
  """
444
479
  return self.seed_generator.next()
445
480
 
446
481
  def _get_call_rng(self, training):
447
482
  """
448
- Returns a key in form of the backend array of size 2 dtype uint32
449
- to pass to `call_fn`.
483
+ Returns a seed or seeds to pass as the `rng` argument of `call_fn`.
450
484
 
451
- By default, this returns a Jax or TF array of size 2 by calling
452
- `self.seed_generator.next()` when `training` is `True`, and `None` when
453
- `training` is `False`. Override this to return a different structure or
454
- to pass RNGs in inference mode too.
485
+ By default, this returns a seed when `training` is `True`, and `None`
486
+ when `training` is `False`. Override this to return a different
487
+ structure or to pass seeds in inference mode too. Overrides should use
488
+ `self._get_call_seed()` to obtain seeds.
455
489
 
456
490
  Returns:
457
- a key as an Jax or TF array of size 2 dtype uint32 will be passed
458
- as the `rng` argument of `call_fn`.
491
+ RNG key or structure of keys as tensors of shape [2] and the backend
492
+ dtype for seeds.
459
493
  """
460
494
  if training:
461
- return self.seed_generator.next()
495
+ return self._get_call_seed()
462
496
  else:
463
497
  return None
464
498
 
@@ -492,7 +526,7 @@ class JaxLayer(Layer):
492
526
  init_args.append(True)
493
527
 
494
528
  init_result = self.init_fn(*init_args)
495
- if self.has_state:
529
+ if self.call_fn_has_state:
496
530
  init_params, init_state = init_result
497
531
  else:
498
532
  init_params, init_state = init_result, None
@@ -503,7 +537,11 @@ class JaxLayer(Layer):
503
537
  self.tracked_state = self._create_variables(init_state, trainable=False)
504
538
 
505
539
  def build(self, input_shape):
506
- if self.params is None and self.state is None:
540
+ if (
541
+ self.params is None
542
+ and self.state is None
543
+ and (self.call_fn_has_params or self.call_fn_has_state)
544
+ ):
507
545
  self._initialize_weights(input_shape)
508
546
 
509
547
  if backend.backend() == "tensorflow":
@@ -578,7 +616,7 @@ class JaxLayer(Layer):
578
616
  variable.assign(value)
579
617
 
580
618
  def call_with_fn(fn):
581
- if self.has_state:
619
+ if self.call_fn_has_state:
582
620
  predictions, new_state = fn(*call_args)
583
621
  jax.tree_util.tree_map(
584
622
  assign_state_to_variable, new_state, self.state
@@ -711,12 +749,12 @@ class FlaxLayer(JaxLayer):
711
749
  **kwargs,
712
750
  ):
713
751
  # Late import to only require Flax when this is used.
714
- from flax.core import scope as flax_scope
752
+ from flax.linen import DenyList
715
753
 
716
754
  self.module = module
717
755
  self.method = method
718
756
 
719
- apply_mutable = flax_scope.DenyList(["params"])
757
+ apply_mutable = DenyList(["params"])
720
758
 
721
759
  def apply_with_training(params, state, rng, inputs, training):
722
760
  return self.module.apply(
@@ -801,13 +839,13 @@ class FlaxLayer(JaxLayer):
801
839
 
802
840
  def _get_init_rng(self):
803
841
  return {
804
- "params": self.seed_generator.next(),
805
- "dropout": self.seed_generator.next(),
842
+ "params": self._get_init_seed(),
843
+ "dropout": self._get_init_seed(),
806
844
  }
807
845
 
808
846
  def _get_call_rng(self, training):
809
847
  if training:
810
- return {"dropout": self.seed_generator.next()}
848
+ return {"dropout": self._get_call_seed()}
811
849
  else:
812
850
  return {}
813
851
 
@@ -44,15 +44,26 @@ class OrbaxLazyModule(LazyModule):
44
44
  try:
45
45
  parent_module = importlib.import_module("orbax.checkpoint")
46
46
  self.module = parent_module.v1
47
+ self.parent_module = parent_module
47
48
  except ImportError:
48
49
  raise ImportError(self.import_error_msg)
49
50
 
51
+ def __getattr__(self, name):
52
+ if name == "_api_export_path":
53
+ raise AttributeError
54
+ if self.module is None:
55
+ self.initialize()
56
+ if name == "multihost":
57
+ return self.parent_module.multihost
58
+ return getattr(self.module, name)
59
+
50
60
 
51
61
  tensorflow = LazyModule("tensorflow")
52
62
  gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow")
53
63
  tensorflow_io = LazyModule("tensorflow_io")
54
64
  scipy = LazyModule("scipy")
55
65
  jax = LazyModule("jax")
66
+ h5py = LazyModule("h5py")
56
67
  torch_xla = LazyModule(
57
68
  "torch_xla",
58
69
  import_error_msg=(
@@ -31,13 +31,13 @@ def no_automatic_dependency_tracking(fn):
31
31
  class Tracker:
32
32
  """Attribute tracker, used for e.g. Variable tracking.
33
33
 
34
- Monitors certain attribute types
35
- and put them in appropriate lists in case of a match.
34
+ Monitors certain attribute types and places matching
35
+ objects into user provided tracking collections.
36
36
 
37
37
  Also passively tracks certain mutable collections
38
- (dict, list) so that items added to them later
39
- still get tracked. This is done by wrapping these
40
- collections into an equivalent, tracking-aware object.
38
+ (e.g. dict and list) ensuring that items added after
39
+ initialization are still tracked. This is done by wrapping
40
+ these collections in tracking-aware proxy objects.
41
41
 
42
42
  Example:
43
43
 
keras/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras.src.api_export import keras_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "3.14.0.dev2025122704"
4
+ __version__ = "3.14.0.dev2026012204"
5
5
 
6
6
 
7
7
  @keras_export("keras.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-nightly
3
- Version: 3.14.0.dev2025122704
3
+ Version: 3.14.0.dev2026012204
4
4
  Summary: Multi-backend Keras
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0