keras-nightly 3.14.0.dev2026010104__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +2 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +2 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +2 -0
  7. keras/ops/numpy/__init__.py +2 -0
  8. keras/quantizers/__init__.py +1 -0
  9. keras/src/backend/jax/nn.py +26 -9
  10. keras/src/backend/jax/numpy.py +10 -0
  11. keras/src/backend/numpy/numpy.py +15 -0
  12. keras/src/backend/openvino/numpy.py +338 -17
  13. keras/src/backend/tensorflow/numpy.py +24 -1
  14. keras/src/backend/tensorflow/rnn.py +17 -7
  15. keras/src/backend/torch/numpy.py +26 -0
  16. keras/src/backend/torch/rnn.py +28 -11
  17. keras/src/callbacks/orbax_checkpoint.py +75 -42
  18. keras/src/dtype_policies/__init__.py +2 -0
  19. keras/src/dtype_policies/dtype_policy.py +90 -1
  20. keras/src/layers/core/dense.py +122 -6
  21. keras/src/layers/core/einsum_dense.py +151 -7
  22. keras/src/layers/core/embedding.py +1 -1
  23. keras/src/layers/core/reversible_embedding.py +10 -1
  24. keras/src/layers/layer.py +5 -0
  25. keras/src/layers/preprocessing/feature_space.py +8 -4
  26. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  27. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
  28. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  29. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  30. keras/src/losses/losses.py +24 -0
  31. keras/src/models/model.py +18 -9
  32. keras/src/ops/image.py +106 -93
  33. keras/src/ops/numpy.py +138 -0
  34. keras/src/quantizers/__init__.py +2 -0
  35. keras/src/quantizers/awq.py +361 -0
  36. keras/src/quantizers/awq_config.py +140 -0
  37. keras/src/quantizers/awq_core.py +217 -0
  38. keras/src/quantizers/gptq.py +1 -2
  39. keras/src/quantizers/gptq_core.py +1 -1
  40. keras/src/quantizers/quantization_config.py +14 -0
  41. keras/src/quantizers/quantizers.py +61 -52
  42. keras/src/random/seed_generator.py +2 -2
  43. keras/src/saving/orbax_util.py +50 -0
  44. keras/src/saving/saving_api.py +37 -14
  45. keras/src/utils/jax_layer.py +69 -31
  46. keras/src/utils/module_utils.py +11 -0
  47. keras/src/utils/tracking.py +5 -5
  48. keras/src/version.py +1 -1
  49. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
  50. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +52 -48
  51. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
  52. {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,361 @@
1
+ """AWQ (Activation-aware Weight Quantization) algorithm implementation.
2
+
3
+ AWQ protects salient weights by finding optimal per-channel scales based on
4
+ activation magnitudes, then applies those scales before quantization.
5
+
6
+ Reference: https://arxiv.org/abs/2306.00978
7
+ """
8
+
9
+ import types
10
+
11
+ from keras.src import ops
12
+ from keras.src.layers import Dense
13
+ from keras.src.layers import EinsumDense
14
+ from keras.src.quantizers.quantizers import compute_quantization_parameters
15
+ from keras.src.quantizers.quantizers import dequantize_with_sz_map
16
+ from keras.src.quantizers.quantizers import dequantize_with_zero_point
17
+ from keras.src.quantizers.quantizers import quantize_with_sz_map
18
+ from keras.src.quantizers.quantizers import quantize_with_zero_point
19
+
20
+
21
+ def awq_search_optimal_scales(
22
+ weights,
23
+ activation_magnitudes,
24
+ *,
25
+ num_grid_points=20,
26
+ group_size=-1,
27
+ ):
28
+ """Search for optimal AWQ scales using grid search.
29
+
30
+ The AWQ algorithm finds scaling factors that protect salient weights.
31
+ For each channel, we search for an optimal ratio in [0, 1] that minimizes
32
+ the activation-weighted quantization error.
33
+
34
+ The key insight: we MULTIPLY weights by scales before quantization to
35
+ expand salient weights. This ensures quantization noise is small relative
36
+ to the expanded weight magnitude. During inference, we divide by scales
37
+ to restore the original magnitude.
38
+
39
+ Scale formula: scales = x_max.pow(ratio).clamp(min=1e-4)
40
+ Loss function: Activation-weighted MSE (approximates output error)
41
+
42
+ Args:
43
+ weights: Weight tensor [out_features, in_features] (transposed kernel).
44
+ activation_magnitudes: Per-channel activation magnitudes [in_features].
45
+ num_grid_points: Number of grid search points. Defaults to 20.
46
+ group_size: Group size for quantization (-1 for per-channel).
47
+
48
+ Returns:
49
+ best_scales: Optimal per-channel scales [in_features].
50
+ """
51
+ in_features = ops.shape(weights)[1]
52
+
53
+ # Compute per-channel activation magnitudes (x_max)
54
+ # activations should already be per-channel max magnitudes
55
+ x_max = ops.cast(activation_magnitudes, "float32")
56
+ # Avoid zero or very small values
57
+ x_max = ops.where(ops.less(x_max, 1e-8), ops.ones_like(x_max), x_max)
58
+
59
+ best_loss = None
60
+ best_scales = ops.ones((in_features,), dtype="float32")
61
+
62
+ # Grid search over ratio values from 0 to 1
63
+ for i in range(num_grid_points + 1):
64
+ ratio = i / num_grid_points
65
+
66
+ # Compute scales: x_max^ratio (clipped to avoid numerical issues)
67
+ if ratio == 0:
68
+ scales = ops.ones_like(x_max)
69
+ else:
70
+ scales = ops.power(x_max, ratio)
71
+ scales = ops.maximum(scales, 1e-4)
72
+
73
+ # Normalize scales to avoid extreme values
74
+ scale_mean = ops.sqrt(ops.multiply(ops.max(scales), ops.min(scales)))
75
+ scale_mean = ops.maximum(scale_mean, 1e-8)
76
+ scales = ops.divide(scales, scale_mean)
77
+
78
+ # Apply scales to weights by MULTIPLYING (expand salient weights)
79
+ # weights_scaled: [out_features, in_features]
80
+ weights_scaled = ops.multiply(weights, scales)
81
+
82
+ if group_size == -1:
83
+ # Per-channel quantization (no grouping)
84
+ scale_q, zero_q, maxq = compute_quantization_parameters(
85
+ weights_scaled,
86
+ bits=4,
87
+ symmetric=False,
88
+ per_channel=True,
89
+ group_size=-1,
90
+ compute_dtype="float32",
91
+ )
92
+
93
+ # Quantize and dequantize
94
+ quantized = quantize_with_zero_point(
95
+ weights_scaled, scale_q, zero_q, maxq
96
+ )
97
+ dequantized = dequantize_with_zero_point(quantized, scale_q, zero_q)
98
+ else:
99
+ # Grouped quantization - use proper per-row grouping
100
+ scale_q, zero_q, maxq = compute_quantization_parameters(
101
+ weights_scaled,
102
+ bits=4,
103
+ symmetric=False,
104
+ per_channel=True,
105
+ group_size=group_size,
106
+ compute_dtype="float32",
107
+ )
108
+
109
+ # Compute group indices: maps each input feature to its group
110
+ g_idx = ops.cast(ops.arange(0, in_features) // group_size, "int32")
111
+
112
+ # Quantize and dequantize using group index mapping
113
+ quantized = quantize_with_sz_map(
114
+ weights_scaled, scale_q, zero_q, g_idx, maxq
115
+ )
116
+ dequantized = dequantize_with_sz_map(
117
+ quantized, scale_q, zero_q, g_idx
118
+ )
119
+
120
+ # Scale back down by DIVIDING to restore original magnitude
121
+ reconstructed = ops.divide(dequantized, scales)
122
+
123
+ # Compute activation-weighted MSE loss
124
+ # This approximates the output error: ||W*X - W_hat*X||^2
125
+ # by weighting each channel's error by x_max^2
126
+ weight_error = ops.square(ops.subtract(weights, reconstructed))
127
+ # Weight by activation magnitudes squared (broadcast over out_features)
128
+ weighted_error = ops.multiply(weight_error, ops.square(x_max))
129
+ loss = ops.mean(weighted_error)
130
+
131
+ # Track best
132
+ if best_loss is None:
133
+ best_loss = loss
134
+ best_scales = scales
135
+ else:
136
+ is_better = ops.less(loss, best_loss)
137
+ if is_better:
138
+ best_loss = loss
139
+ best_scales = scales
140
+
141
+ return best_scales
142
+
143
+
144
+ def awq_quantize_matrix(
145
+ weights_transpose,
146
+ activation_magnitudes,
147
+ *,
148
+ num_grid_points=20,
149
+ group_size=-1,
150
+ ):
151
+ """Quantize a weight matrix using AWQ.
152
+
153
+ This function performs the complete AWQ quantization process:
154
+ 1. Find optimal per-channel scales via grid search
155
+ 2. Apply scales to weights
156
+ 3. Compute quantization parameters
157
+ 4. Quantize weights
158
+
159
+ Args:
160
+ weights_transpose: Weight matrix [out_features, in_features].
161
+ activation_magnitudes: Per-channel activation magnitudes [in_features].
162
+ num_grid_points: Number of grid search points.
163
+ group_size: Group size for quantization.
164
+
165
+ Returns:
166
+ quantized_weights: Quantized weights [out_features, in_features].
167
+ scales: Quantization scales [out_features, num_groups].
168
+ zeros: Zero points [out_features, num_groups].
169
+ awq_scales: AWQ per-channel scales [in_features].
170
+ g_idx: Group indices [in_features].
171
+ """
172
+ in_features = ops.shape(weights_transpose)[1]
173
+
174
+ # Step 1: Find optimal AWQ scales via grid search
175
+ awq_scales = awq_search_optimal_scales(
176
+ weights_transpose,
177
+ activation_magnitudes,
178
+ num_grid_points=num_grid_points,
179
+ group_size=group_size,
180
+ )
181
+
182
+ # Step 2: Apply AWQ scales by MULTIPLYING (expand salient weights)
183
+ # weights_scaled: [out_features, in_features]
184
+ weights_scaled = ops.multiply(weights_transpose, awq_scales)
185
+
186
+ if group_size == -1:
187
+ # Per-channel quantization (no grouping)
188
+ scale_q, zero_q, maxq = compute_quantization_parameters(
189
+ weights_scaled,
190
+ bits=4,
191
+ symmetric=False,
192
+ per_channel=True,
193
+ group_size=-1,
194
+ compute_dtype="float32",
195
+ )
196
+
197
+ # Quantize
198
+ quantized = quantize_with_zero_point(
199
+ weights_scaled, scale_q, zero_q, maxq
200
+ )
201
+
202
+ # Build group indices (all 0s for per-channel)
203
+ g_idx = ops.zeros((in_features,), dtype="float32")
204
+ else:
205
+ # Grouped quantization - use proper per-row grouping
206
+ scale_q, zero_q, maxq = compute_quantization_parameters(
207
+ weights_scaled,
208
+ bits=4,
209
+ symmetric=False,
210
+ per_channel=True,
211
+ group_size=group_size,
212
+ compute_dtype="float32",
213
+ )
214
+
215
+ # Compute group indices: maps each input feature to its group
216
+ g_idx = ops.cast(ops.arange(0, in_features) // group_size, "int32")
217
+
218
+ # Quantize using group index mapping
219
+ quantized = quantize_with_sz_map(
220
+ weights_scaled, scale_q, zero_q, g_idx, maxq
221
+ )
222
+
223
+ # Convert g_idx to float for storage
224
+ g_idx = ops.cast(g_idx, "float32")
225
+
226
+ return quantized, scale_q, zero_q, awq_scales, g_idx
227
+
228
+
229
+ class AWQ:
230
+ """AWQ quantizer for a single layer.
231
+
232
+ This class accumulates activation statistics during calibration and
233
+ performs AWQ quantization on layer weights.
234
+
235
+ The AWQ algorithm works by:
236
+ 1. Collecting per-channel maximum activation magnitudes
237
+ 2. Using activation magnitudes to determine weight saliency
238
+ 3. Finding optimal per-channel scales via grid search
239
+ 4. Applying scales before quantization to protect salient weights
240
+
241
+ Args:
242
+ layer: The layer to quantize (Dense or EinsumDense).
243
+ config: AWQConfig instance with quantization parameters.
244
+ """
245
+
246
+ def __init__(self, layer, config=None):
247
+ from keras.src.quantizers.awq_config import AWQConfig
248
+
249
+ self.original_layer = layer
250
+ self.config = config or AWQConfig(dataset=None, tokenizer=None)
251
+ self.num_samples = 0
252
+
253
+ # Handle Dense and EinsumDense layers
254
+ if isinstance(layer, Dense) or (
255
+ isinstance(layer, EinsumDense) and layer.kernel.ndim == 2
256
+ ):
257
+ self.kernel_shape = layer.kernel.shape
258
+ self.rows = self.kernel_shape[0] # in_features
259
+ self.columns = self.kernel_shape[1] # out_features
260
+ self.layer = layer
261
+ elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
262
+ # Handle 3D EinsumDense layers (typically from attention blocks)
263
+ self.kernel_shape = layer.kernel.shape
264
+ shape = list(self.kernel_shape)
265
+ d_model_dim_index = shape.index(max(shape))
266
+
267
+ if d_model_dim_index == 0: # QKV projection case
268
+ in_features, heads, head_dim = shape
269
+ self.rows = in_features
270
+ self.columns = heads * head_dim
271
+ elif d_model_dim_index in [1, 2]: # Attention Output case
272
+ heads, head_dim, out_features = shape
273
+ self.rows = heads * head_dim
274
+ self.columns = out_features
275
+ else:
276
+ raise ValueError(
277
+ f"Cannot determine dimensions for EinsumDense kernel "
278
+ f"shape {shape}"
279
+ )
280
+
281
+ # Create a temporary object that holds a reshaped 2D version
282
+ self.layer = types.SimpleNamespace(
283
+ kernel=ops.reshape(layer.kernel, (self.rows, self.columns)),
284
+ )
285
+ else:
286
+ raise TypeError(f"Unsupported layer type for AWQ: {type(layer)}")
287
+
288
+ # Initialize activation magnitude accumulator (per-channel max)
289
+ self.activation_magnitudes = ops.zeros((self.rows,), dtype="float32")
290
+
291
+ def update_activation_magnitudes(self, input_batch):
292
+ """Update per-channel activation magnitude statistics.
293
+
294
+ This method tracks the maximum absolute activation value for each
295
+ input channel across all calibration batches.
296
+
297
+ Args:
298
+ input_batch: Input activations tensor [batch, ..., in_features].
299
+ """
300
+ if input_batch is None:
301
+ raise ValueError("Input tensor cannot be None.")
302
+ if ops.size(input_batch) == 0:
303
+ raise ValueError("Input tensor cannot be empty.")
304
+
305
+ # Flatten to [batch_samples, in_features]
306
+ if len(input_batch.shape) > 2:
307
+ input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))
308
+
309
+ x = ops.cast(input_batch, "float32")
310
+
311
+ # Compute per-channel max absolute value for this batch
312
+ batch_max = ops.max(ops.abs(x), axis=0)
313
+
314
+ # Update running max
315
+ self.activation_magnitudes = ops.maximum(
316
+ self.activation_magnitudes, batch_max
317
+ )
318
+ self.num_samples = self.num_samples + int(ops.shape(x)[0])
319
+
320
+ def quantize_layer(self):
321
+ """Perform AWQ quantization on the layer.
322
+
323
+ This method:
324
+ 1. Runs the AWQ grid search to find optimal scales
325
+ 2. Quantizes the layer weights
326
+ 3. Updates the layer's quantized variables
327
+ """
328
+ from keras.src import quantizers
329
+
330
+ weights_matrix = ops.transpose(self.layer.kernel)
331
+
332
+ # Perform AWQ quantization
333
+ quantized, scale, zero, awq_scales, g_idx = awq_quantize_matrix(
334
+ weights_matrix,
335
+ self.activation_magnitudes,
336
+ num_grid_points=self.config.num_grid_points,
337
+ group_size=self.config.group_size,
338
+ )
339
+
340
+ # Cast to uint8 for storage
341
+ # quantized is already [out_features, in_features]
342
+ quantized = ops.cast(quantized, "uint8")
343
+
344
+ # Pack to 4-bit along axis 0 (output features)
345
+ quantized_packed, _, _ = quantizers.pack_int4(
346
+ quantized, axis=0, dtype="uint8"
347
+ )
348
+
349
+ # Assign to layer variables
350
+ del self.original_layer._kernel
351
+ self.original_layer.quantized_kernel.assign(quantized_packed)
352
+ self.original_layer.kernel_scale.assign(scale)
353
+ self.original_layer.kernel_zero.assign(zero)
354
+ self.original_layer.awq_scales.assign(awq_scales)
355
+ self.original_layer.g_idx.assign(g_idx)
356
+ self.original_layer.is_awq_calibrated = True
357
+
358
+ def free(self):
359
+ """Free memory used by the quantizer."""
360
+ del self.activation_magnitudes
361
+ del self.layer
@@ -0,0 +1,140 @@
1
+ from keras.src.api_export import keras_export
2
+ from keras.src.quantizers.quantization_config import QuantizationConfig
3
+
4
+
5
+ @keras_export("keras.quantizers.AWQConfig")
6
+ class AWQConfig(QuantizationConfig):
7
+ """Configuration class for AWQ (Activation-aware Weight Quantization).
8
+
9
+ AWQ is a post-training quantization method that identifies and protects
10
+ salient weights based on activation magnitudes. It applies per-channel
11
+ scaling before quantization to minimize accuracy loss.
12
+
13
+ Methodology:
14
+ 1. Collects activation statistics from calibration data
15
+ 2. Identifies salient weight channels based on activation magnitudes
16
+ 3. Searches for optimal per-channel scaling factors via grid search
17
+ 4. Applies scaling before quantization to protect important weights
18
+
19
+ References:
20
+ - Original AWQ paper: "AWQ: Activation-aware Weight Quantization for
21
+ LLM Compression and Acceleration" (https://arxiv.org/abs/2306.00978)
22
+ - Reference implementation: https://github.com/mit-han-lab/llm-awq
23
+
24
+ Args:
25
+ dataset: The calibration dataset. It can be an iterable that yields
26
+ strings or pre-tokenized numerical tensors (e.g., a list of
27
+ strings, a generator, or a NumPy array). This data is used to
28
+ analyze activation patterns.
29
+ tokenizer: A tokenizer instance (or a similar callable) that is used
30
+ to process the `dataset`.
31
+ weight_bits: The number of bits for weight quantization. AWQ presently
32
+ only supports 4-bit quantization. Defaults to 4.
33
+ num_samples: The number of calibration data samples to use from the
34
+ dataset. Defaults to 128.
35
+ sequence_length: The sequence length to use for each calibration
36
+ sample. Defaults to 512.
37
+ group_size: The size of weight groups to quantize together. A
38
+ `group_size` of -1 indicates per-channel quantization.
39
+ Defaults to 128.
40
+ num_grid_points: The number of grid search points for finding optimal
41
+ per-channel scales. Higher values may find better scales but
42
+ take longer. Defaults to 20.
43
+ quantization_layer_structure: A dictionary defining the model's
44
+ quantization structure. It should contain:
45
+ - "pre_block_layers": list of layers to run before the first
46
+ block (e.g., embedding layer).
47
+ - "sequential_blocks": list of transformer blocks to quantize
48
+ sequentially.
49
+ If not provided, the model must implement
50
+ `get_quantization_layer_structure`.
51
+
52
+ Example:
53
+ ```python
54
+ from keras.quantizers import AWQConfig
55
+
56
+ # Create configuration for 4-bit AWQ quantization
57
+ config = AWQConfig(
58
+ dataset=calibration_data, # Your calibration dataset
59
+ tokenizer=your_tokenizer, # Tokenizer for text data
60
+ num_samples=128, # Number of calibration samples
61
+ sequence_length=512, # Sequence length for each sample
62
+ group_size=128, # Weight grouping for quantization
63
+ num_grid_points=20, # Grid search points for scale search
64
+ )
65
+
66
+ # Apply quantization to your model
67
+ model.quantize("awq", config=config)
68
+ ```
69
+
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ dataset,
75
+ tokenizer,
76
+ *,
77
+ weight_bits: int = 4,
78
+ num_samples: int = 128,
79
+ sequence_length: int = 512,
80
+ group_size: int = 128,
81
+ num_grid_points: int = 20,
82
+ quantization_layer_structure: dict = None,
83
+ ):
84
+ super().__init__()
85
+ # AWQ only supports 4-bit quantization
86
+ if weight_bits != 4:
87
+ raise ValueError(
88
+ f"AWQ only supports 4-bit quantization. "
89
+ f"Received weight_bits={weight_bits}."
90
+ )
91
+ if num_samples <= 0:
92
+ raise ValueError("num_samples must be a positive integer.")
93
+ if sequence_length <= 0:
94
+ raise ValueError("sequence_length must be a positive integer.")
95
+ if group_size < -1 or group_size == 0:
96
+ raise ValueError(
97
+ "Invalid group_size. Supported values are -1 (per-channel) "
98
+ f"or a positive integer, but got {group_size}."
99
+ )
100
+ if num_grid_points <= 0:
101
+ raise ValueError("num_grid_points must be a positive integer.")
102
+
103
+ self.dataset = dataset
104
+ self.tokenizer = tokenizer
105
+ self.weight_bits = weight_bits
106
+ self.num_samples = num_samples
107
+ self.sequence_length = sequence_length
108
+ self.group_size = group_size
109
+ self.num_grid_points = num_grid_points
110
+ self.quantization_layer_structure = quantization_layer_structure
111
+
112
+ @property
113
+ def mode(self):
114
+ return "awq"
115
+
116
+ def dtype_policy_string(self):
117
+ """Returns the dtype policy string for this configuration.
118
+
119
+ Returns:
120
+ A string representing the dtype policy, e.g. "awq/4/128".
121
+ """
122
+ return f"awq/{self.weight_bits}/{self.group_size}"
123
+
124
+ def get_config(self):
125
+ return {
126
+ # Dataset and Tokenizer are only required for one-time
127
+ # calibration and are not saved in the config.
128
+ "dataset": None,
129
+ "tokenizer": None,
130
+ "weight_bits": self.weight_bits,
131
+ "num_samples": self.num_samples,
132
+ "sequence_length": self.sequence_length,
133
+ "group_size": self.group_size,
134
+ "num_grid_points": self.num_grid_points,
135
+ "quantization_layer_structure": self.quantization_layer_structure,
136
+ }
137
+
138
+ @classmethod
139
+ def from_config(cls, config):
140
+ return cls(**config)