compressed-tensors 0.10.3a20250715__py3-none-any.whl → 0.10.3a20250721__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/quantization/lifecycle/forward.py +68 -5
- compressed_tensors/quantization/quant_args.py +31 -8
- compressed_tensors/quantization/quant_scheme.py +41 -0
- compressed_tensors/quantization/utils/helpers.py +11 -2
- compressed_tensors/transform/factory/base.py +1 -3
- compressed_tensors/transform/factory/hadamard.py +17 -8
- compressed_tensors/transform/factory/matrix_multiply.py +18 -8
- compressed_tensors/transform/transform_scheme.py +2 -1
- compressed_tensors/transform/utils/hadamard.py +2 -2
- compressed_tensors/transform/utils/matrix.py +179 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.10.3a20250715.dist-info → compressed_tensors-0.10.3a20250721.dist-info}/METADATA +1 -1
- {compressed_tensors-0.10.3a20250715.dist-info → compressed_tensors-0.10.3a20250721.dist-info}/RECORD +16 -16
- compressed_tensors/transform/utils/utils.py +0 -91
- {compressed_tensors-0.10.3a20250715.dist-info → compressed_tensors-0.10.3a20250721.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.10.3a20250715.dist-info → compressed_tensors-0.10.3a20250721.dist-info}/licenses/LICENSE +0 -0
- {compressed_tensors-0.10.3a20250715.dist-info → compressed_tensors-0.10.3a20250721.dist-info}/top_level.txt +0 -0
@@ -111,11 +111,18 @@ def dequantize(
|
|
111
111
|
elif scale.ndim == 2:
|
112
112
|
if scale.shape[1] == 1:
|
113
113
|
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
|
114
|
-
|
114
|
+
# Scale height matches input or is 1 -> group quantization across columns
|
115
|
+
#
|
116
|
+
# Example 1: scale.shape[0] == 1
|
117
|
+
# x_q: (4, 8), scale: (1, 4) -> 2 columns per group
|
118
|
+
#
|
119
|
+
# Example 2: scale.shape[0] == x_q.shape[0]
|
120
|
+
# x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
|
121
|
+
elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
|
115
122
|
group_size = int(x_q.shape[1] / scale.shape[1])
|
116
|
-
args = QuantizationArgs(
|
117
|
-
|
118
|
-
)
|
123
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=group_size)
|
124
|
+
else:
|
125
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape)
|
119
126
|
else:
|
120
127
|
raise ValueError(
|
121
128
|
f"Could not infer a quantization strategy from scale with {scale.ndim} "
|
@@ -189,7 +196,63 @@ def _process_quantization(
|
|
189
196
|
q_min, q_max = calculate_range(args, x.device)
|
190
197
|
group_size = args.group_size
|
191
198
|
|
192
|
-
|
199
|
+
# blockwise FP8: quantize per 2D block, supports block_structure for static block quant
|
200
|
+
if args.strategy == QuantizationStrategy.BLOCK:
|
201
|
+
original_shape = x.shape
|
202
|
+
rows, cols = x.shape[-2], x.shape[-1]
|
203
|
+
block_height, block_width = args.block_structure
|
204
|
+
|
205
|
+
# Ensure exact division (tensor dimensions must be divisible by block size)
|
206
|
+
if rows % block_height != 0:
|
207
|
+
raise ValueError(
|
208
|
+
f"Tensor height {rows} is not divisible by block_height {block_height}. "
|
209
|
+
f"Block quantization requires exact division."
|
210
|
+
)
|
211
|
+
if cols % block_width != 0:
|
212
|
+
raise ValueError(
|
213
|
+
f"Tensor width {cols} is not divisible by block_width {block_width}. "
|
214
|
+
f"Block quantization requires exact division."
|
215
|
+
)
|
216
|
+
|
217
|
+
# reshape into blocks and transpose to make each block contiguous
|
218
|
+
num_rows_blocks = rows // block_height
|
219
|
+
num_cols_blocks = cols // block_width
|
220
|
+
x_blocks = x.reshape(
|
221
|
+
num_rows_blocks,
|
222
|
+
block_height,
|
223
|
+
num_cols_blocks,
|
224
|
+
block_width,
|
225
|
+
).transpose(1, 2)
|
226
|
+
|
227
|
+
# expand scale/zero_point for blocks
|
228
|
+
sb = scale.unsqueeze(-1).unsqueeze(-1)
|
229
|
+
zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
|
230
|
+
if do_quantize:
|
231
|
+
# quantize blocks
|
232
|
+
x_blocks = _quantize(
|
233
|
+
x=x_blocks,
|
234
|
+
scale=sb,
|
235
|
+
zero_point=zb,
|
236
|
+
q_min=q_min,
|
237
|
+
q_max=q_max,
|
238
|
+
args=args,
|
239
|
+
dtype=dtype,
|
240
|
+
global_scale=global_scale,
|
241
|
+
)
|
242
|
+
if do_dequantize:
|
243
|
+
# dequantize blocks
|
244
|
+
x_blocks = _dequantize(
|
245
|
+
x_q=x_blocks,
|
246
|
+
scale=sb,
|
247
|
+
zero_point=zb,
|
248
|
+
global_scale=global_scale,
|
249
|
+
)
|
250
|
+
# restore original shape
|
251
|
+
output = x_blocks.transpose(1, 2).reshape(original_shape)
|
252
|
+
elif args.strategy in (
|
253
|
+
QuantizationStrategy.GROUP,
|
254
|
+
QuantizationStrategy.TENSOR_GROUP,
|
255
|
+
):
|
193
256
|
n_dims = x.shape
|
194
257
|
if len(n_dims) > 2:
|
195
258
|
x = x.squeeze(0)
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
import warnings
|
16
16
|
from enum import Enum
|
17
|
-
from typing import Any, Dict, Optional, Union
|
17
|
+
from typing import Any, Dict, List, Optional, Union
|
18
18
|
|
19
19
|
import torch
|
20
20
|
from compressed_tensors.utils import Aliasable
|
@@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
153
153
|
:param symmetric: whether or not quantization scale is symmetric about zero-point
|
154
154
|
:param strategy: string id determining the scope of scale/zero-point to apply
|
155
155
|
:param group_size: group length to use for the group strategy
|
156
|
-
:param block_structure: 2d block structure to use for the block strategy
|
157
|
-
|
156
|
+
:param block_structure: 2d block structure to use for the block strategy; must be
|
157
|
+
a list of two ints [rows, cols] like [128, 128].
|
158
158
|
:param dynamic: set True to perform dynamic quantization - values will not be
|
159
159
|
calibrated during calibration phase, instead during inference new quantization
|
160
160
|
ranges will be observed with every sample. Defaults to False for static
|
@@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
169
169
|
symmetric: bool = True
|
170
170
|
group_size: Optional[int] = None
|
171
171
|
strategy: Optional[QuantizationStrategy] = None
|
172
|
-
block_structure: Optional[
|
172
|
+
block_structure: Optional[List[int]] = None
|
173
173
|
dynamic: Union[DynamicType, bool] = False
|
174
174
|
actorder: Union[ActivationOrdering, bool, None] = None
|
175
175
|
observer: Optional[str] = Field(
|
@@ -207,6 +207,28 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
207
207
|
|
208
208
|
return value
|
209
209
|
|
210
|
+
@field_validator("block_structure", mode="before")
|
211
|
+
def validate_block_structure(cls, value) -> Optional[List[int]]:
|
212
|
+
if value is None:
|
213
|
+
return value
|
214
|
+
# For backward compatibility, allow string format "2x4", "8x16", etc.
|
215
|
+
if isinstance(value, str):
|
216
|
+
try:
|
217
|
+
return [int(x) for x in value.split("x")]
|
218
|
+
except Exception:
|
219
|
+
raise ValueError(
|
220
|
+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
|
221
|
+
)
|
222
|
+
if isinstance(value, (list, tuple)):
|
223
|
+
if len(value) != 2 or not all(isinstance(v, int) for v in value):
|
224
|
+
raise ValueError(
|
225
|
+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
|
226
|
+
)
|
227
|
+
return list(value)
|
228
|
+
raise ValueError(
|
229
|
+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
|
230
|
+
)
|
231
|
+
|
210
232
|
@field_validator("strategy", mode="before")
|
211
233
|
def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
|
212
234
|
if isinstance(value, str):
|
@@ -277,14 +299,15 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
277
299
|
|
278
300
|
# infer observer w.r.t. dynamic
|
279
301
|
if dynamic:
|
280
|
-
|
302
|
+
supported_strategies = (
|
281
303
|
QuantizationStrategy.TOKEN,
|
282
304
|
QuantizationStrategy.TENSOR,
|
283
305
|
QuantizationStrategy.TENSOR_GROUP,
|
284
|
-
|
306
|
+
QuantizationStrategy.GROUP,
|
307
|
+
)
|
308
|
+
if strategy not in supported_strategies:
|
285
309
|
raise ValueError(
|
286
|
-
f"One of {
|
287
|
-
"must be used for dynamic quantization",
|
310
|
+
f"One of {supported_strategies} must be used for dynamic quantization"
|
288
311
|
)
|
289
312
|
|
290
313
|
if (
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import warnings
|
15
16
|
from copy import deepcopy
|
16
17
|
from typing import Any, Dict, List, Optional
|
17
18
|
|
@@ -52,6 +53,7 @@ class QuantizationScheme(BaseModel):
|
|
52
53
|
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
|
53
54
|
inputs = model.input_activations
|
54
55
|
outputs = model.output_activations
|
56
|
+
weights = model.weights
|
55
57
|
|
56
58
|
if inputs is not None:
|
57
59
|
if inputs.actorder is not None:
|
@@ -61,6 +63,21 @@ class QuantizationScheme(BaseModel):
|
|
61
63
|
if outputs.actorder is not None:
|
62
64
|
raise ValueError("Cannot apply actorder to output activations")
|
63
65
|
|
66
|
+
if (
|
67
|
+
inputs and weights
|
68
|
+
and weights.strategy == QuantizationStrategy.GROUP
|
69
|
+
and inputs.strategy == QuantizationStrategy.GROUP
|
70
|
+
and weights.group_size != inputs.group_size
|
71
|
+
):
|
72
|
+
warnings.warn(
|
73
|
+
"Using GROUP strategy for both weights and input_activations "
|
74
|
+
f"with different group sizes ({weights.group_size} vs {inputs.group_size}) "
|
75
|
+
"may complicate fused kernel implementations. Consider using "
|
76
|
+
"TENSOR_GROUP strategy for both or matching group sizes.",
|
77
|
+
UserWarning,
|
78
|
+
stacklevel=2
|
79
|
+
)
|
80
|
+
|
64
81
|
return model
|
65
82
|
|
66
83
|
|
@@ -243,6 +260,29 @@ FP8_DYNAMIC = dict(
|
|
243
260
|
),
|
244
261
|
)
|
245
262
|
|
263
|
+
# Block‐wise FP8 (deepseekv3-style quantization):
|
264
|
+
# static 128x128 per‐block weights and
|
265
|
+
# dynamic per‐token‐group activations
|
266
|
+
FP8_BLOCK = dict(
|
267
|
+
weights=QuantizationArgs(
|
268
|
+
num_bits=8,
|
269
|
+
type=QuantizationType.FLOAT,
|
270
|
+
strategy=QuantizationStrategy.BLOCK,
|
271
|
+
symmetric=True,
|
272
|
+
dynamic=False,
|
273
|
+
block_structure=[128, 128],
|
274
|
+
),
|
275
|
+
input_activations=QuantizationArgs(
|
276
|
+
num_bits=8,
|
277
|
+
type=QuantizationType.FLOAT,
|
278
|
+
strategy=QuantizationStrategy.GROUP,
|
279
|
+
symmetric=True,
|
280
|
+
dynamic=True,
|
281
|
+
observer=None,
|
282
|
+
group_size=128,
|
283
|
+
),
|
284
|
+
)
|
285
|
+
|
246
286
|
PRESET_SCHEMES = {
|
247
287
|
# Unquantized (no-op)
|
248
288
|
"UNQUANTIZED": UNQUANTIZED,
|
@@ -257,6 +297,7 @@ PRESET_SCHEMES = {
|
|
257
297
|
# Float weight and activation schemes
|
258
298
|
"FP8": FP8,
|
259
299
|
"FP8_DYNAMIC": FP8_DYNAMIC,
|
300
|
+
"FP8_BLOCK": FP8_BLOCK,
|
260
301
|
"NVFP4A16": NVFP4A16,
|
261
302
|
"NVFP4": NVFP4,
|
262
303
|
}
|
@@ -171,7 +171,10 @@ def compute_dynamic_scales_and_zp(
|
|
171
171
|
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
|
172
172
|
elif args.strategy == QuantizationStrategy.TENSOR:
|
173
173
|
reduce_dims = None
|
174
|
-
elif args.strategy
|
174
|
+
elif args.strategy in (
|
175
|
+
QuantizationStrategy.TENSOR_GROUP,
|
176
|
+
QuantizationStrategy.GROUP,
|
177
|
+
):
|
175
178
|
if len(value.shape) > 2:
|
176
179
|
value = value.squeeze(0)
|
177
180
|
|
@@ -187,9 +190,15 @@ def compute_dynamic_scales_and_zp(
|
|
187
190
|
),
|
188
191
|
)
|
189
192
|
else:
|
193
|
+
supported_strategies = (
|
194
|
+
QuantizationStrategy.TOKEN,
|
195
|
+
QuantizationStrategy.TENSOR,
|
196
|
+
QuantizationStrategy.TENSOR_GROUP,
|
197
|
+
QuantizationStrategy.GROUP,
|
198
|
+
)
|
190
199
|
raise ValueError(
|
191
200
|
"Dynamic quantization is only supported for ",
|
192
|
-
f"{
|
201
|
+
f"{supported_strategies}",
|
193
202
|
)
|
194
203
|
|
195
204
|
if not reduce_dims:
|
@@ -117,10 +117,8 @@ class TransformFactory(RegistryMixin, ABC):
|
|
117
117
|
TransformLocation.WEIGHT_INPUT,
|
118
118
|
TransformLocation.WEIGHT_OUTPUT,
|
119
119
|
):
|
120
|
-
assert isinstance(module, torch.nn.Linear)
|
121
|
-
assert module.bias is None
|
122
|
-
|
123
120
|
# fuse transform into weight
|
121
|
+
assert hasattr(module, "weight")
|
124
122
|
with torch.no_grad(), align_module_device(module):
|
125
123
|
update_offload_parameter(module, "weight", transform(module.weight))
|
126
124
|
|
@@ -14,13 +14,14 @@
|
|
14
14
|
|
15
15
|
from typing import Optional, Union
|
16
16
|
|
17
|
+
import math
|
17
18
|
import torch
|
18
19
|
from compressed_tensors.transform import TransformArgs, TransformScheme
|
19
20
|
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
|
20
21
|
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
21
|
-
from compressed_tensors.transform.utils.
|
22
|
+
from compressed_tensors.transform.utils.matrix import (
|
22
23
|
apply_transform_weight,
|
23
|
-
|
24
|
+
get_transform_size,
|
24
25
|
)
|
25
26
|
from compressed_tensors.utils import get_execution_device, get_offloaded_device
|
26
27
|
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
@@ -51,8 +52,8 @@ class HadamardFactory(TransformFactory):
|
|
51
52
|
:param module: parent module that transform will be applied to
|
52
53
|
:param args: defines how the transform will be applied to the module
|
53
54
|
"""
|
54
|
-
assert
|
55
|
-
size =
|
55
|
+
assert hasattr(module, "weight")
|
56
|
+
size = get_transform_size(module, args.location, self.scheme.head_dim)
|
56
57
|
dtype = module.weight.dtype
|
57
58
|
device = get_offloaded_device(module)
|
58
59
|
exec_device = get_execution_device(module)
|
@@ -60,7 +61,7 @@ class HadamardFactory(TransformFactory):
|
|
60
61
|
factory_kwargs = {"construct_device": exec_device}
|
61
62
|
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
|
62
63
|
perm = self.perms[weight] if self.scheme.randomize else None
|
63
|
-
return HadamardTransform(weight, perm, args)
|
64
|
+
return HadamardTransform(weight, perm, args, type(module))
|
64
65
|
|
65
66
|
def _create_weight(
|
66
67
|
self,
|
@@ -81,12 +82,18 @@ class HadamardFactory(TransformFactory):
|
|
81
82
|
|
82
83
|
class HadamardTransform(TransformBase):
|
83
84
|
def __init__(
|
84
|
-
self,
|
85
|
+
self,
|
86
|
+
weight: Parameter,
|
87
|
+
perm: Optional[Parameter],
|
88
|
+
args: TransformArgs,
|
89
|
+
module_type: type[torch.nn.Module],
|
85
90
|
):
|
86
91
|
super().__init__()
|
87
92
|
self.weight = weight
|
88
93
|
self.perm = perm
|
89
94
|
self.args = args
|
95
|
+
self.module_type = module_type
|
96
|
+
self._scale = math.sqrt(weight.size(0))
|
90
97
|
|
91
98
|
def forward(self, value: Tensor) -> Tensor:
|
92
99
|
weight = self.weight
|
@@ -96,5 +103,7 @@ class HadamardTransform(TransformBase):
|
|
96
103
|
|
97
104
|
if self.args.inverse:
|
98
105
|
weight = weight.T
|
99
|
-
|
100
|
-
return apply_transform_weight(
|
106
|
+
|
107
|
+
return apply_transform_weight(
|
108
|
+
weight, value, self.args.location, self.module_type
|
109
|
+
) / self._scale
|
@@ -17,9 +17,9 @@ from typing import Optional
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.transform import TransformArgs, TransformScheme
|
19
19
|
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
|
20
|
-
from compressed_tensors.transform.utils.
|
20
|
+
from compressed_tensors.transform.utils.matrix import (
|
21
21
|
apply_transform_weight,
|
22
|
-
|
22
|
+
get_transform_size,
|
23
23
|
)
|
24
24
|
from compressed_tensors.utils import get_offloaded_device
|
25
25
|
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
@@ -50,8 +50,8 @@ class RandomMatrixFactory(TransformFactory):
|
|
50
50
|
:param module: parent module that transform will be applied to
|
51
51
|
:param args: defines how the transform will be applied to the module
|
52
52
|
"""
|
53
|
-
assert
|
54
|
-
size =
|
53
|
+
assert hasattr(module, "weight")
|
54
|
+
size = get_transform_size(module, args.location, self.scheme.head_dim)
|
55
55
|
dtype = module.weight.dtype
|
56
56
|
device = get_offloaded_device(module)
|
57
57
|
|
@@ -59,7 +59,7 @@ class RandomMatrixFactory(TransformFactory):
|
|
59
59
|
if args.inverse:
|
60
60
|
weight = self.inverses[weight]
|
61
61
|
|
62
|
-
return RandomMatrixTransform(weight, args)
|
62
|
+
return RandomMatrixTransform(weight, args, type(module))
|
63
63
|
|
64
64
|
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
65
65
|
# TODO: verify that weight is invertible (has non-zero determinant)
|
@@ -74,17 +74,27 @@ class RandomMatrixFactory(TransformFactory):
|
|
74
74
|
|
75
75
|
|
76
76
|
class RandomMatrixTransform(TransformBase):
|
77
|
-
def __init__(
|
77
|
+
def __init__(
|
78
|
+
self,
|
79
|
+
weight: Tensor,
|
80
|
+
args: TransformArgs,
|
81
|
+
module_type: type[torch.nn.Module],
|
82
|
+
):
|
78
83
|
super().__init__()
|
79
84
|
self.weight = weight # is an inverse if args.inverse
|
80
85
|
self.args = args
|
86
|
+
self.module_type = module_type
|
81
87
|
|
82
88
|
def forward(self, value: Tensor) -> Parameter:
|
83
|
-
return apply_transform_weight(
|
89
|
+
return apply_transform_weight(
|
90
|
+
self.weight, value, self.args.location, self.module_type
|
91
|
+
)
|
84
92
|
|
85
93
|
def right_inverse(self, value: Tensor) -> Tensor:
|
86
94
|
inverse = high_precision_invert(self.weight)
|
87
|
-
return apply_transform_weight(
|
95
|
+
return apply_transform_weight(
|
96
|
+
inverse, value, self.args.location, self.module_type
|
97
|
+
)
|
88
98
|
|
89
99
|
|
90
100
|
def high_precision_invert(weight: Tensor) -> Tensor:
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import List
|
15
|
+
from typing import List, Optional
|
16
16
|
|
17
17
|
from compressed_tensors.transform import TransformArgs
|
18
18
|
from pydantic import BaseModel, Field
|
@@ -40,3 +40,4 @@ class TransformScheme(BaseModel):
|
|
40
40
|
apply: List[TransformArgs] = Field(default_factory=list)
|
41
41
|
randomize: bool = Field(default=False)
|
42
42
|
requires_grad: bool = Field(default=False)
|
43
|
+
head_dim: Optional[int] = Field(default=None)
|
@@ -59,7 +59,7 @@ def deterministic_hadamard_matrix(
|
|
59
59
|
for _ in range(log2):
|
60
60
|
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
|
61
61
|
|
62
|
-
return H
|
62
|
+
return H
|
63
63
|
|
64
64
|
|
65
65
|
def random_hadamard_matrix(
|
@@ -86,7 +86,7 @@ def random_hadamard_matrix(
|
|
86
86
|
Q = Q.to(device=device)
|
87
87
|
Q = Q * 2 - 1
|
88
88
|
Q = torch.diag(Q)
|
89
|
-
return _matmul_hadU(Q)
|
89
|
+
return _matmul_hadU(Q)
|
90
90
|
|
91
91
|
|
92
92
|
def is_pow2(n: int) -> bool:
|
@@ -0,0 +1,179 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Callable, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.transform import TransformLocation
|
19
|
+
|
20
|
+
|
21
|
+
__all__ = ["get_transform_size", "apply_transform_weight"]
|
22
|
+
|
23
|
+
|
24
|
+
def get_transform_size(
|
25
|
+
module: torch.nn.Module,
|
26
|
+
location: TransformLocation,
|
27
|
+
head_dim: Optional[int] = None,
|
28
|
+
) -> int:
|
29
|
+
"""
|
30
|
+
Determine the size of a transform matrix given its location on the module
|
31
|
+
|
32
|
+
:param module: module that matrix will be applied to
|
33
|
+
:param location: location on module
|
34
|
+
:param head_dim: size of head when transform is applied to mha
|
35
|
+
:return: size of matrix
|
36
|
+
"""
|
37
|
+
if isinstance(module, torch.nn.Linear):
|
38
|
+
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
|
39
|
+
size = module.in_features
|
40
|
+
else:
|
41
|
+
size = module.out_features
|
42
|
+
elif isinstance(module, torch.nn.Embedding):
|
43
|
+
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
|
44
|
+
size = module.num_embeddings
|
45
|
+
else:
|
46
|
+
size = module.embedding_dim
|
47
|
+
else:
|
48
|
+
raise NotImplementedError(f"Transforms on {type(module)} are not supported")
|
49
|
+
|
50
|
+
if head_dim is not None:
|
51
|
+
if size % head_dim != 0:
|
52
|
+
raise ValueError(
|
53
|
+
f"{head_dim} must divide {size} for {type(module)} at {location}"
|
54
|
+
)
|
55
|
+
|
56
|
+
size = head_dim
|
57
|
+
|
58
|
+
return size
|
59
|
+
|
60
|
+
|
61
|
+
def apply_transform_weight(
|
62
|
+
transform_weight: torch.Tensor,
|
63
|
+
value: torch.Tensor,
|
64
|
+
location: TransformLocation,
|
65
|
+
module_type: type[torch.nn.Module],
|
66
|
+
) -> torch.Tensor:
|
67
|
+
"""
|
68
|
+
Using the transform location, apply the transform_weight to the
|
69
|
+
given value wrt linear weights. For more info on input and output transforms,
|
70
|
+
see `TransformLocation`
|
71
|
+
|
72
|
+
The following explains how weights should be applied to values according to location
|
73
|
+
|
74
|
+
let x be input activation
|
75
|
+
W be weight,
|
76
|
+
yh, xh, Wh be transformed output, input, weight
|
77
|
+
|
78
|
+
note that
|
79
|
+
y = (x W.T) // torch.nn.Linear
|
80
|
+
|
81
|
+
Choose values for yh, xh, and Wh which incorporate matrix transforms
|
82
|
+
|
83
|
+
let V, Vi be transform matrices on input side
|
84
|
+
U, Ui be transform matrices on output side
|
85
|
+
|
86
|
+
pick xh = (x V)
|
87
|
+
Wh = (U.T W Vi.T)
|
88
|
+
yh = (y U)
|
89
|
+
|
90
|
+
The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
|
91
|
+
|
92
|
+
(xh) (Wh).T = (x V) (U.T W Vi.T).T
|
93
|
+
= (x V) (Vi W.T U) // transpose matrix product identity
|
94
|
+
= (x W.T) U
|
95
|
+
= y U
|
96
|
+
= yh
|
97
|
+
|
98
|
+
:param transform_weight: transform weight to apply
|
99
|
+
:param value: value to apply transform_weight to
|
100
|
+
:param location: determines how weight should be applied
|
101
|
+
:param model_type: result of type(module), passed in to determine application of
|
102
|
+
weight transform
|
103
|
+
:return: value after transform_weight has been applied
|
104
|
+
"""
|
105
|
+
|
106
|
+
assert transform_weight.shape[0] == transform_weight.shape[1]
|
107
|
+
|
108
|
+
if module_type == torch.nn.Linear:
|
109
|
+
if location == TransformLocation.INPUT:
|
110
|
+
return _multihead_matmul(value, transform_weight)
|
111
|
+
|
112
|
+
elif location == TransformLocation.WEIGHT_INPUT:
|
113
|
+
# equivalent to (transform_weight @ value.T).T
|
114
|
+
return _multihead_matmul(value, transform_weight.T)
|
115
|
+
|
116
|
+
elif location == TransformLocation.WEIGHT_OUTPUT:
|
117
|
+
# equivalent to (value.T @ transform_weight).T
|
118
|
+
return _multihead_matmul(transform_weight.T, value)
|
119
|
+
|
120
|
+
elif location == TransformLocation.OUTPUT:
|
121
|
+
return _multihead_matmul(value, transform_weight)
|
122
|
+
|
123
|
+
# similar derivation to torch.nn.Linear, but `y = (x W)`
|
124
|
+
elif module_type == torch.nn.Embedding:
|
125
|
+
if location == TransformLocation.INPUT:
|
126
|
+
return _multihead_matmul(value, transform_weight)
|
127
|
+
|
128
|
+
elif location == TransformLocation.WEIGHT_INPUT:
|
129
|
+
return _multihead_matmul(
|
130
|
+
transform_weight,
|
131
|
+
value,
|
132
|
+
)
|
133
|
+
|
134
|
+
elif location == TransformLocation.WEIGHT_OUTPUT:
|
135
|
+
return _multihead_matmul(value, transform_weight)
|
136
|
+
|
137
|
+
elif location == TransformLocation.OUTPUT:
|
138
|
+
return _multihead_matmul(value, transform_weight)
|
139
|
+
|
140
|
+
raise NotImplementedError(
|
141
|
+
f"Applying transforms to {module_type} {location} is not supported"
|
142
|
+
)
|
143
|
+
|
144
|
+
|
145
|
+
def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
|
146
|
+
"""
|
147
|
+
Performs A @ B for last two dims of two matrices A and B that possibly
|
148
|
+
have different shapes, as is the case in multi-headed dimension. If
|
149
|
+
shapes are different, this is equivalent to converting the last two dims
|
150
|
+
of the smaller matrix into a block-diagonal matrix with the same shape as
|
151
|
+
the last two dims of the larger matrix.
|
152
|
+
|
153
|
+
E.g. if A is half the size of B, this function will perform
|
154
|
+
[[A ] @ B
|
155
|
+
[ A]]
|
156
|
+
|
157
|
+
If B is a third of the size of A, this function will perform
|
158
|
+
A @ [[B ]
|
159
|
+
[ B ]
|
160
|
+
[ B]]
|
161
|
+
|
162
|
+
This function will error out if the shapes are not evenly divisble
|
163
|
+
|
164
|
+
:param A: left-hand tensor
|
165
|
+
:param B: right-hand tensor
|
166
|
+
:return: result
|
167
|
+
"""
|
168
|
+
if A.shape[-1] > B.shape[-2]:
|
169
|
+
head_dim = B.shape[-2]
|
170
|
+
num_heads = A.shape[-1] // head_dim
|
171
|
+
A = A.unflatten(-1, (num_heads, head_dim))
|
172
|
+
return (A @ B).flatten(-2, -1)
|
173
|
+
elif A.shape[-1] < B.shape[-2]:
|
174
|
+
head_dim = A.shape[-1]
|
175
|
+
num_heads = B.shape[-2] // head_dim
|
176
|
+
B = B.unflatten(-2, (num_heads, head_dim))
|
177
|
+
return (A @ B).flatten(-3, -2)
|
178
|
+
else:
|
179
|
+
return A @ B
|
compressed_tensors/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.3a20250721
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
{compressed_tensors-0.10.3a20250715.dist-info → compressed_tensors-0.10.3a20250721.dist-info}/RECORD
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
|
2
2
|
compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
|
3
|
-
compressed_tensors/version.py,sha256=
|
3
|
+
compressed_tensors/version.py,sha256=d45NnXPlUzem78cnmH1dUdhwbxzgKieO5KjFyfBwSZA,523
|
4
4
|
compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
|
5
5
|
compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
|
6
6
|
compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
|
@@ -26,33 +26,33 @@ compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5y
|
|
26
26
|
compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
27
27
|
compressed_tensors/linear/compressed_linear.py,sha256=1yo9RyjA0aQ--iuIknFfcSorJn43Mn4CoV-q4JlTJ_o,4052
|
28
28
|
compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
|
29
|
-
compressed_tensors/quantization/quant_args.py,sha256=
|
29
|
+
compressed_tensors/quantization/quant_args.py,sha256=yKTj_4lAy_pnXeTCyUADpyz2qAzJXYJU2P03NF_TP68,12835
|
30
30
|
compressed_tensors/quantization/quant_config.py,sha256=w6sEEZGVGIF0Ub2r_cqRfZwbkBT8WzfY3ug52olmjGY,10049
|
31
|
-
compressed_tensors/quantization/quant_scheme.py,sha256=
|
31
|
+
compressed_tensors/quantization/quant_scheme.py,sha256=qApRLsPxELe5S2qFv8OVyAZ5TpRL7gT35i4U3c9PAwI,8461
|
32
32
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
|
33
33
|
compressed_tensors/quantization/lifecycle/apply.py,sha256=wM8mVcbKvZjBo18pSXMp28i30YWwUXJPSS7_HCakH9U,17892
|
34
34
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
35
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
35
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=jT70Mbbu9pH10vu5ALVD7VWGoFdMEUpxmihGrf4frjM,17432
|
36
36
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
37
37
|
compressed_tensors/quantization/lifecycle/initialize.py,sha256=D7yxua1zELmsBYlQiJUTiClBOMIe2J0-IrN2d-jLFPk,8653
|
38
38
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
39
|
-
compressed_tensors/quantization/utils/helpers.py,sha256=
|
39
|
+
compressed_tensors/quantization/utils/helpers.py,sha256=Je96Wai9SOizbdE5ph0nsJ86zS96lE4fkf_9q9o2tpA,17212
|
40
40
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
41
41
|
compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
|
42
42
|
compressed_tensors/transform/__init__.py,sha256=v2wfl4CMfA6KbD7Hxx_MbRev63y_6QLDlccZq-WTtdw,907
|
43
43
|
compressed_tensors/transform/apply.py,sha256=Cnc7Q8d8FzpLGtXixvdPzqApfjAXpfShxvVl_7nNJ4E,1259
|
44
44
|
compressed_tensors/transform/transform_args.py,sha256=jJY-Qt996w45LWQ10AHd7tUtNrnV9mjD9M5D4SZ5B3E,3199
|
45
45
|
compressed_tensors/transform/transform_config.py,sha256=A3RuLNDqBNEByQNeu40Kg7sItwE6kWgnX18Umg1uONI,2128
|
46
|
-
compressed_tensors/transform/transform_scheme.py,sha256=
|
46
|
+
compressed_tensors/transform/transform_scheme.py,sha256=uGLC4avdbhrVqNC3-Eo0p7WzNRQK92Fpg0N9hWiuCRQ,1752
|
47
47
|
compressed_tensors/transform/factory/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
48
|
-
compressed_tensors/transform/factory/base.py,sha256=
|
49
|
-
compressed_tensors/transform/factory/hadamard.py,sha256=
|
50
|
-
compressed_tensors/transform/factory/matrix_multiply.py,sha256=
|
48
|
+
compressed_tensors/transform/factory/base.py,sha256=w9ic5eSxfSNn2Xju-xZvG4_iXAIsJCU56qik8w---aI,5994
|
49
|
+
compressed_tensors/transform/factory/hadamard.py,sha256=iJ2OyKitR2Duw0z5Jqj69GTih2C1WtHRXQCTtATaTtw,4180
|
50
|
+
compressed_tensors/transform/factory/matrix_multiply.py,sha256=LdoV2E12HTucmUWcw7UKOpRNnL8QhOOIUnNVlpOpGiI,3925
|
51
51
|
compressed_tensors/transform/factory/random_hadamard.py,sha256=nUhTlFa4ikSpcl4Umme71pnjMPgwYoGlwjKlU27UHZ4,1634
|
52
52
|
compressed_tensors/transform/utils/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
53
|
-
compressed_tensors/transform/utils/hadamard.py,sha256=
|
53
|
+
compressed_tensors/transform/utils/hadamard.py,sha256=hDJZC0Gw2fKdxqa3f8TmFc5J0eJqxHtFRxswLU_yVJc,5548
|
54
54
|
compressed_tensors/transform/utils/hadamards.safetensors,sha256=mFd1GzNodGG-ifA1IoH-0nHYzfraCOvrq_dX2zFI1B4,1436901
|
55
|
-
compressed_tensors/transform/utils/
|
55
|
+
compressed_tensors/transform/utils/matrix.py,sha256=FIHCUlpWVIIhdr3c6EbQec41JeiPAAjCM9Ejz77wb-w,6181
|
56
56
|
compressed_tensors/utils/__init__.py,sha256=QFQzF6MpV3yStajPzYktZkmvZsxvfpKUZq2oGbd1Cvw,832
|
57
57
|
compressed_tensors/utils/helpers.py,sha256=Q3iRAa2XSdmmn4vSpUplnvKOmWwn4Clao9ZkPBHXtpI,12604
|
58
58
|
compressed_tensors/utils/internal.py,sha256=7SSWgDoNFRnlfadwkoFhLW-T2jOc7Po_WzWv5h32Sa8,982
|
@@ -61,8 +61,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
|
|
61
61
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
62
62
|
compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
|
63
63
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
64
|
-
compressed_tensors-0.10.
|
65
|
-
compressed_tensors-0.10.
|
66
|
-
compressed_tensors-0.10.
|
67
|
-
compressed_tensors-0.10.
|
68
|
-
compressed_tensors-0.10.
|
64
|
+
compressed_tensors-0.10.3a20250721.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
65
|
+
compressed_tensors-0.10.3a20250721.dist-info/METADATA,sha256=SZk9s5bqtGXwCSxZgsD1tg5GrTH4M8x0d8A3WQkrtlw,7031
|
66
|
+
compressed_tensors-0.10.3a20250721.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
67
|
+
compressed_tensors-0.10.3a20250721.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
68
|
+
compressed_tensors-0.10.3a20250721.dist-info/RECORD,,
|
@@ -1,91 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing,
|
10
|
-
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
import torch
|
16
|
-
from compressed_tensors.transform import TransformLocation
|
17
|
-
|
18
|
-
|
19
|
-
__all__ = ["get_matrix_size", "apply_transform_weight"]
|
20
|
-
|
21
|
-
|
22
|
-
def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int:
|
23
|
-
"""
|
24
|
-
Determine the size of a matrix given its location on the module
|
25
|
-
|
26
|
-
:param module: module that matrix will be applied to
|
27
|
-
:param location: location on module
|
28
|
-
:return: size of matrix
|
29
|
-
"""
|
30
|
-
assert isinstance(module, torch.nn.Linear)
|
31
|
-
if location in ("input", TransformLocation.WEIGHT_INPUT):
|
32
|
-
return module.in_features
|
33
|
-
else:
|
34
|
-
return module.out_features
|
35
|
-
|
36
|
-
|
37
|
-
def apply_transform_weight(
|
38
|
-
weight: torch.Tensor,
|
39
|
-
value: torch.Tensor,
|
40
|
-
location: TransformLocation,
|
41
|
-
) -> torch.Tensor:
|
42
|
-
"""
|
43
|
-
Using the transform location, determine how to apply the transform weight to the
|
44
|
-
given value. For more info on input and output transforms, see `TransformLocation`
|
45
|
-
|
46
|
-
The following explains how weights should be applied to values according to location
|
47
|
-
|
48
|
-
let x be input activation
|
49
|
-
W be weight,
|
50
|
-
yh, xh, Wh be transformed output, input, weight
|
51
|
-
|
52
|
-
note that
|
53
|
-
y = (x W.T) // torch.nn.Linear
|
54
|
-
|
55
|
-
Choose values for yh, xh, and Wh which incorporate matrix transforms
|
56
|
-
|
57
|
-
let V, Vi be transform matrices on input side
|
58
|
-
U, Ui be transform matrices on output side
|
59
|
-
|
60
|
-
pick xh = (x V)
|
61
|
-
Wh = (U.T W Vi.T)
|
62
|
-
yh = (y U)
|
63
|
-
|
64
|
-
The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
|
65
|
-
|
66
|
-
(xh) (Wh).T = (x V) (U.T W Vi.T).T
|
67
|
-
= (x V) (Vi W.T U) // transpose matrix product identity
|
68
|
-
= (x W.T) U
|
69
|
-
= y U
|
70
|
-
= yh
|
71
|
-
|
72
|
-
:param weight: transform weight to apply
|
73
|
-
:param value: value to apply weight to
|
74
|
-
:param location: determines how weight should be applied
|
75
|
-
:return: value after transform weight has been applied
|
76
|
-
"""
|
77
|
-
|
78
|
-
if location == TransformLocation.INPUT:
|
79
|
-
return value @ weight
|
80
|
-
|
81
|
-
elif location == TransformLocation.WEIGHT_INPUT:
|
82
|
-
return value @ weight.T
|
83
|
-
|
84
|
-
elif location == TransformLocation.WEIGHT_OUTPUT:
|
85
|
-
return weight.T @ value
|
86
|
-
|
87
|
-
elif location == TransformLocation.OUTPUT:
|
88
|
-
return value @ weight
|
89
|
-
|
90
|
-
else:
|
91
|
-
raise NotImplementedError(f"{location} has not been implemented yet")
|
{compressed_tensors-0.10.3a20250715.dist-info → compressed_tensors-0.10.3a20250721.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|
File without changes
|