compressed-tensors 0.10.2a20250609__py3-none-any.whl → 0.10.2a20250612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/sparse_compressors/base.py +1 -19
- compressed_tensors/compressors/sparse_compressors/dense.py +19 -1
- compressed_tensors/quantization/lifecycle/apply.py +1 -3
- compressed_tensors/transform/__init__.py +5 -0
- compressed_tensors/transform/factory/__init__.py +13 -0
- compressed_tensors/transform/factory/base.py +164 -0
- compressed_tensors/transform/factory/hadamard.py +79 -0
- compressed_tensors/transform/factory/matrix_multiply.py +90 -0
- compressed_tensors/transform/factory/random_hadamard.py +34 -0
- compressed_tensors/utils/offload.py +49 -48
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.10.2a20250609.dist-info → compressed_tensors-0.10.2a20250612.dist-info}/METADATA +1 -1
- {compressed_tensors-0.10.2a20250609.dist-info → compressed_tensors-0.10.2a20250612.dist-info}/RECORD +16 -11
- {compressed_tensors-0.10.2a20250609.dist-info → compressed_tensors-0.10.2a20250612.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.10.2a20250609.dist-info → compressed_tensors-0.10.2a20250612.dist-info}/licenses/LICENSE +0 -0
- {compressed_tensors-0.10.2a20250609.dist-info → compressed_tensors-0.10.2a20250612.dist-info}/top_level.txt +0 -0
@@ -13,9 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
from typing import
|
16
|
+
from typing import Dict, Generator, Optional, Set, Tuple
|
17
17
|
|
18
|
-
import torch
|
19
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
20
19
|
from compressed_tensors.utils import (
|
21
20
|
get_nested_mappings_from_state_dict,
|
@@ -27,10 +26,6 @@ from torch import Tensor
|
|
27
26
|
from tqdm import tqdm
|
28
27
|
|
29
28
|
|
30
|
-
if TYPE_CHECKING:
|
31
|
-
from compressed_tensors.quantization import QuantizationScheme
|
32
|
-
|
33
|
-
|
34
29
|
__all__ = ["BaseSparseCompressor"]
|
35
30
|
|
36
31
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
@@ -205,16 +200,3 @@ class BaseSparseCompressor(BaseCompressor):
|
|
205
200
|
return (
|
206
201
|
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
|
207
202
|
)
|
208
|
-
|
209
|
-
def decompress_module_from_state_dict(
|
210
|
-
self,
|
211
|
-
prefix: str,
|
212
|
-
state_dict: Dict[str, torch.Tensor],
|
213
|
-
scheme: "QuantizationScheme",
|
214
|
-
) -> Dict[str, torch.Tensor]:
|
215
|
-
"""
|
216
|
-
This function is implemented as a workaround because of how
|
217
|
-
`ModelCompressor.quantization_compressor` can be set to either
|
218
|
-
an instance of `BaseQuantizationCompressor` or `BaseSparseCompressor`.
|
219
|
-
"""
|
220
|
-
return state_dict.copy()
|
@@ -12,13 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Dict, Generator, Tuple
|
15
|
+
from typing import TYPE_CHECKING, Dict, Generator, Tuple
|
16
16
|
|
17
|
+
import torch
|
17
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
18
19
|
from compressed_tensors.config import CompressionFormat
|
19
20
|
from torch import Tensor
|
20
21
|
|
21
22
|
|
23
|
+
if TYPE_CHECKING:
|
24
|
+
from compressed_tensors.quantization import QuantizationScheme
|
25
|
+
|
26
|
+
|
22
27
|
@BaseCompressor.register(name=CompressionFormat.dense.value)
|
23
28
|
class DenseCompressor(BaseCompressor):
|
24
29
|
"""
|
@@ -47,3 +52,16 @@ class DenseCompressor(BaseCompressor):
|
|
47
52
|
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
|
48
53
|
for key, value in state_dict.items():
|
49
54
|
yield key, value
|
55
|
+
|
56
|
+
def decompress_module_from_state_dict(
|
57
|
+
self,
|
58
|
+
prefix: str,
|
59
|
+
state_dict: Dict[str, torch.Tensor],
|
60
|
+
scheme: "QuantizationScheme",
|
61
|
+
) -> Dict[str, torch.Tensor]:
|
62
|
+
"""
|
63
|
+
This function is implemented as a workaround because of how
|
64
|
+
`ModelCompressor.quantization_compressor` can be set to either
|
65
|
+
an instance of `BaseQuantizationCompressor` or `DenseCompressor`.
|
66
|
+
"""
|
67
|
+
return state_dict.copy()
|
@@ -183,9 +183,7 @@ def apply_quantization_config(
|
|
183
183
|
replace_module(model, name, compressed_linear)
|
184
184
|
|
185
185
|
# target matched - add layer and scheme to target list
|
186
|
-
submodule.quantization_scheme =
|
187
|
-
target_to_scheme, targets, name
|
188
|
-
)
|
186
|
+
submodule.quantization_scheme = scheme
|
189
187
|
|
190
188
|
names_to_scheme[name] = submodule.quantization_scheme
|
191
189
|
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,164 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn.utils.parametrize as P
|
20
|
+
from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
|
21
|
+
from compressed_tensors.registry.registry import RegistryMixin, T
|
22
|
+
from compressed_tensors.transform import (
|
23
|
+
TransformArgs,
|
24
|
+
TransformLocation,
|
25
|
+
TransformScheme,
|
26
|
+
)
|
27
|
+
from compressed_tensors.utils import (
|
28
|
+
align_module_device,
|
29
|
+
has_offloaded_params,
|
30
|
+
patch_attr,
|
31
|
+
register_offload_module,
|
32
|
+
update_offload_parameter,
|
33
|
+
)
|
34
|
+
from torch import Tensor
|
35
|
+
from torch.nn import Module, Parameter
|
36
|
+
|
37
|
+
|
38
|
+
__all__ = ["TransformFactory", "TransformBase"]
|
39
|
+
|
40
|
+
|
41
|
+
class TransformFactory(RegistryMixin, ABC):
|
42
|
+
"""
|
43
|
+
Abstract factory base used to create and apply transforms to a model
|
44
|
+
|
45
|
+
:param name: name associated with transform scheme
|
46
|
+
:param scheme: transform scheme which defines how transforms should be created
|
47
|
+
:param seed: random seed used to transform weight randomization
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
51
|
+
self.name = name
|
52
|
+
self.scheme = scheme
|
53
|
+
self.generator = torch.Generator()
|
54
|
+
if seed is not None:
|
55
|
+
self.generator.manual_seed(seed)
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T:
|
59
|
+
"""
|
60
|
+
Create a transform factory from a scheme
|
61
|
+
|
62
|
+
:param scheme: defines how transforms should be created
|
63
|
+
:param kwargs: TransformFactory constructor arguments
|
64
|
+
:return: subclass of `TransformFactory` corresponding to the scheme type
|
65
|
+
"""
|
66
|
+
constructor = cls.get_value_from_registry(name=scheme.type)
|
67
|
+
return constructor(scheme=scheme, **kwargs)
|
68
|
+
|
69
|
+
@abstractmethod
|
70
|
+
def create_transform(self, module: Module, args: TransformArgs) -> "TransformBase":
|
71
|
+
"""
|
72
|
+
Abstract method which defines how a transform should be created. May utilize
|
73
|
+
caching to maximize shared memory
|
74
|
+
|
75
|
+
:param module: parent module that transform will be applied to
|
76
|
+
:param args: defines how the transform will be applied to the module
|
77
|
+
:return: instance of TransformBase
|
78
|
+
"""
|
79
|
+
raise NotImplementedError()
|
80
|
+
|
81
|
+
def apply_to_model(self, model: Module):
|
82
|
+
"""
|
83
|
+
Create transforms and apply them to the model
|
84
|
+
|
85
|
+
:param model: module to apply transforms to
|
86
|
+
"""
|
87
|
+
for arg in self.scheme.apply:
|
88
|
+
for name, module in list(model.named_modules()):
|
89
|
+
if is_target(name, module, arg.targets, arg.ignore):
|
90
|
+
self._apply_to_module(module, arg)
|
91
|
+
|
92
|
+
def _apply_to_module(self, module: Module, args: TransformArgs):
|
93
|
+
"""
|
94
|
+
Create transforms and apply them to the module
|
95
|
+
|
96
|
+
:param module: target module to apply transforms to
|
97
|
+
:param args: defines how the transform will be applied to the target module
|
98
|
+
"""
|
99
|
+
# create transform as submodule
|
100
|
+
transform_name = f"{self.name}_{args.location.value}"
|
101
|
+
transform = self.create_transform(module, args)
|
102
|
+
register_offload_module(module, transform_name, transform) # (1)
|
103
|
+
|
104
|
+
# register input transformation hook
|
105
|
+
if args.location == TransformLocation.INPUT:
|
106
|
+
|
107
|
+
def input_hook(_, args):
|
108
|
+
input = args[0]
|
109
|
+
return transform(input)
|
110
|
+
|
111
|
+
module.register_forward_pre_hook(input_hook, prepend=True)
|
112
|
+
|
113
|
+
# eagerly apply transformation to weight
|
114
|
+
elif args.location in (
|
115
|
+
TransformLocation.WEIGHT_INPUT,
|
116
|
+
TransformLocation.WEIGHT_OUTPUT,
|
117
|
+
):
|
118
|
+
assert isinstance(module, torch.nn.Linear)
|
119
|
+
assert module.bias is None
|
120
|
+
|
121
|
+
with torch.no_grad(), align_module_device(module):
|
122
|
+
update_offload_parameter(module, "weight", transform(module.weight))
|
123
|
+
|
124
|
+
if self.scheme.requires_grad:
|
125
|
+
# for training, the weight changes with every forward pass
|
126
|
+
# so we can leverage parametrization to propagate the gradient
|
127
|
+
if has_offloaded_params(module):
|
128
|
+
raise ValueError("Offloaded training is not supported")
|
129
|
+
P.register_parametrization(module, "weight", transform)
|
130
|
+
|
131
|
+
# register output transformation hook
|
132
|
+
elif args.location == TransformLocation.OUTPUT:
|
133
|
+
|
134
|
+
def output_hook(_, _input, output):
|
135
|
+
return transform(output)
|
136
|
+
|
137
|
+
module.register_forward_hook(output_hook)
|
138
|
+
|
139
|
+
# other locations such as q_attn and k_attn have not been implemented
|
140
|
+
else:
|
141
|
+
raise NotImplementedError()
|
142
|
+
|
143
|
+
# (1) even in the `weight` cases, this submodule attachment is needed in order
|
144
|
+
# to support saving in the frozen state
|
145
|
+
|
146
|
+
|
147
|
+
class TransformBase(Module, ABC):
|
148
|
+
"""
|
149
|
+
Represents the application of a transform accord to TransformArgs
|
150
|
+
"""
|
151
|
+
|
152
|
+
args: TransformArgs
|
153
|
+
weight: Parameter
|
154
|
+
|
155
|
+
@abstractmethod
|
156
|
+
def forward(self, value: Tensor) -> Tensor:
|
157
|
+
raise NotImplementedError()
|
158
|
+
|
159
|
+
def right_inverse(self, value: Tensor) -> Tensor:
|
160
|
+
with patch_attr(self.args, "inverse", not self.args.inverse):
|
161
|
+
return self.forward(value)
|
162
|
+
|
163
|
+
def __repr__(self):
|
164
|
+
return f"{self.__class__.__name__}(inverse={self.args.inverse})"
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.transform import TransformArgs, TransformScheme
|
19
|
+
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
|
20
|
+
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
21
|
+
from compressed_tensors.transform.utils.utils import (
|
22
|
+
apply_transform_weight,
|
23
|
+
get_matrix_size,
|
24
|
+
)
|
25
|
+
from compressed_tensors.utils import get_offloaded_device
|
26
|
+
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
27
|
+
from torch import Tensor, device, dtype
|
28
|
+
from torch.nn import Linear, Module, Parameter
|
29
|
+
|
30
|
+
|
31
|
+
@TransformFactory.register("hadamard")
|
32
|
+
class HadamardFactory(TransformFactory):
|
33
|
+
"""
|
34
|
+
Factory used to apply hadamard transforms to a model
|
35
|
+
|
36
|
+
:param name: name associated with transform scheme
|
37
|
+
:param scheme: transform scheme which defines how transforms should be created
|
38
|
+
:param seed: random seed used to transform weight randomization
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
42
|
+
super().__init__(name, scheme, seed)
|
43
|
+
self.weights = ParameterizedDefaultDict(self._create_weight)
|
44
|
+
|
45
|
+
def create_transform(self, module: Module, args: TransformArgs):
|
46
|
+
"""
|
47
|
+
Create a HadamardTransform for applying to a module. Transforms with the same
|
48
|
+
size, dtype, and device are cached
|
49
|
+
|
50
|
+
:param module: parent module that transform will be applied to
|
51
|
+
:param args: defines how the transform will be applied to the module
|
52
|
+
"""
|
53
|
+
assert isinstance(module, Linear)
|
54
|
+
size = get_matrix_size(module, args.location)
|
55
|
+
dtype = module.weight.dtype
|
56
|
+
device = get_offloaded_device(module)
|
57
|
+
|
58
|
+
weight = self.weights[size, dtype, device]
|
59
|
+
return HadamardTransform(weight, args)
|
60
|
+
|
61
|
+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
62
|
+
data = deterministic_hadamard_matrix(size)
|
63
|
+
data = data.to(dtype=dtype, device=device)
|
64
|
+
return Parameter(data, requires_grad=self.scheme.requires_grad)
|
65
|
+
|
66
|
+
|
67
|
+
class HadamardTransform(TransformBase):
|
68
|
+
def __init__(self, weight: Parameter, args: TransformArgs):
|
69
|
+
super().__init__()
|
70
|
+
self.weight = weight
|
71
|
+
self.args = args
|
72
|
+
|
73
|
+
def forward(self, value: Tensor) -> Tensor:
|
74
|
+
if not self.args.inverse:
|
75
|
+
weight = self.weight
|
76
|
+
else:
|
77
|
+
weight = self.weight.T
|
78
|
+
|
79
|
+
return apply_transform_weight(weight, value, self.args.location)
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.transform import TransformArgs, TransformScheme
|
19
|
+
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
|
20
|
+
from compressed_tensors.transform.utils.utils import (
|
21
|
+
apply_transform_weight,
|
22
|
+
get_matrix_size,
|
23
|
+
)
|
24
|
+
from compressed_tensors.utils import get_offloaded_device
|
25
|
+
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
26
|
+
from torch import Tensor, device, dtype
|
27
|
+
from torch.nn import Linear, Module, Parameter
|
28
|
+
|
29
|
+
|
30
|
+
@TransformFactory.register("random-matrix")
|
31
|
+
class RandomMatrixFactory(TransformFactory):
|
32
|
+
"""
|
33
|
+
Factory used to apply random matrix transforms to a model
|
34
|
+
|
35
|
+
:param name: name associated with transform scheme
|
36
|
+
:param scheme: transform scheme which defines how transforms should be created
|
37
|
+
:param seed: random seed used to transform weight randomization
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
41
|
+
super().__init__(name, scheme, seed)
|
42
|
+
self.weights = ParameterizedDefaultDict(self._create_weight)
|
43
|
+
self.inverses = ParameterizedDefaultDict(self._create_inverse)
|
44
|
+
|
45
|
+
def create_transform(self, module: Module, args: TransformArgs):
|
46
|
+
"""
|
47
|
+
Create a RandomMatrixTransform for applying to a module. Transforms with the
|
48
|
+
same size, dtype, and device are cached
|
49
|
+
|
50
|
+
:param module: parent module that transform will be applied to
|
51
|
+
:param args: defines how the transform will be applied to the module
|
52
|
+
"""
|
53
|
+
assert isinstance(module, Linear)
|
54
|
+
size = get_matrix_size(module, args.location)
|
55
|
+
dtype = module.weight.dtype
|
56
|
+
device = get_offloaded_device(module)
|
57
|
+
|
58
|
+
weight = self.weights[size, dtype, device]
|
59
|
+
if args.inverse:
|
60
|
+
weight = self.inverses[weight]
|
61
|
+
|
62
|
+
return RandomMatrixTransform(weight, args)
|
63
|
+
|
64
|
+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
65
|
+
data = torch.rand(
|
66
|
+
(size, size), generator=self.generator, dtype=dtype, device=device
|
67
|
+
)
|
68
|
+
return Parameter(data, requires_grad=self.scheme.requires_grad)
|
69
|
+
|
70
|
+
def _create_inverse(self, weight: Parameter) -> Parameter:
|
71
|
+
data = high_precision_invert(weight.data)
|
72
|
+
return Parameter(data, requires_grad=False)
|
73
|
+
|
74
|
+
|
75
|
+
class RandomMatrixTransform(TransformBase):
|
76
|
+
def __init__(self, weight: Tensor, args: TransformArgs):
|
77
|
+
super().__init__()
|
78
|
+
self.weight = weight # is an inverse if args.inverse
|
79
|
+
self.args = args
|
80
|
+
|
81
|
+
def forward(self, value: Tensor) -> Parameter:
|
82
|
+
return apply_transform_weight(self.weight, value, self.args.location)
|
83
|
+
|
84
|
+
def right_inverse(self, value: Tensor) -> Tensor:
|
85
|
+
inverse = high_precision_invert(self.weight)
|
86
|
+
return apply_transform_weight(inverse, value, self.args.location)
|
87
|
+
|
88
|
+
|
89
|
+
def high_precision_invert(weight: Tensor) -> Tensor:
|
90
|
+
return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from compressed_tensors.transform import HadamardFactory, TransformFactory
|
16
|
+
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
|
17
|
+
from torch import device, dtype
|
18
|
+
from torch.nn import Parameter
|
19
|
+
|
20
|
+
|
21
|
+
@TransformFactory.register("random-hadamard")
|
22
|
+
class RandomHadamardFactory(HadamardFactory):
|
23
|
+
"""
|
24
|
+
Factory used to apply random hadamard transforms to a model
|
25
|
+
|
26
|
+
:param name: name associated with transform scheme
|
27
|
+
:param scheme: transform scheme which defines how transforms should be created
|
28
|
+
:param seed: random seed used to transform weight randomization
|
29
|
+
"""
|
30
|
+
|
31
|
+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
32
|
+
data = random_hadamard_matrix(size, self.generator)
|
33
|
+
data = data.to(dtype=dtype, device=device)
|
34
|
+
return Parameter(data, requires_grad=self.scheme.requires_grad)
|
@@ -14,27 +14,29 @@
|
|
14
14
|
"""
|
15
15
|
Utilities associated with offloading functionality provided by `accelerate`.
|
16
16
|
|
17
|
-
|
|
18
|
-
| Operation
|
19
|
-
|
|
20
|
-
| Add
|
21
|
-
| Check
|
22
|
-
| Onload
|
23
|
-
| Update
|
24
|
-
| Delete
|
25
|
-
|
|
17
|
+
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
|
18
|
+
| Operation | Without offloading support | With offloading support | # noqa: E501
|
19
|
+
| ---------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
|
20
|
+
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
|
21
|
+
| Check | N/A | has_offloaded_params(module) | # noqa: E501
|
22
|
+
| Onload | N/A | with align_module_device(module) | # noqa: E501
|
23
|
+
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
|
24
|
+
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
|
25
|
+
| Add Module | module.register_module(name, child) | register_offload_module(name, child) | # noqa: E501
|
26
|
+
| Del Module | del module.name | delete_offload_module(module, name) | # noqa: E501
|
27
|
+
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
|
26
28
|
"""
|
27
29
|
|
28
30
|
import contextlib
|
29
31
|
import warnings
|
30
32
|
from functools import wraps
|
31
|
-
from
|
33
|
+
from operator import attrgetter
|
34
|
+
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
|
32
35
|
|
33
36
|
import torch
|
34
37
|
|
35
38
|
|
36
39
|
try:
|
37
|
-
from accelerate import dispatch_model
|
38
40
|
from accelerate.hooks import (
|
39
41
|
AlignDevicesHook,
|
40
42
|
add_hook_to_module,
|
@@ -45,10 +47,12 @@ try:
|
|
45
47
|
from accelerate.utils import (
|
46
48
|
OffloadedWeightsLoader,
|
47
49
|
PrefixedDataset,
|
50
|
+
find_tied_parameters,
|
48
51
|
set_module_tensor_to_device,
|
49
52
|
)
|
50
53
|
|
51
54
|
_has_accelerate = True
|
55
|
+
|
52
56
|
except ImportError:
|
53
57
|
_has_accelerate = False
|
54
58
|
AlignDevicesHook = None
|
@@ -58,8 +62,8 @@ except ImportError:
|
|
58
62
|
PrefixedDataset = None
|
59
63
|
set_module_tensor_to_device = None
|
60
64
|
named_module_tensors = None
|
61
|
-
dispatch_model = None
|
62
65
|
attach_align_device_hook = None
|
66
|
+
find_tied_parameters = None
|
63
67
|
|
64
68
|
|
65
69
|
__all__ = [
|
@@ -78,14 +82,13 @@ __all__ = [
|
|
78
82
|
"align_module_device",
|
79
83
|
"register_offload_module",
|
80
84
|
"delete_offload_module",
|
81
|
-
"
|
85
|
+
"offloaded_dispatch",
|
82
86
|
]
|
83
87
|
|
84
88
|
|
85
89
|
def check_accelerate(fallback: Any):
|
86
90
|
def decorator(func: Callable[[Any], Any]):
|
87
91
|
if not _has_accelerate:
|
88
|
-
|
89
92
|
if fallback == "error":
|
90
93
|
|
91
94
|
@wraps(func)
|
@@ -479,46 +482,44 @@ def delete_offload_module(base: torch.nn.Module, name: str):
|
|
479
482
|
|
480
483
|
|
481
484
|
@check_accelerate(fallback="error")
|
482
|
-
def
|
483
|
-
module: torch.nn.Module,
|
485
|
+
def offloaded_dispatch(
|
486
|
+
module: torch.nn.Module,
|
487
|
+
execution_device: torch.device,
|
488
|
+
offload_device: Union[torch.device, Literal["disk"]] = torch.device("cpu"),
|
484
489
|
) -> torch.nn.Module:
|
485
490
|
"""
|
486
|
-
|
491
|
+
Unlike `dispatch_model`, this function forces a module (and its submodules) to
|
492
|
+
offload all parameters and replace them with meta tensors, utiliizing the
|
493
|
+
`AlignDevicesHook` to control onloading and offloading.
|
487
494
|
|
488
495
|
:param module: module containing parameters to offload
|
489
|
-
:param execution_device:
|
490
|
-
:
|
496
|
+
:param execution_device: device that modules will onload and execute on
|
497
|
+
:param offload_device: device that module parameters will offload to
|
498
|
+
:return: module with offloading device hooks
|
491
499
|
"""
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
name.append(submodule_name)
|
514
|
-
collect_device_map(name, submodule)
|
515
|
-
name.pop()
|
516
|
-
|
517
|
-
collect_device_map([], module)
|
518
|
-
|
519
|
-
return dispatch_model(
|
520
|
-
module, device_map, main_device=execution_device, force_hooks=True
|
500
|
+
if offload_device == "disk":
|
501
|
+
raise NotImplementedError("Disk offloading is not currently supported")
|
502
|
+
|
503
|
+
# create weights map
|
504
|
+
weights_map = OffloadedWeightsLoader(state_dict=module.state_dict(), device="cpu")
|
505
|
+
|
506
|
+
# create tied params map
|
507
|
+
tied_params = find_tied_parameters(module)
|
508
|
+
tied_params_map = {}
|
509
|
+
for group in tied_params:
|
510
|
+
for param_name in group:
|
511
|
+
data_ptr = attrgetter(param_name)(module).data_ptr()
|
512
|
+
tied_params_map[data_ptr] = {}
|
513
|
+
|
514
|
+
# recursively attaches hooks to all submodules
|
515
|
+
attach_align_device_hook(
|
516
|
+
module,
|
517
|
+
execution_device=execution_device,
|
518
|
+
offload=True,
|
519
|
+
weights_map=weights_map,
|
520
|
+
tied_params_map=tied_params_map,
|
521
521
|
)
|
522
|
+
return module
|
522
523
|
|
523
524
|
|
524
525
|
""" Upstreamed Functions """
|
compressed_tensors/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.2a20250612
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
{compressed_tensors-0.10.2a20250609.dist-info → compressed_tensors-0.10.2a20250612.dist-info}/RECORD
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
|
2
2
|
compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
|
3
|
-
compressed_tensors/version.py,sha256=
|
3
|
+
compressed_tensors/version.py,sha256=F2izwCTRKbiv1mAW6qD3TbJD5cXQrz4zRmew4qZ4Ud0,523
|
4
4
|
compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
|
5
5
|
compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
|
6
6
|
compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
|
@@ -12,8 +12,8 @@ compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=0
|
|
12
12
|
compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=Gw-lVzk5jrKUlM5UTCiJBmhM5gHzB9mn8r298MVUbDI,6395
|
13
13
|
compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=_66tQ8bxslDUdas-ULORXblPw9kdNNn1UJJU9-ZOGPY,11380
|
14
14
|
compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=Atuz-OdEgn8OCUhx7Ovd6gXdyImAI186uCR-uR0t_Nk,737
|
15
|
-
compressed_tensors/compressors/sparse_compressors/base.py,sha256=
|
16
|
-
compressed_tensors/compressors/sparse_compressors/dense.py,sha256
|
15
|
+
compressed_tensors/compressors/sparse_compressors/base.py,sha256=YNZWcHjDleAlqbgRZQ6oJf44MQb_UDNvJGOqhl26uFA,8098
|
16
|
+
compressed_tensors/compressors/sparse_compressors/dense.py,sha256=-OujJ1e0iXBvxYVULrIGvAZ9l-IC0mXczZRnimQdgo4,2314
|
17
17
|
compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py,sha256=4cwkj40SFrXEyE_jyt2xjz3R-gTdU9uMpMFUKo1pRBA,8643
|
18
18
|
compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=S8vW0FI9ep_XtUQOxj0P5utJt3vKEYOHjWEPp-Xd9aY,5820
|
19
19
|
compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
|
@@ -30,7 +30,7 @@ compressed_tensors/quantization/quant_args.py,sha256=2OpiiSdl4KidzNmjx7J8UlQoAYm
|
|
30
30
|
compressed_tensors/quantization/quant_config.py,sha256=aFi6PKqmEX9iP9O8GVn3mEUjRDEwk_hOCbmmiq-j9oU,10198
|
31
31
|
compressed_tensors/quantization/quant_scheme.py,sha256=IDWa1GWUbUdWCo8j78Jz6svYF5hLz89J2PVYWBBnXRc,7102
|
32
32
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
|
33
|
-
compressed_tensors/quantization/lifecycle/apply.py,sha256=
|
33
|
+
compressed_tensors/quantization/lifecycle/apply.py,sha256=v7D0TJU_eLT20Odn_J1VCPo2twll2ra-wxlEGBKB2OA,17990
|
34
34
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
35
35
|
compressed_tensors/quantization/lifecycle/forward.py,sha256=JWOQ-03bsgh9_nnOLAjmLZ0S8bFQA-GjwDK6YUBwcrU,14883
|
36
36
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
@@ -39,22 +39,27 @@ compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5
|
|
39
39
|
compressed_tensors/quantization/utils/helpers.py,sha256=bqxNL2NU1XVsSxNzmDVZE3zd65PlLFq1Ir-RHwff8G0,17840
|
40
40
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
41
41
|
compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
|
42
|
-
compressed_tensors/transform/__init__.py,sha256=
|
42
|
+
compressed_tensors/transform/__init__.py,sha256=mtUOzwq-H7fXGi7sMmfe7zj83fjMg_LAu4DjTZ5vaHk,886
|
43
43
|
compressed_tensors/transform/transform_args.py,sha256=8-Ab5_dFfdObfwVCgrWrEWcoVRzXmMBSDSUxjftI-Ss,3177
|
44
44
|
compressed_tensors/transform/transform_config.py,sha256=6JA8VFcoz4EGHOev6thj51OuB7K2gKUUazWjrVPYDLc,2144
|
45
45
|
compressed_tensors/transform/transform_scheme.py,sha256=c7NAuLDL0itFgUfBMNShegMI9bzKL7s4LR3QJTHsXLs,1733
|
46
|
+
compressed_tensors/transform/factory/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
47
|
+
compressed_tensors/transform/factory/base.py,sha256=yVrYWEnrr2RFWE5AjSNeXzO9aXc443dTNMVSxuLztz8,5940
|
48
|
+
compressed_tensors/transform/factory/hadamard.py,sha256=tuFVpKsv__SeDj-QwKxtipLvjb993DOSIFvWcUh42Ww,3124
|
49
|
+
compressed_tensors/transform/factory/matrix_multiply.py,sha256=0g4sYC_tOmCjOomae2gl54UTXiFdl0mCCkmbqIRX8yw,3613
|
50
|
+
compressed_tensors/transform/factory/random_hadamard.py,sha256=6kqr9z6kFc-2qRNskhWRsLGTDT_NfNAkFcTLMqQJcWA,1484
|
46
51
|
compressed_tensors/transform/utils/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
47
52
|
compressed_tensors/transform/utils/hadamard.py,sha256=SmPZmnHtc5N36gJA5EbM1T65uf4w1_flgl7SWBeg_W8,5642
|
48
53
|
compressed_tensors/transform/utils/utils.py,sha256=PRPTYwPs2nnNaQMq2GEbC4QYKHFKlZwaRyPgdDhl66g,2992
|
49
54
|
compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
|
50
55
|
compressed_tensors/utils/helpers.py,sha256=cPg-ikdeA92aIGwBONg8GmPNvcGlFhozyJVwsRiXBTA,11981
|
51
|
-
compressed_tensors/utils/offload.py,sha256=
|
56
|
+
compressed_tensors/utils/offload.py,sha256=myV7iC75gA8A3BGgwR3uoeaJkIC9oigKp9CcqsHsVJc,20686
|
52
57
|
compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
53
58
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
54
59
|
compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
|
55
60
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
56
|
-
compressed_tensors-0.10.
|
57
|
-
compressed_tensors-0.10.
|
58
|
-
compressed_tensors-0.10.
|
59
|
-
compressed_tensors-0.10.
|
60
|
-
compressed_tensors-0.10.
|
61
|
+
compressed_tensors-0.10.2a20250612.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
62
|
+
compressed_tensors-0.10.2a20250612.dist-info/METADATA,sha256=541wdYU5905X69fwti-7pubCIzjsENQnbOxpJt4X2qQ,7005
|
63
|
+
compressed_tensors-0.10.2a20250612.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
64
|
+
compressed_tensors-0.10.2a20250612.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
65
|
+
compressed_tensors-0.10.2a20250612.dist-info/RECORD,,
|
{compressed_tensors-0.10.2a20250609.dist-info → compressed_tensors-0.10.2a20250612.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|
File without changes
|