keras-nightly 3.12.0.dev2025082103__py3-none-any.whl → 3.12.0.dev2025082303__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/ops/__init__.py +1 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/ops/__init__.py +1 -0
- keras/ops/numpy/__init__.py +1 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/applications/convnext.py +20 -20
- keras/src/applications/densenet.py +21 -21
- keras/src/applications/efficientnet.py +16 -16
- keras/src/applications/efficientnet_v2.py +28 -28
- keras/src/applications/inception_resnet_v2.py +7 -7
- keras/src/applications/inception_v3.py +5 -5
- keras/src/applications/mobilenet_v2.py +13 -20
- keras/src/applications/mobilenet_v3.py +15 -15
- keras/src/applications/nasnet.py +7 -8
- keras/src/applications/resnet.py +32 -32
- keras/src/applications/xception.py +10 -10
- keras/src/backend/common/dtypes.py +8 -3
- keras/src/backend/common/variables.py +3 -1
- keras/src/backend/jax/export.py +1 -1
- keras/src/backend/jax/numpy.py +6 -0
- keras/src/backend/jax/trainer.py +1 -1
- keras/src/backend/numpy/numpy.py +28 -0
- keras/src/backend/openvino/numpy.py +5 -1
- keras/src/backend/tensorflow/numpy.py +22 -0
- keras/src/backend/tensorflow/trainer.py +19 -1
- keras/src/backend/torch/core.py +6 -9
- keras/src/backend/torch/nn.py +1 -2
- keras/src/backend/torch/numpy.py +16 -0
- keras/src/backend/torch/trainer.py +1 -1
- keras/src/callbacks/backup_and_restore.py +2 -2
- keras/src/callbacks/csv_logger.py +1 -1
- keras/src/callbacks/model_checkpoint.py +1 -1
- keras/src/callbacks/tensorboard.py +6 -6
- keras/src/constraints/constraints.py +9 -7
- keras/src/datasets/boston_housing.py +1 -1
- keras/src/datasets/california_housing.py +1 -1
- keras/src/datasets/cifar10.py +1 -1
- keras/src/datasets/cifar100.py +2 -2
- keras/src/datasets/imdb.py +2 -2
- keras/src/datasets/mnist.py +1 -1
- keras/src/datasets/reuters.py +2 -2
- keras/src/dtype_policies/dtype_policy.py +1 -1
- keras/src/dtype_policies/dtype_policy_map.py +1 -1
- keras/src/export/tf2onnx_lib.py +1 -3
- keras/src/initializers/constant_initializers.py +9 -5
- keras/src/layers/input_spec.py +6 -6
- keras/src/layers/layer.py +1 -1
- keras/src/layers/preprocessing/category_encoding.py +3 -3
- keras/src/layers/preprocessing/data_layer.py +159 -0
- keras/src/layers/preprocessing/discretization.py +3 -3
- keras/src/layers/preprocessing/feature_space.py +4 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +7 -4
- keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/center_crop.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/cut_mix.py +6 -3
- keras/src/layers/preprocessing/image_preprocessing/equalization.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/mix_up.py +7 -4
- keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +3 -1
- keras/src/layers/preprocessing/image_preprocessing/random_brightness.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_crop.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +6 -3
- keras/src/layers/preprocessing/image_preprocessing/random_flip.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_hue.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_invert.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_perspective.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_posterization.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +1 -1
- keras/src/layers/preprocessing/image_preprocessing/random_saturation.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_shear.py +3 -0
- keras/src/layers/preprocessing/image_preprocessing/random_translation.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/solarization.py +3 -0
- keras/src/layers/preprocessing/mel_spectrogram.py +29 -25
- keras/src/layers/preprocessing/normalization.py +5 -2
- keras/src/layers/preprocessing/rescaling.py +3 -3
- keras/src/layers/rnn/bidirectional.py +4 -4
- keras/src/legacy/backend.py +9 -23
- keras/src/legacy/preprocessing/image.py +11 -22
- keras/src/legacy/preprocessing/text.py +1 -1
- keras/src/models/functional.py +2 -2
- keras/src/models/model.py +21 -3
- keras/src/ops/function.py +1 -1
- keras/src/ops/numpy.py +49 -5
- keras/src/ops/operation.py +3 -2
- keras/src/optimizers/base_optimizer.py +3 -4
- keras/src/optimizers/schedules/learning_rate_schedule.py +16 -9
- keras/src/quantizers/gptq.py +350 -0
- keras/src/quantizers/gptq_config.py +169 -0
- keras/src/quantizers/gptq_core.py +335 -0
- keras/src/quantizers/gptq_quant.py +133 -0
- keras/src/saving/file_editor.py +22 -20
- keras/src/saving/object_registration.py +1 -1
- keras/src/saving/saving_lib.py +4 -4
- keras/src/saving/serialization_lib.py +3 -5
- keras/src/trainers/compile_utils.py +1 -1
- keras/src/trainers/data_adapters/array_data_adapter.py +9 -3
- keras/src/trainers/data_adapters/data_adapter_utils.py +15 -5
- keras/src/trainers/data_adapters/generator_data_adapter.py +2 -0
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +8 -2
- keras/src/trainers/data_adapters/tf_dataset_adapter.py +4 -2
- keras/src/trainers/data_adapters/torch_data_loader_adapter.py +3 -1
- keras/src/tree/dmtree_impl.py +19 -3
- keras/src/tree/optree_impl.py +3 -3
- keras/src/tree/tree_api.py +5 -2
- keras/src/utils/file_utils.py +13 -5
- keras/src/utils/io_utils.py +1 -1
- keras/src/utils/model_visualization.py +1 -1
- keras/src/utils/progbar.py +5 -5
- keras/src/utils/summary_utils.py +4 -4
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/METADATA +1 -1
- {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/RECORD +125 -121
- keras/src/layers/preprocessing/tf_data_layer.py +0 -78
- {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,335 @@
|
|
1
|
+
import random
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from absl import logging
|
5
|
+
|
6
|
+
from keras.src import ops
|
7
|
+
from keras.src import utils as keras_utils
|
8
|
+
from keras.src.layers import Dense
|
9
|
+
from keras.src.layers import EinsumDense
|
10
|
+
from keras.src.layers import Embedding
|
11
|
+
from keras.src.quantizers.gptq import GPTQ
|
12
|
+
from keras.src.quantizers.gptq_quant import GPTQQuantization
|
13
|
+
|
14
|
+
|
15
|
+
def get_dataloader(tokenizer, sequence_length, dataset, num_samples=128):
|
16
|
+
"""
|
17
|
+
Prepares and chunks the calibration dataloader, repeating short datasets.
|
18
|
+
"""
|
19
|
+
all_tokens = []
|
20
|
+
|
21
|
+
if not hasattr(dataset, "__iter__") or isinstance(dataset, (str, bytes)):
|
22
|
+
raise TypeError(
|
23
|
+
"The `dataset` argument must be an iterable (e.g., a list of "
|
24
|
+
"strings, a generator, or a NumPy array). Got type: "
|
25
|
+
f"{type(dataset).__name__}. Please pass the loaded dataset "
|
26
|
+
"directly."
|
27
|
+
)
|
28
|
+
|
29
|
+
dataset_list = list(dataset)
|
30
|
+
|
31
|
+
if not dataset_list:
|
32
|
+
raise ValueError("Provided dataset is empty.")
|
33
|
+
|
34
|
+
if isinstance(dataset_list[0], str):
|
35
|
+
logging.info("(Dataset contains strings, tokenizing now...)")
|
36
|
+
full_text = "\n\n".join(dataset_list)
|
37
|
+
all_tokens = tokenizer.tokenize(full_text)
|
38
|
+
else:
|
39
|
+
logging.info("(Dataset is pre-tokenized, concatenating...)")
|
40
|
+
all_tokens = np.concatenate(
|
41
|
+
[ops.convert_to_numpy(s).reshape(-1) for s in dataset_list], axis=0
|
42
|
+
)
|
43
|
+
|
44
|
+
all_tokens = np.array(all_tokens, dtype=np.int32)
|
45
|
+
|
46
|
+
# Repeat data if it's too short
|
47
|
+
required_tokens = num_samples * sequence_length
|
48
|
+
if len(all_tokens) < required_tokens:
|
49
|
+
logging.info(
|
50
|
+
f"Warning: Dataset is too short ({len(all_tokens)} tokens)."
|
51
|
+
" Repeating data to generate {num_samples} samples."
|
52
|
+
)
|
53
|
+
repeats = -(-required_tokens // len(all_tokens)) # Ceiling division
|
54
|
+
all_tokens = np.tile(all_tokens, repeats)
|
55
|
+
|
56
|
+
# Chunk the token list into samples
|
57
|
+
|
58
|
+
calibration_samples = []
|
59
|
+
for _ in range(num_samples):
|
60
|
+
# Generate a random starting index
|
61
|
+
start_index = random.randint(0, len(all_tokens) - sequence_length - 1)
|
62
|
+
end_index = start_index + sequence_length
|
63
|
+
sample = all_tokens[start_index:end_index]
|
64
|
+
calibration_samples.append(np.reshape(sample, (1, sequence_length)))
|
65
|
+
|
66
|
+
final_array = np.stack(calibration_samples, axis=0)
|
67
|
+
return final_array
|
68
|
+
|
69
|
+
|
70
|
+
def _find_layers_recursive(layer, prefix, found_layers):
|
71
|
+
"""
|
72
|
+
Recursively search for Dense and EinsumDense layers and record them.
|
73
|
+
"""
|
74
|
+
for sub_layer in layer._layers:
|
75
|
+
# Construct a unique name for the layer based on its hierarchy
|
76
|
+
layer_name = f"{prefix}.{sub_layer.name}"
|
77
|
+
if isinstance(sub_layer, (Dense, EinsumDense)):
|
78
|
+
found_layers[layer_name] = sub_layer
|
79
|
+
|
80
|
+
# Recurse into nested layers that are not the target types
|
81
|
+
elif hasattr(sub_layer, "_layers") and sub_layer._layers:
|
82
|
+
_find_layers_recursive(sub_layer, layer_name, found_layers)
|
83
|
+
|
84
|
+
|
85
|
+
def find_layers_in_block(block):
|
86
|
+
"""
|
87
|
+
A pluggable, generic function to find all Dense and EinsumDense layers
|
88
|
+
within any transformer block by using a recursive search.
|
89
|
+
"""
|
90
|
+
found_layers = {}
|
91
|
+
# Start the recursive search from the block itself
|
92
|
+
_find_layers_recursive(block, "block", found_layers)
|
93
|
+
return found_layers
|
94
|
+
|
95
|
+
|
96
|
+
def apply_gptq_layerwise(
|
97
|
+
model,
|
98
|
+
dataloader,
|
99
|
+
num_samples,
|
100
|
+
hessian_damping,
|
101
|
+
group_size,
|
102
|
+
symmetric,
|
103
|
+
activation_order,
|
104
|
+
weight_bits,
|
105
|
+
):
|
106
|
+
"""Applies GPTQ quantization layer-by-layer to a Keras model.
|
107
|
+
|
108
|
+
This function is designed to work with common transformer architectures,
|
109
|
+
like those provided by KerasHub. It automatically discovers the model's
|
110
|
+
structure by first looking for the standard format: a `model.backbone`
|
111
|
+
attribute that contains a `transformer_layers` list.
|
112
|
+
|
113
|
+
If a standard backbone is not found, it falls back to a heuristic for
|
114
|
+
custom models, where it assumes the first `keras.layers.Embedding` layer
|
115
|
+
is the input embedding and any subsequent container layers are the
|
116
|
+
transformer blocks to be quantized.
|
117
|
+
|
118
|
+
The core logic operates as follows:
|
119
|
+
1. It automatically detects the model's structure, identifying the main
|
120
|
+
embedding layer and a sequence of transformer blocks.
|
121
|
+
2. It processes the model sequentially, one block at a time. For each
|
122
|
+
block, it uses temporary hooks to capture the input activations of
|
123
|
+
each target layer during a forward pass with the calibration data.
|
124
|
+
3. These captured activations are used to compute the Hessian matrix for
|
125
|
+
each layer's weights.
|
126
|
+
4. The GPTQ algorithm is then applied to each layer to find the optimal
|
127
|
+
quantized weights that minimize the error introduced.
|
128
|
+
5. The output activations from the current block are then used as the
|
129
|
+
input for the next block, ensuring that quantization errors are
|
130
|
+
accounted for throughout the model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
model: The Keras model instance to be quantized. The function will
|
134
|
+
attempt to automatically discover its structure.
|
135
|
+
dataloader: An iterable providing calibration data. Each item should
|
136
|
+
be a batch of token IDs suitable for the model's embedding layer.
|
137
|
+
num_samples: (int) The number of samples from the dataloader to use for
|
138
|
+
calibration.
|
139
|
+
hessian_damping: (float) The percentage of dampening to add to the
|
140
|
+
Hessian diagonal for stabilization during inverse calculation.
|
141
|
+
A value of 0.01 is common.
|
142
|
+
group_size: (int) The size of the groups to use for quantization. A
|
143
|
+
value of 128 means that 128 weights will share the same scaling
|
144
|
+
factor. Use -1 for per-channel quantization.
|
145
|
+
symmetric: (bool) If True, symmetric quantization is used. Otherwise,
|
146
|
+
asymmetric quantization is used.
|
147
|
+
activation_order: (bool) If True, reorders the weight columns based on
|
148
|
+
activation magnitude, which can improve quantization accuracy.
|
149
|
+
weight_bits: (int) The number of bits to use for the quantized weights,
|
150
|
+
e.g., 4 for 4-bit quantization.
|
151
|
+
|
152
|
+
Raises:
|
153
|
+
ValueError: If the function cannot automatically find an embedding
|
154
|
+
layer or any transformer-like blocks to quantize within the model.
|
155
|
+
"""
|
156
|
+
logging.info("Starting model quantization...")
|
157
|
+
embedding_layer = None
|
158
|
+
transformer_blocks = []
|
159
|
+
if hasattr(model, "backbone"):
|
160
|
+
logging.info("Detected KerasHub model structure.")
|
161
|
+
backbone = model.backbone
|
162
|
+
|
163
|
+
# Add the check for the 'transformer_layers' attribute.
|
164
|
+
if hasattr(backbone, "transformer_layers"):
|
165
|
+
transformer_blocks = backbone.transformer_layers
|
166
|
+
else:
|
167
|
+
# Raise a specific error if the attribute is missing.
|
168
|
+
raise ValueError(
|
169
|
+
"The model's backbone does not have a 'transformer_layers' "
|
170
|
+
"attribute. Please ensure you are using a standard KerasHub "
|
171
|
+
"transformer model."
|
172
|
+
)
|
173
|
+
# Find the embedding layer by checking for common names or by type.
|
174
|
+
if hasattr(backbone, "token_embedding"):
|
175
|
+
embedding_layer = backbone.token_embedding
|
176
|
+
elif hasattr(backbone, "embedding"):
|
177
|
+
embedding_layer = backbone.embedding
|
178
|
+
else:
|
179
|
+
raise ValueError(
|
180
|
+
"Could not automatically find an embedding layer in the model."
|
181
|
+
)
|
182
|
+
|
183
|
+
else:
|
184
|
+
logging.info("Detected custom model structure.")
|
185
|
+
for layer in model.layers:
|
186
|
+
# The first Embedding layer found is assumed to be the main one.
|
187
|
+
if isinstance(layer, Embedding) and embedding_layer is None:
|
188
|
+
embedding_layer = layer
|
189
|
+
# A "block" is a container-like layer with its own sub-layers
|
190
|
+
# that we can quantize. This is a heuristic that works for the
|
191
|
+
# test.
|
192
|
+
elif hasattr(layer, "_layers") and layer._layers:
|
193
|
+
transformer_blocks.append(layer)
|
194
|
+
|
195
|
+
if embedding_layer is None:
|
196
|
+
raise ValueError(
|
197
|
+
"Could not automatically find an embedding layer in the model."
|
198
|
+
)
|
199
|
+
if not transformer_blocks:
|
200
|
+
raise ValueError(
|
201
|
+
"Could not automatically find any transformer-like blocks to "
|
202
|
+
"quantize."
|
203
|
+
)
|
204
|
+
|
205
|
+
# Initial inputs are the outputs of the token embedding layer
|
206
|
+
inputs = [
|
207
|
+
embedding_layer(ops.convert_to_tensor(batch, dtype="int32"))
|
208
|
+
for batch in dataloader
|
209
|
+
]
|
210
|
+
progbar = keras_utils.Progbar(target=len(transformer_blocks))
|
211
|
+
|
212
|
+
for block_idx, block in enumerate(transformer_blocks):
|
213
|
+
logging.info(f"Quantizing Block {block_idx}")
|
214
|
+
sub_layers_map = find_layers_in_block(block)
|
215
|
+
|
216
|
+
if not sub_layers_map:
|
217
|
+
logging.info(
|
218
|
+
f" No Dense or EinsumDense layers found in block {block_idx}. "
|
219
|
+
"Skipping."
|
220
|
+
)
|
221
|
+
else:
|
222
|
+
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
|
223
|
+
gptq_objects = {
|
224
|
+
name: GPTQ(layer) for name, layer in sub_layers_map.items()
|
225
|
+
}
|
226
|
+
|
227
|
+
captured_inputs = {name: [] for name in sub_layers_map.keys()}
|
228
|
+
original_calls = {}
|
229
|
+
|
230
|
+
def create_hook(name, original_call_func):
|
231
|
+
"""A factory for creating a hook to capture layer inputs."""
|
232
|
+
|
233
|
+
def hook(*args, **kwargs):
|
234
|
+
if args:
|
235
|
+
inp = args[0]
|
236
|
+
else:
|
237
|
+
inp = kwargs["inputs"]
|
238
|
+
captured_inputs[name].append(inp)
|
239
|
+
return original_call_func(*args, **kwargs)
|
240
|
+
|
241
|
+
return hook
|
242
|
+
|
243
|
+
try:
|
244
|
+
for name, layer in sub_layers_map.items():
|
245
|
+
original_call = layer.call
|
246
|
+
original_calls[name] = original_call
|
247
|
+
layer.call = create_hook(name, original_call)
|
248
|
+
|
249
|
+
logging.info(f"Capturing activations for block {block_idx}...")
|
250
|
+
for sample_idx in range(num_samples):
|
251
|
+
current_input = inputs[sample_idx]
|
252
|
+
if len(current_input.shape) == 2:
|
253
|
+
current_input = ops.expand_dims(current_input, axis=0)
|
254
|
+
_ = block(current_input)
|
255
|
+
|
256
|
+
finally:
|
257
|
+
for name, layer in sub_layers_map.items():
|
258
|
+
if name in original_calls:
|
259
|
+
layer.call = original_calls[name]
|
260
|
+
|
261
|
+
logging.info(f"Building Hessians for block {block_idx}...")
|
262
|
+
for name, gptq_object in gptq_objects.items():
|
263
|
+
layer_inputs = ops.concatenate(captured_inputs[name], axis=0)
|
264
|
+
|
265
|
+
# Explicitly reshape the input tensor to be 2D, with the second
|
266
|
+
# dimension matching the number of input features expected by
|
267
|
+
# the layer's kernel.
|
268
|
+
# This correctly handles inputs of any dimensionality
|
269
|
+
# (e.g., 3D or 4D).
|
270
|
+
num_features = gptq_object.rows
|
271
|
+
input_reshaped = ops.reshape(layer_inputs, (-1, num_features))
|
272
|
+
gptq_object.update_hessian_with_batch(input_reshaped)
|
273
|
+
|
274
|
+
quantizer = GPTQQuantization(
|
275
|
+
weight_bits,
|
276
|
+
per_channel=True,
|
277
|
+
symmetric=symmetric,
|
278
|
+
group_size=group_size,
|
279
|
+
)
|
280
|
+
for name, gptq_object in gptq_objects.items():
|
281
|
+
logging.info(f"Quantizing {name}...")
|
282
|
+
gptq_object.quantizer = quantizer
|
283
|
+
gptq_object.quantize_and_correct_block(
|
284
|
+
hessian_damping=hessian_damping,
|
285
|
+
group_size=group_size,
|
286
|
+
activation_order=activation_order,
|
287
|
+
)
|
288
|
+
gptq_object.free()
|
289
|
+
|
290
|
+
del gptq_objects, captured_inputs, original_calls
|
291
|
+
|
292
|
+
if block_idx < len(transformer_blocks) - 1:
|
293
|
+
logging.info(f"Generating inputs for block {block_idx + 1}...")
|
294
|
+
next_block_inputs = []
|
295
|
+
for sample_idx in range(num_samples):
|
296
|
+
current_input = inputs[sample_idx]
|
297
|
+
if len(current_input.shape) == 2:
|
298
|
+
current_input = ops.expand_dims(current_input, axis=0)
|
299
|
+
output = block(current_input)[0]
|
300
|
+
next_block_inputs.append(output)
|
301
|
+
inputs = next_block_inputs
|
302
|
+
progbar.update(current=block_idx + 1)
|
303
|
+
|
304
|
+
logging.info("Quantization process complete.")
|
305
|
+
|
306
|
+
|
307
|
+
def quantize_model(model, config):
|
308
|
+
"""
|
309
|
+
Top-level function to quantize a Keras model using GPTQ.
|
310
|
+
"""
|
311
|
+
logging.info("Starting GPTQ quantization process...")
|
312
|
+
|
313
|
+
# Load ALL data needed from the generator/source in a single call.
|
314
|
+
total_samples_to_request = config.num_samples
|
315
|
+
full_dataloader = get_dataloader(
|
316
|
+
config.tokenizer,
|
317
|
+
config.sequence_length,
|
318
|
+
config.dataset,
|
319
|
+
num_samples=total_samples_to_request,
|
320
|
+
)
|
321
|
+
|
322
|
+
# Split the materialized data. This works because full_dataloader
|
323
|
+
# is now a NumPy array, which can be sliced and reused.
|
324
|
+
calibration_dataloader = full_dataloader[: config.num_samples]
|
325
|
+
|
326
|
+
apply_gptq_layerwise(
|
327
|
+
model,
|
328
|
+
calibration_dataloader, # Use the calibration slice
|
329
|
+
config.num_samples, # Use the configured number of samples
|
330
|
+
config.hessian_damping,
|
331
|
+
config.group_size,
|
332
|
+
config.symmetric,
|
333
|
+
config.activation_order,
|
334
|
+
config.weight_bits,
|
335
|
+
)
|
@@ -0,0 +1,133 @@
|
|
1
|
+
from keras.src import ops
|
2
|
+
|
3
|
+
|
4
|
+
def dequantize(input_tensor, scale, zero, maxq):
|
5
|
+
"""The core quantization function."""
|
6
|
+
epsilon = ops.cast(1e-8, dtype=scale.dtype)
|
7
|
+
scale = ops.where(ops.equal(scale, 0), epsilon, scale)
|
8
|
+
|
9
|
+
quantized_tensor = ops.divide(input_tensor, scale)
|
10
|
+
quantized_tensor = ops.round(quantized_tensor)
|
11
|
+
q = ops.add(quantized_tensor, zero)
|
12
|
+
q = ops.clip(q, 0, maxq)
|
13
|
+
|
14
|
+
dequantized_tensor = ops.subtract(q, zero)
|
15
|
+
return ops.multiply(scale, dequantized_tensor)
|
16
|
+
|
17
|
+
|
18
|
+
class GPTQQuantization:
|
19
|
+
"""A class that handles the quantization of weights using GPTQ method.
|
20
|
+
|
21
|
+
This class provides methods to find quantization parameters (scale and zero)
|
22
|
+
for a given tensor and can be used to quantize weights in a GPTQ context.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
weight_bits: (int) The number of bits to quantize to (e.g., 4).
|
26
|
+
per_channel: (bool) A flag indicating whether quantization is
|
27
|
+
applied per-channel (`True`) or per-tensor (`False`).
|
28
|
+
Defaults to `False`.
|
29
|
+
symmetric: (bool) A flag indicating whether symmetric (`True`) or
|
30
|
+
asymmetric (`False`) quantization is used. Defaults to `False`.
|
31
|
+
group_size: (int) The size of weight groups for quantization. A
|
32
|
+
value of -1 indicates that grouping is not used.
|
33
|
+
Defaults to -1.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self, weight_bits, per_channel=True, symmetric=False, group_size=-1
|
38
|
+
):
|
39
|
+
self.weight_bits = weight_bits
|
40
|
+
self.maxq = ops.cast(
|
41
|
+
ops.subtract(ops.power(2, weight_bits), 1), "float32"
|
42
|
+
)
|
43
|
+
self.per_channel = per_channel
|
44
|
+
self.symmetric = symmetric
|
45
|
+
self.group_size = group_size
|
46
|
+
|
47
|
+
# These are now determined later by `find_params`
|
48
|
+
self.scale = None
|
49
|
+
self.zero = None
|
50
|
+
|
51
|
+
def find_params(self, input_tensor, weight=False):
|
52
|
+
"""Finds quantization parameters (scale and zero) for a given tensor."""
|
53
|
+
|
54
|
+
if input_tensor is None:
|
55
|
+
raise ValueError("Input tensor 'input_tensor' cannot be None.")
|
56
|
+
|
57
|
+
# For weights, we typically expect at least a 2D tensor.
|
58
|
+
if weight and len(input_tensor.shape) < 2:
|
59
|
+
raise ValueError(
|
60
|
+
f"Input weight tensor 'input_tensor' must have a rank of at "
|
61
|
+
f"least 2, but got rank {len(input_tensor.shape)}."
|
62
|
+
)
|
63
|
+
|
64
|
+
if ops.size(input_tensor) == 0:
|
65
|
+
raise ValueError("Input tensor 'input_tensor' cannot be empty.")
|
66
|
+
|
67
|
+
original_shape = input_tensor.shape
|
68
|
+
|
69
|
+
if self.per_channel:
|
70
|
+
if weight:
|
71
|
+
if self.group_size != -1:
|
72
|
+
input_reshaped = ops.reshape(
|
73
|
+
input_tensor, [-1, self.group_size]
|
74
|
+
)
|
75
|
+
else:
|
76
|
+
input_reshaped = ops.reshape(
|
77
|
+
input_tensor, [original_shape[0], -1]
|
78
|
+
)
|
79
|
+
else: # per-tensor
|
80
|
+
input_reshaped = ops.reshape(input_tensor, [1, -1])
|
81
|
+
|
82
|
+
# Find min/max values
|
83
|
+
min_values = ops.min(input_reshaped, axis=1)
|
84
|
+
max_values = ops.max(input_reshaped, axis=1)
|
85
|
+
|
86
|
+
# Apply symmetric quantization logic if enabled
|
87
|
+
if self.symmetric:
|
88
|
+
max_values = ops.maximum(ops.abs(min_values), max_values)
|
89
|
+
min_values = ops.where(
|
90
|
+
ops.less(min_values, 0), ops.negative(max_values), min_values
|
91
|
+
)
|
92
|
+
|
93
|
+
# Ensure range is not zero to avoid division errors
|
94
|
+
zero_range = ops.equal(min_values, max_values)
|
95
|
+
min_values = ops.where(
|
96
|
+
zero_range, ops.subtract(min_values, 1), min_values
|
97
|
+
)
|
98
|
+
max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
|
99
|
+
|
100
|
+
# Calculate scale and zero-point
|
101
|
+
self.scale = ops.divide(ops.subtract(max_values, min_values), self.maxq)
|
102
|
+
if self.symmetric:
|
103
|
+
self.zero = ops.full_like(
|
104
|
+
self.scale, ops.divide(ops.add(self.maxq, 1), 2)
|
105
|
+
)
|
106
|
+
else:
|
107
|
+
self.zero = ops.round(
|
108
|
+
ops.divide(ops.negative(min_values), self.scale)
|
109
|
+
)
|
110
|
+
|
111
|
+
# Ensure scale is non-zero
|
112
|
+
self.scale = ops.where(ops.less_equal(self.scale, 0), 1e-8, self.scale)
|
113
|
+
|
114
|
+
if weight:
|
115
|
+
# Per-channel, non-grouped case: simple reshape is correct.
|
116
|
+
if self.per_channel and self.group_size == -1:
|
117
|
+
self.scale = ops.reshape(self.scale, [-1, 1])
|
118
|
+
self.zero = ops.reshape(self.zero, [-1, 1])
|
119
|
+
elif not self.per_channel:
|
120
|
+
num_rows = original_shape[0]
|
121
|
+
self.scale = ops.tile(
|
122
|
+
ops.reshape(self.scale, (1, 1)), (num_rows, 1)
|
123
|
+
)
|
124
|
+
self.zero = ops.tile(
|
125
|
+
ops.reshape(self.zero, (1, 1)), (num_rows, 1)
|
126
|
+
)
|
127
|
+
if self.per_channel:
|
128
|
+
self.scale = ops.reshape(self.scale, [-1, 1])
|
129
|
+
self.zero = ops.reshape(self.zero, [-1, 1])
|
130
|
+
|
131
|
+
def ready(self):
|
132
|
+
"""Checks if the quantization parameters have been computed."""
|
133
|
+
return self.scale is not None and self.zero is not None
|
keras/src/saving/file_editor.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import collections
|
2
2
|
import json
|
3
|
+
import os.path
|
3
4
|
import pprint
|
4
5
|
import zipfile
|
5
6
|
|
@@ -76,7 +77,7 @@ class KerasFileEditor:
|
|
76
77
|
if filepath.endswith(".keras"):
|
77
78
|
zf = zipfile.ZipFile(filepath, "r")
|
78
79
|
weights_store = H5IOStore(
|
79
|
-
saving_lib._VARS_FNAME
|
80
|
+
f"{saving_lib._VARS_FNAME}.h5",
|
80
81
|
archive=zf,
|
81
82
|
mode="r",
|
82
83
|
)
|
@@ -143,7 +144,7 @@ class KerasFileEditor:
|
|
143
144
|
):
|
144
145
|
base_inner_path = inner_path
|
145
146
|
for ref_key, ref_val in ref_spec.items():
|
146
|
-
inner_path = base_inner_path
|
147
|
+
inner_path = f"{base_inner_path}/{ref_key}"
|
147
148
|
if inner_path in checked_paths:
|
148
149
|
continue
|
149
150
|
|
@@ -435,7 +436,7 @@ class KerasFileEditor:
|
|
435
436
|
_save(
|
436
437
|
weights_dict[name],
|
437
438
|
weights_store,
|
438
|
-
inner_path=inner_path
|
439
|
+
inner_path=os.path.join(inner_path, name),
|
439
440
|
)
|
440
441
|
else:
|
441
442
|
# e.g. name="0", value=HDF5Dataset
|
@@ -462,7 +463,7 @@ class KerasFileEditor:
|
|
462
463
|
|
463
464
|
result = collections.OrderedDict()
|
464
465
|
for key in data.keys():
|
465
|
-
inner_path = inner_path
|
466
|
+
inner_path = f"{inner_path}/{key}"
|
466
467
|
value = data[key]
|
467
468
|
if isinstance(value, h5py.Group):
|
468
469
|
if len(value) == 0:
|
@@ -506,7 +507,7 @@ class KerasFileEditor:
|
|
506
507
|
self, weights_dict, indent=0, is_first=True, prefix="", inner_path=""
|
507
508
|
):
|
508
509
|
for idx, (key, value) in enumerate(weights_dict.items()):
|
509
|
-
inner_path = inner_path
|
510
|
+
inner_path = os.path.join(inner_path, key)
|
510
511
|
is_last = idx == len(weights_dict) - 1
|
511
512
|
if is_first:
|
512
513
|
is_first = False
|
@@ -556,29 +557,30 @@ class KerasFileEditor:
|
|
556
557
|
html = ""
|
557
558
|
for key, value in dictionary.items():
|
558
559
|
if isinstance(value, dict) and value:
|
560
|
+
weights_html = _generate_html_weights(
|
561
|
+
value, margin_left + 20, font_size - 1
|
562
|
+
)
|
559
563
|
html += (
|
560
564
|
f'<details style="margin-left: {margin_left}px;">'
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
)
|
568
|
-
+ "</details>"
|
565
|
+
'<summary style="'
|
566
|
+
f"font-size: {font_size}em; "
|
567
|
+
"font-weight: bold;"
|
568
|
+
f'">{key}</summary>'
|
569
|
+
f"{weights_html}"
|
570
|
+
"</details>"
|
569
571
|
)
|
570
572
|
else:
|
571
573
|
html += (
|
572
574
|
f'<details style="margin-left: {margin_left}px;">'
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
575
|
+
f'<summary style="font-size: {font_size}em;">'
|
576
|
+
f"{key} : shape={value.shape}"
|
577
|
+
f", dtype={value.dtype}</summary>"
|
578
|
+
f"<div style="
|
577
579
|
f'"margin-left: {margin_left}px;'
|
578
580
|
f'"margin-top: {margin_left}px;">'
|
579
|
-
|
580
|
-
|
581
|
-
|
581
|
+
f"{display_weight(value)}"
|
582
|
+
"</div>"
|
583
|
+
"</details>"
|
582
584
|
)
|
583
585
|
return html
|
584
586
|
|
@@ -140,7 +140,7 @@ def register_keras_serializable(package="Custom", name=None):
|
|
140
140
|
def decorator(arg):
|
141
141
|
"""Registers a class with the Keras serialization framework."""
|
142
142
|
class_name = name if name is not None else arg.__name__
|
143
|
-
registered_name = package
|
143
|
+
registered_name = f"{package}>{class_name}"
|
144
144
|
|
145
145
|
if inspect.isclass(arg) and not hasattr(arg, "get_config"):
|
146
146
|
raise ValueError(
|
keras/src/saving/saving_lib.py
CHANGED
@@ -46,8 +46,8 @@ except ImportError:
|
|
46
46
|
_CONFIG_FILENAME = "config.json"
|
47
47
|
_METADATA_FILENAME = "metadata.json"
|
48
48
|
_VARS_FNAME = "model.weights" # Will become e.g. "model.weights.h5"
|
49
|
-
_VARS_FNAME_H5 = _VARS_FNAME
|
50
|
-
_VARS_FNAME_NPZ = _VARS_FNAME
|
49
|
+
_VARS_FNAME_H5 = f"{_VARS_FNAME}.h5"
|
50
|
+
_VARS_FNAME_NPZ = f"{_VARS_FNAME}.npz"
|
51
51
|
_ASSETS_DIRNAME = "assets"
|
52
52
|
_MEMORY_UPPER_BOUND = 0.5 # 50%
|
53
53
|
|
@@ -664,7 +664,7 @@ def _write_to_zip_recursively(zipfile_to_save, system_path, zip_path):
|
|
664
664
|
def _name_key(name):
|
665
665
|
"""Make sure that private attributes are visited last."""
|
666
666
|
if name.startswith("_"):
|
667
|
-
return "~"
|
667
|
+
return f"~{name}"
|
668
668
|
return name
|
669
669
|
|
670
670
|
|
@@ -1288,7 +1288,7 @@ class ShardedH5IOStore(H5IOStore):
|
|
1288
1288
|
# If not found, check shard map and switch files.
|
1289
1289
|
weight_map = self.sharding_config["weight_map"]
|
1290
1290
|
filenames = weight_map.get(parsed_path) or weight_map.get(
|
1291
|
-
"/
|
1291
|
+
f"/{parsed_path}/vars"
|
1292
1292
|
)
|
1293
1293
|
if filenames is not None:
|
1294
1294
|
if not isinstance(filenames, list):
|
@@ -778,7 +778,7 @@ def _retrieve_class_or_fn(
|
|
778
778
|
# module name might not match the package structure
|
779
779
|
# (e.g. experimental symbols).
|
780
780
|
if module == "keras" or module.startswith("keras."):
|
781
|
-
api_name = module
|
781
|
+
api_name = f"{module}.{name}"
|
782
782
|
|
783
783
|
if api_name in LOADING_APIS:
|
784
784
|
raise ValueError(
|
@@ -796,9 +796,7 @@ def _retrieve_class_or_fn(
|
|
796
796
|
# the corresponding function from the identifying string.
|
797
797
|
if obj_type == "function" and module == "builtins":
|
798
798
|
for mod in BUILTIN_MODULES:
|
799
|
-
obj = api_export.get_symbol_from_name(
|
800
|
-
"keras." + mod + "." + name
|
801
|
-
)
|
799
|
+
obj = api_export.get_symbol_from_name(f"keras.{mod}.{name}")
|
802
800
|
if obj is not None:
|
803
801
|
return obj
|
804
802
|
|
@@ -807,7 +805,7 @@ def _retrieve_class_or_fn(
|
|
807
805
|
# i.e. "name" instead of "package>name". This allows recent versions
|
808
806
|
# of Keras to reload models saved with 3.6 and lower.
|
809
807
|
if ">" not in name:
|
810
|
-
separated_name = ">"
|
808
|
+
separated_name = f">{name}"
|
811
809
|
for custom_name, custom_object in custom_objects.items():
|
812
810
|
if custom_name.endswith(separated_name):
|
813
811
|
return custom_object
|
@@ -659,7 +659,7 @@ class CompileLoss(losses_module.Loss):
|
|
659
659
|
# Add `Mean` metric to the tracker for each loss.
|
660
660
|
if len(self._flat_losses) > 1:
|
661
661
|
for _loss in self._flat_losses:
|
662
|
-
name = _loss.name
|
662
|
+
name = f"{_loss.name}_loss"
|
663
663
|
self._tracker.add_to_store(
|
664
664
|
"metrics", metrics_module.Mean(name=name)
|
665
665
|
)
|
@@ -76,7 +76,9 @@ class ArrayDataAdapter(DataAdapter):
|
|
76
76
|
inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight)
|
77
77
|
|
78
78
|
data_adapter_utils.check_data_cardinality(inputs)
|
79
|
-
num_samples = set(
|
79
|
+
num_samples = set(
|
80
|
+
i.shape[0] for i in tree.flatten(inputs) if i is not None
|
81
|
+
).pop()
|
80
82
|
self._num_samples = num_samples
|
81
83
|
self._inputs = inputs
|
82
84
|
|
@@ -269,7 +271,9 @@ class ArrayDataAdapter(DataAdapter):
|
|
269
271
|
x = convert_to_tensor(x)
|
270
272
|
return x
|
271
273
|
|
272
|
-
return tree.map_structure(
|
274
|
+
return tree.map_structure(
|
275
|
+
slice_and_convert, self.array, none_is_leaf=False
|
276
|
+
)
|
273
277
|
|
274
278
|
def __len__(self):
|
275
279
|
return len(self.array[0])
|
@@ -337,7 +341,9 @@ class ArrayDataAdapter(DataAdapter):
|
|
337
341
|
slice_indices_and_convert_fn = functools.partial(
|
338
342
|
slice_and_convert_fn, indices=indices
|
339
343
|
)
|
340
|
-
yield tree.map_structure(
|
344
|
+
yield tree.map_structure(
|
345
|
+
slice_indices_and_convert_fn, inputs, none_is_leaf=False
|
346
|
+
)
|
341
347
|
|
342
348
|
@property
|
343
349
|
def num_batches(self):
|