keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -998,6 +998,51 @@ def array(x, dtype=None):
998
998
  return convert_to_tensor(x, dtype=dtype)
999
999
 
1000
1000
 
1001
+ def view(x, dtype=None):
1002
+ from keras.src import backend
1003
+
1004
+ x = convert_to_tensor(x)
1005
+ old_dtype = tf.as_dtype(backend.standardize_dtype(x.dtype))
1006
+ new_dtype = tf.as_dtype(
1007
+ backend.standardize_dtype(dtype if dtype else x.dtype)
1008
+ )
1009
+
1010
+ old_itemsize = old_dtype.size
1011
+ new_itemsize = new_dtype.size
1012
+
1013
+ old_shape = list(shape_op(x))
1014
+ last_dim_size = old_shape[-1] if len(old_shape) > 0 else -1
1015
+ if (last_dim_size == -1 and old_itemsize != new_itemsize) or (
1016
+ last_dim_size * old_itemsize % new_itemsize != 0
1017
+ ):
1018
+ raise ValueError(
1019
+ f"Cannot view array of shape {x.shape} and dtype {old_dtype} "
1020
+ f"as dtype {new_dtype} because the total number of bytes "
1021
+ f"is not divisible by the new itemsize."
1022
+ )
1023
+
1024
+ if old_itemsize == new_itemsize:
1025
+ return tf.bitcast(x, type=new_dtype)
1026
+ elif old_itemsize > new_itemsize:
1027
+ ratio = old_itemsize // new_itemsize
1028
+ new_shape = list(shape_op(x))
1029
+ new_shape[-1] *= ratio
1030
+ flat_tensor = tf.reshape(x, [-1])
1031
+ cast_tensor = tf.bitcast(flat_tensor, type=new_dtype)
1032
+ return tf.reshape(cast_tensor, new_shape)
1033
+ else:
1034
+ ratio = new_itemsize // old_itemsize
1035
+ if isinstance(last_dim_size, int) and last_dim_size % ratio != 0:
1036
+ raise ValueError(
1037
+ f"Cannot view dtype. Last dimension size ({last_dim_size}) "
1038
+ f"must be divisible by the ratio of new/old item sizes "
1039
+ f"({ratio})."
1040
+ )
1041
+ intermediate_shape = old_shape[:-1] + [last_dim_size // ratio, ratio]
1042
+ reshaped_tensor = tf.reshape(x, intermediate_shape)
1043
+ return tf.bitcast(reshaped_tensor, new_dtype)
1044
+
1045
+
1001
1046
  def average(x, axis=None, weights=None):
1002
1047
  x = convert_to_tensor(x)
1003
1048
 
@@ -1314,11 +1359,7 @@ def deg2rad(x):
1314
1359
  def diag(x, k=0):
1315
1360
  x = convert_to_tensor(x)
1316
1361
  if len(x.shape) == 1:
1317
- return tf.cond(
1318
- tf.equal(tf.size(x), 0),
1319
- lambda: tf.zeros([builtins.abs(k), builtins.abs(k)], dtype=x.dtype),
1320
- lambda: tf.linalg.diag(x, k=k),
1321
- )
1362
+ return tf.linalg.diag(x, k=k)
1322
1363
  elif len(x.shape) == 2:
1323
1364
  return diagonal(x, offset=k)
1324
1365
  else:
@@ -1444,6 +1485,10 @@ def empty(shape, dtype=None):
1444
1485
  return tf.zeros(shape, dtype=dtype)
1445
1486
 
1446
1487
 
1488
+ def empty_like(x, dtype=None):
1489
+ return tf.zeros_like(x, dtype=dtype)
1490
+
1491
+
1447
1492
  def equal(x1, x2):
1448
1493
  x1 = convert_to_tensor(x1)
1449
1494
  x2 = convert_to_tensor(x2)
@@ -1712,6 +1757,14 @@ def isposinf(x):
1712
1757
  return tf.math.equal(x, tf.constant(float("inf"), dtype=x.dtype))
1713
1758
 
1714
1759
 
1760
+ def isreal(x):
1761
+ x = convert_to_tensor(x)
1762
+ if x.dtype.is_complex:
1763
+ return tf.equal(tf.math.imag(x), 0)
1764
+ else:
1765
+ return tf.ones_like(x, dtype=tf.bool)
1766
+
1767
+
1715
1768
  def kron(x1, x2):
1716
1769
  x1 = convert_to_tensor(x1)
1717
1770
  x2 = convert_to_tensor(x2)
@@ -1787,6 +1840,23 @@ def lcm(x1, x2):
1787
1840
  return result
1788
1841
 
1789
1842
 
1843
+ def ldexp(x1, x2):
1844
+ x1 = convert_to_tensor(x1)
1845
+ x2 = convert_to_tensor(x2)
1846
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1847
+
1848
+ if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
1849
+ raise TypeError(
1850
+ f"ldexp exponent must be an integer type. "
1851
+ f"Received: x2 dtype={x2.dtype}"
1852
+ )
1853
+
1854
+ x1 = tf.cast(x1, dtype)
1855
+ x2 = tf.cast(x2, x1.dtype)
1856
+ result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)
1857
+ return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)
1858
+
1859
+
1790
1860
  def less(x1, x2):
1791
1861
  x1 = convert_to_tensor(x1)
1792
1862
  x2 = convert_to_tensor(x2)
@@ -2081,7 +2151,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
2081
2151
 
2082
2152
  def ndim(x):
2083
2153
  x = convert_to_tensor(x)
2084
- return x.ndim
2154
+ return x.shape.rank
2085
2155
 
2086
2156
 
2087
2157
  def nonzero(x):
@@ -2145,6 +2215,13 @@ def prod(x, axis=None, keepdims=False, dtype=None):
2145
2215
  return tf.reduce_prod(x, axis=axis, keepdims=keepdims)
2146
2216
 
2147
2217
 
2218
+ def ptp(x, axis=None, keepdims=False):
2219
+ x = convert_to_tensor(x)
2220
+ return tf.reduce_max(x, axis=axis, keepdims=keepdims) - tf.reduce_min(
2221
+ x, axis=axis, keepdims=keepdims
2222
+ )
2223
+
2224
+
2148
2225
  def _quantile(x, q, axis=None, method="linear", keepdims=False):
2149
2226
  # ref: tfp.stats.percentile
2150
2227
  # float64 is needed here and below, else we get the wrong index if the array
@@ -2250,7 +2327,7 @@ def _quantile(x, q, axis=None, method="linear", keepdims=False):
2250
2327
  return gathered_y
2251
2328
  perm = collections.deque(range(ndims))
2252
2329
  perm.rotate(shift_value_static)
2253
- return tf.transpose(a=gathered_y, perm=perm)
2330
+ return tf.transpose(a=gathered_y, perm=list(perm))
2254
2331
 
2255
2332
 
2256
2333
  def quantile(x, q, axis=None, method="linear", keepdims=False):
@@ -2443,6 +2520,17 @@ def split(x, indices_or_sections, axis=0):
2443
2520
  return tf.split(x, num_or_size_splits, axis=axis)
2444
2521
 
2445
2522
 
2523
+ def array_split(x, indices_or_sections, axis=0):
2524
+ x = tf.convert_to_tensor(x)
2525
+ num_splits = indices_or_sections
2526
+ total_size = shape_op(x)[axis]
2527
+ avg_size = total_size // num_splits
2528
+ remainder = total_size % num_splits
2529
+ sizes = [avg_size + 1] * remainder + [avg_size] * (num_splits - remainder)
2530
+
2531
+ return tf.split(x, sizes, axis=axis)
2532
+
2533
+
2446
2534
  def stack(x, axis=0):
2447
2535
  dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
2448
2536
  if len(dtype_set) > 1:
@@ -2674,27 +2762,44 @@ def round(x, decimals=0):
2674
2762
 
2675
2763
  def tile(x, repeats):
2676
2764
  x = convert_to_tensor(x)
2677
- repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1])
2678
- repeats_size = tf.size(repeats)
2679
- repeats = tf.pad(
2680
- repeats,
2681
- [[tf.maximum(x.shape.rank - repeats_size, 0), 0]],
2682
- constant_values=1,
2683
- )
2684
- x_shape = tf.pad(
2685
- tf.shape(x),
2686
- [[tf.maximum(repeats_size - x.shape.rank, 0), 0]],
2687
- constant_values=1,
2688
- )
2689
- x = tf.reshape(x, x_shape)
2765
+
2766
+ # Convert repeats to a list (works for both sequences and 1D tensors)
2767
+ if isinstance(repeats, int):
2768
+ repeats = [repeats]
2769
+ else:
2770
+ repeats = [v for v in repeats]
2771
+
2772
+ # Process list elements: convert concrete scalar tensors to Python ints
2773
+ processed_repeats = []
2774
+ for r in repeats:
2775
+ if hasattr(r, "numpy") and r.shape == ():
2776
+ processed_repeats.append(int(r.numpy()))
2777
+ else:
2778
+ processed_repeats.append(r)
2779
+ repeats = processed_repeats
2780
+
2781
+ # Get x rank
2782
+ x_rank = x.shape.rank
2783
+
2784
+ # Pad repeats if needed
2785
+ if len(repeats) < x_rank:
2786
+ repeats = [1] * (x_rank - len(repeats)) + repeats
2787
+
2788
+ # Add dimensions to x if needed using tf.expand_dims
2789
+ while len(repeats) > x.shape.rank:
2790
+ x = tf.expand_dims(x, 0)
2791
+
2690
2792
  return tf.tile(x, repeats)
2691
2793
 
2692
2794
 
2693
2795
  def trace(x, offset=0, axis1=0, axis2=1):
2694
2796
  x = convert_to_tensor(x)
2695
2797
  dtype = standardize_dtype(x.dtype)
2696
- if dtype not in ("int64", "uint32", "uint64"):
2697
- dtype = dtypes.result_type(dtype, "int32")
2798
+ if dtype in ("bool", "int8", "int16"):
2799
+ dtype = "int32"
2800
+ elif dtype in ("uint8", "uint16"):
2801
+ dtype = "uint32"
2802
+ x = tf.cast(x, dtype)
2698
2803
  x_shape = tf.shape(x)
2699
2804
  x = moveaxis(x, (axis1, axis2), (-2, -1))
2700
2805
  # Mask out the diagonal and reduce.
@@ -2703,10 +2808,7 @@ def trace(x, offset=0, axis1=0, axis2=1):
2703
2808
  x,
2704
2809
  tf.zeros_like(x),
2705
2810
  )
2706
- # The output dtype is set to "int32" if the input dtype is "bool"
2707
- if standardize_dtype(x.dtype) == "bool":
2708
- x = tf.cast(x, "int32")
2709
- return tf.cast(tf.reduce_sum(x, axis=(-2, -1)), dtype)
2811
+ return tf.reduce_sum(x, axis=(-2, -1))
2710
2812
 
2711
2813
 
2712
2814
  def tri(N, M=None, k=0, dtype=None):
@@ -2922,6 +3024,16 @@ def negative(x):
2922
3024
  return tf.negative(x)
2923
3025
 
2924
3026
 
3027
+ def nextafter(x1, x2):
3028
+ x1 = convert_to_tensor(x1)
3029
+ x2 = convert_to_tensor(x2)
3030
+
3031
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
3032
+ x1 = tf.cast(x1, tf.float64)
3033
+ x2 = tf.cast(x2, tf.float64)
3034
+ return tf.cast(tf.math.nextafter(x1, x2), dtype)
3035
+
3036
+
2925
3037
  @sparse.elementwise_unary
2926
3038
  def square(x):
2927
3039
  x = convert_to_tensor(x)
@@ -2976,6 +3088,63 @@ def transpose(x, axes=None):
2976
3088
  return tf.transpose(x, perm=axes)
2977
3089
 
2978
3090
 
3091
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
3092
+ def _move_axis_to_last(tensor, axis):
3093
+ if axis == -1:
3094
+ return tensor
3095
+ rank = tf.rank(tensor)
3096
+ if axis < 0:
3097
+ axis = rank + axis
3098
+ perm = tf.concat(
3099
+ [
3100
+ tf.range(axis, dtype=tf.int32),
3101
+ tf.range(axis + 1, rank, dtype=tf.int32),
3102
+ tf.constant([axis], dtype=tf.int32),
3103
+ ],
3104
+ axis=0,
3105
+ )
3106
+ return tf.transpose(tensor, perm=perm)
3107
+
3108
+ y = convert_to_tensor(y)
3109
+ dtype = dtypes.result_type(y.dtype, float)
3110
+ y = tf.cast(y, dtype)
3111
+
3112
+ if x is None:
3113
+ dx_array = tf.cast(dx, dtype)
3114
+ else:
3115
+ x = convert_to_tensor(x, dtype=dtype)
3116
+ dx_array = diff(x, axis=axis)
3117
+ dx_array = _move_axis_to_last(dx_array, axis)
3118
+
3119
+ y = _move_axis_to_last(y, axis)
3120
+
3121
+ avg_heights = 0.5 * (y[..., 1:] + y[..., :-1])
3122
+ result = tf.reduce_sum(avg_heights * dx_array, axis=-1)
3123
+
3124
+ return result
3125
+
3126
+
3127
+ def vander(x, N=None, increasing=False):
3128
+ x = convert_to_tensor(x)
3129
+ result_dtype = dtypes.result_type(x.dtype)
3130
+
3131
+ if N is None:
3132
+ N = shape_op(x)[0]
3133
+
3134
+ if increasing:
3135
+ powers = tf.range(N)
3136
+ else:
3137
+ powers = tf.range(N - 1, -1, -1)
3138
+
3139
+ x_exp = tf.expand_dims(x, axis=-1)
3140
+
3141
+ compute_dtype = dtypes.result_type(x.dtype, "float32")
3142
+ vander = tf.math.pow(
3143
+ tf.cast(x_exp, compute_dtype), tf.cast(powers, compute_dtype)
3144
+ )
3145
+ return tf.cast(vander, result_dtype)
3146
+
3147
+
2979
3148
  def var(x, axis=None, keepdims=False):
2980
3149
  x = convert_to_tensor(x)
2981
3150
  compute_dtype = dtypes.result_type(x.dtype, "float32")
@@ -3086,30 +3255,57 @@ def correlate(x1, x2, mode="valid"):
3086
3255
  x1 = tf.cast(x1, dtype)
3087
3256
  x2 = tf.cast(x2, dtype)
3088
3257
 
3089
- x1_len, x2_len = int(x1.shape[0]), int(x2.shape[0])
3258
+ def _pack(a, b):
3259
+ # a: input [N] -> [1,N,1];
3260
+ # b: filter [M] -> [M,1,1]
3261
+ return (
3262
+ tf.reshape(a, (1, shape_op(a)[0], 1)),
3263
+ tf.reshape(b, (shape_op(b)[0], 1, 1)),
3264
+ )
3090
3265
 
3091
- if mode == "full":
3092
- full_len = x1_len + x2_len - 1
3266
+ def _full_corr(x1, x2):
3267
+ """Compute 'full' correlation result (length = n + m - 1)."""
3268
+ m = shape_op(x2)[0]
3269
+ pad = (
3270
+ builtins.max(m - 1, 0)
3271
+ if isinstance(m, int)
3272
+ else tf.maximum(m - 1, 0)
3273
+ )
3274
+ x1 = tf.pad(x1, [[pad, pad]]) # pad input with zeros
3275
+ x1, x2 = _pack(x1, x2)
3276
+ out = tf.nn.conv1d(x1, x2, stride=1, padding="VALID")
3277
+ return tf.squeeze(out, axis=[0, 2])
3093
3278
 
3094
- x1_pad = (full_len - x1_len) / 2
3095
- x2_pad = (full_len - x2_len) / 2
3279
+ n = shape_op(x1)[0]
3280
+ m = shape_op(x2)[0]
3096
3281
 
3097
- x1 = tf.pad(
3098
- x1, paddings=[[tf.math.floor(x1_pad), tf.math.ceil(x1_pad)]]
3282
+ if mode == "full":
3283
+ return _full_corr(x1, x2)
3284
+ elif mode == "same":
3285
+ # unfortunately we can't leverage 'SAME' padding directly like
3286
+ # we can with "valid"
3287
+ # it works fine for odd-length filters, but for even-length filters
3288
+ # the output is off by 1 compared to numpy, due to how
3289
+ # tf handles centering
3290
+ full_corr = _full_corr(x1, x2)
3291
+ full_len = n + m - 1
3292
+ out_len = (
3293
+ max([n, m])
3294
+ if isinstance(n, int) and isinstance(m, int)
3295
+ else tf.maximum(n, m)
3099
3296
  )
3100
- x2 = tf.pad(
3101
- x2, paddings=[[tf.math.floor(x2_pad), tf.math.ceil(x2_pad)]]
3297
+ start = (full_len - out_len) // 2
3298
+ return tf.slice(full_corr, [start], [out_len])
3299
+ elif mode == "valid":
3300
+ x1, x2 = _pack(x1, x2)
3301
+ return tf.squeeze(
3302
+ tf.nn.conv1d(x1, x2, stride=1, padding="VALID"), axis=[0, 2]
3303
+ )
3304
+ else:
3305
+ raise ValueError(
3306
+ f"Invalid mode: '{mode}'. Mode must be one of:"
3307
+ f" 'full', 'same', 'valid'."
3102
3308
  )
3103
-
3104
- x1 = tf.reshape(x1, (1, full_len, 1))
3105
- x2 = tf.reshape(x2, (full_len, 1, 1))
3106
-
3107
- return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding="SAME"))
3108
-
3109
- x1 = tf.reshape(x1, (1, x1_len, 1))
3110
- x2 = tf.reshape(x2, (x2_len, 1, 1))
3111
-
3112
- return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding=mode.upper()))
3113
3309
 
3114
3310
 
3115
3311
  def select(condlist, choicelist, default=0):
@@ -3161,10 +3357,14 @@ def histogram(x, bins=10, range=None):
3161
3357
 
3162
3358
  x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
3163
3359
  bin_edges = tf.linspace(min_val, max_val, bins + 1)
3164
- bin_edges_list = bin_edges.numpy().tolist()
3165
- bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1])
3166
-
3167
- bin_counts = tf.math.bincount(
3168
- bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype
3360
+ bin_edges = tf.cast(bin_edges, x.dtype)
3361
+ bin_indices = tf.searchsorted(bin_edges[1:-1], x, side="right")
3362
+
3363
+ # tf.math.bincount does not work with XLA in this case. So, we use
3364
+ # `scatter_nd`.
3365
+ bin_counts = tf.scatter_nd(
3366
+ indices=tf.expand_dims(bin_indices, axis=-1),
3367
+ updates=tf.ones_like(bin_indices, dtype=x.dtype),
3368
+ shape=(bins,),
3169
3369
  )
3170
3370
  return bin_counts, bin_edges
@@ -673,7 +673,9 @@ def remat(f):
673
673
  """
674
674
 
675
675
  def wrapped(*args, **kwargs):
676
- return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False)
676
+ return torch.utils.checkpoint.checkpoint(
677
+ f, *args, use_reentrant=False, **kwargs
678
+ )
677
679
 
678
680
  return wrapped
679
681
 
@@ -80,3 +80,7 @@ def lstsq(a, b, rcond=None):
80
80
  a = convert_to_tensor(a)
81
81
  b = convert_to_tensor(b)
82
82
  return torch.linalg.lstsq(a, b, rcond=rcond)[0]
83
+
84
+
85
+ def jvp(fun, primals, tangents, has_aux=False):
86
+ return torch.func.jvp(fun, primals, tangents, has_aux=has_aux)
@@ -458,6 +458,94 @@ def average_pool(
458
458
  return outputs
459
459
 
460
460
 
461
+ def adaptive_average_pool(inputs, output_size, data_format=None):
462
+ """Adaptive average pooling(1D/2D/3D) with channels_last support."""
463
+ inputs = convert_to_tensor(inputs)
464
+ num_spatial_dims = inputs.ndim - 2
465
+
466
+ data_format = backend.standardize_data_format(data_format)
467
+ orig_format = data_format
468
+ if data_format == "channels_last":
469
+ inputs = _transpose_spatial_inputs(inputs)
470
+
471
+ if isinstance(output_size, int):
472
+ torch_output_size = (
473
+ output_size
474
+ if num_spatial_dims == 1
475
+ else (output_size,) * num_spatial_dims
476
+ )
477
+ else:
478
+ torch_output_size = standardize_tuple(
479
+ output_size, num_spatial_dims, "output_size"
480
+ )
481
+
482
+ if get_device() == "meta":
483
+ inputs = torch.empty(
484
+ size=inputs.shape, dtype=inputs.dtype, device="cpu"
485
+ )
486
+
487
+ if num_spatial_dims == 1:
488
+ outputs = tnn.adaptive_avg_pool1d(inputs, output_size=torch_output_size)
489
+ elif num_spatial_dims == 2:
490
+ outputs = tnn.adaptive_avg_pool2d(inputs, output_size=torch_output_size)
491
+ elif num_spatial_dims == 3:
492
+ outputs = tnn.adaptive_avg_pool3d(inputs, output_size=torch_output_size)
493
+ else:
494
+ raise ValueError(
495
+ "Inputs to adaptive average pooling must have ndim=3, 4 or 5, "
496
+ f"Received input shape: {inputs.shape}."
497
+ )
498
+
499
+ if orig_format == "channels_last":
500
+ outputs = _transpose_spatial_outputs(outputs)
501
+ return outputs
502
+
503
+
504
+ def adaptive_max_pool(inputs, output_size, data_format=None):
505
+ """Adaptive max pooling(1D/2D/3D) with channels_last support."""
506
+ inputs = convert_to_tensor(inputs)
507
+ num_spatial_dims = inputs.ndim - 2
508
+
509
+ data_format = backend.standardize_data_format(data_format)
510
+ orig_format = data_format
511
+ if data_format == "channels_last":
512
+ inputs = _transpose_spatial_inputs(inputs)
513
+
514
+ if isinstance(output_size, int):
515
+ torch_output_size = (
516
+ output_size
517
+ if num_spatial_dims == 1
518
+ else (output_size,) * num_spatial_dims
519
+ )
520
+ else:
521
+ torch_output_size = standardize_tuple(
522
+ output_size, num_spatial_dims, "output_size"
523
+ )
524
+
525
+ if get_device() == "meta":
526
+ inputs = torch.empty(
527
+ size=inputs.shape, dtype=inputs.dtype, device="cpu"
528
+ )
529
+
530
+ if num_spatial_dims == 1:
531
+ res = tnn.adaptive_max_pool1d(inputs, output_size=torch_output_size)
532
+ elif num_spatial_dims == 2:
533
+ res = tnn.adaptive_max_pool2d(inputs, output_size=torch_output_size)
534
+ elif num_spatial_dims == 3:
535
+ res = tnn.adaptive_max_pool3d(inputs, output_size=torch_output_size)
536
+ else:
537
+ raise ValueError(
538
+ "Inputs to adaptive max pooling must have ndim=3, 4 or 5, "
539
+ f"Received input shape: {inputs.shape}."
540
+ )
541
+
542
+ outputs = res[0] if isinstance(res, tuple) else res
543
+
544
+ if orig_format == "channels_last":
545
+ outputs = _transpose_spatial_outputs(outputs)
546
+ return outputs
547
+
548
+
461
549
  def conv(
462
550
  inputs,
463
551
  kernel,
@@ -755,12 +843,26 @@ def binary_crossentropy(target, output, from_logits=False):
755
843
  target = convert_to_tensor(target)
756
844
  output = convert_to_tensor(output)
757
845
 
846
+ # We only apply the squeeze fix if we are on an MPS device,
847
+ # as this change breaks tests on other platforms that
848
+ # expect the original tensor shape to be preserved.
849
+ if (
850
+ torch.backends.mps.is_available()
851
+ and target.ndim > 1
852
+ and output.ndim == target.ndim
853
+ and target.shape[-1] == 1
854
+ and output.shape[-1] == 1
855
+ ):
856
+ target = torch.squeeze(target, -1).contiguous()
857
+ output = torch.squeeze(output, -1).contiguous()
858
+
758
859
  if target.shape != output.shape:
759
860
  raise ValueError(
760
861
  "Arguments `target` and `output` must have the same shape. "
761
862
  "Received: "
762
863
  f"target.shape={target.shape}, output.shape={output.shape}"
763
864
  )
865
+
764
866
  # By default, PyTorch, does reduction of `sum` over all rows,
765
867
  # change reduction to `none` to keep dim
766
868
  if from_logits:
@@ -1092,3 +1194,26 @@ def dot_product_attention(
1092
1194
  scale=scale,
1093
1195
  )
1094
1196
  return torch.transpose(attention_output, axis1, axis0)
1197
+
1198
+
1199
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1200
+ """Native PyTorch implementation of Unfold.
1201
+ Extract sliding local blocks from a **NCHW** batched image tensor.
1202
+
1203
+ Args:
1204
+ input: 4-D tensor, shape (N, C, H, W) **required**.
1205
+ kernel_size: int or (kH, kW)
1206
+ dilation: int or (dH, dW), default 1
1207
+ padding: int or (pH, pW), default 0
1208
+ stride: int or (sH, sW), default 1
1209
+
1210
+ Returns:
1211
+ 3-D tensor, shape (N, C*kH*kW, L)
1212
+ """
1213
+ return tnn.unfold(
1214
+ input,
1215
+ kernel_size=kernel_size,
1216
+ dilation=dilation,
1217
+ padding=padding,
1218
+ stride=stride,
1219
+ )