keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ import math
3
3
 
4
4
  import jax.experimental.sparse as jax_sparse
5
5
  import jax.numpy as jnp
6
+ from jax import export as jax_export
6
7
 
7
8
  from keras.src.backend import config
8
9
  from keras.src.backend.common import dtypes
@@ -306,14 +307,20 @@ def append(x1, x2, axis=None):
306
307
  return jnp.append(x1, x2, axis=axis)
307
308
 
308
309
 
309
- def arange(start, stop=None, step=1, dtype=None):
310
+ def arange(start, stop=None, step=None, dtype=None):
311
+ def get_dtype(x):
312
+ if hasattr(x, "dtype"):
313
+ return x.dtype
314
+ if jax_export.is_symbolic_dim(x):
315
+ return int
316
+ return type(x)
317
+
310
318
  if dtype is None:
311
- dtypes_to_resolve = [
312
- getattr(start, "dtype", type(start)),
313
- getattr(step, "dtype", type(step)),
314
- ]
319
+ dtypes_to_resolve = [get_dtype(start)]
315
320
  if stop is not None:
316
- dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
321
+ dtypes_to_resolve.append(get_dtype(stop))
322
+ if step is not None:
323
+ dtypes_to_resolve.append(get_dtype(step))
317
324
  dtype = dtypes.result_type(*dtypes_to_resolve)
318
325
  dtype = standardize_dtype(dtype)
319
326
  return jnp.arange(start, stop, step=step, dtype=dtype)
@@ -439,6 +446,11 @@ def array(x, dtype=None):
439
446
  return jnp.array(x, dtype=dtype)
440
447
 
441
448
 
449
+ def view(x, dtype=None):
450
+ x = convert_to_tensor(x)
451
+ return x.view(dtype=dtype)
452
+
453
+
442
454
  def average(x, axis=None, weights=None):
443
455
  x = convert_to_tensor(x)
444
456
  dtypes_to_resolve = [x.dtype, float]
@@ -536,15 +548,18 @@ def clip(x, x_min, x_max):
536
548
 
537
549
  def concatenate(xs, axis=0):
538
550
  bcoo_count = builtins.sum(isinstance(x, jax_sparse.BCOO) for x in xs)
539
- if bcoo_count:
540
- if bcoo_count == len(xs):
541
- axis = canonicalize_axis(axis, len(xs[0].shape))
542
- return jax_sparse.bcoo_concatenate(xs, dimension=axis)
543
- else:
544
- xs = [
545
- x.todense() if isinstance(x, jax_sparse.JAXSparse) else x
546
- for x in xs
547
- ]
551
+ if bcoo_count == len(xs):
552
+ axis = canonicalize_axis(axis, len(xs[0].shape))
553
+ return jax_sparse.bcoo_concatenate(xs, dimension=axis)
554
+ elif bcoo_count:
555
+ xs = [
556
+ x.todense()
557
+ if isinstance(x, jax_sparse.JAXSparse)
558
+ else convert_to_tensor(x)
559
+ for x in xs
560
+ ]
561
+ else:
562
+ xs = [convert_to_tensor(x) for x in xs]
548
563
  return jnp.concatenate(xs, axis=axis)
549
564
 
550
565
 
@@ -663,6 +678,10 @@ def empty(shape, dtype=None):
663
678
  return jnp.empty(shape, dtype=dtype)
664
679
 
665
680
 
681
+ def empty_like(x, dtype=None):
682
+ return jnp.empty_like(x, dtype=dtype)
683
+
684
+
666
685
  def equal(x1, x2):
667
686
  x1 = convert_to_tensor(x1)
668
687
  x2 = convert_to_tensor(x2)
@@ -809,6 +828,36 @@ def isposinf(x):
809
828
  return jnp.isposinf(x)
810
829
 
811
830
 
831
+ def isreal(x):
832
+ x = convert_to_tensor(x)
833
+ return jnp.isreal(x)
834
+
835
+
836
+ def kron(x1, x2):
837
+ x1 = convert_to_tensor(x1)
838
+ x2 = convert_to_tensor(x2)
839
+ return jnp.kron(x1, x2)
840
+
841
+
842
+ def lcm(x1, x2):
843
+ x1 = convert_to_tensor(x1)
844
+ x2 = convert_to_tensor(x2)
845
+ return jnp.lcm(x1, x2)
846
+
847
+
848
+ def ldexp(x1, x2):
849
+ x1 = convert_to_tensor(x1)
850
+ x2 = convert_to_tensor(x2)
851
+
852
+ if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
853
+ raise TypeError(
854
+ f"ldexp exponent must be an integer type. "
855
+ f"Received: x2 dtype={x2.dtype}"
856
+ )
857
+
858
+ return jnp.ldexp(x1, x2)
859
+
860
+
812
861
  def less(x1, x2):
813
862
  x1 = convert_to_tensor(x1)
814
863
  x2 = convert_to_tensor(x2)
@@ -876,6 +925,15 @@ def logaddexp(x1, x2):
876
925
  return jnp.logaddexp(x1, x2)
877
926
 
878
927
 
928
+ def logaddexp2(x1, x2):
929
+ x1 = convert_to_tensor(x1)
930
+ x2 = convert_to_tensor(x2)
931
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
932
+ x1 = cast(x1, dtype)
933
+ x2 = cast(x2, dtype)
934
+ return jnp.logaddexp2(x1, x2)
935
+
936
+
879
937
  def logical_and(x1, x2):
880
938
  x1 = convert_to_tensor(x1)
881
939
  x2 = convert_to_tensor(x2)
@@ -1005,6 +1063,11 @@ def prod(x, axis=None, keepdims=False, dtype=None):
1005
1063
  return jnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
1006
1064
 
1007
1065
 
1066
+ def ptp(x, axis=None, keepdims=False):
1067
+ x = convert_to_tensor(x)
1068
+ return jnp.ptp(x, axis=axis, keepdims=keepdims)
1069
+
1070
+
1008
1071
  def quantile(x, q, axis=None, method="linear", keepdims=False):
1009
1072
  x = convert_to_tensor(x)
1010
1073
  q = convert_to_tensor(q)
@@ -1059,6 +1122,7 @@ def reshape(x, newshape):
1059
1122
  if None not in output_shape:
1060
1123
  newshape = output_shape
1061
1124
  return jax_sparse.bcoo_reshape(x, new_sizes=newshape)
1125
+ x = convert_to_tensor(x)
1062
1126
  return jnp.reshape(x, newshape)
1063
1127
 
1064
1128
 
@@ -1121,10 +1185,17 @@ def sort(x, axis=-1):
1121
1185
 
1122
1186
 
1123
1187
  def split(x, indices_or_sections, axis=0):
1188
+ x = convert_to_tensor(x)
1124
1189
  return jnp.split(x, indices_or_sections, axis=axis)
1125
1190
 
1126
1191
 
1192
+ def array_split(x, indices_or_sections, axis=0):
1193
+ x = convert_to_tensor(x)
1194
+ return jnp.array_split(x, indices_or_sections, axis=axis)
1195
+
1196
+
1127
1197
  def stack(x, axis=0):
1198
+ x = [convert_to_tensor(t) for t in x]
1128
1199
  return jnp.stack(x, axis=axis)
1129
1200
 
1130
1201
 
@@ -1147,6 +1218,8 @@ def take(x, indices, axis=None):
1147
1218
 
1148
1219
 
1149
1220
  def take_along_axis(x, indices, axis=None):
1221
+ x = convert_to_tensor(x)
1222
+ indices = convert_to_tensor(indices, sparse=False)
1150
1223
  return jnp.take_along_axis(x, indices, axis=axis)
1151
1224
 
1152
1225
 
@@ -1201,14 +1274,7 @@ def tile(x, repeats):
1201
1274
 
1202
1275
  def trace(x, offset=0, axis1=0, axis2=1):
1203
1276
  x = convert_to_tensor(x)
1204
- dtype = None
1205
- # TODO: Remove the condition of uint8 and uint16 once we have jax>=0.4.27
1206
- # for both CPU & GPU environments.
1207
- # uint8 and uint16 will be casted to uint32 when jax>=0.4.27 but to int32
1208
- # otherwise.
1209
- if standardize_dtype(x.dtype) in ("bool", "uint8", "uint16"):
1210
- dtype = "int32"
1211
- return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)
1277
+ return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)
1212
1278
 
1213
1279
 
1214
1280
  def tri(N, M=None, k=0, dtype=None):
@@ -1290,6 +1356,12 @@ def negative(x):
1290
1356
  return jnp.negative(x)
1291
1357
 
1292
1358
 
1359
+ def nextafter(x1, x2):
1360
+ x1 = convert_to_tensor(x1)
1361
+ x2 = convert_to_tensor(x2)
1362
+ return jnp.nextafter(x1, x2)
1363
+
1364
+
1293
1365
  @sparse.elementwise_unary(linear=False)
1294
1366
  def square(x):
1295
1367
  x = convert_to_tensor(x)
@@ -1310,6 +1382,7 @@ def squeeze(x, axis=None):
1310
1382
  axis = tuple(i for i, d in enumerate(x.shape) if d == 1)
1311
1383
  axis = to_tuple_or_list(axis)
1312
1384
  return jax_sparse.bcoo_squeeze(x, dimensions=axis)
1385
+ x = convert_to_tensor(x)
1313
1386
  return jnp.squeeze(x, axis=axis)
1314
1387
 
1315
1388
 
@@ -1328,6 +1401,19 @@ def transpose(x, axes=None):
1328
1401
  return jnp.transpose(x, axes=axes)
1329
1402
 
1330
1403
 
1404
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
1405
+ y = convert_to_tensor(y)
1406
+ if x is not None:
1407
+ x = convert_to_tensor(x)
1408
+ dx = convert_to_tensor(dx)
1409
+ return jnp.trapezoid(y, x, dx=dx, axis=axis)
1410
+
1411
+
1412
+ def vander(x, N=None, increasing=False):
1413
+ x = convert_to_tensor(x)
1414
+ return jnp.vander(x, N=N, increasing=increasing)
1415
+
1416
+
1331
1417
  def var(x, axis=None, keepdims=False):
1332
1418
  x = convert_to_tensor(x)
1333
1419
  # `jnp.var` does not handle low precision (e.g., float16) overflow
@@ -36,13 +36,14 @@ class JaxOptimizer(base_optimizer.BaseOptimizer):
36
36
  new_g_accs = jax.lax.cond(
37
37
  is_update_step,
38
38
  lambda: [jnp.zeros(g.shape, dtype=g.dtype) for g in acc_grads],
39
- lambda: [g + acc_g for g, acc_g in zip(grads, acc_grads)],
39
+ lambda: [g + acc_g.value for g, acc_g in zip(grads, acc_grads)],
40
40
  )
41
41
 
42
42
  grads = jax.lax.cond(
43
43
  is_update_step,
44
44
  lambda: [
45
- (g + acc_g) / steps for g, acc_g in zip(grads, acc_grads)
45
+ (g + acc_g.value) / steps
46
+ for g, acc_g in zip(grads, acc_grads)
46
47
  ],
47
48
  lambda: list(grads),
48
49
  )
@@ -105,7 +105,10 @@ class JAXTrainer(base_trainer.Trainer):
105
105
  ]
106
106
  ) as scope:
107
107
  self._loss_tracker.update_state(
108
- unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0]
108
+ unscaled_loss,
109
+ sample_weight=next(
110
+ i for i in tree.flatten(x) if i is not None
111
+ ).shape[0],
109
112
  )
110
113
  logs = self.compute_metrics(x, y, y_pred, sample_weight)
111
114
 
@@ -263,8 +266,14 @@ class JAXTrainer(base_trainer.Trainer):
263
266
  if distribution_lib.distribution() is not None:
264
267
  state_shardings = self._get_state_sharding_spec()
265
268
  out_shardings = (None, state_shardings)
269
+ if is_nnx_enabled():
270
+ step_fn = lambda state, data: type(self).train_step(
271
+ self, state, data
272
+ )
273
+ else:
274
+ step_fn = self.train_step
266
275
  train_step = jit(
267
- self.train_step,
276
+ step_fn,
268
277
  donate_argnums=0,
269
278
  out_shardings=out_shardings,
270
279
  )
@@ -293,8 +302,14 @@ class JAXTrainer(base_trainer.Trainer):
293
302
  metrics_shardings,
294
303
  )
295
304
  out_shardings = (None, state_shardings)
305
+ if is_nnx_enabled():
306
+ step_fn = lambda state, data: type(self).test_step(
307
+ self, state, data
308
+ )
309
+ else:
310
+ step_fn = self.test_step
296
311
  test_step = jit(
297
- self.test_step,
312
+ step_fn,
298
313
  donate_argnums=0,
299
314
  out_shardings=out_shardings,
300
315
  )
@@ -96,3 +96,7 @@ def lstsq(a, b, rcond=None):
96
96
  a = convert_to_tensor(a)
97
97
  b = convert_to_tensor(b)
98
98
  return np.linalg.lstsq(a, b, rcond=rcond)[0]
99
+
100
+
101
+ def jvp(fun, primals, tangents, has_aux=False):
102
+ raise NotImplementedError("JVP is not supported by the Numpy backend.")
@@ -3,6 +3,9 @@ import numpy as np
3
3
  from jax import lax
4
4
 
5
5
  from keras.src import backend
6
+ from keras.src.backend.common.backend_utils import (
7
+ compute_adaptive_pooling_window_sizes,
8
+ )
6
9
  from keras.src.backend.common.backend_utils import (
7
10
  compute_conv_transpose_padding_args_for_jax,
8
11
  )
@@ -164,13 +167,14 @@ def celu(x, alpha=1.0):
164
167
 
165
168
  def glu(x, axis=-1):
166
169
  x = convert_to_tensor(x)
170
+ dtype = x.dtype
167
171
  if x.shape[axis] % 2 != 0:
168
172
  raise ValueError(
169
173
  "axis size must be divisible by 2. "
170
174
  f"Received: x.shape={x.shape} with axis={axis}"
171
175
  )
172
176
  x1, x2 = np.split(x, 2, axis)
173
- return x1 * (1 / (1 + np.exp(-x2)))
177
+ return (x1 * sigmoid(x2)).astype(dtype)
174
178
 
175
179
 
176
180
  def hard_tanh(x):
@@ -339,6 +343,252 @@ def average_pool(
339
343
  return pooled / window_counts
340
344
 
341
345
 
346
+ def _compute_adaptive_pooling_gather_indices(
347
+ input_dim, output_size, big_window
348
+ ):
349
+ window_starts = np.floor(
350
+ (np.arange(output_size) * input_dim) / output_size
351
+ ).astype(np.int32)
352
+
353
+ window_ends = np.ceil(
354
+ (np.arange(1, output_size + 1) * input_dim) / output_size
355
+ ).astype(np.int32)
356
+
357
+ window_sizes = window_ends - window_starts
358
+ is_big = window_sizes == big_window
359
+
360
+ small_window = big_window - 1
361
+ small_pool_len = input_dim - small_window + 1
362
+
363
+ small_indices = window_starts
364
+ big_indices = window_starts + small_pool_len
365
+
366
+ gather = np.where(is_big, big_indices, small_indices)
367
+ return gather.astype(np.int32)
368
+
369
+
370
+ def _strided_view_1d(x, window_size):
371
+ n, l, c = x.shape
372
+ out = l - window_size + 1
373
+
374
+ strides = x.strides
375
+ shape = (n, out, window_size, c)
376
+ new_strides = (strides[0], strides[1], strides[1], strides[2])
377
+
378
+ return np.lib.stride_tricks.as_strided(x, shape=shape, strides=new_strides)
379
+
380
+
381
+ def _adaptive_pool1d_impl(inputs, output_size, mode, data_format):
382
+ if isinstance(output_size, int):
383
+ output_size = (output_size,)
384
+
385
+ if data_format == "channels_first":
386
+ inputs = np.transpose(inputs, (0, 2, 1))
387
+
388
+ n, l, c = inputs.shape
389
+ out_l = output_size[0]
390
+
391
+ small, big = compute_adaptive_pooling_window_sizes(l, out_l)
392
+ gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
393
+
394
+ sv_small = _strided_view_1d(inputs, small)
395
+ small_pool = (
396
+ np.mean(sv_small, axis=2)
397
+ if mode == "average"
398
+ else np.max(sv_small, axis=2)
399
+ )
400
+
401
+ sv_big = _strided_view_1d(inputs, big)
402
+ big_pool = (
403
+ np.mean(sv_big, axis=2) if mode == "average" else np.max(sv_big, axis=2)
404
+ )
405
+
406
+ combined = np.concatenate([small_pool, big_pool], axis=1)
407
+ out = combined[:, gather, :]
408
+
409
+ if data_format == "channels_first":
410
+ out = np.transpose(out, (0, 2, 1))
411
+
412
+ return out
413
+
414
+
415
+ def _adaptive_pool2d_impl(inputs, output_size, mode, data_format):
416
+ if isinstance(output_size, int):
417
+ output_size = (output_size, output_size)
418
+
419
+ if data_format == "channels_first":
420
+ inputs = np.transpose(inputs, (0, 2, 3, 1))
421
+
422
+ n, h, w, c = inputs.shape
423
+ out_h, out_w = output_size
424
+
425
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
426
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
427
+
428
+ x_h = np.transpose(inputs, (0, 2, 1, 3)).reshape(n * w, h, c)
429
+
430
+ sv_small_h = _strided_view_1d(x_h, small_h)
431
+ small_pool_h = (
432
+ np.mean(sv_small_h, axis=2)
433
+ if mode == "average"
434
+ else np.max(sv_small_h, axis=2)
435
+ )
436
+
437
+ sv_big_h = _strided_view_1d(x_h, big_h)
438
+ big_pool_h = (
439
+ np.mean(sv_big_h, axis=2)
440
+ if mode == "average"
441
+ else np.max(sv_big_h, axis=2)
442
+ )
443
+
444
+ combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
445
+ pooled_h = combined_h[:, gather_h, :]
446
+
447
+ pooled_h = pooled_h.reshape(n, w, out_h, c)
448
+ pooled_h = np.transpose(pooled_h, (0, 2, 1, 3))
449
+
450
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
451
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
452
+
453
+ x_w = pooled_h.reshape(n * out_h, w, c)
454
+
455
+ sv_small_w = _strided_view_1d(x_w, small_w)
456
+ small_pool_w = (
457
+ np.mean(sv_small_w, axis=2)
458
+ if mode == "average"
459
+ else np.max(sv_small_w, axis=2)
460
+ )
461
+
462
+ sv_big_w = _strided_view_1d(x_w, big_w)
463
+ big_pool_w = (
464
+ np.mean(sv_big_w, axis=2)
465
+ if mode == "average"
466
+ else np.max(sv_big_w, axis=2)
467
+ )
468
+
469
+ combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
470
+ out = combined_w[:, gather_w, :].reshape(n, out_h, out_w, c)
471
+
472
+ if data_format == "channels_first":
473
+ out = np.transpose(out, (0, 3, 1, 2))
474
+
475
+ return out
476
+
477
+
478
+ def _adaptive_pool3d_impl(inputs, output_size, mode, data_format):
479
+ if isinstance(output_size, int):
480
+ output_size = (output_size, output_size, output_size)
481
+
482
+ if data_format == "channels_first":
483
+ inputs = np.transpose(inputs, (0, 2, 3, 4, 1))
484
+
485
+ n, d, h, w, c = inputs.shape
486
+ out_d, out_h, out_w = output_size
487
+
488
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
489
+ gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
490
+
491
+ x_d = np.transpose(inputs, (0, 2, 3, 1, 4)).reshape(n * h * w, d, c)
492
+
493
+ sv_small_d = _strided_view_1d(x_d, small_d)
494
+ small_pool_d = (
495
+ np.mean(sv_small_d, axis=2)
496
+ if mode == "average"
497
+ else np.max(sv_small_d, axis=2)
498
+ )
499
+
500
+ sv_big_d = _strided_view_1d(x_d, big_d)
501
+ big_pool_d = (
502
+ np.mean(sv_big_d, axis=2)
503
+ if mode == "average"
504
+ else np.max(sv_big_d, axis=2)
505
+ )
506
+
507
+ combined_d = np.concatenate([small_pool_d, big_pool_d], axis=1)
508
+ pooled_d = combined_d[:, gather_d, :].reshape(n, h, w, out_d, c)
509
+ pooled_d = np.transpose(pooled_d, (0, 3, 1, 2, 4))
510
+
511
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
512
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
513
+
514
+ x_h = np.transpose(pooled_d, (0, 1, 3, 2, 4)).reshape(n * out_d * w, h, c)
515
+
516
+ sv_small_h = _strided_view_1d(x_h, small_h)
517
+ small_pool_h = (
518
+ np.mean(sv_small_h, axis=2)
519
+ if mode == "average"
520
+ else np.max(sv_small_h, axis=2)
521
+ )
522
+
523
+ sv_big_h = _strided_view_1d(x_h, big_h)
524
+ big_pool_h = (
525
+ np.mean(sv_big_h, axis=2)
526
+ if mode == "average"
527
+ else np.max(sv_big_h, axis=2)
528
+ )
529
+
530
+ combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)
531
+ pooled_h = combined_h[:, gather_h, :].reshape(n, out_d, w, out_h, c)
532
+ pooled_h = np.transpose(pooled_h, (0, 1, 3, 2, 4))
533
+
534
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
535
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
536
+
537
+ x_w = pooled_h.reshape(n * out_d * out_h, w, c)
538
+
539
+ sv_small_w = _strided_view_1d(x_w, small_w)
540
+ small_pool_w = (
541
+ np.mean(sv_small_w, axis=2)
542
+ if mode == "average"
543
+ else np.max(sv_small_w, axis=2)
544
+ )
545
+
546
+ sv_big_w = _strided_view_1d(x_w, big_w)
547
+ big_pool_w = (
548
+ np.mean(sv_big_w, axis=2)
549
+ if mode == "average"
550
+ else np.max(sv_big_w, axis=2)
551
+ )
552
+
553
+ combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)
554
+ out = combined_w[:, gather_w, :].reshape(n, out_d, out_h, out_w, c)
555
+
556
+ if data_format == "channels_first":
557
+ out = np.transpose(out, (0, 4, 1, 2, 3))
558
+
559
+ return out
560
+
561
+
562
+ def adaptive_average_pool(inputs, output_size, data_format=None):
563
+ data_format = backend.standardize_data_format(data_format)
564
+ dims = inputs.ndim - 2
565
+ if dims == 1:
566
+ return _adaptive_pool1d_impl(
567
+ inputs, output_size, "average", data_format
568
+ )
569
+ if dims == 2:
570
+ return _adaptive_pool2d_impl(
571
+ inputs, output_size, "average", data_format
572
+ )
573
+ if dims == 3:
574
+ return _adaptive_pool3d_impl(
575
+ inputs, output_size, "average", data_format
576
+ )
577
+ raise ValueError("adaptive_average_pool supports only 1D/2D/3D")
578
+
579
+
580
+ def adaptive_max_pool(inputs, output_size, data_format=None):
581
+ data_format = backend.standardize_data_format(data_format)
582
+ dims = inputs.ndim - 2
583
+ if dims == 1:
584
+ return _adaptive_pool1d_impl(inputs, output_size, "max", data_format)
585
+ if dims == 2:
586
+ return _adaptive_pool2d_impl(inputs, output_size, "max", data_format)
587
+ if dims == 3:
588
+ return _adaptive_pool3d_impl(inputs, output_size, "max", data_format)
589
+ raise ValueError("adaptive_max_pool supports only 1D/2D/3D")
590
+
591
+
342
592
  def _convert_to_lax_conv_dimension_numbers(
343
593
  num_spatial_dims,
344
594
  data_format="channels_last",
@@ -403,7 +653,7 @@ def conv(
403
653
  f"kernel in_channels {kernel_in_channels}. "
404
654
  )
405
655
  feature_group_count = channels // kernel_in_channels
406
- return np.array(
656
+ result = np.array(
407
657
  jax.lax.conv_general_dilated(
408
658
  inputs,
409
659
  kernel if is_tensor(kernel) else kernel.numpy(),
@@ -414,6 +664,14 @@ def conv(
414
664
  feature_group_count=feature_group_count,
415
665
  )
416
666
  )
667
+ if result.size == 0:
668
+ raise ValueError(
669
+ "The convolution operation resulted in an empty output. "
670
+ "This can happen if the input is too small for the given "
671
+ "kernel size, strides, dilation rate, and padding mode. "
672
+ "Please check the input shape and convolution parameters."
673
+ )
674
+ return result
417
675
 
418
676
 
419
677
  def depthwise_conv(
@@ -1175,3 +1433,56 @@ def dot_product_attention(
1175
1433
  return _dot_product_attention_xla(
1176
1434
  query, key, value, bias, mask, is_causal, scale
1177
1435
  )
1436
+
1437
+
1438
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1439
+ """NumPy implementation of Unfold.
1440
+ Extract sliding local blocks from a **NCHW** batched image tensor.
1441
+
1442
+ Args:
1443
+ input: 4-D tensor, shape (N, C, H, W) **required**.
1444
+ kernel_size: int or (kH, kW)
1445
+ dilation: int or (dH, dW), default 1
1446
+ padding: int or (pH, pW), default 0
1447
+ stride: int or (sH, sW), default 1
1448
+
1449
+ Returns:
1450
+ 3-D tensor, shape (N, C*kH*kW, L)
1451
+ """
1452
+
1453
+ def _pair(x):
1454
+ return (x, x) if isinstance(x, int) else x
1455
+
1456
+ k = _pair(kernel_size)
1457
+ d = _pair(dilation)
1458
+ p = _pair(padding)
1459
+ s = _pair(stride)
1460
+
1461
+ N, C, H, W = input.shape
1462
+
1463
+ # ---- padding ----
1464
+ if any(_ > 0 for _ in p):
1465
+ input = np.pad(
1466
+ input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])), mode="constant"
1467
+ )
1468
+
1469
+ # ---- spatial size ----
1470
+ oH = (input.shape[2] - (k[0] - 1) * d[0] - 1) // s[0] + 1
1471
+ oW = (input.shape[3] - (k[1] - 1) * d[1] - 1) // s[1] + 1
1472
+
1473
+ i0 = np.arange(0, oH) * s[0]
1474
+ j0 = np.arange(0, oW) * s[1]
1475
+ i, j = np.meshgrid(i0, j0, indexing="ij") # shape (oH, oW)
1476
+ i = i.reshape(-1)
1477
+ j = j.reshape(-1)
1478
+
1479
+ # ---- flatten patches ----
1480
+ patches = np.empty((N, C, k[0], k[1], oH * oW), dtype=input.dtype)
1481
+ for idx in range(k[0]):
1482
+ for jdx in range(k[1]):
1483
+ patches[:, :, idx, jdx, :] = input[
1484
+ :, :, i + idx * d[0], j + jdx * d[1]
1485
+ ]
1486
+
1487
+ # ---- reshape -> (N, C*kH*kW, L) ----
1488
+ return patches.reshape(N, C * k[0] * k[1], -1)