compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +3 -1
- compressed_tensors/compressors/__init__.py +9 -1
- compressed_tensors/compressors/base.py +12 -55
- compressed_tensors/compressors/dense.py +5 -5
- compressed_tensors/compressors/helpers.py +12 -12
- compressed_tensors/compressors/marlin_24.py +251 -0
- compressed_tensors/compressors/model_compressor.py +336 -0
- compressed_tensors/compressors/naive_quantized.py +144 -0
- compressed_tensors/compressors/pack_quantized.py +219 -0
- compressed_tensors/compressors/sparse_bitmask.py +4 -4
- compressed_tensors/config/base.py +9 -4
- compressed_tensors/config/dense.py +4 -4
- compressed_tensors/config/sparse_bitmask.py +3 -3
- compressed_tensors/quantization/lifecycle/__init__.py +2 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -31
- compressed_tensors/quantization/lifecycle/calibration.py +20 -1
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +214 -62
- compressed_tensors/quantization/lifecycle/frozen.py +4 -0
- compressed_tensors/quantization/lifecycle/helpers.py +53 -0
- compressed_tensors/quantization/lifecycle/initialize.py +62 -5
- compressed_tensors/quantization/observers/base.py +66 -23
- compressed_tensors/quantization/observers/helpers.py +69 -11
- compressed_tensors/quantization/observers/memoryless.py +17 -9
- compressed_tensors/quantization/observers/min_max.py +44 -13
- compressed_tensors/quantization/quant_args.py +47 -3
- compressed_tensors/quantization/quant_config.py +104 -23
- compressed_tensors/quantization/quant_scheme.py +183 -2
- compressed_tensors/quantization/utils/helpers.py +142 -8
- compressed_tensors/utils/__init__.py +4 -0
- compressed_tensors/utils/helpers.py +54 -7
- compressed_tensors/utils/offload.py +104 -0
- compressed_tensors/utils/permutations_24.py +65 -0
- compressed_tensors/utils/safetensors_load.py +3 -2
- compressed_tensors/utils/semi_structured_conversions.py +341 -0
- compressed_tensors/version.py +53 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
- compressed_tensors-0.5.0.dist-info/RECORD +48 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
- compressed_tensors-0.3.3.dist-info/RECORD +0 -38
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -12,21 +12,24 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
15
|
from typing import Optional
|
17
16
|
|
18
|
-
|
19
|
-
from compressed_tensors.compressors import ModelCompressor
|
20
|
-
from compressed_tensors.config import CompressionConfig
|
17
|
+
import torch
|
21
18
|
from transformers import AutoConfig
|
22
19
|
|
23
20
|
|
24
|
-
__all__ = [
|
21
|
+
__all__ = [
|
22
|
+
"infer_compressor_from_model_config",
|
23
|
+
"fix_fsdp_module_name",
|
24
|
+
"tensor_follows_mask_structure",
|
25
|
+
]
|
26
|
+
|
27
|
+
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
25
28
|
|
26
29
|
|
27
30
|
def infer_compressor_from_model_config(
|
28
31
|
pretrained_model_name_or_path: str,
|
29
|
-
) -> Optional[ModelCompressor]:
|
32
|
+
) -> Optional["ModelCompressor"]: # noqa: F821
|
30
33
|
"""
|
31
34
|
Given a path to a model config, extract a sparsity config if it exists and return
|
32
35
|
the associated ModelCompressor
|
@@ -34,8 +37,11 @@ def infer_compressor_from_model_config(
|
|
34
37
|
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
35
38
|
:return: matching compressor if config contains a sparsity config
|
36
39
|
"""
|
40
|
+
from compressed_tensors.compressors import ModelCompressor
|
41
|
+
from compressed_tensors.config import CompressionConfig
|
42
|
+
|
37
43
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
38
|
-
sparsity_config =
|
44
|
+
sparsity_config = ModelCompressor.parse_sparsity_config(config)
|
39
45
|
if sparsity_config is None:
|
40
46
|
return None
|
41
47
|
|
@@ -43,3 +49,44 @@ def infer_compressor_from_model_config(
|
|
43
49
|
sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
|
44
50
|
compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
|
45
51
|
return compressor
|
52
|
+
|
53
|
+
|
54
|
+
# TODO: There is already the same function in
|
55
|
+
# SparseML, should be moved to a shared location
|
56
|
+
# in the future
|
57
|
+
def fix_fsdp_module_name(name: str) -> str:
|
58
|
+
"""
|
59
|
+
Remove FSDP wrapper prefixes from a module name
|
60
|
+
Accounts for scenario where FSDP_WRAPPER_NAME is
|
61
|
+
at the end of the name, as well as in the middle.
|
62
|
+
:param name: name to strip
|
63
|
+
:return: stripped name
|
64
|
+
"""
|
65
|
+
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
|
66
|
+
"." + FSDP_WRAPPER_NAME, ""
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
|
71
|
+
"""
|
72
|
+
:param tensor: tensor to check
|
73
|
+
:param mask: mask structure to check for, in the format "n:m"
|
74
|
+
:return: True if the tensor follows the mask structure, False otherwise.
|
75
|
+
Note, some weights can incidentally be zero, so we check for
|
76
|
+
atleast n zeros in each chunk of size m
|
77
|
+
"""
|
78
|
+
|
79
|
+
n, m = tuple(map(int, mask.split(":")))
|
80
|
+
# Reshape the tensor into chunks of size m
|
81
|
+
tensor = tensor.view(-1, m)
|
82
|
+
|
83
|
+
# Count the number of zeros in each chunk
|
84
|
+
zero_counts = (tensor == 0).sum(dim=1)
|
85
|
+
|
86
|
+
# Check if the number of zeros in each chunk atleast n
|
87
|
+
# Greater than sign is needed as some weights can incidentally
|
88
|
+
# be zero
|
89
|
+
if not torch.all(zero_counts >= n).item():
|
90
|
+
raise ValueError()
|
91
|
+
|
92
|
+
return True
|
@@ -0,0 +1,104 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from torch.nn import Module
|
17
|
+
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"is_module_offloaded",
|
21
|
+
"get_execution_device",
|
22
|
+
"get_offloaded_device",
|
23
|
+
"update_prefix_dict",
|
24
|
+
"update_parameter_data",
|
25
|
+
]
|
26
|
+
|
27
|
+
|
28
|
+
def is_module_offloaded(module: Module) -> bool:
|
29
|
+
"""
|
30
|
+
:param module: layer to check
|
31
|
+
:return: True if layer is offloaded from GPU, False otherwise
|
32
|
+
"""
|
33
|
+
return hasattr(module, "_hf_hook") and module._hf_hook.offload
|
34
|
+
|
35
|
+
|
36
|
+
def get_execution_device(module: Module) -> torch.device:
|
37
|
+
"""
|
38
|
+
:param module: layer to check
|
39
|
+
:return: device layer is loaded onto during forward pass
|
40
|
+
"""
|
41
|
+
if is_module_offloaded(module):
|
42
|
+
return module._hf_hook.execution_device
|
43
|
+
return next(module.parameters()).device
|
44
|
+
|
45
|
+
|
46
|
+
def get_offloaded_device(module: Module) -> torch.device:
|
47
|
+
"""
|
48
|
+
:param module: layer to check
|
49
|
+
:return: device layer is offloaded to onto after forward pass
|
50
|
+
"""
|
51
|
+
if is_module_offloaded(module):
|
52
|
+
first_key = list(module._hf_hook.weights_map.keys())[0]
|
53
|
+
prefix_dataset = module._hf_hook.weights_map.dataset
|
54
|
+
return prefix_dataset[first_key].device
|
55
|
+
return next(module.parameters()).device
|
56
|
+
|
57
|
+
|
58
|
+
def update_prefix_dict(module: Module, key: str, data: torch.Tensor):
|
59
|
+
"""
|
60
|
+
Updates the offloaded state dict for a given module. Parameter named key is replaced
|
61
|
+
by data. This is neccesary because parameter updates for offloaded modules do not
|
62
|
+
persist automatically between loads. This function only affects the offloaded
|
63
|
+
state dict and not the current state of the loaded module.
|
64
|
+
|
65
|
+
:param module: layer containing the parameter to update
|
66
|
+
:param key: name of parameter to update
|
67
|
+
:param data: tensor to update parameter with in the offloaded state dict
|
68
|
+
"""
|
69
|
+
if not is_module_offloaded(module):
|
70
|
+
raise ValueError("Prefix dict is only applicable to offloaded modules")
|
71
|
+
prefix_dict = module._hf_hook.weights_map
|
72
|
+
prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data
|
73
|
+
|
74
|
+
|
75
|
+
def update_parameter_data(
|
76
|
+
module: Module, new_param_data: torch.Tensor, param_name: str
|
77
|
+
):
|
78
|
+
"""
|
79
|
+
Updates the paramter value named param_name for a given module. This function
|
80
|
+
updates both the current loaded module state and the offloaded state dict if
|
81
|
+
the module is offloaded. This is neccesary because parameter updates for offloaded
|
82
|
+
modules do not persist automatically between loads.
|
83
|
+
|
84
|
+
:param module: layer containing the parameter to update
|
85
|
+
:param new_param_data: tensor to update parameter with
|
86
|
+
:param param_name:
|
87
|
+
"""
|
88
|
+
device = next(module.parameters()).device
|
89
|
+
|
90
|
+
offloaded = False
|
91
|
+
if is_module_offloaded(module):
|
92
|
+
offload_device = get_offloaded_device(module)
|
93
|
+
offloaded = True
|
94
|
+
|
95
|
+
parameter = getattr(module, param_name, None)
|
96
|
+
dtype = parameter.dtype
|
97
|
+
parameter.data = new_param_data.to(device).to(dtype)
|
98
|
+
|
99
|
+
if offloaded:
|
100
|
+
prefix_dict = module._hf_hook.weights_map.dataset
|
101
|
+
prefix = module._hf_hook.weights_map.prefix
|
102
|
+
prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to(
|
103
|
+
dtype
|
104
|
+
)
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import numpy
|
17
|
+
import torch
|
18
|
+
|
19
|
+
|
20
|
+
__all__ = ["get_permutations_24"]
|
21
|
+
|
22
|
+
|
23
|
+
# Precompute permutations for Marlin24 weight and scale shuffling
|
24
|
+
# Originally implemented in nm-vllm/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py # noqa: E501
|
25
|
+
#
|
26
|
+
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight
|
27
|
+
# data so that it is compatible with the tensor-core format that is described here:
|
28
|
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
29
|
+
#
|
30
|
+
# As a result of this reordering, the vector loads inside the kernel will get the data
|
31
|
+
# as it is needed for tensor-core (without the need to use ldmatrix instructions)
|
32
|
+
def get_permutations_24(num_bits):
|
33
|
+
perm_list = []
|
34
|
+
for i in range(32):
|
35
|
+
perm1 = []
|
36
|
+
col = i // 4
|
37
|
+
col_o = col // 2
|
38
|
+
for block in [0, 1]:
|
39
|
+
for row in [
|
40
|
+
2 * (i % 4),
|
41
|
+
2 * (i % 4) + 1,
|
42
|
+
2 * (i % 4 + 4),
|
43
|
+
2 * (i % 4 + 4) + 1,
|
44
|
+
]:
|
45
|
+
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
|
46
|
+
for j in range(4):
|
47
|
+
perm_list.extend([p + 1 * j for p in perm1])
|
48
|
+
perm = numpy.array(perm_list)
|
49
|
+
|
50
|
+
if num_bits == 4:
|
51
|
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
52
|
+
elif num_bits == 8:
|
53
|
+
interleave = numpy.array([0, 2, 1, 3])
|
54
|
+
else:
|
55
|
+
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
56
|
+
|
57
|
+
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
58
|
+
perm = torch.from_numpy(perm)
|
59
|
+
scale_perm = []
|
60
|
+
for i in range(8):
|
61
|
+
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
62
|
+
scale_perm_single = []
|
63
|
+
for i in range(8):
|
64
|
+
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
65
|
+
return perm, scale_perm, scale_perm_single
|
@@ -31,6 +31,7 @@ __all__ = [
|
|
31
31
|
"get_weight_mappings",
|
32
32
|
"get_nested_weight_mappings",
|
33
33
|
"get_quantization_state_dict",
|
34
|
+
"is_quantization_param",
|
34
35
|
]
|
35
36
|
|
36
37
|
|
@@ -214,7 +215,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
|
|
214
215
|
weight_mappings = get_weight_mappings(model_path)
|
215
216
|
state_dict = {}
|
216
217
|
for weight_name, safe_path in weight_mappings.items():
|
217
|
-
if not
|
218
|
+
if not is_quantization_param(weight_name):
|
218
219
|
continue
|
219
220
|
with safe_open(safe_path, framework="pt", device="cpu") as f:
|
220
221
|
state_dict[weight_name] = f.get_tensor(weight_name)
|
@@ -222,7 +223,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
|
|
222
223
|
return state_dict
|
223
224
|
|
224
225
|
|
225
|
-
def
|
226
|
+
def is_quantization_param(name: str) -> bool:
|
226
227
|
"""
|
227
228
|
Checks is a parameter name is associated with a quantization parameter
|
228
229
|
|
@@ -0,0 +1,341 @@
|
|
1
|
+
#
|
2
|
+
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
|
3
|
+
# Pulled from nm-vllm/vllm/model_executor/layers/quantization/utils/format_24.py
|
4
|
+
#
|
5
|
+
# flake8: noqa
|
6
|
+
# isort: skip_file
|
7
|
+
|
8
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
9
|
+
#
|
10
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11
|
+
# you may not use this file except in compliance with the License.
|
12
|
+
# You may obtain a copy of the License at
|
13
|
+
#
|
14
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15
|
+
#
|
16
|
+
# Unless required by applicable law or agreed to in writing,
|
17
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
18
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19
|
+
# See the License for the specific language governing permissions and
|
20
|
+
# limitations under the License.
|
21
|
+
|
22
|
+
import torch
|
23
|
+
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"sparse_semi_structured_from_dense_cutlass",
|
27
|
+
"sparse_semi_structured_to_dense_cutlass",
|
28
|
+
"mask_creator",
|
29
|
+
]
|
30
|
+
|
31
|
+
# This is PyTorch implementation of main part of reorder_meta()
|
32
|
+
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
33
|
+
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
34
|
+
# GEMM decides upon layout of this matrix, and at the moment for the
|
35
|
+
# sparse GEMM executed on tensor cores, this is layout described by
|
36
|
+
# ColumnMajorInterleaved<2> data structure, in
|
37
|
+
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
38
|
+
# reordering of meta matrix into meta_reordered matrix calculated
|
39
|
+
# according to these segments of CUTLASS code is re-implemented here.
|
40
|
+
# Note that this calculation produces offsets for scattering metadata
|
41
|
+
# matrix elements into reordered metadata matrix elements (or,
|
42
|
+
# equivalently, for gathering reordered metadata matrix element back
|
43
|
+
# into metadata matrix elements).
|
44
|
+
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
45
|
+
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
46
|
+
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
47
|
+
|
48
|
+
# Reorder the rows, then swizzle the 2x2 blocks.
|
49
|
+
group_x = 64
|
50
|
+
group_y = 32 if meta_dtype.itemsize == 2 else 16
|
51
|
+
|
52
|
+
dst_rows = (
|
53
|
+
dst_rows // group_x * group_x
|
54
|
+
+ (dst_rows % 2) * 2
|
55
|
+
+ (dst_rows % 8) // 4
|
56
|
+
+ ((dst_rows % group_y) % 4) // 2 * 32
|
57
|
+
+ ((dst_rows % group_x) // 8) * 4
|
58
|
+
)
|
59
|
+
|
60
|
+
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
61
|
+
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
62
|
+
dst_rows += topright - bottomleft
|
63
|
+
dst_cols -= topright - bottomleft
|
64
|
+
|
65
|
+
# Assumed that meta tensor is to be stored in CUTLASS
|
66
|
+
# InterleavedColumnMajor layout, and reverse engineered
|
67
|
+
# corresponding code to store values into this tensor.
|
68
|
+
interleave = 2
|
69
|
+
cols_maj = dst_cols // interleave
|
70
|
+
cols_min = dst_cols % interleave
|
71
|
+
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
72
|
+
|
73
|
+
|
74
|
+
# This function converts dense matrix into sparse semi-structured
|
75
|
+
# representation, producing "compressed" matrix, in the layout used by
|
76
|
+
# CUTLASS backend, and corresponding metadata matrix.
|
77
|
+
def sparse_semi_structured_from_dense_cutlass(dense):
|
78
|
+
if dense.dim() != 2:
|
79
|
+
raise RuntimeError(
|
80
|
+
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
|
81
|
+
)
|
82
|
+
|
83
|
+
m, k = dense.shape
|
84
|
+
device = dense.device
|
85
|
+
|
86
|
+
meta_dtype = torch.int8
|
87
|
+
if dense.dtype == torch.int8:
|
88
|
+
meta_dtype = torch.int32
|
89
|
+
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
|
90
|
+
meta_dtype = torch.int16
|
91
|
+
else:
|
92
|
+
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
93
|
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
94
|
+
if quadbits_per_meta_elem not in (4, 8):
|
95
|
+
raise RuntimeError("Invalid number of elements per meta element calculated")
|
96
|
+
|
97
|
+
if meta_dtype == torch.int32:
|
98
|
+
if m % 16 != 0:
|
99
|
+
raise RuntimeError(
|
100
|
+
f"Number of rows of dense matrix {m} must be divisible by 16"
|
101
|
+
)
|
102
|
+
else:
|
103
|
+
if m % 32 != 0:
|
104
|
+
raise RuntimeError(
|
105
|
+
f"Number of rows of dense matrix {m} must be divisible by 32"
|
106
|
+
)
|
107
|
+
if k % (4 * quadbits_per_meta_elem) != 0:
|
108
|
+
raise RuntimeError(
|
109
|
+
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
|
110
|
+
)
|
111
|
+
|
112
|
+
if dense.dtype != torch.float:
|
113
|
+
ksparse = 4
|
114
|
+
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
115
|
+
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
116
|
+
else:
|
117
|
+
ksparse = 2
|
118
|
+
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
119
|
+
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
120
|
+
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
121
|
+
|
122
|
+
# Encoding quadruples of True/False values as follows:
|
123
|
+
# [True, True, False, False] -> 0b0100
|
124
|
+
# [True, False, True, False] -> 0b1000
|
125
|
+
# [False, True, True, False] -> 0b1001
|
126
|
+
# [True, False, False, True ] -> 0b1100
|
127
|
+
# [False, True, False, True ] -> 0b1101
|
128
|
+
# [False, False, True, True ] -> 0b1110
|
129
|
+
# Thus, lower two bits in the encoding are index of the True value
|
130
|
+
# at the lowest index in the quadruple, and the higher two bits in
|
131
|
+
# the encoding are index of the other True value in the quadruple.
|
132
|
+
# In case there are less than two True values, than False value or
|
133
|
+
# values at some index or indices are considered True for the
|
134
|
+
# encoding. In case there are more than two True values, then the
|
135
|
+
# excess True value(s) at some indices are considered False for
|
136
|
+
# the encoding. The exact encodings used for these cases are as
|
137
|
+
# follows:
|
138
|
+
# [False, False, False, False] -> 0b1110
|
139
|
+
# [False, False, False, True ] -> 0b1110
|
140
|
+
# [False, False, True, False] -> 0b1110
|
141
|
+
# [False, True, False, False] -> 0b1001
|
142
|
+
# [False, True, True, True ] -> 0b1101
|
143
|
+
# [True, False, False, False] -> 0b1000
|
144
|
+
# [True, False, True, True ] -> 0b1100
|
145
|
+
# [True, True, False, True ] -> 0b0100
|
146
|
+
# [True, True, True, False] -> 0b0100
|
147
|
+
# [True, True, True, True ] -> 0b0100
|
148
|
+
# These particular encodings are chosen, with the help of Espresso
|
149
|
+
# logic minimizer software, for the purpose of minimization of
|
150
|
+
# corresponding Boolean functions, that translate non-zero flags
|
151
|
+
# into encoding bits. Note also possible choices for the first
|
152
|
+
# and last of these encodings were limited only to (0b0100,
|
153
|
+
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
154
|
+
# case.
|
155
|
+
|
156
|
+
expr0 = m0 & m1
|
157
|
+
expr1 = ~m0 & m1
|
158
|
+
expr2 = ~m0 & ~m1
|
159
|
+
bit0 = expr1
|
160
|
+
bit1 = expr2
|
161
|
+
bit2 = expr0 | expr2 | m3
|
162
|
+
bit3 = expr1 | ~m1
|
163
|
+
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
164
|
+
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
165
|
+
|
166
|
+
if dense.dtype != torch.float:
|
167
|
+
sparse0 = dense_4.gather(
|
168
|
+
-1, idxs0.unsqueeze(-1)
|
169
|
+
) # type: ignore[possibly-undefined]
|
170
|
+
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
171
|
+
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
172
|
+
else:
|
173
|
+
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
|
174
|
+
m, k // 2
|
175
|
+
) # type: ignore[possibly-undefined]
|
176
|
+
|
177
|
+
meta_4 = idxs0 | (idxs1 << 2)
|
178
|
+
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
179
|
+
|
180
|
+
if quadbits_per_meta_elem == 4:
|
181
|
+
meta = (
|
182
|
+
meta_n[:, :, 0]
|
183
|
+
| (meta_n[:, :, 1] << 4)
|
184
|
+
| (meta_n[:, :, 2] << 8)
|
185
|
+
| (meta_n[:, :, 3] << 12)
|
186
|
+
)
|
187
|
+
elif quadbits_per_meta_elem == 8:
|
188
|
+
meta = (
|
189
|
+
meta_n[:, :, 0]
|
190
|
+
| (meta_n[:, :, 1] << 4)
|
191
|
+
| (meta_n[:, :, 2] << 8)
|
192
|
+
| (meta_n[:, :, 3] << 12)
|
193
|
+
| (meta_n[:, :, 4] << 16)
|
194
|
+
| (meta_n[:, :, 5] << 20)
|
195
|
+
| (meta_n[:, :, 6] << 24)
|
196
|
+
| (meta_n[:, :, 7] << 28)
|
197
|
+
)
|
198
|
+
|
199
|
+
# Reorder meta tensor elements.
|
200
|
+
meta_reordered = meta.new_empty(
|
201
|
+
(m * meta_ncols,)
|
202
|
+
) # type: ignore[possibly-undefined]
|
203
|
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
204
|
+
m, meta_ncols, meta_dtype, device
|
205
|
+
)
|
206
|
+
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
207
|
+
|
208
|
+
return (sparse, meta_reordered.view(m, meta_ncols))
|
209
|
+
|
210
|
+
|
211
|
+
# This function performs reverse of the function above - it
|
212
|
+
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
213
|
+
# in the layout used by CUTLASS backend, and accompanying metadata
|
214
|
+
# matrix.
|
215
|
+
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
216
|
+
if sparse.dim() != 2:
|
217
|
+
raise RuntimeError(
|
218
|
+
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
|
219
|
+
)
|
220
|
+
|
221
|
+
m, k = sparse.shape
|
222
|
+
device = sparse.device
|
223
|
+
|
224
|
+
if meta_reordered.dim() != 2:
|
225
|
+
raise RuntimeError(
|
226
|
+
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
|
227
|
+
)
|
228
|
+
if meta_reordered.device != device:
|
229
|
+
raise RuntimeError(
|
230
|
+
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
|
231
|
+
)
|
232
|
+
|
233
|
+
meta_dtype = meta_reordered.dtype
|
234
|
+
if meta_dtype not in (torch.int16, torch.int32):
|
235
|
+
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
236
|
+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
237
|
+
|
238
|
+
ksparse = 4 if sparse.dtype != torch.float else 2
|
239
|
+
|
240
|
+
meta_nrows, meta_ncols = meta_reordered.shape
|
241
|
+
if meta_nrows != m:
|
242
|
+
raise RuntimeError(
|
243
|
+
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
|
244
|
+
)
|
245
|
+
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
246
|
+
raise RuntimeError(
|
247
|
+
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
|
248
|
+
"expected according to the number of columns of meta matrix"
|
249
|
+
)
|
250
|
+
|
251
|
+
# Undo meta tensor elements reordering.
|
252
|
+
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
253
|
+
m, meta_ncols, meta_dtype, device
|
254
|
+
)
|
255
|
+
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
|
256
|
+
|
257
|
+
# Unpack sparse tensor back to original dense tensor, using
|
258
|
+
# information provided by meta tensor. Note that torch.float
|
259
|
+
# datatype is handled pretty much the same as
|
260
|
+
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
261
|
+
# value is encoded as if underlying 8 bytes contain four
|
262
|
+
# torch.half/torch.bfloat16 values, where either first two or last
|
263
|
+
# two are zeros.
|
264
|
+
meta_2 = torch.empty(
|
265
|
+
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
266
|
+
dtype=meta_dtype,
|
267
|
+
device=device,
|
268
|
+
)
|
269
|
+
if quadbits_per_meta_elem == 4:
|
270
|
+
meta_2[:, :, 0] = meta & 0b11
|
271
|
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
272
|
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
273
|
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
274
|
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
275
|
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
276
|
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
277
|
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
278
|
+
elif quadbits_per_meta_elem == 8:
|
279
|
+
meta_2[:, :, 0] = meta & 0b11
|
280
|
+
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
281
|
+
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
282
|
+
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
283
|
+
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
284
|
+
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
285
|
+
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
286
|
+
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
287
|
+
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
288
|
+
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
289
|
+
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
290
|
+
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
291
|
+
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
292
|
+
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
293
|
+
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
294
|
+
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
295
|
+
|
296
|
+
dense_offsets = meta_2.view(-1) + (
|
297
|
+
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
|
298
|
+
).view(-1, 1).repeat(1, 2).view(-1)
|
299
|
+
|
300
|
+
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
|
301
|
+
if sparse.dtype != torch.float:
|
302
|
+
# dense.scatter_(0, dense_offsets, sparse.view(-1))
|
303
|
+
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
|
304
|
+
else:
|
305
|
+
dense.view(torch.half).scatter_(
|
306
|
+
0, dense_offsets, sparse.view(torch.half).view(-1)
|
307
|
+
)
|
308
|
+
|
309
|
+
return dense.view(m, 2 * k)
|
310
|
+
|
311
|
+
|
312
|
+
def mask_creator(tensor):
|
313
|
+
"""
|
314
|
+
Class for creating N:M sparsity masks.
|
315
|
+
Masks will be created using the N:M ratio, where for every block of
|
316
|
+
M weights, N will be pruned based on ranked weight value. Each mask
|
317
|
+
will correspond to the given tensor.
|
318
|
+
|
319
|
+
:param N: The number of weights in a group to keep
|
320
|
+
:param M: The size of a weight group
|
321
|
+
"""
|
322
|
+
N = 2
|
323
|
+
M = 4
|
324
|
+
|
325
|
+
mask = None
|
326
|
+
# for i, tensor in enumerate(tensors):
|
327
|
+
if tensor.numel() % M != 0:
|
328
|
+
raise ValueError(
|
329
|
+
f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
|
330
|
+
)
|
331
|
+
|
332
|
+
num_groups = tensor.numel() // M
|
333
|
+
|
334
|
+
# N:M sparsity for linear layers
|
335
|
+
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
|
336
|
+
index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
|
337
|
+
|
338
|
+
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
|
339
|
+
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
340
|
+
|
341
|
+
return mask
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Functionality for storing and setting the version info for SparseML
|
17
|
+
"""
|
18
|
+
|
19
|
+
|
20
|
+
version_base = "0.5.0"
|
21
|
+
is_release = True # change to True to set the generated version as a release version
|
22
|
+
|
23
|
+
|
24
|
+
def _generate_version(
|
25
|
+
is_release: bool,
|
26
|
+
version_base: str,
|
27
|
+
):
|
28
|
+
from datetime import date
|
29
|
+
|
30
|
+
if is_release:
|
31
|
+
return version_base
|
32
|
+
else:
|
33
|
+
return f"{version_base}.{date.today().strftime('%Y%m%d')}"
|
34
|
+
|
35
|
+
|
36
|
+
__all__ = [
|
37
|
+
"__version__",
|
38
|
+
"version_base",
|
39
|
+
"is_release",
|
40
|
+
"version",
|
41
|
+
"version_major",
|
42
|
+
"version_minor",
|
43
|
+
"version_patch",
|
44
|
+
"version_build",
|
45
|
+
"version_major_minor",
|
46
|
+
]
|
47
|
+
__version__ = _generate_version(is_release, version_base)
|
48
|
+
|
49
|
+
version = __version__
|
50
|
+
version_major, version_minor, version_patch, version_build = version.split(".") + (
|
51
|
+
[None] if len(version.split(".")) < 4 else []
|
52
|
+
) # handle conditional for version being 3 parts or 4 (4 containing build date)
|
53
|
+
version_major_minor = f"{version_major}.{version_minor}"
|