compressed-tensors-nightly 0.4.0.20240619__py3-none-any.whl → 0.4.0.20240621__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/model_compressor.py +20 -0
- compressed_tensors/compressors/{int_quantized.py → naive_quantized.py} +29 -11
- compressed_tensors/compressors/pack_quantized.py +0 -4
- compressed_tensors/config/base.py +2 -0
- compressed_tensors/quantization/lifecycle/apply.py +4 -8
- compressed_tensors/quantization/lifecycle/forward.py +52 -29
- compressed_tensors/quantization/lifecycle/initialize.py +2 -1
- compressed_tensors/quantization/observers/helpers.py +44 -9
- compressed_tensors/quantization/quant_args.py +45 -1
- compressed_tensors/quantization/quant_scheme.py +14 -8
- {compressed_tensors_nightly-0.4.0.20240619.dist-info → compressed_tensors_nightly-0.4.0.20240621.dist-info}/METADATA +1 -1
- {compressed_tensors_nightly-0.4.0.20240619.dist-info → compressed_tensors_nightly-0.4.0.20240621.dist-info}/RECORD +16 -16
- {compressed_tensors_nightly-0.4.0.20240619.dist-info → compressed_tensors_nightly-0.4.0.20240621.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.4.0.20240619.dist-info → compressed_tensors_nightly-0.4.0.20240621.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.4.0.20240619.dist-info → compressed_tensors_nightly-0.4.0.20240621.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,12 @@
|
|
17
17
|
from .base import Compressor
|
18
18
|
from .dense import DenseCompressor
|
19
19
|
from .helpers import load_compressed, save_compressed, save_compressed_model
|
20
|
-
from .int_quantized import IntQuantizationCompressor
|
21
20
|
from .marlin_24 import Marlin24Compressor
|
22
21
|
from .model_compressor import ModelCompressor, map_modules_to_quant_args
|
22
|
+
from .naive_quantized import (
|
23
|
+
FloatQuantizationCompressor,
|
24
|
+
IntQuantizationCompressor,
|
25
|
+
QuantizationCompressor,
|
26
|
+
)
|
23
27
|
from .pack_quantized import PackedQuantizationCompressor
|
24
28
|
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
|
@@ -16,9 +16,12 @@ import json
|
|
16
16
|
import logging
|
17
17
|
import operator
|
18
18
|
import os
|
19
|
+
import re
|
19
20
|
from copy import deepcopy
|
20
21
|
from typing import Any, Dict, Optional, Union
|
21
22
|
|
23
|
+
import torch
|
24
|
+
import transformers
|
22
25
|
from compressed_tensors.base import (
|
23
26
|
COMPRESSION_CONFIG_NAME,
|
24
27
|
QUANTIZATION_CONFIG_NAME,
|
@@ -236,6 +239,11 @@ class ModelCompressor:
|
|
236
239
|
compressed_state_dict
|
237
240
|
)
|
238
241
|
|
242
|
+
# HACK: Override the dtype_byte_size function in transformers to
|
243
|
+
# support float8 types. Fix is posted upstream
|
244
|
+
# https://github.com/huggingface/transformers/pull/30488
|
245
|
+
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
|
246
|
+
|
239
247
|
return compressed_state_dict
|
240
248
|
|
241
249
|
def decompress(self, model_path: str, model: Module):
|
@@ -313,3 +321,15 @@ def map_modules_to_quant_args(model: Module) -> Dict:
|
|
313
321
|
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
|
314
322
|
|
315
323
|
return quantized_modules_to_args
|
324
|
+
|
325
|
+
|
326
|
+
# HACK: Override the dtype_byte_size function in transformers to support float8 types
|
327
|
+
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
|
328
|
+
def new_dtype_byte_size(dtype):
|
329
|
+
if dtype == torch.bool:
|
330
|
+
return 1 / 8
|
331
|
+
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
332
|
+
if bit_search is None:
|
333
|
+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
334
|
+
bit_size = int(bit_search.groups()[0])
|
335
|
+
return bit_size // 8
|
@@ -27,17 +27,21 @@ from torch import Tensor
|
|
27
27
|
from tqdm import tqdm
|
28
28
|
|
29
29
|
|
30
|
-
__all__ = [
|
30
|
+
__all__ = [
|
31
|
+
"QuantizationCompressor",
|
32
|
+
"IntQuantizationCompressor",
|
33
|
+
"FloatQuantizationCompressor",
|
34
|
+
]
|
31
35
|
|
32
36
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
33
37
|
|
34
38
|
|
35
|
-
@Compressor.register(name=CompressionFormat.
|
36
|
-
class
|
39
|
+
@Compressor.register(name=CompressionFormat.naive_quantized.value)
|
40
|
+
class QuantizationCompressor(Compressor):
|
37
41
|
"""
|
38
|
-
|
39
|
-
converted from its original float type to the
|
40
|
-
|
42
|
+
Implements naive compression for quantized models. Weight of each
|
43
|
+
quantized layer is converted from its original float type to the closest Pytorch
|
44
|
+
type to the type specified by the layer's QuantizationArgs.
|
41
45
|
"""
|
42
46
|
|
43
47
|
COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]
|
@@ -77,7 +81,7 @@ class IntQuantizationCompressor(Compressor):
|
|
77
81
|
scale=scale,
|
78
82
|
zero_point=zp,
|
79
83
|
args=quant_args,
|
80
|
-
dtype=
|
84
|
+
dtype=quant_args.pytorch_dtype(),
|
81
85
|
)
|
82
86
|
elif name.endswith("zero_point"):
|
83
87
|
if torch.all(value == 0):
|
@@ -114,13 +118,27 @@ class IntQuantizationCompressor(Compressor):
|
|
114
118
|
if "weight_scale" in weight_data:
|
115
119
|
zero_point = weight_data.get("weight_zero_point", None)
|
116
120
|
scale = weight_data["weight_scale"]
|
117
|
-
if zero_point is None:
|
118
|
-
# zero_point assumed to be 0 if not included in state_dict
|
119
|
-
zero_point = torch.zeros_like(scale)
|
120
|
-
|
121
121
|
decompressed = dequantize(
|
122
122
|
x_q=weight_data["weight"],
|
123
123
|
scale=scale,
|
124
124
|
zero_point=zero_point,
|
125
125
|
)
|
126
126
|
yield merge_names(weight_name, "weight"), decompressed
|
127
|
+
|
128
|
+
|
129
|
+
@Compressor.register(name=CompressionFormat.int_quantized.value)
|
130
|
+
class IntQuantizationCompressor(QuantizationCompressor):
|
131
|
+
"""
|
132
|
+
Alias for integer quantized models
|
133
|
+
"""
|
134
|
+
|
135
|
+
pass
|
136
|
+
|
137
|
+
|
138
|
+
@Compressor.register(name=CompressionFormat.float_quantized.value)
|
139
|
+
class FloatQuantizationCompressor(QuantizationCompressor):
|
140
|
+
"""
|
141
|
+
Alias for fp quantized models
|
142
|
+
"""
|
143
|
+
|
144
|
+
pass
|
@@ -126,10 +126,6 @@ class PackedQuantizationCompressor(Compressor):
|
|
126
126
|
if "weight_scale" in weight_data:
|
127
127
|
zero_point = weight_data.get("weight_zero_point", None)
|
128
128
|
scale = weight_data["weight_scale"]
|
129
|
-
if zero_point is None:
|
130
|
-
# zero_point assumed to be 0 if not included in state_dict
|
131
|
-
zero_point = torch.zeros_like(scale)
|
132
|
-
|
133
129
|
weight = weight_data["weight_packed"]
|
134
130
|
original_shape = torch.Size(weight_data["weight_shape"])
|
135
131
|
unpacked = unpack_4bit_ints(weight, original_shape)
|
@@ -26,6 +26,8 @@ class CompressionFormat(Enum):
|
|
26
26
|
dense = "dense"
|
27
27
|
sparse_bitmask = "sparse-bitmask"
|
28
28
|
int_quantized = "int-quantized"
|
29
|
+
float_quantized = "float-quantized"
|
30
|
+
naive_quantized = "naive-quantized"
|
29
31
|
pack_quantized = "pack-quantized"
|
30
32
|
marlin_24 = "marlin-24"
|
31
33
|
|
@@ -215,15 +215,11 @@ def _load_quant_args_from_state_dict(
|
|
215
215
|
scale = getattr(module, scale_name, None)
|
216
216
|
zp = getattr(module, zp_name, None)
|
217
217
|
if scale is not None:
|
218
|
-
state_dict_scale = state_dict
|
219
|
-
|
220
|
-
scale.data = state_dict_scale.to(device).to(scale.dtype)
|
221
|
-
else:
|
222
|
-
scale.data = scale.data.to(device)
|
223
|
-
|
218
|
+
state_dict_scale = state_dict[f"{module_name}.{scale_name}"]
|
219
|
+
scale.data = state_dict_scale.to(device).to(scale.dtype)
|
224
220
|
if zp is not None:
|
225
221
|
zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
|
226
222
|
if zp_from_state is not None: # load the non-zero zero points
|
227
|
-
zp.data =
|
223
|
+
zp.data = zp_from_state.to(device).to(zp.dtype)
|
228
224
|
else: # fill with zeros matching scale shape
|
229
|
-
zp.data = torch.zeros_like(scale, dtype=
|
225
|
+
zp.data = torch.zeros_like(scale, dtype=zp.dtype).to(device)
|
@@ -17,9 +17,11 @@ from math import ceil
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from compressed_tensors.quantization.observers.helpers import calculate_range
|
20
21
|
from compressed_tensors.quantization.quant_args import (
|
21
22
|
QuantizationArgs,
|
22
23
|
QuantizationStrategy,
|
24
|
+
round_to_quantized_type,
|
23
25
|
)
|
24
26
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
25
27
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
@@ -80,8 +82,9 @@ def quantize(
|
|
80
82
|
def dequantize(
|
81
83
|
x_q: torch.Tensor,
|
82
84
|
scale: torch.Tensor,
|
83
|
-
zero_point: torch.Tensor,
|
85
|
+
zero_point: torch.Tensor = None,
|
84
86
|
args: QuantizationArgs = None,
|
87
|
+
dtype: Optional[torch.dtype] = None,
|
85
88
|
) -> torch.Tensor:
|
86
89
|
"""
|
87
90
|
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
@@ -91,16 +94,9 @@ def dequantize(
|
|
91
94
|
:param scale: scale tensor
|
92
95
|
:param zero_point: zero point tensor
|
93
96
|
:param args: quantization args used to quantize x_q
|
97
|
+
:param dtype: optional dtype to cast the dequantized output to
|
94
98
|
:return: dequantized float tensor
|
95
99
|
"""
|
96
|
-
# ensure all tensors are on the same device
|
97
|
-
# assumes that the target device is the input
|
98
|
-
# tensor's device
|
99
|
-
if x_q.device != scale.device:
|
100
|
-
scale = scale.to(x_q.device)
|
101
|
-
if x_q.device != zero_point.device:
|
102
|
-
zero_point = zero_point.to(x_q.device)
|
103
|
-
|
104
100
|
if args is None:
|
105
101
|
if scale.ndim == 0 or scale.ndim == 1:
|
106
102
|
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
|
@@ -115,8 +111,12 @@ def dequantize(
|
|
115
111
|
else:
|
116
112
|
raise ValueError(
|
117
113
|
f"Could not infer a quantization strategy from scale with {scale.ndim} "
|
118
|
-
"dimmensions. Expected 0
|
114
|
+
"dimmensions. Expected 0 or 2 dimmensions."
|
119
115
|
)
|
116
|
+
|
117
|
+
if dtype is None:
|
118
|
+
dtype = scale.dtype
|
119
|
+
|
120
120
|
return _process_quantization(
|
121
121
|
x=x_q,
|
122
122
|
scale=scale,
|
@@ -124,6 +124,7 @@ def dequantize(
|
|
124
124
|
args=args,
|
125
125
|
do_quantize=False,
|
126
126
|
do_dequantize=True,
|
127
|
+
dtype=dtype,
|
127
128
|
)
|
128
129
|
|
129
130
|
|
@@ -167,19 +168,13 @@ def _process_quantization(
|
|
167
168
|
do_quantize: bool = True,
|
168
169
|
do_dequantize: bool = True,
|
169
170
|
) -> torch.Tensor:
|
170
|
-
|
171
|
-
q_max =
|
172
|
-
q_min = torch.tensor(-bit_range / 2, device=x.device)
|
171
|
+
|
172
|
+
q_min, q_max = calculate_range(args, x.device)
|
173
173
|
group_size = args.group_size
|
174
174
|
|
175
175
|
if args.strategy == QuantizationStrategy.GROUP:
|
176
|
-
|
177
|
-
|
178
|
-
# if dequantizing a quantized type infer the output type from the scale
|
179
|
-
output = torch.zeros_like(x, dtype=scale.dtype)
|
180
|
-
else:
|
181
|
-
output_dtype = dtype if dtype is not None else x.dtype
|
182
|
-
output = torch.zeros_like(x, dtype=output_dtype)
|
176
|
+
output_dtype = dtype if dtype is not None else x.dtype
|
177
|
+
output = torch.zeros_like(x).to(output_dtype)
|
183
178
|
|
184
179
|
# TODO: vectorize the for loop
|
185
180
|
# TODO: fix genetric assumption about the tensor size for computing group
|
@@ -189,7 +184,7 @@ def _process_quantization(
|
|
189
184
|
while scale.ndim < 2:
|
190
185
|
# pad scale and zero point dims for slicing
|
191
186
|
scale = scale.unsqueeze(1)
|
192
|
-
zero_point = zero_point.unsqueeze(1)
|
187
|
+
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
|
193
188
|
|
194
189
|
columns = x.shape[1]
|
195
190
|
if columns >= group_size:
|
@@ -202,12 +197,18 @@ def _process_quantization(
|
|
202
197
|
# scale.shape should be [nchan, ndim]
|
203
198
|
# sc.shape should be [nchan, 1] after unsqueeze
|
204
199
|
sc = scale[:, i].view(-1, 1)
|
205
|
-
zp = zero_point[:, i].view(-1, 1)
|
200
|
+
zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
|
206
201
|
|
207
202
|
idx = i * group_size
|
208
203
|
if do_quantize:
|
209
204
|
output[:, idx : (idx + group_size)] = _quantize(
|
210
|
-
x[:, idx : (idx + group_size)],
|
205
|
+
x[:, idx : (idx + group_size)],
|
206
|
+
sc,
|
207
|
+
zp,
|
208
|
+
q_min,
|
209
|
+
q_max,
|
210
|
+
args,
|
211
|
+
dtype=dtype,
|
211
212
|
)
|
212
213
|
if do_dequantize:
|
213
214
|
input = (
|
@@ -219,7 +220,15 @@ def _process_quantization(
|
|
219
220
|
|
220
221
|
else: # covers channel, token and tensor strategies
|
221
222
|
if do_quantize:
|
222
|
-
output = _quantize(
|
223
|
+
output = _quantize(
|
224
|
+
x,
|
225
|
+
scale,
|
226
|
+
zero_point,
|
227
|
+
q_min,
|
228
|
+
q_max,
|
229
|
+
args,
|
230
|
+
dtype=dtype,
|
231
|
+
)
|
223
232
|
if do_dequantize:
|
224
233
|
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
225
234
|
|
@@ -313,14 +322,18 @@ def _quantize(
|
|
313
322
|
zero_point: torch.Tensor,
|
314
323
|
q_min: torch.Tensor,
|
315
324
|
q_max: torch.Tensor,
|
325
|
+
args: QuantizationArgs,
|
316
326
|
dtype: Optional[torch.dtype] = None,
|
317
327
|
) -> torch.Tensor:
|
318
|
-
|
319
|
-
|
328
|
+
|
329
|
+
scaled = x / scale + zero_point.to(x.dtype)
|
330
|
+
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
|
331
|
+
clamped_value = torch.clamp(
|
332
|
+
scaled,
|
320
333
|
q_min,
|
321
334
|
q_max,
|
322
335
|
)
|
323
|
-
|
336
|
+
quantized_value = round_to_quantized_type(clamped_value, args)
|
324
337
|
if dtype is not None:
|
325
338
|
quantized_value = quantized_value.to(dtype)
|
326
339
|
|
@@ -331,6 +344,16 @@ def _quantize(
|
|
331
344
|
def _dequantize(
|
332
345
|
x_q: torch.Tensor,
|
333
346
|
scale: torch.Tensor,
|
334
|
-
zero_point: torch.Tensor,
|
347
|
+
zero_point: torch.Tensor = None,
|
348
|
+
dtype: Optional[torch.dtype] = None,
|
335
349
|
) -> torch.Tensor:
|
336
|
-
|
350
|
+
|
351
|
+
dequant_value = x_q
|
352
|
+
if zero_point is not None:
|
353
|
+
dequant_value = dequant_value - zero_point.to(scale.dtype)
|
354
|
+
dequant_value = dequant_value.to(scale.dtype) * scale
|
355
|
+
|
356
|
+
if dtype is not None:
|
357
|
+
dequant_value = dequant_value.to(dtype)
|
358
|
+
|
359
|
+
return dequant_value
|
@@ -120,8 +120,9 @@ def _initialize_scale_zero_point_observer(
|
|
120
120
|
)
|
121
121
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
122
122
|
|
123
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
123
124
|
init_zero_point = Parameter(
|
124
|
-
torch.empty(expected_shape, device=device, dtype=
|
125
|
+
torch.empty(expected_shape, device=device, dtype=zp_dtype),
|
125
126
|
requires_grad=False,
|
126
127
|
)
|
127
128
|
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
@@ -15,11 +15,15 @@
|
|
15
15
|
from typing import Tuple
|
16
16
|
|
17
17
|
import torch
|
18
|
-
from compressed_tensors.quantization.quant_args import
|
18
|
+
from compressed_tensors.quantization.quant_args import (
|
19
|
+
FP8_DTYPE,
|
20
|
+
QuantizationArgs,
|
21
|
+
QuantizationType,
|
22
|
+
)
|
19
23
|
from torch import FloatTensor, IntTensor, Tensor
|
20
24
|
|
21
25
|
|
22
|
-
__all__ = ["calculate_qparams"]
|
26
|
+
__all__ = ["calculate_qparams", "calculate_range"]
|
23
27
|
|
24
28
|
|
25
29
|
def calculate_qparams(
|
@@ -37,22 +41,53 @@ def calculate_qparams(
|
|
37
41
|
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
38
42
|
device = min_vals.device
|
39
43
|
|
40
|
-
|
41
|
-
|
42
|
-
|
44
|
+
bit_min, bit_max = calculate_range(quantization_args, device)
|
45
|
+
bit_range = bit_max - bit_min
|
46
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
47
|
+
|
43
48
|
if quantization_args.symmetric:
|
44
|
-
max_val_pos = torch.max(
|
49
|
+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
|
45
50
|
scales = max_val_pos / (float(bit_range) / 2)
|
46
51
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
47
|
-
zero_points = torch.zeros(scales.shape, device=device, dtype=
|
52
|
+
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
|
48
53
|
else:
|
49
54
|
scales = (max_vals - min_vals) / float(bit_range)
|
50
55
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
51
|
-
zero_points = bit_min -
|
52
|
-
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
56
|
+
zero_points = bit_min - (min_vals / scales)
|
57
|
+
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
58
|
+
|
59
|
+
# match zero-points to quantized type
|
60
|
+
zero_points = zero_points.to(zp_dtype)
|
53
61
|
|
54
62
|
if scales.ndim == 0:
|
55
63
|
scales = scales.reshape(1)
|
56
64
|
zero_points = zero_points.reshape(1)
|
57
65
|
|
58
66
|
return scales, zero_points
|
67
|
+
|
68
|
+
|
69
|
+
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
|
70
|
+
"""
|
71
|
+
Calculated the effective quantization range for the given Quantization Args
|
72
|
+
|
73
|
+
:param quantization_args: quantization args to get range of
|
74
|
+
:param device: device to store the range to
|
75
|
+
:return: tuple endpoints for the given quantization range
|
76
|
+
"""
|
77
|
+
if quantization_args.type == QuantizationType.INT:
|
78
|
+
bit_range = 2**quantization_args.num_bits
|
79
|
+
q_max = torch.tensor(bit_range / 2 - 1, device=device)
|
80
|
+
q_min = torch.tensor(-bit_range / 2, device=device)
|
81
|
+
elif quantization_args.type == QuantizationType.FLOAT:
|
82
|
+
if quantization_args.num_bits != 8:
|
83
|
+
raise ValueError(
|
84
|
+
"Floating point quantization is only supported for 8 bits,"
|
85
|
+
f"got {quantization_args.num_bits}"
|
86
|
+
)
|
87
|
+
fp_range_info = torch.finfo(FP8_DTYPE)
|
88
|
+
q_max = torch.tensor(fp_range_info.max, device=device)
|
89
|
+
q_min = torch.tensor(fp_range_info.min, device=device)
|
90
|
+
else:
|
91
|
+
raise ValueError(f"Invalid quantization type {quantization_args.type}")
|
92
|
+
|
93
|
+
return q_min, q_max
|
@@ -15,10 +15,19 @@
|
|
15
15
|
from enum import Enum
|
16
16
|
from typing import Any, Dict, Optional
|
17
17
|
|
18
|
+
import torch
|
18
19
|
from pydantic import BaseModel, Field, validator
|
19
20
|
|
20
21
|
|
21
|
-
__all__ = [
|
22
|
+
__all__ = [
|
23
|
+
"FP8_DTYPE",
|
24
|
+
"QuantizationType",
|
25
|
+
"QuantizationStrategy",
|
26
|
+
"QuantizationArgs",
|
27
|
+
"round_to_quantized_type",
|
28
|
+
]
|
29
|
+
|
30
|
+
FP8_DTYPE = torch.float8_e4m3fn
|
22
31
|
|
23
32
|
|
24
33
|
class QuantizationType(str, Enum):
|
@@ -123,3 +132,38 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
123
132
|
return QuantizationStrategy.TENSOR
|
124
133
|
|
125
134
|
return value
|
135
|
+
|
136
|
+
def pytorch_dtype(self) -> torch.dtype:
|
137
|
+
if self.type == QuantizationType.FLOAT:
|
138
|
+
return FP8_DTYPE
|
139
|
+
elif self.type == QuantizationType.INT:
|
140
|
+
if self.num_bits <= 8:
|
141
|
+
return torch.int8
|
142
|
+
elif self.num_bits <= 16:
|
143
|
+
return torch.int16
|
144
|
+
else:
|
145
|
+
return torch.int32
|
146
|
+
else:
|
147
|
+
raise ValueError(f"Invalid quantization type {self.type}")
|
148
|
+
|
149
|
+
|
150
|
+
def round_to_quantized_type(
|
151
|
+
tensor: torch.Tensor, args: QuantizationArgs
|
152
|
+
) -> torch.Tensor:
|
153
|
+
"""
|
154
|
+
Rounds each element of the input tensor to the nearest quantized representation,
|
155
|
+
keeping to original dtype
|
156
|
+
|
157
|
+
:param tensor: tensor to round
|
158
|
+
:param args: QuantizationArgs to pull appropriate dtype from
|
159
|
+
:return: rounded tensor
|
160
|
+
"""
|
161
|
+
original_dtype = tensor.dtype
|
162
|
+
if args.type == QuantizationType.FLOAT:
|
163
|
+
rounded = tensor.to(FP8_DTYPE)
|
164
|
+
elif args.type == QuantizationType.INT:
|
165
|
+
rounded = torch.round(tensor)
|
166
|
+
else:
|
167
|
+
raise ValueError(f"Invalid quantization type {args.type}")
|
168
|
+
|
169
|
+
return rounded.to(original_dtype)
|
@@ -15,7 +15,11 @@
|
|
15
15
|
from copy import deepcopy
|
16
16
|
from typing import List, Optional
|
17
17
|
|
18
|
-
from compressed_tensors.quantization.quant_args import
|
18
|
+
from compressed_tensors.quantization.quant_args import (
|
19
|
+
QuantizationArgs,
|
20
|
+
QuantizationStrategy,
|
21
|
+
QuantizationType,
|
22
|
+
)
|
19
23
|
from pydantic import BaseModel
|
20
24
|
|
21
25
|
|
@@ -107,13 +111,15 @@ def is_preset_scheme(name: str) -> bool:
|
|
107
111
|
return name.upper() in PRESET_SCHEMES
|
108
112
|
|
109
113
|
|
110
|
-
W8A8 = dict(
|
111
|
-
|
114
|
+
W8A8 = dict(weights=QuantizationArgs(), input_activations=QuantizationArgs())
|
115
|
+
|
116
|
+
W4A16 = dict(weights=QuantizationArgs(num_bits=4, group_size=128))
|
117
|
+
|
118
|
+
FP8 = dict(
|
119
|
+
weights=QuantizationArgs(type=QuantizationType.FLOAT),
|
120
|
+
input_activations=QuantizationArgs(type=QuantizationType.FLOAT),
|
112
121
|
)
|
113
122
|
|
114
|
-
|
123
|
+
PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8}
|
115
124
|
|
116
|
-
PRESET_SCHEMES = {
|
117
|
-
"W8A8": W8A8,
|
118
|
-
"W4A16": W4A16,
|
119
|
-
}
|
125
|
+
PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8}
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: compressed-tensors-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.20240621
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -1,37 +1,37 @@
|
|
1
1
|
compressed_tensors/__init__.py,sha256=SV1csvHUVCd8kHXz6UDZim1HZ_fAVG3vfk-j_4Bb6hY,789
|
2
2
|
compressed_tensors/base.py,sha256=OA2TOLP1gP3LSH7gp508eqr2ZtDQ-pqRHElCp-aB0vs,755
|
3
3
|
compressed_tensors/version.py,sha256=cJJf0y0NnXErTtQtVQjOvrq9hMIkhXIfBwuu4Tuxl24,1586
|
4
|
-
compressed_tensors/compressors/__init__.py,sha256=
|
4
|
+
compressed_tensors/compressors/__init__.py,sha256=wmX4VnkUTS63xBwK5-6w8FP78bNZpcdcqvf2KOEC5E4,1133
|
5
5
|
compressed_tensors/compressors/base.py,sha256=LWEgbpgTxzmoqQ7Xhq2OQszUgWoDtFuGCiV1Y8nlBGw,2134
|
6
6
|
compressed_tensors/compressors/dense.py,sha256=G_XHbvuENyupIKlXSITOQgvPkNkcMEOLcLWQr70V9EE,1257
|
7
7
|
compressed_tensors/compressors/helpers.py,sha256=k9avlkmeYj6vkOAvl-MgcixtP7ib24SCfhzZ-RusXfw,5403
|
8
|
-
compressed_tensors/compressors/int_quantized.py,sha256=Ct2vCK0yoPm6vkIFlzDMGQ7m14xT1GyURsSwH9DP770,5242
|
9
8
|
compressed_tensors/compressors/marlin_24.py,sha256=X_BjtFB3Mn0hqiLz56UM3jGX2eNmGLnvEIPfbg7di6U,9444
|
10
|
-
compressed_tensors/compressors/model_compressor.py,sha256=
|
11
|
-
compressed_tensors/compressors/
|
9
|
+
compressed_tensors/compressors/model_compressor.py,sha256=83AWAhlrR3QTNelfMGCh_10G-VfMIRXRTvV0ZZinCU8,13338
|
10
|
+
compressed_tensors/compressors/naive_quantized.py,sha256=N3y5LxsCaTUJHT30sqEhnviZsyoz1v2eUaayE7-f8Xs,5562
|
11
|
+
compressed_tensors/compressors/pack_quantized.py,sha256=ODb03_WaBQ1l99Gmp49olAUZ2TB_67z9qNZbc56X7NU,8275
|
12
12
|
compressed_tensors/compressors/sparse_bitmask.py,sha256=H9oZSTYI1oRCzAMbd4zThUnZd1h2rfs8DmA3tPcvuNE,8637
|
13
13
|
compressed_tensors/compressors/utils/__init__.py,sha256=-mbGDZh1hd9T6u62Ht_iBIK255UmMg0f5bLkSs1f9Cc,731
|
14
14
|
compressed_tensors/compressors/utils/helpers.py,sha256=4fq7KclSIK__jemCG9pwYlgWLrQjsaAMxhIrhjdw0BQ,1506
|
15
15
|
compressed_tensors/compressors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
16
16
|
compressed_tensors/compressors/utils/semi_structured_conversions.py,sha256=g1EZHzdv-ko7ufPX430dp7wE33o6FWJXuSP4zZydCu0,13488
|
17
17
|
compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
|
18
|
-
compressed_tensors/config/base.py,sha256=
|
18
|
+
compressed_tensors/config/base.py,sha256=caSZ7xZ_kgcHRMXZ5hM1i6TKbgY__CkiSjZ93imHZQ0,1562
|
19
19
|
compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
|
20
20
|
compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
|
21
21
|
compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
|
22
|
-
compressed_tensors/quantization/quant_args.py,sha256=
|
22
|
+
compressed_tensors/quantization/quant_args.py,sha256=Vc_tWSTcbZZsMJlACpLq4JEPvGx87izc8VEx-mcXjoM,5621
|
23
23
|
compressed_tensors/quantization/quant_config.py,sha256=hL42sXp1wAZxyrkHarw7tAMRcwSVEr0MT3wmrmL3NhE,8285
|
24
|
-
compressed_tensors/quantization/quant_scheme.py,sha256=
|
24
|
+
compressed_tensors/quantization/quant_scheme.py,sha256=Yhaj3QJn4lifGMoQ8mlXXOdLDZA6iGMthb_0hlAzvVk,3811
|
25
25
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=ggRGWRqhCxCaTTDWRcgTVX3axnS2xV6rc5YvdzK7fSg,798
|
26
|
-
compressed_tensors/quantization/lifecycle/apply.py,sha256=
|
26
|
+
compressed_tensors/quantization/lifecycle/apply.py,sha256=eQfuIGcX6KBKeMta1svviXXRpKO3og2CRrxhKlGcE_k,8756
|
27
27
|
compressed_tensors/quantization/lifecycle/calibration.py,sha256=mLns4jlaWmBwOW8Jtlm5bMX-JET1AiZYUBO7qa-XuxI,1776
|
28
28
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=VreB10xPwgSLQQlTu20UCrFpRS--cA7-lx5s7nrPPrg,2247
|
29
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
29
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=tcjL_qyE3ODourNprt2bndF7_ALlUEGY2_Yag4exLoE,11908
|
30
30
|
compressed_tensors/quantization/lifecycle/frozen.py,sha256=h1XYt89MouBTf3jTYLG_6OdFxIu5q2N8tPjsy6J4E6Y,1726
|
31
|
-
compressed_tensors/quantization/lifecycle/initialize.py,sha256=
|
31
|
+
compressed_tensors/quantization/lifecycle/initialize.py,sha256=kIEx6a7UyqAIG_ZPNBhijrDiAHnp2wR7K_GC3envz4M,4631
|
32
32
|
compressed_tensors/quantization/observers/__init__.py,sha256=DNH31NQYrIBBcmHsMyFA6whh4pbRsLwuNa6L8AeXaGc,745
|
33
33
|
compressed_tensors/quantization/observers/base.py,sha256=z_JC-CRz-PY7WlpSoyOoSQQWz5ekTEd5LbXt0iHQRes,5239
|
34
|
-
compressed_tensors/quantization/observers/helpers.py,sha256=
|
34
|
+
compressed_tensors/quantization/observers/helpers.py,sha256=DSNGNJpZyT2Lyu0c82dHEGf9q5vm4N3zgI3DpkBbp0Q,3597
|
35
35
|
compressed_tensors/quantization/observers/memoryless.py,sha256=jH_c6K3gxf4W3VNXQ7tbnP-J_86QTrEfjBn6Kh1C-H8,2165
|
36
36
|
compressed_tensors/quantization/observers/min_max.py,sha256=UK7zCMzxv9GGn6BflBxdajV20RiWaCY2RHcvZodCP1w,3669
|
37
37
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
@@ -41,8 +41,8 @@ compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85S
|
|
41
41
|
compressed_tensors/utils/__init__.py,sha256=5DrYjoZbaEvSkJcC-GRSbM_RBHVF4tG9gMd3zsJnjLw,665
|
42
42
|
compressed_tensors/utils/helpers.py,sha256=5ull5yFT31M2zVxKeFvpvvlvX5f1Sk1LGuj_wrfZWCY,2267
|
43
43
|
compressed_tensors/utils/safetensors_load.py,sha256=0MheXwx1jeY12PeISppiSIZHs6rmN2YddwPpFb9V67I,8527
|
44
|
-
compressed_tensors_nightly-0.4.0.
|
45
|
-
compressed_tensors_nightly-0.4.0.
|
46
|
-
compressed_tensors_nightly-0.4.0.
|
47
|
-
compressed_tensors_nightly-0.4.0.
|
48
|
-
compressed_tensors_nightly-0.4.0.
|
44
|
+
compressed_tensors_nightly-0.4.0.20240621.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
45
|
+
compressed_tensors_nightly-0.4.0.20240621.dist-info/METADATA,sha256=zC3A9MK7GzcOAboNXZHhw_exVI35srmQ3ocSgzAy6j0,5668
|
46
|
+
compressed_tensors_nightly-0.4.0.20240621.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
47
|
+
compressed_tensors_nightly-0.4.0.20240621.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
48
|
+
compressed_tensors_nightly-0.4.0.20240621.dist-info/RECORD,,
|
File without changes
|
File without changes
|