compressed-tensors 0.10.3a20250731__py3-none-any.whl → 0.10.3a20250806__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -124,8 +124,13 @@ def dequantize(
124
124
  strategy=QuantizationStrategy.GROUP, group_size=group_size
125
125
  )
126
126
  else:
127
+ rows, cols = x_q.shape[-2], x_q.shape[-1]
128
+ block_height = rows // scale.shape[0] # Rows per block
129
+ block_width = cols // scale.shape[1] # Columns per block
130
+
127
131
  args = QuantizationArgs(
128
- strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape
132
+ strategy=QuantizationStrategy.BLOCK,
133
+ block_structure=[block_height, block_width],
129
134
  )
130
135
  else:
131
136
  raise ValueError(
@@ -13,7 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from abc import ABC, abstractmethod
16
- from typing import Optional
16
+ from collections import defaultdict
17
+ from typing import List, Optional, Tuple, Set
17
18
 
18
19
  import torch
19
20
  import torch.nn.utils.parametrize as P
@@ -49,10 +50,13 @@ class TransformFactory(RegistryMixin, ABC):
49
50
  :param seed: random seed used to transform weight randomization
50
51
  """
51
52
 
53
+ transforms: List["TransformBase"]
54
+
52
55
  def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
53
56
  self.name = name
54
57
  self.scheme = scheme
55
58
  self.generator = torch.Generator()
59
+ self.transforms = list()
56
60
  if seed is not None:
57
61
  self.generator.manual_seed(seed)
58
62
 
@@ -90,6 +94,8 @@ class TransformFactory(RegistryMixin, ABC):
90
94
  for _, module in match_named_modules(model, arg.targets, arg.ignore):
91
95
  self._apply_to_module(module, arg)
92
96
 
97
+ self._update_tied_weights()
98
+
93
99
  def _apply_to_module(self, module: Module, args: TransformArgs):
94
100
  """
95
101
  Create transforms and apply them to the module
@@ -97,9 +103,17 @@ class TransformFactory(RegistryMixin, ABC):
97
103
  :param module: target module to apply transforms to
98
104
  :param args: defines how the transform will be applied to the target module
99
105
  """
106
+ if has_offloaded_params(module):
107
+ if module._hf_hook.place_submodules:
108
+ raise NotImplementedError(
109
+ "Applying transforms to offloaded submodules with "
110
+ "`place_submodules=True` is not supported"
111
+ )
112
+
100
113
  # create transform as submodule
101
114
  transform_name = f"{self.name}_{args.location}"
102
115
  transform = self.create_transform(module, args)
116
+ self.transforms.append(transform)
103
117
  register_offload_module(module, transform_name, transform)
104
118
 
105
119
  # register input transformation hook
@@ -128,8 +142,9 @@ class TransformFactory(RegistryMixin, ABC):
128
142
  raise ValueError("Offloaded training is not supported")
129
143
  P.register_parametrization(module, "weight", transform)
130
144
 
131
- # transform is no longer needed (unfusing is not supported)
132
- delete_offload_module(module, transform_name)
145
+ else:
146
+ # transform is no longer needed (unfusing is not supported)
147
+ delete_offload_module(module, transform_name)
133
148
 
134
149
  # register output transformation hook
135
150
  elif args.location == TransformLocation.OUTPUT:
@@ -143,6 +158,31 @@ class TransformFactory(RegistryMixin, ABC):
143
158
  else:
144
159
  raise NotImplementedError()
145
160
 
161
+ def _update_tied_weights(self):
162
+ """
163
+ Populate the `_dynamic_tied_weights_keys` attribute of transforms,
164
+ which is used by transformers to detect and remove shared pointers
165
+ during saving
166
+ """
167
+ # map from data_ptrs to keys
168
+ ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
169
+ for transform in self.transforms:
170
+ for name, param in transform.named_parameters(recurse=False):
171
+ # NOTE: previously asserted that parent._hf_hook.place_submodules=False
172
+ if has_offloaded_params(transform):
173
+ param = transform._hf_hook.weights_map[name]
174
+ ptr_to_keys[param.data_ptr()].append((transform, name))
175
+
176
+ # populate `_dynamic_tied_weights_keys` if there is more than one key
177
+ # and ensure that they share tensors
178
+ for shared_keys in ptr_to_keys.values():
179
+ if len(shared_keys) > 1:
180
+ tensor = getattr(shared_keys[0][0], shared_keys[0][1])
181
+
182
+ for transform, name in shared_keys:
183
+ transform._dynamic_tied_weights_keys.add(name)
184
+ setattr(transform, name, tensor)
185
+
146
186
 
147
187
  class TransformBase(InternalModule, ABC):
148
188
  """
@@ -151,6 +191,11 @@ class TransformBase(InternalModule, ABC):
151
191
 
152
192
  args: TransformArgs
153
193
  weight: Parameter
194
+ _dynamic_tied_weights_keys: Set[str]
195
+
196
+ def __init__(self):
197
+ super().__init__()
198
+ self._dynamic_tied_weights_keys = set()
154
199
 
155
200
  @abstractmethod
156
201
  def forward(self, value: Tensor) -> Tensor:
@@ -70,6 +70,7 @@ class RandomMatrixFactory(TransformFactory):
70
70
 
71
71
  def _create_inverse(self, weight: Parameter) -> Parameter:
72
72
  data = high_precision_invert(weight.data)
73
+ data = data.contiguous() # ensure proper serialization
73
74
  return Parameter(data, requires_grad=False)
74
75
 
75
76
 
@@ -15,7 +15,7 @@
15
15
  import logging
16
16
  import re
17
17
  from collections.abc import Generator
18
- from typing import Iterable, Tuple
18
+ from typing import Iterable, Mapping, Optional, Tuple
19
19
 
20
20
  import torch
21
21
  from compressed_tensors.utils.internal import InternalModule
@@ -32,10 +32,14 @@ __all__ = [
32
32
  ]
33
33
 
34
34
 
35
+ FusedMappping = Mapping[str, Iterable[str]]
36
+
37
+
35
38
  def match_named_modules(
36
39
  model: torch.nn.Module,
37
40
  targets: Iterable[str],
38
41
  ignore: Iterable[str] = tuple(),
42
+ fused: Optional[FusedMappping] = None,
39
43
  warn_on_fail: bool = False,
40
44
  ) -> Generator[Tuple[str, torch.nn.Module]]:
41
45
  """
@@ -45,16 +49,18 @@ def match_named_modules(
45
49
  :param model: model containing submodules to match against
46
50
  :param targets: target strings, potentially containing "re:" prefixes
47
51
  :param ignore: targets to ignore, potentially containing "re:" prefixes
52
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
53
+ corresponding shards. See `compressed_tensors.utils.match.is_match`
48
54
  :param warn_on_fail: if True, warns if any targets do not match any modules in model
49
55
  :return: generator of module names and modules
50
56
  """
51
57
  unmatched_targets = set(targets)
52
58
  for name, module in model.named_modules():
53
59
  for target in targets:
54
- if is_match(name, module, target):
60
+ if is_match(name, module, target, fused):
55
61
  unmatched_targets -= {target}
56
62
 
57
- if not any(is_match(name, module, ign) for ign in ignore):
63
+ if not any(is_match(name, module, ign, fused) for ign in ignore):
58
64
  yield name, module
59
65
 
60
66
  if warn_on_fail:
@@ -68,6 +74,7 @@ def match_named_parameters(
68
74
  model: torch.nn.Module,
69
75
  targets: Iterable[str],
70
76
  ignore: Iterable[str] = tuple(),
77
+ fused: Optional[FusedMappping] = None,
71
78
  warn_on_fail: bool = False,
72
79
  ) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
73
80
  """
@@ -77,6 +84,8 @@ def match_named_parameters(
77
84
  :param model: model containing params to match against
78
85
  :param targets: target strings, potentially containing "re:" prefixes
79
86
  :param ignore: targets to ignore, potentially containing "re:" prefixes
87
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
88
+ corresponding shards. See `compressed_tensors.utils.match.is_match`
80
89
  :param warn_on_fail: if True, warns if any targets do not match any params in model
81
90
  :return: generator of fully-qualified param names, parent modules, and params
82
91
  """
@@ -88,10 +97,10 @@ def match_named_parameters(
88
97
  for param_name, param in module.named_parameters(recurse=False):
89
98
  param_fqn = f"{module_name}.{param_name}"
90
99
  for target in targets:
91
- if _match_name(param_fqn, target):
100
+ if _match_name(param_fqn, target, fused):
92
101
  unmatched_targets -= {target}
93
102
 
94
- if not any(_match_name(param_fqn, ign) for ign in ignore):
103
+ if not any(_match_name(param_fqn, ign, fused) for ign in ignore):
95
104
  yield param_fqn, module, param
96
105
 
97
106
  if warn_on_fail:
@@ -164,21 +173,56 @@ def match_modules_set(
164
173
  raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
165
174
 
166
175
 
167
- def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
176
+ def is_match(
177
+ name: str,
178
+ module: torch.nn.Module,
179
+ target: str,
180
+ fused: Optional[FusedMappping] = None,
181
+ ) -> bool:
168
182
  """
169
183
  Returns true if either module name or module parent classes match against target
170
- and the module is not an internal module
184
+ and the module is not an internal module. The name and module may refer to a fused
185
+ module defined by vLLM. In these cases, a `fused` mapping must be provided.
186
+
187
+ For example, in `vllm/model_executor/models/llama.py`:
188
+ ```python
189
+ packed_modules_mapping = {
190
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
191
+ "gate_up_proj": ["gate_proj", "up_proj"]
192
+ }
193
+ ```
194
+
195
+ :param name: name of module
196
+ :param module: module to match
197
+ :param target: target which matches name or module, potentially contains regex
198
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
199
+ corresponding shards
171
200
  """
172
201
  return not isinstance(module, InternalModule) and (
173
- _match_name(name, target) or _match_class(module, target)
202
+ _match_name(name, target, fused) or _match_class(module, target)
174
203
  )
175
204
 
176
205
 
177
- def _match_name(name: str, target: str) -> bool:
206
+ def _match_name(name: str, target: str, fused: Optional[FusedMappping] = None) -> bool:
178
207
  """
179
- Returns true if target string begins with "re:" and
180
- regex matches or if target string exactly matches name
208
+ Returns true if target string begins with "re:" and regex matches or if target
209
+ string exactly matches name. If the name refers to a fused module defined by vLLM,
210
+ a `fused` mapping must be provided.
211
+
212
+ :param name: name of module
213
+ :param target: target name, potentially contains regex
214
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
215
+ corresponding shards
181
216
  """
217
+ if fused is not None:
218
+ for fused_suffix in fused:
219
+ if name.endswith(fused_suffix):
220
+ name_stripped = name.removesuffix(fused_suffix)
221
+ return any(
222
+ _match_name(name_stripped + shard_suffix, target)
223
+ for shard_suffix in fused[fused_suffix]
224
+ )
225
+
182
226
  if target.startswith("re:"):
183
227
  return re.match(target.removeprefix("re:"), name) is not None
184
228
  else:
@@ -187,10 +231,20 @@ def _match_name(name: str, target: str) -> bool:
187
231
 
188
232
  def _match_class(module: torch.nn.Module, target: str) -> bool:
189
233
  """
190
- Returns true if any torch parent class names match the target string exactly
234
+ Returns true if any torch parent class names match the target string exactly.
235
+ A special exception is made for vllm's `LinearBase` class which matches `Linear`
236
+
237
+ :param module: module to match
238
+ :param target: target which matches name or module
191
239
  """
192
240
  # will never match against a regex pattern since `:` is not allowed in class names
193
241
  return any(
194
- issubclass(cls, torch.nn.Module) and cls.__name__ == target
242
+ (
243
+ issubclass(cls, torch.nn.Module)
244
+ and (
245
+ cls.__name__ == target
246
+ or (cls.__name__ == "LinearBase" and target == "Linear")
247
+ )
248
+ )
195
249
  for cls in module.__class__.__mro__
196
250
  )
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.3.a20250731'
20
+ __version__ = version = '0.10.3.a20250806'
21
21
  __version_tuple__ = version_tuple = (0, 10, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250731
3
+ Version: 0.10.3a20250806
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,6 +1,6 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=cuOuj6FL5GE-iPKjLVFuRjlwW0_6uDC3tDxFkkHyXFg,523
3
+ compressed_tensors/version.py,sha256=AuoKIjSgjjAcZIPZe3HN5zhNJ7enhDAjwQrqUHPg76o,523
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
@@ -32,7 +32,7 @@ compressed_tensors/quantization/quant_scheme.py,sha256=xk2LPn18tjS1PEOyf0WKvavBq
32
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
33
33
  compressed_tensors/quantization/lifecycle/apply.py,sha256=wM8mVcbKvZjBo18pSXMp28i30YWwUXJPSS7_HCakH9U,17892
34
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
35
- compressed_tensors/quantization/lifecycle/forward.py,sha256=lQwibkDGroJqONhP9ATZWwaZF9suPmCZMQEagFlFc94,17329
35
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=HzfoRkK3CkEHuCqRWatq0kyu5sFx8ULZHNmmjRNIpWI,17571
36
36
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
37
37
  compressed_tensors/quantization/lifecycle/initialize.py,sha256=BM7bR_uNa-Ex4T-roHonWiRaxCi5sFysXyl0cFh1ZVs,10257
38
38
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
@@ -45,9 +45,9 @@ compressed_tensors/transform/transform_args.py,sha256=jJY-Qt996w45LWQ10AHd7tUtNr
45
45
  compressed_tensors/transform/transform_config.py,sha256=A3RuLNDqBNEByQNeu40Kg7sItwE6kWgnX18Umg1uONI,2128
46
46
  compressed_tensors/transform/transform_scheme.py,sha256=uGLC4avdbhrVqNC3-Eo0p7WzNRQK92Fpg0N9hWiuCRQ,1752
47
47
  compressed_tensors/transform/factory/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
48
- compressed_tensors/transform/factory/base.py,sha256=Zplf8QO-mFqGwDEhLdYL_afSu7v4nMa79oNhidRNPvY,5880
48
+ compressed_tensors/transform/factory/base.py,sha256=NJ3lI95tJk6gHOeZEVheQ_Ae7NHhhUG_9FHXu613x30,7740
49
49
  compressed_tensors/transform/factory/hadamard.py,sha256=B0BVjbF3y707MO6L2XfEoZJTQU965vU9dUPLOiUSXII,4193
50
- compressed_tensors/transform/factory/matrix_multiply.py,sha256=LdoV2E12HTucmUWcw7UKOpRNnL8QhOOIUnNVlpOpGiI,3925
50
+ compressed_tensors/transform/factory/matrix_multiply.py,sha256=kCB7cfM_PCgJDyyhg2d1rKTEiyuscwzhprXY7VfIx6E,3989
51
51
  compressed_tensors/transform/factory/random_hadamard.py,sha256=nUhTlFa4ikSpcl4Umme71pnjMPgwYoGlwjKlU27UHZ4,1634
52
52
  compressed_tensors/transform/utils/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
53
53
  compressed_tensors/transform/utils/hadamard.py,sha256=hDJZC0Gw2fKdxqa3f8TmFc5J0eJqxHtFRxswLU_yVJc,5548
@@ -56,14 +56,14 @@ compressed_tensors/transform/utils/matrix.py,sha256=FIHCUlpWVIIhdr3c6EbQec41JeiP
56
56
  compressed_tensors/utils/__init__.py,sha256=KZctuotCmX4byXhwDvSeXgp-Ny_awpziAX-WUkZfodI,853
57
57
  compressed_tensors/utils/helpers.py,sha256=Q3iRAa2XSdmmn4vSpUplnvKOmWwn4Clao9ZkPBHXtpI,12604
58
58
  compressed_tensors/utils/internal.py,sha256=7SSWgDoNFRnlfadwkoFhLW-T2jOc7Po_WzWv5h32Sa8,982
59
- compressed_tensors/utils/match.py,sha256=ZVBPzrGYExq7-6RRUlU5XeCjl0ooLaNUoDO6Cgnn9cY,7220
59
+ compressed_tensors/utils/match.py,sha256=9x-yZIlq7ndSLf2aQwNT7IpBQDe-8H6utiJkji8wPrQ,9397
60
60
  compressed_tensors/utils/offload.py,sha256=3XiBuWbUkBAt8v1t5i57qDcbB3VJQs_FDeayi-JzIWg,23896
61
61
  compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
62
62
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
63
63
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
64
64
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
65
- compressed_tensors-0.10.3a20250731.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
66
- compressed_tensors-0.10.3a20250731.dist-info/METADATA,sha256=1NCpfVbLTf6aGJ38rJz3Lmu9DptHpuYm5vTRxIB9PB8,7031
67
- compressed_tensors-0.10.3a20250731.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
68
- compressed_tensors-0.10.3a20250731.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
69
- compressed_tensors-0.10.3a20250731.dist-info/RECORD,,
65
+ compressed_tensors-0.10.3a20250806.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
66
+ compressed_tensors-0.10.3a20250806.dist-info/METADATA,sha256=e8DIx-6UDn2Wj7fGLEBgVru2k9Tme9dOPgxS_ciZDcw,7031
67
+ compressed_tensors-0.10.3a20250806.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
68
+ compressed_tensors-0.10.3a20250806.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
69
+ compressed_tensors-0.10.3a20250806.dist-info/RECORD,,