keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,9 @@ from jax.experimental.pallas.ops.tpu.splash_attention import (
16
16
  )
17
17
 
18
18
  from keras.src import backend
19
+ from keras.src.backend.common.backend_utils import (
20
+ compute_adaptive_pooling_window_sizes,
21
+ )
19
22
  from keras.src.backend.common.backend_utils import (
20
23
  compute_conv_transpose_padding_args_for_jax,
21
24
  )
@@ -289,6 +292,403 @@ def average_pool(
289
292
  return pooled / window_counts
290
293
 
291
294
 
295
+ def _compute_adaptive_pooling_gather_indices(
296
+ input_dim, output_size, big_window
297
+ ):
298
+ """Compute gather indices for Two-Pool Gather method."""
299
+ window_starts = jnp.floor(
300
+ (jnp.arange(output_size) * input_dim) / output_size
301
+ ).astype(jnp.int32)
302
+
303
+ window_ends = jnp.ceil(
304
+ (jnp.arange(1, output_size + 1) * input_dim) / output_size
305
+ ).astype(jnp.int32)
306
+
307
+ window_sizes = window_ends - window_starts
308
+ is_big = window_sizes == big_window
309
+
310
+ small_window = big_window - 1
311
+ small_len = input_dim - small_window + 1
312
+
313
+ small_indices = window_starts
314
+ big_indices = window_starts + small_len
315
+
316
+ gather = jnp.where(is_big, big_indices, small_indices)
317
+ return gather.astype(jnp.int32)
318
+
319
+
320
+ def _adaptive_average_pool1d(inputs, output_size, data_format="channels_first"):
321
+ if isinstance(output_size, int):
322
+ output_size = (output_size,)
323
+
324
+ if data_format == "channels_first":
325
+ inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL → NLC
326
+
327
+ n, l, c = inputs.shape
328
+ out_l = output_size[0]
329
+
330
+ small, big = compute_adaptive_pooling_window_sizes(l, out_l)
331
+ gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
332
+
333
+ small_pool = (
334
+ lax.reduce_window(
335
+ inputs, 0.0, lax.add, (1, small, 1), (1, 1, 1), "valid"
336
+ )
337
+ / small
338
+ )
339
+
340
+ big_pool = (
341
+ lax.reduce_window(inputs, 0.0, lax.add, (1, big, 1), (1, 1, 1), "valid")
342
+ / big
343
+ )
344
+
345
+ combined = jnp.concatenate([small_pool, big_pool], axis=1)
346
+ out = jnp.take(combined, gather, axis=1)
347
+
348
+ if data_format == "channels_first":
349
+ out = jnp.transpose(out, (0, 2, 1))
350
+
351
+ return out
352
+
353
+
354
+ def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
355
+ if isinstance(output_size, int):
356
+ output_size = (output_size,)
357
+
358
+ if data_format == "channels_first":
359
+ inputs = jnp.transpose(inputs, (0, 2, 1))
360
+
361
+ n, l, c = inputs.shape
362
+ out_l = output_size[0]
363
+
364
+ small, big = compute_adaptive_pooling_window_sizes(l, out_l)
365
+ gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
366
+
367
+ small_pool = lax.reduce_window(
368
+ inputs, -jnp.inf, lax.max, (1, small, 1), (1, 1, 1), "valid"
369
+ )
370
+
371
+ big_pool = lax.reduce_window(
372
+ inputs, -jnp.inf, lax.max, (1, big, 1), (1, 1, 1), "valid"
373
+ )
374
+
375
+ combined = jnp.concatenate([small_pool, big_pool], axis=1)
376
+ out = jnp.take(combined, gather, axis=1)
377
+
378
+ if data_format == "channels_first":
379
+ out = jnp.transpose(out, (0, 2, 1))
380
+
381
+ return out
382
+
383
+
384
+ def _adaptive_average_pool2d(inputs, output_size, data_format="channels_first"):
385
+ if isinstance(output_size, int):
386
+ output_size = (output_size, output_size)
387
+
388
+ if data_format == "channels_first":
389
+ inputs = jnp.transpose(inputs, (0, 2, 3, 1))
390
+
391
+ n, h, w, c = inputs.shape
392
+ out_h, out_w = output_size
393
+
394
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
395
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
396
+
397
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
398
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
399
+
400
+ small_h_pool = (
401
+ lax.reduce_window(
402
+ inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
403
+ )
404
+ / small_h
405
+ )
406
+
407
+ big_h_pool = (
408
+ lax.reduce_window(
409
+ inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
410
+ )
411
+ / big_h
412
+ )
413
+
414
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
415
+ pooled_h = jnp.take(combined_h, gather_h, axis=1)
416
+
417
+ small_w_pool = (
418
+ lax.reduce_window(
419
+ pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
420
+ )
421
+ / small_w
422
+ )
423
+
424
+ big_w_pool = (
425
+ lax.reduce_window(
426
+ pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
427
+ )
428
+ / big_w
429
+ )
430
+
431
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
432
+ out = jnp.take(combined_w, gather_w, axis=2)
433
+
434
+ if data_format == "channels_first":
435
+ out = jnp.transpose(out, (0, 3, 1, 2))
436
+
437
+ return out
438
+
439
+
440
+ def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
441
+ if isinstance(output_size, int):
442
+ output_size = (output_size, output_size)
443
+
444
+ if data_format == "channels_first":
445
+ inputs = jnp.transpose(inputs, (0, 2, 3, 1))
446
+
447
+ n, h, w, c = inputs.shape
448
+ out_h, out_w = output_size
449
+
450
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
451
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
452
+
453
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
454
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
455
+
456
+ small_h_pool = lax.reduce_window(
457
+ inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
458
+ )
459
+
460
+ big_h_pool = lax.reduce_window(
461
+ inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
462
+ )
463
+
464
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
465
+ pooled_h = jnp.take(combined_h, gather_h, axis=1)
466
+
467
+ small_w_pool = lax.reduce_window(
468
+ pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
469
+ )
470
+
471
+ big_w_pool = lax.reduce_window(
472
+ pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
473
+ )
474
+
475
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
476
+ out = jnp.take(combined_w, gather_w, axis=2)
477
+
478
+ if data_format == "channels_first":
479
+ out = jnp.transpose(out, (0, 3, 1, 2))
480
+
481
+ return out
482
+
483
+
484
+ def _adaptive_average_pool3d(inputs, output_size, data_format="channels_first"):
485
+ if isinstance(output_size, int):
486
+ output_size = (output_size, output_size, output_size)
487
+
488
+ if data_format == "channels_first":
489
+ inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
490
+
491
+ n, d, h, w, c = inputs.shape
492
+ out_d, out_h, out_w = output_size
493
+
494
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
495
+ gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
496
+
497
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
498
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
499
+
500
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
501
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
502
+
503
+ small_d_pool = (
504
+ lax.reduce_window(
505
+ inputs,
506
+ 0.0,
507
+ lax.add,
508
+ (1, small_d, 1, 1, 1),
509
+ (1, 1, 1, 1, 1),
510
+ "valid",
511
+ )
512
+ / small_d
513
+ )
514
+
515
+ big_d_pool = (
516
+ lax.reduce_window(
517
+ inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
518
+ )
519
+ / big_d
520
+ )
521
+
522
+ combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
523
+ pooled_d = jnp.take(combined_d, gather_d, axis=1)
524
+
525
+ small_h_pool = (
526
+ lax.reduce_window(
527
+ pooled_d,
528
+ 0.0,
529
+ lax.add,
530
+ (1, 1, small_h, 1, 1),
531
+ (1, 1, 1, 1, 1),
532
+ "valid",
533
+ )
534
+ / small_h
535
+ )
536
+
537
+ big_h_pool = (
538
+ lax.reduce_window(
539
+ pooled_d,
540
+ 0.0,
541
+ lax.add,
542
+ (1, 1, big_h, 1, 1),
543
+ (1, 1, 1, 1, 1),
544
+ "valid",
545
+ )
546
+ / big_h
547
+ )
548
+
549
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
550
+ pooled_h = jnp.take(combined_h, gather_h, axis=2)
551
+
552
+ small_w_pool = (
553
+ lax.reduce_window(
554
+ pooled_h,
555
+ 0.0,
556
+ lax.add,
557
+ (1, 1, 1, small_w, 1),
558
+ (1, 1, 1, 1, 1),
559
+ "valid",
560
+ )
561
+ / small_w
562
+ )
563
+
564
+ big_w_pool = (
565
+ lax.reduce_window(
566
+ pooled_h,
567
+ 0.0,
568
+ lax.add,
569
+ (1, 1, 1, big_w, 1),
570
+ (1, 1, 1, 1, 1),
571
+ "valid",
572
+ )
573
+ / big_w
574
+ )
575
+
576
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
577
+ out = jnp.take(combined_w, gather_w, axis=3)
578
+
579
+ if data_format == "channels_first":
580
+ out = jnp.transpose(out, (0, 4, 1, 2, 3))
581
+
582
+ return out
583
+
584
+
585
+ def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
586
+ if isinstance(output_size, int):
587
+ output_size = (output_size, output_size, output_size)
588
+
589
+ if data_format == "channels_first":
590
+ inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
591
+
592
+ n, d, h, w, c = inputs.shape
593
+ out_d, out_h, out_w = output_size
594
+
595
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
596
+ gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
597
+
598
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
599
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
600
+
601
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
602
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
603
+
604
+ small_d_pool = lax.reduce_window(
605
+ inputs,
606
+ -jnp.inf,
607
+ lax.max,
608
+ (1, small_d, 1, 1, 1),
609
+ (1, 1, 1, 1, 1),
610
+ "valid",
611
+ )
612
+
613
+ big_d_pool = lax.reduce_window(
614
+ inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
615
+ )
616
+
617
+ combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
618
+ pooled_d = jnp.take(combined_d, gather_d, axis=1)
619
+
620
+ small_h_pool = lax.reduce_window(
621
+ pooled_d,
622
+ -jnp.inf,
623
+ lax.max,
624
+ (1, 1, small_h, 1, 1),
625
+ (1, 1, 1, 1, 1),
626
+ "valid",
627
+ )
628
+
629
+ big_h_pool = lax.reduce_window(
630
+ pooled_d,
631
+ -jnp.inf,
632
+ lax.max,
633
+ (1, 1, big_h, 1, 1),
634
+ (1, 1, 1, 1, 1),
635
+ "valid",
636
+ )
637
+
638
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
639
+ pooled_h = jnp.take(combined_h, gather_h, axis=2)
640
+
641
+ small_w_pool = lax.reduce_window(
642
+ pooled_h,
643
+ -jnp.inf,
644
+ lax.max,
645
+ (1, 1, 1, small_w, 1),
646
+ (1, 1, 1, 1, 1),
647
+ "valid",
648
+ )
649
+
650
+ big_w_pool = lax.reduce_window(
651
+ pooled_h,
652
+ -jnp.inf,
653
+ lax.max,
654
+ (1, 1, 1, big_w, 1),
655
+ (1, 1, 1, 1, 1),
656
+ "valid",
657
+ )
658
+
659
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
660
+ out = jnp.take(combined_w, gather_w, axis=3)
661
+
662
+ if data_format == "channels_first":
663
+ out = jnp.transpose(out, (0, 4, 1, 2, 3))
664
+
665
+ return out
666
+
667
+
668
+ def adaptive_average_pool(inputs, output_size, data_format=None):
669
+ data_format = backend.standardize_data_format(data_format)
670
+ dims = inputs.ndim - 2
671
+ if dims == 1:
672
+ return _adaptive_average_pool1d(inputs, output_size, data_format)
673
+ if dims == 2:
674
+ return _adaptive_average_pool2d(inputs, output_size, data_format)
675
+ if dims == 3:
676
+ return _adaptive_average_pool3d(inputs, output_size, data_format)
677
+ raise ValueError("adaptive_average_pool supports only 1D/2D/3D inputs")
678
+
679
+
680
+ def adaptive_max_pool(inputs, output_size, data_format=None):
681
+ data_format = backend.standardize_data_format(data_format)
682
+ dims = inputs.ndim - 2
683
+ if dims == 1:
684
+ return _adaptive_max_pool1d(inputs, output_size, data_format)
685
+ if dims == 2:
686
+ return _adaptive_max_pool2d(inputs, output_size, data_format)
687
+ if dims == 3:
688
+ return _adaptive_max_pool3d(inputs, output_size, data_format)
689
+ raise ValueError("adaptive_max_pool supports only 1D/2D/3D inputs")
690
+
691
+
292
692
  def _convert_to_lax_conv_dimension_numbers(
293
693
  num_spatial_dims,
294
694
  data_format="channels_last",
@@ -355,7 +755,7 @@ def conv(
355
755
  feature_group_count = channels // kernel_in_channels
356
756
  kernel = convert_to_tensor(kernel)
357
757
  inputs = convert_to_tensor(inputs, dtype=kernel.dtype)
358
- return jax.lax.conv_general_dilated(
758
+ result = jax.lax.conv_general_dilated(
359
759
  inputs,
360
760
  kernel,
361
761
  strides,
@@ -364,6 +764,14 @@ def conv(
364
764
  dimension_numbers=dimension_numbers,
365
765
  feature_group_count=feature_group_count,
366
766
  )
767
+ if result.size == 0:
768
+ raise ValueError(
769
+ "The convolution operation resulted in an empty output. "
770
+ "This can happen if the input is too small for the given "
771
+ "kernel size, strides, dilation rate, and padding mode. "
772
+ "Please check the input shape and convolution parameters."
773
+ )
774
+ return result
367
775
 
368
776
 
369
777
  def depthwise_conv(
@@ -501,7 +909,7 @@ def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
501
909
  values = jnp.greater_equal(jnp.ravel(x), 0).astype(dtype)
502
910
  values_count = values.shape[0]
503
911
  indices = [jnp.arange(dim) for dim in x.shape]
504
- indices = jnp.meshgrid(*indices, indexing="ij")
912
+ indices = list(jnp.meshgrid(*indices, indexing="ij"))
505
913
  indices.insert(axis, jnp.maximum(x, 0)) # Deal with negative indices
506
914
  indices = [a.reshape(values_count, 1).astype("int32") for a in indices]
507
915
  indices = jnp.concatenate(indices, axis=1)
@@ -1063,25 +1471,42 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
1063
1471
  # Only support at least Ampere
1064
1472
  if not check_compute_capability("8.0"):
1065
1473
  raise RuntimeError("Require at least Ampere arch to run")
1066
- # Check inputs layout
1474
+
1475
+ # Inspect inputs of `check_layout`
1067
1476
  check_layout_params = list(
1068
1477
  inspect.signature(check_layout).parameters.keys()
1069
1478
  )
1070
1479
  for known_param in ("query", "key", "value", "bias", "layout"):
1071
1480
  check_layout_params.remove(known_param)
1072
1481
  # Defaults to `None` when not specified.
1073
- kwargs = {key: None for key in check_layout_params}
1482
+ check_layout_kwargs = {key: None for key in check_layout_params}
1074
1483
  check_layout(
1075
- query, key, value, bias, layout=_normalize_layout("BTNH"), **kwargs
1076
- )
1077
- check_is_flash_attention(
1078
1484
  query,
1079
1485
  key,
1080
- _normalize_layout("BTNH"),
1081
- cudnn_version,
1082
- bias is not None,
1083
- is_training=False,
1486
+ value,
1487
+ bias,
1488
+ layout=_normalize_layout("BTNH"),
1489
+ **check_layout_kwargs,
1084
1490
  )
1491
+
1492
+ # Inspect inputs of `check_is_flash_attention`
1493
+ check_is_flash_attention_params = inspect.signature(
1494
+ check_is_flash_attention
1495
+ ).parameters
1496
+ check_is_flash_attention_kwargs = {
1497
+ "query": query,
1498
+ "key": key,
1499
+ "value": value,
1500
+ "layout": _normalize_layout("BTNH"),
1501
+ "cudnn_version": cudnn_version,
1502
+ "has_bias": bias is not None,
1503
+ "is_training": False,
1504
+ }
1505
+ # Remove unsupported arguments
1506
+ for param in list(check_is_flash_attention_kwargs.keys()):
1507
+ if param not in check_is_flash_attention_params:
1508
+ check_is_flash_attention_kwargs.pop(param)
1509
+ check_is_flash_attention(**check_is_flash_attention_kwargs)
1085
1510
  return True
1086
1511
  except:
1087
1512
  if raise_error:
@@ -1332,25 +1757,32 @@ def dot_product_attention(
1332
1757
  if custom_mask is None and is_causal:
1333
1758
  custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
1334
1759
 
1335
- try:
1336
- output = wrap_flash_attention(
1337
- query_tpu_layout,
1338
- key_tpu_layout,
1339
- value_tpu_layout,
1340
- decoder_segment_ids=decoder_segment_ids,
1341
- custom_mask=custom_mask,
1342
- attn_logits_soft_cap=attn_logits_soft_cap,
1343
- head_shards=head_shards,
1344
- q_seq_shards=q_seq_shards,
1345
- )
1346
- # Transpose output back to Keras layout
1347
- return jnp.transpose(output, axes=(0, 2, 1, 3))
1348
- except Exception:
1349
- logging.exception(
1350
- "Failed to apply Splash kernel for flash attention. "
1351
- "Falling back to JAX native dot_product_attention."
1352
- )
1760
+ # Splash attention kernel requires concrete mask values for hashing.
1761
+ # If the mask is a tracer (e.g. inside a scan/loop), we must fall back.
1762
+ if isinstance(mask, jax.core.Tracer) or isinstance(
1763
+ custom_mask, jax.core.Tracer
1764
+ ):
1353
1765
  flash_attention = False
1766
+ else:
1767
+ try:
1768
+ output = wrap_flash_attention(
1769
+ query_tpu_layout,
1770
+ key_tpu_layout,
1771
+ value_tpu_layout,
1772
+ decoder_segment_ids=decoder_segment_ids,
1773
+ custom_mask=custom_mask,
1774
+ attn_logits_soft_cap=attn_logits_soft_cap,
1775
+ head_shards=head_shards,
1776
+ q_seq_shards=q_seq_shards,
1777
+ )
1778
+ # Transpose output back to Keras layout
1779
+ return jnp.transpose(output, axes=(0, 2, 1, 3))
1780
+ except Exception:
1781
+ logging.exception(
1782
+ "Failed to apply Splash kernel for flash attention. "
1783
+ "Falling back to JAX native dot_product_attention."
1784
+ )
1785
+ flash_attention = False
1354
1786
 
1355
1787
  # JAX native dot_product_attention for GPU or fallback for TPU
1356
1788
  if hasattr(jax.nn, "dot_product_attention"):
@@ -1396,6 +1828,11 @@ def dot_product_attention(
1396
1828
 
1397
1829
  def _reshape_to_grouped(t):
1398
1830
  if t is not None:
1831
+ while t.ndim < 4:
1832
+ if t.ndim == 3 and t.shape[1] == N:
1833
+ t = jnp.expand_dims(t, axis=2)
1834
+ else:
1835
+ t = jnp.expand_dims(t, axis=1)
1399
1836
  tB, tN, tT, tS = t.shape
1400
1837
  if tN == 1:
1401
1838
  t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
@@ -1413,3 +1850,46 @@ def dot_product_attention(
1413
1850
  )
1414
1851
  encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
1415
1852
  return jnp.reshape(encoded, output_shape)
1853
+
1854
+
1855
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1856
+ """JAX implementation of Unfold.
1857
+ Extract sliding local blocks from a **NCHW** batched image tensor.
1858
+
1859
+ Args:
1860
+ input: 4-D tensor, shape (N, C, H, W) **required**.
1861
+ kernel_size: int or (kH, kW)
1862
+ dilation: int or (dH, dW), default 1
1863
+ padding: int or (pH, pW), default 0
1864
+ stride: int or (sH, sW), default 1
1865
+
1866
+ Returns:
1867
+ 3-D tensor, shape (N, C*kH*kW, L)
1868
+ """
1869
+
1870
+ def _pair(x):
1871
+ return (x, x) if isinstance(x, int) else x
1872
+
1873
+ k = _pair(kernel_size)
1874
+ d = _pair(dilation)
1875
+ p = _pair(padding)
1876
+ s = _pair(stride)
1877
+
1878
+ N, C, H, W = input.shape
1879
+
1880
+ # ---- padding ----
1881
+ if any(_ > 0 for _ in p):
1882
+ input = jnp.pad(input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])))
1883
+
1884
+ patches = lax.conv_general_dilated_patches(
1885
+ input,
1886
+ filter_shape=k,
1887
+ window_strides=s,
1888
+ padding="VALID", # has padde
1889
+ rhs_dilation=d,
1890
+ dimension_numbers=("NCHW", "OIHW", "NCHW"), # only support 'NCHW'
1891
+ ) # shape: (N, C*kH*kW, oH, oW)
1892
+
1893
+ # ---- reshape -> (N, C*kH*kW, L) ----
1894
+ _, CKK, oH, oW = patches.shape
1895
+ return patches.reshape(N, CKK, oH * oW)