compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +3 -1
- compressed_tensors/compressors/__init__.py +9 -1
- compressed_tensors/compressors/base.py +12 -55
- compressed_tensors/compressors/dense.py +5 -5
- compressed_tensors/compressors/helpers.py +12 -12
- compressed_tensors/compressors/marlin_24.py +251 -0
- compressed_tensors/compressors/model_compressor.py +336 -0
- compressed_tensors/compressors/naive_quantized.py +144 -0
- compressed_tensors/compressors/pack_quantized.py +219 -0
- compressed_tensors/compressors/sparse_bitmask.py +4 -4
- compressed_tensors/config/base.py +9 -4
- compressed_tensors/config/dense.py +4 -4
- compressed_tensors/config/sparse_bitmask.py +3 -3
- compressed_tensors/quantization/lifecycle/__init__.py +2 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -31
- compressed_tensors/quantization/lifecycle/calibration.py +20 -1
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +214 -62
- compressed_tensors/quantization/lifecycle/frozen.py +4 -0
- compressed_tensors/quantization/lifecycle/helpers.py +53 -0
- compressed_tensors/quantization/lifecycle/initialize.py +62 -5
- compressed_tensors/quantization/observers/base.py +66 -23
- compressed_tensors/quantization/observers/helpers.py +69 -11
- compressed_tensors/quantization/observers/memoryless.py +17 -9
- compressed_tensors/quantization/observers/min_max.py +44 -13
- compressed_tensors/quantization/quant_args.py +47 -3
- compressed_tensors/quantization/quant_config.py +104 -23
- compressed_tensors/quantization/quant_scheme.py +183 -2
- compressed_tensors/quantization/utils/helpers.py +142 -8
- compressed_tensors/utils/__init__.py +4 -0
- compressed_tensors/utils/helpers.py +54 -7
- compressed_tensors/utils/offload.py +104 -0
- compressed_tensors/utils/permutations_24.py +65 -0
- compressed_tensors/utils/safetensors_load.py +3 -2
- compressed_tensors/utils/semi_structured_conversions.py +341 -0
- compressed_tensors/version.py +53 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
- compressed_tensors-0.5.0.dist-info/RECORD +48 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
- compressed_tensors-0.3.3.dist-info/RECORD +0 -38
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
import logging
|
16
|
+
from typing import Any, Iterable, Optional, Tuple, Union
|
16
17
|
|
17
18
|
import torch
|
18
19
|
from compressed_tensors.quantization.quant_args import (
|
@@ -24,6 +25,9 @@ from torch import FloatTensor, IntTensor, Tensor
|
|
24
25
|
from torch.nn import Module
|
25
26
|
|
26
27
|
|
28
|
+
_LOGGER = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
|
27
31
|
__all__ = ["Observer"]
|
28
32
|
|
29
33
|
|
@@ -39,7 +43,9 @@ class Observer(Module, RegistryMixin):
|
|
39
43
|
super().__init__()
|
40
44
|
self._scale = None
|
41
45
|
self._zero_point = None
|
46
|
+
self._num_observed_tokens = None
|
42
47
|
|
48
|
+
@torch.no_grad()
|
43
49
|
def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
|
44
50
|
"""
|
45
51
|
maps directly to get_qparams
|
@@ -47,11 +53,19 @@ class Observer(Module, RegistryMixin):
|
|
47
53
|
from
|
48
54
|
:return: tuple of scale and zero point based on last observed value
|
49
55
|
"""
|
56
|
+
self.record_observed_tokens(observed)
|
50
57
|
return self.get_qparams(observed=observed)
|
51
58
|
|
52
|
-
def calculate_qparams(
|
59
|
+
def calculate_qparams(
|
60
|
+
self,
|
61
|
+
observed: Tensor,
|
62
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
63
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
53
64
|
"""
|
54
65
|
:param observed: observed tensor to calculate quantization parameters for
|
66
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
67
|
+
returned scale and zero point will be shaped (1,) along the
|
68
|
+
reduced dimensions
|
55
69
|
:return: tuple of scale and zero point derived from the observed tensor
|
56
70
|
"""
|
57
71
|
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
|
@@ -69,6 +83,7 @@ class Observer(Module, RegistryMixin):
|
|
69
83
|
Convenience function to wrap overwritten calculate_qparams
|
70
84
|
adds support to make observed tensor optional and support for tracking latest
|
71
85
|
calculated scale and zero point
|
86
|
+
|
72
87
|
:param observed: optional observed tensor to calculate quantization parameters
|
73
88
|
from
|
74
89
|
:return: tuple of scale and zero point based on last observed value
|
@@ -84,47 +99,75 @@ class Observer(Module, RegistryMixin):
|
|
84
99
|
elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
|
85
100
|
columns = observed.shape[1]
|
86
101
|
scales, zero_points = [], []
|
87
|
-
|
102
|
+
group_idxs = range(0, columns, self.quantization_args.group_size)
|
103
|
+
for group_id, group_idx in enumerate(group_idxs):
|
88
104
|
scale, zero_point = self.get_qparams_along_dim(
|
89
|
-
observed[:,
|
105
|
+
observed[:, group_idx : (group_idx + group_size)],
|
90
106
|
0,
|
107
|
+
tensor_id=group_id,
|
91
108
|
)
|
92
109
|
scales.append(scale)
|
93
110
|
zero_points.append(zero_point)
|
94
111
|
|
95
|
-
self._scale = torch.
|
96
|
-
self._zero_point = torch.
|
112
|
+
self._scale = torch.cat(scales, dim=1, out=self._scale)
|
113
|
+
self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)
|
97
114
|
|
98
115
|
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
99
116
|
# assume observed is transposed, because its the output, hence use dim 0
|
100
117
|
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
|
101
118
|
|
102
119
|
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
|
103
|
-
|
104
120
|
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
|
105
121
|
# should be batch, token
|
106
|
-
|
107
122
|
self._scale, self._zero_point = self.get_qparams_along_dim(
|
108
|
-
observed,
|
123
|
+
observed,
|
124
|
+
dim={0, 1},
|
109
125
|
)
|
110
126
|
|
111
127
|
return self._scale, self._zero_point
|
112
128
|
|
113
|
-
def get_qparams_along_dim(
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
129
|
+
def get_qparams_along_dim(
|
130
|
+
self,
|
131
|
+
observed,
|
132
|
+
dim: Union[int, Iterable[int]],
|
133
|
+
tensor_id: Optional[Any] = None,
|
134
|
+
):
|
135
|
+
dim = set(dim)
|
118
136
|
|
119
|
-
|
120
|
-
|
137
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
|
138
|
+
return self.calculate_qparams(
|
139
|
+
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
140
|
+
)
|
121
141
|
|
122
|
-
|
123
|
-
|
124
|
-
|
142
|
+
def record_observed_tokens(self, batch_tensor: Tensor):
|
143
|
+
"""
|
144
|
+
Counts the number of tokens observed during the
|
145
|
+
forward passes. The count is aggregated in the
|
146
|
+
_num_observed_tokens attribute of the class.
|
147
|
+
|
148
|
+
Note: The batch_tensor is expected to have two dimensions
|
149
|
+
(batch_size * sequence_length, num_features). This is the
|
150
|
+
general shape expected by the forward pass of the expert
|
151
|
+
layers in a MOE model. If the input tensor does not have
|
152
|
+
two dimensions, the _num_observed_tokens attribute will be set
|
153
|
+
to None.
|
154
|
+
"""
|
155
|
+
if not isinstance(batch_tensor, Tensor):
|
156
|
+
raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")
|
157
|
+
|
158
|
+
if batch_tensor.ndim != 2:
|
159
|
+
_LOGGER.debug(
|
160
|
+
"The input tensor is expected to have two dimensions "
|
161
|
+
"(batch_size * sequence_length, num_features). "
|
162
|
+
f"The input tensor has {batch_tensor.ndim} dimensions."
|
125
163
|
)
|
164
|
+
return
|
165
|
+
|
166
|
+
if self._num_observed_tokens is None:
|
167
|
+
# initialize the count
|
168
|
+
self._num_observed_tokens = 0
|
126
169
|
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
170
|
+
# batch_tensor (batch_size * sequence_length, num_features)
|
171
|
+
# observed_tokens (batch_size * sequence_length)
|
172
|
+
observed_tokens, _ = batch_tensor.shape
|
173
|
+
self._num_observed_tokens += observed_tokens
|
@@ -12,42 +12,100 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from collections import Counter
|
15
16
|
from typing import Tuple
|
16
17
|
|
17
18
|
import torch
|
18
|
-
from compressed_tensors.quantization.quant_args import
|
19
|
+
from compressed_tensors.quantization.quant_args import (
|
20
|
+
FP8_DTYPE,
|
21
|
+
QuantizationArgs,
|
22
|
+
QuantizationType,
|
23
|
+
)
|
19
24
|
from torch import FloatTensor, IntTensor, Tensor
|
20
25
|
|
21
26
|
|
22
|
-
__all__ = ["calculate_qparams"]
|
27
|
+
__all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
|
28
|
+
|
29
|
+
|
30
|
+
def get_observer_token_count(module: torch.nn.Module) -> Counter:
|
31
|
+
"""
|
32
|
+
Parse the module and return the number of tokens observed by
|
33
|
+
each module's observer.
|
34
|
+
|
35
|
+
:param module: module to parse
|
36
|
+
:return: counter with the number of tokens observed by each observer
|
37
|
+
"""
|
38
|
+
token_counts = Counter()
|
39
|
+
for name, module in module.named_modules():
|
40
|
+
if name.endswith(".input_observer"):
|
41
|
+
token_counts[
|
42
|
+
name.replace(".input_observer", "")
|
43
|
+
] = module._num_observed_tokens
|
44
|
+
return token_counts
|
23
45
|
|
24
46
|
|
25
47
|
def calculate_qparams(
|
26
48
|
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
|
27
49
|
) -> Tuple[FloatTensor, IntTensor]:
|
28
50
|
"""
|
29
|
-
:param min_vals: tensor of min value(s) to
|
51
|
+
:param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
|
30
52
|
from
|
31
|
-
:param max_vals: tensor of max value(s) to
|
53
|
+
:param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
|
32
54
|
from
|
33
55
|
:param quantization_args: settings to quantization
|
34
56
|
:return: tuple of the calculated scale(s) and zero point(s)
|
35
57
|
"""
|
36
58
|
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
37
59
|
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
60
|
+
device = min_vals.device
|
61
|
+
|
62
|
+
bit_min, bit_max = calculate_range(quantization_args, device)
|
63
|
+
bit_range = bit_max - bit_min
|
64
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
38
65
|
|
39
|
-
bit_range = 2**quantization_args.num_bits - 1
|
40
|
-
bit_min = -(bit_range + 1) / 2
|
41
|
-
bit_max = bit_min + bit_range
|
42
66
|
if quantization_args.symmetric:
|
43
|
-
|
44
|
-
max_val_pos = torch.max(-min_vals, max_vals)
|
67
|
+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
|
45
68
|
scales = max_val_pos / (float(bit_range) / 2)
|
46
69
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
70
|
+
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
|
47
71
|
else:
|
48
72
|
scales = (max_vals - min_vals) / float(bit_range)
|
49
73
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
50
|
-
zero_points = bit_min -
|
51
|
-
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
74
|
+
zero_points = bit_min - (min_vals / scales)
|
75
|
+
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
76
|
+
|
77
|
+
# match zero-points to quantized type
|
78
|
+
zero_points = zero_points.to(zp_dtype)
|
79
|
+
|
80
|
+
if scales.ndim == 0:
|
81
|
+
scales = scales.reshape(1)
|
82
|
+
zero_points = zero_points.reshape(1)
|
52
83
|
|
53
84
|
return scales, zero_points
|
85
|
+
|
86
|
+
|
87
|
+
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
|
88
|
+
"""
|
89
|
+
Calculated the effective quantization range for the given Quantization Args
|
90
|
+
|
91
|
+
:param quantization_args: quantization args to get range of
|
92
|
+
:param device: device to store the range to
|
93
|
+
:return: tuple endpoints for the given quantization range
|
94
|
+
"""
|
95
|
+
if quantization_args.type == QuantizationType.INT:
|
96
|
+
bit_range = 2**quantization_args.num_bits
|
97
|
+
q_max = torch.tensor(bit_range / 2 - 1, device=device)
|
98
|
+
q_min = torch.tensor(-bit_range / 2, device=device)
|
99
|
+
elif quantization_args.type == QuantizationType.FLOAT:
|
100
|
+
if quantization_args.num_bits != 8:
|
101
|
+
raise ValueError(
|
102
|
+
"Floating point quantization is only supported for 8 bits,"
|
103
|
+
f"got {quantization_args.num_bits}"
|
104
|
+
)
|
105
|
+
fp_range_info = torch.finfo(FP8_DTYPE)
|
106
|
+
q_max = torch.tensor(fp_range_info.max, device=device)
|
107
|
+
q_min = torch.tensor(fp_range_info.min, device=device)
|
108
|
+
else:
|
109
|
+
raise ValueError(f"Invalid quantization type {quantization_args.type}")
|
110
|
+
|
111
|
+
return q_min, q_max
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Tuple
|
15
|
+
from typing import Any, Optional, Tuple
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.quantization.observers.base import Observer
|
@@ -30,19 +30,27 @@ class MemorylessObserver(Observer):
|
|
30
30
|
zero point based on the latest observed value without tracking state
|
31
31
|
"""
|
32
32
|
|
33
|
-
def calculate_qparams(
|
33
|
+
def calculate_qparams(
|
34
|
+
self,
|
35
|
+
observed: Tensor,
|
36
|
+
tensor_id: Optional[Any] = None,
|
37
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
38
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
34
39
|
"""
|
35
|
-
Returns the min and max values of observed
|
40
|
+
Returns the min and max values of observed tensor
|
36
41
|
|
37
42
|
:param observed: observed tensor to calculate quantization parameters for
|
43
|
+
:param tensor_id: optional id for tensor; not used for memoryless
|
44
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
45
|
+
returned scale and zero point will be shaped (1,) along the
|
46
|
+
reduced dimensions
|
38
47
|
:return: tuple of scale and zero point derived from the observed tensor
|
39
48
|
"""
|
40
|
-
# TODO: Add support for full range of quantization Args, only supports 8bit
|
41
|
-
# per tensor
|
42
|
-
min_val, max_val = torch.aminmax(observed)
|
43
49
|
|
44
|
-
|
45
|
-
|
46
|
-
|
50
|
+
if not reduce_dims:
|
51
|
+
min_val, max_val = torch.aminmax(observed)
|
52
|
+
else:
|
53
|
+
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
54
|
+
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
47
55
|
|
48
56
|
return calculate_qparams(min_val, max_val, self.quantization_args)
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Tuple
|
15
|
+
from typing import Any, Optional, Tuple
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.quantization.observers.base import Observer
|
@@ -36,30 +36,61 @@ class MovingAverageMinMaxObserver(Observer):
|
|
36
36
|
):
|
37
37
|
super().__init__(quantization_args=quantization_args)
|
38
38
|
|
39
|
-
self.min_val =
|
40
|
-
self.max_val =
|
39
|
+
self.min_val = {}
|
40
|
+
self.max_val = {}
|
41
41
|
self.averaging_constant = averaging_constant
|
42
42
|
|
43
|
-
def calculate_qparams(
|
43
|
+
def calculate_qparams(
|
44
|
+
self,
|
45
|
+
observed: Tensor,
|
46
|
+
reduce_dims: Optional[Tuple[int]] = None,
|
47
|
+
tensor_id: Optional[Any] = None,
|
48
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
44
49
|
"""
|
45
50
|
Updates the observed min and max using a moving average smoothed by the
|
46
51
|
averaging_constant
|
47
52
|
|
48
53
|
:param observed: observed tensor to calculate quantization parameters for
|
54
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
55
|
+
returned scale and zero point will be shaped (1,) along the
|
56
|
+
reduced dimensions
|
57
|
+
:param tensor_id: Optional id if different ranges of observed tensors are
|
58
|
+
passed, useful for sharding tensors by group_size
|
49
59
|
:return: tuple of scale and zero point derived from the observed tensor
|
50
60
|
"""
|
61
|
+
tensor_id = tensor_id or "default"
|
51
62
|
|
52
|
-
|
63
|
+
if not reduce_dims:
|
64
|
+
min_val, max_val = torch.aminmax(observed)
|
65
|
+
else:
|
66
|
+
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
|
67
|
+
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
|
68
|
+
|
69
|
+
running_min_val = self.min_val.get(tensor_id, None)
|
70
|
+
running_max_val = self.max_val.get(tensor_id, None)
|
53
71
|
|
54
|
-
if
|
55
|
-
|
56
|
-
|
72
|
+
if running_min_val is None or running_max_val is None:
|
73
|
+
updated_min_val = min_val
|
74
|
+
updated_max_val = max_val
|
57
75
|
else:
|
58
|
-
|
59
|
-
min_val -
|
76
|
+
updated_min_val = running_min_val + self.averaging_constant * (
|
77
|
+
min_val - running_min_val
|
60
78
|
)
|
61
|
-
|
62
|
-
max_val -
|
79
|
+
updated_max_val = running_max_val + self.averaging_constant * (
|
80
|
+
max_val - running_max_val
|
63
81
|
)
|
64
82
|
|
65
|
-
|
83
|
+
self.min_val[tensor_id] = updated_min_val
|
84
|
+
self.max_val[tensor_id] = updated_max_val
|
85
|
+
|
86
|
+
return calculate_qparams(
|
87
|
+
updated_min_val, updated_max_val, self.quantization_args
|
88
|
+
)
|
89
|
+
|
90
|
+
def get_qparams_along_dim(
|
91
|
+
self, observed, dim: int, tensor_id: Optional[Any] = None
|
92
|
+
):
|
93
|
+
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
|
94
|
+
return self.calculate_qparams(
|
95
|
+
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
96
|
+
)
|
@@ -15,10 +15,19 @@
|
|
15
15
|
from enum import Enum
|
16
16
|
from typing import Any, Dict, Optional
|
17
17
|
|
18
|
+
import torch
|
18
19
|
from pydantic import BaseModel, Field, validator
|
19
20
|
|
20
21
|
|
21
|
-
__all__ = [
|
22
|
+
__all__ = [
|
23
|
+
"FP8_DTYPE",
|
24
|
+
"QuantizationType",
|
25
|
+
"QuantizationStrategy",
|
26
|
+
"QuantizationArgs",
|
27
|
+
"round_to_quantized_type",
|
28
|
+
]
|
29
|
+
|
30
|
+
FP8_DTYPE = torch.float8_e4m3fn
|
22
31
|
|
23
32
|
|
24
33
|
class QuantizationType(str, Enum):
|
@@ -42,7 +51,7 @@ class QuantizationStrategy(str, Enum):
|
|
42
51
|
TOKEN = "token"
|
43
52
|
|
44
53
|
|
45
|
-
class QuantizationArgs(BaseModel):
|
54
|
+
class QuantizationArgs(BaseModel, use_enum_values=True):
|
46
55
|
"""
|
47
56
|
User facing arguments used to define a quantization config for weights or
|
48
57
|
activations
|
@@ -62,7 +71,7 @@ class QuantizationArgs(BaseModel):
|
|
62
71
|
"""
|
63
72
|
|
64
73
|
num_bits: int = 8
|
65
|
-
type: QuantizationType = QuantizationType.INT
|
74
|
+
type: QuantizationType = QuantizationType.INT.value
|
66
75
|
symmetric: bool = True
|
67
76
|
group_size: Optional[int] = None
|
68
77
|
strategy: Optional[QuantizationStrategy] = None
|
@@ -123,3 +132,38 @@ class QuantizationArgs(BaseModel):
|
|
123
132
|
return QuantizationStrategy.TENSOR
|
124
133
|
|
125
134
|
return value
|
135
|
+
|
136
|
+
def pytorch_dtype(self) -> torch.dtype:
|
137
|
+
if self.type == QuantizationType.FLOAT:
|
138
|
+
return FP8_DTYPE
|
139
|
+
elif self.type == QuantizationType.INT:
|
140
|
+
if self.num_bits <= 8:
|
141
|
+
return torch.int8
|
142
|
+
elif self.num_bits <= 16:
|
143
|
+
return torch.int16
|
144
|
+
else:
|
145
|
+
return torch.int32
|
146
|
+
else:
|
147
|
+
raise ValueError(f"Invalid quantization type {self.type}")
|
148
|
+
|
149
|
+
|
150
|
+
def round_to_quantized_type(
|
151
|
+
tensor: torch.Tensor, args: QuantizationArgs
|
152
|
+
) -> torch.Tensor:
|
153
|
+
"""
|
154
|
+
Rounds each element of the input tensor to the nearest quantized representation,
|
155
|
+
keeping to original dtype
|
156
|
+
|
157
|
+
:param tensor: tensor to round
|
158
|
+
:param args: QuantizationArgs to pull appropriate dtype from
|
159
|
+
:return: rounded tensor
|
160
|
+
"""
|
161
|
+
original_dtype = tensor.dtype
|
162
|
+
if args.type == QuantizationType.FLOAT:
|
163
|
+
rounded = tensor.to(FP8_DTYPE)
|
164
|
+
elif args.type == QuantizationType.INT:
|
165
|
+
rounded = torch.round(tensor)
|
166
|
+
else:
|
167
|
+
raise ValueError(f"Invalid quantization type {args.type}")
|
168
|
+
|
169
|
+
return rounded.to(original_dtype)
|