keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer):
20
20
  The Muon optimizer can use both the Muon update step or the
21
21
  AdamW update step based on the following:
22
22
 
23
- - For any variable that isn't 2D, 3D or 4D, the AdamW step
23
+ - For any variable that isn't 2D, the AdamW step
24
24
  will be used. This is not configurable.
25
25
  - If the argument `exclude_embeddings` (defaults to `True`) is set
26
26
  to `True`, the AdamW step will be used.
@@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer):
46
46
  that takes no arguments and returns the actual value to use.
47
47
  The exponential decay rate for the 1st moment estimates. Defaults to
48
48
  `0.9`.
49
- adam_beta_2: A float value or a constant float tensor, ora callable
49
+ adam_beta_2: A float value or a constant float tensor, or a callable
50
50
  that takes no arguments and returns the actual value to use.
51
51
  The exponential decay rate for the 2nd moment estimates. Defaults to
52
52
  `0.999`.
53
+ adam_weight_decay: Float. If set, weight decay is applied when using
54
+ the Adam optimizer.
53
55
  epsilon: A small constant for numerical stability. This is
54
56
  "epsilon hat" in the Kingma and Ba paper
55
57
  (in the formula just before Section 2.1),
@@ -67,11 +69,15 @@ class Muon(optimizer.Optimizer):
67
69
  It is recommended to use the default value
68
70
  adam_lr_ratio: Float, the ratio of the learning rate when
69
71
  using Adam to the main learning rate.
70
- it is recommended to set it to 0.1
72
+ It is recommended to set it to 1
71
73
  momentum: Float, momentum used by internal SGD.
72
74
  ns_steps: Integer, number of Newton-Schulz iterations to run.
73
75
  nesterov: Boolean, whether to use Nesterov-style momentum
74
76
  {{base_optimizer_keyword_args}}
77
+ rms_rate: Float. A parameter from https://arxiv.org/abs/2502.16982
78
+ that can enhance the stability of Muon, allowing it to use the
79
+ same learning rate and weight decay as Adam. Defaults to `0.2`.
80
+ Set to `None` to disable this feature.
75
81
  """
76
82
 
77
83
  def __init__(
@@ -79,8 +85,9 @@ class Muon(optimizer.Optimizer):
79
85
  learning_rate=0.001,
80
86
  adam_beta_1=0.9,
81
87
  adam_beta_2=0.999,
88
+ adam_weight_decay=0.004,
82
89
  epsilon=1e-7,
83
- weight_decay=0.1,
90
+ weight_decay=0.004,
84
91
  clipnorm=None,
85
92
  clipvalue=None,
86
93
  global_clipnorm=None,
@@ -95,10 +102,11 @@ class Muon(optimizer.Optimizer):
95
102
  muon_a=3.4445,
96
103
  muon_b=-4.7750,
97
104
  muon_c=2.0315,
98
- adam_lr_ratio=0.1,
105
+ adam_lr_ratio=1,
99
106
  momentum=0.95,
100
- ns_steps=6,
107
+ ns_steps=5,
101
108
  nesterov=True,
109
+ rms_rate=0.2,
102
110
  **kwargs,
103
111
  ):
104
112
  super().__init__(
@@ -127,12 +135,13 @@ class Muon(optimizer.Optimizer):
127
135
  self.nesterov = nesterov
128
136
  self.exclude_embeddings = exclude_embeddings
129
137
  self.exclude_layers = exclude_layers or []
138
+ self.adam_weight_decay = adam_weight_decay
139
+ self.rms_rate = rms_rate
130
140
 
131
141
  def _should_use_adamw(self, variable):
132
- # To use it with 4D convolutional filters,
133
142
  # it works well to just flatten their last 3 dimensions.
134
143
  # any {0,1}-D parameters should all be optimized by adam
135
- if not 1 < len(variable.shape) < 4:
144
+ if len(variable.shape) != 2:
136
145
  return True
137
146
  if self.exclude_embeddings and "embedding" in variable.path.lower():
138
147
  return True
@@ -153,52 +162,50 @@ class Muon(optimizer.Optimizer):
153
162
  if self.built:
154
163
  return
155
164
  super().build(var_list)
156
- self.adam_momentums = {}
157
- self.adam_velocities = {}
158
-
159
- self.muon_momentums = {}
160
- self.muon_velocities = {}
165
+ # Momentums are for both Muon and Adam
166
+ self.momentums = [None] * len(var_list)
167
+ # Velocities are just for Adam
168
+ self.adam_velocities = [None] * len(var_list)
161
169
 
162
170
  for var in var_list:
163
171
  if not self._overwrite_variable_with_gradient(var):
164
- self.adam_momentums[var.path] = (
172
+ self.momentums[self._get_variable_index(var)] = (
165
173
  self.add_variable_from_reference(
166
174
  reference_variable=var, name="momentum"
167
175
  )
168
176
  )
169
177
  if self._should_use_adamw(var):
170
- self.adam_velocities[var.path] = (
178
+ self.adam_velocities[self._get_variable_index(var)] = (
171
179
  self.add_variable_from_reference(
172
180
  reference_variable=var, name="velocity"
173
181
  )
174
182
  )
175
183
 
176
184
  def update_step(self, gradient, variable, learning_rate):
177
- if self._should_use_adamw(variable):
185
+ variable_index = self._get_variable_index(variable)
186
+ m = self.momentums[variable_index]
187
+ v = self.adam_velocities[variable_index]
188
+
189
+ # The presence of the velocity tells us that this variable is for Adam
190
+ if v is not None:
178
191
  # It should be noted that lr is one-tenth when using adamw.
179
192
  self._adamw_update_step(
180
- gradient, variable, learning_rate * self.adam_lr_ratio
193
+ gradient, variable, learning_rate * self.adam_lr_ratio, m, v
181
194
  )
182
195
  else:
183
- self._muon_update_step(gradient, variable, learning_rate)
196
+ self._muon_update_step(gradient, variable, learning_rate, m)
184
197
 
185
- def _muon_update_step(self, gradient, variable, lr):
186
- m = self.adam_momentums[variable.path]
198
+ def _muon_update_step(self, gradient, variable, lr, m):
187
199
  self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
188
- shape = variable.shape
189
200
  if self.nesterov:
190
201
  g = ops.add(gradient, self.momentum * m)
191
202
  else:
192
203
  g = m
204
+ update = self.zeropower_via_newtonschulz5(g, self.ns_steps)
193
205
 
194
- self.assign_sub(
195
- variable,
196
- lr
197
- * self.zeropower_via_newtonschulz5(g, self.ns_steps)
198
- * max(1, shape[0] / shape[1]) ** 0.5,
199
- )
206
+ self.assign_sub(variable, self.lr_adjust(lr * update))
200
207
 
201
- def _adamw_update_step(self, gradient, variable, learning_rate):
208
+ def _adamw_update_step(self, gradient, variable, learning_rate, m, v):
202
209
  """Update step given gradient and the associated model variable."""
203
210
  lr = ops.cast(learning_rate, variable.dtype)
204
211
  gradient = ops.cast(gradient, variable.dtype)
@@ -210,9 +217,6 @@ class Muon(optimizer.Optimizer):
210
217
  ops.cast(self.adam_beta_2, variable.dtype), local_step
211
218
  )
212
219
 
213
- m = self.adam_momentums[variable.path]
214
- v = self.adam_velocities[variable.path]
215
-
216
220
  alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)
217
221
 
218
222
  self.assign_add(
@@ -239,6 +243,20 @@ class Muon(optimizer.Optimizer):
239
243
  X = ops.transpose(X, temp_order)
240
244
  return X
241
245
 
246
+ def lr_adjust(self, x):
247
+ """Adjusts learning rate based on the Moonlight implementation.
248
+ This method enhances the stability of Muon, allowing it to use the same
249
+ learning rate and weight decay as Adam. For details, see
250
+ https://arxiv.org/abs/2502.16982.
251
+ For a 2D matrix, the update is scaled by `sqrt(max(n, m)) * rms_rate`,
252
+ where `n` and `m` are the dimensions of the matrix.
253
+ """
254
+ if self.rms_rate is None:
255
+ return x
256
+ # moonlight version
257
+ # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
258
+ return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate
259
+
242
260
  def zeropower_via_newtonschulz5(self, x, steps: int):
243
261
  """We apply the Newton-Schulz iteration to compute matrix G.
244
262
 
@@ -268,6 +286,20 @@ class Muon(optimizer.Optimizer):
268
286
  x = self.transpose_last_axis(x)
269
287
  return x
270
288
 
289
+ def _apply_weight_decay(self, variables):
290
+ for variable in variables:
291
+ if not self._use_weight_decay(variable):
292
+ continue
293
+ if self._should_use_adamw(variable):
294
+ weight_decay_value = self.adam_weight_decay
295
+ else:
296
+ weight_decay_value = self.weight_decay
297
+ if weight_decay_value is None:
298
+ continue
299
+ wd = ops.cast(weight_decay_value, variable.dtype)
300
+ lr = ops.cast(self.learning_rate, variable.dtype)
301
+ variable.assign(variable - variable * wd * lr)
302
+
271
303
  def get_config(self):
272
304
  config = super().get_config()
273
305
  config.update(
@@ -284,6 +316,8 @@ class Muon(optimizer.Optimizer):
284
316
  "ns_steps": self.ns_steps,
285
317
  "nesterov": self.nesterov,
286
318
  "exclude_embeddings": self.exclude_embeddings,
319
+ "adam_weight_decay": self.adam_weight_decay,
320
+ "rms_rate": self.rms_rate,
287
321
  }
288
322
  )
289
323
  return config
@@ -584,9 +584,10 @@ class CosineDecay(LearningRateSchedule):
584
584
  schedule applies a linear increase per optimizer step to our learning rate
585
585
  from `initial_learning_rate` to `warmup_target` for a duration of
586
586
  `warmup_steps`. Afterwards, it applies a cosine decay function taking our
587
- learning rate from `warmup_target` to `alpha` for a duration of
588
- `decay_steps`. If `warmup_target` is None we skip warmup and our decay
589
- will take our learning rate from `initial_learning_rate` to `alpha`.
587
+ learning rate from `warmup_target` to `warmup_target * alpha` for a
588
+ duration of `decay_steps`. If `warmup_target` is None we skip warmup and
589
+ our decay will take our learning rate from `initial_learning_rate` to
590
+ `initial_learning_rate * alpha`.
590
591
  It requires a `step` value to compute the learning rate. You can
591
592
  just pass a backend variable that you increment at each training step.
592
593
 
@@ -1,6 +1,11 @@
1
1
  import inspect
2
2
 
3
3
  from keras.src.api_export import keras_export
4
+ from keras.src.quantizers.awq_config import AWQConfig
5
+ from keras.src.quantizers.quantization_config import Float8QuantizationConfig
6
+ from keras.src.quantizers.quantization_config import Int4QuantizationConfig
7
+ from keras.src.quantizers.quantization_config import Int8QuantizationConfig
8
+ from keras.src.quantizers.quantization_config import QuantizationConfig
4
9
  from keras.src.quantizers.quantizers import AbsMaxQuantizer
5
10
  from keras.src.quantizers.quantizers import Quantizer
6
11
  from keras.src.quantizers.quantizers import abs_max_quantize
@@ -13,7 +18,15 @@ from keras.src.quantizers.quantizers import unpack_int4
13
18
  from keras.src.saving import serialization_lib
14
19
  from keras.src.utils.naming import to_snake_case
15
20
 
16
- ALL_OBJECTS = {Quantizer, AbsMaxQuantizer}
21
+ ALL_OBJECTS = {
22
+ Quantizer,
23
+ AbsMaxQuantizer,
24
+ QuantizationConfig,
25
+ Int8QuantizationConfig,
26
+ Int4QuantizationConfig,
27
+ Float8QuantizationConfig,
28
+ AWQConfig,
29
+ }
17
30
  ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
18
31
  ALL_OBJECTS_DICT.update(
19
32
  {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}
@@ -0,0 +1,361 @@
1
+ """AWQ (Activation-aware Weight Quantization) algorithm implementation.
2
+
3
+ AWQ protects salient weights by finding optimal per-channel scales based on
4
+ activation magnitudes, then applies those scales before quantization.
5
+
6
+ Reference: https://arxiv.org/abs/2306.00978
7
+ """
8
+
9
+ import types
10
+
11
+ from keras.src import ops
12
+ from keras.src.layers import Dense
13
+ from keras.src.layers import EinsumDense
14
+ from keras.src.quantizers.quantizers import compute_quantization_parameters
15
+ from keras.src.quantizers.quantizers import dequantize_with_sz_map
16
+ from keras.src.quantizers.quantizers import dequantize_with_zero_point
17
+ from keras.src.quantizers.quantizers import quantize_with_sz_map
18
+ from keras.src.quantizers.quantizers import quantize_with_zero_point
19
+
20
+
21
+ def awq_search_optimal_scales(
22
+ weights,
23
+ activation_magnitudes,
24
+ *,
25
+ num_grid_points=20,
26
+ group_size=-1,
27
+ ):
28
+ """Search for optimal AWQ scales using grid search.
29
+
30
+ The AWQ algorithm finds scaling factors that protect salient weights.
31
+ For each channel, we search for an optimal ratio in [0, 1] that minimizes
32
+ the activation-weighted quantization error.
33
+
34
+ The key insight: we MULTIPLY weights by scales before quantization to
35
+ expand salient weights. This ensures quantization noise is small relative
36
+ to the expanded weight magnitude. During inference, we divide by scales
37
+ to restore the original magnitude.
38
+
39
+ Scale formula: scales = x_max.pow(ratio).clamp(min=1e-4)
40
+ Loss function: Activation-weighted MSE (approximates output error)
41
+
42
+ Args:
43
+ weights: Weight tensor [out_features, in_features] (transposed kernel).
44
+ activation_magnitudes: Per-channel activation magnitudes [in_features].
45
+ num_grid_points: Number of grid search points. Defaults to 20.
46
+ group_size: Group size for quantization (-1 for per-channel).
47
+
48
+ Returns:
49
+ best_scales: Optimal per-channel scales [in_features].
50
+ """
51
+ in_features = ops.shape(weights)[1]
52
+
53
+ # Compute per-channel activation magnitudes (x_max)
54
+ # activations should already be per-channel max magnitudes
55
+ x_max = ops.cast(activation_magnitudes, "float32")
56
+ # Avoid zero or very small values
57
+ x_max = ops.where(ops.less(x_max, 1e-8), ops.ones_like(x_max), x_max)
58
+
59
+ best_loss = None
60
+ best_scales = ops.ones((in_features,), dtype="float32")
61
+
62
+ # Grid search over ratio values from 0 to 1
63
+ for i in range(num_grid_points + 1):
64
+ ratio = i / num_grid_points
65
+
66
+ # Compute scales: x_max^ratio (clipped to avoid numerical issues)
67
+ if ratio == 0:
68
+ scales = ops.ones_like(x_max)
69
+ else:
70
+ scales = ops.power(x_max, ratio)
71
+ scales = ops.maximum(scales, 1e-4)
72
+
73
+ # Normalize scales to avoid extreme values
74
+ scale_mean = ops.sqrt(ops.multiply(ops.max(scales), ops.min(scales)))
75
+ scale_mean = ops.maximum(scale_mean, 1e-8)
76
+ scales = ops.divide(scales, scale_mean)
77
+
78
+ # Apply scales to weights by MULTIPLYING (expand salient weights)
79
+ # weights_scaled: [out_features, in_features]
80
+ weights_scaled = ops.multiply(weights, scales)
81
+
82
+ if group_size == -1:
83
+ # Per-channel quantization (no grouping)
84
+ scale_q, zero_q, maxq = compute_quantization_parameters(
85
+ weights_scaled,
86
+ bits=4,
87
+ symmetric=False,
88
+ per_channel=True,
89
+ group_size=-1,
90
+ compute_dtype="float32",
91
+ )
92
+
93
+ # Quantize and dequantize
94
+ quantized = quantize_with_zero_point(
95
+ weights_scaled, scale_q, zero_q, maxq
96
+ )
97
+ dequantized = dequantize_with_zero_point(quantized, scale_q, zero_q)
98
+ else:
99
+ # Grouped quantization - use proper per-row grouping
100
+ scale_q, zero_q, maxq = compute_quantization_parameters(
101
+ weights_scaled,
102
+ bits=4,
103
+ symmetric=False,
104
+ per_channel=True,
105
+ group_size=group_size,
106
+ compute_dtype="float32",
107
+ )
108
+
109
+ # Compute group indices: maps each input feature to its group
110
+ g_idx = ops.cast(ops.arange(0, in_features) // group_size, "int32")
111
+
112
+ # Quantize and dequantize using group index mapping
113
+ quantized = quantize_with_sz_map(
114
+ weights_scaled, scale_q, zero_q, g_idx, maxq
115
+ )
116
+ dequantized = dequantize_with_sz_map(
117
+ quantized, scale_q, zero_q, g_idx
118
+ )
119
+
120
+ # Scale back down by DIVIDING to restore original magnitude
121
+ reconstructed = ops.divide(dequantized, scales)
122
+
123
+ # Compute activation-weighted MSE loss
124
+ # This approximates the output error: ||W*X - W_hat*X||^2
125
+ # by weighting each channel's error by x_max^2
126
+ weight_error = ops.square(ops.subtract(weights, reconstructed))
127
+ # Weight by activation magnitudes squared (broadcast over out_features)
128
+ weighted_error = ops.multiply(weight_error, ops.square(x_max))
129
+ loss = ops.mean(weighted_error)
130
+
131
+ # Track best
132
+ if best_loss is None:
133
+ best_loss = loss
134
+ best_scales = scales
135
+ else:
136
+ is_better = ops.less(loss, best_loss)
137
+ if is_better:
138
+ best_loss = loss
139
+ best_scales = scales
140
+
141
+ return best_scales
142
+
143
+
144
+ def awq_quantize_matrix(
145
+ weights_transpose,
146
+ activation_magnitudes,
147
+ *,
148
+ num_grid_points=20,
149
+ group_size=-1,
150
+ ):
151
+ """Quantize a weight matrix using AWQ.
152
+
153
+ This function performs the complete AWQ quantization process:
154
+ 1. Find optimal per-channel scales via grid search
155
+ 2. Apply scales to weights
156
+ 3. Compute quantization parameters
157
+ 4. Quantize weights
158
+
159
+ Args:
160
+ weights_transpose: Weight matrix [out_features, in_features].
161
+ activation_magnitudes: Per-channel activation magnitudes [in_features].
162
+ num_grid_points: Number of grid search points.
163
+ group_size: Group size for quantization.
164
+
165
+ Returns:
166
+ quantized_weights: Quantized weights [out_features, in_features].
167
+ scales: Quantization scales [out_features, num_groups].
168
+ zeros: Zero points [out_features, num_groups].
169
+ awq_scales: AWQ per-channel scales [in_features].
170
+ g_idx: Group indices [in_features].
171
+ """
172
+ in_features = ops.shape(weights_transpose)[1]
173
+
174
+ # Step 1: Find optimal AWQ scales via grid search
175
+ awq_scales = awq_search_optimal_scales(
176
+ weights_transpose,
177
+ activation_magnitudes,
178
+ num_grid_points=num_grid_points,
179
+ group_size=group_size,
180
+ )
181
+
182
+ # Step 2: Apply AWQ scales by MULTIPLYING (expand salient weights)
183
+ # weights_scaled: [out_features, in_features]
184
+ weights_scaled = ops.multiply(weights_transpose, awq_scales)
185
+
186
+ if group_size == -1:
187
+ # Per-channel quantization (no grouping)
188
+ scale_q, zero_q, maxq = compute_quantization_parameters(
189
+ weights_scaled,
190
+ bits=4,
191
+ symmetric=False,
192
+ per_channel=True,
193
+ group_size=-1,
194
+ compute_dtype="float32",
195
+ )
196
+
197
+ # Quantize
198
+ quantized = quantize_with_zero_point(
199
+ weights_scaled, scale_q, zero_q, maxq
200
+ )
201
+
202
+ # Build group indices (all 0s for per-channel)
203
+ g_idx = ops.zeros((in_features,), dtype="float32")
204
+ else:
205
+ # Grouped quantization - use proper per-row grouping
206
+ scale_q, zero_q, maxq = compute_quantization_parameters(
207
+ weights_scaled,
208
+ bits=4,
209
+ symmetric=False,
210
+ per_channel=True,
211
+ group_size=group_size,
212
+ compute_dtype="float32",
213
+ )
214
+
215
+ # Compute group indices: maps each input feature to its group
216
+ g_idx = ops.cast(ops.arange(0, in_features) // group_size, "int32")
217
+
218
+ # Quantize using group index mapping
219
+ quantized = quantize_with_sz_map(
220
+ weights_scaled, scale_q, zero_q, g_idx, maxq
221
+ )
222
+
223
+ # Convert g_idx to float for storage
224
+ g_idx = ops.cast(g_idx, "float32")
225
+
226
+ return quantized, scale_q, zero_q, awq_scales, g_idx
227
+
228
+
229
+ class AWQ:
230
+ """AWQ quantizer for a single layer.
231
+
232
+ This class accumulates activation statistics during calibration and
233
+ performs AWQ quantization on layer weights.
234
+
235
+ The AWQ algorithm works by:
236
+ 1. Collecting per-channel maximum activation magnitudes
237
+ 2. Using activation magnitudes to determine weight saliency
238
+ 3. Finding optimal per-channel scales via grid search
239
+ 4. Applying scales before quantization to protect salient weights
240
+
241
+ Args:
242
+ layer: The layer to quantize (Dense or EinsumDense).
243
+ config: AWQConfig instance with quantization parameters.
244
+ """
245
+
246
+ def __init__(self, layer, config=None):
247
+ from keras.src.quantizers.awq_config import AWQConfig
248
+
249
+ self.original_layer = layer
250
+ self.config = config or AWQConfig(dataset=None, tokenizer=None)
251
+ self.num_samples = 0
252
+
253
+ # Handle Dense and EinsumDense layers
254
+ if isinstance(layer, Dense) or (
255
+ isinstance(layer, EinsumDense) and layer.kernel.ndim == 2
256
+ ):
257
+ self.kernel_shape = layer.kernel.shape
258
+ self.rows = self.kernel_shape[0] # in_features
259
+ self.columns = self.kernel_shape[1] # out_features
260
+ self.layer = layer
261
+ elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
262
+ # Handle 3D EinsumDense layers (typically from attention blocks)
263
+ self.kernel_shape = layer.kernel.shape
264
+ shape = list(self.kernel_shape)
265
+ d_model_dim_index = shape.index(max(shape))
266
+
267
+ if d_model_dim_index == 0: # QKV projection case
268
+ in_features, heads, head_dim = shape
269
+ self.rows = in_features
270
+ self.columns = heads * head_dim
271
+ elif d_model_dim_index in [1, 2]: # Attention Output case
272
+ heads, head_dim, out_features = shape
273
+ self.rows = heads * head_dim
274
+ self.columns = out_features
275
+ else:
276
+ raise ValueError(
277
+ f"Cannot determine dimensions for EinsumDense kernel "
278
+ f"shape {shape}"
279
+ )
280
+
281
+ # Create a temporary object that holds a reshaped 2D version
282
+ self.layer = types.SimpleNamespace(
283
+ kernel=ops.reshape(layer.kernel, (self.rows, self.columns)),
284
+ )
285
+ else:
286
+ raise TypeError(f"Unsupported layer type for AWQ: {type(layer)}")
287
+
288
+ # Initialize activation magnitude accumulator (per-channel max)
289
+ self.activation_magnitudes = ops.zeros((self.rows,), dtype="float32")
290
+
291
+ def update_activation_magnitudes(self, input_batch):
292
+ """Update per-channel activation magnitude statistics.
293
+
294
+ This method tracks the maximum absolute activation value for each
295
+ input channel across all calibration batches.
296
+
297
+ Args:
298
+ input_batch: Input activations tensor [batch, ..., in_features].
299
+ """
300
+ if input_batch is None:
301
+ raise ValueError("Input tensor cannot be None.")
302
+ if ops.size(input_batch) == 0:
303
+ raise ValueError("Input tensor cannot be empty.")
304
+
305
+ # Flatten to [batch_samples, in_features]
306
+ if len(input_batch.shape) > 2:
307
+ input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))
308
+
309
+ x = ops.cast(input_batch, "float32")
310
+
311
+ # Compute per-channel max absolute value for this batch
312
+ batch_max = ops.max(ops.abs(x), axis=0)
313
+
314
+ # Update running max
315
+ self.activation_magnitudes = ops.maximum(
316
+ self.activation_magnitudes, batch_max
317
+ )
318
+ self.num_samples = self.num_samples + int(ops.shape(x)[0])
319
+
320
+ def quantize_layer(self):
321
+ """Perform AWQ quantization on the layer.
322
+
323
+ This method:
324
+ 1. Runs the AWQ grid search to find optimal scales
325
+ 2. Quantizes the layer weights
326
+ 3. Updates the layer's quantized variables
327
+ """
328
+ from keras.src import quantizers
329
+
330
+ weights_matrix = ops.transpose(self.layer.kernel)
331
+
332
+ # Perform AWQ quantization
333
+ quantized, scale, zero, awq_scales, g_idx = awq_quantize_matrix(
334
+ weights_matrix,
335
+ self.activation_magnitudes,
336
+ num_grid_points=self.config.num_grid_points,
337
+ group_size=self.config.group_size,
338
+ )
339
+
340
+ # Cast to uint8 for storage
341
+ # quantized is already [out_features, in_features]
342
+ quantized = ops.cast(quantized, "uint8")
343
+
344
+ # Pack to 4-bit along axis 0 (output features)
345
+ quantized_packed, _, _ = quantizers.pack_int4(
346
+ quantized, axis=0, dtype="uint8"
347
+ )
348
+
349
+ # Assign to layer variables
350
+ del self.original_layer._kernel
351
+ self.original_layer.quantized_kernel.assign(quantized_packed)
352
+ self.original_layer.kernel_scale.assign(scale)
353
+ self.original_layer.kernel_zero.assign(zero)
354
+ self.original_layer.awq_scales.assign(awq_scales)
355
+ self.original_layer.g_idx.assign(g_idx)
356
+ self.original_layer.is_awq_calibrated = True
357
+
358
+ def free(self):
359
+ """Free memory used by the quantizer."""
360
+ del self.activation_magnitudes
361
+ del self.layer