keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/saving/file_editor.py
CHANGED
|
@@ -455,6 +455,9 @@ class KerasFileEditor:
|
|
|
455
455
|
def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
|
|
456
456
|
metadata = metadata or {}
|
|
457
457
|
|
|
458
|
+
# ------------------------------------------------------
|
|
459
|
+
# Collect metadata for this HDF5 group
|
|
460
|
+
# ------------------------------------------------------
|
|
458
461
|
object_metadata = {}
|
|
459
462
|
for k, v in data.attrs.items():
|
|
460
463
|
object_metadata[k] = v
|
|
@@ -462,26 +465,98 @@ class KerasFileEditor:
|
|
|
462
465
|
metadata[inner_path] = object_metadata
|
|
463
466
|
|
|
464
467
|
result = collections.OrderedDict()
|
|
468
|
+
|
|
469
|
+
# ------------------------------------------------------
|
|
470
|
+
# Iterate over all keys in this HDF5 group
|
|
471
|
+
# ------------------------------------------------------
|
|
465
472
|
for key in data.keys():
|
|
466
|
-
|
|
473
|
+
# IMPORTANT:
|
|
474
|
+
# Never mutate inner_path; use local variable.
|
|
475
|
+
current_inner_path = f"{inner_path}/{key}"
|
|
467
476
|
value = data[key]
|
|
477
|
+
|
|
478
|
+
# ------------------------------------------------------
|
|
479
|
+
# CASE 1 — HDF5 GROUP → RECURSE
|
|
480
|
+
# ------------------------------------------------------
|
|
468
481
|
if isinstance(value, h5py.Group):
|
|
482
|
+
# Skip empty groups
|
|
469
483
|
if len(value) == 0:
|
|
470
484
|
continue
|
|
485
|
+
|
|
486
|
+
# Skip empty "vars" groups
|
|
471
487
|
if "vars" in value.keys() and len(value["vars"]) == 0:
|
|
472
488
|
continue
|
|
473
489
|
|
|
474
|
-
|
|
490
|
+
# Recurse into "vars" subgroup when present
|
|
475
491
|
if "vars" in value.keys():
|
|
476
492
|
result[key], metadata = self._extract_weights_from_store(
|
|
477
|
-
value["vars"],
|
|
493
|
+
value["vars"],
|
|
494
|
+
metadata=metadata,
|
|
495
|
+
inner_path=current_inner_path,
|
|
478
496
|
)
|
|
479
497
|
else:
|
|
498
|
+
# Recurse normally
|
|
480
499
|
result[key], metadata = self._extract_weights_from_store(
|
|
481
|
-
value,
|
|
500
|
+
value,
|
|
501
|
+
metadata=metadata,
|
|
502
|
+
inner_path=current_inner_path,
|
|
482
503
|
)
|
|
483
|
-
|
|
484
|
-
|
|
504
|
+
|
|
505
|
+
continue # finished processing this key
|
|
506
|
+
|
|
507
|
+
# ------------------------------------------------------
|
|
508
|
+
# CASE 2 — HDF5 DATASET → SAFE LOADING
|
|
509
|
+
# ------------------------------------------------------
|
|
510
|
+
|
|
511
|
+
# Skip any objects that are not proper datasets
|
|
512
|
+
if not hasattr(value, "shape") or not hasattr(value, "dtype"):
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
shape = value.shape
|
|
516
|
+
dtype = value.dtype
|
|
517
|
+
|
|
518
|
+
# ------------------------------------------------------
|
|
519
|
+
# Validate SHAPE (avoid malformed / malicious metadata)
|
|
520
|
+
# ------------------------------------------------------
|
|
521
|
+
|
|
522
|
+
# No negative dimensions
|
|
523
|
+
if any(dim < 0 for dim in shape):
|
|
524
|
+
raise ValueError(
|
|
525
|
+
"Malformed HDF5 dataset shape encountered in .keras file; "
|
|
526
|
+
"negative dimension detected."
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Prevent absurdly high-rank tensors
|
|
530
|
+
if len(shape) > 64:
|
|
531
|
+
raise ValueError(
|
|
532
|
+
"Malformed HDF5 dataset shape encountered in .keras file; "
|
|
533
|
+
"tensor rank exceeds safety limit."
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
# Safe product computation (Python int is unbounded)
|
|
537
|
+
num_elems = int(np.prod(shape))
|
|
538
|
+
|
|
539
|
+
# ------------------------------------------------------
|
|
540
|
+
# Validate TOTAL memory size
|
|
541
|
+
# ------------------------------------------------------
|
|
542
|
+
MAX_BYTES = 1 << 32 # 4 GiB
|
|
543
|
+
|
|
544
|
+
size_bytes = num_elems * dtype.itemsize
|
|
545
|
+
|
|
546
|
+
if size_bytes > MAX_BYTES:
|
|
547
|
+
raise ValueError(
|
|
548
|
+
f"HDF5 dataset too large to load safely "
|
|
549
|
+
f"({size_bytes} bytes; limit is {MAX_BYTES})."
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
# ------------------------------------------------------
|
|
553
|
+
# SAFE — load dataset (guaranteed ≤ 4 GiB)
|
|
554
|
+
# ------------------------------------------------------
|
|
555
|
+
result[key] = value[()]
|
|
556
|
+
|
|
557
|
+
# ------------------------------------------------------
|
|
558
|
+
# Return final tree and metadata
|
|
559
|
+
# ------------------------------------------------------
|
|
485
560
|
return result, metadata
|
|
486
561
|
|
|
487
562
|
def _generate_filepath_info(self, rich_style=False):
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Orbax checkpoint loading functionality."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from keras.src.utils.module_utils import ocp
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def is_orbax_checkpoint(filepath):
|
|
9
|
+
"""Check if the given path is an Orbax checkpoint directory."""
|
|
10
|
+
if not os.path.exists(filepath):
|
|
11
|
+
return False
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
return ocp.is_orbax_checkpoint(filepath)
|
|
15
|
+
except (ImportError, AttributeError):
|
|
16
|
+
# Fallback to check for orbax.checkpoint file if Orbax API not available
|
|
17
|
+
return os.path.isfile(os.path.join(filepath, "orbax.checkpoint"))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def find_latest_orbax_checkpoint(checkpoint_dir):
|
|
21
|
+
"""Find the latest checkpoint in an Orbax checkpoint directory."""
|
|
22
|
+
checkpointer = ocp.training.Checkpointer(directory=checkpoint_dir)
|
|
23
|
+
latest = checkpointer.latest
|
|
24
|
+
if latest is None:
|
|
25
|
+
raise ValueError(f"No valid checkpoints found in {checkpoint_dir}")
|
|
26
|
+
return os.path.join(checkpoint_dir, str(latest.step))
|
keras/src/saving/saving_api.py
CHANGED
|
@@ -6,13 +6,11 @@ from absl import logging
|
|
|
6
6
|
from keras.src.api_export import keras_export
|
|
7
7
|
from keras.src.legacy.saving import legacy_h5_format
|
|
8
8
|
from keras.src.saving import saving_lib
|
|
9
|
+
from keras.src.saving.orbax_util import find_latest_orbax_checkpoint
|
|
10
|
+
from keras.src.saving.orbax_util import is_orbax_checkpoint
|
|
9
11
|
from keras.src.utils import file_utils
|
|
10
12
|
from keras.src.utils import io_utils
|
|
11
|
-
|
|
12
|
-
try:
|
|
13
|
-
import h5py
|
|
14
|
-
except ImportError:
|
|
15
|
-
h5py = None
|
|
13
|
+
from keras.src.utils.module_utils import h5py
|
|
16
14
|
|
|
17
15
|
|
|
18
16
|
@keras_export(["keras.saving.save_model", "keras.models.save_model"])
|
|
@@ -149,8 +147,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
149
147
|
keras.layers.Softmax()])
|
|
150
148
|
model.save("model.keras")
|
|
151
149
|
loaded_model = keras.saving.load_model("model.keras")
|
|
152
|
-
x = np.random.random((10, 3))
|
|
153
|
-
assert np.allclose(model.predict(x), loaded_model.predict(x))
|
|
154
150
|
```
|
|
155
151
|
|
|
156
152
|
Note that the model variables may have different name values
|
|
@@ -208,7 +204,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
208
204
|
else:
|
|
209
205
|
raise ValueError(
|
|
210
206
|
f"File format not supported: filepath={filepath}. "
|
|
211
|
-
"Keras 3 only supports V3 `.keras` files
|
|
207
|
+
"Keras 3 only supports V3 `.keras` files, "
|
|
212
208
|
"legacy H5 format files (`.h5` extension). "
|
|
213
209
|
"Note that the legacy SavedModel format is not "
|
|
214
210
|
"supported by `load_model()` in Keras 3. In "
|
|
@@ -288,15 +284,16 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
288
284
|
objects_to_skip=objects_to_skip,
|
|
289
285
|
)
|
|
290
286
|
elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
|
|
291
|
-
if not h5py:
|
|
292
|
-
raise ImportError(
|
|
293
|
-
"Loading a H5 file requires `h5py` to be installed."
|
|
294
|
-
)
|
|
295
287
|
if objects_to_skip is not None:
|
|
296
288
|
raise ValueError(
|
|
297
289
|
"`objects_to_skip` only supports loading '.weights.h5' files."
|
|
298
290
|
f"Received: {filepath}"
|
|
299
291
|
)
|
|
292
|
+
if not h5py.available:
|
|
293
|
+
raise ImportError(
|
|
294
|
+
"Loading HDF5 files requires the h5py package. "
|
|
295
|
+
"You can install it via `pip install h5py`"
|
|
296
|
+
)
|
|
300
297
|
with h5py.File(filepath, "r") as f:
|
|
301
298
|
if "layer_names" not in f.attrs and "model_weights" in f:
|
|
302
299
|
f = f["model_weights"]
|
|
@@ -308,9 +305,35 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
308
305
|
legacy_h5_format.load_weights_from_hdf5_group(
|
|
309
306
|
f, model, skip_mismatch
|
|
310
307
|
)
|
|
308
|
+
elif is_orbax_checkpoint(filepath):
|
|
309
|
+
# Load weights from Orbax checkpoint
|
|
310
|
+
from keras.src.utils.module_utils import ocp
|
|
311
|
+
|
|
312
|
+
filepath = str(filepath)
|
|
313
|
+
|
|
314
|
+
# Determine if this is a root directory or a step directory
|
|
315
|
+
items = os.listdir(filepath)
|
|
316
|
+
has_step_subdirs = any(
|
|
317
|
+
os.path.isdir(os.path.join(filepath, item)) and item.isdigit()
|
|
318
|
+
for item in items
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if has_step_subdirs:
|
|
322
|
+
# It's a root directory, find the latest checkpoint
|
|
323
|
+
checkpoint_path = find_latest_orbax_checkpoint(filepath)
|
|
324
|
+
else:
|
|
325
|
+
# It's a step directory, use it directly
|
|
326
|
+
checkpoint_path = filepath
|
|
327
|
+
|
|
328
|
+
# Load checkpoint
|
|
329
|
+
loaded_state = ocp.load_pytree(checkpoint_path)
|
|
330
|
+
|
|
331
|
+
# Set the model state directly from the loaded state
|
|
332
|
+
model.set_state_tree(loaded_state)
|
|
311
333
|
else:
|
|
312
334
|
raise ValueError(
|
|
313
335
|
f"File format not supported: filepath={filepath}. "
|
|
314
|
-
"Keras 3 only supports V3 `.keras`
|
|
315
|
-
"files,
|
|
336
|
+
"Keras 3 only supports V3 `.keras` files, "
|
|
337
|
+
"`.weights.h5` files, legacy H5 format files "
|
|
338
|
+
"(`.h5` extension), or Orbax checkpoints."
|
|
316
339
|
)
|
keras/src/saving/saving_lib.py
CHANGED
|
@@ -943,7 +943,7 @@ class DiskIOStore:
|
|
|
943
943
|
if self.archive:
|
|
944
944
|
self.tmp_dir = get_temp_dir()
|
|
945
945
|
if self.mode == "r":
|
|
946
|
-
self.archive
|
|
946
|
+
file_utils.extract_open_archive(self.archive, self.tmp_dir)
|
|
947
947
|
self.working_dir = file_utils.join(
|
|
948
948
|
self.tmp_dir, self.root_path
|
|
949
949
|
).replace("\\", "/")
|
keras/src/testing/__init__.py
CHANGED
keras/src/testing/test_case.py
CHANGED
|
@@ -40,7 +40,20 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
40
40
|
self.addCleanup(lambda: shutil.rmtree(temp_dir))
|
|
41
41
|
return temp_dir
|
|
42
42
|
|
|
43
|
-
def assertAllClose(
|
|
43
|
+
def assertAllClose(
|
|
44
|
+
self,
|
|
45
|
+
x1,
|
|
46
|
+
x2,
|
|
47
|
+
atol=1e-6,
|
|
48
|
+
rtol=1e-6,
|
|
49
|
+
tpu_atol=None,
|
|
50
|
+
tpu_rtol=None,
|
|
51
|
+
msg=None,
|
|
52
|
+
):
|
|
53
|
+
if tpu_atol is not None and uses_tpu():
|
|
54
|
+
atol = tpu_atol
|
|
55
|
+
if tpu_rtol is not None and uses_tpu():
|
|
56
|
+
rtol = tpu_rtol
|
|
44
57
|
if not isinstance(x1, np.ndarray):
|
|
45
58
|
x1 = backend.convert_to_numpy(x1)
|
|
46
59
|
if not isinstance(x2, np.ndarray):
|
|
@@ -57,7 +70,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
57
70
|
f"The two values are close at all elements. \n{msg}.\nValues: {x1}"
|
|
58
71
|
)
|
|
59
72
|
|
|
60
|
-
def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
|
|
73
|
+
def assertAlmostEqual(self, x1, x2, decimal=3, tpu_decimal=None, msg=None):
|
|
74
|
+
if tpu_decimal is not None and uses_tpu():
|
|
75
|
+
decimal = tpu_decimal
|
|
61
76
|
msg = msg or ""
|
|
62
77
|
if not isinstance(x1, np.ndarray):
|
|
63
78
|
x1 = backend.convert_to_numpy(x1)
|
|
@@ -195,6 +210,8 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
195
210
|
run_training_check=True,
|
|
196
211
|
run_mixed_precision_check=True,
|
|
197
212
|
assert_built_after_instantiation=False,
|
|
213
|
+
tpu_atol=None,
|
|
214
|
+
tpu_rtol=None,
|
|
198
215
|
):
|
|
199
216
|
"""Run basic checks on a layer.
|
|
200
217
|
|
|
@@ -376,7 +393,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
376
393
|
msg="Unexpected number of torch_params",
|
|
377
394
|
)
|
|
378
395
|
|
|
379
|
-
def run_output_asserts(
|
|
396
|
+
def run_output_asserts(
|
|
397
|
+
layer, output, eager=False, tpu_atol=None, tpu_rtol=None
|
|
398
|
+
):
|
|
380
399
|
if expected_output_shape is not None:
|
|
381
400
|
|
|
382
401
|
def verify_shape(expected_shape, x):
|
|
@@ -422,7 +441,11 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
422
441
|
tree.flatten(expected_output), tree.flatten(output)
|
|
423
442
|
):
|
|
424
443
|
self.assertAllClose(
|
|
425
|
-
ref_v,
|
|
444
|
+
ref_v,
|
|
445
|
+
v,
|
|
446
|
+
msg="Unexpected output value",
|
|
447
|
+
tpu_atol=tpu_atol,
|
|
448
|
+
tpu_rtol=tpu_rtol,
|
|
426
449
|
)
|
|
427
450
|
if expected_num_losses is not None:
|
|
428
451
|
self.assertLen(layer.losses, expected_num_losses)
|
|
@@ -551,7 +574,13 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
551
574
|
output_data = layer(**input_data, **call_kwargs)
|
|
552
575
|
else:
|
|
553
576
|
output_data = layer(input_data, **call_kwargs)
|
|
554
|
-
run_output_asserts(
|
|
577
|
+
run_output_asserts(
|
|
578
|
+
layer,
|
|
579
|
+
output_data,
|
|
580
|
+
eager=True,
|
|
581
|
+
tpu_atol=tpu_atol,
|
|
582
|
+
tpu_rtol=tpu_rtol,
|
|
583
|
+
)
|
|
555
584
|
|
|
556
585
|
if run_training_check:
|
|
557
586
|
run_training_step(layer, input_data, output_data)
|
|
@@ -621,6 +650,17 @@ def uses_gpu():
|
|
|
621
650
|
return False
|
|
622
651
|
|
|
623
652
|
|
|
653
|
+
def uses_tpu():
|
|
654
|
+
# Condition used to skip tests when using the TPU
|
|
655
|
+
try:
|
|
656
|
+
devices = distribution.list_devices()
|
|
657
|
+
if any(d.startswith("tpu") for d in devices):
|
|
658
|
+
return True
|
|
659
|
+
except AttributeError:
|
|
660
|
+
return False
|
|
661
|
+
return False
|
|
662
|
+
|
|
663
|
+
|
|
624
664
|
def uses_cpu():
|
|
625
665
|
devices = distribution.list_devices()
|
|
626
666
|
if any(d.startswith("cpu") for d in devices):
|
keras/src/utils/backend_utils.py
CHANGED
|
@@ -3,6 +3,7 @@ import importlib
|
|
|
3
3
|
import inspect
|
|
4
4
|
import os
|
|
5
5
|
import sys
|
|
6
|
+
import warnings
|
|
6
7
|
|
|
7
8
|
from keras.src import backend as backend_module
|
|
8
9
|
from keras.src.api_export import keras_export
|
|
@@ -124,9 +125,22 @@ def set_backend(backend):
|
|
|
124
125
|
|
|
125
126
|
Example:
|
|
126
127
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
128
|
+
>>> import os
|
|
129
|
+
>>> os.environ["KERAS_BACKEND"] = "tensorflow"
|
|
130
|
+
>>>
|
|
131
|
+
>>> import keras
|
|
132
|
+
>>> from keras import ops
|
|
133
|
+
>>> type(ops.ones(()))
|
|
134
|
+
<class 'tensorflow.python.framework.ops.EagerTensor'>
|
|
135
|
+
>>>
|
|
136
|
+
>>> keras.config.set_backend("jax")
|
|
137
|
+
UserWarning: Using `keras.config.set_backend` is dangerous...
|
|
138
|
+
>>> del keras, ops
|
|
139
|
+
>>>
|
|
140
|
+
>>> import keras
|
|
141
|
+
>>> from keras import ops
|
|
142
|
+
>>> type(ops.ones(()))
|
|
143
|
+
<class 'jaxlib.xla_extension.ArrayImpl'>
|
|
130
144
|
|
|
131
145
|
⚠️ WARNING ⚠️: Using this function is dangerous and should be done
|
|
132
146
|
carefully. Changing the backend will **NOT** convert
|
|
@@ -138,7 +152,7 @@ def set_backend(backend):
|
|
|
138
152
|
|
|
139
153
|
This includes any function or class instance that uses any Keras
|
|
140
154
|
functionality. All such code needs to be re-executed after calling
|
|
141
|
-
`set_backend()
|
|
155
|
+
`set_backend()` and re-importing all imported `keras` modules.
|
|
142
156
|
"""
|
|
143
157
|
os.environ["KERAS_BACKEND"] = backend
|
|
144
158
|
# Clear module cache.
|
|
@@ -159,3 +173,16 @@ def set_backend(backend):
|
|
|
159
173
|
module_name = module_name[module_name.find("'") + 1 :]
|
|
160
174
|
module_name = module_name[: module_name.find("'")]
|
|
161
175
|
globals()[key] = importlib.import_module(module_name)
|
|
176
|
+
|
|
177
|
+
warnings.warn(
|
|
178
|
+
"Using `keras.config.set_backend` is dangerous and should be done "
|
|
179
|
+
"carefully. Already-instantiated objects will not be converted. Thus, "
|
|
180
|
+
"any layers / tensors / etc. already created will no longer be usable "
|
|
181
|
+
"without errors. It is strongly recommended not to keep around any "
|
|
182
|
+
"Keras-originated objects instances created before calling "
|
|
183
|
+
"`set_backend()`. This includes any function or class instance that "
|
|
184
|
+
"uses any Keras functionality. All such code needs to be re-executed "
|
|
185
|
+
"after calling `set_backend()` and re-importing all imported `keras` "
|
|
186
|
+
"modules.",
|
|
187
|
+
stacklevel=2,
|
|
188
|
+
)
|