compressed-tensors-nightly 0.6.0.20240930__py3-none-any.whl → 0.6.0.20241004__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. compressed_tensors/base.py +1 -0
  2. compressed_tensors/compressors/__init__.py +6 -12
  3. compressed_tensors/compressors/base.py +38 -102
  4. compressed_tensors/compressors/helpers.py +6 -6
  5. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  6. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +91 -53
  7. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  8. compressed_tensors/compressors/quantized_compressors/base.py +146 -0
  9. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/naive_quantized.py} +11 -11
  10. compressed_tensors/compressors/{pack_quantized.py → quantized_compressors/pack_quantized.py} +6 -3
  11. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  12. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  13. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  14. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  15. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  16. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  17. compressed_tensors/linear/compressed_linear.py +2 -2
  18. compressed_tensors/quantization/lifecycle/calibration.py +2 -3
  19. compressed_tensors/quantization/lifecycle/initialize.py +2 -1
  20. compressed_tensors/quantization/quant_config.py +7 -0
  21. compressed_tensors/quantization/quant_scheme.py +1 -1
  22. compressed_tensors/utils/helpers.py +17 -1
  23. {compressed_tensors_nightly-0.6.0.20240930.dist-info → compressed_tensors_nightly-0.6.0.20241004.dist-info}/METADATA +1 -1
  24. {compressed_tensors_nightly-0.6.0.20240930.dist-info → compressed_tensors_nightly-0.6.0.20241004.dist-info}/RECORD +27 -21
  25. {compressed_tensors_nightly-0.6.0.20240930.dist-info → compressed_tensors_nightly-0.6.0.20241004.dist-info}/LICENSE +0 -0
  26. {compressed_tensors_nightly-0.6.0.20240930.dist-info → compressed_tensors_nightly-0.6.0.20241004.dist-info}/WHEEL +0 -0
  27. {compressed_tensors_nightly-0.6.0.20240930.dist-info → compressed_tensors_nightly-0.6.0.20241004.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import Dict, Generator, Tuple
17
+
18
+ import torch
19
+ from compressed_tensors.compressors.base import BaseCompressor
20
+ from compressed_tensors.quantization import QuantizationArgs
21
+ from compressed_tensors.utils import get_nested_weight_mappings, merge_names
22
+ from safetensors import safe_open
23
+ from torch import Tensor
24
+ from tqdm import tqdm
25
+
26
+
27
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
28
+
29
+ __all__ = ["BaseQuantizationCompressor"]
30
+
31
+
32
+ class BaseQuantizationCompressor(BaseCompressor):
33
+ """
34
+ Base class representing a quant compression algorithm. Each child class should
35
+ implement compression_param_info, compress_weight and decompress_weight.
36
+
37
+ Compressors support compressing/decompressing a full module state dict or a single
38
+ quantized PyTorch leaf module.
39
+
40
+ Model Load Lifecycle (run_compressed=False):
41
+ - ModelCompressor.decompress()
42
+ - apply_quantization_config()
43
+ - BaseQuantizationCompressor.decompress()
44
+ - BaseQuantizationCompressor.decompress_weight()
45
+
46
+ Model Save Lifecycle:
47
+ - ModelCompressor.compress()
48
+ - BaseQuantizationCompressor.compress()
49
+ - BaseQuantizationCompressor.compress_weight()
50
+
51
+ Module Lifecycle (run_compressed=True):
52
+ - apply_quantization_config()
53
+ - compressed_module = CompressedLinear(module)
54
+ - initialize_module_for_quantization()
55
+ - BaseQuantizationCompressor.compression_param_info()
56
+ - register_parameters()
57
+ - compressed_module.forward()
58
+ - compressed_module.decompress()
59
+
60
+
61
+ :param config: config specifying compression parameters
62
+ """
63
+
64
+ def compress(
65
+ self,
66
+ model_state: Dict[str, Tensor],
67
+ names_to_scheme: Dict[str, QuantizationArgs],
68
+ **kwargs,
69
+ ) -> Dict[str, Tensor]:
70
+ """
71
+ Compresses a dense state dict
72
+
73
+ :param model_state: state dict of uncompressed model
74
+ :param names_to_scheme: quantization args for each quantized weight, needed for
75
+ quantize function to calculate bit depth
76
+ :return: compressed state dict
77
+ """
78
+ compressed_dict = {}
79
+ weight_suffix = ".weight"
80
+ _LOGGER.debug(
81
+ f"Compressing model with {len(model_state)} parameterized layers..."
82
+ )
83
+
84
+ for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
85
+ if name.endswith(weight_suffix):
86
+ prefix = name[: -(len(weight_suffix))]
87
+ scale = model_state.get(merge_names(prefix, "weight_scale"), None)
88
+ zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
89
+ g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
90
+ if scale is not None:
91
+ # weight is quantized, compress it
92
+ quant_args = names_to_scheme[prefix]
93
+ compressed_data = self.compress_weight(
94
+ weight=value,
95
+ scale=scale,
96
+ zero_point=zp,
97
+ g_idx=g_idx,
98
+ quantization_args=quant_args,
99
+ device="cpu",
100
+ )
101
+ for key, value in compressed_data.items():
102
+ compressed_dict[merge_names(prefix, key)] = value
103
+ else:
104
+ compressed_dict[name] = value.to("cpu")
105
+ elif name.endswith("zero_point") and torch.all(value == 0):
106
+ continue
107
+ elif name.endswith("g_idx") and torch.any(value <= -1):
108
+ continue
109
+ else:
110
+ compressed_dict[name] = value.to("cpu")
111
+
112
+ return compressed_dict
113
+
114
+ def decompress(
115
+ self,
116
+ path_to_model_or_tensors: str,
117
+ names_to_scheme: Dict[str, QuantizationArgs],
118
+ device: str = "cpu",
119
+ ) -> Generator[Tuple[str, Tensor], None, None]:
120
+ """
121
+ Reads a compressed state dict located at path_to_model_or_tensors
122
+ and returns a generator for sequentially decompressing back to a
123
+ dense state dict
124
+
125
+ :param path_to_model_or_tensors: path to compressed safetensors model (directory
126
+ with one or more safetensors files) or compressed tensors file
127
+ :param names_to_scheme: quantization args for each quantized weight
128
+ :param device: optional device to load intermediate weights into
129
+ :return: compressed state dict
130
+ """
131
+ weight_mappings = get_nested_weight_mappings(
132
+ path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
133
+ )
134
+ for weight_name in weight_mappings.keys():
135
+ weight_data = {}
136
+ for param_name, safe_path in weight_mappings[weight_name].items():
137
+ full_name = merge_names(weight_name, param_name)
138
+ with safe_open(safe_path, framework="pt", device=device) as f:
139
+ weight_data[param_name] = f.get_tensor(full_name)
140
+
141
+ if "weight_scale" in weight_data:
142
+ quant_args = names_to_scheme[weight_name]
143
+ decompressed = self.decompress_weight(
144
+ compressed_data=weight_data, quantization_args=quant_args
145
+ )
146
+ yield merge_names(weight_name, "weight"), decompressed
@@ -12,11 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import logging
16
15
  from typing import Dict, Optional, Tuple
17
16
 
18
17
  import torch
19
- from compressed_tensors.compressors import Compressor
18
+ from compressed_tensors.compressors.base import BaseCompressor
19
+ from compressed_tensors.compressors.quantized_compressors.base import (
20
+ BaseQuantizationCompressor,
21
+ )
20
22
  from compressed_tensors.config import CompressionFormat
21
23
  from compressed_tensors.quantization import QuantizationArgs
22
24
  from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
@@ -25,16 +27,14 @@ from torch import Tensor
25
27
 
26
28
 
27
29
  __all__ = [
28
- "QuantizationCompressor",
30
+ "NaiveQuantizationCompressor",
29
31
  "IntQuantizationCompressor",
30
32
  "FloatQuantizationCompressor",
31
33
  ]
32
34
 
33
- _LOGGER: logging.Logger = logging.getLogger(__name__)
34
35
 
35
-
36
- @Compressor.register(name=CompressionFormat.naive_quantized.value)
37
- class QuantizationCompressor(Compressor):
36
+ @BaseCompressor.register(name=CompressionFormat.naive_quantized.value)
37
+ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
38
38
  """
39
39
  Implements naive compression for quantized models. Weight of each
40
40
  quantized layer is converted from its original float type to the closest Pytorch
@@ -122,8 +122,8 @@ class QuantizationCompressor(Compressor):
122
122
  return decompressed_weight
123
123
 
124
124
 
125
- @Compressor.register(name=CompressionFormat.int_quantized.value)
126
- class IntQuantizationCompressor(QuantizationCompressor):
125
+ @BaseCompressor.register(name=CompressionFormat.int_quantized.value)
126
+ class IntQuantizationCompressor(NaiveQuantizationCompressor):
127
127
  """
128
128
  Alias for integer quantized models
129
129
  """
@@ -131,8 +131,8 @@ class IntQuantizationCompressor(QuantizationCompressor):
131
131
  pass
132
132
 
133
133
 
134
- @Compressor.register(name=CompressionFormat.float_quantized.value)
135
- class FloatQuantizationCompressor(QuantizationCompressor):
134
+ @BaseCompressor.register(name=CompressionFormat.float_quantized.value)
135
+ class FloatQuantizationCompressor(NaiveQuantizationCompressor):
136
136
  """
137
137
  Alias for fp quantized models
138
138
  """
@@ -16,7 +16,10 @@ from typing import Dict, Optional, Tuple
16
16
 
17
17
  import numpy as np
18
18
  import torch
19
- from compressed_tensors.compressors import Compressor
19
+ from compressed_tensors.compressors.base import BaseCompressor
20
+ from compressed_tensors.compressors.quantized_compressors.base import (
21
+ BaseQuantizationCompressor,
22
+ )
20
23
  from compressed_tensors.config import CompressionFormat
21
24
  from compressed_tensors.quantization import QuantizationArgs
22
25
  from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
@@ -27,8 +30,8 @@ from torch import Tensor
27
30
  __all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
28
31
 
29
32
 
30
- @Compressor.register(name=CompressionFormat.pack_quantized.value)
31
- class PackedQuantizationCompressor(Compressor):
33
+ @BaseCompressor.register(name=CompressionFormat.pack_quantized.value)
34
+ class PackedQuantizationCompressor(BaseQuantizationCompressor):
32
35
  """
33
36
  Compresses a quantized model by packing every eight 4-bit weights into an int32
34
37
  """
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .base import *
17
+ from .dense import *
18
+ from .sparse_bitmask import *
@@ -0,0 +1,110 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import Dict, Generator, Tuple
17
+
18
+ from compressed_tensors.compressors.base import BaseCompressor
19
+ from compressed_tensors.utils import get_nested_weight_mappings, merge_names
20
+ from safetensors import safe_open
21
+ from torch import Tensor
22
+ from tqdm import tqdm
23
+
24
+
25
+ __all__ = ["BaseSparseCompressor"]
26
+
27
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
28
+
29
+
30
+ class BaseSparseCompressor(BaseCompressor):
31
+ """
32
+ Base class representing a sparse compression algorithm. Each child class should
33
+ implement compression_param_info, compress_weight and decompress_weight.
34
+
35
+ Compressors support compressing/decompressing a full module state dict or a single
36
+ quantized PyTorch leaf module.
37
+
38
+ Model Load Lifecycle (run_compressed=False):
39
+ - ModelCompressor.decompress()
40
+ - apply_quantization_config()
41
+ - BaseSparseCompressor.decompress()
42
+ - BaseSparseCompressor.decompress_weight()
43
+
44
+ Model Save Lifecycle:
45
+ - ModelCompressor.compress()
46
+ - BaseSparseCompressor.compress()
47
+ - BaseSparseCompressor.compress_weight()
48
+
49
+ Module Lifecycle (run_compressed=True):
50
+ - apply_quantization_config()
51
+ - compressed_module = CompressedLinear(module)
52
+ - initialize_module_for_quantization()
53
+ - BaseSparseCompressor.compression_param_info()
54
+ - register_parameters()
55
+ - compressed_module.forward()
56
+ - compressed_module.decompress()
57
+
58
+
59
+ :param config: config specifying compression parameters
60
+ """
61
+
62
+ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
63
+ """
64
+ Compresses a dense state dict using bitmask compression
65
+
66
+ :param model_state: state dict of uncompressed model
67
+ :return: compressed state dict
68
+ """
69
+ compressed_dict = {}
70
+ _LOGGER.debug(
71
+ f"Compressing model with {len(model_state)} parameterized layers..."
72
+ )
73
+ for name, value in tqdm(model_state.items(), desc="Compressing model"):
74
+ compression_data = self.compress_weight(name, value)
75
+ for key in compression_data.keys():
76
+ if key in compressed_dict:
77
+ _LOGGER.warn(
78
+ f"Expected all compressed state_dict keys to be unique, but "
79
+ f"found an existing entry for {key}. The existing entry will "
80
+ "be replaced."
81
+ )
82
+
83
+ compressed_dict.update(compression_data)
84
+
85
+ return compressed_dict
86
+
87
+ def decompress(
88
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
89
+ ) -> Generator[Tuple[str, Tensor], None, None]:
90
+ """
91
+ Reads a bitmask compressed state dict located
92
+ at path_to_model_or_tensors and returns a generator
93
+ for sequentially decompressing back to a dense state dict
94
+
95
+ :param model_path: path to compressed safetensors model (directory with
96
+ one or more safetensors files) or compressed tensors file
97
+ :param device: device to load decompressed weights onto
98
+ :return: iterator for generating decompressed weights
99
+ """
100
+ weight_mappings = get_nested_weight_mappings(
101
+ path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
102
+ )
103
+ for weight_name in weight_mappings.keys():
104
+ weight_data = {}
105
+ for param_name, safe_path in weight_mappings[weight_name].items():
106
+ full_name = merge_names(weight_name, param_name)
107
+ with safe_open(safe_path, framework="pt", device=device) as f:
108
+ weight_data[param_name] = f.get_tensor(full_name)
109
+ decompressed = self.decompress_weight(weight_data)
110
+ yield weight_name, decompressed
@@ -14,13 +14,13 @@
14
14
 
15
15
  from typing import Dict, Generator, Tuple
16
16
 
17
- from compressed_tensors.compressors import Compressor
17
+ from compressed_tensors.compressors.base import BaseCompressor
18
18
  from compressed_tensors.config import CompressionFormat
19
19
  from torch import Tensor
20
20
 
21
21
 
22
- @Compressor.register(name=CompressionFormat.dense.value)
23
- class DenseCompressor(Compressor):
22
+ @BaseCompressor.register(name=CompressionFormat.dense.value)
23
+ class DenseCompressor(BaseCompressor):
24
24
  """
25
25
  Identity compressor for dense models, returns the original state_dict
26
26
  """
@@ -12,17 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import logging
16
- from typing import Dict, Generator, List, Tuple, Union
15
+ from typing import Dict, List, Tuple, Union
17
16
 
18
17
  import numpy
19
18
  import torch
20
- from compressed_tensors.compressors import Compressor
19
+ from compressed_tensors.compressors.base import BaseCompressor
20
+ from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
21
  from compressed_tensors.config import CompressionFormat
22
- from compressed_tensors.utils import get_nested_weight_mappings, merge_names
23
- from safetensors import safe_open
22
+ from compressed_tensors.utils import merge_names
24
23
  from torch import Tensor
25
- from tqdm import tqdm
26
24
 
27
25
 
28
26
  __all__ = [
@@ -34,11 +32,9 @@ __all__ = [
34
32
  "unpack_bitmasks",
35
33
  ]
36
34
 
37
- _LOGGER: logging.Logger = logging.getLogger(__name__)
38
35
 
39
-
40
- @Compressor.register(name=CompressionFormat.sparse_bitmask.value)
41
- class BitmaskCompressor(Compressor):
36
+ @BaseCompressor.register(name=CompressionFormat.sparse_bitmask.value)
37
+ class BitmaskCompressor(BaseSparseCompressor):
42
38
  """
43
39
  Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d
44
40
  values tensor, with their locations stored in a 2d bitmask
@@ -46,56 +42,15 @@ class BitmaskCompressor(Compressor):
46
42
 
47
43
  COMPRESSION_PARAM_NAMES = ["shape", "compressed", "bitmask", "row_offsets"]
48
44
 
49
- def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
50
- """
51
- Compresses a dense state dict using bitmask compression
45
+ def compress_weight(self, name, value):
46
+ bitmask_tensor = BitmaskTensor.from_dense(value)
47
+ bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
48
+ return bitmask_dict
52
49
 
53
- :param model_state: state dict of uncompressed model
54
- :return: compressed state dict
55
- """
56
- compressed_dict = {}
57
- _LOGGER.debug(
58
- f"Compressing model with {len(model_state)} parameterized layers..."
59
- )
60
- for name, value in tqdm(model_state.items(), desc="Compressing model"):
61
- bitmask_tensor = BitmaskTensor.from_dense(value)
62
- bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
63
- for key in bitmask_dict.keys():
64
- if key in compressed_dict:
65
- _LOGGER.warn(
66
- f"Expected all compressed state_dict keys to be unique, but "
67
- f"found an existing entry for {key}. The existing entry will "
68
- "be replaced."
69
- )
70
- compressed_dict.update(bitmask_dict)
71
-
72
- return compressed_dict
73
-
74
- def decompress(
75
- self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
76
- ) -> Generator[Tuple[str, Tensor], None, None]:
77
- """
78
- Reads a bitmask compressed state dict located
79
- at path_to_model_or_tensors and returns a generator
80
- for sequentially decompressing back to a dense state dict
81
-
82
- :param model_path: path to compressed safetensors model (directory with
83
- one or more safetensors files) or compressed tensors file
84
- :param device: device to load decompressed weights onto
85
- :return: iterator for generating decompressed weights
86
- """
87
- weight_mappings = get_nested_weight_mappings(
88
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
89
- )
90
- for weight_name in weight_mappings.keys():
91
- weight_data = {}
92
- for param_name, safe_path in weight_mappings[weight_name].items():
93
- full_name = merge_names(weight_name, param_name)
94
- with safe_open(safe_path, framework="pt", device=device) as f:
95
- weight_data[param_name] = f.get_tensor(full_name)
96
- data = BitmaskTensor(**weight_data)
97
- decompressed = data.decompress()
98
- yield weight_name, decompressed
50
+ def decompress_weight(self, weight_data):
51
+ data = BitmaskTensor(**weight_data)
52
+ decompressed = data.decompress()
53
+ return decompressed
99
54
 
100
55
 
101
56
  class BitmaskTensor:
@@ -0,0 +1,16 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .marlin_24 import Marlin24Compressor
@@ -17,7 +17,7 @@ from typing import Dict, Generator, Tuple
17
17
 
18
18
  import numpy as np
19
19
  import torch
20
- from compressed_tensors.compressors import Compressor
20
+ from compressed_tensors.compressors.base import BaseCompressor
21
21
  from compressed_tensors.config import CompressionFormat
22
22
  from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
23
23
  from compressed_tensors.quantization.lifecycle.forward import quantize
@@ -35,8 +35,8 @@ from tqdm import tqdm
35
35
  _LOGGER: logging.Logger = logging.getLogger(__name__)
36
36
 
37
37
 
38
- @Compressor.register(name=CompressionFormat.marlin_24.value)
39
- class Marlin24Compressor(Compressor):
38
+ @BaseCompressor.register(name=CompressionFormat.marlin_24.value)
39
+ class Marlin24Compressor(BaseCompressor):
40
40
  """
41
41
  Compresses a quantized model with 2:4 sparsity structure for inference with the
42
42
  Marlin24 kernel. Decompression is not implemented for this compressor.
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import torch
16
- from compressed_tensors.compressors.base import Compressor
16
+ from compressed_tensors.compressors.base import BaseCompressor
17
17
  from compressed_tensors.quantization import (
18
18
  QuantizationScheme,
19
19
  QuantizationStatus,
@@ -44,7 +44,7 @@ class CompressedLinear(Linear):
44
44
  quantization_format: str,
45
45
  ):
46
46
  module.__class__ = CompressedLinear
47
- module.compressor = Compressor.load_from_registry(quantization_format)
47
+ module.compressor = BaseCompressor.load_from_registry(quantization_format)
48
48
  device = next(module.parameters()).device
49
49
 
50
50
  # this will initialize all the scales and zero points
@@ -56,10 +56,9 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
56
56
  observer = module.weight_observer
57
57
  g_idx = getattr(module, "weight_g_idx", None)
58
58
 
59
- offloaded = False
60
- if is_module_offloaded(module):
59
+ offloaded = is_module_offloaded(module)
60
+ if offloaded:
61
61
  module._hf_hook.pre_forward(module)
62
- offloaded = True
63
62
 
64
63
  scale, zero_point = observer(module.weight, g_idx=g_idx)
65
64
  update_parameter_data(module, scale, "weight_scale")
@@ -172,9 +172,10 @@ def _initialize_scale_zero_point_observer(
172
172
  # (output_channels, 1)
173
173
  expected_shape = (weight_shape[0], 1)
174
174
  elif quantization_args.strategy == QuantizationStrategy.GROUP:
175
+ num_groups = weight_shape[1] // quantization_args.group_size
175
176
  expected_shape = (
176
177
  weight_shape[0],
177
- weight_shape[1] // quantization_args.group_size,
178
+ max(num_groups, 1)
178
179
  )
179
180
 
180
181
  scale_dtype = module.weight.dtype
@@ -201,6 +201,13 @@ class QuantizationConfig(BaseModel):
201
201
  if len(quant_scheme_to_layers) == 0: # No quantized layers
202
202
  return None
203
203
 
204
+ # kv-cache only, no weight/activation quantization
205
+ if (
206
+ len(quantization_type_names) == 1
207
+ and "attention" in list(quantization_type_names)[0].lower()
208
+ ):
209
+ quantization_type_names.add("Linear")
210
+
204
211
  # clean up ignore list, we can leave out layers types if none of the
205
212
  # instances are quantized
206
213
  consolidated_ignore = []
@@ -211,7 +211,7 @@ PRESET_SCHEMES = {
211
211
  "W4A16": W4A16,
212
212
  # Integer weight and activation schemes
213
213
  "W8A8": INT8_W8A8,
214
- "INT8": INT8_W8A8, # alias for W8A8
214
+ "INT8": INT8_W8A8, # alias for W8A8
215
215
  "W4A8": INT8_W4A8,
216
216
  # Float weight and activation schemes
217
217
  "FP8": FP8,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  import torch
18
18
  from transformers import AutoConfig
@@ -23,6 +23,7 @@ __all__ = [
23
23
  "fix_fsdp_module_name",
24
24
  "tensor_follows_mask_structure",
25
25
  "replace_module",
26
+ "is_compressed_tensors_config",
26
27
  ]
27
28
 
28
29
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -103,3 +104,18 @@ def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Modul
103
104
  parent = model
104
105
  child_name = name
105
106
  setattr(parent, child_name, new_module)
107
+
108
+
109
+ def is_compressed_tensors_config(compression_config: Any) -> bool:
110
+ """
111
+ Returns True if CompressedTensorsConfig is available from transformers and
112
+ compression_config is an instance of CompressedTensorsConfig
113
+
114
+ See: https://github.com/huggingface/transformers/pull/31704
115
+ """
116
+ try:
117
+ from transformers.utils.quantization_config import CompressedTensorsConfig
118
+
119
+ return isinstance(compression_config, CompressedTensorsConfig)
120
+ except ImportError:
121
+ return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.6.0.20240930
3
+ Version: 0.6.0.20241004
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.