keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/__init__.py CHANGED
@@ -12,6 +12,7 @@ from keras import callbacks as callbacks
12
12
  from keras import config as config
13
13
  from keras import constraints as constraints
14
14
  from keras import datasets as datasets
15
+ from keras import distillation as distillation
15
16
  from keras import distribution as distribution
16
17
  from keras import dtype_policies as dtype_policies
17
18
  from keras import export as export
@@ -10,6 +10,7 @@ from keras import callbacks as callbacks
10
10
  from keras import config as config
11
11
  from keras import constraints as constraints
12
12
  from keras import datasets as datasets
13
+ from keras import distillation as distillation
13
14
  from keras import distribution as distribution
14
15
  from keras import dtype_policies as dtype_policies
15
16
  from keras import export as export
@@ -19,6 +19,9 @@ from keras.src.callbacks.learning_rate_scheduler import (
19
19
  from keras.src.callbacks.model_checkpoint import (
20
20
  ModelCheckpoint as ModelCheckpoint,
21
21
  )
22
+ from keras.src.callbacks.orbax_checkpoint import (
23
+ OrbaxCheckpoint as OrbaxCheckpoint,
24
+ )
22
25
  from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
23
26
  from keras.src.callbacks.reduce_lr_on_plateau import (
24
27
  ReduceLROnPlateau as ReduceLROnPlateau,
@@ -0,0 +1,16 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras.src.distillation.distillation_loss import (
8
+ DistillationLoss as DistillationLoss,
9
+ )
10
+ from keras.src.distillation.distillation_loss import (
11
+ FeatureDistillation as FeatureDistillation,
12
+ )
13
+ from keras.src.distillation.distillation_loss import (
14
+ LogitsDistillation as LogitsDistillation,
15
+ )
16
+ from keras.src.distillation.distiller import Distiller as Distiller
@@ -15,6 +15,9 @@ from keras.src.distribution.distribution_lib import (
15
15
  distribute_tensor as distribute_tensor,
16
16
  )
17
17
  from keras.src.distribution.distribution_lib import distribution as distribution
18
+ from keras.src.distribution.distribution_lib import (
19
+ get_device_count as get_device_count,
20
+ )
18
21
  from keras.src.distribution.distribution_lib import initialize as initialize
19
22
  from keras.src.distribution.distribution_lib import list_devices as list_devices
20
23
  from keras.src.distribution.distribution_lib import (
@@ -7,10 +7,16 @@ since your modifications would be overwritten.
7
7
  from keras.src.dtype_policies import deserialize as deserialize
8
8
  from keras.src.dtype_policies import get as get
9
9
  from keras.src.dtype_policies import serialize as serialize
10
+ from keras.src.dtype_policies.dtype_policy import (
11
+ AWQDTypePolicy as AWQDTypePolicy,
12
+ )
10
13
  from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy
11
14
  from keras.src.dtype_policies.dtype_policy import (
12
15
  FloatDTypePolicy as FloatDTypePolicy,
13
16
  )
17
+ from keras.src.dtype_policies.dtype_policy import (
18
+ GPTQDTypePolicy as GPTQDTypePolicy,
19
+ )
14
20
  from keras.src.dtype_policies.dtype_policy import (
15
21
  QuantizedDTypePolicy as QuantizedDTypePolicy,
16
22
  )
@@ -73,6 +73,9 @@ from keras.src.layers.core.input_layer import Input as Input
73
73
  from keras.src.layers.core.input_layer import InputLayer as InputLayer
74
74
  from keras.src.layers.core.lambda_layer import Lambda as Lambda
75
75
  from keras.src.layers.core.masking import Masking as Masking
76
+ from keras.src.layers.core.reversible_embedding import (
77
+ ReversibleEmbedding as ReversibleEmbedding,
78
+ )
76
79
  from keras.src.layers.core.wrapper import Wrapper as Wrapper
77
80
  from keras.src.layers.input_spec import InputSpec as InputSpec
78
81
  from keras.src.layers.layer import Layer as Layer
@@ -110,6 +113,24 @@ from keras.src.layers.normalization.spectral_normalization import (
110
113
  from keras.src.layers.normalization.unit_normalization import (
111
114
  UnitNormalization as UnitNormalization,
112
115
  )
116
+ from keras.src.layers.pooling.adaptive_average_pooling1d import (
117
+ AdaptiveAveragePooling1D as AdaptiveAveragePooling1D,
118
+ )
119
+ from keras.src.layers.pooling.adaptive_average_pooling2d import (
120
+ AdaptiveAveragePooling2D as AdaptiveAveragePooling2D,
121
+ )
122
+ from keras.src.layers.pooling.adaptive_average_pooling3d import (
123
+ AdaptiveAveragePooling3D as AdaptiveAveragePooling3D,
124
+ )
125
+ from keras.src.layers.pooling.adaptive_max_pooling1d import (
126
+ AdaptiveMaxPooling1D as AdaptiveMaxPooling1D,
127
+ )
128
+ from keras.src.layers.pooling.adaptive_max_pooling2d import (
129
+ AdaptiveMaxPooling2D as AdaptiveMaxPooling2D,
130
+ )
131
+ from keras.src.layers.pooling.adaptive_max_pooling3d import (
132
+ AdaptiveMaxPooling3D as AdaptiveMaxPooling3D,
133
+ )
113
134
  from keras.src.layers.pooling.average_pooling1d import (
114
135
  AveragePooling1D as AveragePooling1D,
115
136
  )
@@ -37,6 +37,7 @@ from keras.src.ops.linalg import det as det
37
37
  from keras.src.ops.linalg import eig as eig
38
38
  from keras.src.ops.linalg import eigh as eigh
39
39
  from keras.src.ops.linalg import inv as inv
40
+ from keras.src.ops.linalg import jvp as jvp
40
41
  from keras.src.ops.linalg import lstsq as lstsq
41
42
  from keras.src.ops.linalg import lu_factor as lu_factor
42
43
  from keras.src.ops.linalg import norm as norm
@@ -63,6 +64,8 @@ from keras.src.ops.math import stft as stft
63
64
  from keras.src.ops.math import top_k as top_k
64
65
  from keras.src.ops.math import view_as_complex as view_as_complex
65
66
  from keras.src.ops.math import view_as_real as view_as_real
67
+ from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
68
+ from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
66
69
  from keras.src.ops.nn import average_pool as average_pool
67
70
  from keras.src.ops.nn import batch_normalization as batch_normalization
68
71
  from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
@@ -116,6 +119,7 @@ from keras.src.ops.nn import sparsemax as sparsemax
116
119
  from keras.src.ops.nn import squareplus as squareplus
117
120
  from keras.src.ops.nn import tanh_shrink as tanh_shrink
118
121
  from keras.src.ops.nn import threshold as threshold
122
+ from keras.src.ops.nn import unfold as unfold
119
123
  from keras.src.ops.numpy import abs as abs
120
124
  from keras.src.ops.numpy import absolute as absolute
121
125
  from keras.src.ops.numpy import add as add
@@ -138,6 +142,7 @@ from keras.src.ops.numpy import argmin as argmin
138
142
  from keras.src.ops.numpy import argpartition as argpartition
139
143
  from keras.src.ops.numpy import argsort as argsort
140
144
  from keras.src.ops.numpy import array as array
145
+ from keras.src.ops.numpy import array_split as array_split
141
146
  from keras.src.ops.numpy import average as average
142
147
  from keras.src.ops.numpy import bartlett as bartlett
143
148
  from keras.src.ops.numpy import bincount as bincount
@@ -176,6 +181,7 @@ from keras.src.ops.numpy import divide_no_nan as divide_no_nan
176
181
  from keras.src.ops.numpy import dot as dot
177
182
  from keras.src.ops.numpy import einsum as einsum
178
183
  from keras.src.ops.numpy import empty as empty
184
+ from keras.src.ops.numpy import empty_like as empty_like
179
185
  from keras.src.ops.numpy import equal as equal
180
186
  from keras.src.ops.numpy import exp as exp
181
187
  from keras.src.ops.numpy import exp2 as exp2
@@ -207,7 +213,11 @@ from keras.src.ops.numpy import isinf as isinf
207
213
  from keras.src.ops.numpy import isnan as isnan
208
214
  from keras.src.ops.numpy import isneginf as isneginf
209
215
  from keras.src.ops.numpy import isposinf as isposinf
216
+ from keras.src.ops.numpy import isreal as isreal
210
217
  from keras.src.ops.numpy import kaiser as kaiser
218
+ from keras.src.ops.numpy import kron as kron
219
+ from keras.src.ops.numpy import lcm as lcm
220
+ from keras.src.ops.numpy import ldexp as ldexp
211
221
  from keras.src.ops.numpy import left_shift as left_shift
212
222
  from keras.src.ops.numpy import less as less
213
223
  from keras.src.ops.numpy import less_equal as less_equal
@@ -217,6 +227,7 @@ from keras.src.ops.numpy import log1p as log1p
217
227
  from keras.src.ops.numpy import log2 as log2
218
228
  from keras.src.ops.numpy import log10 as log10
219
229
  from keras.src.ops.numpy import logaddexp as logaddexp
230
+ from keras.src.ops.numpy import logaddexp2 as logaddexp2
220
231
  from keras.src.ops.numpy import logical_and as logical_and
221
232
  from keras.src.ops.numpy import logical_not as logical_not
222
233
  from keras.src.ops.numpy import logical_or as logical_or
@@ -236,6 +247,7 @@ from keras.src.ops.numpy import multiply as multiply
236
247
  from keras.src.ops.numpy import nan_to_num as nan_to_num
237
248
  from keras.src.ops.numpy import ndim as ndim
238
249
  from keras.src.ops.numpy import negative as negative
250
+ from keras.src.ops.numpy import nextafter as nextafter
239
251
  from keras.src.ops.numpy import nonzero as nonzero
240
252
  from keras.src.ops.numpy import not_equal as not_equal
241
253
  from keras.src.ops.numpy import ones as ones
@@ -244,6 +256,7 @@ from keras.src.ops.numpy import outer as outer
244
256
  from keras.src.ops.numpy import pad as pad
245
257
  from keras.src.ops.numpy import power as power
246
258
  from keras.src.ops.numpy import prod as prod
259
+ from keras.src.ops.numpy import ptp as ptp
247
260
  from keras.src.ops.numpy import quantile as quantile
248
261
  from keras.src.ops.numpy import ravel as ravel
249
262
  from keras.src.ops.numpy import real as real
@@ -280,15 +293,18 @@ from keras.src.ops.numpy import tensordot as tensordot
280
293
  from keras.src.ops.numpy import tile as tile
281
294
  from keras.src.ops.numpy import trace as trace
282
295
  from keras.src.ops.numpy import transpose as transpose
296
+ from keras.src.ops.numpy import trapezoid as trapezoid
283
297
  from keras.src.ops.numpy import tri as tri
284
298
  from keras.src.ops.numpy import tril as tril
285
299
  from keras.src.ops.numpy import triu as triu
286
300
  from keras.src.ops.numpy import true_divide as true_divide
287
301
  from keras.src.ops.numpy import trunc as trunc
288
302
  from keras.src.ops.numpy import unravel_index as unravel_index
303
+ from keras.src.ops.numpy import vander as vander
289
304
  from keras.src.ops.numpy import var as var
290
305
  from keras.src.ops.numpy import vdot as vdot
291
306
  from keras.src.ops.numpy import vectorize as vectorize
307
+ from keras.src.ops.numpy import view as view
292
308
  from keras.src.ops.numpy import vstack as vstack
293
309
  from keras.src.ops.numpy import where as where
294
310
  from keras.src.ops.numpy import zeros as zeros
@@ -8,6 +8,7 @@ from keras.src.ops.image import affine_transform as affine_transform
8
8
  from keras.src.ops.image import crop_images as crop_images
9
9
  from keras.src.ops.image import elastic_transform as elastic_transform
10
10
  from keras.src.ops.image import extract_patches as extract_patches
11
+ from keras.src.ops.image import extract_patches_3d as extract_patches_3d
11
12
  from keras.src.ops.image import gaussian_blur as gaussian_blur
12
13
  from keras.src.ops.image import hsv_to_rgb as hsv_to_rgb
13
14
  from keras.src.ops.image import map_coordinates as map_coordinates
@@ -10,6 +10,7 @@ from keras.src.ops.linalg import det as det
10
10
  from keras.src.ops.linalg import eig as eig
11
11
  from keras.src.ops.linalg import eigh as eigh
12
12
  from keras.src.ops.linalg import inv as inv
13
+ from keras.src.ops.linalg import jvp as jvp
13
14
  from keras.src.ops.linalg import lstsq as lstsq
14
15
  from keras.src.ops.linalg import lu_factor as lu_factor
15
16
  from keras.src.ops.linalg import norm as norm
@@ -4,6 +4,8 @@ This file was autogenerated. Do not edit it by hand,
4
4
  since your modifications would be overwritten.
5
5
  """
6
6
 
7
+ from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
8
+ from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
7
9
  from keras.src.ops.nn import average_pool as average_pool
8
10
  from keras.src.ops.nn import batch_normalization as batch_normalization
9
11
  from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
@@ -57,3 +59,4 @@ from keras.src.ops.nn import sparsemax as sparsemax
57
59
  from keras.src.ops.nn import squareplus as squareplus
58
60
  from keras.src.ops.nn import tanh_shrink as tanh_shrink
59
61
  from keras.src.ops.nn import threshold as threshold
62
+ from keras.src.ops.nn import unfold as unfold
@@ -26,6 +26,7 @@ from keras.src.ops.numpy import argmin as argmin
26
26
  from keras.src.ops.numpy import argpartition as argpartition
27
27
  from keras.src.ops.numpy import argsort as argsort
28
28
  from keras.src.ops.numpy import array as array
29
+ from keras.src.ops.numpy import array_split as array_split
29
30
  from keras.src.ops.numpy import average as average
30
31
  from keras.src.ops.numpy import bartlett as bartlett
31
32
  from keras.src.ops.numpy import bincount as bincount
@@ -64,6 +65,7 @@ from keras.src.ops.numpy import divide_no_nan as divide_no_nan
64
65
  from keras.src.ops.numpy import dot as dot
65
66
  from keras.src.ops.numpy import einsum as einsum
66
67
  from keras.src.ops.numpy import empty as empty
68
+ from keras.src.ops.numpy import empty_like as empty_like
67
69
  from keras.src.ops.numpy import equal as equal
68
70
  from keras.src.ops.numpy import exp as exp
69
71
  from keras.src.ops.numpy import exp2 as exp2
@@ -95,7 +97,11 @@ from keras.src.ops.numpy import isinf as isinf
95
97
  from keras.src.ops.numpy import isnan as isnan
96
98
  from keras.src.ops.numpy import isneginf as isneginf
97
99
  from keras.src.ops.numpy import isposinf as isposinf
100
+ from keras.src.ops.numpy import isreal as isreal
98
101
  from keras.src.ops.numpy import kaiser as kaiser
102
+ from keras.src.ops.numpy import kron as kron
103
+ from keras.src.ops.numpy import lcm as lcm
104
+ from keras.src.ops.numpy import ldexp as ldexp
99
105
  from keras.src.ops.numpy import left_shift as left_shift
100
106
  from keras.src.ops.numpy import less as less
101
107
  from keras.src.ops.numpy import less_equal as less_equal
@@ -105,6 +111,7 @@ from keras.src.ops.numpy import log1p as log1p
105
111
  from keras.src.ops.numpy import log2 as log2
106
112
  from keras.src.ops.numpy import log10 as log10
107
113
  from keras.src.ops.numpy import logaddexp as logaddexp
114
+ from keras.src.ops.numpy import logaddexp2 as logaddexp2
108
115
  from keras.src.ops.numpy import logical_and as logical_and
109
116
  from keras.src.ops.numpy import logical_not as logical_not
110
117
  from keras.src.ops.numpy import logical_or as logical_or
@@ -124,6 +131,7 @@ from keras.src.ops.numpy import multiply as multiply
124
131
  from keras.src.ops.numpy import nan_to_num as nan_to_num
125
132
  from keras.src.ops.numpy import ndim as ndim
126
133
  from keras.src.ops.numpy import negative as negative
134
+ from keras.src.ops.numpy import nextafter as nextafter
127
135
  from keras.src.ops.numpy import nonzero as nonzero
128
136
  from keras.src.ops.numpy import not_equal as not_equal
129
137
  from keras.src.ops.numpy import ones as ones
@@ -132,6 +140,7 @@ from keras.src.ops.numpy import outer as outer
132
140
  from keras.src.ops.numpy import pad as pad
133
141
  from keras.src.ops.numpy import power as power
134
142
  from keras.src.ops.numpy import prod as prod
143
+ from keras.src.ops.numpy import ptp as ptp
135
144
  from keras.src.ops.numpy import quantile as quantile
136
145
  from keras.src.ops.numpy import ravel as ravel
137
146
  from keras.src.ops.numpy import real as real
@@ -168,15 +177,18 @@ from keras.src.ops.numpy import tensordot as tensordot
168
177
  from keras.src.ops.numpy import tile as tile
169
178
  from keras.src.ops.numpy import trace as trace
170
179
  from keras.src.ops.numpy import transpose as transpose
180
+ from keras.src.ops.numpy import trapezoid as trapezoid
171
181
  from keras.src.ops.numpy import tri as tri
172
182
  from keras.src.ops.numpy import tril as tril
173
183
  from keras.src.ops.numpy import triu as triu
174
184
  from keras.src.ops.numpy import true_divide as true_divide
175
185
  from keras.src.ops.numpy import trunc as trunc
176
186
  from keras.src.ops.numpy import unravel_index as unravel_index
187
+ from keras.src.ops.numpy import vander as vander
177
188
  from keras.src.ops.numpy import var as var
178
189
  from keras.src.ops.numpy import vdot as vdot
179
190
  from keras.src.ops.numpy import vectorize as vectorize
191
+ from keras.src.ops.numpy import view as view
180
192
  from keras.src.ops.numpy import vstack as vstack
181
193
  from keras.src.ops.numpy import where as where
182
194
  from keras.src.ops.numpy import zeros as zeros
@@ -7,7 +7,20 @@ since your modifications would be overwritten.
7
7
  from keras.src.quantizers import deserialize as deserialize
8
8
  from keras.src.quantizers import get as get
9
9
  from keras.src.quantizers import serialize as serialize
10
+ from keras.src.quantizers.awq_config import AWQConfig as AWQConfig
10
11
  from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
12
+ from keras.src.quantizers.quantization_config import (
13
+ Float8QuantizationConfig as Float8QuantizationConfig,
14
+ )
15
+ from keras.src.quantizers.quantization_config import (
16
+ Int4QuantizationConfig as Int4QuantizationConfig,
17
+ )
18
+ from keras.src.quantizers.quantization_config import (
19
+ Int8QuantizationConfig as Int8QuantizationConfig,
20
+ )
21
+ from keras.src.quantizers.quantization_config import (
22
+ QuantizationConfig as QuantizationConfig,
23
+ )
11
24
  from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
12
25
  from keras.src.quantizers.quantizers import Quantizer as Quantizer
13
26
  from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
@@ -19,6 +19,9 @@ from keras.src.callbacks.learning_rate_scheduler import (
19
19
  from keras.src.callbacks.model_checkpoint import (
20
20
  ModelCheckpoint as ModelCheckpoint,
21
21
  )
22
+ from keras.src.callbacks.orbax_checkpoint import (
23
+ OrbaxCheckpoint as OrbaxCheckpoint,
24
+ )
22
25
  from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
23
26
  from keras.src.callbacks.reduce_lr_on_plateau import (
24
27
  ReduceLROnPlateau as ReduceLROnPlateau,
@@ -0,0 +1,16 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras.src.distillation.distillation_loss import (
8
+ DistillationLoss as DistillationLoss,
9
+ )
10
+ from keras.src.distillation.distillation_loss import (
11
+ FeatureDistillation as FeatureDistillation,
12
+ )
13
+ from keras.src.distillation.distillation_loss import (
14
+ LogitsDistillation as LogitsDistillation,
15
+ )
16
+ from keras.src.distillation.distiller import Distiller as Distiller
@@ -15,6 +15,9 @@ from keras.src.distribution.distribution_lib import (
15
15
  distribute_tensor as distribute_tensor,
16
16
  )
17
17
  from keras.src.distribution.distribution_lib import distribution as distribution
18
+ from keras.src.distribution.distribution_lib import (
19
+ get_device_count as get_device_count,
20
+ )
18
21
  from keras.src.distribution.distribution_lib import initialize as initialize
19
22
  from keras.src.distribution.distribution_lib import list_devices as list_devices
20
23
  from keras.src.distribution.distribution_lib import (
@@ -7,10 +7,16 @@ since your modifications would be overwritten.
7
7
  from keras.src.dtype_policies import deserialize as deserialize
8
8
  from keras.src.dtype_policies import get as get
9
9
  from keras.src.dtype_policies import serialize as serialize
10
+ from keras.src.dtype_policies.dtype_policy import (
11
+ AWQDTypePolicy as AWQDTypePolicy,
12
+ )
10
13
  from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy
11
14
  from keras.src.dtype_policies.dtype_policy import (
12
15
  FloatDTypePolicy as FloatDTypePolicy,
13
16
  )
17
+ from keras.src.dtype_policies.dtype_policy import (
18
+ GPTQDTypePolicy as GPTQDTypePolicy,
19
+ )
14
20
  from keras.src.dtype_policies.dtype_policy import (
15
21
  QuantizedDTypePolicy as QuantizedDTypePolicy,
16
22
  )
keras/layers/__init__.py CHANGED
@@ -73,6 +73,9 @@ from keras.src.layers.core.input_layer import Input as Input
73
73
  from keras.src.layers.core.input_layer import InputLayer as InputLayer
74
74
  from keras.src.layers.core.lambda_layer import Lambda as Lambda
75
75
  from keras.src.layers.core.masking import Masking as Masking
76
+ from keras.src.layers.core.reversible_embedding import (
77
+ ReversibleEmbedding as ReversibleEmbedding,
78
+ )
76
79
  from keras.src.layers.core.wrapper import Wrapper as Wrapper
77
80
  from keras.src.layers.input_spec import InputSpec as InputSpec
78
81
  from keras.src.layers.layer import Layer as Layer
@@ -110,6 +113,24 @@ from keras.src.layers.normalization.spectral_normalization import (
110
113
  from keras.src.layers.normalization.unit_normalization import (
111
114
  UnitNormalization as UnitNormalization,
112
115
  )
116
+ from keras.src.layers.pooling.adaptive_average_pooling1d import (
117
+ AdaptiveAveragePooling1D as AdaptiveAveragePooling1D,
118
+ )
119
+ from keras.src.layers.pooling.adaptive_average_pooling2d import (
120
+ AdaptiveAveragePooling2D as AdaptiveAveragePooling2D,
121
+ )
122
+ from keras.src.layers.pooling.adaptive_average_pooling3d import (
123
+ AdaptiveAveragePooling3D as AdaptiveAveragePooling3D,
124
+ )
125
+ from keras.src.layers.pooling.adaptive_max_pooling1d import (
126
+ AdaptiveMaxPooling1D as AdaptiveMaxPooling1D,
127
+ )
128
+ from keras.src.layers.pooling.adaptive_max_pooling2d import (
129
+ AdaptiveMaxPooling2D as AdaptiveMaxPooling2D,
130
+ )
131
+ from keras.src.layers.pooling.adaptive_max_pooling3d import (
132
+ AdaptiveMaxPooling3D as AdaptiveMaxPooling3D,
133
+ )
113
134
  from keras.src.layers.pooling.average_pooling1d import (
114
135
  AveragePooling1D as AveragePooling1D,
115
136
  )
keras/ops/__init__.py CHANGED
@@ -37,6 +37,7 @@ from keras.src.ops.linalg import det as det
37
37
  from keras.src.ops.linalg import eig as eig
38
38
  from keras.src.ops.linalg import eigh as eigh
39
39
  from keras.src.ops.linalg import inv as inv
40
+ from keras.src.ops.linalg import jvp as jvp
40
41
  from keras.src.ops.linalg import lstsq as lstsq
41
42
  from keras.src.ops.linalg import lu_factor as lu_factor
42
43
  from keras.src.ops.linalg import norm as norm
@@ -63,6 +64,8 @@ from keras.src.ops.math import stft as stft
63
64
  from keras.src.ops.math import top_k as top_k
64
65
  from keras.src.ops.math import view_as_complex as view_as_complex
65
66
  from keras.src.ops.math import view_as_real as view_as_real
67
+ from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
68
+ from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
66
69
  from keras.src.ops.nn import average_pool as average_pool
67
70
  from keras.src.ops.nn import batch_normalization as batch_normalization
68
71
  from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
@@ -116,6 +119,7 @@ from keras.src.ops.nn import sparsemax as sparsemax
116
119
  from keras.src.ops.nn import squareplus as squareplus
117
120
  from keras.src.ops.nn import tanh_shrink as tanh_shrink
118
121
  from keras.src.ops.nn import threshold as threshold
122
+ from keras.src.ops.nn import unfold as unfold
119
123
  from keras.src.ops.numpy import abs as abs
120
124
  from keras.src.ops.numpy import absolute as absolute
121
125
  from keras.src.ops.numpy import add as add
@@ -138,6 +142,7 @@ from keras.src.ops.numpy import argmin as argmin
138
142
  from keras.src.ops.numpy import argpartition as argpartition
139
143
  from keras.src.ops.numpy import argsort as argsort
140
144
  from keras.src.ops.numpy import array as array
145
+ from keras.src.ops.numpy import array_split as array_split
141
146
  from keras.src.ops.numpy import average as average
142
147
  from keras.src.ops.numpy import bartlett as bartlett
143
148
  from keras.src.ops.numpy import bincount as bincount
@@ -176,6 +181,7 @@ from keras.src.ops.numpy import divide_no_nan as divide_no_nan
176
181
  from keras.src.ops.numpy import dot as dot
177
182
  from keras.src.ops.numpy import einsum as einsum
178
183
  from keras.src.ops.numpy import empty as empty
184
+ from keras.src.ops.numpy import empty_like as empty_like
179
185
  from keras.src.ops.numpy import equal as equal
180
186
  from keras.src.ops.numpy import exp as exp
181
187
  from keras.src.ops.numpy import exp2 as exp2
@@ -207,7 +213,11 @@ from keras.src.ops.numpy import isinf as isinf
207
213
  from keras.src.ops.numpy import isnan as isnan
208
214
  from keras.src.ops.numpy import isneginf as isneginf
209
215
  from keras.src.ops.numpy import isposinf as isposinf
216
+ from keras.src.ops.numpy import isreal as isreal
210
217
  from keras.src.ops.numpy import kaiser as kaiser
218
+ from keras.src.ops.numpy import kron as kron
219
+ from keras.src.ops.numpy import lcm as lcm
220
+ from keras.src.ops.numpy import ldexp as ldexp
211
221
  from keras.src.ops.numpy import left_shift as left_shift
212
222
  from keras.src.ops.numpy import less as less
213
223
  from keras.src.ops.numpy import less_equal as less_equal
@@ -217,6 +227,7 @@ from keras.src.ops.numpy import log1p as log1p
217
227
  from keras.src.ops.numpy import log2 as log2
218
228
  from keras.src.ops.numpy import log10 as log10
219
229
  from keras.src.ops.numpy import logaddexp as logaddexp
230
+ from keras.src.ops.numpy import logaddexp2 as logaddexp2
220
231
  from keras.src.ops.numpy import logical_and as logical_and
221
232
  from keras.src.ops.numpy import logical_not as logical_not
222
233
  from keras.src.ops.numpy import logical_or as logical_or
@@ -236,6 +247,7 @@ from keras.src.ops.numpy import multiply as multiply
236
247
  from keras.src.ops.numpy import nan_to_num as nan_to_num
237
248
  from keras.src.ops.numpy import ndim as ndim
238
249
  from keras.src.ops.numpy import negative as negative
250
+ from keras.src.ops.numpy import nextafter as nextafter
239
251
  from keras.src.ops.numpy import nonzero as nonzero
240
252
  from keras.src.ops.numpy import not_equal as not_equal
241
253
  from keras.src.ops.numpy import ones as ones
@@ -244,6 +256,7 @@ from keras.src.ops.numpy import outer as outer
244
256
  from keras.src.ops.numpy import pad as pad
245
257
  from keras.src.ops.numpy import power as power
246
258
  from keras.src.ops.numpy import prod as prod
259
+ from keras.src.ops.numpy import ptp as ptp
247
260
  from keras.src.ops.numpy import quantile as quantile
248
261
  from keras.src.ops.numpy import ravel as ravel
249
262
  from keras.src.ops.numpy import real as real
@@ -280,15 +293,18 @@ from keras.src.ops.numpy import tensordot as tensordot
280
293
  from keras.src.ops.numpy import tile as tile
281
294
  from keras.src.ops.numpy import trace as trace
282
295
  from keras.src.ops.numpy import transpose as transpose
296
+ from keras.src.ops.numpy import trapezoid as trapezoid
283
297
  from keras.src.ops.numpy import tri as tri
284
298
  from keras.src.ops.numpy import tril as tril
285
299
  from keras.src.ops.numpy import triu as triu
286
300
  from keras.src.ops.numpy import true_divide as true_divide
287
301
  from keras.src.ops.numpy import trunc as trunc
288
302
  from keras.src.ops.numpy import unravel_index as unravel_index
303
+ from keras.src.ops.numpy import vander as vander
289
304
  from keras.src.ops.numpy import var as var
290
305
  from keras.src.ops.numpy import vdot as vdot
291
306
  from keras.src.ops.numpy import vectorize as vectorize
307
+ from keras.src.ops.numpy import view as view
292
308
  from keras.src.ops.numpy import vstack as vstack
293
309
  from keras.src.ops.numpy import where as where
294
310
  from keras.src.ops.numpy import zeros as zeros
@@ -8,6 +8,7 @@ from keras.src.ops.image import affine_transform as affine_transform
8
8
  from keras.src.ops.image import crop_images as crop_images
9
9
  from keras.src.ops.image import elastic_transform as elastic_transform
10
10
  from keras.src.ops.image import extract_patches as extract_patches
11
+ from keras.src.ops.image import extract_patches_3d as extract_patches_3d
11
12
  from keras.src.ops.image import gaussian_blur as gaussian_blur
12
13
  from keras.src.ops.image import hsv_to_rgb as hsv_to_rgb
13
14
  from keras.src.ops.image import map_coordinates as map_coordinates
@@ -10,6 +10,7 @@ from keras.src.ops.linalg import det as det
10
10
  from keras.src.ops.linalg import eig as eig
11
11
  from keras.src.ops.linalg import eigh as eigh
12
12
  from keras.src.ops.linalg import inv as inv
13
+ from keras.src.ops.linalg import jvp as jvp
13
14
  from keras.src.ops.linalg import lstsq as lstsq
14
15
  from keras.src.ops.linalg import lu_factor as lu_factor
15
16
  from keras.src.ops.linalg import norm as norm
keras/ops/nn/__init__.py CHANGED
@@ -4,6 +4,8 @@ This file was autogenerated. Do not edit it by hand,
4
4
  since your modifications would be overwritten.
5
5
  """
6
6
 
7
+ from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
8
+ from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
7
9
  from keras.src.ops.nn import average_pool as average_pool
8
10
  from keras.src.ops.nn import batch_normalization as batch_normalization
9
11
  from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
@@ -57,3 +59,4 @@ from keras.src.ops.nn import sparsemax as sparsemax
57
59
  from keras.src.ops.nn import squareplus as squareplus
58
60
  from keras.src.ops.nn import tanh_shrink as tanh_shrink
59
61
  from keras.src.ops.nn import threshold as threshold
62
+ from keras.src.ops.nn import unfold as unfold
@@ -26,6 +26,7 @@ from keras.src.ops.numpy import argmin as argmin
26
26
  from keras.src.ops.numpy import argpartition as argpartition
27
27
  from keras.src.ops.numpy import argsort as argsort
28
28
  from keras.src.ops.numpy import array as array
29
+ from keras.src.ops.numpy import array_split as array_split
29
30
  from keras.src.ops.numpy import average as average
30
31
  from keras.src.ops.numpy import bartlett as bartlett
31
32
  from keras.src.ops.numpy import bincount as bincount
@@ -64,6 +65,7 @@ from keras.src.ops.numpy import divide_no_nan as divide_no_nan
64
65
  from keras.src.ops.numpy import dot as dot
65
66
  from keras.src.ops.numpy import einsum as einsum
66
67
  from keras.src.ops.numpy import empty as empty
68
+ from keras.src.ops.numpy import empty_like as empty_like
67
69
  from keras.src.ops.numpy import equal as equal
68
70
  from keras.src.ops.numpy import exp as exp
69
71
  from keras.src.ops.numpy import exp2 as exp2
@@ -95,7 +97,11 @@ from keras.src.ops.numpy import isinf as isinf
95
97
  from keras.src.ops.numpy import isnan as isnan
96
98
  from keras.src.ops.numpy import isneginf as isneginf
97
99
  from keras.src.ops.numpy import isposinf as isposinf
100
+ from keras.src.ops.numpy import isreal as isreal
98
101
  from keras.src.ops.numpy import kaiser as kaiser
102
+ from keras.src.ops.numpy import kron as kron
103
+ from keras.src.ops.numpy import lcm as lcm
104
+ from keras.src.ops.numpy import ldexp as ldexp
99
105
  from keras.src.ops.numpy import left_shift as left_shift
100
106
  from keras.src.ops.numpy import less as less
101
107
  from keras.src.ops.numpy import less_equal as less_equal
@@ -105,6 +111,7 @@ from keras.src.ops.numpy import log1p as log1p
105
111
  from keras.src.ops.numpy import log2 as log2
106
112
  from keras.src.ops.numpy import log10 as log10
107
113
  from keras.src.ops.numpy import logaddexp as logaddexp
114
+ from keras.src.ops.numpy import logaddexp2 as logaddexp2
108
115
  from keras.src.ops.numpy import logical_and as logical_and
109
116
  from keras.src.ops.numpy import logical_not as logical_not
110
117
  from keras.src.ops.numpy import logical_or as logical_or
@@ -124,6 +131,7 @@ from keras.src.ops.numpy import multiply as multiply
124
131
  from keras.src.ops.numpy import nan_to_num as nan_to_num
125
132
  from keras.src.ops.numpy import ndim as ndim
126
133
  from keras.src.ops.numpy import negative as negative
134
+ from keras.src.ops.numpy import nextafter as nextafter
127
135
  from keras.src.ops.numpy import nonzero as nonzero
128
136
  from keras.src.ops.numpy import not_equal as not_equal
129
137
  from keras.src.ops.numpy import ones as ones
@@ -132,6 +140,7 @@ from keras.src.ops.numpy import outer as outer
132
140
  from keras.src.ops.numpy import pad as pad
133
141
  from keras.src.ops.numpy import power as power
134
142
  from keras.src.ops.numpy import prod as prod
143
+ from keras.src.ops.numpy import ptp as ptp
135
144
  from keras.src.ops.numpy import quantile as quantile
136
145
  from keras.src.ops.numpy import ravel as ravel
137
146
  from keras.src.ops.numpy import real as real
@@ -168,15 +177,18 @@ from keras.src.ops.numpy import tensordot as tensordot
168
177
  from keras.src.ops.numpy import tile as tile
169
178
  from keras.src.ops.numpy import trace as trace
170
179
  from keras.src.ops.numpy import transpose as transpose
180
+ from keras.src.ops.numpy import trapezoid as trapezoid
171
181
  from keras.src.ops.numpy import tri as tri
172
182
  from keras.src.ops.numpy import tril as tril
173
183
  from keras.src.ops.numpy import triu as triu
174
184
  from keras.src.ops.numpy import true_divide as true_divide
175
185
  from keras.src.ops.numpy import trunc as trunc
176
186
  from keras.src.ops.numpy import unravel_index as unravel_index
187
+ from keras.src.ops.numpy import vander as vander
177
188
  from keras.src.ops.numpy import var as var
178
189
  from keras.src.ops.numpy import vdot as vdot
179
190
  from keras.src.ops.numpy import vectorize as vectorize
191
+ from keras.src.ops.numpy import view as view
180
192
  from keras.src.ops.numpy import vstack as vstack
181
193
  from keras.src.ops.numpy import where as where
182
194
  from keras.src.ops.numpy import zeros as zeros
@@ -7,7 +7,20 @@ since your modifications would be overwritten.
7
7
  from keras.src.quantizers import deserialize as deserialize
8
8
  from keras.src.quantizers import get as get
9
9
  from keras.src.quantizers import serialize as serialize
10
+ from keras.src.quantizers.awq_config import AWQConfig as AWQConfig
10
11
  from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
12
+ from keras.src.quantizers.quantization_config import (
13
+ Float8QuantizationConfig as Float8QuantizationConfig,
14
+ )
15
+ from keras.src.quantizers.quantization_config import (
16
+ Int4QuantizationConfig as Int4QuantizationConfig,
17
+ )
18
+ from keras.src.quantizers.quantization_config import (
19
+ Int8QuantizationConfig as Int8QuantizationConfig,
20
+ )
21
+ from keras.src.quantizers.quantization_config import (
22
+ QuantizationConfig as QuantizationConfig,
23
+ )
11
24
  from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
12
25
  from keras.src.quantizers.quantizers import Quantizer as Quantizer
13
26
  from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
@@ -278,7 +278,10 @@ def _preprocess_tensor_input(x, data_format, mode):
278
278
 
279
279
  # Zero-center by mean pixel
280
280
  if data_format == "channels_first":
281
- mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2))
281
+ if len(x.shape) == 3:
282
+ mean_tensor = ops.reshape(mean_tensor, (3, 1, 1))
283
+ else:
284
+ mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2))
282
285
  else:
283
286
  mean_tensor = ops.reshape(mean_tensor, (1,) * (ndim - 1) + (3,))
284
287
  x += mean_tensor