keras-nightly 3.12.0.dev2025090303__py3-none-any.whl → 3.12.0.dev2025090403__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
- import random
1
+ import math
2
+ from contextlib import contextmanager
2
3
 
3
4
  import numpy as np
4
5
  from absl import logging
@@ -9,15 +10,109 @@ from keras.src.layers import Dense
9
10
  from keras.src.layers import EinsumDense
10
11
  from keras.src.layers import Embedding
11
12
  from keras.src.quantizers.gptq import GPTQ
12
- from keras.src.quantizers.gptq_quant import GPTQQuantization
13
13
 
14
14
 
15
- def get_dataloader(tokenizer, sequence_length, dataset, num_samples=128):
15
+ @contextmanager
16
+ def stream_hessians(layers_map, gptq_objects):
16
17
  """
17
- Prepares and chunks the calibration dataloader, repeating short datasets.
18
+ Temporarily monkey-patch each target layer's `call` method so
19
+ that input activations are streamed into the GPTQ instance
20
+ running Hessian estimate at capture time.
21
+
22
+ On `__enter__`: For every (name, layer) in `layers_map`, replaces
23
+ `layer.call` with a wrapper that:
24
+ 1) extracts the layer input from `*args`/`**kwargs`,
25
+ 2) reshapes it to 2D `[-1, rows]` where
26
+ `rows = gptq_objects[name].rows`,
27
+ 3) calls `gptq_objects[name].update_hessian_with_batch(x2d)`
28
+ 4) delegates to the original `layer.call` and returns its
29
+ output.
30
+
31
+ On `__exit__`: All original `layer.call` methods are restored even if an
32
+ exception occurs.
33
+
34
+ * Space complexity: O(d**2) per layer (for the Hessian).
35
+ * No weights are modified; only GPTQ statistics are updated.
36
+
37
+ Args:
38
+ layers_map: Dict[str, Layer]. Mapping from logical layer names to
39
+ the Keras layers that should be patched during calibration. Keys must
40
+ match `gptq_objects`.
41
+ gptq_objects: Dict[str, GPTQ]. Mapping from names to GPTQ instances.
42
+
43
+ Yields:
44
+ None: The patched state is active only within the `with` block. After
45
+ exit, all layers are unpatched and safe to use normally.
46
+
47
+ Example:
48
+ ```python
49
+ >>> with stream_hessians(layers_map, gptq_objects):
50
+ ... for sample in calibration_inputs:
51
+ ... if len(sample.shape) == 2:
52
+ ... sample = ops.expand_dims(sample, 0)
53
+ ... _ = block(sample) # hooks update Hessians on-the-fly
54
+ >>> # <- original layer.call methods restored here
55
+ ```
56
+ """
57
+ original_calls = {}
58
+
59
+ def create_hook(name, original_call_func):
60
+ def hook(*args, **kwargs):
61
+ inp = args[0] if args else kwargs["inputs"]
62
+ # Explicitly reshape the input tensor to be 2D, with the
63
+ # second dimension matching the number of input features
64
+ # expected by the layer's kernel.
65
+ # This correctly handles inputs of any dimensionality
66
+ # (e.g., 3D or 4D).
67
+ num_features = gptq_objects[name].rows
68
+ input_2d = ops.reshape(inp, (-1, num_features))
69
+ gptq_objects[name].update_hessian_with_batch(input_2d)
70
+ return original_call_func(*args, **kwargs)
71
+
72
+ return hook
73
+
74
+ try:
75
+ for name, layer in layers_map.items():
76
+ original_calls[name] = layer.call
77
+ layer.call = create_hook(name, layer.call)
78
+ yield
79
+ finally:
80
+ for name, layer in layers_map.items():
81
+ layer.call = original_calls[name]
82
+
83
+
84
+ def get_dataloader(
85
+ tokenizer,
86
+ sequence_length,
87
+ dataset,
88
+ num_samples=128,
89
+ *,
90
+ strategy="strided",
91
+ seed=42,
92
+ stride=None,
93
+ eos_id=None,
94
+ ):
18
95
  """
19
- all_tokens = []
96
+ Prepares and chunks the calibration dataloader, repeating short datasets.
97
+ All processing happens on the CPU.
20
98
 
99
+ Args:
100
+ tokenizer: The tokenizer to use for text splitting.
101
+ sequence_length: The length of each input sequence.
102
+ dataset: The dataset to sample from.
103
+ num_samples: The number of samples to generate.
104
+ strategy: The sampling strategy to use. Possible values are
105
+ 1. "strided": Samples are taken at regular intervals.
106
+ 2. "linspace": Samples are taken at evenly spaced intervals.
107
+ 3. "random": Samples are taken at random positions.
108
+ seed: The random seed for reproducibility. Used only if
109
+ strategy="random"
110
+ stride: The stride length for "strided" sampling.
111
+ eos_id: The end-of-sequence token ID.
112
+
113
+ Returns:
114
+ np.ndarray of shape (num_samples, 1, sequence_length), dtype int32.
115
+ """
21
116
  if not hasattr(dataset, "__iter__") or isinstance(dataset, (str, bytes)):
22
117
  raise TypeError(
23
118
  "The `dataset` argument must be an iterable (e.g., a list of "
@@ -27,44 +122,72 @@ def get_dataloader(tokenizer, sequence_length, dataset, num_samples=128):
27
122
  )
28
123
 
29
124
  dataset_list = list(dataset)
30
-
31
125
  if not dataset_list:
32
126
  raise ValueError("Provided dataset is empty.")
33
127
 
128
+ pieces = []
34
129
  if isinstance(dataset_list[0], str):
35
- logging.info("(Dataset contains strings, tokenizing now...)")
36
- full_text = "\n\n".join(dataset_list)
37
- all_tokens = tokenizer.tokenize(full_text)
130
+ for i, s in enumerate(dataset_list):
131
+ toks = np.asarray(tokenizer.tokenize(s)).reshape(-1)
132
+ pieces.append(toks)
133
+ # avoid windows that span document boundaries
134
+ if eos_id is not None and i < len(dataset_list) - 1:
135
+ pieces.append(np.array([eos_id], dtype=np.int32))
38
136
  else:
39
- logging.info("(Dataset is pre-tokenized, concatenating...)")
40
- all_tokens = np.concatenate(
41
- [ops.convert_to_numpy(s).reshape(-1) for s in dataset_list], axis=0
42
- )
43
-
44
- all_tokens = np.array(all_tokens, dtype=np.int32)
137
+ for s in dataset_list:
138
+ toks = ops.convert_to_numpy(s).reshape(-1)
139
+ pieces.append(toks.astype(np.int32, copy=False))
140
+
141
+ all_tokens = (
142
+ pieces[0].astype(np.int32, copy=False)
143
+ if len(pieces) == 1
144
+ else np.concatenate(pieces, axis=0).astype(np.int32, copy=False)
145
+ )
45
146
 
46
- # Repeat data if it's too short
47
147
  required_tokens = num_samples * sequence_length
48
- if len(all_tokens) < required_tokens:
49
- logging.info(
50
- f"Warning: Dataset is too short ({len(all_tokens)} tokens)."
51
- " Repeating data to generate {num_samples} samples."
52
- )
53
- repeats = -(-required_tokens // len(all_tokens)) # Ceiling division
148
+ if all_tokens.size < required_tokens:
149
+ repeats = math.ceil(required_tokens / max(1, all_tokens.size))
54
150
  all_tokens = np.tile(all_tokens, repeats)
55
151
 
56
- # Chunk the token list into samples
152
+ max_start = all_tokens.size - sequence_length
153
+ if max_start < 0:
154
+ raise ValueError(
155
+ f"Not enough tokens to form one sample of length {sequence_length} "
156
+ f"(have {all_tokens.size})."
157
+ )
57
158
 
58
- calibration_samples = []
59
- for _ in range(num_samples):
60
- # Generate a random starting index
61
- start_index = random.randint(0, len(all_tokens) - sequence_length - 1)
62
- end_index = start_index + sequence_length
63
- sample = all_tokens[start_index:end_index]
64
- calibration_samples.append(np.reshape(sample, (1, sequence_length)))
159
+ # Choose deterministic, well-spread starts by default
160
+ if strategy == "random":
161
+ rng = np.random.default_rng(seed)
162
+ starts = rng.integers(
163
+ 0, max_start + 1, size=num_samples, dtype=np.int64
164
+ )
165
+ elif strategy == "linspace":
166
+ # even coverage with no RNG
167
+ starts = np.linspace(0, max_start, num_samples, dtype=np.int64)
168
+ elif strategy == "strided":
169
+ # stride chosen to cover the space roughly uniformly
170
+ if stride is None:
171
+ stride = max(1, (max_start + 1) // num_samples)
172
+ # offset derived deterministically from seed
173
+ offset = (
174
+ (abs(hash(("gptq-calib", seed))) % (max_start + 1))
175
+ if max_start > 0
176
+ else 0
177
+ )
178
+ starts = (offset + np.arange(num_samples, dtype=np.int64) * stride) % (
179
+ max_start + 1
180
+ )
181
+ else:
182
+ raise ValueError(f"Unknown strategy: {strategy}")
65
183
 
66
- final_array = np.stack(calibration_samples, axis=0)
67
- return final_array
184
+ # Gather contiguous windows
185
+ # sliding_window_view avoids building a big index matrix
186
+ windows = np.lib.stride_tricks.sliding_window_view(
187
+ all_tokens, sequence_length
188
+ )
189
+ samples = windows[starts] # (num_samples, sequence_length)
190
+ return samples.astype(np.int32)[:, None, :]
68
191
 
69
192
 
70
193
  def _find_layers_recursive(layer, prefix, found_layers):
@@ -82,6 +205,41 @@ def _find_layers_recursive(layer, prefix, found_layers):
82
205
  _find_layers_recursive(sub_layer, layer_name, found_layers)
83
206
 
84
207
 
208
+ def _get_backbone_layers(model):
209
+ """Extract embedding and transformer layers from a KerasHub model."""
210
+ backbone = model.backbone
211
+ if not hasattr(backbone, "transformer_layers"):
212
+ raise ValueError(
213
+ "The model's backbone does not have a 'transformer_layers' "
214
+ "attribute. Please ensure you are using a standard KerasHub "
215
+ "transformer model."
216
+ )
217
+ transformer_blocks = backbone.transformer_layers
218
+
219
+ if hasattr(backbone, "token_embedding"):
220
+ embedding_layer = backbone.token_embedding
221
+ elif hasattr(backbone, "embedding"):
222
+ embedding_layer = backbone.embedding
223
+ else:
224
+ raise ValueError(
225
+ "Could not automatically find an embedding layer in the model."
226
+ )
227
+ return embedding_layer, transformer_blocks
228
+
229
+
230
+ def _get_custom_layers(model):
231
+ """Heuristic for extracting embedding + transformer blocks from a custom
232
+ model."""
233
+ embedding_layer = None
234
+ transformer_blocks = []
235
+ for layer in model.layers:
236
+ if isinstance(layer, Embedding) and embedding_layer is None:
237
+ embedding_layer = layer
238
+ elif getattr(layer, "_layers", None): # container-like block
239
+ transformer_blocks.append(layer)
240
+ return embedding_layer, transformer_blocks
241
+
242
+
85
243
  def find_layers_in_block(block):
86
244
  """
87
245
  A pluggable, generic function to find all Dense and EinsumDense layers
@@ -93,16 +251,7 @@ def find_layers_in_block(block):
93
251
  return found_layers
94
252
 
95
253
 
96
- def apply_gptq_layerwise(
97
- model,
98
- dataloader,
99
- num_samples,
100
- hessian_damping,
101
- group_size,
102
- symmetric,
103
- activation_order,
104
- weight_bits,
105
- ):
254
+ def apply_gptq_layerwise(model, dataloader, config):
106
255
  """Applies GPTQ quantization layer-by-layer to a Keras model.
107
256
 
108
257
  This function is designed to work with common transformer architectures,
@@ -134,63 +283,24 @@ def apply_gptq_layerwise(
134
283
  attempt to automatically discover its structure.
135
284
  dataloader: An iterable providing calibration data. Each item should
136
285
  be a batch of token IDs suitable for the model's embedding layer.
137
- num_samples: (int) The number of samples from the dataloader to use for
138
- calibration.
139
- hessian_damping: (float) The percentage of dampening to add to the
140
- Hessian diagonal for stabilization during inverse calculation.
141
- A value of 0.01 is common.
142
- group_size: (int) The size of the groups to use for quantization. A
143
- value of 128 means that 128 weights will share the same scaling
144
- factor. Use -1 for per-channel quantization.
145
- symmetric: (bool) If True, symmetric quantization is used. Otherwise,
146
- asymmetric quantization is used.
147
- activation_order: (bool) If True, reorders the weight columns based on
148
- activation magnitude, which can improve quantization accuracy.
149
- weight_bits: (int) The number of bits to use for the quantized weights,
150
- e.g., 4 for 4-bit quantization.
286
+ config: A GPTQConfiguration object.
151
287
 
152
288
  Raises:
153
289
  ValueError: If the function cannot automatically find an embedding
154
290
  layer or any transformer-like blocks to quantize within the model.
155
291
  """
292
+
293
+ num_samples = config.num_samples
294
+
156
295
  logging.info("Starting model quantization...")
157
296
  embedding_layer = None
158
297
  transformer_blocks = []
159
298
  if hasattr(model, "backbone"):
160
299
  logging.info("Detected KerasHub model structure.")
161
- backbone = model.backbone
162
-
163
- # Add the check for the 'transformer_layers' attribute.
164
- if hasattr(backbone, "transformer_layers"):
165
- transformer_blocks = backbone.transformer_layers
166
- else:
167
- # Raise a specific error if the attribute is missing.
168
- raise ValueError(
169
- "The model's backbone does not have a 'transformer_layers' "
170
- "attribute. Please ensure you are using a standard KerasHub "
171
- "transformer model."
172
- )
173
- # Find the embedding layer by checking for common names or by type.
174
- if hasattr(backbone, "token_embedding"):
175
- embedding_layer = backbone.token_embedding
176
- elif hasattr(backbone, "embedding"):
177
- embedding_layer = backbone.embedding
178
- else:
179
- raise ValueError(
180
- "Could not automatically find an embedding layer in the model."
181
- )
182
-
300
+ embedding_layer, transformer_blocks = _get_backbone_layers(model)
183
301
  else:
184
302
  logging.info("Detected custom model structure.")
185
- for layer in model.layers:
186
- # The first Embedding layer found is assumed to be the main one.
187
- if isinstance(layer, Embedding) and embedding_layer is None:
188
- embedding_layer = layer
189
- # A "block" is a container-like layer with its own sub-layers
190
- # that we can quantize. This is a heuristic that works for the
191
- # test.
192
- elif hasattr(layer, "_layers") and layer._layers:
193
- transformer_blocks.append(layer)
303
+ embedding_layer, transformer_blocks = _get_custom_layers(model)
194
304
 
195
305
  if embedding_layer is None:
196
306
  raise ValueError(
@@ -207,6 +317,8 @@ def apply_gptq_layerwise(
207
317
  embedding_layer(ops.convert_to_tensor(batch, dtype="int32"))
208
318
  for batch in dataloader
209
319
  ]
320
+ num_samples = min(num_samples, len(inputs))
321
+
210
322
  progbar = keras_utils.Progbar(target=len(transformer_blocks))
211
323
 
212
324
  for block_idx, block in enumerate(transformer_blocks):
@@ -221,73 +333,23 @@ def apply_gptq_layerwise(
221
333
  else:
222
334
  logging.info(f"Found layers: {list(sub_layers_map.keys())}")
223
335
  gptq_objects = {
224
- name: GPTQ(layer) for name, layer in sub_layers_map.items()
336
+ name: GPTQ(layer, config)
337
+ for name, layer in sub_layers_map.items()
225
338
  }
226
339
 
227
- captured_inputs = {name: [] for name in sub_layers_map.keys()}
228
- original_calls = {}
229
-
230
- def create_hook(name, original_call_func):
231
- """A factory for creating a hook to capture layer inputs."""
232
-
233
- def hook(*args, **kwargs):
234
- if args:
235
- inp = args[0]
236
- else:
237
- inp = kwargs["inputs"]
238
- captured_inputs[name].append(inp)
239
- return original_call_func(*args, **kwargs)
240
-
241
- return hook
242
-
243
- try:
244
- for name, layer in sub_layers_map.items():
245
- original_call = layer.call
246
- original_calls[name] = original_call
247
- layer.call = create_hook(name, original_call)
248
-
249
- logging.info(f"Capturing activations for block {block_idx}...")
340
+ with stream_hessians(sub_layers_map, gptq_objects):
250
341
  for sample_idx in range(num_samples):
251
342
  current_input = inputs[sample_idx]
252
343
  if len(current_input.shape) == 2:
253
344
  current_input = ops.expand_dims(current_input, axis=0)
254
345
  _ = block(current_input)
255
346
 
256
- finally:
257
- for name, layer in sub_layers_map.items():
258
- if name in original_calls:
259
- layer.call = original_calls[name]
260
-
261
- logging.info(f"Building Hessians for block {block_idx}...")
262
- for name, gptq_object in gptq_objects.items():
263
- layer_inputs = ops.concatenate(captured_inputs[name], axis=0)
264
-
265
- # Explicitly reshape the input tensor to be 2D, with the second
266
- # dimension matching the number of input features expected by
267
- # the layer's kernel.
268
- # This correctly handles inputs of any dimensionality
269
- # (e.g., 3D or 4D).
270
- num_features = gptq_object.rows
271
- input_reshaped = ops.reshape(layer_inputs, (-1, num_features))
272
- gptq_object.update_hessian_with_batch(input_reshaped)
273
-
274
- quantizer = GPTQQuantization(
275
- weight_bits,
276
- per_channel=True,
277
- symmetric=symmetric,
278
- group_size=group_size,
279
- )
280
347
  for name, gptq_object in gptq_objects.items():
281
348
  logging.info(f"Quantizing {name}...")
282
- gptq_object.quantizer = quantizer
283
- gptq_object.quantize_and_correct_block(
284
- hessian_damping=hessian_damping,
285
- group_size=group_size,
286
- activation_order=activation_order,
287
- )
349
+ gptq_object.quantize_and_correct_layer()
288
350
  gptq_object.free()
289
351
 
290
- del gptq_objects, captured_inputs, original_calls
352
+ del gptq_objects
291
353
 
292
354
  if block_idx < len(transformer_blocks) - 1:
293
355
  logging.info(f"Generating inputs for block {block_idx + 1}...")
@@ -304,32 +366,23 @@ def apply_gptq_layerwise(
304
366
  logging.info("Quantization process complete.")
305
367
 
306
368
 
307
- def quantize_model(model, config):
369
+ def gptq_quantize(model, config):
308
370
  """
309
371
  Top-level function to quantize a Keras model using GPTQ.
310
372
  """
311
373
  logging.info("Starting GPTQ quantization process...")
312
374
 
313
- # Load ALL data needed from the generator/source in a single call.
375
+ # Load all data needed from the generator/source in a single call.
314
376
  total_samples_to_request = config.num_samples
315
- full_dataloader = get_dataloader(
377
+ dataloader = get_dataloader(
316
378
  config.tokenizer,
317
379
  config.sequence_length,
318
380
  config.dataset,
319
381
  num_samples=total_samples_to_request,
320
382
  )
321
383
 
322
- # Split the materialized data. This works because full_dataloader
384
+ # Split the materialized data. This works because dataloader
323
385
  # is now a NumPy array, which can be sliced and reused.
324
- calibration_dataloader = full_dataloader[: config.num_samples]
325
-
326
- apply_gptq_layerwise(
327
- model,
328
- calibration_dataloader, # Use the calibration slice
329
- config.num_samples, # Use the configured number of samples
330
- config.hessian_damping,
331
- config.group_size,
332
- config.symmetric,
333
- config.activation_order,
334
- config.weight_bits,
335
- )
386
+ calibration_dataloader = dataloader[: config.num_samples]
387
+
388
+ apply_gptq_layerwise(model, calibration_dataloader, config)
@@ -9,6 +9,7 @@ from keras.src.backend import any_symbolic_tensors
9
9
  from keras.src.backend.common.backend_utils import canonicalize_axis
10
10
  from keras.src.backend.common.backend_utils import standardize_axis_for_numpy
11
11
  from keras.src.ops.operation import Operation
12
+ from keras.src.quantizers.gptq_config import GPTQConfig
12
13
 
13
14
  """Int8-related classes and methods"""
14
15
 
@@ -634,3 +635,202 @@ def unpack_int4(packed, orig_len, axis=0):
634
635
  unpacked = ops.transpose(unpacked, inv_perm)
635
636
 
636
637
  return unpacked # dtype is int8
638
+
639
+
640
+ class GPTQQuantizer(Quantizer):
641
+ """A class that handles the quantization of weights using GPTQ method.
642
+
643
+ This class provides methods to find quantization parameters (scale and zero)
644
+ for a given tensor and can be used to quantize weights in a GPTQ context.
645
+
646
+ Args:
647
+ weight_bits: (int) The number of bits to quantize to (e.g., 4).
648
+ per_channel: (bool) A flag indicating whether quantization is
649
+ applied per-channel (`True`) or per-tensor (`False`).
650
+ Defaults to `False`.
651
+ symmetric: (bool) A flag indicating whether symmetric (`True`) or
652
+ asymmetric (`False`) quantization is used. Defaults to `False`.
653
+ group_size: (int) The size of weight groups for quantization. A
654
+ value of -1 indicates that grouping is not used.
655
+ Defaults to -1.
656
+ """
657
+
658
+ def __init__(self, config=GPTQConfig(tokenizer=None, dataset=None)):
659
+ Quantizer.__init__(self)
660
+ self.weight_bits = config.weight_bits
661
+ self.per_channel = config.per_channel
662
+ self.symmetric = config.symmetric
663
+ self.group_size = config.group_size
664
+
665
+ # These are now determined later by `find_params`
666
+ self.scale = None
667
+ self.zero = None
668
+ self.maxq = None
669
+
670
+ def find_params(self, input_tensor, weight=False):
671
+ """Finds quantization parameters (scale and zero) for a given tensor."""
672
+ self.scale, self.zero, self.maxq = compute_quantization_parameters(
673
+ input_tensor,
674
+ bits=self.weight_bits,
675
+ symmetric=self.symmetric,
676
+ per_channel=self.per_channel,
677
+ group_size=self.group_size,
678
+ weight=weight,
679
+ )
680
+ return self.scale, self.zero, self.maxq
681
+
682
+ def ready(self):
683
+ """Checks if the quantization parameters have been computed."""
684
+ return (
685
+ self.scale is not None
686
+ and self.zero is not None
687
+ and self.maxq is not None
688
+ )
689
+
690
+ def get_config(self):
691
+ config = super().get_config()
692
+ config.update(
693
+ {
694
+ "weight_bits": self.weight_bits,
695
+ "per_channel": self.per_channel,
696
+ "symmetric": self.symmetric,
697
+ "group_size": self.group_size,
698
+ }
699
+ )
700
+ return config
701
+
702
+ @classmethod
703
+ def from_config(cls, config):
704
+ gptq = GPTQConfig(
705
+ tokenizer=None,
706
+ dataset=None,
707
+ weight_bits=config["weight_bits"],
708
+ per_channel=config["per_channel"],
709
+ symmetric=config["symmetric"],
710
+ group_size=config["group_size"],
711
+ )
712
+ return cls(gptq)
713
+
714
+
715
+ def compute_quantization_parameters(
716
+ x, *, bits, symmetric=False, per_channel=False, group_size=-1, weight=False
717
+ ):
718
+ """
719
+ Computes the scale and zero-point for quantization.
720
+
721
+ Args:
722
+ x: KerasTensor. The input tensor to quantize.
723
+ bits: int. The number of bits to quantize to (e.g., 4).
724
+ symmetric: bool. Whether to use symmetric quantization.
725
+ per_channel: bool. Whether to quantize per channel.
726
+ group_size: int. The group size for quantization.
727
+ weight: bool. Whether the input tensor is a weight tensor.
728
+ """
729
+ if x is None:
730
+ raise ValueError(f"Input tensor {x} cannot be None.")
731
+
732
+ # For weights, we typically expect at least a 2D tensor.
733
+ if weight and len(x.shape) < 2:
734
+ raise ValueError(
735
+ f"Input weight tensor {x} must have a rank of at "
736
+ f"least 2, but got rank {len(x.shape)}."
737
+ )
738
+
739
+ if ops.size(x) == 0:
740
+ raise ValueError("Input tensor 'x' cannot be empty.")
741
+
742
+ original_shape = x.shape
743
+
744
+ if per_channel:
745
+ if weight:
746
+ if group_size != -1:
747
+ input_reshaped = ops.reshape(x, [-1, group_size])
748
+ else:
749
+ input_reshaped = ops.reshape(x, [original_shape[0], -1])
750
+ else: # per-tensor
751
+ input_reshaped = ops.reshape(x, [1, -1])
752
+
753
+ # Find min/max values
754
+ min_values = ops.min(input_reshaped, axis=1)
755
+ max_values = ops.max(input_reshaped, axis=1)
756
+
757
+ # Apply symmetric quantization logic if enabled
758
+ if symmetric:
759
+ max_values = ops.maximum(ops.abs(min_values), max_values)
760
+ min_values = ops.where(
761
+ ops.less(min_values, 0), ops.negative(max_values), min_values
762
+ )
763
+
764
+ # Ensure range is not zero to avoid division errors
765
+ zero_range = ops.equal(min_values, max_values)
766
+ min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)
767
+ max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)
768
+
769
+ maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), "float32")
770
+
771
+ # Calculate scale and zero-point
772
+ scale = ops.divide(ops.subtract(max_values, min_values), maxq)
773
+ if symmetric:
774
+ zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))
775
+ else:
776
+ zero = ops.round(ops.divide(ops.negative(min_values), scale))
777
+
778
+ # Ensure scale is non-zero
779
+ scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)
780
+
781
+ if weight:
782
+ # Per-channel, non-grouped case: simple reshape is correct.
783
+ if per_channel and group_size == -1:
784
+ scale = ops.reshape(scale, [-1, 1])
785
+ zero = ops.reshape(zero, [-1, 1])
786
+ elif not per_channel:
787
+ num_rows = original_shape[0]
788
+ scale = ops.tile(ops.reshape(scale, (1, 1)), (num_rows, 1))
789
+ zero = ops.tile(ops.reshape(zero, (1, 1)), (num_rows, 1))
790
+ if per_channel:
791
+ scale = ops.reshape(scale, [-1, 1])
792
+ zero = ops.reshape(zero, [-1, 1])
793
+
794
+ return scale, zero, maxq
795
+
796
+
797
+ def quantize_with_zero_point(input_tensor, scale, zero, maxq):
798
+ """Quantize a float tensor into discrete levels [0, maxq] using
799
+ per-tensor/per-channel/grouped scaling.
800
+
801
+ Returns `q` (same dtype as inputs/scales; float is fine) where values are in
802
+ [0, maxq].
803
+
804
+ Args:
805
+ input_tensor: KerasTensor. The input tensor to quantize.
806
+ scale: KerasTensor. The scale tensor for quantization.
807
+ zero: KerasTensor. The zero tensor for quantization.
808
+ maxq: KerasTensor. The maximum quantization value.
809
+
810
+ Returns:
811
+ KerasTensor. The quantized tensor.
812
+ """
813
+ # Guard against divide-by-zero
814
+ epsilon = ops.cast(1e-8, dtype=scale.dtype)
815
+ safe_scale = ops.where(ops.equal(scale, 0), epsilon, scale)
816
+
817
+ quantized_tensor = ops.round(
818
+ ops.add(ops.divide(input_tensor, safe_scale), zero)
819
+ )
820
+ quantized_tensor = ops.clip(quantized_tensor, 0, maxq)
821
+ return quantized_tensor
822
+
823
+
824
+ def dequantize_with_zero_point(input_tensor, scale, zero):
825
+ """
826
+ Dequantizes a quantized tensor using the provided scale and zero tensors.
827
+
828
+ Args:
829
+ input_tensor: KerasTensor. The quantized tensor to dequantize.
830
+ scale: KerasTensor. The scale tensor for dequantization.
831
+ zero: KerasTensor. The zero tensor for dequantization.
832
+
833
+ Returns:
834
+ KerasTensor. The dequantized tensor.
835
+ """
836
+ return ops.multiply(scale, ops.subtract(input_tensor, zero))