keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -6,17 +6,22 @@ from multiprocessing.pool import ThreadPool
6
6
 
7
7
  import numpy as np
8
8
 
9
+ from keras.src import backend
9
10
  from keras.src import tree
10
11
  from keras.src.api_export import keras_export
11
12
  from keras.src.utils import file_utils
12
13
  from keras.src.utils import io_utils
13
14
  from keras.src.utils.module_utils import grain
14
- from keras.src.utils.module_utils import tensorflow as tf
15
15
 
16
16
 
17
17
  @keras_export("keras.utils.split_dataset")
18
18
  def split_dataset(
19
- dataset, left_size=None, right_size=None, shuffle=False, seed=None
19
+ dataset,
20
+ left_size=None,
21
+ right_size=None,
22
+ shuffle=False,
23
+ seed=None,
24
+ preferred_backend=None,
20
25
  ):
21
26
  """Splits a dataset into a left half and a right half (e.g. train / test).
22
27
 
@@ -37,27 +42,86 @@ def split_dataset(
37
42
  Defaults to `None`.
38
43
  shuffle: Boolean, whether to shuffle the data before splitting it.
39
44
  seed: A random seed for shuffling.
45
+ preferred_backend: String, specifying which backend
46
+ (e.g.; "tensorflow", "torch") to use. If `None`, the
47
+ backend is inferred from the type of `dataset` - if
48
+ `dataset` is a `tf.data.Dataset`, "tensorflow" backend
49
+ is used, if `dataset` is a `torch.utils.data.Dataset`,
50
+ "torch" backend is used, and if `dataset` is a list/tuple/np.array
51
+ the current Keras backend is used. Defaults to `None`.
40
52
 
41
53
  Returns:
42
- A tuple of two `tf.data.Dataset` objects:
43
- the left and right splits.
44
-
54
+ A tuple of two dataset objects, the left and right splits. The exact
55
+ type of the returned objects depends on the `preferred_backend`.
56
+ For example, with a "tensorflow" backend,
57
+ `tf.data.Dataset` objects are returned. With a "torch" backend,
58
+ `torch.utils.data.Dataset` objects are returned.
45
59
  Example:
46
60
 
47
61
  >>> data = np.random.random(size=(1000, 4))
48
62
  >>> left_ds, right_ds = keras.utils.split_dataset(data, left_size=0.8)
49
- >>> int(left_ds.cardinality())
50
- 800
51
- >>> int(right_ds.cardinality())
52
- 200
63
+ >>> # For a tf.data.Dataset, you can use .cardinality()
64
+ >>> # >>> int(left_ds.cardinality())
65
+ >>> # 800
66
+ >>> # For a torch.utils.data.Dataset, you can use len()
67
+ >>> # >>> len(left_ds)
68
+ >>> # 800
53
69
  """
70
+ preferred_backend = preferred_backend or _infer_preferred_backend(dataset)
71
+ if preferred_backend != "torch":
72
+ return _split_dataset_tf(
73
+ dataset,
74
+ left_size=left_size,
75
+ right_size=right_size,
76
+ shuffle=shuffle,
77
+ seed=seed,
78
+ )
79
+ else:
80
+ return _split_dataset_torch(
81
+ dataset,
82
+ left_size=left_size,
83
+ right_size=right_size,
84
+ shuffle=shuffle,
85
+ seed=seed,
86
+ )
87
+
88
+
89
+ def _split_dataset_tf(
90
+ dataset, left_size=None, right_size=None, shuffle=False, seed=None
91
+ ):
92
+ """Splits a dataset into a left half and a right half (e.g. train / test).
93
+
94
+ Args:
95
+ dataset:
96
+ A `tf.data.Dataset` object,
97
+ or a list/tuple of arrays with the same length.
98
+ left_size: If float (in the range `[0, 1]`), it signifies
99
+ the fraction of the data to pack in the left dataset. If integer, it
100
+ signifies the number of samples to pack in the left dataset. If
101
+ `None`, defaults to the complement to `right_size`.
102
+ Defaults to `None`.
103
+ right_size: If float (in the range `[0, 1]`), it signifies
104
+ the fraction of the data to pack in the right dataset.
105
+ If integer, it signifies the number of samples to pack
106
+ in the right dataset.
107
+ If `None`, defaults to the complement to `left_size`.
108
+ Defaults to `None`.
109
+ shuffle: Boolean, whether to shuffle the data before splitting it.
110
+ seed: A random seed for shuffling.
111
+
112
+ Returns:
113
+ A tuple of two `tf.data.Dataset` objects:
114
+ the left and right splits.
115
+ """
116
+ from keras.src.utils.module_utils import tensorflow as tf
117
+
54
118
  dataset_type_spec = _get_type_spec(dataset)
55
119
 
56
120
  if dataset_type_spec is None:
57
121
  raise TypeError(
58
122
  "The `dataset` argument must be either"
59
- "a `tf.data.Dataset`, a `torch.utils.data.Dataset`"
60
- "object, or a list/tuple of arrays. "
123
+ "a `tf.data.Dataset` object, or"
124
+ "a list/tuple of arrays. "
61
125
  f"Received: dataset={dataset} of type {type(dataset)}"
62
126
  )
63
127
 
@@ -106,6 +170,103 @@ def split_dataset(
106
170
  return left_split, right_split
107
171
 
108
172
 
173
+ def _split_dataset_torch(
174
+ dataset, left_size=None, right_size=None, shuffle=False, seed=None
175
+ ):
176
+ """Splits a dataset into a left half and a right half (e.g. train / test).
177
+
178
+ Args:
179
+ dataset:
180
+ A `torch.utils.data.Dataset` object,
181
+ or a list/tuple of arrays with the same length.
182
+ left_size: If float (in the range `[0, 1]`), it signifies
183
+ the fraction of the data to pack in the left dataset. If integer, it
184
+ signifies the number of samples to pack in the left dataset. If
185
+ `None`, defaults to the complement to `right_size`.
186
+ Defaults to `None`.
187
+ right_size: If float (in the range `[0, 1]`), it signifies
188
+ the fraction of the data to pack in the right dataset.
189
+ If integer, it signifies the number of samples to pack
190
+ in the right dataset.
191
+ If `None`, defaults to the complement to `left_size`.
192
+ Defaults to `None`.
193
+ shuffle: Boolean, whether to shuffle the data before splitting it.
194
+ seed: A random seed for shuffling.
195
+
196
+ Returns:
197
+ A tuple of two `torch.utils.data.Dataset` objects:
198
+ the left and right splits.
199
+ """
200
+ import torch
201
+ from torch.utils.data import TensorDataset
202
+ from torch.utils.data import random_split
203
+
204
+ dataset_type_spec = _get_type_spec(dataset)
205
+ if dataset_type_spec is None:
206
+ raise TypeError(
207
+ "The `dataset` argument must be a `torch.utils.data.Dataset`"
208
+ " object, or a list/tuple of arrays."
209
+ f" Received: dataset={dataset} of type {type(dataset)}"
210
+ )
211
+
212
+ if not isinstance(dataset, torch.utils.data.Dataset):
213
+ if dataset_type_spec is np.ndarray:
214
+ dataset = TensorDataset(torch.from_numpy(dataset))
215
+ elif dataset_type_spec in (list, tuple):
216
+ tensors = [torch.from_numpy(x) for x in dataset]
217
+ dataset = TensorDataset(*tensors)
218
+ elif is_tf_dataset(dataset):
219
+ dataset_as_list = _convert_dataset_to_list(
220
+ dataset, dataset_type_spec
221
+ )
222
+ tensors = [
223
+ torch.from_numpy(np.array(sample))
224
+ for sample in zip(*dataset_as_list)
225
+ ]
226
+ dataset = TensorDataset(*tensors)
227
+
228
+ if right_size is None and left_size is None:
229
+ raise ValueError(
230
+ "At least one of the `left_size` or `right_size` "
231
+ "must be specified. "
232
+ "Received: left_size=None and right_size=None"
233
+ )
234
+
235
+ # Calculate total length and rescale split sizes
236
+ total_length = len(dataset)
237
+ left_size, right_size = _rescale_dataset_split_sizes(
238
+ left_size, right_size, total_length
239
+ )
240
+
241
+ # Shuffle the dataset if required
242
+ if shuffle:
243
+ generator = torch.Generator()
244
+ if seed is not None:
245
+ generator.manual_seed(seed)
246
+ else:
247
+ generator.seed()
248
+ else:
249
+ generator = None
250
+
251
+ left_split, right_split = random_split(
252
+ dataset, [left_size, right_size], generator=generator
253
+ )
254
+
255
+ return left_split, right_split
256
+
257
+
258
+ def _infer_preferred_backend(dataset):
259
+ """Infer the backend from the dataset type."""
260
+ if isinstance(dataset, (list, tuple, np.ndarray)):
261
+ return backend.backend()
262
+ if is_tf_dataset(dataset):
263
+ return "tensorflow"
264
+ elif is_torch_dataset(dataset):
265
+ return "torch"
266
+ else:
267
+ raise TypeError(f"Unsupported dataset type: {type(dataset)}")
268
+
269
+
109
270
  def _convert_dataset_to_list(
110
271
  dataset,
111
272
  dataset_type_spec,
@@ -208,7 +369,7 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
208
369
  )
209
370
 
210
371
  return iter(zip(*dataset))
211
- elif dataset_type_spec is tf.data.Dataset:
372
+ elif is_tf_dataset(dataset):
212
373
  if is_batched(dataset):
213
374
  dataset = dataset.unbatch()
214
375
  return iter(dataset)
@@ -242,6 +403,9 @@ def _get_next_sample(
242
403
  Yields:
243
404
  data_sample: The next sample.
244
405
  """
406
+ from keras.src.trainers.data_adapters.data_adapter_utils import (
407
+ is_tensorflow_tensor,
408
+ )
245
409
  from keras.src.trainers.data_adapters.data_adapter_utils import (
246
410
  is_torch_tensor,
247
411
  )
@@ -249,8 +413,10 @@ def _get_next_sample(
249
413
  try:
250
414
  dataset_iterator = iter(dataset_iterator)
251
415
  first_sample = next(dataset_iterator)
252
- if isinstance(first_sample, (tf.Tensor, np.ndarray)) or is_torch_tensor(
253
- first_sample
416
+ if (
417
+ isinstance(first_sample, np.ndarray)
418
+ or is_tensorflow_tensor(first_sample)
419
+ or is_torch_tensor(first_sample)
254
420
  ):
255
421
  first_sample_shape = np.array(first_sample).shape
256
422
  else:
@@ -291,23 +457,40 @@ def _get_next_sample(
291
457
  yield sample
292
458
 
293
459
 
294
- def is_torch_dataset(dataset):
295
- if hasattr(dataset, "__class__"):
296
- for parent in dataset.__class__.__mro__:
297
- if parent.__name__ == "Dataset" and str(
298
- parent.__module__
299
- ).startswith("torch.utils.data"):
300
- return True
301
- return False
460
+ def is_tf_dataset(dataset):
461
+ return _mro_matches(
462
+ dataset,
463
+ class_names=("DatasetV2", "Dataset"),
464
+ module_substrings=(
465
+ "tensorflow.python.data", # TF classic
466
+ "tensorflow.data", # newer TF paths
467
+ ),
468
+ )
302
469
 
303
470
 
304
471
  def is_grain_dataset(dataset):
305
- if hasattr(dataset, "__class__"):
306
- for parent in dataset.__class__.__mro__:
307
- if parent.__name__ in (
308
- "MapDataset",
309
- "IterDataset",
310
- ) and str(parent.__module__).startswith("grain._src.python"):
472
+ return _mro_matches(
473
+ dataset,
474
+ class_names=("MapDataset", "IterDataset"),
475
+ module_prefixes=("grain._src.python",),
476
+ )
477
+
478
+
479
+ def is_torch_dataset(dataset):
480
+ return _mro_matches(dataset, ("Dataset",), ("torch.utils.data",))
481
+
482
+
483
+ def _mro_matches(
484
+ dataset, class_names, module_prefixes=(), module_substrings=()
485
+ ):
486
+ if not hasattr(dataset, "__class__"):
487
+ return False
488
+ for parent in dataset.__class__.__mro__:
489
+ if parent.__name__ in class_names:
490
+ mod = str(parent.__module__)
491
+ if any(mod.startswith(pref) for pref in module_prefixes):
492
+ return True
493
+ if any(subs in mod for subs in module_substrings):
311
494
  return True
312
495
  return False
313
496
 
@@ -441,8 +624,10 @@ def _restore_dataset_from_list(
441
624
  dataset_as_list, dataset_type_spec, original_dataset
442
625
  ):
443
626
  """Restore the dataset from the list of arrays."""
444
- if dataset_type_spec in [tuple, list, tf.data.Dataset] or is_torch_dataset(
445
- original_dataset
627
+ if (
628
+ dataset_type_spec in [tuple, list]
629
+ or is_tf_dataset(original_dataset)
630
+ or is_torch_dataset(original_dataset)
446
631
  ):
447
632
  # Save structure by taking the first element.
448
633
  element_spec = dataset_as_list[0]
@@ -483,7 +668,9 @@ def _get_type_spec(dataset):
483
668
  return list
484
669
  elif isinstance(dataset, np.ndarray):
485
670
  return np.ndarray
486
- elif isinstance(dataset, tf.data.Dataset):
671
+ elif is_tf_dataset(dataset):
672
+ from keras.src.utils.module_utils import tensorflow as tf
673
+
487
674
  return tf.data.Dataset
488
675
  elif is_torch_dataset(dataset):
489
676
  from torch.utils.data import Dataset as TorchDataset
@@ -543,6 +730,8 @@ def index_directory(
543
730
  order.
544
731
  """
545
732
  if file_utils.is_remote_path(directory):
733
+ from keras.src.utils.module_utils import tensorflow as tf
734
+
546
735
  os_module = tf.io.gfile
547
736
  path_module = tf.io.gfile
548
737
  else:
@@ -647,7 +836,12 @@ def index_directory(
647
836
 
648
837
 
649
838
  def iter_valid_files(directory, follow_links, formats):
650
- io_module = tf.io.gfile if file_utils.is_remote_path(directory) else os
839
+ if file_utils.is_remote_path(directory):
840
+ from keras.src.utils.module_utils import tensorflow as tf
841
+
842
+ io_module = tf.io.gfile
843
+ else:
844
+ io_module = os
651
845
 
652
846
  if not follow_links:
653
847
  walk = io_module.walk(directory)
@@ -674,9 +868,12 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
674
868
  paths, and `labels` is a list of integer labels corresponding
675
869
  to these files.
676
870
  """
677
- path_module = (
678
- tf.io.gfile if file_utils.is_remote_path(directory) else os.path
679
- )
871
+ if file_utils.is_remote_path(directory):
872
+ from keras.src.utils.module_utils import tensorflow as tf
873
+
874
+ path_module = tf.io.gfile
875
+ else:
876
+ path_module = os.path
680
877
 
681
878
  dirname = os.path.basename(directory)
682
879
  valid_files = iter_valid_files(directory, follow_links, formats)
@@ -746,6 +943,8 @@ def labels_to_dataset_tf(labels, label_mode, num_classes):
746
943
  Returns:
747
944
  A `tf.data.Dataset` instance.
748
945
  """
946
+ from keras.src.utils.module_utils import tensorflow as tf
947
+
749
948
  label_ds = tf.data.Dataset.from_tensor_slices(labels)
750
949
  if label_mode == "binary":
751
950
  label_ds = label_ds.map(
@@ -2,6 +2,7 @@ import hashlib
2
2
  import os
3
3
  import re
4
4
  import shutil
5
+ import sys
5
6
  import tarfile
6
7
  import tempfile
7
8
  import urllib
@@ -52,17 +53,32 @@ def is_link_in_dir(info, base):
52
53
  return is_path_in_dir(info.linkname, base_dir=tip)
53
54
 
54
55
 
55
- def filter_safe_paths(members):
56
+ def filter_safe_zipinfos(members):
56
57
  base_dir = resolve_path(".")
57
58
  for finfo in members:
58
59
  valid_path = False
59
- if is_path_in_dir(finfo.name, base_dir):
60
+ if is_path_in_dir(finfo.filename, base_dir):
60
61
  valid_path = True
61
62
  yield finfo
62
- elif finfo.issym() or finfo.islnk():
63
+ if not valid_path:
64
+ warnings.warn(
65
+ "Skipping invalid path during archive extraction: "
66
+ f"'{finfo.name}'.",
67
+ stacklevel=2,
68
+ )
69
+
70
+
71
+ def filter_safe_tarinfos(members):
72
+ base_dir = resolve_path(".")
73
+ for finfo in members:
74
+ valid_path = False
75
+ if finfo.issym() or finfo.islnk():
63
76
  if is_link_in_dir(finfo, base_dir):
64
77
  valid_path = True
65
78
  yield finfo
79
+ elif is_path_in_dir(finfo.name, base_dir):
80
+ valid_path = True
81
+ yield finfo
66
82
  if not valid_path:
67
83
  warnings.warn(
68
84
  "Skipping invalid path during archive extraction: "
@@ -71,6 +87,35 @@ def filter_safe_paths(members):
71
87
  )
72
88
 
73
89
 
90
+ def extract_open_archive(archive, path="."):
91
+ """Extracts an open tar or zip archive to the provided directory.
92
+
93
+ This function filters unsafe paths during extraction.
94
+
95
+ Args:
96
+ archive: The archive object, either a `TarFile` or a `ZipFile`.
97
+ path: Where to extract the archive file.
98
+ """
99
+ if isinstance(archive, zipfile.ZipFile):
100
+ # Zip archive.
101
+ archive.extractall(
102
+ path, members=filter_safe_zipinfos(archive.infolist())
103
+ )
104
+ else:
105
+ # Tar archive.
106
+ extractall_kwargs = {}
107
+ # The `filter="data"` option was added in Python 3.12. It became the
108
+ # default starting from Python 3.14. So we only specify it between
109
+ # those two versions.
110
+ if sys.version_info >= (3, 12) and sys.version_info < (3, 14):
111
+ extractall_kwargs = {"filter": "data"}
112
+ archive.extractall(
113
+ path,
114
+ members=filter_safe_tarinfos(archive),
115
+ **extractall_kwargs,
116
+ )
117
+
118
+
74
119
  def extract_archive(file_path, path=".", archive_format="auto"):
75
120
  """Extracts an archive if it matches a support format.
76
121
 
@@ -112,14 +157,7 @@ def extract_archive(file_path, path=".", archive_format="auto"):
112
157
  if is_match_fn(file_path):
113
158
  with open_fn(file_path) as archive:
114
159
  try:
115
- if zipfile.is_zipfile(file_path):
116
- # Zip archive.
117
- archive.extractall(path)
118
- else:
119
- # Tar archive, perhaps unsafe. Filter paths.
120
- archive.extractall(
121
- path, members=filter_safe_paths(archive)
122
- )
160
+ extract_open_archive(archive, path)
123
161
  except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
124
162
  if os.path.exists(path):
125
163
  if os.path.isfile(path):
@@ -175,12 +175,24 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
175
175
  **kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
176
176
  """
177
177
  data_format = backend.standardize_data_format(data_format)
178
+
179
+ # Infer format from path if not explicitly provided
180
+ if file_format is None and isinstance(path, (str, pathlib.Path)):
181
+ file_format = pathlib.Path(path).suffix[1:].lower()
182
+
183
+ # Normalize jpg → jpeg for Pillow compatibility
184
+ if file_format and file_format.lower() == "jpg":
185
+ file_format = "jpeg"
186
+
178
187
  img = array_to_img(x, data_format=data_format, scale=scale)
179
- if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
188
+
189
+ # Handle RGBA → RGB conversion for JPEG
190
+ if img.mode == "RGBA" and file_format == "jpeg":
180
191
  warnings.warn(
181
- "The JPG format does not support RGBA images, converting to RGB."
192
+ "The JPEG format does not support RGBA images, converting to RGB."
182
193
  )
183
194
  img = img.convert("RGB")
195
+
184
196
  img.save(path, format=file_format, **kwargs)
185
197
 
186
198