compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +1 -0
- compressed_tensors/base.py +2 -0
- compressed_tensors/compressors/__init__.py +6 -12
- compressed_tensors/compressors/base.py +137 -9
- compressed_tensors/compressors/helpers.py +6 -6
- compressed_tensors/compressors/model_compressors/__init__.py +17 -0
- compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
- compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
- compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
- compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
- compressed_tensors/compressors/sparse_compressors/base.py +110 -0
- compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
- compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
- compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
- compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
- compressed_tensors/config/base.py +6 -1
- compressed_tensors/linear/__init__.py +13 -0
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/__init__.py +1 -0
- compressed_tensors/quantization/cache.py +201 -0
- compressed_tensors/quantization/lifecycle/apply.py +63 -9
- compressed_tensors/quantization/lifecycle/calibration.py +7 -7
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +126 -44
- compressed_tensors/quantization/lifecycle/frozen.py +6 -1
- compressed_tensors/quantization/lifecycle/helpers.py +0 -20
- compressed_tensors/quantization/lifecycle/initialize.py +138 -55
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +54 -14
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +102 -24
- compressed_tensors/quantization/quant_config.py +14 -2
- compressed_tensors/quantization/quant_scheme.py +12 -13
- compressed_tensors/quantization/utils/helpers.py +44 -19
- compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors/utils/helpers.py +30 -1
- compressed_tensors/utils/offload.py +14 -2
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/utils/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
- compressed_tensors-0.7.0.dist-info/RECORD +59 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/pack_quantized.py +0 -219
- compressed_tensors-0.5.0.dist-info/RECORD +0 -48
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -13,10 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from enum import Enum
|
16
|
-
from typing import Any, Dict, Optional
|
16
|
+
from typing import Any, Dict, Optional, Union
|
17
17
|
|
18
18
|
import torch
|
19
|
-
from pydantic import BaseModel, Field,
|
19
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
20
20
|
|
21
21
|
|
22
22
|
__all__ = [
|
@@ -25,6 +25,7 @@ __all__ = [
|
|
25
25
|
"QuantizationStrategy",
|
26
26
|
"QuantizationArgs",
|
27
27
|
"round_to_quantized_type",
|
28
|
+
"ActivationOrdering",
|
28
29
|
]
|
29
30
|
|
30
31
|
FP8_DTYPE = torch.float8_e4m3fn
|
@@ -51,6 +52,19 @@ class QuantizationStrategy(str, Enum):
|
|
51
52
|
TOKEN = "token"
|
52
53
|
|
53
54
|
|
55
|
+
class ActivationOrdering(str, Enum):
|
56
|
+
"""
|
57
|
+
Enum storing strategies for activation ordering
|
58
|
+
|
59
|
+
Group: reorder groups and weight\n
|
60
|
+
Weight: only reorder weight, not groups. Slightly lower latency and
|
61
|
+
accuracy compared to group actorder\n
|
62
|
+
"""
|
63
|
+
|
64
|
+
GROUP = "group"
|
65
|
+
WEIGHT = "weight"
|
66
|
+
|
67
|
+
|
54
68
|
class QuantizationArgs(BaseModel, use_enum_values=True):
|
55
69
|
"""
|
56
70
|
User facing arguments used to define a quantization config for weights or
|
@@ -68,15 +82,18 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
68
82
|
ranges will be observed with every sample. Defaults to False for static
|
69
83
|
quantization. Note that enabling dynamic quantization will change the default
|
70
84
|
observer to a memoryless one
|
85
|
+
:param actorder: whether to apply group quantization in decreasing order of
|
86
|
+
activation. Defaults to None for arbitrary ordering
|
71
87
|
"""
|
72
88
|
|
73
89
|
num_bits: int = 8
|
74
|
-
type: QuantizationType = QuantizationType.INT
|
90
|
+
type: QuantizationType = QuantizationType.INT
|
75
91
|
symmetric: bool = True
|
76
92
|
group_size: Optional[int] = None
|
77
93
|
strategy: Optional[QuantizationStrategy] = None
|
78
94
|
block_structure: Optional[str] = None
|
79
95
|
dynamic: bool = False
|
96
|
+
actorder: Union[ActivationOrdering, bool, None] = None
|
80
97
|
observer: str = Field(
|
81
98
|
default="minmax",
|
82
99
|
description=(
|
@@ -98,41 +115,102 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
98
115
|
"""
|
99
116
|
from compressed_tensors.quantization.observers.base import Observer
|
100
117
|
|
101
|
-
if self.
|
118
|
+
if self.dynamic:
|
102
119
|
# override defualt observer for dynamic, you never want minmax which
|
103
120
|
# keeps state across samples for dynamic
|
104
121
|
self.observer = "memoryless"
|
105
122
|
|
106
123
|
return Observer.load_from_registry(self.observer, quantization_args=self)
|
107
124
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
# use group_size to determinine strategy if not given explicity
|
113
|
-
if group_size is not None and value is None:
|
114
|
-
if group_size > 0:
|
115
|
-
return QuantizationStrategy.GROUP
|
125
|
+
def get_kv_cache(self):
|
126
|
+
"""Get the singleton KV Cache"""
|
127
|
+
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
116
128
|
|
117
|
-
|
118
|
-
return QuantizationStrategy.CHANNEL
|
129
|
+
return QuantizedKVParameterCache(self)
|
119
130
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
"group_size = -1 for 'channel'"
|
125
|
-
)
|
131
|
+
@field_validator("type", mode="before")
|
132
|
+
def validate_type(cls, value) -> QuantizationType:
|
133
|
+
if isinstance(value, str):
|
134
|
+
return QuantizationType(value.lower())
|
126
135
|
|
127
|
-
|
128
|
-
if group_size is None:
|
129
|
-
raise ValueError(f"strategy {value} requires group_size to be set.")
|
136
|
+
return value
|
130
137
|
|
138
|
+
@field_validator("group_size", mode="before")
|
139
|
+
def validate_group(cls, value) -> Union[int, None]:
|
131
140
|
if value is None:
|
132
|
-
return
|
141
|
+
return value
|
142
|
+
|
143
|
+
if value < -1:
|
144
|
+
raise ValueError(
|
145
|
+
f"Invalid group size {value}. Use group_size > 0 for "
|
146
|
+
"strategy='group' and group_size = -1 for 'channel'"
|
147
|
+
)
|
148
|
+
|
149
|
+
return value
|
150
|
+
|
151
|
+
@field_validator("strategy", mode="before")
|
152
|
+
def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
|
153
|
+
if isinstance(value, str):
|
154
|
+
return QuantizationStrategy(value.lower())
|
155
|
+
|
156
|
+
return value
|
157
|
+
|
158
|
+
@field_validator("actorder", mode="before")
|
159
|
+
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
|
160
|
+
if isinstance(value, bool):
|
161
|
+
return ActivationOrdering.GROUP if value else None
|
162
|
+
|
163
|
+
if isinstance(value, str):
|
164
|
+
return ActivationOrdering(value.lower())
|
133
165
|
|
134
166
|
return value
|
135
167
|
|
168
|
+
@model_validator(mode="after")
|
169
|
+
def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
|
170
|
+
# extract user-passed values from dictionary
|
171
|
+
strategy = model.strategy
|
172
|
+
group_size = model.group_size
|
173
|
+
actorder = model.actorder
|
174
|
+
|
175
|
+
# infer strategy
|
176
|
+
if strategy is None:
|
177
|
+
if group_size is None:
|
178
|
+
strategy = QuantizationStrategy.TENSOR
|
179
|
+
elif group_size > 0:
|
180
|
+
strategy = QuantizationStrategy.GROUP
|
181
|
+
elif group_size == -1:
|
182
|
+
strategy = QuantizationStrategy.CHANNEL
|
183
|
+
else:
|
184
|
+
raise ValueError(
|
185
|
+
f"Invalid group size {group_size}. Use group_size > 0 for "
|
186
|
+
"strategy='group' and group_size = -1 for 'channel'"
|
187
|
+
)
|
188
|
+
|
189
|
+
# validate strategy and group
|
190
|
+
if strategy == QuantizationStrategy.GROUP:
|
191
|
+
if group_size is None or group_size <= 0:
|
192
|
+
raise ValueError(
|
193
|
+
f"strategy {strategy} requires group_size to be "
|
194
|
+
"set to a positive value"
|
195
|
+
)
|
196
|
+
if (
|
197
|
+
group_size is not None
|
198
|
+
and group_size > 0
|
199
|
+
and strategy != QuantizationStrategy.GROUP
|
200
|
+
):
|
201
|
+
raise ValueError("group_size requires strategy to be set to 'group'")
|
202
|
+
|
203
|
+
# validate activation ordering and strategy
|
204
|
+
if actorder is not None and strategy != QuantizationStrategy.GROUP:
|
205
|
+
raise ValueError(
|
206
|
+
"Must use group quantization strategy in order to apply "
|
207
|
+
"activation ordering"
|
208
|
+
)
|
209
|
+
|
210
|
+
# write back modified values
|
211
|
+
model.strategy = strategy
|
212
|
+
return model
|
213
|
+
|
136
214
|
def pytorch_dtype(self) -> torch.dtype:
|
137
215
|
if self.type == QuantizationType.FLOAT:
|
138
216
|
return FP8_DTYPE
|
@@ -24,7 +24,7 @@ from compressed_tensors.quantization.quant_scheme import (
|
|
24
24
|
from compressed_tensors.quantization.utils import (
|
25
25
|
calculate_compression_ratio,
|
26
26
|
is_module_quantized,
|
27
|
-
|
27
|
+
iter_named_quantizable_modules,
|
28
28
|
module_type,
|
29
29
|
parse_out_kv_cache_args,
|
30
30
|
)
|
@@ -177,7 +177,9 @@ class QuantizationConfig(BaseModel):
|
|
177
177
|
quantization_status = None
|
178
178
|
ignore = {}
|
179
179
|
quantization_type_names = set()
|
180
|
-
for name, submodule in
|
180
|
+
for name, submodule in iter_named_quantizable_modules(
|
181
|
+
model, include_children=True, include_attn=True
|
182
|
+
):
|
181
183
|
layer_type = module_type(submodule)
|
182
184
|
if not is_module_quantized(submodule):
|
183
185
|
if layer_type not in ignore:
|
@@ -199,6 +201,13 @@ class QuantizationConfig(BaseModel):
|
|
199
201
|
if len(quant_scheme_to_layers) == 0: # No quantized layers
|
200
202
|
return None
|
201
203
|
|
204
|
+
# kv-cache only, no weight/activation quantization
|
205
|
+
if (
|
206
|
+
len(quantization_type_names) == 1
|
207
|
+
and "attention" in list(quantization_type_names)[0].lower()
|
208
|
+
):
|
209
|
+
quantization_type_names.add("Linear")
|
210
|
+
|
202
211
|
# clean up ignore list, we can leave out layers types if none of the
|
203
212
|
# instances are quantized
|
204
213
|
consolidated_ignore = []
|
@@ -241,6 +250,9 @@ class QuantizationConfig(BaseModel):
|
|
241
250
|
)
|
242
251
|
|
243
252
|
def requires_calibration_data(self):
|
253
|
+
if self.kv_cache_scheme is not None:
|
254
|
+
return True
|
255
|
+
|
244
256
|
for _, scheme in self.config_groups.items():
|
245
257
|
if scheme.input_activations is not None:
|
246
258
|
if not scheme.input_activations.dynamic:
|
@@ -57,15 +57,9 @@ class QuantizationScheme(BaseModel):
|
|
57
57
|
# default to quantizing all Linear layers
|
58
58
|
targets = ["Linear"]
|
59
59
|
|
60
|
-
# default
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
# default to 8 bit integer asymmetric quantization
|
65
|
-
input_activations = QuantizationArgs(num_bits=8, symmetric=True)
|
66
|
-
|
67
|
-
# Do not quantize the output activations
|
68
|
-
# by default
|
60
|
+
# by default, activations and weights are left unquantized
|
61
|
+
weights = None
|
62
|
+
input_activations = None
|
69
63
|
output_activations = None
|
70
64
|
|
71
65
|
return cls(
|
@@ -111,8 +105,10 @@ def is_preset_scheme(name: str) -> bool:
|
|
111
105
|
return name.upper() in PRESET_SCHEMES
|
112
106
|
|
113
107
|
|
108
|
+
UNQUANTIZED = dict()
|
109
|
+
|
114
110
|
# 8 bit integer weights and 8 bit activations quantization
|
115
|
-
|
111
|
+
INT8_W8A8 = dict(
|
116
112
|
weights=QuantizationArgs(
|
117
113
|
num_bits=8,
|
118
114
|
type=QuantizationType.INT,
|
@@ -153,7 +149,7 @@ W4A16 = dict(
|
|
153
149
|
)
|
154
150
|
|
155
151
|
# 4 bit integer weights and 8 bit activations quantization
|
156
|
-
|
152
|
+
INT8_W4A8 = dict(
|
157
153
|
weights=QuantizationArgs(
|
158
154
|
num_bits=4,
|
159
155
|
type=QuantizationType.INT,
|
@@ -208,12 +204,15 @@ FP8_DYNAMIC = dict(
|
|
208
204
|
)
|
209
205
|
|
210
206
|
PRESET_SCHEMES = {
|
207
|
+
# Unquantized (no-op)
|
208
|
+
"UNQUANTIZED": UNQUANTIZED,
|
211
209
|
# Integer weight only schemes
|
212
210
|
"W8A16": W8A16,
|
213
211
|
"W4A16": W4A16,
|
214
212
|
# Integer weight and activation schemes
|
215
|
-
"W8A8":
|
216
|
-
"
|
213
|
+
"W8A8": INT8_W8A8,
|
214
|
+
"INT8": INT8_W8A8, # alias for W8A8
|
215
|
+
"W4A8": INT8_W4A8,
|
217
216
|
# Float weight and activation schemes
|
218
217
|
"FP8": FP8,
|
219
218
|
"FP8_DYNAMIC": FP8_DYNAMIC,
|
@@ -13,8 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
import
|
17
|
-
from typing import List, Optional, Tuple
|
16
|
+
from typing import Generator, List, Optional, Tuple
|
18
17
|
|
19
18
|
import torch
|
20
19
|
from compressed_tensors.quantization.observers.base import Observer
|
@@ -28,7 +27,6 @@ __all__ = [
|
|
28
27
|
"infer_quantization_status",
|
29
28
|
"is_module_quantized",
|
30
29
|
"is_model_quantized",
|
31
|
-
"iter_named_leaf_modules",
|
32
30
|
"module_type",
|
33
31
|
"calculate_compression_ratio",
|
34
32
|
"get_torch_bit_depth",
|
@@ -36,9 +34,14 @@ __all__ = [
|
|
36
34
|
"parse_out_kv_cache_args",
|
37
35
|
"KV_CACHE_TARGETS",
|
38
36
|
"is_kv_cache_quant_scheme",
|
37
|
+
"iter_named_leaf_modules",
|
38
|
+
"iter_named_quantizable_modules",
|
39
39
|
]
|
40
40
|
|
41
|
-
|
41
|
+
# target the self_attn layer
|
42
|
+
# QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale
|
43
|
+
KV_CACHE_TARGETS = ["re:.*self_attn$"]
|
44
|
+
|
42
45
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
43
46
|
|
44
47
|
|
@@ -106,11 +109,10 @@ def module_type(module: Module) -> str:
|
|
106
109
|
return type(module).__name__
|
107
110
|
|
108
111
|
|
109
|
-
def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
|
112
|
+
def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]:
|
110
113
|
"""
|
111
114
|
Yields modules that do not have any submodules except observers. The observers
|
112
115
|
themselves are not yielded
|
113
|
-
|
114
116
|
:param model: model to get leaf modules of
|
115
117
|
:returns: generator tuple of (name, leaf_submodule)
|
116
118
|
"""
|
@@ -128,6 +130,37 @@ def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
|
|
128
130
|
yield name, submodule
|
129
131
|
|
130
132
|
|
133
|
+
def iter_named_quantizable_modules(
|
134
|
+
model: Module, include_children: bool = True, include_attn: bool = False
|
135
|
+
) -> Generator[Tuple[str, Module], None, None]:
|
136
|
+
"""
|
137
|
+
Yield name and submodule of
|
138
|
+
- leaf modules, set by include_children
|
139
|
+
- attention modyles, set by include_attn
|
140
|
+
|
141
|
+
:param model: model to get leaf modules of
|
142
|
+
:param include_children: flag to get the leaf modules
|
143
|
+
:param inlcude_attn: flag to get the attention modules
|
144
|
+
:returns: generator tuple of (name, submodule)
|
145
|
+
"""
|
146
|
+
for name, submodule in model.named_modules():
|
147
|
+
if include_children:
|
148
|
+
children = list(submodule.children())
|
149
|
+
if len(children) == 0 and not isinstance(submodule, Observer):
|
150
|
+
yield name, submodule
|
151
|
+
else:
|
152
|
+
has_non_observer_children = False
|
153
|
+
for child in children:
|
154
|
+
if not isinstance(child, Observer):
|
155
|
+
has_non_observer_children = True
|
156
|
+
|
157
|
+
if not has_non_observer_children:
|
158
|
+
yield name, submodule
|
159
|
+
if include_attn:
|
160
|
+
if name.endswith("self_attn"):
|
161
|
+
yield name, submodule
|
162
|
+
|
163
|
+
|
131
164
|
def get_torch_bit_depth(value: torch.Tensor) -> int:
|
132
165
|
"""
|
133
166
|
Determine the number of bits used to represent the dtype of a tensor
|
@@ -181,7 +214,7 @@ def calculate_compression_ratio(model: Module) -> float:
|
|
181
214
|
for parameter in model.parameters():
|
182
215
|
uncompressed_bits = get_torch_bit_depth(parameter)
|
183
216
|
compressed_bits = uncompressed_bits
|
184
|
-
if is_module_quantized(submodule):
|
217
|
+
if is_module_quantized(submodule) and submodule.quantization_scheme.weights:
|
185
218
|
compressed_bits = submodule.quantization_scheme.weights.num_bits
|
186
219
|
|
187
220
|
num_weights = parameter.numel()
|
@@ -204,19 +237,11 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
|
|
204
237
|
:param scheme: The QuantizationScheme to investigate
|
205
238
|
:return: boolean flag
|
206
239
|
"""
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
is_match_targets = any(
|
211
|
-
[re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
|
212
|
-
)
|
213
|
-
else:
|
214
|
-
# match on the exact KV_CACHE_TARGETS
|
215
|
-
# if there are multiple targets
|
216
|
-
is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
|
240
|
+
for target in scheme.targets:
|
241
|
+
if target in KV_CACHE_TARGETS:
|
242
|
+
return True
|
217
243
|
|
218
|
-
|
219
|
-
return is_match_targets and is_match_output_activations
|
244
|
+
return False
|
220
245
|
|
221
246
|
|
222
247
|
def parse_out_kv_cache_args(
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional
|
15
|
+
from typing import Any, Optional
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from transformers import AutoConfig
|
@@ -22,6 +22,8 @@ __all__ = [
|
|
22
22
|
"infer_compressor_from_model_config",
|
23
23
|
"fix_fsdp_module_name",
|
24
24
|
"tensor_follows_mask_structure",
|
25
|
+
"replace_module",
|
26
|
+
"is_compressed_tensors_config",
|
25
27
|
]
|
26
28
|
|
27
29
|
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
@@ -90,3 +92,30 @@ def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
|
|
90
92
|
raise ValueError()
|
91
93
|
|
92
94
|
return True
|
95
|
+
|
96
|
+
|
97
|
+
def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
|
98
|
+
if "." in name:
|
99
|
+
parent_name = name.rsplit(".", 1)[0]
|
100
|
+
child_name = name[len(parent_name) + 1 :]
|
101
|
+
parent = model.get_submodule(parent_name)
|
102
|
+
else:
|
103
|
+
parent_name = ""
|
104
|
+
parent = model
|
105
|
+
child_name = name
|
106
|
+
setattr(parent, child_name, new_module)
|
107
|
+
|
108
|
+
|
109
|
+
def is_compressed_tensors_config(compression_config: Any) -> bool:
|
110
|
+
"""
|
111
|
+
Returns True if CompressedTensorsConfig is available from transformers and
|
112
|
+
compression_config is an instance of CompressedTensorsConfig
|
113
|
+
|
114
|
+
See: https://github.com/huggingface/transformers/pull/31704
|
115
|
+
"""
|
116
|
+
try:
|
117
|
+
from transformers.utils.quantization_config import CompressedTensorsConfig
|
118
|
+
|
119
|
+
return isinstance(compression_config, CompressedTensorsConfig)
|
120
|
+
except ImportError:
|
121
|
+
return False
|
@@ -40,7 +40,13 @@ def get_execution_device(module: Module) -> torch.device:
|
|
40
40
|
"""
|
41
41
|
if is_module_offloaded(module):
|
42
42
|
return module._hf_hook.execution_device
|
43
|
-
|
43
|
+
device = next(module.parameters()).device
|
44
|
+
|
45
|
+
# offload only gets set for leaf modules, fallback to checking for device type
|
46
|
+
if device.type == "meta":
|
47
|
+
return module._hf_hook.execution_device
|
48
|
+
|
49
|
+
return device
|
44
50
|
|
45
51
|
|
46
52
|
def get_offloaded_device(module: Module) -> torch.device:
|
@@ -83,8 +89,11 @@ def update_parameter_data(
|
|
83
89
|
|
84
90
|
:param module: layer containing the parameter to update
|
85
91
|
:param new_param_data: tensor to update parameter with
|
86
|
-
:param param_name:
|
92
|
+
:param param_name: name of layer parameter to update
|
87
93
|
"""
|
94
|
+
if not hasattr(module, param_name):
|
95
|
+
return
|
96
|
+
|
88
97
|
device = next(module.parameters()).device
|
89
98
|
|
90
99
|
offloaded = False
|
@@ -93,6 +102,9 @@ def update_parameter_data(
|
|
93
102
|
offloaded = True
|
94
103
|
|
95
104
|
parameter = getattr(module, param_name, None)
|
105
|
+
if parameter is None:
|
106
|
+
raise ValueError("Attempted to update uninitialized parameter")
|
107
|
+
|
96
108
|
dtype = parameter.dtype
|
97
109
|
parameter.data = new_param_data.to(device).to(dtype)
|
98
110
|
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Set, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
|
19
|
+
|
20
|
+
__all__ = ["safe_permute"]
|
21
|
+
|
22
|
+
|
23
|
+
# these datatypes are missing implementations required for standard permutation
|
24
|
+
_EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
|
25
|
+
|
26
|
+
|
27
|
+
def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
28
|
+
"""
|
29
|
+
Perform out-of-place permutation without using torch.Tensor.index_put_,
|
30
|
+
whose implementation is missing for datatypes such as `torch.float8_e4m3fn`
|
31
|
+
|
32
|
+
:param value: tensor to permute
|
33
|
+
:param perm: permutation map
|
34
|
+
:param dim: dimension along which to apply permutation
|
35
|
+
:return: permuted value
|
36
|
+
"""
|
37
|
+
dtype_tuple = (value.dtype, value.device)
|
38
|
+
|
39
|
+
if dtype_tuple in _EXPERIMENTAL_DTYPES:
|
40
|
+
return _fallback_permute(value, perm, dim)
|
41
|
+
|
42
|
+
try:
|
43
|
+
return value[tuple([slice(None)] * dim + [perm])]
|
44
|
+
except RuntimeError:
|
45
|
+
# Mark dtype as experimental if advanced indexing fails
|
46
|
+
_EXPERIMENTAL_DTYPES.add(dtype_tuple)
|
47
|
+
return _fallback_permute(value, perm, dim)
|
48
|
+
|
49
|
+
|
50
|
+
def _fallback_permute(
|
51
|
+
value: torch.Tensor, perm: torch.Tensor, dim: int
|
52
|
+
) -> torch.Tensor:
|
53
|
+
"""
|
54
|
+
Fallback permutation method for experimental dtypes.
|
55
|
+
|
56
|
+
:param value: tensor to permute
|
57
|
+
:param perm: permutation map
|
58
|
+
:param dim: dimension along which to apply permutation
|
59
|
+
:return: permuted value
|
60
|
+
"""
|
61
|
+
value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
|
62
|
+
orig_slices = [slice(None)] * (dim + 1)
|
63
|
+
perm_slices = [slice(None)] * (dim + 1)
|
64
|
+
|
65
|
+
for index, perm_index in enumerate(perm):
|
66
|
+
orig_slices[dim] = index
|
67
|
+
perm_slices[dim] = perm_index
|
68
|
+
value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
|
69
|
+
|
70
|
+
return value_ret
|
compressed_tensors/version.py
CHANGED