keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""AWQ core functionality for layer-wise quantization.
|
|
2
|
+
|
|
3
|
+
This module provides the orchestration logic for applying AWQ quantization
|
|
4
|
+
to transformer models in a layer-by-layer fashion.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
|
|
9
|
+
from absl import logging
|
|
10
|
+
|
|
11
|
+
from keras.src import ops
|
|
12
|
+
from keras.src import utils as keras_utils
|
|
13
|
+
from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
|
|
14
|
+
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
15
|
+
from keras.src.quantizers.awq import AWQ
|
|
16
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
17
|
+
from keras.src.quantizers.gptq_core import find_layers_in_block
|
|
18
|
+
from keras.src.quantizers.gptq_core import get_dataloader
|
|
19
|
+
from keras.src.quantizers.utils import should_quantize_layer
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@contextmanager
|
|
23
|
+
def stream_activations(layers_map, awq_objects):
|
|
24
|
+
"""Context manager to capture activations for AWQ calibration.
|
|
25
|
+
|
|
26
|
+
Temporarily patches layer.call methods to capture activation statistics
|
|
27
|
+
for computing per-channel scaling factors.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
layers_map: Dict[str, Layer]. Mapping from layer names to layers.
|
|
31
|
+
awq_objects: Dict[str, AWQ]. Mapping from names to AWQ instances.
|
|
32
|
+
|
|
33
|
+
Yields:
|
|
34
|
+
None: The patched state is active only within the `with` block.
|
|
35
|
+
"""
|
|
36
|
+
original_calls = {}
|
|
37
|
+
|
|
38
|
+
def create_hook(name, original_call_func):
|
|
39
|
+
def hook(*args, **kwargs):
|
|
40
|
+
inp = args[0] if args else kwargs["inputs"]
|
|
41
|
+
num_features = awq_objects[name].rows
|
|
42
|
+
input_2d = ops.reshape(inp, (-1, num_features))
|
|
43
|
+
awq_objects[name].update_activation_magnitudes(input_2d)
|
|
44
|
+
return original_call_func(*args, **kwargs)
|
|
45
|
+
|
|
46
|
+
return hook
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
for name, layer in layers_map.items():
|
|
50
|
+
original_calls[name] = layer.call
|
|
51
|
+
layer.call = create_hook(name, layer.call)
|
|
52
|
+
yield
|
|
53
|
+
finally:
|
|
54
|
+
for name, layer in layers_map.items():
|
|
55
|
+
layer.call = original_calls[name]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def apply_awq_layerwise(dataloader, config, structure, filters=None):
|
|
59
|
+
"""Apply AWQ quantization layer-by-layer to a Keras model.
|
|
60
|
+
|
|
61
|
+
This function processes the model sequentially, one block at a time:
|
|
62
|
+
1. Captures activation statistics through calibration data forward pass
|
|
63
|
+
2. Uses activation magnitudes to determine weight saliency
|
|
64
|
+
3. Finds optimal per-channel scales via grid search
|
|
65
|
+
4. Quantizes weights with AWQ scaling
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
dataloader: Calibration data as numpy array.
|
|
69
|
+
config: AWQConfig instance.
|
|
70
|
+
structure: Dict with 'pre_block_layers' and 'sequential_blocks'.
|
|
71
|
+
filters: Optional layer filters.
|
|
72
|
+
"""
|
|
73
|
+
num_samples = config.num_samples
|
|
74
|
+
logging.info("Starting AWQ quantization...")
|
|
75
|
+
|
|
76
|
+
pre_layers = structure.get("pre_block_layers", [])
|
|
77
|
+
transformer_blocks = structure.get("sequential_blocks", [])
|
|
78
|
+
|
|
79
|
+
if not transformer_blocks:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"No sequential blocks found in the structure to quantize."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Process inputs through pre-block layers (e.g., embedding)
|
|
85
|
+
inputs = []
|
|
86
|
+
for batch in dataloader:
|
|
87
|
+
batch = ops.convert_to_tensor(batch, dtype="int32")
|
|
88
|
+
for layer in pre_layers:
|
|
89
|
+
batch = layer(batch)
|
|
90
|
+
inputs.append(batch)
|
|
91
|
+
|
|
92
|
+
num_samples = min(num_samples, len(inputs))
|
|
93
|
+
progbar = keras_utils.Progbar(target=len(transformer_blocks))
|
|
94
|
+
|
|
95
|
+
for block_idx, block in enumerate(transformer_blocks):
|
|
96
|
+
logging.info(f"Quantizing Block {block_idx}")
|
|
97
|
+
sub_layers_map = find_layers_in_block(block)
|
|
98
|
+
|
|
99
|
+
# Apply filters
|
|
100
|
+
final_sub_layers_map = {}
|
|
101
|
+
for name, layer in sub_layers_map.items():
|
|
102
|
+
if not should_quantize_layer(layer, filters):
|
|
103
|
+
continue
|
|
104
|
+
final_sub_layers_map[name] = layer
|
|
105
|
+
|
|
106
|
+
sub_layers_map = final_sub_layers_map
|
|
107
|
+
|
|
108
|
+
if not sub_layers_map:
|
|
109
|
+
logging.info(
|
|
110
|
+
f" No quantizable layers found in block {block_idx}. Skipping."
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
|
|
114
|
+
|
|
115
|
+
# Create AWQ objects for each layer
|
|
116
|
+
awq_objects = {
|
|
117
|
+
name: AWQ(layer, config)
|
|
118
|
+
for name, layer in sub_layers_map.items()
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
# Capture activation statistics
|
|
122
|
+
with stream_activations(sub_layers_map, awq_objects):
|
|
123
|
+
for sample_idx in range(num_samples):
|
|
124
|
+
current_input = inputs[sample_idx]
|
|
125
|
+
if len(current_input.shape) == 2:
|
|
126
|
+
current_input = ops.expand_dims(current_input, axis=0)
|
|
127
|
+
_ = block(current_input)
|
|
128
|
+
|
|
129
|
+
# Quantize each layer
|
|
130
|
+
for name, awq_object in awq_objects.items():
|
|
131
|
+
logging.info(f"Quantizing {name}...")
|
|
132
|
+
awq_object.quantize_layer()
|
|
133
|
+
awq_object.free()
|
|
134
|
+
|
|
135
|
+
del awq_objects
|
|
136
|
+
|
|
137
|
+
# Generate inputs for next block
|
|
138
|
+
if block_idx < len(transformer_blocks) - 1:
|
|
139
|
+
logging.info(f"Generating inputs for block {block_idx + 1}...")
|
|
140
|
+
next_block_inputs = []
|
|
141
|
+
for sample_idx in range(num_samples):
|
|
142
|
+
current_input = inputs[sample_idx]
|
|
143
|
+
if len(current_input.shape) == 2:
|
|
144
|
+
current_input = ops.expand_dims(current_input, axis=0)
|
|
145
|
+
output = block(current_input)[0]
|
|
146
|
+
next_block_inputs.append(output)
|
|
147
|
+
inputs = next_block_inputs
|
|
148
|
+
|
|
149
|
+
progbar.update(current=block_idx + 1)
|
|
150
|
+
|
|
151
|
+
logging.info("AWQ quantization complete.")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def awq_quantize(config, quantization_layer_structure, filters=None):
|
|
155
|
+
"""Main entry point for AWQ quantization.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
config: AWQConfig instance.
|
|
159
|
+
quantization_layer_structure: Model structure dictionary.
|
|
160
|
+
filters: Optional layer filters.
|
|
161
|
+
"""
|
|
162
|
+
if config.dataset is None or config.tokenizer is None:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
"AWQ quantization requires a dataset and tokenizer. "
|
|
165
|
+
"Please provide them in the AWQConfig."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if quantization_layer_structure is None:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
"For 'awq' mode, a valid quantization structure must be provided "
|
|
171
|
+
"either via `config.quantization_layer_structure` or by overriding "
|
|
172
|
+
"`model.get_quantization_layer_structure(mode)`. The structure "
|
|
173
|
+
"should be a dictionary with keys 'pre_block_layers' and "
|
|
174
|
+
"'sequential_blocks'."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Load calibration data
|
|
178
|
+
dataloader = get_dataloader(
|
|
179
|
+
config.tokenizer,
|
|
180
|
+
config.sequence_length,
|
|
181
|
+
config.dataset,
|
|
182
|
+
num_samples=config.num_samples,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
apply_awq_layerwise(
|
|
186
|
+
dataloader[: config.num_samples],
|
|
187
|
+
config,
|
|
188
|
+
quantization_layer_structure,
|
|
189
|
+
filters=filters,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_group_size_for_layer(layer, config):
|
|
194
|
+
"""Get group size from config or dtype policy.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
layer: The layer to get group size for.
|
|
198
|
+
config: Optional AWQConfig instance.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
int: The group size for quantization.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
ValueError: If group size cannot be determined.
|
|
205
|
+
"""
|
|
206
|
+
if config and isinstance(config, AWQConfig):
|
|
207
|
+
return config.group_size
|
|
208
|
+
elif isinstance(layer.dtype_policy, AWQDTypePolicy):
|
|
209
|
+
return layer.dtype_policy.group_size
|
|
210
|
+
elif isinstance(layer.dtype_policy, DTypePolicyMap):
|
|
211
|
+
policy = layer.dtype_policy[layer.path]
|
|
212
|
+
if isinstance(policy, AWQDTypePolicy):
|
|
213
|
+
return policy.group_size
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"For AWQ quantization, group_size must be specified "
|
|
216
|
+
"through AWQConfig or AWQDTypePolicy."
|
|
217
|
+
)
|