compressed-tensors 0.12.3a20251023__tar.gz → 0.12.3a20251030__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. {compressed_tensors-0.12.3a20251023/src/compressed_tensors.egg-info → compressed_tensors-0.12.3a20251030}/PKG-INFO +1 -1
  2. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -9
  3. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/lifecycle/forward.py +1 -1
  4. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/quant_args.py +9 -3
  5. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/utils/__init__.py +1 -0
  6. compressed_tensors-0.12.3a20251030/src/compressed_tensors/quantization/utils/mxfp4_utils.py +97 -0
  7. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/factory/base.py +34 -3
  8. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/factory/hadamard.py +0 -1
  9. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/factory/matrix_multiply.py +2 -3
  10. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/transform_args.py +11 -4
  11. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/utils/matrix.py +13 -21
  12. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/version.py +1 -1
  13. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  14. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors.egg-info/SOURCES.txt +3 -0
  15. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/quantized_compressors/test_pack_quant.py +95 -5
  16. compressed_tensors-0.12.3a20251030/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +172 -0
  17. compressed_tensors-0.12.3a20251030/tests/test_quantization/test_utils/test_mxfp4_utils.py +79 -0
  18. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_transform/conftest.py +21 -1
  19. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_transform/factory/test_correctness.py +74 -7
  20. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.github/.gitkeep +0 -0
  21. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.github/actions/test/action.yml +0 -0
  22. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.github/scripts/step-status +0 -0
  23. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.github/workflows/build-test.yml +0 -0
  24. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.github/workflows/build.yml +0 -0
  25. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.github/workflows/post-release-nightly-build.yml +0 -0
  26. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.github/workflows/quality-check.yaml +0 -0
  27. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.github/workflows/test-check.yaml +0 -0
  28. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.github/workflows/test.yml +0 -0
  29. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.github/workflows/trigger-all.yml +0 -0
  30. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/.gitignore +0 -0
  31. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/LICENSE +0 -0
  32. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/Makefile +0 -0
  33. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/README.md +0 -0
  34. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  35. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/examples/bit_packing/int4_config.json +0 -0
  36. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/examples/bitmask_compression.ipynb +0 -0
  37. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  38. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  39. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/examples/llama_1.1b/example_quant_config.json +0 -0
  40. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  41. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/examples/quantize_and_pack_int4.ipynb +0 -0
  42. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/pyproject.toml +0 -0
  43. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/setup.cfg +0 -0
  44. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/setup.py +0 -0
  45. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/__init__.py +0 -0
  46. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/README.md +0 -0
  47. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/__init__.py +0 -0
  48. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/base.py +0 -0
  49. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/__init__.py +0 -0
  50. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/base.py +0 -0
  51. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/helpers.py +0 -0
  52. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  53. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  54. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  55. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  56. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +0 -0
  57. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  58. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  59. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  60. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  61. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  62. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  63. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  64. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  65. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/config/__init__.py +0 -0
  66. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/config/base.py +0 -0
  67. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/config/dense.py +0 -0
  68. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/config/format.py +0 -0
  69. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  70. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  71. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/linear/__init__.py +0 -0
  72. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  73. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/logger.py +0 -0
  74. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/modeling/__init__.py +0 -0
  75. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/modeling/attention.py +0 -0
  76. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/modeling/kvcache.py +0 -0
  77. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/__init__.py +0 -0
  78. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  79. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  80. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  81. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  82. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  83. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/quant_config.py +0 -0
  84. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/quant_metadata.py +0 -0
  85. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  86. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  87. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/registry/__init__.py +0 -0
  88. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/registry/registry.py +0 -0
  89. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/__init__.py +0 -0
  90. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/apply.py +0 -0
  91. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  92. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  93. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/transform_config.py +0 -0
  94. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  95. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  96. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  97. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  98. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/utils/__init__.py +0 -0
  99. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/utils/helpers.py +0 -0
  100. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/utils/internal.py +0 -0
  101. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/utils/match.py +0 -0
  102. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/utils/offload.py +0 -0
  103. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/utils/permutations_24.py +0 -0
  104. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  105. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  106. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors/utils/type.py +0 -0
  107. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  108. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors.egg-info/requires.txt +0 -0
  109. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  110. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/__init__.py +0 -0
  111. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/conftest.py +0 -0
  112. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/mock_observer.py +0 -0
  113. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/__init__.py +0 -0
  114. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/model_compressors/__init__.py +0 -0
  115. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  116. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  117. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/quantized_compressors/test_fp4_quant.py +0 -0
  118. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  119. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  120. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  121. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  122. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  123. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  124. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  125. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_configs/__init__.py +0 -0
  126. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_configs/test_base.py +0 -0
  127. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_configs/test_infer_quant.py +0 -0
  128. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  129. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_linear/__init__.py +0 -0
  130. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_linear/test_compressed_linear.py +0 -0
  131. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_modeling/test_attention_and_cache.py +0 -0
  132. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/__init__.py +0 -0
  133. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/lifecycle/__init__.py +0 -0
  134. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/lifecycle/conftest.py +0 -0
  135. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  136. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  137. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  138. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  139. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  140. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  141. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/lifecycle/test_static_lifecycle.py +0 -0
  142. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/test_configs/__init__.py +0 -0
  143. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  144. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  145. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/test_quant_args.py +0 -0
  146. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/test_quant_config.py +0 -0
  147. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/test_quant_scheme.py +0 -0
  148. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  149. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_registry.py +0 -0
  150. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_transform/factory/test_memory.py +0 -0
  151. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_transform/factory/test_serialization.py +0 -0
  152. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_transform/test_transform_args.py +0 -0
  153. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_transform/test_transform_config.py +0 -0
  154. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_transform/test_transform_scheme.py +0 -0
  155. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_transform/utils/test_hadamard.py +0 -0
  156. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_utils/__init__.py +0 -0
  157. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_utils/test_helpers.py +0 -0
  158. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_utils/test_match.py +0 -0
  159. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_utils/test_offload.py +0 -0
  160. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_utils/test_safetensors_load.py +0 -0
  161. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/test_utils/test_type.py +0 -0
  162. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/tests/testing_utils.py +0 -0
  163. {compressed_tensors-0.12.3a20251023 → compressed_tensors-0.12.3a20251030}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251023
3
+ Version: 0.12.3a20251030
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/vllm-project/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -134,8 +134,6 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
134
134
  compressed_dict["weight_shape"] = weight_shape
135
135
  compressed_dict["weight_packed"] = packed_weight
136
136
 
137
- # We typically don't compress zp; apart from when using the packed_compressor
138
- # and when storing group/channel zp
139
137
  if not quantization_args.symmetric and quantization_args.strategy in [
140
138
  QuantizationStrategy.GROUP.value,
141
139
  QuantizationStrategy.CHANNEL.value,
@@ -143,7 +141,7 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
143
141
  packed_zp = pack_to_int32(
144
142
  zero_point, quantization_args.num_bits, packed_dim=0
145
143
  )
146
- compressed_dict["weight_zero_point"] = packed_zp
144
+ compressed_dict["weight_zero_point"] = packed_zp.contiguous()
147
145
  return compressed_dict
148
146
 
149
147
  def decompress_weight(
@@ -166,16 +164,13 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
166
164
  num_bits = quantization_args.num_bits
167
165
  unpacked = unpack_from_int32(weight, num_bits, original_shape)
168
166
 
169
- # NOTE: this will fail decompression as we don't currently handle packed zp on
170
- # decompression
171
167
  if not quantization_args.symmetric and quantization_args.strategy in [
172
168
  QuantizationStrategy.GROUP.value,
173
169
  QuantizationStrategy.CHANNEL.value,
174
170
  ]:
175
- raise ValueError(
176
- "Decompression of packed zero points is currently not supported"
177
- )
178
- assert zero_point is not None
171
+ assert (
172
+ zero_point is not None
173
+ ), "Asymmetric quantization requires zero-point values"
179
174
  original_zp_shape = (original_shape[0], scale.shape[-1])
180
175
  zero_point = unpack_from_int32(
181
176
  zero_point, num_bits, original_zp_shape, packed_dim=0
@@ -278,7 +278,7 @@ def _process_quantization(
278
278
  if columns % group_size != 0:
279
279
  raise ValueError(
280
280
  "tensor column shape must be divisble "
281
- f"by the given group_size {group_size}"
281
+ f"by the given group_size {group_size} but got {columns}"
282
282
  )
283
283
 
284
284
  # support column-order (default) quantization as well as other orderings
@@ -25,6 +25,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida
25
25
  __all__ = [
26
26
  "FP8_E4M3_DATA",
27
27
  "FP4_E2M1_DATA",
28
+ "BFLOAT16_DATA",
28
29
  "FloatArgs",
29
30
  "QuantizationType",
30
31
  "QuantizationStrategy",
@@ -38,9 +39,9 @@ __all__ = [
38
39
  class FloatArgs:
39
40
  exponent: int
40
41
  mantissa: int
41
- bits: int
42
- max: float
43
- min: float
42
+ bits: Optional[int] = None
43
+ max: Optional[float] = None
44
+ min: Optional[float] = None
44
45
  dtype: Optional[torch.dtype] = None
45
46
 
46
47
 
@@ -76,6 +77,11 @@ class FP8_E4M3_DATA(FloatArgs):
76
77
  dtype = torch.float8_e4m3fn
77
78
 
78
79
 
80
+ class BFLOAT16_DATA(FloatArgs):
81
+ exponent = 8
82
+ mantissa = 7
83
+
84
+
79
85
  class QuantizationType(str, Enum):
80
86
  """
81
87
  Enum storing quantization type options
@@ -14,3 +14,4 @@
14
14
 
15
15
  # flake8: noqa
16
16
  from .helpers import *
17
+ from .mxfp4_utils import *
@@ -0,0 +1,97 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from compressed_tensors.quantization.quant_args import BFLOAT16_DATA, FP4_E2M1_DATA
17
+
18
+
19
+ __all__ = ["convert_mxfp4_exp_scale", "generate_mxfp4_scales", "round_to_power_2"]
20
+
21
+ # Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501
22
+
23
+
24
+ def convert_mxfp4_exp_scale(
25
+ scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16
26
+ ) -> torch.Tensor:
27
+ """
28
+ Converts mxfp4 scales. Scales are powers of 2, with the
29
+ exponents stored in uint8. Converts to dense dtype so that
30
+ they can be applied to the weights and activations during QDQ
31
+
32
+ :param scale: uint8 exponent scale
33
+ :param dtype: dense dtype
34
+ """
35
+ assert scale.dtype == torch.uint8
36
+ scale_exp = scale.to(torch.int32) - 127
37
+ scale = 2.00 ** (scale_exp.to(torch.float))
38
+ return scale.to(dtype)
39
+
40
+
41
+ def round_to_power_2(x: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ Round values to the closest power of 2.
44
+ This is done by masking the values with BFLOAT16_SIGN_EXPONENT_MASK
45
+ which essentially removes the mantissa and keeps the exponent.
46
+ i.e the closest power of 2 for the input_value.
47
+
48
+ E.g:
49
+ 0.0825 = 1.32 (mantissa) x 2**-4 (exponent)
50
+ 0.0825 ==> -4 (exponent) + 127 = 123 = 01111011 (8 bits for bfloat16)
51
+ 0.0825 ==> 0.32 (mantissa) = 0101001 (7 bits for bfloat16)
52
+ 0.0825 == 0b01111011_0101001 (bfloat16)
53
+ 0b01111011_0101001 & 111111111_0000000 == 0b01111011_0000000
54
+ Keep the exponent + sign bit to give you the closest power of 2, 0.0625
55
+
56
+ :param x: tensor to round to closest power of 2
57
+ """
58
+ assert x.dtype == torch.bfloat16
59
+ x = x.view(torch.uint16).to(torch.int32)
60
+
61
+ # Find closest power of 2
62
+ BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_DATA.mantissa - FP4_E2M1_DATA.mantissa - 1)
63
+ # Add value to push the value to the next exponent
64
+ BFLOAT16_SIGN_EXPONENT_MASK = (
65
+ (1 << (BFLOAT16_DATA.exponent + 1)) - 1
66
+ ) << BFLOAT16_DATA.mantissa
67
+ # mask to only keep exponent - we conservatively round down
68
+ # to better represent smaller numbers / prevent overflow
69
+ block_max_uint = torch.bitwise_and(
70
+ x + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK
71
+ )
72
+ return block_max_uint.to(torch.uint16).view(torch.bfloat16)
73
+
74
+
75
+ def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor:
76
+ """
77
+ Generate mxfp4 scales. The scales require the following steps
78
+ 1. Round to the closest power of 2
79
+ 2. Convert to exponent
80
+ 3. Store in uint8
81
+
82
+ Called when calculating qparams using observers.
83
+
84
+ :param x: tensor to round to closest power of 2
85
+ :returns uint8 scales as exponents
86
+ """
87
+ # Round to closest power of 2
88
+ scale_power_2 = round_to_power_2(x)
89
+ # Convert to exponent
90
+ scale_exp = 127 + torch.floor(torch.log2(scale_power_2)).to(torch.int32) - 2
91
+ # Clamp and store in uint8, as expected by mxfp4
92
+ scale_exp = torch.clamp(
93
+ scale_exp,
94
+ max=torch.iinfo(torch.uint8).max,
95
+ min=torch.iinfo(torch.uint8).min,
96
+ )
97
+ return scale_exp.to(torch.uint8)
@@ -18,6 +18,14 @@ from typing import List, Optional
18
18
  import torch
19
19
  import torch.nn.utils.parametrize as P
20
20
  import tqdm
21
+ from compressed_tensors.modeling.attention import (
22
+ initialize_hooked_attention,
23
+ register_query_hook,
24
+ )
25
+ from compressed_tensors.modeling.kvcache import (
26
+ initialize_hooked_kv_cache,
27
+ register_key_hook,
28
+ )
21
29
  from compressed_tensors.registry.registry import RegistryMixin, T
22
30
  from compressed_tensors.transform import (
23
31
  TransformArgs,
@@ -36,6 +44,7 @@ from compressed_tensors.utils import (
36
44
  from compressed_tensors.utils.internal import InternalModule
37
45
  from torch import Tensor
38
46
  from torch.nn import Module, Parameter
47
+ from transformers import PreTrainedModel
39
48
 
40
49
 
41
50
  __all__ = ["TransformFactory", "TransformBase"]
@@ -97,12 +106,13 @@ class TransformFactory(RegistryMixin, ABC):
97
106
 
98
107
  desc = f"Applying {self.name} transforms"
99
108
  for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
100
- self._apply_to_module(module, arg)
109
+ self._apply_to_module(model, module, arg)
101
110
 
102
- def _apply_to_module(self, module: Module, args: TransformArgs):
111
+ def _apply_to_module(self, model: Module, module: Module, args: TransformArgs):
103
112
  """
104
113
  Create transforms and apply them to the module
105
114
 
115
+ :param model: model which module belongs to
106
116
  :param module: target module to apply transforms to
107
117
  :param args: defines how the transform will be applied to the target module
108
118
  """
@@ -156,7 +166,28 @@ class TransformFactory(RegistryMixin, ABC):
156
166
 
157
167
  module.register_forward_hook(output_hook)
158
168
 
159
- # other locations such as q_attn and k_attn have not been implemented
169
+ # register query hook to attention
170
+ elif args.location == TransformLocation.Q_ATTN:
171
+ if not isinstance(model, PreTrainedModel):
172
+ raise ValueError(f"Cannot hook attention of model: {model}")
173
+
174
+ def query_hook(_, query_states):
175
+ return transform(query_states)
176
+
177
+ initialize_hooked_attention(model, module)
178
+ register_query_hook(module, query_hook)
179
+
180
+ # register key hook to kvcache
181
+ elif args.location == TransformLocation.K_CACHE:
182
+ if not isinstance(model, PreTrainedModel):
183
+ raise ValueError(f"Cannot hook attention of model: {model}")
184
+
185
+ def key_hook(_, key_states):
186
+ return transform(key_states)
187
+
188
+ initialize_hooked_kv_cache(model, module)
189
+ register_key_hook(module, key_hook)
190
+
160
191
  else:
161
192
  raise NotImplementedError()
162
193
 
@@ -51,7 +51,6 @@ class HadamardFactory(TransformFactory):
51
51
  :param module: parent module that transform will be applied to
52
52
  :param args: defines how the transform will be applied to the module
53
53
  """
54
- assert hasattr(module, "weight")
55
54
  size = get_transform_size(module, args.location, self.scheme.head_dim)
56
55
  exec_device = get_execution_device(module)
57
56
  device = get_offloaded_device(module)
@@ -50,7 +50,6 @@ class RandomMatrixFactory(TransformFactory):
50
50
  :param module: parent module that transform will be applied to
51
51
  :param args: defines how the transform will be applied to the module
52
52
  """
53
- assert hasattr(module, "weight")
54
53
  size = get_transform_size(module, args.location, self.scheme.head_dim)
55
54
  device = get_offloaded_device(module)
56
55
  precision = self.scheme.precision if args.is_online() else torch.float64
@@ -68,8 +67,8 @@ class RandomMatrixFactory(TransformFactory):
68
67
  (size, size),
69
68
  generator=self.generator,
70
69
  dtype=precision,
71
- device=device,
72
- )
70
+ device=self.generator.device,
71
+ ).to(device)
73
72
  return Parameter(data, requires_grad=self.scheme.requires_grad)
74
73
 
75
74
  def _create_inverse(self, weight: Parameter) -> Parameter:
@@ -45,6 +45,16 @@ class TransformLocation(str, Enum):
45
45
  K_CACHE = "k_cache"
46
46
  Q_ATTN = "q_attn"
47
47
 
48
+ def is_online(self) -> bool:
49
+ """
50
+ Returns True if the transform location is online
51
+ (applied at runtime), False otherwise
52
+ """
53
+ return self not in (
54
+ TransformLocation.WEIGHT_INPUT,
55
+ TransformLocation.WEIGHT_OUTPUT,
56
+ )
57
+
48
58
 
49
59
  class TransformArgs(BaseModel, use_enum_values=True):
50
60
  """
@@ -70,9 +80,6 @@ class TransformArgs(BaseModel, use_enum_values=True):
70
80
  return value
71
81
 
72
82
  def is_online(self) -> bool:
73
- return self.location not in (
74
- TransformLocation.WEIGHT_INPUT,
75
- TransformLocation.WEIGHT_OUTPUT,
76
- )
83
+ return TransformLocation(self.location).is_online()
77
84
 
78
85
  model_config = ConfigDict(extra="forbid")
@@ -34,6 +34,8 @@ def get_transform_size(
34
34
  :param head_dim: size of head when transform is applied to mha
35
35
  :return: size of matrix
36
36
  """
37
+ size = None
38
+
37
39
  if isinstance(module, torch.nn.Linear):
38
40
  if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
39
41
  size = module.in_features
@@ -44,11 +46,13 @@ def get_transform_size(
44
46
  size = module.num_embeddings
45
47
  else:
46
48
  size = module.embedding_dim
47
- else:
48
- raise NotImplementedError(f"Transforms on {type(module)} are not supported")
49
+ elif head_dim is None:
50
+ raise NotImplementedError(
51
+ f"Transforms on {type(module)} are not supported without head_dim"
52
+ )
49
53
 
50
54
  if head_dim is not None:
51
- if size % head_dim != 0:
55
+ if size is not None and size % head_dim != 0:
52
56
  raise ValueError(
53
57
  f"{head_dim} must divide {size} for {type(module)} at {location}"
54
58
  )
@@ -105,11 +109,11 @@ def apply_transform_weight(
105
109
 
106
110
  assert transform_weight.shape[0] == transform_weight.shape[1]
107
111
 
108
- if module_type == torch.nn.Linear:
109
- if location == TransformLocation.INPUT:
110
- return _multihead_matmul(value, transform_weight)
112
+ if TransformLocation(location).is_online():
113
+ return _multihead_matmul(value, transform_weight)
111
114
 
112
- elif location == TransformLocation.WEIGHT_INPUT:
115
+ if module_type == torch.nn.Linear:
116
+ if location == TransformLocation.WEIGHT_INPUT:
113
117
  # equivalent to (transform_weight @ value.T).T
114
118
  return _multihead_matmul(value, transform_weight.T)
115
119
 
@@ -117,26 +121,14 @@ def apply_transform_weight(
117
121
  # equivalent to (value.T @ transform_weight).T
118
122
  return _multihead_matmul(transform_weight.T, value)
119
123
 
120
- elif location == TransformLocation.OUTPUT:
121
- return _multihead_matmul(value, transform_weight)
122
-
123
124
  # similar derivation to torch.nn.Linear, but `y = (x W)`
124
125
  elif module_type == torch.nn.Embedding:
125
- if location == TransformLocation.INPUT:
126
- return _multihead_matmul(value, transform_weight)
127
-
128
- elif location == TransformLocation.WEIGHT_INPUT:
129
- return _multihead_matmul(
130
- transform_weight,
131
- value,
132
- )
126
+ if location == TransformLocation.WEIGHT_INPUT:
127
+ return _multihead_matmul(transform_weight, value)
133
128
 
134
129
  elif location == TransformLocation.WEIGHT_OUTPUT:
135
130
  return _multihead_matmul(value, transform_weight)
136
131
 
137
- elif location == TransformLocation.OUTPUT:
138
- return _multihead_matmul(value, transform_weight)
139
-
140
132
  raise NotImplementedError(
141
133
  f"Applying transforms to {module_type} {location} is not supported"
142
134
  )
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.12.3.a20251023'
20
+ __version__ = version = '0.12.3.a20251030'
21
21
  __version_tuple__ = version_tuple = (0, 12, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251023
3
+ Version: 0.12.3a20251030
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/vllm-project/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -75,6 +75,7 @@ src/compressed_tensors/quantization/lifecycle/helpers.py
75
75
  src/compressed_tensors/quantization/lifecycle/initialize.py
76
76
  src/compressed_tensors/quantization/utils/__init__.py
77
77
  src/compressed_tensors/quantization/utils/helpers.py
78
+ src/compressed_tensors/quantization/utils/mxfp4_utils.py
78
79
  src/compressed_tensors/registry/__init__.py
79
80
  src/compressed_tensors/registry/registry.py
80
81
  src/compressed_tensors/transform/__init__.py
@@ -113,6 +114,7 @@ tests/test_compressors/quantized_compressors/test_fp4_quant.py
113
114
  tests/test_compressors/quantized_compressors/test_fp8_quant.py
114
115
  tests/test_compressors/quantized_compressors/test_int_quant.py
115
116
  tests/test_compressors/quantized_compressors/test_pack_quant.py
117
+ tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py
116
118
  tests/test_compressors/sparse_compressors/__init__.py
117
119
  tests/test_compressors/sparse_compressors/test_bitmask.py
118
120
  tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py
@@ -142,6 +144,7 @@ tests/test_quantization/test_configs/__init__.py
142
144
  tests/test_quantization/test_configs/test_bit_depths.py
143
145
  tests/test_quantization/test_configs/test_strategies.py
144
146
  tests/test_quantization/test_utils/test_helpers.py
147
+ tests/test_quantization/test_utils/test_mxfp4_utils.py
145
148
  tests/test_transform/conftest.py
146
149
  tests/test_transform/test_transform_args.py
147
150
  tests/test_transform/test_transform_config.py
@@ -15,6 +15,7 @@
15
15
 
16
16
  import math
17
17
  import shutil
18
+ import tempfile
18
19
  from collections import OrderedDict
19
20
 
20
21
  import pytest
@@ -170,12 +171,13 @@ def test_reload_match(tmp_path, num_bits):
170
171
  )
171
172
  save_file(compressed_state_dict, tmp_path / "model.safetensors")
172
173
 
173
- reconstructed_dense_gen = compressor.decompress(
174
- tmp_path, names_to_scheme=quantized_modules_to_scheme
175
- )
176
174
  reconstructed_dense = {}
177
- for name, value in reconstructed_dense_gen:
178
- reconstructed_dense[name] = value
175
+ with tempfile.TemporaryDirectory():
176
+ reconstructed_dense_gen = compressor.decompress(
177
+ tmp_path, names_to_scheme=quantized_modules_to_scheme
178
+ )
179
+ for name, value in reconstructed_dense_gen:
180
+ reconstructed_dense[name] = value
179
181
 
180
182
  fake_quant_dummy = fake_quantize(
181
183
  dense_state_dict["dummy.weight"],
@@ -473,3 +475,91 @@ def test_unpack_from_int32(num_bits, values, expected_tensor):
473
475
  unpacked_tensor = unpack_from_int32(values, num_bits, expected_tensor.shape)
474
476
  assert torch.equal(unpacked_tensor, unpacked_tensor)
475
477
  assert unpacked_tensor.dtype == unpacked_tensor.dtype
478
+
479
+
480
+ @pytest.mark.parametrize(
481
+ "strategy,group_size",
482
+ [
483
+ (QuantizationStrategy.GROUP, 128),
484
+ (QuantizationStrategy.CHANNEL, None),
485
+ ],
486
+ )
487
+ def test_asymmetric_zero_point_decompression(strategy, group_size, tmp_path):
488
+ """
489
+ Test that zero-point packing and unpacking works correctly for asymmetric
490
+ quantization with GROUP and CHANNEL strategies.
491
+ """
492
+ shape = (512, 1024)
493
+
494
+ if strategy == QuantizationStrategy.CHANNEL:
495
+ expected_zp_shape = (shape[0], 1)
496
+ elif strategy == QuantizationStrategy.GROUP:
497
+ num_groups = shape[1] // group_size
498
+ expected_zp_shape = (shape[0], max(num_groups, 1))
499
+
500
+ dense_state_dict = {
501
+ "dummy.weight": torch.randn(shape),
502
+ "dummy.weight_scale": torch.rand(expected_zp_shape).to(torch.float32),
503
+ "dummy.weight_zero_point": torch.randint(-8, 8, expected_zp_shape).to(
504
+ torch.int8
505
+ ),
506
+ }
507
+
508
+ quant_config = get_dummy_quant_config(
509
+ num_bits=4, strategy=strategy.value, symmetric=False, group_size=group_size
510
+ )
511
+
512
+ compressor = PackedQuantizationCompressor(config=quant_config)
513
+ quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]}
514
+ compressed_state_dict = compressor.compress(
515
+ dense_state_dict.copy(), names_to_scheme=quantized_modules_to_scheme
516
+ )
517
+
518
+ assert "dummy.weight_zero_point" in compressed_state_dict
519
+ assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.int32
520
+
521
+ save_file(compressed_state_dict, tmp_path / "model.safetensors")
522
+
523
+ reconstructed_dense_gen = compressor.decompress(
524
+ tmp_path, names_to_scheme=quantized_modules_to_scheme
525
+ )
526
+ reconstructed_dense = {}
527
+ for name, value in reconstructed_dense_gen:
528
+ reconstructed_dense[name] = value
529
+
530
+ assert "dummy" in reconstructed_dense
531
+ assert "weight" in reconstructed_dense["dummy"]
532
+
533
+ assert reconstructed_dense["dummy"]["weight"].shape == shape
534
+
535
+ shutil.rmtree(tmp_path)
536
+
537
+
538
+ @pytest.mark.parametrize(
539
+ "num_bits,strategy",
540
+ [
541
+ (4, QuantizationStrategy.GROUP),
542
+ (4, QuantizationStrategy.CHANNEL),
543
+ (8, QuantizationStrategy.GROUP),
544
+ (8, QuantizationStrategy.CHANNEL),
545
+ ],
546
+ )
547
+ def test_zero_point_pack_unpack_consistency(num_bits, strategy):
548
+ """
549
+ Test that packing and unpacking zero-points preserves values correctly.
550
+ """
551
+ if strategy == QuantizationStrategy.GROUP:
552
+ shape = (512, 8)
553
+ else:
554
+ shape = (512, 1)
555
+
556
+ max_val = (1 << (num_bits - 1)) - 1
557
+ min_val = -(1 << (num_bits - 1))
558
+ original_zp = torch.randint(min_val, max_val + 1, shape).to(torch.int8)
559
+
560
+ packed_zp = pack_to_int32(original_zp, num_bits, packed_dim=0)
561
+
562
+ unpacked_zp = unpack_from_int32(packed_zp, num_bits, shape, packed_dim=0)
563
+
564
+ assert torch.equal(original_zp, unpacked_zp)
565
+ assert unpacked_zp.dtype == torch.int8