compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +1 -0
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +200 -8
- compressed_tensors/compressors/dense.py +1 -1
- compressed_tensors/compressors/marlin_24.py +11 -10
- compressed_tensors/compressors/model_compressor.py +101 -13
- compressed_tensors/compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/pack_quantized.py +128 -132
- compressed_tensors/compressors/sparse_bitmask.py +1 -1
- compressed_tensors/config/base.py +8 -1
- compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -44
- compressed_tensors/quantization/lifecycle/calibration.py +22 -2
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +139 -61
- compressed_tensors/quantization/lifecycle/helpers.py +80 -0
- compressed_tensors/quantization/lifecycle/initialize.py +77 -13
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +93 -14
- compressed_tensors/quantization/observers/helpers.py +64 -11
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +139 -23
- compressed_tensors/quantization/quant_config.py +35 -2
- compressed_tensors/quantization/quant_scheme.py +112 -13
- compressed_tensors/quantization/utils/helpers.py +68 -2
- compressed_tensors/utils/__init__.py +5 -0
- compressed_tensors/utils/helpers.py +44 -2
- compressed_tensors/utils/offload.py +116 -0
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
- compressed_tensors-0.6.0.dist-info/RECORD +52 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/int_quantized.py +0 -126
- compressed_tensors/compressors/utils/helpers.py +0 -43
- compressed_tensors-0.4.0.dist-info/RECORD +0 -48
- /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from enum import Enum
|
|
16
16
|
from typing import Dict, List, Optional, Union
|
17
17
|
|
18
18
|
from compressed_tensors.config import CompressionFormat
|
19
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
19
20
|
from compressed_tensors.quantization.quant_scheme import (
|
20
21
|
QuantizationScheme,
|
21
22
|
preset_name_to_scheme,
|
@@ -25,6 +26,7 @@ from compressed_tensors.quantization.utils import (
|
|
25
26
|
is_module_quantized,
|
26
27
|
iter_named_leaf_modules,
|
27
28
|
module_type,
|
29
|
+
parse_out_kv_cache_args,
|
28
30
|
)
|
29
31
|
from pydantic import BaseModel, Field
|
30
32
|
from torch.nn import Module
|
@@ -117,7 +119,18 @@ class QuantizationConfig(BaseModel):
|
|
117
119
|
other quantization configs
|
118
120
|
:param format: specifies how the quantized model is stored on disk
|
119
121
|
:quantization_status: specifies the current status of all quantized layers. It is
|
120
|
-
|
122
|
+
assumed all layers are in the same state.
|
123
|
+
:param kv_cache_scheme: optional QuantizationArgs, that specify the
|
124
|
+
quantization of the kv cache. If None, kv cache is not quantized.
|
125
|
+
When applying kv cache quantization to transformer AutoModelForCausalLM,
|
126
|
+
the kv_cache_scheme gets converted into a QuantizationScheme that:
|
127
|
+
- targets the `q_proj` and `k_proj` modules of the model. The outputs
|
128
|
+
of those modules are the keys and values that might be cached
|
129
|
+
- quantizes the outputs of the aformentioned layers, so that
|
130
|
+
keys and values are compressed before storing them in the cache
|
131
|
+
There is an explicit assumption that the model contains modules with
|
132
|
+
`k_proj` and `v_proj` in their names. If this is not the case
|
133
|
+
and kv_cache_scheme != None, the quantization of kv cache will fail
|
121
134
|
:global_compression_ratio: optional informational config to report the model
|
122
135
|
compression ratio acheived by the quantization config
|
123
136
|
:ignore: optional list of layers to ignore from config_groups. Layers in this list
|
@@ -126,6 +139,7 @@ class QuantizationConfig(BaseModel):
|
|
126
139
|
|
127
140
|
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
|
128
141
|
quant_method: str = DEFAULT_QUANTIZATION_METHOD
|
142
|
+
kv_cache_scheme: Optional[QuantizationArgs] = None
|
129
143
|
format: str = DEFAULT_QUANTIZATION_FORMAT
|
130
144
|
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
|
131
145
|
global_compression_ratio: Optional[float] = None
|
@@ -154,7 +168,7 @@ class QuantizationConfig(BaseModel):
|
|
154
168
|
) -> Optional["QuantizationConfig"]:
|
155
169
|
"""
|
156
170
|
Converts a model into its associated QuantizationConfig based on the
|
157
|
-
QuantizationScheme attached to each
|
171
|
+
QuantizationScheme attached to each quantized module
|
158
172
|
|
159
173
|
:param model: model to calculate quantization scheme of
|
160
174
|
:return: filled out QuantizationScheme for the input model
|
@@ -195,6 +209,13 @@ class QuantizationConfig(BaseModel):
|
|
195
209
|
# else we leave it off the ignore list, doesn't fall under any of the
|
196
210
|
# existing quantization schemes so it won't be quantized
|
197
211
|
|
212
|
+
kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args(
|
213
|
+
quant_scheme_to_layers
|
214
|
+
)
|
215
|
+
kv_cache_scheme = (
|
216
|
+
kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args
|
217
|
+
)
|
218
|
+
|
198
219
|
config_groups = {}
|
199
220
|
for idx, scheme in enumerate(quant_scheme_to_layers):
|
200
221
|
group_name = "group_" + str(idx)
|
@@ -213,7 +234,19 @@ class QuantizationConfig(BaseModel):
|
|
213
234
|
return QuantizationConfig(
|
214
235
|
config_groups=config_groups,
|
215
236
|
quantization_status=quantization_status,
|
237
|
+
kv_cache_scheme=kv_cache_scheme,
|
216
238
|
global_compression_ratio=compression_ratio,
|
217
239
|
format=format,
|
218
240
|
ignore=consolidated_ignore,
|
219
241
|
)
|
242
|
+
|
243
|
+
def requires_calibration_data(self):
|
244
|
+
for _, scheme in self.config_groups.items():
|
245
|
+
if scheme.input_activations is not None:
|
246
|
+
if not scheme.input_activations.dynamic:
|
247
|
+
return True
|
248
|
+
if scheme.output_activations is not None:
|
249
|
+
if not scheme.output_activations.dynamic:
|
250
|
+
return True
|
251
|
+
|
252
|
+
return False
|
@@ -15,7 +15,11 @@
|
|
15
15
|
from copy import deepcopy
|
16
16
|
from typing import List, Optional
|
17
17
|
|
18
|
-
from compressed_tensors.quantization.quant_args import
|
18
|
+
from compressed_tensors.quantization.quant_args import (
|
19
|
+
QuantizationArgs,
|
20
|
+
QuantizationStrategy,
|
21
|
+
QuantizationType,
|
22
|
+
)
|
19
23
|
from pydantic import BaseModel
|
20
24
|
|
21
25
|
|
@@ -53,15 +57,9 @@ class QuantizationScheme(BaseModel):
|
|
53
57
|
# default to quantizing all Linear layers
|
54
58
|
targets = ["Linear"]
|
55
59
|
|
56
|
-
# default
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
# default to 8 bit integer asymmetric quantization
|
61
|
-
input_activations = QuantizationArgs(num_bits=8, symmetric=True)
|
62
|
-
|
63
|
-
# Do not quantize the output activations
|
64
|
-
# by default
|
60
|
+
# by default, activations and weights are left unquantized
|
61
|
+
weights = None
|
62
|
+
input_activations = None
|
65
63
|
output_activations = None
|
66
64
|
|
67
65
|
return cls(
|
@@ -107,13 +105,114 @@ def is_preset_scheme(name: str) -> bool:
|
|
107
105
|
return name.upper() in PRESET_SCHEMES
|
108
106
|
|
109
107
|
|
108
|
+
UNQUANTIZED = dict()
|
109
|
+
|
110
|
+
# 8 bit integer weights and 8 bit activations quantization
|
110
111
|
W8A8 = dict(
|
111
|
-
weights=QuantizationArgs(
|
112
|
+
weights=QuantizationArgs(
|
113
|
+
num_bits=8,
|
114
|
+
type=QuantizationType.INT,
|
115
|
+
strategy=QuantizationStrategy.CHANNEL,
|
116
|
+
symmetric=True,
|
117
|
+
dynamic=False,
|
118
|
+
),
|
119
|
+
input_activations=QuantizationArgs(
|
120
|
+
num_bits=8,
|
121
|
+
type=QuantizationType.INT,
|
122
|
+
strategy=QuantizationStrategy.TOKEN,
|
123
|
+
symmetric=True,
|
124
|
+
dynamic=True,
|
125
|
+
),
|
112
126
|
)
|
113
127
|
|
114
|
-
|
128
|
+
# 8 bit integer weights only quantization
|
129
|
+
W8A16 = dict(
|
130
|
+
weights=QuantizationArgs(
|
131
|
+
num_bits=8,
|
132
|
+
type=QuantizationType.INT,
|
133
|
+
strategy=QuantizationStrategy.CHANNEL,
|
134
|
+
symmetric=True,
|
135
|
+
dynamic=False,
|
136
|
+
),
|
137
|
+
)
|
138
|
+
|
139
|
+
# 4 bit integer weights only quantization
|
140
|
+
W4A16 = dict(
|
141
|
+
weights=QuantizationArgs(
|
142
|
+
num_bits=4,
|
143
|
+
type=QuantizationType.INT,
|
144
|
+
strategy=QuantizationStrategy.GROUP,
|
145
|
+
group_size=128,
|
146
|
+
symmetric=True,
|
147
|
+
dynamic=False,
|
148
|
+
),
|
149
|
+
)
|
150
|
+
|
151
|
+
# 4 bit integer weights and 8 bit activations quantization
|
152
|
+
W4A8 = dict(
|
153
|
+
weights=QuantizationArgs(
|
154
|
+
num_bits=4,
|
155
|
+
type=QuantizationType.INT,
|
156
|
+
group_size=128,
|
157
|
+
strategy=QuantizationStrategy.GROUP,
|
158
|
+
symmetric=True,
|
159
|
+
dynamic=False,
|
160
|
+
),
|
161
|
+
input_activations=QuantizationArgs(
|
162
|
+
num_bits=8,
|
163
|
+
type=QuantizationType.INT,
|
164
|
+
strategy=QuantizationStrategy.TOKEN,
|
165
|
+
symmetric=True,
|
166
|
+
dynamic=True,
|
167
|
+
),
|
168
|
+
)
|
169
|
+
|
170
|
+
# FP8 weights and FP8 activations quantization
|
171
|
+
FP8 = dict(
|
172
|
+
weights=QuantizationArgs(
|
173
|
+
num_bits=8,
|
174
|
+
type=QuantizationType.FLOAT,
|
175
|
+
strategy=QuantizationStrategy.TENSOR,
|
176
|
+
symmetric=True,
|
177
|
+
dynamic=False,
|
178
|
+
),
|
179
|
+
input_activations=QuantizationArgs(
|
180
|
+
num_bits=8,
|
181
|
+
type=QuantizationType.FLOAT,
|
182
|
+
strategy=QuantizationStrategy.TENSOR,
|
183
|
+
symmetric=True,
|
184
|
+
dynamic=False,
|
185
|
+
),
|
186
|
+
)
|
187
|
+
|
188
|
+
# FP8 weights and FP8 dynamic activations quantization
|
189
|
+
FP8_DYNAMIC = dict(
|
190
|
+
weights=QuantizationArgs(
|
191
|
+
num_bits=8,
|
192
|
+
type=QuantizationType.FLOAT,
|
193
|
+
strategy=QuantizationStrategy.CHANNEL,
|
194
|
+
symmetric=True,
|
195
|
+
dynamic=False,
|
196
|
+
),
|
197
|
+
input_activations=QuantizationArgs(
|
198
|
+
num_bits=8,
|
199
|
+
type=QuantizationType.FLOAT,
|
200
|
+
strategy=QuantizationStrategy.TOKEN,
|
201
|
+
symmetric=True,
|
202
|
+
dynamic=True,
|
203
|
+
),
|
204
|
+
)
|
115
205
|
|
116
206
|
PRESET_SCHEMES = {
|
117
|
-
|
207
|
+
# Unquantized (no-op)
|
208
|
+
"UNQUANTIZED": UNQUANTIZED,
|
209
|
+
# Integer weight only schemes
|
210
|
+
"W8A16": W8A16,
|
118
211
|
"W4A16": W4A16,
|
212
|
+
# Integer weight and activation schemes
|
213
|
+
"W8A8": W8A8,
|
214
|
+
"W4A8": W4A8,
|
215
|
+
# Float weight and activation schemes
|
216
|
+
"FP8": FP8,
|
217
|
+
"FP8_DYNAMIC": FP8_DYNAMIC,
|
119
218
|
}
|
@@ -13,10 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
|
16
|
+
import re
|
17
|
+
from typing import List, Optional, Tuple
|
17
18
|
|
18
19
|
import torch
|
19
20
|
from compressed_tensors.quantization.observers.base import Observer
|
21
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
22
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
20
23
|
from torch.nn import Module
|
21
24
|
from tqdm import tqdm
|
22
25
|
|
@@ -30,8 +33,12 @@ __all__ = [
|
|
30
33
|
"calculate_compression_ratio",
|
31
34
|
"get_torch_bit_depth",
|
32
35
|
"can_quantize",
|
36
|
+
"parse_out_kv_cache_args",
|
37
|
+
"KV_CACHE_TARGETS",
|
38
|
+
"is_kv_cache_quant_scheme",
|
33
39
|
]
|
34
40
|
|
41
|
+
KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"]
|
35
42
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
36
43
|
|
37
44
|
|
@@ -174,7 +181,7 @@ def calculate_compression_ratio(model: Module) -> float:
|
|
174
181
|
for parameter in model.parameters():
|
175
182
|
uncompressed_bits = get_torch_bit_depth(parameter)
|
176
183
|
compressed_bits = uncompressed_bits
|
177
|
-
if is_module_quantized(submodule):
|
184
|
+
if is_module_quantized(submodule) and submodule.quantization_scheme.weights:
|
178
185
|
compressed_bits = submodule.quantization_scheme.weights.num_bits
|
179
186
|
|
180
187
|
num_weights = parameter.numel()
|
@@ -182,3 +189,62 @@ def calculate_compression_ratio(model: Module) -> float:
|
|
182
189
|
total_uncompressed += uncompressed_bits * num_weights
|
183
190
|
|
184
191
|
return total_uncompressed / total_compressed
|
192
|
+
|
193
|
+
|
194
|
+
def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
|
195
|
+
"""
|
196
|
+
Check whether the QuantizationScheme targets the kv cache.
|
197
|
+
It does if all the following criteria are met:
|
198
|
+
- the scheme targets either exactly match the KV_CACHE_TARGETS
|
199
|
+
or the match KV_CACHE_TARGETS regex pattern
|
200
|
+
- the scheme quantizes output_activations (we want to quantize the
|
201
|
+
outputs from the KV_CACHE_TARGETS, as their correspond to the
|
202
|
+
keys and values that are to be saved in the cache)
|
203
|
+
|
204
|
+
:param scheme: The QuantizationScheme to investigate
|
205
|
+
:return: boolean flag
|
206
|
+
"""
|
207
|
+
if len(scheme.targets) == 1:
|
208
|
+
# match on the KV_CACHE_TARGETS regex pattern
|
209
|
+
# if there is only one target
|
210
|
+
is_match_targets = any(
|
211
|
+
[re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
|
212
|
+
)
|
213
|
+
else:
|
214
|
+
# match on the exact KV_CACHE_TARGETS
|
215
|
+
# if there are multiple targets
|
216
|
+
is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
|
217
|
+
|
218
|
+
is_match_output_activations = scheme.output_activations is not None
|
219
|
+
return is_match_targets and is_match_output_activations
|
220
|
+
|
221
|
+
|
222
|
+
def parse_out_kv_cache_args(
|
223
|
+
quant_scheme_to_layers: List[QuantizationScheme],
|
224
|
+
) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]:
|
225
|
+
"""
|
226
|
+
If possible, parse out the kv cache specific QuantizationArgs
|
227
|
+
from the list of the QuantizationSchemes. If no kv cache
|
228
|
+
specific QuantizationArgs available, this function acts
|
229
|
+
as an identity function
|
230
|
+
|
231
|
+
:param quant_scheme_to_layers: list of QuantizationSchemes
|
232
|
+
:return: kv_cache_args (optional) and the (remaining or original)
|
233
|
+
list of the QuantizationSchemes
|
234
|
+
"""
|
235
|
+
kv_cache_quant_scheme_to_layers = [
|
236
|
+
scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme)
|
237
|
+
]
|
238
|
+
quant_scheme_to_layers = [
|
239
|
+
scheme
|
240
|
+
for scheme in quant_scheme_to_layers
|
241
|
+
if not is_kv_cache_quant_scheme(scheme)
|
242
|
+
]
|
243
|
+
|
244
|
+
if kv_cache_quant_scheme_to_layers:
|
245
|
+
kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0]
|
246
|
+
kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations
|
247
|
+
else:
|
248
|
+
kv_cache_args = None
|
249
|
+
|
250
|
+
return kv_cache_args, quant_scheme_to_layers
|
@@ -12,13 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
15
|
from typing import Optional
|
17
16
|
|
17
|
+
import torch
|
18
18
|
from transformers import AutoConfig
|
19
19
|
|
20
20
|
|
21
|
-
__all__ = [
|
21
|
+
__all__ = [
|
22
|
+
"infer_compressor_from_model_config",
|
23
|
+
"fix_fsdp_module_name",
|
24
|
+
"tensor_follows_mask_structure",
|
25
|
+
"replace_module",
|
26
|
+
]
|
22
27
|
|
23
28
|
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
24
29
|
|
@@ -61,3 +66,40 @@ def fix_fsdp_module_name(name: str) -> str:
|
|
61
66
|
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
|
62
67
|
"." + FSDP_WRAPPER_NAME, ""
|
63
68
|
)
|
69
|
+
|
70
|
+
|
71
|
+
def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
|
72
|
+
"""
|
73
|
+
:param tensor: tensor to check
|
74
|
+
:param mask: mask structure to check for, in the format "n:m"
|
75
|
+
:return: True if the tensor follows the mask structure, False otherwise.
|
76
|
+
Note, some weights can incidentally be zero, so we check for
|
77
|
+
atleast n zeros in each chunk of size m
|
78
|
+
"""
|
79
|
+
|
80
|
+
n, m = tuple(map(int, mask.split(":")))
|
81
|
+
# Reshape the tensor into chunks of size m
|
82
|
+
tensor = tensor.view(-1, m)
|
83
|
+
|
84
|
+
# Count the number of zeros in each chunk
|
85
|
+
zero_counts = (tensor == 0).sum(dim=1)
|
86
|
+
|
87
|
+
# Check if the number of zeros in each chunk atleast n
|
88
|
+
# Greater than sign is needed as some weights can incidentally
|
89
|
+
# be zero
|
90
|
+
if not torch.all(zero_counts >= n).item():
|
91
|
+
raise ValueError()
|
92
|
+
|
93
|
+
return True
|
94
|
+
|
95
|
+
|
96
|
+
def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
|
97
|
+
if "." in name:
|
98
|
+
parent_name = name.rsplit(".", 1)[0]
|
99
|
+
child_name = name[len(parent_name) + 1 :]
|
100
|
+
parent = model.get_submodule(parent_name)
|
101
|
+
else:
|
102
|
+
parent_name = ""
|
103
|
+
parent = model
|
104
|
+
child_name = name
|
105
|
+
setattr(parent, child_name, new_module)
|
@@ -0,0 +1,116 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from torch.nn import Module
|
17
|
+
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"is_module_offloaded",
|
21
|
+
"get_execution_device",
|
22
|
+
"get_offloaded_device",
|
23
|
+
"update_prefix_dict",
|
24
|
+
"update_parameter_data",
|
25
|
+
]
|
26
|
+
|
27
|
+
|
28
|
+
def is_module_offloaded(module: Module) -> bool:
|
29
|
+
"""
|
30
|
+
:param module: layer to check
|
31
|
+
:return: True if layer is offloaded from GPU, False otherwise
|
32
|
+
"""
|
33
|
+
return hasattr(module, "_hf_hook") and module._hf_hook.offload
|
34
|
+
|
35
|
+
|
36
|
+
def get_execution_device(module: Module) -> torch.device:
|
37
|
+
"""
|
38
|
+
:param module: layer to check
|
39
|
+
:return: device layer is loaded onto during forward pass
|
40
|
+
"""
|
41
|
+
if is_module_offloaded(module):
|
42
|
+
return module._hf_hook.execution_device
|
43
|
+
device = next(module.parameters()).device
|
44
|
+
|
45
|
+
# offload only gets set for leaf modules, fallback to checking for device type
|
46
|
+
if device.type == "meta":
|
47
|
+
return module._hf_hook.execution_device
|
48
|
+
|
49
|
+
return device
|
50
|
+
|
51
|
+
|
52
|
+
def get_offloaded_device(module: Module) -> torch.device:
|
53
|
+
"""
|
54
|
+
:param module: layer to check
|
55
|
+
:return: device layer is offloaded to onto after forward pass
|
56
|
+
"""
|
57
|
+
if is_module_offloaded(module):
|
58
|
+
first_key = list(module._hf_hook.weights_map.keys())[0]
|
59
|
+
prefix_dataset = module._hf_hook.weights_map.dataset
|
60
|
+
return prefix_dataset[first_key].device
|
61
|
+
return next(module.parameters()).device
|
62
|
+
|
63
|
+
|
64
|
+
def update_prefix_dict(module: Module, key: str, data: torch.Tensor):
|
65
|
+
"""
|
66
|
+
Updates the offloaded state dict for a given module. Parameter named key is replaced
|
67
|
+
by data. This is neccesary because parameter updates for offloaded modules do not
|
68
|
+
persist automatically between loads. This function only affects the offloaded
|
69
|
+
state dict and not the current state of the loaded module.
|
70
|
+
|
71
|
+
:param module: layer containing the parameter to update
|
72
|
+
:param key: name of parameter to update
|
73
|
+
:param data: tensor to update parameter with in the offloaded state dict
|
74
|
+
"""
|
75
|
+
if not is_module_offloaded(module):
|
76
|
+
raise ValueError("Prefix dict is only applicable to offloaded modules")
|
77
|
+
prefix_dict = module._hf_hook.weights_map
|
78
|
+
prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data
|
79
|
+
|
80
|
+
|
81
|
+
def update_parameter_data(
|
82
|
+
module: Module, new_param_data: torch.Tensor, param_name: str
|
83
|
+
):
|
84
|
+
"""
|
85
|
+
Updates the paramter value named param_name for a given module. This function
|
86
|
+
updates both the current loaded module state and the offloaded state dict if
|
87
|
+
the module is offloaded. This is neccesary because parameter updates for offloaded
|
88
|
+
modules do not persist automatically between loads.
|
89
|
+
|
90
|
+
:param module: layer containing the parameter to update
|
91
|
+
:param new_param_data: tensor to update parameter with
|
92
|
+
:param param_name: name of layer parameter to update
|
93
|
+
"""
|
94
|
+
if not hasattr(module, param_name):
|
95
|
+
return
|
96
|
+
|
97
|
+
device = next(module.parameters()).device
|
98
|
+
|
99
|
+
offloaded = False
|
100
|
+
if is_module_offloaded(module):
|
101
|
+
offload_device = get_offloaded_device(module)
|
102
|
+
offloaded = True
|
103
|
+
|
104
|
+
parameter = getattr(module, param_name, None)
|
105
|
+
if parameter is None:
|
106
|
+
raise ValueError("Attempted to update uninitialized parameter")
|
107
|
+
|
108
|
+
dtype = parameter.dtype
|
109
|
+
parameter.data = new_param_data.to(device).to(dtype)
|
110
|
+
|
111
|
+
if offloaded:
|
112
|
+
prefix_dict = module._hf_hook.weights_map.dataset
|
113
|
+
prefix = module._hf_hook.weights_map.prefix
|
114
|
+
prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to(
|
115
|
+
dtype
|
116
|
+
)
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Set, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
|
19
|
+
|
20
|
+
__all__ = ["safe_permute"]
|
21
|
+
|
22
|
+
|
23
|
+
# these datatypes are missing implementations required for standard permutation
|
24
|
+
_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
|
25
|
+
|
26
|
+
|
27
|
+
def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
28
|
+
"""
|
29
|
+
Perform out-of-place permutation without using torch.Tensor.index_put_,
|
30
|
+
whose implementation is missing for datatypes such as `torch.float8_e4m3fn`
|
31
|
+
|
32
|
+
:param value: tensor to permute
|
33
|
+
:param perm: permutation map
|
34
|
+
:param dim: dimension along which to apply permutation
|
35
|
+
:return: permuted value
|
36
|
+
"""
|
37
|
+
dtype_tuple = (value.dtype, value.device)
|
38
|
+
|
39
|
+
if dtype_tuple in _EXPERIMENTAL_DTYPES:
|
40
|
+
return _fallback_permute(value, perm, dim)
|
41
|
+
|
42
|
+
try:
|
43
|
+
return value[tuple([slice(None)] * dim + [perm])]
|
44
|
+
except RuntimeError:
|
45
|
+
# Mark dtype as experimental if advanced indexing fails
|
46
|
+
_EXPERIMENTAL_DTYPES.add(dtype_tuple)
|
47
|
+
return _fallback_permute(value, perm, dim)
|
48
|
+
|
49
|
+
|
50
|
+
def _fallback_permute(
|
51
|
+
value: torch.Tensor, perm: torch.Tensor, dim: int
|
52
|
+
) -> torch.Tensor:
|
53
|
+
"""
|
54
|
+
Fallback permutation method for experimental dtypes.
|
55
|
+
|
56
|
+
:param value: tensor to permute
|
57
|
+
:param perm: permutation map
|
58
|
+
:param dim: dimension along which to apply permutation
|
59
|
+
:return: permuted value
|
60
|
+
"""
|
61
|
+
value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
|
62
|
+
orig_slices = [slice(None)] * (dim + 1)
|
63
|
+
perm_slices = [slice(None)] * (dim + 1)
|
64
|
+
|
65
|
+
for index, perm_index in enumerate(perm):
|
66
|
+
orig_slices[dim] = index
|
67
|
+
perm_slices[dim] = perm_index
|
68
|
+
value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
|
69
|
+
|
70
|
+
return value_ret
|
compressed_tensors/version.py
CHANGED