keras-nightly 3.12.0.dev2025082103__py3-none-any.whl → 3.12.0.dev2025082303__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. keras/_tf_keras/keras/ops/__init__.py +1 -0
  2. keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
  3. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  4. keras/ops/__init__.py +1 -0
  5. keras/ops/numpy/__init__.py +1 -0
  6. keras/quantizers/__init__.py +1 -0
  7. keras/src/applications/convnext.py +20 -20
  8. keras/src/applications/densenet.py +21 -21
  9. keras/src/applications/efficientnet.py +16 -16
  10. keras/src/applications/efficientnet_v2.py +28 -28
  11. keras/src/applications/inception_resnet_v2.py +7 -7
  12. keras/src/applications/inception_v3.py +5 -5
  13. keras/src/applications/mobilenet_v2.py +13 -20
  14. keras/src/applications/mobilenet_v3.py +15 -15
  15. keras/src/applications/nasnet.py +7 -8
  16. keras/src/applications/resnet.py +32 -32
  17. keras/src/applications/xception.py +10 -10
  18. keras/src/backend/common/dtypes.py +8 -3
  19. keras/src/backend/common/variables.py +3 -1
  20. keras/src/backend/jax/export.py +1 -1
  21. keras/src/backend/jax/numpy.py +6 -0
  22. keras/src/backend/jax/trainer.py +1 -1
  23. keras/src/backend/numpy/numpy.py +28 -0
  24. keras/src/backend/openvino/numpy.py +5 -1
  25. keras/src/backend/tensorflow/numpy.py +22 -0
  26. keras/src/backend/tensorflow/trainer.py +19 -1
  27. keras/src/backend/torch/core.py +6 -9
  28. keras/src/backend/torch/nn.py +1 -2
  29. keras/src/backend/torch/numpy.py +16 -0
  30. keras/src/backend/torch/trainer.py +1 -1
  31. keras/src/callbacks/backup_and_restore.py +2 -2
  32. keras/src/callbacks/csv_logger.py +1 -1
  33. keras/src/callbacks/model_checkpoint.py +1 -1
  34. keras/src/callbacks/tensorboard.py +6 -6
  35. keras/src/constraints/constraints.py +9 -7
  36. keras/src/datasets/boston_housing.py +1 -1
  37. keras/src/datasets/california_housing.py +1 -1
  38. keras/src/datasets/cifar10.py +1 -1
  39. keras/src/datasets/cifar100.py +2 -2
  40. keras/src/datasets/imdb.py +2 -2
  41. keras/src/datasets/mnist.py +1 -1
  42. keras/src/datasets/reuters.py +2 -2
  43. keras/src/dtype_policies/dtype_policy.py +1 -1
  44. keras/src/dtype_policies/dtype_policy_map.py +1 -1
  45. keras/src/export/tf2onnx_lib.py +1 -3
  46. keras/src/initializers/constant_initializers.py +9 -5
  47. keras/src/layers/input_spec.py +6 -6
  48. keras/src/layers/layer.py +1 -1
  49. keras/src/layers/preprocessing/category_encoding.py +3 -3
  50. keras/src/layers/preprocessing/data_layer.py +159 -0
  51. keras/src/layers/preprocessing/discretization.py +3 -3
  52. keras/src/layers/preprocessing/feature_space.py +4 -4
  53. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +7 -4
  54. keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py +3 -0
  55. keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py +2 -2
  56. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +1 -1
  57. keras/src/layers/preprocessing/image_preprocessing/cut_mix.py +6 -3
  58. keras/src/layers/preprocessing/image_preprocessing/equalization.py +1 -1
  59. keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py +3 -0
  60. keras/src/layers/preprocessing/image_preprocessing/mix_up.py +7 -4
  61. keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +3 -1
  62. keras/src/layers/preprocessing/image_preprocessing/random_brightness.py +1 -1
  63. keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py +3 -0
  64. keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +3 -0
  65. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +1 -1
  66. keras/src/layers/preprocessing/image_preprocessing/random_crop.py +1 -1
  67. keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py +3 -0
  68. keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +6 -3
  69. keras/src/layers/preprocessing/image_preprocessing/random_flip.py +1 -1
  70. keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py +3 -0
  71. keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +1 -1
  72. keras/src/layers/preprocessing/image_preprocessing/random_hue.py +3 -0
  73. keras/src/layers/preprocessing/image_preprocessing/random_invert.py +3 -0
  74. keras/src/layers/preprocessing/image_preprocessing/random_perspective.py +3 -0
  75. keras/src/layers/preprocessing/image_preprocessing/random_posterization.py +3 -0
  76. keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +1 -1
  77. keras/src/layers/preprocessing/image_preprocessing/random_saturation.py +3 -0
  78. keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py +3 -0
  79. keras/src/layers/preprocessing/image_preprocessing/random_shear.py +3 -0
  80. keras/src/layers/preprocessing/image_preprocessing/random_translation.py +3 -3
  81. keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +3 -3
  82. keras/src/layers/preprocessing/image_preprocessing/resizing.py +3 -3
  83. keras/src/layers/preprocessing/image_preprocessing/solarization.py +3 -0
  84. keras/src/layers/preprocessing/mel_spectrogram.py +29 -25
  85. keras/src/layers/preprocessing/normalization.py +5 -2
  86. keras/src/layers/preprocessing/rescaling.py +3 -3
  87. keras/src/layers/rnn/bidirectional.py +4 -4
  88. keras/src/legacy/backend.py +9 -23
  89. keras/src/legacy/preprocessing/image.py +11 -22
  90. keras/src/legacy/preprocessing/text.py +1 -1
  91. keras/src/models/functional.py +2 -2
  92. keras/src/models/model.py +21 -3
  93. keras/src/ops/function.py +1 -1
  94. keras/src/ops/numpy.py +49 -5
  95. keras/src/ops/operation.py +3 -2
  96. keras/src/optimizers/base_optimizer.py +3 -4
  97. keras/src/optimizers/schedules/learning_rate_schedule.py +16 -9
  98. keras/src/quantizers/gptq.py +350 -0
  99. keras/src/quantizers/gptq_config.py +169 -0
  100. keras/src/quantizers/gptq_core.py +335 -0
  101. keras/src/quantizers/gptq_quant.py +133 -0
  102. keras/src/saving/file_editor.py +22 -20
  103. keras/src/saving/object_registration.py +1 -1
  104. keras/src/saving/saving_lib.py +4 -4
  105. keras/src/saving/serialization_lib.py +3 -5
  106. keras/src/trainers/compile_utils.py +1 -1
  107. keras/src/trainers/data_adapters/array_data_adapter.py +9 -3
  108. keras/src/trainers/data_adapters/data_adapter_utils.py +15 -5
  109. keras/src/trainers/data_adapters/generator_data_adapter.py +2 -0
  110. keras/src/trainers/data_adapters/grain_dataset_adapter.py +8 -2
  111. keras/src/trainers/data_adapters/tf_dataset_adapter.py +4 -2
  112. keras/src/trainers/data_adapters/torch_data_loader_adapter.py +3 -1
  113. keras/src/tree/dmtree_impl.py +19 -3
  114. keras/src/tree/optree_impl.py +3 -3
  115. keras/src/tree/tree_api.py +5 -2
  116. keras/src/utils/file_utils.py +13 -5
  117. keras/src/utils/io_utils.py +1 -1
  118. keras/src/utils/model_visualization.py +1 -1
  119. keras/src/utils/progbar.py +5 -5
  120. keras/src/utils/summary_utils.py +4 -4
  121. keras/src/version.py +1 -1
  122. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/METADATA +1 -1
  123. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/RECORD +125 -121
  124. keras/src/layers/preprocessing/tf_data_layer.py +0 -78
  125. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/WHEEL +0 -0
  126. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/top_level.txt +0 -0
@@ -68,11 +68,7 @@ def batch_dot(x, y, axes=None):
68
68
  raise ValueError(
69
69
  "Cannot do batch_dot on inputs "
70
70
  "with rank < 2. "
71
- "Received inputs with tf.shapes "
72
- + str(x_shape)
73
- + " and "
74
- + str(y_shape)
75
- + "."
71
+ f"Received inputs with tf.shapes {x_shape} and {y_shape}."
76
72
  )
77
73
 
78
74
  x_batch_size = x_shape[0]
@@ -84,10 +80,7 @@ def batch_dot(x, y, axes=None):
84
80
  "Cannot do batch_dot on inputs "
85
81
  "with different batch sizes. "
86
82
  "Received inputs with tf.shapes "
87
- + str(x_shape)
88
- + " and "
89
- + str(y_shape)
90
- + "."
83
+ f"{x_shape} and {y_shape}."
91
84
  )
92
85
  if isinstance(axes, int):
93
86
  axes = [axes, axes]
@@ -101,9 +94,8 @@ def batch_dot(x, y, axes=None):
101
94
  if py_any(isinstance(a, (list, tuple)) for a in axes):
102
95
  raise ValueError(
103
96
  "Multiple target dimensions are not supported. "
104
- + "Expected: None, int, (int, int), "
105
- + "Provided: "
106
- + str(axes)
97
+ "Expected: None, int, (int, int), "
98
+ f"Provided: {axes}"
107
99
  )
108
100
 
109
101
  # if tuple, convert to list.
@@ -130,12 +122,8 @@ def batch_dot(x, y, axes=None):
130
122
  if d1 is not None and d2 is not None and d1 != d2:
131
123
  raise ValueError(
132
124
  "Cannot do batch_dot on inputs with tf.shapes "
133
- + str(x_shape)
134
- + " and "
135
- + str(y_shape)
136
- + " with axes="
137
- + str(axes)
138
- + ". x.shape[%d] != y.shape[%d] (%d != %d)."
125
+ f"{x_shape} and {y_shape} with axes={axes}. "
126
+ "x.shape[%d] != y.shape[%d] (%d != %d)."
139
127
  % (axes[0], axes[1], d1, d2)
140
128
  )
141
129
 
@@ -1129,7 +1117,7 @@ def pool2d(
1129
1117
  x, pool_size, strides, padding=padding, data_format=tf_data_format
1130
1118
  )
1131
1119
  else:
1132
- raise ValueError("Invalid pooling mode: " + str(pool_mode))
1120
+ raise ValueError(f"Invalid pooling mode: {str(pool_mode)}")
1133
1121
 
1134
1122
  if data_format == "channels_first" and tf_data_format == "NHWC":
1135
1123
  x = tf.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
@@ -1169,7 +1157,7 @@ def pool3d(
1169
1157
  x, pool_size, strides, padding=padding, data_format=tf_data_format
1170
1158
  )
1171
1159
  else:
1172
- raise ValueError("Invalid pooling mode: " + str(pool_mode))
1160
+ raise ValueError(f"Invalid pooling mode: {str(pool_mode)}")
1173
1161
 
1174
1162
  if data_format == "channels_first" and tf_data_format == "NDHWC":
1175
1163
  x = tf.transpose(x, (0, 4, 1, 2, 3))
@@ -2150,9 +2138,7 @@ def switch(condition, then_expression, else_expression):
2150
2138
  "Rank of `condition` should be less than or"
2151
2139
  " equal to rank of `then_expression` and "
2152
2140
  "`else_expression`. ndim(condition)="
2153
- + str(cond_ndim)
2154
- + ", ndim(then_expression)="
2155
- + str(expr_ndim)
2141
+ f"{cond_ndim}, ndim(then_expression)={expr_ndim}"
2156
2142
  )
2157
2143
  if cond_ndim > 1:
2158
2144
  ndim_diff = expr_ndim - cond_ndim
@@ -617,17 +617,12 @@ class NumpyArrayIterator(Iterator):
617
617
  channels_axis = 3 if data_format == "channels_last" else 1
618
618
  if self.x.shape[channels_axis] not in {1, 3, 4}:
619
619
  warnings.warn(
620
- 'NumpyArrayIterator is set to use the data format convention "'
621
- + data_format
622
- + '" (channels on axis '
623
- + str(channels_axis)
624
- + "), i.e. expected either 1, 3, or 4 channels on axis "
625
- + str(channels_axis)
626
- + ". However, it was passed an array with shape "
627
- + str(self.x.shape)
628
- + " ("
629
- + str(self.x.shape[channels_axis])
630
- + " channels)."
620
+ f"NumpyArrayIterator is set to use the data format convention"
621
+ f' "{data_format}" (channels on axis {channels_axis})'
622
+ ", i.e. expected either 1, 3, or 4 channels "
623
+ f"on axis {channels_axis}. "
624
+ f"However, it was passed an array with shape {self.x.shape}"
625
+ f" ({self.x.shape[channels_axis]} channels)."
631
626
  )
632
627
  if y is not None:
633
628
  self.y = np.asarray(y)
@@ -1494,17 +1489,11 @@ class ImageDataGenerator:
1494
1489
  if x.shape[self.channel_axis] not in {1, 3, 4}:
1495
1490
  warnings.warn(
1496
1491
  "Expected input to be images (as Numpy array) "
1497
- 'following the data format convention "'
1498
- + self.data_format
1499
- + '" (channels on axis '
1500
- + str(self.channel_axis)
1501
- + "), i.e. expected either 1, 3 or 4 channels on axis "
1502
- + str(self.channel_axis)
1503
- + ". However, it was passed an array with shape "
1504
- + str(x.shape)
1505
- + " ("
1506
- + str(x.shape[self.channel_axis])
1507
- + " channels)."
1492
+ f'following the data format convention "{self.data_format}'
1493
+ f'" (channels on axis {self.channel_axis})'
1494
+ ", i.e. expected either 1, 3 or 4 channels on axis "
1495
+ f"{self.channel_axis}. However, it was passed an array with"
1496
+ f" shape {x.shape} ({x.shape[self.channel_axis]} channels)."
1508
1497
  )
1509
1498
 
1510
1499
  if seed is not None:
@@ -102,7 +102,7 @@ class Tokenizer:
102
102
  num_words = kwargs.pop("nb_words")
103
103
  document_count = kwargs.pop("document_count", 0)
104
104
  if kwargs:
105
- raise TypeError("Unrecognized keyword arguments: " + str(kwargs))
105
+ raise TypeError(f"Unrecognized keyword arguments: {str(kwargs)}")
106
106
 
107
107
  self.word_counts = collections.OrderedDict()
108
108
  self.word_docs = collections.defaultdict(int)
@@ -773,7 +773,7 @@ def is_input_keras_tensor(x):
773
773
 
774
774
  def clone_single_keras_tensor(x):
775
775
  return backend.KerasTensor(
776
- shape=x.shape, dtype=x.dtype, sparse=x.sparse, name=x.name + "_clone"
776
+ shape=x.shape, dtype=x.dtype, sparse=x.sparse, name=f"{x.name}_clone"
777
777
  )
778
778
 
779
779
 
@@ -836,7 +836,7 @@ def clone_graph_nodes(inputs, outputs):
836
836
  batch_shape=kt_input.shape,
837
837
  dtype=kt_input.dtype,
838
838
  sparse=kt_input.sparse,
839
- name=kt_input.name + "CLONE",
839
+ name=f"{kt_input.name}CLONE",
840
840
  )
841
841
  cloned_inputs.append(cloned_input)
842
842
  kt_id_mapping[id(kt_input)] = cloned_input
keras/src/models/model.py CHANGED
@@ -8,6 +8,7 @@ from keras.src import utils
8
8
  from keras.src.api_export import keras_export
9
9
  from keras.src.layers.layer import Layer
10
10
  from keras.src.models.variable_mapping import map_saveable_variables
11
+ from keras.src.quantizers.gptq_config import GPTQConfig
11
12
  from keras.src.saving import saving_api
12
13
  from keras.src.trainers import trainer as base_trainer
13
14
  from keras.src.utils import summary_utils
@@ -420,7 +421,7 @@ class Model(Trainer, base_trainer.Trainer, Layer):
420
421
  **kwargs,
421
422
  )
422
423
 
423
- def quantize(self, mode, **kwargs):
424
+ def quantize(self, mode, config=None, **kwargs):
424
425
  """Quantize the weights of the model.
425
426
 
426
427
  Note that the model must be built first before calling this method.
@@ -433,6 +434,23 @@ class Model(Trainer, base_trainer.Trainer, Layer):
433
434
  """
434
435
  from keras.src.dtype_policies import QUANTIZATION_MODES
435
436
 
437
+ if mode == "gptq":
438
+ if not isinstance(config, GPTQConfig):
439
+ raise ValueError(
440
+ "The `config` argument must be of type "
441
+ "`keras.quantizers.GPTQConfig`."
442
+ )
443
+ # The config object's own quantize method drives the process
444
+ config.quantize(self)
445
+ return
446
+
447
+ # For all other modes, verify that a config object was not passed.
448
+ if config is not None:
449
+ raise ValueError(
450
+ f"The `config` argument is only supported for 'gptq' mode, "
451
+ f"but received mode='{mode}'."
452
+ )
453
+
436
454
  type_check = kwargs.pop("type_check", True)
437
455
  if kwargs:
438
456
  raise ValueError(
@@ -854,9 +872,9 @@ class Model(Trainer, base_trainer.Trainer, Layer):
854
872
  def _flatten(current_dict, prefix=""):
855
873
  for key, value in current_dict.items():
856
874
  if isinstance(value, dict):
857
- _flatten(value, prefix + key + "/")
875
+ _flatten(value, f"{prefix}{key}/")
858
876
  else:
859
- flat_dict[prefix + key] = value
877
+ flat_dict[f"{prefix}{key}"] = value
860
878
 
861
879
  _flatten(nested_dict)
862
880
  return flat_dict
keras/src/ops/function.py CHANGED
@@ -244,7 +244,7 @@ class Function(Operation):
244
244
 
245
245
 
246
246
  def make_node_key(op, node_index):
247
- return str(id(op)) + "_ib-" + str(node_index)
247
+ return f"{id(op)}_ib-{node_index}"
248
248
 
249
249
 
250
250
  def map_graph(inputs, outputs):
keras/src/ops/numpy.py CHANGED
@@ -2872,7 +2872,7 @@ class Einsum(Operation):
2872
2872
  kept_dims = sorted(kept_dims)
2873
2873
 
2874
2874
  if output_spec is None:
2875
- target_broadcast_spec = "..." + "".join(kept_dims)
2875
+ target_broadcast_spec = f"...{''.join(kept_dims)}"
2876
2876
  else:
2877
2877
  target_broadcast_spec = output_spec
2878
2878
 
@@ -2894,18 +2894,18 @@ class Einsum(Operation):
2894
2894
  )
2895
2895
  for size, s in zip(x_shape, split_spec[0]):
2896
2896
  # Replace the letter with the right shape.
2897
- expanded_shape = expanded_shape.replace(s, str(size) + " ")
2897
+ expanded_shape = expanded_shape.replace(s, f"{str(size)} ")
2898
2898
  expanded_shape = expanded_shape.replace("...", "")
2899
2899
  else:
2900
2900
  # In this case, the input spec has "...", e.g., "i...j", "i...",
2901
2901
  # or "...j".
2902
2902
  for i in range(len(split_spec[0])):
2903
2903
  expanded_shape = expanded_shape.replace(
2904
- split_spec[0][i], str(x_shape[i]) + " "
2904
+ split_spec[0][i], f"{x_shape[i]} "
2905
2905
  )
2906
2906
  for i in range(len(split_spec[1])):
2907
2907
  expanded_shape = expanded_shape.replace(
2908
- split_spec[1][-i - 1], str(x_shape[-i - 1]) + " "
2908
+ split_spec[1][-i - 1], f"{x_shape[-i - 1]} "
2909
2909
  )
2910
2910
  # Shape matched by "..." will be inserted to the position of
2911
2911
  # "...".
@@ -2919,7 +2919,7 @@ class Einsum(Operation):
2919
2919
  wildcard_shape_start_index:wildcard_shape_end_index
2920
2920
  ]
2921
2921
  wildcard_shape_str = (
2922
- " ".join([str(size) for size in wildcard_shape]) + " "
2922
+ f"{' '.join([str(size) for size in wildcard_shape])} "
2923
2923
  )
2924
2924
  expanded_shape = expanded_shape.replace(
2925
2925
  "...", wildcard_shape_str
@@ -3514,6 +3514,50 @@ def hstack(xs):
3514
3514
  return backend.numpy.hstack(xs)
3515
3515
 
3516
3516
 
3517
+ class Hypot(Operation):
3518
+ def call(self, x1, x2):
3519
+ return backend.numpy.hypot(x1, x2)
3520
+
3521
+ def compute_output_spec(self, x1, x2):
3522
+ dtype = dtypes.result_type(x1.dtype, x2.dtype)
3523
+ if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]:
3524
+ dtype = backend.floatx()
3525
+ elif dtype == "int64":
3526
+ dtype = "float64"
3527
+ return KerasTensor(broadcast_shapes(x1.shape, x2.shape), dtype=dtype)
3528
+
3529
+
3530
+ @keras_export(["keras.ops.hypot", "keras.ops.numpy.hypot"])
3531
+ def hypot(x1, x2):
3532
+ """Element-wise hypotenuse of right triangles with legs `x1` and `x2`.
3533
+
3534
+ This is equivalent to computing `sqrt(x1**2 + x2**2)` element-wise,
3535
+ with shape determined by broadcasting.
3536
+
3537
+ Args:
3538
+ x1: A tensor, representing the first leg of the right triangle.
3539
+ x2: A tensor, representing the second leg of the right triangle.
3540
+
3541
+ Returns:
3542
+ A tensor with a shape determined by broadcasting `x1` and `x2`.
3543
+
3544
+ Example:
3545
+ >>> x1 = keras.ops.convert_to_tensor([3.0, 4.0, 5.0])
3546
+ >>> x2 = keras.ops.convert_to_tensor([4.0, 3.0, 12.0])
3547
+ >>> keras.ops.hypot(x1, x2)
3548
+ array([5., 5., 13.], dtype=float32)
3549
+
3550
+ >>> x1 = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
3551
+ >>> x2 = keras.ops.convert_to_tensor([1, 1])
3552
+ >>> keras.ops.hypot(x1, x2)
3553
+ array([[1.41421356 2.23606798],
3554
+ [3.16227766 4.12310563]], dtype=float32)
3555
+ """
3556
+ if any_symbolic_tensors((x1, x2)):
3557
+ return Hypot().symbolic_call(x1, x2)
3558
+ return backend.numpy.hypot(x1, x2)
3559
+
3560
+
3517
3561
  @keras_export(["keras.ops.identity", "keras.ops.numpy.identity"])
3518
3562
  def identity(n, dtype=None):
3519
3563
  """Return the identity tensor.
@@ -1,4 +1,5 @@
1
1
  import inspect
2
+ import os.path
2
3
  import textwrap
3
4
 
4
5
  from keras.src import backend
@@ -19,10 +20,10 @@ class Operation(KerasSaveable):
19
20
  def __init__(self, name=None):
20
21
  if name is None:
21
22
  name = auto_name(self.__class__.__name__)
22
- if not isinstance(name, str) or "/" in name:
23
+ if not isinstance(name, str) or os.path.sep in name:
23
24
  raise ValueError(
24
25
  "Argument `name` must be a string and "
25
- "cannot contain character `/`. "
26
+ f"cannot contain character `{os.path.sep}`. "
26
27
  f"Received: name={name} (of type {type(name)})"
27
28
  )
28
29
  self.name = name
@@ -310,13 +310,12 @@ class BaseOptimizer(KerasSaveable):
310
310
  """
311
311
  name = name or "var"
312
312
  if hasattr(reference_variable, "path"):
313
- name = reference_variable.path.replace("/", "_") + "_" + name
313
+ name = f"{reference_variable.path.replace('/', '_')}_{name}"
314
314
  else:
315
- name = (
315
+ sanitised_ref_name = (
316
316
  str(reference_variable.name).replace("/", "_").replace(":", "_")
317
- + "_"
318
- + name
319
317
  )
318
+ name = f"{sanitised_ref_name}_{name}"
320
319
  return self.add_variable(
321
320
  shape=reference_variable.shape,
322
321
  initializer=initializer,
@@ -692,9 +692,11 @@ class CosineDecay(LearningRateSchedule):
692
692
 
693
693
  def _decay_function(self, step, decay_steps, decay_from_lr, dtype):
694
694
  with ops.name_scope(self.name):
695
- completed_fraction = step / decay_steps
695
+ completed_fraction = ops.divide(step, decay_steps)
696
696
  pi = ops.array(math.pi, dtype=dtype)
697
- cosine_decayed = 0.5 * (1.0 + ops.cos(pi * completed_fraction))
697
+ cosine_decayed = 0.5 * (
698
+ 1.0 + ops.cos(ops.multiply(pi, completed_fraction))
699
+ )
698
700
  decayed = (1 - self.alpha) * cosine_decayed + self.alpha
699
701
  return ops.multiply(decay_from_lr, decayed)
700
702
 
@@ -866,10 +868,13 @@ class CosineDecayRestarts(LearningRateSchedule):
866
868
  / ops.log(t_mul)
867
869
  )
868
870
 
869
- sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
870
- completed_fraction = (
871
- completed_fraction - sum_r
872
- ) / t_mul**i_restart
871
+ sum_r = ops.divide(
872
+ 1.0 - ops.power(t_mul, i_restart), (1.0 - t_mul)
873
+ )
874
+ completed_fraction = ops.divide(
875
+ ops.subtract(completed_fraction, sum_r),
876
+ ops.power(t_mul, i_restart),
877
+ )
873
878
 
874
879
  else:
875
880
  i_restart = ops.floor(completed_fraction)
@@ -883,18 +888,20 @@ class CosineDecayRestarts(LearningRateSchedule):
883
888
  lambda: compute_step(completed_fraction, geometric=True),
884
889
  )
885
890
 
886
- m_fac = m_mul**i_restart
891
+ m_fac = ops.power(m_mul, i_restart)
887
892
  cosine_decayed = (
888
893
  0.5
889
894
  * m_fac
890
895
  * (
891
896
  1.0
892
897
  + ops.cos(
893
- ops.array(math.pi, dtype=dtype) * completed_fraction
898
+ ops.multiply(
899
+ ops.array(math.pi, dtype=dtype), completed_fraction
900
+ )
894
901
  )
895
902
  )
896
903
  )
897
- decayed = (1 - alpha) * cosine_decayed + alpha
904
+ decayed = ops.add(ops.multiply((1 - alpha), cosine_decayed), alpha)
898
905
 
899
906
  return ops.multiply(initial_learning_rate, decayed)
900
907