keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
from keras.src import dtype_policies
|
|
4
|
+
from keras.src import layers
|
|
5
|
+
from keras.src import ops
|
|
6
|
+
from keras.src import quantizers
|
|
7
|
+
from keras.src.api_export import keras_export
|
|
8
|
+
from keras.src.backend import KerasTensor
|
|
9
|
+
from keras.src.backend import set_keras_mask
|
|
10
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@keras_export("keras.layers.ReversibleEmbedding")
|
|
14
|
+
class ReversibleEmbedding(layers.Embedding):
|
|
15
|
+
"""An embedding layer which can project backwards to the input dim.
|
|
16
|
+
|
|
17
|
+
This layer is an extension of `keras.layers.Embedding` for language models.
|
|
18
|
+
This layer can be called "in reverse" with `reverse=True`, in which case the
|
|
19
|
+
layer will linearly project from `output_dim` back to `input_dim`.
|
|
20
|
+
|
|
21
|
+
By default, the reverse projection will use the transpose of the
|
|
22
|
+
`embeddings` weights to project to `input_dim` (weights are "tied"). If
|
|
23
|
+
`tie_weights=False`, the model will use a separate, trainable variable for
|
|
24
|
+
reverse projection.
|
|
25
|
+
|
|
26
|
+
This layer has no bias terms.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
input_dim: Integer. Size of the vocabulary,
|
|
30
|
+
i.e. maximum integer index + 1.
|
|
31
|
+
output_dim: Integer. Dimension of the dense embedding.
|
|
32
|
+
tie_weights: Boolean, whether or not the matrix for embedding and
|
|
33
|
+
the matrix for the `reverse` projection should share the same
|
|
34
|
+
weights.
|
|
35
|
+
embeddings_initializer: Initializer for the `embeddings`
|
|
36
|
+
matrix (see `keras.initializers`).
|
|
37
|
+
embeddings_regularizer: Regularizer function applied to
|
|
38
|
+
the `embeddings` matrix (see `keras.regularizers`).
|
|
39
|
+
embeddings_constraint: Constraint function applied to
|
|
40
|
+
the `embeddings` matrix (see `keras.constraints`).
|
|
41
|
+
mask_zero: Boolean, whether or not the input value 0 is a special
|
|
42
|
+
"padding" value that should be masked out.
|
|
43
|
+
reverse_dtype: The dtype for the reverse projection computation.
|
|
44
|
+
Defaults to the `compute_dtype` of the layer.
|
|
45
|
+
logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
|
|
46
|
+
output logits will be scaled by
|
|
47
|
+
`tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
|
|
48
|
+
range of output logits and can improve training.
|
|
49
|
+
**kwargs: other keyword arguments passed to `keras.layers.Embedding`,
|
|
50
|
+
including `name`, `trainable`, `dtype` etc.
|
|
51
|
+
|
|
52
|
+
Call arguments:
|
|
53
|
+
inputs: The tensor inputs to the layer.
|
|
54
|
+
reverse: Boolean. If `True` the layer will perform a linear projection
|
|
55
|
+
from `output_dim` to `input_dim`, instead of a normal embedding
|
|
56
|
+
call. Default to `False`.
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
```python
|
|
60
|
+
batch_size = 16
|
|
61
|
+
vocab_size = 100
|
|
62
|
+
hidden_dim = 32
|
|
63
|
+
seq_length = 50
|
|
64
|
+
|
|
65
|
+
# Generate random inputs.
|
|
66
|
+
token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))
|
|
67
|
+
|
|
68
|
+
embedding = keras.layers.ReversibleEmbedding(vocab_size, hidden_dim)
|
|
69
|
+
# Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
|
|
70
|
+
hidden_states = embedding(token_ids)
|
|
71
|
+
# Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
|
|
72
|
+
logits = embedding(hidden_states, reverse=True)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
References:
|
|
76
|
+
- [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
|
|
77
|
+
- [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
input_dim,
|
|
83
|
+
output_dim,
|
|
84
|
+
tie_weights=True,
|
|
85
|
+
embeddings_initializer="uniform",
|
|
86
|
+
embeddings_regularizer=None,
|
|
87
|
+
embeddings_constraint=None,
|
|
88
|
+
mask_zero=False,
|
|
89
|
+
reverse_dtype=None,
|
|
90
|
+
logit_soft_cap=None,
|
|
91
|
+
**kwargs,
|
|
92
|
+
):
|
|
93
|
+
super().__init__(
|
|
94
|
+
input_dim,
|
|
95
|
+
output_dim,
|
|
96
|
+
embeddings_initializer=embeddings_initializer,
|
|
97
|
+
embeddings_regularizer=embeddings_regularizer,
|
|
98
|
+
embeddings_constraint=embeddings_constraint,
|
|
99
|
+
mask_zero=mask_zero,
|
|
100
|
+
**kwargs,
|
|
101
|
+
)
|
|
102
|
+
self.tie_weights = tie_weights
|
|
103
|
+
self.reverse_dtype = reverse_dtype
|
|
104
|
+
self.logit_soft_cap = logit_soft_cap
|
|
105
|
+
|
|
106
|
+
def build(self, inputs_shape=None):
|
|
107
|
+
super().build(inputs_shape)
|
|
108
|
+
if not self.tie_weights and self.quantization_mode not in (
|
|
109
|
+
"int8",
|
|
110
|
+
"int4",
|
|
111
|
+
):
|
|
112
|
+
self.reverse_embeddings = self.add_weight(
|
|
113
|
+
shape=(self.output_dim, self.input_dim),
|
|
114
|
+
initializer=self.embeddings_initializer,
|
|
115
|
+
name="reverse_embeddings",
|
|
116
|
+
trainable=True,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def call(self, inputs, reverse=False):
|
|
120
|
+
if not reverse:
|
|
121
|
+
result = super().call(inputs)
|
|
122
|
+
mask = super().compute_mask(inputs)
|
|
123
|
+
if mask is not None:
|
|
124
|
+
set_keras_mask(result, mask)
|
|
125
|
+
return result
|
|
126
|
+
else:
|
|
127
|
+
if self.tie_weights:
|
|
128
|
+
kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
|
|
129
|
+
else:
|
|
130
|
+
kernel = self.reverse_embeddings
|
|
131
|
+
if self.reverse_dtype is not None:
|
|
132
|
+
inputs = ops.cast(inputs, self.reverse_dtype)
|
|
133
|
+
kernel = ops.cast(kernel, self.reverse_dtype)
|
|
134
|
+
logits = ops.matmul(inputs, kernel)
|
|
135
|
+
# Optionally soft-cap logits.
|
|
136
|
+
if self.logit_soft_cap is not None:
|
|
137
|
+
soft_cap = self.logit_soft_cap
|
|
138
|
+
logits = ops.multiply(
|
|
139
|
+
ops.tanh(ops.divide(logits, soft_cap)), soft_cap
|
|
140
|
+
)
|
|
141
|
+
return logits
|
|
142
|
+
|
|
143
|
+
def compute_mask(self, inputs, mask=None):
|
|
144
|
+
# Disable masking from super class, masking is done directly in call.
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
def compute_output_shape(self, input_shape, reverse=False):
|
|
148
|
+
output_shape = list(input_shape)
|
|
149
|
+
if reverse:
|
|
150
|
+
output_shape[-1] = self.input_dim
|
|
151
|
+
else:
|
|
152
|
+
output_shape += [self.output_dim]
|
|
153
|
+
return output_shape
|
|
154
|
+
|
|
155
|
+
def compute_output_spec(self, inputs, reverse=False):
|
|
156
|
+
output_shape = list(inputs.shape)
|
|
157
|
+
if reverse:
|
|
158
|
+
output_shape[-1] = self.input_dim
|
|
159
|
+
else:
|
|
160
|
+
output_shape += [self.output_dim]
|
|
161
|
+
return KerasTensor(output_shape, dtype=self.compute_dtype)
|
|
162
|
+
|
|
163
|
+
def get_config(self):
|
|
164
|
+
config = super().get_config()
|
|
165
|
+
config.update(
|
|
166
|
+
{
|
|
167
|
+
"tie_weights": self.tie_weights,
|
|
168
|
+
"reverse_dtype": self.reverse_dtype,
|
|
169
|
+
"logit_soft_cap": self.logit_soft_cap,
|
|
170
|
+
}
|
|
171
|
+
)
|
|
172
|
+
return config
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def variable_serialization_spec(self):
|
|
176
|
+
# Avoid modifying the parent's spec.
|
|
177
|
+
_spec = copy.deepcopy(super().variable_serialization_spec)
|
|
178
|
+
if not self.tie_weights:
|
|
179
|
+
for mode, variable_spec in _spec.items():
|
|
180
|
+
variable_spec.append("reverse_embeddings")
|
|
181
|
+
if mode in ("int4", "int8"):
|
|
182
|
+
variable_spec.append("reverse_embeddings_scale")
|
|
183
|
+
return _spec
|
|
184
|
+
|
|
185
|
+
def quantized_build(self, embeddings_shape, mode, config=None):
|
|
186
|
+
if mode == "int8":
|
|
187
|
+
self._int8_build(embeddings_shape, config)
|
|
188
|
+
elif mode == "int4":
|
|
189
|
+
self._int4_build(embeddings_shape, config)
|
|
190
|
+
else:
|
|
191
|
+
raise self._quantization_mode_error(mode)
|
|
192
|
+
self._is_quantized = True
|
|
193
|
+
|
|
194
|
+
def _int8_build(self, embeddings_shape, config=None):
|
|
195
|
+
if embeddings_shape is None:
|
|
196
|
+
embeddings_shape = (self.input_dim, self.output_dim)
|
|
197
|
+
super()._int8_build(embeddings_shape=embeddings_shape)
|
|
198
|
+
|
|
199
|
+
self.inputs_quantizer = (
|
|
200
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
201
|
+
config, quantizers.AbsMaxQuantizer(axis=-1)
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
if not self.tie_weights:
|
|
205
|
+
self.reverse_embeddings = self.add_weight(
|
|
206
|
+
name="reverse_embeddings",
|
|
207
|
+
shape=(self.output_dim, self.input_dim),
|
|
208
|
+
initializer="zeros",
|
|
209
|
+
dtype="int8",
|
|
210
|
+
trainable=False,
|
|
211
|
+
)
|
|
212
|
+
self.reverse_embeddings_scale = self.add_weight(
|
|
213
|
+
name="reverse_embeddings_scale",
|
|
214
|
+
shape=(self.input_dim,),
|
|
215
|
+
initializer="ones",
|
|
216
|
+
trainable=False,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def _int4_build(self, embeddings_shape, config=None):
|
|
220
|
+
if embeddings_shape is None:
|
|
221
|
+
embeddings_shape = (self.input_dim, self.output_dim)
|
|
222
|
+
super()._int4_build(embeddings_shape=embeddings_shape, config=config)
|
|
223
|
+
|
|
224
|
+
self.inputs_quantizer = (
|
|
225
|
+
QuantizationConfig.activation_quantizer_or_default(
|
|
226
|
+
config, quantizers.AbsMaxQuantizer(axis=-1)
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
if not self.tie_weights:
|
|
230
|
+
packed_rows = (self.output_dim + 1) // 2 # ceil for odd dims
|
|
231
|
+
self.reverse_embeddings = self.add_weight(
|
|
232
|
+
name="reverse_embeddings",
|
|
233
|
+
shape=(packed_rows, self.input_dim),
|
|
234
|
+
initializer="zeros",
|
|
235
|
+
dtype="int8",
|
|
236
|
+
trainable=False,
|
|
237
|
+
)
|
|
238
|
+
self.reverse_embeddings_scale = self.add_weight(
|
|
239
|
+
name="reverse_embeddings_scale",
|
|
240
|
+
shape=(self.input_dim,),
|
|
241
|
+
initializer="ones",
|
|
242
|
+
trainable=False,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def _int8_call(self, inputs, reverse=False):
|
|
246
|
+
if not reverse:
|
|
247
|
+
return super()._int8_call(inputs)
|
|
248
|
+
else:
|
|
249
|
+
if self.tie_weights:
|
|
250
|
+
kernel = ops.transpose(self._embeddings)
|
|
251
|
+
scale = ops.transpose(self.embeddings_scale)
|
|
252
|
+
else:
|
|
253
|
+
kernel = self.reverse_embeddings
|
|
254
|
+
scale = self.reverse_embeddings_scale
|
|
255
|
+
if self.inputs_quantizer:
|
|
256
|
+
inputs, inputs_scale = self.inputs_quantizer(inputs)
|
|
257
|
+
else:
|
|
258
|
+
inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
|
|
259
|
+
logits = ops.matmul(inputs, kernel)
|
|
260
|
+
# De-scale outputs
|
|
261
|
+
logits = ops.cast(logits, self.compute_dtype)
|
|
262
|
+
logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
|
|
263
|
+
# Optionally soft-cap logits.
|
|
264
|
+
if self.logit_soft_cap is not None:
|
|
265
|
+
soft_cap = self.logit_soft_cap
|
|
266
|
+
logits = ops.multiply(
|
|
267
|
+
ops.tanh(ops.divide(logits, soft_cap)), soft_cap
|
|
268
|
+
)
|
|
269
|
+
return logits
|
|
270
|
+
|
|
271
|
+
def _int4_call(self, inputs, reverse=False):
|
|
272
|
+
if not reverse:
|
|
273
|
+
return super()._int4_call(inputs)
|
|
274
|
+
else:
|
|
275
|
+
if self.tie_weights:
|
|
276
|
+
embeddings = ops.transpose(self._embeddings)
|
|
277
|
+
scale = ops.transpose(self.embeddings_scale)
|
|
278
|
+
else:
|
|
279
|
+
embeddings = self.reverse_embeddings
|
|
280
|
+
scale = self.reverse_embeddings_scale
|
|
281
|
+
unpacked_embeddings = quantizers.unpack_int4(
|
|
282
|
+
embeddings, self.output_dim, axis=0
|
|
283
|
+
)
|
|
284
|
+
if self.inputs_quantizer:
|
|
285
|
+
inputs, inputs_scale = self.inputs_quantizer(inputs)
|
|
286
|
+
else:
|
|
287
|
+
inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
|
|
288
|
+
logits = ops.matmul(inputs, unpacked_embeddings)
|
|
289
|
+
# De-scale outputs
|
|
290
|
+
logits = ops.cast(logits, self.compute_dtype)
|
|
291
|
+
logits = ops.divide(logits, ops.multiply(inputs_scale, scale))
|
|
292
|
+
# Optionally soft-cap logits.
|
|
293
|
+
if self.logit_soft_cap is not None:
|
|
294
|
+
soft_cap = self.logit_soft_cap
|
|
295
|
+
logits = ops.multiply(
|
|
296
|
+
ops.tanh(ops.divide(logits, soft_cap)), soft_cap
|
|
297
|
+
)
|
|
298
|
+
return logits
|
|
299
|
+
|
|
300
|
+
def quantize(self, mode=None, type_check=True, config=None):
|
|
301
|
+
if type_check and type(self) is not ReversibleEmbedding:
|
|
302
|
+
raise self._not_implemented_error(self.quantize)
|
|
303
|
+
|
|
304
|
+
self.quantization_config = config
|
|
305
|
+
|
|
306
|
+
embeddings_shape = (self.input_dim, self.output_dim)
|
|
307
|
+
if mode == "int8":
|
|
308
|
+
# Quantize `self._embeddings` to int8 and compute corresponding
|
|
309
|
+
# scale.
|
|
310
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
311
|
+
self.quantization_config, quantizers.AbsMaxQuantizer(axis=-1)
|
|
312
|
+
)
|
|
313
|
+
embeddings_value, embeddings_scale = weight_quantizer(
|
|
314
|
+
self._embeddings, to_numpy=True
|
|
315
|
+
)
|
|
316
|
+
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
317
|
+
del self._embeddings
|
|
318
|
+
if not self.tie_weights:
|
|
319
|
+
reverse_weight_quantizer = (
|
|
320
|
+
QuantizationConfig.weight_quantizer_or_default(
|
|
321
|
+
self.quantization_config,
|
|
322
|
+
quantizers.AbsMaxQuantizer(axis=0),
|
|
323
|
+
)
|
|
324
|
+
)
|
|
325
|
+
reverse_embeddings_value, reverse_embeddings_scale = (
|
|
326
|
+
reverse_weight_quantizer(
|
|
327
|
+
self.reverse_embeddings, to_numpy=True
|
|
328
|
+
)
|
|
329
|
+
)
|
|
330
|
+
reverse_embeddings_scale = ops.squeeze(
|
|
331
|
+
reverse_embeddings_scale, axis=0
|
|
332
|
+
)
|
|
333
|
+
del self.reverse_embeddings
|
|
334
|
+
self.quantized_build(
|
|
335
|
+
embeddings_shape, mode, self.quantization_config
|
|
336
|
+
)
|
|
337
|
+
self._embeddings.assign(embeddings_value)
|
|
338
|
+
self.embeddings_scale.assign(embeddings_scale)
|
|
339
|
+
if not self.tie_weights:
|
|
340
|
+
self.reverse_embeddings.assign(reverse_embeddings_value)
|
|
341
|
+
self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
|
|
342
|
+
elif mode == "int4":
|
|
343
|
+
# Quantize to int4 values (stored in int8 dtype, range [-8, 7]).
|
|
344
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
345
|
+
self.quantization_config,
|
|
346
|
+
quantizers.AbsMaxQuantizer(
|
|
347
|
+
axis=-1,
|
|
348
|
+
value_range=(-8, 7),
|
|
349
|
+
output_dtype="int8",
|
|
350
|
+
),
|
|
351
|
+
)
|
|
352
|
+
embeddings_value, embeddings_scale = weight_quantizer(
|
|
353
|
+
self._embeddings, to_numpy=True
|
|
354
|
+
)
|
|
355
|
+
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
356
|
+
# 2. Pack two int4 values into a single int8 byte.
|
|
357
|
+
packed_embeddings_value, _, _ = quantizers.pack_int4(
|
|
358
|
+
embeddings_value, axis=-1
|
|
359
|
+
)
|
|
360
|
+
del self._embeddings
|
|
361
|
+
if not self.tie_weights:
|
|
362
|
+
reverse_weight_quantizer = (
|
|
363
|
+
QuantizationConfig.weight_quantizer_or_default(
|
|
364
|
+
self.quantization_config,
|
|
365
|
+
quantizers.AbsMaxQuantizer(
|
|
366
|
+
axis=0,
|
|
367
|
+
value_range=(-8, 7),
|
|
368
|
+
output_dtype="int8",
|
|
369
|
+
),
|
|
370
|
+
)
|
|
371
|
+
)
|
|
372
|
+
reverse_embeddings_value, reverse_embeddings_scale = (
|
|
373
|
+
reverse_weight_quantizer(
|
|
374
|
+
self.reverse_embeddings, to_numpy=True
|
|
375
|
+
)
|
|
376
|
+
)
|
|
377
|
+
reverse_embeddings_scale = ops.squeeze(
|
|
378
|
+
reverse_embeddings_scale, axis=0
|
|
379
|
+
)
|
|
380
|
+
# Pack two int4 values into a single int8 byte.
|
|
381
|
+
packed_reverse_embeddings_value, _, _ = quantizers.pack_int4(
|
|
382
|
+
reverse_embeddings_value, axis=0
|
|
383
|
+
)
|
|
384
|
+
del self.reverse_embeddings
|
|
385
|
+
self.quantized_build(
|
|
386
|
+
embeddings_shape, mode, self.quantization_config
|
|
387
|
+
)
|
|
388
|
+
self._embeddings.assign(packed_embeddings_value)
|
|
389
|
+
self.embeddings_scale.assign(embeddings_scale)
|
|
390
|
+
if not self.tie_weights:
|
|
391
|
+
self.reverse_embeddings.assign(packed_reverse_embeddings_value)
|
|
392
|
+
self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
|
|
393
|
+
else:
|
|
394
|
+
raise self._quantization_mode_error(mode)
|
|
395
|
+
|
|
396
|
+
# Set new dtype policy.
|
|
397
|
+
if self.dtype_policy.quantization_mode is None:
|
|
398
|
+
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
|
|
399
|
+
self.dtype_policy = policy
|
keras/src/layers/input_spec.py
CHANGED
|
@@ -111,6 +111,7 @@ class InputSpec:
|
|
|
111
111
|
"max_ndim": self.max_ndim,
|
|
112
112
|
"min_ndim": self.min_ndim,
|
|
113
113
|
"axes": self.axes,
|
|
114
|
+
"optional": self.optional,
|
|
114
115
|
}
|
|
115
116
|
|
|
116
117
|
@classmethod
|
|
@@ -184,24 +185,24 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
|
|
184
185
|
if spec.ndim is not None and not spec.allow_last_axis_squeeze:
|
|
185
186
|
if ndim != spec.ndim:
|
|
186
187
|
raise ValueError(
|
|
187
|
-
f
|
|
188
|
-
"is incompatible with the layer: "
|
|
188
|
+
f"Input {input_index} with name '{spec.name}' of layer "
|
|
189
|
+
f"'{layer_name}' is incompatible with the layer: "
|
|
189
190
|
f"expected ndim={spec.ndim}, found ndim={ndim}. "
|
|
190
191
|
f"Full shape received: {shape}"
|
|
191
192
|
)
|
|
192
193
|
if spec.max_ndim is not None:
|
|
193
194
|
if ndim is not None and ndim > spec.max_ndim:
|
|
194
195
|
raise ValueError(
|
|
195
|
-
f
|
|
196
|
-
"is incompatible with the layer: "
|
|
196
|
+
f"Input {input_index} with name '{spec.name}' of layer "
|
|
197
|
+
f"'{layer_name}' is incompatible with the layer: "
|
|
197
198
|
f"expected max_ndim={spec.max_ndim}, "
|
|
198
199
|
f"found ndim={ndim}"
|
|
199
200
|
)
|
|
200
201
|
if spec.min_ndim is not None:
|
|
201
202
|
if ndim is not None and ndim < spec.min_ndim:
|
|
202
203
|
raise ValueError(
|
|
203
|
-
f
|
|
204
|
-
"is incompatible with the layer: "
|
|
204
|
+
f"Input {input_index} with name '{spec.name}' of layer "
|
|
205
|
+
f"'{layer_name}' is incompatible with the layer: "
|
|
205
206
|
f"expected min_ndim={spec.min_ndim}, "
|
|
206
207
|
f"found ndim={ndim}. "
|
|
207
208
|
f"Full shape received: {shape}"
|
|
@@ -211,8 +212,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
|
|
211
212
|
dtype = backend.standardize_dtype(x.dtype)
|
|
212
213
|
if dtype != spec.dtype:
|
|
213
214
|
raise ValueError(
|
|
214
|
-
f
|
|
215
|
-
"is incompatible with the layer: "
|
|
215
|
+
f"Input {input_index} with name '{spec.name}' of layer "
|
|
216
|
+
f"'{layer_name}' is incompatible with the layer: "
|
|
216
217
|
f"expected dtype={spec.dtype}, "
|
|
217
218
|
f"found dtype={dtype}"
|
|
218
219
|
)
|
|
@@ -225,11 +226,10 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
|
|
225
226
|
None,
|
|
226
227
|
}:
|
|
227
228
|
raise ValueError(
|
|
228
|
-
f
|
|
229
|
-
f"incompatible with the layer:
|
|
230
|
-
f"of input shape to have value
|
|
231
|
-
"but received input with "
|
|
232
|
-
f"shape {shape}"
|
|
229
|
+
f"Input {input_index} with name '{spec.name}' of layer "
|
|
230
|
+
f"'{layer_name}' is incompatible with the layer: "
|
|
231
|
+
f"expected axis {axis} of input shape to have value "
|
|
232
|
+
f"{value}, but received input with shape {shape}"
|
|
233
233
|
)
|
|
234
234
|
# Check shape.
|
|
235
235
|
if spec.shape is not None:
|
|
@@ -243,8 +243,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
|
|
243
243
|
if spec_dim is not None and dim is not None:
|
|
244
244
|
if spec_dim != dim:
|
|
245
245
|
raise ValueError(
|
|
246
|
-
f
|
|
247
|
-
"incompatible with the
|
|
248
|
-
f"expected shape={spec.shape}, "
|
|
249
|
-
f"
|
|
246
|
+
f"Input {input_index} with name '{spec.name}' of "
|
|
247
|
+
f"layer '{layer_name}' is incompatible with the "
|
|
248
|
+
f"layer: expected shape={spec.shape}, found "
|
|
249
|
+
f"shape={shape}"
|
|
250
250
|
)
|
keras/src/layers/layer.py
CHANGED
|
@@ -45,6 +45,7 @@ from keras.src.layers import input_spec
|
|
|
45
45
|
from keras.src.metrics.metric import Metric
|
|
46
46
|
from keras.src.ops.node import Node
|
|
47
47
|
from keras.src.ops.operation import Operation
|
|
48
|
+
from keras.src.quantizers.quantization_config import validate_and_resolve_config
|
|
48
49
|
from keras.src.utils import python_utils
|
|
49
50
|
from keras.src.utils import summary_utils
|
|
50
51
|
from keras.src.utils import traceback_utils
|
|
@@ -244,11 +245,13 @@ class Layer(BackendLayer, Operation):
|
|
|
244
245
|
original_quantize_method = obj.quantize
|
|
245
246
|
|
|
246
247
|
@wraps(original_quantize_method)
|
|
247
|
-
def quantize_wrapper(mode, **kwargs):
|
|
248
|
+
def quantize_wrapper(mode=None, config=None, **kwargs):
|
|
249
|
+
config = validate_and_resolve_config(mode, config)
|
|
250
|
+
mode = config.mode
|
|
248
251
|
obj._check_quantize_args(mode, obj.compute_dtype)
|
|
249
252
|
obj._tracker.unlock()
|
|
250
253
|
try:
|
|
251
|
-
original_quantize_method(mode, **kwargs)
|
|
254
|
+
original_quantize_method(mode=mode, config=config, **kwargs)
|
|
252
255
|
except Exception:
|
|
253
256
|
raise
|
|
254
257
|
finally:
|
|
@@ -757,6 +760,15 @@ class Layer(BackendLayer, Operation):
|
|
|
757
760
|
self._dtype_policy = policy
|
|
758
761
|
if policy.quantization_mode is not None:
|
|
759
762
|
if self.built and not getattr(self, "_is_quantized", False):
|
|
763
|
+
if policy.quantization_mode == "gptq":
|
|
764
|
+
raise ValueError(
|
|
765
|
+
"Implicitly enabling GPTQ quantization by setting "
|
|
766
|
+
f"`dtype_policy` to '{value}' is not supported. "
|
|
767
|
+
"GPTQ requires a calibration dataset and a "
|
|
768
|
+
"`GPTQConfig` object.\n\n"
|
|
769
|
+
"Please use the `.quantize('gptq', config=...)` method "
|
|
770
|
+
"on the layer or model instead."
|
|
771
|
+
)
|
|
760
772
|
self.quantize(policy.quantization_mode)
|
|
761
773
|
|
|
762
774
|
@property
|
|
@@ -824,9 +836,14 @@ class Layer(BackendLayer, Operation):
|
|
|
824
836
|
#############################################################
|
|
825
837
|
# 1. Convert any array arguments to tensors of correct dtype.
|
|
826
838
|
def maybe_convert(x):
|
|
827
|
-
|
|
839
|
+
# Prevent _keras_mask from disappearing
|
|
840
|
+
mask = backend.get_keras_mask(x)
|
|
841
|
+
y = self.dtype_policy.convert_input(
|
|
828
842
|
x, self.autocast, self.input_dtype
|
|
829
843
|
)
|
|
844
|
+
if mask is not None:
|
|
845
|
+
backend.set_keras_mask(y, mask)
|
|
846
|
+
return y
|
|
830
847
|
|
|
831
848
|
# Used to avoid expensive `tree` operations in the most common case.
|
|
832
849
|
if (
|
|
@@ -1268,7 +1285,7 @@ class Layer(BackendLayer, Operation):
|
|
|
1268
1285
|
def quantized_build(self, input_shape, mode):
|
|
1269
1286
|
raise self._not_implemented_error(self.quantized_build)
|
|
1270
1287
|
|
|
1271
|
-
def quantize(self, mode, type_check=True, config=None):
|
|
1288
|
+
def quantize(self, mode=None, type_check=True, config=None):
|
|
1272
1289
|
raise self._not_implemented_error(self.quantize)
|
|
1273
1290
|
|
|
1274
1291
|
def _check_quantize_args(self, mode, compute_dtype):
|
|
@@ -1320,6 +1337,8 @@ class Layer(BackendLayer, Operation):
|
|
|
1320
1337
|
return self._int4_call(*args, **kwargs)
|
|
1321
1338
|
elif self.quantization_mode == "gptq":
|
|
1322
1339
|
return self._gptq_call(*args, **kwargs)
|
|
1340
|
+
elif self.quantization_mode == "awq":
|
|
1341
|
+
return self._awq_call(*args, **kwargs)
|
|
1323
1342
|
else:
|
|
1324
1343
|
raise self._quantization_mode_error(self.quantization_mode)
|
|
1325
1344
|
|
|
@@ -1335,6 +1354,9 @@ class Layer(BackendLayer, Operation):
|
|
|
1335
1354
|
def _gptq_call(self, *args, **kwargs):
|
|
1336
1355
|
raise self._not_implemented_error(self._gptq_call)
|
|
1337
1356
|
|
|
1357
|
+
def _awq_call(self, *args, **kwargs):
|
|
1358
|
+
raise self._not_implemented_error(self._awq_call)
|
|
1359
|
+
|
|
1338
1360
|
def _not_implemented_error(self, attr, msg=None):
|
|
1339
1361
|
if callable(attr):
|
|
1340
1362
|
attr_name = attr.__name__
|
|
@@ -1368,15 +1390,7 @@ class Layer(BackendLayer, Operation):
|
|
|
1368
1390
|
for i, v in enumerate(all_vars):
|
|
1369
1391
|
store[f"{i}"] = v
|
|
1370
1392
|
|
|
1371
|
-
def
|
|
1372
|
-
"""Loads the state of the layer.
|
|
1373
|
-
|
|
1374
|
-
You can override this method to take full control of how the state of
|
|
1375
|
-
the layer is loaded upon calling `keras.models.load_model()`.
|
|
1376
|
-
|
|
1377
|
-
Args:
|
|
1378
|
-
store: Dict from which the state of the model will be loaded.
|
|
1379
|
-
"""
|
|
1393
|
+
def _check_load_own_variables(self, store):
|
|
1380
1394
|
all_vars = self._trainable_variables + self._non_trainable_variables
|
|
1381
1395
|
if len(store.keys()) != len(all_vars):
|
|
1382
1396
|
if len(all_vars) == 0 and not self.built:
|
|
@@ -1409,6 +1423,18 @@ class Layer(BackendLayer, Operation):
|
|
|
1409
1423
|
f"{len(store.keys())} variables during loading. "
|
|
1410
1424
|
f"Expected: {[v.name for v in all_vars]}"
|
|
1411
1425
|
)
|
|
1426
|
+
|
|
1427
|
+
def load_own_variables(self, store):
|
|
1428
|
+
"""Loads the state of the layer.
|
|
1429
|
+
|
|
1430
|
+
You can override this method to take full control of how the state of
|
|
1431
|
+
the layer is loaded upon calling `keras.models.load_model()`.
|
|
1432
|
+
|
|
1433
|
+
Args:
|
|
1434
|
+
store: Dict from which the state of the model will be loaded.
|
|
1435
|
+
"""
|
|
1436
|
+
self._check_load_own_variables(store)
|
|
1437
|
+
all_vars = self._trainable_variables + self._non_trainable_variables
|
|
1412
1438
|
for i, v in enumerate(all_vars):
|
|
1413
1439
|
v.assign(store[f"{i}"])
|
|
1414
1440
|
|
|
@@ -1889,6 +1915,10 @@ def get_shapes_dict(call_spec):
|
|
|
1889
1915
|
{"input_a_shape": (2, 3)}
|
|
1890
1916
|
```
|
|
1891
1917
|
"""
|
|
1918
|
+
|
|
1919
|
+
def standardize_shape_or_none(x):
|
|
1920
|
+
return None if x is None else backend.standardize_shape(x.shape)
|
|
1921
|
+
|
|
1892
1922
|
shapes_dict = {}
|
|
1893
1923
|
for k, v in call_spec.tensor_arguments_dict.items():
|
|
1894
1924
|
if k == "mask" or k.endswith("_mask"):
|
|
@@ -1899,10 +1929,10 @@ def get_shapes_dict(call_spec):
|
|
|
1899
1929
|
continue
|
|
1900
1930
|
if k in call_spec.nested_tensor_argument_names:
|
|
1901
1931
|
shapes_dict[f"{k}_shape"] = tree.map_structure(
|
|
1902
|
-
|
|
1932
|
+
standardize_shape_or_none, v
|
|
1903
1933
|
)
|
|
1904
1934
|
else:
|
|
1905
|
-
shapes_dict[f"{k}_shape"] =
|
|
1935
|
+
shapes_dict[f"{k}_shape"] = standardize_shape_or_none(v)
|
|
1906
1936
|
return shapes_dict
|
|
1907
1937
|
|
|
1908
1938
|
|
keras/src/layers/merging/dot.py
CHANGED
|
@@ -41,6 +41,7 @@ def batch_dot(x, y, axes=None):
|
|
|
41
41
|
axes: Tuple or list of integers with target dimensions, or single
|
|
42
42
|
integer. The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]`
|
|
43
43
|
should be equal.
|
|
44
|
+
Note that axis `0` (the batch axis) cannot be included.
|
|
44
45
|
|
|
45
46
|
Returns:
|
|
46
47
|
A tensor with shape equal to the concatenation of `x`'s shape
|
|
@@ -226,7 +227,8 @@ class Dot(Merge):
|
|
|
226
227
|
take the dot product. If a tuple, should be two integers
|
|
227
228
|
corresponding to the desired axis from the first input and the
|
|
228
229
|
desired axis from the second input, respectively. Note that the
|
|
229
|
-
size of the two selected axes must match
|
|
230
|
+
size of the two selected axes must match, and that
|
|
231
|
+
axis `0` (the batch axis) cannot be included.
|
|
230
232
|
normalize: Whether to L2-normalize samples along the dot product axis
|
|
231
233
|
before taking the dot product. If set to `True`, then
|
|
232
234
|
the output of the dot product is the cosine proximity
|
|
@@ -363,6 +365,7 @@ def dot(inputs, axes=-1, **kwargs):
|
|
|
363
365
|
inputs: A list of input tensors (at least 2).
|
|
364
366
|
axes: Integer or tuple of integers,
|
|
365
367
|
axis or axes along which to take the dot product.
|
|
368
|
+
Note that axis `0` (the batch axis) cannot be included.
|
|
366
369
|
normalize: Whether to L2-normalize samples along the
|
|
367
370
|
dot product axis before taking the dot product.
|
|
368
371
|
If set to `True`, then the output of the dot product
|