keras-nightly 3.12.0.dev2025082103__py3-none-any.whl → 3.12.0.dev2025082303__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. keras/_tf_keras/keras/ops/__init__.py +1 -0
  2. keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
  3. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  4. keras/ops/__init__.py +1 -0
  5. keras/ops/numpy/__init__.py +1 -0
  6. keras/quantizers/__init__.py +1 -0
  7. keras/src/applications/convnext.py +20 -20
  8. keras/src/applications/densenet.py +21 -21
  9. keras/src/applications/efficientnet.py +16 -16
  10. keras/src/applications/efficientnet_v2.py +28 -28
  11. keras/src/applications/inception_resnet_v2.py +7 -7
  12. keras/src/applications/inception_v3.py +5 -5
  13. keras/src/applications/mobilenet_v2.py +13 -20
  14. keras/src/applications/mobilenet_v3.py +15 -15
  15. keras/src/applications/nasnet.py +7 -8
  16. keras/src/applications/resnet.py +32 -32
  17. keras/src/applications/xception.py +10 -10
  18. keras/src/backend/common/dtypes.py +8 -3
  19. keras/src/backend/common/variables.py +3 -1
  20. keras/src/backend/jax/export.py +1 -1
  21. keras/src/backend/jax/numpy.py +6 -0
  22. keras/src/backend/jax/trainer.py +1 -1
  23. keras/src/backend/numpy/numpy.py +28 -0
  24. keras/src/backend/openvino/numpy.py +5 -1
  25. keras/src/backend/tensorflow/numpy.py +22 -0
  26. keras/src/backend/tensorflow/trainer.py +19 -1
  27. keras/src/backend/torch/core.py +6 -9
  28. keras/src/backend/torch/nn.py +1 -2
  29. keras/src/backend/torch/numpy.py +16 -0
  30. keras/src/backend/torch/trainer.py +1 -1
  31. keras/src/callbacks/backup_and_restore.py +2 -2
  32. keras/src/callbacks/csv_logger.py +1 -1
  33. keras/src/callbacks/model_checkpoint.py +1 -1
  34. keras/src/callbacks/tensorboard.py +6 -6
  35. keras/src/constraints/constraints.py +9 -7
  36. keras/src/datasets/boston_housing.py +1 -1
  37. keras/src/datasets/california_housing.py +1 -1
  38. keras/src/datasets/cifar10.py +1 -1
  39. keras/src/datasets/cifar100.py +2 -2
  40. keras/src/datasets/imdb.py +2 -2
  41. keras/src/datasets/mnist.py +1 -1
  42. keras/src/datasets/reuters.py +2 -2
  43. keras/src/dtype_policies/dtype_policy.py +1 -1
  44. keras/src/dtype_policies/dtype_policy_map.py +1 -1
  45. keras/src/export/tf2onnx_lib.py +1 -3
  46. keras/src/initializers/constant_initializers.py +9 -5
  47. keras/src/layers/input_spec.py +6 -6
  48. keras/src/layers/layer.py +1 -1
  49. keras/src/layers/preprocessing/category_encoding.py +3 -3
  50. keras/src/layers/preprocessing/data_layer.py +159 -0
  51. keras/src/layers/preprocessing/discretization.py +3 -3
  52. keras/src/layers/preprocessing/feature_space.py +4 -4
  53. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +7 -4
  54. keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py +3 -0
  55. keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py +2 -2
  56. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +1 -1
  57. keras/src/layers/preprocessing/image_preprocessing/cut_mix.py +6 -3
  58. keras/src/layers/preprocessing/image_preprocessing/equalization.py +1 -1
  59. keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py +3 -0
  60. keras/src/layers/preprocessing/image_preprocessing/mix_up.py +7 -4
  61. keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +3 -1
  62. keras/src/layers/preprocessing/image_preprocessing/random_brightness.py +1 -1
  63. keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py +3 -0
  64. keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +3 -0
  65. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +1 -1
  66. keras/src/layers/preprocessing/image_preprocessing/random_crop.py +1 -1
  67. keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py +3 -0
  68. keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +6 -3
  69. keras/src/layers/preprocessing/image_preprocessing/random_flip.py +1 -1
  70. keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py +3 -0
  71. keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +1 -1
  72. keras/src/layers/preprocessing/image_preprocessing/random_hue.py +3 -0
  73. keras/src/layers/preprocessing/image_preprocessing/random_invert.py +3 -0
  74. keras/src/layers/preprocessing/image_preprocessing/random_perspective.py +3 -0
  75. keras/src/layers/preprocessing/image_preprocessing/random_posterization.py +3 -0
  76. keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +1 -1
  77. keras/src/layers/preprocessing/image_preprocessing/random_saturation.py +3 -0
  78. keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py +3 -0
  79. keras/src/layers/preprocessing/image_preprocessing/random_shear.py +3 -0
  80. keras/src/layers/preprocessing/image_preprocessing/random_translation.py +3 -3
  81. keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +3 -3
  82. keras/src/layers/preprocessing/image_preprocessing/resizing.py +3 -3
  83. keras/src/layers/preprocessing/image_preprocessing/solarization.py +3 -0
  84. keras/src/layers/preprocessing/mel_spectrogram.py +29 -25
  85. keras/src/layers/preprocessing/normalization.py +5 -2
  86. keras/src/layers/preprocessing/rescaling.py +3 -3
  87. keras/src/layers/rnn/bidirectional.py +4 -4
  88. keras/src/legacy/backend.py +9 -23
  89. keras/src/legacy/preprocessing/image.py +11 -22
  90. keras/src/legacy/preprocessing/text.py +1 -1
  91. keras/src/models/functional.py +2 -2
  92. keras/src/models/model.py +21 -3
  93. keras/src/ops/function.py +1 -1
  94. keras/src/ops/numpy.py +49 -5
  95. keras/src/ops/operation.py +3 -2
  96. keras/src/optimizers/base_optimizer.py +3 -4
  97. keras/src/optimizers/schedules/learning_rate_schedule.py +16 -9
  98. keras/src/quantizers/gptq.py +350 -0
  99. keras/src/quantizers/gptq_config.py +169 -0
  100. keras/src/quantizers/gptq_core.py +335 -0
  101. keras/src/quantizers/gptq_quant.py +133 -0
  102. keras/src/saving/file_editor.py +22 -20
  103. keras/src/saving/object_registration.py +1 -1
  104. keras/src/saving/saving_lib.py +4 -4
  105. keras/src/saving/serialization_lib.py +3 -5
  106. keras/src/trainers/compile_utils.py +1 -1
  107. keras/src/trainers/data_adapters/array_data_adapter.py +9 -3
  108. keras/src/trainers/data_adapters/data_adapter_utils.py +15 -5
  109. keras/src/trainers/data_adapters/generator_data_adapter.py +2 -0
  110. keras/src/trainers/data_adapters/grain_dataset_adapter.py +8 -2
  111. keras/src/trainers/data_adapters/tf_dataset_adapter.py +4 -2
  112. keras/src/trainers/data_adapters/torch_data_loader_adapter.py +3 -1
  113. keras/src/tree/dmtree_impl.py +19 -3
  114. keras/src/tree/optree_impl.py +3 -3
  115. keras/src/tree/tree_api.py +5 -2
  116. keras/src/utils/file_utils.py +13 -5
  117. keras/src/utils/io_utils.py +1 -1
  118. keras/src/utils/model_visualization.py +1 -1
  119. keras/src/utils/progbar.py +5 -5
  120. keras/src/utils/summary_utils.py +4 -4
  121. keras/src/version.py +1 -1
  122. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/METADATA +1 -1
  123. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/RECORD +125 -121
  124. keras/src/layers/preprocessing/tf_data_layer.py +0 -78
  125. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/WHEEL +0 -0
  126. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082303.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import contextlib
2
+ import functools
2
3
  import warnings
3
4
 
4
5
  import numpy as np
@@ -107,6 +108,21 @@ class TensorFlowTrainer(base_trainer.Trainer):
107
108
  y_pred = self(x)
108
109
  return y_pred
109
110
 
111
+ def _autoconvert_optionals(self, step_func):
112
+ # Wrapper converting (nested) TF Optional in input data to None
113
+ @functools.wraps(step_func)
114
+ def wrapper(data):
115
+ converted_data = tree.map_structure(
116
+ lambda i: (
117
+ None if isinstance(i, tf.experimental.Optional) else i
118
+ ),
119
+ data,
120
+ )
121
+ result = step_func(converted_data)
122
+ return result
123
+
124
+ return wrapper
125
+
110
126
  def _make_function(self, step_function):
111
127
  @tf.autograph.experimental.do_not_convert
112
128
  def one_step_on_data(data):
@@ -125,6 +141,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
125
141
  reduce_retracing=True,
126
142
  jit_compile=self.jit_compile,
127
143
  )
144
+ one_step_on_data = self._autoconvert_optionals(one_step_on_data)
128
145
 
129
146
  @tf.autograph.experimental.do_not_convert
130
147
  def multi_step_on_iterator(iterator):
@@ -253,6 +270,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
253
270
  one_step_on_data = tf.function(
254
271
  one_step_on_data, reduce_retracing=True, jit_compile=True
255
272
  )
273
+ one_step_on_data = self._autoconvert_optionals(one_step_on_data)
256
274
 
257
275
  @tf.autograph.experimental.do_not_convert
258
276
  def one_step_on_data_distributed(data):
@@ -409,7 +427,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
409
427
  _use_cached_eval_dataset=True,
410
428
  )
411
429
  val_logs = {
412
- "val_" + name: val for name, val in val_logs.items()
430
+ f"val_{name}": val for name, val in val_logs.items()
413
431
  }
414
432
  epoch_logs.update(val_logs)
415
433
 
@@ -191,21 +191,18 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
191
191
  raise ValueError("`sparse=True` is not supported with torch backend")
192
192
  if ragged:
193
193
  raise ValueError("`ragged=True` is not supported with torch backend")
194
- if isinstance(x, Variable):
195
- if dtype is None:
196
- return x.value
197
- x = x.value
198
- return x.to(to_torch_dtype(dtype))
199
- if is_tensor(x):
194
+ if isinstance(x, Variable) or is_tensor(x):
195
+ if isinstance(x, Variable):
196
+ x = x.value
200
197
  device = get_device()
201
198
  if x.device != device:
202
199
  if x.is_meta:
203
200
  x = torch.empty_like(x, device=device)
204
201
  else:
205
202
  x = x.to(device)
206
- if dtype is None:
207
- return x
208
- return x.to(to_torch_dtype(dtype))
203
+ if dtype is not None:
204
+ x = x.to(to_torch_dtype(dtype))
205
+ return x
209
206
  if dtype is None:
210
207
  if isinstance(x, bool):
211
208
  return torch.as_tensor(x, dtype=torch.bool, device=get_device())
@@ -9,7 +9,6 @@ from keras.src.backend.torch.core import cast
9
9
  from keras.src.backend.torch.core import convert_to_tensor
10
10
  from keras.src.backend.torch.core import get_device
11
11
  from keras.src.backend.torch.numpy import expand_dims
12
- from keras.src.backend.torch.numpy import maximum
13
12
  from keras.src.backend.torch.numpy import where
14
13
  from keras.src.utils.argument_validation import standardize_tuple
15
14
 
@@ -668,7 +667,7 @@ def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
668
667
  # manual handling for negatives in the input to one_hot by using max(x, 0).
669
668
  # The output will have some invalid results, so we set them back to 0 using
670
669
  # `where` afterwards.
671
- output = tnn.one_hot(maximum(x, 0), num_classes)
670
+ output = tnn.one_hot(torch.clamp(x, min=0), num_classes)
672
671
  output = where(expand_dims(x, axis=-1) >= 0, output, zero)
673
672
  output = convert_to_tensor(output, dtype=dtype)
674
673
  dims = output.dim()
@@ -854,6 +854,22 @@ def hstack(xs):
854
854
  return torch.hstack(xs)
855
855
 
856
856
 
857
+ def hypot(x1, x2):
858
+ x1 = convert_to_tensor(x1)
859
+ x2 = convert_to_tensor(x2)
860
+
861
+ dtype = dtypes.result_type(x1.dtype, x2.dtype)
862
+ if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]:
863
+ dtype = config.floatx()
864
+ elif dtype == "int64":
865
+ dtype = "float64"
866
+
867
+ x1 = cast(x1, dtype)
868
+ x2 = cast(x2, dtype)
869
+
870
+ return torch.hypot(x1, x2)
871
+
872
+
857
873
  def identity(n, dtype=None):
858
874
  dtype = to_torch_dtype(dtype or config.floatx())
859
875
 
@@ -299,7 +299,7 @@ class TorchTrainer(base_trainer.Trainer):
299
299
  _use_cached_eval_dataset=True,
300
300
  )
301
301
  val_logs = {
302
- "val_" + name: val for name, val in val_logs.items()
302
+ f"val_{name}": val for name, val in val_logs.items()
303
303
  }
304
304
  epoch_logs.update(val_logs)
305
305
 
@@ -99,9 +99,9 @@ class BackupAndRestore(Callback):
99
99
  self._training_metadata_path = file_utils.join(
100
100
  backup_dir, "training_metadata.json"
101
101
  )
102
- self._prev_weights_path = self._weights_path + ".bkp"
102
+ self._prev_weights_path = f"{self._weights_path}.bkp"
103
103
  self._prev_training_metadata_path = (
104
- self._training_metadata_path + ".bkp"
104
+ f"{self._training_metadata_path}.bkp"
105
105
  )
106
106
  if save_freq != "epoch" and not isinstance(save_freq, int):
107
107
  raise ValueError(
@@ -79,7 +79,7 @@ class CSVLogger(Callback):
79
79
  val_keys_found = True
80
80
  break
81
81
  if not val_keys_found and self.keys:
82
- self.keys.extend(["val_" + k for k in self.keys])
82
+ self.keys.extend([f"val_{k}" for k in self.keys])
83
83
 
84
84
  if not self.writer:
85
85
 
@@ -372,7 +372,7 @@ class ModelCheckpoint(MonitorCallback):
372
372
  """
373
373
  dir_name = os.path.dirname(pattern)
374
374
  base_name = os.path.basename(pattern)
375
- base_name_regex = "^" + re.sub(r"{.*}", r".*", base_name) + "$"
375
+ base_name_regex = f"^{re.sub(r'{.*}', r'.*', base_name)}$"
376
376
 
377
377
  latest_mod_time = 0
378
378
  file_path_with_latest_mod_time = None
@@ -424,7 +424,7 @@ class TensorBoard(Callback):
424
424
  with self._val_writer.as_default():
425
425
  for name, value in logs.items():
426
426
  self.summary.scalar(
427
- "evaluation_" + name + "_vs_iterations",
427
+ f"evaluation_{name}_vs_iterations",
428
428
  value,
429
429
  step=self.model.optimizer.iterations,
430
430
  )
@@ -460,7 +460,7 @@ class TensorBoard(Callback):
460
460
  if isinstance(logs, dict):
461
461
  for name, value in logs.items():
462
462
  self.summary.scalar(
463
- "batch_" + name, value, step=self._global_train_batch
463
+ f"batch_{name}", value, step=self._global_train_batch
464
464
  )
465
465
 
466
466
  if not self._should_trace:
@@ -548,12 +548,12 @@ class TensorBoard(Callback):
548
548
  if train_logs:
549
549
  with self._train_writer.as_default():
550
550
  for name, value in train_logs.items():
551
- self.summary.scalar("epoch_" + name, value, step=epoch)
551
+ self.summary.scalar(f"epoch_{name}", value, step=epoch)
552
552
  if val_logs:
553
553
  with self._val_writer.as_default():
554
554
  for name, value in val_logs.items():
555
555
  name = name[4:] # Remove 'val_' prefix.
556
- self.summary.scalar("epoch_" + name, value, step=epoch)
556
+ self.summary.scalar(f"epoch_{name}", value, step=epoch)
557
557
 
558
558
  def _log_weights(self, epoch):
559
559
  """Logs the weights of the Model to TensorBoard."""
@@ -562,14 +562,14 @@ class TensorBoard(Callback):
562
562
  for weight in layer.weights:
563
563
  weight_name = weight.name.replace(":", "_")
564
564
  # Add a suffix to prevent summary tag name collision.
565
- histogram_weight_name = weight_name + "/histogram"
565
+ histogram_weight_name = f"{weight_name}/histogram"
566
566
  self.summary.histogram(
567
567
  histogram_weight_name, weight, step=epoch
568
568
  )
569
569
  if self.write_images:
570
570
  # Add a suffix to prevent summary tag name
571
571
  # collision.
572
- image_weight_name = weight_name + "/image"
572
+ image_weight_name = f"{weight_name}/image"
573
573
  self._log_weight_as_image(
574
574
  weight, image_weight_name, epoch
575
575
  )
@@ -110,7 +110,9 @@ class MaxNorm(Constraint):
110
110
  w = backend.convert_to_tensor(w)
111
111
  norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True))
112
112
  desired = ops.clip(norms, 0, self.max_value)
113
- return w * (desired / (backend.epsilon() + norms))
113
+ return ops.cast(w, norms.dtype) * (
114
+ desired / (backend.epsilon() + norms)
115
+ )
114
116
 
115
117
  def get_config(self):
116
118
  return {"max_value": self.max_value, "axis": self.axis}
@@ -122,7 +124,7 @@ class NonNeg(Constraint):
122
124
 
123
125
  def __call__(self, w):
124
126
  w = backend.convert_to_tensor(w)
125
- return w * ops.cast(ops.greater_equal(w, 0.0), dtype=w.dtype)
127
+ return ops.multiply(w, ops.greater_equal(w, 0.0))
126
128
 
127
129
 
128
130
  @keras_export(["keras.constraints.UnitNorm", "keras.constraints.unit_norm"])
@@ -148,10 +150,8 @@ class UnitNorm(Constraint):
148
150
 
149
151
  def __call__(self, w):
150
152
  w = backend.convert_to_tensor(w)
151
- return w / (
152
- backend.epsilon()
153
- + ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True))
154
- )
153
+ norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True))
154
+ return ops.cast(w, norms.dtype) / (backend.epsilon() + norms)
155
155
 
156
156
  def get_config(self):
157
157
  return {"axis": self.axis}
@@ -202,7 +202,9 @@ class MinMaxNorm(Constraint):
202
202
  self.rate * ops.clip(norms, self.min_value, self.max_value)
203
203
  + (1 - self.rate) * norms
204
204
  )
205
- return w * (desired / (backend.epsilon() + norms))
205
+ return ops.cast(w, norms.dtype) * (
206
+ desired / (backend.epsilon() + norms)
207
+ )
206
208
 
207
209
  def get_config(self):
208
210
  return {
@@ -48,7 +48,7 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
48
48
  )
49
49
  path = get_file(
50
50
  path,
51
- origin=origin_folder + "boston_housing.npz",
51
+ origin=f"{origin_folder}boston_housing.npz",
52
52
  file_hash=( # noqa: E501
53
53
  "f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
54
54
  ),
@@ -73,7 +73,7 @@ def load_data(
73
73
  )
74
74
  path = get_file(
75
75
  path,
76
- origin=origin_folder + "california_housing.npz",
76
+ origin=f"{origin_folder}california_housing.npz",
77
77
  file_hash=( # noqa: E501
78
78
  "1a2e3a52e0398de6463aebe6f4a8da34fb21fbb6b934cf88c3425e766f2a1a6f"
79
79
  ),
@@ -79,7 +79,7 @@ def load_data():
79
79
  # batches are within an inner folder
80
80
  path = os.path.join(path, "cifar-10-batches-py")
81
81
  for i in range(1, 6):
82
- fpath = os.path.join(path, "data_batch_" + str(i))
82
+ fpath = os.path.join(path, f"data_batch_{i}")
83
83
  (
84
84
  x_train[(i - 1) * 10000 : i * 10000, :, :, :],
85
85
  y_train[(i - 1) * 10000 : i * 10000],
@@ -71,10 +71,10 @@ def load_data(label_mode="fine"):
71
71
 
72
72
  path = os.path.join(path, "cifar-100-python")
73
73
  fpath = os.path.join(path, "train")
74
- x_train, y_train = load_batch(fpath, label_key=label_mode + "_labels")
74
+ x_train, y_train = load_batch(fpath, label_key=f"{label_mode}_labels")
75
75
 
76
76
  fpath = os.path.join(path, "test")
77
- x_test, y_test = load_batch(fpath, label_key=label_mode + "_labels")
77
+ x_test, y_test = load_batch(fpath, label_key=f"{label_mode}_labels")
78
78
 
79
79
  y_train = np.reshape(y_train, (len(y_train), 1))
80
80
  y_test = np.reshape(y_test, (len(y_test), 1))
@@ -78,7 +78,7 @@ def load_data(
78
78
  )
79
79
  path = get_file(
80
80
  fname=path,
81
- origin=origin_folder + "imdb.npz",
81
+ origin=f"{origin_folder}imdb.npz",
82
82
  file_hash=( # noqa: E501
83
83
  "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
84
84
  ),
@@ -181,7 +181,7 @@ def get_word_index(path="imdb_word_index.json"):
181
181
  )
182
182
  path = get_file(
183
183
  fname=path,
184
- origin=origin_folder + "imdb_word_index.json",
184
+ origin=f"{origin_folder}imdb_word_index.json",
185
185
  file_hash="bfafd718b763782e994055a2d397834f",
186
186
  )
187
187
  with open(path) as f:
@@ -59,7 +59,7 @@ def load_data(path="mnist.npz"):
59
59
  )
60
60
  path = get_file(
61
61
  fname=path,
62
- origin=origin_folder + "mnist.npz",
62
+ origin=f"{origin_folder}mnist.npz",
63
63
  file_hash=( # noqa: E501
64
64
  "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
65
65
  ),
@@ -87,7 +87,7 @@ def load_data(
87
87
  )
88
88
  path = get_file(
89
89
  fname=path,
90
- origin=origin_folder + "reuters.npz",
90
+ origin=f"{origin_folder}reuters.npz",
91
91
  file_hash=( # noqa: E501
92
92
  "d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916"
93
93
  ),
@@ -156,7 +156,7 @@ def get_word_index(path="reuters_word_index.json"):
156
156
  )
157
157
  path = get_file(
158
158
  path,
159
- origin=origin_folder + "reuters_word_index.json",
159
+ origin=f"{origin_folder}reuters_word_index.json",
160
160
  file_hash="4d44cc38712099c9e383dc6e5f11a921",
161
161
  )
162
162
  with open(path) as f:
@@ -3,7 +3,7 @@ from keras.src import ops
3
3
  from keras.src.api_export import keras_export
4
4
  from keras.src.backend.common import global_state
5
5
 
6
- QUANTIZATION_MODES = ("int8", "float8", "int4")
6
+ QUANTIZATION_MODES = ("int8", "float8", "int4", "gptq")
7
7
 
8
8
 
9
9
  @keras_export(
@@ -74,7 +74,7 @@ class DTypePolicyMap(DTypePolicy, MutableMapping):
74
74
 
75
75
  @property
76
76
  def name(self):
77
- return "map_" + self.default_policy._name
77
+ return f"map_{self.default_policy._name}"
78
78
 
79
79
  @property
80
80
  def default_policy(self):
@@ -157,9 +157,7 @@ def patch_tf2onnx():
157
157
  ):
158
158
  a = copy.deepcopy(a)
159
159
  tensor_name = (
160
- self.name.strip()
161
- + "_"
162
- + str(external_tensor_storage.name_counter)
160
+ f"{self.name.strip()}_{external_tensor_storage.name_counter}"
163
161
  )
164
162
  for c in '~"#%&*:<>?/\\{|}':
165
163
  tensor_name = tensor_name.replace(c, "_")
@@ -253,14 +253,18 @@ class STFT(Initializer):
253
253
  scaling = ops.sum(ops.abs(win))
254
254
 
255
255
  _fft_length = (fft_length - 1) * 2
256
- freq = (
257
- ops.reshape(ops.arange(fft_length, dtype=dtype), (1, 1, fft_length))
258
- / _fft_length
256
+ freq = ops.divide(
257
+ ops.reshape(
258
+ ops.arange(fft_length, dtype=dtype), (1, 1, fft_length)
259
+ ),
260
+ _fft_length,
259
261
  )
260
262
  time = ops.reshape(
261
263
  ops.arange(frame_length, dtype=dtype), (frame_length, 1, 1)
262
264
  )
263
- args = -2 * time * freq * ops.arccos(ops.cast(-1, dtype))
265
+ args = ops.multiply(ops.multiply(-2, time), freq) * ops.arccos(
266
+ ops.cast(-1, dtype)
267
+ )
264
268
 
265
269
  if self.side == "real":
266
270
  kernel = ops.cast(ops.cos(args), dtype)
@@ -268,7 +272,7 @@ class STFT(Initializer):
268
272
  kernel = ops.cast(ops.sin(args), dtype)
269
273
 
270
274
  if win is not None:
271
- kernel = kernel * win / scaling
275
+ kernel = ops.divide(ops.multiply(kernel, win), scaling)
272
276
  return kernel
273
277
 
274
278
  def get_config(self):
@@ -94,12 +94,12 @@ class InputSpec:
94
94
 
95
95
  def __repr__(self):
96
96
  spec = [
97
- ("dtype=" + str(self.dtype)) if self.dtype else "",
98
- ("shape=" + str(self.shape)) if self.shape else "",
99
- ("ndim=" + str(self.ndim)) if self.ndim else "",
100
- ("max_ndim=" + str(self.max_ndim)) if self.max_ndim else "",
101
- ("min_ndim=" + str(self.min_ndim)) if self.min_ndim else "",
102
- ("axes=" + str(self.axes)) if self.axes else "",
97
+ (f"dtype={str(self.dtype)}") if self.dtype else "",
98
+ (f"shape={str(self.shape)}") if self.shape else "",
99
+ (f"ndim={str(self.ndim)}") if self.ndim else "",
100
+ (f"max_ndim={str(self.max_ndim)}") if self.max_ndim else "",
101
+ (f"min_ndim={str(self.min_ndim)}") if self.min_ndim else "",
102
+ (f"axes={str(self.axes)}") if self.axes else "",
103
103
  ]
104
104
  return f"InputSpec({', '.join(x for x in spec if x)})"
105
105
 
keras/src/layers/layer.py CHANGED
@@ -1337,7 +1337,7 @@ class Layer(BackendLayer, Operation):
1337
1337
  else:
1338
1338
  attr_name = str(attr)
1339
1339
  attr_type = "attribute"
1340
- msg = " " + msg if msg is not None else ""
1340
+ msg = f" {msg}" if msg is not None else ""
1341
1341
  return NotImplementedError(
1342
1342
  f"Layer {self.__class__.__name__} does not have a `{attr_name}` "
1343
1343
  f"{attr_type} implemented.{msg}"
@@ -1,12 +1,12 @@
1
1
  from keras.src.api_export import keras_export
2
2
  from keras.src.backend import KerasTensor
3
- from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
3
+ from keras.src.layers.preprocessing.data_layer import DataLayer
4
4
  from keras.src.utils import backend_utils
5
5
  from keras.src.utils import numerical_utils
6
6
 
7
7
 
8
8
  @keras_export("keras.layers.CategoryEncoding")
9
- class CategoryEncoding(TFDataLayer):
9
+ class CategoryEncoding(DataLayer):
10
10
  """A preprocessing layer which encodes integer features.
11
11
 
12
12
  This layer provides options for condensing data into a categorical encoding
@@ -15,7 +15,7 @@ class CategoryEncoding(TFDataLayer):
15
15
  inputs. For integer inputs where the total number of tokens is not known,
16
16
  use `keras.layers.IntegerLookup` instead.
17
17
 
18
- **Note:** This layer is safe to use inside a `tf.data` pipeline
18
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
19
19
  (independently of which backend you're using).
20
20
 
21
21
  Examples:
@@ -0,0 +1,159 @@
1
+ import keras.src.backend
2
+ from keras.src import tree
3
+ from keras.src.layers.layer import Layer
4
+ from keras.src.random.seed_generator import SeedGenerator
5
+ from keras.src.utils import backend_utils
6
+ from keras.src.utils import jax_utils
7
+ from keras.src.utils import tracking
8
+
9
+
10
+ class DataLayer(Layer):
11
+ """Layer designed for safe use in `tf.data` or `grain` pipeline.
12
+
13
+ This layer overrides the `__call__` method to ensure that the correct
14
+ backend is used and that computation is performed on the CPU.
15
+
16
+ The `call()` method in subclasses should use `self.backend` ops. If
17
+ randomness is needed, define both `seed` and `generator` in `__init__` and
18
+ retrieve the running seed using `self._get_seed_generator()`. If the layer
19
+ has weights in `__init__` or `build()`, use `convert_weight()` to ensure
20
+ they are in the correct backend.
21
+
22
+ **Note:** This layer and its subclasses only support a single input tensor.
23
+
24
+ Examples:
25
+
26
+ **Custom `DataLayer` subclass:**
27
+
28
+ ```python
29
+ from keras.src.layers.preprocessing.data_layer import DataLayer
30
+ from keras.src.random import SeedGenerator
31
+
32
+
33
+ class BiasedRandomRGBToHSVLayer(DataLayer):
34
+ def __init__(self, seed=None, **kwargs):
35
+ super().__init__(**kwargs)
36
+ self.probability_bias = ops.convert_to_tensor(0.01)
37
+ self.seed = seed
38
+ self.generator = SeedGenerator(seed)
39
+
40
+ def call(self, inputs):
41
+ images_shape = self.backend.shape(inputs)
42
+ batch_size = 1 if len(images_shape) == 3 else images_shape[0]
43
+ seed = self._get_seed_generator(self.backend._backend)
44
+
45
+ probability = self.backend.random.uniform(
46
+ shape=(batch_size,),
47
+ minval=0.0,
48
+ maxval=1.0,
49
+ seed=seed,
50
+ )
51
+ probability = self.backend.numpy.add(
52
+ probability, self.convert_weight(self.probability_bias)
53
+ )
54
+ hsv_images = self.backend.image.rgb_to_hsv(inputs)
55
+ return self.backend.numpy.where(
56
+ probability[:, None, None, None] > 0.5,
57
+ hsv_images,
58
+ inputs,
59
+ )
60
+
61
+ def compute_output_shape(self, input_shape):
62
+ return input_shape
63
+ ```
64
+
65
+ **Using as a regular Keras layer:**
66
+
67
+ ```python
68
+ import numpy as np
69
+
70
+ x = np.random.uniform(size=(1, 16, 16, 3)).astype("float32")
71
+ print(BiasedRandomRGBToHSVLayer()(x).shape) # (1, 16, 16, 3)
72
+ ```
73
+
74
+ **Using in a `tf.data` pipeline:**
75
+
76
+ ```python
77
+ import tensorflow as tf
78
+
79
+ tf_ds = tf.data.Dataset.from_tensors(x)
80
+ tf_ds = tf_ds.map(BiasedRandomRGBToHSVLayer())
81
+ print([x.shape for x in tf_ds]) # [(1, 16, 16, 3)]
82
+ ```
83
+
84
+ **Using in a `grain` pipeline:**
85
+
86
+ ```python
87
+ import grain
88
+
89
+ grain_ds = grain.MapDataset.source([x])
90
+ grain_ds = grain_ds.map(BiasedRandomRGBToHSVLayer())
91
+ print([x.shape for x in grain_ds]) # [(1, 16, 16, 3)]
92
+ """
93
+
94
+ def __init__(self, **kwargs):
95
+ super().__init__(**kwargs)
96
+ self.backend = backend_utils.DynamicBackend()
97
+ self._allow_non_tensor_positional_args = True
98
+
99
+ def __call__(self, inputs, **kwargs):
100
+ sample_input = tree.flatten(inputs)[0]
101
+ if (
102
+ not isinstance(sample_input, keras.KerasTensor)
103
+ and backend_utils.in_tf_graph()
104
+ and not jax_utils.is_in_jax_tracing_scope(sample_input)
105
+ ):
106
+ # We're in a TF graph, e.g. a tf.data pipeline.
107
+ self.backend.set_backend("tensorflow")
108
+ inputs = tree.map_structure(
109
+ lambda x: self.backend.convert_to_tensor(
110
+ x, dtype=self.compute_dtype
111
+ ),
112
+ inputs,
113
+ )
114
+ switch_convert_input_args = False
115
+ if self._convert_input_args:
116
+ self._convert_input_args = False
117
+ switch_convert_input_args = True
118
+ try:
119
+ outputs = super().__call__(inputs, **kwargs)
120
+ finally:
121
+ self.backend.reset()
122
+ if switch_convert_input_args:
123
+ self._convert_input_args = True
124
+ return outputs
125
+ elif (
126
+ not isinstance(sample_input, keras.KerasTensor)
127
+ and backend_utils.in_grain_data_pipeline()
128
+ ):
129
+ # We're in a Grain data pipeline. Force computation and data
130
+ # placement to CPU.
131
+ with keras.src.backend.device_scope("cpu"):
132
+ return super().__call__(inputs, **kwargs)
133
+ else:
134
+ return super().__call__(inputs, **kwargs)
135
+
136
+ @tracking.no_automatic_dependency_tracking
137
+ def _get_seed_generator(self, backend=None):
138
+ if not hasattr(self, "seed") or not hasattr(self, "generator"):
139
+ raise ValueError(
140
+ "The `seed` and `generator` variable must be set in the "
141
+ "`__init__` method before calling `_get_seed_generator()`."
142
+ )
143
+ if backend is None or backend == keras.backend.backend():
144
+ return self.generator
145
+ if not hasattr(self, "_backend_generators"):
146
+ self._backend_generators = {}
147
+ if backend in self._backend_generators:
148
+ return self._backend_generators[backend]
149
+ seed_generator = SeedGenerator(self.seed, backend=self.backend)
150
+ self._backend_generators[backend] = seed_generator
151
+ return seed_generator
152
+
153
+ def convert_weight(self, weight):
154
+ """Convert the weight if it is from the a different backend."""
155
+ if self.backend.name == keras.backend.backend():
156
+ return weight
157
+ else:
158
+ weight = keras.ops.convert_to_numpy(weight)
159
+ return self.backend.convert_to_tensor(weight)
@@ -2,21 +2,21 @@ import numpy as np
2
2
 
3
3
  from keras.src import backend
4
4
  from keras.src.api_export import keras_export
5
- from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
5
+ from keras.src.layers.preprocessing.data_layer import DataLayer
6
6
  from keras.src.utils import argument_validation
7
7
  from keras.src.utils import numerical_utils
8
8
  from keras.src.utils.module_utils import tensorflow as tf
9
9
 
10
10
 
11
11
  @keras_export("keras.layers.Discretization")
12
- class Discretization(TFDataLayer):
12
+ class Discretization(DataLayer):
13
13
  """A preprocessing layer which buckets continuous features by ranges.
14
14
 
15
15
  This layer will place each element of its input data into one of several
16
16
  contiguous ranges and output an integer index indicating which range each
17
17
  element was placed in.
18
18
 
19
- **Note:** This layer is safe to use inside a `tf.data` pipeline
19
+ **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline
20
20
  (independently of which backend you're using).
21
21
 
22
22
  Input shape: