keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ import math
3
3
 
4
4
  import jax.experimental.sparse as jax_sparse
5
5
  import jax.numpy as jnp
6
+ from jax import export as jax_export
6
7
 
7
8
  from keras.src.backend import config
8
9
  from keras.src.backend.common import dtypes
@@ -306,14 +307,20 @@ def append(x1, x2, axis=None):
306
307
  return jnp.append(x1, x2, axis=axis)
307
308
 
308
309
 
309
- def arange(start, stop=None, step=1, dtype=None):
310
+ def arange(start, stop=None, step=None, dtype=None):
311
+ def get_dtype(x):
312
+ if hasattr(x, "dtype"):
313
+ return x.dtype
314
+ if jax_export.is_symbolic_dim(x):
315
+ return int
316
+ return type(x)
317
+
310
318
  if dtype is None:
311
- dtypes_to_resolve = [
312
- getattr(start, "dtype", type(start)),
313
- getattr(step, "dtype", type(step)),
314
- ]
319
+ dtypes_to_resolve = [get_dtype(start)]
315
320
  if stop is not None:
316
- dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
321
+ dtypes_to_resolve.append(get_dtype(stop))
322
+ if step is not None:
323
+ dtypes_to_resolve.append(get_dtype(step))
317
324
  dtype = dtypes.result_type(*dtypes_to_resolve)
318
325
  dtype = standardize_dtype(dtype)
319
326
  return jnp.arange(start, stop, step=step, dtype=dtype)
@@ -439,6 +446,11 @@ def array(x, dtype=None):
439
446
  return jnp.array(x, dtype=dtype)
440
447
 
441
448
 
449
+ def view(x, dtype=None):
450
+ x = convert_to_tensor(x)
451
+ return x.view(dtype=dtype)
452
+
453
+
442
454
  def average(x, axis=None, weights=None):
443
455
  x = convert_to_tensor(x)
444
456
  dtypes_to_resolve = [x.dtype, float]
@@ -536,15 +548,18 @@ def clip(x, x_min, x_max):
536
548
 
537
549
  def concatenate(xs, axis=0):
538
550
  bcoo_count = builtins.sum(isinstance(x, jax_sparse.BCOO) for x in xs)
539
- if bcoo_count:
540
- if bcoo_count == len(xs):
541
- axis = canonicalize_axis(axis, len(xs[0].shape))
542
- return jax_sparse.bcoo_concatenate(xs, dimension=axis)
543
- else:
544
- xs = [
545
- x.todense() if isinstance(x, jax_sparse.JAXSparse) else x
546
- for x in xs
547
- ]
551
+ if bcoo_count == len(xs):
552
+ axis = canonicalize_axis(axis, len(xs[0].shape))
553
+ return jax_sparse.bcoo_concatenate(xs, dimension=axis)
554
+ elif bcoo_count:
555
+ xs = [
556
+ x.todense()
557
+ if isinstance(x, jax_sparse.JAXSparse)
558
+ else convert_to_tensor(x)
559
+ for x in xs
560
+ ]
561
+ else:
562
+ xs = [convert_to_tensor(x) for x in xs]
548
563
  return jnp.concatenate(xs, axis=axis)
549
564
 
550
565
 
@@ -663,6 +678,10 @@ def empty(shape, dtype=None):
663
678
  return jnp.empty(shape, dtype=dtype)
664
679
 
665
680
 
681
+ def empty_like(x, dtype=None):
682
+ return jnp.empty_like(x, dtype=dtype)
683
+
684
+
666
685
  def equal(x1, x2):
667
686
  x1 = convert_to_tensor(x1)
668
687
  x2 = convert_to_tensor(x2)
@@ -809,6 +828,11 @@ def isposinf(x):
809
828
  return jnp.isposinf(x)
810
829
 
811
830
 
831
+ def isreal(x):
832
+ x = convert_to_tensor(x)
833
+ return jnp.isreal(x)
834
+
835
+
812
836
  def kron(x1, x2):
813
837
  x1 = convert_to_tensor(x1)
814
838
  x2 = convert_to_tensor(x2)
@@ -821,6 +845,19 @@ def lcm(x1, x2):
821
845
  return jnp.lcm(x1, x2)
822
846
 
823
847
 
848
+ def ldexp(x1, x2):
849
+ x1 = convert_to_tensor(x1)
850
+ x2 = convert_to_tensor(x2)
851
+
852
+ if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
853
+ raise TypeError(
854
+ f"ldexp exponent must be an integer type. "
855
+ f"Received: x2 dtype={x2.dtype}"
856
+ )
857
+
858
+ return jnp.ldexp(x1, x2)
859
+
860
+
824
861
  def less(x1, x2):
825
862
  x1 = convert_to_tensor(x1)
826
863
  x2 = convert_to_tensor(x2)
@@ -888,6 +925,15 @@ def logaddexp(x1, x2):
888
925
  return jnp.logaddexp(x1, x2)
889
926
 
890
927
 
928
+ def logaddexp2(x1, x2):
929
+ x1 = convert_to_tensor(x1)
930
+ x2 = convert_to_tensor(x2)
931
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
932
+ x1 = cast(x1, dtype)
933
+ x2 = cast(x2, dtype)
934
+ return jnp.logaddexp2(x1, x2)
935
+
936
+
891
937
  def logical_and(x1, x2):
892
938
  x1 = convert_to_tensor(x1)
893
939
  x2 = convert_to_tensor(x2)
@@ -1071,6 +1117,7 @@ def reshape(x, newshape):
1071
1117
  if None not in output_shape:
1072
1118
  newshape = output_shape
1073
1119
  return jax_sparse.bcoo_reshape(x, new_sizes=newshape)
1120
+ x = convert_to_tensor(x)
1074
1121
  return jnp.reshape(x, newshape)
1075
1122
 
1076
1123
 
@@ -1133,10 +1180,17 @@ def sort(x, axis=-1):
1133
1180
 
1134
1181
 
1135
1182
  def split(x, indices_or_sections, axis=0):
1183
+ x = convert_to_tensor(x)
1136
1184
  return jnp.split(x, indices_or_sections, axis=axis)
1137
1185
 
1138
1186
 
1187
+ def array_split(x, indices_or_sections, axis=0):
1188
+ x = convert_to_tensor(x)
1189
+ return jnp.array_split(x, indices_or_sections, axis=axis)
1190
+
1191
+
1139
1192
  def stack(x, axis=0):
1193
+ x = [convert_to_tensor(t) for t in x]
1140
1194
  return jnp.stack(x, axis=axis)
1141
1195
 
1142
1196
 
@@ -1159,6 +1213,8 @@ def take(x, indices, axis=None):
1159
1213
 
1160
1214
 
1161
1215
  def take_along_axis(x, indices, axis=None):
1216
+ x = convert_to_tensor(x)
1217
+ indices = convert_to_tensor(indices, sparse=False)
1162
1218
  return jnp.take_along_axis(x, indices, axis=axis)
1163
1219
 
1164
1220
 
@@ -1213,14 +1269,7 @@ def tile(x, repeats):
1213
1269
 
1214
1270
  def trace(x, offset=0, axis1=0, axis2=1):
1215
1271
  x = convert_to_tensor(x)
1216
- dtype = None
1217
- # TODO: Remove the condition of uint8 and uint16 once we have jax>=0.4.27
1218
- # for both CPU & GPU environments.
1219
- # uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to int32
1220
- # otherwise.
1221
- if standardize_dtype(x.dtype) in ("bool", "uint8", "uint16"):
1222
- dtype = "int32"
1223
- return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
1272
+ return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)
1224
1273
 
1225
1274
 
1226
1275
  def tri(N, M=None, k=0, dtype=None):
@@ -1302,6 +1351,12 @@ def negative(x):
1302
1351
  return jnp.negative(x)
1303
1352
 
1304
1353
 
1354
+ def nextafter(x1, x2):
1355
+ x1 = convert_to_tensor(x1)
1356
+ x2 = convert_to_tensor(x2)
1357
+ return jnp.nextafter(x1, x2)
1358
+
1359
+
1305
1360
  @sparse.elementwise_unary(linear=False)
1306
1361
  def square(x):
1307
1362
  x = convert_to_tensor(x)
@@ -1322,6 +1377,7 @@ def squeeze(x, axis=None):
1322
1377
  axis = tuple(i for i, d in enumerate(x.shape) if d == 1)
1323
1378
  axis = to_tuple_or_list(axis)
1324
1379
  return jax_sparse.bcoo_squeeze(x, dimensions=axis)
1380
+ x = convert_to_tensor(x)
1325
1381
  return jnp.squeeze(x, axis=axis)
1326
1382
 
1327
1383
 
@@ -1340,6 +1396,19 @@ def transpose(x, axes=None):
1340
1396
  return jnp.transpose(x, axes=axes)
1341
1397
 
1342
1398
 
1399
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
1400
+ y = convert_to_tensor(y)
1401
+ if x is not None:
1402
+ x = convert_to_tensor(x)
1403
+ dx = convert_to_tensor(dx)
1404
+ return jnp.trapezoid(y, x, dx=dx, axis=axis)
1405
+
1406
+
1407
+ def vander(x, N=None, increasing=False):
1408
+ x = convert_to_tensor(x)
1409
+ return jnp.vander(x, N=N, increasing=increasing)
1410
+
1411
+
1343
1412
  def var(x, axis=None, keepdims=False):
1344
1413
  x = convert_to_tensor(x)
1345
1414
  # `jnp.var` does not handle low precision (e.g., float16) overflow
@@ -36,13 +36,14 @@ class JaxOptimizer(base_optimizer.BaseOptimizer):
36
36
  new_g_accs = jax.lax.cond(
37
37
  is_update_step,
38
38
  lambda: [jnp.zeros(g.shape, dtype=g.dtype) for g in acc_grads],
39
- lambda: [g + acc_g for g, acc_g in zip(grads, acc_grads)],
39
+ lambda: [g + acc_g.value for g, acc_g in zip(grads, acc_grads)],
40
40
  )
41
41
 
42
42
  grads = jax.lax.cond(
43
43
  is_update_step,
44
44
  lambda: [
45
- (g + acc_g) / steps for g, acc_g in zip(grads, acc_grads)
45
+ (g + acc_g.value) / steps
46
+ for g, acc_g in zip(grads, acc_grads)
46
47
  ],
47
48
  lambda: list(grads),
48
49
  )
@@ -266,8 +266,14 @@ class JAXTrainer(base_trainer.Trainer):
266
266
  if distribution_lib.distribution() is not None:
267
267
  state_shardings = self._get_state_sharding_spec()
268
268
  out_shardings = (None, state_shardings)
269
+ if is_nnx_enabled():
270
+ step_fn = lambda state, data: type(self).train_step(
271
+ self, state, data
272
+ )
273
+ else:
274
+ step_fn = self.train_step
269
275
  train_step = jit(
270
- self.train_step,
276
+ step_fn,
271
277
  donate_argnums=0,
272
278
  out_shardings=out_shardings,
273
279
  )
@@ -296,8 +302,14 @@ class JAXTrainer(base_trainer.Trainer):
296
302
  metrics_shardings,
297
303
  )
298
304
  out_shardings = (None, state_shardings)
305
+ if is_nnx_enabled():
306
+ step_fn = lambda state, data: type(self).test_step(
307
+ self, state, data
308
+ )
309
+ else:
310
+ step_fn = self.test_step
299
311
  test_step = jit(
300
- self.test_step,
312
+ step_fn,
301
313
  donate_argnums=0,
302
314
  out_shardings=out_shardings,
303
315
  )
@@ -96,3 +96,7 @@ def lstsq(a, b, rcond=None):
96
96
  a = convert_to_tensor(a)
97
97
  b = convert_to_tensor(b)
98
98
  return np.linalg.lstsq(a, b, rcond=rcond)[0]
99
+
100
+
101
+ def jvp(fun, primals, tangents, has_aux=False):
102
+ raise NotImplementedError("JVP is not supported by the Numpy backend.")
@@ -3,6 +3,9 @@ import numpy as np
3
3
  from jax import lax
4
4
 
5
5
  from keras.src import backend
6
+ from keras.src.backend.common.backend_utils import (
7
+ compute_adaptive_pooling_window_sizes,
8
+ )
6
9
  from keras.src.backend.common.backend_utils import (
7
10
  compute_conv_transpose_padding_args_for_jax,
8
11
  )
@@ -164,13 +167,14 @@ def celu(x, alpha=1.0):
164
167
 
165
168
  def glu(x, axis=-1):
166
169
  x = convert_to_tensor(x)
170
+ dtype = x.dtype
167
171
  if x.shape[axis] % 2 != 0:
168
172
  raise ValueError(
169
173
  "axis size must be divisible by 2. "
170
174
  f"Received: x.shape={x.shape} with axis={axis}"
171
175
  )
172
176
  x1, x2 = np.split(x, 2, axis)
173
- return x1 * (1 / (1 + np.exp(-x2)))
177
+ return (x1 * sigmoid(x2)).astype(dtype)
174
178
 
175
179
 
176
180
  def hard_tanh(x):
@@ -339,6 +343,252 @@ def average_pool(
339
343
  return pooled / window_counts
340
344
 
341
345
 
346
+ def _compute_adaptive_pooling_gather_indices(
347
+ input_dim, output_size, big_window
348
+ ):
349
+ window_starts = np.floor(
350
+ (np.arange(output_size) * input_dim) / output_size
351
+ ).astype(np.int32)
352
+
353
+ window_ends = np.ceil(
354
+ (np.arange(1, output_size + 1) * input_dim) / output_size
355
+ ).astype(np.int32)
356
+
357
+ window_sizes = window_ends - window_starts
358
+ is_big = window_sizes == big_window
359
+
360
+ small_window = big_window - 1
361
+ small_pool_len = input_dim - small_window + 1
362
+
363
+ small_indices = window_starts
364
+ big_indices = window_starts + small_pool_len
365
+
366
+ gather = np.where(is_big, big_indices, small_indices)
367
+ return gather.astype(np.int32)
368
+
369
+
370
+ def _strided_view_1d(x, window_size):
371
+ n, l, c = x.shape
372
+ out = l - window_size + 1
373
+
374
+ strides = x.strides
375
+ shape = (n, out, window_size, c)
376
+ new_strides = (strides[0], strides[1], strides[1], strides[2])
377
+
378
+ return np.lib.stride_tricks.as_strided(x, shape=shape, strides=new_strides)
379
+
380
+
381
+ def _adaptive_pool1d_impl(inputs, output_size, mode, data_format):
382
+ if isinstance(output_size, int):
383
+ output_size = (output_size,)
384
+
385
+ if data_format == "channels_first":
386
+ inputs = np.transpose(inputs, (0, 2, 1))
387
+
388
+ n, l, c = inputs.shape
389
+ out_l = output_size[0]
390
+
391
+ small, big = compute_adaptive_pooling_window_sizes(l, out_l)
392
+ gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
393
+
394
+ sv_small = _strided_view_1d(inputs, small)
395
+ small_pool = (
396
+ np.mean(sv_small, axis=2)
397
+ if mode == "average"
398
+ else np.max(sv_small, axis=2)
399
+ )
400
+
401
+ sv_big = _strided_view_1d(inputs, big)
402
+ big_pool = (
403
+ np.mean(sv_big, axis=2) if mode == "average" else np.max(sv_big, axis=2)
404
+ )
405
+
406
+ combined = np.concatenate([small_pool, big_pool], axis=1)
407
+ out = combined[:, gather, :]
408
+
409
+ if data_format == "channels_first":
410
+ out = np.transpose(out, (0, 2, 1))
411
+
412
+ return out
413
+
414
+
415
+ def _adaptive_pool2d_impl(inputs, output_size, mode, data_format):
416
+ if isinstance(output_size, int):
417
+ output_size = (output_size, output_size)
418
+
419
+ if data_format == "channels_first":
420
+ inputs = np.transpose(inputs, (0, 2, 3, 1))
421
+
422
+ n, h, w, c = inputs.shape
423
+ out_h, out_w = output_size
424
+
425
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
426
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
427
+
428
+ x_h = np.transpose(inputs, (0, 2, 1, 3)).reshape(n * w, h, c)
429
+
430
+ sv_small_h = _strided_view_1d(x_h, small_h)
431
+ small_pool_h = (
432
+ np.mean(sv_small_h, axis=2)
433
+ if mode == "average"
434
+ else np.max(sv_small_h, axis=2)
435
+ )
436
+
437
+ sv_big_h = _strided_view_1d(x_h, big_h)
438
+ big_pool_h = (
439
+ np.mean(sv_big_h, axis=2)
440
+ if mode == "average"
441
+ else np.max(sv_big_h, axis=2)
442
+ )
443
+
444
+ combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
445
+ pooled_h = combined_h[:, gather_h, :]
446
+
447
+ pooled_h = pooled_h.reshape(n, w, out_h, c)
448
+ pooled_h = np.transpose(pooled_h, (0, 2, 1, 3))
449
+
450
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
451
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
452
+
453
+ x_w = pooled_h.reshape(n * out_h, w, c)
454
+
455
+ sv_small_w = _strided_view_1d(x_w, small_w)
456
+ small_pool_w = (
457
+ np.mean(sv_small_w, axis=2)
458
+ if mode == "average"
459
+ else np.max(sv_small_w, axis=2)
460
+ )
461
+
462
+ sv_big_w = _strided_view_1d(x_w, big_w)
463
+ big_pool_w = (
464
+ np.mean(sv_big_w, axis=2)
465
+ if mode == "average"
466
+ else np.max(sv_big_w, axis=2)
467
+ )
468
+
469
+ combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
470
+ out = combined_w[:, gather_w, :].reshape(n, out_h, out_w, c)
471
+
472
+ if data_format == "channels_first":
473
+ out = np.transpose(out, (0, 3, 1, 2))
474
+
475
+ return out
476
+
477
+
478
+ def _adaptive_pool3d_impl(inputs, output_size, mode, data_format):
479
+ if isinstance(output_size, int):
480
+ output_size = (output_size, output_size, output_size)
481
+
482
+ if data_format == "channels_first":
483
+ inputs = np.transpose(inputs, (0, 2, 3, 4, 1))
484
+
485
+ n, d, h, w, c = inputs.shape
486
+ out_d, out_h, out_w = output_size
487
+
488
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
489
+ gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
490
+
491
+ x_d = np.transpose(inputs, (0, 2, 3, 1, 4)).reshape(n * h * w, d, c)
492
+
493
+ sv_small_d = _strided_view_1d(x_d, small_d)
494
+ small_pool_d = (
495
+ np.mean(sv_small_d, axis=2)
496
+ if mode == "average"
497
+ else np.max(sv_small_d, axis=2)
498
+ )
499
+
500
+ sv_big_d = _strided_view_1d(x_d, big_d)
501
+ big_pool_d = (
502
+ np.mean(sv_big_d, axis=2)
503
+ if mode == "average"
504
+ else np.max(sv_big_d, axis=2)
505
+ )
506
+
507
+ combined_d = np.concatenate([small_pool_d, big_pool_d], axis=1)
508
+ pooled_d = combined_d[:, gather_d, :].reshape(n, h, w, out_d, c)
509
+ pooled_d = np.transpose(pooled_d, (0, 3, 1, 2, 4))
510
+
511
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
512
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
513
+
514
+ x_h = np.transpose(pooled_d, (0, 1, 3, 2, 4)).reshape(n * out_d * w, h, c)
515
+
516
+ sv_small_h = _strided_view_1d(x_h, small_h)
517
+ small_pool_h = (
518
+ np.mean(sv_small_h, axis=2)
519
+ if mode == "average"
520
+ else np.max(sv_small_h, axis=2)
521
+ )
522
+
523
+ sv_big_h = _strided_view_1d(x_h, big_h)
524
+ big_pool_h = (
525
+ np.mean(sv_big_h, axis=2)
526
+ if mode == "average"
527
+ else np.max(sv_big_h, axis=2)
528
+ )
529
+
530
+ combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
531
+ pooled_h = combined_h[:, gather_h, :].reshape(n, out_d, w, out_h, c)
532
+ pooled_h = np.transpose(pooled_h, (0, 1, 3, 2, 4))
533
+
534
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
535
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
536
+
537
+ x_w = pooled_h.reshape(n * out_d * out_h, w, c)
538
+
539
+ sv_small_w = _strided_view_1d(x_w, small_w)
540
+ small_pool_w = (
541
+ np.mean(sv_small_w, axis=2)
542
+ if mode == "average"
543
+ else np.max(sv_small_w, axis=2)
544
+ )
545
+
546
+ sv_big_w = _strided_view_1d(x_w, big_w)
547
+ big_pool_w = (
548
+ np.mean(sv_big_w, axis=2)
549
+ if mode == "average"
550
+ else np.max(sv_big_w, axis=2)
551
+ )
552
+
553
+ combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
554
+ out = combined_w[:, gather_w, :].reshape(n, out_d, out_h, out_w, c)
555
+
556
+ if data_format == "channels_first":
557
+ out = np.transpose(out, (0, 4, 1, 2, 3))
558
+
559
+ return out
560
+
561
+
562
+ def adaptive_average_pool(inputs, output_size, data_format=None):
563
+ data_format = backend.standardize_data_format(data_format)
564
+ dims = inputs.ndim - 2
565
+ if dims == 1:
566
+ return _adaptive_pool1d_impl(
567
+ inputs, output_size, "average", data_format
568
+ )
569
+ if dims == 2:
570
+ return _adaptive_pool2d_impl(
571
+ inputs, output_size, "average", data_format
572
+ )
573
+ if dims == 3:
574
+ return _adaptive_pool3d_impl(
575
+ inputs, output_size, "average", data_format
576
+ )
577
+ raise ValueError("adaptive_average_pool supports only 1D/2D/3D")
578
+
579
+
580
+ def adaptive_max_pool(inputs, output_size, data_format=None):
581
+ data_format = backend.standardize_data_format(data_format)
582
+ dims = inputs.ndim - 2
583
+ if dims == 1:
584
+ return _adaptive_pool1d_impl(inputs, output_size, "max", data_format)
585
+ if dims == 2:
586
+ return _adaptive_pool2d_impl(inputs, output_size, "max", data_format)
587
+ if dims == 3:
588
+ return _adaptive_pool3d_impl(inputs, output_size, "max", data_format)
589
+ raise ValueError("adaptive_max_pool supports only 1D/2D/3D")
590
+
591
+
342
592
  def _convert_to_lax_conv_dimension_numbers(
343
593
  num_spatial_dims,
344
594
  data_format="channels_last",
@@ -403,7 +653,7 @@ def conv(
403
653
  f"kernel in_channels {kernel_in_channels}. "
404
654
  )
405
655
  feature_group_count = channels // kernel_in_channels
406
- return np.array(
656
+ result = np.array(
407
657
  jax.lax.conv_general_dilated(
408
658
  inputs,
409
659
  kernel if is_tensor(kernel) else kernel.numpy(),
@@ -414,6 +664,14 @@ def conv(
414
664
  feature_group_count=feature_group_count,
415
665
  )
416
666
  )
667
+ if result.size == 0:
668
+ raise ValueError(
669
+ "The convolution operation resulted in an empty output. "
670
+ "This can happen if the input is too small for the given "
671
+ "kernel size, strides, dilation rate, and padding mode. "
672
+ "Please check the input shape and convolution parameters."
673
+ )
674
+ return result
417
675
 
418
676
 
419
677
  def depthwise_conv(
@@ -1175,3 +1433,56 @@ def dot_product_attention(
1175
1433
  return _dot_product_attention_xla(
1176
1434
  query, key, value, bias, mask, is_causal, scale
1177
1435
  )
1436
+
1437
+
1438
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1439
+ """NumPy implementation of Unfold.
1440
+ Extract sliding local blocks from a **NCHW** batched image tensor.
1441
+
1442
+ Args:
1443
+ input: 4-D tensor, shape (N, C, H, W) **required**.
1444
+ kernel_size: int or (kH, kW)
1445
+ dilation: int or (dH, dW), default 1
1446
+ padding: int or (pH, pW), default 0
1447
+ stride: int or (sH, sW), default 1
1448
+
1449
+ Returns:
1450
+ 3-D tensor, shape (N, C*kH*kW, L)
1451
+ """
1452
+
1453
+ def _pair(x):
1454
+ return (x, x) if isinstance(x, int) else x
1455
+
1456
+ k = _pair(kernel_size)
1457
+ d = _pair(dilation)
1458
+ p = _pair(padding)
1459
+ s = _pair(stride)
1460
+
1461
+ N, C, H, W = input.shape
1462
+
1463
+ # ---- padding ----
1464
+ if any(_ > 0 for _ in p):
1465
+ input = np.pad(
1466
+ input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])), mode="constant"
1467
+ )
1468
+
1469
+ # ---- spatial size ----
1470
+ oH = (input.shape[2] - (k[0] - 1) * d[0] - 1) // s[0] + 1
1471
+ oW = (input.shape[3] - (k[1] - 1) * d[1] - 1) // s[1] + 1
1472
+
1473
+ i0 = np.arange(0, oH) * s[0]
1474
+ j0 = np.arange(0, oW) * s[1]
1475
+ i, j = np.meshgrid(i0, j0, indexing="ij") # shape (oH, oW)
1476
+ i = i.reshape(-1)
1477
+ j = j.reshape(-1)
1478
+
1479
+ # ---- flatten patches ----
1480
+ patches = np.empty((N, C, k[0], k[1], oH * oW), dtype=input.dtype)
1481
+ for idx in range(k[0]):
1482
+ for jdx in range(k[1]):
1483
+ patches[:, :, idx, jdx, :] = input[
1484
+ :, :, i + idx * d[0], j + jdx * d[1]
1485
+ ]
1486
+
1487
+ # ---- reshape -> (N, C*kH*kW, L) ----
1488
+ return patches.reshape(N, C * k[0] * k[1], -1)