compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/base.py +3 -1
  2. compressed_tensors/compressors/__init__.py +9 -1
  3. compressed_tensors/compressors/base.py +12 -55
  4. compressed_tensors/compressors/dense.py +5 -5
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/marlin_24.py +251 -0
  7. compressed_tensors/compressors/model_compressor.py +336 -0
  8. compressed_tensors/compressors/naive_quantized.py +144 -0
  9. compressed_tensors/compressors/pack_quantized.py +219 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +4 -4
  11. compressed_tensors/config/base.py +9 -4
  12. compressed_tensors/config/dense.py +4 -4
  13. compressed_tensors/config/sparse_bitmask.py +3 -3
  14. compressed_tensors/quantization/lifecycle/__init__.py +2 -0
  15. compressed_tensors/quantization/lifecycle/apply.py +204 -31
  16. compressed_tensors/quantization/lifecycle/calibration.py +20 -1
  17. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  18. compressed_tensors/quantization/lifecycle/forward.py +214 -62
  19. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  20. compressed_tensors/quantization/lifecycle/helpers.py +53 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +62 -5
  22. compressed_tensors/quantization/observers/base.py +66 -23
  23. compressed_tensors/quantization/observers/helpers.py +69 -11
  24. compressed_tensors/quantization/observers/memoryless.py +17 -9
  25. compressed_tensors/quantization/observers/min_max.py +44 -13
  26. compressed_tensors/quantization/quant_args.py +47 -3
  27. compressed_tensors/quantization/quant_config.py +104 -23
  28. compressed_tensors/quantization/quant_scheme.py +183 -2
  29. compressed_tensors/quantization/utils/helpers.py +142 -8
  30. compressed_tensors/utils/__init__.py +4 -0
  31. compressed_tensors/utils/helpers.py +54 -7
  32. compressed_tensors/utils/offload.py +104 -0
  33. compressed_tensors/utils/permutations_24.py +65 -0
  34. compressed_tensors/utils/safetensors_load.py +3 -2
  35. compressed_tensors/utils/semi_structured_conversions.py +341 -0
  36. compressed_tensors/version.py +53 -0
  37. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
  38. compressed_tensors-0.5.0.dist-info/RECORD +48 -0
  39. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
  40. compressed_tensors-0.3.3.dist-info/RECORD +0 -38
  41. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
  42. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ import logging
16
+ from typing import Any, Iterable, Optional, Tuple, Union
16
17
 
17
18
  import torch
18
19
  from compressed_tensors.quantization.quant_args import (
@@ -24,6 +25,9 @@ from torch import FloatTensor, IntTensor, Tensor
24
25
  from torch.nn import Module
25
26
 
26
27
 
28
+ _LOGGER = logging.getLogger(__name__)
29
+
30
+
27
31
  __all__ = ["Observer"]
28
32
 
29
33
 
@@ -39,7 +43,9 @@ class Observer(Module, RegistryMixin):
39
43
  super().__init__()
40
44
  self._scale = None
41
45
  self._zero_point = None
46
+ self._num_observed_tokens = None
42
47
 
48
+ @torch.no_grad()
43
49
  def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
44
50
  """
45
51
  maps directly to get_qparams
@@ -47,11 +53,19 @@ class Observer(Module, RegistryMixin):
47
53
  from
48
54
  :return: tuple of scale and zero point based on last observed value
49
55
  """
56
+ self.record_observed_tokens(observed)
50
57
  return self.get_qparams(observed=observed)
51
58
 
52
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
59
+ def calculate_qparams(
60
+ self,
61
+ observed: Tensor,
62
+ reduce_dims: Optional[Tuple[int]] = None,
63
+ ) -> Tuple[FloatTensor, IntTensor]:
53
64
  """
54
65
  :param observed: observed tensor to calculate quantization parameters for
66
+ :param reduce_dims: optional tuple of dimensions to reduce along,
67
+ returned scale and zero point will be shaped (1,) along the
68
+ reduced dimensions
55
69
  :return: tuple of scale and zero point derived from the observed tensor
56
70
  """
57
71
  raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
@@ -69,6 +83,7 @@ class Observer(Module, RegistryMixin):
69
83
  Convenience function to wrap overwritten calculate_qparams
70
84
  adds support to make observed tensor optional and support for tracking latest
71
85
  calculated scale and zero point
86
+
72
87
  :param observed: optional observed tensor to calculate quantization parameters
73
88
  from
74
89
  :return: tuple of scale and zero point based on last observed value
@@ -84,47 +99,75 @@ class Observer(Module, RegistryMixin):
84
99
  elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
85
100
  columns = observed.shape[1]
86
101
  scales, zero_points = [], []
87
- for i in range(0, columns, self.quantization_args.group_size):
102
+ group_idxs = range(0, columns, self.quantization_args.group_size)
103
+ for group_id, group_idx in enumerate(group_idxs):
88
104
  scale, zero_point = self.get_qparams_along_dim(
89
- observed[:, i : (i + group_size)],
105
+ observed[:, group_idx : (group_idx + group_size)],
90
106
  0,
107
+ tensor_id=group_id,
91
108
  )
92
109
  scales.append(scale)
93
110
  zero_points.append(zero_point)
94
111
 
95
- self._scale = torch.stack(scales, dim=1)
96
- self._zero_point = torch.stack(zero_points, dim=1)
112
+ self._scale = torch.cat(scales, dim=1, out=self._scale)
113
+ self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)
97
114
 
98
115
  elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
99
116
  # assume observed is transposed, because its the output, hence use dim 0
100
117
  self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
101
118
 
102
119
  elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
103
-
104
120
  # use dim 1, assume the obsersed.shape = [batch, token, hidden]
105
121
  # should be batch, token
106
-
107
122
  self._scale, self._zero_point = self.get_qparams_along_dim(
108
- observed, dim=1
123
+ observed,
124
+ dim={0, 1},
109
125
  )
110
126
 
111
127
  return self._scale, self._zero_point
112
128
 
113
- def get_qparams_along_dim(self, observed, dim: int):
114
- # TODO: add documentation that specifies the shape must
115
- # be padded with 1-dims so the scales are along the right channel
116
- # TODO: generalize the logic for reduce_dims
117
- scales, zero_points = [], []
129
+ def get_qparams_along_dim(
130
+ self,
131
+ observed,
132
+ dim: Union[int, Iterable[int]],
133
+ tensor_id: Optional[Any] = None,
134
+ ):
135
+ dim = set(dim)
118
136
 
119
- # TODO: make a more generic way to get the channel
120
- num_dims = observed.shape[dim]
137
+ reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
138
+ return self.calculate_qparams(
139
+ observed, reduce_dims=reduce_dims, tensor_id=tensor_id
140
+ )
121
141
 
122
- for dim_idx in range(num_dims):
123
- scale, zero_point = self.calculate_qparams(
124
- observed.select(dim=dim, index=dim_idx)
142
+ def record_observed_tokens(self, batch_tensor: Tensor):
143
+ """
144
+ Counts the number of tokens observed during the
145
+ forward passes. The count is aggregated in the
146
+ _num_observed_tokens attribute of the class.
147
+
148
+ Note: The batch_tensor is expected to have two dimensions
149
+ (batch_size * sequence_length, num_features). This is the
150
+ general shape expected by the forward pass of the expert
151
+ layers in a MOE model. If the input tensor does not have
152
+ two dimensions, the _num_observed_tokens attribute will be set
153
+ to None.
154
+ """
155
+ if not isinstance(batch_tensor, Tensor):
156
+ raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")
157
+
158
+ if batch_tensor.ndim != 2:
159
+ _LOGGER.debug(
160
+ "The input tensor is expected to have two dimensions "
161
+ "(batch_size * sequence_length, num_features). "
162
+ f"The input tensor has {batch_tensor.ndim} dimensions."
125
163
  )
164
+ return
165
+
166
+ if self._num_observed_tokens is None:
167
+ # initialize the count
168
+ self._num_observed_tokens = 0
126
169
 
127
- scales.append(scale)
128
- zero_points.append(zero_point)
129
- # breakpoint()
130
- return torch.stack(scales), torch.stack(zero_points)
170
+ # batch_tensor (batch_size * sequence_length, num_features)
171
+ # observed_tokens (batch_size * sequence_length)
172
+ observed_tokens, _ = batch_tensor.shape
173
+ self._num_observed_tokens += observed_tokens
@@ -12,42 +12,100 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from collections import Counter
15
16
  from typing import Tuple
16
17
 
17
18
  import torch
18
- from compressed_tensors.quantization.quant_args import QuantizationArgs
19
+ from compressed_tensors.quantization.quant_args import (
20
+ FP8_DTYPE,
21
+ QuantizationArgs,
22
+ QuantizationType,
23
+ )
19
24
  from torch import FloatTensor, IntTensor, Tensor
20
25
 
21
26
 
22
- __all__ = ["calculate_qparams"]
27
+ __all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
28
+
29
+
30
+ def get_observer_token_count(module: torch.nn.Module) -> Counter:
31
+ """
32
+ Parse the module and return the number of tokens observed by
33
+ each module's observer.
34
+
35
+ :param module: module to parse
36
+ :return: counter with the number of tokens observed by each observer
37
+ """
38
+ token_counts = Counter()
39
+ for name, module in module.named_modules():
40
+ if name.endswith(".input_observer"):
41
+ token_counts[
42
+ name.replace(".input_observer", "")
43
+ ] = module._num_observed_tokens
44
+ return token_counts
23
45
 
24
46
 
25
47
  def calculate_qparams(
26
48
  min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
27
49
  ) -> Tuple[FloatTensor, IntTensor]:
28
50
  """
29
- :param min_vals: tensor of min value(s) to caluclate scale(s) and zero point(s)
51
+ :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
30
52
  from
31
- :param max_vals: tensor of max value(s) to caluclate scale(s) and zero point(s)
53
+ :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
32
54
  from
33
55
  :param quantization_args: settings to quantization
34
56
  :return: tuple of the calculated scale(s) and zero point(s)
35
57
  """
36
58
  min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
37
59
  max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
60
+ device = min_vals.device
61
+
62
+ bit_min, bit_max = calculate_range(quantization_args, device)
63
+ bit_range = bit_max - bit_min
64
+ zp_dtype = quantization_args.pytorch_dtype()
38
65
 
39
- bit_range = 2**quantization_args.num_bits - 1
40
- bit_min = -(bit_range + 1) / 2
41
- bit_max = bit_min + bit_range
42
66
  if quantization_args.symmetric:
43
- zero_points = torch.tensor(0).to(torch.int8)
44
- max_val_pos = torch.max(-min_vals, max_vals)
67
+ max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
45
68
  scales = max_val_pos / (float(bit_range) / 2)
46
69
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
70
+ zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
47
71
  else:
48
72
  scales = (max_vals - min_vals) / float(bit_range)
49
73
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
50
- zero_points = bit_min - torch.round(min_vals / scales)
51
- zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
74
+ zero_points = bit_min - (min_vals / scales)
75
+ zero_points = torch.clamp(zero_points, bit_min, bit_max)
76
+
77
+ # match zero-points to quantized type
78
+ zero_points = zero_points.to(zp_dtype)
79
+
80
+ if scales.ndim == 0:
81
+ scales = scales.reshape(1)
82
+ zero_points = zero_points.reshape(1)
52
83
 
53
84
  return scales, zero_points
85
+
86
+
87
+ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
88
+ """
89
+ Calculated the effective quantization range for the given Quantization Args
90
+
91
+ :param quantization_args: quantization args to get range of
92
+ :param device: device to store the range to
93
+ :return: tuple endpoints for the given quantization range
94
+ """
95
+ if quantization_args.type == QuantizationType.INT:
96
+ bit_range = 2**quantization_args.num_bits
97
+ q_max = torch.tensor(bit_range / 2 - 1, device=device)
98
+ q_min = torch.tensor(-bit_range / 2, device=device)
99
+ elif quantization_args.type == QuantizationType.FLOAT:
100
+ if quantization_args.num_bits != 8:
101
+ raise ValueError(
102
+ "Floating point quantization is only supported for 8 bits,"
103
+ f"got {quantization_args.num_bits}"
104
+ )
105
+ fp_range_info = torch.finfo(FP8_DTYPE)
106
+ q_max = torch.tensor(fp_range_info.max, device=device)
107
+ q_min = torch.tensor(fp_range_info.min, device=device)
108
+ else:
109
+ raise ValueError(f"Invalid quantization type {quantization_args.type}")
110
+
111
+ return q_min, q_max
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  import torch
18
18
  from compressed_tensors.quantization.observers.base import Observer
@@ -30,19 +30,27 @@ class MemorylessObserver(Observer):
30
30
  zero point based on the latest observed value without tracking state
31
31
  """
32
32
 
33
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
33
+ def calculate_qparams(
34
+ self,
35
+ observed: Tensor,
36
+ tensor_id: Optional[Any] = None,
37
+ reduce_dims: Optional[Tuple[int]] = None,
38
+ ) -> Tuple[FloatTensor, IntTensor]:
34
39
  """
35
- Returns the min and max values of observed
40
+ Returns the min and max values of observed tensor
36
41
 
37
42
  :param observed: observed tensor to calculate quantization parameters for
43
+ :param tensor_id: optional id for tensor; not used for memoryless
44
+ :param reduce_dims: optional tuple of dimensions to reduce along,
45
+ returned scale and zero point will be shaped (1,) along the
46
+ reduced dimensions
38
47
  :return: tuple of scale and zero point derived from the observed tensor
39
48
  """
40
- # TODO: Add support for full range of quantization Args, only supports 8bit
41
- # per tensor
42
- min_val, max_val = torch.aminmax(observed)
43
49
 
44
- # ensure zero is in the range
45
- min_val = torch.min(min_val, torch.zeros_like(min_val))
46
- max_val = torch.max(max_val, torch.zeros_like(max_val))
50
+ if not reduce_dims:
51
+ min_val, max_val = torch.aminmax(observed)
52
+ else:
53
+ min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
54
+ max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
47
55
 
48
56
  return calculate_qparams(min_val, max_val, self.quantization_args)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  import torch
18
18
  from compressed_tensors.quantization.observers.base import Observer
@@ -36,30 +36,61 @@ class MovingAverageMinMaxObserver(Observer):
36
36
  ):
37
37
  super().__init__(quantization_args=quantization_args)
38
38
 
39
- self.min_val = float("inf")
40
- self.max_val = -float("inf")
39
+ self.min_val = {}
40
+ self.max_val = {}
41
41
  self.averaging_constant = averaging_constant
42
42
 
43
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
43
+ def calculate_qparams(
44
+ self,
45
+ observed: Tensor,
46
+ reduce_dims: Optional[Tuple[int]] = None,
47
+ tensor_id: Optional[Any] = None,
48
+ ) -> Tuple[FloatTensor, IntTensor]:
44
49
  """
45
50
  Updates the observed min and max using a moving average smoothed by the
46
51
  averaging_constant
47
52
 
48
53
  :param observed: observed tensor to calculate quantization parameters for
54
+ :param reduce_dims: optional tuple of dimensions to reduce along,
55
+ returned scale and zero point will be shaped (1,) along the
56
+ reduced dimensions
57
+ :param tensor_id: Optional id if different ranges of observed tensors are
58
+ passed, useful for sharding tensors by group_size
49
59
  :return: tuple of scale and zero point derived from the observed tensor
50
60
  """
61
+ tensor_id = tensor_id or "default"
51
62
 
52
- min_val, max_val = torch.aminmax(observed)
63
+ if not reduce_dims:
64
+ min_val, max_val = torch.aminmax(observed)
65
+ else:
66
+ min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
67
+ max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
68
+
69
+ running_min_val = self.min_val.get(tensor_id, None)
70
+ running_max_val = self.max_val.get(tensor_id, None)
53
71
 
54
- if self.min_val == float("inf") and self.max_val == float("-inf"):
55
- self.min_val = min_val
56
- self.max_val = max_val
72
+ if running_min_val is None or running_max_val is None:
73
+ updated_min_val = min_val
74
+ updated_max_val = max_val
57
75
  else:
58
- self.min_val = self.min_val + self.averaging_constant * (
59
- min_val - self.min_val
76
+ updated_min_val = running_min_val + self.averaging_constant * (
77
+ min_val - running_min_val
60
78
  )
61
- self.max_val = self.max_val + self.averaging_constant * (
62
- max_val - self.max_val
79
+ updated_max_val = running_max_val + self.averaging_constant * (
80
+ max_val - running_max_val
63
81
  )
64
82
 
65
- return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
83
+ self.min_val[tensor_id] = updated_min_val
84
+ self.max_val[tensor_id] = updated_max_val
85
+
86
+ return calculate_qparams(
87
+ updated_min_val, updated_max_val, self.quantization_args
88
+ )
89
+
90
+ def get_qparams_along_dim(
91
+ self, observed, dim: int, tensor_id: Optional[Any] = None
92
+ ):
93
+ reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
94
+ return self.calculate_qparams(
95
+ observed, reduce_dims=reduce_dims, tensor_id=tensor_id
96
+ )
@@ -15,10 +15,19 @@
15
15
  from enum import Enum
16
16
  from typing import Any, Dict, Optional
17
17
 
18
+ import torch
18
19
  from pydantic import BaseModel, Field, validator
19
20
 
20
21
 
21
- __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
22
+ __all__ = [
23
+ "FP8_DTYPE",
24
+ "QuantizationType",
25
+ "QuantizationStrategy",
26
+ "QuantizationArgs",
27
+ "round_to_quantized_type",
28
+ ]
29
+
30
+ FP8_DTYPE = torch.float8_e4m3fn
22
31
 
23
32
 
24
33
  class QuantizationType(str, Enum):
@@ -42,7 +51,7 @@ class QuantizationStrategy(str, Enum):
42
51
  TOKEN = "token"
43
52
 
44
53
 
45
- class QuantizationArgs(BaseModel):
54
+ class QuantizationArgs(BaseModel, use_enum_values=True):
46
55
  """
47
56
  User facing arguments used to define a quantization config for weights or
48
57
  activations
@@ -62,7 +71,7 @@ class QuantizationArgs(BaseModel):
62
71
  """
63
72
 
64
73
  num_bits: int = 8
65
- type: QuantizationType = QuantizationType.INT
74
+ type: QuantizationType = QuantizationType.INT.value
66
75
  symmetric: bool = True
67
76
  group_size: Optional[int] = None
68
77
  strategy: Optional[QuantizationStrategy] = None
@@ -123,3 +132,38 @@ class QuantizationArgs(BaseModel):
123
132
  return QuantizationStrategy.TENSOR
124
133
 
125
134
  return value
135
+
136
+ def pytorch_dtype(self) -> torch.dtype:
137
+ if self.type == QuantizationType.FLOAT:
138
+ return FP8_DTYPE
139
+ elif self.type == QuantizationType.INT:
140
+ if self.num_bits <= 8:
141
+ return torch.int8
142
+ elif self.num_bits <= 16:
143
+ return torch.int16
144
+ else:
145
+ return torch.int32
146
+ else:
147
+ raise ValueError(f"Invalid quantization type {self.type}")
148
+
149
+
150
+ def round_to_quantized_type(
151
+ tensor: torch.Tensor, args: QuantizationArgs
152
+ ) -> torch.Tensor:
153
+ """
154
+ Rounds each element of the input tensor to the nearest quantized representation,
155
+ keeping to original dtype
156
+
157
+ :param tensor: tensor to round
158
+ :param args: QuantizationArgs to pull appropriate dtype from
159
+ :return: rounded tensor
160
+ """
161
+ original_dtype = tensor.dtype
162
+ if args.type == QuantizationType.FLOAT:
163
+ rounded = tensor.to(FP8_DTYPE)
164
+ elif args.type == QuantizationType.INT:
165
+ rounded = torch.round(tensor)
166
+ else:
167
+ raise ValueError(f"Invalid quantization type {args.type}")
168
+
169
+ return rounded.to(original_dtype)