keras-nightly 3.12.0.dev2025090203__py3-none-any.whl → 3.12.0.dev2025090403__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/src/backend/jax/layer.py +3 -1
- keras/src/models/model.py +2 -2
- keras/src/quantizers/gptq.py +284 -192
- keras/src/quantizers/gptq_config.py +3 -13
- keras/src/quantizers/gptq_core.py +211 -158
- keras/src/quantizers/quantizers.py +200 -0
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.12.0.dev2025090403.dist-info}/METADATA +1 -1
- {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.12.0.dev2025090403.dist-info}/RECORD +11 -12
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.12.0.dev2025090403.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.12.0.dev2025090403.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
|
-
import
|
1
|
+
import math
|
2
|
+
from contextlib import contextmanager
|
2
3
|
|
3
4
|
import numpy as np
|
4
5
|
from absl import logging
|
@@ -9,15 +10,109 @@ from keras.src.layers import Dense
|
|
9
10
|
from keras.src.layers import EinsumDense
|
10
11
|
from keras.src.layers import Embedding
|
11
12
|
from keras.src.quantizers.gptq import GPTQ
|
12
|
-
from keras.src.quantizers.gptq_quant import GPTQQuantization
|
13
13
|
|
14
14
|
|
15
|
-
|
15
|
+
@contextmanager
|
16
|
+
def stream_hessians(layers_map, gptq_objects):
|
16
17
|
"""
|
17
|
-
|
18
|
+
Temporarily monkey-patch each target layer's `call` method so
|
19
|
+
that input activations are streamed into the GPTQ instance
|
20
|
+
running Hessian estimate at capture time.
|
21
|
+
|
22
|
+
On `__enter__`: For every (name, layer) in `layers_map`, replaces
|
23
|
+
`layer.call` with a wrapper that:
|
24
|
+
1) extracts the layer input from `*args`/`**kwargs`,
|
25
|
+
2) reshapes it to 2D `[-1, rows]` where
|
26
|
+
`rows = gptq_objects[name].rows`,
|
27
|
+
3) calls `gptq_objects[name].update_hessian_with_batch(x2d)`
|
28
|
+
4) delegates to the original `layer.call` and returns its
|
29
|
+
output.
|
30
|
+
|
31
|
+
On `__exit__`: All original `layer.call` methods are restored even if an
|
32
|
+
exception occurs.
|
33
|
+
|
34
|
+
* Space complexity: O(d**2) per layer (for the Hessian).
|
35
|
+
* No weights are modified; only GPTQ statistics are updated.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
layers_map: Dict[str, Layer]. Mapping from logical layer names to
|
39
|
+
the Keras layers that should be patched during calibration. Keys must
|
40
|
+
match `gptq_objects`.
|
41
|
+
gptq_objects: Dict[str, GPTQ]. Mapping from names to GPTQ instances.
|
42
|
+
|
43
|
+
Yields:
|
44
|
+
None: The patched state is active only within the `with` block. After
|
45
|
+
exit, all layers are unpatched and safe to use normally.
|
46
|
+
|
47
|
+
Example:
|
48
|
+
```python
|
49
|
+
>>> with stream_hessians(layers_map, gptq_objects):
|
50
|
+
... for sample in calibration_inputs:
|
51
|
+
... if len(sample.shape) == 2:
|
52
|
+
... sample = ops.expand_dims(sample, 0)
|
53
|
+
... _ = block(sample) # hooks update Hessians on-the-fly
|
54
|
+
>>> # <- original layer.call methods restored here
|
55
|
+
```
|
56
|
+
"""
|
57
|
+
original_calls = {}
|
58
|
+
|
59
|
+
def create_hook(name, original_call_func):
|
60
|
+
def hook(*args, **kwargs):
|
61
|
+
inp = args[0] if args else kwargs["inputs"]
|
62
|
+
# Explicitly reshape the input tensor to be 2D, with the
|
63
|
+
# second dimension matching the number of input features
|
64
|
+
# expected by the layer's kernel.
|
65
|
+
# This correctly handles inputs of any dimensionality
|
66
|
+
# (e.g., 3D or 4D).
|
67
|
+
num_features = gptq_objects[name].rows
|
68
|
+
input_2d = ops.reshape(inp, (-1, num_features))
|
69
|
+
gptq_objects[name].update_hessian_with_batch(input_2d)
|
70
|
+
return original_call_func(*args, **kwargs)
|
71
|
+
|
72
|
+
return hook
|
73
|
+
|
74
|
+
try:
|
75
|
+
for name, layer in layers_map.items():
|
76
|
+
original_calls[name] = layer.call
|
77
|
+
layer.call = create_hook(name, layer.call)
|
78
|
+
yield
|
79
|
+
finally:
|
80
|
+
for name, layer in layers_map.items():
|
81
|
+
layer.call = original_calls[name]
|
82
|
+
|
83
|
+
|
84
|
+
def get_dataloader(
|
85
|
+
tokenizer,
|
86
|
+
sequence_length,
|
87
|
+
dataset,
|
88
|
+
num_samples=128,
|
89
|
+
*,
|
90
|
+
strategy="strided",
|
91
|
+
seed=42,
|
92
|
+
stride=None,
|
93
|
+
eos_id=None,
|
94
|
+
):
|
18
95
|
"""
|
19
|
-
|
96
|
+
Prepares and chunks the calibration dataloader, repeating short datasets.
|
97
|
+
All processing happens on the CPU.
|
20
98
|
|
99
|
+
Args:
|
100
|
+
tokenizer: The tokenizer to use for text splitting.
|
101
|
+
sequence_length: The length of each input sequence.
|
102
|
+
dataset: The dataset to sample from.
|
103
|
+
num_samples: The number of samples to generate.
|
104
|
+
strategy: The sampling strategy to use. Possible values are
|
105
|
+
1. "strided": Samples are taken at regular intervals.
|
106
|
+
2. "linspace": Samples are taken at evenly spaced intervals.
|
107
|
+
3. "random": Samples are taken at random positions.
|
108
|
+
seed: The random seed for reproducibility. Used only if
|
109
|
+
strategy="random"
|
110
|
+
stride: The stride length for "strided" sampling.
|
111
|
+
eos_id: The end-of-sequence token ID.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
np.ndarray of shape (num_samples, 1, sequence_length), dtype int32.
|
115
|
+
"""
|
21
116
|
if not hasattr(dataset, "__iter__") or isinstance(dataset, (str, bytes)):
|
22
117
|
raise TypeError(
|
23
118
|
"The `dataset` argument must be an iterable (e.g., a list of "
|
@@ -27,44 +122,72 @@ def get_dataloader(tokenizer, sequence_length, dataset, num_samples=128):
|
|
27
122
|
)
|
28
123
|
|
29
124
|
dataset_list = list(dataset)
|
30
|
-
|
31
125
|
if not dataset_list:
|
32
126
|
raise ValueError("Provided dataset is empty.")
|
33
127
|
|
128
|
+
pieces = []
|
34
129
|
if isinstance(dataset_list[0], str):
|
35
|
-
|
36
|
-
|
37
|
-
|
130
|
+
for i, s in enumerate(dataset_list):
|
131
|
+
toks = np.asarray(tokenizer.tokenize(s)).reshape(-1)
|
132
|
+
pieces.append(toks)
|
133
|
+
# avoid windows that span document boundaries
|
134
|
+
if eos_id is not None and i < len(dataset_list) - 1:
|
135
|
+
pieces.append(np.array([eos_id], dtype=np.int32))
|
38
136
|
else:
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
137
|
+
for s in dataset_list:
|
138
|
+
toks = ops.convert_to_numpy(s).reshape(-1)
|
139
|
+
pieces.append(toks.astype(np.int32, copy=False))
|
140
|
+
|
141
|
+
all_tokens = (
|
142
|
+
pieces[0].astype(np.int32, copy=False)
|
143
|
+
if len(pieces) == 1
|
144
|
+
else np.concatenate(pieces, axis=0).astype(np.int32, copy=False)
|
145
|
+
)
|
45
146
|
|
46
|
-
# Repeat data if it's too short
|
47
147
|
required_tokens = num_samples * sequence_length
|
48
|
-
if
|
49
|
-
|
50
|
-
f"Warning: Dataset is too short ({len(all_tokens)} tokens)."
|
51
|
-
" Repeating data to generate {num_samples} samples."
|
52
|
-
)
|
53
|
-
repeats = -(-required_tokens // len(all_tokens)) # Ceiling division
|
148
|
+
if all_tokens.size < required_tokens:
|
149
|
+
repeats = math.ceil(required_tokens / max(1, all_tokens.size))
|
54
150
|
all_tokens = np.tile(all_tokens, repeats)
|
55
151
|
|
56
|
-
|
152
|
+
max_start = all_tokens.size - sequence_length
|
153
|
+
if max_start < 0:
|
154
|
+
raise ValueError(
|
155
|
+
f"Not enough tokens to form one sample of length {sequence_length} "
|
156
|
+
f"(have {all_tokens.size})."
|
157
|
+
)
|
57
158
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
159
|
+
# Choose deterministic, well-spread starts by default
|
160
|
+
if strategy == "random":
|
161
|
+
rng = np.random.default_rng(seed)
|
162
|
+
starts = rng.integers(
|
163
|
+
0, max_start + 1, size=num_samples, dtype=np.int64
|
164
|
+
)
|
165
|
+
elif strategy == "linspace":
|
166
|
+
# even coverage with no RNG
|
167
|
+
starts = np.linspace(0, max_start, num_samples, dtype=np.int64)
|
168
|
+
elif strategy == "strided":
|
169
|
+
# stride chosen to cover the space roughly uniformly
|
170
|
+
if stride is None:
|
171
|
+
stride = max(1, (max_start + 1) // num_samples)
|
172
|
+
# offset derived deterministically from seed
|
173
|
+
offset = (
|
174
|
+
(abs(hash(("gptq-calib", seed))) % (max_start + 1))
|
175
|
+
if max_start > 0
|
176
|
+
else 0
|
177
|
+
)
|
178
|
+
starts = (offset + np.arange(num_samples, dtype=np.int64) * stride) % (
|
179
|
+
max_start + 1
|
180
|
+
)
|
181
|
+
else:
|
182
|
+
raise ValueError(f"Unknown strategy: {strategy}")
|
65
183
|
|
66
|
-
|
67
|
-
|
184
|
+
# Gather contiguous windows
|
185
|
+
# sliding_window_view avoids building a big index matrix
|
186
|
+
windows = np.lib.stride_tricks.sliding_window_view(
|
187
|
+
all_tokens, sequence_length
|
188
|
+
)
|
189
|
+
samples = windows[starts] # (num_samples, sequence_length)
|
190
|
+
return samples.astype(np.int32)[:, None, :]
|
68
191
|
|
69
192
|
|
70
193
|
def _find_layers_recursive(layer, prefix, found_layers):
|
@@ -82,6 +205,41 @@ def _find_layers_recursive(layer, prefix, found_layers):
|
|
82
205
|
_find_layers_recursive(sub_layer, layer_name, found_layers)
|
83
206
|
|
84
207
|
|
208
|
+
def _get_backbone_layers(model):
|
209
|
+
"""Extract embedding and transformer layers from a KerasHub model."""
|
210
|
+
backbone = model.backbone
|
211
|
+
if not hasattr(backbone, "transformer_layers"):
|
212
|
+
raise ValueError(
|
213
|
+
"The model's backbone does not have a 'transformer_layers' "
|
214
|
+
"attribute. Please ensure you are using a standard KerasHub "
|
215
|
+
"transformer model."
|
216
|
+
)
|
217
|
+
transformer_blocks = backbone.transformer_layers
|
218
|
+
|
219
|
+
if hasattr(backbone, "token_embedding"):
|
220
|
+
embedding_layer = backbone.token_embedding
|
221
|
+
elif hasattr(backbone, "embedding"):
|
222
|
+
embedding_layer = backbone.embedding
|
223
|
+
else:
|
224
|
+
raise ValueError(
|
225
|
+
"Could not automatically find an embedding layer in the model."
|
226
|
+
)
|
227
|
+
return embedding_layer, transformer_blocks
|
228
|
+
|
229
|
+
|
230
|
+
def _get_custom_layers(model):
|
231
|
+
"""Heuristic for extracting embedding + transformer blocks from a custom
|
232
|
+
model."""
|
233
|
+
embedding_layer = None
|
234
|
+
transformer_blocks = []
|
235
|
+
for layer in model.layers:
|
236
|
+
if isinstance(layer, Embedding) and embedding_layer is None:
|
237
|
+
embedding_layer = layer
|
238
|
+
elif getattr(layer, "_layers", None): # container-like block
|
239
|
+
transformer_blocks.append(layer)
|
240
|
+
return embedding_layer, transformer_blocks
|
241
|
+
|
242
|
+
|
85
243
|
def find_layers_in_block(block):
|
86
244
|
"""
|
87
245
|
A pluggable, generic function to find all Dense and EinsumDense layers
|
@@ -93,16 +251,7 @@ def find_layers_in_block(block):
|
|
93
251
|
return found_layers
|
94
252
|
|
95
253
|
|
96
|
-
def apply_gptq_layerwise(
|
97
|
-
model,
|
98
|
-
dataloader,
|
99
|
-
num_samples,
|
100
|
-
hessian_damping,
|
101
|
-
group_size,
|
102
|
-
symmetric,
|
103
|
-
activation_order,
|
104
|
-
weight_bits,
|
105
|
-
):
|
254
|
+
def apply_gptq_layerwise(model, dataloader, config):
|
106
255
|
"""Applies GPTQ quantization layer-by-layer to a Keras model.
|
107
256
|
|
108
257
|
This function is designed to work with common transformer architectures,
|
@@ -134,63 +283,24 @@ def apply_gptq_layerwise(
|
|
134
283
|
attempt to automatically discover its structure.
|
135
284
|
dataloader: An iterable providing calibration data. Each item should
|
136
285
|
be a batch of token IDs suitable for the model's embedding layer.
|
137
|
-
|
138
|
-
calibration.
|
139
|
-
hessian_damping: (float) The percentage of dampening to add to the
|
140
|
-
Hessian diagonal for stabilization during inverse calculation.
|
141
|
-
A value of 0.01 is common.
|
142
|
-
group_size: (int) The size of the groups to use for quantization. A
|
143
|
-
value of 128 means that 128 weights will share the same scaling
|
144
|
-
factor. Use -1 for per-channel quantization.
|
145
|
-
symmetric: (bool) If True, symmetric quantization is used. Otherwise,
|
146
|
-
asymmetric quantization is used.
|
147
|
-
activation_order: (bool) If True, reorders the weight columns based on
|
148
|
-
activation magnitude, which can improve quantization accuracy.
|
149
|
-
weight_bits: (int) The number of bits to use for the quantized weights,
|
150
|
-
e.g., 4 for 4-bit quantization.
|
286
|
+
config: A GPTQConfiguration object.
|
151
287
|
|
152
288
|
Raises:
|
153
289
|
ValueError: If the function cannot automatically find an embedding
|
154
290
|
layer or any transformer-like blocks to quantize within the model.
|
155
291
|
"""
|
292
|
+
|
293
|
+
num_samples = config.num_samples
|
294
|
+
|
156
295
|
logging.info("Starting model quantization...")
|
157
296
|
embedding_layer = None
|
158
297
|
transformer_blocks = []
|
159
298
|
if hasattr(model, "backbone"):
|
160
299
|
logging.info("Detected KerasHub model structure.")
|
161
|
-
|
162
|
-
|
163
|
-
# Add the check for the 'transformer_layers' attribute.
|
164
|
-
if hasattr(backbone, "transformer_layers"):
|
165
|
-
transformer_blocks = backbone.transformer_layers
|
166
|
-
else:
|
167
|
-
# Raise a specific error if the attribute is missing.
|
168
|
-
raise ValueError(
|
169
|
-
"The model's backbone does not have a 'transformer_layers' "
|
170
|
-
"attribute. Please ensure you are using a standard KerasHub "
|
171
|
-
"transformer model."
|
172
|
-
)
|
173
|
-
# Find the embedding layer by checking for common names or by type.
|
174
|
-
if hasattr(backbone, "token_embedding"):
|
175
|
-
embedding_layer = backbone.token_embedding
|
176
|
-
elif hasattr(backbone, "embedding"):
|
177
|
-
embedding_layer = backbone.embedding
|
178
|
-
else:
|
179
|
-
raise ValueError(
|
180
|
-
"Could not automatically find an embedding layer in the model."
|
181
|
-
)
|
182
|
-
|
300
|
+
embedding_layer, transformer_blocks = _get_backbone_layers(model)
|
183
301
|
else:
|
184
302
|
logging.info("Detected custom model structure.")
|
185
|
-
|
186
|
-
# The first Embedding layer found is assumed to be the main one.
|
187
|
-
if isinstance(layer, Embedding) and embedding_layer is None:
|
188
|
-
embedding_layer = layer
|
189
|
-
# A "block" is a container-like layer with its own sub-layers
|
190
|
-
# that we can quantize. This is a heuristic that works for the
|
191
|
-
# test.
|
192
|
-
elif hasattr(layer, "_layers") and layer._layers:
|
193
|
-
transformer_blocks.append(layer)
|
303
|
+
embedding_layer, transformer_blocks = _get_custom_layers(model)
|
194
304
|
|
195
305
|
if embedding_layer is None:
|
196
306
|
raise ValueError(
|
@@ -207,6 +317,8 @@ def apply_gptq_layerwise(
|
|
207
317
|
embedding_layer(ops.convert_to_tensor(batch, dtype="int32"))
|
208
318
|
for batch in dataloader
|
209
319
|
]
|
320
|
+
num_samples = min(num_samples, len(inputs))
|
321
|
+
|
210
322
|
progbar = keras_utils.Progbar(target=len(transformer_blocks))
|
211
323
|
|
212
324
|
for block_idx, block in enumerate(transformer_blocks):
|
@@ -221,73 +333,23 @@ def apply_gptq_layerwise(
|
|
221
333
|
else:
|
222
334
|
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
|
223
335
|
gptq_objects = {
|
224
|
-
name: GPTQ(layer
|
336
|
+
name: GPTQ(layer, config)
|
337
|
+
for name, layer in sub_layers_map.items()
|
225
338
|
}
|
226
339
|
|
227
|
-
|
228
|
-
original_calls = {}
|
229
|
-
|
230
|
-
def create_hook(name, original_call_func):
|
231
|
-
"""A factory for creating a hook to capture layer inputs."""
|
232
|
-
|
233
|
-
def hook(*args, **kwargs):
|
234
|
-
if args:
|
235
|
-
inp = args[0]
|
236
|
-
else:
|
237
|
-
inp = kwargs["inputs"]
|
238
|
-
captured_inputs[name].append(inp)
|
239
|
-
return original_call_func(*args, **kwargs)
|
240
|
-
|
241
|
-
return hook
|
242
|
-
|
243
|
-
try:
|
244
|
-
for name, layer in sub_layers_map.items():
|
245
|
-
original_call = layer.call
|
246
|
-
original_calls[name] = original_call
|
247
|
-
layer.call = create_hook(name, original_call)
|
248
|
-
|
249
|
-
logging.info(f"Capturing activations for block {block_idx}...")
|
340
|
+
with stream_hessians(sub_layers_map, gptq_objects):
|
250
341
|
for sample_idx in range(num_samples):
|
251
342
|
current_input = inputs[sample_idx]
|
252
343
|
if len(current_input.shape) == 2:
|
253
344
|
current_input = ops.expand_dims(current_input, axis=0)
|
254
345
|
_ = block(current_input)
|
255
346
|
|
256
|
-
finally:
|
257
|
-
for name, layer in sub_layers_map.items():
|
258
|
-
if name in original_calls:
|
259
|
-
layer.call = original_calls[name]
|
260
|
-
|
261
|
-
logging.info(f"Building Hessians for block {block_idx}...")
|
262
|
-
for name, gptq_object in gptq_objects.items():
|
263
|
-
layer_inputs = ops.concatenate(captured_inputs[name], axis=0)
|
264
|
-
|
265
|
-
# Explicitly reshape the input tensor to be 2D, with the second
|
266
|
-
# dimension matching the number of input features expected by
|
267
|
-
# the layer's kernel.
|
268
|
-
# This correctly handles inputs of any dimensionality
|
269
|
-
# (e.g., 3D or 4D).
|
270
|
-
num_features = gptq_object.rows
|
271
|
-
input_reshaped = ops.reshape(layer_inputs, (-1, num_features))
|
272
|
-
gptq_object.update_hessian_with_batch(input_reshaped)
|
273
|
-
|
274
|
-
quantizer = GPTQQuantization(
|
275
|
-
weight_bits,
|
276
|
-
per_channel=True,
|
277
|
-
symmetric=symmetric,
|
278
|
-
group_size=group_size,
|
279
|
-
)
|
280
347
|
for name, gptq_object in gptq_objects.items():
|
281
348
|
logging.info(f"Quantizing {name}...")
|
282
|
-
gptq_object.
|
283
|
-
gptq_object.quantize_and_correct_block(
|
284
|
-
hessian_damping=hessian_damping,
|
285
|
-
group_size=group_size,
|
286
|
-
activation_order=activation_order,
|
287
|
-
)
|
349
|
+
gptq_object.quantize_and_correct_layer()
|
288
350
|
gptq_object.free()
|
289
351
|
|
290
|
-
del gptq_objects
|
352
|
+
del gptq_objects
|
291
353
|
|
292
354
|
if block_idx < len(transformer_blocks) - 1:
|
293
355
|
logging.info(f"Generating inputs for block {block_idx + 1}...")
|
@@ -304,32 +366,23 @@ def apply_gptq_layerwise(
|
|
304
366
|
logging.info("Quantization process complete.")
|
305
367
|
|
306
368
|
|
307
|
-
def
|
369
|
+
def gptq_quantize(model, config):
|
308
370
|
"""
|
309
371
|
Top-level function to quantize a Keras model using GPTQ.
|
310
372
|
"""
|
311
373
|
logging.info("Starting GPTQ quantization process...")
|
312
374
|
|
313
|
-
# Load
|
375
|
+
# Load all data needed from the generator/source in a single call.
|
314
376
|
total_samples_to_request = config.num_samples
|
315
|
-
|
377
|
+
dataloader = get_dataloader(
|
316
378
|
config.tokenizer,
|
317
379
|
config.sequence_length,
|
318
380
|
config.dataset,
|
319
381
|
num_samples=total_samples_to_request,
|
320
382
|
)
|
321
383
|
|
322
|
-
# Split the materialized data. This works because
|
384
|
+
# Split the materialized data. This works because dataloader
|
323
385
|
# is now a NumPy array, which can be sliced and reused.
|
324
|
-
calibration_dataloader =
|
325
|
-
|
326
|
-
apply_gptq_layerwise(
|
327
|
-
model,
|
328
|
-
calibration_dataloader, # Use the calibration slice
|
329
|
-
config.num_samples, # Use the configured number of samples
|
330
|
-
config.hessian_damping,
|
331
|
-
config.group_size,
|
332
|
-
config.symmetric,
|
333
|
-
config.activation_order,
|
334
|
-
config.weight_bits,
|
335
|
-
)
|
386
|
+
calibration_dataloader = dataloader[: config.num_samples]
|
387
|
+
|
388
|
+
apply_gptq_layerwise(model, calibration_dataloader, config)
|
@@ -9,6 +9,7 @@ from keras.src.backend import any_symbolic_tensors
|
|
9
9
|
from keras.src.backend.common.backend_utils import canonicalize_axis
|
10
10
|
from keras.src.backend.common.backend_utils import standardize_axis_for_numpy
|
11
11
|
from keras.src.ops.operation import Operation
|
12
|
+
from keras.src.quantizers.gptq_config import GPTQConfig
|
12
13
|
|
13
14
|
"""Int8-related classes and methods"""
|
14
15
|
|
@@ -634,3 +635,202 @@ def unpack_int4(packed, orig_len, axis=0):
|
|
634
635
|
unpacked = ops.transpose(unpacked, inv_perm)
|
635
636
|
|
636
637
|
return unpacked # dtype is int8
|
638
|
+
|
639
|
+
|
640
|
+
class GPTQQuantizer(Quantizer):
|
641
|
+
"""A class that handles the quantization of weights using GPTQ method.
|
642
|
+
|
643
|
+
This class provides methods to find quantization parameters (scale and zero)
|
644
|
+
for a given tensor and can be used to quantize weights in a GPTQ context.
|
645
|
+
|
646
|
+
Args:
|
647
|
+
weight_bits: (int) The number of bits to quantize to (e.g., 4).
|
648
|
+
per_channel: (bool) A flag indicating whether quantization is
|
649
|
+
applied per-channel (`True`) or per-tensor (`False`).
|
650
|
+
Defaults to `False`.
|
651
|
+
symmetric: (bool) A flag indicating whether symmetric (`True`) or
|
652
|
+
asymmetric (`False`) quantization is used. Defaults to `False`.
|
653
|
+
group_size: (int) The size of weight groups for quantization. A
|
654
|
+
value of -1 indicates that grouping is not used.
|
655
|
+
Defaults to -1.
|
656
|
+
"""
|
657
|
+
|
658
|
+
def __init__(self, config=GPTQConfig(tokenizer=None, dataset=None)):
|
659
|
+
Quantizer.__init__(self)
|
660
|
+
self.weight_bits = config.weight_bits
|
661
|
+
self.per_channel = config.per_channel
|
662
|
+
self.symmetric = config.symmetric
|
663
|
+
self.group_size = config.group_size
|
664
|
+
|
665
|
+
# These are now determined later by `find_params`
|
666
|
+
self.scale = None
|
667
|
+
self.zero = None
|
668
|
+
self.maxq = None
|
669
|
+
|
670
|
+
def find_params(self, input_tensor, weight=False):
|
671
|
+
"""Finds quantization parameters (scale and zero) for a given tensor."""
|
672
|
+
self.scale, self.zero, self.maxq = compute_quantization_parameters(
|
673
|
+
input_tensor,
|
674
|
+
bits=self.weight_bits,
|
675
|
+
symmetric=self.symmetric,
|
676
|
+
per_channel=self.per_channel,
|
677
|
+
group_size=self.group_size,
|
678
|
+
weight=weight,
|
679
|
+
)
|
680
|
+
return self.scale, self.zero, self.maxq
|
681
|
+
|
682
|
+
def ready(self):
|
683
|
+
"""Checks if the quantization parameters have been computed."""
|
684
|
+
return (
|
685
|
+
self.scale is not None
|
686
|
+
and self.zero is not None
|
687
|
+
and self.maxq is not None
|
688
|
+
)
|
689
|
+
|
690
|
+
def get_config(self):
|
691
|
+
config = super().get_config()
|
692
|
+
config.update(
|
693
|
+
{
|
694
|
+
"weight_bits": self.weight_bits,
|
695
|
+
"per_channel": self.per_channel,
|
696
|
+
"symmetric": self.symmetric,
|
697
|
+
"group_size": self.group_size,
|
698
|
+
}
|
699
|
+
)
|
700
|
+
return config
|
701
|
+
|
702
|
+
@classmethod
|
703
|
+
def from_config(cls, config):
|
704
|
+
gptq = GPTQConfig(
|
705
|
+
tokenizer=None,
|
706
|
+
dataset=None,
|
707
|
+
weight_bits=config["weight_bits"],
|
708
|
+
per_channel=config["per_channel"],
|
709
|
+
symmetric=config["symmetric"],
|
710
|
+
group_size=config["group_size"],
|
711
|
+
)
|
712
|
+
return cls(gptq)
|
713
|
+
|
714
|
+
|
715
|
+
def compute_quantization_parameters(
|
716
|
+
x, *, bits, symmetric=False, per_channel=False, group_size=-1, weight=False
|
717
|
+
):
|
718
|
+
"""
|
719
|
+
Computes the scale and zero-point for quantization.
|
720
|
+
|
721
|
+
Args:
|
722
|
+
x: KerasTensor. The input tensor to quantize.
|
723
|
+
bits: int. The number of bits to quantize to (e.g., 4).
|
724
|
+
symmetric: bool. Whether to use symmetric quantization.
|
725
|
+
per_channel: bool. Whether to quantize per channel.
|
726
|
+
group_size: int. The group size for quantization.
|
727
|
+
weight: bool. Whether the input tensor is a weight tensor.
|
728
|
+
"""
|
729
|
+
if x is None:
|
730
|
+
raise ValueError(f"Input tensor {x} cannot be None.")
|
731
|
+
|
732
|
+
# For weights, we typically expect at least a 2D tensor.
|
733
|
+
if weight and len(x.shape) < 2:
|
734
|
+
raise ValueError(
|
735
|
+
f"Input weight tensor {x} must have a rank of at "
|
736
|
+
f"least 2, but got rank {len(x.shape)}."
|
737
|
+
)
|
738
|
+
|
739
|
+
if ops.size(x) == 0:
|
740
|
+
raise ValueError("Input tensor 'x' cannot be empty.")
|
741
|
+
|
742
|
+
original_shape = x.shape
|
743
|
+
|
744
|
+
if per_channel:
|
745
|
+
if weight:
|
746
|
+
if group_size != -1:
|
747
|
+
input_reshaped = ops.reshape(x, [-1, group_size])
|
748
|
+
else:
|
749
|
+
input_reshaped = ops.reshape(x, [original_shape[0], -1])
|
750
|
+
else: # per-tensor
|
751
|
+
input_reshaped = ops.reshape(x, [1, -1])
|
752
|
+
|
753
|
+
# Find min/max values
|
754
|
+
min_values = ops.min(input_reshaped, axis=1)
|
755
|
+
max_values = ops.max(input_reshaped, axis=1)
|
756
|
+
|
757
|
+
# Apply symmetric quantization logic if enabled
|
758
|
+
if symmetric:
|
759
|
+
max_values = ops.maximum(ops.abs(min_values), max_values)
|
760
|
+
min_values = ops.where(
|
761
|
+
ops.less(min_values, 0), ops.negative(max_values), min_values
|
762
|
+
)
|
763
|
+
|
764
|
+
# Ensure range is not zero to avoid division errors
|
765
|
+
zero_range = ops.equal(min_values, max_values)
|
766
|
+
min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)
|
767
|
+
max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
|
768
|
+
|
769
|
+
maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), "float32")
|
770
|
+
|
771
|
+
# Calculate scale and zero-point
|
772
|
+
scale = ops.divide(ops.subtract(max_values, min_values), maxq)
|
773
|
+
if symmetric:
|
774
|
+
zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
|
775
|
+
else:
|
776
|
+
zero = ops.round(ops.divide(ops.negative(min_values), scale))
|
777
|
+
|
778
|
+
# Ensure scale is non-zero
|
779
|
+
scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
|
780
|
+
|
781
|
+
if weight:
|
782
|
+
# Per-channel, non-grouped case: simple reshape is correct.
|
783
|
+
if per_channel and group_size == -1:
|
784
|
+
scale = ops.reshape(scale, [-1, 1])
|
785
|
+
zero = ops.reshape(zero, [-1, 1])
|
786
|
+
elif not per_channel:
|
787
|
+
num_rows = original_shape[0]
|
788
|
+
scale = ops.tile(ops.reshape(scale, (1, 1)), (num_rows, 1))
|
789
|
+
zero = ops.tile(ops.reshape(zero, (1, 1)), (num_rows, 1))
|
790
|
+
if per_channel:
|
791
|
+
scale = ops.reshape(scale, [-1, 1])
|
792
|
+
zero = ops.reshape(zero, [-1, 1])
|
793
|
+
|
794
|
+
return scale, zero, maxq
|
795
|
+
|
796
|
+
|
797
|
+
def quantize_with_zero_point(input_tensor, scale, zero, maxq):
|
798
|
+
"""Quantize a float tensor into discrete levels [0, maxq] using
|
799
|
+
per-tensor/per-channel/grouped scaling.
|
800
|
+
|
801
|
+
Returns `q` (same dtype as inputs/scales; float is fine) where values are in
|
802
|
+
[0, maxq].
|
803
|
+
|
804
|
+
Args:
|
805
|
+
input_tensor: KerasTensor. The input tensor to quantize.
|
806
|
+
scale: KerasTensor. The scale tensor for quantization.
|
807
|
+
zero: KerasTensor. The zero tensor for quantization.
|
808
|
+
maxq: KerasTensor. The maximum quantization value.
|
809
|
+
|
810
|
+
Returns:
|
811
|
+
KerasTensor. The quantized tensor.
|
812
|
+
"""
|
813
|
+
# Guard against divide-by-zero
|
814
|
+
epsilon = ops.cast(1e-8, dtype=scale.dtype)
|
815
|
+
safe_scale = ops.where(ops.equal(scale, 0), epsilon, scale)
|
816
|
+
|
817
|
+
quantized_tensor = ops.round(
|
818
|
+
ops.add(ops.divide(input_tensor, safe_scale), zero)
|
819
|
+
)
|
820
|
+
quantized_tensor = ops.clip(quantized_tensor, 0, maxq)
|
821
|
+
return quantized_tensor
|
822
|
+
|
823
|
+
|
824
|
+
def dequantize_with_zero_point(input_tensor, scale, zero):
|
825
|
+
"""
|
826
|
+
Dequantizes a quantized tensor using the provided scale and zero tensors.
|
827
|
+
|
828
|
+
Args:
|
829
|
+
input_tensor: KerasTensor. The quantized tensor to dequantize.
|
830
|
+
scale: KerasTensor. The scale tensor for dequantization.
|
831
|
+
zero: KerasTensor. The zero tensor for dequantization.
|
832
|
+
|
833
|
+
Returns:
|
834
|
+
KerasTensor. The dequantized tensor.
|
835
|
+
"""
|
836
|
+
return ops.multiply(scale, ops.subtract(input_tensor, zero))
|