compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/base.py +3 -1
  2. compressed_tensors/compressors/__init__.py +9 -1
  3. compressed_tensors/compressors/base.py +12 -55
  4. compressed_tensors/compressors/dense.py +5 -5
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/marlin_24.py +251 -0
  7. compressed_tensors/compressors/model_compressor.py +336 -0
  8. compressed_tensors/compressors/naive_quantized.py +144 -0
  9. compressed_tensors/compressors/pack_quantized.py +219 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +4 -4
  11. compressed_tensors/config/base.py +9 -4
  12. compressed_tensors/config/dense.py +4 -4
  13. compressed_tensors/config/sparse_bitmask.py +3 -3
  14. compressed_tensors/quantization/lifecycle/__init__.py +2 -0
  15. compressed_tensors/quantization/lifecycle/apply.py +204 -31
  16. compressed_tensors/quantization/lifecycle/calibration.py +20 -1
  17. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  18. compressed_tensors/quantization/lifecycle/forward.py +214 -62
  19. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  20. compressed_tensors/quantization/lifecycle/helpers.py +53 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +62 -5
  22. compressed_tensors/quantization/observers/base.py +66 -23
  23. compressed_tensors/quantization/observers/helpers.py +69 -11
  24. compressed_tensors/quantization/observers/memoryless.py +17 -9
  25. compressed_tensors/quantization/observers/min_max.py +44 -13
  26. compressed_tensors/quantization/quant_args.py +47 -3
  27. compressed_tensors/quantization/quant_config.py +104 -23
  28. compressed_tensors/quantization/quant_scheme.py +183 -2
  29. compressed_tensors/quantization/utils/helpers.py +142 -8
  30. compressed_tensors/utils/__init__.py +4 -0
  31. compressed_tensors/utils/helpers.py +54 -7
  32. compressed_tensors/utils/offload.py +104 -0
  33. compressed_tensors/utils/permutations_24.py +65 -0
  34. compressed_tensors/utils/safetensors_load.py +3 -2
  35. compressed_tensors/utils/semi_structured_conversions.py +341 -0
  36. compressed_tensors/version.py +53 -0
  37. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
  38. compressed_tensors-0.5.0.dist-info/RECORD +48 -0
  39. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
  40. compressed_tensors-0.3.3.dist-info/RECORD +0 -38
  41. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
  42. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -12,21 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
16
15
  from typing import Optional
17
16
 
18
- from compressed_tensors.base import SPARSITY_CONFIG_NAME
19
- from compressed_tensors.compressors import ModelCompressor
20
- from compressed_tensors.config import CompressionConfig
17
+ import torch
21
18
  from transformers import AutoConfig
22
19
 
23
20
 
24
- __all__ = ["infer_compressor_from_model_config"]
21
+ __all__ = [
22
+ "infer_compressor_from_model_config",
23
+ "fix_fsdp_module_name",
24
+ "tensor_follows_mask_structure",
25
+ ]
26
+
27
+ FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
25
28
 
26
29
 
27
30
  def infer_compressor_from_model_config(
28
31
  pretrained_model_name_or_path: str,
29
- ) -> Optional[ModelCompressor]:
32
+ ) -> Optional["ModelCompressor"]: # noqa: F821
30
33
  """
31
34
  Given a path to a model config, extract a sparsity config if it exists and return
32
35
  the associated ModelCompressor
@@ -34,8 +37,11 @@ def infer_compressor_from_model_config(
34
37
  :param pretrained_model_name_or_path: path to model config on disk or HF hub
35
38
  :return: matching compressor if config contains a sparsity config
36
39
  """
40
+ from compressed_tensors.compressors import ModelCompressor
41
+ from compressed_tensors.config import CompressionConfig
42
+
37
43
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
38
- sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
44
+ sparsity_config = ModelCompressor.parse_sparsity_config(config)
39
45
  if sparsity_config is None:
40
46
  return None
41
47
 
@@ -43,3 +49,44 @@ def infer_compressor_from_model_config(
43
49
  sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
44
50
  compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
45
51
  return compressor
52
+
53
+
54
+ # TODO: There is already the same function in
55
+ # SparseML, should be moved to a shared location
56
+ # in the future
57
+ def fix_fsdp_module_name(name: str) -> str:
58
+ """
59
+ Remove FSDP wrapper prefixes from a module name
60
+ Accounts for scenario where FSDP_WRAPPER_NAME is
61
+ at the end of the name, as well as in the middle.
62
+ :param name: name to strip
63
+ :return: stripped name
64
+ """
65
+ return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
66
+ "." + FSDP_WRAPPER_NAME, ""
67
+ )
68
+
69
+
70
+ def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
71
+ """
72
+ :param tensor: tensor to check
73
+ :param mask: mask structure to check for, in the format "n:m"
74
+ :return: True if the tensor follows the mask structure, False otherwise.
75
+ Note, some weights can incidentally be zero, so we check for
76
+ atleast n zeros in each chunk of size m
77
+ """
78
+
79
+ n, m = tuple(map(int, mask.split(":")))
80
+ # Reshape the tensor into chunks of size m
81
+ tensor = tensor.view(-1, m)
82
+
83
+ # Count the number of zeros in each chunk
84
+ zero_counts = (tensor == 0).sum(dim=1)
85
+
86
+ # Check if the number of zeros in each chunk atleast n
87
+ # Greater than sign is needed as some weights can incidentally
88
+ # be zero
89
+ if not torch.all(zero_counts >= n).item():
90
+ raise ValueError()
91
+
92
+ return True
@@ -0,0 +1,104 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch.nn import Module
17
+
18
+
19
+ __all__ = [
20
+ "is_module_offloaded",
21
+ "get_execution_device",
22
+ "get_offloaded_device",
23
+ "update_prefix_dict",
24
+ "update_parameter_data",
25
+ ]
26
+
27
+
28
+ def is_module_offloaded(module: Module) -> bool:
29
+ """
30
+ :param module: layer to check
31
+ :return: True if layer is offloaded from GPU, False otherwise
32
+ """
33
+ return hasattr(module, "_hf_hook") and module._hf_hook.offload
34
+
35
+
36
+ def get_execution_device(module: Module) -> torch.device:
37
+ """
38
+ :param module: layer to check
39
+ :return: device layer is loaded onto during forward pass
40
+ """
41
+ if is_module_offloaded(module):
42
+ return module._hf_hook.execution_device
43
+ return next(module.parameters()).device
44
+
45
+
46
+ def get_offloaded_device(module: Module) -> torch.device:
47
+ """
48
+ :param module: layer to check
49
+ :return: device layer is offloaded to onto after forward pass
50
+ """
51
+ if is_module_offloaded(module):
52
+ first_key = list(module._hf_hook.weights_map.keys())[0]
53
+ prefix_dataset = module._hf_hook.weights_map.dataset
54
+ return prefix_dataset[first_key].device
55
+ return next(module.parameters()).device
56
+
57
+
58
+ def update_prefix_dict(module: Module, key: str, data: torch.Tensor):
59
+ """
60
+ Updates the offloaded state dict for a given module. Parameter named key is replaced
61
+ by data. This is neccesary because parameter updates for offloaded modules do not
62
+ persist automatically between loads. This function only affects the offloaded
63
+ state dict and not the current state of the loaded module.
64
+
65
+ :param module: layer containing the parameter to update
66
+ :param key: name of parameter to update
67
+ :param data: tensor to update parameter with in the offloaded state dict
68
+ """
69
+ if not is_module_offloaded(module):
70
+ raise ValueError("Prefix dict is only applicable to offloaded modules")
71
+ prefix_dict = module._hf_hook.weights_map
72
+ prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data
73
+
74
+
75
+ def update_parameter_data(
76
+ module: Module, new_param_data: torch.Tensor, param_name: str
77
+ ):
78
+ """
79
+ Updates the paramter value named param_name for a given module. This function
80
+ updates both the current loaded module state and the offloaded state dict if
81
+ the module is offloaded. This is neccesary because parameter updates for offloaded
82
+ modules do not persist automatically between loads.
83
+
84
+ :param module: layer containing the parameter to update
85
+ :param new_param_data: tensor to update parameter with
86
+ :param param_name:
87
+ """
88
+ device = next(module.parameters()).device
89
+
90
+ offloaded = False
91
+ if is_module_offloaded(module):
92
+ offload_device = get_offloaded_device(module)
93
+ offloaded = True
94
+
95
+ parameter = getattr(module, param_name, None)
96
+ dtype = parameter.dtype
97
+ parameter.data = new_param_data.to(device).to(dtype)
98
+
99
+ if offloaded:
100
+ prefix_dict = module._hf_hook.weights_map.dataset
101
+ prefix = module._hf_hook.weights_map.prefix
102
+ prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to(
103
+ dtype
104
+ )
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy
17
+ import torch
18
+
19
+
20
+ __all__ = ["get_permutations_24"]
21
+
22
+
23
+ # Precompute permutations for Marlin24 weight and scale shuffling
24
+ # Originally implemented in nm-vllm/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py # noqa: E501
25
+ #
26
+ # Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight
27
+ # data so that it is compatible with the tensor-core format that is described here:
28
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
29
+ #
30
+ # As a result of this reordering, the vector loads inside the kernel will get the data
31
+ # as it is needed for tensor-core (without the need to use ldmatrix instructions)
32
+ def get_permutations_24(num_bits):
33
+ perm_list = []
34
+ for i in range(32):
35
+ perm1 = []
36
+ col = i // 4
37
+ col_o = col // 2
38
+ for block in [0, 1]:
39
+ for row in [
40
+ 2 * (i % 4),
41
+ 2 * (i % 4) + 1,
42
+ 2 * (i % 4 + 4),
43
+ 2 * (i % 4 + 4) + 1,
44
+ ]:
45
+ perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
46
+ for j in range(4):
47
+ perm_list.extend([p + 1 * j for p in perm1])
48
+ perm = numpy.array(perm_list)
49
+
50
+ if num_bits == 4:
51
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
52
+ elif num_bits == 8:
53
+ interleave = numpy.array([0, 2, 1, 3])
54
+ else:
55
+ raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
56
+
57
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
58
+ perm = torch.from_numpy(perm)
59
+ scale_perm = []
60
+ for i in range(8):
61
+ scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
62
+ scale_perm_single = []
63
+ for i in range(8):
64
+ scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
65
+ return perm, scale_perm, scale_perm_single
@@ -31,6 +31,7 @@ __all__ = [
31
31
  "get_weight_mappings",
32
32
  "get_nested_weight_mappings",
33
33
  "get_quantization_state_dict",
34
+ "is_quantization_param",
34
35
  ]
35
36
 
36
37
 
@@ -214,7 +215,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
214
215
  weight_mappings = get_weight_mappings(model_path)
215
216
  state_dict = {}
216
217
  for weight_name, safe_path in weight_mappings.items():
217
- if not _is_quantization_weight(weight_name):
218
+ if not is_quantization_param(weight_name):
218
219
  continue
219
220
  with safe_open(safe_path, framework="pt", device="cpu") as f:
220
221
  state_dict[weight_name] = f.get_tensor(weight_name)
@@ -222,7 +223,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
222
223
  return state_dict
223
224
 
224
225
 
225
- def _is_quantization_weight(name: str) -> bool:
226
+ def is_quantization_param(name: str) -> bool:
226
227
  """
227
228
  Checks is a parameter name is associated with a quantization parameter
228
229
 
@@ -0,0 +1,341 @@
1
+ #
2
+ # Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
3
+ # Pulled from nm-vllm/vllm/model_executor/layers/quantization/utils/format_24.py
4
+ #
5
+ # flake8: noqa
6
+ # isort: skip_file
7
+
8
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing,
17
+ # software distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import torch
23
+
24
+
25
+ __all__ = [
26
+ "sparse_semi_structured_from_dense_cutlass",
27
+ "sparse_semi_structured_to_dense_cutlass",
28
+ "mask_creator",
29
+ ]
30
+
31
+ # This is PyTorch implementation of main part of reorder_meta()
32
+ # function, from tools/util/include/cutlass/util/host_reorder.h file
33
+ # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
34
+ # GEMM decides upon layout of this matrix, and at the moment for the
35
+ # sparse GEMM executed on tensor cores, this is layout described by
36
+ # ColumnMajorInterleaved<2> data structure, in
37
+ # include/cutlass/layout/matrix.h of CUTLASS source tree. The
38
+ # reordering of meta matrix into meta_reordered matrix calculated
39
+ # according to these segments of CUTLASS code is re-implemented here.
40
+ # Note that this calculation produces offsets for scattering metadata
41
+ # matrix elements into reordered metadata matrix elements (or,
42
+ # equivalently, for gathering reordered metadata matrix element back
43
+ # into metadata matrix elements).
44
+ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
45
+ dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
46
+ dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
47
+
48
+ # Reorder the rows, then swizzle the 2x2 blocks.
49
+ group_x = 64
50
+ group_y = 32 if meta_dtype.itemsize == 2 else 16
51
+
52
+ dst_rows = (
53
+ dst_rows // group_x * group_x
54
+ + (dst_rows % 2) * 2
55
+ + (dst_rows % 8) // 4
56
+ + ((dst_rows % group_y) % 4) // 2 * 32
57
+ + ((dst_rows % group_x) // 8) * 4
58
+ )
59
+
60
+ topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
61
+ bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
62
+ dst_rows += topright - bottomleft
63
+ dst_cols -= topright - bottomleft
64
+
65
+ # Assumed that meta tensor is to be stored in CUTLASS
66
+ # InterleavedColumnMajor layout, and reverse engineered
67
+ # corresponding code to store values into this tensor.
68
+ interleave = 2
69
+ cols_maj = dst_cols // interleave
70
+ cols_min = dst_cols % interleave
71
+ return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
72
+
73
+
74
+ # This function converts dense matrix into sparse semi-structured
75
+ # representation, producing "compressed" matrix, in the layout used by
76
+ # CUTLASS backend, and corresponding metadata matrix.
77
+ def sparse_semi_structured_from_dense_cutlass(dense):
78
+ if dense.dim() != 2:
79
+ raise RuntimeError(
80
+ f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
81
+ )
82
+
83
+ m, k = dense.shape
84
+ device = dense.device
85
+
86
+ meta_dtype = torch.int8
87
+ if dense.dtype == torch.int8:
88
+ meta_dtype = torch.int32
89
+ elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
90
+ meta_dtype = torch.int16
91
+ else:
92
+ raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
93
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
94
+ if quadbits_per_meta_elem not in (4, 8):
95
+ raise RuntimeError("Invalid number of elements per meta element calculated")
96
+
97
+ if meta_dtype == torch.int32:
98
+ if m % 16 != 0:
99
+ raise RuntimeError(
100
+ f"Number of rows of dense matrix {m} must be divisible by 16"
101
+ )
102
+ else:
103
+ if m % 32 != 0:
104
+ raise RuntimeError(
105
+ f"Number of rows of dense matrix {m} must be divisible by 32"
106
+ )
107
+ if k % (4 * quadbits_per_meta_elem) != 0:
108
+ raise RuntimeError(
109
+ f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
110
+ )
111
+
112
+ if dense.dtype != torch.float:
113
+ ksparse = 4
114
+ dense_4 = dense.view(-1, k // ksparse, ksparse)
115
+ m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
116
+ else:
117
+ ksparse = 2
118
+ dense_2 = dense.view(-1, k // ksparse, ksparse)
119
+ m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
120
+ meta_ncols = k // (ksparse * quadbits_per_meta_elem)
121
+
122
+ # Encoding quadruples of True/False values as follows:
123
+ # [True, True, False, False] -> 0b0100
124
+ # [True, False, True, False] -> 0b1000
125
+ # [False, True, True, False] -> 0b1001
126
+ # [True, False, False, True ] -> 0b1100
127
+ # [False, True, False, True ] -> 0b1101
128
+ # [False, False, True, True ] -> 0b1110
129
+ # Thus, lower two bits in the encoding are index of the True value
130
+ # at the lowest index in the quadruple, and the higher two bits in
131
+ # the encoding are index of the other True value in the quadruple.
132
+ # In case there are less than two True values, than False value or
133
+ # values at some index or indices are considered True for the
134
+ # encoding. In case there are more than two True values, then the
135
+ # excess True value(s) at some indices are considered False for
136
+ # the encoding. The exact encodings used for these cases are as
137
+ # follows:
138
+ # [False, False, False, False] -> 0b1110
139
+ # [False, False, False, True ] -> 0b1110
140
+ # [False, False, True, False] -> 0b1110
141
+ # [False, True, False, False] -> 0b1001
142
+ # [False, True, True, True ] -> 0b1101
143
+ # [True, False, False, False] -> 0b1000
144
+ # [True, False, True, True ] -> 0b1100
145
+ # [True, True, False, True ] -> 0b0100
146
+ # [True, True, True, False] -> 0b0100
147
+ # [True, True, True, True ] -> 0b0100
148
+ # These particular encodings are chosen, with the help of Espresso
149
+ # logic minimizer software, for the purpose of minimization of
150
+ # corresponding Boolean functions, that translate non-zero flags
151
+ # into encoding bits. Note also possible choices for the first
152
+ # and last of these encodings were limited only to (0b0100,
153
+ # 0b1110), in order to produce valid encodings for 1:2 sparsity
154
+ # case.
155
+
156
+ expr0 = m0 & m1
157
+ expr1 = ~m0 & m1
158
+ expr2 = ~m0 & ~m1
159
+ bit0 = expr1
160
+ bit1 = expr2
161
+ bit2 = expr0 | expr2 | m3
162
+ bit3 = expr1 | ~m1
163
+ idxs0 = bit0 | (bit1.to(torch.int64) << 1)
164
+ idxs1 = bit2 | (bit3.to(torch.int64) << 1)
165
+
166
+ if dense.dtype != torch.float:
167
+ sparse0 = dense_4.gather(
168
+ -1, idxs0.unsqueeze(-1)
169
+ ) # type: ignore[possibly-undefined]
170
+ sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
171
+ sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
172
+ else:
173
+ sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
174
+ m, k // 2
175
+ ) # type: ignore[possibly-undefined]
176
+
177
+ meta_4 = idxs0 | (idxs1 << 2)
178
+ meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
179
+
180
+ if quadbits_per_meta_elem == 4:
181
+ meta = (
182
+ meta_n[:, :, 0]
183
+ | (meta_n[:, :, 1] << 4)
184
+ | (meta_n[:, :, 2] << 8)
185
+ | (meta_n[:, :, 3] << 12)
186
+ )
187
+ elif quadbits_per_meta_elem == 8:
188
+ meta = (
189
+ meta_n[:, :, 0]
190
+ | (meta_n[:, :, 1] << 4)
191
+ | (meta_n[:, :, 2] << 8)
192
+ | (meta_n[:, :, 3] << 12)
193
+ | (meta_n[:, :, 4] << 16)
194
+ | (meta_n[:, :, 5] << 20)
195
+ | (meta_n[:, :, 6] << 24)
196
+ | (meta_n[:, :, 7] << 28)
197
+ )
198
+
199
+ # Reorder meta tensor elements.
200
+ meta_reordered = meta.new_empty(
201
+ (m * meta_ncols,)
202
+ ) # type: ignore[possibly-undefined]
203
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
204
+ m, meta_ncols, meta_dtype, device
205
+ )
206
+ meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
207
+
208
+ return (sparse, meta_reordered.view(m, meta_ncols))
209
+
210
+
211
+ # This function performs reverse of the function above - it
212
+ # reconstructs dense matrix from a pair of "compressed" matrix, given
213
+ # in the layout used by CUTLASS backend, and accompanying metadata
214
+ # matrix.
215
+ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
216
+ if sparse.dim() != 2:
217
+ raise RuntimeError(
218
+ f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
219
+ )
220
+
221
+ m, k = sparse.shape
222
+ device = sparse.device
223
+
224
+ if meta_reordered.dim() != 2:
225
+ raise RuntimeError(
226
+ f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
227
+ )
228
+ if meta_reordered.device != device:
229
+ raise RuntimeError(
230
+ f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
231
+ )
232
+
233
+ meta_dtype = meta_reordered.dtype
234
+ if meta_dtype not in (torch.int16, torch.int32):
235
+ raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
236
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
237
+
238
+ ksparse = 4 if sparse.dtype != torch.float else 2
239
+
240
+ meta_nrows, meta_ncols = meta_reordered.shape
241
+ if meta_nrows != m:
242
+ raise RuntimeError(
243
+ f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
244
+ )
245
+ if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
246
+ raise RuntimeError(
247
+ f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
248
+ "expected according to the number of columns of meta matrix"
249
+ )
250
+
251
+ # Undo meta tensor elements reordering.
252
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
253
+ m, meta_ncols, meta_dtype, device
254
+ )
255
+ meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
256
+
257
+ # Unpack sparse tensor back to original dense tensor, using
258
+ # information provided by meta tensor. Note that torch.float
259
+ # datatype is handled pretty much the same as
260
+ # torch.half/torch.bfloat16, as metadata for a pair of torch.float
261
+ # value is encoded as if underlying 8 bytes contain four
262
+ # torch.half/torch.bfloat16 values, where either first two or last
263
+ # two are zeros.
264
+ meta_2 = torch.empty(
265
+ (m, meta_ncols, 2 * quadbits_per_meta_elem),
266
+ dtype=meta_dtype,
267
+ device=device,
268
+ )
269
+ if quadbits_per_meta_elem == 4:
270
+ meta_2[:, :, 0] = meta & 0b11
271
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
272
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
273
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
274
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
275
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
276
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
277
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
278
+ elif quadbits_per_meta_elem == 8:
279
+ meta_2[:, :, 0] = meta & 0b11
280
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
281
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
282
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
283
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
284
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
285
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
286
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
287
+ meta_2[:, :, 8] = (meta >> 16) & 0b11
288
+ meta_2[:, :, 9] = (meta >> 18) & 0b11
289
+ meta_2[:, :, 10] = (meta >> 20) & 0b11
290
+ meta_2[:, :, 11] = (meta >> 22) & 0b11
291
+ meta_2[:, :, 12] = (meta >> 24) & 0b11
292
+ meta_2[:, :, 13] = (meta >> 26) & 0b11
293
+ meta_2[:, :, 14] = (meta >> 28) & 0b11
294
+ meta_2[:, :, 15] = (meta >> 30) & 0b11
295
+
296
+ dense_offsets = meta_2.view(-1) + (
297
+ torch.arange(0, 2 * m * k // ksparse, device=device) * 4
298
+ ).view(-1, 1).repeat(1, 2).view(-1)
299
+
300
+ dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
301
+ if sparse.dtype != torch.float:
302
+ # dense.scatter_(0, dense_offsets, sparse.view(-1))
303
+ dense.scatter_(0, dense_offsets, sparse.reshape(-1))
304
+ else:
305
+ dense.view(torch.half).scatter_(
306
+ 0, dense_offsets, sparse.view(torch.half).view(-1)
307
+ )
308
+
309
+ return dense.view(m, 2 * k)
310
+
311
+
312
+ def mask_creator(tensor):
313
+ """
314
+ Class for creating N:M sparsity masks.
315
+ Masks will be created using the N:M ratio, where for every block of
316
+ M weights, N will be pruned based on ranked weight value. Each mask
317
+ will correspond to the given tensor.
318
+
319
+ :param N: The number of weights in a group to keep
320
+ :param M: The size of a weight group
321
+ """
322
+ N = 2
323
+ M = 4
324
+
325
+ mask = None
326
+ # for i, tensor in enumerate(tensors):
327
+ if tensor.numel() % M != 0:
328
+ raise ValueError(
329
+ f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
330
+ )
331
+
332
+ num_groups = tensor.numel() // M
333
+
334
+ # N:M sparsity for linear layers
335
+ tensor_temp = tensor.detach().abs().reshape(num_groups, M)
336
+ index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
337
+
338
+ w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
339
+ mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
340
+
341
+ return mask
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Functionality for storing and setting the version info for SparseML
17
+ """
18
+
19
+
20
+ version_base = "0.5.0"
21
+ is_release = True # change to True to set the generated version as a release version
22
+
23
+
24
+ def _generate_version(
25
+ is_release: bool,
26
+ version_base: str,
27
+ ):
28
+ from datetime import date
29
+
30
+ if is_release:
31
+ return version_base
32
+ else:
33
+ return f"{version_base}.{date.today().strftime('%Y%m%d')}"
34
+
35
+
36
+ __all__ = [
37
+ "__version__",
38
+ "version_base",
39
+ "is_release",
40
+ "version",
41
+ "version_major",
42
+ "version_minor",
43
+ "version_patch",
44
+ "version_build",
45
+ "version_major_minor",
46
+ ]
47
+ __version__ = _generate_version(is_release, version_base)
48
+
49
+ version = __version__
50
+ version_major, version_minor, version_patch, version_build = version.split(".") + (
51
+ [None] if len(version.split(".")) < 4 else []
52
+ ) # handle conditional for version being 3 parts or 4 (4 containing build date)
53
+ version_major_minor = f"{version_major}.{version_minor}"