compressed-tensors-nightly 0.3.3.20240514__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +21 -0
- compressed_tensors/base.py +17 -0
- compressed_tensors/compressors/__init__.py +22 -0
- compressed_tensors/compressors/base.py +59 -0
- compressed_tensors/compressors/dense.py +34 -0
- compressed_tensors/compressors/helpers.py +137 -0
- compressed_tensors/compressors/int_quantized.py +95 -0
- compressed_tensors/compressors/model_compressor.py +264 -0
- compressed_tensors/compressors/sparse_bitmask.py +239 -0
- compressed_tensors/config/__init__.py +18 -0
- compressed_tensors/config/base.py +43 -0
- compressed_tensors/config/dense.py +36 -0
- compressed_tensors/config/sparse_bitmask.py +36 -0
- compressed_tensors/quantization/__init__.py +21 -0
- compressed_tensors/quantization/lifecycle/__init__.py +23 -0
- compressed_tensors/quantization/lifecycle/apply.py +196 -0
- compressed_tensors/quantization/lifecycle/calibration.py +51 -0
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +333 -0
- compressed_tensors/quantization/lifecycle/frozen.py +50 -0
- compressed_tensors/quantization/lifecycle/initialize.py +99 -0
- compressed_tensors/quantization/observers/__init__.py +21 -0
- compressed_tensors/quantization/observers/base.py +130 -0
- compressed_tensors/quantization/observers/helpers.py +54 -0
- compressed_tensors/quantization/observers/memoryless.py +48 -0
- compressed_tensors/quantization/observers/min_max.py +80 -0
- compressed_tensors/quantization/quant_args.py +125 -0
- compressed_tensors/quantization/quant_config.py +210 -0
- compressed_tensors/quantization/quant_scheme.py +39 -0
- compressed_tensors/quantization/utils/__init__.py +16 -0
- compressed_tensors/quantization/utils/helpers.py +131 -0
- compressed_tensors/registry/__init__.py +17 -0
- compressed_tensors/registry/registry.py +360 -0
- compressed_tensors/utils/__init__.py +16 -0
- compressed_tensors/utils/helpers.py +45 -0
- compressed_tensors/utils/safetensors_load.py +237 -0
- compressed_tensors/version.py +50 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/LICENSE +201 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/METADATA +105 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/RECORD +42 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/WHEEL +5 -0
- compressed_tensors_nightly-0.3.3.20240514.dist-info/top_level.txt +1 -0
@@ -0,0 +1,16 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# flake8: noqa
|
16
|
+
from .helpers import *
|
@@ -0,0 +1,131 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.quantization.observers.base import Observer
|
19
|
+
from torch.nn import Module
|
20
|
+
from tqdm import tqdm
|
21
|
+
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
"is_module_quantized",
|
25
|
+
"is_model_quantized",
|
26
|
+
"iter_named_leaf_modules",
|
27
|
+
"module_type",
|
28
|
+
"calculate_compression_ratio",
|
29
|
+
]
|
30
|
+
|
31
|
+
|
32
|
+
def is_module_quantized(module: Module) -> bool:
|
33
|
+
"""
|
34
|
+
Check if a module is quantized, based on the existence of a non-empty quantization
|
35
|
+
scheme
|
36
|
+
|
37
|
+
:param module: pytorch module to check
|
38
|
+
:return: True if module is quantized, False otherwise
|
39
|
+
"""
|
40
|
+
if not hasattr(module, "quantization_scheme"):
|
41
|
+
return False
|
42
|
+
|
43
|
+
if module.quantization_scheme.weights is not None:
|
44
|
+
return True
|
45
|
+
|
46
|
+
if module.quantization_scheme.input_activations is not None:
|
47
|
+
return True
|
48
|
+
|
49
|
+
if module.quantization_scheme.output_activations is not None:
|
50
|
+
return True
|
51
|
+
|
52
|
+
return False
|
53
|
+
|
54
|
+
|
55
|
+
def is_model_quantized(model: Module) -> bool:
|
56
|
+
"""
|
57
|
+
Check if any modules in a model are quantized, based on the existence of a non-empty
|
58
|
+
quantization scheme in at least one module
|
59
|
+
|
60
|
+
:param model: pytorch model
|
61
|
+
:return: True if model is quantized, False otherwise
|
62
|
+
"""
|
63
|
+
|
64
|
+
for _, submodule in iter_named_leaf_modules(model):
|
65
|
+
if is_module_quantized(submodule):
|
66
|
+
return True
|
67
|
+
|
68
|
+
return False
|
69
|
+
|
70
|
+
|
71
|
+
def module_type(module: Module) -> str:
|
72
|
+
"""
|
73
|
+
Gets a string representation of a module type
|
74
|
+
|
75
|
+
:module: pytorch module to get type of
|
76
|
+
:return: module type as a string
|
77
|
+
"""
|
78
|
+
return type(module).__name__
|
79
|
+
|
80
|
+
|
81
|
+
def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
|
82
|
+
"""
|
83
|
+
Yields modules that do not have any submodules except observers. The observers
|
84
|
+
themselves are not yielded
|
85
|
+
|
86
|
+
:param model: model to get leaf modules of
|
87
|
+
:returns: generator tuple of (name, leaf_submodule)
|
88
|
+
"""
|
89
|
+
for name, submodule in model.named_modules():
|
90
|
+
children = list(submodule.children())
|
91
|
+
if len(children) == 0 and not isinstance(submodule, Observer):
|
92
|
+
yield name, submodule
|
93
|
+
else:
|
94
|
+
has_non_observer_children = False
|
95
|
+
for child in children:
|
96
|
+
if not isinstance(child, Observer):
|
97
|
+
has_non_observer_children = True
|
98
|
+
|
99
|
+
if not has_non_observer_children:
|
100
|
+
yield name, submodule
|
101
|
+
|
102
|
+
|
103
|
+
def calculate_compression_ratio(model: Module) -> float:
|
104
|
+
"""
|
105
|
+
Calculates the quantization compression ratio of a pytorch model, based on the
|
106
|
+
number of bits needed to represent the total weights in compressed form. Does not
|
107
|
+
take into account activation quantizatons.
|
108
|
+
|
109
|
+
:param model: pytorch module to calculate compression ratio for
|
110
|
+
:return: compression ratio of the whole model
|
111
|
+
"""
|
112
|
+
total_compressed = 0.0
|
113
|
+
total_uncompressed = 0.0
|
114
|
+
for name, submodule in tqdm(
|
115
|
+
iter_named_leaf_modules(model),
|
116
|
+
desc="Calculating quantization compression ratio",
|
117
|
+
):
|
118
|
+
for parameter in model.parameters():
|
119
|
+
try:
|
120
|
+
uncompressed_bits = torch.finfo(parameter.dtype).bits
|
121
|
+
except TypeError:
|
122
|
+
uncompressed_bits = torch.iinfo(parameter.dtype).bits
|
123
|
+
compressed_bits = uncompressed_bits
|
124
|
+
if is_module_quantized(submodule):
|
125
|
+
compressed_bits = submodule.quantization_scheme.weights.num_bits
|
126
|
+
|
127
|
+
num_weights = parameter.numel()
|
128
|
+
total_compressed += compressed_bits * num_weights
|
129
|
+
total_uncompressed += uncompressed_bits * num_weights
|
130
|
+
|
131
|
+
return total_uncompressed / total_compressed
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# flake8: noqa
|
2
|
+
|
3
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
4
|
+
#
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
# you may not use this file except in compliance with the License.
|
7
|
+
# You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
# See the License for the specific language governing permissions and
|
15
|
+
# limitations under the License.
|
16
|
+
|
17
|
+
from .registry import *
|
@@ -0,0 +1,360 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Universal registry to support registration and loading of child classes and plugins
|
17
|
+
of neuralmagic utilities
|
18
|
+
"""
|
19
|
+
|
20
|
+
import importlib
|
21
|
+
from collections import defaultdict
|
22
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
23
|
+
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"RegistryMixin",
|
27
|
+
"register",
|
28
|
+
"get_from_registry",
|
29
|
+
"registered_names",
|
30
|
+
"registered_aliases",
|
31
|
+
"standardize_lookup_name",
|
32
|
+
]
|
33
|
+
|
34
|
+
|
35
|
+
_ALIAS_REGISTRY: Dict[Type, Dict[str, str]] = defaultdict(dict)
|
36
|
+
_REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict)
|
37
|
+
|
38
|
+
|
39
|
+
def standardize_lookup_name(name: str) -> str:
|
40
|
+
"""
|
41
|
+
Standardize the given name for lookup in the registry.
|
42
|
+
This will replace all underscores and spaces with hyphens and
|
43
|
+
convert the name to lowercase.
|
44
|
+
|
45
|
+
example:
|
46
|
+
```
|
47
|
+
standardize_lookup_name("Foo_bar baz") == "foo-bar-baz"
|
48
|
+
```
|
49
|
+
|
50
|
+
:param name: name to standardize
|
51
|
+
:return: standardized name
|
52
|
+
"""
|
53
|
+
return name.replace("_", "-").replace(" ", "-").lower()
|
54
|
+
|
55
|
+
|
56
|
+
def standardize_alias_name(
|
57
|
+
name: Union[None, str, List[str]]
|
58
|
+
) -> Union[None, str, List[str]]:
|
59
|
+
if name is None:
|
60
|
+
return None
|
61
|
+
elif isinstance(name, str):
|
62
|
+
return standardize_lookup_name(name)
|
63
|
+
else: # isinstance(name, list)
|
64
|
+
return [standardize_lookup_name(n) for n in name]
|
65
|
+
|
66
|
+
|
67
|
+
class RegistryMixin:
|
68
|
+
"""
|
69
|
+
Universal registry to support registration and loading of child classes and plugins
|
70
|
+
of neuralmagic utilities.
|
71
|
+
|
72
|
+
Classes that require a registry or plugins may add the `RegistryMixin` and use
|
73
|
+
`register` and `load` as the main entrypoints for adding new implementations and
|
74
|
+
loading requested values from its registry.
|
75
|
+
|
76
|
+
If a class should only have its child classes in its registry, the class should
|
77
|
+
set the static attribute `registry_requires_subclass` to True
|
78
|
+
|
79
|
+
example
|
80
|
+
```python
|
81
|
+
class Dataset(RegistryMixin):
|
82
|
+
pass
|
83
|
+
|
84
|
+
|
85
|
+
# register with default name
|
86
|
+
@Dataset.register()
|
87
|
+
class ImageNetDataset(Dataset):
|
88
|
+
pass
|
89
|
+
|
90
|
+
# load as "ImageNetDataset"
|
91
|
+
imagenet = Dataset.load("ImageNetDataset")
|
92
|
+
|
93
|
+
# register with custom name
|
94
|
+
@Dataset.register(name="cifar-dataset")
|
95
|
+
class Cifar(Dataset):
|
96
|
+
pass
|
97
|
+
|
98
|
+
Note: the name will be standardized for lookup in the registry.
|
99
|
+
For example, if a class is registered as "cifar_dataset" or
|
100
|
+
"cifar dataset", it will be stored as "cifar-dataset". The user
|
101
|
+
will be able to load the class with any of the three name variants.
|
102
|
+
|
103
|
+
# register with multiple aliases
|
104
|
+
@Dataset.register(alias=["cifar-10-dataset", "cifar_100_dataset"])
|
105
|
+
class Cifar(Dataset):
|
106
|
+
pass
|
107
|
+
|
108
|
+
# load as "cifar-dataset"
|
109
|
+
cifar = Dataset.load_from_registry("cifar-dataset")
|
110
|
+
|
111
|
+
# load from custom file that implements a dataset
|
112
|
+
mnist = Dataset.load_from_registry("/path/to/mnnist_dataset.py:MnistDataset")
|
113
|
+
```
|
114
|
+
"""
|
115
|
+
|
116
|
+
# set to True in child class to add check that registered/retrieved values
|
117
|
+
# implement the class it is registered to
|
118
|
+
registry_requires_subclass: bool = False
|
119
|
+
|
120
|
+
@classmethod
|
121
|
+
def register(
|
122
|
+
cls, name: Optional[str] = None, alias: Union[List[str], str, None] = None
|
123
|
+
):
|
124
|
+
"""
|
125
|
+
Decorator for registering a value (ie class or function) wrapped by this
|
126
|
+
decorator to the base class (class that .register is called from)
|
127
|
+
|
128
|
+
:param name: name or list of names to register the wrapped value as,
|
129
|
+
defaults to value.__name__
|
130
|
+
:param alias: alias or list of aliases to register the wrapped value as,
|
131
|
+
defaults to None
|
132
|
+
:return: register decorator
|
133
|
+
"""
|
134
|
+
|
135
|
+
def decorator(value: Any):
|
136
|
+
cls.register_value(value, name=name, alias=alias)
|
137
|
+
return value
|
138
|
+
|
139
|
+
return decorator
|
140
|
+
|
141
|
+
@classmethod
|
142
|
+
def register_value(
|
143
|
+
cls, value: Any, name: str, alias: Union[str, List[str], None] = None
|
144
|
+
):
|
145
|
+
"""
|
146
|
+
Registers the given value to the class `.register_value` is called from
|
147
|
+
:param value: value to register
|
148
|
+
:param name: name to register the wrapped value as,
|
149
|
+
defaults to value.__name__
|
150
|
+
:param alias: alias or list of aliases to register the wrapped value as,
|
151
|
+
defaults to None
|
152
|
+
"""
|
153
|
+
register(
|
154
|
+
parent_class=cls,
|
155
|
+
value=value,
|
156
|
+
name=name,
|
157
|
+
alias=alias,
|
158
|
+
require_subclass=cls.registry_requires_subclass,
|
159
|
+
)
|
160
|
+
|
161
|
+
@classmethod
|
162
|
+
def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
|
163
|
+
"""
|
164
|
+
:param name: name of registered class to load
|
165
|
+
:param constructor_kwargs: arguments to pass to the constructor retrieved
|
166
|
+
from the registry
|
167
|
+
:return: loaded object registered to this class under the given name,
|
168
|
+
constructed with the given kwargs. Raises error if the name is
|
169
|
+
not found in the registry
|
170
|
+
"""
|
171
|
+
constructor = cls.get_value_from_registry(name=name)
|
172
|
+
return constructor(**constructor_kwargs)
|
173
|
+
|
174
|
+
@classmethod
|
175
|
+
def get_value_from_registry(cls, name: str):
|
176
|
+
"""
|
177
|
+
:param name: name to retrieve from the registry
|
178
|
+
:return: value from retrieved the registry for the given name, raises
|
179
|
+
error if not found
|
180
|
+
"""
|
181
|
+
return get_from_registry(
|
182
|
+
parent_class=cls,
|
183
|
+
name=name,
|
184
|
+
require_subclass=cls.registry_requires_subclass,
|
185
|
+
)
|
186
|
+
|
187
|
+
@classmethod
|
188
|
+
def registered_names(cls) -> List[str]:
|
189
|
+
"""
|
190
|
+
:return: list of all names registered to this class
|
191
|
+
"""
|
192
|
+
return registered_names(cls)
|
193
|
+
|
194
|
+
@classmethod
|
195
|
+
def registered_aliases(cls) -> List[str]:
|
196
|
+
"""
|
197
|
+
:return: list of all aliases registered to this class
|
198
|
+
"""
|
199
|
+
return registered_aliases(cls)
|
200
|
+
|
201
|
+
|
202
|
+
def register(
|
203
|
+
parent_class: Type,
|
204
|
+
value: Any,
|
205
|
+
name: Optional[str] = None,
|
206
|
+
alias: Union[List[str], str, None] = None,
|
207
|
+
require_subclass: bool = False,
|
208
|
+
):
|
209
|
+
"""
|
210
|
+
:param parent_class: class to register the name under
|
211
|
+
:param value: the value to register
|
212
|
+
:param name: name to register the wrapped value as, defaults to value.__name__
|
213
|
+
:param alias: alias or list of aliases to register the wrapped value as,
|
214
|
+
defaults to None
|
215
|
+
:param require_subclass: require that value is a subclass of the class this
|
216
|
+
method is called from
|
217
|
+
"""
|
218
|
+
if name is None:
|
219
|
+
# default name
|
220
|
+
name = value.__name__
|
221
|
+
|
222
|
+
name = standardize_lookup_name(name)
|
223
|
+
alias = standardize_alias_name(alias)
|
224
|
+
register_alias(name=name, alias=alias, parent_class=parent_class)
|
225
|
+
|
226
|
+
if require_subclass:
|
227
|
+
_validate_subclass(parent_class, value)
|
228
|
+
|
229
|
+
if name in _REGISTRY[parent_class]:
|
230
|
+
# name already exists - raise error if two different values are attempting
|
231
|
+
# to share the same name
|
232
|
+
registered_value = _REGISTRY[parent_class][name]
|
233
|
+
if registered_value is not value:
|
234
|
+
raise RuntimeError(
|
235
|
+
f"Attempting to register name {name} as {value} "
|
236
|
+
f"however {name} has already been registered as {registered_value}"
|
237
|
+
)
|
238
|
+
else:
|
239
|
+
_REGISTRY[parent_class][name] = value
|
240
|
+
|
241
|
+
|
242
|
+
def get_from_registry(
|
243
|
+
parent_class: Type, name: str, require_subclass: bool = False
|
244
|
+
) -> Any:
|
245
|
+
"""
|
246
|
+
:param parent_class: class that the name is registered under
|
247
|
+
:param name: name to retrieve from the registry of the class
|
248
|
+
:param require_subclass: require that value is a subclass of the class this
|
249
|
+
method is called from
|
250
|
+
:return: value from retrieved the registry for the given name, raises
|
251
|
+
error if not found
|
252
|
+
"""
|
253
|
+
name = standardize_lookup_name(name)
|
254
|
+
|
255
|
+
if ":" in name:
|
256
|
+
# user specifying specific module to load and value to import
|
257
|
+
module_path, value_name = name.split(":")
|
258
|
+
retrieved_value = _import_and_get_value_from_module(module_path, value_name)
|
259
|
+
else:
|
260
|
+
# look up name in alias registry
|
261
|
+
name = _ALIAS_REGISTRY[parent_class].get(name)
|
262
|
+
# look up name in registry
|
263
|
+
retrieved_value = _REGISTRY[parent_class].get(name)
|
264
|
+
if retrieved_value is None:
|
265
|
+
raise KeyError(
|
266
|
+
f"Unable to find {name} registered under type {parent_class}.\n"
|
267
|
+
f"Registered values for {parent_class}: "
|
268
|
+
f"{registered_names(parent_class)}\n"
|
269
|
+
f"Registered aliases for {parent_class}: "
|
270
|
+
f"{registered_aliases(parent_class)}"
|
271
|
+
)
|
272
|
+
|
273
|
+
if require_subclass:
|
274
|
+
_validate_subclass(parent_class, retrieved_value)
|
275
|
+
|
276
|
+
return retrieved_value
|
277
|
+
|
278
|
+
|
279
|
+
def registered_names(parent_class: Type) -> List[str]:
|
280
|
+
"""
|
281
|
+
:param parent_class: class to look up the registry of
|
282
|
+
:return: all names registered to the given class
|
283
|
+
"""
|
284
|
+
return list(_REGISTRY[parent_class].keys())
|
285
|
+
|
286
|
+
|
287
|
+
def registered_aliases(parent_class: Type) -> List[str]:
|
288
|
+
"""
|
289
|
+
:param parent_class: class to look up the registry of
|
290
|
+
:return: all aliases registered to the given class
|
291
|
+
"""
|
292
|
+
registered_aliases_plus_names = list(_ALIAS_REGISTRY[parent_class].keys())
|
293
|
+
registered_aliases = list(
|
294
|
+
set(registered_aliases_plus_names) - set(registered_names(parent_class))
|
295
|
+
)
|
296
|
+
return registered_aliases
|
297
|
+
|
298
|
+
|
299
|
+
def register_alias(
|
300
|
+
name: str, parent_class: Type, alias: Union[str, List[str], None] = None
|
301
|
+
):
|
302
|
+
"""
|
303
|
+
Updates the mapping from the alias(es) to the given name.
|
304
|
+
If the alias is None, the name is used as the alias.
|
305
|
+
```
|
306
|
+
|
307
|
+
:param name: name that the alias refers to
|
308
|
+
:param parent_class: class that the name is registered under
|
309
|
+
:param alias: single alias or list of aliases that
|
310
|
+
refer to the name, defaults to None
|
311
|
+
"""
|
312
|
+
if alias is not None:
|
313
|
+
alias = alias if isinstance(alias, list) else [alias]
|
314
|
+
else:
|
315
|
+
alias = []
|
316
|
+
|
317
|
+
if name in alias:
|
318
|
+
raise KeyError(
|
319
|
+
f"Attempting to register alias {name}, "
|
320
|
+
f"that is identical to the standardized name: {name}."
|
321
|
+
)
|
322
|
+
alias.append(name)
|
323
|
+
|
324
|
+
for alias_name in alias:
|
325
|
+
if alias_name in _ALIAS_REGISTRY[parent_class]:
|
326
|
+
raise KeyError(
|
327
|
+
f"Attempting to register alias {alias_name} as {name} "
|
328
|
+
f"however {alias_name} has already been registered as "
|
329
|
+
f"{_ALIAS_REGISTRY[alias_name]}"
|
330
|
+
)
|
331
|
+
_ALIAS_REGISTRY[parent_class][alias_name] = name
|
332
|
+
|
333
|
+
|
334
|
+
def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
|
335
|
+
# import the given module path and try to get the value_name if it is included
|
336
|
+
# in the module
|
337
|
+
|
338
|
+
# load module
|
339
|
+
spec = importlib.util.spec_from_file_location(
|
340
|
+
f"plugin_module_for_{value_name}", module_path
|
341
|
+
)
|
342
|
+
module = importlib.util.module_from_spec(spec)
|
343
|
+
spec.loader.exec_module(module)
|
344
|
+
|
345
|
+
# get value from module
|
346
|
+
value = getattr(module, value_name, None)
|
347
|
+
|
348
|
+
if not value:
|
349
|
+
raise RuntimeError(
|
350
|
+
f"Unable to find attribute {value_name} in module {module_path}"
|
351
|
+
)
|
352
|
+
return value
|
353
|
+
|
354
|
+
|
355
|
+
def _validate_subclass(parent_class: Type, child_class: Type):
|
356
|
+
if not issubclass(child_class, parent_class):
|
357
|
+
raise ValueError(
|
358
|
+
f"class {child_class} is not a subclass of the class it is "
|
359
|
+
f"registered for: {parent_class}."
|
360
|
+
)
|
@@ -0,0 +1,16 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# flake8: noqa
|
15
|
+
|
16
|
+
from .safetensors_load import *
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
from compressed_tensors.base import SPARSITY_CONFIG_NAME
|
19
|
+
from compressed_tensors.compressors import ModelCompressor
|
20
|
+
from compressed_tensors.config import CompressionConfig
|
21
|
+
from transformers import AutoConfig
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = ["infer_compressor_from_model_config"]
|
25
|
+
|
26
|
+
|
27
|
+
def infer_compressor_from_model_config(
|
28
|
+
pretrained_model_name_or_path: str,
|
29
|
+
) -> Optional[ModelCompressor]:
|
30
|
+
"""
|
31
|
+
Given a path to a model config, extract a sparsity config if it exists and return
|
32
|
+
the associated ModelCompressor
|
33
|
+
|
34
|
+
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
35
|
+
:return: matching compressor if config contains a sparsity config
|
36
|
+
"""
|
37
|
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
38
|
+
sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
|
39
|
+
if sparsity_config is None:
|
40
|
+
return None
|
41
|
+
|
42
|
+
format = sparsity_config.get("format")
|
43
|
+
sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
|
44
|
+
compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
|
45
|
+
return compressor
|