compressed-tensors 0.10.3a20250811__py3-none-any.whl → 0.10.3a20250814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -169,7 +169,7 @@ class ModelCompressor:
169
169
  cls,
170
170
  model: Module,
171
171
  sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
172
- quantization_format: Optional[str] = None,
172
+ quantization_format: Optional[Union[str, List[str]]] = None,
173
173
  ) -> Optional["ModelCompressor"]:
174
174
  """
175
175
  Given a pytorch model and optional sparsity and/or quantization configs,
@@ -182,7 +182,6 @@ class ModelCompressor:
182
182
  algorithm
183
183
  :return: compressor for the configs, or None if model is not compressed
184
184
  """
185
- # reconstruct config from schemes attached to modules
186
185
  quantization_config = QuantizationConfig.from_pretrained(
187
186
  model, format=quantization_format
188
187
  )
@@ -203,11 +202,14 @@ class ModelCompressor:
203
202
  sparsity_config=sparsity_config,
204
203
  quantization_config=quantization_config,
205
204
  transform_config=transform_config,
205
+ compression_formats=[quantization_format]
206
+ if isinstance(quantization_format, str)
207
+ else quantization_format,
206
208
  )
207
209
 
208
210
  @staticmethod
209
211
  def parse_sparsity_config(
210
- compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
212
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
211
213
  ) -> Union[Dict[str, Any], None]:
212
214
  """
213
215
  Parse sparsity config from quantization/compression config. Sparsity
@@ -227,7 +229,7 @@ class ModelCompressor:
227
229
 
228
230
  @staticmethod
229
231
  def parse_quantization_config(
230
- compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
232
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
231
233
  ) -> Union[Dict[str, Any], None]:
232
234
  """
233
235
  Parse quantization config from quantization/compression config. The
@@ -246,6 +248,7 @@ class ModelCompressor:
246
248
 
247
249
  quantization_config = deepcopy(compression_config)
248
250
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
251
+ quantization_config.pop(TRANSFORM_CONFIG_NAME, None)
249
252
 
250
253
  # some fields are required, even if a qconfig is not present
251
254
  # pop them off and if nothing remains, then there is no qconfig
@@ -262,19 +265,39 @@ class ModelCompressor:
262
265
 
263
266
  return quantization_config
264
267
 
268
+ def _fetch_unique_quantization_formats(self) -> List[str]:
269
+ """
270
+ Get all unique compression formats present in a model.
271
+ :return: list of quantization formats
272
+ """
273
+ quantization_formats = []
274
+ for _, scheme in self.quantization_config.config_groups.items():
275
+ if scheme.format is not None and scheme.format not in quantization_formats:
276
+ quantization_formats.append(scheme.format)
277
+
278
+ if (
279
+ len(quantization_formats) == 0
280
+ and self.quantization_config.format
281
+ != CompressionFormat.mixed_precision.value
282
+ ):
283
+ quantization_formats.append(self.quantization_config.format)
284
+ return quantization_formats
285
+
265
286
  def __init__(
266
287
  self,
267
288
  sparsity_config: Optional[SparsityCompressionConfig] = None,
268
289
  quantization_config: Optional[QuantizationConfig] = None,
269
290
  transform_config: Optional[TransformConfig] = None,
291
+ compression_formats: Optional[List[str]] = None,
270
292
  ):
271
293
  self.sparsity_config = sparsity_config
272
294
  self.quantization_config = quantization_config
273
295
  self.transform_config = transform_config
296
+ self.compression_formats = compression_formats
274
297
 
275
298
  self.sparsity_compressor = None
276
299
  self.quantization_compressor: Optional[
277
- Union[BaseQuantizationCompressor, DenseCompressor]
300
+ Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]]
278
301
  ] = None
279
302
  # no transform compressor is required
280
303
 
@@ -282,10 +305,21 @@ class ModelCompressor:
282
305
  self.sparsity_compressor = BaseCompressor.load_from_registry(
283
306
  sparsity_config.format, config=sparsity_config
284
307
  )
308
+
285
309
  if quantization_config is not None:
286
- self.quantization_compressor = BaseCompressor.load_from_registry(
287
- quantization_config.format, config=quantization_config
288
- )
310
+ # If a list of compression_format is not provided, we resolve the
311
+ # relevant quantization formats using the config groups from the config
312
+ # and if those are not defined, we fall-back to the global quantization format
313
+ if not self.compression_formats:
314
+ self.compression_formats = self._fetch_unique_quantization_formats()
315
+
316
+ self.quantization_compressor = {}
317
+ for format in self.compression_formats:
318
+ self.quantization_compressor[
319
+ format
320
+ ] = BaseCompressor.load_from_registry(
321
+ format, config=quantization_config
322
+ )
289
323
 
290
324
  # ----- used by hf quantizer ----- #
291
325
 
@@ -380,12 +414,13 @@ class ModelCompressor:
380
414
  targets=scheme.targets,
381
415
  ignore=self.quantization_config.ignore,
382
416
  )
383
- unexpected_keys.update(
384
- merge_names(target, param)
385
- for target in quant_targets
386
- for param in self.quantization_compressor.compression_param_names
387
- if param != "weight"
388
- )
417
+ for quant_compressor in self.quantization_compressor.values():
418
+ unexpected_keys.update(
419
+ merge_names(target, param)
420
+ for target in quant_targets
421
+ for param in quant_compressor.compression_param_names
422
+ if param != "weight"
423
+ )
389
424
 
390
425
  return list(unexpected_keys)
391
426
 
@@ -423,7 +458,21 @@ class ModelCompressor:
423
458
 
424
459
  # quantization first
425
460
  if prefix in module_to_scheme:
426
- state_dict = self.quantization_compressor.compress(
461
+ if (
462
+ not hasattr(module.quantization_scheme, "format")
463
+ or module.quantization_scheme.format is None
464
+ ):
465
+ if len(self.compression_formats) > 1:
466
+ raise ValueError(
467
+ "Applying multiple compressors without defining "
468
+ "per module formats is not supported "
469
+ )
470
+ format = self.compression_formats[0]
471
+ else:
472
+ format = module.quantization_scheme.format
473
+
474
+ quant_compressor = self.quantization_compressor.get(format)
475
+ state_dict = quant_compressor.compress(
427
476
  state_dict,
428
477
  names_to_scheme=module_to_scheme,
429
478
  show_progress=False,
@@ -494,12 +543,24 @@ class ModelCompressor:
494
543
 
495
544
  # quantization second
496
545
  if prefix in module_to_scheme:
497
- state_dict = (
498
- self.quantization_compressor.decompress_module_from_state_dict(
499
- prefix,
500
- state_dict,
501
- scheme=module_to_scheme[prefix],
502
- )
546
+
547
+ if (
548
+ not hasattr(module.quantization_scheme, "format")
549
+ or module.quantization_scheme.format is None
550
+ ):
551
+ if len(self.compression_formats) > 1:
552
+ raise ValueError(
553
+ "Applying multiple compressors without defining "
554
+ "per module formats is not supported "
555
+ )
556
+ format = self.compression_formats[0]
557
+ else:
558
+ format = module.quantization_scheme.format
559
+ quant_compressor = self.quantization_compressor.get(format)
560
+ state_dict = quant_compressor.decompress_module_from_state_dict(
561
+ prefix,
562
+ state_dict,
563
+ scheme=module_to_scheme[prefix],
503
564
  )
504
565
 
505
566
  # remove any existing parameters
@@ -538,7 +599,9 @@ class ModelCompressor:
538
599
 
539
600
  if self.quantization_compressor is not None:
540
601
  module_to_scheme = map_module_to_scheme(model)
541
- state_dict = self.quantization_compressor.compress(
602
+ # Note - compress only supports one compression format atm
603
+ quant_compressor = next(iter(self.quantization_compressor.values()))
604
+ state_dict = quant_compressor.compress(
542
605
  state_dict,
543
606
  names_to_scheme=module_to_scheme,
544
607
  show_progress=show_progress,
@@ -587,14 +650,20 @@ class ModelCompressor:
587
650
  """
588
651
  model_path = get_safetensors_folder(model_path)
589
652
  sparse_decompressed = False
653
+ quant_compressor = (
654
+ next(iter(self.quantization_compressor.values()))
655
+ if self.quantization_compressor is not None
656
+ else None
657
+ )
590
658
 
591
659
  if (
592
660
  self.sparsity_compressor is not None
593
661
  and self.sparsity_config.format != CompressionFormat.dense.value
594
662
  ):
663
+ # note - decompress only supports one compressor atm
595
664
  params_to_ignore = None
596
- if self.quantization_compressor is not None:
597
- params_to_ignore = self.quantization_compressor.compression_param_names
665
+ if quant_compressor is not None:
666
+ params_to_ignore = quant_compressor.compression_param_names
598
667
  # Sparse decompression is applied on the model_path
599
668
  # The compressor will try and load any quantization parameters as well
600
669
  # params_to_skip_load will skip over quantization params from being loaded
@@ -605,7 +674,7 @@ class ModelCompressor:
605
674
  setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
606
675
  sparse_decompressed = True
607
676
 
608
- if self.quantization_compressor is not None:
677
+ if quant_compressor is not None:
609
678
  # Temporarily set quantization status to FROZEN to prevent
610
679
  # quantization during apply_quantization_config. This ensures
611
680
  # that the dtypes of the weights are not unintentionally updated.
@@ -628,7 +697,7 @@ class ModelCompressor:
628
697
  # including initialization
629
698
  load_weight_quantization=(
630
699
  sparse_decompressed
631
- or isinstance(self.quantization_compressor, DenseCompressor)
700
+ or isinstance(quant_compressor, DenseCompressor)
632
701
  ),
633
702
  )
634
703
 
@@ -636,7 +705,7 @@ class ModelCompressor:
636
705
  model.state_dict() if sparse_decompressed else model_path
637
706
  )
638
707
 
639
- dense_gen = self.quantization_compressor.decompress(
708
+ dense_gen = quant_compressor.decompress(
640
709
  model_path_or_state_dict, names_to_scheme=names_to_scheme
641
710
  )
642
711
  # TODO: all weight quantization params will be moved to the compressor
@@ -674,7 +743,7 @@ class ModelCompressor:
674
743
 
675
744
  # serialize configs into json
676
745
  qconfig_data = (
677
- self.quantization_config.model_dump(exclude=["quant_method", "format"])
746
+ self.quantization_config.model_dump(exclude=["quant_method"])
678
747
  if self.quantization_config is not None
679
748
  else {}
680
749
  )
@@ -32,6 +32,7 @@ class CompressionFormat(Enum):
32
32
  naive_quantized = "naive-quantized"
33
33
  pack_quantized = "pack-quantized"
34
34
  marlin_24 = "marlin-24"
35
+ mixed_precision = "mixed-precision"
35
36
  nvfp4_pack_quantized = "nvfp4-pack-quantized"
36
37
 
37
38
 
@@ -234,6 +234,12 @@ class QuantizationConfig(BaseModel):
234
234
  format = CompressionFormat.int_quantized.value
235
235
  else:
236
236
  format = CompressionFormat.dense.value
237
+ elif isinstance(format, list):
238
+ format = (
239
+ CompressionFormat.mixed_precision.value
240
+ if len(format) > 1
241
+ else format[0]
242
+ )
237
243
 
238
244
  return QuantizationConfig(
239
245
  config_groups=config_groups,
@@ -16,6 +16,7 @@ import warnings
16
16
  from copy import deepcopy
17
17
  from typing import List, Optional
18
18
 
19
+ from compressed_tensors.config import CompressionFormat
19
20
  from compressed_tensors.quantization.quant_args import (
20
21
  DynamicType,
21
22
  QuantizationArgs,
@@ -42,12 +43,14 @@ class QuantizationScheme(BaseModel):
42
43
  :param weights: quantization config for layer weights
43
44
  :param input_activations: quantization config for layer inputs
44
45
  :param output_activations: quantization config for layer outputs
46
+ :param format: CompressionFormat for the layer
45
47
  """
46
48
 
47
49
  targets: List[str]
48
50
  weights: Optional[QuantizationArgs] = None
49
51
  input_activations: Optional[QuantizationArgs] = None
50
52
  output_activations: Optional[QuantizationArgs] = None
53
+ format: Optional[str] = None
51
54
 
52
55
  @model_validator(mode="after")
53
56
  def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
@@ -86,6 +86,7 @@ __all__ = [
86
86
  "offloaded_dispatch",
87
87
  "disable_offloading",
88
88
  "remove_dispatch",
89
+ "cast_to_device",
89
90
  ]
90
91
 
91
92
 
@@ -169,6 +170,19 @@ def update_parameter_data(
169
170
  """ Candidates for Upstreaming """
170
171
 
171
172
 
173
+ def cast_to_device(device_spec: Union[int, torch.device]) -> torch.device:
174
+ """
175
+ Convert an integer device index or torch.device into a torch.device object.
176
+
177
+ :param device_spec: Device index (int) or torch.device object.
178
+ Negative integers map to CPU.
179
+ :return: torch.device corresponding to the given device specification.
180
+ """
181
+ if isinstance(device_spec, int):
182
+ return torch.device(f"cuda:{device_spec}" if device_spec >= 0 else "cpu")
183
+ return device_spec
184
+
185
+
172
186
  def get_execution_device(module: torch.nn.Module) -> torch.device:
173
187
  """
174
188
  Get the device which inputs should be moved to before module execution.
@@ -179,7 +193,7 @@ def get_execution_device(module: torch.nn.Module) -> torch.device:
179
193
  """
180
194
  for submodule in module.modules():
181
195
  if has_offloaded_params(submodule):
182
- return submodule._hf_hook.execution_device
196
+ return cast_to_device(submodule._hf_hook.execution_device)
183
197
 
184
198
  param = next(submodule.parameters(recurse=False), None)
185
199
  if param is not None:
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.3.a20250811'
20
+ __version__ = version = '0.10.3.a20250814'
21
21
  __version_tuple__ = version_tuple = (0, 10, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250811
3
+ Version: 0.10.3a20250814
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,11 +1,11 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=-gxWvDF4LCkyeDP8YlGzvBBKxo4Dk9h4NINPD61drFU,921
3
- compressed_tensors/version.py,sha256=9NgEdMzgL7r039RghUuU-BhjHVuGK1utk2z8Au9OlWA,523
3
+ compressed_tensors/version.py,sha256=fAUC53w9XJ-gbZ3V6UhPrss4y7OIGxsBJ0rFa1T--zA,523
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
7
7
  compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
8
- compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=0WULLKpgWuTQLjKsCstiTssT778wp9TWMkQjHbYO4Zo,33989
8
+ compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=FuPS3LYSJk0ATu6caW_GQsFi31EqFTnQtR6mIe6fDAU,37278
9
9
  compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=KvaFBL_Q84LxRGJOV035M8OBoCkAx8kOkfphswgkKWk,745
10
10
  compressed_tensors/compressors/quantized_compressors/base.py,sha256=YGUMzbxekj_36ChgQnVZN6T8uDjXtGG1zfMIBGBLWco,10354
11
11
  compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=0ANDcuD8aXPqTYNPY6GnX9iS6eXJw6P0TzNV_rYS2l8,5369
@@ -19,7 +19,7 @@ compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=S8vW0
19
19
  compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
20
20
  compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=7F9J6wgkecitK5hHuqjetZ18HExHIF4QIw1wgm2Y6U8,10099
21
21
  compressed_tensors/config/__init__.py,sha256=8sOoZ6xvYSC79mBvEtO8l6xk4PC80d29AnnJiGMrY2M,737
22
- compressed_tensors/config/base.py,sha256=p3glQHvC2fjodf_SvlelVrTWSIjGXgGC86t8oVOlMng,3529
22
+ compressed_tensors/config/base.py,sha256=FaImUwb5G93en2BHUKDs76L_tO8NFpdxlfwAgQL7mNM,3569
23
23
  compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
24
24
  compressed_tensors/config/sparse_24_bitmask.py,sha256=Lhj39zT2V1hxftprvxvneyhv45ShlXOKd75DBbDTyTE,1401
25
25
  compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
@@ -27,8 +27,8 @@ compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajC
27
27
  compressed_tensors/linear/compressed_linear.py,sha256=1yo9RyjA0aQ--iuIknFfcSorJn43Mn4CoV-q4JlTJ_o,4052
28
28
  compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
29
29
  compressed_tensors/quantization/quant_args.py,sha256=PMoaa6hpyJLGGSeCWefGmzGVxbOtxAdDunHJi_L5gNs,12894
30
- compressed_tensors/quantization/quant_config.py,sha256=StEpCvc70JasE1srLaHqI-TJlasLWGtHU2o0E_gDJhQ,10400
31
- compressed_tensors/quantization/quant_scheme.py,sha256=3EUGCw5_e7nnmvYPK_UlQKaaskOLIAo30dHYn0z7HmQ,8521
30
+ compressed_tensors/quantization/quant_config.py,sha256=2NgDwKuQn0f-ojiHC8c6tXtYX_zQlk26Rj-bU71QKvA,10598
31
+ compressed_tensors/quantization/quant_scheme.py,sha256=k25Cdx7BZCvLlRlENu4BVoFxquqcErP58P3Y_1HsKB4,8661
32
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
33
33
  compressed_tensors/quantization/lifecycle/apply.py,sha256=wM8mVcbKvZjBo18pSXMp28i30YWwUXJPSS7_HCakH9U,17892
34
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
@@ -57,14 +57,14 @@ compressed_tensors/utils/__init__.py,sha256=spzbjUO4-hZ2jXGST27r3MIt2yzIXsjdbEaY
57
57
  compressed_tensors/utils/helpers.py,sha256=Q3iRAa2XSdmmn4vSpUplnvKOmWwn4Clao9ZkPBHXtpI,12604
58
58
  compressed_tensors/utils/internal.py,sha256=7SSWgDoNFRnlfadwkoFhLW-T2jOc7Po_WzWv5h32Sa8,982
59
59
  compressed_tensors/utils/match.py,sha256=9x-yZIlq7ndSLf2aQwNT7IpBQDe-8H6utiJkji8wPrQ,9397
60
- compressed_tensors/utils/offload.py,sha256=3XiBuWbUkBAt8v1t5i57qDcbB3VJQs_FDeayi-JzIWg,23896
60
+ compressed_tensors/utils/offload.py,sha256=gFoEDaissHsLM5-JDbgPxh5hiE9VFN4HFxvszYvReos,24446
61
61
  compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
62
62
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
63
63
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
64
64
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
65
65
  compressed_tensors/utils/type.py,sha256=bNwoo_FWlvLuDpYAGGzZJITRg0JA_Ngk9LGPo-kvjeU,2554
66
- compressed_tensors-0.10.3a20250811.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
- compressed_tensors-0.10.3a20250811.dist-info/METADATA,sha256=mYMXLEK9r53lXrMbZBRmkimI3aW-X1x4n-8DUThb0K8,7031
68
- compressed_tensors-0.10.3a20250811.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
69
- compressed_tensors-0.10.3a20250811.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
70
- compressed_tensors-0.10.3a20250811.dist-info/RECORD,,
66
+ compressed_tensors-0.10.3a20250814.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
+ compressed_tensors-0.10.3a20250814.dist-info/METADATA,sha256=-7voWXyJPB13WkMJADa57hDE4euKxrCjnQfYdHROjKg,7031
68
+ compressed_tensors-0.10.3a20250814.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
69
+ compressed_tensors-0.10.3a20250814.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
70
+ compressed_tensors-0.10.3a20250814.dist-info/RECORD,,