keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -446,6 +446,11 @@ def array(x, dtype=None):
446
446
  return jnp.array(x, dtype=dtype)
447
447
 
448
448
 
449
+ def view(x, dtype=None):
450
+ x = convert_to_tensor(x)
451
+ return x.view(dtype=dtype)
452
+
453
+
449
454
  def average(x, axis=None, weights=None):
450
455
  x = convert_to_tensor(x)
451
456
  dtypes_to_resolve = [x.dtype, float]
@@ -673,6 +678,10 @@ def empty(shape, dtype=None):
673
678
  return jnp.empty(shape, dtype=dtype)
674
679
 
675
680
 
681
+ def empty_like(x, dtype=None):
682
+ return jnp.empty_like(x, dtype=dtype)
683
+
684
+
676
685
  def equal(x1, x2):
677
686
  x1 = convert_to_tensor(x1)
678
687
  x2 = convert_to_tensor(x2)
@@ -819,6 +828,11 @@ def isposinf(x):
819
828
  return jnp.isposinf(x)
820
829
 
821
830
 
831
+ def isreal(x):
832
+ x = convert_to_tensor(x)
833
+ return jnp.isreal(x)
834
+
835
+
822
836
  def kron(x1, x2):
823
837
  x1 = convert_to_tensor(x1)
824
838
  x2 = convert_to_tensor(x2)
@@ -831,6 +845,19 @@ def lcm(x1, x2):
831
845
  return jnp.lcm(x1, x2)
832
846
 
833
847
 
848
+ def ldexp(x1, x2):
849
+ x1 = convert_to_tensor(x1)
850
+ x2 = convert_to_tensor(x2)
851
+
852
+ if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
853
+ raise TypeError(
854
+ f"ldexp exponent must be an integer type. "
855
+ f"Received: x2 dtype={x2.dtype}"
856
+ )
857
+
858
+ return jnp.ldexp(x1, x2)
859
+
860
+
834
861
  def less(x1, x2):
835
862
  x1 = convert_to_tensor(x1)
836
863
  x2 = convert_to_tensor(x2)
@@ -1036,6 +1063,11 @@ def prod(x, axis=None, keepdims=False, dtype=None):
1036
1063
  return jnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
1037
1064
 
1038
1065
 
1066
+ def ptp(x, axis=None, keepdims=False):
1067
+ x = convert_to_tensor(x)
1068
+ return jnp.ptp(x, axis=axis, keepdims=keepdims)
1069
+
1070
+
1039
1071
  def quantile(x, q, axis=None, method="linear", keepdims=False):
1040
1072
  x = convert_to_tensor(x)
1041
1073
  q = convert_to_tensor(q)
@@ -1157,6 +1189,11 @@ def split(x, indices_or_sections, axis=0):
1157
1189
  return jnp.split(x, indices_or_sections, axis=axis)
1158
1190
 
1159
1191
 
1192
+ def array_split(x, indices_or_sections, axis=0):
1193
+ x = convert_to_tensor(x)
1194
+ return jnp.array_split(x, indices_or_sections, axis=axis)
1195
+
1196
+
1160
1197
  def stack(x, axis=0):
1161
1198
  x = [convert_to_tensor(t) for t in x]
1162
1199
  return jnp.stack(x, axis=axis)
@@ -1181,6 +1218,8 @@ def take(x, indices, axis=None):
1181
1218
 
1182
1219
 
1183
1220
  def take_along_axis(x, indices, axis=None):
1221
+ x = convert_to_tensor(x)
1222
+ indices = convert_to_tensor(indices, sparse=False)
1184
1223
  return jnp.take_along_axis(x, indices, axis=axis)
1185
1224
 
1186
1225
 
@@ -1235,14 +1274,7 @@ def tile(x, repeats):
1235
1274
 
1236
1275
  def trace(x, offset=0, axis1=0, axis2=1):
1237
1276
  x = convert_to_tensor(x)
1238
- dtype = None
1239
- # TODO: Remove the condition of uint8 and uint16 once we have jax>=0.4.27
1240
- # for both CPU & GPU environments.
1241
- # uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to int32
1242
- # otherwise.
1243
- if standardize_dtype(x.dtype) in ("bool", "uint8", "uint16"):
1244
- dtype = "int32"
1245
- return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
1277
+ return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)
1246
1278
 
1247
1279
 
1248
1280
  def tri(N, M=None, k=0, dtype=None):
@@ -1324,6 +1356,12 @@ def negative(x):
1324
1356
  return jnp.negative(x)
1325
1357
 
1326
1358
 
1359
+ def nextafter(x1, x2):
1360
+ x1 = convert_to_tensor(x1)
1361
+ x2 = convert_to_tensor(x2)
1362
+ return jnp.nextafter(x1, x2)
1363
+
1364
+
1327
1365
  @sparse.elementwise_unary(linear=False)
1328
1366
  def square(x):
1329
1367
  x = convert_to_tensor(x)
@@ -1363,6 +1401,19 @@ def transpose(x, axes=None):
1363
1401
  return jnp.transpose(x, axes=axes)
1364
1402
 
1365
1403
 
1404
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
1405
+ y = convert_to_tensor(y)
1406
+ if x is not None:
1407
+ x = convert_to_tensor(x)
1408
+ dx = convert_to_tensor(dx)
1409
+ return jnp.trapezoid(y, x, dx=dx, axis=axis)
1410
+
1411
+
1412
+ def vander(x, N=None, increasing=False):
1413
+ x = convert_to_tensor(x)
1414
+ return jnp.vander(x, N=N, increasing=increasing)
1415
+
1416
+
1366
1417
  def var(x, axis=None, keepdims=False):
1367
1418
  x = convert_to_tensor(x)
1368
1419
  # `jnp.var` does not handle low precision (e.g., float16) overflow
@@ -266,8 +266,14 @@ class JAXTrainer(base_trainer.Trainer):
266
266
  if distribution_lib.distribution() is not None:
267
267
  state_shardings = self._get_state_sharding_spec()
268
268
  out_shardings = (None, state_shardings)
269
+ if is_nnx_enabled():
270
+ step_fn = lambda state, data: type(self).train_step(
271
+ self, state, data
272
+ )
273
+ else:
274
+ step_fn = self.train_step
269
275
  train_step = jit(
270
- self.train_step,
276
+ step_fn,
271
277
  donate_argnums=0,
272
278
  out_shardings=out_shardings,
273
279
  )
@@ -296,8 +302,14 @@ class JAXTrainer(base_trainer.Trainer):
296
302
  metrics_shardings,
297
303
  )
298
304
  out_shardings = (None, state_shardings)
305
+ if is_nnx_enabled():
306
+ step_fn = lambda state, data: type(self).test_step(
307
+ self, state, data
308
+ )
309
+ else:
310
+ step_fn = self.test_step
299
311
  test_step = jit(
300
- self.test_step,
312
+ step_fn,
301
313
  donate_argnums=0,
302
314
  out_shardings=out_shardings,
303
315
  )
@@ -96,3 +96,7 @@ def lstsq(a, b, rcond=None):
96
96
  a = convert_to_tensor(a)
97
97
  b = convert_to_tensor(b)
98
98
  return np.linalg.lstsq(a, b, rcond=rcond)[0]
99
+
100
+
101
+ def jvp(fun, primals, tangents, has_aux=False):
102
+ raise NotImplementedError("JVP is not supported by the Numpy backend.")
@@ -3,6 +3,9 @@ import numpy as np
3
3
  from jax import lax
4
4
 
5
5
  from keras.src import backend
6
+ from keras.src.backend.common.backend_utils import (
7
+ compute_adaptive_pooling_window_sizes,
8
+ )
6
9
  from keras.src.backend.common.backend_utils import (
7
10
  compute_conv_transpose_padding_args_for_jax,
8
11
  )
@@ -340,6 +343,252 @@ def average_pool(
340
343
  return pooled / window_counts
341
344
 
342
345
 
346
+ def _compute_adaptive_pooling_gather_indices(
347
+ input_dim, output_size, big_window
348
+ ):
349
+ window_starts = np.floor(
350
+ (np.arange(output_size) * input_dim) / output_size
351
+ ).astype(np.int32)
352
+
353
+ window_ends = np.ceil(
354
+ (np.arange(1, output_size + 1) * input_dim) / output_size
355
+ ).astype(np.int32)
356
+
357
+ window_sizes = window_ends - window_starts
358
+ is_big = window_sizes == big_window
359
+
360
+ small_window = big_window - 1
361
+ small_pool_len = input_dim - small_window + 1
362
+
363
+ small_indices = window_starts
364
+ big_indices = window_starts + small_pool_len
365
+
366
+ gather = np.where(is_big, big_indices, small_indices)
367
+ return gather.astype(np.int32)
368
+
369
+
370
+ def _strided_view_1d(x, window_size):
371
+ n, l, c = x.shape
372
+ out = l - window_size + 1
373
+
374
+ strides = x.strides
375
+ shape = (n, out, window_size, c)
376
+ new_strides = (strides[0], strides[1], strides[1], strides[2])
377
+
378
+ return np.lib.stride_tricks.as_strided(x, shape=shape, strides=new_strides)
379
+
380
+
381
+ def _adaptive_pool1d_impl(inputs, output_size, mode, data_format):
382
+ if isinstance(output_size, int):
383
+ output_size = (output_size,)
384
+
385
+ if data_format == "channels_first":
386
+ inputs = np.transpose(inputs, (0, 2, 1))
387
+
388
+ n, l, c = inputs.shape
389
+ out_l = output_size[0]
390
+
391
+ small, big = compute_adaptive_pooling_window_sizes(l, out_l)
392
+ gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
393
+
394
+ sv_small = _strided_view_1d(inputs, small)
395
+ small_pool = (
396
+ np.mean(sv_small, axis=2)
397
+ if mode == "average"
398
+ else np.max(sv_small, axis=2)
399
+ )
400
+
401
+ sv_big = _strided_view_1d(inputs, big)
402
+ big_pool = (
403
+ np.mean(sv_big, axis=2) if mode == "average" else np.max(sv_big, axis=2)
404
+ )
405
+
406
+ combined = np.concatenate([small_pool, big_pool], axis=1)
407
+ out = combined[:, gather, :]
408
+
409
+ if data_format == "channels_first":
410
+ out = np.transpose(out, (0, 2, 1))
411
+
412
+ return out
413
+
414
+
415
+ def _adaptive_pool2d_impl(inputs, output_size, mode, data_format):
416
+ if isinstance(output_size, int):
417
+ output_size = (output_size, output_size)
418
+
419
+ if data_format == "channels_first":
420
+ inputs = np.transpose(inputs, (0, 2, 3, 1))
421
+
422
+ n, h, w, c = inputs.shape
423
+ out_h, out_w = output_size
424
+
425
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
426
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
427
+
428
+ x_h = np.transpose(inputs, (0, 2, 1, 3)).reshape(n * w, h, c)
429
+
430
+ sv_small_h = _strided_view_1d(x_h, small_h)
431
+ small_pool_h = (
432
+ np.mean(sv_small_h, axis=2)
433
+ if mode == "average"
434
+ else np.max(sv_small_h, axis=2)
435
+ )
436
+
437
+ sv_big_h = _strided_view_1d(x_h, big_h)
438
+ big_pool_h = (
439
+ np.mean(sv_big_h, axis=2)
440
+ if mode == "average"
441
+ else np.max(sv_big_h, axis=2)
442
+ )
443
+
444
+ combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
445
+ pooled_h = combined_h[:, gather_h, :]
446
+
447
+ pooled_h = pooled_h.reshape(n, w, out_h, c)
448
+ pooled_h = np.transpose(pooled_h, (0, 2, 1, 3))
449
+
450
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
451
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
452
+
453
+ x_w = pooled_h.reshape(n * out_h, w, c)
454
+
455
+ sv_small_w = _strided_view_1d(x_w, small_w)
456
+ small_pool_w = (
457
+ np.mean(sv_small_w, axis=2)
458
+ if mode == "average"
459
+ else np.max(sv_small_w, axis=2)
460
+ )
461
+
462
+ sv_big_w = _strided_view_1d(x_w, big_w)
463
+ big_pool_w = (
464
+ np.mean(sv_big_w, axis=2)
465
+ if mode == "average"
466
+ else np.max(sv_big_w, axis=2)
467
+ )
468
+
469
+ combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
470
+ out = combined_w[:, gather_w, :].reshape(n, out_h, out_w, c)
471
+
472
+ if data_format == "channels_first":
473
+ out = np.transpose(out, (0, 3, 1, 2))
474
+
475
+ return out
476
+
477
+
478
+ def _adaptive_pool3d_impl(inputs, output_size, mode, data_format):
479
+ if isinstance(output_size, int):
480
+ output_size = (output_size, output_size, output_size)
481
+
482
+ if data_format == "channels_first":
483
+ inputs = np.transpose(inputs, (0, 2, 3, 4, 1))
484
+
485
+ n, d, h, w, c = inputs.shape
486
+ out_d, out_h, out_w = output_size
487
+
488
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
489
+ gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
490
+
491
+ x_d = np.transpose(inputs, (0, 2, 3, 1, 4)).reshape(n * h * w, d, c)
492
+
493
+ sv_small_d = _strided_view_1d(x_d, small_d)
494
+ small_pool_d = (
495
+ np.mean(sv_small_d, axis=2)
496
+ if mode == "average"
497
+ else np.max(sv_small_d, axis=2)
498
+ )
499
+
500
+ sv_big_d = _strided_view_1d(x_d, big_d)
501
+ big_pool_d = (
502
+ np.mean(sv_big_d, axis=2)
503
+ if mode == "average"
504
+ else np.max(sv_big_d, axis=2)
505
+ )
506
+
507
+ combined_d = np.concatenate([small_pool_d, big_pool_d], axis=1)
508
+ pooled_d = combined_d[:, gather_d, :].reshape(n, h, w, out_d, c)
509
+ pooled_d = np.transpose(pooled_d, (0, 3, 1, 2, 4))
510
+
511
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
512
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
513
+
514
+ x_h = np.transpose(pooled_d, (0, 1, 3, 2, 4)).reshape(n * out_d * w, h, c)
515
+
516
+ sv_small_h = _strided_view_1d(x_h, small_h)
517
+ small_pool_h = (
518
+ np.mean(sv_small_h, axis=2)
519
+ if mode == "average"
520
+ else np.max(sv_small_h, axis=2)
521
+ )
522
+
523
+ sv_big_h = _strided_view_1d(x_h, big_h)
524
+ big_pool_h = (
525
+ np.mean(sv_big_h, axis=2)
526
+ if mode == "average"
527
+ else np.max(sv_big_h, axis=2)
528
+ )
529
+
530
+ combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
531
+ pooled_h = combined_h[:, gather_h, :].reshape(n, out_d, w, out_h, c)
532
+ pooled_h = np.transpose(pooled_h, (0, 1, 3, 2, 4))
533
+
534
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
535
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
536
+
537
+ x_w = pooled_h.reshape(n * out_d * out_h, w, c)
538
+
539
+ sv_small_w = _strided_view_1d(x_w, small_w)
540
+ small_pool_w = (
541
+ np.mean(sv_small_w, axis=2)
542
+ if mode == "average"
543
+ else np.max(sv_small_w, axis=2)
544
+ )
545
+
546
+ sv_big_w = _strided_view_1d(x_w, big_w)
547
+ big_pool_w = (
548
+ np.mean(sv_big_w, axis=2)
549
+ if mode == "average"
550
+ else np.max(sv_big_w, axis=2)
551
+ )
552
+
553
+ combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
554
+ out = combined_w[:, gather_w, :].reshape(n, out_d, out_h, out_w, c)
555
+
556
+ if data_format == "channels_first":
557
+ out = np.transpose(out, (0, 4, 1, 2, 3))
558
+
559
+ return out
560
+
561
+
562
+ def adaptive_average_pool(inputs, output_size, data_format=None):
563
+ data_format = backend.standardize_data_format(data_format)
564
+ dims = inputs.ndim - 2
565
+ if dims == 1:
566
+ return _adaptive_pool1d_impl(
567
+ inputs, output_size, "average", data_format
568
+ )
569
+ if dims == 2:
570
+ return _adaptive_pool2d_impl(
571
+ inputs, output_size, "average", data_format
572
+ )
573
+ if dims == 3:
574
+ return _adaptive_pool3d_impl(
575
+ inputs, output_size, "average", data_format
576
+ )
577
+ raise ValueError("adaptive_average_pool supports only 1D/2D/3D")
578
+
579
+
580
+ def adaptive_max_pool(inputs, output_size, data_format=None):
581
+ data_format = backend.standardize_data_format(data_format)
582
+ dims = inputs.ndim - 2
583
+ if dims == 1:
584
+ return _adaptive_pool1d_impl(inputs, output_size, "max", data_format)
585
+ if dims == 2:
586
+ return _adaptive_pool2d_impl(inputs, output_size, "max", data_format)
587
+ if dims == 3:
588
+ return _adaptive_pool3d_impl(inputs, output_size, "max", data_format)
589
+ raise ValueError("adaptive_max_pool supports only 1D/2D/3D")
590
+
591
+
343
592
  def _convert_to_lax_conv_dimension_numbers(
344
593
  num_spatial_dims,
345
594
  data_format="channels_last",
@@ -404,7 +653,7 @@ def conv(
404
653
  f"kernel in_channels {kernel_in_channels}. "
405
654
  )
406
655
  feature_group_count = channels // kernel_in_channels
407
- return np.array(
656
+ result = np.array(
408
657
  jax.lax.conv_general_dilated(
409
658
  inputs,
410
659
  kernel if is_tensor(kernel) else kernel.numpy(),
@@ -415,6 +664,14 @@ def conv(
415
664
  feature_group_count=feature_group_count,
416
665
  )
417
666
  )
667
+ if result.size == 0:
668
+ raise ValueError(
669
+ "The convolution operation resulted in an empty output. "
670
+ "This can happen if the input is too small for the given "
671
+ "kernel size, strides, dilation rate, and padding mode. "
672
+ "Please check the input shape and convolution parameters."
673
+ )
674
+ return result
418
675
 
419
676
 
420
677
  def depthwise_conv(
@@ -1176,3 +1433,56 @@ def dot_product_attention(
1176
1433
  return _dot_product_attention_xla(
1177
1434
  query, key, value, bias, mask, is_causal, scale
1178
1435
  )
1436
+
1437
+
1438
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1439
+ """NumPy implementation of Unfold.
1440
+ Extract sliding local blocks from a **NCHW** batched image tensor.
1441
+
1442
+ Args:
1443
+ input: 4-D tensor, shape (N, C, H, W) **required**.
1444
+ kernel_size: int or (kH, kW)
1445
+ dilation: int or (dH, dW), default 1
1446
+ padding: int or (pH, pW), default 0
1447
+ stride: int or (sH, sW), default 1
1448
+
1449
+ Returns:
1450
+ 3-D tensor, shape (N, C*kH*kW, L)
1451
+ """
1452
+
1453
+ def _pair(x):
1454
+ return (x, x) if isinstance(x, int) else x
1455
+
1456
+ k = _pair(kernel_size)
1457
+ d = _pair(dilation)
1458
+ p = _pair(padding)
1459
+ s = _pair(stride)
1460
+
1461
+ N, C, H, W = input.shape
1462
+
1463
+ # ---- padding ----
1464
+ if any(_ > 0 for _ in p):
1465
+ input = np.pad(
1466
+ input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])), mode="constant"
1467
+ )
1468
+
1469
+ # ---- spatial size ----
1470
+ oH = (input.shape[2] - (k[0] - 1) * d[0] - 1) // s[0] + 1
1471
+ oW = (input.shape[3] - (k[1] - 1) * d[1] - 1) // s[1] + 1
1472
+
1473
+ i0 = np.arange(0, oH) * s[0]
1474
+ j0 = np.arange(0, oW) * s[1]
1475
+ i, j = np.meshgrid(i0, j0, indexing="ij") # shape (oH, oW)
1476
+ i = i.reshape(-1)
1477
+ j = j.reshape(-1)
1478
+
1479
+ # ---- flatten patches ----
1480
+ patches = np.empty((N, C, k[0], k[1], oH * oW), dtype=input.dtype)
1481
+ for idx in range(k[0]):
1482
+ for jdx in range(k[1]):
1483
+ patches[:, :, idx, jdx, :] = input[
1484
+ :, :, i + idx * d[0], j + jdx * d[1]
1485
+ ]
1486
+
1487
+ # ---- reshape -> (N, C*kH*kW, L) ----
1488
+ return patches.reshape(N, C * k[0] * k[1], -1)
@@ -294,6 +294,11 @@ def array(x, dtype=None):
294
294
  return convert_to_tensor(x, dtype=dtype)
295
295
 
296
296
 
297
+ def view(x, dtype=None):
298
+ x = convert_to_tensor(x)
299
+ return x.view(dtype=dtype)
300
+
301
+
297
302
  def average(x, axis=None, weights=None):
298
303
  axis = standardize_axis_for_numpy(axis)
299
304
  x = convert_to_tensor(x)
@@ -607,6 +612,10 @@ def empty(shape, dtype=None):
607
612
  return np.empty(shape, dtype=dtype)
608
613
 
609
614
 
615
+ def empty_like(x, dtype=None):
616
+ return np.empty_like(x, dtype=dtype)
617
+
618
+
610
619
  def equal(x1, x2):
611
620
  return np.equal(x1, x2)
612
621
 
@@ -745,6 +754,11 @@ def isposinf(x):
745
754
  return np.isposinf(x)
746
755
 
747
756
 
757
+ def isreal(x):
758
+ x = convert_to_tensor(x)
759
+ return np.isreal(x)
760
+
761
+
748
762
  def kron(x1, x2):
749
763
  x1 = convert_to_tensor(x1)
750
764
  x2 = convert_to_tensor(x2)
@@ -759,6 +773,19 @@ def lcm(x1, x2):
759
773
  return np.lcm(x1, x2).astype(dtype)
760
774
 
761
775
 
776
+ def ldexp(x1, x2):
777
+ x1 = convert_to_tensor(x1)
778
+ x2 = convert_to_tensor(x2)
779
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
780
+
781
+ if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
782
+ raise TypeError(
783
+ f"ldexp exponent must be an integer type. "
784
+ f"Received: x2 dtype={x2.dtype}"
785
+ )
786
+ return np.ldexp(x1, x2).astype(dtype)
787
+
788
+
762
789
  def less(x1, x2):
763
790
  return np.less(x1, x2)
764
791
 
@@ -991,6 +1018,10 @@ def prod(x, axis=None, keepdims=False, dtype=None):
991
1018
  return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
992
1019
 
993
1020
 
1021
+ def ptp(x, axis=None, keepdims=False):
1022
+ return np.ptp(x, axis=axis, keepdims=keepdims)
1023
+
1024
+
994
1025
  def quantile(x, q, axis=None, method="linear", keepdims=False):
995
1026
  axis = standardize_axis_for_numpy(axis)
996
1027
  x = convert_to_tensor(x)
@@ -1097,6 +1128,11 @@ def split(x, indices_or_sections, axis=0):
1097
1128
  return np.split(x, indices_or_sections, axis=axis)
1098
1129
 
1099
1130
 
1131
+ def array_split(x, indices_or_sections, axis=0):
1132
+ axis = standardize_axis_for_numpy(axis)
1133
+ return np.array_split(x, indices_or_sections, axis=axis)
1134
+
1135
+
1100
1136
  def stack(x, axis=0):
1101
1137
  axis = standardize_axis_for_numpy(axis)
1102
1138
  dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
@@ -1172,8 +1208,10 @@ def trace(x, offset=0, axis1=0, axis2=1):
1172
1208
  axis2 = standardize_axis_for_numpy(axis2)
1173
1209
  x = convert_to_tensor(x)
1174
1210
  dtype = standardize_dtype(x.dtype)
1175
- if dtype not in ("int64", "uint32", "uint64"):
1176
- dtype = dtypes.result_type(dtype, "int32")
1211
+ if dtype in ("bool", "int8", "int16"):
1212
+ dtype = "int32"
1213
+ elif dtype in ("uint8", "uint16"):
1214
+ dtype = "uint32"
1177
1215
  return np.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
1178
1216
 
1179
1217
 
@@ -1301,6 +1339,14 @@ def negative(x):
1301
1339
  return np.negative(x)
1302
1340
 
1303
1341
 
1342
+ def nextafter(x1, x2):
1343
+ x1 = convert_to_tensor(x1)
1344
+ x2 = convert_to_tensor(x2)
1345
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1346
+
1347
+ return np.nextafter(x1, x2).astype(dtype)
1348
+
1349
+
1304
1350
  def square(x):
1305
1351
  x = convert_to_tensor(x)
1306
1352
  if standardize_dtype(x.dtype) == "bool":
@@ -1329,6 +1375,23 @@ def transpose(x, axes=None):
1329
1375
  return np.transpose(x, axes=axes)
1330
1376
 
1331
1377
 
1378
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
1379
+ y = convert_to_tensor(y)
1380
+ result_dtype = dtypes.result_type(y.dtype, float)
1381
+ if x is not None:
1382
+ x = convert_to_tensor(x)
1383
+ dx = convert_to_tensor(dx)
1384
+ return np.trapezoid(y, x, dx=dx, axis=axis).astype(result_dtype)
1385
+
1386
+
1387
+ def vander(x, N=None, increasing=False):
1388
+ x = convert_to_tensor(x)
1389
+ result_dtype = dtypes.result_type(x.dtype)
1390
+ compute_dtype = dtypes.result_type(x.dtype, config.floatx())
1391
+ x = x.astype(compute_dtype)
1392
+ return np.vander(x, N=N, increasing=increasing).astype(result_dtype)
1393
+
1394
+
1332
1395
  def var(x, axis=None, keepdims=False):
1333
1396
  axis = standardize_axis_for_numpy(axis)
1334
1397
  x = convert_to_tensor(x)
@@ -15,6 +15,7 @@ from keras.src.backend.openvino.core import compute_output_spec
15
15
  from keras.src.backend.openvino.core import cond
16
16
  from keras.src.backend.openvino.core import convert_to_numpy
17
17
  from keras.src.backend.openvino.core import convert_to_tensor
18
+ from keras.src.backend.openvino.core import device_scope
18
19
  from keras.src.backend.openvino.core import is_tensor
19
20
  from keras.src.backend.openvino.core import random_seed_dtype
20
21
  from keras.src.backend.openvino.core import shape
@@ -13,7 +13,6 @@ from openvino import compile_model
13
13
  from keras.src import tree
14
14
  from keras.src.backend.common import KerasVariable
15
15
  from keras.src.backend.common import dtypes
16
- from keras.src.backend.common import global_state
17
16
  from keras.src.backend.common import standardize_dtype
18
17
  from keras.src.backend.common.dtypes import result_type
19
18
  from keras.src.backend.common.keras_tensor import KerasTensor
@@ -530,31 +529,11 @@ def ov_to_keras_type(ov_type):
530
529
 
531
530
  @contextlib.contextmanager
532
531
  def device_scope(device_name):
533
- current_device = _parse_device_input(device_name)
534
- global_state.set_global_attribute("openvino_device", current_device)
532
+ yield
535
533
 
536
534
 
537
535
  def get_device():
538
- device = global_state.get_global_attribute("openvino_device", None)
539
- if device is None:
540
- return "CPU"
541
- return device
542
-
543
-
544
- def _parse_device_input(device_name):
545
- if isinstance(device_name, str):
546
- # We support string value like "cpu:0", "gpu:1", and need to convert
547
- # "gpu" to "cuda"
548
- device_name = device_name.upper()
549
- device_type, _ = device_name.split(":")
550
- return device_type
551
- else:
552
- raise ValueError(
553
- "Invalid value for argument `device_name`. "
554
- "Expected a string like 'gpu:0' or 'cpu'. "
555
- f"Received: device_name='{device_name}'"
556
- )
557
- return device_name
536
+ return "CPU"
558
537
 
559
538
 
560
539
  class Variable(KerasVariable):
@@ -56,3 +56,7 @@ def svd(x, full_matrices=True, compute_uv=True):
56
56
 
57
57
  def lstsq(a, b, rcond=None):
58
58
  raise NotImplementedError("`lstsq` is not supported with openvino backend")
59
+
60
+
61
+ def jvp(fun, primals, tangents, has_aux=False):
62
+ raise NotImplementedError("`jvp` is not supported with openvino backend")