keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +12 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +12 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +1 -1
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +33 -16
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +485 -20
- keras/src/backend/jax/numpy.py +92 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +76 -7
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1030 -185
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +264 -54
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +84 -8
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +299 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +191 -172
- keras/src/layers/core/einsum_dense.py +235 -186
- keras/src/layers/core/embedding.py +83 -93
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +390 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +40 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/losses/loss.py +1 -1
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +156 -27
- keras/src/ops/image.py +184 -3
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +541 -43
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +12 -1
- keras/src/quantizers/gptq.py +8 -6
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +150 -78
- keras/src/quantizers/quantization_config.py +232 -0
- keras/src/quantizers/quantizers.py +114 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +4 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +14 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +187 -36
- keras/src/utils/module_utils.py +18 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/rng_utils.py +9 -1
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
|
@@ -813,16 +813,17 @@ def append(x1, x2, axis=None):
|
|
|
813
813
|
return tf.concat([x1, x2], axis=axis)
|
|
814
814
|
|
|
815
815
|
|
|
816
|
-
def arange(start, stop=None, step=
|
|
816
|
+
def arange(start, stop=None, step=None, dtype=None):
|
|
817
817
|
if dtype is None:
|
|
818
|
-
dtypes_to_resolve = [
|
|
819
|
-
getattr(start, "dtype", type(start)),
|
|
820
|
-
getattr(step, "dtype", type(step)),
|
|
821
|
-
]
|
|
818
|
+
dtypes_to_resolve = [getattr(start, "dtype", type(start))]
|
|
822
819
|
if stop is not None:
|
|
823
820
|
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
|
|
821
|
+
if step is not None:
|
|
822
|
+
dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
|
|
824
823
|
dtype = dtypes.result_type(*dtypes_to_resolve)
|
|
825
824
|
dtype = standardize_dtype(dtype)
|
|
825
|
+
if step is None:
|
|
826
|
+
step = 1
|
|
826
827
|
try:
|
|
827
828
|
out = tf.range(start, stop, delta=step, dtype=dtype)
|
|
828
829
|
except tf.errors.NotFoundError:
|
|
@@ -997,6 +998,51 @@ def array(x, dtype=None):
|
|
|
997
998
|
return convert_to_tensor(x, dtype=dtype)
|
|
998
999
|
|
|
999
1000
|
|
|
1001
|
+
def view(x, dtype=None):
|
|
1002
|
+
from keras.src import backend
|
|
1003
|
+
|
|
1004
|
+
x = convert_to_tensor(x)
|
|
1005
|
+
old_dtype = tf.as_dtype(backend.standardize_dtype(x.dtype))
|
|
1006
|
+
new_dtype = tf.as_dtype(
|
|
1007
|
+
backend.standardize_dtype(dtype if dtype else x.dtype)
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
old_itemsize = old_dtype.size
|
|
1011
|
+
new_itemsize = new_dtype.size
|
|
1012
|
+
|
|
1013
|
+
old_shape = list(shape_op(x))
|
|
1014
|
+
last_dim_size = old_shape[-1] if len(old_shape) > 0 else -1
|
|
1015
|
+
if (last_dim_size == -1 and old_itemsize != new_itemsize) or (
|
|
1016
|
+
last_dim_size * old_itemsize % new_itemsize != 0
|
|
1017
|
+
):
|
|
1018
|
+
raise ValueError(
|
|
1019
|
+
f"Cannot view array of shape {x.shape} and dtype {old_dtype} "
|
|
1020
|
+
f"as dtype {new_dtype} because the total number of bytes "
|
|
1021
|
+
f"is not divisible by the new itemsize."
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
if old_itemsize == new_itemsize:
|
|
1025
|
+
return tf.bitcast(x, type=new_dtype)
|
|
1026
|
+
elif old_itemsize > new_itemsize:
|
|
1027
|
+
ratio = old_itemsize // new_itemsize
|
|
1028
|
+
new_shape = list(shape_op(x))
|
|
1029
|
+
new_shape[-1] *= ratio
|
|
1030
|
+
flat_tensor = tf.reshape(x, [-1])
|
|
1031
|
+
cast_tensor = tf.bitcast(flat_tensor, type=new_dtype)
|
|
1032
|
+
return tf.reshape(cast_tensor, new_shape)
|
|
1033
|
+
else:
|
|
1034
|
+
ratio = new_itemsize // old_itemsize
|
|
1035
|
+
if isinstance(last_dim_size, int) and last_dim_size % ratio != 0:
|
|
1036
|
+
raise ValueError(
|
|
1037
|
+
f"Cannot view dtype. Last dimension size ({last_dim_size}) "
|
|
1038
|
+
f"must be divisible by the ratio of new/old item sizes "
|
|
1039
|
+
f"({ratio})."
|
|
1040
|
+
)
|
|
1041
|
+
intermediate_shape = old_shape[:-1] + [last_dim_size // ratio, ratio]
|
|
1042
|
+
reshaped_tensor = tf.reshape(x, intermediate_shape)
|
|
1043
|
+
return tf.bitcast(reshaped_tensor, new_dtype)
|
|
1044
|
+
|
|
1045
|
+
|
|
1000
1046
|
def average(x, axis=None, weights=None):
|
|
1001
1047
|
x = convert_to_tensor(x)
|
|
1002
1048
|
|
|
@@ -1313,11 +1359,7 @@ def deg2rad(x):
|
|
|
1313
1359
|
def diag(x, k=0):
|
|
1314
1360
|
x = convert_to_tensor(x)
|
|
1315
1361
|
if len(x.shape) == 1:
|
|
1316
|
-
return tf.
|
|
1317
|
-
tf.equal(tf.size(x), 0),
|
|
1318
|
-
lambda: tf.zeros([builtins.abs(k), builtins.abs(k)], dtype=x.dtype),
|
|
1319
|
-
lambda: tf.linalg.diag(x, k=k),
|
|
1320
|
-
)
|
|
1362
|
+
return tf.linalg.diag(x, k=k)
|
|
1321
1363
|
elif len(x.shape) == 2:
|
|
1322
1364
|
return diagonal(x, offset=k)
|
|
1323
1365
|
else:
|
|
@@ -1443,6 +1485,10 @@ def empty(shape, dtype=None):
|
|
|
1443
1485
|
return tf.zeros(shape, dtype=dtype)
|
|
1444
1486
|
|
|
1445
1487
|
|
|
1488
|
+
def empty_like(x, dtype=None):
|
|
1489
|
+
return tf.zeros_like(x, dtype=dtype)
|
|
1490
|
+
|
|
1491
|
+
|
|
1446
1492
|
def equal(x1, x2):
|
|
1447
1493
|
x1 = convert_to_tensor(x1)
|
|
1448
1494
|
x2 = convert_to_tensor(x2)
|
|
@@ -1711,6 +1757,14 @@ def isposinf(x):
|
|
|
1711
1757
|
return tf.math.equal(x, tf.constant(float("inf"), dtype=x.dtype))
|
|
1712
1758
|
|
|
1713
1759
|
|
|
1760
|
+
def isreal(x):
|
|
1761
|
+
x = convert_to_tensor(x)
|
|
1762
|
+
if x.dtype.is_complex:
|
|
1763
|
+
return tf.equal(tf.math.imag(x), 0)
|
|
1764
|
+
else:
|
|
1765
|
+
return tf.ones_like(x, dtype=tf.bool)
|
|
1766
|
+
|
|
1767
|
+
|
|
1714
1768
|
def kron(x1, x2):
|
|
1715
1769
|
x1 = convert_to_tensor(x1)
|
|
1716
1770
|
x2 = convert_to_tensor(x2)
|
|
@@ -1786,6 +1840,23 @@ def lcm(x1, x2):
|
|
|
1786
1840
|
return result
|
|
1787
1841
|
|
|
1788
1842
|
|
|
1843
|
+
def ldexp(x1, x2):
|
|
1844
|
+
x1 = convert_to_tensor(x1)
|
|
1845
|
+
x2 = convert_to_tensor(x2)
|
|
1846
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1847
|
+
|
|
1848
|
+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
|
|
1849
|
+
raise TypeError(
|
|
1850
|
+
f"ldexp exponent must be an integer type. "
|
|
1851
|
+
f"Received: x2 dtype={x2.dtype}"
|
|
1852
|
+
)
|
|
1853
|
+
|
|
1854
|
+
x1 = tf.cast(x1, dtype)
|
|
1855
|
+
x2 = tf.cast(x2, x1.dtype)
|
|
1856
|
+
result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)
|
|
1857
|
+
return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)
|
|
1858
|
+
|
|
1859
|
+
|
|
1789
1860
|
def less(x1, x2):
|
|
1790
1861
|
x1 = convert_to_tensor(x1)
|
|
1791
1862
|
x2 = convert_to_tensor(x2)
|
|
@@ -1909,6 +1980,22 @@ def logaddexp(x1, x2):
|
|
|
1909
1980
|
)
|
|
1910
1981
|
|
|
1911
1982
|
|
|
1983
|
+
def logaddexp2(x1, x2):
|
|
1984
|
+
x1 = tf.convert_to_tensor(x1)
|
|
1985
|
+
x2 = tf.convert_to_tensor(x2)
|
|
1986
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1987
|
+
x1 = tf.cast(x1, dtype)
|
|
1988
|
+
x2 = tf.cast(x2, dtype)
|
|
1989
|
+
delta = x1 - x2
|
|
1990
|
+
log2 = tf.cast(tf.math.log(2.0), dtype)
|
|
1991
|
+
return tf.where(
|
|
1992
|
+
tf.math.is_nan(delta),
|
|
1993
|
+
x1 + x2,
|
|
1994
|
+
tf.maximum(x1, x2)
|
|
1995
|
+
+ tf.math.log1p(tf.math.exp(-tf.abs(delta) * log2)) / log2,
|
|
1996
|
+
)
|
|
1997
|
+
|
|
1998
|
+
|
|
1912
1999
|
def logical_and(x1, x2):
|
|
1913
2000
|
x1 = tf.cast(x1, "bool")
|
|
1914
2001
|
x2 = tf.cast(x2, "bool")
|
|
@@ -2233,7 +2320,7 @@ def _quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
|
2233
2320
|
return gathered_y
|
|
2234
2321
|
perm = collections.deque(range(ndims))
|
|
2235
2322
|
perm.rotate(shift_value_static)
|
|
2236
|
-
return tf.transpose(a=gathered_y, perm=perm)
|
|
2323
|
+
return tf.transpose(a=gathered_y, perm=list(perm))
|
|
2237
2324
|
|
|
2238
2325
|
|
|
2239
2326
|
def quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
@@ -2426,6 +2513,17 @@ def split(x, indices_or_sections, axis=0):
|
|
|
2426
2513
|
return tf.split(x, num_or_size_splits, axis=axis)
|
|
2427
2514
|
|
|
2428
2515
|
|
|
2516
|
+
def array_split(x, indices_or_sections, axis=0):
|
|
2517
|
+
x = tf.convert_to_tensor(x)
|
|
2518
|
+
num_splits = indices_or_sections
|
|
2519
|
+
total_size = shape_op(x)[axis]
|
|
2520
|
+
avg_size = total_size // num_splits
|
|
2521
|
+
remainder = total_size % num_splits
|
|
2522
|
+
sizes = [avg_size + 1] * remainder + [avg_size] * (num_splits - remainder)
|
|
2523
|
+
|
|
2524
|
+
return tf.split(x, sizes, axis=axis)
|
|
2525
|
+
|
|
2526
|
+
|
|
2429
2527
|
def stack(x, axis=0):
|
|
2430
2528
|
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
|
|
2431
2529
|
if len(dtype_set) > 1:
|
|
@@ -2657,27 +2755,44 @@ def round(x, decimals=0):
|
|
|
2657
2755
|
|
|
2658
2756
|
def tile(x, repeats):
|
|
2659
2757
|
x = convert_to_tensor(x)
|
|
2660
|
-
|
|
2661
|
-
|
|
2662
|
-
repeats
|
|
2663
|
-
repeats
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
|
|
2667
|
-
|
|
2668
|
-
|
|
2669
|
-
|
|
2670
|
-
|
|
2671
|
-
|
|
2672
|
-
|
|
2758
|
+
|
|
2759
|
+
# Convert repeats to a list (works for both sequences and 1D tensors)
|
|
2760
|
+
if isinstance(repeats, int):
|
|
2761
|
+
repeats = [repeats]
|
|
2762
|
+
else:
|
|
2763
|
+
repeats = [v for v in repeats]
|
|
2764
|
+
|
|
2765
|
+
# Process list elements: convert concrete scalar tensors to Python ints
|
|
2766
|
+
processed_repeats = []
|
|
2767
|
+
for r in repeats:
|
|
2768
|
+
if hasattr(r, "numpy") and r.shape == ():
|
|
2769
|
+
processed_repeats.append(int(r.numpy()))
|
|
2770
|
+
else:
|
|
2771
|
+
processed_repeats.append(r)
|
|
2772
|
+
repeats = processed_repeats
|
|
2773
|
+
|
|
2774
|
+
# Get x rank
|
|
2775
|
+
x_rank = x.shape.rank
|
|
2776
|
+
|
|
2777
|
+
# Pad repeats if needed
|
|
2778
|
+
if len(repeats) < x_rank:
|
|
2779
|
+
repeats = [1] * (x_rank - len(repeats)) + repeats
|
|
2780
|
+
|
|
2781
|
+
# Add dimensions to x if needed using tf.expand_dims
|
|
2782
|
+
while len(repeats) > x.shape.rank:
|
|
2783
|
+
x = tf.expand_dims(x, 0)
|
|
2784
|
+
|
|
2673
2785
|
return tf.tile(x, repeats)
|
|
2674
2786
|
|
|
2675
2787
|
|
|
2676
2788
|
def trace(x, offset=0, axis1=0, axis2=1):
|
|
2677
2789
|
x = convert_to_tensor(x)
|
|
2678
2790
|
dtype = standardize_dtype(x.dtype)
|
|
2679
|
-
if dtype
|
|
2680
|
-
dtype =
|
|
2791
|
+
if dtype in ("bool", "int8", "int16"):
|
|
2792
|
+
dtype = "int32"
|
|
2793
|
+
elif dtype in ("uint8", "uint16"):
|
|
2794
|
+
dtype = "uint32"
|
|
2795
|
+
x = tf.cast(x, dtype)
|
|
2681
2796
|
x_shape = tf.shape(x)
|
|
2682
2797
|
x = moveaxis(x, (axis1, axis2), (-2, -1))
|
|
2683
2798
|
# Mask out the diagonal and reduce.
|
|
@@ -2686,10 +2801,7 @@ def trace(x, offset=0, axis1=0, axis2=1):
|
|
|
2686
2801
|
x,
|
|
2687
2802
|
tf.zeros_like(x),
|
|
2688
2803
|
)
|
|
2689
|
-
|
|
2690
|
-
if standardize_dtype(x.dtype) == "bool":
|
|
2691
|
-
x = tf.cast(x, "int32")
|
|
2692
|
-
return tf.cast(tf.reduce_sum(x, axis=(-2, -1)), dtype)
|
|
2804
|
+
return tf.reduce_sum(x, axis=(-2, -1))
|
|
2693
2805
|
|
|
2694
2806
|
|
|
2695
2807
|
def tri(N, M=None, k=0, dtype=None):
|
|
@@ -2905,6 +3017,16 @@ def negative(x):
|
|
|
2905
3017
|
return tf.negative(x)
|
|
2906
3018
|
|
|
2907
3019
|
|
|
3020
|
+
def nextafter(x1, x2):
|
|
3021
|
+
x1 = convert_to_tensor(x1)
|
|
3022
|
+
x2 = convert_to_tensor(x2)
|
|
3023
|
+
|
|
3024
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
3025
|
+
x1 = tf.cast(x1, tf.float64)
|
|
3026
|
+
x2 = tf.cast(x2, tf.float64)
|
|
3027
|
+
return tf.cast(tf.math.nextafter(x1, x2), dtype)
|
|
3028
|
+
|
|
3029
|
+
|
|
2908
3030
|
@sparse.elementwise_unary
|
|
2909
3031
|
def square(x):
|
|
2910
3032
|
x = convert_to_tensor(x)
|
|
@@ -2959,6 +3081,63 @@ def transpose(x, axes=None):
|
|
|
2959
3081
|
return tf.transpose(x, perm=axes)
|
|
2960
3082
|
|
|
2961
3083
|
|
|
3084
|
+
def trapezoid(y, x=None, dx=1.0, axis=-1):
|
|
3085
|
+
def _move_axis_to_last(tensor, axis):
|
|
3086
|
+
if axis == -1:
|
|
3087
|
+
return tensor
|
|
3088
|
+
rank = tf.rank(tensor)
|
|
3089
|
+
if axis < 0:
|
|
3090
|
+
axis = rank + axis
|
|
3091
|
+
perm = tf.concat(
|
|
3092
|
+
[
|
|
3093
|
+
tf.range(axis, dtype=tf.int32),
|
|
3094
|
+
tf.range(axis + 1, rank, dtype=tf.int32),
|
|
3095
|
+
tf.constant([axis], dtype=tf.int32),
|
|
3096
|
+
],
|
|
3097
|
+
axis=0,
|
|
3098
|
+
)
|
|
3099
|
+
return tf.transpose(tensor, perm=perm)
|
|
3100
|
+
|
|
3101
|
+
y = convert_to_tensor(y)
|
|
3102
|
+
dtype = dtypes.result_type(y.dtype, float)
|
|
3103
|
+
y = tf.cast(y, dtype)
|
|
3104
|
+
|
|
3105
|
+
if x is None:
|
|
3106
|
+
dx_array = tf.cast(dx, dtype)
|
|
3107
|
+
else:
|
|
3108
|
+
x = convert_to_tensor(x, dtype=dtype)
|
|
3109
|
+
dx_array = diff(x, axis=axis)
|
|
3110
|
+
dx_array = _move_axis_to_last(dx_array, axis)
|
|
3111
|
+
|
|
3112
|
+
y = _move_axis_to_last(y, axis)
|
|
3113
|
+
|
|
3114
|
+
avg_heights = 0.5 * (y[..., 1:] + y[..., :-1])
|
|
3115
|
+
result = tf.reduce_sum(avg_heights * dx_array, axis=-1)
|
|
3116
|
+
|
|
3117
|
+
return result
|
|
3118
|
+
|
|
3119
|
+
|
|
3120
|
+
def vander(x, N=None, increasing=False):
|
|
3121
|
+
x = convert_to_tensor(x)
|
|
3122
|
+
result_dtype = dtypes.result_type(x.dtype)
|
|
3123
|
+
|
|
3124
|
+
if N is None:
|
|
3125
|
+
N = shape_op(x)[0]
|
|
3126
|
+
|
|
3127
|
+
if increasing:
|
|
3128
|
+
powers = tf.range(N)
|
|
3129
|
+
else:
|
|
3130
|
+
powers = tf.range(N - 1, -1, -1)
|
|
3131
|
+
|
|
3132
|
+
x_exp = tf.expand_dims(x, axis=-1)
|
|
3133
|
+
|
|
3134
|
+
compute_dtype = dtypes.result_type(x.dtype, "float32")
|
|
3135
|
+
vander = tf.math.pow(
|
|
3136
|
+
tf.cast(x_exp, compute_dtype), tf.cast(powers, compute_dtype)
|
|
3137
|
+
)
|
|
3138
|
+
return tf.cast(vander, result_dtype)
|
|
3139
|
+
|
|
3140
|
+
|
|
2962
3141
|
def var(x, axis=None, keepdims=False):
|
|
2963
3142
|
x = convert_to_tensor(x)
|
|
2964
3143
|
compute_dtype = dtypes.result_type(x.dtype, "float32")
|
|
@@ -3069,30 +3248,57 @@ def correlate(x1, x2, mode="valid"):
|
|
|
3069
3248
|
x1 = tf.cast(x1, dtype)
|
|
3070
3249
|
x2 = tf.cast(x2, dtype)
|
|
3071
3250
|
|
|
3072
|
-
|
|
3251
|
+
def _pack(a, b):
|
|
3252
|
+
# a: input [N] -> [1,N,1];
|
|
3253
|
+
# b: filter [M] -> [M,1,1]
|
|
3254
|
+
return (
|
|
3255
|
+
tf.reshape(a, (1, shape_op(a)[0], 1)),
|
|
3256
|
+
tf.reshape(b, (shape_op(b)[0], 1, 1)),
|
|
3257
|
+
)
|
|
3073
3258
|
|
|
3074
|
-
|
|
3075
|
-
|
|
3259
|
+
def _full_corr(x1, x2):
|
|
3260
|
+
"""Compute 'full' correlation result (length = n + m - 1)."""
|
|
3261
|
+
m = shape_op(x2)[0]
|
|
3262
|
+
pad = (
|
|
3263
|
+
builtins.max(m - 1, 0)
|
|
3264
|
+
if isinstance(m, int)
|
|
3265
|
+
else tf.maximum(m - 1, 0)
|
|
3266
|
+
)
|
|
3267
|
+
x1 = tf.pad(x1, [[pad, pad]]) # pad input with zeros
|
|
3268
|
+
x1, x2 = _pack(x1, x2)
|
|
3269
|
+
out = tf.nn.conv1d(x1, x2, stride=1, padding="VALID")
|
|
3270
|
+
return tf.squeeze(out, axis=[0, 2])
|
|
3076
3271
|
|
|
3077
|
-
|
|
3078
|
-
|
|
3272
|
+
n = shape_op(x1)[0]
|
|
3273
|
+
m = shape_op(x2)[0]
|
|
3079
3274
|
|
|
3080
|
-
|
|
3081
|
-
|
|
3275
|
+
if mode == "full":
|
|
3276
|
+
return _full_corr(x1, x2)
|
|
3277
|
+
elif mode == "same":
|
|
3278
|
+
# unfortunately we can't leverage 'SAME' padding directly like
|
|
3279
|
+
# we can with "valid"
|
|
3280
|
+
# it works fine for odd-length filters, but for even-length filters
|
|
3281
|
+
# the output is off by 1 compared to numpy, due to how
|
|
3282
|
+
# tf handles centering
|
|
3283
|
+
full_corr = _full_corr(x1, x2)
|
|
3284
|
+
full_len = n + m - 1
|
|
3285
|
+
out_len = (
|
|
3286
|
+
max([n, m])
|
|
3287
|
+
if isinstance(n, int) and isinstance(m, int)
|
|
3288
|
+
else tf.maximum(n, m)
|
|
3082
3289
|
)
|
|
3083
|
-
|
|
3084
|
-
|
|
3290
|
+
start = (full_len - out_len) // 2
|
|
3291
|
+
return tf.slice(full_corr, [start], [out_len])
|
|
3292
|
+
elif mode == "valid":
|
|
3293
|
+
x1, x2 = _pack(x1, x2)
|
|
3294
|
+
return tf.squeeze(
|
|
3295
|
+
tf.nn.conv1d(x1, x2, stride=1, padding="VALID"), axis=[0, 2]
|
|
3296
|
+
)
|
|
3297
|
+
else:
|
|
3298
|
+
raise ValueError(
|
|
3299
|
+
f"Invalid mode: '{mode}'. Mode must be one of:"
|
|
3300
|
+
f" 'full', 'same', 'valid'."
|
|
3085
3301
|
)
|
|
3086
|
-
|
|
3087
|
-
x1 = tf.reshape(x1, (1, full_len, 1))
|
|
3088
|
-
x2 = tf.reshape(x2, (full_len, 1, 1))
|
|
3089
|
-
|
|
3090
|
-
return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding="SAME"))
|
|
3091
|
-
|
|
3092
|
-
x1 = tf.reshape(x1, (1, x1_len, 1))
|
|
3093
|
-
x2 = tf.reshape(x2, (x2_len, 1, 1))
|
|
3094
|
-
|
|
3095
|
-
return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding=mode.upper()))
|
|
3096
3302
|
|
|
3097
3303
|
|
|
3098
3304
|
def select(condlist, choicelist, default=0):
|
|
@@ -3144,10 +3350,14 @@ def histogram(x, bins=10, range=None):
|
|
|
3144
3350
|
|
|
3145
3351
|
x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
|
|
3146
3352
|
bin_edges = tf.linspace(min_val, max_val, bins + 1)
|
|
3147
|
-
|
|
3148
|
-
bin_indices = tf.
|
|
3149
|
-
|
|
3150
|
-
|
|
3151
|
-
|
|
3353
|
+
bin_edges = tf.cast(bin_edges, x.dtype)
|
|
3354
|
+
bin_indices = tf.searchsorted(bin_edges[1:-1], x, side="right")
|
|
3355
|
+
|
|
3356
|
+
# tf.math.bincount does not work with XLA in this case. So, we use
|
|
3357
|
+
# `scatter_nd`.
|
|
3358
|
+
bin_counts = tf.scatter_nd(
|
|
3359
|
+
indices=tf.expand_dims(bin_indices, axis=-1),
|
|
3360
|
+
updates=tf.ones_like(bin_indices, dtype=x.dtype),
|
|
3361
|
+
shape=(bins,),
|
|
3152
3362
|
)
|
|
3153
3363
|
return bin_counts, bin_edges
|
keras/src/backend/torch/core.py
CHANGED
|
@@ -673,7 +673,9 @@ def remat(f):
|
|
|
673
673
|
"""
|
|
674
674
|
|
|
675
675
|
def wrapped(*args, **kwargs):
|
|
676
|
-
return torch.utils.checkpoint.checkpoint(
|
|
676
|
+
return torch.utils.checkpoint.checkpoint(
|
|
677
|
+
f, *args, use_reentrant=False, **kwargs
|
|
678
|
+
)
|
|
677
679
|
|
|
678
680
|
return wrapped
|
|
679
681
|
|
keras/src/backend/torch/nn.py
CHANGED
|
@@ -458,6 +458,94 @@ def average_pool(
|
|
|
458
458
|
return outputs
|
|
459
459
|
|
|
460
460
|
|
|
461
|
+
def adaptive_average_pool(inputs, output_size, data_format=None):
|
|
462
|
+
"""Adaptive average pooling(1D/2D/3D) with channels_last support."""
|
|
463
|
+
inputs = convert_to_tensor(inputs)
|
|
464
|
+
num_spatial_dims = inputs.ndim - 2
|
|
465
|
+
|
|
466
|
+
data_format = backend.standardize_data_format(data_format)
|
|
467
|
+
orig_format = data_format
|
|
468
|
+
if data_format == "channels_last":
|
|
469
|
+
inputs = _transpose_spatial_inputs(inputs)
|
|
470
|
+
|
|
471
|
+
if isinstance(output_size, int):
|
|
472
|
+
torch_output_size = (
|
|
473
|
+
output_size
|
|
474
|
+
if num_spatial_dims == 1
|
|
475
|
+
else (output_size,) * num_spatial_dims
|
|
476
|
+
)
|
|
477
|
+
else:
|
|
478
|
+
torch_output_size = standardize_tuple(
|
|
479
|
+
output_size, num_spatial_dims, "output_size"
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
if get_device() == "meta":
|
|
483
|
+
inputs = torch.empty(
|
|
484
|
+
size=inputs.shape, dtype=inputs.dtype, device="cpu"
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
if num_spatial_dims == 1:
|
|
488
|
+
outputs = tnn.adaptive_avg_pool1d(inputs, output_size=torch_output_size)
|
|
489
|
+
elif num_spatial_dims == 2:
|
|
490
|
+
outputs = tnn.adaptive_avg_pool2d(inputs, output_size=torch_output_size)
|
|
491
|
+
elif num_spatial_dims == 3:
|
|
492
|
+
outputs = tnn.adaptive_avg_pool3d(inputs, output_size=torch_output_size)
|
|
493
|
+
else:
|
|
494
|
+
raise ValueError(
|
|
495
|
+
"Inputs to adaptive average pooling must have ndim=3, 4 or 5, "
|
|
496
|
+
f"Received input shape: {inputs.shape}."
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
if orig_format == "channels_last":
|
|
500
|
+
outputs = _transpose_spatial_outputs(outputs)
|
|
501
|
+
return outputs
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
def adaptive_max_pool(inputs, output_size, data_format=None):
|
|
505
|
+
"""Adaptive max pooling(1D/2D/3D) with channels_last support."""
|
|
506
|
+
inputs = convert_to_tensor(inputs)
|
|
507
|
+
num_spatial_dims = inputs.ndim - 2
|
|
508
|
+
|
|
509
|
+
data_format = backend.standardize_data_format(data_format)
|
|
510
|
+
orig_format = data_format
|
|
511
|
+
if data_format == "channels_last":
|
|
512
|
+
inputs = _transpose_spatial_inputs(inputs)
|
|
513
|
+
|
|
514
|
+
if isinstance(output_size, int):
|
|
515
|
+
torch_output_size = (
|
|
516
|
+
output_size
|
|
517
|
+
if num_spatial_dims == 1
|
|
518
|
+
else (output_size,) * num_spatial_dims
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
torch_output_size = standardize_tuple(
|
|
522
|
+
output_size, num_spatial_dims, "output_size"
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
if get_device() == "meta":
|
|
526
|
+
inputs = torch.empty(
|
|
527
|
+
size=inputs.shape, dtype=inputs.dtype, device="cpu"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
if num_spatial_dims == 1:
|
|
531
|
+
res = tnn.adaptive_max_pool1d(inputs, output_size=torch_output_size)
|
|
532
|
+
elif num_spatial_dims == 2:
|
|
533
|
+
res = tnn.adaptive_max_pool2d(inputs, output_size=torch_output_size)
|
|
534
|
+
elif num_spatial_dims == 3:
|
|
535
|
+
res = tnn.adaptive_max_pool3d(inputs, output_size=torch_output_size)
|
|
536
|
+
else:
|
|
537
|
+
raise ValueError(
|
|
538
|
+
"Inputs to adaptive max pooling must have ndim=3, 4 or 5, "
|
|
539
|
+
f"Received input shape: {inputs.shape}."
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
outputs = res[0] if isinstance(res, tuple) else res
|
|
543
|
+
|
|
544
|
+
if orig_format == "channels_last":
|
|
545
|
+
outputs = _transpose_spatial_outputs(outputs)
|
|
546
|
+
return outputs
|
|
547
|
+
|
|
548
|
+
|
|
461
549
|
def conv(
|
|
462
550
|
inputs,
|
|
463
551
|
kernel,
|
|
@@ -755,12 +843,26 @@ def binary_crossentropy(target, output, from_logits=False):
|
|
|
755
843
|
target = convert_to_tensor(target)
|
|
756
844
|
output = convert_to_tensor(output)
|
|
757
845
|
|
|
846
|
+
# We only apply the squeeze fix if we are on an MPS device,
|
|
847
|
+
# as this change breaks tests on other platforms that
|
|
848
|
+
# expect the original tensor shape to be preserved.
|
|
849
|
+
if (
|
|
850
|
+
torch.backends.mps.is_available()
|
|
851
|
+
and target.ndim > 1
|
|
852
|
+
and output.ndim == target.ndim
|
|
853
|
+
and target.shape[-1] == 1
|
|
854
|
+
and output.shape[-1] == 1
|
|
855
|
+
):
|
|
856
|
+
target = torch.squeeze(target, -1).contiguous()
|
|
857
|
+
output = torch.squeeze(output, -1).contiguous()
|
|
858
|
+
|
|
758
859
|
if target.shape != output.shape:
|
|
759
860
|
raise ValueError(
|
|
760
861
|
"Arguments `target` and `output` must have the same shape. "
|
|
761
862
|
"Received: "
|
|
762
863
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
763
864
|
)
|
|
865
|
+
|
|
764
866
|
# By default, PyTorch, does reduction of `sum` over all rows,
|
|
765
867
|
# change reduction to `none` to keep dim
|
|
766
868
|
if from_logits:
|
|
@@ -1092,3 +1194,26 @@ def dot_product_attention(
|
|
|
1092
1194
|
scale=scale,
|
|
1093
1195
|
)
|
|
1094
1196
|
return torch.transpose(attention_output, axis1, axis0)
|
|
1197
|
+
|
|
1198
|
+
|
|
1199
|
+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|
1200
|
+
"""Native PyTorch implementation of Unfold.
|
|
1201
|
+
Extract sliding local blocks from a **NCHW** batched image tensor.
|
|
1202
|
+
|
|
1203
|
+
Args:
|
|
1204
|
+
input: 4-D tensor, shape (N, C, H, W) **required**.
|
|
1205
|
+
kernel_size: int or (kH, kW)
|
|
1206
|
+
dilation: int or (dH, dW), default 1
|
|
1207
|
+
padding: int or (pH, pW), default 0
|
|
1208
|
+
stride: int or (sH, sW), default 1
|
|
1209
|
+
|
|
1210
|
+
Returns:
|
|
1211
|
+
3-D tensor, shape (N, C*kH*kW, L)
|
|
1212
|
+
"""
|
|
1213
|
+
return tnn.unfold(
|
|
1214
|
+
input,
|
|
1215
|
+
kernel_size=kernel_size,
|
|
1216
|
+
dilation=dilation,
|
|
1217
|
+
padding=padding,
|
|
1218
|
+
stride=stride,
|
|
1219
|
+
)
|