tf-keras-nightly 2.17.0.dev2024031909__py3-none-any.whl → 2.19.0.dev2025011410__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tf_keras/__init__.py +1 -1
  2. tf_keras/src/__init__.py +1 -1
  3. tf_keras/src/backend.py +1 -1
  4. tf_keras/src/callbacks.py +24 -7
  5. tf_keras/src/datasets/boston_housing.py +14 -5
  6. tf_keras/src/datasets/cifar10.py +9 -1
  7. tf_keras/src/datasets/cifar100.py +7 -1
  8. tf_keras/src/datasets/fashion_mnist.py +16 -4
  9. tf_keras/src/datasets/imdb.py +8 -0
  10. tf_keras/src/datasets/mnist.py +9 -3
  11. tf_keras/src/datasets/reuters.py +8 -0
  12. tf_keras/src/engine/base_layer.py +10 -4
  13. tf_keras/src/engine/base_layer_v1.py +10 -4
  14. tf_keras/src/engine/node.py +8 -3
  15. tf_keras/src/layers/activation/prelu.py +1 -1
  16. tf_keras/src/layers/attention/base_dense_attention.py +2 -1
  17. tf_keras/src/layers/convolutional/base_conv.py +1 -1
  18. tf_keras/src/layers/convolutional/base_depthwise_conv.py +3 -1
  19. tf_keras/src/layers/convolutional/base_separable_conv.py +3 -1
  20. tf_keras/src/layers/convolutional/conv1d_transpose.py +3 -1
  21. tf_keras/src/layers/convolutional/conv2d_transpose.py +3 -1
  22. tf_keras/src/layers/convolutional/conv3d_transpose.py +3 -1
  23. tf_keras/src/layers/core/dense.py +1 -1
  24. tf_keras/src/layers/core/embedding.py +1 -1
  25. tf_keras/src/layers/locally_connected/locally_connected1d.py +1 -1
  26. tf_keras/src/layers/locally_connected/locally_connected2d.py +1 -1
  27. tf_keras/src/layers/normalization/batch_normalization.py +1 -1
  28. tf_keras/src/layers/normalization/layer_normalization.py +1 -1
  29. tf_keras/src/layers/normalization/unit_normalization.py +2 -1
  30. tf_keras/src/layers/rnn/abstract_rnn_cell.py +1 -1
  31. tf_keras/src/layers/rnn/base_conv_lstm.py +0 -1
  32. tf_keras/src/layers/rnn/base_conv_rnn.py +3 -1
  33. tf_keras/src/layers/rnn/base_rnn.py +1 -1
  34. tf_keras/src/layers/rnn/base_wrapper.py +1 -1
  35. tf_keras/src/layers/rnn/bidirectional.py +2 -1
  36. tf_keras/src/layers/rnn/cell_wrappers.py +3 -3
  37. tf_keras/src/layers/rnn/cudnn_gru.py +6 -3
  38. tf_keras/src/layers/rnn/cudnn_lstm.py +6 -3
  39. tf_keras/src/layers/rnn/gru.py +35 -47
  40. tf_keras/src/layers/rnn/legacy_cell_wrappers.py +3 -3
  41. tf_keras/src/layers/rnn/legacy_cells.py +20 -25
  42. tf_keras/src/layers/rnn/lstm.py +35 -50
  43. tf_keras/src/layers/rnn/simple_rnn.py +0 -1
  44. tf_keras/src/layers/rnn/stacked_rnn_cells.py +1 -1
  45. tf_keras/src/layers/rnn/time_distributed.py +0 -1
  46. tf_keras/src/mixed_precision/autocast_variable.py +12 -6
  47. tf_keras/src/mixed_precision/test_util.py +6 -5
  48. tf_keras/src/optimizers/legacy/optimizer_v2.py +9 -2
  49. tf_keras/src/optimizers/optimizer.py +18 -9
  50. tf_keras/src/premade_models/linear.py +2 -1
  51. tf_keras/src/saving/legacy/saved_model/json_utils.py +1 -1
  52. tf_keras/src/saving/saving_api.py +165 -127
  53. tf_keras/src/saving/saving_lib.py +1 -11
  54. tf_keras/src/saving/serialization_lib.py +1 -10
  55. tf_keras/src/utils/data_utils.py +1 -1
  56. tf_keras/src/utils/steps_per_execution_tuning.py +1 -1
  57. tf_keras/src/utils/tf_utils.py +2 -2
  58. tf_keras/src/utils/timeseries_dataset.py +13 -5
  59. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/METADATA +14 -3
  60. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/RECORD +62 -62
  61. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/WHEEL +1 -1
  62. {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/top_level.txt +0 -0
@@ -171,14 +171,14 @@ class MultiplyLayer(AssertTypeLayer):
171
171
  activity_regularizer=self._activity_regularizer, **kwargs
172
172
  )
173
173
 
174
- def build(self, _):
174
+ def build(self, input_shape):
175
175
  self.v = self.add_weight(
176
176
  self._var_name,
177
177
  (),
178
178
  initializer="ones",
179
179
  regularizer=self._regularizer,
180
180
  )
181
- self.built = True
181
+ super().build(input_shape)
182
182
 
183
183
  def call(self, inputs):
184
184
  self.assert_input_types(inputs)
@@ -205,7 +205,7 @@ class MultiplyLayer(AssertTypeLayer):
205
205
  class MultiplyLayerWithoutAutoCast(MultiplyLayer):
206
206
  """Same as MultiplyLayer, but does not use AutoCastVariables."""
207
207
 
208
- def build(self, _):
208
+ def build(self, input_shape):
209
209
  dtype = self.dtype
210
210
  if dtype in ("float16", "bfloat16"):
211
211
  dtype = "float32"
@@ -214,10 +214,11 @@ class MultiplyLayerWithoutAutoCast(MultiplyLayer):
214
214
  (),
215
215
  initializer="ones",
216
216
  dtype=dtype,
217
- experimental_autocast=False,
217
+ autocast=False,
218
218
  regularizer=self._regularizer,
219
219
  )
220
- self.built = True
220
+ # Call Layer.build() to skip MultiplyLayer.build() which we override.
221
+ base_layer.Layer.build(self, input_shape)
221
222
 
222
223
  def call(self, inputs):
223
224
  self.assert_input_types(inputs)
@@ -1033,6 +1033,13 @@ class OptimizerV2(tf.__internal__.tracking.Trackable):
1033
1033
  slot_dict = self._slots.setdefault(var_key, {})
1034
1034
  weight = slot_dict.get(slot_name, None)
1035
1035
  if weight is None:
1036
+ # Under a mixed precision policy, variables report their "cast"
1037
+ # dtype. However, we want to use the original dtype for slots.
1038
+ if hasattr(var, "true_dtype"):
1039
+ dtype = var.true_dtype
1040
+ else:
1041
+ dtype = var.dtype
1042
+
1036
1043
  if isinstance(initializer, str) or callable(initializer):
1037
1044
  initializer = initializers.get(initializer)
1038
1045
  if isinstance(
@@ -1043,7 +1050,7 @@ class OptimizerV2(tf.__internal__.tracking.Trackable):
1043
1050
  else:
1044
1051
  slot_shape = var.shape
1045
1052
  initial_value = functools.partial(
1046
- initializer, shape=slot_shape, dtype=var.dtype
1053
+ initializer, shape=slot_shape, dtype=dtype
1047
1054
  )
1048
1055
  else:
1049
1056
  initial_value = initializer
@@ -1064,7 +1071,7 @@ class OptimizerV2(tf.__internal__.tracking.Trackable):
1064
1071
  with strategy.extended.colocate_vars_with(var):
1065
1072
  weight = tf.Variable(
1066
1073
  name=f"{var._shared_name}/{slot_name}",
1067
- dtype=var.dtype,
1074
+ dtype=dtype,
1068
1075
  trainable=False,
1069
1076
  initial_value=initial_value,
1070
1077
  )
@@ -498,26 +498,28 @@ class _BaseOptimizer(tf.__internal__.tracking.AutoTrackable):
498
498
  Returns:
499
499
  An optimizer variable.
500
500
  """
501
+ # Under a mixed precision policy, variables report their "cast"
502
+ # dtype. However, we want to use the original dtype for slots.
503
+ if hasattr(model_variable, "true_dtype"):
504
+ dtype = model_variable.true_dtype
505
+ else:
506
+ dtype = model_variable.dtype
501
507
  if initial_value is None:
502
508
  if shape is None:
503
509
  if model_variable.shape.rank is None:
504
510
  # When the rank is None, we cannot get a concrete
505
511
  # `model_variable.shape`, we use dynamic shape.
506
- initial_value = tf.zeros_like(
507
- model_variable, dtype=model_variable.dtype
508
- )
512
+ initial_value = tf.zeros_like(model_variable, dtype=dtype)
509
513
  else:
510
514
  # We cannot always use `zeros_like`, because some cases
511
515
  # the shape exists while values don't.
512
- initial_value = tf.zeros(
513
- model_variable.shape, dtype=model_variable.dtype
514
- )
516
+ initial_value = tf.zeros(model_variable.shape, dtype=dtype)
515
517
  else:
516
- initial_value = tf.zeros(shape, dtype=model_variable.dtype)
518
+ initial_value = tf.zeros(shape, dtype=dtype)
517
519
  variable = tf.Variable(
518
520
  initial_value=initial_value,
519
521
  name=f"{variable_name}/{model_variable._shared_name}",
520
- dtype=model_variable.dtype,
522
+ dtype=dtype,
521
523
  trainable=False,
522
524
  )
523
525
  # If model_variable is a shard of a ShardedVariable, we should add a
@@ -1188,10 +1190,17 @@ class Optimizer(_BaseOptimizer):
1188
1190
  self._mesh, rank=initial_value.shape.rank
1189
1191
  ),
1190
1192
  )
1193
+ # Under a mixed precision policy, variables report their "cast"
1194
+ # dtype. However, we want to use the original dtype for optimizer
1195
+ # variables.
1196
+ if hasattr(model_variable, "true_dtype"):
1197
+ dtype = model_variable.true_dtype
1198
+ else:
1199
+ dtype = model_variable.dtype
1191
1200
  variable = tf.experimental.dtensor.DVariable(
1192
1201
  initial_value=initial_value,
1193
1202
  name=f"{variable_name}/{model_variable._shared_name}",
1194
- dtype=model_variable.dtype,
1203
+ dtype=dtype,
1195
1204
  trainable=False,
1196
1205
  )
1197
1206
  self._variables.append(variable)
@@ -156,7 +156,8 @@ class LinearModel(training.Model):
156
156
  )
157
157
  else:
158
158
  self.bias = None
159
- self.built = True
159
+ # Call Layer.build() to skip Model.build() which we override here.
160
+ base_layer.Layer.build(self, input_shape)
160
161
 
161
162
  def call(self, inputs):
162
163
  result = None
@@ -210,7 +210,7 @@ def get_json_type(obj):
210
210
  return {
211
211
  "class_name": "TypeSpec",
212
212
  "type_spec": type_spec_name,
213
- "serialized": obj._serialize(),
213
+ "serialized": _encode_tuple(obj._serialize()),
214
214
  }
215
215
  except ValueError:
216
216
  raise ValueError(
@@ -15,6 +15,7 @@
15
15
  """Public API surface for saving APIs."""
16
16
 
17
17
  import os
18
+ import tempfile
18
19
  import warnings
19
20
  import zipfile
20
21
 
@@ -33,17 +34,76 @@ except ImportError:
33
34
  is_oss = True
34
35
 
35
36
 
36
- def _support_gcs_uri(filepath, save_format, is_oss):
37
- """Supports GCS URIs through bigstore via a temporary file."""
38
- gs_filepath = None
39
- if str(filepath).startswith("gs://") and save_format != "tf":
40
- gs_filepath = filepath
41
- if not is_oss:
42
- gs_filepath = filepath.replace("gs://", "/bigstore/")
43
- filepath = os.path.join(
44
- saving_lib.get_temp_dir(), os.path.basename(gs_filepath)
45
- )
46
- return gs_filepath, filepath
37
+ class SupportReadFromRemote:
38
+ """Supports GCS URIs and other remote paths via a temporary file.
39
+
40
+ This is used for `.keras` and H5 files on GCS, CNS and CFS. TensorFlow
41
+ supports remoted saved model out of the box.
42
+ """
43
+
44
+ def __init__(self, filepath):
45
+ save_format = get_save_format(filepath, save_format=None)
46
+ if (
47
+ saving_lib.is_remote_path(filepath)
48
+ and not tf.io.gfile.isdir(filepath)
49
+ and save_format != "tf"
50
+ ):
51
+ self.temp_directory = tempfile.TemporaryDirectory()
52
+ gs_filepath = filepath
53
+ if not is_oss and str(filepath).startswith("gs://"):
54
+ gs_filepath = filepath.replace("gs://", "/bigstore/")
55
+ self.local_filepath = os.path.join(
56
+ self.temp_directory.name, os.path.basename(filepath)
57
+ )
58
+ tf.io.gfile.copy(gs_filepath, self.local_filepath, overwrite=True)
59
+ else:
60
+ self.temp_directory = None
61
+ self.local_filepath = filepath
62
+
63
+ def __enter__(self):
64
+ return self.local_filepath
65
+
66
+ def __exit__(self, exc_type, exc_value, traceback):
67
+ if self.temp_directory is not None:
68
+ self.temp_directory.cleanup()
69
+
70
+
71
+ class SupportWriteToRemote:
72
+ """Supports GCS URIs and other remote paths via a temporary file.
73
+
74
+ This is used for `.keras` and H5 files on GCS, CNS and CFS. TensorFlow
75
+ supports remoted saved model out of the box.
76
+ """
77
+
78
+ def __init__(self, filepath, overwrite=True, save_format=None):
79
+ save_format = get_save_format(filepath, save_format=save_format)
80
+ self.overwrite = overwrite
81
+ if saving_lib.is_remote_path(filepath) and save_format != "tf":
82
+ self.temp_directory = tempfile.TemporaryDirectory()
83
+ self.remote_filepath = filepath
84
+ if not is_oss and str(filepath).startswith("gs://"):
85
+ self.remote_filepath = self.remote_filepath.replace(
86
+ "gs://", "/bigstore/"
87
+ )
88
+ self.local_filepath = os.path.join(
89
+ self.temp_directory.name, os.path.basename(filepath)
90
+ )
91
+ else:
92
+ self.temp_directory = None
93
+ self.remote_filepath = None
94
+ self.local_filepath = filepath
95
+
96
+ def __enter__(self):
97
+ return self.local_filepath
98
+
99
+ def __exit__(self, exc_type, exc_value, traceback):
100
+ if self.temp_directory is not None:
101
+ tf.io.gfile.copy(
102
+ self.local_filepath,
103
+ self.remote_filepath,
104
+ overwrite=self.overwrite,
105
+ )
106
+ self.temp_directory.cleanup()
47
107
 
48
108
 
49
109
  @keras_export("keras.saving.save_model", "keras.models.save_model")
@@ -131,46 +191,49 @@ def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
131
191
  when loading the model. See the `custom_objects` argument in
132
192
  `tf.keras.saving.load_model`.
133
193
  """
134
- save_format = get_save_format(filepath, save_format)
135
-
136
- # Supports GCS URIs through bigstore via a temporary file
137
- gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)
138
-
139
- # Deprecation warnings
140
- if save_format == "h5":
141
- warnings.warn(
142
- "You are saving your model as an HDF5 file via `model.save()`. "
143
- "This file format is considered legacy. "
144
- "We recommend using instead the native TF-Keras format, "
145
- "e.g. `model.save('my_model.keras')`.",
146
- stacklevel=2,
147
- )
194
+ # Supports remote paths via a temporary file
195
+ with SupportWriteToRemote(
196
+ filepath,
197
+ overwrite=overwrite,
198
+ save_format=save_format,
199
+ ) as local_filepath:
200
+ save_format = get_save_format(filepath, save_format)
201
+
202
+ # Deprecation warnings
203
+ if save_format == "h5":
204
+ warnings.warn(
205
+ "You are saving your model as an HDF5 file via `model.save()`. "
206
+ "This file format is considered legacy. "
207
+ "We recommend using instead the native TF-Keras format, "
208
+ "e.g. `model.save('my_model.keras')`.",
209
+ stacklevel=2,
210
+ )
148
211
 
149
- if save_format == "keras":
150
- # If file exists and should not be overwritten.
151
- try:
152
- exists = os.path.exists(filepath)
153
- except TypeError:
154
- exists = False
155
- if exists and not overwrite:
156
- proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
157
- if not proceed:
158
- return
159
- if kwargs:
160
- raise ValueError(
161
- "The following argument(s) are not supported "
162
- f"with the native TF-Keras format: {list(kwargs.keys())}"
212
+ if save_format == "keras":
213
+ # If file exists and should not be overwritten.
214
+ try:
215
+ exists = os.path.exists(local_filepath)
216
+ except TypeError:
217
+ exists = False
218
+ if exists and not overwrite:
219
+ proceed = io_utils.ask_to_proceed_with_overwrite(local_filepath)
220
+ if not proceed:
221
+ return
222
+ if kwargs:
223
+ raise ValueError(
224
+ "The following argument(s) are not supported "
225
+ f"with the native TF-Keras format: {list(kwargs.keys())}"
226
+ )
227
+ saving_lib.save_model(model, local_filepath)
228
+ else:
229
+ # Legacy case
230
+ return legacy_sm_saving_lib.save_model(
231
+ model,
232
+ local_filepath,
233
+ overwrite=overwrite,
234
+ save_format=save_format,
235
+ **kwargs,
163
236
  )
164
- saving_lib.save_model(model, filepath)
165
- else:
166
- # Legacy case
167
- return legacy_sm_saving_lib.save_model(
168
- model,
169
- filepath,
170
- overwrite=overwrite,
171
- save_format=save_format,
172
- **kwargs,
173
- )
174
237
 
175
238
 
176
239
  @keras_export("keras.saving.load_model", "keras.models.load_model")
@@ -217,94 +280,69 @@ def load_model(
217
280
  It is recommended that you use layer attributes to
218
281
  access specific variables, e.g. `model.get_layer("dense_1").kernel`.
219
282
  """
220
- # Supports GCS URIs by copying data to temporary file
221
- save_format = get_save_format(filepath, save_format=None)
222
- gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)
223
- if gs_filepath is not None:
224
- tf.io.gfile.copy(gs_filepath, filepath, overwrite=True)
225
-
226
- is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile(
227
- filepath
228
- )
229
-
230
- # Support for remote zip files
231
- if (
232
- saving_lib.is_remote_path(filepath)
233
- and not tf.io.gfile.isdir(filepath)
234
- and not is_keras_zip
235
- ):
236
- local_path = os.path.join(
237
- saving_lib.get_temp_dir(), os.path.basename(filepath)
238
- )
239
-
240
- # Copy from remote to temporary local directory
241
- tf.io.gfile.copy(filepath, local_path, overwrite=True)
242
-
243
- # Switch filepath to local zipfile for loading model
244
- if zipfile.is_zipfile(local_path):
245
- filepath = local_path
246
- is_keras_zip = True
247
-
248
- if is_keras_zip:
249
- if kwargs:
250
- raise ValueError(
251
- "The following argument(s) are not supported "
252
- f"with the native TF-Keras format: {list(kwargs.keys())}"
283
+ # Supports remote paths via a temporary file
284
+ with SupportReadFromRemote(filepath) as local_filepath:
285
+ if str(local_filepath).endswith(".keras") and zipfile.is_zipfile(
286
+ local_filepath
287
+ ):
288
+ if kwargs:
289
+ raise ValueError(
290
+ "The following argument(s) are not supported "
291
+ f"with the native TF-Keras format: {list(kwargs.keys())}"
292
+ )
293
+ return saving_lib.load_model(
294
+ local_filepath,
295
+ custom_objects=custom_objects,
296
+ compile=compile,
297
+ safe_mode=safe_mode,
253
298
  )
254
- return saving_lib.load_model(
255
- filepath,
299
+
300
+ # Legacy case.
301
+ return legacy_sm_saving_lib.load_model(
302
+ local_filepath,
256
303
  custom_objects=custom_objects,
257
304
  compile=compile,
258
- safe_mode=safe_mode,
305
+ **kwargs,
259
306
  )
260
307
 
261
- # Legacy case.
262
- return legacy_sm_saving_lib.load_model(
263
- filepath, custom_objects=custom_objects, compile=compile, **kwargs
264
- )
265
-
266
308
 
267
309
  def save_weights(model, filepath, overwrite=True, **kwargs):
268
- # Supports GCS URIs through bigstore via a temporary file
269
- save_format = get_save_format(filepath, save_format=None)
270
- gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)
271
-
272
- if str(filepath).endswith(".weights.h5"):
273
- # If file exists and should not be overwritten.
274
- try:
275
- exists = os.path.exists(filepath)
276
- except TypeError:
277
- exists = False
278
- if exists and not overwrite:
279
- proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
280
- if not proceed:
281
- return
282
- saving_lib.save_weights_only(model, filepath)
283
- else:
284
- legacy_sm_saving_lib.save_weights(
285
- model, filepath, overwrite=overwrite, **kwargs
286
- )
310
+ # Supports remote paths via a temporary file
311
+ with SupportWriteToRemote(filepath, overwrite=overwrite) as local_filepath:
312
+ if str(local_filepath).endswith(".weights.h5"):
313
+ # If file exists and should not be overwritten.
314
+ try:
315
+ exists = os.path.exists(local_filepath)
316
+ except TypeError:
317
+ exists = False
318
+ if exists and not overwrite:
319
+ proceed = io_utils.ask_to_proceed_with_overwrite(local_filepath)
320
+ if not proceed:
321
+ return
322
+ saving_lib.save_weights_only(model, local_filepath)
323
+ else:
324
+ legacy_sm_saving_lib.save_weights(
325
+ model, local_filepath, overwrite=overwrite, **kwargs
326
+ )
287
327
 
288
328
 
289
329
  def load_weights(model, filepath, skip_mismatch=False, **kwargs):
290
- # Supports GCS URIs by copying data to temporary file
291
- save_format = get_save_format(filepath, save_format=None)
292
- gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)
293
- if gs_filepath is not None:
294
- tf.io.gfile.copy(gs_filepath, filepath, overwrite=True)
295
-
296
- if str(filepath).endswith(".keras") and zipfile.is_zipfile(filepath):
297
- saving_lib.load_weights_only(
298
- model, filepath, skip_mismatch=skip_mismatch
299
- )
300
- elif str(filepath).endswith(".weights.h5"):
301
- saving_lib.load_weights_only(
302
- model, filepath, skip_mismatch=skip_mismatch
303
- )
304
- else:
305
- return legacy_sm_saving_lib.load_weights(
306
- model, filepath, skip_mismatch=skip_mismatch, **kwargs
307
- )
330
+ # Supports remote paths via a temporary file
331
+ with SupportReadFromRemote(filepath) as local_filepath:
332
+ if str(local_filepath).endswith(".keras") and zipfile.is_zipfile(
333
+ local_filepath
334
+ ):
335
+ saving_lib.load_weights_only(
336
+ model, local_filepath, skip_mismatch=skip_mismatch
337
+ )
338
+ elif str(local_filepath).endswith(".weights.h5"):
339
+ saving_lib.load_weights_only(
340
+ model, local_filepath, skip_mismatch=skip_mismatch
341
+ )
342
+ else:
343
+ return legacy_sm_saving_lib.load_weights(
344
+ model, local_filepath, skip_mismatch=skip_mismatch, **kwargs
345
+ )
308
346
 
309
347
 
310
348
  def get_save_format(filepath, save_format):
@@ -156,14 +156,8 @@ def save_model(model, filepath, weights_format="h5"):
156
156
  "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
157
157
  }
158
158
  )
159
- # TODO(rameshsampath): Need a better logic for local vs remote path
160
- if is_remote_path(filepath):
161
- # Remote path. Zip to local memory byte io and copy to remote
162
- zip_filepath = io.BytesIO()
163
- else:
164
- zip_filepath = filepath
165
159
  try:
166
- with zipfile.ZipFile(zip_filepath, "w") as zf:
160
+ with zipfile.ZipFile(filepath, "w") as zf:
167
161
  with zf.open(_METADATA_FILENAME, "w") as f:
168
162
  f.write(metadata_json.encode())
169
163
  with zf.open(_CONFIG_FILENAME, "w") as f:
@@ -195,10 +189,6 @@ def save_model(model, filepath, weights_format="h5"):
195
189
  )
196
190
  weights_store.close()
197
191
  asset_store.close()
198
-
199
- if is_remote_path(filepath):
200
- with tf.io.gfile.GFile(filepath, "wb") as f:
201
- f.write(zip_filepath.getvalue())
202
192
  except Exception as e:
203
193
  raise e
204
194
  finally:
@@ -378,7 +378,7 @@ def _get_class_or_fn_config(obj):
378
378
  """Return the object's config depending on its type."""
379
379
  # Functions / lambdas:
380
380
  if isinstance(obj, types.FunctionType):
381
- return obj.__name__
381
+ return object_registration.get_registered_name(obj)
382
382
  # All classes:
383
383
  if hasattr(obj, "get_config"):
384
384
  config = obj.get_config()
@@ -789,15 +789,6 @@ def _retrieve_class_or_fn(
789
789
  if obj is not None:
790
790
  return obj
791
791
 
792
- # Retrieval of registered custom function in a package
793
- filtered_dict = {
794
- k: v
795
- for k, v in custom_objects.items()
796
- if k.endswith(full_config["config"])
797
- }
798
- if filtered_dict:
799
- return next(iter(filtered_dict.values()))
800
-
801
792
  # Otherwise, attempt to retrieve the class object given the `module`
802
793
  # and `class_name`. Import the module, find the class.
803
794
  try:
@@ -1108,7 +1108,7 @@ def pad_sequences(
1108
1108
  maxlen = np.max(lengths)
1109
1109
 
1110
1110
  is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(
1111
- dtype, np.unicode_
1111
+ dtype, np.str_
1112
1112
  )
1113
1113
  if isinstance(value, str) and dtype != object and not is_dtype_str:
1114
1114
  raise ValueError(
@@ -229,7 +229,7 @@ class StepsPerExecutionTuner:
229
229
 
230
230
  if current_spe >= spe_limit:
231
231
  new_spe = current_spe
232
- elif current_spe == 0:
232
+ elif current_spe <= 0:
233
233
  new_spe = self.init_spe
234
234
 
235
235
  self._steps_per_execution.assign(np.round(new_spe))
@@ -78,9 +78,9 @@ def get_random_seed():
78
78
  the random seed as an integer.
79
79
  """
80
80
  if getattr(backend._SEED_GENERATOR, "generator", None):
81
- return backend._SEED_GENERATOR.generator.randint(1, 1e9)
81
+ return backend._SEED_GENERATOR.generator.randint(1, int(1e9))
82
82
  else:
83
- return random.randint(1, 1e9)
83
+ return random.randint(1, int(1e9))
84
84
 
85
85
 
86
86
  def is_tensor_or_tensor_list(v):
@@ -110,16 +110,24 @@ def timeseries_dataset_from_array(
110
110
  timesteps to predict the next timestep, you would use:
111
111
 
112
112
  ```python
113
- input_data = data[:-10]
114
- targets = data[10:]
113
+ data = tf.range(15)
114
+ sequence_length = 10
115
+ input_data = data[:]
116
+ targets = data[sequence_length:]
115
117
  dataset = tf.keras.utils.timeseries_dataset_from_array(
116
- input_data, targets, sequence_length=10)
118
+ input_data, targets, sequence_length=sequence_length
119
+ )
117
120
  for batch in dataset:
118
121
  inputs, targets = batch
119
- assert np.array_equal(inputs[0], data[:10]) # First sequence: steps [0-9]
122
+ # First sequence: steps [0-9]
123
+ assert np.array_equal(inputs[0], data[:sequence_length])
120
124
  # Corresponding target: step 10
121
- assert np.array_equal(targets[0], data[10])
125
+ assert np.array_equal(targets[0], data[sequence_length])
122
126
  break
127
+ # To view the generated dataset
128
+ for batch in dataset.as_numpy_iterator():
129
+ input, label = batch
130
+ print(f"Input:{input}, target:{label}")
123
131
  ```
124
132
 
125
133
  Example 3: Temporal regression for many-to-many architectures.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: tf_keras-nightly
3
- Version: 2.17.0.dev2024031909
3
+ Version: 2.19.0.dev2025011410
4
4
  Summary: Deep learning for humans.
5
5
  Home-page: https://keras.io/
6
6
  Download-URL: https://github.com/keras-team/tf-keras/tags
@@ -26,7 +26,18 @@ Classifier: Topic :: Software Development
26
26
  Classifier: Topic :: Software Development :: Libraries
27
27
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
28
28
  Requires-Python: >=3.9
29
- Requires-Dist: tf-nightly ~=2.17.0.dev
29
+ Requires-Dist: tf-nightly~=2.19.0.dev
30
+ Dynamic: author
31
+ Dynamic: author-email
32
+ Dynamic: classifier
33
+ Dynamic: description
34
+ Dynamic: download-url
35
+ Dynamic: home-page
36
+ Dynamic: keywords
37
+ Dynamic: license
38
+ Dynamic: requires-dist
39
+ Dynamic: requires-python
40
+ Dynamic: summary
30
41
 
31
42
  TF-Keras is a deep learning API written in Python,
32
43
  running on top of the machine learning platform TensorFlow.