keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -10,6 +10,8 @@ from keras.src import regularizers
|
|
|
10
10
|
from keras.src.api_export import keras_export
|
|
11
11
|
from keras.src.backend import KerasTensor
|
|
12
12
|
from keras.src.layers.layer import Layer
|
|
13
|
+
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
14
|
+
from keras.src.saving import serialization_lib
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
@keras_export("keras.layers.Embedding")
|
|
@@ -90,6 +92,7 @@ class Embedding(Layer):
|
|
|
90
92
|
weights=None,
|
|
91
93
|
lora_rank=None,
|
|
92
94
|
lora_alpha=None,
|
|
95
|
+
quantization_config=None,
|
|
93
96
|
**kwargs,
|
|
94
97
|
):
|
|
95
98
|
input_length = kwargs.pop("input_length", None)
|
|
@@ -109,6 +112,7 @@ class Embedding(Layer):
|
|
|
109
112
|
self.lora_rank = lora_rank
|
|
110
113
|
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
|
|
111
114
|
self.lora_enabled = False
|
|
115
|
+
self.quantization_config = quantization_config
|
|
112
116
|
|
|
113
117
|
if weights is not None:
|
|
114
118
|
self.build()
|
|
@@ -121,7 +125,11 @@ class Embedding(Layer):
|
|
|
121
125
|
return
|
|
122
126
|
embeddings_shape = (self.input_dim, self.output_dim)
|
|
123
127
|
if self.quantization_mode:
|
|
124
|
-
self.quantized_build(
|
|
128
|
+
self.quantized_build(
|
|
129
|
+
embeddings_shape,
|
|
130
|
+
mode=self.quantization_mode,
|
|
131
|
+
config=self.quantization_config,
|
|
132
|
+
)
|
|
125
133
|
if self.quantization_mode not in ("int8", "int4"):
|
|
126
134
|
self._embeddings = self.add_weight(
|
|
127
135
|
shape=embeddings_shape,
|
|
@@ -218,24 +226,25 @@ class Embedding(Layer):
|
|
|
218
226
|
if not self.built:
|
|
219
227
|
return
|
|
220
228
|
mode = self.quantization_mode
|
|
221
|
-
if mode not in self.
|
|
229
|
+
if mode not in self.variable_serialization_spec:
|
|
222
230
|
raise self._quantization_mode_error(mode)
|
|
223
231
|
|
|
224
232
|
# Embeddings plus optional merged LoRA-aware scale
|
|
225
|
-
# (returns (
|
|
233
|
+
# (returns (embeddings, None) for `None` mode).
|
|
226
234
|
embeddings_value, merged_kernel_scale = (
|
|
227
235
|
self._get_embeddings_with_merged_lora()
|
|
228
236
|
)
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
237
|
+
idx = 0
|
|
238
|
+
for name in self.variable_serialization_spec[mode]:
|
|
239
|
+
if name == "embeddings":
|
|
240
|
+
store[str(idx)] = embeddings_value
|
|
241
|
+
elif name == "embeddings_scale" and mode in ("int4", "int8"):
|
|
234
242
|
# For int4/int8, the merged LoRA scale (if any) comes from
|
|
235
243
|
# `_get_embeddings_with_merged_lora()`
|
|
236
|
-
store[
|
|
244
|
+
store[str(idx)] = merged_kernel_scale
|
|
237
245
|
else:
|
|
238
|
-
store[
|
|
246
|
+
store[str(idx)] = getattr(self, name)
|
|
247
|
+
idx += 1
|
|
239
248
|
|
|
240
249
|
def load_own_variables(self, store):
|
|
241
250
|
if not self.lora_enabled:
|
|
@@ -244,36 +253,16 @@ class Embedding(Layer):
|
|
|
244
253
|
if not self.built:
|
|
245
254
|
return
|
|
246
255
|
mode = self.quantization_mode
|
|
247
|
-
if mode not in self.
|
|
256
|
+
if mode not in self.variable_serialization_spec:
|
|
248
257
|
raise self._quantization_mode_error(mode)
|
|
249
258
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
getattr(self, name).assign(store[name])
|
|
258
|
-
if self.lora_enabled:
|
|
259
|
-
self.lora_embeddings_a.assign(
|
|
260
|
-
ops.zeros(self.lora_embeddings_a.shape)
|
|
261
|
-
)
|
|
262
|
-
self.lora_embeddings_b.assign(
|
|
263
|
-
ops.zeros(self.lora_embeddings_b.shape)
|
|
264
|
-
)
|
|
265
|
-
|
|
266
|
-
def _legacy_load_own_variables(self, store):
|
|
267
|
-
# The keys of the `store` will be saved as determined because the
|
|
268
|
-
# default ordering will change after quantization
|
|
269
|
-
mode = self.quantization_mode
|
|
270
|
-
targets = [self._embeddings]
|
|
271
|
-
targets.extend(
|
|
272
|
-
getattr(self, name)
|
|
273
|
-
for name in self.quantization_variable_spec[mode]
|
|
274
|
-
)
|
|
275
|
-
for i, variable in enumerate(targets):
|
|
276
|
-
variable.assign(store[str(i)])
|
|
259
|
+
idx = 0
|
|
260
|
+
for name in self.variable_serialization_spec[mode]:
|
|
261
|
+
if name == "embeddings":
|
|
262
|
+
self._embeddings.assign(store[str(idx)])
|
|
263
|
+
else:
|
|
264
|
+
getattr(self, name).assign(store[str(idx)])
|
|
265
|
+
idx += 1
|
|
277
266
|
if self.lora_enabled:
|
|
278
267
|
self.lora_embeddings_a.assign(
|
|
279
268
|
ops.zeros(self.lora_embeddings_a.shape)
|
|
@@ -300,45 +289,24 @@ class Embedding(Layer):
|
|
|
300
289
|
self.embeddings_constraint
|
|
301
290
|
),
|
|
302
291
|
"mask_zero": self.mask_zero,
|
|
292
|
+
"quantization_config": serialization_lib.serialize_keras_object(
|
|
293
|
+
self.quantization_config
|
|
294
|
+
),
|
|
303
295
|
}
|
|
304
296
|
if self.lora_rank:
|
|
305
297
|
config["lora_rank"] = self.lora_rank
|
|
306
298
|
config["lora_alpha"] = self.lora_alpha
|
|
307
299
|
return {**base_config, **config}
|
|
308
300
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
"and thus it doesn't have any variables. "
|
|
316
|
-
f"However the weights file lists {len(store.keys())} "
|
|
317
|
-
"variables for this layer.\n"
|
|
318
|
-
"In most cases, this error indicates that either:\n\n"
|
|
319
|
-
"1. The layer is owned by a parent layer that "
|
|
320
|
-
"implements a `build()` method, but calling the "
|
|
321
|
-
"parent's `build()` method did NOT create the state of "
|
|
322
|
-
f"the child layer '{self.name}'. A `build()` method "
|
|
323
|
-
"must create ALL state for the layer, including "
|
|
324
|
-
"the state of any children layers.\n\n"
|
|
325
|
-
"2. You need to implement "
|
|
326
|
-
"the `def build_from_config(self, config)` method "
|
|
327
|
-
f"on layer '{self.name}', to specify how to rebuild "
|
|
328
|
-
"it during loading. "
|
|
329
|
-
"In this case, you might also want to implement the "
|
|
330
|
-
"method that generates the build config at saving time, "
|
|
331
|
-
"`def get_build_config(self)`. "
|
|
332
|
-
"The method `build_from_config()` is meant "
|
|
333
|
-
"to create the state "
|
|
334
|
-
"of the layer (i.e. its variables) upon deserialization.",
|
|
335
|
-
)
|
|
336
|
-
raise ValueError(
|
|
337
|
-
f"Layer '{self.name}' expected {len(all_vars)} variables, "
|
|
338
|
-
"but received "
|
|
339
|
-
f"{len(store.keys())} variables during loading. "
|
|
340
|
-
f"Expected: {[v.name for v in all_vars]}"
|
|
301
|
+
@classmethod
|
|
302
|
+
def from_config(cls, config):
|
|
303
|
+
config = config.copy()
|
|
304
|
+
config["quantization_config"] = (
|
|
305
|
+
serialization_lib.deserialize_keras_object(
|
|
306
|
+
config.get("quantization_config", None)
|
|
341
307
|
)
|
|
308
|
+
)
|
|
309
|
+
return super().from_config(config)
|
|
342
310
|
|
|
343
311
|
def _quantization_mode_error(self, mode):
|
|
344
312
|
return NotImplementedError(
|
|
@@ -347,29 +315,37 @@ class Embedding(Layer):
|
|
|
347
315
|
)
|
|
348
316
|
|
|
349
317
|
@property
|
|
350
|
-
def
|
|
351
|
-
"""Returns a dict mapping quantization modes to variable names.
|
|
318
|
+
def variable_serialization_spec(self):
|
|
319
|
+
"""Returns a dict mapping quantization modes to variable names in order.
|
|
352
320
|
|
|
353
321
|
This spec is used by `save_own_variables` and `load_own_variables` to
|
|
354
|
-
determine
|
|
355
|
-
mode.
|
|
322
|
+
determine the correct ordering of variables during serialization for
|
|
323
|
+
each quantization mode. `None` means no quantization.
|
|
356
324
|
"""
|
|
357
325
|
return {
|
|
358
|
-
None: [
|
|
359
|
-
|
|
360
|
-
|
|
326
|
+
None: [
|
|
327
|
+
"embeddings",
|
|
328
|
+
],
|
|
329
|
+
"int8": [
|
|
330
|
+
"embeddings",
|
|
331
|
+
"embeddings_scale",
|
|
332
|
+
],
|
|
333
|
+
"int4": [
|
|
334
|
+
"embeddings",
|
|
335
|
+
"embeddings_scale",
|
|
336
|
+
],
|
|
361
337
|
}
|
|
362
338
|
|
|
363
|
-
def quantized_build(self, embeddings_shape, mode):
|
|
339
|
+
def quantized_build(self, embeddings_shape, mode, config=None):
|
|
364
340
|
if mode == "int8":
|
|
365
|
-
self._int8_build(embeddings_shape)
|
|
341
|
+
self._int8_build(embeddings_shape, config)
|
|
366
342
|
elif mode == "int4":
|
|
367
|
-
self._int4_build(embeddings_shape)
|
|
343
|
+
self._int4_build(embeddings_shape, config)
|
|
368
344
|
else:
|
|
369
345
|
raise self._quantization_mode_error(mode)
|
|
370
346
|
self._is_quantized = True
|
|
371
347
|
|
|
372
|
-
def _int8_build(self, embeddings_shape):
|
|
348
|
+
def _int8_build(self, embeddings_shape, config=None):
|
|
373
349
|
self._embeddings = self.add_weight(
|
|
374
350
|
name="embeddings",
|
|
375
351
|
shape=embeddings_shape,
|
|
@@ -387,7 +363,7 @@ class Embedding(Layer):
|
|
|
387
363
|
trainable=False,
|
|
388
364
|
)
|
|
389
365
|
|
|
390
|
-
def _int4_build(self, embeddings_shape):
|
|
366
|
+
def _int4_build(self, embeddings_shape, config=None):
|
|
391
367
|
input_dim, output_dim = embeddings_shape
|
|
392
368
|
packed_rows = (output_dim + 1) // 2 # ceil for odd dims
|
|
393
369
|
|
|
@@ -452,31 +428,43 @@ class Embedding(Layer):
|
|
|
452
428
|
)
|
|
453
429
|
return outputs
|
|
454
430
|
|
|
455
|
-
def quantize(self, mode, type_check=True, config=None):
|
|
431
|
+
def quantize(self, mode=None, type_check=True, config=None):
|
|
456
432
|
# Prevent quantization of the subclasses.
|
|
457
433
|
if type_check and (type(self) is not Embedding):
|
|
458
434
|
raise self._not_implemented_error(self.quantize)
|
|
459
435
|
|
|
436
|
+
self.quantization_config = config
|
|
437
|
+
|
|
460
438
|
embeddings_shape = (self.input_dim, self.output_dim)
|
|
461
439
|
if mode == "int8":
|
|
462
440
|
# Quantize `self._embeddings` to int8 and compute corresponding
|
|
463
441
|
# scale.
|
|
464
|
-
|
|
465
|
-
self.
|
|
442
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
443
|
+
self.quantization_config,
|
|
444
|
+
quantizers.AbsMaxQuantizer(axis=-1),
|
|
445
|
+
)
|
|
446
|
+
embeddings_value, embeddings_scale = weight_quantizer(
|
|
447
|
+
self._embeddings, to_numpy=True
|
|
466
448
|
)
|
|
467
449
|
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
468
450
|
del self._embeddings
|
|
469
|
-
self.quantized_build(
|
|
451
|
+
self.quantized_build(
|
|
452
|
+
embeddings_shape, mode, self.quantization_config
|
|
453
|
+
)
|
|
470
454
|
self._embeddings.assign(embeddings_value)
|
|
471
455
|
self.embeddings_scale.assign(embeddings_scale)
|
|
472
456
|
elif mode == "int4":
|
|
473
457
|
# Quantize to int4 values (stored in int8 dtype, range [-8, 7]).
|
|
474
|
-
|
|
475
|
-
self.
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
458
|
+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
|
|
459
|
+
self.quantization_config,
|
|
460
|
+
quantizers.AbsMaxQuantizer(
|
|
461
|
+
axis=-1,
|
|
462
|
+
value_range=(-8, 7),
|
|
463
|
+
output_dtype="int8",
|
|
464
|
+
),
|
|
465
|
+
)
|
|
466
|
+
embeddings_value, embeddings_scale = weight_quantizer(
|
|
467
|
+
self._embeddings, to_numpy=True
|
|
480
468
|
)
|
|
481
469
|
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
482
470
|
# 2. Pack two int4 values into a single int8 byte.
|
|
@@ -484,7 +472,9 @@ class Embedding(Layer):
|
|
|
484
472
|
embeddings_value, axis=-1
|
|
485
473
|
)
|
|
486
474
|
del self._embeddings
|
|
487
|
-
self.quantized_build(
|
|
475
|
+
self.quantized_build(
|
|
476
|
+
embeddings_shape, mode, self.quantization_config
|
|
477
|
+
)
|
|
488
478
|
self._embeddings.assign(packed_embeddings_value)
|
|
489
479
|
self.embeddings_scale.assign(embeddings_scale)
|
|
490
480
|
else:
|
|
@@ -524,7 +514,7 @@ class Embedding(Layer):
|
|
|
524
514
|
`embeddings_scale`: The quantization scale for the merged
|
|
525
515
|
embeddings. This is `None` if the layer is not quantized.
|
|
526
516
|
"""
|
|
527
|
-
if self.dtype_policy.quantization_mode in (None, "gptq"):
|
|
517
|
+
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
528
518
|
return self.embeddings, None
|
|
529
519
|
|
|
530
520
|
embeddings_value = self._embeddings
|