keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/saving/file_editor.py
CHANGED
|
@@ -455,6 +455,9 @@ class KerasFileEditor:
|
|
|
455
455
|
def _extract_weights_from_store(self, data, metadata=None, inner_path=""):
|
|
456
456
|
metadata = metadata or {}
|
|
457
457
|
|
|
458
|
+
# ------------------------------------------------------
|
|
459
|
+
# Collect metadata for this HDF5 group
|
|
460
|
+
# ------------------------------------------------------
|
|
458
461
|
object_metadata = {}
|
|
459
462
|
for k, v in data.attrs.items():
|
|
460
463
|
object_metadata[k] = v
|
|
@@ -462,26 +465,98 @@ class KerasFileEditor:
|
|
|
462
465
|
metadata[inner_path] = object_metadata
|
|
463
466
|
|
|
464
467
|
result = collections.OrderedDict()
|
|
468
|
+
|
|
469
|
+
# ------------------------------------------------------
|
|
470
|
+
# Iterate over all keys in this HDF5 group
|
|
471
|
+
# ------------------------------------------------------
|
|
465
472
|
for key in data.keys():
|
|
466
|
-
|
|
473
|
+
# IMPORTANT:
|
|
474
|
+
# Never mutate inner_path; use local variable.
|
|
475
|
+
current_inner_path = f"{inner_path}/{key}"
|
|
467
476
|
value = data[key]
|
|
477
|
+
|
|
478
|
+
# ------------------------------------------------------
|
|
479
|
+
# CASE 1 — HDF5 GROUP → RECURSE
|
|
480
|
+
# ------------------------------------------------------
|
|
468
481
|
if isinstance(value, h5py.Group):
|
|
482
|
+
# Skip empty groups
|
|
469
483
|
if len(value) == 0:
|
|
470
484
|
continue
|
|
485
|
+
|
|
486
|
+
# Skip empty "vars" groups
|
|
471
487
|
if "vars" in value.keys() and len(value["vars"]) == 0:
|
|
472
488
|
continue
|
|
473
489
|
|
|
474
|
-
|
|
490
|
+
# Recurse into "vars" subgroup when present
|
|
475
491
|
if "vars" in value.keys():
|
|
476
492
|
result[key], metadata = self._extract_weights_from_store(
|
|
477
|
-
value["vars"],
|
|
493
|
+
value["vars"],
|
|
494
|
+
metadata=metadata,
|
|
495
|
+
inner_path=current_inner_path,
|
|
478
496
|
)
|
|
479
497
|
else:
|
|
498
|
+
# Recurse normally
|
|
480
499
|
result[key], metadata = self._extract_weights_from_store(
|
|
481
|
-
value,
|
|
500
|
+
value,
|
|
501
|
+
metadata=metadata,
|
|
502
|
+
inner_path=current_inner_path,
|
|
482
503
|
)
|
|
483
|
-
|
|
484
|
-
|
|
504
|
+
|
|
505
|
+
continue # finished processing this key
|
|
506
|
+
|
|
507
|
+
# ------------------------------------------------------
|
|
508
|
+
# CASE 2 — HDF5 DATASET → SAFE LOADING
|
|
509
|
+
# ------------------------------------------------------
|
|
510
|
+
|
|
511
|
+
# Skip any objects that are not proper datasets
|
|
512
|
+
if not hasattr(value, "shape") or not hasattr(value, "dtype"):
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
shape = value.shape
|
|
516
|
+
dtype = value.dtype
|
|
517
|
+
|
|
518
|
+
# ------------------------------------------------------
|
|
519
|
+
# Validate SHAPE (avoid malformed / malicious metadata)
|
|
520
|
+
# ------------------------------------------------------
|
|
521
|
+
|
|
522
|
+
# No negative dimensions
|
|
523
|
+
if any(dim < 0 for dim in shape):
|
|
524
|
+
raise ValueError(
|
|
525
|
+
"Malformed HDF5 dataset shape encountered in .keras file; "
|
|
526
|
+
"negative dimension detected."
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Prevent absurdly high-rank tensors
|
|
530
|
+
if len(shape) > 64:
|
|
531
|
+
raise ValueError(
|
|
532
|
+
"Malformed HDF5 dataset shape encountered in .keras file; "
|
|
533
|
+
"tensor rank exceeds safety limit."
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
# Safe product computation (Python int is unbounded)
|
|
537
|
+
num_elems = int(np.prod(shape))
|
|
538
|
+
|
|
539
|
+
# ------------------------------------------------------
|
|
540
|
+
# Validate TOTAL memory size
|
|
541
|
+
# ------------------------------------------------------
|
|
542
|
+
MAX_BYTES = 1 << 32 # 4 GiB
|
|
543
|
+
|
|
544
|
+
size_bytes = num_elems * dtype.itemsize
|
|
545
|
+
|
|
546
|
+
if size_bytes > MAX_BYTES:
|
|
547
|
+
raise ValueError(
|
|
548
|
+
f"HDF5 dataset too large to load safely "
|
|
549
|
+
f"({size_bytes} bytes; limit is {MAX_BYTES})."
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
# ------------------------------------------------------
|
|
553
|
+
# SAFE — load dataset (guaranteed ≤ 4 GiB)
|
|
554
|
+
# ------------------------------------------------------
|
|
555
|
+
result[key] = value[()]
|
|
556
|
+
|
|
557
|
+
# ------------------------------------------------------
|
|
558
|
+
# Return final tree and metadata
|
|
559
|
+
# ------------------------------------------------------
|
|
485
560
|
return result, metadata
|
|
486
561
|
|
|
487
562
|
def _generate_filepath_info(self, rich_style=False):
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Orbax checkpoint loading functionality."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from keras.src.utils.module_utils import ocp
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def is_orbax_checkpoint(filepath):
|
|
9
|
+
"""Check if the given path is an Orbax checkpoint directory."""
|
|
10
|
+
if not os.path.exists(filepath):
|
|
11
|
+
return False
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
return ocp.is_orbax_checkpoint(filepath)
|
|
15
|
+
except (ImportError, AttributeError):
|
|
16
|
+
# Fallback to check for orbax.checkpoint file if Orbax API not available
|
|
17
|
+
return os.path.isfile(os.path.join(filepath, "orbax.checkpoint"))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def find_latest_orbax_checkpoint(checkpoint_dir):
|
|
21
|
+
"""Find the latest checkpoint in an Orbax checkpoint directory."""
|
|
22
|
+
checkpointer = ocp.training.Checkpointer(directory=checkpoint_dir)
|
|
23
|
+
latest = checkpointer.latest
|
|
24
|
+
if latest is None:
|
|
25
|
+
raise ValueError(f"No valid checkpoints found in {checkpoint_dir}")
|
|
26
|
+
return os.path.join(checkpoint_dir, str(latest.step))
|
keras/src/saving/saving_api.py
CHANGED
|
@@ -6,13 +6,11 @@ from absl import logging
|
|
|
6
6
|
from keras.src.api_export import keras_export
|
|
7
7
|
from keras.src.legacy.saving import legacy_h5_format
|
|
8
8
|
from keras.src.saving import saving_lib
|
|
9
|
+
from keras.src.saving.orbax_util import find_latest_orbax_checkpoint
|
|
10
|
+
from keras.src.saving.orbax_util import is_orbax_checkpoint
|
|
9
11
|
from keras.src.utils import file_utils
|
|
10
12
|
from keras.src.utils import io_utils
|
|
11
|
-
|
|
12
|
-
try:
|
|
13
|
-
import h5py
|
|
14
|
-
except ImportError:
|
|
15
|
-
h5py = None
|
|
13
|
+
from keras.src.utils.module_utils import h5py
|
|
16
14
|
|
|
17
15
|
|
|
18
16
|
@keras_export(["keras.saving.save_model", "keras.models.save_model"])
|
|
@@ -149,8 +147,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
149
147
|
keras.layers.Softmax()])
|
|
150
148
|
model.save("model.keras")
|
|
151
149
|
loaded_model = keras.saving.load_model("model.keras")
|
|
152
|
-
x = np.random.random((10, 3))
|
|
153
|
-
assert np.allclose(model.predict(x), loaded_model.predict(x))
|
|
154
150
|
```
|
|
155
151
|
|
|
156
152
|
Note that the model variables may have different name values
|
|
@@ -208,7 +204,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
|
|
208
204
|
else:
|
|
209
205
|
raise ValueError(
|
|
210
206
|
f"File format not supported: filepath={filepath}. "
|
|
211
|
-
"Keras 3 only supports V3 `.keras` files
|
|
207
|
+
"Keras 3 only supports V3 `.keras` files, "
|
|
212
208
|
"legacy H5 format files (`.h5` extension). "
|
|
213
209
|
"Note that the legacy SavedModel format is not "
|
|
214
210
|
"supported by `load_model()` in Keras 3. In "
|
|
@@ -288,15 +284,16 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
288
284
|
objects_to_skip=objects_to_skip,
|
|
289
285
|
)
|
|
290
286
|
elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
|
|
291
|
-
if not h5py:
|
|
292
|
-
raise ImportError(
|
|
293
|
-
"Loading a H5 file requires `h5py` to be installed."
|
|
294
|
-
)
|
|
295
287
|
if objects_to_skip is not None:
|
|
296
288
|
raise ValueError(
|
|
297
289
|
"`objects_to_skip` only supports loading '.weights.h5' files."
|
|
298
290
|
f"Received: {filepath}"
|
|
299
291
|
)
|
|
292
|
+
if not h5py.available:
|
|
293
|
+
raise ImportError(
|
|
294
|
+
"Loading HDF5 files requires the h5py package. "
|
|
295
|
+
"You can install it via `pip install h5py`"
|
|
296
|
+
)
|
|
300
297
|
with h5py.File(filepath, "r") as f:
|
|
301
298
|
if "layer_names" not in f.attrs and "model_weights" in f:
|
|
302
299
|
f = f["model_weights"]
|
|
@@ -308,9 +305,35 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
|
|
308
305
|
legacy_h5_format.load_weights_from_hdf5_group(
|
|
309
306
|
f, model, skip_mismatch
|
|
310
307
|
)
|
|
308
|
+
elif is_orbax_checkpoint(filepath):
|
|
309
|
+
# Load weights from Orbax checkpoint
|
|
310
|
+
from keras.src.utils.module_utils import ocp
|
|
311
|
+
|
|
312
|
+
filepath = str(filepath)
|
|
313
|
+
|
|
314
|
+
# Determine if this is a root directory or a step directory
|
|
315
|
+
items = os.listdir(filepath)
|
|
316
|
+
has_step_subdirs = any(
|
|
317
|
+
os.path.isdir(os.path.join(filepath, item)) and item.isdigit()
|
|
318
|
+
for item in items
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
if has_step_subdirs:
|
|
322
|
+
# It's a root directory, find the latest checkpoint
|
|
323
|
+
checkpoint_path = find_latest_orbax_checkpoint(filepath)
|
|
324
|
+
else:
|
|
325
|
+
# It's a step directory, use it directly
|
|
326
|
+
checkpoint_path = filepath
|
|
327
|
+
|
|
328
|
+
# Load checkpoint
|
|
329
|
+
loaded_state = ocp.load_pytree(checkpoint_path)
|
|
330
|
+
|
|
331
|
+
# Set the model state directly from the loaded state
|
|
332
|
+
model.set_state_tree(loaded_state)
|
|
311
333
|
else:
|
|
312
334
|
raise ValueError(
|
|
313
335
|
f"File format not supported: filepath={filepath}. "
|
|
314
|
-
"Keras 3 only supports V3 `.keras`
|
|
315
|
-
"files,
|
|
336
|
+
"Keras 3 only supports V3 `.keras` files, "
|
|
337
|
+
"`.weights.h5` files, legacy H5 format files "
|
|
338
|
+
"(`.h5` extension), or Orbax checkpoints."
|
|
316
339
|
)
|
keras/src/saving/saving_lib.py
CHANGED
|
@@ -943,7 +943,7 @@ class DiskIOStore:
|
|
|
943
943
|
if self.archive:
|
|
944
944
|
self.tmp_dir = get_temp_dir()
|
|
945
945
|
if self.mode == "r":
|
|
946
|
-
self.archive
|
|
946
|
+
file_utils.extract_open_archive(self.archive, self.tmp_dir)
|
|
947
947
|
self.working_dir = file_utils.join(
|
|
948
948
|
self.tmp_dir, self.root_path
|
|
949
949
|
).replace("\\", "/")
|
keras/src/testing/__init__.py
CHANGED
keras/src/testing/test_case.py
CHANGED
|
@@ -40,7 +40,20 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
40
40
|
self.addCleanup(lambda: shutil.rmtree(temp_dir))
|
|
41
41
|
return temp_dir
|
|
42
42
|
|
|
43
|
-
def assertAllClose(
|
|
43
|
+
def assertAllClose(
|
|
44
|
+
self,
|
|
45
|
+
x1,
|
|
46
|
+
x2,
|
|
47
|
+
atol=1e-6,
|
|
48
|
+
rtol=1e-6,
|
|
49
|
+
tpu_atol=None,
|
|
50
|
+
tpu_rtol=None,
|
|
51
|
+
msg=None,
|
|
52
|
+
):
|
|
53
|
+
if tpu_atol is not None and uses_tpu():
|
|
54
|
+
atol = tpu_atol
|
|
55
|
+
if tpu_rtol is not None and uses_tpu():
|
|
56
|
+
rtol = tpu_rtol
|
|
44
57
|
if not isinstance(x1, np.ndarray):
|
|
45
58
|
x1 = backend.convert_to_numpy(x1)
|
|
46
59
|
if not isinstance(x2, np.ndarray):
|
|
@@ -57,7 +70,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
57
70
|
f"The two values are close at all elements. \n{msg}.\nValues: {x1}"
|
|
58
71
|
)
|
|
59
72
|
|
|
60
|
-
def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
|
|
73
|
+
def assertAlmostEqual(self, x1, x2, decimal=3, tpu_decimal=None, msg=None):
|
|
74
|
+
if tpu_decimal is not None and uses_tpu():
|
|
75
|
+
decimal = tpu_decimal
|
|
61
76
|
msg = msg or ""
|
|
62
77
|
if not isinstance(x1, np.ndarray):
|
|
63
78
|
x1 = backend.convert_to_numpy(x1)
|
|
@@ -195,6 +210,8 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
195
210
|
run_training_check=True,
|
|
196
211
|
run_mixed_precision_check=True,
|
|
197
212
|
assert_built_after_instantiation=False,
|
|
213
|
+
tpu_atol=None,
|
|
214
|
+
tpu_rtol=None,
|
|
198
215
|
):
|
|
199
216
|
"""Run basic checks on a layer.
|
|
200
217
|
|
|
@@ -376,7 +393,9 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
376
393
|
msg="Unexpected number of torch_params",
|
|
377
394
|
)
|
|
378
395
|
|
|
379
|
-
def run_output_asserts(
|
|
396
|
+
def run_output_asserts(
|
|
397
|
+
layer, output, eager=False, tpu_atol=None, tpu_rtol=None
|
|
398
|
+
):
|
|
380
399
|
if expected_output_shape is not None:
|
|
381
400
|
|
|
382
401
|
def verify_shape(expected_shape, x):
|
|
@@ -422,7 +441,11 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
422
441
|
tree.flatten(expected_output), tree.flatten(output)
|
|
423
442
|
):
|
|
424
443
|
self.assertAllClose(
|
|
425
|
-
ref_v,
|
|
444
|
+
ref_v,
|
|
445
|
+
v,
|
|
446
|
+
msg="Unexpected output value",
|
|
447
|
+
tpu_atol=tpu_atol,
|
|
448
|
+
tpu_rtol=tpu_rtol,
|
|
426
449
|
)
|
|
427
450
|
if expected_num_losses is not None:
|
|
428
451
|
self.assertLen(layer.losses, expected_num_losses)
|
|
@@ -551,7 +574,13 @@ class TestCase(parameterized.TestCase, unittest.TestCase):
|
|
|
551
574
|
output_data = layer(**input_data, **call_kwargs)
|
|
552
575
|
else:
|
|
553
576
|
output_data = layer(input_data, **call_kwargs)
|
|
554
|
-
run_output_asserts(
|
|
577
|
+
run_output_asserts(
|
|
578
|
+
layer,
|
|
579
|
+
output_data,
|
|
580
|
+
eager=True,
|
|
581
|
+
tpu_atol=tpu_atol,
|
|
582
|
+
tpu_rtol=tpu_rtol,
|
|
583
|
+
)
|
|
555
584
|
|
|
556
585
|
if run_training_check:
|
|
557
586
|
run_training_step(layer, input_data, output_data)
|
|
@@ -621,6 +650,17 @@ def uses_gpu():
|
|
|
621
650
|
return False
|
|
622
651
|
|
|
623
652
|
|
|
653
|
+
def uses_tpu():
|
|
654
|
+
# Condition used to skip tests when using the TPU
|
|
655
|
+
try:
|
|
656
|
+
devices = distribution.list_devices()
|
|
657
|
+
if any(d.startswith("tpu") for d in devices):
|
|
658
|
+
return True
|
|
659
|
+
except AttributeError:
|
|
660
|
+
return False
|
|
661
|
+
return False
|
|
662
|
+
|
|
663
|
+
|
|
624
664
|
def uses_cpu():
|
|
625
665
|
devices = distribution.list_devices()
|
|
626
666
|
if any(d.startswith("cpu") for d in devices):
|
|
@@ -148,6 +148,7 @@ class CompileMetrics(metrics_module.Metric):
|
|
|
148
148
|
self.built = False
|
|
149
149
|
self.name = "compile_metrics"
|
|
150
150
|
self.output_names = output_names
|
|
151
|
+
self._resolved_output_names = None
|
|
151
152
|
|
|
152
153
|
@property
|
|
153
154
|
def metrics(self):
|
|
@@ -175,10 +176,16 @@ class CompileMetrics(metrics_module.Metric):
|
|
|
175
176
|
|
|
176
177
|
def build(self, y_true, y_pred):
|
|
177
178
|
num_outputs = 1 # default
|
|
178
|
-
|
|
179
|
+
# Resolve output names. If y_pred is a dict, prefer its keys.
|
|
180
|
+
if isinstance(y_pred, dict):
|
|
181
|
+
keys = sorted(list(y_pred.keys()))
|
|
182
|
+
if self.output_names and set(self.output_names) == set(keys):
|
|
183
|
+
# If there is a perfect match, use the user-provided order.
|
|
184
|
+
output_names = self.output_names
|
|
185
|
+
else:
|
|
186
|
+
output_names = keys
|
|
187
|
+
elif self.output_names:
|
|
179
188
|
output_names = self.output_names
|
|
180
|
-
elif isinstance(y_pred, dict):
|
|
181
|
-
output_names = sorted(list(y_pred.keys()))
|
|
182
189
|
elif isinstance(y_pred, (list, tuple)):
|
|
183
190
|
num_outputs = len(y_pred)
|
|
184
191
|
if all(hasattr(x, "_keras_history") for x in y_pred):
|
|
@@ -187,6 +194,7 @@ class CompileMetrics(metrics_module.Metric):
|
|
|
187
194
|
output_names = None
|
|
188
195
|
else:
|
|
189
196
|
output_names = None
|
|
197
|
+
self._resolved_output_names = output_names
|
|
190
198
|
if output_names:
|
|
191
199
|
num_outputs = len(output_names)
|
|
192
200
|
|
|
@@ -316,9 +324,10 @@ class CompileMetrics(metrics_module.Metric):
|
|
|
316
324
|
return flat_metrics
|
|
317
325
|
|
|
318
326
|
def _flatten_y(self, y):
|
|
319
|
-
|
|
327
|
+
names = self._resolved_output_names
|
|
328
|
+
if isinstance(y, dict) and names:
|
|
320
329
|
result = []
|
|
321
|
-
for name in
|
|
330
|
+
for name in names:
|
|
322
331
|
if name in y:
|
|
323
332
|
result.append(y[name])
|
|
324
333
|
return result
|
|
@@ -690,17 +699,34 @@ class CompileLoss(losses_module.Loss):
|
|
|
690
699
|
return self.call(y_true, y_pred, sample_weight)
|
|
691
700
|
|
|
692
701
|
def call(self, y_true, y_pred, sample_weight=None):
|
|
702
|
+
def resolve_path(path, object):
|
|
703
|
+
for _path in path:
|
|
704
|
+
object = object[_path]
|
|
705
|
+
return object
|
|
706
|
+
|
|
693
707
|
if not tree.is_nested(y_true) and not tree.is_nested(y_pred):
|
|
694
708
|
# Fast path: single output case / no loss-tracking metric.
|
|
695
709
|
if not self.built:
|
|
696
710
|
self.build(y_true, y_pred)
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
711
|
+
# Although we are in the fast path, we still need to iterate
|
|
712
|
+
# through the losses to prevent the torch compiler from failing.
|
|
713
|
+
loss_values = []
|
|
714
|
+
for path, loss_fn, loss_weight, _ in self._flat_losses:
|
|
715
|
+
y_t, y_p = (
|
|
716
|
+
resolve_path(path, y_true),
|
|
717
|
+
resolve_path(path, y_pred),
|
|
718
|
+
)
|
|
719
|
+
if sample_weight is not None and tree.is_nested(sample_weight):
|
|
720
|
+
_sample_weight = resolve_path(path, sample_weight)
|
|
721
|
+
else:
|
|
722
|
+
_sample_weight = sample_weight
|
|
723
|
+
value = ops.cast(
|
|
724
|
+
loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype
|
|
725
|
+
)
|
|
726
|
+
if loss_weight is not None:
|
|
727
|
+
value = ops.multiply(value, loss_weight)
|
|
728
|
+
loss_values.append(value)
|
|
729
|
+
return loss_values[0]
|
|
704
730
|
|
|
705
731
|
try:
|
|
706
732
|
tree.assert_same_structure(y_pred, y_true)
|
|
@@ -779,11 +805,6 @@ class CompileLoss(losses_module.Loss):
|
|
|
779
805
|
# Iterate all losses in flat form.
|
|
780
806
|
loss_values = []
|
|
781
807
|
|
|
782
|
-
def resolve_path(path, object):
|
|
783
|
-
for _path in path:
|
|
784
|
-
object = object[_path]
|
|
785
|
-
return object
|
|
786
|
-
|
|
787
808
|
for (path, loss_fn, loss_weight, _), metric in zip(
|
|
788
809
|
self._flat_losses, metrics
|
|
789
810
|
):
|
|
@@ -5,13 +5,9 @@ import numpy as np
|
|
|
5
5
|
from keras.src import tree
|
|
6
6
|
from keras.src.trainers.data_adapters import data_adapter_utils
|
|
7
7
|
from keras.src.trainers.data_adapters.data_adapter import DataAdapter
|
|
8
|
+
from keras.src.utils.module_utils import grain
|
|
8
9
|
from keras.src.utils.module_utils import tensorflow as tf
|
|
9
10
|
|
|
10
|
-
try:
|
|
11
|
-
import grain
|
|
12
|
-
except ImportError:
|
|
13
|
-
grain = None
|
|
14
|
-
|
|
15
11
|
|
|
16
12
|
class GrainDatasetAdapter(DataAdapter):
|
|
17
13
|
"""Adapter that handles `grain.DataLoader`, `grain.MapDataset` and
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
from torch.utils import _pytree as torch_tree
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def register_tree_node_class(cls):
|
|
7
|
+
torch_tree.register_pytree_node(
|
|
8
|
+
cls,
|
|
9
|
+
flatten_fn=lambda x: x.torchtree_flatten(),
|
|
10
|
+
unflatten_fn=cls.torchtree_unflatten,
|
|
11
|
+
serialized_type_name=f"{cls.__name__}",
|
|
12
|
+
flatten_with_keys_fn=lambda x: x.torchtree_flatten_with_keys(),
|
|
13
|
+
)
|
|
14
|
+
return cls
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _tree_is_leaf(tree, is_leaf=None):
|
|
18
|
+
if is_leaf is not None and is_leaf(tree):
|
|
19
|
+
return True
|
|
20
|
+
return torch_tree._get_node_type(tree) not in torch_tree.SUPPORTED_NODES
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _dict_to_ordered_dict(structure):
|
|
24
|
+
# We need to sort dict and defaultdict to ensure a deterministic order that
|
|
25
|
+
# that is consistent with other tree implementations.
|
|
26
|
+
def func(x):
|
|
27
|
+
if type(x) is dict:
|
|
28
|
+
return {k: x[k] for k in sorted(x.keys())}
|
|
29
|
+
elif type(x) is defaultdict:
|
|
30
|
+
return defaultdict(
|
|
31
|
+
x.default_factory,
|
|
32
|
+
{k: x[k] for k in sorted(x.keys())},
|
|
33
|
+
)
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
def traverse_children():
|
|
37
|
+
children, treedef = torch_tree.tree_flatten(
|
|
38
|
+
structure,
|
|
39
|
+
is_leaf=lambda x: x is not structure,
|
|
40
|
+
)
|
|
41
|
+
if treedef.num_nodes == 1 and treedef.num_leaves == 1:
|
|
42
|
+
return structure
|
|
43
|
+
else:
|
|
44
|
+
return torch_tree.tree_unflatten(
|
|
45
|
+
[_dict_to_ordered_dict(c) for c in children],
|
|
46
|
+
treedef,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
ret = func(structure)
|
|
50
|
+
if ret is None:
|
|
51
|
+
return traverse_children()
|
|
52
|
+
if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE":
|
|
53
|
+
return None
|
|
54
|
+
return ret
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def is_nested(structure):
|
|
58
|
+
return not _tree_is_leaf(structure)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def traverse(func, structure, top_down=True):
|
|
62
|
+
def traverse_children():
|
|
63
|
+
children, treedef = torch_tree.tree_flatten(
|
|
64
|
+
structure,
|
|
65
|
+
is_leaf=lambda x: x is not structure,
|
|
66
|
+
)
|
|
67
|
+
if treedef.num_nodes == 1 and treedef.num_leaves == 1:
|
|
68
|
+
return structure
|
|
69
|
+
else:
|
|
70
|
+
return torch_tree.tree_unflatten(
|
|
71
|
+
[traverse(func, c, top_down=top_down) for c in children],
|
|
72
|
+
treedef,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
structure = _dict_to_ordered_dict(structure)
|
|
76
|
+
if top_down:
|
|
77
|
+
ret = func(structure)
|
|
78
|
+
if ret is None:
|
|
79
|
+
return traverse_children()
|
|
80
|
+
else:
|
|
81
|
+
traversed_structure = traverse_children()
|
|
82
|
+
ret = func(traversed_structure)
|
|
83
|
+
if ret is None:
|
|
84
|
+
return traversed_structure
|
|
85
|
+
# Detect MAP_TO_NONE without tree_api import to avoid circular import.
|
|
86
|
+
if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE":
|
|
87
|
+
return None
|
|
88
|
+
return ret
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def flatten(structure):
|
|
92
|
+
# We need to first sort dicts to ensure a deterministic order that is
|
|
93
|
+
# consistent with other tree implementations.
|
|
94
|
+
structure = _dict_to_ordered_dict(structure)
|
|
95
|
+
leaves, _ = torch_tree.tree_flatten(structure)
|
|
96
|
+
return leaves
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def flatten_with_path(structure):
|
|
100
|
+
# We need to first sort dicts to ensure a deterministic order that is
|
|
101
|
+
# consistent with other tree implementations.
|
|
102
|
+
structure = _dict_to_ordered_dict(structure)
|
|
103
|
+
leaves_with_path, _ = torch_tree.tree_flatten_with_path(structure)
|
|
104
|
+
results = []
|
|
105
|
+
fields = []
|
|
106
|
+
for key, leaf in leaves_with_path:
|
|
107
|
+
for k in key:
|
|
108
|
+
if isinstance(k, torch_tree.GetAttrKey) and k.name not in fields:
|
|
109
|
+
fields.append(k.name)
|
|
110
|
+
fields = sorted(fields)
|
|
111
|
+
field_to_idx = {f: i for i, f in enumerate(fields)}
|
|
112
|
+
for key, leaf in leaves_with_path:
|
|
113
|
+
# Convert to a tuple of keys.
|
|
114
|
+
path = []
|
|
115
|
+
for k in key:
|
|
116
|
+
if isinstance(k, torch_tree.SequenceKey):
|
|
117
|
+
path.append(k.idx)
|
|
118
|
+
elif isinstance(k, torch_tree.MappingKey):
|
|
119
|
+
path.append(k.key)
|
|
120
|
+
elif isinstance(k, torch_tree.GetAttrKey):
|
|
121
|
+
path.append(field_to_idx[k.name])
|
|
122
|
+
results.append((tuple(path), leaf))
|
|
123
|
+
return results
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def map_structure(func, *structures, none_is_leaf=True):
|
|
127
|
+
if not structures:
|
|
128
|
+
raise ValueError("Must provide at least one structure")
|
|
129
|
+
|
|
130
|
+
map_func = func
|
|
131
|
+
if not none_is_leaf:
|
|
132
|
+
|
|
133
|
+
def func_skipping_none(*args):
|
|
134
|
+
# Check if the reference entry (first one) is None
|
|
135
|
+
if args[0] is None:
|
|
136
|
+
if not all(s is None for s in args):
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"Structure mismatch: some arguments are None, others "
|
|
139
|
+
f"are not. Received arguments: {args}."
|
|
140
|
+
)
|
|
141
|
+
return None
|
|
142
|
+
return func(*args)
|
|
143
|
+
|
|
144
|
+
map_func = func_skipping_none
|
|
145
|
+
|
|
146
|
+
return torch_tree.tree_map(map_func, *structures)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def map_structure_up_to(shallow_structure, func, *structures):
|
|
150
|
+
if not structures:
|
|
151
|
+
raise ValueError("Must provide at least one structure")
|
|
152
|
+
|
|
153
|
+
# Add check that `shallow_structure` really is the shallowest.
|
|
154
|
+
# Also only call `func` on `structures` and not `shallow_structure`.
|
|
155
|
+
def func_with_check_without_shallow_structure(shallow, *args):
|
|
156
|
+
if not _tree_is_leaf(shallow):
|
|
157
|
+
raise ValueError("Structures don't have the same nested structure.")
|
|
158
|
+
return func(*args)
|
|
159
|
+
|
|
160
|
+
return torch_tree.tree_map(
|
|
161
|
+
func_with_check_without_shallow_structure,
|
|
162
|
+
shallow_structure,
|
|
163
|
+
*structures,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def assert_same_structure(a, b):
|
|
168
|
+
def check(a_leaf, b_leaf):
|
|
169
|
+
if not _tree_is_leaf(a_leaf) or not _tree_is_leaf(b_leaf):
|
|
170
|
+
raise ValueError("Structures don't have the same nested structure.")
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
torch_tree.tree_map(check, a, b)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def assert_same_paths(a, b):
|
|
177
|
+
a_paths = set([path for path, _ in flatten_with_path(a)])
|
|
178
|
+
b_paths = set([path for path, _ in flatten_with_path(b)])
|
|
179
|
+
|
|
180
|
+
if a_paths != b_paths:
|
|
181
|
+
msg = "`a` and `b` don't have the same paths."
|
|
182
|
+
a_diff = a_paths.difference(b_paths)
|
|
183
|
+
if a_diff:
|
|
184
|
+
msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
|
|
185
|
+
b_diff = b_paths.difference(a_paths)
|
|
186
|
+
if b_diff:
|
|
187
|
+
msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
|
|
188
|
+
raise ValueError(msg)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def pack_sequence_as(structure, flat_sequence):
|
|
192
|
+
# We need to first sort dicts to ensure a deterministic order that is
|
|
193
|
+
# consistent with other tree implementations.
|
|
194
|
+
structure = _dict_to_ordered_dict(structure)
|
|
195
|
+
_, treespec = torch_tree.tree_flatten(structure)
|
|
196
|
+
return torch_tree.tree_unflatten(flat_sequence, treespec)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def lists_to_tuples(structure):
|
|
200
|
+
def list_to_tuple(instance):
|
|
201
|
+
return tuple(instance) if isinstance(instance, list) else None
|
|
202
|
+
|
|
203
|
+
return traverse(list_to_tuple, structure, top_down=False)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def map_shape_structure(func, structure):
|
|
207
|
+
def is_shape_tuple(x):
|
|
208
|
+
return isinstance(x, (list, tuple)) and all(
|
|
209
|
+
isinstance(e, (int, type(None))) for e in x
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# We need to first sort dicts to ensure a deterministic order that is
|
|
213
|
+
# consistent with other tree implementations.
|
|
214
|
+
structure = _dict_to_ordered_dict(structure)
|
|
215
|
+
return torch_tree.tree_map(func, structure, is_leaf=is_shape_tuple)
|