compressed-tensors 0.3.1__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {compressed-tensors-0.3.1/src/compressed_tensors.egg-info → compressed-tensors-0.3.3}/PKG-INFO +4 -2
  2. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/README.md +3 -1
  3. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/setup.py +1 -1
  4. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/compressors/__init__.py +1 -6
  5. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/compressors/base.py +25 -1
  6. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/compressors/dense.py +1 -1
  7. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/compressors/helpers.py +0 -24
  8. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/compressors/sparse_bitmask.py +4 -3
  9. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/lifecycle/apply.py +14 -9
  10. compressed-tensors-0.3.3/src/compressed_tensors/quantization/lifecycle/forward.py +221 -0
  11. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/lifecycle/frozen.py +9 -9
  12. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/lifecycle/initialize.py +7 -4
  13. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/observers/base.py +64 -3
  14. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/observers/memoryless.py +2 -2
  15. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/quant_args.py +42 -2
  16. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/utils/helpers.py +1 -0
  17. compressed-tensors-0.3.3/src/compressed_tensors/utils/helpers.py +45 -0
  18. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3/src/compressed_tensors.egg-info}/PKG-INFO +4 -2
  19. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors.egg-info/SOURCES.txt +1 -0
  20. compressed-tensors-0.3.1/src/compressed_tensors/quantization/lifecycle/forward.py +0 -136
  21. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/LICENSE +0 -0
  22. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/pyproject.toml +0 -0
  23. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/setup.cfg +0 -0
  24. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/__init__.py +0 -0
  25. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/base.py +0 -0
  26. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/config/__init__.py +0 -0
  27. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/config/base.py +0 -0
  28. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/config/dense.py +0 -0
  29. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  30. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/__init__.py +0 -0
  31. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  32. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/lifecycle/calibration.py +0 -0
  33. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/observers/__init__.py +0 -0
  34. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/observers/helpers.py +0 -0
  35. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/observers/min_max.py +0 -0
  36. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/quant_config.py +0 -0
  37. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  38. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  39. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/registry/__init__.py +0 -0
  40. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/registry/registry.py +0 -0
  41. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/utils/__init__.py +0 -0
  42. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  43. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  44. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors.egg-info/requires.txt +0 -0
  45. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  46. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/tests/test_bitmask.py +0 -0
  47. {compressed-tensors-0.3.1 → compressed-tensors-0.3.3}/tests/test_registry.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -11,7 +11,7 @@ Description-Content-Type: text/markdown
11
11
  Provides-Extra: dev
12
12
  License-File: LICENSE
13
13
 
14
- # compressed-tensors
14
+ # compressed_tensors
15
15
 
16
16
  This repository extends a [safetensors](https://github.com/huggingface/safetensors) format to efficiently store sparse and/or quantized tensors on disk. `compressed-tensors` format supports multiple compression types to minimize the disk space and facilitate the tensor manipulation.
17
17
 
@@ -94,4 +94,6 @@ save_compressed_model(model, "compressed_model.safetensors", compression_format=
94
94
  state_dict = dict(load_compressed("compressed_model.safetensors", compression_config))
95
95
  ```
96
96
 
97
+ For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/neuralmagic/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb).
98
+
97
99
 
@@ -1,4 +1,4 @@
1
- # compressed-tensors
1
+ # compressed_tensors
2
2
 
3
3
  This repository extends a [safetensors](https://github.com/huggingface/safetensors) format to efficiently store sparse and/or quantized tensors on disk. `compressed-tensors` format supports multiple compression types to minimize the disk space and facilitate the tensor manipulation.
4
4
 
@@ -80,3 +80,5 @@ save_compressed_model(model, "compressed_model.safetensors", compression_format=
80
80
  # load compressed model weights (`dict` turns generator into a dictionary)
81
81
  state_dict = dict(load_compressed("compressed_model.safetensors", compression_config))
82
82
  ```
83
+
84
+ For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/neuralmagic/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb).
@@ -32,7 +32,7 @@ def _setup_extras() -> Dict:
32
32
 
33
33
  setup(
34
34
  name="compressed-tensors",
35
- version="0.3.1",
35
+ version="0.3.3",
36
36
  author="Neuralmagic, Inc.",
37
37
  author_email="support@neuralmagic.com",
38
38
  license="Apache 2.0",
@@ -16,10 +16,5 @@
16
16
 
17
17
  from .base import ModelCompressor
18
18
  from .dense import DenseCompressor
19
- from .helpers import (
20
- infer_compressor_from_model_config,
21
- load_compressed,
22
- save_compressed,
23
- save_compressed_model,
24
- )
19
+ from .helpers import load_compressed, save_compressed, save_compressed_model
25
20
  from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
@@ -22,6 +22,7 @@ from compressed_tensors.utils import get_safetensors_folder
22
22
  from torch import Tensor
23
23
  from torch.nn import Module, Parameter
24
24
  from tqdm import tqdm
25
+ from transformers import AutoConfig
25
26
 
26
27
 
27
28
  __all__ = ["ModelCompressor"]
@@ -34,6 +35,29 @@ class ModelCompressor(RegistryMixin):
34
35
  :param config: config specifying compression parameters
35
36
  """
36
37
 
38
+ @classmethod
39
+ def from_pretrained(
40
+ cls, pretrained_model_name_or_path: str
41
+ ) -> Optional["ModelCompressor"]:
42
+ """
43
+ Given a path to a model config, extract a sparsity config if it exists and
44
+ return the associated ModelCompressor
45
+
46
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
47
+ :return: matching compressor if config contains a sparsity config
48
+ """
49
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
50
+ sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
51
+ if sparsity_config is None:
52
+ return None
53
+
54
+ format = sparsity_config.get("format")
55
+ sparsity_config = CompressionConfig.load_from_registry(
56
+ format, **sparsity_config
57
+ )
58
+ compressor = cls.load_from_registry(format, config=sparsity_config)
59
+ return compressor
60
+
37
61
  def __init__(self, config: Optional[CompressionConfig] = None):
38
62
  self.config = config
39
63
 
@@ -47,7 +71,7 @@ class ModelCompressor(RegistryMixin):
47
71
  raise NotImplementedError()
48
72
 
49
73
  def decompress(
50
- self, path_to_model_or_tensors: str
74
+ self, path_to_model_or_tensors: str, device: str = "cpu"
51
75
  ) -> Generator[Tuple[str, Tensor], None, None]:
52
76
  """
53
77
  Reads a compressed state dict located at path_to_model_or_tensors
@@ -29,6 +29,6 @@ class DenseCompressor(ModelCompressor):
29
29
  return model_state
30
30
 
31
31
  def decompress(
32
- self, path_to_model_or_tensors: str, device: str
32
+ self, path_to_model_or_tensors: str, device: str = "cpu"
33
33
  ) -> Generator[Tuple[str, Tensor], None, None]:
34
34
  return iter([])
@@ -16,45 +16,21 @@ from pathlib import Path
16
16
  from typing import Dict, Generator, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from compressed_tensors.base import SPARSITY_CONFIG_NAME
20
19
  from compressed_tensors.compressors import ModelCompressor
21
20
  from compressed_tensors.config import CompressionConfig, CompressionFormat
22
21
  from compressed_tensors.utils.safetensors_load import get_weight_mappings
23
22
  from safetensors import safe_open
24
23
  from safetensors.torch import save_file
25
24
  from torch import Tensor
26
- from transformers import AutoConfig
27
25
 
28
26
 
29
27
  __all__ = [
30
- "infer_compressor_from_model_config",
31
28
  "load_compressed",
32
29
  "save_compressed",
33
30
  "save_compressed_model",
34
31
  ]
35
32
 
36
33
 
37
- def infer_compressor_from_model_config(
38
- pretrained_model_name_or_path: str,
39
- ) -> Optional[ModelCompressor]:
40
- """
41
- Given a path to a model config, extract a sparsity config if it exists and return
42
- the associated ModelCompressor
43
-
44
- :param pretrained_model_name_or_path: path to model config on disk or HF hub
45
- :return: matching compressor if config contains a sparsity config
46
- """
47
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
48
- sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
49
- if sparsity_config is None:
50
- return None
51
-
52
- format = sparsity_config.get("format")
53
- sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
54
- compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
55
- return compressor
56
-
57
-
58
34
  def save_compressed(
59
35
  tensors: Dict[str, Tensor],
60
36
  save_path: Union[str, Path],
@@ -67,7 +67,7 @@ class BitmaskCompressor(ModelCompressor):
67
67
  f"found an existing entry for {key}. The existing entry will "
68
68
  "be replaced."
69
69
  )
70
- compressed_dict |= bitmask_dict
70
+ compressed_dict.update(bitmask_dict)
71
71
 
72
72
  return compressed_dict
73
73
 
@@ -75,8 +75,9 @@ class BitmaskCompressor(ModelCompressor):
75
75
  self, path_to_model_or_tensors: str, device: str = "cpu"
76
76
  ) -> Generator[Tuple[str, Tensor], None, None]:
77
77
  """
78
- Reads a bitmask compressed state dict located at path_to_model_or_tensors
79
- and returns a generator for sequentially decompressing back to a dense state dict
78
+ Reads a bitmask compressed state dict located
79
+ at path_to_model_or_tensors and returns a generator
80
+ for sequentially decompressing back to a dense state dict
80
81
 
81
82
  :param model_path: path to compressed safetensors model (directory with
82
83
  one or more safetensors files) or compressed tensors file
@@ -36,6 +36,7 @@ __all__ = [
36
36
  "load_pretrained_quantization",
37
37
  "apply_quantization_config",
38
38
  "apply_quantization_status",
39
+ "find_first_name_or_class_match",
39
40
  ]
40
41
 
41
42
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -99,9 +100,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
99
100
 
100
101
  # mark appropriate layers for quantization by setting their quantization schemes
101
102
  for name, submodule in iter_named_leaf_modules(model):
102
- if _find_first_name_or_class_match(name, submodule, config.ignore):
103
+ if find_first_name_or_class_match(name, submodule, config.ignore):
103
104
  continue # layer matches ignore list, continue
104
- target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
105
+ target = find_first_name_or_class_match(name, submodule, target_to_scheme)
105
106
  if target is not None:
106
107
  # target matched - add layer and scheme to target list
107
108
  submodule.quantization_scheme = target_to_scheme[target]
@@ -125,27 +126,31 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
125
126
  model.apply(freeze_module_quantization)
126
127
 
127
128
 
128
- def _find_first_name_or_class_match(
129
- name: str,
130
- module: Module,
131
- targets: Iterable[str],
129
+ def find_first_name_or_class_match(
130
+ name: str, module: Module, targets: Iterable[str], check_contains: bool = False
132
131
  ) -> Optional[str]:
133
132
  # first element of targets that matches the given name
134
133
  # if no name matches returns first target that matches the class name
135
134
  # returns None otherwise
136
135
  return _find_first_match(name, targets) or _find_first_match(
137
- module.__class__.__name__, targets
136
+ module.__class__.__name__, targets, check_contains
138
137
  )
139
138
 
140
139
 
141
- def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
140
+ def _find_first_match(
141
+ value: str, targets: Iterable[str], check_contains: bool = False
142
+ ) -> Optional[str]:
142
143
  # returns first element of target that matches value either
143
- # exactly or as a regex after 're:'
144
+ # exactly or as a regex after 're:'. if check_contains is set to True,
145
+ # additionally checks if the target string is contained with value.
144
146
  for target in targets:
145
147
  if target.startswith("re:"):
146
148
  pattern = target[3:]
147
149
  if re.match(pattern, value):
148
150
  return target
151
+ elif check_contains:
152
+ if target.lower() in value.lower():
153
+ return target
149
154
  elif target == value:
150
155
  return target
151
156
  return None
@@ -0,0 +1,221 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import wraps
16
+ from math import ceil
17
+
18
+ import torch
19
+ from compressed_tensors.quantization.quant_args import (
20
+ QuantizationArgs,
21
+ QuantizationStrategy,
22
+ )
23
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
24
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
25
+ from torch.nn import Module
26
+
27
+
28
+ __all__ = ["wrap_module_forward_quantized", "maybe_calibrate_or_quantize"]
29
+
30
+
31
+ @torch.no_grad()
32
+ def quantize(
33
+ x: torch.Tensor,
34
+ scale: torch.Tensor,
35
+ zero_point: torch.Tensor,
36
+ q_min: torch.Tensor,
37
+ q_max: torch.Tensor,
38
+ ) -> torch.Tensor:
39
+
40
+ return torch.clamp(
41
+ torch.round(x / scale + zero_point),
42
+ q_min,
43
+ q_max,
44
+ )
45
+
46
+
47
+ @torch.no_grad()
48
+ def dequantize(
49
+ x_q: torch.Tensor,
50
+ scale: torch.Tensor,
51
+ zero_point: torch.Tensor,
52
+ ) -> torch.Tensor:
53
+ return (x_q - zero_point) * scale
54
+
55
+
56
+ @torch.no_grad()
57
+ def fake_quantize(
58
+ x: torch.Tensor,
59
+ scale: torch.Tensor,
60
+ zero_point: torch.Tensor,
61
+ args: QuantizationArgs,
62
+ ) -> torch.Tensor:
63
+ """
64
+ Fake quantize the input tensor x depending on the group_size.
65
+ if group_size is greater than 0, then q/dq by groups. The groups
66
+ must be divisible by the column size
67
+ if group_size is -1, then channel wise q/dq. THe input scale and
68
+ zero_points are reshaped to support vectorization (Assumes 1 is
69
+ the channel dimension)
70
+
71
+ :param x: Input tensor
72
+ :param scale: scale tensor
73
+ :param zero_point: zero point tensor
74
+ :param args: quantization args that contain group_size info
75
+ :return: fake quantized tensor
76
+
77
+ """
78
+ bit_range = 2**args.num_bits
79
+ max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
80
+ min_q = torch.tensor(-bit_range / 2, device=x.device)
81
+
82
+ group_size = args.group_size
83
+
84
+ # group
85
+ if args.strategy == QuantizationStrategy.GROUP:
86
+
87
+ DQ = torch.zeros_like(x)
88
+
89
+ # TODO: vectorize the for loop
90
+ # TODO: fix genetric assumption about the tensor size for computing group
91
+
92
+ # TODO: make validation step for inputs
93
+
94
+ while scale.ndim < 2:
95
+ # pad scale and zero point dims for slicing
96
+ scale = scale.unsqueeze(1)
97
+ zero_point = zero_point.unsqueeze(1)
98
+
99
+ columns = x.shape[1]
100
+ if columns >= group_size:
101
+ if columns % group_size != 0:
102
+ raise ValueError(
103
+ "tesnor column shape must be divisble "
104
+ f"by the given group_size {group_size}"
105
+ )
106
+ for i in range(ceil(columns / group_size)):
107
+ # scale.shape should be [nchan, ndim]
108
+ # sc.shape should be [nchan, 1] after unsqueeze
109
+
110
+ sc = scale[:, i].unsqueeze(1)
111
+ zp = zero_point[:, i].unsqueeze(1)
112
+
113
+ idx = i * group_size
114
+ Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q)
115
+ DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)
116
+
117
+ # channel-wise
118
+ elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
119
+ # before: scale shape = [channel_size]
120
+ # after: scale shape = [1, channel_size]
121
+ scale = scale.unsqueeze(0)
122
+ zero_point = zero_point.unsqueeze(0)
123
+
124
+ Q = quantize(x, scale, zero_point, min_q, max_q)
125
+ DQ = dequantize(Q, scale, zero_point)
126
+
127
+ # per-token
128
+ elif args.strategy == QuantizationStrategy.TOKEN:
129
+ # before: scale shape = [num_tokens]
130
+ # after: scale shape = [num_tokens, 1]
131
+ # x.shape = 1, num_tokens, 1]
132
+ # scale gets broadcasted as expected withput having [1, num_tokens, 1] shape
133
+
134
+ scale = scale.unsqueeze(1)
135
+ zero_point = zero_point.unsqueeze(1)
136
+
137
+ Q = quantize(x, scale, zero_point, min_q, max_q)
138
+ DQ = dequantize(Q, scale, zero_point)
139
+
140
+ else:
141
+ Q = quantize(x, scale, zero_point, min_q, max_q)
142
+ DQ = dequantize(Q, scale, zero_point)
143
+
144
+ return DQ
145
+
146
+
147
+ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
148
+ # expects a module already initialized and injected with the parameters in
149
+ # initialize_module_for_quantization
150
+ forward_func_orig = module.forward.__func__
151
+
152
+ @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
153
+ def wrapped_forward(self, *args, **kwargs):
154
+ input_ = args[0]
155
+
156
+ if scheme.input_activations is not None:
157
+ # calibrate and (fake) quantize input activations when applicable
158
+ input_ = maybe_calibrate_or_quantize(
159
+ module, input_, "input", scheme.input_activations
160
+ )
161
+
162
+ if scheme.weights is not None:
163
+ # calibrate and (fake) quantize weights when applicable
164
+ unquantized_weight = self.weight.data.clone()
165
+ self.weight.data = maybe_calibrate_or_quantize(
166
+ module, self.weight, "weight", scheme.weights
167
+ )
168
+
169
+ # perform wrapped forward call
170
+ output = forward_func_orig.__get__(module, module.__class__)(
171
+ input_, *args[1:], **kwargs
172
+ )
173
+
174
+ if scheme.output_activations is not None:
175
+ # calibrate and (fake) quantize output activations when applicable
176
+ output = maybe_calibrate_or_quantize(
177
+ module, output, "output", scheme.output_activations
178
+ )
179
+
180
+ # restore back to unquantized_value
181
+ if scheme.weights is not None:
182
+ self.weight.data = unquantized_weight
183
+
184
+ return output
185
+
186
+ # bind wrapped forward to module class so reference to `self` is correct
187
+ bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
188
+ # set forward to wrapped forward
189
+ setattr(module, "forward", bound_wrapped_forward)
190
+
191
+
192
+ def maybe_calibrate_or_quantize(
193
+ module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
194
+ ) -> torch.Tensor:
195
+ # only run quantized for the included stages
196
+ if module.quantization_status not in {
197
+ QuantizationStatus.CALIBRATION,
198
+ QuantizationStatus.FROZEN,
199
+ }:
200
+ return value
201
+
202
+ if args.dynamic:
203
+ # dynamic quantization - get scale and zero point directly from observer
204
+ observer = getattr(module, f"{base_name}_observer")
205
+ scale, zero_point = observer(value)
206
+ else:
207
+ # static quantization - get previous scale and zero point from layer
208
+ scale = getattr(module, f"{base_name}_scale")
209
+ zero_point = getattr(module, f"{base_name}_zero_point")
210
+
211
+ if module.quantization_status == QuantizationStatus.CALIBRATION:
212
+ # calibration mode - get new quant params from observer
213
+ observer = getattr(module, f"{base_name}_observer")
214
+
215
+ updated_scale, updated_zero_point = observer(value)
216
+
217
+ # update scale and zero point
218
+ device = next(module.parameters()).device
219
+ scale.data = updated_scale.to(device)
220
+ zero_point.data = updated_zero_point.to(device)
221
+ return fake_quantize(value, scale, zero_point, args)
@@ -30,17 +30,17 @@ def freeze_module_quantization(module: Module):
30
30
 
31
31
  :param module: module to freeze quantization for
32
32
  """
33
- if not getattr(module, "quantization_scheme", None):
33
+ scheme = getattr(module, "quantization_scheme", None)
34
+ if not scheme:
34
35
  # no quantization scheme nothing to do
35
36
  return
36
37
 
37
- # delete observers from module
38
- observer_names = []
39
- for submodule_name, _ in module.named_modules():
40
- if "." not in submodule_name and submodule_name.endswith("_observer"):
41
- # delete any observers that belong directly to this module
42
- observer_names.append(submodule_name)
43
- for observer_name in observer_names:
44
- delattr(module, observer_name)
38
+ # delete observers from module if not dynamic
39
+ if scheme.input_activations and not scheme.input_activations.dynamic:
40
+ delattr(module, "input_observer")
41
+ if scheme.weights and not scheme.weights.dynamic:
42
+ delattr(module, "weight_observer")
43
+ if scheme.output_activations and not scheme.output_activations.dynamic:
44
+ delattr(module, "output_observer")
45
45
 
46
46
  module.quantization_status = QuantizationStatus.FROZEN
@@ -80,6 +80,13 @@ def initialize_module_for_quantization(
80
80
  def _initialize_scale_zero_point_observer(
81
81
  module: Module, base_name: str, quantization_args: QuantizationArgs
82
82
  ):
83
+ # initialize observer module and attach as submodule
84
+ observer = quantization_args.get_observer()
85
+ module.register_module(f"{base_name}_observer", observer)
86
+
87
+ if quantization_args.dynamic:
88
+ return # no need to register a scale and zero point for a dynamic observer
89
+
83
90
  device = next(module.parameters()).device
84
91
 
85
92
  # initializes empty scale and zero point parameters for the module
@@ -90,7 +97,3 @@ def _initialize_scale_zero_point_observer(
90
97
  torch.empty(0, device=device, dtype=int), requires_grad=False
91
98
  )
92
99
  module.register_parameter(f"{base_name}_zero_point", init_zero_point)
93
-
94
- # initialize observer module and attach as submodule
95
- observer = quantization_args.get_observer()
96
- module.register_module(f"{base_name}_observer", observer)
@@ -14,7 +14,11 @@
14
14
 
15
15
  from typing import Optional, Tuple
16
16
 
17
- from compressed_tensors.quantization.quant_args import QuantizationArgs
17
+ import torch
18
+ from compressed_tensors.quantization.quant_args import (
19
+ QuantizationArgs,
20
+ QuantizationStrategy,
21
+ )
18
22
  from compressed_tensors.registry.registry import RegistryMixin
19
23
  from torch import FloatTensor, IntTensor, Tensor
20
24
  from torch.nn import Module
@@ -52,6 +56,12 @@ class Observer(Module, RegistryMixin):
52
56
  """
53
57
  raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
54
58
 
59
+ def post_calculate_qparams(self) -> None:
60
+ """
61
+ Run any logic specific to its observers after running calculate_qparams
62
+ """
63
+ ...
64
+
55
65
  def get_qparams(
56
66
  self, observed: Optional[Tensor] = None
57
67
  ) -> Tuple[FloatTensor, IntTensor]:
@@ -64,6 +74,57 @@ class Observer(Module, RegistryMixin):
64
74
  :return: tuple of scale and zero point based on last observed value
65
75
  """
66
76
  if observed is not None:
67
- # re-calcualte scale and zero point, update the stored value
68
- self._scale, self._zero_point = self.calculate_qparams(observed)
77
+ group_size = self.quantization_args.group_size
78
+
79
+ if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
80
+
81
+ # re-calculate scale and zero point, update the stored value
82
+ self._scale, self._zero_point = self.calculate_qparams(observed)
83
+
84
+ elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
85
+ columns = observed.shape[1]
86
+ scales, zero_points = [], []
87
+ for i in range(0, columns, self.quantization_args.group_size):
88
+ scale, zero_point = self.get_qparams_along_dim(
89
+ observed[:, i : (i + group_size)],
90
+ 0,
91
+ )
92
+ scales.append(scale)
93
+ zero_points.append(zero_point)
94
+
95
+ self._scale = torch.stack(scales, dim=1)
96
+ self._zero_point = torch.stack(zero_points, dim=1)
97
+
98
+ elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
99
+ # assume observed is transposed, because its the output, hence use dim 0
100
+ self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
101
+
102
+ elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
103
+
104
+ # use dim 1, assume the obsersed.shape = [batch, token, hidden]
105
+ # should be batch, token
106
+
107
+ self._scale, self._zero_point = self.get_qparams_along_dim(
108
+ observed, dim=1
109
+ )
110
+
69
111
  return self._scale, self._zero_point
112
+
113
+ def get_qparams_along_dim(self, observed, dim: int):
114
+ # TODO: add documentation that specifies the shape must
115
+ # be padded with 1-dims so the scales are along the right channel
116
+ # TODO: generalize the logic for reduce_dims
117
+ scales, zero_points = [], []
118
+
119
+ # TODO: make a more generic way to get the channel
120
+ num_dims = observed.shape[dim]
121
+
122
+ for dim_idx in range(num_dims):
123
+ scale, zero_point = self.calculate_qparams(
124
+ observed.select(dim=dim, index=dim_idx)
125
+ )
126
+
127
+ scales.append(scale)
128
+ zero_points.append(zero_point)
129
+ # breakpoint()
130
+ return torch.stack(scales), torch.stack(zero_points)
@@ -23,10 +23,10 @@ from torch import FloatTensor, IntTensor, Tensor
23
23
  __all__ = ["MemorylessObserver"]
24
24
 
25
25
 
26
- @Observer.register("memoryless")
26
+ @Observer.register("memoryless", alias=["dynamic"])
27
27
  class MemorylessObserver(Observer):
28
28
  """
29
- Implements a dynamic quantization observer that sets the scale and
29
+ Implements a quantization observer that sets the scale and
30
30
  zero point based on the latest observed value without tracking state
31
31
  """
32
32
 
@@ -15,7 +15,7 @@
15
15
  from enum import Enum
16
16
  from typing import Any, Dict, Optional
17
17
 
18
- from pydantic import BaseModel, Field
18
+ from pydantic import BaseModel, Field, validator
19
19
 
20
20
 
21
21
  __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
@@ -39,6 +39,7 @@ class QuantizationStrategy(str, Enum):
39
39
  CHANNEL = "channel"
40
40
  GROUP = "group"
41
41
  BLOCK = "block"
42
+ TOKEN = "token"
42
43
 
43
44
 
44
45
  class QuantizationArgs(BaseModel):
@@ -53,14 +54,20 @@ class QuantizationArgs(BaseModel):
53
54
  :param group_size: group length to use for the group strategy
54
55
  :param block_structure: 2d block structure to use for the block strategy, must be
55
56
  of the format "2x4", "8x16", etc.
57
+ :param dynamic: set True to perform dynamic quantization - values will not be
58
+ calibrated during calibration phase, instead during inference new quantization
59
+ ranges will be observed with every sample. Defaults to False for static
60
+ quantization. Note that enabling dynamic quantization will change the default
61
+ observer to a memoryless one
56
62
  """
57
63
 
58
64
  num_bits: int = 8
59
65
  type: QuantizationType = QuantizationType.INT
60
66
  symmetric: bool = True
61
- strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
62
67
  group_size: Optional[int] = None
68
+ strategy: Optional[QuantizationStrategy] = None
63
69
  block_structure: Optional[str] = None
70
+ dynamic: bool = False
64
71
  observer: str = Field(
65
72
  default="minmax",
66
73
  description=(
@@ -82,4 +89,37 @@ class QuantizationArgs(BaseModel):
82
89
  """
83
90
  from compressed_tensors.quantization.observers.base import Observer
84
91
 
92
+ if self.observer == "minmax" and self.dynamic:
93
+ # override defualt observer for dynamic, you never want minmax which
94
+ # keeps state across samples for dynamic
95
+ self.observer = "memoryless"
96
+
85
97
  return Observer.load_from_registry(self.observer, quantization_args=self)
98
+
99
+ @validator("strategy", pre=True, always=True)
100
+ def validate_strategy(cls, value, values):
101
+ group_size = values.get("group_size")
102
+
103
+ # use group_size to determinine strategy if not given explicity
104
+ if group_size is not None and value is None:
105
+ if group_size > 0:
106
+ return QuantizationStrategy.GROUP
107
+
108
+ elif group_size == -1:
109
+ return QuantizationStrategy.CHANNEL
110
+
111
+ else:
112
+ raise ValueError(
113
+ f"group_size={group_size} with strategy {value} is invald. "
114
+ "group_size > 0 for strategy='group' and "
115
+ "group_size = -1 for 'channel'"
116
+ )
117
+
118
+ if value == QuantizationStrategy.GROUP:
119
+ if group_size is None:
120
+ raise ValueError(f"strategy {value} requires group_size to be set.")
121
+
122
+ if value is None:
123
+ return QuantizationStrategy.TENSOR
124
+
125
+ return value
@@ -108,6 +108,7 @@ def calculate_compression_ratio(model: Module) -> float:
108
108
  compressed_bits = uncompressed_bits
109
109
  if is_module_quantized(submodule):
110
110
  compressed_bits = submodule.quantization_scheme.weights.num_bits
111
+
111
112
  num_weights = parameter.numel()
112
113
  total_compressed += compressed_bits * num_weights
113
114
  total_uncompressed += uncompressed_bits * num_weights
@@ -0,0 +1,45 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional
17
+
18
+ from compressed_tensors.base import SPARSITY_CONFIG_NAME
19
+ from compressed_tensors.compressors import ModelCompressor
20
+ from compressed_tensors.config import CompressionConfig
21
+ from transformers import AutoConfig
22
+
23
+
24
+ __all__ = ["infer_compressor_from_model_config"]
25
+
26
+
27
+ def infer_compressor_from_model_config(
28
+ pretrained_model_name_or_path: str,
29
+ ) -> Optional[ModelCompressor]:
30
+ """
31
+ Given a path to a model config, extract a sparsity config if it exists and return
32
+ the associated ModelCompressor
33
+
34
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
35
+ :return: matching compressor if config contains a sparsity config
36
+ """
37
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
38
+ sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
39
+ if sparsity_config is None:
40
+ return None
41
+
42
+ format = sparsity_config.get("format")
43
+ sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
44
+ compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
45
+ return compressor
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -11,7 +11,7 @@ Description-Content-Type: text/markdown
11
11
  Provides-Extra: dev
12
12
  License-File: LICENSE
13
13
 
14
- # compressed-tensors
14
+ # compressed_tensors
15
15
 
16
16
  This repository extends a [safetensors](https://github.com/huggingface/safetensors) format to efficiently store sparse and/or quantized tensors on disk. `compressed-tensors` format supports multiple compression types to minimize the disk space and facilitate the tensor manipulation.
17
17
 
@@ -94,4 +94,6 @@ save_compressed_model(model, "compressed_model.safetensors", compression_format=
94
94
  state_dict = dict(load_compressed("compressed_model.safetensors", compression_config))
95
95
  ```
96
96
 
97
+ For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/neuralmagic/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb).
98
+
97
99
 
@@ -39,6 +39,7 @@ src/compressed_tensors/quantization/utils/helpers.py
39
39
  src/compressed_tensors/registry/__init__.py
40
40
  src/compressed_tensors/registry/registry.py
41
41
  src/compressed_tensors/utils/__init__.py
42
+ src/compressed_tensors/utils/helpers.py
42
43
  src/compressed_tensors/utils/safetensors_load.py
43
44
  tests/test_bitmask.py
44
45
  tests/test_registry.py
@@ -1,136 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from functools import wraps
16
-
17
- import torch
18
- from compressed_tensors.quantization.quant_args import QuantizationArgs
19
- from compressed_tensors.quantization.quant_config import QuantizationStatus
20
- from compressed_tensors.quantization.quant_scheme import QuantizationScheme
21
- from torch.nn import Module
22
-
23
-
24
- __all__ = ["wrap_module_forward_quantized"]
25
-
26
-
27
- @torch.no_grad()
28
- def quantize(
29
- x: torch.Tensor,
30
- scale: torch.Tensor,
31
- zero_point: torch.Tensor,
32
- q_min: torch.Tensor,
33
- q_max: torch.Tensor,
34
- ) -> torch.Tensor:
35
- return torch.clamp(
36
- torch.round(
37
- x / scale + zero_point,
38
- ),
39
- q_min,
40
- q_max,
41
- )
42
-
43
-
44
- @torch.no_grad()
45
- def dequantize(
46
- x_q: torch.Tensor,
47
- scale: torch.Tensor,
48
- zero_point: torch.Tensor,
49
- ) -> torch.Tensor:
50
- return (x_q - zero_point) * scale
51
-
52
-
53
- @torch.no_grad()
54
- def fake_quantize(
55
- x: torch.Tensor,
56
- scale: torch.Tensor,
57
- zero_point: torch.Tensor,
58
- args: QuantizationArgs,
59
- ) -> torch.Tensor:
60
- bit_range = 2**args.num_bits
61
- max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
62
- min_q = torch.tensor(-bit_range / 2, device=x.device)
63
- Q = torch.zeros_like(x)
64
- Q = quantize(x, scale, zero_point, min_q, max_q)
65
- return dequantize(Q, scale, zero_point)
66
-
67
-
68
- def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
69
- # expects a module already initialized and injected with the parameters in
70
- # initialize_module_for_quantization
71
- forward_func_orig = module.forward.__func__
72
-
73
- @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
74
- def wrapped_forward(self, *args, **kwargs):
75
- input_ = args[0]
76
-
77
- if scheme.input_activations is not None:
78
- # calibrate and (fake) quantize input activations when applicable
79
- input_ = _maybe_calibrate_or_quantize(
80
- module, input_, "input", scheme.input_activations
81
- )
82
-
83
- if scheme.weights is not None:
84
- # calibrate and (fake) quantize weights when applicable
85
- unquantized_weight = self.weight.data.clone()
86
- self.weight.data = _maybe_calibrate_or_quantize(
87
- module, self.weight, "weight", scheme.weights
88
- )
89
-
90
- # perform wrapped forward call
91
- output = forward_func_orig.__get__(module, module.__class__)(
92
- input_, *args[1:], **kwargs
93
- )
94
-
95
- if scheme.output_activations is not None:
96
- # calibrate and (fake) quantize output activations when applicable
97
- output = _maybe_calibrate_or_quantize(
98
- module, output, "output", scheme.output_activations
99
- )
100
-
101
- # restore back to unquantized_value
102
- if scheme.weights is not None:
103
- self.weight.data = unquantized_weight
104
-
105
- return output
106
-
107
- # bind wrapped forward to module class so reference to `self` is correct
108
- bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
109
- # set forward to wrapped forward
110
- setattr(module, "forward", bound_wrapped_forward)
111
-
112
-
113
- def _maybe_calibrate_or_quantize(
114
- module: Module, value: Module, base_name: str, args: "QuantizationArgs"
115
- ) -> torch.Tensor:
116
- # only run quantized for the included stages
117
- if module.quantization_status not in {
118
- QuantizationStatus.CALIBRATION,
119
- QuantizationStatus.FROZEN,
120
- }:
121
- return value
122
-
123
- device = next(module.parameters()).device
124
- scale = getattr(module, f"{base_name}_scale")
125
- zero_point = getattr(module, f"{base_name}_zero_point")
126
-
127
- if module.quantization_status == QuantizationStatus.CALIBRATION:
128
- # get observer and get new quant params from observation
129
- observer = getattr(module, f"{base_name}_observer")
130
- updated_scale, updated_zero_point = observer(value)
131
-
132
- # update scale and zero point
133
- scale.data = updated_scale.to(device)
134
- zero_point.data = updated_zero_point.to(device)
135
-
136
- return fake_quantize(value, scale, zero_point, args)