keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/backend/torch/nn.py
CHANGED
|
@@ -458,6 +458,94 @@ def average_pool(
|
|
|
458
458
|
return outputs
|
|
459
459
|
|
|
460
460
|
|
|
461
|
+
def adaptive_average_pool(inputs, output_size, data_format=None):
|
|
462
|
+
"""Adaptive average pooling(1D/2D/3D) with channels_last support."""
|
|
463
|
+
inputs = convert_to_tensor(inputs)
|
|
464
|
+
num_spatial_dims = inputs.ndim - 2
|
|
465
|
+
|
|
466
|
+
data_format = backend.standardize_data_format(data_format)
|
|
467
|
+
orig_format = data_format
|
|
468
|
+
if data_format == "channels_last":
|
|
469
|
+
inputs = _transpose_spatial_inputs(inputs)
|
|
470
|
+
|
|
471
|
+
if isinstance(output_size, int):
|
|
472
|
+
torch_output_size = (
|
|
473
|
+
output_size
|
|
474
|
+
if num_spatial_dims == 1
|
|
475
|
+
else (output_size,) * num_spatial_dims
|
|
476
|
+
)
|
|
477
|
+
else:
|
|
478
|
+
torch_output_size = standardize_tuple(
|
|
479
|
+
output_size, num_spatial_dims, "output_size"
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
if get_device() == "meta":
|
|
483
|
+
inputs = torch.empty(
|
|
484
|
+
size=inputs.shape, dtype=inputs.dtype, device="cpu"
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
if num_spatial_dims == 1:
|
|
488
|
+
outputs = tnn.adaptive_avg_pool1d(inputs, output_size=torch_output_size)
|
|
489
|
+
elif num_spatial_dims == 2:
|
|
490
|
+
outputs = tnn.adaptive_avg_pool2d(inputs, output_size=torch_output_size)
|
|
491
|
+
elif num_spatial_dims == 3:
|
|
492
|
+
outputs = tnn.adaptive_avg_pool3d(inputs, output_size=torch_output_size)
|
|
493
|
+
else:
|
|
494
|
+
raise ValueError(
|
|
495
|
+
"Inputs to adaptive average pooling must have ndim=3, 4 or 5, "
|
|
496
|
+
f"Received input shape: {inputs.shape}."
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
if orig_format == "channels_last":
|
|
500
|
+
outputs = _transpose_spatial_outputs(outputs)
|
|
501
|
+
return outputs
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
def adaptive_max_pool(inputs, output_size, data_format=None):
|
|
505
|
+
"""Adaptive max pooling(1D/2D/3D) with channels_last support."""
|
|
506
|
+
inputs = convert_to_tensor(inputs)
|
|
507
|
+
num_spatial_dims = inputs.ndim - 2
|
|
508
|
+
|
|
509
|
+
data_format = backend.standardize_data_format(data_format)
|
|
510
|
+
orig_format = data_format
|
|
511
|
+
if data_format == "channels_last":
|
|
512
|
+
inputs = _transpose_spatial_inputs(inputs)
|
|
513
|
+
|
|
514
|
+
if isinstance(output_size, int):
|
|
515
|
+
torch_output_size = (
|
|
516
|
+
output_size
|
|
517
|
+
if num_spatial_dims == 1
|
|
518
|
+
else (output_size,) * num_spatial_dims
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
torch_output_size = standardize_tuple(
|
|
522
|
+
output_size, num_spatial_dims, "output_size"
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
if get_device() == "meta":
|
|
526
|
+
inputs = torch.empty(
|
|
527
|
+
size=inputs.shape, dtype=inputs.dtype, device="cpu"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
if num_spatial_dims == 1:
|
|
531
|
+
res = tnn.adaptive_max_pool1d(inputs, output_size=torch_output_size)
|
|
532
|
+
elif num_spatial_dims == 2:
|
|
533
|
+
res = tnn.adaptive_max_pool2d(inputs, output_size=torch_output_size)
|
|
534
|
+
elif num_spatial_dims == 3:
|
|
535
|
+
res = tnn.adaptive_max_pool3d(inputs, output_size=torch_output_size)
|
|
536
|
+
else:
|
|
537
|
+
raise ValueError(
|
|
538
|
+
"Inputs to adaptive max pooling must have ndim=3, 4 or 5, "
|
|
539
|
+
f"Received input shape: {inputs.shape}."
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
outputs = res[0] if isinstance(res, tuple) else res
|
|
543
|
+
|
|
544
|
+
if orig_format == "channels_last":
|
|
545
|
+
outputs = _transpose_spatial_outputs(outputs)
|
|
546
|
+
return outputs
|
|
547
|
+
|
|
548
|
+
|
|
461
549
|
def conv(
|
|
462
550
|
inputs,
|
|
463
551
|
kernel,
|
|
@@ -755,12 +843,26 @@ def binary_crossentropy(target, output, from_logits=False):
|
|
|
755
843
|
target = convert_to_tensor(target)
|
|
756
844
|
output = convert_to_tensor(output)
|
|
757
845
|
|
|
846
|
+
# We only apply the squeeze fix if we are on an MPS device,
|
|
847
|
+
# as this change breaks tests on other platforms that
|
|
848
|
+
# expect the original tensor shape to be preserved.
|
|
849
|
+
if (
|
|
850
|
+
torch.backends.mps.is_available()
|
|
851
|
+
and target.ndim > 1
|
|
852
|
+
and output.ndim == target.ndim
|
|
853
|
+
and target.shape[-1] == 1
|
|
854
|
+
and output.shape[-1] == 1
|
|
855
|
+
):
|
|
856
|
+
target = torch.squeeze(target, -1).contiguous()
|
|
857
|
+
output = torch.squeeze(output, -1).contiguous()
|
|
858
|
+
|
|
758
859
|
if target.shape != output.shape:
|
|
759
860
|
raise ValueError(
|
|
760
861
|
"Arguments `target` and `output` must have the same shape. "
|
|
761
862
|
"Received: "
|
|
762
863
|
f"target.shape={target.shape}, output.shape={output.shape}"
|
|
763
864
|
)
|
|
865
|
+
|
|
764
866
|
# By default, PyTorch, does reduction of `sum` over all rows,
|
|
765
867
|
# change reduction to `none` to keep dim
|
|
766
868
|
if from_logits:
|
|
@@ -1092,3 +1194,26 @@ def dot_product_attention(
|
|
|
1092
1194
|
scale=scale,
|
|
1093
1195
|
)
|
|
1094
1196
|
return torch.transpose(attention_output, axis1, axis0)
|
|
1197
|
+
|
|
1198
|
+
|
|
1199
|
+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
|
|
1200
|
+
"""Native PyTorch implementation of Unfold.
|
|
1201
|
+
Extract sliding local blocks from a **NCHW** batched image tensor.
|
|
1202
|
+
|
|
1203
|
+
Args:
|
|
1204
|
+
input: 4-D tensor, shape (N, C, H, W) **required**.
|
|
1205
|
+
kernel_size: int or (kH, kW)
|
|
1206
|
+
dilation: int or (dH, dW), default 1
|
|
1207
|
+
padding: int or (pH, pW), default 0
|
|
1208
|
+
stride: int or (sH, sW), default 1
|
|
1209
|
+
|
|
1210
|
+
Returns:
|
|
1211
|
+
3-D tensor, shape (N, C*kH*kW, L)
|
|
1212
|
+
"""
|
|
1213
|
+
return tnn.unfold(
|
|
1214
|
+
input,
|
|
1215
|
+
kernel_size=kernel_size,
|
|
1216
|
+
dilation=dilation,
|
|
1217
|
+
padding=padding,
|
|
1218
|
+
stride=stride,
|
|
1219
|
+
)
|
keras/src/backend/torch/numpy.py
CHANGED
|
@@ -313,18 +313,19 @@ def append(x1, x2, axis=None):
|
|
|
313
313
|
return torch.cat((x1, x2), dim=axis)
|
|
314
314
|
|
|
315
315
|
|
|
316
|
-
def arange(start, stop=None, step=
|
|
316
|
+
def arange(start, stop=None, step=None, dtype=None):
|
|
317
317
|
if dtype is None:
|
|
318
|
-
dtypes_to_resolve = [
|
|
319
|
-
getattr(start, "dtype", type(start)),
|
|
320
|
-
getattr(step, "dtype", type(step)),
|
|
321
|
-
]
|
|
318
|
+
dtypes_to_resolve = [getattr(start, "dtype", type(start))]
|
|
322
319
|
if stop is not None:
|
|
323
320
|
dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
|
|
321
|
+
if step is not None:
|
|
322
|
+
dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
|
|
324
323
|
dtype = dtypes.result_type(*dtypes_to_resolve)
|
|
325
324
|
dtype = to_torch_dtype(dtype)
|
|
326
325
|
if stop is None:
|
|
327
|
-
|
|
326
|
+
start, stop = 0, start
|
|
327
|
+
if step is None:
|
|
328
|
+
step = 1
|
|
328
329
|
return torch.arange(
|
|
329
330
|
start, stop, step=step, dtype=dtype, device=get_device()
|
|
330
331
|
)
|
|
@@ -410,6 +411,12 @@ def array(x, dtype=None):
|
|
|
410
411
|
return convert_to_tensor(x, dtype=dtype)
|
|
411
412
|
|
|
412
413
|
|
|
414
|
+
def view(x, dtype=None):
|
|
415
|
+
dtype = to_torch_dtype(dtype)
|
|
416
|
+
x = convert_to_tensor(x)
|
|
417
|
+
return x.view(dtype=dtype)
|
|
418
|
+
|
|
419
|
+
|
|
413
420
|
def average(x, axis=None, weights=None):
|
|
414
421
|
x = convert_to_tensor(x)
|
|
415
422
|
dtypes_to_resolve = [x.dtype, float]
|
|
@@ -763,6 +770,12 @@ def empty(shape, dtype=None):
|
|
|
763
770
|
return torch.empty(size=shape, dtype=dtype, device=get_device())
|
|
764
771
|
|
|
765
772
|
|
|
773
|
+
def empty_like(x, dtype=None):
|
|
774
|
+
x = convert_to_tensor(x)
|
|
775
|
+
dtype = to_torch_dtype(dtype or x.dtype)
|
|
776
|
+
return torch.empty_like(x, dtype=dtype, device=get_device())
|
|
777
|
+
|
|
778
|
+
|
|
766
779
|
def equal(x1, x2):
|
|
767
780
|
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
|
768
781
|
return torch.eq(x1, x2)
|
|
@@ -945,6 +958,37 @@ def isposinf(x):
|
|
|
945
958
|
return torch.isposinf(x)
|
|
946
959
|
|
|
947
960
|
|
|
961
|
+
def isreal(x):
|
|
962
|
+
x = convert_to_tensor(x)
|
|
963
|
+
return torch.isreal(x)
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
def kron(x1, x2):
|
|
967
|
+
x1 = convert_to_tensor(x1)
|
|
968
|
+
x2 = convert_to_tensor(x2)
|
|
969
|
+
return torch.kron(x1, x2)
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
def lcm(x1, x2):
|
|
973
|
+
x1 = convert_to_tensor(x1)
|
|
974
|
+
x2 = convert_to_tensor(x2)
|
|
975
|
+
return torch.lcm(x1, x2)
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
def ldexp(x1, x2):
|
|
979
|
+
x1 = convert_to_tensor(x1)
|
|
980
|
+
x2 = convert_to_tensor(x2)
|
|
981
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
982
|
+
|
|
983
|
+
if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
|
|
984
|
+
raise TypeError(
|
|
985
|
+
f"ldexp exponent must be an integer type. "
|
|
986
|
+
f"Received: x2 dtype={x2.dtype}"
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
return cast(torch.ldexp(x1, x2), dtype)
|
|
990
|
+
|
|
991
|
+
|
|
948
992
|
def less(x1, x2):
|
|
949
993
|
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
|
950
994
|
return torch.less(x1, x2)
|
|
@@ -1041,6 +1085,15 @@ def logaddexp(x1, x2):
|
|
|
1041
1085
|
return torch.logaddexp(x1, x2)
|
|
1042
1086
|
|
|
1043
1087
|
|
|
1088
|
+
def logaddexp2(x1, x2):
|
|
1089
|
+
x1 = convert_to_tensor(x1)
|
|
1090
|
+
x2 = convert_to_tensor(x2)
|
|
1091
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1092
|
+
x1 = cast(x1, dtype)
|
|
1093
|
+
x2 = cast(x2, dtype)
|
|
1094
|
+
return torch.logaddexp2(x1, x2)
|
|
1095
|
+
|
|
1096
|
+
|
|
1044
1097
|
def logical_and(x1, x2):
|
|
1045
1098
|
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
|
|
1046
1099
|
return torch.logical_and(x1, x2)
|
|
@@ -1329,6 +1382,18 @@ def prod(x, axis=None, keepdims=False, dtype=None):
|
|
|
1329
1382
|
return x
|
|
1330
1383
|
|
|
1331
1384
|
|
|
1385
|
+
def ptp(x, axis=None, keepdims=False):
|
|
1386
|
+
x = convert_to_tensor(x)
|
|
1387
|
+
if axis is None:
|
|
1388
|
+
return x.max() - x.min()
|
|
1389
|
+
elif axis == ():
|
|
1390
|
+
return torch.zeros_like(x)
|
|
1391
|
+
else:
|
|
1392
|
+
return torch.amax(x, dim=axis, keepdim=keepdims) - torch.amin(
|
|
1393
|
+
x, dim=axis, keepdim=keepdims
|
|
1394
|
+
)
|
|
1395
|
+
|
|
1396
|
+
|
|
1332
1397
|
def quantile(x, q, axis=None, method="linear", keepdims=False):
|
|
1333
1398
|
x = convert_to_tensor(x)
|
|
1334
1399
|
q = convert_to_tensor(q)
|
|
@@ -1434,7 +1499,7 @@ def searchsorted(sorted_sequence, values, side="left"):
|
|
|
1434
1499
|
"to extend it to N-D sequences. Received: "
|
|
1435
1500
|
f"sorted_sequence.shape={sorted_sequence.shape}"
|
|
1436
1501
|
)
|
|
1437
|
-
out_int32 =
|
|
1502
|
+
out_int32 = sorted_sequence.shape[0] <= np.iinfo(np.int32).max
|
|
1438
1503
|
return torch.searchsorted(
|
|
1439
1504
|
sorted_sequence, values, side=side, out_int32=out_int32
|
|
1440
1505
|
)
|
|
@@ -1506,6 +1571,12 @@ def split(x, indices_or_sections, axis=0):
|
|
|
1506
1571
|
return list(out)
|
|
1507
1572
|
|
|
1508
1573
|
|
|
1574
|
+
def array_split(x, indices_or_sections, axis=0):
|
|
1575
|
+
x = convert_to_tensor(x)
|
|
1576
|
+
out = torch.tensor_split(x, indices_or_sections, dim=axis)
|
|
1577
|
+
return list(out)
|
|
1578
|
+
|
|
1579
|
+
|
|
1509
1580
|
def stack(x, axis=0):
|
|
1510
1581
|
x = [convert_to_tensor(elem) for elem in x]
|
|
1511
1582
|
return torch.stack(x, dim=axis)
|
|
@@ -1619,8 +1690,9 @@ def tile(x, repeats):
|
|
|
1619
1690
|
def trace(x, offset=0, axis1=0, axis2=1):
|
|
1620
1691
|
x = convert_to_tensor(x)
|
|
1621
1692
|
dtype = standardize_dtype(x.dtype)
|
|
1622
|
-
if dtype
|
|
1623
|
-
|
|
1693
|
+
if dtype in ("bool", "int8", "int16", "uint8"):
|
|
1694
|
+
# Torch backend doesn't support uint32 dtype.
|
|
1695
|
+
dtype = "int32"
|
|
1624
1696
|
return torch.sum(
|
|
1625
1697
|
torch.diagonal(x, offset, axis1, axis2),
|
|
1626
1698
|
dim=-1,
|
|
@@ -1733,6 +1805,16 @@ def negative(x):
|
|
|
1733
1805
|
return torch.negative(x)
|
|
1734
1806
|
|
|
1735
1807
|
|
|
1808
|
+
def nextafter(x1, x2):
|
|
1809
|
+
x1 = convert_to_tensor(x1)
|
|
1810
|
+
x2 = convert_to_tensor(x2)
|
|
1811
|
+
|
|
1812
|
+
dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
|
|
1813
|
+
x1 = cast(x1, torch.float64)
|
|
1814
|
+
x2 = cast(x2, torch.float64)
|
|
1815
|
+
return cast(torch.nextafter(x1, x2), dtype)
|
|
1816
|
+
|
|
1817
|
+
|
|
1736
1818
|
def square(x):
|
|
1737
1819
|
x = convert_to_tensor(x)
|
|
1738
1820
|
if standardize_dtype(x.dtype) == "bool":
|
|
@@ -1761,6 +1843,24 @@ def transpose(x, axes=None):
|
|
|
1761
1843
|
return x.T
|
|
1762
1844
|
|
|
1763
1845
|
|
|
1846
|
+
def trapezoid(y, x=None, dx=1.0, axis=-1):
|
|
1847
|
+
y = convert_to_tensor(y)
|
|
1848
|
+
if standardize_dtype(y.dtype) == "bool":
|
|
1849
|
+
y = cast(y, config.floatx())
|
|
1850
|
+
if x is not None:
|
|
1851
|
+
x = convert_to_tensor(x)
|
|
1852
|
+
return torch.trapz(y, x=x, dim=axis)
|
|
1853
|
+
else:
|
|
1854
|
+
dx = convert_to_tensor(dx)
|
|
1855
|
+
return torch.trapz(y, dx=dx, dim=axis)
|
|
1856
|
+
|
|
1857
|
+
|
|
1858
|
+
def vander(x, N=None, increasing=False):
|
|
1859
|
+
x = convert_to_tensor(x)
|
|
1860
|
+
result_dtype = dtypes.result_type(x.dtype)
|
|
1861
|
+
return cast(torch.vander(x, N=N, increasing=increasing), result_dtype)
|
|
1862
|
+
|
|
1863
|
+
|
|
1764
1864
|
def var(x, axis=None, keepdims=False):
|
|
1765
1865
|
x = convert_to_tensor(x)
|
|
1766
1866
|
compute_dtype = dtypes.result_type(x.dtype, "float32")
|
|
@@ -54,7 +54,10 @@ class TorchTrainer(base_trainer.Trainer):
|
|
|
54
54
|
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=True
|
|
55
55
|
)
|
|
56
56
|
self._loss_tracker.update_state(
|
|
57
|
-
loss,
|
|
57
|
+
loss,
|
|
58
|
+
sample_weight=next(
|
|
59
|
+
i for i in tree.flatten(x) if i is not None
|
|
60
|
+
).shape[0],
|
|
58
61
|
)
|
|
59
62
|
if self.optimizer is not None:
|
|
60
63
|
loss = self.optimizer.scale_loss(loss)
|
|
@@ -90,7 +93,10 @@ class TorchTrainer(base_trainer.Trainer):
|
|
|
90
93
|
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
|
|
91
94
|
)
|
|
92
95
|
self._loss_tracker.update_state(
|
|
93
|
-
loss,
|
|
96
|
+
loss,
|
|
97
|
+
sample_weight=next(
|
|
98
|
+
i for i in tree.flatten(x) if i is not None
|
|
99
|
+
).shape[0],
|
|
94
100
|
)
|
|
95
101
|
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
|
|
96
102
|
|
keras/src/callbacks/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ from keras.src.callbacks.lambda_callback import LambdaCallback
|
|
|
8
8
|
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
|
|
9
9
|
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
|
|
10
10
|
from keras.src.callbacks.monitor_callback import MonitorCallback
|
|
11
|
+
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
|
|
11
12
|
from keras.src.callbacks.progbar_logger import ProgbarLogger
|
|
12
13
|
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
|
|
13
14
|
from keras.src.callbacks.remote_monitor import RemoteMonitor
|
|
@@ -39,6 +39,7 @@ class CallbackList(Callback):
|
|
|
39
39
|
via `Callback.set_params`.
|
|
40
40
|
"""
|
|
41
41
|
self.callbacks = tree.flatten(callbacks) if callbacks else []
|
|
42
|
+
self._in_begin_end_block_count = 0
|
|
42
43
|
self._executor = None
|
|
43
44
|
self._async_train = False
|
|
44
45
|
self._async_test = False
|
|
@@ -78,9 +79,6 @@ class CallbackList(Callback):
|
|
|
78
79
|
if not utils.is_default(cbk.on_predict_batch_end):
|
|
79
80
|
async_predict = False
|
|
80
81
|
|
|
81
|
-
if async_train or async_test or async_predict:
|
|
82
|
-
self._executor = concurrent.futures.ThreadPoolExecutor()
|
|
83
|
-
|
|
84
82
|
self._async_train = async_train
|
|
85
83
|
self._async_test = async_test
|
|
86
84
|
self._async_predict = async_predict
|
|
@@ -113,6 +111,33 @@ class CallbackList(Callback):
|
|
|
113
111
|
for callback in self.callbacks:
|
|
114
112
|
callback.set_model(model)
|
|
115
113
|
|
|
114
|
+
def _on_begin(self):
|
|
115
|
+
"""Called by `on_train/test/predict_begin`.
|
|
116
|
+
|
|
117
|
+
Start the executor for async calls if needed.
|
|
118
|
+
"""
|
|
119
|
+
self._in_begin_end_block_count += 1
|
|
120
|
+
if (
|
|
121
|
+
self._in_begin_end_block_count == 1
|
|
122
|
+
and (self._async_train or self._async_test or self._async_predict)
|
|
123
|
+
and self._executor is None
|
|
124
|
+
):
|
|
125
|
+
self._executor = concurrent.futures.ThreadPoolExecutor()
|
|
126
|
+
|
|
127
|
+
def _on_end(self):
|
|
128
|
+
"""Called by `on_train/test/predict_end`.
|
|
129
|
+
|
|
130
|
+
Shutdown the executor for async calls if all begin/end blocks completed.
|
|
131
|
+
"""
|
|
132
|
+
self._in_begin_end_block_count -= 1
|
|
133
|
+
if self._in_begin_end_block_count < 0:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
"`on_xxx_end` called without corresponding `on_xxx_begin`"
|
|
136
|
+
)
|
|
137
|
+
if self._in_begin_end_block_count == 0 and self._executor is not None:
|
|
138
|
+
self._executor.shutdown()
|
|
139
|
+
self._executor = None
|
|
140
|
+
|
|
116
141
|
def _async_dispatch(self, fn, *args):
|
|
117
142
|
for future in self._futures:
|
|
118
143
|
if future.done():
|
|
@@ -121,7 +146,8 @@ class CallbackList(Callback):
|
|
|
121
146
|
future = self._executor.submit(fn, *args)
|
|
122
147
|
self._futures.append(future)
|
|
123
148
|
|
|
124
|
-
def
|
|
149
|
+
def _flush_futures(self):
|
|
150
|
+
"""Waits for all futures to complete and clears the list."""
|
|
125
151
|
for future in self._futures:
|
|
126
152
|
future.result()
|
|
127
153
|
self._futures = []
|
|
@@ -138,7 +164,7 @@ class CallbackList(Callback):
|
|
|
138
164
|
|
|
139
165
|
def on_epoch_end(self, epoch, logs=None):
|
|
140
166
|
if self._async_train:
|
|
141
|
-
self.
|
|
167
|
+
self._flush_futures()
|
|
142
168
|
|
|
143
169
|
logs = python_utils.pythonify_logs(logs)
|
|
144
170
|
for callback in self.callbacks:
|
|
@@ -204,44 +230,52 @@ class CallbackList(Callback):
|
|
|
204
230
|
callback.on_predict_batch_end(batch, logs=logs)
|
|
205
231
|
|
|
206
232
|
def on_train_begin(self, logs=None):
|
|
233
|
+
self._on_begin()
|
|
234
|
+
|
|
207
235
|
logs = python_utils.pythonify_logs(logs)
|
|
208
236
|
for callback in self.callbacks:
|
|
209
237
|
callback.on_train_begin(logs)
|
|
210
238
|
|
|
211
239
|
def on_train_end(self, logs=None):
|
|
212
240
|
if self._async_train:
|
|
213
|
-
self.
|
|
241
|
+
self._flush_futures()
|
|
214
242
|
|
|
215
243
|
logs = python_utils.pythonify_logs(logs)
|
|
216
244
|
for callback in self.callbacks:
|
|
217
245
|
callback.on_train_end(logs)
|
|
218
246
|
|
|
247
|
+
self._on_end()
|
|
248
|
+
|
|
219
249
|
def on_test_begin(self, logs=None):
|
|
250
|
+
self._on_begin()
|
|
251
|
+
|
|
220
252
|
logs = python_utils.pythonify_logs(logs)
|
|
221
253
|
for callback in self.callbacks:
|
|
222
254
|
callback.on_test_begin(logs)
|
|
223
255
|
|
|
224
256
|
def on_test_end(self, logs=None):
|
|
225
257
|
if self._async_test:
|
|
226
|
-
self.
|
|
258
|
+
self._flush_futures()
|
|
227
259
|
|
|
228
260
|
logs = python_utils.pythonify_logs(logs)
|
|
229
261
|
for callback in self.callbacks:
|
|
230
262
|
callback.on_test_end(logs)
|
|
231
263
|
|
|
264
|
+
self._on_end()
|
|
265
|
+
|
|
232
266
|
def on_predict_begin(self, logs=None):
|
|
267
|
+
self._on_begin()
|
|
268
|
+
|
|
233
269
|
logs = python_utils.pythonify_logs(logs)
|
|
234
270
|
for callback in self.callbacks:
|
|
235
271
|
callback.on_predict_begin(logs)
|
|
236
272
|
|
|
237
273
|
def on_predict_end(self, logs=None):
|
|
238
274
|
if self._async_predict:
|
|
239
|
-
self.
|
|
275
|
+
self._flush_futures()
|
|
240
276
|
|
|
241
277
|
logs = python_utils.pythonify_logs(logs)
|
|
242
278
|
for callback in self.callbacks:
|
|
243
279
|
callback.on_predict_end(logs)
|
|
244
280
|
|
|
245
|
-
|
|
246
|
-
if self._executor is not None:
|
|
247
|
-
self._executor.shutdown(cancel_futures=True)
|
|
281
|
+
self._on_end()
|
|
@@ -283,6 +283,11 @@ class ModelCheckpoint(MonitorCallback):
|
|
|
283
283
|
self.model.save_weights(filepath, overwrite=True)
|
|
284
284
|
else:
|
|
285
285
|
self.model.save(filepath, overwrite=True)
|
|
286
|
+
if self.verbose > 0:
|
|
287
|
+
io_utils.print_msg(
|
|
288
|
+
f"\nEpoch {epoch + 1}: "
|
|
289
|
+
f"finished saving model to {filepath}"
|
|
290
|
+
)
|
|
286
291
|
except IsADirectoryError: # h5py 3.x
|
|
287
292
|
raise IOError(
|
|
288
293
|
"Please specify a non-directory filepath for "
|