compressed-tensors 0.11.1a20250820__py3-none-any.whl → 0.11.1a20250821__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressors/model_compressor.py +172 -153
- compressed_tensors/compressors/quantized_compressors/base.py +2 -2
- compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +4 -5
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -3
- compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +1 -1
- compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
- compressed_tensors/quantization/lifecycle/apply.py +40 -129
- compressed_tensors/quantization/lifecycle/forward.py +5 -4
- compressed_tensors/quantization/lifecycle/initialize.py +7 -6
- compressed_tensors/quantization/quant_args.py +7 -5
- compressed_tensors/quantization/quant_scheme.py +4 -3
- compressed_tensors/quantization/utils/helpers.py +0 -1
- compressed_tensors/registry/registry.py +1 -1
- compressed_tensors/transform/transform_config.py +1 -1
- compressed_tensors/transform/utils/matrix.py +1 -1
- compressed_tensors/utils/match.py +57 -8
- compressed_tensors/utils/offload.py +0 -1
- compressed_tensors/utils/safetensors_load.py +0 -1
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250821.dist-info}/METADATA +1 -1
- {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250821.dist-info}/RECORD +24 -24
- {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250821.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250821.dist-info}/licenses/LICENSE +0 -0
- {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250821.dist-info}/top_level.txt +0 -0
@@ -42,8 +42,6 @@ from compressed_tensors.quantization import (
|
|
42
42
|
apply_quantization_config,
|
43
43
|
load_pretrained_quantization_parameters,
|
44
44
|
)
|
45
|
-
from compressed_tensors.quantization.lifecycle import expand_target_names
|
46
|
-
from compressed_tensors.quantization.utils import is_module_quantized
|
47
45
|
from compressed_tensors.transform import TransformConfig
|
48
46
|
from compressed_tensors.utils import (
|
49
47
|
align_module_device,
|
@@ -60,6 +58,7 @@ from compressed_tensors.utils.helpers import (
|
|
60
58
|
fix_fsdp_module_name,
|
61
59
|
is_compressed_tensors_config,
|
62
60
|
)
|
61
|
+
from compressed_tensors.utils.match import match_named_modules
|
63
62
|
from torch import Tensor
|
64
63
|
from torch.nn import Module
|
65
64
|
from tqdm import tqdm
|
@@ -309,7 +308,7 @@ class ModelCompressor:
|
|
309
308
|
if quantization_config is not None:
|
310
309
|
# If a list of compression_format is not provided, we resolve the
|
311
310
|
# relevant quantization formats using the config groups from the config
|
312
|
-
# and if those are not defined, we fall-back to the global quantization
|
311
|
+
# and if those are not defined, we fall-back to the global quantization fmt
|
313
312
|
if not self.compression_formats:
|
314
313
|
self.compression_formats = self._fetch_unique_quantization_formats()
|
315
314
|
|
@@ -342,13 +341,15 @@ class ModelCompressor:
|
|
342
341
|
self.sparsity_compressor
|
343
342
|
and self.sparsity_config.format != CompressionFormat.dense.value
|
344
343
|
):
|
345
|
-
sparse_targets =
|
344
|
+
sparse_targets = match_named_modules(
|
346
345
|
model=model,
|
347
346
|
targets=self.sparsity_config.targets,
|
348
347
|
ignore=self.sparsity_config.ignore,
|
349
348
|
)
|
349
|
+
|
350
350
|
missing_keys.update(
|
351
|
-
merge_names(
|
351
|
+
merge_names(target_name, "weight")
|
352
|
+
for target_name, _module in sparse_targets
|
352
353
|
)
|
353
354
|
|
354
355
|
# Determine missing keys due to pack quantization
|
@@ -358,13 +359,14 @@ class ModelCompressor:
|
|
358
359
|
== CompressionFormat.pack_quantized.value
|
359
360
|
):
|
360
361
|
for scheme in self.quantization_config.config_groups.values():
|
361
|
-
quant_targets =
|
362
|
+
quant_targets = match_named_modules(
|
362
363
|
model=model,
|
363
364
|
targets=scheme.targets,
|
364
365
|
ignore=self.quantization_config.ignore,
|
365
366
|
)
|
366
367
|
missing_keys.update(
|
367
|
-
merge_names(
|
368
|
+
merge_names(target_name, "weight")
|
369
|
+
for target_name, _module in quant_targets
|
368
370
|
)
|
369
371
|
|
370
372
|
return list(missing_keys)
|
@@ -395,29 +397,29 @@ class ModelCompressor:
|
|
395
397
|
self.sparsity_compressor
|
396
398
|
and self.sparsity_config.format != CompressionFormat.dense.value
|
397
399
|
):
|
398
|
-
sparse_targets
|
400
|
+
sparse_targets = match_named_modules(
|
399
401
|
model=model,
|
400
402
|
targets=self.sparsity_config.targets,
|
401
403
|
ignore=self.sparsity_config.ignore,
|
402
404
|
)
|
403
405
|
unexpected_keys.update(
|
404
|
-
merge_names(
|
405
|
-
for
|
406
|
+
merge_names(target_name, param)
|
407
|
+
for target_name, _module in sparse_targets
|
406
408
|
for param in self.sparsity_compressor.compression_param_names
|
407
409
|
)
|
408
410
|
|
409
411
|
# Identify unexpected keys from quantization compression
|
410
412
|
if self.quantization_compressor:
|
411
413
|
for scheme in self.quantization_config.config_groups.values():
|
412
|
-
quant_targets
|
414
|
+
quant_targets = match_named_modules(
|
413
415
|
model=model,
|
414
416
|
targets=scheme.targets,
|
415
417
|
ignore=self.quantization_config.ignore,
|
416
418
|
)
|
417
419
|
for quant_compressor in self.quantization_compressor.values():
|
418
420
|
unexpected_keys.update(
|
419
|
-
merge_names(
|
420
|
-
for
|
421
|
+
merge_names(target_name, param)
|
422
|
+
for target_name, _module in quant_targets
|
421
423
|
for param in quant_compressor.compression_param_names
|
422
424
|
if param != "weight"
|
423
425
|
)
|
@@ -434,73 +436,79 @@ class ModelCompressor:
|
|
434
436
|
:param model: model containing parameters to compress
|
435
437
|
"""
|
436
438
|
module_to_scheme = map_module_to_scheme(model)
|
437
|
-
sparse_compression_targets
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
)
|
439
|
+
sparse_compression_targets = [
|
440
|
+
module_name
|
441
|
+
for module_name, _module in match_named_modules(
|
442
|
+
model=model,
|
443
|
+
targets=self.sparsity_config.targets if self.sparsity_config else [],
|
444
|
+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
|
445
|
+
)
|
446
|
+
]
|
447
|
+
for prefix, module in tqdm(
|
448
|
+
match_named_modules(
|
449
|
+
model,
|
450
|
+
[*sparse_compression_targets, *module_to_scheme.keys()],
|
451
|
+
warn_on_fail=True,
|
452
|
+
),
|
453
|
+
desc="Compressing model",
|
454
|
+
):
|
455
|
+
module_device = get_execution_device(module)
|
456
|
+
is_meta = module_device.type == "meta"
|
457
|
+
|
458
|
+
exec_device = "meta" if is_meta else "cpu"
|
459
|
+
onloading_device = "meta" if is_meta else module_device
|
460
|
+
|
461
|
+
# in the future, support compression on same device
|
462
|
+
with align_module_device(module, execution_device=exec_device):
|
463
|
+
state_dict = {
|
464
|
+
f"{prefix}.{name}": param
|
465
|
+
for name, param in module.named_parameters(recurse=False)
|
466
|
+
}
|
467
|
+
|
468
|
+
# quantization first
|
469
|
+
if prefix in module_to_scheme:
|
470
|
+
if (
|
471
|
+
not hasattr(module.quantization_scheme, "format")
|
472
|
+
or module.quantization_scheme.format is None
|
473
|
+
):
|
474
|
+
if len(self.compression_formats) > 1:
|
475
|
+
raise ValueError(
|
476
|
+
"Applying multiple compressors without defining "
|
477
|
+
"per module formats is not supported "
|
478
|
+
)
|
479
|
+
format = self.compression_formats[0]
|
480
|
+
else:
|
481
|
+
format = module.quantization_scheme.format
|
482
|
+
|
483
|
+
quant_compressor = self.quantization_compressor.get(format)
|
484
|
+
state_dict = quant_compressor.compress(
|
485
|
+
state_dict,
|
486
|
+
names_to_scheme=module_to_scheme,
|
487
|
+
show_progress=False,
|
488
|
+
compression_device=exec_device,
|
489
|
+
)
|
489
490
|
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
491
|
+
# sparsity second
|
492
|
+
if prefix in sparse_compression_targets:
|
493
|
+
state_dict = self.sparsity_compressor.compress(
|
494
|
+
state_dict,
|
495
|
+
compression_targets=sparse_compression_targets,
|
496
|
+
show_progress=False,
|
497
|
+
)
|
494
498
|
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
param = torch.nn.Parameter(value, requires_grad=False)
|
500
|
-
register_offload_parameter(module, name, param, offload_device)
|
499
|
+
# remove any existing parameters
|
500
|
+
offload_device = get_offloaded_device(module)
|
501
|
+
for name, _ in list(module.named_parameters(recurse=False)):
|
502
|
+
delete_offload_parameter(module, name)
|
501
503
|
|
502
|
-
|
504
|
+
# replace with compressed parameters
|
505
|
+
for name, value in state_dict.items():
|
506
|
+
name = name.removeprefix(f"{prefix}.")
|
507
|
+
value = value.to(onloading_device)
|
508
|
+
param = torch.nn.Parameter(value, requires_grad=False)
|
509
|
+
register_offload_parameter(module, name, param, offload_device)
|
503
510
|
|
511
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
504
512
|
# TODO: consider sparse compression to also be compression
|
505
513
|
if (
|
506
514
|
self.quantization_config is not None
|
@@ -516,67 +524,75 @@ class ModelCompressor:
|
|
516
524
|
:param model: model containing parameters to compress
|
517
525
|
"""
|
518
526
|
module_to_scheme = map_module_to_scheme(model)
|
519
|
-
sparse_compression_targets
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
527
|
+
sparse_compression_targets = [
|
528
|
+
module_name
|
529
|
+
for module_name, _module in match_named_modules(
|
530
|
+
model=model,
|
531
|
+
targets=self.sparsity_config.targets if self.sparsity_config else [],
|
532
|
+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
|
533
|
+
)
|
534
|
+
]
|
535
|
+
|
536
|
+
for prefix, module in tqdm(
|
537
|
+
match_named_modules(
|
538
|
+
model,
|
539
|
+
[*sparse_compression_targets, *module_to_scheme.keys()],
|
540
|
+
warn_on_fail=True,
|
541
|
+
),
|
542
|
+
desc="Decompressing model",
|
543
|
+
):
|
544
|
+
# in the future, support decompression on same device
|
545
|
+
with align_module_device(module, execution_device="cpu"):
|
546
|
+
state_dict = {
|
547
|
+
f"{prefix}.{name}": param
|
548
|
+
for name, param in module.named_parameters(recurse=False)
|
549
|
+
}
|
550
|
+
|
551
|
+
# sparsity first
|
552
|
+
if prefix in sparse_compression_targets:
|
553
|
+
# sparse_compression_targets are automatically inferred by this fn
|
554
|
+
generator = self.sparsity_compressor.decompress_from_state_dict(
|
555
|
+
state_dict,
|
556
|
+
)
|
557
|
+
# generates (param_path, param_val)
|
558
|
+
# of compressed and unused params
|
559
|
+
state_dict = {key: value for key, value in generator}
|
560
|
+
|
561
|
+
# quantization second
|
562
|
+
if prefix in module_to_scheme:
|
563
|
+
if (
|
564
|
+
not hasattr(module.quantization_scheme, "format")
|
565
|
+
or module.quantization_scheme.format is None
|
566
|
+
):
|
567
|
+
if len(self.compression_formats) > 1:
|
568
|
+
raise ValueError(
|
569
|
+
"Applying multiple compressors without defining "
|
570
|
+
"per module formats is not supported "
|
571
|
+
)
|
572
|
+
format = self.compression_formats[0]
|
573
|
+
else:
|
574
|
+
format = module.quantization_scheme.format
|
575
|
+
quant_compressor = self.quantization_compressor.get(format)
|
576
|
+
state_dict = quant_compressor.decompress_module_from_state_dict(
|
577
|
+
prefix,
|
578
|
+
state_dict,
|
579
|
+
scheme=module_to_scheme[prefix],
|
580
|
+
)
|
565
581
|
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
582
|
+
# remove any existing parameters
|
583
|
+
exec_device = get_execution_device(module)
|
584
|
+
offload_device = get_offloaded_device(module)
|
585
|
+
for name, _ in list(module.named_parameters(recurse=False)):
|
586
|
+
delete_offload_parameter(module, name)
|
571
587
|
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
588
|
+
# replace with decompressed parameters
|
589
|
+
for name, value in state_dict.items():
|
590
|
+
name = name.removeprefix(f"{prefix}.")
|
591
|
+
value = value.to(exec_device)
|
592
|
+
param = torch.nn.Parameter(value, requires_grad=False)
|
593
|
+
register_offload_parameter(module, name, param, offload_device)
|
578
594
|
|
579
|
-
|
595
|
+
module.quantization_status = QuantizationStatus.FROZEN
|
580
596
|
|
581
597
|
# ----- state dict compression pathways ----- #
|
582
598
|
|
@@ -614,11 +630,14 @@ class ModelCompressor:
|
|
614
630
|
)
|
615
631
|
|
616
632
|
if self.sparsity_compressor is not None:
|
617
|
-
sparse_compression_targets: Set[str] =
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
633
|
+
sparse_compression_targets: Set[str] = {
|
634
|
+
module_name
|
635
|
+
for module_name, _module in match_named_modules(
|
636
|
+
model=model,
|
637
|
+
targets=self.sparsity_config.targets,
|
638
|
+
ignore=self.sparsity_config.ignore,
|
639
|
+
)
|
640
|
+
}
|
622
641
|
state_dict = self.sparsity_compressor.compress(
|
623
642
|
state_dict,
|
624
643
|
compression_targets=sparse_compression_targets,
|
@@ -641,11 +660,12 @@ class ModelCompressor:
|
|
641
660
|
:param model_path: path to compressed weights
|
642
661
|
:param model: pytorch model to load decompressed weights into
|
643
662
|
|
644
|
-
Note: decompress makes use of both _replace_sparsity_weights and
|
645
|
-
The variations in these methods are a result of the subtle
|
646
|
-
and quantization compressors. Specifically,
|
647
|
-
|
648
|
-
|
663
|
+
Note: decompress makes use of both _replace_sparsity_weights and
|
664
|
+
_replace_weights. The variations in these methods are a result of the subtle
|
665
|
+
variations between the sparsity and quantization compressors. Specifically,
|
666
|
+
quantization compressors return not just the decompressed weight, but the
|
667
|
+
quantization parameters (e.g scales, zero_point) whereas sparsity compressors
|
668
|
+
only return the decompressed weight.
|
649
669
|
|
650
670
|
"""
|
651
671
|
model_path = get_safetensors_folder(model_path)
|
@@ -683,18 +703,17 @@ class ModelCompressor:
|
|
683
703
|
with override_quantization_status(
|
684
704
|
self.quantization_config, QuantizationStatus.FROZEN
|
685
705
|
):
|
686
|
-
|
687
706
|
names_to_scheme = apply_quantization_config(
|
688
707
|
model, self.quantization_config
|
689
708
|
)
|
690
709
|
# Load activation scales/zp or any other quantization parameters
|
691
|
-
# Conditionally load the weight quantization parameters if we have a
|
692
|
-
#
|
710
|
+
# Conditionally load the weight quantization parameters if we have a
|
711
|
+
# dense compressor or if a sparsity compressor has already been applied
|
693
712
|
load_pretrained_quantization_parameters(
|
694
713
|
model,
|
695
714
|
model_path,
|
696
|
-
# TODO: all weight quantization params will be moved to the
|
697
|
-
# including initialization
|
715
|
+
# TODO: all weight quantization params will be moved to the
|
716
|
+
# compressor in a follow-up including initialization
|
698
717
|
load_weight_quantization=(
|
699
718
|
sparse_decompressed
|
700
719
|
or isinstance(quant_compressor, DenseCompressor)
|
@@ -786,7 +805,6 @@ class ModelCompressor:
|
|
786
805
|
:param model: The model whose weights are to be updated.
|
787
806
|
"""
|
788
807
|
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
|
789
|
-
|
790
808
|
split_name = name.split(".")
|
791
809
|
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
|
792
810
|
module = operator.attrgetter(prefix)(model)
|
@@ -822,9 +840,10 @@ class ModelCompressor:
|
|
822
840
|
for param_name, param_data in data.items():
|
823
841
|
if hasattr(module, param_name):
|
824
842
|
# If compressed, will have an incorrect dtype for transformers >4.49
|
825
|
-
# TODO: we can also just skip initialization of scales/zp if in
|
826
|
-
# to be consistent with loading which happens
|
827
|
-
# however, update_data does a good shape check -
|
843
|
+
# TODO: we can also just skip initialization of scales/zp if in
|
844
|
+
# decompression in init to be consistent with loading which happens
|
845
|
+
# later as well however, update_data does a good shape check -
|
846
|
+
# should be moved to the compressor
|
828
847
|
if param_name == "weight":
|
829
848
|
delattr(module, param_name)
|
830
849
|
requires_grad = param_data.dtype in (
|
@@ -24,7 +24,6 @@ from compressed_tensors.utils import (
|
|
24
24
|
get_nested_weight_mappings,
|
25
25
|
merge_names,
|
26
26
|
)
|
27
|
-
from compressed_tensors.utils.safetensors_load import match_param_name
|
28
27
|
from safetensors import safe_open
|
29
28
|
from torch import Tensor
|
30
29
|
from tqdm import tqdm
|
@@ -107,7 +106,8 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
107
106
|
compressed_dict[name] = value.to(compression_device)
|
108
107
|
continue
|
109
108
|
|
110
|
-
# compress values on meta if loading from meta otherwise on cpu (memory
|
109
|
+
# compress values on meta if loading from meta otherwise on cpu (memory
|
110
|
+
# movement too expensive)
|
111
111
|
module_path = prefix[:-1] if prefix.endswith(".") else prefix
|
112
112
|
quant_args = names_to_scheme[module_path].weights
|
113
113
|
compressed_values = self.compress_weight(
|
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
from typing import Dict, Optional, Tuple
|
17
17
|
|
18
|
-
import numpy
|
19
18
|
import torch
|
20
19
|
from compressed_tensors.compressors.base import BaseCompressor
|
21
20
|
from compressed_tensors.compressors.quantized_compressors.base import (
|
@@ -92,7 +91,6 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
|
|
92
91
|
zero_point: Optional[torch.Tensor] = None,
|
93
92
|
g_idx: Optional[torch.Tensor] = None,
|
94
93
|
) -> Dict[str, torch.Tensor]:
|
95
|
-
|
96
94
|
quantized_weight = quantize(
|
97
95
|
x=weight,
|
98
96
|
scale=scale,
|
@@ -112,7 +110,6 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
|
|
112
110
|
compressed_data: Dict[str, Tensor],
|
113
111
|
quantization_args: Optional[QuantizationArgs] = None,
|
114
112
|
) -> torch.Tensor:
|
115
|
-
|
116
113
|
weight = compressed_data["weight_packed"]
|
117
114
|
scale = compressed_data["weight_scale"]
|
118
115
|
global_scale = compressed_data["weight_global_scale"]
|
@@ -175,14 +172,16 @@ kE2M1ToFloat = torch.tensor(
|
|
175
172
|
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
176
173
|
)
|
177
174
|
|
175
|
+
|
178
176
|
# reference: : https://github.com/vllm-project/vllm/pull/16362
|
179
177
|
def unpack_fp4_from_uint8(
|
180
178
|
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
|
181
179
|
) -> torch.Tensor:
|
182
180
|
"""
|
183
181
|
Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
|
184
|
-
(i.e. first four bits correspond to one fp4 value, last four
|
185
|
-
fp4 value). The bits represent an index, which are mapped to an fp4
|
182
|
+
(i.e. first four bits correspond to one fp4 value, last four correspond to a
|
183
|
+
consecutive fp4 value). The bits represent an index, which are mapped to an fp4
|
184
|
+
value.
|
186
185
|
|
187
186
|
:param a: tensor to unpack
|
188
187
|
:param m: original dim 0 size of the unpacked tensor
|
@@ -14,7 +14,6 @@
|
|
14
14
|
import math
|
15
15
|
from typing import Dict, Literal, Optional, Tuple, Union
|
16
16
|
|
17
|
-
import numpy as np
|
18
17
|
import torch
|
19
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
20
19
|
from compressed_tensors.compressors.quantized_compressors.base import (
|
@@ -135,7 +134,8 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
|
135
134
|
compressed_dict["weight_shape"] = weight_shape
|
136
135
|
compressed_dict["weight_packed"] = packed_weight
|
137
136
|
|
138
|
-
# We typically don't compress zp; apart from when using the packed_compressor
|
137
|
+
# We typically don't compress zp; apart from when using the packed_compressor
|
138
|
+
# and when storing group/channel zp
|
139
139
|
if not quantization_args.symmetric and quantization_args.strategy in [
|
140
140
|
QuantizationStrategy.GROUP.value,
|
141
141
|
QuantizationStrategy.CHANNEL.value,
|
@@ -166,7 +166,8 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
|
166
166
|
num_bits = quantization_args.num_bits
|
167
167
|
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
168
168
|
|
169
|
-
# NOTE: this will fail decompression as we don't currently handle packed zp on
|
169
|
+
# NOTE: this will fail decompression as we don't currently handle packed zp on
|
170
|
+
# decompression
|
170
171
|
if not quantization_args.symmetric and quantization_args.strategy in [
|
171
172
|
QuantizationStrategy.GROUP.value,
|
172
173
|
QuantizationStrategy.CHANNEL.value,
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from dataclasses import dataclass
|
16
|
-
from typing import Dict,
|
16
|
+
from typing import Dict, List, Tuple, Union
|
17
17
|
|
18
18
|
import torch
|
19
19
|
from compressed_tensors.compressors.base import BaseCompressor
|
@@ -48,7 +48,7 @@ class Marlin24Compressor(BaseCompressor):
|
|
48
48
|
|
49
49
|
@staticmethod
|
50
50
|
def validate_quant_compatability(
|
51
|
-
names_to_scheme: Dict[str, QuantizationScheme]
|
51
|
+
names_to_scheme: Dict[str, QuantizationScheme],
|
52
52
|
) -> bool:
|
53
53
|
"""
|
54
54
|
Checks if every quantized module in the model is compatible with Marlin24
|
@@ -13,12 +13,11 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
import
|
17
|
-
from collections import OrderedDict, defaultdict
|
16
|
+
from collections import OrderedDict
|
18
17
|
from copy import deepcopy
|
19
18
|
from typing import Dict, Iterable, List, Optional
|
20
19
|
from typing import OrderedDict as OrderedDictType
|
21
|
-
from typing import
|
20
|
+
from typing import Union
|
22
21
|
|
23
22
|
import torch
|
24
23
|
from compressed_tensors.config import CompressionFormat
|
@@ -39,7 +38,8 @@ from compressed_tensors.quantization.utils import (
|
|
39
38
|
infer_quantization_status,
|
40
39
|
is_kv_cache_quant_scheme,
|
41
40
|
)
|
42
|
-
from compressed_tensors.utils.helpers import
|
41
|
+
from compressed_tensors.utils.helpers import deprecated, replace_module
|
42
|
+
from compressed_tensors.utils.match import match_named_modules, match_targets
|
43
43
|
from compressed_tensors.utils.offload import update_parameter_data
|
44
44
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
45
45
|
from safetensors import safe_open
|
@@ -51,8 +51,6 @@ __all__ = [
|
|
51
51
|
"apply_quantization_config",
|
52
52
|
"apply_quantization_status",
|
53
53
|
"find_name_or_class_matches",
|
54
|
-
"expand_target_names",
|
55
|
-
"is_target",
|
56
54
|
]
|
57
55
|
|
58
56
|
from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
@@ -73,14 +71,14 @@ def load_pretrained_quantization_parameters(
|
|
73
71
|
Loads the quantization parameters (scale and zero point) from model_name_or_path to
|
74
72
|
a model that has already been initialized with a quantization config.
|
75
73
|
|
76
|
-
NOTE: Will always load inputs/output parameters.
|
77
|
-
|
74
|
+
NOTE: Will always load inputs/output parameters. Will conditioanlly load weight
|
75
|
+
parameters, if load_weight_quantization is set to True.
|
78
76
|
|
79
77
|
:param model: model to load pretrained quantization parameters to
|
80
78
|
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
|
81
79
|
model, which is used to load quantization parameters
|
82
|
-
:param load_weight_quantization: whether or not the weight quantization parameters
|
83
|
-
be
|
80
|
+
:param load_weight_quantization: whether or not the weight quantization parameters
|
81
|
+
should be loaded
|
84
82
|
"""
|
85
83
|
model_path = get_safetensors_folder(model_name_or_path)
|
86
84
|
mapping = get_quantization_parameter_to_path_mapping(model_path)
|
@@ -147,47 +145,30 @@ def apply_quantization_config(
|
|
147
145
|
if run_compressed:
|
148
146
|
from compressed_tensors.linear.compressed_linear import CompressedLinear
|
149
147
|
|
150
|
-
# list of submodules to ignore
|
151
|
-
ignored_submodules = defaultdict(list)
|
152
148
|
# mark appropriate layers for quantization by setting their quantization schemes
|
153
|
-
for name, submodule in
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
)
|
177
|
-
replace_module(model, name, compressed_linear)
|
178
|
-
|
179
|
-
# target matched - add layer and scheme to target list
|
180
|
-
submodule.quantization_scheme = scheme
|
181
|
-
|
182
|
-
names_to_scheme[name] = submodule.quantization_scheme
|
183
|
-
|
184
|
-
if config.ignore is not None and ignored_submodules is not None:
|
185
|
-
if set(config.ignore) - set(ignored_submodules):
|
186
|
-
_LOGGER.warning(
|
187
|
-
"Some layers that were to be ignored were "
|
188
|
-
"not found in the model: "
|
189
|
-
f"{set(config.ignore) - set(ignored_submodules)}"
|
190
|
-
)
|
149
|
+
for name, submodule in match_named_modules(
|
150
|
+
model, target_to_scheme, config.ignore, warn_on_fail=True
|
151
|
+
):
|
152
|
+
# mark modules to be quantized by adding
|
153
|
+
# quant scheme to the matching layers
|
154
|
+
matched_targets = match_targets(name, submodule, target_to_scheme)
|
155
|
+
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
|
156
|
+
if run_compressed:
|
157
|
+
format = config.format
|
158
|
+
if format != CompressionFormat.dense.value:
|
159
|
+
if isinstance(submodule, torch.nn.Linear):
|
160
|
+
# TODO: expand to more module types
|
161
|
+
compressed_linear = CompressedLinear.from_linear(
|
162
|
+
submodule,
|
163
|
+
quantization_scheme=scheme,
|
164
|
+
quantization_format=format,
|
165
|
+
)
|
166
|
+
replace_module(model, name, compressed_linear)
|
167
|
+
|
168
|
+
# target matched - add layer and scheme to target list
|
169
|
+
submodule.quantization_scheme = scheme
|
170
|
+
|
171
|
+
names_to_scheme[name] = submodule.quantization_scheme
|
191
172
|
|
192
173
|
# apply current quantization status across all targeted layers
|
193
174
|
apply_quantization_status(model, config.quantization_status)
|
@@ -262,54 +243,10 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
262
243
|
model.apply(compress_quantized_weights)
|
263
244
|
|
264
245
|
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
) -> Set[str]:
|
270
|
-
"""
|
271
|
-
Finds all unique module names in the model that match the given
|
272
|
-
targets and ignore lists.
|
273
|
-
|
274
|
-
Note: Targets must be regexes, layer types, or full layer names.
|
275
|
-
|
276
|
-
:param model: model to search for targets in
|
277
|
-
:param targets: Iterable of targets to search for
|
278
|
-
:param ignore: Iterable of targets to ignore
|
279
|
-
:return: set of all targets that match the given targets and should
|
280
|
-
not be ignored
|
281
|
-
"""
|
282
|
-
return {
|
283
|
-
name
|
284
|
-
for name, module in model.named_modules()
|
285
|
-
if is_target(name, module, targets, ignore)
|
286
|
-
}
|
287
|
-
|
288
|
-
|
289
|
-
def is_target(
|
290
|
-
name: str,
|
291
|
-
module: Module,
|
292
|
-
targets: Optional[Iterable[str]] = None,
|
293
|
-
ignore: Optional[Iterable[str]] = None,
|
294
|
-
) -> bool:
|
295
|
-
"""
|
296
|
-
Determines if a module should be included in the targets based on the
|
297
|
-
targets and ignore lists.
|
298
|
-
|
299
|
-
Note: Targets must be regexes, layer types, or full layer names.
|
300
|
-
|
301
|
-
:param name: name of the module
|
302
|
-
:param module: the module itself
|
303
|
-
:param targets: Iterable of targets to search for
|
304
|
-
:param ignore: Iterable of targets to ignore
|
305
|
-
:return: True if the module is a target and not ignored, False otherwise
|
306
|
-
"""
|
307
|
-
return bool(
|
308
|
-
find_name_or_class_matches(name, module, targets or [])
|
309
|
-
and not find_name_or_class_matches(name, module, ignore or [])
|
310
|
-
)
|
311
|
-
|
312
|
-
|
246
|
+
@deprecated(
|
247
|
+
message="This function is deprecated and will be removed in a future release."
|
248
|
+
"Please use `match_targets` from `compressed_tensors.utils.match` instead."
|
249
|
+
)
|
313
250
|
def find_name_or_class_matches(
|
314
251
|
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
|
315
252
|
) -> List[str]:
|
@@ -322,38 +259,13 @@ def find_name_or_class_matches(
|
|
322
259
|
2. matches on regex patterns
|
323
260
|
3. matches on module names
|
324
261
|
"""
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
targets = sorted(targets, key=lambda x: ("re:" in x, x))
|
331
|
-
if isinstance(targets, Iterable):
|
332
|
-
matches = _find_matches(name, targets) + _find_matches(
|
333
|
-
module.__class__.__name__, targets, check_contains
|
262
|
+
if check_contains:
|
263
|
+
raise NotImplementedError(
|
264
|
+
"This function is deprecated, and the check_contains=True option has been"
|
265
|
+
" removed."
|
334
266
|
)
|
335
|
-
matches = [match for match in matches if match is not None]
|
336
|
-
return matches
|
337
267
|
|
338
|
-
|
339
|
-
def _find_matches(
|
340
|
-
value: str, targets: Iterable[str], check_contains: bool = False
|
341
|
-
) -> List[str]:
|
342
|
-
# returns all the targets that match value either
|
343
|
-
# exactly or as a regex after 're:'. if check_contains is set to True,
|
344
|
-
# additionally checks if the target string is contained with value.
|
345
|
-
matches = []
|
346
|
-
for target in targets:
|
347
|
-
if target.startswith("re:"):
|
348
|
-
pattern = target[3:]
|
349
|
-
if re.match(pattern, value):
|
350
|
-
matches.append(target)
|
351
|
-
elif check_contains:
|
352
|
-
if target.lower() in value.lower():
|
353
|
-
matches.append(target)
|
354
|
-
elif target == value:
|
355
|
-
matches.append(target)
|
356
|
-
return matches
|
268
|
+
return match_targets(name, module, targets)
|
357
269
|
|
358
270
|
|
359
271
|
def _infer_status(model: Module) -> Optional[QuantizationStatus]:
|
@@ -429,7 +341,6 @@ def _scheme_from_targets(
|
|
429
341
|
def _merge_schemes(
|
430
342
|
schemes_to_merge: List[QuantizationScheme], name: str
|
431
343
|
) -> QuantizationScheme:
|
432
|
-
|
433
344
|
kv_cache_quantization_scheme = [
|
434
345
|
scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
|
435
346
|
]
|
@@ -205,7 +205,8 @@ def _process_quantization(
|
|
205
205
|
q_min, q_max = calculate_range(args, x.device)
|
206
206
|
group_size = args.group_size
|
207
207
|
|
208
|
-
# blockwise FP8: quantize per 2D block, supports block_structure for static block
|
208
|
+
# blockwise FP8: quantize per 2D block, supports block_structure for static block
|
209
|
+
# quantization
|
209
210
|
if args.strategy == QuantizationStrategy.BLOCK:
|
210
211
|
original_shape = x.shape
|
211
212
|
rows, cols = x.shape[-2], x.shape[-1]
|
@@ -214,8 +215,8 @@ def _process_quantization(
|
|
214
215
|
# Ensure exact division (tensor dimensions must be divisible by block size)
|
215
216
|
if rows % block_height != 0:
|
216
217
|
raise ValueError(
|
217
|
-
f"Tensor height {rows} is not divisible by block_height {block_height}.
|
218
|
-
f"Block quantization requires exact division."
|
218
|
+
f"Tensor height {rows} is not divisible by block_height {block_height}."
|
219
|
+
f" Block quantization requires exact division."
|
219
220
|
)
|
220
221
|
if cols % block_width != 0:
|
221
222
|
raise ValueError(
|
@@ -295,7 +296,7 @@ def _process_quantization(
|
|
295
296
|
perm = torch.argsort(g_idx)
|
296
297
|
x = safe_permute(x, perm, dim=1)
|
297
298
|
|
298
|
-
# Maintain all dimensions
|
299
|
+
# Maintain all dimensions except the last dim, which is divided by group_size
|
299
300
|
reshaped_dims = (
|
300
301
|
ceil(x.shape[-1] / group_size),
|
301
302
|
group_size,
|
@@ -17,7 +17,7 @@ import logging
|
|
17
17
|
import math
|
18
18
|
import warnings
|
19
19
|
from enum import Enum
|
20
|
-
from typing import
|
20
|
+
from typing import Optional
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from compressed_tensors.quantization.lifecycle.forward import (
|
@@ -87,7 +87,6 @@ def initialize_module_for_quantization(
|
|
87
87
|
_initialize_attn_scales(module)
|
88
88
|
|
89
89
|
else:
|
90
|
-
|
91
90
|
if scheme.input_activations is not None:
|
92
91
|
_initialize_scale_zero_point(
|
93
92
|
module,
|
@@ -183,7 +182,8 @@ def _initialize_scale_zero_point(
|
|
183
182
|
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
|
184
183
|
expected_shape = (weight_shape[0], max(num_groups, 1))
|
185
184
|
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
|
186
|
-
# For block quantization, scale shape should match number of blocks - only
|
185
|
+
# For block quantization, scale shape should match number of blocks - only
|
186
|
+
# for weights
|
187
187
|
if quantization_args.block_structure is None:
|
188
188
|
raise ValueError(
|
189
189
|
"Block quantization requires block_structure to be specified"
|
@@ -196,9 +196,10 @@ def _initialize_scale_zero_point(
|
|
196
196
|
# Warn if dimensions don't divide evenly
|
197
197
|
if rows % block_height != 0 or cols % block_width != 0:
|
198
198
|
warnings.warn(
|
199
|
-
f"Block quantization: tensor shape {weight_shape} does not divide
|
200
|
-
f"by block structure {quantization_args.block_structure}. "
|
201
|
-
f"Some blocks will be incomplete which may affect quantization
|
199
|
+
f"Block quantization: tensor shape {weight_shape} does not divide"
|
200
|
+
f"evenly by block structure {quantization_args.block_structure}. "
|
201
|
+
f"Some blocks will be incomplete which may affect quantization"
|
202
|
+
"quality.",
|
202
203
|
UserWarning,
|
203
204
|
)
|
204
205
|
|
@@ -217,16 +217,18 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
217
217
|
return [int(x) for x in value.split("x")]
|
218
218
|
except Exception:
|
219
219
|
raise ValueError(
|
220
|
-
f"Invalid block_structure '{value}'. Must be a list of
|
220
|
+
f"Invalid block_structure '{value}'. Must be a list of ints "
|
221
|
+
"[rows, cols]."
|
221
222
|
)
|
222
223
|
if isinstance(value, (list, tuple)):
|
223
224
|
if len(value) != 2 or not all(isinstance(v, int) for v in value):
|
224
225
|
raise ValueError(
|
225
|
-
f"Invalid block_structure '{value}'. Must be a list of
|
226
|
+
f"Invalid block_structure '{value}'. Must be a list of ints "
|
227
|
+
"[rows, cols]."
|
226
228
|
)
|
227
229
|
return list(value)
|
228
230
|
raise ValueError(
|
229
|
-
f"Invalid block_structure '{value}'. Must be a list of
|
231
|
+
f"Invalid block_structure '{value}'. Must be a list of ints [rows, cols]."
|
230
232
|
)
|
231
233
|
|
232
234
|
@field_validator("strategy", mode="before")
|
@@ -307,7 +309,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
307
309
|
)
|
308
310
|
if strategy not in supported_strategies:
|
309
311
|
raise ValueError(
|
310
|
-
f"One of {supported_strategies} must be used for dynamic
|
312
|
+
f"One of {supported_strategies} must be used for dynamic quant."
|
311
313
|
)
|
312
314
|
|
313
315
|
if (
|
@@ -322,7 +324,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
322
324
|
observer != "memoryless"
|
323
325
|
): # avoid annoying users with old configs
|
324
326
|
warnings.warn(
|
325
|
-
"No observer is used for dynamic
|
327
|
+
"No observer is used for dynamic quant., setting to None"
|
326
328
|
)
|
327
329
|
observer = None
|
328
330
|
else:
|
@@ -81,9 +81,10 @@ class QuantizationScheme(BaseModel):
|
|
81
81
|
):
|
82
82
|
warnings.warn(
|
83
83
|
"Using GROUP strategy for both weights and input_activations "
|
84
|
-
f"with different group sizes ({weights.group_size} vs
|
85
|
-
"may complicate fused kernel implementations.
|
86
|
-
"TENSOR_GROUP strategy for both or matching group
|
84
|
+
f"with different group sizes ({weights.group_size} vs "
|
85
|
+
f"{inputs.group_size}) may complicate fused kernel implementations. "
|
86
|
+
"Consider using TENSOR_GROUP strategy for both or matching group"
|
87
|
+
" sizes.",
|
87
88
|
UserWarning,
|
88
89
|
stacklevel=2,
|
89
90
|
)
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import
|
15
|
+
from typing import Optional
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.transform import TransformLocation
|
@@ -27,6 +27,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
|
|
27
27
|
__all__ = [
|
28
28
|
"match_named_modules",
|
29
29
|
"match_named_parameters",
|
30
|
+
"match_targets",
|
30
31
|
"match_modules_set",
|
31
32
|
"is_match",
|
32
33
|
]
|
@@ -37,8 +38,8 @@ FusedMappping = Mapping[str, Iterable[str]]
|
|
37
38
|
|
38
39
|
def match_named_modules(
|
39
40
|
model: torch.nn.Module,
|
40
|
-
targets: Iterable[str],
|
41
|
-
ignore: Iterable[str] =
|
41
|
+
targets: Optional[Iterable[str]],
|
42
|
+
ignore: Optional[Iterable[str]] = None,
|
42
43
|
fused: Optional[FusedMappping] = None,
|
43
44
|
warn_on_fail: bool = False,
|
44
45
|
) -> Generator[Tuple[str, torch.nn.Module]]:
|
@@ -54,14 +55,18 @@ def match_named_modules(
|
|
54
55
|
:param warn_on_fail: if True, warns if any targets do not match any modules in model
|
55
56
|
:return: generator of module names and modules
|
56
57
|
"""
|
58
|
+
targets = targets or []
|
59
|
+
ignore = ignore or []
|
60
|
+
|
57
61
|
unmatched_targets = set(targets)
|
62
|
+
|
58
63
|
for name, module in model.named_modules():
|
59
64
|
for target in targets:
|
60
65
|
if is_match(name, module, target, fused=fused):
|
61
66
|
unmatched_targets -= {target}
|
62
|
-
|
63
67
|
if not is_match(name, module, ignore, fused=fused):
|
64
68
|
yield name, module
|
69
|
+
break
|
65
70
|
|
66
71
|
if warn_on_fail:
|
67
72
|
for target in unmatched_targets:
|
@@ -72,8 +77,8 @@ def match_named_modules(
|
|
72
77
|
|
73
78
|
def match_named_parameters(
|
74
79
|
model: torch.nn.Module,
|
75
|
-
targets: Iterable[str],
|
76
|
-
ignore: Iterable[str] =
|
80
|
+
targets: Optional[Iterable[str]],
|
81
|
+
ignore: Optional[Iterable[str]] = None,
|
77
82
|
fused: Optional[FusedMappping] = None,
|
78
83
|
warn_on_fail: bool = False,
|
79
84
|
) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
|
@@ -89,6 +94,9 @@ def match_named_parameters(
|
|
89
94
|
:param warn_on_fail: if True, warns if any targets do not match any params in model
|
90
95
|
:return: generator of fully-qualified param names, parent modules, and params
|
91
96
|
"""
|
97
|
+
targets = targets or []
|
98
|
+
ignore = ignore or []
|
99
|
+
|
92
100
|
unmatched_targets = set(targets)
|
93
101
|
for module_name, module in model.named_modules():
|
94
102
|
if isinstance(module, InternalModule):
|
@@ -110,16 +118,54 @@ def match_named_parameters(
|
|
110
118
|
)
|
111
119
|
|
112
120
|
|
121
|
+
def match_targets(
|
122
|
+
name: str, module: torch.nn.Module, targets: Optional[Iterable[str]]
|
123
|
+
) -> List[str]:
|
124
|
+
"""
|
125
|
+
Returns the targets that match the given name and module.
|
126
|
+
|
127
|
+
:param name: the name of the module
|
128
|
+
:param module: the module to match
|
129
|
+
:param targets: the target strings, potentially containing "re:" prefixes
|
130
|
+
:return: the targets that match the given name and module
|
131
|
+
|
132
|
+
Outputs are ordered by type: exact name match, regex name match, class name match
|
133
|
+
"""
|
134
|
+
targets = targets or []
|
135
|
+
|
136
|
+
if isinstance(module, InternalModule):
|
137
|
+
return []
|
138
|
+
|
139
|
+
# The order of the output `matches` list matters, the are arranged from most
|
140
|
+
# specific to least specific, and this order will be used when merging configs.
|
141
|
+
# The entries are sorted in the following order:
|
142
|
+
# 1. matches on exact strings
|
143
|
+
# 2. matches on regex patterns
|
144
|
+
# 3. matches on module names
|
145
|
+
|
146
|
+
targets = sorted(targets, key=lambda x: ("re:" in x, x))
|
147
|
+
matched_targets = []
|
148
|
+
for target in targets:
|
149
|
+
if _match_name(name, target):
|
150
|
+
matched_targets.append(target)
|
151
|
+
|
152
|
+
for target in targets:
|
153
|
+
if _match_class(module, target) and target not in matched_targets:
|
154
|
+
matched_targets.append(target)
|
155
|
+
|
156
|
+
return matched_targets
|
157
|
+
|
158
|
+
|
113
159
|
def match_modules_set(
|
114
160
|
model: torch.nn.Module,
|
115
|
-
targets: Iterable[str],
|
116
|
-
ignore: Iterable[str] =
|
161
|
+
targets: Optional[Iterable[str]],
|
162
|
+
ignore: Optional[Iterable[str]] = None,
|
117
163
|
) -> Generator[Iterable[torch.nn.Module]]:
|
118
164
|
"""
|
119
165
|
Yields modules grouped with the same order and size as `targets`.
|
120
166
|
Values are returned in order of `model.named_modules()`
|
121
167
|
|
122
|
-
|
168
|
+
E.g. the following targets would yield module belonging to the following layers:
|
123
169
|
```python3
|
124
170
|
match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
|
125
171
|
(
|
@@ -151,6 +197,9 @@ def match_modules_set(
|
|
151
197
|
:param targets: target strings, potentially containing "re:" prefixes
|
152
198
|
:param ignore: targets to ignore, potentially containing "re:" prefixes
|
153
199
|
"""
|
200
|
+
targets = targets or []
|
201
|
+
ignore = ignore or []
|
202
|
+
|
154
203
|
matches = dict.fromkeys(targets, None)
|
155
204
|
for name, module in model.named_modules():
|
156
205
|
# match until we get a full set
|
compressed_tensors/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.11.
|
3
|
+
Version: 0.11.1a20250821
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
{compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250821.dist-info}/RECORD
RENAMED
@@ -1,23 +1,23 @@
|
|
1
1
|
compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
|
2
2
|
compressed_tensors/base.py,sha256=-gxWvDF4LCkyeDP8YlGzvBBKxo4Dk9h4NINPD61drFU,921
|
3
|
-
compressed_tensors/version.py,sha256=
|
3
|
+
compressed_tensors/version.py,sha256=QiPWK4b5m-LXWHE8_W5EK7VPtKZvorPc5Opz7BYczvA,523
|
4
4
|
compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
|
5
5
|
compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
|
6
6
|
compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
|
7
7
|
compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
|
8
|
-
compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=
|
8
|
+
compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=x2AS1NAPQx51O8uxyLf3wItnp2-_0qU2fI6eQVFBBfY,37388
|
9
9
|
compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=KvaFBL_Q84LxRGJOV035M8OBoCkAx8kOkfphswgkKWk,745
|
10
|
-
compressed_tensors/compressors/quantized_compressors/base.py,sha256=
|
10
|
+
compressed_tensors/compressors/quantized_compressors/base.py,sha256=_mqTG_HjAIbHqDGucA3ZR_01OXU3CMFxtrDjfM-kY0g,10301
|
11
11
|
compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=0ANDcuD8aXPqTYNPY6GnX9iS6eXJw6P0TzNV_rYS2l8,5369
|
12
|
-
compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=
|
13
|
-
compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=
|
12
|
+
compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=Z8k2gi5a1F_36DiI0GJsXGc03Gh0qwBRMwMxuKIWkj8,7136
|
13
|
+
compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=D8h9ltxSIYi1XEKYgbYu1ebbXzCibhPi-eZsBUi0NOg,11245
|
14
14
|
compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=Atuz-OdEgn8OCUhx7Ovd6gXdyImAI186uCR-uR0t_Nk,737
|
15
15
|
compressed_tensors/compressors/sparse_compressors/base.py,sha256=YNZWcHjDleAlqbgRZQ6oJf44MQb_UDNvJGOqhl26uFA,8098
|
16
16
|
compressed_tensors/compressors/sparse_compressors/dense.py,sha256=-OujJ1e0iXBvxYVULrIGvAZ9l-IC0mXczZRnimQdgo4,2314
|
17
|
-
compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py,sha256=
|
17
|
+
compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py,sha256=U6oJz_BYbHi3qtB8RUo5YKxF7hHL1NJQzGBQKjTVJnQ,9251
|
18
18
|
compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=S8vW0FI9ep_XtUQOxj0P5utJt3vKEYOHjWEPp-Xd9aY,5820
|
19
19
|
compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
|
20
|
-
compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=
|
20
|
+
compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=U-zfkUYvQb1owXit8irlRINhlGcjevYwwjtPjb2S2I8,10100
|
21
21
|
compressed_tensors/config/__init__.py,sha256=8sOoZ6xvYSC79mBvEtO8l6xk4PC80d29AnnJiGMrY2M,737
|
22
22
|
compressed_tensors/config/base.py,sha256=FaImUwb5G93en2BHUKDs76L_tO8NFpdxlfwAgQL7mNM,3569
|
23
23
|
compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
|
@@ -26,23 +26,23 @@ compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5y
|
|
26
26
|
compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
27
27
|
compressed_tensors/linear/compressed_linear.py,sha256=1yo9RyjA0aQ--iuIknFfcSorJn43Mn4CoV-q4JlTJ_o,4052
|
28
28
|
compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
|
29
|
-
compressed_tensors/quantization/quant_args.py,sha256=
|
29
|
+
compressed_tensors/quantization/quant_args.py,sha256=5AxYKqCSlg7CDgz2N8G4ZRVIiSUKvIm-SCQa-Bq_SF0,12916
|
30
30
|
compressed_tensors/quantization/quant_config.py,sha256=2NgDwKuQn0f-ojiHC8c6tXtYX_zQlk26Rj-bU71QKvA,10598
|
31
|
-
compressed_tensors/quantization/quant_scheme.py,sha256=
|
31
|
+
compressed_tensors/quantization/quant_scheme.py,sha256=X5Z7oXMLPXnX8g-UvWXlRjn4YnD_qTk5mXfGzu20k9o,8903
|
32
32
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
|
33
|
-
compressed_tensors/quantization/lifecycle/apply.py,sha256=
|
33
|
+
compressed_tensors/quantization/lifecycle/apply.py,sha256=yc9xCuQIcdhy-MGFh8OmBrB45dzJ8TzZju4mBa3AONg,14909
|
34
34
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
35
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
35
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=xcLTgaff1wYUWzvQqYKmhWYkshWVI-PhLPtBOyyZro0,17576
|
36
36
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
37
|
-
compressed_tensors/quantization/lifecycle/initialize.py,sha256=
|
37
|
+
compressed_tensors/quantization/lifecycle/initialize.py,sha256=f05UF6NaUGvR9qyxes_AgRcvg3KWgk5JeM_-NL1EQG0,10285
|
38
38
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
39
|
-
compressed_tensors/quantization/utils/helpers.py,sha256
|
39
|
+
compressed_tensors/quantization/utils/helpers.py,sha256=-pfSmxqHkrB-RnjF0VYz8lMe9CVnB7IJrONf9Y9fjCo,17014
|
40
40
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
41
|
-
compressed_tensors/registry/registry.py,sha256=
|
41
|
+
compressed_tensors/registry/registry.py,sha256=cWnlwZ66lgG0w9OAUEAgq5XVxqsgFm1o8ZYdNhkNvJY,11957
|
42
42
|
compressed_tensors/transform/__init__.py,sha256=v2wfl4CMfA6KbD7Hxx_MbRev63y_6QLDlccZq-WTtdw,907
|
43
43
|
compressed_tensors/transform/apply.py,sha256=nCJvhHleIyWPNYPr-SZvXhmTKpqHVpJrG8VfIW-K6d8,1422
|
44
44
|
compressed_tensors/transform/transform_args.py,sha256=rVgReFp7wMXcYugkfd325e2tTFh8pGV3FnYTGCEv5jY,3429
|
45
|
-
compressed_tensors/transform/transform_config.py,sha256=
|
45
|
+
compressed_tensors/transform/transform_config.py,sha256=3YdtGcau3qkcapX9GMUiLuhQHFQZKFYT3eLgJGj1L6s,1204
|
46
46
|
compressed_tensors/transform/transform_scheme.py,sha256=S7vYLnuv7xZ_bwphkpCiGqZLjnnTnb4lj1T8a6WwnE0,2094
|
47
47
|
compressed_tensors/transform/factory/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
48
48
|
compressed_tensors/transform/factory/base.py,sha256=Txkr1nWKtlMU1MmBcQ85-JqJzD356Z9nYbaF24tJ5rw,7755
|
@@ -52,19 +52,19 @@ compressed_tensors/transform/factory/random_hadamard.py,sha256=nUhTlFa4ikSpcl4Um
|
|
52
52
|
compressed_tensors/transform/utils/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
53
53
|
compressed_tensors/transform/utils/hadamard.py,sha256=hDJZC0Gw2fKdxqa3f8TmFc5J0eJqxHtFRxswLU_yVJc,5548
|
54
54
|
compressed_tensors/transform/utils/hadamards.safetensors,sha256=mFd1GzNodGG-ifA1IoH-0nHYzfraCOvrq_dX2zFI1B4,1436901
|
55
|
-
compressed_tensors/transform/utils/matrix.py,sha256=
|
55
|
+
compressed_tensors/transform/utils/matrix.py,sha256=3sPatOCzcLRE8ROLCGTKHr2c51DubJOFgmuNCgYdJP4,6164
|
56
56
|
compressed_tensors/utils/__init__.py,sha256=spzbjUO4-hZ2jXGST27r3MIt2yzIXsjdbEaYyaMcizo,873
|
57
57
|
compressed_tensors/utils/helpers.py,sha256=Q3iRAa2XSdmmn4vSpUplnvKOmWwn4Clao9ZkPBHXtpI,12604
|
58
58
|
compressed_tensors/utils/internal.py,sha256=7SSWgDoNFRnlfadwkoFhLW-T2jOc7Po_WzWv5h32Sa8,982
|
59
|
-
compressed_tensors/utils/match.py,sha256=
|
60
|
-
compressed_tensors/utils/offload.py,sha256=
|
59
|
+
compressed_tensors/utils/match.py,sha256=y03xJyWTXV8bjIPN5Z4S0_w797qMnh-Z4aiPEGQ4zNE,11239
|
60
|
+
compressed_tensors/utils/offload.py,sha256=jE9xj3VewMc85iOLWSikqdyjNL9JB3oZpO1uDKKCLUE,24444
|
61
61
|
compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
62
62
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
63
|
-
compressed_tensors/utils/safetensors_load.py,sha256=
|
63
|
+
compressed_tensors/utils/safetensors_load.py,sha256=Vql34aCTDHwmTZXJHzCyBISJo7iA7EQ78LdTlMjdpZo,12023
|
64
64
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
65
65
|
compressed_tensors/utils/type.py,sha256=bNwoo_FWlvLuDpYAGGzZJITRg0JA_Ngk9LGPo-kvjeU,2554
|
66
|
-
compressed_tensors-0.11.
|
67
|
-
compressed_tensors-0.11.
|
68
|
-
compressed_tensors-0.11.
|
69
|
-
compressed_tensors-0.11.
|
70
|
-
compressed_tensors-0.11.
|
66
|
+
compressed_tensors-0.11.1a20250821.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
67
|
+
compressed_tensors-0.11.1a20250821.dist-info/METADATA,sha256=jpkjjAiWJwPLa19Ej2tIJm5MEHJ9gwYsPPfvkhF6YYg,7031
|
68
|
+
compressed_tensors-0.11.1a20250821.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
69
|
+
compressed_tensors-0.11.1a20250821.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
70
|
+
compressed_tensors-0.11.1a20250821.dist-info/RECORD,,
|
{compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250821.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|
File without changes
|