compressed-tensors 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,10 +16,5 @@
16
16
 
17
17
  from .base import ModelCompressor
18
18
  from .dense import DenseCompressor
19
- from .helpers import (
20
- infer_compressor_from_model_config,
21
- load_compressed,
22
- save_compressed,
23
- save_compressed_model,
24
- )
19
+ from .helpers import load_compressed, save_compressed, save_compressed_model
25
20
  from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
@@ -22,6 +22,7 @@ from compressed_tensors.utils import get_safetensors_folder
22
22
  from torch import Tensor
23
23
  from torch.nn import Module, Parameter
24
24
  from tqdm import tqdm
25
+ from transformers import AutoConfig
25
26
 
26
27
 
27
28
  __all__ = ["ModelCompressor"]
@@ -34,6 +35,29 @@ class ModelCompressor(RegistryMixin):
34
35
  :param config: config specifying compression parameters
35
36
  """
36
37
 
38
+ @classmethod
39
+ def from_pretrained(
40
+ cls, pretrained_model_name_or_path: str
41
+ ) -> Optional["ModelCompressor"]:
42
+ """
43
+ Given a path to a model config, extract a sparsity config if it exists and
44
+ return the associated ModelCompressor
45
+
46
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
47
+ :return: matching compressor if config contains a sparsity config
48
+ """
49
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
50
+ sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
51
+ if sparsity_config is None:
52
+ return None
53
+
54
+ format = sparsity_config.get("format")
55
+ sparsity_config = CompressionConfig.load_from_registry(
56
+ format, **sparsity_config
57
+ )
58
+ compressor = cls.load_from_registry(format, config=sparsity_config)
59
+ return compressor
60
+
37
61
  def __init__(self, config: Optional[CompressionConfig] = None):
38
62
  self.config = config
39
63
 
@@ -47,7 +71,7 @@ class ModelCompressor(RegistryMixin):
47
71
  raise NotImplementedError()
48
72
 
49
73
  def decompress(
50
- self, path_to_model_or_tensors: str
74
+ self, path_to_model_or_tensors: str, device: str = "cpu"
51
75
  ) -> Generator[Tuple[str, Tensor], None, None]:
52
76
  """
53
77
  Reads a compressed state dict located at path_to_model_or_tensors
@@ -29,6 +29,6 @@ class DenseCompressor(ModelCompressor):
29
29
  return model_state
30
30
 
31
31
  def decompress(
32
- self, path_to_model_or_tensors: str, device: str
32
+ self, path_to_model_or_tensors: str, device: str = "cpu"
33
33
  ) -> Generator[Tuple[str, Tensor], None, None]:
34
34
  return iter([])
@@ -16,45 +16,21 @@ from pathlib import Path
16
16
  from typing import Dict, Generator, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from compressed_tensors.base import SPARSITY_CONFIG_NAME
20
19
  from compressed_tensors.compressors import ModelCompressor
21
20
  from compressed_tensors.config import CompressionConfig, CompressionFormat
22
21
  from compressed_tensors.utils.safetensors_load import get_weight_mappings
23
22
  from safetensors import safe_open
24
23
  from safetensors.torch import save_file
25
24
  from torch import Tensor
26
- from transformers import AutoConfig
27
25
 
28
26
 
29
27
  __all__ = [
30
- "infer_compressor_from_model_config",
31
28
  "load_compressed",
32
29
  "save_compressed",
33
30
  "save_compressed_model",
34
31
  ]
35
32
 
36
33
 
37
- def infer_compressor_from_model_config(
38
- pretrained_model_name_or_path: str,
39
- ) -> Optional[ModelCompressor]:
40
- """
41
- Given a path to a model config, extract a sparsity config if it exists and return
42
- the associated ModelCompressor
43
-
44
- :param pretrained_model_name_or_path: path to model config on disk or HF hub
45
- :return: matching compressor if config contains a sparsity config
46
- """
47
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
48
- sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
49
- if sparsity_config is None:
50
- return None
51
-
52
- format = sparsity_config.get("format")
53
- sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
54
- compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
55
- return compressor
56
-
57
-
58
34
  def save_compressed(
59
35
  tensors: Dict[str, Tensor],
60
36
  save_path: Union[str, Path],
@@ -67,7 +67,7 @@ class BitmaskCompressor(ModelCompressor):
67
67
  f"found an existing entry for {key}. The existing entry will "
68
68
  "be replaced."
69
69
  )
70
- compressed_dict |= bitmask_dict
70
+ compressed_dict.update(bitmask_dict)
71
71
 
72
72
  return compressed_dict
73
73
 
@@ -75,8 +75,9 @@ class BitmaskCompressor(ModelCompressor):
75
75
  self, path_to_model_or_tensors: str, device: str = "cpu"
76
76
  ) -> Generator[Tuple[str, Tensor], None, None]:
77
77
  """
78
- Reads a bitmask compressed state dict located at path_to_model_or_tensors
79
- and returns a generator for sequentially decompressing back to a dense state dict
78
+ Reads a bitmask compressed state dict located
79
+ at path_to_model_or_tensors and returns a generator
80
+ for sequentially decompressing back to a dense state dict
80
81
 
81
82
  :param model_path: path to compressed safetensors model (directory with
82
83
  one or more safetensors files) or compressed tensors file
@@ -36,6 +36,7 @@ __all__ = [
36
36
  "load_pretrained_quantization",
37
37
  "apply_quantization_config",
38
38
  "apply_quantization_status",
39
+ "find_first_name_or_class_match",
39
40
  ]
40
41
 
41
42
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -99,9 +100,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
99
100
 
100
101
  # mark appropriate layers for quantization by setting their quantization schemes
101
102
  for name, submodule in iter_named_leaf_modules(model):
102
- if _find_first_name_or_class_match(name, submodule, config.ignore):
103
+ if find_first_name_or_class_match(name, submodule, config.ignore):
103
104
  continue # layer matches ignore list, continue
104
- target = _find_first_name_or_class_match(name, submodule, target_to_scheme)
105
+ target = find_first_name_or_class_match(name, submodule, target_to_scheme)
105
106
  if target is not None:
106
107
  # target matched - add layer and scheme to target list
107
108
  submodule.quantization_scheme = target_to_scheme[target]
@@ -125,27 +126,31 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
125
126
  model.apply(freeze_module_quantization)
126
127
 
127
128
 
128
- def _find_first_name_or_class_match(
129
- name: str,
130
- module: Module,
131
- targets: Iterable[str],
129
+ def find_first_name_or_class_match(
130
+ name: str, module: Module, targets: Iterable[str], check_contains: bool = False
132
131
  ) -> Optional[str]:
133
132
  # first element of targets that matches the given name
134
133
  # if no name matches returns first target that matches the class name
135
134
  # returns None otherwise
136
135
  return _find_first_match(name, targets) or _find_first_match(
137
- module.__class__.__name__, targets
136
+ module.__class__.__name__, targets, check_contains
138
137
  )
139
138
 
140
139
 
141
- def _find_first_match(value: str, targets: Iterable[str]) -> Optional[str]:
140
+ def _find_first_match(
141
+ value: str, targets: Iterable[str], check_contains: bool = False
142
+ ) -> Optional[str]:
142
143
  # returns first element of target that matches value either
143
- # exactly or as a regex after 're:'
144
+ # exactly or as a regex after 're:'. if check_contains is set to True,
145
+ # additionally checks if the target string is contained with value.
144
146
  for target in targets:
145
147
  if target.startswith("re:"):
146
148
  pattern = target[3:]
147
149
  if re.match(pattern, value):
148
150
  return target
151
+ elif check_contains:
152
+ if target.lower() in value.lower():
153
+ return target
149
154
  elif target == value:
150
155
  return target
151
156
  return None
@@ -13,15 +13,19 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from functools import wraps
16
+ from math import ceil
16
17
 
17
18
  import torch
18
- from compressed_tensors.quantization.quant_args import QuantizationArgs
19
+ from compressed_tensors.quantization.quant_args import (
20
+ QuantizationArgs,
21
+ QuantizationStrategy,
22
+ )
19
23
  from compressed_tensors.quantization.quant_config import QuantizationStatus
20
24
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
21
25
  from torch.nn import Module
22
26
 
23
27
 
24
- __all__ = ["wrap_module_forward_quantized"]
28
+ __all__ = ["wrap_module_forward_quantized", "maybe_calibrate_or_quantize"]
25
29
 
26
30
 
27
31
  @torch.no_grad()
@@ -32,10 +36,9 @@ def quantize(
32
36
  q_min: torch.Tensor,
33
37
  q_max: torch.Tensor,
34
38
  ) -> torch.Tensor:
39
+
35
40
  return torch.clamp(
36
- torch.round(
37
- x / scale + zero_point,
38
- ),
41
+ torch.round(x / scale + zero_point),
39
42
  q_min,
40
43
  q_max,
41
44
  )
@@ -57,12 +60,88 @@ def fake_quantize(
57
60
  zero_point: torch.Tensor,
58
61
  args: QuantizationArgs,
59
62
  ) -> torch.Tensor:
63
+ """
64
+ Fake quantize the input tensor x depending on the group_size.
65
+ if group_size is greater than 0, then q/dq by groups. The groups
66
+ must be divisible by the column size
67
+ if group_size is -1, then channel wise q/dq. THe input scale and
68
+ zero_points are reshaped to support vectorization (Assumes 1 is
69
+ the channel dimension)
70
+
71
+ :param x: Input tensor
72
+ :param scale: scale tensor
73
+ :param zero_point: zero point tensor
74
+ :param args: quantization args that contain group_size info
75
+ :return: fake quantized tensor
76
+
77
+ """
60
78
  bit_range = 2**args.num_bits
61
79
  max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
62
80
  min_q = torch.tensor(-bit_range / 2, device=x.device)
63
- Q = torch.zeros_like(x)
64
- Q = quantize(x, scale, zero_point, min_q, max_q)
65
- return dequantize(Q, scale, zero_point)
81
+
82
+ group_size = args.group_size
83
+
84
+ # group
85
+ if args.strategy == QuantizationStrategy.GROUP:
86
+
87
+ DQ = torch.zeros_like(x)
88
+
89
+ # TODO: vectorize the for loop
90
+ # TODO: fix genetric assumption about the tensor size for computing group
91
+
92
+ # TODO: make validation step for inputs
93
+
94
+ while scale.ndim < 2:
95
+ # pad scale and zero point dims for slicing
96
+ scale = scale.unsqueeze(1)
97
+ zero_point = zero_point.unsqueeze(1)
98
+
99
+ columns = x.shape[1]
100
+ if columns >= group_size:
101
+ if columns % group_size != 0:
102
+ raise ValueError(
103
+ "tesnor column shape must be divisble "
104
+ f"by the given group_size {group_size}"
105
+ )
106
+ for i in range(ceil(columns / group_size)):
107
+ # scale.shape should be [nchan, ndim]
108
+ # sc.shape should be [nchan, 1] after unsqueeze
109
+
110
+ sc = scale[:, i].unsqueeze(1)
111
+ zp = zero_point[:, i].unsqueeze(1)
112
+
113
+ idx = i * group_size
114
+ Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q)
115
+ DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)
116
+
117
+ # channel-wise
118
+ elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
119
+ # before: scale shape = [channel_size]
120
+ # after: scale shape = [1, channel_size]
121
+ scale = scale.unsqueeze(0)
122
+ zero_point = zero_point.unsqueeze(0)
123
+
124
+ Q = quantize(x, scale, zero_point, min_q, max_q)
125
+ DQ = dequantize(Q, scale, zero_point)
126
+
127
+ # per-token
128
+ elif args.strategy == QuantizationStrategy.TOKEN:
129
+ # before: scale shape = [num_tokens]
130
+ # after: scale shape = [num_tokens, 1]
131
+ # x.shape = 1, num_tokens, 1]
132
+ # scale gets broadcasted as expected withput having [1, num_tokens, 1] shape
133
+
134
+ scale = scale.unsqueeze(1)
135
+ zero_point = zero_point.unsqueeze(1)
136
+
137
+ Q = quantize(x, scale, zero_point, min_q, max_q)
138
+ DQ = dequantize(Q, scale, zero_point)
139
+
140
+ else:
141
+ Q = quantize(x, scale, zero_point, min_q, max_q)
142
+ DQ = dequantize(Q, scale, zero_point)
143
+
144
+ return DQ
66
145
 
67
146
 
68
147
  def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
@@ -76,14 +155,14 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
76
155
 
77
156
  if scheme.input_activations is not None:
78
157
  # calibrate and (fake) quantize input activations when applicable
79
- input_ = _maybe_calibrate_or_quantize(
158
+ input_ = maybe_calibrate_or_quantize(
80
159
  module, input_, "input", scheme.input_activations
81
160
  )
82
161
 
83
162
  if scheme.weights is not None:
84
163
  # calibrate and (fake) quantize weights when applicable
85
164
  unquantized_weight = self.weight.data.clone()
86
- self.weight.data = _maybe_calibrate_or_quantize(
165
+ self.weight.data = maybe_calibrate_or_quantize(
87
166
  module, self.weight, "weight", scheme.weights
88
167
  )
89
168
 
@@ -94,7 +173,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
94
173
 
95
174
  if scheme.output_activations is not None:
96
175
  # calibrate and (fake) quantize output activations when applicable
97
- output = _maybe_calibrate_or_quantize(
176
+ output = maybe_calibrate_or_quantize(
98
177
  module, output, "output", scheme.output_activations
99
178
  )
100
179
 
@@ -110,8 +189,8 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
110
189
  setattr(module, "forward", bound_wrapped_forward)
111
190
 
112
191
 
113
- def _maybe_calibrate_or_quantize(
114
- module: Module, value: Module, base_name: str, args: "QuantizationArgs"
192
+ def maybe_calibrate_or_quantize(
193
+ module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
115
194
  ) -> torch.Tensor:
116
195
  # only run quantized for the included stages
117
196
  if module.quantization_status not in {
@@ -120,17 +199,23 @@ def _maybe_calibrate_or_quantize(
120
199
  }:
121
200
  return value
122
201
 
123
- device = next(module.parameters()).device
124
- scale = getattr(module, f"{base_name}_scale")
125
- zero_point = getattr(module, f"{base_name}_zero_point")
126
-
127
- if module.quantization_status == QuantizationStatus.CALIBRATION:
128
- # get observer and get new quant params from observation
202
+ if args.dynamic:
203
+ # dynamic quantization - get scale and zero point directly from observer
129
204
  observer = getattr(module, f"{base_name}_observer")
130
- updated_scale, updated_zero_point = observer(value)
131
-
132
- # update scale and zero point
133
- scale.data = updated_scale.to(device)
134
- zero_point.data = updated_zero_point.to(device)
135
-
205
+ scale, zero_point = observer(value)
206
+ else:
207
+ # static quantization - get previous scale and zero point from layer
208
+ scale = getattr(module, f"{base_name}_scale")
209
+ zero_point = getattr(module, f"{base_name}_zero_point")
210
+
211
+ if module.quantization_status == QuantizationStatus.CALIBRATION:
212
+ # calibration mode - get new quant params from observer
213
+ observer = getattr(module, f"{base_name}_observer")
214
+
215
+ updated_scale, updated_zero_point = observer(value)
216
+
217
+ # update scale and zero point
218
+ device = next(module.parameters()).device
219
+ scale.data = updated_scale.to(device)
220
+ zero_point.data = updated_zero_point.to(device)
136
221
  return fake_quantize(value, scale, zero_point, args)
@@ -30,17 +30,17 @@ def freeze_module_quantization(module: Module):
30
30
 
31
31
  :param module: module to freeze quantization for
32
32
  """
33
- if not getattr(module, "quantization_scheme", None):
33
+ scheme = getattr(module, "quantization_scheme", None)
34
+ if not scheme:
34
35
  # no quantization scheme nothing to do
35
36
  return
36
37
 
37
- # delete observers from module
38
- observer_names = []
39
- for submodule_name, _ in module.named_modules():
40
- if "." not in submodule_name and submodule_name.endswith("_observer"):
41
- # delete any observers that belong directly to this module
42
- observer_names.append(submodule_name)
43
- for observer_name in observer_names:
44
- delattr(module, observer_name)
38
+ # delete observers from module if not dynamic
39
+ if scheme.input_activations and not scheme.input_activations.dynamic:
40
+ delattr(module, "input_observer")
41
+ if scheme.weights and not scheme.weights.dynamic:
42
+ delattr(module, "weight_observer")
43
+ if scheme.output_activations and not scheme.output_activations.dynamic:
44
+ delattr(module, "output_observer")
45
45
 
46
46
  module.quantization_status = QuantizationStatus.FROZEN
@@ -80,6 +80,13 @@ def initialize_module_for_quantization(
80
80
  def _initialize_scale_zero_point_observer(
81
81
  module: Module, base_name: str, quantization_args: QuantizationArgs
82
82
  ):
83
+ # initialize observer module and attach as submodule
84
+ observer = quantization_args.get_observer()
85
+ module.register_module(f"{base_name}_observer", observer)
86
+
87
+ if quantization_args.dynamic:
88
+ return # no need to register a scale and zero point for a dynamic observer
89
+
83
90
  device = next(module.parameters()).device
84
91
 
85
92
  # initializes empty scale and zero point parameters for the module
@@ -90,7 +97,3 @@ def _initialize_scale_zero_point_observer(
90
97
  torch.empty(0, device=device, dtype=int), requires_grad=False
91
98
  )
92
99
  module.register_parameter(f"{base_name}_zero_point", init_zero_point)
93
-
94
- # initialize observer module and attach as submodule
95
- observer = quantization_args.get_observer()
96
- module.register_module(f"{base_name}_observer", observer)
@@ -14,7 +14,11 @@
14
14
 
15
15
  from typing import Optional, Tuple
16
16
 
17
- from compressed_tensors.quantization.quant_args import QuantizationArgs
17
+ import torch
18
+ from compressed_tensors.quantization.quant_args import (
19
+ QuantizationArgs,
20
+ QuantizationStrategy,
21
+ )
18
22
  from compressed_tensors.registry.registry import RegistryMixin
19
23
  from torch import FloatTensor, IntTensor, Tensor
20
24
  from torch.nn import Module
@@ -52,6 +56,12 @@ class Observer(Module, RegistryMixin):
52
56
  """
53
57
  raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
54
58
 
59
+ def post_calculate_qparams(self) -> None:
60
+ """
61
+ Run any logic specific to its observers after running calculate_qparams
62
+ """
63
+ ...
64
+
55
65
  def get_qparams(
56
66
  self, observed: Optional[Tensor] = None
57
67
  ) -> Tuple[FloatTensor, IntTensor]:
@@ -64,6 +74,57 @@ class Observer(Module, RegistryMixin):
64
74
  :return: tuple of scale and zero point based on last observed value
65
75
  """
66
76
  if observed is not None:
67
- # re-calcualte scale and zero point, update the stored value
68
- self._scale, self._zero_point = self.calculate_qparams(observed)
77
+ group_size = self.quantization_args.group_size
78
+
79
+ if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
80
+
81
+ # re-calculate scale and zero point, update the stored value
82
+ self._scale, self._zero_point = self.calculate_qparams(observed)
83
+
84
+ elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
85
+ columns = observed.shape[1]
86
+ scales, zero_points = [], []
87
+ for i in range(0, columns, self.quantization_args.group_size):
88
+ scale, zero_point = self.get_qparams_along_dim(
89
+ observed[:, i : (i + group_size)],
90
+ 0,
91
+ )
92
+ scales.append(scale)
93
+ zero_points.append(zero_point)
94
+
95
+ self._scale = torch.stack(scales, dim=1)
96
+ self._zero_point = torch.stack(zero_points, dim=1)
97
+
98
+ elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
99
+ # assume observed is transposed, because its the output, hence use dim 0
100
+ self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
101
+
102
+ elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
103
+
104
+ # use dim 1, assume the obsersed.shape = [batch, token, hidden]
105
+ # should be batch, token
106
+
107
+ self._scale, self._zero_point = self.get_qparams_along_dim(
108
+ observed, dim=1
109
+ )
110
+
69
111
  return self._scale, self._zero_point
112
+
113
+ def get_qparams_along_dim(self, observed, dim: int):
114
+ # TODO: add documentation that specifies the shape must
115
+ # be padded with 1-dims so the scales are along the right channel
116
+ # TODO: generalize the logic for reduce_dims
117
+ scales, zero_points = [], []
118
+
119
+ # TODO: make a more generic way to get the channel
120
+ num_dims = observed.shape[dim]
121
+
122
+ for dim_idx in range(num_dims):
123
+ scale, zero_point = self.calculate_qparams(
124
+ observed.select(dim=dim, index=dim_idx)
125
+ )
126
+
127
+ scales.append(scale)
128
+ zero_points.append(zero_point)
129
+ # breakpoint()
130
+ return torch.stack(scales), torch.stack(zero_points)
@@ -23,10 +23,10 @@ from torch import FloatTensor, IntTensor, Tensor
23
23
  __all__ = ["MemorylessObserver"]
24
24
 
25
25
 
26
- @Observer.register("memoryless")
26
+ @Observer.register("memoryless", alias=["dynamic"])
27
27
  class MemorylessObserver(Observer):
28
28
  """
29
- Implements a dynamic quantization observer that sets the scale and
29
+ Implements a quantization observer that sets the scale and
30
30
  zero point based on the latest observed value without tracking state
31
31
  """
32
32
 
@@ -15,7 +15,7 @@
15
15
  from enum import Enum
16
16
  from typing import Any, Dict, Optional
17
17
 
18
- from pydantic import BaseModel, Field
18
+ from pydantic import BaseModel, Field, validator
19
19
 
20
20
 
21
21
  __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
@@ -39,6 +39,7 @@ class QuantizationStrategy(str, Enum):
39
39
  CHANNEL = "channel"
40
40
  GROUP = "group"
41
41
  BLOCK = "block"
42
+ TOKEN = "token"
42
43
 
43
44
 
44
45
  class QuantizationArgs(BaseModel):
@@ -53,14 +54,20 @@ class QuantizationArgs(BaseModel):
53
54
  :param group_size: group length to use for the group strategy
54
55
  :param block_structure: 2d block structure to use for the block strategy, must be
55
56
  of the format "2x4", "8x16", etc.
57
+ :param dynamic: set True to perform dynamic quantization - values will not be
58
+ calibrated during calibration phase, instead during inference new quantization
59
+ ranges will be observed with every sample. Defaults to False for static
60
+ quantization. Note that enabling dynamic quantization will change the default
61
+ observer to a memoryless one
56
62
  """
57
63
 
58
64
  num_bits: int = 8
59
65
  type: QuantizationType = QuantizationType.INT
60
66
  symmetric: bool = True
61
- strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
62
67
  group_size: Optional[int] = None
68
+ strategy: Optional[QuantizationStrategy] = None
63
69
  block_structure: Optional[str] = None
70
+ dynamic: bool = False
64
71
  observer: str = Field(
65
72
  default="minmax",
66
73
  description=(
@@ -82,4 +89,37 @@ class QuantizationArgs(BaseModel):
82
89
  """
83
90
  from compressed_tensors.quantization.observers.base import Observer
84
91
 
92
+ if self.observer == "minmax" and self.dynamic:
93
+ # override defualt observer for dynamic, you never want minmax which
94
+ # keeps state across samples for dynamic
95
+ self.observer = "memoryless"
96
+
85
97
  return Observer.load_from_registry(self.observer, quantization_args=self)
98
+
99
+ @validator("strategy", pre=True, always=True)
100
+ def validate_strategy(cls, value, values):
101
+ group_size = values.get("group_size")
102
+
103
+ # use group_size to determinine strategy if not given explicity
104
+ if group_size is not None and value is None:
105
+ if group_size > 0:
106
+ return QuantizationStrategy.GROUP
107
+
108
+ elif group_size == -1:
109
+ return QuantizationStrategy.CHANNEL
110
+
111
+ else:
112
+ raise ValueError(
113
+ f"group_size={group_size} with strategy {value} is invald. "
114
+ "group_size > 0 for strategy='group' and "
115
+ "group_size = -1 for 'channel'"
116
+ )
117
+
118
+ if value == QuantizationStrategy.GROUP:
119
+ if group_size is None:
120
+ raise ValueError(f"strategy {value} requires group_size to be set.")
121
+
122
+ if value is None:
123
+ return QuantizationStrategy.TENSOR
124
+
125
+ return value
@@ -108,6 +108,7 @@ def calculate_compression_ratio(model: Module) -> float:
108
108
  compressed_bits = uncompressed_bits
109
109
  if is_module_quantized(submodule):
110
110
  compressed_bits = submodule.quantization_scheme.weights.num_bits
111
+
111
112
  num_weights = parameter.numel()
112
113
  total_compressed += compressed_bits * num_weights
113
114
  total_uncompressed += uncompressed_bits * num_weights
@@ -12,28 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from pathlib import Path
16
- from typing import Dict, Optional, Union
17
15
 
18
- import torch
16
+ from typing import Optional
17
+
19
18
  from compressed_tensors.base import SPARSITY_CONFIG_NAME
20
19
  from compressed_tensors.compressors import ModelCompressor
21
- from compressed_tensors.config import (
22
- CompressionConfig,
23
- CompressionFormat,
24
- DenseSparsityConfig,
25
- )
26
- from safetensors.torch import save_file
27
- from torch import Tensor
20
+ from compressed_tensors.config import CompressionConfig
28
21
  from transformers import AutoConfig
29
22
 
30
23
 
31
- __all__ = [
32
- "infer_compressor_from_model_config",
33
- "load_compressed",
34
- "save_compressed",
35
- "save_compressed_model",
36
- ]
24
+ __all__ = ["infer_compressor_from_model_config"]
37
25
 
38
26
 
39
27
  def infer_compressor_from_model_config(
@@ -55,97 +43,3 @@ def infer_compressor_from_model_config(
55
43
  sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
56
44
  compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
57
45
  return compressor
58
-
59
-
60
- def save_compressed(
61
- tensors: Dict[str, Tensor],
62
- save_path: Union[str, Path],
63
- compression_format: Optional[CompressionFormat] = None,
64
- ):
65
- """
66
- Save compressed tensors to disk. If tensors are not compressed,
67
- save them as is.
68
-
69
- :param tensors: dictionary of tensors to compress
70
- :param save_path: path to save compressed tensors
71
- :param compression_format: compression format used for the tensors
72
- :return: compression config, if tensors were compressed - None otherwise
73
- """
74
- if tensors is None or len(tensors) == 0:
75
- raise ValueError("No tensors or empty tensors provided to compress")
76
-
77
- # if no compression_format specified, default to `dense_sparsity`
78
- compression_format = compression_format or CompressionFormat.dense_sparsity.value
79
-
80
- if not (
81
- compression_format in ModelCompressor.registered_names()
82
- or compression_format in ModelCompressor.registered_aliases()
83
- ):
84
- raise ValueError(
85
- f"Unknown compression format: {compression_format}. "
86
- f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501
87
- )
88
-
89
- # compress
90
- compressor = ModelCompressor.load_from_registry(compression_format)
91
- # save compressed tensors
92
- compressed_tensors = compressor.compress(tensors)
93
- save_file(compressed_tensors, save_path)
94
-
95
-
96
- def load_compressed(
97
- compressed_tensors: Union[str, Path],
98
- compression_config: CompressionConfig = None,
99
- device: Optional[str] = "cpu",
100
- ) -> Dict[str, Tensor]:
101
- """
102
- Load compressed tensors from disk. If tensors are not compressed,
103
- load them as is.
104
-
105
- :param compressed_tensors: path to compressed tensors
106
- :param compression_config: compression config to use for decompressing tensors.
107
- :param device: device to move tensors to. If None, tensors are loaded on CPU.
108
- :return decompressed tensors
109
- """
110
-
111
- if compressed_tensors is None or not Path(compressed_tensors).exists():
112
- raise ValueError("No compressed tensors provided to load")
113
-
114
- # if no compression_config specified, default to `dense_sparsity`
115
- compression_config = compression_config or DenseSparsityConfig()
116
-
117
- # decompress
118
- compression_format = compression_config.format
119
- compressor = ModelCompressor.load_from_registry(
120
- compression_format, config=compression_config
121
- )
122
- return dict(compressor.decompress(compressed_tensors, device=device))
123
-
124
-
125
- def save_compressed_model(
126
- model: torch.nn.Module,
127
- filename: str,
128
- compression_format: Optional[CompressionFormat] = None,
129
- force_contiguous: bool = True,
130
- ):
131
- """
132
- Wrapper around safetensors `save_model` helper function, which allows for
133
- saving compressed model to disk.
134
-
135
- Note: The model is assumed to have a
136
- state_dict with unique entries
137
-
138
- :param model: model to save on disk
139
- :param filename: filename location to save the file
140
- :param compression_format: compression format used for the model
141
- :param force_contiguous: forcing the state_dict to be saved as contiguous tensors
142
- """
143
- state_dict = model.state_dict()
144
- if force_contiguous:
145
- state_dict = {k: v.contiguous() for k, v in state_dict.items()}
146
- try:
147
- save_compressed(state_dict, filename, compression_format=compression_format)
148
- except ValueError as e:
149
- msg = str(e)
150
- msg += " Or use save_compressed_model(..., force_contiguous=True), read the docs for potential caveats." # noqa E501
151
- raise ValueError(msg)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -20,7 +20,7 @@ Requires-Dist: nbconvert >=7.16.3 ; extra == 'dev'
20
20
  Requires-Dist: pytest >=6.0.0 ; extra == 'dev'
21
21
  Requires-Dist: wheel >=0.36.2 ; extra == 'dev'
22
22
 
23
- # compressed-tensors
23
+ # compressed_tensors
24
24
 
25
25
  This repository extends a [safetensors](https://github.com/huggingface/safetensors) format to efficiently store sparse and/or quantized tensors on disk. `compressed-tensors` format supports multiple compression types to minimize the disk space and facilitate the tensor manipulation.
26
26
 
@@ -103,4 +103,6 @@ save_compressed_model(model, "compressed_model.safetensors", compression_format=
103
103
  state_dict = dict(load_compressed("compressed_model.safetensors", compression_config))
104
104
  ```
105
105
 
106
+ For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/neuralmagic/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb).
107
+
106
108
 
@@ -1,38 +1,38 @@
1
1
  compressed_tensors/__init__.py,sha256=SV1csvHUVCd8kHXz6UDZim1HZ_fAVG3vfk-j_4Bb6hY,789
2
2
  compressed_tensors/base.py,sha256=8zbgK87LpHkKoSknM55svXCT4E4dLLjPijwF9HfzmsQ,717
3
- compressed_tensors/compressors/__init__.py,sha256=3ZHKWSIWTjMx8XXgLtoP9JaVaCTvRecguLZTxLAAkKk,898
4
- compressed_tensors/compressors/base.py,sha256=F1smyJ6x2Sfq43tuP0QE9wZuhVqnewq-XUFPMtdU9yQ,2936
5
- compressed_tensors/compressors/dense.py,sha256=_VTusI3XjaY-zOdB_d7z4zOgPTJi9TJZZHF13g9ulS4,1263
6
- compressed_tensors/compressors/helpers.py,sha256=kSseqbwnu3JHZUKH8u4kQo5bmd87FvCcmWe0u2ikysA,6421
7
- compressed_tensors/compressors/sparse_bitmask.py,sha256=PYAK_Hcy2T57zlbpwl1FYkslluIr2x-d0Rh048YAtpI,8639
3
+ compressed_tensors/compressors/__init__.py,sha256=UcHp0CwUBJoS2MBN6mLUT7B3uRf1TEoRGbME7gLPD38,841
4
+ compressed_tensors/compressors/base.py,sha256=CqQo00ZIkAWpy0yVux5TXhK7WK_6Ws6qb5mCAvIoxB4,3902
5
+ compressed_tensors/compressors/dense.py,sha256=ig9lItmyCX5-VzgMuUqea-s8fHsTjPj5-0VIsPLl0g0,1271
6
+ compressed_tensors/compressors/helpers.py,sha256=wstgUEUYUCTMMu6G1YLF9G7vXqIJPj3MsWhqwU4J6Vw,5458
7
+ compressed_tensors/compressors/sparse_bitmask.py,sha256=qXXFSf1UuQEzodB_xkQgYEJMwPgFsBgTQb8-LqesCsY,8652
8
8
  compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
9
9
  compressed_tensors/config/base.py,sha256=IP-3Y416w-811WozDzKHycIBXjdlG4Ddy7vpbwhOPD8,1373
10
10
  compressed_tensors/config/dense.py,sha256=xtkri7DkP7USu44FnSoTgTSqdGegCBtjRf3DfblSEL0,1311
11
11
  compressed_tensors/config/sparse_bitmask.py,sha256=y8fmQaOoGjIiI4FR6BJjfIqisAcqNQ_zjKyjT75bXwY,1284
12
12
  compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
13
- compressed_tensors/quantization/quant_args.py,sha256=dxSrq0_88ORQXcyIMYqoMZJvYEjnqdYl37f7lgZQqhw,2742
13
+ compressed_tensors/quantization/quant_args.py,sha256=A6b2V8lhsM8Ho8RjlPBQdxRUDNWhqq-ie5E3RR2_GNg,4360
14
14
  compressed_tensors/quantization/quant_config.py,sha256=DWx8ae3gDlw99zAn3MUN9I4qeksbbmITmOXHRynqPB8,6650
15
15
  compressed_tensors/quantization/quant_scheme.py,sha256=X3oqmZPiIKtX5tEKKUj-0N6hB68NeiU2b1GcQEQPadQ,1480
16
16
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=fM9XBtPgJX6z54PTm3Sd0SpK5od95ibwaSf2FFR8DqE,772
17
- compressed_tensors/quantization/lifecycle/apply.py,sha256=WXUL3q1g0s244k0wuqGYZPXTXiscdyrp7RScN2j_KGA,6651
17
+ compressed_tensors/quantization/lifecycle/apply.py,sha256=LQUESSqS5a2_7ij9rHvBdLjjdTOAf9v7chsgfWwh-Jg,6973
18
18
  compressed_tensors/quantization/lifecycle/calibration.py,sha256=mLns4jlaWmBwOW8Jtlm5bMX-JET1AiZYUBO7qa-XuxI,1776
19
- compressed_tensors/quantization/lifecycle/forward.py,sha256=hnjk7pocZLDhLdMx237FKayYdvsdKbYSjTmSN5xbQO8,4599
20
- compressed_tensors/quantization/lifecycle/frozen.py,sha256=NHNmlDIaxurifqeI_qZC8xa4BstQsBNdOCXJjRzAfNU,1596
21
- compressed_tensors/quantization/lifecycle/initialize.py,sha256=8pifqZQSgVqWYI_Qtv6QfBICPbCTFHy48OWPeQsxEHQ,3578
19
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=JcxGBUsthl6_ao5vi6t7poU3YOJsBEzGpE0MEH4Kxus,7600
20
+ compressed_tensors/quantization/lifecycle/frozen.py,sha256=FF7BleuOUX46Egk7F1ZE5r4fjWt9jG5-tO8BjXU1r78,1606
21
+ compressed_tensors/quantization/lifecycle/initialize.py,sha256=U6g9qifSF6pagQZQZEwd-rwWC6uQ_dZXn1wg6nr1Abg,3697
22
22
  compressed_tensors/quantization/observers/__init__.py,sha256=DNH31NQYrIBBcmHsMyFA6whh4pbRsLwuNa6L8AeXaGc,745
23
- compressed_tensors/quantization/observers/base.py,sha256=O76dAxkin7bB602e9kjmxc84p71-PxBtjIq5L69xplI,2786
23
+ compressed_tensors/quantization/observers/base.py,sha256=UqXaR4gOUmMRLKqq4N7IrVuGL11VDWwdmYYFmhk8a3o,5097
24
24
  compressed_tensors/quantization/observers/helpers.py,sha256=SxvOf9zwZ9NDRC3E4Xm7z3RqHcbcPtCABLKX9GnGGHM,2109
25
- compressed_tensors/quantization/observers/memoryless.py,sha256=3f6bUlcf5mzOHPkTRhoQ7Zd8xu_pUmj8e3Y85fGysSU,1848
25
+ compressed_tensors/quantization/observers/memoryless.py,sha256=ZHTPh4aURE8LvHBFaP--HIC2JanMX5-VRdIkE2JHthw,1859
26
26
  compressed_tensors/quantization/observers/min_max.py,sha256=uAcZd5aY6WKM-KumTb2ybX28s8iKGVy6Nrje5Sddqew,2439
27
27
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
28
- compressed_tensors/quantization/utils/helpers.py,sha256=N_wYfrPcFr__Q1mn6mHoNUTclwpTW8P5PDHkR7GvXWo,3694
28
+ compressed_tensors/quantization/utils/helpers.py,sha256=U7tgFUntFbebT43HDSE80rsjlUky_ON_Y8zm__24fd4,3695
29
29
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
30
30
  compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85SLG77nml2iA,11890
31
31
  compressed_tensors/utils/__init__.py,sha256=5DrYjoZbaEvSkJcC-GRSbM_RBHVF4tG9gMd3zsJnjLw,665
32
- compressed_tensors/utils/helpers.py,sha256=wLgiPrk7Vn29AijOGQGk3UnXItRd1jpROS6FxHoC4VQ,5530
32
+ compressed_tensors/utils/helpers.py,sha256=h0jfl9drs5FAx40tCHRcVtJqXixB5hT5yq_IG2aY_-w,1735
33
33
  compressed_tensors/utils/safetensors_load.py,sha256=wo9UirGrGlenBqZeqotvpCT7D5MEdjCo2J3HeRaIFoU,8502
34
- compressed_tensors-0.3.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
- compressed_tensors-0.3.1.dist-info/METADATA,sha256=vpGRbjHWdPUTl9HFoDxIkwAKQJNpff75P4pKC3nJE4A,3850
36
- compressed_tensors-0.3.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
37
- compressed_tensors-0.3.1.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
38
- compressed_tensors-0.3.1.dist-info/RECORD,,
34
+ compressed_tensors-0.3.3.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
+ compressed_tensors-0.3.3.dist-info/METADATA,sha256=ff5Bt4LgmRvE9HGubzPqXfpidTLn7vyTpAMt-k8hvu8,4059
36
+ compressed_tensors-0.3.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
37
+ compressed_tensors-0.3.3.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
38
+ compressed_tensors-0.3.3.dist-info/RECORD,,