keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -10,6 +10,8 @@ from keras.src import regularizers
|
|
|
10
10
|
from keras.src.api_export import keras_export
|
|
11
11
|
from keras.src.backend import KerasTensor
|
|
12
12
|
from keras.src.layers.layer import Layer
|
|
13
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
14
|
+
from keras.src.saving import serialization_lib
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
@keras_export("keras.layers.Embedding")
|
|
@@ -90,6 +92,7 @@ class Embedding(Layer):
|
|
|
90
92
|
weights=None,
|
|
91
93
|
lora_rank=None,
|
|
92
94
|
lora_alpha=None,
|
|
95
|
+
quantization_config=None,
|
|
93
96
|
**kwargs,
|
|
94
97
|
):
|
|
95
98
|
input_length = kwargs.pop("input_length", None)
|
|
@@ -109,6 +112,7 @@ class Embedding(Layer):
|
|
|
109
112
|
self.lora_rank = lora_rank
|
|
110
113
|
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
|
|
111
114
|
self.lora_enabled = False
|
|
115
|
+
self.quantization_config = quantization_config
|
|
112
116
|
|
|
113
117
|
if weights is not None:
|
|
114
118
|
self.build()
|
|
@@ -120,9 +124,13 @@ class Embedding(Layer):
|
|
|
120
124
|
if self.built:
|
|
121
125
|
return
|
|
122
126
|
embeddings_shape = (self.input_dim, self.output_dim)
|
|
123
|
-
if self.quantization_mode
|
|
124
|
-
self.quantized_build(
|
|
125
|
-
|
|
127
|
+
if self.quantization_mode:
|
|
128
|
+
self.quantized_build(
|
|
129
|
+
embeddings_shape,
|
|
130
|
+
mode=self.quantization_mode,
|
|
131
|
+
config=self.quantization_config,
|
|
132
|
+
)
|
|
133
|
+
if self.quantization_mode not in ("int8", "int4"):
|
|
126
134
|
self._embeddings = self.add_weight(
|
|
127
135
|
shape=embeddings_shape,
|
|
128
136
|
initializer=self.embeddings_initializer,
|
|
@@ -137,12 +145,20 @@ class Embedding(Layer):
|
|
|
137
145
|
|
|
138
146
|
@property
|
|
139
147
|
def embeddings(self):
|
|
148
|
+
if not self.built:
|
|
149
|
+
raise AttributeError(
|
|
150
|
+
"You must build the layer before accessing `embeddings`."
|
|
151
|
+
)
|
|
152
|
+
embeddings = self._embeddings
|
|
153
|
+
if self.quantization_mode == "int4":
|
|
154
|
+
embeddings = quantizers.unpack_int4(
|
|
155
|
+
embeddings, self._orig_output_dim, axis=-1
|
|
156
|
+
)
|
|
140
157
|
if self.lora_enabled:
|
|
141
|
-
return self.
|
|
142
|
-
self.
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
return self._embeddings
|
|
158
|
+
return embeddings + (self.lora_alpha / self.lora_rank) * ops.matmul(
|
|
159
|
+
self.lora_embeddings_a, self.lora_embeddings_b
|
|
160
|
+
)
|
|
161
|
+
return embeddings
|
|
146
162
|
|
|
147
163
|
def call(self, inputs):
|
|
148
164
|
if inputs.dtype != "int32" and inputs.dtype != "int64":
|
|
@@ -189,13 +205,13 @@ class Embedding(Layer):
|
|
|
189
205
|
self._tracker.unlock()
|
|
190
206
|
self.lora_embeddings_a = self.add_weight(
|
|
191
207
|
name="lora_embeddings_a",
|
|
192
|
-
shape=(self.
|
|
208
|
+
shape=(self.input_dim, rank),
|
|
193
209
|
initializer=initializers.get(a_initializer),
|
|
194
210
|
regularizer=self.embeddings_regularizer,
|
|
195
211
|
)
|
|
196
212
|
self.lora_embeddings_b = self.add_weight(
|
|
197
213
|
name="lora_embeddings_b",
|
|
198
|
-
shape=(rank, self.
|
|
214
|
+
shape=(rank, self.output_dim),
|
|
199
215
|
initializer=initializers.get(b_initializer),
|
|
200
216
|
regularizer=self.embeddings_regularizer,
|
|
201
217
|
)
|
|
@@ -209,19 +225,26 @@ class Embedding(Layer):
|
|
|
209
225
|
# Do nothing if the layer isn't yet built
|
|
210
226
|
if not self.built:
|
|
211
227
|
return
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
228
|
+
mode = self.quantization_mode
|
|
229
|
+
if mode not in self.variable_serialization_spec:
|
|
230
|
+
raise self._quantization_mode_error(mode)
|
|
231
|
+
|
|
232
|
+
# Embeddings plus optional merged LoRA-aware scale
|
|
233
|
+
# (returns (embeddings, None) for `None` mode).
|
|
234
|
+
embeddings_value, merged_kernel_scale = (
|
|
215
235
|
self._get_embeddings_with_merged_lora()
|
|
216
236
|
)
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
if
|
|
220
|
-
|
|
237
|
+
idx = 0
|
|
238
|
+
for name in self.variable_serialization_spec[mode]:
|
|
239
|
+
if name == "embeddings":
|
|
240
|
+
store[str(idx)] = embeddings_value
|
|
241
|
+
elif name == "embeddings_scale" and mode in ("int4", "int8"):
|
|
242
|
+
# For int4/int8, the merged LoRA scale (if any) comes from
|
|
243
|
+
# `_get_embeddings_with_merged_lora()`
|
|
244
|
+
store[str(idx)] = merged_kernel_scale
|
|
221
245
|
else:
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
store[str(i)] = variable
|
|
246
|
+
store[str(idx)] = getattr(self, name)
|
|
247
|
+
idx += 1
|
|
225
248
|
|
|
226
249
|
def load_own_variables(self, store):
|
|
227
250
|
if not self.lora_enabled:
|
|
@@ -229,16 +252,17 @@ class Embedding(Layer):
|
|
|
229
252
|
# Do nothing if the layer isn't yet built
|
|
230
253
|
if not self.built:
|
|
231
254
|
return
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
255
|
+
mode = self.quantization_mode
|
|
256
|
+
if mode not in self.variable_serialization_spec:
|
|
257
|
+
raise self._quantization_mode_error(mode)
|
|
258
|
+
|
|
259
|
+
idx = 0
|
|
260
|
+
for name in self.variable_serialization_spec[mode]:
|
|
261
|
+
if name == "embeddings":
|
|
262
|
+
self._embeddings.assign(store[str(idx)])
|
|
238
263
|
else:
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
variable.assign(store[str(i)])
|
|
264
|
+
getattr(self, name).assign(store[str(idx)])
|
|
265
|
+
idx += 1
|
|
242
266
|
if self.lora_enabled:
|
|
243
267
|
self.lora_embeddings_a.assign(
|
|
244
268
|
ops.zeros(self.lora_embeddings_a.shape)
|
|
@@ -265,62 +289,63 @@ class Embedding(Layer):
|
|
|
265
289
|
self.embeddings_constraint
|
|
266
290
|
),
|
|
267
291
|
"mask_zero": self.mask_zero,
|
|
292
|
+
"quantization_config": serialization_lib.serialize_keras_object(
|
|
293
|
+
self.quantization_config
|
|
294
|
+
),
|
|
268
295
|
}
|
|
269
296
|
if self.lora_rank:
|
|
270
297
|
config["lora_rank"] = self.lora_rank
|
|
271
298
|
config["lora_alpha"] = self.lora_alpha
|
|
272
299
|
return {**base_config, **config}
|
|
273
300
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
"and thus it doesn't have any variables. "
|
|
281
|
-
f"However the weights file lists {len(store.keys())} "
|
|
282
|
-
"variables for this layer.\n"
|
|
283
|
-
"In most cases, this error indicates that either:\n\n"
|
|
284
|
-
"1. The layer is owned by a parent layer that "
|
|
285
|
-
"implements a `build()` method, but calling the "
|
|
286
|
-
"parent's `build()` method did NOT create the state of "
|
|
287
|
-
f"the child layer '{self.name}'. A `build()` method "
|
|
288
|
-
"must create ALL state for the layer, including "
|
|
289
|
-
"the state of any children layers.\n\n"
|
|
290
|
-
"2. You need to implement "
|
|
291
|
-
"the `def build_from_config(self, config)` method "
|
|
292
|
-
f"on layer '{self.name}', to specify how to rebuild "
|
|
293
|
-
"it during loading. "
|
|
294
|
-
"In this case, you might also want to implement the "
|
|
295
|
-
"method that generates the build config at saving time, "
|
|
296
|
-
"`def get_build_config(self)`. "
|
|
297
|
-
"The method `build_from_config()` is meant "
|
|
298
|
-
"to create the state "
|
|
299
|
-
"of the layer (i.e. its variables) upon deserialization.",
|
|
300
|
-
)
|
|
301
|
-
raise ValueError(
|
|
302
|
-
f"Layer '{self.name}' expected {len(all_vars)} variables, "
|
|
303
|
-
"but received "
|
|
304
|
-
f"{len(store.keys())} variables during loading. "
|
|
305
|
-
f"Expected: {[v.name for v in all_vars]}"
|
|
301
|
+
@classmethod
|
|
302
|
+
def from_config(cls, config):
|
|
303
|
+
config = config.copy()
|
|
304
|
+
config["quantization_config"] = (
|
|
305
|
+
serialization_lib.deserialize_keras_object(
|
|
306
|
+
config.get("quantization_config", None)
|
|
306
307
|
)
|
|
307
|
-
|
|
308
|
-
|
|
308
|
+
)
|
|
309
|
+
return super().from_config(config)
|
|
309
310
|
|
|
310
311
|
def _quantization_mode_error(self, mode):
|
|
311
312
|
return NotImplementedError(
|
|
312
|
-
"Invalid quantization mode. Expected 'int8'. "
|
|
313
|
+
"Invalid quantization mode. Expected one of ('int8', 'int4'). "
|
|
313
314
|
f"Received: quantization_mode={mode}"
|
|
314
315
|
)
|
|
315
316
|
|
|
316
|
-
|
|
317
|
+
@property
|
|
318
|
+
def variable_serialization_spec(self):
|
|
319
|
+
"""Returns a dict mapping quantization modes to variable names in order.
|
|
320
|
+
|
|
321
|
+
This spec is used by `save_own_variables` and `load_own_variables` to
|
|
322
|
+
determine the correct ordering of variables during serialization for
|
|
323
|
+
each quantization mode. `None` means no quantization.
|
|
324
|
+
"""
|
|
325
|
+
return {
|
|
326
|
+
None: [
|
|
327
|
+
"embeddings",
|
|
328
|
+
],
|
|
329
|
+
"int8": [
|
|
330
|
+
"embeddings",
|
|
331
|
+
"embeddings_scale",
|
|
332
|
+
],
|
|
333
|
+
"int4": [
|
|
334
|
+
"embeddings",
|
|
335
|
+
"embeddings_scale",
|
|
336
|
+
],
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
def quantized_build(self, embeddings_shape, mode, config=None):
|
|
317
340
|
if mode == "int8":
|
|
318
|
-
self._int8_build(embeddings_shape)
|
|
341
|
+
self._int8_build(embeddings_shape, config)
|
|
342
|
+
elif mode == "int4":
|
|
343
|
+
self._int4_build(embeddings_shape, config)
|
|
319
344
|
else:
|
|
320
345
|
raise self._quantization_mode_error(mode)
|
|
321
346
|
self._is_quantized = True
|
|
322
347
|
|
|
323
|
-
def _int8_build(self, embeddings_shape):
|
|
348
|
+
def _int8_build(self, embeddings_shape, config=None):
|
|
324
349
|
self._embeddings = self.add_weight(
|
|
325
350
|
name="embeddings",
|
|
326
351
|
shape=embeddings_shape,
|
|
@@ -338,10 +363,27 @@ class Embedding(Layer):
|
|
|
338
363
|
trainable=False,
|
|
339
364
|
)
|
|
340
365
|
|
|
341
|
-
def
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
366
|
+
def _int4_build(self, embeddings_shape, config=None):
|
|
367
|
+
input_dim, output_dim = embeddings_shape
|
|
368
|
+
packed_rows = (output_dim + 1) // 2 # ceil for odd dims
|
|
369
|
+
|
|
370
|
+
# Embeddings are stored *packed*: each int8 byte contains two int4
|
|
371
|
+
# values.
|
|
372
|
+
self._embeddings = self.add_weight(
|
|
373
|
+
name="embeddings",
|
|
374
|
+
shape=(input_dim, packed_rows),
|
|
375
|
+
initializer="zeros",
|
|
376
|
+
dtype="int8",
|
|
377
|
+
trainable=False,
|
|
378
|
+
)
|
|
379
|
+
self.embeddings_scale = self.add_weight(
|
|
380
|
+
name="embeddings_scale",
|
|
381
|
+
shape=(self.input_dim,),
|
|
382
|
+
initializer="ones",
|
|
383
|
+
trainable=False,
|
|
384
|
+
)
|
|
385
|
+
# Record original output_dim for unpacking at runtime.
|
|
386
|
+
self._orig_output_dim = output_dim
|
|
345
387
|
|
|
346
388
|
def _int8_call(self, inputs, training=None):
|
|
347
389
|
# We cannot update quantized self._embeddings, so the custom gradient is
|
|
@@ -363,49 +405,165 @@ class Embedding(Layer):
|
|
|
363
405
|
)
|
|
364
406
|
return outputs
|
|
365
407
|
|
|
366
|
-
def
|
|
367
|
-
#
|
|
408
|
+
def _int4_call(self, inputs, training=None):
|
|
409
|
+
# We cannot update quantized self._embeddings, so the custom gradient is
|
|
410
|
+
# not needed
|
|
411
|
+
if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"):
|
|
412
|
+
inputs = ops.cast(inputs, "int32")
|
|
413
|
+
embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)
|
|
414
|
+
unpacked_embeddings = quantizers.unpack_int4(
|
|
415
|
+
self._embeddings, self._orig_output_dim, axis=-1
|
|
416
|
+
)
|
|
417
|
+
outputs = ops.take(unpacked_embeddings, inputs, axis=0)
|
|
418
|
+
# De-scale outputs
|
|
419
|
+
outputs = ops.divide(
|
|
420
|
+
ops.cast(outputs, dtype=self.compute_dtype),
|
|
421
|
+
ops.expand_dims(embeddings_scale, axis=-1),
|
|
422
|
+
)
|
|
423
|
+
if self.lora_enabled:
|
|
424
|
+
lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0)
|
|
425
|
+
lora_outputs = ops.matmul(lora_outputs, self.lora_embeddings_b)
|
|
426
|
+
outputs = ops.add(
|
|
427
|
+
outputs, (self.lora_alpha / self.lora_rank) * lora_outputs
|
|
428
|
+
)
|
|
429
|
+
return outputs
|
|
430
|
+
|
|
431
|
+
def quantize(self, mode=None, type_check=True, config=None):
|
|
432
|
+
# Prevent quantization of the subclasses.
|
|
368
433
|
if type_check and (type(self) is not Embedding):
|
|
369
434
|
raise self._not_implemented_error(self.quantize)
|
|
370
435
|
|
|
436
|
+
self.quantization_config = config
|
|
437
|
+
|
|
371
438
|
embeddings_shape = (self.input_dim, self.output_dim)
|
|
372
439
|
if mode == "int8":
|
|
373
440
|
# Quantize `self._embeddings` to int8 and compute corresponding
|
|
374
|
-
# scale
|
|
375
|
-
|
|
376
|
-
self.
|
|
441
|
+
# scale.
|
|
442
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
443
|
+
self.quantization_config,
|
|
444
|
+
quantizers.AbsMaxQuantizer(axis=-1),
|
|
445
|
+
)
|
|
446
|
+
embeddings_value, embeddings_scale = weight_quantizer(
|
|
447
|
+
self._embeddings, to_numpy=True
|
|
377
448
|
)
|
|
378
449
|
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
379
450
|
del self._embeddings
|
|
380
|
-
|
|
381
|
-
|
|
451
|
+
self.quantized_build(
|
|
452
|
+
embeddings_shape, mode, self.quantization_config
|
|
453
|
+
)
|
|
382
454
|
self._embeddings.assign(embeddings_value)
|
|
383
455
|
self.embeddings_scale.assign(embeddings_scale)
|
|
456
|
+
elif mode == "int4":
|
|
457
|
+
# Quantize to int4 values (stored in int8 dtype, range [-8, 7]).
|
|
458
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
459
|
+
self.quantization_config,
|
|
460
|
+
quantizers.AbsMaxQuantizer(
|
|
461
|
+
axis=-1,
|
|
462
|
+
value_range=(-8, 7),
|
|
463
|
+
output_dtype="int8",
|
|
464
|
+
),
|
|
465
|
+
)
|
|
466
|
+
embeddings_value, embeddings_scale = weight_quantizer(
|
|
467
|
+
self._embeddings, to_numpy=True
|
|
468
|
+
)
|
|
469
|
+
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
470
|
+
# 2. Pack two int4 values into a single int8 byte.
|
|
471
|
+
packed_embeddings_value, _, _ = quantizers.pack_int4(
|
|
472
|
+
embeddings_value, axis=-1
|
|
473
|
+
)
|
|
474
|
+
del self._embeddings
|
|
475
|
+
self.quantized_build(
|
|
476
|
+
embeddings_shape, mode, self.quantization_config
|
|
477
|
+
)
|
|
478
|
+
self._embeddings.assign(packed_embeddings_value)
|
|
479
|
+
self.embeddings_scale.assign(embeddings_scale)
|
|
480
|
+
else:
|
|
481
|
+
raise self._quantization_mode_error(mode)
|
|
384
482
|
|
|
385
|
-
# Set new dtype policy
|
|
483
|
+
# Set new dtype policy.
|
|
386
484
|
if self.dtype_policy.quantization_mode is None:
|
|
387
485
|
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
|
|
388
486
|
self.dtype_policy = policy
|
|
389
487
|
|
|
390
488
|
def _get_embeddings_with_merged_lora(self):
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
489
|
+
"""Returns the embeddings with LoRA matrices merged, for serialization.
|
|
490
|
+
|
|
491
|
+
This method is called by `save_own_variables` to produce a single
|
|
492
|
+
embeddings tensor that includes the adaptations from LoRA. This is
|
|
493
|
+
useful for deploying the model or for continuing training after
|
|
494
|
+
permanently applying the LoRA update.
|
|
495
|
+
|
|
496
|
+
If the layer is quantized (`int8` or `int4`), the process is:
|
|
497
|
+
1. Dequantize the base embeddings to float.
|
|
498
|
+
2. Compute the LoRA delta (`lora_embeddings_a @ lora_embeddings_b`) and
|
|
499
|
+
add it to the dequantized embeddings.
|
|
500
|
+
3. Re-quantize the merged result back to the original quantized
|
|
501
|
+
type (`int8` or packed `int4`), calculating a new scale factor.
|
|
502
|
+
|
|
503
|
+
If the layer is not quantized, this method returns the result of the
|
|
504
|
+
`embeddings` property (which computes the merge in floating-point) and a
|
|
505
|
+
scale of `None`.
|
|
506
|
+
|
|
507
|
+
If LoRA is not enabled, it returns the original embeddings and scale
|
|
508
|
+
without modification.
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
A tuple `(embeddings_value, embeddings_scale)`:
|
|
512
|
+
`embeddings_value`: The merged embeddings. A quantized tensor if
|
|
513
|
+
quantization is active, otherwise a high precision tensor.
|
|
514
|
+
`embeddings_scale`: The quantization scale for the merged
|
|
515
|
+
embeddings. This is `None` if the layer is not quantized.
|
|
516
|
+
"""
|
|
517
|
+
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
518
|
+
return self.embeddings, None
|
|
519
|
+
|
|
520
|
+
embeddings_value = self._embeddings
|
|
521
|
+
embeddings_scale = self.embeddings_scale
|
|
522
|
+
if not self.lora_enabled:
|
|
410
523
|
return embeddings_value, embeddings_scale
|
|
411
|
-
|
|
524
|
+
|
|
525
|
+
# Dequantize embeddings to float.
|
|
526
|
+
if self.quantization_mode == "int4":
|
|
527
|
+
unpacked_embeddings = quantizers.unpack_int4(
|
|
528
|
+
embeddings_value, self._orig_output_dim, axis=-1
|
|
529
|
+
)
|
|
530
|
+
float_embeddings = ops.divide(
|
|
531
|
+
ops.cast(unpacked_embeddings, self.compute_dtype),
|
|
532
|
+
ops.expand_dims(embeddings_scale, axis=-1),
|
|
533
|
+
)
|
|
534
|
+
quant_range = (-8, 7)
|
|
535
|
+
elif self.quantization_mode == "int8":
|
|
536
|
+
float_embeddings = ops.divide(
|
|
537
|
+
ops.cast(embeddings_value, self.compute_dtype),
|
|
538
|
+
ops.expand_dims(embeddings_scale, axis=-1),
|
|
539
|
+
)
|
|
540
|
+
quant_range = (-127, 127)
|
|
541
|
+
else:
|
|
542
|
+
raise ValueError(
|
|
543
|
+
f"Unsupported quantization mode: {self.quantization_mode}"
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Merge LoRA weights in float domain.
|
|
547
|
+
lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul(
|
|
548
|
+
self.lora_embeddings_a, self.lora_embeddings_b
|
|
549
|
+
)
|
|
550
|
+
merged_float_embeddings = ops.add(float_embeddings, lora_delta)
|
|
551
|
+
|
|
552
|
+
# Requantize.
|
|
553
|
+
requantized_embeddings, embeddings_scale = quantizers.abs_max_quantize(
|
|
554
|
+
merged_float_embeddings,
|
|
555
|
+
axis=-1,
|
|
556
|
+
value_range=quant_range,
|
|
557
|
+
dtype="int8",
|
|
558
|
+
to_numpy=True,
|
|
559
|
+
)
|
|
560
|
+
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
561
|
+
|
|
562
|
+
# Pack if int4.
|
|
563
|
+
if self.quantization_mode == "int4":
|
|
564
|
+
embeddings_value, _, _ = quantizers.pack_int4(
|
|
565
|
+
requantized_embeddings, axis=-1
|
|
566
|
+
)
|
|
567
|
+
else:
|
|
568
|
+
embeddings_value = requantized_embeddings
|
|
569
|
+
return embeddings_value, embeddings_scale
|