compressed-tensors 0.4.0__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {compressed-tensors-0.4.0/src/compressed_tensors.egg-info → compressed_tensors-0.5.0}/PKG-INFO +12 -2
  2. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/setup.py +1 -1
  3. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/base.py +1 -0
  4. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/__init__.py +5 -1
  5. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/base.py +1 -1
  6. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/dense.py +1 -1
  7. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/marlin_24.py +11 -10
  8. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/model_compressor.py +33 -12
  9. compressed-tensors-0.4.0/src/compressed_tensors/compressors/int_quantized.py → compressed_tensors-0.5.0/src/compressed_tensors/compressors/naive_quantized.py +33 -15
  10. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/pack_quantized.py +58 -51
  11. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/sparse_bitmask.py +1 -1
  12. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/config/base.py +2 -0
  13. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  14. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/apply.py +161 -39
  15. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/calibration.py +20 -1
  16. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/forward.py +70 -25
  17. compressed_tensors-0.5.0/src/compressed_tensors/quantization/lifecycle/helpers.py +53 -0
  18. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/initialize.py +30 -1
  19. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/observers/base.py +39 -0
  20. compressed_tensors-0.5.0/src/compressed_tensors/quantization/observers/helpers.py +111 -0
  21. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/quant_args.py +45 -1
  22. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/quant_config.py +35 -2
  23. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/quant_scheme.py +105 -4
  24. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/utils/helpers.py +67 -1
  25. {compressed-tensors-0.4.0/src/compressed_tensors/compressors → compressed_tensors-0.5.0/src/compressed_tensors}/utils/__init__.py +2 -1
  26. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/utils/helpers.py +31 -2
  27. compressed_tensors-0.5.0/src/compressed_tensors/utils/offload.py +104 -0
  28. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/version.py +1 -1
  29. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0/src/compressed_tensors.egg-info}/PKG-INFO +12 -2
  30. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors.egg-info/SOURCES.txt +7 -6
  31. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors.egg-info/requires.txt +1 -0
  32. compressed_tensors-0.5.0/tests/test_registry.py +53 -0
  33. compressed-tensors-0.4.0/src/compressed_tensors/compressors/utils/helpers.py +0 -43
  34. compressed-tensors-0.4.0/src/compressed_tensors/quantization/observers/helpers.py +0 -58
  35. compressed-tensors-0.4.0/src/compressed_tensors/utils/__init__.py +0 -16
  36. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/LICENSE +0 -0
  37. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/README.md +0 -0
  38. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/pyproject.toml +0 -0
  39. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/setup.cfg +0 -0
  40. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/__init__.py +0 -0
  41. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/compressors/helpers.py +0 -0
  42. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/config/__init__.py +0 -0
  43. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/config/dense.py +0 -0
  44. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  45. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/__init__.py +0 -0
  46. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  47. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/lifecycle/frozen.py +0 -0
  48. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/observers/__init__.py +0 -0
  49. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/observers/memoryless.py +0 -0
  50. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/observers/min_max.py +0 -0
  51. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  52. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/registry/__init__.py +0 -0
  53. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/registry/registry.py +0 -0
  54. {compressed-tensors-0.4.0/src/compressed_tensors/compressors → compressed_tensors-0.5.0/src/compressed_tensors}/utils/permutations_24.py +0 -0
  55. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  56. {compressed-tensors-0.4.0/src/compressed_tensors/compressors → compressed_tensors-0.5.0/src/compressed_tensors}/utils/semi_structured_conversions.py +0 -0
  57. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  58. {compressed-tensors-0.4.0 → compressed_tensors-0.5.0}/src/compressed_tensors.egg-info/top_level.txt +0 -0
@@ -1,14 +1,24 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.4.0
3
+ Version: 0.5.0
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
7
7
  Author-email: support@neuralmagic.com
8
8
  License: Apache 2.0
9
9
  Description-Content-Type: text/markdown
10
- Provides-Extra: dev
11
10
  License-File: LICENSE
11
+ Requires-Dist: torch>=1.7.0
12
+ Requires-Dist: transformers
13
+ Requires-Dist: accelerate
14
+ Requires-Dist: pydantic>=2.0
15
+ Provides-Extra: dev
16
+ Requires-Dist: black==22.12.0; extra == "dev"
17
+ Requires-Dist: isort==5.8.0; extra == "dev"
18
+ Requires-Dist: wheel>=0.36.2; extra == "dev"
19
+ Requires-Dist: flake8>=3.8.3; extra == "dev"
20
+ Requires-Dist: pytest>=6.0.0; extra == "dev"
21
+ Requires-Dist: nbconvert>=7.16.3; extra == "dev"
12
22
 
13
23
  # compressed_tensors
14
24
 
@@ -46,7 +46,7 @@ def _setup_packages() -> List:
46
46
  )
47
47
 
48
48
  def _setup_install_requires() -> List:
49
- return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]
49
+ return ["torch>=1.7.0", "transformers", "accelerate", "pydantic>=2.0"]
50
50
 
51
51
  def _setup_extras() -> Dict:
52
52
  return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "nbconvert>=7.16.3"]}
@@ -15,3 +15,4 @@
15
15
  SPARSITY_CONFIG_NAME = "sparsity_config"
16
16
  QUANTIZATION_CONFIG_NAME = "quantization_config"
17
17
  COMPRESSION_CONFIG_NAME = "compression_config"
18
+ KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
@@ -17,8 +17,12 @@
17
17
  from .base import Compressor
18
18
  from .dense import DenseCompressor
19
19
  from .helpers import load_compressed, save_compressed, save_compressed_model
20
- from .int_quantized import IntQuantizationCompressor
21
20
  from .marlin_24 import Marlin24Compressor
22
21
  from .model_compressor import ModelCompressor, map_modules_to_quant_args
22
+ from .naive_quantized import (
23
+ FloatQuantizationCompressor,
24
+ IntQuantizationCompressor,
25
+ QuantizationCompressor,
26
+ )
23
27
  from .pack_quantized import PackedQuantizationCompressor
24
28
  from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
@@ -45,7 +45,7 @@ class Compressor(RegistryMixin):
45
45
  raise NotImplementedError()
46
46
 
47
47
  def decompress(
48
- self, path_to_model_or_tensors: str, device: str = "cpu"
48
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
49
49
  ) -> Generator[Tuple[str, Tensor], None, None]:
50
50
  """
51
51
  Reads a compressed state dict located at path_to_model_or_tensors
@@ -29,6 +29,6 @@ class DenseCompressor(Compressor):
29
29
  return model_state
30
30
 
31
31
  def decompress(
32
- self, path_to_model_or_tensors: str, device: str = "cpu"
32
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
33
33
  ) -> Generator[Tuple[str, Tensor], None, None]:
34
34
  return iter([])
@@ -18,15 +18,16 @@ from typing import Dict, Generator, Tuple
18
18
  import numpy as np
19
19
  import torch
20
20
  from compressed_tensors.compressors import Compressor
21
- from compressed_tensors.compressors.utils import (
21
+ from compressed_tensors.config import CompressionFormat
22
+ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
23
+ from compressed_tensors.quantization.lifecycle.forward import quantize
24
+ from compressed_tensors.utils import (
22
25
  get_permutations_24,
26
+ is_quantization_param,
27
+ merge_names,
23
28
  sparse_semi_structured_from_dense_cutlass,
24
29
  tensor_follows_mask_structure,
25
30
  )
26
- from compressed_tensors.config import CompressionFormat
27
- from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
28
- from compressed_tensors.quantization.lifecycle.forward import quantize
29
- from compressed_tensors.utils import is_quantization_param, merge_names
30
31
  from torch import Tensor
31
32
  from tqdm import tqdm
32
33
 
@@ -107,7 +108,7 @@ class Marlin24Compressor(Compressor):
107
108
  def compress(
108
109
  self,
109
110
  model_state: Dict[str, Tensor],
110
- model_quant_args: Dict[str, QuantizationArgs],
111
+ names_to_scheme: Dict[str, QuantizationArgs],
111
112
  **kwargs,
112
113
  ) -> Dict[str, Tensor]:
113
114
  """
@@ -115,11 +116,11 @@ class Marlin24Compressor(Compressor):
115
116
  with the Marlin24 kernel
116
117
 
117
118
  :param model_state: state dict of uncompressed model
118
- :param model_quant_args: quantization args for each quantized weight, needed for
119
+ :param names_to_scheme: quantization args for each quantized weight, needed for
119
120
  quantize function to calculate bit depth
120
121
  :return: compressed state dict
121
122
  """
122
- self.validate_quant_compatability(model_quant_args)
123
+ self.validate_quant_compatability(names_to_scheme)
123
124
 
124
125
  compressed_dict = {}
125
126
  weight_suffix = ".weight"
@@ -139,7 +140,7 @@ class Marlin24Compressor(Compressor):
139
140
  value = value.to(torch.float16)
140
141
 
141
142
  # quantize weight, keeping it as a float16 for now
142
- quant_args = model_quant_args[prefix]
143
+ quant_args = names_to_scheme[prefix]
143
144
  value = quantize(
144
145
  x=value, scale=scale, zero_point=zp, args=quant_args
145
146
  )
@@ -175,7 +176,7 @@ class Marlin24Compressor(Compressor):
175
176
  return compressed_dict
176
177
 
177
178
  def decompress(
178
- self, path_to_model_or_tensors: str, device: str = "cpu"
179
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
179
180
  ) -> Generator[Tuple[str, Tensor], None, None]:
180
181
  raise NotImplementedError(
181
182
  "Decompression is not implemented for the Marlin24 Compressor."
@@ -16,9 +16,12 @@ import json
16
16
  import logging
17
17
  import operator
18
18
  import os
19
+ import re
19
20
  from copy import deepcopy
20
21
  from typing import Any, Dict, Optional, Union
21
22
 
23
+ import torch
24
+ import transformers
22
25
  from compressed_tensors.base import (
23
26
  COMPRESSION_CONFIG_NAME,
24
27
  QUANTIZATION_CONFIG_NAME,
@@ -36,10 +39,10 @@ from compressed_tensors.quantization.utils import (
36
39
  is_module_quantized,
37
40
  iter_named_leaf_modules,
38
41
  )
39
- from compressed_tensors.utils import get_safetensors_folder
42
+ from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
40
43
  from compressed_tensors.utils.helpers import fix_fsdp_module_name
41
44
  from torch import Tensor
42
- from torch.nn import Module, Parameter
45
+ from torch.nn import Module
43
46
  from tqdm import tqdm
44
47
  from transformers import AutoConfig
45
48
  from transformers.file_utils import CONFIG_NAME
@@ -78,6 +81,7 @@ class ModelCompressor:
78
81
  def from_pretrained(
79
82
  cls,
80
83
  pretrained_model_name_or_path: str,
84
+ **kwargs,
81
85
  ) -> Optional["ModelCompressor"]:
82
86
  """
83
87
  Given a path to a model config, extract the sparsity and/or quantization
@@ -86,7 +90,7 @@ class ModelCompressor:
86
90
  :param pretrained_model_name_or_path: path to model config on disk or HF hub
87
91
  :return: compressor for the extracted configs
88
92
  """
89
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
93
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
90
94
  compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
91
95
  return cls.from_compression_config(compression_config)
92
96
 
@@ -228,7 +232,7 @@ class ModelCompressor:
228
232
  quantized_modules_to_args = map_modules_to_quant_args(model)
229
233
  if self.quantization_compressor is not None:
230
234
  compressed_state_dict = self.quantization_compressor.compress(
231
- state_dict, model_quant_args=quantized_modules_to_args
235
+ state_dict, names_to_scheme=quantized_modules_to_args
232
236
  )
233
237
 
234
238
  if self.sparsity_compressor is not None:
@@ -236,6 +240,11 @@ class ModelCompressor:
236
240
  compressed_state_dict
237
241
  )
238
242
 
243
+ # HACK: Override the dtype_byte_size function in transformers to
244
+ # support float8 types. Fix is posted upstream
245
+ # https://github.com/huggingface/transformers/pull/30488
246
+ transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
247
+
239
248
  return compressed_state_dict
240
249
 
241
250
  def decompress(self, model_path: str, model: Module):
@@ -252,9 +261,11 @@ class ModelCompressor:
252
261
  setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
253
262
 
254
263
  if self.quantization_compressor is not None:
255
- apply_quantization_config(model, self.quantization_config)
264
+ names_to_scheme = apply_quantization_config(model, self.quantization_config)
256
265
  load_pretrained_quantization(model, model_path)
257
- dense_gen = self.quantization_compressor.decompress(model_path)
266
+ dense_gen = self.quantization_compressor.decompress(
267
+ model_path, names_to_scheme=names_to_scheme
268
+ )
258
269
  self._replace_weights(dense_gen, model)
259
270
 
260
271
  def update_status(module):
@@ -296,12 +307,10 @@ class ModelCompressor:
296
307
 
297
308
  def _replace_weights(self, dense_weight_generator, model):
298
309
  for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
299
- # loading the decompressed weights into the model
300
- model_device = operator.attrgetter(name)(model).device
301
- data_old = operator.attrgetter(name)(model)
302
- data_dtype = data_old.dtype
303
- data_new = Parameter(data.to(model_device).to(data_dtype))
304
- data_old.data = data_new.data
310
+ split_name = name.split(".")
311
+ prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
312
+ module = operator.attrgetter(prefix)(model)
313
+ update_parameter_data(module, data, param_name)
305
314
 
306
315
 
307
316
  def map_modules_to_quant_args(model: Module) -> Dict:
@@ -313,3 +322,15 @@ def map_modules_to_quant_args(model: Module) -> Dict:
313
322
  quantized_modules_to_args[name] = submodule.quantization_scheme.weights
314
323
 
315
324
  return quantized_modules_to_args
325
+
326
+
327
+ # HACK: Override the dtype_byte_size function in transformers to support float8 types
328
+ # Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
329
+ def new_dtype_byte_size(dtype):
330
+ if dtype == torch.bool:
331
+ return 1 / 8
332
+ bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
333
+ if bit_search is None:
334
+ raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
335
+ bit_size = int(bit_search.groups()[0])
336
+ return bit_size // 8
@@ -27,17 +27,21 @@ from torch import Tensor
27
27
  from tqdm import tqdm
28
28
 
29
29
 
30
- __all__ = ["IntQuantizationCompressor"]
30
+ __all__ = [
31
+ "QuantizationCompressor",
32
+ "IntQuantizationCompressor",
33
+ "FloatQuantizationCompressor",
34
+ ]
31
35
 
32
36
  _LOGGER: logging.Logger = logging.getLogger(__name__)
33
37
 
34
38
 
35
- @Compressor.register(name=CompressionFormat.int_quantized.value)
36
- class IntQuantizationCompressor(Compressor):
39
+ @Compressor.register(name=CompressionFormat.naive_quantized.value)
40
+ class QuantizationCompressor(Compressor):
37
41
  """
38
- Integer compression for quantized models. Weight of each quantized layer is
39
- converted from its original float type to the format specified by the layer's
40
- quantization scheme.
42
+ Implements naive compression for quantized models. Weight of each
43
+ quantized layer is converted from its original float type to the closest Pytorch
44
+ type to the type specified by the layer's QuantizationArgs.
41
45
  """
42
46
 
43
47
  COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]
@@ -45,14 +49,14 @@ class IntQuantizationCompressor(Compressor):
45
49
  def compress(
46
50
  self,
47
51
  model_state: Dict[str, Tensor],
48
- model_quant_args: Dict[str, QuantizationArgs],
52
+ names_to_scheme: Dict[str, QuantizationArgs],
49
53
  **kwargs,
50
54
  ) -> Dict[str, Tensor]:
51
55
  """
52
56
  Compresses a dense state dict
53
57
 
54
58
  :param model_state: state dict of uncompressed model
55
- :param model_quant_args: quantization args for each quantized weight, needed for
59
+ :param names_to_scheme: quantization args for each quantized weight, needed for
56
60
  quantize function to calculate bit depth
57
61
  :return: compressed state dict
58
62
  """
@@ -69,7 +73,7 @@ class IntQuantizationCompressor(Compressor):
69
73
  zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
70
74
  if scale is not None and zp is not None:
71
75
  # weight is quantized, compress it
72
- quant_args = model_quant_args[prefix]
76
+ quant_args = names_to_scheme[prefix]
73
77
  if can_quantize(value, quant_args):
74
78
  # only quantize if not already quantized
75
79
  value = quantize(
@@ -77,7 +81,7 @@ class IntQuantizationCompressor(Compressor):
77
81
  scale=scale,
78
82
  zero_point=zp,
79
83
  args=quant_args,
80
- dtype=torch.int8,
84
+ dtype=quant_args.pytorch_dtype(),
81
85
  )
82
86
  elif name.endswith("zero_point"):
83
87
  if torch.all(value == 0):
@@ -89,7 +93,7 @@ class IntQuantizationCompressor(Compressor):
89
93
  return compressed_dict
90
94
 
91
95
  def decompress(
92
- self, path_to_model_or_tensors: str, device: str = "cpu"
96
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
93
97
  ) -> Generator[Tuple[str, Tensor], None, None]:
94
98
  """
95
99
  Reads a compressed state dict located at path_to_model_or_tensors
@@ -114,13 +118,27 @@ class IntQuantizationCompressor(Compressor):
114
118
  if "weight_scale" in weight_data:
115
119
  zero_point = weight_data.get("weight_zero_point", None)
116
120
  scale = weight_data["weight_scale"]
117
- if zero_point is None:
118
- # zero_point assumed to be 0 if not included in state_dict
119
- zero_point = torch.zeros_like(scale)
120
-
121
121
  decompressed = dequantize(
122
122
  x_q=weight_data["weight"],
123
123
  scale=scale,
124
124
  zero_point=zero_point,
125
125
  )
126
126
  yield merge_names(weight_name, "weight"), decompressed
127
+
128
+
129
+ @Compressor.register(name=CompressionFormat.int_quantized.value)
130
+ class IntQuantizationCompressor(QuantizationCompressor):
131
+ """
132
+ Alias for integer quantized models
133
+ """
134
+
135
+ pass
136
+
137
+
138
+ @Compressor.register(name=CompressionFormat.float_quantized.value)
139
+ class FloatQuantizationCompressor(QuantizationCompressor):
140
+ """
141
+ Alias for fp quantized models
142
+ """
143
+
144
+ pass
@@ -29,7 +29,7 @@ from torch import Tensor
29
29
  from tqdm import tqdm
30
30
 
31
31
 
32
- __all__ = ["PackedQuantizationCompressor", "pack_4bit_ints", "unpack_4bit_ints"]
32
+ __all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
33
33
 
34
34
  _LOGGER: logging.Logger = logging.getLogger(__name__)
35
35
 
@@ -50,14 +50,14 @@ class PackedQuantizationCompressor(Compressor):
50
50
  def compress(
51
51
  self,
52
52
  model_state: Dict[str, Tensor],
53
- model_quant_args: Dict[str, QuantizationArgs],
53
+ names_to_scheme: Dict[str, QuantizationArgs],
54
54
  **kwargs,
55
55
  ) -> Dict[str, Tensor]:
56
56
  """
57
57
  Compresses a dense state dict
58
58
 
59
59
  :param model_state: state dict of uncompressed model
60
- :param model_quant_args: quantization args for each quantized weight, needed for
60
+ :param names_to_scheme: quantization args for each quantized weight, needed for
61
61
  quantize function to calculate bit depth
62
62
  :return: compressed state dict
63
63
  """
@@ -75,7 +75,7 @@ class PackedQuantizationCompressor(Compressor):
75
75
  shape = torch.tensor(value.shape)
76
76
  if scale is not None and zp is not None:
77
77
  # weight is quantized, compress it
78
- quant_args = model_quant_args[prefix]
78
+ quant_args = names_to_scheme[prefix]
79
79
  if can_quantize(value, quant_args):
80
80
  # convert weight to an int if not already compressed
81
81
  value = quantize(
@@ -85,7 +85,7 @@ class PackedQuantizationCompressor(Compressor):
85
85
  args=quant_args,
86
86
  dtype=torch.int8,
87
87
  )
88
- value = pack_4bit_ints(value.cpu())
88
+ value = pack_to_int32(value.cpu(), quant_args.num_bits)
89
89
  compressed_dict[merge_names(prefix, "weight_shape")] = shape
90
90
  compressed_dict[merge_names(prefix, "weight_packed")] = value
91
91
  continue
@@ -101,7 +101,10 @@ class PackedQuantizationCompressor(Compressor):
101
101
  return compressed_dict
102
102
 
103
103
  def decompress(
104
- self, path_to_model_or_tensors: str, device: str = "cpu"
104
+ self,
105
+ path_to_model_or_tensors: str,
106
+ names_to_scheme: Dict[str, QuantizationArgs],
107
+ device: str = "cpu",
105
108
  ) -> Generator[Tuple[str, Tensor], None, None]:
106
109
  """
107
110
  Reads a compressed state dict located at path_to_model_or_tensors
@@ -119,6 +122,7 @@ class PackedQuantizationCompressor(Compressor):
119
122
  for weight_name in weight_mappings.keys():
120
123
  weight_data = {}
121
124
  for param_name, safe_path in weight_mappings[weight_name].items():
125
+ weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
122
126
  full_name = merge_names(weight_name, param_name)
123
127
  with safe_open(safe_path, framework="pt", device=device) as f:
124
128
  weight_data[param_name] = f.get_tensor(full_name)
@@ -126,13 +130,10 @@ class PackedQuantizationCompressor(Compressor):
126
130
  if "weight_scale" in weight_data:
127
131
  zero_point = weight_data.get("weight_zero_point", None)
128
132
  scale = weight_data["weight_scale"]
129
- if zero_point is None:
130
- # zero_point assumed to be 0 if not included in state_dict
131
- zero_point = torch.zeros_like(scale)
132
-
133
133
  weight = weight_data["weight_packed"]
134
+ num_bits = weight_data["num_bits"]
134
135
  original_shape = torch.Size(weight_data["weight_shape"])
135
- unpacked = unpack_4bit_ints(weight, original_shape)
136
+ unpacked = unpack_from_int32(weight, num_bits, original_shape)
136
137
  decompressed = dequantize(
137
138
  x_q=unpacked,
138
139
  scale=scale,
@@ -141,45 +142,50 @@ class PackedQuantizationCompressor(Compressor):
141
142
  yield merge_names(weight_name, "weight"), decompressed
142
143
 
143
144
 
144
- def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor:
145
+ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
145
146
  """
146
- Packs a tensor of int4 weights stored in int8 into int32s with padding
147
+ Packs a tensor of quantized weights stored in int8 into int32s with padding
147
148
 
148
149
  :param value: tensor to pack
150
+ :param num_bits: number of bits used to store underlying data
149
151
  :returns: packed int32 tensor
150
152
  """
151
153
  if value.dtype is not torch.int8:
152
154
  raise ValueError("Tensor must be quantized to torch.int8 before packing")
153
155
 
154
- # need to convert to unsigned 8bit to use numpy's pack/unpack
155
- temp = (value - 8).to(torch.uint8)
156
- bits = np.unpackbits(temp.numpy(), axis=-1, bitorder="little")
157
- ranges = np.array([range(x, x + 4) for x in range(0, bits.shape[1], 8)]).flatten()
158
- only_4_bits = bits[:, ranges] # top 4 bits are 0 because we're really uint4
156
+ if num_bits > 8:
157
+ raise ValueError("Packing is only supported for less than 8 bits")
159
158
 
160
- # pad each row to fill a full 32bit int
161
- pack_depth = 32
162
- padding = (
163
- math.ceil(only_4_bits.shape[1] / pack_depth) * pack_depth - only_4_bits.shape[1]
164
- )
165
- padded_bits = np.pad(
166
- only_4_bits, pad_width=[(0, 0), (0, padding)], constant_values=0
167
- )
159
+ # convert to unsigned for packing
160
+ offset = pow(2, num_bits) // 2
161
+ value = (value + offset).to(torch.uint8)
162
+ value = value.cpu().numpy().astype(np.uint32)
163
+ pack_factor = 32 // num_bits
168
164
 
169
- # after packbits each uint8 is two packed uint4s
170
- # then we keep the bit pattern the same but convert to int32
171
- compressed = np.packbits(padded_bits, axis=-1, bitorder="little")
172
- compressed = np.ascontiguousarray(compressed).view(np.int32)
165
+ # pad input tensor and initialize packed output
166
+ packed_size = math.ceil(value.shape[1] / pack_factor)
167
+ packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
168
+ padding = packed.shape[1] * pack_factor - value.shape[1]
169
+ value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
173
170
 
174
- return torch.from_numpy(compressed)
171
+ # pack values
172
+ for i in range(pack_factor):
173
+ packed |= value[:, i::pack_factor] << num_bits * i
175
174
 
175
+ # convert back to signed and torch
176
+ packed = np.ascontiguousarray(packed).view(np.int32)
177
+ return torch.from_numpy(packed)
176
178
 
177
- def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
179
+
180
+ def unpack_from_int32(
181
+ value: torch.Tensor, num_bits: int, shape: torch.Size
182
+ ) -> torch.Tensor:
178
183
  """
179
- Unpacks a tensor packed int4 weights into individual int8s, maintaining the
180
- original their int4 range
184
+ Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
185
+ original their bit range
181
186
 
182
187
  :param value: tensor to upack
188
+ :param num_bits: number of bits to unpack each data point into
183
189
  :param shape: shape to unpack into, used to remove padding
184
190
  :returns: unpacked int8 tensor
185
191
  """
@@ -188,25 +194,26 @@ def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
188
194
  f"Expected {torch.int32} but got {value.dtype}, Aborting unpack."
189
195
  )
190
196
 
191
- # unpack bits and undo padding to nearest int32 bits
192
- individual_depth = 4
193
- as_uint8 = value.numpy().view(np.uint8)
194
- bits = np.unpackbits(as_uint8, axis=-1, bitorder="little")
195
- original_row_size = int(shape[1] * individual_depth)
196
- bits = bits[:, :original_row_size]
197
+ if num_bits > 8:
198
+ raise ValueError("Unpacking is only supported for less than 8 bits")
199
+
200
+ # convert packed input to unsigned numpy
201
+ value = value.numpy().view(np.uint32)
202
+ pack_factor = 32 // num_bits
197
203
 
198
- # reformat each packed uint4 to a uint8 by filling to top 4 bits with zeros
199
- # (uint8 format is required by np.packbits)
200
- shape_8bit = (bits.shape[0], bits.shape[1] * 2)
201
- bits_as_8bit = np.zeros(shape_8bit, dtype=np.uint8)
202
- ranges = np.array([range(x, x + 4) for x in range(0, shape_8bit[1], 8)]).flatten()
203
- bits_as_8bit[:, ranges] = bits
204
+ # unpack
205
+ mask = pow(2, num_bits) - 1
206
+ unpacked = np.zeros((value.shape[0], value.shape[1] * pack_factor))
207
+ for i in range(pack_factor):
208
+ unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
204
209
 
205
- # repack the bits to uint8
206
- repacked = np.packbits(bits_as_8bit, axis=-1, bitorder="little")
210
+ # remove padding
211
+ original_row_size = int(shape[1])
212
+ unpacked = unpacked[:, :original_row_size]
207
213
 
208
214
  # bits are packed in unsigned format, reformat to signed
209
- # update the value range from uint4 to int4
210
- final = repacked.astype(np.int8) - 8
215
+ # update the value range from unsigned to signed
216
+ offset = pow(2, num_bits) // 2
217
+ unpacked = (unpacked.astype(np.int16) - offset).astype(np.int8)
211
218
 
212
- return torch.from_numpy(final)
219
+ return torch.from_numpy(unpacked)
@@ -72,7 +72,7 @@ class BitmaskCompressor(Compressor):
72
72
  return compressed_dict
73
73
 
74
74
  def decompress(
75
- self, path_to_model_or_tensors: str, device: str = "cpu"
75
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
76
76
  ) -> Generator[Tuple[str, Tensor], None, None]:
77
77
  """
78
78
  Reads a bitmask compressed state dict located
@@ -26,6 +26,8 @@ class CompressionFormat(Enum):
26
26
  dense = "dense"
27
27
  sparse_bitmask = "sparse-bitmask"
28
28
  int_quantized = "int-quantized"
29
+ float_quantized = "float-quantized"
30
+ naive_quantized = "naive-quantized"
29
31
  pack_quantized = "pack-quantized"
30
32
  marlin_24 = "marlin-24"
31
33
 
@@ -21,3 +21,4 @@ from .frozen import *
21
21
  from .initialize import *
22
22
  from .compressed import *
23
23
  from .apply import *
24
+ from .helpers import *