keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,121 @@
1
- import random
1
+ import math
2
+ from contextlib import contextmanager
2
3
 
3
4
  import numpy as np
4
5
  from absl import logging
5
6
 
6
7
  from keras.src import ops
7
8
  from keras.src import utils as keras_utils
9
+ from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
10
+ from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
8
11
  from keras.src.layers import Dense
9
12
  from keras.src.layers import EinsumDense
10
- from keras.src.layers import Embedding
11
13
  from keras.src.quantizers.gptq import GPTQ
12
- from keras.src.quantizers.gptq_quant import GPTQQuantization
14
+ from keras.src.quantizers.gptq_config import GPTQConfig
15
+ from keras.src.quantizers.utils import should_quantize_layer
13
16
 
14
17
 
15
- def get_dataloader(tokenizer, sequence_length, dataset, num_samples=128):
18
+ @contextmanager
19
+ def stream_hessians(layers_map, gptq_objects):
16
20
  """
17
- Prepares and chunks the calibration dataloader, repeating short datasets.
21
+ Temporarily monkey-patch each target layer's `call` method so
22
+ that input activations are streamed into the GPTQ instance
23
+ running Hessian estimate at capture time.
24
+
25
+ On `__enter__`: For every (name, layer) in `layers_map`, replaces
26
+ `layer.call` with a wrapper that:
27
+ 1) extracts the layer input from `*args`/`**kwargs`,
28
+ 2) reshapes it to 2D `[-1, rows]` where
29
+ `rows = gptq_objects[name].rows`,
30
+ 3) calls `gptq_objects[name].update_hessian_with_batch(x2d)`
31
+ 4) delegates to the original `layer.call` and returns its
32
+ output.
33
+
34
+ On `__exit__`: All original `layer.call` methods are restored even if an
35
+ exception occurs.
36
+
37
+ * Space complexity: O(d**2) per layer (for the Hessian).
38
+ * No weights are modified; only GPTQ statistics are updated.
39
+
40
+ Args:
41
+ layers_map: Dict[str, Layer]. Mapping from logical layer names to
42
+ the Keras layers that should be patched during calibration. Keys must
43
+ match `gptq_objects`.
44
+ gptq_objects: Dict[str, GPTQ]. Mapping from names to GPTQ instances.
45
+
46
+ Yields:
47
+ None: The patched state is active only within the `with` block. After
48
+ exit, all layers are unpatched and safe to use normally.
49
+
50
+ Example:
51
+ ```python
52
+ >>> with stream_hessians(layers_map, gptq_objects):
53
+ ... for sample in calibration_inputs:
54
+ ... if len(sample.shape) == 2:
55
+ ... sample = ops.expand_dims(sample, 0)
56
+ ... _ = block(sample) # hooks update Hessians on-the-fly
57
+ >>> # <- original layer.call methods restored here
58
+ ```
18
59
  """
19
- all_tokens = []
60
+ original_calls = {}
61
+
62
+ def create_hook(name, original_call_func):
63
+ def hook(*args, **kwargs):
64
+ inp = args[0] if args else kwargs["inputs"]
65
+ # Explicitly reshape the input tensor to be 2D, with the
66
+ # second dimension matching the number of input features
67
+ # expected by the layer's kernel.
68
+ # This correctly handles inputs of any dimensionality
69
+ # (e.g., 3D or 4D).
70
+ num_features = gptq_objects[name].rows
71
+ input_2d = ops.reshape(inp, (-1, num_features))
72
+ gptq_objects[name].update_hessian_with_batch(input_2d)
73
+ return original_call_func(*args, **kwargs)
74
+
75
+ return hook
76
+
77
+ try:
78
+ for name, layer in layers_map.items():
79
+ original_calls[name] = layer.call
80
+ layer.call = create_hook(name, layer.call)
81
+ yield
82
+ finally:
83
+ for name, layer in layers_map.items():
84
+ layer.call = original_calls[name]
85
+
86
+
87
+ def get_dataloader(
88
+ tokenizer,
89
+ sequence_length,
90
+ dataset,
91
+ num_samples=128,
92
+ *,
93
+ strategy="strided",
94
+ seed=42,
95
+ stride=None,
96
+ eos_id=None,
97
+ ):
98
+ """
99
+ Prepares and chunks the calibration dataloader, repeating short datasets.
100
+ All processing happens on the CPU.
20
101
 
102
+ Args:
103
+ tokenizer: The tokenizer to use for text splitting.
104
+ sequence_length: The length of each input sequence.
105
+ dataset: The dataset to sample from.
106
+ num_samples: The number of samples to generate.
107
+ strategy: The sampling strategy to use. Possible values are
108
+ 1. "strided": Samples are taken at regular intervals.
109
+ 2. "linspace": Samples are taken at evenly spaced intervals.
110
+ 3. "random": Samples are taken at random positions.
111
+ seed: The random seed for reproducibility. Used only if
112
+ strategy="random"
113
+ stride: The stride length for "strided" sampling.
114
+ eos_id: The end-of-sequence token ID.
115
+
116
+ Returns:
117
+ np.ndarray of shape (num_samples, 1, sequence_length), dtype int32.
118
+ """
21
119
  if not hasattr(dataset, "__iter__") or isinstance(dataset, (str, bytes)):
22
120
  raise TypeError(
23
121
  "The `dataset` argument must be an iterable (e.g., a list of "
@@ -27,267 +125,184 @@ def get_dataloader(tokenizer, sequence_length, dataset, num_samples=128):
27
125
  )
28
126
 
29
127
  dataset_list = list(dataset)
30
-
31
128
  if not dataset_list:
32
129
  raise ValueError("Provided dataset is empty.")
33
130
 
131
+ pieces = []
34
132
  if isinstance(dataset_list[0], str):
35
- logging.info("(Dataset contains strings, tokenizing now...)")
36
- full_text = "\n\n".join(dataset_list)
37
- all_tokens = tokenizer.tokenize(full_text)
133
+ for i, s in enumerate(dataset_list):
134
+ toks = ops.convert_to_numpy(tokenizer.tokenize(s)).reshape(-1)
135
+ pieces.append(toks)
136
+ # avoid windows that span document boundaries
137
+ if eos_id is not None and i < len(dataset_list) - 1:
138
+ pieces.append(np.array([eos_id], dtype=np.int32))
38
139
  else:
39
- logging.info("(Dataset is pre-tokenized, concatenating...)")
40
- all_tokens = np.concatenate(
41
- [ops.convert_to_numpy(s).reshape(-1) for s in dataset_list], axis=0
42
- )
43
-
44
- all_tokens = np.array(all_tokens, dtype=np.int32)
140
+ for s in dataset_list:
141
+ toks = ops.convert_to_numpy(s).reshape(-1)
142
+ pieces.append(toks.astype(np.int32, copy=False))
143
+
144
+ all_tokens = (
145
+ pieces[0].astype(np.int32, copy=False)
146
+ if len(pieces) == 1
147
+ else np.concatenate(pieces, axis=0).astype(np.int32, copy=False)
148
+ )
45
149
 
46
- # Repeat data if it's too short
47
150
  required_tokens = num_samples * sequence_length
48
- if len(all_tokens) < required_tokens:
49
- logging.info(
50
- f"Warning: Dataset is too short ({len(all_tokens)} tokens)."
51
- " Repeating data to generate {num_samples} samples."
52
- )
53
- repeats = -(-required_tokens // len(all_tokens)) # Ceiling division
151
+ if all_tokens.size < required_tokens:
152
+ repeats = math.ceil(required_tokens / max(1, all_tokens.size))
54
153
  all_tokens = np.tile(all_tokens, repeats)
55
154
 
56
- # Chunk the token list into samples
57
-
58
- calibration_samples = []
59
- for _ in range(num_samples):
60
- # Generate a random starting index
61
- start_index = random.randint(0, len(all_tokens) - sequence_length - 1)
62
- end_index = start_index + sequence_length
63
- sample = all_tokens[start_index:end_index]
64
- calibration_samples.append(np.reshape(sample, (1, sequence_length)))
65
-
66
- final_array = np.stack(calibration_samples, axis=0)
67
- return final_array
68
-
155
+ max_start = all_tokens.size - sequence_length
156
+ if max_start < 0:
157
+ raise ValueError(
158
+ f"Not enough tokens to form one sample of length {sequence_length} "
159
+ f"(have {all_tokens.size})."
160
+ )
69
161
 
70
- def _find_layers_recursive(layer, prefix, found_layers):
71
- """
72
- Recursively search for Dense and EinsumDense layers and record them.
73
- """
74
- for sub_layer in layer._layers:
75
- # Construct a unique name for the layer based on its hierarchy
76
- layer_name = f"{prefix}.{sub_layer.name}"
77
- if isinstance(sub_layer, (Dense, EinsumDense)):
78
- found_layers[layer_name] = sub_layer
162
+ # Choose deterministic, well-spread starts by default
163
+ if strategy == "random":
164
+ rng = np.random.default_rng(seed)
165
+ starts = rng.integers(
166
+ 0, max_start + 1, size=num_samples, dtype=np.int64
167
+ )
168
+ elif strategy == "linspace":
169
+ # even coverage with no RNG
170
+ starts = np.linspace(0, max_start, num_samples, dtype=np.int64)
171
+ elif strategy == "strided":
172
+ # stride chosen to cover the space roughly uniformly
173
+ if stride is None:
174
+ stride = max(1, (max_start + 1) // num_samples)
175
+ # offset derived deterministically from seed
176
+ offset = (
177
+ (abs(hash(("gptq-calib", seed))) % (max_start + 1))
178
+ if max_start > 0
179
+ else 0
180
+ )
181
+ starts = (offset + np.arange(num_samples, dtype=np.int64) * stride) % (
182
+ max_start + 1
183
+ )
184
+ else:
185
+ raise ValueError(f"Unknown strategy: {strategy}")
79
186
 
80
- # Recurse into nested layers that are not the target types
81
- elif hasattr(sub_layer, "_layers") and sub_layer._layers:
82
- _find_layers_recursive(sub_layer, layer_name, found_layers)
187
+ # Gather contiguous windows
188
+ # sliding_window_view avoids building a big index matrix
189
+ windows = np.lib.stride_tricks.sliding_window_view(
190
+ all_tokens, sequence_length
191
+ )
192
+ samples = windows[starts] # (num_samples, sequence_length)
193
+ return samples.astype(np.int32)[:, None, :]
83
194
 
84
195
 
85
196
  def find_layers_in_block(block):
86
197
  """
87
- A pluggable, generic function to find all Dense and EinsumDense layers
88
- within any transformer block by using a recursive search.
198
+ Finds all Dense and EinsumDense layers in a transformer block.
199
+
200
+ Args:
201
+ block: A Keras layer representing a transformer block.
202
+ Returns:
203
+ A dict mapping layer paths to the corresponding Dense or EinsumDense
89
204
  """
90
205
  found_layers = {}
91
- # Start the recursive search from the block itself
92
- _find_layers_recursive(block, "block", found_layers)
206
+ for sub_layer in block._flatten_layers():
207
+ if len(list(sub_layer._flatten_layers())) == 1:
208
+ if isinstance(sub_layer, (Dense, EinsumDense)):
209
+ found_layers[sub_layer.path] = sub_layer
93
210
  return found_layers
94
211
 
95
212
 
96
- def apply_gptq_layerwise(
97
- model,
98
- dataloader,
99
- num_samples,
100
- hessian_damping,
101
- group_size,
102
- symmetric,
103
- activation_order,
104
- weight_bits,
105
- ):
213
+ def apply_gptq_layerwise(dataloader, config, structure, filters=None):
106
214
  """Applies GPTQ quantization layer-by-layer to a Keras model.
107
215
 
108
- This function is designed to work with common transformer architectures,
109
- like those provided by KerasHub. It automatically discovers the model's
110
- structure by first looking for the standard format: a `model.backbone`
111
- attribute that contains a `transformer_layers` list.
112
-
113
- If a standard backbone is not found, it falls back to a heuristic for
114
- custom models, where it assumes the first `keras.layers.Embedding` layer
115
- is the input embedding and any subsequent container layers are the
116
- transformer blocks to be quantized.
216
+ This function uses the provided `structure` to identify pre-quantization
217
+ layers and sequential blocks.
117
218
 
118
219
  The core logic operates as follows:
119
- 1. It automatically detects the model's structure, identifying the main
120
- embedding layer and a sequence of transformer blocks.
121
- 2. It processes the model sequentially, one block at a time. For each
220
+
221
+ 1. It processes the model sequentially, one block at a time. For each
122
222
  block, it uses temporary hooks to capture the input activations of
123
223
  each target layer during a forward pass with the calibration data.
124
- 3. These captured activations are used to compute the Hessian matrix for
224
+ 2. These captured activations are used to compute the Hessian matrix for
125
225
  each layer's weights.
126
- 4. The GPTQ algorithm is then applied to each layer to find the optimal
226
+ 3. The GPTQ algorithm is then applied to each layer to find the optimal
127
227
  quantized weights that minimize the error introduced.
128
- 5. The output activations from the current block are then used as the
228
+ 4. The output activations from the current block are then used as the
129
229
  input for the next block, ensuring that quantization errors are
130
230
  accounted for throughout the model.
131
231
 
132
232
  Args:
133
- model: The Keras model instance to be quantized. The function will
134
- attempt to automatically discover its structure.
135
- dataloader: An iterable providing calibration data. Each item should
136
- be a batch of token IDs suitable for the model's embedding layer.
137
- num_samples: (int) The number of samples from the dataloader to use for
138
- calibration.
139
- hessian_damping: (float) The percentage of dampening to add to the
140
- Hessian diagonal for stabilization during inverse calculation.
141
- A value of 0.01 is common.
142
- group_size: (int) The size of the groups to use for quantization. A
143
- value of 128 means that 128 weights will share the same scaling
144
- factor. Use -1 for per-channel quantization.
145
- symmetric: (bool) If True, symmetric quantization is used. Otherwise,
146
- asymmetric quantization is used.
147
- activation_order: (bool) If True, reorders the weight columns based on
148
- activation magnitude, which can improve quantization accuracy.
149
- weight_bits: (int) The number of bits to use for the quantized weights,
150
- e.g., 4 for 4-bit quantization.
233
+ dataloader: An iterable providing calibration data.
234
+ config: A GPTQConfiguration object.
235
+ structure: A dictionary with keys "pre_block_layers" and
236
+ "sequential_blocks".
237
+ filters: Optional filters to exclude layers from quantization.
151
238
 
152
239
  Raises:
153
240
  ValueError: If the function cannot automatically find an embedding
154
241
  layer or any transformer-like blocks to quantize within the model.
155
242
  """
243
+
244
+ num_samples = config.num_samples
245
+
156
246
  logging.info("Starting model quantization...")
157
- embedding_layer = None
158
- transformer_blocks = []
159
- if hasattr(model, "backbone"):
160
- logging.info("Detected KerasHub model structure.")
161
- backbone = model.backbone
162
-
163
- # Add the check for the 'transformer_layers' attribute.
164
- if hasattr(backbone, "transformer_layers"):
165
- transformer_blocks = backbone.transformer_layers
166
- else:
167
- # Raise a specific error if the attribute is missing.
168
- raise ValueError(
169
- "The model's backbone does not have a 'transformer_layers' "
170
- "attribute. Please ensure you are using a standard KerasHub "
171
- "transformer model."
172
- )
173
- # Find the embedding layer by checking for common names or by type.
174
- if hasattr(backbone, "token_embedding"):
175
- embedding_layer = backbone.token_embedding
176
- elif hasattr(backbone, "embedding"):
177
- embedding_layer = backbone.embedding
178
- else:
179
- raise ValueError(
180
- "Could not automatically find an embedding layer in the model."
181
- )
182
247
 
183
- else:
184
- logging.info("Detected custom model structure.")
185
- for layer in model.layers:
186
- # The first Embedding layer found is assumed to be the main one.
187
- if isinstance(layer, Embedding) and embedding_layer is None:
188
- embedding_layer = layer
189
- # A "block" is a container-like layer with its own sub-layers
190
- # that we can quantize. This is a heuristic that works for the
191
- # test.
192
- elif hasattr(layer, "_layers") and layer._layers:
193
- transformer_blocks.append(layer)
194
-
195
- if embedding_layer is None:
196
- raise ValueError(
197
- "Could not automatically find an embedding layer in the model."
198
- )
248
+ pre_layers = structure.get("pre_block_layers", [])
249
+ transformer_blocks = structure.get("sequential_blocks", [])
250
+
199
251
  if not transformer_blocks:
200
252
  raise ValueError(
201
- "Could not automatically find any transformer-like blocks to "
202
- "quantize."
253
+ "No sequential blocks found in the provided structure to quantize."
203
254
  )
204
255
 
205
- # Initial inputs are the outputs of the token embedding layer
206
- inputs = [
207
- embedding_layer(ops.convert_to_tensor(batch, dtype="int32"))
208
- for batch in dataloader
209
- ]
256
+ # Initial inputs are the outputs of the pre-block layers
257
+ inputs = []
258
+ for batch in dataloader:
259
+ batch = ops.convert_to_tensor(batch, dtype="int32")
260
+ for layer in pre_layers:
261
+ batch = layer(batch)
262
+ inputs.append(batch)
263
+
264
+ num_samples = min(num_samples, len(inputs))
265
+
210
266
  progbar = keras_utils.Progbar(target=len(transformer_blocks))
211
267
 
212
268
  for block_idx, block in enumerate(transformer_blocks):
213
269
  logging.info(f"Quantizing Block {block_idx}")
214
270
  sub_layers_map = find_layers_in_block(block)
215
271
 
272
+ # Filter out layers that are not quantized with GPTQ
273
+ final_sub_layers_map = {}
274
+ for name, layer in sub_layers_map.items():
275
+ if not should_quantize_layer(layer, filters):
276
+ continue
277
+
278
+ final_sub_layers_map[name] = layer
279
+
280
+ sub_layers_map = final_sub_layers_map
281
+
216
282
  if not sub_layers_map:
217
283
  logging.info(
218
- f" No Dense or EinsumDense layers found in block {block_idx}. "
219
- "Skipping."
284
+ f" No quantizable layers found in block {block_idx}. Skipping."
220
285
  )
221
286
  else:
222
287
  logging.info(f"Found layers: {list(sub_layers_map.keys())}")
223
288
  gptq_objects = {
224
- name: GPTQ(layer) for name, layer in sub_layers_map.items()
289
+ name: GPTQ(layer, config)
290
+ for name, layer in sub_layers_map.items()
225
291
  }
226
292
 
227
- captured_inputs = {name: [] for name in sub_layers_map.keys()}
228
- original_calls = {}
229
-
230
- def create_hook(name, original_call_func):
231
- """A factory for creating a hook to capture layer inputs."""
232
-
233
- def hook(*args, **kwargs):
234
- if args:
235
- inp = args[0]
236
- else:
237
- inp = kwargs["inputs"]
238
- captured_inputs[name].append(inp)
239
- return original_call_func(*args, **kwargs)
240
-
241
- return hook
242
-
243
- try:
244
- for name, layer in sub_layers_map.items():
245
- original_call = layer.call
246
- original_calls[name] = original_call
247
- layer.call = create_hook(name, original_call)
248
-
249
- logging.info(f"Capturing activations for block {block_idx}...")
293
+ with stream_hessians(sub_layers_map, gptq_objects):
250
294
  for sample_idx in range(num_samples):
251
295
  current_input = inputs[sample_idx]
252
296
  if len(current_input.shape) == 2:
253
297
  current_input = ops.expand_dims(current_input, axis=0)
254
298
  _ = block(current_input)
255
299
 
256
- finally:
257
- for name, layer in sub_layers_map.items():
258
- if name in original_calls:
259
- layer.call = original_calls[name]
260
-
261
- logging.info(f"Building Hessians for block {block_idx}...")
262
- for name, gptq_object in gptq_objects.items():
263
- layer_inputs = ops.concatenate(captured_inputs[name], axis=0)
264
-
265
- # Explicitly reshape the input tensor to be 2D, with the second
266
- # dimension matching the number of input features expected by
267
- # the layer's kernel.
268
- # This correctly handles inputs of any dimensionality
269
- # (e.g., 3D or 4D).
270
- num_features = gptq_object.rows
271
- input_reshaped = ops.reshape(layer_inputs, (-1, num_features))
272
- gptq_object.update_hessian_with_batch(input_reshaped)
273
-
274
- quantizer = GPTQQuantization(
275
- weight_bits,
276
- per_channel=True,
277
- symmetric=symmetric,
278
- group_size=group_size,
279
- )
280
300
  for name, gptq_object in gptq_objects.items():
281
301
  logging.info(f"Quantizing {name}...")
282
- gptq_object.quantizer = quantizer
283
- gptq_object.quantize_and_correct_block(
284
- hessian_damping=hessian_damping,
285
- group_size=group_size,
286
- activation_order=activation_order,
287
- )
302
+ gptq_object.quantize_and_correct_layer()
288
303
  gptq_object.free()
289
304
 
290
- del gptq_objects, captured_inputs, original_calls
305
+ del gptq_objects
291
306
 
292
307
  if block_idx < len(transformer_blocks) - 1:
293
308
  logging.info(f"Generating inputs for block {block_idx + 1}...")
@@ -304,32 +319,130 @@ def apply_gptq_layerwise(
304
319
  logging.info("Quantization process complete.")
305
320
 
306
321
 
307
- def quantize_model(model, config):
322
+ def gptq_quantize(config, quantization_layer_structure, filters=None):
308
323
  """
309
- Top-level function to quantize a Keras model using GPTQ.
324
+ Quantizes the model using GPTQ.
325
+
326
+ Args:
327
+ config: The GPTQ configuration.
328
+ quantization_layer_structure: A dictionary describing the model's layer
329
+ structure for quantization.
330
+ filters: Optional filters to exclude layers from quantization.
310
331
  """
311
- logging.info("Starting GPTQ quantization process...")
332
+ if config.dataset is None or config.tokenizer is None:
333
+ raise ValueError(
334
+ "GPTQ quantization requires a dataset and a tokenizer. "
335
+ "Please provide them in the `GPTQConfig`."
336
+ )
312
337
 
313
- # Load ALL data needed from the generator/source in a single call.
338
+ if quantization_layer_structure is None:
339
+ raise ValueError(
340
+ "For 'gptq' mode, a valid quantization structure must be provided "
341
+ "either via `config.quantization_layer_structure` or by overriding "
342
+ "`model.get_quantization_layer_structure(mode)`. The structure "
343
+ "should be a dictionary with keys 'pre_block_layers' and "
344
+ "'sequential_blocks'."
345
+ )
346
+
347
+ # Load all data needed from the generator/source in a single call.
314
348
  total_samples_to_request = config.num_samples
315
- full_dataloader = get_dataloader(
349
+ dataloader = get_dataloader(
316
350
  config.tokenizer,
317
351
  config.sequence_length,
318
352
  config.dataset,
319
353
  num_samples=total_samples_to_request,
320
354
  )
321
355
 
322
- # Split the materialized data. This works because full_dataloader
356
+ # Split the materialized data. This works because dataloader
323
357
  # is now a NumPy array, which can be sliced and reused.
324
- calibration_dataloader = full_dataloader[: config.num_samples]
358
+ calibration_dataloader = dataloader[: config.num_samples]
325
359
 
326
360
  apply_gptq_layerwise(
327
- model,
328
- calibration_dataloader, # Use the calibration slice
329
- config.num_samples, # Use the configured number of samples
330
- config.hessian_damping,
331
- config.group_size,
332
- config.symmetric,
333
- config.activation_order,
334
- config.weight_bits,
361
+ calibration_dataloader,
362
+ config,
363
+ quantization_layer_structure,
364
+ filters=filters,
335
365
  )
366
+
367
+
368
+ def get_group_size_for_layer(layer, config):
369
+ """Determine the group size for GPTQ quantization.
370
+
371
+ The group size can be specified either through the `config` argument
372
+ or through the `dtype_policy` if it is of type `GPTQDTypePolicy`.
373
+
374
+ The config argument is usually available when quantizing the layer
375
+ via the `quantize` method. If the layer was deserialized from a
376
+ saved model, the group size should be specified in the `dtype_policy`.
377
+
378
+ Args:
379
+ config: An optional configuration object that may contain the
380
+ `group_size` attribute.
381
+ Returns:
382
+ int. The determined group size for GPTQ quantization.
383
+ Raises:
384
+ ValueError: If the group size is not specified in either the
385
+ `config` or the `dtype_policy`.
386
+ """
387
+ if config and isinstance(config, GPTQConfig):
388
+ return config.group_size
389
+ elif isinstance(layer.dtype_policy, GPTQDTypePolicy):
390
+ return layer.dtype_policy.group_size
391
+ elif isinstance(layer.dtype_policy, DTypePolicyMap):
392
+ policy = layer.dtype_policy[layer.path]
393
+ if not isinstance(policy, GPTQDTypePolicy):
394
+ # This should never happen based on how we set the
395
+ # quantization mode, but we check just in case.
396
+ raise ValueError(
397
+ "Expected a `dtype_policy` of type `GPTQDTypePolicy`."
398
+ f"Got: {type(policy)}"
399
+ )
400
+ return policy.group_size
401
+ else:
402
+ raise ValueError(
403
+ "For GPTQ quantization, the group_size must be specified"
404
+ "either through a `dtype_policy` of type "
405
+ "`GPTQDTypePolicy` or the `config` argument."
406
+ )
407
+
408
+
409
+ def get_weight_bits_for_layer(layer, config):
410
+ """Determine the number of weight bits for GPTQ quantization.
411
+
412
+ The number of weight bits can be specified either through the `config`
413
+ argument or through the `dtype_policy` if it is of type
414
+ `GPTQDTypePolicy`.
415
+
416
+ The config argument is usually available when quantizing the layer
417
+ via the `quantize` method. If the layer was deserialized from a
418
+ saved model, the weight bits should be specified in the `dtype_policy`.
419
+
420
+ Args:
421
+ config: An optional configuration object that may contain the
422
+ `weight_bits` attribute.
423
+ Returns:
424
+ int. The determined number of weight bits for GPTQ quantization.
425
+ Raises:
426
+ ValueError: If the weight bits is not specified in either the
427
+ `config` or the `dtype_policy`.
428
+ """
429
+ if config and isinstance(config, GPTQConfig):
430
+ return config.weight_bits
431
+ elif isinstance(layer.dtype_policy, GPTQDTypePolicy):
432
+ return layer.dtype_policy.weight_bits
433
+ elif isinstance(layer.dtype_policy, DTypePolicyMap):
434
+ policy = layer.dtype_policy[layer.path]
435
+ if not isinstance(policy, GPTQDTypePolicy):
436
+ # This should never happen based on how we set the
437
+ # quantization mode, but we check just in case.
438
+ raise ValueError(
439
+ "Expected a `dtype_policy` of type `GPTQDTypePolicy`."
440
+ f"Got: {type(policy)}"
441
+ )
442
+ return policy.weight_bits
443
+ else:
444
+ raise ValueError(
445
+ "For GPTQ quantization, the weight_bits must be specified"
446
+ "either through a `dtype_policy` of type "
447
+ "`GPTQDTypePolicy` or the `config` argument."
448
+ )