compressed-tensors 0.8.0__tar.gz → 0.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/PKG-INFO +1 -1
  2. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/setup.py +32 -6
  3. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +92 -18
  4. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/quantized_compressors/base.py +35 -5
  5. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +6 -4
  6. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -2
  7. {compressed-tensors-0.8.0/src/compressed_tensors/config → compressed-tensors-0.9.0/src/compressed_tensors/compressors/sparse_compressors}/__init__.py +2 -1
  8. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/sparse_compressors/base.py +45 -7
  9. compressed-tensors-0.9.0/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
  10. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
  11. {compressed-tensors-0.8.0/src/compressed_tensors/compressors/sparse_compressors → compressed-tensors-0.9.0/src/compressed_tensors/config}/__init__.py +2 -1
  12. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/config/base.py +1 -0
  13. compressed-tensors-0.9.0/src/compressed_tensors/config/sparse_24_bitmask.py +40 -0
  14. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/linear/compressed_linear.py +3 -1
  15. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/apply.py +48 -2
  16. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/forward.py +2 -2
  17. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/initialize.py +21 -45
  18. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/quant_args.py +16 -3
  19. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/quant_config.py +3 -3
  20. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/quant_scheme.py +17 -24
  21. compressed-tensors-0.9.0/src/compressed_tensors/utils/helpers.py +326 -0
  22. compressed-tensors-0.9.0/src/compressed_tensors/utils/offload.py +404 -0
  23. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/utils/safetensors_load.py +83 -17
  24. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/version.py +1 -1
  25. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors.egg-info/PKG-INFO +1 -1
  26. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors.egg-info/SOURCES.txt +2 -0
  27. compressed-tensors-0.8.0/src/compressed_tensors/utils/helpers.py +0 -121
  28. compressed-tensors-0.8.0/src/compressed_tensors/utils/offload.py +0 -116
  29. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/LICENSE +0 -0
  30. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/README.md +0 -0
  31. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/pyproject.toml +0 -0
  32. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/setup.cfg +0 -0
  33. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/__init__.py +0 -0
  34. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/base.py +0 -0
  35. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/__init__.py +0 -0
  36. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/base.py +0 -0
  37. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/helpers.py +0 -0
  38. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  39. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  40. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  41. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  42. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  43. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/config/dense.py +0 -0
  44. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  45. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/linear/__init__.py +0 -0
  46. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/__init__.py +0 -0
  47. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  48. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  49. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  50. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  51. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  52. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/registry/__init__.py +0 -0
  53. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/registry/registry.py +0 -0
  54. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/utils/__init__.py +0 -0
  55. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/utils/permutations_24.py +0 -0
  56. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/utils/permute.py +0 -0
  57. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  58. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  59. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors.egg-info/requires.txt +0 -0
  60. {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.8.0
3
+ Version: 0.9.0
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,11 +1,11 @@
1
1
  # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
2
+ #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
5
5
  # You may obtain a copy of the License at
6
- #
6
+ #
7
7
  # http://www.apache.org/licenses/LICENSE-2.0
8
- #
8
+ #
9
9
  # Unless required by applicable law or agreed to in writing,
10
10
  # software distributed under the License is distributed on an "AS IS" BASIS,
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -15,7 +15,33 @@
15
15
  import os
16
16
  from setuptools import setup, find_packages
17
17
  from typing import List, Dict, Tuple
18
- from utils.artifacts import get_release_and_version
18
+
19
+
20
+ def get_release_and_version(package_path: str) -> Tuple[bool, bool, str, str, str, str]:
21
+ """
22
+ Load version and release info from compressed-tensors package
23
+ """
24
+ # compressed-tensors/src/compressed_tensors/version.py always exists, default source of truth
25
+ version_path = os.path.join(package_path, "version.py")
26
+
27
+ # exec() cannot set local variables so need to manually
28
+ locals_dict = {}
29
+ exec(open(version_path).read(), globals(), locals_dict)
30
+ is_release = locals_dict.get("is_release", False)
31
+ version = locals_dict.get("version", "unknown")
32
+ version_major = locals_dict.get("version_major", "unknown")
33
+ version_minor = locals_dict.get("version_minor", "unknown")
34
+ version_bug = locals_dict.get("version_bug", "unknown")
35
+
36
+ print(f"Loaded version {version} from {version_path}")
37
+
38
+ return (
39
+ is_release,
40
+ version,
41
+ version_major,
42
+ version_minor,
43
+ version_bug,
44
+ )
19
45
 
20
46
 
21
47
  package_path = os.path.join(
@@ -35,7 +61,7 @@ if is_release:
35
61
  _PACKAGE_NAME = "compressed-tensors"
36
62
  else:
37
63
  _PACKAGE_NAME = "compressed-tensors-nightly"
38
-
64
+
39
65
 
40
66
  def _setup_long_description() -> Tuple[str, str]:
41
67
  return open("README.md", "r", encoding="utf-8").read(), "text/markdown"
@@ -44,7 +70,7 @@ def _setup_packages() -> List:
44
70
  return find_packages(
45
71
  "src", include=["compressed_tensors", "compressed_tensors.*"], exclude=["*.__pycache__.*"]
46
72
  )
47
-
73
+
48
74
  def _setup_install_requires() -> List:
49
75
  return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]
50
76
 
@@ -17,14 +17,14 @@ import logging
17
17
  import operator
18
18
  import os
19
19
  import re
20
+ from contextlib import contextmanager
20
21
  from copy import deepcopy
21
- from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
22
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
22
23
 
23
24
  import compressed_tensors
24
25
  import torch
25
26
  import transformers
26
27
  from compressed_tensors.base import (
27
- COMPRESSION_CONFIG_NAME,
28
28
  COMPRESSION_VERSION_NAME,
29
29
  QUANTIZATION_CONFIG_NAME,
30
30
  QUANTIZATION_METHOD_NAME,
@@ -39,6 +39,8 @@ from compressed_tensors.quantization import (
39
39
  apply_quantization_config,
40
40
  load_pretrained_quantization,
41
41
  )
42
+ from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
43
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
42
44
  from compressed_tensors.quantization.utils import (
43
45
  is_module_quantized,
44
46
  iter_named_leaf_modules,
@@ -103,12 +105,13 @@ class ModelCompressor:
103
105
  :return: compressor for the configs, or None if model is not compressed
104
106
  """
105
107
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
106
- compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
108
+ compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
107
109
  return cls.from_compression_config(compression_config)
108
110
 
109
111
  @classmethod
110
112
  def from_compression_config(
111
- cls, compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
113
+ cls,
114
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
112
115
  ):
113
116
  """
114
117
  :param compression_config:
@@ -135,7 +138,7 @@ class ModelCompressor:
135
138
  format, **sparsity_config
136
139
  )
137
140
  if quantization_config is not None:
138
- quantization_config = QuantizationConfig.parse_obj(quantization_config)
141
+ quantization_config = QuantizationConfig.model_validate(quantization_config)
139
142
 
140
143
  return cls(
141
144
  sparsity_config=sparsity_config, quantization_config=quantization_config
@@ -191,7 +194,7 @@ class ModelCompressor:
191
194
 
192
195
  if is_compressed_tensors_config(compression_config):
193
196
  s_config = compression_config.sparsity_config
194
- return s_config.dict() if s_config is not None else None
197
+ return s_config.model_dump() if s_config is not None else None
195
198
 
196
199
  return compression_config.get(SPARSITY_CONFIG_NAME, None)
197
200
 
@@ -212,7 +215,7 @@ class ModelCompressor:
212
215
 
213
216
  if is_compressed_tensors_config(compression_config):
214
217
  q_config = compression_config.quantization_config
215
- return q_config.dict() if q_config is not None else None
218
+ return q_config.model_dump() if q_config is not None else None
216
219
 
217
220
  quantization_config = deepcopy(compression_config)
218
221
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
@@ -265,7 +268,11 @@ class ModelCompressor:
265
268
  state_dict = model.state_dict()
266
269
 
267
270
  compressed_state_dict = state_dict
268
- quantized_modules_to_args = map_modules_to_quant_args(model)
271
+
272
+ quantized_modules_to_args: Dict[
273
+ str, QuantizationArgs
274
+ ] = map_modules_to_quant_args(model)
275
+
269
276
  if self.quantization_compressor is not None:
270
277
  compressed_state_dict = self.quantization_compressor.compress(
271
278
  state_dict, names_to_scheme=quantized_modules_to_args
@@ -276,8 +283,14 @@ class ModelCompressor:
276
283
  )
277
284
 
278
285
  if self.sparsity_compressor is not None:
286
+ sparse_compression_targets: Set[str] = expand_sparse_target_names(
287
+ model=model,
288
+ targets=self.sparsity_config.targets,
289
+ ignore=self.sparsity_config.ignore,
290
+ )
279
291
  compressed_state_dict = self.sparsity_compressor.compress(
280
- compressed_state_dict
292
+ compressed_state_dict,
293
+ compression_targets=sparse_compression_targets,
281
294
  )
282
295
 
283
296
  # HACK: Override the dtype_byte_size function in transformers to
@@ -295,23 +308,44 @@ class ModelCompressor:
295
308
  :param model: pytorch model to load decompressed weights into
296
309
  """
297
310
  model_path = get_safetensors_folder(model_path)
298
- if self.sparsity_compressor is not None:
311
+ sparse_decompressed = False
312
+
313
+ if (
314
+ self.sparsity_compressor is not None
315
+ and self.sparsity_config.format != CompressionFormat.dense.value
316
+ ):
317
+ # Sparse decompression is applied on the model_path
299
318
  dense_gen = self.sparsity_compressor.decompress(model_path)
300
319
  self._replace_weights(dense_gen, model)
301
320
  setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
321
+ sparse_decompressed = True
302
322
 
303
323
  if self.quantization_compressor is not None:
304
- names_to_scheme = apply_quantization_config(model, self.quantization_config)
305
- load_pretrained_quantization(model, model_path)
324
+ # Temporarily set quantization status to FROZEN to prevent
325
+ # quantization during apply_quantization_config. This ensures
326
+ # that the dtypes of the weights are not unintentionally updated.
327
+ # The status is restored after quantization params are loaded.
328
+ with override_quantization_status(
329
+ self.quantization_config, QuantizationStatus.FROZEN
330
+ ):
331
+ names_to_scheme = apply_quantization_config(
332
+ model, self.quantization_config
333
+ )
334
+ load_pretrained_quantization(model, model_path)
335
+
336
+ model_path_or_state_dict = (
337
+ model.state_dict() if sparse_decompressed else model_path
338
+ )
339
+
306
340
  dense_gen = self.quantization_compressor.decompress(
307
- model_path, names_to_scheme=names_to_scheme
341
+ model_path_or_state_dict, names_to_scheme=names_to_scheme
308
342
  )
309
343
  self._replace_weights(dense_gen, model)
310
344
 
311
- def update_status(module):
345
+ def freeze_quantization_status(module):
312
346
  module.quantization_status = QuantizationStatus.FROZEN
313
347
 
314
- model.apply(update_status)
348
+ model.apply(freeze_quantization_status)
315
349
  setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
316
350
 
317
351
  def update_config(self, save_directory: str):
@@ -361,15 +395,35 @@ class ModelCompressor:
361
395
  with open(config_file_path, "w") as config_file:
362
396
  json.dump(config_data, config_file, indent=2, sort_keys=True)
363
397
 
364
- def _replace_weights(self, dense_weight_generator, model):
398
+ def _replace_weights(self, dense_weight_generator, model: Module):
399
+ """
400
+ Replace the weights of the model with the
401
+ provided dense weights.
402
+
403
+ This method iterates over the dense_weight_generator and
404
+ updates the corresponding weights in the model. If a parameter
405
+ name does not exist in the model, it will be skipped.
406
+
407
+ :param dense_weight_generator (generator): A generator that yields
408
+ tuples of (name, data), where 'name' is the parameter name and
409
+ 'data' is the updated param data
410
+ :param model: The model whose weights are to be updated.
411
+ """
365
412
  for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
366
413
  split_name = name.split(".")
367
414
  prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
368
415
  module = operator.attrgetter(prefix)(model)
369
- update_parameter_data(module, data, param_name)
416
+ if hasattr(module, param_name):
417
+ update_parameter_data(module, data, param_name)
370
418
 
371
419
 
372
- def map_modules_to_quant_args(model: Module) -> Dict:
420
+ def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
421
+ """
422
+ Given a pytorch model, map out the submodule name (usually linear layers)
423
+ to the QuantizationArgs
424
+
425
+ :param model: pytorch model
426
+ """
373
427
  quantized_modules_to_args = {}
374
428
  for name, submodule in iter_named_leaf_modules(model):
375
429
  if is_module_quantized(submodule):
@@ -390,3 +444,23 @@ def new_dtype_byte_size(dtype):
390
444
  raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
391
445
  bit_size = int(bit_search.groups()[0])
392
446
  return bit_size // 8
447
+
448
+
449
+ @contextmanager
450
+ def override_quantization_status(
451
+ config: QuantizationConfig, status: QuantizationStatus
452
+ ):
453
+ """
454
+ Within this context, the quantization status will be set to the
455
+ supplied status. After the context exits, the original status
456
+ will be restored.
457
+
458
+ :param config: the quantization config to override
459
+ :param status: the status to temporarily set
460
+ """
461
+ original_status = config.quantization_status
462
+ config.quantization_status = status
463
+ try:
464
+ yield
465
+ finally:
466
+ config.quantization_status = original_status
@@ -13,12 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from typing import Dict, Generator, Tuple
16
+ from pathlib import Path
17
+ from typing import Any, Dict, Generator, Tuple, Union
17
18
 
18
19
  import torch
19
20
  from compressed_tensors.compressors.base import BaseCompressor
20
21
  from compressed_tensors.quantization import QuantizationArgs
21
- from compressed_tensors.utils import get_nested_weight_mappings, merge_names
22
+ from compressed_tensors.utils import (
23
+ get_nested_mappings_from_state_dict,
24
+ get_nested_weight_mappings,
25
+ merge_names,
26
+ )
22
27
  from safetensors import safe_open
23
28
  from torch import Tensor
24
29
  from tqdm import tqdm
@@ -113,7 +118,7 @@ class BaseQuantizationCompressor(BaseCompressor):
113
118
 
114
119
  def decompress(
115
120
  self,
116
- path_to_model_or_tensors: str,
121
+ path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
117
122
  names_to_scheme: Dict[str, QuantizationArgs],
118
123
  device: str = "cpu",
119
124
  ) -> Generator[Tuple[str, Tensor], None, None]:
@@ -121,15 +126,25 @@ class BaseQuantizationCompressor(BaseCompressor):
121
126
  Reads a compressed state dict located at path_to_model_or_tensors
122
127
  and returns a generator for sequentially decompressing back to a
123
128
  dense state dict
124
-
125
129
  :param path_to_model_or_tensors: path to compressed safetensors model (directory
126
130
  with one or more safetensors files) or compressed tensors file
127
131
  :param names_to_scheme: quantization args for each quantized weight
128
132
  :param device: optional device to load intermediate weights into
129
133
  :return: compressed state dict
130
134
  """
135
+ if isinstance(path_to_model_or_tensors, (str, Path)):
136
+ yield from self._decompress_from_path(
137
+ path_to_model_or_tensors, names_to_scheme, device
138
+ )
139
+
140
+ else:
141
+ yield from self._decompress_from_state_dict(
142
+ path_to_model_or_tensors, names_to_scheme
143
+ )
144
+
145
+ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
131
146
  weight_mappings = get_nested_weight_mappings(
132
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
147
+ path_to_model, self.COMPRESSION_PARAM_NAMES
133
148
  )
134
149
  for weight_name in weight_mappings.keys():
135
150
  weight_data = {}
@@ -137,6 +152,21 @@ class BaseQuantizationCompressor(BaseCompressor):
137
152
  full_name = merge_names(weight_name, param_name)
138
153
  with safe_open(safe_path, framework="pt", device=device) as f:
139
154
  weight_data[param_name] = f.get_tensor(full_name)
155
+ if "weight_scale" in weight_data:
156
+ quant_args = names_to_scheme[weight_name]
157
+ decompressed = self.decompress_weight(
158
+ compressed_data=weight_data, quantization_args=quant_args
159
+ )
160
+ yield merge_names(weight_name, "weight"), decompressed
161
+
162
+ def _decompress_from_state_dict(self, state_dict, names_to_scheme):
163
+ weight_mappings = get_nested_mappings_from_state_dict(
164
+ state_dict, self.COMPRESSION_PARAM_NAMES
165
+ )
166
+ for weight_name in weight_mappings.keys():
167
+ weight_data = {}
168
+ for param_name, param_value in weight_mappings[weight_name].items():
169
+ weight_data[param_name] = param_value
140
170
 
141
171
  if "weight_scale" in weight_data:
142
172
  quant_args = names_to_scheme[weight_name]
@@ -68,9 +68,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
68
68
  self,
69
69
  weight: Tensor,
70
70
  scale: Tensor,
71
+ quantization_args: QuantizationArgs,
71
72
  zero_point: Optional[Tensor] = None,
72
73
  g_idx: Optional[torch.Tensor] = None,
73
- quantization_args: Optional[QuantizationArgs] = None,
74
74
  device: Optional[torch.device] = None,
75
75
  ) -> Dict[str, torch.Tensor]:
76
76
  """
@@ -78,9 +78,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
78
78
 
79
79
  :param weight: uncompressed weight tensor
80
80
  :param scale: quantization scale for weight
81
+ :param quantization_args: quantization parameters for weight
81
82
  :param zero_point: quantization zero point for weight
82
83
  :param g_idx: optional mapping from column index to group index
83
- :param quantization_args: quantization parameters for weight
84
84
  :param device: optional device to move compressed output to
85
85
  :return: dictionary of compressed weight data
86
86
  """
@@ -93,9 +93,11 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
93
93
  args=quantization_args,
94
94
  dtype=quantization_args.pytorch_dtype(),
95
95
  )
96
+ else:
97
+ quantized_weight = weight
96
98
 
97
- if device is not None:
98
- quantized_weight = quantized_weight.to(device)
99
+ if device is not None:
100
+ quantized_weight = quantized_weight.to(device)
99
101
 
100
102
  return {"weight": quantized_weight}
101
103
 
@@ -68,9 +68,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
68
68
  self,
69
69
  weight: Tensor,
70
70
  scale: Tensor,
71
+ quantization_args: QuantizationArgs,
71
72
  zero_point: Optional[Tensor] = None,
72
73
  g_idx: Optional[torch.Tensor] = None,
73
- quantization_args: Optional[QuantizationArgs] = None,
74
74
  device: Optional[torch.device] = None,
75
75
  ) -> Dict[str, torch.Tensor]:
76
76
  """
@@ -78,9 +78,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
78
78
 
79
79
  :param weight: uncompressed weight tensor
80
80
  :param scale: quantization scale for weight
81
+ :param quantization_args: quantization parameters for weight
81
82
  :param zero_point: quantization zero point for weight
82
83
  :param g_idx: optional mapping from column index to group index
83
- :param quantization_args: quantization parameters for weight
84
84
  :param device: optional device to move compressed output to
85
85
  :return: dictionary of compressed weight data
86
86
  """
@@ -94,6 +94,8 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
94
94
  args=quantization_args,
95
95
  dtype=torch.int8,
96
96
  )
97
+ else:
98
+ quantized_weight = weight
97
99
 
98
100
  packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
99
101
  weight_shape = torch.tensor(weight.shape)
@@ -11,8 +11,9 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  # flake8: noqa
15
+
16
16
  from .base import *
17
17
  from .dense import *
18
+ from .sparse_24_bitmask import *
18
19
  from .sparse_bitmask import *
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from typing import Dict, Generator, Tuple
16
+ from typing import Dict, Generator, Optional, Set, Tuple
17
17
 
18
18
  from compressed_tensors.compressors.base import BaseCompressor
19
19
  from compressed_tensors.utils import get_nested_weight_mappings, merge_names
@@ -30,7 +30,8 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
30
30
  class BaseSparseCompressor(BaseCompressor):
31
31
  """
32
32
  Base class representing a sparse compression algorithm. Each child class should
33
- implement compression_param_info, compress_weight and decompress_weight.
33
+ implement compression_param_info, compress_weight and decompress_weight; child
34
+ classes should also define COMPRESSION_PARAM_NAMES.
34
35
 
35
36
  Compressors support compressing/decompressing a full module state dict or a single
36
37
  quantized PyTorch leaf module.
@@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor):
59
60
  :param config: config specifying compression parameters
60
61
  """
61
62
 
62
- def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
63
+ def compress(
64
+ self,
65
+ model_state: Dict[str, Tensor],
66
+ compression_targets: Optional[Set[str]] = None,
67
+ ) -> Dict[str, Tensor]:
63
68
  """
64
69
  Compresses a dense state dict using bitmask compression
65
70
 
66
71
  :param model_state: state dict of uncompressed model
72
+ :param compression_targets: optional set of layer prefixes to compress,
73
+ otherwise compress all layers (for backwards compatibility)
67
74
  :return: compressed state dict
68
75
  """
69
76
  compressed_dict = {}
@@ -71,7 +78,14 @@ class BaseSparseCompressor(BaseCompressor):
71
78
  f"Compressing model with {len(model_state)} parameterized layers..."
72
79
  )
73
80
  for name, value in tqdm(model_state.items(), desc="Compressing model"):
74
- compression_data = self.compress_weight(name, value)
81
+ if not self.should_compress(name, compression_targets):
82
+ compressed_dict[name] = value
83
+ continue
84
+ prefix = name
85
+ if prefix.endswith(".weight"):
86
+ prefix = prefix[: -(len(".weight"))]
87
+
88
+ compression_data = self.compress_weight(prefix, value)
75
89
  for key in compression_data.keys():
76
90
  if key in compressed_dict:
77
91
  _LOGGER.warn(
@@ -97,8 +111,10 @@ class BaseSparseCompressor(BaseCompressor):
97
111
  :param device: device to load decompressed weights onto
98
112
  :return: iterator for generating decompressed weights
99
113
  """
100
- weight_mappings = get_nested_weight_mappings(
101
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
114
+ weight_mappings, ignored_params = get_nested_weight_mappings(
115
+ path_to_model_or_tensors,
116
+ self.COMPRESSION_PARAM_NAMES,
117
+ return_unmatched_params=True,
102
118
  )
103
119
  for weight_name in weight_mappings.keys():
104
120
  weight_data = {}
@@ -107,4 +123,26 @@ class BaseSparseCompressor(BaseCompressor):
107
123
  with safe_open(safe_path, framework="pt", device=device) as f:
108
124
  weight_data[param_name] = f.get_tensor(full_name)
109
125
  decompressed = self.decompress_weight(weight_data)
110
- yield weight_name, decompressed
126
+ yield merge_names(weight_name, "weight"), decompressed
127
+
128
+ for ignored_param_name, safe_path in ignored_params.items():
129
+ with safe_open(safe_path, framework="pt", device=device) as f:
130
+ value = f.get_tensor(ignored_param_name)
131
+ yield ignored_param_name, value
132
+
133
+ @staticmethod
134
+ def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
135
+ """
136
+ Check if a parameter should be compressed.
137
+ Currently, this only returns True for weight parameters.
138
+
139
+ :param name: name of the parameter
140
+ :param expanded_targets: set of layer prefixes to compress
141
+ :return: whether or not the parameter should be compressed
142
+ """
143
+ if expanded_targets is None:
144
+ return name.endswith(".weight")
145
+
146
+ return (
147
+ name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
148
+ )