tf-keras-nightly 2.19.0.dev2024112610__py3-none-any.whl → 2.19.0.dev2024113010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/src/saving/saving_api.py +165 -127
- tf_keras/src/saving/saving_lib.py +1 -11
- {tf_keras_nightly-2.19.0.dev2024112610.dist-info → tf_keras_nightly-2.19.0.dev2024113010.dist-info}/METADATA +1 -1
- {tf_keras_nightly-2.19.0.dev2024112610.dist-info → tf_keras_nightly-2.19.0.dev2024113010.dist-info}/RECORD +7 -7
- {tf_keras_nightly-2.19.0.dev2024112610.dist-info → tf_keras_nightly-2.19.0.dev2024113010.dist-info}/WHEEL +0 -0
- {tf_keras_nightly-2.19.0.dev2024112610.dist-info → tf_keras_nightly-2.19.0.dev2024113010.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
"""Public API surface for saving APIs."""
|
16
16
|
|
17
17
|
import os
|
18
|
+
import tempfile
|
18
19
|
import warnings
|
19
20
|
import zipfile
|
20
21
|
|
@@ -33,17 +34,76 @@ except ImportError:
|
|
33
34
|
is_oss = True
|
34
35
|
|
35
36
|
|
36
|
-
|
37
|
-
"""Supports GCS URIs
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
37
|
+
class SupportReadFromRemote:
|
38
|
+
"""Supports GCS URIs and other remote paths via a temporary file.
|
39
|
+
|
40
|
+
This is used for `.keras` and H5 files on GCS, CNS and CFS. TensorFlow
|
41
|
+
supports remoted saved model out of the box.
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __init__(self, filepath):
|
45
|
+
save_format = get_save_format(filepath, save_format=None)
|
46
|
+
if (
|
47
|
+
saving_lib.is_remote_path(filepath)
|
48
|
+
and not tf.io.gfile.isdir(filepath)
|
49
|
+
and save_format != "tf"
|
50
|
+
):
|
51
|
+
self.temp_directory = tempfile.TemporaryDirectory()
|
52
|
+
gs_filepath = filepath
|
53
|
+
if not is_oss and str(filepath).startswith("gs://"):
|
54
|
+
gs_filepath = filepath.replace("gs://", "/bigstore/")
|
55
|
+
self.local_filepath = os.path.join(
|
56
|
+
self.temp_directory.name, os.path.basename(filepath)
|
57
|
+
)
|
58
|
+
tf.io.gfile.copy(gs_filepath, self.local_filepath, overwrite=True)
|
59
|
+
else:
|
60
|
+
self.temp_directory = None
|
61
|
+
self.local_filepath = filepath
|
62
|
+
|
63
|
+
def __enter__(self):
|
64
|
+
return self.local_filepath
|
65
|
+
|
66
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
67
|
+
if self.temp_directory is not None:
|
68
|
+
self.temp_directory.cleanup()
|
69
|
+
|
70
|
+
|
71
|
+
class SupportWriteToRemote:
|
72
|
+
"""Supports GCS URIs and other remote paths via a temporary file.
|
73
|
+
|
74
|
+
This is used for `.keras` and H5 files on GCS, CNS and CFS. TensorFlow
|
75
|
+
supports remoted saved model out of the box.
|
76
|
+
"""
|
77
|
+
|
78
|
+
def __init__(self, filepath, overwrite=True, save_format=None):
|
79
|
+
save_format = get_save_format(filepath, save_format=save_format)
|
80
|
+
self.overwrite = overwrite
|
81
|
+
if saving_lib.is_remote_path(filepath) and save_format != "tf":
|
82
|
+
self.temp_directory = tempfile.TemporaryDirectory()
|
83
|
+
self.remote_filepath = filepath
|
84
|
+
if not is_oss and str(filepath).startswith("gs://"):
|
85
|
+
self.remote_filepath = self.remote_filepath.replace(
|
86
|
+
"gs://", "/bigstore/"
|
87
|
+
)
|
88
|
+
self.local_filepath = os.path.join(
|
89
|
+
self.temp_directory.name, os.path.basename(filepath)
|
90
|
+
)
|
91
|
+
else:
|
92
|
+
self.temp_directory = None
|
93
|
+
self.remote_filepath = None
|
94
|
+
self.local_filepath = filepath
|
95
|
+
|
96
|
+
def __enter__(self):
|
97
|
+
return self.local_filepath
|
98
|
+
|
99
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
100
|
+
if self.temp_directory is not None:
|
101
|
+
tf.io.gfile.copy(
|
102
|
+
self.local_filepath,
|
103
|
+
self.remote_filepath,
|
104
|
+
overwrite=self.overwrite,
|
105
|
+
)
|
106
|
+
self.temp_directory.cleanup()
|
47
107
|
|
48
108
|
|
49
109
|
@keras_export("keras.saving.save_model", "keras.models.save_model")
|
@@ -131,46 +191,49 @@ def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
|
|
131
191
|
when loading the model. See the `custom_objects` argument in
|
132
192
|
`tf.keras.saving.load_model`.
|
133
193
|
"""
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
194
|
+
# Supports remote paths via a temporary file
|
195
|
+
with SupportWriteToRemote(
|
196
|
+
filepath,
|
197
|
+
overwrite=overwrite,
|
198
|
+
save_format=save_format,
|
199
|
+
) as local_filepath:
|
200
|
+
save_format = get_save_format(filepath, save_format)
|
201
|
+
|
202
|
+
# Deprecation warnings
|
203
|
+
if save_format == "h5":
|
204
|
+
warnings.warn(
|
205
|
+
"You are saving your model as an HDF5 file via `model.save()`. "
|
206
|
+
"This file format is considered legacy. "
|
207
|
+
"We recommend using instead the native TF-Keras format, "
|
208
|
+
"e.g. `model.save('my_model.keras')`.",
|
209
|
+
stacklevel=2,
|
210
|
+
)
|
148
211
|
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
212
|
+
if save_format == "keras":
|
213
|
+
# If file exists and should not be overwritten.
|
214
|
+
try:
|
215
|
+
exists = os.path.exists(local_filepath)
|
216
|
+
except TypeError:
|
217
|
+
exists = False
|
218
|
+
if exists and not overwrite:
|
219
|
+
proceed = io_utils.ask_to_proceed_with_overwrite(local_filepath)
|
220
|
+
if not proceed:
|
221
|
+
return
|
222
|
+
if kwargs:
|
223
|
+
raise ValueError(
|
224
|
+
"The following argument(s) are not supported "
|
225
|
+
f"with the native TF-Keras format: {list(kwargs.keys())}"
|
226
|
+
)
|
227
|
+
saving_lib.save_model(model, local_filepath)
|
228
|
+
else:
|
229
|
+
# Legacy case
|
230
|
+
return legacy_sm_saving_lib.save_model(
|
231
|
+
model,
|
232
|
+
local_filepath,
|
233
|
+
overwrite=overwrite,
|
234
|
+
save_format=save_format,
|
235
|
+
**kwargs,
|
163
236
|
)
|
164
|
-
saving_lib.save_model(model, filepath)
|
165
|
-
else:
|
166
|
-
# Legacy case
|
167
|
-
return legacy_sm_saving_lib.save_model(
|
168
|
-
model,
|
169
|
-
filepath,
|
170
|
-
overwrite=overwrite,
|
171
|
-
save_format=save_format,
|
172
|
-
**kwargs,
|
173
|
-
)
|
174
237
|
|
175
238
|
|
176
239
|
@keras_export("keras.saving.load_model", "keras.models.load_model")
|
@@ -217,94 +280,69 @@ def load_model(
|
|
217
280
|
It is recommended that you use layer attributes to
|
218
281
|
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
|
219
282
|
"""
|
220
|
-
# Supports
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
):
|
236
|
-
local_path = os.path.join(
|
237
|
-
saving_lib.get_temp_dir(), os.path.basename(filepath)
|
238
|
-
)
|
239
|
-
|
240
|
-
# Copy from remote to temporary local directory
|
241
|
-
tf.io.gfile.copy(filepath, local_path, overwrite=True)
|
242
|
-
|
243
|
-
# Switch filepath to local zipfile for loading model
|
244
|
-
if zipfile.is_zipfile(local_path):
|
245
|
-
filepath = local_path
|
246
|
-
is_keras_zip = True
|
247
|
-
|
248
|
-
if is_keras_zip:
|
249
|
-
if kwargs:
|
250
|
-
raise ValueError(
|
251
|
-
"The following argument(s) are not supported "
|
252
|
-
f"with the native TF-Keras format: {list(kwargs.keys())}"
|
283
|
+
# Supports remote paths via a temporary file
|
284
|
+
with SupportReadFromRemote(filepath) as local_filepath:
|
285
|
+
if str(local_filepath).endswith(".keras") and zipfile.is_zipfile(
|
286
|
+
local_filepath
|
287
|
+
):
|
288
|
+
if kwargs:
|
289
|
+
raise ValueError(
|
290
|
+
"The following argument(s) are not supported "
|
291
|
+
f"with the native TF-Keras format: {list(kwargs.keys())}"
|
292
|
+
)
|
293
|
+
return saving_lib.load_model(
|
294
|
+
local_filepath,
|
295
|
+
custom_objects=custom_objects,
|
296
|
+
compile=compile,
|
297
|
+
safe_mode=safe_mode,
|
253
298
|
)
|
254
|
-
|
255
|
-
|
299
|
+
|
300
|
+
# Legacy case.
|
301
|
+
return legacy_sm_saving_lib.load_model(
|
302
|
+
local_filepath,
|
256
303
|
custom_objects=custom_objects,
|
257
304
|
compile=compile,
|
258
|
-
|
305
|
+
**kwargs,
|
259
306
|
)
|
260
307
|
|
261
|
-
# Legacy case.
|
262
|
-
return legacy_sm_saving_lib.load_model(
|
263
|
-
filepath, custom_objects=custom_objects, compile=compile, **kwargs
|
264
|
-
)
|
265
|
-
|
266
308
|
|
267
309
|
def save_weights(model, filepath, overwrite=True, **kwargs):
|
268
|
-
# Supports
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
model, filepath, overwrite=overwrite, **kwargs
|
286
|
-
)
|
310
|
+
# Supports remote paths via a temporary file
|
311
|
+
with SupportWriteToRemote(filepath, overwrite=overwrite) as local_filepath:
|
312
|
+
if str(local_filepath).endswith(".weights.h5"):
|
313
|
+
# If file exists and should not be overwritten.
|
314
|
+
try:
|
315
|
+
exists = os.path.exists(local_filepath)
|
316
|
+
except TypeError:
|
317
|
+
exists = False
|
318
|
+
if exists and not overwrite:
|
319
|
+
proceed = io_utils.ask_to_proceed_with_overwrite(local_filepath)
|
320
|
+
if not proceed:
|
321
|
+
return
|
322
|
+
saving_lib.save_weights_only(model, local_filepath)
|
323
|
+
else:
|
324
|
+
legacy_sm_saving_lib.save_weights(
|
325
|
+
model, local_filepath, overwrite=overwrite, **kwargs
|
326
|
+
)
|
287
327
|
|
288
328
|
|
289
329
|
def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
290
|
-
# Supports
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
model, filepath, skip_mismatch=skip_mismatch, **kwargs
|
307
|
-
)
|
330
|
+
# Supports remote paths via a temporary file
|
331
|
+
with SupportReadFromRemote(filepath) as local_filepath:
|
332
|
+
if str(local_filepath).endswith(".keras") and zipfile.is_zipfile(
|
333
|
+
local_filepath
|
334
|
+
):
|
335
|
+
saving_lib.load_weights_only(
|
336
|
+
model, local_filepath, skip_mismatch=skip_mismatch
|
337
|
+
)
|
338
|
+
elif str(local_filepath).endswith(".weights.h5"):
|
339
|
+
saving_lib.load_weights_only(
|
340
|
+
model, local_filepath, skip_mismatch=skip_mismatch
|
341
|
+
)
|
342
|
+
else:
|
343
|
+
return legacy_sm_saving_lib.load_weights(
|
344
|
+
model, local_filepath, skip_mismatch=skip_mismatch, **kwargs
|
345
|
+
)
|
308
346
|
|
309
347
|
|
310
348
|
def get_save_format(filepath, save_format):
|
@@ -156,14 +156,8 @@ def save_model(model, filepath, weights_format="h5"):
|
|
156
156
|
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
|
157
157
|
}
|
158
158
|
)
|
159
|
-
# TODO(rameshsampath): Need a better logic for local vs remote path
|
160
|
-
if is_remote_path(filepath):
|
161
|
-
# Remote path. Zip to local memory byte io and copy to remote
|
162
|
-
zip_filepath = io.BytesIO()
|
163
|
-
else:
|
164
|
-
zip_filepath = filepath
|
165
159
|
try:
|
166
|
-
with zipfile.ZipFile(
|
160
|
+
with zipfile.ZipFile(filepath, "w") as zf:
|
167
161
|
with zf.open(_METADATA_FILENAME, "w") as f:
|
168
162
|
f.write(metadata_json.encode())
|
169
163
|
with zf.open(_CONFIG_FILENAME, "w") as f:
|
@@ -195,10 +189,6 @@ def save_model(model, filepath, weights_format="h5"):
|
|
195
189
|
)
|
196
190
|
weights_store.close()
|
197
191
|
asset_store.close()
|
198
|
-
|
199
|
-
if is_remote_path(filepath):
|
200
|
-
with tf.io.gfile.GFile(filepath, "wb") as f:
|
201
|
-
f.write(zip_filepath.getvalue())
|
202
192
|
except Exception as e:
|
203
193
|
raise e
|
204
194
|
finally:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tf_keras/__init__.py,sha256=
|
1
|
+
tf_keras/__init__.py,sha256=TkJSEfpxF7E_c5XsMLzCNEryzHQURhS-XXilfgzLWS0,911
|
2
2
|
tf_keras/__internal__/__init__.py,sha256=OHQbeIC0QtRBI7dgXaJaVbH8F00x8dCI-DvEcIfyMsE,671
|
3
3
|
tf_keras/__internal__/backend/__init__.py,sha256=LnMs2A6685gDG79fxqmdulIYlVE_3WmXlBTBo9ZWYcw,162
|
4
4
|
tf_keras/__internal__/layers/__init__.py,sha256=F5SGMhOTPzm-PR44VrfinURHcVeQPIEdwnZlAkSTB3A,176
|
@@ -537,8 +537,8 @@ tf_keras/src/preprocessing/text.py,sha256=aomzwE3G2ErwzgL_Dj3ERA_2k7TZclaPMTBTrW
|
|
537
537
|
tf_keras/src/saving/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
538
538
|
tf_keras/src/saving/object_registration.py,sha256=N8aV6eqREYjW2ueQpL3guYHyh5KXuun3DZAlmjfYrTA,7830
|
539
539
|
tf_keras/src/saving/pickle_utils.py,sha256=5GtHzwNWVaYfZ-0zn69-zn2yv3R6JUwzHOOamnjP7r0,2605
|
540
|
-
tf_keras/src/saving/saving_api.py,sha256=
|
541
|
-
tf_keras/src/saving/saving_lib.py,sha256=
|
540
|
+
tf_keras/src/saving/saving_api.py,sha256=q2_-CsVg81HeHnXDkRqpoBwI3Mpp_gQXjBidr5V_RLo,14838
|
541
|
+
tf_keras/src/saving/saving_lib.py,sha256=Rk5rOvxEmCvwUlG3bS0QpOgtURayYmIAZbh2GeTuUOc,24272
|
542
542
|
tf_keras/src/saving/serialization_lib.py,sha256=kX4qf_fRp4LySkH9FU37DMd0AXxiUrXKT-VLR3JPl7w,30152
|
543
543
|
tf_keras/src/saving/legacy/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
544
544
|
tf_keras/src/saving/legacy/hdf5_format.py,sha256=IqFXHN96fuqKwu_akaqTyf9ISRPavP3Ahjydat948O4,42438
|
@@ -606,7 +606,7 @@ tf_keras/src/utils/legacy/__init__.py,sha256=EfMmeHYDzwvxNaktPhQbkTdcPSIGCqMhBND
|
|
606
606
|
tf_keras/utils/__init__.py,sha256=b7_d-USe_EmLo02_P99Q1rUCzKBYayPCfiYFStP-0nw,2735
|
607
607
|
tf_keras/utils/experimental/__init__.py,sha256=DzGogE2AosjxOVILQBT8PDDcqbWTc0wWnZRobCdpcec,97
|
608
608
|
tf_keras/utils/legacy/__init__.py,sha256=7ujlDa5HeSRcth2NdqA0S1P2-VZF1kB3n68jye6Dj-8,189
|
609
|
-
tf_keras_nightly-2.19.0.
|
610
|
-
tf_keras_nightly-2.19.0.
|
611
|
-
tf_keras_nightly-2.19.0.
|
612
|
-
tf_keras_nightly-2.19.0.
|
609
|
+
tf_keras_nightly-2.19.0.dev2024113010.dist-info/METADATA,sha256=0iGYOVdmUKVLmSbD9H2U6BUQFryvRzLQ2Y8yBMNPNdQ,1637
|
610
|
+
tf_keras_nightly-2.19.0.dev2024113010.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
611
|
+
tf_keras_nightly-2.19.0.dev2024113010.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
|
612
|
+
tf_keras_nightly-2.19.0.dev2024113010.dist-info/RECORD,,
|
File without changes
|
File without changes
|