keras-nightly 3.14.0.dev2026012704__py3-none-any.whl → 3.14.0.dev2026012904__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +1 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +3 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +1 -0
  7. keras/ops/numpy/__init__.py +1 -0
  8. keras/quantizers/__init__.py +3 -0
  9. keras/src/backend/jax/core.py +12 -2
  10. keras/src/backend/jax/numpy.py +5 -0
  11. keras/src/backend/numpy/numpy.py +5 -0
  12. keras/src/backend/openvino/numpy.py +6 -0
  13. keras/src/backend/tensorflow/numpy.py +21 -0
  14. keras/src/backend/torch/numpy.py +10 -0
  15. keras/src/callbacks/orbax_checkpoint.py +41 -8
  16. keras/src/dtype_policies/__init__.py +2 -0
  17. keras/src/dtype_policies/dtype_policy.py +80 -1
  18. keras/src/layers/core/dense.py +278 -95
  19. keras/src/layers/core/einsum_dense.py +350 -181
  20. keras/src/layers/core/embedding.py +236 -49
  21. keras/src/layers/core/reversible_embedding.py +177 -35
  22. keras/src/layers/preprocessing/discretization.py +30 -1
  23. keras/src/ops/numpy.py +54 -0
  24. keras/src/quantizers/__init__.py +6 -0
  25. keras/src/quantizers/quantization_config.py +98 -4
  26. keras/src/quantizers/quantizers.py +262 -32
  27. keras/src/saving/file_editor.py +7 -1
  28. keras/src/saving/saving_api.py +66 -2
  29. keras/src/saving/saving_lib.py +46 -47
  30. keras/src/version.py +1 -1
  31. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/METADATA +1 -1
  32. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/RECORD +34 -34
  33. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/WHEEL +0 -0
  34. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ import math
2
+
1
3
  import ml_dtypes
2
4
  import numpy as np
3
5
 
@@ -118,6 +120,190 @@ def abs_max_quantize(
118
120
  return outputs, scale
119
121
 
120
122
 
123
+ @keras_export("keras.quantizers.abs_max_quantize_grouped_with_zero_point")
124
+ def abs_max_quantize_grouped_with_zero_point(
125
+ inputs,
126
+ block_size,
127
+ value_range=(-8, 7),
128
+ dtype="int8",
129
+ epsilon=backend.epsilon(),
130
+ to_numpy=False,
131
+ ):
132
+ """Quantizes a 2D tensor using grouped asymmetric quantization with
133
+ zero point.
134
+
135
+ Groups are formed along axis 0 (the input/contracting dimension).
136
+ Each group of `block_size` rows gets its own scale factor and zero point
137
+ per column. This is useful for weight distributions that are not centered
138
+ around zero.
139
+
140
+ Args:
141
+ inputs: Input tensor to quantize. Shape: `(input_dim, output_dim)`.
142
+ block_size: Number of elements per group along axis 0.
143
+ value_range: Tuple of `(min, max)` quantization range.
144
+ dtype: Data type of quantized output.
145
+ epsilon: Small value to avoid division by zero.
146
+ to_numpy: Whether to perform computation in numpy for memory
147
+ efficiency.
148
+
149
+ Returns:
150
+ A tuple `(quantized_tensor, scale, zero_point)` where:
151
+ - `quantized_tensor`: Same shape as inputs, dtype=`dtype`.
152
+ - `scale`: Shape `(n_groups, output_dim)` where
153
+ `n_groups = ceil(input_dim / block_size)`.
154
+ - `zero_point`: Shape `(n_groups, output_dim)`, dtype=`uint8`.
155
+
156
+ Example:
157
+
158
+ ```python
159
+ >>> import numpy as np
160
+ >>> from keras.quantizers import abs_max_quantize_grouped_with_zero_point
161
+ >>> kernel = np.random.randn(512, 256).astype("float32")
162
+ >>> quantized, scale, zero_point = abs_max_quantize_grouped_with_zero_point(
163
+ ... kernel, block_size=128, value_range=(-8, 7)
164
+ ... )
165
+ >>> quantized.shape
166
+ (512, 256)
167
+ >>> scale.shape # 512 / 128 = 4 groups
168
+ (4, 256)
169
+ >>> zero_point.shape
170
+ (4, 256)
171
+ ```
172
+ """
173
+ if to_numpy:
174
+ return _abs_max_quantize_grouped_with_zero_point_numpy(
175
+ inputs, block_size, value_range, dtype, epsilon
176
+ )
177
+ return _abs_max_quantize_grouped_with_zero_point_tensor(
178
+ inputs, block_size, value_range, dtype, epsilon
179
+ )
180
+
181
+
182
+ def _abs_max_quantize_grouped_with_zero_point_numpy(
183
+ inputs, block_size, value_range, dtype, epsilon
184
+ ):
185
+ """NumPy implementation of grouped asymmetric quantization.
186
+
187
+ Uses NumPy for computation to reduce GPU memory usage during
188
+ model quantization.
189
+ """
190
+ original_dtype = backend.standardize_dtype(inputs.dtype)
191
+ inputs = ops.convert_to_numpy(inputs)
192
+
193
+ input_dim, output_dim = inputs.shape
194
+ n_groups = math.ceil(input_dim / block_size)
195
+ qmin, qmax = value_range
196
+
197
+ # Zero-pad rows so input_dim is divisible by block_size
198
+ padded_input_dim = n_groups * block_size
199
+ if padded_input_dim > input_dim:
200
+ padding = np.zeros(
201
+ (padded_input_dim - input_dim, output_dim), dtype=inputs.dtype
202
+ )
203
+ inputs_padded = np.concatenate([inputs, padding], axis=0)
204
+ else:
205
+ inputs_padded = inputs
206
+
207
+ inputs_reshaped = inputs_padded.reshape(n_groups, block_size, output_dim)
208
+
209
+ # Compute per-group min/max for asymmetric quantization
210
+ min_val = np.min(inputs_reshaped, axis=1, keepdims=True)
211
+ max_val = np.max(inputs_reshaped, axis=1, keepdims=True)
212
+
213
+ # Scale maps the [min, max] range to [qmin, qmax]
214
+ scale = np.divide(np.subtract(max_val, min_val) + epsilon, qmax - qmin)
215
+
216
+ # Zero point shifts the quantized range to include the original zero
217
+ zero_point = np.round(np.divide(-min_val, scale)) + qmin
218
+ zero_point = np.clip(zero_point, qmin, qmax)
219
+
220
+ # Quantize: q = round(input / scale) + zero_point
221
+ outputs = np.round(np.divide(inputs_reshaped, scale)) + zero_point
222
+ outputs = np.clip(outputs, qmin, qmax)
223
+ outputs = outputs.astype(dtype)
224
+
225
+ # Remove padding and squeeze to (n_groups, output_dim)
226
+ outputs = outputs.reshape(padded_input_dim, output_dim)[:input_dim, :]
227
+ scale = np.squeeze(scale, axis=1)
228
+ zero_point = np.squeeze(zero_point, axis=1).astype("int8")
229
+
230
+ return (
231
+ ops.convert_to_tensor(outputs),
232
+ ops.convert_to_tensor(scale, dtype=original_dtype),
233
+ ops.convert_to_tensor(zero_point),
234
+ )
235
+
236
+
237
+ def _abs_max_quantize_grouped_with_zero_point_tensor(
238
+ inputs, block_size, value_range, dtype, epsilon
239
+ ):
240
+ """Tensor backend implementation of grouped asymmetric quantization."""
241
+ original_dtype = backend.standardize_dtype(inputs.dtype)
242
+ inputs = ops.convert_to_tensor(inputs)
243
+
244
+ input_shape = ops.shape(inputs)
245
+ input_dim = input_shape[0]
246
+ output_dim = input_shape[1]
247
+ qmin, qmax = value_range
248
+
249
+ # Infer bit-width from quantization range (e.g., [-8, 7] -> 4 bits)
250
+ num_levels = qmax - qmin + 1
251
+ bits = int(math.log2(num_levels))
252
+
253
+ n_groups = int(math.ceil(int(input_dim) / block_size))
254
+ padded_input_dim = n_groups * block_size
255
+
256
+ # Transpose to [out_features, in_features] for
257
+ # compute_quantization_parameters
258
+ inputs_t = ops.transpose(inputs)
259
+
260
+ # Compute scale and zero point using the unified quantization function
261
+ scale_t, zero_point_t, _ = compute_quantization_parameters(
262
+ inputs_t,
263
+ bits=bits,
264
+ symmetric=False,
265
+ per_channel=True,
266
+ group_size=block_size,
267
+ compute_dtype=original_dtype,
268
+ epsilon=epsilon,
269
+ signed=True,
270
+ )
271
+
272
+ # Transpose results back to (n_groups, output_dim)
273
+ scale = ops.transpose(scale_t)
274
+ zero_point = ops.transpose(zero_point_t)
275
+
276
+ # Zero-pad rows so input_dim is divisible by block_size
277
+ pad_size = padded_input_dim - int(input_dim)
278
+ if pad_size > 0:
279
+ padding = ops.zeros((pad_size, output_dim), dtype=inputs.dtype)
280
+ inputs_padded = ops.concatenate([inputs, padding], axis=0)
281
+ else:
282
+ inputs_padded = inputs
283
+
284
+ inputs_reshaped = ops.reshape(
285
+ inputs_padded, (n_groups, block_size, output_dim)
286
+ )
287
+
288
+ # Expand scale and zero_point for broadcasting across block_size
289
+ scale_expanded = ops.expand_dims(scale, axis=1)
290
+ zero_point_expanded = ops.expand_dims(zero_point, axis=1)
291
+
292
+ # Quantize: q = round(input / scale) + zero_point
293
+ outputs = ops.add(
294
+ ops.round(ops.divide(inputs_reshaped, scale_expanded)),
295
+ zero_point_expanded,
296
+ )
297
+ outputs = ops.clip(outputs, qmin, qmax)
298
+ outputs = ops.cast(outputs, dtype)
299
+
300
+ # Remove padding
301
+ outputs = ops.reshape(outputs, (padded_input_dim, output_dim))
302
+ outputs = outputs[:input_dim, :]
303
+
304
+ return outputs, scale, zero_point
305
+
306
+
121
307
  @keras_export("keras.quantizers.AbsMaxQuantizer")
122
308
  class AbsMaxQuantizer(Quantizer):
123
309
  def __init__(
@@ -796,6 +982,8 @@ def compute_quantization_parameters(
796
982
  per_channel=False,
797
983
  group_size=-1,
798
984
  compute_dtype="float32",
985
+ epsilon=0.0,
986
+ signed=False,
799
987
  ):
800
988
  """
801
989
  Computes the scale and zero-point for quantizing weight tensors.
@@ -816,10 +1004,17 @@ def compute_quantization_parameters(
816
1004
  per_channel: bool. Whether to quantize per channel.
817
1005
  group_size: int. The group size for quantization. -1 means no grouping.
818
1006
  compute_dtype: str. The dtype for computation. Defaults to "float32".
1007
+ epsilon: float. Small value added to (max - min) before computing
1008
+ scale to avoid division by zero. Defaults to 0.0.
1009
+ signed: bool. Whether to use signed quantization range. If True, uses
1010
+ range [-2^(bits-1), 2^(bits-1)-1] (e.g., [-8, 7] for 4-bit).
1011
+ If False, uses range [0, 2^bits-1] (e.g., [0, 15] for 4-bit).
1012
+ Defaults to False.
819
1013
 
820
1014
  Returns:
821
1015
  scale: KerasTensor. The scale tensor for quantization.
822
- zero: KerasTensor. The zero tensor for quantization.
1016
+ zero: KerasTensor. The zero tensor for quantization (int8 if signed,
1017
+ uint8 if unsigned).
823
1018
  maxq: scalar. The maximum quantization value.
824
1019
  """
825
1020
  # Input validation
@@ -874,13 +1069,31 @@ def compute_quantization_parameters(
874
1069
 
875
1070
  # Compute scale and zero-point
876
1071
  maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)
877
- scale = ops.divide(ops.subtract(max_values, min_values), maxq)
1072
+ range_values = ops.subtract(max_values, min_values)
1073
+ if epsilon > 0:
1074
+ range_values = ops.add(range_values, epsilon)
1075
+ scale = ops.divide(range_values, maxq)
878
1076
  scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
879
1077
 
880
- if symmetric:
881
- zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
1078
+ # Compute zero-point based on signed/unsigned mode
1079
+ if signed:
1080
+ # For signed range [-2^(bits-1), 2^(bits-1)-1], e.g., [-8, 7] for 4-bit
1081
+ qmin = -(2 ** (bits - 1)) # e.g., -8 for 4-bit
1082
+ qmax_signed = 2 ** (bits - 1) - 1 # e.g., 7 for 4-bit
1083
+ if symmetric:
1084
+ zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2) + qmin)
1085
+ else:
1086
+ # zero_signed = round(-min / scale) + qmin
1087
+ zero = ops.add(
1088
+ ops.round(ops.divide(ops.negative(min_values), scale)), qmin
1089
+ )
1090
+ zero = ops.clip(zero, qmin, qmax_signed)
882
1091
  else:
883
- zero = ops.round(ops.divide(ops.negative(min_values), scale))
1092
+ # For unsigned range [0, 2^bits-1], e.g., [0, 15] for 4-bit
1093
+ if symmetric:
1094
+ zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
1095
+ else:
1096
+ zero = ops.round(ops.divide(ops.negative(min_values), scale))
884
1097
 
885
1098
  # Reshape output to [out_features, n_groups] or [out_features, 1]
886
1099
  if n_groups > 1:
@@ -893,7 +1106,8 @@ def compute_quantization_parameters(
893
1106
  scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))
894
1107
  zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))
895
1108
 
896
- return scale, ops.cast(zero, "uint8"), maxq
1109
+ zero_dtype = "int8" if signed else "uint8"
1110
+ return scale, ops.cast(zero, zero_dtype), maxq
897
1111
 
898
1112
 
899
1113
  def quantize_with_zero_point(input_tensor, scale, zero, maxq):
@@ -942,51 +1156,67 @@ def dequantize_with_zero_point(input_tensor, scale, zero):
942
1156
  )
943
1157
 
944
1158
 
945
- def quantize_with_sz_map(weights_matrix, scale, zero, g_idx, maxq):
1159
+ def quantize_with_sz_map(
1160
+ weights_matrix, scale, zero, g_idx, maxq, group_axis=-1
1161
+ ):
946
1162
  """Quantize the weight matrix from group params.
947
1163
 
948
1164
  This function uses the provided scale and zero tensors to quantize the
949
- input weights_matrix according to the group indices. It maps each column
950
- of the weights_matrix to its corresponding group parameters and performs
951
- the quantization operation.
1165
+ input weights_matrix according to the group indices. It maps each position
1166
+ along group_axis of the weights_matrix to its corresponding group
1167
+ parameters and performs the quantization operation.
952
1168
 
953
1169
  Args:
954
- weights_matrix: 2D tensor of shape [out_features, in_features].
955
- scale: Per-group scale tensor of shape [out_features, n_groups].
956
- zero: Per-group zero-point tensor of shape [out_features, n_groups].
957
- g_idx: Integer tensor of shape [in_features,] mapping each column to
958
- its group index.
1170
+ weights_matrix: Tensor to quantize.
1171
+ scale: Per-group scale tensor with n_groups along group_axis.
1172
+ zero: Per-group zero-point tensor with n_groups along group_axis.
1173
+ g_idx: 1D integer tensor of length equal to the size of
1174
+ `weights_matrix` along the dimension being quantized. Each
1175
+ element specifies which group index (0 to n_groups-1) that
1176
+ position belongs to. For example, with 128 columns and
1177
+ group_size=32, g_idx would be
1178
+ `[0,0,...,0, 1,1,...,1, 2,2,...,2, 3,3,...,3]` (32 of each).
959
1179
  maxq: Scalar (float) representing the maximum integer quantization
960
1180
  level (e.g., 2^bits - 1).
1181
+ group_axis: The axis in `scale` and `zero` along which to index
1182
+ using `g_idx`. This determines which dimension of the
1183
+ scale/zero tensors contains the per-group values. Default: -1
1184
+ (last axis).
961
1185
 
962
1186
  Returns:
963
1187
  A tensor with the same shape as `weights_matrix` containing the
964
1188
  quantized weights produced using the provided group parameters.
965
1189
  """
966
1190
  groups = ops.cast(g_idx, "int32")
967
- scale_cols = ops.take(scale, groups, axis=1) # [out_features, in_features]
968
- zero_cols = ops.take(zero, groups, axis=1) # [out_features, in_features]
1191
+ scale_cols = ops.take(scale, groups, axis=group_axis)
1192
+ zero_cols = ops.take(zero, groups, axis=group_axis)
969
1193
 
970
1194
  # Quantize elementwise, then cast to int
971
1195
  return quantize_with_zero_point(weights_matrix, scale_cols, zero_cols, maxq)
972
1196
 
973
1197
 
974
- def dequantize_with_sz_map(weights_matrix, scale, zero, g_idx):
1198
+ def dequantize_with_sz_map(weights_matrix, scale, zero, g_idx, group_axis=-1):
975
1199
  """Rebuild a dequantized weight matrix from group params.
976
1200
 
977
1201
  This function uses the provided scale and zero tensors to dequantize the
978
- input weights_matrix according to the group indices. It maps each column
979
- of the weights_matrix to its corresponding group parameters and performs
980
- the dequantization operation.
1202
+ input weights_matrix according to the group indices. It maps each position
1203
+ along group_axis of the weights_matrix to its corresponding group
1204
+ parameters and performs the dequantization operation.
981
1205
 
982
1206
  Args:
983
- weights_matrix: 2D tensor of shape [out_features, in_features].
984
- scale: Per-group scale tensor of shape [out_features, n_groups].
985
- zero: Per-group zero-point tensor of shape [out_features, n_groups].
986
- g_idx: Integer tensor of shape [in_features,] mapping each column to
987
- its group index.
988
- maxq: Scalar (float) representing the maximum integer quantization
989
- level (e.g., 2^bits - 1).
1207
+ weights_matrix: Tensor to dequantize.
1208
+ scale: Per-group scale tensor with n_groups along group_axis.
1209
+ zero: Per-group zero-point tensor with n_groups along group_axis.
1210
+ g_idx: 1D integer tensor of length equal to the size of
1211
+ `weights_matrix` along the dimension being dequantized. Each
1212
+ element specifies which group index (0 to n_groups-1) that
1213
+ position belongs to. For example, with 128 columns and
1214
+ group_size=32, g_idx would be
1215
+ `[0,0,...,0, 1,1,...,1, 2,2,...,2, 3,3,...,3]` (32 of each).
1216
+ group_axis: The axis in `scale` and `zero` along which to index
1217
+ using `g_idx`. This determines which dimension of the
1218
+ scale/zero tensors contains the per-group values. Default: -1
1219
+ (last axis).
990
1220
 
991
1221
  Returns:
992
1222
  A tensor with the same shape as `weights_matrix` containing the
@@ -994,12 +1224,12 @@ def dequantize_with_sz_map(weights_matrix, scale, zero, g_idx):
994
1224
  """
995
1225
  # Map group indices to scales and zeros
996
1226
  groups = ops.cast(g_idx, "int32")
997
- scales_mapped = ops.take(scale, groups, axis=1)
998
- zeros_mapped = ops.take(zero, groups, axis=1)
1227
+ scales_mapped = ops.take(scale, groups, axis=group_axis)
1228
+ zeros_mapped = ops.take(zero, groups, axis=group_axis)
999
1229
  zeros_mapped = ops.cast(zeros_mapped, scales_mapped.dtype)
1000
1230
 
1001
- quantized = ops.multiply(
1231
+ dequantized = ops.multiply(
1002
1232
  ops.subtract(weights_matrix, zeros_mapped), scales_mapped
1003
1233
  )
1004
1234
 
1005
- return quantized
1235
+ return dequantized
@@ -509,9 +509,15 @@ class KerasFileEditor:
509
509
  # ------------------------------------------------------
510
510
 
511
511
  # Skip any objects that are not proper datasets
512
- if not hasattr(value, "shape") or not hasattr(value, "dtype"):
512
+ if not isinstance(value, h5py.Dataset):
513
513
  continue
514
514
 
515
+ if value.external:
516
+ raise ValueError(
517
+ "Not allowed: H5 file Dataset with external links: "
518
+ f"{value.external}"
519
+ )
520
+
515
521
  shape = value.shape
516
522
  dtype = value.dtype
517
523
 
@@ -121,10 +121,11 @@ def save_model(model, filepath, overwrite=True, zipped=None, **kwargs):
121
121
 
122
122
  @keras_export(["keras.saving.load_model", "keras.models.load_model"])
123
123
  def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
124
- """Loads a model saved via `model.save()`.
124
+ """Loads a model saved via `model.save()` or from an Orbax checkpoint.
125
125
 
126
126
  Args:
127
- filepath: `str` or `pathlib.Path` object, path to the saved model file.
127
+ filepath: `str` or `pathlib.Path` object, path to the saved model file
128
+ or Orbax checkpoint directory.
128
129
  custom_objects: Optional dictionary mapping names
129
130
  (strings) to custom classes or functions to be
130
131
  considered during deserialization.
@@ -195,6 +196,16 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
195
196
  compile=compile,
196
197
  safe_mode=safe_mode,
197
198
  )
199
+
200
+ # Check for Orbax checkpoint directory using utility function
201
+ if is_orbax_checkpoint(filepath):
202
+ return _load_model_from_orbax_checkpoint(
203
+ filepath,
204
+ custom_objects=custom_objects,
205
+ compile=compile,
206
+ safe_mode=safe_mode,
207
+ )
208
+
198
209
  elif str(filepath).endswith(".keras"):
199
210
  raise ValueError(
200
211
  f"File not found: filepath={filepath}. "
@@ -337,3 +348,56 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
337
348
  "`.weights.h5` files, legacy H5 format files "
338
349
  "(`.h5` extension), or Orbax checkpoints."
339
350
  )
351
+
352
+
353
+ def _load_model_from_orbax_checkpoint(
354
+ filepath, custom_objects=None, compile=True, safe_mode=True
355
+ ):
356
+ """Load a model from an Orbax checkpoint directory."""
357
+
358
+ from keras.src.utils.module_utils import ocp
359
+
360
+ # Ensure orbax is available
361
+ ocp.initialize()
362
+
363
+ # Find the latest checkpoint step using the utility function
364
+ checkpoint_path = find_latest_orbax_checkpoint(filepath)
365
+ step = int(os.path.basename(checkpoint_path))
366
+
367
+ # Load the composite state efficiently
368
+ checkpointer = ocp.training.Checkpointer(directory=filepath)
369
+ with ocp.Context():
370
+ composite_state = checkpointer.load_pytree(step)
371
+
372
+ # Validate and extract model config
373
+ if "model_config" not in composite_state:
374
+ raise ValueError(
375
+ "Checkpoint does not contain model configuration. "
376
+ "This checkpoint may have been saved with save_weights_only=True."
377
+ )
378
+
379
+ # Create and build model from config using saving_lib helper
380
+ # This properly handles shared objects and compile_config
381
+ model = saving_lib._model_from_config(
382
+ composite_state["model_config"],
383
+ custom_objects=custom_objects,
384
+ compile=compile,
385
+ safe_mode=safe_mode,
386
+ )
387
+
388
+ # Prepare state tree with only variable keys for set_state_tree
389
+ variable_keys = [
390
+ "trainable_variables",
391
+ "non_trainable_variables",
392
+ "optimizer_variables",
393
+ "metrics_variables",
394
+ ]
395
+ state_tree = {
396
+ key: composite_state[key]
397
+ for key in variable_keys
398
+ if key in composite_state
399
+ }
400
+
401
+ # Apply the loaded state to the model
402
+ model.set_state_tree(state_tree)
403
+ return model
@@ -796,7 +796,8 @@ def _load_state(
796
796
  try:
797
797
  saveable.load_own_variables(weights_store.get(inner_path))
798
798
  except Exception as e:
799
- failed_saveables.add(id(saveable))
799
+ if failed_saveables is not None:
800
+ failed_saveables.add(id(saveable))
800
801
  error_msgs[id(saveable)] = saveable, e
801
802
  failure = True
802
803
  else:
@@ -807,7 +808,8 @@ def _load_state(
807
808
  try:
808
809
  saveable.load_assets(assets_store.get(inner_path))
809
810
  except Exception as e:
810
- failed_saveables.add(id(saveable))
811
+ if failed_saveables is not None:
812
+ failed_saveables.add(id(saveable))
811
813
  error_msgs[id(saveable)] = saveable, e
812
814
  failure = True
813
815
  else:
@@ -855,7 +857,7 @@ def _load_state(
855
857
  if not failure:
856
858
  if visited_saveables is not None and newly_failed <= 0:
857
859
  visited_saveables.add(id(saveable))
858
- if id(saveable) in failed_saveables:
860
+ if failed_saveables is not None and id(saveable) in failed_saveables:
859
861
  failed_saveables.remove(id(saveable))
860
862
  error_msgs.pop(id(saveable))
861
863
 
@@ -1035,6 +1037,25 @@ class H5IOStore:
1035
1037
  # will mistakenly using `__len__` to determine the value.
1036
1038
  return self.h5_file.__bool__()
1037
1039
 
1040
+ def _verify_group(self, group):
1041
+ if not isinstance(group, h5py.Group):
1042
+ raise ValueError(
1043
+ f"Invalid H5 file, expected Group but received {type(group)}"
1044
+ )
1045
+ return group
1046
+
1047
+ def _verify_dataset(self, dataset):
1048
+ if not isinstance(dataset, h5py.Dataset):
1049
+ raise ValueError(
1050
+ f"Invalid H5 file, expected Dataset, received {type(dataset)}"
1051
+ )
1052
+ if dataset.external:
1053
+ raise ValueError(
1054
+ "Not allowed: H5 file Dataset with external links: "
1055
+ f"{dataset.external}"
1056
+ )
1057
+ return dataset
1058
+
1038
1059
  def _get_h5_file(self, path_or_io, mode=None):
1039
1060
  mode = mode or self.mode
1040
1061
  if mode not in ("r", "w", "a"):
@@ -1094,15 +1115,19 @@ class H5IOStore:
1094
1115
  self._h5_entry_group = {} # Defaults to an empty dict if not found.
1095
1116
  if not path:
1096
1117
  if "vars" in self.h5_file:
1097
- self._h5_entry_group = self.h5_file["vars"]
1118
+ self._h5_entry_group = self._verify_group(self.h5_file["vars"])
1098
1119
  elif path in self.h5_file and "vars" in self.h5_file[path]:
1099
- self._h5_entry_group = self.h5_file[path]["vars"]
1120
+ self._h5_entry_group = self._verify_group(
1121
+ self._verify_group(self.h5_file[path])["vars"]
1122
+ )
1100
1123
  else:
1101
1124
  # No hit. Fix for 2.13 compatibility.
1102
1125
  if "_layer_checkpoint_dependencies" in self.h5_file:
1103
1126
  path = path.replace("layers", "_layer_checkpoint_dependencies")
1104
1127
  if path in self.h5_file and "vars" in self.h5_file[path]:
1105
- self._h5_entry_group = self.h5_file[path]["vars"]
1128
+ self._h5_entry_group = self._verify_group(
1129
+ self._verify_group(self.h5_file[path])["vars"]
1130
+ )
1106
1131
  self._h5_entry_initialized = True
1107
1132
  return self
1108
1133
 
@@ -1134,25 +1159,15 @@ class H5IOStore:
1134
1159
  def keys(self):
1135
1160
  return self._h5_entry_group.keys()
1136
1161
 
1137
- def items(self):
1138
- return self._h5_entry_group.items()
1139
-
1140
- def values(self):
1141
- return self._h5_entry_group.values()
1142
-
1143
1162
  def __getitem__(self, key):
1144
- value = self._h5_entry_group[key]
1163
+ value = self._verify_dataset(self._h5_entry_group[key])
1145
1164
  if (
1146
1165
  hasattr(value, "attrs")
1147
1166
  and "dtype" in value.attrs
1148
1167
  and value.attrs["dtype"] == "bfloat16"
1149
1168
  ):
1150
1169
  value = np.array(value, dtype=ml_dtypes.bfloat16)
1151
- elif (
1152
- hasattr(value, "shape")
1153
- and hasattr(value, "dtype")
1154
- and not isinstance(value, np.ndarray)
1155
- ):
1170
+ elif not isinstance(value, np.ndarray):
1156
1171
  value = np.array(value)
1157
1172
  return value
1158
1173
 
@@ -1355,15 +1370,13 @@ class ShardedH5IOStore(H5IOStore):
1355
1370
  self._get_h5_group(self._h5_entry_path)
1356
1371
 
1357
1372
  def _restore_h5_file(self):
1358
- """Ensure the current shard is the last one created.
1359
-
1360
- We use mode="a" to avoid truncating the file during the switching.
1361
- """
1373
+ """Ensure the current shard is the last one created."""
1362
1374
  if (
1363
1375
  pathlib.Path(self.h5_file.filename).name
1364
1376
  != self.current_shard_path.name
1365
1377
  ):
1366
- self._switch_h5_file(self.current_shard_path.name, mode="a")
1378
+ mode = "a" if self.mode == "w" else "r"
1379
+ self._switch_h5_file(self.current_shard_path.name, mode=mode)
1367
1380
 
1368
1381
  # H5 entry level methods.
1369
1382
 
@@ -1371,9 +1384,11 @@ class ShardedH5IOStore(H5IOStore):
1371
1384
  """Get the H5 entry group. If it doesn't exist, return an empty dict."""
1372
1385
  try:
1373
1386
  if not path:
1374
- self._h5_entry_group = self.h5_file["vars"]
1387
+ self._h5_entry_group = self._verify_group(self.h5_file["vars"])
1375
1388
  else:
1376
- self._h5_entry_group = self.h5_file[path]["vars"]
1389
+ self._h5_entry_group = self._verify_group(
1390
+ self._verify_group(self.h5_file[path])["vars"]
1391
+ )
1377
1392
  self._h5_entry_initialized = True
1378
1393
  except KeyError:
1379
1394
  self._h5_entry_group = {}
@@ -1392,33 +1407,17 @@ class ShardedH5IOStore(H5IOStore):
1392
1407
  return total_len
1393
1408
 
1394
1409
  def keys(self):
1395
- keys = set(self._h5_entry_group.keys())
1410
+ keys = []
1411
+ current_shard_keys = list(self._h5_entry_group.keys())
1396
1412
  for filename in self.current_shard_filenames:
1397
1413
  if filename == self.current_shard_path.name:
1398
- continue
1399
- self._switch_h5_file(filename, mode="r")
1400
- keys.update(self._h5_entry_group.keys())
1414
+ keys += current_shard_keys
1415
+ else:
1416
+ self._switch_h5_file(filename, mode="r")
1417
+ keys += list(self._h5_entry_group.keys())
1401
1418
  self._restore_h5_file()
1402
1419
  return keys
1403
1420
 
1404
- def items(self):
1405
- yield from self._h5_entry_group.items()
1406
- for filename in self.current_shard_filenames:
1407
- if filename == self.current_shard_path.name:
1408
- continue
1409
- self._switch_h5_file(filename, mode="r")
1410
- yield from self._h5_entry_group.items()
1411
- self._restore_h5_file()
1412
-
1413
- def values(self):
1414
- yield from self._h5_entry_group.values()
1415
- for filename in self.current_shard_filenames:
1416
- if filename == self.current_shard_path.name:
1417
- continue
1418
- self._switch_h5_file(filename, mode="r")
1419
- yield from self._h5_entry_group.values()
1420
- self._restore_h5_file()
1421
-
1422
1421
  def __getitem__(self, key):
1423
1422
  if key in self._h5_entry_group:
1424
1423
  return super().__getitem__(key)
keras/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras.src.api_export import keras_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "3.14.0.dev2026012704"
4
+ __version__ = "3.14.0.dev2026012904"
5
5
 
6
6
 
7
7
  @keras_export("keras.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-nightly
3
- Version: 3.14.0.dev2026012704
3
+ Version: 3.14.0.dev2026012904
4
4
  Summary: Multi-backend Keras
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0