compressed-tensors 0.3.3__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/PKG-INFO +52 -4
  2. compressed-tensors-0.3.3/src/compressed_tensors.egg-info/PKG-INFO → compressed_tensors-0.5.0/README.md +40 -14
  3. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/setup.py +25 -4
  4. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/base.py +3 -1
  5. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/__init__.py +9 -1
  6. compressed_tensors-0.5.0/src/compressed_tensors/compressors/base.py +60 -0
  7. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/dense.py +5 -5
  8. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/helpers.py +12 -12
  9. compressed_tensors-0.5.0/src/compressed_tensors/compressors/marlin_24.py +251 -0
  10. compressed_tensors-0.5.0/src/compressed_tensors/compressors/model_compressor.py +336 -0
  11. compressed_tensors-0.5.0/src/compressed_tensors/compressors/naive_quantized.py +144 -0
  12. compressed_tensors-0.5.0/src/compressed_tensors/compressors/pack_quantized.py +219 -0
  13. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/sparse_bitmask.py +4 -4
  14. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/config/base.py +9 -4
  15. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/config/dense.py +4 -4
  16. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/config/sparse_bitmask.py +3 -3
  17. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/__init__.py +2 -0
  18. compressed_tensors-0.5.0/src/compressed_tensors/quantization/lifecycle/apply.py +351 -0
  19. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/calibration.py +20 -1
  20. compressed_tensors-0.5.0/src/compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  21. compressed_tensors-0.5.0/src/compressed_tensors/quantization/lifecycle/forward.py +373 -0
  22. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  23. compressed_tensors-0.5.0/src/compressed_tensors/quantization/lifecycle/helpers.py +53 -0
  24. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/initialize.py +62 -5
  25. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/observers/base.py +66 -23
  26. compressed_tensors-0.5.0/src/compressed_tensors/quantization/observers/helpers.py +111 -0
  27. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/observers/memoryless.py +17 -9
  28. compressed_tensors-0.5.0/src/compressed_tensors/quantization/observers/min_max.py +96 -0
  29. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/quant_args.py +47 -3
  30. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/quant_config.py +104 -23
  31. compressed_tensors-0.5.0/src/compressed_tensors/quantization/quant_scheme.py +220 -0
  32. compressed_tensors-0.5.0/src/compressed_tensors/quantization/utils/helpers.py +250 -0
  33. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/utils/__init__.py +4 -0
  34. compressed_tensors-0.5.0/src/compressed_tensors/utils/helpers.py +92 -0
  35. compressed_tensors-0.5.0/src/compressed_tensors/utils/offload.py +104 -0
  36. compressed_tensors-0.5.0/src/compressed_tensors/utils/permutations_24.py +65 -0
  37. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/utils/safetensors_load.py +3 -2
  38. compressed_tensors-0.5.0/src/compressed_tensors/utils/semi_structured_conversions.py +341 -0
  39. compressed_tensors-0.5.0/src/compressed_tensors/version.py +53 -0
  40. compressed-tensors-0.3.3/README.md → compressed_tensors-0.5.0/src/compressed_tensors.egg-info/PKG-INFO +64 -1
  41. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors.egg-info/SOURCES.txt +10 -1
  42. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors.egg-info/requires.txt +6 -5
  43. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/tests/test_registry.py +8 -8
  44. compressed-tensors-0.3.3/src/compressed_tensors/compressors/base.py +0 -103
  45. compressed-tensors-0.3.3/src/compressed_tensors/quantization/lifecycle/apply.py +0 -178
  46. compressed-tensors-0.3.3/src/compressed_tensors/quantization/lifecycle/forward.py +0 -221
  47. compressed-tensors-0.3.3/src/compressed_tensors/quantization/observers/helpers.py +0 -53
  48. compressed-tensors-0.3.3/src/compressed_tensors/quantization/observers/min_max.py +0 -65
  49. compressed-tensors-0.3.3/src/compressed_tensors/quantization/quant_scheme.py +0 -39
  50. compressed-tensors-0.3.3/src/compressed_tensors/quantization/utils/helpers.py +0 -116
  51. compressed-tensors-0.3.3/src/compressed_tensors/utils/helpers.py +0 -45
  52. compressed-tensors-0.3.3/tests/test_bitmask.py +0 -120
  53. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/LICENSE +0 -0
  54. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/pyproject.toml +0 -0
  55. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/setup.cfg +0 -0
  56. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/__init__.py +0 -0
  57. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/config/__init__.py +0 -0
  58. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/__init__.py +0 -0
  59. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/observers/__init__.py +0 -0
  60. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  61. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/registry/__init__.py +0 -0
  62. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors/registry/registry.py +0 -0
  63. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  64. {compressed-tensors-0.3.3 → compressed_tensors-0.5.0}/src/compressed_tensors.egg-info/top_level.txt +0 -0
@@ -1,15 +1,24 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.3.3
3
+ Version: 0.5.0
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
7
7
  Author-email: support@neuralmagic.com
8
8
  License: Apache 2.0
9
- Platform: UNKNOWN
10
9
  Description-Content-Type: text/markdown
11
- Provides-Extra: dev
12
10
  License-File: LICENSE
11
+ Requires-Dist: torch>=1.7.0
12
+ Requires-Dist: transformers
13
+ Requires-Dist: accelerate
14
+ Requires-Dist: pydantic>=2.0
15
+ Provides-Extra: dev
16
+ Requires-Dist: black==22.12.0; extra == "dev"
17
+ Requires-Dist: isort==5.8.0; extra == "dev"
18
+ Requires-Dist: wheel>=0.36.2; extra == "dev"
19
+ Requires-Dist: flake8>=3.8.3; extra == "dev"
20
+ Requires-Dist: pytest>=6.0.0; extra == "dev"
21
+ Requires-Dist: nbconvert>=7.16.3; extra == "dev"
13
22
 
14
23
  # compressed_tensors
15
24
 
@@ -81,7 +90,7 @@ from compressed_tensors import save_compressed_model, load_compressed, BitmaskCo
81
90
  from transformers import AutoModelForCausalLM
82
91
 
83
92
  model_name = "neuralmagic/llama2.c-stories110M-pruned50"
84
- model = AutoModelForCausalLM.from_pretrained(model_name)
93
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
85
94
 
86
95
  original_state_dict = model.state_dict()
87
96
 
@@ -97,3 +106,42 @@ state_dict = dict(load_compressed("compressed_model.safetensors", compression_co
97
106
  For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/neuralmagic/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb).
98
107
 
99
108
 
109
+ ## Saving a Compressed Model with PTQ
110
+
111
+ We can use compressed-tensors to run basic post training quantization (PTQ) and save the quantized model compressed on disk
112
+
113
+ ```python
114
+ model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
115
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0", torch_dtype="auto")
116
+
117
+ config = QuantizationConfig.parse_file("./examples/bit_packing/int4_config.json")
118
+ config.quantization_status = QuantizationStatus.CALIBRATION
119
+ apply_quantization_config(model, config)
120
+
121
+ dataset = load_dataset("ptb_text_only")["train"]
122
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
123
+
124
+ def tokenize_function(examples):
125
+ return tokenizer(examples["sentence"], padding=False, truncation=True, max_length=1024)
126
+
127
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
128
+ data_loader = DataLoader(tokenized_dataset, batch_size=1, collate_fn=DefaultDataCollator())
129
+
130
+ with torch.no_grad():
131
+ for idx, sample in tqdm(enumerate(data_loader), desc="Running calibration"):
132
+ sample = {key: value.to(device) for key,value in sample.items()}
133
+ _ = model(**sample)
134
+
135
+ if idx >= 512:
136
+ break
137
+
138
+ model.apply(freeze_module_quantization)
139
+ model.apply(compress_quantized_weights)
140
+
141
+ output_dir = "./ex_llama1.1b_w4a16_packed_quantize"
142
+ compressor = ModelCompressor(quantization_config=config)
143
+ compressed_state_dict = compressor.compress(model)
144
+ model.save_pretrained(output_dir, state_dict=compressed_state_dict)
145
+ ```
146
+
147
+ For more in-depth tutorial on quantization compression, refer to the [notebook](./examples/quantize_and_pack_int4.ipynb).
@@ -1,16 +1,3 @@
1
- Metadata-Version: 2.1
2
- Name: compressed-tensors
3
- Version: 0.3.3
4
- Summary: Library for utilization of compressed safetensors of neural network models
5
- Home-page: https://github.com/neuralmagic/compressed-tensors
6
- Author: Neuralmagic, Inc.
7
- Author-email: support@neuralmagic.com
8
- License: Apache 2.0
9
- Platform: UNKNOWN
10
- Description-Content-Type: text/markdown
11
- Provides-Extra: dev
12
- License-File: LICENSE
13
-
14
1
  # compressed_tensors
15
2
 
16
3
  This repository extends a [safetensors](https://github.com/huggingface/safetensors) format to efficiently store sparse and/or quantized tensors on disk. `compressed-tensors` format supports multiple compression types to minimize the disk space and facilitate the tensor manipulation.
@@ -81,7 +68,7 @@ from compressed_tensors import save_compressed_model, load_compressed, BitmaskCo
81
68
  from transformers import AutoModelForCausalLM
82
69
 
83
70
  model_name = "neuralmagic/llama2.c-stories110M-pruned50"
84
- model = AutoModelForCausalLM.from_pretrained(model_name)
71
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
85
72
 
86
73
  original_state_dict = model.state_dict()
87
74
 
@@ -97,3 +84,42 @@ state_dict = dict(load_compressed("compressed_model.safetensors", compression_co
97
84
  For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/neuralmagic/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb).
98
85
 
99
86
 
87
+ ## Saving a Compressed Model with PTQ
88
+
89
+ We can use compressed-tensors to run basic post training quantization (PTQ) and save the quantized model compressed on disk
90
+
91
+ ```python
92
+ model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
93
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0", torch_dtype="auto")
94
+
95
+ config = QuantizationConfig.parse_file("./examples/bit_packing/int4_config.json")
96
+ config.quantization_status = QuantizationStatus.CALIBRATION
97
+ apply_quantization_config(model, config)
98
+
99
+ dataset = load_dataset("ptb_text_only")["train"]
100
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
101
+
102
+ def tokenize_function(examples):
103
+ return tokenizer(examples["sentence"], padding=False, truncation=True, max_length=1024)
104
+
105
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
106
+ data_loader = DataLoader(tokenized_dataset, batch_size=1, collate_fn=DefaultDataCollator())
107
+
108
+ with torch.no_grad():
109
+ for idx, sample in tqdm(enumerate(data_loader), desc="Running calibration"):
110
+ sample = {key: value.to(device) for key,value in sample.items()}
111
+ _ = model(**sample)
112
+
113
+ if idx >= 512:
114
+ break
115
+
116
+ model.apply(freeze_module_quantization)
117
+ model.apply(compress_quantized_weights)
118
+
119
+ output_dir = "./ex_llama1.1b_w4a16_packed_quantize"
120
+ compressor = ModelCompressor(quantization_config=config)
121
+ compressed_state_dict = compressor.compress(model)
122
+ model.save_pretrained(output_dir, state_dict=compressed_state_dict)
123
+ ```
124
+
125
+ For more in-depth tutorial on quantization compression, refer to the [notebook](./examples/quantize_and_pack_int4.ipynb).
@@ -12,9 +12,30 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
15
+ import os
16
16
  from setuptools import setup, find_packages
17
17
  from typing import List, Dict, Tuple
18
+ from utils.artifacts import get_release_and_version
19
+
20
+
21
+ package_path = os.path.join(
22
+ os.path.dirname(os.path.realpath(__file__)), "src", "compressed_tensors"
23
+ )
24
+ (
25
+ is_release,
26
+ version,
27
+ version_major,
28
+ version_minor,
29
+ version_bug,
30
+ ) = get_release_and_version(package_path)
31
+
32
+ version_nm_deps = f"{version_major}.{version_minor}.0"
33
+
34
+ if is_release:
35
+ _PACKAGE_NAME = "compressed-tensors"
36
+ else:
37
+ _PACKAGE_NAME = "compressed-tensors-nightly"
38
+
18
39
 
19
40
  def _setup_long_description() -> Tuple[str, str]:
20
41
  return open("README.md", "r", encoding="utf-8").read(), "text/markdown"
@@ -25,14 +46,14 @@ def _setup_packages() -> List:
25
46
  )
26
47
 
27
48
  def _setup_install_requires() -> List:
28
- return ["torch>=1.7.0", "transformers<4.41", "pydantic<2.7"]
49
+ return ["torch>=1.7.0", "transformers", "accelerate", "pydantic>=2.0"]
29
50
 
30
51
  def _setup_extras() -> Dict:
31
52
  return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "nbconvert>=7.16.3"]}
32
53
 
33
54
  setup(
34
- name="compressed-tensors",
35
- version="0.3.3",
55
+ name=_PACKAGE_NAME,
56
+ version=version,
36
57
  author="Neuralmagic, Inc.",
37
58
  author_email="support@neuralmagic.com",
38
59
  license="Apache 2.0",
@@ -13,4 +13,6 @@
13
13
  # limitations under the License.
14
14
 
15
15
  SPARSITY_CONFIG_NAME = "sparsity_config"
16
- QUANTIZATION_CONFIG_NAME = "sparseml_quantization_config"
16
+ QUANTIZATION_CONFIG_NAME = "quantization_config"
17
+ COMPRESSION_CONFIG_NAME = "compression_config"
18
+ KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
@@ -14,7 +14,15 @@
14
14
 
15
15
  # flake8: noqa
16
16
 
17
- from .base import ModelCompressor
17
+ from .base import Compressor
18
18
  from .dense import DenseCompressor
19
19
  from .helpers import load_compressed, save_compressed, save_compressed_model
20
+ from .marlin_24 import Marlin24Compressor
21
+ from .model_compressor import ModelCompressor, map_modules_to_quant_args
22
+ from .naive_quantized import (
23
+ FloatQuantizationCompressor,
24
+ IntQuantizationCompressor,
25
+ QuantizationCompressor,
26
+ )
27
+ from .pack_quantized import PackedQuantizationCompressor
20
28
  from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
@@ -0,0 +1,60 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, Generator, Tuple, Union
16
+
17
+ from compressed_tensors.config import SparsityCompressionConfig
18
+ from compressed_tensors.quantization import QuantizationConfig
19
+ from compressed_tensors.registry import RegistryMixin
20
+ from torch import Tensor
21
+
22
+
23
+ __all__ = ["Compressor"]
24
+
25
+
26
+ class Compressor(RegistryMixin):
27
+ """
28
+ Base class representing a model compression algorithm
29
+
30
+ :param config: config specifying compression parameters
31
+ """
32
+
33
+ def __init__(
34
+ self, config: Union[SparsityCompressionConfig, QuantizationConfig, None] = None
35
+ ):
36
+ self.config = config
37
+
38
+ def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
39
+ """
40
+ Compresses a dense state dict
41
+
42
+ :param model_state: state dict of uncompressed model
43
+ :return: compressed state dict
44
+ """
45
+ raise NotImplementedError()
46
+
47
+ def decompress(
48
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
49
+ ) -> Generator[Tuple[str, Tensor], None, None]:
50
+ """
51
+ Reads a compressed state dict located at path_to_model_or_tensors
52
+ and returns a generator for sequentially decompressing back to a
53
+ dense state dict
54
+
55
+ :param model_path: path to compressed safetensors model (directory with
56
+ one or more safetensors files) or compressed tensors file
57
+ :param device: optional device to load intermediate weights into
58
+ :return: compressed state dict
59
+ """
60
+ raise NotImplementedError()
@@ -14,21 +14,21 @@
14
14
 
15
15
  from typing import Dict, Generator, Tuple
16
16
 
17
- from compressed_tensors.compressors import ModelCompressor
17
+ from compressed_tensors.compressors import Compressor
18
18
  from compressed_tensors.config import CompressionFormat
19
19
  from torch import Tensor
20
20
 
21
21
 
22
- @ModelCompressor.register(name=CompressionFormat.dense_sparsity.value)
23
- class DenseCompressor(ModelCompressor):
22
+ @Compressor.register(name=CompressionFormat.dense.value)
23
+ class DenseCompressor(Compressor):
24
24
  """
25
25
  Identity compressor for dense models, returns the original state_dict
26
26
  """
27
27
 
28
- def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
28
+ def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
29
29
  return model_state
30
30
 
31
31
  def decompress(
32
- self, path_to_model_or_tensors: str, device: str = "cpu"
32
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
33
33
  ) -> Generator[Tuple[str, Tensor], None, None]:
34
34
  return iter([])
@@ -16,8 +16,8 @@ from pathlib import Path
16
16
  from typing import Dict, Generator, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from compressed_tensors.compressors import ModelCompressor
20
- from compressed_tensors.config import CompressionConfig, CompressionFormat
19
+ from compressed_tensors.compressors import Compressor
20
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
21
21
  from compressed_tensors.utils.safetensors_load import get_weight_mappings
22
22
  from safetensors import safe_open
23
23
  from safetensors.torch import save_file
@@ -48,20 +48,20 @@ def save_compressed(
48
48
  if tensors is None or len(tensors) == 0:
49
49
  raise ValueError("No tensors or empty tensors provided to compress")
50
50
 
51
- # if no compression_format specified, default to `dense_sparsity`
52
- compression_format = compression_format or CompressionFormat.dense_sparsity.value
51
+ # if no compression_format specified, default to `dense`
52
+ compression_format = compression_format or CompressionFormat.dense.value
53
53
 
54
54
  if not (
55
- compression_format in ModelCompressor.registered_names()
56
- or compression_format in ModelCompressor.registered_aliases()
55
+ compression_format in Compressor.registered_names()
56
+ or compression_format in Compressor.registered_aliases()
57
57
  ):
58
58
  raise ValueError(
59
59
  f"Unknown compression format: {compression_format}. "
60
- f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501
60
+ f"Must be one of {set(Compressor.registered_names() + Compressor.registered_aliases())}" # noqa E501
61
61
  )
62
62
 
63
63
  # compress
64
- compressor = ModelCompressor.load_from_registry(compression_format)
64
+ compressor = Compressor.load_from_registry(compression_format)
65
65
  # save compressed tensors
66
66
  compressed_tensors = compressor.compress(tensors)
67
67
  save_file(compressed_tensors, save_path)
@@ -69,7 +69,7 @@ def save_compressed(
69
69
 
70
70
  def load_compressed(
71
71
  compressed_tensors: Union[str, Path],
72
- compression_config: CompressionConfig = None,
72
+ compression_config: SparsityCompressionConfig = None,
73
73
  device: Optional[str] = "cpu",
74
74
  ) -> Generator[Tuple[str, Tensor], None, None]:
75
75
  """
@@ -90,9 +90,9 @@ def load_compressed(
90
90
 
91
91
  if (
92
92
  compression_config is None
93
- or compression_config.format == CompressionFormat.dense_sparsity.value
93
+ or compression_config.format == CompressionFormat.dense.value
94
94
  ):
95
- # if no compression_config specified, or `dense_sparsity` format specified,
95
+ # if no compression_config specified, or `dense` format specified,
96
96
  # assume tensors are not compressed on disk
97
97
  weight_mappings = get_weight_mappings(compressed_tensors)
98
98
  for weight_name, file_with_weight_name in weight_mappings.items():
@@ -102,7 +102,7 @@ def load_compressed(
102
102
  else:
103
103
  # decompress tensors
104
104
  compression_format = compression_config.format
105
- compressor = ModelCompressor.load_from_registry(
105
+ compressor = Compressor.load_from_registry(
106
106
  compression_format, config=compression_config
107
107
  )
108
108
  yield from compressor.decompress(compressed_tensors, device=device)
@@ -0,0 +1,251 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import Dict, Generator, Tuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ from compressed_tensors.compressors import Compressor
21
+ from compressed_tensors.config import CompressionFormat
22
+ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
23
+ from compressed_tensors.quantization.lifecycle.forward import quantize
24
+ from compressed_tensors.utils import (
25
+ get_permutations_24,
26
+ is_quantization_param,
27
+ merge_names,
28
+ sparse_semi_structured_from_dense_cutlass,
29
+ tensor_follows_mask_structure,
30
+ )
31
+ from torch import Tensor
32
+ from tqdm import tqdm
33
+
34
+
35
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
36
+
37
+
38
+ @Compressor.register(name=CompressionFormat.marlin_24.value)
39
+ class Marlin24Compressor(Compressor):
40
+ """
41
+ Compresses a quantized model with 2:4 sparsity structure for inference with the
42
+ Marlin24 kernel. Decompression is not implemented for this compressor.
43
+ """
44
+
45
+ COMPRESSION_PARAM_NAMES = ["weight_packed", "scale_packed", "meta"]
46
+
47
+ @staticmethod
48
+ def validate_quant_compatability(
49
+ model_quant_args: Dict[str, QuantizationArgs]
50
+ ) -> bool:
51
+ """
52
+ Checks if every quantized module in the model is compatible with Marlin24
53
+ compression. Quantization must be channel or group strategy with group_size
54
+ of 128. Only symmetric quantization is supported
55
+
56
+ :param model_quant_args: dictionary of mapping module names to their
57
+ quantization configuration
58
+ :return: True if all modules are compatible with Marlin24 compression, raises
59
+ a ValueError otherwise
60
+ """
61
+ for name, quant_args in model_quant_args.items():
62
+ strategy = quant_args.strategy
63
+ group_size = quant_args.group_size
64
+ symmetric = quant_args.symmetric
65
+ if (
66
+ strategy is not QuantizationStrategy.GROUP.value
67
+ and strategy is not QuantizationStrategy.CHANNEL.value
68
+ ):
69
+ raise ValueError(
70
+ f"Marlin24 Compressor is only valid for group and channel "
71
+ f"quantization strategies, got {strategy} in {name}"
72
+ )
73
+
74
+ if group_size is not None and group_size != 128:
75
+ raise ValueError(
76
+ f"Marlin24 Compressor is only valid for group size 128, "
77
+ f"got {group_size} in {name}"
78
+ )
79
+
80
+ if not symmetric:
81
+ raise ValueError(
82
+ f"Marlin24 Compressor is only valid for symmetric quantzation, "
83
+ f"got symmetric={symmetric} in {name}"
84
+ )
85
+
86
+ return True
87
+
88
+ @staticmethod
89
+ def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
90
+ """
91
+ Checks if a tensor fits the required 2:4 sparsity structure
92
+
93
+ :param name: name of the tensor to check
94
+ :param weight: tensor to check for sparsity structure
95
+ :return: True if all rows match the 2:4 sparsity structure, raises
96
+ ValueError otherwise
97
+ """
98
+
99
+ if not tensor_follows_mask_structure(weight):
100
+ raise ValueError(
101
+ "Marlin24 Compressor is only compatible with weights that have "
102
+ f"a 2:4 sparsity structure. Found segments in {name} "
103
+ "that do not match the expected structure."
104
+ )
105
+
106
+ return True
107
+
108
+ def compress(
109
+ self,
110
+ model_state: Dict[str, Tensor],
111
+ names_to_scheme: Dict[str, QuantizationArgs],
112
+ **kwargs,
113
+ ) -> Dict[str, Tensor]:
114
+ """
115
+ Compresses a quantized state_dict with 2:4 sparsity structure for inference
116
+ with the Marlin24 kernel
117
+
118
+ :param model_state: state dict of uncompressed model
119
+ :param names_to_scheme: quantization args for each quantized weight, needed for
120
+ quantize function to calculate bit depth
121
+ :return: compressed state dict
122
+ """
123
+ self.validate_quant_compatability(names_to_scheme)
124
+
125
+ compressed_dict = {}
126
+ weight_suffix = ".weight"
127
+ _LOGGER.debug(
128
+ f"Compressing model with {len(model_state)} parameterized layers..."
129
+ )
130
+
131
+ for name, value in tqdm(model_state.items(), desc="Compressing model"):
132
+ if name.endswith(weight_suffix):
133
+ prefix = name[: -(len(weight_suffix))]
134
+ scale = model_state.get(merge_names(prefix, "weight_scale"), None)
135
+ zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
136
+ if scale is not None: # weight is quantized, compress it
137
+
138
+ # Marlin24 kernel requires float16 inputs
139
+ scale = scale.to(torch.float16)
140
+ value = value.to(torch.float16)
141
+
142
+ # quantize weight, keeping it as a float16 for now
143
+ quant_args = names_to_scheme[prefix]
144
+ value = quantize(
145
+ x=value, scale=scale, zero_point=zp, args=quant_args
146
+ )
147
+
148
+ # compress based on sparsity structure
149
+ self.validate_sparsity_structure(prefix, value)
150
+ value, meta = compress_weight_24(value)
151
+ meta = meta.cpu()
152
+
153
+ # Marlin24 kernel expects input dim first
154
+ value = value.t().contiguous().cpu()
155
+ scale = scale.t().contiguous().cpu()
156
+ og_weight_shape = value.shape
157
+
158
+ # Marlin24 kernel expects unsigned values, shift zero-point
159
+ value += (1 << quant_args.num_bits) // 2
160
+
161
+ # pack quantized weight and scale
162
+ value = pack_weight_24(value, quant_args)
163
+ packed_scale = pack_scales_24(scale, quant_args, og_weight_shape)
164
+ meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
165
+
166
+ # save compressed values
167
+ compressed_dict[merge_names(prefix, "scale_packed")] = packed_scale
168
+ compressed_dict[merge_names(prefix, "weight_packed")] = value
169
+ compressed_dict[merge_names(prefix, "meta")] = meta
170
+ continue
171
+
172
+ if not is_quantization_param(name):
173
+ # export unquantized parameters without modifying
174
+ compressed_dict[name] = value.to("cpu")
175
+
176
+ return compressed_dict
177
+
178
+ def decompress(
179
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
180
+ ) -> Generator[Tuple[str, Tensor], None, None]:
181
+ raise NotImplementedError(
182
+ "Decompression is not implemented for the Marlin24 Compressor."
183
+ )
184
+
185
+
186
+ def compress_weight_24(weight: Tensor):
187
+ weight = weight.contiguous()
188
+ w_comp, meta = sparse_semi_structured_from_dense_cutlass(weight)
189
+ w_comp = w_comp.contiguous()
190
+ return w_comp, meta
191
+
192
+
193
+ def marlin_permute_weights(q_w, size_k, size_n, perm, tile):
194
+ assert q_w.shape == (size_k, size_n)
195
+ assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
196
+ assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
197
+
198
+ # Permute weights to 16x64 marlin tiles
199
+ q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
200
+ q_w = q_w.permute((0, 2, 1, 3))
201
+ q_w = q_w.reshape((size_k // tile, size_n * tile))
202
+
203
+ q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
204
+
205
+ return q_w
206
+
207
+
208
+ def pack_weight_24(
209
+ weight: Tensor,
210
+ quantization_args: QuantizationArgs,
211
+ tile: int = 16,
212
+ ):
213
+ size_k = weight.shape[0]
214
+ size_n = weight.shape[1]
215
+ num_bits = quantization_args.num_bits
216
+ pack_factor = 32 // num_bits
217
+
218
+ # Reshuffle to marlin_24 format
219
+ perm, _, _ = get_permutations_24(num_bits)
220
+ q_w = marlin_permute_weights(weight, size_k, size_n, perm, tile)
221
+
222
+ q_w = q_w.cpu().numpy().astype(np.uint32)
223
+
224
+ q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
225
+ for i in range(pack_factor):
226
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
227
+
228
+ q_packed = torch.from_numpy(q_packed.astype(np.int32))
229
+
230
+ return q_packed
231
+
232
+
233
+ def pack_scales_24(scales, quantization_args, w_shape):
234
+ size_k = w_shape[0]
235
+ size_n = w_shape[1]
236
+ num_bits = quantization_args.num_bits
237
+
238
+ _, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)
239
+
240
+ if (
241
+ quantization_args.strategy is QuantizationStrategy.GROUP
242
+ and quantization_args.group_size < size_k
243
+ ):
244
+ scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
245
+ else: # channelwise
246
+ scales = scales.reshape((-1, len(scale_perm_single_2_4)))[
247
+ :, scale_perm_single_2_4
248
+ ]
249
+ scales = scales.reshape((-1, size_n)).contiguous()
250
+
251
+ return scales