compressed-tensors 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. compressed-tensors-0.3.0/PKG-INFO +13 -0
  2. compressed-tensors-0.3.0/README.md +82 -0
  3. compressed-tensors-0.3.0/pyproject.toml +3 -0
  4. compressed-tensors-0.3.0/setup.cfg +20 -0
  5. compressed-tensors-0.3.0/setup.py +40 -0
  6. compressed-tensors-0.3.0/src/compressed_tensors/__init__.py +21 -0
  7. compressed-tensors-0.3.0/src/compressed_tensors/base.py +16 -0
  8. compressed-tensors-0.3.0/src/compressed_tensors/compressors/__init__.py +25 -0
  9. compressed-tensors-0.3.0/src/compressed_tensors/compressors/base.py +79 -0
  10. compressed-tensors-0.3.0/src/compressed_tensors/compressors/dense.py +34 -0
  11. compressed-tensors-0.3.0/src/compressed_tensors/compressors/helpers.py +161 -0
  12. compressed-tensors-0.3.0/src/compressed_tensors/compressors/sparse_bitmask.py +238 -0
  13. compressed-tensors-0.3.0/src/compressed_tensors/config/__init__.py +18 -0
  14. compressed-tensors-0.3.0/src/compressed_tensors/config/base.py +42 -0
  15. compressed-tensors-0.3.0/src/compressed_tensors/config/dense.py +36 -0
  16. compressed-tensors-0.3.0/src/compressed_tensors/config/sparse_bitmask.py +36 -0
  17. compressed-tensors-0.3.0/src/compressed_tensors/quantization/__init__.py +21 -0
  18. compressed-tensors-0.3.0/src/compressed_tensors/quantization/lifecycle/__init__.py +22 -0
  19. compressed-tensors-0.3.0/src/compressed_tensors/quantization/lifecycle/apply.py +173 -0
  20. compressed-tensors-0.3.0/src/compressed_tensors/quantization/lifecycle/calibration.py +51 -0
  21. compressed-tensors-0.3.0/src/compressed_tensors/quantization/lifecycle/forward.py +136 -0
  22. compressed-tensors-0.3.0/src/compressed_tensors/quantization/lifecycle/frozen.py +46 -0
  23. compressed-tensors-0.3.0/src/compressed_tensors/quantization/lifecycle/initialize.py +96 -0
  24. compressed-tensors-0.3.0/src/compressed_tensors/quantization/observers/__init__.py +21 -0
  25. compressed-tensors-0.3.0/src/compressed_tensors/quantization/observers/base.py +69 -0
  26. compressed-tensors-0.3.0/src/compressed_tensors/quantization/observers/helpers.py +53 -0
  27. compressed-tensors-0.3.0/src/compressed_tensors/quantization/observers/memoryless.py +48 -0
  28. compressed-tensors-0.3.0/src/compressed_tensors/quantization/observers/min_max.py +65 -0
  29. compressed-tensors-0.3.0/src/compressed_tensors/quantization/quant_args.py +85 -0
  30. compressed-tensors-0.3.0/src/compressed_tensors/quantization/quant_config.py +171 -0
  31. compressed-tensors-0.3.0/src/compressed_tensors/quantization/quant_scheme.py +39 -0
  32. compressed-tensors-0.3.0/src/compressed_tensors/quantization/utils/__init__.py +16 -0
  33. compressed-tensors-0.3.0/src/compressed_tensors/quantization/utils/helpers.py +115 -0
  34. compressed-tensors-0.3.0/src/compressed_tensors/registry/__init__.py +17 -0
  35. compressed-tensors-0.3.0/src/compressed_tensors/registry/registry.py +360 -0
  36. compressed-tensors-0.3.0/src/compressed_tensors/utils/__init__.py +16 -0
  37. compressed-tensors-0.3.0/src/compressed_tensors/utils/safetensors_load.py +237 -0
  38. compressed-tensors-0.3.0/src/compressed_tensors.egg-info/PKG-INFO +13 -0
  39. compressed-tensors-0.3.0/src/compressed_tensors.egg-info/SOURCES.txt +43 -0
  40. compressed-tensors-0.3.0/src/compressed_tensors.egg-info/dependency_links.txt +1 -0
  41. compressed-tensors-0.3.0/src/compressed_tensors.egg-info/requires.txt +11 -0
  42. compressed-tensors-0.3.0/src/compressed_tensors.egg-info/top_level.txt +1 -0
  43. compressed-tensors-0.3.0/tests/test_bitmask.py +120 -0
  44. compressed-tensors-0.3.0/tests/test_registry.py +53 -0
@@ -0,0 +1,13 @@
1
+ Metadata-Version: 2.1
2
+ Name: compressed-tensors
3
+ Version: 0.3.0
4
+ Summary: Library for utilization of compressed safetensors of neural network models
5
+ Home-page: UNKNOWN
6
+ Author: Neuralmagic, Inc.
7
+ Author-email: support@neuralmagic.com
8
+ License: UNKNOWN
9
+ Platform: UNKNOWN
10
+ Provides-Extra: dev
11
+
12
+ UNKNOWN
13
+
@@ -0,0 +1,82 @@
1
+ # compressed-tensors
2
+
3
+ This repository extends a [safetensors](https://github.com/huggingface/safetensors) format to efficiently store sparse and/or quantized tensors on disk. `compressed-tensors` format supports multiple compression types to minimize the disk space and facilitate the tensor manipulation.
4
+
5
+ ## Motivation
6
+
7
+ ### Reduce disk space by saving sparse tensors in a compressed format
8
+
9
+ The compressed format stores the data much more efficiently by taking advantage of two properties of tensors:
10
+
11
+ - Sparse tensors -> due to a large number of entries that are equal to zero.
12
+ - Quantized -> due to their low precision representation.
13
+
14
+ ### Introduce an elegant interface to save/load compressed tensors
15
+
16
+ The library provides the user with the ability to compress/decompress tensors. The properties of tensors are defined by human-readable configs, allowing the users to understand the compression format at a quick glance.
17
+
18
+ ## Installation
19
+
20
+ ### Pip
21
+
22
+ ```bash
23
+ pip install compressed-tensors
24
+ ```
25
+
26
+ ### From source
27
+
28
+ ```bash
29
+ git clone https://github.com/neuralmagic/compressed-tensors
30
+ cd compressed-tensors
31
+ pip install -e .
32
+ ```
33
+
34
+ ## Getting started
35
+
36
+ ### Saving/Loading Compressed Tensors (Bitmask Compression)
37
+
38
+ The function `save_compressed` uses the `compression_format` argument to apply compression to tensors.
39
+ The function `load_compressed` reverses the process: converts the compressed weights on disk to decompressed weights in device memory.
40
+
41
+ ```python
42
+ from compressed_tensors import save_compressed, load_compressed, BitmaskConfig
43
+ from torch import Tensor
44
+ from typing import Dict
45
+
46
+ # the example BitmaskConfig method efficiently compresses
47
+ # tensors with large number of zero entries
48
+ compression_config = BitmaskConfig()
49
+
50
+ tensors: Dict[str, Tensor] = {"tensor_1": Tensor(
51
+ [[0.0, 0.0, 0.0],
52
+ [1.0, 1.0, 1.0]]
53
+ )}
54
+ # compress tensors using BitmaskConfig compression format (save them efficiently on disk)
55
+ save_compressed(tensors, "model.safetensors", compression_format=compression_config.format)
56
+
57
+ # decompress tensors (load_compressed returns a generator for memory efficiency)
58
+ decompressed_tensors = {}
59
+ for tensor_name, tensor in load_compressed("model.safetensors", compression_config = compression_config):
60
+ decompressed_tensors[tensor_name] = tensor
61
+ ```
62
+
63
+ ## Saving/Loading Compressed Models (Bitmask Compression)
64
+
65
+ We can apply bitmask compression to a whole model. For more detailed example see `example` directory.
66
+ ```python
67
+ from compressed_tensors import save_compressed_model, load_compressed, BitmaskConfig
68
+ from transformers import AutoModelForCausalLM
69
+
70
+ model_name = "neuralmagic/llama2.c-stories110M-pruned50"
71
+ model = AutoModelForCausalLM.from_pretrained(model_name)
72
+
73
+ original_state_dict = model.state_dict()
74
+
75
+ compression_config = BitmaskConfig()
76
+
77
+ # save compressed model weights
78
+ save_compressed_model(model, "compressed_model.safetensors", compression_format=compression_config.format)
79
+
80
+ # load compressed model weights (`dict` turns generator into a dictionary)
81
+ state_dict = dict(load_compressed("compressed_model.safetensors", compression_config))
82
+ ```
@@ -0,0 +1,3 @@
1
+ [tool.black]
2
+ line-length = 88
3
+ target-version = ['py36']
@@ -0,0 +1,20 @@
1
+ [isort]
2
+ profile = black
3
+ default_section = FIRSTPARTY
4
+ ensure_newline_before_comments = True
5
+ force_grid_wrap = 0
6
+ include_trailing_comma = True
7
+ sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
8
+ line_length = 88
9
+ lines_after_imports = 2
10
+ multi_line_output = 3
11
+ use_parentheses = True
12
+
13
+ [flake8]
14
+ ignore = E203, E251, E701, W503
15
+ max-line-length = 88
16
+
17
+ [egg_info]
18
+ tag_build =
19
+ tag_date = 0
20
+
@@ -0,0 +1,40 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from setuptools import setup, find_packages
17
+ from typing import List, Dict
18
+
19
+ def _setup_packages() -> List:
20
+ return find_packages(
21
+ "src", include=["compressed_tensors", "compressed_tensors.*"], exclude=["*.__pycache__.*"]
22
+ )
23
+
24
+ def _setup_install_requires() -> List:
25
+ return ["torch>=1.7.0", "transformers<=4.40", "pydantic<2.7"]
26
+
27
+ def _setup_extras() -> Dict:
28
+ return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "nbconvert>=7.16.3"]}
29
+
30
+ setup(
31
+ name="compressed-tensors",
32
+ version="0.3.0",
33
+ author="Neuralmagic, Inc.",
34
+ author_email="support@neuralmagic.com",
35
+ description="Library for utilization of compressed safetensors of neural network models",
36
+ extras_require=_setup_extras(),
37
+ install_requires=_setup_install_requires(),
38
+ package_dir={"": "src"},
39
+ packages=_setup_packages(),
40
+ )
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base import *
16
+
17
+ # flake8: noqa
18
+ from .compressors import *
19
+ from .config import *
20
+ from .quantization import QuantizationConfig, QuantizationStatus
21
+ from .utils import *
@@ -0,0 +1,16 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ SPARSITY_CONFIG_NAME = "sparsity_config"
16
+ QUANTIZATION_CONFIG_NAME = "sparseml_quantization_config"
@@ -0,0 +1,25 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+
17
+ from .base import ModelCompressor
18
+ from .dense import DenseCompressor
19
+ from .helpers import (
20
+ infer_compressor_from_model_config,
21
+ load_compressed,
22
+ save_compressed,
23
+ save_compressed_model,
24
+ )
25
+ from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
@@ -0,0 +1,79 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import operator
16
+ from typing import Dict, Generator, Optional, Tuple
17
+
18
+ from compressed_tensors.base import SPARSITY_CONFIG_NAME
19
+ from compressed_tensors.config import CompressionConfig
20
+ from compressed_tensors.registry import RegistryMixin
21
+ from compressed_tensors.utils import get_safetensors_folder
22
+ from torch import Tensor
23
+ from torch.nn import Module, Parameter
24
+ from tqdm import tqdm
25
+
26
+
27
+ __all__ = ["ModelCompressor"]
28
+
29
+
30
+ class ModelCompressor(RegistryMixin):
31
+ """
32
+ Base class representing a model compression algorithm.
33
+
34
+ :param config: config specifying compression parameters
35
+ """
36
+
37
+ def __init__(self, config: Optional[CompressionConfig] = None):
38
+ self.config = config
39
+
40
+ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
41
+ """
42
+ Compresses a dense state dict
43
+
44
+ :param model_state: state dict of uncompressed model
45
+ :return: compressed state dict
46
+ """
47
+ raise NotImplementedError()
48
+
49
+ def decompress(
50
+ self, path_to_model_or_tensors: str
51
+ ) -> Generator[Tuple[str, Tensor], None, None]:
52
+ """
53
+ Reads a compressed state dict located at path_to_model_or_tensors
54
+ and returns a generator for sequentially decompressing back to a
55
+ dense state dict
56
+
57
+ :param model_path: path to compressed safetensors model (directory with
58
+ one or more safetensors files) or compressed tensors file
59
+ :return: compressed state dict
60
+ """
61
+ raise NotImplementedError()
62
+
63
+ def overwrite_weights(self, model_path: str, model: Module):
64
+ """
65
+ Overwrites the weights in model with weights decompressed from model_path
66
+
67
+ :param model_path: path to compressed weights
68
+ :param model: pytorch model to load decompressed weights into
69
+ """
70
+ model_path = get_safetensors_folder(model_path)
71
+ dense_gen = self.decompress(model_path)
72
+ for name, data in tqdm(dense_gen, desc="Decompressing model"):
73
+ # loading the decompressed weights into the model
74
+ model_device = operator.attrgetter(name)(model).device
75
+ data_new = Parameter(data.to(model_device))
76
+ data_old = operator.attrgetter(name)(model)
77
+ data_old.data = data_new.data
78
+
79
+ setattr(model, SPARSITY_CONFIG_NAME, self.config)
@@ -0,0 +1,34 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, Generator, Tuple
16
+
17
+ from compressed_tensors.compressors import ModelCompressor
18
+ from compressed_tensors.config import CompressionFormat
19
+ from torch import Tensor
20
+
21
+
22
+ @ModelCompressor.register(name=CompressionFormat.dense_sparsity.value)
23
+ class DenseCompressor(ModelCompressor):
24
+ """
25
+ Identity compressor for dense models, returns the original state_dict
26
+ """
27
+
28
+ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
29
+ return model_state
30
+
31
+ def decompress(
32
+ self, path_to_model_or_tensors: str, device: str
33
+ ) -> Generator[Tuple[str, Tensor], None, None]:
34
+ return iter([])
@@ -0,0 +1,161 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import Dict, Generator, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from compressed_tensors.base import SPARSITY_CONFIG_NAME
20
+ from compressed_tensors.compressors import ModelCompressor
21
+ from compressed_tensors.config import CompressionConfig, CompressionFormat
22
+ from compressed_tensors.utils.safetensors_load import get_weight_mappings
23
+ from safetensors import safe_open
24
+ from safetensors.torch import save_file
25
+ from torch import Tensor
26
+ from transformers import AutoConfig
27
+
28
+
29
+ __all__ = [
30
+ "infer_compressor_from_model_config",
31
+ "load_compressed",
32
+ "save_compressed",
33
+ "save_compressed_model",
34
+ ]
35
+
36
+
37
+ def infer_compressor_from_model_config(
38
+ pretrained_model_name_or_path: str,
39
+ ) -> Optional[ModelCompressor]:
40
+ """
41
+ Given a path to a model config, extract a sparsity config if it exists and return
42
+ the associated ModelCompressor
43
+
44
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
45
+ :return: matching compressor if config contains a sparsity config
46
+ """
47
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
48
+ sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
49
+ if sparsity_config is None:
50
+ return None
51
+
52
+ format = sparsity_config.get("format")
53
+ sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
54
+ compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
55
+ return compressor
56
+
57
+
58
+ def save_compressed(
59
+ tensors: Dict[str, Tensor],
60
+ save_path: Union[str, Path],
61
+ compression_format: Optional[CompressionFormat] = None,
62
+ ):
63
+ """
64
+ Save compressed tensors to disk. If tensors are not compressed,
65
+ save them as is.
66
+
67
+ :param tensors: dictionary of tensors to compress
68
+ :param save_path: path to save compressed tensors
69
+ :param compression_format: compression format used for the tensors
70
+ :return: compression config, if tensors were compressed - None otherwise
71
+ """
72
+ if tensors is None or len(tensors) == 0:
73
+ raise ValueError("No tensors or empty tensors provided to compress")
74
+
75
+ # if no compression_format specified, default to `dense_sparsity`
76
+ compression_format = compression_format or CompressionFormat.dense_sparsity.value
77
+
78
+ if not (
79
+ compression_format in ModelCompressor.registered_names()
80
+ or compression_format in ModelCompressor.registered_aliases()
81
+ ):
82
+ raise ValueError(
83
+ f"Unknown compression format: {compression_format}. "
84
+ f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501
85
+ )
86
+
87
+ # compress
88
+ compressor = ModelCompressor.load_from_registry(compression_format)
89
+ # save compressed tensors
90
+ compressed_tensors = compressor.compress(tensors)
91
+ save_file(compressed_tensors, save_path)
92
+
93
+
94
+ def load_compressed(
95
+ compressed_tensors: Union[str, Path],
96
+ compression_config: CompressionConfig = None,
97
+ device: Optional[str] = "cpu",
98
+ ) -> Generator[Tuple[str, Tensor], None, None]:
99
+ """
100
+ Load compressed tensors from disk.
101
+ If tensors are not compressed, load them as is.
102
+
103
+ :param compressed_tensors: path to compressed tensors.
104
+ This can be a path to a file or a directory containing
105
+ one or multiple safetensor files (if multiple - in the format
106
+ assumed by huggingface)
107
+ :param compression_config: compression config to use for decompressing tensors.
108
+ :param device: device to move tensors to. If None, tensors are loaded on CPU.
109
+ :param return_dict: if True, return a dictionary of decompressed tensors
110
+ :return a generator that yields the name and tensor of the decompressed tensor
111
+ """
112
+ if compressed_tensors is None or not Path(compressed_tensors).exists():
113
+ raise ValueError("No compressed tensors provided to load")
114
+
115
+ if (
116
+ compression_config is None
117
+ or compression_config.format == CompressionFormat.dense_sparsity.value
118
+ ):
119
+ # if no compression_config specified, or `dense_sparsity` format specified,
120
+ # assume tensors are not compressed on disk
121
+ weight_mappings = get_weight_mappings(compressed_tensors)
122
+ for weight_name, file_with_weight_name in weight_mappings.items():
123
+ with safe_open(file_with_weight_name, framework="pt", device=device) as f:
124
+ weight = f.get_tensor(weight_name)
125
+ yield weight_name, weight
126
+ else:
127
+ # decompress tensors
128
+ compression_format = compression_config.format
129
+ compressor = ModelCompressor.load_from_registry(
130
+ compression_format, config=compression_config
131
+ )
132
+ yield from compressor.decompress(compressed_tensors, device=device)
133
+
134
+
135
+ def save_compressed_model(
136
+ model: torch.nn.Module,
137
+ filename: str,
138
+ compression_format: Optional[CompressionFormat] = None,
139
+ force_contiguous: bool = True,
140
+ ):
141
+ """
142
+ Wrapper around safetensors `save_model` helper function, which allows for
143
+ saving compressed model to disk.
144
+
145
+ Note: The model is assumed to have a
146
+ state_dict with unique entries
147
+
148
+ :param model: model to save on disk
149
+ :param filename: filename location to save the file
150
+ :param compression_format: compression format used for the model
151
+ :param force_contiguous: forcing the state_dict to be saved as contiguous tensors
152
+ """
153
+ state_dict = model.state_dict()
154
+ if force_contiguous:
155
+ state_dict = {k: v.contiguous() for k, v in state_dict.items()}
156
+ try:
157
+ save_compressed(state_dict, filename, compression_format=compression_format)
158
+ except ValueError as e:
159
+ msg = str(e)
160
+ msg += " Or use save_compressed_model(..., force_contiguous=True), read the docs for potential caveats." # noqa E501
161
+ raise ValueError(msg)