keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  import inspect
2
- import os.path
3
2
  import textwrap
4
3
 
5
4
  from keras.src import backend
@@ -20,10 +19,10 @@ class Operation(KerasSaveable):
20
19
  def __init__(self, name=None):
21
20
  if name is None:
22
21
  name = auto_name(self.__class__.__name__)
23
- if not isinstance(name, str) or os.path.sep in name:
22
+ if not isinstance(name, str) or "/" in name:
24
23
  raise ValueError(
25
24
  "Argument `name` must be a string and "
26
- f"cannot contain character `{os.path.sep}`. "
25
+ f"cannot contain character `/`. "
27
26
  f"Received: name={name} (of type {type(name)})"
28
27
  )
29
28
  self.name = name
@@ -130,15 +129,55 @@ class Operation(KerasSaveable):
130
129
  vars(instance)["_object__state"] = nnx.object.ObjectState()
131
130
 
132
131
  # Generate a config to be returned by default by `get_config()`.
133
- arg_names = inspect.getfullargspec(cls.__init__).args
134
- kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
132
+ auto_config = True
133
+
134
+ signature = inspect.signature(cls.__init__)
135
+ argspec = inspect.getfullargspec(cls.__init__)
136
+
137
+ try:
138
+ bound_parameters = signature.bind(None, *args, **kwargs)
139
+ except TypeError:
140
+ # Raised by signature.bind when the supplied args and kwargs
141
+ # do not match the signature.
142
+ auto_config = False
143
+
144
+ if auto_config and any(
145
+ [
146
+ param.kind == inspect.Parameter.POSITIONAL_ONLY
147
+ for name, param in signature.parameters.items()
148
+ if name != argspec.args[0]
149
+ ]
150
+ ):
151
+ # cls.__init__ takes positional only arguments, which
152
+ # cannot be restored via cls(**config)
153
+ auto_config = False
154
+ # Create variable to show appropriate warning in get_config.
155
+ instance._auto_config_error_args = True
156
+
157
+ if auto_config:
158
+ # Include default values in the config.
159
+ bound_parameters.apply_defaults()
160
+ # Extract all arguments as a dictionary.
161
+ kwargs = bound_parameters.arguments
162
+ # Expand variable kwargs argument.
163
+ kwargs |= kwargs.pop(argspec.varkw, {})
164
+ # Remove first positional argument, self.
165
+ kwargs.pop(argspec.args[0])
166
+ # Remove argument "name", as it is provided by get_config.
167
+ kwargs.pop("name", None)
168
+ if argspec.varargs is not None:
169
+ # Varargs cannot be meaningfully converted to a dictionary.
170
+ varargs = kwargs.pop(argspec.varargs)
171
+ if len(varargs) > 0:
172
+ auto_config = False
173
+ # Store variable to show appropriate warning in get_config.
174
+ instance._auto_config_error_args = True
135
175
 
136
176
  # For safety, we only rely on auto-configs for a small set of
137
177
  # serializable types.
138
178
  supported_types = (str, int, float, bool, type(None))
139
179
  try:
140
180
  flat_arg_values = tree.flatten(kwargs)
141
- auto_config = True
142
181
  for value in flat_arg_values:
143
182
  if not isinstance(value, supported_types):
144
183
  auto_config = False
@@ -193,30 +232,52 @@ class Operation(KerasSaveable):
193
232
  config.pop("name", None)
194
233
  return config
195
234
  else:
196
- raise NotImplementedError(
197
- textwrap.dedent(
198
- f"""
199
- Object {self.__class__.__name__} was created by passing
200
- non-serializable argument values in `__init__()`,
201
- and therefore the object must override `get_config()` in
202
- order to be serializable. Please implement `get_config()`.
203
-
204
- Example:
205
-
206
- class CustomLayer(keras.layers.Layer):
207
- def __init__(self, arg1, arg2, **kwargs):
208
- super().__init__(**kwargs)
209
- self.arg1 = arg1
210
- self.arg2 = arg2
211
-
212
- def get_config(self):
213
- config = super().get_config()
214
- config.update({"arg1": self.arg1,
215
- "arg2": self.arg2,
216
- })
217
- return config"""
235
+ example_str = """
236
+ class CustomLayer(keras.layers.Layer):
237
+ def __init__(self, arg1, arg2, **kwargs):
238
+ super().__init__(**kwargs)
239
+ self.arg1 = arg1
240
+ self.arg2 = arg2
241
+
242
+ def get_config(self):
243
+ config = super().get_config()
244
+ config.update({
245
+ "arg1": self.arg1,
246
+ "arg2": self.arg2,
247
+ })
248
+ return config
249
+ """
250
+ if getattr(self, "_auto_config_error_args", False):
251
+ raise NotImplementedError(
252
+ textwrap.dedent(
253
+ f"""
254
+ Object {self.__class__.__name__} was created by passing
255
+ positional only or variadic positional arguments (e.g.,
256
+ `*args`) to `__init__()`, which is not supported by the
257
+ automatic config generation. Please remove all positional
258
+ only and variadic arguments from `__init__()`
259
+ or override `get_config()` and `from_config()` to make
260
+ the object serializatble.
261
+
262
+ Example:
263
+
264
+ {example_str}"""
265
+ )
266
+ )
267
+ else:
268
+ raise NotImplementedError(
269
+ textwrap.dedent(
270
+ f"""
271
+ Object {self.__class__.__name__} was created by passing
272
+ non-serializable argument values in `__init__()`,
273
+ and therefore the object must override `get_config()` in
274
+ order to be serializable. Please implement `get_config()`.
275
+
276
+ Example:
277
+
278
+ {example_str}"""
279
+ )
218
280
  )
219
- )
220
281
 
221
282
  @classmethod
222
283
  def from_config(cls, config):
@@ -378,6 +378,8 @@ def reduce_shape(shape, axis=None, keepdims=False):
378
378
  elif isinstance(axis, int):
379
379
  axis = (axis,)
380
380
 
381
+ axis = tuple(canonicalize_axis(a, len(shape)) for a in axis)
382
+
381
383
  if keepdims:
382
384
  for ax in axis:
383
385
  shape[ax] = 1
@@ -158,33 +158,52 @@ class Adafactor(optimizer.Optimizer):
158
158
  rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step))
159
159
  alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t
160
160
  regulated_grad_square = ops.add(ops.square(gradient), self.epsilon_1)
161
- beta_2_t = 1 - ops.power(local_step, self.beta_2_decay)
161
+ beta_2_t = ops.subtract(1, ops.power(local_step, self.beta_2_decay))
162
162
 
163
163
  if len(variable.shape) >= 2:
164
164
  # `r` deletes the last dimension of gradient, so it is of shape
165
165
  # `gradient.shape[:-1]`.
166
166
  self.assign(
167
167
  r,
168
- beta_2_t * r
169
- + (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-1),
168
+ ops.add(
169
+ ops.multiply(beta_2_t, r),
170
+ ops.multiply(
171
+ ops.subtract(1, beta_2_t),
172
+ ops.mean(regulated_grad_square, axis=-1),
173
+ ),
174
+ ),
170
175
  )
171
176
  # `c` deletes the second last dimension of gradient, so it is of
172
177
  # shape `gradient.shape[:-2] + gradient.shape[-1]`.
173
178
  self.assign(
174
179
  c,
175
- beta_2_t * c
176
- + (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-2),
180
+ ops.add(
181
+ ops.multiply(beta_2_t, c),
182
+ ops.multiply(
183
+ ops.subtract(1, beta_2_t),
184
+ ops.mean(regulated_grad_square, axis=-2),
185
+ ),
186
+ ),
177
187
  )
178
188
  self.assign(
179
189
  v,
180
- ops.expand_dims(
181
- r / ops.mean(r, axis=-1, keepdims=True), axis=-1
182
- )
183
- * ops.expand_dims(c, -2),
190
+ ops.multiply(
191
+ ops.expand_dims(
192
+ ops.divide(r, ops.mean(r, axis=-1, keepdims=True)),
193
+ axis=-1,
194
+ ),
195
+ ops.expand_dims(c, -2),
196
+ ),
184
197
  )
185
198
  else:
186
199
  self.assign(
187
- v, beta_2_t * v + (1 - beta_2_t) * regulated_grad_square
200
+ v,
201
+ ops.add(
202
+ ops.multiply(beta_2_t, v),
203
+ ops.multiply(
204
+ ops.subtract(1, beta_2_t), regulated_grad_square
205
+ ),
206
+ ),
188
207
  )
189
208
 
190
209
  u_t = ops.divide(gradient, ops.sqrt(v))
@@ -631,6 +631,20 @@ class BaseOptimizer(KerasSaveable):
631
631
  g_acc.assign(n_g_acc)
632
632
 
633
633
  def stateless_apply(self, optimizer_variables, grads, trainable_variables):
634
+ """Stateless version of `apply` that returns modified variables.
635
+
636
+ Args:
637
+ optimizer_variables: list of tensors containing the current values
638
+ for the optimizer variables. These are native tensors and not
639
+ `keras.Variable`s.
640
+ grads: list of gradients to apply.
641
+ trainable_variables: list of tensors containing the current values
642
+ for the model variables. These are native tensors and not
643
+ `keras.Variable`s.
644
+
645
+ Returns: A tuple containing two list of tensors, the updated
646
+ `trainable_variables` and the updated `optimizer_variables`.
647
+ """
634
648
  self._check_super_called()
635
649
 
636
650
  if not self.built:
@@ -969,10 +983,15 @@ class BaseOptimizer(KerasSaveable):
969
983
  ):
970
984
  if average is not None:
971
985
  not_first_step = ops.not_equal(self.iterations, 0)
972
- momentum = (
973
- ops.cast(not_first_step, var.dtype) * self.ema_momentum
986
+ momentum = ops.multiply(
987
+ ops.cast(not_first_step, var.dtype), self.ema_momentum
988
+ )
989
+ average.assign(
990
+ ops.add(
991
+ ops.multiply(momentum, average),
992
+ ops.multiply(ops.subtract(1, momentum), var),
993
+ )
974
994
  )
975
- average.assign(momentum * average + (1 - momentum) * var)
976
995
 
977
996
  def _overwrite_model_variables_with_average_value(
978
997
  self, trainable_variables
@@ -48,6 +48,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
48
48
  inner_optimizer,
49
49
  initial_scale=2.0**15,
50
50
  dynamic_growth_steps=2000,
51
+ name=None,
51
52
  **kwargs,
52
53
  ):
53
54
  if not kwargs.pop("dynamic", True):
@@ -56,7 +57,42 @@ class LossScaleOptimizer(optimizer.Optimizer):
56
57
  "Instead, simply set `loss_scale_factor` directly on the "
57
58
  "`inner_optimizer`."
58
59
  )
59
- super().__init__(learning_rate=0.0, **kwargs)
60
+
61
+ # Backwards compatibility code for deserialization.
62
+ # LossScaleOptimizer used to return all these parameters in `get_config`
63
+ # from `super.get_config` even though they are all non-functional. We
64
+ # no longer let user set them, but we have to allow the default values
65
+ # to be passed during deserialization to support older models.
66
+ base_optimizer_defaults = {
67
+ "weight_decay": None,
68
+ "clipnorm": None,
69
+ "global_clipnorm": None,
70
+ "clipvalue": None,
71
+ "use_ema": False,
72
+ "ema_momentum": 0.99,
73
+ "ema_overwrite_frequency": None,
74
+ "loss_scale_factor": None,
75
+ "gradient_accumulation_steps": None,
76
+ }
77
+ for arg_name, default_value in base_optimizer_defaults.items():
78
+ if arg_name not in kwargs:
79
+ continue
80
+ arg_value = kwargs.pop(arg_name)
81
+ if (
82
+ default_value is None and arg_value is not None
83
+ ) or arg_value != default_value:
84
+ raise ValueError(
85
+ f"LossScaleOptimizer does not support `{arg_name}`. "
86
+ f"Instead, set `{arg_name}` on the `inner_optimizer`."
87
+ )
88
+
89
+ if kwargs:
90
+ raise ValueError(
91
+ "LossScaleOptimizer does not support arguments: "
92
+ f"`{'`, `'.join(kwargs.keys())}`."
93
+ )
94
+
95
+ super().__init__(learning_rate=0.0, name=name)
60
96
  self.inner_optimizer = inner_optimizer
61
97
  self.initial_scale = initial_scale
62
98
  self.dynamic_growth_steps = dynamic_growth_steps
@@ -81,7 +117,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
81
117
  name="dynamic_scale",
82
118
  )
83
119
  self.inner_optimizer.build(var_list)
84
- self.built = True
120
+ super().build(var_list)
85
121
 
86
122
  @property
87
123
  def variables(self):
@@ -112,7 +148,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
112
148
  mapping = list(zip(self.variables, optimizer_variables))
113
149
  with backend.StatelessScope(state_mapping=mapping) as scope:
114
150
  self.step_counter.assign(0)
115
- self.dynamic_scale.assign(self.dynamic_scale * 2.0)
151
+ self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0))
116
152
  return [scope.get_current_value(v) for v in self._variables]
117
153
 
118
154
  def increment():
@@ -136,7 +172,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
136
172
  g
137
173
  if g is None or self._overwrite_variable_with_gradient(v)
138
174
  else ops.divide(g, scale)
139
- for g, v in zip(grads, trainable_variables)
175
+ for g, v in zip(grads, self._trainable_variables)
140
176
  ]
141
177
  (
142
178
  new_trainable_variables,
@@ -156,7 +192,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
156
192
  mapping = list(zip(self.variables, optimizer_variables))
157
193
  with backend.StatelessScope(state_mapping=mapping) as scope:
158
194
  self.step_counter.assign(0)
159
- self.dynamic_scale.assign(self.dynamic_scale / 2.0)
195
+ self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5))
160
196
  new_optimizer_variables = []
161
197
  for v in self.variables:
162
198
  new_optimizer_variables.append(scope.get_current_value(v))
@@ -190,7 +226,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
190
226
 
191
227
  def upscale():
192
228
  self.step_counter.assign(0)
193
- self.dynamic_scale.assign(self.dynamic_scale * 2.0)
229
+ self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0))
194
230
 
195
231
  def increment():
196
232
  self.step_counter.assign_add(1)
@@ -205,7 +241,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
205
241
  def _stateful_handle_non_finite_grads(self):
206
242
  # If any inf or nan in grads, downscale loss and reset counter.
207
243
  self.step_counter.assign(0)
208
- self.dynamic_scale.assign(self.dynamic_scale / 2.0)
244
+ self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5))
209
245
 
210
246
  def _common_apply(self, grads, trainable_variables=None):
211
247
  finite = self.check_finite(grads)
@@ -278,25 +314,22 @@ class LossScaleOptimizer(optimizer.Optimizer):
278
314
 
279
315
  def scale_loss(self, loss):
280
316
  scale = self.dynamic_scale if self.built else self.initial_scale
281
- return loss * scale
317
+ return ops.multiply(loss, scale)
282
318
 
283
319
  def finalize_variable_values(self, var_list):
284
320
  self.inner_optimizer.finalize_variable_values(var_list)
285
321
 
286
322
  def get_config(self):
287
- config = super().get_config()
323
+ # Do not use super().get_config() as only "name" is supported.
288
324
  inner_optimizer_config = serialization_lib.serialize_keras_object(
289
325
  self.inner_optimizer
290
326
  )
291
- config.update(
292
- {
293
- "inner_optimizer": inner_optimizer_config,
294
- "initial_scale": self.initial_scale,
295
- "dynamic_growth_steps": self.dynamic_growth_steps,
296
- }
297
- )
298
- del config["learning_rate"]
299
- return config
327
+ return {
328
+ "name": self.name,
329
+ "inner_optimizer": inner_optimizer_config,
330
+ "initial_scale": self.initial_scale,
331
+ "dynamic_growth_steps": self.dynamic_growth_steps,
332
+ }
300
333
 
301
334
  @classmethod
302
335
  def from_config(cls, config, custom_objects=None):
@@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer):
20
20
  The Muon optimizer can use both the Muon update step or the
21
21
  AdamW update step based on the following:
22
22
 
23
- - For any variable that isn't 2D, 3D or 4D, the AdamW step
23
+ - For any variable that isn't 2D, the AdamW step
24
24
  will be used. This is not configurable.
25
25
  - If the argument `exclude_embeddings` (defaults to `True`) is set
26
26
  to `True`, the AdamW step will be used.
@@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer):
46
46
  that takes no arguments and returns the actual value to use.
47
47
  The exponential decay rate for the 1st moment estimates. Defaults to
48
48
  `0.9`.
49
- adam_beta_2: A float value or a constant float tensor, ora callable
49
+ adam_beta_2: A float value or a constant float tensor, or a callable
50
50
  that takes no arguments and returns the actual value to use.
51
51
  The exponential decay rate for the 2nd moment estimates. Defaults to
52
52
  `0.999`.
53
+ adam_weight_decay: Float. If set, weight decay is applied when using
54
+ the Adam optimizer.
53
55
  epsilon: A small constant for numerical stability. This is
54
56
  "epsilon hat" in the Kingma and Ba paper
55
57
  (in the formula just before Section 2.1),
@@ -67,11 +69,15 @@ class Muon(optimizer.Optimizer):
67
69
  It is recommended to use the default value
68
70
  adam_lr_ratio: Float, the ratio of the learning rate when
69
71
  using Adam to the main learning rate.
70
- it is recommended to set it to 0.1
72
+ It is recommended to set it to 1
71
73
  momentum: Float, momentum used by internal SGD.
72
74
  ns_steps: Integer, number of Newton-Schulz iterations to run.
73
75
  nesterov: Boolean, whether to use Nesterov-style momentum
74
76
  {{base_optimizer_keyword_args}}
77
+ rms_rate: Float. A parameter from https://arxiv.org/abs/2502.16982
78
+ that can enhance the stability of Muon, allowing it to use the
79
+ same learning rate and weight decay as Adam. Defaults to `0.2`.
80
+ Set to `None` to disable this feature.
75
81
  """
76
82
 
77
83
  def __init__(
@@ -79,8 +85,9 @@ class Muon(optimizer.Optimizer):
79
85
  learning_rate=0.001,
80
86
  adam_beta_1=0.9,
81
87
  adam_beta_2=0.999,
88
+ adam_weight_decay=0.004,
82
89
  epsilon=1e-7,
83
- weight_decay=0.1,
90
+ weight_decay=0.004,
84
91
  clipnorm=None,
85
92
  clipvalue=None,
86
93
  global_clipnorm=None,
@@ -95,10 +102,11 @@ class Muon(optimizer.Optimizer):
95
102
  muon_a=3.4445,
96
103
  muon_b=-4.7750,
97
104
  muon_c=2.0315,
98
- adam_lr_ratio=0.1,
105
+ adam_lr_ratio=1,
99
106
  momentum=0.95,
100
- ns_steps=6,
107
+ ns_steps=5,
101
108
  nesterov=True,
109
+ rms_rate=0.2,
102
110
  **kwargs,
103
111
  ):
104
112
  super().__init__(
@@ -127,12 +135,13 @@ class Muon(optimizer.Optimizer):
127
135
  self.nesterov = nesterov
128
136
  self.exclude_embeddings = exclude_embeddings
129
137
  self.exclude_layers = exclude_layers or []
138
+ self.adam_weight_decay = adam_weight_decay
139
+ self.rms_rate = rms_rate
130
140
 
131
141
  def _should_use_adamw(self, variable):
132
- # To use it with 4D convolutional filters,
133
142
  # it works well to just flatten their last 3 dimensions.
134
143
  # any {0,1}-D parameters should all be optimized by adam
135
- if not 1 < len(variable.shape) < 4:
144
+ if len(variable.shape) != 2:
136
145
  return True
137
146
  if self.exclude_embeddings and "embedding" in variable.path.lower():
138
147
  return True
@@ -153,52 +162,50 @@ class Muon(optimizer.Optimizer):
153
162
  if self.built:
154
163
  return
155
164
  super().build(var_list)
156
- self.adam_momentums = {}
157
- self.adam_velocities = {}
158
-
159
- self.muon_momentums = {}
160
- self.muon_velocities = {}
165
+ # Momentums are for both Muon and Adam
166
+ self.momentums = [None] * len(var_list)
167
+ # Velocities are just for Adam
168
+ self.adam_velocities = [None] * len(var_list)
161
169
 
162
170
  for var in var_list:
163
171
  if not self._overwrite_variable_with_gradient(var):
164
- self.adam_momentums[var.path] = (
172
+ self.momentums[self._get_variable_index(var)] = (
165
173
  self.add_variable_from_reference(
166
174
  reference_variable=var, name="momentum"
167
175
  )
168
176
  )
169
177
  if self._should_use_adamw(var):
170
- self.adam_velocities[var.path] = (
178
+ self.adam_velocities[self._get_variable_index(var)] = (
171
179
  self.add_variable_from_reference(
172
180
  reference_variable=var, name="velocity"
173
181
  )
174
182
  )
175
183
 
176
184
  def update_step(self, gradient, variable, learning_rate):
177
- if self._should_use_adamw(variable):
185
+ variable_index = self._get_variable_index(variable)
186
+ m = self.momentums[variable_index]
187
+ v = self.adam_velocities[variable_index]
188
+
189
+ # The presence of the velocity tells us that this variable is for Adam
190
+ if v is not None:
178
191
  # It should be noted that lr is one-tenth when using adamw.
179
192
  self._adamw_update_step(
180
- gradient, variable, learning_rate * self.adam_lr_ratio
193
+ gradient, variable, learning_rate * self.adam_lr_ratio, m, v
181
194
  )
182
195
  else:
183
- self._muon_update_step(gradient, variable, learning_rate)
196
+ self._muon_update_step(gradient, variable, learning_rate, m)
184
197
 
185
- def _muon_update_step(self, gradient, variable, lr):
186
- m = self.adam_momentums[variable.path]
198
+ def _muon_update_step(self, gradient, variable, lr, m):
187
199
  self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
188
- shape = variable.shape
189
200
  if self.nesterov:
190
201
  g = ops.add(gradient, self.momentum * m)
191
202
  else:
192
203
  g = m
204
+ update = self.zeropower_via_newtonschulz5(g, self.ns_steps)
193
205
 
194
- self.assign_sub(
195
- variable,
196
- lr
197
- * self.zeropower_via_newtonschulz5(g, self.ns_steps)
198
- * max(1, shape[0] / shape[1]) ** 0.5,
199
- )
206
+ self.assign_sub(variable, self.lr_adjust(lr * update))
200
207
 
201
- def _adamw_update_step(self, gradient, variable, learning_rate):
208
+ def _adamw_update_step(self, gradient, variable, learning_rate, m, v):
202
209
  """Update step given gradient and the associated model variable."""
203
210
  lr = ops.cast(learning_rate, variable.dtype)
204
211
  gradient = ops.cast(gradient, variable.dtype)
@@ -210,9 +217,6 @@ class Muon(optimizer.Optimizer):
210
217
  ops.cast(self.adam_beta_2, variable.dtype), local_step
211
218
  )
212
219
 
213
- m = self.adam_momentums[variable.path]
214
- v = self.adam_velocities[variable.path]
215
-
216
220
  alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)
217
221
 
218
222
  self.assign_add(
@@ -239,6 +243,20 @@ class Muon(optimizer.Optimizer):
239
243
  X = ops.transpose(X, temp_order)
240
244
  return X
241
245
 
246
+ def lr_adjust(self, x):
247
+ """Adjusts learning rate based on the Moonlight implementation.
248
+ This method enhances the stability of Muon, allowing it to use the same
249
+ learning rate and weight decay as Adam. For details, see
250
+ https://arxiv.org/abs/2502.16982.
251
+ For a 2D matrix, the update is scaled by `sqrt(max(n, m)) * rms_rate`,
252
+ where `n` and `m` are the dimensions of the matrix.
253
+ """
254
+ if self.rms_rate is None:
255
+ return x
256
+ # moonlight version
257
+ # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
258
+ return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate
259
+
242
260
  def zeropower_via_newtonschulz5(self, x, steps: int):
243
261
  """We apply the Newton-Schulz iteration to compute matrix G.
244
262
 
@@ -268,6 +286,20 @@ class Muon(optimizer.Optimizer):
268
286
  x = self.transpose_last_axis(x)
269
287
  return x
270
288
 
289
+ def _apply_weight_decay(self, variables):
290
+ for variable in variables:
291
+ if not self._use_weight_decay(variable):
292
+ continue
293
+ if self._should_use_adamw(variable):
294
+ weight_decay_value = self.adam_weight_decay
295
+ else:
296
+ weight_decay_value = self.weight_decay
297
+ if weight_decay_value is None:
298
+ continue
299
+ wd = ops.cast(weight_decay_value, variable.dtype)
300
+ lr = ops.cast(self.learning_rate, variable.dtype)
301
+ variable.assign(variable - variable * wd * lr)
302
+
271
303
  def get_config(self):
272
304
  config = super().get_config()
273
305
  config.update(
@@ -284,6 +316,8 @@ class Muon(optimizer.Optimizer):
284
316
  "ns_steps": self.ns_steps,
285
317
  "nesterov": self.nesterov,
286
318
  "exclude_embeddings": self.exclude_embeddings,
319
+ "adam_weight_decay": self.adam_weight_decay,
320
+ "rms_rate": self.rms_rate,
287
321
  }
288
322
  )
289
323
  return config
@@ -584,9 +584,10 @@ class CosineDecay(LearningRateSchedule):
584
584
  schedule applies a linear increase per optimizer step to our learning rate
585
585
  from `initial_learning_rate` to `warmup_target` for a duration of
586
586
  `warmup_steps`. Afterwards, it applies a cosine decay function taking our
587
- learning rate from `warmup_target` to `alpha` for a duration of
588
- `decay_steps`. If `warmup_target` is None we skip warmup and our decay
589
- will take our learning rate from `initial_learning_rate` to `alpha`.
587
+ learning rate from `warmup_target` to `warmup_target * alpha` for a
588
+ duration of `decay_steps`. If `warmup_target` is None we skip warmup and
589
+ our decay will take our learning rate from `initial_learning_rate` to
590
+ `initial_learning_rate * alpha`.
590
591
  It requires a `step` value to compute the learning rate. You can
591
592
  just pass a backend variable that you increment at each training step.
592
593
 
@@ -1,6 +1,11 @@
1
1
  import inspect
2
2
 
3
3
  from keras.src.api_export import keras_export
4
+ from keras.src.quantizers.awq_config import AWQConfig
5
+ from keras.src.quantizers.quantization_config import Float8QuantizationConfig
6
+ from keras.src.quantizers.quantization_config import Int4QuantizationConfig
7
+ from keras.src.quantizers.quantization_config import Int8QuantizationConfig
8
+ from keras.src.quantizers.quantization_config import QuantizationConfig
4
9
  from keras.src.quantizers.quantizers import AbsMaxQuantizer
5
10
  from keras.src.quantizers.quantizers import Quantizer
6
11
  from keras.src.quantizers.quantizers import abs_max_quantize
@@ -13,7 +18,15 @@ from keras.src.quantizers.quantizers import unpack_int4
13
18
  from keras.src.saving import serialization_lib
14
19
  from keras.src.utils.naming import to_snake_case
15
20
 
16
- ALL_OBJECTS = {Quantizer, AbsMaxQuantizer}
21
+ ALL_OBJECTS = {
22
+ Quantizer,
23
+ AbsMaxQuantizer,
24
+ QuantizationConfig,
25
+ Int8QuantizationConfig,
26
+ Int4QuantizationConfig,
27
+ Float8QuantizationConfig,
28
+ AWQConfig,
29
+ }
17
30
  ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
18
31
  ALL_OBJECTS_DICT.update(
19
32
  {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}