keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import math
1
2
  import re
2
3
  import string
3
4
 
@@ -5,6 +6,7 @@ import ml_dtypes
5
6
  import numpy as np
6
7
 
7
8
  from keras.src import activations
9
+ from keras.src import backend
8
10
  from keras.src import constraints
9
11
  from keras.src import dtype_policies
10
12
  from keras.src import initializers
@@ -14,6 +16,9 @@ from keras.src import regularizers
14
16
  from keras.src.api_export import keras_export
15
17
  from keras.src.layers.input_spec import InputSpec
16
18
  from keras.src.layers.layer import Layer
19
+ from keras.src.quantizers.quantization_config import QuantizationConfig
20
+ from keras.src.quantizers.quantizers import dequantize_with_sz_map
21
+ from keras.src.saving import serialization_lib
17
22
 
18
23
 
19
24
  @keras_export("keras.layers.EinsumDense")
@@ -131,6 +136,8 @@ class EinsumDense(Layer):
131
136
  bias_constraint=None,
132
137
  lora_rank=None,
133
138
  lora_alpha=None,
139
+ gptq_unpacked_column_size=None,
140
+ quantization_config=None,
134
141
  **kwargs,
135
142
  ):
136
143
  super().__init__(**kwargs)
@@ -150,6 +157,8 @@ class EinsumDense(Layer):
150
157
  self.lora_rank = lora_rank
151
158
  self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
152
159
  self.lora_enabled = False
160
+ self.gptq_unpacked_column_size = gptq_unpacked_column_size
161
+ self.quantization_config = quantization_config
153
162
 
154
163
  def build(self, input_shape):
155
164
  shape_data = _analyze_einsum_string(
@@ -162,12 +171,16 @@ class EinsumDense(Layer):
162
171
  self.full_output_shape = tuple(full_output_shape)
163
172
  self.input_spec = InputSpec(ndim=len(input_shape))
164
173
  if self.quantization_mode is not None:
165
- self.quantized_build(kernel_shape, mode=self.quantization_mode)
174
+ self.quantized_build(
175
+ kernel_shape,
176
+ mode=self.quantization_mode,
177
+ config=self.quantization_config,
178
+ )
166
179
  # Skip creating a duplicate kernel variable when the layer is already
167
180
  # quantized to int8 or int4, because `quantized_build` has created the
168
181
  # appropriate kernel variable. For other modes (e.g., float8 or no
169
182
  # quantization), we still need the floating-point kernel.
170
- if self.quantization_mode not in ("int8", "int4"):
183
+ if self.quantization_mode not in ("int8", "int4", "gptq", "awq"):
171
184
  # If the layer is quantized to int8, `self._kernel` will be added
172
185
  # in `self._int8_build`. Therefore, we skip it here.
173
186
  self._kernel = self.add_weight(
@@ -197,15 +210,62 @@ class EinsumDense(Layer):
197
210
 
198
211
  @property
199
212
  def kernel(self):
213
+ from keras.src.quantizers import gptq_core
214
+
200
215
  if not self.built:
201
216
  raise AttributeError(
202
217
  "You must build the layer before accessing `kernel`."
203
218
  )
219
+
220
+ mode = self.quantization_mode
221
+ is_gptq = mode == "gptq"
222
+ is_awq = mode == "awq"
223
+ is_int4 = mode == "int4"
224
+ gptq_calibrated = bool(getattr(self, "is_gptq_calibrated", False))
225
+ awq_calibrated = bool(getattr(self, "is_awq_calibrated", False))
226
+ gptq_bits = (
227
+ gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
228
+ )
229
+
230
+ # Decide the source tensor first (packed vs already-quantized vs plain
231
+ # kernel)
232
+ if is_gptq and gptq_calibrated and gptq_bits != 4:
233
+ # calibrated GPTQ, not 4-bit, no unpacking needed
234
+ kernel = self.quantized_kernel
235
+ else:
236
+ # Start with the stored kernel
237
+ kernel = getattr(self, "_kernel", None)
238
+
239
+ # Handle int4 unpacking cases in one place
240
+ if is_int4:
241
+ kernel = quantizers.unpack_int4(
242
+ kernel,
243
+ self._orig_length_along_pack_axis,
244
+ self._int4_pack_axis,
245
+ )
246
+ elif is_gptq and gptq_calibrated and gptq_bits == 4:
247
+ kernel = quantizers.unpack_int4(
248
+ self.quantized_kernel,
249
+ orig_len=self.gptq_unpacked_column_size,
250
+ axis=0,
251
+ dtype="uint8",
252
+ )
253
+ elif is_awq and awq_calibrated:
254
+ # AWQ always uses 4-bit quantization
255
+ kernel = quantizers.unpack_int4(
256
+ self.quantized_kernel,
257
+ orig_len=self.awq_unpacked_column_size,
258
+ axis=0,
259
+ dtype="uint8",
260
+ )
261
+
262
+ # Apply LoRA if enabled
204
263
  if self.lora_enabled:
205
- return self._kernel + (
206
- self.lora_alpha / self.lora_rank
207
- ) * ops.matmul(self.lora_kernel_a, self.lora_kernel_b)
208
- return self._kernel
264
+ kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
265
+ self.lora_kernel_a, self.lora_kernel_b
266
+ )
267
+
268
+ return kernel
209
269
 
210
270
  def compute_output_shape(self, _):
211
271
  return self.full_output_shape
@@ -239,6 +299,10 @@ class EinsumDense(Layer):
239
299
  raise ValueError(
240
300
  "lora is already enabled. This can only be done once per layer."
241
301
  )
302
+ if self.quantization_mode == "gptq":
303
+ raise NotImplementedError(
304
+ "lora is not currently supported with GPTQ quantization."
305
+ )
242
306
  self._tracker.unlock()
243
307
  # Determine the appropriate (unpacked) kernel shape for LoRA.
244
308
  if self.quantization_mode == "int4":
@@ -277,26 +341,26 @@ class EinsumDense(Layer):
277
341
  # Do nothing if the layer isn't yet built
278
342
  if not self.built:
279
343
  return
280
- # The keys of the `store` will be saved as determined because the
281
- # default ordering will change after quantization
282
- kernel_value, kernel_scale = self._get_kernel_with_merged_lora()
283
- target_variables = [kernel_value]
284
- if self.bias is not None:
285
- target_variables.append(self.bias)
286
- if self.quantization_mode is not None:
287
- if self.quantization_mode in ("int8", "int4"):
288
- target_variables.append(kernel_scale)
289
- elif self.quantization_mode == "float8":
290
- target_variables.append(self.inputs_scale)
291
- target_variables.append(self.inputs_amax_history)
292
- target_variables.append(self.kernel_scale)
293
- target_variables.append(self.kernel_amax_history)
294
- target_variables.append(self.outputs_grad_scale)
295
- target_variables.append(self.outputs_grad_amax_history)
344
+ mode = self.quantization_mode
345
+ if mode not in self.variable_serialization_spec:
346
+ raise self._quantization_mode_error(mode)
347
+
348
+ # Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
349
+ # for None/gptq)
350
+ kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
351
+ idx = 0
352
+ for name in self.variable_serialization_spec[mode]:
353
+ if name == "kernel":
354
+ store[str(idx)] = kernel_value
355
+ elif name == "bias" and self.bias is None:
356
+ continue
357
+ elif name == "kernel_scale" and mode in ("int4", "int8"):
358
+ # For int4/int8, the merged LoRA scale (if any) comes from
359
+ # `_get_kernel_with_merged_lora()`
360
+ store[str(idx)] = merged_kernel_scale
296
361
  else:
297
- raise self._quantization_mode_error(self.quantization_mode)
298
- for i, variable in enumerate(target_variables):
299
- store[str(i)] = variable
362
+ store[str(idx)] = getattr(self, name)
363
+ idx += 1
300
364
 
301
365
  def load_own_variables(self, store):
302
366
  if not self.lora_enabled:
@@ -304,25 +368,23 @@ class EinsumDense(Layer):
304
368
  # Do nothing if the layer isn't yet built
305
369
  if not self.built:
306
370
  return
307
- # The keys of the `store` will be saved as determined because the
308
- # default ordering will change after quantization
309
- target_variables = [self._kernel]
310
- if self.bias is not None:
311
- target_variables.append(self.bias)
312
- if self.quantization_mode is not None:
313
- if self.quantization_mode in ("int8", "int4"):
314
- target_variables.append(self.kernel_scale)
315
- elif self.quantization_mode == "float8":
316
- target_variables.append(self.inputs_scale)
317
- target_variables.append(self.inputs_amax_history)
318
- target_variables.append(self.kernel_scale)
319
- target_variables.append(self.kernel_amax_history)
320
- target_variables.append(self.outputs_grad_scale)
321
- target_variables.append(self.outputs_grad_amax_history)
371
+ mode = self.quantization_mode
372
+ if mode not in self.variable_serialization_spec:
373
+ raise self._quantization_mode_error(mode)
374
+
375
+ # A saved GPTQ/AWQ quantized model will always be calibrated.
376
+ self.is_gptq_calibrated = mode == "gptq"
377
+ self.is_awq_calibrated = mode == "awq"
378
+
379
+ idx = 0
380
+ for name in self.variable_serialization_spec[mode]:
381
+ if name == "kernel":
382
+ self._kernel.assign(store[str(idx)])
383
+ elif name == "bias" and self.bias is None:
384
+ continue
322
385
  else:
323
- raise self._quantization_mode_error(self.quantization_mode)
324
- for i, variable in enumerate(target_variables):
325
- variable.assign(store[str(i)])
386
+ getattr(self, name).assign(store[str(idx)])
387
+ idx += 1
326
388
  if self.lora_enabled:
327
389
  self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
328
390
  self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
@@ -347,64 +409,103 @@ class EinsumDense(Layer):
347
409
  ),
348
410
  "kernel_constraint": constraints.serialize(self.kernel_constraint),
349
411
  "bias_constraint": constraints.serialize(self.bias_constraint),
412
+ "quantization_config": serialization_lib.serialize_keras_object(
413
+ self.quantization_config
414
+ ),
350
415
  }
351
416
  if self.lora_rank:
352
417
  config["lora_rank"] = self.lora_rank
353
418
  config["lora_alpha"] = self.lora_alpha
419
+ if self.gptq_unpacked_column_size:
420
+ config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size
354
421
  return {**base_config, **config}
355
422
 
356
- def _check_load_own_variables(self, store):
357
- all_vars = self._trainable_variables + self._non_trainable_variables
358
- if len(store.keys()) != len(all_vars):
359
- if len(all_vars) == 0 and not self.built:
360
- raise ValueError(
361
- f"Layer '{self.name}' was never built "
362
- "and thus it doesn't have any variables. "
363
- f"However the weights file lists {len(store.keys())} "
364
- "variables for this layer.\n"
365
- "In most cases, this error indicates that either:\n\n"
366
- "1. The layer is owned by a parent layer that "
367
- "implements a `build()` method, but calling the "
368
- "parent's `build()` method did NOT create the state of "
369
- f"the child layer '{self.name}'. A `build()` method "
370
- "must create ALL state for the layer, including "
371
- "the state of any children layers.\n\n"
372
- "2. You need to implement "
373
- "the `def build_from_config(self, config)` method "
374
- f"on layer '{self.name}', to specify how to rebuild "
375
- "it during loading. "
376
- "In this case, you might also want to implement the "
377
- "method that generates the build config at saving time, "
378
- "`def get_build_config(self)`. "
379
- "The method `build_from_config()` is meant "
380
- "to create the state "
381
- "of the layer (i.e. its variables) upon deserialization.",
382
- )
383
- raise ValueError(
384
- f"Layer '{self.name}' expected {len(all_vars)} variables, "
385
- "but received "
386
- f"{len(store.keys())} variables during loading. "
387
- f"Expected: {[v.name for v in all_vars]}"
423
+ @classmethod
424
+ def from_config(cls, config):
425
+ config = config.copy()
426
+ config["quantization_config"] = (
427
+ serialization_lib.deserialize_keras_object(
428
+ config.get("quantization_config", None)
388
429
  )
430
+ )
431
+ return super().from_config(config)
432
+
433
+ @property
434
+ def variable_serialization_spec(self):
435
+ """Returns a dict mapping quantization modes to variable names in order.
389
436
 
390
- # Quantization-related (int8 and float8) methods
437
+ This spec is used by `save_own_variables` and `load_own_variables` to
438
+ determine the correct ordering of variables during serialization for
439
+ each quantization mode. `None` means no quantization.
440
+ """
441
+ return {
442
+ None: [
443
+ "kernel",
444
+ "bias",
445
+ ],
446
+ "int8": [
447
+ "kernel",
448
+ "bias",
449
+ "kernel_scale",
450
+ ],
451
+ "int4": [
452
+ "kernel",
453
+ "bias",
454
+ "kernel_scale",
455
+ ],
456
+ "float8": [
457
+ "kernel",
458
+ "bias",
459
+ "inputs_scale",
460
+ "inputs_amax_history",
461
+ "kernel_scale",
462
+ "kernel_amax_history",
463
+ "outputs_grad_scale",
464
+ "outputs_grad_amax_history",
465
+ ],
466
+ "gptq": [
467
+ "bias",
468
+ "quantized_kernel",
469
+ "kernel_scale",
470
+ "kernel_zero",
471
+ "g_idx",
472
+ ],
473
+ "awq": [
474
+ "bias",
475
+ "quantized_kernel",
476
+ "kernel_scale",
477
+ "kernel_zero",
478
+ "awq_scales",
479
+ "g_idx",
480
+ ],
481
+ }
391
482
 
392
- def quantized_build(self, kernel_shape, mode):
483
+ def quantized_build(self, kernel_shape, mode, config=None):
393
484
  if mode == "int8":
394
- self._int8_build(kernel_shape)
485
+ self._int8_build(kernel_shape, config)
395
486
  elif mode == "int4":
396
- self._int4_build(kernel_shape)
487
+ self._int4_build(kernel_shape, config)
397
488
  elif mode == "float8":
398
489
  self._float8_build()
490
+ elif mode == "gptq":
491
+ self._gptq_build(kernel_shape, config)
492
+ elif mode == "awq":
493
+ self._awq_build(kernel_shape, config)
399
494
  else:
400
495
  raise self._quantization_mode_error(mode)
401
496
  self._is_quantized = True
402
497
 
403
- def _int8_build(self, kernel_shape):
498
+ def _int8_build(self, kernel_shape, config=None):
404
499
  self._set_quantization_info()
405
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(
406
- axis=self._input_reduced_axes
500
+ self.inputs_quantizer = (
501
+ QuantizationConfig.activation_quantizer_or_default(
502
+ config,
503
+ quantizers.AbsMaxQuantizer(),
504
+ )
407
505
  )
506
+ # If the config provided a default AbsMaxQuantizer, we need to
507
+ # override the axis to match the equation's reduction axes.
508
+ self.quantization_axis = tuple(self._input_reduced_axes)
408
509
  self._kernel = self.add_weight(
409
510
  name="kernel",
410
511
  shape=kernel_shape,
@@ -420,7 +521,244 @@ class EinsumDense(Layer):
420
521
  trainable=False,
421
522
  )
422
523
 
423
- def _int4_build(self, kernel_shape):
524
+ def _gptq_build(self, kernel_shape, config):
525
+ """
526
+ Allocate quantized kernel & params for EinsumDense.
527
+
528
+ Args:
529
+ kernel_shape: tuple/list; the layer's original kernel shape, e.g.
530
+ [in_features, out_features] or [in_features, heads, head_dim].
531
+ group_size: int; contiguous input-group size for quantization
532
+ (=-1 means per-output-channel with no grouping).
533
+ """
534
+ from keras.src.quantizers import gptq_core
535
+
536
+ # Ensures the forward pass uses the original high-precision kernel
537
+ # until calibration has been performed.
538
+ self.is_gptq_calibrated = False
539
+
540
+ self.original_kernel_shape = kernel_shape
541
+ if len(kernel_shape) == 2:
542
+ rows = kernel_shape[0]
543
+ columns = kernel_shape[1]
544
+ elif len(kernel_shape) == 3:
545
+ shape = list(self.original_kernel_shape)
546
+ d_model_dim_index = shape.index(max(shape))
547
+
548
+ if d_model_dim_index == 0: # QKV projection case
549
+ in_features, heads, head_dim = shape
550
+ rows, columns = (
551
+ in_features,
552
+ heads * head_dim,
553
+ )
554
+ elif d_model_dim_index in [1, 2]: # Attention Output case
555
+ heads, head_dim, out_features = shape
556
+ rows, columns = (
557
+ heads * head_dim,
558
+ out_features,
559
+ )
560
+ else:
561
+ raise ValueError("Could not determine row/column split.")
562
+
563
+ group_size = gptq_core.get_group_size_for_layer(self, config)
564
+ n_groups = 1 if group_size == -1 else math.ceil(rows / group_size)
565
+
566
+ self.gptq_unpacked_column_size = columns
567
+
568
+ weight_bits = gptq_core.get_weight_bits_for_layer(self, config)
569
+ # For 4-bit weights, we pack two values per byte.
570
+ kernel_columns = (columns + 1) // 2 if weight_bits == 4 else columns
571
+
572
+ self._set_quantization_info()
573
+
574
+ self.quantized_kernel = self.add_weight(
575
+ name="kernel",
576
+ shape=(kernel_columns, rows),
577
+ initializer="zeros",
578
+ dtype="uint8",
579
+ trainable=False,
580
+ )
581
+
582
+ self.kernel_scale = self.add_weight(
583
+ name="kernel_scale",
584
+ shape=(columns, n_groups),
585
+ initializer="ones",
586
+ trainable=False,
587
+ )
588
+ self.kernel_zero = self.add_weight(
589
+ name="zero_point",
590
+ shape=(columns, n_groups),
591
+ initializer="zeros",
592
+ dtype="uint8",
593
+ trainable=False,
594
+ )
595
+
596
+ self.g_idx = self.add_weight(
597
+ name="g_idx",
598
+ shape=(rows,),
599
+ initializer="zeros",
600
+ dtype="float32",
601
+ trainable=False,
602
+ )
603
+
604
+ def _gptq_call(self, inputs, training=False):
605
+ from keras.src.quantizers import gptq_core
606
+
607
+ if not self.is_gptq_calibrated:
608
+ W = self._kernel
609
+ else:
610
+ should_unpack = (
611
+ gptq_core.get_weight_bits_for_layer(self, config=None) == 4
612
+ )
613
+ W = (
614
+ quantizers.unpack_int4(
615
+ self.quantized_kernel,
616
+ orig_len=self.gptq_unpacked_column_size,
617
+ axis=0,
618
+ dtype="uint8",
619
+ )
620
+ if should_unpack
621
+ else self.quantized_kernel
622
+ )
623
+ W = dequantize_with_sz_map(
624
+ W,
625
+ self.kernel_scale,
626
+ self.kernel_zero,
627
+ self.g_idx,
628
+ )
629
+ W = ops.transpose(W)
630
+
631
+ W = ops.reshape(W, self.original_kernel_shape)
632
+
633
+ y = ops.einsum(self.equation, inputs, W)
634
+ if self.bias is not None:
635
+ y = ops.add(y, self.bias)
636
+ if self.activation is not None:
637
+ y = self.activation(y)
638
+ return y
639
+
640
+ def _awq_build(self, kernel_shape, config):
641
+ """Build variables for AWQ quantization.
642
+
643
+ AWQ uses 4-bit quantization with per-channel AWQ scales that protect
644
+ salient weights based on activation magnitudes.
645
+ """
646
+ from keras.src.quantizers import awq_core
647
+
648
+ # Ensures the forward pass uses the original high-precision kernel
649
+ # until calibration has been performed.
650
+ self.is_awq_calibrated = False
651
+
652
+ self.original_kernel_shape = kernel_shape
653
+ if len(kernel_shape) == 2:
654
+ rows = kernel_shape[0]
655
+ columns = kernel_shape[1]
656
+ elif len(kernel_shape) == 3:
657
+ shape = list(self.original_kernel_shape)
658
+ d_model_dim_index = shape.index(max(shape))
659
+
660
+ if d_model_dim_index == 0: # QKV projection case
661
+ in_features, heads, head_dim = shape
662
+ rows, columns = (
663
+ in_features,
664
+ heads * head_dim,
665
+ )
666
+ elif d_model_dim_index in [1, 2]: # Attention Output case
667
+ heads, head_dim, out_features = shape
668
+ rows, columns = (
669
+ heads * head_dim,
670
+ out_features,
671
+ )
672
+ else:
673
+ raise ValueError("Could not determine row/column split.")
674
+ else:
675
+ raise ValueError("AWQ quantization only supports 2D or 3D kernels.")
676
+
677
+ group_size = awq_core.get_group_size_for_layer(self, config)
678
+ num_groups = 1 if group_size == -1 else math.ceil(rows / group_size)
679
+
680
+ self.awq_unpacked_column_size = columns
681
+
682
+ # For 4-bit weights, we pack two values per byte.
683
+ kernel_columns = (columns + 1) // 2
684
+
685
+ self._set_quantization_info()
686
+
687
+ self.quantized_kernel = self.add_weight(
688
+ name="kernel",
689
+ shape=(kernel_columns, rows),
690
+ initializer="zeros",
691
+ dtype="uint8",
692
+ trainable=False,
693
+ )
694
+
695
+ self.kernel_scale = self.add_weight(
696
+ name="kernel_scale",
697
+ shape=(columns, num_groups),
698
+ initializer="ones",
699
+ trainable=False,
700
+ )
701
+ self.kernel_zero = self.add_weight(
702
+ name="zero_point",
703
+ shape=(columns, num_groups),
704
+ initializer="zeros",
705
+ dtype="uint8",
706
+ trainable=False,
707
+ )
708
+
709
+ # Per-channel AWQ scales from activation magnitudes
710
+ self.awq_scales = self.add_weight(
711
+ name="awq_scales",
712
+ shape=(rows,),
713
+ initializer="ones",
714
+ trainable=False,
715
+ )
716
+
717
+ self.g_idx = self.add_weight(
718
+ name="g_idx",
719
+ shape=(rows,),
720
+ initializer="zeros",
721
+ dtype="float32",
722
+ trainable=False,
723
+ )
724
+
725
+ def _awq_call(self, inputs, training=False):
726
+ """Forward pass for AWQ quantized layer."""
727
+ if not self.is_awq_calibrated:
728
+ W = self._kernel
729
+ else:
730
+ # Unpack 4-bit weights
731
+ W = quantizers.unpack_int4(
732
+ self.quantized_kernel,
733
+ orig_len=self.awq_unpacked_column_size,
734
+ axis=0,
735
+ dtype="uint8",
736
+ )
737
+ # Dequantize using scale/zero maps
738
+ W = dequantize_with_sz_map(
739
+ W,
740
+ self.kernel_scale,
741
+ self.kernel_zero,
742
+ self.g_idx,
743
+ )
744
+ W = ops.transpose(W)
745
+
746
+ # Apply AWQ scales by dividing to restore original magnitude
747
+ # (We multiplied by scales before quantization, so divide to undo)
748
+ # awq_scales has shape [input_dim], W has shape [input_dim, out_dim]
749
+ # Expand dims for proper broadcasting.
750
+ W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))
751
+
752
+ W = ops.reshape(W, self.original_kernel_shape)
753
+
754
+ y = ops.einsum(self.equation, inputs, W)
755
+ if self.bias is not None:
756
+ y = ops.add(y, self.bias)
757
+ if self.activation is not None:
758
+ y = self.activation(y)
759
+ return y
760
+
761
+ def _int4_build(self, kernel_shape, config=None):
424
762
  """Build variables for int4 quantization.
425
763
 
426
764
  The packed int4 kernel stores two int4 values within a single int8
@@ -432,9 +770,15 @@ class EinsumDense(Layer):
432
770
  self._set_quantization_info()
433
771
 
434
772
  # Quantizer for the inputs (per the reduced axes)
435
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(
436
- axis=self._input_reduced_axes
773
+ self.inputs_quantizer = (
774
+ QuantizationConfig.activation_quantizer_or_default(
775
+ config,
776
+ quantizers.AbsMaxQuantizer(),
777
+ )
437
778
  )
779
+ # If the config provided a default AbsMaxQuantizer, we need to
780
+ # override the axis to match the equation's reduction axes.
781
+ self.quantization_axis = tuple(self._input_reduced_axes)
438
782
 
439
783
  # Choose the axis to perform int4 packing - use the first reduced axis
440
784
  # for the kernel (analogous to the input dimension of a Dense layer).
@@ -556,13 +900,36 @@ class EinsumDense(Layer):
556
900
  )
557
901
  return (inputs_grad, None, None)
558
902
 
559
- inputs, inputs_scale = self.inputs_quantizer(inputs)
560
- x = ops.einsum(self.equation, inputs, kernel)
561
- # Deal with `inputs_scale`
562
- inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input")
563
- # De-scale outputs
564
- x = ops.cast(x, self.compute_dtype)
565
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
903
+ if self.inputs_quantizer:
904
+ inputs, inputs_scale = self.inputs_quantizer(
905
+ inputs, axis=self.quantization_axis
906
+ )
907
+ # Align `inputs_scale` axes with the output
908
+ # for correct broadcasting
909
+ inputs_scale = self._adjust_scale_for_quant(
910
+ inputs_scale, "input"
911
+ )
912
+ x = ops.einsum(self.equation, inputs, kernel)
913
+ # De-scale outputs
914
+ x = ops.cast(x, self.compute_dtype)
915
+ x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
916
+ else:
917
+ # Weight-only quantization: dequantize kernel and use float
918
+ # einsum. This is a workaround for PyTorch's einsum which
919
+ # doesn't support mixed-precision inputs (float input,
920
+ # int8 kernel).
921
+ if backend.backend() == "torch":
922
+ kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
923
+ float_kernel = ops.divide(
924
+ ops.cast(kernel, dtype=self.compute_dtype),
925
+ kernel_scale,
926
+ )
927
+ x = ops.einsum(self.equation, inputs, float_kernel)
928
+ else:
929
+ x = ops.einsum(self.equation, inputs, kernel)
930
+ # De-scale outputs
931
+ x = ops.cast(x, self.compute_dtype)
932
+ x = ops.divide(x, kernel_scale)
566
933
  return x, grad_fn
567
934
 
568
935
  x = einsum_with_inputs_gradient(
@@ -632,17 +999,38 @@ class EinsumDense(Layer):
632
999
  return (inputs_grad, None, None)
633
1000
 
634
1001
  # Quantize inputs per `self.inputs_quantizer`.
635
- inputs_q, inputs_scale = self.inputs_quantizer(inputs)
636
-
637
- # Compute einsum on quantized inputs and unpacked int4 kernel.
638
- x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
639
-
640
- # Align `inputs_scale` axes with the output for correct broadcasting
641
- inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input")
642
-
643
- # De-scale outputs.
644
- x = ops.cast(x, self.compute_dtype)
645
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
1002
+ if self.inputs_quantizer:
1003
+ inputs_q, inputs_scale = self.inputs_quantizer(
1004
+ inputs, axis=self.quantization_axis
1005
+ )
1006
+ # Align `inputs_scale` axes with the output
1007
+ # for correct broadcasting
1008
+ inputs_scale = self._adjust_scale_for_quant(
1009
+ inputs_scale, "input"
1010
+ )
1011
+ x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
1012
+ # De-scale outputs
1013
+ x = ops.cast(x, self.compute_dtype)
1014
+ x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
1015
+ else:
1016
+ # Weight-only quantization: dequantize kernel and use float
1017
+ # einsum. This is a workaround for PyTorch's einsum which
1018
+ # doesn't support mixed-precision inputs (float input,
1019
+ # int4 kernel).
1020
+ if backend.backend() == "torch":
1021
+ # Align `kernel_scale` to the same layout as
1022
+ # `unpacked_kernel`.
1023
+ kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
1024
+ float_kernel = ops.divide(
1025
+ ops.cast(unpacked_kernel, dtype=self.compute_dtype),
1026
+ kernel_scale,
1027
+ )
1028
+ x = ops.einsum(self.equation, inputs, float_kernel)
1029
+ else:
1030
+ x = ops.einsum(self.equation, inputs, unpacked_kernel)
1031
+ # De-scale outputs
1032
+ x = ops.cast(x, self.compute_dtype)
1033
+ x = ops.divide(x, kernel_scale)
646
1034
  return x, grad_fn
647
1035
 
648
1036
  x = einsum_with_inputs_gradient(
@@ -756,30 +1144,40 @@ class EinsumDense(Layer):
756
1144
  x = self.activation(x)
757
1145
  return x
758
1146
 
759
- def quantize(self, mode, type_check=True):
1147
+ def quantize(self, mode=None, type_check=True, config=None):
760
1148
  # Prevent quantization of the subclasses
761
1149
  if type_check and (type(self) is not EinsumDense):
762
1150
  raise self._not_implemented_error(self.quantize)
763
1151
 
1152
+ self.quantization_config = config
1153
+
764
1154
  kernel_shape = self._kernel.shape
765
- if mode in ("int8", "int4"):
1155
+ if mode in ("int8", "int4", "gptq", "awq"):
766
1156
  self._set_quantization_info()
767
1157
 
768
1158
  if mode == "int8":
769
1159
  # Quantize `self._kernel` to int8 and compute corresponding scale
770
- kernel_value, kernel_scale = quantizers.abs_max_quantize(
771
- self._kernel, axis=self._kernel_reduced_axes, to_numpy=True
1160
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
1161
+ self.quantization_config,
1162
+ quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes),
1163
+ )
1164
+ kernel_value, kernel_scale = weight_quantizer(
1165
+ self._kernel, to_numpy=True
772
1166
  )
773
1167
  kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
774
1168
  del self._kernel
775
1169
  elif mode == "int4":
776
1170
  # Quantize to int4 values (stored in int8 dtype, range [-8, 7])
777
- kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
778
- self._kernel,
779
- axis=self._kernel_reduced_axes,
780
- value_range=(-8, 7),
781
- dtype="int8",
782
- to_numpy=True,
1171
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
1172
+ self.quantization_config,
1173
+ quantizers.AbsMaxQuantizer(
1174
+ axis=self._kernel_reduced_axes,
1175
+ value_range=(-8, 7),
1176
+ output_dtype="int8",
1177
+ ),
1178
+ )
1179
+ kernel_value_int4, kernel_scale = weight_quantizer(
1180
+ self._kernel, to_numpy=True
783
1181
  )
784
1182
  kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
785
1183
 
@@ -790,7 +1188,7 @@ class EinsumDense(Layer):
790
1188
  )
791
1189
  kernel_value = packed_kernel_value
792
1190
  del self._kernel
793
- self.quantized_build(kernel_shape, mode)
1191
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
794
1192
 
795
1193
  # Assign values to the newly created variables.
796
1194
  if mode in ("int8", "int4"):
@@ -799,7 +1197,14 @@ class EinsumDense(Layer):
799
1197
 
800
1198
  # Set new dtype policy
801
1199
  if self.dtype_policy.quantization_mode is None:
802
- policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
1200
+ policy_name = mode
1201
+ if mode == "gptq":
1202
+ policy_name = self.quantization_config.dtype_policy_string()
1203
+ elif mode == "awq":
1204
+ policy_name = self.quantization_config.dtype_policy_string()
1205
+ policy = dtype_policies.get(
1206
+ f"{policy_name}_from_{self.dtype_policy.name}"
1207
+ )
803
1208
  self.dtype_policy = policy
804
1209
 
805
1210
  def _get_kernel_scale_shape(self, kernel_shape):
@@ -860,7 +1265,7 @@ class EinsumDense(Layer):
860
1265
  This is `None` if the layer is not quantized.
861
1266
  """
862
1267
  # If not a quantized layer, return the full-precision kernel directly.
863
- if self.dtype_policy.quantization_mode is None:
1268
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
864
1269
  return self.kernel, None
865
1270
 
866
1271
  # If quantized but LoRA is not enabled, return the original quantized