compressed-tensors 0.10.3a20250731__py3-none-any.whl → 0.10.3a20250806__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/quantization/lifecycle/forward.py +6 -1
- compressed_tensors/transform/factory/base.py +48 -3
- compressed_tensors/transform/factory/matrix_multiply.py +1 -0
- compressed_tensors/utils/match.py +67 -13
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.10.3a20250731.dist-info → compressed_tensors-0.10.3a20250806.dist-info}/METADATA +1 -1
- {compressed_tensors-0.10.3a20250731.dist-info → compressed_tensors-0.10.3a20250806.dist-info}/RECORD +10 -10
- {compressed_tensors-0.10.3a20250731.dist-info → compressed_tensors-0.10.3a20250806.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.10.3a20250731.dist-info → compressed_tensors-0.10.3a20250806.dist-info}/licenses/LICENSE +0 -0
- {compressed_tensors-0.10.3a20250731.dist-info → compressed_tensors-0.10.3a20250806.dist-info}/top_level.txt +0 -0
@@ -124,8 +124,13 @@ def dequantize(
|
|
124
124
|
strategy=QuantizationStrategy.GROUP, group_size=group_size
|
125
125
|
)
|
126
126
|
else:
|
127
|
+
rows, cols = x_q.shape[-2], x_q.shape[-1]
|
128
|
+
block_height = rows // scale.shape[0] # Rows per block
|
129
|
+
block_width = cols // scale.shape[1] # Columns per block
|
130
|
+
|
127
131
|
args = QuantizationArgs(
|
128
|
-
strategy=QuantizationStrategy.BLOCK,
|
132
|
+
strategy=QuantizationStrategy.BLOCK,
|
133
|
+
block_structure=[block_height, block_width],
|
129
134
|
)
|
130
135
|
else:
|
131
136
|
raise ValueError(
|
@@ -13,7 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from abc import ABC, abstractmethod
|
16
|
-
from
|
16
|
+
from collections import defaultdict
|
17
|
+
from typing import List, Optional, Tuple, Set
|
17
18
|
|
18
19
|
import torch
|
19
20
|
import torch.nn.utils.parametrize as P
|
@@ -49,10 +50,13 @@ class TransformFactory(RegistryMixin, ABC):
|
|
49
50
|
:param seed: random seed used to transform weight randomization
|
50
51
|
"""
|
51
52
|
|
53
|
+
transforms: List["TransformBase"]
|
54
|
+
|
52
55
|
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
53
56
|
self.name = name
|
54
57
|
self.scheme = scheme
|
55
58
|
self.generator = torch.Generator()
|
59
|
+
self.transforms = list()
|
56
60
|
if seed is not None:
|
57
61
|
self.generator.manual_seed(seed)
|
58
62
|
|
@@ -90,6 +94,8 @@ class TransformFactory(RegistryMixin, ABC):
|
|
90
94
|
for _, module in match_named_modules(model, arg.targets, arg.ignore):
|
91
95
|
self._apply_to_module(module, arg)
|
92
96
|
|
97
|
+
self._update_tied_weights()
|
98
|
+
|
93
99
|
def _apply_to_module(self, module: Module, args: TransformArgs):
|
94
100
|
"""
|
95
101
|
Create transforms and apply them to the module
|
@@ -97,9 +103,17 @@ class TransformFactory(RegistryMixin, ABC):
|
|
97
103
|
:param module: target module to apply transforms to
|
98
104
|
:param args: defines how the transform will be applied to the target module
|
99
105
|
"""
|
106
|
+
if has_offloaded_params(module):
|
107
|
+
if module._hf_hook.place_submodules:
|
108
|
+
raise NotImplementedError(
|
109
|
+
"Applying transforms to offloaded submodules with "
|
110
|
+
"`place_submodules=True` is not supported"
|
111
|
+
)
|
112
|
+
|
100
113
|
# create transform as submodule
|
101
114
|
transform_name = f"{self.name}_{args.location}"
|
102
115
|
transform = self.create_transform(module, args)
|
116
|
+
self.transforms.append(transform)
|
103
117
|
register_offload_module(module, transform_name, transform)
|
104
118
|
|
105
119
|
# register input transformation hook
|
@@ -128,8 +142,9 @@ class TransformFactory(RegistryMixin, ABC):
|
|
128
142
|
raise ValueError("Offloaded training is not supported")
|
129
143
|
P.register_parametrization(module, "weight", transform)
|
130
144
|
|
131
|
-
|
132
|
-
|
145
|
+
else:
|
146
|
+
# transform is no longer needed (unfusing is not supported)
|
147
|
+
delete_offload_module(module, transform_name)
|
133
148
|
|
134
149
|
# register output transformation hook
|
135
150
|
elif args.location == TransformLocation.OUTPUT:
|
@@ -143,6 +158,31 @@ class TransformFactory(RegistryMixin, ABC):
|
|
143
158
|
else:
|
144
159
|
raise NotImplementedError()
|
145
160
|
|
161
|
+
def _update_tied_weights(self):
|
162
|
+
"""
|
163
|
+
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
|
164
|
+
which is used by transformers to detect and remove shared pointers
|
165
|
+
during saving
|
166
|
+
"""
|
167
|
+
# map from data_ptrs to keys
|
168
|
+
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
|
169
|
+
for transform in self.transforms:
|
170
|
+
for name, param in transform.named_parameters(recurse=False):
|
171
|
+
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
|
172
|
+
if has_offloaded_params(transform):
|
173
|
+
param = transform._hf_hook.weights_map[name]
|
174
|
+
ptr_to_keys[param.data_ptr()].append((transform, name))
|
175
|
+
|
176
|
+
# populate `_dynamic_tied_weights_keys` if there is more than one key
|
177
|
+
# and ensure that they share tensors
|
178
|
+
for shared_keys in ptr_to_keys.values():
|
179
|
+
if len(shared_keys) > 1:
|
180
|
+
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
|
181
|
+
|
182
|
+
for transform, name in shared_keys:
|
183
|
+
transform._dynamic_tied_weights_keys.add(name)
|
184
|
+
setattr(transform, name, tensor)
|
185
|
+
|
146
186
|
|
147
187
|
class TransformBase(InternalModule, ABC):
|
148
188
|
"""
|
@@ -151,6 +191,11 @@ class TransformBase(InternalModule, ABC):
|
|
151
191
|
|
152
192
|
args: TransformArgs
|
153
193
|
weight: Parameter
|
194
|
+
_dynamic_tied_weights_keys: Set[str]
|
195
|
+
|
196
|
+
def __init__(self):
|
197
|
+
super().__init__()
|
198
|
+
self._dynamic_tied_weights_keys = set()
|
154
199
|
|
155
200
|
@abstractmethod
|
156
201
|
def forward(self, value: Tensor) -> Tensor:
|
@@ -70,6 +70,7 @@ class RandomMatrixFactory(TransformFactory):
|
|
70
70
|
|
71
71
|
def _create_inverse(self, weight: Parameter) -> Parameter:
|
72
72
|
data = high_precision_invert(weight.data)
|
73
|
+
data = data.contiguous() # ensure proper serialization
|
73
74
|
return Parameter(data, requires_grad=False)
|
74
75
|
|
75
76
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import logging
|
16
16
|
import re
|
17
17
|
from collections.abc import Generator
|
18
|
-
from typing import Iterable, Tuple
|
18
|
+
from typing import Iterable, Mapping, Optional, Tuple
|
19
19
|
|
20
20
|
import torch
|
21
21
|
from compressed_tensors.utils.internal import InternalModule
|
@@ -32,10 +32,14 @@ __all__ = [
|
|
32
32
|
]
|
33
33
|
|
34
34
|
|
35
|
+
FusedMappping = Mapping[str, Iterable[str]]
|
36
|
+
|
37
|
+
|
35
38
|
def match_named_modules(
|
36
39
|
model: torch.nn.Module,
|
37
40
|
targets: Iterable[str],
|
38
41
|
ignore: Iterable[str] = tuple(),
|
42
|
+
fused: Optional[FusedMappping] = None,
|
39
43
|
warn_on_fail: bool = False,
|
40
44
|
) -> Generator[Tuple[str, torch.nn.Module]]:
|
41
45
|
"""
|
@@ -45,16 +49,18 @@ def match_named_modules(
|
|
45
49
|
:param model: model containing submodules to match against
|
46
50
|
:param targets: target strings, potentially containing "re:" prefixes
|
47
51
|
:param ignore: targets to ignore, potentially containing "re:" prefixes
|
52
|
+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
|
53
|
+
corresponding shards. See `compressed_tensors.utils.match.is_match`
|
48
54
|
:param warn_on_fail: if True, warns if any targets do not match any modules in model
|
49
55
|
:return: generator of module names and modules
|
50
56
|
"""
|
51
57
|
unmatched_targets = set(targets)
|
52
58
|
for name, module in model.named_modules():
|
53
59
|
for target in targets:
|
54
|
-
if is_match(name, module, target):
|
60
|
+
if is_match(name, module, target, fused):
|
55
61
|
unmatched_targets -= {target}
|
56
62
|
|
57
|
-
if not any(is_match(name, module, ign) for ign in ignore):
|
63
|
+
if not any(is_match(name, module, ign, fused) for ign in ignore):
|
58
64
|
yield name, module
|
59
65
|
|
60
66
|
if warn_on_fail:
|
@@ -68,6 +74,7 @@ def match_named_parameters(
|
|
68
74
|
model: torch.nn.Module,
|
69
75
|
targets: Iterable[str],
|
70
76
|
ignore: Iterable[str] = tuple(),
|
77
|
+
fused: Optional[FusedMappping] = None,
|
71
78
|
warn_on_fail: bool = False,
|
72
79
|
) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
|
73
80
|
"""
|
@@ -77,6 +84,8 @@ def match_named_parameters(
|
|
77
84
|
:param model: model containing params to match against
|
78
85
|
:param targets: target strings, potentially containing "re:" prefixes
|
79
86
|
:param ignore: targets to ignore, potentially containing "re:" prefixes
|
87
|
+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
|
88
|
+
corresponding shards. See `compressed_tensors.utils.match.is_match`
|
80
89
|
:param warn_on_fail: if True, warns if any targets do not match any params in model
|
81
90
|
:return: generator of fully-qualified param names, parent modules, and params
|
82
91
|
"""
|
@@ -88,10 +97,10 @@ def match_named_parameters(
|
|
88
97
|
for param_name, param in module.named_parameters(recurse=False):
|
89
98
|
param_fqn = f"{module_name}.{param_name}"
|
90
99
|
for target in targets:
|
91
|
-
if _match_name(param_fqn, target):
|
100
|
+
if _match_name(param_fqn, target, fused):
|
92
101
|
unmatched_targets -= {target}
|
93
102
|
|
94
|
-
if not any(_match_name(param_fqn, ign) for ign in ignore):
|
103
|
+
if not any(_match_name(param_fqn, ign, fused) for ign in ignore):
|
95
104
|
yield param_fqn, module, param
|
96
105
|
|
97
106
|
if warn_on_fail:
|
@@ -164,21 +173,56 @@ def match_modules_set(
|
|
164
173
|
raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
|
165
174
|
|
166
175
|
|
167
|
-
def is_match(
|
176
|
+
def is_match(
|
177
|
+
name: str,
|
178
|
+
module: torch.nn.Module,
|
179
|
+
target: str,
|
180
|
+
fused: Optional[FusedMappping] = None,
|
181
|
+
) -> bool:
|
168
182
|
"""
|
169
183
|
Returns true if either module name or module parent classes match against target
|
170
|
-
and the module is not an internal module
|
184
|
+
and the module is not an internal module. The name and module may refer to a fused
|
185
|
+
module defined by vLLM. In these cases, a `fused` mapping must be provided.
|
186
|
+
|
187
|
+
For example, in `vllm/model_executor/models/llama.py`:
|
188
|
+
```python
|
189
|
+
packed_modules_mapping = {
|
190
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
191
|
+
"gate_up_proj": ["gate_proj", "up_proj"]
|
192
|
+
}
|
193
|
+
```
|
194
|
+
|
195
|
+
:param name: name of module
|
196
|
+
:param module: module to match
|
197
|
+
:param target: target which matches name or module, potentially contains regex
|
198
|
+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
|
199
|
+
corresponding shards
|
171
200
|
"""
|
172
201
|
return not isinstance(module, InternalModule) and (
|
173
|
-
_match_name(name, target) or _match_class(module, target)
|
202
|
+
_match_name(name, target, fused) or _match_class(module, target)
|
174
203
|
)
|
175
204
|
|
176
205
|
|
177
|
-
def _match_name(name: str, target: str) -> bool:
|
206
|
+
def _match_name(name: str, target: str, fused: Optional[FusedMappping] = None) -> bool:
|
178
207
|
"""
|
179
|
-
Returns true if target string begins with "re:" and
|
180
|
-
|
208
|
+
Returns true if target string begins with "re:" and regex matches or if target
|
209
|
+
string exactly matches name. If the name refers to a fused module defined by vLLM,
|
210
|
+
a `fused` mapping must be provided.
|
211
|
+
|
212
|
+
:param name: name of module
|
213
|
+
:param target: target name, potentially contains regex
|
214
|
+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
|
215
|
+
corresponding shards
|
181
216
|
"""
|
217
|
+
if fused is not None:
|
218
|
+
for fused_suffix in fused:
|
219
|
+
if name.endswith(fused_suffix):
|
220
|
+
name_stripped = name.removesuffix(fused_suffix)
|
221
|
+
return any(
|
222
|
+
_match_name(name_stripped + shard_suffix, target)
|
223
|
+
for shard_suffix in fused[fused_suffix]
|
224
|
+
)
|
225
|
+
|
182
226
|
if target.startswith("re:"):
|
183
227
|
return re.match(target.removeprefix("re:"), name) is not None
|
184
228
|
else:
|
@@ -187,10 +231,20 @@ def _match_name(name: str, target: str) -> bool:
|
|
187
231
|
|
188
232
|
def _match_class(module: torch.nn.Module, target: str) -> bool:
|
189
233
|
"""
|
190
|
-
Returns true if any torch parent class names match the target string exactly
|
234
|
+
Returns true if any torch parent class names match the target string exactly.
|
235
|
+
A special exception is made for vllm's `LinearBase` class which matches `Linear`
|
236
|
+
|
237
|
+
:param module: module to match
|
238
|
+
:param target: target which matches name or module
|
191
239
|
"""
|
192
240
|
# will never match against a regex pattern since `:` is not allowed in class names
|
193
241
|
return any(
|
194
|
-
|
242
|
+
(
|
243
|
+
issubclass(cls, torch.nn.Module)
|
244
|
+
and (
|
245
|
+
cls.__name__ == target
|
246
|
+
or (cls.__name__ == "LinearBase" and target == "Linear")
|
247
|
+
)
|
248
|
+
)
|
195
249
|
for cls in module.__class__.__mro__
|
196
250
|
)
|
compressed_tensors/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.3a20250806
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
{compressed_tensors-0.10.3a20250731.dist-info → compressed_tensors-0.10.3a20250806.dist-info}/RECORD
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
|
2
2
|
compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
|
3
|
-
compressed_tensors/version.py,sha256=
|
3
|
+
compressed_tensors/version.py,sha256=AuoKIjSgjjAcZIPZe3HN5zhNJ7enhDAjwQrqUHPg76o,523
|
4
4
|
compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
|
5
5
|
compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
|
6
6
|
compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
|
@@ -32,7 +32,7 @@ compressed_tensors/quantization/quant_scheme.py,sha256=xk2LPn18tjS1PEOyf0WKvavBq
|
|
32
32
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
|
33
33
|
compressed_tensors/quantization/lifecycle/apply.py,sha256=wM8mVcbKvZjBo18pSXMp28i30YWwUXJPSS7_HCakH9U,17892
|
34
34
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
35
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
35
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=HzfoRkK3CkEHuCqRWatq0kyu5sFx8ULZHNmmjRNIpWI,17571
|
36
36
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
37
37
|
compressed_tensors/quantization/lifecycle/initialize.py,sha256=BM7bR_uNa-Ex4T-roHonWiRaxCi5sFysXyl0cFh1ZVs,10257
|
38
38
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
@@ -45,9 +45,9 @@ compressed_tensors/transform/transform_args.py,sha256=jJY-Qt996w45LWQ10AHd7tUtNr
|
|
45
45
|
compressed_tensors/transform/transform_config.py,sha256=A3RuLNDqBNEByQNeu40Kg7sItwE6kWgnX18Umg1uONI,2128
|
46
46
|
compressed_tensors/transform/transform_scheme.py,sha256=uGLC4avdbhrVqNC3-Eo0p7WzNRQK92Fpg0N9hWiuCRQ,1752
|
47
47
|
compressed_tensors/transform/factory/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
48
|
-
compressed_tensors/transform/factory/base.py,sha256=
|
48
|
+
compressed_tensors/transform/factory/base.py,sha256=NJ3lI95tJk6gHOeZEVheQ_Ae7NHhhUG_9FHXu613x30,7740
|
49
49
|
compressed_tensors/transform/factory/hadamard.py,sha256=B0BVjbF3y707MO6L2XfEoZJTQU965vU9dUPLOiUSXII,4193
|
50
|
-
compressed_tensors/transform/factory/matrix_multiply.py,sha256=
|
50
|
+
compressed_tensors/transform/factory/matrix_multiply.py,sha256=kCB7cfM_PCgJDyyhg2d1rKTEiyuscwzhprXY7VfIx6E,3989
|
51
51
|
compressed_tensors/transform/factory/random_hadamard.py,sha256=nUhTlFa4ikSpcl4Umme71pnjMPgwYoGlwjKlU27UHZ4,1634
|
52
52
|
compressed_tensors/transform/utils/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
53
53
|
compressed_tensors/transform/utils/hadamard.py,sha256=hDJZC0Gw2fKdxqa3f8TmFc5J0eJqxHtFRxswLU_yVJc,5548
|
@@ -56,14 +56,14 @@ compressed_tensors/transform/utils/matrix.py,sha256=FIHCUlpWVIIhdr3c6EbQec41JeiP
|
|
56
56
|
compressed_tensors/utils/__init__.py,sha256=KZctuotCmX4byXhwDvSeXgp-Ny_awpziAX-WUkZfodI,853
|
57
57
|
compressed_tensors/utils/helpers.py,sha256=Q3iRAa2XSdmmn4vSpUplnvKOmWwn4Clao9ZkPBHXtpI,12604
|
58
58
|
compressed_tensors/utils/internal.py,sha256=7SSWgDoNFRnlfadwkoFhLW-T2jOc7Po_WzWv5h32Sa8,982
|
59
|
-
compressed_tensors/utils/match.py,sha256=
|
59
|
+
compressed_tensors/utils/match.py,sha256=9x-yZIlq7ndSLf2aQwNT7IpBQDe-8H6utiJkji8wPrQ,9397
|
60
60
|
compressed_tensors/utils/offload.py,sha256=3XiBuWbUkBAt8v1t5i57qDcbB3VJQs_FDeayi-JzIWg,23896
|
61
61
|
compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
62
62
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
63
63
|
compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
|
64
64
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
65
|
-
compressed_tensors-0.10.
|
66
|
-
compressed_tensors-0.10.
|
67
|
-
compressed_tensors-0.10.
|
68
|
-
compressed_tensors-0.10.
|
69
|
-
compressed_tensors-0.10.
|
65
|
+
compressed_tensors-0.10.3a20250806.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
66
|
+
compressed_tensors-0.10.3a20250806.dist-info/METADATA,sha256=e8DIx-6UDn2Wj7fGLEBgVru2k9Tme9dOPgxS_ciZDcw,7031
|
67
|
+
compressed_tensors-0.10.3a20250806.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
68
|
+
compressed_tensors-0.10.3a20250806.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
69
|
+
compressed_tensors-0.10.3a20250806.dist-info/RECORD,,
|
{compressed_tensors-0.10.3a20250731.dist-info → compressed_tensors-0.10.3a20250806.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|
File without changes
|