compressed-tensors-nightly 0.4.0.20240618__py3-none-any.whl → 0.4.0.20240620__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/model_compressor.py +20 -0
- compressed_tensors/compressors/{int_quantized.py → naive_quantized.py} +29 -11
- compressed_tensors/compressors/pack_quantized.py +0 -4
- compressed_tensors/config/base.py +2 -0
- compressed_tensors/quantization/lifecycle/apply.py +4 -8
- compressed_tensors/quantization/lifecycle/forward.py +52 -21
- compressed_tensors/quantization/lifecycle/initialize.py +2 -1
- compressed_tensors/quantization/observers/helpers.py +44 -9
- compressed_tensors/quantization/quant_args.py +45 -1
- compressed_tensors/quantization/quant_scheme.py +14 -8
- {compressed_tensors_nightly-0.4.0.20240618.dist-info → compressed_tensors_nightly-0.4.0.20240620.dist-info}/METADATA +1 -1
- {compressed_tensors_nightly-0.4.0.20240618.dist-info → compressed_tensors_nightly-0.4.0.20240620.dist-info}/RECORD +16 -16
- {compressed_tensors_nightly-0.4.0.20240618.dist-info → compressed_tensors_nightly-0.4.0.20240620.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.4.0.20240618.dist-info → compressed_tensors_nightly-0.4.0.20240620.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.4.0.20240618.dist-info → compressed_tensors_nightly-0.4.0.20240620.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,12 @@
|
|
17
17
|
from .base import Compressor
|
18
18
|
from .dense import DenseCompressor
|
19
19
|
from .helpers import load_compressed, save_compressed, save_compressed_model
|
20
|
-
from .int_quantized import IntQuantizationCompressor
|
21
20
|
from .marlin_24 import Marlin24Compressor
|
22
21
|
from .model_compressor import ModelCompressor, map_modules_to_quant_args
|
22
|
+
from .naive_quantized import (
|
23
|
+
FloatQuantizationCompressor,
|
24
|
+
IntQuantizationCompressor,
|
25
|
+
QuantizationCompressor,
|
26
|
+
)
|
23
27
|
from .pack_quantized import PackedQuantizationCompressor
|
24
28
|
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
|
@@ -16,9 +16,12 @@ import json
|
|
16
16
|
import logging
|
17
17
|
import operator
|
18
18
|
import os
|
19
|
+
import re
|
19
20
|
from copy import deepcopy
|
20
21
|
from typing import Any, Dict, Optional, Union
|
21
22
|
|
23
|
+
import torch
|
24
|
+
import transformers
|
22
25
|
from compressed_tensors.base import (
|
23
26
|
COMPRESSION_CONFIG_NAME,
|
24
27
|
QUANTIZATION_CONFIG_NAME,
|
@@ -236,6 +239,11 @@ class ModelCompressor:
|
|
236
239
|
compressed_state_dict
|
237
240
|
)
|
238
241
|
|
242
|
+
# HACK: Override the dtype_byte_size function in transformers to
|
243
|
+
# support float8 types. Fix is posted upstream
|
244
|
+
# https://github.com/huggingface/transformers/pull/30488
|
245
|
+
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
|
246
|
+
|
239
247
|
return compressed_state_dict
|
240
248
|
|
241
249
|
def decompress(self, model_path: str, model: Module):
|
@@ -313,3 +321,15 @@ def map_modules_to_quant_args(model: Module) -> Dict:
|
|
313
321
|
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
|
314
322
|
|
315
323
|
return quantized_modules_to_args
|
324
|
+
|
325
|
+
|
326
|
+
# HACK: Override the dtype_byte_size function in transformers to support float8 types
|
327
|
+
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
|
328
|
+
def new_dtype_byte_size(dtype):
|
329
|
+
if dtype == torch.bool:
|
330
|
+
return 1 / 8
|
331
|
+
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
332
|
+
if bit_search is None:
|
333
|
+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
334
|
+
bit_size = int(bit_search.groups()[0])
|
335
|
+
return bit_size // 8
|
@@ -27,17 +27,21 @@ from torch import Tensor
|
|
27
27
|
from tqdm import tqdm
|
28
28
|
|
29
29
|
|
30
|
-
__all__ = [
|
30
|
+
__all__ = [
|
31
|
+
"QuantizationCompressor",
|
32
|
+
"IntQuantizationCompressor",
|
33
|
+
"FloatQuantizationCompressor",
|
34
|
+
]
|
31
35
|
|
32
36
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
33
37
|
|
34
38
|
|
35
|
-
@Compressor.register(name=CompressionFormat.
|
36
|
-
class
|
39
|
+
@Compressor.register(name=CompressionFormat.naive_quantized.value)
|
40
|
+
class QuantizationCompressor(Compressor):
|
37
41
|
"""
|
38
|
-
|
39
|
-
converted from its original float type to the
|
40
|
-
|
42
|
+
Implements naive compression for quantized models. Weight of each
|
43
|
+
quantized layer is converted from its original float type to the closest Pytorch
|
44
|
+
type to the type specified by the layer's QuantizationArgs.
|
41
45
|
"""
|
42
46
|
|
43
47
|
COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]
|
@@ -77,7 +81,7 @@ class IntQuantizationCompressor(Compressor):
|
|
77
81
|
scale=scale,
|
78
82
|
zero_point=zp,
|
79
83
|
args=quant_args,
|
80
|
-
dtype=
|
84
|
+
dtype=quant_args.pytorch_dtype(),
|
81
85
|
)
|
82
86
|
elif name.endswith("zero_point"):
|
83
87
|
if torch.all(value == 0):
|
@@ -114,13 +118,27 @@ class IntQuantizationCompressor(Compressor):
|
|
114
118
|
if "weight_scale" in weight_data:
|
115
119
|
zero_point = weight_data.get("weight_zero_point", None)
|
116
120
|
scale = weight_data["weight_scale"]
|
117
|
-
if zero_point is None:
|
118
|
-
# zero_point assumed to be 0 if not included in state_dict
|
119
|
-
zero_point = torch.zeros_like(scale)
|
120
|
-
|
121
121
|
decompressed = dequantize(
|
122
122
|
x_q=weight_data["weight"],
|
123
123
|
scale=scale,
|
124
124
|
zero_point=zero_point,
|
125
125
|
)
|
126
126
|
yield merge_names(weight_name, "weight"), decompressed
|
127
|
+
|
128
|
+
|
129
|
+
@Compressor.register(name=CompressionFormat.int_quantized.value)
|
130
|
+
class IntQuantizationCompressor(QuantizationCompressor):
|
131
|
+
"""
|
132
|
+
Alias for integer quantized models
|
133
|
+
"""
|
134
|
+
|
135
|
+
pass
|
136
|
+
|
137
|
+
|
138
|
+
@Compressor.register(name=CompressionFormat.float_quantized.value)
|
139
|
+
class FloatQuantizationCompressor(QuantizationCompressor):
|
140
|
+
"""
|
141
|
+
Alias for fp quantized models
|
142
|
+
"""
|
143
|
+
|
144
|
+
pass
|
@@ -126,10 +126,6 @@ class PackedQuantizationCompressor(Compressor):
|
|
126
126
|
if "weight_scale" in weight_data:
|
127
127
|
zero_point = weight_data.get("weight_zero_point", None)
|
128
128
|
scale = weight_data["weight_scale"]
|
129
|
-
if zero_point is None:
|
130
|
-
# zero_point assumed to be 0 if not included in state_dict
|
131
|
-
zero_point = torch.zeros_like(scale)
|
132
|
-
|
133
129
|
weight = weight_data["weight_packed"]
|
134
130
|
original_shape = torch.Size(weight_data["weight_shape"])
|
135
131
|
unpacked = unpack_4bit_ints(weight, original_shape)
|
@@ -26,6 +26,8 @@ class CompressionFormat(Enum):
|
|
26
26
|
dense = "dense"
|
27
27
|
sparse_bitmask = "sparse-bitmask"
|
28
28
|
int_quantized = "int-quantized"
|
29
|
+
float_quantized = "float-quantized"
|
30
|
+
naive_quantized = "naive-quantized"
|
29
31
|
pack_quantized = "pack-quantized"
|
30
32
|
marlin_24 = "marlin-24"
|
31
33
|
|
@@ -215,15 +215,11 @@ def _load_quant_args_from_state_dict(
|
|
215
215
|
scale = getattr(module, scale_name, None)
|
216
216
|
zp = getattr(module, zp_name, None)
|
217
217
|
if scale is not None:
|
218
|
-
state_dict_scale = state_dict
|
219
|
-
|
220
|
-
scale.data = state_dict_scale.to(device).to(scale.dtype)
|
221
|
-
else:
|
222
|
-
scale.data = scale.data.to(device)
|
223
|
-
|
218
|
+
state_dict_scale = state_dict[f"{module_name}.{scale_name}"]
|
219
|
+
scale.data = state_dict_scale.to(device).to(scale.dtype)
|
224
220
|
if zp is not None:
|
225
221
|
zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
|
226
222
|
if zp_from_state is not None: # load the non-zero zero points
|
227
|
-
zp.data =
|
223
|
+
zp.data = zp_from_state.to(device).to(zp.dtype)
|
228
224
|
else: # fill with zeros matching scale shape
|
229
|
-
zp.data = torch.zeros_like(scale, dtype=
|
225
|
+
zp.data = torch.zeros_like(scale, dtype=zp.dtype).to(device)
|
@@ -17,9 +17,11 @@ from math import ceil
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from compressed_tensors.quantization.observers.helpers import calculate_range
|
20
21
|
from compressed_tensors.quantization.quant_args import (
|
21
22
|
QuantizationArgs,
|
22
23
|
QuantizationStrategy,
|
24
|
+
round_to_quantized_type,
|
23
25
|
)
|
24
26
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
25
27
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
@@ -80,8 +82,9 @@ def quantize(
|
|
80
82
|
def dequantize(
|
81
83
|
x_q: torch.Tensor,
|
82
84
|
scale: torch.Tensor,
|
83
|
-
zero_point: torch.Tensor,
|
85
|
+
zero_point: torch.Tensor = None,
|
84
86
|
args: QuantizationArgs = None,
|
87
|
+
dtype: Optional[torch.dtype] = None,
|
85
88
|
) -> torch.Tensor:
|
86
89
|
"""
|
87
90
|
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
@@ -91,6 +94,7 @@ def dequantize(
|
|
91
94
|
:param scale: scale tensor
|
92
95
|
:param zero_point: zero point tensor
|
93
96
|
:param args: quantization args used to quantize x_q
|
97
|
+
:param dtype: optional dtype to cast the dequantized output to
|
94
98
|
:return: dequantized float tensor
|
95
99
|
"""
|
96
100
|
if args is None:
|
@@ -107,8 +111,12 @@ def dequantize(
|
|
107
111
|
else:
|
108
112
|
raise ValueError(
|
109
113
|
f"Could not infer a quantization strategy from scale with {scale.ndim} "
|
110
|
-
"dimmensions. Expected 0
|
114
|
+
"dimmensions. Expected 0 or 2 dimmensions."
|
111
115
|
)
|
116
|
+
|
117
|
+
if dtype is None:
|
118
|
+
dtype = scale.dtype
|
119
|
+
|
112
120
|
return _process_quantization(
|
113
121
|
x=x_q,
|
114
122
|
scale=scale,
|
@@ -116,6 +124,7 @@ def dequantize(
|
|
116
124
|
args=args,
|
117
125
|
do_quantize=False,
|
118
126
|
do_dequantize=True,
|
127
|
+
dtype=dtype,
|
119
128
|
)
|
120
129
|
|
121
130
|
|
@@ -159,19 +168,13 @@ def _process_quantization(
|
|
159
168
|
do_quantize: bool = True,
|
160
169
|
do_dequantize: bool = True,
|
161
170
|
) -> torch.Tensor:
|
162
|
-
|
163
|
-
q_max =
|
164
|
-
q_min = torch.tensor(-bit_range / 2, device=x.device)
|
171
|
+
|
172
|
+
q_min, q_max = calculate_range(args, x.device)
|
165
173
|
group_size = args.group_size
|
166
174
|
|
167
175
|
if args.strategy == QuantizationStrategy.GROUP:
|
168
|
-
|
169
|
-
|
170
|
-
# if dequantizing a quantized type infer the output type from the scale
|
171
|
-
output = torch.zeros_like(x, dtype=scale.dtype)
|
172
|
-
else:
|
173
|
-
output_dtype = dtype if dtype is not None else x.dtype
|
174
|
-
output = torch.zeros_like(x, dtype=output_dtype)
|
176
|
+
output_dtype = dtype if dtype is not None else x.dtype
|
177
|
+
output = torch.zeros_like(x).to(output_dtype)
|
175
178
|
|
176
179
|
# TODO: vectorize the for loop
|
177
180
|
# TODO: fix genetric assumption about the tensor size for computing group
|
@@ -181,7 +184,7 @@ def _process_quantization(
|
|
181
184
|
while scale.ndim < 2:
|
182
185
|
# pad scale and zero point dims for slicing
|
183
186
|
scale = scale.unsqueeze(1)
|
184
|
-
zero_point = zero_point.unsqueeze(1)
|
187
|
+
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
|
185
188
|
|
186
189
|
columns = x.shape[1]
|
187
190
|
if columns >= group_size:
|
@@ -194,12 +197,18 @@ def _process_quantization(
|
|
194
197
|
# scale.shape should be [nchan, ndim]
|
195
198
|
# sc.shape should be [nchan, 1] after unsqueeze
|
196
199
|
sc = scale[:, i].view(-1, 1)
|
197
|
-
zp = zero_point[:, i].view(-1, 1)
|
200
|
+
zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
|
198
201
|
|
199
202
|
idx = i * group_size
|
200
203
|
if do_quantize:
|
201
204
|
output[:, idx : (idx + group_size)] = _quantize(
|
202
|
-
x[:, idx : (idx + group_size)],
|
205
|
+
x[:, idx : (idx + group_size)],
|
206
|
+
sc,
|
207
|
+
zp,
|
208
|
+
q_min,
|
209
|
+
q_max,
|
210
|
+
args,
|
211
|
+
dtype=dtype,
|
203
212
|
)
|
204
213
|
if do_dequantize:
|
205
214
|
input = (
|
@@ -211,7 +220,15 @@ def _process_quantization(
|
|
211
220
|
|
212
221
|
else: # covers channel, token and tensor strategies
|
213
222
|
if do_quantize:
|
214
|
-
output = _quantize(
|
223
|
+
output = _quantize(
|
224
|
+
x,
|
225
|
+
scale,
|
226
|
+
zero_point,
|
227
|
+
q_min,
|
228
|
+
q_max,
|
229
|
+
args,
|
230
|
+
dtype=dtype,
|
231
|
+
)
|
215
232
|
if do_dequantize:
|
216
233
|
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
217
234
|
|
@@ -305,14 +322,18 @@ def _quantize(
|
|
305
322
|
zero_point: torch.Tensor,
|
306
323
|
q_min: torch.Tensor,
|
307
324
|
q_max: torch.Tensor,
|
325
|
+
args: QuantizationArgs,
|
308
326
|
dtype: Optional[torch.dtype] = None,
|
309
327
|
) -> torch.Tensor:
|
310
|
-
|
311
|
-
|
328
|
+
|
329
|
+
scaled = x / scale + zero_point.to(x.dtype)
|
330
|
+
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
|
331
|
+
clamped_value = torch.clamp(
|
332
|
+
scaled,
|
312
333
|
q_min,
|
313
334
|
q_max,
|
314
335
|
)
|
315
|
-
|
336
|
+
quantized_value = round_to_quantized_type(clamped_value, args)
|
316
337
|
if dtype is not None:
|
317
338
|
quantized_value = quantized_value.to(dtype)
|
318
339
|
|
@@ -323,6 +344,16 @@ def _quantize(
|
|
323
344
|
def _dequantize(
|
324
345
|
x_q: torch.Tensor,
|
325
346
|
scale: torch.Tensor,
|
326
|
-
zero_point: torch.Tensor,
|
347
|
+
zero_point: torch.Tensor = None,
|
348
|
+
dtype: Optional[torch.dtype] = None,
|
327
349
|
) -> torch.Tensor:
|
328
|
-
|
350
|
+
|
351
|
+
dequant_value = x_q
|
352
|
+
if zero_point is not None:
|
353
|
+
dequant_value = dequant_value - zero_point.to(scale.dtype)
|
354
|
+
dequant_value = dequant_value.to(scale.dtype) * scale
|
355
|
+
|
356
|
+
if dtype is not None:
|
357
|
+
dequant_value = dequant_value.to(dtype)
|
358
|
+
|
359
|
+
return dequant_value
|
@@ -120,8 +120,9 @@ def _initialize_scale_zero_point_observer(
|
|
120
120
|
)
|
121
121
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
122
122
|
|
123
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
123
124
|
init_zero_point = Parameter(
|
124
|
-
torch.empty(expected_shape, device=device, dtype=
|
125
|
+
torch.empty(expected_shape, device=device, dtype=zp_dtype),
|
125
126
|
requires_grad=False,
|
126
127
|
)
|
127
128
|
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
@@ -15,11 +15,15 @@
|
|
15
15
|
from typing import Tuple
|
16
16
|
|
17
17
|
import torch
|
18
|
-
from compressed_tensors.quantization.quant_args import
|
18
|
+
from compressed_tensors.quantization.quant_args import (
|
19
|
+
FP8_DTYPE,
|
20
|
+
QuantizationArgs,
|
21
|
+
QuantizationType,
|
22
|
+
)
|
19
23
|
from torch import FloatTensor, IntTensor, Tensor
|
20
24
|
|
21
25
|
|
22
|
-
__all__ = ["calculate_qparams"]
|
26
|
+
__all__ = ["calculate_qparams", "calculate_range"]
|
23
27
|
|
24
28
|
|
25
29
|
def calculate_qparams(
|
@@ -37,22 +41,53 @@ def calculate_qparams(
|
|
37
41
|
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
38
42
|
device = min_vals.device
|
39
43
|
|
40
|
-
|
41
|
-
|
42
|
-
|
44
|
+
bit_min, bit_max = calculate_range(quantization_args, device)
|
45
|
+
bit_range = bit_max - bit_min
|
46
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
47
|
+
|
43
48
|
if quantization_args.symmetric:
|
44
|
-
max_val_pos = torch.max(
|
49
|
+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
|
45
50
|
scales = max_val_pos / (float(bit_range) / 2)
|
46
51
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
47
|
-
zero_points = torch.zeros(scales.shape, device=device, dtype=
|
52
|
+
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
|
48
53
|
else:
|
49
54
|
scales = (max_vals - min_vals) / float(bit_range)
|
50
55
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
51
|
-
zero_points = bit_min -
|
52
|
-
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
56
|
+
zero_points = bit_min - (min_vals / scales)
|
57
|
+
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
58
|
+
|
59
|
+
# match zero-points to quantized type
|
60
|
+
zero_points = zero_points.to(zp_dtype)
|
53
61
|
|
54
62
|
if scales.ndim == 0:
|
55
63
|
scales = scales.reshape(1)
|
56
64
|
zero_points = zero_points.reshape(1)
|
57
65
|
|
58
66
|
return scales, zero_points
|
67
|
+
|
68
|
+
|
69
|
+
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
|
70
|
+
"""
|
71
|
+
Calculated the effective quantization range for the given Quantization Args
|
72
|
+
|
73
|
+
:param quantization_args: quantization args to get range of
|
74
|
+
:param device: device to store the range to
|
75
|
+
:return: tuple endpoints for the given quantization range
|
76
|
+
"""
|
77
|
+
if quantization_args.type == QuantizationType.INT:
|
78
|
+
bit_range = 2**quantization_args.num_bits
|
79
|
+
q_max = torch.tensor(bit_range / 2 - 1, device=device)
|
80
|
+
q_min = torch.tensor(-bit_range / 2, device=device)
|
81
|
+
elif quantization_args.type == QuantizationType.FLOAT:
|
82
|
+
if quantization_args.num_bits != 8:
|
83
|
+
raise ValueError(
|
84
|
+
"Floating point quantization is only supported for 8 bits,"
|
85
|
+
f"got {quantization_args.num_bits}"
|
86
|
+
)
|
87
|
+
fp_range_info = torch.finfo(FP8_DTYPE)
|
88
|
+
q_max = torch.tensor(fp_range_info.max, device=device)
|
89
|
+
q_min = torch.tensor(fp_range_info.min, device=device)
|
90
|
+
else:
|
91
|
+
raise ValueError(f"Invalid quantization type {quantization_args.type}")
|
92
|
+
|
93
|
+
return q_min, q_max
|
@@ -15,10 +15,19 @@
|
|
15
15
|
from enum import Enum
|
16
16
|
from typing import Any, Dict, Optional
|
17
17
|
|
18
|
+
import torch
|
18
19
|
from pydantic import BaseModel, Field, validator
|
19
20
|
|
20
21
|
|
21
|
-
__all__ = [
|
22
|
+
__all__ = [
|
23
|
+
"FP8_DTYPE",
|
24
|
+
"QuantizationType",
|
25
|
+
"QuantizationStrategy",
|
26
|
+
"QuantizationArgs",
|
27
|
+
"round_to_quantized_type",
|
28
|
+
]
|
29
|
+
|
30
|
+
FP8_DTYPE = torch.float8_e4m3fn
|
22
31
|
|
23
32
|
|
24
33
|
class QuantizationType(str, Enum):
|
@@ -123,3 +132,38 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
123
132
|
return QuantizationStrategy.TENSOR
|
124
133
|
|
125
134
|
return value
|
135
|
+
|
136
|
+
def pytorch_dtype(self) -> torch.dtype:
|
137
|
+
if self.type == QuantizationType.FLOAT:
|
138
|
+
return FP8_DTYPE
|
139
|
+
elif self.type == QuantizationType.INT:
|
140
|
+
if self.num_bits <= 8:
|
141
|
+
return torch.int8
|
142
|
+
elif self.num_bits <= 16:
|
143
|
+
return torch.int16
|
144
|
+
else:
|
145
|
+
return torch.int32
|
146
|
+
else:
|
147
|
+
raise ValueError(f"Invalid quantization type {self.type}")
|
148
|
+
|
149
|
+
|
150
|
+
def round_to_quantized_type(
|
151
|
+
tensor: torch.Tensor, args: QuantizationArgs
|
152
|
+
) -> torch.Tensor:
|
153
|
+
"""
|
154
|
+
Rounds each element of the input tensor to the nearest quantized representation,
|
155
|
+
keeping to original dtype
|
156
|
+
|
157
|
+
:param tensor: tensor to round
|
158
|
+
:param args: QuantizationArgs to pull appropriate dtype from
|
159
|
+
:return: rounded tensor
|
160
|
+
"""
|
161
|
+
original_dtype = tensor.dtype
|
162
|
+
if args.type == QuantizationType.FLOAT:
|
163
|
+
rounded = tensor.to(FP8_DTYPE)
|
164
|
+
elif args.type == QuantizationType.INT:
|
165
|
+
rounded = torch.round(tensor)
|
166
|
+
else:
|
167
|
+
raise ValueError(f"Invalid quantization type {args.type}")
|
168
|
+
|
169
|
+
return rounded.to(original_dtype)
|
@@ -15,7 +15,11 @@
|
|
15
15
|
from copy import deepcopy
|
16
16
|
from typing import List, Optional
|
17
17
|
|
18
|
-
from compressed_tensors.quantization.quant_args import
|
18
|
+
from compressed_tensors.quantization.quant_args import (
|
19
|
+
QuantizationArgs,
|
20
|
+
QuantizationStrategy,
|
21
|
+
QuantizationType,
|
22
|
+
)
|
19
23
|
from pydantic import BaseModel
|
20
24
|
|
21
25
|
|
@@ -107,13 +111,15 @@ def is_preset_scheme(name: str) -> bool:
|
|
107
111
|
return name.upper() in PRESET_SCHEMES
|
108
112
|
|
109
113
|
|
110
|
-
W8A8 = dict(
|
111
|
-
|
114
|
+
W8A8 = dict(weights=QuantizationArgs(), input_activations=QuantizationArgs())
|
115
|
+
|
116
|
+
W4A16 = dict(weights=QuantizationArgs(num_bits=4, group_size=128))
|
117
|
+
|
118
|
+
FP8 = dict(
|
119
|
+
weights=QuantizationArgs(type=QuantizationType.FLOAT),
|
120
|
+
input_activations=QuantizationArgs(type=QuantizationType.FLOAT),
|
112
121
|
)
|
113
122
|
|
114
|
-
|
123
|
+
PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8}
|
115
124
|
|
116
|
-
PRESET_SCHEMES = {
|
117
|
-
"W8A8": W8A8,
|
118
|
-
"W4A16": W4A16,
|
119
|
-
}
|
125
|
+
PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8}
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: compressed-tensors-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.20240620
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -1,37 +1,37 @@
|
|
1
1
|
compressed_tensors/__init__.py,sha256=SV1csvHUVCd8kHXz6UDZim1HZ_fAVG3vfk-j_4Bb6hY,789
|
2
2
|
compressed_tensors/base.py,sha256=OA2TOLP1gP3LSH7gp508eqr2ZtDQ-pqRHElCp-aB0vs,755
|
3
3
|
compressed_tensors/version.py,sha256=cJJf0y0NnXErTtQtVQjOvrq9hMIkhXIfBwuu4Tuxl24,1586
|
4
|
-
compressed_tensors/compressors/__init__.py,sha256=
|
4
|
+
compressed_tensors/compressors/__init__.py,sha256=wmX4VnkUTS63xBwK5-6w8FP78bNZpcdcqvf2KOEC5E4,1133
|
5
5
|
compressed_tensors/compressors/base.py,sha256=LWEgbpgTxzmoqQ7Xhq2OQszUgWoDtFuGCiV1Y8nlBGw,2134
|
6
6
|
compressed_tensors/compressors/dense.py,sha256=G_XHbvuENyupIKlXSITOQgvPkNkcMEOLcLWQr70V9EE,1257
|
7
7
|
compressed_tensors/compressors/helpers.py,sha256=k9avlkmeYj6vkOAvl-MgcixtP7ib24SCfhzZ-RusXfw,5403
|
8
|
-
compressed_tensors/compressors/int_quantized.py,sha256=Ct2vCK0yoPm6vkIFlzDMGQ7m14xT1GyURsSwH9DP770,5242
|
9
8
|
compressed_tensors/compressors/marlin_24.py,sha256=X_BjtFB3Mn0hqiLz56UM3jGX2eNmGLnvEIPfbg7di6U,9444
|
10
|
-
compressed_tensors/compressors/model_compressor.py,sha256=
|
11
|
-
compressed_tensors/compressors/
|
9
|
+
compressed_tensors/compressors/model_compressor.py,sha256=83AWAhlrR3QTNelfMGCh_10G-VfMIRXRTvV0ZZinCU8,13338
|
10
|
+
compressed_tensors/compressors/naive_quantized.py,sha256=N3y5LxsCaTUJHT30sqEhnviZsyoz1v2eUaayE7-f8Xs,5562
|
11
|
+
compressed_tensors/compressors/pack_quantized.py,sha256=ODb03_WaBQ1l99Gmp49olAUZ2TB_67z9qNZbc56X7NU,8275
|
12
12
|
compressed_tensors/compressors/sparse_bitmask.py,sha256=H9oZSTYI1oRCzAMbd4zThUnZd1h2rfs8DmA3tPcvuNE,8637
|
13
13
|
compressed_tensors/compressors/utils/__init__.py,sha256=-mbGDZh1hd9T6u62Ht_iBIK255UmMg0f5bLkSs1f9Cc,731
|
14
14
|
compressed_tensors/compressors/utils/helpers.py,sha256=4fq7KclSIK__jemCG9pwYlgWLrQjsaAMxhIrhjdw0BQ,1506
|
15
15
|
compressed_tensors/compressors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
16
16
|
compressed_tensors/compressors/utils/semi_structured_conversions.py,sha256=g1EZHzdv-ko7ufPX430dp7wE33o6FWJXuSP4zZydCu0,13488
|
17
17
|
compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
|
18
|
-
compressed_tensors/config/base.py,sha256=
|
18
|
+
compressed_tensors/config/base.py,sha256=caSZ7xZ_kgcHRMXZ5hM1i6TKbgY__CkiSjZ93imHZQ0,1562
|
19
19
|
compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
|
20
20
|
compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
|
21
21
|
compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
|
22
|
-
compressed_tensors/quantization/quant_args.py,sha256=
|
22
|
+
compressed_tensors/quantization/quant_args.py,sha256=Vc_tWSTcbZZsMJlACpLq4JEPvGx87izc8VEx-mcXjoM,5621
|
23
23
|
compressed_tensors/quantization/quant_config.py,sha256=hL42sXp1wAZxyrkHarw7tAMRcwSVEr0MT3wmrmL3NhE,8285
|
24
|
-
compressed_tensors/quantization/quant_scheme.py,sha256
|
24
|
+
compressed_tensors/quantization/quant_scheme.py,sha256=Yhaj3QJn4lifGMoQ8mlXXOdLDZA6iGMthb_0hlAzvVk,3811
|
25
25
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=ggRGWRqhCxCaTTDWRcgTVX3axnS2xV6rc5YvdzK7fSg,798
|
26
|
-
compressed_tensors/quantization/lifecycle/apply.py,sha256=
|
26
|
+
compressed_tensors/quantization/lifecycle/apply.py,sha256=eQfuIGcX6KBKeMta1svviXXRpKO3og2CRrxhKlGcE_k,8756
|
27
27
|
compressed_tensors/quantization/lifecycle/calibration.py,sha256=mLns4jlaWmBwOW8Jtlm5bMX-JET1AiZYUBO7qa-XuxI,1776
|
28
28
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=VreB10xPwgSLQQlTu20UCrFpRS--cA7-lx5s7nrPPrg,2247
|
29
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
29
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=tcjL_qyE3ODourNprt2bndF7_ALlUEGY2_Yag4exLoE,11908
|
30
30
|
compressed_tensors/quantization/lifecycle/frozen.py,sha256=h1XYt89MouBTf3jTYLG_6OdFxIu5q2N8tPjsy6J4E6Y,1726
|
31
|
-
compressed_tensors/quantization/lifecycle/initialize.py,sha256=
|
31
|
+
compressed_tensors/quantization/lifecycle/initialize.py,sha256=kIEx6a7UyqAIG_ZPNBhijrDiAHnp2wR7K_GC3envz4M,4631
|
32
32
|
compressed_tensors/quantization/observers/__init__.py,sha256=DNH31NQYrIBBcmHsMyFA6whh4pbRsLwuNa6L8AeXaGc,745
|
33
33
|
compressed_tensors/quantization/observers/base.py,sha256=z_JC-CRz-PY7WlpSoyOoSQQWz5ekTEd5LbXt0iHQRes,5239
|
34
|
-
compressed_tensors/quantization/observers/helpers.py,sha256=
|
34
|
+
compressed_tensors/quantization/observers/helpers.py,sha256=DSNGNJpZyT2Lyu0c82dHEGf9q5vm4N3zgI3DpkBbp0Q,3597
|
35
35
|
compressed_tensors/quantization/observers/memoryless.py,sha256=jH_c6K3gxf4W3VNXQ7tbnP-J_86QTrEfjBn6Kh1C-H8,2165
|
36
36
|
compressed_tensors/quantization/observers/min_max.py,sha256=UK7zCMzxv9GGn6BflBxdajV20RiWaCY2RHcvZodCP1w,3669
|
37
37
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
@@ -41,8 +41,8 @@ compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85S
|
|
41
41
|
compressed_tensors/utils/__init__.py,sha256=5DrYjoZbaEvSkJcC-GRSbM_RBHVF4tG9gMd3zsJnjLw,665
|
42
42
|
compressed_tensors/utils/helpers.py,sha256=5ull5yFT31M2zVxKeFvpvvlvX5f1Sk1LGuj_wrfZWCY,2267
|
43
43
|
compressed_tensors/utils/safetensors_load.py,sha256=0MheXwx1jeY12PeISppiSIZHs6rmN2YddwPpFb9V67I,8527
|
44
|
-
compressed_tensors_nightly-0.4.0.
|
45
|
-
compressed_tensors_nightly-0.4.0.
|
46
|
-
compressed_tensors_nightly-0.4.0.
|
47
|
-
compressed_tensors_nightly-0.4.0.
|
48
|
-
compressed_tensors_nightly-0.4.0.
|
44
|
+
compressed_tensors_nightly-0.4.0.20240620.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
45
|
+
compressed_tensors_nightly-0.4.0.20240620.dist-info/METADATA,sha256=SIDlitJJYg5Kj4OMjexqjCVR1DmdCFND-At86Hrnqt4,5668
|
46
|
+
compressed_tensors_nightly-0.4.0.20240620.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
47
|
+
compressed_tensors_nightly-0.4.0.20240620.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
48
|
+
compressed_tensors_nightly-0.4.0.20240620.dist-info/RECORD,,
|
File without changes
|
File without changes
|