keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/ops/image.py CHANGED
@@ -565,6 +565,8 @@ class ExtractPatches(Operation):
565
565
  if isinstance(size, int):
566
566
  size = (size, size)
567
567
  self.size = size
568
+ if strides is None:
569
+ strides = size
568
570
  self.strides = strides
569
571
  self.dilation_rate = dilation_rate
570
572
  self.padding = padding
@@ -583,8 +585,6 @@ class ExtractPatches(Operation):
583
585
  def compute_output_spec(self, images):
584
586
  images_shape = list(images.shape)
585
587
  original_ndim = len(images_shape)
586
- if not self.strides:
587
- strides = (self.size[0], self.size[1])
588
588
  if self.data_format == "channels_last":
589
589
  channels_in = images_shape[-1]
590
590
  else:
@@ -597,7 +597,7 @@ class ExtractPatches(Operation):
597
597
  images_shape,
598
598
  filters,
599
599
  kernel_size,
600
- strides=strides,
600
+ strides=self.strides,
601
601
  padding=self.padding,
602
602
  data_format=self.data_format,
603
603
  dilation_rate=self.dilation_rate,
@@ -616,42 +616,98 @@ def extract_patches(
616
616
  padding="valid",
617
617
  data_format=None,
618
618
  ):
619
- """Extracts patches from the image(s).
619
+ """Extracts patches from the image(s) or volume(s).
620
+
621
+ This function supports both 2D and 3D patch extraction based on the
622
+ `size` argument length, similar to how `keras.ops.conv` handles
623
+ different dimensions.
620
624
 
621
625
  Args:
622
- images: Input image or batch of images. Must be 3D or 4D.
623
- size: Patch size int or tuple (patch_height, patch_width)
624
- strides: strides along height and width. If not specified, or
625
- if `None`, it defaults to the same value as `size`.
626
- dilation_rate: This is the input stride, specifying how far two
627
- consecutive patch samples are in the input. For value other than 1,
628
- strides must be 1. NOTE: `strides > 1` is not supported in
629
- conjunction with `dilation_rate > 1`
626
+ images: Input image/volume or batch of images/volumes.
627
+ For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.
628
+ For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.
629
+ size: Patch size as int or tuple.
630
+ Length 2 tuple `(patch_height, patch_width)` or int for 2D patches.
631
+ Length 3 tuple `(patch_depth, patch_height, patch_width)` for
632
+ 3D patches.
633
+ strides: Strides for patch extraction. If not specified, defaults
634
+ to `size` (non-overlapping patches).
635
+ dilation_rate: Dilation rate for patch extraction. Note that
636
+ `dilation_rate > 1` is not supported with `strides > 1`.
630
637
  padding: The type of padding algorithm to use: `"same"` or `"valid"`.
631
638
  data_format: A string specifying the data format of the input tensor.
632
639
  It can be either `"channels_last"` or `"channels_first"`.
633
- `"channels_last"` corresponds to inputs with shape
634
- `(batch, height, width, channels)`, while `"channels_first"`
635
- corresponds to inputs with shape `(batch, channels, height, width)`.
636
- If not specified, the value will default to
637
- `keras.config.image_data_format`.
640
+ If not specified, defaults to `keras.config.image_data_format`.
638
641
 
639
642
  Returns:
640
- Extracted patches 3D (if not batched) or 4D (if batched)
643
+ Extracted patches with shape depending on input and `size`:
644
+ - 2D patches: 3D (unbatched) or 4D (batched)
645
+ - 3D patches: 4D (unbatched) or 5D (batched)
641
646
 
642
647
  Examples:
643
648
 
649
+ >>> # 2D patches from batch of images
644
650
  >>> image = np.random.random(
645
651
  ... (2, 20, 20, 3)
646
- ... ).astype("float32") # batch of 2 RGB images
652
+ ... ).astype("float32")
647
653
  >>> patches = keras.ops.image.extract_patches(image, (5, 5))
648
654
  >>> patches.shape
649
655
  (2, 4, 4, 75)
650
- >>> image = np.random.random((20, 20, 3)).astype("float32") # 1 RGB image
656
+
657
+ >>> # 2D patches from single image
658
+ >>> image = np.random.random((20, 20, 3)).astype("float32")
651
659
  >>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1))
652
660
  >>> patches.shape
653
661
  (18, 18, 27)
662
+
663
+ >>> # 3D patches from batch of volumes
664
+ >>> volumes = np.random.random(
665
+ ... (2, 10, 10, 10, 3)
666
+ ... ).astype("float32")
667
+ >>> patches = keras.ops.image.extract_patches(volumes, (3, 3, 3))
668
+ >>> patches.shape
669
+ (2, 3, 3, 3, 81)
670
+
671
+ >>> # 3D patches from single volume
672
+ >>> volume = np.random.random((10, 10, 10, 3)).astype("float32")
673
+ >>> patches = keras.ops.image.extract_patches(volume, (3, 3, 3))
674
+ >>> patches.shape
675
+ (3, 3, 3, 81)
654
676
  """
677
+ # Validate size argument
678
+ if not isinstance(size, int):
679
+ if not isinstance(size, (tuple, list)):
680
+ raise TypeError(
681
+ "Invalid `size` argument. Expected an int or a tuple. "
682
+ f"Received: size={size} of type {type(size).__name__}"
683
+ )
684
+ if len(size) not in (2, 3):
685
+ raise ValueError(
686
+ "Invalid `size` argument. Expected a tuple of length 2 or 3. "
687
+ f"Received: size={size} with length {len(size)}"
688
+ )
689
+
690
+ # Determine 2D vs 3D based on size argument
691
+ if not isinstance(size, int) and len(size) == 3:
692
+ # 3D patch extraction
693
+ if any_symbolic_tensors((images,)):
694
+ return ExtractPatches3D(
695
+ size=size,
696
+ strides=strides,
697
+ dilation_rate=dilation_rate,
698
+ padding=padding,
699
+ data_format=data_format,
700
+ ).symbolic_call(images)
701
+ return _extract_patches_3d(
702
+ images,
703
+ size,
704
+ strides,
705
+ dilation_rate,
706
+ padding,
707
+ data_format=data_format,
708
+ )
709
+
710
+ # 2D patch extraction (default)
655
711
  if any_symbolic_tensors((images,)):
656
712
  return ExtractPatches(
657
713
  size=size,
@@ -712,6 +768,187 @@ def _extract_patches(
712
768
  return patches
713
769
 
714
770
 
771
+ class ExtractPatches3D(Operation):
772
+ def __init__(
773
+ self,
774
+ size,
775
+ strides=None,
776
+ dilation_rate=1,
777
+ padding="valid",
778
+ data_format=None,
779
+ *,
780
+ name=None,
781
+ ):
782
+ super().__init__(name=name)
783
+ if isinstance(size, int):
784
+ size = (size, size, size)
785
+ elif len(size) != 3:
786
+ raise TypeError(
787
+ "Invalid `size` argument. Expected an "
788
+ f"int or a tuple of length 3. Received: size={size}"
789
+ )
790
+ self.size = size
791
+ if strides is not None:
792
+ if isinstance(strides, int):
793
+ strides = (strides, strides, strides)
794
+ elif len(strides) != 3:
795
+ raise ValueError(f"Invalid `strides` argument. Got: {strides}")
796
+ else:
797
+ strides = size
798
+ self.strides = strides
799
+ self.dilation_rate = dilation_rate
800
+ self.padding = padding
801
+ self.data_format = backend.standardize_data_format(data_format)
802
+
803
+ def call(self, volumes):
804
+ return _extract_patches_3d(
805
+ volumes,
806
+ self.size,
807
+ self.strides,
808
+ self.dilation_rate,
809
+ self.padding,
810
+ self.data_format,
811
+ )
812
+
813
+ def compute_output_spec(self, volumes):
814
+ volumes_shape = list(volumes.shape)
815
+ original_ndim = len(volumes_shape)
816
+ strides = self.strides
817
+ if self.data_format == "channels_last":
818
+ channels_in = volumes_shape[-1]
819
+ else:
820
+ channels_in = volumes_shape[-4]
821
+ if original_ndim == 4:
822
+ volumes_shape = [1] + volumes_shape
823
+ filters = self.size[0] * self.size[1] * self.size[2] * channels_in
824
+ kernel_size = (self.size[0], self.size[1], self.size[2])
825
+ out_shape = compute_conv_output_shape(
826
+ volumes_shape,
827
+ filters,
828
+ kernel_size,
829
+ strides=strides,
830
+ padding=self.padding,
831
+ data_format=self.data_format,
832
+ dilation_rate=self.dilation_rate,
833
+ )
834
+ if original_ndim == 4:
835
+ out_shape = out_shape[1:]
836
+ return KerasTensor(shape=out_shape, dtype=volumes.dtype)
837
+
838
+
839
+ def _extract_patches_3d(
840
+ volumes,
841
+ size,
842
+ strides=None,
843
+ dilation_rate=1,
844
+ padding="valid",
845
+ data_format=None,
846
+ ):
847
+ if isinstance(size, int):
848
+ patch_d = patch_h = patch_w = size
849
+ elif len(size) == 3:
850
+ patch_d, patch_h, patch_w = size
851
+ else:
852
+ raise TypeError(
853
+ "Invalid `size` argument. Expected an "
854
+ f"int or a tuple of length 3. Received: size={size}"
855
+ )
856
+ if strides is None:
857
+ strides = size
858
+ if isinstance(strides, int):
859
+ strides = (strides, strides, strides)
860
+ if len(strides) != 3:
861
+ raise ValueError(f"Invalid `strides` argument. Got: {strides}")
862
+ data_format = backend.standardize_data_format(data_format)
863
+ if data_format == "channels_last":
864
+ channels_in = volumes.shape[-1]
865
+ elif data_format == "channels_first":
866
+ channels_in = volumes.shape[-4]
867
+ out_dim = patch_d * patch_w * patch_h * channels_in
868
+ kernel = backend.numpy.eye(out_dim, dtype=volumes.dtype)
869
+ kernel = backend.numpy.reshape(
870
+ kernel, (patch_d, patch_h, patch_w, channels_in, out_dim)
871
+ )
872
+ _unbatched = False
873
+ if len(volumes.shape) == 4:
874
+ _unbatched = True
875
+ volumes = backend.numpy.expand_dims(volumes, axis=0)
876
+ patches = backend.nn.conv(
877
+ inputs=volumes,
878
+ kernel=kernel,
879
+ strides=strides,
880
+ padding=padding,
881
+ data_format=data_format,
882
+ dilation_rate=dilation_rate,
883
+ )
884
+ if _unbatched:
885
+ patches = backend.numpy.squeeze(patches, axis=0)
886
+ return patches
887
+
888
+
889
+ @keras_export("keras.ops.image.extract_patches_3d")
890
+ def extract_patches_3d(
891
+ volumes,
892
+ size,
893
+ strides=None,
894
+ dilation_rate=1,
895
+ padding="valid",
896
+ data_format=None,
897
+ ):
898
+ """Extracts patches from the volume(s).
899
+
900
+ Args:
901
+ volumes: Input volume or batch of volumes. Must be 4D or 5D.
902
+ size: Patch size int or tuple (patch_depth, patch_height, patch_width)
903
+ strides: strides along depth, height, and width. If not specified, or
904
+ if `None`, it defaults to the same value as `size`.
905
+ dilation_rate: This is the input stride, specifying how far two
906
+ consecutive patch samples are in the input. Note that using
907
+ `dilation_rate > 1` is not supported in conjunction with
908
+ `strides > 1` on the TensorFlow backend.
909
+ padding: The type of padding algorithm to use: `"same"` or `"valid"`.
910
+ data_format: A string specifying the data format of the input tensor.
911
+ It can be either `"channels_last"` or `"channels_first"`.
912
+ `"channels_last"` corresponds to inputs with shape
913
+ `(batch, depth, height, width, channels)`, while `"channels_first"`
914
+ corresponds to inputs with shape
915
+ `(batch, channels, depth, height, width)`. If not specified,
916
+ the value will default to `keras.config.image_data_format()`.
917
+
918
+ Returns:
919
+ Extracted patches 4D (if not batched) or 5D (if batched)
920
+
921
+ Examples:
922
+
923
+ >>> import numpy as np
924
+ >>> import keras
925
+ >>> # Batched case
926
+ >>> volumes = np.random.random(
927
+ ... (2, 10, 10, 10, 3)
928
+ ... ).astype("float32") # batch of 2 volumes
929
+ >>> patches = keras.ops.image.extract_patches_3d(volumes, (3, 3, 3))
930
+ >>> patches.shape
931
+ (2, 3, 3, 3, 81)
932
+ >>> # Unbatched case
933
+ >>> volume = np.random.random((10, 10, 10, 3)).astype("float32") # 1 volume
934
+ >>> patches = keras.ops.image.extract_patches_3d(volume, (3, 3, 3))
935
+ >>> patches.shape
936
+ (3, 3, 3, 81)
937
+ """
938
+ if any_symbolic_tensors((volumes,)):
939
+ return ExtractPatches3D(
940
+ size=size,
941
+ strides=strides,
942
+ dilation_rate=dilation_rate,
943
+ padding=padding,
944
+ data_format=data_format,
945
+ ).symbolic_call(volumes)
946
+
947
+ return _extract_patches_3d(
948
+ volumes, size, strides, dilation_rate, padding, data_format=data_format
949
+ )
950
+
951
+
715
952
  class MapCoordinates(Operation):
716
953
  def __init__(self, order, fill_mode="constant", fill_value=0, *, name=None):
717
954
  super().__init__(name=name)
keras/src/ops/linalg.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from keras.src import backend
2
+ from keras.src import tree
2
3
  from keras.src.api_export import keras_export
3
4
  from keras.src.backend import KerasTensor
4
5
  from keras.src.backend import any_symbolic_tensors
@@ -732,3 +733,95 @@ def _assert_a_b_compat(a, b):
732
733
  "Expected `a.shape[-1] == b.shape[-1]`. "
733
734
  f"Received: a.shape={a.shape}, b.shape={b.shape}"
734
735
  )
736
+
737
+
738
+ class JVP(Operation):
739
+ def __init__(self, has_aux=False, *, name=None):
740
+ super().__init__(name=name)
741
+ self.has_aux = has_aux
742
+
743
+ def call(self, fun, primals, tangents):
744
+ """Computes the JVP of `fun` at `primals` along `tangents`.
745
+
746
+ Args:
747
+ fun: A callable that takes tensors (or nested structures) as input
748
+ and returns a tensor (or nested structure) as output.
749
+ primals: Input tensors (or nested structures) at which the Jacobian
750
+ of `fun` is evaluated.
751
+ tangents: Tensors (or nested structures) representing the direction
752
+ vectors for the JVP. Must have the same structure as
753
+ `primals`.
754
+
755
+ Returns:
756
+ If `has_aux` is False:
757
+ A tuple (primals_out, tangents_out) where:
758
+ - primals_out: Output of `fun(*primals)`
759
+ - tangents_out: JVP of `fun` at `primals` along `tangents`
760
+ If `has_aux` is True:
761
+ A tuple (primals_out, tangents_out, aux) where:
762
+ - aux: Auxiliary data returned by `fun`
763
+ """
764
+ return backend.linalg.jvp(fun, primals, tangents, has_aux=self.has_aux)
765
+
766
+ def compute_output_spec(self, fun, primals, tangents):
767
+ # Infer primal output spec
768
+ if self.has_aux:
769
+ primals_out_spec, aux_spec = backend.compute_output_spec(
770
+ fun, *primals
771
+ )
772
+ else:
773
+ primals_out_spec = backend.compute_output_spec(fun, *primals)
774
+
775
+ # Tangents output should match primals output in structure and shape
776
+ tangents_out_spec = tree.map_structure(
777
+ lambda x: KerasTensor(x.shape, x.dtype), primals_out_spec
778
+ )
779
+
780
+ if self.has_aux:
781
+ return primals_out_spec, tangents_out_spec, aux_spec
782
+ return primals_out_spec, tangents_out_spec
783
+
784
+
785
+ @keras_export(["keras.ops.jvp", "keras.ops.linalg.jvp"])
786
+ def jvp(fun, primals, tangents, has_aux=False):
787
+ """Computes a (forward-mode) Jacobian-vector product of `fun`.
788
+ Args:
789
+ fun: Function to be differentiated. Its arguments should be arrays,
790
+ scalars, or standard Python containers of arrays or scalars. It
791
+ should return an array, scalar, or standard Python container of
792
+ arrays or scalars.
793
+ primals: The primal values at which the Jacobian of `fun` should be
794
+ evaluated. Should be either a tuple or a list of arguments,
795
+ and its length should be equal to the number of positional
796
+ parameters of `fun`.
797
+ tangents: The tangent vector for which the Jacobian-vector product
798
+ should be evaluated. Should be either a tuple or a list of
799
+ tangents, with the same tree structure and array shapes as
800
+ `primals`.
801
+ has_aux: Optional, bool. Indicates whether `fun` returns a pair where
802
+ the first element is considered the output of the mathematical
803
+ function to be differentiated and the second element is
804
+ auxiliary data. Default is False.
805
+
806
+ Returns:
807
+ If `has_aux` is False, returns a (`primals_out`, `tangents_out`) pair,
808
+ where `primals_out` is `fun(*primals)`, and `tangents_out` is the
809
+ Jacobian-vector product of `fun` evaluated at `primals` with
810
+ `tangents`. The `tangents_out` value has the same Python tree
811
+ structure and shapes as `primals_out`.
812
+
813
+ If `has_aux` is True, returns a (`primals_out`, `tangents_out`, `aux`)
814
+ tuple where `aux` is the auxiliary data returned by `fun`.
815
+
816
+ Example:
817
+ >>> from keras import ops
818
+ >>> a1, a2 = ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2)
819
+ >>> primals, tangents = ops.jvp(ops.sin, (a1,), (a2,))
820
+ >>> primals
821
+ 0.09983342
822
+ >>> tangents
823
+ 0.19900084
824
+ """
825
+ if any_symbolic_tensors((primals, tangents)):
826
+ return JVP(has_aux=has_aux).symbolic_call(fun, primals, tangents)
827
+ return backend.linalg.jvp(fun, primals, tangents, has_aux=has_aux)