keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/models/model.py CHANGED
@@ -2,13 +2,16 @@ import inspect
2
2
  import json
3
3
  import typing
4
4
  import warnings
5
+ from collections.abc import Callable
5
6
 
6
7
  from keras.src import backend
7
8
  from keras.src import utils
8
9
  from keras.src.api_export import keras_export
9
10
  from keras.src.layers.layer import Layer
10
11
  from keras.src.models.variable_mapping import map_saveable_variables
11
- from keras.src.quantizers.gptq_config import GPTQConfig
12
+ from keras.src.quantizers.awq_core import awq_quantize
13
+ from keras.src.quantizers.gptq_core import gptq_quantize
14
+ from keras.src.quantizers.utils import should_quantize_layer
12
15
  from keras.src.saving import saving_api
13
16
  from keras.src.trainers import trainer as base_trainer
14
17
  from keras.src.utils import summary_utils
@@ -421,62 +424,168 @@ class Model(Trainer, base_trainer.Trainer, Layer):
421
424
  **kwargs,
422
425
  )
423
426
 
424
- def quantize(self, mode, config=None, **kwargs):
427
+ def get_quantization_layer_structure(self, mode=None):
428
+ """Returns the quantization structure for the model.
429
+
430
+ This method is intended to be overridden by model authors to provide
431
+ topology information required for structure-aware quantization modes
432
+ like 'gptq'.
433
+
434
+ Args:
435
+ mode: The quantization mode.
436
+
437
+ Returns:
438
+ A dictionary describing the topology, e.g.:
439
+ `{'pre_block_layers': [list], 'sequential_blocks': [list]}`
440
+ or `None` if the mode does not require structure or is not
441
+ supported. `'pre_block_layers'` is a list of layers that
442
+ the inputs should be passed through, before being passed to
443
+ the sequential blocks. For example, inputs to an LLM must
444
+ first be passed through an embedding layer, followed by
445
+ the transformer.
446
+ """
447
+ del mode # Unused.
448
+ return None
449
+
450
+ def quantize(self, mode=None, config=None, filters=None, **kwargs):
425
451
  """Quantize the weights of the model.
426
452
 
427
453
  Note that the model must be built first before calling this method.
428
- `quantize` will recursively call `quantize(mode)` in all layers and
454
+ `quantize` will recursively call `quantize(...)` in all layers and
429
455
  will be skipped if the layer doesn't implement the function.
430
456
 
457
+ This method can be called by passing a `mode` string, which uses the
458
+ default configuration for that mode. Alternatively, a `config` object
459
+ can be passed to customize the behavior of the quantization (e.g. to
460
+ use specific quantizers for weights or activations).
461
+
431
462
  Args:
432
- mode: The mode of the quantization. Only 'int8' is supported at this
433
- time.
434
- """
435
- from keras.src.dtype_policies import QUANTIZATION_MODES
463
+ mode: The mode of the quantization. Supported modes are:
464
+ `"int8"`, `"int4"`, `"float8"`, `"gptq"`. This is
465
+ optional if `config` is provided.
466
+ config: The configuration object specifying additional
467
+ quantization options. This argument allows to configure
468
+ the weight and activation quantizers. be an instance of
469
+ `keras.quantizers.QuantizationConfig`.
470
+ filters: Optional filters to apply to the quantization. Can be a
471
+ regex string, a list of regex strings, or a callable. Only the
472
+ layers which match the filter conditions will be quantized.
473
+ **kwargs: Additional keyword arguments.
436
474
 
437
- if mode == "gptq":
438
- if not isinstance(config, GPTQConfig):
439
- raise ValueError(
440
- "The `config` argument must be of type "
441
- "`keras.quantizers.GPTQConfig`."
442
- )
443
- # The config object's own quantize method drives the process
444
- config.quantize(self)
445
- return
475
+ Example:
446
476
 
447
- # For all other modes, verify that a config object was not passed.
448
- if config is not None:
449
- raise ValueError(
450
- f"The `config` argument is only supported for 'gptq' mode, "
451
- f"but received mode='{mode}'."
452
- )
477
+ Quantize a model to int8 with default configuration:
478
+
479
+ ```python
480
+ # Build the model
481
+ model = keras.Sequential([
482
+ keras.Input(shape=(10,)),
483
+ keras.layers.Dense(10),
484
+ ])
485
+ model.build((None, 10))
453
486
 
487
+ # Quantize with default int8 config
488
+ model.quantize("int8")
489
+ ```
490
+
491
+ Quantize a model to int8 with a custom configuration:
492
+
493
+ ```python
494
+ from keras.quantizers import Int8QuantizationConfig
495
+ from keras.quantizers import AbsMaxQuantizer
496
+
497
+ # Build the model
498
+ model = keras.Sequential([
499
+ keras.Input(shape=(10,)),
500
+ keras.layers.Dense(10),
501
+ ])
502
+ model.build((None, 10))
503
+
504
+ # Create a custom config
505
+ config = Int8QuantizationConfig(
506
+ weight_quantizer=AbsMaxQuantizer(
507
+ axis=0,
508
+ value_range=(-127, 127)
509
+ ),
510
+ activation_quantizer=AbsMaxQuantizer(
511
+ axis=-1,
512
+ value_range=(-127, 127)
513
+ ),
514
+ )
515
+
516
+ # Quantize with custom config
517
+ model.quantize(config=config)
518
+ ```
519
+ """
520
+ # Validate inputs.
454
521
  type_check = kwargs.pop("type_check", True)
455
522
  if kwargs:
456
523
  raise ValueError(
457
524
  "Unrecognized keyword arguments "
458
525
  f"passed to {self.__class__.__name__}: {kwargs}"
459
526
  )
460
- if mode not in QUANTIZATION_MODES:
461
- raise ValueError(
462
- "Invalid quantization mode. "
463
- f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
464
- )
465
- mode_changed = False
527
+
528
+ if filters is not None:
529
+ if not isinstance(filters, (str, Callable, list, tuple)):
530
+ raise ValueError(
531
+ "The `filters` argument must be a regex string, a list of "
532
+ "regex strings, or a callable. Received: "
533
+ f"{type(filters)}"
534
+ )
535
+
536
+ graph_modified = False
466
537
  for layer in self._flatten_layers():
467
- list_of_sublayers = list(layer._flatten_layers())
468
- if len(list_of_sublayers) == 1: # leaves of the model
538
+ # Apply filters
539
+ if not should_quantize_layer(layer, filters):
540
+ continue
541
+
542
+ if len(list(layer._flatten_layers())) == 1:
469
543
  try:
470
- layer.quantize(mode, type_check=type_check)
471
- mode_changed = True
544
+ layer.quantize(mode, type_check=type_check, config=config)
545
+ graph_modified = True
472
546
  except NotImplementedError as e:
473
547
  warnings.warn(str(e))
474
- # We need to set these functions to `None` to remake them for changed
475
- # call function
476
- if mode_changed:
548
+ except AttributeError:
549
+ pass
550
+
551
+ if mode in ["gptq", "awq"]:
552
+ # Resolve model structure.
553
+ # 1. If quantization_layer_structure is provided inside the config,
554
+ # use that.
555
+ structure = config.quantization_layer_structure
556
+ # 2. If no layer structure is provided in the config, try to fetch
557
+ # it using the `get_quantization_layer_structure` hook.
558
+ if structure is None:
559
+ structure = self.get_quantization_layer_structure(mode)
560
+
561
+ if structure is None:
562
+ raise ValueError(
563
+ f"For {mode=}, a valid quantization structure must be "
564
+ "provided either via `config.quantization_layer_structure` "
565
+ "or by overriding "
566
+ "`model.get_quantization_layer_structure(mode)`. The "
567
+ "structure should be a dictionary with keys "
568
+ "'pre_block_layers' and 'sequential_blocks'."
569
+ )
570
+ if mode == "gptq":
571
+ gptq_quantize(config, structure, filters=filters)
572
+ elif mode == "awq":
573
+ awq_quantize(config, structure, filters=filters)
574
+
575
+ # If any layer was changed, we must rebuild the execution functions.
576
+ if graph_modified:
477
577
  self.train_function = None
478
578
  self.test_function = None
479
579
  self.predict_function = None
580
+ self._post_quantize(mode, **kwargs)
581
+
582
+ def _post_quantize(self, mode, **kwargs):
583
+ if backend.backend() == "torch":
584
+ # We need to manually retrack `torch_params`.
585
+ # The reason is that after quantization, the removed variables are
586
+ # still referenced by `torch_params` and cannot be gc.
587
+ for layer in self._flatten_layers():
588
+ layer._track_variables()
480
589
 
481
590
  def build_from_config(self, config):
482
591
  if not config:
@@ -556,8 +665,8 @@ class Model(Trainer, base_trainer.Trainer, Layer):
556
665
  filepath: `str` or `pathlib.Path` object. The path to save the
557
666
  artifact.
558
667
  format: `str`. The export format. Supported values:
559
- `"tf_saved_model"` and `"onnx"`. Defaults to
560
- `"tf_saved_model"`.
668
+ `"tf_saved_model"`, `"onnx"`, `"openvino"`, and `"litert"`.
669
+ Defaults to `"tf_saved_model"`.
561
670
  verbose: `bool`. Whether to print a message during export. Defaults
562
671
  to `None`, which uses the default value set by different
563
672
  backends and formats.
@@ -580,6 +689,13 @@ class Model(Trainer, base_trainer.Trainer, Layer):
580
689
  provided, they will be automatically computed.
581
690
  - `opset_version`: Optional `int`. Specific to `format="onnx"`.
582
691
  An integer value that specifies the ONNX opset version.
692
+ - LiteRT-specific options: Optional keyword arguments specific
693
+ to `format="litert"`. These are passed directly to the
694
+ TensorFlow Lite converter and include options like
695
+ `optimizations`, `representative_dataset`,
696
+ `experimental_new_quantizer`, `allow_custom_ops`,
697
+ `enable_select_tf_ops`, etc. See TensorFlow Lite
698
+ documentation for all available options.
583
699
 
584
700
  **Note:** This feature is currently supported only with TensorFlow, JAX
585
701
  and Torch backends.
@@ -614,18 +730,41 @@ class Model(Trainer, base_trainer.Trainer, Layer):
614
730
  }
615
731
  predictions = ort_session.run(None, ort_inputs)
616
732
  ```
733
+
734
+ Here's how to export a LiteRT (TFLite) for inference.
735
+
736
+ ```python
737
+ # Export the model as a LiteRT artifact
738
+ model.export("path/to/location", format="litert")
739
+
740
+ # Load the artifact in a different process/environment
741
+ interpreter = tf.lite.Interpreter(model_path="path/to/location")
742
+ interpreter.allocate_tensors()
743
+ interpreter.set_tensor(
744
+ interpreter.get_input_details()[0]['index'], input_data
745
+ )
746
+ interpreter.invoke()
747
+ output_data = interpreter.get_tensor(
748
+ interpreter.get_output_details()[0]['index']
749
+ )
750
+ ```
617
751
  """
752
+ from keras.src.export import export_litert
618
753
  from keras.src.export import export_onnx
619
754
  from keras.src.export import export_openvino
620
755
  from keras.src.export import export_saved_model
621
756
 
622
- available_formats = ("tf_saved_model", "onnx", "openvino")
757
+ available_formats = ("tf_saved_model", "onnx", "openvino", "litert")
623
758
  if format not in available_formats:
624
759
  raise ValueError(
625
760
  f"Unrecognized format={format}. Supported formats are: "
626
761
  f"{list(available_formats)}."
627
762
  )
628
763
 
764
+ # Check if LiteRT export is available (requires TensorFlow backend)
765
+ if format == "litert" and backend.backend() != "tensorflow":
766
+ raise ImportError("LiteRT export requires TensorFlow backend.")
767
+
629
768
  if format == "tf_saved_model":
630
769
  export_saved_model(
631
770
  self,
@@ -650,6 +789,13 @@ class Model(Trainer, base_trainer.Trainer, Layer):
650
789
  input_signature=input_signature,
651
790
  **kwargs,
652
791
  )
792
+ elif format == "litert":
793
+ export_litert(
794
+ self,
795
+ filepath,
796
+ input_signature=input_signature,
797
+ **kwargs,
798
+ )
653
799
 
654
800
  @classmethod
655
801
  def from_config(cls, config, custom_objects=None):
@@ -850,13 +996,18 @@ class Model(Trainer, base_trainer.Trainer, Layer):
850
996
  self.non_trainable_variables, path_value_dict
851
997
  )
852
998
  elif k == "optimizer_variables":
853
- self._assign_variable_values(
854
- self.optimizer.variables, path_value_dict
855
- )
999
+ if hasattr(self, "optimizer") and self.optimizer is not None:
1000
+ self._assign_variable_values(
1001
+ self.optimizer.variables, path_value_dict
1002
+ )
856
1003
  elif k == "metrics_variables":
857
- self._assign_variable_values(
858
- self.metrics_variables, path_value_dict
859
- )
1004
+ if (
1005
+ hasattr(self, "metrics_variables")
1006
+ and self.metrics_variables
1007
+ ):
1008
+ self._assign_variable_values(
1009
+ self.metrics_variables, path_value_dict
1010
+ )
860
1011
  else:
861
1012
  raise ValueError(f"Unknown variable name: {k}")
862
1013
 
keras/src/ops/image.py CHANGED
@@ -565,6 +565,8 @@ class ExtractPatches(Operation):
565
565
  if isinstance(size, int):
566
566
  size = (size, size)
567
567
  self.size = size
568
+ if strides is None:
569
+ strides = size
568
570
  self.strides = strides
569
571
  self.dilation_rate = dilation_rate
570
572
  self.padding = padding
@@ -583,8 +585,6 @@ class ExtractPatches(Operation):
583
585
  def compute_output_spec(self, images):
584
586
  images_shape = list(images.shape)
585
587
  original_ndim = len(images_shape)
586
- if not self.strides:
587
- strides = (self.size[0], self.size[1])
588
588
  if self.data_format == "channels_last":
589
589
  channels_in = images_shape[-1]
590
590
  else:
@@ -597,7 +597,7 @@ class ExtractPatches(Operation):
597
597
  images_shape,
598
598
  filters,
599
599
  kernel_size,
600
- strides=strides,
600
+ strides=self.strides,
601
601
  padding=self.padding,
602
602
  data_format=self.data_format,
603
603
  dilation_rate=self.dilation_rate,
@@ -616,42 +616,98 @@ def extract_patches(
616
616
  padding="valid",
617
617
  data_format=None,
618
618
  ):
619
- """Extracts patches from the image(s).
619
+ """Extracts patches from the image(s) or volume(s).
620
+
621
+ This function supports both 2D and 3D patch extraction based on the
622
+ `size` argument length, similar to how `keras.ops.conv` handles
623
+ different dimensions.
620
624
 
621
625
  Args:
622
- images: Input image or batch of images. Must be 3D or 4D.
623
- size: Patch size int or tuple (patch_height, patch_width)
624
- strides: strides along height and width. If not specified, or
625
- if `None`, it defaults to the same value as `size`.
626
- dilation_rate: This is the input stride, specifying how far two
627
- consecutive patch samples are in the input. For value other than 1,
628
- strides must be 1. NOTE: `strides > 1` is not supported in
629
- conjunction with `dilation_rate > 1`
626
+ images: Input image/volume or batch of images/volumes.
627
+ For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.
628
+ For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.
629
+ size: Patch size as int or tuple.
630
+ Length 2 tuple `(patch_height, patch_width)` or int for 2D patches.
631
+ Length 3 tuple `(patch_depth, patch_height, patch_width)` for
632
+ 3D patches.
633
+ strides: Strides for patch extraction. If not specified, defaults
634
+ to `size` (non-overlapping patches).
635
+ dilation_rate: Dilation rate for patch extraction. Note that
636
+ `dilation_rate > 1` is not supported with `strides > 1`.
630
637
  padding: The type of padding algorithm to use: `"same"` or `"valid"`.
631
638
  data_format: A string specifying the data format of the input tensor.
632
639
  It can be either `"channels_last"` or `"channels_first"`.
633
- `"channels_last"` corresponds to inputs with shape
634
- `(batch, height, width, channels)`, while `"channels_first"`
635
- corresponds to inputs with shape `(batch, channels, height, width)`.
636
- If not specified, the value will default to
637
- `keras.config.image_data_format`.
640
+ If not specified, defaults to `keras.config.image_data_format`.
638
641
 
639
642
  Returns:
640
- Extracted patches 3D (if not batched) or 4D (if batched)
643
+ Extracted patches with shape depending on input and `size`:
644
+ - 2D patches: 3D (unbatched) or 4D (batched)
645
+ - 3D patches: 4D (unbatched) or 5D (batched)
641
646
 
642
647
  Examples:
643
648
 
649
+ >>> # 2D patches from batch of images
644
650
  >>> image = np.random.random(
645
651
  ... (2, 20, 20, 3)
646
- ... ).astype("float32") # batch of 2 RGB images
652
+ ... ).astype("float32")
647
653
  >>> patches = keras.ops.image.extract_patches(image, (5, 5))
648
654
  >>> patches.shape
649
655
  (2, 4, 4, 75)
650
- >>> image = np.random.random((20, 20, 3)).astype("float32") # 1 RGB image
656
+
657
+ >>> # 2D patches from single image
658
+ >>> image = np.random.random((20, 20, 3)).astype("float32")
651
659
  >>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1))
652
660
  >>> patches.shape
653
661
  (18, 18, 27)
662
+
663
+ >>> # 3D patches from batch of volumes
664
+ >>> volumes = np.random.random(
665
+ ... (2, 10, 10, 10, 3)
666
+ ... ).astype("float32")
667
+ >>> patches = keras.ops.image.extract_patches(volumes, (3, 3, 3))
668
+ >>> patches.shape
669
+ (2, 3, 3, 3, 81)
670
+
671
+ >>> # 3D patches from single volume
672
+ >>> volume = np.random.random((10, 10, 10, 3)).astype("float32")
673
+ >>> patches = keras.ops.image.extract_patches(volume, (3, 3, 3))
674
+ >>> patches.shape
675
+ (3, 3, 3, 81)
654
676
  """
677
+ # Validate size argument
678
+ if not isinstance(size, int):
679
+ if not isinstance(size, (tuple, list)):
680
+ raise TypeError(
681
+ "Invalid `size` argument. Expected an int or a tuple. "
682
+ f"Received: size={size} of type {type(size).__name__}"
683
+ )
684
+ if len(size) not in (2, 3):
685
+ raise ValueError(
686
+ "Invalid `size` argument. Expected a tuple of length 2 or 3. "
687
+ f"Received: size={size} with length {len(size)}"
688
+ )
689
+
690
+ # Determine 2D vs 3D based on size argument
691
+ if not isinstance(size, int) and len(size) == 3:
692
+ # 3D patch extraction
693
+ if any_symbolic_tensors((images,)):
694
+ return ExtractPatches3D(
695
+ size=size,
696
+ strides=strides,
697
+ dilation_rate=dilation_rate,
698
+ padding=padding,
699
+ data_format=data_format,
700
+ ).symbolic_call(images)
701
+ return _extract_patches_3d(
702
+ images,
703
+ size,
704
+ strides,
705
+ dilation_rate,
706
+ padding,
707
+ data_format=data_format,
708
+ )
709
+
710
+ # 2D patch extraction (default)
655
711
  if any_symbolic_tensors((images,)):
656
712
  return ExtractPatches(
657
713
  size=size,
@@ -712,6 +768,187 @@ def _extract_patches(
712
768
  return patches
713
769
 
714
770
 
771
+ class ExtractPatches3D(Operation):
772
+ def __init__(
773
+ self,
774
+ size,
775
+ strides=None,
776
+ dilation_rate=1,
777
+ padding="valid",
778
+ data_format=None,
779
+ *,
780
+ name=None,
781
+ ):
782
+ super().__init__(name=name)
783
+ if isinstance(size, int):
784
+ size = (size, size, size)
785
+ elif len(size) != 3:
786
+ raise TypeError(
787
+ "Invalid `size` argument. Expected an "
788
+ f"int or a tuple of length 3. Received: size={size}"
789
+ )
790
+ self.size = size
791
+ if strides is not None:
792
+ if isinstance(strides, int):
793
+ strides = (strides, strides, strides)
794
+ elif len(strides) != 3:
795
+ raise ValueError(f"Invalid `strides` argument. Got: {strides}")
796
+ else:
797
+ strides = size
798
+ self.strides = strides
799
+ self.dilation_rate = dilation_rate
800
+ self.padding = padding
801
+ self.data_format = backend.standardize_data_format(data_format)
802
+
803
+ def call(self, volumes):
804
+ return _extract_patches_3d(
805
+ volumes,
806
+ self.size,
807
+ self.strides,
808
+ self.dilation_rate,
809
+ self.padding,
810
+ self.data_format,
811
+ )
812
+
813
+ def compute_output_spec(self, volumes):
814
+ volumes_shape = list(volumes.shape)
815
+ original_ndim = len(volumes_shape)
816
+ strides = self.strides
817
+ if self.data_format == "channels_last":
818
+ channels_in = volumes_shape[-1]
819
+ else:
820
+ channels_in = volumes_shape[-4]
821
+ if original_ndim == 4:
822
+ volumes_shape = [1] + volumes_shape
823
+ filters = self.size[0] * self.size[1] * self.size[2] * channels_in
824
+ kernel_size = (self.size[0], self.size[1], self.size[2])
825
+ out_shape = compute_conv_output_shape(
826
+ volumes_shape,
827
+ filters,
828
+ kernel_size,
829
+ strides=strides,
830
+ padding=self.padding,
831
+ data_format=self.data_format,
832
+ dilation_rate=self.dilation_rate,
833
+ )
834
+ if original_ndim == 4:
835
+ out_shape = out_shape[1:]
836
+ return KerasTensor(shape=out_shape, dtype=volumes.dtype)
837
+
838
+
839
+ def _extract_patches_3d(
840
+ volumes,
841
+ size,
842
+ strides=None,
843
+ dilation_rate=1,
844
+ padding="valid",
845
+ data_format=None,
846
+ ):
847
+ if isinstance(size, int):
848
+ patch_d = patch_h = patch_w = size
849
+ elif len(size) == 3:
850
+ patch_d, patch_h, patch_w = size
851
+ else:
852
+ raise TypeError(
853
+ "Invalid `size` argument. Expected an "
854
+ f"int or a tuple of length 3. Received: size={size}"
855
+ )
856
+ if strides is None:
857
+ strides = size
858
+ if isinstance(strides, int):
859
+ strides = (strides, strides, strides)
860
+ if len(strides) != 3:
861
+ raise ValueError(f"Invalid `strides` argument. Got: {strides}")
862
+ data_format = backend.standardize_data_format(data_format)
863
+ if data_format == "channels_last":
864
+ channels_in = volumes.shape[-1]
865
+ elif data_format == "channels_first":
866
+ channels_in = volumes.shape[-4]
867
+ out_dim = patch_d * patch_w * patch_h * channels_in
868
+ kernel = backend.numpy.eye(out_dim, dtype=volumes.dtype)
869
+ kernel = backend.numpy.reshape(
870
+ kernel, (patch_d, patch_h, patch_w, channels_in, out_dim)
871
+ )
872
+ _unbatched = False
873
+ if len(volumes.shape) == 4:
874
+ _unbatched = True
875
+ volumes = backend.numpy.expand_dims(volumes, axis=0)
876
+ patches = backend.nn.conv(
877
+ inputs=volumes,
878
+ kernel=kernel,
879
+ strides=strides,
880
+ padding=padding,
881
+ data_format=data_format,
882
+ dilation_rate=dilation_rate,
883
+ )
884
+ if _unbatched:
885
+ patches = backend.numpy.squeeze(patches, axis=0)
886
+ return patches
887
+
888
+
889
+ @keras_export("keras.ops.image.extract_patches_3d")
890
+ def extract_patches_3d(
891
+ volumes,
892
+ size,
893
+ strides=None,
894
+ dilation_rate=1,
895
+ padding="valid",
896
+ data_format=None,
897
+ ):
898
+ """Extracts patches from the volume(s).
899
+
900
+ Args:
901
+ volumes: Input volume or batch of volumes. Must be 4D or 5D.
902
+ size: Patch size int or tuple (patch_depth, patch_height, patch_width)
903
+ strides: strides along depth, height, and width. If not specified, or
904
+ if `None`, it defaults to the same value as `size`.
905
+ dilation_rate: This is the input stride, specifying how far two
906
+ consecutive patch samples are in the input. Note that using
907
+ `dilation_rate > 1` is not supported in conjunction with
908
+ `strides > 1` on the TensorFlow backend.
909
+ padding: The type of padding algorithm to use: `"same"` or `"valid"`.
910
+ data_format: A string specifying the data format of the input tensor.
911
+ It can be either `"channels_last"` or `"channels_first"`.
912
+ `"channels_last"` corresponds to inputs with shape
913
+ `(batch, depth, height, width, channels)`, while `"channels_first"`
914
+ corresponds to inputs with shape
915
+ `(batch, channels, depth, height, width)`. If not specified,
916
+ the value will default to `keras.config.image_data_format()`.
917
+
918
+ Returns:
919
+ Extracted patches 4D (if not batched) or 5D (if batched)
920
+
921
+ Examples:
922
+
923
+ >>> import numpy as np
924
+ >>> import keras
925
+ >>> # Batched case
926
+ >>> volumes = np.random.random(
927
+ ... (2, 10, 10, 10, 3)
928
+ ... ).astype("float32") # batch of 2 volumes
929
+ >>> patches = keras.ops.image.extract_patches_3d(volumes, (3, 3, 3))
930
+ >>> patches.shape
931
+ (2, 3, 3, 3, 81)
932
+ >>> # Unbatched case
933
+ >>> volume = np.random.random((10, 10, 10, 3)).astype("float32") # 1 volume
934
+ >>> patches = keras.ops.image.extract_patches_3d(volume, (3, 3, 3))
935
+ >>> patches.shape
936
+ (3, 3, 3, 81)
937
+ """
938
+ if any_symbolic_tensors((volumes,)):
939
+ return ExtractPatches3D(
940
+ size=size,
941
+ strides=strides,
942
+ dilation_rate=dilation_rate,
943
+ padding=padding,
944
+ data_format=data_format,
945
+ ).symbolic_call(volumes)
946
+
947
+ return _extract_patches_3d(
948
+ volumes, size, strides, dilation_rate, padding, data_format=data_format
949
+ )
950
+
951
+
715
952
  class MapCoordinates(Operation):
716
953
  def __init__(self, order, fill_mode="constant", fill_value=0, *, name=None):
717
954
  super().__init__(name=name)