keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/backend/jax/nn.py
CHANGED
|
@@ -16,6 +16,9 @@ from jax.experimental.pallas.ops.tpu.splash_attention import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
from keras.src import backend
|
|
19
|
+
from keras.src.backend.common.backend_utils import (
|
|
20
|
+
compute_adaptive_pooling_window_sizes,
|
|
21
|
+
)
|
|
19
22
|
from keras.src.backend.common.backend_utils import (
|
|
20
23
|
compute_conv_transpose_padding_args_for_jax,
|
|
21
24
|
)
|
|
@@ -289,6 +292,403 @@ def average_pool(
|
|
|
289
292
|
return pooled / window_counts
|
|
290
293
|
|
|
291
294
|
|
|
295
|
+
def _compute_adaptive_pooling_gather_indices(
|
|
296
|
+
input_dim, output_size, big_window
|
|
297
|
+
):
|
|
298
|
+
"""Compute gather indices for Two-Pool Gather method."""
|
|
299
|
+
window_starts = jnp.floor(
|
|
300
|
+
(jnp.arange(output_size) * input_dim) / output_size
|
|
301
|
+
).astype(jnp.int32)
|
|
302
|
+
|
|
303
|
+
window_ends = jnp.ceil(
|
|
304
|
+
(jnp.arange(1, output_size + 1) * input_dim) / output_size
|
|
305
|
+
).astype(jnp.int32)
|
|
306
|
+
|
|
307
|
+
window_sizes = window_ends - window_starts
|
|
308
|
+
is_big = window_sizes == big_window
|
|
309
|
+
|
|
310
|
+
small_window = big_window - 1
|
|
311
|
+
small_len = input_dim - small_window + 1
|
|
312
|
+
|
|
313
|
+
small_indices = window_starts
|
|
314
|
+
big_indices = window_starts + small_len
|
|
315
|
+
|
|
316
|
+
gather = jnp.where(is_big, big_indices, small_indices)
|
|
317
|
+
return gather.astype(jnp.int32)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _adaptive_average_pool1d(inputs, output_size, data_format="channels_first"):
|
|
321
|
+
if isinstance(output_size, int):
|
|
322
|
+
output_size = (output_size,)
|
|
323
|
+
|
|
324
|
+
if data_format == "channels_first":
|
|
325
|
+
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL → NLC
|
|
326
|
+
|
|
327
|
+
n, l, c = inputs.shape
|
|
328
|
+
out_l = output_size[0]
|
|
329
|
+
|
|
330
|
+
small, big = compute_adaptive_pooling_window_sizes(l, out_l)
|
|
331
|
+
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
|
|
332
|
+
|
|
333
|
+
small_pool = (
|
|
334
|
+
lax.reduce_window(
|
|
335
|
+
inputs, 0.0, lax.add, (1, small, 1), (1, 1, 1), "valid"
|
|
336
|
+
)
|
|
337
|
+
/ small
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
big_pool = (
|
|
341
|
+
lax.reduce_window(inputs, 0.0, lax.add, (1, big, 1), (1, 1, 1), "valid")
|
|
342
|
+
/ big
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
combined = jnp.concatenate([small_pool, big_pool], axis=1)
|
|
346
|
+
out = jnp.take(combined, gather, axis=1)
|
|
347
|
+
|
|
348
|
+
if data_format == "channels_first":
|
|
349
|
+
out = jnp.transpose(out, (0, 2, 1))
|
|
350
|
+
|
|
351
|
+
return out
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
|
|
355
|
+
if isinstance(output_size, int):
|
|
356
|
+
output_size = (output_size,)
|
|
357
|
+
|
|
358
|
+
if data_format == "channels_first":
|
|
359
|
+
inputs = jnp.transpose(inputs, (0, 2, 1))
|
|
360
|
+
|
|
361
|
+
n, l, c = inputs.shape
|
|
362
|
+
out_l = output_size[0]
|
|
363
|
+
|
|
364
|
+
small, big = compute_adaptive_pooling_window_sizes(l, out_l)
|
|
365
|
+
gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)
|
|
366
|
+
|
|
367
|
+
small_pool = lax.reduce_window(
|
|
368
|
+
inputs, -jnp.inf, lax.max, (1, small, 1), (1, 1, 1), "valid"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
big_pool = lax.reduce_window(
|
|
372
|
+
inputs, -jnp.inf, lax.max, (1, big, 1), (1, 1, 1), "valid"
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
combined = jnp.concatenate([small_pool, big_pool], axis=1)
|
|
376
|
+
out = jnp.take(combined, gather, axis=1)
|
|
377
|
+
|
|
378
|
+
if data_format == "channels_first":
|
|
379
|
+
out = jnp.transpose(out, (0, 2, 1))
|
|
380
|
+
|
|
381
|
+
return out
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _adaptive_average_pool2d(inputs, output_size, data_format="channels_first"):
|
|
385
|
+
if isinstance(output_size, int):
|
|
386
|
+
output_size = (output_size, output_size)
|
|
387
|
+
|
|
388
|
+
if data_format == "channels_first":
|
|
389
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 1))
|
|
390
|
+
|
|
391
|
+
n, h, w, c = inputs.shape
|
|
392
|
+
out_h, out_w = output_size
|
|
393
|
+
|
|
394
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
395
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
396
|
+
|
|
397
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
398
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
399
|
+
|
|
400
|
+
small_h_pool = (
|
|
401
|
+
lax.reduce_window(
|
|
402
|
+
inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
403
|
+
)
|
|
404
|
+
/ small_h
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
big_h_pool = (
|
|
408
|
+
lax.reduce_window(
|
|
409
|
+
inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
410
|
+
)
|
|
411
|
+
/ big_h
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
|
|
415
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=1)
|
|
416
|
+
|
|
417
|
+
small_w_pool = (
|
|
418
|
+
lax.reduce_window(
|
|
419
|
+
pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
|
|
420
|
+
)
|
|
421
|
+
/ small_w
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
big_w_pool = (
|
|
425
|
+
lax.reduce_window(
|
|
426
|
+
pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
|
|
427
|
+
)
|
|
428
|
+
/ big_w
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
|
|
432
|
+
out = jnp.take(combined_w, gather_w, axis=2)
|
|
433
|
+
|
|
434
|
+
if data_format == "channels_first":
|
|
435
|
+
out = jnp.transpose(out, (0, 3, 1, 2))
|
|
436
|
+
|
|
437
|
+
return out
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def _adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
|
|
441
|
+
if isinstance(output_size, int):
|
|
442
|
+
output_size = (output_size, output_size)
|
|
443
|
+
|
|
444
|
+
if data_format == "channels_first":
|
|
445
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 1))
|
|
446
|
+
|
|
447
|
+
n, h, w, c = inputs.shape
|
|
448
|
+
out_h, out_w = output_size
|
|
449
|
+
|
|
450
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
451
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
452
|
+
|
|
453
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
454
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
455
|
+
|
|
456
|
+
small_h_pool = lax.reduce_window(
|
|
457
|
+
inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
big_h_pool = lax.reduce_window(
|
|
461
|
+
inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)
|
|
465
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=1)
|
|
466
|
+
|
|
467
|
+
small_w_pool = lax.reduce_window(
|
|
468
|
+
pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
big_w_pool = lax.reduce_window(
|
|
472
|
+
pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)
|
|
476
|
+
out = jnp.take(combined_w, gather_w, axis=2)
|
|
477
|
+
|
|
478
|
+
if data_format == "channels_first":
|
|
479
|
+
out = jnp.transpose(out, (0, 3, 1, 2))
|
|
480
|
+
|
|
481
|
+
return out
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def _adaptive_average_pool3d(inputs, output_size, data_format="channels_first"):
|
|
485
|
+
if isinstance(output_size, int):
|
|
486
|
+
output_size = (output_size, output_size, output_size)
|
|
487
|
+
|
|
488
|
+
if data_format == "channels_first":
|
|
489
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
|
|
490
|
+
|
|
491
|
+
n, d, h, w, c = inputs.shape
|
|
492
|
+
out_d, out_h, out_w = output_size
|
|
493
|
+
|
|
494
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
|
|
495
|
+
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
|
|
496
|
+
|
|
497
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
498
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
499
|
+
|
|
500
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
501
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
502
|
+
|
|
503
|
+
small_d_pool = (
|
|
504
|
+
lax.reduce_window(
|
|
505
|
+
inputs,
|
|
506
|
+
0.0,
|
|
507
|
+
lax.add,
|
|
508
|
+
(1, small_d, 1, 1, 1),
|
|
509
|
+
(1, 1, 1, 1, 1),
|
|
510
|
+
"valid",
|
|
511
|
+
)
|
|
512
|
+
/ small_d
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
big_d_pool = (
|
|
516
|
+
lax.reduce_window(
|
|
517
|
+
inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
|
|
518
|
+
)
|
|
519
|
+
/ big_d
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
|
|
523
|
+
pooled_d = jnp.take(combined_d, gather_d, axis=1)
|
|
524
|
+
|
|
525
|
+
small_h_pool = (
|
|
526
|
+
lax.reduce_window(
|
|
527
|
+
pooled_d,
|
|
528
|
+
0.0,
|
|
529
|
+
lax.add,
|
|
530
|
+
(1, 1, small_h, 1, 1),
|
|
531
|
+
(1, 1, 1, 1, 1),
|
|
532
|
+
"valid",
|
|
533
|
+
)
|
|
534
|
+
/ small_h
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
big_h_pool = (
|
|
538
|
+
lax.reduce_window(
|
|
539
|
+
pooled_d,
|
|
540
|
+
0.0,
|
|
541
|
+
lax.add,
|
|
542
|
+
(1, 1, big_h, 1, 1),
|
|
543
|
+
(1, 1, 1, 1, 1),
|
|
544
|
+
"valid",
|
|
545
|
+
)
|
|
546
|
+
/ big_h
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
|
|
550
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=2)
|
|
551
|
+
|
|
552
|
+
small_w_pool = (
|
|
553
|
+
lax.reduce_window(
|
|
554
|
+
pooled_h,
|
|
555
|
+
0.0,
|
|
556
|
+
lax.add,
|
|
557
|
+
(1, 1, 1, small_w, 1),
|
|
558
|
+
(1, 1, 1, 1, 1),
|
|
559
|
+
"valid",
|
|
560
|
+
)
|
|
561
|
+
/ small_w
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
big_w_pool = (
|
|
565
|
+
lax.reduce_window(
|
|
566
|
+
pooled_h,
|
|
567
|
+
0.0,
|
|
568
|
+
lax.add,
|
|
569
|
+
(1, 1, 1, big_w, 1),
|
|
570
|
+
(1, 1, 1, 1, 1),
|
|
571
|
+
"valid",
|
|
572
|
+
)
|
|
573
|
+
/ big_w
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
|
|
577
|
+
out = jnp.take(combined_w, gather_w, axis=3)
|
|
578
|
+
|
|
579
|
+
if data_format == "channels_first":
|
|
580
|
+
out = jnp.transpose(out, (0, 4, 1, 2, 3))
|
|
581
|
+
|
|
582
|
+
return out
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def _adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
|
|
586
|
+
if isinstance(output_size, int):
|
|
587
|
+
output_size = (output_size, output_size, output_size)
|
|
588
|
+
|
|
589
|
+
if data_format == "channels_first":
|
|
590
|
+
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))
|
|
591
|
+
|
|
592
|
+
n, d, h, w, c = inputs.shape
|
|
593
|
+
out_d, out_h, out_w = output_size
|
|
594
|
+
|
|
595
|
+
small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)
|
|
596
|
+
gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)
|
|
597
|
+
|
|
598
|
+
small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)
|
|
599
|
+
gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)
|
|
600
|
+
|
|
601
|
+
small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)
|
|
602
|
+
gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)
|
|
603
|
+
|
|
604
|
+
small_d_pool = lax.reduce_window(
|
|
605
|
+
inputs,
|
|
606
|
+
-jnp.inf,
|
|
607
|
+
lax.max,
|
|
608
|
+
(1, small_d, 1, 1, 1),
|
|
609
|
+
(1, 1, 1, 1, 1),
|
|
610
|
+
"valid",
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
big_d_pool = lax.reduce_window(
|
|
614
|
+
inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)
|
|
618
|
+
pooled_d = jnp.take(combined_d, gather_d, axis=1)
|
|
619
|
+
|
|
620
|
+
small_h_pool = lax.reduce_window(
|
|
621
|
+
pooled_d,
|
|
622
|
+
-jnp.inf,
|
|
623
|
+
lax.max,
|
|
624
|
+
(1, 1, small_h, 1, 1),
|
|
625
|
+
(1, 1, 1, 1, 1),
|
|
626
|
+
"valid",
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
big_h_pool = lax.reduce_window(
|
|
630
|
+
pooled_d,
|
|
631
|
+
-jnp.inf,
|
|
632
|
+
lax.max,
|
|
633
|
+
(1, 1, big_h, 1, 1),
|
|
634
|
+
(1, 1, 1, 1, 1),
|
|
635
|
+
"valid",
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)
|
|
639
|
+
pooled_h = jnp.take(combined_h, gather_h, axis=2)
|
|
640
|
+
|
|
641
|
+
small_w_pool = lax.reduce_window(
|
|
642
|
+
pooled_h,
|
|
643
|
+
-jnp.inf,
|
|
644
|
+
lax.max,
|
|
645
|
+
(1, 1, 1, small_w, 1),
|
|
646
|
+
(1, 1, 1, 1, 1),
|
|
647
|
+
"valid",
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
big_w_pool = lax.reduce_window(
|
|
651
|
+
pooled_h,
|
|
652
|
+
-jnp.inf,
|
|
653
|
+
lax.max,
|
|
654
|
+
(1, 1, 1, big_w, 1),
|
|
655
|
+
(1, 1, 1, 1, 1),
|
|
656
|
+
"valid",
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)
|
|
660
|
+
out = jnp.take(combined_w, gather_w, axis=3)
|
|
661
|
+
|
|
662
|
+
if data_format == "channels_first":
|
|
663
|
+
out = jnp.transpose(out, (0, 4, 1, 2, 3))
|
|
664
|
+
|
|
665
|
+
return out
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def adaptive_average_pool(inputs, output_size, data_format=None):
|
|
669
|
+
data_format = backend.standardize_data_format(data_format)
|
|
670
|
+
dims = inputs.ndim - 2
|
|
671
|
+
if dims == 1:
|
|
672
|
+
return _adaptive_average_pool1d(inputs, output_size, data_format)
|
|
673
|
+
if dims == 2:
|
|
674
|
+
return _adaptive_average_pool2d(inputs, output_size, data_format)
|
|
675
|
+
if dims == 3:
|
|
676
|
+
return _adaptive_average_pool3d(inputs, output_size, data_format)
|
|
677
|
+
raise ValueError("adaptive_average_pool supports only 1D/2D/3D inputs")
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def adaptive_max_pool(inputs, output_size, data_format=None):
|
|
681
|
+
data_format = backend.standardize_data_format(data_format)
|
|
682
|
+
dims = inputs.ndim - 2
|
|
683
|
+
if dims == 1:
|
|
684
|
+
return _adaptive_max_pool1d(inputs, output_size, data_format)
|
|
685
|
+
if dims == 2:
|
|
686
|
+
return _adaptive_max_pool2d(inputs, output_size, data_format)
|
|
687
|
+
if dims == 3:
|
|
688
|
+
return _adaptive_max_pool3d(inputs, output_size, data_format)
|
|
689
|
+
raise ValueError("adaptive_max_pool supports only 1D/2D/3D inputs")
|
|
690
|
+
|
|
691
|
+
|
|
292
692
|
def _convert_to_lax_conv_dimension_numbers(
|
|
293
693
|
num_spatial_dims,
|
|
294
694
|
data_format="channels_last",
|
|
@@ -355,7 +755,7 @@ def conv(
|
|
|
355
755
|
feature_group_count = channels // kernel_in_channels
|
|
356
756
|
kernel = convert_to_tensor(kernel)
|
|
357
757
|
inputs = convert_to_tensor(inputs, dtype=kernel.dtype)
|
|
358
|
-
|
|
758
|
+
result = jax.lax.conv_general_dilated(
|
|
359
759
|
inputs,
|
|
360
760
|
kernel,
|
|
361
761
|
strides,
|
|
@@ -364,6 +764,14 @@ def conv(
|
|
|
364
764
|
dimension_numbers=dimension_numbers,
|
|
365
765
|
feature_group_count=feature_group_count,
|
|
366
766
|
)
|
|
767
|
+
if result.size == 0:
|
|
768
|
+
raise ValueError(
|
|
769
|
+
"The convolution operation resulted in an empty output. "
|
|
770
|
+
"This can happen if the input is too small for the given "
|
|
771
|
+
"kernel size, strides, dilation rate, and padding mode. "
|
|
772
|
+
"Please check the input shape and convolution parameters."
|
|
773
|
+
)
|
|
774
|
+
return result
|
|
367
775
|
|
|
368
776
|
|
|
369
777
|
def depthwise_conv(
|
|
@@ -396,6 +804,8 @@ def depthwise_conv(
|
|
|
396
804
|
feature_group_count = (
|
|
397
805
|
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
|
|
398
806
|
)
|
|
807
|
+
kernel = convert_to_tensor(kernel)
|
|
808
|
+
inputs = convert_to_tensor(inputs)
|
|
399
809
|
kernel = jnp.reshape(
|
|
400
810
|
kernel,
|
|
401
811
|
kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
|
|
@@ -499,7 +909,7 @@ def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
|
|
|
499
909
|
values = jnp.greater_equal(jnp.ravel(x), 0).astype(dtype)
|
|
500
910
|
values_count = values.shape[0]
|
|
501
911
|
indices = [jnp.arange(dim) for dim in x.shape]
|
|
502
|
-
indices = jnp.meshgrid(*indices, indexing="ij")
|
|
912
|
+
indices = list(jnp.meshgrid(*indices, indexing="ij"))
|
|
503
913
|
indices.insert(axis, jnp.maximum(x, 0)) # Deal with negative indices
|
|
504
914
|
indices = [a.reshape(values_count, 1).astype("int32") for a in indices]
|
|
505
915
|
indices = jnp.concatenate(indices, axis=1)
|
|
@@ -1061,25 +1471,42 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
|
|
|
1061
1471
|
# Only support at least Ampere
|
|
1062
1472
|
if not check_compute_capability("8.0"):
|
|
1063
1473
|
raise RuntimeError("Require at least Ampere arch to run")
|
|
1064
|
-
|
|
1474
|
+
|
|
1475
|
+
# Inspect inputs of `check_layout`
|
|
1065
1476
|
check_layout_params = list(
|
|
1066
1477
|
inspect.signature(check_layout).parameters.keys()
|
|
1067
1478
|
)
|
|
1068
1479
|
for known_param in ("query", "key", "value", "bias", "layout"):
|
|
1069
1480
|
check_layout_params.remove(known_param)
|
|
1070
1481
|
# Defaults to `None` when not specified.
|
|
1071
|
-
|
|
1482
|
+
check_layout_kwargs = {key: None for key in check_layout_params}
|
|
1072
1483
|
check_layout(
|
|
1073
|
-
query, key, value, bias, layout=_normalize_layout("BTNH"), **kwargs
|
|
1074
|
-
)
|
|
1075
|
-
check_is_flash_attention(
|
|
1076
1484
|
query,
|
|
1077
1485
|
key,
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1486
|
+
value,
|
|
1487
|
+
bias,
|
|
1488
|
+
layout=_normalize_layout("BTNH"),
|
|
1489
|
+
**check_layout_kwargs,
|
|
1082
1490
|
)
|
|
1491
|
+
|
|
1492
|
+
# Inspect inputs of `check_is_flash_attention`
|
|
1493
|
+
check_is_flash_attention_params = inspect.signature(
|
|
1494
|
+
check_is_flash_attention
|
|
1495
|
+
).parameters
|
|
1496
|
+
check_is_flash_attention_kwargs = {
|
|
1497
|
+
"query": query,
|
|
1498
|
+
"key": key,
|
|
1499
|
+
"value": value,
|
|
1500
|
+
"layout": _normalize_layout("BTNH"),
|
|
1501
|
+
"cudnn_version": cudnn_version,
|
|
1502
|
+
"has_bias": bias is not None,
|
|
1503
|
+
"is_training": False,
|
|
1504
|
+
}
|
|
1505
|
+
# Remove unsupported arguments
|
|
1506
|
+
for param in list(check_is_flash_attention_kwargs.keys()):
|
|
1507
|
+
if param not in check_is_flash_attention_params:
|
|
1508
|
+
check_is_flash_attention_kwargs.pop(param)
|
|
1509
|
+
check_is_flash_attention(**check_is_flash_attention_kwargs)
|
|
1083
1510
|
return True
|
|
1084
1511
|
except:
|
|
1085
1512
|
if raise_error:
|
|
@@ -1330,25 +1757,32 @@ def dot_product_attention(
|
|
|
1330
1757
|
if custom_mask is None and is_causal:
|
|
1331
1758
|
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
|
|
1332
1759
|
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
decoder_segment_ids=decoder_segment_ids,
|
|
1339
|
-
custom_mask=custom_mask,
|
|
1340
|
-
attn_logits_soft_cap=attn_logits_soft_cap,
|
|
1341
|
-
head_shards=head_shards,
|
|
1342
|
-
q_seq_shards=q_seq_shards,
|
|
1343
|
-
)
|
|
1344
|
-
# Transpose output back to Keras layout
|
|
1345
|
-
return jnp.transpose(output, axes=(0, 2, 1, 3))
|
|
1346
|
-
except Exception:
|
|
1347
|
-
logging.exception(
|
|
1348
|
-
"Failed to apply Splash kernel for flash attention. "
|
|
1349
|
-
"Falling back to JAX native dot_product_attention."
|
|
1350
|
-
)
|
|
1760
|
+
# Splash attention kernel requires concrete mask values for hashing.
|
|
1761
|
+
# If the mask is a tracer (e.g. inside a scan/loop), we must fall back.
|
|
1762
|
+
if isinstance(mask, jax.core.Tracer) or isinstance(
|
|
1763
|
+
custom_mask, jax.core.Tracer
|
|
1764
|
+
):
|
|
1351
1765
|
flash_attention = False
|
|
1766
|
+
else:
|
|
1767
|
+
try:
|
|
1768
|
+
output = wrap_flash_attention(
|
|
1769
|
+
query_tpu_layout,
|
|
1770
|
+
key_tpu_layout,
|
|
1771
|
+
value_tpu_layout,
|
|
1772
|
+
decoder_segment_ids=decoder_segment_ids,
|
|
1773
|
+
custom_mask=custom_mask,
|
|
1774
|
+
attn_logits_soft_cap=attn_logits_soft_cap,
|
|
1775
|
+
head_shards=head_shards,
|
|
1776
|
+
q_seq_shards=q_seq_shards,
|
|
1777
|
+
)
|
|
1778
|
+
# Transpose output back to Keras layout
|
|
1779
|
+
return jnp.transpose(output, axes=(0, 2, 1, 3))
|
|
1780
|
+
except Exception:
|
|
1781
|
+
logging.exception(
|
|
1782
|
+
"Failed to apply Splash kernel for flash attention. "
|
|
1783
|
+
"Falling back to JAX native dot_product_attention."
|
|
1784
|
+
)
|
|
1785
|
+
flash_attention = False
|
|
1352
1786
|
|
|
1353
1787
|
# JAX native dot_product_attention for GPU or fallback for TPU
|
|
1354
1788
|
if hasattr(jax.nn, "dot_product_attention"):
|
|
@@ -1394,6 +1828,11 @@ def dot_product_attention(
|
|
|
1394
1828
|
|
|
1395
1829
|
def _reshape_to_grouped(t):
|
|
1396
1830
|
if t is not None:
|
|
1831
|
+
while t.ndim < 4:
|
|
1832
|
+
if t.ndim == 3 and t.shape[1] == N:
|
|
1833
|
+
t = jnp.expand_dims(t, axis=2)
|
|
1834
|
+
else:
|
|
1835
|
+
t = jnp.expand_dims(t, axis=1)
|
|
1397
1836
|
tB, tN, tT, tS = t.shape
|
|
1398
1837
|
if tN == 1:
|
|
1399
1838
|
t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))
|
|
@@ -1411,3 +1850,46 @@ def dot_product_attention(
|
|
|
1411
1850
|
)
|
|
1412
1851
|
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
|
|
1413
1852
|
return jnp.reshape(encoded, output_shape)
|
|
1853
|
+
|
|
1854
|
+
|
|
1855
|
+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|
1856
|
+
"""JAX implementation of Unfold.
|
|
1857
|
+
Extract sliding local blocks from a **NCHW** batched image tensor.
|
|
1858
|
+
|
|
1859
|
+
Args:
|
|
1860
|
+
input: 4-D tensor, shape (N, C, H, W) **required**.
|
|
1861
|
+
kernel_size: int or (kH, kW)
|
|
1862
|
+
dilation: int or (dH, dW), default 1
|
|
1863
|
+
padding: int or (pH, pW), default 0
|
|
1864
|
+
stride: int or (sH, sW), default 1
|
|
1865
|
+
|
|
1866
|
+
Returns:
|
|
1867
|
+
3-D tensor, shape (N, C*kH*kW, L)
|
|
1868
|
+
"""
|
|
1869
|
+
|
|
1870
|
+
def _pair(x):
|
|
1871
|
+
return (x, x) if isinstance(x, int) else x
|
|
1872
|
+
|
|
1873
|
+
k = _pair(kernel_size)
|
|
1874
|
+
d = _pair(dilation)
|
|
1875
|
+
p = _pair(padding)
|
|
1876
|
+
s = _pair(stride)
|
|
1877
|
+
|
|
1878
|
+
N, C, H, W = input.shape
|
|
1879
|
+
|
|
1880
|
+
# ---- padding ----
|
|
1881
|
+
if any(_ > 0 for _ in p):
|
|
1882
|
+
input = jnp.pad(input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])))
|
|
1883
|
+
|
|
1884
|
+
patches = lax.conv_general_dilated_patches(
|
|
1885
|
+
input,
|
|
1886
|
+
filter_shape=k,
|
|
1887
|
+
window_strides=s,
|
|
1888
|
+
padding="VALID", # has padde
|
|
1889
|
+
rhs_dilation=d,
|
|
1890
|
+
dimension_numbers=("NCHW", "OIHW", "NCHW"), # only support 'NCHW'
|
|
1891
|
+
) # shape: (N, C*kH*kW, oH, oW)
|
|
1892
|
+
|
|
1893
|
+
# ---- reshape -> (N, C*kH*kW, L) ----
|
|
1894
|
+
_, CKK, oH, oW = patches.shape
|
|
1895
|
+
return patches.reshape(N, CKK, oH * oW)
|