compressed-tensors 0.9.5a20250530__tar.gz → 0.9.5a20250603__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. {compressed_tensors-0.9.5a20250530/src/compressed_tensors.egg-info → compressed_tensors-0.9.5a20250603}/PKG-INFO +1 -1
  2. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/apply.py +1 -10
  3. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/forward.py +35 -24
  4. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/initialize.py +7 -113
  5. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/quant_args.py +1 -0
  6. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/utils/helpers.py +9 -7
  7. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/offload.py +134 -1
  8. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/version.py +1 -1
  9. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  10. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +3 -3
  11. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/quantized_compressors/test_int_quant.py +2 -2
  12. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_forward.py +12 -12
  13. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_utils/test_helpers.py +3 -5
  14. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_utils/test_offload.py +95 -5
  15. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/.gitkeep +0 -0
  16. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/actions/test/action.yml +0 -0
  17. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/scripts/step-status +0 -0
  18. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/build-test.yml +0 -0
  19. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/build.yml +0 -0
  20. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/report.yml +0 -0
  21. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/test-check.yaml +0 -0
  22. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/test.yml +0 -0
  23. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/trigger-all.yml +0 -0
  24. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/upload.yml +0 -0
  25. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.gitignore +0 -0
  26. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/LICENSE +0 -0
  27. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/Makefile +0 -0
  28. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/README.md +0 -0
  29. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  30. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/bit_packing/int4_config.json +0 -0
  31. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/bitmask_compression.ipynb +0 -0
  32. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  33. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  34. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/llama_1.1b/example_quant_config.json +0 -0
  35. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  36. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/quantize_and_pack_int4.ipynb +0 -0
  37. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/pyproject.toml +0 -0
  38. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/setup.cfg +0 -0
  39. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/setup.py +0 -0
  40. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/__init__.py +0 -0
  41. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/README.md +0 -0
  42. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/__init__.py +0 -0
  43. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/base.py +0 -0
  44. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/__init__.py +0 -0
  45. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/base.py +0 -0
  46. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/helpers.py +0 -0
  47. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  48. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  49. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  50. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  51. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  52. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  53. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  54. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  55. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  56. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  57. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  58. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  59. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  60. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  61. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/config/__init__.py +0 -0
  62. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/config/base.py +0 -0
  63. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/config/dense.py +0 -0
  64. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  65. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  66. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/linear/__init__.py +0 -0
  67. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  68. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/__init__.py +0 -0
  69. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  70. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  71. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  72. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/quant_config.py +0 -0
  73. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  74. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  75. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/registry/__init__.py +0 -0
  76. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/registry/registry.py +0 -0
  77. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/transform/__init__.py +0 -0
  78. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/transform/transform_args.py +0 -0
  79. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/transform/transform_config.py +0 -0
  80. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  81. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/__init__.py +0 -0
  82. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/helpers.py +0 -0
  83. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/permutations_24.py +0 -0
  84. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/permute.py +0 -0
  85. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  86. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  87. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors.egg-info/SOURCES.txt +0 -0
  88. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  89. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors.egg-info/requires.txt +0 -0
  90. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  91. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/__init__.py +0 -0
  92. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/conftest.py +0 -0
  93. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/__init__.py +0 -0
  94. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/model_compressors/__init__.py +0 -0
  95. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  96. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  97. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  98. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  99. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  100. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  101. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  102. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  103. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  104. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_configs/__init__.py +0 -0
  105. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_configs/test_base.py +0 -0
  106. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  107. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_linear/__init__.py +0 -0
  108. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_linear/test_compressed_linear.py +0 -0
  109. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/__init__.py +0 -0
  110. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/__init__.py +0 -0
  111. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/conftest.py +0 -0
  112. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  113. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  114. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  115. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  116. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  117. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  118. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_configs/__init__.py +0 -0
  119. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  120. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  121. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_quant_args.py +0 -0
  122. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_quant_config.py +0 -0
  123. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_quant_scheme.py +0 -0
  124. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_registry.py +0 -0
  125. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_transform/test_transform_args.py +0 -0
  126. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_transform/test_transform_config.py +0 -0
  127. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_transform/test_transform_scheme.py +0 -0
  128. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_utils/__init__.py +0 -0
  129. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_utils/test_helpers.py +0 -0
  130. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_utils/test_safetensors_load.py +0 -0
  131. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/testing_utils.py +0 -0
  132. {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250530
3
+ Version: 0.9.5a20250603
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -27,14 +27,8 @@ from compressed_tensors.quantization.lifecycle.compressed import (
27
27
  )
28
28
  from compressed_tensors.quantization.lifecycle.initialize import (
29
29
  initialize_module_for_quantization,
30
- update_fused_layer_weight_global_scales,
31
- )
32
- from compressed_tensors.quantization.quant_args import (
33
- FP4_E2M1_DATA,
34
- FP8_E4M3_DATA,
35
- QuantizationArgs,
36
- QuantizationType,
37
30
  )
31
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
38
32
  from compressed_tensors.quantization.quant_config import (
39
33
  QuantizationConfig,
40
34
  QuantizationStatus,
@@ -272,9 +266,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
272
266
  )
273
267
  )
274
268
 
275
- if status == QuantizationStatus.INITIALIZED:
276
- update_fused_layer_weight_global_scales(model)
277
-
278
269
  if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
279
270
  model.apply(compress_quantized_weights)
280
271
 
@@ -227,31 +227,42 @@ def _process_quantization(
227
227
  perm = torch.argsort(g_idx)
228
228
  x = safe_permute(x, perm, dim=1)
229
229
 
230
- # TODO: experiment with vectorizing for loop for performance
231
- end = 0
232
- for index, group_count in enumerate(group_sizes):
233
- sc = scale[:, index].view(-1, 1)
234
- zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
235
-
236
- start = end
237
- end = start + group_count
238
- if do_quantize:
239
- output[:, start:end] = _quantize(
240
- x=x[:, start:end],
241
- scale=sc,
242
- zero_point=zp,
243
- q_min=q_min,
244
- q_max=q_max,
245
- args=args,
246
- dtype=dtype,
247
- global_scale=global_scale,
248
- )
230
+ x = torch.reshape(
231
+ x,
232
+ (
233
+ x.shape[0],
234
+ ceil(x.shape[1] / group_size),
235
+ group_size,
236
+ ),
237
+ )
249
238
 
250
- if do_dequantize:
251
- input = output[:, start:end] if do_quantize else x[:, start:end]
252
- output[:, start:end] = _dequantize(
253
- x_q=input, scale=sc, zero_point=zp, global_scale=global_scale
254
- )
239
+ if do_quantize:
240
+ output = _quantize(
241
+ x=x,
242
+ scale=scale.unsqueeze(-1),
243
+ zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
244
+ dtype=dtype,
245
+ global_scale=global_scale,
246
+ q_min=q_min,
247
+ q_max=q_max,
248
+ args=args,
249
+ )
250
+
251
+ if do_dequantize:
252
+ input = output if do_quantize else x
253
+ output = _dequantize(
254
+ x_q=input,
255
+ scale=scale.unsqueeze(-1),
256
+ zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
257
+ global_scale=global_scale,
258
+ )
259
+
260
+ output = torch.reshape(
261
+ output,
262
+ (output.shape[0], output.shape[1] * output.shape[2]),
263
+ )
264
+
265
+ output = output.to(output_dtype)
255
266
 
256
267
  if not is_column_order:
257
268
  output = safe_permute(output, torch.argsort(perm), dim=1)
@@ -23,26 +23,18 @@ from compressed_tensors.quantization.lifecycle.forward import (
23
23
  wrap_module_forward_quantized,
24
24
  )
25
25
  from compressed_tensors.quantization.quant_args import (
26
- FP4_E2M1_DATA,
27
26
  FP8_E4M3_DATA,
28
27
  ActivationOrdering,
29
28
  QuantizationArgs,
30
29
  QuantizationStrategy,
31
- QuantizationType,
32
30
  )
33
31
  from compressed_tensors.quantization.quant_config import QuantizationStatus
34
32
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
35
- from compressed_tensors.quantization.utils import (
36
- generate_global_scale,
37
- is_fp4,
38
- is_kv_cache_quant_scheme,
39
- iter_named_quantizable_modules,
40
- )
33
+ from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
41
34
  from compressed_tensors.utils import (
42
35
  disable_hf_hook,
43
36
  get_execution_device,
44
37
  register_offload_parameter,
45
- update_parameter_data,
46
38
  )
47
39
  from torch.nn import Module, Parameter
48
40
 
@@ -51,7 +43,6 @@ __all__ = [
51
43
  "initialize_module_for_quantization",
52
44
  "is_attention_module",
53
45
  "KVCacheScaleType",
54
- "update_fused_layer_weight_global_scales",
55
46
  ]
56
47
 
57
48
 
@@ -162,22 +153,13 @@ def _initialize_scale_zero_point(
162
153
  # initialize on execution device to avoid performing quantized ops on cpu
163
154
  device = get_execution_device(module)
164
155
 
165
- # 1. Create global_scales for tensor_group
156
+ # 1. Create global_scales for tensor_group - generates
157
+ # a per tensor scale
166
158
  if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
167
- # TODO: should move to llmcompressor
168
- if base_name == "weight":
169
- # When applying weight-only FP4 quantization, generate a global_scale
170
- # This scale is applied during runtime to ensure that the generated
171
- # local scale falls properly within the FP8 range (i.e max value is FP8_max)
172
- # which is the expected dtype of NVFP4A16 scales
173
- value = generate_global_scale(input_tensor=module.weight)
174
- value = value.to(device)
175
- init_global_scale = Parameter(value, requires_grad=False)
176
- else:
177
- init_global_scale = Parameter(
178
- torch.empty(1, dtype=torch.float32, device=device),
179
- requires_grad=False,
180
- )
159
+ init_global_scale = Parameter(
160
+ torch.empty(1, dtype=torch.float32, device=device),
161
+ requires_grad=False,
162
+ )
181
163
  register_offload_parameter(
182
164
  module, f"{base_name}_global_scale", init_global_scale
183
165
  )
@@ -258,91 +240,3 @@ def _initialize_attn_scales(module: Module) -> None:
258
240
  requires_grad=False,
259
241
  )
260
242
  register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
261
-
262
-
263
- # TODO: Potentially introduce an argument to turn this off
264
- # Only relevant for NVFP4A16 currently
265
- def update_fused_layer_weight_global_scales(model: torch.nn.Module):
266
- """
267
- When running NVFP4A16 quantization, update the global scale
268
- such that q,k,v layers are treated as one tensor with the same
269
- global_scale and gate_proj/up_proj layers are treated as one tensor
270
- with the same global scale. This is requirement currently being set
271
- by vLLM and may be removed in the future OR potentially make it
272
- an optional step.
273
-
274
- :param model: model to quantize
275
- """
276
-
277
- def _is_attention_module(module: Module):
278
- return "attention" in module.__class__.__name__.lower() and (
279
- hasattr(module, "k_proj")
280
- or hasattr(module, "v_proj")
281
- or hasattr(module, "qkv_proj")
282
- )
283
-
284
- def _is_mlp_module(module: Module):
285
- return "mlp" in module.__class__.__name__.lower() and (
286
- hasattr(module, "gate_proj") or hasattr(module, "up_proj")
287
- )
288
-
289
- def _valid_fp4_quant(layer_list: List[torch.nn.Linear]):
290
- """
291
- Return True if all the linear layers in the layer_list are
292
- NVFP4A16 quantized.
293
- """
294
- for layer in layer_list:
295
- scheme = getattr(layer, "quantization_scheme", None)
296
- if scheme is None:
297
- return False
298
-
299
- weight_quant_args = scheme.weights
300
-
301
- if weight_quant_args is None:
302
- return False
303
-
304
- if not is_fp4(quantization_args=weight_quant_args):
305
- return False
306
- return True
307
-
308
- for name, submodule in iter_named_quantizable_modules(
309
- model,
310
- include_attn=True,
311
- include_mlp=True,
312
- ):
313
-
314
- if _is_attention_module(submodule):
315
- # already fused/treated as one layer
316
- if hasattr(submodule, "qkv_proj"):
317
- continue
318
-
319
- if not _valid_fp4_quant(
320
- [submodule.q_proj, submodule.v_proj, submodule.k_proj]
321
- ):
322
- continue
323
-
324
- q_weight = submodule.q_proj.weight.data
325
- v_weight = submodule.v_proj.weight.data
326
- k_weight = submodule.k_proj.weight.data
327
-
328
- value = generate_global_scale(
329
- input_tensor=torch.cat((q_weight, v_weight, k_weight), dim=0)
330
- )
331
-
332
- update_parameter_data(submodule.q_proj, value, "weight_global_scale")
333
- update_parameter_data(submodule.k_proj, value, "weight_global_scale")
334
- update_parameter_data(submodule.v_proj, value, "weight_global_scale")
335
-
336
- if _is_mlp_module(submodule):
337
- if not _valid_fp4_quant([submodule.gate_proj, submodule.up_proj]):
338
- continue
339
-
340
- gate_data = submodule.gate_proj.weight.data
341
- up_data = submodule.up_proj.weight.data
342
-
343
- value = generate_global_scale(
344
- input_tensor=torch.cat((gate_data, up_data), dim=0)
345
- )
346
-
347
- update_parameter_data(submodule.gate_proj, value, "weight_global_scale")
348
- update_parameter_data(submodule.up_proj, value, "weight_global_scale")
@@ -53,6 +53,7 @@ class FP4_E2M1_DATA(FloatArgs):
53
53
  min = -6.0
54
54
 
55
55
  @staticmethod
56
+ @torch.compile
56
57
  def cast_to_fp4(x):
57
58
  sign = torch.sign(x)
58
59
  x = torch.abs(x)
@@ -47,7 +47,7 @@ __all__ = [
47
47
  "compute_dynamic_scales_and_zp",
48
48
  "calculate_range",
49
49
  "calculate_qparams",
50
- "generate_global_scale",
50
+ "generate_gparam",
51
51
  "is_fp4",
52
52
  ]
53
53
 
@@ -81,7 +81,7 @@ def calculate_qparams(
81
81
  currently only applied/supported for Fp4
82
82
 
83
83
  :return: tuple of the calculated scale(s) and zero point(s). For FP4, the calculated
84
- scale if of dtype FP8
84
+ scale is of dtype FP8
85
85
  """
86
86
  # based on the implementations for consuming quantized values,
87
87
  # 0.0 must always be representable within the quantized range
@@ -475,8 +475,9 @@ def parse_out_kv_cache_args(
475
475
  return kv_cache_args, quant_scheme_to_layers
476
476
 
477
477
 
478
- def generate_global_scale(
479
- input_tensor: torch.Tensor,
478
+ def generate_gparam(
479
+ updated_min_val: torch.Tensor,
480
+ updated_max_val: torch.Tensor,
480
481
  scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
481
482
  quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
482
483
  dtype: Optional[torch.dtype] = torch.float32,
@@ -490,7 +491,8 @@ def generate_global_scale(
490
491
  attempts to use the entire FP8 dtype range while mapping a per-group max
491
492
  to the FP4 max.
492
493
  """
493
- scale_dtype = scale_data.dtype
494
- tensor_amax = torch.abs(input_tensor.data).max().to(dtype)
495
- global_scale = scale_data.max * quant_data.max / tensor_amax
494
+ min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val))
495
+ max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
496
+ max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
497
+ global_scale = scale_data.max * quant_data.max / max_val_pos
496
498
  return global_scale.to(dtype)
@@ -28,15 +28,18 @@ Utilities associated with offloading functionality provided by `accelerate`.
28
28
  import contextlib
29
29
  import warnings
30
30
  from functools import wraps
31
- from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
31
+ from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union
32
32
 
33
33
  import torch
34
34
 
35
35
 
36
36
  try:
37
+ from accelerate import dispatch_model
37
38
  from accelerate.hooks import (
38
39
  AlignDevicesHook,
39
40
  add_hook_to_module,
41
+ attach_align_device_hook,
42
+ named_module_tensors,
40
43
  remove_hook_from_module,
41
44
  )
42
45
  from accelerate.utils import (
@@ -54,6 +57,9 @@ except ImportError:
54
57
  OffloadedWeightsLoader = None
55
58
  PrefixedDataset = None
56
59
  set_module_tensor_to_device = None
60
+ named_module_tensors = None
61
+ dispatch_model = None
62
+ attach_align_device_hook = None
57
63
 
58
64
 
59
65
  __all__ = [
@@ -70,6 +76,9 @@ __all__ = [
70
76
  "disable_offload",
71
77
  "align_modules",
72
78
  "align_module_device",
79
+ "register_offload_module",
80
+ "delete_offload_module",
81
+ "force_cpu_offload",
73
82
  ]
74
83
 
75
84
 
@@ -77,6 +86,11 @@ def check_accelerate(fallback: Any):
77
86
  def decorator(func: Callable[[Any], Any]):
78
87
  if not _has_accelerate:
79
88
 
89
+ if fallback == "error":
90
+ raise ValueError(
91
+ "Please install `accelerate` in order to use this function"
92
+ )
93
+
80
94
  @wraps(func)
81
95
  def fallback_fn(*args, **kwargs):
82
96
  return fallback
@@ -346,6 +360,7 @@ def delete_from_weights_map(
346
360
  )
347
361
 
348
362
 
363
+ @check_accelerate(fallback=contextlib.nullcontext())
349
364
  @contextlib.contextmanager
350
365
  def disable_offload(module: torch.nn.Module):
351
366
  """
@@ -362,6 +377,7 @@ def disable_offload(module: torch.nn.Module):
362
377
  yield
363
378
 
364
379
 
380
+ @check_accelerate(fallback=contextlib.nullcontext())
365
381
  @contextlib.contextmanager
366
382
  def align_modules(
367
383
  modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
@@ -383,6 +399,123 @@ def align_modules(
383
399
  yield
384
400
 
385
401
 
402
+ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
403
+ """
404
+ Register a submodule with offloading if the parent module is offloaded
405
+
406
+ :param base: module to attach submodule to
407
+ :param name: name of submodule
408
+ :param module: submodule to attach
409
+ """
410
+
411
+ if has_offloaded_params(base):
412
+ hook: AlignDevicesHook = base._hf_hook
413
+ assert hook.offload
414
+ assert hook.weights_map is not None
415
+ assert hook.tied_params_map is not None
416
+
417
+ # offloading kwargs for submodule
418
+ place_submodules = False
419
+ offload_buffers = True
420
+
421
+ # copy device offloading arguments from parent
422
+ current_device = next(base.parameters()).device # assume base has parameters
423
+ offload_device = get_offloaded_device(base)
424
+
425
+ # offload parameters to weights map
426
+ for param_name, param in named_module_tensors(
427
+ module, include_buffers=offload_buffers, recurse=place_submodules
428
+ ):
429
+ offloaded = param.to(offload_device)
430
+ hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
431
+ offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
432
+
433
+ # if the parent places submodules, offload here
434
+ if hook.place_submodules:
435
+ set_module_tensor_to_device(module, param_name, current_device)
436
+
437
+ # if the parent does not place submodules, then add a hook
438
+ # parameters are offloaded by `add_hook_to_module`
439
+ if not hook.place_submodules:
440
+ weights_map = PrefixedDataset(
441
+ hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}."
442
+ )
443
+
444
+ submodule_hook = AlignDevicesHook(
445
+ execution_device=hook.execution_device,
446
+ offload=hook.offload,
447
+ io_same_device=False,
448
+ weights_map=weights_map,
449
+ offload_buffers=offload_buffers,
450
+ place_submodules=place_submodules,
451
+ skip_keys=None,
452
+ tied_params_map=hook.tied_params_map,
453
+ )
454
+ add_hook_to_module(module, submodule_hook)
455
+
456
+ base.register_module(name, module)
457
+
458
+ # (1): Since we cannot know which pointers are shared when we add parameters in an
459
+ # online way, assume that all pointers are shared. This comes at no runtime cost
460
+
461
+
462
+ def delete_offload_module(base: torch.nn.Module, name: str):
463
+ """
464
+ Delete a submodule from a model which may contain offloading
465
+ :param base: parent module to delete submodule from
466
+ :param name: name of submodule on parent
467
+ """
468
+ module: torch.nn.Module = getattr(base, name)
469
+
470
+ for param_name, _ in list(module.named_parameters()):
471
+ delete_offload_parameter(module, param_name)
472
+
473
+ delattr(base, name)
474
+
475
+
476
+ @check_accelerate(fallback="error")
477
+ def force_cpu_offload(
478
+ module: torch.nn.Module, execution_device: torch.device
479
+ ) -> torch.nn.Module:
480
+ """
481
+ Force cpu offloading a module, primarily used for testing
482
+
483
+ :param module: module containing parameters to offload
484
+ :param execution_device: execution device submodules
485
+ :return: module with hooks to perform cpu offloading
486
+ """
487
+ # edge case: there is a bug in `dispatch_model` which causes
488
+ # the function to only work if the model contains submodules
489
+ if next(module.children(), None) is None:
490
+ attach_align_device_hook(
491
+ module,
492
+ execution_device=execution_device,
493
+ offload=True,
494
+ weights_map=module.state_dict(),
495
+ tied_params_map={},
496
+ )
497
+ return module
498
+
499
+ device_map = {}
500
+
501
+ def collect_device_map(name: List[str], module: torch.nn.Module):
502
+ if next(module.parameters(recurse=False), None) is not None:
503
+ device_map[".".join(name)] = "cpu"
504
+ return
505
+
506
+ else:
507
+ for submodule_name, submodule in module.named_children():
508
+ name.append(submodule_name)
509
+ collect_device_map(name, submodule)
510
+ name.pop()
511
+
512
+ collect_device_map([], module)
513
+
514
+ return dispatch_model(
515
+ module, device_map, main_device=execution_device, force_hooks=True
516
+ )
517
+
518
+
386
519
  """ Upstreamed Functions """
387
520
 
388
521
 
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.5.a20250530'
20
+ __version__ = version = '0.9.5.a20250603'
21
21
  __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250530
3
+ Version: 0.9.5a20250603
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -61,8 +61,8 @@ def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor:
61
61
  [
62
62
  QuantizationStrategy.GROUP,
63
63
  128,
64
- torch.rand((512, 8, 1)) * 0.01,
65
- torch.zeros((512, 8, 1), dtype=torch.int8),
64
+ torch.rand((512, 8)) * 0.01,
65
+ torch.zeros((512, 8), dtype=torch.int8),
66
66
  ],
67
67
  [
68
68
  QuantizationStrategy.CHANNEL,
@@ -79,7 +79,7 @@ def test_quant_format(strategy, group_size, sc, zp):
79
79
  "dummy.weight_zero_point": torch.tensor(zp, dtype=torch.float32),
80
80
  }
81
81
  if group_size is not None:
82
- dense_state_dict["dummy.weight_g_idx"] = make_dummy_g_idx(512, group_size)
82
+ dense_state_dict["dummy.weight_g_idx"] = make_dummy_g_idx(1024, group_size)
83
83
 
84
84
  quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)
85
85
 
@@ -53,8 +53,8 @@ def get_dummy_quant_config(strategy, group_size=None, symmetric=True):
53
53
  QuantizationStrategy.GROUP,
54
54
  True,
55
55
  128,
56
- torch.rand((512, 8, 1)) * 0.01,
57
- torch.zeros((512, 8, 1), dtype=torch.int8),
56
+ torch.rand((512, 8)) * 0.01,
57
+ torch.zeros((512, 8), dtype=torch.int8),
58
58
  ],
59
59
  [
60
60
  QuantizationStrategy.CHANNEL,
@@ -108,8 +108,8 @@ def test_forward_quantize(
108
108
  "int",
109
109
  QuantizationStrategy.GROUP,
110
110
  128,
111
- torch.rand((512, 8, 1)) * 0.01,
112
- torch.zeros((512, 8, 1)),
111
+ torch.rand((512, 8)) * 0.01,
112
+ torch.zeros((512, 8)),
113
113
  None,
114
114
  ),
115
115
  (
@@ -117,8 +117,8 @@ def test_forward_quantize(
117
117
  "int",
118
118
  QuantizationStrategy.GROUP,
119
119
  128,
120
- torch.rand((512, 8, 1)) * 0.01,
121
- torch.zeros((512, 8, 1)),
120
+ torch.rand((512, 8)) * 0.01,
121
+ torch.zeros((512, 8)),
122
122
  make_dummy_g_idx(1024, 128),
123
123
  ),
124
124
  (
@@ -135,8 +135,8 @@ def test_forward_quantize(
135
135
  "float",
136
136
  QuantizationStrategy.GROUP,
137
137
  128,
138
- torch.rand((512, 8, 1)) * 0.01,
139
- torch.zeros((512, 8, 1)),
138
+ torch.rand((512, 8)) * 0.01,
139
+ torch.zeros((512, 8)),
140
140
  None,
141
141
  ),
142
142
  (
@@ -144,8 +144,8 @@ def test_forward_quantize(
144
144
  "float",
145
145
  QuantizationStrategy.GROUP,
146
146
  128,
147
- torch.rand((512, 8, 1)) * 0.01,
148
- torch.zeros((512, 8, 1)),
147
+ torch.rand((512, 8)) * 0.01,
148
+ torch.zeros((512, 8)),
149
149
  make_dummy_g_idx(1024, 128),
150
150
  ),
151
151
  ],
@@ -174,8 +174,8 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx
174
174
  "int",
175
175
  QuantizationStrategy.GROUP,
176
176
  128,
177
- torch.rand((512, 8, 1)) * 0.01,
178
- torch.zeros((512, 8, 1)),
177
+ torch.rand((512, 8)) * 0.01,
178
+ torch.zeros((512, 8)),
179
179
  None,
180
180
  ),
181
181
  (
@@ -183,8 +183,8 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx
183
183
  "int",
184
184
  QuantizationStrategy.GROUP,
185
185
  128,
186
- torch.rand((512, 8, 1)) * 0.01,
187
- torch.zeros((512, 8, 1)),
186
+ torch.rand((512, 8)) * 0.01,
187
+ torch.zeros((512, 8)),
188
188
  make_dummy_g_idx(1024, 128),
189
189
  ),
190
190
  ],
@@ -20,10 +20,7 @@ from compressed_tensors.quantization import (
20
20
  QuantizationArgs,
21
21
  QuantizationStrategy,
22
22
  )
23
- from compressed_tensors.quantization.utils import (
24
- calculate_qparams,
25
- generate_global_scale,
26
- )
23
+ from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
27
24
 
28
25
 
29
26
  @pytest.mark.parametrize(
@@ -70,7 +67,8 @@ def test_fused_global_scales():
70
67
  layer = torch.nn.Linear(7, 8)
71
68
  max_tensor_value = torch.abs(layer.weight.data).max()
72
69
  # use defaults
73
- global_scale = generate_global_scale(layer.weight)
70
+ min_val, max_val = torch.aminmax(layer.weight)
71
+ global_scale = generate_gparam(min_val.data, max_val.data)
74
72
  # max value should be = (448 * 6) / global_scale
75
73
  assert max_tensor_value == pytest.approx(
76
74
  FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001