compressed-tensors 0.10.1a20250604__py3-none-any.whl → 0.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressors/model_compressor.py +7 -1
- compressed_tensors/compressors/sparse_compressors/dense.py +19 -1
- compressed_tensors/quantization/lifecycle/apply.py +1 -3
- compressed_tensors/transform/__init__.py +5 -0
- compressed_tensors/transform/factory/__init__.py +13 -0
- compressed_tensors/transform/factory/base.py +164 -0
- compressed_tensors/transform/factory/hadamard.py +79 -0
- compressed_tensors/transform/factory/matrix_multiply.py +90 -0
- compressed_tensors/transform/factory/random_hadamard.py +34 -0
- compressed_tensors/transform/transform_args.py +18 -2
- compressed_tensors/transform/utils/__init__.py +13 -0
- compressed_tensors/transform/utils/hadamard.py +160 -0
- compressed_tensors/transform/utils/hadamards.safetensors +0 -0
- compressed_tensors/transform/utils/utils.py +91 -0
- compressed_tensors/utils/helpers.py +53 -0
- compressed_tensors/utils/offload.py +158 -71
- compressed_tensors/version.py +2 -2
- {compressed_tensors-0.10.1a20250604.dist-info → compressed_tensors-0.10.2.dist-info}/METADATA +1 -1
- {compressed_tensors-0.10.1a20250604.dist-info → compressed_tensors-0.10.2.dist-info}/RECORD +22 -13
- {compressed_tensors-0.10.1a20250604.dist-info → compressed_tensors-0.10.2.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.10.1a20250604.dist-info → compressed_tensors-0.10.2.dist-info}/licenses/LICENSE +0 -0
- {compressed_tensors-0.10.1a20250604.dist-info → compressed_tensors-0.10.2.dist-info}/top_level.txt +0 -0
@@ -68,6 +68,10 @@ from transformers import AutoConfig
|
|
68
68
|
from transformers.file_utils import CONFIG_NAME
|
69
69
|
|
70
70
|
|
71
|
+
if TYPE_CHECKING:
|
72
|
+
from compressed_tensors.compressors import BaseQuantizationCompressor
|
73
|
+
|
74
|
+
|
71
75
|
__all__ = ["ModelCompressor", "map_module_to_scheme"]
|
72
76
|
|
73
77
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
@@ -257,7 +261,9 @@ class ModelCompressor:
|
|
257
261
|
self.sparsity_config = sparsity_config
|
258
262
|
self.quantization_config = quantization_config
|
259
263
|
self.sparsity_compressor = None
|
260
|
-
self.quantization_compressor
|
264
|
+
self.quantization_compressor: Optional[
|
265
|
+
Union[BaseQuantizationCompressor, DenseCompressor]
|
266
|
+
] = None
|
261
267
|
|
262
268
|
if sparsity_config is not None:
|
263
269
|
self.sparsity_compressor = BaseCompressor.load_from_registry(
|
@@ -12,13 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Dict, Generator, Tuple
|
15
|
+
from typing import TYPE_CHECKING, Dict, Generator, Tuple
|
16
16
|
|
17
|
+
import torch
|
17
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
18
19
|
from compressed_tensors.config import CompressionFormat
|
19
20
|
from torch import Tensor
|
20
21
|
|
21
22
|
|
23
|
+
if TYPE_CHECKING:
|
24
|
+
from compressed_tensors.quantization import QuantizationScheme
|
25
|
+
|
26
|
+
|
22
27
|
@BaseCompressor.register(name=CompressionFormat.dense.value)
|
23
28
|
class DenseCompressor(BaseCompressor):
|
24
29
|
"""
|
@@ -47,3 +52,16 @@ class DenseCompressor(BaseCompressor):
|
|
47
52
|
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
|
48
53
|
for key, value in state_dict.items():
|
49
54
|
yield key, value
|
55
|
+
|
56
|
+
def decompress_module_from_state_dict(
|
57
|
+
self,
|
58
|
+
prefix: str,
|
59
|
+
state_dict: Dict[str, torch.Tensor],
|
60
|
+
scheme: "QuantizationScheme",
|
61
|
+
) -> Dict[str, torch.Tensor]:
|
62
|
+
"""
|
63
|
+
This function is implemented as a workaround because of how
|
64
|
+
`ModelCompressor.quantization_compressor` can be set to either
|
65
|
+
an instance of `BaseQuantizationCompressor` or `DenseCompressor`.
|
66
|
+
"""
|
67
|
+
return state_dict.copy()
|
@@ -183,9 +183,7 @@ def apply_quantization_config(
|
|
183
183
|
replace_module(model, name, compressed_linear)
|
184
184
|
|
185
185
|
# target matched - add layer and scheme to target list
|
186
|
-
submodule.quantization_scheme =
|
187
|
-
target_to_scheme, targets, name
|
188
|
-
)
|
186
|
+
submodule.quantization_scheme = scheme
|
189
187
|
|
190
188
|
names_to_scheme[name] = submodule.quantization_scheme
|
191
189
|
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,164 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn.utils.parametrize as P
|
20
|
+
from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
|
21
|
+
from compressed_tensors.registry.registry import RegistryMixin, T
|
22
|
+
from compressed_tensors.transform import (
|
23
|
+
TransformArgs,
|
24
|
+
TransformLocation,
|
25
|
+
TransformScheme,
|
26
|
+
)
|
27
|
+
from compressed_tensors.utils import (
|
28
|
+
align_module_device,
|
29
|
+
has_offloaded_params,
|
30
|
+
patch_attr,
|
31
|
+
register_offload_module,
|
32
|
+
update_offload_parameter,
|
33
|
+
)
|
34
|
+
from torch import Tensor
|
35
|
+
from torch.nn import Module, Parameter
|
36
|
+
|
37
|
+
|
38
|
+
__all__ = ["TransformFactory", "TransformBase"]
|
39
|
+
|
40
|
+
|
41
|
+
class TransformFactory(RegistryMixin, ABC):
|
42
|
+
"""
|
43
|
+
Abstract factory base used to create and apply transforms to a model
|
44
|
+
|
45
|
+
:param name: name associated with transform scheme
|
46
|
+
:param scheme: transform scheme which defines how transforms should be created
|
47
|
+
:param seed: random seed used to transform weight randomization
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
51
|
+
self.name = name
|
52
|
+
self.scheme = scheme
|
53
|
+
self.generator = torch.Generator()
|
54
|
+
if seed is not None:
|
55
|
+
self.generator.manual_seed(seed)
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T:
|
59
|
+
"""
|
60
|
+
Create a transform factory from a scheme
|
61
|
+
|
62
|
+
:param scheme: defines how transforms should be created
|
63
|
+
:param kwargs: TransformFactory constructor arguments
|
64
|
+
:return: subclass of `TransformFactory` corresponding to the scheme type
|
65
|
+
"""
|
66
|
+
constructor = cls.get_value_from_registry(name=scheme.type)
|
67
|
+
return constructor(scheme=scheme, **kwargs)
|
68
|
+
|
69
|
+
@abstractmethod
|
70
|
+
def create_transform(self, module: Module, args: TransformArgs) -> "TransformBase":
|
71
|
+
"""
|
72
|
+
Abstract method which defines how a transform should be created. May utilize
|
73
|
+
caching to maximize shared memory
|
74
|
+
|
75
|
+
:param module: parent module that transform will be applied to
|
76
|
+
:param args: defines how the transform will be applied to the module
|
77
|
+
:return: instance of TransformBase
|
78
|
+
"""
|
79
|
+
raise NotImplementedError()
|
80
|
+
|
81
|
+
def apply_to_model(self, model: Module):
|
82
|
+
"""
|
83
|
+
Create transforms and apply them to the model
|
84
|
+
|
85
|
+
:param model: module to apply transforms to
|
86
|
+
"""
|
87
|
+
for arg in self.scheme.apply:
|
88
|
+
for name, module in list(model.named_modules()):
|
89
|
+
if is_target(name, module, arg.targets, arg.ignore):
|
90
|
+
self._apply_to_module(module, arg)
|
91
|
+
|
92
|
+
def _apply_to_module(self, module: Module, args: TransformArgs):
|
93
|
+
"""
|
94
|
+
Create transforms and apply them to the module
|
95
|
+
|
96
|
+
:param module: target module to apply transforms to
|
97
|
+
:param args: defines how the transform will be applied to the target module
|
98
|
+
"""
|
99
|
+
# create transform as submodule
|
100
|
+
transform_name = f"{self.name}_{args.location.value}"
|
101
|
+
transform = self.create_transform(module, args)
|
102
|
+
register_offload_module(module, transform_name, transform) # (1)
|
103
|
+
|
104
|
+
# register input transformation hook
|
105
|
+
if args.location == TransformLocation.INPUT:
|
106
|
+
|
107
|
+
def input_hook(_, args):
|
108
|
+
input = args[0]
|
109
|
+
return transform(input)
|
110
|
+
|
111
|
+
module.register_forward_pre_hook(input_hook, prepend=True)
|
112
|
+
|
113
|
+
# eagerly apply transformation to weight
|
114
|
+
elif args.location in (
|
115
|
+
TransformLocation.WEIGHT_INPUT,
|
116
|
+
TransformLocation.WEIGHT_OUTPUT,
|
117
|
+
):
|
118
|
+
assert isinstance(module, torch.nn.Linear)
|
119
|
+
assert module.bias is None
|
120
|
+
|
121
|
+
with torch.no_grad(), align_module_device(module):
|
122
|
+
update_offload_parameter(module, "weight", transform(module.weight))
|
123
|
+
|
124
|
+
if self.scheme.requires_grad:
|
125
|
+
# for training, the weight changes with every forward pass
|
126
|
+
# so we can leverage parametrization to propagate the gradient
|
127
|
+
if has_offloaded_params(module):
|
128
|
+
raise ValueError("Offloaded training is not supported")
|
129
|
+
P.register_parametrization(module, "weight", transform)
|
130
|
+
|
131
|
+
# register output transformation hook
|
132
|
+
elif args.location == TransformLocation.OUTPUT:
|
133
|
+
|
134
|
+
def output_hook(_, _input, output):
|
135
|
+
return transform(output)
|
136
|
+
|
137
|
+
module.register_forward_hook(output_hook)
|
138
|
+
|
139
|
+
# other locations such as q_attn and k_attn have not been implemented
|
140
|
+
else:
|
141
|
+
raise NotImplementedError()
|
142
|
+
|
143
|
+
# (1) even in the `weight` cases, this submodule attachment is needed in order
|
144
|
+
# to support saving in the frozen state
|
145
|
+
|
146
|
+
|
147
|
+
class TransformBase(Module, ABC):
|
148
|
+
"""
|
149
|
+
Represents the application of a transform accord to TransformArgs
|
150
|
+
"""
|
151
|
+
|
152
|
+
args: TransformArgs
|
153
|
+
weight: Parameter
|
154
|
+
|
155
|
+
@abstractmethod
|
156
|
+
def forward(self, value: Tensor) -> Tensor:
|
157
|
+
raise NotImplementedError()
|
158
|
+
|
159
|
+
def right_inverse(self, value: Tensor) -> Tensor:
|
160
|
+
with patch_attr(self.args, "inverse", not self.args.inverse):
|
161
|
+
return self.forward(value)
|
162
|
+
|
163
|
+
def __repr__(self):
|
164
|
+
return f"{self.__class__.__name__}(inverse={self.args.inverse})"
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.transform import TransformArgs, TransformScheme
|
19
|
+
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
|
20
|
+
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
21
|
+
from compressed_tensors.transform.utils.utils import (
|
22
|
+
apply_transform_weight,
|
23
|
+
get_matrix_size,
|
24
|
+
)
|
25
|
+
from compressed_tensors.utils import get_offloaded_device
|
26
|
+
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
27
|
+
from torch import Tensor, device, dtype
|
28
|
+
from torch.nn import Linear, Module, Parameter
|
29
|
+
|
30
|
+
|
31
|
+
@TransformFactory.register("hadamard")
|
32
|
+
class HadamardFactory(TransformFactory):
|
33
|
+
"""
|
34
|
+
Factory used to apply hadamard transforms to a model
|
35
|
+
|
36
|
+
:param name: name associated with transform scheme
|
37
|
+
:param scheme: transform scheme which defines how transforms should be created
|
38
|
+
:param seed: random seed used to transform weight randomization
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
42
|
+
super().__init__(name, scheme, seed)
|
43
|
+
self.weights = ParameterizedDefaultDict(self._create_weight)
|
44
|
+
|
45
|
+
def create_transform(self, module: Module, args: TransformArgs):
|
46
|
+
"""
|
47
|
+
Create a HadamardTransform for applying to a module. Transforms with the same
|
48
|
+
size, dtype, and device are cached
|
49
|
+
|
50
|
+
:param module: parent module that transform will be applied to
|
51
|
+
:param args: defines how the transform will be applied to the module
|
52
|
+
"""
|
53
|
+
assert isinstance(module, Linear)
|
54
|
+
size = get_matrix_size(module, args.location)
|
55
|
+
dtype = module.weight.dtype
|
56
|
+
device = get_offloaded_device(module)
|
57
|
+
|
58
|
+
weight = self.weights[size, dtype, device]
|
59
|
+
return HadamardTransform(weight, args)
|
60
|
+
|
61
|
+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
62
|
+
data = deterministic_hadamard_matrix(size, dtype, device)
|
63
|
+
data = data.to(dtype=dtype, device=device)
|
64
|
+
return Parameter(data, requires_grad=self.scheme.requires_grad)
|
65
|
+
|
66
|
+
|
67
|
+
class HadamardTransform(TransformBase):
|
68
|
+
def __init__(self, weight: Parameter, args: TransformArgs):
|
69
|
+
super().__init__()
|
70
|
+
self.weight = weight
|
71
|
+
self.args = args
|
72
|
+
|
73
|
+
def forward(self, value: Tensor) -> Tensor:
|
74
|
+
if not self.args.inverse:
|
75
|
+
weight = self.weight
|
76
|
+
else:
|
77
|
+
weight = self.weight.T
|
78
|
+
|
79
|
+
return apply_transform_weight(weight, value, self.args.location)
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.transform import TransformArgs, TransformScheme
|
19
|
+
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
|
20
|
+
from compressed_tensors.transform.utils.utils import (
|
21
|
+
apply_transform_weight,
|
22
|
+
get_matrix_size,
|
23
|
+
)
|
24
|
+
from compressed_tensors.utils import get_offloaded_device
|
25
|
+
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
26
|
+
from torch import Tensor, device, dtype
|
27
|
+
from torch.nn import Linear, Module, Parameter
|
28
|
+
|
29
|
+
|
30
|
+
@TransformFactory.register("random-matrix")
|
31
|
+
class RandomMatrixFactory(TransformFactory):
|
32
|
+
"""
|
33
|
+
Factory used to apply random matrix transforms to a model
|
34
|
+
|
35
|
+
:param name: name associated with transform scheme
|
36
|
+
:param scheme: transform scheme which defines how transforms should be created
|
37
|
+
:param seed: random seed used to transform weight randomization
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
41
|
+
super().__init__(name, scheme, seed)
|
42
|
+
self.weights = ParameterizedDefaultDict(self._create_weight)
|
43
|
+
self.inverses = ParameterizedDefaultDict(self._create_inverse)
|
44
|
+
|
45
|
+
def create_transform(self, module: Module, args: TransformArgs):
|
46
|
+
"""
|
47
|
+
Create a RandomMatrixTransform for applying to a module. Transforms with the
|
48
|
+
same size, dtype, and device are cached
|
49
|
+
|
50
|
+
:param module: parent module that transform will be applied to
|
51
|
+
:param args: defines how the transform will be applied to the module
|
52
|
+
"""
|
53
|
+
assert isinstance(module, Linear)
|
54
|
+
size = get_matrix_size(module, args.location)
|
55
|
+
dtype = module.weight.dtype
|
56
|
+
device = get_offloaded_device(module)
|
57
|
+
|
58
|
+
weight = self.weights[size, dtype, device]
|
59
|
+
if args.inverse:
|
60
|
+
weight = self.inverses[weight]
|
61
|
+
|
62
|
+
return RandomMatrixTransform(weight, args)
|
63
|
+
|
64
|
+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
65
|
+
data = torch.rand(
|
66
|
+
(size, size), generator=self.generator, dtype=dtype, device=device
|
67
|
+
)
|
68
|
+
return Parameter(data, requires_grad=self.scheme.requires_grad)
|
69
|
+
|
70
|
+
def _create_inverse(self, weight: Parameter) -> Parameter:
|
71
|
+
data = high_precision_invert(weight.data)
|
72
|
+
return Parameter(data, requires_grad=False)
|
73
|
+
|
74
|
+
|
75
|
+
class RandomMatrixTransform(TransformBase):
|
76
|
+
def __init__(self, weight: Tensor, args: TransformArgs):
|
77
|
+
super().__init__()
|
78
|
+
self.weight = weight # is an inverse if args.inverse
|
79
|
+
self.args = args
|
80
|
+
|
81
|
+
def forward(self, value: Tensor) -> Parameter:
|
82
|
+
return apply_transform_weight(self.weight, value, self.args.location)
|
83
|
+
|
84
|
+
def right_inverse(self, value: Tensor) -> Tensor:
|
85
|
+
inverse = high_precision_invert(self.weight)
|
86
|
+
return apply_transform_weight(inverse, value, self.args.location)
|
87
|
+
|
88
|
+
|
89
|
+
def high_precision_invert(weight: Tensor) -> Tensor:
|
90
|
+
return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from compressed_tensors.transform import HadamardFactory, TransformFactory
|
16
|
+
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
|
17
|
+
from torch import device, dtype
|
18
|
+
from torch.nn import Parameter
|
19
|
+
|
20
|
+
|
21
|
+
@TransformFactory.register("random-hadamard")
|
22
|
+
class RandomHadamardFactory(HadamardFactory):
|
23
|
+
"""
|
24
|
+
Factory used to apply random hadamard transforms to a model
|
25
|
+
|
26
|
+
:param name: name associated with transform scheme
|
27
|
+
:param scheme: transform scheme which defines how transforms should be created
|
28
|
+
:param seed: random seed used to transform weight randomization
|
29
|
+
"""
|
30
|
+
|
31
|
+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
32
|
+
data = random_hadamard_matrix(size, dtype, device, self.generator)
|
33
|
+
data = data.to(dtype=dtype, device=device)
|
34
|
+
return Parameter(data, requires_grad=self.scheme.requires_grad)
|
@@ -13,15 +13,31 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from enum import Enum
|
16
|
-
from typing import
|
16
|
+
from typing import List
|
17
17
|
|
18
18
|
from pydantic import BaseModel, Field, field_validator
|
19
19
|
|
20
20
|
|
21
|
-
__all__ = ["TransformArgs"]
|
21
|
+
__all__ = ["TransformArgs", "TransformLocation"]
|
22
22
|
|
23
23
|
|
24
24
|
class TransformLocation(str, Enum):
|
25
|
+
"""
|
26
|
+
Enum representing which parameters/activations a transform weight should be applied
|
27
|
+
to on a given module.
|
28
|
+
|
29
|
+
| -------------------------------------------------------------------------------------------------------- | # noqa: E501
|
30
|
+
| Name | Runtime | Values | Locations Where Inverse Could Be Applied | # noqa: E501
|
31
|
+
| --------------- | ----------- | ------------- | -------------------------------------------------------- | # noqa: E501
|
32
|
+
| `INPUT` | online | activations | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.WEIGHT_INPUT` | # noqa: E501
|
33
|
+
| `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501
|
34
|
+
| `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
|
35
|
+
| `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
|
36
|
+
| `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501
|
37
|
+
| `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501
|
38
|
+
| -------------------------------------------------------------------------------------------------------- | # noqa: E501
|
39
|
+
"""
|
40
|
+
|
25
41
|
INPUT = "input"
|
26
42
|
WEIGHT_INPUT = "weight_input"
|
27
43
|
WEIGHT_OUTPUT = "weight_output"
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from safetensors import safe_open
|
21
|
+
|
22
|
+
|
23
|
+
REPO_PATH = Path(__file__).parent / "hadamards.safetensors"
|
24
|
+
|
25
|
+
|
26
|
+
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
|
27
|
+
|
28
|
+
|
29
|
+
# note that hadamard matrix multiplication can be accelerated using a library such as
|
30
|
+
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
|
31
|
+
|
32
|
+
|
33
|
+
def deterministic_hadamard_matrix(
|
34
|
+
size: int,
|
35
|
+
dtype: torch.dtype = torch.bfloat16,
|
36
|
+
device: torch.device = torch.device("cpu"),
|
37
|
+
) -> torch.Tensor:
|
38
|
+
"""
|
39
|
+
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
|
40
|
+
`n` must be a power of 2.
|
41
|
+
|
42
|
+
Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
|
43
|
+
|
44
|
+
:param size: order of the matrix, must be a power of 2
|
45
|
+
:param dtype: data type of matrix
|
46
|
+
:param device: device to construct matrix on
|
47
|
+
:return: hadamard matrix of size `size`
|
48
|
+
"""
|
49
|
+
if size <= 0:
|
50
|
+
raise ValueError("Cannot construct deterministic hadamard of size <= 0")
|
51
|
+
|
52
|
+
log2 = int(math.log2(size))
|
53
|
+
if size != 2**log2:
|
54
|
+
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
|
55
|
+
|
56
|
+
H = torch.tensor([[1]], dtype=dtype, device=device)
|
57
|
+
|
58
|
+
# Sylvester's construction
|
59
|
+
for _ in range(log2):
|
60
|
+
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
|
61
|
+
|
62
|
+
return H / math.sqrt(size)
|
63
|
+
|
64
|
+
|
65
|
+
def random_hadamard_matrix(
|
66
|
+
size: int,
|
67
|
+
dtype: torch.dtype = torch.bfloat16,
|
68
|
+
device: torch.device = torch.device("cpu"),
|
69
|
+
gen: Optional[torch.Generator] = None,
|
70
|
+
) -> torch.Tensor:
|
71
|
+
"""
|
72
|
+
Produces a randomly generated Hadamard matrix. Differs from
|
73
|
+
`deterministic_hadamard_matrix` in that this function supports non powers of 2
|
74
|
+
and randomization using a seeded generator
|
75
|
+
|
76
|
+
Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
|
77
|
+
Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501
|
78
|
+
|
79
|
+
:param size: The dimension of the hamadard matrix
|
80
|
+
:param dtype: data type of matrix
|
81
|
+
:param device: device to construct matrix on
|
82
|
+
:param gen: Optional generator random values
|
83
|
+
:return: randomly generated hadamard matrix
|
84
|
+
"""
|
85
|
+
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
|
86
|
+
Q = Q.to(device=device)
|
87
|
+
Q = Q * 2 - 1
|
88
|
+
Q = torch.diag(Q)
|
89
|
+
return _matmul_hadU(Q) / math.sqrt(size)
|
90
|
+
|
91
|
+
|
92
|
+
def is_pow2(n: int) -> bool:
|
93
|
+
"""
|
94
|
+
Check if a number is a power of 2
|
95
|
+
|
96
|
+
:param n: number to check
|
97
|
+
:return: True iff `n` is a power of 2
|
98
|
+
"""
|
99
|
+
return n > 0 and (n & (n - 1) == 0)
|
100
|
+
|
101
|
+
|
102
|
+
def _fetch_hadamard_divisor(
|
103
|
+
n: int,
|
104
|
+
dtype: torch.dtype,
|
105
|
+
device: torch.device = torch.device("cpu"),
|
106
|
+
file_path: str = REPO_PATH,
|
107
|
+
) -> Optional[torch.Tensor]:
|
108
|
+
"""
|
109
|
+
Fetch a known hadamard matrix from the given file path. The returned matrix will
|
110
|
+
be of of size `k` such that `n / k` is a power of two. Return None if no such
|
111
|
+
matrix exists.
|
112
|
+
|
113
|
+
Note: This function reopens the safetensors file every time it is called.
|
114
|
+
This is technically inefficient, but a very small runtime cost and simpler
|
115
|
+
than forcing callers to manage the file open context
|
116
|
+
|
117
|
+
:param n: size of known hadamard matrix
|
118
|
+
:return: a known hadamard matrix of size `n` if one exists, else None
|
119
|
+
"""
|
120
|
+
with safe_open(file_path, framework="pt", device=str(device)) as file:
|
121
|
+
divisors = sorted((int(key) for key in file.keys()), reverse=True)
|
122
|
+
for divisor in divisors:
|
123
|
+
if n % divisor == 0 and is_pow2(n // divisor):
|
124
|
+
return file.get_tensor(str(divisor)).to(dtype=dtype)
|
125
|
+
|
126
|
+
return None
|
127
|
+
|
128
|
+
|
129
|
+
def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
|
130
|
+
size = X.size(0)
|
131
|
+
dtype = X.dtype
|
132
|
+
device = X.device
|
133
|
+
|
134
|
+
# Check if we have the determined hadamard matrix
|
135
|
+
hadK = _fetch_hadamard_divisor(size, dtype, device=device)
|
136
|
+
if hadK is None:
|
137
|
+
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
|
138
|
+
K = hadK.size(0)
|
139
|
+
|
140
|
+
# Reshape diag matrix with randomized -1/+1
|
141
|
+
input = X.clone().view(-1, size, 1)
|
142
|
+
output = input.clone()
|
143
|
+
while input.shape[1] > K:
|
144
|
+
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
|
145
|
+
output = output.view(input.shape)
|
146
|
+
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
|
147
|
+
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
|
148
|
+
output = output.view(input.shape[0], input.shape[1], -1)
|
149
|
+
(input, output) = (output, input)
|
150
|
+
assert input.shape[1] == K
|
151
|
+
del output
|
152
|
+
|
153
|
+
# Do not explicitly repeat - OOM
|
154
|
+
# input = torch.bmm(
|
155
|
+
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
|
156
|
+
# Use bcast instead
|
157
|
+
input = hadK.view(1, K, K).to(input) @ input
|
158
|
+
|
159
|
+
# normalize
|
160
|
+
return input.view(X.shape)
|
Binary file
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from compressed_tensors.transform import TransformLocation
|
17
|
+
|
18
|
+
|
19
|
+
__all__ = ["get_matrix_size", "apply_transform_weight"]
|
20
|
+
|
21
|
+
|
22
|
+
def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int:
|
23
|
+
"""
|
24
|
+
Determine the size of a matrix given its location on the module
|
25
|
+
|
26
|
+
:param module: module that matrix will be applied to
|
27
|
+
:param location: location on module
|
28
|
+
:return: size of matrix
|
29
|
+
"""
|
30
|
+
assert isinstance(module, torch.nn.Linear)
|
31
|
+
if location in ("input", TransformLocation.WEIGHT_INPUT):
|
32
|
+
return module.in_features
|
33
|
+
else:
|
34
|
+
return module.out_features
|
35
|
+
|
36
|
+
|
37
|
+
def apply_transform_weight(
|
38
|
+
weight: torch.Tensor,
|
39
|
+
value: torch.Tensor,
|
40
|
+
location: TransformLocation,
|
41
|
+
) -> torch.Tensor:
|
42
|
+
"""
|
43
|
+
Using the transform location, determine how to apply the transform weight to the
|
44
|
+
given value. For more info on input and output transforms, see `TransformLocation`
|
45
|
+
|
46
|
+
The following explains how weights should be applied to values according to location
|
47
|
+
|
48
|
+
let x be input activation
|
49
|
+
W be weight,
|
50
|
+
yh, xh, Wh be transformed output, input, weight
|
51
|
+
|
52
|
+
note that
|
53
|
+
y = (x W.T) // torch.nn.Linear
|
54
|
+
|
55
|
+
Choose values for yh, xh, and Wh which incorporate matrix transforms
|
56
|
+
|
57
|
+
let V, Vi be transform matrices on input side
|
58
|
+
U, Ui be transform matrices on output side
|
59
|
+
|
60
|
+
pick xh = (x V)
|
61
|
+
Wh = (U.T W Vi.T)
|
62
|
+
yh = (y U)
|
63
|
+
|
64
|
+
The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
|
65
|
+
|
66
|
+
(xh) (Wh).T = (x V) (U.T W Vi.T).T
|
67
|
+
= (x V) (Vi W.T U) // transpose matrix product identity
|
68
|
+
= (x W.T) U
|
69
|
+
= y U
|
70
|
+
= yh
|
71
|
+
|
72
|
+
:param weight: transform weight to apply
|
73
|
+
:param value: value to apply weight to
|
74
|
+
:param location: determines how weight should be applied
|
75
|
+
:return: value after transform weight has been applied
|
76
|
+
"""
|
77
|
+
|
78
|
+
if location == TransformLocation.INPUT:
|
79
|
+
return value @ weight
|
80
|
+
|
81
|
+
elif location == TransformLocation.WEIGHT_INPUT:
|
82
|
+
return value @ weight.T
|
83
|
+
|
84
|
+
elif location == TransformLocation.WEIGHT_OUTPUT:
|
85
|
+
return weight.T @ value
|
86
|
+
|
87
|
+
elif location == TransformLocation.OUTPUT:
|
88
|
+
return value @ weight
|
89
|
+
|
90
|
+
else:
|
91
|
+
raise NotImplementedError(f"{location} has not been implemented yet")
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import contextlib
|
15
16
|
import warnings
|
16
17
|
from functools import wraps
|
17
18
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
@@ -38,6 +39,8 @@ __all__ = [
|
|
38
39
|
"shard_tensor",
|
39
40
|
"pack_bitmasks",
|
40
41
|
"unpack_bitmasks",
|
42
|
+
"patch_attr",
|
43
|
+
"ParameterizedDefaultDict",
|
41
44
|
]
|
42
45
|
|
43
46
|
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
@@ -328,3 +331,53 @@ def unpack_bitmasks(
|
|
328
331
|
)
|
329
332
|
|
330
333
|
return unpacked_bitmasks_torch
|
334
|
+
|
335
|
+
|
336
|
+
@contextlib.contextmanager
|
337
|
+
def patch_attr(base: object, attr: str, value: Any):
|
338
|
+
"""
|
339
|
+
Patch the value of an object attribute. Original value is restored upon exit
|
340
|
+
|
341
|
+
:param base: object which has the attribute to patch
|
342
|
+
:param attr: name of the the attribute to patch
|
343
|
+
:param value: used to replace original value
|
344
|
+
|
345
|
+
Usage:
|
346
|
+
>>> from types import SimpleNamespace
|
347
|
+
>>> obj = SimpleNamespace()
|
348
|
+
>>> with patch_attr(obj, "attribute", "value"):
|
349
|
+
... assert obj.attribute == "value"
|
350
|
+
>>> assert not hasattr(obj, "attribute")
|
351
|
+
"""
|
352
|
+
_sentinel = object()
|
353
|
+
original_value = getattr(base, attr, _sentinel)
|
354
|
+
|
355
|
+
setattr(base, attr, value)
|
356
|
+
try:
|
357
|
+
yield
|
358
|
+
finally:
|
359
|
+
if original_value is not _sentinel:
|
360
|
+
setattr(base, attr, original_value)
|
361
|
+
else:
|
362
|
+
delattr(base, attr)
|
363
|
+
|
364
|
+
|
365
|
+
class ParameterizedDefaultDict(dict):
|
366
|
+
"""
|
367
|
+
Similar to `collections.DefaultDict`, but upon fetching a key which is missing,
|
368
|
+
the key is passed as arguments to the `default_factory`
|
369
|
+
|
370
|
+
:param default_factory: function which takes a key as input and returns the
|
371
|
+
corresponding default value
|
372
|
+
"""
|
373
|
+
|
374
|
+
def __init__(self, default_factory: Callable[[Any], Any]):
|
375
|
+
self.default_factory = default_factory
|
376
|
+
|
377
|
+
def __missing__(self, key):
|
378
|
+
if isinstance(key, tuple):
|
379
|
+
value = self.default_factory(*key)
|
380
|
+
else:
|
381
|
+
value = self.default_factory(key)
|
382
|
+
self[key] = value
|
383
|
+
return value
|
@@ -14,27 +14,30 @@
|
|
14
14
|
"""
|
15
15
|
Utilities associated with offloading functionality provided by `accelerate`.
|
16
16
|
|
17
|
-
|
|
18
|
-
| Operation
|
19
|
-
|
|
20
|
-
| Add
|
21
|
-
| Check
|
22
|
-
| Onload
|
23
|
-
| Update
|
24
|
-
| Delete
|
25
|
-
|
|
17
|
+
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
|
18
|
+
| Operation | Without offloading support | With offloading support | # noqa: E501
|
19
|
+
| ---------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
|
20
|
+
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
|
21
|
+
| Check | N/A | has_offloaded_params(module) | # noqa: E501
|
22
|
+
| Onload | N/A | with align_module_device(module) | # noqa: E501
|
23
|
+
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
|
24
|
+
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
|
25
|
+
| Add Module | module.register_module(name, child) | register_offload_module(name, child) | # noqa: E501
|
26
|
+
| Del Module | del module.name | delete_offload_module(module, name) | # noqa: E501
|
27
|
+
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
|
26
28
|
"""
|
27
29
|
|
28
30
|
import contextlib
|
29
31
|
import warnings
|
30
32
|
from functools import wraps
|
31
|
-
from
|
33
|
+
from operator import attrgetter
|
34
|
+
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Union
|
32
35
|
|
33
36
|
import torch
|
37
|
+
from compressed_tensors.utils import patch_attr
|
34
38
|
|
35
39
|
|
36
40
|
try:
|
37
|
-
from accelerate import dispatch_model
|
38
41
|
from accelerate.hooks import (
|
39
42
|
AlignDevicesHook,
|
40
43
|
add_hook_to_module,
|
@@ -45,10 +48,12 @@ try:
|
|
45
48
|
from accelerate.utils import (
|
46
49
|
OffloadedWeightsLoader,
|
47
50
|
PrefixedDataset,
|
51
|
+
find_tied_parameters,
|
48
52
|
set_module_tensor_to_device,
|
49
53
|
)
|
50
54
|
|
51
55
|
_has_accelerate = True
|
56
|
+
|
52
57
|
except ImportError:
|
53
58
|
_has_accelerate = False
|
54
59
|
AlignDevicesHook = None
|
@@ -58,8 +63,8 @@ except ImportError:
|
|
58
63
|
PrefixedDataset = None
|
59
64
|
set_module_tensor_to_device = None
|
60
65
|
named_module_tensors = None
|
61
|
-
dispatch_model = None
|
62
66
|
attach_align_device_hook = None
|
67
|
+
find_tied_parameters = None
|
63
68
|
|
64
69
|
|
65
70
|
__all__ = [
|
@@ -78,22 +83,28 @@ __all__ = [
|
|
78
83
|
"align_module_device",
|
79
84
|
"register_offload_module",
|
80
85
|
"delete_offload_module",
|
81
|
-
"
|
86
|
+
"offloaded_dispatch",
|
87
|
+
"disable_offloading",
|
88
|
+
"remove_dispatch",
|
82
89
|
]
|
83
90
|
|
84
91
|
|
85
92
|
def check_accelerate(fallback: Any):
|
86
93
|
def decorator(func: Callable[[Any], Any]):
|
87
94
|
if not _has_accelerate:
|
88
|
-
|
89
95
|
if fallback == "error":
|
90
|
-
raise ValueError(
|
91
|
-
"Please install `accelerate` in order to use this function"
|
92
|
-
)
|
93
96
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
+
@wraps(func)
|
98
|
+
def fallback_fn(*args, **kwargs):
|
99
|
+
raise ValueError(
|
100
|
+
"Please install `accelerate` in order to use this function"
|
101
|
+
)
|
102
|
+
|
103
|
+
else:
|
104
|
+
|
105
|
+
@wraps(func)
|
106
|
+
def fallback_fn(*args, **kwargs):
|
107
|
+
return fallback
|
97
108
|
|
98
109
|
return fallback_fn
|
99
110
|
|
@@ -160,22 +171,22 @@ def update_parameter_data(
|
|
160
171
|
|
161
172
|
def get_execution_device(module: torch.nn.Module) -> torch.device:
|
162
173
|
"""
|
163
|
-
Get the device which inputs should be moved to before module execution
|
174
|
+
Get the device which inputs should be moved to before module execution.
|
175
|
+
Assume that modules execute in the same order as returned by `model.modules()`
|
164
176
|
|
165
177
|
:param module: module to check, may be offloaded
|
166
178
|
:return: onload device of module
|
167
179
|
"""
|
168
|
-
|
169
|
-
|
180
|
+
for submodule in module.modules():
|
181
|
+
if has_offloaded_params(submodule):
|
182
|
+
return submodule._hf_hook.execution_device
|
170
183
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
f"Unable able to infer execution device of {module}, falling back to CPU"
|
175
|
-
)
|
176
|
-
return torch.device("cpu")
|
184
|
+
param = next(submodule.parameters(recurse=False), None)
|
185
|
+
if param is not None:
|
186
|
+
return param.device
|
177
187
|
|
178
|
-
|
188
|
+
warnings.warn(f"Unable to get execution device of {module}, falling back to CPU")
|
189
|
+
return torch.device("cpu")
|
179
190
|
|
180
191
|
|
181
192
|
def register_offload_parameter(
|
@@ -196,9 +207,24 @@ def register_offload_parameter(
|
|
196
207
|
has_onload = any(p.device != torch.device("meta") for p in module.parameters())
|
197
208
|
module.register_parameter(name, parameter)
|
198
209
|
|
210
|
+
# do everything AlignDevicesHook.init_hook does
|
211
|
+
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
|
199
212
|
if has_offloaded_params(module):
|
200
|
-
|
201
|
-
|
213
|
+
hook: AlignDevicesHook = module._hf_hook
|
214
|
+
assert hook.weights_map is not None
|
215
|
+
|
216
|
+
# append to original_devices
|
217
|
+
hook.original_devices[name] = parameter.device
|
218
|
+
|
219
|
+
# append to weights map
|
220
|
+
offload_to_weights_map(hook.weights_map, name, parameter.data, offload_device)
|
221
|
+
|
222
|
+
# append to tied_params_map
|
223
|
+
offloaded = hook.weights_map[name]
|
224
|
+
if hook.tied_params_map is not None:
|
225
|
+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
|
226
|
+
|
227
|
+
# perform offloading
|
202
228
|
if not has_onload:
|
203
229
|
set_module_tensor_to_device(module, name, "meta")
|
204
230
|
|
@@ -206,7 +232,7 @@ def register_offload_parameter(
|
|
206
232
|
def update_offload_parameter(
|
207
233
|
module: torch.nn.Module,
|
208
234
|
name: str,
|
209
|
-
data:
|
235
|
+
data: torch.Tensor,
|
210
236
|
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
211
237
|
):
|
212
238
|
"""
|
@@ -219,7 +245,7 @@ def update_offload_parameter(
|
|
219
245
|
:param offload_device: device on which weight will be offloaded to. If None is
|
220
246
|
provided, then infer device from parameters on module
|
221
247
|
"""
|
222
|
-
param = getattr(module, name)
|
248
|
+
param: torch.nn.Parameter = getattr(module, name)
|
223
249
|
if param.data.shape != data.shape:
|
224
250
|
warnings.warn(
|
225
251
|
f"Shape of parameter being updated {param.data.shape} does not match shape "
|
@@ -227,7 +253,7 @@ def update_offload_parameter(
|
|
227
253
|
)
|
228
254
|
|
229
255
|
# copy data into onloaded parameter if applicable
|
230
|
-
if param.device != torch.device("meta"):
|
256
|
+
if param.device != torch.device("meta") and data is not param.data:
|
231
257
|
param.data.copy_(data)
|
232
258
|
|
233
259
|
# update offload dict
|
@@ -412,7 +438,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
|
|
412
438
|
hook: AlignDevicesHook = base._hf_hook
|
413
439
|
assert hook.offload
|
414
440
|
assert hook.weights_map is not None
|
415
|
-
assert hook.tied_params_map is not None
|
416
441
|
|
417
442
|
# offloading kwargs for submodule
|
418
443
|
place_submodules = False
|
@@ -427,7 +452,8 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
|
|
427
452
|
module, include_buffers=offload_buffers, recurse=place_submodules
|
428
453
|
):
|
429
454
|
offloaded = param.to(offload_device)
|
430
|
-
hook.tied_params_map
|
455
|
+
if hook.tied_params_map is not None:
|
456
|
+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
|
431
457
|
offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
|
432
458
|
|
433
459
|
# if the parent places submodules, offload here
|
@@ -455,9 +481,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
|
|
455
481
|
|
456
482
|
base.register_module(name, module)
|
457
483
|
|
458
|
-
# (1): Since we cannot know which pointers are shared when we add parameters in an
|
459
|
-
# online way, assume that all pointers are shared. This comes at no runtime cost
|
460
|
-
|
461
484
|
|
462
485
|
def delete_offload_module(base: torch.nn.Module, name: str):
|
463
486
|
"""
|
@@ -474,46 +497,106 @@ def delete_offload_module(base: torch.nn.Module, name: str):
|
|
474
497
|
|
475
498
|
|
476
499
|
@check_accelerate(fallback="error")
|
477
|
-
def
|
478
|
-
module: torch.nn.Module,
|
500
|
+
def offloaded_dispatch(
|
501
|
+
module: torch.nn.Module,
|
502
|
+
execution_device: torch.device,
|
503
|
+
offload_device: Union[torch.device, Literal["disk"]] = torch.device("cpu"),
|
479
504
|
) -> torch.nn.Module:
|
480
505
|
"""
|
481
|
-
|
506
|
+
Unlike `dispatch_model`, this function forces a module (and its submodules) to
|
507
|
+
offload all parameters and replace them with meta tensors, utiliizing the
|
508
|
+
`AlignDevicesHook` to control onloading and offloading.
|
482
509
|
|
483
510
|
:param module: module containing parameters to offload
|
484
|
-
:param execution_device:
|
485
|
-
:
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
511
|
+
:param execution_device: device that modules will onload and execute on
|
512
|
+
:param offload_device: device that module parameters will offload to
|
513
|
+
:return: module with offloading device hooks
|
514
|
+
"""
|
515
|
+
if offload_device == "disk":
|
516
|
+
raise NotImplementedError("Disk offloading is not currently supported")
|
517
|
+
|
518
|
+
# remove any existing hooks
|
519
|
+
remove_dispatch(module)
|
520
|
+
|
521
|
+
# create weights map
|
522
|
+
state_dict = module.state_dict()
|
523
|
+
state_dict = {key: val.to(offload_device) for key, val in state_dict.items()}
|
524
|
+
weights_map = OffloadedWeightsLoader(state_dict=state_dict, device=offload_device)
|
525
|
+
|
526
|
+
# create tied params map
|
527
|
+
tied_params = find_tied_parameters(module)
|
528
|
+
tied_params_map = {}
|
529
|
+
for group in tied_params:
|
530
|
+
for param_name in group:
|
531
|
+
data_ptr = attrgetter(param_name)(module).data_ptr()
|
532
|
+
tied_params_map[data_ptr] = {}
|
533
|
+
|
534
|
+
# recursively attaches hooks to all submodules
|
535
|
+
attach_align_device_hook(
|
536
|
+
module,
|
537
|
+
execution_device=execution_device,
|
538
|
+
offload=True,
|
539
|
+
weights_map=weights_map,
|
540
|
+
tied_params_map=tied_params_map,
|
541
|
+
)
|
498
542
|
|
499
|
-
|
543
|
+
# when saving a model, `PretrainedModel.save_pretrained` will only
|
544
|
+
# onload weights if the following requirements are met
|
545
|
+
# if (
|
546
|
+
# hasattr(self, "hf_device_map")
|
547
|
+
# and len(set(self.hf_device_map.values())) > 1
|
548
|
+
# and ("cpu" in self.hf_device_map.values()
|
549
|
+
# or "disk" in self.hf_device_map.values())
|
550
|
+
# ):
|
551
|
+
# because this function always offloads, disregard actual devices and
|
552
|
+
# always use `cpu` and `cuda:0` to guarantee this condition passes
|
553
|
+
setattr(module, "hf_device_map", {"fake_offload": "cpu", "fake_exec": "cuda:0"})
|
500
554
|
|
501
|
-
|
502
|
-
if next(module.parameters(recurse=False), None) is not None:
|
503
|
-
device_map[".".join(name)] = "cpu"
|
504
|
-
return
|
555
|
+
return module
|
505
556
|
|
506
|
-
else:
|
507
|
-
for submodule_name, submodule in module.named_children():
|
508
|
-
name.append(submodule_name)
|
509
|
-
collect_device_map(name, submodule)
|
510
|
-
name.pop()
|
511
557
|
|
512
|
-
|
558
|
+
def remove_dispatch(module: torch.nn.Module) -> torch.nn.Module:
|
559
|
+
"""
|
560
|
+
Remove any existing dispatches from module
|
513
561
|
|
514
|
-
|
515
|
-
|
516
|
-
|
562
|
+
:param module: module which may be dispatched with hf hooks
|
563
|
+
:return: module without dispatch
|
564
|
+
"""
|
565
|
+
remove_hook_from_module(module, recurse=True)
|
566
|
+
if hasattr(module, "hf_device_map"):
|
567
|
+
delattr(module, "hf_device_map")
|
568
|
+
|
569
|
+
return module
|
570
|
+
|
571
|
+
|
572
|
+
@contextlib.contextmanager
|
573
|
+
def disable_offloading():
|
574
|
+
"""
|
575
|
+
Keep modules onloaded and disable offloading until this context exits.
|
576
|
+
Affects modules which have been hooked with accelerate's `AlignDevicesHook`
|
577
|
+
"""
|
578
|
+
original_pre_forward = AlignDevicesHook.pre_forward
|
579
|
+
onloaded_modules: Dict[torch.nn.Module, Tuple[AlignDevicesHook, bool]] = dict()
|
580
|
+
|
581
|
+
# onload once and disable any future onloading/offloading steps
|
582
|
+
def keep_onload_pre_forward(self: AlignDevicesHook, module, *args, **kwargs):
|
583
|
+
ret = original_pre_forward(self, module, *args, **kwargs)
|
584
|
+
if module not in onloaded_modules:
|
585
|
+
onloaded_modules[module] = (self, self.offload)
|
586
|
+
self.offload = False
|
587
|
+
return ret
|
588
|
+
|
589
|
+
# use the patched pre_forward function within the context
|
590
|
+
with patch_attr(AlignDevicesHook, "pre_forward", keep_onload_pre_forward):
|
591
|
+
yield
|
592
|
+
|
593
|
+
# manually offload all modules that were onloaded
|
594
|
+
# update any parameters which may have changed
|
595
|
+
for module, (hook, offload) in onloaded_modules.items():
|
596
|
+
hook.offload = offload
|
597
|
+
for name, param in module.named_parameters(recurse=False):
|
598
|
+
update_offload_parameter(module, name, param.data)
|
599
|
+
hook.post_forward(module, None)
|
517
600
|
|
518
601
|
|
519
602
|
""" Upstreamed Functions """
|
@@ -583,3 +666,7 @@ def align_module_device(
|
|
583
666
|
|
584
667
|
else:
|
585
668
|
yield
|
669
|
+
|
670
|
+
|
671
|
+
# (1): Since we cannot know which pointers are shared when we add parameters in an
|
672
|
+
# online way, assume that all pointers are shared. This has virtually no runtime cost
|
compressed_tensors/version.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.10.
|
21
|
-
__version_tuple__ = version_tuple = (0, 10,
|
20
|
+
__version__ = version = '0.10.2'
|
21
|
+
__version_tuple__ = version_tuple = (0, 10, 2)
|
{compressed_tensors-0.10.1a20250604.dist-info → compressed_tensors-0.10.2.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.2
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -1,11 +1,11 @@
|
|
1
1
|
compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
|
2
2
|
compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
|
3
|
-
compressed_tensors/version.py,sha256=
|
3
|
+
compressed_tensors/version.py,sha256=hLMY-mTgNhuVqmjaSY9lkyEvKWbHFLs0gilovrevj8M,513
|
4
4
|
compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
|
5
5
|
compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
|
6
6
|
compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
|
7
7
|
compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
|
8
|
-
compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=
|
8
|
+
compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=nt0KxhZakDdlTIebBYcSvqxLCZhA6p6IL_1AYiHLFug,32695
|
9
9
|
compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=KvaFBL_Q84LxRGJOV035M8OBoCkAx8kOkfphswgkKWk,745
|
10
10
|
compressed_tensors/compressors/quantized_compressors/base.py,sha256=ByE3z61boZ5wdz0nhc-2CJH61bSixJQE78pfkS6XRDg,10269
|
11
11
|
compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=0ANDcuD8aXPqTYNPY6GnX9iS6eXJw6P0TzNV_rYS2l8,5369
|
@@ -13,7 +13,7 @@ compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=G
|
|
13
13
|
compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=_66tQ8bxslDUdas-ULORXblPw9kdNNn1UJJU9-ZOGPY,11380
|
14
14
|
compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=Atuz-OdEgn8OCUhx7Ovd6gXdyImAI186uCR-uR0t_Nk,737
|
15
15
|
compressed_tensors/compressors/sparse_compressors/base.py,sha256=YNZWcHjDleAlqbgRZQ6oJf44MQb_UDNvJGOqhl26uFA,8098
|
16
|
-
compressed_tensors/compressors/sparse_compressors/dense.py,sha256
|
16
|
+
compressed_tensors/compressors/sparse_compressors/dense.py,sha256=-OujJ1e0iXBvxYVULrIGvAZ9l-IC0mXczZRnimQdgo4,2314
|
17
17
|
compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py,sha256=4cwkj40SFrXEyE_jyt2xjz3R-gTdU9uMpMFUKo1pRBA,8643
|
18
18
|
compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=S8vW0FI9ep_XtUQOxj0P5utJt3vKEYOHjWEPp-Xd9aY,5820
|
19
19
|
compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
|
@@ -30,7 +30,7 @@ compressed_tensors/quantization/quant_args.py,sha256=2OpiiSdl4KidzNmjx7J8UlQoAYm
|
|
30
30
|
compressed_tensors/quantization/quant_config.py,sha256=aFi6PKqmEX9iP9O8GVn3mEUjRDEwk_hOCbmmiq-j9oU,10198
|
31
31
|
compressed_tensors/quantization/quant_scheme.py,sha256=IDWa1GWUbUdWCo8j78Jz6svYF5hLz89J2PVYWBBnXRc,7102
|
32
32
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
|
33
|
-
compressed_tensors/quantization/lifecycle/apply.py,sha256=
|
33
|
+
compressed_tensors/quantization/lifecycle/apply.py,sha256=v7D0TJU_eLT20Odn_J1VCPo2twll2ra-wxlEGBKB2OA,17990
|
34
34
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
35
35
|
compressed_tensors/quantization/lifecycle/forward.py,sha256=JWOQ-03bsgh9_nnOLAjmLZ0S8bFQA-GjwDK6YUBwcrU,14883
|
36
36
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
@@ -39,19 +39,28 @@ compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5
|
|
39
39
|
compressed_tensors/quantization/utils/helpers.py,sha256=bqxNL2NU1XVsSxNzmDVZE3zd65PlLFq1Ir-RHwff8G0,17840
|
40
40
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
41
41
|
compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
|
42
|
-
compressed_tensors/transform/__init__.py,sha256=
|
43
|
-
compressed_tensors/transform/transform_args.py,sha256=
|
42
|
+
compressed_tensors/transform/__init__.py,sha256=mtUOzwq-H7fXGi7sMmfe7zj83fjMg_LAu4DjTZ5vaHk,886
|
43
|
+
compressed_tensors/transform/transform_args.py,sha256=8-Ab5_dFfdObfwVCgrWrEWcoVRzXmMBSDSUxjftI-Ss,3177
|
44
44
|
compressed_tensors/transform/transform_config.py,sha256=6JA8VFcoz4EGHOev6thj51OuB7K2gKUUazWjrVPYDLc,2144
|
45
45
|
compressed_tensors/transform/transform_scheme.py,sha256=c7NAuLDL0itFgUfBMNShegMI9bzKL7s4LR3QJTHsXLs,1733
|
46
|
+
compressed_tensors/transform/factory/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
47
|
+
compressed_tensors/transform/factory/base.py,sha256=yVrYWEnrr2RFWE5AjSNeXzO9aXc443dTNMVSxuLztz8,5940
|
48
|
+
compressed_tensors/transform/factory/hadamard.py,sha256=zkq6w8uJXRLokUXajAkFb2fJrH0K3SL6qrR2dARrAr8,3139
|
49
|
+
compressed_tensors/transform/factory/matrix_multiply.py,sha256=0g4sYC_tOmCjOomae2gl54UTXiFdl0mCCkmbqIRX8yw,3613
|
50
|
+
compressed_tensors/transform/factory/random_hadamard.py,sha256=TFInxbHslqREOFFiy_mpR88eEYXQnslxXmyh-ZbN-MU,1499
|
51
|
+
compressed_tensors/transform/utils/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
52
|
+
compressed_tensors/transform/utils/hadamard.py,sha256=U27Kvo-eDebKcVt8oXTSIAaQ5DvPQj9tDv2hdXHCPPQ,5584
|
53
|
+
compressed_tensors/transform/utils/hadamards.safetensors,sha256=mFd1GzNodGG-ifA1IoH-0nHYzfraCOvrq_dX2zFI1B4,1436901
|
54
|
+
compressed_tensors/transform/utils/utils.py,sha256=PRPTYwPs2nnNaQMq2GEbC4QYKHFKlZwaRyPgdDhl66g,2992
|
46
55
|
compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
|
47
|
-
compressed_tensors/utils/helpers.py,sha256=
|
48
|
-
compressed_tensors/utils/offload.py,sha256=
|
56
|
+
compressed_tensors/utils/helpers.py,sha256=cPg-ikdeA92aIGwBONg8GmPNvcGlFhozyJVwsRiXBTA,11981
|
57
|
+
compressed_tensors/utils/offload.py,sha256=3XiBuWbUkBAt8v1t5i57qDcbB3VJQs_FDeayi-JzIWg,23896
|
49
58
|
compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
50
59
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
51
60
|
compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
|
52
61
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
53
|
-
compressed_tensors-0.10.
|
54
|
-
compressed_tensors-0.10.
|
55
|
-
compressed_tensors-0.10.
|
56
|
-
compressed_tensors-0.10.
|
57
|
-
compressed_tensors-0.10.
|
62
|
+
compressed_tensors-0.10.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
63
|
+
compressed_tensors-0.10.2.dist-info/METADATA,sha256=EJZIGQMfAp5b8c9p9vblRCx6hQPg32jHoDGhNv5W96k,6996
|
64
|
+
compressed_tensors-0.10.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
65
|
+
compressed_tensors-0.10.2.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
66
|
+
compressed_tensors-0.10.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{compressed_tensors-0.10.1a20250604.dist-info → compressed_tensors-0.10.2.dist-info}/top_level.txt
RENAMED
File without changes
|