compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/base.py +3 -1
  2. compressed_tensors/compressors/__init__.py +9 -1
  3. compressed_tensors/compressors/base.py +12 -55
  4. compressed_tensors/compressors/dense.py +5 -5
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/marlin_24.py +251 -0
  7. compressed_tensors/compressors/model_compressor.py +336 -0
  8. compressed_tensors/compressors/naive_quantized.py +144 -0
  9. compressed_tensors/compressors/pack_quantized.py +219 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +4 -4
  11. compressed_tensors/config/base.py +9 -4
  12. compressed_tensors/config/dense.py +4 -4
  13. compressed_tensors/config/sparse_bitmask.py +3 -3
  14. compressed_tensors/quantization/lifecycle/__init__.py +2 -0
  15. compressed_tensors/quantization/lifecycle/apply.py +204 -31
  16. compressed_tensors/quantization/lifecycle/calibration.py +20 -1
  17. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  18. compressed_tensors/quantization/lifecycle/forward.py +214 -62
  19. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  20. compressed_tensors/quantization/lifecycle/helpers.py +53 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +62 -5
  22. compressed_tensors/quantization/observers/base.py +66 -23
  23. compressed_tensors/quantization/observers/helpers.py +69 -11
  24. compressed_tensors/quantization/observers/memoryless.py +17 -9
  25. compressed_tensors/quantization/observers/min_max.py +44 -13
  26. compressed_tensors/quantization/quant_args.py +47 -3
  27. compressed_tensors/quantization/quant_config.py +104 -23
  28. compressed_tensors/quantization/quant_scheme.py +183 -2
  29. compressed_tensors/quantization/utils/helpers.py +142 -8
  30. compressed_tensors/utils/__init__.py +4 -0
  31. compressed_tensors/utils/helpers.py +54 -7
  32. compressed_tensors/utils/offload.py +104 -0
  33. compressed_tensors/utils/permutations_24.py +65 -0
  34. compressed_tensors/utils/safetensors_load.py +3 -2
  35. compressed_tensors/utils/semi_structured_conversions.py +341 -0
  36. compressed_tensors/version.py +53 -0
  37. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
  38. compressed_tensors-0.5.0.dist-info/RECORD +48 -0
  39. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
  40. compressed_tensors-0.3.3.dist-info/RECORD +0 -38
  41. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
  42. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -14,18 +14,28 @@
14
14
 
15
15
  from functools import wraps
16
16
  from math import ceil
17
+ from typing import Optional
17
18
 
18
19
  import torch
20
+ from compressed_tensors.quantization.observers.helpers import calculate_range
19
21
  from compressed_tensors.quantization.quant_args import (
20
22
  QuantizationArgs,
21
23
  QuantizationStrategy,
24
+ round_to_quantized_type,
22
25
  )
23
26
  from compressed_tensors.quantization.quant_config import QuantizationStatus
24
27
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
28
+ from compressed_tensors.utils import update_parameter_data
25
29
  from torch.nn import Module
26
30
 
27
31
 
28
- __all__ = ["wrap_module_forward_quantized", "maybe_calibrate_or_quantize"]
32
+ __all__ = [
33
+ "quantize",
34
+ "dequantize",
35
+ "fake_quantize",
36
+ "wrap_module_forward_quantized",
37
+ "maybe_calibrate_or_quantize",
38
+ ]
29
39
 
30
40
 
31
41
  @torch.no_grad()
@@ -33,14 +43,39 @@ def quantize(
33
43
  x: torch.Tensor,
34
44
  scale: torch.Tensor,
35
45
  zero_point: torch.Tensor,
36
- q_min: torch.Tensor,
37
- q_max: torch.Tensor,
46
+ args: QuantizationArgs,
47
+ dtype: Optional[torch.dtype] = None,
38
48
  ) -> torch.Tensor:
49
+ """
50
+ Quantize the input tensor x using the QuantizationStrategy specified in args.
51
+ Quantization can be done per tensor, channel, token or group. For group
52
+ quantization, the group_size must be divisible by the column size. The input scale
53
+ and zero_points are reshaped to support vectorization (Assumes 1 is the
54
+ channel dimension)
39
55
 
40
- return torch.clamp(
41
- torch.round(x / scale + zero_point),
42
- q_min,
43
- q_max,
56
+ :param x: Input tensor
57
+ :param scale: scale tensor
58
+ :param zero_point: zero point tensor
59
+ :param args: quantization args dictating how to quantize x
60
+ :param dtype: optional dtype to cast the quantized output to
61
+ :return: fake quantized tensor
62
+ """
63
+ # ensure all tensors are on the same device
64
+ # assumes that the target device is the input
65
+ # tensor's device
66
+ if x.device != scale.device:
67
+ scale = scale.to(x.device)
68
+ if x.device != zero_point.device:
69
+ zero_point = zero_point.to(x.device)
70
+
71
+ return _process_quantization(
72
+ x=x,
73
+ scale=scale,
74
+ zero_point=zero_point,
75
+ args=args,
76
+ dtype=dtype,
77
+ do_quantize=True,
78
+ do_dequantize=False,
44
79
  )
45
80
 
46
81
 
@@ -48,9 +83,50 @@ def quantize(
48
83
  def dequantize(
49
84
  x_q: torch.Tensor,
50
85
  scale: torch.Tensor,
51
- zero_point: torch.Tensor,
86
+ zero_point: torch.Tensor = None,
87
+ args: QuantizationArgs = None,
88
+ dtype: Optional[torch.dtype] = None,
52
89
  ) -> torch.Tensor:
53
- return (x_q - zero_point) * scale
90
+ """
91
+ Dequantize a quantized input tensor x_q based on the strategy specified in args. If
92
+ args is not provided, the strategy will be inferred.
93
+
94
+ :param x: quantized input tensor
95
+ :param scale: scale tensor
96
+ :param zero_point: zero point tensor
97
+ :param args: quantization args used to quantize x_q
98
+ :param dtype: optional dtype to cast the dequantized output to
99
+ :return: dequantized float tensor
100
+ """
101
+ if args is None:
102
+ if scale.ndim == 0 or scale.ndim == 1:
103
+ args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
104
+ elif scale.ndim == 2:
105
+ if scale.shape[1] == 1:
106
+ args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
107
+ else:
108
+ group_size = int(x_q.shape[1] / scale.shape[1])
109
+ args = QuantizationArgs(
110
+ strategy=QuantizationStrategy.GROUP, group_size=group_size
111
+ )
112
+ else:
113
+ raise ValueError(
114
+ f"Could not infer a quantization strategy from scale with {scale.ndim} "
115
+ "dimmensions. Expected 0 or 2 dimmensions."
116
+ )
117
+
118
+ if dtype is None:
119
+ dtype = scale.dtype
120
+
121
+ return _process_quantization(
122
+ x=x_q,
123
+ scale=scale,
124
+ zero_point=zero_point,
125
+ args=args,
126
+ do_quantize=False,
127
+ do_dequantize=True,
128
+ dtype=dtype,
129
+ )
54
130
 
55
131
 
56
132
  @torch.no_grad()
@@ -61,30 +137,45 @@ def fake_quantize(
61
137
  args: QuantizationArgs,
62
138
  ) -> torch.Tensor:
63
139
  """
64
- Fake quantize the input tensor x depending on the group_size.
65
- if group_size is greater than 0, then q/dq by groups. The groups
66
- must be divisible by the column size
67
- if group_size is -1, then channel wise q/dq. THe input scale and
68
- zero_points are reshaped to support vectorization (Assumes 1 is
69
- the channel dimension)
140
+ Fake quantize the input tensor x by quantizing then dequantizing with
141
+ the QuantizationStrategy specified in args. Quantization can be done per tensor,
142
+ channel, token or group. For group quantization, the group_size must be divisible
143
+ by the column size. The input scale and zero_points are reshaped to support
144
+ vectorization (Assumes 1 is the channel dimension)
70
145
 
71
146
  :param x: Input tensor
72
147
  :param scale: scale tensor
73
148
  :param zero_point: zero point tensor
74
- :param args: quantization args that contain group_size info
149
+ :param args: quantization args dictating how to quantize x
75
150
  :return: fake quantized tensor
76
-
77
151
  """
78
- bit_range = 2**args.num_bits
79
- max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
80
- min_q = torch.tensor(-bit_range / 2, device=x.device)
152
+ return _process_quantization(
153
+ x=x,
154
+ scale=scale,
155
+ zero_point=zero_point,
156
+ args=args,
157
+ do_quantize=True,
158
+ do_dequantize=True,
159
+ )
160
+
161
+
162
+ @torch.no_grad()
163
+ def _process_quantization(
164
+ x: torch.Tensor,
165
+ scale: torch.Tensor,
166
+ zero_point: torch.Tensor,
167
+ args: QuantizationArgs,
168
+ dtype: Optional[torch.dtype] = None,
169
+ do_quantize: bool = True,
170
+ do_dequantize: bool = True,
171
+ ) -> torch.Tensor:
81
172
 
173
+ q_min, q_max = calculate_range(args, x.device)
82
174
  group_size = args.group_size
83
175
 
84
- # group
85
176
  if args.strategy == QuantizationStrategy.GROUP:
86
-
87
- DQ = torch.zeros_like(x)
177
+ output_dtype = dtype if dtype is not None else x.dtype
178
+ output = torch.zeros_like(x).to(output_dtype)
88
179
 
89
180
  # TODO: vectorize the for loop
90
181
  # TODO: fix genetric assumption about the tensor size for computing group
@@ -94,7 +185,7 @@ def fake_quantize(
94
185
  while scale.ndim < 2:
95
186
  # pad scale and zero point dims for slicing
96
187
  scale = scale.unsqueeze(1)
97
- zero_point = zero_point.unsqueeze(1)
188
+ zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
98
189
 
99
190
  columns = x.shape[1]
100
191
  if columns >= group_size:
@@ -106,51 +197,60 @@ def fake_quantize(
106
197
  for i in range(ceil(columns / group_size)):
107
198
  # scale.shape should be [nchan, ndim]
108
199
  # sc.shape should be [nchan, 1] after unsqueeze
109
-
110
- sc = scale[:, i].unsqueeze(1)
111
- zp = zero_point[:, i].unsqueeze(1)
200
+ sc = scale[:, i].view(-1, 1)
201
+ zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
112
202
 
113
203
  idx = i * group_size
114
- Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q)
115
- DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)
116
-
117
- # channel-wise
118
- elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
119
- # before: scale shape = [channel_size]
120
- # after: scale shape = [1, channel_size]
121
- scale = scale.unsqueeze(0)
122
- zero_point = zero_point.unsqueeze(0)
123
-
124
- Q = quantize(x, scale, zero_point, min_q, max_q)
125
- DQ = dequantize(Q, scale, zero_point)
126
-
127
- # per-token
128
- elif args.strategy == QuantizationStrategy.TOKEN:
129
- # before: scale shape = [num_tokens]
130
- # after: scale shape = [num_tokens, 1]
131
- # x.shape = 1, num_tokens, 1]
132
- # scale gets broadcasted as expected withput having [1, num_tokens, 1] shape
133
-
134
- scale = scale.unsqueeze(1)
135
- zero_point = zero_point.unsqueeze(1)
136
-
137
- Q = quantize(x, scale, zero_point, min_q, max_q)
138
- DQ = dequantize(Q, scale, zero_point)
139
-
140
- else:
141
- Q = quantize(x, scale, zero_point, min_q, max_q)
142
- DQ = dequantize(Q, scale, zero_point)
204
+ if do_quantize:
205
+ output[:, idx : (idx + group_size)] = _quantize(
206
+ x[:, idx : (idx + group_size)],
207
+ sc,
208
+ zp,
209
+ q_min,
210
+ q_max,
211
+ args,
212
+ dtype=dtype,
213
+ )
214
+ if do_dequantize:
215
+ input = (
216
+ output[:, idx : (idx + group_size)]
217
+ if do_quantize
218
+ else x[:, idx : (idx + group_size)]
219
+ )
220
+ output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
221
+
222
+ else: # covers channel, token and tensor strategies
223
+ if do_quantize:
224
+ output = _quantize(
225
+ x,
226
+ scale,
227
+ zero_point,
228
+ q_min,
229
+ q_max,
230
+ args,
231
+ dtype=dtype,
232
+ )
233
+ if do_dequantize:
234
+ output = _dequantize(output if do_quantize else x, scale, zero_point)
143
235
 
144
- return DQ
236
+ return output
145
237
 
146
238
 
147
239
  def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
148
240
  # expects a module already initialized and injected with the parameters in
149
241
  # initialize_module_for_quantization
150
- forward_func_orig = module.forward.__func__
242
+ if hasattr(module.forward, "__func__"):
243
+ forward_func_orig = module.forward.__func__
244
+ else:
245
+ forward_func_orig = module.forward.func
151
246
 
152
247
  @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
153
248
  def wrapped_forward(self, *args, **kwargs):
249
+ if not getattr(module, "quantization_enabled", True):
250
+ # quantization is disabled on forward passes, return baseline
251
+ # forward call
252
+ return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
253
+
154
254
  input_ = args[0]
155
255
 
156
256
  if scheme.input_activations is not None:
@@ -199,6 +299,11 @@ def maybe_calibrate_or_quantize(
199
299
  }:
200
300
  return value
201
301
 
302
+ if value.numel() == 0:
303
+ # if the tensor is empty,
304
+ # skip quantization
305
+ return value
306
+
202
307
  if args.dynamic:
203
308
  # dynamic quantization - get scale and zero point directly from observer
204
309
  observer = getattr(module, f"{base_name}_observer")
@@ -208,14 +313,61 @@ def maybe_calibrate_or_quantize(
208
313
  scale = getattr(module, f"{base_name}_scale")
209
314
  zero_point = getattr(module, f"{base_name}_zero_point")
210
315
 
211
- if module.quantization_status == QuantizationStatus.CALIBRATION:
316
+ if (
317
+ module.quantization_status == QuantizationStatus.CALIBRATION
318
+ and base_name != "weight"
319
+ ):
212
320
  # calibration mode - get new quant params from observer
213
321
  observer = getattr(module, f"{base_name}_observer")
214
322
 
215
323
  updated_scale, updated_zero_point = observer(value)
216
324
 
217
325
  # update scale and zero point
218
- device = next(module.parameters()).device
219
- scale.data = updated_scale.to(device)
220
- zero_point.data = updated_zero_point.to(device)
326
+ update_parameter_data(module, updated_scale, f"{base_name}_scale")
327
+ update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
328
+
221
329
  return fake_quantize(value, scale, zero_point, args)
330
+
331
+
332
+ @torch.no_grad()
333
+ def _quantize(
334
+ x: torch.Tensor,
335
+ scale: torch.Tensor,
336
+ zero_point: torch.Tensor,
337
+ q_min: torch.Tensor,
338
+ q_max: torch.Tensor,
339
+ args: QuantizationArgs,
340
+ dtype: Optional[torch.dtype] = None,
341
+ ) -> torch.Tensor:
342
+
343
+ scaled = x / scale + zero_point.to(x.dtype)
344
+ # clamp first because cast isn't guaranteed to be saturated (ie for fp8)
345
+ clamped_value = torch.clamp(
346
+ scaled,
347
+ q_min,
348
+ q_max,
349
+ )
350
+ quantized_value = round_to_quantized_type(clamped_value, args)
351
+ if dtype is not None:
352
+ quantized_value = quantized_value.to(dtype)
353
+
354
+ return quantized_value
355
+
356
+
357
+ @torch.no_grad()
358
+ def _dequantize(
359
+ x_q: torch.Tensor,
360
+ scale: torch.Tensor,
361
+ zero_point: torch.Tensor = None,
362
+ dtype: Optional[torch.dtype] = None,
363
+ ) -> torch.Tensor:
364
+
365
+ dequant_value = x_q
366
+ if zero_point is not None:
367
+ dequant_value = dequant_value - zero_point.to(scale.dtype)
368
+ dequant_value = dequant_value.to(scale.dtype) * scale
369
+
370
+ if dtype is not None:
371
+ dequant_value = dequant_value.to(dtype)
372
+
373
+ return dequant_value
@@ -35,6 +35,10 @@ def freeze_module_quantization(module: Module):
35
35
  # no quantization scheme nothing to do
36
36
  return
37
37
 
38
+ if module.quantization_status == QuantizationStatus.FROZEN:
39
+ # nothing to do, already frozen
40
+ return
41
+
38
42
  # delete observers from module if not dynamic
39
43
  if scheme.input_activations and not scheme.input_activations.dynamic:
40
44
  delattr(module, "input_observer")
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Miscelaneous helpers for the quantization lifecycle
17
+ """
18
+
19
+
20
+ from torch.nn import Module
21
+
22
+
23
+ __all__ = [
24
+ "update_layer_weight_quant_params",
25
+ "enable_quantization",
26
+ "disable_quantization",
27
+ ]
28
+
29
+
30
+ def update_layer_weight_quant_params(layer: Module):
31
+ weight = getattr(layer, "weight", None)
32
+ scale = getattr(layer, "weight_scale", None)
33
+ zero_point = getattr(layer, "weight_zero_point", None)
34
+ observer = getattr(layer, "weight_observer", None)
35
+
36
+ if weight is None or observer is None or scale is None or zero_point is None:
37
+ # scale, zp, or observer not calibratable or weight not available
38
+ return
39
+
40
+ updated_scale, updated_zero_point = observer(weight)
41
+
42
+ # update scale and zero point
43
+ device = next(layer.parameters()).device
44
+ scale.data = updated_scale.to(device)
45
+ zero_point.data = updated_zero_point.to(device)
46
+
47
+
48
+ def enable_quantization(module: Module):
49
+ module.quantization_enabled = True
50
+
51
+
52
+ def disable_quantization(module: Module):
53
+ module.quantization_enabled = False
@@ -17,12 +17,18 @@ import logging
17
17
  from typing import Optional
18
18
 
19
19
  import torch
20
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
21
+ from accelerate.utils import PrefixedDataset
20
22
  from compressed_tensors.quantization.lifecycle.forward import (
21
23
  wrap_module_forward_quantized,
22
24
  )
23
- from compressed_tensors.quantization.quant_args import QuantizationArgs
25
+ from compressed_tensors.quantization.quant_args import (
26
+ QuantizationArgs,
27
+ QuantizationStrategy,
28
+ )
24
29
  from compressed_tensors.quantization.quant_config import QuantizationStatus
25
30
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
31
+ from compressed_tensors.utils import get_execution_device, is_module_offloaded
26
32
  from torch.nn import Module, Parameter
27
33
 
28
34
 
@@ -58,7 +64,12 @@ def initialize_module_for_quantization(
58
64
  _initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
59
65
  if scheme.weights is not None:
60
66
  if hasattr(module, "weight"):
61
- _initialize_scale_zero_point_observer(module, "weight", scheme.weights)
67
+ weight_shape = None
68
+ if isinstance(module, torch.nn.Linear):
69
+ weight_shape = module.weight.shape
70
+ _initialize_scale_zero_point_observer(
71
+ module, "weight", scheme.weights, weight_shape=weight_shape
72
+ )
62
73
  else:
63
74
  _LOGGER.warning(
64
75
  f"module type {type(module)} targeted for weight quantization but "
@@ -73,12 +84,38 @@ def initialize_module_for_quantization(
73
84
  module.quantization_scheme = scheme
74
85
  module.quantization_status = QuantizationStatus.INITIALIZED
75
86
 
87
+ offloaded = False
88
+ if is_module_offloaded(module):
89
+ offloaded = True
90
+ hook = module._hf_hook
91
+ prefix_dict = module._hf_hook.weights_map
92
+ new_prefix = {}
93
+
94
+ # recreate the prefix dict (since it is immutable)
95
+ # and add quantization parameters
96
+ for key, data in module.named_parameters():
97
+ if key not in prefix_dict:
98
+ new_prefix[f"{prefix_dict.prefix}{key}"] = data
99
+ else:
100
+ new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
101
+ new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
102
+ remove_hook_from_module(module)
103
+
76
104
  # wrap forward call of module to perform quantized actions based on calltime status
77
105
  wrap_module_forward_quantized(module, scheme)
78
106
 
107
+ if offloaded:
108
+ # we need to re-add the hook for offloading now that we've wrapped forward
109
+ add_hook_to_module(module, hook)
110
+ if prefix_dict is not None:
111
+ module._hf_hook.weights_map = new_prefix_dict
112
+
79
113
 
80
114
  def _initialize_scale_zero_point_observer(
81
- module: Module, base_name: str, quantization_args: QuantizationArgs
115
+ module: Module,
116
+ base_name: str,
117
+ quantization_args: QuantizationArgs,
118
+ weight_shape: Optional[torch.Size] = None,
82
119
  ):
83
120
  # initialize observer module and attach as submodule
84
121
  observer = quantization_args.get_observer()
@@ -88,12 +125,32 @@ def _initialize_scale_zero_point_observer(
88
125
  return # no need to register a scale and zero point for a dynamic observer
89
126
 
90
127
  device = next(module.parameters()).device
128
+ if is_module_offloaded(module):
129
+ device = get_execution_device(module)
130
+
131
+ # infer expected scale/zero point shape
132
+ expected_shape = 1 # per tensor
133
+
134
+ if base_name == "weight" and weight_shape is not None:
135
+ if quantization_args.strategy == QuantizationStrategy.CHANNEL:
136
+ # (output_channels, 1)
137
+ expected_shape = (weight_shape[0], 1)
138
+ elif quantization_args.strategy == QuantizationStrategy.GROUP:
139
+ expected_shape = (
140
+ weight_shape[0],
141
+ weight_shape[1] // quantization_args.group_size,
142
+ )
91
143
 
92
144
  # initializes empty scale and zero point parameters for the module
93
- init_scale = Parameter(torch.empty(0, device=device), requires_grad=False)
145
+ init_scale = Parameter(
146
+ torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
147
+ requires_grad=False,
148
+ )
94
149
  module.register_parameter(f"{base_name}_scale", init_scale)
95
150
 
151
+ zp_dtype = quantization_args.pytorch_dtype()
96
152
  init_zero_point = Parameter(
97
- torch.empty(0, device=device, dtype=int), requires_grad=False
153
+ torch.empty(expected_shape, device=device, dtype=zp_dtype),
154
+ requires_grad=False,
98
155
  )
99
156
  module.register_parameter(f"{base_name}_zero_point", init_zero_point)