keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,399 @@
1
+ import copy
2
+
3
+ from keras.src import dtype_policies
4
+ from keras.src import layers
5
+ from keras.src import ops
6
+ from keras.src import quantizers
7
+ from keras.src.api_export import keras_export
8
+ from keras.src.backend import KerasTensor
9
+ from keras.src.backend import set_keras_mask
10
+ from keras.src.quantizers.quantization_config import QuantizationConfig
11
+
12
+
13
+ @keras_export("keras.layers.ReversibleEmbedding")
14
+ class ReversibleEmbedding(layers.Embedding):
15
+ """An embedding layer which can project backwards to the input dim.
16
+
17
+ This layer is an extension of `keras.layers.Embedding` for language models.
18
+ This layer can be called "in reverse" with `reverse=True`, in which case the
19
+ layer will linearly project from `output_dim` back to `input_dim`.
20
+
21
+ By default, the reverse projection will use the transpose of the
22
+ `embeddings` weights to project to `input_dim` (weights are "tied"). If
23
+ `tie_weights=False`, the model will use a separate, trainable variable for
24
+ reverse projection.
25
+
26
+ This layer has no bias terms.
27
+
28
+ Args:
29
+ input_dim: Integer. Size of the vocabulary,
30
+ i.e. maximum integer index + 1.
31
+ output_dim: Integer. Dimension of the dense embedding.
32
+ tie_weights: Boolean, whether or not the matrix for embedding and
33
+ the matrix for the `reverse` projection should share the same
34
+ weights.
35
+ embeddings_initializer: Initializer for the `embeddings`
36
+ matrix (see `keras.initializers`).
37
+ embeddings_regularizer: Regularizer function applied to
38
+ the `embeddings` matrix (see `keras.regularizers`).
39
+ embeddings_constraint: Constraint function applied to
40
+ the `embeddings` matrix (see `keras.constraints`).
41
+ mask_zero: Boolean, whether or not the input value 0 is a special
42
+ "padding" value that should be masked out.
43
+ reverse_dtype: The dtype for the reverse projection computation.
44
+ Defaults to the `compute_dtype` of the layer.
45
+ logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
46
+ output logits will be scaled by
47
+ `tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
48
+ range of output logits and can improve training.
49
+ **kwargs: other keyword arguments passed to `keras.layers.Embedding`,
50
+ including `name`, `trainable`, `dtype` etc.
51
+
52
+ Call arguments:
53
+ inputs: The tensor inputs to the layer.
54
+ reverse: Boolean. If `True` the layer will perform a linear projection
55
+ from `output_dim` to `input_dim`, instead of a normal embedding
56
+ call. Default to `False`.
57
+
58
+ Example:
59
+ ```python
60
+ batch_size = 16
61
+ vocab_size = 100
62
+ hidden_dim = 32
63
+ seq_length = 50
64
+
65
+ # Generate random inputs.
66
+ token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
67
+
68
+ embedding = keras.layers.ReversibleEmbedding(vocab_size, hidden_dim)
69
+ # Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
70
+ hidden_states = embedding(token_ids)
71
+ # Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
72
+ logits = embedding(hidden_states, reverse=True)
73
+ ```
74
+
75
+ References:
76
+ - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
77
+ - [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ input_dim,
83
+ output_dim,
84
+ tie_weights=True,
85
+ embeddings_initializer="uniform",
86
+ embeddings_regularizer=None,
87
+ embeddings_constraint=None,
88
+ mask_zero=False,
89
+ reverse_dtype=None,
90
+ logit_soft_cap=None,
91
+ **kwargs,
92
+ ):
93
+ super().__init__(
94
+ input_dim,
95
+ output_dim,
96
+ embeddings_initializer=embeddings_initializer,
97
+ embeddings_regularizer=embeddings_regularizer,
98
+ embeddings_constraint=embeddings_constraint,
99
+ mask_zero=mask_zero,
100
+ **kwargs,
101
+ )
102
+ self.tie_weights = tie_weights
103
+ self.reverse_dtype = reverse_dtype
104
+ self.logit_soft_cap = logit_soft_cap
105
+
106
+ def build(self, inputs_shape=None):
107
+ super().build(inputs_shape)
108
+ if not self.tie_weights and self.quantization_mode not in (
109
+ "int8",
110
+ "int4",
111
+ ):
112
+ self.reverse_embeddings = self.add_weight(
113
+ shape=(self.output_dim, self.input_dim),
114
+ initializer=self.embeddings_initializer,
115
+ name="reverse_embeddings",
116
+ trainable=True,
117
+ )
118
+
119
+ def call(self, inputs, reverse=False):
120
+ if not reverse:
121
+ result = super().call(inputs)
122
+ mask = super().compute_mask(inputs)
123
+ if mask is not None:
124
+ set_keras_mask(result, mask)
125
+ return result
126
+ else:
127
+ if self.tie_weights:
128
+ kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
129
+ else:
130
+ kernel = self.reverse_embeddings
131
+ if self.reverse_dtype is not None:
132
+ inputs = ops.cast(inputs, self.reverse_dtype)
133
+ kernel = ops.cast(kernel, self.reverse_dtype)
134
+ logits = ops.matmul(inputs, kernel)
135
+ # Optionally soft-cap logits.
136
+ if self.logit_soft_cap is not None:
137
+ soft_cap = self.logit_soft_cap
138
+ logits = ops.multiply(
139
+ ops.tanh(ops.divide(logits, soft_cap)), soft_cap
140
+ )
141
+ return logits
142
+
143
+ def compute_mask(self, inputs, mask=None):
144
+ # Disable masking from super class, masking is done directly in call.
145
+ return None
146
+
147
+ def compute_output_shape(self, input_shape, reverse=False):
148
+ output_shape = list(input_shape)
149
+ if reverse:
150
+ output_shape[-1] = self.input_dim
151
+ else:
152
+ output_shape += [self.output_dim]
153
+ return output_shape
154
+
155
+ def compute_output_spec(self, inputs, reverse=False):
156
+ output_shape = list(inputs.shape)
157
+ if reverse:
158
+ output_shape[-1] = self.input_dim
159
+ else:
160
+ output_shape += [self.output_dim]
161
+ return KerasTensor(output_shape, dtype=self.compute_dtype)
162
+
163
+ def get_config(self):
164
+ config = super().get_config()
165
+ config.update(
166
+ {
167
+ "tie_weights": self.tie_weights,
168
+ "reverse_dtype": self.reverse_dtype,
169
+ "logit_soft_cap": self.logit_soft_cap,
170
+ }
171
+ )
172
+ return config
173
+
174
+ @property
175
+ def variable_serialization_spec(self):
176
+ # Avoid modifying the parent's spec.
177
+ _spec = copy.deepcopy(super().variable_serialization_spec)
178
+ if not self.tie_weights:
179
+ for mode, variable_spec in _spec.items():
180
+ variable_spec.append("reverse_embeddings")
181
+ if mode in ("int4", "int8"):
182
+ variable_spec.append("reverse_embeddings_scale")
183
+ return _spec
184
+
185
+ def quantized_build(self, embeddings_shape, mode, config=None):
186
+ if mode == "int8":
187
+ self._int8_build(embeddings_shape, config)
188
+ elif mode == "int4":
189
+ self._int4_build(embeddings_shape, config)
190
+ else:
191
+ raise self._quantization_mode_error(mode)
192
+ self._is_quantized = True
193
+
194
+ def _int8_build(self, embeddings_shape, config=None):
195
+ if embeddings_shape is None:
196
+ embeddings_shape = (self.input_dim, self.output_dim)
197
+ super()._int8_build(embeddings_shape=embeddings_shape)
198
+
199
+ self.inputs_quantizer = (
200
+ QuantizationConfig.activation_quantizer_or_default(
201
+ config, quantizers.AbsMaxQuantizer(axis=-1)
202
+ )
203
+ )
204
+ if not self.tie_weights:
205
+ self.reverse_embeddings = self.add_weight(
206
+ name="reverse_embeddings",
207
+ shape=(self.output_dim, self.input_dim),
208
+ initializer="zeros",
209
+ dtype="int8",
210
+ trainable=False,
211
+ )
212
+ self.reverse_embeddings_scale = self.add_weight(
213
+ name="reverse_embeddings_scale",
214
+ shape=(self.input_dim,),
215
+ initializer="ones",
216
+ trainable=False,
217
+ )
218
+
219
+ def _int4_build(self, embeddings_shape, config=None):
220
+ if embeddings_shape is None:
221
+ embeddings_shape = (self.input_dim, self.output_dim)
222
+ super()._int4_build(embeddings_shape=embeddings_shape, config=config)
223
+
224
+ self.inputs_quantizer = (
225
+ QuantizationConfig.activation_quantizer_or_default(
226
+ config, quantizers.AbsMaxQuantizer(axis=-1)
227
+ )
228
+ )
229
+ if not self.tie_weights:
230
+ packed_rows = (self.output_dim + 1) // 2 # ceil for odd dims
231
+ self.reverse_embeddings = self.add_weight(
232
+ name="reverse_embeddings",
233
+ shape=(packed_rows, self.input_dim),
234
+ initializer="zeros",
235
+ dtype="int8",
236
+ trainable=False,
237
+ )
238
+ self.reverse_embeddings_scale = self.add_weight(
239
+ name="reverse_embeddings_scale",
240
+ shape=(self.input_dim,),
241
+ initializer="ones",
242
+ trainable=False,
243
+ )
244
+
245
+ def _int8_call(self, inputs, reverse=False):
246
+ if not reverse:
247
+ return super()._int8_call(inputs)
248
+ else:
249
+ if self.tie_weights:
250
+ kernel = ops.transpose(self._embeddings)
251
+ scale = ops.transpose(self.embeddings_scale)
252
+ else:
253
+ kernel = self.reverse_embeddings
254
+ scale = self.reverse_embeddings_scale
255
+ if self.inputs_quantizer:
256
+ inputs, inputs_scale = self.inputs_quantizer(inputs)
257
+ else:
258
+ inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
259
+ logits = ops.matmul(inputs, kernel)
260
+ # De-scale outputs
261
+ logits = ops.cast(logits, self.compute_dtype)
262
+ logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
263
+ # Optionally soft-cap logits.
264
+ if self.logit_soft_cap is not None:
265
+ soft_cap = self.logit_soft_cap
266
+ logits = ops.multiply(
267
+ ops.tanh(ops.divide(logits, soft_cap)), soft_cap
268
+ )
269
+ return logits
270
+
271
+ def _int4_call(self, inputs, reverse=False):
272
+ if not reverse:
273
+ return super()._int4_call(inputs)
274
+ else:
275
+ if self.tie_weights:
276
+ embeddings = ops.transpose(self._embeddings)
277
+ scale = ops.transpose(self.embeddings_scale)
278
+ else:
279
+ embeddings = self.reverse_embeddings
280
+ scale = self.reverse_embeddings_scale
281
+ unpacked_embeddings = quantizers.unpack_int4(
282
+ embeddings, self.output_dim, axis=0
283
+ )
284
+ if self.inputs_quantizer:
285
+ inputs, inputs_scale = self.inputs_quantizer(inputs)
286
+ else:
287
+ inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
288
+ logits = ops.matmul(inputs, unpacked_embeddings)
289
+ # De-scale outputs
290
+ logits = ops.cast(logits, self.compute_dtype)
291
+ logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
292
+ # Optionally soft-cap logits.
293
+ if self.logit_soft_cap is not None:
294
+ soft_cap = self.logit_soft_cap
295
+ logits = ops.multiply(
296
+ ops.tanh(ops.divide(logits, soft_cap)), soft_cap
297
+ )
298
+ return logits
299
+
300
+ def quantize(self, mode=None, type_check=True, config=None):
301
+ if type_check and type(self) is not ReversibleEmbedding:
302
+ raise self._not_implemented_error(self.quantize)
303
+
304
+ self.quantization_config = config
305
+
306
+ embeddings_shape = (self.input_dim, self.output_dim)
307
+ if mode == "int8":
308
+ # Quantize `self._embeddings` to int8 and compute corresponding
309
+ # scale.
310
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
311
+ self.quantization_config, quantizers.AbsMaxQuantizer(axis=-1)
312
+ )
313
+ embeddings_value, embeddings_scale = weight_quantizer(
314
+ self._embeddings, to_numpy=True
315
+ )
316
+ embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
317
+ del self._embeddings
318
+ if not self.tie_weights:
319
+ reverse_weight_quantizer = (
320
+ QuantizationConfig.weight_quantizer_or_default(
321
+ self.quantization_config,
322
+ quantizers.AbsMaxQuantizer(axis=0),
323
+ )
324
+ )
325
+ reverse_embeddings_value, reverse_embeddings_scale = (
326
+ reverse_weight_quantizer(
327
+ self.reverse_embeddings, to_numpy=True
328
+ )
329
+ )
330
+ reverse_embeddings_scale = ops.squeeze(
331
+ reverse_embeddings_scale, axis=0
332
+ )
333
+ del self.reverse_embeddings
334
+ self.quantized_build(
335
+ embeddings_shape, mode, self.quantization_config
336
+ )
337
+ self._embeddings.assign(embeddings_value)
338
+ self.embeddings_scale.assign(embeddings_scale)
339
+ if not self.tie_weights:
340
+ self.reverse_embeddings.assign(reverse_embeddings_value)
341
+ self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
342
+ elif mode == "int4":
343
+ # Quantize to int4 values (stored in int8 dtype, range [-8, 7]).
344
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
345
+ self.quantization_config,
346
+ quantizers.AbsMaxQuantizer(
347
+ axis=-1,
348
+ value_range=(-8, 7),
349
+ output_dtype="int8",
350
+ ),
351
+ )
352
+ embeddings_value, embeddings_scale = weight_quantizer(
353
+ self._embeddings, to_numpy=True
354
+ )
355
+ embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
356
+ # 2. Pack two int4 values into a single int8 byte.
357
+ packed_embeddings_value, _, _ = quantizers.pack_int4(
358
+ embeddings_value, axis=-1
359
+ )
360
+ del self._embeddings
361
+ if not self.tie_weights:
362
+ reverse_weight_quantizer = (
363
+ QuantizationConfig.weight_quantizer_or_default(
364
+ self.quantization_config,
365
+ quantizers.AbsMaxQuantizer(
366
+ axis=0,
367
+ value_range=(-8, 7),
368
+ output_dtype="int8",
369
+ ),
370
+ )
371
+ )
372
+ reverse_embeddings_value, reverse_embeddings_scale = (
373
+ reverse_weight_quantizer(
374
+ self.reverse_embeddings, to_numpy=True
375
+ )
376
+ )
377
+ reverse_embeddings_scale = ops.squeeze(
378
+ reverse_embeddings_scale, axis=0
379
+ )
380
+ # Pack two int4 values into a single int8 byte.
381
+ packed_reverse_embeddings_value, _, _ = quantizers.pack_int4(
382
+ reverse_embeddings_value, axis=0
383
+ )
384
+ del self.reverse_embeddings
385
+ self.quantized_build(
386
+ embeddings_shape, mode, self.quantization_config
387
+ )
388
+ self._embeddings.assign(packed_embeddings_value)
389
+ self.embeddings_scale.assign(embeddings_scale)
390
+ if not self.tie_weights:
391
+ self.reverse_embeddings.assign(packed_reverse_embeddings_value)
392
+ self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
393
+ else:
394
+ raise self._quantization_mode_error(mode)
395
+
396
+ # Set new dtype policy.
397
+ if self.dtype_policy.quantization_mode is None:
398
+ policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
399
+ self.dtype_policy = policy
@@ -111,6 +111,7 @@ class InputSpec:
111
111
  "max_ndim": self.max_ndim,
112
112
  "min_ndim": self.min_ndim,
113
113
  "axes": self.axes,
114
+ "optional": self.optional,
114
115
  }
115
116
 
116
117
  @classmethod
@@ -184,24 +185,24 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
184
185
  if spec.ndim is not None and not spec.allow_last_axis_squeeze:
185
186
  if ndim != spec.ndim:
186
187
  raise ValueError(
187
- f'Input {input_index} of layer "{layer_name}" '
188
- "is incompatible with the layer: "
188
+ f"Input {input_index} with name '{spec.name}' of layer "
189
+ f"'{layer_name}' is incompatible with the layer: "
189
190
  f"expected ndim={spec.ndim}, found ndim={ndim}. "
190
191
  f"Full shape received: {shape}"
191
192
  )
192
193
  if spec.max_ndim is not None:
193
194
  if ndim is not None and ndim > spec.max_ndim:
194
195
  raise ValueError(
195
- f'Input {input_index} of layer "{layer_name}" '
196
- "is incompatible with the layer: "
196
+ f"Input {input_index} with name '{spec.name}' of layer "
197
+ f"'{layer_name}' is incompatible with the layer: "
197
198
  f"expected max_ndim={spec.max_ndim}, "
198
199
  f"found ndim={ndim}"
199
200
  )
200
201
  if spec.min_ndim is not None:
201
202
  if ndim is not None and ndim < spec.min_ndim:
202
203
  raise ValueError(
203
- f'Input {input_index} of layer "{layer_name}" '
204
- "is incompatible with the layer: "
204
+ f"Input {input_index} with name '{spec.name}' of layer "
205
+ f"'{layer_name}' is incompatible with the layer: "
205
206
  f"expected min_ndim={spec.min_ndim}, "
206
207
  f"found ndim={ndim}. "
207
208
  f"Full shape received: {shape}"
@@ -211,8 +212,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
211
212
  dtype = backend.standardize_dtype(x.dtype)
212
213
  if dtype != spec.dtype:
213
214
  raise ValueError(
214
- f'Input {input_index} of layer "{layer_name}" '
215
- "is incompatible with the layer: "
215
+ f"Input {input_index} with name '{spec.name}' of layer "
216
+ f"'{layer_name}' is incompatible with the layer: "
216
217
  f"expected dtype={spec.dtype}, "
217
218
  f"found dtype={dtype}"
218
219
  )
@@ -225,11 +226,10 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
225
226
  None,
226
227
  }:
227
228
  raise ValueError(
228
- f'Input {input_index} of layer "{layer_name}" is '
229
- f"incompatible with the layer: expected axis {axis} "
230
- f"of input shape to have value {value}, "
231
- "but received input with "
232
- f"shape {shape}"
229
+ f"Input {input_index} with name '{spec.name}' of layer "
230
+ f"'{layer_name}' is incompatible with the layer: "
231
+ f"expected axis {axis} of input shape to have value "
232
+ f"{value}, but received input with shape {shape}"
233
233
  )
234
234
  # Check shape.
235
235
  if spec.shape is not None:
@@ -243,8 +243,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
243
243
  if spec_dim is not None and dim is not None:
244
244
  if spec_dim != dim:
245
245
  raise ValueError(
246
- f'Input {input_index} of layer "{layer_name}" is '
247
- "incompatible with the layer: "
248
- f"expected shape={spec.shape}, "
249
- f"found shape={shape}"
246
+ f"Input {input_index} with name '{spec.name}' of "
247
+ f"layer '{layer_name}' is incompatible with the "
248
+ f"layer: expected shape={spec.shape}, found "
249
+ f"shape={shape}"
250
250
  )
keras/src/layers/layer.py CHANGED
@@ -45,6 +45,7 @@ from keras.src.layers import input_spec
45
45
  from keras.src.metrics.metric import Metric
46
46
  from keras.src.ops.node import Node
47
47
  from keras.src.ops.operation import Operation
48
+ from keras.src.quantizers.quantization_config import validate_and_resolve_config
48
49
  from keras.src.utils import python_utils
49
50
  from keras.src.utils import summary_utils
50
51
  from keras.src.utils import traceback_utils
@@ -244,11 +245,13 @@ class Layer(BackendLayer, Operation):
244
245
  original_quantize_method = obj.quantize
245
246
 
246
247
  @wraps(original_quantize_method)
247
- def quantize_wrapper(mode, **kwargs):
248
+ def quantize_wrapper(mode=None, config=None, **kwargs):
249
+ config = validate_and_resolve_config(mode, config)
250
+ mode = config.mode
248
251
  obj._check_quantize_args(mode, obj.compute_dtype)
249
252
  obj._tracker.unlock()
250
253
  try:
251
- original_quantize_method(mode, **kwargs)
254
+ original_quantize_method(mode=mode, config=config, **kwargs)
252
255
  except Exception:
253
256
  raise
254
257
  finally:
@@ -757,6 +760,15 @@ class Layer(BackendLayer, Operation):
757
760
  self._dtype_policy = policy
758
761
  if policy.quantization_mode is not None:
759
762
  if self.built and not getattr(self, "_is_quantized", False):
763
+ if policy.quantization_mode == "gptq":
764
+ raise ValueError(
765
+ "Implicitly enabling GPTQ quantization by setting "
766
+ f"`dtype_policy` to '{value}' is not supported. "
767
+ "GPTQ requires a calibration dataset and a "
768
+ "`GPTQConfig` object.\n\n"
769
+ "Please use the `.quantize('gptq', config=...)` method "
770
+ "on the layer or model instead."
771
+ )
760
772
  self.quantize(policy.quantization_mode)
761
773
 
762
774
  @property
@@ -824,9 +836,14 @@ class Layer(BackendLayer, Operation):
824
836
  #############################################################
825
837
  # 1. Convert any array arguments to tensors of correct dtype.
826
838
  def maybe_convert(x):
827
- return self.dtype_policy.convert_input(
839
+ # Prevent _keras_mask from disappearing
840
+ mask = backend.get_keras_mask(x)
841
+ y = self.dtype_policy.convert_input(
828
842
  x, self.autocast, self.input_dtype
829
843
  )
844
+ if mask is not None:
845
+ backend.set_keras_mask(y, mask)
846
+ return y
830
847
 
831
848
  # Used to avoid expensive `tree` operations in the most common case.
832
849
  if (
@@ -1268,7 +1285,7 @@ class Layer(BackendLayer, Operation):
1268
1285
  def quantized_build(self, input_shape, mode):
1269
1286
  raise self._not_implemented_error(self.quantized_build)
1270
1287
 
1271
- def quantize(self, mode, type_check=True, config=None):
1288
+ def quantize(self, mode=None, type_check=True, config=None):
1272
1289
  raise self._not_implemented_error(self.quantize)
1273
1290
 
1274
1291
  def _check_quantize_args(self, mode, compute_dtype):
@@ -1320,6 +1337,8 @@ class Layer(BackendLayer, Operation):
1320
1337
  return self._int4_call(*args, **kwargs)
1321
1338
  elif self.quantization_mode == "gptq":
1322
1339
  return self._gptq_call(*args, **kwargs)
1340
+ elif self.quantization_mode == "awq":
1341
+ return self._awq_call(*args, **kwargs)
1323
1342
  else:
1324
1343
  raise self._quantization_mode_error(self.quantization_mode)
1325
1344
 
@@ -1335,6 +1354,9 @@ class Layer(BackendLayer, Operation):
1335
1354
  def _gptq_call(self, *args, **kwargs):
1336
1355
  raise self._not_implemented_error(self._gptq_call)
1337
1356
 
1357
+ def _awq_call(self, *args, **kwargs):
1358
+ raise self._not_implemented_error(self._awq_call)
1359
+
1338
1360
  def _not_implemented_error(self, attr, msg=None):
1339
1361
  if callable(attr):
1340
1362
  attr_name = attr.__name__
@@ -1368,15 +1390,7 @@ class Layer(BackendLayer, Operation):
1368
1390
  for i, v in enumerate(all_vars):
1369
1391
  store[f"{i}"] = v
1370
1392
 
1371
- def load_own_variables(self, store):
1372
- """Loads the state of the layer.
1373
-
1374
- You can override this method to take full control of how the state of
1375
- the layer is loaded upon calling `keras.models.load_model()`.
1376
-
1377
- Args:
1378
- store: Dict from which the state of the model will be loaded.
1379
- """
1393
+ def _check_load_own_variables(self, store):
1380
1394
  all_vars = self._trainable_variables + self._non_trainable_variables
1381
1395
  if len(store.keys()) != len(all_vars):
1382
1396
  if len(all_vars) == 0 and not self.built:
@@ -1409,6 +1423,18 @@ class Layer(BackendLayer, Operation):
1409
1423
  f"{len(store.keys())} variables during loading. "
1410
1424
  f"Expected: {[v.name for v in all_vars]}"
1411
1425
  )
1426
+
1427
+ def load_own_variables(self, store):
1428
+ """Loads the state of the layer.
1429
+
1430
+ You can override this method to take full control of how the state of
1431
+ the layer is loaded upon calling `keras.models.load_model()`.
1432
+
1433
+ Args:
1434
+ store: Dict from which the state of the model will be loaded.
1435
+ """
1436
+ self._check_load_own_variables(store)
1437
+ all_vars = self._trainable_variables + self._non_trainable_variables
1412
1438
  for i, v in enumerate(all_vars):
1413
1439
  v.assign(store[f"{i}"])
1414
1440
 
@@ -1889,6 +1915,10 @@ def get_shapes_dict(call_spec):
1889
1915
  {"input_a_shape": (2, 3)}
1890
1916
  ```
1891
1917
  """
1918
+
1919
+ def standardize_shape_or_none(x):
1920
+ return None if x is None else backend.standardize_shape(x.shape)
1921
+
1892
1922
  shapes_dict = {}
1893
1923
  for k, v in call_spec.tensor_arguments_dict.items():
1894
1924
  if k == "mask" or k.endswith("_mask"):
@@ -1899,10 +1929,10 @@ def get_shapes_dict(call_spec):
1899
1929
  continue
1900
1930
  if k in call_spec.nested_tensor_argument_names:
1901
1931
  shapes_dict[f"{k}_shape"] = tree.map_structure(
1902
- lambda x: backend.standardize_shape(x.shape), v
1932
+ standardize_shape_or_none, v
1903
1933
  )
1904
1934
  else:
1905
- shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
1935
+ shapes_dict[f"{k}_shape"] = standardize_shape_or_none(v)
1906
1936
  return shapes_dict
1907
1937
 
1908
1938
 
@@ -41,6 +41,7 @@ def batch_dot(x, y, axes=None):
41
41
  axes: Tuple or list of integers with target dimensions, or single
42
42
  integer. The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]`
43
43
  should be equal.
44
+ Note that axis `0` (the batch axis) cannot be included.
44
45
 
45
46
  Returns:
46
47
  A tensor with shape equal to the concatenation of `x`'s shape
@@ -226,7 +227,8 @@ class Dot(Merge):
226
227
  take the dot product. If a tuple, should be two integers
227
228
  corresponding to the desired axis from the first input and the
228
229
  desired axis from the second input, respectively. Note that the
229
- size of the two selected axes must match.
230
+ size of the two selected axes must match, and that
231
+ axis `0` (the batch axis) cannot be included.
230
232
  normalize: Whether to L2-normalize samples along the dot product axis
231
233
  before taking the dot product. If set to `True`, then
232
234
  the output of the dot product is the cosine proximity
@@ -363,6 +365,7 @@ def dot(inputs, axes=-1, **kwargs):
363
365
  inputs: A list of input tensors (at least 2).
364
366
  axes: Integer or tuple of integers,
365
367
  axis or axes along which to take the dot product.
368
+ Note that axis `0` (the batch axis) cannot be included.
366
369
  normalize: Whether to L2-normalize samples along the
367
370
  dot product axis before taking the dot product.
368
371
  If set to `True`, then the output of the dot product