tf-keras-nightly 2.19.0.dev2024121210__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tf_keras/__init__.py +1 -1
  2. tf_keras/protobuf/projector_config_pb2.py +23 -12
  3. tf_keras/protobuf/saved_metadata_pb2.py +21 -10
  4. tf_keras/protobuf/versions_pb2.py +19 -8
  5. tf_keras/src/__init__.py +1 -1
  6. tf_keras/src/backend.py +1 -1
  7. tf_keras/src/datasets/boston_housing.py +14 -5
  8. tf_keras/src/datasets/cifar10.py +9 -1
  9. tf_keras/src/datasets/cifar100.py +7 -1
  10. tf_keras/src/datasets/fashion_mnist.py +16 -4
  11. tf_keras/src/datasets/imdb.py +8 -0
  12. tf_keras/src/datasets/mnist.py +9 -3
  13. tf_keras/src/datasets/reuters.py +8 -0
  14. tf_keras/src/engine/base_layer.py +235 -97
  15. tf_keras/src/engine/base_layer_utils.py +17 -5
  16. tf_keras/src/engine/base_layer_v1.py +12 -3
  17. tf_keras/src/engine/data_adapter.py +35 -19
  18. tf_keras/src/engine/functional.py +36 -15
  19. tf_keras/src/engine/input_layer.py +9 -0
  20. tf_keras/src/engine/input_spec.py +11 -1
  21. tf_keras/src/engine/sequential.py +29 -12
  22. tf_keras/src/layers/activation/softmax.py +26 -11
  23. tf_keras/src/layers/attention/multi_head_attention.py +8 -1
  24. tf_keras/src/layers/core/tf_op_layer.py +4 -0
  25. tf_keras/src/layers/normalization/spectral_normalization.py +29 -22
  26. tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
  27. tf_keras/src/metrics/confusion_metrics.py +51 -4
  28. tf_keras/src/models/sharpness_aware_minimization.py +17 -7
  29. tf_keras/src/preprocessing/sequence.py +2 -2
  30. tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
  31. tf_keras/src/saving/legacy/saving_utils.py +14 -2
  32. tf_keras/src/saving/saving_api.py +18 -5
  33. tf_keras/src/saving/saving_lib.py +1 -1
  34. tf_keras/src/utils/layer_utils.py +45 -3
  35. tf_keras/src/utils/metrics_utils.py +4 -1
  36. tf_keras/src/utils/tf_utils.py +2 -2
  37. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +14 -3
  38. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +40 -62
  39. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
  40. tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
  41. tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
  42. tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
  43. tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
  44. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
  45. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
  46. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
  47. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
  48. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
  49. tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
  50. tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
  51. tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
  52. tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
  53. tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
  54. tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
  55. tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
  56. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
  57. tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
  58. tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
  59. tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
  60. tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
  61. tf_keras/src/tests/keras_doctest.py +0 -159
  62. {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py CHANGED
@@ -27,4 +27,4 @@ from tf_keras.src.engine.sequential import Sequential
27
27
  from tf_keras.src.engine.training import Model
28
28
 
29
29
 
30
- __version__ = "2.19.0.dev2024121210"
30
+ __version__ = "2.21.0.dev2025123010"
@@ -1,11 +1,22 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
3
4
  # source: tf_keras/protobuf/projector_config.proto
5
+ # Protobuf Python Version: 6.31.1
4
6
  """Generated protocol buffer code."""
5
- from google.protobuf.internal import builder as _builder
6
7
  from google.protobuf import descriptor as _descriptor
7
8
  from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
8
10
  from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 31,
16
+ 1,
17
+ '',
18
+ 'tf_keras/protobuf/projector_config.proto'
19
+ )
9
20
  # @@protoc_insertion_point(imports)
10
21
 
11
22
  _sym_db = _symbol_database.Default()
@@ -15,15 +26,15 @@ _sym_db = _symbol_database.Default()
15
26
 
16
27
  DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(tf_keras/protobuf/projector_config.proto\x12 third_party.py.tf_keras.protobuf\">\n\x0eSpriteMetadata\x12\x12\n\nimage_path\x18\x01 \x01(\t\x12\x18\n\x10single_image_dim\x18\x02 \x03(\r\"\xc0\x01\n\rEmbeddingInfo\x12\x13\n\x0btensor_name\x18\x01 \x01(\t\x12\x15\n\rmetadata_path\x18\x02 \x01(\t\x12\x16\n\x0e\x62ookmarks_path\x18\x03 \x01(\t\x12\x14\n\x0ctensor_shape\x18\x04 \x03(\r\x12@\n\x06sprite\x18\x05 \x01(\x0b\x32\x30.third_party.py.tf_keras.protobuf.SpriteMetadata\x12\x13\n\x0btensor_path\x18\x06 \x01(\t\"\x93\x01\n\x0fProjectorConfig\x12\x1d\n\x15model_checkpoint_path\x18\x01 \x01(\t\x12\x43\n\nembeddings\x18\x02 \x03(\x0b\x32/.third_party.py.tf_keras.protobuf.EmbeddingInfo\x12\x1c\n\x14model_checkpoint_dir\x18\x03 \x01(\tb\x06proto3')
17
28
 
18
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
19
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.projector_config_pb2', globals())
20
- if _descriptor._USE_C_DESCRIPTORS == False:
21
-
22
- DESCRIPTOR._options = None
23
- _SPRITEMETADATA._serialized_start=78
24
- _SPRITEMETADATA._serialized_end=140
25
- _EMBEDDINGINFO._serialized_start=143
26
- _EMBEDDINGINFO._serialized_end=335
27
- _PROJECTORCONFIG._serialized_start=338
28
- _PROJECTORCONFIG._serialized_end=485
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.projector_config_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ DESCRIPTOR._loaded_options = None
34
+ _globals['_SPRITEMETADATA']._serialized_start=78
35
+ _globals['_SPRITEMETADATA']._serialized_end=140
36
+ _globals['_EMBEDDINGINFO']._serialized_start=143
37
+ _globals['_EMBEDDINGINFO']._serialized_end=335
38
+ _globals['_PROJECTORCONFIG']._serialized_start=338
39
+ _globals['_PROJECTORCONFIG']._serialized_end=485
29
40
  # @@protoc_insertion_point(module_scope)
@@ -1,11 +1,22 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
3
4
  # source: tf_keras/protobuf/saved_metadata.proto
5
+ # Protobuf Python Version: 6.31.1
4
6
  """Generated protocol buffer code."""
5
- from google.protobuf.internal import builder as _builder
6
7
  from google.protobuf import descriptor as _descriptor
7
8
  from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
8
10
  from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 31,
16
+ 1,
17
+ '',
18
+ 'tf_keras/protobuf/saved_metadata.proto'
19
+ )
9
20
  # @@protoc_insertion_point(imports)
10
21
 
11
22
  _sym_db = _symbol_database.Default()
@@ -16,13 +27,13 @@ from tf_keras.protobuf import versions_pb2 as tf__keras_dot_protobuf_dot_version
16
27
 
17
28
  DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&tf_keras/protobuf/saved_metadata.proto\x12 third_party.py.tf_keras.protobuf\x1a tf_keras/protobuf/versions.proto\"M\n\rSavedMetadata\x12<\n\x05nodes\x18\x01 \x03(\x0b\x32-.third_party.py.tf_keras.protobuf.SavedObject\"\x9c\x01\n\x0bSavedObject\x12\x0f\n\x07node_id\x18\x02 \x01(\x05\x12\x11\n\tnode_path\x18\x03 \x01(\t\x12\x12\n\nidentifier\x18\x04 \x01(\t\x12\x10\n\x08metadata\x18\x05 \x01(\t\x12=\n\x07version\x18\x06 \x01(\x0b\x32,.third_party.py.tf_keras.protobuf.VersionDefJ\x04\x08\x01\x10\x02\x62\x06proto3')
18
29
 
19
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
20
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.saved_metadata_pb2', globals())
21
- if _descriptor._USE_C_DESCRIPTORS == False:
22
-
23
- DESCRIPTOR._options = None
24
- _SAVEDMETADATA._serialized_start=110
25
- _SAVEDMETADATA._serialized_end=187
26
- _SAVEDOBJECT._serialized_start=190
27
- _SAVEDOBJECT._serialized_end=346
30
+ _globals = globals()
31
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
32
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.saved_metadata_pb2', _globals)
33
+ if not _descriptor._USE_C_DESCRIPTORS:
34
+ DESCRIPTOR._loaded_options = None
35
+ _globals['_SAVEDMETADATA']._serialized_start=110
36
+ _globals['_SAVEDMETADATA']._serialized_end=187
37
+ _globals['_SAVEDOBJECT']._serialized_start=190
38
+ _globals['_SAVEDOBJECT']._serialized_end=346
28
39
  # @@protoc_insertion_point(module_scope)
@@ -1,11 +1,22 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
3
4
  # source: tf_keras/protobuf/versions.proto
5
+ # Protobuf Python Version: 6.31.1
4
6
  """Generated protocol buffer code."""
5
- from google.protobuf.internal import builder as _builder
6
7
  from google.protobuf import descriptor as _descriptor
7
8
  from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
8
10
  from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 31,
16
+ 1,
17
+ '',
18
+ 'tf_keras/protobuf/versions.proto'
19
+ )
9
20
  # @@protoc_insertion_point(imports)
10
21
 
11
22
  _sym_db = _symbol_database.Default()
@@ -15,11 +26,11 @@ _sym_db = _symbol_database.Default()
15
26
 
16
27
  DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n tf_keras/protobuf/versions.proto\x12 third_party.py.tf_keras.protobuf\"K\n\nVersionDef\x12\x10\n\x08producer\x18\x01 \x01(\x05\x12\x14\n\x0cmin_consumer\x18\x02 \x01(\x05\x12\x15\n\rbad_consumers\x18\x03 \x03(\x05\x62\x06proto3')
17
28
 
18
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
19
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.versions_pb2', globals())
20
- if _descriptor._USE_C_DESCRIPTORS == False:
21
-
22
- DESCRIPTOR._options = None
23
- _VERSIONDEF._serialized_start=70
24
- _VERSIONDEF._serialized_end=145
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.versions_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ DESCRIPTOR._loaded_options = None
34
+ _globals['_VERSIONDEF']._serialized_start=70
35
+ _globals['_VERSIONDEF']._serialized_end=145
25
36
  # @@protoc_insertion_point(module_scope)
tf_keras/src/__init__.py CHANGED
@@ -35,7 +35,7 @@ from tf_keras.src.testing_infra import test_utils
35
35
  from tensorflow.python import tf2
36
36
  from tensorflow.python.util.tf_export import keras_export
37
37
 
38
- __version__ = "2.19.0"
38
+ __version__ = "2.21.0"
39
39
 
40
40
  keras_export("keras.__version__").export_constant(__name__, "__version__")
41
41
 
tf_keras/src/backend.py CHANGED
@@ -2029,7 +2029,7 @@ class RandomGenerator(tf.__internal__.tracking.AutoTrackable):
2029
2029
  if user_specified_seed is not None:
2030
2030
  return user_specified_seed
2031
2031
  elif getattr(_SEED_GENERATOR, "generator", None):
2032
- return _SEED_GENERATOR.generator.randint(1, 1e9)
2032
+ return _SEED_GENERATOR.generator.randint(1, int(1e9))
2033
2033
  else:
2034
2034
  return random.randint(1, int(1e9))
2035
2035
 
@@ -14,6 +14,8 @@
14
14
  # ==============================================================================
15
15
  """Boston housing price regression dataset."""
16
16
 
17
+ import os
18
+
17
19
  import numpy as np
18
20
 
19
21
  from tf_keras.src.utils.data_utils import get_file
@@ -23,7 +25,9 @@ from tensorflow.python.util.tf_export import keras_export
23
25
 
24
26
 
25
27
  @keras_export("keras.datasets.boston_housing.load_data")
26
- def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
28
+ def load_data(
29
+ path="boston_housing.npz", test_split=0.2, seed=113, cache_dir=None
30
+ ):
27
31
  """Loads the Boston Housing dataset.
28
32
 
29
33
  This is a dataset taken from the StatLib library which is maintained at
@@ -43,11 +47,12 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
43
47
  [StatLib website](http://lib.stat.cmu.edu/datasets/boston).
44
48
 
45
49
  Args:
46
- path: path where to cache the dataset locally
47
- (relative to `~/.keras/datasets`).
50
+ path: path where to cache the dataset locally (relative to
51
+ `~/.keras/datasets`).
48
52
  test_split: fraction of the data to reserve as test set.
49
- seed: Random seed for shuffling the data
50
- before computing the test split.
53
+ seed: Random seed for shuffling the data before computing the test split.
54
+ cache_dir: directory where to cache the dataset locally. When None,
55
+ defaults to `~/.keras/datasets`.
51
56
 
52
57
  Returns:
53
58
  Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -64,12 +69,16 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
64
69
  origin_folder = (
65
70
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
66
71
  )
72
+ if cache_dir:
73
+ cache_dir = os.path.expanduser(cache_dir)
74
+ os.makedirs(cache_dir, exist_ok=True)
67
75
  path = get_file(
68
76
  path,
69
77
  origin=origin_folder + "boston_housing.npz",
70
78
  file_hash=( # noqa: E501
71
79
  "f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
72
80
  ),
81
+ cache_dir=cache_dir,
73
82
  )
74
83
  with np.load(path, allow_pickle=True) as f:
75
84
  x = f["x"]
@@ -27,7 +27,7 @@ from tensorflow.python.util.tf_export import keras_export
27
27
 
28
28
 
29
29
  @keras_export("keras.datasets.cifar10.load_data")
30
- def load_data():
30
+ def load_data(cache_dir=None):
31
31
  """Loads the CIFAR10 dataset.
32
32
 
33
33
  This is a dataset of 50,000 32x32 color training images and 10,000 test
@@ -49,6 +49,10 @@ def load_data():
49
49
  | 8 | ship |
50
50
  | 9 | truck |
51
51
 
52
+ Args:
53
+ cache_dir: directory where to cache the dataset locally. When None,
54
+ defaults to `~/.keras/datasets`.
55
+
52
56
  Returns:
53
57
  Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
54
58
 
@@ -78,6 +82,9 @@ def load_data():
78
82
  """
79
83
  dirname = "cifar-10-batches-py"
80
84
  origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
85
+ if cache_dir:
86
+ cache_dir = os.path.expanduser(cache_dir)
87
+ os.makedirs(cache_dir, exist_ok=True)
81
88
  path = get_file(
82
89
  dirname,
83
90
  origin=origin,
@@ -85,6 +92,7 @@ def load_data():
85
92
  file_hash=( # noqa: E501
86
93
  "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
87
94
  ),
95
+ cache_dir=cache_dir,
88
96
  )
89
97
 
90
98
  num_train_samples = 50000
@@ -27,7 +27,7 @@ from tensorflow.python.util.tf_export import keras_export
27
27
 
28
28
 
29
29
  @keras_export("keras.datasets.cifar100.load_data")
30
- def load_data(label_mode="fine"):
30
+ def load_data(label_mode="fine", cache_dir=None):
31
31
  """Loads the CIFAR100 dataset.
32
32
 
33
33
  This is a dataset of 50,000 32x32 color training images and
@@ -39,6 +39,8 @@ def load_data(label_mode="fine"):
39
39
  label_mode: one of "fine", "coarse". If it is "fine" the category labels
40
40
  are the fine-grained labels, if it is "coarse" the output labels are the
41
41
  coarse-grained superclasses.
42
+ cache_dir: directory where to cache the dataset locally. When None,
43
+ defaults to `~/.keras/datasets`.
42
44
 
43
45
  Returns:
44
46
  Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -75,6 +77,9 @@ def load_data(label_mode="fine"):
75
77
 
76
78
  dirname = "cifar-100-python"
77
79
  origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
80
+ if cache_dir:
81
+ cache_dir = os.path.expanduser(cache_dir)
82
+ os.makedirs(cache_dir, exist_ok=True)
78
83
  path = get_file(
79
84
  dirname,
80
85
  origin=origin,
@@ -82,6 +87,7 @@ def load_data(label_mode="fine"):
82
87
  file_hash=( # noqa: E501
83
88
  "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
84
89
  ),
90
+ cache_dir=cache_dir,
85
91
  )
86
92
 
87
93
  fpath = os.path.join(path, "train")
@@ -26,7 +26,7 @@ from tensorflow.python.util.tf_export import keras_export
26
26
 
27
27
 
28
28
  @keras_export("keras.datasets.fashion_mnist.load_data")
29
- def load_data():
29
+ def load_data(cache_dir=None):
30
30
  """Loads the Fashion-MNIST dataset.
31
31
 
32
32
  This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,
@@ -48,6 +48,10 @@ def load_data():
48
48
  | 8 | Bag |
49
49
  | 9 | Ankle boot |
50
50
 
51
+ Args:
52
+ cache_dir: directory where to cache the dataset locally. When None,
53
+ defaults to `~/.keras/datasets`.
54
+
51
55
  Returns:
52
56
  Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
53
57
 
@@ -77,7 +81,6 @@ def load_data():
77
81
  The copyright for Fashion-MNIST is held by Zalando SE.
78
82
  Fashion-MNIST is licensed under the [MIT license](
79
83
  https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
80
-
81
84
  """
82
85
  dirname = os.path.join("datasets", "fashion-mnist")
83
86
  base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
@@ -87,10 +90,19 @@ def load_data():
87
90
  "t10k-labels-idx1-ubyte.gz",
88
91
  "t10k-images-idx3-ubyte.gz",
89
92
  ]
90
-
93
+ if cache_dir:
94
+ cache_dir = os.path.expanduser(cache_dir)
95
+ os.makedirs(cache_dir, exist_ok=True)
91
96
  paths = []
92
97
  for fname in files:
93
- paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))
98
+ paths.append(
99
+ get_file(
100
+ fname,
101
+ origin=base + fname,
102
+ cache_dir=cache_dir,
103
+ cache_subdir=dirname,
104
+ )
105
+ )
94
106
 
95
107
  with gzip.open(paths[0], "rb") as lbpath:
96
108
  y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
@@ -15,6 +15,7 @@
15
15
  """IMDB sentiment classification dataset."""
16
16
 
17
17
  import json
18
+ import os
18
19
 
19
20
  import numpy as np
20
21
 
@@ -36,6 +37,7 @@ def load_data(
36
37
  start_char=1,
37
38
  oov_char=2,
38
39
  index_from=3,
40
+ cache_dir=None,
39
41
  **kwargs,
40
42
  ):
41
43
  """Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
@@ -73,6 +75,8 @@ def load_data(
73
75
  Words that were cut out because of the `num_words` or
74
76
  `skip_top` limits will be replaced with this character.
75
77
  index_from: int. Index actual words with this index and higher.
78
+ cache_dir: directory where to cache the dataset locally. When None,
79
+ defaults to `~/.keras/datasets`.
76
80
  **kwargs: Used for backwards compatibility.
77
81
 
78
82
  Returns:
@@ -108,12 +112,16 @@ def load_data(
108
112
  origin_folder = (
109
113
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
110
114
  )
115
+ if cache_dir:
116
+ cache_dir = os.path.expanduser(cache_dir)
117
+ os.makedirs(cache_dir, exist_ok=True)
111
118
  path = get_file(
112
119
  path,
113
120
  origin=origin_folder + "imdb.npz",
114
121
  file_hash=( # noqa: E501
115
122
  "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
116
123
  ),
124
+ cache_dir=cache_dir,
117
125
  )
118
126
  with np.load(path, allow_pickle=True) as f:
119
127
  x_train, labels_train = f["x_train"], f["y_train"]
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """MNIST handwritten digits dataset."""
16
+ import os
16
17
 
17
18
  import numpy as np
18
19
 
@@ -23,7 +24,7 @@ from tensorflow.python.util.tf_export import keras_export
23
24
 
24
25
 
25
26
  @keras_export("keras.datasets.mnist.load_data")
26
- def load_data(path="mnist.npz"):
27
+ def load_data(path="mnist.npz", cache_dir=None):
27
28
  """Loads the MNIST dataset.
28
29
 
29
30
  This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
@@ -32,8 +33,9 @@ def load_data(path="mnist.npz"):
32
33
  [MNIST homepage](http://yann.lecun.com/exdb/mnist/).
33
34
 
34
35
  Args:
35
- path: path where to cache the dataset locally
36
- (relative to `~/.keras/datasets`).
36
+ path: path where to cache the dataset locally relative to cache_dir.
37
+ cache_dir: dir location where to cache the dataset locally. When None,
38
+ defaults to `~/.keras/datasets`.
37
39
 
38
40
  Returns:
39
41
  Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -72,12 +74,16 @@ def load_data(path="mnist.npz"):
72
74
  origin_folder = (
73
75
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
74
76
  )
77
+ if cache_dir:
78
+ cache_dir = os.path.expanduser(cache_dir)
79
+ os.makedirs(cache_dir, exist_ok=True)
75
80
  path = get_file(
76
81
  path,
77
82
  origin=origin_folder + "mnist.npz",
78
83
  file_hash=( # noqa: E501
79
84
  "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
80
85
  ),
86
+ cache_dir=cache_dir,
81
87
  )
82
88
  with np.load(path, allow_pickle=True) as f:
83
89
  x_train, y_train = f["x_train"], f["y_train"]
@@ -15,6 +15,7 @@
15
15
  """Reuters topic classification dataset."""
16
16
 
17
17
  import json
18
+ import os
18
19
 
19
20
  import numpy as np
20
21
 
@@ -37,6 +38,7 @@ def load_data(
37
38
  start_char=1,
38
39
  oov_char=2,
39
40
  index_from=3,
41
+ cache_dir=None,
40
42
  **kwargs,
41
43
  ):
42
44
  """Loads the Reuters newswire classification dataset.
@@ -83,6 +85,8 @@ def load_data(
83
85
  Words that were cut out because of the `num_words` or
84
86
  `skip_top` limits will be replaced with this character.
85
87
  index_from: int. Index actual words with this index and higher.
88
+ cache_dir: directory where to cache the dataset locally. When None,
89
+ defaults to `~/.keras/datasets`.
86
90
  **kwargs: Used for backwards compatibility.
87
91
 
88
92
  Returns:
@@ -114,12 +118,16 @@ def load_data(
114
118
  origin_folder = (
115
119
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
116
120
  )
121
+ if cache_dir:
122
+ cache_dir = os.path.expanduser(cache_dir)
123
+ os.makedirs(cache_dir, exist_ok=True)
117
124
  path = get_file(
118
125
  path,
119
126
  origin=origin_folder + "reuters.npz",
120
127
  file_hash=( # noqa: E501
121
128
  "d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916"
122
129
  ),
130
+ cache_dir=cache_dir,
123
131
  )
124
132
  with np.load(path, allow_pickle=True) as f:
125
133
  xs, labels = f["x"], f["y"]