keras-nightly 3.14.0.dev2026012804__py3-none-any.whl → 3.14.0.dev2026013004__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/quantizers/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/quantizers/__init__.py +3 -0
- keras/src/backend/jax/core.py +12 -2
- keras/src/callbacks/orbax_checkpoint.py +41 -8
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +80 -1
- keras/src/export/tfsm_layer.py +34 -0
- keras/src/layers/core/dense.py +278 -95
- keras/src/layers/core/einsum_dense.py +350 -181
- keras/src/layers/core/embedding.py +236 -49
- keras/src/layers/core/reversible_embedding.py +177 -35
- keras/src/layers/preprocessing/discretization.py +30 -1
- keras/src/quantizers/__init__.py +6 -0
- keras/src/quantizers/quantization_config.py +98 -4
- keras/src/quantizers/quantizers.py +262 -32
- keras/src/saving/saving_api.py +66 -2
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026013004.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026013004.dist-info}/RECORD +23 -23
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026013004.dist-info}/WHEEL +0 -0
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026013004.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
import warnings
|
|
2
3
|
|
|
3
4
|
from keras.src import backend
|
|
@@ -11,6 +12,8 @@ from keras.src.api_export import keras_export
|
|
|
11
12
|
from keras.src.backend import KerasTensor
|
|
12
13
|
from keras.src.layers.layer import Layer
|
|
13
14
|
from keras.src.quantizers.quantization_config import QuantizationConfig
|
|
15
|
+
from keras.src.quantizers.quantization_config import get_block_size_for_layer
|
|
16
|
+
from keras.src.quantizers.quantizers import dequantize_with_sz_map
|
|
14
17
|
from keras.src.saving import serialization_lib
|
|
15
18
|
|
|
16
19
|
|
|
@@ -229,20 +232,37 @@ class Embedding(Layer):
|
|
|
229
232
|
if mode not in self.variable_serialization_spec:
|
|
230
233
|
raise self._quantization_mode_error(mode)
|
|
231
234
|
|
|
232
|
-
# Embeddings plus optional merged LoRA-aware scale
|
|
233
|
-
# (
|
|
234
|
-
embeddings_value,
|
|
235
|
+
# Embeddings plus optional merged LoRA-aware scale/zero (returns
|
|
236
|
+
# (embeddings, None, None) for `None` mode).
|
|
237
|
+
embeddings_value, merged_embeddings_scale, merged_embeddings_zero = (
|
|
235
238
|
self._get_embeddings_with_merged_lora()
|
|
236
239
|
)
|
|
237
240
|
idx = 0
|
|
238
241
|
for name in self.variable_serialization_spec[mode]:
|
|
239
242
|
if name == "embeddings":
|
|
240
243
|
store[str(idx)] = embeddings_value
|
|
244
|
+
elif name == "embeddings_zero":
|
|
245
|
+
if merged_embeddings_zero is None:
|
|
246
|
+
# embeddings_zero only exists for sub-channel int4
|
|
247
|
+
# quantization
|
|
248
|
+
continue
|
|
249
|
+
store[str(idx)] = merged_embeddings_zero
|
|
250
|
+
elif name == "g_idx" and not hasattr(self, "g_idx"):
|
|
251
|
+
# g_idx only exists for sub-channel int4 quantization
|
|
252
|
+
continue
|
|
241
253
|
elif name == "embeddings_scale" and mode in ("int4", "int8"):
|
|
242
254
|
# For int4/int8, the merged LoRA scale (if any) comes from
|
|
243
255
|
# `_get_embeddings_with_merged_lora()`
|
|
244
|
-
store[str(idx)] =
|
|
256
|
+
store[str(idx)] = merged_embeddings_scale
|
|
245
257
|
else:
|
|
258
|
+
# Generic handling for subclass variables:
|
|
259
|
+
# Check if the attribute exists on the instance before saving.
|
|
260
|
+
# This supports optional variables in subclasses (e.g.,
|
|
261
|
+
# `reverse_embeddings_zero` in ReversibleEmbedding) that are
|
|
262
|
+
# present in the spec but may not exist on the object depending
|
|
263
|
+
# on configuration (e.g., per-channel vs. sub-channel).
|
|
264
|
+
if not hasattr(self, name):
|
|
265
|
+
continue
|
|
246
266
|
store[str(idx)] = getattr(self, name)
|
|
247
267
|
idx += 1
|
|
248
268
|
|
|
@@ -260,7 +280,21 @@ class Embedding(Layer):
|
|
|
260
280
|
for name in self.variable_serialization_spec[mode]:
|
|
261
281
|
if name == "embeddings":
|
|
262
282
|
self._embeddings.assign(store[str(idx)])
|
|
283
|
+
elif name == "embeddings_zero" and not hasattr(
|
|
284
|
+
self, "embeddings_zero"
|
|
285
|
+
):
|
|
286
|
+
# embeddings_zero only exists for sub-channel int4 quantization
|
|
287
|
+
continue
|
|
288
|
+
elif name == "g_idx" and not hasattr(self, "g_idx"):
|
|
289
|
+
# g_idx only exists for sub-channel int4 quantization
|
|
290
|
+
continue
|
|
263
291
|
else:
|
|
292
|
+
# Generic handling for subclass variables:
|
|
293
|
+
# Check if the attribute exists before attempting to assign.
|
|
294
|
+
# If the variable is in the spec but missing from the object,
|
|
295
|
+
# we skip it to prevent AttributeError.
|
|
296
|
+
if not hasattr(self, name):
|
|
297
|
+
continue
|
|
264
298
|
getattr(self, name).assign(store[str(idx)])
|
|
265
299
|
idx += 1
|
|
266
300
|
if self.lora_enabled:
|
|
@@ -333,6 +367,8 @@ class Embedding(Layer):
|
|
|
333
367
|
"int4": [
|
|
334
368
|
"embeddings",
|
|
335
369
|
"embeddings_scale",
|
|
370
|
+
"embeddings_zero",
|
|
371
|
+
"g_idx",
|
|
336
372
|
],
|
|
337
373
|
}
|
|
338
374
|
|
|
@@ -364,11 +400,17 @@ class Embedding(Layer):
|
|
|
364
400
|
)
|
|
365
401
|
|
|
366
402
|
def _int4_build(self, embeddings_shape, config=None):
|
|
403
|
+
"""Build variables for int4 quantization.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
embeddings_shape: Original shape `(input_dim, output_dim)`.
|
|
407
|
+
config: Optional quantization config specifying block_size.
|
|
408
|
+
"""
|
|
367
409
|
input_dim, output_dim = embeddings_shape
|
|
368
|
-
packed_rows = (output_dim + 1) // 2
|
|
410
|
+
packed_rows = (output_dim + 1) // 2
|
|
369
411
|
|
|
370
|
-
# Embeddings are stored
|
|
371
|
-
# values.
|
|
412
|
+
# Embeddings are stored packed: each int8 byte contains two
|
|
413
|
+
# int4 values.
|
|
372
414
|
self._embeddings = self.add_weight(
|
|
373
415
|
name="embeddings",
|
|
374
416
|
shape=(input_dim, packed_rows),
|
|
@@ -376,13 +418,46 @@ class Embedding(Layer):
|
|
|
376
418
|
dtype="int8",
|
|
377
419
|
trainable=False,
|
|
378
420
|
)
|
|
421
|
+
|
|
422
|
+
block_size = get_block_size_for_layer(self, config)
|
|
423
|
+
self._int4_block_size = block_size
|
|
424
|
+
|
|
425
|
+
if block_size is None or block_size == -1:
|
|
426
|
+
scale_shape = (self.input_dim,)
|
|
427
|
+
else:
|
|
428
|
+
n_groups = math.ceil(output_dim / block_size)
|
|
429
|
+
scale_shape = (self.input_dim, n_groups)
|
|
430
|
+
|
|
379
431
|
self.embeddings_scale = self.add_weight(
|
|
380
432
|
name="embeddings_scale",
|
|
381
|
-
shape=
|
|
433
|
+
shape=scale_shape,
|
|
382
434
|
initializer="ones",
|
|
383
435
|
trainable=False,
|
|
384
436
|
)
|
|
385
|
-
|
|
437
|
+
|
|
438
|
+
# Sub-channel quantization uses asymmetric quantization with
|
|
439
|
+
# zero point
|
|
440
|
+
if block_size is not None and block_size > 0:
|
|
441
|
+
self.embeddings_zero = self.add_weight(
|
|
442
|
+
name="embeddings_zero",
|
|
443
|
+
shape=scale_shape,
|
|
444
|
+
initializer="zeros",
|
|
445
|
+
dtype="int8",
|
|
446
|
+
trainable=False,
|
|
447
|
+
)
|
|
448
|
+
self.g_idx = self.add_weight(
|
|
449
|
+
name="g_idx",
|
|
450
|
+
shape=(output_dim,),
|
|
451
|
+
initializer="zeros",
|
|
452
|
+
dtype="float32",
|
|
453
|
+
trainable=False,
|
|
454
|
+
)
|
|
455
|
+
self.g_idx.assign(
|
|
456
|
+
ops.floor_divide(
|
|
457
|
+
ops.arange(output_dim, dtype="float32"), block_size
|
|
458
|
+
)
|
|
459
|
+
)
|
|
460
|
+
|
|
386
461
|
self._orig_output_dim = output_dim
|
|
387
462
|
|
|
388
463
|
def _int8_call(self, inputs, training=None):
|
|
@@ -406,20 +481,38 @@ class Embedding(Layer):
|
|
|
406
481
|
return outputs
|
|
407
482
|
|
|
408
483
|
def _int4_call(self, inputs, training=None):
|
|
409
|
-
|
|
410
|
-
# not needed
|
|
484
|
+
"""Forward pass for int4 quantized Embedding layer."""
|
|
411
485
|
if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"):
|
|
412
486
|
inputs = ops.cast(inputs, "int32")
|
|
413
|
-
|
|
487
|
+
|
|
414
488
|
unpacked_embeddings = quantizers.unpack_int4(
|
|
415
489
|
self._embeddings, self._orig_output_dim, axis=-1
|
|
416
490
|
)
|
|
417
491
|
outputs = ops.take(unpacked_embeddings, inputs, axis=0)
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
492
|
+
|
|
493
|
+
block_size = getattr(self, "_int4_block_size", None)
|
|
494
|
+
|
|
495
|
+
if block_size is None or block_size == -1:
|
|
496
|
+
embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)
|
|
497
|
+
outputs = ops.divide(
|
|
498
|
+
ops.cast(outputs, dtype=self.compute_dtype),
|
|
499
|
+
ops.expand_dims(embeddings_scale, axis=-1),
|
|
500
|
+
)
|
|
501
|
+
else:
|
|
502
|
+
# Sub-channel: look up scale/zero for each input token,
|
|
503
|
+
# then dequantize using g_idx to expand groups
|
|
504
|
+
embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)
|
|
505
|
+
embeddings_zero = ops.take(self.embeddings_zero, inputs, axis=0)
|
|
506
|
+
|
|
507
|
+
# Scale/zero are [batch..., n_groups], g_idx is [output_dim]
|
|
508
|
+
outputs = dequantize_with_sz_map(
|
|
509
|
+
ops.cast(outputs, dtype=self.compute_dtype),
|
|
510
|
+
embeddings_scale,
|
|
511
|
+
embeddings_zero,
|
|
512
|
+
self.g_idx,
|
|
513
|
+
group_axis=-1,
|
|
514
|
+
)
|
|
515
|
+
|
|
423
516
|
if self.lora_enabled:
|
|
424
517
|
lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0)
|
|
425
518
|
lora_outputs = ops.matmul(lora_outputs, self.lora_embeddings_b)
|
|
@@ -454,20 +547,52 @@ class Embedding(Layer):
|
|
|
454
547
|
self._embeddings.assign(embeddings_value)
|
|
455
548
|
self.embeddings_scale.assign(embeddings_scale)
|
|
456
549
|
elif mode == "int4":
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
self.quantization_config,
|
|
460
|
-
quantizers.AbsMaxQuantizer(
|
|
461
|
-
axis=-1,
|
|
462
|
-
value_range=(-8, 7),
|
|
463
|
-
output_dtype="int8",
|
|
464
|
-
),
|
|
465
|
-
)
|
|
466
|
-
embeddings_value, embeddings_scale = weight_quantizer(
|
|
467
|
-
self._embeddings, to_numpy=True
|
|
550
|
+
from keras.src.quantizers.quantization_config import (
|
|
551
|
+
Int4QuantizationConfig,
|
|
468
552
|
)
|
|
469
|
-
|
|
470
|
-
|
|
553
|
+
|
|
554
|
+
block_size = None
|
|
555
|
+
if isinstance(self.quantization_config, Int4QuantizationConfig):
|
|
556
|
+
block_size = self.quantization_config.block_size
|
|
557
|
+
|
|
558
|
+
use_grouped = block_size is not None and block_size != -1
|
|
559
|
+
|
|
560
|
+
if not use_grouped:
|
|
561
|
+
# Per-channel quantization
|
|
562
|
+
weight_quantizer = (
|
|
563
|
+
QuantizationConfig.weight_quantizer_or_default(
|
|
564
|
+
self.quantization_config,
|
|
565
|
+
quantizers.AbsMaxQuantizer(
|
|
566
|
+
axis=-1,
|
|
567
|
+
value_range=(-8, 7),
|
|
568
|
+
output_dtype="int8",
|
|
569
|
+
),
|
|
570
|
+
)
|
|
571
|
+
)
|
|
572
|
+
embeddings_value, embeddings_scale = weight_quantizer(
|
|
573
|
+
self._embeddings, to_numpy=True
|
|
574
|
+
)
|
|
575
|
+
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
576
|
+
else:
|
|
577
|
+
# Sub-channel quantization with asymmetric zero point
|
|
578
|
+
input_dim, output_dim = ops.shape(self._embeddings)
|
|
579
|
+
# Transpose to put output_dim first for grouped quantization
|
|
580
|
+
embeddings_t = ops.transpose(self._embeddings)
|
|
581
|
+
|
|
582
|
+
embeddings_value_t, scale_t, zero_t = (
|
|
583
|
+
quantizers.abs_max_quantize_grouped_with_zero_point(
|
|
584
|
+
embeddings_t,
|
|
585
|
+
block_size=block_size,
|
|
586
|
+
value_range=(-8, 7),
|
|
587
|
+
dtype="int8",
|
|
588
|
+
to_numpy=True,
|
|
589
|
+
)
|
|
590
|
+
)
|
|
591
|
+
# Transpose back to (input_dim, output_dim) layout
|
|
592
|
+
embeddings_value = ops.transpose(embeddings_value_t)
|
|
593
|
+
embeddings_scale = ops.transpose(scale_t)
|
|
594
|
+
embeddings_zero = ops.transpose(zero_t)
|
|
595
|
+
|
|
471
596
|
packed_embeddings_value, _, _ = quantizers.pack_int4(
|
|
472
597
|
embeddings_value, axis=-1
|
|
473
598
|
)
|
|
@@ -477,12 +602,22 @@ class Embedding(Layer):
|
|
|
477
602
|
)
|
|
478
603
|
self._embeddings.assign(packed_embeddings_value)
|
|
479
604
|
self.embeddings_scale.assign(embeddings_scale)
|
|
605
|
+
if use_grouped:
|
|
606
|
+
self.embeddings_zero.assign(embeddings_zero)
|
|
480
607
|
else:
|
|
481
608
|
raise self._quantization_mode_error(mode)
|
|
482
609
|
|
|
483
610
|
# Set new dtype policy.
|
|
484
611
|
if self.dtype_policy.quantization_mode is None:
|
|
485
|
-
|
|
612
|
+
policy_name = mode
|
|
613
|
+
if mode == "int4":
|
|
614
|
+
# Include block_size in policy name for sub-channel quantization
|
|
615
|
+
block_size = get_block_size_for_layer(self, config)
|
|
616
|
+
block_size_value = -1 if block_size is None else block_size
|
|
617
|
+
policy_name = f"int4/{block_size_value}"
|
|
618
|
+
policy = dtype_policies.get(
|
|
619
|
+
f"{policy_name}_from_{self.dtype_policy.name}"
|
|
620
|
+
)
|
|
486
621
|
self.dtype_policy = policy
|
|
487
622
|
|
|
488
623
|
def _get_embeddings_with_merged_lora(self):
|
|
@@ -508,29 +643,46 @@ class Embedding(Layer):
|
|
|
508
643
|
without modification.
|
|
509
644
|
|
|
510
645
|
Returns:
|
|
511
|
-
A tuple `(embeddings_value, embeddings_scale)`:
|
|
646
|
+
A tuple `(embeddings_value, embeddings_scale, embeddings_zero)`:
|
|
512
647
|
`embeddings_value`: The merged embeddings. A quantized tensor if
|
|
513
648
|
quantization is active, otherwise a high precision tensor.
|
|
514
649
|
`embeddings_scale`: The quantization scale for the merged
|
|
515
650
|
embeddings. This is `None` if the layer is not quantized.
|
|
651
|
+
`embeddings_zero`: The zero point for sub-channel quantization.
|
|
652
|
+
This is `None` for per-channel quantization modes.
|
|
516
653
|
"""
|
|
517
654
|
if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
|
|
518
|
-
return self.embeddings, None
|
|
655
|
+
return self.embeddings, None, None
|
|
519
656
|
|
|
520
657
|
embeddings_value = self._embeddings
|
|
521
658
|
embeddings_scale = self.embeddings_scale
|
|
659
|
+
embeddings_zero = getattr(self, "embeddings_zero", None)
|
|
660
|
+
|
|
522
661
|
if not self.lora_enabled:
|
|
523
|
-
return embeddings_value, embeddings_scale
|
|
662
|
+
return embeddings_value, embeddings_scale, embeddings_zero
|
|
663
|
+
|
|
664
|
+
block_size = getattr(self, "_int4_block_size", None)
|
|
524
665
|
|
|
525
666
|
# Dequantize embeddings to float.
|
|
526
667
|
if self.quantization_mode == "int4":
|
|
527
668
|
unpacked_embeddings = quantizers.unpack_int4(
|
|
528
669
|
embeddings_value, self._orig_output_dim, axis=-1
|
|
529
670
|
)
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
ops.
|
|
533
|
-
|
|
671
|
+
if block_size is None or block_size == -1:
|
|
672
|
+
# Per-channel dequantization
|
|
673
|
+
float_embeddings = ops.divide(
|
|
674
|
+
ops.cast(unpacked_embeddings, self.compute_dtype),
|
|
675
|
+
ops.expand_dims(embeddings_scale, axis=-1),
|
|
676
|
+
)
|
|
677
|
+
else:
|
|
678
|
+
# Sub-channel: grouped dequantization using shared utility
|
|
679
|
+
float_embeddings = dequantize_with_sz_map(
|
|
680
|
+
ops.cast(unpacked_embeddings, self.compute_dtype),
|
|
681
|
+
embeddings_scale,
|
|
682
|
+
self.embeddings_zero,
|
|
683
|
+
self.g_idx,
|
|
684
|
+
group_axis=-1,
|
|
685
|
+
)
|
|
534
686
|
quant_range = (-8, 7)
|
|
535
687
|
elif self.quantization_mode == "int8":
|
|
536
688
|
float_embeddings = ops.divide(
|
|
@@ -550,20 +702,55 @@ class Embedding(Layer):
|
|
|
550
702
|
merged_float_embeddings = ops.add(float_embeddings, lora_delta)
|
|
551
703
|
|
|
552
704
|
# Requantize.
|
|
553
|
-
requantized_embeddings, embeddings_scale = quantizers.abs_max_quantize(
|
|
554
|
-
merged_float_embeddings,
|
|
555
|
-
axis=-1,
|
|
556
|
-
value_range=quant_range,
|
|
557
|
-
dtype="int8",
|
|
558
|
-
to_numpy=True,
|
|
559
|
-
)
|
|
560
|
-
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
561
|
-
|
|
562
|
-
# Pack if int4.
|
|
563
705
|
if self.quantization_mode == "int4":
|
|
706
|
+
if block_size is None or block_size == -1:
|
|
707
|
+
# Per-channel re-quantization
|
|
708
|
+
requantized_embeddings, new_scale = quantizers.abs_max_quantize(
|
|
709
|
+
merged_float_embeddings,
|
|
710
|
+
axis=-1,
|
|
711
|
+
value_range=quant_range,
|
|
712
|
+
dtype="int8",
|
|
713
|
+
to_numpy=True,
|
|
714
|
+
)
|
|
715
|
+
new_scale = ops.squeeze(new_scale, axis=-1)
|
|
716
|
+
embeddings_zero = None
|
|
717
|
+
else:
|
|
718
|
+
# Grouped re-quantization (asymmetric with zero point)
|
|
719
|
+
merged_np = merged_float_embeddings
|
|
720
|
+
# Transpose to (output_dim, input_dim) for grouped quantization
|
|
721
|
+
merged_t = ops.transpose(merged_np)
|
|
722
|
+
|
|
723
|
+
requantized_t, scale_t, zero_t = (
|
|
724
|
+
quantizers.abs_max_quantize_grouped_with_zero_point(
|
|
725
|
+
merged_t,
|
|
726
|
+
block_size=block_size,
|
|
727
|
+
value_range=quant_range,
|
|
728
|
+
dtype="int8",
|
|
729
|
+
to_numpy=True,
|
|
730
|
+
)
|
|
731
|
+
)
|
|
732
|
+
# Transpose back
|
|
733
|
+
requantized_embeddings = ops.transpose(requantized_t)
|
|
734
|
+
new_scale = ops.transpose(scale_t)
|
|
735
|
+
embeddings_zero = ops.transpose(zero_t)
|
|
736
|
+
|
|
737
|
+
# Pack for int4
|
|
564
738
|
embeddings_value, _, _ = quantizers.pack_int4(
|
|
565
739
|
requantized_embeddings, axis=-1
|
|
566
740
|
)
|
|
741
|
+
embeddings_scale = new_scale
|
|
567
742
|
else:
|
|
743
|
+
# int8 re-quantization
|
|
744
|
+
requantized_embeddings, embeddings_scale = (
|
|
745
|
+
quantizers.abs_max_quantize(
|
|
746
|
+
merged_float_embeddings,
|
|
747
|
+
axis=-1,
|
|
748
|
+
value_range=quant_range,
|
|
749
|
+
dtype="int8",
|
|
750
|
+
to_numpy=True,
|
|
751
|
+
)
|
|
752
|
+
)
|
|
753
|
+
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
568
754
|
embeddings_value = requantized_embeddings
|
|
569
|
-
|
|
755
|
+
embeddings_zero = None
|
|
756
|
+
return embeddings_value, embeddings_scale, embeddings_zero
|