keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +12 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +12 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +1 -1
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +33 -16
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +485 -20
- keras/src/backend/jax/numpy.py +92 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +76 -7
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1030 -185
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +264 -54
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +84 -8
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +299 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +191 -172
- keras/src/layers/core/einsum_dense.py +235 -186
- keras/src/layers/core/embedding.py +83 -93
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +390 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +40 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/losses/loss.py +1 -1
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +156 -27
- keras/src/ops/image.py +184 -3
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +541 -43
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +12 -1
- keras/src/quantizers/gptq.py +8 -6
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +150 -78
- keras/src/quantizers/quantization_config.py +232 -0
- keras/src/quantizers/quantizers.py +114 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +4 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +14 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +187 -36
- keras/src/utils/module_utils.py +18 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/rng_utils.py +9 -1
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
keras/src/backend/jax/nn.py
CHANGED
|
@@ -16,6 +16,9 @@ from jax.experimental.pallas.ops.tpu.splash_attention import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
from keras.src import backend
|
|
19
|
+
from keras.src.backend.common.backend_utils import (
|
|
20
|
+
compute_adaptive_pooling_window_sizes,
|
|
21
|
+
)
|
|
19
22
|
from keras.src.backend.common.backend_utils import (
|
|
20
23
|
compute_conv_transpose_padding_args_for_jax,
|
|
21
24
|
)
|
|
@@ -289,6 +292,403 @@ def average_pool(
|
|
|
289
292
|
return pooled / window_counts
|
|
290
293
|
|
|
291
294
|
|
|
295
|
+
def _compute_adaptive_pooling_gather_indices(
|
|
296
|
+
input_dim, output_size, big_window
|
|
297
|
+
):
|
|
298
|
+
"""Compute gather indices for Two-Pool Gather method."""
|
|
299
|
+
window_starts = jnp.floor(
|
|
300
|
+
(jnp.arange(output_size) * input_dim) / output_size
|
|
301
|
+
).astype(jnp.int32)
|
|
302
|
+
|
|
303
|
+
window_ends = jnp.ceil(
|
|
304
|
+
(jnp.arange(1, output_size + 1) * input_dim) / output_size
|
|
305
|
+
).astype(jnp.int32)
|
|
306
|
+
|
|
307
|
+
window_sizes = window_ends - window_starts
|
|
308
|
+
is_big = window_sizes == big_window
|
|
309
|
+
|
|
310
|
+
small_window = big_window - 1
|
|
311
|
+
small_len = input_dim - small_window + 1
|
|
312
|
+
|
|
313
|
+
small_indices = window_starts
|
|
314
|
+
big_indices = window_starts + small_len
|
|
315
|
+
|
|
316
|
+
gather = jnp.where(is_big, big_indices, small_indices)
|
|
317
|
+
return gather.astype(jnp.int32)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _adaptive_average_pool1d(inputs, output_size, data_format="channels_first"):
|
|
321
|
+
if isinstance(output_size, int):
|
|
322
|
+
output_size = (output_size,)
|
|
323
|
+
|
|
324
|
+
if data_format == "channels_first":
|
|
325
|
+
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL → NLC
|
|
326
|
+
|
|
327
|
+
n, l, c = inputs.shape
|
|
328
|
+
out_l = output_size[0]
|
|
329
|
+
|
|
330
|
+
small, big = compute_adaptive_pooling_window_sizes(l, out_l)
|
|
331
|
+
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
|
|
332
|
+
|
|
333
|
+
small_pool = (
|
|
334
|
+
lax.reduce_window(
|
|
335
|
+
inputs, 0.0, lax.add, (1, small, 1), (1, 1, 1), "valid"
|
|
336
|
+
)
|
|
337
|
+
/ small
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
big_pool = (
|
|
341
|
+
lax.reduce_window(inputs, 0.0, lax.add, (1, big, 1), (1, 1, 1), "valid")
|
|
342
|
+
/ big
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
combined = jnp.concatenate([small_pool, big_pool], axis=1)
|
|
346
|
+
out = jnp.take(combined, gather, axis=1)
|
|
347
|
+
|
|
348
|
+
if data_format == "channels_first":
|
|
349
|
+
out = jnp.transpose(out, (0, 2, 1))
|
|
350
|
+
|
|
351
|
+
return out
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
|
|
355
|
+
if isinstance(output_size, int):
|
|
356
|
+
output_size = (output_size,)
|
|
357
|
+
|
|
358
|
+
if data_format == "channels_first":
|
|
359
|
+
inputs = jnp.transpose(inputs, (0, 2, 1))
|
|
360
|
+
|
|
361
|
+
n, l, c = inputs.shape
|
|
362
|
+
out_l = output_size[0]
|
|
363
|
+
|
|
364
|
+
small, big = compute_adaptive_pooling_window_sizes(l, out_l)
|
|
365
|
+
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
|
|
366
|
+
|
|
367
|
+
small_pool = lax.reduce_window(
|
|
368
|
+
inputs, -jnp.inf, lax.max, (1, small, 1), (1, 1, 1), "valid"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
big_pool = lax.reduce_window(
|
|
372
|
+
inputs, -jnp.inf, lax.max, (1, big, 1), (1, 1, 1), "valid"
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
combined = jnp.concatenate([small_pool, big_pool], axis=1)
|
|
376
|
+
out = jnp.take(combined, gather, axis=1)
|
|
377
|
+
|
|
378
|
+
if data_format == "channels_first":
|
|
379
|
+
out = jnp.transpose(out, (0, 2, 1))
|
|
380
|
+
|
|
381
|
+
return out
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _adaptive_average_pool2d(inputs, output_size, data_format="channels_first"):
|
|
385
|
+
if isinstance(output_size, int):
|
|
386
|
+
output_size = (output_size, output_size)
|
|
387
|
+
|
|
388
|
+
if data_format == "channels_first":
|
|
389
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 1))
|
|
390
|
+
|
|
391
|
+
n, h, w, c = inputs.shape
|
|
392
|
+
out_h, out_w = output_size
|
|
393
|
+
|
|
394
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
395
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
396
|
+
|
|
397
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
398
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
399
|
+
|
|
400
|
+
small_h_pool = (
|
|
401
|
+
lax.reduce_window(
|
|
402
|
+
inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
403
|
+
)
|
|
404
|
+
/ small_h
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
big_h_pool = (
|
|
408
|
+
lax.reduce_window(
|
|
409
|
+
inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
410
|
+
)
|
|
411
|
+
/ big_h
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
|
|
415
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=1)
|
|
416
|
+
|
|
417
|
+
small_w_pool = (
|
|
418
|
+
lax.reduce_window(
|
|
419
|
+
pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
|
|
420
|
+
)
|
|
421
|
+
/ small_w
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
big_w_pool = (
|
|
425
|
+
lax.reduce_window(
|
|
426
|
+
pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
|
|
427
|
+
)
|
|
428
|
+
/ big_w
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
|
|
432
|
+
out = jnp.take(combined_w, gather_w, axis=2)
|
|
433
|
+
|
|
434
|
+
if data_format == "channels_first":
|
|
435
|
+
out = jnp.transpose(out, (0, 3, 1, 2))
|
|
436
|
+
|
|
437
|
+
return out
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
|
|
441
|
+
if isinstance(output_size, int):
|
|
442
|
+
output_size = (output_size, output_size)
|
|
443
|
+
|
|
444
|
+
if data_format == "channels_first":
|
|
445
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 1))
|
|
446
|
+
|
|
447
|
+
n, h, w, c = inputs.shape
|
|
448
|
+
out_h, out_w = output_size
|
|
449
|
+
|
|
450
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
451
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
452
|
+
|
|
453
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
454
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
455
|
+
|
|
456
|
+
small_h_pool = lax.reduce_window(
|
|
457
|
+
inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
big_h_pool = lax.reduce_window(
|
|
461
|
+
inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
|
|
465
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=1)
|
|
466
|
+
|
|
467
|
+
small_w_pool = lax.reduce_window(
|
|
468
|
+
pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
big_w_pool = lax.reduce_window(
|
|
472
|
+
pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
|
|
476
|
+
out = jnp.take(combined_w, gather_w, axis=2)
|
|
477
|
+
|
|
478
|
+
if data_format == "channels_first":
|
|
479
|
+
out = jnp.transpose(out, (0, 3, 1, 2))
|
|
480
|
+
|
|
481
|
+
return out
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _adaptive_average_pool3d(inputs, output_size, data_format="channels_first"):
|
|
485
|
+
if isinstance(output_size, int):
|
|
486
|
+
output_size = (output_size, output_size, output_size)
|
|
487
|
+
|
|
488
|
+
if data_format == "channels_first":
|
|
489
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
|
|
490
|
+
|
|
491
|
+
n, d, h, w, c = inputs.shape
|
|
492
|
+
out_d, out_h, out_w = output_size
|
|
493
|
+
|
|
494
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
|
|
495
|
+
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
|
|
496
|
+
|
|
497
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
498
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
499
|
+
|
|
500
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
501
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
502
|
+
|
|
503
|
+
small_d_pool = (
|
|
504
|
+
lax.reduce_window(
|
|
505
|
+
inputs,
|
|
506
|
+
0.0,
|
|
507
|
+
lax.add,
|
|
508
|
+
(1, small_d, 1, 1, 1),
|
|
509
|
+
(1, 1, 1, 1, 1),
|
|
510
|
+
"valid",
|
|
511
|
+
)
|
|
512
|
+
/ small_d
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
big_d_pool = (
|
|
516
|
+
lax.reduce_window(
|
|
517
|
+
inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
|
|
518
|
+
)
|
|
519
|
+
/ big_d
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
|
|
523
|
+
pooled_d = jnp.take(combined_d, gather_d, axis=1)
|
|
524
|
+
|
|
525
|
+
small_h_pool = (
|
|
526
|
+
lax.reduce_window(
|
|
527
|
+
pooled_d,
|
|
528
|
+
0.0,
|
|
529
|
+
lax.add,
|
|
530
|
+
(1, 1, small_h, 1, 1),
|
|
531
|
+
(1, 1, 1, 1, 1),
|
|
532
|
+
"valid",
|
|
533
|
+
)
|
|
534
|
+
/ small_h
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
big_h_pool = (
|
|
538
|
+
lax.reduce_window(
|
|
539
|
+
pooled_d,
|
|
540
|
+
0.0,
|
|
541
|
+
lax.add,
|
|
542
|
+
(1, 1, big_h, 1, 1),
|
|
543
|
+
(1, 1, 1, 1, 1),
|
|
544
|
+
"valid",
|
|
545
|
+
)
|
|
546
|
+
/ big_h
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
|
|
550
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=2)
|
|
551
|
+
|
|
552
|
+
small_w_pool = (
|
|
553
|
+
lax.reduce_window(
|
|
554
|
+
pooled_h,
|
|
555
|
+
0.0,
|
|
556
|
+
lax.add,
|
|
557
|
+
(1, 1, 1, small_w, 1),
|
|
558
|
+
(1, 1, 1, 1, 1),
|
|
559
|
+
"valid",
|
|
560
|
+
)
|
|
561
|
+
/ small_w
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
big_w_pool = (
|
|
565
|
+
lax.reduce_window(
|
|
566
|
+
pooled_h,
|
|
567
|
+
0.0,
|
|
568
|
+
lax.add,
|
|
569
|
+
(1, 1, 1, big_w, 1),
|
|
570
|
+
(1, 1, 1, 1, 1),
|
|
571
|
+
"valid",
|
|
572
|
+
)
|
|
573
|
+
/ big_w
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
|
|
577
|
+
out = jnp.take(combined_w, gather_w, axis=3)
|
|
578
|
+
|
|
579
|
+
if data_format == "channels_first":
|
|
580
|
+
out = jnp.transpose(out, (0, 4, 1, 2, 3))
|
|
581
|
+
|
|
582
|
+
return out
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
|
|
586
|
+
if isinstance(output_size, int):
|
|
587
|
+
output_size = (output_size, output_size, output_size)
|
|
588
|
+
|
|
589
|
+
if data_format == "channels_first":
|
|
590
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
|
|
591
|
+
|
|
592
|
+
n, d, h, w, c = inputs.shape
|
|
593
|
+
out_d, out_h, out_w = output_size
|
|
594
|
+
|
|
595
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
|
|
596
|
+
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
|
|
597
|
+
|
|
598
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
599
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
600
|
+
|
|
601
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
602
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
603
|
+
|
|
604
|
+
small_d_pool = lax.reduce_window(
|
|
605
|
+
inputs,
|
|
606
|
+
-jnp.inf,
|
|
607
|
+
lax.max,
|
|
608
|
+
(1, small_d, 1, 1, 1),
|
|
609
|
+
(1, 1, 1, 1, 1),
|
|
610
|
+
"valid",
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
big_d_pool = lax.reduce_window(
|
|
614
|
+
inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
|
|
618
|
+
pooled_d = jnp.take(combined_d, gather_d, axis=1)
|
|
619
|
+
|
|
620
|
+
small_h_pool = lax.reduce_window(
|
|
621
|
+
pooled_d,
|
|
622
|
+
-jnp.inf,
|
|
623
|
+
lax.max,
|
|
624
|
+
(1, 1, small_h, 1, 1),
|
|
625
|
+
(1, 1, 1, 1, 1),
|
|
626
|
+
"valid",
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
big_h_pool = lax.reduce_window(
|
|
630
|
+
pooled_d,
|
|
631
|
+
-jnp.inf,
|
|
632
|
+
lax.max,
|
|
633
|
+
(1, 1, big_h, 1, 1),
|
|
634
|
+
(1, 1, 1, 1, 1),
|
|
635
|
+
"valid",
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
|
|
639
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=2)
|
|
640
|
+
|
|
641
|
+
small_w_pool = lax.reduce_window(
|
|
642
|
+
pooled_h,
|
|
643
|
+
-jnp.inf,
|
|
644
|
+
lax.max,
|
|
645
|
+
(1, 1, 1, small_w, 1),
|
|
646
|
+
(1, 1, 1, 1, 1),
|
|
647
|
+
"valid",
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
big_w_pool = lax.reduce_window(
|
|
651
|
+
pooled_h,
|
|
652
|
+
-jnp.inf,
|
|
653
|
+
lax.max,
|
|
654
|
+
(1, 1, 1, big_w, 1),
|
|
655
|
+
(1, 1, 1, 1, 1),
|
|
656
|
+
"valid",
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
|
|
660
|
+
out = jnp.take(combined_w, gather_w, axis=3)
|
|
661
|
+
|
|
662
|
+
if data_format == "channels_first":
|
|
663
|
+
out = jnp.transpose(out, (0, 4, 1, 2, 3))
|
|
664
|
+
|
|
665
|
+
return out
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def adaptive_average_pool(inputs, output_size, data_format=None):
|
|
669
|
+
data_format = backend.standardize_data_format(data_format)
|
|
670
|
+
dims = inputs.ndim - 2
|
|
671
|
+
if dims == 1:
|
|
672
|
+
return _adaptive_average_pool1d(inputs, output_size, data_format)
|
|
673
|
+
if dims == 2:
|
|
674
|
+
return _adaptive_average_pool2d(inputs, output_size, data_format)
|
|
675
|
+
if dims == 3:
|
|
676
|
+
return _adaptive_average_pool3d(inputs, output_size, data_format)
|
|
677
|
+
raise ValueError("adaptive_average_pool supports only 1D/2D/3D inputs")
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def adaptive_max_pool(inputs, output_size, data_format=None):
|
|
681
|
+
data_format = backend.standardize_data_format(data_format)
|
|
682
|
+
dims = inputs.ndim - 2
|
|
683
|
+
if dims == 1:
|
|
684
|
+
return _adaptive_max_pool1d(inputs, output_size, data_format)
|
|
685
|
+
if dims == 2:
|
|
686
|
+
return _adaptive_max_pool2d(inputs, output_size, data_format)
|
|
687
|
+
if dims == 3:
|
|
688
|
+
return _adaptive_max_pool3d(inputs, output_size, data_format)
|
|
689
|
+
raise ValueError("adaptive_max_pool supports only 1D/2D/3D inputs")
|
|
690
|
+
|
|
691
|
+
|
|
292
692
|
def _convert_to_lax_conv_dimension_numbers(
|
|
293
693
|
num_spatial_dims,
|
|
294
694
|
data_format="channels_last",
|
|
@@ -355,7 +755,7 @@ def conv(
|
|
|
355
755
|
feature_group_count = channels // kernel_in_channels
|
|
356
756
|
kernel = convert_to_tensor(kernel)
|
|
357
757
|
inputs = convert_to_tensor(inputs, dtype=kernel.dtype)
|
|
358
|
-
|
|
758
|
+
result = jax.lax.conv_general_dilated(
|
|
359
759
|
inputs,
|
|
360
760
|
kernel,
|
|
361
761
|
strides,
|
|
@@ -364,6 +764,14 @@ def conv(
|
|
|
364
764
|
dimension_numbers=dimension_numbers,
|
|
365
765
|
feature_group_count=feature_group_count,
|
|
366
766
|
)
|
|
767
|
+
if result.size == 0:
|
|
768
|
+
raise ValueError(
|
|
769
|
+
"The convolution operation resulted in an empty output. "
|
|
770
|
+
"This can happen if the input is too small for the given "
|
|
771
|
+
"kernel size, strides, dilation rate, and padding mode. "
|
|
772
|
+
"Please check the input shape and convolution parameters."
|
|
773
|
+
)
|
|
774
|
+
return result
|
|
367
775
|
|
|
368
776
|
|
|
369
777
|
def depthwise_conv(
|
|
@@ -396,6 +804,8 @@ def depthwise_conv(
|
|
|
396
804
|
feature_group_count = (
|
|
397
805
|
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
|
|
398
806
|
)
|
|
807
|
+
kernel = convert_to_tensor(kernel)
|
|
808
|
+
inputs = convert_to_tensor(inputs)
|
|
399
809
|
kernel = jnp.reshape(
|
|
400
810
|
kernel,
|
|
401
811
|
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
|
|
@@ -499,7 +909,7 @@ def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
|
|
|
499
909
|
values = jnp.greater_equal(jnp.ravel(x), 0).astype(dtype)
|
|
500
910
|
values_count = values.shape[0]
|
|
501
911
|
indices = [jnp.arange(dim) for dim in x.shape]
|
|
502
|
-
indices = jnp.meshgrid(*indices, indexing="ij")
|
|
912
|
+
indices = list(jnp.meshgrid(*indices, indexing="ij"))
|
|
503
913
|
indices.insert(axis, jnp.maximum(x, 0)) # Deal with negative indices
|
|
504
914
|
indices = [a.reshape(values_count, 1).astype("int32") for a in indices]
|
|
505
915
|
indices = jnp.concatenate(indices, axis=1)
|
|
@@ -1330,25 +1740,32 @@ def dot_product_attention(
|
|
|
1330
1740
|
if custom_mask is None and is_causal:
|
|
1331
1741
|
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
|
|
1332
1742
|
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
decoder_segment_ids=decoder_segment_ids,
|
|
1339
|
-
custom_mask=custom_mask,
|
|
1340
|
-
attn_logits_soft_cap=attn_logits_soft_cap,
|
|
1341
|
-
head_shards=head_shards,
|
|
1342
|
-
q_seq_shards=q_seq_shards,
|
|
1343
|
-
)
|
|
1344
|
-
# Transpose output back to Keras layout
|
|
1345
|
-
return jnp.transpose(output, axes=(0, 2, 1, 3))
|
|
1346
|
-
except Exception:
|
|
1347
|
-
logging.exception(
|
|
1348
|
-
"Failed to apply Splash kernel for flash attention. "
|
|
1349
|
-
"Falling back to JAX native dot_product_attention."
|
|
1350
|
-
)
|
|
1743
|
+
# Splash attention kernel requires concrete mask values for hashing.
|
|
1744
|
+
# If the mask is a tracer (e.g. inside a scan/loop), we must fall back.
|
|
1745
|
+
if isinstance(mask, jax.core.Tracer) or isinstance(
|
|
1746
|
+
custom_mask, jax.core.Tracer
|
|
1747
|
+
):
|
|
1351
1748
|
flash_attention = False
|
|
1749
|
+
else:
|
|
1750
|
+
try:
|
|
1751
|
+
output = wrap_flash_attention(
|
|
1752
|
+
query_tpu_layout,
|
|
1753
|
+
key_tpu_layout,
|
|
1754
|
+
value_tpu_layout,
|
|
1755
|
+
decoder_segment_ids=decoder_segment_ids,
|
|
1756
|
+
custom_mask=custom_mask,
|
|
1757
|
+
attn_logits_soft_cap=attn_logits_soft_cap,
|
|
1758
|
+
head_shards=head_shards,
|
|
1759
|
+
q_seq_shards=q_seq_shards,
|
|
1760
|
+
)
|
|
1761
|
+
# Transpose output back to Keras layout
|
|
1762
|
+
return jnp.transpose(output, axes=(0, 2, 1, 3))
|
|
1763
|
+
except Exception:
|
|
1764
|
+
logging.exception(
|
|
1765
|
+
"Failed to apply Splash kernel for flash attention. "
|
|
1766
|
+
"Falling back to JAX native dot_product_attention."
|
|
1767
|
+
)
|
|
1768
|
+
flash_attention = False
|
|
1352
1769
|
|
|
1353
1770
|
# JAX native dot_product_attention for GPU or fallback for TPU
|
|
1354
1771
|
if hasattr(jax.nn, "dot_product_attention"):
|
|
@@ -1394,6 +1811,11 @@ def dot_product_attention(
|
|
|
1394
1811
|
|
|
1395
1812
|
def _reshape_to_grouped(t):
|
|
1396
1813
|
if t is not None:
|
|
1814
|
+
while t.ndim < 4:
|
|
1815
|
+
if t.ndim == 3 and t.shape[1] == N:
|
|
1816
|
+
t = jnp.expand_dims(t, axis=2)
|
|
1817
|
+
else:
|
|
1818
|
+
t = jnp.expand_dims(t, axis=1)
|
|
1397
1819
|
tB, tN, tT, tS = t.shape
|
|
1398
1820
|
if tN == 1:
|
|
1399
1821
|
t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
|
|
@@ -1411,3 +1833,46 @@ def dot_product_attention(
|
|
|
1411
1833
|
)
|
|
1412
1834
|
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
|
|
1413
1835
|
return jnp.reshape(encoded, output_shape)
|
|
1836
|
+
|
|
1837
|
+
|
|
1838
|
+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|
1839
|
+
"""JAX implementation of Unfold.
|
|
1840
|
+
Extract sliding local blocks from a **NCHW** batched image tensor.
|
|
1841
|
+
|
|
1842
|
+
Args:
|
|
1843
|
+
input: 4-D tensor, shape (N, C, H, W) **required**.
|
|
1844
|
+
kernel_size: int or (kH, kW)
|
|
1845
|
+
dilation: int or (dH, dW), default 1
|
|
1846
|
+
padding: int or (pH, pW), default 0
|
|
1847
|
+
stride: int or (sH, sW), default 1
|
|
1848
|
+
|
|
1849
|
+
Returns:
|
|
1850
|
+
3-D tensor, shape (N, C*kH*kW, L)
|
|
1851
|
+
"""
|
|
1852
|
+
|
|
1853
|
+
def _pair(x):
|
|
1854
|
+
return (x, x) if isinstance(x, int) else x
|
|
1855
|
+
|
|
1856
|
+
k = _pair(kernel_size)
|
|
1857
|
+
d = _pair(dilation)
|
|
1858
|
+
p = _pair(padding)
|
|
1859
|
+
s = _pair(stride)
|
|
1860
|
+
|
|
1861
|
+
N, C, H, W = input.shape
|
|
1862
|
+
|
|
1863
|
+
# ---- padding ----
|
|
1864
|
+
if any(_ > 0 for _ in p):
|
|
1865
|
+
input = jnp.pad(input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])))
|
|
1866
|
+
|
|
1867
|
+
patches = lax.conv_general_dilated_patches(
|
|
1868
|
+
input,
|
|
1869
|
+
filter_shape=k,
|
|
1870
|
+
window_strides=s,
|
|
1871
|
+
padding="VALID", # has padde
|
|
1872
|
+
rhs_dilation=d,
|
|
1873
|
+
dimension_numbers=("NCHW", "OIHW", "NCHW"), # only support 'NCHW'
|
|
1874
|
+
) # shape: (N, C*kH*kW, oH, oW)
|
|
1875
|
+
|
|
1876
|
+
# ---- reshape -> (N, C*kH*kW, L) ----
|
|
1877
|
+
_, CKK, oH, oW = patches.shape
|
|
1878
|
+
return patches.reshape(N, CKK, oH * oW)
|