tf-keras-nightly 2.19.0.dev2024112610__py3-none-any.whl → 2.19.0.dev2024113010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tf_keras/__init__.py CHANGED
@@ -27,4 +27,4 @@ from tf_keras.src.engine.sequential import Sequential
27
27
  from tf_keras.src.engine.training import Model
28
28
 
29
29
 
30
- __version__ = "2.19.0.dev2024112610"
30
+ __version__ = "2.19.0.dev2024113010"
@@ -15,6 +15,7 @@
15
15
  """Public API surface for saving APIs."""
16
16
 
17
17
  import os
18
+ import tempfile
18
19
  import warnings
19
20
  import zipfile
20
21
 
@@ -33,17 +34,76 @@ except ImportError:
33
34
  is_oss = True
34
35
 
35
36
 
36
- def _support_gcs_uri(filepath, save_format, is_oss):
37
- """Supports GCS URIs through bigstore via a temporary file."""
38
- gs_filepath = None
39
- if str(filepath).startswith("gs://") and save_format != "tf":
40
- gs_filepath = filepath
41
- if not is_oss:
42
- gs_filepath = filepath.replace("gs://", "/bigstore/")
43
- filepath = os.path.join(
44
- saving_lib.get_temp_dir(), os.path.basename(gs_filepath)
45
- )
46
- return gs_filepath, filepath
37
+ class SupportReadFromRemote:
38
+ """Supports GCS URIs and other remote paths via a temporary file.
39
+
40
+ This is used for `.keras` and H5 files on GCS, CNS and CFS. TensorFlow
41
+ supports remoted saved model out of the box.
42
+ """
43
+
44
+ def __init__(self, filepath):
45
+ save_format = get_save_format(filepath, save_format=None)
46
+ if (
47
+ saving_lib.is_remote_path(filepath)
48
+ and not tf.io.gfile.isdir(filepath)
49
+ and save_format != "tf"
50
+ ):
51
+ self.temp_directory = tempfile.TemporaryDirectory()
52
+ gs_filepath = filepath
53
+ if not is_oss and str(filepath).startswith("gs://"):
54
+ gs_filepath = filepath.replace("gs://", "/bigstore/")
55
+ self.local_filepath = os.path.join(
56
+ self.temp_directory.name, os.path.basename(filepath)
57
+ )
58
+ tf.io.gfile.copy(gs_filepath, self.local_filepath, overwrite=True)
59
+ else:
60
+ self.temp_directory = None
61
+ self.local_filepath = filepath
62
+
63
+ def __enter__(self):
64
+ return self.local_filepath
65
+
66
+ def __exit__(self, exc_type, exc_value, traceback):
67
+ if self.temp_directory is not None:
68
+ self.temp_directory.cleanup()
69
+
70
+
71
+ class SupportWriteToRemote:
72
+ """Supports GCS URIs and other remote paths via a temporary file.
73
+
74
+ This is used for `.keras` and H5 files on GCS, CNS and CFS. TensorFlow
75
+ supports remoted saved model out of the box.
76
+ """
77
+
78
+ def __init__(self, filepath, overwrite=True, save_format=None):
79
+ save_format = get_save_format(filepath, save_format=save_format)
80
+ self.overwrite = overwrite
81
+ if saving_lib.is_remote_path(filepath) and save_format != "tf":
82
+ self.temp_directory = tempfile.TemporaryDirectory()
83
+ self.remote_filepath = filepath
84
+ if not is_oss and str(filepath).startswith("gs://"):
85
+ self.remote_filepath = self.remote_filepath.replace(
86
+ "gs://", "/bigstore/"
87
+ )
88
+ self.local_filepath = os.path.join(
89
+ self.temp_directory.name, os.path.basename(filepath)
90
+ )
91
+ else:
92
+ self.temp_directory = None
93
+ self.remote_filepath = None
94
+ self.local_filepath = filepath
95
+
96
+ def __enter__(self):
97
+ return self.local_filepath
98
+
99
+ def __exit__(self, exc_type, exc_value, traceback):
100
+ if self.temp_directory is not None:
101
+ tf.io.gfile.copy(
102
+ self.local_filepath,
103
+ self.remote_filepath,
104
+ overwrite=self.overwrite,
105
+ )
106
+ self.temp_directory.cleanup()
47
107
 
48
108
 
49
109
  @keras_export("keras.saving.save_model", "keras.models.save_model")
@@ -131,46 +191,49 @@ def save_model(model, filepath, overwrite=True, save_format=None, **kwargs):
131
191
  when loading the model. See the `custom_objects` argument in
132
192
  `tf.keras.saving.load_model`.
133
193
  """
134
- save_format = get_save_format(filepath, save_format)
135
-
136
- # Supports GCS URIs through bigstore via a temporary file
137
- gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)
138
-
139
- # Deprecation warnings
140
- if save_format == "h5":
141
- warnings.warn(
142
- "You are saving your model as an HDF5 file via `model.save()`. "
143
- "This file format is considered legacy. "
144
- "We recommend using instead the native TF-Keras format, "
145
- "e.g. `model.save('my_model.keras')`.",
146
- stacklevel=2,
147
- )
194
+ # Supports remote paths via a temporary file
195
+ with SupportWriteToRemote(
196
+ filepath,
197
+ overwrite=overwrite,
198
+ save_format=save_format,
199
+ ) as local_filepath:
200
+ save_format = get_save_format(filepath, save_format)
201
+
202
+ # Deprecation warnings
203
+ if save_format == "h5":
204
+ warnings.warn(
205
+ "You are saving your model as an HDF5 file via `model.save()`. "
206
+ "This file format is considered legacy. "
207
+ "We recommend using instead the native TF-Keras format, "
208
+ "e.g. `model.save('my_model.keras')`.",
209
+ stacklevel=2,
210
+ )
148
211
 
149
- if save_format == "keras":
150
- # If file exists and should not be overwritten.
151
- try:
152
- exists = os.path.exists(filepath)
153
- except TypeError:
154
- exists = False
155
- if exists and not overwrite:
156
- proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
157
- if not proceed:
158
- return
159
- if kwargs:
160
- raise ValueError(
161
- "The following argument(s) are not supported "
162
- f"with the native TF-Keras format: {list(kwargs.keys())}"
212
+ if save_format == "keras":
213
+ # If file exists and should not be overwritten.
214
+ try:
215
+ exists = os.path.exists(local_filepath)
216
+ except TypeError:
217
+ exists = False
218
+ if exists and not overwrite:
219
+ proceed = io_utils.ask_to_proceed_with_overwrite(local_filepath)
220
+ if not proceed:
221
+ return
222
+ if kwargs:
223
+ raise ValueError(
224
+ "The following argument(s) are not supported "
225
+ f"with the native TF-Keras format: {list(kwargs.keys())}"
226
+ )
227
+ saving_lib.save_model(model, local_filepath)
228
+ else:
229
+ # Legacy case
230
+ return legacy_sm_saving_lib.save_model(
231
+ model,
232
+ local_filepath,
233
+ overwrite=overwrite,
234
+ save_format=save_format,
235
+ **kwargs,
163
236
  )
164
- saving_lib.save_model(model, filepath)
165
- else:
166
- # Legacy case
167
- return legacy_sm_saving_lib.save_model(
168
- model,
169
- filepath,
170
- overwrite=overwrite,
171
- save_format=save_format,
172
- **kwargs,
173
- )
174
237
 
175
238
 
176
239
  @keras_export("keras.saving.load_model", "keras.models.load_model")
@@ -217,94 +280,69 @@ def load_model(
217
280
  It is recommended that you use layer attributes to
218
281
  access specific variables, e.g. `model.get_layer("dense_1").kernel`.
219
282
  """
220
- # Supports GCS URIs by copying data to temporary file
221
- save_format = get_save_format(filepath, save_format=None)
222
- gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)
223
- if gs_filepath is not None:
224
- tf.io.gfile.copy(gs_filepath, filepath, overwrite=True)
225
-
226
- is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile(
227
- filepath
228
- )
229
-
230
- # Support for remote zip files
231
- if (
232
- saving_lib.is_remote_path(filepath)
233
- and not tf.io.gfile.isdir(filepath)
234
- and not is_keras_zip
235
- ):
236
- local_path = os.path.join(
237
- saving_lib.get_temp_dir(), os.path.basename(filepath)
238
- )
239
-
240
- # Copy from remote to temporary local directory
241
- tf.io.gfile.copy(filepath, local_path, overwrite=True)
242
-
243
- # Switch filepath to local zipfile for loading model
244
- if zipfile.is_zipfile(local_path):
245
- filepath = local_path
246
- is_keras_zip = True
247
-
248
- if is_keras_zip:
249
- if kwargs:
250
- raise ValueError(
251
- "The following argument(s) are not supported "
252
- f"with the native TF-Keras format: {list(kwargs.keys())}"
283
+ # Supports remote paths via a temporary file
284
+ with SupportReadFromRemote(filepath) as local_filepath:
285
+ if str(local_filepath).endswith(".keras") and zipfile.is_zipfile(
286
+ local_filepath
287
+ ):
288
+ if kwargs:
289
+ raise ValueError(
290
+ "The following argument(s) are not supported "
291
+ f"with the native TF-Keras format: {list(kwargs.keys())}"
292
+ )
293
+ return saving_lib.load_model(
294
+ local_filepath,
295
+ custom_objects=custom_objects,
296
+ compile=compile,
297
+ safe_mode=safe_mode,
253
298
  )
254
- return saving_lib.load_model(
255
- filepath,
299
+
300
+ # Legacy case.
301
+ return legacy_sm_saving_lib.load_model(
302
+ local_filepath,
256
303
  custom_objects=custom_objects,
257
304
  compile=compile,
258
- safe_mode=safe_mode,
305
+ **kwargs,
259
306
  )
260
307
 
261
- # Legacy case.
262
- return legacy_sm_saving_lib.load_model(
263
- filepath, custom_objects=custom_objects, compile=compile, **kwargs
264
- )
265
-
266
308
 
267
309
  def save_weights(model, filepath, overwrite=True, **kwargs):
268
- # Supports GCS URIs through bigstore via a temporary file
269
- save_format = get_save_format(filepath, save_format=None)
270
- gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)
271
-
272
- if str(filepath).endswith(".weights.h5"):
273
- # If file exists and should not be overwritten.
274
- try:
275
- exists = os.path.exists(filepath)
276
- except TypeError:
277
- exists = False
278
- if exists and not overwrite:
279
- proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
280
- if not proceed:
281
- return
282
- saving_lib.save_weights_only(model, filepath)
283
- else:
284
- legacy_sm_saving_lib.save_weights(
285
- model, filepath, overwrite=overwrite, **kwargs
286
- )
310
+ # Supports remote paths via a temporary file
311
+ with SupportWriteToRemote(filepath, overwrite=overwrite) as local_filepath:
312
+ if str(local_filepath).endswith(".weights.h5"):
313
+ # If file exists and should not be overwritten.
314
+ try:
315
+ exists = os.path.exists(local_filepath)
316
+ except TypeError:
317
+ exists = False
318
+ if exists and not overwrite:
319
+ proceed = io_utils.ask_to_proceed_with_overwrite(local_filepath)
320
+ if not proceed:
321
+ return
322
+ saving_lib.save_weights_only(model, local_filepath)
323
+ else:
324
+ legacy_sm_saving_lib.save_weights(
325
+ model, local_filepath, overwrite=overwrite, **kwargs
326
+ )
287
327
 
288
328
 
289
329
  def load_weights(model, filepath, skip_mismatch=False, **kwargs):
290
- # Supports GCS URIs by copying data to temporary file
291
- save_format = get_save_format(filepath, save_format=None)
292
- gs_filepath, filepath = _support_gcs_uri(filepath, save_format, is_oss)
293
- if gs_filepath is not None:
294
- tf.io.gfile.copy(gs_filepath, filepath, overwrite=True)
295
-
296
- if str(filepath).endswith(".keras") and zipfile.is_zipfile(filepath):
297
- saving_lib.load_weights_only(
298
- model, filepath, skip_mismatch=skip_mismatch
299
- )
300
- elif str(filepath).endswith(".weights.h5"):
301
- saving_lib.load_weights_only(
302
- model, filepath, skip_mismatch=skip_mismatch
303
- )
304
- else:
305
- return legacy_sm_saving_lib.load_weights(
306
- model, filepath, skip_mismatch=skip_mismatch, **kwargs
307
- )
330
+ # Supports remote paths via a temporary file
331
+ with SupportReadFromRemote(filepath) as local_filepath:
332
+ if str(local_filepath).endswith(".keras") and zipfile.is_zipfile(
333
+ local_filepath
334
+ ):
335
+ saving_lib.load_weights_only(
336
+ model, local_filepath, skip_mismatch=skip_mismatch
337
+ )
338
+ elif str(local_filepath).endswith(".weights.h5"):
339
+ saving_lib.load_weights_only(
340
+ model, local_filepath, skip_mismatch=skip_mismatch
341
+ )
342
+ else:
343
+ return legacy_sm_saving_lib.load_weights(
344
+ model, local_filepath, skip_mismatch=skip_mismatch, **kwargs
345
+ )
308
346
 
309
347
 
310
348
  def get_save_format(filepath, save_format):
@@ -156,14 +156,8 @@ def save_model(model, filepath, weights_format="h5"):
156
156
  "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
157
157
  }
158
158
  )
159
- # TODO(rameshsampath): Need a better logic for local vs remote path
160
- if is_remote_path(filepath):
161
- # Remote path. Zip to local memory byte io and copy to remote
162
- zip_filepath = io.BytesIO()
163
- else:
164
- zip_filepath = filepath
165
159
  try:
166
- with zipfile.ZipFile(zip_filepath, "w") as zf:
160
+ with zipfile.ZipFile(filepath, "w") as zf:
167
161
  with zf.open(_METADATA_FILENAME, "w") as f:
168
162
  f.write(metadata_json.encode())
169
163
  with zf.open(_CONFIG_FILENAME, "w") as f:
@@ -195,10 +189,6 @@ def save_model(model, filepath, weights_format="h5"):
195
189
  )
196
190
  weights_store.close()
197
191
  asset_store.close()
198
-
199
- if is_remote_path(filepath):
200
- with tf.io.gfile.GFile(filepath, "wb") as f:
201
- f.write(zip_filepath.getvalue())
202
192
  except Exception as e:
203
193
  raise e
204
194
  finally:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tf_keras-nightly
3
- Version: 2.19.0.dev2024112610
3
+ Version: 2.19.0.dev2024113010
4
4
  Summary: Deep learning for humans.
5
5
  Home-page: https://keras.io/
6
6
  Download-URL: https://github.com/keras-team/tf-keras/tags
@@ -1,4 +1,4 @@
1
- tf_keras/__init__.py,sha256=NBr1Tzh6ZB8znc9LhSep9zgEMM1EA7B7QrTjkZ6AKDM,911
1
+ tf_keras/__init__.py,sha256=TkJSEfpxF7E_c5XsMLzCNEryzHQURhS-XXilfgzLWS0,911
2
2
  tf_keras/__internal__/__init__.py,sha256=OHQbeIC0QtRBI7dgXaJaVbH8F00x8dCI-DvEcIfyMsE,671
3
3
  tf_keras/__internal__/backend/__init__.py,sha256=LnMs2A6685gDG79fxqmdulIYlVE_3WmXlBTBo9ZWYcw,162
4
4
  tf_keras/__internal__/layers/__init__.py,sha256=F5SGMhOTPzm-PR44VrfinURHcVeQPIEdwnZlAkSTB3A,176
@@ -537,8 +537,8 @@ tf_keras/src/preprocessing/text.py,sha256=aomzwE3G2ErwzgL_Dj3ERA_2k7TZclaPMTBTrW
537
537
  tf_keras/src/saving/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
538
538
  tf_keras/src/saving/object_registration.py,sha256=N8aV6eqREYjW2ueQpL3guYHyh5KXuun3DZAlmjfYrTA,7830
539
539
  tf_keras/src/saving/pickle_utils.py,sha256=5GtHzwNWVaYfZ-0zn69-zn2yv3R6JUwzHOOamnjP7r0,2605
540
- tf_keras/src/saving/saving_api.py,sha256=OepUlpp79IjEy5NdXr6pMQoWvNVU2vTFDDzFojXsRhs,13152
541
- tf_keras/src/saving/saving_lib.py,sha256=4pgC9vv46losKZ42ytUOGRqxF2PrNuw-o3OQx9b-ylo,24673
540
+ tf_keras/src/saving/saving_api.py,sha256=q2_-CsVg81HeHnXDkRqpoBwI3Mpp_gQXjBidr5V_RLo,14838
541
+ tf_keras/src/saving/saving_lib.py,sha256=Rk5rOvxEmCvwUlG3bS0QpOgtURayYmIAZbh2GeTuUOc,24272
542
542
  tf_keras/src/saving/serialization_lib.py,sha256=kX4qf_fRp4LySkH9FU37DMd0AXxiUrXKT-VLR3JPl7w,30152
543
543
  tf_keras/src/saving/legacy/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
544
544
  tf_keras/src/saving/legacy/hdf5_format.py,sha256=IqFXHN96fuqKwu_akaqTyf9ISRPavP3Ahjydat948O4,42438
@@ -606,7 +606,7 @@ tf_keras/src/utils/legacy/__init__.py,sha256=EfMmeHYDzwvxNaktPhQbkTdcPSIGCqMhBND
606
606
  tf_keras/utils/__init__.py,sha256=b7_d-USe_EmLo02_P99Q1rUCzKBYayPCfiYFStP-0nw,2735
607
607
  tf_keras/utils/experimental/__init__.py,sha256=DzGogE2AosjxOVILQBT8PDDcqbWTc0wWnZRobCdpcec,97
608
608
  tf_keras/utils/legacy/__init__.py,sha256=7ujlDa5HeSRcth2NdqA0S1P2-VZF1kB3n68jye6Dj-8,189
609
- tf_keras_nightly-2.19.0.dev2024112610.dist-info/METADATA,sha256=dqzxkrH3z-LM4Qs7XG7X0diDtwCfWs0GznYSgOtyoG0,1637
610
- tf_keras_nightly-2.19.0.dev2024112610.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
611
- tf_keras_nightly-2.19.0.dev2024112610.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
612
- tf_keras_nightly-2.19.0.dev2024112610.dist-info/RECORD,,
609
+ tf_keras_nightly-2.19.0.dev2024113010.dist-info/METADATA,sha256=0iGYOVdmUKVLmSbD9H2U6BUQFryvRzLQ2Y8yBMNPNdQ,1637
610
+ tf_keras_nightly-2.19.0.dev2024113010.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
611
+ tf_keras_nightly-2.19.0.dev2024113010.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
612
+ tf_keras_nightly-2.19.0.dev2024113010.dist-info/RECORD,,