keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -813,16 +813,17 @@ def append(x1, x2, axis=None):
|
|
|
813
813
|
return tf.concat([x1, x2], axis=axis)
|
|
814
814
|
|
|
815
815
|
|
|
816
|
-
def arange(start, stop=None, step=
|
|
816
|
+
def arange(start, stop=None, step=None, dtype=None):
|
|
817
817
|
if dtype is None:
|
|
818
|
-
dtypes_to_resolve = [
|
|
819
|
-
getattr(start, "dtype", type(start)),
|
|
820
|
-
getattr(step, "dtype", type(step)),
|
|
821
|
-
]
|
|
818
|
+
dtypes_to_resolve = [getattr(start, "dtype", type(start))]
|
|
822
819
|
if stop is not None:
|
|
823
820
|
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
|
|
821
|
+
if step is not None:
|
|
822
|
+
dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
|
|
824
823
|
dtype = dtypes.result_type(*dtypes_to_resolve)
|
|
825
824
|
dtype = standardize_dtype(dtype)
|
|
825
|
+
if step is None:
|
|
826
|
+
step = 1
|
|
826
827
|
try:
|
|
827
828
|
out = tf.range(start, stop, delta=step, dtype=dtype)
|
|
828
829
|
except tf.errors.NotFoundError:
|
|
@@ -997,6 +998,51 @@ def array(x, dtype=None):
|
|
|
997
998
|
return convert_to_tensor(x, dtype=dtype)
|
|
998
999
|
|
|
999
1000
|
|
|
1001
|
+
def view(x, dtype=None):
|
|
1002
|
+
from keras.src import backend
|
|
1003
|
+
|
|
1004
|
+
x = convert_to_tensor(x)
|
|
1005
|
+
old_dtype = tf.as_dtype(backend.standardize_dtype(x.dtype))
|
|
1006
|
+
new_dtype = tf.as_dtype(
|
|
1007
|
+
backend.standardize_dtype(dtype if dtype else x.dtype)
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
old_itemsize = old_dtype.size
|
|
1011
|
+
new_itemsize = new_dtype.size
|
|
1012
|
+
|
|
1013
|
+
old_shape = list(shape_op(x))
|
|
1014
|
+
last_dim_size = old_shape[-1] if len(old_shape) > 0 else -1
|
|
1015
|
+
if (last_dim_size == -1 and old_itemsize != new_itemsize) or (
|
|
1016
|
+
last_dim_size * old_itemsize % new_itemsize != 0
|
|
1017
|
+
):
|
|
1018
|
+
raise ValueError(
|
|
1019
|
+
f"Cannot view array of shape {x.shape} and dtype {old_dtype} "
|
|
1020
|
+
f"as dtype {new_dtype} because the total number of bytes "
|
|
1021
|
+
f"is not divisible by the new itemsize."
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
if old_itemsize == new_itemsize:
|
|
1025
|
+
return tf.bitcast(x, type=new_dtype)
|
|
1026
|
+
elif old_itemsize > new_itemsize:
|
|
1027
|
+
ratio = old_itemsize // new_itemsize
|
|
1028
|
+
new_shape = list(shape_op(x))
|
|
1029
|
+
new_shape[-1] *= ratio
|
|
1030
|
+
flat_tensor = tf.reshape(x, [-1])
|
|
1031
|
+
cast_tensor = tf.bitcast(flat_tensor, type=new_dtype)
|
|
1032
|
+
return tf.reshape(cast_tensor, new_shape)
|
|
1033
|
+
else:
|
|
1034
|
+
ratio = new_itemsize // old_itemsize
|
|
1035
|
+
if isinstance(last_dim_size, int) and last_dim_size % ratio != 0:
|
|
1036
|
+
raise ValueError(
|
|
1037
|
+
f"Cannot view dtype. Last dimension size ({last_dim_size}) "
|
|
1038
|
+
f"must be divisible by the ratio of new/old item sizes "
|
|
1039
|
+
f"({ratio})."
|
|
1040
|
+
)
|
|
1041
|
+
intermediate_shape = old_shape[:-1] + [last_dim_size // ratio, ratio]
|
|
1042
|
+
reshaped_tensor = tf.reshape(x, intermediate_shape)
|
|
1043
|
+
return tf.bitcast(reshaped_tensor, new_dtype)
|
|
1044
|
+
|
|
1045
|
+
|
|
1000
1046
|
def average(x, axis=None, weights=None):
|
|
1001
1047
|
x = convert_to_tensor(x)
|
|
1002
1048
|
|
|
@@ -1313,11 +1359,7 @@ def deg2rad(x):
|
|
|
1313
1359
|
def diag(x, k=0):
|
|
1314
1360
|
x = convert_to_tensor(x)
|
|
1315
1361
|
if len(x.shape) == 1:
|
|
1316
|
-
return tf.
|
|
1317
|
-
tf.equal(tf.size(x), 0),
|
|
1318
|
-
lambda: tf.zeros([builtins.abs(k), builtins.abs(k)], dtype=x.dtype),
|
|
1319
|
-
lambda: tf.linalg.diag(x, k=k),
|
|
1320
|
-
)
|
|
1362
|
+
return tf.linalg.diag(x, k=k)
|
|
1321
1363
|
elif len(x.shape) == 2:
|
|
1322
1364
|
return diagonal(x, offset=k)
|
|
1323
1365
|
else:
|
|
@@ -1443,6 +1485,10 @@ def empty(shape, dtype=None):
|
|
|
1443
1485
|
return tf.zeros(shape, dtype=dtype)
|
|
1444
1486
|
|
|
1445
1487
|
|
|
1488
|
+
def empty_like(x, dtype=None):
|
|
1489
|
+
return tf.zeros_like(x, dtype=dtype)
|
|
1490
|
+
|
|
1491
|
+
|
|
1446
1492
|
def equal(x1, x2):
|
|
1447
1493
|
x1 = convert_to_tensor(x1)
|
|
1448
1494
|
x2 = convert_to_tensor(x2)
|
|
@@ -1711,6 +1757,106 @@ def isposinf(x):
|
|
|
1711
1757
|
return tf.math.equal(x, tf.constant(float("inf"), dtype=x.dtype))
|
|
1712
1758
|
|
|
1713
1759
|
|
|
1760
|
+
def isreal(x):
|
|
1761
|
+
x = convert_to_tensor(x)
|
|
1762
|
+
if x.dtype.is_complex:
|
|
1763
|
+
return tf.equal(tf.math.imag(x), 0)
|
|
1764
|
+
else:
|
|
1765
|
+
return tf.ones_like(x, dtype=tf.bool)
|
|
1766
|
+
|
|
1767
|
+
|
|
1768
|
+
def kron(x1, x2):
|
|
1769
|
+
x1 = convert_to_tensor(x1)
|
|
1770
|
+
x2 = convert_to_tensor(x2)
|
|
1771
|
+
|
|
1772
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
|
|
1773
|
+
x1 = tf.cast(x1, dtype)
|
|
1774
|
+
x2 = tf.cast(x2, dtype)
|
|
1775
|
+
|
|
1776
|
+
ndim_x1 = tf.rank(x1)
|
|
1777
|
+
ndim_x2 = tf.rank(x2)
|
|
1778
|
+
|
|
1779
|
+
def expand_front(x, num):
|
|
1780
|
+
for _ in range(num):
|
|
1781
|
+
x = tf.expand_dims(x, axis=0)
|
|
1782
|
+
return x
|
|
1783
|
+
|
|
1784
|
+
x1 = tf.cond(
|
|
1785
|
+
ndim_x1 < ndim_x2,
|
|
1786
|
+
lambda: expand_front(x1, ndim_x2 - ndim_x1),
|
|
1787
|
+
lambda: x1,
|
|
1788
|
+
)
|
|
1789
|
+
x2 = tf.cond(
|
|
1790
|
+
ndim_x2 < ndim_x1,
|
|
1791
|
+
lambda: expand_front(x2, ndim_x1 - ndim_x2),
|
|
1792
|
+
lambda: x2,
|
|
1793
|
+
)
|
|
1794
|
+
|
|
1795
|
+
x1_reshaped = tf.reshape(
|
|
1796
|
+
x1,
|
|
1797
|
+
tf.reshape(
|
|
1798
|
+
tf.stack([tf.shape(x1), tf.ones_like(tf.shape(x1))], axis=1), [-1]
|
|
1799
|
+
),
|
|
1800
|
+
)
|
|
1801
|
+
x2_reshaped = tf.reshape(
|
|
1802
|
+
x2,
|
|
1803
|
+
tf.reshape(
|
|
1804
|
+
tf.stack([tf.ones_like(tf.shape(x2)), tf.shape(x2)], axis=1), [-1]
|
|
1805
|
+
),
|
|
1806
|
+
)
|
|
1807
|
+
|
|
1808
|
+
out = tf.multiply(x1_reshaped, x2_reshaped)
|
|
1809
|
+
out_shape = tf.multiply(tf.shape(x1), tf.shape(x2))
|
|
1810
|
+
out = tf.reshape(out, out_shape)
|
|
1811
|
+
return out
|
|
1812
|
+
|
|
1813
|
+
|
|
1814
|
+
def lcm(x1, x2):
|
|
1815
|
+
x1 = convert_to_tensor(x1)
|
|
1816
|
+
x2 = convert_to_tensor(x2)
|
|
1817
|
+
|
|
1818
|
+
if not (x1.dtype.is_integer and x2.dtype.is_integer):
|
|
1819
|
+
raise TypeError(
|
|
1820
|
+
f"Arguments to lcm must be integers. "
|
|
1821
|
+
f"Received: x1.dtype={x1.dtype.name}, x2.dtype={x2.dtype.name}"
|
|
1822
|
+
)
|
|
1823
|
+
|
|
1824
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
|
|
1825
|
+
x1 = tf.cast(x1, dtype)
|
|
1826
|
+
x2 = tf.cast(x2, dtype)
|
|
1827
|
+
|
|
1828
|
+
if dtype not in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]:
|
|
1829
|
+
x1 = tf.math.abs(x1)
|
|
1830
|
+
x2 = tf.math.abs(x2)
|
|
1831
|
+
|
|
1832
|
+
divisor = gcd(x1, x2)
|
|
1833
|
+
divisor_safe = tf.where(
|
|
1834
|
+
divisor == 0, tf.constant(1, dtype=divisor.dtype), divisor
|
|
1835
|
+
)
|
|
1836
|
+
|
|
1837
|
+
result = x1 * (x2 // divisor_safe)
|
|
1838
|
+
result = tf.where(divisor == 0, tf.zeros_like(result), result)
|
|
1839
|
+
|
|
1840
|
+
return result
|
|
1841
|
+
|
|
1842
|
+
|
|
1843
|
+
def ldexp(x1, x2):
|
|
1844
|
+
x1 = convert_to_tensor(x1)
|
|
1845
|
+
x2 = convert_to_tensor(x2)
|
|
1846
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1847
|
+
|
|
1848
|
+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
|
|
1849
|
+
raise TypeError(
|
|
1850
|
+
f"ldexp exponent must be an integer type. "
|
|
1851
|
+
f"Received: x2 dtype={x2.dtype}"
|
|
1852
|
+
)
|
|
1853
|
+
|
|
1854
|
+
x1 = tf.cast(x1, dtype)
|
|
1855
|
+
x2 = tf.cast(x2, x1.dtype)
|
|
1856
|
+
result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)
|
|
1857
|
+
return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)
|
|
1858
|
+
|
|
1859
|
+
|
|
1714
1860
|
def less(x1, x2):
|
|
1715
1861
|
x1 = convert_to_tensor(x1)
|
|
1716
1862
|
x2 = convert_to_tensor(x2)
|
|
@@ -1834,6 +1980,22 @@ def logaddexp(x1, x2):
|
|
|
1834
1980
|
)
|
|
1835
1981
|
|
|
1836
1982
|
|
|
1983
|
+
def logaddexp2(x1, x2):
|
|
1984
|
+
x1 = tf.convert_to_tensor(x1)
|
|
1985
|
+
x2 = tf.convert_to_tensor(x2)
|
|
1986
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1987
|
+
x1 = tf.cast(x1, dtype)
|
|
1988
|
+
x2 = tf.cast(x2, dtype)
|
|
1989
|
+
delta = x1 - x2
|
|
1990
|
+
log2 = tf.cast(tf.math.log(2.0), dtype)
|
|
1991
|
+
return tf.where(
|
|
1992
|
+
tf.math.is_nan(delta),
|
|
1993
|
+
x1 + x2,
|
|
1994
|
+
tf.maximum(x1, x2)
|
|
1995
|
+
+ tf.math.log1p(tf.math.exp(-tf.abs(delta) * log2)) / log2,
|
|
1996
|
+
)
|
|
1997
|
+
|
|
1998
|
+
|
|
1837
1999
|
def logical_and(x1, x2):
|
|
1838
2000
|
x1 = tf.cast(x1, "bool")
|
|
1839
2001
|
x2 = tf.cast(x2, "bool")
|
|
@@ -1989,7 +2151,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
|
|
|
1989
2151
|
|
|
1990
2152
|
def ndim(x):
|
|
1991
2153
|
x = convert_to_tensor(x)
|
|
1992
|
-
return x.
|
|
2154
|
+
return x.shape.rank
|
|
1993
2155
|
|
|
1994
2156
|
|
|
1995
2157
|
def nonzero(x):
|
|
@@ -2053,6 +2215,13 @@ def prod(x, axis=None, keepdims=False, dtype=None):
|
|
|
2053
2215
|
return tf.reduce_prod(x, axis=axis, keepdims=keepdims)
|
|
2054
2216
|
|
|
2055
2217
|
|
|
2218
|
+
def ptp(x, axis=None, keepdims=False):
|
|
2219
|
+
x = convert_to_tensor(x)
|
|
2220
|
+
return tf.reduce_max(x, axis=axis, keepdims=keepdims) - tf.reduce_min(
|
|
2221
|
+
x, axis=axis, keepdims=keepdims
|
|
2222
|
+
)
|
|
2223
|
+
|
|
2224
|
+
|
|
2056
2225
|
def _quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
2057
2226
|
# ref: tfp.stats.percentile
|
|
2058
2227
|
# float64 is needed here and below, else we get the wrong index if the array
|
|
@@ -2158,7 +2327,7 @@ def _quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
|
2158
2327
|
return gathered_y
|
|
2159
2328
|
perm = collections.deque(range(ndims))
|
|
2160
2329
|
perm.rotate(shift_value_static)
|
|
2161
|
-
return tf.transpose(a=gathered_y, perm=perm)
|
|
2330
|
+
return tf.transpose(a=gathered_y, perm=list(perm))
|
|
2162
2331
|
|
|
2163
2332
|
|
|
2164
2333
|
def quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
@@ -2256,8 +2425,11 @@ def searchsorted(sorted_sequence, values, side="left"):
|
|
|
2256
2425
|
"to extend it to N-D sequences. Received: "
|
|
2257
2426
|
f"sorted_sequence.shape={sorted_sequence.shape}"
|
|
2258
2427
|
)
|
|
2428
|
+
sequence_len = sorted_sequence.shape[0]
|
|
2259
2429
|
out_type = (
|
|
2260
|
-
"int32"
|
|
2430
|
+
"int32"
|
|
2431
|
+
if sequence_len is not None and sequence_len <= np.iinfo(np.int32).max
|
|
2432
|
+
else "int64"
|
|
2261
2433
|
)
|
|
2262
2434
|
return tf.searchsorted(
|
|
2263
2435
|
sorted_sequence, values, side=side, out_type=out_type
|
|
@@ -2348,6 +2520,17 @@ def split(x, indices_or_sections, axis=0):
|
|
|
2348
2520
|
return tf.split(x, num_or_size_splits, axis=axis)
|
|
2349
2521
|
|
|
2350
2522
|
|
|
2523
|
+
def array_split(x, indices_or_sections, axis=0):
|
|
2524
|
+
x = tf.convert_to_tensor(x)
|
|
2525
|
+
num_splits = indices_or_sections
|
|
2526
|
+
total_size = shape_op(x)[axis]
|
|
2527
|
+
avg_size = total_size // num_splits
|
|
2528
|
+
remainder = total_size % num_splits
|
|
2529
|
+
sizes = [avg_size + 1] * remainder + [avg_size] * (num_splits - remainder)
|
|
2530
|
+
|
|
2531
|
+
return tf.split(x, sizes, axis=axis)
|
|
2532
|
+
|
|
2533
|
+
|
|
2351
2534
|
def stack(x, axis=0):
|
|
2352
2535
|
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
|
|
2353
2536
|
if len(dtype_set) > 1:
|
|
@@ -2579,27 +2762,44 @@ def round(x, decimals=0):
|
|
|
2579
2762
|
|
|
2580
2763
|
def tile(x, repeats):
|
|
2581
2764
|
x = convert_to_tensor(x)
|
|
2582
|
-
|
|
2583
|
-
|
|
2584
|
-
repeats
|
|
2585
|
-
repeats
|
|
2586
|
-
|
|
2587
|
-
|
|
2588
|
-
|
|
2589
|
-
|
|
2590
|
-
|
|
2591
|
-
|
|
2592
|
-
|
|
2593
|
-
|
|
2594
|
-
|
|
2765
|
+
|
|
2766
|
+
# Convert repeats to a list (works for both sequences and 1D tensors)
|
|
2767
|
+
if isinstance(repeats, int):
|
|
2768
|
+
repeats = [repeats]
|
|
2769
|
+
else:
|
|
2770
|
+
repeats = [v for v in repeats]
|
|
2771
|
+
|
|
2772
|
+
# Process list elements: convert concrete scalar tensors to Python ints
|
|
2773
|
+
processed_repeats = []
|
|
2774
|
+
for r in repeats:
|
|
2775
|
+
if hasattr(r, "numpy") and r.shape == ():
|
|
2776
|
+
processed_repeats.append(int(r.numpy()))
|
|
2777
|
+
else:
|
|
2778
|
+
processed_repeats.append(r)
|
|
2779
|
+
repeats = processed_repeats
|
|
2780
|
+
|
|
2781
|
+
# Get x rank
|
|
2782
|
+
x_rank = x.shape.rank
|
|
2783
|
+
|
|
2784
|
+
# Pad repeats if needed
|
|
2785
|
+
if len(repeats) < x_rank:
|
|
2786
|
+
repeats = [1] * (x_rank - len(repeats)) + repeats
|
|
2787
|
+
|
|
2788
|
+
# Add dimensions to x if needed using tf.expand_dims
|
|
2789
|
+
while len(repeats) > x.shape.rank:
|
|
2790
|
+
x = tf.expand_dims(x, 0)
|
|
2791
|
+
|
|
2595
2792
|
return tf.tile(x, repeats)
|
|
2596
2793
|
|
|
2597
2794
|
|
|
2598
2795
|
def trace(x, offset=0, axis1=0, axis2=1):
|
|
2599
2796
|
x = convert_to_tensor(x)
|
|
2600
2797
|
dtype = standardize_dtype(x.dtype)
|
|
2601
|
-
if dtype
|
|
2602
|
-
dtype =
|
|
2798
|
+
if dtype in ("bool", "int8", "int16"):
|
|
2799
|
+
dtype = "int32"
|
|
2800
|
+
elif dtype in ("uint8", "uint16"):
|
|
2801
|
+
dtype = "uint32"
|
|
2802
|
+
x = tf.cast(x, dtype)
|
|
2603
2803
|
x_shape = tf.shape(x)
|
|
2604
2804
|
x = moveaxis(x, (axis1, axis2), (-2, -1))
|
|
2605
2805
|
# Mask out the diagonal and reduce.
|
|
@@ -2608,10 +2808,7 @@ def trace(x, offset=0, axis1=0, axis2=1):
|
|
|
2608
2808
|
x,
|
|
2609
2809
|
tf.zeros_like(x),
|
|
2610
2810
|
)
|
|
2611
|
-
|
|
2612
|
-
if standardize_dtype(x.dtype) == "bool":
|
|
2613
|
-
x = tf.cast(x, "int32")
|
|
2614
|
-
return tf.cast(tf.reduce_sum(x, axis=(-2, -1)), dtype)
|
|
2811
|
+
return tf.reduce_sum(x, axis=(-2, -1))
|
|
2615
2812
|
|
|
2616
2813
|
|
|
2617
2814
|
def tri(N, M=None, k=0, dtype=None):
|
|
@@ -2827,6 +3024,16 @@ def negative(x):
|
|
|
2827
3024
|
return tf.negative(x)
|
|
2828
3025
|
|
|
2829
3026
|
|
|
3027
|
+
def nextafter(x1, x2):
|
|
3028
|
+
x1 = convert_to_tensor(x1)
|
|
3029
|
+
x2 = convert_to_tensor(x2)
|
|
3030
|
+
|
|
3031
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
3032
|
+
x1 = tf.cast(x1, tf.float64)
|
|
3033
|
+
x2 = tf.cast(x2, tf.float64)
|
|
3034
|
+
return tf.cast(tf.math.nextafter(x1, x2), dtype)
|
|
3035
|
+
|
|
3036
|
+
|
|
2830
3037
|
@sparse.elementwise_unary
|
|
2831
3038
|
def square(x):
|
|
2832
3039
|
x = convert_to_tensor(x)
|
|
@@ -2881,6 +3088,63 @@ def transpose(x, axes=None):
|
|
|
2881
3088
|
return tf.transpose(x, perm=axes)
|
|
2882
3089
|
|
|
2883
3090
|
|
|
3091
|
+
def trapezoid(y, x=None, dx=1.0, axis=-1):
|
|
3092
|
+
def _move_axis_to_last(tensor, axis):
|
|
3093
|
+
if axis == -1:
|
|
3094
|
+
return tensor
|
|
3095
|
+
rank = tf.rank(tensor)
|
|
3096
|
+
if axis < 0:
|
|
3097
|
+
axis = rank + axis
|
|
3098
|
+
perm = tf.concat(
|
|
3099
|
+
[
|
|
3100
|
+
tf.range(axis, dtype=tf.int32),
|
|
3101
|
+
tf.range(axis + 1, rank, dtype=tf.int32),
|
|
3102
|
+
tf.constant([axis], dtype=tf.int32),
|
|
3103
|
+
],
|
|
3104
|
+
axis=0,
|
|
3105
|
+
)
|
|
3106
|
+
return tf.transpose(tensor, perm=perm)
|
|
3107
|
+
|
|
3108
|
+
y = convert_to_tensor(y)
|
|
3109
|
+
dtype = dtypes.result_type(y.dtype, float)
|
|
3110
|
+
y = tf.cast(y, dtype)
|
|
3111
|
+
|
|
3112
|
+
if x is None:
|
|
3113
|
+
dx_array = tf.cast(dx, dtype)
|
|
3114
|
+
else:
|
|
3115
|
+
x = convert_to_tensor(x, dtype=dtype)
|
|
3116
|
+
dx_array = diff(x, axis=axis)
|
|
3117
|
+
dx_array = _move_axis_to_last(dx_array, axis)
|
|
3118
|
+
|
|
3119
|
+
y = _move_axis_to_last(y, axis)
|
|
3120
|
+
|
|
3121
|
+
avg_heights = 0.5 * (y[..., 1:] + y[..., :-1])
|
|
3122
|
+
result = tf.reduce_sum(avg_heights * dx_array, axis=-1)
|
|
3123
|
+
|
|
3124
|
+
return result
|
|
3125
|
+
|
|
3126
|
+
|
|
3127
|
+
def vander(x, N=None, increasing=False):
|
|
3128
|
+
x = convert_to_tensor(x)
|
|
3129
|
+
result_dtype = dtypes.result_type(x.dtype)
|
|
3130
|
+
|
|
3131
|
+
if N is None:
|
|
3132
|
+
N = shape_op(x)[0]
|
|
3133
|
+
|
|
3134
|
+
if increasing:
|
|
3135
|
+
powers = tf.range(N)
|
|
3136
|
+
else:
|
|
3137
|
+
powers = tf.range(N - 1, -1, -1)
|
|
3138
|
+
|
|
3139
|
+
x_exp = tf.expand_dims(x, axis=-1)
|
|
3140
|
+
|
|
3141
|
+
compute_dtype = dtypes.result_type(x.dtype, "float32")
|
|
3142
|
+
vander = tf.math.pow(
|
|
3143
|
+
tf.cast(x_exp, compute_dtype), tf.cast(powers, compute_dtype)
|
|
3144
|
+
)
|
|
3145
|
+
return tf.cast(vander, result_dtype)
|
|
3146
|
+
|
|
3147
|
+
|
|
2884
3148
|
def var(x, axis=None, keepdims=False):
|
|
2885
3149
|
x = convert_to_tensor(x)
|
|
2886
3150
|
compute_dtype = dtypes.result_type(x.dtype, "float32")
|
|
@@ -2991,30 +3255,57 @@ def correlate(x1, x2, mode="valid"):
|
|
|
2991
3255
|
x1 = tf.cast(x1, dtype)
|
|
2992
3256
|
x2 = tf.cast(x2, dtype)
|
|
2993
3257
|
|
|
2994
|
-
|
|
3258
|
+
def _pack(a, b):
|
|
3259
|
+
# a: input [N] -> [1,N,1];
|
|
3260
|
+
# b: filter [M] -> [M,1,1]
|
|
3261
|
+
return (
|
|
3262
|
+
tf.reshape(a, (1, shape_op(a)[0], 1)),
|
|
3263
|
+
tf.reshape(b, (shape_op(b)[0], 1, 1)),
|
|
3264
|
+
)
|
|
2995
3265
|
|
|
2996
|
-
|
|
2997
|
-
|
|
3266
|
+
def _full_corr(x1, x2):
|
|
3267
|
+
"""Compute 'full' correlation result (length = n + m - 1)."""
|
|
3268
|
+
m = shape_op(x2)[0]
|
|
3269
|
+
pad = (
|
|
3270
|
+
builtins.max(m - 1, 0)
|
|
3271
|
+
if isinstance(m, int)
|
|
3272
|
+
else tf.maximum(m - 1, 0)
|
|
3273
|
+
)
|
|
3274
|
+
x1 = tf.pad(x1, [[pad, pad]]) # pad input with zeros
|
|
3275
|
+
x1, x2 = _pack(x1, x2)
|
|
3276
|
+
out = tf.nn.conv1d(x1, x2, stride=1, padding="VALID")
|
|
3277
|
+
return tf.squeeze(out, axis=[0, 2])
|
|
2998
3278
|
|
|
2999
|
-
|
|
3000
|
-
|
|
3279
|
+
n = shape_op(x1)[0]
|
|
3280
|
+
m = shape_op(x2)[0]
|
|
3001
3281
|
|
|
3002
|
-
|
|
3003
|
-
|
|
3282
|
+
if mode == "full":
|
|
3283
|
+
return _full_corr(x1, x2)
|
|
3284
|
+
elif mode == "same":
|
|
3285
|
+
# unfortunately we can't leverage 'SAME' padding directly like
|
|
3286
|
+
# we can with "valid"
|
|
3287
|
+
# it works fine for odd-length filters, but for even-length filters
|
|
3288
|
+
# the output is off by 1 compared to numpy, due to how
|
|
3289
|
+
# tf handles centering
|
|
3290
|
+
full_corr = _full_corr(x1, x2)
|
|
3291
|
+
full_len = n + m - 1
|
|
3292
|
+
out_len = (
|
|
3293
|
+
max([n, m])
|
|
3294
|
+
if isinstance(n, int) and isinstance(m, int)
|
|
3295
|
+
else tf.maximum(n, m)
|
|
3004
3296
|
)
|
|
3005
|
-
|
|
3006
|
-
|
|
3297
|
+
start = (full_len - out_len) // 2
|
|
3298
|
+
return tf.slice(full_corr, [start], [out_len])
|
|
3299
|
+
elif mode == "valid":
|
|
3300
|
+
x1, x2 = _pack(x1, x2)
|
|
3301
|
+
return tf.squeeze(
|
|
3302
|
+
tf.nn.conv1d(x1, x2, stride=1, padding="VALID"), axis=[0, 2]
|
|
3303
|
+
)
|
|
3304
|
+
else:
|
|
3305
|
+
raise ValueError(
|
|
3306
|
+
f"Invalid mode: '{mode}'. Mode must be one of:"
|
|
3307
|
+
f" 'full', 'same', 'valid'."
|
|
3007
3308
|
)
|
|
3008
|
-
|
|
3009
|
-
x1 = tf.reshape(x1, (1, full_len, 1))
|
|
3010
|
-
x2 = tf.reshape(x2, (full_len, 1, 1))
|
|
3011
|
-
|
|
3012
|
-
return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding="SAME"))
|
|
3013
|
-
|
|
3014
|
-
x1 = tf.reshape(x1, (1, x1_len, 1))
|
|
3015
|
-
x2 = tf.reshape(x2, (x2_len, 1, 1))
|
|
3016
|
-
|
|
3017
|
-
return tf.squeeze(tf.nn.conv1d(x1, x2, stride=1, padding=mode.upper()))
|
|
3018
3309
|
|
|
3019
3310
|
|
|
3020
3311
|
def select(condlist, choicelist, default=0):
|
|
@@ -3066,10 +3357,14 @@ def histogram(x, bins=10, range=None):
|
|
|
3066
3357
|
|
|
3067
3358
|
x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))
|
|
3068
3359
|
bin_edges = tf.linspace(min_val, max_val, bins + 1)
|
|
3069
|
-
|
|
3070
|
-
bin_indices = tf.
|
|
3071
|
-
|
|
3072
|
-
|
|
3073
|
-
|
|
3360
|
+
bin_edges = tf.cast(bin_edges, x.dtype)
|
|
3361
|
+
bin_indices = tf.searchsorted(bin_edges[1:-1], x, side="right")
|
|
3362
|
+
|
|
3363
|
+
# tf.math.bincount does not work with XLA in this case. So, we use
|
|
3364
|
+
# `scatter_nd`.
|
|
3365
|
+
bin_counts = tf.scatter_nd(
|
|
3366
|
+
indices=tf.expand_dims(bin_indices, axis=-1),
|
|
3367
|
+
updates=tf.ones_like(bin_indices, dtype=x.dtype),
|
|
3368
|
+
shape=(bins,),
|
|
3074
3369
|
)
|
|
3075
3370
|
return bin_counts, bin_edges
|
|
@@ -68,7 +68,9 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
|
|
68
68
|
)
|
|
69
69
|
self._loss_tracker.update_state(
|
|
70
70
|
loss_module.unscale_loss_for_distribution(loss),
|
|
71
|
-
sample_weight=tf.shape(
|
|
71
|
+
sample_weight=tf.shape(
|
|
72
|
+
next(i for i in tree.flatten(x) if i is not None)
|
|
73
|
+
)[0],
|
|
72
74
|
)
|
|
73
75
|
if self.optimizer is not None:
|
|
74
76
|
loss = self.optimizer.scale_loss(loss)
|
|
@@ -96,7 +98,9 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
|
|
96
98
|
)
|
|
97
99
|
self._loss_tracker.update_state(
|
|
98
100
|
loss_module.unscale_loss_for_distribution(loss),
|
|
99
|
-
sample_weight=tf.shape(
|
|
101
|
+
sample_weight=tf.shape(
|
|
102
|
+
next(i for i in tree.flatten(x) if i is not None)
|
|
103
|
+
)[0],
|
|
100
104
|
)
|
|
101
105
|
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
|
|
102
106
|
|
keras/src/backend/torch/core.py
CHANGED
|
@@ -673,7 +673,9 @@ def remat(f):
|
|
|
673
673
|
"""
|
|
674
674
|
|
|
675
675
|
def wrapped(*args, **kwargs):
|
|
676
|
-
return torch.utils.checkpoint.checkpoint(
|
|
676
|
+
return torch.utils.checkpoint.checkpoint(
|
|
677
|
+
f, *args, use_reentrant=False, **kwargs
|
|
678
|
+
)
|
|
677
679
|
|
|
678
680
|
return wrapped
|
|
679
681
|
|