keras-nightly 3.14.0.dev2026010104__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +2 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +2 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +2 -0
  7. keras/ops/numpy/__init__.py +2 -0
  8. keras/quantizers/__init__.py +1 -0
  9. keras/src/backend/jax/nn.py +26 -9
  10. keras/src/backend/jax/numpy.py +10 -0
  11. keras/src/backend/numpy/numpy.py +15 -0
  12. keras/src/backend/openvino/numpy.py +338 -17
  13. keras/src/backend/tensorflow/numpy.py +24 -1
  14. keras/src/backend/tensorflow/rnn.py +17 -7
  15. keras/src/backend/torch/numpy.py +26 -0
  16. keras/src/backend/torch/rnn.py +28 -11
  17. keras/src/callbacks/orbax_checkpoint.py +75 -42
  18. keras/src/dtype_policies/__init__.py +2 -0
  19. keras/src/dtype_policies/dtype_policy.py +90 -1
  20. keras/src/layers/core/dense.py +122 -6
  21. keras/src/layers/core/einsum_dense.py +151 -7
  22. keras/src/layers/core/embedding.py +1 -1
  23. keras/src/layers/core/reversible_embedding.py +10 -1
  24. keras/src/layers/layer.py +5 -0
  25. keras/src/layers/preprocessing/feature_space.py +8 -4
  26. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  27. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
  28. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  29. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  30. keras/src/losses/losses.py +24 -0
  31. keras/src/models/model.py +18 -9
  32. keras/src/ops/image.py +106 -93
  33. keras/src/ops/numpy.py +138 -0
  34. keras/src/quantizers/__init__.py +2 -0
  35. keras/src/quantizers/awq.py +361 -0
  36. keras/src/quantizers/awq_config.py +140 -0
  37. keras/src/quantizers/awq_core.py +217 -0
  38. keras/src/quantizers/gptq.py +1 -2
  39. keras/src/quantizers/gptq_core.py +1 -1
  40. keras/src/quantizers/quantization_config.py +14 -0
  41. keras/src/quantizers/quantizers.py +61 -52
  42. keras/src/random/seed_generator.py +2 -2
  43. keras/src/saving/orbax_util.py +50 -0
  44. keras/src/saving/saving_api.py +37 -14
  45. keras/src/utils/jax_layer.py +69 -31
  46. keras/src/utils/module_utils.py +11 -0
  47. keras/src/utils/tracking.py +5 -5
  48. keras/src/version.py +1 -1
  49. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
  50. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +52 -48
  51. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
  52. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,217 @@
1
+ """AWQ core functionality for layer-wise quantization.
2
+
3
+ This module provides the orchestration logic for applying AWQ quantization
4
+ to transformer models in a layer-by-layer fashion.
5
+ """
6
+
7
+ from contextlib import contextmanager
8
+
9
+ from absl import logging
10
+
11
+ from keras.src import ops
12
+ from keras.src import utils as keras_utils
13
+ from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
14
+ from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
15
+ from keras.src.quantizers.awq import AWQ
16
+ from keras.src.quantizers.awq_config import AWQConfig
17
+ from keras.src.quantizers.gptq_core import find_layers_in_block
18
+ from keras.src.quantizers.gptq_core import get_dataloader
19
+ from keras.src.quantizers.utils import should_quantize_layer
20
+
21
+
22
+ @contextmanager
23
+ def stream_activations(layers_map, awq_objects):
24
+ """Context manager to capture activations for AWQ calibration.
25
+
26
+ Temporarily patches layer.call methods to capture activation statistics
27
+ for computing per-channel scaling factors.
28
+
29
+ Args:
30
+ layers_map: Dict[str, Layer]. Mapping from layer names to layers.
31
+ awq_objects: Dict[str, AWQ]. Mapping from names to AWQ instances.
32
+
33
+ Yields:
34
+ None: The patched state is active only within the `with` block.
35
+ """
36
+ original_calls = {}
37
+
38
+ def create_hook(name, original_call_func):
39
+ def hook(*args, **kwargs):
40
+ inp = args[0] if args else kwargs["inputs"]
41
+ num_features = awq_objects[name].rows
42
+ input_2d = ops.reshape(inp, (-1, num_features))
43
+ awq_objects[name].update_activation_magnitudes(input_2d)
44
+ return original_call_func(*args, **kwargs)
45
+
46
+ return hook
47
+
48
+ try:
49
+ for name, layer in layers_map.items():
50
+ original_calls[name] = layer.call
51
+ layer.call = create_hook(name, layer.call)
52
+ yield
53
+ finally:
54
+ for name, layer in layers_map.items():
55
+ layer.call = original_calls[name]
56
+
57
+
58
+ def apply_awq_layerwise(dataloader, config, structure, filters=None):
59
+ """Apply AWQ quantization layer-by-layer to a Keras model.
60
+
61
+ This function processes the model sequentially, one block at a time:
62
+ 1. Captures activation statistics through calibration data forward pass
63
+ 2. Uses activation magnitudes to determine weight saliency
64
+ 3. Finds optimal per-channel scales via grid search
65
+ 4. Quantizes weights with AWQ scaling
66
+
67
+ Args:
68
+ dataloader: Calibration data as numpy array.
69
+ config: AWQConfig instance.
70
+ structure: Dict with 'pre_block_layers' and 'sequential_blocks'.
71
+ filters: Optional layer filters.
72
+ """
73
+ num_samples = config.num_samples
74
+ logging.info("Starting AWQ quantization...")
75
+
76
+ pre_layers = structure.get("pre_block_layers", [])
77
+ transformer_blocks = structure.get("sequential_blocks", [])
78
+
79
+ if not transformer_blocks:
80
+ raise ValueError(
81
+ "No sequential blocks found in the structure to quantize."
82
+ )
83
+
84
+ # Process inputs through pre-block layers (e.g., embedding)
85
+ inputs = []
86
+ for batch in dataloader:
87
+ batch = ops.convert_to_tensor(batch, dtype="int32")
88
+ for layer in pre_layers:
89
+ batch = layer(batch)
90
+ inputs.append(batch)
91
+
92
+ num_samples = min(num_samples, len(inputs))
93
+ progbar = keras_utils.Progbar(target=len(transformer_blocks))
94
+
95
+ for block_idx, block in enumerate(transformer_blocks):
96
+ logging.info(f"Quantizing Block {block_idx}")
97
+ sub_layers_map = find_layers_in_block(block)
98
+
99
+ # Apply filters
100
+ final_sub_layers_map = {}
101
+ for name, layer in sub_layers_map.items():
102
+ if not should_quantize_layer(layer, filters):
103
+ continue
104
+ final_sub_layers_map[name] = layer
105
+
106
+ sub_layers_map = final_sub_layers_map
107
+
108
+ if not sub_layers_map:
109
+ logging.info(
110
+ f" No quantizable layers found in block {block_idx}. Skipping."
111
+ )
112
+ else:
113
+ logging.info(f"Found layers: {list(sub_layers_map.keys())}")
114
+
115
+ # Create AWQ objects for each layer
116
+ awq_objects = {
117
+ name: AWQ(layer, config)
118
+ for name, layer in sub_layers_map.items()
119
+ }
120
+
121
+ # Capture activation statistics
122
+ with stream_activations(sub_layers_map, awq_objects):
123
+ for sample_idx in range(num_samples):
124
+ current_input = inputs[sample_idx]
125
+ if len(current_input.shape) == 2:
126
+ current_input = ops.expand_dims(current_input, axis=0)
127
+ _ = block(current_input)
128
+
129
+ # Quantize each layer
130
+ for name, awq_object in awq_objects.items():
131
+ logging.info(f"Quantizing {name}...")
132
+ awq_object.quantize_layer()
133
+ awq_object.free()
134
+
135
+ del awq_objects
136
+
137
+ # Generate inputs for next block
138
+ if block_idx < len(transformer_blocks) - 1:
139
+ logging.info(f"Generating inputs for block {block_idx + 1}...")
140
+ next_block_inputs = []
141
+ for sample_idx in range(num_samples):
142
+ current_input = inputs[sample_idx]
143
+ if len(current_input.shape) == 2:
144
+ current_input = ops.expand_dims(current_input, axis=0)
145
+ output = block(current_input)[0]
146
+ next_block_inputs.append(output)
147
+ inputs = next_block_inputs
148
+
149
+ progbar.update(current=block_idx + 1)
150
+
151
+ logging.info("AWQ quantization complete.")
152
+
153
+
154
+ def awq_quantize(config, quantization_layer_structure, filters=None):
155
+ """Main entry point for AWQ quantization.
156
+
157
+ Args:
158
+ config: AWQConfig instance.
159
+ quantization_layer_structure: Model structure dictionary.
160
+ filters: Optional layer filters.
161
+ """
162
+ if config.dataset is None or config.tokenizer is None:
163
+ raise ValueError(
164
+ "AWQ quantization requires a dataset and tokenizer. "
165
+ "Please provide them in the AWQConfig."
166
+ )
167
+
168
+ if quantization_layer_structure is None:
169
+ raise ValueError(
170
+ "For 'awq' mode, a valid quantization structure must be provided "
171
+ "either via `config.quantization_layer_structure` or by overriding "
172
+ "`model.get_quantization_layer_structure(mode)`. The structure "
173
+ "should be a dictionary with keys 'pre_block_layers' and "
174
+ "'sequential_blocks'."
175
+ )
176
+
177
+ # Load calibration data
178
+ dataloader = get_dataloader(
179
+ config.tokenizer,
180
+ config.sequence_length,
181
+ config.dataset,
182
+ num_samples=config.num_samples,
183
+ )
184
+
185
+ apply_awq_layerwise(
186
+ dataloader[: config.num_samples],
187
+ config,
188
+ quantization_layer_structure,
189
+ filters=filters,
190
+ )
191
+
192
+
193
+ def get_group_size_for_layer(layer, config):
194
+ """Get group size from config or dtype policy.
195
+
196
+ Args:
197
+ layer: The layer to get group size for.
198
+ config: Optional AWQConfig instance.
199
+
200
+ Returns:
201
+ int: The group size for quantization.
202
+
203
+ Raises:
204
+ ValueError: If group size cannot be determined.
205
+ """
206
+ if config and isinstance(config, AWQConfig):
207
+ return config.group_size
208
+ elif isinstance(layer.dtype_policy, AWQDTypePolicy):
209
+ return layer.dtype_policy.group_size
210
+ elif isinstance(layer.dtype_policy, DTypePolicyMap):
211
+ policy = layer.dtype_policy[layer.path]
212
+ if isinstance(policy, AWQDTypePolicy):
213
+ return policy.group_size
214
+ raise ValueError(
215
+ "For AWQ quantization, group_size must be specified "
216
+ "through AWQConfig or AWQDTypePolicy."
217
+ )
@@ -1,5 +1,4 @@
1
1
  import types
2
- from functools import partial
3
2
 
4
3
  from keras.src import ops
5
4
  from keras.src import quantizers
@@ -466,7 +465,7 @@ class GPTQ:
466
465
  group_size=self.config.group_size,
467
466
  activation_order=self.config.activation_order,
468
467
  order_metric=ops.diagonal(hessian_matrix),
469
- compute_scale_zero=partial(self.quantizer.find_params, weight=True),
468
+ compute_scale_zero=self.quantizer.find_params,
470
469
  )
471
470
  quantized = ops.cast(
472
471
  quantized, self.original_layer.quantized_kernel.dtype
@@ -131,7 +131,7 @@ def get_dataloader(
131
131
  pieces = []
132
132
  if isinstance(dataset_list[0], str):
133
133
  for i, s in enumerate(dataset_list):
134
- toks = np.asarray(tokenizer.tokenize(s)).reshape(-1)
134
+ toks = ops.convert_to_numpy(tokenizer.tokenize(s)).reshape(-1)
135
135
  pieces.append(toks)
136
136
  # avoid windows that span document boundaries
137
137
  if eos_id is not None and i < len(dataset_list) - 1:
@@ -182,6 +182,11 @@ def validate_and_resolve_config(mode, config):
182
182
  "For GPTQ, you must pass a `GPTQConfig` object in the "
183
183
  "`config` argument."
184
184
  )
185
+ elif mode == "awq":
186
+ raise ValueError(
187
+ "For AWQ, you must pass an `AWQConfig` object in the "
188
+ "`config` argument."
189
+ )
185
190
  else:
186
191
  if mode is not None:
187
192
  raise ValueError(
@@ -220,6 +225,15 @@ def validate_and_resolve_config(mode, config):
220
225
  f"`GPTQConfig`. Received: {type(config)}"
221
226
  )
222
227
 
228
+ if mode == "awq":
229
+ from keras.src.quantizers.awq_config import AWQConfig
230
+
231
+ if not isinstance(config, AWQConfig):
232
+ raise ValueError(
233
+ "Mode 'awq' requires a valid `config` argument of type "
234
+ f"`AWQConfig`. Received: {type(config)}"
235
+ )
236
+
223
237
  return config
224
238
 
225
239
 
@@ -653,11 +653,14 @@ def unpack_int4(packed, orig_len, axis=0, dtype="int8"):
653
653
  )
654
654
 
655
655
  def to_signed(x):
656
- """Converts unpacked nibbles [0, 15] to signed int4 [-8, 7]."""
656
+ """Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
657
+
658
+ Uses a branchless XOR approach: (x ^ 8) - 8
659
+ This maps: 0->0, 1->1, ..., 7->7, 8->-8, 9->-7, ..., 15->-1
660
+ """
657
661
  dtype_x = backend.standardize_dtype(x.dtype)
658
662
  eight = ops.cast(8, dtype_x)
659
- sixteen = ops.cast(16, dtype_x)
660
- return ops.where(x < eight, x, x - sixteen)
663
+ return ops.subtract(ops.bitwise_xor(x, eight), eight)
661
664
 
662
665
  rank = getattr(packed.shape, "rank", None) or len(packed.shape)
663
666
  if axis < 0:
@@ -748,7 +751,7 @@ class GPTQQuantizer(Quantizer):
748
751
  self.zero = None
749
752
  self.maxq = None
750
753
 
751
- def find_params(self, input_tensor, weight=True):
754
+ def find_params(self, input_tensor):
752
755
  """Finds quantization parameters (scale and zero) for a given tensor."""
753
756
  self.scale, self.zero, self.maxq = compute_quantization_parameters(
754
757
  input_tensor,
@@ -756,7 +759,6 @@ class GPTQQuantizer(Quantizer):
756
759
  symmetric=self.symmetric,
757
760
  per_channel=self.per_channel,
758
761
  group_size=self.group_size,
759
- weight=weight,
760
762
  compute_dtype=self.compute_dtype,
761
763
  )
762
764
  return self.scale, self.zero, self.maxq
@@ -793,98 +795,105 @@ def compute_quantization_parameters(
793
795
  symmetric=False,
794
796
  per_channel=False,
795
797
  group_size=-1,
796
- weight=False,
797
798
  compute_dtype="float32",
798
799
  ):
799
800
  """
800
- Computes the scale and zero-point for quantization.
801
+ Computes the scale and zero-point for quantizing weight tensors.
801
802
 
802
803
  This function calculates the scale and zero-point required for quantizing
803
- a given tensor `x` based on the specified parameters. It supports grouped,
804
- per-channel, per-tensor, symmetric, and asymmetric quantization - along
805
- with any combinations of these.
804
+ a given weight tensor `x` based on the specified parameters. It supports
805
+ grouped, per-channel, per-tensor, symmetric, and asymmetric quantization.
806
+
807
+ For grouped quantization (per_channel=True, group_size > 0), the output
808
+ shapes are [out_features, n_groups] where n_groups is the number of groups
809
+ along the in_features dimension.
806
810
 
807
811
  Args:
808
- x: KerasTensor. The input tensor to quantize.
812
+ x: KerasTensor. The weight tensor to quantize with shape
813
+ [out_features, in_features].
809
814
  bits: int. The number of bits to quantize to (e.g., 4).
810
815
  symmetric: bool. Whether to use symmetric quantization.
811
816
  per_channel: bool. Whether to quantize per channel.
812
- group_size: int. The group size for quantization.
813
- weight: bool. Whether the input tensor is a weight tensor.
817
+ group_size: int. The group size for quantization. -1 means no grouping.
818
+ compute_dtype: str. The dtype for computation. Defaults to "float32".
814
819
 
815
820
  Returns:
816
821
  scale: KerasTensor. The scale tensor for quantization.
817
822
  zero: KerasTensor. The zero tensor for quantization.
818
823
  maxq: scalar. The maximum quantization value.
819
824
  """
825
+ # Input validation
820
826
  if x is None:
821
827
  raise ValueError(f"Input tensor {x} cannot be None.")
822
-
823
- # For weights, we typically expect at least a 2D tensor.
824
- if weight and len(x.shape) < 2:
828
+ if len(x.shape) < 2:
825
829
  raise ValueError(
826
830
  f"Input weight tensor {x} must have a rank of at "
827
831
  f"least 2, but got rank {len(x.shape)}."
828
832
  )
829
-
830
833
  if ops.size(x) == 0:
831
834
  raise ValueError("Input tensor 'x' cannot be empty.")
832
835
 
833
- original_shape = x.shape
834
-
835
- if per_channel:
836
- if weight:
837
- if group_size != -1:
838
- input_reshaped = ops.reshape(x, [-1, group_size])
839
- else:
840
- input_reshaped = ops.reshape(x, [original_shape[0], -1])
841
- else: # per-tensor
842
- input_reshaped = ops.reshape(x, [1, -1])
836
+ out_features, in_features = x.shape[0], x.shape[1]
843
837
 
844
- # Find min/max values
845
- min_values = ops.min(input_reshaped, axis=1)
846
- max_values = ops.max(input_reshaped, axis=1)
838
+ # Determine number of groups for quantization
839
+ if per_channel and group_size > 0:
840
+ n_groups = (in_features + group_size - 1) // group_size
841
+ else:
842
+ n_groups = 1
843
+
844
+ # Compute min/max values based on quantization mode
845
+ if n_groups > 1:
846
+ # Grouped quantization: output shape [out_features, n_groups]
847
+ remainder = in_features % group_size
848
+ if remainder != 0:
849
+ pad_size = group_size - remainder
850
+ x = ops.pad(x, [[0, 0], [0, pad_size]], constant_values=0.0)
851
+
852
+ x_grouped = ops.reshape(x, [out_features, n_groups, group_size])
853
+ min_values = ops.min(x_grouped, axis=2)
854
+ max_values = ops.max(x_grouped, axis=2)
855
+ else:
856
+ # Per-channel or per-tensor: compute stats along rows
857
+ reduction_shape = [out_features, -1] if per_channel else [1, -1]
858
+ x_reshaped = ops.reshape(x, reduction_shape)
859
+ min_values = ops.min(x_reshaped, axis=1)
860
+ max_values = ops.max(x_reshaped, axis=1)
847
861
 
848
- # Apply symmetric quantization logic if enabled
862
+ # Symmetric quantization: make range symmetric around zero
849
863
  if symmetric:
850
- max_values = ops.maximum(ops.abs(min_values), max_values)
864
+ max_abs = ops.maximum(ops.abs(min_values), max_values)
851
865
  min_values = ops.where(
852
- ops.less(min_values, 0), ops.negative(max_values), min_values
866
+ ops.less(min_values, 0), ops.negative(max_abs), min_values
853
867
  )
868
+ max_values = max_abs
854
869
 
855
- # Ensure range is not zero to avoid division errors
870
+ # Ensure non-zero range to avoid division errors
856
871
  zero_range = ops.equal(min_values, max_values)
857
872
  min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)
858
873
  max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
859
874
 
875
+ # Compute scale and zero-point
860
876
  maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)
861
-
862
- # Calculate scale and zero-point
863
877
  scale = ops.divide(ops.subtract(max_values, min_values), maxq)
878
+ scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
879
+
864
880
  if symmetric:
865
881
  zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
866
882
  else:
867
883
  zero = ops.round(ops.divide(ops.negative(min_values), scale))
868
884
 
869
- # Ensure scale is non-zero
870
- scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
871
-
872
- if weight:
873
- # Per-channel, non-grouped case: simple reshape is correct.
874
- if per_channel and group_size == -1:
875
- scale = ops.reshape(scale, [-1, 1])
876
- zero = ops.reshape(zero, [-1, 1])
877
- elif not per_channel:
878
- num_rows = original_shape[0]
879
- scale = ops.tile(ops.reshape(scale, (1, 1)), (num_rows, 1))
880
- zero = ops.tile(ops.reshape(zero, (1, 1)), (num_rows, 1))
881
- if per_channel:
885
+ # Reshape output to [out_features, n_groups] or [out_features, 1]
886
+ if n_groups > 1:
887
+ pass # Already [out_features, n_groups]
888
+ elif per_channel:
882
889
  scale = ops.reshape(scale, [-1, 1])
883
890
  zero = ops.reshape(zero, [-1, 1])
891
+ else:
892
+ # Per-tensor: tile single value to [out_features, 1]
893
+ scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))
894
+ zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))
884
895
 
885
- zero = ops.cast(zero, "uint8")
886
-
887
- return scale, zero, maxq
896
+ return scale, ops.cast(zero, "uint8"), maxq
888
897
 
889
898
 
890
899
  def quantize_with_zero_point(input_tensor, scale, zero, maxq):
@@ -29,7 +29,7 @@ class SeedGenerator:
29
29
  a local `StateGenerator` with either a deterministic or random initial
30
30
  state.
31
31
 
32
- Remark concerning the JAX backen: Note that the use of a local
32
+ Remark concerning the JAX backend: Note that the use of a local
33
33
  `StateGenerator` as seed argument is required for JIT compilation of
34
34
  RNG with the JAX backend, because the use of global state is not
35
35
  supported.
@@ -111,7 +111,7 @@ class SeedGenerator:
111
111
  return new_seed_value
112
112
 
113
113
  def get_config(self):
114
- return {"seed": self._initial_seed}
114
+ return {"seed": self._initial_seed, "name": self.name}
115
115
 
116
116
  @classmethod
117
117
  def from_config(cls, config):
@@ -0,0 +1,50 @@
1
+ """Orbax checkpoint loading functionality."""
2
+
3
+ import os
4
+
5
+ from keras.src.utils.module_utils import ocp
6
+
7
+
8
+ def is_orbax_checkpoint(filepath):
9
+ """Check if the given path is an Orbax checkpoint directory.
10
+
11
+ This function implements custom detection logic instead of relying on
12
+ Orbax APIs which may be unreliable in some environments.
13
+ """
14
+ if not os.path.exists(filepath) or not os.path.isdir(filepath):
15
+ return False
16
+
17
+ try:
18
+ # List directory contents
19
+ contents = os.listdir(filepath)
20
+
21
+ # A set is more efficient for membership testing
22
+ orbax_indicators = {
23
+ "orbax.checkpoint",
24
+ "pytree.orbax-checkpoint",
25
+ "checkpoint_metadata",
26
+ }
27
+
28
+ # Fast check for standard files
29
+ if not orbax_indicators.isdisjoint(contents):
30
+ return True
31
+
32
+ # Check for step directories or temporary files in a single pass
33
+ return any(
34
+ ".orbax-checkpoint-tmp" in item
35
+ or (item.isdigit() and os.path.isdir(os.path.join(filepath, item)))
36
+ for item in contents
37
+ )
38
+
39
+ except (OSError, PermissionError):
40
+ # If we can't read the directory, assume it's not a checkpoint
41
+ return False
42
+
43
+
44
+ def find_latest_orbax_checkpoint(checkpoint_dir):
45
+ """Find the latest checkpoint in an Orbax checkpoint directory."""
46
+ checkpointer = ocp.training.Checkpointer(directory=checkpoint_dir)
47
+ latest = checkpointer.latest
48
+ if latest is None:
49
+ raise ValueError(f"No valid checkpoints found in {checkpoint_dir}")
50
+ return os.path.join(checkpoint_dir, str(latest.step))
@@ -6,13 +6,11 @@ from absl import logging
6
6
  from keras.src.api_export import keras_export
7
7
  from keras.src.legacy.saving import legacy_h5_format
8
8
  from keras.src.saving import saving_lib
9
+ from keras.src.saving.orbax_util import find_latest_orbax_checkpoint
10
+ from keras.src.saving.orbax_util import is_orbax_checkpoint
9
11
  from keras.src.utils import file_utils
10
12
  from keras.src.utils import io_utils
11
-
12
- try:
13
- import h5py
14
- except ImportError:
15
- h5py = None
13
+ from keras.src.utils.module_utils import h5py
16
14
 
17
15
 
18
16
  @keras_export(["keras.saving.save_model", "keras.models.save_model"])
@@ -149,8 +147,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
149
147
  keras.layers.Softmax()])
150
148
  model.save("model.keras")
151
149
  loaded_model = keras.saving.load_model("model.keras")
152
- x = np.random.random((10, 3))
153
- assert np.allclose(model.predict(x), loaded_model.predict(x))
154
150
  ```
155
151
 
156
152
  Note that the model variables may have different name values
@@ -208,7 +204,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
208
204
  else:
209
205
  raise ValueError(
210
206
  f"File format not supported: filepath={filepath}. "
211
- "Keras 3 only supports V3 `.keras` files and "
207
+ "Keras 3 only supports V3 `.keras` files, "
212
208
  "legacy H5 format files (`.h5` extension). "
213
209
  "Note that the legacy SavedModel format is not "
214
210
  "supported by `load_model()` in Keras 3. In "
@@ -288,15 +284,16 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
288
284
  objects_to_skip=objects_to_skip,
289
285
  )
290
286
  elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
291
- if not h5py:
292
- raise ImportError(
293
- "Loading a H5 file requires `h5py` to be installed."
294
- )
295
287
  if objects_to_skip is not None:
296
288
  raise ValueError(
297
289
  "`objects_to_skip` only supports loading '.weights.h5' files."
298
290
  f"Received: {filepath}"
299
291
  )
292
+ if not h5py.available:
293
+ raise ImportError(
294
+ "Loading HDF5 files requires the h5py package. "
295
+ "You can install it via `pip install h5py`"
296
+ )
300
297
  with h5py.File(filepath, "r") as f:
301
298
  if "layer_names" not in f.attrs and "model_weights" in f:
302
299
  f = f["model_weights"]
@@ -308,9 +305,35 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
308
305
  legacy_h5_format.load_weights_from_hdf5_group(
309
306
  f, model, skip_mismatch
310
307
  )
308
+ elif is_orbax_checkpoint(filepath):
309
+ # Load weights from Orbax checkpoint
310
+ from keras.src.utils.module_utils import ocp
311
+
312
+ filepath = str(filepath)
313
+
314
+ # Determine if this is a root directory or a step directory
315
+ items = os.listdir(filepath)
316
+ has_step_subdirs = any(
317
+ os.path.isdir(os.path.join(filepath, item)) and item.isdigit()
318
+ for item in items
319
+ )
320
+
321
+ if has_step_subdirs:
322
+ # It's a root directory, find the latest checkpoint
323
+ checkpoint_path = find_latest_orbax_checkpoint(filepath)
324
+ else:
325
+ # It's a step directory, use it directly
326
+ checkpoint_path = filepath
327
+
328
+ # Load checkpoint
329
+ loaded_state = ocp.load_pytree(checkpoint_path)
330
+
331
+ # Set the model state directly from the loaded state
332
+ model.set_state_tree(loaded_state)
311
333
  else:
312
334
  raise ValueError(
313
335
  f"File format not supported: filepath={filepath}. "
314
- "Keras 3 only supports V3 `.keras` and `.weights.h5` "
315
- "files, or legacy V1/V2 `.h5` files."
336
+ "Keras 3 only supports V3 `.keras` files, "
337
+ "`.weights.h5` files, legacy H5 format files "
338
+ "(`.h5` extension), or Orbax checkpoints."
316
339
  )