keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -998,6 +998,51 @@ def array(x, dtype=None):
|
|
|
998
998
|
return convert_to_tensor(x, dtype=dtype)
|
|
999
999
|
|
|
1000
1000
|
|
|
1001
|
+
def view(x, dtype=None):
|
|
1002
|
+
from keras.src import backend
|
|
1003
|
+
|
|
1004
|
+
x = convert_to_tensor(x)
|
|
1005
|
+
old_dtype = tf.as_dtype(backend.standardize_dtype(x.dtype))
|
|
1006
|
+
new_dtype = tf.as_dtype(
|
|
1007
|
+
backend.standardize_dtype(dtype if dtype else x.dtype)
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
old_itemsize = old_dtype.size
|
|
1011
|
+
new_itemsize = new_dtype.size
|
|
1012
|
+
|
|
1013
|
+
old_shape = list(shape_op(x))
|
|
1014
|
+
last_dim_size = old_shape[-1] if len(old_shape) > 0 else -1
|
|
1015
|
+
if (last_dim_size == -1 and old_itemsize != new_itemsize) or (
|
|
1016
|
+
last_dim_size * old_itemsize % new_itemsize != 0
|
|
1017
|
+
):
|
|
1018
|
+
raise ValueError(
|
|
1019
|
+
f"Cannot view array of shape {x.shape} and dtype {old_dtype} "
|
|
1020
|
+
f"as dtype {new_dtype} because the total number of bytes "
|
|
1021
|
+
f"is not divisible by the new itemsize."
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
if old_itemsize == new_itemsize:
|
|
1025
|
+
return tf.bitcast(x, type=new_dtype)
|
|
1026
|
+
elif old_itemsize > new_itemsize:
|
|
1027
|
+
ratio = old_itemsize // new_itemsize
|
|
1028
|
+
new_shape = list(shape_op(x))
|
|
1029
|
+
new_shape[-1] *= ratio
|
|
1030
|
+
flat_tensor = tf.reshape(x, [-1])
|
|
1031
|
+
cast_tensor = tf.bitcast(flat_tensor, type=new_dtype)
|
|
1032
|
+
return tf.reshape(cast_tensor, new_shape)
|
|
1033
|
+
else:
|
|
1034
|
+
ratio = new_itemsize // old_itemsize
|
|
1035
|
+
if isinstance(last_dim_size, int) and last_dim_size % ratio != 0:
|
|
1036
|
+
raise ValueError(
|
|
1037
|
+
f"Cannot view dtype. Last dimension size ({last_dim_size}) "
|
|
1038
|
+
f"must be divisible by the ratio of new/old item sizes "
|
|
1039
|
+
f"({ratio})."
|
|
1040
|
+
)
|
|
1041
|
+
intermediate_shape = old_shape[:-1] + [last_dim_size // ratio, ratio]
|
|
1042
|
+
reshaped_tensor = tf.reshape(x, intermediate_shape)
|
|
1043
|
+
return tf.bitcast(reshaped_tensor, new_dtype)
|
|
1044
|
+
|
|
1045
|
+
|
|
1001
1046
|
def average(x, axis=None, weights=None):
|
|
1002
1047
|
x = convert_to_tensor(x)
|
|
1003
1048
|
|
|
@@ -1314,11 +1359,7 @@ def deg2rad(x):
|
|
|
1314
1359
|
def diag(x, k=0):
|
|
1315
1360
|
x = convert_to_tensor(x)
|
|
1316
1361
|
if len(x.shape) == 1:
|
|
1317
|
-
return tf.
|
|
1318
|
-
tf.equal(tf.size(x), 0),
|
|
1319
|
-
lambda: tf.zeros([builtins.abs(k), builtins.abs(k)], dtype=x.dtype),
|
|
1320
|
-
lambda: tf.linalg.diag(x, k=k),
|
|
1321
|
-
)
|
|
1362
|
+
return tf.linalg.diag(x, k=k)
|
|
1322
1363
|
elif len(x.shape) == 2:
|
|
1323
1364
|
return diagonal(x, offset=k)
|
|
1324
1365
|
else:
|
|
@@ -1444,6 +1485,10 @@ def empty(shape, dtype=None):
|
|
|
1444
1485
|
return tf.zeros(shape, dtype=dtype)
|
|
1445
1486
|
|
|
1446
1487
|
|
|
1488
|
+
def empty_like(x, dtype=None):
|
|
1489
|
+
return tf.zeros_like(x, dtype=dtype)
|
|
1490
|
+
|
|
1491
|
+
|
|
1447
1492
|
def equal(x1, x2):
|
|
1448
1493
|
x1 = convert_to_tensor(x1)
|
|
1449
1494
|
x2 = convert_to_tensor(x2)
|
|
@@ -1712,6 +1757,14 @@ def isposinf(x):
|
|
|
1712
1757
|
return tf.math.equal(x, tf.constant(float("inf"), dtype=x.dtype))
|
|
1713
1758
|
|
|
1714
1759
|
|
|
1760
|
+
def isreal(x):
|
|
1761
|
+
x = convert_to_tensor(x)
|
|
1762
|
+
if x.dtype.is_complex:
|
|
1763
|
+
return tf.equal(tf.math.imag(x), 0)
|
|
1764
|
+
else:
|
|
1765
|
+
return tf.ones_like(x, dtype=tf.bool)
|
|
1766
|
+
|
|
1767
|
+
|
|
1715
1768
|
def kron(x1, x2):
|
|
1716
1769
|
x1 = convert_to_tensor(x1)
|
|
1717
1770
|
x2 = convert_to_tensor(x2)
|
|
@@ -1787,6 +1840,23 @@ def lcm(x1, x2):
|
|
|
1787
1840
|
return result
|
|
1788
1841
|
|
|
1789
1842
|
|
|
1843
|
+
def ldexp(x1, x2):
|
|
1844
|
+
x1 = convert_to_tensor(x1)
|
|
1845
|
+
x2 = convert_to_tensor(x2)
|
|
1846
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1847
|
+
|
|
1848
|
+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
|
|
1849
|
+
raise TypeError(
|
|
1850
|
+
f"ldexp exponent must be an integer type. "
|
|
1851
|
+
f"Received: x2 dtype={x2.dtype}"
|
|
1852
|
+
)
|
|
1853
|
+
|
|
1854
|
+
x1 = tf.cast(x1, dtype)
|
|
1855
|
+
x2 = tf.cast(x2, x1.dtype)
|
|
1856
|
+
result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)
|
|
1857
|
+
return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)
|
|
1858
|
+
|
|
1859
|
+
|
|
1790
1860
|
def less(x1, x2):
|
|
1791
1861
|
x1 = convert_to_tensor(x1)
|
|
1792
1862
|
x2 = convert_to_tensor(x2)
|
|
@@ -2081,7 +2151,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
|
|
|
2081
2151
|
|
|
2082
2152
|
def ndim(x):
|
|
2083
2153
|
x = convert_to_tensor(x)
|
|
2084
|
-
return x.
|
|
2154
|
+
return x.shape.rank
|
|
2085
2155
|
|
|
2086
2156
|
|
|
2087
2157
|
def nonzero(x):
|
|
@@ -2145,6 +2215,13 @@ def prod(x, axis=None, keepdims=False, dtype=None):
|
|
|
2145
2215
|
return tf.reduce_prod(x, axis=axis, keepdims=keepdims)
|
|
2146
2216
|
|
|
2147
2217
|
|
|
2218
|
+
def ptp(x, axis=None, keepdims=False):
|
|
2219
|
+
x = convert_to_tensor(x)
|
|
2220
|
+
return tf.reduce_max(x, axis=axis, keepdims=keepdims) - tf.reduce_min(
|
|
2221
|
+
x, axis=axis, keepdims=keepdims
|
|
2222
|
+
)
|
|
2223
|
+
|
|
2224
|
+
|
|
2148
2225
|
def _quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
2149
2226
|
# ref: tfp.stats.percentile
|
|
2150
2227
|
# float64 is needed here and below, else we get the wrong index if the array
|
|
@@ -2250,7 +2327,7 @@ def _quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
|
2250
2327
|
return gathered_y
|
|
2251
2328
|
perm = collections.deque(range(ndims))
|
|
2252
2329
|
perm.rotate(shift_value_static)
|
|
2253
|
-
return tf.transpose(a=gathered_y, perm=perm)
|
|
2330
|
+
return tf.transpose(a=gathered_y, perm=list(perm))
|
|
2254
2331
|
|
|
2255
2332
|
|
|
2256
2333
|
def quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
@@ -2443,6 +2520,17 @@ def split(x, indices_or_sections, axis=0):
|
|
|
2443
2520
|
return tf.split(x, num_or_size_splits, axis=axis)
|
|
2444
2521
|
|
|
2445
2522
|
|
|
2523
|
+
def array_split(x, indices_or_sections, axis=0):
|
|
2524
|
+
x = tf.convert_to_tensor(x)
|
|
2525
|
+
num_splits = indices_or_sections
|
|
2526
|
+
total_size = shape_op(x)[axis]
|
|
2527
|
+
avg_size = total_size // num_splits
|
|
2528
|
+
remainder = total_size % num_splits
|
|
2529
|
+
sizes = [avg_size + 1] * remainder + [avg_size] * (num_splits - remainder)
|
|
2530
|
+
|
|
2531
|
+
return tf.split(x, sizes, axis=axis)
|
|
2532
|
+
|
|
2533
|
+
|
|
2446
2534
|
def stack(x, axis=0):
|
|
2447
2535
|
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
|
|
2448
2536
|
if len(dtype_set) > 1:
|
|
@@ -2674,27 +2762,44 @@ def round(x, decimals=0):
|
|
|
2674
2762
|
|
|
2675
2763
|
def tile(x, repeats):
|
|
2676
2764
|
x = convert_to_tensor(x)
|
|
2677
|
-
|
|
2678
|
-
|
|
2679
|
-
repeats
|
|
2680
|
-
repeats
|
|
2681
|
-
|
|
2682
|
-
|
|
2683
|
-
|
|
2684
|
-
|
|
2685
|
-
|
|
2686
|
-
|
|
2687
|
-
|
|
2688
|
-
|
|
2689
|
-
|
|
2765
|
+
|
|
2766
|
+
# Convert repeats to a list (works for both sequences and 1D tensors)
|
|
2767
|
+
if isinstance(repeats, int):
|
|
2768
|
+
repeats = [repeats]
|
|
2769
|
+
else:
|
|
2770
|
+
repeats = [v for v in repeats]
|
|
2771
|
+
|
|
2772
|
+
# Process list elements: convert concrete scalar tensors to Python ints
|
|
2773
|
+
processed_repeats = []
|
|
2774
|
+
for r in repeats:
|
|
2775
|
+
if hasattr(r, "numpy") and r.shape == ():
|
|
2776
|
+
processed_repeats.append(int(r.numpy()))
|
|
2777
|
+
else:
|
|
2778
|
+
processed_repeats.append(r)
|
|
2779
|
+
repeats = processed_repeats
|
|
2780
|
+
|
|
2781
|
+
# Get x rank
|
|
2782
|
+
x_rank = x.shape.rank
|
|
2783
|
+
|
|
2784
|
+
# Pad repeats if needed
|
|
2785
|
+
if len(repeats) < x_rank:
|
|
2786
|
+
repeats = [1] * (x_rank - len(repeats)) + repeats
|
|
2787
|
+
|
|
2788
|
+
# Add dimensions to x if needed using tf.expand_dims
|
|
2789
|
+
while len(repeats) > x.shape.rank:
|
|
2790
|
+
x = tf.expand_dims(x, 0)
|
|
2791
|
+
|
|
2690
2792
|
return tf.tile(x, repeats)
|
|
2691
2793
|
|
|
2692
2794
|
|
|
2693
2795
|
def trace(x, offset=0, axis1=0, axis2=1):
|
|
2694
2796
|
x = convert_to_tensor(x)
|
|
2695
2797
|
dtype = standardize_dtype(x.dtype)
|
|
2696
|
-
if dtype
|
|
2697
|
-
dtype =
|
|
2798
|
+
if dtype in ("bool", "int8", "int16"):
|
|
2799
|
+
dtype = "int32"
|
|
2800
|
+
elif dtype in ("uint8", "uint16"):
|
|
2801
|
+
dtype = "uint32"
|
|
2802
|
+
x = tf.cast(x, dtype)
|
|
2698
2803
|
x_shape = tf.shape(x)
|
|
2699
2804
|
x = moveaxis(x, (axis1, axis2), (-2, -1))
|
|
2700
2805
|
# Mask out the diagonal and reduce.
|
|
@@ -2703,10 +2808,7 @@ def trace(x, offset=0, axis1=0, axis2=1):
|
|
|
2703
2808
|
x,
|
|
2704
2809
|
tf.zeros_like(x),
|
|
2705
2810
|
)
|
|
2706
|
-
|
|
2707
|
-
if standardize_dtype(x.dtype) == "bool":
|
|
2708
|
-
x = tf.cast(x, "int32")
|
|
2709
|
-
return tf.cast(tf.reduce_sum(x, axis=(-2, -1)), dtype)
|
|
2811
|
+
return tf.reduce_sum(x, axis=(-2, -1))
|
|
2710
2812
|
|
|
2711
2813
|
|
|
2712
2814
|
def tri(N, M=None, k=0, dtype=None):
|
|
@@ -2922,6 +3024,16 @@ def negative(x):
|
|
|
2922
3024
|
return tf.negative(x)
|
|
2923
3025
|
|
|
2924
3026
|
|
|
3027
|
+
def nextafter(x1, x2):
|
|
3028
|
+
x1 = convert_to_tensor(x1)
|
|
3029
|
+
x2 = convert_to_tensor(x2)
|
|
3030
|
+
|
|
3031
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
3032
|
+
x1 = tf.cast(x1, tf.float64)
|
|
3033
|
+
x2 = tf.cast(x2, tf.float64)
|
|
3034
|
+
return tf.cast(tf.math.nextafter(x1, x2), dtype)
|
|
3035
|
+
|
|
3036
|
+
|
|
2925
3037
|
@sparse.elementwise_unary
|
|
2926
3038
|
def square(x):
|
|
2927
3039
|
x = convert_to_tensor(x)
|
|
@@ -2976,6 +3088,63 @@ def transpose(x, axes=None):
|
|
|
2976
3088
|
return tf.transpose(x, perm=axes)
|
|
2977
3089
|
|
|
2978
3090
|
|
|
3091
|
+
def trapezoid(y, x=None, dx=1.0, axis=-1):
|
|
3092
|
+
def _move_axis_to_last(tensor, axis):
|
|
3093
|
+
if axis == -1:
|
|
3094
|
+
return tensor
|
|
3095
|
+
rank = tf.rank(tensor)
|
|
3096
|
+
if axis < 0:
|
|
3097
|
+
axis = rank + axis
|
|
3098
|
+
perm = tf.concat(
|
|
3099
|
+
[
|
|
3100
|
+
tf.range(axis, dtype=tf.int32),
|
|
3101
|
+
tf.range(axis + 1, rank, dtype=tf.int32),
|
|
3102
|
+
tf.constant([axis], dtype=tf.int32),
|
|
3103
|
+
],
|
|
3104
|
+
axis=0,
|
|
3105
|
+
)
|
|
3106
|
+
return tf.transpose(tensor, perm=perm)
|
|
3107
|
+
|
|
3108
|
+
y = convert_to_tensor(y)
|
|
3109
|
+
dtype = dtypes.result_type(y.dtype, float)
|
|
3110
|
+
y = tf.cast(y, dtype)
|
|
3111
|
+
|
|
3112
|
+
if x is None:
|
|
3113
|
+
dx_array = tf.cast(dx, dtype)
|
|
3114
|
+
else:
|
|
3115
|
+
x = convert_to_tensor(x, dtype=dtype)
|
|
3116
|
+
dx_array = diff(x, axis=axis)
|
|
3117
|
+
dx_array = _move_axis_to_last(dx_array, axis)
|
|
3118
|
+
|
|
3119
|
+
y = _move_axis_to_last(y, axis)
|
|
3120
|
+
|
|
3121
|
+
avg_heights = 0.5 * (y[..., 1:] + y[..., :-1])
|
|
3122
|
+
result = tf.reduce_sum(avg_heights * dx_array, axis=-1)
|
|
3123
|
+
|
|
3124
|
+
return result
|
|
3125
|
+
|
|
3126
|
+
|
|
3127
|
+
def vander(x, N=None, increasing=False):
|
|
3128
|
+
x = convert_to_tensor(x)
|
|
3129
|
+
result_dtype = dtypes.result_type(x.dtype)
|
|
3130
|
+
|
|
3131
|
+
if N is None:
|
|
3132
|
+
N = shape_op(x)[0]
|
|
3133
|
+
|
|
3134
|
+
if increasing:
|
|
3135
|
+
powers = tf.range(N)
|
|
3136
|
+
else:
|
|
3137
|
+
powers = tf.range(N - 1, -1, -1)
|
|
3138
|
+
|
|
3139
|
+
x_exp = tf.expand_dims(x, axis=-1)
|
|
3140
|
+
|
|
3141
|
+
compute_dtype = dtypes.result_type(x.dtype, "float32")
|
|
3142
|
+
vander = tf.math.pow(
|
|
3143
|
+
tf.cast(x_exp, compute_dtype), tf.cast(powers, compute_dtype)
|
|
3144
|
+
)
|
|
3145
|
+
return tf.cast(vander, result_dtype)
|
|
3146
|
+
|
|
3147
|
+
|
|
2979
3148
|
def var(x, axis=None, keepdims=False):
|
|
2980
3149
|
x = convert_to_tensor(x)
|
|
2981
3150
|
compute_dtype = dtypes.result_type(x.dtype, "float32")
|
|
@@ -3086,30 +3255,57 @@ def correlate(x1, x2, mode="valid"):
|
|
|
3086
3255
|
x1 = tf.cast(x1, dtype)
|
|
3087
3256
|
x2 = tf.cast(x2, dtype)
|
|
3088
3257
|
|
|
3089
|
-
|
|
3258
|
+
def _pack(a, b):
|
|
3259
|
+
# a: input [N] -> [1,N,1];
|
|
3260
|
+
# b: filter [M] -> [M,1,1]
|
|
3261
|
+
return (
|
|
3262
|
+
tf.reshape(a, (1, shape_op(a)[0], 1)),
|
|
3263
|
+
tf.reshape(b, (shape_op(b)[0], 1, 1)),
|
|
3264
|
+
)
|
|
3090
3265
|
|
|
3091
|
-
|
|
3092
|
-
|
|
3266
|
+
def _full_corr(x1, x2):
|
|
3267
|
+
"""Compute 'full' correlation result (length = n + m - 1)."""
|
|
3268
|
+
m = shape_op(x2)[0]
|
|
3269
|
+
pad = (
|
|
3270
|
+
builtins.max(m - 1, 0)
|
|
3271
|
+
if isinstance(m, int)
|
|
3272
|
+
else tf.maximum(m - 1, 0)
|
|
3273
|
+
)
|
|
3274
|
+
x1 = tf.pad(x1, [[pad, pad]]) # pad input with zeros
|
|
3275
|
+
x1, x2 = _pack(x1, x2)
|
|
3276
|
+
out = tf.nn.conv1d(x1, x2, stride=1, padding="VALID")
|
|
3277
|
+
return tf.squeeze(out, axis=[0, 2])
|
|
3093
3278
|
|
|
3094
|
-
|
|
3095
|
-
|
|
3279
|
+
n = shape_op(x1)[0]
|
|
3280
|
+
m = shape_op(x2)[0]
|
|
3096
3281
|
|
|
3097
|
-
|
|
3098
|
-
|
|
3282
|
+
if mode == "full":
|
|
3283
|
+
return _full_corr(x1, x2)
|
|
3284
|
+
elif mode == "same":
|
|
3285
|
+
# unfortunately we can't leverage 'SAME' padding directly like
|
|
3286
|
+
# we can with "valid"
|
|
3287
|
+
# it works fine for odd-length filters, but for even-length filters
|
|
3288
|
+
# the output is off by 1 compared to numpy, due to how
|
|
3289
|
+
# tf handles centering
|
|
3290
|
+
full_corr = _full_corr(x1, x2)
|
|
3291
|
+
full_len = n + m - 1
|
|
3292
|
+
out_len = (
|
|
3293
|
+
max([n, m])
|
|
3294
|
+
if isinstance(n, int) and isinstance(m, int)
|
|
3295
|
+
else tf.maximum(n, m)
|
|
3099
3296
|
)
|
|
3100
|
-
|
|
3101
|
-
|
|
3297
|
+
start = (full_len - out_len) // 2
|
|
3298
|
+
return tf.slice(full_corr, [start], [out_len])
|
|
3299
|
+
elif mode == "valid":
|
|
3300
|
+
x1, x2 = _pack(x1, x2)
|
|
3301
|
+
return tf.squeeze(
|
|
3302
|
+
tf.nn.conv1d(x1, x2, stride=1, padding="VALID"), axis=[0, 2]
|
|
3303
|
+
)
|
|
3304
|
+
else:
|
|
3305
|
+
raise ValueError(
|
|
3306
|
+
f"Invalid mode: '{mode}'. Mode must be one of:"
|
|
3307
|
+
f" 'full', 'same', 'valid'."
|
|
3102
3308
|
)
|
|
3103
|
-
|
|
3104
|
-
x1 = tf.reshape(x1, (1, full_len, 1))
|
|
3105
|
-
x2 = tf.reshape(x2, (full_len, 1, 1))
|
|
3106
|
-
|
|
3107
|
-
return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding="SAME"))
|
|
3108
|
-
|
|
3109
|
-
x1 = tf.reshape(x1, (1, x1_len, 1))
|
|
3110
|
-
x2 = tf.reshape(x2, (x2_len, 1, 1))
|
|
3111
|
-
|
|
3112
|
-
return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding=mode.upper()))
|
|
3113
3309
|
|
|
3114
3310
|
|
|
3115
3311
|
def select(condlist, choicelist, default=0):
|
|
@@ -3161,10 +3357,14 @@ def histogram(x, bins=10, range=None):
|
|
|
3161
3357
|
|
|
3162
3358
|
x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
|
|
3163
3359
|
bin_edges = tf.linspace(min_val, max_val, bins + 1)
|
|
3164
|
-
|
|
3165
|
-
bin_indices = tf.
|
|
3166
|
-
|
|
3167
|
-
|
|
3168
|
-
|
|
3360
|
+
bin_edges = tf.cast(bin_edges, x.dtype)
|
|
3361
|
+
bin_indices = tf.searchsorted(bin_edges[1:-1], x, side="right")
|
|
3362
|
+
|
|
3363
|
+
# tf.math.bincount does not work with XLA in this case. So, we use
|
|
3364
|
+
# `scatter_nd`.
|
|
3365
|
+
bin_counts = tf.scatter_nd(
|
|
3366
|
+
indices=tf.expand_dims(bin_indices, axis=-1),
|
|
3367
|
+
updates=tf.ones_like(bin_indices, dtype=x.dtype),
|
|
3368
|
+
shape=(bins,),
|
|
3169
3369
|
)
|
|
3170
3370
|
return bin_counts, bin_edges
|
keras/src/backend/torch/core.py
CHANGED
|
@@ -673,7 +673,9 @@ def remat(f):
|
|
|
673
673
|
"""
|
|
674
674
|
|
|
675
675
|
def wrapped(*args, **kwargs):
|
|
676
|
-
return torch.utils.checkpoint.checkpoint(
|
|
676
|
+
return torch.utils.checkpoint.checkpoint(
|
|
677
|
+
f, *args, use_reentrant=False, **kwargs
|
|
678
|
+
)
|
|
677
679
|
|
|
678
680
|
return wrapped
|
|
679
681
|
|
keras/src/backend/torch/nn.py
CHANGED
|
@@ -458,6 +458,94 @@ def average_pool(
|
|
|
458
458
|
return outputs
|
|
459
459
|
|
|
460
460
|
|
|
461
|
+
def adaptive_average_pool(inputs, output_size, data_format=None):
|
|
462
|
+
"""Adaptive average pooling(1D/2D/3D) with channels_last support."""
|
|
463
|
+
inputs = convert_to_tensor(inputs)
|
|
464
|
+
num_spatial_dims = inputs.ndim - 2
|
|
465
|
+
|
|
466
|
+
data_format = backend.standardize_data_format(data_format)
|
|
467
|
+
orig_format = data_format
|
|
468
|
+
if data_format == "channels_last":
|
|
469
|
+
inputs = _transpose_spatial_inputs(inputs)
|
|
470
|
+
|
|
471
|
+
if isinstance(output_size, int):
|
|
472
|
+
torch_output_size = (
|
|
473
|
+
output_size
|
|
474
|
+
if num_spatial_dims == 1
|
|
475
|
+
else (output_size,) * num_spatial_dims
|
|
476
|
+
)
|
|
477
|
+
else:
|
|
478
|
+
torch_output_size = standardize_tuple(
|
|
479
|
+
output_size, num_spatial_dims, "output_size"
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
if get_device() == "meta":
|
|
483
|
+
inputs = torch.empty(
|
|
484
|
+
size=inputs.shape, dtype=inputs.dtype, device="cpu"
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
if num_spatial_dims == 1:
|
|
488
|
+
outputs = tnn.adaptive_avg_pool1d(inputs, output_size=torch_output_size)
|
|
489
|
+
elif num_spatial_dims == 2:
|
|
490
|
+
outputs = tnn.adaptive_avg_pool2d(inputs, output_size=torch_output_size)
|
|
491
|
+
elif num_spatial_dims == 3:
|
|
492
|
+
outputs = tnn.adaptive_avg_pool3d(inputs, output_size=torch_output_size)
|
|
493
|
+
else:
|
|
494
|
+
raise ValueError(
|
|
495
|
+
"Inputs to adaptive average pooling must have ndim=3, 4 or 5, "
|
|
496
|
+
f"Received input shape: {inputs.shape}."
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
if orig_format == "channels_last":
|
|
500
|
+
outputs = _transpose_spatial_outputs(outputs)
|
|
501
|
+
return outputs
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
def adaptive_max_pool(inputs, output_size, data_format=None):
|
|
505
|
+
"""Adaptive max pooling(1D/2D/3D) with channels_last support."""
|
|
506
|
+
inputs = convert_to_tensor(inputs)
|
|
507
|
+
num_spatial_dims = inputs.ndim - 2
|
|
508
|
+
|
|
509
|
+
data_format = backend.standardize_data_format(data_format)
|
|
510
|
+
orig_format = data_format
|
|
511
|
+
if data_format == "channels_last":
|
|
512
|
+
inputs = _transpose_spatial_inputs(inputs)
|
|
513
|
+
|
|
514
|
+
if isinstance(output_size, int):
|
|
515
|
+
torch_output_size = (
|
|
516
|
+
output_size
|
|
517
|
+
if num_spatial_dims == 1
|
|
518
|
+
else (output_size,) * num_spatial_dims
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
torch_output_size = standardize_tuple(
|
|
522
|
+
output_size, num_spatial_dims, "output_size"
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
if get_device() == "meta":
|
|
526
|
+
inputs = torch.empty(
|
|
527
|
+
size=inputs.shape, dtype=inputs.dtype, device="cpu"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
if num_spatial_dims == 1:
|
|
531
|
+
res = tnn.adaptive_max_pool1d(inputs, output_size=torch_output_size)
|
|
532
|
+
elif num_spatial_dims == 2:
|
|
533
|
+
res = tnn.adaptive_max_pool2d(inputs, output_size=torch_output_size)
|
|
534
|
+
elif num_spatial_dims == 3:
|
|
535
|
+
res = tnn.adaptive_max_pool3d(inputs, output_size=torch_output_size)
|
|
536
|
+
else:
|
|
537
|
+
raise ValueError(
|
|
538
|
+
"Inputs to adaptive max pooling must have ndim=3, 4 or 5, "
|
|
539
|
+
f"Received input shape: {inputs.shape}."
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
outputs = res[0] if isinstance(res, tuple) else res
|
|
543
|
+
|
|
544
|
+
if orig_format == "channels_last":
|
|
545
|
+
outputs = _transpose_spatial_outputs(outputs)
|
|
546
|
+
return outputs
|
|
547
|
+
|
|
548
|
+
|
|
461
549
|
def conv(
|
|
462
550
|
inputs,
|
|
463
551
|
kernel,
|
|
@@ -755,12 +843,26 @@ def binary_crossentropy(target, output, from_logits=False):
|
|
|
755
843
|
target = convert_to_tensor(target)
|
|
756
844
|
output = convert_to_tensor(output)
|
|
757
845
|
|
|
846
|
+
# We only apply the squeeze fix if we are on an MPS device,
|
|
847
|
+
# as this change breaks tests on other platforms that
|
|
848
|
+
# expect the original tensor shape to be preserved.
|
|
849
|
+
if (
|
|
850
|
+
torch.backends.mps.is_available()
|
|
851
|
+
and target.ndim > 1
|
|
852
|
+
and output.ndim == target.ndim
|
|
853
|
+
and target.shape[-1] == 1
|
|
854
|
+
and output.shape[-1] == 1
|
|
855
|
+
):
|
|
856
|
+
target = torch.squeeze(target, -1).contiguous()
|
|
857
|
+
output = torch.squeeze(output, -1).contiguous()
|
|
858
|
+
|
|
758
859
|
if target.shape != output.shape:
|
|
759
860
|
raise ValueError(
|
|
760
861
|
"Arguments `target` and `output` must have the same shape. "
|
|
761
862
|
"Received: "
|
|
762
863
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
763
864
|
)
|
|
865
|
+
|
|
764
866
|
# By default, PyTorch, does reduction of `sum` over all rows,
|
|
765
867
|
# change reduction to `none` to keep dim
|
|
766
868
|
if from_logits:
|
|
@@ -1092,3 +1194,26 @@ def dot_product_attention(
|
|
|
1092
1194
|
scale=scale,
|
|
1093
1195
|
)
|
|
1094
1196
|
return torch.transpose(attention_output, axis1, axis0)
|
|
1197
|
+
|
|
1198
|
+
|
|
1199
|
+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|
1200
|
+
"""Native PyTorch implementation of Unfold.
|
|
1201
|
+
Extract sliding local blocks from a **NCHW** batched image tensor.
|
|
1202
|
+
|
|
1203
|
+
Args:
|
|
1204
|
+
input: 4-D tensor, shape (N, C, H, W) **required**.
|
|
1205
|
+
kernel_size: int or (kH, kW)
|
|
1206
|
+
dilation: int or (dH, dW), default 1
|
|
1207
|
+
padding: int or (pH, pW), default 0
|
|
1208
|
+
stride: int or (sH, sW), default 1
|
|
1209
|
+
|
|
1210
|
+
Returns:
|
|
1211
|
+
3-D tensor, shape (N, C*kH*kW, L)
|
|
1212
|
+
"""
|
|
1213
|
+
return tnn.unfold(
|
|
1214
|
+
input,
|
|
1215
|
+
kernel_size=kernel_size,
|
|
1216
|
+
dilation=dilation,
|
|
1217
|
+
padding=padding,
|
|
1218
|
+
stride=stride,
|
|
1219
|
+
)
|