keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import ml_dtypes
6
6
  import numpy as np
7
7
 
8
8
  from keras.src import activations
9
+ from keras.src import backend
9
10
  from keras.src import constraints
10
11
  from keras.src import dtype_policies
11
12
  from keras.src import initializers
@@ -15,7 +16,9 @@ from keras.src import regularizers
15
16
  from keras.src.api_export import keras_export
16
17
  from keras.src.layers.input_spec import InputSpec
17
18
  from keras.src.layers.layer import Layer
19
+ from keras.src.quantizers.quantization_config import QuantizationConfig
18
20
  from keras.src.quantizers.quantizers import dequantize_with_sz_map
21
+ from keras.src.saving import serialization_lib
19
22
 
20
23
 
21
24
  @keras_export("keras.layers.EinsumDense")
@@ -134,6 +137,7 @@ class EinsumDense(Layer):
134
137
  lora_rank=None,
135
138
  lora_alpha=None,
136
139
  gptq_unpacked_column_size=None,
140
+ quantization_config=None,
137
141
  **kwargs,
138
142
  ):
139
143
  super().__init__(**kwargs)
@@ -154,6 +158,7 @@ class EinsumDense(Layer):
154
158
  self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
155
159
  self.lora_enabled = False
156
160
  self.gptq_unpacked_column_size = gptq_unpacked_column_size
161
+ self.quantization_config = quantization_config
157
162
 
158
163
  def build(self, input_shape):
159
164
  shape_data = _analyze_einsum_string(
@@ -169,12 +174,13 @@ class EinsumDense(Layer):
169
174
  self.quantized_build(
170
175
  kernel_shape,
171
176
  mode=self.quantization_mode,
177
+ config=self.quantization_config,
172
178
  )
173
179
  # Skip creating a duplicate kernel variable when the layer is already
174
180
  # quantized to int8 or int4, because `quantized_build` has created the
175
181
  # appropriate kernel variable. For other modes (e.g., float8 or no
176
182
  # quantization), we still need the floating-point kernel.
177
- if self.quantization_mode not in ("int8", "int4", "gptq"):
183
+ if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
178
184
  # If the layer is quantized to int8, `self._kernel` will be added
179
185
  # in `self._int8_build`. Therefore, we skip it here.
180
186
  self._kernel = self.add_weight(
@@ -213,15 +219,17 @@ class EinsumDense(Layer):
213
219
 
214
220
  mode = self.quantization_mode
215
221
  is_gptq = mode == "gptq"
222
+ is_awq = mode == "awq"
216
223
  is_int4 = mode == "int4"
217
- calibrated = bool(getattr(self, "is_gptq_calibrated", False))
224
+ gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
225
+ awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
218
226
  gptq_bits = (
219
227
  gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
220
228
  )
221
229
 
222
230
  # Decide the source tensor first (packed vs already-quantized vs plain
223
231
  # kernel)
224
- if is_gptq and calibrated and gptq_bits != 4:
232
+ if is_gptq and gptq_calibrated and gptq_bits != 4:
225
233
  # calibrated GPTQ, not 4-bit, no unpacking needed
226
234
  kernel = self.quantized_kernel
227
235
  else:
@@ -235,13 +243,21 @@ class EinsumDense(Layer):
235
243
  self._orig_length_along_pack_axis,
236
244
  self._int4_pack_axis,
237
245
  )
238
- elif is_gptq and calibrated and gptq_bits == 4:
246
+ elif is_gptq and gptq_calibrated and gptq_bits == 4:
239
247
  kernel = quantizers.unpack_int4(
240
248
  self.quantized_kernel,
241
249
  orig_len=self.gptq_unpacked_column_size,
242
250
  axis=0,
243
251
  dtype="uint8",
244
252
  )
253
+ elif is_awq and awq_calibrated:
254
+ # AWQ always uses 4-bit quantization
255
+ kernel = quantizers.unpack_int4(
256
+ self.quantized_kernel,
257
+ orig_len=self.awq_unpacked_column_size,
258
+ axis=0,
259
+ dtype="uint8",
260
+ )
245
261
 
246
262
  # Apply LoRA if enabled
247
263
  if self.lora_enabled:
@@ -326,25 +342,25 @@ class EinsumDense(Layer):
326
342
  if not self.built:
327
343
  return
328
344
  mode = self.quantization_mode
329
- if mode not in self.quantization_variable_spec:
345
+ if mode not in self.variable_serialization_spec:
330
346
  raise self._quantization_mode_error(mode)
331
347
 
332
348
  # Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
333
349
  # for None/gptq)
334
350
  kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
335
-
336
- # Save the variables using the name as the key.
337
- if mode != "gptq":
338
- store["kernel"] = kernel_value
339
- if self.bias is not None:
340
- store["bias"] = self.bias
341
- for name in self.quantization_variable_spec[mode]:
342
- if name == "kernel_scale" and mode in ("int4", "int8"):
351
+ idx = 0
352
+ for name in self.variable_serialization_spec[mode]:
353
+ if name == "kernel":
354
+ store[str(idx)] = kernel_value
355
+ elif name == "bias" and self.bias is None:
356
+ continue
357
+ elif name == "kernel_scale" and mode in ("int4", "int8"):
343
358
  # For int4/int8, the merged LoRA scale (if any) comes from
344
359
  # `_get_kernel_with_merged_lora()`
345
- store[name] = merged_kernel_scale
360
+ store[str(idx)] = merged_kernel_scale
346
361
  else:
347
- store[name] = getattr(self, name)
362
+ store[str(idx)] = getattr(self, name)
363
+ idx += 1
348
364
 
349
365
  def load_own_variables(self, store):
350
366
  if not self.lora_enabled:
@@ -353,39 +369,22 @@ class EinsumDense(Layer):
353
369
  if not self.built:
354
370
  return
355
371
  mode = self.quantization_mode
356
- if mode not in self.quantization_variable_spec:
372
+ if mode not in self.variable_serialization_spec:
357
373
  raise self._quantization_mode_error(mode)
358
374
 
359
- # Determine whether to use the legacy loading method.
360
- if "0" in store:
361
- return self._legacy_load_own_variables(store)
362
-
363
- # Load the variables using the name as the key.
364
- if mode != "gptq":
365
- self._kernel.assign(store["kernel"])
366
- if self.bias is not None:
367
- self.bias.assign(store["bias"])
368
- for name in self.quantization_variable_spec[mode]:
369
- getattr(self, name).assign(store[name])
370
- if self.lora_enabled:
371
- self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
372
- self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
375
+ # A saved GPTQ/AWQ quantized model will always be calibrated.
376
+ self.is_gptq_calibrated = mode == "gptq"
377
+ self.is_awq_calibrated = mode == "awq"
373
378
 
374
- def _legacy_load_own_variables(self, store):
375
- # The keys of the `store` will be saved as determined because the
376
- # default ordering will change after quantization
377
- mode = self.quantization_mode
378
- targets = []
379
- if mode != "gptq":
380
- targets.append(self._kernel)
381
- if self.bias is not None:
382
- targets.append(self.bias)
383
- targets.extend(
384
- getattr(self, name)
385
- for name in self.quantization_variable_spec[mode]
386
- )
387
- for i, variable in enumerate(targets):
388
- variable.assign(store[str(i)])
379
+ idx = 0
380
+ for name in self.variable_serialization_spec[mode]:
381
+ if name == "kernel":
382
+ self._kernel.assign(store[str(idx)])
383
+ elif name == "bias" and self.bias is None:
384
+ continue
385
+ else:
386
+ getattr(self, name).assign(store[str(idx)])
387
+ idx += 1
389
388
  if self.lora_enabled:
390
389
  self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
391
390
  self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
@@ -410,6 +409,9 @@ class EinsumDense(Layer):
410
409
  ),
411
410
  "kernel_constraint": constraints.serialize(self.kernel_constraint),
412
411
  "bias_constraint": constraints.serialize(self.bias_constraint),
412
+ "quantization_config": serialization_lib.serialize_keras_object(
413
+ self.quantization_config
414
+ ),
413
415
  }
414
416
  if self.lora_rank:
415
417
  config["lora_rank"] = self.lora_rank
@@ -418,53 +420,42 @@ class EinsumDense(Layer):
418
420
  config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size
419
421
  return {**base_config, **config}
420
422
 
421
- def _check_load_own_variables(self, store):
422
- all_vars = self._trainable_variables + self._non_trainable_variables
423
- if len(store.keys()) != len(all_vars):
424
- if len(all_vars) == 0 and not self.built:
425
- raise ValueError(
426
- f"Layer '{self.name}' was never built "
427
- "and thus it doesn't have any variables. "
428
- f"However the weights file lists {len(store.keys())} "
429
- "variables for this layer.\n"
430
- "In most cases, this error indicates that either:\n\n"
431
- "1. The layer is owned by a parent layer that "
432
- "implements a `build()` method, but calling the "
433
- "parent's `build()` method did NOT create the state of "
434
- f"the child layer '{self.name}'. A `build()` method "
435
- "must create ALL state for the layer, including "
436
- "the state of any children layers.\n\n"
437
- "2. You need to implement "
438
- "the `def build_from_config(self, config)` method "
439
- f"on layer '{self.name}', to specify how to rebuild "
440
- "it during loading. "
441
- "In this case, you might also want to implement the "
442
- "method that generates the build config at saving time, "
443
- "`def get_build_config(self)`. "
444
- "The method `build_from_config()` is meant "
445
- "to create the state "
446
- "of the layer (i.e. its variables) upon deserialization.",
447
- )
448
- raise ValueError(
449
- f"Layer '{self.name}' expected {len(all_vars)} variables, "
450
- "but received "
451
- f"{len(store.keys())} variables during loading. "
452
- f"Expected: {[v.name for v in all_vars]}"
423
+ @classmethod
424
+ def from_config(cls, config):
425
+ config = config.copy()
426
+ config["quantization_config"] = (
427
+ serialization_lib.deserialize_keras_object(
428
+ config.get("quantization_config", None)
453
429
  )
430
+ )
431
+ return super().from_config(config)
454
432
 
455
433
  @property
456
- def quantization_variable_spec(self):
457
- """Returns a dict mapping quantization modes to variable names.
434
+ def variable_serialization_spec(self):
435
+ """Returns a dict mapping quantization modes to variable names in order.
458
436
 
459
437
  This spec is used by `save_own_variables` and `load_own_variables` to
460
- determine which variables should be saved/loaded for each quantization
461
- mode.
438
+ determine the correct ordering of variables during serialization for
439
+ each quantization mode. `None` means no quantization.
462
440
  """
463
441
  return {
464
- None: [],
465
- "int8": ["kernel_scale"],
466
- "int4": ["kernel_scale"],
442
+ None: [
443
+ "kernel",
444
+ "bias",
445
+ ],
446
+ "int8": [
447
+ "kernel",
448
+ "bias",
449
+ "kernel_scale",
450
+ ],
451
+ "int4": [
452
+ "kernel",
453
+ "bias",
454
+ "kernel_scale",
455
+ ],
467
456
  "float8": [
457
+ "kernel",
458
+ "bias",
468
459
  "inputs_scale",
469
460
  "inputs_amax_history",
470
461
  "kernel_scale",
@@ -473,31 +464,48 @@ class EinsumDense(Layer):
473
464
  "outputs_grad_amax_history",
474
465
  ],
475
466
  "gptq": [
467
+ "bias",
468
+ "quantized_kernel",
469
+ "kernel_scale",
470
+ "kernel_zero",
471
+ "g_idx",
472
+ ],
473
+ "awq": [
474
+ "bias",
476
475
  "quantized_kernel",
477
476
  "kernel_scale",
478
477
  "kernel_zero",
478
+ "awq_scales",
479
479
  "g_idx",
480
480
  ],
481
481
  }
482
482
 
483
483
  def quantized_build(self, kernel_shape, mode, config=None):
484
484
  if mode == "int8":
485
- self._int8_build(kernel_shape)
485
+ self._int8_build(kernel_shape, config)
486
486
  elif mode == "int4":
487
- self._int4_build(kernel_shape)
487
+ self._int4_build(kernel_shape, config)
488
488
  elif mode == "float8":
489
489
  self._float8_build()
490
490
  elif mode == "gptq":
491
491
  self._gptq_build(kernel_shape, config)
492
+ elif mode == "awq":
493
+ self._awq_build(kernel_shape, config)
492
494
  else:
493
495
  raise self._quantization_mode_error(mode)
494
496
  self._is_quantized = True
495
497
 
496
- def _int8_build(self, kernel_shape):
498
+ def _int8_build(self, kernel_shape, config=None):
497
499
  self._set_quantization_info()
498
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(
499
- axis=self._input_reduced_axes
500
+ self.inputs_quantizer = (
501
+ QuantizationConfig.activation_quantizer_or_default(
502
+ config,
503
+ quantizers.AbsMaxQuantizer(),
504
+ )
500
505
  )
506
+ # If the config provided a default AbsMaxQuantizer, we need to
507
+ # override the axis to match the equation's reduction axes.
508
+ self.quantization_axis = tuple(self._input_reduced_axes)
501
509
  self._kernel = self.add_weight(
502
510
  name="kernel",
503
511
  shape=kernel_shape,
@@ -535,12 +543,7 @@ class EinsumDense(Layer):
535
543
  columns = kernel_shape[1]
536
544
  elif len(kernel_shape) == 3:
537
545
  shape = list(self.original_kernel_shape)
538
- try:
539
- d_model_dim_index = shape.index(max(shape))
540
- except ValueError:
541
- raise TypeError(
542
- f"Could not determine hidden dimension from shape {shape}"
543
- )
546
+ d_model_dim_index = shape.index(max(shape))
544
547
 
545
548
  if d_model_dim_index == 0: # QKV projection case
546
549
  in_features, heads, head_dim = shape
@@ -566,8 +569,7 @@ class EinsumDense(Layer):
566
569
  # For 4-bit weights, we pack two values per byte.
567
570
  kernel_columns = (columns + 1) // 2 if weight_bits == 4 else columns
568
571
 
569
- if hasattr(self, "_set_quantization_info"):
570
- self._set_quantization_info()
572
+ self._set_quantization_info()
571
573
 
572
574
  self.quantized_kernel = self.add_weight(
573
575
  name="kernel",
@@ -635,7 +637,128 @@ class EinsumDense(Layer):
635
637
  y = self.activation(y)
636
638
  return y
637
639
 
638
- def _int4_build(self, kernel_shape):
640
+ def _awq_build(self, kernel_shape, config):
641
+ """Build variables for AWQ quantization.
642
+
643
+ AWQ uses 4-bit quantization with per-channel AWQ scales that protect
644
+ salient weights based on activation magnitudes.
645
+ """
646
+ from keras.src.quantizers import awq_core
647
+
648
+ # Ensures the forward pass uses the original high-precision kernel
649
+ # until calibration has been performed.
650
+ self.is_awq_calibrated = False
651
+
652
+ self.original_kernel_shape = kernel_shape
653
+ if len(kernel_shape) == 2:
654
+ rows = kernel_shape[0]
655
+ columns = kernel_shape[1]
656
+ elif len(kernel_shape) == 3:
657
+ shape = list(self.original_kernel_shape)
658
+ d_model_dim_index = shape.index(max(shape))
659
+
660
+ if d_model_dim_index == 0: # QKV projection case
661
+ in_features, heads, head_dim = shape
662
+ rows, columns = (
663
+ in_features,
664
+ heads * head_dim,
665
+ )
666
+ elif d_model_dim_index in [1, 2]: # Attention Output case
667
+ heads, head_dim, out_features = shape
668
+ rows, columns = (
669
+ heads * head_dim,
670
+ out_features,
671
+ )
672
+ else:
673
+ raise ValueError("Could not determine row/column split.")
674
+ else:
675
+ raise ValueError("AWQ quantization only supports 2D or 3D kernels.")
676
+
677
+ group_size = awq_core.get_group_size_for_layer(self, config)
678
+ num_groups = 1 if group_size == -1 else math.ceil(rows / group_size)
679
+
680
+ self.awq_unpacked_column_size = columns
681
+
682
+ # For 4-bit weights, we pack two values per byte.
683
+ kernel_columns = (columns + 1) // 2
684
+
685
+ self._set_quantization_info()
686
+
687
+ self.quantized_kernel = self.add_weight(
688
+ name="kernel",
689
+ shape=(kernel_columns, rows),
690
+ initializer="zeros",
691
+ dtype="uint8",
692
+ trainable=False,
693
+ )
694
+
695
+ self.kernel_scale = self.add_weight(
696
+ name="kernel_scale",
697
+ shape=(columns, num_groups),
698
+ initializer="ones",
699
+ trainable=False,
700
+ )
701
+ self.kernel_zero = self.add_weight(
702
+ name="zero_point",
703
+ shape=(columns, num_groups),
704
+ initializer="zeros",
705
+ dtype="uint8",
706
+ trainable=False,
707
+ )
708
+
709
+ # Per-channel AWQ scales from activation magnitudes
710
+ self.awq_scales = self.add_weight(
711
+ name="awq_scales",
712
+ shape=(rows,),
713
+ initializer="ones",
714
+ trainable=False,
715
+ )
716
+
717
+ self.g_idx = self.add_weight(
718
+ name="g_idx",
719
+ shape=(rows,),
720
+ initializer="zeros",
721
+ dtype="float32",
722
+ trainable=False,
723
+ )
724
+
725
+ def _awq_call(self, inputs, training=False):
726
+ """Forward pass for AWQ quantized layer."""
727
+ if not self.is_awq_calibrated:
728
+ W = self._kernel
729
+ else:
730
+ # Unpack 4-bit weights
731
+ W = quantizers.unpack_int4(
732
+ self.quantized_kernel,
733
+ orig_len=self.awq_unpacked_column_size,
734
+ axis=0,
735
+ dtype="uint8",
736
+ )
737
+ # Dequantize using scale/zero maps
738
+ W = dequantize_with_sz_map(
739
+ W,
740
+ self.kernel_scale,
741
+ self.kernel_zero,
742
+ self.g_idx,
743
+ )
744
+ W = ops.transpose(W)
745
+
746
+ # Apply AWQ scales by dividing to restore original magnitude
747
+ # (We multiplied by scales before quantization, so divide to undo)
748
+ # awq_scales has shape [input_dim], W has shape [input_dim, out_dim]
749
+ # Expand dims for proper broadcasting.
750
+ W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
751
+
752
+ W = ops.reshape(W, self.original_kernel_shape)
753
+
754
+ y = ops.einsum(self.equation, inputs, W)
755
+ if self.bias is not None:
756
+ y = ops.add(y, self.bias)
757
+ if self.activation is not None:
758
+ y = self.activation(y)
759
+ return y
760
+
761
+ def _int4_build(self, kernel_shape, config=None):
639
762
  """Build variables for int4 quantization.
640
763
 
641
764
  The packed int4 kernel stores two int4 values within a single int8
@@ -647,9 +770,15 @@ class EinsumDense(Layer):
647
770
  self._set_quantization_info()
648
771
 
649
772
  # Quantizer for the inputs (per the reduced axes)
650
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(
651
- axis=self._input_reduced_axes
773
+ self.inputs_quantizer = (
774
+ QuantizationConfig.activation_quantizer_or_default(
775
+ config,
776
+ quantizers.AbsMaxQuantizer(),
777
+ )
652
778
  )
779
+ # If the config provided a default AbsMaxQuantizer, we need to
780
+ # override the axis to match the equation's reduction axes.
781
+ self.quantization_axis = tuple(self._input_reduced_axes)
653
782
 
654
783
  # Choose the axis to perform int4 packing - use the first reduced axis
655
784
  # for the kernel (analogous to the input dimension of a Dense layer).
@@ -771,13 +900,36 @@ class EinsumDense(Layer):
771
900
  )
772
901
  return (inputs_grad, None, None)
773
902
 
774
- inputs, inputs_scale = self.inputs_quantizer(inputs)
775
- x = ops.einsum(self.equation, inputs, kernel)
776
- # Deal with `inputs_scale`
777
- inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input")
778
- # De-scale outputs
779
- x = ops.cast(x, self.compute_dtype)
780
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
903
+ if self.inputs_quantizer:
904
+ inputs, inputs_scale = self.inputs_quantizer(
905
+ inputs, axis=self.quantization_axis
906
+ )
907
+ # Align `inputs_scale` axes with the output
908
+ # for correct broadcasting
909
+ inputs_scale = self._adjust_scale_for_quant(
910
+ inputs_scale, "input"
911
+ )
912
+ x = ops.einsum(self.equation, inputs, kernel)
913
+ # De-scale outputs
914
+ x = ops.cast(x, self.compute_dtype)
915
+ x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
916
+ else:
917
+ # Weight-only quantization: dequantize kernel and use float
918
+ # einsum. This is a workaround for PyTorch's einsum which
919
+ # doesn't support mixed-precision inputs (float input,
920
+ # int8 kernel).
921
+ if backend.backend() == "torch":
922
+ kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
923
+ float_kernel = ops.divide(
924
+ ops.cast(kernel, dtype=self.compute_dtype),
925
+ kernel_scale,
926
+ )
927
+ x = ops.einsum(self.equation, inputs, float_kernel)
928
+ else:
929
+ x = ops.einsum(self.equation, inputs, kernel)
930
+ # De-scale outputs
931
+ x = ops.cast(x, self.compute_dtype)
932
+ x = ops.divide(x, kernel_scale)
781
933
  return x, grad_fn
782
934
 
783
935
  x = einsum_with_inputs_gradient(
@@ -847,17 +999,38 @@ class EinsumDense(Layer):
847
999
  return (inputs_grad, None, None)
848
1000
 
849
1001
  # Quantize inputs per `self.inputs_quantizer`.
850
- inputs_q, inputs_scale = self.inputs_quantizer(inputs)
851
-
852
- # Compute einsum on quantized inputs and unpacked int4 kernel.
853
- x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
854
-
855
- # Align `inputs_scale` axes with the output for correct broadcasting
856
- inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input")
857
-
858
- # De-scale outputs.
859
- x = ops.cast(x, self.compute_dtype)
860
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
1002
+ if self.inputs_quantizer:
1003
+ inputs_q, inputs_scale = self.inputs_quantizer(
1004
+ inputs, axis=self.quantization_axis
1005
+ )
1006
+ # Align `inputs_scale` axes with the output
1007
+ # for correct broadcasting
1008
+ inputs_scale = self._adjust_scale_for_quant(
1009
+ inputs_scale, "input"
1010
+ )
1011
+ x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
1012
+ # De-scale outputs
1013
+ x = ops.cast(x, self.compute_dtype)
1014
+ x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
1015
+ else:
1016
+ # Weight-only quantization: dequantize kernel and use float
1017
+ # einsum. This is a workaround for PyTorch's einsum which
1018
+ # doesn't support mixed-precision inputs (float input,
1019
+ # int4 kernel).
1020
+ if backend.backend() == "torch":
1021
+ # Align `kernel_scale` to the same layout as
1022
+ # `unpacked_kernel`.
1023
+ kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
1024
+ float_kernel = ops.divide(
1025
+ ops.cast(unpacked_kernel, dtype=self.compute_dtype),
1026
+ kernel_scale,
1027
+ )
1028
+ x = ops.einsum(self.equation, inputs, float_kernel)
1029
+ else:
1030
+ x = ops.einsum(self.equation, inputs, unpacked_kernel)
1031
+ # De-scale outputs
1032
+ x = ops.cast(x, self.compute_dtype)
1033
+ x = ops.divide(x, kernel_scale)
861
1034
  return x, grad_fn
862
1035
 
863
1036
  x = einsum_with_inputs_gradient(
@@ -971,30 +1144,40 @@ class EinsumDense(Layer):
971
1144
  x = self.activation(x)
972
1145
  return x
973
1146
 
974
- def quantize(self, mode, type_check=True, config=None):
1147
+ def quantize(self, mode=None, type_check=True, config=None):
975
1148
  # Prevent quantization of the subclasses
976
1149
  if type_check and (type(self) is not EinsumDense):
977
1150
  raise self._not_implemented_error(self.quantize)
978
1151
 
1152
+ self.quantization_config = config
1153
+
979
1154
  kernel_shape = self._kernel.shape
980
- if mode in ("int8", "int4", "gptq"):
1155
+ if mode in ("int8", "int4", "gptq", "awq"):
981
1156
  self._set_quantization_info()
982
1157
 
983
1158
  if mode == "int8":
984
1159
  # Quantize `self._kernel` to int8 and compute corresponding scale
985
- kernel_value, kernel_scale = quantizers.abs_max_quantize(
986
- self._kernel, axis=self._kernel_reduced_axes, to_numpy=True
1160
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
1161
+ self.quantization_config,
1162
+ quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes),
1163
+ )
1164
+ kernel_value, kernel_scale = weight_quantizer(
1165
+ self._kernel, to_numpy=True
987
1166
  )
988
1167
  kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
989
1168
  del self._kernel
990
1169
  elif mode == "int4":
991
1170
  # Quantize to int4 values (stored in int8 dtype, range [-8, 7])
992
- kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
993
- self._kernel,
994
- axis=self._kernel_reduced_axes,
995
- value_range=(-8, 7),
996
- dtype="int8",
997
- to_numpy=True,
1171
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
1172
+ self.quantization_config,
1173
+ quantizers.AbsMaxQuantizer(
1174
+ axis=self._kernel_reduced_axes,
1175
+ value_range=(-8, 7),
1176
+ output_dtype="int8",
1177
+ ),
1178
+ )
1179
+ kernel_value_int4, kernel_scale = weight_quantizer(
1180
+ self._kernel, to_numpy=True
998
1181
  )
999
1182
  kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
1000
1183
 
@@ -1005,7 +1188,7 @@ class EinsumDense(Layer):
1005
1188
  )
1006
1189
  kernel_value = packed_kernel_value
1007
1190
  del self._kernel
1008
- self.quantized_build(kernel_shape, mode, config)
1191
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
1009
1192
 
1010
1193
  # Assign values to the newly created variables.
1011
1194
  if mode in ("int8", "int4"):
@@ -1016,7 +1199,9 @@ class EinsumDense(Layer):
1016
1199
  if self.dtype_policy.quantization_mode is None:
1017
1200
  policy_name = mode
1018
1201
  if mode == "gptq":
1019
- policy_name = config.dtype_policy_string()
1202
+ policy_name = self.quantization_config.dtype_policy_string()
1203
+ elif mode == "awq":
1204
+ policy_name = self.quantization_config.dtype_policy_string()
1020
1205
  policy = dtype_policies.get(
1021
1206
  f"{policy_name}_from_{self.dtype_policy.name}"
1022
1207
  )
@@ -1080,7 +1265,7 @@ class EinsumDense(Layer):
1080
1265
  This is `None` if the layer is not quantized.
1081
1266
  """
1082
1267
  # If not a quantized layer, return the full-precision kernel directly.
1083
- if self.dtype_policy.quantization_mode in (None, "gptq"):
1268
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
1084
1269
  return self.kernel, None
1085
1270
 
1086
1271
  # If quantized but LoRA is not enabled, return the original quantized