keras-nightly 3.14.0.dev2026012704__py3-none-any.whl → 3.14.0.dev2026012904__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +1 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +1 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +3 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +1 -0
  7. keras/ops/numpy/__init__.py +1 -0
  8. keras/quantizers/__init__.py +3 -0
  9. keras/src/backend/jax/core.py +12 -2
  10. keras/src/backend/jax/numpy.py +5 -0
  11. keras/src/backend/numpy/numpy.py +5 -0
  12. keras/src/backend/openvino/numpy.py +6 -0
  13. keras/src/backend/tensorflow/numpy.py +21 -0
  14. keras/src/backend/torch/numpy.py +10 -0
  15. keras/src/callbacks/orbax_checkpoint.py +41 -8
  16. keras/src/dtype_policies/__init__.py +2 -0
  17. keras/src/dtype_policies/dtype_policy.py +80 -1
  18. keras/src/layers/core/dense.py +278 -95
  19. keras/src/layers/core/einsum_dense.py +350 -181
  20. keras/src/layers/core/embedding.py +236 -49
  21. keras/src/layers/core/reversible_embedding.py +177 -35
  22. keras/src/layers/preprocessing/discretization.py +30 -1
  23. keras/src/ops/numpy.py +54 -0
  24. keras/src/quantizers/__init__.py +6 -0
  25. keras/src/quantizers/quantization_config.py +98 -4
  26. keras/src/quantizers/quantizers.py +262 -32
  27. keras/src/saving/file_editor.py +7 -1
  28. keras/src/saving/saving_api.py +66 -2
  29. keras/src/saving/saving_lib.py +46 -47
  30. keras/src/version.py +1 -1
  31. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/METADATA +1 -1
  32. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/RECORD +34 -34
  33. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/WHEEL +0 -0
  34. {keras_nightly-3.14.0.dev2026012704.dist-info → keras_nightly-3.14.0.dev2026012904.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import math
1
2
  import warnings
2
3
 
3
4
  from keras.src import backend
@@ -11,6 +12,8 @@ from keras.src.api_export import keras_export
11
12
  from keras.src.backend import KerasTensor
12
13
  from keras.src.layers.layer import Layer
13
14
  from keras.src.quantizers.quantization_config import QuantizationConfig
15
+ from keras.src.quantizers.quantization_config import get_block_size_for_layer
16
+ from keras.src.quantizers.quantizers import dequantize_with_sz_map
14
17
  from keras.src.saving import serialization_lib
15
18
 
16
19
 
@@ -229,20 +232,37 @@ class Embedding(Layer):
229
232
  if mode not in self.variable_serialization_spec:
230
233
  raise self._quantization_mode_error(mode)
231
234
 
232
- # Embeddings plus optional merged LoRA-aware scale
233
- # (returns (embeddings, None) for `None` mode).
234
- embeddings_value, merged_kernel_scale = (
235
+ # Embeddings plus optional merged LoRA-aware scale/zero (returns
236
+ # (embeddings, None, None) for `None` mode).
237
+ embeddings_value, merged_embeddings_scale, merged_embeddings_zero = (
235
238
  self._get_embeddings_with_merged_lora()
236
239
  )
237
240
  idx = 0
238
241
  for name in self.variable_serialization_spec[mode]:
239
242
  if name == "embeddings":
240
243
  store[str(idx)] = embeddings_value
244
+ elif name == "embeddings_zero":
245
+ if merged_embeddings_zero is None:
246
+ # embeddings_zero only exists for sub-channel int4
247
+ # quantization
248
+ continue
249
+ store[str(idx)] = merged_embeddings_zero
250
+ elif name == "g_idx" and not hasattr(self, "g_idx"):
251
+ # g_idx only exists for sub-channel int4 quantization
252
+ continue
241
253
  elif name == "embeddings_scale" and mode in ("int4", "int8"):
242
254
  # For int4/int8, the merged LoRA scale (if any) comes from
243
255
  # `_get_embeddings_with_merged_lora()`
244
- store[str(idx)] = merged_kernel_scale
256
+ store[str(idx)] = merged_embeddings_scale
245
257
  else:
258
+ # Generic handling for subclass variables:
259
+ # Check if the attribute exists on the instance before saving.
260
+ # This supports optional variables in subclasses (e.g.,
261
+ # `reverse_embeddings_zero` in ReversibleEmbedding) that are
262
+ # present in the spec but may not exist on the object depending
263
+ # on configuration (e.g., per-channel vs. sub-channel).
264
+ if not hasattr(self, name):
265
+ continue
246
266
  store[str(idx)] = getattr(self, name)
247
267
  idx += 1
248
268
 
@@ -260,7 +280,21 @@ class Embedding(Layer):
260
280
  for name in self.variable_serialization_spec[mode]:
261
281
  if name == "embeddings":
262
282
  self._embeddings.assign(store[str(idx)])
283
+ elif name == "embeddings_zero" and not hasattr(
284
+ self, "embeddings_zero"
285
+ ):
286
+ # embeddings_zero only exists for sub-channel int4 quantization
287
+ continue
288
+ elif name == "g_idx" and not hasattr(self, "g_idx"):
289
+ # g_idx only exists for sub-channel int4 quantization
290
+ continue
263
291
  else:
292
+ # Generic handling for subclass variables:
293
+ # Check if the attribute exists before attempting to assign.
294
+ # If the variable is in the spec but missing from the object,
295
+ # we skip it to prevent AttributeError.
296
+ if not hasattr(self, name):
297
+ continue
264
298
  getattr(self, name).assign(store[str(idx)])
265
299
  idx += 1
266
300
  if self.lora_enabled:
@@ -333,6 +367,8 @@ class Embedding(Layer):
333
367
  "int4": [
334
368
  "embeddings",
335
369
  "embeddings_scale",
370
+ "embeddings_zero",
371
+ "g_idx",
336
372
  ],
337
373
  }
338
374
 
@@ -364,11 +400,17 @@ class Embedding(Layer):
364
400
  )
365
401
 
366
402
  def _int4_build(self, embeddings_shape, config=None):
403
+ """Build variables for int4 quantization.
404
+
405
+ Args:
406
+ embeddings_shape: Original shape `(input_dim, output_dim)`.
407
+ config: Optional quantization config specifying block_size.
408
+ """
367
409
  input_dim, output_dim = embeddings_shape
368
- packed_rows = (output_dim + 1) // 2 # ceil for odd dims
410
+ packed_rows = (output_dim + 1) // 2
369
411
 
370
- # Embeddings are stored *packed*: each int8 byte contains two int4
371
- # values.
412
+ # Embeddings are stored packed: each int8 byte contains two
413
+ # int4 values.
372
414
  self._embeddings = self.add_weight(
373
415
  name="embeddings",
374
416
  shape=(input_dim, packed_rows),
@@ -376,13 +418,46 @@ class Embedding(Layer):
376
418
  dtype="int8",
377
419
  trainable=False,
378
420
  )
421
+
422
+ block_size = get_block_size_for_layer(self, config)
423
+ self._int4_block_size = block_size
424
+
425
+ if block_size is None or block_size == -1:
426
+ scale_shape = (self.input_dim,)
427
+ else:
428
+ n_groups = math.ceil(output_dim / block_size)
429
+ scale_shape = (self.input_dim, n_groups)
430
+
379
431
  self.embeddings_scale = self.add_weight(
380
432
  name="embeddings_scale",
381
- shape=(self.input_dim,),
433
+ shape=scale_shape,
382
434
  initializer="ones",
383
435
  trainable=False,
384
436
  )
385
- # Record original output_dim for unpacking at runtime.
437
+
438
+ # Sub-channel quantization uses asymmetric quantization with
439
+ # zero point
440
+ if block_size is not None and block_size > 0:
441
+ self.embeddings_zero = self.add_weight(
442
+ name="embeddings_zero",
443
+ shape=scale_shape,
444
+ initializer="zeros",
445
+ dtype="int8",
446
+ trainable=False,
447
+ )
448
+ self.g_idx = self.add_weight(
449
+ name="g_idx",
450
+ shape=(output_dim,),
451
+ initializer="zeros",
452
+ dtype="float32",
453
+ trainable=False,
454
+ )
455
+ self.g_idx.assign(
456
+ ops.floor_divide(
457
+ ops.arange(output_dim, dtype="float32"), block_size
458
+ )
459
+ )
460
+
386
461
  self._orig_output_dim = output_dim
387
462
 
388
463
  def _int8_call(self, inputs, training=None):
@@ -406,20 +481,38 @@ class Embedding(Layer):
406
481
  return outputs
407
482
 
408
483
  def _int4_call(self, inputs, training=None):
409
- # We cannot update quantized self._embeddings, so the custom gradient is
410
- # not needed
484
+ """Forward pass for int4 quantized Embedding layer."""
411
485
  if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"):
412
486
  inputs = ops.cast(inputs, "int32")
413
- embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)
487
+
414
488
  unpacked_embeddings = quantizers.unpack_int4(
415
489
  self._embeddings, self._orig_output_dim, axis=-1
416
490
  )
417
491
  outputs = ops.take(unpacked_embeddings, inputs, axis=0)
418
- # De-scale outputs
419
- outputs = ops.divide(
420
- ops.cast(outputs, dtype=self.compute_dtype),
421
- ops.expand_dims(embeddings_scale, axis=-1),
422
- )
492
+
493
+ block_size = getattr(self, "_int4_block_size", None)
494
+
495
+ if block_size is None or block_size == -1:
496
+ embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)
497
+ outputs = ops.divide(
498
+ ops.cast(outputs, dtype=self.compute_dtype),
499
+ ops.expand_dims(embeddings_scale, axis=-1),
500
+ )
501
+ else:
502
+ # Sub-channel: look up scale/zero for each input token,
503
+ # then dequantize using g_idx to expand groups
504
+ embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)
505
+ embeddings_zero = ops.take(self.embeddings_zero, inputs, axis=0)
506
+
507
+ # Scale/zero are [batch..., n_groups], g_idx is [output_dim]
508
+ outputs = dequantize_with_sz_map(
509
+ ops.cast(outputs, dtype=self.compute_dtype),
510
+ embeddings_scale,
511
+ embeddings_zero,
512
+ self.g_idx,
513
+ group_axis=-1,
514
+ )
515
+
423
516
  if self.lora_enabled:
424
517
  lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0)
425
518
  lora_outputs = ops.matmul(lora_outputs, self.lora_embeddings_b)
@@ -454,20 +547,52 @@ class Embedding(Layer):
454
547
  self._embeddings.assign(embeddings_value)
455
548
  self.embeddings_scale.assign(embeddings_scale)
456
549
  elif mode == "int4":
457
- # Quantize to int4 values (stored in int8 dtype, range [-8, 7]).
458
- weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
459
- self.quantization_config,
460
- quantizers.AbsMaxQuantizer(
461
- axis=-1,
462
- value_range=(-8, 7),
463
- output_dtype="int8",
464
- ),
465
- )
466
- embeddings_value, embeddings_scale = weight_quantizer(
467
- self._embeddings, to_numpy=True
550
+ from keras.src.quantizers.quantization_config import (
551
+ Int4QuantizationConfig,
468
552
  )
469
- embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
470
- # 2. Pack two int4 values into a single int8 byte.
553
+
554
+ block_size = None
555
+ if isinstance(self.quantization_config, Int4QuantizationConfig):
556
+ block_size = self.quantization_config.block_size
557
+
558
+ use_grouped = block_size is not None and block_size != -1
559
+
560
+ if not use_grouped:
561
+ # Per-channel quantization
562
+ weight_quantizer = (
563
+ QuantizationConfig.weight_quantizer_or_default(
564
+ self.quantization_config,
565
+ quantizers.AbsMaxQuantizer(
566
+ axis=-1,
567
+ value_range=(-8, 7),
568
+ output_dtype="int8",
569
+ ),
570
+ )
571
+ )
572
+ embeddings_value, embeddings_scale = weight_quantizer(
573
+ self._embeddings, to_numpy=True
574
+ )
575
+ embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
576
+ else:
577
+ # Sub-channel quantization with asymmetric zero point
578
+ input_dim, output_dim = ops.shape(self._embeddings)
579
+ # Transpose to put output_dim first for grouped quantization
580
+ embeddings_t = ops.transpose(self._embeddings)
581
+
582
+ embeddings_value_t, scale_t, zero_t = (
583
+ quantizers.abs_max_quantize_grouped_with_zero_point(
584
+ embeddings_t,
585
+ block_size=block_size,
586
+ value_range=(-8, 7),
587
+ dtype="int8",
588
+ to_numpy=True,
589
+ )
590
+ )
591
+ # Transpose back to (input_dim, output_dim) layout
592
+ embeddings_value = ops.transpose(embeddings_value_t)
593
+ embeddings_scale = ops.transpose(scale_t)
594
+ embeddings_zero = ops.transpose(zero_t)
595
+
471
596
  packed_embeddings_value, _, _ = quantizers.pack_int4(
472
597
  embeddings_value, axis=-1
473
598
  )
@@ -477,12 +602,22 @@ class Embedding(Layer):
477
602
  )
478
603
  self._embeddings.assign(packed_embeddings_value)
479
604
  self.embeddings_scale.assign(embeddings_scale)
605
+ if use_grouped:
606
+ self.embeddings_zero.assign(embeddings_zero)
480
607
  else:
481
608
  raise self._quantization_mode_error(mode)
482
609
 
483
610
  # Set new dtype policy.
484
611
  if self.dtype_policy.quantization_mode is None:
485
- policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
612
+ policy_name = mode
613
+ if mode == "int4":
614
+ # Include block_size in policy name for sub-channel quantization
615
+ block_size = get_block_size_for_layer(self, config)
616
+ block_size_value = -1 if block_size is None else block_size
617
+ policy_name = f"int4/{block_size_value}"
618
+ policy = dtype_policies.get(
619
+ f"{policy_name}_from_{self.dtype_policy.name}"
620
+ )
486
621
  self.dtype_policy = policy
487
622
 
488
623
  def _get_embeddings_with_merged_lora(self):
@@ -508,29 +643,46 @@ class Embedding(Layer):
508
643
  without modification.
509
644
 
510
645
  Returns:
511
- A tuple `(embeddings_value, embeddings_scale)`:
646
+ A tuple `(embeddings_value, embeddings_scale, embeddings_zero)`:
512
647
  `embeddings_value`: The merged embeddings. A quantized tensor if
513
648
  quantization is active, otherwise a high precision tensor.
514
649
  `embeddings_scale`: The quantization scale for the merged
515
650
  embeddings. This is `None` if the layer is not quantized.
651
+ `embeddings_zero`: The zero point for sub-channel quantization.
652
+ This is `None` for per-channel quantization modes.
516
653
  """
517
654
  if self.dtype_policy.quantization_mode in (None, "gptq", "awq"):
518
- return self.embeddings, None
655
+ return self.embeddings, None, None
519
656
 
520
657
  embeddings_value = self._embeddings
521
658
  embeddings_scale = self.embeddings_scale
659
+ embeddings_zero = getattr(self, "embeddings_zero", None)
660
+
522
661
  if not self.lora_enabled:
523
- return embeddings_value, embeddings_scale
662
+ return embeddings_value, embeddings_scale, embeddings_zero
663
+
664
+ block_size = getattr(self, "_int4_block_size", None)
524
665
 
525
666
  # Dequantize embeddings to float.
526
667
  if self.quantization_mode == "int4":
527
668
  unpacked_embeddings = quantizers.unpack_int4(
528
669
  embeddings_value, self._orig_output_dim, axis=-1
529
670
  )
530
- float_embeddings = ops.divide(
531
- ops.cast(unpacked_embeddings, self.compute_dtype),
532
- ops.expand_dims(embeddings_scale, axis=-1),
533
- )
671
+ if block_size is None or block_size == -1:
672
+ # Per-channel dequantization
673
+ float_embeddings = ops.divide(
674
+ ops.cast(unpacked_embeddings, self.compute_dtype),
675
+ ops.expand_dims(embeddings_scale, axis=-1),
676
+ )
677
+ else:
678
+ # Sub-channel: grouped dequantization using shared utility
679
+ float_embeddings = dequantize_with_sz_map(
680
+ ops.cast(unpacked_embeddings, self.compute_dtype),
681
+ embeddings_scale,
682
+ self.embeddings_zero,
683
+ self.g_idx,
684
+ group_axis=-1,
685
+ )
534
686
  quant_range = (-8, 7)
535
687
  elif self.quantization_mode == "int8":
536
688
  float_embeddings = ops.divide(
@@ -550,20 +702,55 @@ class Embedding(Layer):
550
702
  merged_float_embeddings = ops.add(float_embeddings, lora_delta)
551
703
 
552
704
  # Requantize.
553
- requantized_embeddings, embeddings_scale = quantizers.abs_max_quantize(
554
- merged_float_embeddings,
555
- axis=-1,
556
- value_range=quant_range,
557
- dtype="int8",
558
- to_numpy=True,
559
- )
560
- embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
561
-
562
- # Pack if int4.
563
705
  if self.quantization_mode == "int4":
706
+ if block_size is None or block_size == -1:
707
+ # Per-channel re-quantization
708
+ requantized_embeddings, new_scale = quantizers.abs_max_quantize(
709
+ merged_float_embeddings,
710
+ axis=-1,
711
+ value_range=quant_range,
712
+ dtype="int8",
713
+ to_numpy=True,
714
+ )
715
+ new_scale = ops.squeeze(new_scale, axis=-1)
716
+ embeddings_zero = None
717
+ else:
718
+ # Grouped re-quantization (asymmetric with zero point)
719
+ merged_np = merged_float_embeddings
720
+ # Transpose to (output_dim, input_dim) for grouped quantization
721
+ merged_t = ops.transpose(merged_np)
722
+
723
+ requantized_t, scale_t, zero_t = (
724
+ quantizers.abs_max_quantize_grouped_with_zero_point(
725
+ merged_t,
726
+ block_size=block_size,
727
+ value_range=quant_range,
728
+ dtype="int8",
729
+ to_numpy=True,
730
+ )
731
+ )
732
+ # Transpose back
733
+ requantized_embeddings = ops.transpose(requantized_t)
734
+ new_scale = ops.transpose(scale_t)
735
+ embeddings_zero = ops.transpose(zero_t)
736
+
737
+ # Pack for int4
564
738
  embeddings_value, _, _ = quantizers.pack_int4(
565
739
  requantized_embeddings, axis=-1
566
740
  )
741
+ embeddings_scale = new_scale
567
742
  else:
743
+ # int8 re-quantization
744
+ requantized_embeddings, embeddings_scale = (
745
+ quantizers.abs_max_quantize(
746
+ merged_float_embeddings,
747
+ axis=-1,
748
+ value_range=quant_range,
749
+ dtype="int8",
750
+ to_numpy=True,
751
+ )
752
+ )
753
+ embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
568
754
  embeddings_value = requantized_embeddings
569
- return embeddings_value, embeddings_scale
755
+ embeddings_zero = None
756
+ return embeddings_value, embeddings_scale, embeddings_zero