keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from keras.src.api_export import keras_export
|
|
2
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@keras_export("keras.quantizers.AWQConfig")
|
|
6
|
+
class AWQConfig(QuantizationConfig):
|
|
7
|
+
"""Configuration class for AWQ (Activation-aware Weight Quantization).
|
|
8
|
+
|
|
9
|
+
AWQ is a post-training quantization method that identifies and protects
|
|
10
|
+
salient weights based on activation magnitudes. It applies per-channel
|
|
11
|
+
scaling before quantization to minimize accuracy loss.
|
|
12
|
+
|
|
13
|
+
Methodology:
|
|
14
|
+
1. Collects activation statistics from calibration data
|
|
15
|
+
2. Identifies salient weight channels based on activation magnitudes
|
|
16
|
+
3. Searches for optimal per-channel scaling factors via grid search
|
|
17
|
+
4. Applies scaling before quantization to protect important weights
|
|
18
|
+
|
|
19
|
+
References:
|
|
20
|
+
- Original AWQ paper: "AWQ: Activation-aware Weight Quantization for
|
|
21
|
+
LLM Compression and Acceleration" (https://arxiv.org/abs/2306.00978)
|
|
22
|
+
- Reference implementation: https://github.com/mit-han-lab/llm-awq
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
dataset: The calibration dataset. It can be an iterable that yields
|
|
26
|
+
strings or pre-tokenized numerical tensors (e.g., a list of
|
|
27
|
+
strings, a generator, or a NumPy array). This data is used to
|
|
28
|
+
analyze activation patterns.
|
|
29
|
+
tokenizer: A tokenizer instance (or a similar callable) that is used
|
|
30
|
+
to process the `dataset`.
|
|
31
|
+
weight_bits: The number of bits for weight quantization. AWQ presently
|
|
32
|
+
only supports 4-bit quantization. Defaults to 4.
|
|
33
|
+
num_samples: The number of calibration data samples to use from the
|
|
34
|
+
dataset. Defaults to 128.
|
|
35
|
+
sequence_length: The sequence length to use for each calibration
|
|
36
|
+
sample. Defaults to 512.
|
|
37
|
+
group_size: The size of weight groups to quantize together. A
|
|
38
|
+
`group_size` of -1 indicates per-channel quantization.
|
|
39
|
+
Defaults to 128.
|
|
40
|
+
num_grid_points: The number of grid search points for finding optimal
|
|
41
|
+
per-channel scales. Higher values may find better scales but
|
|
42
|
+
take longer. Defaults to 20.
|
|
43
|
+
quantization_layer_structure: A dictionary defining the model's
|
|
44
|
+
quantization structure. It should contain:
|
|
45
|
+
- "pre_block_layers": list of layers to run before the first
|
|
46
|
+
block (e.g., embedding layer).
|
|
47
|
+
- "sequential_blocks": list of transformer blocks to quantize
|
|
48
|
+
sequentially.
|
|
49
|
+
If not provided, the model must implement
|
|
50
|
+
`get_quantization_layer_structure`.
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
```python
|
|
54
|
+
from keras.quantizers import AWQConfig
|
|
55
|
+
|
|
56
|
+
# Create configuration for 4-bit AWQ quantization
|
|
57
|
+
config = AWQConfig(
|
|
58
|
+
dataset=calibration_data, # Your calibration dataset
|
|
59
|
+
tokenizer=your_tokenizer, # Tokenizer for text data
|
|
60
|
+
num_samples=128, # Number of calibration samples
|
|
61
|
+
sequence_length=512, # Sequence length for each sample
|
|
62
|
+
group_size=128, # Weight grouping for quantization
|
|
63
|
+
num_grid_points=20, # Grid search points for scale search
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Apply quantization to your model
|
|
67
|
+
model.quantize("awq", config=config)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
dataset,
|
|
75
|
+
tokenizer,
|
|
76
|
+
*,
|
|
77
|
+
weight_bits: int = 4,
|
|
78
|
+
num_samples: int = 128,
|
|
79
|
+
sequence_length: int = 512,
|
|
80
|
+
group_size: int = 128,
|
|
81
|
+
num_grid_points: int = 20,
|
|
82
|
+
quantization_layer_structure: dict = None,
|
|
83
|
+
):
|
|
84
|
+
super().__init__()
|
|
85
|
+
# AWQ only supports 4-bit quantization
|
|
86
|
+
if weight_bits != 4:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"AWQ only supports 4-bit quantization. "
|
|
89
|
+
f"Received weight_bits={weight_bits}."
|
|
90
|
+
)
|
|
91
|
+
if num_samples <= 0:
|
|
92
|
+
raise ValueError("num_samples must be a positive integer.")
|
|
93
|
+
if sequence_length <= 0:
|
|
94
|
+
raise ValueError("sequence_length must be a positive integer.")
|
|
95
|
+
if group_size < -1 or group_size == 0:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Invalid group_size. Supported values are -1 (per-channel) "
|
|
98
|
+
f"or a positive integer, but got {group_size}."
|
|
99
|
+
)
|
|
100
|
+
if num_grid_points <= 0:
|
|
101
|
+
raise ValueError("num_grid_points must be a positive integer.")
|
|
102
|
+
|
|
103
|
+
self.dataset = dataset
|
|
104
|
+
self.tokenizer = tokenizer
|
|
105
|
+
self.weight_bits = weight_bits
|
|
106
|
+
self.num_samples = num_samples
|
|
107
|
+
self.sequence_length = sequence_length
|
|
108
|
+
self.group_size = group_size
|
|
109
|
+
self.num_grid_points = num_grid_points
|
|
110
|
+
self.quantization_layer_structure = quantization_layer_structure
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def mode(self):
|
|
114
|
+
return "awq"
|
|
115
|
+
|
|
116
|
+
def dtype_policy_string(self):
|
|
117
|
+
"""Returns the dtype policy string for this configuration.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
A string representing the dtype policy, e.g. "awq/4/128".
|
|
121
|
+
"""
|
|
122
|
+
return f"awq/{self.weight_bits}/{self.group_size}"
|
|
123
|
+
|
|
124
|
+
def get_config(self):
|
|
125
|
+
return {
|
|
126
|
+
# Dataset and Tokenizer are only required for one-time
|
|
127
|
+
# calibration and are not saved in the config.
|
|
128
|
+
"dataset": None,
|
|
129
|
+
"tokenizer": None,
|
|
130
|
+
"weight_bits": self.weight_bits,
|
|
131
|
+
"num_samples": self.num_samples,
|
|
132
|
+
"sequence_length": self.sequence_length,
|
|
133
|
+
"group_size": self.group_size,
|
|
134
|
+
"num_grid_points": self.num_grid_points,
|
|
135
|
+
"quantization_layer_structure": self.quantization_layer_structure,
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def from_config(cls, config):
|
|
140
|
+
return cls(**config)
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""AWQ core functionality for layer-wise quantization.
|
|
2
|
+
|
|
3
|
+
This module provides the orchestration logic for applying AWQ quantization
|
|
4
|
+
to transformer models in a layer-by-layer fashion.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
|
|
9
|
+
from absl import logging
|
|
10
|
+
|
|
11
|
+
from keras.src import ops
|
|
12
|
+
from keras.src import utils as keras_utils
|
|
13
|
+
from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
|
|
14
|
+
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
15
|
+
from keras.src.quantizers.awq import AWQ
|
|
16
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
17
|
+
from keras.src.quantizers.gptq_core import find_layers_in_block
|
|
18
|
+
from keras.src.quantizers.gptq_core import get_dataloader
|
|
19
|
+
from keras.src.quantizers.utils import should_quantize_layer
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@contextmanager
|
|
23
|
+
def stream_activations(layers_map, awq_objects):
|
|
24
|
+
"""Context manager to capture activations for AWQ calibration.
|
|
25
|
+
|
|
26
|
+
Temporarily patches layer.call methods to capture activation statistics
|
|
27
|
+
for computing per-channel scaling factors.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
layers_map: Dict[str, Layer]. Mapping from layer names to layers.
|
|
31
|
+
awq_objects: Dict[str, AWQ]. Mapping from names to AWQ instances.
|
|
32
|
+
|
|
33
|
+
Yields:
|
|
34
|
+
None: The patched state is active only within the `with` block.
|
|
35
|
+
"""
|
|
36
|
+
original_calls = {}
|
|
37
|
+
|
|
38
|
+
def create_hook(name, original_call_func):
|
|
39
|
+
def hook(*args, **kwargs):
|
|
40
|
+
inp = args[0] if args else kwargs["inputs"]
|
|
41
|
+
num_features = awq_objects[name].rows
|
|
42
|
+
input_2d = ops.reshape(inp, (-1, num_features))
|
|
43
|
+
awq_objects[name].update_activation_magnitudes(input_2d)
|
|
44
|
+
return original_call_func(*args, **kwargs)
|
|
45
|
+
|
|
46
|
+
return hook
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
for name, layer in layers_map.items():
|
|
50
|
+
original_calls[name] = layer.call
|
|
51
|
+
layer.call = create_hook(name, layer.call)
|
|
52
|
+
yield
|
|
53
|
+
finally:
|
|
54
|
+
for name, layer in layers_map.items():
|
|
55
|
+
layer.call = original_calls[name]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def apply_awq_layerwise(dataloader, config, structure, filters=None):
|
|
59
|
+
"""Apply AWQ quantization layer-by-layer to a Keras model.
|
|
60
|
+
|
|
61
|
+
This function processes the model sequentially, one block at a time:
|
|
62
|
+
1. Captures activation statistics through calibration data forward pass
|
|
63
|
+
2. Uses activation magnitudes to determine weight saliency
|
|
64
|
+
3. Finds optimal per-channel scales via grid search
|
|
65
|
+
4. Quantizes weights with AWQ scaling
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
dataloader: Calibration data as numpy array.
|
|
69
|
+
config: AWQConfig instance.
|
|
70
|
+
structure: Dict with 'pre_block_layers' and 'sequential_blocks'.
|
|
71
|
+
filters: Optional layer filters.
|
|
72
|
+
"""
|
|
73
|
+
num_samples = config.num_samples
|
|
74
|
+
logging.info("Starting AWQ quantization...")
|
|
75
|
+
|
|
76
|
+
pre_layers = structure.get("pre_block_layers", [])
|
|
77
|
+
transformer_blocks = structure.get("sequential_blocks", [])
|
|
78
|
+
|
|
79
|
+
if not transformer_blocks:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"No sequential blocks found in the structure to quantize."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Process inputs through pre-block layers (e.g., embedding)
|
|
85
|
+
inputs = []
|
|
86
|
+
for batch in dataloader:
|
|
87
|
+
batch = ops.convert_to_tensor(batch, dtype="int32")
|
|
88
|
+
for layer in pre_layers:
|
|
89
|
+
batch = layer(batch)
|
|
90
|
+
inputs.append(batch)
|
|
91
|
+
|
|
92
|
+
num_samples = min(num_samples, len(inputs))
|
|
93
|
+
progbar = keras_utils.Progbar(target=len(transformer_blocks))
|
|
94
|
+
|
|
95
|
+
for block_idx, block in enumerate(transformer_blocks):
|
|
96
|
+
logging.info(f"Quantizing Block {block_idx}")
|
|
97
|
+
sub_layers_map = find_layers_in_block(block)
|
|
98
|
+
|
|
99
|
+
# Apply filters
|
|
100
|
+
final_sub_layers_map = {}
|
|
101
|
+
for name, layer in sub_layers_map.items():
|
|
102
|
+
if not should_quantize_layer(layer, filters):
|
|
103
|
+
continue
|
|
104
|
+
final_sub_layers_map[name] = layer
|
|
105
|
+
|
|
106
|
+
sub_layers_map = final_sub_layers_map
|
|
107
|
+
|
|
108
|
+
if not sub_layers_map:
|
|
109
|
+
logging.info(
|
|
110
|
+
f" No quantizable layers found in block {block_idx}. Skipping."
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
|
|
114
|
+
|
|
115
|
+
# Create AWQ objects for each layer
|
|
116
|
+
awq_objects = {
|
|
117
|
+
name: AWQ(layer, config)
|
|
118
|
+
for name, layer in sub_layers_map.items()
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
# Capture activation statistics
|
|
122
|
+
with stream_activations(sub_layers_map, awq_objects):
|
|
123
|
+
for sample_idx in range(num_samples):
|
|
124
|
+
current_input = inputs[sample_idx]
|
|
125
|
+
if len(current_input.shape) == 2:
|
|
126
|
+
current_input = ops.expand_dims(current_input, axis=0)
|
|
127
|
+
_ = block(current_input)
|
|
128
|
+
|
|
129
|
+
# Quantize each layer
|
|
130
|
+
for name, awq_object in awq_objects.items():
|
|
131
|
+
logging.info(f"Quantizing {name}...")
|
|
132
|
+
awq_object.quantize_layer()
|
|
133
|
+
awq_object.free()
|
|
134
|
+
|
|
135
|
+
del awq_objects
|
|
136
|
+
|
|
137
|
+
# Generate inputs for next block
|
|
138
|
+
if block_idx < len(transformer_blocks) - 1:
|
|
139
|
+
logging.info(f"Generating inputs for block {block_idx + 1}...")
|
|
140
|
+
next_block_inputs = []
|
|
141
|
+
for sample_idx in range(num_samples):
|
|
142
|
+
current_input = inputs[sample_idx]
|
|
143
|
+
if len(current_input.shape) == 2:
|
|
144
|
+
current_input = ops.expand_dims(current_input, axis=0)
|
|
145
|
+
output = block(current_input)[0]
|
|
146
|
+
next_block_inputs.append(output)
|
|
147
|
+
inputs = next_block_inputs
|
|
148
|
+
|
|
149
|
+
progbar.update(current=block_idx + 1)
|
|
150
|
+
|
|
151
|
+
logging.info("AWQ quantization complete.")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def awq_quantize(config, quantization_layer_structure, filters=None):
|
|
155
|
+
"""Main entry point for AWQ quantization.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
config: AWQConfig instance.
|
|
159
|
+
quantization_layer_structure: Model structure dictionary.
|
|
160
|
+
filters: Optional layer filters.
|
|
161
|
+
"""
|
|
162
|
+
if config.dataset is None or config.tokenizer is None:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
"AWQ quantization requires a dataset and tokenizer. "
|
|
165
|
+
"Please provide them in the AWQConfig."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if quantization_layer_structure is None:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
"For 'awq' mode, a valid quantization structure must be provided "
|
|
171
|
+
"either via `config.quantization_layer_structure` or by overriding "
|
|
172
|
+
"`model.get_quantization_layer_structure(mode)`. The structure "
|
|
173
|
+
"should be a dictionary with keys 'pre_block_layers' and "
|
|
174
|
+
"'sequential_blocks'."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Load calibration data
|
|
178
|
+
dataloader = get_dataloader(
|
|
179
|
+
config.tokenizer,
|
|
180
|
+
config.sequence_length,
|
|
181
|
+
config.dataset,
|
|
182
|
+
num_samples=config.num_samples,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
apply_awq_layerwise(
|
|
186
|
+
dataloader[: config.num_samples],
|
|
187
|
+
config,
|
|
188
|
+
quantization_layer_structure,
|
|
189
|
+
filters=filters,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_group_size_for_layer(layer, config):
|
|
194
|
+
"""Get group size from config or dtype policy.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
layer: The layer to get group size for.
|
|
198
|
+
config: Optional AWQConfig instance.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
int: The group size for quantization.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
ValueError: If group size cannot be determined.
|
|
205
|
+
"""
|
|
206
|
+
if config and isinstance(config, AWQConfig):
|
|
207
|
+
return config.group_size
|
|
208
|
+
elif isinstance(layer.dtype_policy, AWQDTypePolicy):
|
|
209
|
+
return layer.dtype_policy.group_size
|
|
210
|
+
elif isinstance(layer.dtype_policy, DTypePolicyMap):
|
|
211
|
+
policy = layer.dtype_policy[layer.path]
|
|
212
|
+
if isinstance(policy, AWQDTypePolicy):
|
|
213
|
+
return policy.group_size
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"For AWQ quantization, group_size must be specified "
|
|
216
|
+
"through AWQConfig or AWQDTypePolicy."
|
|
217
|
+
)
|
keras/src/quantizers/gptq.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import types
|
|
2
|
-
from functools import partial
|
|
3
2
|
|
|
4
3
|
from keras.src import ops
|
|
5
4
|
from keras.src import quantizers
|
|
@@ -296,12 +295,7 @@ class GPTQ:
|
|
|
296
295
|
# For EinsumDense, we determine the effective 2D dimensions.
|
|
297
296
|
self.kernel_shape = layer.kernel.shape
|
|
298
297
|
shape = list(self.kernel_shape)
|
|
299
|
-
|
|
300
|
-
d_model_dim_index = shape.index(max(shape))
|
|
301
|
-
except ValueError:
|
|
302
|
-
raise TypeError(
|
|
303
|
-
f"Could not determine hidden dimension from shape {shape}"
|
|
304
|
-
)
|
|
298
|
+
d_model_dim_index = shape.index(max(shape))
|
|
305
299
|
|
|
306
300
|
if d_model_dim_index == 0: # QKV projection case
|
|
307
301
|
in_features, heads, head_dim = shape
|
|
@@ -471,7 +465,7 @@ class GPTQ:
|
|
|
471
465
|
group_size=self.config.group_size,
|
|
472
466
|
activation_order=self.config.activation_order,
|
|
473
467
|
order_metric=ops.diagonal(hessian_matrix),
|
|
474
|
-
compute_scale_zero=
|
|
468
|
+
compute_scale_zero=self.quantizer.find_params,
|
|
475
469
|
)
|
|
476
470
|
quantized = ops.cast(
|
|
477
471
|
quantized, self.original_layer.quantized_kernel.dtype
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from keras.src.api_export import keras_export
|
|
2
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
@keras_export("keras.quantizers.GPTQConfig")
|
|
5
|
-
class GPTQConfig:
|
|
6
|
+
class GPTQConfig(QuantizationConfig):
|
|
6
7
|
"""Configuration class for the GPTQ (Gradient-based Post-Training
|
|
7
8
|
Quantization) algorithm.
|
|
8
9
|
|
|
@@ -131,6 +132,12 @@ class GPTQConfig:
|
|
|
131
132
|
activation_order: (bool, optional) If `True`, reorders weight columns
|
|
132
133
|
based on activation magnitude, which can improve quantization
|
|
133
134
|
accuracy. Defaults to `False`.
|
|
135
|
+
quantization_layer_structure: (dict, optional) A dictionary defining the
|
|
136
|
+
model's quantization structure. It should contain:
|
|
137
|
+
- "pre_block_layers": list of layers to run before the first block.
|
|
138
|
+
- "sequential_blocks": list of blocks to be quantized sequentially.
|
|
139
|
+
If not provided, the model must implement
|
|
140
|
+
`get_quantization_layer_structure`.
|
|
134
141
|
"""
|
|
135
142
|
|
|
136
143
|
def __init__(
|
|
@@ -146,7 +153,9 @@ class GPTQConfig:
|
|
|
146
153
|
group_size: int = 128,
|
|
147
154
|
symmetric: bool = False,
|
|
148
155
|
activation_order: bool = False,
|
|
156
|
+
quantization_layer_structure: dict = None,
|
|
149
157
|
):
|
|
158
|
+
super().__init__()
|
|
150
159
|
if weight_bits not in [2, 3, 4, 8]:
|
|
151
160
|
raise ValueError(
|
|
152
161
|
f"Unsupported weight_bits {weight_bits}. "
|
|
@@ -174,6 +183,32 @@ class GPTQConfig:
|
|
|
174
183
|
self.group_size = group_size
|
|
175
184
|
self.symmetric = symmetric
|
|
176
185
|
self.activation_order = activation_order
|
|
186
|
+
self.quantization_layer_structure = quantization_layer_structure
|
|
187
|
+
|
|
188
|
+
def get_config(self):
|
|
189
|
+
return {
|
|
190
|
+
# Dataset and Tokenizer are only required for a one-time
|
|
191
|
+
# calibration and are not saved in the config.
|
|
192
|
+
"dataset": None,
|
|
193
|
+
"tokenizer": None,
|
|
194
|
+
"weight_bits": self.weight_bits,
|
|
195
|
+
"num_samples": self.num_samples,
|
|
196
|
+
"per_channel": self.per_channel,
|
|
197
|
+
"sequence_length": self.sequence_length,
|
|
198
|
+
"hessian_damping": self.hessian_damping,
|
|
199
|
+
"group_size": self.group_size,
|
|
200
|
+
"symmetric": self.symmetric,
|
|
201
|
+
"activation_order": self.activation_order,
|
|
202
|
+
"quantization_layer_structure": self.quantization_layer_structure,
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
@classmethod
|
|
206
|
+
def from_config(cls, config):
|
|
207
|
+
return cls(**config)
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def mode(self):
|
|
211
|
+
return "gptq"
|
|
177
212
|
|
|
178
213
|
def dtype_policy_string(self):
|
|
179
214
|
"""Returns the dtype policy string for this configuration.
|
|
@@ -10,9 +10,9 @@ from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
|
|
|
10
10
|
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
11
11
|
from keras.src.layers import Dense
|
|
12
12
|
from keras.src.layers import EinsumDense
|
|
13
|
-
from keras.src.layers import Embedding
|
|
14
13
|
from keras.src.quantizers.gptq import GPTQ
|
|
15
14
|
from keras.src.quantizers.gptq_config import GPTQConfig
|
|
15
|
+
from keras.src.quantizers.utils import should_quantize_layer
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
@contextmanager
|
|
@@ -131,7 +131,7 @@ def get_dataloader(
|
|
|
131
131
|
pieces = []
|
|
132
132
|
if isinstance(dataset_list[0], str):
|
|
133
133
|
for i, s in enumerate(dataset_list):
|
|
134
|
-
toks =
|
|
134
|
+
toks = ops.convert_to_numpy(tokenizer.tokenize(s)).reshape(-1)
|
|
135
135
|
pieces.append(toks)
|
|
136
136
|
# avoid windows that span document boundaries
|
|
137
137
|
if eos_id is not None and i < len(dataset_list) - 1:
|
|
@@ -193,38 +193,6 @@ def get_dataloader(
|
|
|
193
193
|
return samples.astype(np.int32)[:, None, :]
|
|
194
194
|
|
|
195
195
|
|
|
196
|
-
def _get_backbone_layers(model):
|
|
197
|
-
"""Extract embedding and transformer layers from a KerasHub model."""
|
|
198
|
-
backbone = model.backbone
|
|
199
|
-
if not hasattr(backbone, "transformer_layers"):
|
|
200
|
-
raise ValueError(
|
|
201
|
-
"The model's backbone does not have a 'transformer_layers' "
|
|
202
|
-
"attribute. Please ensure you are using a standard KerasHub "
|
|
203
|
-
"transformer model."
|
|
204
|
-
)
|
|
205
|
-
transformer_blocks = backbone.transformer_layers
|
|
206
|
-
|
|
207
|
-
embedding_layer = None
|
|
208
|
-
if hasattr(backbone, "token_embedding"):
|
|
209
|
-
embedding_layer = backbone.token_embedding
|
|
210
|
-
elif hasattr(backbone, "embedding"):
|
|
211
|
-
embedding_layer = backbone.embedding
|
|
212
|
-
return embedding_layer, transformer_blocks
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
def _get_custom_layers(model):
|
|
216
|
-
"""Heuristic for extracting embedding + transformer blocks from a custom
|
|
217
|
-
model."""
|
|
218
|
-
embedding_layer = None
|
|
219
|
-
transformer_blocks = []
|
|
220
|
-
for layer in model.layers:
|
|
221
|
-
if isinstance(layer, Embedding) and embedding_layer is None:
|
|
222
|
-
embedding_layer = layer
|
|
223
|
-
elif getattr(layer, "_layers", None): # container-like block
|
|
224
|
-
transformer_blocks.append(layer)
|
|
225
|
-
return embedding_layer, transformer_blocks
|
|
226
|
-
|
|
227
|
-
|
|
228
196
|
def find_layers_in_block(block):
|
|
229
197
|
"""
|
|
230
198
|
Finds all Dense and EinsumDense layers in a transformer block.
|
|
@@ -242,39 +210,31 @@ def find_layers_in_block(block):
|
|
|
242
210
|
return found_layers
|
|
243
211
|
|
|
244
212
|
|
|
245
|
-
def apply_gptq_layerwise(
|
|
213
|
+
def apply_gptq_layerwise(dataloader, config, structure, filters=None):
|
|
246
214
|
"""Applies GPTQ quantization layer-by-layer to a Keras model.
|
|
247
215
|
|
|
248
|
-
This function
|
|
249
|
-
|
|
250
|
-
structure by first looking for the standard format: a `model.backbone`
|
|
251
|
-
attribute that contains a `transformer_layers` list.
|
|
252
|
-
|
|
253
|
-
If a standard backbone is not found, it falls back to a heuristic for
|
|
254
|
-
custom models, where it assumes the first `keras.layers.Embedding` layer
|
|
255
|
-
is the input embedding and any subsequent container layers are the
|
|
256
|
-
transformer blocks to be quantized.
|
|
216
|
+
This function uses the provided `structure` to identify pre-quantization
|
|
217
|
+
layers and sequential blocks.
|
|
257
218
|
|
|
258
219
|
The core logic operates as follows:
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
2. It processes the model sequentially, one block at a time. For each
|
|
220
|
+
|
|
221
|
+
1. It processes the model sequentially, one block at a time. For each
|
|
262
222
|
block, it uses temporary hooks to capture the input activations of
|
|
263
223
|
each target layer during a forward pass with the calibration data.
|
|
264
|
-
|
|
224
|
+
2. These captured activations are used to compute the Hessian matrix for
|
|
265
225
|
each layer's weights.
|
|
266
|
-
|
|
226
|
+
3. The GPTQ algorithm is then applied to each layer to find the optimal
|
|
267
227
|
quantized weights that minimize the error introduced.
|
|
268
|
-
|
|
228
|
+
4. The output activations from the current block are then used as the
|
|
269
229
|
input for the next block, ensuring that quantization errors are
|
|
270
230
|
accounted for throughout the model.
|
|
271
231
|
|
|
272
232
|
Args:
|
|
273
|
-
|
|
274
|
-
attempt to automatically discover its structure.
|
|
275
|
-
dataloader: An iterable providing calibration data. Each item should
|
|
276
|
-
be a batch of token IDs suitable for the model's embedding layer.
|
|
233
|
+
dataloader: An iterable providing calibration data.
|
|
277
234
|
config: A GPTQConfiguration object.
|
|
235
|
+
structure: A dictionary with keys "pre_block_layers" and
|
|
236
|
+
"sequential_blocks".
|
|
237
|
+
filters: Optional filters to exclude layers from quantization.
|
|
278
238
|
|
|
279
239
|
Raises:
|
|
280
240
|
ValueError: If the function cannot automatically find an embedding
|
|
@@ -284,30 +244,23 @@ def apply_gptq_layerwise(model, dataloader, config):
|
|
|
284
244
|
num_samples = config.num_samples
|
|
285
245
|
|
|
286
246
|
logging.info("Starting model quantization...")
|
|
287
|
-
embedding_layer = None
|
|
288
|
-
transformer_blocks = []
|
|
289
|
-
if hasattr(model, "backbone"):
|
|
290
|
-
logging.info("Detected KerasHub model structure.")
|
|
291
|
-
embedding_layer, transformer_blocks = _get_backbone_layers(model)
|
|
292
|
-
else:
|
|
293
|
-
logging.info("Detected custom model structure.")
|
|
294
|
-
embedding_layer, transformer_blocks = _get_custom_layers(model)
|
|
295
247
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
)
|
|
248
|
+
pre_layers = structure.get("pre_block_layers", [])
|
|
249
|
+
transformer_blocks = structure.get("sequential_blocks", [])
|
|
250
|
+
|
|
300
251
|
if not transformer_blocks:
|
|
301
252
|
raise ValueError(
|
|
302
|
-
"
|
|
303
|
-
"quantize."
|
|
253
|
+
"No sequential blocks found in the provided structure to quantize."
|
|
304
254
|
)
|
|
305
255
|
|
|
306
|
-
# Initial inputs are the outputs of the
|
|
307
|
-
inputs = [
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
256
|
+
# Initial inputs are the outputs of the pre-block layers
|
|
257
|
+
inputs = []
|
|
258
|
+
for batch in dataloader:
|
|
259
|
+
batch = ops.convert_to_tensor(batch, dtype="int32")
|
|
260
|
+
for layer in pre_layers:
|
|
261
|
+
batch = layer(batch)
|
|
262
|
+
inputs.append(batch)
|
|
263
|
+
|
|
311
264
|
num_samples = min(num_samples, len(inputs))
|
|
312
265
|
|
|
313
266
|
progbar = keras_utils.Progbar(target=len(transformer_blocks))
|
|
@@ -316,10 +269,19 @@ def apply_gptq_layerwise(model, dataloader, config):
|
|
|
316
269
|
logging.info(f"Quantizing Block {block_idx}")
|
|
317
270
|
sub_layers_map = find_layers_in_block(block)
|
|
318
271
|
|
|
272
|
+
# Filter out layers that are not quantized with GPTQ
|
|
273
|
+
final_sub_layers_map = {}
|
|
274
|
+
for name, layer in sub_layers_map.items():
|
|
275
|
+
if not should_quantize_layer(layer, filters):
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
final_sub_layers_map[name] = layer
|
|
279
|
+
|
|
280
|
+
sub_layers_map = final_sub_layers_map
|
|
281
|
+
|
|
319
282
|
if not sub_layers_map:
|
|
320
283
|
logging.info(
|
|
321
|
-
f" No
|
|
322
|
-
"Skipping."
|
|
284
|
+
f" No quantizable layers found in block {block_idx}. Skipping."
|
|
323
285
|
)
|
|
324
286
|
else:
|
|
325
287
|
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
|
|
@@ -357,11 +319,30 @@ def apply_gptq_layerwise(model, dataloader, config):
|
|
|
357
319
|
logging.info("Quantization process complete.")
|
|
358
320
|
|
|
359
321
|
|
|
360
|
-
def gptq_quantize(
|
|
322
|
+
def gptq_quantize(config, quantization_layer_structure, filters=None):
|
|
361
323
|
"""
|
|
362
|
-
|
|
324
|
+
Quantizes the model using GPTQ.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
config: The GPTQ configuration.
|
|
328
|
+
quantization_layer_structure: A dictionary describing the model's layer
|
|
329
|
+
structure for quantization.
|
|
330
|
+
filters: Optional filters to exclude layers from quantization.
|
|
363
331
|
"""
|
|
364
|
-
|
|
332
|
+
if config.dataset is None or config.tokenizer is None:
|
|
333
|
+
raise ValueError(
|
|
334
|
+
"GPTQ quantization requires a dataset and a tokenizer. "
|
|
335
|
+
"Please provide them in the `GPTQConfig`."
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if quantization_layer_structure is None:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
"For 'gptq' mode, a valid quantization structure must be provided "
|
|
341
|
+
"either via `config.quantization_layer_structure` or by overriding "
|
|
342
|
+
"`model.get_quantization_layer_structure(mode)`. The structure "
|
|
343
|
+
"should be a dictionary with keys 'pre_block_layers' and "
|
|
344
|
+
"'sequential_blocks'."
|
|
345
|
+
)
|
|
365
346
|
|
|
366
347
|
# Load all data needed from the generator/source in a single call.
|
|
367
348
|
total_samples_to_request = config.num_samples
|
|
@@ -376,7 +357,12 @@ def gptq_quantize(model, config):
|
|
|
376
357
|
# is now a NumPy array, which can be sliced and reused.
|
|
377
358
|
calibration_dataloader = dataloader[: config.num_samples]
|
|
378
359
|
|
|
379
|
-
apply_gptq_layerwise(
|
|
360
|
+
apply_gptq_layerwise(
|
|
361
|
+
calibration_dataloader,
|
|
362
|
+
config,
|
|
363
|
+
quantization_layer_structure,
|
|
364
|
+
filters=filters,
|
|
365
|
+
)
|
|
380
366
|
|
|
381
367
|
|
|
382
368
|
def get_group_size_for_layer(layer, config):
|