keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -458,6 +458,94 @@ def average_pool(
458
458
  return outputs
459
459
 
460
460
 
461
+ def adaptive_average_pool(inputs, output_size, data_format=None):
462
+ """Adaptive average pooling(1D/2D/3D) with channels_last support."""
463
+ inputs = convert_to_tensor(inputs)
464
+ num_spatial_dims = inputs.ndim - 2
465
+
466
+ data_format = backend.standardize_data_format(data_format)
467
+ orig_format = data_format
468
+ if data_format == "channels_last":
469
+ inputs = _transpose_spatial_inputs(inputs)
470
+
471
+ if isinstance(output_size, int):
472
+ torch_output_size = (
473
+ output_size
474
+ if num_spatial_dims == 1
475
+ else (output_size,) * num_spatial_dims
476
+ )
477
+ else:
478
+ torch_output_size = standardize_tuple(
479
+ output_size, num_spatial_dims, "output_size"
480
+ )
481
+
482
+ if get_device() == "meta":
483
+ inputs = torch.empty(
484
+ size=inputs.shape, dtype=inputs.dtype, device="cpu"
485
+ )
486
+
487
+ if num_spatial_dims == 1:
488
+ outputs = tnn.adaptive_avg_pool1d(inputs, output_size=torch_output_size)
489
+ elif num_spatial_dims == 2:
490
+ outputs = tnn.adaptive_avg_pool2d(inputs, output_size=torch_output_size)
491
+ elif num_spatial_dims == 3:
492
+ outputs = tnn.adaptive_avg_pool3d(inputs, output_size=torch_output_size)
493
+ else:
494
+ raise ValueError(
495
+ "Inputs to adaptive average pooling must have ndim=3, 4 or 5, "
496
+ f"Received input shape: {inputs.shape}."
497
+ )
498
+
499
+ if orig_format == "channels_last":
500
+ outputs = _transpose_spatial_outputs(outputs)
501
+ return outputs
502
+
503
+
504
+ def adaptive_max_pool(inputs, output_size, data_format=None):
505
+ """Adaptive max pooling(1D/2D/3D) with channels_last support."""
506
+ inputs = convert_to_tensor(inputs)
507
+ num_spatial_dims = inputs.ndim - 2
508
+
509
+ data_format = backend.standardize_data_format(data_format)
510
+ orig_format = data_format
511
+ if data_format == "channels_last":
512
+ inputs = _transpose_spatial_inputs(inputs)
513
+
514
+ if isinstance(output_size, int):
515
+ torch_output_size = (
516
+ output_size
517
+ if num_spatial_dims == 1
518
+ else (output_size,) * num_spatial_dims
519
+ )
520
+ else:
521
+ torch_output_size = standardize_tuple(
522
+ output_size, num_spatial_dims, "output_size"
523
+ )
524
+
525
+ if get_device() == "meta":
526
+ inputs = torch.empty(
527
+ size=inputs.shape, dtype=inputs.dtype, device="cpu"
528
+ )
529
+
530
+ if num_spatial_dims == 1:
531
+ res = tnn.adaptive_max_pool1d(inputs, output_size=torch_output_size)
532
+ elif num_spatial_dims == 2:
533
+ res = tnn.adaptive_max_pool2d(inputs, output_size=torch_output_size)
534
+ elif num_spatial_dims == 3:
535
+ res = tnn.adaptive_max_pool3d(inputs, output_size=torch_output_size)
536
+ else:
537
+ raise ValueError(
538
+ "Inputs to adaptive max pooling must have ndim=3, 4 or 5, "
539
+ f"Received input shape: {inputs.shape}."
540
+ )
541
+
542
+ outputs = res[0] if isinstance(res, tuple) else res
543
+
544
+ if orig_format == "channels_last":
545
+ outputs = _transpose_spatial_outputs(outputs)
546
+ return outputs
547
+
548
+
461
549
  def conv(
462
550
  inputs,
463
551
  kernel,
@@ -755,12 +843,26 @@ def binary_crossentropy(target, output, from_logits=False):
755
843
  target = convert_to_tensor(target)
756
844
  output = convert_to_tensor(output)
757
845
 
846
+ # We only apply the squeeze fix if we are on an MPS device,
847
+ # as this change breaks tests on other platforms that
848
+ # expect the original tensor shape to be preserved.
849
+ if (
850
+ torch.backends.mps.is_available()
851
+ and target.ndim > 1
852
+ and output.ndim == target.ndim
853
+ and target.shape[-1] == 1
854
+ and output.shape[-1] == 1
855
+ ):
856
+ target = torch.squeeze(target, -1).contiguous()
857
+ output = torch.squeeze(output, -1).contiguous()
858
+
758
859
  if target.shape != output.shape:
759
860
  raise ValueError(
760
861
  "Arguments `target` and `output` must have the same shape. "
761
862
  "Received: "
762
863
  f"target.shape={target.shape}, output.shape={output.shape}"
763
864
  )
865
+
764
866
  # By default, PyTorch, does reduction of `sum` over all rows,
765
867
  # change reduction to `none` to keep dim
766
868
  if from_logits:
@@ -1092,3 +1194,26 @@ def dot_product_attention(
1092
1194
  scale=scale,
1093
1195
  )
1094
1196
  return torch.transpose(attention_output, axis1, axis0)
1197
+
1198
+
1199
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1200
+ """Native PyTorch implementation of Unfold.
1201
+ Extract sliding local blocks from a **NCHW** batched image tensor.
1202
+
1203
+ Args:
1204
+ input: 4-D tensor, shape (N, C, H, W) **required**.
1205
+ kernel_size: int or (kH, kW)
1206
+ dilation: int or (dH, dW), default 1
1207
+ padding: int or (pH, pW), default 0
1208
+ stride: int or (sH, sW), default 1
1209
+
1210
+ Returns:
1211
+ 3-D tensor, shape (N, C*kH*kW, L)
1212
+ """
1213
+ return tnn.unfold(
1214
+ input,
1215
+ kernel_size=kernel_size,
1216
+ dilation=dilation,
1217
+ padding=padding,
1218
+ stride=stride,
1219
+ )
@@ -313,18 +313,19 @@ def append(x1, x2, axis=None):
313
313
  return torch.cat((x1, x2), dim=axis)
314
314
 
315
315
 
316
- def arange(start, stop=None, step=1, dtype=None):
316
+ def arange(start, stop=None, step=None, dtype=None):
317
317
  if dtype is None:
318
- dtypes_to_resolve = [
319
- getattr(start, "dtype", type(start)),
320
- getattr(step, "dtype", type(step)),
321
- ]
318
+ dtypes_to_resolve = [getattr(start, "dtype", type(start))]
322
319
  if stop is not None:
323
320
  dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
321
+ if step is not None:
322
+ dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
324
323
  dtype = dtypes.result_type(*dtypes_to_resolve)
325
324
  dtype = to_torch_dtype(dtype)
326
325
  if stop is None:
327
- return torch.arange(end=start, dtype=dtype, device=get_device())
326
+ start, stop = 0, start
327
+ if step is None:
328
+ step = 1
328
329
  return torch.arange(
329
330
  start, stop, step=step, dtype=dtype, device=get_device()
330
331
  )
@@ -410,6 +411,12 @@ def array(x, dtype=None):
410
411
  return convert_to_tensor(x, dtype=dtype)
411
412
 
412
413
 
414
+ def view(x, dtype=None):
415
+ dtype = to_torch_dtype(dtype)
416
+ x = convert_to_tensor(x)
417
+ return x.view(dtype=dtype)
418
+
419
+
413
420
  def average(x, axis=None, weights=None):
414
421
  x = convert_to_tensor(x)
415
422
  dtypes_to_resolve = [x.dtype, float]
@@ -763,6 +770,12 @@ def empty(shape, dtype=None):
763
770
  return torch.empty(size=shape, dtype=dtype, device=get_device())
764
771
 
765
772
 
773
+ def empty_like(x, dtype=None):
774
+ x = convert_to_tensor(x)
775
+ dtype = to_torch_dtype(dtype or x.dtype)
776
+ return torch.empty_like(x, dtype=dtype, device=get_device())
777
+
778
+
766
779
  def equal(x1, x2):
767
780
  x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
768
781
  return torch.eq(x1, x2)
@@ -945,6 +958,37 @@ def isposinf(x):
945
958
  return torch.isposinf(x)
946
959
 
947
960
 
961
+ def isreal(x):
962
+ x = convert_to_tensor(x)
963
+ return torch.isreal(x)
964
+
965
+
966
+ def kron(x1, x2):
967
+ x1 = convert_to_tensor(x1)
968
+ x2 = convert_to_tensor(x2)
969
+ return torch.kron(x1, x2)
970
+
971
+
972
+ def lcm(x1, x2):
973
+ x1 = convert_to_tensor(x1)
974
+ x2 = convert_to_tensor(x2)
975
+ return torch.lcm(x1, x2)
976
+
977
+
978
+ def ldexp(x1, x2):
979
+ x1 = convert_to_tensor(x1)
980
+ x2 = convert_to_tensor(x2)
981
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
982
+
983
+ if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:
984
+ raise TypeError(
985
+ f"ldexp exponent must be an integer type. "
986
+ f"Received: x2 dtype={x2.dtype}"
987
+ )
988
+
989
+ return cast(torch.ldexp(x1, x2), dtype)
990
+
991
+
948
992
  def less(x1, x2):
949
993
  x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
950
994
  return torch.less(x1, x2)
@@ -1041,6 +1085,15 @@ def logaddexp(x1, x2):
1041
1085
  return torch.logaddexp(x1, x2)
1042
1086
 
1043
1087
 
1088
+ def logaddexp2(x1, x2):
1089
+ x1 = convert_to_tensor(x1)
1090
+ x2 = convert_to_tensor(x2)
1091
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1092
+ x1 = cast(x1, dtype)
1093
+ x2 = cast(x2, dtype)
1094
+ return torch.logaddexp2(x1, x2)
1095
+
1096
+
1044
1097
  def logical_and(x1, x2):
1045
1098
  x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
1046
1099
  return torch.logical_and(x1, x2)
@@ -1329,6 +1382,18 @@ def prod(x, axis=None, keepdims=False, dtype=None):
1329
1382
  return x
1330
1383
 
1331
1384
 
1385
+ def ptp(x, axis=None, keepdims=False):
1386
+ x = convert_to_tensor(x)
1387
+ if axis is None:
1388
+ return x.max() - x.min()
1389
+ elif axis == ():
1390
+ return torch.zeros_like(x)
1391
+ else:
1392
+ return torch.amax(x, dim=axis, keepdim=keepdims) - torch.amin(
1393
+ x, dim=axis, keepdim=keepdims
1394
+ )
1395
+
1396
+
1332
1397
  def quantile(x, q, axis=None, method="linear", keepdims=False):
1333
1398
  x = convert_to_tensor(x)
1334
1399
  q = convert_to_tensor(q)
@@ -1434,7 +1499,7 @@ def searchsorted(sorted_sequence, values, side="left"):
1434
1499
  "to extend it to N-D sequences. Received: "
1435
1500
  f"sorted_sequence.shape={sorted_sequence.shape}"
1436
1501
  )
1437
- out_int32 = len(sorted_sequence) <= np.iinfo(np.int32).max
1502
+ out_int32 = sorted_sequence.shape[0] <= np.iinfo(np.int32).max
1438
1503
  return torch.searchsorted(
1439
1504
  sorted_sequence, values, side=side, out_int32=out_int32
1440
1505
  )
@@ -1506,6 +1571,12 @@ def split(x, indices_or_sections, axis=0):
1506
1571
  return list(out)
1507
1572
 
1508
1573
 
1574
+ def array_split(x, indices_or_sections, axis=0):
1575
+ x = convert_to_tensor(x)
1576
+ out = torch.tensor_split(x, indices_or_sections, dim=axis)
1577
+ return list(out)
1578
+
1579
+
1509
1580
  def stack(x, axis=0):
1510
1581
  x = [convert_to_tensor(elem) for elem in x]
1511
1582
  return torch.stack(x, dim=axis)
@@ -1619,8 +1690,9 @@ def tile(x, repeats):
1619
1690
  def trace(x, offset=0, axis1=0, axis2=1):
1620
1691
  x = convert_to_tensor(x)
1621
1692
  dtype = standardize_dtype(x.dtype)
1622
- if dtype != "int64":
1623
- dtype = dtypes.result_type(dtype, "int32")
1693
+ if dtype in ("bool", "int8", "int16", "uint8"):
1694
+ # Torch backend doesn't support uint32 dtype.
1695
+ dtype = "int32"
1624
1696
  return torch.sum(
1625
1697
  torch.diagonal(x, offset, axis1, axis2),
1626
1698
  dim=-1,
@@ -1733,6 +1805,16 @@ def negative(x):
1733
1805
  return torch.negative(x)
1734
1806
 
1735
1807
 
1808
+ def nextafter(x1, x2):
1809
+ x1 = convert_to_tensor(x1)
1810
+ x2 = convert_to_tensor(x2)
1811
+
1812
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1813
+ x1 = cast(x1, torch.float64)
1814
+ x2 = cast(x2, torch.float64)
1815
+ return cast(torch.nextafter(x1, x2), dtype)
1816
+
1817
+
1736
1818
  def square(x):
1737
1819
  x = convert_to_tensor(x)
1738
1820
  if standardize_dtype(x.dtype) == "bool":
@@ -1761,6 +1843,24 @@ def transpose(x, axes=None):
1761
1843
  return x.T
1762
1844
 
1763
1845
 
1846
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
1847
+ y = convert_to_tensor(y)
1848
+ if standardize_dtype(y.dtype) == "bool":
1849
+ y = cast(y, config.floatx())
1850
+ if x is not None:
1851
+ x = convert_to_tensor(x)
1852
+ return torch.trapz(y, x=x, dim=axis)
1853
+ else:
1854
+ dx = convert_to_tensor(dx)
1855
+ return torch.trapz(y, dx=dx, dim=axis)
1856
+
1857
+
1858
+ def vander(x, N=None, increasing=False):
1859
+ x = convert_to_tensor(x)
1860
+ result_dtype = dtypes.result_type(x.dtype)
1861
+ return cast(torch.vander(x, N=N, increasing=increasing), result_dtype)
1862
+
1863
+
1764
1864
  def var(x, axis=None, keepdims=False):
1765
1865
  x = convert_to_tensor(x)
1766
1866
  compute_dtype = dtypes.result_type(x.dtype, "float32")
@@ -54,7 +54,10 @@ class TorchTrainer(base_trainer.Trainer):
54
54
  x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=True
55
55
  )
56
56
  self._loss_tracker.update_state(
57
- loss, sample_weight=tree.flatten(x)[0].shape[0]
57
+ loss,
58
+ sample_weight=next(
59
+ i for i in tree.flatten(x) if i is not None
60
+ ).shape[0],
58
61
  )
59
62
  if self.optimizer is not None:
60
63
  loss = self.optimizer.scale_loss(loss)
@@ -90,7 +93,10 @@ class TorchTrainer(base_trainer.Trainer):
90
93
  x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
91
94
  )
92
95
  self._loss_tracker.update_state(
93
- loss, sample_weight=tree.flatten(x)[0].shape[0]
96
+ loss,
97
+ sample_weight=next(
98
+ i for i in tree.flatten(x) if i is not None
99
+ ).shape[0],
94
100
  )
95
101
  return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
96
102
 
@@ -8,6 +8,7 @@ from keras.src.callbacks.lambda_callback import LambdaCallback
8
8
  from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
9
9
  from keras.src.callbacks.model_checkpoint import ModelCheckpoint
10
10
  from keras.src.callbacks.monitor_callback import MonitorCallback
11
+ from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
11
12
  from keras.src.callbacks.progbar_logger import ProgbarLogger
12
13
  from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
13
14
  from keras.src.callbacks.remote_monitor import RemoteMonitor
@@ -39,6 +39,7 @@ class CallbackList(Callback):
39
39
  via `Callback.set_params`.
40
40
  """
41
41
  self.callbacks = tree.flatten(callbacks) if callbacks else []
42
+ self._in_begin_end_block_count = 0
42
43
  self._executor = None
43
44
  self._async_train = False
44
45
  self._async_test = False
@@ -78,9 +79,6 @@ class CallbackList(Callback):
78
79
  if not utils.is_default(cbk.on_predict_batch_end):
79
80
  async_predict = False
80
81
 
81
- if async_train or async_test or async_predict:
82
- self._executor = concurrent.futures.ThreadPoolExecutor()
83
-
84
82
  self._async_train = async_train
85
83
  self._async_test = async_test
86
84
  self._async_predict = async_predict
@@ -113,6 +111,33 @@ class CallbackList(Callback):
113
111
  for callback in self.callbacks:
114
112
  callback.set_model(model)
115
113
 
114
+ def _on_begin(self):
115
+ """Called by `on_train/test/predict_begin`.
116
+
117
+ Start the executor for async calls if needed.
118
+ """
119
+ self._in_begin_end_block_count += 1
120
+ if (
121
+ self._in_begin_end_block_count == 1
122
+ and (self._async_train or self._async_test or self._async_predict)
123
+ and self._executor is None
124
+ ):
125
+ self._executor = concurrent.futures.ThreadPoolExecutor()
126
+
127
+ def _on_end(self):
128
+ """Called by `on_train/test/predict_end`.
129
+
130
+ Shutdown the executor for async calls if all begin/end blocks completed.
131
+ """
132
+ self._in_begin_end_block_count -= 1
133
+ if self._in_begin_end_block_count < 0:
134
+ raise ValueError(
135
+ "`on_xxx_end` called without corresponding `on_xxx_begin`"
136
+ )
137
+ if self._in_begin_end_block_count == 0 and self._executor is not None:
138
+ self._executor.shutdown()
139
+ self._executor = None
140
+
116
141
  def _async_dispatch(self, fn, *args):
117
142
  for future in self._futures:
118
143
  if future.done():
@@ -121,7 +146,8 @@ class CallbackList(Callback):
121
146
  future = self._executor.submit(fn, *args)
122
147
  self._futures.append(future)
123
148
 
124
- def _clear_futures(self):
149
+ def _flush_futures(self):
150
+ """Waits for all futures to complete and clears the list."""
125
151
  for future in self._futures:
126
152
  future.result()
127
153
  self._futures = []
@@ -138,7 +164,7 @@ class CallbackList(Callback):
138
164
 
139
165
  def on_epoch_end(self, epoch, logs=None):
140
166
  if self._async_train:
141
- self._clear_futures()
167
+ self._flush_futures()
142
168
 
143
169
  logs = python_utils.pythonify_logs(logs)
144
170
  for callback in self.callbacks:
@@ -204,44 +230,52 @@ class CallbackList(Callback):
204
230
  callback.on_predict_batch_end(batch, logs=logs)
205
231
 
206
232
  def on_train_begin(self, logs=None):
233
+ self._on_begin()
234
+
207
235
  logs = python_utils.pythonify_logs(logs)
208
236
  for callback in self.callbacks:
209
237
  callback.on_train_begin(logs)
210
238
 
211
239
  def on_train_end(self, logs=None):
212
240
  if self._async_train:
213
- self._clear_futures()
241
+ self._flush_futures()
214
242
 
215
243
  logs = python_utils.pythonify_logs(logs)
216
244
  for callback in self.callbacks:
217
245
  callback.on_train_end(logs)
218
246
 
247
+ self._on_end()
248
+
219
249
  def on_test_begin(self, logs=None):
250
+ self._on_begin()
251
+
220
252
  logs = python_utils.pythonify_logs(logs)
221
253
  for callback in self.callbacks:
222
254
  callback.on_test_begin(logs)
223
255
 
224
256
  def on_test_end(self, logs=None):
225
257
  if self._async_test:
226
- self._clear_futures()
258
+ self._flush_futures()
227
259
 
228
260
  logs = python_utils.pythonify_logs(logs)
229
261
  for callback in self.callbacks:
230
262
  callback.on_test_end(logs)
231
263
 
264
+ self._on_end()
265
+
232
266
  def on_predict_begin(self, logs=None):
267
+ self._on_begin()
268
+
233
269
  logs = python_utils.pythonify_logs(logs)
234
270
  for callback in self.callbacks:
235
271
  callback.on_predict_begin(logs)
236
272
 
237
273
  def on_predict_end(self, logs=None):
238
274
  if self._async_predict:
239
- self._clear_futures()
275
+ self._flush_futures()
240
276
 
241
277
  logs = python_utils.pythonify_logs(logs)
242
278
  for callback in self.callbacks:
243
279
  callback.on_predict_end(logs)
244
280
 
245
- def __del__(self):
246
- if self._executor is not None:
247
- self._executor.shutdown(cancel_futures=True)
281
+ self._on_end()
@@ -283,6 +283,11 @@ class ModelCheckpoint(MonitorCallback):
283
283
  self.model.save_weights(filepath, overwrite=True)
284
284
  else:
285
285
  self.model.save(filepath, overwrite=True)
286
+ if self.verbose > 0:
287
+ io_utils.print_msg(
288
+ f"\nEpoch {epoch + 1}: "
289
+ f"finished saving model to {filepath}"
290
+ )
286
291
  except IsADirectoryError: # h5py 3.x
287
292
  raise IOError(
288
293
  "Please specify a non-directory filepath for "