keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,9 @@ from jax.experimental.pallas.ops.tpu.splash_attention import (
16
16
  )
17
17
 
18
18
  from keras.src import backend
19
+ from keras.src.backend.common.backend_utils import (
20
+ compute_adaptive_pooling_window_sizes,
21
+ )
19
22
  from keras.src.backend.common.backend_utils import (
20
23
  compute_conv_transpose_padding_args_for_jax,
21
24
  )
@@ -289,6 +292,403 @@ def average_pool(
289
292
  return pooled / window_counts
290
293
 
291
294
 
295
+ def _compute_adaptive_pooling_gather_indices(
296
+ input_dim, output_size, big_window
297
+ ):
298
+ """Compute gather indices for Two-Pool Gather method."""
299
+ window_starts = jnp.floor(
300
+ (jnp.arange(output_size) * input_dim) / output_size
301
+ ).astype(jnp.int32)
302
+
303
+ window_ends = jnp.ceil(
304
+ (jnp.arange(1, output_size + 1) * input_dim) / output_size
305
+ ).astype(jnp.int32)
306
+
307
+ window_sizes = window_ends - window_starts
308
+ is_big = window_sizes == big_window
309
+
310
+ small_window = big_window - 1
311
+ small_len = input_dim - small_window + 1
312
+
313
+ small_indices = window_starts
314
+ big_indices = window_starts + small_len
315
+
316
+ gather = jnp.where(is_big, big_indices, small_indices)
317
+ return gather.astype(jnp.int32)
318
+
319
+
320
+ def _adaptive_average_pool1d(inputs, output_size, data_format="channels_first"):
321
+ if isinstance(output_size, int):
322
+ output_size = (output_size,)
323
+
324
+ if data_format == "channels_first":
325
+ inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL → NLC
326
+
327
+ n, l, c = inputs.shape
328
+ out_l = output_size[0]
329
+
330
+ small, big = compute_adaptive_pooling_window_sizes(l, out_l)
331
+ gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
332
+
333
+ small_pool = (
334
+ lax.reduce_window(
335
+ inputs, 0.0, lax.add, (1, small, 1), (1, 1, 1), "valid"
336
+ )
337
+ / small
338
+ )
339
+
340
+ big_pool = (
341
+ lax.reduce_window(inputs, 0.0, lax.add, (1, big, 1), (1, 1, 1), "valid")
342
+ / big
343
+ )
344
+
345
+ combined = jnp.concatenate([small_pool, big_pool], axis=1)
346
+ out = jnp.take(combined, gather, axis=1)
347
+
348
+ if data_format == "channels_first":
349
+ out = jnp.transpose(out, (0, 2, 1))
350
+
351
+ return out
352
+
353
+
354
+ def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
355
+ if isinstance(output_size, int):
356
+ output_size = (output_size,)
357
+
358
+ if data_format == "channels_first":
359
+ inputs = jnp.transpose(inputs, (0, 2, 1))
360
+
361
+ n, l, c = inputs.shape
362
+ out_l = output_size[0]
363
+
364
+ small, big = compute_adaptive_pooling_window_sizes(l, out_l)
365
+ gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
366
+
367
+ small_pool = lax.reduce_window(
368
+ inputs, -jnp.inf, lax.max, (1, small, 1), (1, 1, 1), "valid"
369
+ )
370
+
371
+ big_pool = lax.reduce_window(
372
+ inputs, -jnp.inf, lax.max, (1, big, 1), (1, 1, 1), "valid"
373
+ )
374
+
375
+ combined = jnp.concatenate([small_pool, big_pool], axis=1)
376
+ out = jnp.take(combined, gather, axis=1)
377
+
378
+ if data_format == "channels_first":
379
+ out = jnp.transpose(out, (0, 2, 1))
380
+
381
+ return out
382
+
383
+
384
+ def _adaptive_average_pool2d(inputs, output_size, data_format="channels_first"):
385
+ if isinstance(output_size, int):
386
+ output_size = (output_size, output_size)
387
+
388
+ if data_format == "channels_first":
389
+ inputs = jnp.transpose(inputs, (0, 2, 3, 1))
390
+
391
+ n, h, w, c = inputs.shape
392
+ out_h, out_w = output_size
393
+
394
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
395
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
396
+
397
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
398
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
399
+
400
+ small_h_pool = (
401
+ lax.reduce_window(
402
+ inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
403
+ )
404
+ / small_h
405
+ )
406
+
407
+ big_h_pool = (
408
+ lax.reduce_window(
409
+ inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
410
+ )
411
+ / big_h
412
+ )
413
+
414
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
415
+ pooled_h = jnp.take(combined_h, gather_h, axis=1)
416
+
417
+ small_w_pool = (
418
+ lax.reduce_window(
419
+ pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
420
+ )
421
+ / small_w
422
+ )
423
+
424
+ big_w_pool = (
425
+ lax.reduce_window(
426
+ pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
427
+ )
428
+ / big_w
429
+ )
430
+
431
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
432
+ out = jnp.take(combined_w, gather_w, axis=2)
433
+
434
+ if data_format == "channels_first":
435
+ out = jnp.transpose(out, (0, 3, 1, 2))
436
+
437
+ return out
438
+
439
+
440
+ def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
441
+ if isinstance(output_size, int):
442
+ output_size = (output_size, output_size)
443
+
444
+ if data_format == "channels_first":
445
+ inputs = jnp.transpose(inputs, (0, 2, 3, 1))
446
+
447
+ n, h, w, c = inputs.shape
448
+ out_h, out_w = output_size
449
+
450
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
451
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
452
+
453
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
454
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
455
+
456
+ small_h_pool = lax.reduce_window(
457
+ inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
458
+ )
459
+
460
+ big_h_pool = lax.reduce_window(
461
+ inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
462
+ )
463
+
464
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
465
+ pooled_h = jnp.take(combined_h, gather_h, axis=1)
466
+
467
+ small_w_pool = lax.reduce_window(
468
+ pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
469
+ )
470
+
471
+ big_w_pool = lax.reduce_window(
472
+ pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
473
+ )
474
+
475
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
476
+ out = jnp.take(combined_w, gather_w, axis=2)
477
+
478
+ if data_format == "channels_first":
479
+ out = jnp.transpose(out, (0, 3, 1, 2))
480
+
481
+ return out
482
+
483
+
484
+ def _adaptive_average_pool3d(inputs, output_size, data_format="channels_first"):
485
+ if isinstance(output_size, int):
486
+ output_size = (output_size, output_size, output_size)
487
+
488
+ if data_format == "channels_first":
489
+ inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
490
+
491
+ n, d, h, w, c = inputs.shape
492
+ out_d, out_h, out_w = output_size
493
+
494
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
495
+ gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
496
+
497
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
498
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
499
+
500
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
501
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
502
+
503
+ small_d_pool = (
504
+ lax.reduce_window(
505
+ inputs,
506
+ 0.0,
507
+ lax.add,
508
+ (1, small_d, 1, 1, 1),
509
+ (1, 1, 1, 1, 1),
510
+ "valid",
511
+ )
512
+ / small_d
513
+ )
514
+
515
+ big_d_pool = (
516
+ lax.reduce_window(
517
+ inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
518
+ )
519
+ / big_d
520
+ )
521
+
522
+ combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
523
+ pooled_d = jnp.take(combined_d, gather_d, axis=1)
524
+
525
+ small_h_pool = (
526
+ lax.reduce_window(
527
+ pooled_d,
528
+ 0.0,
529
+ lax.add,
530
+ (1, 1, small_h, 1, 1),
531
+ (1, 1, 1, 1, 1),
532
+ "valid",
533
+ )
534
+ / small_h
535
+ )
536
+
537
+ big_h_pool = (
538
+ lax.reduce_window(
539
+ pooled_d,
540
+ 0.0,
541
+ lax.add,
542
+ (1, 1, big_h, 1, 1),
543
+ (1, 1, 1, 1, 1),
544
+ "valid",
545
+ )
546
+ / big_h
547
+ )
548
+
549
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
550
+ pooled_h = jnp.take(combined_h, gather_h, axis=2)
551
+
552
+ small_w_pool = (
553
+ lax.reduce_window(
554
+ pooled_h,
555
+ 0.0,
556
+ lax.add,
557
+ (1, 1, 1, small_w, 1),
558
+ (1, 1, 1, 1, 1),
559
+ "valid",
560
+ )
561
+ / small_w
562
+ )
563
+
564
+ big_w_pool = (
565
+ lax.reduce_window(
566
+ pooled_h,
567
+ 0.0,
568
+ lax.add,
569
+ (1, 1, 1, big_w, 1),
570
+ (1, 1, 1, 1, 1),
571
+ "valid",
572
+ )
573
+ / big_w
574
+ )
575
+
576
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
577
+ out = jnp.take(combined_w, gather_w, axis=3)
578
+
579
+ if data_format == "channels_first":
580
+ out = jnp.transpose(out, (0, 4, 1, 2, 3))
581
+
582
+ return out
583
+
584
+
585
+ def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
586
+ if isinstance(output_size, int):
587
+ output_size = (output_size, output_size, output_size)
588
+
589
+ if data_format == "channels_first":
590
+ inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
591
+
592
+ n, d, h, w, c = inputs.shape
593
+ out_d, out_h, out_w = output_size
594
+
595
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
596
+ gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
597
+
598
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
599
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
600
+
601
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
602
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
603
+
604
+ small_d_pool = lax.reduce_window(
605
+ inputs,
606
+ -jnp.inf,
607
+ lax.max,
608
+ (1, small_d, 1, 1, 1),
609
+ (1, 1, 1, 1, 1),
610
+ "valid",
611
+ )
612
+
613
+ big_d_pool = lax.reduce_window(
614
+ inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
615
+ )
616
+
617
+ combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
618
+ pooled_d = jnp.take(combined_d, gather_d, axis=1)
619
+
620
+ small_h_pool = lax.reduce_window(
621
+ pooled_d,
622
+ -jnp.inf,
623
+ lax.max,
624
+ (1, 1, small_h, 1, 1),
625
+ (1, 1, 1, 1, 1),
626
+ "valid",
627
+ )
628
+
629
+ big_h_pool = lax.reduce_window(
630
+ pooled_d,
631
+ -jnp.inf,
632
+ lax.max,
633
+ (1, 1, big_h, 1, 1),
634
+ (1, 1, 1, 1, 1),
635
+ "valid",
636
+ )
637
+
638
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
639
+ pooled_h = jnp.take(combined_h, gather_h, axis=2)
640
+
641
+ small_w_pool = lax.reduce_window(
642
+ pooled_h,
643
+ -jnp.inf,
644
+ lax.max,
645
+ (1, 1, 1, small_w, 1),
646
+ (1, 1, 1, 1, 1),
647
+ "valid",
648
+ )
649
+
650
+ big_w_pool = lax.reduce_window(
651
+ pooled_h,
652
+ -jnp.inf,
653
+ lax.max,
654
+ (1, 1, 1, big_w, 1),
655
+ (1, 1, 1, 1, 1),
656
+ "valid",
657
+ )
658
+
659
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
660
+ out = jnp.take(combined_w, gather_w, axis=3)
661
+
662
+ if data_format == "channels_first":
663
+ out = jnp.transpose(out, (0, 4, 1, 2, 3))
664
+
665
+ return out
666
+
667
+
668
+ def adaptive_average_pool(inputs, output_size, data_format=None):
669
+ data_format = backend.standardize_data_format(data_format)
670
+ dims = inputs.ndim - 2
671
+ if dims == 1:
672
+ return _adaptive_average_pool1d(inputs, output_size, data_format)
673
+ if dims == 2:
674
+ return _adaptive_average_pool2d(inputs, output_size, data_format)
675
+ if dims == 3:
676
+ return _adaptive_average_pool3d(inputs, output_size, data_format)
677
+ raise ValueError("adaptive_average_pool supports only 1D/2D/3D inputs")
678
+
679
+
680
+ def adaptive_max_pool(inputs, output_size, data_format=None):
681
+ data_format = backend.standardize_data_format(data_format)
682
+ dims = inputs.ndim - 2
683
+ if dims == 1:
684
+ return _adaptive_max_pool1d(inputs, output_size, data_format)
685
+ if dims == 2:
686
+ return _adaptive_max_pool2d(inputs, output_size, data_format)
687
+ if dims == 3:
688
+ return _adaptive_max_pool3d(inputs, output_size, data_format)
689
+ raise ValueError("adaptive_max_pool supports only 1D/2D/3D inputs")
690
+
691
+
292
692
  def _convert_to_lax_conv_dimension_numbers(
293
693
  num_spatial_dims,
294
694
  data_format="channels_last",
@@ -355,7 +755,7 @@ def conv(
355
755
  feature_group_count = channels // kernel_in_channels
356
756
  kernel = convert_to_tensor(kernel)
357
757
  inputs = convert_to_tensor(inputs, dtype=kernel.dtype)
358
- return jax.lax.conv_general_dilated(
758
+ result = jax.lax.conv_general_dilated(
359
759
  inputs,
360
760
  kernel,
361
761
  strides,
@@ -364,6 +764,14 @@ def conv(
364
764
  dimension_numbers=dimension_numbers,
365
765
  feature_group_count=feature_group_count,
366
766
  )
767
+ if result.size == 0:
768
+ raise ValueError(
769
+ "The convolution operation resulted in an empty output. "
770
+ "This can happen if the input is too small for the given "
771
+ "kernel size, strides, dilation rate, and padding mode. "
772
+ "Please check the input shape and convolution parameters."
773
+ )
774
+ return result
367
775
 
368
776
 
369
777
  def depthwise_conv(
@@ -396,6 +804,8 @@ def depthwise_conv(
396
804
  feature_group_count = (
397
805
  inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
398
806
  )
807
+ kernel = convert_to_tensor(kernel)
808
+ inputs = convert_to_tensor(inputs)
399
809
  kernel = jnp.reshape(
400
810
  kernel,
401
811
  kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
@@ -499,7 +909,7 @@ def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
499
909
  values = jnp.greater_equal(jnp.ravel(x), 0).astype(dtype)
500
910
  values_count = values.shape[0]
501
911
  indices = [jnp.arange(dim) for dim in x.shape]
502
- indices = jnp.meshgrid(*indices, indexing="ij")
912
+ indices = list(jnp.meshgrid(*indices, indexing="ij"))
503
913
  indices.insert(axis, jnp.maximum(x, 0)) # Deal with negative indices
504
914
  indices = [a.reshape(values_count, 1).astype("int32") for a in indices]
505
915
  indices = jnp.concatenate(indices, axis=1)
@@ -1061,25 +1471,42 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
1061
1471
  # Only support at least Ampere
1062
1472
  if not check_compute_capability("8.0"):
1063
1473
  raise RuntimeError("Require at least Ampere arch to run")
1064
- # Check inputs layout
1474
+
1475
+ # Inspect inputs of `check_layout`
1065
1476
  check_layout_params = list(
1066
1477
  inspect.signature(check_layout).parameters.keys()
1067
1478
  )
1068
1479
  for known_param in ("query", "key", "value", "bias", "layout"):
1069
1480
  check_layout_params.remove(known_param)
1070
1481
  # Defaults to `None` when not specified.
1071
- kwargs = {key: None for key in check_layout_params}
1482
+ check_layout_kwargs = {key: None for key in check_layout_params}
1072
1483
  check_layout(
1073
- query, key, value, bias, layout=_normalize_layout("BTNH"), **kwargs
1074
- )
1075
- check_is_flash_attention(
1076
1484
  query,
1077
1485
  key,
1078
- _normalize_layout("BTNH"),
1079
- cudnn_version,
1080
- bias is not None,
1081
- is_training=False,
1486
+ value,
1487
+ bias,
1488
+ layout=_normalize_layout("BTNH"),
1489
+ **check_layout_kwargs,
1082
1490
  )
1491
+
1492
+ # Inspect inputs of `check_is_flash_attention`
1493
+ check_is_flash_attention_params = inspect.signature(
1494
+ check_is_flash_attention
1495
+ ).parameters
1496
+ check_is_flash_attention_kwargs = {
1497
+ "query": query,
1498
+ "key": key,
1499
+ "value": value,
1500
+ "layout": _normalize_layout("BTNH"),
1501
+ "cudnn_version": cudnn_version,
1502
+ "has_bias": bias is not None,
1503
+ "is_training": False,
1504
+ }
1505
+ # Remove unsupported arguments
1506
+ for param in list(check_is_flash_attention_kwargs.keys()):
1507
+ if param not in check_is_flash_attention_params:
1508
+ check_is_flash_attention_kwargs.pop(param)
1509
+ check_is_flash_attention(**check_is_flash_attention_kwargs)
1083
1510
  return True
1084
1511
  except:
1085
1512
  if raise_error:
@@ -1330,25 +1757,32 @@ def dot_product_attention(
1330
1757
  if custom_mask is None and is_causal:
1331
1758
  custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
1332
1759
 
1333
- try:
1334
- output = wrap_flash_attention(
1335
- query_tpu_layout,
1336
- key_tpu_layout,
1337
- value_tpu_layout,
1338
- decoder_segment_ids=decoder_segment_ids,
1339
- custom_mask=custom_mask,
1340
- attn_logits_soft_cap=attn_logits_soft_cap,
1341
- head_shards=head_shards,
1342
- q_seq_shards=q_seq_shards,
1343
- )
1344
- # Transpose output back to Keras layout
1345
- return jnp.transpose(output, axes=(0, 2, 1, 3))
1346
- except Exception:
1347
- logging.exception(
1348
- "Failed to apply Splash kernel for flash attention. "
1349
- "Falling back to JAX native dot_product_attention."
1350
- )
1760
+ # Splash attention kernel requires concrete mask values for hashing.
1761
+ # If the mask is a tracer (e.g. inside a scan/loop), we must fall back.
1762
+ if isinstance(mask, jax.core.Tracer) or isinstance(
1763
+ custom_mask, jax.core.Tracer
1764
+ ):
1351
1765
  flash_attention = False
1766
+ else:
1767
+ try:
1768
+ output = wrap_flash_attention(
1769
+ query_tpu_layout,
1770
+ key_tpu_layout,
1771
+ value_tpu_layout,
1772
+ decoder_segment_ids=decoder_segment_ids,
1773
+ custom_mask=custom_mask,
1774
+ attn_logits_soft_cap=attn_logits_soft_cap,
1775
+ head_shards=head_shards,
1776
+ q_seq_shards=q_seq_shards,
1777
+ )
1778
+ # Transpose output back to Keras layout
1779
+ return jnp.transpose(output, axes=(0, 2, 1, 3))
1780
+ except Exception:
1781
+ logging.exception(
1782
+ "Failed to apply Splash kernel for flash attention. "
1783
+ "Falling back to JAX native dot_product_attention."
1784
+ )
1785
+ flash_attention = False
1352
1786
 
1353
1787
  # JAX native dot_product_attention for GPU or fallback for TPU
1354
1788
  if hasattr(jax.nn, "dot_product_attention"):
@@ -1394,6 +1828,11 @@ def dot_product_attention(
1394
1828
 
1395
1829
  def _reshape_to_grouped(t):
1396
1830
  if t is not None:
1831
+ while t.ndim < 4:
1832
+ if t.ndim == 3 and t.shape[1] == N:
1833
+ t = jnp.expand_dims(t, axis=2)
1834
+ else:
1835
+ t = jnp.expand_dims(t, axis=1)
1397
1836
  tB, tN, tT, tS = t.shape
1398
1837
  if tN == 1:
1399
1838
  t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
@@ -1411,3 +1850,46 @@ def dot_product_attention(
1411
1850
  )
1412
1851
  encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
1413
1852
  return jnp.reshape(encoded, output_shape)
1853
+
1854
+
1855
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1856
+ """JAX implementation of Unfold.
1857
+ Extract sliding local blocks from a **NCHW** batched image tensor.
1858
+
1859
+ Args:
1860
+ input: 4-D tensor, shape (N, C, H, W) **required**.
1861
+ kernel_size: int or (kH, kW)
1862
+ dilation: int or (dH, dW), default 1
1863
+ padding: int or (pH, pW), default 0
1864
+ stride: int or (sH, sW), default 1
1865
+
1866
+ Returns:
1867
+ 3-D tensor, shape (N, C*kH*kW, L)
1868
+ """
1869
+
1870
+ def _pair(x):
1871
+ return (x, x) if isinstance(x, int) else x
1872
+
1873
+ k = _pair(kernel_size)
1874
+ d = _pair(dilation)
1875
+ p = _pair(padding)
1876
+ s = _pair(stride)
1877
+
1878
+ N, C, H, W = input.shape
1879
+
1880
+ # ---- padding ----
1881
+ if any(_ > 0 for _ in p):
1882
+ input = jnp.pad(input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])))
1883
+
1884
+ patches = lax.conv_general_dilated_patches(
1885
+ input,
1886
+ filter_shape=k,
1887
+ window_strides=s,
1888
+ padding="VALID", # has padde
1889
+ rhs_dilation=d,
1890
+ dimension_numbers=("NCHW", "OIHW", "NCHW"), # only support 'NCHW'
1891
+ ) # shape: (N, C*kH*kW, oH, oW)
1892
+
1893
+ # ---- reshape -> (N, C*kH*kW, L) ----
1894
+ _, CKK, oH, oW = patches.shape
1895
+ return patches.reshape(N, CKK, oH * oW)