keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -813,16 +813,17 @@ def append(x1, x2, axis=None):
813
813
  return tf.concat([x1, x2], axis=axis)
814
814
 
815
815
 
816
- def arange(start, stop=None, step=1, dtype=None):
816
+ def arange(start, stop=None, step=None, dtype=None):
817
817
  if dtype is None:
818
- dtypes_to_resolve = [
819
- getattr(start, "dtype", type(start)),
820
- getattr(step, "dtype", type(step)),
821
- ]
818
+ dtypes_to_resolve = [getattr(start, "dtype", type(start))]
822
819
  if stop is not None:
823
820
  dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
821
+ if step is not None:
822
+ dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
824
823
  dtype = dtypes.result_type(*dtypes_to_resolve)
825
824
  dtype = standardize_dtype(dtype)
825
+ if step is None:
826
+ step = 1
826
827
  try:
827
828
  out = tf.range(start, stop, delta=step, dtype=dtype)
828
829
  except tf.errors.NotFoundError:
@@ -997,6 +998,51 @@ def array(x, dtype=None):
997
998
  return convert_to_tensor(x, dtype=dtype)
998
999
 
999
1000
 
1001
+ def view(x, dtype=None):
1002
+ from keras.src import backend
1003
+
1004
+ x = convert_to_tensor(x)
1005
+ old_dtype = tf.as_dtype(backend.standardize_dtype(x.dtype))
1006
+ new_dtype = tf.as_dtype(
1007
+ backend.standardize_dtype(dtype if dtype else x.dtype)
1008
+ )
1009
+
1010
+ old_itemsize = old_dtype.size
1011
+ new_itemsize = new_dtype.size
1012
+
1013
+ old_shape = list(shape_op(x))
1014
+ last_dim_size = old_shape[-1] if len(old_shape) > 0 else -1
1015
+ if (last_dim_size == -1 and old_itemsize != new_itemsize) or (
1016
+ last_dim_size * old_itemsize % new_itemsize != 0
1017
+ ):
1018
+ raise ValueError(
1019
+ f"Cannot view array of shape {x.shape} and dtype {old_dtype} "
1020
+ f"as dtype {new_dtype} because the total number of bytes "
1021
+ f"is not divisible by the new itemsize."
1022
+ )
1023
+
1024
+ if old_itemsize == new_itemsize:
1025
+ return tf.bitcast(x, type=new_dtype)
1026
+ elif old_itemsize > new_itemsize:
1027
+ ratio = old_itemsize // new_itemsize
1028
+ new_shape = list(shape_op(x))
1029
+ new_shape[-1] *= ratio
1030
+ flat_tensor = tf.reshape(x, [-1])
1031
+ cast_tensor = tf.bitcast(flat_tensor, type=new_dtype)
1032
+ return tf.reshape(cast_tensor, new_shape)
1033
+ else:
1034
+ ratio = new_itemsize // old_itemsize
1035
+ if isinstance(last_dim_size, int) and last_dim_size % ratio != 0:
1036
+ raise ValueError(
1037
+ f"Cannot view dtype. Last dimension size ({last_dim_size}) "
1038
+ f"must be divisible by the ratio of new/old item sizes "
1039
+ f"({ratio})."
1040
+ )
1041
+ intermediate_shape = old_shape[:-1] + [last_dim_size // ratio, ratio]
1042
+ reshaped_tensor = tf.reshape(x, intermediate_shape)
1043
+ return tf.bitcast(reshaped_tensor, new_dtype)
1044
+
1045
+
1000
1046
  def average(x, axis=None, weights=None):
1001
1047
  x = convert_to_tensor(x)
1002
1048
 
@@ -1313,11 +1359,7 @@ def deg2rad(x):
1313
1359
  def diag(x, k=0):
1314
1360
  x = convert_to_tensor(x)
1315
1361
  if len(x.shape) == 1:
1316
- return tf.cond(
1317
- tf.equal(tf.size(x), 0),
1318
- lambda: tf.zeros([builtins.abs(k), builtins.abs(k)], dtype=x.dtype),
1319
- lambda: tf.linalg.diag(x, k=k),
1320
- )
1362
+ return tf.linalg.diag(x, k=k)
1321
1363
  elif len(x.shape) == 2:
1322
1364
  return diagonal(x, offset=k)
1323
1365
  else:
@@ -1443,6 +1485,10 @@ def empty(shape, dtype=None):
1443
1485
  return tf.zeros(shape, dtype=dtype)
1444
1486
 
1445
1487
 
1488
+ def empty_like(x, dtype=None):
1489
+ return tf.zeros_like(x, dtype=dtype)
1490
+
1491
+
1446
1492
  def equal(x1, x2):
1447
1493
  x1 = convert_to_tensor(x1)
1448
1494
  x2 = convert_to_tensor(x2)
@@ -1711,6 +1757,14 @@ def isposinf(x):
1711
1757
  return tf.math.equal(x, tf.constant(float("inf"), dtype=x.dtype))
1712
1758
 
1713
1759
 
1760
+ def isreal(x):
1761
+ x = convert_to_tensor(x)
1762
+ if x.dtype.is_complex:
1763
+ return tf.equal(tf.math.imag(x), 0)
1764
+ else:
1765
+ return tf.ones_like(x, dtype=tf.bool)
1766
+
1767
+
1714
1768
  def kron(x1, x2):
1715
1769
  x1 = convert_to_tensor(x1)
1716
1770
  x2 = convert_to_tensor(x2)
@@ -1786,6 +1840,23 @@ def lcm(x1, x2):
1786
1840
  return result
1787
1841
 
1788
1842
 
1843
+ def ldexp(x1, x2):
1844
+ x1 = convert_to_tensor(x1)
1845
+ x2 = convert_to_tensor(x2)
1846
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1847
+
1848
+ if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
1849
+ raise TypeError(
1850
+ f"ldexp exponent must be an integer type. "
1851
+ f"Received: x2 dtype={x2.dtype}"
1852
+ )
1853
+
1854
+ x1 = tf.cast(x1, dtype)
1855
+ x2 = tf.cast(x2, x1.dtype)
1856
+ result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)
1857
+ return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)
1858
+
1859
+
1789
1860
  def less(x1, x2):
1790
1861
  x1 = convert_to_tensor(x1)
1791
1862
  x2 = convert_to_tensor(x2)
@@ -1909,6 +1980,22 @@ def logaddexp(x1, x2):
1909
1980
  )
1910
1981
 
1911
1982
 
1983
+ def logaddexp2(x1, x2):
1984
+ x1 = tf.convert_to_tensor(x1)
1985
+ x2 = tf.convert_to_tensor(x2)
1986
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1987
+ x1 = tf.cast(x1, dtype)
1988
+ x2 = tf.cast(x2, dtype)
1989
+ delta = x1 - x2
1990
+ log2 = tf.cast(tf.math.log(2.0), dtype)
1991
+ return tf.where(
1992
+ tf.math.is_nan(delta),
1993
+ x1 + x2,
1994
+ tf.maximum(x1, x2)
1995
+ + tf.math.log1p(tf.math.exp(-tf.abs(delta) * log2)) / log2,
1996
+ )
1997
+
1998
+
1912
1999
  def logical_and(x1, x2):
1913
2000
  x1 = tf.cast(x1, "bool")
1914
2001
  x2 = tf.cast(x2, "bool")
@@ -2233,7 +2320,7 @@ def _quantile(x, q, axis=None, method="linear", keepdims=False):
2233
2320
  return gathered_y
2234
2321
  perm = collections.deque(range(ndims))
2235
2322
  perm.rotate(shift_value_static)
2236
- return tf.transpose(a=gathered_y, perm=perm)
2323
+ return tf.transpose(a=gathered_y, perm=list(perm))
2237
2324
 
2238
2325
 
2239
2326
  def quantile(x, q, axis=None, method="linear", keepdims=False):
@@ -2426,6 +2513,17 @@ def split(x, indices_or_sections, axis=0):
2426
2513
  return tf.split(x, num_or_size_splits, axis=axis)
2427
2514
 
2428
2515
 
2516
+ def array_split(x, indices_or_sections, axis=0):
2517
+ x = tf.convert_to_tensor(x)
2518
+ num_splits = indices_or_sections
2519
+ total_size = shape_op(x)[axis]
2520
+ avg_size = total_size // num_splits
2521
+ remainder = total_size % num_splits
2522
+ sizes = [avg_size + 1] * remainder + [avg_size] * (num_splits - remainder)
2523
+
2524
+ return tf.split(x, sizes, axis=axis)
2525
+
2526
+
2429
2527
  def stack(x, axis=0):
2430
2528
  dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
2431
2529
  if len(dtype_set) > 1:
@@ -2657,27 +2755,44 @@ def round(x, decimals=0):
2657
2755
 
2658
2756
  def tile(x, repeats):
2659
2757
  x = convert_to_tensor(x)
2660
- repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1])
2661
- repeats_size = tf.size(repeats)
2662
- repeats = tf.pad(
2663
- repeats,
2664
- [[tf.maximum(x.shape.rank - repeats_size, 0), 0]],
2665
- constant_values=1,
2666
- )
2667
- x_shape = tf.pad(
2668
- tf.shape(x),
2669
- [[tf.maximum(repeats_size - x.shape.rank, 0), 0]],
2670
- constant_values=1,
2671
- )
2672
- x = tf.reshape(x, x_shape)
2758
+
2759
+ # Convert repeats to a list (works for both sequences and 1D tensors)
2760
+ if isinstance(repeats, int):
2761
+ repeats = [repeats]
2762
+ else:
2763
+ repeats = [v for v in repeats]
2764
+
2765
+ # Process list elements: convert concrete scalar tensors to Python ints
2766
+ processed_repeats = []
2767
+ for r in repeats:
2768
+ if hasattr(r, "numpy") and r.shape == ():
2769
+ processed_repeats.append(int(r.numpy()))
2770
+ else:
2771
+ processed_repeats.append(r)
2772
+ repeats = processed_repeats
2773
+
2774
+ # Get x rank
2775
+ x_rank = x.shape.rank
2776
+
2777
+ # Pad repeats if needed
2778
+ if len(repeats) < x_rank:
2779
+ repeats = [1] * (x_rank - len(repeats)) + repeats
2780
+
2781
+ # Add dimensions to x if needed using tf.expand_dims
2782
+ while len(repeats) > x.shape.rank:
2783
+ x = tf.expand_dims(x, 0)
2784
+
2673
2785
  return tf.tile(x, repeats)
2674
2786
 
2675
2787
 
2676
2788
  def trace(x, offset=0, axis1=0, axis2=1):
2677
2789
  x = convert_to_tensor(x)
2678
2790
  dtype = standardize_dtype(x.dtype)
2679
- if dtype not in ("int64", "uint32", "uint64"):
2680
- dtype = dtypes.result_type(dtype, "int32")
2791
+ if dtype in ("bool", "int8", "int16"):
2792
+ dtype = "int32"
2793
+ elif dtype in ("uint8", "uint16"):
2794
+ dtype = "uint32"
2795
+ x = tf.cast(x, dtype)
2681
2796
  x_shape = tf.shape(x)
2682
2797
  x = moveaxis(x, (axis1, axis2), (-2, -1))
2683
2798
  # Mask out the diagonal and reduce.
@@ -2686,10 +2801,7 @@ def trace(x, offset=0, axis1=0, axis2=1):
2686
2801
  x,
2687
2802
  tf.zeros_like(x),
2688
2803
  )
2689
- # The output dtype is set to "int32" if the input dtype is "bool"
2690
- if standardize_dtype(x.dtype) == "bool":
2691
- x = tf.cast(x, "int32")
2692
- return tf.cast(tf.reduce_sum(x, axis=(-2, -1)), dtype)
2804
+ return tf.reduce_sum(x, axis=(-2, -1))
2693
2805
 
2694
2806
 
2695
2807
  def tri(N, M=None, k=0, dtype=None):
@@ -2905,6 +3017,16 @@ def negative(x):
2905
3017
  return tf.negative(x)
2906
3018
 
2907
3019
 
3020
+ def nextafter(x1, x2):
3021
+ x1 = convert_to_tensor(x1)
3022
+ x2 = convert_to_tensor(x2)
3023
+
3024
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
3025
+ x1 = tf.cast(x1, tf.float64)
3026
+ x2 = tf.cast(x2, tf.float64)
3027
+ return tf.cast(tf.math.nextafter(x1, x2), dtype)
3028
+
3029
+
2908
3030
  @sparse.elementwise_unary
2909
3031
  def square(x):
2910
3032
  x = convert_to_tensor(x)
@@ -2959,6 +3081,63 @@ def transpose(x, axes=None):
2959
3081
  return tf.transpose(x, perm=axes)
2960
3082
 
2961
3083
 
3084
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
3085
+ def _move_axis_to_last(tensor, axis):
3086
+ if axis == -1:
3087
+ return tensor
3088
+ rank = tf.rank(tensor)
3089
+ if axis < 0:
3090
+ axis = rank + axis
3091
+ perm = tf.concat(
3092
+ [
3093
+ tf.range(axis, dtype=tf.int32),
3094
+ tf.range(axis + 1, rank, dtype=tf.int32),
3095
+ tf.constant([axis], dtype=tf.int32),
3096
+ ],
3097
+ axis=0,
3098
+ )
3099
+ return tf.transpose(tensor, perm=perm)
3100
+
3101
+ y = convert_to_tensor(y)
3102
+ dtype = dtypes.result_type(y.dtype, float)
3103
+ y = tf.cast(y, dtype)
3104
+
3105
+ if x is None:
3106
+ dx_array = tf.cast(dx, dtype)
3107
+ else:
3108
+ x = convert_to_tensor(x, dtype=dtype)
3109
+ dx_array = diff(x, axis=axis)
3110
+ dx_array = _move_axis_to_last(dx_array, axis)
3111
+
3112
+ y = _move_axis_to_last(y, axis)
3113
+
3114
+ avg_heights = 0.5 * (y[..., 1:] + y[..., :-1])
3115
+ result = tf.reduce_sum(avg_heights * dx_array, axis=-1)
3116
+
3117
+ return result
3118
+
3119
+
3120
+ def vander(x, N=None, increasing=False):
3121
+ x = convert_to_tensor(x)
3122
+ result_dtype = dtypes.result_type(x.dtype)
3123
+
3124
+ if N is None:
3125
+ N = shape_op(x)[0]
3126
+
3127
+ if increasing:
3128
+ powers = tf.range(N)
3129
+ else:
3130
+ powers = tf.range(N - 1, -1, -1)
3131
+
3132
+ x_exp = tf.expand_dims(x, axis=-1)
3133
+
3134
+ compute_dtype = dtypes.result_type(x.dtype, "float32")
3135
+ vander = tf.math.pow(
3136
+ tf.cast(x_exp, compute_dtype), tf.cast(powers, compute_dtype)
3137
+ )
3138
+ return tf.cast(vander, result_dtype)
3139
+
3140
+
2962
3141
  def var(x, axis=None, keepdims=False):
2963
3142
  x = convert_to_tensor(x)
2964
3143
  compute_dtype = dtypes.result_type(x.dtype, "float32")
@@ -3069,30 +3248,57 @@ def correlate(x1, x2, mode="valid"):
3069
3248
  x1 = tf.cast(x1, dtype)
3070
3249
  x2 = tf.cast(x2, dtype)
3071
3250
 
3072
- x1_len, x2_len = int(x1.shape[0]), int(x2.shape[0])
3251
+ def _pack(a, b):
3252
+ # a: input [N] -> [1,N,1];
3253
+ # b: filter [M] -> [M,1,1]
3254
+ return (
3255
+ tf.reshape(a, (1, shape_op(a)[0], 1)),
3256
+ tf.reshape(b, (shape_op(b)[0], 1, 1)),
3257
+ )
3073
3258
 
3074
- if mode == "full":
3075
- full_len = x1_len + x2_len - 1
3259
+ def _full_corr(x1, x2):
3260
+ """Compute 'full' correlation result (length = n + m - 1)."""
3261
+ m = shape_op(x2)[0]
3262
+ pad = (
3263
+ builtins.max(m - 1, 0)
3264
+ if isinstance(m, int)
3265
+ else tf.maximum(m - 1, 0)
3266
+ )
3267
+ x1 = tf.pad(x1, [[pad, pad]]) # pad input with zeros
3268
+ x1, x2 = _pack(x1, x2)
3269
+ out = tf.nn.conv1d(x1, x2, stride=1, padding="VALID")
3270
+ return tf.squeeze(out, axis=[0, 2])
3076
3271
 
3077
- x1_pad = (full_len - x1_len) / 2
3078
- x2_pad = (full_len - x2_len) / 2
3272
+ n = shape_op(x1)[0]
3273
+ m = shape_op(x2)[0]
3079
3274
 
3080
- x1 = tf.pad(
3081
- x1, paddings=[[tf.math.floor(x1_pad), tf.math.ceil(x1_pad)]]
3275
+ if mode == "full":
3276
+ return _full_corr(x1, x2)
3277
+ elif mode == "same":
3278
+ # unfortunately we can't leverage 'SAME' padding directly like
3279
+ # we can with "valid"
3280
+ # it works fine for odd-length filters, but for even-length filters
3281
+ # the output is off by 1 compared to numpy, due to how
3282
+ # tf handles centering
3283
+ full_corr = _full_corr(x1, x2)
3284
+ full_len = n + m - 1
3285
+ out_len = (
3286
+ max([n, m])
3287
+ if isinstance(n, int) and isinstance(m, int)
3288
+ else tf.maximum(n, m)
3082
3289
  )
3083
- x2 = tf.pad(
3084
- x2, paddings=[[tf.math.floor(x2_pad), tf.math.ceil(x2_pad)]]
3290
+ start = (full_len - out_len) // 2
3291
+ return tf.slice(full_corr, [start], [out_len])
3292
+ elif mode == "valid":
3293
+ x1, x2 = _pack(x1, x2)
3294
+ return tf.squeeze(
3295
+ tf.nn.conv1d(x1, x2, stride=1, padding="VALID"), axis=[0, 2]
3296
+ )
3297
+ else:
3298
+ raise ValueError(
3299
+ f"Invalid mode: '{mode}'. Mode must be one of:"
3300
+ f" 'full', 'same', 'valid'."
3085
3301
  )
3086
-
3087
- x1 = tf.reshape(x1, (1, full_len, 1))
3088
- x2 = tf.reshape(x2, (full_len, 1, 1))
3089
-
3090
- return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding="SAME"))
3091
-
3092
- x1 = tf.reshape(x1, (1, x1_len, 1))
3093
- x2 = tf.reshape(x2, (x2_len, 1, 1))
3094
-
3095
- return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding=mode.upper()))
3096
3302
 
3097
3303
 
3098
3304
  def select(condlist, choicelist, default=0):
@@ -3144,10 +3350,14 @@ def histogram(x, bins=10, range=None):
3144
3350
 
3145
3351
  x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
3146
3352
  bin_edges = tf.linspace(min_val, max_val, bins + 1)
3147
- bin_edges_list = bin_edges.numpy().tolist()
3148
- bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1])
3149
-
3150
- bin_counts = tf.math.bincount(
3151
- bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype
3353
+ bin_edges = tf.cast(bin_edges, x.dtype)
3354
+ bin_indices = tf.searchsorted(bin_edges[1:-1], x, side="right")
3355
+
3356
+ # tf.math.bincount does not work with XLA in this case. So, we use
3357
+ # `scatter_nd`.
3358
+ bin_counts = tf.scatter_nd(
3359
+ indices=tf.expand_dims(bin_indices, axis=-1),
3360
+ updates=tf.ones_like(bin_indices, dtype=x.dtype),
3361
+ shape=(bins,),
3152
3362
  )
3153
3363
  return bin_counts, bin_edges
@@ -673,7 +673,9 @@ def remat(f):
673
673
  """
674
674
 
675
675
  def wrapped(*args, **kwargs):
676
- return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False)
676
+ return torch.utils.checkpoint.checkpoint(
677
+ f, *args, use_reentrant=False, **kwargs
678
+ )
677
679
 
678
680
  return wrapped
679
681
 
@@ -80,3 +80,7 @@ def lstsq(a, b, rcond=None):
80
80
  a = convert_to_tensor(a)
81
81
  b = convert_to_tensor(b)
82
82
  return torch.linalg.lstsq(a, b, rcond=rcond)[0]
83
+
84
+
85
+ def jvp(fun, primals, tangents, has_aux=False):
86
+ return torch.func.jvp(fun, primals, tangents, has_aux=has_aux)
@@ -458,6 +458,94 @@ def average_pool(
458
458
  return outputs
459
459
 
460
460
 
461
+ def adaptive_average_pool(inputs, output_size, data_format=None):
462
+ """Adaptive average pooling(1D/2D/3D) with channels_last support."""
463
+ inputs = convert_to_tensor(inputs)
464
+ num_spatial_dims = inputs.ndim - 2
465
+
466
+ data_format = backend.standardize_data_format(data_format)
467
+ orig_format = data_format
468
+ if data_format == "channels_last":
469
+ inputs = _transpose_spatial_inputs(inputs)
470
+
471
+ if isinstance(output_size, int):
472
+ torch_output_size = (
473
+ output_size
474
+ if num_spatial_dims == 1
475
+ else (output_size,) * num_spatial_dims
476
+ )
477
+ else:
478
+ torch_output_size = standardize_tuple(
479
+ output_size, num_spatial_dims, "output_size"
480
+ )
481
+
482
+ if get_device() == "meta":
483
+ inputs = torch.empty(
484
+ size=inputs.shape, dtype=inputs.dtype, device="cpu"
485
+ )
486
+
487
+ if num_spatial_dims == 1:
488
+ outputs = tnn.adaptive_avg_pool1d(inputs, output_size=torch_output_size)
489
+ elif num_spatial_dims == 2:
490
+ outputs = tnn.adaptive_avg_pool2d(inputs, output_size=torch_output_size)
491
+ elif num_spatial_dims == 3:
492
+ outputs = tnn.adaptive_avg_pool3d(inputs, output_size=torch_output_size)
493
+ else:
494
+ raise ValueError(
495
+ "Inputs to adaptive average pooling must have ndim=3, 4 or 5, "
496
+ f"Received input shape: {inputs.shape}."
497
+ )
498
+
499
+ if orig_format == "channels_last":
500
+ outputs = _transpose_spatial_outputs(outputs)
501
+ return outputs
502
+
503
+
504
+ def adaptive_max_pool(inputs, output_size, data_format=None):
505
+ """Adaptive max pooling(1D/2D/3D) with channels_last support."""
506
+ inputs = convert_to_tensor(inputs)
507
+ num_spatial_dims = inputs.ndim - 2
508
+
509
+ data_format = backend.standardize_data_format(data_format)
510
+ orig_format = data_format
511
+ if data_format == "channels_last":
512
+ inputs = _transpose_spatial_inputs(inputs)
513
+
514
+ if isinstance(output_size, int):
515
+ torch_output_size = (
516
+ output_size
517
+ if num_spatial_dims == 1
518
+ else (output_size,) * num_spatial_dims
519
+ )
520
+ else:
521
+ torch_output_size = standardize_tuple(
522
+ output_size, num_spatial_dims, "output_size"
523
+ )
524
+
525
+ if get_device() == "meta":
526
+ inputs = torch.empty(
527
+ size=inputs.shape, dtype=inputs.dtype, device="cpu"
528
+ )
529
+
530
+ if num_spatial_dims == 1:
531
+ res = tnn.adaptive_max_pool1d(inputs, output_size=torch_output_size)
532
+ elif num_spatial_dims == 2:
533
+ res = tnn.adaptive_max_pool2d(inputs, output_size=torch_output_size)
534
+ elif num_spatial_dims == 3:
535
+ res = tnn.adaptive_max_pool3d(inputs, output_size=torch_output_size)
536
+ else:
537
+ raise ValueError(
538
+ "Inputs to adaptive max pooling must have ndim=3, 4 or 5, "
539
+ f"Received input shape: {inputs.shape}."
540
+ )
541
+
542
+ outputs = res[0] if isinstance(res, tuple) else res
543
+
544
+ if orig_format == "channels_last":
545
+ outputs = _transpose_spatial_outputs(outputs)
546
+ return outputs
547
+
548
+
461
549
  def conv(
462
550
  inputs,
463
551
  kernel,
@@ -755,12 +843,26 @@ def binary_crossentropy(target, output, from_logits=False):
755
843
  target = convert_to_tensor(target)
756
844
  output = convert_to_tensor(output)
757
845
 
846
+ # We only apply the squeeze fix if we are on an MPS device,
847
+ # as this change breaks tests on other platforms that
848
+ # expect the original tensor shape to be preserved.
849
+ if (
850
+ torch.backends.mps.is_available()
851
+ and target.ndim > 1
852
+ and output.ndim == target.ndim
853
+ and target.shape[-1] == 1
854
+ and output.shape[-1] == 1
855
+ ):
856
+ target = torch.squeeze(target, -1).contiguous()
857
+ output = torch.squeeze(output, -1).contiguous()
858
+
758
859
  if target.shape != output.shape:
759
860
  raise ValueError(
760
861
  "Arguments `target` and `output` must have the same shape. "
761
862
  "Received: "
762
863
  f"target.shape={target.shape}, output.shape={output.shape}"
763
864
  )
865
+
764
866
  # By default, PyTorch, does reduction of `sum` over all rows,
765
867
  # change reduction to `none` to keep dim
766
868
  if from_logits:
@@ -1092,3 +1194,26 @@ def dot_product_attention(
1092
1194
  scale=scale,
1093
1195
  )
1094
1196
  return torch.transpose(attention_output, axis1, axis0)
1197
+
1198
+
1199
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1200
+ """Native PyTorch implementation of Unfold.
1201
+ Extract sliding local blocks from a **NCHW** batched image tensor.
1202
+
1203
+ Args:
1204
+ input: 4-D tensor, shape (N, C, H, W) **required**.
1205
+ kernel_size: int or (kH, kW)
1206
+ dilation: int or (dH, dW), default 1
1207
+ padding: int or (pH, pW), default 0
1208
+ stride: int or (sH, sW), default 1
1209
+
1210
+ Returns:
1211
+ 3-D tensor, shape (N, C*kH*kW, L)
1212
+ """
1213
+ return tnn.unfold(
1214
+ input,
1215
+ kernel_size=kernel_size,
1216
+ dilation=dilation,
1217
+ padding=padding,
1218
+ stride=stride,
1219
+ )