keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,8 @@ from keras.src import regularizers
10
10
  from keras.src.api_export import keras_export
11
11
  from keras.src.backend import KerasTensor
12
12
  from keras.src.layers.layer import Layer
13
+ from keras.src.quantizers.quantization_config import QuantizationConfig
14
+ from keras.src.saving import serialization_lib
13
15
 
14
16
 
15
17
  @keras_export("keras.layers.Embedding")
@@ -90,6 +92,7 @@ class Embedding(Layer):
90
92
  weights=None,
91
93
  lora_rank=None,
92
94
  lora_alpha=None,
95
+ quantization_config=None,
93
96
  **kwargs,
94
97
  ):
95
98
  input_length = kwargs.pop("input_length", None)
@@ -109,6 +112,7 @@ class Embedding(Layer):
109
112
  self.lora_rank = lora_rank
110
113
  self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
111
114
  self.lora_enabled = False
115
+ self.quantization_config = quantization_config
112
116
 
113
117
  if weights is not None:
114
118
  self.build()
@@ -120,9 +124,13 @@ class Embedding(Layer):
120
124
  if self.built:
121
125
  return
122
126
  embeddings_shape = (self.input_dim, self.output_dim)
123
- if self.quantization_mode is not None:
124
- self.quantized_build(embeddings_shape, mode=self.quantization_mode)
125
- if self.quantization_mode != "int8":
127
+ if self.quantization_mode:
128
+ self.quantized_build(
129
+ embeddings_shape,
130
+ mode=self.quantization_mode,
131
+ config=self.quantization_config,
132
+ )
133
+ if self.quantization_mode not in ("int8", "int4"):
126
134
  self._embeddings = self.add_weight(
127
135
  shape=embeddings_shape,
128
136
  initializer=self.embeddings_initializer,
@@ -137,12 +145,20 @@ class Embedding(Layer):
137
145
 
138
146
  @property
139
147
  def embeddings(self):
148
+ if not self.built:
149
+ raise AttributeError(
150
+ "You must build the layer before accessing `embeddings`."
151
+ )
152
+ embeddings = self._embeddings
153
+ if self.quantization_mode == "int4":
154
+ embeddings = quantizers.unpack_int4(
155
+ embeddings, self._orig_output_dim, axis=-1
156
+ )
140
157
  if self.lora_enabled:
141
- return self._embeddings + (
142
- self.lora_alpha / self.lora_rank
143
- ) * ops.matmul(self.lora_embeddings_a, self.lora_embeddings_b)
144
-
145
- return self._embeddings
158
+ return embeddings + (self.lora_alpha / self.lora_rank) * ops.matmul(
159
+ self.lora_embeddings_a, self.lora_embeddings_b
160
+ )
161
+ return embeddings
146
162
 
147
163
  def call(self, inputs):
148
164
  if inputs.dtype != "int32" and inputs.dtype != "int64":
@@ -189,13 +205,13 @@ class Embedding(Layer):
189
205
  self._tracker.unlock()
190
206
  self.lora_embeddings_a = self.add_weight(
191
207
  name="lora_embeddings_a",
192
- shape=(self.embeddings.shape[0], rank),
208
+ shape=(self.input_dim, rank),
193
209
  initializer=initializers.get(a_initializer),
194
210
  regularizer=self.embeddings_regularizer,
195
211
  )
196
212
  self.lora_embeddings_b = self.add_weight(
197
213
  name="lora_embeddings_b",
198
- shape=(rank, self.embeddings.shape[1]),
214
+ shape=(rank, self.output_dim),
199
215
  initializer=initializers.get(b_initializer),
200
216
  regularizer=self.embeddings_regularizer,
201
217
  )
@@ -209,19 +225,26 @@ class Embedding(Layer):
209
225
  # Do nothing if the layer isn't yet built
210
226
  if not self.built:
211
227
  return
212
- # The keys of the `store` will be saved as determined because the
213
- # default ordering will change after quantization
214
- embeddings_value, embeddings_scale = (
228
+ mode = self.quantization_mode
229
+ if mode not in self.variable_serialization_spec:
230
+ raise self._quantization_mode_error(mode)
231
+
232
+ # Embeddings plus optional merged LoRA-aware scale
233
+ # (returns (embeddings, None) for `None` mode).
234
+ embeddings_value, merged_kernel_scale = (
215
235
  self._get_embeddings_with_merged_lora()
216
236
  )
217
- target_variables = [embeddings_value]
218
- if self.quantization_mode is not None:
219
- if self.quantization_mode == "int8":
220
- target_variables.append(embeddings_scale)
237
+ idx = 0
238
+ for name in self.variable_serialization_spec[mode]:
239
+ if name == "embeddings":
240
+ store[str(idx)] = embeddings_value
241
+ elif name == "embeddings_scale" and mode in ("int4", "int8"):
242
+ # For int4/int8, the merged LoRA scale (if any) comes from
243
+ # `_get_embeddings_with_merged_lora()`
244
+ store[str(idx)] = merged_kernel_scale
221
245
  else:
222
- raise self._quantization_mode_error(self.quantization_mode)
223
- for i, variable in enumerate(target_variables):
224
- store[str(i)] = variable
246
+ store[str(idx)] = getattr(self, name)
247
+ idx += 1
225
248
 
226
249
  def load_own_variables(self, store):
227
250
  if not self.lora_enabled:
@@ -229,16 +252,17 @@ class Embedding(Layer):
229
252
  # Do nothing if the layer isn't yet built
230
253
  if not self.built:
231
254
  return
232
- # The keys of the `store` will be saved as determined because the
233
- # default ordering will change after quantization
234
- target_variables = [self._embeddings]
235
- if self.quantization_mode is not None:
236
- if self.quantization_mode == "int8":
237
- target_variables.append(self.embeddings_scale)
255
+ mode = self.quantization_mode
256
+ if mode not in self.variable_serialization_spec:
257
+ raise self._quantization_mode_error(mode)
258
+
259
+ idx = 0
260
+ for name in self.variable_serialization_spec[mode]:
261
+ if name == "embeddings":
262
+ self._embeddings.assign(store[str(idx)])
238
263
  else:
239
- raise self._quantization_mode_error(self.quantization_mode)
240
- for i, variable in enumerate(target_variables):
241
- variable.assign(store[str(i)])
264
+ getattr(self, name).assign(store[str(idx)])
265
+ idx += 1
242
266
  if self.lora_enabled:
243
267
  self.lora_embeddings_a.assign(
244
268
  ops.zeros(self.lora_embeddings_a.shape)
@@ -265,62 +289,63 @@ class Embedding(Layer):
265
289
  self.embeddings_constraint
266
290
  ),
267
291
  "mask_zero": self.mask_zero,
292
+ "quantization_config": serialization_lib.serialize_keras_object(
293
+ self.quantization_config
294
+ ),
268
295
  }
269
296
  if self.lora_rank:
270
297
  config["lora_rank"] = self.lora_rank
271
298
  config["lora_alpha"] = self.lora_alpha
272
299
  return {**base_config, **config}
273
300
 
274
- def _check_load_own_variables(self, store):
275
- all_vars = self._trainable_variables + self._non_trainable_variables
276
- if len(store.keys()) != len(all_vars):
277
- if len(all_vars) == 0 and not self.built:
278
- raise ValueError(
279
- f"Layer '{self.name}' was never built "
280
- "and thus it doesn't have any variables. "
281
- f"However the weights file lists {len(store.keys())} "
282
- "variables for this layer.\n"
283
- "In most cases, this error indicates that either:\n\n"
284
- "1. The layer is owned by a parent layer that "
285
- "implements a `build()` method, but calling the "
286
- "parent's `build()` method did NOT create the state of "
287
- f"the child layer '{self.name}'. A `build()` method "
288
- "must create ALL state for the layer, including "
289
- "the state of any children layers.\n\n"
290
- "2. You need to implement "
291
- "the `def build_from_config(self, config)` method "
292
- f"on layer '{self.name}', to specify how to rebuild "
293
- "it during loading. "
294
- "In this case, you might also want to implement the "
295
- "method that generates the build config at saving time, "
296
- "`def get_build_config(self)`. "
297
- "The method `build_from_config()` is meant "
298
- "to create the state "
299
- "of the layer (i.e. its variables) upon deserialization.",
300
- )
301
- raise ValueError(
302
- f"Layer '{self.name}' expected {len(all_vars)} variables, "
303
- "but received "
304
- f"{len(store.keys())} variables during loading. "
305
- f"Expected: {[v.name for v in all_vars]}"
301
+ @classmethod
302
+ def from_config(cls, config):
303
+ config = config.copy()
304
+ config["quantization_config"] = (
305
+ serialization_lib.deserialize_keras_object(
306
+ config.get("quantization_config", None)
306
307
  )
307
-
308
- """Quantization-related (int8) methods"""
308
+ )
309
+ return super().from_config(config)
309
310
 
310
311
  def _quantization_mode_error(self, mode):
311
312
  return NotImplementedError(
312
- "Invalid quantization mode. Expected 'int8'. "
313
+ "Invalid quantization mode. Expected one of ('int8', 'int4'). "
313
314
  f"Received: quantization_mode={mode}"
314
315
  )
315
316
 
316
- def quantized_build(self, embeddings_shape, mode):
317
+ @property
318
+ def variable_serialization_spec(self):
319
+ """Returns a dict mapping quantization modes to variable names in order.
320
+
321
+ This spec is used by `save_own_variables` and `load_own_variables` to
322
+ determine the correct ordering of variables during serialization for
323
+ each quantization mode. `None` means no quantization.
324
+ """
325
+ return {
326
+ None: [
327
+ "embeddings",
328
+ ],
329
+ "int8": [
330
+ "embeddings",
331
+ "embeddings_scale",
332
+ ],
333
+ "int4": [
334
+ "embeddings",
335
+ "embeddings_scale",
336
+ ],
337
+ }
338
+
339
+ def quantized_build(self, embeddings_shape, mode, config=None):
317
340
  if mode == "int8":
318
- self._int8_build(embeddings_shape)
341
+ self._int8_build(embeddings_shape, config)
342
+ elif mode == "int4":
343
+ self._int4_build(embeddings_shape, config)
319
344
  else:
320
345
  raise self._quantization_mode_error(mode)
321
346
  self._is_quantized = True
322
347
 
323
- def _int8_build(self, embeddings_shape):
348
+ def _int8_build(self, embeddings_shape, config=None):
324
349
  self._embeddings = self.add_weight(
325
350
  name="embeddings",
326
351
  shape=embeddings_shape,
@@ -338,10 +363,27 @@ class Embedding(Layer):
338
363
  trainable=False,
339
364
  )
340
365
 
341
- def quantized_call(self, *args, **kwargs):
342
- if self.quantization_mode != "int8":
343
- raise self._quantization_mode_error(self.quantization_mode)
344
- return super().quantized_call(*args, **kwargs)
366
+ def _int4_build(self, embeddings_shape, config=None):
367
+ input_dim, output_dim = embeddings_shape
368
+ packed_rows = (output_dim + 1) // 2 # ceil for odd dims
369
+
370
+ # Embeddings are stored *packed*: each int8 byte contains two int4
371
+ # values.
372
+ self._embeddings = self.add_weight(
373
+ name="embeddings",
374
+ shape=(input_dim, packed_rows),
375
+ initializer="zeros",
376
+ dtype="int8",
377
+ trainable=False,
378
+ )
379
+ self.embeddings_scale = self.add_weight(
380
+ name="embeddings_scale",
381
+ shape=(self.input_dim,),
382
+ initializer="ones",
383
+ trainable=False,
384
+ )
385
+ # Record original output_dim for unpacking at runtime.
386
+ self._orig_output_dim = output_dim
345
387
 
346
388
  def _int8_call(self, inputs, training=None):
347
389
  # We cannot update quantized self._embeddings, so the custom gradient is
@@ -363,49 +405,165 @@ class Embedding(Layer):
363
405
  )
364
406
  return outputs
365
407
 
366
- def quantize(self, mode, type_check=True):
367
- # Prevent quantization of the subclasses
408
+ def _int4_call(self, inputs, training=None):
409
+ # We cannot update quantized self._embeddings, so the custom gradient is
410
+ # not needed
411
+ if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"):
412
+ inputs = ops.cast(inputs, "int32")
413
+ embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)
414
+ unpacked_embeddings = quantizers.unpack_int4(
415
+ self._embeddings, self._orig_output_dim, axis=-1
416
+ )
417
+ outputs = ops.take(unpacked_embeddings, inputs, axis=0)
418
+ # De-scale outputs
419
+ outputs = ops.divide(
420
+ ops.cast(outputs, dtype=self.compute_dtype),
421
+ ops.expand_dims(embeddings_scale, axis=-1),
422
+ )
423
+ if self.lora_enabled:
424
+ lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0)
425
+ lora_outputs = ops.matmul(lora_outputs, self.lora_embeddings_b)
426
+ outputs = ops.add(
427
+ outputs, (self.lora_alpha / self.lora_rank) * lora_outputs
428
+ )
429
+ return outputs
430
+
431
+ def quantize(self, mode=None, type_check=True, config=None):
432
+ # Prevent quantization of the subclasses.
368
433
  if type_check and (type(self) is not Embedding):
369
434
  raise self._not_implemented_error(self.quantize)
370
435
 
436
+ self.quantization_config = config
437
+
371
438
  embeddings_shape = (self.input_dim, self.output_dim)
372
439
  if mode == "int8":
373
440
  # Quantize `self._embeddings` to int8 and compute corresponding
374
- # scale
375
- embeddings_value, embeddings_scale = quantizers.abs_max_quantize(
376
- self._embeddings, axis=-1, to_numpy=True
441
+ # scale.
442
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
443
+ self.quantization_config,
444
+ quantizers.AbsMaxQuantizer(axis=-1),
445
+ )
446
+ embeddings_value, embeddings_scale = weight_quantizer(
447
+ self._embeddings, to_numpy=True
377
448
  )
378
449
  embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
379
450
  del self._embeddings
380
- self.quantized_build(embeddings_shape, mode)
381
- if mode == "int8":
451
+ self.quantized_build(
452
+ embeddings_shape, mode, self.quantization_config
453
+ )
382
454
  self._embeddings.assign(embeddings_value)
383
455
  self.embeddings_scale.assign(embeddings_scale)
456
+ elif mode == "int4":
457
+ # Quantize to int4 values (stored in int8 dtype, range [-8, 7]).
458
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
459
+ self.quantization_config,
460
+ quantizers.AbsMaxQuantizer(
461
+ axis=-1,
462
+ value_range=(-8, 7),
463
+ output_dtype="int8",
464
+ ),
465
+ )
466
+ embeddings_value, embeddings_scale = weight_quantizer(
467
+ self._embeddings, to_numpy=True
468
+ )
469
+ embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
470
+ # 2. Pack two int4 values into a single int8 byte.
471
+ packed_embeddings_value, _, _ = quantizers.pack_int4(
472
+ embeddings_value, axis=-1
473
+ )
474
+ del self._embeddings
475
+ self.quantized_build(
476
+ embeddings_shape, mode, self.quantization_config
477
+ )
478
+ self._embeddings.assign(packed_embeddings_value)
479
+ self.embeddings_scale.assign(embeddings_scale)
480
+ else:
481
+ raise self._quantization_mode_error(mode)
384
482
 
385
- # Set new dtype policy
483
+ # Set new dtype policy.
386
484
  if self.dtype_policy.quantization_mode is None:
387
485
  policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
388
486
  self.dtype_policy = policy
389
487
 
390
488
  def _get_embeddings_with_merged_lora(self):
391
- if self.dtype_policy.quantization_mode is not None:
392
- embeddings_value = self._embeddings
393
- embeddings_scale = self.embeddings_scale
394
- if self.lora_enabled:
395
- # Dequantize & quantize to merge lora weights into embeddings
396
- # Note that this is a lossy compression
397
- embeddings_value = ops.divide(
398
- embeddings_value, ops.expand_dims(embeddings_scale, axis=-1)
399
- )
400
- embeddings_value = ops.add(
401
- embeddings_value,
402
- ops.matmul(self.lora_embeddings_a, self.lora_embeddings_b),
403
- )
404
- embeddings_value, embeddings_scale = (
405
- quantizers.abs_max_quantize(
406
- embeddings_value, axis=-1, to_numpy=True
407
- )
408
- )
409
- embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
489
+ """Returns the embeddings with LoRA matrices merged, for serialization.
490
+
491
+ This method is called by `save_own_variables` to produce a single
492
+ embeddings tensor that includes the adaptations from LoRA. This is
493
+ useful for deploying the model or for continuing training after
494
+ permanently applying the LoRA update.
495
+
496
+ If the layer is quantized (`int8` or `int4`), the process is:
497
+ 1. Dequantize the base embeddings to float.
498
+ 2. Compute the LoRA delta (`lora_embeddings_a @ lora_embeddings_b`) and
499
+ add it to the dequantized embeddings.
500
+ 3. Re-quantize the merged result back to the original quantized
501
+ type (`int8` or packed `int4`), calculating a new scale factor.
502
+
503
+ If the layer is not quantized, this method returns the result of the
504
+ `embeddings` property (which computes the merge in floating-point) and a
505
+ scale of `None`.
506
+
507
+ If LoRA is not enabled, it returns the original embeddings and scale
508
+ without modification.
509
+
510
+ Returns:
511
+ A tuple `(embeddings_value, embeddings_scale)`:
512
+ `embeddings_value`: The merged embeddings. A quantized tensor if
513
+ quantization is active, otherwise a high precision tensor.
514
+ `embeddings_scale`: The quantization scale for the merged
515
+ embeddings. This is `None` if the layer is not quantized.
516
+ """
517
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
518
+ return self.embeddings, None
519
+
520
+ embeddings_value = self._embeddings
521
+ embeddings_scale = self.embeddings_scale
522
+ if not self.lora_enabled:
410
523
  return embeddings_value, embeddings_scale
411
- return self.embeddings, None
524
+
525
+ # Dequantize embeddings to float.
526
+ if self.quantization_mode == "int4":
527
+ unpacked_embeddings = quantizers.unpack_int4(
528
+ embeddings_value, self._orig_output_dim, axis=-1
529
+ )
530
+ float_embeddings = ops.divide(
531
+ ops.cast(unpacked_embeddings, self.compute_dtype),
532
+ ops.expand_dims(embeddings_scale, axis=-1),
533
+ )
534
+ quant_range = (-8, 7)
535
+ elif self.quantization_mode == "int8":
536
+ float_embeddings = ops.divide(
537
+ ops.cast(embeddings_value, self.compute_dtype),
538
+ ops.expand_dims(embeddings_scale, axis=-1),
539
+ )
540
+ quant_range = (-127, 127)
541
+ else:
542
+ raise ValueError(
543
+ f"Unsupported quantization mode: {self.quantization_mode}"
544
+ )
545
+
546
+ # Merge LoRA weights in float domain.
547
+ lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul(
548
+ self.lora_embeddings_a, self.lora_embeddings_b
549
+ )
550
+ merged_float_embeddings = ops.add(float_embeddings, lora_delta)
551
+
552
+ # Requantize.
553
+ requantized_embeddings, embeddings_scale = quantizers.abs_max_quantize(
554
+ merged_float_embeddings,
555
+ axis=-1,
556
+ value_range=quant_range,
557
+ dtype="int8",
558
+ to_numpy=True,
559
+ )
560
+ embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
561
+
562
+ # Pack if int4.
563
+ if self.quantization_mode == "int4":
564
+ embeddings_value, _, _ = quantizers.pack_int4(
565
+ requantized_embeddings, axis=-1
566
+ )
567
+ else:
568
+ embeddings_value = requantized_embeddings
569
+ return embeddings_value, embeddings_scale
@@ -138,6 +138,7 @@ class InputLayer(Layer):
138
138
  "sparse": self.sparse,
139
139
  "ragged": self.ragged,
140
140
  "name": self.name,
141
+ "optional": self.optional,
141
142
  }
142
143
 
143
144