compressed-tensors 0.9.5a20250521__py3-none-any.whl → 0.9.5a20250530__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/quantization/lifecycle/forward.py +16 -3
- compressed_tensors/quantization/lifecycle/initialize.py +44 -36
- compressed_tensors/quantization/quant_args.py +49 -12
- compressed_tensors/quantization/quant_config.py +2 -2
- compressed_tensors/quantization/quant_scheme.py +23 -1
- compressed_tensors/quantization/utils/helpers.py +31 -6
- compressed_tensors/transform/__init__.py +20 -0
- compressed_tensors/transform/transform_args.py +54 -0
- compressed_tensors/transform/transform_config.py +73 -0
- compressed_tensors/transform/transform_scheme.py +43 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.9.5a20250521.dist-info → compressed_tensors-0.9.5a20250530.dist-info}/METADATA +1 -1
- {compressed_tensors-0.9.5a20250521.dist-info → compressed_tensors-0.9.5a20250530.dist-info}/RECORD +16 -12
- {compressed_tensors-0.9.5a20250521.dist-info → compressed_tensors-0.9.5a20250530.dist-info}/WHEEL +1 -1
- {compressed_tensors-0.9.5a20250521.dist-info → compressed_tensors-0.9.5a20250530.dist-info}/licenses/LICENSE +0 -0
- {compressed_tensors-0.9.5a20250521.dist-info → compressed_tensors-0.9.5a20250530.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ from typing import Optional
|
|
18
18
|
|
19
19
|
import torch
|
20
20
|
from compressed_tensors.quantization.quant_args import (
|
21
|
+
DynamicType,
|
21
22
|
QuantizationArgs,
|
22
23
|
QuantizationStrategy,
|
23
24
|
QuantizationType,
|
@@ -189,7 +190,11 @@ def _process_quantization(
|
|
189
190
|
q_min, q_max = calculate_range(args, x.device)
|
190
191
|
group_size = args.group_size
|
191
192
|
|
192
|
-
if args.strategy
|
193
|
+
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
|
194
|
+
n_dims = x.shape
|
195
|
+
if len(n_dims) > 2:
|
196
|
+
x = x.squeeze(0)
|
197
|
+
|
193
198
|
output_dtype = dtype if dtype is not None else x.dtype
|
194
199
|
output = torch.zeros_like(x).to(output_dtype)
|
195
200
|
columns = output.shape[1]
|
@@ -251,6 +256,9 @@ def _process_quantization(
|
|
251
256
|
if not is_column_order:
|
252
257
|
output = safe_permute(output, torch.argsort(perm), dim=1)
|
253
258
|
|
259
|
+
if len(n_dims) > 2:
|
260
|
+
output = output.unsqueeze(0)
|
261
|
+
|
254
262
|
else: # covers channel, token and tensor strategies
|
255
263
|
if do_quantize:
|
256
264
|
output = _quantize(
|
@@ -352,9 +360,11 @@ def forward_quantize(
|
|
352
360
|
g_idx = getattr(module, "weight_g_idx", None)
|
353
361
|
global_scale = getattr(module, f"{base_name}_global_scale", None)
|
354
362
|
|
355
|
-
if args.dynamic:
|
363
|
+
if args.dynamic in (True, DynamicType.LOCAL):
|
356
364
|
# dynamic quantization - determine the scale/zp on the fly
|
357
|
-
scale, zero_point = compute_dynamic_scales_and_zp(
|
365
|
+
scale, zero_point = compute_dynamic_scales_and_zp(
|
366
|
+
value=value, args=args, module=module, global_scale=global_scale
|
367
|
+
)
|
358
368
|
else:
|
359
369
|
# static quantization - get scale and zero point from layer
|
360
370
|
scale = getattr(module, f"{base_name}_scale")
|
@@ -388,6 +398,7 @@ def _quantize(
|
|
388
398
|
scale = scale.to(global_scale.dtype) / global_scale
|
389
399
|
|
390
400
|
scaled = x / scale
|
401
|
+
|
391
402
|
if zero_point is not None:
|
392
403
|
scaled += zero_point.to(x.dtype)
|
393
404
|
|
@@ -398,6 +409,7 @@ def _quantize(
|
|
398
409
|
q_max,
|
399
410
|
)
|
400
411
|
quantized_value = round_to_quantized_type(clamped_value, args)
|
412
|
+
|
401
413
|
if dtype is not None:
|
402
414
|
quantized_value = quantized_value.to(dtype)
|
403
415
|
|
@@ -422,6 +434,7 @@ def _dequantize(
|
|
422
434
|
|
423
435
|
if zero_point is not None:
|
424
436
|
dequant_value = dequant_value - zero_point.to(scale.dtype)
|
437
|
+
|
425
438
|
dequant_value = dequant_value * scale
|
426
439
|
|
427
440
|
if dtype is not None:
|
@@ -156,13 +156,33 @@ def _initialize_scale_zero_point(
|
|
156
156
|
force_zero_point: bool = True,
|
157
157
|
scale_dtype: Optional[torch.dtype] = None,
|
158
158
|
):
|
159
|
-
if quantization_args.dynamic:
|
159
|
+
if quantization_args.dynamic is True:
|
160
160
|
return
|
161
161
|
|
162
162
|
# initialize on execution device to avoid performing quantized ops on cpu
|
163
163
|
device = get_execution_device(module)
|
164
164
|
|
165
|
-
#
|
165
|
+
# 1. Create global_scales for tensor_group
|
166
|
+
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
|
167
|
+
# TODO: should move to llmcompressor
|
168
|
+
if base_name == "weight":
|
169
|
+
# When applying weight-only FP4 quantization, generate a global_scale
|
170
|
+
# This scale is applied during runtime to ensure that the generated
|
171
|
+
# local scale falls properly within the FP8 range (i.e max value is FP8_max)
|
172
|
+
# which is the expected dtype of NVFP4A16 scales
|
173
|
+
value = generate_global_scale(input_tensor=module.weight)
|
174
|
+
value = value.to(device)
|
175
|
+
init_global_scale = Parameter(value, requires_grad=False)
|
176
|
+
else:
|
177
|
+
init_global_scale = Parameter(
|
178
|
+
torch.empty(1, dtype=torch.float32, device=device),
|
179
|
+
requires_grad=False,
|
180
|
+
)
|
181
|
+
register_offload_parameter(
|
182
|
+
module, f"{base_name}_global_scale", init_global_scale
|
183
|
+
)
|
184
|
+
|
185
|
+
# 2. Infer expected scale/zero point shape
|
166
186
|
if quantization_args.strategy == QuantizationStrategy.TOKEN:
|
167
187
|
expected_shape = (1, 1)
|
168
188
|
else:
|
@@ -172,47 +192,35 @@ def _initialize_scale_zero_point(
|
|
172
192
|
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
173
193
|
# (output_channels, 1)
|
174
194
|
expected_shape = (weight_shape[0], 1)
|
175
|
-
elif quantization_args.strategy
|
195
|
+
elif quantization_args.strategy in (
|
196
|
+
QuantizationStrategy.TENSOR_GROUP,
|
197
|
+
QuantizationStrategy.GROUP,
|
198
|
+
):
|
176
199
|
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
|
177
200
|
expected_shape = (weight_shape[0], max(num_groups, 1))
|
178
201
|
|
202
|
+
# 3. Identify quantization scale and zp dtype
|
179
203
|
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
|
180
|
-
# TODO: consider erroring out in the future as if the dtype if not one fo these,
|
181
|
-
# there is likely bug
|
182
|
-
|
183
|
-
if is_fp4(quantization_args=quantization_args) and base_name == "weight":
|
184
|
-
scale_dtype = FP8_E4M3_DATA.dtype
|
185
|
-
# When applying weight-only FP4 quantization, generate a global_scale
|
186
|
-
# This scale is applied during runtime to ensure that the generated
|
187
|
-
# local scale falls properly within the FP8 range (i.e max value is FP8_max)
|
188
|
-
# which is the expected dtype of NVFP4A16 scales
|
189
|
-
value = generate_global_scale(input_tensor=module.weight)
|
190
|
-
value = value.to(device)
|
191
|
-
init_global_scale = Parameter(value, requires_grad=False)
|
192
|
-
register_offload_parameter(
|
193
|
-
module, f"{base_name}_global_scale", init_global_scale
|
194
|
-
)
|
195
|
-
|
196
|
-
if scale_dtype not in [
|
197
|
-
torch.float16,
|
198
|
-
torch.bfloat16,
|
199
|
-
torch.float32,
|
200
|
-
] and not is_fp4(quantization_args=quantization_args):
|
201
|
-
scale_dtype = torch.float16
|
202
204
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
205
|
+
if is_fp4(quantization_args=quantization_args):
|
206
|
+
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
|
207
|
+
else:
|
208
|
+
# TODO: consider erroring out in the future as if the dtype if not one of these,
|
209
|
+
# there is likely bug
|
210
|
+
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
|
211
|
+
scale_dtype = torch.float16
|
212
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
213
|
+
|
214
|
+
# 4. Initializes empty scale, zero point, and g_idx parameters for the module
|
215
|
+
# do not init scales for quantzation_args.dynamic == DynamicType.local
|
216
|
+
if not quantization_args.dynamic:
|
217
|
+
init_scale = Parameter(
|
218
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
219
|
+
requires_grad=False,
|
220
|
+
)
|
221
|
+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
|
209
222
|
|
210
223
|
if force_zero_point or not quantization_args.symmetric:
|
211
|
-
if is_fp4(quantization_args=quantization_args):
|
212
|
-
zp_dtype = FP8_E4M3_DATA.dtype
|
213
|
-
else:
|
214
|
-
zp_dtype = quantization_args.pytorch_dtype()
|
215
|
-
|
216
224
|
init_zero_point = Parameter(
|
217
225
|
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
|
218
226
|
requires_grad=False,
|
@@ -32,6 +32,7 @@ __all__ = [
|
|
32
32
|
"QuantizationArgs",
|
33
33
|
"round_to_quantized_type",
|
34
34
|
"ActivationOrdering",
|
35
|
+
"DynamicType",
|
35
36
|
]
|
36
37
|
|
37
38
|
|
@@ -98,6 +99,22 @@ class QuantizationStrategy(str, Enum):
|
|
98
99
|
GROUP = "group"
|
99
100
|
BLOCK = "block"
|
100
101
|
TOKEN = "token"
|
102
|
+
TENSOR_GROUP = "tensor_group"
|
103
|
+
|
104
|
+
|
105
|
+
class DynamicType(str, Enum):
|
106
|
+
"""
|
107
|
+
Enum storing potential dynamic types.
|
108
|
+
|
109
|
+
1. If dynamic is True, all quantization parameters are generated on the fly.
|
110
|
+
2. If dynamic is False, all quantization parameters generated are static.
|
111
|
+
3. If "local" is provided, only local quantization parameters are dynamic.
|
112
|
+
|
113
|
+
Note: "local" is only currently supported for NVFP4.
|
114
|
+
|
115
|
+
"""
|
116
|
+
|
117
|
+
LOCAL = "local"
|
101
118
|
|
102
119
|
|
103
120
|
class ActivationOrdering(Aliasable, str, Enum):
|
@@ -152,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
152
169
|
group_size: Optional[int] = None
|
153
170
|
strategy: Optional[QuantizationStrategy] = None
|
154
171
|
block_structure: Optional[str] = None
|
155
|
-
dynamic: bool = False
|
172
|
+
dynamic: Union[DynamicType, bool] = False
|
156
173
|
actorder: Union[ActivationOrdering, bool, None] = None
|
157
174
|
observer: Optional[str] = Field(
|
158
175
|
default=None,
|
@@ -206,6 +223,12 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
206
223
|
|
207
224
|
return value
|
208
225
|
|
226
|
+
@field_validator("dynamic", mode="before")
|
227
|
+
def validate_dynamic(cls, value) -> Union[DynamicType, bool]:
|
228
|
+
if isinstance(value, str):
|
229
|
+
return DynamicType(value.lower())
|
230
|
+
return value
|
231
|
+
|
209
232
|
@model_validator(mode="after")
|
210
233
|
def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
|
211
234
|
# extract user-passed values from dictionary
|
@@ -239,7 +262,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
239
262
|
if (
|
240
263
|
group_size is not None
|
241
264
|
and group_size > 0
|
242
|
-
and strategy
|
265
|
+
and strategy
|
266
|
+
not in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP)
|
243
267
|
):
|
244
268
|
raise ValueError("group_size requires strategy to be set to 'group'")
|
245
269
|
|
@@ -255,22 +279,35 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
255
279
|
if strategy not in (
|
256
280
|
QuantizationStrategy.TOKEN,
|
257
281
|
QuantizationStrategy.TENSOR,
|
282
|
+
QuantizationStrategy.TENSOR_GROUP,
|
258
283
|
):
|
259
284
|
raise ValueError(
|
260
|
-
f"One of {QuantizationStrategy.TOKEN}
|
261
|
-
|
262
|
-
"quantization",
|
285
|
+
f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} "
|
286
|
+
"must be used for dynamic quantization",
|
263
287
|
)
|
288
|
+
|
289
|
+
if (
|
290
|
+
dynamic == DynamicType.LOCAL
|
291
|
+
and strategy != QuantizationStrategy.TENSOR_GROUP
|
292
|
+
):
|
293
|
+
raise ValueError("local is only supported for strategy tensor_group")
|
294
|
+
|
264
295
|
if observer is not None:
|
265
|
-
if
|
266
|
-
|
267
|
-
|
268
|
-
)
|
269
|
-
|
296
|
+
if dynamic is True: # checking if dynamic is True, not "local"
|
297
|
+
if (
|
298
|
+
observer != "memoryless"
|
299
|
+
): # avoid annoying users with old configs
|
300
|
+
warnings.warn(
|
301
|
+
"No observer is used for dynamic quantization, setting to None"
|
302
|
+
)
|
303
|
+
observer = None
|
304
|
+
else:
|
305
|
+
if dynamic == DynamicType.LOCAL:
|
306
|
+
observer = "minmax"
|
270
307
|
|
271
308
|
elif observer is None:
|
272
|
-
# default to
|
273
|
-
observer = "
|
309
|
+
# default to minmax for non-dynamic cases
|
310
|
+
observer = "minmax"
|
274
311
|
|
275
312
|
# write back modified values
|
276
313
|
model.strategy = strategy
|
@@ -16,7 +16,7 @@ from enum import Enum
|
|
16
16
|
from typing import Dict, List, Optional, Union
|
17
17
|
|
18
18
|
from compressed_tensors.config import CompressionFormat
|
19
|
-
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
19
|
+
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
|
20
20
|
from compressed_tensors.quantization.quant_scheme import (
|
21
21
|
QuantizationScheme,
|
22
22
|
preset_name_to_scheme,
|
@@ -251,7 +251,7 @@ class QuantizationConfig(BaseModel):
|
|
251
251
|
|
252
252
|
for _, scheme in self.config_groups.items():
|
253
253
|
if scheme.input_activations is not None:
|
254
|
-
if
|
254
|
+
if scheme.input_activations.dynamic in (False, DynamicType.LOCAL):
|
255
255
|
return True
|
256
256
|
if scheme.output_activations is not None:
|
257
257
|
if not scheme.output_activations.dynamic:
|
@@ -16,6 +16,7 @@ from copy import deepcopy
|
|
16
16
|
from typing import Any, Dict, List, Optional
|
17
17
|
|
18
18
|
from compressed_tensors.quantization.quant_args import (
|
19
|
+
DynamicType,
|
19
20
|
QuantizationArgs,
|
20
21
|
QuantizationStrategy,
|
21
22
|
QuantizationType,
|
@@ -104,13 +105,33 @@ NVFP4A16 = dict(
|
|
104
105
|
weights=QuantizationArgs(
|
105
106
|
num_bits=4,
|
106
107
|
type=QuantizationType.FLOAT,
|
107
|
-
strategy=QuantizationStrategy.
|
108
|
+
strategy=QuantizationStrategy.TENSOR_GROUP,
|
108
109
|
symmetric=True,
|
109
110
|
dynamic=False,
|
110
111
|
group_size=16,
|
111
112
|
)
|
112
113
|
)
|
113
114
|
|
115
|
+
|
116
|
+
NVFP4 = dict(
|
117
|
+
weights=QuantizationArgs(
|
118
|
+
num_bits=4,
|
119
|
+
type=QuantizationType.FLOAT,
|
120
|
+
strategy=QuantizationStrategy.TENSOR_GROUP,
|
121
|
+
symmetric=True,
|
122
|
+
dynamic=False,
|
123
|
+
group_size=16,
|
124
|
+
),
|
125
|
+
input_activations=QuantizationArgs(
|
126
|
+
num_bits=4,
|
127
|
+
type=QuantizationType.FLOAT,
|
128
|
+
strategy=QuantizationStrategy.TENSOR_GROUP,
|
129
|
+
symmetric=True,
|
130
|
+
dynamic=DynamicType.LOCAL,
|
131
|
+
group_size=16,
|
132
|
+
),
|
133
|
+
)
|
134
|
+
|
114
135
|
# 8 bit integer weights and 8 bit activations quantization
|
115
136
|
INT8_W8A8 = dict(
|
116
137
|
weights=QuantizationArgs(
|
@@ -237,4 +258,5 @@ PRESET_SCHEMES = {
|
|
237
258
|
"FP8": FP8,
|
238
259
|
"FP8_DYNAMIC": FP8_DYNAMIC,
|
239
260
|
"NVFP4A16": NVFP4A16,
|
261
|
+
"NVFP4": NVFP4,
|
240
262
|
}
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
+
import math
|
16
17
|
from typing import Generator, List, Optional, Tuple
|
17
18
|
|
18
19
|
import torch
|
@@ -103,7 +104,9 @@ def calculate_qparams(
|
|
103
104
|
if is_fp4(quantization_args=quantization_args) and global_scale is not None:
|
104
105
|
# Conditionally scale the generated local scale by a global_scale
|
105
106
|
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
|
107
|
+
scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
|
106
108
|
scales = scales.to(FP8_E4M3_DATA.dtype)
|
109
|
+
|
107
110
|
else:
|
108
111
|
scales = max_val_pos / (float(bit_range) / 2)
|
109
112
|
|
@@ -143,7 +146,12 @@ def calculate_qparams(
|
|
143
146
|
return scales, zero_points
|
144
147
|
|
145
148
|
|
146
|
-
def compute_dynamic_scales_and_zp(
|
149
|
+
def compute_dynamic_scales_and_zp(
|
150
|
+
value: Tensor,
|
151
|
+
args: QuantizationArgs,
|
152
|
+
module: torch.nn.Module,
|
153
|
+
global_scale: Optional[Tensor] = None,
|
154
|
+
):
|
147
155
|
"""
|
148
156
|
Returns the computed scales and zero points for dynamic activation
|
149
157
|
quantization.
|
@@ -155,24 +163,41 @@ def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
|
|
155
163
|
reduced dimensions
|
156
164
|
:return: tuple of scale and zero point derived from the observed tensor
|
157
165
|
"""
|
166
|
+
|
167
|
+
keep_dims = True
|
158
168
|
if args.strategy == QuantizationStrategy.TOKEN:
|
159
169
|
dim = {1, 2}
|
160
170
|
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
|
161
171
|
elif args.strategy == QuantizationStrategy.TENSOR:
|
162
172
|
reduce_dims = None
|
173
|
+
elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
|
174
|
+
if len(value.shape) > 2:
|
175
|
+
value = value.squeeze(0)
|
176
|
+
|
177
|
+
dim = {0, 1}
|
178
|
+
reduce_dims = tuple(idx for idx in range(3) if idx not in dim)
|
179
|
+
keep_dims = False
|
180
|
+
value = torch.reshape(
|
181
|
+
value,
|
182
|
+
(
|
183
|
+
value.shape[0],
|
184
|
+
math.ceil(value.shape[1] / args.group_size),
|
185
|
+
args.group_size,
|
186
|
+
),
|
187
|
+
)
|
163
188
|
else:
|
164
189
|
raise ValueError(
|
165
|
-
|
166
|
-
"
|
190
|
+
"Dynamic quantization is only supported for ",
|
191
|
+
f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}",
|
167
192
|
)
|
168
193
|
|
169
194
|
if not reduce_dims:
|
170
195
|
min_val, max_val = torch.aminmax(value)
|
171
196
|
else:
|
172
|
-
min_val = torch.amin(value, dim=reduce_dims, keepdims=
|
173
|
-
max_val = torch.amax(value, dim=reduce_dims, keepdims=
|
197
|
+
min_val = torch.amin(value, dim=reduce_dims, keepdims=keep_dims)
|
198
|
+
max_val = torch.amax(value, dim=reduce_dims, keepdims=keep_dims)
|
174
199
|
|
175
|
-
return calculate_qparams(min_val, max_val, args)
|
200
|
+
return calculate_qparams(min_val, max_val, args, global_scale=global_scale)
|
176
201
|
|
177
202
|
|
178
203
|
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# flake8: noqa
|
16
|
+
# isort: skip_file
|
17
|
+
|
18
|
+
from .transform_args import *
|
19
|
+
from .transform_scheme import *
|
20
|
+
from .transform_config import *
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from enum import Enum
|
16
|
+
from typing import Any, List
|
17
|
+
|
18
|
+
from pydantic import BaseModel, Field, field_validator
|
19
|
+
|
20
|
+
|
21
|
+
__all__ = ["TransformArgs"]
|
22
|
+
|
23
|
+
|
24
|
+
class TransformLocation(str, Enum):
|
25
|
+
INPUT = "input"
|
26
|
+
WEIGHT_INPUT = "weight_input"
|
27
|
+
WEIGHT_OUTPUT = "weight_output"
|
28
|
+
OUTPUT = "output"
|
29
|
+
K_CACHE = "k_cache"
|
30
|
+
Q_ATTN = "q_attn"
|
31
|
+
|
32
|
+
|
33
|
+
class TransformArgs(BaseModel):
|
34
|
+
"""
|
35
|
+
Arguments which define *how* and where a transform should be applied to a model
|
36
|
+
|
37
|
+
:param targets: list of modules to apply transforms to
|
38
|
+
:param location: where to apply transform on module, one of (`input`, `weight`,
|
39
|
+
`output`, `k_cache`, `q_attn`)
|
40
|
+
:param inverse: whether or not to apply the inverse of a transform
|
41
|
+
:param ignore: any modules which should be ignored from the targets list
|
42
|
+
"""
|
43
|
+
|
44
|
+
targets: List[str]
|
45
|
+
location: TransformLocation
|
46
|
+
inverse: bool = Field(default=False)
|
47
|
+
ignore: List[str] = Field(default_factory=list)
|
48
|
+
|
49
|
+
@field_validator("targets", "ignore", mode="before")
|
50
|
+
@classmethod
|
51
|
+
def wrap_singleton(cls, value):
|
52
|
+
if isinstance(value, str):
|
53
|
+
return [value]
|
54
|
+
return value
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict
|
16
|
+
|
17
|
+
from compressed_tensors.transform import TransformArgs, TransformScheme
|
18
|
+
from pydantic import BaseModel
|
19
|
+
|
20
|
+
|
21
|
+
__all__ = ["TransformConfig"]
|
22
|
+
|
23
|
+
|
24
|
+
class TransformConfig(BaseModel):
|
25
|
+
"""
|
26
|
+
Configuration of transforms to be applied to a model. This config is to be
|
27
|
+
serialized within a model's `config.json` file
|
28
|
+
|
29
|
+
:param config_groups: A dictionary of `TransformSchemes` that should be applied
|
30
|
+
to a particular model. The keys can be any arbitrary string
|
31
|
+
"""
|
32
|
+
|
33
|
+
config_groups: Dict[str, TransformScheme]
|
34
|
+
|
35
|
+
|
36
|
+
# quip / quip sharp
|
37
|
+
QUIP = TransformConfig(
|
38
|
+
config_groups={
|
39
|
+
"v": TransformScheme(
|
40
|
+
type="hadamard",
|
41
|
+
apply=[
|
42
|
+
TransformArgs(
|
43
|
+
targets=["Linear"],
|
44
|
+
location="input", # non-mergable
|
45
|
+
),
|
46
|
+
TransformArgs(
|
47
|
+
targets=["Linear"],
|
48
|
+
location="weight_input",
|
49
|
+
inverse=True,
|
50
|
+
),
|
51
|
+
],
|
52
|
+
randomize_modules=True,
|
53
|
+
),
|
54
|
+
"u": TransformScheme(
|
55
|
+
type="hadamard",
|
56
|
+
apply=[
|
57
|
+
TransformArgs(
|
58
|
+
targets=["Linear"],
|
59
|
+
location="weight_output",
|
60
|
+
),
|
61
|
+
TransformArgs(
|
62
|
+
targets=["Linear"], location="output", inverse=True # non-mergable
|
63
|
+
),
|
64
|
+
],
|
65
|
+
randomize_modules=True,
|
66
|
+
),
|
67
|
+
}
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
PRESET_CONFIGS = {
|
72
|
+
"QUIP": QUIP,
|
73
|
+
}
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List
|
16
|
+
|
17
|
+
from compressed_tensors.transform import TransformArgs
|
18
|
+
from pydantic import BaseModel, Field
|
19
|
+
|
20
|
+
|
21
|
+
__all__ = ["TransformScheme"]
|
22
|
+
|
23
|
+
|
24
|
+
class TransformScheme(BaseModel):
|
25
|
+
"""
|
26
|
+
Scheme used to parameterize a particular transform type and specify how and where it
|
27
|
+
should be applied to the model
|
28
|
+
|
29
|
+
:param type: string indicating the particular transform type that should be created
|
30
|
+
and applied. This should be one of the registered transform types
|
31
|
+
(see `Transforms.registered_names()`)
|
32
|
+
:param apply: list of TransformationArgs containing the information about the
|
33
|
+
modules that should be targeted by the specified transform
|
34
|
+
:param randomize_modules: True if unique transforms should be applied to each
|
35
|
+
unique module targeted by `apply`, otherwise reuse transform weights where
|
36
|
+
applicable
|
37
|
+
:param requires_grad: True if weights include gradients for training
|
38
|
+
"""
|
39
|
+
|
40
|
+
type: str
|
41
|
+
apply: List[TransformArgs] = Field(default_factory=list)
|
42
|
+
randomize_modules: bool = Field(default=False)
|
43
|
+
requires_grad: bool = Field(default=False)
|
compressed_tensors/version.py
CHANGED
{compressed_tensors-0.9.5a20250521.dist-info → compressed_tensors-0.9.5a20250530.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.9.
|
3
|
+
Version: 0.9.5a20250530
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
{compressed_tensors-0.9.5a20250521.dist-info → compressed_tensors-0.9.5a20250530.dist-info}/RECORD
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
|
2
2
|
compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
|
3
|
-
compressed_tensors/version.py,sha256=
|
3
|
+
compressed_tensors/version.py,sha256=BwDcUUpFaOn_-cMqdBWktPf89WCzFmESpx94d8qAUZM,521
|
4
4
|
compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
|
5
5
|
compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
|
6
6
|
compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
|
@@ -26,19 +26,23 @@ compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5y
|
|
26
26
|
compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
27
27
|
compressed_tensors/linear/compressed_linear.py,sha256=1yo9RyjA0aQ--iuIknFfcSorJn43Mn4CoV-q4JlTJ_o,4052
|
28
28
|
compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
|
29
|
-
compressed_tensors/quantization/quant_args.py,sha256=
|
30
|
-
compressed_tensors/quantization/quant_config.py,sha256=
|
31
|
-
compressed_tensors/quantization/quant_scheme.py,sha256=
|
29
|
+
compressed_tensors/quantization/quant_args.py,sha256=huROC8fbY899EYa2MnEmujvcBeHYLpn-e8ZEViEFASo,11804
|
30
|
+
compressed_tensors/quantization/quant_config.py,sha256=aFi6PKqmEX9iP9O8GVn3mEUjRDEwk_hOCbmmiq-j9oU,10198
|
31
|
+
compressed_tensors/quantization/quant_scheme.py,sha256=IDWa1GWUbUdWCo8j78Jz6svYF5hLz89J2PVYWBBnXRc,7102
|
32
32
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
|
33
33
|
compressed_tensors/quantization/lifecycle/apply.py,sha256=-OKZ-FFFfIIoeGTrho8lXx6HVWZQp3Xkn3Q-G0hU-CM,18294
|
34
34
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
35
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
35
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=65USJEtsp_n8X36L5y4g4ftMnhrQyRWbwKJ8RZMMiBo,14797
|
36
36
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
37
|
-
compressed_tensors/quantization/lifecycle/initialize.py,sha256=
|
37
|
+
compressed_tensors/quantization/lifecycle/initialize.py,sha256=976sZ45ywGVzH1n4pyVhG7hnUBP1wKEWoo9cHrmKHxU,12522
|
38
38
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
39
|
-
compressed_tensors/quantization/utils/helpers.py,sha256=
|
39
|
+
compressed_tensors/quantization/utils/helpers.py,sha256=I-bJcMdBFXjIUQEpnxMMN_FfQyXjojpe5w7ZIKSZ5UU,17588
|
40
40
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
41
41
|
compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
|
42
|
+
compressed_tensors/transform/__init__.py,sha256=oa5VdrE-GtDYYceXNSwj5X_ropoXLLukm6Aufcc9WhY,747
|
43
|
+
compressed_tensors/transform/transform_args.py,sha256=Sazu_4kXL7IvIEgTaimgo8dV-qacXf_t1NLEfDvPJEU,1759
|
44
|
+
compressed_tensors/transform/transform_config.py,sha256=6JA8VFcoz4EGHOev6thj51OuB7K2gKUUazWjrVPYDLc,2144
|
45
|
+
compressed_tensors/transform/transform_scheme.py,sha256=c7NAuLDL0itFgUfBMNShegMI9bzKL7s4LR3QJTHsXLs,1733
|
42
46
|
compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
|
43
47
|
compressed_tensors/utils/helpers.py,sha256=RrNvzD08naEjEiXdU-FdZjQVda1nQywu1hA_GCDj0vg,10415
|
44
48
|
compressed_tensors/utils/offload.py,sha256=JNQ66_6vhSsizhlUaMgyEdBuFolYxbgUuT1mAZrCfKY,15436
|
@@ -46,8 +50,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
|
|
46
50
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
47
51
|
compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
|
48
52
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
49
|
-
compressed_tensors-0.9.
|
50
|
-
compressed_tensors-0.9.
|
51
|
-
compressed_tensors-0.9.
|
52
|
-
compressed_tensors-0.9.
|
53
|
-
compressed_tensors-0.9.
|
53
|
+
compressed_tensors-0.9.5a20250530.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
54
|
+
compressed_tensors-0.9.5a20250530.dist-info/METADATA,sha256=avjHgMxk1vnX09YKjerSCov-X8mTckulmJV1xQyLk5I,7004
|
55
|
+
compressed_tensors-0.9.5a20250530.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
56
|
+
compressed_tensors-0.9.5a20250530.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
57
|
+
compressed_tensors-0.9.5a20250530.dist-info/RECORD,,
|
File without changes
|
File without changes
|