compressed-tensors 0.10.3a20250728__py3-none-any.whl → 0.10.3a20250805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -61,6 +61,27 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
61
61
  "weight_global_scale",
62
62
  )
63
63
 
64
+ def compression_param_info(
65
+ self,
66
+ weight_shape: torch.Size,
67
+ quantization_args: Optional[QuantizationArgs] = None,
68
+ ) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
69
+ """
70
+ Creates a dictionary of expected shapes and dtypes for each compression
71
+ parameter used by the compressor
72
+
73
+ :param weight_shape: uncompressed weight shape
74
+ :param quantization_args: quantization parameters for the weight
75
+ :return: dictionary mapping compressed parameter names to shape and dtype
76
+ """
77
+ output = {
78
+ "weight_packed": (
79
+ torch.Size((weight_shape[0], weight_shape[1] // 2)),
80
+ torch.uint8,
81
+ ),
82
+ }
83
+ return output
84
+
64
85
  def compress_weight(
65
86
  self,
66
87
  weight: Tensor,
@@ -124,8 +124,13 @@ def dequantize(
124
124
  strategy=QuantizationStrategy.GROUP, group_size=group_size
125
125
  )
126
126
  else:
127
+ rows, cols = x_q.shape[-2], x_q.shape[-1]
128
+ block_height = rows // scale.shape[0] # Rows per block
129
+ block_width = cols // scale.shape[1] # Columns per block
130
+
127
131
  args = QuantizationArgs(
128
- strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape
132
+ strategy=QuantizationStrategy.BLOCK,
133
+ block_structure=[block_height, block_width],
129
134
  )
130
135
  else:
131
136
  raise ValueError(
@@ -257,13 +262,10 @@ def _process_quantization(
257
262
  QuantizationStrategy.GROUP,
258
263
  QuantizationStrategy.TENSOR_GROUP,
259
264
  ):
260
- n_dims = x.shape
261
- if len(n_dims) > 2:
262
- x = x.squeeze(0)
263
265
 
264
266
  output_dtype = dtype if dtype is not None else x.dtype
265
267
  output = torch.zeros_like(x).to(output_dtype)
266
- columns = output.shape[1]
268
+ columns = output.shape[-1]
267
269
 
268
270
  # TODO: make validation step for inputs
269
271
 
@@ -293,14 +295,12 @@ def _process_quantization(
293
295
  perm = torch.argsort(g_idx)
294
296
  x = safe_permute(x, perm, dim=1)
295
297
 
296
- x = torch.reshape(
297
- x,
298
- (
299
- x.shape[0],
300
- ceil(x.shape[1] / group_size),
301
- group_size,
302
- ),
298
+ # Maintain all dimensions apart from the last dim, which is divided by the group_size
299
+ reshaped_dims = (
300
+ ceil(x.shape[-1] / group_size),
301
+ group_size,
303
302
  )
303
+ x = x.unflatten(-1, reshaped_dims)
304
304
 
305
305
  if do_quantize:
306
306
  output = _quantize(
@@ -323,19 +323,12 @@ def _process_quantization(
323
323
  global_scale=global_scale,
324
324
  )
325
325
 
326
- output = torch.reshape(
327
- output,
328
- (output.shape[0], output.shape[1] * output.shape[2]),
329
- )
330
-
326
+ output = output.flatten(start_dim=-2)
331
327
  output = output.to(output_dtype)
332
328
 
333
329
  if not is_column_order:
334
330
  output = safe_permute(output, torch.argsort(perm), dim=1)
335
331
 
336
- if len(n_dims) > 2:
337
- output = output.unsqueeze(0)
338
-
339
332
  else: # covers channel, token and tensor strategies
340
333
  if do_quantize:
341
334
  output = _quantize(
@@ -175,20 +175,16 @@ def compute_dynamic_scales_and_zp(
175
175
  QuantizationStrategy.TENSOR_GROUP,
176
176
  QuantizationStrategy.GROUP,
177
177
  ):
178
- if len(value.shape) > 2:
179
- value = value.squeeze(0)
180
178
 
181
- dim = {0, 1}
182
- reduce_dims = tuple(idx for idx in range(3) if idx not in dim)
179
+ reduce_dims = -1
183
180
  keep_dims = False
184
- value = torch.reshape(
185
- value,
186
- (
187
- value.shape[0],
188
- math.ceil(value.shape[1] / args.group_size),
189
- args.group_size,
190
- ),
181
+
182
+ reshaped_dims = (
183
+ math.ceil(value.shape[-1] / args.group_size),
184
+ args.group_size,
191
185
  )
186
+ value = value.unflatten(-1, reshaped_dims)
187
+
192
188
  else:
193
189
  supported_strategies = (
194
190
  QuantizationStrategy.TOKEN,
@@ -15,7 +15,7 @@
15
15
  import logging
16
16
  import re
17
17
  from collections.abc import Generator
18
- from typing import Iterable, Tuple
18
+ from typing import Iterable, Mapping, Optional, Tuple
19
19
 
20
20
  import torch
21
21
  from compressed_tensors.utils.internal import InternalModule
@@ -32,10 +32,14 @@ __all__ = [
32
32
  ]
33
33
 
34
34
 
35
+ FusedMappping = Mapping[str, Iterable[str]]
36
+
37
+
35
38
  def match_named_modules(
36
39
  model: torch.nn.Module,
37
40
  targets: Iterable[str],
38
41
  ignore: Iterable[str] = tuple(),
42
+ fused: Optional[FusedMappping] = None,
39
43
  warn_on_fail: bool = False,
40
44
  ) -> Generator[Tuple[str, torch.nn.Module]]:
41
45
  """
@@ -45,16 +49,18 @@ def match_named_modules(
45
49
  :param model: model containing submodules to match against
46
50
  :param targets: target strings, potentially containing "re:" prefixes
47
51
  :param ignore: targets to ignore, potentially containing "re:" prefixes
52
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
53
+ corresponding shards. See `compressed_tensors.utils.match.is_match`
48
54
  :param warn_on_fail: if True, warns if any targets do not match any modules in model
49
55
  :return: generator of module names and modules
50
56
  """
51
57
  unmatched_targets = set(targets)
52
58
  for name, module in model.named_modules():
53
59
  for target in targets:
54
- if is_match(name, module, target):
60
+ if is_match(name, module, target, fused):
55
61
  unmatched_targets -= {target}
56
62
 
57
- if not any(is_match(name, module, ign) for ign in ignore):
63
+ if not any(is_match(name, module, ign, fused) for ign in ignore):
58
64
  yield name, module
59
65
 
60
66
  if warn_on_fail:
@@ -68,6 +74,7 @@ def match_named_parameters(
68
74
  model: torch.nn.Module,
69
75
  targets: Iterable[str],
70
76
  ignore: Iterable[str] = tuple(),
77
+ fused: Optional[FusedMappping] = None,
71
78
  warn_on_fail: bool = False,
72
79
  ) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
73
80
  """
@@ -77,6 +84,8 @@ def match_named_parameters(
77
84
  :param model: model containing params to match against
78
85
  :param targets: target strings, potentially containing "re:" prefixes
79
86
  :param ignore: targets to ignore, potentially containing "re:" prefixes
87
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
88
+ corresponding shards. See `compressed_tensors.utils.match.is_match`
80
89
  :param warn_on_fail: if True, warns if any targets do not match any params in model
81
90
  :return: generator of fully-qualified param names, parent modules, and params
82
91
  """
@@ -88,10 +97,10 @@ def match_named_parameters(
88
97
  for param_name, param in module.named_parameters(recurse=False):
89
98
  param_fqn = f"{module_name}.{param_name}"
90
99
  for target in targets:
91
- if _match_name(param_fqn, target):
100
+ if _match_name(param_fqn, target, fused):
92
101
  unmatched_targets -= {target}
93
102
 
94
- if not any(_match_name(param_fqn, ign) for ign in ignore):
103
+ if not any(_match_name(param_fqn, ign, fused) for ign in ignore):
95
104
  yield param_fqn, module, param
96
105
 
97
106
  if warn_on_fail:
@@ -164,21 +173,56 @@ def match_modules_set(
164
173
  raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
165
174
 
166
175
 
167
- def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
176
+ def is_match(
177
+ name: str,
178
+ module: torch.nn.Module,
179
+ target: str,
180
+ fused: Optional[FusedMappping] = None,
181
+ ) -> bool:
168
182
  """
169
183
  Returns true if either module name or module parent classes match against target
170
- and the module is not an internal module
184
+ and the module is not an internal module. The name and module may refer to a fused
185
+ module defined by vLLM. In these cases, a `fused` mapping must be provided.
186
+
187
+ For example, in `vllm/model_executor/models/llama.py`:
188
+ ```python
189
+ packed_modules_mapping = {
190
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
191
+ "gate_up_proj": ["gate_proj", "up_proj"]
192
+ }
193
+ ```
194
+
195
+ :param name: name of module
196
+ :param module: module to match
197
+ :param target: target which matches name or module, potentially contains regex
198
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
199
+ corresponding shards
171
200
  """
172
201
  return not isinstance(module, InternalModule) and (
173
- _match_name(name, target) or _match_class(module, target)
202
+ _match_name(name, target, fused) or _match_class(module, target)
174
203
  )
175
204
 
176
205
 
177
- def _match_name(name: str, target: str) -> bool:
206
+ def _match_name(name: str, target: str, fused: Optional[FusedMappping] = None) -> bool:
178
207
  """
179
- Returns true if target string begins with "re:" and
180
- regex matches or if target string exactly matches name
208
+ Returns true if target string begins with "re:" and regex matches or if target
209
+ string exactly matches name. If the name refers to a fused module defined by vLLM,
210
+ a `fused` mapping must be provided.
211
+
212
+ :param name: name of module
213
+ :param target: target name, potentially contains regex
214
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
215
+ corresponding shards
181
216
  """
217
+ if fused is not None:
218
+ for fused_suffix in fused:
219
+ if name.endswith(fused_suffix):
220
+ name_stripped = name.removesuffix(fused_suffix)
221
+ return any(
222
+ _match_name(name_stripped + shard_suffix, target)
223
+ for shard_suffix in fused[fused_suffix]
224
+ )
225
+
182
226
  if target.startswith("re:"):
183
227
  return re.match(target.removeprefix("re:"), name) is not None
184
228
  else:
@@ -187,10 +231,20 @@ def _match_name(name: str, target: str) -> bool:
187
231
 
188
232
  def _match_class(module: torch.nn.Module, target: str) -> bool:
189
233
  """
190
- Returns true if any torch parent class names match the target string exactly
234
+ Returns true if any torch parent class names match the target string exactly.
235
+ A special exception is made for vllm's `LinearBase` class which matches `Linear`
236
+
237
+ :param module: module to match
238
+ :param target: target which matches name or module
191
239
  """
192
240
  # will never match against a regex pattern since `:` is not allowed in class names
193
241
  return any(
194
- issubclass(cls, torch.nn.Module) and cls.__name__ == target
242
+ (
243
+ issubclass(cls, torch.nn.Module)
244
+ and (
245
+ cls.__name__ == target
246
+ or (cls.__name__ == "LinearBase" and target == "Linear")
247
+ )
248
+ )
195
249
  for cls in module.__class__.__mro__
196
250
  )
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.3.a20250728'
20
+ __version__ = version = '0.10.3.a20250805'
21
21
  __version_tuple__ = version_tuple = (0, 10, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250728
3
+ Version: 0.10.3a20250805
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,6 +1,6 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=EY3NpvLIsm31BPA-e32djbQIUYdm3sP8W28lHH72d0Y,523
3
+ compressed_tensors/version.py,sha256=UcH3DkUtSV6xgd1l5QTWXLV_iWa7GzNrCWIOpZvkzkE,523
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
@@ -9,7 +9,7 @@ compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=e-2n
9
9
  compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=KvaFBL_Q84LxRGJOV035M8OBoCkAx8kOkfphswgkKWk,745
10
10
  compressed_tensors/compressors/quantized_compressors/base.py,sha256=YGUMzbxekj_36ChgQnVZN6T8uDjXtGG1zfMIBGBLWco,10354
11
11
  compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=0ANDcuD8aXPqTYNPY6GnX9iS6eXJw6P0TzNV_rYS2l8,5369
12
- compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=Gw-lVzk5jrKUlM5UTCiJBmhM5gHzB9mn8r298MVUbDI,6395
12
+ compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=tKEaYom4SdMwZWg4MDMMMLNGTLgcVT20lPzewboVpMM,7145
13
13
  compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=47W1hFTi5YHVNKEWptzztsSutwI1kxy2Troh-NW1y14,11244
14
14
  compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=Atuz-OdEgn8OCUhx7Ovd6gXdyImAI186uCR-uR0t_Nk,737
15
15
  compressed_tensors/compressors/sparse_compressors/base.py,sha256=YNZWcHjDleAlqbgRZQ6oJf44MQb_UDNvJGOqhl26uFA,8098
@@ -32,11 +32,11 @@ compressed_tensors/quantization/quant_scheme.py,sha256=xk2LPn18tjS1PEOyf0WKvavBq
32
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
33
33
  compressed_tensors/quantization/lifecycle/apply.py,sha256=wM8mVcbKvZjBo18pSXMp28i30YWwUXJPSS7_HCakH9U,17892
34
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
35
- compressed_tensors/quantization/lifecycle/forward.py,sha256=V98jWzb3rfV91EC6kfzAyXtmnbLjNF01Rd_EHU2bLo8,17506
35
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=HzfoRkK3CkEHuCqRWatq0kyu5sFx8ULZHNmmjRNIpWI,17571
36
36
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
37
37
  compressed_tensors/quantization/lifecycle/initialize.py,sha256=BM7bR_uNa-Ex4T-roHonWiRaxCi5sFysXyl0cFh1ZVs,10257
38
38
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
39
- compressed_tensors/quantization/utils/helpers.py,sha256=Je96Wai9SOizbdE5ph0nsJ86zS96lE4fkf_9q9o2tpA,17212
39
+ compressed_tensors/quantization/utils/helpers.py,sha256=7a89X0kg6xDGplw6trOrkRQzMRPu-txY_qvEt07Vcgc,17036
40
40
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
41
41
  compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
42
42
  compressed_tensors/transform/__init__.py,sha256=v2wfl4CMfA6KbD7Hxx_MbRev63y_6QLDlccZq-WTtdw,907
@@ -56,14 +56,14 @@ compressed_tensors/transform/utils/matrix.py,sha256=FIHCUlpWVIIhdr3c6EbQec41JeiP
56
56
  compressed_tensors/utils/__init__.py,sha256=KZctuotCmX4byXhwDvSeXgp-Ny_awpziAX-WUkZfodI,853
57
57
  compressed_tensors/utils/helpers.py,sha256=Q3iRAa2XSdmmn4vSpUplnvKOmWwn4Clao9ZkPBHXtpI,12604
58
58
  compressed_tensors/utils/internal.py,sha256=7SSWgDoNFRnlfadwkoFhLW-T2jOc7Po_WzWv5h32Sa8,982
59
- compressed_tensors/utils/match.py,sha256=ZVBPzrGYExq7-6RRUlU5XeCjl0ooLaNUoDO6Cgnn9cY,7220
59
+ compressed_tensors/utils/match.py,sha256=9x-yZIlq7ndSLf2aQwNT7IpBQDe-8H6utiJkji8wPrQ,9397
60
60
  compressed_tensors/utils/offload.py,sha256=3XiBuWbUkBAt8v1t5i57qDcbB3VJQs_FDeayi-JzIWg,23896
61
61
  compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
62
62
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
63
63
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
64
64
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
65
- compressed_tensors-0.10.3a20250728.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
66
- compressed_tensors-0.10.3a20250728.dist-info/METADATA,sha256=rQbbrFahVspKPEfY86EpebdjgoYAtSyyH7JLOPTPcrg,7031
67
- compressed_tensors-0.10.3a20250728.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
68
- compressed_tensors-0.10.3a20250728.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
69
- compressed_tensors-0.10.3a20250728.dist-info/RECORD,,
65
+ compressed_tensors-0.10.3a20250805.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
66
+ compressed_tensors-0.10.3a20250805.dist-info/METADATA,sha256=8SpvZ9SNB_DGL6L4I8QrtLczHtxI17ezOlwf6Ew_4R8,7031
67
+ compressed_tensors-0.10.3a20250805.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
68
+ compressed_tensors-0.10.3a20250805.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
69
+ compressed_tensors-0.10.3a20250805.dist-info/RECORD,,