compressed-tensors 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +21 -0
- compressed_tensors/base.py +16 -0
- compressed_tensors/compressors/__init__.py +25 -0
- compressed_tensors/compressors/base.py +79 -0
- compressed_tensors/compressors/dense.py +34 -0
- compressed_tensors/compressors/helpers.py +161 -0
- compressed_tensors/compressors/sparse_bitmask.py +238 -0
- compressed_tensors/config/__init__.py +18 -0
- compressed_tensors/config/base.py +42 -0
- compressed_tensors/config/dense.py +36 -0
- compressed_tensors/config/sparse_bitmask.py +36 -0
- compressed_tensors/quantization/__init__.py +21 -0
- compressed_tensors/quantization/lifecycle/__init__.py +22 -0
- compressed_tensors/quantization/lifecycle/apply.py +173 -0
- compressed_tensors/quantization/lifecycle/calibration.py +51 -0
- compressed_tensors/quantization/lifecycle/forward.py +136 -0
- compressed_tensors/quantization/lifecycle/frozen.py +46 -0
- compressed_tensors/quantization/lifecycle/initialize.py +96 -0
- compressed_tensors/quantization/observers/__init__.py +21 -0
- compressed_tensors/quantization/observers/base.py +69 -0
- compressed_tensors/quantization/observers/helpers.py +53 -0
- compressed_tensors/quantization/observers/memoryless.py +48 -0
- compressed_tensors/quantization/observers/min_max.py +65 -0
- compressed_tensors/quantization/quant_args.py +85 -0
- compressed_tensors/quantization/quant_config.py +171 -0
- compressed_tensors/quantization/quant_scheme.py +39 -0
- compressed_tensors/quantization/utils/__init__.py +16 -0
- compressed_tensors/quantization/utils/helpers.py +115 -0
- compressed_tensors/registry/__init__.py +17 -0
- compressed_tensors/registry/registry.py +360 -0
- compressed_tensors/utils/__init__.py +16 -0
- compressed_tensors/utils/helpers.py +151 -0
- compressed_tensors/utils/safetensors_load.py +237 -0
- compressed_tensors-0.3.0.dist-info/METADATA +22 -0
- compressed_tensors-0.3.0.dist-info/RECORD +37 -0
- compressed_tensors-0.3.0.dist-info/WHEEL +5 -0
- compressed_tensors-0.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,36 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from compressed_tensors.config import CompressionConfig, CompressionFormat
|
18
|
+
|
19
|
+
|
20
|
+
__all__ = ["BitmaskConfig"]
|
21
|
+
|
22
|
+
|
23
|
+
@CompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
|
24
|
+
class BitmaskConfig(CompressionConfig):
|
25
|
+
"""
|
26
|
+
Configuration for storing a sparse model using
|
27
|
+
bitmask compression
|
28
|
+
|
29
|
+
:param global_sparsity: average sparsity of the entire model
|
30
|
+
:param sparsity_structure: structure of the sparsity, such as
|
31
|
+
"unstructured", "2:4", "8:16" etc
|
32
|
+
"""
|
33
|
+
|
34
|
+
format: str = CompressionFormat.sparse_bitmask.value
|
35
|
+
global_sparsity: Optional[float] = 0.0
|
36
|
+
sparsity_structure: Optional[str] = "unstructured"
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# flake8: noqa
|
16
|
+
# isort: skip_file
|
17
|
+
|
18
|
+
from .quant_args import *
|
19
|
+
from .quant_config import *
|
20
|
+
from .quant_scheme import *
|
21
|
+
from .lifecycle import *
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# flake8: noqa
|
16
|
+
# isort: skip_file
|
17
|
+
|
18
|
+
from .calibration import *
|
19
|
+
from .forward import *
|
20
|
+
from .frozen import *
|
21
|
+
from .initialize import *
|
22
|
+
from .apply import *
|
@@ -0,0 +1,173 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import re
|
16
|
+
from collections import OrderedDict
|
17
|
+
from typing import Dict, Iterable, Optional
|
18
|
+
|
19
|
+
from compressed_tensors.quantization.lifecycle.calibration import (
|
20
|
+
set_module_for_calibration,
|
21
|
+
)
|
22
|
+
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
|
23
|
+
from compressed_tensors.quantization.lifecycle.initialize import (
|
24
|
+
initialize_module_for_quantization,
|
25
|
+
)
|
26
|
+
from compressed_tensors.quantization.quant_config import (
|
27
|
+
QuantizationConfig,
|
28
|
+
QuantizationStatus,
|
29
|
+
)
|
30
|
+
from compressed_tensors.quantization.utils import iter_named_leaf_modules
|
31
|
+
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
32
|
+
from torch.nn import Module
|
33
|
+
|
34
|
+
|
35
|
+
__all__ = [
|
36
|
+
"load_pretrained_quantization",
|
37
|
+
"apply_quantization_config",
|
38
|
+
"apply_quantization_status",
|
39
|
+
]
|
40
|
+
|
41
|
+
from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
42
|
+
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
|
43
|
+
|
44
|
+
|
45
|
+
def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
46
|
+
"""
|
47
|
+
Loads the quantization parameters (scale and zero point) from model_name_or_path to
|
48
|
+
a model that has already been initialized with a quantization config
|
49
|
+
|
50
|
+
:param model: model to load pretrained quantization parameters to
|
51
|
+
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
|
52
|
+
model, which is used to load quantization parameters
|
53
|
+
"""
|
54
|
+
model_path = get_safetensors_folder(model_name_or_path)
|
55
|
+
state_dict = get_quantization_state_dict(model_path)
|
56
|
+
|
57
|
+
for name, submodule in iter_named_leaf_modules(model):
|
58
|
+
if not is_module_quantized(submodule):
|
59
|
+
continue
|
60
|
+
if submodule.quantization_scheme.weights is not None:
|
61
|
+
base_name = "weight"
|
62
|
+
_load_quant_args_from_state_dict(
|
63
|
+
base_name=base_name,
|
64
|
+
module_name=name,
|
65
|
+
module=submodule,
|
66
|
+
state_dict=state_dict,
|
67
|
+
)
|
68
|
+
if submodule.quantization_scheme.input_activations is not None:
|
69
|
+
base_name = "input"
|
70
|
+
_load_quant_args_from_state_dict(
|
71
|
+
base_name=base_name,
|
72
|
+
module_name=name,
|
73
|
+
module=submodule,
|
74
|
+
state_dict=state_dict,
|
75
|
+
)
|
76
|
+
if submodule.quantization_scheme.output_activations is not None:
|
77
|
+
base_name = "output"
|
78
|
+
_load_quant_args_from_state_dict(
|
79
|
+
base_name=base_name,
|
80
|
+
module_name=name,
|
81
|
+
module=submodule,
|
82
|
+
state_dict=state_dict,
|
83
|
+
)
|
84
|
+
|
85
|
+
|
86
|
+
def apply_quantization_config(model: Module, config: QuantizationConfig):
|
87
|
+
"""
|
88
|
+
Initializes the model for quantization in-place based on the given config
|
89
|
+
|
90
|
+
:param model: model to apply quantization config to
|
91
|
+
:param config: quantization config
|
92
|
+
"""
|
93
|
+
# build mapping of targets to schemes for easier matching
|
94
|
+
# use ordered dict to preserve target ordering in config
|
95
|
+
target_to_scheme = OrderedDict()
|
96
|
+
for scheme in config.config_groups.values():
|
97
|
+
for target in scheme.targets:
|
98
|
+
target_to_scheme[target] = scheme
|
99
|
+
|
100
|
+
# mark appropriate layers for quantization by setting their quantization schemes
|
101
|
+
for name, submodule in iter_named_leaf_modules(model):
|
102
|
+
if _find_first_name_or_class_match(name, submodule, config.ignore):
|
103
|
+
continue # layer matches ignore list, continue
|
104
|
+
target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
|
105
|
+
if target is not None:
|
106
|
+
# target matched - add layer and scheme to target list
|
107
|
+
submodule.quantization_scheme = target_to_scheme[target]
|
108
|
+
|
109
|
+
# apply current quantization status across all targeted layers
|
110
|
+
apply_quantization_status(model, config.quantization_status)
|
111
|
+
|
112
|
+
|
113
|
+
def apply_quantization_status(model: Module, status: QuantizationStatus):
|
114
|
+
"""
|
115
|
+
Applies in place the quantization lifecycle up to the given status
|
116
|
+
|
117
|
+
:param model: model to apply quantization to
|
118
|
+
:param status: status to update the module to
|
119
|
+
"""
|
120
|
+
if status >= QuantizationStatus.INITIALIZED:
|
121
|
+
model.apply(initialize_module_for_quantization)
|
122
|
+
if status >= QuantizationStatus.CALIBRATION:
|
123
|
+
model.apply(set_module_for_calibration)
|
124
|
+
if status >= QuantizationStatus.FROZEN:
|
125
|
+
model.apply(freeze_module_quantization)
|
126
|
+
|
127
|
+
|
128
|
+
def _find_first_name_or_class_match(
|
129
|
+
name: str,
|
130
|
+
module: Module,
|
131
|
+
targets: Iterable[str],
|
132
|
+
) -> Optional[str]:
|
133
|
+
# first element of targets that matches the given name
|
134
|
+
# if no name matches returns first target that matches the class name
|
135
|
+
# returns None otherwise
|
136
|
+
return _find_first_match(name, targets) or _find_first_match(
|
137
|
+
module.__class__.__name__, targets
|
138
|
+
)
|
139
|
+
|
140
|
+
|
141
|
+
def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
|
142
|
+
# returns first element of target that matches value either
|
143
|
+
# exactly or as a regex after 're:'
|
144
|
+
for target in targets:
|
145
|
+
if target.startswith("re:"):
|
146
|
+
pattern = target[3:]
|
147
|
+
if re.match(pattern, value):
|
148
|
+
return target
|
149
|
+
elif target == value:
|
150
|
+
return target
|
151
|
+
return None
|
152
|
+
|
153
|
+
|
154
|
+
def _load_quant_args_from_state_dict(
|
155
|
+
base_name: str, module_name: str, module: Module, state_dict: Dict
|
156
|
+
):
|
157
|
+
"""
|
158
|
+
Loads scale and zero point from a state_dict into the specified module
|
159
|
+
|
160
|
+
:param base_name: quantization target, one of: weights, input_activations or
|
161
|
+
output_activations
|
162
|
+
:param module_name: pytorch module name to look up in state_dict
|
163
|
+
:module: pytorch module associated with module_name
|
164
|
+
:state_dict: state_dict to search for matching quantization parameters
|
165
|
+
"""
|
166
|
+
scale_name = f"{base_name}_scale"
|
167
|
+
zp_name = f"{base_name}_zero_point"
|
168
|
+
device = next(module.parameters()).device
|
169
|
+
|
170
|
+
scale = getattr(module, scale_name)
|
171
|
+
zp = getattr(module, zp_name)
|
172
|
+
scale.data = state_dict[f"{module_name}.{scale_name}"].to(device)
|
173
|
+
zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
19
|
+
from torch.nn import Module
|
20
|
+
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
"set_module_for_calibration",
|
24
|
+
]
|
25
|
+
|
26
|
+
|
27
|
+
_LOGGER = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
def set_module_for_calibration(module: Module):
|
31
|
+
"""
|
32
|
+
marks a layer as ready for calibration which activates observers
|
33
|
+
to update scales and zero points on each forward pass
|
34
|
+
|
35
|
+
apply to full model with `model.apply(set_module_for_calibration)`
|
36
|
+
|
37
|
+
:param module: module to set for calibration
|
38
|
+
"""
|
39
|
+
if not getattr(module, "quantization_scheme", None):
|
40
|
+
# no quantization scheme nothing to do
|
41
|
+
return
|
42
|
+
status = getattr(module, "quantization_status", None)
|
43
|
+
if not status or status != QuantizationStatus.INITIALIZED:
|
44
|
+
raise _LOGGER.warning(
|
45
|
+
f"Attempting set module with status {status} to calibration mode. "
|
46
|
+
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
|
47
|
+
"be calibrating an uninitialized module which may fail or attempting "
|
48
|
+
"to re-calibrate a frozen module"
|
49
|
+
)
|
50
|
+
|
51
|
+
module.quantization_status = QuantizationStatus.CALIBRATION
|
@@ -0,0 +1,136 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from functools import wraps
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
19
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
20
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
21
|
+
from torch.nn import Module
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = ["wrap_module_forward_quantized"]
|
25
|
+
|
26
|
+
|
27
|
+
@torch.no_grad()
|
28
|
+
def quantize(
|
29
|
+
x: torch.Tensor,
|
30
|
+
scale: torch.Tensor,
|
31
|
+
zero_point: torch.Tensor,
|
32
|
+
q_min: torch.Tensor,
|
33
|
+
q_max: torch.Tensor,
|
34
|
+
) -> torch.Tensor:
|
35
|
+
return torch.clamp(
|
36
|
+
torch.round(
|
37
|
+
x / scale + zero_point,
|
38
|
+
),
|
39
|
+
q_min,
|
40
|
+
q_max,
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
@torch.no_grad()
|
45
|
+
def dequantize(
|
46
|
+
x_q: torch.Tensor,
|
47
|
+
scale: torch.Tensor,
|
48
|
+
zero_point: torch.Tensor,
|
49
|
+
) -> torch.Tensor:
|
50
|
+
return (x_q - zero_point) * scale
|
51
|
+
|
52
|
+
|
53
|
+
@torch.no_grad()
|
54
|
+
def fake_quantize(
|
55
|
+
x: torch.Tensor,
|
56
|
+
scale: torch.Tensor,
|
57
|
+
zero_point: torch.Tensor,
|
58
|
+
args: QuantizationArgs,
|
59
|
+
) -> torch.Tensor:
|
60
|
+
bit_range = 2**args.num_bits
|
61
|
+
max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
|
62
|
+
min_q = torch.tensor(-bit_range / 2, device=x.device)
|
63
|
+
Q = torch.zeros_like(x)
|
64
|
+
Q = quantize(x, scale, zero_point, min_q, max_q)
|
65
|
+
return dequantize(Q, scale, zero_point)
|
66
|
+
|
67
|
+
|
68
|
+
def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
69
|
+
# expects a module already initialized and injected with the parameters in
|
70
|
+
# initialize_module_for_quantization
|
71
|
+
forward_func_orig = module.forward.__func__
|
72
|
+
|
73
|
+
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
74
|
+
def wrapped_forward(self, *args, **kwargs):
|
75
|
+
input_ = args[0]
|
76
|
+
|
77
|
+
if scheme.input_activations is not None:
|
78
|
+
# calibrate and (fake) quantize input activations when applicable
|
79
|
+
input_ = _maybe_calibrate_or_quantize(
|
80
|
+
module, input_, "input", scheme.input_activations
|
81
|
+
)
|
82
|
+
|
83
|
+
if scheme.weights is not None:
|
84
|
+
# calibrate and (fake) quantize weights when applicable
|
85
|
+
unquantized_weight = self.weight.data.clone()
|
86
|
+
self.weight.data = _maybe_calibrate_or_quantize(
|
87
|
+
module, self.weight, "weight", scheme.weights
|
88
|
+
)
|
89
|
+
|
90
|
+
# perform wrapped forward call
|
91
|
+
output = forward_func_orig.__get__(module, module.__class__)(
|
92
|
+
input_, *args[1:], **kwargs
|
93
|
+
)
|
94
|
+
|
95
|
+
if scheme.output_activations is not None:
|
96
|
+
# calibrate and (fake) quantize output activations when applicable
|
97
|
+
output = _maybe_calibrate_or_quantize(
|
98
|
+
module, output, "output", scheme.output_activations
|
99
|
+
)
|
100
|
+
|
101
|
+
# restore back to unquantized_value
|
102
|
+
if scheme.weights is not None:
|
103
|
+
self.weight.data = unquantized_weight
|
104
|
+
|
105
|
+
return output
|
106
|
+
|
107
|
+
# bind wrapped forward to module class so reference to `self` is correct
|
108
|
+
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
|
109
|
+
# set forward to wrapped forward
|
110
|
+
setattr(module, "forward", bound_wrapped_forward)
|
111
|
+
|
112
|
+
|
113
|
+
def _maybe_calibrate_or_quantize(
|
114
|
+
module: Module, value: Module, base_name: str, args: "QuantizationArgs"
|
115
|
+
) -> torch.Tensor:
|
116
|
+
# only run quantized for the included stages
|
117
|
+
if module.quantization_status not in {
|
118
|
+
QuantizationStatus.CALIBRATION,
|
119
|
+
QuantizationStatus.FROZEN,
|
120
|
+
}:
|
121
|
+
return value
|
122
|
+
|
123
|
+
device = next(module.parameters()).device
|
124
|
+
scale = getattr(module, f"{base_name}_scale")
|
125
|
+
zero_point = getattr(module, f"{base_name}_zero_point")
|
126
|
+
|
127
|
+
if module.quantization_status == QuantizationStatus.CALIBRATION:
|
128
|
+
# get observer and get new quant params from observation
|
129
|
+
observer = getattr(module, f"{base_name}_observer")
|
130
|
+
updated_scale, updated_zero_point = observer(value)
|
131
|
+
|
132
|
+
# update scale and zero point
|
133
|
+
scale.data = updated_scale.to(device)
|
134
|
+
zero_point.data = updated_zero_point.to(device)
|
135
|
+
|
136
|
+
return fake_quantize(value, scale, zero_point, args)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
17
|
+
from torch.nn import Module
|
18
|
+
|
19
|
+
|
20
|
+
__all__ = [
|
21
|
+
"freeze_module_quantization",
|
22
|
+
]
|
23
|
+
|
24
|
+
|
25
|
+
def freeze_module_quantization(module: Module):
|
26
|
+
"""
|
27
|
+
deletes observers so static quantization is completed.
|
28
|
+
|
29
|
+
apply to full model with `model.apply(freeze_module_quantization)`
|
30
|
+
|
31
|
+
:param module: module to freeze quantization for
|
32
|
+
"""
|
33
|
+
if not getattr(module, "quantization_scheme", None):
|
34
|
+
# no quantization scheme nothing to do
|
35
|
+
return
|
36
|
+
|
37
|
+
# delete observers from module
|
38
|
+
observer_names = []
|
39
|
+
for submodule_name, _ in module.named_modules():
|
40
|
+
if "." not in submodule_name and submodule_name.endswith("_observer"):
|
41
|
+
# delete any observers that belong directly to this module
|
42
|
+
observer_names.append(submodule_name)
|
43
|
+
for observer_name in observer_names:
|
44
|
+
delattr(module, observer_name)
|
45
|
+
|
46
|
+
module.quantization_status = QuantizationStatus.FROZEN
|
@@ -0,0 +1,96 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from compressed_tensors.quantization.lifecycle.forward import (
|
21
|
+
wrap_module_forward_quantized,
|
22
|
+
)
|
23
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
24
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
25
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
26
|
+
from torch.nn import Module, Parameter
|
27
|
+
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
"initialize_module_for_quantization",
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
_LOGGER = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
def initialize_module_for_quantization(
|
38
|
+
module: Module,
|
39
|
+
scheme: Optional[QuantizationScheme] = None,
|
40
|
+
):
|
41
|
+
"""
|
42
|
+
attaches appropriate scales, zero points, and observers to a layer
|
43
|
+
given its target quantization scheme
|
44
|
+
|
45
|
+
apply to full model with `model.apply(initialize_module_for_quantization)`
|
46
|
+
|
47
|
+
:param module: module to set for calibration
|
48
|
+
:param scheme: scheme to use for quantization. if None is provided,
|
49
|
+
will attempt to use scheme stored in the module under `quantization_scheme`,
|
50
|
+
if not provided, the layer will be skipped
|
51
|
+
"""
|
52
|
+
scheme = scheme or getattr(module, "quantization_scheme", None)
|
53
|
+
if scheme is None:
|
54
|
+
# no scheme passed and layer not targeted for quantization - skip
|
55
|
+
return
|
56
|
+
|
57
|
+
if scheme.input_activations is not None:
|
58
|
+
_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
|
59
|
+
if scheme.weights is not None:
|
60
|
+
if hasattr(module, "weight"):
|
61
|
+
_initialize_scale_zero_point_observer(module, "weight", scheme.weights)
|
62
|
+
else:
|
63
|
+
_LOGGER.warning(
|
64
|
+
f"module type {type(module)} targeted for weight quantization but "
|
65
|
+
"has no attribute weight, skipping weight quantization "
|
66
|
+
f"for {type(module)}"
|
67
|
+
)
|
68
|
+
if scheme.output_activations is not None:
|
69
|
+
_initialize_scale_zero_point_observer(
|
70
|
+
module, "output", scheme.output_activations
|
71
|
+
)
|
72
|
+
|
73
|
+
module.quantization_scheme = scheme
|
74
|
+
module.quantization_status = QuantizationStatus.INITIALIZED
|
75
|
+
|
76
|
+
# wrap forward call of module to perform quantized actions based on calltime status
|
77
|
+
wrap_module_forward_quantized(module, scheme)
|
78
|
+
|
79
|
+
|
80
|
+
def _initialize_scale_zero_point_observer(
|
81
|
+
module: Module, base_name: str, quantization_args: QuantizationArgs
|
82
|
+
):
|
83
|
+
device = next(module.parameters()).device
|
84
|
+
|
85
|
+
# initializes empty scale and zero point parameters for the module
|
86
|
+
init_scale = Parameter(torch.empty(0, device=device), requires_grad=False)
|
87
|
+
module.register_parameter(f"{base_name}_scale", init_scale)
|
88
|
+
|
89
|
+
init_zero_point = Parameter(
|
90
|
+
torch.empty(0, device=device, dtype=int), requires_grad=False
|
91
|
+
)
|
92
|
+
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
93
|
+
|
94
|
+
# initialize observer module and attach as submodule
|
95
|
+
observer = quantization_args.get_observer()
|
96
|
+
module.register_module(f"{base_name}_observer", observer)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# flake8: noqa
|
16
|
+
# isort: skip_file
|
17
|
+
|
18
|
+
from .helpers import *
|
19
|
+
from .base import *
|
20
|
+
from .memoryless import *
|
21
|
+
from .min_max import *
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional, Tuple
|
16
|
+
|
17
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
18
|
+
from compressed_tensors.registry.registry import RegistryMixin
|
19
|
+
from torch import FloatTensor, IntTensor, Tensor
|
20
|
+
from torch.nn import Module
|
21
|
+
|
22
|
+
|
23
|
+
__all__ = ["Observer"]
|
24
|
+
|
25
|
+
|
26
|
+
class Observer(Module, RegistryMixin):
|
27
|
+
"""
|
28
|
+
Base Observer class to be subclassed for specific implementation.
|
29
|
+
Subclasses should override `calculate_qparams` to return a scale, zero_point
|
30
|
+
pair
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, quantization_args: QuantizationArgs):
|
34
|
+
self.quantization_args: QuantizationArgs = quantization_args
|
35
|
+
super().__init__()
|
36
|
+
self._scale = None
|
37
|
+
self._zero_point = None
|
38
|
+
|
39
|
+
def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
|
40
|
+
"""
|
41
|
+
maps directly to get_qparams
|
42
|
+
:param observed: optional observed tensor to calculate quantization parameters
|
43
|
+
from
|
44
|
+
:return: tuple of scale and zero point based on last observed value
|
45
|
+
"""
|
46
|
+
return self.get_qparams(observed=observed)
|
47
|
+
|
48
|
+
def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
|
49
|
+
"""
|
50
|
+
:param observed: observed tensor to calculate quantization parameters for
|
51
|
+
:return: tuple of scale and zero point derived from the observed tensor
|
52
|
+
"""
|
53
|
+
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
|
54
|
+
|
55
|
+
def get_qparams(
|
56
|
+
self, observed: Optional[Tensor] = None
|
57
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
58
|
+
"""
|
59
|
+
Convenience function to wrap overwritten calculate_qparams
|
60
|
+
adds support to make observed tensor optional and support for tracking latest
|
61
|
+
calculated scale and zero point
|
62
|
+
:param observed: optional observed tensor to calculate quantization parameters
|
63
|
+
from
|
64
|
+
:return: tuple of scale and zero point based on last observed value
|
65
|
+
"""
|
66
|
+
if observed is not None:
|
67
|
+
# re-calcualte scale and zero point, update the stored value
|
68
|
+
self._scale, self._zero_point = self.calculate_qparams(observed)
|
69
|
+
return self._scale, self._zero_point
|