keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -654,7 +654,7 @@ class SensitivitySpecificityBase(Metric):
654
654
  Args:
655
655
  constrained: Over these values the constraint is specified. A rank-1
656
656
  tensor.
657
- dependent: From these values the maximum that satiesfies the
657
+ dependent: From these values the maximum that satisfies the
658
658
  constraint is selected. Values in this tensor and in
659
659
  `constrained` are linked by having the same threshold at each
660
660
  position, hence this tensor must have the same shape.
@@ -664,11 +664,12 @@ class SensitivitySpecificityBase(Metric):
664
664
  Returns:
665
665
  maximal dependent value, if no value satisfies the constraint 0.0.
666
666
  """
667
- feasible = ops.nonzero(predicate(constrained, self.value))
668
- feasible_exists = ops.greater(ops.size(feasible), 0)
669
- max_dependent = ops.max(ops.take(dependent, feasible), initial=0)
670
-
671
- return ops.where(feasible_exists, max_dependent, 0.0)
667
+ feasible = predicate(constrained, self.value)
668
+ # Mask values based on whether they satisfy the constraint and take max.
669
+ return ops.max(
670
+ ops.multiply(dependent, ops.cast(feasible, dependent.dtype)),
671
+ initial=0,
672
+ )
672
673
 
673
674
 
674
675
  @keras_export("keras.metrics.SensitivityAtSpecificity")
@@ -293,10 +293,12 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
293
293
  input_name = ref_input_layer.name
294
294
  input_batch_shape = ref_input_layer.batch_shape
295
295
  input_dtype = ref_input_layer._dtype
296
+ input_optional = ref_input_layer.optional
296
297
  else:
297
298
  input_name = None
298
299
  input_dtype = None
299
300
  input_batch_shape = None
301
+ input_optional = False
300
302
 
301
303
  if input_tensors is not None:
302
304
  if isinstance(input_tensors, (list, tuple)):
@@ -313,6 +315,7 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
313
315
  inputs = Input(
314
316
  tensor=input_tensors,
315
317
  name=input_name,
318
+ optional=input_optional,
316
319
  )
317
320
  new_layers = [inputs] + new_layers
318
321
  else:
@@ -321,6 +324,7 @@ def _clone_sequential_model(model, clone_function, input_tensors=None):
321
324
  batch_shape=input_batch_shape,
322
325
  dtype=input_dtype,
323
326
  name=input_name,
327
+ optional=input_optional,
324
328
  )
325
329
  new_layers = [inputs] + new_layers
326
330
  cloned_model = Sequential(
@@ -254,9 +254,9 @@ class Functional(Function, Model):
254
254
  return converted
255
255
 
256
256
  def _adjust_input_rank(self, flat_inputs):
257
- flat_ref_shapes = [x.shape for x in self._inputs]
258
257
  adjusted = []
259
- for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
258
+ for i, x in enumerate(flat_inputs):
259
+ ref_shape = self._inputs[i].shape
260
260
  if x is None:
261
261
  adjusted.append(x)
262
262
  continue
@@ -273,8 +273,11 @@ class Functional(Function, Model):
273
273
  if ref_shape[-1] == 1:
274
274
  adjusted.append(ops.expand_dims(x, axis=-1))
275
275
  continue
276
+ flat_paths_and_inputs = tree.flatten_with_path(self._inputs_struct)
277
+ path = ".".join(str(p) for p in flat_paths_and_inputs[i][0])
276
278
  raise ValueError(
277
- f"Invalid input shape for input {x}. Expected shape "
279
+ f"Invalid input shape for input {x} with name "
280
+ f"'{self._inputs[i].name}' and path '{path}'. Expected shape "
278
281
  f"{ref_shape}, but input has incompatible shape {x.shape}"
279
282
  )
280
283
  # Add back metadata.
@@ -832,11 +835,16 @@ def clone_graph_nodes(inputs, outputs):
832
835
  kt_id_mapping[id(kt_input)] = kt_input
833
836
  else:
834
837
  # We need to create a new Keras tensor for any intermediate tensor
838
+ original_op = kt_input._keras_history.operation
839
+ optional = False
840
+ if isinstance(original_op, InputLayer):
841
+ optional = original_op.optional
835
842
  cloned_input = Input(
836
843
  batch_shape=kt_input.shape,
837
844
  dtype=kt_input.dtype,
838
845
  sparse=kt_input.sparse,
839
846
  name=f"{kt_input.name}CLONE",
847
+ optional=optional,
840
848
  )
841
849
  cloned_inputs.append(cloned_input)
842
850
  kt_id_mapping[id(kt_input)] = cloned_input
keras/src/models/model.py CHANGED
@@ -2,14 +2,15 @@ import inspect
2
2
  import json
3
3
  import typing
4
4
  import warnings
5
+ from collections.abc import Callable
5
6
 
6
7
  from keras.src import backend
7
8
  from keras.src import utils
8
9
  from keras.src.api_export import keras_export
9
10
  from keras.src.layers.layer import Layer
10
11
  from keras.src.models.variable_mapping import map_saveable_variables
11
- from keras.src.quantizers.gptq_config import GPTQConfig
12
12
  from keras.src.quantizers.gptq_core import gptq_quantize
13
+ from keras.src.quantizers.utils import should_quantize_layer
13
14
  from keras.src.saving import saving_api
14
15
  from keras.src.trainers import trainer as base_trainer
15
16
  from keras.src.utils import summary_utils
@@ -422,19 +423,99 @@ class Model(Trainer, base_trainer.Trainer, Layer):
422
423
  **kwargs,
423
424
  )
424
425
 
425
- def quantize(self, mode, config=None, **kwargs):
426
+ def get_quantization_layer_structure(self, mode=None):
427
+ """Returns the quantization structure for the model.
428
+
429
+ This method is intended to be overridden by model authors to provide
430
+ topology information required for structure-aware quantization modes
431
+ like 'gptq'.
432
+
433
+ Args:
434
+ mode: The quantization mode.
435
+
436
+ Returns:
437
+ A dictionary describing the topology, e.g.:
438
+ `{'pre_block_layers': [list], 'sequential_blocks': [list]}`
439
+ or `None` if the mode does not require structure or is not
440
+ supported. `'pre_block_layers'` is a list of layers that
441
+ the inputs should be passed through, before being passed to
442
+ the sequential blocks. For example, inputs to an LLM must
443
+ first be passed through an embedding layer, followed by
444
+ the transformer.
445
+ """
446
+ del mode # Unused.
447
+ return None
448
+
449
+ def quantize(self, mode=None, config=None, filters=None, **kwargs):
426
450
  """Quantize the weights of the model.
427
451
 
428
452
  Note that the model must be built first before calling this method.
429
- `quantize` will recursively call `quantize(mode)` in all layers and
453
+ `quantize` will recursively call `quantize(...)` in all layers and
430
454
  will be skipped if the layer doesn't implement the function.
431
455
 
456
+ This method can be called by passing a `mode` string, which uses the
457
+ default configuration for that mode. Alternatively, a `config` object
458
+ can be passed to customize the behavior of the quantization (e.g. to
459
+ use specific quantizers for weights or activations).
460
+
432
461
  Args:
433
- mode: The mode of the quantization. Only 'int8' is supported at this
434
- time.
435
- """
436
- from keras.src.dtype_policies import QUANTIZATION_MODES
462
+ mode: The mode of the quantization. Supported modes are:
463
+ `"int8"`, `"int4"`, `"float8"`, `"gptq"`. This is
464
+ optional if `config` is provided.
465
+ config: The configuration object specifying additional
466
+ quantization options. This argument allows to configure
467
+ the weight and activation quantizers. be an instance of
468
+ `keras.quantizers.QuantizationConfig`.
469
+ filters: Optional filters to apply to the quantization. Can be a
470
+ regex string, a list of regex strings, or a callable. Only the
471
+ layers which match the filter conditions will be quantized.
472
+ **kwargs: Additional keyword arguments.
473
+
474
+ Example:
475
+
476
+ Quantize a model to int8 with default configuration:
437
477
 
478
+ ```python
479
+ # Build the model
480
+ model = keras.Sequential([
481
+ keras.Input(shape=(10,)),
482
+ keras.layers.Dense(10),
483
+ ])
484
+ model.build((None, 10))
485
+
486
+ # Quantize with default int8 config
487
+ model.quantize("int8")
488
+ ```
489
+
490
+ Quantize a model to int8 with a custom configuration:
491
+
492
+ ```python
493
+ from keras.quantizers import Int8QuantizationConfig
494
+ from keras.quantizers import AbsMaxQuantizer
495
+
496
+ # Build the model
497
+ model = keras.Sequential([
498
+ keras.Input(shape=(10,)),
499
+ keras.layers.Dense(10),
500
+ ])
501
+ model.build((None, 10))
502
+
503
+ # Create a custom config
504
+ config = Int8QuantizationConfig(
505
+ weight_quantizer=AbsMaxQuantizer(
506
+ axis=0,
507
+ value_range=(-127, 127)
508
+ ),
509
+ activation_quantizer=AbsMaxQuantizer(
510
+ axis=-1,
511
+ value_range=(-127, 127)
512
+ ),
513
+ )
514
+
515
+ # Quantize with custom config
516
+ model.quantize(config=config)
517
+ ```
518
+ """
438
519
  # Validate inputs.
439
520
  type_check = kwargs.pop("type_check", True)
440
521
  if kwargs:
@@ -443,27 +524,20 @@ class Model(Trainer, base_trainer.Trainer, Layer):
443
524
  f"passed to {self.__class__.__name__}: {kwargs}"
444
525
  )
445
526
 
446
- if mode not in QUANTIZATION_MODES:
447
- raise ValueError(
448
- "Invalid quantization mode. "
449
- f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
450
- )
451
-
452
- if mode == "gptq":
453
- if not isinstance(config, GPTQConfig):
527
+ if filters is not None:
528
+ if not isinstance(filters, (str, Callable, list, tuple)):
454
529
  raise ValueError(
455
- "Mode 'gptq' requires a valid `config` argument of type "
456
- f"`GPTQConfig`. Received: {type(config)}"
530
+ "The `filters` argument must be a regex string, a list of "
531
+ "regex strings, or a callable. Received: "
532
+ f"{type(filters)}"
457
533
  )
458
- elif config is not None:
459
- # All other modes must not receive a config
460
- raise ValueError(
461
- f"The `config` argument is only supported for 'gptq' mode, "
462
- f"but received mode='{mode}' and a non-None config."
463
- )
464
534
 
465
535
  graph_modified = False
466
536
  for layer in self._flatten_layers():
537
+ # Apply filters
538
+ if not should_quantize_layer(layer, filters):
539
+ continue
540
+
467
541
  if len(list(layer._flatten_layers())) == 1:
468
542
  try:
469
543
  layer.quantize(mode, type_check=type_check, config=config)
@@ -474,7 +548,25 @@ class Model(Trainer, base_trainer.Trainer, Layer):
474
548
  pass
475
549
 
476
550
  if mode == "gptq":
477
- gptq_quantize(self, config)
551
+ # Resolve model structure.
552
+ # 1. If quantization_layer_structure is provided inside the config,
553
+ # use that.
554
+ structure = config.quantization_layer_structure
555
+ # 2. If no layer structure is provided in the config, try to fetch
556
+ # it using the `get_quantization_layer_structure` hook.
557
+ if structure is None:
558
+ structure = self.get_quantization_layer_structure(mode)
559
+
560
+ if structure is None:
561
+ raise ValueError(
562
+ "For 'gptq' mode, a valid quantization structure must be "
563
+ "provided either via `config.quantization_layer_structure` "
564
+ "or by overriding "
565
+ "`model.get_quantization_layer_structure(mode)`. The "
566
+ "structure should be a dictionary with keys "
567
+ "'pre_block_layers' and 'sequential_blocks'."
568
+ )
569
+ gptq_quantize(config, structure, filters=filters)
478
570
 
479
571
  # If any layer was changed, we must rebuild the execution functions.
480
572
  if graph_modified:
@@ -569,8 +661,8 @@ class Model(Trainer, base_trainer.Trainer, Layer):
569
661
  filepath: `str` or `pathlib.Path` object. The path to save the
570
662
  artifact.
571
663
  format: `str`. The export format. Supported values:
572
- `"tf_saved_model"` and `"onnx"`. Defaults to
573
- `"tf_saved_model"`.
664
+ `"tf_saved_model"`, `"onnx"`, `"openvino"`, and `"litert"`.
665
+ Defaults to `"tf_saved_model"`.
574
666
  verbose: `bool`. Whether to print a message during export. Defaults
575
667
  to `None`, which uses the default value set by different
576
668
  backends and formats.
@@ -593,6 +685,13 @@ class Model(Trainer, base_trainer.Trainer, Layer):
593
685
  provided, they will be automatically computed.
594
686
  - `opset_version`: Optional `int`. Specific to `format="onnx"`.
595
687
  An integer value that specifies the ONNX opset version.
688
+ - LiteRT-specific options: Optional keyword arguments specific
689
+ to `format="litert"`. These are passed directly to the
690
+ TensorFlow Lite converter and include options like
691
+ `optimizations`, `representative_dataset`,
692
+ `experimental_new_quantizer`, `allow_custom_ops`,
693
+ `enable_select_tf_ops`, etc. See TensorFlow Lite
694
+ documentation for all available options.
596
695
 
597
696
  **Note:** This feature is currently supported only with TensorFlow, JAX
598
697
  and Torch backends.
@@ -627,18 +726,41 @@ class Model(Trainer, base_trainer.Trainer, Layer):
627
726
  }
628
727
  predictions = ort_session.run(None, ort_inputs)
629
728
  ```
729
+
730
+ Here's how to export a LiteRT (TFLite) for inference.
731
+
732
+ ```python
733
+ # Export the model as a LiteRT artifact
734
+ model.export("path/to/location", format="litert")
735
+
736
+ # Load the artifact in a different process/environment
737
+ interpreter = tf.lite.Interpreter(model_path="path/to/location")
738
+ interpreter.allocate_tensors()
739
+ interpreter.set_tensor(
740
+ interpreter.get_input_details()[0]['index'], input_data
741
+ )
742
+ interpreter.invoke()
743
+ output_data = interpreter.get_tensor(
744
+ interpreter.get_output_details()[0]['index']
745
+ )
746
+ ```
630
747
  """
748
+ from keras.src.export import export_litert
631
749
  from keras.src.export import export_onnx
632
750
  from keras.src.export import export_openvino
633
751
  from keras.src.export import export_saved_model
634
752
 
635
- available_formats = ("tf_saved_model", "onnx", "openvino")
753
+ available_formats = ("tf_saved_model", "onnx", "openvino", "litert")
636
754
  if format not in available_formats:
637
755
  raise ValueError(
638
756
  f"Unrecognized format={format}. Supported formats are: "
639
757
  f"{list(available_formats)}."
640
758
  )
641
759
 
760
+ # Check if LiteRT export is available (requires TensorFlow backend)
761
+ if format == "litert" and backend.backend() != "tensorflow":
762
+ raise ImportError("LiteRT export requires TensorFlow backend.")
763
+
642
764
  if format == "tf_saved_model":
643
765
  export_saved_model(
644
766
  self,
@@ -663,6 +785,13 @@ class Model(Trainer, base_trainer.Trainer, Layer):
663
785
  input_signature=input_signature,
664
786
  **kwargs,
665
787
  )
788
+ elif format == "litert":
789
+ export_litert(
790
+ self,
791
+ filepath,
792
+ input_signature=input_signature,
793
+ **kwargs,
794
+ )
666
795
 
667
796
  @classmethod
668
797
  def from_config(cls, config, custom_objects=None):
keras/src/ops/image.py CHANGED
@@ -565,6 +565,8 @@ class ExtractPatches(Operation):
565
565
  if isinstance(size, int):
566
566
  size = (size, size)
567
567
  self.size = size
568
+ if strides is None:
569
+ strides = size
568
570
  self.strides = strides
569
571
  self.dilation_rate = dilation_rate
570
572
  self.padding = padding
@@ -583,8 +585,6 @@ class ExtractPatches(Operation):
583
585
  def compute_output_spec(self, images):
584
586
  images_shape = list(images.shape)
585
587
  original_ndim = len(images_shape)
586
- if not self.strides:
587
- strides = (self.size[0], self.size[1])
588
588
  if self.data_format == "channels_last":
589
589
  channels_in = images_shape[-1]
590
590
  else:
@@ -597,7 +597,7 @@ class ExtractPatches(Operation):
597
597
  images_shape,
598
598
  filters,
599
599
  kernel_size,
600
- strides=strides,
600
+ strides=self.strides,
601
601
  padding=self.padding,
602
602
  data_format=self.data_format,
603
603
  dilation_rate=self.dilation_rate,
@@ -712,6 +712,187 @@ def _extract_patches(
712
712
  return patches
713
713
 
714
714
 
715
+ class ExtractPatches3D(Operation):
716
+ def __init__(
717
+ self,
718
+ size,
719
+ strides=None,
720
+ dilation_rate=1,
721
+ padding="valid",
722
+ data_format=None,
723
+ *,
724
+ name=None,
725
+ ):
726
+ super().__init__(name=name)
727
+ if isinstance(size, int):
728
+ size = (size, size, size)
729
+ elif len(size) != 3:
730
+ raise TypeError(
731
+ "Invalid `size` argument. Expected an "
732
+ f"int or a tuple of length 3. Received: size={size}"
733
+ )
734
+ self.size = size
735
+ if strides is not None:
736
+ if isinstance(strides, int):
737
+ strides = (strides, strides, strides)
738
+ elif len(strides) != 3:
739
+ raise ValueError(f"Invalid `strides` argument. Got: {strides}")
740
+ else:
741
+ strides = size
742
+ self.strides = strides
743
+ self.dilation_rate = dilation_rate
744
+ self.padding = padding
745
+ self.data_format = backend.standardize_data_format(data_format)
746
+
747
+ def call(self, volumes):
748
+ return _extract_patches_3d(
749
+ volumes,
750
+ self.size,
751
+ self.strides,
752
+ self.dilation_rate,
753
+ self.padding,
754
+ self.data_format,
755
+ )
756
+
757
+ def compute_output_spec(self, volumes):
758
+ volumes_shape = list(volumes.shape)
759
+ original_ndim = len(volumes_shape)
760
+ strides = self.strides
761
+ if self.data_format == "channels_last":
762
+ channels_in = volumes_shape[-1]
763
+ else:
764
+ channels_in = volumes_shape[-4]
765
+ if original_ndim == 4:
766
+ volumes_shape = [1] + volumes_shape
767
+ filters = self.size[0] * self.size[1] * self.size[2] * channels_in
768
+ kernel_size = (self.size[0], self.size[1], self.size[2])
769
+ out_shape = compute_conv_output_shape(
770
+ volumes_shape,
771
+ filters,
772
+ kernel_size,
773
+ strides=strides,
774
+ padding=self.padding,
775
+ data_format=self.data_format,
776
+ dilation_rate=self.dilation_rate,
777
+ )
778
+ if original_ndim == 4:
779
+ out_shape = out_shape[1:]
780
+ return KerasTensor(shape=out_shape, dtype=volumes.dtype)
781
+
782
+
783
+ def _extract_patches_3d(
784
+ volumes,
785
+ size,
786
+ strides=None,
787
+ dilation_rate=1,
788
+ padding="valid",
789
+ data_format=None,
790
+ ):
791
+ if isinstance(size, int):
792
+ patch_d = patch_h = patch_w = size
793
+ elif len(size) == 3:
794
+ patch_d, patch_h, patch_w = size
795
+ else:
796
+ raise TypeError(
797
+ "Invalid `size` argument. Expected an "
798
+ f"int or a tuple of length 3. Received: size={size}"
799
+ )
800
+ if strides is None:
801
+ strides = size
802
+ if isinstance(strides, int):
803
+ strides = (strides, strides, strides)
804
+ if len(strides) != 3:
805
+ raise ValueError(f"Invalid `strides` argument. Got: {strides}")
806
+ data_format = backend.standardize_data_format(data_format)
807
+ if data_format == "channels_last":
808
+ channels_in = volumes.shape[-1]
809
+ elif data_format == "channels_first":
810
+ channels_in = volumes.shape[-4]
811
+ out_dim = patch_d * patch_w * patch_h * channels_in
812
+ kernel = backend.numpy.eye(out_dim, dtype=volumes.dtype)
813
+ kernel = backend.numpy.reshape(
814
+ kernel, (patch_d, patch_h, patch_w, channels_in, out_dim)
815
+ )
816
+ _unbatched = False
817
+ if len(volumes.shape) == 4:
818
+ _unbatched = True
819
+ volumes = backend.numpy.expand_dims(volumes, axis=0)
820
+ patches = backend.nn.conv(
821
+ inputs=volumes,
822
+ kernel=kernel,
823
+ strides=strides,
824
+ padding=padding,
825
+ data_format=data_format,
826
+ dilation_rate=dilation_rate,
827
+ )
828
+ if _unbatched:
829
+ patches = backend.numpy.squeeze(patches, axis=0)
830
+ return patches
831
+
832
+
833
+ @keras_export("keras.ops.image.extract_patches_3d")
834
+ def extract_patches_3d(
835
+ volumes,
836
+ size,
837
+ strides=None,
838
+ dilation_rate=1,
839
+ padding="valid",
840
+ data_format=None,
841
+ ):
842
+ """Extracts patches from the volume(s).
843
+
844
+ Args:
845
+ volumes: Input volume or batch of volumes. Must be 4D or 5D.
846
+ size: Patch size int or tuple (patch_depth, patch_height, patch_width)
847
+ strides: strides along depth, height, and width. If not specified, or
848
+ if `None`, it defaults to the same value as `size`.
849
+ dilation_rate: This is the input stride, specifying how far two
850
+ consecutive patch samples are in the input. Note that using
851
+ `dilation_rate > 1` is not supported in conjunction with
852
+ `strides > 1` on the TensorFlow backend.
853
+ padding: The type of padding algorithm to use: `"same"` or `"valid"`.
854
+ data_format: A string specifying the data format of the input tensor.
855
+ It can be either `"channels_last"` or `"channels_first"`.
856
+ `"channels_last"` corresponds to inputs with shape
857
+ `(batch, depth, height, width, channels)`, while `"channels_first"`
858
+ corresponds to inputs with shape
859
+ `(batch, channels, depth, height, width)`. If not specified,
860
+ the value will default to `keras.config.image_data_format()`.
861
+
862
+ Returns:
863
+ Extracted patches 4D (if not batched) or 5D (if batched)
864
+
865
+ Examples:
866
+
867
+ >>> import numpy as np
868
+ >>> import keras
869
+ >>> # Batched case
870
+ >>> volumes = np.random.random(
871
+ ... (2, 10, 10, 10, 3)
872
+ ... ).astype("float32") # batch of 2 volumes
873
+ >>> patches = keras.ops.image.extract_patches_3d(volumes, (3, 3, 3))
874
+ >>> patches.shape
875
+ (2, 3, 3, 3, 81)
876
+ >>> # Unbatched case
877
+ >>> volume = np.random.random((10, 10, 10, 3)).astype("float32") # 1 volume
878
+ >>> patches = keras.ops.image.extract_patches_3d(volume, (3, 3, 3))
879
+ >>> patches.shape
880
+ (3, 3, 3, 81)
881
+ """
882
+ if any_symbolic_tensors((volumes,)):
883
+ return ExtractPatches3D(
884
+ size=size,
885
+ strides=strides,
886
+ dilation_rate=dilation_rate,
887
+ padding=padding,
888
+ data_format=data_format,
889
+ ).symbolic_call(volumes)
890
+
891
+ return _extract_patches_3d(
892
+ volumes, size, strides, dilation_rate, padding, data_format=data_format
893
+ )
894
+
895
+
715
896
  class MapCoordinates(Operation):
716
897
  def __init__(self, order, fill_mode="constant", fill_value=0, *, name=None):
717
898
  super().__init__(name=name)
keras/src/ops/linalg.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from keras.src import backend
2
+ from keras.src import tree
2
3
  from keras.src.api_export import keras_export
3
4
  from keras.src.backend import KerasTensor
4
5
  from keras.src.backend import any_symbolic_tensors
@@ -732,3 +733,95 @@ def _assert_a_b_compat(a, b):
732
733
  "Expected `a.shape[-1] == b.shape[-1]`. "
733
734
  f"Received: a.shape={a.shape}, b.shape={b.shape}"
734
735
  )
736
+
737
+
738
+ class JVP(Operation):
739
+ def __init__(self, has_aux=False, *, name=None):
740
+ super().__init__(name=name)
741
+ self.has_aux = has_aux
742
+
743
+ def call(self, fun, primals, tangents):
744
+ """Computes the JVP of `fun` at `primals` along `tangents`.
745
+
746
+ Args:
747
+ fun: A callable that takes tensors (or nested structures) as input
748
+ and returns a tensor (or nested structure) as output.
749
+ primals: Input tensors (or nested structures) at which the Jacobian
750
+ of `fun` is evaluated.
751
+ tangents: Tensors (or nested structures) representing the direction
752
+ vectors for the JVP. Must have the same structure as
753
+ `primals`.
754
+
755
+ Returns:
756
+ If `has_aux` is False:
757
+ A tuple (primals_out, tangents_out) where:
758
+ - primals_out: Output of `fun(*primals)`
759
+ - tangents_out: JVP of `fun` at `primals` along `tangents`
760
+ If `has_aux` is True:
761
+ A tuple (primals_out, tangents_out, aux) where:
762
+ - aux: Auxiliary data returned by `fun`
763
+ """
764
+ return backend.linalg.jvp(fun, primals, tangents, has_aux=self.has_aux)
765
+
766
+ def compute_output_spec(self, fun, primals, tangents):
767
+ # Infer primal output spec
768
+ if self.has_aux:
769
+ primals_out_spec, aux_spec = backend.compute_output_spec(
770
+ fun, *primals
771
+ )
772
+ else:
773
+ primals_out_spec = backend.compute_output_spec(fun, *primals)
774
+
775
+ # Tangents output should match primals output in structure and shape
776
+ tangents_out_spec = tree.map_structure(
777
+ lambda x: KerasTensor(x.shape, x.dtype), primals_out_spec
778
+ )
779
+
780
+ if self.has_aux:
781
+ return primals_out_spec, tangents_out_spec, aux_spec
782
+ return primals_out_spec, tangents_out_spec
783
+
784
+
785
+ @keras_export(["keras.ops.jvp", "keras.ops.linalg.jvp"])
786
+ def jvp(fun, primals, tangents, has_aux=False):
787
+ """Computes a (forward-mode) Jacobian-vector product of `fun`.
788
+ Args:
789
+ fun: Function to be differentiated. Its arguments should be arrays,
790
+ scalars, or standard Python containers of arrays or scalars. It
791
+ should return an array, scalar, or standard Python container of
792
+ arrays or scalars.
793
+ primals: The primal values at which the Jacobian of `fun` should be
794
+ evaluated. Should be either a tuple or a list of arguments,
795
+ and its length should be equal to the number of positional
796
+ parameters of `fun`.
797
+ tangents: The tangent vector for which the Jacobian-vector product
798
+ should be evaluated. Should be either a tuple or a list of
799
+ tangents, with the same tree structure and array shapes as
800
+ `primals`.
801
+ has_aux: Optional, bool. Indicates whether `fun` returns a pair where
802
+ the first element is considered the output of the mathematical
803
+ function to be differentiated and the second element is
804
+ auxiliary data. Default is False.
805
+
806
+ Returns:
807
+ If `has_aux` is False, returns a (`primals_out`, `tangents_out`) pair,
808
+ where `primals_out` is `fun(*primals)`, and `tangents_out` is the
809
+ Jacobian-vector product of `fun` evaluated at `primals` with
810
+ `tangents`. The `tangents_out` value has the same Python tree
811
+ structure and shapes as `primals_out`.
812
+
813
+ If `has_aux` is True, returns a (`primals_out`, `tangents_out`, `aux`)
814
+ tuple where `aux` is the auxiliary data returned by `fun`.
815
+
816
+ Example:
817
+ >>> from keras import ops
818
+ >>> a1, a2 = ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2)
819
+ >>> primals, tangents = ops.jvp(ops.sin, (a1,), (a2,))
820
+ >>> primals
821
+ 0.09983342
822
+ >>> tangents
823
+ 0.19900084
824
+ """
825
+ if any_symbolic_tensors((primals, tangents)):
826
+ return JVP(has_aux=has_aux).symbolic_call(fun, primals, tangents)
827
+ return backend.linalg.jvp(fun, primals, tangents, has_aux=has_aux)