keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/backend/jax/numpy.py
CHANGED
|
@@ -3,6 +3,7 @@ import math
|
|
|
3
3
|
|
|
4
4
|
import jax.experimental.sparse as jax_sparse
|
|
5
5
|
import jax.numpy as jnp
|
|
6
|
+
from jax import export as jax_export
|
|
6
7
|
|
|
7
8
|
from keras.src.backend import config
|
|
8
9
|
from keras.src.backend.common import dtypes
|
|
@@ -306,14 +307,20 @@ def append(x1, x2, axis=None):
|
|
|
306
307
|
return jnp.append(x1, x2, axis=axis)
|
|
307
308
|
|
|
308
309
|
|
|
309
|
-
def arange(start, stop=None, step=
|
|
310
|
+
def arange(start, stop=None, step=None, dtype=None):
|
|
311
|
+
def get_dtype(x):
|
|
312
|
+
if hasattr(x, "dtype"):
|
|
313
|
+
return x.dtype
|
|
314
|
+
if jax_export.is_symbolic_dim(x):
|
|
315
|
+
return int
|
|
316
|
+
return type(x)
|
|
317
|
+
|
|
310
318
|
if dtype is None:
|
|
311
|
-
dtypes_to_resolve = [
|
|
312
|
-
getattr(start, "dtype", type(start)),
|
|
313
|
-
getattr(step, "dtype", type(step)),
|
|
314
|
-
]
|
|
319
|
+
dtypes_to_resolve = [get_dtype(start)]
|
|
315
320
|
if stop is not None:
|
|
316
|
-
dtypes_to_resolve.append(
|
|
321
|
+
dtypes_to_resolve.append(get_dtype(stop))
|
|
322
|
+
if step is not None:
|
|
323
|
+
dtypes_to_resolve.append(get_dtype(step))
|
|
317
324
|
dtype = dtypes.result_type(*dtypes_to_resolve)
|
|
318
325
|
dtype = standardize_dtype(dtype)
|
|
319
326
|
return jnp.arange(start, stop, step=step, dtype=dtype)
|
|
@@ -439,6 +446,11 @@ def array(x, dtype=None):
|
|
|
439
446
|
return jnp.array(x, dtype=dtype)
|
|
440
447
|
|
|
441
448
|
|
|
449
|
+
def view(x, dtype=None):
|
|
450
|
+
x = convert_to_tensor(x)
|
|
451
|
+
return x.view(dtype=dtype)
|
|
452
|
+
|
|
453
|
+
|
|
442
454
|
def average(x, axis=None, weights=None):
|
|
443
455
|
x = convert_to_tensor(x)
|
|
444
456
|
dtypes_to_resolve = [x.dtype, float]
|
|
@@ -536,15 +548,18 @@ def clip(x, x_min, x_max):
|
|
|
536
548
|
|
|
537
549
|
def concatenate(xs, axis=0):
|
|
538
550
|
bcoo_count = builtins.sum(isinstance(x, jax_sparse.BCOO) for x in xs)
|
|
539
|
-
if bcoo_count:
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
551
|
+
if bcoo_count == len(xs):
|
|
552
|
+
axis = canonicalize_axis(axis, len(xs[0].shape))
|
|
553
|
+
return jax_sparse.bcoo_concatenate(xs, dimension=axis)
|
|
554
|
+
elif bcoo_count:
|
|
555
|
+
xs = [
|
|
556
|
+
x.todense()
|
|
557
|
+
if isinstance(x, jax_sparse.JAXSparse)
|
|
558
|
+
else convert_to_tensor(x)
|
|
559
|
+
for x in xs
|
|
560
|
+
]
|
|
561
|
+
else:
|
|
562
|
+
xs = [convert_to_tensor(x) for x in xs]
|
|
548
563
|
return jnp.concatenate(xs, axis=axis)
|
|
549
564
|
|
|
550
565
|
|
|
@@ -663,6 +678,10 @@ def empty(shape, dtype=None):
|
|
|
663
678
|
return jnp.empty(shape, dtype=dtype)
|
|
664
679
|
|
|
665
680
|
|
|
681
|
+
def empty_like(x, dtype=None):
|
|
682
|
+
return jnp.empty_like(x, dtype=dtype)
|
|
683
|
+
|
|
684
|
+
|
|
666
685
|
def equal(x1, x2):
|
|
667
686
|
x1 = convert_to_tensor(x1)
|
|
668
687
|
x2 = convert_to_tensor(x2)
|
|
@@ -809,6 +828,36 @@ def isposinf(x):
|
|
|
809
828
|
return jnp.isposinf(x)
|
|
810
829
|
|
|
811
830
|
|
|
831
|
+
def isreal(x):
|
|
832
|
+
x = convert_to_tensor(x)
|
|
833
|
+
return jnp.isreal(x)
|
|
834
|
+
|
|
835
|
+
|
|
836
|
+
def kron(x1, x2):
|
|
837
|
+
x1 = convert_to_tensor(x1)
|
|
838
|
+
x2 = convert_to_tensor(x2)
|
|
839
|
+
return jnp.kron(x1, x2)
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
def lcm(x1, x2):
|
|
843
|
+
x1 = convert_to_tensor(x1)
|
|
844
|
+
x2 = convert_to_tensor(x2)
|
|
845
|
+
return jnp.lcm(x1, x2)
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
def ldexp(x1, x2):
|
|
849
|
+
x1 = convert_to_tensor(x1)
|
|
850
|
+
x2 = convert_to_tensor(x2)
|
|
851
|
+
|
|
852
|
+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
|
|
853
|
+
raise TypeError(
|
|
854
|
+
f"ldexp exponent must be an integer type. "
|
|
855
|
+
f"Received: x2 dtype={x2.dtype}"
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
return jnp.ldexp(x1, x2)
|
|
859
|
+
|
|
860
|
+
|
|
812
861
|
def less(x1, x2):
|
|
813
862
|
x1 = convert_to_tensor(x1)
|
|
814
863
|
x2 = convert_to_tensor(x2)
|
|
@@ -876,6 +925,15 @@ def logaddexp(x1, x2):
|
|
|
876
925
|
return jnp.logaddexp(x1, x2)
|
|
877
926
|
|
|
878
927
|
|
|
928
|
+
def logaddexp2(x1, x2):
|
|
929
|
+
x1 = convert_to_tensor(x1)
|
|
930
|
+
x2 = convert_to_tensor(x2)
|
|
931
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
932
|
+
x1 = cast(x1, dtype)
|
|
933
|
+
x2 = cast(x2, dtype)
|
|
934
|
+
return jnp.logaddexp2(x1, x2)
|
|
935
|
+
|
|
936
|
+
|
|
879
937
|
def logical_and(x1, x2):
|
|
880
938
|
x1 = convert_to_tensor(x1)
|
|
881
939
|
x2 = convert_to_tensor(x2)
|
|
@@ -1005,6 +1063,11 @@ def prod(x, axis=None, keepdims=False, dtype=None):
|
|
|
1005
1063
|
return jnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
|
|
1006
1064
|
|
|
1007
1065
|
|
|
1066
|
+
def ptp(x, axis=None, keepdims=False):
|
|
1067
|
+
x = convert_to_tensor(x)
|
|
1068
|
+
return jnp.ptp(x, axis=axis, keepdims=keepdims)
|
|
1069
|
+
|
|
1070
|
+
|
|
1008
1071
|
def quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
1009
1072
|
x = convert_to_tensor(x)
|
|
1010
1073
|
q = convert_to_tensor(q)
|
|
@@ -1059,6 +1122,7 @@ def reshape(x, newshape):
|
|
|
1059
1122
|
if None not in output_shape:
|
|
1060
1123
|
newshape = output_shape
|
|
1061
1124
|
return jax_sparse.bcoo_reshape(x, new_sizes=newshape)
|
|
1125
|
+
x = convert_to_tensor(x)
|
|
1062
1126
|
return jnp.reshape(x, newshape)
|
|
1063
1127
|
|
|
1064
1128
|
|
|
@@ -1121,10 +1185,17 @@ def sort(x, axis=-1):
|
|
|
1121
1185
|
|
|
1122
1186
|
|
|
1123
1187
|
def split(x, indices_or_sections, axis=0):
|
|
1188
|
+
x = convert_to_tensor(x)
|
|
1124
1189
|
return jnp.split(x, indices_or_sections, axis=axis)
|
|
1125
1190
|
|
|
1126
1191
|
|
|
1192
|
+
def array_split(x, indices_or_sections, axis=0):
|
|
1193
|
+
x = convert_to_tensor(x)
|
|
1194
|
+
return jnp.array_split(x, indices_or_sections, axis=axis)
|
|
1195
|
+
|
|
1196
|
+
|
|
1127
1197
|
def stack(x, axis=0):
|
|
1198
|
+
x = [convert_to_tensor(t) for t in x]
|
|
1128
1199
|
return jnp.stack(x, axis=axis)
|
|
1129
1200
|
|
|
1130
1201
|
|
|
@@ -1147,6 +1218,8 @@ def take(x, indices, axis=None):
|
|
|
1147
1218
|
|
|
1148
1219
|
|
|
1149
1220
|
def take_along_axis(x, indices, axis=None):
|
|
1221
|
+
x = convert_to_tensor(x)
|
|
1222
|
+
indices = convert_to_tensor(indices, sparse=False)
|
|
1150
1223
|
return jnp.take_along_axis(x, indices, axis=axis)
|
|
1151
1224
|
|
|
1152
1225
|
|
|
@@ -1201,14 +1274,7 @@ def tile(x, repeats):
|
|
|
1201
1274
|
|
|
1202
1275
|
def trace(x, offset=0, axis1=0, axis2=1):
|
|
1203
1276
|
x = convert_to_tensor(x)
|
|
1204
|
-
|
|
1205
|
-
# TODO: Remove the condition of uint8 and uint16 once we have jax>=0.4.27
|
|
1206
|
-
# for both CPU & GPU environments.
|
|
1207
|
-
# uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to int32
|
|
1208
|
-
# otherwise.
|
|
1209
|
-
if standardize_dtype(x.dtype) in ("bool", "uint8", "uint16"):
|
|
1210
|
-
dtype = "int32"
|
|
1211
|
-
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
|
|
1277
|
+
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)
|
|
1212
1278
|
|
|
1213
1279
|
|
|
1214
1280
|
def tri(N, M=None, k=0, dtype=None):
|
|
@@ -1290,6 +1356,12 @@ def negative(x):
|
|
|
1290
1356
|
return jnp.negative(x)
|
|
1291
1357
|
|
|
1292
1358
|
|
|
1359
|
+
def nextafter(x1, x2):
|
|
1360
|
+
x1 = convert_to_tensor(x1)
|
|
1361
|
+
x2 = convert_to_tensor(x2)
|
|
1362
|
+
return jnp.nextafter(x1, x2)
|
|
1363
|
+
|
|
1364
|
+
|
|
1293
1365
|
@sparse.elementwise_unary(linear=False)
|
|
1294
1366
|
def square(x):
|
|
1295
1367
|
x = convert_to_tensor(x)
|
|
@@ -1310,6 +1382,7 @@ def squeeze(x, axis=None):
|
|
|
1310
1382
|
axis = tuple(i for i, d in enumerate(x.shape) if d == 1)
|
|
1311
1383
|
axis = to_tuple_or_list(axis)
|
|
1312
1384
|
return jax_sparse.bcoo_squeeze(x, dimensions=axis)
|
|
1385
|
+
x = convert_to_tensor(x)
|
|
1313
1386
|
return jnp.squeeze(x, axis=axis)
|
|
1314
1387
|
|
|
1315
1388
|
|
|
@@ -1328,6 +1401,19 @@ def transpose(x, axes=None):
|
|
|
1328
1401
|
return jnp.transpose(x, axes=axes)
|
|
1329
1402
|
|
|
1330
1403
|
|
|
1404
|
+
def trapezoid(y, x=None, dx=1.0, axis=-1):
|
|
1405
|
+
y = convert_to_tensor(y)
|
|
1406
|
+
if x is not None:
|
|
1407
|
+
x = convert_to_tensor(x)
|
|
1408
|
+
dx = convert_to_tensor(dx)
|
|
1409
|
+
return jnp.trapezoid(y, x, dx=dx, axis=axis)
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
def vander(x, N=None, increasing=False):
|
|
1413
|
+
x = convert_to_tensor(x)
|
|
1414
|
+
return jnp.vander(x, N=N, increasing=increasing)
|
|
1415
|
+
|
|
1416
|
+
|
|
1331
1417
|
def var(x, axis=None, keepdims=False):
|
|
1332
1418
|
x = convert_to_tensor(x)
|
|
1333
1419
|
# `jnp.var` does not handle low precision (e.g., float16) overflow
|
|
@@ -36,13 +36,14 @@ class JaxOptimizer(base_optimizer.BaseOptimizer):
|
|
|
36
36
|
new_g_accs = jax.lax.cond(
|
|
37
37
|
is_update_step,
|
|
38
38
|
lambda: [jnp.zeros(g.shape, dtype=g.dtype) for g in acc_grads],
|
|
39
|
-
lambda: [g + acc_g for g, acc_g in zip(grads, acc_grads)],
|
|
39
|
+
lambda: [g + acc_g.value for g, acc_g in zip(grads, acc_grads)],
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
grads = jax.lax.cond(
|
|
43
43
|
is_update_step,
|
|
44
44
|
lambda: [
|
|
45
|
-
(g + acc_g) / steps
|
|
45
|
+
(g + acc_g.value) / steps
|
|
46
|
+
for g, acc_g in zip(grads, acc_grads)
|
|
46
47
|
],
|
|
47
48
|
lambda: list(grads),
|
|
48
49
|
)
|
keras/src/backend/jax/trainer.py
CHANGED
|
@@ -105,7 +105,10 @@ class JAXTrainer(base_trainer.Trainer):
|
|
|
105
105
|
]
|
|
106
106
|
) as scope:
|
|
107
107
|
self._loss_tracker.update_state(
|
|
108
|
-
unscaled_loss,
|
|
108
|
+
unscaled_loss,
|
|
109
|
+
sample_weight=next(
|
|
110
|
+
i for i in tree.flatten(x) if i is not None
|
|
111
|
+
).shape[0],
|
|
109
112
|
)
|
|
110
113
|
logs = self.compute_metrics(x, y, y_pred, sample_weight)
|
|
111
114
|
|
|
@@ -263,8 +266,14 @@ class JAXTrainer(base_trainer.Trainer):
|
|
|
263
266
|
if distribution_lib.distribution() is not None:
|
|
264
267
|
state_shardings = self._get_state_sharding_spec()
|
|
265
268
|
out_shardings = (None, state_shardings)
|
|
269
|
+
if is_nnx_enabled():
|
|
270
|
+
step_fn = lambda state, data: type(self).train_step(
|
|
271
|
+
self, state, data
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
step_fn = self.train_step
|
|
266
275
|
train_step = jit(
|
|
267
|
-
|
|
276
|
+
step_fn,
|
|
268
277
|
donate_argnums=0,
|
|
269
278
|
out_shardings=out_shardings,
|
|
270
279
|
)
|
|
@@ -293,8 +302,14 @@ class JAXTrainer(base_trainer.Trainer):
|
|
|
293
302
|
metrics_shardings,
|
|
294
303
|
)
|
|
295
304
|
out_shardings = (None, state_shardings)
|
|
305
|
+
if is_nnx_enabled():
|
|
306
|
+
step_fn = lambda state, data: type(self).test_step(
|
|
307
|
+
self, state, data
|
|
308
|
+
)
|
|
309
|
+
else:
|
|
310
|
+
step_fn = self.test_step
|
|
296
311
|
test_step = jit(
|
|
297
|
-
|
|
312
|
+
step_fn,
|
|
298
313
|
donate_argnums=0,
|
|
299
314
|
out_shardings=out_shardings,
|
|
300
315
|
)
|
|
@@ -96,3 +96,7 @@ def lstsq(a, b, rcond=None):
|
|
|
96
96
|
a = convert_to_tensor(a)
|
|
97
97
|
b = convert_to_tensor(b)
|
|
98
98
|
return np.linalg.lstsq(a, b, rcond=rcond)[0]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def jvp(fun, primals, tangents, has_aux=False):
|
|
102
|
+
raise NotImplementedError("JVP is not supported by the Numpy backend.")
|
keras/src/backend/numpy/nn.py
CHANGED
|
@@ -3,6 +3,9 @@ import numpy as np
|
|
|
3
3
|
from jax import lax
|
|
4
4
|
|
|
5
5
|
from keras.src import backend
|
|
6
|
+
from keras.src.backend.common.backend_utils import (
|
|
7
|
+
compute_adaptive_pooling_window_sizes,
|
|
8
|
+
)
|
|
6
9
|
from keras.src.backend.common.backend_utils import (
|
|
7
10
|
compute_conv_transpose_padding_args_for_jax,
|
|
8
11
|
)
|
|
@@ -164,13 +167,14 @@ def celu(x, alpha=1.0):
|
|
|
164
167
|
|
|
165
168
|
def glu(x, axis=-1):
|
|
166
169
|
x = convert_to_tensor(x)
|
|
170
|
+
dtype = x.dtype
|
|
167
171
|
if x.shape[axis] % 2 != 0:
|
|
168
172
|
raise ValueError(
|
|
169
173
|
"axis size must be divisible by 2. "
|
|
170
174
|
f"Received: x.shape={x.shape} with axis={axis}"
|
|
171
175
|
)
|
|
172
176
|
x1, x2 = np.split(x, 2, axis)
|
|
173
|
-
return x1 * (
|
|
177
|
+
return (x1 * sigmoid(x2)).astype(dtype)
|
|
174
178
|
|
|
175
179
|
|
|
176
180
|
def hard_tanh(x):
|
|
@@ -339,6 +343,252 @@ def average_pool(
|
|
|
339
343
|
return pooled / window_counts
|
|
340
344
|
|
|
341
345
|
|
|
346
|
+
def _compute_adaptive_pooling_gather_indices(
|
|
347
|
+
input_dim, output_size, big_window
|
|
348
|
+
):
|
|
349
|
+
window_starts = np.floor(
|
|
350
|
+
(np.arange(output_size) * input_dim) / output_size
|
|
351
|
+
).astype(np.int32)
|
|
352
|
+
|
|
353
|
+
window_ends = np.ceil(
|
|
354
|
+
(np.arange(1, output_size + 1) * input_dim) / output_size
|
|
355
|
+
).astype(np.int32)
|
|
356
|
+
|
|
357
|
+
window_sizes = window_ends - window_starts
|
|
358
|
+
is_big = window_sizes == big_window
|
|
359
|
+
|
|
360
|
+
small_window = big_window - 1
|
|
361
|
+
small_pool_len = input_dim - small_window + 1
|
|
362
|
+
|
|
363
|
+
small_indices = window_starts
|
|
364
|
+
big_indices = window_starts + small_pool_len
|
|
365
|
+
|
|
366
|
+
gather = np.where(is_big, big_indices, small_indices)
|
|
367
|
+
return gather.astype(np.int32)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _strided_view_1d(x, window_size):
|
|
371
|
+
n, l, c = x.shape
|
|
372
|
+
out = l - window_size + 1
|
|
373
|
+
|
|
374
|
+
strides = x.strides
|
|
375
|
+
shape = (n, out, window_size, c)
|
|
376
|
+
new_strides = (strides[0], strides[1], strides[1], strides[2])
|
|
377
|
+
|
|
378
|
+
return np.lib.stride_tricks.as_strided(x, shape=shape, strides=new_strides)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _adaptive_pool1d_impl(inputs, output_size, mode, data_format):
|
|
382
|
+
if isinstance(output_size, int):
|
|
383
|
+
output_size = (output_size,)
|
|
384
|
+
|
|
385
|
+
if data_format == "channels_first":
|
|
386
|
+
inputs = np.transpose(inputs, (0, 2, 1))
|
|
387
|
+
|
|
388
|
+
n, l, c = inputs.shape
|
|
389
|
+
out_l = output_size[0]
|
|
390
|
+
|
|
391
|
+
small, big = compute_adaptive_pooling_window_sizes(l, out_l)
|
|
392
|
+
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
|
|
393
|
+
|
|
394
|
+
sv_small = _strided_view_1d(inputs, small)
|
|
395
|
+
small_pool = (
|
|
396
|
+
np.mean(sv_small, axis=2)
|
|
397
|
+
if mode == "average"
|
|
398
|
+
else np.max(sv_small, axis=2)
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
sv_big = _strided_view_1d(inputs, big)
|
|
402
|
+
big_pool = (
|
|
403
|
+
np.mean(sv_big, axis=2) if mode == "average" else np.max(sv_big, axis=2)
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
combined = np.concatenate([small_pool, big_pool], axis=1)
|
|
407
|
+
out = combined[:, gather, :]
|
|
408
|
+
|
|
409
|
+
if data_format == "channels_first":
|
|
410
|
+
out = np.transpose(out, (0, 2, 1))
|
|
411
|
+
|
|
412
|
+
return out
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _adaptive_pool2d_impl(inputs, output_size, mode, data_format):
|
|
416
|
+
if isinstance(output_size, int):
|
|
417
|
+
output_size = (output_size, output_size)
|
|
418
|
+
|
|
419
|
+
if data_format == "channels_first":
|
|
420
|
+
inputs = np.transpose(inputs, (0, 2, 3, 1))
|
|
421
|
+
|
|
422
|
+
n, h, w, c = inputs.shape
|
|
423
|
+
out_h, out_w = output_size
|
|
424
|
+
|
|
425
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
426
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
427
|
+
|
|
428
|
+
x_h = np.transpose(inputs, (0, 2, 1, 3)).reshape(n * w, h, c)
|
|
429
|
+
|
|
430
|
+
sv_small_h = _strided_view_1d(x_h, small_h)
|
|
431
|
+
small_pool_h = (
|
|
432
|
+
np.mean(sv_small_h, axis=2)
|
|
433
|
+
if mode == "average"
|
|
434
|
+
else np.max(sv_small_h, axis=2)
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
sv_big_h = _strided_view_1d(x_h, big_h)
|
|
438
|
+
big_pool_h = (
|
|
439
|
+
np.mean(sv_big_h, axis=2)
|
|
440
|
+
if mode == "average"
|
|
441
|
+
else np.max(sv_big_h, axis=2)
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
|
|
445
|
+
pooled_h = combined_h[:, gather_h, :]
|
|
446
|
+
|
|
447
|
+
pooled_h = pooled_h.reshape(n, w, out_h, c)
|
|
448
|
+
pooled_h = np.transpose(pooled_h, (0, 2, 1, 3))
|
|
449
|
+
|
|
450
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
451
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
452
|
+
|
|
453
|
+
x_w = pooled_h.reshape(n * out_h, w, c)
|
|
454
|
+
|
|
455
|
+
sv_small_w = _strided_view_1d(x_w, small_w)
|
|
456
|
+
small_pool_w = (
|
|
457
|
+
np.mean(sv_small_w, axis=2)
|
|
458
|
+
if mode == "average"
|
|
459
|
+
else np.max(sv_small_w, axis=2)
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
sv_big_w = _strided_view_1d(x_w, big_w)
|
|
463
|
+
big_pool_w = (
|
|
464
|
+
np.mean(sv_big_w, axis=2)
|
|
465
|
+
if mode == "average"
|
|
466
|
+
else np.max(sv_big_w, axis=2)
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
|
|
470
|
+
out = combined_w[:, gather_w, :].reshape(n, out_h, out_w, c)
|
|
471
|
+
|
|
472
|
+
if data_format == "channels_first":
|
|
473
|
+
out = np.transpose(out, (0, 3, 1, 2))
|
|
474
|
+
|
|
475
|
+
return out
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def _adaptive_pool3d_impl(inputs, output_size, mode, data_format):
|
|
479
|
+
if isinstance(output_size, int):
|
|
480
|
+
output_size = (output_size, output_size, output_size)
|
|
481
|
+
|
|
482
|
+
if data_format == "channels_first":
|
|
483
|
+
inputs = np.transpose(inputs, (0, 2, 3, 4, 1))
|
|
484
|
+
|
|
485
|
+
n, d, h, w, c = inputs.shape
|
|
486
|
+
out_d, out_h, out_w = output_size
|
|
487
|
+
|
|
488
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
|
|
489
|
+
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
|
|
490
|
+
|
|
491
|
+
x_d = np.transpose(inputs, (0, 2, 3, 1, 4)).reshape(n * h * w, d, c)
|
|
492
|
+
|
|
493
|
+
sv_small_d = _strided_view_1d(x_d, small_d)
|
|
494
|
+
small_pool_d = (
|
|
495
|
+
np.mean(sv_small_d, axis=2)
|
|
496
|
+
if mode == "average"
|
|
497
|
+
else np.max(sv_small_d, axis=2)
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
sv_big_d = _strided_view_1d(x_d, big_d)
|
|
501
|
+
big_pool_d = (
|
|
502
|
+
np.mean(sv_big_d, axis=2)
|
|
503
|
+
if mode == "average"
|
|
504
|
+
else np.max(sv_big_d, axis=2)
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
combined_d = np.concatenate([small_pool_d, big_pool_d], axis=1)
|
|
508
|
+
pooled_d = combined_d[:, gather_d, :].reshape(n, h, w, out_d, c)
|
|
509
|
+
pooled_d = np.transpose(pooled_d, (0, 3, 1, 2, 4))
|
|
510
|
+
|
|
511
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
512
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
513
|
+
|
|
514
|
+
x_h = np.transpose(pooled_d, (0, 1, 3, 2, 4)).reshape(n * out_d * w, h, c)
|
|
515
|
+
|
|
516
|
+
sv_small_h = _strided_view_1d(x_h, small_h)
|
|
517
|
+
small_pool_h = (
|
|
518
|
+
np.mean(sv_small_h, axis=2)
|
|
519
|
+
if mode == "average"
|
|
520
|
+
else np.max(sv_small_h, axis=2)
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
sv_big_h = _strided_view_1d(x_h, big_h)
|
|
524
|
+
big_pool_h = (
|
|
525
|
+
np.mean(sv_big_h, axis=2)
|
|
526
|
+
if mode == "average"
|
|
527
|
+
else np.max(sv_big_h, axis=2)
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
|
|
531
|
+
pooled_h = combined_h[:, gather_h, :].reshape(n, out_d, w, out_h, c)
|
|
532
|
+
pooled_h = np.transpose(pooled_h, (0, 1, 3, 2, 4))
|
|
533
|
+
|
|
534
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
535
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
536
|
+
|
|
537
|
+
x_w = pooled_h.reshape(n * out_d * out_h, w, c)
|
|
538
|
+
|
|
539
|
+
sv_small_w = _strided_view_1d(x_w, small_w)
|
|
540
|
+
small_pool_w = (
|
|
541
|
+
np.mean(sv_small_w, axis=2)
|
|
542
|
+
if mode == "average"
|
|
543
|
+
else np.max(sv_small_w, axis=2)
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
sv_big_w = _strided_view_1d(x_w, big_w)
|
|
547
|
+
big_pool_w = (
|
|
548
|
+
np.mean(sv_big_w, axis=2)
|
|
549
|
+
if mode == "average"
|
|
550
|
+
else np.max(sv_big_w, axis=2)
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
|
|
554
|
+
out = combined_w[:, gather_w, :].reshape(n, out_d, out_h, out_w, c)
|
|
555
|
+
|
|
556
|
+
if data_format == "channels_first":
|
|
557
|
+
out = np.transpose(out, (0, 4, 1, 2, 3))
|
|
558
|
+
|
|
559
|
+
return out
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def adaptive_average_pool(inputs, output_size, data_format=None):
|
|
563
|
+
data_format = backend.standardize_data_format(data_format)
|
|
564
|
+
dims = inputs.ndim - 2
|
|
565
|
+
if dims == 1:
|
|
566
|
+
return _adaptive_pool1d_impl(
|
|
567
|
+
inputs, output_size, "average", data_format
|
|
568
|
+
)
|
|
569
|
+
if dims == 2:
|
|
570
|
+
return _adaptive_pool2d_impl(
|
|
571
|
+
inputs, output_size, "average", data_format
|
|
572
|
+
)
|
|
573
|
+
if dims == 3:
|
|
574
|
+
return _adaptive_pool3d_impl(
|
|
575
|
+
inputs, output_size, "average", data_format
|
|
576
|
+
)
|
|
577
|
+
raise ValueError("adaptive_average_pool supports only 1D/2D/3D")
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def adaptive_max_pool(inputs, output_size, data_format=None):
|
|
581
|
+
data_format = backend.standardize_data_format(data_format)
|
|
582
|
+
dims = inputs.ndim - 2
|
|
583
|
+
if dims == 1:
|
|
584
|
+
return _adaptive_pool1d_impl(inputs, output_size, "max", data_format)
|
|
585
|
+
if dims == 2:
|
|
586
|
+
return _adaptive_pool2d_impl(inputs, output_size, "max", data_format)
|
|
587
|
+
if dims == 3:
|
|
588
|
+
return _adaptive_pool3d_impl(inputs, output_size, "max", data_format)
|
|
589
|
+
raise ValueError("adaptive_max_pool supports only 1D/2D/3D")
|
|
590
|
+
|
|
591
|
+
|
|
342
592
|
def _convert_to_lax_conv_dimension_numbers(
|
|
343
593
|
num_spatial_dims,
|
|
344
594
|
data_format="channels_last",
|
|
@@ -403,7 +653,7 @@ def conv(
|
|
|
403
653
|
f"kernel in_channels {kernel_in_channels}. "
|
|
404
654
|
)
|
|
405
655
|
feature_group_count = channels // kernel_in_channels
|
|
406
|
-
|
|
656
|
+
result = np.array(
|
|
407
657
|
jax.lax.conv_general_dilated(
|
|
408
658
|
inputs,
|
|
409
659
|
kernel if is_tensor(kernel) else kernel.numpy(),
|
|
@@ -414,6 +664,14 @@ def conv(
|
|
|
414
664
|
feature_group_count=feature_group_count,
|
|
415
665
|
)
|
|
416
666
|
)
|
|
667
|
+
if result.size == 0:
|
|
668
|
+
raise ValueError(
|
|
669
|
+
"The convolution operation resulted in an empty output. "
|
|
670
|
+
"This can happen if the input is too small for the given "
|
|
671
|
+
"kernel size, strides, dilation rate, and padding mode. "
|
|
672
|
+
"Please check the input shape and convolution parameters."
|
|
673
|
+
)
|
|
674
|
+
return result
|
|
417
675
|
|
|
418
676
|
|
|
419
677
|
def depthwise_conv(
|
|
@@ -1175,3 +1433,56 @@ def dot_product_attention(
|
|
|
1175
1433
|
return _dot_product_attention_xla(
|
|
1176
1434
|
query, key, value, bias, mask, is_causal, scale
|
|
1177
1435
|
)
|
|
1436
|
+
|
|
1437
|
+
|
|
1438
|
+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|
1439
|
+
"""NumPy implementation of Unfold.
|
|
1440
|
+
Extract sliding local blocks from a **NCHW** batched image tensor.
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
input: 4-D tensor, shape (N, C, H, W) **required**.
|
|
1444
|
+
kernel_size: int or (kH, kW)
|
|
1445
|
+
dilation: int or (dH, dW), default 1
|
|
1446
|
+
padding: int or (pH, pW), default 0
|
|
1447
|
+
stride: int or (sH, sW), default 1
|
|
1448
|
+
|
|
1449
|
+
Returns:
|
|
1450
|
+
3-D tensor, shape (N, C*kH*kW, L)
|
|
1451
|
+
"""
|
|
1452
|
+
|
|
1453
|
+
def _pair(x):
|
|
1454
|
+
return (x, x) if isinstance(x, int) else x
|
|
1455
|
+
|
|
1456
|
+
k = _pair(kernel_size)
|
|
1457
|
+
d = _pair(dilation)
|
|
1458
|
+
p = _pair(padding)
|
|
1459
|
+
s = _pair(stride)
|
|
1460
|
+
|
|
1461
|
+
N, C, H, W = input.shape
|
|
1462
|
+
|
|
1463
|
+
# ---- padding ----
|
|
1464
|
+
if any(_ > 0 for _ in p):
|
|
1465
|
+
input = np.pad(
|
|
1466
|
+
input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])), mode="constant"
|
|
1467
|
+
)
|
|
1468
|
+
|
|
1469
|
+
# ---- spatial size ----
|
|
1470
|
+
oH = (input.shape[2] - (k[0] - 1) * d[0] - 1) // s[0] + 1
|
|
1471
|
+
oW = (input.shape[3] - (k[1] - 1) * d[1] - 1) // s[1] + 1
|
|
1472
|
+
|
|
1473
|
+
i0 = np.arange(0, oH) * s[0]
|
|
1474
|
+
j0 = np.arange(0, oW) * s[1]
|
|
1475
|
+
i, j = np.meshgrid(i0, j0, indexing="ij") # shape (oH, oW)
|
|
1476
|
+
i = i.reshape(-1)
|
|
1477
|
+
j = j.reshape(-1)
|
|
1478
|
+
|
|
1479
|
+
# ---- flatten patches ----
|
|
1480
|
+
patches = np.empty((N, C, k[0], k[1], oH * oW), dtype=input.dtype)
|
|
1481
|
+
for idx in range(k[0]):
|
|
1482
|
+
for jdx in range(k[1]):
|
|
1483
|
+
patches[:, :, idx, jdx, :] = input[
|
|
1484
|
+
:, :, i + idx * d[0], j + jdx * d[1]
|
|
1485
|
+
]
|
|
1486
|
+
|
|
1487
|
+
# ---- reshape -> (N, C*kH*kW, L) ----
|
|
1488
|
+
return patches.reshape(N, C * k[0] * k[1], -1)
|