keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -813,16 +813,17 @@ def append(x1, x2, axis=None):
813
813
  return tf.concat([x1, x2], axis=axis)
814
814
 
815
815
 
816
- def arange(start, stop=None, step=1, dtype=None):
816
+ def arange(start, stop=None, step=None, dtype=None):
817
817
  if dtype is None:
818
- dtypes_to_resolve = [
819
- getattr(start, "dtype", type(start)),
820
- getattr(step, "dtype", type(step)),
821
- ]
818
+ dtypes_to_resolve = [getattr(start, "dtype", type(start))]
822
819
  if stop is not None:
823
820
  dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
821
+ if step is not None:
822
+ dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
824
823
  dtype = dtypes.result_type(*dtypes_to_resolve)
825
824
  dtype = standardize_dtype(dtype)
825
+ if step is None:
826
+ step = 1
826
827
  try:
827
828
  out = tf.range(start, stop, delta=step, dtype=dtype)
828
829
  except tf.errors.NotFoundError:
@@ -997,6 +998,51 @@ def array(x, dtype=None):
997
998
  return convert_to_tensor(x, dtype=dtype)
998
999
 
999
1000
 
1001
+ def view(x, dtype=None):
1002
+ from keras.src import backend
1003
+
1004
+ x = convert_to_tensor(x)
1005
+ old_dtype = tf.as_dtype(backend.standardize_dtype(x.dtype))
1006
+ new_dtype = tf.as_dtype(
1007
+ backend.standardize_dtype(dtype if dtype else x.dtype)
1008
+ )
1009
+
1010
+ old_itemsize = old_dtype.size
1011
+ new_itemsize = new_dtype.size
1012
+
1013
+ old_shape = list(shape_op(x))
1014
+ last_dim_size = old_shape[-1] if len(old_shape) > 0 else -1
1015
+ if (last_dim_size == -1 and old_itemsize != new_itemsize) or (
1016
+ last_dim_size * old_itemsize % new_itemsize != 0
1017
+ ):
1018
+ raise ValueError(
1019
+ f"Cannot view array of shape {x.shape} and dtype {old_dtype} "
1020
+ f"as dtype {new_dtype} because the total number of bytes "
1021
+ f"is not divisible by the new itemsize."
1022
+ )
1023
+
1024
+ if old_itemsize == new_itemsize:
1025
+ return tf.bitcast(x, type=new_dtype)
1026
+ elif old_itemsize > new_itemsize:
1027
+ ratio = old_itemsize // new_itemsize
1028
+ new_shape = list(shape_op(x))
1029
+ new_shape[-1] *= ratio
1030
+ flat_tensor = tf.reshape(x, [-1])
1031
+ cast_tensor = tf.bitcast(flat_tensor, type=new_dtype)
1032
+ return tf.reshape(cast_tensor, new_shape)
1033
+ else:
1034
+ ratio = new_itemsize // old_itemsize
1035
+ if isinstance(last_dim_size, int) and last_dim_size % ratio != 0:
1036
+ raise ValueError(
1037
+ f"Cannot view dtype. Last dimension size ({last_dim_size}) "
1038
+ f"must be divisible by the ratio of new/old item sizes "
1039
+ f"({ratio})."
1040
+ )
1041
+ intermediate_shape = old_shape[:-1] + [last_dim_size // ratio, ratio]
1042
+ reshaped_tensor = tf.reshape(x, intermediate_shape)
1043
+ return tf.bitcast(reshaped_tensor, new_dtype)
1044
+
1045
+
1000
1046
  def average(x, axis=None, weights=None):
1001
1047
  x = convert_to_tensor(x)
1002
1048
 
@@ -1313,11 +1359,7 @@ def deg2rad(x):
1313
1359
  def diag(x, k=0):
1314
1360
  x = convert_to_tensor(x)
1315
1361
  if len(x.shape) == 1:
1316
- return tf.cond(
1317
- tf.equal(tf.size(x), 0),
1318
- lambda: tf.zeros([builtins.abs(k), builtins.abs(k)], dtype=x.dtype),
1319
- lambda: tf.linalg.diag(x, k=k),
1320
- )
1362
+ return tf.linalg.diag(x, k=k)
1321
1363
  elif len(x.shape) == 2:
1322
1364
  return diagonal(x, offset=k)
1323
1365
  else:
@@ -1443,6 +1485,10 @@ def empty(shape, dtype=None):
1443
1485
  return tf.zeros(shape, dtype=dtype)
1444
1486
 
1445
1487
 
1488
+ def empty_like(x, dtype=None):
1489
+ return tf.zeros_like(x, dtype=dtype)
1490
+
1491
+
1446
1492
  def equal(x1, x2):
1447
1493
  x1 = convert_to_tensor(x1)
1448
1494
  x2 = convert_to_tensor(x2)
@@ -1711,6 +1757,106 @@ def isposinf(x):
1711
1757
  return tf.math.equal(x, tf.constant(float("inf"), dtype=x.dtype))
1712
1758
 
1713
1759
 
1760
+ def isreal(x):
1761
+ x = convert_to_tensor(x)
1762
+ if x.dtype.is_complex:
1763
+ return tf.equal(tf.math.imag(x), 0)
1764
+ else:
1765
+ return tf.ones_like(x, dtype=tf.bool)
1766
+
1767
+
1768
+ def kron(x1, x2):
1769
+ x1 = convert_to_tensor(x1)
1770
+ x2 = convert_to_tensor(x2)
1771
+
1772
+ dtype = dtypes.result_type(x1.dtype, x2.dtype)
1773
+ x1 = tf.cast(x1, dtype)
1774
+ x2 = tf.cast(x2, dtype)
1775
+
1776
+ ndim_x1 = tf.rank(x1)
1777
+ ndim_x2 = tf.rank(x2)
1778
+
1779
+ def expand_front(x, num):
1780
+ for _ in range(num):
1781
+ x = tf.expand_dims(x, axis=0)
1782
+ return x
1783
+
1784
+ x1 = tf.cond(
1785
+ ndim_x1 < ndim_x2,
1786
+ lambda: expand_front(x1, ndim_x2 - ndim_x1),
1787
+ lambda: x1,
1788
+ )
1789
+ x2 = tf.cond(
1790
+ ndim_x2 < ndim_x1,
1791
+ lambda: expand_front(x2, ndim_x1 - ndim_x2),
1792
+ lambda: x2,
1793
+ )
1794
+
1795
+ x1_reshaped = tf.reshape(
1796
+ x1,
1797
+ tf.reshape(
1798
+ tf.stack([tf.shape(x1), tf.ones_like(tf.shape(x1))], axis=1), [-1]
1799
+ ),
1800
+ )
1801
+ x2_reshaped = tf.reshape(
1802
+ x2,
1803
+ tf.reshape(
1804
+ tf.stack([tf.ones_like(tf.shape(x2)), tf.shape(x2)], axis=1), [-1]
1805
+ ),
1806
+ )
1807
+
1808
+ out = tf.multiply(x1_reshaped, x2_reshaped)
1809
+ out_shape = tf.multiply(tf.shape(x1), tf.shape(x2))
1810
+ out = tf.reshape(out, out_shape)
1811
+ return out
1812
+
1813
+
1814
+ def lcm(x1, x2):
1815
+ x1 = convert_to_tensor(x1)
1816
+ x2 = convert_to_tensor(x2)
1817
+
1818
+ if not (x1.dtype.is_integer and x2.dtype.is_integer):
1819
+ raise TypeError(
1820
+ f"Arguments to lcm must be integers. "
1821
+ f"Received: x1.dtype={x1.dtype.name}, x2.dtype={x2.dtype.name}"
1822
+ )
1823
+
1824
+ dtype = dtypes.result_type(x1.dtype, x2.dtype)
1825
+ x1 = tf.cast(x1, dtype)
1826
+ x2 = tf.cast(x2, dtype)
1827
+
1828
+ if dtype not in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]:
1829
+ x1 = tf.math.abs(x1)
1830
+ x2 = tf.math.abs(x2)
1831
+
1832
+ divisor = gcd(x1, x2)
1833
+ divisor_safe = tf.where(
1834
+ divisor == 0, tf.constant(1, dtype=divisor.dtype), divisor
1835
+ )
1836
+
1837
+ result = x1 * (x2 // divisor_safe)
1838
+ result = tf.where(divisor == 0, tf.zeros_like(result), result)
1839
+
1840
+ return result
1841
+
1842
+
1843
+ def ldexp(x1, x2):
1844
+ x1 = convert_to_tensor(x1)
1845
+ x2 = convert_to_tensor(x2)
1846
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1847
+
1848
+ if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
1849
+ raise TypeError(
1850
+ f"ldexp exponent must be an integer type. "
1851
+ f"Received: x2 dtype={x2.dtype}"
1852
+ )
1853
+
1854
+ x1 = tf.cast(x1, dtype)
1855
+ x2 = tf.cast(x2, x1.dtype)
1856
+ result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)
1857
+ return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)
1858
+
1859
+
1714
1860
  def less(x1, x2):
1715
1861
  x1 = convert_to_tensor(x1)
1716
1862
  x2 = convert_to_tensor(x2)
@@ -1834,6 +1980,22 @@ def logaddexp(x1, x2):
1834
1980
  )
1835
1981
 
1836
1982
 
1983
+ def logaddexp2(x1, x2):
1984
+ x1 = tf.convert_to_tensor(x1)
1985
+ x2 = tf.convert_to_tensor(x2)
1986
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1987
+ x1 = tf.cast(x1, dtype)
1988
+ x2 = tf.cast(x2, dtype)
1989
+ delta = x1 - x2
1990
+ log2 = tf.cast(tf.math.log(2.0), dtype)
1991
+ return tf.where(
1992
+ tf.math.is_nan(delta),
1993
+ x1 + x2,
1994
+ tf.maximum(x1, x2)
1995
+ + tf.math.log1p(tf.math.exp(-tf.abs(delta) * log2)) / log2,
1996
+ )
1997
+
1998
+
1837
1999
  def logical_and(x1, x2):
1838
2000
  x1 = tf.cast(x1, "bool")
1839
2001
  x2 = tf.cast(x2, "bool")
@@ -1989,7 +2151,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
1989
2151
 
1990
2152
  def ndim(x):
1991
2153
  x = convert_to_tensor(x)
1992
- return x.ndim
2154
+ return x.shape.rank
1993
2155
 
1994
2156
 
1995
2157
  def nonzero(x):
@@ -2053,6 +2215,13 @@ def prod(x, axis=None, keepdims=False, dtype=None):
2053
2215
  return tf.reduce_prod(x, axis=axis, keepdims=keepdims)
2054
2216
 
2055
2217
 
2218
+ def ptp(x, axis=None, keepdims=False):
2219
+ x = convert_to_tensor(x)
2220
+ return tf.reduce_max(x, axis=axis, keepdims=keepdims) - tf.reduce_min(
2221
+ x, axis=axis, keepdims=keepdims
2222
+ )
2223
+
2224
+
2056
2225
  def _quantile(x, q, axis=None, method="linear", keepdims=False):
2057
2226
  # ref: tfp.stats.percentile
2058
2227
  # float64 is needed here and below, else we get the wrong index if the array
@@ -2158,7 +2327,7 @@ def _quantile(x, q, axis=None, method="linear", keepdims=False):
2158
2327
  return gathered_y
2159
2328
  perm = collections.deque(range(ndims))
2160
2329
  perm.rotate(shift_value_static)
2161
- return tf.transpose(a=gathered_y, perm=perm)
2330
+ return tf.transpose(a=gathered_y, perm=list(perm))
2162
2331
 
2163
2332
 
2164
2333
  def quantile(x, q, axis=None, method="linear", keepdims=False):
@@ -2256,8 +2425,11 @@ def searchsorted(sorted_sequence, values, side="left"):
2256
2425
  "to extend it to N-D sequences. Received: "
2257
2426
  f"sorted_sequence.shape={sorted_sequence.shape}"
2258
2427
  )
2428
+ sequence_len = sorted_sequence.shape[0]
2259
2429
  out_type = (
2260
- "int32" if len(sorted_sequence) <= np.iinfo(np.int32).max else "int64"
2430
+ "int32"
2431
+ if sequence_len is not None and sequence_len <= np.iinfo(np.int32).max
2432
+ else "int64"
2261
2433
  )
2262
2434
  return tf.searchsorted(
2263
2435
  sorted_sequence, values, side=side, out_type=out_type
@@ -2348,6 +2520,17 @@ def split(x, indices_or_sections, axis=0):
2348
2520
  return tf.split(x, num_or_size_splits, axis=axis)
2349
2521
 
2350
2522
 
2523
+ def array_split(x, indices_or_sections, axis=0):
2524
+ x = tf.convert_to_tensor(x)
2525
+ num_splits = indices_or_sections
2526
+ total_size = shape_op(x)[axis]
2527
+ avg_size = total_size // num_splits
2528
+ remainder = total_size % num_splits
2529
+ sizes = [avg_size + 1] * remainder + [avg_size] * (num_splits - remainder)
2530
+
2531
+ return tf.split(x, sizes, axis=axis)
2532
+
2533
+
2351
2534
  def stack(x, axis=0):
2352
2535
  dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
2353
2536
  if len(dtype_set) > 1:
@@ -2579,27 +2762,44 @@ def round(x, decimals=0):
2579
2762
 
2580
2763
  def tile(x, repeats):
2581
2764
  x = convert_to_tensor(x)
2582
- repeats = tf.reshape(convert_to_tensor(repeats, dtype="int32"), [-1])
2583
- repeats_size = tf.size(repeats)
2584
- repeats = tf.pad(
2585
- repeats,
2586
- [[tf.maximum(x.shape.rank - repeats_size, 0), 0]],
2587
- constant_values=1,
2588
- )
2589
- x_shape = tf.pad(
2590
- tf.shape(x),
2591
- [[tf.maximum(repeats_size - x.shape.rank, 0), 0]],
2592
- constant_values=1,
2593
- )
2594
- x = tf.reshape(x, x_shape)
2765
+
2766
+ # Convert repeats to a list (works for both sequences and 1D tensors)
2767
+ if isinstance(repeats, int):
2768
+ repeats = [repeats]
2769
+ else:
2770
+ repeats = [v for v in repeats]
2771
+
2772
+ # Process list elements: convert concrete scalar tensors to Python ints
2773
+ processed_repeats = []
2774
+ for r in repeats:
2775
+ if hasattr(r, "numpy") and r.shape == ():
2776
+ processed_repeats.append(int(r.numpy()))
2777
+ else:
2778
+ processed_repeats.append(r)
2779
+ repeats = processed_repeats
2780
+
2781
+ # Get x rank
2782
+ x_rank = x.shape.rank
2783
+
2784
+ # Pad repeats if needed
2785
+ if len(repeats) < x_rank:
2786
+ repeats = [1] * (x_rank - len(repeats)) + repeats
2787
+
2788
+ # Add dimensions to x if needed using tf.expand_dims
2789
+ while len(repeats) > x.shape.rank:
2790
+ x = tf.expand_dims(x, 0)
2791
+
2595
2792
  return tf.tile(x, repeats)
2596
2793
 
2597
2794
 
2598
2795
  def trace(x, offset=0, axis1=0, axis2=1):
2599
2796
  x = convert_to_tensor(x)
2600
2797
  dtype = standardize_dtype(x.dtype)
2601
- if dtype not in ("int64", "uint32", "uint64"):
2602
- dtype = dtypes.result_type(dtype, "int32")
2798
+ if dtype in ("bool", "int8", "int16"):
2799
+ dtype = "int32"
2800
+ elif dtype in ("uint8", "uint16"):
2801
+ dtype = "uint32"
2802
+ x = tf.cast(x, dtype)
2603
2803
  x_shape = tf.shape(x)
2604
2804
  x = moveaxis(x, (axis1, axis2), (-2, -1))
2605
2805
  # Mask out the diagonal and reduce.
@@ -2608,10 +2808,7 @@ def trace(x, offset=0, axis1=0, axis2=1):
2608
2808
  x,
2609
2809
  tf.zeros_like(x),
2610
2810
  )
2611
- # The output dtype is set to "int32" if the input dtype is "bool"
2612
- if standardize_dtype(x.dtype) == "bool":
2613
- x = tf.cast(x, "int32")
2614
- return tf.cast(tf.reduce_sum(x, axis=(-2, -1)), dtype)
2811
+ return tf.reduce_sum(x, axis=(-2, -1))
2615
2812
 
2616
2813
 
2617
2814
  def tri(N, M=None, k=0, dtype=None):
@@ -2827,6 +3024,16 @@ def negative(x):
2827
3024
  return tf.negative(x)
2828
3025
 
2829
3026
 
3027
+ def nextafter(x1, x2):
3028
+ x1 = convert_to_tensor(x1)
3029
+ x2 = convert_to_tensor(x2)
3030
+
3031
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
3032
+ x1 = tf.cast(x1, tf.float64)
3033
+ x2 = tf.cast(x2, tf.float64)
3034
+ return tf.cast(tf.math.nextafter(x1, x2), dtype)
3035
+
3036
+
2830
3037
  @sparse.elementwise_unary
2831
3038
  def square(x):
2832
3039
  x = convert_to_tensor(x)
@@ -2881,6 +3088,63 @@ def transpose(x, axes=None):
2881
3088
  return tf.transpose(x, perm=axes)
2882
3089
 
2883
3090
 
3091
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
3092
+ def _move_axis_to_last(tensor, axis):
3093
+ if axis == -1:
3094
+ return tensor
3095
+ rank = tf.rank(tensor)
3096
+ if axis < 0:
3097
+ axis = rank + axis
3098
+ perm = tf.concat(
3099
+ [
3100
+ tf.range(axis, dtype=tf.int32),
3101
+ tf.range(axis + 1, rank, dtype=tf.int32),
3102
+ tf.constant([axis], dtype=tf.int32),
3103
+ ],
3104
+ axis=0,
3105
+ )
3106
+ return tf.transpose(tensor, perm=perm)
3107
+
3108
+ y = convert_to_tensor(y)
3109
+ dtype = dtypes.result_type(y.dtype, float)
3110
+ y = tf.cast(y, dtype)
3111
+
3112
+ if x is None:
3113
+ dx_array = tf.cast(dx, dtype)
3114
+ else:
3115
+ x = convert_to_tensor(x, dtype=dtype)
3116
+ dx_array = diff(x, axis=axis)
3117
+ dx_array = _move_axis_to_last(dx_array, axis)
3118
+
3119
+ y = _move_axis_to_last(y, axis)
3120
+
3121
+ avg_heights = 0.5 * (y[..., 1:] + y[..., :-1])
3122
+ result = tf.reduce_sum(avg_heights * dx_array, axis=-1)
3123
+
3124
+ return result
3125
+
3126
+
3127
+ def vander(x, N=None, increasing=False):
3128
+ x = convert_to_tensor(x)
3129
+ result_dtype = dtypes.result_type(x.dtype)
3130
+
3131
+ if N is None:
3132
+ N = shape_op(x)[0]
3133
+
3134
+ if increasing:
3135
+ powers = tf.range(N)
3136
+ else:
3137
+ powers = tf.range(N - 1, -1, -1)
3138
+
3139
+ x_exp = tf.expand_dims(x, axis=-1)
3140
+
3141
+ compute_dtype = dtypes.result_type(x.dtype, "float32")
3142
+ vander = tf.math.pow(
3143
+ tf.cast(x_exp, compute_dtype), tf.cast(powers, compute_dtype)
3144
+ )
3145
+ return tf.cast(vander, result_dtype)
3146
+
3147
+
2884
3148
  def var(x, axis=None, keepdims=False):
2885
3149
  x = convert_to_tensor(x)
2886
3150
  compute_dtype = dtypes.result_type(x.dtype, "float32")
@@ -2991,30 +3255,57 @@ def correlate(x1, x2, mode="valid"):
2991
3255
  x1 = tf.cast(x1, dtype)
2992
3256
  x2 = tf.cast(x2, dtype)
2993
3257
 
2994
- x1_len, x2_len = int(x1.shape[0]), int(x2.shape[0])
3258
+ def _pack(a, b):
3259
+ # a: input [N] -> [1,N,1];
3260
+ # b: filter [M] -> [M,1,1]
3261
+ return (
3262
+ tf.reshape(a, (1, shape_op(a)[0], 1)),
3263
+ tf.reshape(b, (shape_op(b)[0], 1, 1)),
3264
+ )
2995
3265
 
2996
- if mode == "full":
2997
- full_len = x1_len + x2_len - 1
3266
+ def _full_corr(x1, x2):
3267
+ """Compute 'full' correlation result (length = n + m - 1)."""
3268
+ m = shape_op(x2)[0]
3269
+ pad = (
3270
+ builtins.max(m - 1, 0)
3271
+ if isinstance(m, int)
3272
+ else tf.maximum(m - 1, 0)
3273
+ )
3274
+ x1 = tf.pad(x1, [[pad, pad]]) # pad input with zeros
3275
+ x1, x2 = _pack(x1, x2)
3276
+ out = tf.nn.conv1d(x1, x2, stride=1, padding="VALID")
3277
+ return tf.squeeze(out, axis=[0, 2])
2998
3278
 
2999
- x1_pad = (full_len - x1_len) / 2
3000
- x2_pad = (full_len - x2_len) / 2
3279
+ n = shape_op(x1)[0]
3280
+ m = shape_op(x2)[0]
3001
3281
 
3002
- x1 = tf.pad(
3003
- x1, paddings=[[tf.math.floor(x1_pad), tf.math.ceil(x1_pad)]]
3282
+ if mode == "full":
3283
+ return _full_corr(x1, x2)
3284
+ elif mode == "same":
3285
+ # unfortunately we can't leverage 'SAME' padding directly like
3286
+ # we can with "valid"
3287
+ # it works fine for odd-length filters, but for even-length filters
3288
+ # the output is off by 1 compared to numpy, due to how
3289
+ # tf handles centering
3290
+ full_corr = _full_corr(x1, x2)
3291
+ full_len = n + m - 1
3292
+ out_len = (
3293
+ max([n, m])
3294
+ if isinstance(n, int) and isinstance(m, int)
3295
+ else tf.maximum(n, m)
3004
3296
  )
3005
- x2 = tf.pad(
3006
- x2, paddings=[[tf.math.floor(x2_pad), tf.math.ceil(x2_pad)]]
3297
+ start = (full_len - out_len) // 2
3298
+ return tf.slice(full_corr, [start], [out_len])
3299
+ elif mode == "valid":
3300
+ x1, x2 = _pack(x1, x2)
3301
+ return tf.squeeze(
3302
+ tf.nn.conv1d(x1, x2, stride=1, padding="VALID"), axis=[0, 2]
3303
+ )
3304
+ else:
3305
+ raise ValueError(
3306
+ f"Invalid mode: '{mode}'. Mode must be one of:"
3307
+ f" 'full', 'same', 'valid'."
3007
3308
  )
3008
-
3009
- x1 = tf.reshape(x1, (1, full_len, 1))
3010
- x2 = tf.reshape(x2, (full_len, 1, 1))
3011
-
3012
- return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding="SAME"))
3013
-
3014
- x1 = tf.reshape(x1, (1, x1_len, 1))
3015
- x2 = tf.reshape(x2, (x2_len, 1, 1))
3016
-
3017
- return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding=mode.upper()))
3018
3309
 
3019
3310
 
3020
3311
  def select(condlist, choicelist, default=0):
@@ -3066,10 +3357,14 @@ def histogram(x, bins=10, range=None):
3066
3357
 
3067
3358
  x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
3068
3359
  bin_edges = tf.linspace(min_val, max_val, bins + 1)
3069
- bin_edges_list = bin_edges.numpy().tolist()
3070
- bin_indices = tf.raw_ops.Bucketize(input=x, boundaries=bin_edges_list[1:-1])
3071
-
3072
- bin_counts = tf.math.bincount(
3073
- bin_indices, minlength=bins, maxlength=bins, dtype=x.dtype
3360
+ bin_edges = tf.cast(bin_edges, x.dtype)
3361
+ bin_indices = tf.searchsorted(bin_edges[1:-1], x, side="right")
3362
+
3363
+ # tf.math.bincount does not work with XLA in this case. So, we use
3364
+ # `scatter_nd`.
3365
+ bin_counts = tf.scatter_nd(
3366
+ indices=tf.expand_dims(bin_indices, axis=-1),
3367
+ updates=tf.ones_like(bin_indices, dtype=x.dtype),
3368
+ shape=(bins,),
3074
3369
  )
3075
3370
  return bin_counts, bin_edges
@@ -68,7 +68,9 @@ class TensorFlowTrainer(base_trainer.Trainer):
68
68
  )
69
69
  self._loss_tracker.update_state(
70
70
  loss_module.unscale_loss_for_distribution(loss),
71
- sample_weight=tf.shape(tree.flatten(x)[0])[0],
71
+ sample_weight=tf.shape(
72
+ next(i for i in tree.flatten(x) if i is not None)
73
+ )[0],
72
74
  )
73
75
  if self.optimizer is not None:
74
76
  loss = self.optimizer.scale_loss(loss)
@@ -96,7 +98,9 @@ class TensorFlowTrainer(base_trainer.Trainer):
96
98
  )
97
99
  self._loss_tracker.update_state(
98
100
  loss_module.unscale_loss_for_distribution(loss),
99
- sample_weight=tf.shape(tree.flatten(x)[0])[0],
101
+ sample_weight=tf.shape(
102
+ next(i for i in tree.flatten(x) if i is not None)
103
+ )[0],
100
104
  )
101
105
  return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
102
106
 
@@ -673,7 +673,9 @@ def remat(f):
673
673
  """
674
674
 
675
675
  def wrapped(*args, **kwargs):
676
- return torch.utils.checkpoint.checkpoint(f, *args, use_reentrant=False)
676
+ return torch.utils.checkpoint.checkpoint(
677
+ f, *args, use_reentrant=False, **kwargs
678
+ )
677
679
 
678
680
  return wrapped
679
681
 
@@ -80,3 +80,7 @@ def lstsq(a, b, rcond=None):
80
80
  a = convert_to_tensor(a)
81
81
  b = convert_to_tensor(b)
82
82
  return torch.linalg.lstsq(a, b, rcond=rcond)[0]
83
+
84
+
85
+ def jvp(fun, primals, tangents, has_aux=False):
86
+ return torch.func.jvp(fun, primals, tangents, has_aux=has_aux)