keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -1,23 +1,121 @@
|
|
|
1
|
-
import
|
|
1
|
+
import math
|
|
2
|
+
from contextlib import contextmanager
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
from absl import logging
|
|
5
6
|
|
|
6
7
|
from keras.src import ops
|
|
7
8
|
from keras.src import utils as keras_utils
|
|
9
|
+
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
|
|
10
|
+
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
8
11
|
from keras.src.layers import Dense
|
|
9
12
|
from keras.src.layers import EinsumDense
|
|
10
|
-
from keras.src.layers import Embedding
|
|
11
13
|
from keras.src.quantizers.gptq import GPTQ
|
|
12
|
-
from keras.src.quantizers.
|
|
14
|
+
from keras.src.quantizers.gptq_config import GPTQConfig
|
|
15
|
+
from keras.src.quantizers.utils import should_quantize_layer
|
|
13
16
|
|
|
14
17
|
|
|
15
|
-
|
|
18
|
+
@contextmanager
|
|
19
|
+
def stream_hessians(layers_map, gptq_objects):
|
|
16
20
|
"""
|
|
17
|
-
|
|
21
|
+
Temporarily monkey-patch each target layer's `call` method so
|
|
22
|
+
that input activations are streamed into the GPTQ instance
|
|
23
|
+
running Hessian estimate at capture time.
|
|
24
|
+
|
|
25
|
+
On `__enter__`: For every (name, layer) in `layers_map`, replaces
|
|
26
|
+
`layer.call` with a wrapper that:
|
|
27
|
+
1) extracts the layer input from `*args`/`**kwargs`,
|
|
28
|
+
2) reshapes it to 2D `[-1, rows]` where
|
|
29
|
+
`rows = gptq_objects[name].rows`,
|
|
30
|
+
3) calls `gptq_objects[name].update_hessian_with_batch(x2d)`
|
|
31
|
+
4) delegates to the original `layer.call` and returns its
|
|
32
|
+
output.
|
|
33
|
+
|
|
34
|
+
On `__exit__`: All original `layer.call` methods are restored even if an
|
|
35
|
+
exception occurs.
|
|
36
|
+
|
|
37
|
+
* Space complexity: O(d**2) per layer (for the Hessian).
|
|
38
|
+
* No weights are modified; only GPTQ statistics are updated.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
layers_map: Dict[str, Layer]. Mapping from logical layer names to
|
|
42
|
+
the Keras layers that should be patched during calibration. Keys must
|
|
43
|
+
match `gptq_objects`.
|
|
44
|
+
gptq_objects: Dict[str, GPTQ]. Mapping from names to GPTQ instances.
|
|
45
|
+
|
|
46
|
+
Yields:
|
|
47
|
+
None: The patched state is active only within the `with` block. After
|
|
48
|
+
exit, all layers are unpatched and safe to use normally.
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
```python
|
|
52
|
+
>>> with stream_hessians(layers_map, gptq_objects):
|
|
53
|
+
... for sample in calibration_inputs:
|
|
54
|
+
... if len(sample.shape) == 2:
|
|
55
|
+
... sample = ops.expand_dims(sample, 0)
|
|
56
|
+
... _ = block(sample) # hooks update Hessians on-the-fly
|
|
57
|
+
>>> # <- original layer.call methods restored here
|
|
58
|
+
```
|
|
18
59
|
"""
|
|
19
|
-
|
|
60
|
+
original_calls = {}
|
|
61
|
+
|
|
62
|
+
def create_hook(name, original_call_func):
|
|
63
|
+
def hook(*args, **kwargs):
|
|
64
|
+
inp = args[0] if args else kwargs["inputs"]
|
|
65
|
+
# Explicitly reshape the input tensor to be 2D, with the
|
|
66
|
+
# second dimension matching the number of input features
|
|
67
|
+
# expected by the layer's kernel.
|
|
68
|
+
# This correctly handles inputs of any dimensionality
|
|
69
|
+
# (e.g., 3D or 4D).
|
|
70
|
+
num_features = gptq_objects[name].rows
|
|
71
|
+
input_2d = ops.reshape(inp, (-1, num_features))
|
|
72
|
+
gptq_objects[name].update_hessian_with_batch(input_2d)
|
|
73
|
+
return original_call_func(*args, **kwargs)
|
|
74
|
+
|
|
75
|
+
return hook
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
for name, layer in layers_map.items():
|
|
79
|
+
original_calls[name] = layer.call
|
|
80
|
+
layer.call = create_hook(name, layer.call)
|
|
81
|
+
yield
|
|
82
|
+
finally:
|
|
83
|
+
for name, layer in layers_map.items():
|
|
84
|
+
layer.call = original_calls[name]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_dataloader(
|
|
88
|
+
tokenizer,
|
|
89
|
+
sequence_length,
|
|
90
|
+
dataset,
|
|
91
|
+
num_samples=128,
|
|
92
|
+
*,
|
|
93
|
+
strategy="strided",
|
|
94
|
+
seed=42,
|
|
95
|
+
stride=None,
|
|
96
|
+
eos_id=None,
|
|
97
|
+
):
|
|
98
|
+
"""
|
|
99
|
+
Prepares and chunks the calibration dataloader, repeating short datasets.
|
|
100
|
+
All processing happens on the CPU.
|
|
20
101
|
|
|
102
|
+
Args:
|
|
103
|
+
tokenizer: The tokenizer to use for text splitting.
|
|
104
|
+
sequence_length: The length of each input sequence.
|
|
105
|
+
dataset: The dataset to sample from.
|
|
106
|
+
num_samples: The number of samples to generate.
|
|
107
|
+
strategy: The sampling strategy to use. Possible values are
|
|
108
|
+
1. "strided": Samples are taken at regular intervals.
|
|
109
|
+
2. "linspace": Samples are taken at evenly spaced intervals.
|
|
110
|
+
3. "random": Samples are taken at random positions.
|
|
111
|
+
seed: The random seed for reproducibility. Used only if
|
|
112
|
+
strategy="random"
|
|
113
|
+
stride: The stride length for "strided" sampling.
|
|
114
|
+
eos_id: The end-of-sequence token ID.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
np.ndarray of shape (num_samples, 1, sequence_length), dtype int32.
|
|
118
|
+
"""
|
|
21
119
|
if not hasattr(dataset, "__iter__") or isinstance(dataset, (str, bytes)):
|
|
22
120
|
raise TypeError(
|
|
23
121
|
"The `dataset` argument must be an iterable (e.g., a list of "
|
|
@@ -27,267 +125,184 @@ def get_dataloader(tokenizer, sequence_length, dataset, num_samples=128):
|
|
|
27
125
|
)
|
|
28
126
|
|
|
29
127
|
dataset_list = list(dataset)
|
|
30
|
-
|
|
31
128
|
if not dataset_list:
|
|
32
129
|
raise ValueError("Provided dataset is empty.")
|
|
33
130
|
|
|
131
|
+
pieces = []
|
|
34
132
|
if isinstance(dataset_list[0], str):
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
133
|
+
for i, s in enumerate(dataset_list):
|
|
134
|
+
toks = ops.convert_to_numpy(tokenizer.tokenize(s)).reshape(-1)
|
|
135
|
+
pieces.append(toks)
|
|
136
|
+
# avoid windows that span document boundaries
|
|
137
|
+
if eos_id is not None and i < len(dataset_list) - 1:
|
|
138
|
+
pieces.append(np.array([eos_id], dtype=np.int32))
|
|
38
139
|
else:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
140
|
+
for s in dataset_list:
|
|
141
|
+
toks = ops.convert_to_numpy(s).reshape(-1)
|
|
142
|
+
pieces.append(toks.astype(np.int32, copy=False))
|
|
143
|
+
|
|
144
|
+
all_tokens = (
|
|
145
|
+
pieces[0].astype(np.int32, copy=False)
|
|
146
|
+
if len(pieces) == 1
|
|
147
|
+
else np.concatenate(pieces, axis=0).astype(np.int32, copy=False)
|
|
148
|
+
)
|
|
45
149
|
|
|
46
|
-
# Repeat data if it's too short
|
|
47
150
|
required_tokens = num_samples * sequence_length
|
|
48
|
-
if
|
|
49
|
-
|
|
50
|
-
f"Warning: Dataset is too short ({len(all_tokens)} tokens)."
|
|
51
|
-
" Repeating data to generate {num_samples} samples."
|
|
52
|
-
)
|
|
53
|
-
repeats = -(-required_tokens // len(all_tokens)) # Ceiling division
|
|
151
|
+
if all_tokens.size < required_tokens:
|
|
152
|
+
repeats = math.ceil(required_tokens / max(1, all_tokens.size))
|
|
54
153
|
all_tokens = np.tile(all_tokens, repeats)
|
|
55
154
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
end_index = start_index + sequence_length
|
|
63
|
-
sample = all_tokens[start_index:end_index]
|
|
64
|
-
calibration_samples.append(np.reshape(sample, (1, sequence_length)))
|
|
65
|
-
|
|
66
|
-
final_array = np.stack(calibration_samples, axis=0)
|
|
67
|
-
return final_array
|
|
68
|
-
|
|
155
|
+
max_start = all_tokens.size - sequence_length
|
|
156
|
+
if max_start < 0:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Not enough tokens to form one sample of length {sequence_length} "
|
|
159
|
+
f"(have {all_tokens.size})."
|
|
160
|
+
)
|
|
69
161
|
|
|
70
|
-
|
|
71
|
-
""
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
162
|
+
# Choose deterministic, well-spread starts by default
|
|
163
|
+
if strategy == "random":
|
|
164
|
+
rng = np.random.default_rng(seed)
|
|
165
|
+
starts = rng.integers(
|
|
166
|
+
0, max_start + 1, size=num_samples, dtype=np.int64
|
|
167
|
+
)
|
|
168
|
+
elif strategy == "linspace":
|
|
169
|
+
# even coverage with no RNG
|
|
170
|
+
starts = np.linspace(0, max_start, num_samples, dtype=np.int64)
|
|
171
|
+
elif strategy == "strided":
|
|
172
|
+
# stride chosen to cover the space roughly uniformly
|
|
173
|
+
if stride is None:
|
|
174
|
+
stride = max(1, (max_start + 1) // num_samples)
|
|
175
|
+
# offset derived deterministically from seed
|
|
176
|
+
offset = (
|
|
177
|
+
(abs(hash(("gptq-calib", seed))) % (max_start + 1))
|
|
178
|
+
if max_start > 0
|
|
179
|
+
else 0
|
|
180
|
+
)
|
|
181
|
+
starts = (offset + np.arange(num_samples, dtype=np.int64) * stride) % (
|
|
182
|
+
max_start + 1
|
|
183
|
+
)
|
|
184
|
+
else:
|
|
185
|
+
raise ValueError(f"Unknown strategy: {strategy}")
|
|
79
186
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
187
|
+
# Gather contiguous windows
|
|
188
|
+
# sliding_window_view avoids building a big index matrix
|
|
189
|
+
windows = np.lib.stride_tricks.sliding_window_view(
|
|
190
|
+
all_tokens, sequence_length
|
|
191
|
+
)
|
|
192
|
+
samples = windows[starts] # (num_samples, sequence_length)
|
|
193
|
+
return samples.astype(np.int32)[:, None, :]
|
|
83
194
|
|
|
84
195
|
|
|
85
196
|
def find_layers_in_block(block):
|
|
86
197
|
"""
|
|
87
|
-
|
|
88
|
-
|
|
198
|
+
Finds all Dense and EinsumDense layers in a transformer block.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
block: A Keras layer representing a transformer block.
|
|
202
|
+
Returns:
|
|
203
|
+
A dict mapping layer paths to the corresponding Dense or EinsumDense
|
|
89
204
|
"""
|
|
90
205
|
found_layers = {}
|
|
91
|
-
|
|
92
|
-
|
|
206
|
+
for sub_layer in block._flatten_layers():
|
|
207
|
+
if len(list(sub_layer._flatten_layers())) == 1:
|
|
208
|
+
if isinstance(sub_layer, (Dense, EinsumDense)):
|
|
209
|
+
found_layers[sub_layer.path] = sub_layer
|
|
93
210
|
return found_layers
|
|
94
211
|
|
|
95
212
|
|
|
96
|
-
def apply_gptq_layerwise(
|
|
97
|
-
model,
|
|
98
|
-
dataloader,
|
|
99
|
-
num_samples,
|
|
100
|
-
hessian_damping,
|
|
101
|
-
group_size,
|
|
102
|
-
symmetric,
|
|
103
|
-
activation_order,
|
|
104
|
-
weight_bits,
|
|
105
|
-
):
|
|
213
|
+
def apply_gptq_layerwise(dataloader, config, structure, filters=None):
|
|
106
214
|
"""Applies GPTQ quantization layer-by-layer to a Keras model.
|
|
107
215
|
|
|
108
|
-
This function
|
|
109
|
-
|
|
110
|
-
structure by first looking for the standard format: a `model.backbone`
|
|
111
|
-
attribute that contains a `transformer_layers` list.
|
|
112
|
-
|
|
113
|
-
If a standard backbone is not found, it falls back to a heuristic for
|
|
114
|
-
custom models, where it assumes the first `keras.layers.Embedding` layer
|
|
115
|
-
is the input embedding and any subsequent container layers are the
|
|
116
|
-
transformer blocks to be quantized.
|
|
216
|
+
This function uses the provided `structure` to identify pre-quantization
|
|
217
|
+
layers and sequential blocks.
|
|
117
218
|
|
|
118
219
|
The core logic operates as follows:
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
2. It processes the model sequentially, one block at a time. For each
|
|
220
|
+
|
|
221
|
+
1. It processes the model sequentially, one block at a time. For each
|
|
122
222
|
block, it uses temporary hooks to capture the input activations of
|
|
123
223
|
each target layer during a forward pass with the calibration data.
|
|
124
|
-
|
|
224
|
+
2. These captured activations are used to compute the Hessian matrix for
|
|
125
225
|
each layer's weights.
|
|
126
|
-
|
|
226
|
+
3. The GPTQ algorithm is then applied to each layer to find the optimal
|
|
127
227
|
quantized weights that minimize the error introduced.
|
|
128
|
-
|
|
228
|
+
4. The output activations from the current block are then used as the
|
|
129
229
|
input for the next block, ensuring that quantization errors are
|
|
130
230
|
accounted for throughout the model.
|
|
131
231
|
|
|
132
232
|
Args:
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
calibration.
|
|
139
|
-
hessian_damping: (float) The percentage of dampening to add to the
|
|
140
|
-
Hessian diagonal for stabilization during inverse calculation.
|
|
141
|
-
A value of 0.01 is common.
|
|
142
|
-
group_size: (int) The size of the groups to use for quantization. A
|
|
143
|
-
value of 128 means that 128 weights will share the same scaling
|
|
144
|
-
factor. Use -1 for per-channel quantization.
|
|
145
|
-
symmetric: (bool) If True, symmetric quantization is used. Otherwise,
|
|
146
|
-
asymmetric quantization is used.
|
|
147
|
-
activation_order: (bool) If True, reorders the weight columns based on
|
|
148
|
-
activation magnitude, which can improve quantization accuracy.
|
|
149
|
-
weight_bits: (int) The number of bits to use for the quantized weights,
|
|
150
|
-
e.g., 4 for 4-bit quantization.
|
|
233
|
+
dataloader: An iterable providing calibration data.
|
|
234
|
+
config: A GPTQConfiguration object.
|
|
235
|
+
structure: A dictionary with keys "pre_block_layers" and
|
|
236
|
+
"sequential_blocks".
|
|
237
|
+
filters: Optional filters to exclude layers from quantization.
|
|
151
238
|
|
|
152
239
|
Raises:
|
|
153
240
|
ValueError: If the function cannot automatically find an embedding
|
|
154
241
|
layer or any transformer-like blocks to quantize within the model.
|
|
155
242
|
"""
|
|
243
|
+
|
|
244
|
+
num_samples = config.num_samples
|
|
245
|
+
|
|
156
246
|
logging.info("Starting model quantization...")
|
|
157
|
-
embedding_layer = None
|
|
158
|
-
transformer_blocks = []
|
|
159
|
-
if hasattr(model, "backbone"):
|
|
160
|
-
logging.info("Detected KerasHub model structure.")
|
|
161
|
-
backbone = model.backbone
|
|
162
|
-
|
|
163
|
-
# Add the check for the 'transformer_layers' attribute.
|
|
164
|
-
if hasattr(backbone, "transformer_layers"):
|
|
165
|
-
transformer_blocks = backbone.transformer_layers
|
|
166
|
-
else:
|
|
167
|
-
# Raise a specific error if the attribute is missing.
|
|
168
|
-
raise ValueError(
|
|
169
|
-
"The model's backbone does not have a 'transformer_layers' "
|
|
170
|
-
"attribute. Please ensure you are using a standard KerasHub "
|
|
171
|
-
"transformer model."
|
|
172
|
-
)
|
|
173
|
-
# Find the embedding layer by checking for common names or by type.
|
|
174
|
-
if hasattr(backbone, "token_embedding"):
|
|
175
|
-
embedding_layer = backbone.token_embedding
|
|
176
|
-
elif hasattr(backbone, "embedding"):
|
|
177
|
-
embedding_layer = backbone.embedding
|
|
178
|
-
else:
|
|
179
|
-
raise ValueError(
|
|
180
|
-
"Could not automatically find an embedding layer in the model."
|
|
181
|
-
)
|
|
182
247
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
# The first Embedding layer found is assumed to be the main one.
|
|
187
|
-
if isinstance(layer, Embedding) and embedding_layer is None:
|
|
188
|
-
embedding_layer = layer
|
|
189
|
-
# A "block" is a container-like layer with its own sub-layers
|
|
190
|
-
# that we can quantize. This is a heuristic that works for the
|
|
191
|
-
# test.
|
|
192
|
-
elif hasattr(layer, "_layers") and layer._layers:
|
|
193
|
-
transformer_blocks.append(layer)
|
|
194
|
-
|
|
195
|
-
if embedding_layer is None:
|
|
196
|
-
raise ValueError(
|
|
197
|
-
"Could not automatically find an embedding layer in the model."
|
|
198
|
-
)
|
|
248
|
+
pre_layers = structure.get("pre_block_layers", [])
|
|
249
|
+
transformer_blocks = structure.get("sequential_blocks", [])
|
|
250
|
+
|
|
199
251
|
if not transformer_blocks:
|
|
200
252
|
raise ValueError(
|
|
201
|
-
"
|
|
202
|
-
"quantize."
|
|
253
|
+
"No sequential blocks found in the provided structure to quantize."
|
|
203
254
|
)
|
|
204
255
|
|
|
205
|
-
# Initial inputs are the outputs of the
|
|
206
|
-
inputs = [
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
256
|
+
# Initial inputs are the outputs of the pre-block layers
|
|
257
|
+
inputs = []
|
|
258
|
+
for batch in dataloader:
|
|
259
|
+
batch = ops.convert_to_tensor(batch, dtype="int32")
|
|
260
|
+
for layer in pre_layers:
|
|
261
|
+
batch = layer(batch)
|
|
262
|
+
inputs.append(batch)
|
|
263
|
+
|
|
264
|
+
num_samples = min(num_samples, len(inputs))
|
|
265
|
+
|
|
210
266
|
progbar = keras_utils.Progbar(target=len(transformer_blocks))
|
|
211
267
|
|
|
212
268
|
for block_idx, block in enumerate(transformer_blocks):
|
|
213
269
|
logging.info(f"Quantizing Block {block_idx}")
|
|
214
270
|
sub_layers_map = find_layers_in_block(block)
|
|
215
271
|
|
|
272
|
+
# Filter out layers that are not quantized with GPTQ
|
|
273
|
+
final_sub_layers_map = {}
|
|
274
|
+
for name, layer in sub_layers_map.items():
|
|
275
|
+
if not should_quantize_layer(layer, filters):
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
final_sub_layers_map[name] = layer
|
|
279
|
+
|
|
280
|
+
sub_layers_map = final_sub_layers_map
|
|
281
|
+
|
|
216
282
|
if not sub_layers_map:
|
|
217
283
|
logging.info(
|
|
218
|
-
f" No
|
|
219
|
-
"Skipping."
|
|
284
|
+
f" No quantizable layers found in block {block_idx}. Skipping."
|
|
220
285
|
)
|
|
221
286
|
else:
|
|
222
287
|
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
|
|
223
288
|
gptq_objects = {
|
|
224
|
-
name: GPTQ(layer
|
|
289
|
+
name: GPTQ(layer, config)
|
|
290
|
+
for name, layer in sub_layers_map.items()
|
|
225
291
|
}
|
|
226
292
|
|
|
227
|
-
|
|
228
|
-
original_calls = {}
|
|
229
|
-
|
|
230
|
-
def create_hook(name, original_call_func):
|
|
231
|
-
"""A factory for creating a hook to capture layer inputs."""
|
|
232
|
-
|
|
233
|
-
def hook(*args, **kwargs):
|
|
234
|
-
if args:
|
|
235
|
-
inp = args[0]
|
|
236
|
-
else:
|
|
237
|
-
inp = kwargs["inputs"]
|
|
238
|
-
captured_inputs[name].append(inp)
|
|
239
|
-
return original_call_func(*args, **kwargs)
|
|
240
|
-
|
|
241
|
-
return hook
|
|
242
|
-
|
|
243
|
-
try:
|
|
244
|
-
for name, layer in sub_layers_map.items():
|
|
245
|
-
original_call = layer.call
|
|
246
|
-
original_calls[name] = original_call
|
|
247
|
-
layer.call = create_hook(name, original_call)
|
|
248
|
-
|
|
249
|
-
logging.info(f"Capturing activations for block {block_idx}...")
|
|
293
|
+
with stream_hessians(sub_layers_map, gptq_objects):
|
|
250
294
|
for sample_idx in range(num_samples):
|
|
251
295
|
current_input = inputs[sample_idx]
|
|
252
296
|
if len(current_input.shape) == 2:
|
|
253
297
|
current_input = ops.expand_dims(current_input, axis=0)
|
|
254
298
|
_ = block(current_input)
|
|
255
299
|
|
|
256
|
-
finally:
|
|
257
|
-
for name, layer in sub_layers_map.items():
|
|
258
|
-
if name in original_calls:
|
|
259
|
-
layer.call = original_calls[name]
|
|
260
|
-
|
|
261
|
-
logging.info(f"Building Hessians for block {block_idx}...")
|
|
262
|
-
for name, gptq_object in gptq_objects.items():
|
|
263
|
-
layer_inputs = ops.concatenate(captured_inputs[name], axis=0)
|
|
264
|
-
|
|
265
|
-
# Explicitly reshape the input tensor to be 2D, with the second
|
|
266
|
-
# dimension matching the number of input features expected by
|
|
267
|
-
# the layer's kernel.
|
|
268
|
-
# This correctly handles inputs of any dimensionality
|
|
269
|
-
# (e.g., 3D or 4D).
|
|
270
|
-
num_features = gptq_object.rows
|
|
271
|
-
input_reshaped = ops.reshape(layer_inputs, (-1, num_features))
|
|
272
|
-
gptq_object.update_hessian_with_batch(input_reshaped)
|
|
273
|
-
|
|
274
|
-
quantizer = GPTQQuantization(
|
|
275
|
-
weight_bits,
|
|
276
|
-
per_channel=True,
|
|
277
|
-
symmetric=symmetric,
|
|
278
|
-
group_size=group_size,
|
|
279
|
-
)
|
|
280
300
|
for name, gptq_object in gptq_objects.items():
|
|
281
301
|
logging.info(f"Quantizing {name}...")
|
|
282
|
-
gptq_object.
|
|
283
|
-
gptq_object.quantize_and_correct_block(
|
|
284
|
-
hessian_damping=hessian_damping,
|
|
285
|
-
group_size=group_size,
|
|
286
|
-
activation_order=activation_order,
|
|
287
|
-
)
|
|
302
|
+
gptq_object.quantize_and_correct_layer()
|
|
288
303
|
gptq_object.free()
|
|
289
304
|
|
|
290
|
-
del gptq_objects
|
|
305
|
+
del gptq_objects
|
|
291
306
|
|
|
292
307
|
if block_idx < len(transformer_blocks) - 1:
|
|
293
308
|
logging.info(f"Generating inputs for block {block_idx + 1}...")
|
|
@@ -304,32 +319,130 @@ def apply_gptq_layerwise(
|
|
|
304
319
|
logging.info("Quantization process complete.")
|
|
305
320
|
|
|
306
321
|
|
|
307
|
-
def
|
|
322
|
+
def gptq_quantize(config, quantization_layer_structure, filters=None):
|
|
308
323
|
"""
|
|
309
|
-
|
|
324
|
+
Quantizes the model using GPTQ.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
config: The GPTQ configuration.
|
|
328
|
+
quantization_layer_structure: A dictionary describing the model's layer
|
|
329
|
+
structure for quantization.
|
|
330
|
+
filters: Optional filters to exclude layers from quantization.
|
|
310
331
|
"""
|
|
311
|
-
|
|
332
|
+
if config.dataset is None or config.tokenizer is None:
|
|
333
|
+
raise ValueError(
|
|
334
|
+
"GPTQ quantization requires a dataset and a tokenizer. "
|
|
335
|
+
"Please provide them in the `GPTQConfig`."
|
|
336
|
+
)
|
|
312
337
|
|
|
313
|
-
|
|
338
|
+
if quantization_layer_structure is None:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
"For 'gptq' mode, a valid quantization structure must be provided "
|
|
341
|
+
"either via `config.quantization_layer_structure` or by overriding "
|
|
342
|
+
"`model.get_quantization_layer_structure(mode)`. The structure "
|
|
343
|
+
"should be a dictionary with keys 'pre_block_layers' and "
|
|
344
|
+
"'sequential_blocks'."
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Load all data needed from the generator/source in a single call.
|
|
314
348
|
total_samples_to_request = config.num_samples
|
|
315
|
-
|
|
349
|
+
dataloader = get_dataloader(
|
|
316
350
|
config.tokenizer,
|
|
317
351
|
config.sequence_length,
|
|
318
352
|
config.dataset,
|
|
319
353
|
num_samples=total_samples_to_request,
|
|
320
354
|
)
|
|
321
355
|
|
|
322
|
-
# Split the materialized data. This works because
|
|
356
|
+
# Split the materialized data. This works because dataloader
|
|
323
357
|
# is now a NumPy array, which can be sliced and reused.
|
|
324
|
-
calibration_dataloader =
|
|
358
|
+
calibration_dataloader = dataloader[: config.num_samples]
|
|
325
359
|
|
|
326
360
|
apply_gptq_layerwise(
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
config.group_size,
|
|
332
|
-
config.symmetric,
|
|
333
|
-
config.activation_order,
|
|
334
|
-
config.weight_bits,
|
|
361
|
+
calibration_dataloader,
|
|
362
|
+
config,
|
|
363
|
+
quantization_layer_structure,
|
|
364
|
+
filters=filters,
|
|
335
365
|
)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def get_group_size_for_layer(layer, config):
|
|
369
|
+
"""Determine the group size for GPTQ quantization.
|
|
370
|
+
|
|
371
|
+
The group size can be specified either through the `config` argument
|
|
372
|
+
or through the `dtype_policy` if it is of type `GPTQDTypePolicy`.
|
|
373
|
+
|
|
374
|
+
The config argument is usually available when quantizing the layer
|
|
375
|
+
via the `quantize` method. If the layer was deserialized from a
|
|
376
|
+
saved model, the group size should be specified in the `dtype_policy`.
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
config: An optional configuration object that may contain the
|
|
380
|
+
`group_size` attribute.
|
|
381
|
+
Returns:
|
|
382
|
+
int. The determined group size for GPTQ quantization.
|
|
383
|
+
Raises:
|
|
384
|
+
ValueError: If the group size is not specified in either the
|
|
385
|
+
`config` or the `dtype_policy`.
|
|
386
|
+
"""
|
|
387
|
+
if config and isinstance(config, GPTQConfig):
|
|
388
|
+
return config.group_size
|
|
389
|
+
elif isinstance(layer.dtype_policy, GPTQDTypePolicy):
|
|
390
|
+
return layer.dtype_policy.group_size
|
|
391
|
+
elif isinstance(layer.dtype_policy, DTypePolicyMap):
|
|
392
|
+
policy = layer.dtype_policy[layer.path]
|
|
393
|
+
if not isinstance(policy, GPTQDTypePolicy):
|
|
394
|
+
# This should never happen based on how we set the
|
|
395
|
+
# quantization mode, but we check just in case.
|
|
396
|
+
raise ValueError(
|
|
397
|
+
"Expected a `dtype_policy` of type `GPTQDTypePolicy`."
|
|
398
|
+
f"Got: {type(policy)}"
|
|
399
|
+
)
|
|
400
|
+
return policy.group_size
|
|
401
|
+
else:
|
|
402
|
+
raise ValueError(
|
|
403
|
+
"For GPTQ quantization, the group_size must be specified"
|
|
404
|
+
"either through a `dtype_policy` of type "
|
|
405
|
+
"`GPTQDTypePolicy` or the `config` argument."
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def get_weight_bits_for_layer(layer, config):
|
|
410
|
+
"""Determine the number of weight bits for GPTQ quantization.
|
|
411
|
+
|
|
412
|
+
The number of weight bits can be specified either through the `config`
|
|
413
|
+
argument or through the `dtype_policy` if it is of type
|
|
414
|
+
`GPTQDTypePolicy`.
|
|
415
|
+
|
|
416
|
+
The config argument is usually available when quantizing the layer
|
|
417
|
+
via the `quantize` method. If the layer was deserialized from a
|
|
418
|
+
saved model, the weight bits should be specified in the `dtype_policy`.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
config: An optional configuration object that may contain the
|
|
422
|
+
`weight_bits` attribute.
|
|
423
|
+
Returns:
|
|
424
|
+
int. The determined number of weight bits for GPTQ quantization.
|
|
425
|
+
Raises:
|
|
426
|
+
ValueError: If the weight bits is not specified in either the
|
|
427
|
+
`config` or the `dtype_policy`.
|
|
428
|
+
"""
|
|
429
|
+
if config and isinstance(config, GPTQConfig):
|
|
430
|
+
return config.weight_bits
|
|
431
|
+
elif isinstance(layer.dtype_policy, GPTQDTypePolicy):
|
|
432
|
+
return layer.dtype_policy.weight_bits
|
|
433
|
+
elif isinstance(layer.dtype_policy, DTypePolicyMap):
|
|
434
|
+
policy = layer.dtype_policy[layer.path]
|
|
435
|
+
if not isinstance(policy, GPTQDTypePolicy):
|
|
436
|
+
# This should never happen based on how we set the
|
|
437
|
+
# quantization mode, but we check just in case.
|
|
438
|
+
raise ValueError(
|
|
439
|
+
"Expected a `dtype_policy` of type `GPTQDTypePolicy`."
|
|
440
|
+
f"Got: {type(policy)}"
|
|
441
|
+
)
|
|
442
|
+
return policy.weight_bits
|
|
443
|
+
else:
|
|
444
|
+
raise ValueError(
|
|
445
|
+
"For GPTQ quantization, the weight_bits must be specified"
|
|
446
|
+
"either through a `dtype_policy` of type "
|
|
447
|
+
"`GPTQDTypePolicy` or the `config` argument."
|
|
448
|
+
)
|