keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/backend/jax/nn.py
CHANGED
|
@@ -16,6 +16,9 @@ from jax.experimental.pallas.ops.tpu.splash_attention import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
from keras.src import backend
|
|
19
|
+
from keras.src.backend.common.backend_utils import (
|
|
20
|
+
compute_adaptive_pooling_window_sizes,
|
|
21
|
+
)
|
|
19
22
|
from keras.src.backend.common.backend_utils import (
|
|
20
23
|
compute_conv_transpose_padding_args_for_jax,
|
|
21
24
|
)
|
|
@@ -289,6 +292,403 @@ def average_pool(
|
|
|
289
292
|
return pooled / window_counts
|
|
290
293
|
|
|
291
294
|
|
|
295
|
+
def _compute_adaptive_pooling_gather_indices(
|
|
296
|
+
input_dim, output_size, big_window
|
|
297
|
+
):
|
|
298
|
+
"""Compute gather indices for Two-Pool Gather method."""
|
|
299
|
+
window_starts = jnp.floor(
|
|
300
|
+
(jnp.arange(output_size) * input_dim) / output_size
|
|
301
|
+
).astype(jnp.int32)
|
|
302
|
+
|
|
303
|
+
window_ends = jnp.ceil(
|
|
304
|
+
(jnp.arange(1, output_size + 1) * input_dim) / output_size
|
|
305
|
+
).astype(jnp.int32)
|
|
306
|
+
|
|
307
|
+
window_sizes = window_ends - window_starts
|
|
308
|
+
is_big = window_sizes == big_window
|
|
309
|
+
|
|
310
|
+
small_window = big_window - 1
|
|
311
|
+
small_len = input_dim - small_window + 1
|
|
312
|
+
|
|
313
|
+
small_indices = window_starts
|
|
314
|
+
big_indices = window_starts + small_len
|
|
315
|
+
|
|
316
|
+
gather = jnp.where(is_big, big_indices, small_indices)
|
|
317
|
+
return gather.astype(jnp.int32)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _adaptive_average_pool1d(inputs, output_size, data_format="channels_first"):
|
|
321
|
+
if isinstance(output_size, int):
|
|
322
|
+
output_size = (output_size,)
|
|
323
|
+
|
|
324
|
+
if data_format == "channels_first":
|
|
325
|
+
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL → NLC
|
|
326
|
+
|
|
327
|
+
n, l, c = inputs.shape
|
|
328
|
+
out_l = output_size[0]
|
|
329
|
+
|
|
330
|
+
small, big = compute_adaptive_pooling_window_sizes(l, out_l)
|
|
331
|
+
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
|
|
332
|
+
|
|
333
|
+
small_pool = (
|
|
334
|
+
lax.reduce_window(
|
|
335
|
+
inputs, 0.0, lax.add, (1, small, 1), (1, 1, 1), "valid"
|
|
336
|
+
)
|
|
337
|
+
/ small
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
big_pool = (
|
|
341
|
+
lax.reduce_window(inputs, 0.0, lax.add, (1, big, 1), (1, 1, 1), "valid")
|
|
342
|
+
/ big
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
combined = jnp.concatenate([small_pool, big_pool], axis=1)
|
|
346
|
+
out = jnp.take(combined, gather, axis=1)
|
|
347
|
+
|
|
348
|
+
if data_format == "channels_first":
|
|
349
|
+
out = jnp.transpose(out, (0, 2, 1))
|
|
350
|
+
|
|
351
|
+
return out
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
|
|
355
|
+
if isinstance(output_size, int):
|
|
356
|
+
output_size = (output_size,)
|
|
357
|
+
|
|
358
|
+
if data_format == "channels_first":
|
|
359
|
+
inputs = jnp.transpose(inputs, (0, 2, 1))
|
|
360
|
+
|
|
361
|
+
n, l, c = inputs.shape
|
|
362
|
+
out_l = output_size[0]
|
|
363
|
+
|
|
364
|
+
small, big = compute_adaptive_pooling_window_sizes(l, out_l)
|
|
365
|
+
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
|
|
366
|
+
|
|
367
|
+
small_pool = lax.reduce_window(
|
|
368
|
+
inputs, -jnp.inf, lax.max, (1, small, 1), (1, 1, 1), "valid"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
big_pool = lax.reduce_window(
|
|
372
|
+
inputs, -jnp.inf, lax.max, (1, big, 1), (1, 1, 1), "valid"
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
combined = jnp.concatenate([small_pool, big_pool], axis=1)
|
|
376
|
+
out = jnp.take(combined, gather, axis=1)
|
|
377
|
+
|
|
378
|
+
if data_format == "channels_first":
|
|
379
|
+
out = jnp.transpose(out, (0, 2, 1))
|
|
380
|
+
|
|
381
|
+
return out
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _adaptive_average_pool2d(inputs, output_size, data_format="channels_first"):
|
|
385
|
+
if isinstance(output_size, int):
|
|
386
|
+
output_size = (output_size, output_size)
|
|
387
|
+
|
|
388
|
+
if data_format == "channels_first":
|
|
389
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 1))
|
|
390
|
+
|
|
391
|
+
n, h, w, c = inputs.shape
|
|
392
|
+
out_h, out_w = output_size
|
|
393
|
+
|
|
394
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
395
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
396
|
+
|
|
397
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
398
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
399
|
+
|
|
400
|
+
small_h_pool = (
|
|
401
|
+
lax.reduce_window(
|
|
402
|
+
inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
403
|
+
)
|
|
404
|
+
/ small_h
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
big_h_pool = (
|
|
408
|
+
lax.reduce_window(
|
|
409
|
+
inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
410
|
+
)
|
|
411
|
+
/ big_h
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
|
|
415
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=1)
|
|
416
|
+
|
|
417
|
+
small_w_pool = (
|
|
418
|
+
lax.reduce_window(
|
|
419
|
+
pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
|
|
420
|
+
)
|
|
421
|
+
/ small_w
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
big_w_pool = (
|
|
425
|
+
lax.reduce_window(
|
|
426
|
+
pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
|
|
427
|
+
)
|
|
428
|
+
/ big_w
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
|
|
432
|
+
out = jnp.take(combined_w, gather_w, axis=2)
|
|
433
|
+
|
|
434
|
+
if data_format == "channels_first":
|
|
435
|
+
out = jnp.transpose(out, (0, 3, 1, 2))
|
|
436
|
+
|
|
437
|
+
return out
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
|
|
441
|
+
if isinstance(output_size, int):
|
|
442
|
+
output_size = (output_size, output_size)
|
|
443
|
+
|
|
444
|
+
if data_format == "channels_first":
|
|
445
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 1))
|
|
446
|
+
|
|
447
|
+
n, h, w, c = inputs.shape
|
|
448
|
+
out_h, out_w = output_size
|
|
449
|
+
|
|
450
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
451
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
452
|
+
|
|
453
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
454
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
455
|
+
|
|
456
|
+
small_h_pool = lax.reduce_window(
|
|
457
|
+
inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
big_h_pool = lax.reduce_window(
|
|
461
|
+
inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
|
|
465
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=1)
|
|
466
|
+
|
|
467
|
+
small_w_pool = lax.reduce_window(
|
|
468
|
+
pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
big_w_pool = lax.reduce_window(
|
|
472
|
+
pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
|
|
476
|
+
out = jnp.take(combined_w, gather_w, axis=2)
|
|
477
|
+
|
|
478
|
+
if data_format == "channels_first":
|
|
479
|
+
out = jnp.transpose(out, (0, 3, 1, 2))
|
|
480
|
+
|
|
481
|
+
return out
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _adaptive_average_pool3d(inputs, output_size, data_format="channels_first"):
|
|
485
|
+
if isinstance(output_size, int):
|
|
486
|
+
output_size = (output_size, output_size, output_size)
|
|
487
|
+
|
|
488
|
+
if data_format == "channels_first":
|
|
489
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
|
|
490
|
+
|
|
491
|
+
n, d, h, w, c = inputs.shape
|
|
492
|
+
out_d, out_h, out_w = output_size
|
|
493
|
+
|
|
494
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
|
|
495
|
+
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
|
|
496
|
+
|
|
497
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
498
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
499
|
+
|
|
500
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
501
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
502
|
+
|
|
503
|
+
small_d_pool = (
|
|
504
|
+
lax.reduce_window(
|
|
505
|
+
inputs,
|
|
506
|
+
0.0,
|
|
507
|
+
lax.add,
|
|
508
|
+
(1, small_d, 1, 1, 1),
|
|
509
|
+
(1, 1, 1, 1, 1),
|
|
510
|
+
"valid",
|
|
511
|
+
)
|
|
512
|
+
/ small_d
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
big_d_pool = (
|
|
516
|
+
lax.reduce_window(
|
|
517
|
+
inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
|
|
518
|
+
)
|
|
519
|
+
/ big_d
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
|
|
523
|
+
pooled_d = jnp.take(combined_d, gather_d, axis=1)
|
|
524
|
+
|
|
525
|
+
small_h_pool = (
|
|
526
|
+
lax.reduce_window(
|
|
527
|
+
pooled_d,
|
|
528
|
+
0.0,
|
|
529
|
+
lax.add,
|
|
530
|
+
(1, 1, small_h, 1, 1),
|
|
531
|
+
(1, 1, 1, 1, 1),
|
|
532
|
+
"valid",
|
|
533
|
+
)
|
|
534
|
+
/ small_h
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
big_h_pool = (
|
|
538
|
+
lax.reduce_window(
|
|
539
|
+
pooled_d,
|
|
540
|
+
0.0,
|
|
541
|
+
lax.add,
|
|
542
|
+
(1, 1, big_h, 1, 1),
|
|
543
|
+
(1, 1, 1, 1, 1),
|
|
544
|
+
"valid",
|
|
545
|
+
)
|
|
546
|
+
/ big_h
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
|
|
550
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=2)
|
|
551
|
+
|
|
552
|
+
small_w_pool = (
|
|
553
|
+
lax.reduce_window(
|
|
554
|
+
pooled_h,
|
|
555
|
+
0.0,
|
|
556
|
+
lax.add,
|
|
557
|
+
(1, 1, 1, small_w, 1),
|
|
558
|
+
(1, 1, 1, 1, 1),
|
|
559
|
+
"valid",
|
|
560
|
+
)
|
|
561
|
+
/ small_w
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
big_w_pool = (
|
|
565
|
+
lax.reduce_window(
|
|
566
|
+
pooled_h,
|
|
567
|
+
0.0,
|
|
568
|
+
lax.add,
|
|
569
|
+
(1, 1, 1, big_w, 1),
|
|
570
|
+
(1, 1, 1, 1, 1),
|
|
571
|
+
"valid",
|
|
572
|
+
)
|
|
573
|
+
/ big_w
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
|
|
577
|
+
out = jnp.take(combined_w, gather_w, axis=3)
|
|
578
|
+
|
|
579
|
+
if data_format == "channels_first":
|
|
580
|
+
out = jnp.transpose(out, (0, 4, 1, 2, 3))
|
|
581
|
+
|
|
582
|
+
return out
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
|
|
586
|
+
if isinstance(output_size, int):
|
|
587
|
+
output_size = (output_size, output_size, output_size)
|
|
588
|
+
|
|
589
|
+
if data_format == "channels_first":
|
|
590
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
|
|
591
|
+
|
|
592
|
+
n, d, h, w, c = inputs.shape
|
|
593
|
+
out_d, out_h, out_w = output_size
|
|
594
|
+
|
|
595
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
|
|
596
|
+
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
|
|
597
|
+
|
|
598
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
599
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
600
|
+
|
|
601
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
602
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
603
|
+
|
|
604
|
+
small_d_pool = lax.reduce_window(
|
|
605
|
+
inputs,
|
|
606
|
+
-jnp.inf,
|
|
607
|
+
lax.max,
|
|
608
|
+
(1, small_d, 1, 1, 1),
|
|
609
|
+
(1, 1, 1, 1, 1),
|
|
610
|
+
"valid",
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
big_d_pool = lax.reduce_window(
|
|
614
|
+
inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
|
|
618
|
+
pooled_d = jnp.take(combined_d, gather_d, axis=1)
|
|
619
|
+
|
|
620
|
+
small_h_pool = lax.reduce_window(
|
|
621
|
+
pooled_d,
|
|
622
|
+
-jnp.inf,
|
|
623
|
+
lax.max,
|
|
624
|
+
(1, 1, small_h, 1, 1),
|
|
625
|
+
(1, 1, 1, 1, 1),
|
|
626
|
+
"valid",
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
big_h_pool = lax.reduce_window(
|
|
630
|
+
pooled_d,
|
|
631
|
+
-jnp.inf,
|
|
632
|
+
lax.max,
|
|
633
|
+
(1, 1, big_h, 1, 1),
|
|
634
|
+
(1, 1, 1, 1, 1),
|
|
635
|
+
"valid",
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
|
|
639
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=2)
|
|
640
|
+
|
|
641
|
+
small_w_pool = lax.reduce_window(
|
|
642
|
+
pooled_h,
|
|
643
|
+
-jnp.inf,
|
|
644
|
+
lax.max,
|
|
645
|
+
(1, 1, 1, small_w, 1),
|
|
646
|
+
(1, 1, 1, 1, 1),
|
|
647
|
+
"valid",
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
big_w_pool = lax.reduce_window(
|
|
651
|
+
pooled_h,
|
|
652
|
+
-jnp.inf,
|
|
653
|
+
lax.max,
|
|
654
|
+
(1, 1, 1, big_w, 1),
|
|
655
|
+
(1, 1, 1, 1, 1),
|
|
656
|
+
"valid",
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
|
|
660
|
+
out = jnp.take(combined_w, gather_w, axis=3)
|
|
661
|
+
|
|
662
|
+
if data_format == "channels_first":
|
|
663
|
+
out = jnp.transpose(out, (0, 4, 1, 2, 3))
|
|
664
|
+
|
|
665
|
+
return out
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def adaptive_average_pool(inputs, output_size, data_format=None):
|
|
669
|
+
data_format = backend.standardize_data_format(data_format)
|
|
670
|
+
dims = inputs.ndim - 2
|
|
671
|
+
if dims == 1:
|
|
672
|
+
return _adaptive_average_pool1d(inputs, output_size, data_format)
|
|
673
|
+
if dims == 2:
|
|
674
|
+
return _adaptive_average_pool2d(inputs, output_size, data_format)
|
|
675
|
+
if dims == 3:
|
|
676
|
+
return _adaptive_average_pool3d(inputs, output_size, data_format)
|
|
677
|
+
raise ValueError("adaptive_average_pool supports only 1D/2D/3D inputs")
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def adaptive_max_pool(inputs, output_size, data_format=None):
|
|
681
|
+
data_format = backend.standardize_data_format(data_format)
|
|
682
|
+
dims = inputs.ndim - 2
|
|
683
|
+
if dims == 1:
|
|
684
|
+
return _adaptive_max_pool1d(inputs, output_size, data_format)
|
|
685
|
+
if dims == 2:
|
|
686
|
+
return _adaptive_max_pool2d(inputs, output_size, data_format)
|
|
687
|
+
if dims == 3:
|
|
688
|
+
return _adaptive_max_pool3d(inputs, output_size, data_format)
|
|
689
|
+
raise ValueError("adaptive_max_pool supports only 1D/2D/3D inputs")
|
|
690
|
+
|
|
691
|
+
|
|
292
692
|
def _convert_to_lax_conv_dimension_numbers(
|
|
293
693
|
num_spatial_dims,
|
|
294
694
|
data_format="channels_last",
|
|
@@ -355,7 +755,7 @@ def conv(
|
|
|
355
755
|
feature_group_count = channels // kernel_in_channels
|
|
356
756
|
kernel = convert_to_tensor(kernel)
|
|
357
757
|
inputs = convert_to_tensor(inputs, dtype=kernel.dtype)
|
|
358
|
-
|
|
758
|
+
result = jax.lax.conv_general_dilated(
|
|
359
759
|
inputs,
|
|
360
760
|
kernel,
|
|
361
761
|
strides,
|
|
@@ -364,6 +764,14 @@ def conv(
|
|
|
364
764
|
dimension_numbers=dimension_numbers,
|
|
365
765
|
feature_group_count=feature_group_count,
|
|
366
766
|
)
|
|
767
|
+
if result.size == 0:
|
|
768
|
+
raise ValueError(
|
|
769
|
+
"The convolution operation resulted in an empty output. "
|
|
770
|
+
"This can happen if the input is too small for the given "
|
|
771
|
+
"kernel size, strides, dilation rate, and padding mode. "
|
|
772
|
+
"Please check the input shape and convolution parameters."
|
|
773
|
+
)
|
|
774
|
+
return result
|
|
367
775
|
|
|
368
776
|
|
|
369
777
|
def depthwise_conv(
|
|
@@ -501,7 +909,7 @@ def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
|
|
|
501
909
|
values = jnp.greater_equal(jnp.ravel(x), 0).astype(dtype)
|
|
502
910
|
values_count = values.shape[0]
|
|
503
911
|
indices = [jnp.arange(dim) for dim in x.shape]
|
|
504
|
-
indices = jnp.meshgrid(*indices, indexing="ij")
|
|
912
|
+
indices = list(jnp.meshgrid(*indices, indexing="ij"))
|
|
505
913
|
indices.insert(axis, jnp.maximum(x, 0)) # Deal with negative indices
|
|
506
914
|
indices = [a.reshape(values_count, 1).astype("int32") for a in indices]
|
|
507
915
|
indices = jnp.concatenate(indices, axis=1)
|
|
@@ -1063,25 +1471,42 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
|
|
|
1063
1471
|
# Only support at least Ampere
|
|
1064
1472
|
if not check_compute_capability("8.0"):
|
|
1065
1473
|
raise RuntimeError("Require at least Ampere arch to run")
|
|
1066
|
-
|
|
1474
|
+
|
|
1475
|
+
# Inspect inputs of `check_layout`
|
|
1067
1476
|
check_layout_params = list(
|
|
1068
1477
|
inspect.signature(check_layout).parameters.keys()
|
|
1069
1478
|
)
|
|
1070
1479
|
for known_param in ("query", "key", "value", "bias", "layout"):
|
|
1071
1480
|
check_layout_params.remove(known_param)
|
|
1072
1481
|
# Defaults to `None` when not specified.
|
|
1073
|
-
|
|
1482
|
+
check_layout_kwargs = {key: None for key in check_layout_params}
|
|
1074
1483
|
check_layout(
|
|
1075
|
-
query, key, value, bias, layout=_normalize_layout("BTNH"), **kwargs
|
|
1076
|
-
)
|
|
1077
|
-
check_is_flash_attention(
|
|
1078
1484
|
query,
|
|
1079
1485
|
key,
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1486
|
+
value,
|
|
1487
|
+
bias,
|
|
1488
|
+
layout=_normalize_layout("BTNH"),
|
|
1489
|
+
**check_layout_kwargs,
|
|
1084
1490
|
)
|
|
1491
|
+
|
|
1492
|
+
# Inspect inputs of `check_is_flash_attention`
|
|
1493
|
+
check_is_flash_attention_params = inspect.signature(
|
|
1494
|
+
check_is_flash_attention
|
|
1495
|
+
).parameters
|
|
1496
|
+
check_is_flash_attention_kwargs = {
|
|
1497
|
+
"query": query,
|
|
1498
|
+
"key": key,
|
|
1499
|
+
"value": value,
|
|
1500
|
+
"layout": _normalize_layout("BTNH"),
|
|
1501
|
+
"cudnn_version": cudnn_version,
|
|
1502
|
+
"has_bias": bias is not None,
|
|
1503
|
+
"is_training": False,
|
|
1504
|
+
}
|
|
1505
|
+
# Remove unsupported arguments
|
|
1506
|
+
for param in list(check_is_flash_attention_kwargs.keys()):
|
|
1507
|
+
if param not in check_is_flash_attention_params:
|
|
1508
|
+
check_is_flash_attention_kwargs.pop(param)
|
|
1509
|
+
check_is_flash_attention(**check_is_flash_attention_kwargs)
|
|
1085
1510
|
return True
|
|
1086
1511
|
except:
|
|
1087
1512
|
if raise_error:
|
|
@@ -1332,25 +1757,32 @@ def dot_product_attention(
|
|
|
1332
1757
|
if custom_mask is None and is_causal:
|
|
1333
1758
|
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
|
|
1334
1759
|
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
decoder_segment_ids=decoder_segment_ids,
|
|
1341
|
-
custom_mask=custom_mask,
|
|
1342
|
-
attn_logits_soft_cap=attn_logits_soft_cap,
|
|
1343
|
-
head_shards=head_shards,
|
|
1344
|
-
q_seq_shards=q_seq_shards,
|
|
1345
|
-
)
|
|
1346
|
-
# Transpose output back to Keras layout
|
|
1347
|
-
return jnp.transpose(output, axes=(0, 2, 1, 3))
|
|
1348
|
-
except Exception:
|
|
1349
|
-
logging.exception(
|
|
1350
|
-
"Failed to apply Splash kernel for flash attention. "
|
|
1351
|
-
"Falling back to JAX native dot_product_attention."
|
|
1352
|
-
)
|
|
1760
|
+
# Splash attention kernel requires concrete mask values for hashing.
|
|
1761
|
+
# If the mask is a tracer (e.g. inside a scan/loop), we must fall back.
|
|
1762
|
+
if isinstance(mask, jax.core.Tracer) or isinstance(
|
|
1763
|
+
custom_mask, jax.core.Tracer
|
|
1764
|
+
):
|
|
1353
1765
|
flash_attention = False
|
|
1766
|
+
else:
|
|
1767
|
+
try:
|
|
1768
|
+
output = wrap_flash_attention(
|
|
1769
|
+
query_tpu_layout,
|
|
1770
|
+
key_tpu_layout,
|
|
1771
|
+
value_tpu_layout,
|
|
1772
|
+
decoder_segment_ids=decoder_segment_ids,
|
|
1773
|
+
custom_mask=custom_mask,
|
|
1774
|
+
attn_logits_soft_cap=attn_logits_soft_cap,
|
|
1775
|
+
head_shards=head_shards,
|
|
1776
|
+
q_seq_shards=q_seq_shards,
|
|
1777
|
+
)
|
|
1778
|
+
# Transpose output back to Keras layout
|
|
1779
|
+
return jnp.transpose(output, axes=(0, 2, 1, 3))
|
|
1780
|
+
except Exception:
|
|
1781
|
+
logging.exception(
|
|
1782
|
+
"Failed to apply Splash kernel for flash attention. "
|
|
1783
|
+
"Falling back to JAX native dot_product_attention."
|
|
1784
|
+
)
|
|
1785
|
+
flash_attention = False
|
|
1354
1786
|
|
|
1355
1787
|
# JAX native dot_product_attention for GPU or fallback for TPU
|
|
1356
1788
|
if hasattr(jax.nn, "dot_product_attention"):
|
|
@@ -1396,6 +1828,11 @@ def dot_product_attention(
|
|
|
1396
1828
|
|
|
1397
1829
|
def _reshape_to_grouped(t):
|
|
1398
1830
|
if t is not None:
|
|
1831
|
+
while t.ndim < 4:
|
|
1832
|
+
if t.ndim == 3 and t.shape[1] == N:
|
|
1833
|
+
t = jnp.expand_dims(t, axis=2)
|
|
1834
|
+
else:
|
|
1835
|
+
t = jnp.expand_dims(t, axis=1)
|
|
1399
1836
|
tB, tN, tT, tS = t.shape
|
|
1400
1837
|
if tN == 1:
|
|
1401
1838
|
t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
|
|
@@ -1413,3 +1850,46 @@ def dot_product_attention(
|
|
|
1413
1850
|
)
|
|
1414
1851
|
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
|
|
1415
1852
|
return jnp.reshape(encoded, output_shape)
|
|
1853
|
+
|
|
1854
|
+
|
|
1855
|
+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|
1856
|
+
"""JAX implementation of Unfold.
|
|
1857
|
+
Extract sliding local blocks from a **NCHW** batched image tensor.
|
|
1858
|
+
|
|
1859
|
+
Args:
|
|
1860
|
+
input: 4-D tensor, shape (N, C, H, W) **required**.
|
|
1861
|
+
kernel_size: int or (kH, kW)
|
|
1862
|
+
dilation: int or (dH, dW), default 1
|
|
1863
|
+
padding: int or (pH, pW), default 0
|
|
1864
|
+
stride: int or (sH, sW), default 1
|
|
1865
|
+
|
|
1866
|
+
Returns:
|
|
1867
|
+
3-D tensor, shape (N, C*kH*kW, L)
|
|
1868
|
+
"""
|
|
1869
|
+
|
|
1870
|
+
def _pair(x):
|
|
1871
|
+
return (x, x) if isinstance(x, int) else x
|
|
1872
|
+
|
|
1873
|
+
k = _pair(kernel_size)
|
|
1874
|
+
d = _pair(dilation)
|
|
1875
|
+
p = _pair(padding)
|
|
1876
|
+
s = _pair(stride)
|
|
1877
|
+
|
|
1878
|
+
N, C, H, W = input.shape
|
|
1879
|
+
|
|
1880
|
+
# ---- padding ----
|
|
1881
|
+
if any(_ > 0 for _ in p):
|
|
1882
|
+
input = jnp.pad(input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])))
|
|
1883
|
+
|
|
1884
|
+
patches = lax.conv_general_dilated_patches(
|
|
1885
|
+
input,
|
|
1886
|
+
filter_shape=k,
|
|
1887
|
+
window_strides=s,
|
|
1888
|
+
padding="VALID", # has padde
|
|
1889
|
+
rhs_dilation=d,
|
|
1890
|
+
dimension_numbers=("NCHW", "OIHW", "NCHW"), # only support 'NCHW'
|
|
1891
|
+
) # shape: (N, C*kH*kW, oH, oW)
|
|
1892
|
+
|
|
1893
|
+
# ---- reshape -> (N, C*kH*kW, L) ----
|
|
1894
|
+
_, CKK, oH, oW = patches.shape
|
|
1895
|
+
return patches.reshape(N, CKK, oH * oW)
|