keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +16 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +6 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +16 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +12 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/dtypes.py +6 -12
  28. keras/src/backend/common/name_scope.py +2 -1
  29. keras/src/backend/common/variables.py +38 -20
  30. keras/src/backend/jax/core.py +126 -78
  31. keras/src/backend/jax/distribution_lib.py +16 -2
  32. keras/src/backend/jax/layer.py +3 -1
  33. keras/src/backend/jax/linalg.py +4 -0
  34. keras/src/backend/jax/nn.py +511 -29
  35. keras/src/backend/jax/numpy.py +109 -23
  36. keras/src/backend/jax/optimizer.py +3 -2
  37. keras/src/backend/jax/trainer.py +18 -3
  38. keras/src/backend/numpy/linalg.py +4 -0
  39. keras/src/backend/numpy/nn.py +313 -2
  40. keras/src/backend/numpy/numpy.py +97 -8
  41. keras/src/backend/openvino/__init__.py +1 -0
  42. keras/src/backend/openvino/core.py +6 -23
  43. keras/src/backend/openvino/linalg.py +4 -0
  44. keras/src/backend/openvino/nn.py +271 -20
  45. keras/src/backend/openvino/numpy.py +1369 -195
  46. keras/src/backend/openvino/random.py +7 -14
  47. keras/src/backend/tensorflow/layer.py +43 -9
  48. keras/src/backend/tensorflow/linalg.py +24 -0
  49. keras/src/backend/tensorflow/nn.py +545 -1
  50. keras/src/backend/tensorflow/numpy.py +351 -56
  51. keras/src/backend/tensorflow/trainer.py +6 -2
  52. keras/src/backend/torch/core.py +3 -1
  53. keras/src/backend/torch/linalg.py +4 -0
  54. keras/src/backend/torch/nn.py +125 -0
  55. keras/src/backend/torch/numpy.py +109 -9
  56. keras/src/backend/torch/trainer.py +8 -2
  57. keras/src/callbacks/__init__.py +1 -0
  58. keras/src/callbacks/callback_list.py +45 -11
  59. keras/src/callbacks/model_checkpoint.py +5 -0
  60. keras/src/callbacks/orbax_checkpoint.py +332 -0
  61. keras/src/callbacks/terminate_on_nan.py +54 -5
  62. keras/src/datasets/cifar10.py +5 -0
  63. keras/src/distillation/__init__.py +1 -0
  64. keras/src/distillation/distillation_loss.py +390 -0
  65. keras/src/distillation/distiller.py +598 -0
  66. keras/src/distribution/distribution_lib.py +14 -0
  67. keras/src/dtype_policies/__init__.py +4 -0
  68. keras/src/dtype_policies/dtype_policy.py +180 -1
  69. keras/src/export/__init__.py +2 -0
  70. keras/src/export/export_utils.py +39 -2
  71. keras/src/export/litert.py +248 -0
  72. keras/src/export/onnx.py +6 -0
  73. keras/src/export/openvino.py +1 -1
  74. keras/src/export/tf2onnx_lib.py +3 -0
  75. keras/src/layers/__init__.py +13 -0
  76. keras/src/layers/activations/softmax.py +9 -4
  77. keras/src/layers/attention/attention.py +1 -1
  78. keras/src/layers/attention/multi_head_attention.py +4 -1
  79. keras/src/layers/core/dense.py +406 -102
  80. keras/src/layers/core/einsum_dense.py +521 -116
  81. keras/src/layers/core/embedding.py +257 -99
  82. keras/src/layers/core/input_layer.py +1 -0
  83. keras/src/layers/core/reversible_embedding.py +399 -0
  84. keras/src/layers/input_spec.py +17 -17
  85. keras/src/layers/layer.py +50 -15
  86. keras/src/layers/merging/concatenate.py +6 -5
  87. keras/src/layers/merging/dot.py +4 -1
  88. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  89. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  90. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  91. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  92. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  93. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  94. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  95. keras/src/layers/preprocessing/discretization.py +6 -5
  96. keras/src/layers/preprocessing/feature_space.py +8 -4
  97. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  98. keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
  99. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  100. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  101. keras/src/layers/preprocessing/index_lookup.py +19 -1
  102. keras/src/layers/preprocessing/normalization.py +16 -1
  103. keras/src/layers/preprocessing/string_lookup.py +26 -28
  104. keras/src/layers/regularization/dropout.py +43 -1
  105. keras/src/layers/rnn/gru.py +1 -1
  106. keras/src/layers/rnn/lstm.py +2 -2
  107. keras/src/layers/rnn/rnn.py +19 -0
  108. keras/src/layers/rnn/simple_rnn.py +1 -1
  109. keras/src/legacy/preprocessing/image.py +4 -1
  110. keras/src/legacy/preprocessing/sequence.py +20 -12
  111. keras/src/losses/loss.py +1 -1
  112. keras/src/losses/losses.py +24 -0
  113. keras/src/metrics/confusion_metrics.py +7 -6
  114. keras/src/models/cloning.py +4 -0
  115. keras/src/models/functional.py +11 -3
  116. keras/src/models/model.py +195 -44
  117. keras/src/ops/image.py +257 -20
  118. keras/src/ops/linalg.py +93 -0
  119. keras/src/ops/nn.py +268 -2
  120. keras/src/ops/numpy.py +701 -44
  121. keras/src/ops/operation.py +90 -29
  122. keras/src/ops/operation_utils.py +2 -0
  123. keras/src/optimizers/adafactor.py +29 -10
  124. keras/src/optimizers/base_optimizer.py +22 -3
  125. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  126. keras/src/optimizers/muon.py +65 -31
  127. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  128. keras/src/quantizers/__init__.py +14 -1
  129. keras/src/quantizers/awq.py +361 -0
  130. keras/src/quantizers/awq_config.py +140 -0
  131. keras/src/quantizers/awq_core.py +217 -0
  132. keras/src/quantizers/gptq.py +346 -207
  133. keras/src/quantizers/gptq_config.py +63 -13
  134. keras/src/quantizers/gptq_core.py +328 -215
  135. keras/src/quantizers/quantization_config.py +246 -0
  136. keras/src/quantizers/quantizers.py +407 -38
  137. keras/src/quantizers/utils.py +23 -0
  138. keras/src/random/seed_generator.py +6 -4
  139. keras/src/saving/file_editor.py +81 -6
  140. keras/src/saving/orbax_util.py +26 -0
  141. keras/src/saving/saving_api.py +37 -14
  142. keras/src/saving/saving_lib.py +1 -1
  143. keras/src/testing/__init__.py +1 -0
  144. keras/src/testing/test_case.py +45 -5
  145. keras/src/trainers/compile_utils.py +38 -17
  146. keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
  147. keras/src/tree/torchtree_impl.py +215 -0
  148. keras/src/tree/tree_api.py +6 -1
  149. keras/src/utils/backend_utils.py +31 -4
  150. keras/src/utils/dataset_utils.py +234 -35
  151. keras/src/utils/file_utils.py +49 -11
  152. keras/src/utils/image_utils.py +14 -2
  153. keras/src/utils/jax_layer.py +244 -55
  154. keras/src/utils/module_utils.py +29 -0
  155. keras/src/utils/progbar.py +10 -12
  156. keras/src/utils/python_utils.py +5 -0
  157. keras/src/utils/rng_utils.py +9 -1
  158. keras/src/utils/tracking.py +70 -5
  159. keras/src/version.py +1 -1
  160. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  161. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
  162. keras/src/quantizers/gptq_quant.py +0 -133
  163. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  164. {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,282 @@
1
+ import types
2
+
1
3
  from keras.src import ops
4
+ from keras.src import quantizers
2
5
  from keras.src.layers import Dense
3
6
  from keras.src.layers import EinsumDense
4
- from keras.src.quantizers.gptq_quant import dequantize
7
+ from keras.src.ops import linalg
8
+ from keras.src.quantizers.gptq_config import GPTQConfig
9
+ from keras.src.quantizers.quantizers import GPTQQuantizer
10
+ from keras.src.quantizers.quantizers import compute_quantization_parameters
11
+ from keras.src.quantizers.quantizers import dequantize_with_zero_point
12
+ from keras.src.quantizers.quantizers import quantize_with_zero_point
13
+
14
+
15
+ def _stable_permutation(metric):
16
+ """Return a stable permutation that sorts `metric` in descending order.
17
+ Uses an index-based jitter to break ties deterministically."""
18
+ n = ops.shape(metric)[0]
19
+ idx = ops.arange(0, n, dtype="int32")
20
+ # tiny jitter = (idx / n) * 1e-12 so it never flips a real strict ordering
21
+ jitter = ops.divide(ops.cast(idx, "float32"), ops.cast(n, "float32"))
22
+ metric_jittered = ops.add(metric, ops.multiply(jitter, 1e-12))
23
+ # argsort by negative to get descending
24
+ return ops.argsort(ops.negative(metric_jittered))
25
+
26
+
27
+ def gptq_quantize_matrix(
28
+ weights_transpose,
29
+ inv_hessian,
30
+ *,
31
+ blocksize=128,
32
+ group_size=-1,
33
+ activation_order=False,
34
+ order_metric=None,
35
+ compute_scale_zero=compute_quantization_parameters,
36
+ ):
37
+ """
38
+ Implements the GPTQ error correction updates.
39
+
40
+ For a single column update (column j):
41
+ e = invH[j, j] * (w_j - q_j)
42
+ W[:, j+1:] -= e * invH[j, j+1:]
43
+ where:
44
+ - w_j is the original column,
45
+ - q_j is the quantized column,
46
+ - invH is the inverse Hessian,
47
+ - e is the propagated error term.
48
+
49
+ Across entire blocks:
50
+ W[:, future] -= E_block * invH[block, future]
51
+ where:
52
+ - E_block is the quantization error accumulated for the current block,
53
+ - invH[block, future] denotes the cross-block slice of the inverse Hessian,
54
+ - W[:, future] are the columns yet to be quantized.
55
+
56
+ Args:
57
+ weights_transpose: Transposed weight matrix [out_features, in_features]
58
+ to quantize.
59
+ inv_hessian: Inverse Hessian matrix [in_features, in_features] for
60
+ error propagation.
61
+ blocksize: Size of the blocks to process (default: 128).
62
+ group_size: Size of the groups for parameter reuse
63
+ (default: -1, no grouping).
64
+ activation_order: Whether to apply activation-order permutation
65
+ (default: False).
66
+ order_metric: Metric for ordering features
67
+ (default: None, uses 1 / diag(invH)).
68
+ compute_scale_zero: Function to compute scale and zero for
69
+ quantization.
70
+
71
+ Returns:
72
+ quantized_weights: Quantized weight matrix [out_features, in_features].
73
+ scale: float32. Scale parameters for quantization
74
+ [out_features, num_groups].
75
+ zero: Zero-point parameters for quantization [out_features, num_groups].
76
+ g_idx: int32. Group indices for each feature [in_features].
77
+ """
78
+ in_features = ops.shape(weights_transpose)[1]
79
+
80
+ if activation_order:
81
+ # Use 1 / diag(inverse_hessian) as importance proxy by default.
82
+ if order_metric is None:
83
+ order_metric = ops.reciprocal(
84
+ ops.add(ops.diagonal(inv_hessian), 1e-12)
85
+ )
86
+ else:
87
+ # sanitize provided metric
88
+ order_metric = ops.cast(order_metric, "float32")
89
+ order_metric = ops.where(
90
+ ops.isfinite(order_metric),
91
+ order_metric,
92
+ ops.zeros_like(order_metric),
93
+ )
94
+ # Sort in descending order by importance
95
+ perm = _stable_permutation(order_metric)
96
+ inv_perm = ops.argsort(perm)
97
+
98
+ weights_transpose = ops.take(weights_transpose, perm, axis=1)
99
+ inv_hessian = ops.take(
100
+ ops.take(inv_hessian, perm, axis=0), perm, axis=1
101
+ )
102
+ else:
103
+ perm = inv_perm = None
104
+
105
+ # weights_buffer: [out_features, in_features]
106
+ weights_buffer = weights_transpose
107
+ # Buffer for the final quantized matrix: [out_features, in_features]
108
+ quantized_weights_buffer = ops.zeros_like(weights_transpose, dtype="int32")
109
+
110
+ scale_chunks = []
111
+ zero_chunks = []
112
+
113
+ # Compute effective group size
114
+ effective_group = in_features if group_size == -1 else group_size
115
+
116
+ # Process features in blocks
117
+ for block_start in range(0, in_features, blocksize):
118
+ block_end = min(block_start + blocksize, in_features)
119
+ block_size = block_end - block_start
120
+
121
+ # Block views
122
+ # block_weights: [out_features, block_size]
123
+ block_weights = weights_buffer[:, block_start:block_end]
124
+ # block_error: [out_features, block_size]
125
+ block_error = ops.zeros_like(block_weights)
126
+ # block_inv_hessian: [block_size, block_size]
127
+ block_inv_hessian = inv_hessian[
128
+ block_start:block_end, block_start:block_end
129
+ ]
130
+
131
+ # Per-group cached params for reuse within the group
132
+ cached_scale = None
133
+ cached_zero = None
134
+ cached_maxq = None
135
+ cached_group_start = -1
136
+
137
+ for block_idx in range(block_size):
138
+ # Current global column index, represents the original column
139
+ # in the weight matrix
140
+ global_idx = block_start + block_idx
141
+ # weight_column: [out_features,]
142
+ weight_column = block_weights[:, block_idx]
143
+ # Group-wise parameter reuse (compute once per group)
144
+ if not effective_group == in_features: # group_size != -1
145
+ # Determine the group start index for the current column
146
+ group_start = (global_idx // effective_group) * effective_group
147
+ if group_start != cached_group_start:
148
+ # New group encountered, compute & cache params
149
+ # for this group
150
+ group_end = min(group_start + effective_group, in_features)
151
+ group_slice = weights_buffer[:, group_start:group_end]
152
+ cached_scale, cached_zero, cached_maxq = compute_scale_zero(
153
+ group_slice
154
+ )
155
+ # Store params once per group (in the order encountered).
156
+ scale_chunks.append(cached_scale)
157
+ zero_chunks.append(cached_zero)
158
+ cached_group_start = group_start
159
+ scale, zero, maxq = cached_scale, cached_zero, cached_maxq
160
+ else:
161
+ # Single global group covering all columns.
162
+ if cached_scale is None:
163
+ cached_scale, cached_zero, cached_maxq = compute_scale_zero(
164
+ weights_buffer
165
+ )
166
+ scale_chunks.append(cached_scale)
167
+ zero_chunks.append(cached_zero)
168
+ cached_group_start = 0
169
+ scale, zero, maxq = cached_scale, cached_zero, cached_maxq
170
+
171
+ # Quantize column and store it.
172
+ # quantized_column: [out_features, 1]
173
+ quantized_column = quantize_with_zero_point(
174
+ ops.expand_dims(weight_column, 1), scale, zero, maxq
175
+ )
176
+
177
+ # Store quantized column in the buffer.
178
+ quantized_weights_buffer = ops.slice_update(
179
+ quantized_weights_buffer,
180
+ (0, global_idx),
181
+ ops.cast(quantized_column, "int32"),
182
+ )
183
+ # Dequantize column to compute error.
184
+ # dequantized_col: [out_features,]
185
+ dequantized_col = dequantize_with_zero_point(
186
+ quantized_column, scale, zero
187
+ )[:, 0]
188
+ # Error feedback for remaining columns within the block
189
+ # block_inv_hessian_diag: scalar
190
+ current_block_influence = block_inv_hessian[block_idx, block_idx]
191
+ # We divide by current_block_influence to get the
192
+ # correct scaling of the error term.
193
+ err = ops.divide(
194
+ ops.subtract(weight_column, dequantized_col),
195
+ current_block_influence,
196
+ )
197
+ # Record error for propagation to future blocks
198
+ block_error = ops.slice_update(
199
+ block_error, (0, block_idx), ops.expand_dims(err, 1)
200
+ )
201
+
202
+ # Update remaining columns in the current block
203
+ # (those before the current column have already been quantized)
204
+ # Propagate error to remaining columns in the block.
205
+ if block_idx < block_size - 1:
206
+ # update: [out_features, block_size - block_idx - 1]
207
+ update = ops.matmul(
208
+ ops.expand_dims(err, 1),
209
+ ops.expand_dims(
210
+ block_inv_hessian[block_idx, block_idx + 1 :], 0
211
+ ),
212
+ )
213
+ # tail is a view of the remaining columns in the block
214
+ # to be updated
215
+ # tail: [out_features, block_size - block_idx - 1]
216
+ tail = block_weights[:, block_idx + 1 :]
217
+ block_weights = ops.slice_update(
218
+ block_weights,
219
+ (0, block_idx + 1),
220
+ ops.subtract(tail, update),
221
+ )
222
+
223
+ # Propagate block errors to future features (beyond the block)
224
+ if block_end < in_features:
225
+ # Total update for all future columns, based on the
226
+ # accumulated error in this block. This is calculated
227
+ # as the matrix product of the block_error and the
228
+ # relevant slice of the inverse Hessian.
229
+ # total_update: [out_features, in_features - block_end]
230
+ total_update = ops.matmul(
231
+ block_error, inv_hessian[block_start:block_end, block_end:]
232
+ )
233
+ # Update the remaining weights in the buffer. This is done
234
+ # by subtracting the total_update from the remaining columns.
235
+ weights_buffer = ops.concatenate(
236
+ [
237
+ weights_buffer[:, :block_end],
238
+ ops.subtract(weights_buffer[:, block_end:], total_update),
239
+ ],
240
+ axis=1,
241
+ )
242
+
243
+ # Build group indices for each (possibly permuted) column
244
+ # base_group = effective_group (int)
245
+ base_group = effective_group
246
+
247
+ # g_idx in permuted domain
248
+ g_idx = ops.arange(0, in_features, dtype="int32")
249
+ g_idx = ops.divide(g_idx, base_group)
250
+ g_idx = ops.cast(g_idx, "float32")
251
+
252
+ # Map group indices and quantized weights back to original column order
253
+ if activation_order:
254
+ g_idx = ops.take(g_idx, inv_perm, axis=0)
255
+ quantized_weights_buffer = ops.take(
256
+ quantized_weights_buffer, inv_perm, axis=1
257
+ )
258
+
259
+ # Concatenate recorded group params
260
+ if len(scale_chunks) == 0:
261
+ # Edge case: no groups recorded (empty input); fall back to whole matrix
262
+ s, z, _ = compute_scale_zero(weights_transpose)
263
+ scale = s
264
+ zero = z
265
+ else:
266
+ scale = ops.concatenate(scale_chunks, axis=1)
267
+ zero = ops.concatenate(zero_chunks, axis=1)
268
+
269
+ return quantized_weights_buffer, scale, zero, g_idx
5
270
 
6
271
 
7
272
  class GPTQ:
8
- def __init__(self, layer):
273
+ def __init__(self, layer, config=GPTQConfig(tokenizer=None, dataset=None)):
9
274
  self.original_layer = layer
10
275
  self.num_samples = 0
11
- self.quantizer = None
276
+ self.config = config
277
+ self.quantizer = GPTQQuantizer(
278
+ config, compute_dtype=layer.variable_dtype
279
+ )
12
280
 
13
281
  # Explicitly handle each supported layer type
14
282
  if isinstance(layer, Dense) or (
@@ -16,21 +284,18 @@ class GPTQ:
16
284
  ):
17
285
  # For a standard Dense layer, the dimensions are straightforward.
18
286
  self.kernel_shape = layer.kernel.shape
19
- self.rows = self.kernel_shape[0] # Input features
20
- self.columns = self.kernel_shape[1] # Output features
21
- self.layer = layer # The layer itself can be used directly.
287
+ # rows: [input_features]
288
+ self.rows = self.kernel_shape[0]
289
+ # columns: [output_features]
290
+ self.columns = self.kernel_shape[1]
291
+ self.layer = layer
22
292
 
23
293
  # Handle 3D EinsumDense layers (typically from attention blocks).
24
294
  elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:
25
295
  # For EinsumDense, we determine the effective 2D dimensions.
26
296
  self.kernel_shape = layer.kernel.shape
27
297
  shape = list(self.kernel_shape)
28
- try:
29
- d_model_dim_index = shape.index(max(shape))
30
- except ValueError:
31
- raise TypeError(
32
- f"Could not determine hidden dimension from shape {shape}"
33
- )
298
+ d_model_dim_index = shape.index(max(shape))
34
299
 
35
300
  if d_model_dim_index == 0: # QKV projection case
36
301
  in_features, heads, head_dim = shape
@@ -47,17 +312,9 @@ class GPTQ:
47
312
 
48
313
  # Create a temporary object that holds a reshaped
49
314
  # 2D version of the kernel.
50
- self.layer = type(
51
- "temp",
52
- (object,),
53
- {
54
- "kernel": ops.reshape(
55
- layer.kernel, (self.rows, self.columns)
56
- ),
57
- "bias": layer.bias,
58
- },
59
- )()
60
-
315
+ self.layer = types.SimpleNamespace(
316
+ kernel=ops.reshape(layer.kernel, (self.rows, self.columns)),
317
+ )
61
318
  else:
62
319
  # Raise an error if the layer is not supported.
63
320
  raise TypeError(f"Unsupported layer type for GPTQ: {type(layer)}")
@@ -86,53 +343,55 @@ class GPTQ:
86
343
  pre-initialized Hessian matrix `self.hessian`.
87
344
  """
88
345
  if input_batch is None:
89
- raise ValueError("Input tensor 'input_batch' cannot be None.")
346
+ raise ValueError("Input tensor cannot be None.")
90
347
 
91
348
  if len(input_batch.shape) < 2:
92
349
  raise ValueError(
93
- f"Input tensor 'input_batch' must have a rank of at least 2 "
94
- f"(e.g., [batch, features]), but got rank "
95
- f"{len(input_batch.shape)}."
350
+ "Input tensor must have rank >= 2 "
351
+ f"(got rank {len(input_batch.shape)})."
96
352
  )
97
353
  if ops.size(input_batch) == 0:
98
- raise ValueError("Input tensor 'input_batch' cannot be empty.")
99
-
354
+ raise ValueError("Input tensor cannot be empty.")
100
355
  if len(input_batch.shape) > 2:
356
+ # [batch, features]
101
357
  input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))
102
- input_batch = ops.cast(input_batch, "float32")
358
+ x = ops.cast(input_batch, "float32")
359
+
360
+ num_new_samples = ops.shape(x)[0]
361
+ num_prev_samples = self.num_samples
362
+ total_samples = ops.add(num_prev_samples, num_new_samples)
103
363
 
104
- if self.hessian.shape[0] != input_batch.shape[-1]:
364
+ if ops.shape(self.hessian)[0] != ops.shape(x)[-1]:
105
365
  raise ValueError(
106
- f"Hessian dimensions ({self.hessian.shape[0]}) do not"
107
- "match input features ({input_batch.shape[-1]})."
366
+ f"Hessian dimensions ({ops.shape(self.hessian)[0]}) do not "
367
+ f"match input features ({ops.shape(x)[-1]})."
108
368
  )
109
369
 
110
- current_hessian = ops.multiply(
111
- 2, ops.matmul(ops.transpose(input_batch), input_batch)
370
+ # gram_matrix: [features, features]
371
+ gram_matrix = ops.matmul(ops.transpose(x), x)
372
+ # Ensures numerical stability and symmetry in case of large floating
373
+ # point activations.
374
+ gram_matrix = ops.divide(
375
+ ops.add(gram_matrix, ops.transpose(gram_matrix)), 2.0
112
376
  )
113
377
 
114
- if self.num_samples == 0:
115
- self.hessian = current_hessian
116
- else:
117
- total_samples = ops.add(self.num_samples, input_batch.shape[0])
118
- old_hessian_weight = ops.divide(self.num_samples, total_samples)
119
- current_hessian_weight = ops.divide(
120
- input_batch.shape[0], total_samples
378
+ # Decay previous mean and add current per-sample contribution
379
+ # (factor 2/N)
380
+ if self.num_samples > 0:
381
+ self.hessian = ops.multiply(
382
+ self.hessian, ops.divide(num_prev_samples, total_samples)
121
383
  )
122
384
 
123
- # Update the accumulated Hessian
124
- old_term = ops.multiply(self.hessian, old_hessian_weight)
125
- current_term = ops.multiply(current_hessian, current_hessian_weight)
126
- self.hessian = ops.add(old_term, current_term)
385
+ self.hessian = ops.add(
386
+ self.hessian,
387
+ ops.multiply(ops.divide(2.0, total_samples), gram_matrix),
388
+ )
127
389
 
128
- self.num_samples = ops.add(self.num_samples, input_batch.shape[0])
390
+ self.num_samples = self.num_samples + ops.shape(x)[0] or 0
129
391
 
130
- def quantize_and_correct_block(
392
+ def quantize_and_correct_layer(
131
393
  self,
132
394
  blocksize=128,
133
- hessian_damping=0.01,
134
- group_size=-1,
135
- activation_order=False,
136
395
  ):
137
396
  """
138
397
  Performs GPTQ quantization and correction on the layer's weights.
@@ -143,66 +402,42 @@ class GPTQ:
143
402
  quantization error by updating the remaining weights.
144
403
 
145
404
  The algorithm follows these main steps:
146
- 1. **Initialization**: It optionally reorders the weight columns based
405
+ 1. Initialization: It optionally reorders the weight columns based
147
406
  on activation magnitudes (`activation_order=True`) to protect more
148
407
  salient
149
408
  weights.
150
- 2. **Hessian Modification**: The Hessian matrix, pre-computed from
409
+ 2. Hessian Modification: The Hessian matrix, pre-computed from
151
410
  calibration data, is dampened to ensure its invertibility and
152
411
  stability.
153
- 3. **Iterative Quantization**: The function iterates through the
412
+ 3. Iterative Quantization: The function iterates through the
154
413
  weight columns in blocks (`blocksize`). In each iteration, it:
155
414
  a. Quantizes one column.
156
415
  b. Calculates the quantization error.
157
416
  c. Updates the remaining weights in the *current* block by
158
417
  distributing the error, using the inverse Hessian.
159
- 4. **Block-wise Correction**: After a block is quantized, the total
418
+ 4. Block-wise Correction: After a block is quantized, the total
160
419
  error from that block is propagated to the *next* block of weights
161
420
  to be processed.
162
- 5. **Finalization**: The quantized weights are reordered back if
421
+ 5. Finalization: The quantized weights are reordered back if
163
422
  `activation_order` was used, and the layer's weights are updated.
164
-
165
423
  This implementation is based on the official GPTQ paper and repository.
166
424
  For more details, see:
167
425
  - Paper: https://arxiv.org/abs/2210.17323
168
426
  - Original Code: https://github.com/IST-DASLab/gptq
169
427
 
428
+
170
429
  Args:
171
430
  blocksize: (int, optional) The size of the weight block to process
172
431
  at a time. Defaults to 128.
173
- hessian_damping: (float, optional) The percentage of dampening to
174
- add the
175
- Hessian's diagonal. A value of 0.01 is recommended.
176
- Defaults to 0.01.
177
- group_size: (int, optional) The number of weights that share the
178
- same quantization parameters (scale and zero-point).
179
- A value of -1 indicates per-channel quantization.
180
- activation_order: (bool, optional) If True, reorders weight columns
181
- based
182
- on their activation's second-order information.
183
432
  """
184
-
185
- weights_matrix = ops.transpose(ops.cast(self.layer.kernel, "float32"))
186
- hessian_matrix = ops.cast(self.hessian, "float32")
187
-
188
- if activation_order:
189
- permutation = ops.argsort(
190
- ops.negative(ops.diagonal(hessian_matrix))
191
- )
192
- weights_matrix = ops.take(weights_matrix, permutation, axis=1)
193
- hessian_matrix = ops.take(
194
- ops.take(hessian_matrix, permutation, axis=0),
195
- permutation,
196
- axis=1,
197
- )
198
- inverse_permutation = ops.argsort(permutation)
433
+ weights_matrix = ops.transpose(self.layer.kernel)
199
434
 
200
435
  # Dampen the Hessian for Stability
201
- hessian_diagonal = ops.diagonal(hessian_matrix)
436
+ hessian_diagonal = ops.diagonal(self.hessian)
202
437
  dead_diagonal = ops.equal(hessian_diagonal, 0.0)
203
438
  hessian_diagonal = ops.where(dead_diagonal, 1.0, hessian_diagonal)
204
439
  hessian_matrix = ops.add(
205
- hessian_matrix,
440
+ self.hessian,
206
441
  ops.diag(
207
442
  ops.where(dead_diagonal, 1.0, ops.zeros_like(hessian_diagonal))
208
443
  ),
@@ -210,7 +445,7 @@ class GPTQ:
210
445
 
211
446
  # Add dampening factor to the Hessian diagonal
212
447
  damping_factor = ops.multiply(
213
- hessian_damping, ops.mean(hessian_diagonal)
448
+ self.config.hessian_damping, ops.mean(hessian_diagonal)
214
449
  )
215
450
  hessian_diagonal = ops.add(hessian_diagonal, damping_factor)
216
451
  hessian_matrix = ops.add(
@@ -221,130 +456,34 @@ class GPTQ:
221
456
  )
222
457
 
223
458
  # Compute the inverse Hessian, which is used for error correction
224
- inverse_hessian = ops.linalg.inv(hessian_matrix)
225
- quantized_weights = ops.zeros_like(weights_matrix)
226
-
227
- for block_start in range(0, self.rows, blocksize):
228
- block_end = min(ops.add(block_start, blocksize), self.rows)
229
- block_size = ops.subtract(block_end, block_start)
230
- # Extract the current block of weights and its corresponding
231
- # Hessian
232
- block_weights = weights_matrix[:, block_start:block_end]
233
- block_quantized = ops.zeros_like(block_weights)
234
- block_errors = ops.zeros_like(block_weights)
235
- block_inverse_hessian = inverse_hessian[
236
- block_start:block_end, block_start:block_end
237
- ]
238
-
239
- # Process one column at a time within the block
240
- for col_idx in range(block_size):
241
- weight_column = block_weights[:, col_idx]
242
- diagonal_element = block_inverse_hessian[col_idx, col_idx]
243
-
244
- if group_size != -1:
245
- if ops.mod(ops.add(block_start, col_idx), group_size) == 0:
246
- self.quantizer.find_params(
247
- weights_matrix[
248
- :,
249
- (ops.add(block_start, col_idx)) : (
250
- ops.add(
251
- ops.add(block_start, col_idx),
252
- group_size,
253
- )
254
- ),
255
- ],
256
- weight=True,
257
- )
258
- else:
259
- self.quantizer.find_params(
260
- ops.expand_dims(weight_column, 1), weight=True
261
- )
262
-
263
- # Quantize the current weight column
264
- quantized_column = dequantize(
265
- ops.expand_dims(weight_column, 1),
266
- self.quantizer.scale,
267
- self.quantizer.zero,
268
- self.quantizer.maxq,
269
- )[:, 0]
270
-
271
- block_quantized = ops.slice_update(
272
- block_quantized,
273
- (0, col_idx),
274
- ops.expand_dims(quantized_column, axis=1),
275
- )
276
- quantization_error = ops.divide(
277
- ops.subtract(weight_column, quantized_column),
278
- diagonal_element,
279
- )
280
- block_errors = ops.slice_update(
281
- block_errors,
282
- (0, col_idx),
283
- ops.expand_dims(quantization_error, axis=1),
284
- )
285
-
286
- if ops.less(col_idx, ops.subtract(block_size, 1)):
287
- error_update = ops.matmul(
288
- ops.expand_dims(quantization_error, 1),
289
- ops.expand_dims(
290
- block_inverse_hessian[
291
- col_idx, ops.add(col_idx, 1) :
292
- ],
293
- 0,
294
- ),
295
- )
296
-
297
- # Efficiently update the remaining part of the
298
- # block_weights tensor.
299
- slice_to_update = block_weights[:, ops.add(col_idx, 1) :]
300
- updated_slice = ops.subtract(slice_to_update, error_update)
301
- block_weights = ops.slice_update(
302
- block_weights, (0, ops.add(col_idx, 1)), updated_slice
303
- )
304
-
305
- # Update the full quantized matrix with the processed block
306
- quantized_weights = ops.concatenate(
307
- [
308
- quantized_weights[:, :block_start],
309
- block_quantized,
310
- quantized_weights[:, block_end:],
311
- ],
312
- axis=1,
313
- )
314
-
315
- if block_end < self.rows:
316
- total_error_update = ops.matmul(
317
- block_errors,
318
- inverse_hessian[block_start:block_end, block_end:],
319
- )
320
- weights_matrix = ops.concatenate(
321
- [
322
- weights_matrix[:, :block_end],
323
- ops.subtract(
324
- weights_matrix[:, block_end:], total_error_update
325
- ),
326
- ],
327
- axis=1,
328
- )
329
-
330
- if activation_order:
331
- quantized_weights = ops.take(
332
- quantized_weights, inverse_permutation, axis=1
333
- )
334
-
335
- quantized_weights = ops.transpose(quantized_weights)
459
+ inverse_hessian = linalg.inv(hessian_matrix)
460
+
461
+ quantized, scale, zero, g_idx = gptq_quantize_matrix(
462
+ weights_matrix,
463
+ inv_hessian=inverse_hessian,
464
+ blocksize=blocksize,
465
+ group_size=self.config.group_size,
466
+ activation_order=self.config.activation_order,
467
+ order_metric=ops.diagonal(hessian_matrix),
468
+ compute_scale_zero=self.quantizer.find_params,
469
+ )
470
+ quantized = ops.cast(
471
+ quantized, self.original_layer.quantized_kernel.dtype
472
+ )
336
473
 
337
- if isinstance(self.original_layer, EinsumDense):
338
- quantized_weights = ops.reshape(
339
- quantized_weights, self.kernel_shape
474
+ if self.config.weight_bits == 4:
475
+ # For 4-bit weights, we need to pack them into bytes
476
+ quantized, _, _ = quantizers.pack_int4(
477
+ quantized, axis=0, dtype="uint8"
340
478
  )
341
479
 
342
- # Set the new quantized weights in the original layer
343
- new_weights = [ops.convert_to_numpy(quantized_weights)]
344
- if self.original_layer.bias is not None:
345
- new_weights.append(ops.convert_to_numpy(self.original_layer.bias))
346
-
347
- self.original_layer.set_weights(new_weights)
480
+ del self.original_layer._kernel
481
+ self.original_layer.quantized_kernel.assign(quantized)
482
+ self.original_layer.kernel_scale.assign(scale)
483
+ self.original_layer.kernel_zero.assign(zero)
484
+ self.original_layer.g_idx.assign(g_idx)
485
+ self.original_layer.is_gptq_calibrated = True
348
486
 
349
487
  def free(self):
350
- self.hessian = None
488
+ del self.hessian
489
+ del self.layer