keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/tree/tree_api.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
|
|
3
3
|
from keras.src.api_export import keras_export
|
|
4
|
+
from keras.src.backend.config import backend
|
|
4
5
|
from keras.src.utils.module_utils import dmtree
|
|
5
6
|
from keras.src.utils.module_utils import optree
|
|
6
7
|
|
|
7
|
-
if
|
|
8
|
+
if backend() == "torch":
|
|
9
|
+
# torchtree_impl is especially used for Torch backend, as it works better
|
|
10
|
+
# with torch.compile.
|
|
11
|
+
from keras.src.tree import torchtree_impl as tree_impl
|
|
12
|
+
elif optree.available:
|
|
8
13
|
from keras.src.tree import optree_impl as tree_impl
|
|
9
14
|
elif dmtree.available:
|
|
10
15
|
from keras.src.tree import dmtree_impl as tree_impl
|
keras/src/utils/backend_utils.py
CHANGED
|
@@ -3,6 +3,7 @@ import importlib
|
|
|
3
3
|
import inspect
|
|
4
4
|
import os
|
|
5
5
|
import sys
|
|
6
|
+
import warnings
|
|
6
7
|
|
|
7
8
|
from keras.src import backend as backend_module
|
|
8
9
|
from keras.src.api_export import keras_export
|
|
@@ -124,9 +125,22 @@ def set_backend(backend):
|
|
|
124
125
|
|
|
125
126
|
Example:
|
|
126
127
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
128
|
+
>>> import os
|
|
129
|
+
>>> os.environ["KERAS_BACKEND"] = "tensorflow"
|
|
130
|
+
>>>
|
|
131
|
+
>>> import keras
|
|
132
|
+
>>> from keras import ops
|
|
133
|
+
>>> type(ops.ones(()))
|
|
134
|
+
<class 'tensorflow.python.framework.ops.EagerTensor'>
|
|
135
|
+
>>>
|
|
136
|
+
>>> keras.config.set_backend("jax")
|
|
137
|
+
UserWarning: Using `keras.config.set_backend` is dangerous...
|
|
138
|
+
>>> del keras, ops
|
|
139
|
+
>>>
|
|
140
|
+
>>> import keras
|
|
141
|
+
>>> from keras import ops
|
|
142
|
+
>>> type(ops.ones(()))
|
|
143
|
+
<class 'jaxlib.xla_extension.ArrayImpl'>
|
|
130
144
|
|
|
131
145
|
⚠️ WARNING ⚠️: Using this function is dangerous and should be done
|
|
132
146
|
carefully. Changing the backend will **NOT** convert
|
|
@@ -138,7 +152,7 @@ def set_backend(backend):
|
|
|
138
152
|
|
|
139
153
|
This includes any function or class instance that uses any Keras
|
|
140
154
|
functionality. All such code needs to be re-executed after calling
|
|
141
|
-
`set_backend()
|
|
155
|
+
`set_backend()` and re-importing all imported `keras` modules.
|
|
142
156
|
"""
|
|
143
157
|
os.environ["KERAS_BACKEND"] = backend
|
|
144
158
|
# Clear module cache.
|
|
@@ -159,3 +173,16 @@ def set_backend(backend):
|
|
|
159
173
|
module_name = module_name[module_name.find("'") + 1 :]
|
|
160
174
|
module_name = module_name[: module_name.find("'")]
|
|
161
175
|
globals()[key] = importlib.import_module(module_name)
|
|
176
|
+
|
|
177
|
+
warnings.warn(
|
|
178
|
+
"Using `keras.config.set_backend` is dangerous and should be done "
|
|
179
|
+
"carefully. Already-instantiated objects will not be converted. Thus, "
|
|
180
|
+
"any layers / tensors / etc. already created will no longer be usable "
|
|
181
|
+
"without errors. It is strongly recommended not to keep around any "
|
|
182
|
+
"Keras-originated objects instances created before calling "
|
|
183
|
+
"`set_backend()`. This includes any function or class instance that "
|
|
184
|
+
"uses any Keras functionality. All such code needs to be re-executed "
|
|
185
|
+
"after calling `set_backend()` and re-importing all imported `keras` "
|
|
186
|
+
"modules.",
|
|
187
|
+
stacklevel=2,
|
|
188
|
+
)
|
keras/src/utils/dataset_utils.py
CHANGED
|
@@ -6,17 +6,22 @@ from multiprocessing.pool import ThreadPool
|
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
9
|
+
from keras.src import backend
|
|
9
10
|
from keras.src import tree
|
|
10
11
|
from keras.src.api_export import keras_export
|
|
11
12
|
from keras.src.utils import file_utils
|
|
12
13
|
from keras.src.utils import io_utils
|
|
13
14
|
from keras.src.utils.module_utils import grain
|
|
14
|
-
from keras.src.utils.module_utils import tensorflow as tf
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@keras_export("keras.utils.split_dataset")
|
|
18
18
|
def split_dataset(
|
|
19
|
-
dataset,
|
|
19
|
+
dataset,
|
|
20
|
+
left_size=None,
|
|
21
|
+
right_size=None,
|
|
22
|
+
shuffle=False,
|
|
23
|
+
seed=None,
|
|
24
|
+
preferred_backend=None,
|
|
20
25
|
):
|
|
21
26
|
"""Splits a dataset into a left half and a right half (e.g. train / test).
|
|
22
27
|
|
|
@@ -37,27 +42,86 @@ def split_dataset(
|
|
|
37
42
|
Defaults to `None`.
|
|
38
43
|
shuffle: Boolean, whether to shuffle the data before splitting it.
|
|
39
44
|
seed: A random seed for shuffling.
|
|
45
|
+
preferred_backend: String, specifying which backend
|
|
46
|
+
(e.g.; "tensorflow", "torch") to use. If `None`, the
|
|
47
|
+
backend is inferred from the type of `dataset` - if
|
|
48
|
+
`dataset` is a `tf.data.Dataset`, "tensorflow" backend
|
|
49
|
+
is used, if `dataset` is a `torch.utils.data.Dataset`,
|
|
50
|
+
"torch" backend is used, and if `dataset` is a list/tuple/np.array
|
|
51
|
+
the current Keras backend is used. Defaults to `None`.
|
|
40
52
|
|
|
41
53
|
Returns:
|
|
42
|
-
A tuple of two
|
|
43
|
-
the
|
|
44
|
-
|
|
54
|
+
A tuple of two dataset objects, the left and right splits. The exact
|
|
55
|
+
type of the returned objects depends on the `preferred_backend`.
|
|
56
|
+
For example, with a "tensorflow" backend,
|
|
57
|
+
`tf.data.Dataset` objects are returned. With a "torch" backend,
|
|
58
|
+
`torch.utils.data.Dataset` objects are returned.
|
|
45
59
|
Example:
|
|
46
60
|
|
|
47
61
|
>>> data = np.random.random(size=(1000, 4))
|
|
48
62
|
>>> left_ds, right_ds = keras.utils.split_dataset(data, left_size=0.8)
|
|
49
|
-
>>>
|
|
50
|
-
|
|
51
|
-
>>>
|
|
52
|
-
|
|
63
|
+
>>> # For a tf.data.Dataset, you can use .cardinality()
|
|
64
|
+
>>> # >>> int(left_ds.cardinality())
|
|
65
|
+
>>> # 800
|
|
66
|
+
>>> # For a torch.utils.data.Dataset, you can use len()
|
|
67
|
+
>>> # >>> len(left_ds)
|
|
68
|
+
>>> # 800
|
|
53
69
|
"""
|
|
70
|
+
preferred_backend = preferred_backend or _infer_preferred_backend(dataset)
|
|
71
|
+
if preferred_backend != "torch":
|
|
72
|
+
return _split_dataset_tf(
|
|
73
|
+
dataset,
|
|
74
|
+
left_size=left_size,
|
|
75
|
+
right_size=right_size,
|
|
76
|
+
shuffle=shuffle,
|
|
77
|
+
seed=seed,
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
return _split_dataset_torch(
|
|
81
|
+
dataset,
|
|
82
|
+
left_size=left_size,
|
|
83
|
+
right_size=right_size,
|
|
84
|
+
shuffle=shuffle,
|
|
85
|
+
seed=seed,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _split_dataset_tf(
|
|
90
|
+
dataset, left_size=None, right_size=None, shuffle=False, seed=None
|
|
91
|
+
):
|
|
92
|
+
"""Splits a dataset into a left half and a right half (e.g. train / test).
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
dataset:
|
|
96
|
+
A `tf.data.Dataset` object,
|
|
97
|
+
or a list/tuple of arrays with the same length.
|
|
98
|
+
left_size: If float (in the range `[0, 1]`), it signifies
|
|
99
|
+
the fraction of the data to pack in the left dataset. If integer, it
|
|
100
|
+
signifies the number of samples to pack in the left dataset. If
|
|
101
|
+
`None`, defaults to the complement to `right_size`.
|
|
102
|
+
Defaults to `None`.
|
|
103
|
+
right_size: If float (in the range `[0, 1]`), it signifies
|
|
104
|
+
the fraction of the data to pack in the right dataset.
|
|
105
|
+
If integer, it signifies the number of samples to pack
|
|
106
|
+
in the right dataset.
|
|
107
|
+
If `None`, defaults to the complement to `left_size`.
|
|
108
|
+
Defaults to `None`.
|
|
109
|
+
shuffle: Boolean, whether to shuffle the data before splitting it.
|
|
110
|
+
seed: A random seed for shuffling.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A tuple of two `tf.data.Dataset` objects:
|
|
114
|
+
the left and right splits.
|
|
115
|
+
"""
|
|
116
|
+
from keras.src.utils.module_utils import tensorflow as tf
|
|
117
|
+
|
|
54
118
|
dataset_type_spec = _get_type_spec(dataset)
|
|
55
119
|
|
|
56
120
|
if dataset_type_spec is None:
|
|
57
121
|
raise TypeError(
|
|
58
122
|
"The `dataset` argument must be either"
|
|
59
|
-
"a `tf.data.Dataset
|
|
60
|
-
"
|
|
123
|
+
"a `tf.data.Dataset` object, or"
|
|
124
|
+
"a list/tuple of arrays. "
|
|
61
125
|
f"Received: dataset={dataset} of type {type(dataset)}"
|
|
62
126
|
)
|
|
63
127
|
|
|
@@ -106,6 +170,103 @@ def split_dataset(
|
|
|
106
170
|
return left_split, right_split
|
|
107
171
|
|
|
108
172
|
|
|
173
|
+
def _split_dataset_torch(
|
|
174
|
+
dataset, left_size=None, right_size=None, shuffle=False, seed=None
|
|
175
|
+
):
|
|
176
|
+
"""Splits a dataset into a left half and a right half (e.g. train / test).
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
dataset:
|
|
180
|
+
A `torch.utils.data.Dataset` object,
|
|
181
|
+
or a list/tuple of arrays with the same length.
|
|
182
|
+
left_size: If float (in the range `[0, 1]`), it signifies
|
|
183
|
+
the fraction of the data to pack in the left dataset. If integer, it
|
|
184
|
+
signifies the number of samples to pack in the left dataset. If
|
|
185
|
+
`None`, defaults to the complement to `right_size`.
|
|
186
|
+
Defaults to `None`.
|
|
187
|
+
right_size: If float (in the range `[0, 1]`), it signifies
|
|
188
|
+
the fraction of the data to pack in the right dataset.
|
|
189
|
+
If integer, it signifies the number of samples to pack
|
|
190
|
+
in the right dataset.
|
|
191
|
+
If `None`, defaults to the complement to `left_size`.
|
|
192
|
+
Defaults to `None`.
|
|
193
|
+
shuffle: Boolean, whether to shuffle the data before splitting it.
|
|
194
|
+
seed: A random seed for shuffling.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
A tuple of two `torch.utils.data.Dataset` objects:
|
|
198
|
+
the left and right splits.
|
|
199
|
+
"""
|
|
200
|
+
import torch
|
|
201
|
+
from torch.utils.data import TensorDataset
|
|
202
|
+
from torch.utils.data import random_split
|
|
203
|
+
|
|
204
|
+
dataset_type_spec = _get_type_spec(dataset)
|
|
205
|
+
if dataset_type_spec is None:
|
|
206
|
+
raise TypeError(
|
|
207
|
+
"The `dataset` argument must be a `torch.utils.data.Dataset`"
|
|
208
|
+
" object, or a list/tuple of arrays."
|
|
209
|
+
f" Received: dataset={dataset} of type {type(dataset)}"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if not isinstance(dataset, torch.utils.data.Dataset):
|
|
213
|
+
if dataset_type_spec is np.ndarray:
|
|
214
|
+
dataset = TensorDataset(torch.from_numpy(dataset))
|
|
215
|
+
elif dataset_type_spec in (list, tuple):
|
|
216
|
+
tensors = [torch.from_numpy(x) for x in dataset]
|
|
217
|
+
dataset = TensorDataset(*tensors)
|
|
218
|
+
elif is_tf_dataset(dataset):
|
|
219
|
+
dataset_as_list = _convert_dataset_to_list(
|
|
220
|
+
dataset, dataset_type_spec
|
|
221
|
+
)
|
|
222
|
+
tensors = [
|
|
223
|
+
torch.from_numpy(np.array(sample))
|
|
224
|
+
for sample in zip(*dataset_as_list)
|
|
225
|
+
]
|
|
226
|
+
dataset = TensorDataset(*tensors)
|
|
227
|
+
|
|
228
|
+
if right_size is None and left_size is None:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
"At least one of the `left_size` or `right_size` "
|
|
231
|
+
"must be specified. "
|
|
232
|
+
"Received: left_size=None and right_size=None"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Calculate total length and rescale split sizes
|
|
236
|
+
total_length = len(dataset)
|
|
237
|
+
left_size, right_size = _rescale_dataset_split_sizes(
|
|
238
|
+
left_size, right_size, total_length
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Shuffle the dataset if required
|
|
242
|
+
if shuffle:
|
|
243
|
+
generator = torch.Generator()
|
|
244
|
+
if seed is not None:
|
|
245
|
+
generator.manual_seed(seed)
|
|
246
|
+
else:
|
|
247
|
+
generator.seed()
|
|
248
|
+
else:
|
|
249
|
+
generator = None
|
|
250
|
+
|
|
251
|
+
left_split, right_split = random_split(
|
|
252
|
+
dataset, [left_size, right_size], generator=generator
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
return left_split, right_split
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _infer_preferred_backend(dataset):
|
|
259
|
+
"""Infer the backend from the dataset type."""
|
|
260
|
+
if isinstance(dataset, (list, tuple, np.ndarray)):
|
|
261
|
+
return backend.backend()
|
|
262
|
+
if is_tf_dataset(dataset):
|
|
263
|
+
return "tensorflow"
|
|
264
|
+
elif is_torch_dataset(dataset):
|
|
265
|
+
return "torch"
|
|
266
|
+
else:
|
|
267
|
+
raise TypeError(f"Unsupported dataset type: {type(dataset)}")
|
|
268
|
+
|
|
269
|
+
|
|
109
270
|
def _convert_dataset_to_list(
|
|
110
271
|
dataset,
|
|
111
272
|
dataset_type_spec,
|
|
@@ -208,7 +369,7 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
|
|
|
208
369
|
)
|
|
209
370
|
|
|
210
371
|
return iter(zip(*dataset))
|
|
211
|
-
elif
|
|
372
|
+
elif is_tf_dataset(dataset):
|
|
212
373
|
if is_batched(dataset):
|
|
213
374
|
dataset = dataset.unbatch()
|
|
214
375
|
return iter(dataset)
|
|
@@ -242,6 +403,9 @@ def _get_next_sample(
|
|
|
242
403
|
Yields:
|
|
243
404
|
data_sample: The next sample.
|
|
244
405
|
"""
|
|
406
|
+
from keras.src.trainers.data_adapters.data_adapter_utils import (
|
|
407
|
+
is_tensorflow_tensor,
|
|
408
|
+
)
|
|
245
409
|
from keras.src.trainers.data_adapters.data_adapter_utils import (
|
|
246
410
|
is_torch_tensor,
|
|
247
411
|
)
|
|
@@ -249,8 +413,10 @@ def _get_next_sample(
|
|
|
249
413
|
try:
|
|
250
414
|
dataset_iterator = iter(dataset_iterator)
|
|
251
415
|
first_sample = next(dataset_iterator)
|
|
252
|
-
if
|
|
253
|
-
first_sample
|
|
416
|
+
if (
|
|
417
|
+
isinstance(first_sample, np.ndarray)
|
|
418
|
+
or is_tensorflow_tensor(first_sample)
|
|
419
|
+
or is_torch_tensor(first_sample)
|
|
254
420
|
):
|
|
255
421
|
first_sample_shape = np.array(first_sample).shape
|
|
256
422
|
else:
|
|
@@ -291,23 +457,40 @@ def _get_next_sample(
|
|
|
291
457
|
yield sample
|
|
292
458
|
|
|
293
459
|
|
|
294
|
-
def
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
460
|
+
def is_tf_dataset(dataset):
|
|
461
|
+
return _mro_matches(
|
|
462
|
+
dataset,
|
|
463
|
+
class_names=("DatasetV2", "Dataset"),
|
|
464
|
+
module_substrings=(
|
|
465
|
+
"tensorflow.python.data", # TF classic
|
|
466
|
+
"tensorflow.data", # newer TF paths
|
|
467
|
+
),
|
|
468
|
+
)
|
|
302
469
|
|
|
303
470
|
|
|
304
471
|
def is_grain_dataset(dataset):
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
472
|
+
return _mro_matches(
|
|
473
|
+
dataset,
|
|
474
|
+
class_names=("MapDataset", "IterDataset"),
|
|
475
|
+
module_prefixes=("grain._src.python",),
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def is_torch_dataset(dataset):
|
|
480
|
+
return _mro_matches(dataset, ("Dataset",), ("torch.utils.data",))
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _mro_matches(
|
|
484
|
+
dataset, class_names, module_prefixes=(), module_substrings=()
|
|
485
|
+
):
|
|
486
|
+
if not hasattr(dataset, "__class__"):
|
|
487
|
+
return False
|
|
488
|
+
for parent in dataset.__class__.__mro__:
|
|
489
|
+
if parent.__name__ in class_names:
|
|
490
|
+
mod = str(parent.__module__)
|
|
491
|
+
if any(mod.startswith(pref) for pref in module_prefixes):
|
|
492
|
+
return True
|
|
493
|
+
if any(subs in mod for subs in module_substrings):
|
|
311
494
|
return True
|
|
312
495
|
return False
|
|
313
496
|
|
|
@@ -441,8 +624,10 @@ def _restore_dataset_from_list(
|
|
|
441
624
|
dataset_as_list, dataset_type_spec, original_dataset
|
|
442
625
|
):
|
|
443
626
|
"""Restore the dataset from the list of arrays."""
|
|
444
|
-
if
|
|
445
|
-
|
|
627
|
+
if (
|
|
628
|
+
dataset_type_spec in [tuple, list]
|
|
629
|
+
or is_tf_dataset(original_dataset)
|
|
630
|
+
or is_torch_dataset(original_dataset)
|
|
446
631
|
):
|
|
447
632
|
# Save structure by taking the first element.
|
|
448
633
|
element_spec = dataset_as_list[0]
|
|
@@ -483,7 +668,9 @@ def _get_type_spec(dataset):
|
|
|
483
668
|
return list
|
|
484
669
|
elif isinstance(dataset, np.ndarray):
|
|
485
670
|
return np.ndarray
|
|
486
|
-
elif
|
|
671
|
+
elif is_tf_dataset(dataset):
|
|
672
|
+
from keras.src.utils.module_utils import tensorflow as tf
|
|
673
|
+
|
|
487
674
|
return tf.data.Dataset
|
|
488
675
|
elif is_torch_dataset(dataset):
|
|
489
676
|
from torch.utils.data import Dataset as TorchDataset
|
|
@@ -543,6 +730,8 @@ def index_directory(
|
|
|
543
730
|
order.
|
|
544
731
|
"""
|
|
545
732
|
if file_utils.is_remote_path(directory):
|
|
733
|
+
from keras.src.utils.module_utils import tensorflow as tf
|
|
734
|
+
|
|
546
735
|
os_module = tf.io.gfile
|
|
547
736
|
path_module = tf.io.gfile
|
|
548
737
|
else:
|
|
@@ -647,7 +836,12 @@ def index_directory(
|
|
|
647
836
|
|
|
648
837
|
|
|
649
838
|
def iter_valid_files(directory, follow_links, formats):
|
|
650
|
-
|
|
839
|
+
if file_utils.is_remote_path(directory):
|
|
840
|
+
from keras.src.utils.module_utils import tensorflow as tf
|
|
841
|
+
|
|
842
|
+
io_module = tf.io.gfile
|
|
843
|
+
else:
|
|
844
|
+
io_module = os
|
|
651
845
|
|
|
652
846
|
if not follow_links:
|
|
653
847
|
walk = io_module.walk(directory)
|
|
@@ -674,9 +868,12 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
|
|
|
674
868
|
paths, and `labels` is a list of integer labels corresponding
|
|
675
869
|
to these files.
|
|
676
870
|
"""
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
871
|
+
if file_utils.is_remote_path(directory):
|
|
872
|
+
from keras.src.utils.module_utils import tensorflow as tf
|
|
873
|
+
|
|
874
|
+
path_module = tf.io.gfile
|
|
875
|
+
else:
|
|
876
|
+
path_module = os.path
|
|
680
877
|
|
|
681
878
|
dirname = os.path.basename(directory)
|
|
682
879
|
valid_files = iter_valid_files(directory, follow_links, formats)
|
|
@@ -746,6 +943,8 @@ def labels_to_dataset_tf(labels, label_mode, num_classes):
|
|
|
746
943
|
Returns:
|
|
747
944
|
A `tf.data.Dataset` instance.
|
|
748
945
|
"""
|
|
946
|
+
from keras.src.utils.module_utils import tensorflow as tf
|
|
947
|
+
|
|
749
948
|
label_ds = tf.data.Dataset.from_tensor_slices(labels)
|
|
750
949
|
if label_mode == "binary":
|
|
751
950
|
label_ds = label_ds.map(
|
keras/src/utils/file_utils.py
CHANGED
|
@@ -2,6 +2,7 @@ import hashlib
|
|
|
2
2
|
import os
|
|
3
3
|
import re
|
|
4
4
|
import shutil
|
|
5
|
+
import sys
|
|
5
6
|
import tarfile
|
|
6
7
|
import tempfile
|
|
7
8
|
import urllib
|
|
@@ -52,17 +53,32 @@ def is_link_in_dir(info, base):
|
|
|
52
53
|
return is_path_in_dir(info.linkname, base_dir=tip)
|
|
53
54
|
|
|
54
55
|
|
|
55
|
-
def
|
|
56
|
+
def filter_safe_zipinfos(members):
|
|
56
57
|
base_dir = resolve_path(".")
|
|
57
58
|
for finfo in members:
|
|
58
59
|
valid_path = False
|
|
59
|
-
if is_path_in_dir(finfo.
|
|
60
|
+
if is_path_in_dir(finfo.filename, base_dir):
|
|
60
61
|
valid_path = True
|
|
61
62
|
yield finfo
|
|
62
|
-
|
|
63
|
+
if not valid_path:
|
|
64
|
+
warnings.warn(
|
|
65
|
+
"Skipping invalid path during archive extraction: "
|
|
66
|
+
f"'{finfo.name}'.",
|
|
67
|
+
stacklevel=2,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def filter_safe_tarinfos(members):
|
|
72
|
+
base_dir = resolve_path(".")
|
|
73
|
+
for finfo in members:
|
|
74
|
+
valid_path = False
|
|
75
|
+
if finfo.issym() or finfo.islnk():
|
|
63
76
|
if is_link_in_dir(finfo, base_dir):
|
|
64
77
|
valid_path = True
|
|
65
78
|
yield finfo
|
|
79
|
+
elif is_path_in_dir(finfo.name, base_dir):
|
|
80
|
+
valid_path = True
|
|
81
|
+
yield finfo
|
|
66
82
|
if not valid_path:
|
|
67
83
|
warnings.warn(
|
|
68
84
|
"Skipping invalid path during archive extraction: "
|
|
@@ -71,6 +87,35 @@ def filter_safe_paths(members):
|
|
|
71
87
|
)
|
|
72
88
|
|
|
73
89
|
|
|
90
|
+
def extract_open_archive(archive, path="."):
|
|
91
|
+
"""Extracts an open tar or zip archive to the provided directory.
|
|
92
|
+
|
|
93
|
+
This function filters unsafe paths during extraction.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
archive: The archive object, either a `TarFile` or a `ZipFile`.
|
|
97
|
+
path: Where to extract the archive file.
|
|
98
|
+
"""
|
|
99
|
+
if isinstance(archive, zipfile.ZipFile):
|
|
100
|
+
# Zip archive.
|
|
101
|
+
archive.extractall(
|
|
102
|
+
path, members=filter_safe_zipinfos(archive.infolist())
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
# Tar archive.
|
|
106
|
+
extractall_kwargs = {}
|
|
107
|
+
# The `filter="data"` option was added in Python 3.12. It became the
|
|
108
|
+
# default starting from Python 3.14. So we only specify it between
|
|
109
|
+
# those two versions.
|
|
110
|
+
if sys.version_info >= (3, 12) and sys.version_info < (3, 14):
|
|
111
|
+
extractall_kwargs = {"filter": "data"}
|
|
112
|
+
archive.extractall(
|
|
113
|
+
path,
|
|
114
|
+
members=filter_safe_tarinfos(archive),
|
|
115
|
+
**extractall_kwargs,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
74
119
|
def extract_archive(file_path, path=".", archive_format="auto"):
|
|
75
120
|
"""Extracts an archive if it matches a support format.
|
|
76
121
|
|
|
@@ -112,14 +157,7 @@ def extract_archive(file_path, path=".", archive_format="auto"):
|
|
|
112
157
|
if is_match_fn(file_path):
|
|
113
158
|
with open_fn(file_path) as archive:
|
|
114
159
|
try:
|
|
115
|
-
|
|
116
|
-
# Zip archive.
|
|
117
|
-
archive.extractall(path)
|
|
118
|
-
else:
|
|
119
|
-
# Tar archive, perhaps unsafe. Filter paths.
|
|
120
|
-
archive.extractall(
|
|
121
|
-
path, members=filter_safe_paths(archive)
|
|
122
|
-
)
|
|
160
|
+
extract_open_archive(archive, path)
|
|
123
161
|
except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
|
|
124
162
|
if os.path.exists(path):
|
|
125
163
|
if os.path.isfile(path):
|
keras/src/utils/image_utils.py
CHANGED
|
@@ -175,12 +175,24 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
|
|
|
175
175
|
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
|
|
176
176
|
"""
|
|
177
177
|
data_format = backend.standardize_data_format(data_format)
|
|
178
|
+
|
|
179
|
+
# Infer format from path if not explicitly provided
|
|
180
|
+
if file_format is None and isinstance(path, (str, pathlib.Path)):
|
|
181
|
+
file_format = pathlib.Path(path).suffix[1:].lower()
|
|
182
|
+
|
|
183
|
+
# Normalize jpg → jpeg for Pillow compatibility
|
|
184
|
+
if file_format and file_format.lower() == "jpg":
|
|
185
|
+
file_format = "jpeg"
|
|
186
|
+
|
|
178
187
|
img = array_to_img(x, data_format=data_format, scale=scale)
|
|
179
|
-
|
|
188
|
+
|
|
189
|
+
# Handle RGBA → RGB conversion for JPEG
|
|
190
|
+
if img.mode == "RGBA" and file_format == "jpeg":
|
|
180
191
|
warnings.warn(
|
|
181
|
-
"The
|
|
192
|
+
"The JPEG format does not support RGBA images, converting to RGB."
|
|
182
193
|
)
|
|
183
194
|
img = img.convert("RGB")
|
|
195
|
+
|
|
184
196
|
img.save(path, format=file_format, **kwargs)
|
|
185
197
|
|
|
186
198
|
|