keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -455,6 +455,9 @@ class KerasFileEditor:
455
455
  def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
456
456
  metadata = metadata or {}
457
457
 
458
+ # ------------------------------------------------------
459
+ # Collect metadata for this HDF5 group
460
+ # ------------------------------------------------------
458
461
  object_metadata = {}
459
462
  for k, v in data.attrs.items():
460
463
  object_metadata[k] = v
@@ -462,26 +465,98 @@ class KerasFileEditor:
462
465
  metadata[inner_path] = object_metadata
463
466
 
464
467
  result = collections.OrderedDict()
468
+
469
+ # ------------------------------------------------------
470
+ # Iterate over all keys in this HDF5 group
471
+ # ------------------------------------------------------
465
472
  for key in data.keys():
466
- inner_path = f"{inner_path}/{key}"
473
+ # IMPORTANT:
474
+ # Never mutate inner_path; use local variable.
475
+ current_inner_path = f"{inner_path}/{key}"
467
476
  value = data[key]
477
+
478
+ # ------------------------------------------------------
479
+ # CASE 1 — HDF5 GROUP → RECURSE
480
+ # ------------------------------------------------------
468
481
  if isinstance(value, h5py.Group):
482
+ # Skip empty groups
469
483
  if len(value) == 0:
470
484
  continue
485
+
486
+ # Skip empty "vars" groups
471
487
  if "vars" in value.keys() and len(value["vars"]) == 0:
472
488
  continue
473
489
 
474
- if hasattr(value, "keys"):
490
+ # Recurse into "vars" subgroup when present
475
491
  if "vars" in value.keys():
476
492
  result[key], metadata = self._extract_weights_from_store(
477
- value["vars"], metadata=metadata, inner_path=inner_path
493
+ value["vars"],
494
+ metadata=metadata,
495
+ inner_path=current_inner_path,
478
496
  )
479
497
  else:
498
+ # Recurse normally
480
499
  result[key], metadata = self._extract_weights_from_store(
481
- value, metadata=metadata, inner_path=inner_path
500
+ value,
501
+ metadata=metadata,
502
+ inner_path=current_inner_path,
482
503
  )
483
- else:
484
- result[key] = value[()]
504
+
505
+ continue # finished processing this key
506
+
507
+ # ------------------------------------------------------
508
+ # CASE 2 — HDF5 DATASET → SAFE LOADING
509
+ # ------------------------------------------------------
510
+
511
+ # Skip any objects that are not proper datasets
512
+ if not hasattr(value, "shape") or not hasattr(value, "dtype"):
513
+ continue
514
+
515
+ shape = value.shape
516
+ dtype = value.dtype
517
+
518
+ # ------------------------------------------------------
519
+ # Validate SHAPE (avoid malformed / malicious metadata)
520
+ # ------------------------------------------------------
521
+
522
+ # No negative dimensions
523
+ if any(dim < 0 for dim in shape):
524
+ raise ValueError(
525
+ "Malformed HDF5 dataset shape encountered in .keras file; "
526
+ "negative dimension detected."
527
+ )
528
+
529
+ # Prevent absurdly high-rank tensors
530
+ if len(shape) > 64:
531
+ raise ValueError(
532
+ "Malformed HDF5 dataset shape encountered in .keras file; "
533
+ "tensor rank exceeds safety limit."
534
+ )
535
+
536
+ # Safe product computation (Python int is unbounded)
537
+ num_elems = int(np.prod(shape))
538
+
539
+ # ------------------------------------------------------
540
+ # Validate TOTAL memory size
541
+ # ------------------------------------------------------
542
+ MAX_BYTES = 1 << 32 # 4 GiB
543
+
544
+ size_bytes = num_elems * dtype.itemsize
545
+
546
+ if size_bytes > MAX_BYTES:
547
+ raise ValueError(
548
+ f"HDF5 dataset too large to load safely "
549
+ f"({size_bytes} bytes; limit is {MAX_BYTES})."
550
+ )
551
+
552
+ # ------------------------------------------------------
553
+ # SAFE — load dataset (guaranteed ≤ 4 GiB)
554
+ # ------------------------------------------------------
555
+ result[key] = value[()]
556
+
557
+ # ------------------------------------------------------
558
+ # Return final tree and metadata
559
+ # ------------------------------------------------------
485
560
  return result, metadata
486
561
 
487
562
  def _generate_filepath_info(self, rich_style=False):
@@ -0,0 +1,26 @@
1
+ """Orbax checkpoint loading functionality."""
2
+
3
+ import os
4
+
5
+ from keras.src.utils.module_utils import ocp
6
+
7
+
8
+ def is_orbax_checkpoint(filepath):
9
+ """Check if the given path is an Orbax checkpoint directory."""
10
+ if not os.path.exists(filepath):
11
+ return False
12
+
13
+ try:
14
+ return ocp.is_orbax_checkpoint(filepath)
15
+ except (ImportError, AttributeError):
16
+ # Fallback to check for orbax.checkpoint file if Orbax API not available
17
+ return os.path.isfile(os.path.join(filepath, "orbax.checkpoint"))
18
+
19
+
20
+ def find_latest_orbax_checkpoint(checkpoint_dir):
21
+ """Find the latest checkpoint in an Orbax checkpoint directory."""
22
+ checkpointer = ocp.training.Checkpointer(directory=checkpoint_dir)
23
+ latest = checkpointer.latest
24
+ if latest is None:
25
+ raise ValueError(f"No valid checkpoints found in {checkpoint_dir}")
26
+ return os.path.join(checkpoint_dir, str(latest.step))
@@ -6,13 +6,11 @@ from absl import logging
6
6
  from keras.src.api_export import keras_export
7
7
  from keras.src.legacy.saving import legacy_h5_format
8
8
  from keras.src.saving import saving_lib
9
+ from keras.src.saving.orbax_util import find_latest_orbax_checkpoint
10
+ from keras.src.saving.orbax_util import is_orbax_checkpoint
9
11
  from keras.src.utils import file_utils
10
12
  from keras.src.utils import io_utils
11
-
12
- try:
13
- import h5py
14
- except ImportError:
15
- h5py = None
13
+ from keras.src.utils.module_utils import h5py
16
14
 
17
15
 
18
16
  @keras_export(["keras.saving.save_model", "keras.models.save_model"])
@@ -149,8 +147,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
149
147
  keras.layers.Softmax()])
150
148
  model.save("model.keras")
151
149
  loaded_model = keras.saving.load_model("model.keras")
152
- x = np.random.random((10, 3))
153
- assert np.allclose(model.predict(x), loaded_model.predict(x))
154
150
  ```
155
151
 
156
152
  Note that the model variables may have different name values
@@ -208,7 +204,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
208
204
  else:
209
205
  raise ValueError(
210
206
  f"File format not supported: filepath={filepath}. "
211
- "Keras 3 only supports V3 `.keras` files and "
207
+ "Keras 3 only supports V3 `.keras` files, "
212
208
  "legacy H5 format files (`.h5` extension). "
213
209
  "Note that the legacy SavedModel format is not "
214
210
  "supported by `load_model()` in Keras 3. In "
@@ -288,15 +284,16 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
288
284
  objects_to_skip=objects_to_skip,
289
285
  )
290
286
  elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
291
- if not h5py:
292
- raise ImportError(
293
- "Loading a H5 file requires `h5py` to be installed."
294
- )
295
287
  if objects_to_skip is not None:
296
288
  raise ValueError(
297
289
  "`objects_to_skip` only supports loading '.weights.h5' files."
298
290
  f"Received: {filepath}"
299
291
  )
292
+ if not h5py.available:
293
+ raise ImportError(
294
+ "Loading HDF5 files requires the h5py package. "
295
+ "You can install it via `pip install h5py`"
296
+ )
300
297
  with h5py.File(filepath, "r") as f:
301
298
  if "layer_names" not in f.attrs and "model_weights" in f:
302
299
  f = f["model_weights"]
@@ -308,9 +305,35 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
308
305
  legacy_h5_format.load_weights_from_hdf5_group(
309
306
  f, model, skip_mismatch
310
307
  )
308
+ elif is_orbax_checkpoint(filepath):
309
+ # Load weights from Orbax checkpoint
310
+ from keras.src.utils.module_utils import ocp
311
+
312
+ filepath = str(filepath)
313
+
314
+ # Determine if this is a root directory or a step directory
315
+ items = os.listdir(filepath)
316
+ has_step_subdirs = any(
317
+ os.path.isdir(os.path.join(filepath, item)) and item.isdigit()
318
+ for item in items
319
+ )
320
+
321
+ if has_step_subdirs:
322
+ # It's a root directory, find the latest checkpoint
323
+ checkpoint_path = find_latest_orbax_checkpoint(filepath)
324
+ else:
325
+ # It's a step directory, use it directly
326
+ checkpoint_path = filepath
327
+
328
+ # Load checkpoint
329
+ loaded_state = ocp.load_pytree(checkpoint_path)
330
+
331
+ # Set the model state directly from the loaded state
332
+ model.set_state_tree(loaded_state)
311
333
  else:
312
334
  raise ValueError(
313
335
  f"File format not supported: filepath={filepath}. "
314
- "Keras 3 only supports V3 `.keras` and `.weights.h5` "
315
- "files, or legacy V1/V2 `.h5` files."
336
+ "Keras 3 only supports V3 `.keras` files, "
337
+ "`.weights.h5` files, legacy H5 format files "
338
+ "(`.h5` extension), or Orbax checkpoints."
316
339
  )
@@ -943,7 +943,7 @@ class DiskIOStore:
943
943
  if self.archive:
944
944
  self.tmp_dir = get_temp_dir()
945
945
  if self.mode == "r":
946
- self.archive.extractall(path=self.tmp_dir)
946
+ file_utils.extract_open_archive(self.archive, self.tmp_dir)
947
947
  self.working_dir = file_utils.join(
948
948
  self.tmp_dir, self.root_path
949
949
  ).replace("\\", "/")
@@ -3,3 +3,4 @@ from keras.src.testing.test_case import jax_uses_gpu
3
3
  from keras.src.testing.test_case import tensorflow_uses_gpu
4
4
  from keras.src.testing.test_case import torch_uses_gpu
5
5
  from keras.src.testing.test_case import uses_gpu
6
+ from keras.src.testing.test_case import uses_tpu
@@ -40,7 +40,20 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
40
40
  self.addCleanup(lambda: shutil.rmtree(temp_dir))
41
41
  return temp_dir
42
42
 
43
- def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
43
+ def assertAllClose(
44
+ self,
45
+ x1,
46
+ x2,
47
+ atol=1e-6,
48
+ rtol=1e-6,
49
+ tpu_atol=None,
50
+ tpu_rtol=None,
51
+ msg=None,
52
+ ):
53
+ if tpu_atol is not None and uses_tpu():
54
+ atol = tpu_atol
55
+ if tpu_rtol is not None and uses_tpu():
56
+ rtol = tpu_rtol
44
57
  if not isinstance(x1, np.ndarray):
45
58
  x1 = backend.convert_to_numpy(x1)
46
59
  if not isinstance(x2, np.ndarray):
@@ -57,7 +70,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
57
70
  f"The two values are close at all elements. \n{msg}.\nValues: {x1}"
58
71
  )
59
72
 
60
- def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
73
+ def assertAlmostEqual(self, x1, x2, decimal=3, tpu_decimal=None, msg=None):
74
+ if tpu_decimal is not None and uses_tpu():
75
+ decimal = tpu_decimal
61
76
  msg = msg or ""
62
77
  if not isinstance(x1, np.ndarray):
63
78
  x1 = backend.convert_to_numpy(x1)
@@ -195,6 +210,8 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
195
210
  run_training_check=True,
196
211
  run_mixed_precision_check=True,
197
212
  assert_built_after_instantiation=False,
213
+ tpu_atol=None,
214
+ tpu_rtol=None,
198
215
  ):
199
216
  """Run basic checks on a layer.
200
217
 
@@ -376,7 +393,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
376
393
  msg="Unexpected number of torch_params",
377
394
  )
378
395
 
379
- def run_output_asserts(layer, output, eager=False):
396
+ def run_output_asserts(
397
+ layer, output, eager=False, tpu_atol=None, tpu_rtol=None
398
+ ):
380
399
  if expected_output_shape is not None:
381
400
 
382
401
  def verify_shape(expected_shape, x):
@@ -422,7 +441,11 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
422
441
  tree.flatten(expected_output), tree.flatten(output)
423
442
  ):
424
443
  self.assertAllClose(
425
- ref_v, v, msg="Unexpected output value"
444
+ ref_v,
445
+ v,
446
+ msg="Unexpected output value",
447
+ tpu_atol=tpu_atol,
448
+ tpu_rtol=tpu_rtol,
426
449
  )
427
450
  if expected_num_losses is not None:
428
451
  self.assertLen(layer.losses, expected_num_losses)
@@ -551,7 +574,13 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
551
574
  output_data = layer(**input_data, **call_kwargs)
552
575
  else:
553
576
  output_data = layer(input_data, **call_kwargs)
554
- run_output_asserts(layer, output_data, eager=True)
577
+ run_output_asserts(
578
+ layer,
579
+ output_data,
580
+ eager=True,
581
+ tpu_atol=tpu_atol,
582
+ tpu_rtol=tpu_rtol,
583
+ )
555
584
 
556
585
  if run_training_check:
557
586
  run_training_step(layer, input_data, output_data)
@@ -621,6 +650,17 @@ def uses_gpu():
621
650
  return False
622
651
 
623
652
 
653
+ def uses_tpu():
654
+ # Condition used to skip tests when using the TPU
655
+ try:
656
+ devices = distribution.list_devices()
657
+ if any(d.startswith("tpu") for d in devices):
658
+ return True
659
+ except AttributeError:
660
+ return False
661
+ return False
662
+
663
+
624
664
  def uses_cpu():
625
665
  devices = distribution.list_devices()
626
666
  if any(d.startswith("cpu") for d in devices):
@@ -3,6 +3,7 @@ import importlib
3
3
  import inspect
4
4
  import os
5
5
  import sys
6
+ import warnings
6
7
 
7
8
  from keras.src import backend as backend_module
8
9
  from keras.src.api_export import keras_export
@@ -124,9 +125,22 @@ def set_backend(backend):
124
125
 
125
126
  Example:
126
127
 
127
- ```python
128
- keras.config.set_backend("jax")
129
- ```
128
+ >>> import os
129
+ >>> os.environ["KERAS_BACKEND"] = "tensorflow"
130
+ >>>
131
+ >>> import keras
132
+ >>> from keras import ops
133
+ >>> type(ops.ones(()))
134
+ <class 'tensorflow.python.framework.ops.EagerTensor'>
135
+ >>>
136
+ >>> keras.config.set_backend("jax")
137
+ UserWarning: Using `keras.config.set_backend` is dangerous...
138
+ >>> del keras, ops
139
+ >>>
140
+ >>> import keras
141
+ >>> from keras import ops
142
+ >>> type(ops.ones(()))
143
+ <class 'jaxlib.xla_extension.ArrayImpl'>
130
144
 
131
145
  ⚠️ WARNING ⚠️: Using this function is dangerous and should be done
132
146
  carefully. Changing the backend will **NOT** convert
@@ -138,7 +152,7 @@ def set_backend(backend):
138
152
 
139
153
  This includes any function or class instance that uses any Keras
140
154
  functionality. All such code needs to be re-executed after calling
141
- `set_backend()`.
155
+ `set_backend()` and re-importing all imported `keras` modules.
142
156
  """
143
157
  os.environ["KERAS_BACKEND"] = backend
144
158
  # Clear module cache.
@@ -159,3 +173,16 @@ def set_backend(backend):
159
173
  module_name = module_name[module_name.find("'") + 1 :]
160
174
  module_name = module_name[: module_name.find("'")]
161
175
  globals()[key] = importlib.import_module(module_name)
176
+
177
+ warnings.warn(
178
+ "Using `keras.config.set_backend` is dangerous and should be done "
179
+ "carefully. Already-instantiated objects will not be converted. Thus, "
180
+ "any layers / tensors / etc. already created will no longer be usable "
181
+ "without errors. It is strongly recommended not to keep around any "
182
+ "Keras-originated objects instances created before calling "
183
+ "`set_backend()`. This includes any function or class instance that "
184
+ "uses any Keras functionality. All such code needs to be re-executed "
185
+ "after calling `set_backend()` and re-importing all imported `keras` "
186
+ "modules.",
187
+ stacklevel=2,
188
+ )