keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,8 @@ from keras.src import regularizers
10
10
  from keras.src.api_export import keras_export
11
11
  from keras.src.backend import KerasTensor
12
12
  from keras.src.layers.layer import Layer
13
+ from keras.src.quantizers.quantization_config import QuantizationConfig
14
+ from keras.src.saving import serialization_lib
13
15
 
14
16
 
15
17
  @keras_export("keras.layers.Embedding")
@@ -90,6 +92,7 @@ class Embedding(Layer):
90
92
  weights=None,
91
93
  lora_rank=None,
92
94
  lora_alpha=None,
95
+ quantization_config=None,
93
96
  **kwargs,
94
97
  ):
95
98
  input_length = kwargs.pop("input_length", None)
@@ -109,6 +112,7 @@ class Embedding(Layer):
109
112
  self.lora_rank = lora_rank
110
113
  self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
111
114
  self.lora_enabled = False
115
+ self.quantization_config = quantization_config
112
116
 
113
117
  if weights is not None:
114
118
  self.build()
@@ -121,7 +125,11 @@ class Embedding(Layer):
121
125
  return
122
126
  embeddings_shape = (self.input_dim, self.output_dim)
123
127
  if self.quantization_mode:
124
- self.quantized_build(embeddings_shape, mode=self.quantization_mode)
128
+ self.quantized_build(
129
+ embeddings_shape,
130
+ mode=self.quantization_mode,
131
+ config=self.quantization_config,
132
+ )
125
133
  if self.quantization_mode not in ("int8", "int4"):
126
134
  self._embeddings = self.add_weight(
127
135
  shape=embeddings_shape,
@@ -218,24 +226,25 @@ class Embedding(Layer):
218
226
  if not self.built:
219
227
  return
220
228
  mode = self.quantization_mode
221
- if mode not in self.quantization_variable_spec:
229
+ if mode not in self.variable_serialization_spec:
222
230
  raise self._quantization_mode_error(mode)
223
231
 
224
232
  # Embeddings plus optional merged LoRA-aware scale
225
- # (returns (kernel, None) for None/gptq).
233
+ # (returns (embeddings, None) for `None` mode).
226
234
  embeddings_value, merged_kernel_scale = (
227
235
  self._get_embeddings_with_merged_lora()
228
236
  )
229
-
230
- # Save the variables using the name as the key.
231
- store["embeddings"] = embeddings_value
232
- for name in self.quantization_variable_spec[mode]:
233
- if name == "embeddings_scale" and mode in ("int4", "int8"):
237
+ idx = 0
238
+ for name in self.variable_serialization_spec[mode]:
239
+ if name == "embeddings":
240
+ store[str(idx)] = embeddings_value
241
+ elif name == "embeddings_scale" and mode in ("int4", "int8"):
234
242
  # For int4/int8, the merged LoRA scale (if any) comes from
235
243
  # `_get_embeddings_with_merged_lora()`
236
- store[name] = merged_kernel_scale
244
+ store[str(idx)] = merged_kernel_scale
237
245
  else:
238
- store[name] = getattr(self, name)
246
+ store[str(idx)] = getattr(self, name)
247
+ idx += 1
239
248
 
240
249
  def load_own_variables(self, store):
241
250
  if not self.lora_enabled:
@@ -244,36 +253,16 @@ class Embedding(Layer):
244
253
  if not self.built:
245
254
  return
246
255
  mode = self.quantization_mode
247
- if mode not in self.quantization_variable_spec:
256
+ if mode not in self.variable_serialization_spec:
248
257
  raise self._quantization_mode_error(mode)
249
258
 
250
- # Determine whether to use the legacy loading method.
251
- if "0" in store:
252
- return self._legacy_load_own_variables(store)
253
-
254
- # Load the variables using the name as the key.
255
- self._embeddings.assign(store["embeddings"])
256
- for name in self.quantization_variable_spec[mode]:
257
- getattr(self, name).assign(store[name])
258
- if self.lora_enabled:
259
- self.lora_embeddings_a.assign(
260
- ops.zeros(self.lora_embeddings_a.shape)
261
- )
262
- self.lora_embeddings_b.assign(
263
- ops.zeros(self.lora_embeddings_b.shape)
264
- )
265
-
266
- def _legacy_load_own_variables(self, store):
267
- # The keys of the `store` will be saved as determined because the
268
- # default ordering will change after quantization
269
- mode = self.quantization_mode
270
- targets = [self._embeddings]
271
- targets.extend(
272
- getattr(self, name)
273
- for name in self.quantization_variable_spec[mode]
274
- )
275
- for i, variable in enumerate(targets):
276
- variable.assign(store[str(i)])
259
+ idx = 0
260
+ for name in self.variable_serialization_spec[mode]:
261
+ if name == "embeddings":
262
+ self._embeddings.assign(store[str(idx)])
263
+ else:
264
+ getattr(self, name).assign(store[str(idx)])
265
+ idx += 1
277
266
  if self.lora_enabled:
278
267
  self.lora_embeddings_a.assign(
279
268
  ops.zeros(self.lora_embeddings_a.shape)
@@ -300,45 +289,24 @@ class Embedding(Layer):
300
289
  self.embeddings_constraint
301
290
  ),
302
291
  "mask_zero": self.mask_zero,
292
+ "quantization_config": serialization_lib.serialize_keras_object(
293
+ self.quantization_config
294
+ ),
303
295
  }
304
296
  if self.lora_rank:
305
297
  config["lora_rank"] = self.lora_rank
306
298
  config["lora_alpha"] = self.lora_alpha
307
299
  return {**base_config, **config}
308
300
 
309
- def _check_load_own_variables(self, store):
310
- all_vars = self._trainable_variables + self._non_trainable_variables
311
- if len(store.keys()) != len(all_vars):
312
- if len(all_vars) == 0 and not self.built:
313
- raise ValueError(
314
- f"Layer '{self.name}' was never built "
315
- "and thus it doesn't have any variables. "
316
- f"However the weights file lists {len(store.keys())} "
317
- "variables for this layer.\n"
318
- "In most cases, this error indicates that either:\n\n"
319
- "1. The layer is owned by a parent layer that "
320
- "implements a `build()` method, but calling the "
321
- "parent's `build()` method did NOT create the state of "
322
- f"the child layer '{self.name}'. A `build()` method "
323
- "must create ALL state for the layer, including "
324
- "the state of any children layers.\n\n"
325
- "2. You need to implement "
326
- "the `def build_from_config(self, config)` method "
327
- f"on layer '{self.name}', to specify how to rebuild "
328
- "it during loading. "
329
- "In this case, you might also want to implement the "
330
- "method that generates the build config at saving time, "
331
- "`def get_build_config(self)`. "
332
- "The method `build_from_config()` is meant "
333
- "to create the state "
334
- "of the layer (i.e. its variables) upon deserialization.",
335
- )
336
- raise ValueError(
337
- f"Layer '{self.name}' expected {len(all_vars)} variables, "
338
- "but received "
339
- f"{len(store.keys())} variables during loading. "
340
- f"Expected: {[v.name for v in all_vars]}"
301
+ @classmethod
302
+ def from_config(cls, config):
303
+ config = config.copy()
304
+ config["quantization_config"] = (
305
+ serialization_lib.deserialize_keras_object(
306
+ config.get("quantization_config", None)
341
307
  )
308
+ )
309
+ return super().from_config(config)
342
310
 
343
311
  def _quantization_mode_error(self, mode):
344
312
  return NotImplementedError(
@@ -347,29 +315,37 @@ class Embedding(Layer):
347
315
  )
348
316
 
349
317
  @property
350
- def quantization_variable_spec(self):
351
- """Returns a dict mapping quantization modes to variable names.
318
+ def variable_serialization_spec(self):
319
+ """Returns a dict mapping quantization modes to variable names in order.
352
320
 
353
321
  This spec is used by `save_own_variables` and `load_own_variables` to
354
- determine which variables should be saved/loaded for each quantization
355
- mode.
322
+ determine the correct ordering of variables during serialization for
323
+ each quantization mode. `None` means no quantization.
356
324
  """
357
325
  return {
358
- None: [],
359
- "int8": ["embeddings_scale"],
360
- "int4": ["embeddings_scale"],
326
+ None: [
327
+ "embeddings",
328
+ ],
329
+ "int8": [
330
+ "embeddings",
331
+ "embeddings_scale",
332
+ ],
333
+ "int4": [
334
+ "embeddings",
335
+ "embeddings_scale",
336
+ ],
361
337
  }
362
338
 
363
- def quantized_build(self, embeddings_shape, mode):
339
+ def quantized_build(self, embeddings_shape, mode, config=None):
364
340
  if mode == "int8":
365
- self._int8_build(embeddings_shape)
341
+ self._int8_build(embeddings_shape, config)
366
342
  elif mode == "int4":
367
- self._int4_build(embeddings_shape)
343
+ self._int4_build(embeddings_shape, config)
368
344
  else:
369
345
  raise self._quantization_mode_error(mode)
370
346
  self._is_quantized = True
371
347
 
372
- def _int8_build(self, embeddings_shape):
348
+ def _int8_build(self, embeddings_shape, config=None):
373
349
  self._embeddings = self.add_weight(
374
350
  name="embeddings",
375
351
  shape=embeddings_shape,
@@ -387,7 +363,7 @@ class Embedding(Layer):
387
363
  trainable=False,
388
364
  )
389
365
 
390
- def _int4_build(self, embeddings_shape):
366
+ def _int4_build(self, embeddings_shape, config=None):
391
367
  input_dim, output_dim = embeddings_shape
392
368
  packed_rows = (output_dim + 1) // 2 # ceil for odd dims
393
369
 
@@ -452,31 +428,43 @@ class Embedding(Layer):
452
428
  )
453
429
  return outputs
454
430
 
455
- def quantize(self, mode, type_check=True, config=None):
431
+ def quantize(self, mode=None, type_check=True, config=None):
456
432
  # Prevent quantization of the subclasses.
457
433
  if type_check and (type(self) is not Embedding):
458
434
  raise self._not_implemented_error(self.quantize)
459
435
 
436
+ self.quantization_config = config
437
+
460
438
  embeddings_shape = (self.input_dim, self.output_dim)
461
439
  if mode == "int8":
462
440
  # Quantize `self._embeddings` to int8 and compute corresponding
463
441
  # scale.
464
- embeddings_value, embeddings_scale = quantizers.abs_max_quantize(
465
- self._embeddings, axis=-1, to_numpy=True
442
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
443
+ self.quantization_config,
444
+ quantizers.AbsMaxQuantizer(axis=-1),
445
+ )
446
+ embeddings_value, embeddings_scale = weight_quantizer(
447
+ self._embeddings, to_numpy=True
466
448
  )
467
449
  embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
468
450
  del self._embeddings
469
- self.quantized_build(embeddings_shape, mode)
451
+ self.quantized_build(
452
+ embeddings_shape, mode, self.quantization_config
453
+ )
470
454
  self._embeddings.assign(embeddings_value)
471
455
  self.embeddings_scale.assign(embeddings_scale)
472
456
  elif mode == "int4":
473
457
  # Quantize to int4 values (stored in int8 dtype, range [-8, 7]).
474
- embeddings_value, embeddings_scale = quantizers.abs_max_quantize(
475
- self._embeddings,
476
- axis=-1,
477
- value_range=(-8, 7),
478
- dtype="int8",
479
- to_numpy=True,
458
+ weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
459
+ self.quantization_config,
460
+ quantizers.AbsMaxQuantizer(
461
+ axis=-1,
462
+ value_range=(-8, 7),
463
+ output_dtype="int8",
464
+ ),
465
+ )
466
+ embeddings_value, embeddings_scale = weight_quantizer(
467
+ self._embeddings, to_numpy=True
480
468
  )
481
469
  embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
482
470
  # 2. Pack two int4 values into a single int8 byte.
@@ -484,7 +472,9 @@ class Embedding(Layer):
484
472
  embeddings_value, axis=-1
485
473
  )
486
474
  del self._embeddings
487
- self.quantized_build(embeddings_shape, mode)
475
+ self.quantized_build(
476
+ embeddings_shape, mode, self.quantization_config
477
+ )
488
478
  self._embeddings.assign(packed_embeddings_value)
489
479
  self.embeddings_scale.assign(embeddings_scale)
490
480
  else:
@@ -524,7 +514,7 @@ class Embedding(Layer):
524
514
  `embeddings_scale`: The quantization scale for the merged
525
515
  embeddings. This is `None` if the layer is not quantized.
526
516
  """
527
- if self.dtype_policy.quantization_mode in (None, "gptq"):
517
+ if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
528
518
  return self.embeddings, None
529
519
 
530
520
  embeddings_value = self._embeddings
@@ -138,6 +138,7 @@ class InputLayer(Layer):
138
138
  "sparse": self.sparse,
139
139
  "ragged": self.ragged,
140
140
  "name": self.name,
141
+ "optional": self.optional,
141
142
  }
142
143
 
143
144