compressed-tensors-nightly 0.5.0.20240813__py3-none-any.whl → 0.5.0.20240829__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressor.py +53 -0
- compressed_tensors/compressors/naive_quantized.py +11 -2
- compressed_tensors/compressors/pack_quantized.py +7 -2
- compressed_tensors/quantization/lifecycle/apply.py +5 -0
- compressed_tensors/quantization/lifecycle/calibration.py +3 -2
- compressed_tensors/quantization/lifecycle/forward.py +48 -24
- compressed_tensors/quantization/lifecycle/helpers.py +29 -2
- compressed_tensors/quantization/lifecycle/initialize.py +22 -3
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +54 -14
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +48 -20
- compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors/utils/offload.py +4 -1
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/utils/semi_structured_conversions.py +1 -0
- {compressed_tensors_nightly-0.5.0.20240813.dist-info → compressed_tensors_nightly-0.5.0.20240829.dist-info}/METADATA +3 -2
- {compressed_tensors_nightly-0.5.0.20240813.dist-info → compressed_tensors_nightly-0.5.0.20240829.dist-info}/RECORD +23 -21
- {compressed_tensors_nightly-0.5.0.20240813.dist-info → compressed_tensors_nightly-0.5.0.20240829.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.5.0.20240813.dist-info → compressed_tensors_nightly-0.5.0.20240829.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.5.0.20240813.dist-info → compressed_tensors_nightly-0.5.0.20240829.dist-info}/top_level.txt +0 -0
@@ -240,6 +240,59 @@ class ModelCompressor:
|
|
240
240
|
compressed_state_dict
|
241
241
|
)
|
242
242
|
|
243
|
+
# HACK (mgoin): Post-process step for kv cache scales to take the
|
244
|
+
# k/v_proj module `output_scale` parameters, and store them in the
|
245
|
+
# parent attention module as `k_scale` and `v_scale`
|
246
|
+
#
|
247
|
+
# Example:
|
248
|
+
# Replace `model.layers.0.self_attn.k_proj.output_scale`
|
249
|
+
# with `model.layers.0.self_attn.k_scale`
|
250
|
+
if (
|
251
|
+
self.quantization_config is not None
|
252
|
+
and self.quantization_config.kv_cache_scheme is not None
|
253
|
+
):
|
254
|
+
# HACK (mgoin): We assume the quantized modules in question
|
255
|
+
# will be k_proj and v_proj since those are the default targets.
|
256
|
+
# We check that both of these modules have output activation
|
257
|
+
# quantization, and additionally check that q_proj doesn't.
|
258
|
+
q_proj_has_no_quant_output = 0
|
259
|
+
k_proj_has_quant_output = 0
|
260
|
+
v_proj_has_quant_output = 0
|
261
|
+
for name, module in model.named_modules():
|
262
|
+
if not hasattr(module, "quantization_scheme"):
|
263
|
+
continue
|
264
|
+
out_act = module.quantization_scheme.output_activations
|
265
|
+
if name.endswith(".q_proj") and out_act is None:
|
266
|
+
q_proj_has_no_quant_output += 1
|
267
|
+
elif name.endswith(".k_proj") and out_act is not None:
|
268
|
+
k_proj_has_quant_output += 1
|
269
|
+
elif name.endswith(".v_proj") and out_act is not None:
|
270
|
+
v_proj_has_quant_output += 1
|
271
|
+
|
272
|
+
assert (
|
273
|
+
q_proj_has_no_quant_output > 0
|
274
|
+
and k_proj_has_quant_output > 0
|
275
|
+
and v_proj_has_quant_output > 0
|
276
|
+
)
|
277
|
+
assert (
|
278
|
+
q_proj_has_no_quant_output
|
279
|
+
== k_proj_has_quant_output
|
280
|
+
== v_proj_has_quant_output
|
281
|
+
)
|
282
|
+
|
283
|
+
# Move all .k/v_proj.output_scale parameters to .k/v_scale
|
284
|
+
working_state_dict = {}
|
285
|
+
for key in compressed_state_dict.keys():
|
286
|
+
if key.endswith(".k_proj.output_scale"):
|
287
|
+
new_key = key.replace(".k_proj.output_scale", ".k_scale")
|
288
|
+
working_state_dict[new_key] = compressed_state_dict[key]
|
289
|
+
elif key.endswith(".v_proj.output_scale"):
|
290
|
+
new_key = key.replace(".v_proj.output_scale", ".v_scale")
|
291
|
+
working_state_dict[new_key] = compressed_state_dict[key]
|
292
|
+
else:
|
293
|
+
working_state_dict[key] = compressed_state_dict[key]
|
294
|
+
compressed_state_dict = working_state_dict
|
295
|
+
|
243
296
|
# HACK: Override the dtype_byte_size function in transformers to
|
244
297
|
# support float8 types. Fix is posted upstream
|
245
298
|
# https://github.com/huggingface/transformers/pull/30488
|
@@ -44,7 +44,12 @@ class QuantizationCompressor(Compressor):
|
|
44
44
|
type to the type specified by the layer's QuantizationArgs.
|
45
45
|
"""
|
46
46
|
|
47
|
-
COMPRESSION_PARAM_NAMES = [
|
47
|
+
COMPRESSION_PARAM_NAMES = [
|
48
|
+
"weight",
|
49
|
+
"weight_scale",
|
50
|
+
"weight_zero_point",
|
51
|
+
"weight_g_idx",
|
52
|
+
]
|
48
53
|
|
49
54
|
def compress(
|
50
55
|
self,
|
@@ -71,6 +76,7 @@ class QuantizationCompressor(Compressor):
|
|
71
76
|
prefix = name[: -(len(weight_suffix))]
|
72
77
|
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
|
73
78
|
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
|
79
|
+
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
|
74
80
|
if scale is not None and zp is not None:
|
75
81
|
# weight is quantized, compress it
|
76
82
|
quant_args = names_to_scheme[prefix]
|
@@ -82,6 +88,7 @@ class QuantizationCompressor(Compressor):
|
|
82
88
|
zero_point=zp,
|
83
89
|
args=quant_args,
|
84
90
|
dtype=quant_args.pytorch_dtype(),
|
91
|
+
g_idx=g_idx,
|
85
92
|
)
|
86
93
|
elif name.endswith("zero_point"):
|
87
94
|
if torch.all(value == 0):
|
@@ -116,12 +123,14 @@ class QuantizationCompressor(Compressor):
|
|
116
123
|
weight_data[param_name] = f.get_tensor(full_name)
|
117
124
|
|
118
125
|
if "weight_scale" in weight_data:
|
119
|
-
zero_point = weight_data.get("weight_zero_point", None)
|
120
126
|
scale = weight_data["weight_scale"]
|
127
|
+
zero_point = weight_data.get("weight_zero_point", None)
|
128
|
+
g_idx = weight_data.get("weight_g_idx", None)
|
121
129
|
decompressed = dequantize(
|
122
130
|
x_q=weight_data["weight"],
|
123
131
|
scale=scale,
|
124
132
|
zero_point=zero_point,
|
133
|
+
g_idx=g_idx,
|
125
134
|
)
|
126
135
|
yield merge_names(weight_name, "weight"), decompressed
|
127
136
|
|
@@ -44,6 +44,7 @@ class PackedQuantizationCompressor(Compressor):
|
|
44
44
|
"weight_packed",
|
45
45
|
"weight_scale",
|
46
46
|
"weight_zero_point",
|
47
|
+
"weight_g_idx",
|
47
48
|
"weight_shape",
|
48
49
|
]
|
49
50
|
|
@@ -72,6 +73,7 @@ class PackedQuantizationCompressor(Compressor):
|
|
72
73
|
prefix = name[: -(len(weight_suffix))]
|
73
74
|
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
|
74
75
|
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
|
76
|
+
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
|
75
77
|
shape = torch.tensor(value.shape)
|
76
78
|
if scale is not None and zp is not None:
|
77
79
|
# weight is quantized, compress it
|
@@ -82,6 +84,7 @@ class PackedQuantizationCompressor(Compressor):
|
|
82
84
|
x=value,
|
83
85
|
scale=scale,
|
84
86
|
zero_point=zp,
|
87
|
+
g_idx=g_idx,
|
85
88
|
args=quant_args,
|
86
89
|
dtype=torch.int8,
|
87
90
|
)
|
@@ -128,9 +131,10 @@ class PackedQuantizationCompressor(Compressor):
|
|
128
131
|
weight_data[param_name] = f.get_tensor(full_name)
|
129
132
|
|
130
133
|
if "weight_scale" in weight_data:
|
131
|
-
zero_point = weight_data.get("weight_zero_point", None)
|
132
|
-
scale = weight_data["weight_scale"]
|
133
134
|
weight = weight_data["weight_packed"]
|
135
|
+
scale = weight_data["weight_scale"]
|
136
|
+
zero_point = weight_data.get("weight_zero_point", None)
|
137
|
+
g_idx = weight_data.get("weight_g_idx", None)
|
134
138
|
num_bits = weight_data["num_bits"]
|
135
139
|
original_shape = torch.Size(weight_data["weight_shape"])
|
136
140
|
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
@@ -138,6 +142,7 @@ class PackedQuantizationCompressor(Compressor):
|
|
138
142
|
x_q=unpacked,
|
139
143
|
scale=scale,
|
140
144
|
zero_point=zero_point,
|
145
|
+
g_idx=g_idx,
|
141
146
|
)
|
142
147
|
yield merge_names(weight_name, "weight"), decompressed
|
143
148
|
|
@@ -279,9 +279,11 @@ def _load_quant_args_from_state_dict(
|
|
279
279
|
"""
|
280
280
|
scale_name = f"{base_name}_scale"
|
281
281
|
zp_name = f"{base_name}_zero_point"
|
282
|
+
g_idx_name = f"{base_name}_g_idx"
|
282
283
|
|
283
284
|
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
|
284
285
|
state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
|
286
|
+
state_dict_g_idx = state_dict.get(f"{module_name}.{g_idx_name}", None)
|
285
287
|
|
286
288
|
if state_dict_scale is not None:
|
287
289
|
# module is quantized
|
@@ -291,6 +293,9 @@ def _load_quant_args_from_state_dict(
|
|
291
293
|
state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
|
292
294
|
update_parameter_data(module, state_dict_zp, zp_name)
|
293
295
|
|
296
|
+
if state_dict_g_idx is not None:
|
297
|
+
update_parameter_data(module, state_dict_g_idx, g_idx_name)
|
298
|
+
|
294
299
|
|
295
300
|
def _scheme_from_targets(
|
296
301
|
target_to_scheme: OrderedDictType[str, QuantizationScheme],
|
@@ -44,7 +44,7 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
|
|
44
44
|
return
|
45
45
|
status = getattr(module, "quantization_status", None)
|
46
46
|
if not status or status != QuantizationStatus.INITIALIZED:
|
47
|
-
|
47
|
+
_LOGGER.warning(
|
48
48
|
f"Attempting set module with status {status} to calibration mode. "
|
49
49
|
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
|
50
50
|
"be calibrating an uninitialized module which may fail or attempting "
|
@@ -54,13 +54,14 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
|
|
54
54
|
if quantize_weights_upfront and module.quantization_scheme.weights is not None:
|
55
55
|
# set weight scale and zero_point up front, calibration data doesn't affect it
|
56
56
|
observer = module.weight_observer
|
57
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
57
58
|
|
58
59
|
offloaded = False
|
59
60
|
if is_module_offloaded(module):
|
60
61
|
module._hf_hook.pre_forward(module)
|
61
62
|
offloaded = True
|
62
63
|
|
63
|
-
scale, zero_point = observer(module.weight)
|
64
|
+
scale, zero_point = observer(module.weight, g_idx=g_idx)
|
64
65
|
update_parameter_data(module, scale, "weight_scale")
|
65
66
|
update_parameter_data(module, zero_point, "weight_zero_point")
|
66
67
|
|
@@ -25,7 +25,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
25
25
|
)
|
26
26
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
27
27
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
28
|
-
from compressed_tensors.utils import update_parameter_data
|
28
|
+
from compressed_tensors.utils import safe_permute, update_parameter_data
|
29
29
|
from torch.nn import Module
|
30
30
|
|
31
31
|
|
@@ -45,6 +45,7 @@ def quantize(
|
|
45
45
|
zero_point: torch.Tensor,
|
46
46
|
args: QuantizationArgs,
|
47
47
|
dtype: Optional[torch.dtype] = None,
|
48
|
+
g_idx: Optional[torch.Tensor] = None,
|
48
49
|
) -> torch.Tensor:
|
49
50
|
"""
|
50
51
|
Quantize the input tensor x using the QuantizationStrategy specified in args.
|
@@ -58,6 +59,7 @@ def quantize(
|
|
58
59
|
:param zero_point: zero point tensor
|
59
60
|
:param args: quantization args dictating how to quantize x
|
60
61
|
:param dtype: optional dtype to cast the quantized output to
|
62
|
+
:param g_idx: optional mapping from column index to group index
|
61
63
|
:return: fake quantized tensor
|
62
64
|
"""
|
63
65
|
# ensure all tensors are on the same device
|
@@ -76,6 +78,7 @@ def quantize(
|
|
76
78
|
dtype=dtype,
|
77
79
|
do_quantize=True,
|
78
80
|
do_dequantize=False,
|
81
|
+
g_idx=g_idx,
|
79
82
|
)
|
80
83
|
|
81
84
|
|
@@ -86,6 +89,7 @@ def dequantize(
|
|
86
89
|
zero_point: torch.Tensor = None,
|
87
90
|
args: QuantizationArgs = None,
|
88
91
|
dtype: Optional[torch.dtype] = None,
|
92
|
+
g_idx: Optional[torch.Tensor] = None,
|
89
93
|
) -> torch.Tensor:
|
90
94
|
"""
|
91
95
|
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
@@ -96,6 +100,7 @@ def dequantize(
|
|
96
100
|
:param zero_point: zero point tensor
|
97
101
|
:param args: quantization args used to quantize x_q
|
98
102
|
:param dtype: optional dtype to cast the dequantized output to
|
103
|
+
:param g_idx: optional mapping from column index to group index
|
99
104
|
:return: dequantized float tensor
|
100
105
|
"""
|
101
106
|
if args is None:
|
@@ -126,6 +131,7 @@ def dequantize(
|
|
126
131
|
do_quantize=False,
|
127
132
|
do_dequantize=True,
|
128
133
|
dtype=dtype,
|
134
|
+
g_idx=g_idx,
|
129
135
|
)
|
130
136
|
|
131
137
|
|
@@ -135,6 +141,7 @@ def fake_quantize(
|
|
135
141
|
scale: torch.Tensor,
|
136
142
|
zero_point: torch.Tensor,
|
137
143
|
args: QuantizationArgs,
|
144
|
+
g_idx: Optional[torch.Tensor] = None,
|
138
145
|
) -> torch.Tensor:
|
139
146
|
"""
|
140
147
|
Fake quantize the input tensor x by quantizing then dequantizing with
|
@@ -147,6 +154,7 @@ def fake_quantize(
|
|
147
154
|
:param scale: scale tensor
|
148
155
|
:param zero_point: zero point tensor
|
149
156
|
:param args: quantization args dictating how to quantize x
|
157
|
+
:param g_idx: optional mapping from column index to group index
|
150
158
|
:return: fake quantized tensor
|
151
159
|
"""
|
152
160
|
return _process_quantization(
|
@@ -156,6 +164,7 @@ def fake_quantize(
|
|
156
164
|
args=args,
|
157
165
|
do_quantize=True,
|
158
166
|
do_dequantize=True,
|
167
|
+
g_idx=g_idx,
|
159
168
|
)
|
160
169
|
|
161
170
|
|
@@ -164,21 +173,19 @@ def _process_quantization(
|
|
164
173
|
x: torch.Tensor,
|
165
174
|
scale: torch.Tensor,
|
166
175
|
zero_point: torch.Tensor,
|
176
|
+
g_idx: Optional[torch.Tensor],
|
167
177
|
args: QuantizationArgs,
|
168
178
|
dtype: Optional[torch.dtype] = None,
|
169
179
|
do_quantize: bool = True,
|
170
180
|
do_dequantize: bool = True,
|
171
181
|
) -> torch.Tensor:
|
172
|
-
|
173
182
|
q_min, q_max = calculate_range(args, x.device)
|
174
183
|
group_size = args.group_size
|
175
184
|
|
176
185
|
if args.strategy == QuantizationStrategy.GROUP:
|
177
186
|
output_dtype = dtype if dtype is not None else x.dtype
|
178
187
|
output = torch.zeros_like(x).to(output_dtype)
|
179
|
-
|
180
|
-
# TODO: vectorize the for loop
|
181
|
-
# TODO: fix genetric assumption about the tensor size for computing group
|
188
|
+
columns = output.shape[1]
|
182
189
|
|
183
190
|
# TODO: make validation step for inputs
|
184
191
|
|
@@ -187,23 +194,38 @@ def _process_quantization(
|
|
187
194
|
scale = scale.unsqueeze(1)
|
188
195
|
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
|
189
196
|
|
190
|
-
columns = x.shape[1]
|
191
197
|
if columns >= group_size:
|
192
198
|
if columns % group_size != 0:
|
193
199
|
raise ValueError(
|
194
|
-
"
|
200
|
+
"tensor column shape must be divisble "
|
195
201
|
f"by the given group_size {group_size}"
|
196
202
|
)
|
197
|
-
for i in range(ceil(columns / group_size)):
|
198
|
-
# scale.shape should be [nchan, ndim]
|
199
|
-
# sc.shape should be [nchan, 1] after unsqueeze
|
200
|
-
sc = scale[:, i].view(-1, 1)
|
201
|
-
zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
|
202
203
|
|
203
|
-
|
204
|
+
# support column-order (default) quantization as well as other orderings
|
205
|
+
# such as activation ordering. Below checks if g_idx has been initialized
|
206
|
+
is_column_order = g_idx is None or -1 in g_idx
|
207
|
+
if is_column_order:
|
208
|
+
num_groups = int(ceil(columns / group_size))
|
209
|
+
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
|
210
|
+
|
211
|
+
else:
|
212
|
+
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
|
213
|
+
group_sizes = group_sizes[torch.argsort(group_indices)]
|
214
|
+
|
215
|
+
perm = torch.argsort(g_idx)
|
216
|
+
x = safe_permute(x, perm, dim=1)
|
217
|
+
|
218
|
+
# TODO: experiment with vectorizing for loop for performance
|
219
|
+
end = 0
|
220
|
+
for index, group_count in enumerate(group_sizes):
|
221
|
+
sc = scale[:, index].view(-1, 1)
|
222
|
+
zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
|
223
|
+
|
224
|
+
start = end
|
225
|
+
end = start + group_count
|
204
226
|
if do_quantize:
|
205
|
-
output[:,
|
206
|
-
x[:,
|
227
|
+
output[:, start:end] = _quantize(
|
228
|
+
x[:, start:end],
|
207
229
|
sc,
|
208
230
|
zp,
|
209
231
|
q_min,
|
@@ -211,13 +233,13 @@ def _process_quantization(
|
|
211
233
|
args,
|
212
234
|
dtype=dtype,
|
213
235
|
)
|
236
|
+
|
214
237
|
if do_dequantize:
|
215
|
-
input =
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
|
238
|
+
input = output[:, start:end] if do_quantize else x[:, start:end]
|
239
|
+
output[:, start:end] = _dequantize(input, sc, zp)
|
240
|
+
|
241
|
+
if not is_column_order:
|
242
|
+
output = safe_permute(output, torch.argsort(perm), dim=1)
|
221
243
|
|
222
244
|
else: # covers channel, token and tensor strategies
|
223
245
|
if do_quantize:
|
@@ -304,10 +326,12 @@ def maybe_calibrate_or_quantize(
|
|
304
326
|
# skip quantization
|
305
327
|
return value
|
306
328
|
|
329
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
330
|
+
|
307
331
|
if args.dynamic:
|
308
332
|
# dynamic quantization - get scale and zero point directly from observer
|
309
333
|
observer = getattr(module, f"{base_name}_observer")
|
310
|
-
scale, zero_point = observer(value)
|
334
|
+
scale, zero_point = observer(value, g_idx=g_idx)
|
311
335
|
else:
|
312
336
|
# static quantization - get previous scale and zero point from layer
|
313
337
|
scale = getattr(module, f"{base_name}_scale")
|
@@ -320,13 +344,13 @@ def maybe_calibrate_or_quantize(
|
|
320
344
|
# calibration mode - get new quant params from observer
|
321
345
|
observer = getattr(module, f"{base_name}_observer")
|
322
346
|
|
323
|
-
updated_scale, updated_zero_point = observer(value)
|
347
|
+
updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
|
324
348
|
|
325
349
|
# update scale and zero point
|
326
350
|
update_parameter_data(module, updated_scale, f"{base_name}_scale")
|
327
351
|
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
|
328
352
|
|
329
|
-
return fake_quantize(value, scale, zero_point, args)
|
353
|
+
return fake_quantize(value, scale, zero_point, args, g_idx=g_idx)
|
330
354
|
|
331
355
|
|
332
356
|
@torch.no_grad()
|
@@ -16,7 +16,9 @@
|
|
16
16
|
Miscelaneous helpers for the quantization lifecycle
|
17
17
|
"""
|
18
18
|
|
19
|
+
from typing import Optional
|
19
20
|
|
21
|
+
import torch
|
20
22
|
from torch.nn import Module
|
21
23
|
|
22
24
|
|
@@ -27,16 +29,41 @@ __all__ = [
|
|
27
29
|
]
|
28
30
|
|
29
31
|
|
30
|
-
def update_layer_weight_quant_params(
|
31
|
-
|
32
|
+
def update_layer_weight_quant_params(
|
33
|
+
layer: Module,
|
34
|
+
weight: Optional[torch.Tensor] = None,
|
35
|
+
g_idx: Optional[torch.Tensor] = None,
|
36
|
+
reset_obs: bool = False,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Update quantization parameters on layer
|
40
|
+
|
41
|
+
:param layer: input layer
|
42
|
+
:param weight: weight to update quant params with, defaults to layer weight
|
43
|
+
:param g_idx: optional mapping from column index to group index
|
44
|
+
:param reset_obs: reset the observer before calculating quant params,
|
45
|
+
defaults to False
|
46
|
+
"""
|
47
|
+
attached_weight = getattr(layer, "weight", None)
|
48
|
+
|
49
|
+
if weight is None:
|
50
|
+
weight = attached_weight
|
32
51
|
scale = getattr(layer, "weight_scale", None)
|
33
52
|
zero_point = getattr(layer, "weight_zero_point", None)
|
53
|
+
if g_idx is None:
|
54
|
+
g_idx = getattr(layer, "weight_g_idx", None)
|
34
55
|
observer = getattr(layer, "weight_observer", None)
|
35
56
|
|
36
57
|
if weight is None or observer is None or scale is None or zero_point is None:
|
37
58
|
# scale, zp, or observer not calibratable or weight not available
|
38
59
|
return
|
39
60
|
|
61
|
+
if reset_obs:
|
62
|
+
observer.reset()
|
63
|
+
|
64
|
+
if attached_weight is not None:
|
65
|
+
weight = weight.to(attached_weight.dtype)
|
66
|
+
|
40
67
|
updated_scale, updated_zero_point = observer(weight)
|
41
68
|
|
42
69
|
# update scale and zero point
|
@@ -17,8 +17,6 @@ import logging
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
-
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
21
|
-
from accelerate.utils import PrefixedDataset
|
22
20
|
from compressed_tensors.quantization.lifecycle.forward import (
|
23
21
|
wrap_module_forward_quantized,
|
24
22
|
)
|
@@ -86,6 +84,16 @@ def initialize_module_for_quantization(
|
|
86
84
|
|
87
85
|
offloaded = False
|
88
86
|
if is_module_offloaded(module):
|
87
|
+
try:
|
88
|
+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
89
|
+
from accelerate.utils import PrefixedDataset
|
90
|
+
except ModuleNotFoundError:
|
91
|
+
raise ModuleNotFoundError(
|
92
|
+
"Offloaded model detected. To use CPU offloading with "
|
93
|
+
"compressed-tensors the `accelerate` package must be installed, "
|
94
|
+
"run `pip install compressed-tensors[accelerate]`"
|
95
|
+
)
|
96
|
+
|
89
97
|
offloaded = True
|
90
98
|
hook = module._hf_hook
|
91
99
|
prefix_dict = module._hf_hook.weights_map
|
@@ -141,16 +149,27 @@ def _initialize_scale_zero_point_observer(
|
|
141
149
|
weight_shape[1] // quantization_args.group_size,
|
142
150
|
)
|
143
151
|
|
144
|
-
# initializes empty scale
|
152
|
+
# initializes empty scale, zero point, and g_idx parameters for the module
|
145
153
|
init_scale = Parameter(
|
146
154
|
torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
|
147
155
|
requires_grad=False,
|
148
156
|
)
|
149
157
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
150
158
|
|
159
|
+
# TODO: @kylesayrs do not initialize if symmetric
|
151
160
|
zp_dtype = quantization_args.pytorch_dtype()
|
152
161
|
init_zero_point = Parameter(
|
153
162
|
torch.empty(expected_shape, device=device, dtype=zp_dtype),
|
154
163
|
requires_grad=False,
|
155
164
|
)
|
156
165
|
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
166
|
+
|
167
|
+
# initialize with empty for actorder, to be populated by GPTQ or state_dict
|
168
|
+
if quantization_args.actorder:
|
169
|
+
g_idx_shape = (weight_shape[1],)
|
170
|
+
g_idx_dtype = torch.int
|
171
|
+
init_g_idx = Parameter(
|
172
|
+
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
|
173
|
+
requires_grad=False,
|
174
|
+
)
|
175
|
+
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
+
from math import ceil
|
16
17
|
from typing import Any, Iterable, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import torch
|
@@ -21,6 +22,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
21
22
|
QuantizationStrategy,
|
22
23
|
)
|
23
24
|
from compressed_tensors.registry.registry import RegistryMixin
|
25
|
+
from compressed_tensors.utils import safe_permute
|
24
26
|
from torch import FloatTensor, IntTensor, Tensor
|
25
27
|
from torch.nn import Module
|
26
28
|
|
@@ -46,15 +48,18 @@ class Observer(Module, RegistryMixin):
|
|
46
48
|
self._num_observed_tokens = None
|
47
49
|
|
48
50
|
@torch.no_grad()
|
49
|
-
def forward(
|
51
|
+
def forward(
|
52
|
+
self, observed: Tensor, g_idx: Optional[Tensor] = None
|
53
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
50
54
|
"""
|
51
55
|
maps directly to get_qparams
|
52
|
-
:param observed: optional observed tensor to calculate
|
53
|
-
|
56
|
+
:param observed: optional observed tensor from which to calculate
|
57
|
+
quantization parameters
|
58
|
+
:param g_idx: optional mapping from column index to group index
|
54
59
|
:return: tuple of scale and zero point based on last observed value
|
55
60
|
"""
|
56
61
|
self.record_observed_tokens(observed)
|
57
|
-
return self.get_qparams(observed=observed)
|
62
|
+
return self.get_qparams(observed=observed, g_idx=g_idx)
|
58
63
|
|
59
64
|
def calculate_qparams(
|
60
65
|
self,
|
@@ -77,7 +82,9 @@ class Observer(Module, RegistryMixin):
|
|
77
82
|
...
|
78
83
|
|
79
84
|
def get_qparams(
|
80
|
-
self,
|
85
|
+
self,
|
86
|
+
observed: Optional[Tensor] = None,
|
87
|
+
g_idx: Optional[Tensor] = None,
|
81
88
|
) -> Tuple[FloatTensor, IntTensor]:
|
82
89
|
"""
|
83
90
|
Convenience function to wrap overwritten calculate_qparams
|
@@ -86,6 +93,7 @@ class Observer(Module, RegistryMixin):
|
|
86
93
|
|
87
94
|
:param observed: optional observed tensor to calculate quantization parameters
|
88
95
|
from
|
96
|
+
:param g_idx: optional mapping from column index to group index
|
89
97
|
:return: tuple of scale and zero point based on last observed value
|
90
98
|
"""
|
91
99
|
if observed is not None:
|
@@ -97,20 +105,42 @@ class Observer(Module, RegistryMixin):
|
|
97
105
|
self._scale, self._zero_point = self.calculate_qparams(observed)
|
98
106
|
|
99
107
|
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
|
108
|
+
rows = observed.shape[0]
|
100
109
|
columns = observed.shape[1]
|
101
|
-
|
102
|
-
|
103
|
-
|
110
|
+
num_groups = int(ceil(columns / group_size))
|
111
|
+
self._scale = torch.empty(
|
112
|
+
(rows, num_groups), dtype=observed.dtype, device=observed.device
|
113
|
+
)
|
114
|
+
zp_dtype = self.quantization_args.pytorch_dtype()
|
115
|
+
self._zero_point = torch.empty(
|
116
|
+
(rows, num_groups), dtype=zp_dtype, device=observed.device
|
117
|
+
)
|
118
|
+
|
119
|
+
# support column-order (default) quantization as well as other orderings
|
120
|
+
# such as activation ordering. Below checks if g_idx has initialized
|
121
|
+
is_column_order = g_idx is None or -1 in g_idx
|
122
|
+
if is_column_order:
|
123
|
+
group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
|
124
|
+
else:
|
125
|
+
group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
|
126
|
+
group_sizes = group_sizes[torch.argsort(group_indices)]
|
127
|
+
|
128
|
+
perm = torch.argsort(g_idx)
|
129
|
+
observed = safe_permute(observed, perm, dim=1)
|
130
|
+
|
131
|
+
# TODO: experiment with vectorizing for loop for performance
|
132
|
+
end = 0
|
133
|
+
for group_index, group_count in enumerate(group_sizes):
|
134
|
+
start = end
|
135
|
+
end = start + group_count
|
104
136
|
scale, zero_point = self.get_qparams_along_dim(
|
105
|
-
observed[:,
|
137
|
+
observed[:, start:end],
|
106
138
|
0,
|
107
|
-
tensor_id=
|
139
|
+
tensor_id=group_index,
|
108
140
|
)
|
109
|
-
scales.append(scale)
|
110
|
-
zero_points.append(zero_point)
|
111
141
|
|
112
|
-
|
113
|
-
|
142
|
+
self._scale[:, group_index] = scale.squeeze(1)
|
143
|
+
self._zero_point[:, group_index] = zero_point.squeeze(1)
|
114
144
|
|
115
145
|
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
116
146
|
# assume observed is transposed, because its the output, hence use dim 0
|
@@ -132,6 +162,8 @@ class Observer(Module, RegistryMixin):
|
|
132
162
|
dim: Union[int, Iterable[int]],
|
133
163
|
tensor_id: Optional[Any] = None,
|
134
164
|
):
|
165
|
+
if isinstance(dim, int):
|
166
|
+
dim = [dim]
|
135
167
|
dim = set(dim)
|
136
168
|
|
137
169
|
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
|
@@ -171,3 +203,11 @@ class Observer(Module, RegistryMixin):
|
|
171
203
|
# observed_tokens (batch_size * sequence_length)
|
172
204
|
observed_tokens, _ = batch_tensor.shape
|
173
205
|
self._num_observed_tokens += observed_tokens
|
206
|
+
|
207
|
+
def reset(self):
|
208
|
+
"""
|
209
|
+
Reset the state of the observer
|
210
|
+
"""
|
211
|
+
self._num_observed_tokens = None
|
212
|
+
self._scale = None
|
213
|
+
self._zero_point = None
|
@@ -94,3 +94,11 @@ class MovingAverageMinMaxObserver(Observer):
|
|
94
94
|
return self.calculate_qparams(
|
95
95
|
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
96
96
|
)
|
97
|
+
|
98
|
+
def reset(self):
|
99
|
+
"""
|
100
|
+
Reset the state of the observer, including min and maximum values
|
101
|
+
"""
|
102
|
+
super().reset()
|
103
|
+
self.min_val = {}
|
104
|
+
self.max_val = {}
|
@@ -0,0 +1,162 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.quantization.observers.base import Observer
|
19
|
+
from compressed_tensors.quantization.observers.helpers import calculate_qparams
|
20
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
21
|
+
from torch import FloatTensor, IntTensor, Tensor
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = ["MovingAverageMSEObserver"]
|
25
|
+
|
26
|
+
|
27
|
+
@Observer.register("mse")
|
28
|
+
class MovingAverageMSEObserver(Observer):
|
29
|
+
"""
|
30
|
+
Implements a dynamic quantization observer that sets the scale and
|
31
|
+
zero point based on a moving average of the mse-clipped min and max observed values
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
quantization_args: QuantizationArgs,
|
37
|
+
averaging_constant: float = 0.01,
|
38
|
+
grid: float = 100.0,
|
39
|
+
maxshrink: float = 0.80,
|
40
|
+
norm: float = 2.4,
|
41
|
+
):
|
42
|
+
super().__init__(quantization_args=quantization_args)
|
43
|
+
|
44
|
+
self.min_val = {}
|
45
|
+
self.max_val = {}
|
46
|
+
self.averaging_constant = averaging_constant
|
47
|
+
self.grid = grid
|
48
|
+
self.maxshrink = maxshrink
|
49
|
+
self.norm = norm
|
50
|
+
|
51
|
+
def calculate_mse_min_max(
|
52
|
+
self,
|
53
|
+
observed: Tensor,
|
54
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
55
|
+
):
|
56
|
+
"""
|
57
|
+
Computes the mse-clipped min and max values of the observed tensor by
|
58
|
+
optimizing for quantization error
|
59
|
+
|
60
|
+
:param observed: observed tensor to calculate quantization parameters for
|
61
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
62
|
+
returned values will be shaped (1,) along the reduced dimensions
|
63
|
+
:return: tuple of min and max values derived from the observed tensor
|
64
|
+
"""
|
65
|
+
from compressed_tensors.quantization.lifecycle import fake_quantize
|
66
|
+
|
67
|
+
if not reduce_dims:
|
68
|
+
absolute_min_val, absolute_max_val = torch.aminmax(observed)
|
69
|
+
else:
|
70
|
+
absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
71
|
+
absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
72
|
+
|
73
|
+
best = torch.full(absolute_min_val.shape, float("inf"))
|
74
|
+
min_val = torch.ones(absolute_min_val.shape)
|
75
|
+
max_val = torch.zeros(absolute_max_val.shape)
|
76
|
+
for i in range(int(self.maxshrink * self.grid)):
|
77
|
+
p = 1 - i / self.grid
|
78
|
+
shrinked_min_val = p * absolute_min_val
|
79
|
+
shrinked_max_val = p * absolute_max_val
|
80
|
+
|
81
|
+
candidate_scales, candidate_zero_points = calculate_qparams(
|
82
|
+
shrinked_min_val, shrinked_max_val, self.quantization_args
|
83
|
+
)
|
84
|
+
q = fake_quantize(
|
85
|
+
observed,
|
86
|
+
candidate_scales,
|
87
|
+
candidate_zero_points,
|
88
|
+
self.quantization_args,
|
89
|
+
)
|
90
|
+
|
91
|
+
q -= observed
|
92
|
+
q.abs_()
|
93
|
+
q.pow_(self.norm)
|
94
|
+
if not reduce_dims:
|
95
|
+
err = torch.sum(q)
|
96
|
+
else:
|
97
|
+
err = torch.sum(q, reduce_dims, keepdims=True)
|
98
|
+
|
99
|
+
tmp = err < best
|
100
|
+
if torch.any(tmp):
|
101
|
+
best[tmp] = err[tmp]
|
102
|
+
min_val[tmp] = shrinked_min_val[tmp]
|
103
|
+
max_val[tmp] = shrinked_max_val[tmp]
|
104
|
+
return min_val, max_val
|
105
|
+
|
106
|
+
def calculate_qparams(
|
107
|
+
self,
|
108
|
+
observed: Tensor,
|
109
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
110
|
+
tensor_id: Optional[Any] = None,
|
111
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
112
|
+
"""
|
113
|
+
Updates the mse-clipped min and max values of the observed tensor using
|
114
|
+
a moving average smoothed by the averaging_constant
|
115
|
+
|
116
|
+
:param observed: observed tensor to calculate quantization parameters for
|
117
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
118
|
+
returned scale and zero point will be shaped (1,) along the
|
119
|
+
reduced dimensions
|
120
|
+
:param tensor_id: Optional id if different ranges of observed tensors are
|
121
|
+
passed, useful for sharding tensors by group_size
|
122
|
+
:return: tuple of scale and zero point derived from the observed tensor
|
123
|
+
"""
|
124
|
+
min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)
|
125
|
+
|
126
|
+
running_min_val = self.min_val.get(tensor_id, None)
|
127
|
+
running_max_val = self.max_val.get(tensor_id, None)
|
128
|
+
|
129
|
+
if running_min_val is None or running_max_val is None:
|
130
|
+
updated_min_val = min_val
|
131
|
+
updated_max_val = max_val
|
132
|
+
else:
|
133
|
+
updated_min_val = running_min_val + self.averaging_constant * (
|
134
|
+
min_val - running_min_val
|
135
|
+
)
|
136
|
+
updated_max_val = running_max_val + self.averaging_constant * (
|
137
|
+
max_val - running_max_val
|
138
|
+
)
|
139
|
+
|
140
|
+
tensor_id = tensor_id or "default"
|
141
|
+
self.min_val[tensor_id] = updated_min_val
|
142
|
+
self.max_val[tensor_id] = updated_max_val
|
143
|
+
|
144
|
+
return calculate_qparams(
|
145
|
+
updated_min_val, updated_max_val, self.quantization_args
|
146
|
+
)
|
147
|
+
|
148
|
+
def get_qparams_along_dim(
|
149
|
+
self, observed, dim: int, tensor_id: Optional[Any] = None
|
150
|
+
):
|
151
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
|
152
|
+
return self.calculate_qparams(
|
153
|
+
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
154
|
+
)
|
155
|
+
|
156
|
+
def reset(self):
|
157
|
+
"""
|
158
|
+
Reset the state of the observer, including min and maximum values
|
159
|
+
"""
|
160
|
+
super().reset()
|
161
|
+
self.min_val = {}
|
162
|
+
self.max_val = {}
|
@@ -16,7 +16,7 @@ from enum import Enum
|
|
16
16
|
from typing import Any, Dict, Optional
|
17
17
|
|
18
18
|
import torch
|
19
|
-
from pydantic import BaseModel, Field,
|
19
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
20
20
|
|
21
21
|
|
22
22
|
__all__ = [
|
@@ -68,6 +68,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
68
68
|
ranges will be observed with every sample. Defaults to False for static
|
69
69
|
quantization. Note that enabling dynamic quantization will change the default
|
70
70
|
observer to a memoryless one
|
71
|
+
:param actorder: whether to apply group quantization in decreasing order of
|
72
|
+
activation. Defaults to False for arbitrary ordering
|
71
73
|
"""
|
72
74
|
|
73
75
|
num_bits: int = 8
|
@@ -77,6 +79,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
77
79
|
strategy: Optional[QuantizationStrategy] = None
|
78
80
|
block_structure: Optional[str] = None
|
79
81
|
dynamic: bool = False
|
82
|
+
actorder: bool = False
|
80
83
|
observer: str = Field(
|
81
84
|
default="minmax",
|
82
85
|
description=(
|
@@ -98,40 +101,65 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
98
101
|
"""
|
99
102
|
from compressed_tensors.quantization.observers.base import Observer
|
100
103
|
|
101
|
-
if self.
|
104
|
+
if self.dynamic:
|
102
105
|
# override defualt observer for dynamic, you never want minmax which
|
103
106
|
# keeps state across samples for dynamic
|
104
107
|
self.observer = "memoryless"
|
105
108
|
|
106
109
|
return Observer.load_from_registry(self.observer, quantization_args=self)
|
107
110
|
|
108
|
-
@
|
109
|
-
def
|
110
|
-
|
111
|
+
@field_validator("group_size", mode="before")
|
112
|
+
def validate_group(cls, value) -> int:
|
113
|
+
if value is None:
|
114
|
+
return value
|
111
115
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
+
if value < -1:
|
117
|
+
raise ValueError(
|
118
|
+
f"Invalid group size {value}. Use group_size > 0 for "
|
119
|
+
"strategy='group' and group_size = -1 for 'channel'"
|
120
|
+
)
|
116
121
|
|
117
|
-
|
118
|
-
|
122
|
+
return value
|
123
|
+
|
124
|
+
@model_validator(mode="before")
|
125
|
+
def validate_strategy(values) -> Dict[str, Any]:
|
126
|
+
model_fields = QuantizationArgs.model_fields
|
127
|
+
strategy = values.get("strategy", model_fields["strategy"].default)
|
128
|
+
group_size = values.get("group_size", model_fields["group_size"].default)
|
129
|
+
actorder = values.get("actorder", model_fields["actorder"].default)
|
119
130
|
|
131
|
+
if strategy is not None:
|
132
|
+
strategy = QuantizationStrategy(strategy.lower())
|
133
|
+
|
134
|
+
else:
|
135
|
+
# use group_size to determinine strategy if not given explicity
|
136
|
+
if group_size is None:
|
137
|
+
strategy = QuantizationStrategy.TENSOR
|
138
|
+
elif group_size > 0:
|
139
|
+
strategy = QuantizationStrategy.GROUP
|
140
|
+
elif group_size == -1:
|
141
|
+
strategy = QuantizationStrategy.CHANNEL
|
120
142
|
else:
|
121
143
|
raise ValueError(
|
122
|
-
f"
|
123
|
-
"group_size
|
124
|
-
"group_size = -1 for 'channel'"
|
144
|
+
f"Invalid group size {group_size}. Use group_size > 0 for "
|
145
|
+
"strategy='group' and group_size = -1 for 'channel'"
|
125
146
|
)
|
126
147
|
|
127
|
-
if
|
128
|
-
if group_size is None:
|
129
|
-
raise ValueError(
|
148
|
+
if strategy == QuantizationStrategy.GROUP:
|
149
|
+
if group_size is None or group_size <= 0:
|
150
|
+
raise ValueError(
|
151
|
+
f"strategy {strategy} requires group_size to be "
|
152
|
+
"set to a positive value"
|
153
|
+
)
|
130
154
|
|
131
|
-
if
|
132
|
-
|
155
|
+
if actorder and strategy != QuantizationStrategy.GROUP:
|
156
|
+
raise ValueError(
|
157
|
+
"Group quantization must be specified in order to apply "
|
158
|
+
"activation ordering"
|
159
|
+
)
|
133
160
|
|
134
|
-
|
161
|
+
values["strategy"] = strategy
|
162
|
+
return values
|
135
163
|
|
136
164
|
def pytorch_dtype(self) -> torch.dtype:
|
137
165
|
if self.type == QuantizationType.FLOAT:
|
@@ -89,7 +89,7 @@ def update_parameter_data(
|
|
89
89
|
|
90
90
|
:param module: layer containing the parameter to update
|
91
91
|
:param new_param_data: tensor to update parameter with
|
92
|
-
:param param_name:
|
92
|
+
:param param_name: name of layer parameter to update
|
93
93
|
"""
|
94
94
|
device = next(module.parameters()).device
|
95
95
|
|
@@ -99,6 +99,9 @@ def update_parameter_data(
|
|
99
99
|
offloaded = True
|
100
100
|
|
101
101
|
parameter = getattr(module, param_name, None)
|
102
|
+
if parameter is None:
|
103
|
+
raise ValueError("Attempted to update uninitialized parameter")
|
104
|
+
|
102
105
|
dtype = parameter.dtype
|
103
106
|
parameter.data = new_param_data.to(device).to(dtype)
|
104
107
|
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Set, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
|
19
|
+
|
20
|
+
__all__ = ["safe_permute"]
|
21
|
+
|
22
|
+
|
23
|
+
# these datatypes are missing implementations required for standard permutation
|
24
|
+
_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
|
25
|
+
|
26
|
+
|
27
|
+
def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
28
|
+
"""
|
29
|
+
Perform out-of-place permutation without using torch.Tensor.index_put_,
|
30
|
+
whose implementation is missing for datatypes such as `torch.float8_e4m3fn`
|
31
|
+
|
32
|
+
:param value: tensor to permute
|
33
|
+
:param perm: permutation map
|
34
|
+
:param dim: dimension along which to apply permutation
|
35
|
+
:return: permuted value
|
36
|
+
"""
|
37
|
+
dtype_tuple = (value.dtype, value.device)
|
38
|
+
|
39
|
+
if dtype_tuple in _EXPERIMENTAL_DTYPES:
|
40
|
+
return _fallback_permute(value, perm, dim)
|
41
|
+
|
42
|
+
try:
|
43
|
+
return value[tuple([slice(None)] * dim + [perm])]
|
44
|
+
except RuntimeError:
|
45
|
+
# Mark dtype as experimental if advanced indexing fails
|
46
|
+
_EXPERIMENTAL_DTYPES.add(dtype_tuple)
|
47
|
+
return _fallback_permute(value, perm, dim)
|
48
|
+
|
49
|
+
|
50
|
+
def _fallback_permute(
|
51
|
+
value: torch.Tensor, perm: torch.Tensor, dim: int
|
52
|
+
) -> torch.Tensor:
|
53
|
+
"""
|
54
|
+
Fallback permutation method for experimental dtypes.
|
55
|
+
|
56
|
+
:param value: tensor to permute
|
57
|
+
:param perm: permutation map
|
58
|
+
:param dim: dimension along which to apply permutation
|
59
|
+
:return: permuted value
|
60
|
+
"""
|
61
|
+
value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
|
62
|
+
orig_slices = [slice(None)] * (dim + 1)
|
63
|
+
perm_slices = [slice(None)] * (dim + 1)
|
64
|
+
|
65
|
+
for index, perm_index in enumerate(perm):
|
66
|
+
orig_slices[dim] = index
|
67
|
+
perm_slices[dim] = perm_index
|
68
|
+
value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
|
69
|
+
|
70
|
+
return value_ret
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: compressed-tensors-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.20240829
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -10,8 +10,9 @@ Description-Content-Type: text/markdown
|
|
10
10
|
License-File: LICENSE
|
11
11
|
Requires-Dist: torch>=1.7.0
|
12
12
|
Requires-Dist: transformers
|
13
|
-
Requires-Dist: accelerate
|
14
13
|
Requires-Dist: pydantic>=2.0
|
14
|
+
Provides-Extra: accelerate
|
15
|
+
Requires-Dist: accelerate; extra == "accelerate"
|
15
16
|
Provides-Extra: dev
|
16
17
|
Requires-Dist: black==22.12.0; extra == "dev"
|
17
18
|
Requires-Dist: isort==5.8.0; extra == "dev"
|
@@ -6,43 +6,45 @@ compressed_tensors/compressors/base.py,sha256=-rqT2h9G2iwDkwrVj0d0jxxn9h0dccJA1m
|
|
6
6
|
compressed_tensors/compressors/dense.py,sha256=xcWECjcRY4INN6jC7vHx5wvUX3NmnKlxA9SVE1A6m2Q,1267
|
7
7
|
compressed_tensors/compressors/helpers.py,sha256=k9avlkmeYj6vkOAvl-MgcixtP7ib24SCfhzZ-RusXfw,5403
|
8
8
|
compressed_tensors/compressors/marlin_24.py,sha256=e7fGUyZbjUpA5VUMCPxqcYPGNiwoDKupHJaXWCoVKRw,9410
|
9
|
-
compressed_tensors/compressors/model_compressor.py,sha256=
|
10
|
-
compressed_tensors/compressors/naive_quantized.py,sha256=
|
11
|
-
compressed_tensors/compressors/pack_quantized.py,sha256=
|
9
|
+
compressed_tensors/compressors/model_compressor.py,sha256=IX6IkJi5Nga1k7atGs_9L9YRHD64A9MITo0x9GqcVLI,15927
|
10
|
+
compressed_tensors/compressors/naive_quantized.py,sha256=rb7-e_ygcdwxtyWf3kpME9rVCPAawbwoAkiOWnF0pgk,5843
|
11
|
+
compressed_tensors/compressors/pack_quantized.py,sha256=DLyysMpewwxTYnxHCzWt6gg3uK9BFc9iUf3LRQUlgwY,8749
|
12
12
|
compressed_tensors/compressors/sparse_bitmask.py,sha256=kiDwBlFV0sJGLcIdDYxIiuF64ccgwDfqq1hWRQThYDc,8647
|
13
13
|
compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
|
14
14
|
compressed_tensors/config/base.py,sha256=caSZ7xZ_kgcHRMXZ5hM1i6TKbgY__CkiSjZ93imHZQ0,1562
|
15
15
|
compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
|
16
16
|
compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
|
17
17
|
compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
|
18
|
-
compressed_tensors/quantization/quant_args.py,sha256=
|
18
|
+
compressed_tensors/quantization/quant_args.py,sha256=wSC2ve1P-XRwZUpqEaqvQpj1Xe0EGgmmPEjPk9YEnyg,6797
|
19
19
|
compressed_tensors/quantization/quant_config.py,sha256=NpVu8YJ4Xw2pIQW_PGaNaml8kx1bUnxkvb0jBYWbKdE,9971
|
20
20
|
compressed_tensors/quantization/quant_scheme.py,sha256=_RKOFJI0T5xJVBLX63UeYkSY4EFAecsBnqzUIVBjeU0,6014
|
21
21
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=MXE2E7GfIfRRfhrdGy2Og3AZOz5N59B0ZGFcsD89y6c,821
|
22
|
-
compressed_tensors/quantization/lifecycle/apply.py,sha256=
|
23
|
-
compressed_tensors/quantization/lifecycle/calibration.py,sha256=
|
22
|
+
compressed_tensors/quantization/lifecycle/apply.py,sha256=DdX-ilWn1cNidPTolyrb3OVnZ2fh_dU89sQFtvdoW8E,14119
|
23
|
+
compressed_tensors/quantization/lifecycle/calibration.py,sha256=PlS_EqCOPqJD3QKuLPXO9AOtDzXtQWvEBTynFv-FFVw,2698
|
24
24
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=VreB10xPwgSLQQlTu20UCrFpRS--cA7-lx5s7nrPPrg,2247
|
25
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
25
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=evjXwqSVvIGCW-HPBjRoAsXPLNDj1P2GLBa4oqpUjV0,13414
|
26
26
|
compressed_tensors/quantization/lifecycle/frozen.py,sha256=h1XYt89MouBTf3jTYLG_6OdFxIu5q2N8tPjsy6J4E6Y,1726
|
27
|
-
compressed_tensors/quantization/lifecycle/helpers.py,sha256=
|
28
|
-
compressed_tensors/quantization/lifecycle/initialize.py,sha256=
|
29
|
-
compressed_tensors/quantization/observers/__init__.py,sha256=
|
30
|
-
compressed_tensors/quantization/observers/base.py,sha256=
|
27
|
+
compressed_tensors/quantization/lifecycle/helpers.py,sha256=TmLY_G5VP_Fg2Ywio_dxoHRTxOKZdT7_aG5S9WtD4zI,2424
|
28
|
+
compressed_tensors/quantization/lifecycle/initialize.py,sha256=zwtDbfFwnTYlUHq8FRQl4obDQw1RVmXUGxDdFOYqCsw,6618
|
29
|
+
compressed_tensors/quantization/observers/__init__.py,sha256=4Sa7rqi5RB_S5bPO8KmncETiqDsoMBhwP37arlQym8s,764
|
30
|
+
compressed_tensors/quantization/observers/base.py,sha256=5ovQicWPYHjIxr6-EkQ4lgOX0PpI9g23iSzKpxjM1Zg,8420
|
31
31
|
compressed_tensors/quantization/observers/helpers.py,sha256=s_A23Qa_BLfOdHJCN5bm-qPWkhjjj_RIVrhSp1Y9Dtk,4211
|
32
32
|
compressed_tensors/quantization/observers/memoryless.py,sha256=jH_c6K3gxf4W3VNXQ7tbnP-J_86QTrEfjBn6Kh1C-H8,2165
|
33
|
-
compressed_tensors/quantization/observers/min_max.py,sha256=
|
33
|
+
compressed_tensors/quantization/observers/min_max.py,sha256=sQXqU3z-voxIDfR_9mQzwQUflZj2sASm_G8CYaXntFw,3865
|
34
|
+
compressed_tensors/quantization/observers/mse.py,sha256=Aeh-253Vbab1F8cYuBiGNn4OXWJ67wXQ_JVfl3mu2a8,6034
|
34
35
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
35
36
|
compressed_tensors/quantization/utils/helpers.py,sha256=YjXABJQUnelof-z7qcwck6fnrFLh4uMSrOmPiqNp_RY,8591
|
36
37
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
37
38
|
compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85SLG77nml2iA,11890
|
38
|
-
compressed_tensors/utils/__init__.py,sha256=
|
39
|
+
compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
|
39
40
|
compressed_tensors/utils/helpers.py,sha256=d3yP9ViQ8R3GzMHfohxNlaokzyrRuj2PyjxWAJZmSws,3156
|
40
|
-
compressed_tensors/utils/offload.py,sha256=
|
41
|
+
compressed_tensors/utils/offload.py,sha256=GV9Cg__gb4jiyyqh6tS74c0Z63oZMTJKSuddRRyywpI,4033
|
41
42
|
compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
42
|
-
compressed_tensors/utils/
|
43
|
-
compressed_tensors/utils/
|
44
|
-
|
45
|
-
compressed_tensors_nightly-0.5.0.
|
46
|
-
compressed_tensors_nightly-0.5.0.
|
47
|
-
compressed_tensors_nightly-0.5.0.
|
48
|
-
compressed_tensors_nightly-0.5.0.
|
43
|
+
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
44
|
+
compressed_tensors/utils/safetensors_load.py,sha256=m08ANVuTBxQdoa6LufDgcNJ7wCLDJolyZljB8VEybAU,8578
|
45
|
+
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
46
|
+
compressed_tensors_nightly-0.5.0.20240829.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
47
|
+
compressed_tensors_nightly-0.5.0.20240829.dist-info/METADATA,sha256=lcK8iKQNb16vov8zvbn5JzLMqENOjY-QgV3jGvdfvmQ,6799
|
48
|
+
compressed_tensors_nightly-0.5.0.20240829.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
49
|
+
compressed_tensors_nightly-0.5.0.20240829.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
50
|
+
compressed_tensors_nightly-0.5.0.20240829.dist-info/RECORD,,
|
File without changes
|
File without changes
|