compressed-tensors-nightly 0.3.3.20240514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/__init__.py +21 -0
  2. compressed_tensors/base.py +17 -0
  3. compressed_tensors/compressors/__init__.py +22 -0
  4. compressed_tensors/compressors/base.py +59 -0
  5. compressed_tensors/compressors/dense.py +34 -0
  6. compressed_tensors/compressors/helpers.py +137 -0
  7. compressed_tensors/compressors/int_quantized.py +95 -0
  8. compressed_tensors/compressors/model_compressor.py +264 -0
  9. compressed_tensors/compressors/sparse_bitmask.py +239 -0
  10. compressed_tensors/config/__init__.py +18 -0
  11. compressed_tensors/config/base.py +43 -0
  12. compressed_tensors/config/dense.py +36 -0
  13. compressed_tensors/config/sparse_bitmask.py +36 -0
  14. compressed_tensors/quantization/__init__.py +21 -0
  15. compressed_tensors/quantization/lifecycle/__init__.py +23 -0
  16. compressed_tensors/quantization/lifecycle/apply.py +196 -0
  17. compressed_tensors/quantization/lifecycle/calibration.py +51 -0
  18. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  19. compressed_tensors/quantization/lifecycle/forward.py +333 -0
  20. compressed_tensors/quantization/lifecycle/frozen.py +50 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +99 -0
  22. compressed_tensors/quantization/observers/__init__.py +21 -0
  23. compressed_tensors/quantization/observers/base.py +130 -0
  24. compressed_tensors/quantization/observers/helpers.py +54 -0
  25. compressed_tensors/quantization/observers/memoryless.py +48 -0
  26. compressed_tensors/quantization/observers/min_max.py +80 -0
  27. compressed_tensors/quantization/quant_args.py +125 -0
  28. compressed_tensors/quantization/quant_config.py +210 -0
  29. compressed_tensors/quantization/quant_scheme.py +39 -0
  30. compressed_tensors/quantization/utils/__init__.py +16 -0
  31. compressed_tensors/quantization/utils/helpers.py +131 -0
  32. compressed_tensors/registry/__init__.py +17 -0
  33. compressed_tensors/registry/registry.py +360 -0
  34. compressed_tensors/utils/__init__.py +16 -0
  35. compressed_tensors/utils/helpers.py +45 -0
  36. compressed_tensors/utils/safetensors_load.py +237 -0
  37. compressed_tensors/version.py +50 -0
  38. compressed_tensors_nightly-0.3.3.20240514.dist-info/LICENSE +201 -0
  39. compressed_tensors_nightly-0.3.3.20240514.dist-info/METADATA +105 -0
  40. compressed_tensors_nightly-0.3.3.20240514.dist-info/RECORD +42 -0
  41. compressed_tensors_nightly-0.3.3.20240514.dist-info/WHEEL +5 -0
  42. compressed_tensors_nightly-0.3.3.20240514.dist-info/top_level.txt +1 -0
@@ -0,0 +1,16 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ from .helpers import *
@@ -0,0 +1,131 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.observers.base import Observer
19
+ from torch.nn import Module
20
+ from tqdm import tqdm
21
+
22
+
23
+ __all__ = [
24
+ "is_module_quantized",
25
+ "is_model_quantized",
26
+ "iter_named_leaf_modules",
27
+ "module_type",
28
+ "calculate_compression_ratio",
29
+ ]
30
+
31
+
32
+ def is_module_quantized(module: Module) -> bool:
33
+ """
34
+ Check if a module is quantized, based on the existence of a non-empty quantization
35
+ scheme
36
+
37
+ :param module: pytorch module to check
38
+ :return: True if module is quantized, False otherwise
39
+ """
40
+ if not hasattr(module, "quantization_scheme"):
41
+ return False
42
+
43
+ if module.quantization_scheme.weights is not None:
44
+ return True
45
+
46
+ if module.quantization_scheme.input_activations is not None:
47
+ return True
48
+
49
+ if module.quantization_scheme.output_activations is not None:
50
+ return True
51
+
52
+ return False
53
+
54
+
55
+ def is_model_quantized(model: Module) -> bool:
56
+ """
57
+ Check if any modules in a model are quantized, based on the existence of a non-empty
58
+ quantization scheme in at least one module
59
+
60
+ :param model: pytorch model
61
+ :return: True if model is quantized, False otherwise
62
+ """
63
+
64
+ for _, submodule in iter_named_leaf_modules(model):
65
+ if is_module_quantized(submodule):
66
+ return True
67
+
68
+ return False
69
+
70
+
71
+ def module_type(module: Module) -> str:
72
+ """
73
+ Gets a string representation of a module type
74
+
75
+ :module: pytorch module to get type of
76
+ :return: module type as a string
77
+ """
78
+ return type(module).__name__
79
+
80
+
81
+ def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
82
+ """
83
+ Yields modules that do not have any submodules except observers. The observers
84
+ themselves are not yielded
85
+
86
+ :param model: model to get leaf modules of
87
+ :returns: generator tuple of (name, leaf_submodule)
88
+ """
89
+ for name, submodule in model.named_modules():
90
+ children = list(submodule.children())
91
+ if len(children) == 0 and not isinstance(submodule, Observer):
92
+ yield name, submodule
93
+ else:
94
+ has_non_observer_children = False
95
+ for child in children:
96
+ if not isinstance(child, Observer):
97
+ has_non_observer_children = True
98
+
99
+ if not has_non_observer_children:
100
+ yield name, submodule
101
+
102
+
103
+ def calculate_compression_ratio(model: Module) -> float:
104
+ """
105
+ Calculates the quantization compression ratio of a pytorch model, based on the
106
+ number of bits needed to represent the total weights in compressed form. Does not
107
+ take into account activation quantizatons.
108
+
109
+ :param model: pytorch module to calculate compression ratio for
110
+ :return: compression ratio of the whole model
111
+ """
112
+ total_compressed = 0.0
113
+ total_uncompressed = 0.0
114
+ for name, submodule in tqdm(
115
+ iter_named_leaf_modules(model),
116
+ desc="Calculating quantization compression ratio",
117
+ ):
118
+ for parameter in model.parameters():
119
+ try:
120
+ uncompressed_bits = torch.finfo(parameter.dtype).bits
121
+ except TypeError:
122
+ uncompressed_bits = torch.iinfo(parameter.dtype).bits
123
+ compressed_bits = uncompressed_bits
124
+ if is_module_quantized(submodule):
125
+ compressed_bits = submodule.quantization_scheme.weights.num_bits
126
+
127
+ num_weights = parameter.numel()
128
+ total_compressed += compressed_bits * num_weights
129
+ total_uncompressed += uncompressed_bits * num_weights
130
+
131
+ return total_uncompressed / total_compressed
@@ -0,0 +1,17 @@
1
+ # flake8: noqa
2
+
3
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .registry import *
@@ -0,0 +1,360 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Universal registry to support registration and loading of child classes and plugins
17
+ of neuralmagic utilities
18
+ """
19
+
20
+ import importlib
21
+ from collections import defaultdict
22
+ from typing import Any, Dict, List, Optional, Type, Union
23
+
24
+
25
+ __all__ = [
26
+ "RegistryMixin",
27
+ "register",
28
+ "get_from_registry",
29
+ "registered_names",
30
+ "registered_aliases",
31
+ "standardize_lookup_name",
32
+ ]
33
+
34
+
35
+ _ALIAS_REGISTRY: Dict[Type, Dict[str, str]] = defaultdict(dict)
36
+ _REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict)
37
+
38
+
39
+ def standardize_lookup_name(name: str) -> str:
40
+ """
41
+ Standardize the given name for lookup in the registry.
42
+ This will replace all underscores and spaces with hyphens and
43
+ convert the name to lowercase.
44
+
45
+ example:
46
+ ```
47
+ standardize_lookup_name("Foo_bar baz") == "foo-bar-baz"
48
+ ```
49
+
50
+ :param name: name to standardize
51
+ :return: standardized name
52
+ """
53
+ return name.replace("_", "-").replace(" ", "-").lower()
54
+
55
+
56
+ def standardize_alias_name(
57
+ name: Union[None, str, List[str]]
58
+ ) -> Union[None, str, List[str]]:
59
+ if name is None:
60
+ return None
61
+ elif isinstance(name, str):
62
+ return standardize_lookup_name(name)
63
+ else: # isinstance(name, list)
64
+ return [standardize_lookup_name(n) for n in name]
65
+
66
+
67
+ class RegistryMixin:
68
+ """
69
+ Universal registry to support registration and loading of child classes and plugins
70
+ of neuralmagic utilities.
71
+
72
+ Classes that require a registry or plugins may add the `RegistryMixin` and use
73
+ `register` and `load` as the main entrypoints for adding new implementations and
74
+ loading requested values from its registry.
75
+
76
+ If a class should only have its child classes in its registry, the class should
77
+ set the static attribute `registry_requires_subclass` to True
78
+
79
+ example
80
+ ```python
81
+ class Dataset(RegistryMixin):
82
+ pass
83
+
84
+
85
+ # register with default name
86
+ @Dataset.register()
87
+ class ImageNetDataset(Dataset):
88
+ pass
89
+
90
+ # load as "ImageNetDataset"
91
+ imagenet = Dataset.load("ImageNetDataset")
92
+
93
+ # register with custom name
94
+ @Dataset.register(name="cifar-dataset")
95
+ class Cifar(Dataset):
96
+ pass
97
+
98
+ Note: the name will be standardized for lookup in the registry.
99
+ For example, if a class is registered as "cifar_dataset" or
100
+ "cifar dataset", it will be stored as "cifar-dataset". The user
101
+ will be able to load the class with any of the three name variants.
102
+
103
+ # register with multiple aliases
104
+ @Dataset.register(alias=["cifar-10-dataset", "cifar_100_dataset"])
105
+ class Cifar(Dataset):
106
+ pass
107
+
108
+ # load as "cifar-dataset"
109
+ cifar = Dataset.load_from_registry("cifar-dataset")
110
+
111
+ # load from custom file that implements a dataset
112
+ mnist = Dataset.load_from_registry("/path/to/mnnist_dataset.py:MnistDataset")
113
+ ```
114
+ """
115
+
116
+ # set to True in child class to add check that registered/retrieved values
117
+ # implement the class it is registered to
118
+ registry_requires_subclass: bool = False
119
+
120
+ @classmethod
121
+ def register(
122
+ cls, name: Optional[str] = None, alias: Union[List[str], str, None] = None
123
+ ):
124
+ """
125
+ Decorator for registering a value (ie class or function) wrapped by this
126
+ decorator to the base class (class that .register is called from)
127
+
128
+ :param name: name or list of names to register the wrapped value as,
129
+ defaults to value.__name__
130
+ :param alias: alias or list of aliases to register the wrapped value as,
131
+ defaults to None
132
+ :return: register decorator
133
+ """
134
+
135
+ def decorator(value: Any):
136
+ cls.register_value(value, name=name, alias=alias)
137
+ return value
138
+
139
+ return decorator
140
+
141
+ @classmethod
142
+ def register_value(
143
+ cls, value: Any, name: str, alias: Union[str, List[str], None] = None
144
+ ):
145
+ """
146
+ Registers the given value to the class `.register_value` is called from
147
+ :param value: value to register
148
+ :param name: name to register the wrapped value as,
149
+ defaults to value.__name__
150
+ :param alias: alias or list of aliases to register the wrapped value as,
151
+ defaults to None
152
+ """
153
+ register(
154
+ parent_class=cls,
155
+ value=value,
156
+ name=name,
157
+ alias=alias,
158
+ require_subclass=cls.registry_requires_subclass,
159
+ )
160
+
161
+ @classmethod
162
+ def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
163
+ """
164
+ :param name: name of registered class to load
165
+ :param constructor_kwargs: arguments to pass to the constructor retrieved
166
+ from the registry
167
+ :return: loaded object registered to this class under the given name,
168
+ constructed with the given kwargs. Raises error if the name is
169
+ not found in the registry
170
+ """
171
+ constructor = cls.get_value_from_registry(name=name)
172
+ return constructor(**constructor_kwargs)
173
+
174
+ @classmethod
175
+ def get_value_from_registry(cls, name: str):
176
+ """
177
+ :param name: name to retrieve from the registry
178
+ :return: value from retrieved the registry for the given name, raises
179
+ error if not found
180
+ """
181
+ return get_from_registry(
182
+ parent_class=cls,
183
+ name=name,
184
+ require_subclass=cls.registry_requires_subclass,
185
+ )
186
+
187
+ @classmethod
188
+ def registered_names(cls) -> List[str]:
189
+ """
190
+ :return: list of all names registered to this class
191
+ """
192
+ return registered_names(cls)
193
+
194
+ @classmethod
195
+ def registered_aliases(cls) -> List[str]:
196
+ """
197
+ :return: list of all aliases registered to this class
198
+ """
199
+ return registered_aliases(cls)
200
+
201
+
202
+ def register(
203
+ parent_class: Type,
204
+ value: Any,
205
+ name: Optional[str] = None,
206
+ alias: Union[List[str], str, None] = None,
207
+ require_subclass: bool = False,
208
+ ):
209
+ """
210
+ :param parent_class: class to register the name under
211
+ :param value: the value to register
212
+ :param name: name to register the wrapped value as, defaults to value.__name__
213
+ :param alias: alias or list of aliases to register the wrapped value as,
214
+ defaults to None
215
+ :param require_subclass: require that value is a subclass of the class this
216
+ method is called from
217
+ """
218
+ if name is None:
219
+ # default name
220
+ name = value.__name__
221
+
222
+ name = standardize_lookup_name(name)
223
+ alias = standardize_alias_name(alias)
224
+ register_alias(name=name, alias=alias, parent_class=parent_class)
225
+
226
+ if require_subclass:
227
+ _validate_subclass(parent_class, value)
228
+
229
+ if name in _REGISTRY[parent_class]:
230
+ # name already exists - raise error if two different values are attempting
231
+ # to share the same name
232
+ registered_value = _REGISTRY[parent_class][name]
233
+ if registered_value is not value:
234
+ raise RuntimeError(
235
+ f"Attempting to register name {name} as {value} "
236
+ f"however {name} has already been registered as {registered_value}"
237
+ )
238
+ else:
239
+ _REGISTRY[parent_class][name] = value
240
+
241
+
242
+ def get_from_registry(
243
+ parent_class: Type, name: str, require_subclass: bool = False
244
+ ) -> Any:
245
+ """
246
+ :param parent_class: class that the name is registered under
247
+ :param name: name to retrieve from the registry of the class
248
+ :param require_subclass: require that value is a subclass of the class this
249
+ method is called from
250
+ :return: value from retrieved the registry for the given name, raises
251
+ error if not found
252
+ """
253
+ name = standardize_lookup_name(name)
254
+
255
+ if ":" in name:
256
+ # user specifying specific module to load and value to import
257
+ module_path, value_name = name.split(":")
258
+ retrieved_value = _import_and_get_value_from_module(module_path, value_name)
259
+ else:
260
+ # look up name in alias registry
261
+ name = _ALIAS_REGISTRY[parent_class].get(name)
262
+ # look up name in registry
263
+ retrieved_value = _REGISTRY[parent_class].get(name)
264
+ if retrieved_value is None:
265
+ raise KeyError(
266
+ f"Unable to find {name} registered under type {parent_class}.\n"
267
+ f"Registered values for {parent_class}: "
268
+ f"{registered_names(parent_class)}\n"
269
+ f"Registered aliases for {parent_class}: "
270
+ f"{registered_aliases(parent_class)}"
271
+ )
272
+
273
+ if require_subclass:
274
+ _validate_subclass(parent_class, retrieved_value)
275
+
276
+ return retrieved_value
277
+
278
+
279
+ def registered_names(parent_class: Type) -> List[str]:
280
+ """
281
+ :param parent_class: class to look up the registry of
282
+ :return: all names registered to the given class
283
+ """
284
+ return list(_REGISTRY[parent_class].keys())
285
+
286
+
287
+ def registered_aliases(parent_class: Type) -> List[str]:
288
+ """
289
+ :param parent_class: class to look up the registry of
290
+ :return: all aliases registered to the given class
291
+ """
292
+ registered_aliases_plus_names = list(_ALIAS_REGISTRY[parent_class].keys())
293
+ registered_aliases = list(
294
+ set(registered_aliases_plus_names) - set(registered_names(parent_class))
295
+ )
296
+ return registered_aliases
297
+
298
+
299
+ def register_alias(
300
+ name: str, parent_class: Type, alias: Union[str, List[str], None] = None
301
+ ):
302
+ """
303
+ Updates the mapping from the alias(es) to the given name.
304
+ If the alias is None, the name is used as the alias.
305
+ ```
306
+
307
+ :param name: name that the alias refers to
308
+ :param parent_class: class that the name is registered under
309
+ :param alias: single alias or list of aliases that
310
+ refer to the name, defaults to None
311
+ """
312
+ if alias is not None:
313
+ alias = alias if isinstance(alias, list) else [alias]
314
+ else:
315
+ alias = []
316
+
317
+ if name in alias:
318
+ raise KeyError(
319
+ f"Attempting to register alias {name}, "
320
+ f"that is identical to the standardized name: {name}."
321
+ )
322
+ alias.append(name)
323
+
324
+ for alias_name in alias:
325
+ if alias_name in _ALIAS_REGISTRY[parent_class]:
326
+ raise KeyError(
327
+ f"Attempting to register alias {alias_name} as {name} "
328
+ f"however {alias_name} has already been registered as "
329
+ f"{_ALIAS_REGISTRY[alias_name]}"
330
+ )
331
+ _ALIAS_REGISTRY[parent_class][alias_name] = name
332
+
333
+
334
+ def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
335
+ # import the given module path and try to get the value_name if it is included
336
+ # in the module
337
+
338
+ # load module
339
+ spec = importlib.util.spec_from_file_location(
340
+ f"plugin_module_for_{value_name}", module_path
341
+ )
342
+ module = importlib.util.module_from_spec(spec)
343
+ spec.loader.exec_module(module)
344
+
345
+ # get value from module
346
+ value = getattr(module, value_name, None)
347
+
348
+ if not value:
349
+ raise RuntimeError(
350
+ f"Unable to find attribute {value_name} in module {module_path}"
351
+ )
352
+ return value
353
+
354
+
355
+ def _validate_subclass(parent_class: Type, child_class: Type):
356
+ if not issubclass(child_class, parent_class):
357
+ raise ValueError(
358
+ f"class {child_class} is not a subclass of the class it is "
359
+ f"registered for: {parent_class}."
360
+ )
@@ -0,0 +1,16 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .safetensors_load import *
@@ -0,0 +1,45 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional
17
+
18
+ from compressed_tensors.base import SPARSITY_CONFIG_NAME
19
+ from compressed_tensors.compressors import ModelCompressor
20
+ from compressed_tensors.config import CompressionConfig
21
+ from transformers import AutoConfig
22
+
23
+
24
+ __all__ = ["infer_compressor_from_model_config"]
25
+
26
+
27
+ def infer_compressor_from_model_config(
28
+ pretrained_model_name_or_path: str,
29
+ ) -> Optional[ModelCompressor]:
30
+ """
31
+ Given a path to a model config, extract a sparsity config if it exists and return
32
+ the associated ModelCompressor
33
+
34
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
35
+ :return: matching compressor if config contains a sparsity config
36
+ """
37
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
38
+ sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
39
+ if sparsity_config is None:
40
+ return None
41
+
42
+ format = sparsity_config.get("format")
43
+ sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
44
+ compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
45
+ return compressor