compressed-tensors 0.12.3a20251030__tar.gz → 0.12.3a20251110__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/.github/workflows/test-check.yaml +1 -1
  2. {compressed_tensors-0.12.3a20251030/src/compressed_tensors.egg-info → compressed_tensors-0.12.3a20251110}/PKG-INFO +1 -1
  3. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/quantized_compressors/base.py +8 -1
  4. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +3 -2
  5. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/forward.py +6 -9
  6. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/initialize.py +11 -17
  7. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/quant_args.py +73 -8
  8. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/quant_config.py +0 -1
  9. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/quant_scheme.py +7 -0
  10. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/utils/helpers.py +45 -43
  11. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/version.py +1 -1
  12. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  13. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors.egg-info/SOURCES.txt +0 -5
  14. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/model_compressors/test_model_compressor.py +13 -2
  15. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_apply.py +40 -7
  16. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +1 -1
  17. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_initialize.py +15 -3
  18. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_lifecycle.py +1 -1
  19. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_static_lifecycle.py +5 -0
  20. compressed_tensors-0.12.3a20251030/.github/workflows/build-test.yml +0 -57
  21. compressed_tensors-0.12.3a20251030/.github/workflows/build.yml +0 -134
  22. compressed_tensors-0.12.3a20251030/.github/workflows/post-release-nightly-build.yml +0 -15
  23. compressed_tensors-0.12.3a20251030/.github/workflows/test.yml +0 -187
  24. compressed_tensors-0.12.3a20251030/.github/workflows/trigger-all.yml +0 -45
  25. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/.github/.gitkeep +0 -0
  26. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/.github/actions/test/action.yml +0 -0
  27. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/.github/scripts/step-status +0 -0
  28. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/.github/workflows/quality-check.yaml +0 -0
  29. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/.gitignore +0 -0
  30. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/LICENSE +0 -0
  31. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/Makefile +0 -0
  32. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/README.md +0 -0
  33. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  34. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/examples/bit_packing/int4_config.json +0 -0
  35. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/examples/bitmask_compression.ipynb +0 -0
  36. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  37. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  38. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/examples/llama_1.1b/example_quant_config.json +0 -0
  39. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  40. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/examples/quantize_and_pack_int4.ipynb +0 -0
  41. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/pyproject.toml +0 -0
  42. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/setup.cfg +0 -0
  43. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/setup.py +0 -0
  44. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/__init__.py +0 -0
  45. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/README.md +0 -0
  46. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/__init__.py +0 -0
  47. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/base.py +0 -0
  48. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/__init__.py +0 -0
  49. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/base.py +0 -0
  50. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/helpers.py +0 -0
  51. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  52. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  53. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  54. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  55. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  56. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  57. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  58. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  59. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  60. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  61. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  62. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  63. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/__init__.py +0 -0
  64. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/base.py +0 -0
  65. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/dense.py +0 -0
  66. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/format.py +0 -0
  67. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  68. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  69. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/linear/__init__.py +0 -0
  70. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  71. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/logger.py +0 -0
  72. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/modeling/__init__.py +0 -0
  73. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/modeling/attention.py +0 -0
  74. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/modeling/kvcache.py +0 -0
  75. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/__init__.py +0 -0
  76. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  77. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  78. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  79. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  80. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/quant_metadata.py +0 -0
  81. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  82. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/utils/mxfp4_utils.py +0 -0
  83. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/registry/__init__.py +0 -0
  84. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/registry/registry.py +0 -0
  85. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/__init__.py +0 -0
  86. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/apply.py +0 -0
  87. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  88. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/factory/base.py +0 -0
  89. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/factory/hadamard.py +0 -0
  90. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/factory/matrix_multiply.py +0 -0
  91. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  92. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/transform_args.py +0 -0
  93. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/transform_config.py +0 -0
  94. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  95. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  96. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  97. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  98. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/utils/matrix.py +0 -0
  99. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/__init__.py +0 -0
  100. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/helpers.py +0 -0
  101. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/internal.py +0 -0
  102. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/match.py +0 -0
  103. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/offload.py +0 -0
  104. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/permutations_24.py +0 -0
  105. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  106. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  107. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/type.py +0 -0
  108. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  109. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors.egg-info/requires.txt +0 -0
  110. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  111. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/__init__.py +0 -0
  112. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/conftest.py +0 -0
  113. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/mock_observer.py +0 -0
  114. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/__init__.py +0 -0
  115. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/model_compressors/__init__.py +0 -0
  116. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  117. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/test_fp4_quant.py +0 -0
  118. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  119. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  120. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  121. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +0 -0
  122. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  123. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  124. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  125. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  126. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  127. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_configs/__init__.py +0 -0
  128. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_configs/test_base.py +0 -0
  129. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_configs/test_infer_quant.py +0 -0
  130. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  131. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_linear/__init__.py +0 -0
  132. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_linear/test_compressed_linear.py +0 -0
  133. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_modeling/test_attention_and_cache.py +0 -0
  134. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/__init__.py +0 -0
  135. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/__init__.py +0 -0
  136. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/conftest.py +0 -0
  137. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  138. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  139. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_configs/__init__.py +0 -0
  140. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  141. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  142. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_quant_args.py +0 -0
  143. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_quant_config.py +0 -0
  144. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_quant_scheme.py +0 -0
  145. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  146. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_utils/test_mxfp4_utils.py +0 -0
  147. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_registry.py +0 -0
  148. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_transform/conftest.py +0 -0
  149. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_transform/factory/test_correctness.py +0 -0
  150. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_transform/factory/test_memory.py +0 -0
  151. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_transform/factory/test_serialization.py +0 -0
  152. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_transform/test_transform_args.py +0 -0
  153. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_transform/test_transform_config.py +0 -0
  154. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_transform/test_transform_scheme.py +0 -0
  155. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_transform/utils/test_hadamard.py +0 -0
  156. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_utils/__init__.py +0 -0
  157. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_utils/test_helpers.py +0 -0
  158. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_utils/test_match.py +0 -0
  159. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_utils/test_offload.py +0 -0
  160. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_utils/test_safetensors_load.py +0 -0
  161. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/test_utils/test_type.py +0 -0
  162. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/tests/testing_utils.py +0 -0
  163. {compressed_tensors-0.12.3a20251030 → compressed_tensors-0.12.3a20251110}/utils/copyright.py +0 -0
@@ -12,7 +12,7 @@ on:
12
12
 
13
13
  jobs:
14
14
  python-tests:
15
- runs-on: ubuntu-22.04
15
+ runs-on: ibm-wdc-k8s-vllm-h100-solo
16
16
  env:
17
17
  HF_TOKEN: ${{ secrets.HF_RED_HAT_READ_ONLY }}
18
18
  steps:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251030
3
+ Version: 0.12.3a20251110
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/vllm-project/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -90,7 +90,6 @@ class BaseQuantizationCompressor(BaseCompressor):
90
90
  desc = "Compressing with quantization"
91
91
  for name in tqdm(uncompressed_names, desc=desc, disable=(not show_progress)):
92
92
  value = model_state[name]
93
-
94
93
  # compress weights
95
94
  if name.endswith("weight"):
96
95
  prefix = name.removesuffix("weight")
@@ -129,10 +128,18 @@ class BaseQuantizationCompressor(BaseCompressor):
129
128
  if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
130
129
  continue
131
130
 
131
+ if name.endswith("weight_scale") and self._skip_scale():
132
+ continue
133
+
132
134
  compressed_dict[name] = value.to(compression_device)
133
135
 
134
136
  return compressed_dict
135
137
 
138
+ def _skip_scale(self):
139
+ from compressed_tensors.compressors import NVFP4PackedCompressor
140
+
141
+ return isinstance(self, NVFP4PackedCompressor)
142
+
136
143
  def _skip_zp(
137
144
  self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
138
145
  ) -> bool:
@@ -26,7 +26,7 @@ from compressed_tensors.quantization.lifecycle.forward import dequantize, quanti
26
26
  from torch import Tensor
27
27
 
28
28
 
29
- __all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"]
29
+ __all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8", "NVFP4PackedCompressor"]
30
30
 
31
31
  FLOAT_TO_E2M1 = [
32
32
  0.0,
@@ -103,6 +103,7 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
103
103
  if device is not None:
104
104
  weight_packed = weight_packed.to(device)
105
105
  compressed_dict["weight_packed"] = weight_packed
106
+ compressed_dict["weight_scale"] = scale.to(quantization_args.scale_dtype)
106
107
  return compressed_dict
107
108
 
108
109
  def decompress_weight(
@@ -111,8 +112,8 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
111
112
  quantization_args: Optional[QuantizationArgs] = None,
112
113
  ) -> torch.Tensor:
113
114
  weight = compressed_data["weight_packed"]
114
- scale = compressed_data["weight_scale"]
115
115
  global_scale = compressed_data["weight_global_scale"]
116
+ scale = compressed_data["weight_scale"]
116
117
  m, n = weight.shape
117
118
  # TODO: use a user provided dequant dtype
118
119
  unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
@@ -21,7 +21,7 @@ from compressed_tensors.quantization.quant_args import (
21
21
  DynamicType,
22
22
  QuantizationArgs,
23
23
  QuantizationStrategy,
24
- round_to_quantized_type,
24
+ round_to_quantized_type_args,
25
25
  )
26
26
  from compressed_tensors.quantization.quant_config import QuantizationStatus
27
27
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
@@ -466,20 +466,17 @@ def _quantize(
466
466
  # if a global scale is optionally provided, use it
467
467
  # to further scale the local `scale` parameter
468
468
  if global_scale is not None:
469
- scale = scale.to(global_scale.dtype) / global_scale
469
+ scale = scale / global_scale
470
470
 
471
471
  scaled = x / scale
472
472
 
473
473
  if zero_point is not None:
474
474
  scaled += zero_point.to(x.dtype)
475
475
 
476
- # clamp first because cast isn't guaranteed to be saturated (ie for fp8)
477
- clamped_value = torch.clamp(
478
- scaled,
479
- q_min,
480
- q_max,
476
+ # clamp and round
477
+ quantized_value = round_to_quantized_type_args(
478
+ tensor=scaled, args=args, min=q_min, max=q_max
481
479
  )
482
- quantized_value = round_to_quantized_type(clamped_value, args)
483
480
 
484
481
  if dtype is not None:
485
482
  quantized_value = quantized_value.to(dtype)
@@ -499,7 +496,7 @@ def _dequantize(
499
496
  # if a global scale is optionally provided, use it
500
497
  # to further scale the local `scale` parameter
501
498
  if global_scale is not None:
502
- scale = scale.to(global_scale.dtype) / global_scale
499
+ scale = scale / global_scale
503
500
 
504
501
  dequant_value = x_q.to(scale.dtype)
505
502
 
@@ -24,7 +24,6 @@ from compressed_tensors.modeling import (
24
24
  QuantizedKVCache,
25
25
  )
26
26
  from compressed_tensors.quantization import (
27
- FP8_E4M3_DATA,
28
27
  ActivationOrdering,
29
28
  DynamicType,
30
29
  QuantizationArgs,
@@ -36,7 +35,7 @@ from compressed_tensors.quantization import (
36
35
  from compressed_tensors.quantization.lifecycle.forward import (
37
36
  wrap_module_forward_quantized,
38
37
  )
39
- from compressed_tensors.quantization.utils import is_fp4, strategy_cdiv
38
+ from compressed_tensors.quantization.utils import strategy_cdiv
40
39
  from compressed_tensors.utils import (
41
40
  disable_hf_hook,
42
41
  get_execution_device,
@@ -250,20 +249,13 @@ def initialize_qparams(
250
249
 
251
250
  # 2. Identify quantization scale and zp dtype
252
251
  scale_dtype = observed_dtype
253
-
254
- if is_fp4(quantization_args=quantization_args):
255
- scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
256
- else:
257
- # TODO: consider erroring out in the future as if the dtype if not one of these,
258
- # there is likely bug
259
- if scale_dtype not in [
260
- torch.float16,
261
- torch.bfloat16,
262
- torch.float32,
263
- torch.float64,
264
- ]:
265
- scale_dtype = torch.bfloat16
266
- zp_dtype = quantization_args.pytorch_dtype()
252
+ if scale_dtype not in [
253
+ torch.float16,
254
+ torch.bfloat16,
255
+ torch.float32,
256
+ torch.float64,
257
+ ]:
258
+ scale_dtype = torch.float16
267
259
 
268
260
  # 3. Initializes scale/zp for the module
269
261
  init_scale = Parameter(
@@ -274,7 +266,9 @@ def initialize_qparams(
274
266
 
275
267
  if force_zero_point or not quantization_args.symmetric:
276
268
  init_zero_point = Parameter(
277
- torch.zeros(expected_shape, device=device, dtype=zp_dtype),
269
+ torch.zeros(
270
+ expected_shape, device=device, dtype=quantization_args.zp_dtype
271
+ ),
278
272
  requires_grad=False,
279
273
  )
280
274
  register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
@@ -19,7 +19,15 @@ from typing import Any, Dict, List, Optional, Union
19
19
  import torch
20
20
  from compressed_tensors.utils import Aliasable
21
21
  from compressed_tensors.utils.helpers import deprecated
22
- from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
22
+ from compressed_tensors.utils.type import TorchDtype
23
+ from pydantic import (
24
+ BaseModel,
25
+ ConfigDict,
26
+ Field,
27
+ field_serializer,
28
+ field_validator,
29
+ model_validator,
30
+ )
23
31
 
24
32
 
25
33
  __all__ = [
@@ -30,7 +38,8 @@ __all__ = [
30
38
  "QuantizationType",
31
39
  "QuantizationStrategy",
32
40
  "QuantizationArgs",
33
- "round_to_quantized_type",
41
+ "round_to_quantized_type_args",
42
+ "round_to_quantized_type_dtype",
34
43
  "ActivationOrdering",
35
44
  "DynamicType",
36
45
  ]
@@ -174,6 +183,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
174
183
  block_structure: Optional[List[int]] = None
175
184
  dynamic: Union[DynamicType, bool] = False
176
185
  actorder: Union[ActivationOrdering, bool, None] = None
186
+ scale_dtype: Optional[TorchDtype] = None
187
+ zp_dtype: Optional[TorchDtype] = None
177
188
  observer: Optional[str] = Field(
178
189
  default=None,
179
190
  description=(
@@ -189,6 +200,12 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
189
200
  ),
190
201
  )
191
202
 
203
+ @field_serializer("zp_dtype")
204
+ def serialize_dtype(self, dtype: torch.dtype):
205
+ if self.symmetric:
206
+ return None
207
+ return str(dtype)
208
+
192
209
  @field_validator("type", mode="before")
193
210
  def validate_type(cls, value) -> QuantizationType:
194
211
  if isinstance(value, str):
@@ -266,6 +283,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
266
283
  dynamic = model.dynamic
267
284
  observer = model.observer
268
285
  dynamic = model.dynamic
286
+ zp_dtype = model.zp_dtype
269
287
 
270
288
  # infer strategy
271
289
  if strategy is None:
@@ -353,9 +371,16 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
353
371
  # default to minmax for non-dynamic cases
354
372
  observer = "minmax"
355
373
 
374
+ if zp_dtype is None:
375
+ if model.num_bits == 4 and model.type == QuantizationType.FLOAT:
376
+ zp_dtype = FP8_E4M3_DATA.dtype
377
+ else:
378
+ zp_dtype = model.pytorch_dtype()
379
+
356
380
  # write back modified values
357
381
  model.strategy = strategy
358
382
  model.observer = observer
383
+ model.zp_dtype = zp_dtype
359
384
  return model
360
385
 
361
386
  def pytorch_dtype(self) -> torch.dtype:
@@ -381,18 +406,56 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
381
406
  model_config = ConfigDict(extra="forbid")
382
407
 
383
408
 
384
- def round_to_quantized_type(
385
- tensor: torch.Tensor, args: QuantizationArgs
409
+ def round_to_quantized_type_dtype(
410
+ tensor: torch.Tensor,
411
+ dtype: torch.dtype,
412
+ cast_to_original_dtype: Optional[bool] = True,
386
413
  ) -> torch.Tensor:
387
414
  """
388
- Rounds each element of the input tensor to the nearest quantized representation,
389
- keeping to original dtype
415
+ Rounds an input tensor to the nearest quantized representation given a dtype.
416
+ The original dtype is kept post-rounding.
390
417
 
391
418
  :param tensor: tensor to round
392
- :param args: QuantizationArgs to pull appropriate dtype from
419
+ :param dtype: dtype to use for rounding
420
+ :param cast_to_original_dtype: whether or not we cast the rounded tensor to
421
+ the original dtype
393
422
  :return: rounded tensor
394
423
  """
395
424
  original_dtype = tensor.dtype
425
+ if torch.is_floating_point(torch.tensor([], dtype=dtype)):
426
+ finfo = torch.finfo(dtype)
427
+ rounded = torch.clamp(tensor, finfo.min, finfo.max).to(dtype)
428
+ else:
429
+ iinfo = torch.iinfo(dtype)
430
+ rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max)).to(dtype)
431
+
432
+ if cast_to_original_dtype:
433
+ return rounded.to(original_dtype)
434
+ return rounded
435
+
436
+
437
+ def round_to_quantized_type_args(
438
+ tensor: torch.Tensor,
439
+ args: QuantizationArgs,
440
+ min: torch.Tensor,
441
+ max: torch.Tensor,
442
+ cast_to_original_dtype: Optional[bool] = True,
443
+ ) -> torch.Tensor:
444
+ """
445
+ Rounds an input tensor to the nearest quantized representation given
446
+ qunatization args. The original dtype is kept post-rounding.
447
+
448
+ :param tensor: tensor to round
449
+ :param args: quantization args to use for rounding
450
+ :param min: min value to use for clamping
451
+ :param max: max value to use for clamping
452
+ :param cast_to_original_dtype: whether or not we cast the rounded tensor to
453
+ the original dtype
454
+ :return: rounded tensor
455
+ """
456
+
457
+ original_dtype = tensor.dtype
458
+ tensor = torch.clamp(tensor, min, max)
396
459
  if args.type == QuantizationType.FLOAT:
397
460
  if args.num_bits == 8:
398
461
  rounded = tensor.to(FP8_E4M3_DATA.dtype)
@@ -405,4 +468,6 @@ def round_to_quantized_type(
405
468
  else:
406
469
  raise ValueError(f"Invalid quantization type {args.type}")
407
470
 
408
- return rounded.to(original_dtype)
471
+ if cast_to_original_dtype:
472
+ return rounded.to(original_dtype)
473
+ return rounded
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  from collections import defaultdict
16
15
  from enum import Enum
17
16
  from typing import Annotated, Any, Dict, List, Optional, Set, Union
@@ -18,6 +18,7 @@ from typing import List, Optional
18
18
 
19
19
  from compressed_tensors.config import CompressionFormat
20
20
  from compressed_tensors.quantization.quant_args import (
21
+ FP8_E4M3_DATA,
21
22
  DynamicType,
22
23
  QuantizationArgs,
23
24
  QuantizationStrategy,
@@ -160,6 +161,8 @@ NVFP4A16 = dict(
160
161
  symmetric=True,
161
162
  dynamic=False,
162
163
  group_size=16,
164
+ scale_dtype=FP8_E4M3_DATA.dtype,
165
+ zp_dtype=FP8_E4M3_DATA.dtype,
163
166
  )
164
167
  )
165
168
 
@@ -173,6 +176,8 @@ NVFP4 = dict(
173
176
  dynamic=False,
174
177
  group_size=16,
175
178
  observer="static_minmax",
179
+ scale_dtype=FP8_E4M3_DATA.dtype,
180
+ zp_dtype=FP8_E4M3_DATA.dtype,
176
181
  ),
177
182
  input_activations=QuantizationArgs(
178
183
  num_bits=4,
@@ -182,6 +187,8 @@ NVFP4 = dict(
182
187
  dynamic=DynamicType.LOCAL,
183
188
  group_size=16,
184
189
  observer="static_minmax",
190
+ scale_dtype=FP8_E4M3_DATA.dtype,
191
+ zp_dtype=FP8_E4M3_DATA.dtype,
185
192
  ),
186
193
  )
187
194
 
@@ -24,6 +24,7 @@ from compressed_tensors.quantization.quant_args import (
24
24
  QuantizationArgs,
25
25
  QuantizationStrategy,
26
26
  QuantizationType,
27
+ round_to_quantized_type_dtype,
27
28
  )
28
29
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
29
30
  from compressed_tensors.utils import deprecated
@@ -46,7 +47,6 @@ __all__ = [
46
47
  "calculate_range",
47
48
  "calculate_qparams",
48
49
  "generate_gparam",
49
- "is_fp4",
50
50
  "strategy_cdiv",
51
51
  ]
52
52
 
@@ -57,13 +57,6 @@ KV_CACHE_TARGETS = ["re:.*self_attn$"]
57
57
  _LOGGER: logging.Logger = logging.getLogger(__name__)
58
58
 
59
59
 
60
- def is_fp4(quantization_args: QuantizationArgs):
61
- return (
62
- quantization_args.num_bits == 4
63
- and quantization_args.type == QuantizationType.FLOAT
64
- )
65
-
66
-
67
60
  def calculate_qparams(
68
61
  min_vals: Tensor,
69
62
  max_vals: Tensor,
@@ -92,52 +85,50 @@ def calculate_qparams(
92
85
  bit_min, bit_max = calculate_range(quantization_args, device)
93
86
  bit_range = bit_max - bit_min
94
87
 
95
- if is_fp4(quantization_args=quantization_args):
96
- zp_dtype = FP8_E4M3_DATA.dtype
97
- else:
98
- zp_dtype = quantization_args.pytorch_dtype()
99
-
88
+ # 1. Generate scale and zero-point
100
89
  if quantization_args.symmetric:
101
90
  max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
102
-
103
- if is_fp4(quantization_args=quantization_args) and global_scale is not None:
104
- # Conditionally scale the generated local scale by a global_scale
105
- scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
106
- scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
107
- scales = scales.to(FP8_E4M3_DATA.dtype)
108
-
109
- else:
110
- scales = max_val_pos / (float(bit_range) / 2)
111
-
112
- # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
113
- if scales.dtype == FP8_E4M3_DATA.dtype:
114
- # torch.clamp not supported for FP8
115
- # use the next largest fp8 value from 0
116
- scales = torch.where(
117
- scales == 0,
118
- torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
119
- scales,
120
- )
121
- else:
122
- scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
123
-
91
+ scales = max_val_pos / (float(bit_range) / 2)
124
92
  zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
125
93
  else:
126
- if is_fp4(quantization_args=quantization_args):
94
+ if (
95
+ quantization_args.num_bits == 4
96
+ and quantization_args.type == QuantizationType.FLOAT
97
+ ):
127
98
  raise NotImplementedError(
128
99
  "Asymmetric Quantization is not supported for FP4"
129
100
  )
130
-
131
101
  scales = (max_vals - min_vals) / float(bit_range)
132
- scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
133
102
  zero_points = bit_min - (min_vals / scales)
134
103
  zero_points = torch.clamp(zero_points, bit_min, bit_max)
135
104
 
136
- # match zero-points to quantized type
137
- # if casting to int, use round instead of truncate
138
- if quantization_args.type == QuantizationType.INT:
139
- zero_points = torch.round(zero_points)
140
- zero_points = zero_points.to(zp_dtype)
105
+ # 2. Conditionally scale the generated local scale by a global_scale
106
+ if global_scale is not None:
107
+ scales = global_scale * scales
108
+
109
+ # 3. Conditionally round the scale to the quantized dtype, if scale_dtype is set
110
+ if quantization_args.scale_dtype is not None:
111
+ scales = round_to_quantized_type_dtype(
112
+ scales, dtype=quantization_args.scale_dtype
113
+ )
114
+
115
+ # 4. Update any 0s with small values to
116
+ # prevent div by 0
117
+ eps = _get_dtype_eps(
118
+ dtype=quantization_args.scale_dtype
119
+ if quantization_args.scale_dtype is not None
120
+ else scales.dtype
121
+ )
122
+ scales = torch.where(
123
+ scales == 0,
124
+ torch.tensor(eps, dtype=scales.dtype, device=device),
125
+ scales,
126
+ )
127
+
128
+ # 5. Round the zp to zp_dtype
129
+ zero_points = round_to_quantized_type_dtype(
130
+ zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False
131
+ )
141
132
 
142
133
  if scales.ndim == 0:
143
134
  scales = scales.reshape(1)
@@ -455,3 +446,14 @@ def strategy_cdiv(
455
446
  logger.bind(log_once=True).warning(message)
456
447
 
457
448
  return dividend
449
+
450
+
451
+ def _get_dtype_eps(dtype: torch.dtype) -> float:
452
+ if dtype == FP8_E4M3_DATA.dtype:
453
+ return 0.125
454
+ elif dtype == FP4_E2M1_DATA.dtype:
455
+ return 0.25
456
+ elif torch.is_floating_point(torch.tensor([], dtype=dtype)):
457
+ return torch.finfo(dtype).eps
458
+ else:
459
+ return 1
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.12.3.a20251030'
20
+ __version__ = version = '0.12.3.a20251110'
21
21
  __version_tuple__ = version_tuple = (0, 12, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251030
3
+ Version: 0.12.3a20251110
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/vllm-project/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -8,13 +8,8 @@ setup.py
8
8
  .github/.gitkeep
9
9
  .github/actions/test/action.yml
10
10
  .github/scripts/step-status
11
- .github/workflows/build-test.yml
12
- .github/workflows/build.yml
13
- .github/workflows/post-release-nightly-build.yml
14
11
  .github/workflows/quality-check.yaml
15
12
  .github/workflows/test-check.yaml
16
- .github/workflows/test.yml
17
- .github/workflows/trigger-all.yml
18
13
  examples/bitmask_compression.ipynb
19
14
  examples/quantize_and_pack_int4.ipynb
20
15
  examples/bit_packing/ex_quantize_and_pack.py
@@ -22,6 +22,7 @@ import torch.nn as nn
22
22
  from compressed_tensors.compressors import ModelCompressor
23
23
  from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
24
24
  from compressed_tensors.quantization import (
25
+ FP8_E4M3_DATA,
25
26
  QuantizationArgs,
26
27
  QuantizationConfig,
27
28
  QuantizationScheme,
@@ -425,8 +426,18 @@ def test_multiple_quant_compressors():
425
426
  format=CompressionFormat.float_quantized.value,
426
427
  )
427
428
 
428
- input_activations = QuantizationArgs(num_bits=4, type="float")
429
- weights = QuantizationArgs(num_bits=4, type="float")
429
+ input_activations = QuantizationArgs(
430
+ num_bits=4,
431
+ type="float",
432
+ scale_dtype=FP8_E4M3_DATA.dtype,
433
+ zp_dtype=FP8_E4M3_DATA.dtype,
434
+ )
435
+ weights = QuantizationArgs(
436
+ num_bits=4,
437
+ type="float",
438
+ scale_dtype=FP8_E4M3_DATA.dtype,
439
+ zp_dtype=FP8_E4M3_DATA.dtype,
440
+ )
430
441
 
431
442
  scheme_nvfp4 = QuantizationScheme(
432
443
  targets=["Linear"],
@@ -22,6 +22,7 @@ import torch
22
22
  from compressed_tensors.config import CompressionFormat
23
23
  from compressed_tensors.quantization import (
24
24
  DEFAULT_QUANTIZATION_METHOD,
25
+ FP8_E4M3_DATA,
25
26
  QuantizationArgs,
26
27
  QuantizationConfig,
27
28
  QuantizationScheme,
@@ -153,7 +154,11 @@ def test_apply_quantization_config_tinyllama():
153
154
  "linear": QuantizationScheme(
154
155
  targets=["Linear"],
155
156
  input_activations=QuantizationArgs(
156
- num_bits=8, type="float", strategy="tensor"
157
+ num_bits=8,
158
+ type="float",
159
+ strategy="tensor",
160
+ scale_dtype=FP8_E4M3_DATA.dtype,
161
+ zp_dtype=torch.float,
157
162
  ),
158
163
  )
159
164
  }
@@ -163,7 +168,11 @@ def test_apply_quantization_config_tinyllama():
163
168
  "linear": QuantizationScheme(
164
169
  targets=["Linear"],
165
170
  input_activations=QuantizationArgs(
166
- num_bits=8, type="float", strategy="tensor"
171
+ num_bits=8,
172
+ type="float",
173
+ strategy="tensor",
174
+ scale_dtype=FP8_E4M3_DATA.dtype,
175
+ zp_dtype=torch.float,
167
176
  ),
168
177
  )
169
178
  },
@@ -176,7 +185,11 @@ def test_apply_quantization_config_tinyllama():
176
185
  QuantizationConfig(
177
186
  config_groups={},
178
187
  kv_cache_scheme=QuantizationArgs(
179
- num_bits=8, type="float", strategy="tensor"
188
+ num_bits=8,
189
+ type="float",
190
+ strategy="tensor",
191
+ scale_dtype=FP8_E4M3_DATA.dtype,
192
+ zp_dtype=torch.float,
180
193
  ),
181
194
  ),
182
195
  QuantizationConfig(
@@ -184,12 +197,20 @@ def test_apply_quantization_config_tinyllama():
184
197
  "attention": QuantizationScheme(
185
198
  targets=["LlamaAttention"],
186
199
  input_activations=QuantizationArgs(
187
- num_bits=8, type="float", strategy="tensor"
200
+ num_bits=8,
201
+ type="float",
202
+ strategy="tensor",
203
+ scale_dtype=FP8_E4M3_DATA.dtype,
204
+ zp_dtype=torch.float,
188
205
  ),
189
206
  )
190
207
  },
191
208
  kv_cache_scheme=QuantizationArgs(
192
- num_bits=8, type="float", strategy="tensor"
209
+ num_bits=8,
210
+ type="float",
211
+ strategy="tensor",
212
+ scale_dtype=FP8_E4M3_DATA.dtype,
213
+ zp_dtype=torch.float,
193
214
  ),
194
215
  ),
195
216
  ],
@@ -448,7 +469,13 @@ def test_apply_kv_cache():
448
469
  with init_empty_weights():
449
470
  model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M")
450
471
 
451
- args = QuantizationArgs(num_bits=8, type="float", strategy="tensor")
472
+ args = QuantizationArgs(
473
+ num_bits=8,
474
+ type="float",
475
+ strategy="tensor",
476
+ scale_dtype=FP8_E4M3_DATA.dtype,
477
+ zp_dtype=torch.float,
478
+ )
452
479
  config = QuantizationConfig(config_groups={}, kv_cache_scheme=args)
453
480
 
454
481
  apply_quantization_config(model, config)
@@ -468,7 +495,13 @@ def test_apply_attention():
468
495
 
469
496
  scheme = QuantizationScheme(
470
497
  targets=["LlamaAttention"],
471
- input_activations=QuantizationArgs(num_bits=8, type="float", strategy="tensor"),
498
+ input_activations=QuantizationArgs(
499
+ num_bits=8,
500
+ type="float",
501
+ strategy="tensor",
502
+ scale_dtype=FP8_E4M3_DATA.dtype,
503
+ zp_dtype=torch.float,
504
+ ),
472
505
  )
473
506
  config = QuantizationConfig(config_groups={"attention": scheme})
474
507
 
@@ -79,7 +79,7 @@ def _test_layer_dynamic_quantization_status(
79
79
  def get_tinyllama_model():
80
80
  return AutoModelForCausalLM.from_pretrained(
81
81
  "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
82
- torch_dtype="auto",
82
+ torch_dtype=torch.bfloat16,
83
83
  )
84
84
 
85
85