compressed-tensors 0.12.3a20251013__tar.gz → 0.12.3a20251023__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.github/workflows/build-test.yml +0 -24
  2. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.github/workflows/build.yml +5 -5
  3. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.github/workflows/test-check.yaml +17 -3
  4. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.github/workflows/test.yml +6 -7
  5. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.github/workflows/trigger-all.yml +4 -13
  6. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/PKG-INFO +5 -5
  7. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/README.md +3 -3
  8. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/setup.py +1 -1
  9. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +1 -1
  10. compressed_tensors-0.12.3a20251013/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py → compressed_tensors-0.12.3a20251023/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +9 -0
  11. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +4 -4
  12. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +3 -3
  13. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/config/base.py +1 -0
  14. compressed_tensors-0.12.3a20251023/src/compressed_tensors/modeling/__init__.py +18 -0
  15. compressed_tensors-0.12.3a20251023/src/compressed_tensors/modeling/attention.py +147 -0
  16. compressed_tensors-0.12.3a20251023/src/compressed_tensors/modeling/kvcache.py +183 -0
  17. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/lifecycle/apply.py +48 -103
  18. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/lifecycle/initialize.py +83 -28
  19. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/quant_args.py +1 -6
  20. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/quant_config.py +59 -45
  21. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/quant_scheme.py +2 -0
  22. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/utils/helpers.py +2 -33
  23. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/utils/helpers.py +63 -1
  24. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/utils/match.py +29 -0
  25. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/version.py +1 -1
  26. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors.egg-info/PKG-INFO +5 -5
  27. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors.egg-info/SOURCES.txt +6 -4
  28. compressed_tensors-0.12.3a20251013/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py → compressed_tensors-0.12.3a20251023/tests/test_compressors/quantized_compressors/test_fp4_quant.py +1 -1
  29. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +5 -5
  30. compressed_tensors-0.12.3a20251023/tests/test_modeling/test_attention_and_cache.py +108 -0
  31. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/lifecycle/test_apply.py +114 -1
  32. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_utils/test_match.py +81 -0
  33. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/testing_utils.py +3 -3
  34. compressed_tensors-0.12.3a20251013/.github/workflows/report.yml +0 -128
  35. compressed_tensors-0.12.3a20251013/.github/workflows/upload.yml +0 -158
  36. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.github/.gitkeep +0 -0
  37. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.github/actions/test/action.yml +0 -0
  38. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.github/scripts/step-status +0 -0
  39. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.github/workflows/post-release-nightly-build.yml +0 -0
  40. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.github/workflows/quality-check.yaml +0 -0
  41. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/.gitignore +0 -0
  42. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/LICENSE +0 -0
  43. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/Makefile +0 -0
  44. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  45. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/examples/bit_packing/int4_config.json +0 -0
  46. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/examples/bitmask_compression.ipynb +0 -0
  47. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  48. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  49. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/examples/llama_1.1b/example_quant_config.json +0 -0
  50. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  51. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/examples/quantize_and_pack_int4.ipynb +0 -0
  52. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/pyproject.toml +0 -0
  53. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/setup.cfg +0 -0
  54. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/__init__.py +0 -0
  55. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/README.md +0 -0
  56. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/__init__.py +0 -0
  57. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/base.py +0 -0
  58. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/__init__.py +0 -0
  59. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/base.py +0 -0
  60. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/helpers.py +0 -0
  61. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  62. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  63. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  64. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  65. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  66. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  67. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  68. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  69. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  70. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  71. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/config/__init__.py +0 -0
  72. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/config/dense.py +0 -0
  73. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/config/format.py +0 -0
  74. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  75. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  76. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/linear/__init__.py +0 -0
  77. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  78. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/logger.py +0 -0
  79. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/__init__.py +0 -0
  80. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  81. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  82. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  83. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  84. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/quant_metadata.py +0 -0
  85. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  86. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/registry/__init__.py +0 -0
  87. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/registry/registry.py +0 -0
  88. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/__init__.py +0 -0
  89. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/apply.py +0 -0
  90. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  91. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/factory/base.py +0 -0
  92. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/factory/hadamard.py +0 -0
  93. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/factory/matrix_multiply.py +0 -0
  94. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  95. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/transform_args.py +0 -0
  96. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/transform_config.py +0 -0
  97. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  98. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  99. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  100. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  101. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/transform/utils/matrix.py +0 -0
  102. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/utils/__init__.py +0 -0
  103. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/utils/internal.py +0 -0
  104. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/utils/offload.py +0 -0
  105. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/utils/permutations_24.py +0 -0
  106. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  107. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  108. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors/utils/type.py +0 -0
  109. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  110. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors.egg-info/requires.txt +0 -0
  111. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  112. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/__init__.py +0 -0
  113. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/conftest.py +0 -0
  114. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/mock_observer.py +0 -0
  115. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/__init__.py +0 -0
  116. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/model_compressors/__init__.py +0 -0
  117. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  118. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  119. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  120. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  121. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  122. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  123. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  124. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  125. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  126. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_configs/__init__.py +0 -0
  127. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_configs/test_base.py +0 -0
  128. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_configs/test_infer_quant.py +0 -0
  129. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  130. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_linear/__init__.py +0 -0
  131. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_linear/test_compressed_linear.py +0 -0
  132. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/__init__.py +0 -0
  133. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/lifecycle/__init__.py +0 -0
  134. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/lifecycle/conftest.py +0 -0
  135. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  136. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  137. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  138. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  139. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  140. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/lifecycle/test_static_lifecycle.py +0 -0
  141. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/test_configs/__init__.py +0 -0
  142. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  143. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  144. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/test_quant_args.py +0 -0
  145. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/test_quant_config.py +0 -0
  146. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/test_quant_scheme.py +0 -0
  147. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  148. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_registry.py +0 -0
  149. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_transform/conftest.py +0 -0
  150. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_transform/factory/test_correctness.py +0 -0
  151. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_transform/factory/test_memory.py +0 -0
  152. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_transform/factory/test_serialization.py +0 -0
  153. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_transform/test_transform_args.py +0 -0
  154. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_transform/test_transform_config.py +0 -0
  155. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_transform/test_transform_scheme.py +0 -0
  156. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_transform/utils/test_hadamard.py +0 -0
  157. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_utils/__init__.py +0 -0
  158. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_utils/test_helpers.py +0 -0
  159. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_utils/test_offload.py +0 -0
  160. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_utils/test_safetensors_load.py +0 -0
  161. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/tests/test_utils/test_type.py +0 -0
  162. {compressed_tensors-0.12.3a20251013 → compressed_tensors-0.12.3a20251023}/utils/copyright.py +0 -0
@@ -55,27 +55,3 @@ jobs:
55
55
  whl: ${{ needs.BUILD.outputs.whl }}
56
56
  code_coverage: ${{ matrix.test_config.code_coverage || false }}
57
57
  secrets: inherit
58
-
59
- UPLOAD:
60
- needs: [TEST]
61
- uses: ./.github/workflows/upload.yml
62
- with:
63
- label: gcp-k8s-util
64
- timeout: 40
65
- run_id: ${{ github.run_id }}
66
- push_to_pypi: ${{ inputs.push_to_pypi }}
67
- secrets: inherit
68
-
69
- REPORT:
70
- needs: [BUILD, TEST]
71
- if: success() || failure()
72
- uses: ./.github/workflows/report.yml
73
- with:
74
- label: rh-reporter
75
- timeout: 40
76
- run_id: ${{ github.run_id }}
77
- run_name: compressed-tensors
78
- wheel: ${{ needs.BUILD.outputs.whl }}
79
- wf_category: ${{ inputs.wf_category }}
80
- gitref: ${{ inputs.gitref }}
81
- secrets: inherit
@@ -86,9 +86,9 @@ jobs:
86
86
  id: auth
87
87
  uses: google-github-actions/auth@v2.1.3
88
88
  with:
89
- project_id: ${{ secrets.GCP_PROJECT }}
90
- workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }}
91
- service_account: ${{ secrets.GCP_GHA_SA }}
89
+ project_id: ${{ secrets.GCP_VLLM_PROJECT }}
90
+ workload_identity_provider: ${{ secrets.GCP_VLLM_PROJECT_WORKLOAD_IDENTITY_PROVIDER }}
91
+ service_account: ${{ secrets.GCP_VLLM_PROJECT_GHA_SA }}
92
92
 
93
93
  - name: 'Set up Cloud SDK'
94
94
  uses: 'google-github-actions/setup-gcloud@v2'
@@ -97,8 +97,8 @@ jobs:
97
97
 
98
98
  - name: copy whl and source distribution
99
99
  run: |
100
- gcloud storage cp dist/${{ steps.build.outputs.whlname }} ${{ secrets.GCP_BUILD_ML_ASSETS2 }}/${{ github.run_id }}/${{ steps.build.outputs.whlname }}
101
- gcloud storage cp dist/${{ steps.build.outputs.tarname }} ${{ secrets.GCP_BUILD_ML_ASSETS2 }}/${{ github.run_id }}/${{ steps.build.outputs.tarname }}
100
+ gcloud storage cp dist/${{ steps.build.outputs.whlname }} ${{ secrets.GCP_VLLM_PROJECT_BUILD_ASSETS }}/${{ github.run_id }}/${{ steps.build.outputs.whlname }}
101
+ gcloud storage cp dist/${{ steps.build.outputs.tarname }} ${{ secrets.GCP_VLLM_PROJECT_BUILD_ASSETS }}/${{ github.run_id }}/${{ steps.build.outputs.tarname }}
102
102
 
103
103
  - name: upload whl
104
104
  uses: actions/upload-artifact@v4
@@ -12,10 +12,9 @@ on:
12
12
 
13
13
  jobs:
14
14
  python-tests:
15
- runs-on: k8s-util
15
+ runs-on: ubuntu-22.04
16
16
  env:
17
- HF_HOME: /model-cache
18
- HF_TOKEN: ${{ secrets.NM_HF_TOKEN_READ_ONLY }}
17
+ HF_TOKEN: ${{ secrets.HF_RED_HAT_READ_ONLY }}
19
18
  steps:
20
19
  - uses: actions/setup-python@v5
21
20
  with:
@@ -32,5 +31,20 @@ jobs:
32
31
  run: pip3 install --upgrade pip setuptools
33
32
  - name: "⚙️ Install dependencies"
34
33
  run: pip3 install .[dev,accelerate]
34
+ - name: clean up
35
+ run: |
36
+ echo "cleaning up disk space as GHA runner has limited disk size."
37
+ python3 -m pip cache purge
38
+ sudo rm -rf /usr/local/.ghcup
39
+ sudo rm -rf /opt/hostedtoolcache/CodeQL
40
+ sudo rm -rf /usr/local/lib/android/sdk/ndk
41
+ sudo rm -rf /usr/share/dotnet
42
+ sudo rm -rf /opt/ghc
43
+ sudo rm -rf /usr/local/share/boost
44
+ if [[ "$(cat /etc/issue)" =~ Ubuntu ]]; then
45
+ sudo apt-get clean
46
+ fi
47
+ df -h
48
+ shell: bash
35
49
  - name: "🔬 Running tests"
36
50
  run: make test
@@ -72,8 +72,7 @@ jobs:
72
72
  id-token: 'write'
73
73
  pages: 'write'
74
74
  env:
75
- HF_HOME: /model-cache
76
- HF_TOKEN: ${{ secrets.NM_HF_TOKEN_READ_ONLY }}
75
+ HF_TOKEN: ${{ secrets.HF_RED_HAT_READ_ONLY }}
77
76
  environment:
78
77
  name: github-pages
79
78
  url: ${{ steps.coverage.outputs.page_url }}
@@ -123,9 +122,9 @@ jobs:
123
122
  id: auth
124
123
  uses: google-github-actions/auth@v2.1.3
125
124
  with:
126
- project_id: ${{ secrets.GCP_PROJECT }}
127
- workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }}
128
- service_account: ${{ secrets.GCP_GHA_SA }}
125
+ project_id: ${{ secrets.GCP_VLLM_PROJECT }}
126
+ workload_identity_provider: ${{ secrets.GCP_VLLM_PROJECT_WORKLOAD_IDENTITY_PROVIDER }}
127
+ service_account: ${{ secrets.GCP_VLLM_PROJECT_GHA_SA }}
129
128
 
130
129
  - name: 'Set up Cloud SDK'
131
130
  uses: 'google-github-actions/setup-gcloud@v2'
@@ -136,7 +135,7 @@ jobs:
136
135
  if: ${{ inputs.run_id != '' }}
137
136
  uses: neuralmagic/nm-actions/actions/gcp-download-assets@v1.1.0
138
137
  with:
139
- bucket_source: ${{ secrets.GCP_BUILD_ML_ASSETS2 }}
138
+ bucket_source: ${{ secrets.GCP_VLLM_PROJECT_BUILD_ASSETS }}
140
139
  run_id: ${{ inputs.run_id }}
141
140
 
142
141
  - name: run tests
@@ -165,7 +164,7 @@ jobs:
165
164
 
166
165
  - name: copy results to GCP
167
166
  run: |
168
- gcloud storage cp test-results/report.xml ${{ secrets.GCP_BUILD_ML_ASSETS2 }}/${{ github.run_id }}/test-results/report-${{ inputs.test_label }}.xml
167
+ gcloud storage cp test-results/report.xml ${{ secrets.GCP_VLLM_PROJECT_BUILD_ASSETS }}/${{ github.run_id }}/test-results/report-${{ inputs.test_label }}.xml
169
168
 
170
169
  - name: upload results
171
170
  uses: actions/upload-artifact@v4
@@ -11,10 +11,6 @@ on:
11
11
  description: "workflow category, must be 'NIGHTLY' or 'RELEASE' (default: NIGHTLY)"
12
12
  type: string
13
13
  default: NIGHTLY
14
- push_to_pypi:
15
- description: "when set and tests pass, then '.whl' & '.tar.gz' will be pushed to public pypi"
16
- type: boolean
17
- default: false
18
14
  gitref:
19
15
  description: "git commit hash or tag name"
20
16
  type: string
@@ -29,10 +25,6 @@ on:
29
25
  - NIGHTLY
30
26
  - RELEASE
31
27
  default: NIGHTLY
32
- push_to_pypi:
33
- description: "when set and tests pass, then '.whl' & '.tar.gz' will be pushed to public pypi"
34
- type: boolean
35
- default: false
36
28
  gitref:
37
29
  description: "git commit hash or tag name"
38
30
  type: string
@@ -46,9 +38,8 @@ jobs:
46
38
  with:
47
39
  wf_category: ${{ inputs.wf_category || 'NIGHTLY' }}
48
40
  gitref: ${{ inputs.gitref || 'main' }}
49
- push_to_pypi: ${{ (github.event.schedule == '30 0 * * *') || inputs.push_to_pypi || false }}
50
- test_configs: '[{"python":"3.11.4","label":"k8s-util","timeout":"40","code_coverage":true},
51
- {"python":"3.10.12","label":"k8s-util","timeout":"40"},
52
- {"python":"3.13","label":"k8s-h100-solo","timeout":"40"},
53
- {"python":"3.12.6","label":"k8s-a100-duo","timeout":"40"}]'
41
+ test_configs: '[{"python":"3.11.4","label":"ubuntu-latest","timeout":"40","code_coverage":true},
42
+ {"python":"3.10.12","label":"ubuntu-latest","timeout":"40"},
43
+ {"python":"3.13","label":"ubuntu-24.04","timeout":"40"},
44
+ {"python":"3.12.6","label":"ubuntu-22.04","timeout":"40"}]'
54
45
  secrets: inherit
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251013
3
+ Version: 0.12.3a20251023
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
- Home-page: https://github.com/neuralmagic/compressed-tensors
5
+ Home-page: https://github.com/vllm-project/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
7
7
  Author-email: support@neuralmagic.com
8
8
  License: Apache 2.0
@@ -71,7 +71,7 @@ pip install --pre compressed-tensors
71
71
  ### From Source
72
72
 
73
73
  ```bash
74
- git clone https://github.com/neuralmagic/compressed-tensors
74
+ git clone https://github.com/vllm-project/compressed-tensors
75
75
  cd compressed-tensors
76
76
  pip install -e .
77
77
  ```
@@ -112,7 +112,7 @@ We can apply bitmask compression to a whole model. For more detailed example see
112
112
  from compressed_tensors import save_compressed_model, load_compressed, BitmaskConfig
113
113
  from transformers import AutoModelForCausalLM
114
114
 
115
- model_name = "neuralmagic/llama2.c-stories110M-pruned50"
115
+ model_name = "RedHatAI/llama2.c-stories110M-pruned50"
116
116
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
117
117
 
118
118
  original_state_dict = model.state_dict()
@@ -126,7 +126,7 @@ save_compressed_model(model, "compressed_model.safetensors", compression_format=
126
126
  state_dict = dict(load_compressed("compressed_model.safetensors", compression_config))
127
127
  ```
128
128
 
129
- For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/neuralmagic/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb).
129
+ For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/vllm-project/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb).
130
130
 
131
131
 
132
132
  ## Saving a Compressed Model with PTQ
@@ -37,7 +37,7 @@ pip install --pre compressed-tensors
37
37
  ### From Source
38
38
 
39
39
  ```bash
40
- git clone https://github.com/neuralmagic/compressed-tensors
40
+ git clone https://github.com/vllm-project/compressed-tensors
41
41
  cd compressed-tensors
42
42
  pip install -e .
43
43
  ```
@@ -78,7 +78,7 @@ We can apply bitmask compression to a whole model. For more detailed example see
78
78
  from compressed_tensors import save_compressed_model, load_compressed, BitmaskConfig
79
79
  from transformers import AutoModelForCausalLM
80
80
 
81
- model_name = "neuralmagic/llama2.c-stories110M-pruned50"
81
+ model_name = "RedHatAI/llama2.c-stories110M-pruned50"
82
82
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
83
83
 
84
84
  original_state_dict = model.state_dict()
@@ -92,7 +92,7 @@ save_compressed_model(model, "compressed_model.safetensors", compression_format=
92
92
  state_dict = dict(load_compressed("compressed_model.safetensors", compression_config))
93
93
  ```
94
94
 
95
- For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/neuralmagic/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb).
95
+ For more in-depth tutorial on bitmask compression, refer to the [notebook](https://github.com/vllm-project/compressed-tensors/blob/d707c5b84bc3fef164aebdcd97cb6eaa571982f8/examples/bitmask_compression.ipynb).
96
96
 
97
97
 
98
98
  ## Saving a Compressed Model with PTQ
@@ -109,7 +109,7 @@ setup(
109
109
  description="Library for utilization of compressed safetensors of neural network models",
110
110
  long_description=_setup_long_description()[0],
111
111
  long_description_content_type=_setup_long_description()[1],
112
- url="https://github.com/neuralmagic/compressed-tensors",
112
+ url="https://github.com/vllm-project/compressed-tensors",
113
113
  extras_require=_setup_extras(),
114
114
  install_requires=_setup_install_requires(),
115
115
  package_dir={"": "src"},
@@ -14,6 +14,6 @@
14
14
  # flake8: noqa
15
15
 
16
16
  from .base import *
17
+ from .fp4_quantized import *
17
18
  from .naive_quantized import *
18
- from .nvfp4_quantized import *
19
19
  from .pack_quantized import *
@@ -123,6 +123,15 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
123
123
  return decompressed_weight
124
124
 
125
125
 
126
+ @BaseCompressor.register(name=CompressionFormat.mxfp4_pack_quantized.value)
127
+ class MXFP4PackedCompressor(NVFP4PackedCompressor):
128
+ """
129
+ Alias for mxfp4 quantized models
130
+ """
131
+
132
+ pass
133
+
134
+
126
135
  @torch.compile(fullgraph=True, dynamic=True)
127
136
  def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
128
137
  """
@@ -19,7 +19,7 @@ import torch
19
19
  from compressed_tensors.compressors.base import BaseCompressor
20
20
  from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
21
  from compressed_tensors.config import CompressionFormat, SparsityStructure
22
- from compressed_tensors.quantization import FP8_DTYPE
22
+ from compressed_tensors.quantization import FP8_E4M3_DATA
23
23
  from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
24
24
  from torch import Tensor
25
25
 
@@ -189,11 +189,11 @@ def sparse24_bitmask_compress(
189
189
 
190
190
  bytemasks = get_24_bytemasks(tensor=tensor)
191
191
 
192
- if tensor.dtype == FP8_DTYPE:
192
+ if tensor.dtype == FP8_E4M3_DATA.dtype:
193
193
  # acces raw bytes of the tensor
194
194
  tensor_view = tensor.view(torch.int8)
195
195
  values = tensor_view[bytemasks]
196
- values = values.view(FP8_DTYPE)
196
+ values = values.view(FP8_E4M3_DATA.dtype)
197
197
  else:
198
198
  values = tensor[bytemasks]
199
199
 
@@ -241,7 +241,7 @@ def get_24_bytemasks(tensor):
241
241
  multiple of 4.
242
242
  """
243
243
  original_dtype = tensor.dtype
244
- if tensor.dtype == FP8_DTYPE:
244
+ if tensor.dtype == FP8_E4M3_DATA.dtype:
245
245
  tensor = tensor.view(torch.int8)
246
246
  original_shape = tensor.shape
247
247
  num_elements = tensor.numel()
@@ -18,7 +18,7 @@ import torch
18
18
  from compressed_tensors.compressors.base import BaseCompressor
19
19
  from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
20
20
  from compressed_tensors.config import CompressionFormat
21
- from compressed_tensors.quantization import FP8_DTYPE
21
+ from compressed_tensors.quantization import FP8_E4M3_DATA
22
22
  from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
23
23
  from torch import Tensor
24
24
 
@@ -138,11 +138,11 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
138
138
  bytemasks = tensor != 0
139
139
  row_counts = bytemasks.sum(dim=-1)
140
140
  row_offsets = torch.cumsum(row_counts, 0) - row_counts
141
- if tensor.dtype == FP8_DTYPE:
141
+ if tensor.dtype == FP8_E4M3_DATA.dtype:
142
142
  # acces raw bytes of the tensor
143
143
  tensor_view = tensor.view(torch.int8)
144
144
  values = tensor_view[bytemasks]
145
- values = values.view(FP8_DTYPE)
145
+ values = values.view(FP8_E4M3_DATA.dtype)
146
146
  else:
147
147
  values = tensor[bytemasks]
148
148
  bitmasks_packed = pack_bitmasks(bytemasks)
@@ -34,6 +34,7 @@ class CompressionFormat(Enum):
34
34
  marlin_24 = "marlin-24"
35
35
  mixed_precision = "mixed-precision"
36
36
  nvfp4_pack_quantized = "nvfp4-pack-quantized"
37
+ mxfp4_pack_quantized = "mxfp4-pack-quantized"
37
38
 
38
39
 
39
40
  @unique
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ # isort: off
17
+ from .kvcache import *
18
+ from .attention import *
@@ -0,0 +1,147 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Callable, Optional
17
+
18
+ from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
19
+ from compressed_tensors.quantization.lifecycle.forward import forward_quantize
20
+ from compressed_tensors.utils import getattr_chain
21
+ from compressed_tensors.utils.internal import InternalModule
22
+ from torch import Tensor
23
+ from torch.nn import Module
24
+ from torch.utils.hooks import RemovableHandle
25
+ from transformers import PretrainedConfig, PreTrainedModel
26
+ from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
27
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
28
+
29
+
30
+ __all__ = [
31
+ "QuantizedAttentionImpl",
32
+ "initialize_hooked_attention",
33
+ "register_query_hook",
34
+ "IMPL_ATTR",
35
+ ]
36
+
37
+
38
+ IMPL_ATTR = "impl"
39
+ HOOKED_ATTENTION_NAME = "ct_hooked_attention"
40
+
41
+
42
+ class QuantizedAttentionImpl(InternalModule):
43
+ """
44
+ QuantizedAttentionImpl module which wraps the functionality of the original
45
+ attention implementation. Unlike the original attention function, this
46
+ implementation is a `torch.nn.Module` which can be hooked to trigger
47
+ transforms and calibration hooks.
48
+
49
+ This module works by being registered as a submodule to attention modules via
50
+ `initialize_hooked_attention`, registering a new attention implementation function
51
+ which calls this module, then setting the model attention implementation to the new
52
+ function. After triggering hooks and quantization, this module calls the original
53
+ attention implementation function.
54
+ """
55
+
56
+ _original_impl = "eager"
57
+
58
+ def __init__(self, config: PretrainedConfig):
59
+ super().__init__()
60
+ self.config = config
61
+
62
+ def forward(
63
+ self,
64
+ module: Module,
65
+ query: Tensor,
66
+ key: Tensor,
67
+ value: Tensor,
68
+ *args,
69
+ **kwargs,
70
+ ):
71
+ # quantization
72
+ quant_args_attr = "quantization_scheme.input_activations"
73
+ quant_args = getattr_chain(module, quant_args_attr, None)
74
+ quant_enabled = getattr(module, "quantization_enabled", True)
75
+ if quant_args is not None and quant_enabled:
76
+ query = forward_quantize(module, query, "q", quant_args)
77
+
78
+ # original attention
79
+ return ALL_ATTENTION_FUNCTIONS[QuantizedAttentionImpl._original_impl](
80
+ module,
81
+ query,
82
+ key,
83
+ value,
84
+ *args,
85
+ **kwargs,
86
+ )
87
+
88
+
89
+ # ----- initialize ----- #
90
+
91
+
92
+ def _hooked_attention(module: Module, *args, **kwargs):
93
+ assert hasattr(module, IMPL_ATTR), (
94
+ f"Using {HOOKED_ATTENTION_NAME} attention implementation, "
95
+ f"but attention module does not have {IMPL_ATTR} submodule."
96
+ )
97
+
98
+ return getattr(module, IMPL_ATTR)(module, *args, **kwargs)
99
+
100
+
101
+ def initialize_hooked_attention(model: PreTrainedModel, module: Module):
102
+ """
103
+ Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
104
+ attached to attention. Assumes that only one model is hooked at a time.
105
+
106
+ :param model: parent model of attention module
107
+ :param module: attention module to initialize with
108
+ """
109
+ if not hasattr(module, IMPL_ATTR):
110
+ module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config))
111
+
112
+ if model.config._attn_implementation != HOOKED_ATTENTION_NAME:
113
+ QuantizedAttentionImpl._original_impl = model.config._attn_implementation
114
+ original_mask = ALL_MASK_ATTENTION_FUNCTIONS[model.config._attn_implementation]
115
+
116
+ ALL_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, _hooked_attention)
117
+ ALL_MASK_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, original_mask)
118
+ model.set_attn_implementation(HOOKED_ATTENTION_NAME)
119
+ assert model.config._attn_implementation == HOOKED_ATTENTION_NAME
120
+
121
+ initialize_hooked_kv_cache(model, module)
122
+
123
+
124
+ # ----- hooks ----- #
125
+
126
+
127
+ def register_query_hook(
128
+ module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
129
+ ) -> RemovableHandle:
130
+ """
131
+ Register a hook which takes post-rope query states as an argument and
132
+ returns the modified query states or `None`
133
+
134
+ :param module: attention module to add hook to
135
+ :param hook: query hook function
136
+ """
137
+ impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
138
+
139
+ def _hook(impl: QuantizedAttentionImpl, args, kwargs):
140
+ bound = inspect.signature(impl.forward).bind(*args, **kwargs)
141
+ value = hook(module, bound.arguments["query"])
142
+ if value is not None:
143
+ bound.arguments["query"] = value
144
+
145
+ return bound.args, bound.kwargs
146
+
147
+ return impl.register_forward_pre_hook(_hook, with_kwargs=True)