keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/backend/jax/numpy.py
CHANGED
|
@@ -446,6 +446,11 @@ def array(x, dtype=None):
|
|
|
446
446
|
return jnp.array(x, dtype=dtype)
|
|
447
447
|
|
|
448
448
|
|
|
449
|
+
def view(x, dtype=None):
|
|
450
|
+
x = convert_to_tensor(x)
|
|
451
|
+
return x.view(dtype=dtype)
|
|
452
|
+
|
|
453
|
+
|
|
449
454
|
def average(x, axis=None, weights=None):
|
|
450
455
|
x = convert_to_tensor(x)
|
|
451
456
|
dtypes_to_resolve = [x.dtype, float]
|
|
@@ -673,6 +678,10 @@ def empty(shape, dtype=None):
|
|
|
673
678
|
return jnp.empty(shape, dtype=dtype)
|
|
674
679
|
|
|
675
680
|
|
|
681
|
+
def empty_like(x, dtype=None):
|
|
682
|
+
return jnp.empty_like(x, dtype=dtype)
|
|
683
|
+
|
|
684
|
+
|
|
676
685
|
def equal(x1, x2):
|
|
677
686
|
x1 = convert_to_tensor(x1)
|
|
678
687
|
x2 = convert_to_tensor(x2)
|
|
@@ -819,6 +828,11 @@ def isposinf(x):
|
|
|
819
828
|
return jnp.isposinf(x)
|
|
820
829
|
|
|
821
830
|
|
|
831
|
+
def isreal(x):
|
|
832
|
+
x = convert_to_tensor(x)
|
|
833
|
+
return jnp.isreal(x)
|
|
834
|
+
|
|
835
|
+
|
|
822
836
|
def kron(x1, x2):
|
|
823
837
|
x1 = convert_to_tensor(x1)
|
|
824
838
|
x2 = convert_to_tensor(x2)
|
|
@@ -831,6 +845,19 @@ def lcm(x1, x2):
|
|
|
831
845
|
return jnp.lcm(x1, x2)
|
|
832
846
|
|
|
833
847
|
|
|
848
|
+
def ldexp(x1, x2):
|
|
849
|
+
x1 = convert_to_tensor(x1)
|
|
850
|
+
x2 = convert_to_tensor(x2)
|
|
851
|
+
|
|
852
|
+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
|
|
853
|
+
raise TypeError(
|
|
854
|
+
f"ldexp exponent must be an integer type. "
|
|
855
|
+
f"Received: x2 dtype={x2.dtype}"
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
return jnp.ldexp(x1, x2)
|
|
859
|
+
|
|
860
|
+
|
|
834
861
|
def less(x1, x2):
|
|
835
862
|
x1 = convert_to_tensor(x1)
|
|
836
863
|
x2 = convert_to_tensor(x2)
|
|
@@ -1036,6 +1063,11 @@ def prod(x, axis=None, keepdims=False, dtype=None):
|
|
|
1036
1063
|
return jnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
|
|
1037
1064
|
|
|
1038
1065
|
|
|
1066
|
+
def ptp(x, axis=None, keepdims=False):
|
|
1067
|
+
x = convert_to_tensor(x)
|
|
1068
|
+
return jnp.ptp(x, axis=axis, keepdims=keepdims)
|
|
1069
|
+
|
|
1070
|
+
|
|
1039
1071
|
def quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
1040
1072
|
x = convert_to_tensor(x)
|
|
1041
1073
|
q = convert_to_tensor(q)
|
|
@@ -1157,6 +1189,11 @@ def split(x, indices_or_sections, axis=0):
|
|
|
1157
1189
|
return jnp.split(x, indices_or_sections, axis=axis)
|
|
1158
1190
|
|
|
1159
1191
|
|
|
1192
|
+
def array_split(x, indices_or_sections, axis=0):
|
|
1193
|
+
x = convert_to_tensor(x)
|
|
1194
|
+
return jnp.array_split(x, indices_or_sections, axis=axis)
|
|
1195
|
+
|
|
1196
|
+
|
|
1160
1197
|
def stack(x, axis=0):
|
|
1161
1198
|
x = [convert_to_tensor(t) for t in x]
|
|
1162
1199
|
return jnp.stack(x, axis=axis)
|
|
@@ -1181,6 +1218,8 @@ def take(x, indices, axis=None):
|
|
|
1181
1218
|
|
|
1182
1219
|
|
|
1183
1220
|
def take_along_axis(x, indices, axis=None):
|
|
1221
|
+
x = convert_to_tensor(x)
|
|
1222
|
+
indices = convert_to_tensor(indices, sparse=False)
|
|
1184
1223
|
return jnp.take_along_axis(x, indices, axis=axis)
|
|
1185
1224
|
|
|
1186
1225
|
|
|
@@ -1235,14 +1274,7 @@ def tile(x, repeats):
|
|
|
1235
1274
|
|
|
1236
1275
|
def trace(x, offset=0, axis1=0, axis2=1):
|
|
1237
1276
|
x = convert_to_tensor(x)
|
|
1238
|
-
|
|
1239
|
-
# TODO: Remove the condition of uint8 and uint16 once we have jax>=0.4.27
|
|
1240
|
-
# for both CPU & GPU environments.
|
|
1241
|
-
# uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to int32
|
|
1242
|
-
# otherwise.
|
|
1243
|
-
if standardize_dtype(x.dtype) in ("bool", "uint8", "uint16"):
|
|
1244
|
-
dtype = "int32"
|
|
1245
|
-
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
|
|
1277
|
+
return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)
|
|
1246
1278
|
|
|
1247
1279
|
|
|
1248
1280
|
def tri(N, M=None, k=0, dtype=None):
|
|
@@ -1324,6 +1356,12 @@ def negative(x):
|
|
|
1324
1356
|
return jnp.negative(x)
|
|
1325
1357
|
|
|
1326
1358
|
|
|
1359
|
+
def nextafter(x1, x2):
|
|
1360
|
+
x1 = convert_to_tensor(x1)
|
|
1361
|
+
x2 = convert_to_tensor(x2)
|
|
1362
|
+
return jnp.nextafter(x1, x2)
|
|
1363
|
+
|
|
1364
|
+
|
|
1327
1365
|
@sparse.elementwise_unary(linear=False)
|
|
1328
1366
|
def square(x):
|
|
1329
1367
|
x = convert_to_tensor(x)
|
|
@@ -1363,6 +1401,19 @@ def transpose(x, axes=None):
|
|
|
1363
1401
|
return jnp.transpose(x, axes=axes)
|
|
1364
1402
|
|
|
1365
1403
|
|
|
1404
|
+
def trapezoid(y, x=None, dx=1.0, axis=-1):
|
|
1405
|
+
y = convert_to_tensor(y)
|
|
1406
|
+
if x is not None:
|
|
1407
|
+
x = convert_to_tensor(x)
|
|
1408
|
+
dx = convert_to_tensor(dx)
|
|
1409
|
+
return jnp.trapezoid(y, x, dx=dx, axis=axis)
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
def vander(x, N=None, increasing=False):
|
|
1413
|
+
x = convert_to_tensor(x)
|
|
1414
|
+
return jnp.vander(x, N=N, increasing=increasing)
|
|
1415
|
+
|
|
1416
|
+
|
|
1366
1417
|
def var(x, axis=None, keepdims=False):
|
|
1367
1418
|
x = convert_to_tensor(x)
|
|
1368
1419
|
# `jnp.var` does not handle low precision (e.g., float16) overflow
|
keras/src/backend/jax/trainer.py
CHANGED
|
@@ -266,8 +266,14 @@ class JAXTrainer(base_trainer.Trainer):
|
|
|
266
266
|
if distribution_lib.distribution() is not None:
|
|
267
267
|
state_shardings = self._get_state_sharding_spec()
|
|
268
268
|
out_shardings = (None, state_shardings)
|
|
269
|
+
if is_nnx_enabled():
|
|
270
|
+
step_fn = lambda state, data: type(self).train_step(
|
|
271
|
+
self, state, data
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
step_fn = self.train_step
|
|
269
275
|
train_step = jit(
|
|
270
|
-
|
|
276
|
+
step_fn,
|
|
271
277
|
donate_argnums=0,
|
|
272
278
|
out_shardings=out_shardings,
|
|
273
279
|
)
|
|
@@ -296,8 +302,14 @@ class JAXTrainer(base_trainer.Trainer):
|
|
|
296
302
|
metrics_shardings,
|
|
297
303
|
)
|
|
298
304
|
out_shardings = (None, state_shardings)
|
|
305
|
+
if is_nnx_enabled():
|
|
306
|
+
step_fn = lambda state, data: type(self).test_step(
|
|
307
|
+
self, state, data
|
|
308
|
+
)
|
|
309
|
+
else:
|
|
310
|
+
step_fn = self.test_step
|
|
299
311
|
test_step = jit(
|
|
300
|
-
|
|
312
|
+
step_fn,
|
|
301
313
|
donate_argnums=0,
|
|
302
314
|
out_shardings=out_shardings,
|
|
303
315
|
)
|
|
@@ -96,3 +96,7 @@ def lstsq(a, b, rcond=None):
|
|
|
96
96
|
a = convert_to_tensor(a)
|
|
97
97
|
b = convert_to_tensor(b)
|
|
98
98
|
return np.linalg.lstsq(a, b, rcond=rcond)[0]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def jvp(fun, primals, tangents, has_aux=False):
|
|
102
|
+
raise NotImplementedError("JVP is not supported by the Numpy backend.")
|
keras/src/backend/numpy/nn.py
CHANGED
|
@@ -3,6 +3,9 @@ import numpy as np
|
|
|
3
3
|
from jax import lax
|
|
4
4
|
|
|
5
5
|
from keras.src import backend
|
|
6
|
+
from keras.src.backend.common.backend_utils import (
|
|
7
|
+
compute_adaptive_pooling_window_sizes,
|
|
8
|
+
)
|
|
6
9
|
from keras.src.backend.common.backend_utils import (
|
|
7
10
|
compute_conv_transpose_padding_args_for_jax,
|
|
8
11
|
)
|
|
@@ -340,6 +343,252 @@ def average_pool(
|
|
|
340
343
|
return pooled / window_counts
|
|
341
344
|
|
|
342
345
|
|
|
346
|
+
def _compute_adaptive_pooling_gather_indices(
|
|
347
|
+
input_dim, output_size, big_window
|
|
348
|
+
):
|
|
349
|
+
window_starts = np.floor(
|
|
350
|
+
(np.arange(output_size) * input_dim) / output_size
|
|
351
|
+
).astype(np.int32)
|
|
352
|
+
|
|
353
|
+
window_ends = np.ceil(
|
|
354
|
+
(np.arange(1, output_size + 1) * input_dim) / output_size
|
|
355
|
+
).astype(np.int32)
|
|
356
|
+
|
|
357
|
+
window_sizes = window_ends - window_starts
|
|
358
|
+
is_big = window_sizes == big_window
|
|
359
|
+
|
|
360
|
+
small_window = big_window - 1
|
|
361
|
+
small_pool_len = input_dim - small_window + 1
|
|
362
|
+
|
|
363
|
+
small_indices = window_starts
|
|
364
|
+
big_indices = window_starts + small_pool_len
|
|
365
|
+
|
|
366
|
+
gather = np.where(is_big, big_indices, small_indices)
|
|
367
|
+
return gather.astype(np.int32)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _strided_view_1d(x, window_size):
|
|
371
|
+
n, l, c = x.shape
|
|
372
|
+
out = l - window_size + 1
|
|
373
|
+
|
|
374
|
+
strides = x.strides
|
|
375
|
+
shape = (n, out, window_size, c)
|
|
376
|
+
new_strides = (strides[0], strides[1], strides[1], strides[2])
|
|
377
|
+
|
|
378
|
+
return np.lib.stride_tricks.as_strided(x, shape=shape, strides=new_strides)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _adaptive_pool1d_impl(inputs, output_size, mode, data_format):
|
|
382
|
+
if isinstance(output_size, int):
|
|
383
|
+
output_size = (output_size,)
|
|
384
|
+
|
|
385
|
+
if data_format == "channels_first":
|
|
386
|
+
inputs = np.transpose(inputs, (0, 2, 1))
|
|
387
|
+
|
|
388
|
+
n, l, c = inputs.shape
|
|
389
|
+
out_l = output_size[0]
|
|
390
|
+
|
|
391
|
+
small, big = compute_adaptive_pooling_window_sizes(l, out_l)
|
|
392
|
+
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
|
|
393
|
+
|
|
394
|
+
sv_small = _strided_view_1d(inputs, small)
|
|
395
|
+
small_pool = (
|
|
396
|
+
np.mean(sv_small, axis=2)
|
|
397
|
+
if mode == "average"
|
|
398
|
+
else np.max(sv_small, axis=2)
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
sv_big = _strided_view_1d(inputs, big)
|
|
402
|
+
big_pool = (
|
|
403
|
+
np.mean(sv_big, axis=2) if mode == "average" else np.max(sv_big, axis=2)
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
combined = np.concatenate([small_pool, big_pool], axis=1)
|
|
407
|
+
out = combined[:, gather, :]
|
|
408
|
+
|
|
409
|
+
if data_format == "channels_first":
|
|
410
|
+
out = np.transpose(out, (0, 2, 1))
|
|
411
|
+
|
|
412
|
+
return out
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _adaptive_pool2d_impl(inputs, output_size, mode, data_format):
|
|
416
|
+
if isinstance(output_size, int):
|
|
417
|
+
output_size = (output_size, output_size)
|
|
418
|
+
|
|
419
|
+
if data_format == "channels_first":
|
|
420
|
+
inputs = np.transpose(inputs, (0, 2, 3, 1))
|
|
421
|
+
|
|
422
|
+
n, h, w, c = inputs.shape
|
|
423
|
+
out_h, out_w = output_size
|
|
424
|
+
|
|
425
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
426
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
427
|
+
|
|
428
|
+
x_h = np.transpose(inputs, (0, 2, 1, 3)).reshape(n * w, h, c)
|
|
429
|
+
|
|
430
|
+
sv_small_h = _strided_view_1d(x_h, small_h)
|
|
431
|
+
small_pool_h = (
|
|
432
|
+
np.mean(sv_small_h, axis=2)
|
|
433
|
+
if mode == "average"
|
|
434
|
+
else np.max(sv_small_h, axis=2)
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
sv_big_h = _strided_view_1d(x_h, big_h)
|
|
438
|
+
big_pool_h = (
|
|
439
|
+
np.mean(sv_big_h, axis=2)
|
|
440
|
+
if mode == "average"
|
|
441
|
+
else np.max(sv_big_h, axis=2)
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
|
|
445
|
+
pooled_h = combined_h[:, gather_h, :]
|
|
446
|
+
|
|
447
|
+
pooled_h = pooled_h.reshape(n, w, out_h, c)
|
|
448
|
+
pooled_h = np.transpose(pooled_h, (0, 2, 1, 3))
|
|
449
|
+
|
|
450
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
451
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
452
|
+
|
|
453
|
+
x_w = pooled_h.reshape(n * out_h, w, c)
|
|
454
|
+
|
|
455
|
+
sv_small_w = _strided_view_1d(x_w, small_w)
|
|
456
|
+
small_pool_w = (
|
|
457
|
+
np.mean(sv_small_w, axis=2)
|
|
458
|
+
if mode == "average"
|
|
459
|
+
else np.max(sv_small_w, axis=2)
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
sv_big_w = _strided_view_1d(x_w, big_w)
|
|
463
|
+
big_pool_w = (
|
|
464
|
+
np.mean(sv_big_w, axis=2)
|
|
465
|
+
if mode == "average"
|
|
466
|
+
else np.max(sv_big_w, axis=2)
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
|
|
470
|
+
out = combined_w[:, gather_w, :].reshape(n, out_h, out_w, c)
|
|
471
|
+
|
|
472
|
+
if data_format == "channels_first":
|
|
473
|
+
out = np.transpose(out, (0, 3, 1, 2))
|
|
474
|
+
|
|
475
|
+
return out
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def _adaptive_pool3d_impl(inputs, output_size, mode, data_format):
|
|
479
|
+
if isinstance(output_size, int):
|
|
480
|
+
output_size = (output_size, output_size, output_size)
|
|
481
|
+
|
|
482
|
+
if data_format == "channels_first":
|
|
483
|
+
inputs = np.transpose(inputs, (0, 2, 3, 4, 1))
|
|
484
|
+
|
|
485
|
+
n, d, h, w, c = inputs.shape
|
|
486
|
+
out_d, out_h, out_w = output_size
|
|
487
|
+
|
|
488
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
|
|
489
|
+
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
|
|
490
|
+
|
|
491
|
+
x_d = np.transpose(inputs, (0, 2, 3, 1, 4)).reshape(n * h * w, d, c)
|
|
492
|
+
|
|
493
|
+
sv_small_d = _strided_view_1d(x_d, small_d)
|
|
494
|
+
small_pool_d = (
|
|
495
|
+
np.mean(sv_small_d, axis=2)
|
|
496
|
+
if mode == "average"
|
|
497
|
+
else np.max(sv_small_d, axis=2)
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
sv_big_d = _strided_view_1d(x_d, big_d)
|
|
501
|
+
big_pool_d = (
|
|
502
|
+
np.mean(sv_big_d, axis=2)
|
|
503
|
+
if mode == "average"
|
|
504
|
+
else np.max(sv_big_d, axis=2)
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
combined_d = np.concatenate([small_pool_d, big_pool_d], axis=1)
|
|
508
|
+
pooled_d = combined_d[:, gather_d, :].reshape(n, h, w, out_d, c)
|
|
509
|
+
pooled_d = np.transpose(pooled_d, (0, 3, 1, 2, 4))
|
|
510
|
+
|
|
511
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
512
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
513
|
+
|
|
514
|
+
x_h = np.transpose(pooled_d, (0, 1, 3, 2, 4)).reshape(n * out_d * w, h, c)
|
|
515
|
+
|
|
516
|
+
sv_small_h = _strided_view_1d(x_h, small_h)
|
|
517
|
+
small_pool_h = (
|
|
518
|
+
np.mean(sv_small_h, axis=2)
|
|
519
|
+
if mode == "average"
|
|
520
|
+
else np.max(sv_small_h, axis=2)
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
sv_big_h = _strided_view_1d(x_h, big_h)
|
|
524
|
+
big_pool_h = (
|
|
525
|
+
np.mean(sv_big_h, axis=2)
|
|
526
|
+
if mode == "average"
|
|
527
|
+
else np.max(sv_big_h, axis=2)
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
|
|
531
|
+
pooled_h = combined_h[:, gather_h, :].reshape(n, out_d, w, out_h, c)
|
|
532
|
+
pooled_h = np.transpose(pooled_h, (0, 1, 3, 2, 4))
|
|
533
|
+
|
|
534
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
535
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
536
|
+
|
|
537
|
+
x_w = pooled_h.reshape(n * out_d * out_h, w, c)
|
|
538
|
+
|
|
539
|
+
sv_small_w = _strided_view_1d(x_w, small_w)
|
|
540
|
+
small_pool_w = (
|
|
541
|
+
np.mean(sv_small_w, axis=2)
|
|
542
|
+
if mode == "average"
|
|
543
|
+
else np.max(sv_small_w, axis=2)
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
sv_big_w = _strided_view_1d(x_w, big_w)
|
|
547
|
+
big_pool_w = (
|
|
548
|
+
np.mean(sv_big_w, axis=2)
|
|
549
|
+
if mode == "average"
|
|
550
|
+
else np.max(sv_big_w, axis=2)
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
|
|
554
|
+
out = combined_w[:, gather_w, :].reshape(n, out_d, out_h, out_w, c)
|
|
555
|
+
|
|
556
|
+
if data_format == "channels_first":
|
|
557
|
+
out = np.transpose(out, (0, 4, 1, 2, 3))
|
|
558
|
+
|
|
559
|
+
return out
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def adaptive_average_pool(inputs, output_size, data_format=None):
|
|
563
|
+
data_format = backend.standardize_data_format(data_format)
|
|
564
|
+
dims = inputs.ndim - 2
|
|
565
|
+
if dims == 1:
|
|
566
|
+
return _adaptive_pool1d_impl(
|
|
567
|
+
inputs, output_size, "average", data_format
|
|
568
|
+
)
|
|
569
|
+
if dims == 2:
|
|
570
|
+
return _adaptive_pool2d_impl(
|
|
571
|
+
inputs, output_size, "average", data_format
|
|
572
|
+
)
|
|
573
|
+
if dims == 3:
|
|
574
|
+
return _adaptive_pool3d_impl(
|
|
575
|
+
inputs, output_size, "average", data_format
|
|
576
|
+
)
|
|
577
|
+
raise ValueError("adaptive_average_pool supports only 1D/2D/3D")
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def adaptive_max_pool(inputs, output_size, data_format=None):
|
|
581
|
+
data_format = backend.standardize_data_format(data_format)
|
|
582
|
+
dims = inputs.ndim - 2
|
|
583
|
+
if dims == 1:
|
|
584
|
+
return _adaptive_pool1d_impl(inputs, output_size, "max", data_format)
|
|
585
|
+
if dims == 2:
|
|
586
|
+
return _adaptive_pool2d_impl(inputs, output_size, "max", data_format)
|
|
587
|
+
if dims == 3:
|
|
588
|
+
return _adaptive_pool3d_impl(inputs, output_size, "max", data_format)
|
|
589
|
+
raise ValueError("adaptive_max_pool supports only 1D/2D/3D")
|
|
590
|
+
|
|
591
|
+
|
|
343
592
|
def _convert_to_lax_conv_dimension_numbers(
|
|
344
593
|
num_spatial_dims,
|
|
345
594
|
data_format="channels_last",
|
|
@@ -404,7 +653,7 @@ def conv(
|
|
|
404
653
|
f"kernel in_channels {kernel_in_channels}. "
|
|
405
654
|
)
|
|
406
655
|
feature_group_count = channels // kernel_in_channels
|
|
407
|
-
|
|
656
|
+
result = np.array(
|
|
408
657
|
jax.lax.conv_general_dilated(
|
|
409
658
|
inputs,
|
|
410
659
|
kernel if is_tensor(kernel) else kernel.numpy(),
|
|
@@ -415,6 +664,14 @@ def conv(
|
|
|
415
664
|
feature_group_count=feature_group_count,
|
|
416
665
|
)
|
|
417
666
|
)
|
|
667
|
+
if result.size == 0:
|
|
668
|
+
raise ValueError(
|
|
669
|
+
"The convolution operation resulted in an empty output. "
|
|
670
|
+
"This can happen if the input is too small for the given "
|
|
671
|
+
"kernel size, strides, dilation rate, and padding mode. "
|
|
672
|
+
"Please check the input shape and convolution parameters."
|
|
673
|
+
)
|
|
674
|
+
return result
|
|
418
675
|
|
|
419
676
|
|
|
420
677
|
def depthwise_conv(
|
|
@@ -1176,3 +1433,56 @@ def dot_product_attention(
|
|
|
1176
1433
|
return _dot_product_attention_xla(
|
|
1177
1434
|
query, key, value, bias, mask, is_causal, scale
|
|
1178
1435
|
)
|
|
1436
|
+
|
|
1437
|
+
|
|
1438
|
+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|
1439
|
+
"""NumPy implementation of Unfold.
|
|
1440
|
+
Extract sliding local blocks from a **NCHW** batched image tensor.
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
input: 4-D tensor, shape (N, C, H, W) **required**.
|
|
1444
|
+
kernel_size: int or (kH, kW)
|
|
1445
|
+
dilation: int or (dH, dW), default 1
|
|
1446
|
+
padding: int or (pH, pW), default 0
|
|
1447
|
+
stride: int or (sH, sW), default 1
|
|
1448
|
+
|
|
1449
|
+
Returns:
|
|
1450
|
+
3-D tensor, shape (N, C*kH*kW, L)
|
|
1451
|
+
"""
|
|
1452
|
+
|
|
1453
|
+
def _pair(x):
|
|
1454
|
+
return (x, x) if isinstance(x, int) else x
|
|
1455
|
+
|
|
1456
|
+
k = _pair(kernel_size)
|
|
1457
|
+
d = _pair(dilation)
|
|
1458
|
+
p = _pair(padding)
|
|
1459
|
+
s = _pair(stride)
|
|
1460
|
+
|
|
1461
|
+
N, C, H, W = input.shape
|
|
1462
|
+
|
|
1463
|
+
# ---- padding ----
|
|
1464
|
+
if any(_ > 0 for _ in p):
|
|
1465
|
+
input = np.pad(
|
|
1466
|
+
input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])), mode="constant"
|
|
1467
|
+
)
|
|
1468
|
+
|
|
1469
|
+
# ---- spatial size ----
|
|
1470
|
+
oH = (input.shape[2] - (k[0] - 1) * d[0] - 1) // s[0] + 1
|
|
1471
|
+
oW = (input.shape[3] - (k[1] - 1) * d[1] - 1) // s[1] + 1
|
|
1472
|
+
|
|
1473
|
+
i0 = np.arange(0, oH) * s[0]
|
|
1474
|
+
j0 = np.arange(0, oW) * s[1]
|
|
1475
|
+
i, j = np.meshgrid(i0, j0, indexing="ij") # shape (oH, oW)
|
|
1476
|
+
i = i.reshape(-1)
|
|
1477
|
+
j = j.reshape(-1)
|
|
1478
|
+
|
|
1479
|
+
# ---- flatten patches ----
|
|
1480
|
+
patches = np.empty((N, C, k[0], k[1], oH * oW), dtype=input.dtype)
|
|
1481
|
+
for idx in range(k[0]):
|
|
1482
|
+
for jdx in range(k[1]):
|
|
1483
|
+
patches[:, :, idx, jdx, :] = input[
|
|
1484
|
+
:, :, i + idx * d[0], j + jdx * d[1]
|
|
1485
|
+
]
|
|
1486
|
+
|
|
1487
|
+
# ---- reshape -> (N, C*kH*kW, L) ----
|
|
1488
|
+
return patches.reshape(N, C * k[0] * k[1], -1)
|
keras/src/backend/numpy/numpy.py
CHANGED
|
@@ -294,6 +294,11 @@ def array(x, dtype=None):
|
|
|
294
294
|
return convert_to_tensor(x, dtype=dtype)
|
|
295
295
|
|
|
296
296
|
|
|
297
|
+
def view(x, dtype=None):
|
|
298
|
+
x = convert_to_tensor(x)
|
|
299
|
+
return x.view(dtype=dtype)
|
|
300
|
+
|
|
301
|
+
|
|
297
302
|
def average(x, axis=None, weights=None):
|
|
298
303
|
axis = standardize_axis_for_numpy(axis)
|
|
299
304
|
x = convert_to_tensor(x)
|
|
@@ -607,6 +612,10 @@ def empty(shape, dtype=None):
|
|
|
607
612
|
return np.empty(shape, dtype=dtype)
|
|
608
613
|
|
|
609
614
|
|
|
615
|
+
def empty_like(x, dtype=None):
|
|
616
|
+
return np.empty_like(x, dtype=dtype)
|
|
617
|
+
|
|
618
|
+
|
|
610
619
|
def equal(x1, x2):
|
|
611
620
|
return np.equal(x1, x2)
|
|
612
621
|
|
|
@@ -745,6 +754,11 @@ def isposinf(x):
|
|
|
745
754
|
return np.isposinf(x)
|
|
746
755
|
|
|
747
756
|
|
|
757
|
+
def isreal(x):
|
|
758
|
+
x = convert_to_tensor(x)
|
|
759
|
+
return np.isreal(x)
|
|
760
|
+
|
|
761
|
+
|
|
748
762
|
def kron(x1, x2):
|
|
749
763
|
x1 = convert_to_tensor(x1)
|
|
750
764
|
x2 = convert_to_tensor(x2)
|
|
@@ -759,6 +773,19 @@ def lcm(x1, x2):
|
|
|
759
773
|
return np.lcm(x1, x2).astype(dtype)
|
|
760
774
|
|
|
761
775
|
|
|
776
|
+
def ldexp(x1, x2):
|
|
777
|
+
x1 = convert_to_tensor(x1)
|
|
778
|
+
x2 = convert_to_tensor(x2)
|
|
779
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
780
|
+
|
|
781
|
+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
|
|
782
|
+
raise TypeError(
|
|
783
|
+
f"ldexp exponent must be an integer type. "
|
|
784
|
+
f"Received: x2 dtype={x2.dtype}"
|
|
785
|
+
)
|
|
786
|
+
return np.ldexp(x1, x2).astype(dtype)
|
|
787
|
+
|
|
788
|
+
|
|
762
789
|
def less(x1, x2):
|
|
763
790
|
return np.less(x1, x2)
|
|
764
791
|
|
|
@@ -991,6 +1018,10 @@ def prod(x, axis=None, keepdims=False, dtype=None):
|
|
|
991
1018
|
return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
|
|
992
1019
|
|
|
993
1020
|
|
|
1021
|
+
def ptp(x, axis=None, keepdims=False):
|
|
1022
|
+
return np.ptp(x, axis=axis, keepdims=keepdims)
|
|
1023
|
+
|
|
1024
|
+
|
|
994
1025
|
def quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
995
1026
|
axis = standardize_axis_for_numpy(axis)
|
|
996
1027
|
x = convert_to_tensor(x)
|
|
@@ -1097,6 +1128,11 @@ def split(x, indices_or_sections, axis=0):
|
|
|
1097
1128
|
return np.split(x, indices_or_sections, axis=axis)
|
|
1098
1129
|
|
|
1099
1130
|
|
|
1131
|
+
def array_split(x, indices_or_sections, axis=0):
|
|
1132
|
+
axis = standardize_axis_for_numpy(axis)
|
|
1133
|
+
return np.array_split(x, indices_or_sections, axis=axis)
|
|
1134
|
+
|
|
1135
|
+
|
|
1100
1136
|
def stack(x, axis=0):
|
|
1101
1137
|
axis = standardize_axis_for_numpy(axis)
|
|
1102
1138
|
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
|
|
@@ -1172,8 +1208,10 @@ def trace(x, offset=0, axis1=0, axis2=1):
|
|
|
1172
1208
|
axis2 = standardize_axis_for_numpy(axis2)
|
|
1173
1209
|
x = convert_to_tensor(x)
|
|
1174
1210
|
dtype = standardize_dtype(x.dtype)
|
|
1175
|
-
if dtype
|
|
1176
|
-
dtype =
|
|
1211
|
+
if dtype in ("bool", "int8", "int16"):
|
|
1212
|
+
dtype = "int32"
|
|
1213
|
+
elif dtype in ("uint8", "uint16"):
|
|
1214
|
+
dtype = "uint32"
|
|
1177
1215
|
return np.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
|
|
1178
1216
|
|
|
1179
1217
|
|
|
@@ -1301,6 +1339,14 @@ def negative(x):
|
|
|
1301
1339
|
return np.negative(x)
|
|
1302
1340
|
|
|
1303
1341
|
|
|
1342
|
+
def nextafter(x1, x2):
|
|
1343
|
+
x1 = convert_to_tensor(x1)
|
|
1344
|
+
x2 = convert_to_tensor(x2)
|
|
1345
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1346
|
+
|
|
1347
|
+
return np.nextafter(x1, x2).astype(dtype)
|
|
1348
|
+
|
|
1349
|
+
|
|
1304
1350
|
def square(x):
|
|
1305
1351
|
x = convert_to_tensor(x)
|
|
1306
1352
|
if standardize_dtype(x.dtype) == "bool":
|
|
@@ -1329,6 +1375,23 @@ def transpose(x, axes=None):
|
|
|
1329
1375
|
return np.transpose(x, axes=axes)
|
|
1330
1376
|
|
|
1331
1377
|
|
|
1378
|
+
def trapezoid(y, x=None, dx=1.0, axis=-1):
|
|
1379
|
+
y = convert_to_tensor(y)
|
|
1380
|
+
result_dtype = dtypes.result_type(y.dtype, float)
|
|
1381
|
+
if x is not None:
|
|
1382
|
+
x = convert_to_tensor(x)
|
|
1383
|
+
dx = convert_to_tensor(dx)
|
|
1384
|
+
return np.trapezoid(y, x, dx=dx, axis=axis).astype(result_dtype)
|
|
1385
|
+
|
|
1386
|
+
|
|
1387
|
+
def vander(x, N=None, increasing=False):
|
|
1388
|
+
x = convert_to_tensor(x)
|
|
1389
|
+
result_dtype = dtypes.result_type(x.dtype)
|
|
1390
|
+
compute_dtype = dtypes.result_type(x.dtype, config.floatx())
|
|
1391
|
+
x = x.astype(compute_dtype)
|
|
1392
|
+
return np.vander(x, N=N, increasing=increasing).astype(result_dtype)
|
|
1393
|
+
|
|
1394
|
+
|
|
1332
1395
|
def var(x, axis=None, keepdims=False):
|
|
1333
1396
|
axis = standardize_axis_for_numpy(axis)
|
|
1334
1397
|
x = convert_to_tensor(x)
|
|
@@ -15,6 +15,7 @@ from keras.src.backend.openvino.core import compute_output_spec
|
|
|
15
15
|
from keras.src.backend.openvino.core import cond
|
|
16
16
|
from keras.src.backend.openvino.core import convert_to_numpy
|
|
17
17
|
from keras.src.backend.openvino.core import convert_to_tensor
|
|
18
|
+
from keras.src.backend.openvino.core import device_scope
|
|
18
19
|
from keras.src.backend.openvino.core import is_tensor
|
|
19
20
|
from keras.src.backend.openvino.core import random_seed_dtype
|
|
20
21
|
from keras.src.backend.openvino.core import shape
|
|
@@ -13,7 +13,6 @@ from openvino import compile_model
|
|
|
13
13
|
from keras.src import tree
|
|
14
14
|
from keras.src.backend.common import KerasVariable
|
|
15
15
|
from keras.src.backend.common import dtypes
|
|
16
|
-
from keras.src.backend.common import global_state
|
|
17
16
|
from keras.src.backend.common import standardize_dtype
|
|
18
17
|
from keras.src.backend.common.dtypes import result_type
|
|
19
18
|
from keras.src.backend.common.keras_tensor import KerasTensor
|
|
@@ -530,31 +529,11 @@ def ov_to_keras_type(ov_type):
|
|
|
530
529
|
|
|
531
530
|
@contextlib.contextmanager
|
|
532
531
|
def device_scope(device_name):
|
|
533
|
-
|
|
534
|
-
global_state.set_global_attribute("openvino_device", current_device)
|
|
532
|
+
yield
|
|
535
533
|
|
|
536
534
|
|
|
537
535
|
def get_device():
|
|
538
|
-
|
|
539
|
-
if device is None:
|
|
540
|
-
return "CPU"
|
|
541
|
-
return device
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
def _parse_device_input(device_name):
|
|
545
|
-
if isinstance(device_name, str):
|
|
546
|
-
# We support string value like "cpu:0", "gpu:1", and need to convert
|
|
547
|
-
# "gpu" to "cuda"
|
|
548
|
-
device_name = device_name.upper()
|
|
549
|
-
device_type, _ = device_name.split(":")
|
|
550
|
-
return device_type
|
|
551
|
-
else:
|
|
552
|
-
raise ValueError(
|
|
553
|
-
"Invalid value for argument `device_name`. "
|
|
554
|
-
"Expected a string like 'gpu:0' or 'cpu'. "
|
|
555
|
-
f"Received: device_name='{device_name}'"
|
|
556
|
-
)
|
|
557
|
-
return device_name
|
|
536
|
+
return "CPU"
|
|
558
537
|
|
|
559
538
|
|
|
560
539
|
class Variable(KerasVariable):
|
|
@@ -56,3 +56,7 @@ def svd(x, full_matrices=True, compute_uv=True):
|
|
|
56
56
|
|
|
57
57
|
def lstsq(a, b, rcond=None):
|
|
58
58
|
raise NotImplementedError("`lstsq` is not supported with openvino backend")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def jvp(fun, primals, tangents, has_aux=False):
|
|
62
|
+
raise NotImplementedError("`jvp` is not supported with openvino backend")
|