keras-nightly 3.14.0.dev2025122704__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/ops/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +3 -0
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/ops/__init__.py +3 -0
- keras/ops/numpy/__init__.py +3 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/backend/jax/nn.py +26 -9
- keras/src/backend/jax/numpy.py +16 -0
- keras/src/backend/numpy/numpy.py +23 -0
- keras/src/backend/openvino/numpy.py +369 -16
- keras/src/backend/tensorflow/numpy.py +34 -1
- keras/src/backend/tensorflow/rnn.py +17 -7
- keras/src/backend/torch/numpy.py +36 -0
- keras/src/backend/torch/rnn.py +28 -11
- keras/src/callbacks/orbax_checkpoint.py +75 -42
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/layers/core/dense.py +122 -6
- keras/src/layers/core/einsum_dense.py +151 -7
- keras/src/layers/core/embedding.py +1 -1
- keras/src/layers/core/reversible_embedding.py +10 -1
- keras/src/layers/layer.py +5 -0
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/losses/losses.py +24 -0
- keras/src/models/model.py +18 -9
- keras/src/ops/image.py +109 -96
- keras/src/ops/numpy.py +181 -0
- keras/src/quantizers/__init__.py +2 -0
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +1 -2
- keras/src/quantizers/gptq_core.py +1 -1
- keras/src/quantizers/quantization_config.py +14 -0
- keras/src/quantizers/quantizers.py +61 -52
- keras/src/random/seed_generator.py +2 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +50 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/utils/jax_layer.py +69 -31
- keras/src/utils/module_utils.py +11 -0
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +53 -49
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
- {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""AWQ core functionality for layer-wise quantization.
|
|
2
|
+
|
|
3
|
+
This module provides the orchestration logic for applying AWQ quantization
|
|
4
|
+
to transformer models in a layer-by-layer fashion.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
|
|
9
|
+
from absl import logging
|
|
10
|
+
|
|
11
|
+
from keras.src import ops
|
|
12
|
+
from keras.src import utils as keras_utils
|
|
13
|
+
from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
|
|
14
|
+
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
15
|
+
from keras.src.quantizers.awq import AWQ
|
|
16
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
17
|
+
from keras.src.quantizers.gptq_core import find_layers_in_block
|
|
18
|
+
from keras.src.quantizers.gptq_core import get_dataloader
|
|
19
|
+
from keras.src.quantizers.utils import should_quantize_layer
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@contextmanager
|
|
23
|
+
def stream_activations(layers_map, awq_objects):
|
|
24
|
+
"""Context manager to capture activations for AWQ calibration.
|
|
25
|
+
|
|
26
|
+
Temporarily patches layer.call methods to capture activation statistics
|
|
27
|
+
for computing per-channel scaling factors.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
layers_map: Dict[str, Layer]. Mapping from layer names to layers.
|
|
31
|
+
awq_objects: Dict[str, AWQ]. Mapping from names to AWQ instances.
|
|
32
|
+
|
|
33
|
+
Yields:
|
|
34
|
+
None: The patched state is active only within the `with` block.
|
|
35
|
+
"""
|
|
36
|
+
original_calls = {}
|
|
37
|
+
|
|
38
|
+
def create_hook(name, original_call_func):
|
|
39
|
+
def hook(*args, **kwargs):
|
|
40
|
+
inp = args[0] if args else kwargs["inputs"]
|
|
41
|
+
num_features = awq_objects[name].rows
|
|
42
|
+
input_2d = ops.reshape(inp, (-1, num_features))
|
|
43
|
+
awq_objects[name].update_activation_magnitudes(input_2d)
|
|
44
|
+
return original_call_func(*args, **kwargs)
|
|
45
|
+
|
|
46
|
+
return hook
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
for name, layer in layers_map.items():
|
|
50
|
+
original_calls[name] = layer.call
|
|
51
|
+
layer.call = create_hook(name, layer.call)
|
|
52
|
+
yield
|
|
53
|
+
finally:
|
|
54
|
+
for name, layer in layers_map.items():
|
|
55
|
+
layer.call = original_calls[name]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def apply_awq_layerwise(dataloader, config, structure, filters=None):
|
|
59
|
+
"""Apply AWQ quantization layer-by-layer to a Keras model.
|
|
60
|
+
|
|
61
|
+
This function processes the model sequentially, one block at a time:
|
|
62
|
+
1. Captures activation statistics through calibration data forward pass
|
|
63
|
+
2. Uses activation magnitudes to determine weight saliency
|
|
64
|
+
3. Finds optimal per-channel scales via grid search
|
|
65
|
+
4. Quantizes weights with AWQ scaling
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
dataloader: Calibration data as numpy array.
|
|
69
|
+
config: AWQConfig instance.
|
|
70
|
+
structure: Dict with 'pre_block_layers' and 'sequential_blocks'.
|
|
71
|
+
filters: Optional layer filters.
|
|
72
|
+
"""
|
|
73
|
+
num_samples = config.num_samples
|
|
74
|
+
logging.info("Starting AWQ quantization...")
|
|
75
|
+
|
|
76
|
+
pre_layers = structure.get("pre_block_layers", [])
|
|
77
|
+
transformer_blocks = structure.get("sequential_blocks", [])
|
|
78
|
+
|
|
79
|
+
if not transformer_blocks:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"No sequential blocks found in the structure to quantize."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Process inputs through pre-block layers (e.g., embedding)
|
|
85
|
+
inputs = []
|
|
86
|
+
for batch in dataloader:
|
|
87
|
+
batch = ops.convert_to_tensor(batch, dtype="int32")
|
|
88
|
+
for layer in pre_layers:
|
|
89
|
+
batch = layer(batch)
|
|
90
|
+
inputs.append(batch)
|
|
91
|
+
|
|
92
|
+
num_samples = min(num_samples, len(inputs))
|
|
93
|
+
progbar = keras_utils.Progbar(target=len(transformer_blocks))
|
|
94
|
+
|
|
95
|
+
for block_idx, block in enumerate(transformer_blocks):
|
|
96
|
+
logging.info(f"Quantizing Block {block_idx}")
|
|
97
|
+
sub_layers_map = find_layers_in_block(block)
|
|
98
|
+
|
|
99
|
+
# Apply filters
|
|
100
|
+
final_sub_layers_map = {}
|
|
101
|
+
for name, layer in sub_layers_map.items():
|
|
102
|
+
if not should_quantize_layer(layer, filters):
|
|
103
|
+
continue
|
|
104
|
+
final_sub_layers_map[name] = layer
|
|
105
|
+
|
|
106
|
+
sub_layers_map = final_sub_layers_map
|
|
107
|
+
|
|
108
|
+
if not sub_layers_map:
|
|
109
|
+
logging.info(
|
|
110
|
+
f" No quantizable layers found in block {block_idx}. Skipping."
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
|
|
114
|
+
|
|
115
|
+
# Create AWQ objects for each layer
|
|
116
|
+
awq_objects = {
|
|
117
|
+
name: AWQ(layer, config)
|
|
118
|
+
for name, layer in sub_layers_map.items()
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
# Capture activation statistics
|
|
122
|
+
with stream_activations(sub_layers_map, awq_objects):
|
|
123
|
+
for sample_idx in range(num_samples):
|
|
124
|
+
current_input = inputs[sample_idx]
|
|
125
|
+
if len(current_input.shape) == 2:
|
|
126
|
+
current_input = ops.expand_dims(current_input, axis=0)
|
|
127
|
+
_ = block(current_input)
|
|
128
|
+
|
|
129
|
+
# Quantize each layer
|
|
130
|
+
for name, awq_object in awq_objects.items():
|
|
131
|
+
logging.info(f"Quantizing {name}...")
|
|
132
|
+
awq_object.quantize_layer()
|
|
133
|
+
awq_object.free()
|
|
134
|
+
|
|
135
|
+
del awq_objects
|
|
136
|
+
|
|
137
|
+
# Generate inputs for next block
|
|
138
|
+
if block_idx < len(transformer_blocks) - 1:
|
|
139
|
+
logging.info(f"Generating inputs for block {block_idx + 1}...")
|
|
140
|
+
next_block_inputs = []
|
|
141
|
+
for sample_idx in range(num_samples):
|
|
142
|
+
current_input = inputs[sample_idx]
|
|
143
|
+
if len(current_input.shape) == 2:
|
|
144
|
+
current_input = ops.expand_dims(current_input, axis=0)
|
|
145
|
+
output = block(current_input)[0]
|
|
146
|
+
next_block_inputs.append(output)
|
|
147
|
+
inputs = next_block_inputs
|
|
148
|
+
|
|
149
|
+
progbar.update(current=block_idx + 1)
|
|
150
|
+
|
|
151
|
+
logging.info("AWQ quantization complete.")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def awq_quantize(config, quantization_layer_structure, filters=None):
|
|
155
|
+
"""Main entry point for AWQ quantization.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
config: AWQConfig instance.
|
|
159
|
+
quantization_layer_structure: Model structure dictionary.
|
|
160
|
+
filters: Optional layer filters.
|
|
161
|
+
"""
|
|
162
|
+
if config.dataset is None or config.tokenizer is None:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
"AWQ quantization requires a dataset and tokenizer. "
|
|
165
|
+
"Please provide them in the AWQConfig."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if quantization_layer_structure is None:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
"For 'awq' mode, a valid quantization structure must be provided "
|
|
171
|
+
"either via `config.quantization_layer_structure` or by overriding "
|
|
172
|
+
"`model.get_quantization_layer_structure(mode)`. The structure "
|
|
173
|
+
"should be a dictionary with keys 'pre_block_layers' and "
|
|
174
|
+
"'sequential_blocks'."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Load calibration data
|
|
178
|
+
dataloader = get_dataloader(
|
|
179
|
+
config.tokenizer,
|
|
180
|
+
config.sequence_length,
|
|
181
|
+
config.dataset,
|
|
182
|
+
num_samples=config.num_samples,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
apply_awq_layerwise(
|
|
186
|
+
dataloader[: config.num_samples],
|
|
187
|
+
config,
|
|
188
|
+
quantization_layer_structure,
|
|
189
|
+
filters=filters,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_group_size_for_layer(layer, config):
|
|
194
|
+
"""Get group size from config or dtype policy.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
layer: The layer to get group size for.
|
|
198
|
+
config: Optional AWQConfig instance.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
int: The group size for quantization.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
ValueError: If group size cannot be determined.
|
|
205
|
+
"""
|
|
206
|
+
if config and isinstance(config, AWQConfig):
|
|
207
|
+
return config.group_size
|
|
208
|
+
elif isinstance(layer.dtype_policy, AWQDTypePolicy):
|
|
209
|
+
return layer.dtype_policy.group_size
|
|
210
|
+
elif isinstance(layer.dtype_policy, DTypePolicyMap):
|
|
211
|
+
policy = layer.dtype_policy[layer.path]
|
|
212
|
+
if isinstance(policy, AWQDTypePolicy):
|
|
213
|
+
return policy.group_size
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"For AWQ quantization, group_size must be specified "
|
|
216
|
+
"through AWQConfig or AWQDTypePolicy."
|
|
217
|
+
)
|
keras/src/quantizers/gptq.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import types
|
|
2
|
-
from functools import partial
|
|
3
2
|
|
|
4
3
|
from keras.src import ops
|
|
5
4
|
from keras.src import quantizers
|
|
@@ -466,7 +465,7 @@ class GPTQ:
|
|
|
466
465
|
group_size=self.config.group_size,
|
|
467
466
|
activation_order=self.config.activation_order,
|
|
468
467
|
order_metric=ops.diagonal(hessian_matrix),
|
|
469
|
-
compute_scale_zero=
|
|
468
|
+
compute_scale_zero=self.quantizer.find_params,
|
|
470
469
|
)
|
|
471
470
|
quantized = ops.cast(
|
|
472
471
|
quantized, self.original_layer.quantized_kernel.dtype
|
|
@@ -131,7 +131,7 @@ def get_dataloader(
|
|
|
131
131
|
pieces = []
|
|
132
132
|
if isinstance(dataset_list[0], str):
|
|
133
133
|
for i, s in enumerate(dataset_list):
|
|
134
|
-
toks =
|
|
134
|
+
toks = ops.convert_to_numpy(tokenizer.tokenize(s)).reshape(-1)
|
|
135
135
|
pieces.append(toks)
|
|
136
136
|
# avoid windows that span document boundaries
|
|
137
137
|
if eos_id is not None and i < len(dataset_list) - 1:
|
|
@@ -182,6 +182,11 @@ def validate_and_resolve_config(mode, config):
|
|
|
182
182
|
"For GPTQ, you must pass a `GPTQConfig` object in the "
|
|
183
183
|
"`config` argument."
|
|
184
184
|
)
|
|
185
|
+
elif mode == "awq":
|
|
186
|
+
raise ValueError(
|
|
187
|
+
"For AWQ, you must pass an `AWQConfig` object in the "
|
|
188
|
+
"`config` argument."
|
|
189
|
+
)
|
|
185
190
|
else:
|
|
186
191
|
if mode is not None:
|
|
187
192
|
raise ValueError(
|
|
@@ -220,6 +225,15 @@ def validate_and_resolve_config(mode, config):
|
|
|
220
225
|
f"`GPTQConfig`. Received: {type(config)}"
|
|
221
226
|
)
|
|
222
227
|
|
|
228
|
+
if mode == "awq":
|
|
229
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
230
|
+
|
|
231
|
+
if not isinstance(config, AWQConfig):
|
|
232
|
+
raise ValueError(
|
|
233
|
+
"Mode 'awq' requires a valid `config` argument of type "
|
|
234
|
+
f"`AWQConfig`. Received: {type(config)}"
|
|
235
|
+
)
|
|
236
|
+
|
|
223
237
|
return config
|
|
224
238
|
|
|
225
239
|
|
|
@@ -653,11 +653,14 @@ def unpack_int4(packed, orig_len, axis=0, dtype="int8"):
|
|
|
653
653
|
)
|
|
654
654
|
|
|
655
655
|
def to_signed(x):
|
|
656
|
-
"""Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
|
|
656
|
+
"""Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
|
|
657
|
+
|
|
658
|
+
Uses a branchless XOR approach: (x ^ 8) - 8
|
|
659
|
+
This maps: 0->0, 1->1, ..., 7->7, 8->-8, 9->-7, ..., 15->-1
|
|
660
|
+
"""
|
|
657
661
|
dtype_x = backend.standardize_dtype(x.dtype)
|
|
658
662
|
eight = ops.cast(8, dtype_x)
|
|
659
|
-
|
|
660
|
-
return ops.where(x < eight, x, x - sixteen)
|
|
663
|
+
return ops.subtract(ops.bitwise_xor(x, eight), eight)
|
|
661
664
|
|
|
662
665
|
rank = getattr(packed.shape, "rank", None) or len(packed.shape)
|
|
663
666
|
if axis < 0:
|
|
@@ -748,7 +751,7 @@ class GPTQQuantizer(Quantizer):
|
|
|
748
751
|
self.zero = None
|
|
749
752
|
self.maxq = None
|
|
750
753
|
|
|
751
|
-
def find_params(self, input_tensor
|
|
754
|
+
def find_params(self, input_tensor):
|
|
752
755
|
"""Finds quantization parameters (scale and zero) for a given tensor."""
|
|
753
756
|
self.scale, self.zero, self.maxq = compute_quantization_parameters(
|
|
754
757
|
input_tensor,
|
|
@@ -756,7 +759,6 @@ class GPTQQuantizer(Quantizer):
|
|
|
756
759
|
symmetric=self.symmetric,
|
|
757
760
|
per_channel=self.per_channel,
|
|
758
761
|
group_size=self.group_size,
|
|
759
|
-
weight=weight,
|
|
760
762
|
compute_dtype=self.compute_dtype,
|
|
761
763
|
)
|
|
762
764
|
return self.scale, self.zero, self.maxq
|
|
@@ -793,98 +795,105 @@ def compute_quantization_parameters(
|
|
|
793
795
|
symmetric=False,
|
|
794
796
|
per_channel=False,
|
|
795
797
|
group_size=-1,
|
|
796
|
-
weight=False,
|
|
797
798
|
compute_dtype="float32",
|
|
798
799
|
):
|
|
799
800
|
"""
|
|
800
|
-
Computes the scale and zero-point for
|
|
801
|
+
Computes the scale and zero-point for quantizing weight tensors.
|
|
801
802
|
|
|
802
803
|
This function calculates the scale and zero-point required for quantizing
|
|
803
|
-
a given tensor `x` based on the specified parameters. It supports
|
|
804
|
-
per-channel, per-tensor, symmetric, and asymmetric quantization
|
|
805
|
-
|
|
804
|
+
a given weight tensor `x` based on the specified parameters. It supports
|
|
805
|
+
grouped, per-channel, per-tensor, symmetric, and asymmetric quantization.
|
|
806
|
+
|
|
807
|
+
For grouped quantization (per_channel=True, group_size > 0), the output
|
|
808
|
+
shapes are [out_features, n_groups] where n_groups is the number of groups
|
|
809
|
+
along the in_features dimension.
|
|
806
810
|
|
|
807
811
|
Args:
|
|
808
|
-
x: KerasTensor. The
|
|
812
|
+
x: KerasTensor. The weight tensor to quantize with shape
|
|
813
|
+
[out_features, in_features].
|
|
809
814
|
bits: int. The number of bits to quantize to (e.g., 4).
|
|
810
815
|
symmetric: bool. Whether to use symmetric quantization.
|
|
811
816
|
per_channel: bool. Whether to quantize per channel.
|
|
812
|
-
group_size: int. The group size for quantization.
|
|
813
|
-
|
|
817
|
+
group_size: int. The group size for quantization. -1 means no grouping.
|
|
818
|
+
compute_dtype: str. The dtype for computation. Defaults to "float32".
|
|
814
819
|
|
|
815
820
|
Returns:
|
|
816
821
|
scale: KerasTensor. The scale tensor for quantization.
|
|
817
822
|
zero: KerasTensor. The zero tensor for quantization.
|
|
818
823
|
maxq: scalar. The maximum quantization value.
|
|
819
824
|
"""
|
|
825
|
+
# Input validation
|
|
820
826
|
if x is None:
|
|
821
827
|
raise ValueError(f"Input tensor {x} cannot be None.")
|
|
822
|
-
|
|
823
|
-
# For weights, we typically expect at least a 2D tensor.
|
|
824
|
-
if weight and len(x.shape) < 2:
|
|
828
|
+
if len(x.shape) < 2:
|
|
825
829
|
raise ValueError(
|
|
826
830
|
f"Input weight tensor {x} must have a rank of at "
|
|
827
831
|
f"least 2, but got rank {len(x.shape)}."
|
|
828
832
|
)
|
|
829
|
-
|
|
830
833
|
if ops.size(x) == 0:
|
|
831
834
|
raise ValueError("Input tensor 'x' cannot be empty.")
|
|
832
835
|
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
if per_channel:
|
|
836
|
-
if weight:
|
|
837
|
-
if group_size != -1:
|
|
838
|
-
input_reshaped = ops.reshape(x, [-1, group_size])
|
|
839
|
-
else:
|
|
840
|
-
input_reshaped = ops.reshape(x, [original_shape[0], -1])
|
|
841
|
-
else: # per-tensor
|
|
842
|
-
input_reshaped = ops.reshape(x, [1, -1])
|
|
836
|
+
out_features, in_features = x.shape[0], x.shape[1]
|
|
843
837
|
|
|
844
|
-
#
|
|
845
|
-
|
|
846
|
-
|
|
838
|
+
# Determine number of groups for quantization
|
|
839
|
+
if per_channel and group_size > 0:
|
|
840
|
+
n_groups = (in_features + group_size - 1) // group_size
|
|
841
|
+
else:
|
|
842
|
+
n_groups = 1
|
|
843
|
+
|
|
844
|
+
# Compute min/max values based on quantization mode
|
|
845
|
+
if n_groups > 1:
|
|
846
|
+
# Grouped quantization: output shape [out_features, n_groups]
|
|
847
|
+
remainder = in_features % group_size
|
|
848
|
+
if remainder != 0:
|
|
849
|
+
pad_size = group_size - remainder
|
|
850
|
+
x = ops.pad(x, [[0, 0], [0, pad_size]], constant_values=0.0)
|
|
851
|
+
|
|
852
|
+
x_grouped = ops.reshape(x, [out_features, n_groups, group_size])
|
|
853
|
+
min_values = ops.min(x_grouped, axis=2)
|
|
854
|
+
max_values = ops.max(x_grouped, axis=2)
|
|
855
|
+
else:
|
|
856
|
+
# Per-channel or per-tensor: compute stats along rows
|
|
857
|
+
reduction_shape = [out_features, -1] if per_channel else [1, -1]
|
|
858
|
+
x_reshaped = ops.reshape(x, reduction_shape)
|
|
859
|
+
min_values = ops.min(x_reshaped, axis=1)
|
|
860
|
+
max_values = ops.max(x_reshaped, axis=1)
|
|
847
861
|
|
|
848
|
-
#
|
|
862
|
+
# Symmetric quantization: make range symmetric around zero
|
|
849
863
|
if symmetric:
|
|
850
|
-
|
|
864
|
+
max_abs = ops.maximum(ops.abs(min_values), max_values)
|
|
851
865
|
min_values = ops.where(
|
|
852
|
-
ops.less(min_values, 0), ops.negative(
|
|
866
|
+
ops.less(min_values, 0), ops.negative(max_abs), min_values
|
|
853
867
|
)
|
|
868
|
+
max_values = max_abs
|
|
854
869
|
|
|
855
|
-
# Ensure range
|
|
870
|
+
# Ensure non-zero range to avoid division errors
|
|
856
871
|
zero_range = ops.equal(min_values, max_values)
|
|
857
872
|
min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)
|
|
858
873
|
max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
|
|
859
874
|
|
|
875
|
+
# Compute scale and zero-point
|
|
860
876
|
maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)
|
|
861
|
-
|
|
862
|
-
# Calculate scale and zero-point
|
|
863
877
|
scale = ops.divide(ops.subtract(max_values, min_values), maxq)
|
|
878
|
+
scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
|
|
879
|
+
|
|
864
880
|
if symmetric:
|
|
865
881
|
zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
|
|
866
882
|
else:
|
|
867
883
|
zero = ops.round(ops.divide(ops.negative(min_values), scale))
|
|
868
884
|
|
|
869
|
-
#
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
# Per-channel, non-grouped case: simple reshape is correct.
|
|
874
|
-
if per_channel and group_size == -1:
|
|
875
|
-
scale = ops.reshape(scale, [-1, 1])
|
|
876
|
-
zero = ops.reshape(zero, [-1, 1])
|
|
877
|
-
elif not per_channel:
|
|
878
|
-
num_rows = original_shape[0]
|
|
879
|
-
scale = ops.tile(ops.reshape(scale, (1, 1)), (num_rows, 1))
|
|
880
|
-
zero = ops.tile(ops.reshape(zero, (1, 1)), (num_rows, 1))
|
|
881
|
-
if per_channel:
|
|
885
|
+
# Reshape output to [out_features, n_groups] or [out_features, 1]
|
|
886
|
+
if n_groups > 1:
|
|
887
|
+
pass # Already [out_features, n_groups]
|
|
888
|
+
elif per_channel:
|
|
882
889
|
scale = ops.reshape(scale, [-1, 1])
|
|
883
890
|
zero = ops.reshape(zero, [-1, 1])
|
|
891
|
+
else:
|
|
892
|
+
# Per-tensor: tile single value to [out_features, 1]
|
|
893
|
+
scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))
|
|
894
|
+
zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))
|
|
884
895
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
return scale, zero, maxq
|
|
896
|
+
return scale, ops.cast(zero, "uint8"), maxq
|
|
888
897
|
|
|
889
898
|
|
|
890
899
|
def quantize_with_zero_point(input_tensor, scale, zero, maxq):
|
|
@@ -29,7 +29,7 @@ class SeedGenerator:
|
|
|
29
29
|
a local `StateGenerator` with either a deterministic or random initial
|
|
30
30
|
state.
|
|
31
31
|
|
|
32
|
-
Remark concerning the JAX
|
|
32
|
+
Remark concerning the JAX backend: Note that the use of a local
|
|
33
33
|
`StateGenerator` as seed argument is required for JIT compilation of
|
|
34
34
|
RNG with the JAX backend, because the use of global state is not
|
|
35
35
|
supported.
|
|
@@ -111,7 +111,7 @@ class SeedGenerator:
|
|
|
111
111
|
return new_seed_value
|
|
112
112
|
|
|
113
113
|
def get_config(self):
|
|
114
|
-
return {"seed": self._initial_seed}
|
|
114
|
+
return {"seed": self._initial_seed, "name": self.name}
|
|
115
115
|
|
|
116
116
|
@classmethod
|
|
117
117
|
def from_config(cls, config):
|
keras/src/saving/file_editor.py
CHANGED
|
@@ -455,6 +455,9 @@ class KerasFileEditor:
|
|
|
455
455
|
def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
|
|
456
456
|
metadata = metadata or {}
|
|
457
457
|
|
|
458
|
+
# ------------------------------------------------------
|
|
459
|
+
# Collect metadata for this HDF5 group
|
|
460
|
+
# ------------------------------------------------------
|
|
458
461
|
object_metadata = {}
|
|
459
462
|
for k, v in data.attrs.items():
|
|
460
463
|
object_metadata[k] = v
|
|
@@ -462,26 +465,98 @@ class KerasFileEditor:
|
|
|
462
465
|
metadata[inner_path] = object_metadata
|
|
463
466
|
|
|
464
467
|
result = collections.OrderedDict()
|
|
468
|
+
|
|
469
|
+
# ------------------------------------------------------
|
|
470
|
+
# Iterate over all keys in this HDF5 group
|
|
471
|
+
# ------------------------------------------------------
|
|
465
472
|
for key in data.keys():
|
|
466
|
-
|
|
473
|
+
# IMPORTANT:
|
|
474
|
+
# Never mutate inner_path; use local variable.
|
|
475
|
+
current_inner_path = f"{inner_path}/{key}"
|
|
467
476
|
value = data[key]
|
|
477
|
+
|
|
478
|
+
# ------------------------------------------------------
|
|
479
|
+
# CASE 1 — HDF5 GROUP → RECURSE
|
|
480
|
+
# ------------------------------------------------------
|
|
468
481
|
if isinstance(value, h5py.Group):
|
|
482
|
+
# Skip empty groups
|
|
469
483
|
if len(value) == 0:
|
|
470
484
|
continue
|
|
485
|
+
|
|
486
|
+
# Skip empty "vars" groups
|
|
471
487
|
if "vars" in value.keys() and len(value["vars"]) == 0:
|
|
472
488
|
continue
|
|
473
489
|
|
|
474
|
-
|
|
490
|
+
# Recurse into "vars" subgroup when present
|
|
475
491
|
if "vars" in value.keys():
|
|
476
492
|
result[key], metadata = self._extract_weights_from_store(
|
|
477
|
-
value["vars"],
|
|
493
|
+
value["vars"],
|
|
494
|
+
metadata=metadata,
|
|
495
|
+
inner_path=current_inner_path,
|
|
478
496
|
)
|
|
479
497
|
else:
|
|
498
|
+
# Recurse normally
|
|
480
499
|
result[key], metadata = self._extract_weights_from_store(
|
|
481
|
-
value,
|
|
500
|
+
value,
|
|
501
|
+
metadata=metadata,
|
|
502
|
+
inner_path=current_inner_path,
|
|
482
503
|
)
|
|
483
|
-
|
|
484
|
-
|
|
504
|
+
|
|
505
|
+
continue # finished processing this key
|
|
506
|
+
|
|
507
|
+
# ------------------------------------------------------
|
|
508
|
+
# CASE 2 — HDF5 DATASET → SAFE LOADING
|
|
509
|
+
# ------------------------------------------------------
|
|
510
|
+
|
|
511
|
+
# Skip any objects that are not proper datasets
|
|
512
|
+
if not hasattr(value, "shape") or not hasattr(value, "dtype"):
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
shape = value.shape
|
|
516
|
+
dtype = value.dtype
|
|
517
|
+
|
|
518
|
+
# ------------------------------------------------------
|
|
519
|
+
# Validate SHAPE (avoid malformed / malicious metadata)
|
|
520
|
+
# ------------------------------------------------------
|
|
521
|
+
|
|
522
|
+
# No negative dimensions
|
|
523
|
+
if any(dim < 0 for dim in shape):
|
|
524
|
+
raise ValueError(
|
|
525
|
+
"Malformed HDF5 dataset shape encountered in .keras file; "
|
|
526
|
+
"negative dimension detected."
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Prevent absurdly high-rank tensors
|
|
530
|
+
if len(shape) > 64:
|
|
531
|
+
raise ValueError(
|
|
532
|
+
"Malformed HDF5 dataset shape encountered in .keras file; "
|
|
533
|
+
"tensor rank exceeds safety limit."
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
# Safe product computation (Python int is unbounded)
|
|
537
|
+
num_elems = int(np.prod(shape))
|
|
538
|
+
|
|
539
|
+
# ------------------------------------------------------
|
|
540
|
+
# Validate TOTAL memory size
|
|
541
|
+
# ------------------------------------------------------
|
|
542
|
+
MAX_BYTES = 1 << 32 # 4 GiB
|
|
543
|
+
|
|
544
|
+
size_bytes = num_elems * dtype.itemsize
|
|
545
|
+
|
|
546
|
+
if size_bytes > MAX_BYTES:
|
|
547
|
+
raise ValueError(
|
|
548
|
+
f"HDF5 dataset too large to load safely "
|
|
549
|
+
f"({size_bytes} bytes; limit is {MAX_BYTES})."
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
# ------------------------------------------------------
|
|
553
|
+
# SAFE — load dataset (guaranteed ≤ 4 GiB)
|
|
554
|
+
# ------------------------------------------------------
|
|
555
|
+
result[key] = value[()]
|
|
556
|
+
|
|
557
|
+
# ------------------------------------------------------
|
|
558
|
+
# Return final tree and metadata
|
|
559
|
+
# ------------------------------------------------------
|
|
485
560
|
return result, metadata
|
|
486
561
|
|
|
487
562
|
def _generate_filepath_info(self, rich_style=False):
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Orbax checkpoint loading functionality."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from keras.src.utils.module_utils import ocp
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def is_orbax_checkpoint(filepath):
|
|
9
|
+
"""Check if the given path is an Orbax checkpoint directory.
|
|
10
|
+
|
|
11
|
+
This function implements custom detection logic instead of relying on
|
|
12
|
+
Orbax APIs which may be unreliable in some environments.
|
|
13
|
+
"""
|
|
14
|
+
if not os.path.exists(filepath) or not os.path.isdir(filepath):
|
|
15
|
+
return False
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
# List directory contents
|
|
19
|
+
contents = os.listdir(filepath)
|
|
20
|
+
|
|
21
|
+
# A set is more efficient for membership testing
|
|
22
|
+
orbax_indicators = {
|
|
23
|
+
"orbax.checkpoint",
|
|
24
|
+
"pytree.orbax-checkpoint",
|
|
25
|
+
"checkpoint_metadata",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
# Fast check for standard files
|
|
29
|
+
if not orbax_indicators.isdisjoint(contents):
|
|
30
|
+
return True
|
|
31
|
+
|
|
32
|
+
# Check for step directories or temporary files in a single pass
|
|
33
|
+
return any(
|
|
34
|
+
".orbax-checkpoint-tmp" in item
|
|
35
|
+
or (item.isdigit() and os.path.isdir(os.path.join(filepath, item)))
|
|
36
|
+
for item in contents
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
except (OSError, PermissionError):
|
|
40
|
+
# If we can't read the directory, assume it's not a checkpoint
|
|
41
|
+
return False
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def find_latest_orbax_checkpoint(checkpoint_dir):
|
|
45
|
+
"""Find the latest checkpoint in an Orbax checkpoint directory."""
|
|
46
|
+
checkpointer = ocp.training.Checkpointer(directory=checkpoint_dir)
|
|
47
|
+
latest = checkpointer.latest
|
|
48
|
+
if latest is None:
|
|
49
|
+
raise ValueError(f"No valid checkpoints found in {checkpoint_dir}")
|
|
50
|
+
return os.path.join(checkpoint_dir, str(latest.step))
|