keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,246 @@
1
+ from keras.src.api_export import keras_export
2
+ from keras.src.dtype_policies import QUANTIZATION_MODES
3
+ from keras.src.saving import serialization_lib
4
+
5
+
6
+ @keras_export("keras.quantizers.QuantizationConfig")
7
+ class QuantizationConfig:
8
+ """Base class for quantization configs.
9
+
10
+ Subclasses must implement the `mode` property and the `get_config` and
11
+ `from_config` class methods.
12
+
13
+ Args:
14
+ weight_quantizer: Quantizer for weights.
15
+ activation_quantizer: Quantizer for activations.
16
+ """
17
+
18
+ def __init__(self, weight_quantizer=None, activation_quantizer=None):
19
+ self.weight_quantizer = weight_quantizer
20
+ self.activation_quantizer = activation_quantizer
21
+
22
+ @property
23
+ def mode(self):
24
+ raise NotImplementedError(
25
+ "Subclasses must implement this property. Do not instantiate "
26
+ "QuantizationConfig directly."
27
+ )
28
+
29
+ def get_config(self):
30
+ return {
31
+ "weight_quantizer": serialization_lib.serialize_keras_object(
32
+ self.weight_quantizer
33
+ ),
34
+ "activation_quantizer": serialization_lib.serialize_keras_object(
35
+ self.activation_quantizer
36
+ ),
37
+ }
38
+
39
+ @classmethod
40
+ def from_config(cls, config):
41
+ weight_quantizer = serialization_lib.deserialize_keras_object(
42
+ config.get("weight_quantizer")
43
+ )
44
+ activation_quantizer = serialization_lib.deserialize_keras_object(
45
+ config.get("activation_quantizer")
46
+ )
47
+ return cls(
48
+ weight_quantizer=weight_quantizer,
49
+ activation_quantizer=activation_quantizer,
50
+ )
51
+
52
+ @staticmethod
53
+ def weight_quantizer_or_default(config, default):
54
+ if config is not None and config.weight_quantizer is not None:
55
+ return config.weight_quantizer
56
+ return default
57
+
58
+ @staticmethod
59
+ def activation_quantizer_or_default(config, default):
60
+ if config is not None:
61
+ return config.activation_quantizer
62
+ return default
63
+
64
+
65
+ @keras_export("keras.quantizers.Int8QuantizationConfig")
66
+ class Int8QuantizationConfig(QuantizationConfig):
67
+ """Int8 quantization config.
68
+
69
+ Args:
70
+ weight_quantizer: Quantizer for weights.
71
+ activation_quantizer: Quantizer for activations. If "default", uses
72
+ AbsMaxQuantizer with axis=-1.
73
+ """
74
+
75
+ def __init__(self, weight_quantizer=None, activation_quantizer="default"):
76
+ from keras.src.quantizers.quantizers import AbsMaxQuantizer
77
+
78
+ if activation_quantizer == "default":
79
+ activation_quantizer = AbsMaxQuantizer()
80
+ super().__init__(weight_quantizer, activation_quantizer)
81
+ if self.weight_quantizer is not None:
82
+ if self.weight_quantizer.output_dtype != "int8":
83
+ raise ValueError(
84
+ "Int8QuantizationConfig requires a weight_quantizer "
85
+ "with output_dtype='int8'. Received: "
86
+ f"output_dtype={self.weight_quantizer.output_dtype}"
87
+ )
88
+
89
+ @property
90
+ def mode(self):
91
+ return "int8"
92
+
93
+
94
+ @keras_export("keras.quantizers.Int4QuantizationConfig")
95
+ class Int4QuantizationConfig(QuantizationConfig):
96
+ """Int4 quantization config.
97
+
98
+ Args:
99
+ weight_quantizer: Quantizer for weights.
100
+ activation_quantizer: Quantizer for activations. If "default", uses
101
+ AbsMaxQuantizer with axis=-1.
102
+ """
103
+
104
+ def __init__(self, weight_quantizer=None, activation_quantizer="default"):
105
+ from keras.src.quantizers.quantizers import AbsMaxQuantizer
106
+
107
+ if activation_quantizer == "default":
108
+ activation_quantizer = AbsMaxQuantizer()
109
+ super().__init__(weight_quantizer, activation_quantizer)
110
+ if self.weight_quantizer is not None:
111
+ if self.weight_quantizer.value_range != (-8, 7):
112
+ raise ValueError(
113
+ "Int4QuantizationConfig requires a weight_quantizer "
114
+ "with value_range=(-8, 7). Received: "
115
+ f"value_range={self.weight_quantizer.value_range}"
116
+ )
117
+
118
+ if self.weight_quantizer.output_dtype != "int8":
119
+ raise ValueError(
120
+ "Int4QuantizationConfig requires a weight_quantizer "
121
+ "with output_dtype='int8'. Received: "
122
+ f"output_dtype={self.weight_quantizer.output_dtype}"
123
+ )
124
+
125
+ @property
126
+ def mode(self):
127
+ return "int4"
128
+
129
+
130
+ @keras_export("keras.quantizers.Float8QuantizationConfig")
131
+ class Float8QuantizationConfig(QuantizationConfig):
132
+ """FP8 quantization config.
133
+
134
+ FP8 mixed-precision training does not support user defined quantizers.
135
+ This config is only used to indicate that FP8 mixed-precision training
136
+ should be used.
137
+ """
138
+
139
+ def __init__(self):
140
+ super().__init__(None, None)
141
+
142
+ @property
143
+ def mode(self):
144
+ return "float8"
145
+
146
+ def get_config(self):
147
+ return {}
148
+
149
+ @classmethod
150
+ def from_config(cls, config):
151
+ return cls()
152
+
153
+
154
+ def validate_and_resolve_config(mode, config):
155
+ """Validate and resolve quantization config.
156
+
157
+ This function validates the quantization config and resolves the mode.
158
+ If mode is not provided, it is inferred from the config.
159
+ If config is not provided, a default config is inferred from the mode.
160
+
161
+ Args:
162
+ mode: Quantization mode.
163
+ config: Quantization config.
164
+ """
165
+ # 1. Backwards Compatibility: Handle string shortcuts.
166
+ if isinstance(config, str):
167
+ mode = config
168
+ config = None
169
+
170
+ _validate_mode(mode)
171
+
172
+ # 2. Resolve "mode" into a Config object.
173
+ if config is None:
174
+ if mode == "int8":
175
+ config = Int8QuantizationConfig()
176
+ elif mode == "int4":
177
+ config = Int4QuantizationConfig()
178
+ elif mode == "float8":
179
+ config = Float8QuantizationConfig()
180
+ elif mode == "gptq":
181
+ raise ValueError(
182
+ "For GPTQ, you must pass a `GPTQConfig` object in the "
183
+ "`config` argument."
184
+ )
185
+ elif mode == "awq":
186
+ raise ValueError(
187
+ "For AWQ, you must pass an `AWQConfig` object in the "
188
+ "`config` argument."
189
+ )
190
+ else:
191
+ if mode is not None:
192
+ raise ValueError(
193
+ f"Invalid quantization mode. Received: mode={mode}"
194
+ )
195
+ raise ValueError(
196
+ "You must provide either `mode` or `config` to `quantize`."
197
+ )
198
+ else:
199
+ if not isinstance(config, QuantizationConfig):
200
+ raise ValueError(
201
+ "Argument `config` must be an instance of "
202
+ "`QuantizationConfig`. "
203
+ f"Received: config={config} (of type {type(config)})"
204
+ )
205
+
206
+ # 3. Validation: Prevent contradictions.
207
+ if mode is not None and config.mode != mode:
208
+ raise ValueError(
209
+ f"Contradictory arguments: mode='{mode}' but "
210
+ f"config.mode='{config.mode}'"
211
+ )
212
+
213
+ # Ensure mode is consistent.
214
+ mode = config.mode
215
+
216
+ # Ensure the mode derived from the config is valid.
217
+ _validate_mode(mode)
218
+
219
+ if mode == "gptq":
220
+ from keras.src.quantizers.gptq_config import GPTQConfig
221
+
222
+ if not isinstance(config, GPTQConfig):
223
+ raise ValueError(
224
+ "Mode 'gptq' requires a valid `config` argument of type "
225
+ f"`GPTQConfig`. Received: {type(config)}"
226
+ )
227
+
228
+ if mode == "awq":
229
+ from keras.src.quantizers.awq_config import AWQConfig
230
+
231
+ if not isinstance(config, AWQConfig):
232
+ raise ValueError(
233
+ "Mode 'awq' requires a valid `config` argument of type "
234
+ f"`AWQConfig`. Received: {type(config)}"
235
+ )
236
+
237
+ return config
238
+
239
+
240
+ def _validate_mode(mode):
241
+ """Validates quantization mode."""
242
+ if mode is not None and mode not in QUANTIZATION_MODES:
243
+ raise ValueError(
244
+ "Invalid quantization mode. "
245
+ f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}"
246
+ )
@@ -73,6 +73,23 @@ def abs_max_quantize(
73
73
  epsilon=backend.epsilon(),
74
74
  to_numpy=False,
75
75
  ):
76
+ """
77
+ Quantizes the input tensor using the absolute maximum quantization scheme.
78
+
79
+ Args:
80
+ inputs: Input tensor to quantize.
81
+ axis: Axis along which to compute the quantization range.
82
+ value_range: Tuple of the minimum and maximum values of the quantization
83
+ range.
84
+ dtype: Data type of the quantized output.
85
+ epsilon: Small value to avoid division by zero.
86
+ to_numpy: Whether to perform the quantization in numpy. This performs
87
+ the computation on the host CPU and can be useful for saving memory
88
+ on the device. If False, the computation is performed on the device.
89
+
90
+ Returns:
91
+ A tuple of the quantized tensor and the scale.
92
+ """
76
93
  if to_numpy:
77
94
  # Save memory on the device using numpy
78
95
  original_dtype = backend.standardize_dtype(inputs.dtype)
@@ -105,31 +122,69 @@ def abs_max_quantize(
105
122
  class AbsMaxQuantizer(Quantizer):
106
123
  def __init__(
107
124
  self,
108
- axis,
125
+ axis=None, # Deprecated, provide axis in __call__ instead.
109
126
  value_range=(-127, 127),
110
127
  epsilon=backend.epsilon(),
111
128
  output_dtype="int8",
112
129
  ):
113
130
  Quantizer.__init__(self, output_dtype=output_dtype)
114
- if isinstance(axis, int):
115
- axis = (axis,)
116
- self.axis = tuple(axis)
131
+ if axis is not None:
132
+ if isinstance(axis, int):
133
+ axis = (axis,)
134
+ self.axis = tuple(axis)
135
+ else:
136
+ self.axis = None
117
137
  self.value_range = value_range
118
138
  self.epsilon = epsilon
139
+ if output_dtype == "int8":
140
+ if value_range[0] < -128 or value_range[1] > 127:
141
+ raise ValueError(
142
+ f"Quantizer with output_dtype='int8' requires value_range "
143
+ f"to be within the interval [-128, 127]. Received: "
144
+ f"value_range={value_range}"
145
+ )
119
146
 
120
- def __call__(self, x):
147
+ def __call__(self, x, axis=None, to_numpy=False):
148
+ """
149
+ Quantizes the input tensor.
150
+
151
+ Args:
152
+ x: Input tensor to quantize.
153
+ axis: Axis along which to compute the quantization range. If None,
154
+ uses the axis specified in the constructor. If None and no axis
155
+ was specified in the constructor, defaults to -1.
156
+ to_numpy: Whether to perform the quantization in numpy. This
157
+ performs the computation on the host CPU and can be useful for
158
+ saving memory on the device. If False, the computation is
159
+ performed on the device.
160
+
161
+ Returns:
162
+ A tuple of the quantized tensor and the scale.
163
+ """
164
+ if axis is None:
165
+ axis = self.axis
166
+ if axis is None:
167
+ # Default to -1 if no axis is specified
168
+ axis = -1
121
169
  quantized_x, scale = abs_max_quantize(
122
- x, self.axis, self.value_range, self.output_dtype, self.epsilon
170
+ x,
171
+ axis,
172
+ self.value_range,
173
+ self.output_dtype,
174
+ self.epsilon,
175
+ to_numpy,
123
176
  )
124
177
  return quantized_x, scale
125
178
 
126
179
  def get_config(self):
127
- return {
128
- "axis": self.axis,
180
+ config = {
129
181
  "value_range": self.value_range,
130
182
  "epsilon": self.epsilon,
131
183
  "output_dtype": self.output_dtype,
132
184
  }
185
+ if self.axis is not None:
186
+ config["axis"] = self.axis
187
+ return config
133
188
 
134
189
 
135
190
  def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
@@ -281,7 +336,7 @@ def fake_quant_with_min_max_vars(
281
336
  ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
282
337
  )
283
338
  x_clamped = ops.clip(
284
- x, ops.cast(nudged_min, x.dtype), ops.cast(nudged_max, x.dtype)
339
+ ops.cast(x, nudged_min.dtype), nudged_min, nudged_max
285
340
  )
286
341
  x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
287
342
  result = ops.multiply(
@@ -318,6 +373,7 @@ def fake_quant_with_min_max_vars(
318
373
  grad_min = ops.sum(grad_min, axis=axes)
319
374
  else:
320
375
  grad_min = ops.sum(grad_min)
376
+ grad_min = ops.reshape(grad_min, ops.shape(min_val))
321
377
 
322
378
  # Gradient for max_val
323
379
  # When x is clipped to max, the gradient flows to max_val
@@ -327,6 +383,7 @@ def fake_quant_with_min_max_vars(
327
383
  grad_max = ops.sum(grad_max, axis=axes)
328
384
  else:
329
385
  grad_max = ops.sum(grad_max)
386
+ grad_max = ops.reshape(grad_max, ops.shape(max_val))
330
387
 
331
388
  return dx, grad_min, grad_max
332
389
 
@@ -596,11 +653,14 @@ def unpack_int4(packed, orig_len, axis=0, dtype="int8"):
596
653
  )
597
654
 
598
655
  def to_signed(x):
599
- """Converts unpacked nibbles [0, 15] to signed int4 [-8, 7]."""
656
+ """Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].
657
+
658
+ Uses a branchless XOR approach: (x ^ 8) - 8
659
+ This maps: 0->0, 1->1, ..., 7->7, 8->-8, 9->-7, ..., 15->-1
660
+ """
600
661
  dtype_x = backend.standardize_dtype(x.dtype)
601
662
  eight = ops.cast(8, dtype_x)
602
- sixteen = ops.cast(16, dtype_x)
603
- return ops.where(x < eight, x, x - sixteen)
663
+ return ops.subtract(ops.bitwise_xor(x, eight), eight)
604
664
 
605
665
  rank = getattr(packed.shape, "rank", None) or len(packed.shape)
606
666
  if axis < 0:
@@ -691,7 +751,7 @@ class GPTQQuantizer(Quantizer):
691
751
  self.zero = None
692
752
  self.maxq = None
693
753
 
694
- def find_params(self, input_tensor, weight=True):
754
+ def find_params(self, input_tensor):
695
755
  """Finds quantization parameters (scale and zero) for a given tensor."""
696
756
  self.scale, self.zero, self.maxq = compute_quantization_parameters(
697
757
  input_tensor,
@@ -699,7 +759,6 @@ class GPTQQuantizer(Quantizer):
699
759
  symmetric=self.symmetric,
700
760
  per_channel=self.per_channel,
701
761
  group_size=self.group_size,
702
- weight=weight,
703
762
  compute_dtype=self.compute_dtype,
704
763
  )
705
764
  return self.scale, self.zero, self.maxq
@@ -736,98 +795,105 @@ def compute_quantization_parameters(
736
795
  symmetric=False,
737
796
  per_channel=False,
738
797
  group_size=-1,
739
- weight=False,
740
798
  compute_dtype="float32",
741
799
  ):
742
800
  """
743
- Computes the scale and zero-point for quantization.
801
+ Computes the scale and zero-point for quantizing weight tensors.
744
802
 
745
803
  This function calculates the scale and zero-point required for quantizing
746
- a given tensor `x` based on the specified parameters. It supports grouped,
747
- per-channel, per-tensor, symmetric, and asymmetric quantization - along
748
- with any combinations of these.
804
+ a given weight tensor `x` based on the specified parameters. It supports
805
+ grouped, per-channel, per-tensor, symmetric, and asymmetric quantization.
806
+
807
+ For grouped quantization (per_channel=True, group_size > 0), the output
808
+ shapes are [out_features, n_groups] where n_groups is the number of groups
809
+ along the in_features dimension.
749
810
 
750
811
  Args:
751
- x: KerasTensor. The input tensor to quantize.
812
+ x: KerasTensor. The weight tensor to quantize with shape
813
+ [out_features, in_features].
752
814
  bits: int. The number of bits to quantize to (e.g., 4).
753
815
  symmetric: bool. Whether to use symmetric quantization.
754
816
  per_channel: bool. Whether to quantize per channel.
755
- group_size: int. The group size for quantization.
756
- weight: bool. Whether the input tensor is a weight tensor.
817
+ group_size: int. The group size for quantization. -1 means no grouping.
818
+ compute_dtype: str. The dtype for computation. Defaults to "float32".
757
819
 
758
820
  Returns:
759
821
  scale: KerasTensor. The scale tensor for quantization.
760
822
  zero: KerasTensor. The zero tensor for quantization.
761
823
  maxq: scalar. The maximum quantization value.
762
824
  """
825
+ # Input validation
763
826
  if x is None:
764
827
  raise ValueError(f"Input tensor {x} cannot be None.")
765
-
766
- # For weights, we typically expect at least a 2D tensor.
767
- if weight and len(x.shape) < 2:
828
+ if len(x.shape) < 2:
768
829
  raise ValueError(
769
830
  f"Input weight tensor {x} must have a rank of at "
770
831
  f"least 2, but got rank {len(x.shape)}."
771
832
  )
772
-
773
833
  if ops.size(x) == 0:
774
834
  raise ValueError("Input tensor 'x' cannot be empty.")
775
835
 
776
- original_shape = x.shape
777
-
778
- if per_channel:
779
- if weight:
780
- if group_size != -1:
781
- input_reshaped = ops.reshape(x, [-1, group_size])
782
- else:
783
- input_reshaped = ops.reshape(x, [original_shape[0], -1])
784
- else: # per-tensor
785
- input_reshaped = ops.reshape(x, [1, -1])
836
+ out_features, in_features = x.shape[0], x.shape[1]
786
837
 
787
- # Find min/max values
788
- min_values = ops.min(input_reshaped, axis=1)
789
- max_values = ops.max(input_reshaped, axis=1)
838
+ # Determine number of groups for quantization
839
+ if per_channel and group_size > 0:
840
+ n_groups = (in_features + group_size - 1) // group_size
841
+ else:
842
+ n_groups = 1
843
+
844
+ # Compute min/max values based on quantization mode
845
+ if n_groups > 1:
846
+ # Grouped quantization: output shape [out_features, n_groups]
847
+ remainder = in_features % group_size
848
+ if remainder != 0:
849
+ pad_size = group_size - remainder
850
+ x = ops.pad(x, [[0, 0], [0, pad_size]], constant_values=0.0)
851
+
852
+ x_grouped = ops.reshape(x, [out_features, n_groups, group_size])
853
+ min_values = ops.min(x_grouped, axis=2)
854
+ max_values = ops.max(x_grouped, axis=2)
855
+ else:
856
+ # Per-channel or per-tensor: compute stats along rows
857
+ reduction_shape = [out_features, -1] if per_channel else [1, -1]
858
+ x_reshaped = ops.reshape(x, reduction_shape)
859
+ min_values = ops.min(x_reshaped, axis=1)
860
+ max_values = ops.max(x_reshaped, axis=1)
790
861
 
791
- # Apply symmetric quantization logic if enabled
862
+ # Symmetric quantization: make range symmetric around zero
792
863
  if symmetric:
793
- max_values = ops.maximum(ops.abs(min_values), max_values)
864
+ max_abs = ops.maximum(ops.abs(min_values), max_values)
794
865
  min_values = ops.where(
795
- ops.less(min_values, 0), ops.negative(max_values), min_values
866
+ ops.less(min_values, 0), ops.negative(max_abs), min_values
796
867
  )
868
+ max_values = max_abs
797
869
 
798
- # Ensure range is not zero to avoid division errors
870
+ # Ensure non-zero range to avoid division errors
799
871
  zero_range = ops.equal(min_values, max_values)
800
872
  min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)
801
873
  max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
802
874
 
875
+ # Compute scale and zero-point
803
876
  maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)
804
-
805
- # Calculate scale and zero-point
806
877
  scale = ops.divide(ops.subtract(max_values, min_values), maxq)
878
+ scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
879
+
807
880
  if symmetric:
808
881
  zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
809
882
  else:
810
883
  zero = ops.round(ops.divide(ops.negative(min_values), scale))
811
884
 
812
- # Ensure scale is non-zero
813
- scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
814
-
815
- if weight:
816
- # Per-channel, non-grouped case: simple reshape is correct.
817
- if per_channel and group_size == -1:
818
- scale = ops.reshape(scale, [-1, 1])
819
- zero = ops.reshape(zero, [-1, 1])
820
- elif not per_channel:
821
- num_rows = original_shape[0]
822
- scale = ops.tile(ops.reshape(scale, (1, 1)), (num_rows, 1))
823
- zero = ops.tile(ops.reshape(zero, (1, 1)), (num_rows, 1))
824
- if per_channel:
885
+ # Reshape output to [out_features, n_groups] or [out_features, 1]
886
+ if n_groups > 1:
887
+ pass # Already [out_features, n_groups]
888
+ elif per_channel:
825
889
  scale = ops.reshape(scale, [-1, 1])
826
890
  zero = ops.reshape(zero, [-1, 1])
891
+ else:
892
+ # Per-tensor: tile single value to [out_features, 1]
893
+ scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))
894
+ zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))
827
895
 
828
- zero = ops.cast(zero, "uint8")
829
-
830
- return scale, zero, maxq
896
+ return scale, ops.cast(zero, "uint8"), maxq
831
897
 
832
898
 
833
899
  def quantize_with_zero_point(input_tensor, scale, zero, maxq):
@@ -0,0 +1,23 @@
1
+ import re
2
+
3
+
4
+ def should_quantize_layer(layer, filters):
5
+ """Determines if a layer should be quantized based on filters.
6
+
7
+ Args:
8
+ layer: The layer to check.
9
+ filters: A regex string, a list of regex strings, or a callable.
10
+ If None, returns True.
11
+
12
+ Returns:
13
+ True if the layer should be quantized, False otherwise.
14
+ """
15
+ if filters is None:
16
+ return True
17
+ if isinstance(filters, str):
18
+ return bool(re.search(filters, layer.name))
19
+ if isinstance(filters, (list, tuple)):
20
+ return any(re.search(pat, layer.name) for pat in filters)
21
+ if callable(filters):
22
+ return filters(layer)
23
+ return True
@@ -8,6 +8,8 @@ from keras.src.backend.common import global_state
8
8
  from keras.src.utils import jax_utils
9
9
  from keras.src.utils.naming import auto_name
10
10
 
11
+ GLOBAL_SEED_GENERATOR = "global_seed_generator"
12
+
11
13
 
12
14
  @keras_export("keras.random.SeedGenerator")
13
15
  class SeedGenerator:
@@ -27,7 +29,7 @@ class SeedGenerator:
27
29
  a local `StateGenerator` with either a deterministic or random initial
28
30
  state.
29
31
 
30
- Remark concerning the JAX backen: Note that the use of a local
32
+ Remark concerning the JAX backend: Note that the use of a local
31
33
  `StateGenerator` as seed argument is required for JIT compilation of
32
34
  RNG with the JAX backend, because the use of global state is not
33
35
  supported.
@@ -109,7 +111,7 @@ class SeedGenerator:
109
111
  return new_seed_value
110
112
 
111
113
  def get_config(self):
112
- return {"seed": self._initial_seed}
114
+ return {"seed": self._initial_seed, "name": self.name}
113
115
 
114
116
  @classmethod
115
117
  def from_config(cls, config):
@@ -133,10 +135,10 @@ def global_seed_generator():
133
135
  "out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n"
134
136
  "```"
135
137
  )
136
- gen = global_state.get_global_attribute("global_seed_generator")
138
+ gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR)
137
139
  if gen is None:
138
140
  gen = SeedGenerator()
139
- global_state.set_global_attribute("global_seed_generator", gen)
141
+ global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen)
140
142
  return gen
141
143
 
142
144