keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +12 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +12 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +1 -1
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +33 -16
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +485 -20
- keras/src/backend/jax/numpy.py +92 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +76 -7
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1030 -185
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +264 -54
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +84 -8
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +299 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +191 -172
- keras/src/layers/core/einsum_dense.py +235 -186
- keras/src/layers/core/embedding.py +83 -93
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +390 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +40 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/losses/loss.py +1 -1
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +156 -27
- keras/src/ops/image.py +184 -3
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +541 -43
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +12 -1
- keras/src/quantizers/gptq.py +8 -6
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +150 -78
- keras/src/quantizers/quantization_config.py +232 -0
- keras/src/quantizers/quantizers.py +114 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +4 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +14 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +187 -36
- keras/src/utils/module_utils.py +18 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/rng_utils.py +9 -1
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
|
@@ -73,6 +73,23 @@ def abs_max_quantize(
|
|
|
73
73
|
epsilon=backend.epsilon(),
|
|
74
74
|
to_numpy=False,
|
|
75
75
|
):
|
|
76
|
+
"""
|
|
77
|
+
Quantizes the input tensor using the absolute maximum quantization scheme.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
inputs: Input tensor to quantize.
|
|
81
|
+
axis: Axis along which to compute the quantization range.
|
|
82
|
+
value_range: Tuple of the minimum and maximum values of the quantization
|
|
83
|
+
range.
|
|
84
|
+
dtype: Data type of the quantized output.
|
|
85
|
+
epsilon: Small value to avoid division by zero.
|
|
86
|
+
to_numpy: Whether to perform the quantization in numpy. This performs
|
|
87
|
+
the computation on the host CPU and can be useful for saving memory
|
|
88
|
+
on the device. If False, the computation is performed on the device.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
A tuple of the quantized tensor and the scale.
|
|
92
|
+
"""
|
|
76
93
|
if to_numpy:
|
|
77
94
|
# Save memory on the device using numpy
|
|
78
95
|
original_dtype = backend.standardize_dtype(inputs.dtype)
|
|
@@ -105,31 +122,69 @@ def abs_max_quantize(
|
|
|
105
122
|
class AbsMaxQuantizer(Quantizer):
|
|
106
123
|
def __init__(
|
|
107
124
|
self,
|
|
108
|
-
axis,
|
|
125
|
+
axis=None, # Deprecated, provide axis in __call__ instead.
|
|
109
126
|
value_range=(-127, 127),
|
|
110
127
|
epsilon=backend.epsilon(),
|
|
111
128
|
output_dtype="int8",
|
|
112
129
|
):
|
|
113
130
|
Quantizer.__init__(self, output_dtype=output_dtype)
|
|
114
|
-
if
|
|
115
|
-
|
|
116
|
-
|
|
131
|
+
if axis is not None:
|
|
132
|
+
if isinstance(axis, int):
|
|
133
|
+
axis = (axis,)
|
|
134
|
+
self.axis = tuple(axis)
|
|
135
|
+
else:
|
|
136
|
+
self.axis = None
|
|
117
137
|
self.value_range = value_range
|
|
118
138
|
self.epsilon = epsilon
|
|
139
|
+
if output_dtype == "int8":
|
|
140
|
+
if value_range[0] < -128 or value_range[1] > 127:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Quantizer with output_dtype='int8' requires value_range "
|
|
143
|
+
f"to be within the interval [-128, 127]. Received: "
|
|
144
|
+
f"value_range={value_range}"
|
|
145
|
+
)
|
|
119
146
|
|
|
120
|
-
def __call__(self, x):
|
|
147
|
+
def __call__(self, x, axis=None, to_numpy=False):
|
|
148
|
+
"""
|
|
149
|
+
Quantizes the input tensor.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
x: Input tensor to quantize.
|
|
153
|
+
axis: Axis along which to compute the quantization range. If None,
|
|
154
|
+
uses the axis specified in the constructor. If None and no axis
|
|
155
|
+
was specified in the constructor, defaults to -1.
|
|
156
|
+
to_numpy: Whether to perform the quantization in numpy. This
|
|
157
|
+
performs the computation on the host CPU and can be useful for
|
|
158
|
+
saving memory on the device. If False, the computation is
|
|
159
|
+
performed on the device.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
A tuple of the quantized tensor and the scale.
|
|
163
|
+
"""
|
|
164
|
+
if axis is None:
|
|
165
|
+
axis = self.axis
|
|
166
|
+
if axis is None:
|
|
167
|
+
# Default to -1 if no axis is specified
|
|
168
|
+
axis = -1
|
|
121
169
|
quantized_x, scale = abs_max_quantize(
|
|
122
|
-
x,
|
|
170
|
+
x,
|
|
171
|
+
axis,
|
|
172
|
+
self.value_range,
|
|
173
|
+
self.output_dtype,
|
|
174
|
+
self.epsilon,
|
|
175
|
+
to_numpy,
|
|
123
176
|
)
|
|
124
177
|
return quantized_x, scale
|
|
125
178
|
|
|
126
179
|
def get_config(self):
|
|
127
|
-
|
|
128
|
-
"axis": self.axis,
|
|
180
|
+
config = {
|
|
129
181
|
"value_range": self.value_range,
|
|
130
182
|
"epsilon": self.epsilon,
|
|
131
183
|
"output_dtype": self.output_dtype,
|
|
132
184
|
}
|
|
185
|
+
if self.axis is not None:
|
|
186
|
+
config["axis"] = self.axis
|
|
187
|
+
return config
|
|
133
188
|
|
|
134
189
|
|
|
135
190
|
def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
|
|
@@ -281,7 +336,7 @@ def fake_quant_with_min_max_vars(
|
|
|
281
336
|
ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
|
|
282
337
|
)
|
|
283
338
|
x_clamped = ops.clip(
|
|
284
|
-
|
|
339
|
+
ops.cast(x, nudged_min.dtype), nudged_min, nudged_max
|
|
285
340
|
)
|
|
286
341
|
x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
|
|
287
342
|
result = ops.multiply(
|
|
@@ -318,6 +373,7 @@ def fake_quant_with_min_max_vars(
|
|
|
318
373
|
grad_min = ops.sum(grad_min, axis=axes)
|
|
319
374
|
else:
|
|
320
375
|
grad_min = ops.sum(grad_min)
|
|
376
|
+
grad_min = ops.reshape(grad_min, ops.shape(min_val))
|
|
321
377
|
|
|
322
378
|
# Gradient for max_val
|
|
323
379
|
# When x is clipped to max, the gradient flows to max_val
|
|
@@ -327,6 +383,7 @@ def fake_quant_with_min_max_vars(
|
|
|
327
383
|
grad_max = ops.sum(grad_max, axis=axes)
|
|
328
384
|
else:
|
|
329
385
|
grad_max = ops.sum(grad_max)
|
|
386
|
+
grad_max = ops.reshape(grad_max, ops.shape(max_val))
|
|
330
387
|
|
|
331
388
|
return dx, grad_min, grad_max
|
|
332
389
|
|
|
@@ -378,7 +435,7 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype):
|
|
|
378
435
|
|
|
379
436
|
|
|
380
437
|
@keras_export("keras.quantizers.pack_int4")
|
|
381
|
-
def pack_int4(arr, axis=0):
|
|
438
|
+
def pack_int4(arr, axis=0, dtype="int8"):
|
|
382
439
|
"""Pack an int4 tensor into an int8 tensor with packed nibbles.
|
|
383
440
|
|
|
384
441
|
The input values must already be int8 in the signed range `[-8, 7]` and
|
|
@@ -390,8 +447,11 @@ def pack_int4(arr, axis=0):
|
|
|
390
447
|
the value from the second row.
|
|
391
448
|
|
|
392
449
|
Args:
|
|
393
|
-
arr: An int8 tensor containing int4 values in the range
|
|
450
|
+
arr: An `int8` or `uint8` tensor containing int4 values in the range
|
|
451
|
+
`[-8, 7]`.
|
|
394
452
|
axis: The axis along which to pack the tensor. Defaults to 0.
|
|
453
|
+
dtype: The data type of the input and packed tensor. Can be
|
|
454
|
+
`"int8"` or `"uint8"`. Defaults to `"int8"`.
|
|
395
455
|
|
|
396
456
|
Returns:
|
|
397
457
|
tuple: A tuple `(packed, packed_shape, orig_rows)` where `packed` is
|
|
@@ -451,9 +511,14 @@ def pack_int4(arr, axis=0):
|
|
|
451
511
|
True
|
|
452
512
|
```
|
|
453
513
|
"""
|
|
454
|
-
if
|
|
514
|
+
if dtype not in ("int8", "uint8"):
|
|
515
|
+
raise ValueError(
|
|
516
|
+
f"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'."
|
|
517
|
+
)
|
|
518
|
+
if backend.standardize_dtype(arr.dtype) != dtype:
|
|
455
519
|
raise TypeError(
|
|
456
|
-
"Expected
|
|
520
|
+
f"Expected {dtype} tensor for packing, got "
|
|
521
|
+
f"{backend.standardize_dtype(arr.dtype)}."
|
|
457
522
|
)
|
|
458
523
|
|
|
459
524
|
rank = getattr(arr.shape, "rank", None) or len(arr.shape)
|
|
@@ -487,12 +552,12 @@ def pack_int4(arr, axis=0):
|
|
|
487
552
|
low = padded[::2, ...]
|
|
488
553
|
high = padded[1::2, ...]
|
|
489
554
|
|
|
490
|
-
mask = ops.array(0x0F, dtype=
|
|
555
|
+
mask = ops.array(0x0F, dtype=dtype)
|
|
491
556
|
low_u = ops.bitwise_and(low, mask)
|
|
492
557
|
high_u = ops.bitwise_and(high, mask)
|
|
493
558
|
|
|
494
559
|
packed = ops.bitwise_or(low_u, ops.left_shift(high_u, 4))
|
|
495
|
-
packed = ops.cast(packed,
|
|
560
|
+
packed = ops.cast(packed, dtype)
|
|
496
561
|
|
|
497
562
|
# 5-6. Restore shape.
|
|
498
563
|
packed = ops.transpose(packed, inv_perm) # back to original order
|
|
@@ -501,7 +566,7 @@ def pack_int4(arr, axis=0):
|
|
|
501
566
|
|
|
502
567
|
|
|
503
568
|
@keras_export("keras.quantizers.unpack_int4")
|
|
504
|
-
def unpack_int4(packed, orig_len, axis=0):
|
|
569
|
+
def unpack_int4(packed, orig_len, axis=0, dtype="int8"):
|
|
505
570
|
"""Unpack a packed int4 back to an int8 tensor in the range [-8, 7].
|
|
506
571
|
|
|
507
572
|
This function reverses the packing performed by `pack_int4`, restoring
|
|
@@ -519,6 +584,8 @@ def unpack_int4(packed, orig_len, axis=0):
|
|
|
519
584
|
packed. This is used to remove any padding that may have
|
|
520
585
|
been added during packing to ensure an even number of rows.
|
|
521
586
|
axis: The axis along which the tensor was packed. Defaults to 0.
|
|
587
|
+
dtype: The data type of the input and unpacked tensor. Can be
|
|
588
|
+
`"int8"` or `"uint8"`. Defaults to `"int8"`.
|
|
522
589
|
|
|
523
590
|
Returns:
|
|
524
591
|
unpacked: An int8 tensor with the same shape as the original
|
|
@@ -575,13 +642,24 @@ def unpack_int4(packed, orig_len, axis=0):
|
|
|
575
642
|
True
|
|
576
643
|
```
|
|
577
644
|
"""
|
|
578
|
-
if
|
|
645
|
+
if dtype not in ("int8", "uint8"):
|
|
646
|
+
raise ValueError(
|
|
647
|
+
f"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'."
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
if backend.standardize_dtype(packed.dtype) not in ("int8", "uint8"):
|
|
579
651
|
raise TypeError(
|
|
580
|
-
f"Expected int8 tensor for unpacking, got {packed.dtype}"
|
|
652
|
+
f"Expected int8 or uint8 tensor for unpacking, got {packed.dtype}"
|
|
581
653
|
)
|
|
582
654
|
|
|
583
|
-
|
|
655
|
+
def to_signed(x):
|
|
656
|
+
"""Converts unpacked nibbles [0, 15] to signed int4 [-8, 7]."""
|
|
657
|
+
dtype_x = backend.standardize_dtype(x.dtype)
|
|
658
|
+
eight = ops.cast(8, dtype_x)
|
|
659
|
+
sixteen = ops.cast(16, dtype_x)
|
|
660
|
+
return ops.where(x < eight, x, x - sixteen)
|
|
584
661
|
|
|
662
|
+
rank = getattr(packed.shape, "rank", None) or len(packed.shape)
|
|
585
663
|
if axis < 0:
|
|
586
664
|
axis += rank
|
|
587
665
|
|
|
@@ -592,16 +670,15 @@ def unpack_int4(packed, orig_len, axis=0):
|
|
|
592
670
|
low_unpacked = ops.bitwise_and(packed, mask)
|
|
593
671
|
high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), mask)
|
|
594
672
|
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
)
|
|
673
|
+
if dtype == "int8":
|
|
674
|
+
low_unpacked = to_signed(low_unpacked)
|
|
675
|
+
high_unpacked = to_signed(high_unpacked)
|
|
676
|
+
|
|
677
|
+
low_final = ops.cast(low_unpacked, dtype)
|
|
678
|
+
high_final = ops.cast(high_unpacked, dtype)
|
|
602
679
|
|
|
603
680
|
# Interleave and reshape
|
|
604
|
-
stacked = ops.stack([
|
|
681
|
+
stacked = ops.stack([low_final, high_final], axis=1)
|
|
605
682
|
unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(packed)[1:]))
|
|
606
683
|
|
|
607
684
|
# Remove padding and return
|
|
@@ -613,28 +690,27 @@ def unpack_int4(packed, orig_len, axis=0):
|
|
|
613
690
|
transposed = ops.transpose(packed, perm)
|
|
614
691
|
|
|
615
692
|
# 1. Split nibbles.
|
|
616
|
-
mask = ops.array(0x0F, dtype=
|
|
693
|
+
mask = ops.array(0x0F, dtype=packed.dtype)
|
|
617
694
|
low = ops.bitwise_and(transposed, mask)
|
|
618
695
|
high = ops.bitwise_and(ops.right_shift(transposed, 4), mask)
|
|
619
696
|
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
return ops.where(x < eight, x, x - sixteen)
|
|
697
|
+
# 2. Conditionally convert to signed.
|
|
698
|
+
if dtype == "int8":
|
|
699
|
+
low = to_signed(low)
|
|
700
|
+
high = to_signed(high)
|
|
625
701
|
|
|
626
|
-
low =
|
|
627
|
-
high =
|
|
702
|
+
low = ops.cast(low, dtype)
|
|
703
|
+
high = ops.cast(high, dtype)
|
|
628
704
|
|
|
629
|
-
#
|
|
630
|
-
stacked = ops.stack([low, high], axis=1)
|
|
705
|
+
# 3. Interleave and reshape.
|
|
706
|
+
stacked = ops.stack([low, high], axis=1)
|
|
631
707
|
unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:]))
|
|
632
708
|
|
|
633
709
|
# 4. Remove padding and restore original layout.
|
|
634
710
|
unpacked = unpacked[:orig_len, ...]
|
|
635
711
|
unpacked = ops.transpose(unpacked, inv_perm)
|
|
636
712
|
|
|
637
|
-
return unpacked
|
|
713
|
+
return unpacked
|
|
638
714
|
|
|
639
715
|
|
|
640
716
|
class GPTQQuantizer(Quantizer):
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def should_quantize_layer(layer, filters):
|
|
5
|
+
"""Determines if a layer should be quantized based on filters.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
layer: The layer to check.
|
|
9
|
+
filters: A regex string, a list of regex strings, or a callable.
|
|
10
|
+
If None, returns True.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
True if the layer should be quantized, False otherwise.
|
|
14
|
+
"""
|
|
15
|
+
if filters is None:
|
|
16
|
+
return True
|
|
17
|
+
if isinstance(filters, str):
|
|
18
|
+
return bool(re.search(filters, layer.name))
|
|
19
|
+
if isinstance(filters, (list, tuple)):
|
|
20
|
+
return any(re.search(pat, layer.name) for pat in filters)
|
|
21
|
+
if callable(filters):
|
|
22
|
+
return filters(layer)
|
|
23
|
+
return True
|
|
@@ -8,6 +8,8 @@ from keras.src.backend.common import global_state
|
|
|
8
8
|
from keras.src.utils import jax_utils
|
|
9
9
|
from keras.src.utils.naming import auto_name
|
|
10
10
|
|
|
11
|
+
GLOBAL_SEED_GENERATOR = "global_seed_generator"
|
|
12
|
+
|
|
11
13
|
|
|
12
14
|
@keras_export("keras.random.SeedGenerator")
|
|
13
15
|
class SeedGenerator:
|
|
@@ -133,10 +135,10 @@ def global_seed_generator():
|
|
|
133
135
|
"out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n"
|
|
134
136
|
"```"
|
|
135
137
|
)
|
|
136
|
-
gen = global_state.get_global_attribute(
|
|
138
|
+
gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR)
|
|
137
139
|
if gen is None:
|
|
138
140
|
gen = SeedGenerator()
|
|
139
|
-
global_state.set_global_attribute(
|
|
141
|
+
global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen)
|
|
140
142
|
return gen
|
|
141
143
|
|
|
142
144
|
|
keras/src/saving/file_editor.py
CHANGED
|
@@ -455,6 +455,9 @@ class KerasFileEditor:
|
|
|
455
455
|
def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
|
|
456
456
|
metadata = metadata or {}
|
|
457
457
|
|
|
458
|
+
# ------------------------------------------------------
|
|
459
|
+
# Collect metadata for this HDF5 group
|
|
460
|
+
# ------------------------------------------------------
|
|
458
461
|
object_metadata = {}
|
|
459
462
|
for k, v in data.attrs.items():
|
|
460
463
|
object_metadata[k] = v
|
|
@@ -462,26 +465,98 @@ class KerasFileEditor:
|
|
|
462
465
|
metadata[inner_path] = object_metadata
|
|
463
466
|
|
|
464
467
|
result = collections.OrderedDict()
|
|
468
|
+
|
|
469
|
+
# ------------------------------------------------------
|
|
470
|
+
# Iterate over all keys in this HDF5 group
|
|
471
|
+
# ------------------------------------------------------
|
|
465
472
|
for key in data.keys():
|
|
466
|
-
|
|
473
|
+
# IMPORTANT:
|
|
474
|
+
# Never mutate inner_path; use local variable.
|
|
475
|
+
current_inner_path = f"{inner_path}/{key}"
|
|
467
476
|
value = data[key]
|
|
477
|
+
|
|
478
|
+
# ------------------------------------------------------
|
|
479
|
+
# CASE 1 — HDF5 GROUP → RECURSE
|
|
480
|
+
# ------------------------------------------------------
|
|
468
481
|
if isinstance(value, h5py.Group):
|
|
482
|
+
# Skip empty groups
|
|
469
483
|
if len(value) == 0:
|
|
470
484
|
continue
|
|
485
|
+
|
|
486
|
+
# Skip empty "vars" groups
|
|
471
487
|
if "vars" in value.keys() and len(value["vars"]) == 0:
|
|
472
488
|
continue
|
|
473
489
|
|
|
474
|
-
|
|
490
|
+
# Recurse into "vars" subgroup when present
|
|
475
491
|
if "vars" in value.keys():
|
|
476
492
|
result[key], metadata = self._extract_weights_from_store(
|
|
477
|
-
value["vars"],
|
|
493
|
+
value["vars"],
|
|
494
|
+
metadata=metadata,
|
|
495
|
+
inner_path=current_inner_path,
|
|
478
496
|
)
|
|
479
497
|
else:
|
|
498
|
+
# Recurse normally
|
|
480
499
|
result[key], metadata = self._extract_weights_from_store(
|
|
481
|
-
value,
|
|
500
|
+
value,
|
|
501
|
+
metadata=metadata,
|
|
502
|
+
inner_path=current_inner_path,
|
|
482
503
|
)
|
|
483
|
-
|
|
484
|
-
|
|
504
|
+
|
|
505
|
+
continue # finished processing this key
|
|
506
|
+
|
|
507
|
+
# ------------------------------------------------------
|
|
508
|
+
# CASE 2 — HDF5 DATASET → SAFE LOADING
|
|
509
|
+
# ------------------------------------------------------
|
|
510
|
+
|
|
511
|
+
# Skip any objects that are not proper datasets
|
|
512
|
+
if not hasattr(value, "shape") or not hasattr(value, "dtype"):
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
shape = value.shape
|
|
516
|
+
dtype = value.dtype
|
|
517
|
+
|
|
518
|
+
# ------------------------------------------------------
|
|
519
|
+
# Validate SHAPE (avoid malformed / malicious metadata)
|
|
520
|
+
# ------------------------------------------------------
|
|
521
|
+
|
|
522
|
+
# No negative dimensions
|
|
523
|
+
if any(dim < 0 for dim in shape):
|
|
524
|
+
raise ValueError(
|
|
525
|
+
"Malformed HDF5 dataset shape encountered in .keras file; "
|
|
526
|
+
"negative dimension detected."
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Prevent absurdly high-rank tensors
|
|
530
|
+
if len(shape) > 64:
|
|
531
|
+
raise ValueError(
|
|
532
|
+
"Malformed HDF5 dataset shape encountered in .keras file; "
|
|
533
|
+
"tensor rank exceeds safety limit."
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
# Safe product computation (Python int is unbounded)
|
|
537
|
+
num_elems = int(np.prod(shape))
|
|
538
|
+
|
|
539
|
+
# ------------------------------------------------------
|
|
540
|
+
# Validate TOTAL memory size
|
|
541
|
+
# ------------------------------------------------------
|
|
542
|
+
MAX_BYTES = 1 << 32 # 4 GiB
|
|
543
|
+
|
|
544
|
+
size_bytes = num_elems * dtype.itemsize
|
|
545
|
+
|
|
546
|
+
if size_bytes > MAX_BYTES:
|
|
547
|
+
raise ValueError(
|
|
548
|
+
f"HDF5 dataset too large to load safely "
|
|
549
|
+
f"({size_bytes} bytes; limit is {MAX_BYTES})."
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
# ------------------------------------------------------
|
|
553
|
+
# SAFE — load dataset (guaranteed ≤ 4 GiB)
|
|
554
|
+
# ------------------------------------------------------
|
|
555
|
+
result[key] = value[()]
|
|
556
|
+
|
|
557
|
+
# ------------------------------------------------------
|
|
558
|
+
# Return final tree and metadata
|
|
559
|
+
# ------------------------------------------------------
|
|
485
560
|
return result, metadata
|
|
486
561
|
|
|
487
562
|
def _generate_filepath_info(self, rich_style=False):
|
keras/src/saving/saving_lib.py
CHANGED
|
@@ -943,7 +943,7 @@ class DiskIOStore:
|
|
|
943
943
|
if self.archive:
|
|
944
944
|
self.tmp_dir = get_temp_dir()
|
|
945
945
|
if self.mode == "r":
|
|
946
|
-
self.archive
|
|
946
|
+
file_utils.extract_open_archive(self.archive, self.tmp_dir)
|
|
947
947
|
self.working_dir = file_utils.join(
|
|
948
948
|
self.tmp_dir, self.root_path
|
|
949
949
|
).replace("\\", "/")
|
keras/src/testing/__init__.py
CHANGED
keras/src/testing/test_case.py
CHANGED
|
@@ -40,7 +40,20 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
40
40
|
self.addCleanup(lambda: shutil.rmtree(temp_dir))
|
|
41
41
|
return temp_dir
|
|
42
42
|
|
|
43
|
-
def assertAllClose(
|
|
43
|
+
def assertAllClose(
|
|
44
|
+
self,
|
|
45
|
+
x1,
|
|
46
|
+
x2,
|
|
47
|
+
atol=1e-6,
|
|
48
|
+
rtol=1e-6,
|
|
49
|
+
tpu_atol=None,
|
|
50
|
+
tpu_rtol=None,
|
|
51
|
+
msg=None,
|
|
52
|
+
):
|
|
53
|
+
if tpu_atol is not None and uses_tpu():
|
|
54
|
+
atol = tpu_atol
|
|
55
|
+
if tpu_rtol is not None and uses_tpu():
|
|
56
|
+
rtol = tpu_rtol
|
|
44
57
|
if not isinstance(x1, np.ndarray):
|
|
45
58
|
x1 = backend.convert_to_numpy(x1)
|
|
46
59
|
if not isinstance(x2, np.ndarray):
|
|
@@ -57,7 +70,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
57
70
|
f"The two values are close at all elements. \n{msg}.\nValues: {x1}"
|
|
58
71
|
)
|
|
59
72
|
|
|
60
|
-
def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
|
|
73
|
+
def assertAlmostEqual(self, x1, x2, decimal=3, tpu_decimal=None, msg=None):
|
|
74
|
+
if tpu_decimal is not None and uses_tpu():
|
|
75
|
+
decimal = tpu_decimal
|
|
61
76
|
msg = msg or ""
|
|
62
77
|
if not isinstance(x1, np.ndarray):
|
|
63
78
|
x1 = backend.convert_to_numpy(x1)
|
|
@@ -195,6 +210,8 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
195
210
|
run_training_check=True,
|
|
196
211
|
run_mixed_precision_check=True,
|
|
197
212
|
assert_built_after_instantiation=False,
|
|
213
|
+
tpu_atol=None,
|
|
214
|
+
tpu_rtol=None,
|
|
198
215
|
):
|
|
199
216
|
"""Run basic checks on a layer.
|
|
200
217
|
|
|
@@ -376,7 +393,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
376
393
|
msg="Unexpected number of torch_params",
|
|
377
394
|
)
|
|
378
395
|
|
|
379
|
-
def run_output_asserts(
|
|
396
|
+
def run_output_asserts(
|
|
397
|
+
layer, output, eager=False, tpu_atol=None, tpu_rtol=None
|
|
398
|
+
):
|
|
380
399
|
if expected_output_shape is not None:
|
|
381
400
|
|
|
382
401
|
def verify_shape(expected_shape, x):
|
|
@@ -422,7 +441,11 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
422
441
|
tree.flatten(expected_output), tree.flatten(output)
|
|
423
442
|
):
|
|
424
443
|
self.assertAllClose(
|
|
425
|
-
ref_v,
|
|
444
|
+
ref_v,
|
|
445
|
+
v,
|
|
446
|
+
msg="Unexpected output value",
|
|
447
|
+
tpu_atol=tpu_atol,
|
|
448
|
+
tpu_rtol=tpu_rtol,
|
|
426
449
|
)
|
|
427
450
|
if expected_num_losses is not None:
|
|
428
451
|
self.assertLen(layer.losses, expected_num_losses)
|
|
@@ -551,7 +574,13 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
551
574
|
output_data = layer(**input_data, **call_kwargs)
|
|
552
575
|
else:
|
|
553
576
|
output_data = layer(input_data, **call_kwargs)
|
|
554
|
-
run_output_asserts(
|
|
577
|
+
run_output_asserts(
|
|
578
|
+
layer,
|
|
579
|
+
output_data,
|
|
580
|
+
eager=True,
|
|
581
|
+
tpu_atol=tpu_atol,
|
|
582
|
+
tpu_rtol=tpu_rtol,
|
|
583
|
+
)
|
|
555
584
|
|
|
556
585
|
if run_training_check:
|
|
557
586
|
run_training_step(layer, input_data, output_data)
|
|
@@ -621,6 +650,17 @@ def uses_gpu():
|
|
|
621
650
|
return False
|
|
622
651
|
|
|
623
652
|
|
|
653
|
+
def uses_tpu():
|
|
654
|
+
# Condition used to skip tests when using the TPU
|
|
655
|
+
try:
|
|
656
|
+
devices = distribution.list_devices()
|
|
657
|
+
if any(d.startswith("tpu") for d in devices):
|
|
658
|
+
return True
|
|
659
|
+
except AttributeError:
|
|
660
|
+
return False
|
|
661
|
+
return False
|
|
662
|
+
|
|
663
|
+
|
|
624
664
|
def uses_cpu():
|
|
625
665
|
devices = distribution.list_devices()
|
|
626
666
|
if any(d.startswith("cpu") for d in devices):
|
|
@@ -148,6 +148,7 @@ class CompileMetrics(metrics_module.Metric):
|
|
|
148
148
|
self.built = False
|
|
149
149
|
self.name = "compile_metrics"
|
|
150
150
|
self.output_names = output_names
|
|
151
|
+
self._resolved_output_names = None
|
|
151
152
|
|
|
152
153
|
@property
|
|
153
154
|
def metrics(self):
|
|
@@ -175,10 +176,16 @@ class CompileMetrics(metrics_module.Metric):
|
|
|
175
176
|
|
|
176
177
|
def build(self, y_true, y_pred):
|
|
177
178
|
num_outputs = 1 # default
|
|
178
|
-
|
|
179
|
+
# Resolve output names. If y_pred is a dict, prefer its keys.
|
|
180
|
+
if isinstance(y_pred, dict):
|
|
181
|
+
keys = sorted(list(y_pred.keys()))
|
|
182
|
+
if self.output_names and set(self.output_names) == set(keys):
|
|
183
|
+
# If there is a perfect match, use the user-provided order.
|
|
184
|
+
output_names = self.output_names
|
|
185
|
+
else:
|
|
186
|
+
output_names = keys
|
|
187
|
+
elif self.output_names:
|
|
179
188
|
output_names = self.output_names
|
|
180
|
-
elif isinstance(y_pred, dict):
|
|
181
|
-
output_names = sorted(list(y_pred.keys()))
|
|
182
189
|
elif isinstance(y_pred, (list, tuple)):
|
|
183
190
|
num_outputs = len(y_pred)
|
|
184
191
|
if all(hasattr(x, "_keras_history") for x in y_pred):
|
|
@@ -187,6 +194,7 @@ class CompileMetrics(metrics_module.Metric):
|
|
|
187
194
|
output_names = None
|
|
188
195
|
else:
|
|
189
196
|
output_names = None
|
|
197
|
+
self._resolved_output_names = output_names
|
|
190
198
|
if output_names:
|
|
191
199
|
num_outputs = len(output_names)
|
|
192
200
|
|
|
@@ -316,9 +324,10 @@ class CompileMetrics(metrics_module.Metric):
|
|
|
316
324
|
return flat_metrics
|
|
317
325
|
|
|
318
326
|
def _flatten_y(self, y):
|
|
319
|
-
|
|
327
|
+
names = self._resolved_output_names
|
|
328
|
+
if isinstance(y, dict) and names:
|
|
320
329
|
result = []
|
|
321
|
-
for name in
|
|
330
|
+
for name in names:
|
|
322
331
|
if name in y:
|
|
323
332
|
result.append(y[name])
|
|
324
333
|
return result
|
keras/src/utils/backend_utils.py
CHANGED
|
@@ -3,6 +3,7 @@ import importlib
|
|
|
3
3
|
import inspect
|
|
4
4
|
import os
|
|
5
5
|
import sys
|
|
6
|
+
import warnings
|
|
6
7
|
|
|
7
8
|
from keras.src import backend as backend_module
|
|
8
9
|
from keras.src.api_export import keras_export
|
|
@@ -124,9 +125,22 @@ def set_backend(backend):
|
|
|
124
125
|
|
|
125
126
|
Example:
|
|
126
127
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
128
|
+
>>> import os
|
|
129
|
+
>>> os.environ["KERAS_BACKEND"] = "tensorflow"
|
|
130
|
+
>>>
|
|
131
|
+
>>> import keras
|
|
132
|
+
>>> from keras import ops
|
|
133
|
+
>>> type(ops.ones(()))
|
|
134
|
+
<class 'tensorflow.python.framework.ops.EagerTensor'>
|
|
135
|
+
>>>
|
|
136
|
+
>>> keras.config.set_backend("jax")
|
|
137
|
+
UserWarning: Using `keras.config.set_backend` is dangerous...
|
|
138
|
+
>>> del keras, ops
|
|
139
|
+
>>>
|
|
140
|
+
>>> import keras
|
|
141
|
+
>>> from keras import ops
|
|
142
|
+
>>> type(ops.ones(()))
|
|
143
|
+
<class 'jaxlib.xla_extension.ArrayImpl'>
|
|
130
144
|
|
|
131
145
|
⚠️ WARNING ⚠️: Using this function is dangerous and should be done
|
|
132
146
|
carefully. Changing the backend will **NOT** convert
|
|
@@ -138,7 +152,7 @@ def set_backend(backend):
|
|
|
138
152
|
|
|
139
153
|
This includes any function or class instance that uses any Keras
|
|
140
154
|
functionality. All such code needs to be re-executed after calling
|
|
141
|
-
`set_backend()
|
|
155
|
+
`set_backend()` and re-importing all imported `keras` modules.
|
|
142
156
|
"""
|
|
143
157
|
os.environ["KERAS_BACKEND"] = backend
|
|
144
158
|
# Clear module cache.
|
|
@@ -159,3 +173,16 @@ def set_backend(backend):
|
|
|
159
173
|
module_name = module_name[module_name.find("'") + 1 :]
|
|
160
174
|
module_name = module_name[: module_name.find("'")]
|
|
161
175
|
globals()[key] = importlib.import_module(module_name)
|
|
176
|
+
|
|
177
|
+
warnings.warn(
|
|
178
|
+
"Using `keras.config.set_backend` is dangerous and should be done "
|
|
179
|
+
"carefully. Already-instantiated objects will not be converted. Thus, "
|
|
180
|
+
"any layers / tensors / etc. already created will no longer be usable "
|
|
181
|
+
"without errors. It is strongly recommended not to keep around any "
|
|
182
|
+
"Keras-originated objects instances created before calling "
|
|
183
|
+
"`set_backend()`. This includes any function or class instance that "
|
|
184
|
+
"uses any Keras functionality. All such code needs to be re-executed "
|
|
185
|
+
"after calling `set_backend()` and re-importing all imported `keras` "
|
|
186
|
+
"modules.",
|
|
187
|
+
stacklevel=2,
|
|
188
|
+
)
|