keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +12 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +12 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +1 -1
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +33 -16
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +485 -20
- keras/src/backend/jax/numpy.py +92 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +76 -7
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1030 -185
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +264 -54
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +84 -8
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +299 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +191 -172
- keras/src/layers/core/einsum_dense.py +235 -186
- keras/src/layers/core/embedding.py +83 -93
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +390 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +40 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/losses/loss.py +1 -1
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +156 -27
- keras/src/ops/image.py +184 -3
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +541 -43
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +12 -1
- keras/src/quantizers/gptq.py +8 -6
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +150 -78
- keras/src/quantizers/quantization_config.py +232 -0
- keras/src/quantizers/quantizers.py +114 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +4 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +14 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +187 -36
- keras/src/utils/module_utils.py +18 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/rng_utils.py +9 -1
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
keras/src/backend/jax/numpy.py
CHANGED
|
@@ -3,6 +3,7 @@ import math
|
|
|
3
3
|
|
|
4
4
|
import jax.experimental.sparse as jax_sparse
|
|
5
5
|
import jax.numpy as jnp
|
|
6
|
+
from jax import export as jax_export
|
|
6
7
|
|
|
7
8
|
from keras.src.backend import config
|
|
8
9
|
from keras.src.backend.common import dtypes
|
|
@@ -306,14 +307,20 @@ def append(x1, x2, axis=None):
|
|
|
306
307
|
return jnp.append(x1, x2, axis=axis)
|
|
307
308
|
|
|
308
309
|
|
|
309
|
-
def arange(start, stop=None, step=
|
|
310
|
+
def arange(start, stop=None, step=None, dtype=None):
|
|
311
|
+
def get_dtype(x):
|
|
312
|
+
if hasattr(x, "dtype"):
|
|
313
|
+
return x.dtype
|
|
314
|
+
if jax_export.is_symbolic_dim(x):
|
|
315
|
+
return int
|
|
316
|
+
return type(x)
|
|
317
|
+
|
|
310
318
|
if dtype is None:
|
|
311
|
-
dtypes_to_resolve = [
|
|
312
|
-
getattr(start, "dtype", type(start)),
|
|
313
|
-
getattr(step, "dtype", type(step)),
|
|
314
|
-
]
|
|
319
|
+
dtypes_to_resolve = [get_dtype(start)]
|
|
315
320
|
if stop is not None:
|
|
316
|
-
dtypes_to_resolve.append(
|
|
321
|
+
dtypes_to_resolve.append(get_dtype(stop))
|
|
322
|
+
if step is not None:
|
|
323
|
+
dtypes_to_resolve.append(get_dtype(step))
|
|
317
324
|
dtype = dtypes.result_type(*dtypes_to_resolve)
|
|
318
325
|
dtype = standardize_dtype(dtype)
|
|
319
326
|
return jnp.arange(start, stop, step=step, dtype=dtype)
|
|
@@ -439,6 +446,11 @@ def array(x, dtype=None):
|
|
|
439
446
|
return jnp.array(x, dtype=dtype)
|
|
440
447
|
|
|
441
448
|
|
|
449
|
+
def view(x, dtype=None):
|
|
450
|
+
x = convert_to_tensor(x)
|
|
451
|
+
return x.view(dtype=dtype)
|
|
452
|
+
|
|
453
|
+
|
|
442
454
|
def average(x, axis=None, weights=None):
|
|
443
455
|
x = convert_to_tensor(x)
|
|
444
456
|
dtypes_to_resolve = [x.dtype, float]
|
|
@@ -536,15 +548,18 @@ def clip(x, x_min, x_max):
|
|
|
536
548
|
|
|
537
549
|
def concatenate(xs, axis=0):
|
|
538
550
|
bcoo_count = builtins.sum(isinstance(x, jax_sparse.BCOO) for x in xs)
|
|
539
|
-
if bcoo_count:
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
551
|
+
if bcoo_count == len(xs):
|
|
552
|
+
axis = canonicalize_axis(axis, len(xs[0].shape))
|
|
553
|
+
return jax_sparse.bcoo_concatenate(xs, dimension=axis)
|
|
554
|
+
elif bcoo_count:
|
|
555
|
+
xs = [
|
|
556
|
+
x.todense()
|
|
557
|
+
if isinstance(x, jax_sparse.JAXSparse)
|
|
558
|
+
else convert_to_tensor(x)
|
|
559
|
+
for x in xs
|
|
560
|
+
]
|
|
561
|
+
else:
|
|
562
|
+
xs = [convert_to_tensor(x) for x in xs]
|
|
548
563
|
return jnp.concatenate(xs, axis=axis)
|
|
549
564
|
|
|
550
565
|
|
|
@@ -663,6 +678,10 @@ def empty(shape, dtype=None):
|
|
|
663
678
|
return jnp.empty(shape, dtype=dtype)
|
|
664
679
|
|
|
665
680
|
|
|
681
|
+
def empty_like(x, dtype=None):
|
|
682
|
+
return jnp.empty_like(x, dtype=dtype)
|
|
683
|
+
|
|
684
|
+
|
|
666
685
|
def equal(x1, x2):
|
|
667
686
|
x1 = convert_to_tensor(x1)
|
|
668
687
|
x2 = convert_to_tensor(x2)
|
|
@@ -809,6 +828,11 @@ def isposinf(x):
|
|
|
809
828
|
return jnp.isposinf(x)
|
|
810
829
|
|
|
811
830
|
|
|
831
|
+
def isreal(x):
|
|
832
|
+
x = convert_to_tensor(x)
|
|
833
|
+
return jnp.isreal(x)
|
|
834
|
+
|
|
835
|
+
|
|
812
836
|
def kron(x1, x2):
|
|
813
837
|
x1 = convert_to_tensor(x1)
|
|
814
838
|
x2 = convert_to_tensor(x2)
|
|
@@ -821,6 +845,19 @@ def lcm(x1, x2):
|
|
|
821
845
|
return jnp.lcm(x1, x2)
|
|
822
846
|
|
|
823
847
|
|
|
848
|
+
def ldexp(x1, x2):
|
|
849
|
+
x1 = convert_to_tensor(x1)
|
|
850
|
+
x2 = convert_to_tensor(x2)
|
|
851
|
+
|
|
852
|
+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
|
|
853
|
+
raise TypeError(
|
|
854
|
+
f"ldexp exponent must be an integer type. "
|
|
855
|
+
f"Received: x2 dtype={x2.dtype}"
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
return jnp.ldexp(x1, x2)
|
|
859
|
+
|
|
860
|
+
|
|
824
861
|
def less(x1, x2):
|
|
825
862
|
x1 = convert_to_tensor(x1)
|
|
826
863
|
x2 = convert_to_tensor(x2)
|
|
@@ -888,6 +925,15 @@ def logaddexp(x1, x2):
|
|
|
888
925
|
return jnp.logaddexp(x1, x2)
|
|
889
926
|
|
|
890
927
|
|
|
928
|
+
def logaddexp2(x1, x2):
|
|
929
|
+
x1 = convert_to_tensor(x1)
|
|
930
|
+
x2 = convert_to_tensor(x2)
|
|
931
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
932
|
+
x1 = cast(x1, dtype)
|
|
933
|
+
x2 = cast(x2, dtype)
|
|
934
|
+
return jnp.logaddexp2(x1, x2)
|
|
935
|
+
|
|
936
|
+
|
|
891
937
|
def logical_and(x1, x2):
|
|
892
938
|
x1 = convert_to_tensor(x1)
|
|
893
939
|
x2 = convert_to_tensor(x2)
|
|
@@ -1071,6 +1117,7 @@ def reshape(x, newshape):
|
|
|
1071
1117
|
if None not in output_shape:
|
|
1072
1118
|
newshape = output_shape
|
|
1073
1119
|
return jax_sparse.bcoo_reshape(x, new_sizes=newshape)
|
|
1120
|
+
x = convert_to_tensor(x)
|
|
1074
1121
|
return jnp.reshape(x, newshape)
|
|
1075
1122
|
|
|
1076
1123
|
|
|
@@ -1133,10 +1180,17 @@ def sort(x, axis=-1):
|
|
|
1133
1180
|
|
|
1134
1181
|
|
|
1135
1182
|
def split(x, indices_or_sections, axis=0):
|
|
1183
|
+
x = convert_to_tensor(x)
|
|
1136
1184
|
return jnp.split(x, indices_or_sections, axis=axis)
|
|
1137
1185
|
|
|
1138
1186
|
|
|
1187
|
+
def array_split(x, indices_or_sections, axis=0):
|
|
1188
|
+
x = convert_to_tensor(x)
|
|
1189
|
+
return jnp.array_split(x, indices_or_sections, axis=axis)
|
|
1190
|
+
|
|
1191
|
+
|
|
1139
1192
|
def stack(x, axis=0):
|
|
1193
|
+
x = [convert_to_tensor(t) for t in x]
|
|
1140
1194
|
return jnp.stack(x, axis=axis)
|
|
1141
1195
|
|
|
1142
1196
|
|
|
@@ -1159,6 +1213,8 @@ def take(x, indices, axis=None):
|
|
|
1159
1213
|
|
|
1160
1214
|
|
|
1161
1215
|
def take_along_axis(x, indices, axis=None):
|
|
1216
|
+
x = convert_to_tensor(x)
|
|
1217
|
+
indices = convert_to_tensor(indices, sparse=False)
|
|
1162
1218
|
return jnp.take_along_axis(x, indices, axis=axis)
|
|
1163
1219
|
|
|
1164
1220
|
|
|
@@ -1213,14 +1269,7 @@ def tile(x, repeats):
|
|
|
1213
1269
|
|
|
1214
1270
|
def trace(x, offset=0, axis1=0, axis2=1):
|
|
1215
1271
|
x = convert_to_tensor(x)
|
|
1216
|
-
|
|
1217
|
-
# TODO: Remove the condition of uint8 and uint16 once we have jax>=0.4.27
|
|
1218
|
-
# for both CPU & GPU environments.
|
|
1219
|
-
# uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to int32
|
|
1220
|
-
# otherwise.
|
|
1221
|
-
if standardize_dtype(x.dtype) in ("bool", "uint8", "uint16"):
|
|
1222
|
-
dtype = "int32"
|
|
1223
|
-
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
|
|
1272
|
+
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)
|
|
1224
1273
|
|
|
1225
1274
|
|
|
1226
1275
|
def tri(N, M=None, k=0, dtype=None):
|
|
@@ -1302,6 +1351,12 @@ def negative(x):
|
|
|
1302
1351
|
return jnp.negative(x)
|
|
1303
1352
|
|
|
1304
1353
|
|
|
1354
|
+
def nextafter(x1, x2):
|
|
1355
|
+
x1 = convert_to_tensor(x1)
|
|
1356
|
+
x2 = convert_to_tensor(x2)
|
|
1357
|
+
return jnp.nextafter(x1, x2)
|
|
1358
|
+
|
|
1359
|
+
|
|
1305
1360
|
@sparse.elementwise_unary(linear=False)
|
|
1306
1361
|
def square(x):
|
|
1307
1362
|
x = convert_to_tensor(x)
|
|
@@ -1322,6 +1377,7 @@ def squeeze(x, axis=None):
|
|
|
1322
1377
|
axis = tuple(i for i, d in enumerate(x.shape) if d == 1)
|
|
1323
1378
|
axis = to_tuple_or_list(axis)
|
|
1324
1379
|
return jax_sparse.bcoo_squeeze(x, dimensions=axis)
|
|
1380
|
+
x = convert_to_tensor(x)
|
|
1325
1381
|
return jnp.squeeze(x, axis=axis)
|
|
1326
1382
|
|
|
1327
1383
|
|
|
@@ -1340,6 +1396,19 @@ def transpose(x, axes=None):
|
|
|
1340
1396
|
return jnp.transpose(x, axes=axes)
|
|
1341
1397
|
|
|
1342
1398
|
|
|
1399
|
+
def trapezoid(y, x=None, dx=1.0, axis=-1):
|
|
1400
|
+
y = convert_to_tensor(y)
|
|
1401
|
+
if x is not None:
|
|
1402
|
+
x = convert_to_tensor(x)
|
|
1403
|
+
dx = convert_to_tensor(dx)
|
|
1404
|
+
return jnp.trapezoid(y, x, dx=dx, axis=axis)
|
|
1405
|
+
|
|
1406
|
+
|
|
1407
|
+
def vander(x, N=None, increasing=False):
|
|
1408
|
+
x = convert_to_tensor(x)
|
|
1409
|
+
return jnp.vander(x, N=N, increasing=increasing)
|
|
1410
|
+
|
|
1411
|
+
|
|
1343
1412
|
def var(x, axis=None, keepdims=False):
|
|
1344
1413
|
x = convert_to_tensor(x)
|
|
1345
1414
|
# `jnp.var` does not handle low precision (e.g., float16) overflow
|
|
@@ -36,13 +36,14 @@ class JaxOptimizer(base_optimizer.BaseOptimizer):
|
|
|
36
36
|
new_g_accs = jax.lax.cond(
|
|
37
37
|
is_update_step,
|
|
38
38
|
lambda: [jnp.zeros(g.shape, dtype=g.dtype) for g in acc_grads],
|
|
39
|
-
lambda: [g + acc_g for g, acc_g in zip(grads, acc_grads)],
|
|
39
|
+
lambda: [g + acc_g.value for g, acc_g in zip(grads, acc_grads)],
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
grads = jax.lax.cond(
|
|
43
43
|
is_update_step,
|
|
44
44
|
lambda: [
|
|
45
|
-
(g + acc_g) / steps
|
|
45
|
+
(g + acc_g.value) / steps
|
|
46
|
+
for g, acc_g in zip(grads, acc_grads)
|
|
46
47
|
],
|
|
47
48
|
lambda: list(grads),
|
|
48
49
|
)
|
keras/src/backend/jax/trainer.py
CHANGED
|
@@ -266,8 +266,14 @@ class JAXTrainer(base_trainer.Trainer):
|
|
|
266
266
|
if distribution_lib.distribution() is not None:
|
|
267
267
|
state_shardings = self._get_state_sharding_spec()
|
|
268
268
|
out_shardings = (None, state_shardings)
|
|
269
|
+
if is_nnx_enabled():
|
|
270
|
+
step_fn = lambda state, data: type(self).train_step(
|
|
271
|
+
self, state, data
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
step_fn = self.train_step
|
|
269
275
|
train_step = jit(
|
|
270
|
-
|
|
276
|
+
step_fn,
|
|
271
277
|
donate_argnums=0,
|
|
272
278
|
out_shardings=out_shardings,
|
|
273
279
|
)
|
|
@@ -296,8 +302,14 @@ class JAXTrainer(base_trainer.Trainer):
|
|
|
296
302
|
metrics_shardings,
|
|
297
303
|
)
|
|
298
304
|
out_shardings = (None, state_shardings)
|
|
305
|
+
if is_nnx_enabled():
|
|
306
|
+
step_fn = lambda state, data: type(self).test_step(
|
|
307
|
+
self, state, data
|
|
308
|
+
)
|
|
309
|
+
else:
|
|
310
|
+
step_fn = self.test_step
|
|
299
311
|
test_step = jit(
|
|
300
|
-
|
|
312
|
+
step_fn,
|
|
301
313
|
donate_argnums=0,
|
|
302
314
|
out_shardings=out_shardings,
|
|
303
315
|
)
|
|
@@ -96,3 +96,7 @@ def lstsq(a, b, rcond=None):
|
|
|
96
96
|
a = convert_to_tensor(a)
|
|
97
97
|
b = convert_to_tensor(b)
|
|
98
98
|
return np.linalg.lstsq(a, b, rcond=rcond)[0]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def jvp(fun, primals, tangents, has_aux=False):
|
|
102
|
+
raise NotImplementedError("JVP is not supported by the Numpy backend.")
|
keras/src/backend/numpy/nn.py
CHANGED
|
@@ -3,6 +3,9 @@ import numpy as np
|
|
|
3
3
|
from jax import lax
|
|
4
4
|
|
|
5
5
|
from keras.src import backend
|
|
6
|
+
from keras.src.backend.common.backend_utils import (
|
|
7
|
+
compute_adaptive_pooling_window_sizes,
|
|
8
|
+
)
|
|
6
9
|
from keras.src.backend.common.backend_utils import (
|
|
7
10
|
compute_conv_transpose_padding_args_for_jax,
|
|
8
11
|
)
|
|
@@ -164,13 +167,14 @@ def celu(x, alpha=1.0):
|
|
|
164
167
|
|
|
165
168
|
def glu(x, axis=-1):
|
|
166
169
|
x = convert_to_tensor(x)
|
|
170
|
+
dtype = x.dtype
|
|
167
171
|
if x.shape[axis] % 2 != 0:
|
|
168
172
|
raise ValueError(
|
|
169
173
|
"axis size must be divisible by 2. "
|
|
170
174
|
f"Received: x.shape={x.shape} with axis={axis}"
|
|
171
175
|
)
|
|
172
176
|
x1, x2 = np.split(x, 2, axis)
|
|
173
|
-
return x1 * (
|
|
177
|
+
return (x1 * sigmoid(x2)).astype(dtype)
|
|
174
178
|
|
|
175
179
|
|
|
176
180
|
def hard_tanh(x):
|
|
@@ -339,6 +343,252 @@ def average_pool(
|
|
|
339
343
|
return pooled / window_counts
|
|
340
344
|
|
|
341
345
|
|
|
346
|
+
def _compute_adaptive_pooling_gather_indices(
|
|
347
|
+
input_dim, output_size, big_window
|
|
348
|
+
):
|
|
349
|
+
window_starts = np.floor(
|
|
350
|
+
(np.arange(output_size) * input_dim) / output_size
|
|
351
|
+
).astype(np.int32)
|
|
352
|
+
|
|
353
|
+
window_ends = np.ceil(
|
|
354
|
+
(np.arange(1, output_size + 1) * input_dim) / output_size
|
|
355
|
+
).astype(np.int32)
|
|
356
|
+
|
|
357
|
+
window_sizes = window_ends - window_starts
|
|
358
|
+
is_big = window_sizes == big_window
|
|
359
|
+
|
|
360
|
+
small_window = big_window - 1
|
|
361
|
+
small_pool_len = input_dim - small_window + 1
|
|
362
|
+
|
|
363
|
+
small_indices = window_starts
|
|
364
|
+
big_indices = window_starts + small_pool_len
|
|
365
|
+
|
|
366
|
+
gather = np.where(is_big, big_indices, small_indices)
|
|
367
|
+
return gather.astype(np.int32)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _strided_view_1d(x, window_size):
|
|
371
|
+
n, l, c = x.shape
|
|
372
|
+
out = l - window_size + 1
|
|
373
|
+
|
|
374
|
+
strides = x.strides
|
|
375
|
+
shape = (n, out, window_size, c)
|
|
376
|
+
new_strides = (strides[0], strides[1], strides[1], strides[2])
|
|
377
|
+
|
|
378
|
+
return np.lib.stride_tricks.as_strided(x, shape=shape, strides=new_strides)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _adaptive_pool1d_impl(inputs, output_size, mode, data_format):
|
|
382
|
+
if isinstance(output_size, int):
|
|
383
|
+
output_size = (output_size,)
|
|
384
|
+
|
|
385
|
+
if data_format == "channels_first":
|
|
386
|
+
inputs = np.transpose(inputs, (0, 2, 1))
|
|
387
|
+
|
|
388
|
+
n, l, c = inputs.shape
|
|
389
|
+
out_l = output_size[0]
|
|
390
|
+
|
|
391
|
+
small, big = compute_adaptive_pooling_window_sizes(l, out_l)
|
|
392
|
+
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
|
|
393
|
+
|
|
394
|
+
sv_small = _strided_view_1d(inputs, small)
|
|
395
|
+
small_pool = (
|
|
396
|
+
np.mean(sv_small, axis=2)
|
|
397
|
+
if mode == "average"
|
|
398
|
+
else np.max(sv_small, axis=2)
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
sv_big = _strided_view_1d(inputs, big)
|
|
402
|
+
big_pool = (
|
|
403
|
+
np.mean(sv_big, axis=2) if mode == "average" else np.max(sv_big, axis=2)
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
combined = np.concatenate([small_pool, big_pool], axis=1)
|
|
407
|
+
out = combined[:, gather, :]
|
|
408
|
+
|
|
409
|
+
if data_format == "channels_first":
|
|
410
|
+
out = np.transpose(out, (0, 2, 1))
|
|
411
|
+
|
|
412
|
+
return out
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _adaptive_pool2d_impl(inputs, output_size, mode, data_format):
|
|
416
|
+
if isinstance(output_size, int):
|
|
417
|
+
output_size = (output_size, output_size)
|
|
418
|
+
|
|
419
|
+
if data_format == "channels_first":
|
|
420
|
+
inputs = np.transpose(inputs, (0, 2, 3, 1))
|
|
421
|
+
|
|
422
|
+
n, h, w, c = inputs.shape
|
|
423
|
+
out_h, out_w = output_size
|
|
424
|
+
|
|
425
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
426
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
427
|
+
|
|
428
|
+
x_h = np.transpose(inputs, (0, 2, 1, 3)).reshape(n * w, h, c)
|
|
429
|
+
|
|
430
|
+
sv_small_h = _strided_view_1d(x_h, small_h)
|
|
431
|
+
small_pool_h = (
|
|
432
|
+
np.mean(sv_small_h, axis=2)
|
|
433
|
+
if mode == "average"
|
|
434
|
+
else np.max(sv_small_h, axis=2)
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
sv_big_h = _strided_view_1d(x_h, big_h)
|
|
438
|
+
big_pool_h = (
|
|
439
|
+
np.mean(sv_big_h, axis=2)
|
|
440
|
+
if mode == "average"
|
|
441
|
+
else np.max(sv_big_h, axis=2)
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
|
|
445
|
+
pooled_h = combined_h[:, gather_h, :]
|
|
446
|
+
|
|
447
|
+
pooled_h = pooled_h.reshape(n, w, out_h, c)
|
|
448
|
+
pooled_h = np.transpose(pooled_h, (0, 2, 1, 3))
|
|
449
|
+
|
|
450
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
451
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
452
|
+
|
|
453
|
+
x_w = pooled_h.reshape(n * out_h, w, c)
|
|
454
|
+
|
|
455
|
+
sv_small_w = _strided_view_1d(x_w, small_w)
|
|
456
|
+
small_pool_w = (
|
|
457
|
+
np.mean(sv_small_w, axis=2)
|
|
458
|
+
if mode == "average"
|
|
459
|
+
else np.max(sv_small_w, axis=2)
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
sv_big_w = _strided_view_1d(x_w, big_w)
|
|
463
|
+
big_pool_w = (
|
|
464
|
+
np.mean(sv_big_w, axis=2)
|
|
465
|
+
if mode == "average"
|
|
466
|
+
else np.max(sv_big_w, axis=2)
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
|
|
470
|
+
out = combined_w[:, gather_w, :].reshape(n, out_h, out_w, c)
|
|
471
|
+
|
|
472
|
+
if data_format == "channels_first":
|
|
473
|
+
out = np.transpose(out, (0, 3, 1, 2))
|
|
474
|
+
|
|
475
|
+
return out
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def _adaptive_pool3d_impl(inputs, output_size, mode, data_format):
|
|
479
|
+
if isinstance(output_size, int):
|
|
480
|
+
output_size = (output_size, output_size, output_size)
|
|
481
|
+
|
|
482
|
+
if data_format == "channels_first":
|
|
483
|
+
inputs = np.transpose(inputs, (0, 2, 3, 4, 1))
|
|
484
|
+
|
|
485
|
+
n, d, h, w, c = inputs.shape
|
|
486
|
+
out_d, out_h, out_w = output_size
|
|
487
|
+
|
|
488
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
|
|
489
|
+
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
|
|
490
|
+
|
|
491
|
+
x_d = np.transpose(inputs, (0, 2, 3, 1, 4)).reshape(n * h * w, d, c)
|
|
492
|
+
|
|
493
|
+
sv_small_d = _strided_view_1d(x_d, small_d)
|
|
494
|
+
small_pool_d = (
|
|
495
|
+
np.mean(sv_small_d, axis=2)
|
|
496
|
+
if mode == "average"
|
|
497
|
+
else np.max(sv_small_d, axis=2)
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
sv_big_d = _strided_view_1d(x_d, big_d)
|
|
501
|
+
big_pool_d = (
|
|
502
|
+
np.mean(sv_big_d, axis=2)
|
|
503
|
+
if mode == "average"
|
|
504
|
+
else np.max(sv_big_d, axis=2)
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
combined_d = np.concatenate([small_pool_d, big_pool_d], axis=1)
|
|
508
|
+
pooled_d = combined_d[:, gather_d, :].reshape(n, h, w, out_d, c)
|
|
509
|
+
pooled_d = np.transpose(pooled_d, (0, 3, 1, 2, 4))
|
|
510
|
+
|
|
511
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
512
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
513
|
+
|
|
514
|
+
x_h = np.transpose(pooled_d, (0, 1, 3, 2, 4)).reshape(n * out_d * w, h, c)
|
|
515
|
+
|
|
516
|
+
sv_small_h = _strided_view_1d(x_h, small_h)
|
|
517
|
+
small_pool_h = (
|
|
518
|
+
np.mean(sv_small_h, axis=2)
|
|
519
|
+
if mode == "average"
|
|
520
|
+
else np.max(sv_small_h, axis=2)
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
sv_big_h = _strided_view_1d(x_h, big_h)
|
|
524
|
+
big_pool_h = (
|
|
525
|
+
np.mean(sv_big_h, axis=2)
|
|
526
|
+
if mode == "average"
|
|
527
|
+
else np.max(sv_big_h, axis=2)
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
|
|
531
|
+
pooled_h = combined_h[:, gather_h, :].reshape(n, out_d, w, out_h, c)
|
|
532
|
+
pooled_h = np.transpose(pooled_h, (0, 1, 3, 2, 4))
|
|
533
|
+
|
|
534
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
535
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
536
|
+
|
|
537
|
+
x_w = pooled_h.reshape(n * out_d * out_h, w, c)
|
|
538
|
+
|
|
539
|
+
sv_small_w = _strided_view_1d(x_w, small_w)
|
|
540
|
+
small_pool_w = (
|
|
541
|
+
np.mean(sv_small_w, axis=2)
|
|
542
|
+
if mode == "average"
|
|
543
|
+
else np.max(sv_small_w, axis=2)
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
sv_big_w = _strided_view_1d(x_w, big_w)
|
|
547
|
+
big_pool_w = (
|
|
548
|
+
np.mean(sv_big_w, axis=2)
|
|
549
|
+
if mode == "average"
|
|
550
|
+
else np.max(sv_big_w, axis=2)
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
|
|
554
|
+
out = combined_w[:, gather_w, :].reshape(n, out_d, out_h, out_w, c)
|
|
555
|
+
|
|
556
|
+
if data_format == "channels_first":
|
|
557
|
+
out = np.transpose(out, (0, 4, 1, 2, 3))
|
|
558
|
+
|
|
559
|
+
return out
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def adaptive_average_pool(inputs, output_size, data_format=None):
|
|
563
|
+
data_format = backend.standardize_data_format(data_format)
|
|
564
|
+
dims = inputs.ndim - 2
|
|
565
|
+
if dims == 1:
|
|
566
|
+
return _adaptive_pool1d_impl(
|
|
567
|
+
inputs, output_size, "average", data_format
|
|
568
|
+
)
|
|
569
|
+
if dims == 2:
|
|
570
|
+
return _adaptive_pool2d_impl(
|
|
571
|
+
inputs, output_size, "average", data_format
|
|
572
|
+
)
|
|
573
|
+
if dims == 3:
|
|
574
|
+
return _adaptive_pool3d_impl(
|
|
575
|
+
inputs, output_size, "average", data_format
|
|
576
|
+
)
|
|
577
|
+
raise ValueError("adaptive_average_pool supports only 1D/2D/3D")
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def adaptive_max_pool(inputs, output_size, data_format=None):
|
|
581
|
+
data_format = backend.standardize_data_format(data_format)
|
|
582
|
+
dims = inputs.ndim - 2
|
|
583
|
+
if dims == 1:
|
|
584
|
+
return _adaptive_pool1d_impl(inputs, output_size, "max", data_format)
|
|
585
|
+
if dims == 2:
|
|
586
|
+
return _adaptive_pool2d_impl(inputs, output_size, "max", data_format)
|
|
587
|
+
if dims == 3:
|
|
588
|
+
return _adaptive_pool3d_impl(inputs, output_size, "max", data_format)
|
|
589
|
+
raise ValueError("adaptive_max_pool supports only 1D/2D/3D")
|
|
590
|
+
|
|
591
|
+
|
|
342
592
|
def _convert_to_lax_conv_dimension_numbers(
|
|
343
593
|
num_spatial_dims,
|
|
344
594
|
data_format="channels_last",
|
|
@@ -403,7 +653,7 @@ def conv(
|
|
|
403
653
|
f"kernel in_channels {kernel_in_channels}. "
|
|
404
654
|
)
|
|
405
655
|
feature_group_count = channels // kernel_in_channels
|
|
406
|
-
|
|
656
|
+
result = np.array(
|
|
407
657
|
jax.lax.conv_general_dilated(
|
|
408
658
|
inputs,
|
|
409
659
|
kernel if is_tensor(kernel) else kernel.numpy(),
|
|
@@ -414,6 +664,14 @@ def conv(
|
|
|
414
664
|
feature_group_count=feature_group_count,
|
|
415
665
|
)
|
|
416
666
|
)
|
|
667
|
+
if result.size == 0:
|
|
668
|
+
raise ValueError(
|
|
669
|
+
"The convolution operation resulted in an empty output. "
|
|
670
|
+
"This can happen if the input is too small for the given "
|
|
671
|
+
"kernel size, strides, dilation rate, and padding mode. "
|
|
672
|
+
"Please check the input shape and convolution parameters."
|
|
673
|
+
)
|
|
674
|
+
return result
|
|
417
675
|
|
|
418
676
|
|
|
419
677
|
def depthwise_conv(
|
|
@@ -1175,3 +1433,56 @@ def dot_product_attention(
|
|
|
1175
1433
|
return _dot_product_attention_xla(
|
|
1176
1434
|
query, key, value, bias, mask, is_causal, scale
|
|
1177
1435
|
)
|
|
1436
|
+
|
|
1437
|
+
|
|
1438
|
+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|
1439
|
+
"""NumPy implementation of Unfold.
|
|
1440
|
+
Extract sliding local blocks from a **NCHW** batched image tensor.
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
input: 4-D tensor, shape (N, C, H, W) **required**.
|
|
1444
|
+
kernel_size: int or (kH, kW)
|
|
1445
|
+
dilation: int or (dH, dW), default 1
|
|
1446
|
+
padding: int or (pH, pW), default 0
|
|
1447
|
+
stride: int or (sH, sW), default 1
|
|
1448
|
+
|
|
1449
|
+
Returns:
|
|
1450
|
+
3-D tensor, shape (N, C*kH*kW, L)
|
|
1451
|
+
"""
|
|
1452
|
+
|
|
1453
|
+
def _pair(x):
|
|
1454
|
+
return (x, x) if isinstance(x, int) else x
|
|
1455
|
+
|
|
1456
|
+
k = _pair(kernel_size)
|
|
1457
|
+
d = _pair(dilation)
|
|
1458
|
+
p = _pair(padding)
|
|
1459
|
+
s = _pair(stride)
|
|
1460
|
+
|
|
1461
|
+
N, C, H, W = input.shape
|
|
1462
|
+
|
|
1463
|
+
# ---- padding ----
|
|
1464
|
+
if any(_ > 0 for _ in p):
|
|
1465
|
+
input = np.pad(
|
|
1466
|
+
input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])), mode="constant"
|
|
1467
|
+
)
|
|
1468
|
+
|
|
1469
|
+
# ---- spatial size ----
|
|
1470
|
+
oH = (input.shape[2] - (k[0] - 1) * d[0] - 1) // s[0] + 1
|
|
1471
|
+
oW = (input.shape[3] - (k[1] - 1) * d[1] - 1) // s[1] + 1
|
|
1472
|
+
|
|
1473
|
+
i0 = np.arange(0, oH) * s[0]
|
|
1474
|
+
j0 = np.arange(0, oW) * s[1]
|
|
1475
|
+
i, j = np.meshgrid(i0, j0, indexing="ij") # shape (oH, oW)
|
|
1476
|
+
i = i.reshape(-1)
|
|
1477
|
+
j = j.reshape(-1)
|
|
1478
|
+
|
|
1479
|
+
# ---- flatten patches ----
|
|
1480
|
+
patches = np.empty((N, C, k[0], k[1], oH * oW), dtype=input.dtype)
|
|
1481
|
+
for idx in range(k[0]):
|
|
1482
|
+
for jdx in range(k[1]):
|
|
1483
|
+
patches[:, :, idx, jdx, :] = input[
|
|
1484
|
+
:, :, i + idx * d[0], j + jdx * d[1]
|
|
1485
|
+
]
|
|
1486
|
+
|
|
1487
|
+
# ---- reshape -> (N, C*kH*kW, L) ----
|
|
1488
|
+
return patches.reshape(N, C * k[0] * k[1], -1)
|