compressed-tensors 0.10.1a20250604__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -68,6 +68,10 @@ from transformers import AutoConfig
68
68
  from transformers.file_utils import CONFIG_NAME
69
69
 
70
70
 
71
+ if TYPE_CHECKING:
72
+ from compressed_tensors.compressors import BaseQuantizationCompressor
73
+
74
+
71
75
  __all__ = ["ModelCompressor", "map_module_to_scheme"]
72
76
 
73
77
  _LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -257,7 +261,9 @@ class ModelCompressor:
257
261
  self.sparsity_config = sparsity_config
258
262
  self.quantization_config = quantization_config
259
263
  self.sparsity_compressor = None
260
- self.quantization_compressor = None
264
+ self.quantization_compressor: Optional[
265
+ Union[BaseQuantizationCompressor, DenseCompressor]
266
+ ] = None
261
267
 
262
268
  if sparsity_config is not None:
263
269
  self.sparsity_compressor = BaseCompressor.load_from_registry(
@@ -12,13 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Dict, Generator, Tuple
15
+ from typing import TYPE_CHECKING, Dict, Generator, Tuple
16
16
 
17
+ import torch
17
18
  from compressed_tensors.compressors.base import BaseCompressor
18
19
  from compressed_tensors.config import CompressionFormat
19
20
  from torch import Tensor
20
21
 
21
22
 
23
+ if TYPE_CHECKING:
24
+ from compressed_tensors.quantization import QuantizationScheme
25
+
26
+
22
27
  @BaseCompressor.register(name=CompressionFormat.dense.value)
23
28
  class DenseCompressor(BaseCompressor):
24
29
  """
@@ -47,3 +52,16 @@ class DenseCompressor(BaseCompressor):
47
52
  ) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
48
53
  for key, value in state_dict.items():
49
54
  yield key, value
55
+
56
+ def decompress_module_from_state_dict(
57
+ self,
58
+ prefix: str,
59
+ state_dict: Dict[str, torch.Tensor],
60
+ scheme: "QuantizationScheme",
61
+ ) -> Dict[str, torch.Tensor]:
62
+ """
63
+ This function is implemented as a workaround because of how
64
+ `ModelCompressor.quantization_compressor` can be set to either
65
+ an instance of `BaseQuantizationCompressor` or `DenseCompressor`.
66
+ """
67
+ return state_dict.copy()
@@ -183,9 +183,7 @@ def apply_quantization_config(
183
183
  replace_module(model, name, compressed_linear)
184
184
 
185
185
  # target matched - add layer and scheme to target list
186
- submodule.quantization_scheme = _scheme_from_targets(
187
- target_to_scheme, targets, name
188
- )
186
+ submodule.quantization_scheme = scheme
189
187
 
190
188
  names_to_scheme[name] = submodule.quantization_scheme
191
189
 
@@ -18,3 +18,8 @@
18
18
  from .transform_args import *
19
19
  from .transform_scheme import *
20
20
  from .transform_config import *
21
+
22
+ from .factory.base import *
23
+ from .factory.hadamard import *
24
+ from .factory.matrix_multiply import *
25
+ from .factory.random_hadamard import *
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,164 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn.utils.parametrize as P
20
+ from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
21
+ from compressed_tensors.registry.registry import RegistryMixin, T
22
+ from compressed_tensors.transform import (
23
+ TransformArgs,
24
+ TransformLocation,
25
+ TransformScheme,
26
+ )
27
+ from compressed_tensors.utils import (
28
+ align_module_device,
29
+ has_offloaded_params,
30
+ patch_attr,
31
+ register_offload_module,
32
+ update_offload_parameter,
33
+ )
34
+ from torch import Tensor
35
+ from torch.nn import Module, Parameter
36
+
37
+
38
+ __all__ = ["TransformFactory", "TransformBase"]
39
+
40
+
41
+ class TransformFactory(RegistryMixin, ABC):
42
+ """
43
+ Abstract factory base used to create and apply transforms to a model
44
+
45
+ :param name: name associated with transform scheme
46
+ :param scheme: transform scheme which defines how transforms should be created
47
+ :param seed: random seed used to transform weight randomization
48
+ """
49
+
50
+ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
51
+ self.name = name
52
+ self.scheme = scheme
53
+ self.generator = torch.Generator()
54
+ if seed is not None:
55
+ self.generator.manual_seed(seed)
56
+
57
+ @classmethod
58
+ def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T:
59
+ """
60
+ Create a transform factory from a scheme
61
+
62
+ :param scheme: defines how transforms should be created
63
+ :param kwargs: TransformFactory constructor arguments
64
+ :return: subclass of `TransformFactory` corresponding to the scheme type
65
+ """
66
+ constructor = cls.get_value_from_registry(name=scheme.type)
67
+ return constructor(scheme=scheme, **kwargs)
68
+
69
+ @abstractmethod
70
+ def create_transform(self, module: Module, args: TransformArgs) -> "TransformBase":
71
+ """
72
+ Abstract method which defines how a transform should be created. May utilize
73
+ caching to maximize shared memory
74
+
75
+ :param module: parent module that transform will be applied to
76
+ :param args: defines how the transform will be applied to the module
77
+ :return: instance of TransformBase
78
+ """
79
+ raise NotImplementedError()
80
+
81
+ def apply_to_model(self, model: Module):
82
+ """
83
+ Create transforms and apply them to the model
84
+
85
+ :param model: module to apply transforms to
86
+ """
87
+ for arg in self.scheme.apply:
88
+ for name, module in list(model.named_modules()):
89
+ if is_target(name, module, arg.targets, arg.ignore):
90
+ self._apply_to_module(module, arg)
91
+
92
+ def _apply_to_module(self, module: Module, args: TransformArgs):
93
+ """
94
+ Create transforms and apply them to the module
95
+
96
+ :param module: target module to apply transforms to
97
+ :param args: defines how the transform will be applied to the target module
98
+ """
99
+ # create transform as submodule
100
+ transform_name = f"{self.name}_{args.location.value}"
101
+ transform = self.create_transform(module, args)
102
+ register_offload_module(module, transform_name, transform) # (1)
103
+
104
+ # register input transformation hook
105
+ if args.location == TransformLocation.INPUT:
106
+
107
+ def input_hook(_, args):
108
+ input = args[0]
109
+ return transform(input)
110
+
111
+ module.register_forward_pre_hook(input_hook, prepend=True)
112
+
113
+ # eagerly apply transformation to weight
114
+ elif args.location in (
115
+ TransformLocation.WEIGHT_INPUT,
116
+ TransformLocation.WEIGHT_OUTPUT,
117
+ ):
118
+ assert isinstance(module, torch.nn.Linear)
119
+ assert module.bias is None
120
+
121
+ with torch.no_grad(), align_module_device(module):
122
+ update_offload_parameter(module, "weight", transform(module.weight))
123
+
124
+ if self.scheme.requires_grad:
125
+ # for training, the weight changes with every forward pass
126
+ # so we can leverage parametrization to propagate the gradient
127
+ if has_offloaded_params(module):
128
+ raise ValueError("Offloaded training is not supported")
129
+ P.register_parametrization(module, "weight", transform)
130
+
131
+ # register output transformation hook
132
+ elif args.location == TransformLocation.OUTPUT:
133
+
134
+ def output_hook(_, _input, output):
135
+ return transform(output)
136
+
137
+ module.register_forward_hook(output_hook)
138
+
139
+ # other locations such as q_attn and k_attn have not been implemented
140
+ else:
141
+ raise NotImplementedError()
142
+
143
+ # (1) even in the `weight` cases, this submodule attachment is needed in order
144
+ # to support saving in the frozen state
145
+
146
+
147
+ class TransformBase(Module, ABC):
148
+ """
149
+ Represents the application of a transform accord to TransformArgs
150
+ """
151
+
152
+ args: TransformArgs
153
+ weight: Parameter
154
+
155
+ @abstractmethod
156
+ def forward(self, value: Tensor) -> Tensor:
157
+ raise NotImplementedError()
158
+
159
+ def right_inverse(self, value: Tensor) -> Tensor:
160
+ with patch_attr(self.args, "inverse", not self.args.inverse):
161
+ return self.forward(value)
162
+
163
+ def __repr__(self):
164
+ return f"{self.__class__.__name__}(inverse={self.args.inverse})"
@@ -0,0 +1,79 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from compressed_tensors.transform import TransformArgs, TransformScheme
19
+ from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20
+ from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
21
+ from compressed_tensors.transform.utils.utils import (
22
+ apply_transform_weight,
23
+ get_matrix_size,
24
+ )
25
+ from compressed_tensors.utils import get_offloaded_device
26
+ from compressed_tensors.utils.helpers import ParameterizedDefaultDict
27
+ from torch import Tensor, device, dtype
28
+ from torch.nn import Linear, Module, Parameter
29
+
30
+
31
+ @TransformFactory.register("hadamard")
32
+ class HadamardFactory(TransformFactory):
33
+ """
34
+ Factory used to apply hadamard transforms to a model
35
+
36
+ :param name: name associated with transform scheme
37
+ :param scheme: transform scheme which defines how transforms should be created
38
+ :param seed: random seed used to transform weight randomization
39
+ """
40
+
41
+ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
42
+ super().__init__(name, scheme, seed)
43
+ self.weights = ParameterizedDefaultDict(self._create_weight)
44
+
45
+ def create_transform(self, module: Module, args: TransformArgs):
46
+ """
47
+ Create a HadamardTransform for applying to a module. Transforms with the same
48
+ size, dtype, and device are cached
49
+
50
+ :param module: parent module that transform will be applied to
51
+ :param args: defines how the transform will be applied to the module
52
+ """
53
+ assert isinstance(module, Linear)
54
+ size = get_matrix_size(module, args.location)
55
+ dtype = module.weight.dtype
56
+ device = get_offloaded_device(module)
57
+
58
+ weight = self.weights[size, dtype, device]
59
+ return HadamardTransform(weight, args)
60
+
61
+ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62
+ data = deterministic_hadamard_matrix(size, dtype, device)
63
+ data = data.to(dtype=dtype, device=device)
64
+ return Parameter(data, requires_grad=self.scheme.requires_grad)
65
+
66
+
67
+ class HadamardTransform(TransformBase):
68
+ def __init__(self, weight: Parameter, args: TransformArgs):
69
+ super().__init__()
70
+ self.weight = weight
71
+ self.args = args
72
+
73
+ def forward(self, value: Tensor) -> Tensor:
74
+ if not self.args.inverse:
75
+ weight = self.weight
76
+ else:
77
+ weight = self.weight.T
78
+
79
+ return apply_transform_weight(weight, value, self.args.location)
@@ -0,0 +1,90 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from compressed_tensors.transform import TransformArgs, TransformScheme
19
+ from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20
+ from compressed_tensors.transform.utils.utils import (
21
+ apply_transform_weight,
22
+ get_matrix_size,
23
+ )
24
+ from compressed_tensors.utils import get_offloaded_device
25
+ from compressed_tensors.utils.helpers import ParameterizedDefaultDict
26
+ from torch import Tensor, device, dtype
27
+ from torch.nn import Linear, Module, Parameter
28
+
29
+
30
+ @TransformFactory.register("random-matrix")
31
+ class RandomMatrixFactory(TransformFactory):
32
+ """
33
+ Factory used to apply random matrix transforms to a model
34
+
35
+ :param name: name associated with transform scheme
36
+ :param scheme: transform scheme which defines how transforms should be created
37
+ :param seed: random seed used to transform weight randomization
38
+ """
39
+
40
+ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
41
+ super().__init__(name, scheme, seed)
42
+ self.weights = ParameterizedDefaultDict(self._create_weight)
43
+ self.inverses = ParameterizedDefaultDict(self._create_inverse)
44
+
45
+ def create_transform(self, module: Module, args: TransformArgs):
46
+ """
47
+ Create a RandomMatrixTransform for applying to a module. Transforms with the
48
+ same size, dtype, and device are cached
49
+
50
+ :param module: parent module that transform will be applied to
51
+ :param args: defines how the transform will be applied to the module
52
+ """
53
+ assert isinstance(module, Linear)
54
+ size = get_matrix_size(module, args.location)
55
+ dtype = module.weight.dtype
56
+ device = get_offloaded_device(module)
57
+
58
+ weight = self.weights[size, dtype, device]
59
+ if args.inverse:
60
+ weight = self.inverses[weight]
61
+
62
+ return RandomMatrixTransform(weight, args)
63
+
64
+ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65
+ data = torch.rand(
66
+ (size, size), generator=self.generator, dtype=dtype, device=device
67
+ )
68
+ return Parameter(data, requires_grad=self.scheme.requires_grad)
69
+
70
+ def _create_inverse(self, weight: Parameter) -> Parameter:
71
+ data = high_precision_invert(weight.data)
72
+ return Parameter(data, requires_grad=False)
73
+
74
+
75
+ class RandomMatrixTransform(TransformBase):
76
+ def __init__(self, weight: Tensor, args: TransformArgs):
77
+ super().__init__()
78
+ self.weight = weight # is an inverse if args.inverse
79
+ self.args = args
80
+
81
+ def forward(self, value: Tensor) -> Parameter:
82
+ return apply_transform_weight(self.weight, value, self.args.location)
83
+
84
+ def right_inverse(self, value: Tensor) -> Tensor:
85
+ inverse = high_precision_invert(self.weight)
86
+ return apply_transform_weight(inverse, value, self.args.location)
87
+
88
+
89
+ def high_precision_invert(weight: Tensor) -> Tensor:
90
+ return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
@@ -0,0 +1,34 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from compressed_tensors.transform import HadamardFactory, TransformFactory
16
+ from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
17
+ from torch import device, dtype
18
+ from torch.nn import Parameter
19
+
20
+
21
+ @TransformFactory.register("random-hadamard")
22
+ class RandomHadamardFactory(HadamardFactory):
23
+ """
24
+ Factory used to apply random hadamard transforms to a model
25
+
26
+ :param name: name associated with transform scheme
27
+ :param scheme: transform scheme which defines how transforms should be created
28
+ :param seed: random seed used to transform weight randomization
29
+ """
30
+
31
+ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32
+ data = random_hadamard_matrix(size, dtype, device, self.generator)
33
+ data = data.to(dtype=dtype, device=device)
34
+ return Parameter(data, requires_grad=self.scheme.requires_grad)
@@ -13,15 +13,31 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Any, List
16
+ from typing import List
17
17
 
18
18
  from pydantic import BaseModel, Field, field_validator
19
19
 
20
20
 
21
- __all__ = ["TransformArgs"]
21
+ __all__ = ["TransformArgs", "TransformLocation"]
22
22
 
23
23
 
24
24
  class TransformLocation(str, Enum):
25
+ """
26
+ Enum representing which parameters/activations a transform weight should be applied
27
+ to on a given module.
28
+
29
+ | -------------------------------------------------------------------------------------------------------- | # noqa: E501
30
+ | Name | Runtime | Values | Locations Where Inverse Could Be Applied | # noqa: E501
31
+ | --------------- | ----------- | ------------- | -------------------------------------------------------- | # noqa: E501
32
+ | `INPUT` | online | activations | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.WEIGHT_INPUT` | # noqa: E501
33
+ | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501
34
+ | `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
35
+ | `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
36
+ | `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501
37
+ | `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501
38
+ | -------------------------------------------------------------------------------------------------------- | # noqa: E501
39
+ """
40
+
25
41
  INPUT = "input"
26
42
  WEIGHT_INPUT = "weight_input"
27
43
  WEIGHT_OUTPUT = "weight_output"
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,160 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ import torch
20
+ from safetensors import safe_open
21
+
22
+
23
+ REPO_PATH = Path(__file__).parent / "hadamards.safetensors"
24
+
25
+
26
+ __all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
27
+
28
+
29
+ # note that hadamard matrix multiplication can be accelerated using a library such as
30
+ # https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
31
+
32
+
33
+ def deterministic_hadamard_matrix(
34
+ size: int,
35
+ dtype: torch.dtype = torch.bfloat16,
36
+ device: torch.device = torch.device("cpu"),
37
+ ) -> torch.Tensor:
38
+ """
39
+ Construct an n-by-n Hadamard matrix, using Sylvester's construction.
40
+ `n` must be a power of 2.
41
+
42
+ Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
43
+
44
+ :param size: order of the matrix, must be a power of 2
45
+ :param dtype: data type of matrix
46
+ :param device: device to construct matrix on
47
+ :return: hadamard matrix of size `size`
48
+ """
49
+ if size <= 0:
50
+ raise ValueError("Cannot construct deterministic hadamard of size <= 0")
51
+
52
+ log2 = int(math.log2(size))
53
+ if size != 2**log2:
54
+ raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
55
+
56
+ H = torch.tensor([[1]], dtype=dtype, device=device)
57
+
58
+ # Sylvester's construction
59
+ for _ in range(log2):
60
+ H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
61
+
62
+ return H / math.sqrt(size)
63
+
64
+
65
+ def random_hadamard_matrix(
66
+ size: int,
67
+ dtype: torch.dtype = torch.bfloat16,
68
+ device: torch.device = torch.device("cpu"),
69
+ gen: Optional[torch.Generator] = None,
70
+ ) -> torch.Tensor:
71
+ """
72
+ Produces a randomly generated Hadamard matrix. Differs from
73
+ `deterministic_hadamard_matrix` in that this function supports non powers of 2
74
+ and randomization using a seeded generator
75
+
76
+ Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
77
+ Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501
78
+
79
+ :param size: The dimension of the hamadard matrix
80
+ :param dtype: data type of matrix
81
+ :param device: device to construct matrix on
82
+ :param gen: Optional generator random values
83
+ :return: randomly generated hadamard matrix
84
+ """
85
+ Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
86
+ Q = Q.to(device=device)
87
+ Q = Q * 2 - 1
88
+ Q = torch.diag(Q)
89
+ return _matmul_hadU(Q) / math.sqrt(size)
90
+
91
+
92
+ def is_pow2(n: int) -> bool:
93
+ """
94
+ Check if a number is a power of 2
95
+
96
+ :param n: number to check
97
+ :return: True iff `n` is a power of 2
98
+ """
99
+ return n > 0 and (n & (n - 1) == 0)
100
+
101
+
102
+ def _fetch_hadamard_divisor(
103
+ n: int,
104
+ dtype: torch.dtype,
105
+ device: torch.device = torch.device("cpu"),
106
+ file_path: str = REPO_PATH,
107
+ ) -> Optional[torch.Tensor]:
108
+ """
109
+ Fetch a known hadamard matrix from the given file path. The returned matrix will
110
+ be of of size `k` such that `n / k` is a power of two. Return None if no such
111
+ matrix exists.
112
+
113
+ Note: This function reopens the safetensors file every time it is called.
114
+ This is technically inefficient, but a very small runtime cost and simpler
115
+ than forcing callers to manage the file open context
116
+
117
+ :param n: size of known hadamard matrix
118
+ :return: a known hadamard matrix of size `n` if one exists, else None
119
+ """
120
+ with safe_open(file_path, framework="pt", device=str(device)) as file:
121
+ divisors = sorted((int(key) for key in file.keys()), reverse=True)
122
+ for divisor in divisors:
123
+ if n % divisor == 0 and is_pow2(n // divisor):
124
+ return file.get_tensor(str(divisor)).to(dtype=dtype)
125
+
126
+ return None
127
+
128
+
129
+ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
130
+ size = X.size(0)
131
+ dtype = X.dtype
132
+ device = X.device
133
+
134
+ # Check if we have the determined hadamard matrix
135
+ hadK = _fetch_hadamard_divisor(size, dtype, device=device)
136
+ if hadK is None:
137
+ raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
138
+ K = hadK.size(0)
139
+
140
+ # Reshape diag matrix with randomized -1/+1
141
+ input = X.clone().view(-1, size, 1)
142
+ output = input.clone()
143
+ while input.shape[1] > K:
144
+ input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
145
+ output = output.view(input.shape)
146
+ output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
147
+ output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
148
+ output = output.view(input.shape[0], input.shape[1], -1)
149
+ (input, output) = (output, input)
150
+ assert input.shape[1] == K
151
+ del output
152
+
153
+ # Do not explicitly repeat - OOM
154
+ # input = torch.bmm(
155
+ # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
156
+ # Use bcast instead
157
+ input = hadK.view(1, K, K).to(input) @ input
158
+
159
+ # normalize
160
+ return input.view(X.shape)
@@ -0,0 +1,91 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from compressed_tensors.transform import TransformLocation
17
+
18
+
19
+ __all__ = ["get_matrix_size", "apply_transform_weight"]
20
+
21
+
22
+ def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int:
23
+ """
24
+ Determine the size of a matrix given its location on the module
25
+
26
+ :param module: module that matrix will be applied to
27
+ :param location: location on module
28
+ :return: size of matrix
29
+ """
30
+ assert isinstance(module, torch.nn.Linear)
31
+ if location in ("input", TransformLocation.WEIGHT_INPUT):
32
+ return module.in_features
33
+ else:
34
+ return module.out_features
35
+
36
+
37
+ def apply_transform_weight(
38
+ weight: torch.Tensor,
39
+ value: torch.Tensor,
40
+ location: TransformLocation,
41
+ ) -> torch.Tensor:
42
+ """
43
+ Using the transform location, determine how to apply the transform weight to the
44
+ given value. For more info on input and output transforms, see `TransformLocation`
45
+
46
+ The following explains how weights should be applied to values according to location
47
+
48
+ let x be input activation
49
+ W be weight,
50
+ yh, xh, Wh be transformed output, input, weight
51
+
52
+ note that
53
+ y = (x W.T) // torch.nn.Linear
54
+
55
+ Choose values for yh, xh, and Wh which incorporate matrix transforms
56
+
57
+ let V, Vi be transform matrices on input side
58
+ U, Ui be transform matrices on output side
59
+
60
+ pick xh = (x V)
61
+ Wh = (U.T W Vi.T)
62
+ yh = (y U)
63
+
64
+ The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
65
+
66
+ (xh) (Wh).T = (x V) (U.T W Vi.T).T
67
+ = (x V) (Vi W.T U) // transpose matrix product identity
68
+ = (x W.T) U
69
+ = y U
70
+ = yh
71
+
72
+ :param weight: transform weight to apply
73
+ :param value: value to apply weight to
74
+ :param location: determines how weight should be applied
75
+ :return: value after transform weight has been applied
76
+ """
77
+
78
+ if location == TransformLocation.INPUT:
79
+ return value @ weight
80
+
81
+ elif location == TransformLocation.WEIGHT_INPUT:
82
+ return value @ weight.T
83
+
84
+ elif location == TransformLocation.WEIGHT_OUTPUT:
85
+ return weight.T @ value
86
+
87
+ elif location == TransformLocation.OUTPUT:
88
+ return value @ weight
89
+
90
+ else:
91
+ raise NotImplementedError(f"{location} has not been implemented yet")
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import contextlib
15
16
  import warnings
16
17
  from functools import wraps
17
18
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
@@ -38,6 +39,8 @@ __all__ = [
38
39
  "shard_tensor",
39
40
  "pack_bitmasks",
40
41
  "unpack_bitmasks",
42
+ "patch_attr",
43
+ "ParameterizedDefaultDict",
41
44
  ]
42
45
 
43
46
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -328,3 +331,53 @@ def unpack_bitmasks(
328
331
  )
329
332
 
330
333
  return unpacked_bitmasks_torch
334
+
335
+
336
+ @contextlib.contextmanager
337
+ def patch_attr(base: object, attr: str, value: Any):
338
+ """
339
+ Patch the value of an object attribute. Original value is restored upon exit
340
+
341
+ :param base: object which has the attribute to patch
342
+ :param attr: name of the the attribute to patch
343
+ :param value: used to replace original value
344
+
345
+ Usage:
346
+ >>> from types import SimpleNamespace
347
+ >>> obj = SimpleNamespace()
348
+ >>> with patch_attr(obj, "attribute", "value"):
349
+ ... assert obj.attribute == "value"
350
+ >>> assert not hasattr(obj, "attribute")
351
+ """
352
+ _sentinel = object()
353
+ original_value = getattr(base, attr, _sentinel)
354
+
355
+ setattr(base, attr, value)
356
+ try:
357
+ yield
358
+ finally:
359
+ if original_value is not _sentinel:
360
+ setattr(base, attr, original_value)
361
+ else:
362
+ delattr(base, attr)
363
+
364
+
365
+ class ParameterizedDefaultDict(dict):
366
+ """
367
+ Similar to `collections.DefaultDict`, but upon fetching a key which is missing,
368
+ the key is passed as arguments to the `default_factory`
369
+
370
+ :param default_factory: function which takes a key as input and returns the
371
+ corresponding default value
372
+ """
373
+
374
+ def __init__(self, default_factory: Callable[[Any], Any]):
375
+ self.default_factory = default_factory
376
+
377
+ def __missing__(self, key):
378
+ if isinstance(key, tuple):
379
+ value = self.default_factory(*key)
380
+ else:
381
+ value = self.default_factory(key)
382
+ self[key] = value
383
+ return value
@@ -14,27 +14,30 @@
14
14
  """
15
15
  Utilities associated with offloading functionality provided by `accelerate`.
16
16
 
17
- | ----------------------------------------------------------------------------------------------------- | # noqa: E501
18
- | Operation | Without offloading support | With offloading support | # noqa: E501
19
- | --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20
- | Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21
- | Check | N/A | has_offloaded_params(module) | # noqa: E501
22
- | Onload | N/A | with align_module_device(module) | # noqa: E501
23
- | Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24
- | Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25
- | ----------------------------------------------------------------------------------------------------- | # noqa: E501
17
+ | ------------------------------------------------------------------------------------------------------ | # noqa: E501
18
+ | Operation | Without offloading support | With offloading support | # noqa: E501
19
+ | ---------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20
+ | Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21
+ | Check | N/A | has_offloaded_params(module) | # noqa: E501
22
+ | Onload | N/A | with align_module_device(module) | # noqa: E501
23
+ | Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24
+ | Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25
+ | Add Module | module.register_module(name, child) | register_offload_module(name, child) | # noqa: E501
26
+ | Del Module | del module.name | delete_offload_module(module, name) | # noqa: E501
27
+ | ------------------------------------------------------------------------------------------------------ | # noqa: E501
26
28
  """
27
29
 
28
30
  import contextlib
29
31
  import warnings
30
32
  from functools import wraps
31
- from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union
33
+ from operator import attrgetter
34
+ from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Union
32
35
 
33
36
  import torch
37
+ from compressed_tensors.utils import patch_attr
34
38
 
35
39
 
36
40
  try:
37
- from accelerate import dispatch_model
38
41
  from accelerate.hooks import (
39
42
  AlignDevicesHook,
40
43
  add_hook_to_module,
@@ -45,10 +48,12 @@ try:
45
48
  from accelerate.utils import (
46
49
  OffloadedWeightsLoader,
47
50
  PrefixedDataset,
51
+ find_tied_parameters,
48
52
  set_module_tensor_to_device,
49
53
  )
50
54
 
51
55
  _has_accelerate = True
56
+
52
57
  except ImportError:
53
58
  _has_accelerate = False
54
59
  AlignDevicesHook = None
@@ -58,8 +63,8 @@ except ImportError:
58
63
  PrefixedDataset = None
59
64
  set_module_tensor_to_device = None
60
65
  named_module_tensors = None
61
- dispatch_model = None
62
66
  attach_align_device_hook = None
67
+ find_tied_parameters = None
63
68
 
64
69
 
65
70
  __all__ = [
@@ -78,22 +83,28 @@ __all__ = [
78
83
  "align_module_device",
79
84
  "register_offload_module",
80
85
  "delete_offload_module",
81
- "force_cpu_offload",
86
+ "offloaded_dispatch",
87
+ "disable_offloading",
88
+ "remove_dispatch",
82
89
  ]
83
90
 
84
91
 
85
92
  def check_accelerate(fallback: Any):
86
93
  def decorator(func: Callable[[Any], Any]):
87
94
  if not _has_accelerate:
88
-
89
95
  if fallback == "error":
90
- raise ValueError(
91
- "Please install `accelerate` in order to use this function"
92
- )
93
96
 
94
- @wraps(func)
95
- def fallback_fn(*args, **kwargs):
96
- return fallback
97
+ @wraps(func)
98
+ def fallback_fn(*args, **kwargs):
99
+ raise ValueError(
100
+ "Please install `accelerate` in order to use this function"
101
+ )
102
+
103
+ else:
104
+
105
+ @wraps(func)
106
+ def fallback_fn(*args, **kwargs):
107
+ return fallback
97
108
 
98
109
  return fallback_fn
99
110
 
@@ -160,22 +171,22 @@ def update_parameter_data(
160
171
 
161
172
  def get_execution_device(module: torch.nn.Module) -> torch.device:
162
173
  """
163
- Get the device which inputs should be moved to before module execution
174
+ Get the device which inputs should be moved to before module execution.
175
+ Assume that modules execute in the same order as returned by `model.modules()`
164
176
 
165
177
  :param module: module to check, may be offloaded
166
178
  :return: onload device of module
167
179
  """
168
- if has_offloaded_params(module):
169
- return module._hf_hook.execution_device
180
+ for submodule in module.modules():
181
+ if has_offloaded_params(submodule):
182
+ return submodule._hf_hook.execution_device
170
183
 
171
- first_param = next(module.parameters(), None)
172
- if first_param is None:
173
- warnings.warn(
174
- f"Unable able to infer execution device of {module}, falling back to CPU"
175
- )
176
- return torch.device("cpu")
184
+ param = next(submodule.parameters(recurse=False), None)
185
+ if param is not None:
186
+ return param.device
177
187
 
178
- return first_param.device
188
+ warnings.warn(f"Unable to get execution device of {module}, falling back to CPU")
189
+ return torch.device("cpu")
179
190
 
180
191
 
181
192
  def register_offload_parameter(
@@ -196,9 +207,24 @@ def register_offload_parameter(
196
207
  has_onload = any(p.device != torch.device("meta") for p in module.parameters())
197
208
  module.register_parameter(name, parameter)
198
209
 
210
+ # do everything AlignDevicesHook.init_hook does
211
+ # https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
199
212
  if has_offloaded_params(module):
200
- weights_map = module._hf_hook.weights_map
201
- offload_to_weights_map(weights_map, name, parameter.data, offload_device)
213
+ hook: AlignDevicesHook = module._hf_hook
214
+ assert hook.weights_map is not None
215
+
216
+ # append to original_devices
217
+ hook.original_devices[name] = parameter.device
218
+
219
+ # append to weights map
220
+ offload_to_weights_map(hook.weights_map, name, parameter.data, offload_device)
221
+
222
+ # append to tied_params_map
223
+ offloaded = hook.weights_map[name]
224
+ if hook.tied_params_map is not None:
225
+ hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
226
+
227
+ # perform offloading
202
228
  if not has_onload:
203
229
  set_module_tensor_to_device(module, name, "meta")
204
230
 
@@ -206,7 +232,7 @@ def register_offload_parameter(
206
232
  def update_offload_parameter(
207
233
  module: torch.nn.Module,
208
234
  name: str,
209
- data: Optional[torch.Tensor],
235
+ data: torch.Tensor,
210
236
  offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
211
237
  ):
212
238
  """
@@ -219,7 +245,7 @@ def update_offload_parameter(
219
245
  :param offload_device: device on which weight will be offloaded to. If None is
220
246
  provided, then infer device from parameters on module
221
247
  """
222
- param = getattr(module, name)
248
+ param: torch.nn.Parameter = getattr(module, name)
223
249
  if param.data.shape != data.shape:
224
250
  warnings.warn(
225
251
  f"Shape of parameter being updated {param.data.shape} does not match shape "
@@ -227,7 +253,7 @@ def update_offload_parameter(
227
253
  )
228
254
 
229
255
  # copy data into onloaded parameter if applicable
230
- if param.device != torch.device("meta"):
256
+ if param.device != torch.device("meta") and data is not param.data:
231
257
  param.data.copy_(data)
232
258
 
233
259
  # update offload dict
@@ -412,7 +438,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
412
438
  hook: AlignDevicesHook = base._hf_hook
413
439
  assert hook.offload
414
440
  assert hook.weights_map is not None
415
- assert hook.tied_params_map is not None
416
441
 
417
442
  # offloading kwargs for submodule
418
443
  place_submodules = False
@@ -427,7 +452,8 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
427
452
  module, include_buffers=offload_buffers, recurse=place_submodules
428
453
  ):
429
454
  offloaded = param.to(offload_device)
430
- hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
455
+ if hook.tied_params_map is not None:
456
+ hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
431
457
  offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
432
458
 
433
459
  # if the parent places submodules, offload here
@@ -455,9 +481,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
455
481
 
456
482
  base.register_module(name, module)
457
483
 
458
- # (1): Since we cannot know which pointers are shared when we add parameters in an
459
- # online way, assume that all pointers are shared. This comes at no runtime cost
460
-
461
484
 
462
485
  def delete_offload_module(base: torch.nn.Module, name: str):
463
486
  """
@@ -474,46 +497,106 @@ def delete_offload_module(base: torch.nn.Module, name: str):
474
497
 
475
498
 
476
499
  @check_accelerate(fallback="error")
477
- def force_cpu_offload(
478
- module: torch.nn.Module, execution_device: torch.device
500
+ def offloaded_dispatch(
501
+ module: torch.nn.Module,
502
+ execution_device: torch.device,
503
+ offload_device: Union[torch.device, Literal["disk"]] = torch.device("cpu"),
479
504
  ) -> torch.nn.Module:
480
505
  """
481
- Force cpu offloading a module, primarily used for testing
506
+ Unlike `dispatch_model`, this function forces a module (and its submodules) to
507
+ offload all parameters and replace them with meta tensors, utiliizing the
508
+ `AlignDevicesHook` to control onloading and offloading.
482
509
 
483
510
  :param module: module containing parameters to offload
484
- :param execution_device: execution device submodules
485
- :return: module with hooks to perform cpu offloading
486
- """
487
- # edge case: there is a bug in `dispatch_model` which causes
488
- # the function to only work if the model contains submodules
489
- if next(module.children(), None) is None:
490
- attach_align_device_hook(
491
- module,
492
- execution_device=execution_device,
493
- offload=True,
494
- weights_map=module.state_dict(),
495
- tied_params_map={},
496
- )
497
- return module
511
+ :param execution_device: device that modules will onload and execute on
512
+ :param offload_device: device that module parameters will offload to
513
+ :return: module with offloading device hooks
514
+ """
515
+ if offload_device == "disk":
516
+ raise NotImplementedError("Disk offloading is not currently supported")
517
+
518
+ # remove any existing hooks
519
+ remove_dispatch(module)
520
+
521
+ # create weights map
522
+ state_dict = module.state_dict()
523
+ state_dict = {key: val.to(offload_device) for key, val in state_dict.items()}
524
+ weights_map = OffloadedWeightsLoader(state_dict=state_dict, device=offload_device)
525
+
526
+ # create tied params map
527
+ tied_params = find_tied_parameters(module)
528
+ tied_params_map = {}
529
+ for group in tied_params:
530
+ for param_name in group:
531
+ data_ptr = attrgetter(param_name)(module).data_ptr()
532
+ tied_params_map[data_ptr] = {}
533
+
534
+ # recursively attaches hooks to all submodules
535
+ attach_align_device_hook(
536
+ module,
537
+ execution_device=execution_device,
538
+ offload=True,
539
+ weights_map=weights_map,
540
+ tied_params_map=tied_params_map,
541
+ )
498
542
 
499
- device_map = {}
543
+ # when saving a model, `PretrainedModel.save_pretrained` will only
544
+ # onload weights if the following requirements are met
545
+ # if (
546
+ # hasattr(self, "hf_device_map")
547
+ # and len(set(self.hf_device_map.values())) > 1
548
+ # and ("cpu" in self.hf_device_map.values()
549
+ # or "disk" in self.hf_device_map.values())
550
+ # ):
551
+ # because this function always offloads, disregard actual devices and
552
+ # always use `cpu` and `cuda:0` to guarantee this condition passes
553
+ setattr(module, "hf_device_map", {"fake_offload": "cpu", "fake_exec": "cuda:0"})
500
554
 
501
- def collect_device_map(name: List[str], module: torch.nn.Module):
502
- if next(module.parameters(recurse=False), None) is not None:
503
- device_map[".".join(name)] = "cpu"
504
- return
555
+ return module
505
556
 
506
- else:
507
- for submodule_name, submodule in module.named_children():
508
- name.append(submodule_name)
509
- collect_device_map(name, submodule)
510
- name.pop()
511
557
 
512
- collect_device_map([], module)
558
+ def remove_dispatch(module: torch.nn.Module) -> torch.nn.Module:
559
+ """
560
+ Remove any existing dispatches from module
513
561
 
514
- return dispatch_model(
515
- module, device_map, main_device=execution_device, force_hooks=True
516
- )
562
+ :param module: module which may be dispatched with hf hooks
563
+ :return: module without dispatch
564
+ """
565
+ remove_hook_from_module(module, recurse=True)
566
+ if hasattr(module, "hf_device_map"):
567
+ delattr(module, "hf_device_map")
568
+
569
+ return module
570
+
571
+
572
+ @contextlib.contextmanager
573
+ def disable_offloading():
574
+ """
575
+ Keep modules onloaded and disable offloading until this context exits.
576
+ Affects modules which have been hooked with accelerate's `AlignDevicesHook`
577
+ """
578
+ original_pre_forward = AlignDevicesHook.pre_forward
579
+ onloaded_modules: Dict[torch.nn.Module, Tuple[AlignDevicesHook, bool]] = dict()
580
+
581
+ # onload once and disable any future onloading/offloading steps
582
+ def keep_onload_pre_forward(self: AlignDevicesHook, module, *args, **kwargs):
583
+ ret = original_pre_forward(self, module, *args, **kwargs)
584
+ if module not in onloaded_modules:
585
+ onloaded_modules[module] = (self, self.offload)
586
+ self.offload = False
587
+ return ret
588
+
589
+ # use the patched pre_forward function within the context
590
+ with patch_attr(AlignDevicesHook, "pre_forward", keep_onload_pre_forward):
591
+ yield
592
+
593
+ # manually offload all modules that were onloaded
594
+ # update any parameters which may have changed
595
+ for module, (hook, offload) in onloaded_modules.items():
596
+ hook.offload = offload
597
+ for name, param in module.named_parameters(recurse=False):
598
+ update_offload_parameter(module, name, param.data)
599
+ hook.post_forward(module, None)
517
600
 
518
601
 
519
602
  """ Upstreamed Functions """
@@ -583,3 +666,7 @@ def align_module_device(
583
666
 
584
667
  else:
585
668
  yield
669
+
670
+
671
+ # (1): Since we cannot know which pointers are shared when we add parameters in an
672
+ # online way, assume that all pointers are shared. This has virtually no runtime cost
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.1.a20250604'
21
- __version_tuple__ = version_tuple = (0, 10, 1)
20
+ __version__ = version = '0.10.2'
21
+ __version_tuple__ = version_tuple = (0, 10, 2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.1a20250604
3
+ Version: 0.10.2
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,11 +1,11 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=Fn_EICNqvyc3DZaaQYlxe0v5oHAtjV8CgHwL8zEhApQ,523
3
+ compressed_tensors/version.py,sha256=hLMY-mTgNhuVqmjaSY9lkyEvKWbHFLs0gilovrevj8M,513
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
7
7
  compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
8
- compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=72h2tWDIGbbqLQF8MDzOehy18eu5TvsCLd_AuzGv_O4,32517
8
+ compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=nt0KxhZakDdlTIebBYcSvqxLCZhA6p6IL_1AYiHLFug,32695
9
9
  compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=KvaFBL_Q84LxRGJOV035M8OBoCkAx8kOkfphswgkKWk,745
10
10
  compressed_tensors/compressors/quantized_compressors/base.py,sha256=ByE3z61boZ5wdz0nhc-2CJH61bSixJQE78pfkS6XRDg,10269
11
11
  compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=0ANDcuD8aXPqTYNPY6GnX9iS6eXJw6P0TzNV_rYS2l8,5369
@@ -13,7 +13,7 @@ compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=G
13
13
  compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=_66tQ8bxslDUdas-ULORXblPw9kdNNn1UJJU9-ZOGPY,11380
14
14
  compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=Atuz-OdEgn8OCUhx7Ovd6gXdyImAI186uCR-uR0t_Nk,737
15
15
  compressed_tensors/compressors/sparse_compressors/base.py,sha256=YNZWcHjDleAlqbgRZQ6oJf44MQb_UDNvJGOqhl26uFA,8098
16
- compressed_tensors/compressors/sparse_compressors/dense.py,sha256=rPaxbP7P52prWNs4lGaiBbpNvsQLElFMwOrq1oBP2Yg,1733
16
+ compressed_tensors/compressors/sparse_compressors/dense.py,sha256=-OujJ1e0iXBvxYVULrIGvAZ9l-IC0mXczZRnimQdgo4,2314
17
17
  compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py,sha256=4cwkj40SFrXEyE_jyt2xjz3R-gTdU9uMpMFUKo1pRBA,8643
18
18
  compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=S8vW0FI9ep_XtUQOxj0P5utJt3vKEYOHjWEPp-Xd9aY,5820
19
19
  compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
@@ -30,7 +30,7 @@ compressed_tensors/quantization/quant_args.py,sha256=2OpiiSdl4KidzNmjx7J8UlQoAYm
30
30
  compressed_tensors/quantization/quant_config.py,sha256=aFi6PKqmEX9iP9O8GVn3mEUjRDEwk_hOCbmmiq-j9oU,10198
31
31
  compressed_tensors/quantization/quant_scheme.py,sha256=IDWa1GWUbUdWCo8j78Jz6svYF5hLz89J2PVYWBBnXRc,7102
32
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
33
- compressed_tensors/quantization/lifecycle/apply.py,sha256=DOoxH4jM8r0270GGGUFOpRrgwaisiJi7TV-Q6E8qM8E,18067
33
+ compressed_tensors/quantization/lifecycle/apply.py,sha256=v7D0TJU_eLT20Odn_J1VCPo2twll2ra-wxlEGBKB2OA,17990
34
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
35
35
  compressed_tensors/quantization/lifecycle/forward.py,sha256=JWOQ-03bsgh9_nnOLAjmLZ0S8bFQA-GjwDK6YUBwcrU,14883
36
36
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
@@ -39,19 +39,28 @@ compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5
39
39
  compressed_tensors/quantization/utils/helpers.py,sha256=bqxNL2NU1XVsSxNzmDVZE3zd65PlLFq1Ir-RHwff8G0,17840
40
40
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
41
41
  compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
42
- compressed_tensors/transform/__init__.py,sha256=oa5VdrE-GtDYYceXNSwj5X_ropoXLLukm6Aufcc9WhY,747
43
- compressed_tensors/transform/transform_args.py,sha256=Sazu_4kXL7IvIEgTaimgo8dV-qacXf_t1NLEfDvPJEU,1759
42
+ compressed_tensors/transform/__init__.py,sha256=mtUOzwq-H7fXGi7sMmfe7zj83fjMg_LAu4DjTZ5vaHk,886
43
+ compressed_tensors/transform/transform_args.py,sha256=8-Ab5_dFfdObfwVCgrWrEWcoVRzXmMBSDSUxjftI-Ss,3177
44
44
  compressed_tensors/transform/transform_config.py,sha256=6JA8VFcoz4EGHOev6thj51OuB7K2gKUUazWjrVPYDLc,2144
45
45
  compressed_tensors/transform/transform_scheme.py,sha256=c7NAuLDL0itFgUfBMNShegMI9bzKL7s4LR3QJTHsXLs,1733
46
+ compressed_tensors/transform/factory/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
47
+ compressed_tensors/transform/factory/base.py,sha256=yVrYWEnrr2RFWE5AjSNeXzO9aXc443dTNMVSxuLztz8,5940
48
+ compressed_tensors/transform/factory/hadamard.py,sha256=zkq6w8uJXRLokUXajAkFb2fJrH0K3SL6qrR2dARrAr8,3139
49
+ compressed_tensors/transform/factory/matrix_multiply.py,sha256=0g4sYC_tOmCjOomae2gl54UTXiFdl0mCCkmbqIRX8yw,3613
50
+ compressed_tensors/transform/factory/random_hadamard.py,sha256=TFInxbHslqREOFFiy_mpR88eEYXQnslxXmyh-ZbN-MU,1499
51
+ compressed_tensors/transform/utils/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
52
+ compressed_tensors/transform/utils/hadamard.py,sha256=U27Kvo-eDebKcVt8oXTSIAaQ5DvPQj9tDv2hdXHCPPQ,5584
53
+ compressed_tensors/transform/utils/hadamards.safetensors,sha256=mFd1GzNodGG-ifA1IoH-0nHYzfraCOvrq_dX2zFI1B4,1436901
54
+ compressed_tensors/transform/utils/utils.py,sha256=PRPTYwPs2nnNaQMq2GEbC4QYKHFKlZwaRyPgdDhl66g,2992
46
55
  compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
47
- compressed_tensors/utils/helpers.py,sha256=RrNvzD08naEjEiXdU-FdZjQVda1nQywu1hA_GCDj0vg,10415
48
- compressed_tensors/utils/offload.py,sha256=hAGjp9aS0HpFVhjYMGf-WTm76WMY6cS-YXhVEn80qPE,20196
56
+ compressed_tensors/utils/helpers.py,sha256=cPg-ikdeA92aIGwBONg8GmPNvcGlFhozyJVwsRiXBTA,11981
57
+ compressed_tensors/utils/offload.py,sha256=3XiBuWbUkBAt8v1t5i57qDcbB3VJQs_FDeayi-JzIWg,23896
49
58
  compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
50
59
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
51
60
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
52
61
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
53
- compressed_tensors-0.10.1a20250604.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
54
- compressed_tensors-0.10.1a20250604.dist-info/METADATA,sha256=Y26QINEI1iAOl6vWp7KaEsgXHCNYlLIJPQXLRAe4KpQ,7005
55
- compressed_tensors-0.10.1a20250604.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- compressed_tensors-0.10.1a20250604.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
57
- compressed_tensors-0.10.1a20250604.dist-info/RECORD,,
62
+ compressed_tensors-0.10.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
63
+ compressed_tensors-0.10.2.dist-info/METADATA,sha256=EJZIGQMfAp5b8c9p9vblRCx6hQPg32jHoDGhNv5W96k,6996
64
+ compressed_tensors-0.10.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
+ compressed_tensors-0.10.2.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
66
+ compressed_tensors-0.10.2.dist-info/RECORD,,