keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,9 @@ from jax.experimental.pallas.ops.tpu.splash_attention import (
16
16
  )
17
17
 
18
18
  from keras.src import backend
19
+ from keras.src.backend.common.backend_utils import (
20
+ compute_adaptive_pooling_window_sizes,
21
+ )
19
22
  from keras.src.backend.common.backend_utils import (
20
23
  compute_conv_transpose_padding_args_for_jax,
21
24
  )
@@ -289,6 +292,403 @@ def average_pool(
289
292
  return pooled / window_counts
290
293
 
291
294
 
295
+ def _compute_adaptive_pooling_gather_indices(
296
+ input_dim, output_size, big_window
297
+ ):
298
+ """Compute gather indices for Two-Pool Gather method."""
299
+ window_starts = jnp.floor(
300
+ (jnp.arange(output_size) * input_dim) / output_size
301
+ ).astype(jnp.int32)
302
+
303
+ window_ends = jnp.ceil(
304
+ (jnp.arange(1, output_size + 1) * input_dim) / output_size
305
+ ).astype(jnp.int32)
306
+
307
+ window_sizes = window_ends - window_starts
308
+ is_big = window_sizes == big_window
309
+
310
+ small_window = big_window - 1
311
+ small_len = input_dim - small_window + 1
312
+
313
+ small_indices = window_starts
314
+ big_indices = window_starts + small_len
315
+
316
+ gather = jnp.where(is_big, big_indices, small_indices)
317
+ return gather.astype(jnp.int32)
318
+
319
+
320
+ def _adaptive_average_pool1d(inputs, output_size, data_format="channels_first"):
321
+ if isinstance(output_size, int):
322
+ output_size = (output_size,)
323
+
324
+ if data_format == "channels_first":
325
+ inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL → NLC
326
+
327
+ n, l, c = inputs.shape
328
+ out_l = output_size[0]
329
+
330
+ small, big = compute_adaptive_pooling_window_sizes(l, out_l)
331
+ gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
332
+
333
+ small_pool = (
334
+ lax.reduce_window(
335
+ inputs, 0.0, lax.add, (1, small, 1), (1, 1, 1), "valid"
336
+ )
337
+ / small
338
+ )
339
+
340
+ big_pool = (
341
+ lax.reduce_window(inputs, 0.0, lax.add, (1, big, 1), (1, 1, 1), "valid")
342
+ / big
343
+ )
344
+
345
+ combined = jnp.concatenate([small_pool, big_pool], axis=1)
346
+ out = jnp.take(combined, gather, axis=1)
347
+
348
+ if data_format == "channels_first":
349
+ out = jnp.transpose(out, (0, 2, 1))
350
+
351
+ return out
352
+
353
+
354
+ def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
355
+ if isinstance(output_size, int):
356
+ output_size = (output_size,)
357
+
358
+ if data_format == "channels_first":
359
+ inputs = jnp.transpose(inputs, (0, 2, 1))
360
+
361
+ n, l, c = inputs.shape
362
+ out_l = output_size[0]
363
+
364
+ small, big = compute_adaptive_pooling_window_sizes(l, out_l)
365
+ gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
366
+
367
+ small_pool = lax.reduce_window(
368
+ inputs, -jnp.inf, lax.max, (1, small, 1), (1, 1, 1), "valid"
369
+ )
370
+
371
+ big_pool = lax.reduce_window(
372
+ inputs, -jnp.inf, lax.max, (1, big, 1), (1, 1, 1), "valid"
373
+ )
374
+
375
+ combined = jnp.concatenate([small_pool, big_pool], axis=1)
376
+ out = jnp.take(combined, gather, axis=1)
377
+
378
+ if data_format == "channels_first":
379
+ out = jnp.transpose(out, (0, 2, 1))
380
+
381
+ return out
382
+
383
+
384
+ def _adaptive_average_pool2d(inputs, output_size, data_format="channels_first"):
385
+ if isinstance(output_size, int):
386
+ output_size = (output_size, output_size)
387
+
388
+ if data_format == "channels_first":
389
+ inputs = jnp.transpose(inputs, (0, 2, 3, 1))
390
+
391
+ n, h, w, c = inputs.shape
392
+ out_h, out_w = output_size
393
+
394
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
395
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
396
+
397
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
398
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
399
+
400
+ small_h_pool = (
401
+ lax.reduce_window(
402
+ inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
403
+ )
404
+ / small_h
405
+ )
406
+
407
+ big_h_pool = (
408
+ lax.reduce_window(
409
+ inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
410
+ )
411
+ / big_h
412
+ )
413
+
414
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
415
+ pooled_h = jnp.take(combined_h, gather_h, axis=1)
416
+
417
+ small_w_pool = (
418
+ lax.reduce_window(
419
+ pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
420
+ )
421
+ / small_w
422
+ )
423
+
424
+ big_w_pool = (
425
+ lax.reduce_window(
426
+ pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
427
+ )
428
+ / big_w
429
+ )
430
+
431
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
432
+ out = jnp.take(combined_w, gather_w, axis=2)
433
+
434
+ if data_format == "channels_first":
435
+ out = jnp.transpose(out, (0, 3, 1, 2))
436
+
437
+ return out
438
+
439
+
440
+ def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
441
+ if isinstance(output_size, int):
442
+ output_size = (output_size, output_size)
443
+
444
+ if data_format == "channels_first":
445
+ inputs = jnp.transpose(inputs, (0, 2, 3, 1))
446
+
447
+ n, h, w, c = inputs.shape
448
+ out_h, out_w = output_size
449
+
450
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
451
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
452
+
453
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
454
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
455
+
456
+ small_h_pool = lax.reduce_window(
457
+ inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
458
+ )
459
+
460
+ big_h_pool = lax.reduce_window(
461
+ inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
462
+ )
463
+
464
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
465
+ pooled_h = jnp.take(combined_h, gather_h, axis=1)
466
+
467
+ small_w_pool = lax.reduce_window(
468
+ pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
469
+ )
470
+
471
+ big_w_pool = lax.reduce_window(
472
+ pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
473
+ )
474
+
475
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
476
+ out = jnp.take(combined_w, gather_w, axis=2)
477
+
478
+ if data_format == "channels_first":
479
+ out = jnp.transpose(out, (0, 3, 1, 2))
480
+
481
+ return out
482
+
483
+
484
+ def _adaptive_average_pool3d(inputs, output_size, data_format="channels_first"):
485
+ if isinstance(output_size, int):
486
+ output_size = (output_size, output_size, output_size)
487
+
488
+ if data_format == "channels_first":
489
+ inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
490
+
491
+ n, d, h, w, c = inputs.shape
492
+ out_d, out_h, out_w = output_size
493
+
494
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
495
+ gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
496
+
497
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
498
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
499
+
500
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
501
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
502
+
503
+ small_d_pool = (
504
+ lax.reduce_window(
505
+ inputs,
506
+ 0.0,
507
+ lax.add,
508
+ (1, small_d, 1, 1, 1),
509
+ (1, 1, 1, 1, 1),
510
+ "valid",
511
+ )
512
+ / small_d
513
+ )
514
+
515
+ big_d_pool = (
516
+ lax.reduce_window(
517
+ inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
518
+ )
519
+ / big_d
520
+ )
521
+
522
+ combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
523
+ pooled_d = jnp.take(combined_d, gather_d, axis=1)
524
+
525
+ small_h_pool = (
526
+ lax.reduce_window(
527
+ pooled_d,
528
+ 0.0,
529
+ lax.add,
530
+ (1, 1, small_h, 1, 1),
531
+ (1, 1, 1, 1, 1),
532
+ "valid",
533
+ )
534
+ / small_h
535
+ )
536
+
537
+ big_h_pool = (
538
+ lax.reduce_window(
539
+ pooled_d,
540
+ 0.0,
541
+ lax.add,
542
+ (1, 1, big_h, 1, 1),
543
+ (1, 1, 1, 1, 1),
544
+ "valid",
545
+ )
546
+ / big_h
547
+ )
548
+
549
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
550
+ pooled_h = jnp.take(combined_h, gather_h, axis=2)
551
+
552
+ small_w_pool = (
553
+ lax.reduce_window(
554
+ pooled_h,
555
+ 0.0,
556
+ lax.add,
557
+ (1, 1, 1, small_w, 1),
558
+ (1, 1, 1, 1, 1),
559
+ "valid",
560
+ )
561
+ / small_w
562
+ )
563
+
564
+ big_w_pool = (
565
+ lax.reduce_window(
566
+ pooled_h,
567
+ 0.0,
568
+ lax.add,
569
+ (1, 1, 1, big_w, 1),
570
+ (1, 1, 1, 1, 1),
571
+ "valid",
572
+ )
573
+ / big_w
574
+ )
575
+
576
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
577
+ out = jnp.take(combined_w, gather_w, axis=3)
578
+
579
+ if data_format == "channels_first":
580
+ out = jnp.transpose(out, (0, 4, 1, 2, 3))
581
+
582
+ return out
583
+
584
+
585
+ def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
586
+ if isinstance(output_size, int):
587
+ output_size = (output_size, output_size, output_size)
588
+
589
+ if data_format == "channels_first":
590
+ inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
591
+
592
+ n, d, h, w, c = inputs.shape
593
+ out_d, out_h, out_w = output_size
594
+
595
+ small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
596
+ gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
597
+
598
+ small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
599
+ gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
600
+
601
+ small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
602
+ gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
603
+
604
+ small_d_pool = lax.reduce_window(
605
+ inputs,
606
+ -jnp.inf,
607
+ lax.max,
608
+ (1, small_d, 1, 1, 1),
609
+ (1, 1, 1, 1, 1),
610
+ "valid",
611
+ )
612
+
613
+ big_d_pool = lax.reduce_window(
614
+ inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
615
+ )
616
+
617
+ combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
618
+ pooled_d = jnp.take(combined_d, gather_d, axis=1)
619
+
620
+ small_h_pool = lax.reduce_window(
621
+ pooled_d,
622
+ -jnp.inf,
623
+ lax.max,
624
+ (1, 1, small_h, 1, 1),
625
+ (1, 1, 1, 1, 1),
626
+ "valid",
627
+ )
628
+
629
+ big_h_pool = lax.reduce_window(
630
+ pooled_d,
631
+ -jnp.inf,
632
+ lax.max,
633
+ (1, 1, big_h, 1, 1),
634
+ (1, 1, 1, 1, 1),
635
+ "valid",
636
+ )
637
+
638
+ combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
639
+ pooled_h = jnp.take(combined_h, gather_h, axis=2)
640
+
641
+ small_w_pool = lax.reduce_window(
642
+ pooled_h,
643
+ -jnp.inf,
644
+ lax.max,
645
+ (1, 1, 1, small_w, 1),
646
+ (1, 1, 1, 1, 1),
647
+ "valid",
648
+ )
649
+
650
+ big_w_pool = lax.reduce_window(
651
+ pooled_h,
652
+ -jnp.inf,
653
+ lax.max,
654
+ (1, 1, 1, big_w, 1),
655
+ (1, 1, 1, 1, 1),
656
+ "valid",
657
+ )
658
+
659
+ combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
660
+ out = jnp.take(combined_w, gather_w, axis=3)
661
+
662
+ if data_format == "channels_first":
663
+ out = jnp.transpose(out, (0, 4, 1, 2, 3))
664
+
665
+ return out
666
+
667
+
668
+ def adaptive_average_pool(inputs, output_size, data_format=None):
669
+ data_format = backend.standardize_data_format(data_format)
670
+ dims = inputs.ndim - 2
671
+ if dims == 1:
672
+ return _adaptive_average_pool1d(inputs, output_size, data_format)
673
+ if dims == 2:
674
+ return _adaptive_average_pool2d(inputs, output_size, data_format)
675
+ if dims == 3:
676
+ return _adaptive_average_pool3d(inputs, output_size, data_format)
677
+ raise ValueError("adaptive_average_pool supports only 1D/2D/3D inputs")
678
+
679
+
680
+ def adaptive_max_pool(inputs, output_size, data_format=None):
681
+ data_format = backend.standardize_data_format(data_format)
682
+ dims = inputs.ndim - 2
683
+ if dims == 1:
684
+ return _adaptive_max_pool1d(inputs, output_size, data_format)
685
+ if dims == 2:
686
+ return _adaptive_max_pool2d(inputs, output_size, data_format)
687
+ if dims == 3:
688
+ return _adaptive_max_pool3d(inputs, output_size, data_format)
689
+ raise ValueError("adaptive_max_pool supports only 1D/2D/3D inputs")
690
+
691
+
292
692
  def _convert_to_lax_conv_dimension_numbers(
293
693
  num_spatial_dims,
294
694
  data_format="channels_last",
@@ -355,7 +755,7 @@ def conv(
355
755
  feature_group_count = channels // kernel_in_channels
356
756
  kernel = convert_to_tensor(kernel)
357
757
  inputs = convert_to_tensor(inputs, dtype=kernel.dtype)
358
- return jax.lax.conv_general_dilated(
758
+ result = jax.lax.conv_general_dilated(
359
759
  inputs,
360
760
  kernel,
361
761
  strides,
@@ -364,6 +764,14 @@ def conv(
364
764
  dimension_numbers=dimension_numbers,
365
765
  feature_group_count=feature_group_count,
366
766
  )
767
+ if result.size == 0:
768
+ raise ValueError(
769
+ "The convolution operation resulted in an empty output. "
770
+ "This can happen if the input is too small for the given "
771
+ "kernel size, strides, dilation rate, and padding mode. "
772
+ "Please check the input shape and convolution parameters."
773
+ )
774
+ return result
367
775
 
368
776
 
369
777
  def depthwise_conv(
@@ -396,6 +804,8 @@ def depthwise_conv(
396
804
  feature_group_count = (
397
805
  inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
398
806
  )
807
+ kernel = convert_to_tensor(kernel)
808
+ inputs = convert_to_tensor(inputs)
399
809
  kernel = jnp.reshape(
400
810
  kernel,
401
811
  kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
@@ -499,7 +909,7 @@ def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
499
909
  values = jnp.greater_equal(jnp.ravel(x), 0).astype(dtype)
500
910
  values_count = values.shape[0]
501
911
  indices = [jnp.arange(dim) for dim in x.shape]
502
- indices = jnp.meshgrid(*indices, indexing="ij")
912
+ indices = list(jnp.meshgrid(*indices, indexing="ij"))
503
913
  indices.insert(axis, jnp.maximum(x, 0)) # Deal with negative indices
504
914
  indices = [a.reshape(values_count, 1).astype("int32") for a in indices]
505
915
  indices = jnp.concatenate(indices, axis=1)
@@ -1330,25 +1740,32 @@ def dot_product_attention(
1330
1740
  if custom_mask is None and is_causal:
1331
1741
  custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
1332
1742
 
1333
- try:
1334
- output = wrap_flash_attention(
1335
- query_tpu_layout,
1336
- key_tpu_layout,
1337
- value_tpu_layout,
1338
- decoder_segment_ids=decoder_segment_ids,
1339
- custom_mask=custom_mask,
1340
- attn_logits_soft_cap=attn_logits_soft_cap,
1341
- head_shards=head_shards,
1342
- q_seq_shards=q_seq_shards,
1343
- )
1344
- # Transpose output back to Keras layout
1345
- return jnp.transpose(output, axes=(0, 2, 1, 3))
1346
- except Exception:
1347
- logging.exception(
1348
- "Failed to apply Splash kernel for flash attention. "
1349
- "Falling back to JAX native dot_product_attention."
1350
- )
1743
+ # Splash attention kernel requires concrete mask values for hashing.
1744
+ # If the mask is a tracer (e.g. inside a scan/loop), we must fall back.
1745
+ if isinstance(mask, jax.core.Tracer) or isinstance(
1746
+ custom_mask, jax.core.Tracer
1747
+ ):
1351
1748
  flash_attention = False
1749
+ else:
1750
+ try:
1751
+ output = wrap_flash_attention(
1752
+ query_tpu_layout,
1753
+ key_tpu_layout,
1754
+ value_tpu_layout,
1755
+ decoder_segment_ids=decoder_segment_ids,
1756
+ custom_mask=custom_mask,
1757
+ attn_logits_soft_cap=attn_logits_soft_cap,
1758
+ head_shards=head_shards,
1759
+ q_seq_shards=q_seq_shards,
1760
+ )
1761
+ # Transpose output back to Keras layout
1762
+ return jnp.transpose(output, axes=(0, 2, 1, 3))
1763
+ except Exception:
1764
+ logging.exception(
1765
+ "Failed to apply Splash kernel for flash attention. "
1766
+ "Falling back to JAX native dot_product_attention."
1767
+ )
1768
+ flash_attention = False
1352
1769
 
1353
1770
  # JAX native dot_product_attention for GPU or fallback for TPU
1354
1771
  if hasattr(jax.nn, "dot_product_attention"):
@@ -1394,6 +1811,11 @@ def dot_product_attention(
1394
1811
 
1395
1812
  def _reshape_to_grouped(t):
1396
1813
  if t is not None:
1814
+ while t.ndim < 4:
1815
+ if t.ndim == 3 and t.shape[1] == N:
1816
+ t = jnp.expand_dims(t, axis=2)
1817
+ else:
1818
+ t = jnp.expand_dims(t, axis=1)
1397
1819
  tB, tN, tT, tS = t.shape
1398
1820
  if tN == 1:
1399
1821
  t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
@@ -1411,3 +1833,46 @@ def dot_product_attention(
1411
1833
  )
1412
1834
  encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
1413
1835
  return jnp.reshape(encoded, output_shape)
1836
+
1837
+
1838
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1839
+ """JAX implementation of Unfold.
1840
+ Extract sliding local blocks from a **NCHW** batched image tensor.
1841
+
1842
+ Args:
1843
+ input: 4-D tensor, shape (N, C, H, W) **required**.
1844
+ kernel_size: int or (kH, kW)
1845
+ dilation: int or (dH, dW), default 1
1846
+ padding: int or (pH, pW), default 0
1847
+ stride: int or (sH, sW), default 1
1848
+
1849
+ Returns:
1850
+ 3-D tensor, shape (N, C*kH*kW, L)
1851
+ """
1852
+
1853
+ def _pair(x):
1854
+ return (x, x) if isinstance(x, int) else x
1855
+
1856
+ k = _pair(kernel_size)
1857
+ d = _pair(dilation)
1858
+ p = _pair(padding)
1859
+ s = _pair(stride)
1860
+
1861
+ N, C, H, W = input.shape
1862
+
1863
+ # ---- padding ----
1864
+ if any(_ > 0 for _ in p):
1865
+ input = jnp.pad(input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])))
1866
+
1867
+ patches = lax.conv_general_dilated_patches(
1868
+ input,
1869
+ filter_shape=k,
1870
+ window_strides=s,
1871
+ padding="VALID", # has padde
1872
+ rhs_dilation=d,
1873
+ dimension_numbers=("NCHW", "OIHW", "NCHW"), # only support 'NCHW'
1874
+ ) # shape: (N, C*kH*kW, oH, oW)
1875
+
1876
+ # ---- reshape -> (N, C*kH*kW, L) ----
1877
+ _, CKK, oH, oW = patches.shape
1878
+ return patches.reshape(N, CKK, oH * oW)