keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,390 @@
1
+ import copy
2
+
3
+ from keras.src import dtype_policies
4
+ from keras.src import layers
5
+ from keras.src import ops
6
+ from keras.src import quantizers
7
+ from keras.src.api_export import keras_export
8
+ from keras.src.backend import KerasTensor
9
+ from keras.src.quantizers.quantization_config import QuantizationConfig
10
+
11
+
12
+ @keras_export("keras.layers.ReversibleEmbedding")
13
+ class ReversibleEmbedding(layers.Embedding):
14
+ """An embedding layer which can project backwards to the input dim.
15
+
16
+ This layer is an extension of `keras.layers.Embedding` for language models.
17
+ This layer can be called "in reverse" with `reverse=True`, in which case the
18
+ layer will linearly project from `output_dim` back to `input_dim`.
19
+
20
+ By default, the reverse projection will use the transpose of the
21
+ `embeddings` weights to project to `input_dim` (weights are "tied"). If
22
+ `tie_weights=False`, the model will use a separate, trainable variable for
23
+ reverse projection.
24
+
25
+ This layer has no bias terms.
26
+
27
+ Args:
28
+ input_dim: Integer. Size of the vocabulary,
29
+ i.e. maximum integer index + 1.
30
+ output_dim: Integer. Dimension of the dense embedding.
31
+ tie_weights: Boolean, whether or not the matrix for embedding and
32
+ the matrix for the `reverse` projection should share the same
33
+ weights.
34
+ embeddings_initializer: Initializer for the `embeddings`
35
+ matrix (see `keras.initializers`).
36
+ embeddings_regularizer: Regularizer function applied to
37
+ the `embeddings` matrix (see `keras.regularizers`).
38
+ embeddings_constraint: Constraint function applied to
39
+ the `embeddings` matrix (see `keras.constraints`).
40
+ mask_zero: Boolean, whether or not the input value 0 is a special
41
+ "padding" value that should be masked out.
42
+ reverse_dtype: The dtype for the reverse projection computation.
43
+ Defaults to the `compute_dtype` of the layer.
44
+ logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
45
+ output logits will be scaled by
46
+ `tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
47
+ range of output logits and can improve training.
48
+ **kwargs: other keyword arguments passed to `keras.layers.Embedding`,
49
+ including `name`, `trainable`, `dtype` etc.
50
+
51
+ Call arguments:
52
+ inputs: The tensor inputs to the layer.
53
+ reverse: Boolean. If `True` the layer will perform a linear projection
54
+ from `output_dim` to `input_dim`, instead of a normal embedding
55
+ call. Default to `False`.
56
+
57
+ Example:
58
+ ```python
59
+ batch_size = 16
60
+ vocab_size = 100
61
+ hidden_dim = 32
62
+ seq_length = 50
63
+
64
+ # Generate random inputs.
65
+ token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
66
+
67
+ embedding = keras.layers.ReversibleEmbedding(vocab_size, hidden_dim)
68
+ # Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
69
+ hidden_states = embedding(token_ids)
70
+ # Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
71
+ logits = embedding(hidden_states, reverse=True)
72
+ ```
73
+
74
+ References:
75
+ - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
76
+ - [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ input_dim,
82
+ output_dim,
83
+ tie_weights=True,
84
+ embeddings_initializer="uniform",
85
+ embeddings_regularizer=None,
86
+ embeddings_constraint=None,
87
+ mask_zero=False,
88
+ reverse_dtype=None,
89
+ logit_soft_cap=None,
90
+ **kwargs,
91
+ ):
92
+ super().__init__(
93
+ input_dim,
94
+ output_dim,
95
+ embeddings_initializer=embeddings_initializer,
96
+ embeddings_regularizer=embeddings_regularizer,
97
+ embeddings_constraint=embeddings_constraint,
98
+ mask_zero=mask_zero,
99
+ **kwargs,
100
+ )
101
+ self.tie_weights = tie_weights
102
+ self.reverse_dtype = reverse_dtype
103
+ self.logit_soft_cap = logit_soft_cap
104
+
105
+ def build(self, inputs_shape=None):
106
+ super().build(inputs_shape)
107
+ if not self.tie_weights and self.quantization_mode not in (
108
+ "int8",
109
+ "int4",
110
+ ):
111
+ self.reverse_embeddings = self.add_weight(
112
+ shape=(self.output_dim, self.input_dim),
113
+ initializer=self.embeddings_initializer,
114
+ name="reverse_embeddings",
115
+ trainable=True,
116
+ )
117
+
118
+ def call(self, inputs, reverse=False):
119
+ if not reverse:
120
+ return super().call(inputs)
121
+ else:
122
+ if self.tie_weights:
123
+ kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
124
+ else:
125
+ kernel = self.reverse_embeddings
126
+ if self.reverse_dtype is not None:
127
+ inputs = ops.cast(inputs, self.reverse_dtype)
128
+ kernel = ops.cast(kernel, self.reverse_dtype)
129
+ logits = ops.matmul(inputs, kernel)
130
+ # Optionally soft-cap logits.
131
+ if self.logit_soft_cap is not None:
132
+ soft_cap = self.logit_soft_cap
133
+ logits = ops.multiply(
134
+ ops.tanh(ops.divide(logits, soft_cap)), soft_cap
135
+ )
136
+ return logits
137
+
138
+ def compute_output_shape(self, input_shape, reverse=False):
139
+ output_shape = list(input_shape)
140
+ if reverse:
141
+ output_shape[-1] = self.input_dim
142
+ else:
143
+ output_shape += [self.output_dim]
144
+ return output_shape
145
+
146
+ def compute_output_spec(self, inputs, reverse=False):
147
+ output_shape = list(inputs.shape)
148
+ if reverse:
149
+ output_shape[-1] = self.input_dim
150
+ else:
151
+ output_shape += [self.output_dim]
152
+ return KerasTensor(output_shape, dtype=self.compute_dtype)
153
+
154
+ def get_config(self):
155
+ config = super().get_config()
156
+ config.update(
157
+ {
158
+ "tie_weights": self.tie_weights,
159
+ "reverse_dtype": self.reverse_dtype,
160
+ "logit_soft_cap": self.logit_soft_cap,
161
+ }
162
+ )
163
+ return config
164
+
165
+ @property
166
+ def variable_serialization_spec(self):
167
+ # Avoid modifying the parent's spec.
168
+ _spec = copy.deepcopy(super().variable_serialization_spec)
169
+ if not self.tie_weights:
170
+ for mode, variable_spec in _spec.items():
171
+ variable_spec.append("reverse_embeddings")
172
+ if mode in ("int4", "int8"):
173
+ variable_spec.append("reverse_embeddings_scale")
174
+ return _spec
175
+
176
+ def quantized_build(self, embeddings_shape, mode, config=None):
177
+ if mode == "int8":
178
+ self._int8_build(embeddings_shape, config)
179
+ elif mode == "int4":
180
+ self._int4_build(embeddings_shape, config)
181
+ else:
182
+ raise self._quantization_mode_error(mode)
183
+ self._is_quantized = True
184
+
185
+ def _int8_build(self, embeddings_shape, config=None):
186
+ if embeddings_shape is None:
187
+ embeddings_shape = (self.input_dim, self.output_dim)
188
+ super()._int8_build(embeddings_shape=embeddings_shape)
189
+
190
+ self.inputs_quantizer = (
191
+ QuantizationConfig.activation_quantizer_or_default(
192
+ config, quantizers.AbsMaxQuantizer(axis=-1)
193
+ )
194
+ )
195
+ if not self.tie_weights:
196
+ self.reverse_embeddings = self.add_weight(
197
+ name="reverse_embeddings",
198
+ shape=(self.output_dim, self.input_dim),
199
+ initializer="zeros",
200
+ dtype="int8",
201
+ trainable=False,
202
+ )
203
+ self.reverse_embeddings_scale = self.add_weight(
204
+ name="reverse_embeddings_scale",
205
+ shape=(self.input_dim,),
206
+ initializer="ones",
207
+ trainable=False,
208
+ )
209
+
210
+ def _int4_build(self, embeddings_shape, config=None):
211
+ if embeddings_shape is None:
212
+ embeddings_shape = (self.input_dim, self.output_dim)
213
+ super()._int4_build(embeddings_shape=embeddings_shape, config=config)
214
+
215
+ self.inputs_quantizer = (
216
+ QuantizationConfig.activation_quantizer_or_default(
217
+ config, quantizers.AbsMaxQuantizer(axis=-1)
218
+ )
219
+ )
220
+ if not self.tie_weights:
221
+ packed_rows = (self.output_dim + 1) // 2 # ceil for odd dims
222
+ self.reverse_embeddings = self.add_weight(
223
+ name="reverse_embeddings",
224
+ shape=(packed_rows, self.input_dim),
225
+ initializer="zeros",
226
+ dtype="int8",
227
+ trainable=False,
228
+ )
229
+ self.reverse_embeddings_scale = self.add_weight(
230
+ name="reverse_embeddings_scale",
231
+ shape=(self.input_dim,),
232
+ initializer="ones",
233
+ trainable=False,
234
+ )
235
+
236
+ def _int8_call(self, inputs, reverse=False):
237
+ if not reverse:
238
+ return super()._int8_call(inputs)
239
+ else:
240
+ if self.tie_weights:
241
+ kernel = ops.transpose(self._embeddings)
242
+ scale = ops.transpose(self.embeddings_scale)
243
+ else:
244
+ kernel = self.reverse_embeddings
245
+ scale = self.reverse_embeddings_scale
246
+ if self.inputs_quantizer:
247
+ inputs, inputs_scale = self.inputs_quantizer(inputs)
248
+ else:
249
+ inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
250
+ logits = ops.matmul(inputs, kernel)
251
+ # De-scale outputs
252
+ logits = ops.cast(logits, self.compute_dtype)
253
+ logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
254
+ # Optionally soft-cap logits.
255
+ if self.logit_soft_cap is not None:
256
+ soft_cap = self.logit_soft_cap
257
+ logits = ops.multiply(
258
+ ops.tanh(ops.divide(logits, soft_cap)), soft_cap
259
+ )
260
+ return logits
261
+
262
+ def _int4_call(self, inputs, reverse=False):
263
+ if not reverse:
264
+ return super()._int4_call(inputs)
265
+ else:
266
+ if self.tie_weights:
267
+ embeddings = ops.transpose(self._embeddings)
268
+ scale = ops.transpose(self.embeddings_scale)
269
+ else:
270
+ embeddings = self.reverse_embeddings
271
+ scale = self.reverse_embeddings_scale
272
+ unpacked_embeddings = quantizers.unpack_int4(
273
+ embeddings, self.output_dim, axis=0
274
+ )
275
+ if self.inputs_quantizer:
276
+ inputs, inputs_scale = self.inputs_quantizer(inputs)
277
+ else:
278
+ inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
279
+ logits = ops.matmul(inputs, unpacked_embeddings)
280
+ # De-scale outputs
281
+ logits = ops.cast(logits, self.compute_dtype)
282
+ logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
283
+ # Optionally soft-cap logits.
284
+ if self.logit_soft_cap is not None:
285
+ soft_cap = self.logit_soft_cap
286
+ logits = ops.multiply(
287
+ ops.tanh(ops.divide(logits, soft_cap)), soft_cap
288
+ )
289
+ return logits
290
+
291
+ def quantize(self, mode=None, type_check=True, config=None):
292
+ if type_check and type(self) is not ReversibleEmbedding:
293
+ raise self._not_implemented_error(self.quantize)
294
+
295
+ self.quantization_config = config
296
+
297
+ embeddings_shape = (self.input_dim, self.output_dim)
298
+ if mode == "int8":
299
+ # Quantize `self._embeddings` to int8 and compute corresponding
300
+ # scale.
301
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
302
+ self.quantization_config, quantizers.AbsMaxQuantizer(axis=-1)
303
+ )
304
+ embeddings_value, embeddings_scale = weight_quantizer(
305
+ self._embeddings, to_numpy=True
306
+ )
307
+ embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
308
+ del self._embeddings
309
+ if not self.tie_weights:
310
+ reverse_weight_quantizer = (
311
+ QuantizationConfig.weight_quantizer_or_default(
312
+ self.quantization_config,
313
+ quantizers.AbsMaxQuantizer(axis=0),
314
+ )
315
+ )
316
+ reverse_embeddings_value, reverse_embeddings_scale = (
317
+ reverse_weight_quantizer(
318
+ self.reverse_embeddings, to_numpy=True
319
+ )
320
+ )
321
+ reverse_embeddings_scale = ops.squeeze(
322
+ reverse_embeddings_scale, axis=0
323
+ )
324
+ del self.reverse_embeddings
325
+ self.quantized_build(
326
+ embeddings_shape, mode, self.quantization_config
327
+ )
328
+ self._embeddings.assign(embeddings_value)
329
+ self.embeddings_scale.assign(embeddings_scale)
330
+ if not self.tie_weights:
331
+ self.reverse_embeddings.assign(reverse_embeddings_value)
332
+ self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
333
+ elif mode == "int4":
334
+ # Quantize to int4 values (stored in int8 dtype, range [-8, 7]).
335
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
336
+ self.quantization_config,
337
+ quantizers.AbsMaxQuantizer(
338
+ axis=-1,
339
+ value_range=(-8, 7),
340
+ output_dtype="int8",
341
+ ),
342
+ )
343
+ embeddings_value, embeddings_scale = weight_quantizer(
344
+ self._embeddings, to_numpy=True
345
+ )
346
+ embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
347
+ # 2. Pack two int4 values into a single int8 byte.
348
+ packed_embeddings_value, _, _ = quantizers.pack_int4(
349
+ embeddings_value, axis=-1
350
+ )
351
+ del self._embeddings
352
+ if not self.tie_weights:
353
+ reverse_weight_quantizer = (
354
+ QuantizationConfig.weight_quantizer_or_default(
355
+ self.quantization_config,
356
+ quantizers.AbsMaxQuantizer(
357
+ axis=0,
358
+ value_range=(-8, 7),
359
+ output_dtype="int8",
360
+ ),
361
+ )
362
+ )
363
+ reverse_embeddings_value, reverse_embeddings_scale = (
364
+ reverse_weight_quantizer(
365
+ self.reverse_embeddings, to_numpy=True
366
+ )
367
+ )
368
+ reverse_embeddings_scale = ops.squeeze(
369
+ reverse_embeddings_scale, axis=0
370
+ )
371
+ # Pack two int4 values into a single int8 byte.
372
+ packed_reverse_embeddings_value, _, _ = quantizers.pack_int4(
373
+ reverse_embeddings_value, axis=0
374
+ )
375
+ del self.reverse_embeddings
376
+ self.quantized_build(
377
+ embeddings_shape, mode, self.quantization_config
378
+ )
379
+ self._embeddings.assign(packed_embeddings_value)
380
+ self.embeddings_scale.assign(embeddings_scale)
381
+ if not self.tie_weights:
382
+ self.reverse_embeddings.assign(packed_reverse_embeddings_value)
383
+ self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
384
+ else:
385
+ raise self._quantization_mode_error(mode)
386
+
387
+ # Set new dtype policy.
388
+ if self.dtype_policy.quantization_mode is None:
389
+ policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
390
+ self.dtype_policy = policy
@@ -111,6 +111,7 @@ class InputSpec:
111
111
  "max_ndim": self.max_ndim,
112
112
  "min_ndim": self.min_ndim,
113
113
  "axes": self.axes,
114
+ "optional": self.optional,
114
115
  }
115
116
 
116
117
  @classmethod
@@ -184,24 +185,24 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
184
185
  if spec.ndim is not None and not spec.allow_last_axis_squeeze:
185
186
  if ndim != spec.ndim:
186
187
  raise ValueError(
187
- f'Input {input_index} of layer "{layer_name}" '
188
- "is incompatible with the layer: "
188
+ f"Input {input_index} with name '{spec.name}' of layer "
189
+ f"'{layer_name}' is incompatible with the layer: "
189
190
  f"expected ndim={spec.ndim}, found ndim={ndim}. "
190
191
  f"Full shape received: {shape}"
191
192
  )
192
193
  if spec.max_ndim is not None:
193
194
  if ndim is not None and ndim > spec.max_ndim:
194
195
  raise ValueError(
195
- f'Input {input_index} of layer "{layer_name}" '
196
- "is incompatible with the layer: "
196
+ f"Input {input_index} with name '{spec.name}' of layer "
197
+ f"'{layer_name}' is incompatible with the layer: "
197
198
  f"expected max_ndim={spec.max_ndim}, "
198
199
  f"found ndim={ndim}"
199
200
  )
200
201
  if spec.min_ndim is not None:
201
202
  if ndim is not None and ndim < spec.min_ndim:
202
203
  raise ValueError(
203
- f'Input {input_index} of layer "{layer_name}" '
204
- "is incompatible with the layer: "
204
+ f"Input {input_index} with name '{spec.name}' of layer "
205
+ f"'{layer_name}' is incompatible with the layer: "
205
206
  f"expected min_ndim={spec.min_ndim}, "
206
207
  f"found ndim={ndim}. "
207
208
  f"Full shape received: {shape}"
@@ -211,8 +212,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
211
212
  dtype = backend.standardize_dtype(x.dtype)
212
213
  if dtype != spec.dtype:
213
214
  raise ValueError(
214
- f'Input {input_index} of layer "{layer_name}" '
215
- "is incompatible with the layer: "
215
+ f"Input {input_index} with name '{spec.name}' of layer "
216
+ f"'{layer_name}' is incompatible with the layer: "
216
217
  f"expected dtype={spec.dtype}, "
217
218
  f"found dtype={dtype}"
218
219
  )
@@ -225,11 +226,10 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
225
226
  None,
226
227
  }:
227
228
  raise ValueError(
228
- f'Input {input_index} of layer "{layer_name}" is '
229
- f"incompatible with the layer: expected axis {axis} "
230
- f"of input shape to have value {value}, "
231
- "but received input with "
232
- f"shape {shape}"
229
+ f"Input {input_index} with name '{spec.name}' of layer "
230
+ f"'{layer_name}' is incompatible with the layer: "
231
+ f"expected axis {axis} of input shape to have value "
232
+ f"{value}, but received input with shape {shape}"
233
233
  )
234
234
  # Check shape.
235
235
  if spec.shape is not None:
@@ -243,8 +243,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
243
243
  if spec_dim is not None and dim is not None:
244
244
  if spec_dim != dim:
245
245
  raise ValueError(
246
- f'Input {input_index} of layer "{layer_name}" is '
247
- "incompatible with the layer: "
248
- f"expected shape={spec.shape}, "
249
- f"found shape={shape}"
246
+ f"Input {input_index} with name '{spec.name}' of "
247
+ f"layer '{layer_name}' is incompatible with the "
248
+ f"layer: expected shape={spec.shape}, found "
249
+ f"shape={shape}"
250
250
  )
keras/src/layers/layer.py CHANGED
@@ -45,6 +45,7 @@ from keras.src.layers import input_spec
45
45
  from keras.src.metrics.metric import Metric
46
46
  from keras.src.ops.node import Node
47
47
  from keras.src.ops.operation import Operation
48
+ from keras.src.quantizers.quantization_config import validate_and_resolve_config
48
49
  from keras.src.utils import python_utils
49
50
  from keras.src.utils import summary_utils
50
51
  from keras.src.utils import traceback_utils
@@ -244,11 +245,13 @@ class Layer(BackendLayer, Operation):
244
245
  original_quantize_method = obj.quantize
245
246
 
246
247
  @wraps(original_quantize_method)
247
- def quantize_wrapper(mode, **kwargs):
248
+ def quantize_wrapper(mode=None, config=None, **kwargs):
249
+ config = validate_and_resolve_config(mode, config)
250
+ mode = config.mode
248
251
  obj._check_quantize_args(mode, obj.compute_dtype)
249
252
  obj._tracker.unlock()
250
253
  try:
251
- original_quantize_method(mode, **kwargs)
254
+ original_quantize_method(mode=mode, config=config, **kwargs)
252
255
  except Exception:
253
256
  raise
254
257
  finally:
@@ -757,6 +760,15 @@ class Layer(BackendLayer, Operation):
757
760
  self._dtype_policy = policy
758
761
  if policy.quantization_mode is not None:
759
762
  if self.built and not getattr(self, "_is_quantized", False):
763
+ if policy.quantization_mode == "gptq":
764
+ raise ValueError(
765
+ "Implicitly enabling GPTQ quantization by setting "
766
+ f"`dtype_policy` to '{value}' is not supported. "
767
+ "GPTQ requires a calibration dataset and a "
768
+ "`GPTQConfig` object.\n\n"
769
+ "Please use the `.quantize('gptq', config=...)` method "
770
+ "on the layer or model instead."
771
+ )
760
772
  self.quantize(policy.quantization_mode)
761
773
 
762
774
  @property
@@ -824,9 +836,14 @@ class Layer(BackendLayer, Operation):
824
836
  #############################################################
825
837
  # 1. Convert any array arguments to tensors of correct dtype.
826
838
  def maybe_convert(x):
827
- return self.dtype_policy.convert_input(
839
+ # Prevent _keras_mask from disappearing
840
+ mask = backend.get_keras_mask(x)
841
+ y = self.dtype_policy.convert_input(
828
842
  x, self.autocast, self.input_dtype
829
843
  )
844
+ if mask is not None:
845
+ backend.set_keras_mask(y, mask)
846
+ return y
830
847
 
831
848
  # Used to avoid expensive `tree` operations in the most common case.
832
849
  if (
@@ -1268,7 +1285,7 @@ class Layer(BackendLayer, Operation):
1268
1285
  def quantized_build(self, input_shape, mode):
1269
1286
  raise self._not_implemented_error(self.quantized_build)
1270
1287
 
1271
- def quantize(self, mode, type_check=True, config=None):
1288
+ def quantize(self, mode=None, type_check=True, config=None):
1272
1289
  raise self._not_implemented_error(self.quantize)
1273
1290
 
1274
1291
  def _check_quantize_args(self, mode, compute_dtype):
@@ -1368,15 +1385,7 @@ class Layer(BackendLayer, Operation):
1368
1385
  for i, v in enumerate(all_vars):
1369
1386
  store[f"{i}"] = v
1370
1387
 
1371
- def load_own_variables(self, store):
1372
- """Loads the state of the layer.
1373
-
1374
- You can override this method to take full control of how the state of
1375
- the layer is loaded upon calling `keras.models.load_model()`.
1376
-
1377
- Args:
1378
- store: Dict from which the state of the model will be loaded.
1379
- """
1388
+ def _check_load_own_variables(self, store):
1380
1389
  all_vars = self._trainable_variables + self._non_trainable_variables
1381
1390
  if len(store.keys()) != len(all_vars):
1382
1391
  if len(all_vars) == 0 and not self.built:
@@ -1409,6 +1418,18 @@ class Layer(BackendLayer, Operation):
1409
1418
  f"{len(store.keys())} variables during loading. "
1410
1419
  f"Expected: {[v.name for v in all_vars]}"
1411
1420
  )
1421
+
1422
+ def load_own_variables(self, store):
1423
+ """Loads the state of the layer.
1424
+
1425
+ You can override this method to take full control of how the state of
1426
+ the layer is loaded upon calling `keras.models.load_model()`.
1427
+
1428
+ Args:
1429
+ store: Dict from which the state of the model will be loaded.
1430
+ """
1431
+ self._check_load_own_variables(store)
1432
+ all_vars = self._trainable_variables + self._non_trainable_variables
1412
1433
  for i, v in enumerate(all_vars):
1413
1434
  v.assign(store[f"{i}"])
1414
1435
 
@@ -1889,6 +1910,10 @@ def get_shapes_dict(call_spec):
1889
1910
  {"input_a_shape": (2, 3)}
1890
1911
  ```
1891
1912
  """
1913
+
1914
+ def standardize_shape_or_none(x):
1915
+ return None if x is None else backend.standardize_shape(x.shape)
1916
+
1892
1917
  shapes_dict = {}
1893
1918
  for k, v in call_spec.tensor_arguments_dict.items():
1894
1919
  if k == "mask" or k.endswith("_mask"):
@@ -1899,10 +1924,10 @@ def get_shapes_dict(call_spec):
1899
1924
  continue
1900
1925
  if k in call_spec.nested_tensor_argument_names:
1901
1926
  shapes_dict[f"{k}_shape"] = tree.map_structure(
1902
- lambda x: backend.standardize_shape(x.shape), v
1927
+ standardize_shape_or_none, v
1903
1928
  )
1904
1929
  else:
1905
- shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
1930
+ shapes_dict[f"{k}_shape"] = standardize_shape_or_none(v)
1906
1931
  return shapes_dict
1907
1932
 
1908
1933
 
@@ -41,6 +41,7 @@ def batch_dot(x, y, axes=None):
41
41
  axes: Tuple or list of integers with target dimensions, or single
42
42
  integer. The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]`
43
43
  should be equal.
44
+ Note that axis `0` (the batch axis) cannot be included.
44
45
 
45
46
  Returns:
46
47
  A tensor with shape equal to the concatenation of `x`'s shape
@@ -226,7 +227,8 @@ class Dot(Merge):
226
227
  take the dot product. If a tuple, should be two integers
227
228
  corresponding to the desired axis from the first input and the
228
229
  desired axis from the second input, respectively. Note that the
229
- size of the two selected axes must match.
230
+ size of the two selected axes must match, and that
231
+ axis `0` (the batch axis) cannot be included.
230
232
  normalize: Whether to L2-normalize samples along the dot product axis
231
233
  before taking the dot product. If set to `True`, then
232
234
  the output of the dot product is the cosine proximity
@@ -363,6 +365,7 @@ def dot(inputs, axes=-1, **kwargs):
363
365
  inputs: A list of input tensors (at least 2).
364
366
  axes: Integer or tuple of integers,
365
367
  axis or axes along which to take the dot product.
368
+ Note that axis `0` (the batch axis) cannot be included.
366
369
  normalize: Whether to L2-normalize samples along the
367
370
  dot product axis before taking the dot product.
368
371
  If set to `True`, then the output of the dot product
@@ -0,0 +1,65 @@
1
+ """Adaptive Average Pooling 1D layer."""
2
+
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.layers.pooling.base_adaptive_pooling import (
5
+ BaseAdaptiveAveragePooling,
6
+ )
7
+
8
+
9
+ @keras_export("keras.layers.AdaptiveAveragePooling1D")
10
+ class AdaptiveAveragePooling1D(BaseAdaptiveAveragePooling):
11
+ """Adaptive average pooling operation for 1D temporal or spatial data.
12
+
13
+ This layer applies an adaptive average pooling operation, which pools the
14
+ input such that the output has a target length specified by `output_size`,
15
+ regardless of the input length. The kernel size and stride are automatically
16
+ computed to achieve the target output size.
17
+
18
+ Args:
19
+ output_size: Integer specifying the target output length.
20
+ data_format: string, either `"channels_last"` or `"channels_first"`.
21
+ `"channels_last"` corresponds to inputs with shape
22
+ `(batch, length, channels)`.
23
+ `"channels_first"` corresponds to inputs with shape
24
+ `(batch, channels, length)`.
25
+ Defaults to the value found in your Keras config file at
26
+ `~/.keras/keras.json`. If never set, `"channels_last"` is used.
27
+
28
+ Input shape:
29
+ - If `data_format="channels_last"`: 3D tensor
30
+ `(batch_size, length, channels)`
31
+ - If `data_format="channels_first"`: 3D tensor
32
+ `(batch_size, channels, length)`
33
+
34
+ Output shape:
35
+ - If `data_format="channels_last"`:
36
+ `(batch_size, output_length, channels)`
37
+ - If `data_format="channels_first"`:
38
+ `(batch_size, channels, output_length)`
39
+
40
+ Examples:
41
+ >>> import numpy as np
42
+ >>> input_seq = np.random.rand(1, 64, 3)
43
+ >>> layer = AdaptiveAveragePooling1D(output_size=32)
44
+ >>> output_seq = layer(input_seq)
45
+ >>> output_seq.shape
46
+ (1, 32, 3)
47
+ """
48
+
49
+ def __init__(self, output_size, data_format=None, **kwargs):
50
+ if isinstance(output_size, int):
51
+ output_size = (output_size,)
52
+ elif isinstance(output_size, (tuple, list)):
53
+ if len(output_size) != 1:
54
+ raise ValueError(
55
+ f"For 1D input, `output_size` tuple must have length 1. "
56
+ f"Received: {output_size}"
57
+ )
58
+ output_size = tuple(output_size)
59
+ else:
60
+ raise TypeError(
61
+ f"`output_size` must be an integer or tuple of 1 integer. "
62
+ f"Received: {output_size} of type {type(output_size)}"
63
+ )
64
+
65
+ super().__init__(output_size, data_format, **kwargs)