keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -73,6 +73,23 @@ def abs_max_quantize(
73
73
  epsilon=backend.epsilon(),
74
74
  to_numpy=False,
75
75
  ):
76
+ """
77
+ Quantizes the input tensor using the absolute maximum quantization scheme.
78
+
79
+ Args:
80
+ inputs: Input tensor to quantize.
81
+ axis: Axis along which to compute the quantization range.
82
+ value_range: Tuple of the minimum and maximum values of the quantization
83
+ range.
84
+ dtype: Data type of the quantized output.
85
+ epsilon: Small value to avoid division by zero.
86
+ to_numpy: Whether to perform the quantization in numpy. This performs
87
+ the computation on the host CPU and can be useful for saving memory
88
+ on the device. If False, the computation is performed on the device.
89
+
90
+ Returns:
91
+ A tuple of the quantized tensor and the scale.
92
+ """
76
93
  if to_numpy:
77
94
  # Save memory on the device using numpy
78
95
  original_dtype = backend.standardize_dtype(inputs.dtype)
@@ -105,31 +122,69 @@ def abs_max_quantize(
105
122
  class AbsMaxQuantizer(Quantizer):
106
123
  def __init__(
107
124
  self,
108
- axis,
125
+ axis=None, # Deprecated, provide axis in __call__ instead.
109
126
  value_range=(-127, 127),
110
127
  epsilon=backend.epsilon(),
111
128
  output_dtype="int8",
112
129
  ):
113
130
  Quantizer.__init__(self, output_dtype=output_dtype)
114
- if isinstance(axis, int):
115
- axis = (axis,)
116
- self.axis = tuple(axis)
131
+ if axis is not None:
132
+ if isinstance(axis, int):
133
+ axis = (axis,)
134
+ self.axis = tuple(axis)
135
+ else:
136
+ self.axis = None
117
137
  self.value_range = value_range
118
138
  self.epsilon = epsilon
139
+ if output_dtype == "int8":
140
+ if value_range[0] < -128 or value_range[1] > 127:
141
+ raise ValueError(
142
+ f"Quantizer with output_dtype='int8' requires value_range "
143
+ f"to be within the interval [-128, 127]. Received: "
144
+ f"value_range={value_range}"
145
+ )
119
146
 
120
- def __call__(self, x):
147
+ def __call__(self, x, axis=None, to_numpy=False):
148
+ """
149
+ Quantizes the input tensor.
150
+
151
+ Args:
152
+ x: Input tensor to quantize.
153
+ axis: Axis along which to compute the quantization range. If None,
154
+ uses the axis specified in the constructor. If None and no axis
155
+ was specified in the constructor, defaults to -1.
156
+ to_numpy: Whether to perform the quantization in numpy. This
157
+ performs the computation on the host CPU and can be useful for
158
+ saving memory on the device. If False, the computation is
159
+ performed on the device.
160
+
161
+ Returns:
162
+ A tuple of the quantized tensor and the scale.
163
+ """
164
+ if axis is None:
165
+ axis = self.axis
166
+ if axis is None:
167
+ # Default to -1 if no axis is specified
168
+ axis = -1
121
169
  quantized_x, scale = abs_max_quantize(
122
- x, self.axis, self.value_range, self.output_dtype, self.epsilon
170
+ x,
171
+ axis,
172
+ self.value_range,
173
+ self.output_dtype,
174
+ self.epsilon,
175
+ to_numpy,
123
176
  )
124
177
  return quantized_x, scale
125
178
 
126
179
  def get_config(self):
127
- return {
128
- "axis": self.axis,
180
+ config = {
129
181
  "value_range": self.value_range,
130
182
  "epsilon": self.epsilon,
131
183
  "output_dtype": self.output_dtype,
132
184
  }
185
+ if self.axis is not None:
186
+ config["axis"] = self.axis
187
+ return config
133
188
 
134
189
 
135
190
  def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
@@ -281,7 +336,7 @@ def fake_quant_with_min_max_vars(
281
336
  ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
282
337
  )
283
338
  x_clamped = ops.clip(
284
- x, ops.cast(nudged_min, x.dtype), ops.cast(nudged_max, x.dtype)
339
+ ops.cast(x, nudged_min.dtype), nudged_min, nudged_max
285
340
  )
286
341
  x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
287
342
  result = ops.multiply(
@@ -318,6 +373,7 @@ def fake_quant_with_min_max_vars(
318
373
  grad_min = ops.sum(grad_min, axis=axes)
319
374
  else:
320
375
  grad_min = ops.sum(grad_min)
376
+ grad_min = ops.reshape(grad_min, ops.shape(min_val))
321
377
 
322
378
  # Gradient for max_val
323
379
  # When x is clipped to max, the gradient flows to max_val
@@ -327,6 +383,7 @@ def fake_quant_with_min_max_vars(
327
383
  grad_max = ops.sum(grad_max, axis=axes)
328
384
  else:
329
385
  grad_max = ops.sum(grad_max)
386
+ grad_max = ops.reshape(grad_max, ops.shape(max_val))
330
387
 
331
388
  return dx, grad_min, grad_max
332
389
 
@@ -378,7 +435,7 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype):
378
435
 
379
436
 
380
437
  @keras_export("keras.quantizers.pack_int4")
381
- def pack_int4(arr, axis=0):
438
+ def pack_int4(arr, axis=0, dtype="int8"):
382
439
  """Pack an int4 tensor into an int8 tensor with packed nibbles.
383
440
 
384
441
  The input values must already be int8 in the signed range `[-8, 7]` and
@@ -390,8 +447,11 @@ def pack_int4(arr, axis=0):
390
447
  the value from the second row.
391
448
 
392
449
  Args:
393
- arr: An int8 tensor containing int4 values in the range `[-8, 7]`.
450
+ arr: An `int8` or `uint8` tensor containing int4 values in the range
451
+ `[-8, 7]`.
394
452
  axis: The axis along which to pack the tensor. Defaults to 0.
453
+ dtype: The data type of the input and packed tensor. Can be
454
+ `"int8"` or `"uint8"`. Defaults to `"int8"`.
395
455
 
396
456
  Returns:
397
457
  tuple: A tuple `(packed, packed_shape, orig_rows)` where `packed` is
@@ -451,9 +511,14 @@ def pack_int4(arr, axis=0):
451
511
  True
452
512
  ```
453
513
  """
454
- if backend.standardize_dtype(arr.dtype) != "int8":
514
+ if dtype not in ("int8", "uint8"):
515
+ raise ValueError(
516
+ f"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'."
517
+ )
518
+ if backend.standardize_dtype(arr.dtype) != dtype:
455
519
  raise TypeError(
456
- "Expected int8 tensor for packing, got {}".format(arr.dtype)
520
+ f"Expected {dtype} tensor for packing, got "
521
+ f"{backend.standardize_dtype(arr.dtype)}."
457
522
  )
458
523
 
459
524
  rank = getattr(arr.shape, "rank", None) or len(arr.shape)
@@ -487,12 +552,12 @@ def pack_int4(arr, axis=0):
487
552
  low = padded[::2, ...]
488
553
  high = padded[1::2, ...]
489
554
 
490
- mask = ops.array(0x0F, dtype="int8")
555
+ mask = ops.array(0x0F, dtype=dtype)
491
556
  low_u = ops.bitwise_and(low, mask)
492
557
  high_u = ops.bitwise_and(high, mask)
493
558
 
494
559
  packed = ops.bitwise_or(low_u, ops.left_shift(high_u, 4))
495
- packed = ops.cast(packed, "int8")
560
+ packed = ops.cast(packed, dtype)
496
561
 
497
562
  # 5-6. Restore shape.
498
563
  packed = ops.transpose(packed, inv_perm) # back to original order
@@ -501,7 +566,7 @@ def pack_int4(arr, axis=0):
501
566
 
502
567
 
503
568
  @keras_export("keras.quantizers.unpack_int4")
504
- def unpack_int4(packed, orig_len, axis=0):
569
+ def unpack_int4(packed, orig_len, axis=0, dtype="int8"):
505
570
  """Unpack a packed int4 back to an int8 tensor in the range [-8, 7].
506
571
 
507
572
  This function reverses the packing performed by `pack_int4`, restoring
@@ -519,6 +584,8 @@ def unpack_int4(packed, orig_len, axis=0):
519
584
  packed. This is used to remove any padding that may have
520
585
  been added during packing to ensure an even number of rows.
521
586
  axis: The axis along which the tensor was packed. Defaults to 0.
587
+ dtype: The data type of the input and unpacked tensor. Can be
588
+ `"int8"` or `"uint8"`. Defaults to `"int8"`.
522
589
 
523
590
  Returns:
524
591
  unpacked: An int8 tensor with the same shape as the original
@@ -575,13 +642,24 @@ def unpack_int4(packed, orig_len, axis=0):
575
642
  True
576
643
  ```
577
644
  """
578
- if backend.standardize_dtype(packed.dtype) != "int8":
645
+ if dtype not in ("int8", "uint8"):
646
+ raise ValueError(
647
+ f"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'."
648
+ )
649
+
650
+ if backend.standardize_dtype(packed.dtype) not in ("int8", "uint8"):
579
651
  raise TypeError(
580
- f"Expected int8 tensor for unpacking, got {packed.dtype}"
652
+ f"Expected int8 or uint8 tensor for unpacking, got {packed.dtype}"
581
653
  )
582
654
 
583
- rank = getattr(packed.shape, "rank", None) or len(packed.shape)
655
+ def to_signed(x):
656
+ """Converts unpacked nibbles [0, 15] to signed int4 [-8, 7]."""
657
+ dtype_x = backend.standardize_dtype(x.dtype)
658
+ eight = ops.cast(8, dtype_x)
659
+ sixteen = ops.cast(16, dtype_x)
660
+ return ops.where(x < eight, x, x - sixteen)
584
661
 
662
+ rank = getattr(packed.shape, "rank", None) or len(packed.shape)
585
663
  if axis < 0:
586
664
  axis += rank
587
665
 
@@ -592,16 +670,15 @@ def unpack_int4(packed, orig_len, axis=0):
592
670
  low_unpacked = ops.bitwise_and(packed, mask)
593
671
  high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), mask)
594
672
 
595
- # Convert values from [0, 15] to [-8, 7].
596
- low_signed = ops.where(
597
- low_unpacked < 8, low_unpacked, low_unpacked - 16
598
- )
599
- high_signed = ops.where(
600
- high_unpacked < 8, high_unpacked, high_unpacked - 16
601
- )
673
+ if dtype == "int8":
674
+ low_unpacked = to_signed(low_unpacked)
675
+ high_unpacked = to_signed(high_unpacked)
676
+
677
+ low_final = ops.cast(low_unpacked, dtype)
678
+ high_final = ops.cast(high_unpacked, dtype)
602
679
 
603
680
  # Interleave and reshape
604
- stacked = ops.stack([low_signed, high_signed], axis=1)
681
+ stacked = ops.stack([low_final, high_final], axis=1)
605
682
  unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(packed)[1:]))
606
683
 
607
684
  # Remove padding and return
@@ -613,28 +690,27 @@ def unpack_int4(packed, orig_len, axis=0):
613
690
  transposed = ops.transpose(packed, perm)
614
691
 
615
692
  # 1. Split nibbles.
616
- mask = ops.array(0x0F, dtype="int8") # int8 arrays
693
+ mask = ops.array(0x0F, dtype=packed.dtype)
617
694
  low = ops.bitwise_and(transposed, mask)
618
695
  high = ops.bitwise_and(ops.right_shift(transposed, 4), mask)
619
696
 
620
- eight = ops.array(8, dtype="int8")
621
- sixteen = ops.array(16, dtype="int8")
622
-
623
- def to_signed(x):
624
- return ops.where(x < eight, x, x - sixteen)
697
+ # 2. Conditionally convert to signed.
698
+ if dtype == "int8":
699
+ low = to_signed(low)
700
+ high = to_signed(high)
625
701
 
626
- low = to_signed(low)
627
- high = to_signed(high)
702
+ low = ops.cast(low, dtype)
703
+ high = ops.cast(high, dtype)
628
704
 
629
- # 2. Interleave and reshape.
630
- stacked = ops.stack([low, high], axis=1) # (pairs, 2, ...)
705
+ # 3. Interleave and reshape.
706
+ stacked = ops.stack([low, high], axis=1)
631
707
  unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:]))
632
708
 
633
709
  # 4. Remove padding and restore original layout.
634
710
  unpacked = unpacked[:orig_len, ...]
635
711
  unpacked = ops.transpose(unpacked, inv_perm)
636
712
 
637
- return unpacked # dtype is int8
713
+ return unpacked
638
714
 
639
715
 
640
716
  class GPTQQuantizer(Quantizer):
@@ -0,0 +1,23 @@
1
+ import re
2
+
3
+
4
+ def should_quantize_layer(layer, filters):
5
+ """Determines if a layer should be quantized based on filters.
6
+
7
+ Args:
8
+ layer: The layer to check.
9
+ filters: A regex string, a list of regex strings, or a callable.
10
+ If None, returns True.
11
+
12
+ Returns:
13
+ True if the layer should be quantized, False otherwise.
14
+ """
15
+ if filters is None:
16
+ return True
17
+ if isinstance(filters, str):
18
+ return bool(re.search(filters, layer.name))
19
+ if isinstance(filters, (list, tuple)):
20
+ return any(re.search(pat, layer.name) for pat in filters)
21
+ if callable(filters):
22
+ return filters(layer)
23
+ return True
@@ -8,6 +8,8 @@ from keras.src.backend.common import global_state
8
8
  from keras.src.utils import jax_utils
9
9
  from keras.src.utils.naming import auto_name
10
10
 
11
+ GLOBAL_SEED_GENERATOR = "global_seed_generator"
12
+
11
13
 
12
14
  @keras_export("keras.random.SeedGenerator")
13
15
  class SeedGenerator:
@@ -133,10 +135,10 @@ def global_seed_generator():
133
135
  "out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n"
134
136
  "```"
135
137
  )
136
- gen = global_state.get_global_attribute("global_seed_generator")
138
+ gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR)
137
139
  if gen is None:
138
140
  gen = SeedGenerator()
139
- global_state.set_global_attribute("global_seed_generator", gen)
141
+ global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen)
140
142
  return gen
141
143
 
142
144
 
@@ -455,6 +455,9 @@ class KerasFileEditor:
455
455
  def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
456
456
  metadata = metadata or {}
457
457
 
458
+ # ------------------------------------------------------
459
+ # Collect metadata for this HDF5 group
460
+ # ------------------------------------------------------
458
461
  object_metadata = {}
459
462
  for k, v in data.attrs.items():
460
463
  object_metadata[k] = v
@@ -462,26 +465,98 @@ class KerasFileEditor:
462
465
  metadata[inner_path] = object_metadata
463
466
 
464
467
  result = collections.OrderedDict()
468
+
469
+ # ------------------------------------------------------
470
+ # Iterate over all keys in this HDF5 group
471
+ # ------------------------------------------------------
465
472
  for key in data.keys():
466
- inner_path = f"{inner_path}/{key}"
473
+ # IMPORTANT:
474
+ # Never mutate inner_path; use local variable.
475
+ current_inner_path = f"{inner_path}/{key}"
467
476
  value = data[key]
477
+
478
+ # ------------------------------------------------------
479
+ # CASE 1 — HDF5 GROUP → RECURSE
480
+ # ------------------------------------------------------
468
481
  if isinstance(value, h5py.Group):
482
+ # Skip empty groups
469
483
  if len(value) == 0:
470
484
  continue
485
+
486
+ # Skip empty "vars" groups
471
487
  if "vars" in value.keys() and len(value["vars"]) == 0:
472
488
  continue
473
489
 
474
- if hasattr(value, "keys"):
490
+ # Recurse into "vars" subgroup when present
475
491
  if "vars" in value.keys():
476
492
  result[key], metadata = self._extract_weights_from_store(
477
- value["vars"], metadata=metadata, inner_path=inner_path
493
+ value["vars"],
494
+ metadata=metadata,
495
+ inner_path=current_inner_path,
478
496
  )
479
497
  else:
498
+ # Recurse normally
480
499
  result[key], metadata = self._extract_weights_from_store(
481
- value, metadata=metadata, inner_path=inner_path
500
+ value,
501
+ metadata=metadata,
502
+ inner_path=current_inner_path,
482
503
  )
483
- else:
484
- result[key] = value[()]
504
+
505
+ continue # finished processing this key
506
+
507
+ # ------------------------------------------------------
508
+ # CASE 2 — HDF5 DATASET → SAFE LOADING
509
+ # ------------------------------------------------------
510
+
511
+ # Skip any objects that are not proper datasets
512
+ if not hasattr(value, "shape") or not hasattr(value, "dtype"):
513
+ continue
514
+
515
+ shape = value.shape
516
+ dtype = value.dtype
517
+
518
+ # ------------------------------------------------------
519
+ # Validate SHAPE (avoid malformed / malicious metadata)
520
+ # ------------------------------------------------------
521
+
522
+ # No negative dimensions
523
+ if any(dim < 0 for dim in shape):
524
+ raise ValueError(
525
+ "Malformed HDF5 dataset shape encountered in .keras file; "
526
+ "negative dimension detected."
527
+ )
528
+
529
+ # Prevent absurdly high-rank tensors
530
+ if len(shape) > 64:
531
+ raise ValueError(
532
+ "Malformed HDF5 dataset shape encountered in .keras file; "
533
+ "tensor rank exceeds safety limit."
534
+ )
535
+
536
+ # Safe product computation (Python int is unbounded)
537
+ num_elems = int(np.prod(shape))
538
+
539
+ # ------------------------------------------------------
540
+ # Validate TOTAL memory size
541
+ # ------------------------------------------------------
542
+ MAX_BYTES = 1 << 32 # 4 GiB
543
+
544
+ size_bytes = num_elems * dtype.itemsize
545
+
546
+ if size_bytes > MAX_BYTES:
547
+ raise ValueError(
548
+ f"HDF5 dataset too large to load safely "
549
+ f"({size_bytes} bytes; limit is {MAX_BYTES})."
550
+ )
551
+
552
+ # ------------------------------------------------------
553
+ # SAFE — load dataset (guaranteed ≤ 4 GiB)
554
+ # ------------------------------------------------------
555
+ result[key] = value[()]
556
+
557
+ # ------------------------------------------------------
558
+ # Return final tree and metadata
559
+ # ------------------------------------------------------
485
560
  return result, metadata
486
561
 
487
562
  def _generate_filepath_info(self, rich_style=False):
@@ -943,7 +943,7 @@ class DiskIOStore:
943
943
  if self.archive:
944
944
  self.tmp_dir = get_temp_dir()
945
945
  if self.mode == "r":
946
- self.archive.extractall(path=self.tmp_dir)
946
+ file_utils.extract_open_archive(self.archive, self.tmp_dir)
947
947
  self.working_dir = file_utils.join(
948
948
  self.tmp_dir, self.root_path
949
949
  ).replace("\\", "/")
@@ -3,3 +3,4 @@ from keras.src.testing.test_case import jax_uses_gpu
3
3
  from keras.src.testing.test_case import tensorflow_uses_gpu
4
4
  from keras.src.testing.test_case import torch_uses_gpu
5
5
  from keras.src.testing.test_case import uses_gpu
6
+ from keras.src.testing.test_case import uses_tpu
@@ -40,7 +40,20 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
40
40
  self.addCleanup(lambda: shutil.rmtree(temp_dir))
41
41
  return temp_dir
42
42
 
43
- def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
43
+ def assertAllClose(
44
+ self,
45
+ x1,
46
+ x2,
47
+ atol=1e-6,
48
+ rtol=1e-6,
49
+ tpu_atol=None,
50
+ tpu_rtol=None,
51
+ msg=None,
52
+ ):
53
+ if tpu_atol is not None and uses_tpu():
54
+ atol = tpu_atol
55
+ if tpu_rtol is not None and uses_tpu():
56
+ rtol = tpu_rtol
44
57
  if not isinstance(x1, np.ndarray):
45
58
  x1 = backend.convert_to_numpy(x1)
46
59
  if not isinstance(x2, np.ndarray):
@@ -57,7 +70,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
57
70
  f"The two values are close at all elements. \n{msg}.\nValues: {x1}"
58
71
  )
59
72
 
60
- def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
73
+ def assertAlmostEqual(self, x1, x2, decimal=3, tpu_decimal=None, msg=None):
74
+ if tpu_decimal is not None and uses_tpu():
75
+ decimal = tpu_decimal
61
76
  msg = msg or ""
62
77
  if not isinstance(x1, np.ndarray):
63
78
  x1 = backend.convert_to_numpy(x1)
@@ -195,6 +210,8 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
195
210
  run_training_check=True,
196
211
  run_mixed_precision_check=True,
197
212
  assert_built_after_instantiation=False,
213
+ tpu_atol=None,
214
+ tpu_rtol=None,
198
215
  ):
199
216
  """Run basic checks on a layer.
200
217
 
@@ -376,7 +393,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
376
393
  msg="Unexpected number of torch_params",
377
394
  )
378
395
 
379
- def run_output_asserts(layer, output, eager=False):
396
+ def run_output_asserts(
397
+ layer, output, eager=False, tpu_atol=None, tpu_rtol=None
398
+ ):
380
399
  if expected_output_shape is not None:
381
400
 
382
401
  def verify_shape(expected_shape, x):
@@ -422,7 +441,11 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
422
441
  tree.flatten(expected_output), tree.flatten(output)
423
442
  ):
424
443
  self.assertAllClose(
425
- ref_v, v, msg="Unexpected output value"
444
+ ref_v,
445
+ v,
446
+ msg="Unexpected output value",
447
+ tpu_atol=tpu_atol,
448
+ tpu_rtol=tpu_rtol,
426
449
  )
427
450
  if expected_num_losses is not None:
428
451
  self.assertLen(layer.losses, expected_num_losses)
@@ -551,7 +574,13 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
551
574
  output_data = layer(**input_data, **call_kwargs)
552
575
  else:
553
576
  output_data = layer(input_data, **call_kwargs)
554
- run_output_asserts(layer, output_data, eager=True)
577
+ run_output_asserts(
578
+ layer,
579
+ output_data,
580
+ eager=True,
581
+ tpu_atol=tpu_atol,
582
+ tpu_rtol=tpu_rtol,
583
+ )
555
584
 
556
585
  if run_training_check:
557
586
  run_training_step(layer, input_data, output_data)
@@ -621,6 +650,17 @@ def uses_gpu():
621
650
  return False
622
651
 
623
652
 
653
+ def uses_tpu():
654
+ # Condition used to skip tests when using the TPU
655
+ try:
656
+ devices = distribution.list_devices()
657
+ if any(d.startswith("tpu") for d in devices):
658
+ return True
659
+ except AttributeError:
660
+ return False
661
+ return False
662
+
663
+
624
664
  def uses_cpu():
625
665
  devices = distribution.list_devices()
626
666
  if any(d.startswith("cpu") for d in devices):
@@ -148,6 +148,7 @@ class CompileMetrics(metrics_module.Metric):
148
148
  self.built = False
149
149
  self.name = "compile_metrics"
150
150
  self.output_names = output_names
151
+ self._resolved_output_names = None
151
152
 
152
153
  @property
153
154
  def metrics(self):
@@ -175,10 +176,16 @@ class CompileMetrics(metrics_module.Metric):
175
176
 
176
177
  def build(self, y_true, y_pred):
177
178
  num_outputs = 1 # default
178
- if self.output_names:
179
+ # Resolve output names. If y_pred is a dict, prefer its keys.
180
+ if isinstance(y_pred, dict):
181
+ keys = sorted(list(y_pred.keys()))
182
+ if self.output_names and set(self.output_names) == set(keys):
183
+ # If there is a perfect match, use the user-provided order.
184
+ output_names = self.output_names
185
+ else:
186
+ output_names = keys
187
+ elif self.output_names:
179
188
  output_names = self.output_names
180
- elif isinstance(y_pred, dict):
181
- output_names = sorted(list(y_pred.keys()))
182
189
  elif isinstance(y_pred, (list, tuple)):
183
190
  num_outputs = len(y_pred)
184
191
  if all(hasattr(x, "_keras_history") for x in y_pred):
@@ -187,6 +194,7 @@ class CompileMetrics(metrics_module.Metric):
187
194
  output_names = None
188
195
  else:
189
196
  output_names = None
197
+ self._resolved_output_names = output_names
190
198
  if output_names:
191
199
  num_outputs = len(output_names)
192
200
 
@@ -316,9 +324,10 @@ class CompileMetrics(metrics_module.Metric):
316
324
  return flat_metrics
317
325
 
318
326
  def _flatten_y(self, y):
319
- if isinstance(y, dict) and self.output_names:
327
+ names = self._resolved_output_names
328
+ if isinstance(y, dict) and names:
320
329
  result = []
321
- for name in self.output_names:
330
+ for name in names:
322
331
  if name in y:
323
332
  result.append(y[name])
324
333
  return result
@@ -3,6 +3,7 @@ import importlib
3
3
  import inspect
4
4
  import os
5
5
  import sys
6
+ import warnings
6
7
 
7
8
  from keras.src import backend as backend_module
8
9
  from keras.src.api_export import keras_export
@@ -124,9 +125,22 @@ def set_backend(backend):
124
125
 
125
126
  Example:
126
127
 
127
- ```python
128
- keras.config.set_backend("jax")
129
- ```
128
+ >>> import os
129
+ >>> os.environ["KERAS_BACKEND"] = "tensorflow"
130
+ >>>
131
+ >>> import keras
132
+ >>> from keras import ops
133
+ >>> type(ops.ones(()))
134
+ <class 'tensorflow.python.framework.ops.EagerTensor'>
135
+ >>>
136
+ >>> keras.config.set_backend("jax")
137
+ UserWarning: Using `keras.config.set_backend` is dangerous...
138
+ >>> del keras, ops
139
+ >>>
140
+ >>> import keras
141
+ >>> from keras import ops
142
+ >>> type(ops.ones(()))
143
+ <class 'jaxlib.xla_extension.ArrayImpl'>
130
144
 
131
145
  ⚠️ WARNING ⚠️: Using this function is dangerous and should be done
132
146
  carefully. Changing the backend will **NOT** convert
@@ -138,7 +152,7 @@ def set_backend(backend):
138
152
 
139
153
  This includes any function or class instance that uses any Keras
140
154
  functionality. All such code needs to be re-executed after calling
141
- `set_backend()`.
155
+ `set_backend()` and re-importing all imported `keras` modules.
142
156
  """
143
157
  os.environ["KERAS_BACKEND"] = backend
144
158
  # Clear module cache.
@@ -159,3 +173,16 @@ def set_backend(backend):
159
173
  module_name = module_name[module_name.find("'") + 1 :]
160
174
  module_name = module_name[: module_name.find("'")]
161
175
  globals()[key] = importlib.import_module(module_name)
176
+
177
+ warnings.warn(
178
+ "Using `keras.config.set_backend` is dangerous and should be done "
179
+ "carefully. Already-instantiated objects will not be converted. Thus, "
180
+ "any layers / tensors / etc. already created will no longer be usable "
181
+ "without errors. It is strongly recommended not to keep around any "
182
+ "Keras-originated objects instances created before calling "
183
+ "`set_backend()`. This includes any function or class instance that "
184
+ "uses any Keras functionality. All such code needs to be re-executed "
185
+ "after calling `set_backend()` and re-importing all imported `keras` "
186
+ "modules.",
187
+ stacklevel=2,
188
+ )