compressed-tensors 0.10.3a20250731__py3-none-any.whl → 0.10.3a20250805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -124,8 +124,13 @@ def dequantize(
124
124
  strategy=QuantizationStrategy.GROUP, group_size=group_size
125
125
  )
126
126
  else:
127
+ rows, cols = x_q.shape[-2], x_q.shape[-1]
128
+ block_height = rows // scale.shape[0] # Rows per block
129
+ block_width = cols // scale.shape[1] # Columns per block
130
+
127
131
  args = QuantizationArgs(
128
- strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape
132
+ strategy=QuantizationStrategy.BLOCK,
133
+ block_structure=[block_height, block_width],
129
134
  )
130
135
  else:
131
136
  raise ValueError(
@@ -15,7 +15,7 @@
15
15
  import logging
16
16
  import re
17
17
  from collections.abc import Generator
18
- from typing import Iterable, Tuple
18
+ from typing import Iterable, Mapping, Optional, Tuple
19
19
 
20
20
  import torch
21
21
  from compressed_tensors.utils.internal import InternalModule
@@ -32,10 +32,14 @@ __all__ = [
32
32
  ]
33
33
 
34
34
 
35
+ FusedMappping = Mapping[str, Iterable[str]]
36
+
37
+
35
38
  def match_named_modules(
36
39
  model: torch.nn.Module,
37
40
  targets: Iterable[str],
38
41
  ignore: Iterable[str] = tuple(),
42
+ fused: Optional[FusedMappping] = None,
39
43
  warn_on_fail: bool = False,
40
44
  ) -> Generator[Tuple[str, torch.nn.Module]]:
41
45
  """
@@ -45,16 +49,18 @@ def match_named_modules(
45
49
  :param model: model containing submodules to match against
46
50
  :param targets: target strings, potentially containing "re:" prefixes
47
51
  :param ignore: targets to ignore, potentially containing "re:" prefixes
52
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
53
+ corresponding shards. See `compressed_tensors.utils.match.is_match`
48
54
  :param warn_on_fail: if True, warns if any targets do not match any modules in model
49
55
  :return: generator of module names and modules
50
56
  """
51
57
  unmatched_targets = set(targets)
52
58
  for name, module in model.named_modules():
53
59
  for target in targets:
54
- if is_match(name, module, target):
60
+ if is_match(name, module, target, fused):
55
61
  unmatched_targets -= {target}
56
62
 
57
- if not any(is_match(name, module, ign) for ign in ignore):
63
+ if not any(is_match(name, module, ign, fused) for ign in ignore):
58
64
  yield name, module
59
65
 
60
66
  if warn_on_fail:
@@ -68,6 +74,7 @@ def match_named_parameters(
68
74
  model: torch.nn.Module,
69
75
  targets: Iterable[str],
70
76
  ignore: Iterable[str] = tuple(),
77
+ fused: Optional[FusedMappping] = None,
71
78
  warn_on_fail: bool = False,
72
79
  ) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
73
80
  """
@@ -77,6 +84,8 @@ def match_named_parameters(
77
84
  :param model: model containing params to match against
78
85
  :param targets: target strings, potentially containing "re:" prefixes
79
86
  :param ignore: targets to ignore, potentially containing "re:" prefixes
87
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
88
+ corresponding shards. See `compressed_tensors.utils.match.is_match`
80
89
  :param warn_on_fail: if True, warns if any targets do not match any params in model
81
90
  :return: generator of fully-qualified param names, parent modules, and params
82
91
  """
@@ -88,10 +97,10 @@ def match_named_parameters(
88
97
  for param_name, param in module.named_parameters(recurse=False):
89
98
  param_fqn = f"{module_name}.{param_name}"
90
99
  for target in targets:
91
- if _match_name(param_fqn, target):
100
+ if _match_name(param_fqn, target, fused):
92
101
  unmatched_targets -= {target}
93
102
 
94
- if not any(_match_name(param_fqn, ign) for ign in ignore):
103
+ if not any(_match_name(param_fqn, ign, fused) for ign in ignore):
95
104
  yield param_fqn, module, param
96
105
 
97
106
  if warn_on_fail:
@@ -164,21 +173,56 @@ def match_modules_set(
164
173
  raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
165
174
 
166
175
 
167
- def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
176
+ def is_match(
177
+ name: str,
178
+ module: torch.nn.Module,
179
+ target: str,
180
+ fused: Optional[FusedMappping] = None,
181
+ ) -> bool:
168
182
  """
169
183
  Returns true if either module name or module parent classes match against target
170
- and the module is not an internal module
184
+ and the module is not an internal module. The name and module may refer to a fused
185
+ module defined by vLLM. In these cases, a `fused` mapping must be provided.
186
+
187
+ For example, in `vllm/model_executor/models/llama.py`:
188
+ ```python
189
+ packed_modules_mapping = {
190
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
191
+ "gate_up_proj": ["gate_proj", "up_proj"]
192
+ }
193
+ ```
194
+
195
+ :param name: name of module
196
+ :param module: module to match
197
+ :param target: target which matches name or module, potentially contains regex
198
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
199
+ corresponding shards
171
200
  """
172
201
  return not isinstance(module, InternalModule) and (
173
- _match_name(name, target) or _match_class(module, target)
202
+ _match_name(name, target, fused) or _match_class(module, target)
174
203
  )
175
204
 
176
205
 
177
- def _match_name(name: str, target: str) -> bool:
206
+ def _match_name(name: str, target: str, fused: Optional[FusedMappping] = None) -> bool:
178
207
  """
179
- Returns true if target string begins with "re:" and
180
- regex matches or if target string exactly matches name
208
+ Returns true if target string begins with "re:" and regex matches or if target
209
+ string exactly matches name. If the name refers to a fused module defined by vLLM,
210
+ a `fused` mapping must be provided.
211
+
212
+ :param name: name of module
213
+ :param target: target name, potentially contains regex
214
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
215
+ corresponding shards
181
216
  """
217
+ if fused is not None:
218
+ for fused_suffix in fused:
219
+ if name.endswith(fused_suffix):
220
+ name_stripped = name.removesuffix(fused_suffix)
221
+ return any(
222
+ _match_name(name_stripped + shard_suffix, target)
223
+ for shard_suffix in fused[fused_suffix]
224
+ )
225
+
182
226
  if target.startswith("re:"):
183
227
  return re.match(target.removeprefix("re:"), name) is not None
184
228
  else:
@@ -187,10 +231,20 @@ def _match_name(name: str, target: str) -> bool:
187
231
 
188
232
  def _match_class(module: torch.nn.Module, target: str) -> bool:
189
233
  """
190
- Returns true if any torch parent class names match the target string exactly
234
+ Returns true if any torch parent class names match the target string exactly.
235
+ A special exception is made for vllm's `LinearBase` class which matches `Linear`
236
+
237
+ :param module: module to match
238
+ :param target: target which matches name or module
191
239
  """
192
240
  # will never match against a regex pattern since `:` is not allowed in class names
193
241
  return any(
194
- issubclass(cls, torch.nn.Module) and cls.__name__ == target
242
+ (
243
+ issubclass(cls, torch.nn.Module)
244
+ and (
245
+ cls.__name__ == target
246
+ or (cls.__name__ == "LinearBase" and target == "Linear")
247
+ )
248
+ )
195
249
  for cls in module.__class__.__mro__
196
250
  )
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.3.a20250731'
20
+ __version__ = version = '0.10.3.a20250805'
21
21
  __version_tuple__ = version_tuple = (0, 10, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250731
3
+ Version: 0.10.3a20250805
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,6 +1,6 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=cuOuj6FL5GE-iPKjLVFuRjlwW0_6uDC3tDxFkkHyXFg,523
3
+ compressed_tensors/version.py,sha256=UcH3DkUtSV6xgd1l5QTWXLV_iWa7GzNrCWIOpZvkzkE,523
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
@@ -32,7 +32,7 @@ compressed_tensors/quantization/quant_scheme.py,sha256=xk2LPn18tjS1PEOyf0WKvavBq
32
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
33
33
  compressed_tensors/quantization/lifecycle/apply.py,sha256=wM8mVcbKvZjBo18pSXMp28i30YWwUXJPSS7_HCakH9U,17892
34
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
35
- compressed_tensors/quantization/lifecycle/forward.py,sha256=lQwibkDGroJqONhP9ATZWwaZF9suPmCZMQEagFlFc94,17329
35
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=HzfoRkK3CkEHuCqRWatq0kyu5sFx8ULZHNmmjRNIpWI,17571
36
36
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
37
37
  compressed_tensors/quantization/lifecycle/initialize.py,sha256=BM7bR_uNa-Ex4T-roHonWiRaxCi5sFysXyl0cFh1ZVs,10257
38
38
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
@@ -56,14 +56,14 @@ compressed_tensors/transform/utils/matrix.py,sha256=FIHCUlpWVIIhdr3c6EbQec41JeiP
56
56
  compressed_tensors/utils/__init__.py,sha256=KZctuotCmX4byXhwDvSeXgp-Ny_awpziAX-WUkZfodI,853
57
57
  compressed_tensors/utils/helpers.py,sha256=Q3iRAa2XSdmmn4vSpUplnvKOmWwn4Clao9ZkPBHXtpI,12604
58
58
  compressed_tensors/utils/internal.py,sha256=7SSWgDoNFRnlfadwkoFhLW-T2jOc7Po_WzWv5h32Sa8,982
59
- compressed_tensors/utils/match.py,sha256=ZVBPzrGYExq7-6RRUlU5XeCjl0ooLaNUoDO6Cgnn9cY,7220
59
+ compressed_tensors/utils/match.py,sha256=9x-yZIlq7ndSLf2aQwNT7IpBQDe-8H6utiJkji8wPrQ,9397
60
60
  compressed_tensors/utils/offload.py,sha256=3XiBuWbUkBAt8v1t5i57qDcbB3VJQs_FDeayi-JzIWg,23896
61
61
  compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
62
62
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
63
63
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
64
64
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
65
- compressed_tensors-0.10.3a20250731.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
66
- compressed_tensors-0.10.3a20250731.dist-info/METADATA,sha256=1NCpfVbLTf6aGJ38rJz3Lmu9DptHpuYm5vTRxIB9PB8,7031
67
- compressed_tensors-0.10.3a20250731.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
68
- compressed_tensors-0.10.3a20250731.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
69
- compressed_tensors-0.10.3a20250731.dist-info/RECORD,,
65
+ compressed_tensors-0.10.3a20250805.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
66
+ compressed_tensors-0.10.3a20250805.dist-info/METADATA,sha256=8SpvZ9SNB_DGL6L4I8QrtLczHtxI17ezOlwf6Ew_4R8,7031
67
+ compressed_tensors-0.10.3a20250805.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
68
+ compressed_tensors-0.10.3a20250805.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
69
+ compressed_tensors-0.10.3a20250805.dist-info/RECORD,,