keras-nightly 3.14.0.dev2026010104__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/ops/__init__.py +2 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +2 -0
- keras/_tf_keras/keras/quantizers/__init__.py +1 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/ops/__init__.py +2 -0
- keras/ops/numpy/__init__.py +2 -0
- keras/quantizers/__init__.py +1 -0
- keras/src/backend/jax/nn.py +26 -9
- keras/src/backend/jax/numpy.py +10 -0
- keras/src/backend/numpy/numpy.py +15 -0
- keras/src/backend/openvino/numpy.py +338 -17
- keras/src/backend/tensorflow/numpy.py +24 -1
- keras/src/backend/tensorflow/rnn.py +17 -7
- keras/src/backend/torch/numpy.py +26 -0
- keras/src/backend/torch/rnn.py +28 -11
- keras/src/callbacks/orbax_checkpoint.py +75 -42
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/layers/core/dense.py +122 -6
- keras/src/layers/core/einsum_dense.py +151 -7
- keras/src/layers/core/embedding.py +1 -1
- keras/src/layers/core/reversible_embedding.py +10 -1
- keras/src/layers/layer.py +5 -0
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/losses/losses.py +24 -0
- keras/src/models/model.py +18 -9
- keras/src/ops/image.py +106 -93
- keras/src/ops/numpy.py +138 -0
- keras/src/quantizers/__init__.py +2 -0
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +1 -2
- keras/src/quantizers/gptq_core.py +1 -1
- keras/src/quantizers/quantization_config.py +14 -0
- keras/src/quantizers/quantizers.py +61 -52
- keras/src/random/seed_generator.py +2 -2
- keras/src/saving/orbax_util.py +50 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/utils/jax_layer.py +69 -31
- keras/src/utils/module_utils.py +11 -0
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +52 -48
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
- {keras_nightly-3.14.0.dev2026010104.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""AWQ core functionality for layer-wise quantization.
|
|
2
|
+
|
|
3
|
+
This module provides the orchestration logic for applying AWQ quantization
|
|
4
|
+
to transformer models in a layer-by-layer fashion.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
|
|
9
|
+
from absl import logging
|
|
10
|
+
|
|
11
|
+
from keras.src import ops
|
|
12
|
+
from keras.src import utils as keras_utils
|
|
13
|
+
from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
|
|
14
|
+
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
15
|
+
from keras.src.quantizers.awq import AWQ
|
|
16
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
17
|
+
from keras.src.quantizers.gptq_core import find_layers_in_block
|
|
18
|
+
from keras.src.quantizers.gptq_core import get_dataloader
|
|
19
|
+
from keras.src.quantizers.utils import should_quantize_layer
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@contextmanager
|
|
23
|
+
def stream_activations(layers_map, awq_objects):
|
|
24
|
+
"""Context manager to capture activations for AWQ calibration.
|
|
25
|
+
|
|
26
|
+
Temporarily patches layer.call methods to capture activation statistics
|
|
27
|
+
for computing per-channel scaling factors.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
layers_map: Dict[str, Layer]. Mapping from layer names to layers.
|
|
31
|
+
awq_objects: Dict[str, AWQ]. Mapping from names to AWQ instances.
|
|
32
|
+
|
|
33
|
+
Yields:
|
|
34
|
+
None: The patched state is active only within the `with` block.
|
|
35
|
+
"""
|
|
36
|
+
original_calls = {}
|
|
37
|
+
|
|
38
|
+
def create_hook(name, original_call_func):
|
|
39
|
+
def hook(*args, **kwargs):
|
|
40
|
+
inp = args[0] if args else kwargs["inputs"]
|
|
41
|
+
num_features = awq_objects[name].rows
|
|
42
|
+
input_2d = ops.reshape(inp, (-1, num_features))
|
|
43
|
+
awq_objects[name].update_activation_magnitudes(input_2d)
|
|
44
|
+
return original_call_func(*args, **kwargs)
|
|
45
|
+
|
|
46
|
+
return hook
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
for name, layer in layers_map.items():
|
|
50
|
+
original_calls[name] = layer.call
|
|
51
|
+
layer.call = create_hook(name, layer.call)
|
|
52
|
+
yield
|
|
53
|
+
finally:
|
|
54
|
+
for name, layer in layers_map.items():
|
|
55
|
+
layer.call = original_calls[name]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def apply_awq_layerwise(dataloader, config, structure, filters=None):
|
|
59
|
+
"""Apply AWQ quantization layer-by-layer to a Keras model.
|
|
60
|
+
|
|
61
|
+
This function processes the model sequentially, one block at a time:
|
|
62
|
+
1. Captures activation statistics through calibration data forward pass
|
|
63
|
+
2. Uses activation magnitudes to determine weight saliency
|
|
64
|
+
3. Finds optimal per-channel scales via grid search
|
|
65
|
+
4. Quantizes weights with AWQ scaling
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
dataloader: Calibration data as numpy array.
|
|
69
|
+
config: AWQConfig instance.
|
|
70
|
+
structure: Dict with 'pre_block_layers' and 'sequential_blocks'.
|
|
71
|
+
filters: Optional layer filters.
|
|
72
|
+
"""
|
|
73
|
+
num_samples = config.num_samples
|
|
74
|
+
logging.info("Starting AWQ quantization...")
|
|
75
|
+
|
|
76
|
+
pre_layers = structure.get("pre_block_layers", [])
|
|
77
|
+
transformer_blocks = structure.get("sequential_blocks", [])
|
|
78
|
+
|
|
79
|
+
if not transformer_blocks:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"No sequential blocks found in the structure to quantize."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Process inputs through pre-block layers (e.g., embedding)
|
|
85
|
+
inputs = []
|
|
86
|
+
for batch in dataloader:
|
|
87
|
+
batch = ops.convert_to_tensor(batch, dtype="int32")
|
|
88
|
+
for layer in pre_layers:
|
|
89
|
+
batch = layer(batch)
|
|
90
|
+
inputs.append(batch)
|
|
91
|
+
|
|
92
|
+
num_samples = min(num_samples, len(inputs))
|
|
93
|
+
progbar = keras_utils.Progbar(target=len(transformer_blocks))
|
|
94
|
+
|
|
95
|
+
for block_idx, block in enumerate(transformer_blocks):
|
|
96
|
+
logging.info(f"Quantizing Block {block_idx}")
|
|
97
|
+
sub_layers_map = find_layers_in_block(block)
|
|
98
|
+
|
|
99
|
+
# Apply filters
|
|
100
|
+
final_sub_layers_map = {}
|
|
101
|
+
for name, layer in sub_layers_map.items():
|
|
102
|
+
if not should_quantize_layer(layer, filters):
|
|
103
|
+
continue
|
|
104
|
+
final_sub_layers_map[name] = layer
|
|
105
|
+
|
|
106
|
+
sub_layers_map = final_sub_layers_map
|
|
107
|
+
|
|
108
|
+
if not sub_layers_map:
|
|
109
|
+
logging.info(
|
|
110
|
+
f" No quantizable layers found in block {block_idx}. Skipping."
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
|
|
114
|
+
|
|
115
|
+
# Create AWQ objects for each layer
|
|
116
|
+
awq_objects = {
|
|
117
|
+
name: AWQ(layer, config)
|
|
118
|
+
for name, layer in sub_layers_map.items()
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
# Capture activation statistics
|
|
122
|
+
with stream_activations(sub_layers_map, awq_objects):
|
|
123
|
+
for sample_idx in range(num_samples):
|
|
124
|
+
current_input = inputs[sample_idx]
|
|
125
|
+
if len(current_input.shape) == 2:
|
|
126
|
+
current_input = ops.expand_dims(current_input, axis=0)
|
|
127
|
+
_ = block(current_input)
|
|
128
|
+
|
|
129
|
+
# Quantize each layer
|
|
130
|
+
for name, awq_object in awq_objects.items():
|
|
131
|
+
logging.info(f"Quantizing {name}...")
|
|
132
|
+
awq_object.quantize_layer()
|
|
133
|
+
awq_object.free()
|
|
134
|
+
|
|
135
|
+
del awq_objects
|
|
136
|
+
|
|
137
|
+
# Generate inputs for next block
|
|
138
|
+
if block_idx < len(transformer_blocks) - 1:
|
|
139
|
+
logging.info(f"Generating inputs for block {block_idx + 1}...")
|
|
140
|
+
next_block_inputs = []
|
|
141
|
+
for sample_idx in range(num_samples):
|
|
142
|
+
current_input = inputs[sample_idx]
|
|
143
|
+
if len(current_input.shape) == 2:
|
|
144
|
+
current_input = ops.expand_dims(current_input, axis=0)
|
|
145
|
+
output = block(current_input)[0]
|
|
146
|
+
next_block_inputs.append(output)
|
|
147
|
+
inputs = next_block_inputs
|
|
148
|
+
|
|
149
|
+
progbar.update(current=block_idx + 1)
|
|
150
|
+
|
|
151
|
+
logging.info("AWQ quantization complete.")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def awq_quantize(config, quantization_layer_structure, filters=None):
|
|
155
|
+
"""Main entry point for AWQ quantization.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
config: AWQConfig instance.
|
|
159
|
+
quantization_layer_structure: Model structure dictionary.
|
|
160
|
+
filters: Optional layer filters.
|
|
161
|
+
"""
|
|
162
|
+
if config.dataset is None or config.tokenizer is None:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
"AWQ quantization requires a dataset and tokenizer. "
|
|
165
|
+
"Please provide them in the AWQConfig."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if quantization_layer_structure is None:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
"For 'awq' mode, a valid quantization structure must be provided "
|
|
171
|
+
"either via `config.quantization_layer_structure` or by overriding "
|
|
172
|
+
"`model.get_quantization_layer_structure(mode)`. The structure "
|
|
173
|
+
"should be a dictionary with keys 'pre_block_layers' and "
|
|
174
|
+
"'sequential_blocks'."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Load calibration data
|
|
178
|
+
dataloader = get_dataloader(
|
|
179
|
+
config.tokenizer,
|
|
180
|
+
config.sequence_length,
|
|
181
|
+
config.dataset,
|
|
182
|
+
num_samples=config.num_samples,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
apply_awq_layerwise(
|
|
186
|
+
dataloader[: config.num_samples],
|
|
187
|
+
config,
|
|
188
|
+
quantization_layer_structure,
|
|
189
|
+
filters=filters,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_group_size_for_layer(layer, config):
|
|
194
|
+
"""Get group size from config or dtype policy.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
layer: The layer to get group size for.
|
|
198
|
+
config: Optional AWQConfig instance.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
int: The group size for quantization.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
ValueError: If group size cannot be determined.
|
|
205
|
+
"""
|
|
206
|
+
if config and isinstance(config, AWQConfig):
|
|
207
|
+
return config.group_size
|
|
208
|
+
elif isinstance(layer.dtype_policy, AWQDTypePolicy):
|
|
209
|
+
return layer.dtype_policy.group_size
|
|
210
|
+
elif isinstance(layer.dtype_policy, DTypePolicyMap):
|
|
211
|
+
policy = layer.dtype_policy[layer.path]
|
|
212
|
+
if isinstance(policy, AWQDTypePolicy):
|
|
213
|
+
return policy.group_size
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"For AWQ quantization, group_size must be specified "
|
|
216
|
+
"through AWQConfig or AWQDTypePolicy."
|
|
217
|
+
)
|
keras/src/quantizers/gptq.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import types
|
|
2
|
-
from functools import partial
|
|
3
2
|
|
|
4
3
|
from keras.src import ops
|
|
5
4
|
from keras.src import quantizers
|
|
@@ -466,7 +465,7 @@ class GPTQ:
|
|
|
466
465
|
group_size=self.config.group_size,
|
|
467
466
|
activation_order=self.config.activation_order,
|
|
468
467
|
order_metric=ops.diagonal(hessian_matrix),
|
|
469
|
-
compute_scale_zero=
|
|
468
|
+
compute_scale_zero=self.quantizer.find_params,
|
|
470
469
|
)
|
|
471
470
|
quantized = ops.cast(
|
|
472
471
|
quantized, self.original_layer.quantized_kernel.dtype
|
|
@@ -131,7 +131,7 @@ def get_dataloader(
|
|
|
131
131
|
pieces = []
|
|
132
132
|
if isinstance(dataset_list[0], str):
|
|
133
133
|
for i, s in enumerate(dataset_list):
|
|
134
|
-
toks =
|
|
134
|
+
toks = ops.convert_to_numpy(tokenizer.tokenize(s)).reshape(-1)
|
|
135
135
|
pieces.append(toks)
|
|
136
136
|
# avoid windows that span document boundaries
|
|
137
137
|
if eos_id is not None and i < len(dataset_list) - 1:
|
|
@@ -182,6 +182,11 @@ def validate_and_resolve_config(mode, config):
|
|
|
182
182
|
"For GPTQ, you must pass a `GPTQConfig` object in the "
|
|
183
183
|
"`config` argument."
|
|
184
184
|
)
|
|
185
|
+
elif mode == "awq":
|
|
186
|
+
raise ValueError(
|
|
187
|
+
"For AWQ, you must pass an `AWQConfig` object in the "
|
|
188
|
+
"`config` argument."
|
|
189
|
+
)
|
|
185
190
|
else:
|
|
186
191
|
if mode is not None:
|
|
187
192
|
raise ValueError(
|
|
@@ -220,6 +225,15 @@ def validate_and_resolve_config(mode, config):
|
|
|
220
225
|
f"`GPTQConfig`. Received: {type(config)}"
|
|
221
226
|
)
|
|
222
227
|
|
|
228
|
+
if mode == "awq":
|
|
229
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
230
|
+
|
|
231
|
+
if not isinstance(config, AWQConfig):
|
|
232
|
+
raise ValueError(
|
|
233
|
+
"Mode 'awq' requires a valid `config` argument of type "
|
|
234
|
+
f"`AWQConfig`. Received: {type(config)}"
|
|
235
|
+
)
|
|
236
|
+
|
|
223
237
|
return config
|
|
224
238
|
|
|
225
239
|
|
|
@@ -653,11 +653,14 @@ def unpack_int4(packed, orig_len, axis=0, dtype="int8"):
|
|
|
653
653
|
)
|
|
654
654
|
|
|
655
655
|
def to_signed(x):
|
|
656
|
-
"""Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
|
|
656
|
+
"""Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
|
|
657
|
+
|
|
658
|
+
Uses a branchless XOR approach: (x ^ 8) - 8
|
|
659
|
+
This maps: 0->0, 1->1, ..., 7->7, 8->-8, 9->-7, ..., 15->-1
|
|
660
|
+
"""
|
|
657
661
|
dtype_x = backend.standardize_dtype(x.dtype)
|
|
658
662
|
eight = ops.cast(8, dtype_x)
|
|
659
|
-
|
|
660
|
-
return ops.where(x < eight, x, x - sixteen)
|
|
663
|
+
return ops.subtract(ops.bitwise_xor(x, eight), eight)
|
|
661
664
|
|
|
662
665
|
rank = getattr(packed.shape, "rank", None) or len(packed.shape)
|
|
663
666
|
if axis < 0:
|
|
@@ -748,7 +751,7 @@ class GPTQQuantizer(Quantizer):
|
|
|
748
751
|
self.zero = None
|
|
749
752
|
self.maxq = None
|
|
750
753
|
|
|
751
|
-
def find_params(self, input_tensor
|
|
754
|
+
def find_params(self, input_tensor):
|
|
752
755
|
"""Finds quantization parameters (scale and zero) for a given tensor."""
|
|
753
756
|
self.scale, self.zero, self.maxq = compute_quantization_parameters(
|
|
754
757
|
input_tensor,
|
|
@@ -756,7 +759,6 @@ class GPTQQuantizer(Quantizer):
|
|
|
756
759
|
symmetric=self.symmetric,
|
|
757
760
|
per_channel=self.per_channel,
|
|
758
761
|
group_size=self.group_size,
|
|
759
|
-
weight=weight,
|
|
760
762
|
compute_dtype=self.compute_dtype,
|
|
761
763
|
)
|
|
762
764
|
return self.scale, self.zero, self.maxq
|
|
@@ -793,98 +795,105 @@ def compute_quantization_parameters(
|
|
|
793
795
|
symmetric=False,
|
|
794
796
|
per_channel=False,
|
|
795
797
|
group_size=-1,
|
|
796
|
-
weight=False,
|
|
797
798
|
compute_dtype="float32",
|
|
798
799
|
):
|
|
799
800
|
"""
|
|
800
|
-
Computes the scale and zero-point for
|
|
801
|
+
Computes the scale and zero-point for quantizing weight tensors.
|
|
801
802
|
|
|
802
803
|
This function calculates the scale and zero-point required for quantizing
|
|
803
|
-
a given tensor `x` based on the specified parameters. It supports
|
|
804
|
-
per-channel, per-tensor, symmetric, and asymmetric quantization
|
|
805
|
-
|
|
804
|
+
a given weight tensor `x` based on the specified parameters. It supports
|
|
805
|
+
grouped, per-channel, per-tensor, symmetric, and asymmetric quantization.
|
|
806
|
+
|
|
807
|
+
For grouped quantization (per_channel=True, group_size > 0), the output
|
|
808
|
+
shapes are [out_features, n_groups] where n_groups is the number of groups
|
|
809
|
+
along the in_features dimension.
|
|
806
810
|
|
|
807
811
|
Args:
|
|
808
|
-
x: KerasTensor. The
|
|
812
|
+
x: KerasTensor. The weight tensor to quantize with shape
|
|
813
|
+
[out_features, in_features].
|
|
809
814
|
bits: int. The number of bits to quantize to (e.g., 4).
|
|
810
815
|
symmetric: bool. Whether to use symmetric quantization.
|
|
811
816
|
per_channel: bool. Whether to quantize per channel.
|
|
812
|
-
group_size: int. The group size for quantization.
|
|
813
|
-
|
|
817
|
+
group_size: int. The group size for quantization. -1 means no grouping.
|
|
818
|
+
compute_dtype: str. The dtype for computation. Defaults to "float32".
|
|
814
819
|
|
|
815
820
|
Returns:
|
|
816
821
|
scale: KerasTensor. The scale tensor for quantization.
|
|
817
822
|
zero: KerasTensor. The zero tensor for quantization.
|
|
818
823
|
maxq: scalar. The maximum quantization value.
|
|
819
824
|
"""
|
|
825
|
+
# Input validation
|
|
820
826
|
if x is None:
|
|
821
827
|
raise ValueError(f"Input tensor {x} cannot be None.")
|
|
822
|
-
|
|
823
|
-
# For weights, we typically expect at least a 2D tensor.
|
|
824
|
-
if weight and len(x.shape) < 2:
|
|
828
|
+
if len(x.shape) < 2:
|
|
825
829
|
raise ValueError(
|
|
826
830
|
f"Input weight tensor {x} must have a rank of at "
|
|
827
831
|
f"least 2, but got rank {len(x.shape)}."
|
|
828
832
|
)
|
|
829
|
-
|
|
830
833
|
if ops.size(x) == 0:
|
|
831
834
|
raise ValueError("Input tensor 'x' cannot be empty.")
|
|
832
835
|
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
if per_channel:
|
|
836
|
-
if weight:
|
|
837
|
-
if group_size != -1:
|
|
838
|
-
input_reshaped = ops.reshape(x, [-1, group_size])
|
|
839
|
-
else:
|
|
840
|
-
input_reshaped = ops.reshape(x, [original_shape[0], -1])
|
|
841
|
-
else: # per-tensor
|
|
842
|
-
input_reshaped = ops.reshape(x, [1, -1])
|
|
836
|
+
out_features, in_features = x.shape[0], x.shape[1]
|
|
843
837
|
|
|
844
|
-
#
|
|
845
|
-
|
|
846
|
-
|
|
838
|
+
# Determine number of groups for quantization
|
|
839
|
+
if per_channel and group_size > 0:
|
|
840
|
+
n_groups = (in_features + group_size - 1) // group_size
|
|
841
|
+
else:
|
|
842
|
+
n_groups = 1
|
|
843
|
+
|
|
844
|
+
# Compute min/max values based on quantization mode
|
|
845
|
+
if n_groups > 1:
|
|
846
|
+
# Grouped quantization: output shape [out_features, n_groups]
|
|
847
|
+
remainder = in_features % group_size
|
|
848
|
+
if remainder != 0:
|
|
849
|
+
pad_size = group_size - remainder
|
|
850
|
+
x = ops.pad(x, [[0, 0], [0, pad_size]], constant_values=0.0)
|
|
851
|
+
|
|
852
|
+
x_grouped = ops.reshape(x, [out_features, n_groups, group_size])
|
|
853
|
+
min_values = ops.min(x_grouped, axis=2)
|
|
854
|
+
max_values = ops.max(x_grouped, axis=2)
|
|
855
|
+
else:
|
|
856
|
+
# Per-channel or per-tensor: compute stats along rows
|
|
857
|
+
reduction_shape = [out_features, -1] if per_channel else [1, -1]
|
|
858
|
+
x_reshaped = ops.reshape(x, reduction_shape)
|
|
859
|
+
min_values = ops.min(x_reshaped, axis=1)
|
|
860
|
+
max_values = ops.max(x_reshaped, axis=1)
|
|
847
861
|
|
|
848
|
-
#
|
|
862
|
+
# Symmetric quantization: make range symmetric around zero
|
|
849
863
|
if symmetric:
|
|
850
|
-
|
|
864
|
+
max_abs = ops.maximum(ops.abs(min_values), max_values)
|
|
851
865
|
min_values = ops.where(
|
|
852
|
-
ops.less(min_values, 0), ops.negative(
|
|
866
|
+
ops.less(min_values, 0), ops.negative(max_abs), min_values
|
|
853
867
|
)
|
|
868
|
+
max_values = max_abs
|
|
854
869
|
|
|
855
|
-
# Ensure range
|
|
870
|
+
# Ensure non-zero range to avoid division errors
|
|
856
871
|
zero_range = ops.equal(min_values, max_values)
|
|
857
872
|
min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)
|
|
858
873
|
max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
|
|
859
874
|
|
|
875
|
+
# Compute scale and zero-point
|
|
860
876
|
maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)
|
|
861
|
-
|
|
862
|
-
# Calculate scale and zero-point
|
|
863
877
|
scale = ops.divide(ops.subtract(max_values, min_values), maxq)
|
|
878
|
+
scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
|
|
879
|
+
|
|
864
880
|
if symmetric:
|
|
865
881
|
zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
|
|
866
882
|
else:
|
|
867
883
|
zero = ops.round(ops.divide(ops.negative(min_values), scale))
|
|
868
884
|
|
|
869
|
-
#
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
# Per-channel, non-grouped case: simple reshape is correct.
|
|
874
|
-
if per_channel and group_size == -1:
|
|
875
|
-
scale = ops.reshape(scale, [-1, 1])
|
|
876
|
-
zero = ops.reshape(zero, [-1, 1])
|
|
877
|
-
elif not per_channel:
|
|
878
|
-
num_rows = original_shape[0]
|
|
879
|
-
scale = ops.tile(ops.reshape(scale, (1, 1)), (num_rows, 1))
|
|
880
|
-
zero = ops.tile(ops.reshape(zero, (1, 1)), (num_rows, 1))
|
|
881
|
-
if per_channel:
|
|
885
|
+
# Reshape output to [out_features, n_groups] or [out_features, 1]
|
|
886
|
+
if n_groups > 1:
|
|
887
|
+
pass # Already [out_features, n_groups]
|
|
888
|
+
elif per_channel:
|
|
882
889
|
scale = ops.reshape(scale, [-1, 1])
|
|
883
890
|
zero = ops.reshape(zero, [-1, 1])
|
|
891
|
+
else:
|
|
892
|
+
# Per-tensor: tile single value to [out_features, 1]
|
|
893
|
+
scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))
|
|
894
|
+
zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))
|
|
884
895
|
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
return scale, zero, maxq
|
|
896
|
+
return scale, ops.cast(zero, "uint8"), maxq
|
|
888
897
|
|
|
889
898
|
|
|
890
899
|
def quantize_with_zero_point(input_tensor, scale, zero, maxq):
|
|
@@ -29,7 +29,7 @@ class SeedGenerator:
|
|
|
29
29
|
a local `StateGenerator` with either a deterministic or random initial
|
|
30
30
|
state.
|
|
31
31
|
|
|
32
|
-
Remark concerning the JAX
|
|
32
|
+
Remark concerning the JAX backend: Note that the use of a local
|
|
33
33
|
`StateGenerator` as seed argument is required for JIT compilation of
|
|
34
34
|
RNG with the JAX backend, because the use of global state is not
|
|
35
35
|
supported.
|
|
@@ -111,7 +111,7 @@ class SeedGenerator:
|
|
|
111
111
|
return new_seed_value
|
|
112
112
|
|
|
113
113
|
def get_config(self):
|
|
114
|
-
return {"seed": self._initial_seed}
|
|
114
|
+
return {"seed": self._initial_seed, "name": self.name}
|
|
115
115
|
|
|
116
116
|
@classmethod
|
|
117
117
|
def from_config(cls, config):
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Orbax checkpoint loading functionality."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from keras.src.utils.module_utils import ocp
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def is_orbax_checkpoint(filepath):
|
|
9
|
+
"""Check if the given path is an Orbax checkpoint directory.
|
|
10
|
+
|
|
11
|
+
This function implements custom detection logic instead of relying on
|
|
12
|
+
Orbax APIs which may be unreliable in some environments.
|
|
13
|
+
"""
|
|
14
|
+
if not os.path.exists(filepath) or not os.path.isdir(filepath):
|
|
15
|
+
return False
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
# List directory contents
|
|
19
|
+
contents = os.listdir(filepath)
|
|
20
|
+
|
|
21
|
+
# A set is more efficient for membership testing
|
|
22
|
+
orbax_indicators = {
|
|
23
|
+
"orbax.checkpoint",
|
|
24
|
+
"pytree.orbax-checkpoint",
|
|
25
|
+
"checkpoint_metadata",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
# Fast check for standard files
|
|
29
|
+
if not orbax_indicators.isdisjoint(contents):
|
|
30
|
+
return True
|
|
31
|
+
|
|
32
|
+
# Check for step directories or temporary files in a single pass
|
|
33
|
+
return any(
|
|
34
|
+
".orbax-checkpoint-tmp" in item
|
|
35
|
+
or (item.isdigit() and os.path.isdir(os.path.join(filepath, item)))
|
|
36
|
+
for item in contents
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
except (OSError, PermissionError):
|
|
40
|
+
# If we can't read the directory, assume it's not a checkpoint
|
|
41
|
+
return False
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def find_latest_orbax_checkpoint(checkpoint_dir):
|
|
45
|
+
"""Find the latest checkpoint in an Orbax checkpoint directory."""
|
|
46
|
+
checkpointer = ocp.training.Checkpointer(directory=checkpoint_dir)
|
|
47
|
+
latest = checkpointer.latest
|
|
48
|
+
if latest is None:
|
|
49
|
+
raise ValueError(f"No valid checkpoints found in {checkpoint_dir}")
|
|
50
|
+
return os.path.join(checkpoint_dir, str(latest.step))
|
keras/src/saving/saving_api.py
CHANGED
|
@@ -6,13 +6,11 @@ from absl import logging
|
|
|
6
6
|
from keras.src.api_export import keras_export
|
|
7
7
|
from keras.src.legacy.saving import legacy_h5_format
|
|
8
8
|
from keras.src.saving import saving_lib
|
|
9
|
+
from keras.src.saving.orbax_util import find_latest_orbax_checkpoint
|
|
10
|
+
from keras.src.saving.orbax_util import is_orbax_checkpoint
|
|
9
11
|
from keras.src.utils import file_utils
|
|
10
12
|
from keras.src.utils import io_utils
|
|
11
|
-
|
|
12
|
-
try:
|
|
13
|
-
import h5py
|
|
14
|
-
except ImportError:
|
|
15
|
-
h5py = None
|
|
13
|
+
from keras.src.utils.module_utils import h5py
|
|
16
14
|
|
|
17
15
|
|
|
18
16
|
@keras_export(["keras.saving.save_model", "keras.models.save_model"])
|
|
@@ -149,8 +147,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
149
147
|
keras.layers.Softmax()])
|
|
150
148
|
model.save("model.keras")
|
|
151
149
|
loaded_model = keras.saving.load_model("model.keras")
|
|
152
|
-
x = np.random.random((10, 3))
|
|
153
|
-
assert np.allclose(model.predict(x), loaded_model.predict(x))
|
|
154
150
|
```
|
|
155
151
|
|
|
156
152
|
Note that the model variables may have different name values
|
|
@@ -208,7 +204,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
208
204
|
else:
|
|
209
205
|
raise ValueError(
|
|
210
206
|
f"File format not supported: filepath={filepath}. "
|
|
211
|
-
"Keras 3 only supports V3 `.keras` files
|
|
207
|
+
"Keras 3 only supports V3 `.keras` files, "
|
|
212
208
|
"legacy H5 format files (`.h5` extension). "
|
|
213
209
|
"Note that the legacy SavedModel format is not "
|
|
214
210
|
"supported by `load_model()` in Keras 3. In "
|
|
@@ -288,15 +284,16 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
288
284
|
objects_to_skip=objects_to_skip,
|
|
289
285
|
)
|
|
290
286
|
elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
|
|
291
|
-
if not h5py:
|
|
292
|
-
raise ImportError(
|
|
293
|
-
"Loading a H5 file requires `h5py` to be installed."
|
|
294
|
-
)
|
|
295
287
|
if objects_to_skip is not None:
|
|
296
288
|
raise ValueError(
|
|
297
289
|
"`objects_to_skip` only supports loading '.weights.h5' files."
|
|
298
290
|
f"Received: {filepath}"
|
|
299
291
|
)
|
|
292
|
+
if not h5py.available:
|
|
293
|
+
raise ImportError(
|
|
294
|
+
"Loading HDF5 files requires the h5py package. "
|
|
295
|
+
"You can install it via `pip install h5py`"
|
|
296
|
+
)
|
|
300
297
|
with h5py.File(filepath, "r") as f:
|
|
301
298
|
if "layer_names" not in f.attrs and "model_weights" in f:
|
|
302
299
|
f = f["model_weights"]
|
|
@@ -308,9 +305,35 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
308
305
|
legacy_h5_format.load_weights_from_hdf5_group(
|
|
309
306
|
f, model, skip_mismatch
|
|
310
307
|
)
|
|
308
|
+
elif is_orbax_checkpoint(filepath):
|
|
309
|
+
# Load weights from Orbax checkpoint
|
|
310
|
+
from keras.src.utils.module_utils import ocp
|
|
311
|
+
|
|
312
|
+
filepath = str(filepath)
|
|
313
|
+
|
|
314
|
+
# Determine if this is a root directory or a step directory
|
|
315
|
+
items = os.listdir(filepath)
|
|
316
|
+
has_step_subdirs = any(
|
|
317
|
+
os.path.isdir(os.path.join(filepath, item)) and item.isdigit()
|
|
318
|
+
for item in items
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if has_step_subdirs:
|
|
322
|
+
# It's a root directory, find the latest checkpoint
|
|
323
|
+
checkpoint_path = find_latest_orbax_checkpoint(filepath)
|
|
324
|
+
else:
|
|
325
|
+
# It's a step directory, use it directly
|
|
326
|
+
checkpoint_path = filepath
|
|
327
|
+
|
|
328
|
+
# Load checkpoint
|
|
329
|
+
loaded_state = ocp.load_pytree(checkpoint_path)
|
|
330
|
+
|
|
331
|
+
# Set the model state directly from the loaded state
|
|
332
|
+
model.set_state_tree(loaded_state)
|
|
311
333
|
else:
|
|
312
334
|
raise ValueError(
|
|
313
335
|
f"File format not supported: filepath={filepath}. "
|
|
314
|
-
"Keras 3 only supports V3 `.keras`
|
|
315
|
-
"files,
|
|
336
|
+
"Keras 3 only supports V3 `.keras` files, "
|
|
337
|
+
"`.weights.h5` files, legacy H5 format files "
|
|
338
|
+
"(`.h5` extension), or Orbax checkpoints."
|
|
316
339
|
)
|