keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/ops/operation.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
import os.path
|
|
3
2
|
import textwrap
|
|
4
3
|
|
|
5
4
|
from keras.src import backend
|
|
@@ -20,10 +19,10 @@ class Operation(KerasSaveable):
|
|
|
20
19
|
def __init__(self, name=None):
|
|
21
20
|
if name is None:
|
|
22
21
|
name = auto_name(self.__class__.__name__)
|
|
23
|
-
if not isinstance(name, str) or
|
|
22
|
+
if not isinstance(name, str) or "/" in name:
|
|
24
23
|
raise ValueError(
|
|
25
24
|
"Argument `name` must be a string and "
|
|
26
|
-
f"cannot contain character
|
|
25
|
+
f"cannot contain character `/`. "
|
|
27
26
|
f"Received: name={name} (of type {type(name)})"
|
|
28
27
|
)
|
|
29
28
|
self.name = name
|
|
@@ -130,15 +129,55 @@ class Operation(KerasSaveable):
|
|
|
130
129
|
vars(instance)["_object__state"] = nnx.object.ObjectState()
|
|
131
130
|
|
|
132
131
|
# Generate a config to be returned by default by `get_config()`.
|
|
133
|
-
|
|
134
|
-
|
|
132
|
+
auto_config = True
|
|
133
|
+
|
|
134
|
+
signature = inspect.signature(cls.__init__)
|
|
135
|
+
argspec = inspect.getfullargspec(cls.__init__)
|
|
136
|
+
|
|
137
|
+
try:
|
|
138
|
+
bound_parameters = signature.bind(None, *args, **kwargs)
|
|
139
|
+
except TypeError:
|
|
140
|
+
# Raised by signature.bind when the supplied args and kwargs
|
|
141
|
+
# do not match the signature.
|
|
142
|
+
auto_config = False
|
|
143
|
+
|
|
144
|
+
if auto_config and any(
|
|
145
|
+
[
|
|
146
|
+
param.kind == inspect.Parameter.POSITIONAL_ONLY
|
|
147
|
+
for name, param in signature.parameters.items()
|
|
148
|
+
if name != argspec.args[0]
|
|
149
|
+
]
|
|
150
|
+
):
|
|
151
|
+
# cls.__init__ takes positional only arguments, which
|
|
152
|
+
# cannot be restored via cls(**config)
|
|
153
|
+
auto_config = False
|
|
154
|
+
# Create variable to show appropriate warning in get_config.
|
|
155
|
+
instance._auto_config_error_args = True
|
|
156
|
+
|
|
157
|
+
if auto_config:
|
|
158
|
+
# Include default values in the config.
|
|
159
|
+
bound_parameters.apply_defaults()
|
|
160
|
+
# Extract all arguments as a dictionary.
|
|
161
|
+
kwargs = bound_parameters.arguments
|
|
162
|
+
# Expand variable kwargs argument.
|
|
163
|
+
kwargs |= kwargs.pop(argspec.varkw, {})
|
|
164
|
+
# Remove first positional argument, self.
|
|
165
|
+
kwargs.pop(argspec.args[0])
|
|
166
|
+
# Remove argument "name", as it is provided by get_config.
|
|
167
|
+
kwargs.pop("name", None)
|
|
168
|
+
if argspec.varargs is not None:
|
|
169
|
+
# Varargs cannot be meaningfully converted to a dictionary.
|
|
170
|
+
varargs = kwargs.pop(argspec.varargs)
|
|
171
|
+
if len(varargs) > 0:
|
|
172
|
+
auto_config = False
|
|
173
|
+
# Store variable to show appropriate warning in get_config.
|
|
174
|
+
instance._auto_config_error_args = True
|
|
135
175
|
|
|
136
176
|
# For safety, we only rely on auto-configs for a small set of
|
|
137
177
|
# serializable types.
|
|
138
178
|
supported_types = (str, int, float, bool, type(None))
|
|
139
179
|
try:
|
|
140
180
|
flat_arg_values = tree.flatten(kwargs)
|
|
141
|
-
auto_config = True
|
|
142
181
|
for value in flat_arg_values:
|
|
143
182
|
if not isinstance(value, supported_types):
|
|
144
183
|
auto_config = False
|
|
@@ -193,30 +232,52 @@ class Operation(KerasSaveable):
|
|
|
193
232
|
config.pop("name", None)
|
|
194
233
|
return config
|
|
195
234
|
else:
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
235
|
+
example_str = """
|
|
236
|
+
class CustomLayer(keras.layers.Layer):
|
|
237
|
+
def __init__(self, arg1, arg2, **kwargs):
|
|
238
|
+
super().__init__(**kwargs)
|
|
239
|
+
self.arg1 = arg1
|
|
240
|
+
self.arg2 = arg2
|
|
241
|
+
|
|
242
|
+
def get_config(self):
|
|
243
|
+
config = super().get_config()
|
|
244
|
+
config.update({
|
|
245
|
+
"arg1": self.arg1,
|
|
246
|
+
"arg2": self.arg2,
|
|
247
|
+
})
|
|
248
|
+
return config
|
|
249
|
+
"""
|
|
250
|
+
if getattr(self, "_auto_config_error_args", False):
|
|
251
|
+
raise NotImplementedError(
|
|
252
|
+
textwrap.dedent(
|
|
253
|
+
f"""
|
|
254
|
+
Object {self.__class__.__name__} was created by passing
|
|
255
|
+
positional only or variadic positional arguments (e.g.,
|
|
256
|
+
`*args`) to `__init__()`, which is not supported by the
|
|
257
|
+
automatic config generation. Please remove all positional
|
|
258
|
+
only and variadic arguments from `__init__()`
|
|
259
|
+
or override `get_config()` and `from_config()` to make
|
|
260
|
+
the object serializatble.
|
|
261
|
+
|
|
262
|
+
Example:
|
|
263
|
+
|
|
264
|
+
{example_str}"""
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
else:
|
|
268
|
+
raise NotImplementedError(
|
|
269
|
+
textwrap.dedent(
|
|
270
|
+
f"""
|
|
271
|
+
Object {self.__class__.__name__} was created by passing
|
|
272
|
+
non-serializable argument values in `__init__()`,
|
|
273
|
+
and therefore the object must override `get_config()` in
|
|
274
|
+
order to be serializable. Please implement `get_config()`.
|
|
275
|
+
|
|
276
|
+
Example:
|
|
277
|
+
|
|
278
|
+
{example_str}"""
|
|
279
|
+
)
|
|
218
280
|
)
|
|
219
|
-
)
|
|
220
281
|
|
|
221
282
|
@classmethod
|
|
222
283
|
def from_config(cls, config):
|
keras/src/ops/operation_utils.py
CHANGED
|
@@ -158,33 +158,52 @@ class Adafactor(optimizer.Optimizer):
|
|
|
158
158
|
rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step))
|
|
159
159
|
alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t
|
|
160
160
|
regulated_grad_square = ops.add(ops.square(gradient), self.epsilon_1)
|
|
161
|
-
beta_2_t = 1
|
|
161
|
+
beta_2_t = ops.subtract(1, ops.power(local_step, self.beta_2_decay))
|
|
162
162
|
|
|
163
163
|
if len(variable.shape) >= 2:
|
|
164
164
|
# `r` deletes the last dimension of gradient, so it is of shape
|
|
165
165
|
# `gradient.shape[:-1]`.
|
|
166
166
|
self.assign(
|
|
167
167
|
r,
|
|
168
|
-
|
|
169
|
-
|
|
168
|
+
ops.add(
|
|
169
|
+
ops.multiply(beta_2_t, r),
|
|
170
|
+
ops.multiply(
|
|
171
|
+
ops.subtract(1, beta_2_t),
|
|
172
|
+
ops.mean(regulated_grad_square, axis=-1),
|
|
173
|
+
),
|
|
174
|
+
),
|
|
170
175
|
)
|
|
171
176
|
# `c` deletes the second last dimension of gradient, so it is of
|
|
172
177
|
# shape `gradient.shape[:-2] + gradient.shape[-1]`.
|
|
173
178
|
self.assign(
|
|
174
179
|
c,
|
|
175
|
-
|
|
176
|
-
|
|
180
|
+
ops.add(
|
|
181
|
+
ops.multiply(beta_2_t, c),
|
|
182
|
+
ops.multiply(
|
|
183
|
+
ops.subtract(1, beta_2_t),
|
|
184
|
+
ops.mean(regulated_grad_square, axis=-2),
|
|
185
|
+
),
|
|
186
|
+
),
|
|
177
187
|
)
|
|
178
188
|
self.assign(
|
|
179
189
|
v,
|
|
180
|
-
ops.
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
190
|
+
ops.multiply(
|
|
191
|
+
ops.expand_dims(
|
|
192
|
+
ops.divide(r, ops.mean(r, axis=-1, keepdims=True)),
|
|
193
|
+
axis=-1,
|
|
194
|
+
),
|
|
195
|
+
ops.expand_dims(c, -2),
|
|
196
|
+
),
|
|
184
197
|
)
|
|
185
198
|
else:
|
|
186
199
|
self.assign(
|
|
187
|
-
v,
|
|
200
|
+
v,
|
|
201
|
+
ops.add(
|
|
202
|
+
ops.multiply(beta_2_t, v),
|
|
203
|
+
ops.multiply(
|
|
204
|
+
ops.subtract(1, beta_2_t), regulated_grad_square
|
|
205
|
+
),
|
|
206
|
+
),
|
|
188
207
|
)
|
|
189
208
|
|
|
190
209
|
u_t = ops.divide(gradient, ops.sqrt(v))
|
|
@@ -631,6 +631,20 @@ class BaseOptimizer(KerasSaveable):
|
|
|
631
631
|
g_acc.assign(n_g_acc)
|
|
632
632
|
|
|
633
633
|
def stateless_apply(self, optimizer_variables, grads, trainable_variables):
|
|
634
|
+
"""Stateless version of `apply` that returns modified variables.
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
optimizer_variables: list of tensors containing the current values
|
|
638
|
+
for the optimizer variables. These are native tensors and not
|
|
639
|
+
`keras.Variable`s.
|
|
640
|
+
grads: list of gradients to apply.
|
|
641
|
+
trainable_variables: list of tensors containing the current values
|
|
642
|
+
for the model variables. These are native tensors and not
|
|
643
|
+
`keras.Variable`s.
|
|
644
|
+
|
|
645
|
+
Returns: A tuple containing two list of tensors, the updated
|
|
646
|
+
`trainable_variables` and the updated `optimizer_variables`.
|
|
647
|
+
"""
|
|
634
648
|
self._check_super_called()
|
|
635
649
|
|
|
636
650
|
if not self.built:
|
|
@@ -969,10 +983,15 @@ class BaseOptimizer(KerasSaveable):
|
|
|
969
983
|
):
|
|
970
984
|
if average is not None:
|
|
971
985
|
not_first_step = ops.not_equal(self.iterations, 0)
|
|
972
|
-
momentum = (
|
|
973
|
-
ops.cast(not_first_step, var.dtype)
|
|
986
|
+
momentum = ops.multiply(
|
|
987
|
+
ops.cast(not_first_step, var.dtype), self.ema_momentum
|
|
988
|
+
)
|
|
989
|
+
average.assign(
|
|
990
|
+
ops.add(
|
|
991
|
+
ops.multiply(momentum, average),
|
|
992
|
+
ops.multiply(ops.subtract(1, momentum), var),
|
|
993
|
+
)
|
|
974
994
|
)
|
|
975
|
-
average.assign(momentum * average + (1 - momentum) * var)
|
|
976
995
|
|
|
977
996
|
def _overwrite_model_variables_with_average_value(
|
|
978
997
|
self, trainable_variables
|
|
@@ -48,6 +48,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|
|
48
48
|
inner_optimizer,
|
|
49
49
|
initial_scale=2.0**15,
|
|
50
50
|
dynamic_growth_steps=2000,
|
|
51
|
+
name=None,
|
|
51
52
|
**kwargs,
|
|
52
53
|
):
|
|
53
54
|
if not kwargs.pop("dynamic", True):
|
|
@@ -56,7 +57,42 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|
|
56
57
|
"Instead, simply set `loss_scale_factor` directly on the "
|
|
57
58
|
"`inner_optimizer`."
|
|
58
59
|
)
|
|
59
|
-
|
|
60
|
+
|
|
61
|
+
# Backwards compatibility code for deserialization.
|
|
62
|
+
# LossScaleOptimizer used to return all these parameters in `get_config`
|
|
63
|
+
# from `super.get_config` even though they are all non-functional. We
|
|
64
|
+
# no longer let user set them, but we have to allow the default values
|
|
65
|
+
# to be passed during deserialization to support older models.
|
|
66
|
+
base_optimizer_defaults = {
|
|
67
|
+
"weight_decay": None,
|
|
68
|
+
"clipnorm": None,
|
|
69
|
+
"global_clipnorm": None,
|
|
70
|
+
"clipvalue": None,
|
|
71
|
+
"use_ema": False,
|
|
72
|
+
"ema_momentum": 0.99,
|
|
73
|
+
"ema_overwrite_frequency": None,
|
|
74
|
+
"loss_scale_factor": None,
|
|
75
|
+
"gradient_accumulation_steps": None,
|
|
76
|
+
}
|
|
77
|
+
for arg_name, default_value in base_optimizer_defaults.items():
|
|
78
|
+
if arg_name not in kwargs:
|
|
79
|
+
continue
|
|
80
|
+
arg_value = kwargs.pop(arg_name)
|
|
81
|
+
if (
|
|
82
|
+
default_value is None and arg_value is not None
|
|
83
|
+
) or arg_value != default_value:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"LossScaleOptimizer does not support `{arg_name}`. "
|
|
86
|
+
f"Instead, set `{arg_name}` on the `inner_optimizer`."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if kwargs:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"LossScaleOptimizer does not support arguments: "
|
|
92
|
+
f"`{'`, `'.join(kwargs.keys())}`."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
super().__init__(learning_rate=0.0, name=name)
|
|
60
96
|
self.inner_optimizer = inner_optimizer
|
|
61
97
|
self.initial_scale = initial_scale
|
|
62
98
|
self.dynamic_growth_steps = dynamic_growth_steps
|
|
@@ -81,7 +117,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|
|
81
117
|
name="dynamic_scale",
|
|
82
118
|
)
|
|
83
119
|
self.inner_optimizer.build(var_list)
|
|
84
|
-
|
|
120
|
+
super().build(var_list)
|
|
85
121
|
|
|
86
122
|
@property
|
|
87
123
|
def variables(self):
|
|
@@ -112,7 +148,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|
|
112
148
|
mapping = list(zip(self.variables, optimizer_variables))
|
|
113
149
|
with backend.StatelessScope(state_mapping=mapping) as scope:
|
|
114
150
|
self.step_counter.assign(0)
|
|
115
|
-
self.dynamic_scale.assign(self.dynamic_scale
|
|
151
|
+
self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0))
|
|
116
152
|
return [scope.get_current_value(v) for v in self._variables]
|
|
117
153
|
|
|
118
154
|
def increment():
|
|
@@ -136,7 +172,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|
|
136
172
|
g
|
|
137
173
|
if g is None or self._overwrite_variable_with_gradient(v)
|
|
138
174
|
else ops.divide(g, scale)
|
|
139
|
-
for g, v in zip(grads,
|
|
175
|
+
for g, v in zip(grads, self._trainable_variables)
|
|
140
176
|
]
|
|
141
177
|
(
|
|
142
178
|
new_trainable_variables,
|
|
@@ -156,7 +192,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|
|
156
192
|
mapping = list(zip(self.variables, optimizer_variables))
|
|
157
193
|
with backend.StatelessScope(state_mapping=mapping) as scope:
|
|
158
194
|
self.step_counter.assign(0)
|
|
159
|
-
self.dynamic_scale.assign(self.dynamic_scale
|
|
195
|
+
self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5))
|
|
160
196
|
new_optimizer_variables = []
|
|
161
197
|
for v in self.variables:
|
|
162
198
|
new_optimizer_variables.append(scope.get_current_value(v))
|
|
@@ -190,7 +226,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|
|
190
226
|
|
|
191
227
|
def upscale():
|
|
192
228
|
self.step_counter.assign(0)
|
|
193
|
-
self.dynamic_scale.assign(self.dynamic_scale
|
|
229
|
+
self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0))
|
|
194
230
|
|
|
195
231
|
def increment():
|
|
196
232
|
self.step_counter.assign_add(1)
|
|
@@ -205,7 +241,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|
|
205
241
|
def _stateful_handle_non_finite_grads(self):
|
|
206
242
|
# If any inf or nan in grads, downscale loss and reset counter.
|
|
207
243
|
self.step_counter.assign(0)
|
|
208
|
-
self.dynamic_scale.assign(self.dynamic_scale
|
|
244
|
+
self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5))
|
|
209
245
|
|
|
210
246
|
def _common_apply(self, grads, trainable_variables=None):
|
|
211
247
|
finite = self.check_finite(grads)
|
|
@@ -278,25 +314,22 @@ class LossScaleOptimizer(optimizer.Optimizer):
|
|
|
278
314
|
|
|
279
315
|
def scale_loss(self, loss):
|
|
280
316
|
scale = self.dynamic_scale if self.built else self.initial_scale
|
|
281
|
-
return loss
|
|
317
|
+
return ops.multiply(loss, scale)
|
|
282
318
|
|
|
283
319
|
def finalize_variable_values(self, var_list):
|
|
284
320
|
self.inner_optimizer.finalize_variable_values(var_list)
|
|
285
321
|
|
|
286
322
|
def get_config(self):
|
|
287
|
-
|
|
323
|
+
# Do not use super().get_config() as only "name" is supported.
|
|
288
324
|
inner_optimizer_config = serialization_lib.serialize_keras_object(
|
|
289
325
|
self.inner_optimizer
|
|
290
326
|
)
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
)
|
|
298
|
-
del config["learning_rate"]
|
|
299
|
-
return config
|
|
327
|
+
return {
|
|
328
|
+
"name": self.name,
|
|
329
|
+
"inner_optimizer": inner_optimizer_config,
|
|
330
|
+
"initial_scale": self.initial_scale,
|
|
331
|
+
"dynamic_growth_steps": self.dynamic_growth_steps,
|
|
332
|
+
}
|
|
300
333
|
|
|
301
334
|
@classmethod
|
|
302
335
|
def from_config(cls, config, custom_objects=None):
|
keras/src/optimizers/muon.py
CHANGED
|
@@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer):
|
|
|
20
20
|
The Muon optimizer can use both the Muon update step or the
|
|
21
21
|
AdamW update step based on the following:
|
|
22
22
|
|
|
23
|
-
- For any variable that isn't 2D,
|
|
23
|
+
- For any variable that isn't 2D, the AdamW step
|
|
24
24
|
will be used. This is not configurable.
|
|
25
25
|
- If the argument `exclude_embeddings` (defaults to `True`) is set
|
|
26
26
|
to `True`, the AdamW step will be used.
|
|
@@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer):
|
|
|
46
46
|
that takes no arguments and returns the actual value to use.
|
|
47
47
|
The exponential decay rate for the 1st moment estimates. Defaults to
|
|
48
48
|
`0.9`.
|
|
49
|
-
adam_beta_2: A float value or a constant float tensor,
|
|
49
|
+
adam_beta_2: A float value or a constant float tensor, or a callable
|
|
50
50
|
that takes no arguments and returns the actual value to use.
|
|
51
51
|
The exponential decay rate for the 2nd moment estimates. Defaults to
|
|
52
52
|
`0.999`.
|
|
53
|
+
adam_weight_decay: Float. If set, weight decay is applied when using
|
|
54
|
+
the Adam optimizer.
|
|
53
55
|
epsilon: A small constant for numerical stability. This is
|
|
54
56
|
"epsilon hat" in the Kingma and Ba paper
|
|
55
57
|
(in the formula just before Section 2.1),
|
|
@@ -67,11 +69,15 @@ class Muon(optimizer.Optimizer):
|
|
|
67
69
|
It is recommended to use the default value
|
|
68
70
|
adam_lr_ratio: Float, the ratio of the learning rate when
|
|
69
71
|
using Adam to the main learning rate.
|
|
70
|
-
|
|
72
|
+
It is recommended to set it to 1
|
|
71
73
|
momentum: Float, momentum used by internal SGD.
|
|
72
74
|
ns_steps: Integer, number of Newton-Schulz iterations to run.
|
|
73
75
|
nesterov: Boolean, whether to use Nesterov-style momentum
|
|
74
76
|
{{base_optimizer_keyword_args}}
|
|
77
|
+
rms_rate: Float. A parameter from https://arxiv.org/abs/2502.16982
|
|
78
|
+
that can enhance the stability of Muon, allowing it to use the
|
|
79
|
+
same learning rate and weight decay as Adam. Defaults to `0.2`.
|
|
80
|
+
Set to `None` to disable this feature.
|
|
75
81
|
"""
|
|
76
82
|
|
|
77
83
|
def __init__(
|
|
@@ -79,8 +85,9 @@ class Muon(optimizer.Optimizer):
|
|
|
79
85
|
learning_rate=0.001,
|
|
80
86
|
adam_beta_1=0.9,
|
|
81
87
|
adam_beta_2=0.999,
|
|
88
|
+
adam_weight_decay=0.004,
|
|
82
89
|
epsilon=1e-7,
|
|
83
|
-
weight_decay=0.
|
|
90
|
+
weight_decay=0.004,
|
|
84
91
|
clipnorm=None,
|
|
85
92
|
clipvalue=None,
|
|
86
93
|
global_clipnorm=None,
|
|
@@ -95,10 +102,11 @@ class Muon(optimizer.Optimizer):
|
|
|
95
102
|
muon_a=3.4445,
|
|
96
103
|
muon_b=-4.7750,
|
|
97
104
|
muon_c=2.0315,
|
|
98
|
-
adam_lr_ratio=
|
|
105
|
+
adam_lr_ratio=1,
|
|
99
106
|
momentum=0.95,
|
|
100
|
-
ns_steps=
|
|
107
|
+
ns_steps=5,
|
|
101
108
|
nesterov=True,
|
|
109
|
+
rms_rate=0.2,
|
|
102
110
|
**kwargs,
|
|
103
111
|
):
|
|
104
112
|
super().__init__(
|
|
@@ -127,12 +135,13 @@ class Muon(optimizer.Optimizer):
|
|
|
127
135
|
self.nesterov = nesterov
|
|
128
136
|
self.exclude_embeddings = exclude_embeddings
|
|
129
137
|
self.exclude_layers = exclude_layers or []
|
|
138
|
+
self.adam_weight_decay = adam_weight_decay
|
|
139
|
+
self.rms_rate = rms_rate
|
|
130
140
|
|
|
131
141
|
def _should_use_adamw(self, variable):
|
|
132
|
-
# To use it with 4D convolutional filters,
|
|
133
142
|
# it works well to just flatten their last 3 dimensions.
|
|
134
143
|
# any {0,1}-D parameters should all be optimized by adam
|
|
135
|
-
if
|
|
144
|
+
if len(variable.shape) != 2:
|
|
136
145
|
return True
|
|
137
146
|
if self.exclude_embeddings and "embedding" in variable.path.lower():
|
|
138
147
|
return True
|
|
@@ -153,52 +162,50 @@ class Muon(optimizer.Optimizer):
|
|
|
153
162
|
if self.built:
|
|
154
163
|
return
|
|
155
164
|
super().build(var_list)
|
|
156
|
-
|
|
157
|
-
self.
|
|
158
|
-
|
|
159
|
-
self.
|
|
160
|
-
self.muon_velocities = {}
|
|
165
|
+
# Momentums are for both Muon and Adam
|
|
166
|
+
self.momentums = [None] * len(var_list)
|
|
167
|
+
# Velocities are just for Adam
|
|
168
|
+
self.adam_velocities = [None] * len(var_list)
|
|
161
169
|
|
|
162
170
|
for var in var_list:
|
|
163
171
|
if not self._overwrite_variable_with_gradient(var):
|
|
164
|
-
self.
|
|
172
|
+
self.momentums[self._get_variable_index(var)] = (
|
|
165
173
|
self.add_variable_from_reference(
|
|
166
174
|
reference_variable=var, name="momentum"
|
|
167
175
|
)
|
|
168
176
|
)
|
|
169
177
|
if self._should_use_adamw(var):
|
|
170
|
-
self.adam_velocities[var
|
|
178
|
+
self.adam_velocities[self._get_variable_index(var)] = (
|
|
171
179
|
self.add_variable_from_reference(
|
|
172
180
|
reference_variable=var, name="velocity"
|
|
173
181
|
)
|
|
174
182
|
)
|
|
175
183
|
|
|
176
184
|
def update_step(self, gradient, variable, learning_rate):
|
|
177
|
-
|
|
185
|
+
variable_index = self._get_variable_index(variable)
|
|
186
|
+
m = self.momentums[variable_index]
|
|
187
|
+
v = self.adam_velocities[variable_index]
|
|
188
|
+
|
|
189
|
+
# The presence of the velocity tells us that this variable is for Adam
|
|
190
|
+
if v is not None:
|
|
178
191
|
# It should be noted that lr is one-tenth when using adamw.
|
|
179
192
|
self._adamw_update_step(
|
|
180
|
-
gradient, variable, learning_rate * self.adam_lr_ratio
|
|
193
|
+
gradient, variable, learning_rate * self.adam_lr_ratio, m, v
|
|
181
194
|
)
|
|
182
195
|
else:
|
|
183
|
-
self._muon_update_step(gradient, variable, learning_rate)
|
|
196
|
+
self._muon_update_step(gradient, variable, learning_rate, m)
|
|
184
197
|
|
|
185
|
-
def _muon_update_step(self, gradient, variable, lr):
|
|
186
|
-
m = self.adam_momentums[variable.path]
|
|
198
|
+
def _muon_update_step(self, gradient, variable, lr, m):
|
|
187
199
|
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
|
|
188
|
-
shape = variable.shape
|
|
189
200
|
if self.nesterov:
|
|
190
201
|
g = ops.add(gradient, self.momentum * m)
|
|
191
202
|
else:
|
|
192
203
|
g = m
|
|
204
|
+
update = self.zeropower_via_newtonschulz5(g, self.ns_steps)
|
|
193
205
|
|
|
194
|
-
self.assign_sub(
|
|
195
|
-
variable,
|
|
196
|
-
lr
|
|
197
|
-
* self.zeropower_via_newtonschulz5(g, self.ns_steps)
|
|
198
|
-
* max(1, shape[0] / shape[1]) ** 0.5,
|
|
199
|
-
)
|
|
206
|
+
self.assign_sub(variable, self.lr_adjust(lr * update))
|
|
200
207
|
|
|
201
|
-
def _adamw_update_step(self, gradient, variable, learning_rate):
|
|
208
|
+
def _adamw_update_step(self, gradient, variable, learning_rate, m, v):
|
|
202
209
|
"""Update step given gradient and the associated model variable."""
|
|
203
210
|
lr = ops.cast(learning_rate, variable.dtype)
|
|
204
211
|
gradient = ops.cast(gradient, variable.dtype)
|
|
@@ -210,9 +217,6 @@ class Muon(optimizer.Optimizer):
|
|
|
210
217
|
ops.cast(self.adam_beta_2, variable.dtype), local_step
|
|
211
218
|
)
|
|
212
219
|
|
|
213
|
-
m = self.adam_momentums[variable.path]
|
|
214
|
-
v = self.adam_velocities[variable.path]
|
|
215
|
-
|
|
216
220
|
alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)
|
|
217
221
|
|
|
218
222
|
self.assign_add(
|
|
@@ -239,6 +243,20 @@ class Muon(optimizer.Optimizer):
|
|
|
239
243
|
X = ops.transpose(X, temp_order)
|
|
240
244
|
return X
|
|
241
245
|
|
|
246
|
+
def lr_adjust(self, x):
|
|
247
|
+
"""Adjusts learning rate based on the Moonlight implementation.
|
|
248
|
+
This method enhances the stability of Muon, allowing it to use the same
|
|
249
|
+
learning rate and weight decay as Adam. For details, see
|
|
250
|
+
https://arxiv.org/abs/2502.16982.
|
|
251
|
+
For a 2D matrix, the update is scaled by `sqrt(max(n, m)) * rms_rate`,
|
|
252
|
+
where `n` and `m` are the dimensions of the matrix.
|
|
253
|
+
"""
|
|
254
|
+
if self.rms_rate is None:
|
|
255
|
+
return x
|
|
256
|
+
# moonlight version
|
|
257
|
+
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
|
|
258
|
+
return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate
|
|
259
|
+
|
|
242
260
|
def zeropower_via_newtonschulz5(self, x, steps: int):
|
|
243
261
|
"""We apply the Newton-Schulz iteration to compute matrix G.
|
|
244
262
|
|
|
@@ -268,6 +286,20 @@ class Muon(optimizer.Optimizer):
|
|
|
268
286
|
x = self.transpose_last_axis(x)
|
|
269
287
|
return x
|
|
270
288
|
|
|
289
|
+
def _apply_weight_decay(self, variables):
|
|
290
|
+
for variable in variables:
|
|
291
|
+
if not self._use_weight_decay(variable):
|
|
292
|
+
continue
|
|
293
|
+
if self._should_use_adamw(variable):
|
|
294
|
+
weight_decay_value = self.adam_weight_decay
|
|
295
|
+
else:
|
|
296
|
+
weight_decay_value = self.weight_decay
|
|
297
|
+
if weight_decay_value is None:
|
|
298
|
+
continue
|
|
299
|
+
wd = ops.cast(weight_decay_value, variable.dtype)
|
|
300
|
+
lr = ops.cast(self.learning_rate, variable.dtype)
|
|
301
|
+
variable.assign(variable - variable * wd * lr)
|
|
302
|
+
|
|
271
303
|
def get_config(self):
|
|
272
304
|
config = super().get_config()
|
|
273
305
|
config.update(
|
|
@@ -284,6 +316,8 @@ class Muon(optimizer.Optimizer):
|
|
|
284
316
|
"ns_steps": self.ns_steps,
|
|
285
317
|
"nesterov": self.nesterov,
|
|
286
318
|
"exclude_embeddings": self.exclude_embeddings,
|
|
319
|
+
"adam_weight_decay": self.adam_weight_decay,
|
|
320
|
+
"rms_rate": self.rms_rate,
|
|
287
321
|
}
|
|
288
322
|
)
|
|
289
323
|
return config
|
|
@@ -584,9 +584,10 @@ class CosineDecay(LearningRateSchedule):
|
|
|
584
584
|
schedule applies a linear increase per optimizer step to our learning rate
|
|
585
585
|
from `initial_learning_rate` to `warmup_target` for a duration of
|
|
586
586
|
`warmup_steps`. Afterwards, it applies a cosine decay function taking our
|
|
587
|
-
learning rate from `warmup_target` to `alpha` for a
|
|
588
|
-
`decay_steps`. If `warmup_target` is None we skip warmup and
|
|
589
|
-
will take our learning rate from `initial_learning_rate` to
|
|
587
|
+
learning rate from `warmup_target` to `warmup_target * alpha` for a
|
|
588
|
+
duration of `decay_steps`. If `warmup_target` is None we skip warmup and
|
|
589
|
+
our decay will take our learning rate from `initial_learning_rate` to
|
|
590
|
+
`initial_learning_rate * alpha`.
|
|
590
591
|
It requires a `step` value to compute the learning rate. You can
|
|
591
592
|
just pass a backend variable that you increment at each training step.
|
|
592
593
|
|
keras/src/quantizers/__init__.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
|
|
3
3
|
from keras.src.api_export import keras_export
|
|
4
|
+
from keras.src.quantizers.awq_config import AWQConfig
|
|
5
|
+
from keras.src.quantizers.quantization_config import Float8QuantizationConfig
|
|
6
|
+
from keras.src.quantizers.quantization_config import Int4QuantizationConfig
|
|
7
|
+
from keras.src.quantizers.quantization_config import Int8QuantizationConfig
|
|
8
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
4
9
|
from keras.src.quantizers.quantizers import AbsMaxQuantizer
|
|
5
10
|
from keras.src.quantizers.quantizers import Quantizer
|
|
6
11
|
from keras.src.quantizers.quantizers import abs_max_quantize
|
|
@@ -13,7 +18,15 @@ from keras.src.quantizers.quantizers import unpack_int4
|
|
|
13
18
|
from keras.src.saving import serialization_lib
|
|
14
19
|
from keras.src.utils.naming import to_snake_case
|
|
15
20
|
|
|
16
|
-
ALL_OBJECTS = {
|
|
21
|
+
ALL_OBJECTS = {
|
|
22
|
+
Quantizer,
|
|
23
|
+
AbsMaxQuantizer,
|
|
24
|
+
QuantizationConfig,
|
|
25
|
+
Int8QuantizationConfig,
|
|
26
|
+
Int4QuantizationConfig,
|
|
27
|
+
Float8QuantizationConfig,
|
|
28
|
+
AWQConfig,
|
|
29
|
+
}
|
|
17
30
|
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
|
|
18
31
|
ALL_OBJECTS_DICT.update(
|
|
19
32
|
{to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}
|