keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -9,12 +9,11 @@ from keras.src import ops
9
9
  from keras.src import quantizers
10
10
  from keras.src import regularizers
11
11
  from keras.src.api_export import keras_export
12
- from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
13
- from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
14
12
  from keras.src.layers.input_spec import InputSpec
15
13
  from keras.src.layers.layer import Layer
16
- from keras.src.quantizers.gptq_config import GPTQConfig
14
+ from keras.src.quantizers.quantization_config import QuantizationConfig
17
15
  from keras.src.quantizers.quantizers import dequantize_with_sz_map
16
+ from keras.src.saving import serialization_lib
18
17
 
19
18
 
20
19
  @keras_export("keras.layers.Dense")
@@ -26,7 +25,9 @@ class Dense(Layer):
26
25
  where `activation` is the element-wise activation function
27
26
  passed as the `activation` argument, `kernel` is a weights matrix
28
27
  created by the layer, and `bias` is a bias vector created by the layer
29
- (only applicable if `use_bias` is `True`).
28
+ (only applicable if `use_bias` is `True`). When this layer is
29
+ followed by a `BatchNormalization` layer, it is recommended to set
30
+ `use_bias=False` as `BatchNormalization` has its own bias term.
30
31
 
31
32
  Note: If the input to the layer has a rank greater than 2, `Dense`
32
33
  computes the dot product between the `inputs` and the `kernel` along the
@@ -93,8 +94,15 @@ class Dense(Layer):
93
94
  bias_constraint=None,
94
95
  lora_rank=None,
95
96
  lora_alpha=None,
97
+ quantization_config=None,
96
98
  **kwargs,
97
99
  ):
100
+ if not isinstance(units, int) or units <= 0:
101
+ raise ValueError(
102
+ "Received an invalid value for `units`, expected a positive "
103
+ f"integer. Received: units={units}"
104
+ )
105
+
98
106
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
99
107
  self.units = units
100
108
  self.activation = activations.get(activation)
@@ -108,13 +116,18 @@ class Dense(Layer):
108
116
  self.lora_rank = lora_rank
109
117
  self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
110
118
  self.lora_enabled = False
119
+ self.quantization_config = quantization_config
111
120
  self.input_spec = InputSpec(min_ndim=2)
112
121
  self.supports_masking = True
113
122
 
114
123
  def build(self, input_shape):
115
124
  kernel_shape = (input_shape[-1], self.units)
116
125
  if self.quantization_mode:
117
- self.quantized_build(kernel_shape, mode=self.quantization_mode)
126
+ self.quantized_build(
127
+ kernel_shape,
128
+ mode=self.quantization_mode,
129
+ config=self.quantization_config,
130
+ )
118
131
  if self.quantization_mode not in ("int8", "int4", "gptq"):
119
132
  # If the layer is quantized to int8 or int4, `self._kernel` will be
120
133
  # added in `self._int8_build` or `_int4_build`. Therefore, we skip
@@ -143,22 +156,47 @@ class Dense(Layer):
143
156
 
144
157
  @property
145
158
  def kernel(self):
159
+ from keras.src.quantizers import gptq_core
160
+
146
161
  if not self.built:
147
162
  raise AttributeError(
148
163
  "You must build the layer before accessing `kernel`."
149
164
  )
150
- if (
151
- getattr(self, "is_gptq_calibrated", False)
152
- and self.quantization_mode == "gptq"
153
- ):
154
- return self.quantized_kernel
155
- kernel = self._kernel
156
- if self.quantization_mode == "int4":
157
- kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
165
+
166
+ mode = self.quantization_mode
167
+ is_gptq = mode == "gptq"
168
+ is_int4 = mode == "int4"
169
+ calibrated = bool(getattr(self, "is_gptq_calibrated", False))
170
+ gptq_bits = (
171
+ gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None
172
+ )
173
+
174
+ # Decide the source tensor first (packed vs already-quantized vs plain
175
+ # kernel)
176
+ if is_gptq and calibrated and gptq_bits != 4:
177
+ # calibrated GPTQ, not 4-bit, no unpacking needed
178
+ kernel = self.quantized_kernel
179
+ else:
180
+ # Start with the stored kernel
181
+ kernel = getattr(self, "_kernel", None)
182
+
183
+ # Handle int4 unpacking cases in one place
184
+ if is_int4:
185
+ kernel = quantizers.unpack_int4(kernel, self._orig_input_dim)
186
+ elif is_gptq and calibrated and gptq_bits == 4:
187
+ kernel = quantizers.unpack_int4(
188
+ self.quantized_kernel,
189
+ orig_len=self.units,
190
+ axis=0,
191
+ dtype="uint8",
192
+ )
193
+
194
+ # Apply LoRA once at the end.
158
195
  if self.lora_enabled:
159
- return kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
196
+ kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
160
197
  self.lora_kernel_a, self.lora_kernel_b
161
198
  )
199
+
162
200
  return kernel
163
201
 
164
202
  def call(self, inputs, training=None):
@@ -236,25 +274,25 @@ class Dense(Layer):
236
274
  if not self.built:
237
275
  return
238
276
  mode = self.quantization_mode
239
- if mode not in self.quantization_variable_spec:
277
+ if mode not in self.variable_serialization_spec:
240
278
  raise self._quantization_mode_error(mode)
241
279
 
242
280
  # Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
243
281
  # for None/gptq)
244
282
  kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
245
-
246
- # Save the variables using the name as the key.
247
- if mode != "gptq":
248
- store["kernel"] = kernel_value
249
- if self.bias is not None:
250
- store["bias"] = self.bias
251
- for name in self.quantization_variable_spec[mode]:
252
- if name == "kernel_scale" and mode in ("int4", "int8"):
283
+ idx = 0
284
+ for name in self.variable_serialization_spec[mode]:
285
+ if name == "kernel":
286
+ store[str(idx)] = kernel_value
287
+ elif name == "bias" and self.bias is None:
288
+ continue
289
+ elif name == "kernel_scale" and mode in ("int4", "int8"):
253
290
  # For int4/int8, the merged LoRA scale (if any) comes from
254
291
  # `_get_kernel_with_merged_lora()`
255
- store[name] = merged_kernel_scale
292
+ store[str(idx)] = merged_kernel_scale
256
293
  else:
257
- store[name] = getattr(self, name)
294
+ store[str(idx)] = getattr(self, name)
295
+ idx += 1
258
296
 
259
297
  def load_own_variables(self, store):
260
298
  if not self.lora_enabled:
@@ -263,39 +301,21 @@ class Dense(Layer):
263
301
  if not self.built:
264
302
  return
265
303
  mode = self.quantization_mode
266
- if mode not in self.quantization_variable_spec:
304
+ if mode not in self.variable_serialization_spec:
267
305
  raise self._quantization_mode_error(mode)
268
306
 
269
- # Determine whether to use the legacy loading method.
270
- if "0" in store:
271
- return self._legacy_load_own_variables(store)
307
+ # A saved GPTQ quantized model will always be calibrated.
308
+ self.is_gptq_calibrated = mode == "gptq"
272
309
 
273
- # Load the variables using the name as the key.
274
- if mode != "gptq":
275
- self._kernel.assign(store["kernel"])
276
- if self.bias is not None:
277
- self.bias.assign(store["bias"])
278
- for name in self.quantization_variable_spec[mode]:
279
- getattr(self, name).assign(store[name])
280
- if self.lora_enabled:
281
- self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
282
- self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
283
-
284
- def _legacy_load_own_variables(self, store):
285
- # The keys of the `store` will be saved as determined because the
286
- # default ordering will change after quantization
287
- mode = self.quantization_mode
288
- targets = []
289
- if mode != "gptq":
290
- targets.append(self._kernel)
291
- if self.bias is not None:
292
- targets.append(self.bias)
293
- targets.extend(
294
- getattr(self, name)
295
- for name in self.quantization_variable_spec[mode]
296
- )
297
- for i, variable in enumerate(targets):
298
- variable.assign(store[str(i)])
310
+ idx = 0
311
+ for name in self.variable_serialization_spec[mode]:
312
+ if name == "kernel":
313
+ self._kernel.assign(store[str(idx)])
314
+ elif name == "bias" and self.bias is None:
315
+ continue
316
+ else:
317
+ getattr(self, name).assign(store[str(idx)])
318
+ idx += 1
299
319
  if self.lora_enabled:
300
320
  self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
301
321
  self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
@@ -316,59 +336,51 @@ class Dense(Layer):
316
336
  "bias_regularizer": regularizers.serialize(self.bias_regularizer),
317
337
  "kernel_constraint": constraints.serialize(self.kernel_constraint),
318
338
  "bias_constraint": constraints.serialize(self.bias_constraint),
339
+ "quantization_config": serialization_lib.serialize_keras_object(
340
+ self.quantization_config
341
+ ),
319
342
  }
320
343
  if self.lora_rank:
321
344
  config["lora_rank"] = self.lora_rank
322
345
  config["lora_alpha"] = self.lora_alpha
323
346
  return {**base_config, **config}
324
347
 
325
- def _check_load_own_variables(self, store):
326
- all_vars = self._trainable_variables + self._non_trainable_variables
327
- if len(store.keys()) != len(all_vars):
328
- if len(all_vars) == 0 and not self.built:
329
- raise ValueError(
330
- f"Layer '{self.name}' was never built "
331
- "and thus it doesn't have any variables. "
332
- f"However the weights file lists {len(store.keys())} "
333
- "variables for this layer.\n"
334
- "In most cases, this error indicates that either:\n\n"
335
- "1. The layer is owned by a parent layer that "
336
- "implements a `build()` method, but calling the "
337
- "parent's `build()` method did NOT create the state of "
338
- f"the child layer '{self.name}'. A `build()` method "
339
- "must create ALL state for the layer, including "
340
- "the state of any children layers.\n\n"
341
- "2. You need to implement "
342
- "the `def build_from_config(self, config)` method "
343
- f"on layer '{self.name}', to specify how to rebuild "
344
- "it during loading. "
345
- "In this case, you might also want to implement the "
346
- "method that generates the build config at saving time, "
347
- "`def get_build_config(self)`. "
348
- "The method `build_from_config()` is meant "
349
- "to create the state "
350
- "of the layer (i.e. its variables) upon deserialization.",
351
- )
352
- raise ValueError(
353
- f"Layer '{self.name}' expected {len(all_vars)} variables, "
354
- "but received "
355
- f"{len(store.keys())} variables during loading. "
356
- f"Expected: {[v.name for v in all_vars]}"
348
+ @classmethod
349
+ def from_config(cls, config):
350
+ config = config.copy()
351
+ config["quantization_config"] = (
352
+ serialization_lib.deserialize_keras_object(
353
+ config.get("quantization_config", None)
357
354
  )
355
+ )
356
+ return super().from_config(config)
358
357
 
359
358
  @property
360
- def quantization_variable_spec(self):
361
- """Returns a dict mapping quantization modes to variable names.
359
+ def variable_serialization_spec(self):
360
+ """Returns a dict mapping quantization modes to variable names in order.
362
361
 
363
362
  This spec is used by `save_own_variables` and `load_own_variables` to
364
- determine which variables should be saved/loaded for each quantization
365
- mode.
363
+ determine the correct ordering of variables during serialization for
364
+ each quantization mode. `None` means no quantization.
366
365
  """
367
366
  return {
368
- None: [],
369
- "int8": ["kernel_scale"],
370
- "int4": ["kernel_scale"],
367
+ None: [
368
+ "kernel",
369
+ "bias",
370
+ ],
371
+ "int8": [
372
+ "kernel",
373
+ "bias",
374
+ "kernel_scale",
375
+ ],
376
+ "int4": [
377
+ "kernel",
378
+ "bias",
379
+ "kernel_scale",
380
+ ],
371
381
  "float8": [
382
+ "kernel",
383
+ "bias",
372
384
  "inputs_scale",
373
385
  "inputs_amax_history",
374
386
  "kernel_scale",
@@ -377,6 +389,7 @@ class Dense(Layer):
377
389
  "outputs_grad_amax_history",
378
390
  ],
379
391
  "gptq": [
392
+ "bias",
380
393
  "quantized_kernel",
381
394
  "kernel_scale",
382
395
  "kernel_zero",
@@ -386,9 +399,9 @@ class Dense(Layer):
386
399
 
387
400
  def quantized_build(self, kernel_shape, mode, config=None):
388
401
  if mode == "int8":
389
- self._int8_build(kernel_shape)
402
+ self._int8_build(kernel_shape, config)
390
403
  elif mode == "int4":
391
- self._int4_build(kernel_shape)
404
+ self._int4_build(kernel_shape, config)
392
405
  elif mode == "float8":
393
406
  self._float8_build()
394
407
  elif mode == "gptq":
@@ -397,8 +410,13 @@ class Dense(Layer):
397
410
  raise self._quantization_mode_error(mode)
398
411
  self._is_quantized = True
399
412
 
400
- def _int8_build(self, kernel_shape):
401
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
413
+ def _int8_build(self, kernel_shape, config=None):
414
+ self.inputs_quantizer = (
415
+ QuantizationConfig.activation_quantizer_or_default(
416
+ config, quantizers.AbsMaxQuantizer()
417
+ )
418
+ )
419
+
402
420
  self._kernel = self.add_weight(
403
421
  name="kernel",
404
422
  shape=kernel_shape,
@@ -414,23 +432,33 @@ class Dense(Layer):
414
432
  )
415
433
 
416
434
  def _gptq_build(self, kernel_shape, config):
435
+ from keras.src.quantizers import gptq_core
436
+
417
437
  # Ensures the forward pass uses the original high-precision kernel
418
438
  # until calibration has been performed.
419
439
  self.is_gptq_calibrated = False
420
440
  self.kernel_shape = kernel_shape
441
+
442
+ weight_bits = gptq_core.get_weight_bits_for_layer(self, config)
443
+ # For 4-bit weights, we pack two values per byte.
444
+ units = (
445
+ (kernel_shape[1] + 1) // 2 if weight_bits == 4 else kernel_shape[1]
446
+ )
447
+
421
448
  self.quantized_kernel = self.add_weight(
422
449
  name="kernel",
423
- shape=(kernel_shape[1], kernel_shape[0]),
450
+ shape=(units, kernel_shape[0]),
424
451
  initializer="zeros",
425
452
  dtype="uint8",
426
453
  trainable=False,
427
454
  )
428
455
 
429
- group_size = self._get_gptq_group_size(config)
430
- if group_size == -1:
431
- n_groups = 1
432
- else:
433
- n_groups = math.ceil(self.kernel_shape[0] / group_size)
456
+ group_size = gptq_core.get_group_size_for_layer(self, config)
457
+ n_groups = (
458
+ 1
459
+ if group_size == -1
460
+ else math.ceil(self.kernel_shape[0] / group_size)
461
+ )
434
462
  self.kernel_scale = self.add_weight(
435
463
  name="kernel_scale",
436
464
  shape=(self.units, n_groups),
@@ -453,18 +481,31 @@ class Dense(Layer):
453
481
  )
454
482
 
455
483
  def _gptq_call(self, inputs, training=False):
484
+ from keras.src.quantizers import gptq_core
485
+
456
486
  if not self.is_gptq_calibrated:
457
487
  W = self._kernel
458
488
  else:
489
+ should_unpack = (
490
+ gptq_core.get_weight_bits_for_layer(self, config=None) == 4
491
+ )
459
492
  W = (
460
- ops.transpose(
461
- dequantize_with_sz_map(
462
- self.quantized_kernel,
463
- self.kernel_scale,
464
- self.kernel_zero,
465
- self.g_idx,
466
- )
467
- ),
493
+ quantizers.unpack_int4(
494
+ self.quantized_kernel,
495
+ orig_len=self.units,
496
+ axis=0,
497
+ dtype="uint8",
498
+ )
499
+ if should_unpack
500
+ else self.quantized_kernel
501
+ )
502
+ W = ops.transpose(
503
+ dequantize_with_sz_map(
504
+ W,
505
+ self.kernel_scale,
506
+ self.kernel_zero,
507
+ self.g_idx,
508
+ )
468
509
  )
469
510
 
470
511
  y = ops.matmul(inputs, W)
@@ -474,7 +515,7 @@ class Dense(Layer):
474
515
  y = self.activation(y)
475
516
  return y
476
517
 
477
- def _int4_build(self, kernel_shape):
518
+ def _int4_build(self, kernel_shape, config=None):
478
519
  """Build variables for int4 quantization.
479
520
 
480
521
  `kernel_shape` is the *original* float32 kernel shape
@@ -483,8 +524,10 @@ class Dense(Layer):
483
524
  int8 byte.
484
525
  """
485
526
  # Per-channel int8 quantizer for the last axis (features).
486
- self.inputs_quantizer = quantizers.AbsMaxQuantizer(
487
- axis=-1,
527
+ self.inputs_quantizer = (
528
+ QuantizationConfig.activation_quantizer_or_default(
529
+ config, quantizers.AbsMaxQuantizer()
530
+ )
488
531
  )
489
532
  input_dim, output_dim = kernel_shape
490
533
  packed_rows = (input_dim + 1) // 2 # ceil for odd dims
@@ -573,11 +616,15 @@ class Dense(Layer):
573
616
  inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
574
617
  return (inputs_grad, None, None)
575
618
 
576
- inputs, inputs_scale = self.inputs_quantizer(inputs)
619
+ output_scale = kernel_scale
620
+ if self.inputs_quantizer:
621
+ inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
622
+ output_scale = ops.multiply(output_scale, inputs_scale)
623
+
577
624
  x = ops.matmul(inputs, kernel)
578
625
  # De-scale outputs
579
626
  x = ops.cast(x, self.compute_dtype)
580
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
627
+ x = ops.divide(x, output_scale)
581
628
  return x, grad_fn
582
629
 
583
630
  x = matmul_with_inputs_gradient(
@@ -624,10 +671,15 @@ class Dense(Layer):
624
671
  inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
625
672
  return (inputs_grad, None, None)
626
673
 
627
- inputs, inputs_scale = self.inputs_quantizer(inputs)
674
+ output_scale = kernel_scale
675
+
676
+ if self.inputs_quantizer:
677
+ inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)
678
+ output_scale = ops.multiply(output_scale, inputs_scale)
679
+
628
680
  x = ops.matmul(inputs, unpacked_kernel)
629
681
  x = ops.cast(x, self.compute_dtype)
630
- x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
682
+ x = ops.divide(x, output_scale)
631
683
  return x, grad_fn
632
684
 
633
685
  x = matmul_with_inputs_gradient(
@@ -739,30 +791,37 @@ class Dense(Layer):
739
791
  x = self.activation(x)
740
792
  return x
741
793
 
742
- def quantize(self, mode, type_check=True, config=None):
794
+ def quantize(self, mode=None, type_check=True, config=None):
743
795
  # Prevent quantization of the subclasses
744
796
  if type_check and (type(self) is not Dense):
745
797
  raise self._not_implemented_error(self.quantize)
746
798
 
799
+ self.quantization_config = config
800
+
747
801
  kernel_shape = self._kernel.shape
748
802
  if mode == "int8":
749
- kernel_value, kernel_scale = quantizers.abs_max_quantize(
750
- self._kernel, axis=0, to_numpy=True
803
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
804
+ self.quantization_config, quantizers.AbsMaxQuantizer(axis=0)
805
+ )
806
+ kernel_value, kernel_scale = weight_quantizer(
807
+ self._kernel, to_numpy=True
751
808
  )
752
809
  kernel_scale = ops.squeeze(kernel_scale, axis=0)
753
810
  del self._kernel
754
811
  # Build variables for int8 mode
755
- self.quantized_build(kernel_shape, mode)
812
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
756
813
  self._kernel.assign(kernel_value)
757
814
  self.kernel_scale.assign(kernel_scale)
758
815
  elif mode == "int4":
759
816
  # 1. Quantize to int4 values (still int8 dtype, range [-8,7])
760
- kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
761
- self._kernel,
762
- axis=0,
763
- value_range=(-8, 7),
764
- dtype="int8",
765
- to_numpy=True,
817
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
818
+ self.quantization_config,
819
+ quantizers.AbsMaxQuantizer(
820
+ axis=0, value_range=(-8, 7), output_dtype="int8"
821
+ ),
822
+ )
823
+ kernel_value_int4, kernel_scale = weight_quantizer(
824
+ self._kernel, to_numpy=True
766
825
  )
767
826
  kernel_scale = ops.squeeze(kernel_scale, axis=0)
768
827
  # 2. Pack two int4 values into a single int8 byte.
@@ -770,12 +829,12 @@ class Dense(Layer):
770
829
  del self._kernel
771
830
  # Build variables using the original kernel shape; _int4_build will
772
831
  # compute the packed shape internally.
773
- self.quantized_build(kernel_shape, mode)
832
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
774
833
  # Assign packed values.
775
834
  self._kernel.assign(packed_kernel_value)
776
835
  self.kernel_scale.assign(kernel_scale)
777
836
  elif mode == "gptq":
778
- self.quantized_build(kernel_shape, mode, config)
837
+ self.quantized_build(kernel_shape, mode, self.quantization_config)
779
838
  elif mode == "float8":
780
839
  self.quantized_build(kernel_shape, mode)
781
840
  else:
@@ -787,7 +846,7 @@ class Dense(Layer):
787
846
 
788
847
  policy_name = mode
789
848
  if mode == "gptq":
790
- policy_name = config.dtype_policy_string()
849
+ policy_name = self.quantization_config.dtype_policy_string()
791
850
  policy = dtype_policies.get(
792
851
  f"{policy_name}_from_{self.dtype_policy.name}"
793
852
  )
@@ -875,43 +934,3 @@ class Dense(Layer):
875
934
  else:
876
935
  kernel_value = requantized_kernel
877
936
  return kernel_value, kernel_scale
878
-
879
- def _get_gptq_group_size(self, config):
880
- """Determine the group size for GPTQ quantization.
881
-
882
- The group size can be specified either through the `config` argument
883
- or through the `dtype_policy` if it is of type `GPTQDTypePolicy`.
884
-
885
- The config argument is usually available when quantizing the layer
886
- via the `quantize` method. If the layer was deserialized from a
887
- saved model, the group size should be specified in the `dtype_policy`.
888
-
889
- Args:
890
- config: An optional configuration object that may contain the
891
- `group_size` attribute.
892
- Returns:
893
- int. The determined group size for GPTQ quantization.
894
- Raises:
895
- ValueError: If the group size is not specified in either the
896
- `config` or the `dtype_policy`.
897
- """
898
- if config and isinstance(config, GPTQConfig):
899
- return config.group_size
900
- elif isinstance(self.dtype_policy, GPTQDTypePolicy):
901
- return self.dtype_policy.group_size
902
- elif isinstance(self.dtype_policy, DTypePolicyMap):
903
- policy = self.dtype_policy[self.path]
904
- if not isinstance(policy, GPTQDTypePolicy):
905
- # This should never happen based on how we set the
906
- # quantization mode, but we check just in case.
907
- raise ValueError(
908
- "Expected a `dtype_policy` of type `GPTQDTypePolicy`."
909
- f"Got: {type(policy)}"
910
- )
911
- return policy.group_size
912
- else:
913
- raise ValueError(
914
- "For GPTQ quantization, the group_size must be specified"
915
- "either through a `dtype_policy` of type "
916
- "`GPTQDTypePolicy` or the `config` argument."
917
- )