keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,9 @@ from keras.src import regularizers
11
11
  from keras.src.api_export import keras_export
12
12
  from keras.src.layers.input_spec import InputSpec
13
13
  from keras.src.layers.layer import Layer
14
+ from keras.src.quantizers.quantization_config import QuantizationConfig
14
15
  from keras.src.quantizers.quantizers import dequantize_with_sz_map
16
+ from keras.src.saving import serialization_lib
15
17
 
16
18
 
17
19
  @keras_export("keras.layers.Dense")
@@ -23,7 +25,9 @@ class Dense(Layer):
23
25
  where `activation` is the element-wise activation function
24
26
  passed as the `activation` argument, `kernel` is a weights matrix
25
27
  created by the layer, and `bias` is a bias vector created by the layer
26
- (only applicable if `use_bias` is `True`).
28
+ (only applicable if `use_bias` is `True`). When this layer is
29
+ followed by a `BatchNormalization` layer, it is recommended to set
30
+ `use_bias=False` as `BatchNormalization` has its own bias term.
27
31
 
28
32
  Note: If the input to the layer has a rank greater than 2, `Dense`
29
33
  computes the dot product between the `inputs` and the `kernel` along the
@@ -90,8 +94,15 @@ class Dense(Layer):
90
94
  bias_constraint=None,
91
95
  lora_rank=None,
92
96
  lora_alpha=None,
97
+ quantization_config=None,
93
98
  **kwargs,
94
99
  ):
100
+ if not isinstance(units, int) or units <= 0:
101
+ raise ValueError(
102
+ "Received an invalid value for `units`, expected a positive "
103
+ f"integer. Received: units={units}"
104
+ )
105
+
95
106
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
96
107
  self.units = units
97
108
  self.activation = activations.get(activation)
@@ -105,14 +116,19 @@ class Dense(Layer):
105
116
  self.lora_rank = lora_rank
106
117
  self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
107
118
  self.lora_enabled = False
119
+ self.quantization_config = quantization_config
108
120
  self.input_spec = InputSpec(min_ndim=2)
109
121
  self.supports_masking = True
110
122
 
111
123
  def build(self, input_shape):
112
124
  kernel_shape = (input_shape[-1], self.units)
113
125
  if self.quantization_mode:
114
- self.quantized_build(kernel_shape, mode=self.quantization_mode)
115
- if self.quantization_mode not in ("int8", "int4", "gptq"):
126
+ self.quantized_build(
127
+ kernel_shape,
128
+ mode=self.quantization_mode,
129
+ config=self.quantization_config,
130
+ )
131
+ if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
116
132
  # If the layer is quantized to int8 or int4, `self._kernel` will be
117
133
  # added in `self._int8_build` or `_int4_build`. Therefore, we skip
118
134
  # it here.
@@ -149,15 +165,17 @@ class Dense(Layer):
149
165
 
150
166
  mode = self.quantization_mode
151
167
  is_gptq = mode == "gptq"
168
+ is_awq = mode == "awq"
152
169
  is_int4 = mode == "int4"
153
- calibrated = bool(getattr(self, "is_gptq_calibrated", False))
170
+ gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
171
+ awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
154
172
  gptq_bits = (
155
173
  gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
156
174
  )
157
175
 
158
176
  # Decide the source tensor first (packed vs already-quantized vs plain
159
177
  # kernel)
160
- if is_gptq and calibrated and gptq_bits != 4:
178
+ if is_gptq and gptq_calibrated and gptq_bits != 4:
161
179
  # calibrated GPTQ, not 4-bit, no unpacking needed
162
180
  kernel = self.quantized_kernel
163
181
  else:
@@ -167,7 +185,15 @@ class Dense(Layer):
167
185
  # Handle int4 unpacking cases in one place
168
186
  if is_int4:
169
187
  kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
170
- elif is_gptq and calibrated and gptq_bits == 4:
188
+ elif is_gptq and gptq_calibrated and gptq_bits == 4:
189
+ kernel = quantizers.unpack_int4(
190
+ self.quantized_kernel,
191
+ orig_len=self.units,
192
+ axis=0,
193
+ dtype="uint8",
194
+ )
195
+ elif is_awq and awq_calibrated:
196
+ # AWQ always uses 4-bit quantization
171
197
  kernel = quantizers.unpack_int4(
172
198
  self.quantized_kernel,
173
199
  orig_len=self.units,
@@ -258,25 +284,25 @@ class Dense(Layer):
258
284
  if not self.built:
259
285
  return
260
286
  mode = self.quantization_mode
261
- if mode not in self.quantization_variable_spec:
287
+ if mode not in self.variable_serialization_spec:
262
288
  raise self._quantization_mode_error(mode)
263
289
 
264
290
  # Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
265
291
  # for None/gptq)
266
292
  kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
267
-
268
- # Save the variables using the name as the key.
269
- if mode != "gptq":
270
- store["kernel"] = kernel_value
271
- if self.bias is not None:
272
- store["bias"] = self.bias
273
- for name in self.quantization_variable_spec[mode]:
274
- if name == "kernel_scale" and mode in ("int4", "int8"):
293
+ idx = 0
294
+ for name in self.variable_serialization_spec[mode]:
295
+ if name == "kernel":
296
+ store[str(idx)] = kernel_value
297
+ elif name == "bias" and self.bias is None:
298
+ continue
299
+ elif name == "kernel_scale" and mode in ("int4", "int8"):
275
300
  # For int4/int8, the merged LoRA scale (if any) comes from
276
301
  # `_get_kernel_with_merged_lora()`
277
- store[name] = merged_kernel_scale
302
+ store[str(idx)] = merged_kernel_scale
278
303
  else:
279
- store[name] = getattr(self, name)
304
+ store[str(idx)] = getattr(self, name)
305
+ idx += 1
280
306
 
281
307
  def load_own_variables(self, store):
282
308
  if not self.lora_enabled:
@@ -285,39 +311,22 @@ class Dense(Layer):
285
311
  if not self.built:
286
312
  return
287
313
  mode = self.quantization_mode
288
- if mode not in self.quantization_variable_spec:
314
+ if mode not in self.variable_serialization_spec:
289
315
  raise self._quantization_mode_error(mode)
290
316
 
291
- # Determine whether to use the legacy loading method.
292
- if "0" in store:
293
- return self._legacy_load_own_variables(store)
317
+ # A saved GPTQ/AWQ quantized model will always be calibrated.
318
+ self.is_gptq_calibrated = mode == "gptq"
319
+ self.is_awq_calibrated = mode == "awq"
294
320
 
295
- # Load the variables using the name as the key.
296
- if mode != "gptq":
297
- self._kernel.assign(store["kernel"])
298
- if self.bias is not None:
299
- self.bias.assign(store["bias"])
300
- for name in self.quantization_variable_spec[mode]:
301
- getattr(self, name).assign(store[name])
302
- if self.lora_enabled:
303
- self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
304
- self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
305
-
306
- def _legacy_load_own_variables(self, store):
307
- # The keys of the `store` will be saved as determined because the
308
- # default ordering will change after quantization
309
- mode = self.quantization_mode
310
- targets = []
311
- if mode != "gptq":
312
- targets.append(self._kernel)
313
- if self.bias is not None:
314
- targets.append(self.bias)
315
- targets.extend(
316
- getattr(self, name)
317
- for name in self.quantization_variable_spec[mode]
318
- )
319
- for i, variable in enumerate(targets):
320
- variable.assign(store[str(i)])
321
+ idx = 0
322
+ for name in self.variable_serialization_spec[mode]:
323
+ if name == "kernel":
324
+ self._kernel.assign(store[str(idx)])
325
+ elif name == "bias" and self.bias is None:
326
+ continue
327
+ else:
328
+ getattr(self, name).assign(store[str(idx)])
329
+ idx += 1
321
330
  if self.lora_enabled:
322
331
  self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
323
332
  self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
@@ -338,59 +347,51 @@ class Dense(Layer):
338
347
  "bias_regularizer": regularizers.serialize(self.bias_regularizer),
339
348
  "kernel_constraint": constraints.serialize(self.kernel_constraint),
340
349
  "bias_constraint": constraints.serialize(self.bias_constraint),
350
+ "quantization_config": serialization_lib.serialize_keras_object(
351
+ self.quantization_config
352
+ ),
341
353
  }
342
354
  if self.lora_rank:
343
355
  config["lora_rank"] = self.lora_rank
344
356
  config["lora_alpha"] = self.lora_alpha
345
357
  return {**base_config, **config}
346
358
 
347
- def _check_load_own_variables(self, store):
348
- all_vars = self._trainable_variables + self._non_trainable_variables
349
- if len(store.keys()) != len(all_vars):
350
- if len(all_vars) == 0 and not self.built:
351
- raise ValueError(
352
- f"Layer '{self.name}' was never built "
353
- "and thus it doesn't have any variables. "
354
- f"However the weights file lists {len(store.keys())} "
355
- "variables for this layer.\n"
356
- "In most cases, this error indicates that either:\n\n"
357
- "1. The layer is owned by a parent layer that "
358
- "implements a `build()` method, but calling the "
359
- "parent's `build()` method did NOT create the state of "
360
- f"the child layer '{self.name}'. A `build()` method "
361
- "must create ALL state for the layer, including "
362
- "the state of any children layers.\n\n"
363
- "2. You need to implement "
364
- "the `def build_from_config(self, config)` method "
365
- f"on layer '{self.name}', to specify how to rebuild "
366
- "it during loading. "
367
- "In this case, you might also want to implement the "
368
- "method that generates the build config at saving time, "
369
- "`def get_build_config(self)`. "
370
- "The method `build_from_config()` is meant "
371
- "to create the state "
372
- "of the layer (i.e. its variables) upon deserialization.",
373
- )
374
- raise ValueError(
375
- f"Layer '{self.name}' expected {len(all_vars)} variables, "
376
- "but received "
377
- f"{len(store.keys())} variables during loading. "
378
- f"Expected: {[v.name for v in all_vars]}"
359
+ @classmethod
360
+ def from_config(cls, config):
361
+ config = config.copy()
362
+ config["quantization_config"] = (
363
+ serialization_lib.deserialize_keras_object(
364
+ config.get("quantization_config", None)
379
365
  )
366
+ )
367
+ return super().from_config(config)
380
368
 
381
369
  @property
382
- def quantization_variable_spec(self):
383
- """Returns a dict mapping quantization modes to variable names.
370
+ def variable_serialization_spec(self):
371
+ """Returns a dict mapping quantization modes to variable names in order.
384
372
 
385
373
  This spec is used by `save_own_variables` and `load_own_variables` to
386
- determine which variables should be saved/loaded for each quantization
387
- mode.
374
+ determine the correct ordering of variables during serialization for
375
+ each quantization mode. `None` means no quantization.
388
376
  """
389
377
  return {
390
- None: [],
391
- "int8": ["kernel_scale"],
392
- "int4": ["kernel_scale"],
378
+ None: [
379
+ "kernel",
380
+ "bias",
381
+ ],
382
+ "int8": [
383
+ "kernel",
384
+ "bias",
385
+ "kernel_scale",
386
+ ],
387
+ "int4": [
388
+ "kernel",
389
+ "bias",
390
+ "kernel_scale",
391
+ ],
393
392
  "float8": [
393
+ "kernel",
394
+ "bias",
394
395
  "inputs_scale",
395
396
  "inputs_amax_history",
396
397
  "kernel_scale",
@@ -399,28 +400,44 @@ class Dense(Layer):
399
400
  "outputs_grad_amax_history",
400
401
  ],
401
402
  "gptq": [
403
+ "bias",
404
+ "quantized_kernel",
405
+ "kernel_scale",
406
+ "kernel_zero",
407
+ "g_idx",
408
+ ],
409
+ "awq": [
410
+ "bias",
402
411
  "quantized_kernel",
403
412
  "kernel_scale",
404
413
  "kernel_zero",
414
+ "awq_scales",
405
415
  "g_idx",
406
416
  ],
407
417
  }
408
418
 
409
419
  def quantized_build(self, kernel_shape, mode, config=None):
410
420
  if mode == "int8":
411
- self._int8_build(kernel_shape)
421
+ self._int8_build(kernel_shape, config)
412
422
  elif mode == "int4":
413
- self._int4_build(kernel_shape)
423
+ self._int4_build(kernel_shape, config)
414
424
  elif mode == "float8":
415
425
  self._float8_build()
416
426
  elif mode == "gptq":
417
427
  self._gptq_build(kernel_shape, config)
428
+ elif mode == "awq":
429
+ self._awq_build(kernel_shape, config)
418
430
  else:
419
431
  raise self._quantization_mode_error(mode)
420
432
  self._is_quantized = True
421
433
 
422
- def _int8_build(self, kernel_shape):
423
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
434
+ def _int8_build(self, kernel_shape, config=None):
435
+ self.inputs_quantizer = (
436
+ QuantizationConfig.activation_quantizer_or_default(
437
+ config, quantizers.AbsMaxQuantizer()
438
+ )
439
+ )
440
+
424
441
  self._kernel = self.add_weight(
425
442
  name="kernel",
426
443
  shape=kernel_shape,
@@ -519,7 +536,98 @@ class Dense(Layer):
519
536
  y = self.activation(y)
520
537
  return y
521
538
 
522
- def _int4_build(self, kernel_shape):
539
+ def _awq_build(self, kernel_shape, config):
540
+ """Build variables for AWQ quantization.
541
+
542
+ AWQ uses 4-bit quantization with per-channel AWQ scales that protect
543
+ salient weights based on activation magnitudes.
544
+ """
545
+ from keras.src.quantizers import awq_core
546
+
547
+ # Ensures the forward pass uses the original high-precision kernel
548
+ # until calibration has been performed.
549
+ self.is_awq_calibrated = False
550
+ self.kernel_shape = kernel_shape
551
+
552
+ # For 4-bit weights, we pack two values per byte.
553
+ units = (kernel_shape[1] + 1) // 2
554
+
555
+ self.quantized_kernel = self.add_weight(
556
+ name="kernel",
557
+ shape=(units, kernel_shape[0]),
558
+ initializer="zeros",
559
+ dtype="uint8",
560
+ trainable=False,
561
+ )
562
+
563
+ group_size = awq_core.get_group_size_for_layer(self, config)
564
+ num_groups = (
565
+ 1 if group_size == -1 else math.ceil(kernel_shape[0] / group_size)
566
+ )
567
+ self.kernel_scale = self.add_weight(
568
+ name="kernel_scale",
569
+ shape=(self.units, num_groups),
570
+ initializer="ones",
571
+ trainable=False,
572
+ )
573
+ self.kernel_zero = self.add_weight(
574
+ name="kernel_zero",
575
+ shape=(self.units, num_groups),
576
+ initializer="zeros",
577
+ dtype="uint8",
578
+ trainable=False,
579
+ )
580
+
581
+ # Per-channel AWQ scales from activation magnitudes
582
+ self.awq_scales = self.add_weight(
583
+ name="awq_scales",
584
+ shape=(kernel_shape[0],),
585
+ initializer="ones",
586
+ trainable=False,
587
+ )
588
+ self.g_idx = self.add_weight(
589
+ name="g_idx",
590
+ shape=(kernel_shape[0],),
591
+ initializer="zeros",
592
+ dtype="float32",
593
+ trainable=False,
594
+ )
595
+
596
+ def _awq_call(self, inputs, training=False):
597
+ """Forward pass for AWQ quantized layer."""
598
+ if not self.is_awq_calibrated:
599
+ W = self._kernel
600
+ else:
601
+ # Unpack 4-bit weights
602
+ W = quantizers.unpack_int4(
603
+ self.quantized_kernel,
604
+ orig_len=self.units,
605
+ axis=0,
606
+ dtype="uint8",
607
+ )
608
+ # Dequantize using scale/zero maps
609
+ W = ops.transpose(
610
+ dequantize_with_sz_map(
611
+ W,
612
+ self.kernel_scale,
613
+ self.kernel_zero,
614
+ self.g_idx,
615
+ )
616
+ )
617
+ # Apply AWQ scales by dividing to restore original magnitude
618
+ # (We multiplied by scales before quantization, so divide to undo)
619
+ # awq_scales has shape [input_dim], W has shape [input_dim, units]
620
+ # Expand dims for proper broadcasting.
621
+ W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
622
+
623
+ y = ops.matmul(inputs, W)
624
+ if self.bias is not None:
625
+ y = ops.add(y, self.bias)
626
+ if self.activation is not None:
627
+ y = self.activation(y)
628
+ return y
629
+
630
+ def _int4_build(self, kernel_shape, config=None):
523
631
  """Build variables for int4 quantization.
524
632
 
525
633
  `kernel_shape` is the *original* float32 kernel shape
@@ -528,8 +636,10 @@ class Dense(Layer):
528
636
  int8 byte.
529
637
  """
530
638
  # Per-channel int8 quantizer for the last axis (features).
531
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(
532
- axis=-1,
639
+ self.inputs_quantizer = (
640
+ QuantizationConfig.activation_quantizer_or_default(
641
+ config, quantizers.AbsMaxQuantizer()
642
+ )
533
643
  )
534
644
  input_dim, output_dim = kernel_shape
535
645
  packed_rows = (input_dim + 1) // 2 # ceil for odd dims
@@ -618,11 +728,15 @@ class Dense(Layer):
618
728
  inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
619
729
  return (inputs_grad, None, None)
620
730
 
621
- inputs, inputs_scale = self.inputs_quantizer(inputs)
731
+ output_scale = kernel_scale
732
+ if self.inputs_quantizer:
733
+ inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
734
+ output_scale = ops.multiply(output_scale, inputs_scale)
735
+
622
736
  x = ops.matmul(inputs, kernel)
623
737
  # De-scale outputs
624
738
  x = ops.cast(x, self.compute_dtype)
625
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
739
+ x = ops.divide(x, output_scale)
626
740
  return x, grad_fn
627
741
 
628
742
  x = matmul_with_inputs_gradient(
@@ -669,10 +783,15 @@ class Dense(Layer):
669
783
  inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
670
784
  return (inputs_grad, None, None)
671
785
 
672
- inputs, inputs_scale = self.inputs_quantizer(inputs)
786
+ output_scale = kernel_scale
787
+
788
+ if self.inputs_quantizer:
789
+ inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
790
+ output_scale = ops.multiply(output_scale, inputs_scale)
791
+
673
792
  x = ops.matmul(inputs, unpacked_kernel)
674
793
  x = ops.cast(x, self.compute_dtype)
675
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
794
+ x = ops.divide(x, output_scale)
676
795
  return x, grad_fn
677
796
 
678
797
  x = matmul_with_inputs_gradient(
@@ -784,30 +903,37 @@ class Dense(Layer):
784
903
  x = self.activation(x)
785
904
  return x
786
905
 
787
- def quantize(self, mode, type_check=True, config=None):
906
+ def quantize(self, mode=None, type_check=True, config=None):
788
907
  # Prevent quantization of the subclasses
789
908
  if type_check and (type(self) is not Dense):
790
909
  raise self._not_implemented_error(self.quantize)
791
910
 
911
+ self.quantization_config = config
912
+
792
913
  kernel_shape = self._kernel.shape
793
914
  if mode == "int8":
794
- kernel_value, kernel_scale = quantizers.abs_max_quantize(
795
- self._kernel, axis=0, to_numpy=True
915
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
916
+ self.quantization_config, quantizers.AbsMaxQuantizer(axis=0)
917
+ )
918
+ kernel_value, kernel_scale = weight_quantizer(
919
+ self._kernel, to_numpy=True
796
920
  )
797
921
  kernel_scale = ops.squeeze(kernel_scale, axis=0)
798
922
  del self._kernel
799
923
  # Build variables for int8 mode
800
- self.quantized_build(kernel_shape, mode)
924
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
801
925
  self._kernel.assign(kernel_value)
802
926
  self.kernel_scale.assign(kernel_scale)
803
927
  elif mode == "int4":
804
928
  # 1. Quantize to int4 values (still int8 dtype, range [-8,7])
805
- kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
806
- self._kernel,
807
- axis=0,
808
- value_range=(-8, 7),
809
- dtype="int8",
810
- to_numpy=True,
929
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
930
+ self.quantization_config,
931
+ quantizers.AbsMaxQuantizer(
932
+ axis=0, value_range=(-8, 7), output_dtype="int8"
933
+ ),
934
+ )
935
+ kernel_value_int4, kernel_scale = weight_quantizer(
936
+ self._kernel, to_numpy=True
811
937
  )
812
938
  kernel_scale = ops.squeeze(kernel_scale, axis=0)
813
939
  # 2. Pack two int4 values into a single int8 byte.
@@ -815,12 +941,14 @@ class Dense(Layer):
815
941
  del self._kernel
816
942
  # Build variables using the original kernel shape; _int4_build will
817
943
  # compute the packed shape internally.
818
- self.quantized_build(kernel_shape, mode)
944
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
819
945
  # Assign packed values.
820
946
  self._kernel.assign(packed_kernel_value)
821
947
  self.kernel_scale.assign(kernel_scale)
822
948
  elif mode == "gptq":
823
- self.quantized_build(kernel_shape, mode, config)
949
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
950
+ elif mode == "awq":
951
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
824
952
  elif mode == "float8":
825
953
  self.quantized_build(kernel_shape, mode)
826
954
  else:
@@ -832,7 +960,9 @@ class Dense(Layer):
832
960
 
833
961
  policy_name = mode
834
962
  if mode == "gptq":
835
- policy_name = config.dtype_policy_string()
963
+ policy_name = self.quantization_config.dtype_policy_string()
964
+ elif mode == "awq":
965
+ policy_name = self.quantization_config.dtype_policy_string()
836
966
  policy = dtype_policies.get(
837
967
  f"{policy_name}_from_{self.dtype_policy.name}"
838
968
  )
@@ -867,7 +997,7 @@ class Dense(Layer):
867
997
  `kernel_scale`: The quantization scale for the merged kernel.
868
998
  This is `None` if the layer is not quantized.
869
999
  """
870
- if self.dtype_policy.quantization_mode in (None, "gptq"):
1000
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
871
1001
  return self.kernel, None
872
1002
 
873
1003
  kernel_value = self._kernel