keras-nightly 3.12.0.dev2025090203__py3-none-any.whl → 3.14.0.dev2025122704__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +14 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +10 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +14 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +10 -0
  24. keras/quantizers/__init__.py +12 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +485 -20
  35. keras/src/backend/jax/numpy.py +98 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +85 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1081 -163
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +333 -55
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +87 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +299 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +2 -0
  68. keras/src/dtype_policies/dtype_policy.py +90 -0
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +290 -102
  80. keras/src/layers/core/einsum_dense.py +377 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +390 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +45 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/index_lookup.py +19 -1
  97. keras/src/layers/preprocessing/normalization.py +16 -1
  98. keras/src/layers/preprocessing/string_lookup.py +26 -28
  99. keras/src/layers/regularization/dropout.py +43 -1
  100. keras/src/layers/rnn/gru.py +1 -1
  101. keras/src/layers/rnn/lstm.py +2 -2
  102. keras/src/layers/rnn/rnn.py +19 -0
  103. keras/src/layers/rnn/simple_rnn.py +1 -1
  104. keras/src/legacy/preprocessing/image.py +4 -1
  105. keras/src/legacy/preprocessing/sequence.py +20 -12
  106. keras/src/losses/loss.py +1 -1
  107. keras/src/metrics/confusion_metrics.py +7 -6
  108. keras/src/models/cloning.py +4 -0
  109. keras/src/models/functional.py +11 -3
  110. keras/src/models/model.py +180 -38
  111. keras/src/ops/image.py +181 -0
  112. keras/src/ops/linalg.py +93 -0
  113. keras/src/ops/nn.py +268 -2
  114. keras/src/ops/numpy.py +581 -44
  115. keras/src/ops/operation.py +2 -3
  116. keras/src/ops/operation_utils.py +2 -0
  117. keras/src/optimizers/adafactor.py +29 -10
  118. keras/src/optimizers/base_optimizer.py +22 -3
  119. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  120. keras/src/optimizers/muon.py +65 -31
  121. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  122. keras/src/quantizers/__init__.py +12 -1
  123. keras/src/quantizers/gptq.py +347 -207
  124. keras/src/quantizers/gptq_config.py +63 -13
  125. keras/src/quantizers/gptq_core.py +328 -215
  126. keras/src/quantizers/quantization_config.py +232 -0
  127. keras/src/quantizers/quantizers.py +398 -38
  128. keras/src/quantizers/utils.py +23 -0
  129. keras/src/random/seed_generator.py +4 -2
  130. keras/src/saving/saving_lib.py +1 -1
  131. keras/src/testing/__init__.py +1 -0
  132. keras/src/testing/test_case.py +45 -5
  133. keras/src/trainers/compile_utils.py +38 -17
  134. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  135. keras/src/tree/torchtree_impl.py +215 -0
  136. keras/src/tree/tree_api.py +6 -1
  137. keras/src/utils/backend_utils.py +31 -4
  138. keras/src/utils/dataset_utils.py +234 -35
  139. keras/src/utils/file_utils.py +49 -11
  140. keras/src/utils/image_utils.py +14 -2
  141. keras/src/utils/jax_layer.py +187 -36
  142. keras/src/utils/module_utils.py +18 -0
  143. keras/src/utils/progbar.py +10 -12
  144. keras/src/utils/python_utils.py +5 -0
  145. keras/src/utils/rng_utils.py +9 -1
  146. keras/src/utils/tracking.py +65 -0
  147. keras/src/version.py +1 -1
  148. {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.14.0.dev2025122704.dist-info}/METADATA +16 -6
  149. {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.14.0.dev2025122704.dist-info}/RECORD +151 -134
  150. keras/src/quantizers/gptq_quant.py +0 -133
  151. {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.14.0.dev2025122704.dist-info}/WHEEL +0 -0
  152. {keras_nightly-3.12.0.dev2025090203.dist-info → keras_nightly-3.14.0.dev2025122704.dist-info}/top_level.txt +0 -0
keras/__init__.py CHANGED
@@ -12,6 +12,7 @@ from keras import callbacks as callbacks
12
12
  from keras import config as config
13
13
  from keras import constraints as constraints
14
14
  from keras import datasets as datasets
15
+ from keras import distillation as distillation
15
16
  from keras import distribution as distribution
16
17
  from keras import dtype_policies as dtype_policies
17
18
  from keras import export as export
@@ -10,6 +10,7 @@ from keras import callbacks as callbacks
10
10
  from keras import config as config
11
11
  from keras import constraints as constraints
12
12
  from keras import datasets as datasets
13
+ from keras import distillation as distillation
13
14
  from keras import distribution as distribution
14
15
  from keras import dtype_policies as dtype_policies
15
16
  from keras import export as export
@@ -19,6 +19,9 @@ from keras.src.callbacks.learning_rate_scheduler import (
19
19
  from keras.src.callbacks.model_checkpoint import (
20
20
  ModelCheckpoint as ModelCheckpoint,
21
21
  )
22
+ from keras.src.callbacks.orbax_checkpoint import (
23
+ OrbaxCheckpoint as OrbaxCheckpoint,
24
+ )
22
25
  from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
23
26
  from keras.src.callbacks.reduce_lr_on_plateau import (
24
27
  ReduceLROnPlateau as ReduceLROnPlateau,
@@ -0,0 +1,16 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras.src.distillation.distillation_loss import (
8
+ DistillationLoss as DistillationLoss,
9
+ )
10
+ from keras.src.distillation.distillation_loss import (
11
+ FeatureDistillation as FeatureDistillation,
12
+ )
13
+ from keras.src.distillation.distillation_loss import (
14
+ LogitsDistillation as LogitsDistillation,
15
+ )
16
+ from keras.src.distillation.distiller import Distiller as Distiller
@@ -15,6 +15,9 @@ from keras.src.distribution.distribution_lib import (
15
15
  distribute_tensor as distribute_tensor,
16
16
  )
17
17
  from keras.src.distribution.distribution_lib import distribution as distribution
18
+ from keras.src.distribution.distribution_lib import (
19
+ get_device_count as get_device_count,
20
+ )
18
21
  from keras.src.distribution.distribution_lib import initialize as initialize
19
22
  from keras.src.distribution.distribution_lib import list_devices as list_devices
20
23
  from keras.src.distribution.distribution_lib import (
@@ -11,6 +11,9 @@ from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy
11
11
  from keras.src.dtype_policies.dtype_policy import (
12
12
  FloatDTypePolicy as FloatDTypePolicy,
13
13
  )
14
+ from keras.src.dtype_policies.dtype_policy import (
15
+ GPTQDTypePolicy as GPTQDTypePolicy,
16
+ )
14
17
  from keras.src.dtype_policies.dtype_policy import (
15
18
  QuantizedDTypePolicy as QuantizedDTypePolicy,
16
19
  )
@@ -73,6 +73,9 @@ from keras.src.layers.core.input_layer import Input as Input
73
73
  from keras.src.layers.core.input_layer import InputLayer as InputLayer
74
74
  from keras.src.layers.core.lambda_layer import Lambda as Lambda
75
75
  from keras.src.layers.core.masking import Masking as Masking
76
+ from keras.src.layers.core.reversible_embedding import (
77
+ ReversibleEmbedding as ReversibleEmbedding,
78
+ )
76
79
  from keras.src.layers.core.wrapper import Wrapper as Wrapper
77
80
  from keras.src.layers.input_spec import InputSpec as InputSpec
78
81
  from keras.src.layers.layer import Layer as Layer
@@ -110,6 +113,24 @@ from keras.src.layers.normalization.spectral_normalization import (
110
113
  from keras.src.layers.normalization.unit_normalization import (
111
114
  UnitNormalization as UnitNormalization,
112
115
  )
116
+ from keras.src.layers.pooling.adaptive_average_pooling1d import (
117
+ AdaptiveAveragePooling1D as AdaptiveAveragePooling1D,
118
+ )
119
+ from keras.src.layers.pooling.adaptive_average_pooling2d import (
120
+ AdaptiveAveragePooling2D as AdaptiveAveragePooling2D,
121
+ )
122
+ from keras.src.layers.pooling.adaptive_average_pooling3d import (
123
+ AdaptiveAveragePooling3D as AdaptiveAveragePooling3D,
124
+ )
125
+ from keras.src.layers.pooling.adaptive_max_pooling1d import (
126
+ AdaptiveMaxPooling1D as AdaptiveMaxPooling1D,
127
+ )
128
+ from keras.src.layers.pooling.adaptive_max_pooling2d import (
129
+ AdaptiveMaxPooling2D as AdaptiveMaxPooling2D,
130
+ )
131
+ from keras.src.layers.pooling.adaptive_max_pooling3d import (
132
+ AdaptiveMaxPooling3D as AdaptiveMaxPooling3D,
133
+ )
113
134
  from keras.src.layers.pooling.average_pooling1d import (
114
135
  AveragePooling1D as AveragePooling1D,
115
136
  )
@@ -37,6 +37,7 @@ from keras.src.ops.linalg import det as det
37
37
  from keras.src.ops.linalg import eig as eig
38
38
  from keras.src.ops.linalg import eigh as eigh
39
39
  from keras.src.ops.linalg import inv as inv
40
+ from keras.src.ops.linalg import jvp as jvp
40
41
  from keras.src.ops.linalg import lstsq as lstsq
41
42
  from keras.src.ops.linalg import lu_factor as lu_factor
42
43
  from keras.src.ops.linalg import norm as norm
@@ -63,6 +64,8 @@ from keras.src.ops.math import stft as stft
63
64
  from keras.src.ops.math import top_k as top_k
64
65
  from keras.src.ops.math import view_as_complex as view_as_complex
65
66
  from keras.src.ops.math import view_as_real as view_as_real
67
+ from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
68
+ from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
66
69
  from keras.src.ops.nn import average_pool as average_pool
67
70
  from keras.src.ops.nn import batch_normalization as batch_normalization
68
71
  from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
@@ -116,6 +119,7 @@ from keras.src.ops.nn import sparsemax as sparsemax
116
119
  from keras.src.ops.nn import squareplus as squareplus
117
120
  from keras.src.ops.nn import tanh_shrink as tanh_shrink
118
121
  from keras.src.ops.nn import threshold as threshold
122
+ from keras.src.ops.nn import unfold as unfold
119
123
  from keras.src.ops.numpy import abs as abs
120
124
  from keras.src.ops.numpy import absolute as absolute
121
125
  from keras.src.ops.numpy import add as add
@@ -138,6 +142,7 @@ from keras.src.ops.numpy import argmin as argmin
138
142
  from keras.src.ops.numpy import argpartition as argpartition
139
143
  from keras.src.ops.numpy import argsort as argsort
140
144
  from keras.src.ops.numpy import array as array
145
+ from keras.src.ops.numpy import array_split as array_split
141
146
  from keras.src.ops.numpy import average as average
142
147
  from keras.src.ops.numpy import bartlett as bartlett
143
148
  from keras.src.ops.numpy import bincount as bincount
@@ -176,6 +181,7 @@ from keras.src.ops.numpy import divide_no_nan as divide_no_nan
176
181
  from keras.src.ops.numpy import dot as dot
177
182
  from keras.src.ops.numpy import einsum as einsum
178
183
  from keras.src.ops.numpy import empty as empty
184
+ from keras.src.ops.numpy import empty_like as empty_like
179
185
  from keras.src.ops.numpy import equal as equal
180
186
  from keras.src.ops.numpy import exp as exp
181
187
  from keras.src.ops.numpy import exp2 as exp2
@@ -207,7 +213,11 @@ from keras.src.ops.numpy import isinf as isinf
207
213
  from keras.src.ops.numpy import isnan as isnan
208
214
  from keras.src.ops.numpy import isneginf as isneginf
209
215
  from keras.src.ops.numpy import isposinf as isposinf
216
+ from keras.src.ops.numpy import isreal as isreal
210
217
  from keras.src.ops.numpy import kaiser as kaiser
218
+ from keras.src.ops.numpy import kron as kron
219
+ from keras.src.ops.numpy import lcm as lcm
220
+ from keras.src.ops.numpy import ldexp as ldexp
211
221
  from keras.src.ops.numpy import left_shift as left_shift
212
222
  from keras.src.ops.numpy import less as less
213
223
  from keras.src.ops.numpy import less_equal as less_equal
@@ -217,6 +227,7 @@ from keras.src.ops.numpy import log1p as log1p
217
227
  from keras.src.ops.numpy import log2 as log2
218
228
  from keras.src.ops.numpy import log10 as log10
219
229
  from keras.src.ops.numpy import logaddexp as logaddexp
230
+ from keras.src.ops.numpy import logaddexp2 as logaddexp2
220
231
  from keras.src.ops.numpy import logical_and as logical_and
221
232
  from keras.src.ops.numpy import logical_not as logical_not
222
233
  from keras.src.ops.numpy import logical_or as logical_or
@@ -280,15 +291,18 @@ from keras.src.ops.numpy import tensordot as tensordot
280
291
  from keras.src.ops.numpy import tile as tile
281
292
  from keras.src.ops.numpy import trace as trace
282
293
  from keras.src.ops.numpy import transpose as transpose
294
+ from keras.src.ops.numpy import trapezoid as trapezoid
283
295
  from keras.src.ops.numpy import tri as tri
284
296
  from keras.src.ops.numpy import tril as tril
285
297
  from keras.src.ops.numpy import triu as triu
286
298
  from keras.src.ops.numpy import true_divide as true_divide
287
299
  from keras.src.ops.numpy import trunc as trunc
288
300
  from keras.src.ops.numpy import unravel_index as unravel_index
301
+ from keras.src.ops.numpy import vander as vander
289
302
  from keras.src.ops.numpy import var as var
290
303
  from keras.src.ops.numpy import vdot as vdot
291
304
  from keras.src.ops.numpy import vectorize as vectorize
305
+ from keras.src.ops.numpy import view as view
292
306
  from keras.src.ops.numpy import vstack as vstack
293
307
  from keras.src.ops.numpy import where as where
294
308
  from keras.src.ops.numpy import zeros as zeros
@@ -8,6 +8,7 @@ from keras.src.ops.image import affine_transform as affine_transform
8
8
  from keras.src.ops.image import crop_images as crop_images
9
9
  from keras.src.ops.image import elastic_transform as elastic_transform
10
10
  from keras.src.ops.image import extract_patches as extract_patches
11
+ from keras.src.ops.image import extract_patches_3d as extract_patches_3d
11
12
  from keras.src.ops.image import gaussian_blur as gaussian_blur
12
13
  from keras.src.ops.image import hsv_to_rgb as hsv_to_rgb
13
14
  from keras.src.ops.image import map_coordinates as map_coordinates
@@ -10,6 +10,7 @@ from keras.src.ops.linalg import det as det
10
10
  from keras.src.ops.linalg import eig as eig
11
11
  from keras.src.ops.linalg import eigh as eigh
12
12
  from keras.src.ops.linalg import inv as inv
13
+ from keras.src.ops.linalg import jvp as jvp
13
14
  from keras.src.ops.linalg import lstsq as lstsq
14
15
  from keras.src.ops.linalg import lu_factor as lu_factor
15
16
  from keras.src.ops.linalg import norm as norm
@@ -4,6 +4,8 @@ This file was autogenerated. Do not edit it by hand,
4
4
  since your modifications would be overwritten.
5
5
  """
6
6
 
7
+ from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
8
+ from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
7
9
  from keras.src.ops.nn import average_pool as average_pool
8
10
  from keras.src.ops.nn import batch_normalization as batch_normalization
9
11
  from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
@@ -57,3 +59,4 @@ from keras.src.ops.nn import sparsemax as sparsemax
57
59
  from keras.src.ops.nn import squareplus as squareplus
58
60
  from keras.src.ops.nn import tanh_shrink as tanh_shrink
59
61
  from keras.src.ops.nn import threshold as threshold
62
+ from keras.src.ops.nn import unfold as unfold
@@ -26,6 +26,7 @@ from keras.src.ops.numpy import argmin as argmin
26
26
  from keras.src.ops.numpy import argpartition as argpartition
27
27
  from keras.src.ops.numpy import argsort as argsort
28
28
  from keras.src.ops.numpy import array as array
29
+ from keras.src.ops.numpy import array_split as array_split
29
30
  from keras.src.ops.numpy import average as average
30
31
  from keras.src.ops.numpy import bartlett as bartlett
31
32
  from keras.src.ops.numpy import bincount as bincount
@@ -64,6 +65,7 @@ from keras.src.ops.numpy import divide_no_nan as divide_no_nan
64
65
  from keras.src.ops.numpy import dot as dot
65
66
  from keras.src.ops.numpy import einsum as einsum
66
67
  from keras.src.ops.numpy import empty as empty
68
+ from keras.src.ops.numpy import empty_like as empty_like
67
69
  from keras.src.ops.numpy import equal as equal
68
70
  from keras.src.ops.numpy import exp as exp
69
71
  from keras.src.ops.numpy import exp2 as exp2
@@ -95,7 +97,11 @@ from keras.src.ops.numpy import isinf as isinf
95
97
  from keras.src.ops.numpy import isnan as isnan
96
98
  from keras.src.ops.numpy import isneginf as isneginf
97
99
  from keras.src.ops.numpy import isposinf as isposinf
100
+ from keras.src.ops.numpy import isreal as isreal
98
101
  from keras.src.ops.numpy import kaiser as kaiser
102
+ from keras.src.ops.numpy import kron as kron
103
+ from keras.src.ops.numpy import lcm as lcm
104
+ from keras.src.ops.numpy import ldexp as ldexp
99
105
  from keras.src.ops.numpy import left_shift as left_shift
100
106
  from keras.src.ops.numpy import less as less
101
107
  from keras.src.ops.numpy import less_equal as less_equal
@@ -105,6 +111,7 @@ from keras.src.ops.numpy import log1p as log1p
105
111
  from keras.src.ops.numpy import log2 as log2
106
112
  from keras.src.ops.numpy import log10 as log10
107
113
  from keras.src.ops.numpy import logaddexp as logaddexp
114
+ from keras.src.ops.numpy import logaddexp2 as logaddexp2
108
115
  from keras.src.ops.numpy import logical_and as logical_and
109
116
  from keras.src.ops.numpy import logical_not as logical_not
110
117
  from keras.src.ops.numpy import logical_or as logical_or
@@ -168,15 +175,18 @@ from keras.src.ops.numpy import tensordot as tensordot
168
175
  from keras.src.ops.numpy import tile as tile
169
176
  from keras.src.ops.numpy import trace as trace
170
177
  from keras.src.ops.numpy import transpose as transpose
178
+ from keras.src.ops.numpy import trapezoid as trapezoid
171
179
  from keras.src.ops.numpy import tri as tri
172
180
  from keras.src.ops.numpy import tril as tril
173
181
  from keras.src.ops.numpy import triu as triu
174
182
  from keras.src.ops.numpy import true_divide as true_divide
175
183
  from keras.src.ops.numpy import trunc as trunc
176
184
  from keras.src.ops.numpy import unravel_index as unravel_index
185
+ from keras.src.ops.numpy import vander as vander
177
186
  from keras.src.ops.numpy import var as var
178
187
  from keras.src.ops.numpy import vdot as vdot
179
188
  from keras.src.ops.numpy import vectorize as vectorize
189
+ from keras.src.ops.numpy import view as view
180
190
  from keras.src.ops.numpy import vstack as vstack
181
191
  from keras.src.ops.numpy import where as where
182
192
  from keras.src.ops.numpy import zeros as zeros
@@ -8,6 +8,18 @@ from keras.src.quantizers import deserialize as deserialize
8
8
  from keras.src.quantizers import get as get
9
9
  from keras.src.quantizers import serialize as serialize
10
10
  from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
11
+ from keras.src.quantizers.quantization_config import (
12
+ Float8QuantizationConfig as Float8QuantizationConfig,
13
+ )
14
+ from keras.src.quantizers.quantization_config import (
15
+ Int4QuantizationConfig as Int4QuantizationConfig,
16
+ )
17
+ from keras.src.quantizers.quantization_config import (
18
+ Int8QuantizationConfig as Int8QuantizationConfig,
19
+ )
20
+ from keras.src.quantizers.quantization_config import (
21
+ QuantizationConfig as QuantizationConfig,
22
+ )
11
23
  from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
12
24
  from keras.src.quantizers.quantizers import Quantizer as Quantizer
13
25
  from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
@@ -19,6 +19,9 @@ from keras.src.callbacks.learning_rate_scheduler import (
19
19
  from keras.src.callbacks.model_checkpoint import (
20
20
  ModelCheckpoint as ModelCheckpoint,
21
21
  )
22
+ from keras.src.callbacks.orbax_checkpoint import (
23
+ OrbaxCheckpoint as OrbaxCheckpoint,
24
+ )
22
25
  from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
23
26
  from keras.src.callbacks.reduce_lr_on_plateau import (
24
27
  ReduceLROnPlateau as ReduceLROnPlateau,
@@ -0,0 +1,16 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras.src.distillation.distillation_loss import (
8
+ DistillationLoss as DistillationLoss,
9
+ )
10
+ from keras.src.distillation.distillation_loss import (
11
+ FeatureDistillation as FeatureDistillation,
12
+ )
13
+ from keras.src.distillation.distillation_loss import (
14
+ LogitsDistillation as LogitsDistillation,
15
+ )
16
+ from keras.src.distillation.distiller import Distiller as Distiller
@@ -15,6 +15,9 @@ from keras.src.distribution.distribution_lib import (
15
15
  distribute_tensor as distribute_tensor,
16
16
  )
17
17
  from keras.src.distribution.distribution_lib import distribution as distribution
18
+ from keras.src.distribution.distribution_lib import (
19
+ get_device_count as get_device_count,
20
+ )
18
21
  from keras.src.distribution.distribution_lib import initialize as initialize
19
22
  from keras.src.distribution.distribution_lib import list_devices as list_devices
20
23
  from keras.src.distribution.distribution_lib import (
@@ -11,6 +11,9 @@ from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy
11
11
  from keras.src.dtype_policies.dtype_policy import (
12
12
  FloatDTypePolicy as FloatDTypePolicy,
13
13
  )
14
+ from keras.src.dtype_policies.dtype_policy import (
15
+ GPTQDTypePolicy as GPTQDTypePolicy,
16
+ )
14
17
  from keras.src.dtype_policies.dtype_policy import (
15
18
  QuantizedDTypePolicy as QuantizedDTypePolicy,
16
19
  )
keras/layers/__init__.py CHANGED
@@ -73,6 +73,9 @@ from keras.src.layers.core.input_layer import Input as Input
73
73
  from keras.src.layers.core.input_layer import InputLayer as InputLayer
74
74
  from keras.src.layers.core.lambda_layer import Lambda as Lambda
75
75
  from keras.src.layers.core.masking import Masking as Masking
76
+ from keras.src.layers.core.reversible_embedding import (
77
+ ReversibleEmbedding as ReversibleEmbedding,
78
+ )
76
79
  from keras.src.layers.core.wrapper import Wrapper as Wrapper
77
80
  from keras.src.layers.input_spec import InputSpec as InputSpec
78
81
  from keras.src.layers.layer import Layer as Layer
@@ -110,6 +113,24 @@ from keras.src.layers.normalization.spectral_normalization import (
110
113
  from keras.src.layers.normalization.unit_normalization import (
111
114
  UnitNormalization as UnitNormalization,
112
115
  )
116
+ from keras.src.layers.pooling.adaptive_average_pooling1d import (
117
+ AdaptiveAveragePooling1D as AdaptiveAveragePooling1D,
118
+ )
119
+ from keras.src.layers.pooling.adaptive_average_pooling2d import (
120
+ AdaptiveAveragePooling2D as AdaptiveAveragePooling2D,
121
+ )
122
+ from keras.src.layers.pooling.adaptive_average_pooling3d import (
123
+ AdaptiveAveragePooling3D as AdaptiveAveragePooling3D,
124
+ )
125
+ from keras.src.layers.pooling.adaptive_max_pooling1d import (
126
+ AdaptiveMaxPooling1D as AdaptiveMaxPooling1D,
127
+ )
128
+ from keras.src.layers.pooling.adaptive_max_pooling2d import (
129
+ AdaptiveMaxPooling2D as AdaptiveMaxPooling2D,
130
+ )
131
+ from keras.src.layers.pooling.adaptive_max_pooling3d import (
132
+ AdaptiveMaxPooling3D as AdaptiveMaxPooling3D,
133
+ )
113
134
  from keras.src.layers.pooling.average_pooling1d import (
114
135
  AveragePooling1D as AveragePooling1D,
115
136
  )
keras/ops/__init__.py CHANGED
@@ -37,6 +37,7 @@ from keras.src.ops.linalg import det as det
37
37
  from keras.src.ops.linalg import eig as eig
38
38
  from keras.src.ops.linalg import eigh as eigh
39
39
  from keras.src.ops.linalg import inv as inv
40
+ from keras.src.ops.linalg import jvp as jvp
40
41
  from keras.src.ops.linalg import lstsq as lstsq
41
42
  from keras.src.ops.linalg import lu_factor as lu_factor
42
43
  from keras.src.ops.linalg import norm as norm
@@ -63,6 +64,8 @@ from keras.src.ops.math import stft as stft
63
64
  from keras.src.ops.math import top_k as top_k
64
65
  from keras.src.ops.math import view_as_complex as view_as_complex
65
66
  from keras.src.ops.math import view_as_real as view_as_real
67
+ from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
68
+ from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
66
69
  from keras.src.ops.nn import average_pool as average_pool
67
70
  from keras.src.ops.nn import batch_normalization as batch_normalization
68
71
  from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
@@ -116,6 +119,7 @@ from keras.src.ops.nn import sparsemax as sparsemax
116
119
  from keras.src.ops.nn import squareplus as squareplus
117
120
  from keras.src.ops.nn import tanh_shrink as tanh_shrink
118
121
  from keras.src.ops.nn import threshold as threshold
122
+ from keras.src.ops.nn import unfold as unfold
119
123
  from keras.src.ops.numpy import abs as abs
120
124
  from keras.src.ops.numpy import absolute as absolute
121
125
  from keras.src.ops.numpy import add as add
@@ -138,6 +142,7 @@ from keras.src.ops.numpy import argmin as argmin
138
142
  from keras.src.ops.numpy import argpartition as argpartition
139
143
  from keras.src.ops.numpy import argsort as argsort
140
144
  from keras.src.ops.numpy import array as array
145
+ from keras.src.ops.numpy import array_split as array_split
141
146
  from keras.src.ops.numpy import average as average
142
147
  from keras.src.ops.numpy import bartlett as bartlett
143
148
  from keras.src.ops.numpy import bincount as bincount
@@ -176,6 +181,7 @@ from keras.src.ops.numpy import divide_no_nan as divide_no_nan
176
181
  from keras.src.ops.numpy import dot as dot
177
182
  from keras.src.ops.numpy import einsum as einsum
178
183
  from keras.src.ops.numpy import empty as empty
184
+ from keras.src.ops.numpy import empty_like as empty_like
179
185
  from keras.src.ops.numpy import equal as equal
180
186
  from keras.src.ops.numpy import exp as exp
181
187
  from keras.src.ops.numpy import exp2 as exp2
@@ -207,7 +213,11 @@ from keras.src.ops.numpy import isinf as isinf
207
213
  from keras.src.ops.numpy import isnan as isnan
208
214
  from keras.src.ops.numpy import isneginf as isneginf
209
215
  from keras.src.ops.numpy import isposinf as isposinf
216
+ from keras.src.ops.numpy import isreal as isreal
210
217
  from keras.src.ops.numpy import kaiser as kaiser
218
+ from keras.src.ops.numpy import kron as kron
219
+ from keras.src.ops.numpy import lcm as lcm
220
+ from keras.src.ops.numpy import ldexp as ldexp
211
221
  from keras.src.ops.numpy import left_shift as left_shift
212
222
  from keras.src.ops.numpy import less as less
213
223
  from keras.src.ops.numpy import less_equal as less_equal
@@ -217,6 +227,7 @@ from keras.src.ops.numpy import log1p as log1p
217
227
  from keras.src.ops.numpy import log2 as log2
218
228
  from keras.src.ops.numpy import log10 as log10
219
229
  from keras.src.ops.numpy import logaddexp as logaddexp
230
+ from keras.src.ops.numpy import logaddexp2 as logaddexp2
220
231
  from keras.src.ops.numpy import logical_and as logical_and
221
232
  from keras.src.ops.numpy import logical_not as logical_not
222
233
  from keras.src.ops.numpy import logical_or as logical_or
@@ -280,15 +291,18 @@ from keras.src.ops.numpy import tensordot as tensordot
280
291
  from keras.src.ops.numpy import tile as tile
281
292
  from keras.src.ops.numpy import trace as trace
282
293
  from keras.src.ops.numpy import transpose as transpose
294
+ from keras.src.ops.numpy import trapezoid as trapezoid
283
295
  from keras.src.ops.numpy import tri as tri
284
296
  from keras.src.ops.numpy import tril as tril
285
297
  from keras.src.ops.numpy import triu as triu
286
298
  from keras.src.ops.numpy import true_divide as true_divide
287
299
  from keras.src.ops.numpy import trunc as trunc
288
300
  from keras.src.ops.numpy import unravel_index as unravel_index
301
+ from keras.src.ops.numpy import vander as vander
289
302
  from keras.src.ops.numpy import var as var
290
303
  from keras.src.ops.numpy import vdot as vdot
291
304
  from keras.src.ops.numpy import vectorize as vectorize
305
+ from keras.src.ops.numpy import view as view
292
306
  from keras.src.ops.numpy import vstack as vstack
293
307
  from keras.src.ops.numpy import where as where
294
308
  from keras.src.ops.numpy import zeros as zeros
@@ -8,6 +8,7 @@ from keras.src.ops.image import affine_transform as affine_transform
8
8
  from keras.src.ops.image import crop_images as crop_images
9
9
  from keras.src.ops.image import elastic_transform as elastic_transform
10
10
  from keras.src.ops.image import extract_patches as extract_patches
11
+ from keras.src.ops.image import extract_patches_3d as extract_patches_3d
11
12
  from keras.src.ops.image import gaussian_blur as gaussian_blur
12
13
  from keras.src.ops.image import hsv_to_rgb as hsv_to_rgb
13
14
  from keras.src.ops.image import map_coordinates as map_coordinates
@@ -10,6 +10,7 @@ from keras.src.ops.linalg import det as det
10
10
  from keras.src.ops.linalg import eig as eig
11
11
  from keras.src.ops.linalg import eigh as eigh
12
12
  from keras.src.ops.linalg import inv as inv
13
+ from keras.src.ops.linalg import jvp as jvp
13
14
  from keras.src.ops.linalg import lstsq as lstsq
14
15
  from keras.src.ops.linalg import lu_factor as lu_factor
15
16
  from keras.src.ops.linalg import norm as norm
keras/ops/nn/__init__.py CHANGED
@@ -4,6 +4,8 @@ This file was autogenerated. Do not edit it by hand,
4
4
  since your modifications would be overwritten.
5
5
  """
6
6
 
7
+ from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
8
+ from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
7
9
  from keras.src.ops.nn import average_pool as average_pool
8
10
  from keras.src.ops.nn import batch_normalization as batch_normalization
9
11
  from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
@@ -57,3 +59,4 @@ from keras.src.ops.nn import sparsemax as sparsemax
57
59
  from keras.src.ops.nn import squareplus as squareplus
58
60
  from keras.src.ops.nn import tanh_shrink as tanh_shrink
59
61
  from keras.src.ops.nn import threshold as threshold
62
+ from keras.src.ops.nn import unfold as unfold
@@ -26,6 +26,7 @@ from keras.src.ops.numpy import argmin as argmin
26
26
  from keras.src.ops.numpy import argpartition as argpartition
27
27
  from keras.src.ops.numpy import argsort as argsort
28
28
  from keras.src.ops.numpy import array as array
29
+ from keras.src.ops.numpy import array_split as array_split
29
30
  from keras.src.ops.numpy import average as average
30
31
  from keras.src.ops.numpy import bartlett as bartlett
31
32
  from keras.src.ops.numpy import bincount as bincount
@@ -64,6 +65,7 @@ from keras.src.ops.numpy import divide_no_nan as divide_no_nan
64
65
  from keras.src.ops.numpy import dot as dot
65
66
  from keras.src.ops.numpy import einsum as einsum
66
67
  from keras.src.ops.numpy import empty as empty
68
+ from keras.src.ops.numpy import empty_like as empty_like
67
69
  from keras.src.ops.numpy import equal as equal
68
70
  from keras.src.ops.numpy import exp as exp
69
71
  from keras.src.ops.numpy import exp2 as exp2
@@ -95,7 +97,11 @@ from keras.src.ops.numpy import isinf as isinf
95
97
  from keras.src.ops.numpy import isnan as isnan
96
98
  from keras.src.ops.numpy import isneginf as isneginf
97
99
  from keras.src.ops.numpy import isposinf as isposinf
100
+ from keras.src.ops.numpy import isreal as isreal
98
101
  from keras.src.ops.numpy import kaiser as kaiser
102
+ from keras.src.ops.numpy import kron as kron
103
+ from keras.src.ops.numpy import lcm as lcm
104
+ from keras.src.ops.numpy import ldexp as ldexp
99
105
  from keras.src.ops.numpy import left_shift as left_shift
100
106
  from keras.src.ops.numpy import less as less
101
107
  from keras.src.ops.numpy import less_equal as less_equal
@@ -105,6 +111,7 @@ from keras.src.ops.numpy import log1p as log1p
105
111
  from keras.src.ops.numpy import log2 as log2
106
112
  from keras.src.ops.numpy import log10 as log10
107
113
  from keras.src.ops.numpy import logaddexp as logaddexp
114
+ from keras.src.ops.numpy import logaddexp2 as logaddexp2
108
115
  from keras.src.ops.numpy import logical_and as logical_and
109
116
  from keras.src.ops.numpy import logical_not as logical_not
110
117
  from keras.src.ops.numpy import logical_or as logical_or
@@ -168,15 +175,18 @@ from keras.src.ops.numpy import tensordot as tensordot
168
175
  from keras.src.ops.numpy import tile as tile
169
176
  from keras.src.ops.numpy import trace as trace
170
177
  from keras.src.ops.numpy import transpose as transpose
178
+ from keras.src.ops.numpy import trapezoid as trapezoid
171
179
  from keras.src.ops.numpy import tri as tri
172
180
  from keras.src.ops.numpy import tril as tril
173
181
  from keras.src.ops.numpy import triu as triu
174
182
  from keras.src.ops.numpy import true_divide as true_divide
175
183
  from keras.src.ops.numpy import trunc as trunc
176
184
  from keras.src.ops.numpy import unravel_index as unravel_index
185
+ from keras.src.ops.numpy import vander as vander
177
186
  from keras.src.ops.numpy import var as var
178
187
  from keras.src.ops.numpy import vdot as vdot
179
188
  from keras.src.ops.numpy import vectorize as vectorize
189
+ from keras.src.ops.numpy import view as view
180
190
  from keras.src.ops.numpy import vstack as vstack
181
191
  from keras.src.ops.numpy import where as where
182
192
  from keras.src.ops.numpy import zeros as zeros
@@ -8,6 +8,18 @@ from keras.src.quantizers import deserialize as deserialize
8
8
  from keras.src.quantizers import get as get
9
9
  from keras.src.quantizers import serialize as serialize
10
10
  from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
11
+ from keras.src.quantizers.quantization_config import (
12
+ Float8QuantizationConfig as Float8QuantizationConfig,
13
+ )
14
+ from keras.src.quantizers.quantization_config import (
15
+ Int4QuantizationConfig as Int4QuantizationConfig,
16
+ )
17
+ from keras.src.quantizers.quantization_config import (
18
+ Int8QuantizationConfig as Int8QuantizationConfig,
19
+ )
20
+ from keras.src.quantizers.quantization_config import (
21
+ QuantizationConfig as QuantizationConfig,
22
+ )
11
23
  from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
12
24
  from keras.src.quantizers.quantizers import Quantizer as Quantizer
13
25
  from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
@@ -278,7 +278,10 @@ def _preprocess_tensor_input(x, data_format, mode):
278
278
 
279
279
  # Zero-center by mean pixel
280
280
  if data_format == "channels_first":
281
- mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2))
281
+ if len(x.shape) == 3:
282
+ mean_tensor = ops.reshape(mean_tensor, (3, 1, 1))
283
+ else:
284
+ mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2))
282
285
  else:
283
286
  mean_tensor = ops.reshape(mean_tensor, (1,) * (ndim - 1) + (3,))
284
287
  x += mean_tensor
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import math
2
3
  import operator
3
4
  import re
4
5
  import warnings
@@ -96,13 +97,13 @@ def _convert_conv_transpose_padding_args_from_keras_to_torch(
96
97
  )
97
98
 
98
99
  if torch_output_padding >= stride:
99
- raise ValueError(
100
- f"The padding arguments (padding={padding}) and "
101
- f"output_padding={output_padding}) lead to a Torch "
102
- f"output_padding ({torch_output_padding}) that is greater than "
103
- f"strides ({stride}). This is not supported. You can change the "
104
- f"padding arguments, kernel or stride, or run on another backend. "
100
+ warnings.warn(
101
+ f"Torch backend requires output_padding < stride. "
102
+ f"Clamping output_padding {torch_output_padding} -> {stride - 1} "
103
+ f"for stride {stride}.",
104
+ UserWarning,
105
105
  )
106
+ torch_output_padding = stride - 1
106
107
 
107
108
  return torch_padding, torch_output_padding
108
109
 
@@ -184,6 +185,22 @@ def compute_conv_transpose_padding_args_for_torch(
184
185
  torch_paddings.append(torch_padding)
185
186
  torch_output_paddings.append(torch_output_padding)
186
187
 
188
+ # --- FIX FOR TORCH CONSTRAINT: output_padding < stride ---
189
+ corrected_output_paddings = []
190
+ for s, op in zip(
191
+ strides
192
+ if isinstance(strides, (list, tuple))
193
+ else [strides] * num_spatial_dims,
194
+ torch_output_paddings,
195
+ ):
196
+ max_allowed = max(0, s - 1)
197
+ if op > max_allowed:
198
+ corrected_output_paddings.append(max_allowed)
199
+ else:
200
+ corrected_output_paddings.append(op)
201
+
202
+ torch_output_paddings = corrected_output_paddings
203
+
187
204
  return torch_paddings, torch_output_paddings
188
205
 
189
206
 
@@ -523,3 +540,10 @@ def slice_along_axis(x, start=0, stop=None, step=1, axis=0):
523
540
  -1 - axis
524
541
  )
525
542
  return x[tuple(slices)]
543
+
544
+
545
+ def compute_adaptive_pooling_window_sizes(input_dim, output_dim):
546
+ """Compute small and big window sizes for adaptive pooling."""
547
+ small = math.ceil(input_dim / output_dim)
548
+ big = small + 1
549
+ return small, big