compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. compressed_tensors/base.py +1 -0
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +200 -8
  4. compressed_tensors/compressors/dense.py +1 -1
  5. compressed_tensors/compressors/marlin_24.py +11 -10
  6. compressed_tensors/compressors/model_compressor.py +101 -13
  7. compressed_tensors/compressors/naive_quantized.py +140 -0
  8. compressed_tensors/compressors/pack_quantized.py +128 -132
  9. compressed_tensors/compressors/sparse_bitmask.py +1 -1
  10. compressed_tensors/config/base.py +8 -1
  11. compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
  12. compressed_tensors/linear/compressed_linear.py +87 -0
  13. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  14. compressed_tensors/quantization/lifecycle/apply.py +204 -44
  15. compressed_tensors/quantization/lifecycle/calibration.py +22 -2
  16. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  17. compressed_tensors/quantization/lifecycle/forward.py +139 -61
  18. compressed_tensors/quantization/lifecycle/helpers.py +80 -0
  19. compressed_tensors/quantization/lifecycle/initialize.py +77 -13
  20. compressed_tensors/quantization/observers/__init__.py +1 -0
  21. compressed_tensors/quantization/observers/base.py +93 -14
  22. compressed_tensors/quantization/observers/helpers.py +64 -11
  23. compressed_tensors/quantization/observers/min_max.py +8 -0
  24. compressed_tensors/quantization/observers/mse.py +162 -0
  25. compressed_tensors/quantization/quant_args.py +139 -23
  26. compressed_tensors/quantization/quant_config.py +35 -2
  27. compressed_tensors/quantization/quant_scheme.py +112 -13
  28. compressed_tensors/quantization/utils/helpers.py +68 -2
  29. compressed_tensors/utils/__init__.py +5 -0
  30. compressed_tensors/utils/helpers.py +44 -2
  31. compressed_tensors/utils/offload.py +116 -0
  32. compressed_tensors/utils/permute.py +70 -0
  33. compressed_tensors/utils/safetensors_load.py +2 -0
  34. compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
  35. compressed_tensors/version.py +1 -1
  36. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
  37. compressed_tensors-0.6.0.dist-info/RECORD +52 -0
  38. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
  39. compressed_tensors/compressors/int_quantized.py +0 -126
  40. compressed_tensors/compressors/utils/helpers.py +0 -43
  41. compressed_tensors-0.4.0.dist-info/RECORD +0 -48
  42. /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
  43. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
  44. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from enum import Enum
16
16
  from typing import Dict, List, Optional, Union
17
17
 
18
18
  from compressed_tensors.config import CompressionFormat
19
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
19
20
  from compressed_tensors.quantization.quant_scheme import (
20
21
  QuantizationScheme,
21
22
  preset_name_to_scheme,
@@ -25,6 +26,7 @@ from compressed_tensors.quantization.utils import (
25
26
  is_module_quantized,
26
27
  iter_named_leaf_modules,
27
28
  module_type,
29
+ parse_out_kv_cache_args,
28
30
  )
29
31
  from pydantic import BaseModel, Field
30
32
  from torch.nn import Module
@@ -117,7 +119,18 @@ class QuantizationConfig(BaseModel):
117
119
  other quantization configs
118
120
  :param format: specifies how the quantized model is stored on disk
119
121
  :quantization_status: specifies the current status of all quantized layers. It is
120
- assumed all layers are in the same state.
122
+ assumed all layers are in the same state.
123
+ :param kv_cache_scheme: optional QuantizationArgs, that specify the
124
+ quantization of the kv cache. If None, kv cache is not quantized.
125
+ When applying kv cache quantization to transformer AutoModelForCausalLM,
126
+ the kv_cache_scheme gets converted into a QuantizationScheme that:
127
+ - targets the `q_proj` and `k_proj` modules of the model. The outputs
128
+ of those modules are the keys and values that might be cached
129
+ - quantizes the outputs of the aformentioned layers, so that
130
+ keys and values are compressed before storing them in the cache
131
+ There is an explicit assumption that the model contains modules with
132
+ `k_proj` and `v_proj` in their names. If this is not the case
133
+ and kv_cache_scheme != None, the quantization of kv cache will fail
121
134
  :global_compression_ratio: optional informational config to report the model
122
135
  compression ratio acheived by the quantization config
123
136
  :ignore: optional list of layers to ignore from config_groups. Layers in this list
@@ -126,6 +139,7 @@ class QuantizationConfig(BaseModel):
126
139
 
127
140
  config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
128
141
  quant_method: str = DEFAULT_QUANTIZATION_METHOD
142
+ kv_cache_scheme: Optional[QuantizationArgs] = None
129
143
  format: str = DEFAULT_QUANTIZATION_FORMAT
130
144
  quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
131
145
  global_compression_ratio: Optional[float] = None
@@ -154,7 +168,7 @@ class QuantizationConfig(BaseModel):
154
168
  ) -> Optional["QuantizationConfig"]:
155
169
  """
156
170
  Converts a model into its associated QuantizationConfig based on the
157
- QuantizationScheme attached to each quanitzed module
171
+ QuantizationScheme attached to each quantized module
158
172
 
159
173
  :param model: model to calculate quantization scheme of
160
174
  :return: filled out QuantizationScheme for the input model
@@ -195,6 +209,13 @@ class QuantizationConfig(BaseModel):
195
209
  # else we leave it off the ignore list, doesn't fall under any of the
196
210
  # existing quantization schemes so it won't be quantized
197
211
 
212
+ kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args(
213
+ quant_scheme_to_layers
214
+ )
215
+ kv_cache_scheme = (
216
+ kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args
217
+ )
218
+
198
219
  config_groups = {}
199
220
  for idx, scheme in enumerate(quant_scheme_to_layers):
200
221
  group_name = "group_" + str(idx)
@@ -213,7 +234,19 @@ class QuantizationConfig(BaseModel):
213
234
  return QuantizationConfig(
214
235
  config_groups=config_groups,
215
236
  quantization_status=quantization_status,
237
+ kv_cache_scheme=kv_cache_scheme,
216
238
  global_compression_ratio=compression_ratio,
217
239
  format=format,
218
240
  ignore=consolidated_ignore,
219
241
  )
242
+
243
+ def requires_calibration_data(self):
244
+ for _, scheme in self.config_groups.items():
245
+ if scheme.input_activations is not None:
246
+ if not scheme.input_activations.dynamic:
247
+ return True
248
+ if scheme.output_activations is not None:
249
+ if not scheme.output_activations.dynamic:
250
+ return True
251
+
252
+ return False
@@ -15,7 +15,11 @@
15
15
  from copy import deepcopy
16
16
  from typing import List, Optional
17
17
 
18
- from compressed_tensors.quantization.quant_args import QuantizationArgs
18
+ from compressed_tensors.quantization.quant_args import (
19
+ QuantizationArgs,
20
+ QuantizationStrategy,
21
+ QuantizationType,
22
+ )
19
23
  from pydantic import BaseModel
20
24
 
21
25
 
@@ -53,15 +57,9 @@ class QuantizationScheme(BaseModel):
53
57
  # default to quantizing all Linear layers
54
58
  targets = ["Linear"]
55
59
 
56
- # default to 8 bit integer symmetric quantization
57
- # for weights
58
- weights = QuantizationArgs(num_bits=8, symmetric=True)
59
-
60
- # default to 8 bit integer asymmetric quantization
61
- input_activations = QuantizationArgs(num_bits=8, symmetric=True)
62
-
63
- # Do not quantize the output activations
64
- # by default
60
+ # by default, activations and weights are left unquantized
61
+ weights = None
62
+ input_activations = None
65
63
  output_activations = None
66
64
 
67
65
  return cls(
@@ -107,13 +105,114 @@ def is_preset_scheme(name: str) -> bool:
107
105
  return name.upper() in PRESET_SCHEMES
108
106
 
109
107
 
108
+ UNQUANTIZED = dict()
109
+
110
+ # 8 bit integer weights and 8 bit activations quantization
110
111
  W8A8 = dict(
111
- weights=QuantizationArgs(), input_activations=QuantizationArgs(symmetric=True)
112
+ weights=QuantizationArgs(
113
+ num_bits=8,
114
+ type=QuantizationType.INT,
115
+ strategy=QuantizationStrategy.CHANNEL,
116
+ symmetric=True,
117
+ dynamic=False,
118
+ ),
119
+ input_activations=QuantizationArgs(
120
+ num_bits=8,
121
+ type=QuantizationType.INT,
122
+ strategy=QuantizationStrategy.TOKEN,
123
+ symmetric=True,
124
+ dynamic=True,
125
+ ),
112
126
  )
113
127
 
114
- W4A16 = dict(weights=QuantizationArgs(num_bits=4, group_size=128))
128
+ # 8 bit integer weights only quantization
129
+ W8A16 = dict(
130
+ weights=QuantizationArgs(
131
+ num_bits=8,
132
+ type=QuantizationType.INT,
133
+ strategy=QuantizationStrategy.CHANNEL,
134
+ symmetric=True,
135
+ dynamic=False,
136
+ ),
137
+ )
138
+
139
+ # 4 bit integer weights only quantization
140
+ W4A16 = dict(
141
+ weights=QuantizationArgs(
142
+ num_bits=4,
143
+ type=QuantizationType.INT,
144
+ strategy=QuantizationStrategy.GROUP,
145
+ group_size=128,
146
+ symmetric=True,
147
+ dynamic=False,
148
+ ),
149
+ )
150
+
151
+ # 4 bit integer weights and 8 bit activations quantization
152
+ W4A8 = dict(
153
+ weights=QuantizationArgs(
154
+ num_bits=4,
155
+ type=QuantizationType.INT,
156
+ group_size=128,
157
+ strategy=QuantizationStrategy.GROUP,
158
+ symmetric=True,
159
+ dynamic=False,
160
+ ),
161
+ input_activations=QuantizationArgs(
162
+ num_bits=8,
163
+ type=QuantizationType.INT,
164
+ strategy=QuantizationStrategy.TOKEN,
165
+ symmetric=True,
166
+ dynamic=True,
167
+ ),
168
+ )
169
+
170
+ # FP8 weights and FP8 activations quantization
171
+ FP8 = dict(
172
+ weights=QuantizationArgs(
173
+ num_bits=8,
174
+ type=QuantizationType.FLOAT,
175
+ strategy=QuantizationStrategy.TENSOR,
176
+ symmetric=True,
177
+ dynamic=False,
178
+ ),
179
+ input_activations=QuantizationArgs(
180
+ num_bits=8,
181
+ type=QuantizationType.FLOAT,
182
+ strategy=QuantizationStrategy.TENSOR,
183
+ symmetric=True,
184
+ dynamic=False,
185
+ ),
186
+ )
187
+
188
+ # FP8 weights and FP8 dynamic activations quantization
189
+ FP8_DYNAMIC = dict(
190
+ weights=QuantizationArgs(
191
+ num_bits=8,
192
+ type=QuantizationType.FLOAT,
193
+ strategy=QuantizationStrategy.CHANNEL,
194
+ symmetric=True,
195
+ dynamic=False,
196
+ ),
197
+ input_activations=QuantizationArgs(
198
+ num_bits=8,
199
+ type=QuantizationType.FLOAT,
200
+ strategy=QuantizationStrategy.TOKEN,
201
+ symmetric=True,
202
+ dynamic=True,
203
+ ),
204
+ )
115
205
 
116
206
  PRESET_SCHEMES = {
117
- "W8A8": W8A8,
207
+ # Unquantized (no-op)
208
+ "UNQUANTIZED": UNQUANTIZED,
209
+ # Integer weight only schemes
210
+ "W8A16": W8A16,
118
211
  "W4A16": W4A16,
212
+ # Integer weight and activation schemes
213
+ "W8A8": W8A8,
214
+ "W4A8": W4A8,
215
+ # Float weight and activation schemes
216
+ "FP8": FP8,
217
+ "FP8_DYNAMIC": FP8_DYNAMIC,
119
218
  }
@@ -13,10 +13,13 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from typing import Optional, Tuple
16
+ import re
17
+ from typing import List, Optional, Tuple
17
18
 
18
19
  import torch
19
20
  from compressed_tensors.quantization.observers.base import Observer
21
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
22
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
20
23
  from torch.nn import Module
21
24
  from tqdm import tqdm
22
25
 
@@ -30,8 +33,12 @@ __all__ = [
30
33
  "calculate_compression_ratio",
31
34
  "get_torch_bit_depth",
32
35
  "can_quantize",
36
+ "parse_out_kv_cache_args",
37
+ "KV_CACHE_TARGETS",
38
+ "is_kv_cache_quant_scheme",
33
39
  ]
34
40
 
41
+ KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"]
35
42
  _LOGGER: logging.Logger = logging.getLogger(__name__)
36
43
 
37
44
 
@@ -174,7 +181,7 @@ def calculate_compression_ratio(model: Module) -> float:
174
181
  for parameter in model.parameters():
175
182
  uncompressed_bits = get_torch_bit_depth(parameter)
176
183
  compressed_bits = uncompressed_bits
177
- if is_module_quantized(submodule):
184
+ if is_module_quantized(submodule) and submodule.quantization_scheme.weights:
178
185
  compressed_bits = submodule.quantization_scheme.weights.num_bits
179
186
 
180
187
  num_weights = parameter.numel()
@@ -182,3 +189,62 @@ def calculate_compression_ratio(model: Module) -> float:
182
189
  total_uncompressed += uncompressed_bits * num_weights
183
190
 
184
191
  return total_uncompressed / total_compressed
192
+
193
+
194
+ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
195
+ """
196
+ Check whether the QuantizationScheme targets the kv cache.
197
+ It does if all the following criteria are met:
198
+ - the scheme targets either exactly match the KV_CACHE_TARGETS
199
+ or the match KV_CACHE_TARGETS regex pattern
200
+ - the scheme quantizes output_activations (we want to quantize the
201
+ outputs from the KV_CACHE_TARGETS, as their correspond to the
202
+ keys and values that are to be saved in the cache)
203
+
204
+ :param scheme: The QuantizationScheme to investigate
205
+ :return: boolean flag
206
+ """
207
+ if len(scheme.targets) == 1:
208
+ # match on the KV_CACHE_TARGETS regex pattern
209
+ # if there is only one target
210
+ is_match_targets = any(
211
+ [re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
212
+ )
213
+ else:
214
+ # match on the exact KV_CACHE_TARGETS
215
+ # if there are multiple targets
216
+ is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
217
+
218
+ is_match_output_activations = scheme.output_activations is not None
219
+ return is_match_targets and is_match_output_activations
220
+
221
+
222
+ def parse_out_kv_cache_args(
223
+ quant_scheme_to_layers: List[QuantizationScheme],
224
+ ) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]:
225
+ """
226
+ If possible, parse out the kv cache specific QuantizationArgs
227
+ from the list of the QuantizationSchemes. If no kv cache
228
+ specific QuantizationArgs available, this function acts
229
+ as an identity function
230
+
231
+ :param quant_scheme_to_layers: list of QuantizationSchemes
232
+ :return: kv_cache_args (optional) and the (remaining or original)
233
+ list of the QuantizationSchemes
234
+ """
235
+ kv_cache_quant_scheme_to_layers = [
236
+ scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme)
237
+ ]
238
+ quant_scheme_to_layers = [
239
+ scheme
240
+ for scheme in quant_scheme_to_layers
241
+ if not is_kv_cache_quant_scheme(scheme)
242
+ ]
243
+
244
+ if kv_cache_quant_scheme_to_layers:
245
+ kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0]
246
+ kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations
247
+ else:
248
+ kv_cache_args = None
249
+
250
+ return kv_cache_args, quant_scheme_to_layers
@@ -13,4 +13,9 @@
13
13
  # limitations under the License.
14
14
  # flake8: noqa
15
15
 
16
+ from .helpers import *
17
+ from .offload import *
18
+ from .permutations_24 import *
19
+ from .permute import *
16
20
  from .safetensors_load import *
21
+ from .semi_structured_conversions import *
@@ -12,13 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
16
15
  from typing import Optional
17
16
 
17
+ import torch
18
18
  from transformers import AutoConfig
19
19
 
20
20
 
21
- __all__ = ["infer_compressor_from_model_config", "fix_fsdp_module_name"]
21
+ __all__ = [
22
+ "infer_compressor_from_model_config",
23
+ "fix_fsdp_module_name",
24
+ "tensor_follows_mask_structure",
25
+ "replace_module",
26
+ ]
22
27
 
23
28
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
24
29
 
@@ -61,3 +66,40 @@ def fix_fsdp_module_name(name: str) -> str:
61
66
  return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
62
67
  "." + FSDP_WRAPPER_NAME, ""
63
68
  )
69
+
70
+
71
+ def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
72
+ """
73
+ :param tensor: tensor to check
74
+ :param mask: mask structure to check for, in the format "n:m"
75
+ :return: True if the tensor follows the mask structure, False otherwise.
76
+ Note, some weights can incidentally be zero, so we check for
77
+ atleast n zeros in each chunk of size m
78
+ """
79
+
80
+ n, m = tuple(map(int, mask.split(":")))
81
+ # Reshape the tensor into chunks of size m
82
+ tensor = tensor.view(-1, m)
83
+
84
+ # Count the number of zeros in each chunk
85
+ zero_counts = (tensor == 0).sum(dim=1)
86
+
87
+ # Check if the number of zeros in each chunk atleast n
88
+ # Greater than sign is needed as some weights can incidentally
89
+ # be zero
90
+ if not torch.all(zero_counts >= n).item():
91
+ raise ValueError()
92
+
93
+ return True
94
+
95
+
96
+ def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
97
+ if "." in name:
98
+ parent_name = name.rsplit(".", 1)[0]
99
+ child_name = name[len(parent_name) + 1 :]
100
+ parent = model.get_submodule(parent_name)
101
+ else:
102
+ parent_name = ""
103
+ parent = model
104
+ child_name = name
105
+ setattr(parent, child_name, new_module)
@@ -0,0 +1,116 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch.nn import Module
17
+
18
+
19
+ __all__ = [
20
+ "is_module_offloaded",
21
+ "get_execution_device",
22
+ "get_offloaded_device",
23
+ "update_prefix_dict",
24
+ "update_parameter_data",
25
+ ]
26
+
27
+
28
+ def is_module_offloaded(module: Module) -> bool:
29
+ """
30
+ :param module: layer to check
31
+ :return: True if layer is offloaded from GPU, False otherwise
32
+ """
33
+ return hasattr(module, "_hf_hook") and module._hf_hook.offload
34
+
35
+
36
+ def get_execution_device(module: Module) -> torch.device:
37
+ """
38
+ :param module: layer to check
39
+ :return: device layer is loaded onto during forward pass
40
+ """
41
+ if is_module_offloaded(module):
42
+ return module._hf_hook.execution_device
43
+ device = next(module.parameters()).device
44
+
45
+ # offload only gets set for leaf modules, fallback to checking for device type
46
+ if device.type == "meta":
47
+ return module._hf_hook.execution_device
48
+
49
+ return device
50
+
51
+
52
+ def get_offloaded_device(module: Module) -> torch.device:
53
+ """
54
+ :param module: layer to check
55
+ :return: device layer is offloaded to onto after forward pass
56
+ """
57
+ if is_module_offloaded(module):
58
+ first_key = list(module._hf_hook.weights_map.keys())[0]
59
+ prefix_dataset = module._hf_hook.weights_map.dataset
60
+ return prefix_dataset[first_key].device
61
+ return next(module.parameters()).device
62
+
63
+
64
+ def update_prefix_dict(module: Module, key: str, data: torch.Tensor):
65
+ """
66
+ Updates the offloaded state dict for a given module. Parameter named key is replaced
67
+ by data. This is neccesary because parameter updates for offloaded modules do not
68
+ persist automatically between loads. This function only affects the offloaded
69
+ state dict and not the current state of the loaded module.
70
+
71
+ :param module: layer containing the parameter to update
72
+ :param key: name of parameter to update
73
+ :param data: tensor to update parameter with in the offloaded state dict
74
+ """
75
+ if not is_module_offloaded(module):
76
+ raise ValueError("Prefix dict is only applicable to offloaded modules")
77
+ prefix_dict = module._hf_hook.weights_map
78
+ prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data
79
+
80
+
81
+ def update_parameter_data(
82
+ module: Module, new_param_data: torch.Tensor, param_name: str
83
+ ):
84
+ """
85
+ Updates the paramter value named param_name for a given module. This function
86
+ updates both the current loaded module state and the offloaded state dict if
87
+ the module is offloaded. This is neccesary because parameter updates for offloaded
88
+ modules do not persist automatically between loads.
89
+
90
+ :param module: layer containing the parameter to update
91
+ :param new_param_data: tensor to update parameter with
92
+ :param param_name: name of layer parameter to update
93
+ """
94
+ if not hasattr(module, param_name):
95
+ return
96
+
97
+ device = next(module.parameters()).device
98
+
99
+ offloaded = False
100
+ if is_module_offloaded(module):
101
+ offload_device = get_offloaded_device(module)
102
+ offloaded = True
103
+
104
+ parameter = getattr(module, param_name, None)
105
+ if parameter is None:
106
+ raise ValueError("Attempted to update uninitialized parameter")
107
+
108
+ dtype = parameter.dtype
109
+ parameter.data = new_param_data.to(device).to(dtype)
110
+
111
+ if offloaded:
112
+ prefix_dict = module._hf_hook.weights_map.dataset
113
+ prefix = module._hf_hook.weights_map.prefix
114
+ prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to(
115
+ dtype
116
+ )
@@ -0,0 +1,70 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Set, Tuple
16
+
17
+ import torch
18
+
19
+
20
+ __all__ = ["safe_permute"]
21
+
22
+
23
+ # these datatypes are missing implementations required for standard permutation
24
+ _EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
25
+
26
+
27
+ def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
28
+ """
29
+ Perform out-of-place permutation without using torch.Tensor.index_put_,
30
+ whose implementation is missing for datatypes such as `torch.float8_e4m3fn`
31
+
32
+ :param value: tensor to permute
33
+ :param perm: permutation map
34
+ :param dim: dimension along which to apply permutation
35
+ :return: permuted value
36
+ """
37
+ dtype_tuple = (value.dtype, value.device)
38
+
39
+ if dtype_tuple in _EXPERIMENTAL_DTYPES:
40
+ return _fallback_permute(value, perm, dim)
41
+
42
+ try:
43
+ return value[tuple([slice(None)] * dim + [perm])]
44
+ except RuntimeError:
45
+ # Mark dtype as experimental if advanced indexing fails
46
+ _EXPERIMENTAL_DTYPES.add(dtype_tuple)
47
+ return _fallback_permute(value, perm, dim)
48
+
49
+
50
+ def _fallback_permute(
51
+ value: torch.Tensor, perm: torch.Tensor, dim: int
52
+ ) -> torch.Tensor:
53
+ """
54
+ Fallback permutation method for experimental dtypes.
55
+
56
+ :param value: tensor to permute
57
+ :param perm: permutation map
58
+ :param dim: dimension along which to apply permutation
59
+ :return: permuted value
60
+ """
61
+ value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
62
+ orig_slices = [slice(None)] * (dim + 1)
63
+ perm_slices = [slice(None)] * (dim + 1)
64
+
65
+ for index, perm_index in enumerate(perm):
66
+ orig_slices[dim] = index
67
+ perm_slices[dim] = perm_index
68
+ value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
69
+
70
+ return value_ret
@@ -234,5 +234,7 @@ def is_quantization_param(name: str) -> bool:
234
234
  return True
235
235
  if name.endswith("zero_point"):
236
236
  return True
237
+ if name.endswith("g_idx"):
238
+ return True
237
239
 
238
240
  return False
@@ -28,6 +28,7 @@ __all__ = [
28
28
  "mask_creator",
29
29
  ]
30
30
 
31
+
31
32
  # This is PyTorch implementation of main part of reorder_meta()
32
33
  # function, from tools/util/include/cutlass/util/host_reorder.h file
33
34
  # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
@@ -17,7 +17,7 @@ Functionality for storing and setting the version info for SparseML
17
17
  """
18
18
 
19
19
 
20
- version_base = "0.4.0"
20
+ version_base = "0.6.0"
21
21
  is_release = True # change to True to set the generated version as a release version
22
22
 
23
23