compressed-tensors 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. compressed_tensors/__init__.py +21 -0
  2. compressed_tensors/base.py +16 -0
  3. compressed_tensors/compressors/__init__.py +25 -0
  4. compressed_tensors/compressors/base.py +79 -0
  5. compressed_tensors/compressors/dense.py +34 -0
  6. compressed_tensors/compressors/helpers.py +161 -0
  7. compressed_tensors/compressors/sparse_bitmask.py +238 -0
  8. compressed_tensors/config/__init__.py +18 -0
  9. compressed_tensors/config/base.py +42 -0
  10. compressed_tensors/config/dense.py +36 -0
  11. compressed_tensors/config/sparse_bitmask.py +36 -0
  12. compressed_tensors/quantization/__init__.py +21 -0
  13. compressed_tensors/quantization/lifecycle/__init__.py +22 -0
  14. compressed_tensors/quantization/lifecycle/apply.py +173 -0
  15. compressed_tensors/quantization/lifecycle/calibration.py +51 -0
  16. compressed_tensors/quantization/lifecycle/forward.py +136 -0
  17. compressed_tensors/quantization/lifecycle/frozen.py +46 -0
  18. compressed_tensors/quantization/lifecycle/initialize.py +96 -0
  19. compressed_tensors/quantization/observers/__init__.py +21 -0
  20. compressed_tensors/quantization/observers/base.py +69 -0
  21. compressed_tensors/quantization/observers/helpers.py +53 -0
  22. compressed_tensors/quantization/observers/memoryless.py +48 -0
  23. compressed_tensors/quantization/observers/min_max.py +65 -0
  24. compressed_tensors/quantization/quant_args.py +85 -0
  25. compressed_tensors/quantization/quant_config.py +171 -0
  26. compressed_tensors/quantization/quant_scheme.py +39 -0
  27. compressed_tensors/quantization/utils/__init__.py +16 -0
  28. compressed_tensors/quantization/utils/helpers.py +115 -0
  29. compressed_tensors/registry/__init__.py +17 -0
  30. compressed_tensors/registry/registry.py +360 -0
  31. compressed_tensors/utils/__init__.py +16 -0
  32. compressed_tensors/utils/helpers.py +151 -0
  33. compressed_tensors/utils/safetensors_load.py +237 -0
  34. compressed_tensors-0.3.0.dist-info/METADATA +22 -0
  35. compressed_tensors-0.3.0.dist-info/RECORD +37 -0
  36. compressed_tensors-0.3.0.dist-info/WHEEL +5 -0
  37. compressed_tensors-0.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,360 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Universal registry to support registration and loading of child classes and plugins
17
+ of neuralmagic utilities
18
+ """
19
+
20
+ import importlib
21
+ from collections import defaultdict
22
+ from typing import Any, Dict, List, Optional, Type, Union
23
+
24
+
25
+ __all__ = [
26
+ "RegistryMixin",
27
+ "register",
28
+ "get_from_registry",
29
+ "registered_names",
30
+ "registered_aliases",
31
+ "standardize_lookup_name",
32
+ ]
33
+
34
+
35
+ _ALIAS_REGISTRY: Dict[Type, Dict[str, str]] = defaultdict(dict)
36
+ _REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict)
37
+
38
+
39
+ def standardize_lookup_name(name: str) -> str:
40
+ """
41
+ Standardize the given name for lookup in the registry.
42
+ This will replace all underscores and spaces with hyphens and
43
+ convert the name to lowercase.
44
+
45
+ example:
46
+ ```
47
+ standardize_lookup_name("Foo_bar baz") == "foo-bar-baz"
48
+ ```
49
+
50
+ :param name: name to standardize
51
+ :return: standardized name
52
+ """
53
+ return name.replace("_", "-").replace(" ", "-").lower()
54
+
55
+
56
+ def standardize_alias_name(
57
+ name: Union[None, str, List[str]]
58
+ ) -> Union[None, str, List[str]]:
59
+ if name is None:
60
+ return None
61
+ elif isinstance(name, str):
62
+ return standardize_lookup_name(name)
63
+ else: # isinstance(name, list)
64
+ return [standardize_lookup_name(n) for n in name]
65
+
66
+
67
+ class RegistryMixin:
68
+ """
69
+ Universal registry to support registration and loading of child classes and plugins
70
+ of neuralmagic utilities.
71
+
72
+ Classes that require a registry or plugins may add the `RegistryMixin` and use
73
+ `register` and `load` as the main entrypoints for adding new implementations and
74
+ loading requested values from its registry.
75
+
76
+ If a class should only have its child classes in its registry, the class should
77
+ set the static attribute `registry_requires_subclass` to True
78
+
79
+ example
80
+ ```python
81
+ class Dataset(RegistryMixin):
82
+ pass
83
+
84
+
85
+ # register with default name
86
+ @Dataset.register()
87
+ class ImageNetDataset(Dataset):
88
+ pass
89
+
90
+ # load as "ImageNetDataset"
91
+ imagenet = Dataset.load("ImageNetDataset")
92
+
93
+ # register with custom name
94
+ @Dataset.register(name="cifar-dataset")
95
+ class Cifar(Dataset):
96
+ pass
97
+
98
+ Note: the name will be standardized for lookup in the registry.
99
+ For example, if a class is registered as "cifar_dataset" or
100
+ "cifar dataset", it will be stored as "cifar-dataset". The user
101
+ will be able to load the class with any of the three name variants.
102
+
103
+ # register with multiple aliases
104
+ @Dataset.register(alias=["cifar-10-dataset", "cifar_100_dataset"])
105
+ class Cifar(Dataset):
106
+ pass
107
+
108
+ # load as "cifar-dataset"
109
+ cifar = Dataset.load_from_registry("cifar-dataset")
110
+
111
+ # load from custom file that implements a dataset
112
+ mnist = Dataset.load_from_registry("/path/to/mnnist_dataset.py:MnistDataset")
113
+ ```
114
+ """
115
+
116
+ # set to True in child class to add check that registered/retrieved values
117
+ # implement the class it is registered to
118
+ registry_requires_subclass: bool = False
119
+
120
+ @classmethod
121
+ def register(
122
+ cls, name: Optional[str] = None, alias: Union[List[str], str, None] = None
123
+ ):
124
+ """
125
+ Decorator for registering a value (ie class or function) wrapped by this
126
+ decorator to the base class (class that .register is called from)
127
+
128
+ :param name: name or list of names to register the wrapped value as,
129
+ defaults to value.__name__
130
+ :param alias: alias or list of aliases to register the wrapped value as,
131
+ defaults to None
132
+ :return: register decorator
133
+ """
134
+
135
+ def decorator(value: Any):
136
+ cls.register_value(value, name=name, alias=alias)
137
+ return value
138
+
139
+ return decorator
140
+
141
+ @classmethod
142
+ def register_value(
143
+ cls, value: Any, name: str, alias: Union[str, List[str], None] = None
144
+ ):
145
+ """
146
+ Registers the given value to the class `.register_value` is called from
147
+ :param value: value to register
148
+ :param name: name to register the wrapped value as,
149
+ defaults to value.__name__
150
+ :param alias: alias or list of aliases to register the wrapped value as,
151
+ defaults to None
152
+ """
153
+ register(
154
+ parent_class=cls,
155
+ value=value,
156
+ name=name,
157
+ alias=alias,
158
+ require_subclass=cls.registry_requires_subclass,
159
+ )
160
+
161
+ @classmethod
162
+ def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
163
+ """
164
+ :param name: name of registered class to load
165
+ :param constructor_kwargs: arguments to pass to the constructor retrieved
166
+ from the registry
167
+ :return: loaded object registered to this class under the given name,
168
+ constructed with the given kwargs. Raises error if the name is
169
+ not found in the registry
170
+ """
171
+ constructor = cls.get_value_from_registry(name=name)
172
+ return constructor(**constructor_kwargs)
173
+
174
+ @classmethod
175
+ def get_value_from_registry(cls, name: str):
176
+ """
177
+ :param name: name to retrieve from the registry
178
+ :return: value from retrieved the registry for the given name, raises
179
+ error if not found
180
+ """
181
+ return get_from_registry(
182
+ parent_class=cls,
183
+ name=name,
184
+ require_subclass=cls.registry_requires_subclass,
185
+ )
186
+
187
+ @classmethod
188
+ def registered_names(cls) -> List[str]:
189
+ """
190
+ :return: list of all names registered to this class
191
+ """
192
+ return registered_names(cls)
193
+
194
+ @classmethod
195
+ def registered_aliases(cls) -> List[str]:
196
+ """
197
+ :return: list of all aliases registered to this class
198
+ """
199
+ return registered_aliases(cls)
200
+
201
+
202
+ def register(
203
+ parent_class: Type,
204
+ value: Any,
205
+ name: Optional[str] = None,
206
+ alias: Union[List[str], str, None] = None,
207
+ require_subclass: bool = False,
208
+ ):
209
+ """
210
+ :param parent_class: class to register the name under
211
+ :param value: the value to register
212
+ :param name: name to register the wrapped value as, defaults to value.__name__
213
+ :param alias: alias or list of aliases to register the wrapped value as,
214
+ defaults to None
215
+ :param require_subclass: require that value is a subclass of the class this
216
+ method is called from
217
+ """
218
+ if name is None:
219
+ # default name
220
+ name = value.__name__
221
+
222
+ name = standardize_lookup_name(name)
223
+ alias = standardize_alias_name(alias)
224
+ register_alias(name=name, alias=alias, parent_class=parent_class)
225
+
226
+ if require_subclass:
227
+ _validate_subclass(parent_class, value)
228
+
229
+ if name in _REGISTRY[parent_class]:
230
+ # name already exists - raise error if two different values are attempting
231
+ # to share the same name
232
+ registered_value = _REGISTRY[parent_class][name]
233
+ if registered_value is not value:
234
+ raise RuntimeError(
235
+ f"Attempting to register name {name} as {value} "
236
+ f"however {name} has already been registered as {registered_value}"
237
+ )
238
+ else:
239
+ _REGISTRY[parent_class][name] = value
240
+
241
+
242
+ def get_from_registry(
243
+ parent_class: Type, name: str, require_subclass: bool = False
244
+ ) -> Any:
245
+ """
246
+ :param parent_class: class that the name is registered under
247
+ :param name: name to retrieve from the registry of the class
248
+ :param require_subclass: require that value is a subclass of the class this
249
+ method is called from
250
+ :return: value from retrieved the registry for the given name, raises
251
+ error if not found
252
+ """
253
+ name = standardize_lookup_name(name)
254
+
255
+ if ":" in name:
256
+ # user specifying specific module to load and value to import
257
+ module_path, value_name = name.split(":")
258
+ retrieved_value = _import_and_get_value_from_module(module_path, value_name)
259
+ else:
260
+ # look up name in alias registry
261
+ name = _ALIAS_REGISTRY[parent_class].get(name)
262
+ # look up name in registry
263
+ retrieved_value = _REGISTRY[parent_class].get(name)
264
+ if retrieved_value is None:
265
+ raise KeyError(
266
+ f"Unable to find {name} registered under type {parent_class}.\n"
267
+ f"Registered values for {parent_class}: "
268
+ f"{registered_names(parent_class)}\n"
269
+ f"Registered aliases for {parent_class}: "
270
+ f"{registered_aliases(parent_class)}"
271
+ )
272
+
273
+ if require_subclass:
274
+ _validate_subclass(parent_class, retrieved_value)
275
+
276
+ return retrieved_value
277
+
278
+
279
+ def registered_names(parent_class: Type) -> List[str]:
280
+ """
281
+ :param parent_class: class to look up the registry of
282
+ :return: all names registered to the given class
283
+ """
284
+ return list(_REGISTRY[parent_class].keys())
285
+
286
+
287
+ def registered_aliases(parent_class: Type) -> List[str]:
288
+ """
289
+ :param parent_class: class to look up the registry of
290
+ :return: all aliases registered to the given class
291
+ """
292
+ registered_aliases_plus_names = list(_ALIAS_REGISTRY[parent_class].keys())
293
+ registered_aliases = list(
294
+ set(registered_aliases_plus_names) - set(registered_names(parent_class))
295
+ )
296
+ return registered_aliases
297
+
298
+
299
+ def register_alias(
300
+ name: str, parent_class: Type, alias: Union[str, List[str], None] = None
301
+ ):
302
+ """
303
+ Updates the mapping from the alias(es) to the given name.
304
+ If the alias is None, the name is used as the alias.
305
+ ```
306
+
307
+ :param name: name that the alias refers to
308
+ :param parent_class: class that the name is registered under
309
+ :param alias: single alias or list of aliases that
310
+ refer to the name, defaults to None
311
+ """
312
+ if alias is not None:
313
+ alias = alias if isinstance(alias, list) else [alias]
314
+ else:
315
+ alias = []
316
+
317
+ if name in alias:
318
+ raise KeyError(
319
+ f"Attempting to register alias {name}, "
320
+ f"that is identical to the standardized name: {name}."
321
+ )
322
+ alias.append(name)
323
+
324
+ for alias_name in alias:
325
+ if alias_name in _ALIAS_REGISTRY[parent_class]:
326
+ raise KeyError(
327
+ f"Attempting to register alias {alias_name} as {name} "
328
+ f"however {alias_name} has already been registered as "
329
+ f"{_ALIAS_REGISTRY[alias_name]}"
330
+ )
331
+ _ALIAS_REGISTRY[parent_class][alias_name] = name
332
+
333
+
334
+ def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
335
+ # import the given module path and try to get the value_name if it is included
336
+ # in the module
337
+
338
+ # load module
339
+ spec = importlib.util.spec_from_file_location(
340
+ f"plugin_module_for_{value_name}", module_path
341
+ )
342
+ module = importlib.util.module_from_spec(spec)
343
+ spec.loader.exec_module(module)
344
+
345
+ # get value from module
346
+ value = getattr(module, value_name, None)
347
+
348
+ if not value:
349
+ raise RuntimeError(
350
+ f"Unable to find attribute {value_name} in module {module_path}"
351
+ )
352
+ return value
353
+
354
+
355
+ def _validate_subclass(parent_class: Type, child_class: Type):
356
+ if not issubclass(child_class, parent_class):
357
+ raise ValueError(
358
+ f"class {child_class} is not a subclass of the class it is "
359
+ f"registered for: {parent_class}."
360
+ )
@@ -0,0 +1,16 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .safetensors_load import *
@@ -0,0 +1,151 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import Dict, Optional, Union
17
+
18
+ import torch
19
+ from compressed_tensors.base import SPARSITY_CONFIG_NAME
20
+ from compressed_tensors.compressors import ModelCompressor
21
+ from compressed_tensors.config import (
22
+ CompressionConfig,
23
+ CompressionFormat,
24
+ DenseSparsityConfig,
25
+ )
26
+ from safetensors.torch import save_file
27
+ from torch import Tensor
28
+ from transformers import AutoConfig
29
+
30
+
31
+ __all__ = [
32
+ "infer_compressor_from_model_config",
33
+ "load_compressed",
34
+ "save_compressed",
35
+ "save_compressed_model",
36
+ ]
37
+
38
+
39
+ def infer_compressor_from_model_config(
40
+ pretrained_model_name_or_path: str,
41
+ ) -> Optional[ModelCompressor]:
42
+ """
43
+ Given a path to a model config, extract a sparsity config if it exists and return
44
+ the associated ModelCompressor
45
+
46
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
47
+ :return: matching compressor if config contains a sparsity config
48
+ """
49
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
50
+ sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
51
+ if sparsity_config is None:
52
+ return None
53
+
54
+ format = sparsity_config.get("format")
55
+ sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
56
+ compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
57
+ return compressor
58
+
59
+
60
+ def save_compressed(
61
+ tensors: Dict[str, Tensor],
62
+ save_path: Union[str, Path],
63
+ compression_format: Optional[CompressionFormat] = None,
64
+ ):
65
+ """
66
+ Save compressed tensors to disk. If tensors are not compressed,
67
+ save them as is.
68
+
69
+ :param tensors: dictionary of tensors to compress
70
+ :param save_path: path to save compressed tensors
71
+ :param compression_format: compression format used for the tensors
72
+ :return: compression config, if tensors were compressed - None otherwise
73
+ """
74
+ if tensors is None or len(tensors) == 0:
75
+ raise ValueError("No tensors or empty tensors provided to compress")
76
+
77
+ # if no compression_format specified, default to `dense_sparsity`
78
+ compression_format = compression_format or CompressionFormat.dense_sparsity.value
79
+
80
+ if not (
81
+ compression_format in ModelCompressor.registered_names()
82
+ or compression_format in ModelCompressor.registered_aliases()
83
+ ):
84
+ raise ValueError(
85
+ f"Unknown compression format: {compression_format}. "
86
+ f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501
87
+ )
88
+
89
+ # compress
90
+ compressor = ModelCompressor.load_from_registry(compression_format)
91
+ # save compressed tensors
92
+ compressed_tensors = compressor.compress(tensors)
93
+ save_file(compressed_tensors, save_path)
94
+
95
+
96
+ def load_compressed(
97
+ compressed_tensors: Union[str, Path],
98
+ compression_config: CompressionConfig = None,
99
+ device: Optional[str] = "cpu",
100
+ ) -> Dict[str, Tensor]:
101
+ """
102
+ Load compressed tensors from disk. If tensors are not compressed,
103
+ load them as is.
104
+
105
+ :param compressed_tensors: path to compressed tensors
106
+ :param compression_config: compression config to use for decompressing tensors.
107
+ :param device: device to move tensors to. If None, tensors are loaded on CPU.
108
+ :return decompressed tensors
109
+ """
110
+
111
+ if compressed_tensors is None or not Path(compressed_tensors).exists():
112
+ raise ValueError("No compressed tensors provided to load")
113
+
114
+ # if no compression_config specified, default to `dense_sparsity`
115
+ compression_config = compression_config or DenseSparsityConfig()
116
+
117
+ # decompress
118
+ compression_format = compression_config.format
119
+ compressor = ModelCompressor.load_from_registry(
120
+ compression_format, config=compression_config
121
+ )
122
+ return dict(compressor.decompress(compressed_tensors, device=device))
123
+
124
+
125
+ def save_compressed_model(
126
+ model: torch.nn.Module,
127
+ filename: str,
128
+ compression_format: Optional[CompressionFormat] = None,
129
+ force_contiguous: bool = True,
130
+ ):
131
+ """
132
+ Wrapper around safetensors `save_model` helper function, which allows for
133
+ saving compressed model to disk.
134
+
135
+ Note: The model is assumed to have a
136
+ state_dict with unique entries
137
+
138
+ :param model: model to save on disk
139
+ :param filename: filename location to save the file
140
+ :param compression_format: compression format used for the model
141
+ :param force_contiguous: forcing the state_dict to be saved as contiguous tensors
142
+ """
143
+ state_dict = model.state_dict()
144
+ if force_contiguous:
145
+ state_dict = {k: v.contiguous() for k, v in state_dict.items()}
146
+ try:
147
+ save_compressed(state_dict, filename, compression_format=compression_format)
148
+ except ValueError as e:
149
+ msg = str(e)
150
+ msg += " Or use save_compressed_model(..., force_contiguous=True), read the docs for potential caveats." # noqa E501
151
+ raise ValueError(msg)