keras-nightly 3.14.0.dev2026012704__py3-none-any.whl → 3.14.0.dev2026012904__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/ops/__init__.py +1 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
- keras/_tf_keras/keras/quantizers/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/ops/__init__.py +1 -0
- keras/ops/numpy/__init__.py +1 -0
- keras/quantizers/__init__.py +3 -0
- keras/src/backend/jax/core.py +12 -2
- keras/src/backend/jax/numpy.py +5 -0
- keras/src/backend/numpy/numpy.py +5 -0
- keras/src/backend/openvino/numpy.py +6 -0
- keras/src/backend/tensorflow/numpy.py +21 -0
- keras/src/backend/torch/numpy.py +10 -0
- keras/src/callbacks/orbax_checkpoint.py +41 -8
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +80 -1
- keras/src/layers/core/dense.py +278 -95
- keras/src/layers/core/einsum_dense.py +350 -181
- keras/src/layers/core/embedding.py +236 -49
- keras/src/layers/core/reversible_embedding.py +177 -35
- keras/src/layers/preprocessing/discretization.py +30 -1
- keras/src/ops/numpy.py +54 -0
- keras/src/quantizers/__init__.py +6 -0
- keras/src/quantizers/quantization_config.py +98 -4
- keras/src/quantizers/quantizers.py +262 -32
- keras/src/saving/file_editor.py +7 -1
- keras/src/saving/saving_api.py +66 -2
- keras/src/saving/saving_lib.py +46 -47
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/RECORD +34 -34
- {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/WHEEL +0 -0
- {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
1
3
|
import ml_dtypes
|
|
2
4
|
import numpy as np
|
|
3
5
|
|
|
@@ -118,6 +120,190 @@ def abs_max_quantize(
|
|
|
118
120
|
return outputs, scale
|
|
119
121
|
|
|
120
122
|
|
|
123
|
+
@keras_export("keras.quantizers.abs_max_quantize_grouped_with_zero_point")
|
|
124
|
+
def abs_max_quantize_grouped_with_zero_point(
|
|
125
|
+
inputs,
|
|
126
|
+
block_size,
|
|
127
|
+
value_range=(-8, 7),
|
|
128
|
+
dtype="int8",
|
|
129
|
+
epsilon=backend.epsilon(),
|
|
130
|
+
to_numpy=False,
|
|
131
|
+
):
|
|
132
|
+
"""Quantizes a 2D tensor using grouped asymmetric quantization with
|
|
133
|
+
zero point.
|
|
134
|
+
|
|
135
|
+
Groups are formed along axis 0 (the input/contracting dimension).
|
|
136
|
+
Each group of `block_size` rows gets its own scale factor and zero point
|
|
137
|
+
per column. This is useful for weight distributions that are not centered
|
|
138
|
+
around zero.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
inputs: Input tensor to quantize. Shape: `(input_dim, output_dim)`.
|
|
142
|
+
block_size: Number of elements per group along axis 0.
|
|
143
|
+
value_range: Tuple of `(min, max)` quantization range.
|
|
144
|
+
dtype: Data type of quantized output.
|
|
145
|
+
epsilon: Small value to avoid division by zero.
|
|
146
|
+
to_numpy: Whether to perform computation in numpy for memory
|
|
147
|
+
efficiency.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
A tuple `(quantized_tensor, scale, zero_point)` where:
|
|
151
|
+
- `quantized_tensor`: Same shape as inputs, dtype=`dtype`.
|
|
152
|
+
- `scale`: Shape `(n_groups, output_dim)` where
|
|
153
|
+
`n_groups = ceil(input_dim / block_size)`.
|
|
154
|
+
- `zero_point`: Shape `(n_groups, output_dim)`, dtype=`uint8`.
|
|
155
|
+
|
|
156
|
+
Example:
|
|
157
|
+
|
|
158
|
+
```python
|
|
159
|
+
>>> import numpy as np
|
|
160
|
+
>>> from keras.quantizers import abs_max_quantize_grouped_with_zero_point
|
|
161
|
+
>>> kernel = np.random.randn(512, 256).astype("float32")
|
|
162
|
+
>>> quantized, scale, zero_point = abs_max_quantize_grouped_with_zero_point(
|
|
163
|
+
... kernel, block_size=128, value_range=(-8, 7)
|
|
164
|
+
... )
|
|
165
|
+
>>> quantized.shape
|
|
166
|
+
(512, 256)
|
|
167
|
+
>>> scale.shape # 512 / 128 = 4 groups
|
|
168
|
+
(4, 256)
|
|
169
|
+
>>> zero_point.shape
|
|
170
|
+
(4, 256)
|
|
171
|
+
```
|
|
172
|
+
"""
|
|
173
|
+
if to_numpy:
|
|
174
|
+
return _abs_max_quantize_grouped_with_zero_point_numpy(
|
|
175
|
+
inputs, block_size, value_range, dtype, epsilon
|
|
176
|
+
)
|
|
177
|
+
return _abs_max_quantize_grouped_with_zero_point_tensor(
|
|
178
|
+
inputs, block_size, value_range, dtype, epsilon
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _abs_max_quantize_grouped_with_zero_point_numpy(
|
|
183
|
+
inputs, block_size, value_range, dtype, epsilon
|
|
184
|
+
):
|
|
185
|
+
"""NumPy implementation of grouped asymmetric quantization.
|
|
186
|
+
|
|
187
|
+
Uses NumPy for computation to reduce GPU memory usage during
|
|
188
|
+
model quantization.
|
|
189
|
+
"""
|
|
190
|
+
original_dtype = backend.standardize_dtype(inputs.dtype)
|
|
191
|
+
inputs = ops.convert_to_numpy(inputs)
|
|
192
|
+
|
|
193
|
+
input_dim, output_dim = inputs.shape
|
|
194
|
+
n_groups = math.ceil(input_dim / block_size)
|
|
195
|
+
qmin, qmax = value_range
|
|
196
|
+
|
|
197
|
+
# Zero-pad rows so input_dim is divisible by block_size
|
|
198
|
+
padded_input_dim = n_groups * block_size
|
|
199
|
+
if padded_input_dim > input_dim:
|
|
200
|
+
padding = np.zeros(
|
|
201
|
+
(padded_input_dim - input_dim, output_dim), dtype=inputs.dtype
|
|
202
|
+
)
|
|
203
|
+
inputs_padded = np.concatenate([inputs, padding], axis=0)
|
|
204
|
+
else:
|
|
205
|
+
inputs_padded = inputs
|
|
206
|
+
|
|
207
|
+
inputs_reshaped = inputs_padded.reshape(n_groups, block_size, output_dim)
|
|
208
|
+
|
|
209
|
+
# Compute per-group min/max for asymmetric quantization
|
|
210
|
+
min_val = np.min(inputs_reshaped, axis=1, keepdims=True)
|
|
211
|
+
max_val = np.max(inputs_reshaped, axis=1, keepdims=True)
|
|
212
|
+
|
|
213
|
+
# Scale maps the [min, max] range to [qmin, qmax]
|
|
214
|
+
scale = np.divide(np.subtract(max_val, min_val) + epsilon, qmax - qmin)
|
|
215
|
+
|
|
216
|
+
# Zero point shifts the quantized range to include the original zero
|
|
217
|
+
zero_point = np.round(np.divide(-min_val, scale)) + qmin
|
|
218
|
+
zero_point = np.clip(zero_point, qmin, qmax)
|
|
219
|
+
|
|
220
|
+
# Quantize: q = round(input / scale) + zero_point
|
|
221
|
+
outputs = np.round(np.divide(inputs_reshaped, scale)) + zero_point
|
|
222
|
+
outputs = np.clip(outputs, qmin, qmax)
|
|
223
|
+
outputs = outputs.astype(dtype)
|
|
224
|
+
|
|
225
|
+
# Remove padding and squeeze to (n_groups, output_dim)
|
|
226
|
+
outputs = outputs.reshape(padded_input_dim, output_dim)[:input_dim, :]
|
|
227
|
+
scale = np.squeeze(scale, axis=1)
|
|
228
|
+
zero_point = np.squeeze(zero_point, axis=1).astype("int8")
|
|
229
|
+
|
|
230
|
+
return (
|
|
231
|
+
ops.convert_to_tensor(outputs),
|
|
232
|
+
ops.convert_to_tensor(scale, dtype=original_dtype),
|
|
233
|
+
ops.convert_to_tensor(zero_point),
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _abs_max_quantize_grouped_with_zero_point_tensor(
|
|
238
|
+
inputs, block_size, value_range, dtype, epsilon
|
|
239
|
+
):
|
|
240
|
+
"""Tensor backend implementation of grouped asymmetric quantization."""
|
|
241
|
+
original_dtype = backend.standardize_dtype(inputs.dtype)
|
|
242
|
+
inputs = ops.convert_to_tensor(inputs)
|
|
243
|
+
|
|
244
|
+
input_shape = ops.shape(inputs)
|
|
245
|
+
input_dim = input_shape[0]
|
|
246
|
+
output_dim = input_shape[1]
|
|
247
|
+
qmin, qmax = value_range
|
|
248
|
+
|
|
249
|
+
# Infer bit-width from quantization range (e.g., [-8, 7] -> 4 bits)
|
|
250
|
+
num_levels = qmax - qmin + 1
|
|
251
|
+
bits = int(math.log2(num_levels))
|
|
252
|
+
|
|
253
|
+
n_groups = int(math.ceil(int(input_dim) / block_size))
|
|
254
|
+
padded_input_dim = n_groups * block_size
|
|
255
|
+
|
|
256
|
+
# Transpose to [out_features, in_features] for
|
|
257
|
+
# compute_quantization_parameters
|
|
258
|
+
inputs_t = ops.transpose(inputs)
|
|
259
|
+
|
|
260
|
+
# Compute scale and zero point using the unified quantization function
|
|
261
|
+
scale_t, zero_point_t, _ = compute_quantization_parameters(
|
|
262
|
+
inputs_t,
|
|
263
|
+
bits=bits,
|
|
264
|
+
symmetric=False,
|
|
265
|
+
per_channel=True,
|
|
266
|
+
group_size=block_size,
|
|
267
|
+
compute_dtype=original_dtype,
|
|
268
|
+
epsilon=epsilon,
|
|
269
|
+
signed=True,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Transpose results back to (n_groups, output_dim)
|
|
273
|
+
scale = ops.transpose(scale_t)
|
|
274
|
+
zero_point = ops.transpose(zero_point_t)
|
|
275
|
+
|
|
276
|
+
# Zero-pad rows so input_dim is divisible by block_size
|
|
277
|
+
pad_size = padded_input_dim - int(input_dim)
|
|
278
|
+
if pad_size > 0:
|
|
279
|
+
padding = ops.zeros((pad_size, output_dim), dtype=inputs.dtype)
|
|
280
|
+
inputs_padded = ops.concatenate([inputs, padding], axis=0)
|
|
281
|
+
else:
|
|
282
|
+
inputs_padded = inputs
|
|
283
|
+
|
|
284
|
+
inputs_reshaped = ops.reshape(
|
|
285
|
+
inputs_padded, (n_groups, block_size, output_dim)
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Expand scale and zero_point for broadcasting across block_size
|
|
289
|
+
scale_expanded = ops.expand_dims(scale, axis=1)
|
|
290
|
+
zero_point_expanded = ops.expand_dims(zero_point, axis=1)
|
|
291
|
+
|
|
292
|
+
# Quantize: q = round(input / scale) + zero_point
|
|
293
|
+
outputs = ops.add(
|
|
294
|
+
ops.round(ops.divide(inputs_reshaped, scale_expanded)),
|
|
295
|
+
zero_point_expanded,
|
|
296
|
+
)
|
|
297
|
+
outputs = ops.clip(outputs, qmin, qmax)
|
|
298
|
+
outputs = ops.cast(outputs, dtype)
|
|
299
|
+
|
|
300
|
+
# Remove padding
|
|
301
|
+
outputs = ops.reshape(outputs, (padded_input_dim, output_dim))
|
|
302
|
+
outputs = outputs[:input_dim, :]
|
|
303
|
+
|
|
304
|
+
return outputs, scale, zero_point
|
|
305
|
+
|
|
306
|
+
|
|
121
307
|
@keras_export("keras.quantizers.AbsMaxQuantizer")
|
|
122
308
|
class AbsMaxQuantizer(Quantizer):
|
|
123
309
|
def __init__(
|
|
@@ -796,6 +982,8 @@ def compute_quantization_parameters(
|
|
|
796
982
|
per_channel=False,
|
|
797
983
|
group_size=-1,
|
|
798
984
|
compute_dtype="float32",
|
|
985
|
+
epsilon=0.0,
|
|
986
|
+
signed=False,
|
|
799
987
|
):
|
|
800
988
|
"""
|
|
801
989
|
Computes the scale and zero-point for quantizing weight tensors.
|
|
@@ -816,10 +1004,17 @@ def compute_quantization_parameters(
|
|
|
816
1004
|
per_channel: bool. Whether to quantize per channel.
|
|
817
1005
|
group_size: int. The group size for quantization. -1 means no grouping.
|
|
818
1006
|
compute_dtype: str. The dtype for computation. Defaults to "float32".
|
|
1007
|
+
epsilon: float. Small value added to (max - min) before computing
|
|
1008
|
+
scale to avoid division by zero. Defaults to 0.0.
|
|
1009
|
+
signed: bool. Whether to use signed quantization range. If True, uses
|
|
1010
|
+
range [-2^(bits-1), 2^(bits-1)-1] (e.g., [-8, 7] for 4-bit).
|
|
1011
|
+
If False, uses range [0, 2^bits-1] (e.g., [0, 15] for 4-bit).
|
|
1012
|
+
Defaults to False.
|
|
819
1013
|
|
|
820
1014
|
Returns:
|
|
821
1015
|
scale: KerasTensor. The scale tensor for quantization.
|
|
822
|
-
zero: KerasTensor. The zero tensor for quantization
|
|
1016
|
+
zero: KerasTensor. The zero tensor for quantization (int8 if signed,
|
|
1017
|
+
uint8 if unsigned).
|
|
823
1018
|
maxq: scalar. The maximum quantization value.
|
|
824
1019
|
"""
|
|
825
1020
|
# Input validation
|
|
@@ -874,13 +1069,31 @@ def compute_quantization_parameters(
|
|
|
874
1069
|
|
|
875
1070
|
# Compute scale and zero-point
|
|
876
1071
|
maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)
|
|
877
|
-
|
|
1072
|
+
range_values = ops.subtract(max_values, min_values)
|
|
1073
|
+
if epsilon > 0:
|
|
1074
|
+
range_values = ops.add(range_values, epsilon)
|
|
1075
|
+
scale = ops.divide(range_values, maxq)
|
|
878
1076
|
scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
|
|
879
1077
|
|
|
880
|
-
|
|
881
|
-
|
|
1078
|
+
# Compute zero-point based on signed/unsigned mode
|
|
1079
|
+
if signed:
|
|
1080
|
+
# For signed range [-2^(bits-1), 2^(bits-1)-1], e.g., [-8, 7] for 4-bit
|
|
1081
|
+
qmin = -(2 ** (bits - 1)) # e.g., -8 for 4-bit
|
|
1082
|
+
qmax_signed = 2 ** (bits - 1) - 1 # e.g., 7 for 4-bit
|
|
1083
|
+
if symmetric:
|
|
1084
|
+
zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2) + qmin)
|
|
1085
|
+
else:
|
|
1086
|
+
# zero_signed = round(-min / scale) + qmin
|
|
1087
|
+
zero = ops.add(
|
|
1088
|
+
ops.round(ops.divide(ops.negative(min_values), scale)), qmin
|
|
1089
|
+
)
|
|
1090
|
+
zero = ops.clip(zero, qmin, qmax_signed)
|
|
882
1091
|
else:
|
|
883
|
-
|
|
1092
|
+
# For unsigned range [0, 2^bits-1], e.g., [0, 15] for 4-bit
|
|
1093
|
+
if symmetric:
|
|
1094
|
+
zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
|
|
1095
|
+
else:
|
|
1096
|
+
zero = ops.round(ops.divide(ops.negative(min_values), scale))
|
|
884
1097
|
|
|
885
1098
|
# Reshape output to [out_features, n_groups] or [out_features, 1]
|
|
886
1099
|
if n_groups > 1:
|
|
@@ -893,7 +1106,8 @@ def compute_quantization_parameters(
|
|
|
893
1106
|
scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))
|
|
894
1107
|
zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))
|
|
895
1108
|
|
|
896
|
-
|
|
1109
|
+
zero_dtype = "int8" if signed else "uint8"
|
|
1110
|
+
return scale, ops.cast(zero, zero_dtype), maxq
|
|
897
1111
|
|
|
898
1112
|
|
|
899
1113
|
def quantize_with_zero_point(input_tensor, scale, zero, maxq):
|
|
@@ -942,51 +1156,67 @@ def dequantize_with_zero_point(input_tensor, scale, zero):
|
|
|
942
1156
|
)
|
|
943
1157
|
|
|
944
1158
|
|
|
945
|
-
def quantize_with_sz_map(
|
|
1159
|
+
def quantize_with_sz_map(
|
|
1160
|
+
weights_matrix, scale, zero, g_idx, maxq, group_axis=-1
|
|
1161
|
+
):
|
|
946
1162
|
"""Quantize the weight matrix from group params.
|
|
947
1163
|
|
|
948
1164
|
This function uses the provided scale and zero tensors to quantize the
|
|
949
|
-
input weights_matrix according to the group indices. It maps each
|
|
950
|
-
of the weights_matrix to its corresponding group
|
|
951
|
-
the quantization operation.
|
|
1165
|
+
input weights_matrix according to the group indices. It maps each position
|
|
1166
|
+
along group_axis of the weights_matrix to its corresponding group
|
|
1167
|
+
parameters and performs the quantization operation.
|
|
952
1168
|
|
|
953
1169
|
Args:
|
|
954
|
-
weights_matrix:
|
|
955
|
-
scale: Per-group scale tensor
|
|
956
|
-
zero: Per-group zero-point tensor
|
|
957
|
-
g_idx:
|
|
958
|
-
|
|
1170
|
+
weights_matrix: Tensor to quantize.
|
|
1171
|
+
scale: Per-group scale tensor with n_groups along group_axis.
|
|
1172
|
+
zero: Per-group zero-point tensor with n_groups along group_axis.
|
|
1173
|
+
g_idx: 1D integer tensor of length equal to the size of
|
|
1174
|
+
`weights_matrix` along the dimension being quantized. Each
|
|
1175
|
+
element specifies which group index (0 to n_groups-1) that
|
|
1176
|
+
position belongs to. For example, with 128 columns and
|
|
1177
|
+
group_size=32, g_idx would be
|
|
1178
|
+
`[0,0,...,0, 1,1,...,1, 2,2,...,2, 3,3,...,3]` (32 of each).
|
|
959
1179
|
maxq: Scalar (float) representing the maximum integer quantization
|
|
960
1180
|
level (e.g., 2^bits - 1).
|
|
1181
|
+
group_axis: The axis in `scale` and `zero` along which to index
|
|
1182
|
+
using `g_idx`. This determines which dimension of the
|
|
1183
|
+
scale/zero tensors contains the per-group values. Default: -1
|
|
1184
|
+
(last axis).
|
|
961
1185
|
|
|
962
1186
|
Returns:
|
|
963
1187
|
A tensor with the same shape as `weights_matrix` containing the
|
|
964
1188
|
quantized weights produced using the provided group parameters.
|
|
965
1189
|
"""
|
|
966
1190
|
groups = ops.cast(g_idx, "int32")
|
|
967
|
-
scale_cols = ops.take(scale, groups, axis=
|
|
968
|
-
zero_cols = ops.take(zero, groups, axis=
|
|
1191
|
+
scale_cols = ops.take(scale, groups, axis=group_axis)
|
|
1192
|
+
zero_cols = ops.take(zero, groups, axis=group_axis)
|
|
969
1193
|
|
|
970
1194
|
# Quantize elementwise, then cast to int
|
|
971
1195
|
return quantize_with_zero_point(weights_matrix, scale_cols, zero_cols, maxq)
|
|
972
1196
|
|
|
973
1197
|
|
|
974
|
-
def dequantize_with_sz_map(weights_matrix, scale, zero, g_idx):
|
|
1198
|
+
def dequantize_with_sz_map(weights_matrix, scale, zero, g_idx, group_axis=-1):
|
|
975
1199
|
"""Rebuild a dequantized weight matrix from group params.
|
|
976
1200
|
|
|
977
1201
|
This function uses the provided scale and zero tensors to dequantize the
|
|
978
|
-
input weights_matrix according to the group indices. It maps each
|
|
979
|
-
of the weights_matrix to its corresponding group
|
|
980
|
-
the dequantization operation.
|
|
1202
|
+
input weights_matrix according to the group indices. It maps each position
|
|
1203
|
+
along group_axis of the weights_matrix to its corresponding group
|
|
1204
|
+
parameters and performs the dequantization operation.
|
|
981
1205
|
|
|
982
1206
|
Args:
|
|
983
|
-
weights_matrix:
|
|
984
|
-
scale: Per-group scale tensor
|
|
985
|
-
zero: Per-group zero-point tensor
|
|
986
|
-
g_idx:
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
1207
|
+
weights_matrix: Tensor to dequantize.
|
|
1208
|
+
scale: Per-group scale tensor with n_groups along group_axis.
|
|
1209
|
+
zero: Per-group zero-point tensor with n_groups along group_axis.
|
|
1210
|
+
g_idx: 1D integer tensor of length equal to the size of
|
|
1211
|
+
`weights_matrix` along the dimension being dequantized. Each
|
|
1212
|
+
element specifies which group index (0 to n_groups-1) that
|
|
1213
|
+
position belongs to. For example, with 128 columns and
|
|
1214
|
+
group_size=32, g_idx would be
|
|
1215
|
+
`[0,0,...,0, 1,1,...,1, 2,2,...,2, 3,3,...,3]` (32 of each).
|
|
1216
|
+
group_axis: The axis in `scale` and `zero` along which to index
|
|
1217
|
+
using `g_idx`. This determines which dimension of the
|
|
1218
|
+
scale/zero tensors contains the per-group values. Default: -1
|
|
1219
|
+
(last axis).
|
|
990
1220
|
|
|
991
1221
|
Returns:
|
|
992
1222
|
A tensor with the same shape as `weights_matrix` containing the
|
|
@@ -994,12 +1224,12 @@ def dequantize_with_sz_map(weights_matrix, scale, zero, g_idx):
|
|
|
994
1224
|
"""
|
|
995
1225
|
# Map group indices to scales and zeros
|
|
996
1226
|
groups = ops.cast(g_idx, "int32")
|
|
997
|
-
scales_mapped = ops.take(scale, groups, axis=
|
|
998
|
-
zeros_mapped = ops.take(zero, groups, axis=
|
|
1227
|
+
scales_mapped = ops.take(scale, groups, axis=group_axis)
|
|
1228
|
+
zeros_mapped = ops.take(zero, groups, axis=group_axis)
|
|
999
1229
|
zeros_mapped = ops.cast(zeros_mapped, scales_mapped.dtype)
|
|
1000
1230
|
|
|
1001
|
-
|
|
1231
|
+
dequantized = ops.multiply(
|
|
1002
1232
|
ops.subtract(weights_matrix, zeros_mapped), scales_mapped
|
|
1003
1233
|
)
|
|
1004
1234
|
|
|
1005
|
-
return
|
|
1235
|
+
return dequantized
|
keras/src/saving/file_editor.py
CHANGED
|
@@ -509,9 +509,15 @@ class KerasFileEditor:
|
|
|
509
509
|
# ------------------------------------------------------
|
|
510
510
|
|
|
511
511
|
# Skip any objects that are not proper datasets
|
|
512
|
-
if not
|
|
512
|
+
if not isinstance(value, h5py.Dataset):
|
|
513
513
|
continue
|
|
514
514
|
|
|
515
|
+
if value.external:
|
|
516
|
+
raise ValueError(
|
|
517
|
+
"Not allowed: H5 file Dataset with external links: "
|
|
518
|
+
f"{value.external}"
|
|
519
|
+
)
|
|
520
|
+
|
|
515
521
|
shape = value.shape
|
|
516
522
|
dtype = value.dtype
|
|
517
523
|
|
keras/src/saving/saving_api.py
CHANGED
|
@@ -121,10 +121,11 @@ def save_model(model, filepath, overwrite=True, zipped=None, **kwargs):
|
|
|
121
121
|
|
|
122
122
|
@keras_export(["keras.saving.load_model", "keras.models.load_model"])
|
|
123
123
|
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
124
|
-
"""Loads a model saved via `model.save()
|
|
124
|
+
"""Loads a model saved via `model.save()` or from an Orbax checkpoint.
|
|
125
125
|
|
|
126
126
|
Args:
|
|
127
|
-
filepath: `str` or `pathlib.Path` object, path to the saved model file
|
|
127
|
+
filepath: `str` or `pathlib.Path` object, path to the saved model file
|
|
128
|
+
or Orbax checkpoint directory.
|
|
128
129
|
custom_objects: Optional dictionary mapping names
|
|
129
130
|
(strings) to custom classes or functions to be
|
|
130
131
|
considered during deserialization.
|
|
@@ -195,6 +196,16 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
195
196
|
compile=compile,
|
|
196
197
|
safe_mode=safe_mode,
|
|
197
198
|
)
|
|
199
|
+
|
|
200
|
+
# Check for Orbax checkpoint directory using utility function
|
|
201
|
+
if is_orbax_checkpoint(filepath):
|
|
202
|
+
return _load_model_from_orbax_checkpoint(
|
|
203
|
+
filepath,
|
|
204
|
+
custom_objects=custom_objects,
|
|
205
|
+
compile=compile,
|
|
206
|
+
safe_mode=safe_mode,
|
|
207
|
+
)
|
|
208
|
+
|
|
198
209
|
elif str(filepath).endswith(".keras"):
|
|
199
210
|
raise ValueError(
|
|
200
211
|
f"File not found: filepath={filepath}. "
|
|
@@ -337,3 +348,56 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
337
348
|
"`.weights.h5` files, legacy H5 format files "
|
|
338
349
|
"(`.h5` extension), or Orbax checkpoints."
|
|
339
350
|
)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _load_model_from_orbax_checkpoint(
|
|
354
|
+
filepath, custom_objects=None, compile=True, safe_mode=True
|
|
355
|
+
):
|
|
356
|
+
"""Load a model from an Orbax checkpoint directory."""
|
|
357
|
+
|
|
358
|
+
from keras.src.utils.module_utils import ocp
|
|
359
|
+
|
|
360
|
+
# Ensure orbax is available
|
|
361
|
+
ocp.initialize()
|
|
362
|
+
|
|
363
|
+
# Find the latest checkpoint step using the utility function
|
|
364
|
+
checkpoint_path = find_latest_orbax_checkpoint(filepath)
|
|
365
|
+
step = int(os.path.basename(checkpoint_path))
|
|
366
|
+
|
|
367
|
+
# Load the composite state efficiently
|
|
368
|
+
checkpointer = ocp.training.Checkpointer(directory=filepath)
|
|
369
|
+
with ocp.Context():
|
|
370
|
+
composite_state = checkpointer.load_pytree(step)
|
|
371
|
+
|
|
372
|
+
# Validate and extract model config
|
|
373
|
+
if "model_config" not in composite_state:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
"Checkpoint does not contain model configuration. "
|
|
376
|
+
"This checkpoint may have been saved with save_weights_only=True."
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Create and build model from config using saving_lib helper
|
|
380
|
+
# This properly handles shared objects and compile_config
|
|
381
|
+
model = saving_lib._model_from_config(
|
|
382
|
+
composite_state["model_config"],
|
|
383
|
+
custom_objects=custom_objects,
|
|
384
|
+
compile=compile,
|
|
385
|
+
safe_mode=safe_mode,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Prepare state tree with only variable keys for set_state_tree
|
|
389
|
+
variable_keys = [
|
|
390
|
+
"trainable_variables",
|
|
391
|
+
"non_trainable_variables",
|
|
392
|
+
"optimizer_variables",
|
|
393
|
+
"metrics_variables",
|
|
394
|
+
]
|
|
395
|
+
state_tree = {
|
|
396
|
+
key: composite_state[key]
|
|
397
|
+
for key in variable_keys
|
|
398
|
+
if key in composite_state
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
# Apply the loaded state to the model
|
|
402
|
+
model.set_state_tree(state_tree)
|
|
403
|
+
return model
|
keras/src/saving/saving_lib.py
CHANGED
|
@@ -796,7 +796,8 @@ def _load_state(
|
|
|
796
796
|
try:
|
|
797
797
|
saveable.load_own_variables(weights_store.get(inner_path))
|
|
798
798
|
except Exception as e:
|
|
799
|
-
failed_saveables
|
|
799
|
+
if failed_saveables is not None:
|
|
800
|
+
failed_saveables.add(id(saveable))
|
|
800
801
|
error_msgs[id(saveable)] = saveable, e
|
|
801
802
|
failure = True
|
|
802
803
|
else:
|
|
@@ -807,7 +808,8 @@ def _load_state(
|
|
|
807
808
|
try:
|
|
808
809
|
saveable.load_assets(assets_store.get(inner_path))
|
|
809
810
|
except Exception as e:
|
|
810
|
-
failed_saveables
|
|
811
|
+
if failed_saveables is not None:
|
|
812
|
+
failed_saveables.add(id(saveable))
|
|
811
813
|
error_msgs[id(saveable)] = saveable, e
|
|
812
814
|
failure = True
|
|
813
815
|
else:
|
|
@@ -855,7 +857,7 @@ def _load_state(
|
|
|
855
857
|
if not failure:
|
|
856
858
|
if visited_saveables is not None and newly_failed <= 0:
|
|
857
859
|
visited_saveables.add(id(saveable))
|
|
858
|
-
if id(saveable) in failed_saveables:
|
|
860
|
+
if failed_saveables is not None and id(saveable) in failed_saveables:
|
|
859
861
|
failed_saveables.remove(id(saveable))
|
|
860
862
|
error_msgs.pop(id(saveable))
|
|
861
863
|
|
|
@@ -1035,6 +1037,25 @@ class H5IOStore:
|
|
|
1035
1037
|
# will mistakenly using `__len__` to determine the value.
|
|
1036
1038
|
return self.h5_file.__bool__()
|
|
1037
1039
|
|
|
1040
|
+
def _verify_group(self, group):
|
|
1041
|
+
if not isinstance(group, h5py.Group):
|
|
1042
|
+
raise ValueError(
|
|
1043
|
+
f"Invalid H5 file, expected Group but received {type(group)}"
|
|
1044
|
+
)
|
|
1045
|
+
return group
|
|
1046
|
+
|
|
1047
|
+
def _verify_dataset(self, dataset):
|
|
1048
|
+
if not isinstance(dataset, h5py.Dataset):
|
|
1049
|
+
raise ValueError(
|
|
1050
|
+
f"Invalid H5 file, expected Dataset, received {type(dataset)}"
|
|
1051
|
+
)
|
|
1052
|
+
if dataset.external:
|
|
1053
|
+
raise ValueError(
|
|
1054
|
+
"Not allowed: H5 file Dataset with external links: "
|
|
1055
|
+
f"{dataset.external}"
|
|
1056
|
+
)
|
|
1057
|
+
return dataset
|
|
1058
|
+
|
|
1038
1059
|
def _get_h5_file(self, path_or_io, mode=None):
|
|
1039
1060
|
mode = mode or self.mode
|
|
1040
1061
|
if mode not in ("r", "w", "a"):
|
|
@@ -1094,15 +1115,19 @@ class H5IOStore:
|
|
|
1094
1115
|
self._h5_entry_group = {} # Defaults to an empty dict if not found.
|
|
1095
1116
|
if not path:
|
|
1096
1117
|
if "vars" in self.h5_file:
|
|
1097
|
-
self._h5_entry_group = self.h5_file["vars"]
|
|
1118
|
+
self._h5_entry_group = self._verify_group(self.h5_file["vars"])
|
|
1098
1119
|
elif path in self.h5_file and "vars" in self.h5_file[path]:
|
|
1099
|
-
self._h5_entry_group = self.
|
|
1120
|
+
self._h5_entry_group = self._verify_group(
|
|
1121
|
+
self._verify_group(self.h5_file[path])["vars"]
|
|
1122
|
+
)
|
|
1100
1123
|
else:
|
|
1101
1124
|
# No hit. Fix for 2.13 compatibility.
|
|
1102
1125
|
if "_layer_checkpoint_dependencies" in self.h5_file:
|
|
1103
1126
|
path = path.replace("layers", "_layer_checkpoint_dependencies")
|
|
1104
1127
|
if path in self.h5_file and "vars" in self.h5_file[path]:
|
|
1105
|
-
self._h5_entry_group = self.
|
|
1128
|
+
self._h5_entry_group = self._verify_group(
|
|
1129
|
+
self._verify_group(self.h5_file[path])["vars"]
|
|
1130
|
+
)
|
|
1106
1131
|
self._h5_entry_initialized = True
|
|
1107
1132
|
return self
|
|
1108
1133
|
|
|
@@ -1134,25 +1159,15 @@ class H5IOStore:
|
|
|
1134
1159
|
def keys(self):
|
|
1135
1160
|
return self._h5_entry_group.keys()
|
|
1136
1161
|
|
|
1137
|
-
def items(self):
|
|
1138
|
-
return self._h5_entry_group.items()
|
|
1139
|
-
|
|
1140
|
-
def values(self):
|
|
1141
|
-
return self._h5_entry_group.values()
|
|
1142
|
-
|
|
1143
1162
|
def __getitem__(self, key):
|
|
1144
|
-
value = self._h5_entry_group[key]
|
|
1163
|
+
value = self._verify_dataset(self._h5_entry_group[key])
|
|
1145
1164
|
if (
|
|
1146
1165
|
hasattr(value, "attrs")
|
|
1147
1166
|
and "dtype" in value.attrs
|
|
1148
1167
|
and value.attrs["dtype"] == "bfloat16"
|
|
1149
1168
|
):
|
|
1150
1169
|
value = np.array(value, dtype=ml_dtypes.bfloat16)
|
|
1151
|
-
elif (
|
|
1152
|
-
hasattr(value, "shape")
|
|
1153
|
-
and hasattr(value, "dtype")
|
|
1154
|
-
and not isinstance(value, np.ndarray)
|
|
1155
|
-
):
|
|
1170
|
+
elif not isinstance(value, np.ndarray):
|
|
1156
1171
|
value = np.array(value)
|
|
1157
1172
|
return value
|
|
1158
1173
|
|
|
@@ -1355,15 +1370,13 @@ class ShardedH5IOStore(H5IOStore):
|
|
|
1355
1370
|
self._get_h5_group(self._h5_entry_path)
|
|
1356
1371
|
|
|
1357
1372
|
def _restore_h5_file(self):
|
|
1358
|
-
"""Ensure the current shard is the last one created.
|
|
1359
|
-
|
|
1360
|
-
We use mode="a" to avoid truncating the file during the switching.
|
|
1361
|
-
"""
|
|
1373
|
+
"""Ensure the current shard is the last one created."""
|
|
1362
1374
|
if (
|
|
1363
1375
|
pathlib.Path(self.h5_file.filename).name
|
|
1364
1376
|
!= self.current_shard_path.name
|
|
1365
1377
|
):
|
|
1366
|
-
|
|
1378
|
+
mode = "a" if self.mode == "w" else "r"
|
|
1379
|
+
self._switch_h5_file(self.current_shard_path.name, mode=mode)
|
|
1367
1380
|
|
|
1368
1381
|
# H5 entry level methods.
|
|
1369
1382
|
|
|
@@ -1371,9 +1384,11 @@ class ShardedH5IOStore(H5IOStore):
|
|
|
1371
1384
|
"""Get the H5 entry group. If it doesn't exist, return an empty dict."""
|
|
1372
1385
|
try:
|
|
1373
1386
|
if not path:
|
|
1374
|
-
self._h5_entry_group = self.h5_file["vars"]
|
|
1387
|
+
self._h5_entry_group = self._verify_group(self.h5_file["vars"])
|
|
1375
1388
|
else:
|
|
1376
|
-
self._h5_entry_group = self.
|
|
1389
|
+
self._h5_entry_group = self._verify_group(
|
|
1390
|
+
self._verify_group(self.h5_file[path])["vars"]
|
|
1391
|
+
)
|
|
1377
1392
|
self._h5_entry_initialized = True
|
|
1378
1393
|
except KeyError:
|
|
1379
1394
|
self._h5_entry_group = {}
|
|
@@ -1392,33 +1407,17 @@ class ShardedH5IOStore(H5IOStore):
|
|
|
1392
1407
|
return total_len
|
|
1393
1408
|
|
|
1394
1409
|
def keys(self):
|
|
1395
|
-
keys =
|
|
1410
|
+
keys = []
|
|
1411
|
+
current_shard_keys = list(self._h5_entry_group.keys())
|
|
1396
1412
|
for filename in self.current_shard_filenames:
|
|
1397
1413
|
if filename == self.current_shard_path.name:
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1414
|
+
keys += current_shard_keys
|
|
1415
|
+
else:
|
|
1416
|
+
self._switch_h5_file(filename, mode="r")
|
|
1417
|
+
keys += list(self._h5_entry_group.keys())
|
|
1401
1418
|
self._restore_h5_file()
|
|
1402
1419
|
return keys
|
|
1403
1420
|
|
|
1404
|
-
def items(self):
|
|
1405
|
-
yield from self._h5_entry_group.items()
|
|
1406
|
-
for filename in self.current_shard_filenames:
|
|
1407
|
-
if filename == self.current_shard_path.name:
|
|
1408
|
-
continue
|
|
1409
|
-
self._switch_h5_file(filename, mode="r")
|
|
1410
|
-
yield from self._h5_entry_group.items()
|
|
1411
|
-
self._restore_h5_file()
|
|
1412
|
-
|
|
1413
|
-
def values(self):
|
|
1414
|
-
yield from self._h5_entry_group.values()
|
|
1415
|
-
for filename in self.current_shard_filenames:
|
|
1416
|
-
if filename == self.current_shard_path.name:
|
|
1417
|
-
continue
|
|
1418
|
-
self._switch_h5_file(filename, mode="r")
|
|
1419
|
-
yield from self._h5_entry_group.values()
|
|
1420
|
-
self._restore_h5_file()
|
|
1421
|
-
|
|
1422
1421
|
def __getitem__(self, key):
|
|
1423
1422
|
if key in self._h5_entry_group:
|
|
1424
1423
|
return super().__getitem__(key)
|
keras/src/version.py
CHANGED