keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import ml_dtypes
6
6
  import numpy as np
7
7
 
8
8
  from keras.src import activations
9
+ from keras.src import backend
9
10
  from keras.src import constraints
10
11
  from keras.src import dtype_policies
11
12
  from keras.src import initializers
@@ -13,12 +14,11 @@ from keras.src import ops
13
14
  from keras.src import quantizers
14
15
  from keras.src import regularizers
15
16
  from keras.src.api_export import keras_export
16
- from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
17
- from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
18
17
  from keras.src.layers.input_spec import InputSpec
19
18
  from keras.src.layers.layer import Layer
20
- from keras.src.quantizers.gptq_config import GPTQConfig
19
+ from keras.src.quantizers.quantization_config import QuantizationConfig
21
20
  from keras.src.quantizers.quantizers import dequantize_with_sz_map
21
+ from keras.src.saving import serialization_lib
22
22
 
23
23
 
24
24
  @keras_export("keras.layers.EinsumDense")
@@ -136,6 +136,8 @@ class EinsumDense(Layer):
136
136
  bias_constraint=None,
137
137
  lora_rank=None,
138
138
  lora_alpha=None,
139
+ gptq_unpacked_column_size=None,
140
+ quantization_config=None,
139
141
  **kwargs,
140
142
  ):
141
143
  super().__init__(**kwargs)
@@ -155,6 +157,8 @@ class EinsumDense(Layer):
155
157
  self.lora_rank = lora_rank
156
158
  self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
157
159
  self.lora_enabled = False
160
+ self.gptq_unpacked_column_size = gptq_unpacked_column_size
161
+ self.quantization_config = quantization_config
158
162
 
159
163
  def build(self, input_shape):
160
164
  shape_data = _analyze_einsum_string(
@@ -170,6 +174,7 @@ class EinsumDense(Layer):
170
174
  self.quantized_build(
171
175
  kernel_shape,
172
176
  mode=self.quantization_mode,
177
+ config=self.quantization_config,
173
178
  )
174
179
  # Skip creating a duplicate kernel variable when the layer is already
175
180
  # quantized to int8 or int4, because `quantized_build` has created the
@@ -205,24 +210,51 @@ class EinsumDense(Layer):
205
210
 
206
211
  @property
207
212
  def kernel(self):
213
+ from keras.src.quantizers import gptq_core
214
+
208
215
  if not self.built:
209
216
  raise AttributeError(
210
217
  "You must build the layer before accessing `kernel`."
211
218
  )
212
- if (
213
- getattr(self, "is_gptq_calibrated", False)
214
- and self.quantization_mode == "gptq"
215
- ):
216
- return self.quantized_kernel
217
- kernel = self._kernel
218
- if self.quantization_mode == "int4":
219
- kernel = quantizers.unpack_int4(
220
- kernel, self._orig_length_along_pack_axis, self._int4_pack_axis
221
- )
219
+
220
+ mode = self.quantization_mode
221
+ is_gptq = mode == "gptq"
222
+ is_int4 = mode == "int4"
223
+ calibrated = bool(getattr(self, "is_gptq_calibrated", False))
224
+ gptq_bits = (
225
+ gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
226
+ )
227
+
228
+ # Decide the source tensor first (packed vs already-quantized vs plain
229
+ # kernel)
230
+ if is_gptq and calibrated and gptq_bits != 4:
231
+ # calibrated GPTQ, not 4-bit, no unpacking needed
232
+ kernel = self.quantized_kernel
233
+ else:
234
+ # Start with the stored kernel
235
+ kernel = getattr(self, "_kernel", None)
236
+
237
+ # Handle int4 unpacking cases in one place
238
+ if is_int4:
239
+ kernel = quantizers.unpack_int4(
240
+ kernel,
241
+ self._orig_length_along_pack_axis,
242
+ self._int4_pack_axis,
243
+ )
244
+ elif is_gptq and calibrated and gptq_bits == 4:
245
+ kernel = quantizers.unpack_int4(
246
+ self.quantized_kernel,
247
+ orig_len=self.gptq_unpacked_column_size,
248
+ axis=0,
249
+ dtype="uint8",
250
+ )
251
+
252
+ # Apply LoRA if enabled
222
253
  if self.lora_enabled:
223
- return kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
254
+ kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
224
255
  self.lora_kernel_a, self.lora_kernel_b
225
256
  )
257
+
226
258
  return kernel
227
259
 
228
260
  def compute_output_shape(self, _):
@@ -300,25 +332,25 @@ class EinsumDense(Layer):
300
332
  if not self.built:
301
333
  return
302
334
  mode = self.quantization_mode
303
- if mode not in self.quantization_variable_spec:
335
+ if mode not in self.variable_serialization_spec:
304
336
  raise self._quantization_mode_error(mode)
305
337
 
306
338
  # Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
307
339
  # for None/gptq)
308
340
  kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
309
-
310
- # Save the variables using the name as the key.
311
- if mode != "gptq":
312
- store["kernel"] = kernel_value
313
- if self.bias is not None:
314
- store["bias"] = self.bias
315
- for name in self.quantization_variable_spec[mode]:
316
- if name == "kernel_scale" and mode in ("int4", "int8"):
341
+ idx = 0
342
+ for name in self.variable_serialization_spec[mode]:
343
+ if name == "kernel":
344
+ store[str(idx)] = kernel_value
345
+ elif name == "bias" and self.bias is None:
346
+ continue
347
+ elif name == "kernel_scale" and mode in ("int4", "int8"):
317
348
  # For int4/int8, the merged LoRA scale (if any) comes from
318
349
  # `_get_kernel_with_merged_lora()`
319
- store[name] = merged_kernel_scale
350
+ store[str(idx)] = merged_kernel_scale
320
351
  else:
321
- store[name] = getattr(self, name)
352
+ store[str(idx)] = getattr(self, name)
353
+ idx += 1
322
354
 
323
355
  def load_own_variables(self, store):
324
356
  if not self.lora_enabled:
@@ -327,39 +359,21 @@ class EinsumDense(Layer):
327
359
  if not self.built:
328
360
  return
329
361
  mode = self.quantization_mode
330
- if mode not in self.quantization_variable_spec:
362
+ if mode not in self.variable_serialization_spec:
331
363
  raise self._quantization_mode_error(mode)
332
364
 
333
- # Determine whether to use the legacy loading method.
334
- if "0" in store:
335
- return self._legacy_load_own_variables(store)
336
-
337
- # Load the variables using the name as the key.
338
- if mode != "gptq":
339
- self._kernel.assign(store["kernel"])
340
- if self.bias is not None:
341
- self.bias.assign(store["bias"])
342
- for name in self.quantization_variable_spec[mode]:
343
- getattr(self, name).assign(store[name])
344
- if self.lora_enabled:
345
- self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
346
- self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
365
+ # A saved GPTQ quantized model will always be calibrated.
366
+ self.is_gptq_calibrated = mode == "gptq"
347
367
 
348
- def _legacy_load_own_variables(self, store):
349
- # The keys of the `store` will be saved as determined because the
350
- # default ordering will change after quantization
351
- mode = self.quantization_mode
352
- targets = []
353
- if mode != "gptq":
354
- targets.append(self._kernel)
355
- if self.bias is not None:
356
- targets.append(self.bias)
357
- targets.extend(
358
- getattr(self, name)
359
- for name in self.quantization_variable_spec[mode]
360
- )
361
- for i, variable in enumerate(targets):
362
- variable.assign(store[str(i)])
368
+ idx = 0
369
+ for name in self.variable_serialization_spec[mode]:
370
+ if name == "kernel":
371
+ self._kernel.assign(store[str(idx)])
372
+ elif name == "bias" and self.bias is None:
373
+ continue
374
+ else:
375
+ getattr(self, name).assign(store[str(idx)])
376
+ idx += 1
363
377
  if self.lora_enabled:
364
378
  self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
365
379
  self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
@@ -384,59 +398,53 @@ class EinsumDense(Layer):
384
398
  ),
385
399
  "kernel_constraint": constraints.serialize(self.kernel_constraint),
386
400
  "bias_constraint": constraints.serialize(self.bias_constraint),
401
+ "quantization_config": serialization_lib.serialize_keras_object(
402
+ self.quantization_config
403
+ ),
387
404
  }
388
405
  if self.lora_rank:
389
406
  config["lora_rank"] = self.lora_rank
390
407
  config["lora_alpha"] = self.lora_alpha
408
+ if self.gptq_unpacked_column_size:
409
+ config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size
391
410
  return {**base_config, **config}
392
411
 
393
- def _check_load_own_variables(self, store):
394
- all_vars = self._trainable_variables + self._non_trainable_variables
395
- if len(store.keys()) != len(all_vars):
396
- if len(all_vars) == 0 and not self.built:
397
- raise ValueError(
398
- f"Layer '{self.name}' was never built "
399
- "and thus it doesn't have any variables. "
400
- f"However the weights file lists {len(store.keys())} "
401
- "variables for this layer.\n"
402
- "In most cases, this error indicates that either:\n\n"
403
- "1. The layer is owned by a parent layer that "
404
- "implements a `build()` method, but calling the "
405
- "parent's `build()` method did NOT create the state of "
406
- f"the child layer '{self.name}'. A `build()` method "
407
- "must create ALL state for the layer, including "
408
- "the state of any children layers.\n\n"
409
- "2. You need to implement "
410
- "the `def build_from_config(self, config)` method "
411
- f"on layer '{self.name}', to specify how to rebuild "
412
- "it during loading. "
413
- "In this case, you might also want to implement the "
414
- "method that generates the build config at saving time, "
415
- "`def get_build_config(self)`. "
416
- "The method `build_from_config()` is meant "
417
- "to create the state "
418
- "of the layer (i.e. its variables) upon deserialization.",
419
- )
420
- raise ValueError(
421
- f"Layer '{self.name}' expected {len(all_vars)} variables, "
422
- "but received "
423
- f"{len(store.keys())} variables during loading. "
424
- f"Expected: {[v.name for v in all_vars]}"
412
+ @classmethod
413
+ def from_config(cls, config):
414
+ config = config.copy()
415
+ config["quantization_config"] = (
416
+ serialization_lib.deserialize_keras_object(
417
+ config.get("quantization_config", None)
425
418
  )
419
+ )
420
+ return super().from_config(config)
426
421
 
427
422
  @property
428
- def quantization_variable_spec(self):
429
- """Returns a dict mapping quantization modes to variable names.
423
+ def variable_serialization_spec(self):
424
+ """Returns a dict mapping quantization modes to variable names in order.
430
425
 
431
426
  This spec is used by `save_own_variables` and `load_own_variables` to
432
- determine which variables should be saved/loaded for each quantization
433
- mode.
427
+ determine the correct ordering of variables during serialization for
428
+ each quantization mode. `None` means no quantization.
434
429
  """
435
430
  return {
436
- None: [],
437
- "int8": ["kernel_scale"],
438
- "int4": ["kernel_scale"],
431
+ None: [
432
+ "kernel",
433
+ "bias",
434
+ ],
435
+ "int8": [
436
+ "kernel",
437
+ "bias",
438
+ "kernel_scale",
439
+ ],
440
+ "int4": [
441
+ "kernel",
442
+ "bias",
443
+ "kernel_scale",
444
+ ],
439
445
  "float8": [
446
+ "kernel",
447
+ "bias",
440
448
  "inputs_scale",
441
449
  "inputs_amax_history",
442
450
  "kernel_scale",
@@ -445,6 +453,7 @@ class EinsumDense(Layer):
445
453
  "outputs_grad_amax_history",
446
454
  ],
447
455
  "gptq": [
456
+ "bias",
448
457
  "quantized_kernel",
449
458
  "kernel_scale",
450
459
  "kernel_zero",
@@ -454,9 +463,9 @@ class EinsumDense(Layer):
454
463
 
455
464
  def quantized_build(self, kernel_shape, mode, config=None):
456
465
  if mode == "int8":
457
- self._int8_build(kernel_shape)
466
+ self._int8_build(kernel_shape, config)
458
467
  elif mode == "int4":
459
- self._int4_build(kernel_shape)
468
+ self._int4_build(kernel_shape, config)
460
469
  elif mode == "float8":
461
470
  self._float8_build()
462
471
  elif mode == "gptq":
@@ -465,11 +474,17 @@ class EinsumDense(Layer):
465
474
  raise self._quantization_mode_error(mode)
466
475
  self._is_quantized = True
467
476
 
468
- def _int8_build(self, kernel_shape):
477
+ def _int8_build(self, kernel_shape, config=None):
469
478
  self._set_quantization_info()
470
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(
471
- axis=self._input_reduced_axes
479
+ self.inputs_quantizer = (
480
+ QuantizationConfig.activation_quantizer_or_default(
481
+ config,
482
+ quantizers.AbsMaxQuantizer(),
483
+ )
472
484
  )
485
+ # If the config provided a default AbsMaxQuantizer, we need to
486
+ # override the axis to match the equation's reduction axes.
487
+ self.quantization_axis = tuple(self._input_reduced_axes)
473
488
  self._kernel = self.add_weight(
474
489
  name="kernel",
475
490
  shape=kernel_shape,
@@ -495,6 +510,8 @@ class EinsumDense(Layer):
495
510
  group_size: int; contiguous input-group size for quantization
496
511
  (=-1 means per-output-channel with no grouping).
497
512
  """
513
+ from keras.src.quantizers import gptq_core
514
+
498
515
  # Ensures the forward pass uses the original high-precision kernel
499
516
  # until calibration has been performed.
500
517
  self.is_gptq_calibrated = False
@@ -505,12 +522,7 @@ class EinsumDense(Layer):
505
522
  columns = kernel_shape[1]
506
523
  elif len(kernel_shape) == 3:
507
524
  shape = list(self.original_kernel_shape)
508
- try:
509
- d_model_dim_index = shape.index(max(shape))
510
- except ValueError:
511
- raise TypeError(
512
- f"Could not determine hidden dimension from shape {shape}"
513
- )
525
+ d_model_dim_index = shape.index(max(shape))
514
526
 
515
527
  if d_model_dim_index == 0: # QKV projection case
516
528
  in_features, heads, head_dim = shape
@@ -527,18 +539,20 @@ class EinsumDense(Layer):
527
539
  else:
528
540
  raise ValueError("Could not determine row/column split.")
529
541
 
530
- group_size = self._get_gptq_group_size(config)
531
- if group_size == -1:
532
- n_groups = 1
533
- else:
534
- n_groups = math.ceil(rows / group_size)
542
+ group_size = gptq_core.get_group_size_for_layer(self, config)
543
+ n_groups = 1 if group_size == -1 else math.ceil(rows / group_size)
535
544
 
536
- if hasattr(self, "_set_quantization_info"):
537
- self._set_quantization_info()
545
+ self.gptq_unpacked_column_size = columns
546
+
547
+ weight_bits = gptq_core.get_weight_bits_for_layer(self, config)
548
+ # For 4-bit weights, we pack two values per byte.
549
+ kernel_columns = (columns + 1) // 2 if weight_bits == 4 else columns
550
+
551
+ self._set_quantization_info()
538
552
 
539
553
  self.quantized_kernel = self.add_weight(
540
554
  name="kernel",
541
- shape=(columns, rows),
555
+ shape=(kernel_columns, rows),
542
556
  initializer="zeros",
543
557
  dtype="uint8",
544
558
  trainable=False,
@@ -567,11 +581,26 @@ class EinsumDense(Layer):
567
581
  )
568
582
 
569
583
  def _gptq_call(self, inputs, training=False):
584
+ from keras.src.quantizers import gptq_core
585
+
570
586
  if not self.is_gptq_calibrated:
571
587
  W = self._kernel
572
588
  else:
589
+ should_unpack = (
590
+ gptq_core.get_weight_bits_for_layer(self, config=None) == 4
591
+ )
592
+ W = (
593
+ quantizers.unpack_int4(
594
+ self.quantized_kernel,
595
+ orig_len=self.gptq_unpacked_column_size,
596
+ axis=0,
597
+ dtype="uint8",
598
+ )
599
+ if should_unpack
600
+ else self.quantized_kernel
601
+ )
573
602
  W = dequantize_with_sz_map(
574
- self.quantized_kernel,
603
+ W,
575
604
  self.kernel_scale,
576
605
  self.kernel_zero,
577
606
  self.g_idx,
@@ -587,7 +616,7 @@ class EinsumDense(Layer):
587
616
  y = self.activation(y)
588
617
  return y
589
618
 
590
- def _int4_build(self, kernel_shape):
619
+ def _int4_build(self, kernel_shape, config=None):
591
620
  """Build variables for int4 quantization.
592
621
 
593
622
  The packed int4 kernel stores two int4 values within a single int8
@@ -599,9 +628,15 @@ class EinsumDense(Layer):
599
628
  self._set_quantization_info()
600
629
 
601
630
  # Quantizer for the inputs (per the reduced axes)
602
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(
603
- axis=self._input_reduced_axes
631
+ self.inputs_quantizer = (
632
+ QuantizationConfig.activation_quantizer_or_default(
633
+ config,
634
+ quantizers.AbsMaxQuantizer(),
635
+ )
604
636
  )
637
+ # If the config provided a default AbsMaxQuantizer, we need to
638
+ # override the axis to match the equation's reduction axes.
639
+ self.quantization_axis = tuple(self._input_reduced_axes)
605
640
 
606
641
  # Choose the axis to perform int4 packing - use the first reduced axis
607
642
  # for the kernel (analogous to the input dimension of a Dense layer).
@@ -723,13 +758,36 @@ class EinsumDense(Layer):
723
758
  )
724
759
  return (inputs_grad, None, None)
725
760
 
726
- inputs, inputs_scale = self.inputs_quantizer(inputs)
727
- x = ops.einsum(self.equation, inputs, kernel)
728
- # Deal with `inputs_scale`
729
- inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input")
730
- # De-scale outputs
731
- x = ops.cast(x, self.compute_dtype)
732
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
761
+ if self.inputs_quantizer:
762
+ inputs, inputs_scale = self.inputs_quantizer(
763
+ inputs, axis=self.quantization_axis
764
+ )
765
+ # Align `inputs_scale` axes with the output
766
+ # for correct broadcasting
767
+ inputs_scale = self._adjust_scale_for_quant(
768
+ inputs_scale, "input"
769
+ )
770
+ x = ops.einsum(self.equation, inputs, kernel)
771
+ # De-scale outputs
772
+ x = ops.cast(x, self.compute_dtype)
773
+ x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
774
+ else:
775
+ # Weight-only quantization: dequantize kernel and use float
776
+ # einsum. This is a workaround for PyTorch's einsum which
777
+ # doesn't support mixed-precision inputs (float input,
778
+ # int8 kernel).
779
+ if backend.backend() == "torch":
780
+ kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
781
+ float_kernel = ops.divide(
782
+ ops.cast(kernel, dtype=self.compute_dtype),
783
+ kernel_scale,
784
+ )
785
+ x = ops.einsum(self.equation, inputs, float_kernel)
786
+ else:
787
+ x = ops.einsum(self.equation, inputs, kernel)
788
+ # De-scale outputs
789
+ x = ops.cast(x, self.compute_dtype)
790
+ x = ops.divide(x, kernel_scale)
733
791
  return x, grad_fn
734
792
 
735
793
  x = einsum_with_inputs_gradient(
@@ -799,17 +857,38 @@ class EinsumDense(Layer):
799
857
  return (inputs_grad, None, None)
800
858
 
801
859
  # Quantize inputs per `self.inputs_quantizer`.
802
- inputs_q, inputs_scale = self.inputs_quantizer(inputs)
803
-
804
- # Compute einsum on quantized inputs and unpacked int4 kernel.
805
- x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
806
-
807
- # Align `inputs_scale` axes with the output for correct broadcasting
808
- inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input")
809
-
810
- # De-scale outputs.
811
- x = ops.cast(x, self.compute_dtype)
812
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
860
+ if self.inputs_quantizer:
861
+ inputs_q, inputs_scale = self.inputs_quantizer(
862
+ inputs, axis=self.quantization_axis
863
+ )
864
+ # Align `inputs_scale` axes with the output
865
+ # for correct broadcasting
866
+ inputs_scale = self._adjust_scale_for_quant(
867
+ inputs_scale, "input"
868
+ )
869
+ x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
870
+ # De-scale outputs
871
+ x = ops.cast(x, self.compute_dtype)
872
+ x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
873
+ else:
874
+ # Weight-only quantization: dequantize kernel and use float
875
+ # einsum. This is a workaround for PyTorch's einsum which
876
+ # doesn't support mixed-precision inputs (float input,
877
+ # int4 kernel).
878
+ if backend.backend() == "torch":
879
+ # Align `kernel_scale` to the same layout as
880
+ # `unpacked_kernel`.
881
+ kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
882
+ float_kernel = ops.divide(
883
+ ops.cast(unpacked_kernel, dtype=self.compute_dtype),
884
+ kernel_scale,
885
+ )
886
+ x = ops.einsum(self.equation, inputs, float_kernel)
887
+ else:
888
+ x = ops.einsum(self.equation, inputs, unpacked_kernel)
889
+ # De-scale outputs
890
+ x = ops.cast(x, self.compute_dtype)
891
+ x = ops.divide(x, kernel_scale)
813
892
  return x, grad_fn
814
893
 
815
894
  x = einsum_with_inputs_gradient(
@@ -923,30 +1002,40 @@ class EinsumDense(Layer):
923
1002
  x = self.activation(x)
924
1003
  return x
925
1004
 
926
- def quantize(self, mode, type_check=True, config=None):
1005
+ def quantize(self, mode=None, type_check=True, config=None):
927
1006
  # Prevent quantization of the subclasses
928
1007
  if type_check and (type(self) is not EinsumDense):
929
1008
  raise self._not_implemented_error(self.quantize)
930
1009
 
1010
+ self.quantization_config = config
1011
+
931
1012
  kernel_shape = self._kernel.shape
932
1013
  if mode in ("int8", "int4", "gptq"):
933
1014
  self._set_quantization_info()
934
1015
 
935
1016
  if mode == "int8":
936
1017
  # Quantize `self._kernel` to int8 and compute corresponding scale
937
- kernel_value, kernel_scale = quantizers.abs_max_quantize(
938
- self._kernel, axis=self._kernel_reduced_axes, to_numpy=True
1018
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
1019
+ self.quantization_config,
1020
+ quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes),
1021
+ )
1022
+ kernel_value, kernel_scale = weight_quantizer(
1023
+ self._kernel, to_numpy=True
939
1024
  )
940
1025
  kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
941
1026
  del self._kernel
942
1027
  elif mode == "int4":
943
1028
  # Quantize to int4 values (stored in int8 dtype, range [-8, 7])
944
- kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
945
- self._kernel,
946
- axis=self._kernel_reduced_axes,
947
- value_range=(-8, 7),
948
- dtype="int8",
949
- to_numpy=True,
1029
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
1030
+ self.quantization_config,
1031
+ quantizers.AbsMaxQuantizer(
1032
+ axis=self._kernel_reduced_axes,
1033
+ value_range=(-8, 7),
1034
+ output_dtype="int8",
1035
+ ),
1036
+ )
1037
+ kernel_value_int4, kernel_scale = weight_quantizer(
1038
+ self._kernel, to_numpy=True
950
1039
  )
951
1040
  kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
952
1041
 
@@ -957,7 +1046,7 @@ class EinsumDense(Layer):
957
1046
  )
958
1047
  kernel_value = packed_kernel_value
959
1048
  del self._kernel
960
- self.quantized_build(kernel_shape, mode, config)
1049
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
961
1050
 
962
1051
  # Assign values to the newly created variables.
963
1052
  if mode in ("int8", "int4"):
@@ -968,7 +1057,7 @@ class EinsumDense(Layer):
968
1057
  if self.dtype_policy.quantization_mode is None:
969
1058
  policy_name = mode
970
1059
  if mode == "gptq":
971
- policy_name = config.dtype_policy_string()
1060
+ policy_name = self.quantization_config.dtype_policy_string()
972
1061
  policy = dtype_policies.get(
973
1062
  f"{policy_name}_from_{self.dtype_policy.name}"
974
1063
  )
@@ -1165,46 +1254,6 @@ class EinsumDense(Layer):
1165
1254
  self._kernel_reverse_transpose_axes,
1166
1255
  ) = _analyze_quantization_info(self.equation, self.input_spec.ndim)
1167
1256
 
1168
- def _get_gptq_group_size(self, config):
1169
- """Determine the group size for GPTQ quantization.
1170
-
1171
- The group size can be specified either through the `config` argument
1172
- or through the `dtype_policy` if it is of type `GPTQDTypePolicy`.
1173
-
1174
- The config argument is usually available when quantizing the layer
1175
- via the `quantize` method. If the layer was deserialized from a
1176
- saved model, the group size should be specified in the `dtype_policy`.
1177
-
1178
- Args:
1179
- config: An optional configuration object that may contain the
1180
- `group_size` attribute.
1181
- Returns:
1182
- int. The determined group size for GPTQ quantization.
1183
- Raises:
1184
- ValueError: If the group size is not specified in either the
1185
- `config` or the `dtype_policy`.
1186
- """
1187
- if config and isinstance(config, GPTQConfig):
1188
- return config.group_size
1189
- elif isinstance(self.dtype_policy, GPTQDTypePolicy):
1190
- return self.dtype_policy.group_size
1191
- elif isinstance(self.dtype_policy, DTypePolicyMap):
1192
- policy = self.dtype_policy[self.path]
1193
- if not isinstance(policy, GPTQDTypePolicy):
1194
- # This should never happen based on how we set the
1195
- # quantization mode, but we check just in case.
1196
- raise ValueError(
1197
- "Expected a `dtype_policy` of type `GPTQDTypePolicy`."
1198
- f"Got: {type(policy)}"
1199
- )
1200
- return policy.group_size
1201
- else:
1202
- raise ValueError(
1203
- "For GPTQ quantization, the group_size must be specified"
1204
- "either through a `dtype_policy` of type "
1205
- "`GPTQDTypePolicy` or the `config` argument."
1206
- )
1207
-
1208
1257
 
1209
1258
  def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):
1210
1259
  """Parses an einsum string to determine the shapes of the weights.