keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/ops/linalg.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from keras.src import backend
2
+ from keras.src import tree
2
3
  from keras.src.api_export import keras_export
3
4
  from keras.src.backend import KerasTensor
4
5
  from keras.src.backend import any_symbolic_tensors
@@ -732,3 +733,95 @@ def _assert_a_b_compat(a, b):
732
733
  "Expected `a.shape[-1] == b.shape[-1]`. "
733
734
  f"Received: a.shape={a.shape}, b.shape={b.shape}"
734
735
  )
736
+
737
+
738
+ class JVP(Operation):
739
+ def __init__(self, has_aux=False, *, name=None):
740
+ super().__init__(name=name)
741
+ self.has_aux = has_aux
742
+
743
+ def call(self, fun, primals, tangents):
744
+ """Computes the JVP of `fun` at `primals` along `tangents`.
745
+
746
+ Args:
747
+ fun: A callable that takes tensors (or nested structures) as input
748
+ and returns a tensor (or nested structure) as output.
749
+ primals: Input tensors (or nested structures) at which the Jacobian
750
+ of `fun` is evaluated.
751
+ tangents: Tensors (or nested structures) representing the direction
752
+ vectors for the JVP. Must have the same structure as
753
+ `primals`.
754
+
755
+ Returns:
756
+ If `has_aux` is False:
757
+ A tuple (primals_out, tangents_out) where:
758
+ - primals_out: Output of `fun(*primals)`
759
+ - tangents_out: JVP of `fun` at `primals` along `tangents`
760
+ If `has_aux` is True:
761
+ A tuple (primals_out, tangents_out, aux) where:
762
+ - aux: Auxiliary data returned by `fun`
763
+ """
764
+ return backend.linalg.jvp(fun, primals, tangents, has_aux=self.has_aux)
765
+
766
+ def compute_output_spec(self, fun, primals, tangents):
767
+ # Infer primal output spec
768
+ if self.has_aux:
769
+ primals_out_spec, aux_spec = backend.compute_output_spec(
770
+ fun, *primals
771
+ )
772
+ else:
773
+ primals_out_spec = backend.compute_output_spec(fun, *primals)
774
+
775
+ # Tangents output should match primals output in structure and shape
776
+ tangents_out_spec = tree.map_structure(
777
+ lambda x: KerasTensor(x.shape, x.dtype), primals_out_spec
778
+ )
779
+
780
+ if self.has_aux:
781
+ return primals_out_spec, tangents_out_spec, aux_spec
782
+ return primals_out_spec, tangents_out_spec
783
+
784
+
785
+ @keras_export(["keras.ops.jvp", "keras.ops.linalg.jvp"])
786
+ def jvp(fun, primals, tangents, has_aux=False):
787
+ """Computes a (forward-mode) Jacobian-vector product of `fun`.
788
+ Args:
789
+ fun: Function to be differentiated. Its arguments should be arrays,
790
+ scalars, or standard Python containers of arrays or scalars. It
791
+ should return an array, scalar, or standard Python container of
792
+ arrays or scalars.
793
+ primals: The primal values at which the Jacobian of `fun` should be
794
+ evaluated. Should be either a tuple or a list of arguments,
795
+ and its length should be equal to the number of positional
796
+ parameters of `fun`.
797
+ tangents: The tangent vector for which the Jacobian-vector product
798
+ should be evaluated. Should be either a tuple or a list of
799
+ tangents, with the same tree structure and array shapes as
800
+ `primals`.
801
+ has_aux: Optional, bool. Indicates whether `fun` returns a pair where
802
+ the first element is considered the output of the mathematical
803
+ function to be differentiated and the second element is
804
+ auxiliary data. Default is False.
805
+
806
+ Returns:
807
+ If `has_aux` is False, returns a (`primals_out`, `tangents_out`) pair,
808
+ where `primals_out` is `fun(*primals)`, and `tangents_out` is the
809
+ Jacobian-vector product of `fun` evaluated at `primals` with
810
+ `tangents`. The `tangents_out` value has the same Python tree
811
+ structure and shapes as `primals_out`.
812
+
813
+ If `has_aux` is True, returns a (`primals_out`, `tangents_out`, `aux`)
814
+ tuple where `aux` is the auxiliary data returned by `fun`.
815
+
816
+ Example:
817
+ >>> from keras import ops
818
+ >>> a1, a2 = ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2)
819
+ >>> primals, tangents = ops.jvp(ops.sin, (a1,), (a2,))
820
+ >>> primals
821
+ 0.09983342
822
+ >>> tangents
823
+ 0.19900084
824
+ """
825
+ if any_symbolic_tensors((primals, tangents)):
826
+ return JVP(has_aux=has_aux).symbolic_call(fun, primals, tangents)
827
+ return backend.linalg.jvp(fun, primals, tangents, has_aux=has_aux)
keras/src/ops/nn.py CHANGED
@@ -6,6 +6,7 @@ from keras.src import backend
6
6
  from keras.src.api_export import keras_export
7
7
  from keras.src.backend import KerasTensor
8
8
  from keras.src.backend import any_symbolic_tensors
9
+ from keras.src.backend import config
9
10
  from keras.src.backend import standardize_data_format
10
11
  from keras.src.backend.common.backend_utils import (
11
12
  compute_conv_transpose_output_shape,
@@ -704,7 +705,15 @@ class Glu(Operation):
704
705
  return backend.nn.glu(x, axis=self.axis)
705
706
 
706
707
  def compute_output_spec(self, x):
707
- return KerasTensor(x.shape, dtype=x.dtype)
708
+ output_shape = list(x.shape)
709
+ if output_shape[self.axis] is not None:
710
+ if output_shape[self.axis] % 2 != 0:
711
+ raise ValueError(
712
+ "axis size must be divisible by 2. "
713
+ f"Received: x.shape={x.shape} with axis={self.axis}"
714
+ )
715
+ output_shape[self.axis] = output_shape[self.axis] // 2
716
+ return KerasTensor(output_shape, dtype=x.dtype)
708
717
 
709
718
 
710
719
  @keras_export(["keras.ops.glu", "keras.ops.nn.glu"])
@@ -1154,6 +1163,87 @@ def max_pool(
1154
1163
  return backend.nn.max_pool(inputs, pool_size, strides, padding, data_format)
1155
1164
 
1156
1165
 
1166
+ class AdaptiveMaxPool(Operation):
1167
+ """Adaptive max pooling operation."""
1168
+
1169
+ def __init__(self, output_size, data_format=None, *, name=None):
1170
+ super().__init__(name=name)
1171
+ self.output_size = output_size
1172
+ self.data_format = data_format
1173
+
1174
+ def call(self, inputs):
1175
+ return backend.nn.adaptive_max_pool(
1176
+ inputs, output_size=self.output_size, data_format=self.data_format
1177
+ )
1178
+
1179
+ def compute_output_spec(self, inputs):
1180
+ if self.data_format == "channels_last":
1181
+ spatial_dims = self.output_size
1182
+ output_shape = (
1183
+ inputs.shape[: -len(self.output_size)]
1184
+ + spatial_dims
1185
+ + (inputs.shape[-1],)
1186
+ )
1187
+ else:
1188
+ spatial_dims = self.output_size
1189
+ output_shape = (inputs.shape[0], inputs.shape[1]) + spatial_dims
1190
+ return backend.KerasTensor(output_shape, dtype=inputs.dtype)
1191
+
1192
+
1193
+ @keras_export(["keras.ops.adaptive_max_pool", "keras.ops.nn.adaptive_max_pool"])
1194
+ def adaptive_max_pool(
1195
+ inputs,
1196
+ output_size,
1197
+ data_format=None,
1198
+ ):
1199
+ """Adaptive max pooling operation.
1200
+
1201
+ Applies an adaptive max pooling operation that automatically computes the
1202
+ kernel size and stride to pool the input to the specified `output_size`.
1203
+ This operation is useful when you want a fixed output size regardless of
1204
+ input size, commonly used in models like ResNet for global feature
1205
+ extraction.
1206
+ Args:
1207
+ inputs: Tensor of rank 4. Input tensor of shape:
1208
+ - If `data_format="channels_last"`:
1209
+ `(batch_size, height, width, channels)`.
1210
+ - If `data_format="channels_first"`:
1211
+ `(batch_size, channels, height, width)`.
1212
+ output_size: Integer or tuple/list of 2 integers, specifying the target
1213
+ output spatial dimensions `(output_height, output_width)`. If a
1214
+ single
1215
+ integer is provided, the same value is used for both dimensions.
1216
+ data_format: string, either `"channels_last"` or `"channels_first"`.
1217
+ Defaults to the value found in your Keras config file at
1218
+ `~/.keras/keras.json`. If never set, defaults to `"channels_last"`.
1219
+
1220
+ Returns:
1221
+ A tensor of rank 4 representing the adaptive max pooled result.
1222
+
1223
+ Example:
1224
+
1225
+ >>> x = np.random.rand(2, 64, 64, 3)
1226
+ >>> y = keras.ops.adaptive_max_pool(x, output_size=(32, 32))
1227
+ >>> y.shape
1228
+ (2, 32, 32, 3)
1229
+
1230
+ >>> # Works with any input size
1231
+ >>> x = np.random.rand(2, 100, 80, 3)
1232
+ >>> y = keras.ops.adaptive_max_pool(x, output_size=7)
1233
+ >>> y.shape
1234
+ (2, 7, 7, 3)
1235
+ """
1236
+ if data_format is None:
1237
+ data_format = config.image_data_format()
1238
+
1239
+ if any_symbolic_tensors((inputs,)):
1240
+ return AdaptiveMaxPool(output_size, data_format).symbolic_call(inputs)
1241
+
1242
+ return backend.nn.adaptive_max_pool(
1243
+ inputs, output_size=output_size, data_format=data_format
1244
+ )
1245
+
1246
+
1157
1247
  class AveragePool(Operation):
1158
1248
  def __init__(
1159
1249
  self,
@@ -1249,6 +1339,92 @@ def average_pool(
1249
1339
  )
1250
1340
 
1251
1341
 
1342
+ class AdaptiveAveragePool(Operation):
1343
+ """Adaptive average pooling operation."""
1344
+
1345
+ def __init__(self, output_size, data_format=None, *, name=None):
1346
+ super().__init__(name=name)
1347
+ self.output_size = output_size
1348
+ self.data_format = data_format
1349
+
1350
+ def call(self, inputs):
1351
+ return backend.nn.adaptive_average_pool(
1352
+ inputs, output_size=self.output_size, data_format=self.data_format
1353
+ )
1354
+
1355
+ def compute_output_spec(self, inputs):
1356
+ if self.data_format == "channels_last":
1357
+ spatial_dims = self.output_size
1358
+ output_shape = (
1359
+ inputs.shape[: -len(self.output_size)]
1360
+ + spatial_dims
1361
+ + (inputs.shape[-1],)
1362
+ )
1363
+ else:
1364
+ spatial_dims = self.output_size
1365
+ output_shape = (inputs.shape[0], inputs.shape[1]) + spatial_dims
1366
+ return backend.KerasTensor(output_shape, dtype=inputs.dtype)
1367
+
1368
+
1369
+ @keras_export(
1370
+ ["keras.ops.adaptive_average_pool", "keras.ops.nn.adaptive_average_pool"]
1371
+ )
1372
+ def adaptive_average_pool(
1373
+ inputs,
1374
+ output_size,
1375
+ data_format=None,
1376
+ ):
1377
+ """Adaptive average pooling operation.
1378
+
1379
+ Applies an adaptive average pooling operation that automatically
1380
+ computes the kernel size and stride to pool the input to the
1381
+ specified `output_size`. This operation is useful when you want a
1382
+ fixed output size regardless of input size, commonly used in models
1383
+ like ResNet for global feature extraction.
1384
+
1385
+ Args:
1386
+ inputs: Tensor of rank 4. Input tensor of shape:
1387
+ - If `data_format="channels_last"`:
1388
+ `(batch_size, height, width, channels)`.
1389
+ - If `data_format="channels_first"`:
1390
+ `(batch_size, channels, height, width)`.
1391
+ output_size: Integer or tuple/list of 2 integers, specifying the target
1392
+ output spatial dimensions `(output_height, output_width)`. If a
1393
+ single
1394
+ integer is provided, the same value is used for both dimensions.
1395
+ data_format: string, either `"channels_last"` or `"channels_first"`.
1396
+ Defaults to the value found in your Keras config file at
1397
+ `~/.keras/keras.json`. If never set, defaults to `"channels_last"`.
1398
+
1399
+ Returns:
1400
+ A tensor of rank 4 representing the adaptive average pooled result.
1401
+
1402
+ Example:
1403
+
1404
+ >>> x = np.random.rand(2, 64, 64, 3)
1405
+ >>> y = keras.ops.adaptive_average_pool(x, output_size=(32, 32))
1406
+ >>> y.shape
1407
+ (2, 32, 32, 3)
1408
+
1409
+ >>> # Works with any input size
1410
+ >>> x = np.random.rand(2, 100, 80, 3)
1411
+ >>> y = keras.ops.adaptive_average_pool(x, output_size=7)
1412
+ >>> y.shape
1413
+ (2, 7, 7, 3)
1414
+ """
1415
+ if data_format is None:
1416
+ data_format = config.image_data_format()
1417
+
1418
+ if any_symbolic_tensors((inputs,)):
1419
+ return AdaptiveAveragePool(output_size, data_format).symbolic_call(
1420
+ inputs
1421
+ )
1422
+
1423
+ return backend.nn.adaptive_average_pool(
1424
+ inputs, output_size=output_size, data_format=data_format
1425
+ )
1426
+
1427
+
1252
1428
  class Conv(Operation):
1253
1429
  def __init__(
1254
1430
  self,
@@ -1435,7 +1611,7 @@ def depthwise_conv(
1435
1611
  """
1436
1612
  data_format = standardize_data_format(data_format)
1437
1613
  padding = padding.lower()
1438
- if any_symbolic_tensors((inputs,)):
1614
+ if any_symbolic_tensors((inputs, kernel)):
1439
1615
  return DepthwiseConv(
1440
1616
  strides, padding, data_format, dilation_rate
1441
1617
  ).symbolic_call(inputs, kernel)
@@ -3047,3 +3223,93 @@ def _polar(abs_, angle):
3047
3223
  result = backend.math._get_complex_tensor_from_tuple((real, imaginary))
3048
3224
 
3049
3225
  return result
3226
+
3227
+
3228
+ class Unfold(Operation):
3229
+ def __init__(
3230
+ self, kernel_size, dilation=1, padding=0, stride=1, *, name=None
3231
+ ):
3232
+ super().__init__(name=name)
3233
+ self.kernel_size = kernel_size
3234
+ self.dilation = dilation
3235
+ self.padding = padding
3236
+ self.stride = stride
3237
+
3238
+ def compute_output_spec(self, x):
3239
+ N, C, H, W = x.shape
3240
+
3241
+ def _pair(x):
3242
+ return (x, x) if isinstance(x, int) else x
3243
+
3244
+ kH, kW = _pair(self.kernel_size)
3245
+ dH, dW = _pair(self.dilation)
3246
+ pH, pW = _pair(self.padding)
3247
+ sH, sW = _pair(self.stride)
3248
+
3249
+ def out_size(L, k, d, p, s):
3250
+ return (L + 2 * p - d * (k - 1) - 1) // s + 1
3251
+
3252
+ outH = out_size(H, kH, dH, pH, sH)
3253
+ outW = out_size(W, kW, dW, pW, sW)
3254
+ return KerasTensor(shape=(N, C * kH * kW, outH * outW), dtype=x.dtype)
3255
+
3256
+ def call(self, x):
3257
+ return _unfold(
3258
+ x, self.kernel_size, self.dilation, self.padding, self.stride
3259
+ )
3260
+
3261
+
3262
+ @keras_export(["keras.ops.unfold", "keras.ops.nn.unfold"])
3263
+ def unfold(x, kernel_size, dilation=1, padding=0, stride=1):
3264
+ """Extract sliding local blocks from a 4-D input (batched image).
3265
+
3266
+ This operation is known as **im2col** when used with convolution.
3267
+ It rearranges the image into overlapping or non-overlapping patches
3268
+ and returns a tensor whose *depth* (last axis) contains the flattened
3269
+ patches.
3270
+
3271
+ Args:
3272
+ x: A 4-D tensor of shape `(N, C, H, W)` (**channels-first** format).
3273
+ kernel_size: int or tuple of two ints, the size of the sliding window
3274
+ `(kH, kW)`. If a single int is given, it is used for both
3275
+ dimensions.
3276
+ dilation: int or tuple of two ints, the spacing between kernel points
3277
+ (a.k.a. **dilation** or **atrous** convolution). Default: 1.
3278
+ padding: int or tuple of two ints, the amount of zero-padding to apply
3279
+ to both spatial dimensions. Default: 0.
3280
+ stride: int or tuple of two ints, the step size of the sliding window.
3281
+ Default: 1.
3282
+
3283
+ Returns:
3284
+ A 3-D tensor of shape `(N, C * kH * kW, L)` where
3285
+ `L = num_patches_H * num_patches_W` is the total number of patches
3286
+ extracted.
3287
+
3288
+ Example:
3289
+
3290
+ >>> x = keras.ops.ones((1, 2, 4, 4))
3291
+ >>> patches = keras.ops.unfold(x, kernel_size=2, stride=2)
3292
+ >>> patches.shape
3293
+ (1, 8, 4)
3294
+
3295
+ """
3296
+ input_shape = x.shape
3297
+ ndims = len(input_shape)
3298
+ if ndims != 4:
3299
+ raise ValueError(
3300
+ f"Input must be a 4D tensor. Received: input.shape={input_shape}"
3301
+ )
3302
+ if any_symbolic_tensors((x,)):
3303
+ return Unfold(kernel_size, dilation, padding, stride).symbolic_call(x)
3304
+ return _unfold(x, kernel_size, dilation, padding, stride)
3305
+
3306
+
3307
+ def _unfold(x, kernel_size, dilation=1, padding=0, stride=1):
3308
+ """Internal implementation of unfold."""
3309
+ return backend.nn.unfold(
3310
+ x,
3311
+ kernel_size=kernel_size,
3312
+ dilation=dilation,
3313
+ padding=padding,
3314
+ stride=stride,
3315
+ )