keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -22,21 +22,14 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
22
22
 
23
23
  def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
24
24
  dtype = dtype or floatx()
25
- ov_type = OPENVINO_DTYPES[dtype]
26
- seed = draw_seed(seed)
27
- if isinstance(seed, OpenVINOKerasTensor):
28
- seed1, seed2 = convert_to_numpy(seed)
25
+ seed_val = draw_seed(seed)
26
+ if isinstance(seed_val, OpenVINOKerasTensor):
27
+ seed_data = convert_to_numpy(seed_val)
29
28
  else:
30
- seed1, seed2 = draw_seed(seed).data
31
- minval_const = ov_opset.constant(minval, dtype=dtype)
32
- maxval_const = ov_opset.constant(maxval, dtype=dtype)
33
- if isinstance(shape, tuple):
34
- shape = list(shape)
35
- output_shape_const = ov_opset.constant(shape, dtype=Type.i32)
36
- random_uniform = ov_opset.random_uniform(
37
- output_shape_const, minval_const, maxval_const, ov_type, seed1, seed2
38
- )
39
- return OpenVINOKerasTensor(random_uniform.output(0))
29
+ seed_data = seed_val.data
30
+ rng = np.random.default_rng(seed_data)
31
+ random_values = rng.uniform(minval, maxval, size=shape).astype(dtype)
32
+ return OpenVINOKerasTensor(ov_opset.constant(random_values).output(0))
40
33
 
41
34
 
42
35
  def categorical(logits, num_samples, dtype="int64", seed=None):
@@ -13,7 +13,6 @@ class TFLayer(KerasAutoTrackable):
13
13
  self._saved_model_arg_spec = None
14
14
  self._tracked = []
15
15
 
16
- @tf.__internal__.tracking.no_automatic_dependency_tracking
17
16
  def _set_save_spec(self, inputs, args=None, kwargs=None):
18
17
  """Defines the save spec so that serialization can trace layer calls.
19
18
 
@@ -45,6 +44,7 @@ class TFLayer(KerasAutoTrackable):
45
44
  kwargs_spec,
46
45
  )
47
46
 
47
+ @tf.__internal__.tracking.no_automatic_dependency_tracking
48
48
  def _trackable_children(self, save_type="checkpoint", **kwargs):
49
49
  if save_type == "savedmodel":
50
50
  # SavedModel needs to ignore the execution functions.
@@ -62,17 +62,51 @@ class TFLayer(KerasAutoTrackable):
62
62
  self.test_function = test_function
63
63
  self.predict_function = predict_function
64
64
 
65
- for tracked_attr in self._tracked:
66
- tracked_item = getattr(self, tracked_attr)
67
- if isinstance(tracked_item, tracking.TrackedList):
68
- children[tracked_attr] = list(tracked_item)
69
- if isinstance(tracked_item, tracking.TrackedDict):
70
- children[tracked_attr] = dict(tracked_item)
71
- if isinstance(tracked_item, tracking.TrackedSet):
72
- children[tracked_attr] = list(tracked_item)
65
+ # Convert Keras tracked collections to plain Python structures
66
+ # without creating TensorFlow trackable dependencies
67
+ self._convert_tracked_collections(children)
73
68
 
74
69
  return children
75
70
 
71
+ def _convert_tracked_collections(self, children):
72
+ """Convert TrackedList/Dict/Set to plain Python structures."""
73
+ for tracked_attr in self._tracked:
74
+ tracked_item = getattr(self, tracked_attr)
75
+ if isinstance(tracked_item, tracking.TrackedList):
76
+ children[tracked_attr] = list(tracked_item)
77
+ if isinstance(tracked_item, tracking.TrackedDict):
78
+ children[tracked_attr] = dict(tracked_item)
79
+ if isinstance(tracked_item, tracking.TrackedSet):
80
+ children[tracked_attr] = list(tracked_item)
81
+
82
+ def _get_save_spec(self, dynamic_batch=True):
83
+ """Compatibility shim for TensorFlow saving utilities.
84
+
85
+ TensorFlow's SavedModel / TFLite export paths (e.g.,
86
+ tf.lite.TFLiteConverter.from_keras_model) expect a `_get_save_spec`
87
+ method on models. This method generates TensorSpec objects
88
+ describing the model's input signature.
89
+
90
+ Args:
91
+ dynamic_batch: whether to set the batch dimension to `None`.
92
+
93
+ Returns:
94
+ A TensorSpec, list or dict mirroring the model inputs, or
95
+ `None` when specs cannot be inferred.
96
+ """
97
+ # Lazy import to avoid circular dependency
98
+ from keras.src.export.export_utils import make_tf_tensor_spec
99
+
100
+ # Fall back to building specs from `self.inputs`
101
+ inputs = getattr(self, "inputs", None)
102
+ if inputs is None:
103
+ return None
104
+
105
+ return tree.map_structure(
106
+ lambda x: make_tf_tensor_spec(x, dynamic_batch=dynamic_batch),
107
+ inputs,
108
+ )
109
+
76
110
  @property
77
111
  def _default_save_signature(self):
78
112
  """For SavedModel support: returns the default serving signature."""
@@ -244,3 +244,27 @@ def lstsq(a, b, rcond=None):
244
244
  if b_orig_ndim == 1:
245
245
  x = tf.reshape(x, [-1])
246
246
  return x
247
+
248
+
249
+ def jvp(fun, primals, tangents, has_aux=False):
250
+ primal_flat = tf.nest.flatten(primals)
251
+ tangent_flat = tf.nest.flatten(tangents)
252
+
253
+ tangent_flat = [
254
+ tf.cast(t, p.dtype) for t, p in zip(tangent_flat, primal_flat)
255
+ ]
256
+
257
+ with tf.autodiff.ForwardAccumulator(primal_flat, tangent_flat) as acc:
258
+ if has_aux:
259
+ primals_out, aux = fun(*primals)
260
+ else:
261
+ primals_out = fun(*primals)
262
+
263
+ primals_out_flat = tf.nest.flatten(primals_out)
264
+ tangents_out_flat = [acc.jvp(po) for po in primals_out_flat]
265
+
266
+ tangents_out = tf.nest.pack_sequence_as(primals_out, tangents_out_flat)
267
+
268
+ if has_aux:
269
+ return primals_out, tangents_out, aux
270
+ return primals_out, tangents_out
@@ -4,6 +4,9 @@ import warnings
4
4
  import tensorflow as tf
5
5
 
6
6
  from keras.src import backend
7
+ from keras.src.backend.common.backend_utils import (
8
+ compute_adaptive_pooling_window_sizes,
9
+ )
7
10
  from keras.src.backend.common.backend_utils import (
8
11
  compute_conv_transpose_output_shape,
9
12
  )
@@ -268,6 +271,486 @@ def average_pool(
268
271
  return outputs
269
272
 
270
273
 
274
+ def _compute_static_gather_indices(
275
+ input_dim, output_size, small_window, big_window
276
+ ):
277
+ """Compute gather indices for Two-Pool Gather method (corrected)."""
278
+ window_starts = tf.cast(
279
+ tf.floor(
280
+ tf.cast(tf.range(output_size), tf.float32)
281
+ * tf.cast(input_dim, tf.float32)
282
+ / tf.cast(output_size, tf.float32)
283
+ ),
284
+ tf.int32,
285
+ )
286
+ window_ends = tf.cast(
287
+ tf.math.ceil(
288
+ tf.cast(tf.range(1, output_size + 1), tf.float32)
289
+ * tf.cast(input_dim, tf.float32)
290
+ / tf.cast(output_size, tf.float32)
291
+ ),
292
+ tf.int32,
293
+ )
294
+
295
+ window_ends = tf.minimum(window_ends, input_dim)
296
+ window_starts = tf.minimum(window_starts, input_dim - 1)
297
+
298
+ window_sizes = window_ends - window_starts
299
+ is_big_window = tf.equal(window_sizes, big_window)
300
+
301
+ small_pool_len = max(1, input_dim - small_window + 1)
302
+
303
+ small_indices = window_starts
304
+ big_indices = window_starts + small_pool_len
305
+
306
+ gather_indices = tf.where(is_big_window, big_indices, small_indices)
307
+ return tf.cast(gather_indices, tf.int32)
308
+
309
+
310
+ def _adaptive_average_pool1d(inputs, output_size, data_format="channels_first"):
311
+ if isinstance(output_size, int):
312
+ output_size = (output_size,)
313
+ if data_format == "channels_first":
314
+ inputs = tf.transpose(inputs, (0, 2, 1))
315
+
316
+ static_shape = inputs.shape.as_list()
317
+ l_static = static_shape[1]
318
+ out_l = output_size[0]
319
+
320
+ if l_static is None:
321
+ raise ValueError(
322
+ "Input length must be statically known for adaptive pooling"
323
+ )
324
+
325
+ small_l, big_l = compute_adaptive_pooling_window_sizes(l_static, out_l)
326
+ gather_l = _compute_static_gather_indices(l_static, out_l, small_l, big_l)
327
+
328
+ small_pool_l = tf.nn.pool(
329
+ inputs,
330
+ window_shape=(small_l,),
331
+ pooling_type="AVG",
332
+ strides=(1,),
333
+ padding="VALID",
334
+ data_format="NWC",
335
+ )
336
+ big_pool_l = tf.nn.pool(
337
+ inputs,
338
+ window_shape=(big_l,),
339
+ pooling_type="AVG",
340
+ strides=(1,),
341
+ padding="VALID",
342
+ data_format="NWC",
343
+ )
344
+
345
+ combined_l = tf.concat([small_pool_l, big_pool_l], axis=1)
346
+ pooled_l = tf.gather(combined_l, gather_l, axis=1)
347
+
348
+ if data_format == "channels_first":
349
+ pooled_l = tf.transpose(pooled_l, (0, 2, 1))
350
+ return pooled_l
351
+
352
+
353
+ def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
354
+ if isinstance(output_size, int):
355
+ output_size = (output_size,)
356
+ if data_format == "channels_first":
357
+ inputs = tf.transpose(inputs, (0, 2, 1))
358
+
359
+ static_shape = inputs.shape.as_list()
360
+ l_static = static_shape[1]
361
+ out_l = output_size[0]
362
+
363
+ if l_static is None:
364
+ raise ValueError(
365
+ "Input length must be statically known for adaptive pooling"
366
+ )
367
+
368
+ small_l, big_l = compute_adaptive_pooling_window_sizes(l_static, out_l)
369
+ gather_l = _compute_static_gather_indices(l_static, out_l, small_l, big_l)
370
+
371
+ small_pool_l = tf.nn.pool(
372
+ inputs,
373
+ window_shape=(small_l,),
374
+ pooling_type="MAX",
375
+ strides=(1,),
376
+ padding="VALID",
377
+ data_format="NWC",
378
+ )
379
+ big_pool_l = tf.nn.pool(
380
+ inputs,
381
+ window_shape=(big_l,),
382
+ pooling_type="MAX",
383
+ strides=(1,),
384
+ padding="VALID",
385
+ data_format="NWC",
386
+ )
387
+
388
+ combined_l = tf.concat([small_pool_l, big_pool_l], axis=1)
389
+ pooled_l = tf.gather(combined_l, gather_l, axis=1)
390
+
391
+ if data_format == "channels_first":
392
+ pooled_l = tf.transpose(pooled_l, (0, 2, 1))
393
+ return pooled_l
394
+
395
+
396
+ def _adaptive_average_pool2d(inputs, output_size, data_format="channels_first"):
397
+ if isinstance(output_size, int):
398
+ output_size = (output_size, output_size)
399
+
400
+ if data_format == "channels_first":
401
+ inputs = tf.transpose(inputs, (0, 2, 3, 1))
402
+
403
+ static_shape = inputs.shape.as_list()
404
+ h_static = static_shape[1]
405
+ w_static = static_shape[2]
406
+ out_h, out_w = output_size
407
+
408
+ if h_static is None or w_static is None:
409
+ raise ValueError(
410
+ "Input spatial dimensions must be "
411
+ "statically known for adaptive pooling"
412
+ )
413
+
414
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)
415
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)
416
+
417
+ gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)
418
+ gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)
419
+
420
+ small_pool_h = tf.nn.pool(
421
+ inputs,
422
+ window_shape=(small_h, 1),
423
+ pooling_type="AVG",
424
+ strides=(1, 1),
425
+ padding="VALID",
426
+ data_format="NHWC",
427
+ )
428
+ big_pool_h = tf.nn.pool(
429
+ inputs,
430
+ window_shape=(big_h, 1),
431
+ pooling_type="AVG",
432
+ strides=(1, 1),
433
+ padding="VALID",
434
+ data_format="NHWC",
435
+ )
436
+
437
+ combined_h = tf.concat([small_pool_h, big_pool_h], axis=1)
438
+ pooled_h = tf.gather(combined_h, gather_h, axis=1)
439
+
440
+ small_pool_w = tf.nn.pool(
441
+ pooled_h,
442
+ window_shape=(1, small_w),
443
+ pooling_type="AVG",
444
+ strides=(1, 1),
445
+ padding="VALID",
446
+ data_format="NHWC",
447
+ )
448
+ big_pool_w = tf.nn.pool(
449
+ pooled_h,
450
+ window_shape=(1, big_w),
451
+ pooling_type="AVG",
452
+ strides=(1, 1),
453
+ padding="VALID",
454
+ data_format="NHWC",
455
+ )
456
+
457
+ combined_w = tf.concat([small_pool_w, big_pool_w], axis=2)
458
+ pooled_w = tf.gather(combined_w, gather_w, axis=2)
459
+
460
+ if data_format == "channels_first":
461
+ pooled_w = tf.transpose(pooled_w, (0, 3, 1, 2))
462
+
463
+ return pooled_w
464
+
465
+
466
+ def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
467
+ """Adaptive Max Pooling 2D using Two-Pool Gather method."""
468
+ if isinstance(output_size, int):
469
+ output_size = (output_size, output_size)
470
+
471
+ if data_format == "channels_first":
472
+ inputs = tf.transpose(inputs, (0, 2, 3, 1))
473
+
474
+ static_shape = inputs.shape.as_list()
475
+ h_static = static_shape[1]
476
+ w_static = static_shape[2]
477
+ out_h, out_w = output_size
478
+
479
+ if h_static is None or w_static is None:
480
+ raise ValueError(
481
+ "Input spatial dimensions must be "
482
+ "statically known for adaptive pooling"
483
+ )
484
+
485
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)
486
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)
487
+
488
+ gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)
489
+ gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)
490
+
491
+ small_pool_h = tf.nn.pool(
492
+ inputs,
493
+ window_shape=(small_h, 1),
494
+ pooling_type="MAX",
495
+ strides=(1, 1),
496
+ padding="VALID",
497
+ data_format="NHWC",
498
+ )
499
+ big_pool_h = tf.nn.pool(
500
+ inputs,
501
+ window_shape=(big_h, 1),
502
+ pooling_type="MAX",
503
+ strides=(1, 1),
504
+ padding="VALID",
505
+ data_format="NHWC",
506
+ )
507
+
508
+ combined_h = tf.concat([small_pool_h, big_pool_h], axis=1)
509
+ pooled_h = tf.gather(combined_h, gather_h, axis=1)
510
+
511
+ small_pool_w = tf.nn.pool(
512
+ pooled_h,
513
+ window_shape=(1, small_w),
514
+ pooling_type="MAX",
515
+ strides=(1, 1),
516
+ padding="VALID",
517
+ data_format="NHWC",
518
+ )
519
+ big_pool_w = tf.nn.pool(
520
+ pooled_h,
521
+ window_shape=(1, big_w),
522
+ pooling_type="MAX",
523
+ strides=(1, 1),
524
+ padding="VALID",
525
+ data_format="NHWC",
526
+ )
527
+
528
+ combined_w = tf.concat([small_pool_w, big_pool_w], axis=2)
529
+ pooled_w = tf.gather(combined_w, gather_w, axis=2)
530
+
531
+ if data_format == "channels_first":
532
+ pooled_w = tf.transpose(pooled_w, (0, 3, 1, 2))
533
+
534
+ return pooled_w
535
+
536
+
537
+ def _adaptive_average_pool3d(inputs, output_size, data_format="channels_first"):
538
+ if isinstance(output_size, int):
539
+ output_size = (output_size, output_size, output_size)
540
+
541
+ if data_format == "channels_first":
542
+ inputs = tf.transpose(inputs, (0, 2, 3, 4, 1))
543
+
544
+ static_shape = inputs.shape.as_list()
545
+ d_static = static_shape[1]
546
+ h_static = static_shape[2]
547
+ w_static = static_shape[3]
548
+ out_d, out_h, out_w = output_size
549
+
550
+ if d_static is None or h_static is None or w_static is None:
551
+ raise ValueError(
552
+ "Input spatial dimensions must be "
553
+ "statically known for adaptive pooling"
554
+ )
555
+
556
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d_static, out_d)
557
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)
558
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)
559
+
560
+ gather_d = _compute_static_gather_indices(d_static, out_d, small_d, big_d)
561
+ gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)
562
+ gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)
563
+
564
+ small_pool_d = tf.nn.pool(
565
+ inputs,
566
+ window_shape=(small_d, 1, 1),
567
+ pooling_type="AVG",
568
+ strides=(1, 1, 1),
569
+ padding="VALID",
570
+ data_format="NDHWC",
571
+ )
572
+ big_pool_d = tf.nn.pool(
573
+ inputs,
574
+ window_shape=(big_d, 1, 1),
575
+ pooling_type="AVG",
576
+ strides=(1, 1, 1),
577
+ padding="VALID",
578
+ data_format="NDHWC",
579
+ )
580
+
581
+ combined_d = tf.concat([small_pool_d, big_pool_d], axis=1)
582
+ pooled_d = tf.gather(combined_d, gather_d, axis=1)
583
+
584
+ small_pool_h = tf.nn.pool(
585
+ pooled_d,
586
+ window_shape=(1, small_h, 1),
587
+ pooling_type="AVG",
588
+ strides=(1, 1, 1),
589
+ padding="VALID",
590
+ data_format="NDHWC",
591
+ )
592
+ big_pool_h = tf.nn.pool(
593
+ pooled_d,
594
+ window_shape=(1, big_h, 1),
595
+ pooling_type="AVG",
596
+ strides=(1, 1, 1),
597
+ padding="VALID",
598
+ data_format="NDHWC",
599
+ )
600
+
601
+ combined_h = tf.concat([small_pool_h, big_pool_h], axis=2)
602
+ pooled_h = tf.gather(combined_h, gather_h, axis=2)
603
+
604
+ small_pool_w = tf.nn.pool(
605
+ pooled_h,
606
+ window_shape=(1, 1, small_w),
607
+ pooling_type="AVG",
608
+ strides=(1, 1, 1),
609
+ padding="VALID",
610
+ data_format="NDHWC",
611
+ )
612
+ big_pool_w = tf.nn.pool(
613
+ pooled_h,
614
+ window_shape=(1, 1, big_w),
615
+ pooling_type="AVG",
616
+ strides=(1, 1, 1),
617
+ padding="VALID",
618
+ data_format="NDHWC",
619
+ )
620
+
621
+ combined_w = tf.concat([small_pool_w, big_pool_w], axis=3)
622
+ pooled_w = tf.gather(combined_w, gather_w, axis=3)
623
+
624
+ if data_format == "channels_first":
625
+ pooled_w = tf.transpose(pooled_w, (0, 4, 1, 2, 3))
626
+
627
+ return pooled_w
628
+
629
+
630
+ def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
631
+ """Adaptive Max Pooling 3D using Two-Pool Gather method."""
632
+ if isinstance(output_size, int):
633
+ output_size = (output_size, output_size, output_size)
634
+
635
+ if data_format == "channels_first":
636
+ inputs = tf.transpose(inputs, (0, 2, 3, 4, 1))
637
+
638
+ static_shape = inputs.shape.as_list()
639
+ d_static = static_shape[1]
640
+ h_static = static_shape[2]
641
+ w_static = static_shape[3]
642
+ out_d, out_h, out_w = output_size
643
+
644
+ if d_static is None or h_static is None or w_static is None:
645
+ raise ValueError(
646
+ "Input spatial dimensions must be "
647
+ "statically known for adaptive pooling"
648
+ )
649
+
650
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d_static, out_d)
651
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)
652
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)
653
+
654
+ gather_d = _compute_static_gather_indices(d_static, out_d, small_d, big_d)
655
+ gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)
656
+ gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)
657
+
658
+ small_pool_d = tf.nn.pool(
659
+ inputs,
660
+ window_shape=(small_d, 1, 1),
661
+ pooling_type="MAX",
662
+ strides=(1, 1, 1),
663
+ padding="VALID",
664
+ data_format="NDHWC",
665
+ )
666
+ big_pool_d = tf.nn.pool(
667
+ inputs,
668
+ window_shape=(big_d, 1, 1),
669
+ pooling_type="MAX",
670
+ strides=(1, 1, 1),
671
+ padding="VALID",
672
+ data_format="NDHWC",
673
+ )
674
+
675
+ combined_d = tf.concat([small_pool_d, big_pool_d], axis=1)
676
+ pooled_d = tf.gather(combined_d, gather_d, axis=1)
677
+
678
+ small_pool_h = tf.nn.pool(
679
+ pooled_d,
680
+ window_shape=(1, small_h, 1),
681
+ pooling_type="MAX",
682
+ strides=(1, 1, 1),
683
+ padding="VALID",
684
+ data_format="NDHWC",
685
+ )
686
+ big_pool_h = tf.nn.pool(
687
+ pooled_d,
688
+ window_shape=(1, big_h, 1),
689
+ pooling_type="MAX",
690
+ strides=(1, 1, 1),
691
+ padding="VALID",
692
+ data_format="NDHWC",
693
+ )
694
+
695
+ combined_h = tf.concat([small_pool_h, big_pool_h], axis=2)
696
+ pooled_h = tf.gather(combined_h, gather_h, axis=2)
697
+
698
+ small_pool_w = tf.nn.pool(
699
+ pooled_h,
700
+ window_shape=(1, 1, small_w),
701
+ pooling_type="MAX",
702
+ strides=(1, 1, 1),
703
+ padding="VALID",
704
+ data_format="NDHWC",
705
+ )
706
+ big_pool_w = tf.nn.pool(
707
+ pooled_h,
708
+ window_shape=(1, 1, big_w),
709
+ pooling_type="MAX",
710
+ strides=(1, 1, 1),
711
+ padding="VALID",
712
+ data_format="NDHWC",
713
+ )
714
+
715
+ combined_w = tf.concat([small_pool_w, big_pool_w], axis=3)
716
+ pooled_w = tf.gather(combined_w, gather_w, axis=3)
717
+
718
+ if data_format == "channels_first":
719
+ pooled_w = tf.transpose(pooled_w, (0, 4, 1, 2, 3))
720
+
721
+ return pooled_w
722
+
723
+
724
+ def adaptive_average_pool(inputs, output_size, data_format=None):
725
+ data_format = backend.standardize_data_format(data_format)
726
+ ndims = len(inputs.shape) - 2
727
+ if ndims == 1:
728
+ return _adaptive_average_pool1d(inputs, output_size, data_format)
729
+ elif ndims == 2:
730
+ return _adaptive_average_pool2d(inputs, output_size, data_format)
731
+ elif ndims == 3:
732
+ return _adaptive_average_pool3d(inputs, output_size, data_format)
733
+ else:
734
+ raise ValueError(
735
+ "adaptive_average_pool supports 1D, 2D, or 3D inputs only."
736
+ )
737
+
738
+
739
+ def adaptive_max_pool(inputs, output_size, data_format=None):
740
+ data_format = backend.standardize_data_format(data_format)
741
+ ndims = len(inputs.shape) - 2
742
+ if ndims == 1:
743
+ return _adaptive_max_pool1d(inputs, output_size, data_format)
744
+ elif ndims == 2:
745
+ return _adaptive_max_pool2d(inputs, output_size, data_format)
746
+ elif ndims == 3:
747
+ return _adaptive_max_pool3d(inputs, output_size, data_format)
748
+ else:
749
+ raise ValueError(
750
+ "adaptive_max_pool supports 1D, 2D, or 3D inputs only."
751
+ )
752
+
753
+
271
754
  def _convert_data_format(data_format, ndim):
272
755
  if data_format == "channels_last":
273
756
  if ndim == 3:
@@ -310,7 +793,7 @@ def conv(
310
793
  ):
311
794
  def _conv():
312
795
  tf_data_format = _convert_data_format(data_format, len(inputs.shape))
313
- return tf.nn.convolution(
796
+ result = tf.nn.convolution(
314
797
  inputs,
315
798
  kernel,
316
799
  strides,
@@ -318,6 +801,20 @@ def conv(
318
801
  data_format=tf_data_format,
319
802
  dilations=dilation_rate,
320
803
  )
804
+ result_shape = result.shape
805
+ if (
806
+ result_shape.is_fully_defined()
807
+ and math.prod(result_shape.as_list()) == 0
808
+ ):
809
+ raise ValueError(
810
+ "The convolution operation resulted in an empty output. "
811
+ "Output shape:"
812
+ f" {result_shape}. This can happen if the input is too small "
813
+ "for the given kernel size, strides, dilation rate, and "
814
+ "padding mode. Please check the input shape and convolution "
815
+ "parameters."
816
+ )
817
+ return result
321
818
 
322
819
  # Certain ops are are broken in Tensorflow on CPU only.
323
820
  # We can work around by compiling the op with XLA.
@@ -1077,3 +1574,50 @@ def dot_product_attention(
1077
1574
  return _dot_product_attention_xla(
1078
1575
  query, key, value, bias, mask, is_causal, scale
1079
1576
  )
1577
+
1578
+
1579
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1580
+ """Tensorflow implementation of Unfold.
1581
+ Extract sliding local blocks from a **NCHW** batched image tensor.
1582
+
1583
+ Args:
1584
+ input: 4-D tensor, shape (N, C, H, W) **required**.
1585
+ kernel_size: int or (kH, kW)
1586
+ dilation: int or (dH, dW), default 1
1587
+ padding: int or (pH, pW), default 0
1588
+ stride: int or (sH, sW), default 1
1589
+
1590
+ Returns:
1591
+ 3-D tensor, shape (N, C*kH*kW, L)
1592
+ """
1593
+ k = (
1594
+ (kernel_size, kernel_size)
1595
+ if isinstance(kernel_size, int)
1596
+ else kernel_size
1597
+ )
1598
+ d = (dilation, dilation) if isinstance(dilation, int) else dilation
1599
+ p = (padding, padding) if isinstance(padding, int) else padding
1600
+ s = (stride, stride) if isinstance(stride, int) else stride
1601
+ N, C, H, W = input.shape
1602
+
1603
+ # ---- padding ----
1604
+ if any(_ > 0 for _ in p):
1605
+ input = tf.pad(input, [[0, 0], [0, 0], [p[0], p[0]], [p[1], p[1]]])
1606
+ x = tf.transpose(input, [0, 2, 3, 1]) # (N, H, W, C)
1607
+ patches = tf.image.extract_patches(
1608
+ images=x,
1609
+ sizes=[1, k[0], k[1], 1],
1610
+ strides=[1, s[0], s[1], 1],
1611
+ rates=[1, d[0], d[1], 1],
1612
+ padding="VALID",
1613
+ ) # (N, nH, nW, kH*kW*C)
1614
+
1615
+ N, nH, nW, D = patches.shape
1616
+ patches = tf.reshape(
1617
+ patches, [N, nH, nW, k[0], k[1], C]
1618
+ ) # (N, nH, nW, kH, kW, C)
1619
+ patches = tf.transpose(
1620
+ patches, [0, 5, 3, 4, 1, 2]
1621
+ ) # (N, C, kH, kW, nH, nW)
1622
+ patches = tf.reshape(patches, [N, C * k[0] * k[1], nH * nW])
1623
+ return patches