keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ import math
2
+
1
3
  import ml_dtypes
2
4
 
3
5
  from keras.src import activations
@@ -9,6 +11,9 @@ from keras.src import regularizers
9
11
  from keras.src.api_export import keras_export
10
12
  from keras.src.layers.input_spec import InputSpec
11
13
  from keras.src.layers.layer import Layer
14
+ from keras.src.quantizers.quantization_config import QuantizationConfig
15
+ from keras.src.quantizers.quantizers import dequantize_with_sz_map
16
+ from keras.src.saving import serialization_lib
12
17
 
13
18
 
14
19
  @keras_export("keras.layers.Dense")
@@ -20,7 +25,9 @@ class Dense(Layer):
20
25
  where `activation` is the element-wise activation function
21
26
  passed as the `activation` argument, `kernel` is a weights matrix
22
27
  created by the layer, and `bias` is a bias vector created by the layer
23
- (only applicable if `use_bias` is `True`).
28
+ (only applicable if `use_bias` is `True`). When this layer is
29
+ followed by a `BatchNormalization` layer, it is recommended to set
30
+ `use_bias=False` as `BatchNormalization` has its own bias term.
24
31
 
25
32
  Note: If the input to the layer has a rank greater than 2, `Dense`
26
33
  computes the dot product between the `inputs` and the `kernel` along the
@@ -87,8 +94,15 @@ class Dense(Layer):
87
94
  bias_constraint=None,
88
95
  lora_rank=None,
89
96
  lora_alpha=None,
97
+ quantization_config=None,
90
98
  **kwargs,
91
99
  ):
100
+ if not isinstance(units, int) or units <= 0:
101
+ raise ValueError(
102
+ "Received an invalid value for `units`, expected a positive "
103
+ f"integer. Received: units={units}"
104
+ )
105
+
92
106
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
93
107
  self.units = units
94
108
  self.activation = activations.get(activation)
@@ -102,14 +116,19 @@ class Dense(Layer):
102
116
  self.lora_rank = lora_rank
103
117
  self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
104
118
  self.lora_enabled = False
119
+ self.quantization_config = quantization_config
105
120
  self.input_spec = InputSpec(min_ndim=2)
106
121
  self.supports_masking = True
107
122
 
108
123
  def build(self, input_shape):
109
124
  kernel_shape = (input_shape[-1], self.units)
110
125
  if self.quantization_mode:
111
- self.quantized_build(kernel_shape, mode=self.quantization_mode)
112
- if self.quantization_mode not in ("int8", "int4"):
126
+ self.quantized_build(
127
+ kernel_shape,
128
+ mode=self.quantization_mode,
129
+ config=self.quantization_config,
130
+ )
131
+ if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
113
132
  # If the layer is quantized to int8 or int4, `self._kernel` will be
114
133
  # added in `self._int8_build` or `_int4_build`. Therefore, we skip
115
134
  # it here.
@@ -137,15 +156,58 @@ class Dense(Layer):
137
156
 
138
157
  @property
139
158
  def kernel(self):
159
+ from keras.src.quantizers import gptq_core
160
+
140
161
  if not self.built:
141
162
  raise AttributeError(
142
163
  "You must build the layer before accessing `kernel`."
143
164
  )
165
+
166
+ mode = self.quantization_mode
167
+ is_gptq = mode == "gptq"
168
+ is_awq = mode == "awq"
169
+ is_int4 = mode == "int4"
170
+ gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
171
+ awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
172
+ gptq_bits = (
173
+ gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
174
+ )
175
+
176
+ # Decide the source tensor first (packed vs already-quantized vs plain
177
+ # kernel)
178
+ if is_gptq and gptq_calibrated and gptq_bits != 4:
179
+ # calibrated GPTQ, not 4-bit, no unpacking needed
180
+ kernel = self.quantized_kernel
181
+ else:
182
+ # Start with the stored kernel
183
+ kernel = getattr(self, "_kernel", None)
184
+
185
+ # Handle int4 unpacking cases in one place
186
+ if is_int4:
187
+ kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
188
+ elif is_gptq and gptq_calibrated and gptq_bits == 4:
189
+ kernel = quantizers.unpack_int4(
190
+ self.quantized_kernel,
191
+ orig_len=self.units,
192
+ axis=0,
193
+ dtype="uint8",
194
+ )
195
+ elif is_awq and awq_calibrated:
196
+ # AWQ always uses 4-bit quantization
197
+ kernel = quantizers.unpack_int4(
198
+ self.quantized_kernel,
199
+ orig_len=self.units,
200
+ axis=0,
201
+ dtype="uint8",
202
+ )
203
+
204
+ # Apply LoRA once at the end.
144
205
  if self.lora_enabled:
145
- return self._kernel + (
146
- self.lora_alpha / self.lora_rank
147
- ) * ops.matmul(self.lora_kernel_a, self.lora_kernel_b)
148
- return self._kernel
206
+ kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
207
+ self.lora_kernel_a, self.lora_kernel_b
208
+ )
209
+
210
+ return kernel
149
211
 
150
212
  def call(self, inputs, training=None):
151
213
  x = ops.matmul(inputs, self.kernel)
@@ -181,6 +243,10 @@ class Dense(Layer):
181
243
  raise ValueError(
182
244
  "lora is already enabled. This can only be done once per layer."
183
245
  )
246
+ if self.quantization_mode == "gptq":
247
+ raise NotImplementedError(
248
+ "lora is not currently supported with GPTQ quantization."
249
+ )
184
250
  self._tracker.unlock()
185
251
  # Determine the correct input dimension for the LoRA A matrix. When
186
252
  # the layer has been int4-quantized, `self._kernel` stores a *packed*
@@ -217,26 +283,26 @@ class Dense(Layer):
217
283
  # Do nothing if the layer isn't yet built
218
284
  if not self.built:
219
285
  return
220
- # The keys of the `store` will be saved as determined because the
221
- # default ordering will change after quantization
222
- kernel_value, kernel_scale = self._get_kernel_with_merged_lora()
223
- target_variables = [kernel_value]
224
- if self.use_bias:
225
- target_variables.append(self.bias)
226
- if self.quantization_mode is not None:
227
- if self.quantization_mode in ("int8", "int4"):
228
- target_variables.append(kernel_scale)
229
- elif self.quantization_mode == "float8":
230
- target_variables.append(self.inputs_scale)
231
- target_variables.append(self.inputs_amax_history)
232
- target_variables.append(self.kernel_scale)
233
- target_variables.append(self.kernel_amax_history)
234
- target_variables.append(self.outputs_grad_scale)
235
- target_variables.append(self.outputs_grad_amax_history)
286
+ mode = self.quantization_mode
287
+ if mode not in self.variable_serialization_spec:
288
+ raise self._quantization_mode_error(mode)
289
+
290
+ # Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
291
+ # for None/gptq)
292
+ kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
293
+ idx = 0
294
+ for name in self.variable_serialization_spec[mode]:
295
+ if name == "kernel":
296
+ store[str(idx)] = kernel_value
297
+ elif name == "bias" and self.bias is None:
298
+ continue
299
+ elif name == "kernel_scale" and mode in ("int4", "int8"):
300
+ # For int4/int8, the merged LoRA scale (if any) comes from
301
+ # `_get_kernel_with_merged_lora()`
302
+ store[str(idx)] = merged_kernel_scale
236
303
  else:
237
- raise self._quantization_mode_error(self.quantization_mode)
238
- for i, variable in enumerate(target_variables):
239
- store[str(i)] = variable
304
+ store[str(idx)] = getattr(self, name)
305
+ idx += 1
240
306
 
241
307
  def load_own_variables(self, store):
242
308
  if not self.lora_enabled:
@@ -244,25 +310,23 @@ class Dense(Layer):
244
310
  # Do nothing if the layer isn't yet built
245
311
  if not self.built:
246
312
  return
247
- # The keys of the `store` will be saved as determined because the
248
- # default ordering will change after quantization
249
- target_variables = [self._kernel]
250
- if self.use_bias:
251
- target_variables.append(self.bias)
252
- if self.quantization_mode is not None:
253
- if self.quantization_mode in ("int8", "int4"):
254
- target_variables.append(self.kernel_scale)
255
- elif self.quantization_mode == "float8":
256
- target_variables.append(self.inputs_scale)
257
- target_variables.append(self.inputs_amax_history)
258
- target_variables.append(self.kernel_scale)
259
- target_variables.append(self.kernel_amax_history)
260
- target_variables.append(self.outputs_grad_scale)
261
- target_variables.append(self.outputs_grad_amax_history)
313
+ mode = self.quantization_mode
314
+ if mode not in self.variable_serialization_spec:
315
+ raise self._quantization_mode_error(mode)
316
+
317
+ # A saved GPTQ/AWQ quantized model will always be calibrated.
318
+ self.is_gptq_calibrated = mode == "gptq"
319
+ self.is_awq_calibrated = mode == "awq"
320
+
321
+ idx = 0
322
+ for name in self.variable_serialization_spec[mode]:
323
+ if name == "kernel":
324
+ self._kernel.assign(store[str(idx)])
325
+ elif name == "bias" and self.bias is None:
326
+ continue
262
327
  else:
263
- raise self._quantization_mode_error(self.quantization_mode)
264
- for i, variable in enumerate(target_variables):
265
- variable.assign(store[str(i)])
328
+ getattr(self, name).assign(store[str(idx)])
329
+ idx += 1
266
330
  if self.lora_enabled:
267
331
  self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
268
332
  self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
@@ -283,61 +347,97 @@ class Dense(Layer):
283
347
  "bias_regularizer": regularizers.serialize(self.bias_regularizer),
284
348
  "kernel_constraint": constraints.serialize(self.kernel_constraint),
285
349
  "bias_constraint": constraints.serialize(self.bias_constraint),
350
+ "quantization_config": serialization_lib.serialize_keras_object(
351
+ self.quantization_config
352
+ ),
286
353
  }
287
354
  if self.lora_rank:
288
355
  config["lora_rank"] = self.lora_rank
289
356
  config["lora_alpha"] = self.lora_alpha
290
357
  return {**base_config, **config}
291
358
 
292
- def _check_load_own_variables(self, store):
293
- all_vars = self._trainable_variables + self._non_trainable_variables
294
- if len(store.keys()) != len(all_vars):
295
- if len(all_vars) == 0 and not self.built:
296
- raise ValueError(
297
- f"Layer '{self.name}' was never built "
298
- "and thus it doesn't have any variables. "
299
- f"However the weights file lists {len(store.keys())} "
300
- "variables for this layer.\n"
301
- "In most cases, this error indicates that either:\n\n"
302
- "1. The layer is owned by a parent layer that "
303
- "implements a `build()` method, but calling the "
304
- "parent's `build()` method did NOT create the state of "
305
- f"the child layer '{self.name}'. A `build()` method "
306
- "must create ALL state for the layer, including "
307
- "the state of any children layers.\n\n"
308
- "2. You need to implement "
309
- "the `def build_from_config(self, config)` method "
310
- f"on layer '{self.name}', to specify how to rebuild "
311
- "it during loading. "
312
- "In this case, you might also want to implement the "
313
- "method that generates the build config at saving time, "
314
- "`def get_build_config(self)`. "
315
- "The method `build_from_config()` is meant "
316
- "to create the state "
317
- "of the layer (i.e. its variables) upon deserialization.",
318
- )
319
- raise ValueError(
320
- f"Layer '{self.name}' expected {len(all_vars)} variables, "
321
- "but received "
322
- f"{len(store.keys())} variables during loading. "
323
- f"Expected: {[v.name for v in all_vars]}"
359
+ @classmethod
360
+ def from_config(cls, config):
361
+ config = config.copy()
362
+ config["quantization_config"] = (
363
+ serialization_lib.deserialize_keras_object(
364
+ config.get("quantization_config", None)
324
365
  )
366
+ )
367
+ return super().from_config(config)
368
+
369
+ @property
370
+ def variable_serialization_spec(self):
371
+ """Returns a dict mapping quantization modes to variable names in order.
325
372
 
326
- # Quantization-related (int8 and float8) methods
373
+ This spec is used by `save_own_variables` and `load_own_variables` to
374
+ determine the correct ordering of variables during serialization for
375
+ each quantization mode. `None` means no quantization.
376
+ """
377
+ return {
378
+ None: [
379
+ "kernel",
380
+ "bias",
381
+ ],
382
+ "int8": [
383
+ "kernel",
384
+ "bias",
385
+ "kernel_scale",
386
+ ],
387
+ "int4": [
388
+ "kernel",
389
+ "bias",
390
+ "kernel_scale",
391
+ ],
392
+ "float8": [
393
+ "kernel",
394
+ "bias",
395
+ "inputs_scale",
396
+ "inputs_amax_history",
397
+ "kernel_scale",
398
+ "kernel_amax_history",
399
+ "outputs_grad_scale",
400
+ "outputs_grad_amax_history",
401
+ ],
402
+ "gptq": [
403
+ "bias",
404
+ "quantized_kernel",
405
+ "kernel_scale",
406
+ "kernel_zero",
407
+ "g_idx",
408
+ ],
409
+ "awq": [
410
+ "bias",
411
+ "quantized_kernel",
412
+ "kernel_scale",
413
+ "kernel_zero",
414
+ "awq_scales",
415
+ "g_idx",
416
+ ],
417
+ }
327
418
 
328
- def quantized_build(self, kernel_shape, mode):
419
+ def quantized_build(self, kernel_shape, mode, config=None):
329
420
  if mode == "int8":
330
- self._int8_build(kernel_shape)
421
+ self._int8_build(kernel_shape, config)
331
422
  elif mode == "int4":
332
- self._int4_build(kernel_shape)
423
+ self._int4_build(kernel_shape, config)
333
424
  elif mode == "float8":
334
425
  self._float8_build()
426
+ elif mode == "gptq":
427
+ self._gptq_build(kernel_shape, config)
428
+ elif mode == "awq":
429
+ self._awq_build(kernel_shape, config)
335
430
  else:
336
431
  raise self._quantization_mode_error(mode)
337
432
  self._is_quantized = True
338
433
 
339
- def _int8_build(self, kernel_shape):
340
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
434
+ def _int8_build(self, kernel_shape, config=None):
435
+ self.inputs_quantizer = (
436
+ QuantizationConfig.activation_quantizer_or_default(
437
+ config, quantizers.AbsMaxQuantizer()
438
+ )
439
+ )
440
+
341
441
  self._kernel = self.add_weight(
342
442
  name="kernel",
343
443
  shape=kernel_shape,
@@ -352,7 +452,182 @@ class Dense(Layer):
352
452
  trainable=False,
353
453
  )
354
454
 
355
- def _int4_build(self, kernel_shape):
455
+ def _gptq_build(self, kernel_shape, config):
456
+ from keras.src.quantizers import gptq_core
457
+
458
+ # Ensures the forward pass uses the original high-precision kernel
459
+ # until calibration has been performed.
460
+ self.is_gptq_calibrated = False
461
+ self.kernel_shape = kernel_shape
462
+
463
+ weight_bits = gptq_core.get_weight_bits_for_layer(self, config)
464
+ # For 4-bit weights, we pack two values per byte.
465
+ units = (
466
+ (kernel_shape[1] + 1) // 2 if weight_bits == 4 else kernel_shape[1]
467
+ )
468
+
469
+ self.quantized_kernel = self.add_weight(
470
+ name="kernel",
471
+ shape=(units, kernel_shape[0]),
472
+ initializer="zeros",
473
+ dtype="uint8",
474
+ trainable=False,
475
+ )
476
+
477
+ group_size = gptq_core.get_group_size_for_layer(self, config)
478
+ n_groups = (
479
+ 1
480
+ if group_size == -1
481
+ else math.ceil(self.kernel_shape[0] / group_size)
482
+ )
483
+ self.kernel_scale = self.add_weight(
484
+ name="kernel_scale",
485
+ shape=(self.units, n_groups),
486
+ initializer="ones",
487
+ trainable=False,
488
+ )
489
+ self.kernel_zero = self.add_weight(
490
+ name="kernel_zero",
491
+ shape=(self.units, n_groups),
492
+ initializer="zeros",
493
+ dtype="uint8",
494
+ trainable=False,
495
+ )
496
+ self.g_idx = self.add_weight(
497
+ name="g_idx",
498
+ shape=(self.kernel_shape[0],),
499
+ initializer="zeros",
500
+ dtype="float32",
501
+ trainable=False,
502
+ )
503
+
504
+ def _gptq_call(self, inputs, training=False):
505
+ from keras.src.quantizers import gptq_core
506
+
507
+ if not self.is_gptq_calibrated:
508
+ W = self._kernel
509
+ else:
510
+ should_unpack = (
511
+ gptq_core.get_weight_bits_for_layer(self, config=None) == 4
512
+ )
513
+ W = (
514
+ quantizers.unpack_int4(
515
+ self.quantized_kernel,
516
+ orig_len=self.units,
517
+ axis=0,
518
+ dtype="uint8",
519
+ )
520
+ if should_unpack
521
+ else self.quantized_kernel
522
+ )
523
+ W = ops.transpose(
524
+ dequantize_with_sz_map(
525
+ W,
526
+ self.kernel_scale,
527
+ self.kernel_zero,
528
+ self.g_idx,
529
+ )
530
+ )
531
+
532
+ y = ops.matmul(inputs, W)
533
+ if self.bias is not None:
534
+ y = ops.add(y, self.bias)
535
+ if self.activation is not None:
536
+ y = self.activation(y)
537
+ return y
538
+
539
+ def _awq_build(self, kernel_shape, config):
540
+ """Build variables for AWQ quantization.
541
+
542
+ AWQ uses 4-bit quantization with per-channel AWQ scales that protect
543
+ salient weights based on activation magnitudes.
544
+ """
545
+ from keras.src.quantizers import awq_core
546
+
547
+ # Ensures the forward pass uses the original high-precision kernel
548
+ # until calibration has been performed.
549
+ self.is_awq_calibrated = False
550
+ self.kernel_shape = kernel_shape
551
+
552
+ # For 4-bit weights, we pack two values per byte.
553
+ units = (kernel_shape[1] + 1) // 2
554
+
555
+ self.quantized_kernel = self.add_weight(
556
+ name="kernel",
557
+ shape=(units, kernel_shape[0]),
558
+ initializer="zeros",
559
+ dtype="uint8",
560
+ trainable=False,
561
+ )
562
+
563
+ group_size = awq_core.get_group_size_for_layer(self, config)
564
+ num_groups = (
565
+ 1 if group_size == -1 else math.ceil(kernel_shape[0] / group_size)
566
+ )
567
+ self.kernel_scale = self.add_weight(
568
+ name="kernel_scale",
569
+ shape=(self.units, num_groups),
570
+ initializer="ones",
571
+ trainable=False,
572
+ )
573
+ self.kernel_zero = self.add_weight(
574
+ name="kernel_zero",
575
+ shape=(self.units, num_groups),
576
+ initializer="zeros",
577
+ dtype="uint8",
578
+ trainable=False,
579
+ )
580
+
581
+ # Per-channel AWQ scales from activation magnitudes
582
+ self.awq_scales = self.add_weight(
583
+ name="awq_scales",
584
+ shape=(kernel_shape[0],),
585
+ initializer="ones",
586
+ trainable=False,
587
+ )
588
+ self.g_idx = self.add_weight(
589
+ name="g_idx",
590
+ shape=(kernel_shape[0],),
591
+ initializer="zeros",
592
+ dtype="float32",
593
+ trainable=False,
594
+ )
595
+
596
+ def _awq_call(self, inputs, training=False):
597
+ """Forward pass for AWQ quantized layer."""
598
+ if not self.is_awq_calibrated:
599
+ W = self._kernel
600
+ else:
601
+ # Unpack 4-bit weights
602
+ W = quantizers.unpack_int4(
603
+ self.quantized_kernel,
604
+ orig_len=self.units,
605
+ axis=0,
606
+ dtype="uint8",
607
+ )
608
+ # Dequantize using scale/zero maps
609
+ W = ops.transpose(
610
+ dequantize_with_sz_map(
611
+ W,
612
+ self.kernel_scale,
613
+ self.kernel_zero,
614
+ self.g_idx,
615
+ )
616
+ )
617
+ # Apply AWQ scales by dividing to restore original magnitude
618
+ # (We multiplied by scales before quantization, so divide to undo)
619
+ # awq_scales has shape [input_dim], W has shape [input_dim, units]
620
+ # Expand dims for proper broadcasting.
621
+ W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
622
+
623
+ y = ops.matmul(inputs, W)
624
+ if self.bias is not None:
625
+ y = ops.add(y, self.bias)
626
+ if self.activation is not None:
627
+ y = self.activation(y)
628
+ return y
629
+
630
+ def _int4_build(self, kernel_shape, config=None):
356
631
  """Build variables for int4 quantization.
357
632
 
358
633
  `kernel_shape` is the *original* float32 kernel shape
@@ -361,8 +636,10 @@ class Dense(Layer):
361
636
  int8 byte.
362
637
  """
363
638
  # Per-channel int8 quantizer for the last axis (features).
364
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(
365
- axis=-1,
639
+ self.inputs_quantizer = (
640
+ QuantizationConfig.activation_quantizer_or_default(
641
+ config, quantizers.AbsMaxQuantizer()
642
+ )
366
643
  )
367
644
  input_dim, output_dim = kernel_shape
368
645
  packed_rows = (input_dim + 1) // 2 # ceil for odd dims
@@ -451,11 +728,15 @@ class Dense(Layer):
451
728
  inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
452
729
  return (inputs_grad, None, None)
453
730
 
454
- inputs, inputs_scale = self.inputs_quantizer(inputs)
731
+ output_scale = kernel_scale
732
+ if self.inputs_quantizer:
733
+ inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
734
+ output_scale = ops.multiply(output_scale, inputs_scale)
735
+
455
736
  x = ops.matmul(inputs, kernel)
456
737
  # De-scale outputs
457
738
  x = ops.cast(x, self.compute_dtype)
458
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
739
+ x = ops.divide(x, output_scale)
459
740
  return x, grad_fn
460
741
 
461
742
  x = matmul_with_inputs_gradient(
@@ -502,10 +783,15 @@ class Dense(Layer):
502
783
  inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
503
784
  return (inputs_grad, None, None)
504
785
 
505
- inputs, inputs_scale = self.inputs_quantizer(inputs)
786
+ output_scale = kernel_scale
787
+
788
+ if self.inputs_quantizer:
789
+ inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
790
+ output_scale = ops.multiply(output_scale, inputs_scale)
791
+
506
792
  x = ops.matmul(inputs, unpacked_kernel)
507
793
  x = ops.cast(x, self.compute_dtype)
508
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
794
+ x = ops.divide(x, output_scale)
509
795
  return x, grad_fn
510
796
 
511
797
  x = matmul_with_inputs_gradient(
@@ -617,30 +903,37 @@ class Dense(Layer):
617
903
  x = self.activation(x)
618
904
  return x
619
905
 
620
- def quantize(self, mode, type_check=True):
906
+ def quantize(self, mode=None, type_check=True, config=None):
621
907
  # Prevent quantization of the subclasses
622
908
  if type_check and (type(self) is not Dense):
623
909
  raise self._not_implemented_error(self.quantize)
624
910
 
911
+ self.quantization_config = config
912
+
625
913
  kernel_shape = self._kernel.shape
626
914
  if mode == "int8":
627
- kernel_value, kernel_scale = quantizers.abs_max_quantize(
628
- self._kernel, axis=0, to_numpy=True
915
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
916
+ self.quantization_config, quantizers.AbsMaxQuantizer(axis=0)
917
+ )
918
+ kernel_value, kernel_scale = weight_quantizer(
919
+ self._kernel, to_numpy=True
629
920
  )
630
921
  kernel_scale = ops.squeeze(kernel_scale, axis=0)
631
922
  del self._kernel
632
923
  # Build variables for int8 mode
633
- self.quantized_build(kernel_shape, mode)
924
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
634
925
  self._kernel.assign(kernel_value)
635
926
  self.kernel_scale.assign(kernel_scale)
636
927
  elif mode == "int4":
637
928
  # 1. Quantize to int4 values (still int8 dtype, range [-8,7])
638
- kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
639
- self._kernel,
640
- axis=0,
641
- value_range=(-8, 7),
642
- dtype="int8",
643
- to_numpy=True,
929
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
930
+ self.quantization_config,
931
+ quantizers.AbsMaxQuantizer(
932
+ axis=0, value_range=(-8, 7), output_dtype="int8"
933
+ ),
934
+ )
935
+ kernel_value_int4, kernel_scale = weight_quantizer(
936
+ self._kernel, to_numpy=True
644
937
  )
645
938
  kernel_scale = ops.squeeze(kernel_scale, axis=0)
646
939
  # 2. Pack two int4 values into a single int8 byte.
@@ -648,10 +941,14 @@ class Dense(Layer):
648
941
  del self._kernel
649
942
  # Build variables using the original kernel shape; _int4_build will
650
943
  # compute the packed shape internally.
651
- self.quantized_build(kernel_shape, mode)
944
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
652
945
  # Assign packed values.
653
946
  self._kernel.assign(packed_kernel_value)
654
947
  self.kernel_scale.assign(kernel_scale)
948
+ elif mode == "gptq":
949
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
950
+ elif mode == "awq":
951
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
655
952
  elif mode == "float8":
656
953
  self.quantized_build(kernel_shape, mode)
657
954
  else:
@@ -661,7 +958,14 @@ class Dense(Layer):
661
958
  if self.dtype_policy.quantization_mode is None:
662
959
  from keras.src import dtype_policies # local import to avoid cycle
663
960
 
664
- policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
961
+ policy_name = mode
962
+ if mode == "gptq":
963
+ policy_name = self.quantization_config.dtype_policy_string()
964
+ elif mode == "awq":
965
+ policy_name = self.quantization_config.dtype_policy_string()
966
+ policy = dtype_policies.get(
967
+ f"{policy_name}_from_{self.dtype_policy.name}"
968
+ )
665
969
  self.dtype_policy = policy
666
970
 
667
971
  def _get_kernel_with_merged_lora(self):
@@ -693,7 +997,7 @@ class Dense(Layer):
693
997
  `kernel_scale`: The quantization scale for the merged kernel.
694
998
  This is `None` if the layer is not quantized.
695
999
  """
696
- if self.dtype_policy.quantization_mode is None:
1000
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
697
1001
  return self.kernel, None
698
1002
 
699
1003
  kernel_value = self._kernel