tf-keras-nightly 2.17.0.dev2024031909__py3-none-any.whl → 2.19.0.dev2025011410__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/src/__init__.py +1 -1
- tf_keras/src/backend.py +1 -1
- tf_keras/src/callbacks.py +24 -7
- tf_keras/src/datasets/boston_housing.py +14 -5
- tf_keras/src/datasets/cifar10.py +9 -1
- tf_keras/src/datasets/cifar100.py +7 -1
- tf_keras/src/datasets/fashion_mnist.py +16 -4
- tf_keras/src/datasets/imdb.py +8 -0
- tf_keras/src/datasets/mnist.py +9 -3
- tf_keras/src/datasets/reuters.py +8 -0
- tf_keras/src/engine/base_layer.py +10 -4
- tf_keras/src/engine/base_layer_v1.py +10 -4
- tf_keras/src/engine/node.py +8 -3
- tf_keras/src/layers/activation/prelu.py +1 -1
- tf_keras/src/layers/attention/base_dense_attention.py +2 -1
- tf_keras/src/layers/convolutional/base_conv.py +1 -1
- tf_keras/src/layers/convolutional/base_depthwise_conv.py +3 -1
- tf_keras/src/layers/convolutional/base_separable_conv.py +3 -1
- tf_keras/src/layers/convolutional/conv1d_transpose.py +3 -1
- tf_keras/src/layers/convolutional/conv2d_transpose.py +3 -1
- tf_keras/src/layers/convolutional/conv3d_transpose.py +3 -1
- tf_keras/src/layers/core/dense.py +1 -1
- tf_keras/src/layers/core/embedding.py +1 -1
- tf_keras/src/layers/locally_connected/locally_connected1d.py +1 -1
- tf_keras/src/layers/locally_connected/locally_connected2d.py +1 -1
- tf_keras/src/layers/normalization/batch_normalization.py +1 -1
- tf_keras/src/layers/normalization/layer_normalization.py +1 -1
- tf_keras/src/layers/normalization/unit_normalization.py +2 -1
- tf_keras/src/layers/rnn/abstract_rnn_cell.py +1 -1
- tf_keras/src/layers/rnn/base_conv_lstm.py +0 -1
- tf_keras/src/layers/rnn/base_conv_rnn.py +3 -1
- tf_keras/src/layers/rnn/base_rnn.py +1 -1
- tf_keras/src/layers/rnn/base_wrapper.py +1 -1
- tf_keras/src/layers/rnn/bidirectional.py +2 -1
- tf_keras/src/layers/rnn/cell_wrappers.py +3 -3
- tf_keras/src/layers/rnn/cudnn_gru.py +6 -3
- tf_keras/src/layers/rnn/cudnn_lstm.py +6 -3
- tf_keras/src/layers/rnn/gru.py +35 -47
- tf_keras/src/layers/rnn/legacy_cell_wrappers.py +3 -3
- tf_keras/src/layers/rnn/legacy_cells.py +20 -25
- tf_keras/src/layers/rnn/lstm.py +35 -50
- tf_keras/src/layers/rnn/simple_rnn.py +0 -1
- tf_keras/src/layers/rnn/stacked_rnn_cells.py +1 -1
- tf_keras/src/layers/rnn/time_distributed.py +0 -1
- tf_keras/src/mixed_precision/autocast_variable.py +12 -6
- tf_keras/src/mixed_precision/test_util.py +6 -5
- tf_keras/src/optimizers/legacy/optimizer_v2.py +9 -2
- tf_keras/src/optimizers/optimizer.py +18 -9
- tf_keras/src/premade_models/linear.py +2 -1
- tf_keras/src/saving/legacy/saved_model/json_utils.py +1 -1
- tf_keras/src/saving/saving_api.py +165 -127
- tf_keras/src/saving/saving_lib.py +1 -11
- tf_keras/src/saving/serialization_lib.py +1 -10
- tf_keras/src/utils/data_utils.py +1 -1
- tf_keras/src/utils/steps_per_execution_tuning.py +1 -1
- tf_keras/src/utils/tf_utils.py +2 -2
- tf_keras/src/utils/timeseries_dataset.py +13 -5
- {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/METADATA +14 -3
- {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/RECORD +62 -62
- {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/WHEEL +1 -1
- {tf_keras_nightly-2.17.0.dev2024031909.dist-info → tf_keras_nightly-2.19.0.dev2025011410.dist-info}/top_level.txt +0 -0
@@ -171,14 +171,14 @@ class MultiplyLayer(AssertTypeLayer):
|
|
171
171
|
activity_regularizer=self._activity_regularizer, **kwargs
|
172
172
|
)
|
173
173
|
|
174
|
-
def build(self,
|
174
|
+
def build(self, input_shape):
|
175
175
|
self.v = self.add_weight(
|
176
176
|
self._var_name,
|
177
177
|
(),
|
178
178
|
initializer="ones",
|
179
179
|
regularizer=self._regularizer,
|
180
180
|
)
|
181
|
-
|
181
|
+
super().build(input_shape)
|
182
182
|
|
183
183
|
def call(self, inputs):
|
184
184
|
self.assert_input_types(inputs)
|
@@ -205,7 +205,7 @@ class MultiplyLayer(AssertTypeLayer):
|
|
205
205
|
class MultiplyLayerWithoutAutoCast(MultiplyLayer):
|
206
206
|
"""Same as MultiplyLayer, but does not use AutoCastVariables."""
|
207
207
|
|
208
|
-
def build(self,
|
208
|
+
def build(self, input_shape):
|
209
209
|
dtype = self.dtype
|
210
210
|
if dtype in ("float16", "bfloat16"):
|
211
211
|
dtype = "float32"
|
@@ -214,10 +214,11 @@ class MultiplyLayerWithoutAutoCast(MultiplyLayer):
|
|
214
214
|
(),
|
215
215
|
initializer="ones",
|
216
216
|
dtype=dtype,
|
217
|
-
|
217
|
+
autocast=False,
|
218
218
|
regularizer=self._regularizer,
|
219
219
|
)
|
220
|
-
|
220
|
+
# Call Layer.build() to skip MultiplyLayer.build() which we override.
|
221
|
+
base_layer.Layer.build(self, input_shape)
|
221
222
|
|
222
223
|
def call(self, inputs):
|
223
224
|
self.assert_input_types(inputs)
|
@@ -1033,6 +1033,13 @@ class OptimizerV2(tf.__internal__.tracking.Trackable):
|
|
1033
1033
|
slot_dict = self._slots.setdefault(var_key, {})
|
1034
1034
|
weight = slot_dict.get(slot_name, None)
|
1035
1035
|
if weight is None:
|
1036
|
+
# Under a mixed precision policy, variables report their "cast"
|
1037
|
+
# dtype. However, we want to use the original dtype for slots.
|
1038
|
+
if hasattr(var, "true_dtype"):
|
1039
|
+
dtype = var.true_dtype
|
1040
|
+
else:
|
1041
|
+
dtype = var.dtype
|
1042
|
+
|
1036
1043
|
if isinstance(initializer, str) or callable(initializer):
|
1037
1044
|
initializer = initializers.get(initializer)
|
1038
1045
|
if isinstance(
|
@@ -1043,7 +1050,7 @@ class OptimizerV2(tf.__internal__.tracking.Trackable):
|
|
1043
1050
|
else:
|
1044
1051
|
slot_shape = var.shape
|
1045
1052
|
initial_value = functools.partial(
|
1046
|
-
initializer, shape=slot_shape, dtype=
|
1053
|
+
initializer, shape=slot_shape, dtype=dtype
|
1047
1054
|
)
|
1048
1055
|
else:
|
1049
1056
|
initial_value = initializer
|
@@ -1064,7 +1071,7 @@ class OptimizerV2(tf.__internal__.tracking.Trackable):
|
|
1064
1071
|
with strategy.extended.colocate_vars_with(var):
|
1065
1072
|
weight = tf.Variable(
|
1066
1073
|
name=f"{var._shared_name}/{slot_name}",
|
1067
|
-
dtype=
|
1074
|
+
dtype=dtype,
|
1068
1075
|
trainable=False,
|
1069
1076
|
initial_value=initial_value,
|
1070
1077
|
)
|
@@ -498,26 +498,28 @@ class _BaseOptimizer(tf.__internal__.tracking.AutoTrackable):
|
|
498
498
|
Returns:
|
499
499
|
An optimizer variable.
|
500
500
|
"""
|
501
|
+
# Under a mixed precision policy, variables report their "cast"
|
502
|
+
# dtype. However, we want to use the original dtype for slots.
|
503
|
+
if hasattr(model_variable, "true_dtype"):
|
504
|
+
dtype = model_variable.true_dtype
|
505
|
+
else:
|
506
|
+
dtype = model_variable.dtype
|
501
507
|
if initial_value is None:
|
502
508
|
if shape is None:
|
503
509
|
if model_variable.shape.rank is None:
|
504
510
|
# When the rank is None, we cannot get a concrete
|
505
511
|
# `model_variable.shape`, we use dynamic shape.
|
506
|
-
initial_value = tf.zeros_like(
|
507
|
-
model_variable, dtype=model_variable.dtype
|
508
|
-
)
|
512
|
+
initial_value = tf.zeros_like(model_variable, dtype=dtype)
|
509
513
|
else:
|
510
514
|
# We cannot always use `zeros_like`, because some cases
|
511
515
|
# the shape exists while values don't.
|
512
|
-
initial_value = tf.zeros(
|
513
|
-
model_variable.shape, dtype=model_variable.dtype
|
514
|
-
)
|
516
|
+
initial_value = tf.zeros(model_variable.shape, dtype=dtype)
|
515
517
|
else:
|
516
|
-
initial_value = tf.zeros(shape, dtype=
|
518
|
+
initial_value = tf.zeros(shape, dtype=dtype)
|
517
519
|
variable = tf.Variable(
|
518
520
|
initial_value=initial_value,
|
519
521
|
name=f"{variable_name}/{model_variable._shared_name}",
|
520
|
-
dtype=
|
522
|
+
dtype=dtype,
|
521
523
|
trainable=False,
|
522
524
|
)
|
523
525
|
# If model_variable is a shard of a ShardedVariable, we should add a
|
@@ -1188,10 +1190,17 @@ class Optimizer(_BaseOptimizer):
|
|
1188
1190
|
self._mesh, rank=initial_value.shape.rank
|
1189
1191
|
),
|
1190
1192
|
)
|
1193
|
+
# Under a mixed precision policy, variables report their "cast"
|
1194
|
+
# dtype. However, we want to use the original dtype for optimizer
|
1195
|
+
# variables.
|
1196
|
+
if hasattr(model_variable, "true_dtype"):
|
1197
|
+
dtype = model_variable.true_dtype
|
1198
|
+
else:
|
1199
|
+
dtype = model_variable.dtype
|
1191
1200
|
variable = tf.experimental.dtensor.DVariable(
|
1192
1201
|
initial_value=initial_value,
|
1193
1202
|
name=f"{variable_name}/{model_variable._shared_name}",
|
1194
|
-
dtype=
|
1203
|
+
dtype=dtype,
|
1195
1204
|
trainable=False,
|
1196
1205
|
)
|
1197
1206
|
self._variables.append(variable)
|
@@ -156,7 +156,8 @@ class LinearModel(training.Model):
|
|
156
156
|
)
|
157
157
|
else:
|
158
158
|
self.bias = None
|
159
|
-
|
159
|
+
# Call Layer.build() to skip Model.build() which we override here.
|
160
|
+
base_layer.Layer.build(self, input_shape)
|
160
161
|
|
161
162
|
def call(self, inputs):
|
162
163
|
result = None
|
@@ -15,6 +15,7 @@
|
|
15
15
|
"""Public API surface for saving APIs."""
|
16
16
|
|
17
17
|
import os
|
18
|
+
import tempfile
|
18
19
|
import warnings
|
19
20
|
import zipfile
|
20
21
|
|
@@ -33,17 +34,76 @@ except ImportError:
|
|
33
34
|
is_oss = True
|
34
35
|
|
35
36
|
|
36
|
-
|
37
|
-
"""Supports GCS URIs
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
37
|
+
class SupportReadFromRemote:
|
38
|
+
"""Supports GCS URIs and other remote paths via a temporary file.
|
39
|
+
|
40
|
+
This is used for `.keras` and H5 files on GCS, CNS and CFS. TensorFlow
|
41
|
+
supports remoted saved model out of the box.
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __init__(self, filepath):
|
45
|
+
save_format = get_save_format(filepath, save_format=None)
|
46
|
+
if (
|
47
|
+
saving_lib.is_remote_path(filepath)
|
48
|
+
and not tf.io.gfile.isdir(filepath)
|
49
|
+
and save_format != "tf"
|
50
|
+
):
|
51
|
+
self.temp_directory = tempfile.TemporaryDirectory()
|
52
|
+
gs_filepath = filepath
|
53
|
+
if not is_oss and str(filepath).startswith("gs://"):
|
54
|
+
gs_filepath = filepath.replace("gs://", "/bigstore/")
|
55
|
+
self.local_filepath = os.path.join(
|
56
|
+
self.temp_directory.name, os.path.basename(filepath)
|
57
|
+
)
|
58
|
+
tf.io.gfile.copy(gs_filepath, self.local_filepath, overwrite=True)
|
59
|
+
else:
|
60
|
+
self.temp_directory = None
|
61
|
+
self.local_filepath = filepath
|
62
|
+
|
63
|
+
def __enter__(self):
|
64
|
+
return self.local_filepath
|
65
|
+
|
66
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
67
|
+
if self.temp_directory is not None:
|
68
|
+
self.temp_directory.cleanup()
|
69
|
+
|
70
|
+
|
71
|
+
class SupportWriteToRemote:
|
72
|
+
"""Supports GCS URIs and other remote paths via a temporary file.
|
73
|
+
|
74
|
+
This is used for `.keras` and H5 files on GCS, CNS and CFS. TensorFlow
|
75
|
+
supports remoted saved model out of the box.
|
76
|
+
"""
|
77
|
+
|
78
|
+
def __init__(self, filepath, overwrite=True, save_format=None):
|
79
|
+
save_format = get_save_format(filepath, save_format=save_format)
|
80
|
+
self.overwrite = overwrite
|
81
|
+
if saving_lib.is_remote_path(filepath) and save_format != "tf":
|
82
|
+
self.temp_directory = tempfile.TemporaryDirectory()
|
83
|
+
self.remote_filepath = filepath
|
84
|
+
if not is_oss and str(filepath).startswith("gs://"):
|
85
|
+
self.remote_filepath = self.remote_filepath.replace(
|
86
|
+
"gs://", "/bigstore/"
|
87
|
+
)
|
88
|
+
self.local_filepath = os.path.join(
|
89
|
+
self.temp_directory.name, os.path.basename(filepath)
|
90
|
+
)
|
91
|
+
else:
|
92
|
+
self.temp_directory = None
|
93
|
+
self.remote_filepath = None
|
94
|
+
self.local_filepath = filepath
|
95
|
+
|
96
|
+
def __enter__(self):
|
97
|
+
return self.local_filepath
|
98
|
+
|
99
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
100
|
+
if self.temp_directory is not None:
|
101
|
+
tf.io.gfile.copy(
|
102
|
+
self.local_filepath,
|
103
|
+
self.remote_filepath,
|
104
|
+
overwrite=self.overwrite,
|
105
|
+
)
|
106
|
+
self.temp_directory.cleanup()
|
47
107
|
|
48
108
|
|
49
109
|
@keras_export("keras.saving.save_model", "keras.models.save_model")
|
@@ -131,46 +191,49 @@ def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
|
|
131
191
|
when loading the model. See the `custom_objects` argument in
|
132
192
|
`tf.keras.saving.load_model`.
|
133
193
|
"""
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
194
|
+
# Supports remote paths via a temporary file
|
195
|
+
with SupportWriteToRemote(
|
196
|
+
filepath,
|
197
|
+
overwrite=overwrite,
|
198
|
+
save_format=save_format,
|
199
|
+
) as local_filepath:
|
200
|
+
save_format = get_save_format(filepath, save_format)
|
201
|
+
|
202
|
+
# Deprecation warnings
|
203
|
+
if save_format == "h5":
|
204
|
+
warnings.warn(
|
205
|
+
"You are saving your model as an HDF5 file via `model.save()`. "
|
206
|
+
"This file format is considered legacy. "
|
207
|
+
"We recommend using instead the native TF-Keras format, "
|
208
|
+
"e.g. `model.save('my_model.keras')`.",
|
209
|
+
stacklevel=2,
|
210
|
+
)
|
148
211
|
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
212
|
+
if save_format == "keras":
|
213
|
+
# If file exists and should not be overwritten.
|
214
|
+
try:
|
215
|
+
exists = os.path.exists(local_filepath)
|
216
|
+
except TypeError:
|
217
|
+
exists = False
|
218
|
+
if exists and not overwrite:
|
219
|
+
proceed = io_utils.ask_to_proceed_with_overwrite(local_filepath)
|
220
|
+
if not proceed:
|
221
|
+
return
|
222
|
+
if kwargs:
|
223
|
+
raise ValueError(
|
224
|
+
"The following argument(s) are not supported "
|
225
|
+
f"with the native TF-Keras format: {list(kwargs.keys())}"
|
226
|
+
)
|
227
|
+
saving_lib.save_model(model, local_filepath)
|
228
|
+
else:
|
229
|
+
# Legacy case
|
230
|
+
return legacy_sm_saving_lib.save_model(
|
231
|
+
model,
|
232
|
+
local_filepath,
|
233
|
+
overwrite=overwrite,
|
234
|
+
save_format=save_format,
|
235
|
+
**kwargs,
|
163
236
|
)
|
164
|
-
saving_lib.save_model(model, filepath)
|
165
|
-
else:
|
166
|
-
# Legacy case
|
167
|
-
return legacy_sm_saving_lib.save_model(
|
168
|
-
model,
|
169
|
-
filepath,
|
170
|
-
overwrite=overwrite,
|
171
|
-
save_format=save_format,
|
172
|
-
**kwargs,
|
173
|
-
)
|
174
237
|
|
175
238
|
|
176
239
|
@keras_export("keras.saving.load_model", "keras.models.load_model")
|
@@ -217,94 +280,69 @@ def load_model(
|
|
217
280
|
It is recommended that you use layer attributes to
|
218
281
|
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
|
219
282
|
"""
|
220
|
-
# Supports
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
):
|
236
|
-
local_path = os.path.join(
|
237
|
-
saving_lib.get_temp_dir(), os.path.basename(filepath)
|
238
|
-
)
|
239
|
-
|
240
|
-
# Copy from remote to temporary local directory
|
241
|
-
tf.io.gfile.copy(filepath, local_path, overwrite=True)
|
242
|
-
|
243
|
-
# Switch filepath to local zipfile for loading model
|
244
|
-
if zipfile.is_zipfile(local_path):
|
245
|
-
filepath = local_path
|
246
|
-
is_keras_zip = True
|
247
|
-
|
248
|
-
if is_keras_zip:
|
249
|
-
if kwargs:
|
250
|
-
raise ValueError(
|
251
|
-
"The following argument(s) are not supported "
|
252
|
-
f"with the native TF-Keras format: {list(kwargs.keys())}"
|
283
|
+
# Supports remote paths via a temporary file
|
284
|
+
with SupportReadFromRemote(filepath) as local_filepath:
|
285
|
+
if str(local_filepath).endswith(".keras") and zipfile.is_zipfile(
|
286
|
+
local_filepath
|
287
|
+
):
|
288
|
+
if kwargs:
|
289
|
+
raise ValueError(
|
290
|
+
"The following argument(s) are not supported "
|
291
|
+
f"with the native TF-Keras format: {list(kwargs.keys())}"
|
292
|
+
)
|
293
|
+
return saving_lib.load_model(
|
294
|
+
local_filepath,
|
295
|
+
custom_objects=custom_objects,
|
296
|
+
compile=compile,
|
297
|
+
safe_mode=safe_mode,
|
253
298
|
)
|
254
|
-
|
255
|
-
|
299
|
+
|
300
|
+
# Legacy case.
|
301
|
+
return legacy_sm_saving_lib.load_model(
|
302
|
+
local_filepath,
|
256
303
|
custom_objects=custom_objects,
|
257
304
|
compile=compile,
|
258
|
-
|
305
|
+
**kwargs,
|
259
306
|
)
|
260
307
|
|
261
|
-
# Legacy case.
|
262
|
-
return legacy_sm_saving_lib.load_model(
|
263
|
-
filepath, custom_objects=custom_objects, compile=compile, **kwargs
|
264
|
-
)
|
265
|
-
|
266
308
|
|
267
309
|
def save_weights(model, filepath, overwrite=True, **kwargs):
|
268
|
-
# Supports
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
model, filepath, overwrite=overwrite, **kwargs
|
286
|
-
)
|
310
|
+
# Supports remote paths via a temporary file
|
311
|
+
with SupportWriteToRemote(filepath, overwrite=overwrite) as local_filepath:
|
312
|
+
if str(local_filepath).endswith(".weights.h5"):
|
313
|
+
# If file exists and should not be overwritten.
|
314
|
+
try:
|
315
|
+
exists = os.path.exists(local_filepath)
|
316
|
+
except TypeError:
|
317
|
+
exists = False
|
318
|
+
if exists and not overwrite:
|
319
|
+
proceed = io_utils.ask_to_proceed_with_overwrite(local_filepath)
|
320
|
+
if not proceed:
|
321
|
+
return
|
322
|
+
saving_lib.save_weights_only(model, local_filepath)
|
323
|
+
else:
|
324
|
+
legacy_sm_saving_lib.save_weights(
|
325
|
+
model, local_filepath, overwrite=overwrite, **kwargs
|
326
|
+
)
|
287
327
|
|
288
328
|
|
289
329
|
def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
290
|
-
# Supports
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
model, filepath, skip_mismatch=skip_mismatch, **kwargs
|
307
|
-
)
|
330
|
+
# Supports remote paths via a temporary file
|
331
|
+
with SupportReadFromRemote(filepath) as local_filepath:
|
332
|
+
if str(local_filepath).endswith(".keras") and zipfile.is_zipfile(
|
333
|
+
local_filepath
|
334
|
+
):
|
335
|
+
saving_lib.load_weights_only(
|
336
|
+
model, local_filepath, skip_mismatch=skip_mismatch
|
337
|
+
)
|
338
|
+
elif str(local_filepath).endswith(".weights.h5"):
|
339
|
+
saving_lib.load_weights_only(
|
340
|
+
model, local_filepath, skip_mismatch=skip_mismatch
|
341
|
+
)
|
342
|
+
else:
|
343
|
+
return legacy_sm_saving_lib.load_weights(
|
344
|
+
model, local_filepath, skip_mismatch=skip_mismatch, **kwargs
|
345
|
+
)
|
308
346
|
|
309
347
|
|
310
348
|
def get_save_format(filepath, save_format):
|
@@ -156,14 +156,8 @@ def save_model(model, filepath, weights_format="h5"):
|
|
156
156
|
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
|
157
157
|
}
|
158
158
|
)
|
159
|
-
# TODO(rameshsampath): Need a better logic for local vs remote path
|
160
|
-
if is_remote_path(filepath):
|
161
|
-
# Remote path. Zip to local memory byte io and copy to remote
|
162
|
-
zip_filepath = io.BytesIO()
|
163
|
-
else:
|
164
|
-
zip_filepath = filepath
|
165
159
|
try:
|
166
|
-
with zipfile.ZipFile(
|
160
|
+
with zipfile.ZipFile(filepath, "w") as zf:
|
167
161
|
with zf.open(_METADATA_FILENAME, "w") as f:
|
168
162
|
f.write(metadata_json.encode())
|
169
163
|
with zf.open(_CONFIG_FILENAME, "w") as f:
|
@@ -195,10 +189,6 @@ def save_model(model, filepath, weights_format="h5"):
|
|
195
189
|
)
|
196
190
|
weights_store.close()
|
197
191
|
asset_store.close()
|
198
|
-
|
199
|
-
if is_remote_path(filepath):
|
200
|
-
with tf.io.gfile.GFile(filepath, "wb") as f:
|
201
|
-
f.write(zip_filepath.getvalue())
|
202
192
|
except Exception as e:
|
203
193
|
raise e
|
204
194
|
finally:
|
@@ -378,7 +378,7 @@ def _get_class_or_fn_config(obj):
|
|
378
378
|
"""Return the object's config depending on its type."""
|
379
379
|
# Functions / lambdas:
|
380
380
|
if isinstance(obj, types.FunctionType):
|
381
|
-
return obj
|
381
|
+
return object_registration.get_registered_name(obj)
|
382
382
|
# All classes:
|
383
383
|
if hasattr(obj, "get_config"):
|
384
384
|
config = obj.get_config()
|
@@ -789,15 +789,6 @@ def _retrieve_class_or_fn(
|
|
789
789
|
if obj is not None:
|
790
790
|
return obj
|
791
791
|
|
792
|
-
# Retrieval of registered custom function in a package
|
793
|
-
filtered_dict = {
|
794
|
-
k: v
|
795
|
-
for k, v in custom_objects.items()
|
796
|
-
if k.endswith(full_config["config"])
|
797
|
-
}
|
798
|
-
if filtered_dict:
|
799
|
-
return next(iter(filtered_dict.values()))
|
800
|
-
|
801
792
|
# Otherwise, attempt to retrieve the class object given the `module`
|
802
793
|
# and `class_name`. Import the module, find the class.
|
803
794
|
try:
|
tf_keras/src/utils/data_utils.py
CHANGED
@@ -1108,7 +1108,7 @@ def pad_sequences(
|
|
1108
1108
|
maxlen = np.max(lengths)
|
1109
1109
|
|
1110
1110
|
is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(
|
1111
|
-
dtype, np.
|
1111
|
+
dtype, np.str_
|
1112
1112
|
)
|
1113
1113
|
if isinstance(value, str) and dtype != object and not is_dtype_str:
|
1114
1114
|
raise ValueError(
|
tf_keras/src/utils/tf_utils.py
CHANGED
@@ -78,9 +78,9 @@ def get_random_seed():
|
|
78
78
|
the random seed as an integer.
|
79
79
|
"""
|
80
80
|
if getattr(backend._SEED_GENERATOR, "generator", None):
|
81
|
-
return backend._SEED_GENERATOR.generator.randint(1, 1e9)
|
81
|
+
return backend._SEED_GENERATOR.generator.randint(1, int(1e9))
|
82
82
|
else:
|
83
|
-
return random.randint(1, 1e9)
|
83
|
+
return random.randint(1, int(1e9))
|
84
84
|
|
85
85
|
|
86
86
|
def is_tensor_or_tensor_list(v):
|
@@ -110,16 +110,24 @@ def timeseries_dataset_from_array(
|
|
110
110
|
timesteps to predict the next timestep, you would use:
|
111
111
|
|
112
112
|
```python
|
113
|
-
|
114
|
-
|
113
|
+
data = tf.range(15)
|
114
|
+
sequence_length = 10
|
115
|
+
input_data = data[:]
|
116
|
+
targets = data[sequence_length:]
|
115
117
|
dataset = tf.keras.utils.timeseries_dataset_from_array(
|
116
|
-
input_data, targets, sequence_length=
|
118
|
+
input_data, targets, sequence_length=sequence_length
|
119
|
+
)
|
117
120
|
for batch in dataset:
|
118
121
|
inputs, targets = batch
|
119
|
-
|
122
|
+
# First sequence: steps [0-9]
|
123
|
+
assert np.array_equal(inputs[0], data[:sequence_length])
|
120
124
|
# Corresponding target: step 10
|
121
|
-
assert np.array_equal(targets[0], data[
|
125
|
+
assert np.array_equal(targets[0], data[sequence_length])
|
122
126
|
break
|
127
|
+
# To view the generated dataset
|
128
|
+
for batch in dataset.as_numpy_iterator():
|
129
|
+
input, label = batch
|
130
|
+
print(f"Input:{input}, target:{label}")
|
123
131
|
```
|
124
132
|
|
125
133
|
Example 3: Temporal regression for many-to-many architectures.
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: tf_keras-nightly
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.19.0.dev2025011410
|
4
4
|
Summary: Deep learning for humans.
|
5
5
|
Home-page: https://keras.io/
|
6
6
|
Download-URL: https://github.com/keras-team/tf-keras/tags
|
@@ -26,7 +26,18 @@ Classifier: Topic :: Software Development
|
|
26
26
|
Classifier: Topic :: Software Development :: Libraries
|
27
27
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
28
28
|
Requires-Python: >=3.9
|
29
|
-
Requires-Dist: tf-nightly
|
29
|
+
Requires-Dist: tf-nightly~=2.19.0.dev
|
30
|
+
Dynamic: author
|
31
|
+
Dynamic: author-email
|
32
|
+
Dynamic: classifier
|
33
|
+
Dynamic: description
|
34
|
+
Dynamic: download-url
|
35
|
+
Dynamic: home-page
|
36
|
+
Dynamic: keywords
|
37
|
+
Dynamic: license
|
38
|
+
Dynamic: requires-dist
|
39
|
+
Dynamic: requires-python
|
40
|
+
Dynamic: summary
|
30
41
|
|
31
42
|
TF-Keras is a deep learning API written in Python,
|
32
43
|
running on top of the machine learning platform TensorFlow.
|