compressed-tensors 0.11.1a20250820__py3-none-any.whl → 0.11.1a20250828__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. compressed_tensors/compressors/model_compressors/model_compressor.py +178 -156
  2. compressed_tensors/compressors/quantized_compressors/base.py +2 -2
  3. compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +9 -9
  4. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -3
  5. compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +1 -1
  6. compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
  7. compressed_tensors/quantization/lifecycle/apply.py +48 -142
  8. compressed_tensors/quantization/lifecycle/forward.py +5 -4
  9. compressed_tensors/quantization/lifecycle/initialize.py +7 -6
  10. compressed_tensors/quantization/quant_args.py +7 -5
  11. compressed_tensors/quantization/quant_scheme.py +4 -3
  12. compressed_tensors/quantization/utils/helpers.py +0 -1
  13. compressed_tensors/registry/registry.py +1 -1
  14. compressed_tensors/transform/transform_config.py +1 -1
  15. compressed_tensors/transform/utils/matrix.py +1 -1
  16. compressed_tensors/utils/match.py +57 -8
  17. compressed_tensors/utils/offload.py +0 -1
  18. compressed_tensors/utils/safetensors_load.py +0 -1
  19. compressed_tensors/version.py +1 -1
  20. {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/METADATA +1 -1
  21. {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/RECORD +24 -24
  22. {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/WHEEL +0 -0
  23. {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/licenses/LICENSE +0 -0
  24. {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/top_level.txt +0 -0
@@ -42,8 +42,6 @@ from compressed_tensors.quantization import (
42
42
  apply_quantization_config,
43
43
  load_pretrained_quantization_parameters,
44
44
  )
45
- from compressed_tensors.quantization.lifecycle import expand_target_names
46
- from compressed_tensors.quantization.utils import is_module_quantized
47
45
  from compressed_tensors.transform import TransformConfig
48
46
  from compressed_tensors.utils import (
49
47
  align_module_device,
@@ -60,6 +58,7 @@ from compressed_tensors.utils.helpers import (
60
58
  fix_fsdp_module_name,
61
59
  is_compressed_tensors_config,
62
60
  )
61
+ from compressed_tensors.utils.match import match_named_modules
63
62
  from torch import Tensor
64
63
  from torch.nn import Module
65
64
  from tqdm import tqdm
@@ -309,7 +308,7 @@ class ModelCompressor:
309
308
  if quantization_config is not None:
310
309
  # If a list of compression_format is not provided, we resolve the
311
310
  # relevant quantization formats using the config groups from the config
312
- # and if those are not defined, we fall-back to the global quantization format
311
+ # and if those are not defined, we fall-back to the global quantization fmt
313
312
  if not self.compression_formats:
314
313
  self.compression_formats = self._fetch_unique_quantization_formats()
315
314
 
@@ -342,13 +341,15 @@ class ModelCompressor:
342
341
  self.sparsity_compressor
343
342
  and self.sparsity_config.format != CompressionFormat.dense.value
344
343
  ):
345
- sparse_targets = expand_target_names(
344
+ sparse_targets = match_named_modules(
346
345
  model=model,
347
346
  targets=self.sparsity_config.targets,
348
347
  ignore=self.sparsity_config.ignore,
349
348
  )
349
+
350
350
  missing_keys.update(
351
- merge_names(target, "weight") for target in sparse_targets
351
+ merge_names(target_name, "weight")
352
+ for target_name, _module in sparse_targets
352
353
  )
353
354
 
354
355
  # Determine missing keys due to pack quantization
@@ -358,13 +359,14 @@ class ModelCompressor:
358
359
  == CompressionFormat.pack_quantized.value
359
360
  ):
360
361
  for scheme in self.quantization_config.config_groups.values():
361
- quant_targets = expand_target_names(
362
+ quant_targets = match_named_modules(
362
363
  model=model,
363
364
  targets=scheme.targets,
364
365
  ignore=self.quantization_config.ignore,
365
366
  )
366
367
  missing_keys.update(
367
- merge_names(target, "weight") for target in quant_targets
368
+ merge_names(target_name, "weight")
369
+ for target_name, _module in quant_targets
368
370
  )
369
371
 
370
372
  return list(missing_keys)
@@ -395,29 +397,29 @@ class ModelCompressor:
395
397
  self.sparsity_compressor
396
398
  and self.sparsity_config.format != CompressionFormat.dense.value
397
399
  ):
398
- sparse_targets: Set[str] = expand_target_names(
400
+ sparse_targets = match_named_modules(
399
401
  model=model,
400
402
  targets=self.sparsity_config.targets,
401
403
  ignore=self.sparsity_config.ignore,
402
404
  )
403
405
  unexpected_keys.update(
404
- merge_names(target, param)
405
- for target in sparse_targets
406
+ merge_names(target_name, param)
407
+ for target_name, _module in sparse_targets
406
408
  for param in self.sparsity_compressor.compression_param_names
407
409
  )
408
410
 
409
411
  # Identify unexpected keys from quantization compression
410
412
  if self.quantization_compressor:
411
413
  for scheme in self.quantization_config.config_groups.values():
412
- quant_targets: Set[str] = expand_target_names(
414
+ quant_targets = match_named_modules(
413
415
  model=model,
414
416
  targets=scheme.targets,
415
417
  ignore=self.quantization_config.ignore,
416
418
  )
417
419
  for quant_compressor in self.quantization_compressor.values():
418
420
  unexpected_keys.update(
419
- merge_names(target, param)
420
- for target in quant_targets
421
+ merge_names(target_name, param)
422
+ for target_name, _module in quant_targets
421
423
  for param in quant_compressor.compression_param_names
422
424
  if param != "weight"
423
425
  )
@@ -434,73 +436,79 @@ class ModelCompressor:
434
436
  :param model: model containing parameters to compress
435
437
  """
436
438
  module_to_scheme = map_module_to_scheme(model)
437
- sparse_compression_targets: Set[str] = expand_target_names(
438
- model=model,
439
- targets=self.sparsity_config.targets if self.sparsity_config else [],
440
- ignore=self.sparsity_config.ignore if self.sparsity_config else [],
441
- )
442
-
443
- for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
444
-
445
- if prefix in module_to_scheme or prefix in sparse_compression_targets:
446
- module_device = get_execution_device(module)
447
- is_meta = module_device.type == "meta"
448
-
449
- exec_device = "meta" if is_meta else "cpu"
450
- onloading_device = "meta" if is_meta else module_device
451
-
452
- # in the future, support compression on same device
453
- with align_module_device(module, execution_device=exec_device):
454
- state_dict = {
455
- f"{prefix}.{name}": param
456
- for name, param in module.named_parameters(recurse=False)
457
- }
458
-
459
- # quantization first
460
- if prefix in module_to_scheme:
461
- if (
462
- not hasattr(module.quantization_scheme, "format")
463
- or module.quantization_scheme.format is None
464
- ):
465
- if len(self.compression_formats) > 1:
466
- raise ValueError(
467
- "Applying multiple compressors without defining "
468
- "per module formats is not supported "
469
- )
470
- format = self.compression_formats[0]
471
- else:
472
- format = module.quantization_scheme.format
473
-
474
- quant_compressor = self.quantization_compressor.get(format)
475
- state_dict = quant_compressor.compress(
476
- state_dict,
477
- names_to_scheme=module_to_scheme,
478
- show_progress=False,
479
- compression_device=exec_device,
480
- )
481
-
482
- # sparsity second
483
- if prefix in sparse_compression_targets:
484
- state_dict = self.sparsity_compressor.compress(
485
- state_dict,
486
- compression_targets=sparse_compression_targets,
487
- show_progress=False,
488
- )
439
+ sparse_compression_targets = [
440
+ module_name
441
+ for module_name, _module in match_named_modules(
442
+ model=model,
443
+ targets=self.sparsity_config.targets if self.sparsity_config else [],
444
+ ignore=self.sparsity_config.ignore if self.sparsity_config else [],
445
+ )
446
+ ]
447
+ for prefix, module in tqdm(
448
+ match_named_modules(
449
+ model,
450
+ [*sparse_compression_targets, *module_to_scheme.keys()],
451
+ warn_on_fail=True,
452
+ ),
453
+ desc="Compressing model",
454
+ ):
455
+ module_device = get_execution_device(module)
456
+ is_meta = module_device.type == "meta"
457
+
458
+ exec_device = "meta" if is_meta else "cpu"
459
+ onloading_device = "meta" if is_meta else module_device
460
+
461
+ # in the future, support compression on same device
462
+ with align_module_device(module, execution_device=exec_device):
463
+ state_dict = {
464
+ f"{prefix}.{name}": param
465
+ for name, param in module.named_parameters(recurse=False)
466
+ }
467
+
468
+ # quantization first
469
+ if prefix in module_to_scheme:
470
+ if (
471
+ not hasattr(module.quantization_scheme, "format")
472
+ or module.quantization_scheme.format is None
473
+ ):
474
+ if len(self.compression_formats) > 1:
475
+ raise ValueError(
476
+ "Applying multiple compressors without defining "
477
+ "per module formats is not supported "
478
+ )
479
+ format = self.compression_formats[0]
480
+ else:
481
+ format = module.quantization_scheme.format
482
+
483
+ quant_compressor = self.quantization_compressor.get(format)
484
+ state_dict = quant_compressor.compress(
485
+ state_dict,
486
+ names_to_scheme=module_to_scheme,
487
+ show_progress=False,
488
+ compression_device=exec_device,
489
+ )
489
490
 
490
- # remove any existing parameters
491
- offload_device = get_offloaded_device(module)
492
- for name, _ in list(module.named_parameters(recurse=False)):
493
- delete_offload_parameter(module, name)
491
+ # sparsity second
492
+ if prefix in sparse_compression_targets:
493
+ state_dict = self.sparsity_compressor.compress(
494
+ state_dict,
495
+ compression_targets=sparse_compression_targets,
496
+ show_progress=False,
497
+ )
494
498
 
495
- # replace with compressed parameters
496
- for name, value in state_dict.items():
497
- name = name.removeprefix(f"{prefix}.")
498
- value = value.to(onloading_device)
499
- param = torch.nn.Parameter(value, requires_grad=False)
500
- register_offload_parameter(module, name, param, offload_device)
499
+ # remove any existing parameters
500
+ offload_device = get_offloaded_device(module)
501
+ for name, _ in list(module.named_parameters(recurse=False)):
502
+ delete_offload_parameter(module, name)
501
503
 
502
- module.quantization_status = QuantizationStatus.COMPRESSED
504
+ # replace with compressed parameters
505
+ for name, value in state_dict.items():
506
+ name = name.removeprefix(f"{prefix}.")
507
+ value = value.to(onloading_device)
508
+ param = torch.nn.Parameter(value, requires_grad=False)
509
+ register_offload_parameter(module, name, param, offload_device)
503
510
 
511
+ module.quantization_status = QuantizationStatus.COMPRESSED
504
512
  # TODO: consider sparse compression to also be compression
505
513
  if (
506
514
  self.quantization_config is not None
@@ -516,67 +524,75 @@ class ModelCompressor:
516
524
  :param model: model containing parameters to compress
517
525
  """
518
526
  module_to_scheme = map_module_to_scheme(model)
519
- sparse_compression_targets: Set[str] = expand_target_names(
520
- model=model,
521
- targets=self.sparsity_config.targets if self.sparsity_config else [],
522
- ignore=self.sparsity_config.ignore if self.sparsity_config else [],
523
- )
524
-
525
- for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
526
- if prefix in module_to_scheme or prefix in sparse_compression_targets:
527
- # in the future, support decompression on same device
528
- with align_module_device(module, execution_device="cpu"):
529
- state_dict = {
530
- f"{prefix}.{name}": param
531
- for name, param in module.named_parameters(recurse=False)
532
- }
533
-
534
- # sparsity first
535
- if prefix in sparse_compression_targets:
536
- # sparse_compression_targets are automatically inferred by this fn
537
- generator = self.sparsity_compressor.decompress_from_state_dict(
538
- state_dict,
539
- )
540
- # generates (param_path, param_val)
541
- # of compressed and unused params
542
- state_dict = {key: value for key, value in generator}
543
-
544
- # quantization second
545
- if prefix in module_to_scheme:
546
-
547
- if (
548
- not hasattr(module.quantization_scheme, "format")
549
- or module.quantization_scheme.format is None
550
- ):
551
- if len(self.compression_formats) > 1:
552
- raise ValueError(
553
- "Applying multiple compressors without defining "
554
- "per module formats is not supported "
555
- )
556
- format = self.compression_formats[0]
557
- else:
558
- format = module.quantization_scheme.format
559
- quant_compressor = self.quantization_compressor.get(format)
560
- state_dict = quant_compressor.decompress_module_from_state_dict(
561
- prefix,
562
- state_dict,
563
- scheme=module_to_scheme[prefix],
564
- )
527
+ sparse_compression_targets = [
528
+ module_name
529
+ for module_name, _module in match_named_modules(
530
+ model=model,
531
+ targets=self.sparsity_config.targets if self.sparsity_config else [],
532
+ ignore=self.sparsity_config.ignore if self.sparsity_config else [],
533
+ )
534
+ ]
535
+
536
+ for prefix, module in tqdm(
537
+ match_named_modules(
538
+ model,
539
+ [*sparse_compression_targets, *module_to_scheme.keys()],
540
+ warn_on_fail=True,
541
+ ),
542
+ desc="Decompressing model",
543
+ ):
544
+ # in the future, support decompression on same device
545
+ with align_module_device(module, execution_device="cpu"):
546
+ state_dict = {
547
+ f"{prefix}.{name}": param
548
+ for name, param in module.named_parameters(recurse=False)
549
+ }
550
+
551
+ # sparsity first
552
+ if prefix in sparse_compression_targets:
553
+ # sparse_compression_targets are automatically inferred by this fn
554
+ generator = self.sparsity_compressor.decompress_from_state_dict(
555
+ state_dict,
556
+ )
557
+ # generates (param_path, param_val)
558
+ # of compressed and unused params
559
+ state_dict = {key: value for key, value in generator}
560
+
561
+ # quantization second
562
+ if prefix in module_to_scheme:
563
+ if (
564
+ not hasattr(module.quantization_scheme, "format")
565
+ or module.quantization_scheme.format is None
566
+ ):
567
+ if len(self.compression_formats) > 1:
568
+ raise ValueError(
569
+ "Applying multiple compressors without defining "
570
+ "per module formats is not supported "
571
+ )
572
+ format = self.compression_formats[0]
573
+ else:
574
+ format = module.quantization_scheme.format
575
+ quant_compressor = self.quantization_compressor.get(format)
576
+ state_dict = quant_compressor.decompress_module_from_state_dict(
577
+ prefix,
578
+ state_dict,
579
+ scheme=module_to_scheme[prefix],
580
+ )
565
581
 
566
- # remove any existing parameters
567
- exec_device = get_execution_device(module)
568
- offload_device = get_offloaded_device(module)
569
- for name, _ in list(module.named_parameters(recurse=False)):
570
- delete_offload_parameter(module, name)
582
+ # remove any existing parameters
583
+ exec_device = get_execution_device(module)
584
+ offload_device = get_offloaded_device(module)
585
+ for name, _ in list(module.named_parameters(recurse=False)):
586
+ delete_offload_parameter(module, name)
571
587
 
572
- # replace with decompressed parameters
573
- for name, value in state_dict.items():
574
- name = name.removeprefix(f"{prefix}.")
575
- value = value.to(exec_device)
576
- param = torch.nn.Parameter(value, requires_grad=False)
577
- register_offload_parameter(module, name, param, offload_device)
588
+ # replace with decompressed parameters
589
+ for name, value in state_dict.items():
590
+ name = name.removeprefix(f"{prefix}.")
591
+ value = value.to(exec_device)
592
+ param = torch.nn.Parameter(value, requires_grad=False)
593
+ register_offload_parameter(module, name, param, offload_device)
578
594
 
579
- module.quantization_status = QuantizationStatus.FROZEN
595
+ module.quantization_status = QuantizationStatus.FROZEN
580
596
 
581
597
  # ----- state dict compression pathways ----- #
582
598
 
@@ -614,11 +630,14 @@ class ModelCompressor:
614
630
  )
615
631
 
616
632
  if self.sparsity_compressor is not None:
617
- sparse_compression_targets: Set[str] = expand_target_names(
618
- model=model,
619
- targets=self.sparsity_config.targets,
620
- ignore=self.sparsity_config.ignore,
621
- )
633
+ sparse_compression_targets: Set[str] = {
634
+ module_name
635
+ for module_name, _module in match_named_modules(
636
+ model=model,
637
+ targets=self.sparsity_config.targets,
638
+ ignore=self.sparsity_config.ignore,
639
+ )
640
+ }
622
641
  state_dict = self.sparsity_compressor.compress(
623
642
  state_dict,
624
643
  compression_targets=sparse_compression_targets,
@@ -641,11 +660,12 @@ class ModelCompressor:
641
660
  :param model_path: path to compressed weights
642
661
  :param model: pytorch model to load decompressed weights into
643
662
 
644
- Note: decompress makes use of both _replace_sparsity_weights and _replace_weights
645
- The variations in these methods are a result of the subtle variations between the sparsity
646
- and quantization compressors. Specifically, quantization compressors return not just the
647
- decompressed weight, but the quantization parameters (e.g scales, zero_point) whereas sparsity
648
- compressors only return the decompressed weight.
663
+ Note: decompress makes use of both _replace_sparsity_weights and
664
+ _replace_weights. The variations in these methods are a result of the subtle
665
+ variations between the sparsity and quantization compressors. Specifically,
666
+ quantization compressors return not just the decompressed weight, but the
667
+ quantization parameters (e.g scales, zero_point) whereas sparsity compressors
668
+ only return the decompressed weight.
649
669
 
650
670
  """
651
671
  model_path = get_safetensors_folder(model_path)
@@ -683,18 +703,20 @@ class ModelCompressor:
683
703
  with override_quantization_status(
684
704
  self.quantization_config, QuantizationStatus.FROZEN
685
705
  ):
686
-
687
- names_to_scheme = apply_quantization_config(
688
- model, self.quantization_config
689
- )
706
+ apply_quantization_config(model, self.quantization_config)
707
+ names_to_scheme: Set[QuantizationScheme] = {
708
+ name: getattr(module, "quantization_scheme")
709
+ for name, module in model.named_modules()
710
+ if getattr(module, "quantization_scheme", None) is not None
711
+ }
690
712
  # Load activation scales/zp or any other quantization parameters
691
- # Conditionally load the weight quantization parameters if we have a dense compressor
692
- # Or if a sparsity compressor has already been applied
713
+ # Conditionally load the weight quantization parameters if we have a
714
+ # dense compressor or if a sparsity compressor has already been applied
693
715
  load_pretrained_quantization_parameters(
694
716
  model,
695
717
  model_path,
696
- # TODO: all weight quantization params will be moved to the compressor in a follow-up
697
- # including initialization
718
+ # TODO: all weight quantization params will be moved to the
719
+ # compressor in a follow-up including initialization
698
720
  load_weight_quantization=(
699
721
  sparse_decompressed
700
722
  or isinstance(quant_compressor, DenseCompressor)
@@ -786,7 +808,6 @@ class ModelCompressor:
786
808
  :param model: The model whose weights are to be updated.
787
809
  """
788
810
  for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
789
-
790
811
  split_name = name.split(".")
791
812
  prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
792
813
  module = operator.attrgetter(prefix)(model)
@@ -822,9 +843,10 @@ class ModelCompressor:
822
843
  for param_name, param_data in data.items():
823
844
  if hasattr(module, param_name):
824
845
  # If compressed, will have an incorrect dtype for transformers >4.49
825
- # TODO: we can also just skip initialization of scales/zp if in decompression in init
826
- # to be consistent with loading which happens later as well
827
- # however, update_data does a good shape check - should be moved to the compressor
846
+ # TODO: we can also just skip initialization of scales/zp if in
847
+ # decompression in init to be consistent with loading which happens
848
+ # later as well however, update_data does a good shape check -
849
+ # should be moved to the compressor
828
850
  if param_name == "weight":
829
851
  delattr(module, param_name)
830
852
  requires_grad = param_data.dtype in (
@@ -24,7 +24,6 @@ from compressed_tensors.utils import (
24
24
  get_nested_weight_mappings,
25
25
  merge_names,
26
26
  )
27
- from compressed_tensors.utils.safetensors_load import match_param_name
28
27
  from safetensors import safe_open
29
28
  from torch import Tensor
30
29
  from tqdm import tqdm
@@ -107,7 +106,8 @@ class BaseQuantizationCompressor(BaseCompressor):
107
106
  compressed_dict[name] = value.to(compression_device)
108
107
  continue
109
108
 
110
- # compress values on meta if loading from meta otherwise on cpu (memory movement too expensive)
109
+ # compress values on meta if loading from meta otherwise on cpu (memory
110
+ # movement too expensive)
111
111
  module_path = prefix[:-1] if prefix.endswith(".") else prefix
112
112
  quant_args = names_to_scheme[module_path].weights
113
113
  compressed_values = self.compress_weight(
@@ -15,7 +15,6 @@
15
15
 
16
16
  from typing import Dict, Optional, Tuple
17
17
 
18
- import numpy
19
18
  import torch
20
19
  from compressed_tensors.compressors.base import BaseCompressor
21
20
  from compressed_tensors.compressors.quantized_compressors.base import (
@@ -92,7 +91,6 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
92
91
  zero_point: Optional[torch.Tensor] = None,
93
92
  g_idx: Optional[torch.Tensor] = None,
94
93
  ) -> Dict[str, torch.Tensor]:
95
-
96
94
  quantized_weight = quantize(
97
95
  x=weight,
98
96
  scale=scale,
@@ -112,7 +110,6 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
112
110
  compressed_data: Dict[str, Tensor],
113
111
  quantization_args: Optional[QuantizationArgs] = None,
114
112
  ) -> torch.Tensor:
115
-
116
113
  weight = compressed_data["weight_packed"]
117
114
  scale = compressed_data["weight_scale"]
118
115
  global_scale = compressed_data["weight_global_scale"]
@@ -126,6 +123,7 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
126
123
  return decompressed_weight
127
124
 
128
125
 
126
+ @torch.compile(fullgraph=True, dynamic=True)
129
127
  def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
130
128
  """
131
129
  Packs a tensor with values in the fp4 range into uint8.
@@ -148,12 +146,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
148
146
 
149
147
  # Find closest valid FP4 value index for each element
150
148
  abs_x = torch.abs(x)
151
- abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
152
- for i, val in enumerate(kE2M1):
153
- abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)
149
+ abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8]
150
+ abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n]
154
151
 
155
152
  # Apply sign bit (bit 3) to get final 4-bit representation
156
- indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)
153
+ indices = abs_indices + (torch.signbit(x).to(torch.long) << 3)
157
154
 
158
155
  # Reshape to prepare for packing pairs of values
159
156
  indices = indices.reshape(-1)
@@ -175,14 +172,17 @@ kE2M1ToFloat = torch.tensor(
175
172
  [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
176
173
  )
177
174
 
175
+
178
176
  # reference: : https://github.com/vllm-project/vllm/pull/16362
177
+ @torch.compile(fullgraph=True, dynamic=True)
179
178
  def unpack_fp4_from_uint8(
180
179
  a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
181
180
  ) -> torch.Tensor:
182
181
  """
183
182
  Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
184
- (i.e. first four bits correspond to one fp4 value, last four corresond to a consecutive
185
- fp4 value). The bits represent an index, which are mapped to an fp4 value.
183
+ (i.e. first four bits correspond to one fp4 value, last four correspond to a
184
+ consecutive fp4 value). The bits represent an index, which are mapped to an fp4
185
+ value.
186
186
 
187
187
  :param a: tensor to unpack
188
188
  :param m: original dim 0 size of the unpacked tensor
@@ -14,7 +14,6 @@
14
14
  import math
15
15
  from typing import Dict, Literal, Optional, Tuple, Union
16
16
 
17
- import numpy as np
18
17
  import torch
19
18
  from compressed_tensors.compressors.base import BaseCompressor
20
19
  from compressed_tensors.compressors.quantized_compressors.base import (
@@ -135,7 +134,8 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
135
134
  compressed_dict["weight_shape"] = weight_shape
136
135
  compressed_dict["weight_packed"] = packed_weight
137
136
 
138
- # We typically don't compress zp; apart from when using the packed_compressor and when storing group/channel zp
137
+ # We typically don't compress zp; apart from when using the packed_compressor
138
+ # and when storing group/channel zp
139
139
  if not quantization_args.symmetric and quantization_args.strategy in [
140
140
  QuantizationStrategy.GROUP.value,
141
141
  QuantizationStrategy.CHANNEL.value,
@@ -166,7 +166,8 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
166
166
  num_bits = quantization_args.num_bits
167
167
  unpacked = unpack_from_int32(weight, num_bits, original_shape)
168
168
 
169
- # NOTE: this will fail decompression as we don't currently handle packed zp on decompression
169
+ # NOTE: this will fail decompression as we don't currently handle packed zp on
170
+ # decompression
170
171
  if not quantization_args.symmetric and quantization_args.strategy in [
171
172
  QuantizationStrategy.GROUP.value,
172
173
  QuantizationStrategy.CHANNEL.value,
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from dataclasses import dataclass
16
- from typing import Dict, Generator, List, Tuple, Union
16
+ from typing import Dict, List, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from compressed_tensors.compressors.base import BaseCompressor
@@ -48,7 +48,7 @@ class Marlin24Compressor(BaseCompressor):
48
48
 
49
49
  @staticmethod
50
50
  def validate_quant_compatability(
51
- names_to_scheme: Dict[str, QuantizationScheme]
51
+ names_to_scheme: Dict[str, QuantizationScheme],
52
52
  ) -> bool:
53
53
  """
54
54
  Checks if every quantized module in the model is compatible with Marlin24