compressed-tensors-nightly 0.7.1.20241031__py3-none-any.whl → 0.7.1.20241102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/quantization/__init__.py +0 -1
- compressed_tensors/quantization/lifecycle/__init__.py +0 -2
- compressed_tensors/quantization/lifecycle/apply.py +1 -16
- compressed_tensors/quantization/lifecycle/forward.py +15 -109
- compressed_tensors/quantization/lifecycle/initialize.py +18 -21
- compressed_tensors/quantization/quant_args.py +11 -22
- compressed_tensors/quantization/utils/helpers.py +125 -8
- compressed_tensors/registry/registry.py +1 -1
- {compressed_tensors_nightly-0.7.1.20241031.dist-info → compressed_tensors_nightly-0.7.1.20241102.dist-info}/METADATA +1 -1
- {compressed_tensors_nightly-0.7.1.20241031.dist-info → compressed_tensors_nightly-0.7.1.20241102.dist-info}/RECORD +13 -21
- compressed_tensors/quantization/cache.py +0 -200
- compressed_tensors/quantization/lifecycle/calibration.py +0 -80
- compressed_tensors/quantization/lifecycle/frozen.py +0 -50
- compressed_tensors/quantization/observers/__init__.py +0 -21
- compressed_tensors/quantization/observers/base.py +0 -213
- compressed_tensors/quantization/observers/helpers.py +0 -149
- compressed_tensors/quantization/observers/min_max.py +0 -104
- compressed_tensors/quantization/observers/mse.py +0 -164
- {compressed_tensors_nightly-0.7.1.20241031.dist-info → compressed_tensors_nightly-0.7.1.20241102.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.7.1.20241031.dist-info → compressed_tensors_nightly-0.7.1.20241102.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.7.1.20241031.dist-info → compressed_tensors_nightly-0.7.1.20241102.dist-info}/top_level.txt +0 -0
@@ -22,13 +22,9 @@ from typing import Union
|
|
22
22
|
|
23
23
|
import torch
|
24
24
|
from compressed_tensors.config import CompressionFormat
|
25
|
-
from compressed_tensors.quantization.lifecycle.calibration import (
|
26
|
-
set_module_for_calibration,
|
27
|
-
)
|
28
25
|
from compressed_tensors.quantization.lifecycle.compressed import (
|
29
26
|
compress_quantized_weights,
|
30
27
|
)
|
31
|
-
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
|
32
28
|
from compressed_tensors.quantization.lifecycle.initialize import (
|
33
29
|
initialize_module_for_quantization,
|
34
30
|
)
|
@@ -233,6 +229,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
233
229
|
:param model: model to apply quantization to
|
234
230
|
:param status: status to update the module to
|
235
231
|
"""
|
232
|
+
|
236
233
|
current_status = infer_quantization_status(model)
|
237
234
|
|
238
235
|
if status >= QuantizationStatus.INITIALIZED > current_status:
|
@@ -243,18 +240,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
243
240
|
)
|
244
241
|
)
|
245
242
|
|
246
|
-
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
247
|
-
# only quantize weights up front when our end goal state is calibration,
|
248
|
-
# weight quantization parameters are already loaded for frozen/compressed
|
249
|
-
quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
|
250
|
-
model.apply(
|
251
|
-
lambda module: set_module_for_calibration(
|
252
|
-
module, quantize_weights_upfront=quantize_weights_upfront
|
253
|
-
)
|
254
|
-
)
|
255
|
-
if current_status < status >= QuantizationStatus.FROZEN > current_status:
|
256
|
-
model.apply(freeze_module_quantization)
|
257
|
-
|
258
243
|
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
|
259
244
|
model.apply(compress_quantized_weights)
|
260
245
|
|
@@ -14,14 +14,9 @@
|
|
14
14
|
|
15
15
|
from functools import wraps
|
16
16
|
from math import ceil
|
17
|
-
from typing import
|
17
|
+
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
-
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
21
|
-
from compressed_tensors.quantization.observers.helpers import (
|
22
|
-
calculate_range,
|
23
|
-
compute_dynamic_scales_and_zp,
|
24
|
-
)
|
25
20
|
from compressed_tensors.quantization.quant_args import (
|
26
21
|
QuantizationArgs,
|
27
22
|
QuantizationStrategy,
|
@@ -29,7 +24,11 @@ from compressed_tensors.quantization.quant_args import (
|
|
29
24
|
)
|
30
25
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
31
26
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
32
|
-
from compressed_tensors.utils import
|
27
|
+
from compressed_tensors.quantization.utils import (
|
28
|
+
calculate_range,
|
29
|
+
compute_dynamic_scales_and_zp,
|
30
|
+
)
|
31
|
+
from compressed_tensors.utils import safe_permute
|
33
32
|
from torch.nn import Module
|
34
33
|
|
35
34
|
|
@@ -39,7 +38,6 @@ __all__ = [
|
|
39
38
|
"fake_quantize",
|
40
39
|
"wrap_module_forward_quantized",
|
41
40
|
"forward_quantize",
|
42
|
-
"calibrate_activations",
|
43
41
|
]
|
44
42
|
|
45
43
|
|
@@ -276,19 +274,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
276
274
|
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
|
277
275
|
|
278
276
|
if scheme.input_activations is not None:
|
279
|
-
#
|
280
|
-
# NOTE: will be moved out of compressed-tensors
|
281
|
-
if (
|
282
|
-
module.quantization_status == QuantizationStatus.CALIBRATION
|
283
|
-
and not scheme.input_activations.dynamic
|
284
|
-
):
|
285
|
-
calibrate_activations(
|
286
|
-
module=module,
|
287
|
-
value=input_,
|
288
|
-
base_name="input",
|
289
|
-
quantization_args=scheme.input_activations,
|
290
|
-
)
|
291
|
-
|
277
|
+
# prehook should calibrate activations before forward call
|
292
278
|
input_ = forward_quantize(module, input_, "input", scheme.input_activations)
|
293
279
|
|
294
280
|
if scheme.weights is not None and not compressed:
|
@@ -302,31 +288,22 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
302
288
|
output = forward_func_orig.__get__(module, module.__class__)(
|
303
289
|
input_, *args[1:], **kwargs
|
304
290
|
)
|
305
|
-
if scheme.output_activations is not None:
|
306
291
|
|
307
|
-
|
308
|
-
|
309
|
-
|
292
|
+
# restore back to unquantized_value
|
293
|
+
if scheme.weights is not None and not compressed:
|
294
|
+
self.weight.data = unquantized_weight
|
310
295
|
|
296
|
+
if scheme.output_activations is not None:
|
297
|
+
# forward-hook should calibrate/forward_quantize
|
311
298
|
if (
|
312
299
|
module.quantization_status == QuantizationStatus.CALIBRATION
|
313
300
|
and not scheme.output_activations.dynamic
|
314
301
|
):
|
315
|
-
|
316
|
-
module=module,
|
317
|
-
value=output,
|
318
|
-
base_name="output",
|
319
|
-
quantization_args=scheme.ouput_activations,
|
320
|
-
)
|
302
|
+
return output
|
321
303
|
|
322
304
|
output = forward_quantize(
|
323
305
|
module, output, "output", scheme.output_activations
|
324
306
|
)
|
325
|
-
|
326
|
-
# restore back to unquantized_value
|
327
|
-
if scheme.weights is not None and not compressed:
|
328
|
-
self.weight.data = unquantized_weight
|
329
|
-
|
330
307
|
return output
|
331
308
|
|
332
309
|
# bind wrapped forward to module class so reference to `self` is correct
|
@@ -335,77 +312,6 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
335
312
|
setattr(module, "forward", bound_wrapped_forward)
|
336
313
|
|
337
314
|
|
338
|
-
def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
|
339
|
-
# expects a module already initialized and injected with the parameters in
|
340
|
-
# initialize_module_for_quantization
|
341
|
-
if hasattr(module.forward, "__func__"):
|
342
|
-
forward_func_orig = module.forward.__func__
|
343
|
-
else:
|
344
|
-
forward_func_orig = module.forward.func
|
345
|
-
|
346
|
-
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
347
|
-
def wrapped_forward(self, *args, **kwargs):
|
348
|
-
|
349
|
-
# kv cache stored under weights
|
350
|
-
if module.quantization_status == QuantizationStatus.CALIBRATION:
|
351
|
-
quantization_args: QuantizationArgs = scheme.output_activations
|
352
|
-
past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
|
353
|
-
kwargs["past_key_value"] = past_key_value
|
354
|
-
|
355
|
-
# QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
|
356
|
-
# does not store quantized_key_states and quantized_value_state
|
357
|
-
kwargs["use_cache"] = False
|
358
|
-
|
359
|
-
attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
|
360
|
-
|
361
|
-
past_key_value.reset_states()
|
362
|
-
|
363
|
-
rtn = attn_forward(*args, **kwargs)
|
364
|
-
|
365
|
-
update_parameter_data(
|
366
|
-
module, past_key_value.k_scales[module.layer_idx], "k_scale"
|
367
|
-
)
|
368
|
-
update_parameter_data(
|
369
|
-
module, past_key_value.v_scales[module.layer_idx], "v_scale"
|
370
|
-
)
|
371
|
-
|
372
|
-
return rtn
|
373
|
-
|
374
|
-
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
375
|
-
|
376
|
-
# bind wrapped forward to module class so reference to `self` is correct
|
377
|
-
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
|
378
|
-
# set forward to wrapped forward
|
379
|
-
setattr(module, "forward", bound_wrapped_forward)
|
380
|
-
|
381
|
-
|
382
|
-
def calibrate_activations(
|
383
|
-
module: Module,
|
384
|
-
value: torch.Tensor,
|
385
|
-
base_name: str,
|
386
|
-
quantization_args: QuantizationArgs,
|
387
|
-
):
|
388
|
-
# If empty tensor, can't update zp/scale
|
389
|
-
# Case for MoEs
|
390
|
-
if value.numel() == 0:
|
391
|
-
return
|
392
|
-
# calibration mode - get new quant params from observer
|
393
|
-
if not hasattr(module, f"{base_name}_observer"):
|
394
|
-
from compressed_tensors.quantization.lifecycle import initialize_observers
|
395
|
-
|
396
|
-
initialize_observers(
|
397
|
-
module=module, base_name=base_name, quantization_args=quantization_args
|
398
|
-
)
|
399
|
-
|
400
|
-
observer = getattr(module, f"{base_name}_observer")
|
401
|
-
|
402
|
-
updated_scale, updated_zero_point = observer(value)
|
403
|
-
|
404
|
-
# update scale and zero point
|
405
|
-
update_parameter_data(module, updated_scale, f"{base_name}_scale")
|
406
|
-
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
|
407
|
-
|
408
|
-
|
409
315
|
def forward_quantize(
|
410
316
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
411
317
|
) -> torch.Tensor:
|
@@ -426,10 +332,10 @@ def forward_quantize(
|
|
426
332
|
g_idx = getattr(module, "weight_g_idx", None)
|
427
333
|
|
428
334
|
if args.dynamic:
|
429
|
-
# dynamic quantization -
|
335
|
+
# dynamic quantization - determine the scale/zp on the fly
|
430
336
|
scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
|
431
337
|
else:
|
432
|
-
# static quantization - get
|
338
|
+
# static quantization - get scale and zero point from layer
|
433
339
|
scale = getattr(module, f"{base_name}_scale")
|
434
340
|
zero_point = getattr(module, f"{base_name}_zero_point", None)
|
435
341
|
|
@@ -14,13 +14,12 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
import logging
|
17
|
+
from enum import Enum
|
17
18
|
from typing import Optional
|
18
19
|
|
19
20
|
import torch
|
20
|
-
from compressed_tensors.quantization.cache import KVCacheScaleType
|
21
21
|
from compressed_tensors.quantization.lifecycle.forward import (
|
22
22
|
wrap_module_forward_quantized,
|
23
|
-
wrap_module_forward_quantized_attn,
|
24
23
|
)
|
25
24
|
from compressed_tensors.quantization.quant_args import (
|
26
25
|
ActivationOrdering,
|
@@ -34,12 +33,21 @@ from compressed_tensors.utils import get_execution_device, is_module_offloaded
|
|
34
33
|
from torch.nn import Module, Parameter
|
35
34
|
|
36
35
|
|
37
|
-
__all__ = [
|
36
|
+
__all__ = [
|
37
|
+
"initialize_module_for_quantization",
|
38
|
+
"is_attention_module",
|
39
|
+
"KVCacheScaleType",
|
40
|
+
]
|
38
41
|
|
39
42
|
|
40
43
|
_LOGGER = logging.getLogger(__name__)
|
41
44
|
|
42
45
|
|
46
|
+
class KVCacheScaleType(Enum):
|
47
|
+
KEY = "k_scale"
|
48
|
+
VALUE = "v_scale"
|
49
|
+
|
50
|
+
|
43
51
|
def initialize_module_for_quantization(
|
44
52
|
module: Module,
|
45
53
|
scheme: Optional[QuantizationScheme] = None,
|
@@ -64,9 +72,7 @@ def initialize_module_for_quantization(
|
|
64
72
|
return
|
65
73
|
|
66
74
|
if is_attention_module(module):
|
67
|
-
# wrap forward call of module to perform
|
68
75
|
# quantized actions based on calltime status
|
69
|
-
wrap_module_forward_quantized_attn(module, scheme)
|
70
76
|
_initialize_attn_scales(module)
|
71
77
|
|
72
78
|
else:
|
@@ -107,6 +113,7 @@ def initialize_module_for_quantization(
|
|
107
113
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
108
114
|
|
109
115
|
offloaded = False
|
116
|
+
# What is this doing/why isn't this in the attn case?
|
110
117
|
if is_module_offloaded(module):
|
111
118
|
try:
|
112
119
|
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
@@ -144,14 +151,12 @@ def initialize_module_for_quantization(
|
|
144
151
|
module._hf_hook.weights_map = new_prefix_dict
|
145
152
|
|
146
153
|
|
147
|
-
def
|
148
|
-
module
|
149
|
-
|
150
|
-
|
151
|
-
)
|
152
|
-
|
153
|
-
observer = quantization_args.get_observer()
|
154
|
-
module.register_module(f"{base_name}_observer", observer)
|
154
|
+
def is_attention_module(module: Module):
|
155
|
+
return "attention" in module.__class__.__name__.lower() and (
|
156
|
+
hasattr(module, "k_proj")
|
157
|
+
or hasattr(module, "v_proj")
|
158
|
+
or hasattr(module, "qkv_proj")
|
159
|
+
)
|
155
160
|
|
156
161
|
|
157
162
|
def _initialize_scale_zero_point(
|
@@ -209,14 +214,6 @@ def _initialize_scale_zero_point(
|
|
209
214
|
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
|
210
215
|
|
211
216
|
|
212
|
-
def is_attention_module(module: Module):
|
213
|
-
return "attention" in module.__class__.__name__.lower() and (
|
214
|
-
hasattr(module, "k_proj")
|
215
|
-
or hasattr(module, "v_proj")
|
216
|
-
or hasattr(module, "qkv_proj")
|
217
|
-
)
|
218
|
-
|
219
|
-
|
220
217
|
def _initialize_attn_scales(module: Module) -> None:
|
221
218
|
"""Initlaize k_scale, v_scale for self_attn"""
|
222
219
|
|
@@ -114,20 +114,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
114
114
|
"""
|
115
115
|
:return: torch quantization FakeQuantize built based on these QuantizationArgs
|
116
116
|
"""
|
117
|
-
|
118
|
-
|
119
|
-
# No observer required for the dynamic case
|
120
|
-
if self.dynamic:
|
121
|
-
self.observer = None
|
122
|
-
return self.observer
|
123
|
-
|
124
|
-
return Observer.load_from_registry(self.observer, quantization_args=self)
|
125
|
-
|
126
|
-
def get_kv_cache(self):
|
127
|
-
"""Get the singleton KV Cache"""
|
128
|
-
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
129
|
-
|
130
|
-
return QuantizedKVParameterCache(self)
|
117
|
+
return self.observer
|
131
118
|
|
132
119
|
@field_validator("type", mode="before")
|
133
120
|
def validate_type(cls, value) -> QuantizationType:
|
@@ -210,6 +197,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
210
197
|
"activation ordering"
|
211
198
|
)
|
212
199
|
|
200
|
+
# infer observer w.r.t. dynamic
|
213
201
|
if dynamic:
|
214
202
|
if strategy not in (
|
215
203
|
QuantizationStrategy.TOKEN,
|
@@ -221,18 +209,19 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
221
209
|
"quantization",
|
222
210
|
)
|
223
211
|
if observer is not None:
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
212
|
+
if observer != "memoryless": # avoid annoying users with old configs
|
213
|
+
warnings.warn(
|
214
|
+
"No observer is used for dynamic quantization, setting to None"
|
215
|
+
)
|
216
|
+
observer = None
|
228
217
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
model.observer = "minmax"
|
218
|
+
elif observer is None:
|
219
|
+
# default to minmax for non-dynamic cases
|
220
|
+
observer = "minmax"
|
233
221
|
|
234
222
|
# write back modified values
|
235
223
|
model.strategy = strategy
|
224
|
+
model.observer = observer
|
236
225
|
return model
|
237
226
|
|
238
227
|
def pytorch_dtype(self) -> torch.dtype:
|
@@ -16,9 +16,14 @@ import logging
|
|
16
16
|
from typing import Generator, List, Optional, Tuple
|
17
17
|
|
18
18
|
import torch
|
19
|
-
from compressed_tensors.quantization.
|
20
|
-
|
19
|
+
from compressed_tensors.quantization.quant_args import (
|
20
|
+
FP8_DTYPE,
|
21
|
+
QuantizationArgs,
|
22
|
+
QuantizationStrategy,
|
23
|
+
QuantizationType,
|
24
|
+
)
|
21
25
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
26
|
+
from torch import FloatTensor, IntTensor, Tensor
|
22
27
|
from torch.nn import Module
|
23
28
|
from tqdm import tqdm
|
24
29
|
|
@@ -36,6 +41,9 @@ __all__ = [
|
|
36
41
|
"is_kv_cache_quant_scheme",
|
37
42
|
"iter_named_leaf_modules",
|
38
43
|
"iter_named_quantizable_modules",
|
44
|
+
"compute_dynamic_scales_and_zp",
|
45
|
+
"calculate_range",
|
46
|
+
"calculate_qparams",
|
39
47
|
]
|
40
48
|
|
41
49
|
# target the self_attn layer
|
@@ -45,6 +53,105 @@ KV_CACHE_TARGETS = ["re:.*self_attn$"]
|
|
45
53
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
46
54
|
|
47
55
|
|
56
|
+
def calculate_qparams(
|
57
|
+
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
|
58
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
59
|
+
"""
|
60
|
+
:param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
|
61
|
+
from
|
62
|
+
:param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
|
63
|
+
from
|
64
|
+
:param quantization_args: settings to quantization
|
65
|
+
:return: tuple of the calculated scale(s) and zero point(s)
|
66
|
+
"""
|
67
|
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
68
|
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
69
|
+
device = min_vals.device
|
70
|
+
|
71
|
+
bit_min, bit_max = calculate_range(quantization_args, device)
|
72
|
+
bit_range = bit_max - bit_min
|
73
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
74
|
+
|
75
|
+
if quantization_args.symmetric:
|
76
|
+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
|
77
|
+
scales = max_val_pos / (float(bit_range) / 2)
|
78
|
+
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
79
|
+
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
|
80
|
+
else:
|
81
|
+
scales = (max_vals - min_vals) / float(bit_range)
|
82
|
+
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
83
|
+
zero_points = bit_min - (min_vals / scales)
|
84
|
+
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
85
|
+
|
86
|
+
# match zero-points to quantized type
|
87
|
+
zero_points = zero_points.to(zp_dtype)
|
88
|
+
|
89
|
+
if scales.ndim == 0:
|
90
|
+
scales = scales.reshape(1)
|
91
|
+
zero_points = zero_points.reshape(1)
|
92
|
+
|
93
|
+
return scales, zero_points
|
94
|
+
|
95
|
+
|
96
|
+
def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
|
97
|
+
"""
|
98
|
+
Returns the computed scales and zero points for dynamic activation
|
99
|
+
qunatization.
|
100
|
+
|
101
|
+
:param value: tensor to calculate quantization parameters for
|
102
|
+
:param args: quantization args
|
103
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
104
|
+
returned scale and zero point will be shaped (1,) along the
|
105
|
+
reduced dimensions
|
106
|
+
:return: tuple of scale and zero point derived from the observed tensor
|
107
|
+
"""
|
108
|
+
if args.strategy == QuantizationStrategy.TOKEN:
|
109
|
+
dim = {1, 2}
|
110
|
+
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
|
111
|
+
elif args.strategy == QuantizationStrategy.TENSOR:
|
112
|
+
reduce_dims = None
|
113
|
+
else:
|
114
|
+
raise ValueError(
|
115
|
+
f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
|
116
|
+
"must be used for dynamic quantization",
|
117
|
+
)
|
118
|
+
|
119
|
+
if not reduce_dims:
|
120
|
+
min_val, max_val = torch.aminmax(value)
|
121
|
+
else:
|
122
|
+
min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
|
123
|
+
max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
|
124
|
+
|
125
|
+
return calculate_qparams(min_val, max_val, args)
|
126
|
+
|
127
|
+
|
128
|
+
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
|
129
|
+
"""
|
130
|
+
Calculated the effective quantization range for the given Quantization Args
|
131
|
+
|
132
|
+
:param quantization_args: quantization args to get range of
|
133
|
+
:param device: device to store the range to
|
134
|
+
:return: tuple endpoints for the given quantization range
|
135
|
+
"""
|
136
|
+
if quantization_args.type == QuantizationType.INT:
|
137
|
+
bit_range = 2**quantization_args.num_bits
|
138
|
+
q_max = torch.tensor(bit_range / 2 - 1, device=device)
|
139
|
+
q_min = torch.tensor(-bit_range / 2, device=device)
|
140
|
+
elif quantization_args.type == QuantizationType.FLOAT:
|
141
|
+
if quantization_args.num_bits != 8:
|
142
|
+
raise ValueError(
|
143
|
+
"Floating point quantization is only supported for 8 bits,"
|
144
|
+
f"got {quantization_args.num_bits}"
|
145
|
+
)
|
146
|
+
fp_range_info = torch.finfo(FP8_DTYPE)
|
147
|
+
q_max = torch.tensor(fp_range_info.max, device=device)
|
148
|
+
q_min = torch.tensor(fp_range_info.min, device=device)
|
149
|
+
else:
|
150
|
+
raise ValueError(f"Invalid quantization type {quantization_args.type}")
|
151
|
+
|
152
|
+
return q_min, q_max
|
153
|
+
|
154
|
+
|
48
155
|
def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
|
49
156
|
"""
|
50
157
|
Checks the quantization status of a model. Assumes all modules in the model have
|
@@ -118,12 +225,17 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None
|
|
118
225
|
"""
|
119
226
|
for name, submodule in model.named_modules():
|
120
227
|
children = list(submodule.children())
|
121
|
-
if
|
228
|
+
# TODO: verify if an observer would ever be attached in this case/remove check
|
229
|
+
if len(children) == 0 and "observer" in name:
|
122
230
|
yield name, submodule
|
123
231
|
else:
|
232
|
+
if len(children) > 0:
|
233
|
+
named_children, children = zip(*list(submodule.named_children()))
|
124
234
|
has_non_observer_children = False
|
125
|
-
for
|
126
|
-
|
235
|
+
for i in range(len(children)):
|
236
|
+
child_name = named_children[i]
|
237
|
+
|
238
|
+
if "observer" not in child_name:
|
127
239
|
has_non_observer_children = True
|
128
240
|
|
129
241
|
if not has_non_observer_children:
|
@@ -144,14 +256,19 @@ def iter_named_quantizable_modules(
|
|
144
256
|
:returns: generator tuple of (name, submodule)
|
145
257
|
"""
|
146
258
|
for name, submodule in model.named_modules():
|
259
|
+
# TODO: verify if an observer would ever be attached in this case/remove check
|
147
260
|
if include_children:
|
148
261
|
children = list(submodule.children())
|
149
|
-
if len(children) == 0 and not
|
262
|
+
if len(children) == 0 and "observer" not in name:
|
150
263
|
yield name, submodule
|
151
264
|
else:
|
265
|
+
if len(children) > 0:
|
266
|
+
named_children, children = zip(*list(submodule.named_children()))
|
152
267
|
has_non_observer_children = False
|
153
|
-
for
|
154
|
-
|
268
|
+
for i in range(len(children)):
|
269
|
+
child_name = named_children[i]
|
270
|
+
|
271
|
+
if "observer" not in child_name:
|
155
272
|
has_non_observer_children = True
|
156
273
|
|
157
274
|
if not has_non_observer_children:
|
@@ -258,7 +258,7 @@ def get_from_registry(
|
|
258
258
|
retrieved_value = _import_and_get_value_from_module(module_path, value_name)
|
259
259
|
else:
|
260
260
|
# look up name in alias registry
|
261
|
-
name = _ALIAS_REGISTRY[parent_class].get(name)
|
261
|
+
name = _ALIAS_REGISTRY[parent_class].get(name, name)
|
262
262
|
# look up name in registry
|
263
263
|
retrieved_value = _REGISTRY[parent_class].get(name)
|
264
264
|
if retrieved_value is None:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: compressed-tensors-nightly
|
3
|
-
Version: 0.7.1.
|
3
|
+
Version: 0.7.1.20241102
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -22,28 +22,20 @@ compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74j
|
|
22
22
|
compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
|
23
23
|
compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
24
24
|
compressed_tensors/linear/compressed_linear.py,sha256=0jTTf6XxOAjAYs3tvFtgiNMAO4W10sSeR-pdH2M413g,3218
|
25
|
-
compressed_tensors/quantization/__init__.py,sha256=
|
26
|
-
compressed_tensors/quantization/
|
27
|
-
compressed_tensors/quantization/quant_args.py,sha256=k7NuZn8OqjgzmAVaN2-jHPQ1bgDkMuUoLJtLnhkvIOI,9085
|
25
|
+
compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
|
26
|
+
compressed_tensors/quantization/quant_args.py,sha256=osjNwCSB6tcyH9Qeg5sHEiB-bHyi3XJ8TzkGVJuGTc4,8711
|
28
27
|
compressed_tensors/quantization/quant_config.py,sha256=NCiMvUMnnz5kTyAkDylxjtEGQnjgsIYIeNR2zyHEdTQ,10371
|
29
28
|
compressed_tensors/quantization/quant_scheme.py,sha256=5ggPz5sqEfTUgvJJeiPIINA74QtO-08hb3szsm7UHGE,6000
|
30
|
-
compressed_tensors/quantization/lifecycle/__init__.py,sha256=
|
31
|
-
compressed_tensors/quantization/lifecycle/apply.py,sha256=
|
32
|
-
compressed_tensors/quantization/lifecycle/calibration.py,sha256=fJ2RDL3E4hmWR8v8nYhq_tv31K8WV00o_4Y3xr7c37Y,3041
|
29
|
+
compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
|
30
|
+
compressed_tensors/quantization/lifecycle/apply.py,sha256=pdCqxXnVw7HoDDanaOtek13g8x_nb54CBUlfuMdhFG4,14993
|
33
31
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
34
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
35
|
-
compressed_tensors/quantization/lifecycle/frozen.py,sha256=71TsgS0Uxku0NomdWOBJsVfXCGTne-Gx9zUEMsCmw5Q,1764
|
32
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=QPL6-vKOFuKdKIEsVqMhsw4x552Jpm2sqO0oeChbnrM,12941
|
36
33
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
37
|
-
compressed_tensors/quantization/lifecycle/initialize.py,sha256=
|
38
|
-
compressed_tensors/quantization/observers/__init__.py,sha256=DYrttzq-8MHLZUzpX-xzzm4hrw6HcXkMkux82KBKb1M,738
|
39
|
-
compressed_tensors/quantization/observers/base.py,sha256=5ovQicWPYHjIxr6-EkQ4lgOX0PpI9g23iSzKpxjM1Zg,8420
|
40
|
-
compressed_tensors/quantization/observers/helpers.py,sha256=nUFdNEIACiPBfFwNYDGCXOvw6tf7j6jfTvDwImHKMPg,5506
|
41
|
-
compressed_tensors/quantization/observers/min_max.py,sha256=sQXqU3z-voxIDfR_9mQzwQUflZj2sASm_G8CYaXntFw,3865
|
42
|
-
compressed_tensors/quantization/observers/mse.py,sha256=G5Y9v4MqXUVcKxBSmCFFW3p_7rlu-6scqLIN88ng-sE,6080
|
34
|
+
compressed_tensors/quantization/lifecycle/initialize.py,sha256=C41hKA5VANyEwkB5FxzEn3Z0Da5tfxF1I07P8rUcyS0,8537
|
43
35
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
44
|
-
compressed_tensors/quantization/utils/helpers.py,sha256=
|
36
|
+
compressed_tensors/quantization/utils/helpers.py,sha256=DBP-sGRpGAY01K0LFE7qqonNj4hkTYL_mXrMs2LtAD8,14100
|
45
37
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
46
|
-
compressed_tensors/registry/registry.py,sha256=
|
38
|
+
compressed_tensors/registry/registry.py,sha256=vRcjVB1ITfSbfYUaGndBBmqhip_5vsS62weorVg0iXo,11896
|
47
39
|
compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
|
48
40
|
compressed_tensors/utils/helpers.py,sha256=hWGIR0W7ENHwdC7wW2SQJJiCF9-xOu_u3fY2RzLyYg4,4101
|
49
41
|
compressed_tensors/utils/offload.py,sha256=d9q8LNe8HyF8tOjgjA7QGLD3HRysmNp0d8eBbdqBgIM,4089
|
@@ -51,8 +43,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
|
|
51
43
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
52
44
|
compressed_tensors/utils/safetensors_load.py,sha256=m08ANVuTBxQdoa6LufDgcNJ7wCLDJolyZljB8VEybAU,8578
|
53
45
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
54
|
-
compressed_tensors_nightly-0.7.1.
|
55
|
-
compressed_tensors_nightly-0.7.1.
|
56
|
-
compressed_tensors_nightly-0.7.1.
|
57
|
-
compressed_tensors_nightly-0.7.1.
|
58
|
-
compressed_tensors_nightly-0.7.1.
|
46
|
+
compressed_tensors_nightly-0.7.1.20241102.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
47
|
+
compressed_tensors_nightly-0.7.1.20241102.dist-info/METADATA,sha256=pQ8FXKctjUHKkisrXYyeDUuunknVPkjHnHvS-uJ89oI,6799
|
48
|
+
compressed_tensors_nightly-0.7.1.20241102.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
49
|
+
compressed_tensors_nightly-0.7.1.20241102.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
50
|
+
compressed_tensors_nightly-0.7.1.20241102.dist-info/RECORD,,
|