compressed-tensors 0.9.5a20250428__tar.gz → 0.9.5a20250507__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. {compressed_tensors-0.9.5a20250428/src/compressed_tensors.egg-info → compressed_tensors-0.9.5a20250507}/PKG-INFO +1 -1
  2. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +20 -38
  3. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/quantized_compressors/base.py +92 -85
  4. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +23 -11
  5. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/lifecycle/apply.py +4 -4
  6. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/lifecycle/initialize.py +2 -1
  7. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/utils/helpers.py +7 -0
  8. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/utils/safetensors_load.py +10 -10
  9. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/version.py +1 -1
  10. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  11. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +7 -7
  12. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/quantized_compressors/test_int_quant.py +9 -9
  13. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/quantized_compressors/test_pack_quant.py +17 -21
  14. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +4 -4
  15. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/lifecycle/test_initialize.py +3 -1
  16. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.github/.gitkeep +0 -0
  17. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.github/actions/test/action.yml +0 -0
  18. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.github/scripts/step-status +0 -0
  19. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.github/workflows/build-test.yml +0 -0
  20. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.github/workflows/build.yml +0 -0
  21. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.github/workflows/report.yml +0 -0
  22. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.github/workflows/test-check.yaml +0 -0
  23. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.github/workflows/test.yml +0 -0
  24. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.github/workflows/trigger-all.yml +0 -0
  25. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.github/workflows/upload.yml +0 -0
  26. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/.gitignore +0 -0
  27. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/LICENSE +0 -0
  28. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/Makefile +0 -0
  29. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/README.md +0 -0
  30. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  31. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/examples/bit_packing/int4_config.json +0 -0
  32. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/examples/bitmask_compression.ipynb +0 -0
  33. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  34. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  35. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/examples/llama_1.1b/example_quant_config.json +0 -0
  36. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  37. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/examples/quantize_and_pack_int4.ipynb +0 -0
  38. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/pyproject.toml +0 -0
  39. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/setup.cfg +0 -0
  40. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/setup.py +0 -0
  41. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/__init__.py +0 -0
  42. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/README.md +0 -0
  43. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/__init__.py +0 -0
  44. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/base.py +0 -0
  45. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/__init__.py +0 -0
  46. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/base.py +0 -0
  47. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/helpers.py +0 -0
  48. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  49. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  50. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  51. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  52. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  53. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  54. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  55. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  56. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  57. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  58. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/config/__init__.py +0 -0
  59. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/config/base.py +0 -0
  60. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/config/dense.py +0 -0
  61. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  62. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  63. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/linear/__init__.py +0 -0
  64. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  65. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/__init__.py +0 -0
  66. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  67. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  68. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  69. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  70. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/quant_args.py +0 -0
  71. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/quant_config.py +0 -0
  72. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  73. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  74. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  75. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/registry/__init__.py +0 -0
  76. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/registry/registry.py +0 -0
  77. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/utils/__init__.py +0 -0
  78. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/utils/offload.py +0 -0
  79. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/utils/permutations_24.py +0 -0
  80. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/utils/permute.py +0 -0
  81. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  82. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors.egg-info/SOURCES.txt +0 -0
  83. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  84. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors.egg-info/requires.txt +0 -0
  85. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  86. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/__init__.py +0 -0
  87. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/conftest.py +0 -0
  88. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/__init__.py +0 -0
  89. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/model_compressors/__init__.py +0 -0
  90. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  91. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  92. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  93. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  94. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  95. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  96. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_configs/__init__.py +0 -0
  97. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_configs/test_base.py +0 -0
  98. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  99. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_linear/__init__.py +0 -0
  100. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_linear/test_compressed_linear.py +0 -0
  101. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/__init__.py +0 -0
  102. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/lifecycle/__init__.py +0 -0
  103. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/lifecycle/conftest.py +0 -0
  104. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  105. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  106. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  107. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  108. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  109. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  110. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/test_configs/__init__.py +0 -0
  111. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  112. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  113. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/test_quant_args.py +0 -0
  114. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/test_quant_config.py +0 -0
  115. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/test_quant_scheme.py +0 -0
  116. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  117. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_registry.py +0 -0
  118. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_utils/__init__.py +0 -0
  119. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_utils/test_helpers.py +0 -0
  120. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_utils/test_offload.py +0 -0
  121. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/test_utils/test_safetensors_load.py +0 -0
  122. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/tests/testing_utils.py +0 -0
  123. {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250507}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250428
3
+ Version: 0.9.5a20250507
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -19,7 +19,7 @@ import os
19
19
  import re
20
20
  from contextlib import contextmanager
21
21
  from copy import deepcopy
22
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
22
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
23
23
 
24
24
  import compressed_tensors
25
25
  import torch
@@ -36,12 +36,12 @@ from compressed_tensors.config import CompressionFormat, SparsityCompressionConf
36
36
  from compressed_tensors.quantization import (
37
37
  DEFAULT_QUANTIZATION_METHOD,
38
38
  QuantizationConfig,
39
+ QuantizationScheme,
39
40
  QuantizationStatus,
40
41
  apply_quantization_config,
41
42
  load_pretrained_quantization_parameters,
42
43
  )
43
44
  from compressed_tensors.quantization.lifecycle import expand_target_names
44
- from compressed_tensors.quantization.quant_args import QuantizationArgs
45
45
  from compressed_tensors.quantization.utils import (
46
46
  is_module_quantized,
47
47
  iter_named_leaf_modules,
@@ -64,7 +64,7 @@ from transformers import AutoConfig
64
64
  from transformers.file_utils import CONFIG_NAME
65
65
 
66
66
 
67
- __all__ = ["ModelCompressor", "map_modules_to_quant_args"]
67
+ __all__ = ["ModelCompressor", "map_module_to_scheme"]
68
68
 
69
69
  _LOGGER: logging.Logger = logging.getLogger(__name__)
70
70
 
@@ -372,20 +372,17 @@ class ModelCompressor:
372
372
  :param state_dict: optional uncompressed state_dict to insert into model
373
373
  :return: compressed state dict
374
374
  """
375
+
375
376
  if state_dict is None:
376
377
  state_dict = model.state_dict()
377
378
 
378
- compressed_state_dict = state_dict
379
-
380
- quantized_modules_to_args: Dict[
381
- str, QuantizationArgs
382
- ] = map_modules_to_quant_args(model)
383
-
384
379
  if self.quantization_compressor is not None:
385
- compressed_state_dict = self.quantization_compressor.compress(
386
- state_dict, names_to_scheme=quantized_modules_to_args
380
+ module_to_scheme = map_module_to_scheme(model)
381
+ state_dict = self.quantization_compressor.compress(
382
+ state_dict, names_to_scheme=module_to_scheme
387
383
  )
388
384
 
385
+ # TODO: consider sparse compression to also be compression
389
386
  if self.quantization_config.format != CompressionFormat.dense.value:
390
387
  self.quantization_config.quantization_status = (
391
388
  QuantizationStatus.COMPRESSED
@@ -397,8 +394,8 @@ class ModelCompressor:
397
394
  targets=self.sparsity_config.targets,
398
395
  ignore=self.sparsity_config.ignore,
399
396
  )
400
- compressed_state_dict = self.sparsity_compressor.compress(
401
- compressed_state_dict,
397
+ state_dict = self.sparsity_compressor.compress(
398
+ state_dict,
402
399
  compression_targets=sparse_compression_targets,
403
400
  )
404
401
 
@@ -407,7 +404,7 @@ class ModelCompressor:
407
404
  # https://github.com/huggingface/transformers/pull/30488
408
405
  transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
409
406
 
410
- return compressed_state_dict
407
+ return state_dict
411
408
 
412
409
  def decompress(self, model_path: str, model: Module):
413
410
  """
@@ -576,8 +573,8 @@ class ModelCompressor:
576
573
  :param model: The model whose weights are to be updated.
577
574
  """
578
575
 
579
- for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
580
- module = operator.attrgetter(name)(model)
576
+ for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"):
577
+ module = operator.attrgetter(mod_path)(model)
581
578
 
582
579
  params_device = next(module.parameters()).device
583
580
  device = "cpu" if has_offloaded_params(module) else params_device
@@ -605,30 +602,15 @@ class ModelCompressor:
605
602
  update_parameter_data(module, param_data, param_name)
606
603
 
607
604
 
608
- def map_modules_to_quant_args(
609
- model: Module,
610
- ) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]:
605
+ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
611
606
  """
612
- Given a pytorch model, map out the submodule name (usually linear layers)
613
- to the weight QuantizationArgs. If running input activation quantization, will also
614
- map to the input QuantizationArgs in a tuple.
615
-
616
- :param model: pytorch model
607
+ Returns a dictionary which maps quantized module names to their quantization schemes
617
608
  """
618
- quantized_modules_to_args = {}
619
- for name, submodule in iter_named_leaf_modules(model):
620
- if is_module_quantized(submodule):
621
- if submodule.quantization_scheme.weights is not None:
622
- name = fix_fsdp_module_name(name)
623
- quantized_modules_to_args[name] = submodule.quantization_scheme.weights
624
- if submodule.quantization_scheme.input_activations is not None:
625
- weight_args = quantized_modules_to_args.get(name)
626
- quantized_modules_to_args[name] = (
627
- weight_args,
628
- submodule.quantization_scheme.input_activations,
629
- )
630
-
631
- return quantized_modules_to_args
609
+ return {
610
+ fix_fsdp_module_name(name): module.quantization_scheme
611
+ for name, module in iter_named_leaf_modules(model)
612
+ if is_module_quantized(module)
613
+ }
632
614
 
633
615
 
634
616
  # HACK: Override the dtype_byte_size function in transformers to support float8 types
@@ -14,15 +14,16 @@
14
14
 
15
15
  import logging
16
16
  from pathlib import Path
17
- from typing import Any, Dict, Generator, Optional, Tuple, Union
17
+ from typing import Any, Dict, Generator, Tuple, Union
18
18
 
19
19
  import torch
20
20
  from compressed_tensors.compressors.base import BaseCompressor
21
- from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
21
+ from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
22
22
  from compressed_tensors.utils import (
23
23
  get_nested_mappings_from_state_dict,
24
24
  get_nested_weight_mappings,
25
25
  merge_names,
26
+ remove_suffix,
26
27
  )
27
28
  from safetensors import safe_open
28
29
  from torch import Tensor
@@ -69,7 +70,7 @@ class BaseQuantizationCompressor(BaseCompressor):
69
70
  def compress(
70
71
  self,
71
72
  model_state: Dict[str, Tensor],
72
- names_to_scheme: Dict[str, QuantizationArgs],
73
+ names_to_scheme: Dict[str, QuantizationScheme],
73
74
  **kwargs,
74
75
  ) -> Dict[str, Tensor]:
75
76
  """
@@ -81,87 +82,87 @@ class BaseQuantizationCompressor(BaseCompressor):
81
82
  :return: compressed state dict
82
83
  """
83
84
  compressed_dict = {}
84
- weight_suffix = ".weight"
85
- input_zp_suffix = ".input_zero_point"
86
- weight_zp_suffix = ".weight_zero_point"
87
- _LOGGER.debug(
88
- f"Compressing model with {len(model_state)} parameterized layers..."
89
- )
85
+ save_device = "cpu"
86
+
87
+ uncompressed_names = list(model_state.keys())
88
+ for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
89
+ value = model_state[name]
90
+
91
+ # compress weights
92
+ if name.endswith("weight"):
93
+ prefix = remove_suffix(name, "weight")
94
+
95
+ # gather qparams
96
+ scale = model_state.get(prefix + "weight_scale", None)
97
+ g_idx = model_state.get(prefix + "weight_g_idx", None)
98
+ zp = model_state.get(prefix + "weight_zero_point", None)
99
+
100
+ # is scale does not exist, then weight cannot be compressed
101
+ if scale is None:
102
+ compressed_dict[name] = value.to(save_device)
103
+ continue
104
+
105
+ # compress values on cpu (memory movement too expensive)
106
+ module_path = prefix[:-1] if prefix.endswith(".") else prefix
107
+ quant_args = names_to_scheme[module_path].weights
108
+ compressed_values = self.compress_weight(
109
+ weight=value,
110
+ scale=scale,
111
+ zero_point=zp,
112
+ g_idx=g_idx,
113
+ quantization_args=quant_args,
114
+ device="cpu",
115
+ )
116
+
117
+ # update state dict
118
+ for key, value in compressed_values.items():
119
+ compressed_dict[prefix + key] = value.to(save_device)
90
120
 
91
- for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
92
- # check if the parameter we're compressing is the weight zp
93
- # or the input zp
94
- is_weight_zp = name.endswith(weight_zp_suffix)
95
- is_input_zp = name.endswith(input_zp_suffix)
96
-
97
- # if we're saving the weight zp, fetch weight quant args
98
- if is_weight_zp:
99
- quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))])
100
- if isinstance(quant_args_zp, tuple):
101
- # If tuple, first value is weight args, second is input args
102
- quant_args_zp = quant_args_zp[0]
103
-
104
- # if we're saving the input zp, fetch input quant args
105
- if is_input_zp:
106
- input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))])
107
- if isinstance(input_args_zp, tuple):
108
- # If tuple, first value is weight args, second is input args
109
- input_args_zp = input_args_zp[-1]
110
-
111
- if name.endswith(weight_suffix):
112
- prefix = name[: -(len(weight_suffix))]
113
- scale = model_state.get(merge_names(prefix, "weight_scale"), None)
114
- zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
115
- g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
116
- if scale is not None:
117
- # weight is quantized, compress it
118
- if isinstance(names_to_scheme[prefix], tuple):
119
- quant_args = names_to_scheme[prefix][0]
120
- else:
121
- quant_args = names_to_scheme[prefix]
122
-
123
- compressed_data = self.compress_weight(
124
- weight=value,
125
- scale=scale,
126
- zero_point=zp,
127
- g_idx=g_idx,
128
- quantization_args=quant_args,
129
- device="cpu",
130
- )
131
- for key, value in compressed_data.items():
132
- compressed_dict[merge_names(prefix, key)] = value
133
- else:
134
- compressed_dict[name] = value.to("cpu")
135
- # only save zp if asym and not packed zp
136
- elif is_weight_zp and (
137
- quant_args_zp.symmetric or self._check_if_zp_pack_quantized(quant_args)
138
- ):
139
- continue
140
- # only save if asym
141
- elif is_input_zp and input_args_zp.symmetric:
142
- continue
143
- elif name.endswith("g_idx") and torch.any(value <= -1):
144
- continue
145
121
  else:
146
- compressed_dict[name] = value.to("cpu")
122
+ # omit saving zero points for symmetric or packed quantization
123
+ if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
124
+ continue
125
+
126
+ # omit saving for g_idx if uninitialized
127
+ # TODO: does this case actually occur?
128
+ elif name.endswith("g_idx") and torch.any(value <= -1):
129
+ continue
130
+
131
+ compressed_dict[name] = value.to(save_device)
147
132
 
148
133
  return compressed_dict
149
134
 
150
- def _check_if_zp_pack_quantized(self, quant_args):
135
+ def _skip_zp(
136
+ self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
137
+ ) -> bool:
151
138
  from compressed_tensors.compressors import PackedQuantizationCompressor
152
139
 
153
- if isinstance(self, PackedQuantizationCompressor):
154
- if not quant_args.symmetric and quant_args.strategy in [
155
- QuantizationStrategy.GROUP.value,
156
- QuantizationStrategy.CHANNEL.value,
157
- ]:
158
- return True
159
- return False
140
+ module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
141
+ scheme = names_to_scheme[module_name]
142
+
143
+ if zp_name == "weight_zero_point":
144
+ args = scheme.weights
145
+ if zp_name == "input_zero_point":
146
+ args = scheme.input_activations
147
+ if zp_name == "output_zero_point":
148
+ args = scheme.output_activations
149
+
150
+ symmetric = args.symmetric
151
+ packable_strategies = [
152
+ QuantizationStrategy.GROUP.value,
153
+ QuantizationStrategy.CHANNEL.value,
154
+ ]
155
+ packed = (
156
+ isinstance(self, PackedQuantizationCompressor)
157
+ and args.strategy in packable_strategies
158
+ )
159
+
160
+ return symmetric or packed
160
161
 
161
162
  def decompress(
162
163
  self,
163
164
  path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
164
- names_to_scheme: Dict[str, QuantizationArgs],
165
+ names_to_scheme: Dict[str, QuantizationScheme],
165
166
  device: str = "cpu",
166
167
  ) -> Generator[Tuple[str, Tensor], None, None]:
167
168
  """
@@ -170,8 +171,9 @@ class BaseQuantizationCompressor(BaseCompressor):
170
171
  dense state dict
171
172
  :param path_to_model_or_tensors: path to compressed safetensors model (directory
172
173
  with one or more safetensors files) or compressed tensors file
173
- :param names_to_scheme: quantization args for each quantized weight
174
- :param device: optional device to load intermediate weights into
174
+ :param names_to_scheme: quantization scheme for each quantized weight
175
+ :param device: optional device to load intermediate weights into (must be `str`,
176
+ not `torch.device`)
175
177
  :return: compressed state dict
176
178
  """
177
179
  if isinstance(path_to_model_or_tensors, (str, Path)):
@@ -184,37 +186,42 @@ class BaseQuantizationCompressor(BaseCompressor):
184
186
  path_to_model_or_tensors, names_to_scheme
185
187
  )
186
188
 
187
- def _decompress_from_path(self, path_to_model, names_to_scheme, device):
189
+ def _decompress_from_path(
190
+ self,
191
+ path_to_model: Union[str, Path, Dict[str, Any]],
192
+ names_to_scheme: Dict[str, QuantizationScheme],
193
+ device: str,
194
+ ):
188
195
  weight_mappings = get_nested_weight_mappings(
189
196
  path_to_model, self.compression_param_names
190
197
  )
191
- for weight_name in weight_mappings.keys():
198
+ for module_path in weight_mappings.keys():
192
199
  weight_data = {}
193
- for param_name, safe_path in weight_mappings[weight_name].items():
194
- full_name = merge_names(weight_name, param_name)
200
+ for param_name, safe_path in weight_mappings[module_path].items():
201
+ full_name = merge_names(module_path, param_name)
195
202
  with safe_open(safe_path, framework="pt", device=device) as f:
196
203
  weight_data[param_name] = f.get_tensor(full_name)
197
204
  if "weight_scale" in weight_data:
198
- quant_args = names_to_scheme[weight_name]
205
+ quant_args = names_to_scheme[module_path].weights
199
206
  decompressed = self.decompress_weight(
200
207
  compressed_data=weight_data, quantization_args=quant_args
201
208
  )
202
209
  weight_data["weight"] = decompressed
203
- yield weight_name, weight_data
210
+ yield module_path, weight_data
204
211
 
205
212
  def _decompress_from_state_dict(self, state_dict, names_to_scheme):
206
213
  weight_mappings = get_nested_mappings_from_state_dict(
207
214
  state_dict, self.compression_param_names
208
215
  )
209
- for weight_name in weight_mappings.keys():
216
+ for module_path in weight_mappings.keys():
210
217
  weight_data = {}
211
- for param_name, param_value in weight_mappings[weight_name].items():
218
+ for param_name, param_value in weight_mappings[module_path].items():
212
219
  weight_data[param_name] = param_value
213
220
 
214
221
  if "weight_scale" in weight_data:
215
- quant_args = names_to_scheme[weight_name]
222
+ quant_args = names_to_scheme[module_path]
216
223
  decompressed = self.decompress_weight(
217
224
  compressed_data=weight_data, quantization_args=quant_args
218
225
  )
219
226
  weight_data["weight"] = decompressed
220
- yield weight_name, weight_data
227
+ yield module_path, weight_data
@@ -19,7 +19,11 @@ import numpy as np
19
19
  import torch
20
20
  from compressed_tensors.compressors.base import BaseCompressor
21
21
  from compressed_tensors.config import CompressionFormat
22
- from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
22
+ from compressed_tensors.quantization import (
23
+ QuantizationArgs,
24
+ QuantizationScheme,
25
+ QuantizationStrategy,
26
+ )
23
27
  from compressed_tensors.quantization.lifecycle.forward import quantize
24
28
  from compressed_tensors.utils import (
25
29
  get_permutations_24,
@@ -44,19 +48,25 @@ class Marlin24Compressor(BaseCompressor):
44
48
 
45
49
  @staticmethod
46
50
  def validate_quant_compatability(
47
- model_quant_args: Dict[str, QuantizationArgs]
51
+ names_to_scheme: Dict[str, QuantizationScheme]
48
52
  ) -> bool:
49
53
  """
50
54
  Checks if every quantized module in the model is compatible with Marlin24
51
55
  compression. Quantization must be channel or group strategy with group_size
52
56
  of 128. Only symmetric quantization is supported
53
57
 
54
- :param model_quant_args: dictionary of mapping module names to their
55
- quantization configuration
58
+ :param names_to_scheme: dictionary of mapping module names to their
59
+ quantization schemes
56
60
  :return: True if all modules are compatible with Marlin24 compression, raises
57
61
  a ValueError otherwise
58
62
  """
59
- for name, quant_args in model_quant_args.items():
63
+ for name, scheme in names_to_scheme.items():
64
+ quant_args = scheme.weights
65
+ if quant_args is None:
66
+ raise ValueError(
67
+ "Marlin24 Compressor is only valid for weight quantization schemes"
68
+ )
69
+
60
70
  strategy = quant_args.strategy
61
71
  group_size = quant_args.group_size
62
72
  symmetric = quant_args.symmetric
@@ -114,7 +124,7 @@ class Marlin24Compressor(BaseCompressor):
114
124
  def compress(
115
125
  self,
116
126
  model_state: Dict[str, Tensor],
117
- names_to_scheme: Dict[str, QuantizationArgs],
127
+ names_to_scheme: Dict[str, QuantizationScheme],
118
128
  **kwargs,
119
129
  ) -> Dict[str, Tensor]:
120
130
  """
@@ -122,8 +132,8 @@ class Marlin24Compressor(BaseCompressor):
122
132
  with the Marlin24 kernel
123
133
 
124
134
  :param model_state: state dict of uncompressed model
125
- :param names_to_scheme: quantization args for each quantized weight, needed for
126
- quantize function to calculate bit depth
135
+ :param names_to_scheme: quantization scheme for each quantized weight, needed
136
+ for quantize function to calculate bit depth
127
137
  :return: compressed state dict
128
138
  """
129
139
  self.validate_quant_compatability(names_to_scheme)
@@ -146,7 +156,7 @@ class Marlin24Compressor(BaseCompressor):
146
156
  value = value.to(torch.float16)
147
157
 
148
158
  # quantize weight, keeping it as a float16 for now
149
- quant_args = names_to_scheme[prefix]
159
+ quant_args = names_to_scheme[prefix].weights
150
160
  value = quantize(
151
161
  x=value, scale=scale, zero_point=zp, args=quant_args
152
162
  )
@@ -215,7 +225,7 @@ def pack_weight_24(
215
225
  weight: Tensor,
216
226
  quantization_args: QuantizationArgs,
217
227
  tile: int = 16,
218
- ):
228
+ ) -> torch.Tensor:
219
229
  size_k = weight.shape[0]
220
230
  size_n = weight.shape[1]
221
231
  num_bits = quantization_args.num_bits
@@ -236,7 +246,9 @@ def pack_weight_24(
236
246
  return q_packed
237
247
 
238
248
 
239
- def pack_scales_24(scales, quantization_args, w_shape):
249
+ def pack_scales_24(
250
+ scales: torch.Tensor, quantization_args: QuantizationArgs, w_shape: torch.Size
251
+ ) -> torch.Tensor:
240
252
  size_k = w_shape[0]
241
253
  size_n = w_shape[1]
242
254
  num_bits = quantization_args.num_bits
@@ -119,7 +119,7 @@ def load_pretrained_quantization_parameters(
119
119
 
120
120
  def apply_quantization_config(
121
121
  model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
122
- ) -> OrderedDict:
122
+ ) -> Dict[str, QuantizationScheme]:
123
123
  """
124
124
  Initializes the model for quantization in-place based on the given config.
125
125
  Optionally coverts quantizable modules to compressed_linear modules
@@ -131,7 +131,7 @@ def apply_quantization_config(
131
131
  """
132
132
  # Workaround for when HF Quantizer passes None, see PR #180
133
133
  if config is None:
134
- return OrderedDict()
134
+ return dict()
135
135
 
136
136
  # remove reference to the original `config`
137
137
  # argument. This function can mutate it, and we'd
@@ -141,7 +141,7 @@ def apply_quantization_config(
141
141
  # use ordered dict to preserve target ordering in config
142
142
  target_to_scheme = OrderedDict()
143
143
  config = process_quantization_config(config)
144
- names_to_scheme = OrderedDict()
144
+ names_to_scheme = dict()
145
145
  for scheme in config.config_groups.values():
146
146
  for target in scheme.targets:
147
147
  target_to_scheme[target] = scheme
@@ -187,7 +187,7 @@ def apply_quantization_config(
187
187
  target_to_scheme, targets, name
188
188
  )
189
189
 
190
- names_to_scheme[name] = submodule.quantization_scheme.weights
190
+ names_to_scheme[name] = submodule.quantization_scheme
191
191
 
192
192
  if config.ignore is not None and ignored_submodules is not None:
193
193
  if set(config.ignore) - set(ignored_submodules):
@@ -14,6 +14,7 @@
14
14
 
15
15
 
16
16
  import logging
17
+ import math
17
18
  from enum import Enum
18
19
  from typing import Optional
19
20
 
@@ -162,7 +163,7 @@ def _initialize_scale_zero_point(
162
163
  # (output_channels, 1)
163
164
  expected_shape = (weight_shape[0], 1)
164
165
  elif quantization_args.strategy == QuantizationStrategy.GROUP:
165
- num_groups = weight_shape[1] // quantization_args.group_size
166
+ num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
166
167
  expected_shape = (weight_shape[0], max(num_groups, 1))
167
168
 
168
169
  scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
@@ -38,6 +38,7 @@ __all__ = [
38
38
  "shard_tensor",
39
39
  "pack_bitmasks",
40
40
  "unpack_bitmasks",
41
+ "remove_suffix",
41
42
  ]
42
43
 
43
44
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -328,3 +329,9 @@ def unpack_bitmasks(
328
329
  )
329
330
 
330
331
  return unpacked_bitmasks_torch
332
+
333
+
334
+ def remove_suffix(value: str, suffix: str) -> str:
335
+ # can replace with str.removesuffix in python3.9+
336
+ assert value.endswith(suffix)
337
+ return value[: -len(suffix)]
@@ -234,11 +234,11 @@ def get_nested_weight_mappings(
234
234
  for key, file_location in weight_mappings.items():
235
235
  matched = False
236
236
  for param_name in params_to_nest:
237
- dense_param = match_param_name(key, param_name)
238
- if dense_param:
239
- if dense_param not in nested_weight_mappings:
240
- nested_weight_mappings[dense_param] = {}
241
- nested_weight_mappings[dense_param][param_name] = file_location
237
+ module_path = match_param_name(key, param_name)
238
+ if module_path:
239
+ if module_path not in nested_weight_mappings:
240
+ nested_weight_mappings[module_path] = {}
241
+ nested_weight_mappings[module_path][param_name] = file_location
242
242
  matched = True
243
243
  if return_unmatched_params and not matched:
244
244
  unmatched_params[key] = file_location
@@ -271,11 +271,11 @@ def get_nested_mappings_from_state_dict(
271
271
  nested_weight_mappings = {}
272
272
  for key in state_dict.keys():
273
273
  for param_name in params_to_nest:
274
- dense_param = match_param_name(key, param_name)
275
- if dense_param:
276
- if dense_param not in nested_weight_mappings:
277
- nested_weight_mappings[dense_param] = {}
278
- nested_weight_mappings[dense_param][param_name] = state_dict[key]
274
+ module_path = match_param_name(key, param_name)
275
+ if module_path:
276
+ if module_path not in nested_weight_mappings:
277
+ nested_weight_mappings[module_path] = {}
278
+ nested_weight_mappings[module_path][param_name] = state_dict[key]
279
279
  return nested_weight_mappings
280
280
 
281
281
 
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.5.a20250428'
20
+ __version__ = version = '0.9.5.a20250507'
21
21
  __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250428
3
+ Version: 0.9.5a20250507
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -84,9 +84,9 @@ def test_quant_format(strategy, group_size, sc, zp):
84
84
  quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)
85
85
 
86
86
  compressor = FloatQuantizationCompressor(config=quant_config)
87
- quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}
87
+ module_name_to_scheme = {"dummy": quant_config.config_groups["group_1"]}
88
88
  compressed_state_dict = compressor.compress(
89
- dense_state_dict, names_to_scheme=quantized_modules_to_args
89
+ dense_state_dict, names_to_scheme=module_name_to_scheme
90
90
  )
91
91
 
92
92
  # state_dict params should be the same, minus the zero_point if symmetric
@@ -140,15 +140,15 @@ def test_reload_match(
140
140
  )
141
141
 
142
142
  compressor = FloatQuantizationCompressor(config=quant_config)
143
- quantized_modules_to_args = {
144
- "dummy": quant_config.config_groups["group_1"].weights,
143
+ module_name_to_scheme = {
144
+ "dummy": quant_config.config_groups["group_1"],
145
145
  }
146
146
  compressed_state_dict = compressor.compress(
147
- model.state_dict(), names_to_scheme=quantized_modules_to_args
147
+ model.state_dict(), names_to_scheme=module_name_to_scheme
148
148
  )
149
149
  save_file(compressed_state_dict, tmp_path / "model.safetensors")
150
150
  reconstructed_dense_gen = compressor.decompress(
151
- tmp_path, names_to_scheme=quantized_modules_to_args
151
+ tmp_path, names_to_scheme=module_name_to_scheme
152
152
  )
153
153
  reconstructed_dense = {}
154
154
  for name, value in reconstructed_dense_gen:
@@ -158,7 +158,7 @@ def test_reload_match(
158
158
  model.dummy.weight,
159
159
  scale=model.dummy.weight_scale,
160
160
  zero_point=model.dummy.weight_zero_point,
161
- args=quantized_modules_to_args["dummy"],
161
+ args=module_name_to_scheme["dummy"].weights,
162
162
  )
163
163
  assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy"].get("weight"))
164
164