keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,140 @@
1
+ from keras.src.api_export import keras_export
2
+ from keras.src.quantizers.quantization_config import QuantizationConfig
3
+
4
+
5
+ @keras_export("keras.quantizers.AWQConfig")
6
+ class AWQConfig(QuantizationConfig):
7
+ """Configuration class for AWQ (Activation-aware Weight Quantization).
8
+
9
+ AWQ is a post-training quantization method that identifies and protects
10
+ salient weights based on activation magnitudes. It applies per-channel
11
+ scaling before quantization to minimize accuracy loss.
12
+
13
+ Methodology:
14
+ 1. Collects activation statistics from calibration data
15
+ 2. Identifies salient weight channels based on activation magnitudes
16
+ 3. Searches for optimal per-channel scaling factors via grid search
17
+ 4. Applies scaling before quantization to protect important weights
18
+
19
+ References:
20
+ - Original AWQ paper: "AWQ: Activation-aware Weight Quantization for
21
+ LLM Compression and Acceleration" (https://arxiv.org/abs/2306.00978)
22
+ - Reference implementation: https://github.com/mit-han-lab/llm-awq
23
+
24
+ Args:
25
+ dataset: The calibration dataset. It can be an iterable that yields
26
+ strings or pre-tokenized numerical tensors (e.g., a list of
27
+ strings, a generator, or a NumPy array). This data is used to
28
+ analyze activation patterns.
29
+ tokenizer: A tokenizer instance (or a similar callable) that is used
30
+ to process the `dataset`.
31
+ weight_bits: The number of bits for weight quantization. AWQ presently
32
+ only supports 4-bit quantization. Defaults to 4.
33
+ num_samples: The number of calibration data samples to use from the
34
+ dataset. Defaults to 128.
35
+ sequence_length: The sequence length to use for each calibration
36
+ sample. Defaults to 512.
37
+ group_size: The size of weight groups to quantize together. A
38
+ `group_size` of -1 indicates per-channel quantization.
39
+ Defaults to 128.
40
+ num_grid_points: The number of grid search points for finding optimal
41
+ per-channel scales. Higher values may find better scales but
42
+ take longer. Defaults to 20.
43
+ quantization_layer_structure: A dictionary defining the model's
44
+ quantization structure. It should contain:
45
+ - "pre_block_layers": list of layers to run before the first
46
+ block (e.g., embedding layer).
47
+ - "sequential_blocks": list of transformer blocks to quantize
48
+ sequentially.
49
+ If not provided, the model must implement
50
+ `get_quantization_layer_structure`.
51
+
52
+ Example:
53
+ ```python
54
+ from keras.quantizers import AWQConfig
55
+
56
+ # Create configuration for 4-bit AWQ quantization
57
+ config = AWQConfig(
58
+ dataset=calibration_data, # Your calibration dataset
59
+ tokenizer=your_tokenizer, # Tokenizer for text data
60
+ num_samples=128, # Number of calibration samples
61
+ sequence_length=512, # Sequence length for each sample
62
+ group_size=128, # Weight grouping for quantization
63
+ num_grid_points=20, # Grid search points for scale search
64
+ )
65
+
66
+ # Apply quantization to your model
67
+ model.quantize("awq", config=config)
68
+ ```
69
+
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ dataset,
75
+ tokenizer,
76
+ *,
77
+ weight_bits: int = 4,
78
+ num_samples: int = 128,
79
+ sequence_length: int = 512,
80
+ group_size: int = 128,
81
+ num_grid_points: int = 20,
82
+ quantization_layer_structure: dict = None,
83
+ ):
84
+ super().__init__()
85
+ # AWQ only supports 4-bit quantization
86
+ if weight_bits != 4:
87
+ raise ValueError(
88
+ f"AWQ only supports 4-bit quantization. "
89
+ f"Received weight_bits={weight_bits}."
90
+ )
91
+ if num_samples <= 0:
92
+ raise ValueError("num_samples must be a positive integer.")
93
+ if sequence_length <= 0:
94
+ raise ValueError("sequence_length must be a positive integer.")
95
+ if group_size < -1 or group_size == 0:
96
+ raise ValueError(
97
+ "Invalid group_size. Supported values are -1 (per-channel) "
98
+ f"or a positive integer, but got {group_size}."
99
+ )
100
+ if num_grid_points <= 0:
101
+ raise ValueError("num_grid_points must be a positive integer.")
102
+
103
+ self.dataset = dataset
104
+ self.tokenizer = tokenizer
105
+ self.weight_bits = weight_bits
106
+ self.num_samples = num_samples
107
+ self.sequence_length = sequence_length
108
+ self.group_size = group_size
109
+ self.num_grid_points = num_grid_points
110
+ self.quantization_layer_structure = quantization_layer_structure
111
+
112
+ @property
113
+ def mode(self):
114
+ return "awq"
115
+
116
+ def dtype_policy_string(self):
117
+ """Returns the dtype policy string for this configuration.
118
+
119
+ Returns:
120
+ A string representing the dtype policy, e.g. "awq/4/128".
121
+ """
122
+ return f"awq/{self.weight_bits}/{self.group_size}"
123
+
124
+ def get_config(self):
125
+ return {
126
+ # Dataset and Tokenizer are only required for one-time
127
+ # calibration and are not saved in the config.
128
+ "dataset": None,
129
+ "tokenizer": None,
130
+ "weight_bits": self.weight_bits,
131
+ "num_samples": self.num_samples,
132
+ "sequence_length": self.sequence_length,
133
+ "group_size": self.group_size,
134
+ "num_grid_points": self.num_grid_points,
135
+ "quantization_layer_structure": self.quantization_layer_structure,
136
+ }
137
+
138
+ @classmethod
139
+ def from_config(cls, config):
140
+ return cls(**config)
@@ -0,0 +1,217 @@
1
+ """AWQ core functionality for layer-wise quantization.
2
+
3
+ This module provides the orchestration logic for applying AWQ quantization
4
+ to transformer models in a layer-by-layer fashion.
5
+ """
6
+
7
+ from contextlib import contextmanager
8
+
9
+ from absl import logging
10
+
11
+ from keras.src import ops
12
+ from keras.src import utils as keras_utils
13
+ from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
14
+ from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
15
+ from keras.src.quantizers.awq import AWQ
16
+ from keras.src.quantizers.awq_config import AWQConfig
17
+ from keras.src.quantizers.gptq_core import find_layers_in_block
18
+ from keras.src.quantizers.gptq_core import get_dataloader
19
+ from keras.src.quantizers.utils import should_quantize_layer
20
+
21
+
22
+ @contextmanager
23
+ def stream_activations(layers_map, awq_objects):
24
+ """Context manager to capture activations for AWQ calibration.
25
+
26
+ Temporarily patches layer.call methods to capture activation statistics
27
+ for computing per-channel scaling factors.
28
+
29
+ Args:
30
+ layers_map: Dict[str, Layer]. Mapping from layer names to layers.
31
+ awq_objects: Dict[str, AWQ]. Mapping from names to AWQ instances.
32
+
33
+ Yields:
34
+ None: The patched state is active only within the `with` block.
35
+ """
36
+ original_calls = {}
37
+
38
+ def create_hook(name, original_call_func):
39
+ def hook(*args, **kwargs):
40
+ inp = args[0] if args else kwargs["inputs"]
41
+ num_features = awq_objects[name].rows
42
+ input_2d = ops.reshape(inp, (-1, num_features))
43
+ awq_objects[name].update_activation_magnitudes(input_2d)
44
+ return original_call_func(*args, **kwargs)
45
+
46
+ return hook
47
+
48
+ try:
49
+ for name, layer in layers_map.items():
50
+ original_calls[name] = layer.call
51
+ layer.call = create_hook(name, layer.call)
52
+ yield
53
+ finally:
54
+ for name, layer in layers_map.items():
55
+ layer.call = original_calls[name]
56
+
57
+
58
+ def apply_awq_layerwise(dataloader, config, structure, filters=None):
59
+ """Apply AWQ quantization layer-by-layer to a Keras model.
60
+
61
+ This function processes the model sequentially, one block at a time:
62
+ 1. Captures activation statistics through calibration data forward pass
63
+ 2. Uses activation magnitudes to determine weight saliency
64
+ 3. Finds optimal per-channel scales via grid search
65
+ 4. Quantizes weights with AWQ scaling
66
+
67
+ Args:
68
+ dataloader: Calibration data as numpy array.
69
+ config: AWQConfig instance.
70
+ structure: Dict with 'pre_block_layers' and 'sequential_blocks'.
71
+ filters: Optional layer filters.
72
+ """
73
+ num_samples = config.num_samples
74
+ logging.info("Starting AWQ quantization...")
75
+
76
+ pre_layers = structure.get("pre_block_layers", [])
77
+ transformer_blocks = structure.get("sequential_blocks", [])
78
+
79
+ if not transformer_blocks:
80
+ raise ValueError(
81
+ "No sequential blocks found in the structure to quantize."
82
+ )
83
+
84
+ # Process inputs through pre-block layers (e.g., embedding)
85
+ inputs = []
86
+ for batch in dataloader:
87
+ batch = ops.convert_to_tensor(batch, dtype="int32")
88
+ for layer in pre_layers:
89
+ batch = layer(batch)
90
+ inputs.append(batch)
91
+
92
+ num_samples = min(num_samples, len(inputs))
93
+ progbar = keras_utils.Progbar(target=len(transformer_blocks))
94
+
95
+ for block_idx, block in enumerate(transformer_blocks):
96
+ logging.info(f"Quantizing Block {block_idx}")
97
+ sub_layers_map = find_layers_in_block(block)
98
+
99
+ # Apply filters
100
+ final_sub_layers_map = {}
101
+ for name, layer in sub_layers_map.items():
102
+ if not should_quantize_layer(layer, filters):
103
+ continue
104
+ final_sub_layers_map[name] = layer
105
+
106
+ sub_layers_map = final_sub_layers_map
107
+
108
+ if not sub_layers_map:
109
+ logging.info(
110
+ f" No quantizable layers found in block {block_idx}. Skipping."
111
+ )
112
+ else:
113
+ logging.info(f"Found layers: {list(sub_layers_map.keys())}")
114
+
115
+ # Create AWQ objects for each layer
116
+ awq_objects = {
117
+ name: AWQ(layer, config)
118
+ for name, layer in sub_layers_map.items()
119
+ }
120
+
121
+ # Capture activation statistics
122
+ with stream_activations(sub_layers_map, awq_objects):
123
+ for sample_idx in range(num_samples):
124
+ current_input = inputs[sample_idx]
125
+ if len(current_input.shape) == 2:
126
+ current_input = ops.expand_dims(current_input, axis=0)
127
+ _ = block(current_input)
128
+
129
+ # Quantize each layer
130
+ for name, awq_object in awq_objects.items():
131
+ logging.info(f"Quantizing {name}...")
132
+ awq_object.quantize_layer()
133
+ awq_object.free()
134
+
135
+ del awq_objects
136
+
137
+ # Generate inputs for next block
138
+ if block_idx < len(transformer_blocks) - 1:
139
+ logging.info(f"Generating inputs for block {block_idx + 1}...")
140
+ next_block_inputs = []
141
+ for sample_idx in range(num_samples):
142
+ current_input = inputs[sample_idx]
143
+ if len(current_input.shape) == 2:
144
+ current_input = ops.expand_dims(current_input, axis=0)
145
+ output = block(current_input)[0]
146
+ next_block_inputs.append(output)
147
+ inputs = next_block_inputs
148
+
149
+ progbar.update(current=block_idx + 1)
150
+
151
+ logging.info("AWQ quantization complete.")
152
+
153
+
154
+ def awq_quantize(config, quantization_layer_structure, filters=None):
155
+ """Main entry point for AWQ quantization.
156
+
157
+ Args:
158
+ config: AWQConfig instance.
159
+ quantization_layer_structure: Model structure dictionary.
160
+ filters: Optional layer filters.
161
+ """
162
+ if config.dataset is None or config.tokenizer is None:
163
+ raise ValueError(
164
+ "AWQ quantization requires a dataset and tokenizer. "
165
+ "Please provide them in the AWQConfig."
166
+ )
167
+
168
+ if quantization_layer_structure is None:
169
+ raise ValueError(
170
+ "For 'awq' mode, a valid quantization structure must be provided "
171
+ "either via `config.quantization_layer_structure` or by overriding "
172
+ "`model.get_quantization_layer_structure(mode)`. The structure "
173
+ "should be a dictionary with keys 'pre_block_layers' and "
174
+ "'sequential_blocks'."
175
+ )
176
+
177
+ # Load calibration data
178
+ dataloader = get_dataloader(
179
+ config.tokenizer,
180
+ config.sequence_length,
181
+ config.dataset,
182
+ num_samples=config.num_samples,
183
+ )
184
+
185
+ apply_awq_layerwise(
186
+ dataloader[: config.num_samples],
187
+ config,
188
+ quantization_layer_structure,
189
+ filters=filters,
190
+ )
191
+
192
+
193
+ def get_group_size_for_layer(layer, config):
194
+ """Get group size from config or dtype policy.
195
+
196
+ Args:
197
+ layer: The layer to get group size for.
198
+ config: Optional AWQConfig instance.
199
+
200
+ Returns:
201
+ int: The group size for quantization.
202
+
203
+ Raises:
204
+ ValueError: If group size cannot be determined.
205
+ """
206
+ if config and isinstance(config, AWQConfig):
207
+ return config.group_size
208
+ elif isinstance(layer.dtype_policy, AWQDTypePolicy):
209
+ return layer.dtype_policy.group_size
210
+ elif isinstance(layer.dtype_policy, DTypePolicyMap):
211
+ policy = layer.dtype_policy[layer.path]
212
+ if isinstance(policy, AWQDTypePolicy):
213
+ return policy.group_size
214
+ raise ValueError(
215
+ "For AWQ quantization, group_size must be specified "
216
+ "through AWQConfig or AWQDTypePolicy."
217
+ )
@@ -1,5 +1,4 @@
1
1
  import types
2
- from functools import partial
3
2
 
4
3
  from keras.src import ops
5
4
  from keras.src import quantizers
@@ -296,12 +295,7 @@ class GPTQ:
296
295
  # For EinsumDense, we determine the effective 2D dimensions.
297
296
  self.kernel_shape = layer.kernel.shape
298
297
  shape = list(self.kernel_shape)
299
- try:
300
- d_model_dim_index = shape.index(max(shape))
301
- except ValueError:
302
- raise TypeError(
303
- f"Could not determine hidden dimension from shape {shape}"
304
- )
298
+ d_model_dim_index = shape.index(max(shape))
305
299
 
306
300
  if d_model_dim_index == 0: # QKV projection case
307
301
  in_features, heads, head_dim = shape
@@ -471,7 +465,7 @@ class GPTQ:
471
465
  group_size=self.config.group_size,
472
466
  activation_order=self.config.activation_order,
473
467
  order_metric=ops.diagonal(hessian_matrix),
474
- compute_scale_zero=partial(self.quantizer.find_params, weight=True),
468
+ compute_scale_zero=self.quantizer.find_params,
475
469
  )
476
470
  quantized = ops.cast(
477
471
  quantized, self.original_layer.quantized_kernel.dtype
@@ -1,8 +1,9 @@
1
1
  from keras.src.api_export import keras_export
2
+ from keras.src.quantizers.quantization_config import QuantizationConfig
2
3
 
3
4
 
4
5
  @keras_export("keras.quantizers.GPTQConfig")
5
- class GPTQConfig:
6
+ class GPTQConfig(QuantizationConfig):
6
7
  """Configuration class for the GPTQ (Gradient-based Post-Training
7
8
  Quantization) algorithm.
8
9
 
@@ -131,6 +132,12 @@ class GPTQConfig:
131
132
  activation_order: (bool, optional) If `True`, reorders weight columns
132
133
  based on activation magnitude, which can improve quantization
133
134
  accuracy. Defaults to `False`.
135
+ quantization_layer_structure: (dict, optional) A dictionary defining the
136
+ model's quantization structure. It should contain:
137
+ - "pre_block_layers": list of layers to run before the first block.
138
+ - "sequential_blocks": list of blocks to be quantized sequentially.
139
+ If not provided, the model must implement
140
+ `get_quantization_layer_structure`.
134
141
  """
135
142
 
136
143
  def __init__(
@@ -146,7 +153,9 @@ class GPTQConfig:
146
153
  group_size: int = 128,
147
154
  symmetric: bool = False,
148
155
  activation_order: bool = False,
156
+ quantization_layer_structure: dict = None,
149
157
  ):
158
+ super().__init__()
150
159
  if weight_bits not in [2, 3, 4, 8]:
151
160
  raise ValueError(
152
161
  f"Unsupported weight_bits {weight_bits}. "
@@ -174,6 +183,32 @@ class GPTQConfig:
174
183
  self.group_size = group_size
175
184
  self.symmetric = symmetric
176
185
  self.activation_order = activation_order
186
+ self.quantization_layer_structure = quantization_layer_structure
187
+
188
+ def get_config(self):
189
+ return {
190
+ # Dataset and Tokenizer are only required for a one-time
191
+ # calibration and are not saved in the config.
192
+ "dataset": None,
193
+ "tokenizer": None,
194
+ "weight_bits": self.weight_bits,
195
+ "num_samples": self.num_samples,
196
+ "per_channel": self.per_channel,
197
+ "sequence_length": self.sequence_length,
198
+ "hessian_damping": self.hessian_damping,
199
+ "group_size": self.group_size,
200
+ "symmetric": self.symmetric,
201
+ "activation_order": self.activation_order,
202
+ "quantization_layer_structure": self.quantization_layer_structure,
203
+ }
204
+
205
+ @classmethod
206
+ def from_config(cls, config):
207
+ return cls(**config)
208
+
209
+ @property
210
+ def mode(self):
211
+ return "gptq"
177
212
 
178
213
  def dtype_policy_string(self):
179
214
  """Returns the dtype policy string for this configuration.
@@ -10,9 +10,9 @@ from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
10
10
  from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
11
11
  from keras.src.layers import Dense
12
12
  from keras.src.layers import EinsumDense
13
- from keras.src.layers import Embedding
14
13
  from keras.src.quantizers.gptq import GPTQ
15
14
  from keras.src.quantizers.gptq_config import GPTQConfig
15
+ from keras.src.quantizers.utils import should_quantize_layer
16
16
 
17
17
 
18
18
  @contextmanager
@@ -131,7 +131,7 @@ def get_dataloader(
131
131
  pieces = []
132
132
  if isinstance(dataset_list[0], str):
133
133
  for i, s in enumerate(dataset_list):
134
- toks = np.asarray(tokenizer.tokenize(s)).reshape(-1)
134
+ toks = ops.convert_to_numpy(tokenizer.tokenize(s)).reshape(-1)
135
135
  pieces.append(toks)
136
136
  # avoid windows that span document boundaries
137
137
  if eos_id is not None and i < len(dataset_list) - 1:
@@ -193,38 +193,6 @@ def get_dataloader(
193
193
  return samples.astype(np.int32)[:, None, :]
194
194
 
195
195
 
196
- def _get_backbone_layers(model):
197
- """Extract embedding and transformer layers from a KerasHub model."""
198
- backbone = model.backbone
199
- if not hasattr(backbone, "transformer_layers"):
200
- raise ValueError(
201
- "The model's backbone does not have a 'transformer_layers' "
202
- "attribute. Please ensure you are using a standard KerasHub "
203
- "transformer model."
204
- )
205
- transformer_blocks = backbone.transformer_layers
206
-
207
- embedding_layer = None
208
- if hasattr(backbone, "token_embedding"):
209
- embedding_layer = backbone.token_embedding
210
- elif hasattr(backbone, "embedding"):
211
- embedding_layer = backbone.embedding
212
- return embedding_layer, transformer_blocks
213
-
214
-
215
- def _get_custom_layers(model):
216
- """Heuristic for extracting embedding + transformer blocks from a custom
217
- model."""
218
- embedding_layer = None
219
- transformer_blocks = []
220
- for layer in model.layers:
221
- if isinstance(layer, Embedding) and embedding_layer is None:
222
- embedding_layer = layer
223
- elif getattr(layer, "_layers", None): # container-like block
224
- transformer_blocks.append(layer)
225
- return embedding_layer, transformer_blocks
226
-
227
-
228
196
  def find_layers_in_block(block):
229
197
  """
230
198
  Finds all Dense and EinsumDense layers in a transformer block.
@@ -242,39 +210,31 @@ def find_layers_in_block(block):
242
210
  return found_layers
243
211
 
244
212
 
245
- def apply_gptq_layerwise(model, dataloader, config):
213
+ def apply_gptq_layerwise(dataloader, config, structure, filters=None):
246
214
  """Applies GPTQ quantization layer-by-layer to a Keras model.
247
215
 
248
- This function is designed to work with common transformer architectures,
249
- like those provided by KerasHub. It automatically discovers the model's
250
- structure by first looking for the standard format: a `model.backbone`
251
- attribute that contains a `transformer_layers` list.
252
-
253
- If a standard backbone is not found, it falls back to a heuristic for
254
- custom models, where it assumes the first `keras.layers.Embedding` layer
255
- is the input embedding and any subsequent container layers are the
256
- transformer blocks to be quantized.
216
+ This function uses the provided `structure` to identify pre-quantization
217
+ layers and sequential blocks.
257
218
 
258
219
  The core logic operates as follows:
259
- 1. It automatically detects the model's structure, identifying the main
260
- embedding layer and a sequence of transformer blocks.
261
- 2. It processes the model sequentially, one block at a time. For each
220
+
221
+ 1. It processes the model sequentially, one block at a time. For each
262
222
  block, it uses temporary hooks to capture the input activations of
263
223
  each target layer during a forward pass with the calibration data.
264
- 3. These captured activations are used to compute the Hessian matrix for
224
+ 2. These captured activations are used to compute the Hessian matrix for
265
225
  each layer's weights.
266
- 4. The GPTQ algorithm is then applied to each layer to find the optimal
226
+ 3. The GPTQ algorithm is then applied to each layer to find the optimal
267
227
  quantized weights that minimize the error introduced.
268
- 5. The output activations from the current block are then used as the
228
+ 4. The output activations from the current block are then used as the
269
229
  input for the next block, ensuring that quantization errors are
270
230
  accounted for throughout the model.
271
231
 
272
232
  Args:
273
- model: The Keras model instance to be quantized. The function will
274
- attempt to automatically discover its structure.
275
- dataloader: An iterable providing calibration data. Each item should
276
- be a batch of token IDs suitable for the model's embedding layer.
233
+ dataloader: An iterable providing calibration data.
277
234
  config: A GPTQConfiguration object.
235
+ structure: A dictionary with keys "pre_block_layers" and
236
+ "sequential_blocks".
237
+ filters: Optional filters to exclude layers from quantization.
278
238
 
279
239
  Raises:
280
240
  ValueError: If the function cannot automatically find an embedding
@@ -284,30 +244,23 @@ def apply_gptq_layerwise(model, dataloader, config):
284
244
  num_samples = config.num_samples
285
245
 
286
246
  logging.info("Starting model quantization...")
287
- embedding_layer = None
288
- transformer_blocks = []
289
- if hasattr(model, "backbone"):
290
- logging.info("Detected KerasHub model structure.")
291
- embedding_layer, transformer_blocks = _get_backbone_layers(model)
292
- else:
293
- logging.info("Detected custom model structure.")
294
- embedding_layer, transformer_blocks = _get_custom_layers(model)
295
247
 
296
- if embedding_layer is None:
297
- raise ValueError(
298
- "Could not automatically find an embedding layer in the model."
299
- )
248
+ pre_layers = structure.get("pre_block_layers", [])
249
+ transformer_blocks = structure.get("sequential_blocks", [])
250
+
300
251
  if not transformer_blocks:
301
252
  raise ValueError(
302
- "Could not automatically find any transformer-like blocks to "
303
- "quantize."
253
+ "No sequential blocks found in the provided structure to quantize."
304
254
  )
305
255
 
306
- # Initial inputs are the outputs of the token embedding layer
307
- inputs = [
308
- embedding_layer(ops.convert_to_tensor(batch, dtype="int32"))
309
- for batch in dataloader
310
- ]
256
+ # Initial inputs are the outputs of the pre-block layers
257
+ inputs = []
258
+ for batch in dataloader:
259
+ batch = ops.convert_to_tensor(batch, dtype="int32")
260
+ for layer in pre_layers:
261
+ batch = layer(batch)
262
+ inputs.append(batch)
263
+
311
264
  num_samples = min(num_samples, len(inputs))
312
265
 
313
266
  progbar = keras_utils.Progbar(target=len(transformer_blocks))
@@ -316,10 +269,19 @@ def apply_gptq_layerwise(model, dataloader, config):
316
269
  logging.info(f"Quantizing Block {block_idx}")
317
270
  sub_layers_map = find_layers_in_block(block)
318
271
 
272
+ # Filter out layers that are not quantized with GPTQ
273
+ final_sub_layers_map = {}
274
+ for name, layer in sub_layers_map.items():
275
+ if not should_quantize_layer(layer, filters):
276
+ continue
277
+
278
+ final_sub_layers_map[name] = layer
279
+
280
+ sub_layers_map = final_sub_layers_map
281
+
319
282
  if not sub_layers_map:
320
283
  logging.info(
321
- f" No Dense or EinsumDense layers found in block {block_idx}. "
322
- "Skipping."
284
+ f" No quantizable layers found in block {block_idx}. Skipping."
323
285
  )
324
286
  else:
325
287
  logging.info(f"Found layers: {list(sub_layers_map.keys())}")
@@ -357,11 +319,30 @@ def apply_gptq_layerwise(model, dataloader, config):
357
319
  logging.info("Quantization process complete.")
358
320
 
359
321
 
360
- def gptq_quantize(model, config):
322
+ def gptq_quantize(config, quantization_layer_structure, filters=None):
361
323
  """
362
- Top-level function to quantize a Keras model using GPTQ.
324
+ Quantizes the model using GPTQ.
325
+
326
+ Args:
327
+ config: The GPTQ configuration.
328
+ quantization_layer_structure: A dictionary describing the model's layer
329
+ structure for quantization.
330
+ filters: Optional filters to exclude layers from quantization.
363
331
  """
364
- logging.info("Starting GPTQ quantization process...")
332
+ if config.dataset is None or config.tokenizer is None:
333
+ raise ValueError(
334
+ "GPTQ quantization requires a dataset and a tokenizer. "
335
+ "Please provide them in the `GPTQConfig`."
336
+ )
337
+
338
+ if quantization_layer_structure is None:
339
+ raise ValueError(
340
+ "For 'gptq' mode, a valid quantization structure must be provided "
341
+ "either via `config.quantization_layer_structure` or by overriding "
342
+ "`model.get_quantization_layer_structure(mode)`. The structure "
343
+ "should be a dictionary with keys 'pre_block_layers' and "
344
+ "'sequential_blocks'."
345
+ )
365
346
 
366
347
  # Load all data needed from the generator/source in a single call.
367
348
  total_samples_to_request = config.num_samples
@@ -376,7 +357,12 @@ def gptq_quantize(model, config):
376
357
  # is now a NumPy array, which can be sliced and reused.
377
358
  calibration_dataloader = dataloader[: config.num_samples]
378
359
 
379
- apply_gptq_layerwise(model, calibration_dataloader, config)
360
+ apply_gptq_layerwise(
361
+ calibration_dataloader,
362
+ config,
363
+ quantization_layer_structure,
364
+ filters=filters,
365
+ )
380
366
 
381
367
 
382
368
  def get_group_size_for_layer(layer, config):