keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -455,6 +455,9 @@ class KerasFileEditor:
455
455
  def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
456
456
  metadata = metadata or {}
457
457
 
458
+ # ------------------------------------------------------
459
+ # Collect metadata for this HDF5 group
460
+ # ------------------------------------------------------
458
461
  object_metadata = {}
459
462
  for k, v in data.attrs.items():
460
463
  object_metadata[k] = v
@@ -462,26 +465,98 @@ class KerasFileEditor:
462
465
  metadata[inner_path] = object_metadata
463
466
 
464
467
  result = collections.OrderedDict()
468
+
469
+ # ------------------------------------------------------
470
+ # Iterate over all keys in this HDF5 group
471
+ # ------------------------------------------------------
465
472
  for key in data.keys():
466
- inner_path = f"{inner_path}/{key}"
473
+ # IMPORTANT:
474
+ # Never mutate inner_path; use local variable.
475
+ current_inner_path = f"{inner_path}/{key}"
467
476
  value = data[key]
477
+
478
+ # ------------------------------------------------------
479
+ # CASE 1 — HDF5 GROUP → RECURSE
480
+ # ------------------------------------------------------
468
481
  if isinstance(value, h5py.Group):
482
+ # Skip empty groups
469
483
  if len(value) == 0:
470
484
  continue
485
+
486
+ # Skip empty "vars" groups
471
487
  if "vars" in value.keys() and len(value["vars"]) == 0:
472
488
  continue
473
489
 
474
- if hasattr(value, "keys"):
490
+ # Recurse into "vars" subgroup when present
475
491
  if "vars" in value.keys():
476
492
  result[key], metadata = self._extract_weights_from_store(
477
- value["vars"], metadata=metadata, inner_path=inner_path
493
+ value["vars"],
494
+ metadata=metadata,
495
+ inner_path=current_inner_path,
478
496
  )
479
497
  else:
498
+ # Recurse normally
480
499
  result[key], metadata = self._extract_weights_from_store(
481
- value, metadata=metadata, inner_path=inner_path
500
+ value,
501
+ metadata=metadata,
502
+ inner_path=current_inner_path,
482
503
  )
483
- else:
484
- result[key] = value[()]
504
+
505
+ continue # finished processing this key
506
+
507
+ # ------------------------------------------------------
508
+ # CASE 2 — HDF5 DATASET → SAFE LOADING
509
+ # ------------------------------------------------------
510
+
511
+ # Skip any objects that are not proper datasets
512
+ if not hasattr(value, "shape") or not hasattr(value, "dtype"):
513
+ continue
514
+
515
+ shape = value.shape
516
+ dtype = value.dtype
517
+
518
+ # ------------------------------------------------------
519
+ # Validate SHAPE (avoid malformed / malicious metadata)
520
+ # ------------------------------------------------------
521
+
522
+ # No negative dimensions
523
+ if any(dim < 0 for dim in shape):
524
+ raise ValueError(
525
+ "Malformed HDF5 dataset shape encountered in .keras file; "
526
+ "negative dimension detected."
527
+ )
528
+
529
+ # Prevent absurdly high-rank tensors
530
+ if len(shape) > 64:
531
+ raise ValueError(
532
+ "Malformed HDF5 dataset shape encountered in .keras file; "
533
+ "tensor rank exceeds safety limit."
534
+ )
535
+
536
+ # Safe product computation (Python int is unbounded)
537
+ num_elems = int(np.prod(shape))
538
+
539
+ # ------------------------------------------------------
540
+ # Validate TOTAL memory size
541
+ # ------------------------------------------------------
542
+ MAX_BYTES = 1 << 32 # 4 GiB
543
+
544
+ size_bytes = num_elems * dtype.itemsize
545
+
546
+ if size_bytes > MAX_BYTES:
547
+ raise ValueError(
548
+ f"HDF5 dataset too large to load safely "
549
+ f"({size_bytes} bytes; limit is {MAX_BYTES})."
550
+ )
551
+
552
+ # ------------------------------------------------------
553
+ # SAFE — load dataset (guaranteed ≤ 4 GiB)
554
+ # ------------------------------------------------------
555
+ result[key] = value[()]
556
+
557
+ # ------------------------------------------------------
558
+ # Return final tree and metadata
559
+ # ------------------------------------------------------
485
560
  return result, metadata
486
561
 
487
562
  def _generate_filepath_info(self, rich_style=False):
@@ -0,0 +1,26 @@
1
+ """Orbax checkpoint loading functionality."""
2
+
3
+ import os
4
+
5
+ from keras.src.utils.module_utils import ocp
6
+
7
+
8
+ def is_orbax_checkpoint(filepath):
9
+ """Check if the given path is an Orbax checkpoint directory."""
10
+ if not os.path.exists(filepath):
11
+ return False
12
+
13
+ try:
14
+ return ocp.is_orbax_checkpoint(filepath)
15
+ except (ImportError, AttributeError):
16
+ # Fallback to check for orbax.checkpoint file if Orbax API not available
17
+ return os.path.isfile(os.path.join(filepath, "orbax.checkpoint"))
18
+
19
+
20
+ def find_latest_orbax_checkpoint(checkpoint_dir):
21
+ """Find the latest checkpoint in an Orbax checkpoint directory."""
22
+ checkpointer = ocp.training.Checkpointer(directory=checkpoint_dir)
23
+ latest = checkpointer.latest
24
+ if latest is None:
25
+ raise ValueError(f"No valid checkpoints found in {checkpoint_dir}")
26
+ return os.path.join(checkpoint_dir, str(latest.step))
@@ -6,13 +6,11 @@ from absl import logging
6
6
  from keras.src.api_export import keras_export
7
7
  from keras.src.legacy.saving import legacy_h5_format
8
8
  from keras.src.saving import saving_lib
9
+ from keras.src.saving.orbax_util import find_latest_orbax_checkpoint
10
+ from keras.src.saving.orbax_util import is_orbax_checkpoint
9
11
  from keras.src.utils import file_utils
10
12
  from keras.src.utils import io_utils
11
-
12
- try:
13
- import h5py
14
- except ImportError:
15
- h5py = None
13
+ from keras.src.utils.module_utils import h5py
16
14
 
17
15
 
18
16
  @keras_export(["keras.saving.save_model", "keras.models.save_model"])
@@ -149,8 +147,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
149
147
  keras.layers.Softmax()])
150
148
  model.save("model.keras")
151
149
  loaded_model = keras.saving.load_model("model.keras")
152
- x = np.random.random((10, 3))
153
- assert np.allclose(model.predict(x), loaded_model.predict(x))
154
150
  ```
155
151
 
156
152
  Note that the model variables may have different name values
@@ -208,7 +204,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
208
204
  else:
209
205
  raise ValueError(
210
206
  f"File format not supported: filepath={filepath}. "
211
- "Keras 3 only supports V3 `.keras` files and "
207
+ "Keras 3 only supports V3 `.keras` files, "
212
208
  "legacy H5 format files (`.h5` extension). "
213
209
  "Note that the legacy SavedModel format is not "
214
210
  "supported by `load_model()` in Keras 3. In "
@@ -288,15 +284,16 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
288
284
  objects_to_skip=objects_to_skip,
289
285
  )
290
286
  elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
291
- if not h5py:
292
- raise ImportError(
293
- "Loading a H5 file requires `h5py` to be installed."
294
- )
295
287
  if objects_to_skip is not None:
296
288
  raise ValueError(
297
289
  "`objects_to_skip` only supports loading '.weights.h5' files."
298
290
  f"Received: {filepath}"
299
291
  )
292
+ if not h5py.available:
293
+ raise ImportError(
294
+ "Loading HDF5 files requires the h5py package. "
295
+ "You can install it via `pip install h5py`"
296
+ )
300
297
  with h5py.File(filepath, "r") as f:
301
298
  if "layer_names" not in f.attrs and "model_weights" in f:
302
299
  f = f["model_weights"]
@@ -308,9 +305,35 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
308
305
  legacy_h5_format.load_weights_from_hdf5_group(
309
306
  f, model, skip_mismatch
310
307
  )
308
+ elif is_orbax_checkpoint(filepath):
309
+ # Load weights from Orbax checkpoint
310
+ from keras.src.utils.module_utils import ocp
311
+
312
+ filepath = str(filepath)
313
+
314
+ # Determine if this is a root directory or a step directory
315
+ items = os.listdir(filepath)
316
+ has_step_subdirs = any(
317
+ os.path.isdir(os.path.join(filepath, item)) and item.isdigit()
318
+ for item in items
319
+ )
320
+
321
+ if has_step_subdirs:
322
+ # It's a root directory, find the latest checkpoint
323
+ checkpoint_path = find_latest_orbax_checkpoint(filepath)
324
+ else:
325
+ # It's a step directory, use it directly
326
+ checkpoint_path = filepath
327
+
328
+ # Load checkpoint
329
+ loaded_state = ocp.load_pytree(checkpoint_path)
330
+
331
+ # Set the model state directly from the loaded state
332
+ model.set_state_tree(loaded_state)
311
333
  else:
312
334
  raise ValueError(
313
335
  f"File format not supported: filepath={filepath}. "
314
- "Keras 3 only supports V3 `.keras` and `.weights.h5` "
315
- "files, or legacy V1/V2 `.h5` files."
336
+ "Keras 3 only supports V3 `.keras` files, "
337
+ "`.weights.h5` files, legacy H5 format files "
338
+ "(`.h5` extension), or Orbax checkpoints."
316
339
  )
@@ -943,7 +943,7 @@ class DiskIOStore:
943
943
  if self.archive:
944
944
  self.tmp_dir = get_temp_dir()
945
945
  if self.mode == "r":
946
- self.archive.extractall(path=self.tmp_dir)
946
+ file_utils.extract_open_archive(self.archive, self.tmp_dir)
947
947
  self.working_dir = file_utils.join(
948
948
  self.tmp_dir, self.root_path
949
949
  ).replace("\\", "/")
@@ -3,3 +3,4 @@ from keras.src.testing.test_case import jax_uses_gpu
3
3
  from keras.src.testing.test_case import tensorflow_uses_gpu
4
4
  from keras.src.testing.test_case import torch_uses_gpu
5
5
  from keras.src.testing.test_case import uses_gpu
6
+ from keras.src.testing.test_case import uses_tpu
@@ -40,7 +40,20 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
40
40
  self.addCleanup(lambda: shutil.rmtree(temp_dir))
41
41
  return temp_dir
42
42
 
43
- def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
43
+ def assertAllClose(
44
+ self,
45
+ x1,
46
+ x2,
47
+ atol=1e-6,
48
+ rtol=1e-6,
49
+ tpu_atol=None,
50
+ tpu_rtol=None,
51
+ msg=None,
52
+ ):
53
+ if tpu_atol is not None and uses_tpu():
54
+ atol = tpu_atol
55
+ if tpu_rtol is not None and uses_tpu():
56
+ rtol = tpu_rtol
44
57
  if not isinstance(x1, np.ndarray):
45
58
  x1 = backend.convert_to_numpy(x1)
46
59
  if not isinstance(x2, np.ndarray):
@@ -57,7 +70,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
57
70
  f"The two values are close at all elements. \n{msg}.\nValues: {x1}"
58
71
  )
59
72
 
60
- def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
73
+ def assertAlmostEqual(self, x1, x2, decimal=3, tpu_decimal=None, msg=None):
74
+ if tpu_decimal is not None and uses_tpu():
75
+ decimal = tpu_decimal
61
76
  msg = msg or ""
62
77
  if not isinstance(x1, np.ndarray):
63
78
  x1 = backend.convert_to_numpy(x1)
@@ -195,6 +210,8 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
195
210
  run_training_check=True,
196
211
  run_mixed_precision_check=True,
197
212
  assert_built_after_instantiation=False,
213
+ tpu_atol=None,
214
+ tpu_rtol=None,
198
215
  ):
199
216
  """Run basic checks on a layer.
200
217
 
@@ -376,7 +393,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
376
393
  msg="Unexpected number of torch_params",
377
394
  )
378
395
 
379
- def run_output_asserts(layer, output, eager=False):
396
+ def run_output_asserts(
397
+ layer, output, eager=False, tpu_atol=None, tpu_rtol=None
398
+ ):
380
399
  if expected_output_shape is not None:
381
400
 
382
401
  def verify_shape(expected_shape, x):
@@ -422,7 +441,11 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
422
441
  tree.flatten(expected_output), tree.flatten(output)
423
442
  ):
424
443
  self.assertAllClose(
425
- ref_v, v, msg="Unexpected output value"
444
+ ref_v,
445
+ v,
446
+ msg="Unexpected output value",
447
+ tpu_atol=tpu_atol,
448
+ tpu_rtol=tpu_rtol,
426
449
  )
427
450
  if expected_num_losses is not None:
428
451
  self.assertLen(layer.losses, expected_num_losses)
@@ -551,7 +574,13 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
551
574
  output_data = layer(**input_data, **call_kwargs)
552
575
  else:
553
576
  output_data = layer(input_data, **call_kwargs)
554
- run_output_asserts(layer, output_data, eager=True)
577
+ run_output_asserts(
578
+ layer,
579
+ output_data,
580
+ eager=True,
581
+ tpu_atol=tpu_atol,
582
+ tpu_rtol=tpu_rtol,
583
+ )
555
584
 
556
585
  if run_training_check:
557
586
  run_training_step(layer, input_data, output_data)
@@ -621,6 +650,17 @@ def uses_gpu():
621
650
  return False
622
651
 
623
652
 
653
+ def uses_tpu():
654
+ # Condition used to skip tests when using the TPU
655
+ try:
656
+ devices = distribution.list_devices()
657
+ if any(d.startswith("tpu") for d in devices):
658
+ return True
659
+ except AttributeError:
660
+ return False
661
+ return False
662
+
663
+
624
664
  def uses_cpu():
625
665
  devices = distribution.list_devices()
626
666
  if any(d.startswith("cpu") for d in devices):
@@ -148,6 +148,7 @@ class CompileMetrics(metrics_module.Metric):
148
148
  self.built = False
149
149
  self.name = "compile_metrics"
150
150
  self.output_names = output_names
151
+ self._resolved_output_names = None
151
152
 
152
153
  @property
153
154
  def metrics(self):
@@ -175,10 +176,16 @@ class CompileMetrics(metrics_module.Metric):
175
176
 
176
177
  def build(self, y_true, y_pred):
177
178
  num_outputs = 1 # default
178
- if self.output_names:
179
+ # Resolve output names. If y_pred is a dict, prefer its keys.
180
+ if isinstance(y_pred, dict):
181
+ keys = sorted(list(y_pred.keys()))
182
+ if self.output_names and set(self.output_names) == set(keys):
183
+ # If there is a perfect match, use the user-provided order.
184
+ output_names = self.output_names
185
+ else:
186
+ output_names = keys
187
+ elif self.output_names:
179
188
  output_names = self.output_names
180
- elif isinstance(y_pred, dict):
181
- output_names = sorted(list(y_pred.keys()))
182
189
  elif isinstance(y_pred, (list, tuple)):
183
190
  num_outputs = len(y_pred)
184
191
  if all(hasattr(x, "_keras_history") for x in y_pred):
@@ -187,6 +194,7 @@ class CompileMetrics(metrics_module.Metric):
187
194
  output_names = None
188
195
  else:
189
196
  output_names = None
197
+ self._resolved_output_names = output_names
190
198
  if output_names:
191
199
  num_outputs = len(output_names)
192
200
 
@@ -316,9 +324,10 @@ class CompileMetrics(metrics_module.Metric):
316
324
  return flat_metrics
317
325
 
318
326
  def _flatten_y(self, y):
319
- if isinstance(y, dict) and self.output_names:
327
+ names = self._resolved_output_names
328
+ if isinstance(y, dict) and names:
320
329
  result = []
321
- for name in self.output_names:
330
+ for name in names:
322
331
  if name in y:
323
332
  result.append(y[name])
324
333
  return result
@@ -690,17 +699,34 @@ class CompileLoss(losses_module.Loss):
690
699
  return self.call(y_true, y_pred, sample_weight)
691
700
 
692
701
  def call(self, y_true, y_pred, sample_weight=None):
702
+ def resolve_path(path, object):
703
+ for _path in path:
704
+ object = object[_path]
705
+ return object
706
+
693
707
  if not tree.is_nested(y_true) and not tree.is_nested(y_pred):
694
708
  # Fast path: single output case / no loss-tracking metric.
695
709
  if not self.built:
696
710
  self.build(y_true, y_pred)
697
- _, loss_fn, loss_weight, _ = self._flat_losses[0]
698
- loss_value = ops.cast(
699
- loss_fn(y_true, y_pred, sample_weight), dtype=self.dtype
700
- )
701
- if loss_weight is not None:
702
- loss_value = ops.multiply(loss_value, loss_weight)
703
- return loss_value
711
+ # Although we are in the fast path, we still need to iterate
712
+ # through the losses to prevent the torch compiler from failing.
713
+ loss_values = []
714
+ for path, loss_fn, loss_weight, _ in self._flat_losses:
715
+ y_t, y_p = (
716
+ resolve_path(path, y_true),
717
+ resolve_path(path, y_pred),
718
+ )
719
+ if sample_weight is not None and tree.is_nested(sample_weight):
720
+ _sample_weight = resolve_path(path, sample_weight)
721
+ else:
722
+ _sample_weight = sample_weight
723
+ value = ops.cast(
724
+ loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype
725
+ )
726
+ if loss_weight is not None:
727
+ value = ops.multiply(value, loss_weight)
728
+ loss_values.append(value)
729
+ return loss_values[0]
704
730
 
705
731
  try:
706
732
  tree.assert_same_structure(y_pred, y_true)
@@ -779,11 +805,6 @@ class CompileLoss(losses_module.Loss):
779
805
  # Iterate all losses in flat form.
780
806
  loss_values = []
781
807
 
782
- def resolve_path(path, object):
783
- for _path in path:
784
- object = object[_path]
785
- return object
786
-
787
808
  for (path, loss_fn, loss_weight, _), metric in zip(
788
809
  self._flat_losses, metrics
789
810
  ):
@@ -5,13 +5,9 @@ import numpy as np
5
5
  from keras.src import tree
6
6
  from keras.src.trainers.data_adapters import data_adapter_utils
7
7
  from keras.src.trainers.data_adapters.data_adapter import DataAdapter
8
+ from keras.src.utils.module_utils import grain
8
9
  from keras.src.utils.module_utils import tensorflow as tf
9
10
 
10
- try:
11
- import grain
12
- except ImportError:
13
- grain = None
14
-
15
11
 
16
12
  class GrainDatasetAdapter(DataAdapter):
17
13
  """Adapter that handles `grain.DataLoader`, `grain.MapDataset` and
@@ -0,0 +1,215 @@
1
+ from collections import defaultdict
2
+
3
+ from torch.utils import _pytree as torch_tree
4
+
5
+
6
+ def register_tree_node_class(cls):
7
+ torch_tree.register_pytree_node(
8
+ cls,
9
+ flatten_fn=lambda x: x.torchtree_flatten(),
10
+ unflatten_fn=cls.torchtree_unflatten,
11
+ serialized_type_name=f"{cls.__name__}",
12
+ flatten_with_keys_fn=lambda x: x.torchtree_flatten_with_keys(),
13
+ )
14
+ return cls
15
+
16
+
17
+ def _tree_is_leaf(tree, is_leaf=None):
18
+ if is_leaf is not None and is_leaf(tree):
19
+ return True
20
+ return torch_tree._get_node_type(tree) not in torch_tree.SUPPORTED_NODES
21
+
22
+
23
+ def _dict_to_ordered_dict(structure):
24
+ # We need to sort dict and defaultdict to ensure a deterministic order that
25
+ # that is consistent with other tree implementations.
26
+ def func(x):
27
+ if type(x) is dict:
28
+ return {k: x[k] for k in sorted(x.keys())}
29
+ elif type(x) is defaultdict:
30
+ return defaultdict(
31
+ x.default_factory,
32
+ {k: x[k] for k in sorted(x.keys())},
33
+ )
34
+ return None
35
+
36
+ def traverse_children():
37
+ children, treedef = torch_tree.tree_flatten(
38
+ structure,
39
+ is_leaf=lambda x: x is not structure,
40
+ )
41
+ if treedef.num_nodes == 1 and treedef.num_leaves == 1:
42
+ return structure
43
+ else:
44
+ return torch_tree.tree_unflatten(
45
+ [_dict_to_ordered_dict(c) for c in children],
46
+ treedef,
47
+ )
48
+
49
+ ret = func(structure)
50
+ if ret is None:
51
+ return traverse_children()
52
+ if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE":
53
+ return None
54
+ return ret
55
+
56
+
57
+ def is_nested(structure):
58
+ return not _tree_is_leaf(structure)
59
+
60
+
61
+ def traverse(func, structure, top_down=True):
62
+ def traverse_children():
63
+ children, treedef = torch_tree.tree_flatten(
64
+ structure,
65
+ is_leaf=lambda x: x is not structure,
66
+ )
67
+ if treedef.num_nodes == 1 and treedef.num_leaves == 1:
68
+ return structure
69
+ else:
70
+ return torch_tree.tree_unflatten(
71
+ [traverse(func, c, top_down=top_down) for c in children],
72
+ treedef,
73
+ )
74
+
75
+ structure = _dict_to_ordered_dict(structure)
76
+ if top_down:
77
+ ret = func(structure)
78
+ if ret is None:
79
+ return traverse_children()
80
+ else:
81
+ traversed_structure = traverse_children()
82
+ ret = func(traversed_structure)
83
+ if ret is None:
84
+ return traversed_structure
85
+ # Detect MAP_TO_NONE without tree_api import to avoid circular import.
86
+ if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE":
87
+ return None
88
+ return ret
89
+
90
+
91
+ def flatten(structure):
92
+ # We need to first sort dicts to ensure a deterministic order that is
93
+ # consistent with other tree implementations.
94
+ structure = _dict_to_ordered_dict(structure)
95
+ leaves, _ = torch_tree.tree_flatten(structure)
96
+ return leaves
97
+
98
+
99
+ def flatten_with_path(structure):
100
+ # We need to first sort dicts to ensure a deterministic order that is
101
+ # consistent with other tree implementations.
102
+ structure = _dict_to_ordered_dict(structure)
103
+ leaves_with_path, _ = torch_tree.tree_flatten_with_path(structure)
104
+ results = []
105
+ fields = []
106
+ for key, leaf in leaves_with_path:
107
+ for k in key:
108
+ if isinstance(k, torch_tree.GetAttrKey) and k.name not in fields:
109
+ fields.append(k.name)
110
+ fields = sorted(fields)
111
+ field_to_idx = {f: i for i, f in enumerate(fields)}
112
+ for key, leaf in leaves_with_path:
113
+ # Convert to a tuple of keys.
114
+ path = []
115
+ for k in key:
116
+ if isinstance(k, torch_tree.SequenceKey):
117
+ path.append(k.idx)
118
+ elif isinstance(k, torch_tree.MappingKey):
119
+ path.append(k.key)
120
+ elif isinstance(k, torch_tree.GetAttrKey):
121
+ path.append(field_to_idx[k.name])
122
+ results.append((tuple(path), leaf))
123
+ return results
124
+
125
+
126
+ def map_structure(func, *structures, none_is_leaf=True):
127
+ if not structures:
128
+ raise ValueError("Must provide at least one structure")
129
+
130
+ map_func = func
131
+ if not none_is_leaf:
132
+
133
+ def func_skipping_none(*args):
134
+ # Check if the reference entry (first one) is None
135
+ if args[0] is None:
136
+ if not all(s is None for s in args):
137
+ raise ValueError(
138
+ "Structure mismatch: some arguments are None, others "
139
+ f"are not. Received arguments: {args}."
140
+ )
141
+ return None
142
+ return func(*args)
143
+
144
+ map_func = func_skipping_none
145
+
146
+ return torch_tree.tree_map(map_func, *structures)
147
+
148
+
149
+ def map_structure_up_to(shallow_structure, func, *structures):
150
+ if not structures:
151
+ raise ValueError("Must provide at least one structure")
152
+
153
+ # Add check that `shallow_structure` really is the shallowest.
154
+ # Also only call `func` on `structures` and not `shallow_structure`.
155
+ def func_with_check_without_shallow_structure(shallow, *args):
156
+ if not _tree_is_leaf(shallow):
157
+ raise ValueError("Structures don't have the same nested structure.")
158
+ return func(*args)
159
+
160
+ return torch_tree.tree_map(
161
+ func_with_check_without_shallow_structure,
162
+ shallow_structure,
163
+ *structures,
164
+ )
165
+
166
+
167
+ def assert_same_structure(a, b):
168
+ def check(a_leaf, b_leaf):
169
+ if not _tree_is_leaf(a_leaf) or not _tree_is_leaf(b_leaf):
170
+ raise ValueError("Structures don't have the same nested structure.")
171
+ return None
172
+
173
+ torch_tree.tree_map(check, a, b)
174
+
175
+
176
+ def assert_same_paths(a, b):
177
+ a_paths = set([path for path, _ in flatten_with_path(a)])
178
+ b_paths = set([path for path, _ in flatten_with_path(b)])
179
+
180
+ if a_paths != b_paths:
181
+ msg = "`a` and `b` don't have the same paths."
182
+ a_diff = a_paths.difference(b_paths)
183
+ if a_diff:
184
+ msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
185
+ b_diff = b_paths.difference(a_paths)
186
+ if b_diff:
187
+ msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
188
+ raise ValueError(msg)
189
+
190
+
191
+ def pack_sequence_as(structure, flat_sequence):
192
+ # We need to first sort dicts to ensure a deterministic order that is
193
+ # consistent with other tree implementations.
194
+ structure = _dict_to_ordered_dict(structure)
195
+ _, treespec = torch_tree.tree_flatten(structure)
196
+ return torch_tree.tree_unflatten(flat_sequence, treespec)
197
+
198
+
199
+ def lists_to_tuples(structure):
200
+ def list_to_tuple(instance):
201
+ return tuple(instance) if isinstance(instance, list) else None
202
+
203
+ return traverse(list_to_tuple, structure, top_down=False)
204
+
205
+
206
+ def map_shape_structure(func, structure):
207
+ def is_shape_tuple(x):
208
+ return isinstance(x, (list, tuple)) and all(
209
+ isinstance(e, (int, type(None))) for e in x
210
+ )
211
+
212
+ # We need to first sort dicts to ensure a deterministic order that is
213
+ # consistent with other tree implementations.
214
+ structure = _dict_to_ordered_dict(structure)
215
+ return torch_tree.tree_map(func, structure, is_leaf=is_shape_tuple)