compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. compressed_tensors/__init__.py +1 -0
  2. compressed_tensors/base.py +2 -0
  3. compressed_tensors/compressors/__init__.py +6 -12
  4. compressed_tensors/compressors/base.py +137 -9
  5. compressed_tensors/compressors/helpers.py +6 -6
  6. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  7. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
  8. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  9. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
  10. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
  11. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
  12. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  13. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  14. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  15. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  16. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  17. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  18. compressed_tensors/config/base.py +6 -1
  19. compressed_tensors/linear/__init__.py +13 -0
  20. compressed_tensors/linear/compressed_linear.py +87 -0
  21. compressed_tensors/quantization/__init__.py +1 -0
  22. compressed_tensors/quantization/cache.py +201 -0
  23. compressed_tensors/quantization/lifecycle/apply.py +63 -9
  24. compressed_tensors/quantization/lifecycle/calibration.py +7 -7
  25. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  26. compressed_tensors/quantization/lifecycle/forward.py +126 -44
  27. compressed_tensors/quantization/lifecycle/frozen.py +6 -1
  28. compressed_tensors/quantization/lifecycle/helpers.py +0 -20
  29. compressed_tensors/quantization/lifecycle/initialize.py +138 -55
  30. compressed_tensors/quantization/observers/__init__.py +1 -0
  31. compressed_tensors/quantization/observers/base.py +54 -14
  32. compressed_tensors/quantization/observers/min_max.py +8 -0
  33. compressed_tensors/quantization/observers/mse.py +162 -0
  34. compressed_tensors/quantization/quant_args.py +102 -24
  35. compressed_tensors/quantization/quant_config.py +14 -2
  36. compressed_tensors/quantization/quant_scheme.py +12 -13
  37. compressed_tensors/quantization/utils/helpers.py +44 -19
  38. compressed_tensors/utils/__init__.py +1 -0
  39. compressed_tensors/utils/helpers.py +30 -1
  40. compressed_tensors/utils/offload.py +14 -2
  41. compressed_tensors/utils/permute.py +70 -0
  42. compressed_tensors/utils/safetensors_load.py +2 -0
  43. compressed_tensors/utils/semi_structured_conversions.py +1 -0
  44. compressed_tensors/version.py +1 -1
  45. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
  46. compressed_tensors-0.7.0.dist-info/RECORD +59 -0
  47. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
  48. compressed_tensors/compressors/pack_quantized.py +0 -219
  49. compressed_tensors-0.5.0.dist-info/RECORD +0 -48
  50. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
  51. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -13,10 +13,10 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Any, Dict, Optional
16
+ from typing import Any, Dict, Optional, Union
17
17
 
18
18
  import torch
19
- from pydantic import BaseModel, Field, validator
19
+ from pydantic import BaseModel, Field, field_validator, model_validator
20
20
 
21
21
 
22
22
  __all__ = [
@@ -25,6 +25,7 @@ __all__ = [
25
25
  "QuantizationStrategy",
26
26
  "QuantizationArgs",
27
27
  "round_to_quantized_type",
28
+ "ActivationOrdering",
28
29
  ]
29
30
 
30
31
  FP8_DTYPE = torch.float8_e4m3fn
@@ -51,6 +52,19 @@ class QuantizationStrategy(str, Enum):
51
52
  TOKEN = "token"
52
53
 
53
54
 
55
+ class ActivationOrdering(str, Enum):
56
+ """
57
+ Enum storing strategies for activation ordering
58
+
59
+ Group: reorder groups and weight\n
60
+ Weight: only reorder weight, not groups. Slightly lower latency and
61
+ accuracy compared to group actorder\n
62
+ """
63
+
64
+ GROUP = "group"
65
+ WEIGHT = "weight"
66
+
67
+
54
68
  class QuantizationArgs(BaseModel, use_enum_values=True):
55
69
  """
56
70
  User facing arguments used to define a quantization config for weights or
@@ -68,15 +82,18 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
68
82
  ranges will be observed with every sample. Defaults to False for static
69
83
  quantization. Note that enabling dynamic quantization will change the default
70
84
  observer to a memoryless one
85
+ :param actorder: whether to apply group quantization in decreasing order of
86
+ activation. Defaults to None for arbitrary ordering
71
87
  """
72
88
 
73
89
  num_bits: int = 8
74
- type: QuantizationType = QuantizationType.INT.value
90
+ type: QuantizationType = QuantizationType.INT
75
91
  symmetric: bool = True
76
92
  group_size: Optional[int] = None
77
93
  strategy: Optional[QuantizationStrategy] = None
78
94
  block_structure: Optional[str] = None
79
95
  dynamic: bool = False
96
+ actorder: Union[ActivationOrdering, bool, None] = None
80
97
  observer: str = Field(
81
98
  default="minmax",
82
99
  description=(
@@ -98,41 +115,102 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
98
115
  """
99
116
  from compressed_tensors.quantization.observers.base import Observer
100
117
 
101
- if self.observer == "minmax" and self.dynamic:
118
+ if self.dynamic:
102
119
  # override defualt observer for dynamic, you never want minmax which
103
120
  # keeps state across samples for dynamic
104
121
  self.observer = "memoryless"
105
122
 
106
123
  return Observer.load_from_registry(self.observer, quantization_args=self)
107
124
 
108
- @validator("strategy", pre=True, always=True)
109
- def validate_strategy(cls, value, values):
110
- group_size = values.get("group_size")
111
-
112
- # use group_size to determinine strategy if not given explicity
113
- if group_size is not None and value is None:
114
- if group_size > 0:
115
- return QuantizationStrategy.GROUP
125
+ def get_kv_cache(self):
126
+ """Get the singleton KV Cache"""
127
+ from compressed_tensors.quantization.cache import QuantizedKVParameterCache
116
128
 
117
- elif group_size == -1:
118
- return QuantizationStrategy.CHANNEL
129
+ return QuantizedKVParameterCache(self)
119
130
 
120
- else:
121
- raise ValueError(
122
- f"group_size={group_size} with strategy {value} is invald. "
123
- "group_size > 0 for strategy='group' and "
124
- "group_size = -1 for 'channel'"
125
- )
131
+ @field_validator("type", mode="before")
132
+ def validate_type(cls, value) -> QuantizationType:
133
+ if isinstance(value, str):
134
+ return QuantizationType(value.lower())
126
135
 
127
- if value == QuantizationStrategy.GROUP:
128
- if group_size is None:
129
- raise ValueError(f"strategy {value} requires group_size to be set.")
136
+ return value
130
137
 
138
+ @field_validator("group_size", mode="before")
139
+ def validate_group(cls, value) -> Union[int, None]:
131
140
  if value is None:
132
- return QuantizationStrategy.TENSOR
141
+ return value
142
+
143
+ if value < -1:
144
+ raise ValueError(
145
+ f"Invalid group size {value}. Use group_size > 0 for "
146
+ "strategy='group' and group_size = -1 for 'channel'"
147
+ )
148
+
149
+ return value
150
+
151
+ @field_validator("strategy", mode="before")
152
+ def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
153
+ if isinstance(value, str):
154
+ return QuantizationStrategy(value.lower())
155
+
156
+ return value
157
+
158
+ @field_validator("actorder", mode="before")
159
+ def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
160
+ if isinstance(value, bool):
161
+ return ActivationOrdering.GROUP if value else None
162
+
163
+ if isinstance(value, str):
164
+ return ActivationOrdering(value.lower())
133
165
 
134
166
  return value
135
167
 
168
+ @model_validator(mode="after")
169
+ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
170
+ # extract user-passed values from dictionary
171
+ strategy = model.strategy
172
+ group_size = model.group_size
173
+ actorder = model.actorder
174
+
175
+ # infer strategy
176
+ if strategy is None:
177
+ if group_size is None:
178
+ strategy = QuantizationStrategy.TENSOR
179
+ elif group_size > 0:
180
+ strategy = QuantizationStrategy.GROUP
181
+ elif group_size == -1:
182
+ strategy = QuantizationStrategy.CHANNEL
183
+ else:
184
+ raise ValueError(
185
+ f"Invalid group size {group_size}. Use group_size > 0 for "
186
+ "strategy='group' and group_size = -1 for 'channel'"
187
+ )
188
+
189
+ # validate strategy and group
190
+ if strategy == QuantizationStrategy.GROUP:
191
+ if group_size is None or group_size <= 0:
192
+ raise ValueError(
193
+ f"strategy {strategy} requires group_size to be "
194
+ "set to a positive value"
195
+ )
196
+ if (
197
+ group_size is not None
198
+ and group_size > 0
199
+ and strategy != QuantizationStrategy.GROUP
200
+ ):
201
+ raise ValueError("group_size requires strategy to be set to 'group'")
202
+
203
+ # validate activation ordering and strategy
204
+ if actorder is not None and strategy != QuantizationStrategy.GROUP:
205
+ raise ValueError(
206
+ "Must use group quantization strategy in order to apply "
207
+ "activation ordering"
208
+ )
209
+
210
+ # write back modified values
211
+ model.strategy = strategy
212
+ return model
213
+
136
214
  def pytorch_dtype(self) -> torch.dtype:
137
215
  if self.type == QuantizationType.FLOAT:
138
216
  return FP8_DTYPE
@@ -24,7 +24,7 @@ from compressed_tensors.quantization.quant_scheme import (
24
24
  from compressed_tensors.quantization.utils import (
25
25
  calculate_compression_ratio,
26
26
  is_module_quantized,
27
- iter_named_leaf_modules,
27
+ iter_named_quantizable_modules,
28
28
  module_type,
29
29
  parse_out_kv_cache_args,
30
30
  )
@@ -177,7 +177,9 @@ class QuantizationConfig(BaseModel):
177
177
  quantization_status = None
178
178
  ignore = {}
179
179
  quantization_type_names = set()
180
- for name, submodule in iter_named_leaf_modules(model):
180
+ for name, submodule in iter_named_quantizable_modules(
181
+ model, include_children=True, include_attn=True
182
+ ):
181
183
  layer_type = module_type(submodule)
182
184
  if not is_module_quantized(submodule):
183
185
  if layer_type not in ignore:
@@ -199,6 +201,13 @@ class QuantizationConfig(BaseModel):
199
201
  if len(quant_scheme_to_layers) == 0: # No quantized layers
200
202
  return None
201
203
 
204
+ # kv-cache only, no weight/activation quantization
205
+ if (
206
+ len(quantization_type_names) == 1
207
+ and "attention" in list(quantization_type_names)[0].lower()
208
+ ):
209
+ quantization_type_names.add("Linear")
210
+
202
211
  # clean up ignore list, we can leave out layers types if none of the
203
212
  # instances are quantized
204
213
  consolidated_ignore = []
@@ -241,6 +250,9 @@ class QuantizationConfig(BaseModel):
241
250
  )
242
251
 
243
252
  def requires_calibration_data(self):
253
+ if self.kv_cache_scheme is not None:
254
+ return True
255
+
244
256
  for _, scheme in self.config_groups.items():
245
257
  if scheme.input_activations is not None:
246
258
  if not scheme.input_activations.dynamic:
@@ -57,15 +57,9 @@ class QuantizationScheme(BaseModel):
57
57
  # default to quantizing all Linear layers
58
58
  targets = ["Linear"]
59
59
 
60
- # default to 8 bit integer symmetric quantization
61
- # for weights
62
- weights = QuantizationArgs(num_bits=8, symmetric=True)
63
-
64
- # default to 8 bit integer asymmetric quantization
65
- input_activations = QuantizationArgs(num_bits=8, symmetric=True)
66
-
67
- # Do not quantize the output activations
68
- # by default
60
+ # by default, activations and weights are left unquantized
61
+ weights = None
62
+ input_activations = None
69
63
  output_activations = None
70
64
 
71
65
  return cls(
@@ -111,8 +105,10 @@ def is_preset_scheme(name: str) -> bool:
111
105
  return name.upper() in PRESET_SCHEMES
112
106
 
113
107
 
108
+ UNQUANTIZED = dict()
109
+
114
110
  # 8 bit integer weights and 8 bit activations quantization
115
- W8A8 = dict(
111
+ INT8_W8A8 = dict(
116
112
  weights=QuantizationArgs(
117
113
  num_bits=8,
118
114
  type=QuantizationType.INT,
@@ -153,7 +149,7 @@ W4A16 = dict(
153
149
  )
154
150
 
155
151
  # 4 bit integer weights and 8 bit activations quantization
156
- W4A8 = dict(
152
+ INT8_W4A8 = dict(
157
153
  weights=QuantizationArgs(
158
154
  num_bits=4,
159
155
  type=QuantizationType.INT,
@@ -208,12 +204,15 @@ FP8_DYNAMIC = dict(
208
204
  )
209
205
 
210
206
  PRESET_SCHEMES = {
207
+ # Unquantized (no-op)
208
+ "UNQUANTIZED": UNQUANTIZED,
211
209
  # Integer weight only schemes
212
210
  "W8A16": W8A16,
213
211
  "W4A16": W4A16,
214
212
  # Integer weight and activation schemes
215
- "W8A8": W8A8,
216
- "W4A8": W4A8,
213
+ "W8A8": INT8_W8A8,
214
+ "INT8": INT8_W8A8, # alias for W8A8
215
+ "W4A8": INT8_W4A8,
217
216
  # Float weight and activation schemes
218
217
  "FP8": FP8,
219
218
  "FP8_DYNAMIC": FP8_DYNAMIC,
@@ -13,8 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- import re
17
- from typing import List, Optional, Tuple
16
+ from typing import Generator, List, Optional, Tuple
18
17
 
19
18
  import torch
20
19
  from compressed_tensors.quantization.observers.base import Observer
@@ -28,7 +27,6 @@ __all__ = [
28
27
  "infer_quantization_status",
29
28
  "is_module_quantized",
30
29
  "is_model_quantized",
31
- "iter_named_leaf_modules",
32
30
  "module_type",
33
31
  "calculate_compression_ratio",
34
32
  "get_torch_bit_depth",
@@ -36,9 +34,14 @@ __all__ = [
36
34
  "parse_out_kv_cache_args",
37
35
  "KV_CACHE_TARGETS",
38
36
  "is_kv_cache_quant_scheme",
37
+ "iter_named_leaf_modules",
38
+ "iter_named_quantizable_modules",
39
39
  ]
40
40
 
41
- KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"]
41
+ # target the self_attn layer
42
+ # QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale
43
+ KV_CACHE_TARGETS = ["re:.*self_attn$"]
44
+
42
45
  _LOGGER: logging.Logger = logging.getLogger(__name__)
43
46
 
44
47
 
@@ -106,11 +109,10 @@ def module_type(module: Module) -> str:
106
109
  return type(module).__name__
107
110
 
108
111
 
109
- def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
112
+ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]:
110
113
  """
111
114
  Yields modules that do not have any submodules except observers. The observers
112
115
  themselves are not yielded
113
-
114
116
  :param model: model to get leaf modules of
115
117
  :returns: generator tuple of (name, leaf_submodule)
116
118
  """
@@ -128,6 +130,37 @@ def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
128
130
  yield name, submodule
129
131
 
130
132
 
133
+ def iter_named_quantizable_modules(
134
+ model: Module, include_children: bool = True, include_attn: bool = False
135
+ ) -> Generator[Tuple[str, Module], None, None]:
136
+ """
137
+ Yield name and submodule of
138
+ - leaf modules, set by include_children
139
+ - attention modyles, set by include_attn
140
+
141
+ :param model: model to get leaf modules of
142
+ :param include_children: flag to get the leaf modules
143
+ :param inlcude_attn: flag to get the attention modules
144
+ :returns: generator tuple of (name, submodule)
145
+ """
146
+ for name, submodule in model.named_modules():
147
+ if include_children:
148
+ children = list(submodule.children())
149
+ if len(children) == 0 and not isinstance(submodule, Observer):
150
+ yield name, submodule
151
+ else:
152
+ has_non_observer_children = False
153
+ for child in children:
154
+ if not isinstance(child, Observer):
155
+ has_non_observer_children = True
156
+
157
+ if not has_non_observer_children:
158
+ yield name, submodule
159
+ if include_attn:
160
+ if name.endswith("self_attn"):
161
+ yield name, submodule
162
+
163
+
131
164
  def get_torch_bit_depth(value: torch.Tensor) -> int:
132
165
  """
133
166
  Determine the number of bits used to represent the dtype of a tensor
@@ -181,7 +214,7 @@ def calculate_compression_ratio(model: Module) -> float:
181
214
  for parameter in model.parameters():
182
215
  uncompressed_bits = get_torch_bit_depth(parameter)
183
216
  compressed_bits = uncompressed_bits
184
- if is_module_quantized(submodule):
217
+ if is_module_quantized(submodule) and submodule.quantization_scheme.weights:
185
218
  compressed_bits = submodule.quantization_scheme.weights.num_bits
186
219
 
187
220
  num_weights = parameter.numel()
@@ -204,19 +237,11 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
204
237
  :param scheme: The QuantizationScheme to investigate
205
238
  :return: boolean flag
206
239
  """
207
- if len(scheme.targets) == 1:
208
- # match on the KV_CACHE_TARGETS regex pattern
209
- # if there is only one target
210
- is_match_targets = any(
211
- [re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
212
- )
213
- else:
214
- # match on the exact KV_CACHE_TARGETS
215
- # if there are multiple targets
216
- is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
240
+ for target in scheme.targets:
241
+ if target in KV_CACHE_TARGETS:
242
+ return True
217
243
 
218
- is_match_output_activations = scheme.output_activations is not None
219
- return is_match_targets and is_match_output_activations
244
+ return False
220
245
 
221
246
 
222
247
  def parse_out_kv_cache_args(
@@ -16,5 +16,6 @@
16
16
  from .helpers import *
17
17
  from .offload import *
18
18
  from .permutations_24 import *
19
+ from .permute import *
19
20
  from .safetensors_load import *
20
21
  from .semi_structured_conversions import *
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  import torch
18
18
  from transformers import AutoConfig
@@ -22,6 +22,8 @@ __all__ = [
22
22
  "infer_compressor_from_model_config",
23
23
  "fix_fsdp_module_name",
24
24
  "tensor_follows_mask_structure",
25
+ "replace_module",
26
+ "is_compressed_tensors_config",
25
27
  ]
26
28
 
27
29
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -90,3 +92,30 @@ def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
90
92
  raise ValueError()
91
93
 
92
94
  return True
95
+
96
+
97
+ def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
98
+ if "." in name:
99
+ parent_name = name.rsplit(".", 1)[0]
100
+ child_name = name[len(parent_name) + 1 :]
101
+ parent = model.get_submodule(parent_name)
102
+ else:
103
+ parent_name = ""
104
+ parent = model
105
+ child_name = name
106
+ setattr(parent, child_name, new_module)
107
+
108
+
109
+ def is_compressed_tensors_config(compression_config: Any) -> bool:
110
+ """
111
+ Returns True if CompressedTensorsConfig is available from transformers and
112
+ compression_config is an instance of CompressedTensorsConfig
113
+
114
+ See: https://github.com/huggingface/transformers/pull/31704
115
+ """
116
+ try:
117
+ from transformers.utils.quantization_config import CompressedTensorsConfig
118
+
119
+ return isinstance(compression_config, CompressedTensorsConfig)
120
+ except ImportError:
121
+ return False
@@ -40,7 +40,13 @@ def get_execution_device(module: Module) -> torch.device:
40
40
  """
41
41
  if is_module_offloaded(module):
42
42
  return module._hf_hook.execution_device
43
- return next(module.parameters()).device
43
+ device = next(module.parameters()).device
44
+
45
+ # offload only gets set for leaf modules, fallback to checking for device type
46
+ if device.type == "meta":
47
+ return module._hf_hook.execution_device
48
+
49
+ return device
44
50
 
45
51
 
46
52
  def get_offloaded_device(module: Module) -> torch.device:
@@ -83,8 +89,11 @@ def update_parameter_data(
83
89
 
84
90
  :param module: layer containing the parameter to update
85
91
  :param new_param_data: tensor to update parameter with
86
- :param param_name:
92
+ :param param_name: name of layer parameter to update
87
93
  """
94
+ if not hasattr(module, param_name):
95
+ return
96
+
88
97
  device = next(module.parameters()).device
89
98
 
90
99
  offloaded = False
@@ -93,6 +102,9 @@ def update_parameter_data(
93
102
  offloaded = True
94
103
 
95
104
  parameter = getattr(module, param_name, None)
105
+ if parameter is None:
106
+ raise ValueError("Attempted to update uninitialized parameter")
107
+
96
108
  dtype = parameter.dtype
97
109
  parameter.data = new_param_data.to(device).to(dtype)
98
110
 
@@ -0,0 +1,70 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Set, Tuple
16
+
17
+ import torch
18
+
19
+
20
+ __all__ = ["safe_permute"]
21
+
22
+
23
+ # these datatypes are missing implementations required for standard permutation
24
+ _EXPERIMENTAL_DTYPES: Set[Tuple[torch.dtype, torch.device]] = set()
25
+
26
+
27
+ def safe_permute(value: torch.Tensor, perm: torch.Tensor, dim: int = 0) -> torch.Tensor:
28
+ """
29
+ Perform out-of-place permutation without using torch.Tensor.index_put_,
30
+ whose implementation is missing for datatypes such as `torch.float8_e4m3fn`
31
+
32
+ :param value: tensor to permute
33
+ :param perm: permutation map
34
+ :param dim: dimension along which to apply permutation
35
+ :return: permuted value
36
+ """
37
+ dtype_tuple = (value.dtype, value.device)
38
+
39
+ if dtype_tuple in _EXPERIMENTAL_DTYPES:
40
+ return _fallback_permute(value, perm, dim)
41
+
42
+ try:
43
+ return value[tuple([slice(None)] * dim + [perm])]
44
+ except RuntimeError:
45
+ # Mark dtype as experimental if advanced indexing fails
46
+ _EXPERIMENTAL_DTYPES.add(dtype_tuple)
47
+ return _fallback_permute(value, perm, dim)
48
+
49
+
50
+ def _fallback_permute(
51
+ value: torch.Tensor, perm: torch.Tensor, dim: int
52
+ ) -> torch.Tensor:
53
+ """
54
+ Fallback permutation method for experimental dtypes.
55
+
56
+ :param value: tensor to permute
57
+ :param perm: permutation map
58
+ :param dim: dimension along which to apply permutation
59
+ :return: permuted value
60
+ """
61
+ value_ret = value.clone() # cannot use zeros_like b/c of missing impl.
62
+ orig_slices = [slice(None)] * (dim + 1)
63
+ perm_slices = [slice(None)] * (dim + 1)
64
+
65
+ for index, perm_index in enumerate(perm):
66
+ orig_slices[dim] = index
67
+ perm_slices[dim] = perm_index
68
+ value_ret[tuple(orig_slices)] = value[tuple(perm_slices)]
69
+
70
+ return value_ret
@@ -234,5 +234,7 @@ def is_quantization_param(name: str) -> bool:
234
234
  return True
235
235
  if name.endswith("zero_point"):
236
236
  return True
237
+ if name.endswith("g_idx"):
238
+ return True
237
239
 
238
240
  return False
@@ -28,6 +28,7 @@ __all__ = [
28
28
  "mask_creator",
29
29
  ]
30
30
 
31
+
31
32
  # This is PyTorch implementation of main part of reorder_meta()
32
33
  # function, from tools/util/include/cutlass/util/host_reorder.h file
33
34
  # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
@@ -17,7 +17,7 @@ Functionality for storing and setting the version info for SparseML
17
17
  """
18
18
 
19
19
 
20
- version_base = "0.5.0"
20
+ version_base = "0.7.0"
21
21
  is_release = True # change to True to set the generated version as a release version
22
22
 
23
23