keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -9,6 +9,7 @@ from keras.src.backend import any_symbolic_tensors
|
|
|
9
9
|
from keras.src.backend.common.backend_utils import canonicalize_axis
|
|
10
10
|
from keras.src.backend.common.backend_utils import standardize_axis_for_numpy
|
|
11
11
|
from keras.src.ops.operation import Operation
|
|
12
|
+
from keras.src.quantizers.gptq_config import GPTQConfig
|
|
12
13
|
|
|
13
14
|
"""Int8-related classes and methods"""
|
|
14
15
|
|
|
@@ -72,6 +73,23 @@ def abs_max_quantize(
|
|
|
72
73
|
epsilon=backend.epsilon(),
|
|
73
74
|
to_numpy=False,
|
|
74
75
|
):
|
|
76
|
+
"""
|
|
77
|
+
Quantizes the input tensor using the absolute maximum quantization scheme.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
inputs: Input tensor to quantize.
|
|
81
|
+
axis: Axis along which to compute the quantization range.
|
|
82
|
+
value_range: Tuple of the minimum and maximum values of the quantization
|
|
83
|
+
range.
|
|
84
|
+
dtype: Data type of the quantized output.
|
|
85
|
+
epsilon: Small value to avoid division by zero.
|
|
86
|
+
to_numpy: Whether to perform the quantization in numpy. This performs
|
|
87
|
+
the computation on the host CPU and can be useful for saving memory
|
|
88
|
+
on the device. If False, the computation is performed on the device.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
A tuple of the quantized tensor and the scale.
|
|
92
|
+
"""
|
|
75
93
|
if to_numpy:
|
|
76
94
|
# Save memory on the device using numpy
|
|
77
95
|
original_dtype = backend.standardize_dtype(inputs.dtype)
|
|
@@ -104,31 +122,69 @@ def abs_max_quantize(
|
|
|
104
122
|
class AbsMaxQuantizer(Quantizer):
|
|
105
123
|
def __init__(
|
|
106
124
|
self,
|
|
107
|
-
axis,
|
|
125
|
+
axis=None, # Deprecated, provide axis in __call__ instead.
|
|
108
126
|
value_range=(-127, 127),
|
|
109
127
|
epsilon=backend.epsilon(),
|
|
110
128
|
output_dtype="int8",
|
|
111
129
|
):
|
|
112
130
|
Quantizer.__init__(self, output_dtype=output_dtype)
|
|
113
|
-
if
|
|
114
|
-
|
|
115
|
-
|
|
131
|
+
if axis is not None:
|
|
132
|
+
if isinstance(axis, int):
|
|
133
|
+
axis = (axis,)
|
|
134
|
+
self.axis = tuple(axis)
|
|
135
|
+
else:
|
|
136
|
+
self.axis = None
|
|
116
137
|
self.value_range = value_range
|
|
117
138
|
self.epsilon = epsilon
|
|
139
|
+
if output_dtype == "int8":
|
|
140
|
+
if value_range[0] < -128 or value_range[1] > 127:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Quantizer with output_dtype='int8' requires value_range "
|
|
143
|
+
f"to be within the interval [-128, 127]. Received: "
|
|
144
|
+
f"value_range={value_range}"
|
|
145
|
+
)
|
|
118
146
|
|
|
119
|
-
def __call__(self, x):
|
|
147
|
+
def __call__(self, x, axis=None, to_numpy=False):
|
|
148
|
+
"""
|
|
149
|
+
Quantizes the input tensor.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
x: Input tensor to quantize.
|
|
153
|
+
axis: Axis along which to compute the quantization range. If None,
|
|
154
|
+
uses the axis specified in the constructor. If None and no axis
|
|
155
|
+
was specified in the constructor, defaults to -1.
|
|
156
|
+
to_numpy: Whether to perform the quantization in numpy. This
|
|
157
|
+
performs the computation on the host CPU and can be useful for
|
|
158
|
+
saving memory on the device. If False, the computation is
|
|
159
|
+
performed on the device.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
A tuple of the quantized tensor and the scale.
|
|
163
|
+
"""
|
|
164
|
+
if axis is None:
|
|
165
|
+
axis = self.axis
|
|
166
|
+
if axis is None:
|
|
167
|
+
# Default to -1 if no axis is specified
|
|
168
|
+
axis = -1
|
|
120
169
|
quantized_x, scale = abs_max_quantize(
|
|
121
|
-
x,
|
|
170
|
+
x,
|
|
171
|
+
axis,
|
|
172
|
+
self.value_range,
|
|
173
|
+
self.output_dtype,
|
|
174
|
+
self.epsilon,
|
|
175
|
+
to_numpy,
|
|
122
176
|
)
|
|
123
177
|
return quantized_x, scale
|
|
124
178
|
|
|
125
179
|
def get_config(self):
|
|
126
|
-
|
|
127
|
-
"axis": self.axis,
|
|
180
|
+
config = {
|
|
128
181
|
"value_range": self.value_range,
|
|
129
182
|
"epsilon": self.epsilon,
|
|
130
183
|
"output_dtype": self.output_dtype,
|
|
131
184
|
}
|
|
185
|
+
if self.axis is not None:
|
|
186
|
+
config["axis"] = self.axis
|
|
187
|
+
return config
|
|
132
188
|
|
|
133
189
|
|
|
134
190
|
def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
|
|
@@ -280,7 +336,7 @@ def fake_quant_with_min_max_vars(
|
|
|
280
336
|
ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
|
|
281
337
|
)
|
|
282
338
|
x_clamped = ops.clip(
|
|
283
|
-
|
|
339
|
+
ops.cast(x, nudged_min.dtype), nudged_min, nudged_max
|
|
284
340
|
)
|
|
285
341
|
x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
|
|
286
342
|
result = ops.multiply(
|
|
@@ -317,6 +373,7 @@ def fake_quant_with_min_max_vars(
|
|
|
317
373
|
grad_min = ops.sum(grad_min, axis=axes)
|
|
318
374
|
else:
|
|
319
375
|
grad_min = ops.sum(grad_min)
|
|
376
|
+
grad_min = ops.reshape(grad_min, ops.shape(min_val))
|
|
320
377
|
|
|
321
378
|
# Gradient for max_val
|
|
322
379
|
# When x is clipped to max, the gradient flows to max_val
|
|
@@ -326,6 +383,7 @@ def fake_quant_with_min_max_vars(
|
|
|
326
383
|
grad_max = ops.sum(grad_max, axis=axes)
|
|
327
384
|
else:
|
|
328
385
|
grad_max = ops.sum(grad_max)
|
|
386
|
+
grad_max = ops.reshape(grad_max, ops.shape(max_val))
|
|
329
387
|
|
|
330
388
|
return dx, grad_min, grad_max
|
|
331
389
|
|
|
@@ -377,7 +435,7 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype):
|
|
|
377
435
|
|
|
378
436
|
|
|
379
437
|
@keras_export("keras.quantizers.pack_int4")
|
|
380
|
-
def pack_int4(arr, axis=0):
|
|
438
|
+
def pack_int4(arr, axis=0, dtype="int8"):
|
|
381
439
|
"""Pack an int4 tensor into an int8 tensor with packed nibbles.
|
|
382
440
|
|
|
383
441
|
The input values must already be int8 in the signed range `[-8, 7]` and
|
|
@@ -389,8 +447,11 @@ def pack_int4(arr, axis=0):
|
|
|
389
447
|
the value from the second row.
|
|
390
448
|
|
|
391
449
|
Args:
|
|
392
|
-
arr: An int8 tensor containing int4 values in the range
|
|
450
|
+
arr: An `int8` or `uint8` tensor containing int4 values in the range
|
|
451
|
+
`[-8, 7]`.
|
|
393
452
|
axis: The axis along which to pack the tensor. Defaults to 0.
|
|
453
|
+
dtype: The data type of the input and packed tensor. Can be
|
|
454
|
+
`"int8"` or `"uint8"`. Defaults to `"int8"`.
|
|
394
455
|
|
|
395
456
|
Returns:
|
|
396
457
|
tuple: A tuple `(packed, packed_shape, orig_rows)` where `packed` is
|
|
@@ -450,9 +511,14 @@ def pack_int4(arr, axis=0):
|
|
|
450
511
|
True
|
|
451
512
|
```
|
|
452
513
|
"""
|
|
453
|
-
if
|
|
514
|
+
if dtype not in ("int8", "uint8"):
|
|
515
|
+
raise ValueError(
|
|
516
|
+
f"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'."
|
|
517
|
+
)
|
|
518
|
+
if backend.standardize_dtype(arr.dtype) != dtype:
|
|
454
519
|
raise TypeError(
|
|
455
|
-
"Expected
|
|
520
|
+
f"Expected {dtype} tensor for packing, got "
|
|
521
|
+
f"{backend.standardize_dtype(arr.dtype)}."
|
|
456
522
|
)
|
|
457
523
|
|
|
458
524
|
rank = getattr(arr.shape, "rank", None) or len(arr.shape)
|
|
@@ -486,12 +552,12 @@ def pack_int4(arr, axis=0):
|
|
|
486
552
|
low = padded[::2, ...]
|
|
487
553
|
high = padded[1::2, ...]
|
|
488
554
|
|
|
489
|
-
mask = ops.array(0x0F, dtype=
|
|
555
|
+
mask = ops.array(0x0F, dtype=dtype)
|
|
490
556
|
low_u = ops.bitwise_and(low, mask)
|
|
491
557
|
high_u = ops.bitwise_and(high, mask)
|
|
492
558
|
|
|
493
559
|
packed = ops.bitwise_or(low_u, ops.left_shift(high_u, 4))
|
|
494
|
-
packed = ops.cast(packed,
|
|
560
|
+
packed = ops.cast(packed, dtype)
|
|
495
561
|
|
|
496
562
|
# 5-6. Restore shape.
|
|
497
563
|
packed = ops.transpose(packed, inv_perm) # back to original order
|
|
@@ -500,7 +566,7 @@ def pack_int4(arr, axis=0):
|
|
|
500
566
|
|
|
501
567
|
|
|
502
568
|
@keras_export("keras.quantizers.unpack_int4")
|
|
503
|
-
def unpack_int4(packed, orig_len, axis=0):
|
|
569
|
+
def unpack_int4(packed, orig_len, axis=0, dtype="int8"):
|
|
504
570
|
"""Unpack a packed int4 back to an int8 tensor in the range [-8, 7].
|
|
505
571
|
|
|
506
572
|
This function reverses the packing performed by `pack_int4`, restoring
|
|
@@ -518,6 +584,8 @@ def unpack_int4(packed, orig_len, axis=0):
|
|
|
518
584
|
packed. This is used to remove any padding that may have
|
|
519
585
|
been added during packing to ensure an even number of rows.
|
|
520
586
|
axis: The axis along which the tensor was packed. Defaults to 0.
|
|
587
|
+
dtype: The data type of the input and unpacked tensor. Can be
|
|
588
|
+
`"int8"` or `"uint8"`. Defaults to `"int8"`.
|
|
521
589
|
|
|
522
590
|
Returns:
|
|
523
591
|
unpacked: An int8 tensor with the same shape as the original
|
|
@@ -574,13 +642,27 @@ def unpack_int4(packed, orig_len, axis=0):
|
|
|
574
642
|
True
|
|
575
643
|
```
|
|
576
644
|
"""
|
|
577
|
-
if
|
|
645
|
+
if dtype not in ("int8", "uint8"):
|
|
646
|
+
raise ValueError(
|
|
647
|
+
f"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'."
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
if backend.standardize_dtype(packed.dtype) not in ("int8", "uint8"):
|
|
578
651
|
raise TypeError(
|
|
579
|
-
f"Expected int8 tensor for unpacking, got {packed.dtype}"
|
|
652
|
+
f"Expected int8 or uint8 tensor for unpacking, got {packed.dtype}"
|
|
580
653
|
)
|
|
581
654
|
|
|
582
|
-
|
|
655
|
+
def to_signed(x):
|
|
656
|
+
"""Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
|
|
657
|
+
|
|
658
|
+
Uses a branchless XOR approach: (x ^ 8) - 8
|
|
659
|
+
This maps: 0->0, 1->1, ..., 7->7, 8->-8, 9->-7, ..., 15->-1
|
|
660
|
+
"""
|
|
661
|
+
dtype_x = backend.standardize_dtype(x.dtype)
|
|
662
|
+
eight = ops.cast(8, dtype_x)
|
|
663
|
+
return ops.subtract(ops.bitwise_xor(x, eight), eight)
|
|
583
664
|
|
|
665
|
+
rank = getattr(packed.shape, "rank", None) or len(packed.shape)
|
|
584
666
|
if axis < 0:
|
|
585
667
|
axis += rank
|
|
586
668
|
|
|
@@ -591,16 +673,15 @@ def unpack_int4(packed, orig_len, axis=0):
|
|
|
591
673
|
low_unpacked = ops.bitwise_and(packed, mask)
|
|
592
674
|
high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), mask)
|
|
593
675
|
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
)
|
|
676
|
+
if dtype == "int8":
|
|
677
|
+
low_unpacked = to_signed(low_unpacked)
|
|
678
|
+
high_unpacked = to_signed(high_unpacked)
|
|
679
|
+
|
|
680
|
+
low_final = ops.cast(low_unpacked, dtype)
|
|
681
|
+
high_final = ops.cast(high_unpacked, dtype)
|
|
601
682
|
|
|
602
683
|
# Interleave and reshape
|
|
603
|
-
stacked = ops.stack([
|
|
684
|
+
stacked = ops.stack([low_final, high_final], axis=1)
|
|
604
685
|
unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(packed)[1:]))
|
|
605
686
|
|
|
606
687
|
# Remove padding and return
|
|
@@ -612,25 +693,313 @@ def unpack_int4(packed, orig_len, axis=0):
|
|
|
612
693
|
transposed = ops.transpose(packed, perm)
|
|
613
694
|
|
|
614
695
|
# 1. Split nibbles.
|
|
615
|
-
mask = ops.array(0x0F, dtype=
|
|
696
|
+
mask = ops.array(0x0F, dtype=packed.dtype)
|
|
616
697
|
low = ops.bitwise_and(transposed, mask)
|
|
617
698
|
high = ops.bitwise_and(ops.right_shift(transposed, 4), mask)
|
|
618
699
|
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
return ops.where(x < eight, x, x - sixteen)
|
|
700
|
+
# 2. Conditionally convert to signed.
|
|
701
|
+
if dtype == "int8":
|
|
702
|
+
low = to_signed(low)
|
|
703
|
+
high = to_signed(high)
|
|
624
704
|
|
|
625
|
-
low =
|
|
626
|
-
high =
|
|
705
|
+
low = ops.cast(low, dtype)
|
|
706
|
+
high = ops.cast(high, dtype)
|
|
627
707
|
|
|
628
|
-
#
|
|
629
|
-
stacked = ops.stack([low, high], axis=1)
|
|
708
|
+
# 3. Interleave and reshape.
|
|
709
|
+
stacked = ops.stack([low, high], axis=1)
|
|
630
710
|
unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:]))
|
|
631
711
|
|
|
632
712
|
# 4. Remove padding and restore original layout.
|
|
633
713
|
unpacked = unpacked[:orig_len, ...]
|
|
634
714
|
unpacked = ops.transpose(unpacked, inv_perm)
|
|
635
715
|
|
|
636
|
-
return unpacked
|
|
716
|
+
return unpacked
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
class GPTQQuantizer(Quantizer):
|
|
720
|
+
"""A class that handles the quantization of weights using GPTQ method.
|
|
721
|
+
|
|
722
|
+
This class provides methods to find quantization parameters (scale and zero)
|
|
723
|
+
for a given tensor and can be used to quantize weights in a GPTQ context.
|
|
724
|
+
|
|
725
|
+
Args:
|
|
726
|
+
weight_bits: (int) The number of bits to quantize to (e.g., 4).
|
|
727
|
+
per_channel: (bool) A flag indicating whether quantization is
|
|
728
|
+
applied per-channel (`True`) or per-tensor (`False`).
|
|
729
|
+
Defaults to `False`.
|
|
730
|
+
symmetric: (bool) A flag indicating whether symmetric (`True`) or
|
|
731
|
+
asymmetric (`False`) quantization is used. Defaults to `False`.
|
|
732
|
+
group_size: (int) The size of weight groups for quantization. A
|
|
733
|
+
value of -1 indicates that grouping is not used.
|
|
734
|
+
Defaults to -1.
|
|
735
|
+
"""
|
|
736
|
+
|
|
737
|
+
def __init__(
|
|
738
|
+
self,
|
|
739
|
+
config=GPTQConfig(tokenizer=None, dataset=None),
|
|
740
|
+
compute_dtype="float32",
|
|
741
|
+
):
|
|
742
|
+
Quantizer.__init__(self)
|
|
743
|
+
self.weight_bits = config.weight_bits
|
|
744
|
+
self.per_channel = config.per_channel
|
|
745
|
+
self.symmetric = config.symmetric
|
|
746
|
+
self.group_size = config.group_size
|
|
747
|
+
self.compute_dtype = compute_dtype
|
|
748
|
+
|
|
749
|
+
# These are now determined later by `find_params`
|
|
750
|
+
self.scale = None
|
|
751
|
+
self.zero = None
|
|
752
|
+
self.maxq = None
|
|
753
|
+
|
|
754
|
+
def find_params(self, input_tensor):
|
|
755
|
+
"""Finds quantization parameters (scale and zero) for a given tensor."""
|
|
756
|
+
self.scale, self.zero, self.maxq = compute_quantization_parameters(
|
|
757
|
+
input_tensor,
|
|
758
|
+
bits=self.weight_bits,
|
|
759
|
+
symmetric=self.symmetric,
|
|
760
|
+
per_channel=self.per_channel,
|
|
761
|
+
group_size=self.group_size,
|
|
762
|
+
compute_dtype=self.compute_dtype,
|
|
763
|
+
)
|
|
764
|
+
return self.scale, self.zero, self.maxq
|
|
765
|
+
|
|
766
|
+
def get_config(self):
|
|
767
|
+
config = super().get_config()
|
|
768
|
+
config.update(
|
|
769
|
+
{
|
|
770
|
+
"weight_bits": self.weight_bits,
|
|
771
|
+
"per_channel": self.per_channel,
|
|
772
|
+
"symmetric": self.symmetric,
|
|
773
|
+
"group_size": self.group_size,
|
|
774
|
+
}
|
|
775
|
+
)
|
|
776
|
+
return config
|
|
777
|
+
|
|
778
|
+
@classmethod
|
|
779
|
+
def from_config(cls, config):
|
|
780
|
+
gptq = GPTQConfig(
|
|
781
|
+
tokenizer=None,
|
|
782
|
+
dataset=None,
|
|
783
|
+
weight_bits=config["weight_bits"],
|
|
784
|
+
per_channel=config["per_channel"],
|
|
785
|
+
symmetric=config["symmetric"],
|
|
786
|
+
group_size=config["group_size"],
|
|
787
|
+
)
|
|
788
|
+
return cls(gptq)
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
def compute_quantization_parameters(
|
|
792
|
+
x,
|
|
793
|
+
*,
|
|
794
|
+
bits,
|
|
795
|
+
symmetric=False,
|
|
796
|
+
per_channel=False,
|
|
797
|
+
group_size=-1,
|
|
798
|
+
compute_dtype="float32",
|
|
799
|
+
):
|
|
800
|
+
"""
|
|
801
|
+
Computes the scale and zero-point for quantizing weight tensors.
|
|
802
|
+
|
|
803
|
+
This function calculates the scale and zero-point required for quantizing
|
|
804
|
+
a given weight tensor `x` based on the specified parameters. It supports
|
|
805
|
+
grouped, per-channel, per-tensor, symmetric, and asymmetric quantization.
|
|
806
|
+
|
|
807
|
+
For grouped quantization (per_channel=True, group_size > 0), the output
|
|
808
|
+
shapes are [out_features, n_groups] where n_groups is the number of groups
|
|
809
|
+
along the in_features dimension.
|
|
810
|
+
|
|
811
|
+
Args:
|
|
812
|
+
x: KerasTensor. The weight tensor to quantize with shape
|
|
813
|
+
[out_features, in_features].
|
|
814
|
+
bits: int. The number of bits to quantize to (e.g., 4).
|
|
815
|
+
symmetric: bool. Whether to use symmetric quantization.
|
|
816
|
+
per_channel: bool. Whether to quantize per channel.
|
|
817
|
+
group_size: int. The group size for quantization. -1 means no grouping.
|
|
818
|
+
compute_dtype: str. The dtype for computation. Defaults to "float32".
|
|
819
|
+
|
|
820
|
+
Returns:
|
|
821
|
+
scale: KerasTensor. The scale tensor for quantization.
|
|
822
|
+
zero: KerasTensor. The zero tensor for quantization.
|
|
823
|
+
maxq: scalar. The maximum quantization value.
|
|
824
|
+
"""
|
|
825
|
+
# Input validation
|
|
826
|
+
if x is None:
|
|
827
|
+
raise ValueError(f"Input tensor {x} cannot be None.")
|
|
828
|
+
if len(x.shape) < 2:
|
|
829
|
+
raise ValueError(
|
|
830
|
+
f"Input weight tensor {x} must have a rank of at "
|
|
831
|
+
f"least 2, but got rank {len(x.shape)}."
|
|
832
|
+
)
|
|
833
|
+
if ops.size(x) == 0:
|
|
834
|
+
raise ValueError("Input tensor 'x' cannot be empty.")
|
|
835
|
+
|
|
836
|
+
out_features, in_features = x.shape[0], x.shape[1]
|
|
837
|
+
|
|
838
|
+
# Determine number of groups for quantization
|
|
839
|
+
if per_channel and group_size > 0:
|
|
840
|
+
n_groups = (in_features + group_size - 1) // group_size
|
|
841
|
+
else:
|
|
842
|
+
n_groups = 1
|
|
843
|
+
|
|
844
|
+
# Compute min/max values based on quantization mode
|
|
845
|
+
if n_groups > 1:
|
|
846
|
+
# Grouped quantization: output shape [out_features, n_groups]
|
|
847
|
+
remainder = in_features % group_size
|
|
848
|
+
if remainder != 0:
|
|
849
|
+
pad_size = group_size - remainder
|
|
850
|
+
x = ops.pad(x, [[0, 0], [0, pad_size]], constant_values=0.0)
|
|
851
|
+
|
|
852
|
+
x_grouped = ops.reshape(x, [out_features, n_groups, group_size])
|
|
853
|
+
min_values = ops.min(x_grouped, axis=2)
|
|
854
|
+
max_values = ops.max(x_grouped, axis=2)
|
|
855
|
+
else:
|
|
856
|
+
# Per-channel or per-tensor: compute stats along rows
|
|
857
|
+
reduction_shape = [out_features, -1] if per_channel else [1, -1]
|
|
858
|
+
x_reshaped = ops.reshape(x, reduction_shape)
|
|
859
|
+
min_values = ops.min(x_reshaped, axis=1)
|
|
860
|
+
max_values = ops.max(x_reshaped, axis=1)
|
|
861
|
+
|
|
862
|
+
# Symmetric quantization: make range symmetric around zero
|
|
863
|
+
if symmetric:
|
|
864
|
+
max_abs = ops.maximum(ops.abs(min_values), max_values)
|
|
865
|
+
min_values = ops.where(
|
|
866
|
+
ops.less(min_values, 0), ops.negative(max_abs), min_values
|
|
867
|
+
)
|
|
868
|
+
max_values = max_abs
|
|
869
|
+
|
|
870
|
+
# Ensure non-zero range to avoid division errors
|
|
871
|
+
zero_range = ops.equal(min_values, max_values)
|
|
872
|
+
min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)
|
|
873
|
+
max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
|
|
874
|
+
|
|
875
|
+
# Compute scale and zero-point
|
|
876
|
+
maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)
|
|
877
|
+
scale = ops.divide(ops.subtract(max_values, min_values), maxq)
|
|
878
|
+
scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
|
|
879
|
+
|
|
880
|
+
if symmetric:
|
|
881
|
+
zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
|
|
882
|
+
else:
|
|
883
|
+
zero = ops.round(ops.divide(ops.negative(min_values), scale))
|
|
884
|
+
|
|
885
|
+
# Reshape output to [out_features, n_groups] or [out_features, 1]
|
|
886
|
+
if n_groups > 1:
|
|
887
|
+
pass # Already [out_features, n_groups]
|
|
888
|
+
elif per_channel:
|
|
889
|
+
scale = ops.reshape(scale, [-1, 1])
|
|
890
|
+
zero = ops.reshape(zero, [-1, 1])
|
|
891
|
+
else:
|
|
892
|
+
# Per-tensor: tile single value to [out_features, 1]
|
|
893
|
+
scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))
|
|
894
|
+
zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))
|
|
895
|
+
|
|
896
|
+
return scale, ops.cast(zero, "uint8"), maxq
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def quantize_with_zero_point(input_tensor, scale, zero, maxq):
|
|
900
|
+
"""Quantize a float tensor into discrete levels [0, maxq] using
|
|
901
|
+
per-tensor/per-channel/grouped scaling.
|
|
902
|
+
|
|
903
|
+
Returns `q` (same dtype as inputs/scales; float is fine) where values are in
|
|
904
|
+
[0, maxq].
|
|
905
|
+
|
|
906
|
+
Args:
|
|
907
|
+
input_tensor: KerasTensor. The input tensor to quantize.
|
|
908
|
+
scale: KerasTensor. The scale tensor for quantization.
|
|
909
|
+
zero: KerasTensor. The zero tensor for quantization.
|
|
910
|
+
maxq: KerasTensor. The maximum quantization value.
|
|
911
|
+
|
|
912
|
+
Returns:
|
|
913
|
+
KerasTensor. The quantized tensor.
|
|
914
|
+
"""
|
|
915
|
+
# Guard against divide-by-zero
|
|
916
|
+
epsilon = ops.cast(1e-8, dtype=scale.dtype)
|
|
917
|
+
safe_scale = ops.where(ops.equal(scale, 0), epsilon, scale)
|
|
918
|
+
|
|
919
|
+
quantized_tensor = ops.round(
|
|
920
|
+
ops.add(
|
|
921
|
+
ops.divide(input_tensor, safe_scale), ops.cast(zero, scale.dtype)
|
|
922
|
+
)
|
|
923
|
+
)
|
|
924
|
+
quantized_tensor = ops.clip(quantized_tensor, 0, maxq)
|
|
925
|
+
return quantized_tensor
|
|
926
|
+
|
|
927
|
+
|
|
928
|
+
def dequantize_with_zero_point(input_tensor, scale, zero):
|
|
929
|
+
"""
|
|
930
|
+
Dequantizes a quantized tensor using the provided scale and zero tensors.
|
|
931
|
+
|
|
932
|
+
Args:
|
|
933
|
+
input_tensor: KerasTensor. The quantized tensor to dequantize.
|
|
934
|
+
scale: KerasTensor. The scale tensor for dequantization.
|
|
935
|
+
zero: KerasTensor. The zero tensor for dequantization.
|
|
936
|
+
|
|
937
|
+
Returns:
|
|
938
|
+
KerasTensor. The dequantized tensor.
|
|
939
|
+
"""
|
|
940
|
+
return ops.multiply(
|
|
941
|
+
scale, ops.subtract(input_tensor, ops.cast(zero, scale.dtype))
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
|
|
945
|
+
def quantize_with_sz_map(weights_matrix, scale, zero, g_idx, maxq):
|
|
946
|
+
"""Quantize the weight matrix from group params.
|
|
947
|
+
|
|
948
|
+
This function uses the provided scale and zero tensors to quantize the
|
|
949
|
+
input weights_matrix according to the group indices. It maps each column
|
|
950
|
+
of the weights_matrix to its corresponding group parameters and performs
|
|
951
|
+
the quantization operation.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
weights_matrix: 2D tensor of shape [out_features, in_features].
|
|
955
|
+
scale: Per-group scale tensor of shape [out_features, n_groups].
|
|
956
|
+
zero: Per-group zero-point tensor of shape [out_features, n_groups].
|
|
957
|
+
g_idx: Integer tensor of shape [in_features,] mapping each column to
|
|
958
|
+
its group index.
|
|
959
|
+
maxq: Scalar (float) representing the maximum integer quantization
|
|
960
|
+
level (e.g., 2^bits - 1).
|
|
961
|
+
|
|
962
|
+
Returns:
|
|
963
|
+
A tensor with the same shape as `weights_matrix` containing the
|
|
964
|
+
quantized weights produced using the provided group parameters.
|
|
965
|
+
"""
|
|
966
|
+
groups = ops.cast(g_idx, "int32")
|
|
967
|
+
scale_cols = ops.take(scale, groups, axis=1) # [out_features, in_features]
|
|
968
|
+
zero_cols = ops.take(zero, groups, axis=1) # [out_features, in_features]
|
|
969
|
+
|
|
970
|
+
# Quantize elementwise, then cast to int
|
|
971
|
+
return quantize_with_zero_point(weights_matrix, scale_cols, zero_cols, maxq)
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
def dequantize_with_sz_map(weights_matrix, scale, zero, g_idx):
|
|
975
|
+
"""Rebuild a dequantized weight matrix from group params.
|
|
976
|
+
|
|
977
|
+
This function uses the provided scale and zero tensors to dequantize the
|
|
978
|
+
input weights_matrix according to the group indices. It maps each column
|
|
979
|
+
of the weights_matrix to its corresponding group parameters and performs
|
|
980
|
+
the dequantization operation.
|
|
981
|
+
|
|
982
|
+
Args:
|
|
983
|
+
weights_matrix: 2D tensor of shape [out_features, in_features].
|
|
984
|
+
scale: Per-group scale tensor of shape [out_features, n_groups].
|
|
985
|
+
zero: Per-group zero-point tensor of shape [out_features, n_groups].
|
|
986
|
+
g_idx: Integer tensor of shape [in_features,] mapping each column to
|
|
987
|
+
its group index.
|
|
988
|
+
maxq: Scalar (float) representing the maximum integer quantization
|
|
989
|
+
level (e.g., 2^bits - 1).
|
|
990
|
+
|
|
991
|
+
Returns:
|
|
992
|
+
A tensor with the same shape as `weights_matrix` containing the
|
|
993
|
+
dequantized weights produced using the provided group parameters.
|
|
994
|
+
"""
|
|
995
|
+
# Map group indices to scales and zeros
|
|
996
|
+
groups = ops.cast(g_idx, "int32")
|
|
997
|
+
scales_mapped = ops.take(scale, groups, axis=1)
|
|
998
|
+
zeros_mapped = ops.take(zero, groups, axis=1)
|
|
999
|
+
zeros_mapped = ops.cast(zeros_mapped, scales_mapped.dtype)
|
|
1000
|
+
|
|
1001
|
+
quantized = ops.multiply(
|
|
1002
|
+
ops.subtract(weights_matrix, zeros_mapped), scales_mapped
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
return quantized
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def should_quantize_layer(layer, filters):
|
|
5
|
+
"""Determines if a layer should be quantized based on filters.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
layer: The layer to check.
|
|
9
|
+
filters: A regex string, a list of regex strings, or a callable.
|
|
10
|
+
If None, returns True.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
True if the layer should be quantized, False otherwise.
|
|
14
|
+
"""
|
|
15
|
+
if filters is None:
|
|
16
|
+
return True
|
|
17
|
+
if isinstance(filters, str):
|
|
18
|
+
return bool(re.search(filters, layer.name))
|
|
19
|
+
if isinstance(filters, (list, tuple)):
|
|
20
|
+
return any(re.search(pat, layer.name) for pat in filters)
|
|
21
|
+
if callable(filters):
|
|
22
|
+
return filters(layer)
|
|
23
|
+
return True
|
|
@@ -8,6 +8,8 @@ from keras.src.backend.common import global_state
|
|
|
8
8
|
from keras.src.utils import jax_utils
|
|
9
9
|
from keras.src.utils.naming import auto_name
|
|
10
10
|
|
|
11
|
+
GLOBAL_SEED_GENERATOR = "global_seed_generator"
|
|
12
|
+
|
|
11
13
|
|
|
12
14
|
@keras_export("keras.random.SeedGenerator")
|
|
13
15
|
class SeedGenerator:
|
|
@@ -27,7 +29,7 @@ class SeedGenerator:
|
|
|
27
29
|
a local `StateGenerator` with either a deterministic or random initial
|
|
28
30
|
state.
|
|
29
31
|
|
|
30
|
-
Remark concerning the JAX
|
|
32
|
+
Remark concerning the JAX backend: Note that the use of a local
|
|
31
33
|
`StateGenerator` as seed argument is required for JIT compilation of
|
|
32
34
|
RNG with the JAX backend, because the use of global state is not
|
|
33
35
|
supported.
|
|
@@ -109,7 +111,7 @@ class SeedGenerator:
|
|
|
109
111
|
return new_seed_value
|
|
110
112
|
|
|
111
113
|
def get_config(self):
|
|
112
|
-
return {"seed": self._initial_seed}
|
|
114
|
+
return {"seed": self._initial_seed, "name": self.name}
|
|
113
115
|
|
|
114
116
|
@classmethod
|
|
115
117
|
def from_config(cls, config):
|
|
@@ -133,10 +135,10 @@ def global_seed_generator():
|
|
|
133
135
|
"out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n"
|
|
134
136
|
"```"
|
|
135
137
|
)
|
|
136
|
-
gen = global_state.get_global_attribute(
|
|
138
|
+
gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR)
|
|
137
139
|
if gen is None:
|
|
138
140
|
gen = SeedGenerator()
|
|
139
|
-
global_state.set_global_attribute(
|
|
141
|
+
global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen)
|
|
140
142
|
return gen
|
|
141
143
|
|
|
142
144
|
|