keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ from keras.src.backend import any_symbolic_tensors
9
9
  from keras.src.backend.common.backend_utils import canonicalize_axis
10
10
  from keras.src.backend.common.backend_utils import standardize_axis_for_numpy
11
11
  from keras.src.ops.operation import Operation
12
+ from keras.src.quantizers.gptq_config import GPTQConfig
12
13
 
13
14
  """Int8-related classes and methods"""
14
15
 
@@ -72,6 +73,23 @@ def abs_max_quantize(
72
73
  epsilon=backend.epsilon(),
73
74
  to_numpy=False,
74
75
  ):
76
+ """
77
+ Quantizes the input tensor using the absolute maximum quantization scheme.
78
+
79
+ Args:
80
+ inputs: Input tensor to quantize.
81
+ axis: Axis along which to compute the quantization range.
82
+ value_range: Tuple of the minimum and maximum values of the quantization
83
+ range.
84
+ dtype: Data type of the quantized output.
85
+ epsilon: Small value to avoid division by zero.
86
+ to_numpy: Whether to perform the quantization in numpy. This performs
87
+ the computation on the host CPU and can be useful for saving memory
88
+ on the device. If False, the computation is performed on the device.
89
+
90
+ Returns:
91
+ A tuple of the quantized tensor and the scale.
92
+ """
75
93
  if to_numpy:
76
94
  # Save memory on the device using numpy
77
95
  original_dtype = backend.standardize_dtype(inputs.dtype)
@@ -104,31 +122,69 @@ def abs_max_quantize(
104
122
  class AbsMaxQuantizer(Quantizer):
105
123
  def __init__(
106
124
  self,
107
- axis,
125
+ axis=None, # Deprecated, provide axis in __call__ instead.
108
126
  value_range=(-127, 127),
109
127
  epsilon=backend.epsilon(),
110
128
  output_dtype="int8",
111
129
  ):
112
130
  Quantizer.__init__(self, output_dtype=output_dtype)
113
- if isinstance(axis, int):
114
- axis = (axis,)
115
- self.axis = tuple(axis)
131
+ if axis is not None:
132
+ if isinstance(axis, int):
133
+ axis = (axis,)
134
+ self.axis = tuple(axis)
135
+ else:
136
+ self.axis = None
116
137
  self.value_range = value_range
117
138
  self.epsilon = epsilon
139
+ if output_dtype == "int8":
140
+ if value_range[0] < -128 or value_range[1] > 127:
141
+ raise ValueError(
142
+ f"Quantizer with output_dtype='int8' requires value_range "
143
+ f"to be within the interval [-128, 127]. Received: "
144
+ f"value_range={value_range}"
145
+ )
118
146
 
119
- def __call__(self, x):
147
+ def __call__(self, x, axis=None, to_numpy=False):
148
+ """
149
+ Quantizes the input tensor.
150
+
151
+ Args:
152
+ x: Input tensor to quantize.
153
+ axis: Axis along which to compute the quantization range. If None,
154
+ uses the axis specified in the constructor. If None and no axis
155
+ was specified in the constructor, defaults to -1.
156
+ to_numpy: Whether to perform the quantization in numpy. This
157
+ performs the computation on the host CPU and can be useful for
158
+ saving memory on the device. If False, the computation is
159
+ performed on the device.
160
+
161
+ Returns:
162
+ A tuple of the quantized tensor and the scale.
163
+ """
164
+ if axis is None:
165
+ axis = self.axis
166
+ if axis is None:
167
+ # Default to -1 if no axis is specified
168
+ axis = -1
120
169
  quantized_x, scale = abs_max_quantize(
121
- x, self.axis, self.value_range, self.output_dtype, self.epsilon
170
+ x,
171
+ axis,
172
+ self.value_range,
173
+ self.output_dtype,
174
+ self.epsilon,
175
+ to_numpy,
122
176
  )
123
177
  return quantized_x, scale
124
178
 
125
179
  def get_config(self):
126
- return {
127
- "axis": self.axis,
180
+ config = {
128
181
  "value_range": self.value_range,
129
182
  "epsilon": self.epsilon,
130
183
  "output_dtype": self.output_dtype,
131
184
  }
185
+ if self.axis is not None:
186
+ config["axis"] = self.axis
187
+ return config
132
188
 
133
189
 
134
190
  def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
@@ -280,7 +336,7 @@ def fake_quant_with_min_max_vars(
280
336
  ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
281
337
  )
282
338
  x_clamped = ops.clip(
283
- x, ops.cast(nudged_min, x.dtype), ops.cast(nudged_max, x.dtype)
339
+ ops.cast(x, nudged_min.dtype), nudged_min, nudged_max
284
340
  )
285
341
  x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
286
342
  result = ops.multiply(
@@ -317,6 +373,7 @@ def fake_quant_with_min_max_vars(
317
373
  grad_min = ops.sum(grad_min, axis=axes)
318
374
  else:
319
375
  grad_min = ops.sum(grad_min)
376
+ grad_min = ops.reshape(grad_min, ops.shape(min_val))
320
377
 
321
378
  # Gradient for max_val
322
379
  # When x is clipped to max, the gradient flows to max_val
@@ -326,6 +383,7 @@ def fake_quant_with_min_max_vars(
326
383
  grad_max = ops.sum(grad_max, axis=axes)
327
384
  else:
328
385
  grad_max = ops.sum(grad_max)
386
+ grad_max = ops.reshape(grad_max, ops.shape(max_val))
329
387
 
330
388
  return dx, grad_min, grad_max
331
389
 
@@ -377,7 +435,7 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype):
377
435
 
378
436
 
379
437
  @keras_export("keras.quantizers.pack_int4")
380
- def pack_int4(arr, axis=0):
438
+ def pack_int4(arr, axis=0, dtype="int8"):
381
439
  """Pack an int4 tensor into an int8 tensor with packed nibbles.
382
440
 
383
441
  The input values must already be int8 in the signed range `[-8, 7]` and
@@ -389,8 +447,11 @@ def pack_int4(arr, axis=0):
389
447
  the value from the second row.
390
448
 
391
449
  Args:
392
- arr: An int8 tensor containing int4 values in the range `[-8, 7]`.
450
+ arr: An `int8` or `uint8` tensor containing int4 values in the range
451
+ `[-8, 7]`.
393
452
  axis: The axis along which to pack the tensor. Defaults to 0.
453
+ dtype: The data type of the input and packed tensor. Can be
454
+ `"int8"` or `"uint8"`. Defaults to `"int8"`.
394
455
 
395
456
  Returns:
396
457
  tuple: A tuple `(packed, packed_shape, orig_rows)` where `packed` is
@@ -450,9 +511,14 @@ def pack_int4(arr, axis=0):
450
511
  True
451
512
  ```
452
513
  """
453
- if backend.standardize_dtype(arr.dtype) != "int8":
514
+ if dtype not in ("int8", "uint8"):
515
+ raise ValueError(
516
+ f"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'."
517
+ )
518
+ if backend.standardize_dtype(arr.dtype) != dtype:
454
519
  raise TypeError(
455
- "Expected int8 tensor for packing, got {}".format(arr.dtype)
520
+ f"Expected {dtype} tensor for packing, got "
521
+ f"{backend.standardize_dtype(arr.dtype)}."
456
522
  )
457
523
 
458
524
  rank = getattr(arr.shape, "rank", None) or len(arr.shape)
@@ -486,12 +552,12 @@ def pack_int4(arr, axis=0):
486
552
  low = padded[::2, ...]
487
553
  high = padded[1::2, ...]
488
554
 
489
- mask = ops.array(0x0F, dtype="int8")
555
+ mask = ops.array(0x0F, dtype=dtype)
490
556
  low_u = ops.bitwise_and(low, mask)
491
557
  high_u = ops.bitwise_and(high, mask)
492
558
 
493
559
  packed = ops.bitwise_or(low_u, ops.left_shift(high_u, 4))
494
- packed = ops.cast(packed, "int8")
560
+ packed = ops.cast(packed, dtype)
495
561
 
496
562
  # 5-6. Restore shape.
497
563
  packed = ops.transpose(packed, inv_perm) # back to original order
@@ -500,7 +566,7 @@ def pack_int4(arr, axis=0):
500
566
 
501
567
 
502
568
  @keras_export("keras.quantizers.unpack_int4")
503
- def unpack_int4(packed, orig_len, axis=0):
569
+ def unpack_int4(packed, orig_len, axis=0, dtype="int8"):
504
570
  """Unpack a packed int4 back to an int8 tensor in the range [-8, 7].
505
571
 
506
572
  This function reverses the packing performed by `pack_int4`, restoring
@@ -518,6 +584,8 @@ def unpack_int4(packed, orig_len, axis=0):
518
584
  packed. This is used to remove any padding that may have
519
585
  been added during packing to ensure an even number of rows.
520
586
  axis: The axis along which the tensor was packed. Defaults to 0.
587
+ dtype: The data type of the input and unpacked tensor. Can be
588
+ `"int8"` or `"uint8"`. Defaults to `"int8"`.
521
589
 
522
590
  Returns:
523
591
  unpacked: An int8 tensor with the same shape as the original
@@ -574,13 +642,27 @@ def unpack_int4(packed, orig_len, axis=0):
574
642
  True
575
643
  ```
576
644
  """
577
- if backend.standardize_dtype(packed.dtype) != "int8":
645
+ if dtype not in ("int8", "uint8"):
646
+ raise ValueError(
647
+ f"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'."
648
+ )
649
+
650
+ if backend.standardize_dtype(packed.dtype) not in ("int8", "uint8"):
578
651
  raise TypeError(
579
- f"Expected int8 tensor for unpacking, got {packed.dtype}"
652
+ f"Expected int8 or uint8 tensor for unpacking, got {packed.dtype}"
580
653
  )
581
654
 
582
- rank = getattr(packed.shape, "rank", None) or len(packed.shape)
655
+ def to_signed(x):
656
+ """Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
657
+
658
+ Uses a branchless XOR approach: (x ^ 8) - 8
659
+ This maps: 0->0, 1->1, ..., 7->7, 8->-8, 9->-7, ..., 15->-1
660
+ """
661
+ dtype_x = backend.standardize_dtype(x.dtype)
662
+ eight = ops.cast(8, dtype_x)
663
+ return ops.subtract(ops.bitwise_xor(x, eight), eight)
583
664
 
665
+ rank = getattr(packed.shape, "rank", None) or len(packed.shape)
584
666
  if axis < 0:
585
667
  axis += rank
586
668
 
@@ -591,16 +673,15 @@ def unpack_int4(packed, orig_len, axis=0):
591
673
  low_unpacked = ops.bitwise_and(packed, mask)
592
674
  high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), mask)
593
675
 
594
- # Convert values from [0, 15] to [-8, 7].
595
- low_signed = ops.where(
596
- low_unpacked < 8, low_unpacked, low_unpacked - 16
597
- )
598
- high_signed = ops.where(
599
- high_unpacked < 8, high_unpacked, high_unpacked - 16
600
- )
676
+ if dtype == "int8":
677
+ low_unpacked = to_signed(low_unpacked)
678
+ high_unpacked = to_signed(high_unpacked)
679
+
680
+ low_final = ops.cast(low_unpacked, dtype)
681
+ high_final = ops.cast(high_unpacked, dtype)
601
682
 
602
683
  # Interleave and reshape
603
- stacked = ops.stack([low_signed, high_signed], axis=1)
684
+ stacked = ops.stack([low_final, high_final], axis=1)
604
685
  unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(packed)[1:]))
605
686
 
606
687
  # Remove padding and return
@@ -612,25 +693,313 @@ def unpack_int4(packed, orig_len, axis=0):
612
693
  transposed = ops.transpose(packed, perm)
613
694
 
614
695
  # 1. Split nibbles.
615
- mask = ops.array(0x0F, dtype="int8") # int8 arrays
696
+ mask = ops.array(0x0F, dtype=packed.dtype)
616
697
  low = ops.bitwise_and(transposed, mask)
617
698
  high = ops.bitwise_and(ops.right_shift(transposed, 4), mask)
618
699
 
619
- eight = ops.array(8, dtype="int8")
620
- sixteen = ops.array(16, dtype="int8")
621
-
622
- def to_signed(x):
623
- return ops.where(x < eight, x, x - sixteen)
700
+ # 2. Conditionally convert to signed.
701
+ if dtype == "int8":
702
+ low = to_signed(low)
703
+ high = to_signed(high)
624
704
 
625
- low = to_signed(low)
626
- high = to_signed(high)
705
+ low = ops.cast(low, dtype)
706
+ high = ops.cast(high, dtype)
627
707
 
628
- # 2. Interleave and reshape.
629
- stacked = ops.stack([low, high], axis=1) # (pairs, 2, ...)
708
+ # 3. Interleave and reshape.
709
+ stacked = ops.stack([low, high], axis=1)
630
710
  unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:]))
631
711
 
632
712
  # 4. Remove padding and restore original layout.
633
713
  unpacked = unpacked[:orig_len, ...]
634
714
  unpacked = ops.transpose(unpacked, inv_perm)
635
715
 
636
- return unpacked # dtype is int8
716
+ return unpacked
717
+
718
+
719
+ class GPTQQuantizer(Quantizer):
720
+ """A class that handles the quantization of weights using GPTQ method.
721
+
722
+ This class provides methods to find quantization parameters (scale and zero)
723
+ for a given tensor and can be used to quantize weights in a GPTQ context.
724
+
725
+ Args:
726
+ weight_bits: (int) The number of bits to quantize to (e.g., 4).
727
+ per_channel: (bool) A flag indicating whether quantization is
728
+ applied per-channel (`True`) or per-tensor (`False`).
729
+ Defaults to `False`.
730
+ symmetric: (bool) A flag indicating whether symmetric (`True`) or
731
+ asymmetric (`False`) quantization is used. Defaults to `False`.
732
+ group_size: (int) The size of weight groups for quantization. A
733
+ value of -1 indicates that grouping is not used.
734
+ Defaults to -1.
735
+ """
736
+
737
+ def __init__(
738
+ self,
739
+ config=GPTQConfig(tokenizer=None, dataset=None),
740
+ compute_dtype="float32",
741
+ ):
742
+ Quantizer.__init__(self)
743
+ self.weight_bits = config.weight_bits
744
+ self.per_channel = config.per_channel
745
+ self.symmetric = config.symmetric
746
+ self.group_size = config.group_size
747
+ self.compute_dtype = compute_dtype
748
+
749
+ # These are now determined later by `find_params`
750
+ self.scale = None
751
+ self.zero = None
752
+ self.maxq = None
753
+
754
+ def find_params(self, input_tensor):
755
+ """Finds quantization parameters (scale and zero) for a given tensor."""
756
+ self.scale, self.zero, self.maxq = compute_quantization_parameters(
757
+ input_tensor,
758
+ bits=self.weight_bits,
759
+ symmetric=self.symmetric,
760
+ per_channel=self.per_channel,
761
+ group_size=self.group_size,
762
+ compute_dtype=self.compute_dtype,
763
+ )
764
+ return self.scale, self.zero, self.maxq
765
+
766
+ def get_config(self):
767
+ config = super().get_config()
768
+ config.update(
769
+ {
770
+ "weight_bits": self.weight_bits,
771
+ "per_channel": self.per_channel,
772
+ "symmetric": self.symmetric,
773
+ "group_size": self.group_size,
774
+ }
775
+ )
776
+ return config
777
+
778
+ @classmethod
779
+ def from_config(cls, config):
780
+ gptq = GPTQConfig(
781
+ tokenizer=None,
782
+ dataset=None,
783
+ weight_bits=config["weight_bits"],
784
+ per_channel=config["per_channel"],
785
+ symmetric=config["symmetric"],
786
+ group_size=config["group_size"],
787
+ )
788
+ return cls(gptq)
789
+
790
+
791
+ def compute_quantization_parameters(
792
+ x,
793
+ *,
794
+ bits,
795
+ symmetric=False,
796
+ per_channel=False,
797
+ group_size=-1,
798
+ compute_dtype="float32",
799
+ ):
800
+ """
801
+ Computes the scale and zero-point for quantizing weight tensors.
802
+
803
+ This function calculates the scale and zero-point required for quantizing
804
+ a given weight tensor `x` based on the specified parameters. It supports
805
+ grouped, per-channel, per-tensor, symmetric, and asymmetric quantization.
806
+
807
+ For grouped quantization (per_channel=True, group_size > 0), the output
808
+ shapes are [out_features, n_groups] where n_groups is the number of groups
809
+ along the in_features dimension.
810
+
811
+ Args:
812
+ x: KerasTensor. The weight tensor to quantize with shape
813
+ [out_features, in_features].
814
+ bits: int. The number of bits to quantize to (e.g., 4).
815
+ symmetric: bool. Whether to use symmetric quantization.
816
+ per_channel: bool. Whether to quantize per channel.
817
+ group_size: int. The group size for quantization. -1 means no grouping.
818
+ compute_dtype: str. The dtype for computation. Defaults to "float32".
819
+
820
+ Returns:
821
+ scale: KerasTensor. The scale tensor for quantization.
822
+ zero: KerasTensor. The zero tensor for quantization.
823
+ maxq: scalar. The maximum quantization value.
824
+ """
825
+ # Input validation
826
+ if x is None:
827
+ raise ValueError(f"Input tensor {x} cannot be None.")
828
+ if len(x.shape) < 2:
829
+ raise ValueError(
830
+ f"Input weight tensor {x} must have a rank of at "
831
+ f"least 2, but got rank {len(x.shape)}."
832
+ )
833
+ if ops.size(x) == 0:
834
+ raise ValueError("Input tensor 'x' cannot be empty.")
835
+
836
+ out_features, in_features = x.shape[0], x.shape[1]
837
+
838
+ # Determine number of groups for quantization
839
+ if per_channel and group_size > 0:
840
+ n_groups = (in_features + group_size - 1) // group_size
841
+ else:
842
+ n_groups = 1
843
+
844
+ # Compute min/max values based on quantization mode
845
+ if n_groups > 1:
846
+ # Grouped quantization: output shape [out_features, n_groups]
847
+ remainder = in_features % group_size
848
+ if remainder != 0:
849
+ pad_size = group_size - remainder
850
+ x = ops.pad(x, [[0, 0], [0, pad_size]], constant_values=0.0)
851
+
852
+ x_grouped = ops.reshape(x, [out_features, n_groups, group_size])
853
+ min_values = ops.min(x_grouped, axis=2)
854
+ max_values = ops.max(x_grouped, axis=2)
855
+ else:
856
+ # Per-channel or per-tensor: compute stats along rows
857
+ reduction_shape = [out_features, -1] if per_channel else [1, -1]
858
+ x_reshaped = ops.reshape(x, reduction_shape)
859
+ min_values = ops.min(x_reshaped, axis=1)
860
+ max_values = ops.max(x_reshaped, axis=1)
861
+
862
+ # Symmetric quantization: make range symmetric around zero
863
+ if symmetric:
864
+ max_abs = ops.maximum(ops.abs(min_values), max_values)
865
+ min_values = ops.where(
866
+ ops.less(min_values, 0), ops.negative(max_abs), min_values
867
+ )
868
+ max_values = max_abs
869
+
870
+ # Ensure non-zero range to avoid division errors
871
+ zero_range = ops.equal(min_values, max_values)
872
+ min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)
873
+ max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
874
+
875
+ # Compute scale and zero-point
876
+ maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)
877
+ scale = ops.divide(ops.subtract(max_values, min_values), maxq)
878
+ scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
879
+
880
+ if symmetric:
881
+ zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
882
+ else:
883
+ zero = ops.round(ops.divide(ops.negative(min_values), scale))
884
+
885
+ # Reshape output to [out_features, n_groups] or [out_features, 1]
886
+ if n_groups > 1:
887
+ pass # Already [out_features, n_groups]
888
+ elif per_channel:
889
+ scale = ops.reshape(scale, [-1, 1])
890
+ zero = ops.reshape(zero, [-1, 1])
891
+ else:
892
+ # Per-tensor: tile single value to [out_features, 1]
893
+ scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))
894
+ zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))
895
+
896
+ return scale, ops.cast(zero, "uint8"), maxq
897
+
898
+
899
+ def quantize_with_zero_point(input_tensor, scale, zero, maxq):
900
+ """Quantize a float tensor into discrete levels [0, maxq] using
901
+ per-tensor/per-channel/grouped scaling.
902
+
903
+ Returns `q` (same dtype as inputs/scales; float is fine) where values are in
904
+ [0, maxq].
905
+
906
+ Args:
907
+ input_tensor: KerasTensor. The input tensor to quantize.
908
+ scale: KerasTensor. The scale tensor for quantization.
909
+ zero: KerasTensor. The zero tensor for quantization.
910
+ maxq: KerasTensor. The maximum quantization value.
911
+
912
+ Returns:
913
+ KerasTensor. The quantized tensor.
914
+ """
915
+ # Guard against divide-by-zero
916
+ epsilon = ops.cast(1e-8, dtype=scale.dtype)
917
+ safe_scale = ops.where(ops.equal(scale, 0), epsilon, scale)
918
+
919
+ quantized_tensor = ops.round(
920
+ ops.add(
921
+ ops.divide(input_tensor, safe_scale), ops.cast(zero, scale.dtype)
922
+ )
923
+ )
924
+ quantized_tensor = ops.clip(quantized_tensor, 0, maxq)
925
+ return quantized_tensor
926
+
927
+
928
+ def dequantize_with_zero_point(input_tensor, scale, zero):
929
+ """
930
+ Dequantizes a quantized tensor using the provided scale and zero tensors.
931
+
932
+ Args:
933
+ input_tensor: KerasTensor. The quantized tensor to dequantize.
934
+ scale: KerasTensor. The scale tensor for dequantization.
935
+ zero: KerasTensor. The zero tensor for dequantization.
936
+
937
+ Returns:
938
+ KerasTensor. The dequantized tensor.
939
+ """
940
+ return ops.multiply(
941
+ scale, ops.subtract(input_tensor, ops.cast(zero, scale.dtype))
942
+ )
943
+
944
+
945
+ def quantize_with_sz_map(weights_matrix, scale, zero, g_idx, maxq):
946
+ """Quantize the weight matrix from group params.
947
+
948
+ This function uses the provided scale and zero tensors to quantize the
949
+ input weights_matrix according to the group indices. It maps each column
950
+ of the weights_matrix to its corresponding group parameters and performs
951
+ the quantization operation.
952
+
953
+ Args:
954
+ weights_matrix: 2D tensor of shape [out_features, in_features].
955
+ scale: Per-group scale tensor of shape [out_features, n_groups].
956
+ zero: Per-group zero-point tensor of shape [out_features, n_groups].
957
+ g_idx: Integer tensor of shape [in_features,] mapping each column to
958
+ its group index.
959
+ maxq: Scalar (float) representing the maximum integer quantization
960
+ level (e.g., 2^bits - 1).
961
+
962
+ Returns:
963
+ A tensor with the same shape as `weights_matrix` containing the
964
+ quantized weights produced using the provided group parameters.
965
+ """
966
+ groups = ops.cast(g_idx, "int32")
967
+ scale_cols = ops.take(scale, groups, axis=1) # [out_features, in_features]
968
+ zero_cols = ops.take(zero, groups, axis=1) # [out_features, in_features]
969
+
970
+ # Quantize elementwise, then cast to int
971
+ return quantize_with_zero_point(weights_matrix, scale_cols, zero_cols, maxq)
972
+
973
+
974
+ def dequantize_with_sz_map(weights_matrix, scale, zero, g_idx):
975
+ """Rebuild a dequantized weight matrix from group params.
976
+
977
+ This function uses the provided scale and zero tensors to dequantize the
978
+ input weights_matrix according to the group indices. It maps each column
979
+ of the weights_matrix to its corresponding group parameters and performs
980
+ the dequantization operation.
981
+
982
+ Args:
983
+ weights_matrix: 2D tensor of shape [out_features, in_features].
984
+ scale: Per-group scale tensor of shape [out_features, n_groups].
985
+ zero: Per-group zero-point tensor of shape [out_features, n_groups].
986
+ g_idx: Integer tensor of shape [in_features,] mapping each column to
987
+ its group index.
988
+ maxq: Scalar (float) representing the maximum integer quantization
989
+ level (e.g., 2^bits - 1).
990
+
991
+ Returns:
992
+ A tensor with the same shape as `weights_matrix` containing the
993
+ dequantized weights produced using the provided group parameters.
994
+ """
995
+ # Map group indices to scales and zeros
996
+ groups = ops.cast(g_idx, "int32")
997
+ scales_mapped = ops.take(scale, groups, axis=1)
998
+ zeros_mapped = ops.take(zero, groups, axis=1)
999
+ zeros_mapped = ops.cast(zeros_mapped, scales_mapped.dtype)
1000
+
1001
+ quantized = ops.multiply(
1002
+ ops.subtract(weights_matrix, zeros_mapped), scales_mapped
1003
+ )
1004
+
1005
+ return quantized
@@ -0,0 +1,23 @@
1
+ import re
2
+
3
+
4
+ def should_quantize_layer(layer, filters):
5
+ """Determines if a layer should be quantized based on filters.
6
+
7
+ Args:
8
+ layer: The layer to check.
9
+ filters: A regex string, a list of regex strings, or a callable.
10
+ If None, returns True.
11
+
12
+ Returns:
13
+ True if the layer should be quantized, False otherwise.
14
+ """
15
+ if filters is None:
16
+ return True
17
+ if isinstance(filters, str):
18
+ return bool(re.search(filters, layer.name))
19
+ if isinstance(filters, (list, tuple)):
20
+ return any(re.search(pat, layer.name) for pat in filters)
21
+ if callable(filters):
22
+ return filters(layer)
23
+ return True
@@ -8,6 +8,8 @@ from keras.src.backend.common import global_state
8
8
  from keras.src.utils import jax_utils
9
9
  from keras.src.utils.naming import auto_name
10
10
 
11
+ GLOBAL_SEED_GENERATOR = "global_seed_generator"
12
+
11
13
 
12
14
  @keras_export("keras.random.SeedGenerator")
13
15
  class SeedGenerator:
@@ -27,7 +29,7 @@ class SeedGenerator:
27
29
  a local `StateGenerator` with either a deterministic or random initial
28
30
  state.
29
31
 
30
- Remark concerning the JAX backen: Note that the use of a local
32
+ Remark concerning the JAX backend: Note that the use of a local
31
33
  `StateGenerator` as seed argument is required for JIT compilation of
32
34
  RNG with the JAX backend, because the use of global state is not
33
35
  supported.
@@ -109,7 +111,7 @@ class SeedGenerator:
109
111
  return new_seed_value
110
112
 
111
113
  def get_config(self):
112
- return {"seed": self._initial_seed}
114
+ return {"seed": self._initial_seed, "name": self.name}
113
115
 
114
116
  @classmethod
115
117
  def from_config(cls, config):
@@ -133,10 +135,10 @@ def global_seed_generator():
133
135
  "out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n"
134
136
  "```"
135
137
  )
136
- gen = global_state.get_global_attribute("global_seed_generator")
138
+ gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR)
137
139
  if gen is None:
138
140
  gen = SeedGenerator()
139
- global_state.set_global_attribute("global_seed_generator", gen)
141
+ global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen)
140
142
  return gen
141
143
 
142
144