keras-nightly 3.14.0.dev2025122704__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +3 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +3 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +3 -0
  7. keras/ops/numpy/__init__.py +3 -0
  8. keras/quantizers/__init__.py +1 -0
  9. keras/src/backend/jax/nn.py +26 -9
  10. keras/src/backend/jax/numpy.py +16 -0
  11. keras/src/backend/numpy/numpy.py +23 -0
  12. keras/src/backend/openvino/numpy.py +369 -16
  13. keras/src/backend/tensorflow/numpy.py +34 -1
  14. keras/src/backend/tensorflow/rnn.py +17 -7
  15. keras/src/backend/torch/numpy.py +36 -0
  16. keras/src/backend/torch/rnn.py +28 -11
  17. keras/src/callbacks/orbax_checkpoint.py +75 -42
  18. keras/src/dtype_policies/__init__.py +2 -0
  19. keras/src/dtype_policies/dtype_policy.py +90 -1
  20. keras/src/layers/core/dense.py +122 -6
  21. keras/src/layers/core/einsum_dense.py +151 -7
  22. keras/src/layers/core/embedding.py +1 -1
  23. keras/src/layers/core/reversible_embedding.py +10 -1
  24. keras/src/layers/layer.py +5 -0
  25. keras/src/layers/preprocessing/feature_space.py +8 -4
  26. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  27. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
  28. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  29. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  30. keras/src/losses/losses.py +24 -0
  31. keras/src/models/model.py +18 -9
  32. keras/src/ops/image.py +109 -96
  33. keras/src/ops/numpy.py +181 -0
  34. keras/src/quantizers/__init__.py +2 -0
  35. keras/src/quantizers/awq.py +361 -0
  36. keras/src/quantizers/awq_config.py +140 -0
  37. keras/src/quantizers/awq_core.py +217 -0
  38. keras/src/quantizers/gptq.py +1 -2
  39. keras/src/quantizers/gptq_core.py +1 -1
  40. keras/src/quantizers/quantization_config.py +14 -0
  41. keras/src/quantizers/quantizers.py +61 -52
  42. keras/src/random/seed_generator.py +2 -2
  43. keras/src/saving/file_editor.py +81 -6
  44. keras/src/saving/orbax_util.py +50 -0
  45. keras/src/saving/saving_api.py +37 -14
  46. keras/src/utils/jax_layer.py +69 -31
  47. keras/src/utils/module_utils.py +11 -0
  48. keras/src/utils/tracking.py +5 -5
  49. keras/src/version.py +1 -1
  50. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
  51. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +53 -49
  52. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
  53. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,217 @@
1
+ """AWQ core functionality for layer-wise quantization.
2
+
3
+ This module provides the orchestration logic for applying AWQ quantization
4
+ to transformer models in a layer-by-layer fashion.
5
+ """
6
+
7
+ from contextlib import contextmanager
8
+
9
+ from absl import logging
10
+
11
+ from keras.src import ops
12
+ from keras.src import utils as keras_utils
13
+ from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
14
+ from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
15
+ from keras.src.quantizers.awq import AWQ
16
+ from keras.src.quantizers.awq_config import AWQConfig
17
+ from keras.src.quantizers.gptq_core import find_layers_in_block
18
+ from keras.src.quantizers.gptq_core import get_dataloader
19
+ from keras.src.quantizers.utils import should_quantize_layer
20
+
21
+
22
+ @contextmanager
23
+ def stream_activations(layers_map, awq_objects):
24
+ """Context manager to capture activations for AWQ calibration.
25
+
26
+ Temporarily patches layer.call methods to capture activation statistics
27
+ for computing per-channel scaling factors.
28
+
29
+ Args:
30
+ layers_map: Dict[str, Layer]. Mapping from layer names to layers.
31
+ awq_objects: Dict[str, AWQ]. Mapping from names to AWQ instances.
32
+
33
+ Yields:
34
+ None: The patched state is active only within the `with` block.
35
+ """
36
+ original_calls = {}
37
+
38
+ def create_hook(name, original_call_func):
39
+ def hook(*args, **kwargs):
40
+ inp = args[0] if args else kwargs["inputs"]
41
+ num_features = awq_objects[name].rows
42
+ input_2d = ops.reshape(inp, (-1, num_features))
43
+ awq_objects[name].update_activation_magnitudes(input_2d)
44
+ return original_call_func(*args, **kwargs)
45
+
46
+ return hook
47
+
48
+ try:
49
+ for name, layer in layers_map.items():
50
+ original_calls[name] = layer.call
51
+ layer.call = create_hook(name, layer.call)
52
+ yield
53
+ finally:
54
+ for name, layer in layers_map.items():
55
+ layer.call = original_calls[name]
56
+
57
+
58
+ def apply_awq_layerwise(dataloader, config, structure, filters=None):
59
+ """Apply AWQ quantization layer-by-layer to a Keras model.
60
+
61
+ This function processes the model sequentially, one block at a time:
62
+ 1. Captures activation statistics through calibration data forward pass
63
+ 2. Uses activation magnitudes to determine weight saliency
64
+ 3. Finds optimal per-channel scales via grid search
65
+ 4. Quantizes weights with AWQ scaling
66
+
67
+ Args:
68
+ dataloader: Calibration data as numpy array.
69
+ config: AWQConfig instance.
70
+ structure: Dict with 'pre_block_layers' and 'sequential_blocks'.
71
+ filters: Optional layer filters.
72
+ """
73
+ num_samples = config.num_samples
74
+ logging.info("Starting AWQ quantization...")
75
+
76
+ pre_layers = structure.get("pre_block_layers", [])
77
+ transformer_blocks = structure.get("sequential_blocks", [])
78
+
79
+ if not transformer_blocks:
80
+ raise ValueError(
81
+ "No sequential blocks found in the structure to quantize."
82
+ )
83
+
84
+ # Process inputs through pre-block layers (e.g., embedding)
85
+ inputs = []
86
+ for batch in dataloader:
87
+ batch = ops.convert_to_tensor(batch, dtype="int32")
88
+ for layer in pre_layers:
89
+ batch = layer(batch)
90
+ inputs.append(batch)
91
+
92
+ num_samples = min(num_samples, len(inputs))
93
+ progbar = keras_utils.Progbar(target=len(transformer_blocks))
94
+
95
+ for block_idx, block in enumerate(transformer_blocks):
96
+ logging.info(f"Quantizing Block {block_idx}")
97
+ sub_layers_map = find_layers_in_block(block)
98
+
99
+ # Apply filters
100
+ final_sub_layers_map = {}
101
+ for name, layer in sub_layers_map.items():
102
+ if not should_quantize_layer(layer, filters):
103
+ continue
104
+ final_sub_layers_map[name] = layer
105
+
106
+ sub_layers_map = final_sub_layers_map
107
+
108
+ if not sub_layers_map:
109
+ logging.info(
110
+ f" No quantizable layers found in block {block_idx}. Skipping."
111
+ )
112
+ else:
113
+ logging.info(f"Found layers: {list(sub_layers_map.keys())}")
114
+
115
+ # Create AWQ objects for each layer
116
+ awq_objects = {
117
+ name: AWQ(layer, config)
118
+ for name, layer in sub_layers_map.items()
119
+ }
120
+
121
+ # Capture activation statistics
122
+ with stream_activations(sub_layers_map, awq_objects):
123
+ for sample_idx in range(num_samples):
124
+ current_input = inputs[sample_idx]
125
+ if len(current_input.shape) == 2:
126
+ current_input = ops.expand_dims(current_input, axis=0)
127
+ _ = block(current_input)
128
+
129
+ # Quantize each layer
130
+ for name, awq_object in awq_objects.items():
131
+ logging.info(f"Quantizing {name}...")
132
+ awq_object.quantize_layer()
133
+ awq_object.free()
134
+
135
+ del awq_objects
136
+
137
+ # Generate inputs for next block
138
+ if block_idx < len(transformer_blocks) - 1:
139
+ logging.info(f"Generating inputs for block {block_idx + 1}...")
140
+ next_block_inputs = []
141
+ for sample_idx in range(num_samples):
142
+ current_input = inputs[sample_idx]
143
+ if len(current_input.shape) == 2:
144
+ current_input = ops.expand_dims(current_input, axis=0)
145
+ output = block(current_input)[0]
146
+ next_block_inputs.append(output)
147
+ inputs = next_block_inputs
148
+
149
+ progbar.update(current=block_idx + 1)
150
+
151
+ logging.info("AWQ quantization complete.")
152
+
153
+
154
+ def awq_quantize(config, quantization_layer_structure, filters=None):
155
+ """Main entry point for AWQ quantization.
156
+
157
+ Args:
158
+ config: AWQConfig instance.
159
+ quantization_layer_structure: Model structure dictionary.
160
+ filters: Optional layer filters.
161
+ """
162
+ if config.dataset is None or config.tokenizer is None:
163
+ raise ValueError(
164
+ "AWQ quantization requires a dataset and tokenizer. "
165
+ "Please provide them in the AWQConfig."
166
+ )
167
+
168
+ if quantization_layer_structure is None:
169
+ raise ValueError(
170
+ "For 'awq' mode, a valid quantization structure must be provided "
171
+ "either via `config.quantization_layer_structure` or by overriding "
172
+ "`model.get_quantization_layer_structure(mode)`. The structure "
173
+ "should be a dictionary with keys 'pre_block_layers' and "
174
+ "'sequential_blocks'."
175
+ )
176
+
177
+ # Load calibration data
178
+ dataloader = get_dataloader(
179
+ config.tokenizer,
180
+ config.sequence_length,
181
+ config.dataset,
182
+ num_samples=config.num_samples,
183
+ )
184
+
185
+ apply_awq_layerwise(
186
+ dataloader[: config.num_samples],
187
+ config,
188
+ quantization_layer_structure,
189
+ filters=filters,
190
+ )
191
+
192
+
193
+ def get_group_size_for_layer(layer, config):
194
+ """Get group size from config or dtype policy.
195
+
196
+ Args:
197
+ layer: The layer to get group size for.
198
+ config: Optional AWQConfig instance.
199
+
200
+ Returns:
201
+ int: The group size for quantization.
202
+
203
+ Raises:
204
+ ValueError: If group size cannot be determined.
205
+ """
206
+ if config and isinstance(config, AWQConfig):
207
+ return config.group_size
208
+ elif isinstance(layer.dtype_policy, AWQDTypePolicy):
209
+ return layer.dtype_policy.group_size
210
+ elif isinstance(layer.dtype_policy, DTypePolicyMap):
211
+ policy = layer.dtype_policy[layer.path]
212
+ if isinstance(policy, AWQDTypePolicy):
213
+ return policy.group_size
214
+ raise ValueError(
215
+ "For AWQ quantization, group_size must be specified "
216
+ "through AWQConfig or AWQDTypePolicy."
217
+ )
@@ -1,5 +1,4 @@
1
1
  import types
2
- from functools import partial
3
2
 
4
3
  from keras.src import ops
5
4
  from keras.src import quantizers
@@ -466,7 +465,7 @@ class GPTQ:
466
465
  group_size=self.config.group_size,
467
466
  activation_order=self.config.activation_order,
468
467
  order_metric=ops.diagonal(hessian_matrix),
469
- compute_scale_zero=partial(self.quantizer.find_params, weight=True),
468
+ compute_scale_zero=self.quantizer.find_params,
470
469
  )
471
470
  quantized = ops.cast(
472
471
  quantized, self.original_layer.quantized_kernel.dtype
@@ -131,7 +131,7 @@ def get_dataloader(
131
131
  pieces = []
132
132
  if isinstance(dataset_list[0], str):
133
133
  for i, s in enumerate(dataset_list):
134
- toks = np.asarray(tokenizer.tokenize(s)).reshape(-1)
134
+ toks = ops.convert_to_numpy(tokenizer.tokenize(s)).reshape(-1)
135
135
  pieces.append(toks)
136
136
  # avoid windows that span document boundaries
137
137
  if eos_id is not None and i < len(dataset_list) - 1:
@@ -182,6 +182,11 @@ def validate_and_resolve_config(mode, config):
182
182
  "For GPTQ, you must pass a `GPTQConfig` object in the "
183
183
  "`config` argument."
184
184
  )
185
+ elif mode == "awq":
186
+ raise ValueError(
187
+ "For AWQ, you must pass an `AWQConfig` object in the "
188
+ "`config` argument."
189
+ )
185
190
  else:
186
191
  if mode is not None:
187
192
  raise ValueError(
@@ -220,6 +225,15 @@ def validate_and_resolve_config(mode, config):
220
225
  f"`GPTQConfig`. Received: {type(config)}"
221
226
  )
222
227
 
228
+ if mode == "awq":
229
+ from keras.src.quantizers.awq_config import AWQConfig
230
+
231
+ if not isinstance(config, AWQConfig):
232
+ raise ValueError(
233
+ "Mode 'awq' requires a valid `config` argument of type "
234
+ f"`AWQConfig`. Received: {type(config)}"
235
+ )
236
+
223
237
  return config
224
238
 
225
239
 
@@ -653,11 +653,14 @@ def unpack_int4(packed, orig_len, axis=0, dtype="int8"):
653
653
  )
654
654
 
655
655
  def to_signed(x):
656
- """Converts unpacked nibbles [0, 15] to signed int4 [-8, 7]."""
656
+ """Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
657
+
658
+ Uses a branchless XOR approach: (x ^ 8) - 8
659
+ This maps: 0->0, 1->1, ..., 7->7, 8->-8, 9->-7, ..., 15->-1
660
+ """
657
661
  dtype_x = backend.standardize_dtype(x.dtype)
658
662
  eight = ops.cast(8, dtype_x)
659
- sixteen = ops.cast(16, dtype_x)
660
- return ops.where(x < eight, x, x - sixteen)
663
+ return ops.subtract(ops.bitwise_xor(x, eight), eight)
661
664
 
662
665
  rank = getattr(packed.shape, "rank", None) or len(packed.shape)
663
666
  if axis < 0:
@@ -748,7 +751,7 @@ class GPTQQuantizer(Quantizer):
748
751
  self.zero = None
749
752
  self.maxq = None
750
753
 
751
- def find_params(self, input_tensor, weight=True):
754
+ def find_params(self, input_tensor):
752
755
  """Finds quantization parameters (scale and zero) for a given tensor."""
753
756
  self.scale, self.zero, self.maxq = compute_quantization_parameters(
754
757
  input_tensor,
@@ -756,7 +759,6 @@ class GPTQQuantizer(Quantizer):
756
759
  symmetric=self.symmetric,
757
760
  per_channel=self.per_channel,
758
761
  group_size=self.group_size,
759
- weight=weight,
760
762
  compute_dtype=self.compute_dtype,
761
763
  )
762
764
  return self.scale, self.zero, self.maxq
@@ -793,98 +795,105 @@ def compute_quantization_parameters(
793
795
  symmetric=False,
794
796
  per_channel=False,
795
797
  group_size=-1,
796
- weight=False,
797
798
  compute_dtype="float32",
798
799
  ):
799
800
  """
800
- Computes the scale and zero-point for quantization.
801
+ Computes the scale and zero-point for quantizing weight tensors.
801
802
 
802
803
  This function calculates the scale and zero-point required for quantizing
803
- a given tensor `x` based on the specified parameters. It supports grouped,
804
- per-channel, per-tensor, symmetric, and asymmetric quantization - along
805
- with any combinations of these.
804
+ a given weight tensor `x` based on the specified parameters. It supports
805
+ grouped, per-channel, per-tensor, symmetric, and asymmetric quantization.
806
+
807
+ For grouped quantization (per_channel=True, group_size > 0), the output
808
+ shapes are [out_features, n_groups] where n_groups is the number of groups
809
+ along the in_features dimension.
806
810
 
807
811
  Args:
808
- x: KerasTensor. The input tensor to quantize.
812
+ x: KerasTensor. The weight tensor to quantize with shape
813
+ [out_features, in_features].
809
814
  bits: int. The number of bits to quantize to (e.g., 4).
810
815
  symmetric: bool. Whether to use symmetric quantization.
811
816
  per_channel: bool. Whether to quantize per channel.
812
- group_size: int. The group size for quantization.
813
- weight: bool. Whether the input tensor is a weight tensor.
817
+ group_size: int. The group size for quantization. -1 means no grouping.
818
+ compute_dtype: str. The dtype for computation. Defaults to "float32".
814
819
 
815
820
  Returns:
816
821
  scale: KerasTensor. The scale tensor for quantization.
817
822
  zero: KerasTensor. The zero tensor for quantization.
818
823
  maxq: scalar. The maximum quantization value.
819
824
  """
825
+ # Input validation
820
826
  if x is None:
821
827
  raise ValueError(f"Input tensor {x} cannot be None.")
822
-
823
- # For weights, we typically expect at least a 2D tensor.
824
- if weight and len(x.shape) < 2:
828
+ if len(x.shape) < 2:
825
829
  raise ValueError(
826
830
  f"Input weight tensor {x} must have a rank of at "
827
831
  f"least 2, but got rank {len(x.shape)}."
828
832
  )
829
-
830
833
  if ops.size(x) == 0:
831
834
  raise ValueError("Input tensor 'x' cannot be empty.")
832
835
 
833
- original_shape = x.shape
834
-
835
- if per_channel:
836
- if weight:
837
- if group_size != -1:
838
- input_reshaped = ops.reshape(x, [-1, group_size])
839
- else:
840
- input_reshaped = ops.reshape(x, [original_shape[0], -1])
841
- else: # per-tensor
842
- input_reshaped = ops.reshape(x, [1, -1])
836
+ out_features, in_features = x.shape[0], x.shape[1]
843
837
 
844
- # Find min/max values
845
- min_values = ops.min(input_reshaped, axis=1)
846
- max_values = ops.max(input_reshaped, axis=1)
838
+ # Determine number of groups for quantization
839
+ if per_channel and group_size > 0:
840
+ n_groups = (in_features + group_size - 1) // group_size
841
+ else:
842
+ n_groups = 1
843
+
844
+ # Compute min/max values based on quantization mode
845
+ if n_groups > 1:
846
+ # Grouped quantization: output shape [out_features, n_groups]
847
+ remainder = in_features % group_size
848
+ if remainder != 0:
849
+ pad_size = group_size - remainder
850
+ x = ops.pad(x, [[0, 0], [0, pad_size]], constant_values=0.0)
851
+
852
+ x_grouped = ops.reshape(x, [out_features, n_groups, group_size])
853
+ min_values = ops.min(x_grouped, axis=2)
854
+ max_values = ops.max(x_grouped, axis=2)
855
+ else:
856
+ # Per-channel or per-tensor: compute stats along rows
857
+ reduction_shape = [out_features, -1] if per_channel else [1, -1]
858
+ x_reshaped = ops.reshape(x, reduction_shape)
859
+ min_values = ops.min(x_reshaped, axis=1)
860
+ max_values = ops.max(x_reshaped, axis=1)
847
861
 
848
- # Apply symmetric quantization logic if enabled
862
+ # Symmetric quantization: make range symmetric around zero
849
863
  if symmetric:
850
- max_values = ops.maximum(ops.abs(min_values), max_values)
864
+ max_abs = ops.maximum(ops.abs(min_values), max_values)
851
865
  min_values = ops.where(
852
- ops.less(min_values, 0), ops.negative(max_values), min_values
866
+ ops.less(min_values, 0), ops.negative(max_abs), min_values
853
867
  )
868
+ max_values = max_abs
854
869
 
855
- # Ensure range is not zero to avoid division errors
870
+ # Ensure non-zero range to avoid division errors
856
871
  zero_range = ops.equal(min_values, max_values)
857
872
  min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)
858
873
  max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
859
874
 
875
+ # Compute scale and zero-point
860
876
  maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)
861
-
862
- # Calculate scale and zero-point
863
877
  scale = ops.divide(ops.subtract(max_values, min_values), maxq)
878
+ scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
879
+
864
880
  if symmetric:
865
881
  zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
866
882
  else:
867
883
  zero = ops.round(ops.divide(ops.negative(min_values), scale))
868
884
 
869
- # Ensure scale is non-zero
870
- scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
871
-
872
- if weight:
873
- # Per-channel, non-grouped case: simple reshape is correct.
874
- if per_channel and group_size == -1:
875
- scale = ops.reshape(scale, [-1, 1])
876
- zero = ops.reshape(zero, [-1, 1])
877
- elif not per_channel:
878
- num_rows = original_shape[0]
879
- scale = ops.tile(ops.reshape(scale, (1, 1)), (num_rows, 1))
880
- zero = ops.tile(ops.reshape(zero, (1, 1)), (num_rows, 1))
881
- if per_channel:
885
+ # Reshape output to [out_features, n_groups] or [out_features, 1]
886
+ if n_groups > 1:
887
+ pass # Already [out_features, n_groups]
888
+ elif per_channel:
882
889
  scale = ops.reshape(scale, [-1, 1])
883
890
  zero = ops.reshape(zero, [-1, 1])
891
+ else:
892
+ # Per-tensor: tile single value to [out_features, 1]
893
+ scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))
894
+ zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))
884
895
 
885
- zero = ops.cast(zero, "uint8")
886
-
887
- return scale, zero, maxq
896
+ return scale, ops.cast(zero, "uint8"), maxq
888
897
 
889
898
 
890
899
  def quantize_with_zero_point(input_tensor, scale, zero, maxq):
@@ -29,7 +29,7 @@ class SeedGenerator:
29
29
  a local `StateGenerator` with either a deterministic or random initial
30
30
  state.
31
31
 
32
- Remark concerning the JAX backen: Note that the use of a local
32
+ Remark concerning the JAX backend: Note that the use of a local
33
33
  `StateGenerator` as seed argument is required for JIT compilation of
34
34
  RNG with the JAX backend, because the use of global state is not
35
35
  supported.
@@ -111,7 +111,7 @@ class SeedGenerator:
111
111
  return new_seed_value
112
112
 
113
113
  def get_config(self):
114
- return {"seed": self._initial_seed}
114
+ return {"seed": self._initial_seed, "name": self.name}
115
115
 
116
116
  @classmethod
117
117
  def from_config(cls, config):
@@ -455,6 +455,9 @@ class KerasFileEditor:
455
455
  def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
456
456
  metadata = metadata or {}
457
457
 
458
+ # ------------------------------------------------------
459
+ # Collect metadata for this HDF5 group
460
+ # ------------------------------------------------------
458
461
  object_metadata = {}
459
462
  for k, v in data.attrs.items():
460
463
  object_metadata[k] = v
@@ -462,26 +465,98 @@ class KerasFileEditor:
462
465
  metadata[inner_path] = object_metadata
463
466
 
464
467
  result = collections.OrderedDict()
468
+
469
+ # ------------------------------------------------------
470
+ # Iterate over all keys in this HDF5 group
471
+ # ------------------------------------------------------
465
472
  for key in data.keys():
466
- inner_path = f"{inner_path}/{key}"
473
+ # IMPORTANT:
474
+ # Never mutate inner_path; use local variable.
475
+ current_inner_path = f"{inner_path}/{key}"
467
476
  value = data[key]
477
+
478
+ # ------------------------------------------------------
479
+ # CASE 1 — HDF5 GROUP → RECURSE
480
+ # ------------------------------------------------------
468
481
  if isinstance(value, h5py.Group):
482
+ # Skip empty groups
469
483
  if len(value) == 0:
470
484
  continue
485
+
486
+ # Skip empty "vars" groups
471
487
  if "vars" in value.keys() and len(value["vars"]) == 0:
472
488
  continue
473
489
 
474
- if hasattr(value, "keys"):
490
+ # Recurse into "vars" subgroup when present
475
491
  if "vars" in value.keys():
476
492
  result[key], metadata = self._extract_weights_from_store(
477
- value["vars"], metadata=metadata, inner_path=inner_path
493
+ value["vars"],
494
+ metadata=metadata,
495
+ inner_path=current_inner_path,
478
496
  )
479
497
  else:
498
+ # Recurse normally
480
499
  result[key], metadata = self._extract_weights_from_store(
481
- value, metadata=metadata, inner_path=inner_path
500
+ value,
501
+ metadata=metadata,
502
+ inner_path=current_inner_path,
482
503
  )
483
- else:
484
- result[key] = value[()]
504
+
505
+ continue # finished processing this key
506
+
507
+ # ------------------------------------------------------
508
+ # CASE 2 — HDF5 DATASET → SAFE LOADING
509
+ # ------------------------------------------------------
510
+
511
+ # Skip any objects that are not proper datasets
512
+ if not hasattr(value, "shape") or not hasattr(value, "dtype"):
513
+ continue
514
+
515
+ shape = value.shape
516
+ dtype = value.dtype
517
+
518
+ # ------------------------------------------------------
519
+ # Validate SHAPE (avoid malformed / malicious metadata)
520
+ # ------------------------------------------------------
521
+
522
+ # No negative dimensions
523
+ if any(dim < 0 for dim in shape):
524
+ raise ValueError(
525
+ "Malformed HDF5 dataset shape encountered in .keras file; "
526
+ "negative dimension detected."
527
+ )
528
+
529
+ # Prevent absurdly high-rank tensors
530
+ if len(shape) > 64:
531
+ raise ValueError(
532
+ "Malformed HDF5 dataset shape encountered in .keras file; "
533
+ "tensor rank exceeds safety limit."
534
+ )
535
+
536
+ # Safe product computation (Python int is unbounded)
537
+ num_elems = int(np.prod(shape))
538
+
539
+ # ------------------------------------------------------
540
+ # Validate TOTAL memory size
541
+ # ------------------------------------------------------
542
+ MAX_BYTES = 1 << 32 # 4 GiB
543
+
544
+ size_bytes = num_elems * dtype.itemsize
545
+
546
+ if size_bytes > MAX_BYTES:
547
+ raise ValueError(
548
+ f"HDF5 dataset too large to load safely "
549
+ f"({size_bytes} bytes; limit is {MAX_BYTES})."
550
+ )
551
+
552
+ # ------------------------------------------------------
553
+ # SAFE — load dataset (guaranteed ≤ 4 GiB)
554
+ # ------------------------------------------------------
555
+ result[key] = value[()]
556
+
557
+ # ------------------------------------------------------
558
+ # Return final tree and metadata
559
+ # ------------------------------------------------------
485
560
  return result, metadata
486
561
 
487
562
  def _generate_filepath_info(self, rich_style=False):
@@ -0,0 +1,50 @@
1
+ """Orbax checkpoint loading functionality."""
2
+
3
+ import os
4
+
5
+ from keras.src.utils.module_utils import ocp
6
+
7
+
8
+ def is_orbax_checkpoint(filepath):
9
+ """Check if the given path is an Orbax checkpoint directory.
10
+
11
+ This function implements custom detection logic instead of relying on
12
+ Orbax APIs which may be unreliable in some environments.
13
+ """
14
+ if not os.path.exists(filepath) or not os.path.isdir(filepath):
15
+ return False
16
+
17
+ try:
18
+ # List directory contents
19
+ contents = os.listdir(filepath)
20
+
21
+ # A set is more efficient for membership testing
22
+ orbax_indicators = {
23
+ "orbax.checkpoint",
24
+ "pytree.orbax-checkpoint",
25
+ "checkpoint_metadata",
26
+ }
27
+
28
+ # Fast check for standard files
29
+ if not orbax_indicators.isdisjoint(contents):
30
+ return True
31
+
32
+ # Check for step directories or temporary files in a single pass
33
+ return any(
34
+ ".orbax-checkpoint-tmp" in item
35
+ or (item.isdigit() and os.path.isdir(os.path.join(filepath, item)))
36
+ for item in contents
37
+ )
38
+
39
+ except (OSError, PermissionError):
40
+ # If we can't read the directory, assume it's not a checkpoint
41
+ return False
42
+
43
+
44
+ def find_latest_orbax_checkpoint(checkpoint_dir):
45
+ """Find the latest checkpoint in an Orbax checkpoint directory."""
46
+ checkpointer = ocp.training.Checkpointer(directory=checkpoint_dir)
47
+ latest = checkpointer.latest
48
+ if latest is None:
49
+ raise ValueError(f"No valid checkpoints found in {checkpoint_dir}")
50
+ return os.path.join(checkpoint_dir, str(latest.step))