compressed-tensors 0.12.3a20251203__tar.gz → 0.12.3a20251212__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {compressed_tensors-0.12.3a20251203/src/compressed_tensors.egg-info → compressed_tensors-0.12.3a20251212}/PKG-INFO +1 -1
  2. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/base.py +33 -1
  3. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/quantized_compressors/base.py +24 -39
  4. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +40 -5
  5. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +35 -7
  6. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/config/format.py +2 -0
  7. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/quant_scheme.py +40 -1
  8. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/utils/helpers.py +14 -3
  9. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/utils/mxfp4_utils.py +25 -19
  10. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/version.py +1 -1
  11. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  12. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/test_utils/test_mxfp4_utils.py +21 -3
  13. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/.github/.gitkeep +0 -0
  14. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/.github/actions/test/action.yml +0 -0
  15. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/.github/scripts/step-status +0 -0
  16. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/.github/workflows/quality-check.yaml +0 -0
  17. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/.github/workflows/test-check.yaml +0 -0
  18. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/.gitignore +0 -0
  19. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/LICENSE +0 -0
  20. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/Makefile +0 -0
  21. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/README.md +0 -0
  22. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  23. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/examples/bit_packing/int4_config.json +0 -0
  24. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/examples/bitmask_compression.ipynb +0 -0
  25. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  26. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  27. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/examples/llama_1.1b/example_quant_config.json +0 -0
  28. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  29. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/examples/quantize_and_pack_int4.ipynb +0 -0
  30. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/pyproject.toml +0 -0
  31. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/setup.cfg +0 -0
  32. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/setup.py +0 -0
  33. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/__init__.py +0 -0
  34. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/README.md +0 -0
  35. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/__init__.py +0 -0
  36. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/base.py +0 -0
  37. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/__init__.py +0 -0
  38. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/helpers.py +0 -0
  39. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  40. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  41. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  42. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  43. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  44. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  45. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  46. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  47. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  48. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  49. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  50. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/config/__init__.py +0 -0
  51. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/config/base.py +0 -0
  52. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/config/dense.py +0 -0
  53. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  54. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  55. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/linear/__init__.py +0 -0
  56. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  57. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/logger.py +0 -0
  58. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/modeling/__init__.py +0 -0
  59. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/modeling/attention.py +0 -0
  60. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/modeling/kvcache.py +0 -0
  61. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/__init__.py +0 -0
  62. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  63. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  64. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  65. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  66. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  67. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  68. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/quant_args.py +0 -0
  69. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/quant_config.py +0 -0
  70. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/quant_metadata.py +0 -0
  71. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  72. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/registry/__init__.py +0 -0
  73. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/registry/registry.py +0 -0
  74. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/__init__.py +0 -0
  75. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/apply.py +0 -0
  76. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  77. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/factory/base.py +0 -0
  78. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/factory/hadamard.py +0 -0
  79. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/factory/matrix_multiply.py +0 -0
  80. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  81. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/transform_args.py +0 -0
  82. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/transform_config.py +0 -0
  83. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  84. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  85. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  86. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  87. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/transform/utils/matrix.py +0 -0
  88. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/utils/__init__.py +0 -0
  89. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/utils/helpers.py +0 -0
  90. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/utils/internal.py +0 -0
  91. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/utils/match.py +0 -0
  92. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/utils/offload.py +0 -0
  93. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/utils/permutations_24.py +0 -0
  94. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  95. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  96. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors/utils/type.py +0 -0
  97. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors.egg-info/SOURCES.txt +0 -0
  98. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  99. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors.egg-info/requires.txt +0 -0
  100. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  101. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/__init__.py +0 -0
  102. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/conftest.py +0 -0
  103. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/mock_observer.py +0 -0
  104. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/__init__.py +0 -0
  105. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/model_compressors/__init__.py +0 -0
  106. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  107. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  108. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/quantized_compressors/test_fp4_quant.py +0 -0
  109. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  110. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  111. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  112. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +0 -0
  113. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  114. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  115. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  116. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  117. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  118. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_configs/__init__.py +0 -0
  119. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_configs/test_base.py +0 -0
  120. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_configs/test_infer_quant.py +0 -0
  121. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  122. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_linear/__init__.py +0 -0
  123. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_linear/test_compressed_linear.py +0 -0
  124. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_modeling/test_attention_and_cache.py +0 -0
  125. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/__init__.py +0 -0
  126. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/lifecycle/__init__.py +0 -0
  127. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/lifecycle/conftest.py +0 -0
  128. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  129. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  130. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  131. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  132. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  133. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  134. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/lifecycle/test_static_lifecycle.py +0 -0
  135. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/test_configs/__init__.py +0 -0
  136. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  137. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  138. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/test_quant_args.py +0 -0
  139. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/test_quant_config.py +0 -0
  140. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/test_quant_scheme.py +0 -0
  141. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  142. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_registry.py +0 -0
  143. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_transform/conftest.py +0 -0
  144. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_transform/factory/test_correctness.py +0 -0
  145. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_transform/factory/test_memory.py +0 -0
  146. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_transform/factory/test_serialization.py +0 -0
  147. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_transform/test_transform_args.py +0 -0
  148. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_transform/test_transform_config.py +0 -0
  149. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_transform/test_transform_scheme.py +0 -0
  150. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_transform/utils/test_hadamard.py +0 -0
  151. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_utils/__init__.py +0 -0
  152. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_utils/test_helpers.py +0 -0
  153. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_utils/test_match.py +0 -0
  154. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_utils/test_offload.py +0 -0
  155. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_utils/test_safetensors_load.py +0 -0
  156. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/test_utils/test_type.py +0 -0
  157. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/tests/testing_utils.py +0 -0
  158. {compressed_tensors-0.12.3a20251203 → compressed_tensors-0.12.3a20251212}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251203
3
+ Version: 0.12.3a20251212
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/vllm-project/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -20,6 +20,11 @@ from compressed_tensors.config import SparsityCompressionConfig
20
20
  from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
21
21
  from compressed_tensors.registry import RegistryMixin
22
22
  from compressed_tensors.utils import has_offloaded_params
23
+ from compressed_tensors.utils.offload import (
24
+ delete_offload_parameter,
25
+ get_offloaded_device,
26
+ register_offload_parameter,
27
+ )
23
28
  from torch import Tensor
24
29
  from torch.nn import Module
25
30
 
@@ -185,10 +190,37 @@ class BaseCompressor(RegistryMixin, ABC):
185
190
  for name, parameter in module.named_parameters():
186
191
  compressed_data[name] = parameter
187
192
 
188
- return self.decompress_weight(
193
+ # Save references to original parameters before decompression
194
+ original_scale = compressed_data.get("weight_scale")
195
+ original_zp = compressed_data.get("weight_zero_point")
196
+
197
+ # NOTE: decompress_weight may modify compressed_data dict in-place
198
+ # This is subtle but allows us to update the module's qparams with
199
+ # the unpacked values.
200
+ # TODO: Consider refactoring to return modified qparams explicitly
201
+ result = self.decompress_weight(
189
202
  compressed_data=compressed_data, quantization_args=quantization_args
190
203
  ).to(device)
191
204
 
205
+ # Update module's parameters only if they were modified
206
+ for param_name, original_param in [
207
+ ("weight_scale", original_scale),
208
+ ("weight_zero_point", original_zp),
209
+ ]:
210
+ if (
211
+ param_name in compressed_data
212
+ and compressed_data[param_name] is not original_param
213
+ ):
214
+ # Delete the old parameter and register the updated one
215
+ delete_offload_parameter(module, param_name)
216
+ offload_device = get_offloaded_device(module)
217
+ param = torch.nn.Parameter(
218
+ compressed_data[param_name], requires_grad=False
219
+ )
220
+ register_offload_parameter(module, param_name, param, offload_device)
221
+
222
+ return result
223
+
192
224
  def decompress_weight(
193
225
  self, compressed_data: Dict[str, Tensor], **kwargs
194
226
  ) -> torch.Tensor:
@@ -18,7 +18,7 @@ from typing import Any, Dict, Generator, Tuple, Union
18
18
 
19
19
  import torch
20
20
  from compressed_tensors.compressors.base import BaseCompressor
21
- from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
21
+ from compressed_tensors.quantization import QuantizationScheme
22
22
  from compressed_tensors.utils import (
23
23
  get_nested_mappings_from_state_dict,
24
24
  get_nested_weight_mappings,
@@ -85,6 +85,7 @@ class BaseQuantizationCompressor(BaseCompressor):
85
85
  """
86
86
  uncompressed_names = list(model_state.keys())
87
87
  compressed_dict = {}
88
+ compressed_param_names = set()
88
89
 
89
90
  # compress values
90
91
  desc = "Compressing with quantization"
@@ -119,54 +120,38 @@ class BaseQuantizationCompressor(BaseCompressor):
119
120
  device=compression_device,
120
121
  )
121
122
 
122
- # update state dict
123
+ # update state dict and track which params were added
123
124
  for key, value in compressed_values.items():
124
- compressed_dict[prefix + key] = value.to(compression_device)
125
+ full_name = prefix + key
126
+ compressed_dict[full_name] = value.to(compression_device)
127
+ compressed_param_names.add(full_name)
125
128
 
126
129
  else:
127
- # omit saving zero points for symmetric or packed quantization
128
- if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
130
+ # Skip qparams already added by compress_weight
131
+ if name in compressed_param_names:
129
132
  continue
130
133
 
131
- if name.endswith("weight_scale") and self._skip_scale():
132
- continue
134
+ # for symmetric quantization, omit zero_point
135
+ # manually because it wasn't handled in compress_weight
136
+ if name.endswith("weight_zero_point"):
137
+ module_path = name.rsplit(".", 1)[0]
138
+ if (
139
+ module_path in names_to_scheme
140
+ and names_to_scheme[module_path].weights.symmetric
141
+ ):
142
+ continue
143
+ # Call compress_zp if available (for PackedQuantizationCompressor)
144
+ if module_path in names_to_scheme and hasattr(self, "compress_zp"):
145
+ value = self.compress_zp(
146
+ value, names_to_scheme[module_path].weights
147
+ )
148
+ if value is None:
149
+ continue
133
150
 
134
151
  compressed_dict[name] = value.to(compression_device)
135
152
 
136
153
  return compressed_dict
137
154
 
138
- def _skip_scale(self):
139
- from compressed_tensors.compressors import NVFP4PackedCompressor
140
-
141
- return isinstance(self, NVFP4PackedCompressor)
142
-
143
- def _skip_zp(
144
- self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
145
- ) -> bool:
146
- from compressed_tensors.compressors import PackedQuantizationCompressor
147
-
148
- module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
149
- scheme = names_to_scheme[module_name]
150
-
151
- if zp_name == "weight_zero_point":
152
- args = scheme.weights
153
- if zp_name == "input_zero_point":
154
- args = scheme.input_activations
155
- if zp_name == "output_zero_point":
156
- args = scheme.output_activations
157
-
158
- symmetric = args.symmetric
159
- packable_strategies = [
160
- QuantizationStrategy.GROUP.value,
161
- QuantizationStrategy.CHANNEL.value,
162
- ]
163
- packed = (
164
- isinstance(self, PackedQuantizationCompressor)
165
- and args.strategy in packable_strategies
166
- )
167
-
168
- return symmetric or packed
169
-
170
155
  def decompress(
171
156
  self,
172
157
  path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
@@ -56,7 +56,6 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
56
56
  return (
57
57
  "weight_packed",
58
58
  "weight_scale",
59
- "weight_zero_point",
60
59
  "weight_global_scale",
61
60
  )
62
61
 
@@ -73,13 +72,20 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
73
72
  :param quantization_args: quantization parameters for the weight
74
73
  :return: dictionary mapping compressed parameter names to shape and dtype
75
74
  """
76
- output = {
75
+ return {
77
76
  "weight_packed": (
78
77
  torch.Size((weight_shape[0], weight_shape[1] // 2)),
79
78
  torch.uint8,
80
79
  ),
81
80
  }
82
- return output
81
+
82
+ def compress_scale(
83
+ self,
84
+ scale: Tensor,
85
+ quantization_args: QuantizationArgs,
86
+ ) -> Dict[str, torch.Tensor]:
87
+ assert quantization_args.scale_dtype is not None
88
+ return scale.to(quantization_args.scale_dtype)
83
89
 
84
90
  def compress_weight(
85
91
  self,
@@ -103,7 +109,16 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
103
109
  if device is not None:
104
110
  weight_packed = weight_packed.to(device)
105
111
  compressed_dict["weight_packed"] = weight_packed
106
- compressed_dict["weight_scale"] = scale.to(quantization_args.scale_dtype)
112
+ compressed_dict["weight_scale"] = self.compress_scale(
113
+ scale=scale, quantization_args=quantization_args
114
+ )
115
+
116
+ if global_scale is None:
117
+ raise ValueError(
118
+ "NVFP4 quantization requires global_scale (TENSOR_GROUP strategy). "
119
+ "Use TENSOR_GROUP strategy instead of GROUP for FP4 quantization."
120
+ )
121
+
107
122
  return compressed_dict
108
123
 
109
124
  def decompress_weight(
@@ -117,6 +132,12 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
117
132
  m, n = weight.shape
118
133
  # TODO: use a user provided dequant dtype
119
134
  unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
135
+
136
+ # cast scale dtype to match unpacked dtype for dequantization
137
+ if scale.dtype != unpacked.dtype:
138
+ scale = scale.to(unpacked.dtype)
139
+ compressed_data["weight_scale"] = scale
140
+
120
141
  decompressed_weight = dequantize(
121
142
  x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
122
143
  )
@@ -130,7 +151,21 @@ class MXFP4PackedCompressor(NVFP4PackedCompressor):
130
151
  Alias for mxfp4 quantized models
131
152
  """
132
153
 
133
- pass
154
+ def compress_scale(
155
+ self,
156
+ scale: Tensor,
157
+ quantization_args: QuantizationArgs,
158
+ ) -> Dict[str, torch.Tensor]:
159
+ assert quantization_args.scale_dtype is not None
160
+ scale_exp = 127 + torch.floor(torch.log2(scale)).to(torch.int32) - 2
161
+ return scale_exp.to(quantization_args.scale_dtype)
162
+
163
+ def decompress_weight(
164
+ self,
165
+ compressed_data: Dict[str, Tensor],
166
+ quantization_args: Optional[QuantizationArgs] = None,
167
+ ) -> torch.Tensor:
168
+ raise NotImplementedError("MXFP4 Decompression is currently not supported")
134
169
 
135
170
 
136
171
  @torch.compile(fullgraph=True, dynamic=True)
@@ -64,25 +64,34 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
64
64
  """
65
65
  pack_factor = 32 // quantization_args.num_bits
66
66
  packed_size = math.ceil(weight_shape[1] / pack_factor)
67
- packed_size_zp = math.ceil(weight_shape[0] / pack_factor)
68
67
  output = {
69
68
  "weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
70
69
  "weight_shape": (torch.Size((2,)), torch.int32),
71
70
  }
72
- if not quantization_args.symmetric and quantization_args.strategy in [
71
+
72
+ # Add weight_scale - always needed for quantization
73
+ if quantization_args.strategy in [
73
74
  QuantizationStrategy.GROUP.value,
74
75
  QuantizationStrategy.CHANNEL.value,
75
76
  ]:
76
- zp_factor = (
77
+ shape_factor = (
77
78
  quantization_args.group_size
78
79
  if quantization_args.strategy == QuantizationStrategy.GROUP.value
79
80
  else weight_shape[-1]
80
81
  )
81
-
82
- output["weight_zero_point"] = (
83
- torch.Size((packed_size_zp, weight_shape[-1] // zp_factor)),
84
- torch.int32,
82
+ scale_cols = math.ceil(weight_shape[-1] / shape_factor)
83
+ output["weight_scale"] = (
84
+ torch.Size((weight_shape[0], scale_cols)),
85
+ quantization_args.scale_dtype,
85
86
  )
87
+
88
+ # Add weight_zero_point for asymmetric quantization
89
+ if not quantization_args.symmetric:
90
+ output["weight_zero_point"] = (
91
+ torch.Size((math.ceil(weight_shape[0] / pack_factor), scale_cols)),
92
+ torch.int32,
93
+ )
94
+
86
95
  return output
87
96
 
88
97
  def compress_weight(
@@ -175,6 +184,8 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
175
184
  zero_point = unpack_from_int32(
176
185
  zero_point, num_bits, original_zp_shape, packed_dim=0
177
186
  )
187
+ # Update the compressed_data dict with the unpacked zero_point
188
+ compressed_data["weight_zero_point"] = zero_point
178
189
 
179
190
  decompressed_weight = dequantize(
180
191
  x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
@@ -182,6 +193,20 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
182
193
 
183
194
  return decompressed_weight
184
195
 
196
+ def compress_zp(
197
+ self, zero_point: Tensor, quantization_args: Optional[QuantizationArgs] = None
198
+ ) -> Optional[Tensor]:
199
+ if zero_point is None or quantization_args.symmetric:
200
+ return None
201
+ if zero_point.dtype == torch.int32:
202
+ return zero_point
203
+ if quantization_args.strategy in [
204
+ QuantizationStrategy.GROUP.value,
205
+ QuantizationStrategy.CHANNEL.value,
206
+ ]:
207
+ return pack_to_int32(zero_point, quantization_args.num_bits, packed_dim=0)
208
+ return zero_point
209
+
185
210
 
186
211
  def pack_to_int32(
187
212
  value: torch.Tensor,
@@ -226,6 +251,9 @@ def pack_to_int32(
226
251
  if packed_dim == 0:
227
252
  value = value.transpose(0, 1)
228
253
 
254
+ # Ensure contiguous memory for .view() operation
255
+ value = value.contiguous()
256
+
229
257
  rows, cols = value.shape
230
258
  padded_cols = math.ceil(cols / pack_factor) * pack_factor
231
259
  pad_len = padded_cols - cols
@@ -50,6 +50,8 @@ def _get_quant_compression_format(
50
50
  is_weight_only = weight_args is not None and input_args is None
51
51
 
52
52
  if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
53
+ if weight_args.group_size == 32:
54
+ return CompressionFormat.mxfp4_pack_quantized
53
55
  return CompressionFormat.nvfp4_pack_quantized
54
56
 
55
57
  if is_weight_only: # w4a16 and w8a16
@@ -11,11 +11,11 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  import warnings
16
15
  from copy import deepcopy
17
16
  from typing import List, Optional
18
17
 
18
+ import torch
19
19
  from compressed_tensors.config import CompressionFormat
20
20
  from compressed_tensors.quantization.quant_args import (
21
21
  FP8_E4M3_DATA,
@@ -192,6 +192,43 @@ NVFP4 = dict(
192
192
  ),
193
193
  )
194
194
 
195
+ MXFP4A16 = dict(
196
+ weights=QuantizationArgs(
197
+ num_bits=4,
198
+ type=QuantizationType.FLOAT,
199
+ strategy=QuantizationStrategy.GROUP,
200
+ symmetric=True,
201
+ dynamic=False,
202
+ group_size=32,
203
+ scale_dtype=torch.uint8,
204
+ zp_dtype=torch.uint8,
205
+ )
206
+ )
207
+
208
+ MXFP4 = dict(
209
+ weights=QuantizationArgs(
210
+ num_bits=4,
211
+ type=QuantizationType.FLOAT,
212
+ strategy=QuantizationStrategy.GROUP,
213
+ symmetric=True,
214
+ dynamic=False,
215
+ group_size=32,
216
+ scale_dtype=torch.uint8,
217
+ zp_dtype=torch.uint8,
218
+ ),
219
+ input_activations=QuantizationArgs(
220
+ num_bits=4,
221
+ type=QuantizationType.FLOAT,
222
+ strategy=QuantizationStrategy.GROUP,
223
+ dynamic=True,
224
+ symmetric=True,
225
+ group_size=32,
226
+ scale_dtype=torch.uint8,
227
+ zp_dtype=torch.uint8,
228
+ ),
229
+ )
230
+
231
+
195
232
  # 8 bit integer weights and 8 bit activations quantization
196
233
  INT8_W8A8 = dict(
197
234
  weights=QuantizationArgs(
@@ -343,4 +380,6 @@ PRESET_SCHEMES = {
343
380
  "FP8_BLOCK": FP8_BLOCK,
344
381
  "NVFP4A16": NVFP4A16,
345
382
  "NVFP4": NVFP4,
383
+ "MXFP4A16": MXFP4A16,
384
+ "MXFP4": MXFP4,
346
385
  }
@@ -27,6 +27,11 @@ from compressed_tensors.quantization.quant_args import (
27
27
  round_to_quantized_type_dtype,
28
28
  )
29
29
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
30
+ from compressed_tensors.quantization.utils.mxfp4_utils import (
31
+ generate_mxfp4_scales,
32
+ maybe_convert_from_mxfp4_exp,
33
+ should_generatre_mxfp4_scales,
34
+ )
30
35
  from compressed_tensors.utils import deprecated
31
36
  from loguru import logger
32
37
  from torch import FloatTensor, IntTensor, Tensor
@@ -88,7 +93,10 @@ def calculate_qparams(
88
93
  # 1. Generate scale and zero-point
89
94
  if quantization_args.symmetric:
90
95
  max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
91
- scales = max_val_pos / (float(bit_range) / 2)
96
+ if should_generatre_mxfp4_scales(args=quantization_args):
97
+ scales = generate_mxfp4_scales(x=max_val_pos)
98
+ else:
99
+ scales = max_val_pos / (float(bit_range) / 2)
92
100
  zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
93
101
  else:
94
102
  if (
@@ -112,7 +120,10 @@ def calculate_qparams(
112
120
  scales, dtype=quantization_args.scale_dtype
113
121
  )
114
122
 
115
- # 4. Update any 0s with small values to
123
+ # 4. Optionally remove exponent
124
+ scales = maybe_convert_from_mxfp4_exp(quantization_args, scales)
125
+
126
+ # 5. Update any 0s with small values to
116
127
  # prevent div by 0
117
128
  eps = _get_dtype_eps(
118
129
  dtype=quantization_args.scale_dtype
@@ -125,7 +136,7 @@ def calculate_qparams(
125
136
  scales,
126
137
  )
127
138
 
128
- # 5. Round the zp to zp_dtype
139
+ # 6. Round the zp to zp_dtype
129
140
  zero_points = round_to_quantized_type_dtype(
130
141
  zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False
131
142
  )
@@ -13,16 +13,29 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import torch
16
- from compressed_tensors.quantization.quant_args import BFLOAT16_DATA, FP4_E2M1_DATA
16
+ from compressed_tensors.quantization.quant_args import (
17
+ BFLOAT16_DATA,
18
+ FP4_E2M1_DATA,
19
+ QuantizationArgs,
20
+ )
17
21
 
18
22
 
19
- __all__ = ["convert_mxfp4_exp_scale", "generate_mxfp4_scales", "round_to_power_2"]
23
+ __all__ = [
24
+ "maybe_convert_from_mxfp4_exp",
25
+ "generate_mxfp4_scales",
26
+ "round_to_power_2",
27
+ "should_generatre_mxfp4_scales",
28
+ ]
20
29
 
21
30
  # Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501
22
31
 
23
32
 
24
- def convert_mxfp4_exp_scale(
25
- scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16
33
+ def should_generatre_mxfp4_scales(args: QuantizationArgs):
34
+ return args.num_bits == 4 and args.type == "float" and args.group_size == 32
35
+
36
+
37
+ def maybe_convert_from_mxfp4_exp(
38
+ args: QuantizationArgs, scale: torch.Tensor
26
39
  ) -> torch.Tensor:
27
40
  """
28
41
  Converts mxfp4 scales. Scales are powers of 2, with the
@@ -32,10 +45,12 @@ def convert_mxfp4_exp_scale(
32
45
  :param scale: uint8 exponent scale
33
46
  :param dtype: dense dtype
34
47
  """
35
- assert scale.dtype == torch.uint8
36
- scale_exp = scale.to(torch.int32) - 127
37
- scale = 2.00 ** (scale_exp.to(torch.float))
38
- return scale.to(dtype)
48
+ original_dtype = scale.dtype
49
+ if should_generatre_mxfp4_scales(args):
50
+ scale_exp = scale.to(torch.int32) - 127
51
+ scale = 2.00 ** (scale_exp.to(torch.float))
52
+ return scale.to(original_dtype)
53
+ return scale
39
54
 
40
55
 
41
56
  def round_to_power_2(x: torch.Tensor) -> torch.Tensor:
@@ -77,21 +92,12 @@ def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor:
77
92
  Generate mxfp4 scales. The scales require the following steps
78
93
  1. Round to the closest power of 2
79
94
  2. Convert to exponent
80
- 3. Store in uint8
81
95
 
82
96
  Called when calculating qparams using observers.
83
97
 
84
98
  :param x: tensor to round to closest power of 2
85
- :returns uint8 scales as exponents
99
+ :returns scales as exponents
86
100
  """
87
101
  # Round to closest power of 2
88
102
  scale_power_2 = round_to_power_2(x)
89
- # Convert to exponent
90
- scale_exp = 127 + torch.floor(torch.log2(scale_power_2)).to(torch.int32) - 2
91
- # Clamp and store in uint8, as expected by mxfp4
92
- scale_exp = torch.clamp(
93
- scale_exp,
94
- max=torch.iinfo(torch.uint8).max,
95
- min=torch.iinfo(torch.uint8).min,
96
- )
97
- return scale_exp.to(torch.uint8)
103
+ return 127 + torch.floor(torch.log2(scale_power_2)) - 2
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.12.3.a20251203'
20
+ __version__ = version = '0.12.3.a20251212'
21
21
  __version_tuple__ = version_tuple = (0, 12, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251203
3
+ Version: 0.12.3a20251212
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/vllm-project/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -13,9 +13,10 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import torch
16
+ from compressed_tensors.quantization import round_to_quantized_type_dtype
16
17
  from compressed_tensors.quantization.utils import (
17
- convert_mxfp4_exp_scale,
18
18
  generate_mxfp4_scales,
19
+ maybe_convert_from_mxfp4_exp,
19
20
  round_to_power_2,
20
21
  )
21
22
 
@@ -61,6 +62,12 @@ def test_round_power_2():
61
62
 
62
63
 
63
64
  def test_mxfp4_scales_e2e():
65
+ from compressed_tensors.quantization.quant_args import (
66
+ QuantizationArgs,
67
+ QuantizationStrategy,
68
+ QuantizationType,
69
+ )
70
+
64
71
  mock_weight = torch.normal(mean=0.0002, std=0.0576, size=(2880, 2880))
65
72
 
66
73
  x = mock_weight.reshape(*mock_weight.shape[:-1], -1, 32).to(torch.bfloat16)
@@ -71,8 +78,19 @@ def test_mxfp4_scales_e2e():
71
78
  max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
72
79
  block_max = torch.max(torch.abs(min_vals), torch.abs(max_vals))
73
80
 
74
- scales_generated = generate_mxfp4_scales(block_max)
75
- converted_ct = convert_mxfp4_exp_scale(scales_generated)
81
+ args = QuantizationArgs(
82
+ num_bits=4,
83
+ type=QuantizationType.FLOAT,
84
+ strategy=QuantizationStrategy.GROUP,
85
+ group_size=32,
86
+ scale_dtype=torch.uint8,
87
+ zp_dtype=torch.uint8,
88
+ )
89
+
90
+ scales = generate_mxfp4_scales(block_max)
91
+ scales = round_to_quantized_type_dtype(scales, dtype=args.scale_dtype)
92
+
93
+ converted_ct = maybe_convert_from_mxfp4_exp(args=args, scale=scales)
76
94
 
77
95
  scales_exp = torch.log2(converted_ct)
78
96
  block_max_exp = torch.floor(torch.log2(round_to_power_2(block_max))) - 2