keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/layers/core/dense.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
1
3
|
import ml_dtypes
|
|
2
4
|
|
|
3
5
|
from keras.src import activations
|
|
@@ -9,6 +11,9 @@ from keras.src import regularizers
|
|
|
9
11
|
from keras.src.api_export import keras_export
|
|
10
12
|
from keras.src.layers.input_spec import InputSpec
|
|
11
13
|
from keras.src.layers.layer import Layer
|
|
14
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
15
|
+
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
16
|
+
from keras.src.saving import serialization_lib
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
@keras_export("keras.layers.Dense")
|
|
@@ -20,7 +25,9 @@ class Dense(Layer):
|
|
|
20
25
|
where `activation` is the element-wise activation function
|
|
21
26
|
passed as the `activation` argument, `kernel` is a weights matrix
|
|
22
27
|
created by the layer, and `bias` is a bias vector created by the layer
|
|
23
|
-
(only applicable if `use_bias` is `True`).
|
|
28
|
+
(only applicable if `use_bias` is `True`). When this layer is
|
|
29
|
+
followed by a `BatchNormalization` layer, it is recommended to set
|
|
30
|
+
`use_bias=False` as `BatchNormalization` has its own bias term.
|
|
24
31
|
|
|
25
32
|
Note: If the input to the layer has a rank greater than 2, `Dense`
|
|
26
33
|
computes the dot product between the `inputs` and the `kernel` along the
|
|
@@ -87,8 +94,15 @@ class Dense(Layer):
|
|
|
87
94
|
bias_constraint=None,
|
|
88
95
|
lora_rank=None,
|
|
89
96
|
lora_alpha=None,
|
|
97
|
+
quantization_config=None,
|
|
90
98
|
**kwargs,
|
|
91
99
|
):
|
|
100
|
+
if not isinstance(units, int) or units <= 0:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"Received an invalid value for `units`, expected a positive "
|
|
103
|
+
f"integer. Received: units={units}"
|
|
104
|
+
)
|
|
105
|
+
|
|
92
106
|
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
|
|
93
107
|
self.units = units
|
|
94
108
|
self.activation = activations.get(activation)
|
|
@@ -102,14 +116,19 @@ class Dense(Layer):
|
|
|
102
116
|
self.lora_rank = lora_rank
|
|
103
117
|
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
|
|
104
118
|
self.lora_enabled = False
|
|
119
|
+
self.quantization_config = quantization_config
|
|
105
120
|
self.input_spec = InputSpec(min_ndim=2)
|
|
106
121
|
self.supports_masking = True
|
|
107
122
|
|
|
108
123
|
def build(self, input_shape):
|
|
109
124
|
kernel_shape = (input_shape[-1], self.units)
|
|
110
125
|
if self.quantization_mode:
|
|
111
|
-
self.quantized_build(
|
|
112
|
-
|
|
126
|
+
self.quantized_build(
|
|
127
|
+
kernel_shape,
|
|
128
|
+
mode=self.quantization_mode,
|
|
129
|
+
config=self.quantization_config,
|
|
130
|
+
)
|
|
131
|
+
if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
|
|
113
132
|
# If the layer is quantized to int8 or int4, `self._kernel` will be
|
|
114
133
|
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
|
|
115
134
|
# it here.
|
|
@@ -137,15 +156,58 @@ class Dense(Layer):
|
|
|
137
156
|
|
|
138
157
|
@property
|
|
139
158
|
def kernel(self):
|
|
159
|
+
from keras.src.quantizers import gptq_core
|
|
160
|
+
|
|
140
161
|
if not self.built:
|
|
141
162
|
raise AttributeError(
|
|
142
163
|
"You must build the layer before accessing `kernel`."
|
|
143
164
|
)
|
|
165
|
+
|
|
166
|
+
mode = self.quantization_mode
|
|
167
|
+
is_gptq = mode == "gptq"
|
|
168
|
+
is_awq = mode == "awq"
|
|
169
|
+
is_int4 = mode == "int4"
|
|
170
|
+
gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
|
|
171
|
+
awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
|
|
172
|
+
gptq_bits = (
|
|
173
|
+
gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Decide the source tensor first (packed vs already-quantized vs plain
|
|
177
|
+
# kernel)
|
|
178
|
+
if is_gptq and gptq_calibrated and gptq_bits != 4:
|
|
179
|
+
# calibrated GPTQ, not 4-bit, no unpacking needed
|
|
180
|
+
kernel = self.quantized_kernel
|
|
181
|
+
else:
|
|
182
|
+
# Start with the stored kernel
|
|
183
|
+
kernel = getattr(self, "_kernel", None)
|
|
184
|
+
|
|
185
|
+
# Handle int4 unpacking cases in one place
|
|
186
|
+
if is_int4:
|
|
187
|
+
kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
|
|
188
|
+
elif is_gptq and gptq_calibrated and gptq_bits == 4:
|
|
189
|
+
kernel = quantizers.unpack_int4(
|
|
190
|
+
self.quantized_kernel,
|
|
191
|
+
orig_len=self.units,
|
|
192
|
+
axis=0,
|
|
193
|
+
dtype="uint8",
|
|
194
|
+
)
|
|
195
|
+
elif is_awq and awq_calibrated:
|
|
196
|
+
# AWQ always uses 4-bit quantization
|
|
197
|
+
kernel = quantizers.unpack_int4(
|
|
198
|
+
self.quantized_kernel,
|
|
199
|
+
orig_len=self.units,
|
|
200
|
+
axis=0,
|
|
201
|
+
dtype="uint8",
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Apply LoRA once at the end.
|
|
144
205
|
if self.lora_enabled:
|
|
145
|
-
|
|
146
|
-
self.
|
|
147
|
-
)
|
|
148
|
-
|
|
206
|
+
kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
|
|
207
|
+
self.lora_kernel_a, self.lora_kernel_b
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
return kernel
|
|
149
211
|
|
|
150
212
|
def call(self, inputs, training=None):
|
|
151
213
|
x = ops.matmul(inputs, self.kernel)
|
|
@@ -181,6 +243,10 @@ class Dense(Layer):
|
|
|
181
243
|
raise ValueError(
|
|
182
244
|
"lora is already enabled. This can only be done once per layer."
|
|
183
245
|
)
|
|
246
|
+
if self.quantization_mode == "gptq":
|
|
247
|
+
raise NotImplementedError(
|
|
248
|
+
"lora is not currently supported with GPTQ quantization."
|
|
249
|
+
)
|
|
184
250
|
self._tracker.unlock()
|
|
185
251
|
# Determine the correct input dimension for the LoRA A matrix. When
|
|
186
252
|
# the layer has been int4-quantized, `self._kernel` stores a *packed*
|
|
@@ -217,26 +283,26 @@ class Dense(Layer):
|
|
|
217
283
|
# Do nothing if the layer isn't yet built
|
|
218
284
|
if not self.built:
|
|
219
285
|
return
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
286
|
+
mode = self.quantization_mode
|
|
287
|
+
if mode not in self.variable_serialization_spec:
|
|
288
|
+
raise self._quantization_mode_error(mode)
|
|
289
|
+
|
|
290
|
+
# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
|
|
291
|
+
# for None/gptq)
|
|
292
|
+
kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
|
|
293
|
+
idx = 0
|
|
294
|
+
for name in self.variable_serialization_spec[mode]:
|
|
295
|
+
if name == "kernel":
|
|
296
|
+
store[str(idx)] = kernel_value
|
|
297
|
+
elif name == "bias" and self.bias is None:
|
|
298
|
+
continue
|
|
299
|
+
elif name == "kernel_scale" and mode in ("int4", "int8"):
|
|
300
|
+
# For int4/int8, the merged LoRA scale (if any) comes from
|
|
301
|
+
# `_get_kernel_with_merged_lora()`
|
|
302
|
+
store[str(idx)] = merged_kernel_scale
|
|
236
303
|
else:
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
store[str(i)] = variable
|
|
304
|
+
store[str(idx)] = getattr(self, name)
|
|
305
|
+
idx += 1
|
|
240
306
|
|
|
241
307
|
def load_own_variables(self, store):
|
|
242
308
|
if not self.lora_enabled:
|
|
@@ -244,25 +310,23 @@ class Dense(Layer):
|
|
|
244
310
|
# Do nothing if the layer isn't yet built
|
|
245
311
|
if not self.built:
|
|
246
312
|
return
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
target_variables.append(self.outputs_grad_amax_history)
|
|
313
|
+
mode = self.quantization_mode
|
|
314
|
+
if mode not in self.variable_serialization_spec:
|
|
315
|
+
raise self._quantization_mode_error(mode)
|
|
316
|
+
|
|
317
|
+
# A saved GPTQ/AWQ quantized model will always be calibrated.
|
|
318
|
+
self.is_gptq_calibrated = mode == "gptq"
|
|
319
|
+
self.is_awq_calibrated = mode == "awq"
|
|
320
|
+
|
|
321
|
+
idx = 0
|
|
322
|
+
for name in self.variable_serialization_spec[mode]:
|
|
323
|
+
if name == "kernel":
|
|
324
|
+
self._kernel.assign(store[str(idx)])
|
|
325
|
+
elif name == "bias" and self.bias is None:
|
|
326
|
+
continue
|
|
262
327
|
else:
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
variable.assign(store[str(i)])
|
|
328
|
+
getattr(self, name).assign(store[str(idx)])
|
|
329
|
+
idx += 1
|
|
266
330
|
if self.lora_enabled:
|
|
267
331
|
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
|
|
268
332
|
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
|
|
@@ -283,61 +347,97 @@ class Dense(Layer):
|
|
|
283
347
|
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
|
|
284
348
|
"kernel_constraint": constraints.serialize(self.kernel_constraint),
|
|
285
349
|
"bias_constraint": constraints.serialize(self.bias_constraint),
|
|
350
|
+
"quantization_config": serialization_lib.serialize_keras_object(
|
|
351
|
+
self.quantization_config
|
|
352
|
+
),
|
|
286
353
|
}
|
|
287
354
|
if self.lora_rank:
|
|
288
355
|
config["lora_rank"] = self.lora_rank
|
|
289
356
|
config["lora_alpha"] = self.lora_alpha
|
|
290
357
|
return {**base_config, **config}
|
|
291
358
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
"and thus it doesn't have any variables. "
|
|
299
|
-
f"However the weights file lists {len(store.keys())} "
|
|
300
|
-
"variables for this layer.\n"
|
|
301
|
-
"In most cases, this error indicates that either:\n\n"
|
|
302
|
-
"1. The layer is owned by a parent layer that "
|
|
303
|
-
"implements a `build()` method, but calling the "
|
|
304
|
-
"parent's `build()` method did NOT create the state of "
|
|
305
|
-
f"the child layer '{self.name}'. A `build()` method "
|
|
306
|
-
"must create ALL state for the layer, including "
|
|
307
|
-
"the state of any children layers.\n\n"
|
|
308
|
-
"2. You need to implement "
|
|
309
|
-
"the `def build_from_config(self, config)` method "
|
|
310
|
-
f"on layer '{self.name}', to specify how to rebuild "
|
|
311
|
-
"it during loading. "
|
|
312
|
-
"In this case, you might also want to implement the "
|
|
313
|
-
"method that generates the build config at saving time, "
|
|
314
|
-
"`def get_build_config(self)`. "
|
|
315
|
-
"The method `build_from_config()` is meant "
|
|
316
|
-
"to create the state "
|
|
317
|
-
"of the layer (i.e. its variables) upon deserialization.",
|
|
318
|
-
)
|
|
319
|
-
raise ValueError(
|
|
320
|
-
f"Layer '{self.name}' expected {len(all_vars)} variables, "
|
|
321
|
-
"but received "
|
|
322
|
-
f"{len(store.keys())} variables during loading. "
|
|
323
|
-
f"Expected: {[v.name for v in all_vars]}"
|
|
359
|
+
@classmethod
|
|
360
|
+
def from_config(cls, config):
|
|
361
|
+
config = config.copy()
|
|
362
|
+
config["quantization_config"] = (
|
|
363
|
+
serialization_lib.deserialize_keras_object(
|
|
364
|
+
config.get("quantization_config", None)
|
|
324
365
|
)
|
|
366
|
+
)
|
|
367
|
+
return super().from_config(config)
|
|
368
|
+
|
|
369
|
+
@property
|
|
370
|
+
def variable_serialization_spec(self):
|
|
371
|
+
"""Returns a dict mapping quantization modes to variable names in order.
|
|
325
372
|
|
|
326
|
-
|
|
373
|
+
This spec is used by `save_own_variables` and `load_own_variables` to
|
|
374
|
+
determine the correct ordering of variables during serialization for
|
|
375
|
+
each quantization mode. `None` means no quantization.
|
|
376
|
+
"""
|
|
377
|
+
return {
|
|
378
|
+
None: [
|
|
379
|
+
"kernel",
|
|
380
|
+
"bias",
|
|
381
|
+
],
|
|
382
|
+
"int8": [
|
|
383
|
+
"kernel",
|
|
384
|
+
"bias",
|
|
385
|
+
"kernel_scale",
|
|
386
|
+
],
|
|
387
|
+
"int4": [
|
|
388
|
+
"kernel",
|
|
389
|
+
"bias",
|
|
390
|
+
"kernel_scale",
|
|
391
|
+
],
|
|
392
|
+
"float8": [
|
|
393
|
+
"kernel",
|
|
394
|
+
"bias",
|
|
395
|
+
"inputs_scale",
|
|
396
|
+
"inputs_amax_history",
|
|
397
|
+
"kernel_scale",
|
|
398
|
+
"kernel_amax_history",
|
|
399
|
+
"outputs_grad_scale",
|
|
400
|
+
"outputs_grad_amax_history",
|
|
401
|
+
],
|
|
402
|
+
"gptq": [
|
|
403
|
+
"bias",
|
|
404
|
+
"quantized_kernel",
|
|
405
|
+
"kernel_scale",
|
|
406
|
+
"kernel_zero",
|
|
407
|
+
"g_idx",
|
|
408
|
+
],
|
|
409
|
+
"awq": [
|
|
410
|
+
"bias",
|
|
411
|
+
"quantized_kernel",
|
|
412
|
+
"kernel_scale",
|
|
413
|
+
"kernel_zero",
|
|
414
|
+
"awq_scales",
|
|
415
|
+
"g_idx",
|
|
416
|
+
],
|
|
417
|
+
}
|
|
327
418
|
|
|
328
|
-
def quantized_build(self, kernel_shape, mode):
|
|
419
|
+
def quantized_build(self, kernel_shape, mode, config=None):
|
|
329
420
|
if mode == "int8":
|
|
330
|
-
self._int8_build(kernel_shape)
|
|
421
|
+
self._int8_build(kernel_shape, config)
|
|
331
422
|
elif mode == "int4":
|
|
332
|
-
self._int4_build(kernel_shape)
|
|
423
|
+
self._int4_build(kernel_shape, config)
|
|
333
424
|
elif mode == "float8":
|
|
334
425
|
self._float8_build()
|
|
426
|
+
elif mode == "gptq":
|
|
427
|
+
self._gptq_build(kernel_shape, config)
|
|
428
|
+
elif mode == "awq":
|
|
429
|
+
self._awq_build(kernel_shape, config)
|
|
335
430
|
else:
|
|
336
431
|
raise self._quantization_mode_error(mode)
|
|
337
432
|
self._is_quantized = True
|
|
338
433
|
|
|
339
|
-
def _int8_build(self, kernel_shape):
|
|
340
|
-
self.inputs_quantizer =
|
|
434
|
+
def _int8_build(self, kernel_shape, config=None):
|
|
435
|
+
self.inputs_quantizer = (
|
|
436
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
437
|
+
config, quantizers.AbsMaxQuantizer()
|
|
438
|
+
)
|
|
439
|
+
)
|
|
440
|
+
|
|
341
441
|
self._kernel = self.add_weight(
|
|
342
442
|
name="kernel",
|
|
343
443
|
shape=kernel_shape,
|
|
@@ -352,7 +452,182 @@ class Dense(Layer):
|
|
|
352
452
|
trainable=False,
|
|
353
453
|
)
|
|
354
454
|
|
|
355
|
-
def
|
|
455
|
+
def _gptq_build(self, kernel_shape, config):
|
|
456
|
+
from keras.src.quantizers import gptq_core
|
|
457
|
+
|
|
458
|
+
# Ensures the forward pass uses the original high-precision kernel
|
|
459
|
+
# until calibration has been performed.
|
|
460
|
+
self.is_gptq_calibrated = False
|
|
461
|
+
self.kernel_shape = kernel_shape
|
|
462
|
+
|
|
463
|
+
weight_bits = gptq_core.get_weight_bits_for_layer(self, config)
|
|
464
|
+
# For 4-bit weights, we pack two values per byte.
|
|
465
|
+
units = (
|
|
466
|
+
(kernel_shape[1] + 1) // 2 if weight_bits == 4 else kernel_shape[1]
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
self.quantized_kernel = self.add_weight(
|
|
470
|
+
name="kernel",
|
|
471
|
+
shape=(units, kernel_shape[0]),
|
|
472
|
+
initializer="zeros",
|
|
473
|
+
dtype="uint8",
|
|
474
|
+
trainable=False,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
group_size = gptq_core.get_group_size_for_layer(self, config)
|
|
478
|
+
n_groups = (
|
|
479
|
+
1
|
|
480
|
+
if group_size == -1
|
|
481
|
+
else math.ceil(self.kernel_shape[0] / group_size)
|
|
482
|
+
)
|
|
483
|
+
self.kernel_scale = self.add_weight(
|
|
484
|
+
name="kernel_scale",
|
|
485
|
+
shape=(self.units, n_groups),
|
|
486
|
+
initializer="ones",
|
|
487
|
+
trainable=False,
|
|
488
|
+
)
|
|
489
|
+
self.kernel_zero = self.add_weight(
|
|
490
|
+
name="kernel_zero",
|
|
491
|
+
shape=(self.units, n_groups),
|
|
492
|
+
initializer="zeros",
|
|
493
|
+
dtype="uint8",
|
|
494
|
+
trainable=False,
|
|
495
|
+
)
|
|
496
|
+
self.g_idx = self.add_weight(
|
|
497
|
+
name="g_idx",
|
|
498
|
+
shape=(self.kernel_shape[0],),
|
|
499
|
+
initializer="zeros",
|
|
500
|
+
dtype="float32",
|
|
501
|
+
trainable=False,
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
def _gptq_call(self, inputs, training=False):
|
|
505
|
+
from keras.src.quantizers import gptq_core
|
|
506
|
+
|
|
507
|
+
if not self.is_gptq_calibrated:
|
|
508
|
+
W = self._kernel
|
|
509
|
+
else:
|
|
510
|
+
should_unpack = (
|
|
511
|
+
gptq_core.get_weight_bits_for_layer(self, config=None) == 4
|
|
512
|
+
)
|
|
513
|
+
W = (
|
|
514
|
+
quantizers.unpack_int4(
|
|
515
|
+
self.quantized_kernel,
|
|
516
|
+
orig_len=self.units,
|
|
517
|
+
axis=0,
|
|
518
|
+
dtype="uint8",
|
|
519
|
+
)
|
|
520
|
+
if should_unpack
|
|
521
|
+
else self.quantized_kernel
|
|
522
|
+
)
|
|
523
|
+
W = ops.transpose(
|
|
524
|
+
dequantize_with_sz_map(
|
|
525
|
+
W,
|
|
526
|
+
self.kernel_scale,
|
|
527
|
+
self.kernel_zero,
|
|
528
|
+
self.g_idx,
|
|
529
|
+
)
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
y = ops.matmul(inputs, W)
|
|
533
|
+
if self.bias is not None:
|
|
534
|
+
y = ops.add(y, self.bias)
|
|
535
|
+
if self.activation is not None:
|
|
536
|
+
y = self.activation(y)
|
|
537
|
+
return y
|
|
538
|
+
|
|
539
|
+
def _awq_build(self, kernel_shape, config):
|
|
540
|
+
"""Build variables for AWQ quantization.
|
|
541
|
+
|
|
542
|
+
AWQ uses 4-bit quantization with per-channel AWQ scales that protect
|
|
543
|
+
salient weights based on activation magnitudes.
|
|
544
|
+
"""
|
|
545
|
+
from keras.src.quantizers import awq_core
|
|
546
|
+
|
|
547
|
+
# Ensures the forward pass uses the original high-precision kernel
|
|
548
|
+
# until calibration has been performed.
|
|
549
|
+
self.is_awq_calibrated = False
|
|
550
|
+
self.kernel_shape = kernel_shape
|
|
551
|
+
|
|
552
|
+
# For 4-bit weights, we pack two values per byte.
|
|
553
|
+
units = (kernel_shape[1] + 1) // 2
|
|
554
|
+
|
|
555
|
+
self.quantized_kernel = self.add_weight(
|
|
556
|
+
name="kernel",
|
|
557
|
+
shape=(units, kernel_shape[0]),
|
|
558
|
+
initializer="zeros",
|
|
559
|
+
dtype="uint8",
|
|
560
|
+
trainable=False,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
group_size = awq_core.get_group_size_for_layer(self, config)
|
|
564
|
+
num_groups = (
|
|
565
|
+
1 if group_size == -1 else math.ceil(kernel_shape[0] / group_size)
|
|
566
|
+
)
|
|
567
|
+
self.kernel_scale = self.add_weight(
|
|
568
|
+
name="kernel_scale",
|
|
569
|
+
shape=(self.units, num_groups),
|
|
570
|
+
initializer="ones",
|
|
571
|
+
trainable=False,
|
|
572
|
+
)
|
|
573
|
+
self.kernel_zero = self.add_weight(
|
|
574
|
+
name="kernel_zero",
|
|
575
|
+
shape=(self.units, num_groups),
|
|
576
|
+
initializer="zeros",
|
|
577
|
+
dtype="uint8",
|
|
578
|
+
trainable=False,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Per-channel AWQ scales from activation magnitudes
|
|
582
|
+
self.awq_scales = self.add_weight(
|
|
583
|
+
name="awq_scales",
|
|
584
|
+
shape=(kernel_shape[0],),
|
|
585
|
+
initializer="ones",
|
|
586
|
+
trainable=False,
|
|
587
|
+
)
|
|
588
|
+
self.g_idx = self.add_weight(
|
|
589
|
+
name="g_idx",
|
|
590
|
+
shape=(kernel_shape[0],),
|
|
591
|
+
initializer="zeros",
|
|
592
|
+
dtype="float32",
|
|
593
|
+
trainable=False,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
def _awq_call(self, inputs, training=False):
|
|
597
|
+
"""Forward pass for AWQ quantized layer."""
|
|
598
|
+
if not self.is_awq_calibrated:
|
|
599
|
+
W = self._kernel
|
|
600
|
+
else:
|
|
601
|
+
# Unpack 4-bit weights
|
|
602
|
+
W = quantizers.unpack_int4(
|
|
603
|
+
self.quantized_kernel,
|
|
604
|
+
orig_len=self.units,
|
|
605
|
+
axis=0,
|
|
606
|
+
dtype="uint8",
|
|
607
|
+
)
|
|
608
|
+
# Dequantize using scale/zero maps
|
|
609
|
+
W = ops.transpose(
|
|
610
|
+
dequantize_with_sz_map(
|
|
611
|
+
W,
|
|
612
|
+
self.kernel_scale,
|
|
613
|
+
self.kernel_zero,
|
|
614
|
+
self.g_idx,
|
|
615
|
+
)
|
|
616
|
+
)
|
|
617
|
+
# Apply AWQ scales by dividing to restore original magnitude
|
|
618
|
+
# (We multiplied by scales before quantization, so divide to undo)
|
|
619
|
+
# awq_scales has shape [input_dim], W has shape [input_dim, units]
|
|
620
|
+
# Expand dims for proper broadcasting.
|
|
621
|
+
W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
|
|
622
|
+
|
|
623
|
+
y = ops.matmul(inputs, W)
|
|
624
|
+
if self.bias is not None:
|
|
625
|
+
y = ops.add(y, self.bias)
|
|
626
|
+
if self.activation is not None:
|
|
627
|
+
y = self.activation(y)
|
|
628
|
+
return y
|
|
629
|
+
|
|
630
|
+
def _int4_build(self, kernel_shape, config=None):
|
|
356
631
|
"""Build variables for int4 quantization.
|
|
357
632
|
|
|
358
633
|
`kernel_shape` is the *original* float32 kernel shape
|
|
@@ -361,8 +636,10 @@ class Dense(Layer):
|
|
|
361
636
|
int8 byte.
|
|
362
637
|
"""
|
|
363
638
|
# Per-channel int8 quantizer for the last axis (features).
|
|
364
|
-
self.inputs_quantizer =
|
|
365
|
-
|
|
639
|
+
self.inputs_quantizer = (
|
|
640
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
641
|
+
config, quantizers.AbsMaxQuantizer()
|
|
642
|
+
)
|
|
366
643
|
)
|
|
367
644
|
input_dim, output_dim = kernel_shape
|
|
368
645
|
packed_rows = (input_dim + 1) // 2 # ceil for odd dims
|
|
@@ -451,11 +728,15 @@ class Dense(Layer):
|
|
|
451
728
|
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
|
|
452
729
|
return (inputs_grad, None, None)
|
|
453
730
|
|
|
454
|
-
|
|
731
|
+
output_scale = kernel_scale
|
|
732
|
+
if self.inputs_quantizer:
|
|
733
|
+
inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
|
|
734
|
+
output_scale = ops.multiply(output_scale, inputs_scale)
|
|
735
|
+
|
|
455
736
|
x = ops.matmul(inputs, kernel)
|
|
456
737
|
# De-scale outputs
|
|
457
738
|
x = ops.cast(x, self.compute_dtype)
|
|
458
|
-
x = ops.divide(x,
|
|
739
|
+
x = ops.divide(x, output_scale)
|
|
459
740
|
return x, grad_fn
|
|
460
741
|
|
|
461
742
|
x = matmul_with_inputs_gradient(
|
|
@@ -502,10 +783,15 @@ class Dense(Layer):
|
|
|
502
783
|
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
|
|
503
784
|
return (inputs_grad, None, None)
|
|
504
785
|
|
|
505
|
-
|
|
786
|
+
output_scale = kernel_scale
|
|
787
|
+
|
|
788
|
+
if self.inputs_quantizer:
|
|
789
|
+
inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
|
|
790
|
+
output_scale = ops.multiply(output_scale, inputs_scale)
|
|
791
|
+
|
|
506
792
|
x = ops.matmul(inputs, unpacked_kernel)
|
|
507
793
|
x = ops.cast(x, self.compute_dtype)
|
|
508
|
-
x = ops.divide(x,
|
|
794
|
+
x = ops.divide(x, output_scale)
|
|
509
795
|
return x, grad_fn
|
|
510
796
|
|
|
511
797
|
x = matmul_with_inputs_gradient(
|
|
@@ -617,30 +903,37 @@ class Dense(Layer):
|
|
|
617
903
|
x = self.activation(x)
|
|
618
904
|
return x
|
|
619
905
|
|
|
620
|
-
def quantize(self, mode, type_check=True):
|
|
906
|
+
def quantize(self, mode=None, type_check=True, config=None):
|
|
621
907
|
# Prevent quantization of the subclasses
|
|
622
908
|
if type_check and (type(self) is not Dense):
|
|
623
909
|
raise self._not_implemented_error(self.quantize)
|
|
624
910
|
|
|
911
|
+
self.quantization_config = config
|
|
912
|
+
|
|
625
913
|
kernel_shape = self._kernel.shape
|
|
626
914
|
if mode == "int8":
|
|
627
|
-
|
|
628
|
-
self.
|
|
915
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
916
|
+
self.quantization_config, quantizers.AbsMaxQuantizer(axis=0)
|
|
917
|
+
)
|
|
918
|
+
kernel_value, kernel_scale = weight_quantizer(
|
|
919
|
+
self._kernel, to_numpy=True
|
|
629
920
|
)
|
|
630
921
|
kernel_scale = ops.squeeze(kernel_scale, axis=0)
|
|
631
922
|
del self._kernel
|
|
632
923
|
# Build variables for int8 mode
|
|
633
|
-
self.quantized_build(kernel_shape, mode)
|
|
924
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
634
925
|
self._kernel.assign(kernel_value)
|
|
635
926
|
self.kernel_scale.assign(kernel_scale)
|
|
636
927
|
elif mode == "int4":
|
|
637
928
|
# 1. Quantize to int4 values (still int8 dtype, range [-8,7])
|
|
638
|
-
|
|
639
|
-
self.
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
929
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
930
|
+
self.quantization_config,
|
|
931
|
+
quantizers.AbsMaxQuantizer(
|
|
932
|
+
axis=0, value_range=(-8, 7), output_dtype="int8"
|
|
933
|
+
),
|
|
934
|
+
)
|
|
935
|
+
kernel_value_int4, kernel_scale = weight_quantizer(
|
|
936
|
+
self._kernel, to_numpy=True
|
|
644
937
|
)
|
|
645
938
|
kernel_scale = ops.squeeze(kernel_scale, axis=0)
|
|
646
939
|
# 2. Pack two int4 values into a single int8 byte.
|
|
@@ -648,10 +941,14 @@ class Dense(Layer):
|
|
|
648
941
|
del self._kernel
|
|
649
942
|
# Build variables using the original kernel shape; _int4_build will
|
|
650
943
|
# compute the packed shape internally.
|
|
651
|
-
self.quantized_build(kernel_shape, mode)
|
|
944
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
652
945
|
# Assign packed values.
|
|
653
946
|
self._kernel.assign(packed_kernel_value)
|
|
654
947
|
self.kernel_scale.assign(kernel_scale)
|
|
948
|
+
elif mode == "gptq":
|
|
949
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
950
|
+
elif mode == "awq":
|
|
951
|
+
self.quantized_build(kernel_shape, mode, self.quantization_config)
|
|
655
952
|
elif mode == "float8":
|
|
656
953
|
self.quantized_build(kernel_shape, mode)
|
|
657
954
|
else:
|
|
@@ -661,7 +958,14 @@ class Dense(Layer):
|
|
|
661
958
|
if self.dtype_policy.quantization_mode is None:
|
|
662
959
|
from keras.src import dtype_policies # local import to avoid cycle
|
|
663
960
|
|
|
664
|
-
|
|
961
|
+
policy_name = mode
|
|
962
|
+
if mode == "gptq":
|
|
963
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
964
|
+
elif mode == "awq":
|
|
965
|
+
policy_name = self.quantization_config.dtype_policy_string()
|
|
966
|
+
policy = dtype_policies.get(
|
|
967
|
+
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
968
|
+
)
|
|
665
969
|
self.dtype_policy = policy
|
|
666
970
|
|
|
667
971
|
def _get_kernel_with_merged_lora(self):
|
|
@@ -693,7 +997,7 @@ class Dense(Layer):
|
|
|
693
997
|
`kernel_scale`: The quantization scale for the merged kernel.
|
|
694
998
|
This is `None` if the layer is not quantized.
|
|
695
999
|
"""
|
|
696
|
-
if self.dtype_policy.quantization_mode
|
|
1000
|
+
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
697
1001
|
return self.kernel, None
|
|
698
1002
|
|
|
699
1003
|
kernel_value = self._kernel
|