tf-keras-nightly 2.19.0.dev2024121210__py3-none-any.whl → 2.21.0.dev2025123010__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tf_keras/__init__.py +1 -1
- tf_keras/protobuf/projector_config_pb2.py +23 -12
- tf_keras/protobuf/saved_metadata_pb2.py +21 -10
- tf_keras/protobuf/versions_pb2.py +19 -8
- tf_keras/src/__init__.py +1 -1
- tf_keras/src/backend.py +1 -1
- tf_keras/src/datasets/boston_housing.py +14 -5
- tf_keras/src/datasets/cifar10.py +9 -1
- tf_keras/src/datasets/cifar100.py +7 -1
- tf_keras/src/datasets/fashion_mnist.py +16 -4
- tf_keras/src/datasets/imdb.py +8 -0
- tf_keras/src/datasets/mnist.py +9 -3
- tf_keras/src/datasets/reuters.py +8 -0
- tf_keras/src/engine/base_layer.py +235 -97
- tf_keras/src/engine/base_layer_utils.py +17 -5
- tf_keras/src/engine/base_layer_v1.py +12 -3
- tf_keras/src/engine/data_adapter.py +35 -19
- tf_keras/src/engine/functional.py +36 -15
- tf_keras/src/engine/input_layer.py +9 -0
- tf_keras/src/engine/input_spec.py +11 -1
- tf_keras/src/engine/sequential.py +29 -12
- tf_keras/src/layers/activation/softmax.py +26 -11
- tf_keras/src/layers/attention/multi_head_attention.py +8 -1
- tf_keras/src/layers/core/tf_op_layer.py +4 -0
- tf_keras/src/layers/normalization/spectral_normalization.py +29 -22
- tf_keras/src/layers/rnn/cell_wrappers.py +13 -1
- tf_keras/src/metrics/confusion_metrics.py +51 -4
- tf_keras/src/models/sharpness_aware_minimization.py +17 -7
- tf_keras/src/preprocessing/sequence.py +2 -2
- tf_keras/src/saving/legacy/saved_model/save_impl.py +28 -12
- tf_keras/src/saving/legacy/saving_utils.py +14 -2
- tf_keras/src/saving/saving_api.py +18 -5
- tf_keras/src/saving/saving_lib.py +1 -1
- tf_keras/src/utils/layer_utils.py +45 -3
- tf_keras/src/utils/metrics_utils.py +4 -1
- tf_keras/src/utils/tf_utils.py +2 -2
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/METADATA +14 -3
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/RECORD +40 -62
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/WHEEL +1 -1
- tf_keras/src/layers/preprocessing/benchmarks/bucketized_column_dense_benchmark.py +0 -85
- tf_keras/src/layers/preprocessing/benchmarks/category_encoding_benchmark.py +0 -84
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_dense_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_hash_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_dense_benchmark.py +0 -110
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_file_varlen_benchmark.py +0 -103
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_dense_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_dense_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_indicator_varlen_benchmark.py +0 -96
- tf_keras/src/layers/preprocessing/benchmarks/category_vocab_list_varlen_benchmark.py +0 -87
- tf_keras/src/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py +0 -109
- tf_keras/src/layers/preprocessing/benchmarks/embedding_dense_benchmark.py +0 -86
- tf_keras/src/layers/preprocessing/benchmarks/embedding_varlen_benchmark.py +0 -89
- tf_keras/src/layers/preprocessing/benchmarks/hashed_crossing_benchmark.py +0 -90
- tf_keras/src/layers/preprocessing/benchmarks/hashing_benchmark.py +0 -105
- tf_keras/src/layers/preprocessing/benchmarks/image_preproc_benchmark.py +0 -159
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_adapt_benchmark.py +0 -135
- tf_keras/src/layers/preprocessing/benchmarks/index_lookup_forward_benchmark.py +0 -144
- tf_keras/src/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py +0 -124
- tf_keras/src/layers/preprocessing/benchmarks/weighted_embedding_varlen_benchmark.py +0 -99
- tf_keras/src/saving/legacy/saved_model/create_test_saved_model.py +0 -37
- tf_keras/src/tests/keras_doctest.py +0 -159
- {tf_keras_nightly-2.19.0.dev2024121210.dist-info → tf_keras_nightly-2.21.0.dev2025123010.dist-info}/top_level.txt +0 -0
tf_keras/__init__.py
CHANGED
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
3
4
|
# source: tf_keras/protobuf/projector_config.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
4
6
|
"""Generated protocol buffer code."""
|
|
5
|
-
from google.protobuf.internal import builder as _builder
|
|
6
7
|
from google.protobuf import descriptor as _descriptor
|
|
7
8
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
8
10
|
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'tf_keras/protobuf/projector_config.proto'
|
|
19
|
+
)
|
|
9
20
|
# @@protoc_insertion_point(imports)
|
|
10
21
|
|
|
11
22
|
_sym_db = _symbol_database.Default()
|
|
@@ -15,15 +26,15 @@ _sym_db = _symbol_database.Default()
|
|
|
15
26
|
|
|
16
27
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(tf_keras/protobuf/projector_config.proto\x12 third_party.py.tf_keras.protobuf\">\n\x0eSpriteMetadata\x12\x12\n\nimage_path\x18\x01 \x01(\t\x12\x18\n\x10single_image_dim\x18\x02 \x03(\r\"\xc0\x01\n\rEmbeddingInfo\x12\x13\n\x0btensor_name\x18\x01 \x01(\t\x12\x15\n\rmetadata_path\x18\x02 \x01(\t\x12\x16\n\x0e\x62ookmarks_path\x18\x03 \x01(\t\x12\x14\n\x0ctensor_shape\x18\x04 \x03(\r\x12@\n\x06sprite\x18\x05 \x01(\x0b\x32\x30.third_party.py.tf_keras.protobuf.SpriteMetadata\x12\x13\n\x0btensor_path\x18\x06 \x01(\t\"\x93\x01\n\x0fProjectorConfig\x12\x1d\n\x15model_checkpoint_path\x18\x01 \x01(\t\x12\x43\n\nembeddings\x18\x02 \x03(\x0b\x32/.third_party.py.tf_keras.protobuf.EmbeddingInfo\x12\x1c\n\x14model_checkpoint_dir\x18\x03 \x01(\tb\x06proto3')
|
|
17
28
|
|
|
18
|
-
|
|
19
|
-
_builder.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
DESCRIPTOR.
|
|
23
|
-
_SPRITEMETADATA._serialized_start=78
|
|
24
|
-
_SPRITEMETADATA._serialized_end=140
|
|
25
|
-
_EMBEDDINGINFO._serialized_start=143
|
|
26
|
-
_EMBEDDINGINFO._serialized_end=335
|
|
27
|
-
_PROJECTORCONFIG._serialized_start=338
|
|
28
|
-
_PROJECTORCONFIG._serialized_end=485
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.projector_config_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
DESCRIPTOR._loaded_options = None
|
|
34
|
+
_globals['_SPRITEMETADATA']._serialized_start=78
|
|
35
|
+
_globals['_SPRITEMETADATA']._serialized_end=140
|
|
36
|
+
_globals['_EMBEDDINGINFO']._serialized_start=143
|
|
37
|
+
_globals['_EMBEDDINGINFO']._serialized_end=335
|
|
38
|
+
_globals['_PROJECTORCONFIG']._serialized_start=338
|
|
39
|
+
_globals['_PROJECTORCONFIG']._serialized_end=485
|
|
29
40
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
3
4
|
# source: tf_keras/protobuf/saved_metadata.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
4
6
|
"""Generated protocol buffer code."""
|
|
5
|
-
from google.protobuf.internal import builder as _builder
|
|
6
7
|
from google.protobuf import descriptor as _descriptor
|
|
7
8
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
8
10
|
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'tf_keras/protobuf/saved_metadata.proto'
|
|
19
|
+
)
|
|
9
20
|
# @@protoc_insertion_point(imports)
|
|
10
21
|
|
|
11
22
|
_sym_db = _symbol_database.Default()
|
|
@@ -16,13 +27,13 @@ from tf_keras.protobuf import versions_pb2 as tf__keras_dot_protobuf_dot_version
|
|
|
16
27
|
|
|
17
28
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&tf_keras/protobuf/saved_metadata.proto\x12 third_party.py.tf_keras.protobuf\x1a tf_keras/protobuf/versions.proto\"M\n\rSavedMetadata\x12<\n\x05nodes\x18\x01 \x03(\x0b\x32-.third_party.py.tf_keras.protobuf.SavedObject\"\x9c\x01\n\x0bSavedObject\x12\x0f\n\x07node_id\x18\x02 \x01(\x05\x12\x11\n\tnode_path\x18\x03 \x01(\t\x12\x12\n\nidentifier\x18\x04 \x01(\t\x12\x10\n\x08metadata\x18\x05 \x01(\t\x12=\n\x07version\x18\x06 \x01(\x0b\x32,.third_party.py.tf_keras.protobuf.VersionDefJ\x04\x08\x01\x10\x02\x62\x06proto3')
|
|
18
29
|
|
|
19
|
-
|
|
20
|
-
_builder.
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
DESCRIPTOR.
|
|
24
|
-
_SAVEDMETADATA._serialized_start=110
|
|
25
|
-
_SAVEDMETADATA._serialized_end=187
|
|
26
|
-
_SAVEDOBJECT._serialized_start=190
|
|
27
|
-
_SAVEDOBJECT._serialized_end=346
|
|
30
|
+
_globals = globals()
|
|
31
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
32
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.saved_metadata_pb2', _globals)
|
|
33
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
34
|
+
DESCRIPTOR._loaded_options = None
|
|
35
|
+
_globals['_SAVEDMETADATA']._serialized_start=110
|
|
36
|
+
_globals['_SAVEDMETADATA']._serialized_end=187
|
|
37
|
+
_globals['_SAVEDOBJECT']._serialized_start=190
|
|
38
|
+
_globals['_SAVEDOBJECT']._serialized_end=346
|
|
28
39
|
# @@protoc_insertion_point(module_scope)
|
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
3
4
|
# source: tf_keras/protobuf/versions.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
4
6
|
"""Generated protocol buffer code."""
|
|
5
|
-
from google.protobuf.internal import builder as _builder
|
|
6
7
|
from google.protobuf import descriptor as _descriptor
|
|
7
8
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
8
10
|
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'tf_keras/protobuf/versions.proto'
|
|
19
|
+
)
|
|
9
20
|
# @@protoc_insertion_point(imports)
|
|
10
21
|
|
|
11
22
|
_sym_db = _symbol_database.Default()
|
|
@@ -15,11 +26,11 @@ _sym_db = _symbol_database.Default()
|
|
|
15
26
|
|
|
16
27
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n tf_keras/protobuf/versions.proto\x12 third_party.py.tf_keras.protobuf\"K\n\nVersionDef\x12\x10\n\x08producer\x18\x01 \x01(\x05\x12\x14\n\x0cmin_consumer\x18\x02 \x01(\x05\x12\x15\n\rbad_consumers\x18\x03 \x03(\x05\x62\x06proto3')
|
|
17
28
|
|
|
18
|
-
|
|
19
|
-
_builder.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
DESCRIPTOR.
|
|
23
|
-
_VERSIONDEF._serialized_start=70
|
|
24
|
-
_VERSIONDEF._serialized_end=145
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tf_keras.protobuf.versions_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
DESCRIPTOR._loaded_options = None
|
|
34
|
+
_globals['_VERSIONDEF']._serialized_start=70
|
|
35
|
+
_globals['_VERSIONDEF']._serialized_end=145
|
|
25
36
|
# @@protoc_insertion_point(module_scope)
|
tf_keras/src/__init__.py
CHANGED
|
@@ -35,7 +35,7 @@ from tf_keras.src.testing_infra import test_utils
|
|
|
35
35
|
from tensorflow.python import tf2
|
|
36
36
|
from tensorflow.python.util.tf_export import keras_export
|
|
37
37
|
|
|
38
|
-
__version__ = "2.
|
|
38
|
+
__version__ = "2.21.0"
|
|
39
39
|
|
|
40
40
|
keras_export("keras.__version__").export_constant(__name__, "__version__")
|
|
41
41
|
|
tf_keras/src/backend.py
CHANGED
|
@@ -2029,7 +2029,7 @@ class RandomGenerator(tf.__internal__.tracking.AutoTrackable):
|
|
|
2029
2029
|
if user_specified_seed is not None:
|
|
2030
2030
|
return user_specified_seed
|
|
2031
2031
|
elif getattr(_SEED_GENERATOR, "generator", None):
|
|
2032
|
-
return _SEED_GENERATOR.generator.randint(1, 1e9)
|
|
2032
|
+
return _SEED_GENERATOR.generator.randint(1, int(1e9))
|
|
2033
2033
|
else:
|
|
2034
2034
|
return random.randint(1, int(1e9))
|
|
2035
2035
|
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Boston housing price regression dataset."""
|
|
16
16
|
|
|
17
|
+
import os
|
|
18
|
+
|
|
17
19
|
import numpy as np
|
|
18
20
|
|
|
19
21
|
from tf_keras.src.utils.data_utils import get_file
|
|
@@ -23,7 +25,9 @@ from tensorflow.python.util.tf_export import keras_export
|
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
@keras_export("keras.datasets.boston_housing.load_data")
|
|
26
|
-
def load_data(
|
|
28
|
+
def load_data(
|
|
29
|
+
path="boston_housing.npz", test_split=0.2, seed=113, cache_dir=None
|
|
30
|
+
):
|
|
27
31
|
"""Loads the Boston Housing dataset.
|
|
28
32
|
|
|
29
33
|
This is a dataset taken from the StatLib library which is maintained at
|
|
@@ -43,11 +47,12 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
|
|
|
43
47
|
[StatLib website](http://lib.stat.cmu.edu/datasets/boston).
|
|
44
48
|
|
|
45
49
|
Args:
|
|
46
|
-
path: path where to cache the dataset locally
|
|
47
|
-
|
|
50
|
+
path: path where to cache the dataset locally (relative to
|
|
51
|
+
`~/.keras/datasets`).
|
|
48
52
|
test_split: fraction of the data to reserve as test set.
|
|
49
|
-
seed: Random seed for shuffling the data
|
|
50
|
-
|
|
53
|
+
seed: Random seed for shuffling the data before computing the test split.
|
|
54
|
+
cache_dir: directory where to cache the dataset locally. When None,
|
|
55
|
+
defaults to `~/.keras/datasets`.
|
|
51
56
|
|
|
52
57
|
Returns:
|
|
53
58
|
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
|
@@ -64,12 +69,16 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
|
|
|
64
69
|
origin_folder = (
|
|
65
70
|
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
|
|
66
71
|
)
|
|
72
|
+
if cache_dir:
|
|
73
|
+
cache_dir = os.path.expanduser(cache_dir)
|
|
74
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
67
75
|
path = get_file(
|
|
68
76
|
path,
|
|
69
77
|
origin=origin_folder + "boston_housing.npz",
|
|
70
78
|
file_hash=( # noqa: E501
|
|
71
79
|
"f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
|
|
72
80
|
),
|
|
81
|
+
cache_dir=cache_dir,
|
|
73
82
|
)
|
|
74
83
|
with np.load(path, allow_pickle=True) as f:
|
|
75
84
|
x = f["x"]
|
tf_keras/src/datasets/cifar10.py
CHANGED
|
@@ -27,7 +27,7 @@ from tensorflow.python.util.tf_export import keras_export
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
@keras_export("keras.datasets.cifar10.load_data")
|
|
30
|
-
def load_data():
|
|
30
|
+
def load_data(cache_dir=None):
|
|
31
31
|
"""Loads the CIFAR10 dataset.
|
|
32
32
|
|
|
33
33
|
This is a dataset of 50,000 32x32 color training images and 10,000 test
|
|
@@ -49,6 +49,10 @@ def load_data():
|
|
|
49
49
|
| 8 | ship |
|
|
50
50
|
| 9 | truck |
|
|
51
51
|
|
|
52
|
+
Args:
|
|
53
|
+
cache_dir: directory where to cache the dataset locally. When None,
|
|
54
|
+
defaults to `~/.keras/datasets`.
|
|
55
|
+
|
|
52
56
|
Returns:
|
|
53
57
|
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
|
|
54
58
|
|
|
@@ -78,6 +82,9 @@ def load_data():
|
|
|
78
82
|
"""
|
|
79
83
|
dirname = "cifar-10-batches-py"
|
|
80
84
|
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
|
85
|
+
if cache_dir:
|
|
86
|
+
cache_dir = os.path.expanduser(cache_dir)
|
|
87
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
81
88
|
path = get_file(
|
|
82
89
|
dirname,
|
|
83
90
|
origin=origin,
|
|
@@ -85,6 +92,7 @@ def load_data():
|
|
|
85
92
|
file_hash=( # noqa: E501
|
|
86
93
|
"6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
|
|
87
94
|
),
|
|
95
|
+
cache_dir=cache_dir,
|
|
88
96
|
)
|
|
89
97
|
|
|
90
98
|
num_train_samples = 50000
|
|
@@ -27,7 +27,7 @@ from tensorflow.python.util.tf_export import keras_export
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
@keras_export("keras.datasets.cifar100.load_data")
|
|
30
|
-
def load_data(label_mode="fine"):
|
|
30
|
+
def load_data(label_mode="fine", cache_dir=None):
|
|
31
31
|
"""Loads the CIFAR100 dataset.
|
|
32
32
|
|
|
33
33
|
This is a dataset of 50,000 32x32 color training images and
|
|
@@ -39,6 +39,8 @@ def load_data(label_mode="fine"):
|
|
|
39
39
|
label_mode: one of "fine", "coarse". If it is "fine" the category labels
|
|
40
40
|
are the fine-grained labels, if it is "coarse" the output labels are the
|
|
41
41
|
coarse-grained superclasses.
|
|
42
|
+
cache_dir: directory where to cache the dataset locally. When None,
|
|
43
|
+
defaults to `~/.keras/datasets`.
|
|
42
44
|
|
|
43
45
|
Returns:
|
|
44
46
|
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
|
|
@@ -75,6 +77,9 @@ def load_data(label_mode="fine"):
|
|
|
75
77
|
|
|
76
78
|
dirname = "cifar-100-python"
|
|
77
79
|
origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
|
80
|
+
if cache_dir:
|
|
81
|
+
cache_dir = os.path.expanduser(cache_dir)
|
|
82
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
78
83
|
path = get_file(
|
|
79
84
|
dirname,
|
|
80
85
|
origin=origin,
|
|
@@ -82,6 +87,7 @@ def load_data(label_mode="fine"):
|
|
|
82
87
|
file_hash=( # noqa: E501
|
|
83
88
|
"85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
|
|
84
89
|
),
|
|
90
|
+
cache_dir=cache_dir,
|
|
85
91
|
)
|
|
86
92
|
|
|
87
93
|
fpath = os.path.join(path, "train")
|
|
@@ -26,7 +26,7 @@ from tensorflow.python.util.tf_export import keras_export
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
@keras_export("keras.datasets.fashion_mnist.load_data")
|
|
29
|
-
def load_data():
|
|
29
|
+
def load_data(cache_dir=None):
|
|
30
30
|
"""Loads the Fashion-MNIST dataset.
|
|
31
31
|
|
|
32
32
|
This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,
|
|
@@ -48,6 +48,10 @@ def load_data():
|
|
|
48
48
|
| 8 | Bag |
|
|
49
49
|
| 9 | Ankle boot |
|
|
50
50
|
|
|
51
|
+
Args:
|
|
52
|
+
cache_dir: directory where to cache the dataset locally. When None,
|
|
53
|
+
defaults to `~/.keras/datasets`.
|
|
54
|
+
|
|
51
55
|
Returns:
|
|
52
56
|
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
|
|
53
57
|
|
|
@@ -77,7 +81,6 @@ def load_data():
|
|
|
77
81
|
The copyright for Fashion-MNIST is held by Zalando SE.
|
|
78
82
|
Fashion-MNIST is licensed under the [MIT license](
|
|
79
83
|
https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
|
|
80
|
-
|
|
81
84
|
"""
|
|
82
85
|
dirname = os.path.join("datasets", "fashion-mnist")
|
|
83
86
|
base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
|
|
@@ -87,10 +90,19 @@ def load_data():
|
|
|
87
90
|
"t10k-labels-idx1-ubyte.gz",
|
|
88
91
|
"t10k-images-idx3-ubyte.gz",
|
|
89
92
|
]
|
|
90
|
-
|
|
93
|
+
if cache_dir:
|
|
94
|
+
cache_dir = os.path.expanduser(cache_dir)
|
|
95
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
91
96
|
paths = []
|
|
92
97
|
for fname in files:
|
|
93
|
-
paths.append(
|
|
98
|
+
paths.append(
|
|
99
|
+
get_file(
|
|
100
|
+
fname,
|
|
101
|
+
origin=base + fname,
|
|
102
|
+
cache_dir=cache_dir,
|
|
103
|
+
cache_subdir=dirname,
|
|
104
|
+
)
|
|
105
|
+
)
|
|
94
106
|
|
|
95
107
|
with gzip.open(paths[0], "rb") as lbpath:
|
|
96
108
|
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
|
tf_keras/src/datasets/imdb.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""IMDB sentiment classification dataset."""
|
|
16
16
|
|
|
17
17
|
import json
|
|
18
|
+
import os
|
|
18
19
|
|
|
19
20
|
import numpy as np
|
|
20
21
|
|
|
@@ -36,6 +37,7 @@ def load_data(
|
|
|
36
37
|
start_char=1,
|
|
37
38
|
oov_char=2,
|
|
38
39
|
index_from=3,
|
|
40
|
+
cache_dir=None,
|
|
39
41
|
**kwargs,
|
|
40
42
|
):
|
|
41
43
|
"""Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
|
|
@@ -73,6 +75,8 @@ def load_data(
|
|
|
73
75
|
Words that were cut out because of the `num_words` or
|
|
74
76
|
`skip_top` limits will be replaced with this character.
|
|
75
77
|
index_from: int. Index actual words with this index and higher.
|
|
78
|
+
cache_dir: directory where to cache the dataset locally. When None,
|
|
79
|
+
defaults to `~/.keras/datasets`.
|
|
76
80
|
**kwargs: Used for backwards compatibility.
|
|
77
81
|
|
|
78
82
|
Returns:
|
|
@@ -108,12 +112,16 @@ def load_data(
|
|
|
108
112
|
origin_folder = (
|
|
109
113
|
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
|
|
110
114
|
)
|
|
115
|
+
if cache_dir:
|
|
116
|
+
cache_dir = os.path.expanduser(cache_dir)
|
|
117
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
111
118
|
path = get_file(
|
|
112
119
|
path,
|
|
113
120
|
origin=origin_folder + "imdb.npz",
|
|
114
121
|
file_hash=( # noqa: E501
|
|
115
122
|
"69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
|
|
116
123
|
),
|
|
124
|
+
cache_dir=cache_dir,
|
|
117
125
|
)
|
|
118
126
|
with np.load(path, allow_pickle=True) as f:
|
|
119
127
|
x_train, labels_train = f["x_train"], f["y_train"]
|
tf_keras/src/datasets/mnist.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""MNIST handwritten digits dataset."""
|
|
16
|
+
import os
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
|
|
@@ -23,7 +24,7 @@ from tensorflow.python.util.tf_export import keras_export
|
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
@keras_export("keras.datasets.mnist.load_data")
|
|
26
|
-
def load_data(path="mnist.npz"):
|
|
27
|
+
def load_data(path="mnist.npz", cache_dir=None):
|
|
27
28
|
"""Loads the MNIST dataset.
|
|
28
29
|
|
|
29
30
|
This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
|
|
@@ -32,8 +33,9 @@ def load_data(path="mnist.npz"):
|
|
|
32
33
|
[MNIST homepage](http://yann.lecun.com/exdb/mnist/).
|
|
33
34
|
|
|
34
35
|
Args:
|
|
35
|
-
path: path where to cache the dataset locally
|
|
36
|
-
|
|
36
|
+
path: path where to cache the dataset locally relative to cache_dir.
|
|
37
|
+
cache_dir: dir location where to cache the dataset locally. When None,
|
|
38
|
+
defaults to `~/.keras/datasets`.
|
|
37
39
|
|
|
38
40
|
Returns:
|
|
39
41
|
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
|
|
@@ -72,12 +74,16 @@ def load_data(path="mnist.npz"):
|
|
|
72
74
|
origin_folder = (
|
|
73
75
|
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
|
|
74
76
|
)
|
|
77
|
+
if cache_dir:
|
|
78
|
+
cache_dir = os.path.expanduser(cache_dir)
|
|
79
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
75
80
|
path = get_file(
|
|
76
81
|
path,
|
|
77
82
|
origin=origin_folder + "mnist.npz",
|
|
78
83
|
file_hash=( # noqa: E501
|
|
79
84
|
"731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
|
|
80
85
|
),
|
|
86
|
+
cache_dir=cache_dir,
|
|
81
87
|
)
|
|
82
88
|
with np.load(path, allow_pickle=True) as f:
|
|
83
89
|
x_train, y_train = f["x_train"], f["y_train"]
|
tf_keras/src/datasets/reuters.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
"""Reuters topic classification dataset."""
|
|
16
16
|
|
|
17
17
|
import json
|
|
18
|
+
import os
|
|
18
19
|
|
|
19
20
|
import numpy as np
|
|
20
21
|
|
|
@@ -37,6 +38,7 @@ def load_data(
|
|
|
37
38
|
start_char=1,
|
|
38
39
|
oov_char=2,
|
|
39
40
|
index_from=3,
|
|
41
|
+
cache_dir=None,
|
|
40
42
|
**kwargs,
|
|
41
43
|
):
|
|
42
44
|
"""Loads the Reuters newswire classification dataset.
|
|
@@ -83,6 +85,8 @@ def load_data(
|
|
|
83
85
|
Words that were cut out because of the `num_words` or
|
|
84
86
|
`skip_top` limits will be replaced with this character.
|
|
85
87
|
index_from: int. Index actual words with this index and higher.
|
|
88
|
+
cache_dir: directory where to cache the dataset locally. When None,
|
|
89
|
+
defaults to `~/.keras/datasets`.
|
|
86
90
|
**kwargs: Used for backwards compatibility.
|
|
87
91
|
|
|
88
92
|
Returns:
|
|
@@ -114,12 +118,16 @@ def load_data(
|
|
|
114
118
|
origin_folder = (
|
|
115
119
|
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
|
|
116
120
|
)
|
|
121
|
+
if cache_dir:
|
|
122
|
+
cache_dir = os.path.expanduser(cache_dir)
|
|
123
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
117
124
|
path = get_file(
|
|
118
125
|
path,
|
|
119
126
|
origin=origin_folder + "reuters.npz",
|
|
120
127
|
file_hash=( # noqa: E501
|
|
121
128
|
"d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916"
|
|
122
129
|
),
|
|
130
|
+
cache_dir=cache_dir,
|
|
123
131
|
)
|
|
124
132
|
with np.load(path, allow_pickle=True) as f:
|
|
125
133
|
xs, labels = f["x"], f["y"]
|