keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/ops/image.py
CHANGED
|
@@ -565,6 +565,8 @@ class ExtractPatches(Operation):
|
|
|
565
565
|
if isinstance(size, int):
|
|
566
566
|
size = (size, size)
|
|
567
567
|
self.size = size
|
|
568
|
+
if strides is None:
|
|
569
|
+
strides = size
|
|
568
570
|
self.strides = strides
|
|
569
571
|
self.dilation_rate = dilation_rate
|
|
570
572
|
self.padding = padding
|
|
@@ -583,8 +585,6 @@ class ExtractPatches(Operation):
|
|
|
583
585
|
def compute_output_spec(self, images):
|
|
584
586
|
images_shape = list(images.shape)
|
|
585
587
|
original_ndim = len(images_shape)
|
|
586
|
-
if not self.strides:
|
|
587
|
-
strides = (self.size[0], self.size[1])
|
|
588
588
|
if self.data_format == "channels_last":
|
|
589
589
|
channels_in = images_shape[-1]
|
|
590
590
|
else:
|
|
@@ -597,7 +597,7 @@ class ExtractPatches(Operation):
|
|
|
597
597
|
images_shape,
|
|
598
598
|
filters,
|
|
599
599
|
kernel_size,
|
|
600
|
-
strides=strides,
|
|
600
|
+
strides=self.strides,
|
|
601
601
|
padding=self.padding,
|
|
602
602
|
data_format=self.data_format,
|
|
603
603
|
dilation_rate=self.dilation_rate,
|
|
@@ -616,42 +616,98 @@ def extract_patches(
|
|
|
616
616
|
padding="valid",
|
|
617
617
|
data_format=None,
|
|
618
618
|
):
|
|
619
|
-
"""Extracts patches from the image(s).
|
|
619
|
+
"""Extracts patches from the image(s) or volume(s).
|
|
620
|
+
|
|
621
|
+
This function supports both 2D and 3D patch extraction based on the
|
|
622
|
+
`size` argument length, similar to how `keras.ops.conv` handles
|
|
623
|
+
different dimensions.
|
|
620
624
|
|
|
621
625
|
Args:
|
|
622
|
-
images: Input image or batch of images.
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
626
|
+
images: Input image/volume or batch of images/volumes.
|
|
627
|
+
For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.
|
|
628
|
+
For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.
|
|
629
|
+
size: Patch size as int or tuple.
|
|
630
|
+
Length 2 tuple `(patch_height, patch_width)` or int for 2D patches.
|
|
631
|
+
Length 3 tuple `(patch_depth, patch_height, patch_width)` for
|
|
632
|
+
3D patches.
|
|
633
|
+
strides: Strides for patch extraction. If not specified, defaults
|
|
634
|
+
to `size` (non-overlapping patches).
|
|
635
|
+
dilation_rate: Dilation rate for patch extraction. Note that
|
|
636
|
+
`dilation_rate > 1` is not supported with `strides > 1`.
|
|
630
637
|
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
|
|
631
638
|
data_format: A string specifying the data format of the input tensor.
|
|
632
639
|
It can be either `"channels_last"` or `"channels_first"`.
|
|
633
|
-
|
|
634
|
-
`(batch, height, width, channels)`, while `"channels_first"`
|
|
635
|
-
corresponds to inputs with shape `(batch, channels, height, width)`.
|
|
636
|
-
If not specified, the value will default to
|
|
637
|
-
`keras.config.image_data_format`.
|
|
640
|
+
If not specified, defaults to `keras.config.image_data_format`.
|
|
638
641
|
|
|
639
642
|
Returns:
|
|
640
|
-
Extracted patches
|
|
643
|
+
Extracted patches with shape depending on input and `size`:
|
|
644
|
+
- 2D patches: 3D (unbatched) or 4D (batched)
|
|
645
|
+
- 3D patches: 4D (unbatched) or 5D (batched)
|
|
641
646
|
|
|
642
647
|
Examples:
|
|
643
648
|
|
|
649
|
+
>>> # 2D patches from batch of images
|
|
644
650
|
>>> image = np.random.random(
|
|
645
651
|
... (2, 20, 20, 3)
|
|
646
|
-
... ).astype("float32")
|
|
652
|
+
... ).astype("float32")
|
|
647
653
|
>>> patches = keras.ops.image.extract_patches(image, (5, 5))
|
|
648
654
|
>>> patches.shape
|
|
649
655
|
(2, 4, 4, 75)
|
|
650
|
-
|
|
656
|
+
|
|
657
|
+
>>> # 2D patches from single image
|
|
658
|
+
>>> image = np.random.random((20, 20, 3)).astype("float32")
|
|
651
659
|
>>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1))
|
|
652
660
|
>>> patches.shape
|
|
653
661
|
(18, 18, 27)
|
|
662
|
+
|
|
663
|
+
>>> # 3D patches from batch of volumes
|
|
664
|
+
>>> volumes = np.random.random(
|
|
665
|
+
... (2, 10, 10, 10, 3)
|
|
666
|
+
... ).astype("float32")
|
|
667
|
+
>>> patches = keras.ops.image.extract_patches(volumes, (3, 3, 3))
|
|
668
|
+
>>> patches.shape
|
|
669
|
+
(2, 3, 3, 3, 81)
|
|
670
|
+
|
|
671
|
+
>>> # 3D patches from single volume
|
|
672
|
+
>>> volume = np.random.random((10, 10, 10, 3)).astype("float32")
|
|
673
|
+
>>> patches = keras.ops.image.extract_patches(volume, (3, 3, 3))
|
|
674
|
+
>>> patches.shape
|
|
675
|
+
(3, 3, 3, 81)
|
|
654
676
|
"""
|
|
677
|
+
# Validate size argument
|
|
678
|
+
if not isinstance(size, int):
|
|
679
|
+
if not isinstance(size, (tuple, list)):
|
|
680
|
+
raise TypeError(
|
|
681
|
+
"Invalid `size` argument. Expected an int or a tuple. "
|
|
682
|
+
f"Received: size={size} of type {type(size).__name__}"
|
|
683
|
+
)
|
|
684
|
+
if len(size) not in (2, 3):
|
|
685
|
+
raise ValueError(
|
|
686
|
+
"Invalid `size` argument. Expected a tuple of length 2 or 3. "
|
|
687
|
+
f"Received: size={size} with length {len(size)}"
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
# Determine 2D vs 3D based on size argument
|
|
691
|
+
if not isinstance(size, int) and len(size) == 3:
|
|
692
|
+
# 3D patch extraction
|
|
693
|
+
if any_symbolic_tensors((images,)):
|
|
694
|
+
return ExtractPatches3D(
|
|
695
|
+
size=size,
|
|
696
|
+
strides=strides,
|
|
697
|
+
dilation_rate=dilation_rate,
|
|
698
|
+
padding=padding,
|
|
699
|
+
data_format=data_format,
|
|
700
|
+
).symbolic_call(images)
|
|
701
|
+
return _extract_patches_3d(
|
|
702
|
+
images,
|
|
703
|
+
size,
|
|
704
|
+
strides,
|
|
705
|
+
dilation_rate,
|
|
706
|
+
padding,
|
|
707
|
+
data_format=data_format,
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# 2D patch extraction (default)
|
|
655
711
|
if any_symbolic_tensors((images,)):
|
|
656
712
|
return ExtractPatches(
|
|
657
713
|
size=size,
|
|
@@ -712,6 +768,187 @@ def _extract_patches(
|
|
|
712
768
|
return patches
|
|
713
769
|
|
|
714
770
|
|
|
771
|
+
class ExtractPatches3D(Operation):
|
|
772
|
+
def __init__(
|
|
773
|
+
self,
|
|
774
|
+
size,
|
|
775
|
+
strides=None,
|
|
776
|
+
dilation_rate=1,
|
|
777
|
+
padding="valid",
|
|
778
|
+
data_format=None,
|
|
779
|
+
*,
|
|
780
|
+
name=None,
|
|
781
|
+
):
|
|
782
|
+
super().__init__(name=name)
|
|
783
|
+
if isinstance(size, int):
|
|
784
|
+
size = (size, size, size)
|
|
785
|
+
elif len(size) != 3:
|
|
786
|
+
raise TypeError(
|
|
787
|
+
"Invalid `size` argument. Expected an "
|
|
788
|
+
f"int or a tuple of length 3. Received: size={size}"
|
|
789
|
+
)
|
|
790
|
+
self.size = size
|
|
791
|
+
if strides is not None:
|
|
792
|
+
if isinstance(strides, int):
|
|
793
|
+
strides = (strides, strides, strides)
|
|
794
|
+
elif len(strides) != 3:
|
|
795
|
+
raise ValueError(f"Invalid `strides` argument. Got: {strides}")
|
|
796
|
+
else:
|
|
797
|
+
strides = size
|
|
798
|
+
self.strides = strides
|
|
799
|
+
self.dilation_rate = dilation_rate
|
|
800
|
+
self.padding = padding
|
|
801
|
+
self.data_format = backend.standardize_data_format(data_format)
|
|
802
|
+
|
|
803
|
+
def call(self, volumes):
|
|
804
|
+
return _extract_patches_3d(
|
|
805
|
+
volumes,
|
|
806
|
+
self.size,
|
|
807
|
+
self.strides,
|
|
808
|
+
self.dilation_rate,
|
|
809
|
+
self.padding,
|
|
810
|
+
self.data_format,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
def compute_output_spec(self, volumes):
|
|
814
|
+
volumes_shape = list(volumes.shape)
|
|
815
|
+
original_ndim = len(volumes_shape)
|
|
816
|
+
strides = self.strides
|
|
817
|
+
if self.data_format == "channels_last":
|
|
818
|
+
channels_in = volumes_shape[-1]
|
|
819
|
+
else:
|
|
820
|
+
channels_in = volumes_shape[-4]
|
|
821
|
+
if original_ndim == 4:
|
|
822
|
+
volumes_shape = [1] + volumes_shape
|
|
823
|
+
filters = self.size[0] * self.size[1] * self.size[2] * channels_in
|
|
824
|
+
kernel_size = (self.size[0], self.size[1], self.size[2])
|
|
825
|
+
out_shape = compute_conv_output_shape(
|
|
826
|
+
volumes_shape,
|
|
827
|
+
filters,
|
|
828
|
+
kernel_size,
|
|
829
|
+
strides=strides,
|
|
830
|
+
padding=self.padding,
|
|
831
|
+
data_format=self.data_format,
|
|
832
|
+
dilation_rate=self.dilation_rate,
|
|
833
|
+
)
|
|
834
|
+
if original_ndim == 4:
|
|
835
|
+
out_shape = out_shape[1:]
|
|
836
|
+
return KerasTensor(shape=out_shape, dtype=volumes.dtype)
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
def _extract_patches_3d(
|
|
840
|
+
volumes,
|
|
841
|
+
size,
|
|
842
|
+
strides=None,
|
|
843
|
+
dilation_rate=1,
|
|
844
|
+
padding="valid",
|
|
845
|
+
data_format=None,
|
|
846
|
+
):
|
|
847
|
+
if isinstance(size, int):
|
|
848
|
+
patch_d = patch_h = patch_w = size
|
|
849
|
+
elif len(size) == 3:
|
|
850
|
+
patch_d, patch_h, patch_w = size
|
|
851
|
+
else:
|
|
852
|
+
raise TypeError(
|
|
853
|
+
"Invalid `size` argument. Expected an "
|
|
854
|
+
f"int or a tuple of length 3. Received: size={size}"
|
|
855
|
+
)
|
|
856
|
+
if strides is None:
|
|
857
|
+
strides = size
|
|
858
|
+
if isinstance(strides, int):
|
|
859
|
+
strides = (strides, strides, strides)
|
|
860
|
+
if len(strides) != 3:
|
|
861
|
+
raise ValueError(f"Invalid `strides` argument. Got: {strides}")
|
|
862
|
+
data_format = backend.standardize_data_format(data_format)
|
|
863
|
+
if data_format == "channels_last":
|
|
864
|
+
channels_in = volumes.shape[-1]
|
|
865
|
+
elif data_format == "channels_first":
|
|
866
|
+
channels_in = volumes.shape[-4]
|
|
867
|
+
out_dim = patch_d * patch_w * patch_h * channels_in
|
|
868
|
+
kernel = backend.numpy.eye(out_dim, dtype=volumes.dtype)
|
|
869
|
+
kernel = backend.numpy.reshape(
|
|
870
|
+
kernel, (patch_d, patch_h, patch_w, channels_in, out_dim)
|
|
871
|
+
)
|
|
872
|
+
_unbatched = False
|
|
873
|
+
if len(volumes.shape) == 4:
|
|
874
|
+
_unbatched = True
|
|
875
|
+
volumes = backend.numpy.expand_dims(volumes, axis=0)
|
|
876
|
+
patches = backend.nn.conv(
|
|
877
|
+
inputs=volumes,
|
|
878
|
+
kernel=kernel,
|
|
879
|
+
strides=strides,
|
|
880
|
+
padding=padding,
|
|
881
|
+
data_format=data_format,
|
|
882
|
+
dilation_rate=dilation_rate,
|
|
883
|
+
)
|
|
884
|
+
if _unbatched:
|
|
885
|
+
patches = backend.numpy.squeeze(patches, axis=0)
|
|
886
|
+
return patches
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
@keras_export("keras.ops.image.extract_patches_3d")
|
|
890
|
+
def extract_patches_3d(
|
|
891
|
+
volumes,
|
|
892
|
+
size,
|
|
893
|
+
strides=None,
|
|
894
|
+
dilation_rate=1,
|
|
895
|
+
padding="valid",
|
|
896
|
+
data_format=None,
|
|
897
|
+
):
|
|
898
|
+
"""Extracts patches from the volume(s).
|
|
899
|
+
|
|
900
|
+
Args:
|
|
901
|
+
volumes: Input volume or batch of volumes. Must be 4D or 5D.
|
|
902
|
+
size: Patch size int or tuple (patch_depth, patch_height, patch_width)
|
|
903
|
+
strides: strides along depth, height, and width. If not specified, or
|
|
904
|
+
if `None`, it defaults to the same value as `size`.
|
|
905
|
+
dilation_rate: This is the input stride, specifying how far two
|
|
906
|
+
consecutive patch samples are in the input. Note that using
|
|
907
|
+
`dilation_rate > 1` is not supported in conjunction with
|
|
908
|
+
`strides > 1` on the TensorFlow backend.
|
|
909
|
+
padding: The type of padding algorithm to use: `"same"` or `"valid"`.
|
|
910
|
+
data_format: A string specifying the data format of the input tensor.
|
|
911
|
+
It can be either `"channels_last"` or `"channels_first"`.
|
|
912
|
+
`"channels_last"` corresponds to inputs with shape
|
|
913
|
+
`(batch, depth, height, width, channels)`, while `"channels_first"`
|
|
914
|
+
corresponds to inputs with shape
|
|
915
|
+
`(batch, channels, depth, height, width)`. If not specified,
|
|
916
|
+
the value will default to `keras.config.image_data_format()`.
|
|
917
|
+
|
|
918
|
+
Returns:
|
|
919
|
+
Extracted patches 4D (if not batched) or 5D (if batched)
|
|
920
|
+
|
|
921
|
+
Examples:
|
|
922
|
+
|
|
923
|
+
>>> import numpy as np
|
|
924
|
+
>>> import keras
|
|
925
|
+
>>> # Batched case
|
|
926
|
+
>>> volumes = np.random.random(
|
|
927
|
+
... (2, 10, 10, 10, 3)
|
|
928
|
+
... ).astype("float32") # batch of 2 volumes
|
|
929
|
+
>>> patches = keras.ops.image.extract_patches_3d(volumes, (3, 3, 3))
|
|
930
|
+
>>> patches.shape
|
|
931
|
+
(2, 3, 3, 3, 81)
|
|
932
|
+
>>> # Unbatched case
|
|
933
|
+
>>> volume = np.random.random((10, 10, 10, 3)).astype("float32") # 1 volume
|
|
934
|
+
>>> patches = keras.ops.image.extract_patches_3d(volume, (3, 3, 3))
|
|
935
|
+
>>> patches.shape
|
|
936
|
+
(3, 3, 3, 81)
|
|
937
|
+
"""
|
|
938
|
+
if any_symbolic_tensors((volumes,)):
|
|
939
|
+
return ExtractPatches3D(
|
|
940
|
+
size=size,
|
|
941
|
+
strides=strides,
|
|
942
|
+
dilation_rate=dilation_rate,
|
|
943
|
+
padding=padding,
|
|
944
|
+
data_format=data_format,
|
|
945
|
+
).symbolic_call(volumes)
|
|
946
|
+
|
|
947
|
+
return _extract_patches_3d(
|
|
948
|
+
volumes, size, strides, dilation_rate, padding, data_format=data_format
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
|
|
715
952
|
class MapCoordinates(Operation):
|
|
716
953
|
def __init__(self, order, fill_mode="constant", fill_value=0, *, name=None):
|
|
717
954
|
super().__init__(name=name)
|
keras/src/ops/linalg.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from keras.src import backend
|
|
2
|
+
from keras.src import tree
|
|
2
3
|
from keras.src.api_export import keras_export
|
|
3
4
|
from keras.src.backend import KerasTensor
|
|
4
5
|
from keras.src.backend import any_symbolic_tensors
|
|
@@ -732,3 +733,95 @@ def _assert_a_b_compat(a, b):
|
|
|
732
733
|
"Expected `a.shape[-1] == b.shape[-1]`. "
|
|
733
734
|
f"Received: a.shape={a.shape}, b.shape={b.shape}"
|
|
734
735
|
)
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
class JVP(Operation):
|
|
739
|
+
def __init__(self, has_aux=False, *, name=None):
|
|
740
|
+
super().__init__(name=name)
|
|
741
|
+
self.has_aux = has_aux
|
|
742
|
+
|
|
743
|
+
def call(self, fun, primals, tangents):
|
|
744
|
+
"""Computes the JVP of `fun` at `primals` along `tangents`.
|
|
745
|
+
|
|
746
|
+
Args:
|
|
747
|
+
fun: A callable that takes tensors (or nested structures) as input
|
|
748
|
+
and returns a tensor (or nested structure) as output.
|
|
749
|
+
primals: Input tensors (or nested structures) at which the Jacobian
|
|
750
|
+
of `fun` is evaluated.
|
|
751
|
+
tangents: Tensors (or nested structures) representing the direction
|
|
752
|
+
vectors for the JVP. Must have the same structure as
|
|
753
|
+
`primals`.
|
|
754
|
+
|
|
755
|
+
Returns:
|
|
756
|
+
If `has_aux` is False:
|
|
757
|
+
A tuple (primals_out, tangents_out) where:
|
|
758
|
+
- primals_out: Output of `fun(*primals)`
|
|
759
|
+
- tangents_out: JVP of `fun` at `primals` along `tangents`
|
|
760
|
+
If `has_aux` is True:
|
|
761
|
+
A tuple (primals_out, tangents_out, aux) where:
|
|
762
|
+
- aux: Auxiliary data returned by `fun`
|
|
763
|
+
"""
|
|
764
|
+
return backend.linalg.jvp(fun, primals, tangents, has_aux=self.has_aux)
|
|
765
|
+
|
|
766
|
+
def compute_output_spec(self, fun, primals, tangents):
|
|
767
|
+
# Infer primal output spec
|
|
768
|
+
if self.has_aux:
|
|
769
|
+
primals_out_spec, aux_spec = backend.compute_output_spec(
|
|
770
|
+
fun, *primals
|
|
771
|
+
)
|
|
772
|
+
else:
|
|
773
|
+
primals_out_spec = backend.compute_output_spec(fun, *primals)
|
|
774
|
+
|
|
775
|
+
# Tangents output should match primals output in structure and shape
|
|
776
|
+
tangents_out_spec = tree.map_structure(
|
|
777
|
+
lambda x: KerasTensor(x.shape, x.dtype), primals_out_spec
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
if self.has_aux:
|
|
781
|
+
return primals_out_spec, tangents_out_spec, aux_spec
|
|
782
|
+
return primals_out_spec, tangents_out_spec
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
@keras_export(["keras.ops.jvp", "keras.ops.linalg.jvp"])
|
|
786
|
+
def jvp(fun, primals, tangents, has_aux=False):
|
|
787
|
+
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
|
|
788
|
+
Args:
|
|
789
|
+
fun: Function to be differentiated. Its arguments should be arrays,
|
|
790
|
+
scalars, or standard Python containers of arrays or scalars. It
|
|
791
|
+
should return an array, scalar, or standard Python container of
|
|
792
|
+
arrays or scalars.
|
|
793
|
+
primals: The primal values at which the Jacobian of `fun` should be
|
|
794
|
+
evaluated. Should be either a tuple or a list of arguments,
|
|
795
|
+
and its length should be equal to the number of positional
|
|
796
|
+
parameters of `fun`.
|
|
797
|
+
tangents: The tangent vector for which the Jacobian-vector product
|
|
798
|
+
should be evaluated. Should be either a tuple or a list of
|
|
799
|
+
tangents, with the same tree structure and array shapes as
|
|
800
|
+
`primals`.
|
|
801
|
+
has_aux: Optional, bool. Indicates whether `fun` returns a pair where
|
|
802
|
+
the first element is considered the output of the mathematical
|
|
803
|
+
function to be differentiated and the second element is
|
|
804
|
+
auxiliary data. Default is False.
|
|
805
|
+
|
|
806
|
+
Returns:
|
|
807
|
+
If `has_aux` is False, returns a (`primals_out`, `tangents_out`) pair,
|
|
808
|
+
where `primals_out` is `fun(*primals)`, and `tangents_out` is the
|
|
809
|
+
Jacobian-vector product of `fun` evaluated at `primals` with
|
|
810
|
+
`tangents`. The `tangents_out` value has the same Python tree
|
|
811
|
+
structure and shapes as `primals_out`.
|
|
812
|
+
|
|
813
|
+
If `has_aux` is True, returns a (`primals_out`, `tangents_out`, `aux`)
|
|
814
|
+
tuple where `aux` is the auxiliary data returned by `fun`.
|
|
815
|
+
|
|
816
|
+
Example:
|
|
817
|
+
>>> from keras import ops
|
|
818
|
+
>>> a1, a2 = ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2)
|
|
819
|
+
>>> primals, tangents = ops.jvp(ops.sin, (a1,), (a2,))
|
|
820
|
+
>>> primals
|
|
821
|
+
0.09983342
|
|
822
|
+
>>> tangents
|
|
823
|
+
0.19900084
|
|
824
|
+
"""
|
|
825
|
+
if any_symbolic_tensors((primals, tangents)):
|
|
826
|
+
return JVP(has_aux=has_aux).symbolic_call(fun, primals, tangents)
|
|
827
|
+
return backend.linalg.jvp(fun, primals, tangents, has_aux=has_aux)
|