keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/optimizers/muon.py
CHANGED
|
@@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer):
|
|
|
20
20
|
The Muon optimizer can use both the Muon update step or the
|
|
21
21
|
AdamW update step based on the following:
|
|
22
22
|
|
|
23
|
-
- For any variable that isn't 2D,
|
|
23
|
+
- For any variable that isn't 2D, the AdamW step
|
|
24
24
|
will be used. This is not configurable.
|
|
25
25
|
- If the argument `exclude_embeddings` (defaults to `True`) is set
|
|
26
26
|
to `True`, the AdamW step will be used.
|
|
@@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer):
|
|
|
46
46
|
that takes no arguments and returns the actual value to use.
|
|
47
47
|
The exponential decay rate for the 1st moment estimates. Defaults to
|
|
48
48
|
`0.9`.
|
|
49
|
-
adam_beta_2: A float value or a constant float tensor,
|
|
49
|
+
adam_beta_2: A float value or a constant float tensor, or a callable
|
|
50
50
|
that takes no arguments and returns the actual value to use.
|
|
51
51
|
The exponential decay rate for the 2nd moment estimates. Defaults to
|
|
52
52
|
`0.999`.
|
|
53
|
+
adam_weight_decay: Float. If set, weight decay is applied when using
|
|
54
|
+
the Adam optimizer.
|
|
53
55
|
epsilon: A small constant for numerical stability. This is
|
|
54
56
|
"epsilon hat" in the Kingma and Ba paper
|
|
55
57
|
(in the formula just before Section 2.1),
|
|
@@ -67,11 +69,15 @@ class Muon(optimizer.Optimizer):
|
|
|
67
69
|
It is recommended to use the default value
|
|
68
70
|
adam_lr_ratio: Float, the ratio of the learning rate when
|
|
69
71
|
using Adam to the main learning rate.
|
|
70
|
-
|
|
72
|
+
It is recommended to set it to 1
|
|
71
73
|
momentum: Float, momentum used by internal SGD.
|
|
72
74
|
ns_steps: Integer, number of Newton-Schulz iterations to run.
|
|
73
75
|
nesterov: Boolean, whether to use Nesterov-style momentum
|
|
74
76
|
{{base_optimizer_keyword_args}}
|
|
77
|
+
rms_rate: Float. A parameter from https://arxiv.org/abs/2502.16982
|
|
78
|
+
that can enhance the stability of Muon, allowing it to use the
|
|
79
|
+
same learning rate and weight decay as Adam. Defaults to `0.2`.
|
|
80
|
+
Set to `None` to disable this feature.
|
|
75
81
|
"""
|
|
76
82
|
|
|
77
83
|
def __init__(
|
|
@@ -79,8 +85,9 @@ class Muon(optimizer.Optimizer):
|
|
|
79
85
|
learning_rate=0.001,
|
|
80
86
|
adam_beta_1=0.9,
|
|
81
87
|
adam_beta_2=0.999,
|
|
88
|
+
adam_weight_decay=0.004,
|
|
82
89
|
epsilon=1e-7,
|
|
83
|
-
weight_decay=0.
|
|
90
|
+
weight_decay=0.004,
|
|
84
91
|
clipnorm=None,
|
|
85
92
|
clipvalue=None,
|
|
86
93
|
global_clipnorm=None,
|
|
@@ -95,10 +102,11 @@ class Muon(optimizer.Optimizer):
|
|
|
95
102
|
muon_a=3.4445,
|
|
96
103
|
muon_b=-4.7750,
|
|
97
104
|
muon_c=2.0315,
|
|
98
|
-
adam_lr_ratio=
|
|
105
|
+
adam_lr_ratio=1,
|
|
99
106
|
momentum=0.95,
|
|
100
|
-
ns_steps=
|
|
107
|
+
ns_steps=5,
|
|
101
108
|
nesterov=True,
|
|
109
|
+
rms_rate=0.2,
|
|
102
110
|
**kwargs,
|
|
103
111
|
):
|
|
104
112
|
super().__init__(
|
|
@@ -127,12 +135,13 @@ class Muon(optimizer.Optimizer):
|
|
|
127
135
|
self.nesterov = nesterov
|
|
128
136
|
self.exclude_embeddings = exclude_embeddings
|
|
129
137
|
self.exclude_layers = exclude_layers or []
|
|
138
|
+
self.adam_weight_decay = adam_weight_decay
|
|
139
|
+
self.rms_rate = rms_rate
|
|
130
140
|
|
|
131
141
|
def _should_use_adamw(self, variable):
|
|
132
|
-
# To use it with 4D convolutional filters,
|
|
133
142
|
# it works well to just flatten their last 3 dimensions.
|
|
134
143
|
# any {0,1}-D parameters should all be optimized by adam
|
|
135
|
-
if
|
|
144
|
+
if len(variable.shape) != 2:
|
|
136
145
|
return True
|
|
137
146
|
if self.exclude_embeddings and "embedding" in variable.path.lower():
|
|
138
147
|
return True
|
|
@@ -153,52 +162,50 @@ class Muon(optimizer.Optimizer):
|
|
|
153
162
|
if self.built:
|
|
154
163
|
return
|
|
155
164
|
super().build(var_list)
|
|
156
|
-
|
|
157
|
-
self.
|
|
158
|
-
|
|
159
|
-
self.
|
|
160
|
-
self.muon_velocities = {}
|
|
165
|
+
# Momentums are for both Muon and Adam
|
|
166
|
+
self.momentums = [None] * len(var_list)
|
|
167
|
+
# Velocities are just for Adam
|
|
168
|
+
self.adam_velocities = [None] * len(var_list)
|
|
161
169
|
|
|
162
170
|
for var in var_list:
|
|
163
171
|
if not self._overwrite_variable_with_gradient(var):
|
|
164
|
-
self.
|
|
172
|
+
self.momentums[self._get_variable_index(var)] = (
|
|
165
173
|
self.add_variable_from_reference(
|
|
166
174
|
reference_variable=var, name="momentum"
|
|
167
175
|
)
|
|
168
176
|
)
|
|
169
177
|
if self._should_use_adamw(var):
|
|
170
|
-
self.adam_velocities[var
|
|
178
|
+
self.adam_velocities[self._get_variable_index(var)] = (
|
|
171
179
|
self.add_variable_from_reference(
|
|
172
180
|
reference_variable=var, name="velocity"
|
|
173
181
|
)
|
|
174
182
|
)
|
|
175
183
|
|
|
176
184
|
def update_step(self, gradient, variable, learning_rate):
|
|
177
|
-
|
|
185
|
+
variable_index = self._get_variable_index(variable)
|
|
186
|
+
m = self.momentums[variable_index]
|
|
187
|
+
v = self.adam_velocities[variable_index]
|
|
188
|
+
|
|
189
|
+
# The presence of the velocity tells us that this variable is for Adam
|
|
190
|
+
if v is not None:
|
|
178
191
|
# It should be noted that lr is one-tenth when using adamw.
|
|
179
192
|
self._adamw_update_step(
|
|
180
|
-
gradient, variable, learning_rate * self.adam_lr_ratio
|
|
193
|
+
gradient, variable, learning_rate * self.adam_lr_ratio, m, v
|
|
181
194
|
)
|
|
182
195
|
else:
|
|
183
|
-
self._muon_update_step(gradient, variable, learning_rate)
|
|
196
|
+
self._muon_update_step(gradient, variable, learning_rate, m)
|
|
184
197
|
|
|
185
|
-
def _muon_update_step(self, gradient, variable, lr):
|
|
186
|
-
m = self.adam_momentums[variable.path]
|
|
198
|
+
def _muon_update_step(self, gradient, variable, lr, m):
|
|
187
199
|
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
|
|
188
|
-
shape = variable.shape
|
|
189
200
|
if self.nesterov:
|
|
190
201
|
g = ops.add(gradient, self.momentum * m)
|
|
191
202
|
else:
|
|
192
203
|
g = m
|
|
204
|
+
update = self.zeropower_via_newtonschulz5(g, self.ns_steps)
|
|
193
205
|
|
|
194
|
-
self.assign_sub(
|
|
195
|
-
variable,
|
|
196
|
-
lr
|
|
197
|
-
* self.zeropower_via_newtonschulz5(g, self.ns_steps)
|
|
198
|
-
* max(1, shape[0] / shape[1]) ** 0.5,
|
|
199
|
-
)
|
|
206
|
+
self.assign_sub(variable, self.lr_adjust(lr * update))
|
|
200
207
|
|
|
201
|
-
def _adamw_update_step(self, gradient, variable, learning_rate):
|
|
208
|
+
def _adamw_update_step(self, gradient, variable, learning_rate, m, v):
|
|
202
209
|
"""Update step given gradient and the associated model variable."""
|
|
203
210
|
lr = ops.cast(learning_rate, variable.dtype)
|
|
204
211
|
gradient = ops.cast(gradient, variable.dtype)
|
|
@@ -210,9 +217,6 @@ class Muon(optimizer.Optimizer):
|
|
|
210
217
|
ops.cast(self.adam_beta_2, variable.dtype), local_step
|
|
211
218
|
)
|
|
212
219
|
|
|
213
|
-
m = self.adam_momentums[variable.path]
|
|
214
|
-
v = self.adam_velocities[variable.path]
|
|
215
|
-
|
|
216
220
|
alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)
|
|
217
221
|
|
|
218
222
|
self.assign_add(
|
|
@@ -239,6 +243,20 @@ class Muon(optimizer.Optimizer):
|
|
|
239
243
|
X = ops.transpose(X, temp_order)
|
|
240
244
|
return X
|
|
241
245
|
|
|
246
|
+
def lr_adjust(self, x):
|
|
247
|
+
"""Adjusts learning rate based on the Moonlight implementation.
|
|
248
|
+
This method enhances the stability of Muon, allowing it to use the same
|
|
249
|
+
learning rate and weight decay as Adam. For details, see
|
|
250
|
+
https://arxiv.org/abs/2502.16982.
|
|
251
|
+
For a 2D matrix, the update is scaled by `sqrt(max(n, m)) * rms_rate`,
|
|
252
|
+
where `n` and `m` are the dimensions of the matrix.
|
|
253
|
+
"""
|
|
254
|
+
if self.rms_rate is None:
|
|
255
|
+
return x
|
|
256
|
+
# moonlight version
|
|
257
|
+
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
|
|
258
|
+
return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate
|
|
259
|
+
|
|
242
260
|
def zeropower_via_newtonschulz5(self, x, steps: int):
|
|
243
261
|
"""We apply the Newton-Schulz iteration to compute matrix G.
|
|
244
262
|
|
|
@@ -268,6 +286,20 @@ class Muon(optimizer.Optimizer):
|
|
|
268
286
|
x = self.transpose_last_axis(x)
|
|
269
287
|
return x
|
|
270
288
|
|
|
289
|
+
def _apply_weight_decay(self, variables):
|
|
290
|
+
for variable in variables:
|
|
291
|
+
if not self._use_weight_decay(variable):
|
|
292
|
+
continue
|
|
293
|
+
if self._should_use_adamw(variable):
|
|
294
|
+
weight_decay_value = self.adam_weight_decay
|
|
295
|
+
else:
|
|
296
|
+
weight_decay_value = self.weight_decay
|
|
297
|
+
if weight_decay_value is None:
|
|
298
|
+
continue
|
|
299
|
+
wd = ops.cast(weight_decay_value, variable.dtype)
|
|
300
|
+
lr = ops.cast(self.learning_rate, variable.dtype)
|
|
301
|
+
variable.assign(variable - variable * wd * lr)
|
|
302
|
+
|
|
271
303
|
def get_config(self):
|
|
272
304
|
config = super().get_config()
|
|
273
305
|
config.update(
|
|
@@ -284,6 +316,8 @@ class Muon(optimizer.Optimizer):
|
|
|
284
316
|
"ns_steps": self.ns_steps,
|
|
285
317
|
"nesterov": self.nesterov,
|
|
286
318
|
"exclude_embeddings": self.exclude_embeddings,
|
|
319
|
+
"adam_weight_decay": self.adam_weight_decay,
|
|
320
|
+
"rms_rate": self.rms_rate,
|
|
287
321
|
}
|
|
288
322
|
)
|
|
289
323
|
return config
|
|
@@ -584,9 +584,10 @@ class CosineDecay(LearningRateSchedule):
|
|
|
584
584
|
schedule applies a linear increase per optimizer step to our learning rate
|
|
585
585
|
from `initial_learning_rate` to `warmup_target` for a duration of
|
|
586
586
|
`warmup_steps`. Afterwards, it applies a cosine decay function taking our
|
|
587
|
-
learning rate from `warmup_target` to `alpha` for a
|
|
588
|
-
`decay_steps`. If `warmup_target` is None we skip warmup and
|
|
589
|
-
will take our learning rate from `initial_learning_rate` to
|
|
587
|
+
learning rate from `warmup_target` to `warmup_target * alpha` for a
|
|
588
|
+
duration of `decay_steps`. If `warmup_target` is None we skip warmup and
|
|
589
|
+
our decay will take our learning rate from `initial_learning_rate` to
|
|
590
|
+
`initial_learning_rate * alpha`.
|
|
590
591
|
It requires a `step` value to compute the learning rate. You can
|
|
591
592
|
just pass a backend variable that you increment at each training step.
|
|
592
593
|
|
keras/src/quantizers/__init__.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
|
|
3
3
|
from keras.src.api_export import keras_export
|
|
4
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
5
|
+
from keras.src.quantizers.quantization_config import Float8QuantizationConfig
|
|
6
|
+
from keras.src.quantizers.quantization_config import Int4QuantizationConfig
|
|
7
|
+
from keras.src.quantizers.quantization_config import Int8QuantizationConfig
|
|
8
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
4
9
|
from keras.src.quantizers.quantizers import AbsMaxQuantizer
|
|
5
10
|
from keras.src.quantizers.quantizers import Quantizer
|
|
6
11
|
from keras.src.quantizers.quantizers import abs_max_quantize
|
|
@@ -13,7 +18,15 @@ from keras.src.quantizers.quantizers import unpack_int4
|
|
|
13
18
|
from keras.src.saving import serialization_lib
|
|
14
19
|
from keras.src.utils.naming import to_snake_case
|
|
15
20
|
|
|
16
|
-
ALL_OBJECTS = {
|
|
21
|
+
ALL_OBJECTS = {
|
|
22
|
+
Quantizer,
|
|
23
|
+
AbsMaxQuantizer,
|
|
24
|
+
QuantizationConfig,
|
|
25
|
+
Int8QuantizationConfig,
|
|
26
|
+
Int4QuantizationConfig,
|
|
27
|
+
Float8QuantizationConfig,
|
|
28
|
+
AWQConfig,
|
|
29
|
+
}
|
|
17
30
|
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
|
|
18
31
|
ALL_OBJECTS_DICT.update(
|
|
19
32
|
{to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}
|
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
"""AWQ (Activation-aware Weight Quantization) algorithm implementation.
|
|
2
|
+
|
|
3
|
+
AWQ protects salient weights by finding optimal per-channel scales based on
|
|
4
|
+
activation magnitudes, then applies those scales before quantization.
|
|
5
|
+
|
|
6
|
+
Reference: https://arxiv.org/abs/2306.00978
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import types
|
|
10
|
+
|
|
11
|
+
from keras.src import ops
|
|
12
|
+
from keras.src.layers import Dense
|
|
13
|
+
from keras.src.layers import EinsumDense
|
|
14
|
+
from keras.src.quantizers.quantizers import compute_quantization_parameters
|
|
15
|
+
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
16
|
+
from keras.src.quantizers.quantizers import dequantize_with_zero_point
|
|
17
|
+
from keras.src.quantizers.quantizers import quantize_with_sz_map
|
|
18
|
+
from keras.src.quantizers.quantizers import quantize_with_zero_point
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def awq_search_optimal_scales(
|
|
22
|
+
weights,
|
|
23
|
+
activation_magnitudes,
|
|
24
|
+
*,
|
|
25
|
+
num_grid_points=20,
|
|
26
|
+
group_size=-1,
|
|
27
|
+
):
|
|
28
|
+
"""Search for optimal AWQ scales using grid search.
|
|
29
|
+
|
|
30
|
+
The AWQ algorithm finds scaling factors that protect salient weights.
|
|
31
|
+
For each channel, we search for an optimal ratio in [0, 1] that minimizes
|
|
32
|
+
the activation-weighted quantization error.
|
|
33
|
+
|
|
34
|
+
The key insight: we MULTIPLY weights by scales before quantization to
|
|
35
|
+
expand salient weights. This ensures quantization noise is small relative
|
|
36
|
+
to the expanded weight magnitude. During inference, we divide by scales
|
|
37
|
+
to restore the original magnitude.
|
|
38
|
+
|
|
39
|
+
Scale formula: scales = x_max.pow(ratio).clamp(min=1e-4)
|
|
40
|
+
Loss function: Activation-weighted MSE (approximates output error)
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
weights: Weight tensor [out_features, in_features] (transposed kernel).
|
|
44
|
+
activation_magnitudes: Per-channel activation magnitudes [in_features].
|
|
45
|
+
num_grid_points: Number of grid search points. Defaults to 20.
|
|
46
|
+
group_size: Group size for quantization (-1 for per-channel).
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
best_scales: Optimal per-channel scales [in_features].
|
|
50
|
+
"""
|
|
51
|
+
in_features = ops.shape(weights)[1]
|
|
52
|
+
|
|
53
|
+
# Compute per-channel activation magnitudes (x_max)
|
|
54
|
+
# activations should already be per-channel max magnitudes
|
|
55
|
+
x_max = ops.cast(activation_magnitudes, "float32")
|
|
56
|
+
# Avoid zero or very small values
|
|
57
|
+
x_max = ops.where(ops.less(x_max, 1e-8), ops.ones_like(x_max), x_max)
|
|
58
|
+
|
|
59
|
+
best_loss = None
|
|
60
|
+
best_scales = ops.ones((in_features,), dtype="float32")
|
|
61
|
+
|
|
62
|
+
# Grid search over ratio values from 0 to 1
|
|
63
|
+
for i in range(num_grid_points + 1):
|
|
64
|
+
ratio = i / num_grid_points
|
|
65
|
+
|
|
66
|
+
# Compute scales: x_max^ratio (clipped to avoid numerical issues)
|
|
67
|
+
if ratio == 0:
|
|
68
|
+
scales = ops.ones_like(x_max)
|
|
69
|
+
else:
|
|
70
|
+
scales = ops.power(x_max, ratio)
|
|
71
|
+
scales = ops.maximum(scales, 1e-4)
|
|
72
|
+
|
|
73
|
+
# Normalize scales to avoid extreme values
|
|
74
|
+
scale_mean = ops.sqrt(ops.multiply(ops.max(scales), ops.min(scales)))
|
|
75
|
+
scale_mean = ops.maximum(scale_mean, 1e-8)
|
|
76
|
+
scales = ops.divide(scales, scale_mean)
|
|
77
|
+
|
|
78
|
+
# Apply scales to weights by MULTIPLYING (expand salient weights)
|
|
79
|
+
# weights_scaled: [out_features, in_features]
|
|
80
|
+
weights_scaled = ops.multiply(weights, scales)
|
|
81
|
+
|
|
82
|
+
if group_size == -1:
|
|
83
|
+
# Per-channel quantization (no grouping)
|
|
84
|
+
scale_q, zero_q, maxq = compute_quantization_parameters(
|
|
85
|
+
weights_scaled,
|
|
86
|
+
bits=4,
|
|
87
|
+
symmetric=False,
|
|
88
|
+
per_channel=True,
|
|
89
|
+
group_size=-1,
|
|
90
|
+
compute_dtype="float32",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Quantize and dequantize
|
|
94
|
+
quantized = quantize_with_zero_point(
|
|
95
|
+
weights_scaled, scale_q, zero_q, maxq
|
|
96
|
+
)
|
|
97
|
+
dequantized = dequantize_with_zero_point(quantized, scale_q, zero_q)
|
|
98
|
+
else:
|
|
99
|
+
# Grouped quantization - use proper per-row grouping
|
|
100
|
+
scale_q, zero_q, maxq = compute_quantization_parameters(
|
|
101
|
+
weights_scaled,
|
|
102
|
+
bits=4,
|
|
103
|
+
symmetric=False,
|
|
104
|
+
per_channel=True,
|
|
105
|
+
group_size=group_size,
|
|
106
|
+
compute_dtype="float32",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Compute group indices: maps each input feature to its group
|
|
110
|
+
g_idx = ops.cast(ops.arange(0, in_features) // group_size, "int32")
|
|
111
|
+
|
|
112
|
+
# Quantize and dequantize using group index mapping
|
|
113
|
+
quantized = quantize_with_sz_map(
|
|
114
|
+
weights_scaled, scale_q, zero_q, g_idx, maxq
|
|
115
|
+
)
|
|
116
|
+
dequantized = dequantize_with_sz_map(
|
|
117
|
+
quantized, scale_q, zero_q, g_idx
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Scale back down by DIVIDING to restore original magnitude
|
|
121
|
+
reconstructed = ops.divide(dequantized, scales)
|
|
122
|
+
|
|
123
|
+
# Compute activation-weighted MSE loss
|
|
124
|
+
# This approximates the output error: ||W*X - W_hat*X||^2
|
|
125
|
+
# by weighting each channel's error by x_max^2
|
|
126
|
+
weight_error = ops.square(ops.subtract(weights, reconstructed))
|
|
127
|
+
# Weight by activation magnitudes squared (broadcast over out_features)
|
|
128
|
+
weighted_error = ops.multiply(weight_error, ops.square(x_max))
|
|
129
|
+
loss = ops.mean(weighted_error)
|
|
130
|
+
|
|
131
|
+
# Track best
|
|
132
|
+
if best_loss is None:
|
|
133
|
+
best_loss = loss
|
|
134
|
+
best_scales = scales
|
|
135
|
+
else:
|
|
136
|
+
is_better = ops.less(loss, best_loss)
|
|
137
|
+
if is_better:
|
|
138
|
+
best_loss = loss
|
|
139
|
+
best_scales = scales
|
|
140
|
+
|
|
141
|
+
return best_scales
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def awq_quantize_matrix(
|
|
145
|
+
weights_transpose,
|
|
146
|
+
activation_magnitudes,
|
|
147
|
+
*,
|
|
148
|
+
num_grid_points=20,
|
|
149
|
+
group_size=-1,
|
|
150
|
+
):
|
|
151
|
+
"""Quantize a weight matrix using AWQ.
|
|
152
|
+
|
|
153
|
+
This function performs the complete AWQ quantization process:
|
|
154
|
+
1. Find optimal per-channel scales via grid search
|
|
155
|
+
2. Apply scales to weights
|
|
156
|
+
3. Compute quantization parameters
|
|
157
|
+
4. Quantize weights
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
weights_transpose: Weight matrix [out_features, in_features].
|
|
161
|
+
activation_magnitudes: Per-channel activation magnitudes [in_features].
|
|
162
|
+
num_grid_points: Number of grid search points.
|
|
163
|
+
group_size: Group size for quantization.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
quantized_weights: Quantized weights [out_features, in_features].
|
|
167
|
+
scales: Quantization scales [out_features, num_groups].
|
|
168
|
+
zeros: Zero points [out_features, num_groups].
|
|
169
|
+
awq_scales: AWQ per-channel scales [in_features].
|
|
170
|
+
g_idx: Group indices [in_features].
|
|
171
|
+
"""
|
|
172
|
+
in_features = ops.shape(weights_transpose)[1]
|
|
173
|
+
|
|
174
|
+
# Step 1: Find optimal AWQ scales via grid search
|
|
175
|
+
awq_scales = awq_search_optimal_scales(
|
|
176
|
+
weights_transpose,
|
|
177
|
+
activation_magnitudes,
|
|
178
|
+
num_grid_points=num_grid_points,
|
|
179
|
+
group_size=group_size,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Step 2: Apply AWQ scales by MULTIPLYING (expand salient weights)
|
|
183
|
+
# weights_scaled: [out_features, in_features]
|
|
184
|
+
weights_scaled = ops.multiply(weights_transpose, awq_scales)
|
|
185
|
+
|
|
186
|
+
if group_size == -1:
|
|
187
|
+
# Per-channel quantization (no grouping)
|
|
188
|
+
scale_q, zero_q, maxq = compute_quantization_parameters(
|
|
189
|
+
weights_scaled,
|
|
190
|
+
bits=4,
|
|
191
|
+
symmetric=False,
|
|
192
|
+
per_channel=True,
|
|
193
|
+
group_size=-1,
|
|
194
|
+
compute_dtype="float32",
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Quantize
|
|
198
|
+
quantized = quantize_with_zero_point(
|
|
199
|
+
weights_scaled, scale_q, zero_q, maxq
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Build group indices (all 0s for per-channel)
|
|
203
|
+
g_idx = ops.zeros((in_features,), dtype="float32")
|
|
204
|
+
else:
|
|
205
|
+
# Grouped quantization - use proper per-row grouping
|
|
206
|
+
scale_q, zero_q, maxq = compute_quantization_parameters(
|
|
207
|
+
weights_scaled,
|
|
208
|
+
bits=4,
|
|
209
|
+
symmetric=False,
|
|
210
|
+
per_channel=True,
|
|
211
|
+
group_size=group_size,
|
|
212
|
+
compute_dtype="float32",
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Compute group indices: maps each input feature to its group
|
|
216
|
+
g_idx = ops.cast(ops.arange(0, in_features) // group_size, "int32")
|
|
217
|
+
|
|
218
|
+
# Quantize using group index mapping
|
|
219
|
+
quantized = quantize_with_sz_map(
|
|
220
|
+
weights_scaled, scale_q, zero_q, g_idx, maxq
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Convert g_idx to float for storage
|
|
224
|
+
g_idx = ops.cast(g_idx, "float32")
|
|
225
|
+
|
|
226
|
+
return quantized, scale_q, zero_q, awq_scales, g_idx
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class AWQ:
|
|
230
|
+
"""AWQ quantizer for a single layer.
|
|
231
|
+
|
|
232
|
+
This class accumulates activation statistics during calibration and
|
|
233
|
+
performs AWQ quantization on layer weights.
|
|
234
|
+
|
|
235
|
+
The AWQ algorithm works by:
|
|
236
|
+
1. Collecting per-channel maximum activation magnitudes
|
|
237
|
+
2. Using activation magnitudes to determine weight saliency
|
|
238
|
+
3. Finding optimal per-channel scales via grid search
|
|
239
|
+
4. Applying scales before quantization to protect salient weights
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
layer: The layer to quantize (Dense or EinsumDense).
|
|
243
|
+
config: AWQConfig instance with quantization parameters.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
def __init__(self, layer, config=None):
|
|
247
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
248
|
+
|
|
249
|
+
self.original_layer = layer
|
|
250
|
+
self.config = config or AWQConfig(dataset=None, tokenizer=None)
|
|
251
|
+
self.num_samples = 0
|
|
252
|
+
|
|
253
|
+
# Handle Dense and EinsumDense layers
|
|
254
|
+
if isinstance(layer, Dense) or (
|
|
255
|
+
isinstance(layer, EinsumDense) and layer.kernel.ndim == 2
|
|
256
|
+
):
|
|
257
|
+
self.kernel_shape = layer.kernel.shape
|
|
258
|
+
self.rows = self.kernel_shape[0] # in_features
|
|
259
|
+
self.columns = self.kernel_shape[1] # out_features
|
|
260
|
+
self.layer = layer
|
|
261
|
+
elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
|
|
262
|
+
# Handle 3D EinsumDense layers (typically from attention blocks)
|
|
263
|
+
self.kernel_shape = layer.kernel.shape
|
|
264
|
+
shape = list(self.kernel_shape)
|
|
265
|
+
d_model_dim_index = shape.index(max(shape))
|
|
266
|
+
|
|
267
|
+
if d_model_dim_index == 0: # QKV projection case
|
|
268
|
+
in_features, heads, head_dim = shape
|
|
269
|
+
self.rows = in_features
|
|
270
|
+
self.columns = heads * head_dim
|
|
271
|
+
elif d_model_dim_index in [1, 2]: # Attention Output case
|
|
272
|
+
heads, head_dim, out_features = shape
|
|
273
|
+
self.rows = heads * head_dim
|
|
274
|
+
self.columns = out_features
|
|
275
|
+
else:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"Cannot determine dimensions for EinsumDense kernel "
|
|
278
|
+
f"shape {shape}"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Create a temporary object that holds a reshaped 2D version
|
|
282
|
+
self.layer = types.SimpleNamespace(
|
|
283
|
+
kernel=ops.reshape(layer.kernel, (self.rows, self.columns)),
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
raise TypeError(f"Unsupported layer type for AWQ: {type(layer)}")
|
|
287
|
+
|
|
288
|
+
# Initialize activation magnitude accumulator (per-channel max)
|
|
289
|
+
self.activation_magnitudes = ops.zeros((self.rows,), dtype="float32")
|
|
290
|
+
|
|
291
|
+
def update_activation_magnitudes(self, input_batch):
|
|
292
|
+
"""Update per-channel activation magnitude statistics.
|
|
293
|
+
|
|
294
|
+
This method tracks the maximum absolute activation value for each
|
|
295
|
+
input channel across all calibration batches.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
input_batch: Input activations tensor [batch, ..., in_features].
|
|
299
|
+
"""
|
|
300
|
+
if input_batch is None:
|
|
301
|
+
raise ValueError("Input tensor cannot be None.")
|
|
302
|
+
if ops.size(input_batch) == 0:
|
|
303
|
+
raise ValueError("Input tensor cannot be empty.")
|
|
304
|
+
|
|
305
|
+
# Flatten to [batch_samples, in_features]
|
|
306
|
+
if len(input_batch.shape) > 2:
|
|
307
|
+
input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))
|
|
308
|
+
|
|
309
|
+
x = ops.cast(input_batch, "float32")
|
|
310
|
+
|
|
311
|
+
# Compute per-channel max absolute value for this batch
|
|
312
|
+
batch_max = ops.max(ops.abs(x), axis=0)
|
|
313
|
+
|
|
314
|
+
# Update running max
|
|
315
|
+
self.activation_magnitudes = ops.maximum(
|
|
316
|
+
self.activation_magnitudes, batch_max
|
|
317
|
+
)
|
|
318
|
+
self.num_samples = self.num_samples + int(ops.shape(x)[0])
|
|
319
|
+
|
|
320
|
+
def quantize_layer(self):
|
|
321
|
+
"""Perform AWQ quantization on the layer.
|
|
322
|
+
|
|
323
|
+
This method:
|
|
324
|
+
1. Runs the AWQ grid search to find optimal scales
|
|
325
|
+
2. Quantizes the layer weights
|
|
326
|
+
3. Updates the layer's quantized variables
|
|
327
|
+
"""
|
|
328
|
+
from keras.src import quantizers
|
|
329
|
+
|
|
330
|
+
weights_matrix = ops.transpose(self.layer.kernel)
|
|
331
|
+
|
|
332
|
+
# Perform AWQ quantization
|
|
333
|
+
quantized, scale, zero, awq_scales, g_idx = awq_quantize_matrix(
|
|
334
|
+
weights_matrix,
|
|
335
|
+
self.activation_magnitudes,
|
|
336
|
+
num_grid_points=self.config.num_grid_points,
|
|
337
|
+
group_size=self.config.group_size,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Cast to uint8 for storage
|
|
341
|
+
# quantized is already [out_features, in_features]
|
|
342
|
+
quantized = ops.cast(quantized, "uint8")
|
|
343
|
+
|
|
344
|
+
# Pack to 4-bit along axis 0 (output features)
|
|
345
|
+
quantized_packed, _, _ = quantizers.pack_int4(
|
|
346
|
+
quantized, axis=0, dtype="uint8"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Assign to layer variables
|
|
350
|
+
del self.original_layer._kernel
|
|
351
|
+
self.original_layer.quantized_kernel.assign(quantized_packed)
|
|
352
|
+
self.original_layer.kernel_scale.assign(scale)
|
|
353
|
+
self.original_layer.kernel_zero.assign(zero)
|
|
354
|
+
self.original_layer.awq_scales.assign(awq_scales)
|
|
355
|
+
self.original_layer.g_idx.assign(g_idx)
|
|
356
|
+
self.original_layer.is_awq_calibrated = True
|
|
357
|
+
|
|
358
|
+
def free(self):
|
|
359
|
+
"""Free memory used by the quantizer."""
|
|
360
|
+
del self.activation_magnitudes
|
|
361
|
+
del self.layer
|