compressed-tensors 0.12.3a20251028__tar.gz → 0.12.3a20251110__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/.github/workflows/test-check.yaml +1 -1
  2. {compressed_tensors-0.12.3a20251028/src/compressed_tensors.egg-info → compressed_tensors-0.12.3a20251110}/PKG-INFO +1 -1
  3. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/quantized_compressors/base.py +8 -1
  4. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +3 -2
  5. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -9
  6. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/forward.py +7 -10
  7. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/initialize.py +11 -17
  8. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/quant_args.py +73 -8
  9. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/quant_config.py +0 -1
  10. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/quant_scheme.py +7 -0
  11. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/utils/helpers.py +45 -43
  12. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/factory/base.py +34 -3
  13. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/factory/hadamard.py +0 -1
  14. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/factory/matrix_multiply.py +2 -3
  15. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/transform_args.py +11 -4
  16. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/utils/matrix.py +13 -21
  17. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/version.py +1 -1
  18. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  19. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors.egg-info/SOURCES.txt +1 -5
  20. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/model_compressors/test_model_compressor.py +13 -2
  21. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/test_pack_quant.py +95 -5
  22. compressed_tensors-0.12.3a20251110/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +172 -0
  23. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_apply.py +40 -7
  24. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +1 -1
  25. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_initialize.py +15 -3
  26. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_lifecycle.py +1 -1
  27. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_static_lifecycle.py +5 -0
  28. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_transform/conftest.py +21 -1
  29. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_transform/factory/test_correctness.py +74 -7
  30. compressed_tensors-0.12.3a20251028/.github/workflows/build-test.yml +0 -57
  31. compressed_tensors-0.12.3a20251028/.github/workflows/build.yml +0 -134
  32. compressed_tensors-0.12.3a20251028/.github/workflows/post-release-nightly-build.yml +0 -15
  33. compressed_tensors-0.12.3a20251028/.github/workflows/test.yml +0 -187
  34. compressed_tensors-0.12.3a20251028/.github/workflows/trigger-all.yml +0 -45
  35. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/.github/.gitkeep +0 -0
  36. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/.github/actions/test/action.yml +0 -0
  37. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/.github/scripts/step-status +0 -0
  38. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/.github/workflows/quality-check.yaml +0 -0
  39. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/.gitignore +0 -0
  40. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/LICENSE +0 -0
  41. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/Makefile +0 -0
  42. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/README.md +0 -0
  43. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  44. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/examples/bit_packing/int4_config.json +0 -0
  45. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/examples/bitmask_compression.ipynb +0 -0
  46. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  47. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  48. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/examples/llama_1.1b/example_quant_config.json +0 -0
  49. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  50. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/examples/quantize_and_pack_int4.ipynb +0 -0
  51. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/pyproject.toml +0 -0
  52. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/setup.cfg +0 -0
  53. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/setup.py +0 -0
  54. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/__init__.py +0 -0
  55. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/README.md +0 -0
  56. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/__init__.py +0 -0
  57. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/base.py +0 -0
  58. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/__init__.py +0 -0
  59. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/base.py +0 -0
  60. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/helpers.py +0 -0
  61. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  62. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  63. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  64. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  65. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  66. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  67. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  68. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  69. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  70. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  71. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  72. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/__init__.py +0 -0
  73. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/base.py +0 -0
  74. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/dense.py +0 -0
  75. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/format.py +0 -0
  76. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  77. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  78. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/linear/__init__.py +0 -0
  79. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  80. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/logger.py +0 -0
  81. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/modeling/__init__.py +0 -0
  82. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/modeling/attention.py +0 -0
  83. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/modeling/kvcache.py +0 -0
  84. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/__init__.py +0 -0
  85. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  86. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  87. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  88. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  89. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/quant_metadata.py +0 -0
  90. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  91. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/quantization/utils/mxfp4_utils.py +0 -0
  92. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/registry/__init__.py +0 -0
  93. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/registry/registry.py +0 -0
  94. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/__init__.py +0 -0
  95. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/apply.py +0 -0
  96. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  97. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  98. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/transform_config.py +0 -0
  99. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  100. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  101. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  102. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  103. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/__init__.py +0 -0
  104. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/helpers.py +0 -0
  105. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/internal.py +0 -0
  106. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/match.py +0 -0
  107. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/offload.py +0 -0
  108. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/permutations_24.py +0 -0
  109. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  110. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  111. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors/utils/type.py +0 -0
  112. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  113. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors.egg-info/requires.txt +0 -0
  114. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  115. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/__init__.py +0 -0
  116. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/conftest.py +0 -0
  117. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/mock_observer.py +0 -0
  118. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/__init__.py +0 -0
  119. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/model_compressors/__init__.py +0 -0
  120. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  121. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/test_fp4_quant.py +0 -0
  122. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  123. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  124. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  125. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  126. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  127. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  128. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  129. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_configs/__init__.py +0 -0
  130. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_configs/test_base.py +0 -0
  131. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_configs/test_infer_quant.py +0 -0
  132. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  133. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_linear/__init__.py +0 -0
  134. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_linear/test_compressed_linear.py +0 -0
  135. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_modeling/test_attention_and_cache.py +0 -0
  136. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/__init__.py +0 -0
  137. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/__init__.py +0 -0
  138. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/conftest.py +0 -0
  139. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  140. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  141. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_configs/__init__.py +0 -0
  142. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  143. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  144. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_quant_args.py +0 -0
  145. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_quant_config.py +0 -0
  146. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_quant_scheme.py +0 -0
  147. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  148. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_quantization/test_utils/test_mxfp4_utils.py +0 -0
  149. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_registry.py +0 -0
  150. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_transform/factory/test_memory.py +0 -0
  151. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_transform/factory/test_serialization.py +0 -0
  152. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_transform/test_transform_args.py +0 -0
  153. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_transform/test_transform_config.py +0 -0
  154. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_transform/test_transform_scheme.py +0 -0
  155. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_transform/utils/test_hadamard.py +0 -0
  156. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_utils/__init__.py +0 -0
  157. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_utils/test_helpers.py +0 -0
  158. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_utils/test_match.py +0 -0
  159. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_utils/test_offload.py +0 -0
  160. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_utils/test_safetensors_load.py +0 -0
  161. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/test_utils/test_type.py +0 -0
  162. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/tests/testing_utils.py +0 -0
  163. {compressed_tensors-0.12.3a20251028 → compressed_tensors-0.12.3a20251110}/utils/copyright.py +0 -0
@@ -12,7 +12,7 @@ on:
12
12
 
13
13
  jobs:
14
14
  python-tests:
15
- runs-on: ubuntu-22.04
15
+ runs-on: ibm-wdc-k8s-vllm-h100-solo
16
16
  env:
17
17
  HF_TOKEN: ${{ secrets.HF_RED_HAT_READ_ONLY }}
18
18
  steps:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251028
3
+ Version: 0.12.3a20251110
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/vllm-project/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -90,7 +90,6 @@ class BaseQuantizationCompressor(BaseCompressor):
90
90
  desc = "Compressing with quantization"
91
91
  for name in tqdm(uncompressed_names, desc=desc, disable=(not show_progress)):
92
92
  value = model_state[name]
93
-
94
93
  # compress weights
95
94
  if name.endswith("weight"):
96
95
  prefix = name.removesuffix("weight")
@@ -129,10 +128,18 @@ class BaseQuantizationCompressor(BaseCompressor):
129
128
  if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
130
129
  continue
131
130
 
131
+ if name.endswith("weight_scale") and self._skip_scale():
132
+ continue
133
+
132
134
  compressed_dict[name] = value.to(compression_device)
133
135
 
134
136
  return compressed_dict
135
137
 
138
+ def _skip_scale(self):
139
+ from compressed_tensors.compressors import NVFP4PackedCompressor
140
+
141
+ return isinstance(self, NVFP4PackedCompressor)
142
+
136
143
  def _skip_zp(
137
144
  self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
138
145
  ) -> bool:
@@ -26,7 +26,7 @@ from compressed_tensors.quantization.lifecycle.forward import dequantize, quanti
26
26
  from torch import Tensor
27
27
 
28
28
 
29
- __all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"]
29
+ __all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8", "NVFP4PackedCompressor"]
30
30
 
31
31
  FLOAT_TO_E2M1 = [
32
32
  0.0,
@@ -103,6 +103,7 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
103
103
  if device is not None:
104
104
  weight_packed = weight_packed.to(device)
105
105
  compressed_dict["weight_packed"] = weight_packed
106
+ compressed_dict["weight_scale"] = scale.to(quantization_args.scale_dtype)
106
107
  return compressed_dict
107
108
 
108
109
  def decompress_weight(
@@ -111,8 +112,8 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
111
112
  quantization_args: Optional[QuantizationArgs] = None,
112
113
  ) -> torch.Tensor:
113
114
  weight = compressed_data["weight_packed"]
114
- scale = compressed_data["weight_scale"]
115
115
  global_scale = compressed_data["weight_global_scale"]
116
+ scale = compressed_data["weight_scale"]
116
117
  m, n = weight.shape
117
118
  # TODO: use a user provided dequant dtype
118
119
  unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
@@ -134,8 +134,6 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
134
134
  compressed_dict["weight_shape"] = weight_shape
135
135
  compressed_dict["weight_packed"] = packed_weight
136
136
 
137
- # We typically don't compress zp; apart from when using the packed_compressor
138
- # and when storing group/channel zp
139
137
  if not quantization_args.symmetric and quantization_args.strategy in [
140
138
  QuantizationStrategy.GROUP.value,
141
139
  QuantizationStrategy.CHANNEL.value,
@@ -143,7 +141,7 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
143
141
  packed_zp = pack_to_int32(
144
142
  zero_point, quantization_args.num_bits, packed_dim=0
145
143
  )
146
- compressed_dict["weight_zero_point"] = packed_zp
144
+ compressed_dict["weight_zero_point"] = packed_zp.contiguous()
147
145
  return compressed_dict
148
146
 
149
147
  def decompress_weight(
@@ -166,16 +164,13 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
166
164
  num_bits = quantization_args.num_bits
167
165
  unpacked = unpack_from_int32(weight, num_bits, original_shape)
168
166
 
169
- # NOTE: this will fail decompression as we don't currently handle packed zp on
170
- # decompression
171
167
  if not quantization_args.symmetric and quantization_args.strategy in [
172
168
  QuantizationStrategy.GROUP.value,
173
169
  QuantizationStrategy.CHANNEL.value,
174
170
  ]:
175
- raise ValueError(
176
- "Decompression of packed zero points is currently not supported"
177
- )
178
- assert zero_point is not None
171
+ assert (
172
+ zero_point is not None
173
+ ), "Asymmetric quantization requires zero-point values"
179
174
  original_zp_shape = (original_shape[0], scale.shape[-1])
180
175
  zero_point = unpack_from_int32(
181
176
  zero_point, num_bits, original_zp_shape, packed_dim=0
@@ -21,7 +21,7 @@ from compressed_tensors.quantization.quant_args import (
21
21
  DynamicType,
22
22
  QuantizationArgs,
23
23
  QuantizationStrategy,
24
- round_to_quantized_type,
24
+ round_to_quantized_type_args,
25
25
  )
26
26
  from compressed_tensors.quantization.quant_config import QuantizationStatus
27
27
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
@@ -278,7 +278,7 @@ def _process_quantization(
278
278
  if columns % group_size != 0:
279
279
  raise ValueError(
280
280
  "tensor column shape must be divisble "
281
- f"by the given group_size {group_size}"
281
+ f"by the given group_size {group_size} but got {columns}"
282
282
  )
283
283
 
284
284
  # support column-order (default) quantization as well as other orderings
@@ -466,20 +466,17 @@ def _quantize(
466
466
  # if a global scale is optionally provided, use it
467
467
  # to further scale the local `scale` parameter
468
468
  if global_scale is not None:
469
- scale = scale.to(global_scale.dtype) / global_scale
469
+ scale = scale / global_scale
470
470
 
471
471
  scaled = x / scale
472
472
 
473
473
  if zero_point is not None:
474
474
  scaled += zero_point.to(x.dtype)
475
475
 
476
- # clamp first because cast isn't guaranteed to be saturated (ie for fp8)
477
- clamped_value = torch.clamp(
478
- scaled,
479
- q_min,
480
- q_max,
476
+ # clamp and round
477
+ quantized_value = round_to_quantized_type_args(
478
+ tensor=scaled, args=args, min=q_min, max=q_max
481
479
  )
482
- quantized_value = round_to_quantized_type(clamped_value, args)
483
480
 
484
481
  if dtype is not None:
485
482
  quantized_value = quantized_value.to(dtype)
@@ -499,7 +496,7 @@ def _dequantize(
499
496
  # if a global scale is optionally provided, use it
500
497
  # to further scale the local `scale` parameter
501
498
  if global_scale is not None:
502
- scale = scale.to(global_scale.dtype) / global_scale
499
+ scale = scale / global_scale
503
500
 
504
501
  dequant_value = x_q.to(scale.dtype)
505
502
 
@@ -24,7 +24,6 @@ from compressed_tensors.modeling import (
24
24
  QuantizedKVCache,
25
25
  )
26
26
  from compressed_tensors.quantization import (
27
- FP8_E4M3_DATA,
28
27
  ActivationOrdering,
29
28
  DynamicType,
30
29
  QuantizationArgs,
@@ -36,7 +35,7 @@ from compressed_tensors.quantization import (
36
35
  from compressed_tensors.quantization.lifecycle.forward import (
37
36
  wrap_module_forward_quantized,
38
37
  )
39
- from compressed_tensors.quantization.utils import is_fp4, strategy_cdiv
38
+ from compressed_tensors.quantization.utils import strategy_cdiv
40
39
  from compressed_tensors.utils import (
41
40
  disable_hf_hook,
42
41
  get_execution_device,
@@ -250,20 +249,13 @@ def initialize_qparams(
250
249
 
251
250
  # 2. Identify quantization scale and zp dtype
252
251
  scale_dtype = observed_dtype
253
-
254
- if is_fp4(quantization_args=quantization_args):
255
- scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
256
- else:
257
- # TODO: consider erroring out in the future as if the dtype if not one of these,
258
- # there is likely bug
259
- if scale_dtype not in [
260
- torch.float16,
261
- torch.bfloat16,
262
- torch.float32,
263
- torch.float64,
264
- ]:
265
- scale_dtype = torch.bfloat16
266
- zp_dtype = quantization_args.pytorch_dtype()
252
+ if scale_dtype not in [
253
+ torch.float16,
254
+ torch.bfloat16,
255
+ torch.float32,
256
+ torch.float64,
257
+ ]:
258
+ scale_dtype = torch.float16
267
259
 
268
260
  # 3. Initializes scale/zp for the module
269
261
  init_scale = Parameter(
@@ -274,7 +266,9 @@ def initialize_qparams(
274
266
 
275
267
  if force_zero_point or not quantization_args.symmetric:
276
268
  init_zero_point = Parameter(
277
- torch.zeros(expected_shape, device=device, dtype=zp_dtype),
269
+ torch.zeros(
270
+ expected_shape, device=device, dtype=quantization_args.zp_dtype
271
+ ),
278
272
  requires_grad=False,
279
273
  )
280
274
  register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
@@ -19,7 +19,15 @@ from typing import Any, Dict, List, Optional, Union
19
19
  import torch
20
20
  from compressed_tensors.utils import Aliasable
21
21
  from compressed_tensors.utils.helpers import deprecated
22
- from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
22
+ from compressed_tensors.utils.type import TorchDtype
23
+ from pydantic import (
24
+ BaseModel,
25
+ ConfigDict,
26
+ Field,
27
+ field_serializer,
28
+ field_validator,
29
+ model_validator,
30
+ )
23
31
 
24
32
 
25
33
  __all__ = [
@@ -30,7 +38,8 @@ __all__ = [
30
38
  "QuantizationType",
31
39
  "QuantizationStrategy",
32
40
  "QuantizationArgs",
33
- "round_to_quantized_type",
41
+ "round_to_quantized_type_args",
42
+ "round_to_quantized_type_dtype",
34
43
  "ActivationOrdering",
35
44
  "DynamicType",
36
45
  ]
@@ -174,6 +183,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
174
183
  block_structure: Optional[List[int]] = None
175
184
  dynamic: Union[DynamicType, bool] = False
176
185
  actorder: Union[ActivationOrdering, bool, None] = None
186
+ scale_dtype: Optional[TorchDtype] = None
187
+ zp_dtype: Optional[TorchDtype] = None
177
188
  observer: Optional[str] = Field(
178
189
  default=None,
179
190
  description=(
@@ -189,6 +200,12 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
189
200
  ),
190
201
  )
191
202
 
203
+ @field_serializer("zp_dtype")
204
+ def serialize_dtype(self, dtype: torch.dtype):
205
+ if self.symmetric:
206
+ return None
207
+ return str(dtype)
208
+
192
209
  @field_validator("type", mode="before")
193
210
  def validate_type(cls, value) -> QuantizationType:
194
211
  if isinstance(value, str):
@@ -266,6 +283,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
266
283
  dynamic = model.dynamic
267
284
  observer = model.observer
268
285
  dynamic = model.dynamic
286
+ zp_dtype = model.zp_dtype
269
287
 
270
288
  # infer strategy
271
289
  if strategy is None:
@@ -353,9 +371,16 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
353
371
  # default to minmax for non-dynamic cases
354
372
  observer = "minmax"
355
373
 
374
+ if zp_dtype is None:
375
+ if model.num_bits == 4 and model.type == QuantizationType.FLOAT:
376
+ zp_dtype = FP8_E4M3_DATA.dtype
377
+ else:
378
+ zp_dtype = model.pytorch_dtype()
379
+
356
380
  # write back modified values
357
381
  model.strategy = strategy
358
382
  model.observer = observer
383
+ model.zp_dtype = zp_dtype
359
384
  return model
360
385
 
361
386
  def pytorch_dtype(self) -> torch.dtype:
@@ -381,18 +406,56 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
381
406
  model_config = ConfigDict(extra="forbid")
382
407
 
383
408
 
384
- def round_to_quantized_type(
385
- tensor: torch.Tensor, args: QuantizationArgs
409
+ def round_to_quantized_type_dtype(
410
+ tensor: torch.Tensor,
411
+ dtype: torch.dtype,
412
+ cast_to_original_dtype: Optional[bool] = True,
386
413
  ) -> torch.Tensor:
387
414
  """
388
- Rounds each element of the input tensor to the nearest quantized representation,
389
- keeping to original dtype
415
+ Rounds an input tensor to the nearest quantized representation given a dtype.
416
+ The original dtype is kept post-rounding.
390
417
 
391
418
  :param tensor: tensor to round
392
- :param args: QuantizationArgs to pull appropriate dtype from
419
+ :param dtype: dtype to use for rounding
420
+ :param cast_to_original_dtype: whether or not we cast the rounded tensor to
421
+ the original dtype
393
422
  :return: rounded tensor
394
423
  """
395
424
  original_dtype = tensor.dtype
425
+ if torch.is_floating_point(torch.tensor([], dtype=dtype)):
426
+ finfo = torch.finfo(dtype)
427
+ rounded = torch.clamp(tensor, finfo.min, finfo.max).to(dtype)
428
+ else:
429
+ iinfo = torch.iinfo(dtype)
430
+ rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max)).to(dtype)
431
+
432
+ if cast_to_original_dtype:
433
+ return rounded.to(original_dtype)
434
+ return rounded
435
+
436
+
437
+ def round_to_quantized_type_args(
438
+ tensor: torch.Tensor,
439
+ args: QuantizationArgs,
440
+ min: torch.Tensor,
441
+ max: torch.Tensor,
442
+ cast_to_original_dtype: Optional[bool] = True,
443
+ ) -> torch.Tensor:
444
+ """
445
+ Rounds an input tensor to the nearest quantized representation given
446
+ qunatization args. The original dtype is kept post-rounding.
447
+
448
+ :param tensor: tensor to round
449
+ :param args: quantization args to use for rounding
450
+ :param min: min value to use for clamping
451
+ :param max: max value to use for clamping
452
+ :param cast_to_original_dtype: whether or not we cast the rounded tensor to
453
+ the original dtype
454
+ :return: rounded tensor
455
+ """
456
+
457
+ original_dtype = tensor.dtype
458
+ tensor = torch.clamp(tensor, min, max)
396
459
  if args.type == QuantizationType.FLOAT:
397
460
  if args.num_bits == 8:
398
461
  rounded = tensor.to(FP8_E4M3_DATA.dtype)
@@ -405,4 +468,6 @@ def round_to_quantized_type(
405
468
  else:
406
469
  raise ValueError(f"Invalid quantization type {args.type}")
407
470
 
408
- return rounded.to(original_dtype)
471
+ if cast_to_original_dtype:
472
+ return rounded.to(original_dtype)
473
+ return rounded
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  from collections import defaultdict
16
15
  from enum import Enum
17
16
  from typing import Annotated, Any, Dict, List, Optional, Set, Union
@@ -18,6 +18,7 @@ from typing import List, Optional
18
18
 
19
19
  from compressed_tensors.config import CompressionFormat
20
20
  from compressed_tensors.quantization.quant_args import (
21
+ FP8_E4M3_DATA,
21
22
  DynamicType,
22
23
  QuantizationArgs,
23
24
  QuantizationStrategy,
@@ -160,6 +161,8 @@ NVFP4A16 = dict(
160
161
  symmetric=True,
161
162
  dynamic=False,
162
163
  group_size=16,
164
+ scale_dtype=FP8_E4M3_DATA.dtype,
165
+ zp_dtype=FP8_E4M3_DATA.dtype,
163
166
  )
164
167
  )
165
168
 
@@ -173,6 +176,8 @@ NVFP4 = dict(
173
176
  dynamic=False,
174
177
  group_size=16,
175
178
  observer="static_minmax",
179
+ scale_dtype=FP8_E4M3_DATA.dtype,
180
+ zp_dtype=FP8_E4M3_DATA.dtype,
176
181
  ),
177
182
  input_activations=QuantizationArgs(
178
183
  num_bits=4,
@@ -182,6 +187,8 @@ NVFP4 = dict(
182
187
  dynamic=DynamicType.LOCAL,
183
188
  group_size=16,
184
189
  observer="static_minmax",
190
+ scale_dtype=FP8_E4M3_DATA.dtype,
191
+ zp_dtype=FP8_E4M3_DATA.dtype,
185
192
  ),
186
193
  )
187
194
 
@@ -24,6 +24,7 @@ from compressed_tensors.quantization.quant_args import (
24
24
  QuantizationArgs,
25
25
  QuantizationStrategy,
26
26
  QuantizationType,
27
+ round_to_quantized_type_dtype,
27
28
  )
28
29
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
29
30
  from compressed_tensors.utils import deprecated
@@ -46,7 +47,6 @@ __all__ = [
46
47
  "calculate_range",
47
48
  "calculate_qparams",
48
49
  "generate_gparam",
49
- "is_fp4",
50
50
  "strategy_cdiv",
51
51
  ]
52
52
 
@@ -57,13 +57,6 @@ KV_CACHE_TARGETS = ["re:.*self_attn$"]
57
57
  _LOGGER: logging.Logger = logging.getLogger(__name__)
58
58
 
59
59
 
60
- def is_fp4(quantization_args: QuantizationArgs):
61
- return (
62
- quantization_args.num_bits == 4
63
- and quantization_args.type == QuantizationType.FLOAT
64
- )
65
-
66
-
67
60
  def calculate_qparams(
68
61
  min_vals: Tensor,
69
62
  max_vals: Tensor,
@@ -92,52 +85,50 @@ def calculate_qparams(
92
85
  bit_min, bit_max = calculate_range(quantization_args, device)
93
86
  bit_range = bit_max - bit_min
94
87
 
95
- if is_fp4(quantization_args=quantization_args):
96
- zp_dtype = FP8_E4M3_DATA.dtype
97
- else:
98
- zp_dtype = quantization_args.pytorch_dtype()
99
-
88
+ # 1. Generate scale and zero-point
100
89
  if quantization_args.symmetric:
101
90
  max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
102
-
103
- if is_fp4(quantization_args=quantization_args) and global_scale is not None:
104
- # Conditionally scale the generated local scale by a global_scale
105
- scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
106
- scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
107
- scales = scales.to(FP8_E4M3_DATA.dtype)
108
-
109
- else:
110
- scales = max_val_pos / (float(bit_range) / 2)
111
-
112
- # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
113
- if scales.dtype == FP8_E4M3_DATA.dtype:
114
- # torch.clamp not supported for FP8
115
- # use the next largest fp8 value from 0
116
- scales = torch.where(
117
- scales == 0,
118
- torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
119
- scales,
120
- )
121
- else:
122
- scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
123
-
91
+ scales = max_val_pos / (float(bit_range) / 2)
124
92
  zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
125
93
  else:
126
- if is_fp4(quantization_args=quantization_args):
94
+ if (
95
+ quantization_args.num_bits == 4
96
+ and quantization_args.type == QuantizationType.FLOAT
97
+ ):
127
98
  raise NotImplementedError(
128
99
  "Asymmetric Quantization is not supported for FP4"
129
100
  )
130
-
131
101
  scales = (max_vals - min_vals) / float(bit_range)
132
- scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
133
102
  zero_points = bit_min - (min_vals / scales)
134
103
  zero_points = torch.clamp(zero_points, bit_min, bit_max)
135
104
 
136
- # match zero-points to quantized type
137
- # if casting to int, use round instead of truncate
138
- if quantization_args.type == QuantizationType.INT:
139
- zero_points = torch.round(zero_points)
140
- zero_points = zero_points.to(zp_dtype)
105
+ # 2. Conditionally scale the generated local scale by a global_scale
106
+ if global_scale is not None:
107
+ scales = global_scale * scales
108
+
109
+ # 3. Conditionally round the scale to the quantized dtype, if scale_dtype is set
110
+ if quantization_args.scale_dtype is not None:
111
+ scales = round_to_quantized_type_dtype(
112
+ scales, dtype=quantization_args.scale_dtype
113
+ )
114
+
115
+ # 4. Update any 0s with small values to
116
+ # prevent div by 0
117
+ eps = _get_dtype_eps(
118
+ dtype=quantization_args.scale_dtype
119
+ if quantization_args.scale_dtype is not None
120
+ else scales.dtype
121
+ )
122
+ scales = torch.where(
123
+ scales == 0,
124
+ torch.tensor(eps, dtype=scales.dtype, device=device),
125
+ scales,
126
+ )
127
+
128
+ # 5. Round the zp to zp_dtype
129
+ zero_points = round_to_quantized_type_dtype(
130
+ zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False
131
+ )
141
132
 
142
133
  if scales.ndim == 0:
143
134
  scales = scales.reshape(1)
@@ -455,3 +446,14 @@ def strategy_cdiv(
455
446
  logger.bind(log_once=True).warning(message)
456
447
 
457
448
  return dividend
449
+
450
+
451
+ def _get_dtype_eps(dtype: torch.dtype) -> float:
452
+ if dtype == FP8_E4M3_DATA.dtype:
453
+ return 0.125
454
+ elif dtype == FP4_E2M1_DATA.dtype:
455
+ return 0.25
456
+ elif torch.is_floating_point(torch.tensor([], dtype=dtype)):
457
+ return torch.finfo(dtype).eps
458
+ else:
459
+ return 1
@@ -18,6 +18,14 @@ from typing import List, Optional
18
18
  import torch
19
19
  import torch.nn.utils.parametrize as P
20
20
  import tqdm
21
+ from compressed_tensors.modeling.attention import (
22
+ initialize_hooked_attention,
23
+ register_query_hook,
24
+ )
25
+ from compressed_tensors.modeling.kvcache import (
26
+ initialize_hooked_kv_cache,
27
+ register_key_hook,
28
+ )
21
29
  from compressed_tensors.registry.registry import RegistryMixin, T
22
30
  from compressed_tensors.transform import (
23
31
  TransformArgs,
@@ -36,6 +44,7 @@ from compressed_tensors.utils import (
36
44
  from compressed_tensors.utils.internal import InternalModule
37
45
  from torch import Tensor
38
46
  from torch.nn import Module, Parameter
47
+ from transformers import PreTrainedModel
39
48
 
40
49
 
41
50
  __all__ = ["TransformFactory", "TransformBase"]
@@ -97,12 +106,13 @@ class TransformFactory(RegistryMixin, ABC):
97
106
 
98
107
  desc = f"Applying {self.name} transforms"
99
108
  for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
100
- self._apply_to_module(module, arg)
109
+ self._apply_to_module(model, module, arg)
101
110
 
102
- def _apply_to_module(self, module: Module, args: TransformArgs):
111
+ def _apply_to_module(self, model: Module, module: Module, args: TransformArgs):
103
112
  """
104
113
  Create transforms and apply them to the module
105
114
 
115
+ :param model: model which module belongs to
106
116
  :param module: target module to apply transforms to
107
117
  :param args: defines how the transform will be applied to the target module
108
118
  """
@@ -156,7 +166,28 @@ class TransformFactory(RegistryMixin, ABC):
156
166
 
157
167
  module.register_forward_hook(output_hook)
158
168
 
159
- # other locations such as q_attn and k_attn have not been implemented
169
+ # register query hook to attention
170
+ elif args.location == TransformLocation.Q_ATTN:
171
+ if not isinstance(model, PreTrainedModel):
172
+ raise ValueError(f"Cannot hook attention of model: {model}")
173
+
174
+ def query_hook(_, query_states):
175
+ return transform(query_states)
176
+
177
+ initialize_hooked_attention(model, module)
178
+ register_query_hook(module, query_hook)
179
+
180
+ # register key hook to kvcache
181
+ elif args.location == TransformLocation.K_CACHE:
182
+ if not isinstance(model, PreTrainedModel):
183
+ raise ValueError(f"Cannot hook attention of model: {model}")
184
+
185
+ def key_hook(_, key_states):
186
+ return transform(key_states)
187
+
188
+ initialize_hooked_kv_cache(model, module)
189
+ register_key_hook(module, key_hook)
190
+
160
191
  else:
161
192
  raise NotImplementedError()
162
193
 
@@ -51,7 +51,6 @@ class HadamardFactory(TransformFactory):
51
51
  :param module: parent module that transform will be applied to
52
52
  :param args: defines how the transform will be applied to the module
53
53
  """
54
- assert hasattr(module, "weight")
55
54
  size = get_transform_size(module, args.location, self.scheme.head_dim)
56
55
  exec_device = get_execution_device(module)
57
56
  device = get_offloaded_device(module)
@@ -50,7 +50,6 @@ class RandomMatrixFactory(TransformFactory):
50
50
  :param module: parent module that transform will be applied to
51
51
  :param args: defines how the transform will be applied to the module
52
52
  """
53
- assert hasattr(module, "weight")
54
53
  size = get_transform_size(module, args.location, self.scheme.head_dim)
55
54
  device = get_offloaded_device(module)
56
55
  precision = self.scheme.precision if args.is_online() else torch.float64
@@ -68,8 +67,8 @@ class RandomMatrixFactory(TransformFactory):
68
67
  (size, size),
69
68
  generator=self.generator,
70
69
  dtype=precision,
71
- device=device,
72
- )
70
+ device=self.generator.device,
71
+ ).to(device)
73
72
  return Parameter(data, requires_grad=self.scheme.requires_grad)
74
73
 
75
74
  def _create_inverse(self, weight: Parameter) -> Parameter:
@@ -45,6 +45,16 @@ class TransformLocation(str, Enum):
45
45
  K_CACHE = "k_cache"
46
46
  Q_ATTN = "q_attn"
47
47
 
48
+ def is_online(self) -> bool:
49
+ """
50
+ Returns True if the transform location is online
51
+ (applied at runtime), False otherwise
52
+ """
53
+ return self not in (
54
+ TransformLocation.WEIGHT_INPUT,
55
+ TransformLocation.WEIGHT_OUTPUT,
56
+ )
57
+
48
58
 
49
59
  class TransformArgs(BaseModel, use_enum_values=True):
50
60
  """
@@ -70,9 +80,6 @@ class TransformArgs(BaseModel, use_enum_values=True):
70
80
  return value
71
81
 
72
82
  def is_online(self) -> bool:
73
- return self.location not in (
74
- TransformLocation.WEIGHT_INPUT,
75
- TransformLocation.WEIGHT_OUTPUT,
76
- )
83
+ return TransformLocation(self.location).is_online()
77
84
 
78
85
  model_config = ConfigDict(extra="forbid")