compressed-tensors 0.12.3a20251008__tar.gz → 0.12.3a20251009__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/PKG-INFO +1 -1
  2. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/lifecycle/forward.py +1 -1
  3. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/lifecycle/initialize.py +9 -2
  4. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/quant_args.py +1 -0
  5. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/quant_scheme.py +1 -0
  6. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/version.py +1 -1
  7. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors.egg-info/PKG-INFO +1 -1
  8. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors.egg-info/SOURCES.txt +2 -0
  9. compressed_tensors-0.12.3a20251009/tests/mock_observer.py +173 -0
  10. compressed_tensors-0.12.3a20251009/tests/test_quantization/lifecycle/test_static_lifecycle.py +388 -0
  11. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/.gitkeep +0 -0
  12. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/actions/test/action.yml +0 -0
  13. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/scripts/step-status +0 -0
  14. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/workflows/build-test.yml +0 -0
  15. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/workflows/build.yml +0 -0
  16. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/workflows/post-release-nightly-build.yml +0 -0
  17. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/workflows/quality-check.yaml +0 -0
  18. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/workflows/report.yml +0 -0
  19. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/workflows/test-check.yaml +0 -0
  20. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/workflows/test.yml +0 -0
  21. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/workflows/trigger-all.yml +0 -0
  22. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.github/workflows/upload.yml +0 -0
  23. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/.gitignore +0 -0
  24. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/LICENSE +0 -0
  25. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/Makefile +0 -0
  26. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/README.md +0 -0
  27. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  28. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/examples/bit_packing/int4_config.json +0 -0
  29. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/examples/bitmask_compression.ipynb +0 -0
  30. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  31. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  32. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/examples/llama_1.1b/example_quant_config.json +0 -0
  33. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  34. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/examples/quantize_and_pack_int4.ipynb +0 -0
  35. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/pyproject.toml +0 -0
  36. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/setup.cfg +0 -0
  37. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/setup.py +0 -0
  38. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/__init__.py +0 -0
  39. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/README.md +0 -0
  40. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/__init__.py +0 -0
  41. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/base.py +0 -0
  42. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/__init__.py +0 -0
  43. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/base.py +0 -0
  44. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/helpers.py +0 -0
  45. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  46. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  47. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  48. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  49. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  50. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  51. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  52. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  53. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  54. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  55. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  56. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  57. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  58. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  59. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/config/__init__.py +0 -0
  60. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/config/base.py +0 -0
  61. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/config/dense.py +0 -0
  62. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/config/format.py +0 -0
  63. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  64. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  65. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/linear/__init__.py +0 -0
  66. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  67. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/logger.py +0 -0
  68. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/__init__.py +0 -0
  69. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  70. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  71. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  72. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  73. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/quant_config.py +0 -0
  74. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/quant_metadata.py +0 -0
  75. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  76. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  77. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/registry/__init__.py +0 -0
  78. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/registry/registry.py +0 -0
  79. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/__init__.py +0 -0
  80. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/apply.py +0 -0
  81. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  82. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/factory/base.py +0 -0
  83. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/factory/hadamard.py +0 -0
  84. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/factory/matrix_multiply.py +0 -0
  85. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  86. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/transform_args.py +0 -0
  87. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/transform_config.py +0 -0
  88. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  89. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  90. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  91. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  92. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/transform/utils/matrix.py +0 -0
  93. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/utils/__init__.py +0 -0
  94. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/utils/helpers.py +0 -0
  95. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/utils/internal.py +0 -0
  96. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/utils/match.py +0 -0
  97. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/utils/offload.py +0 -0
  98. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/utils/permutations_24.py +0 -0
  99. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  100. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  101. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors/utils/type.py +0 -0
  102. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  103. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors.egg-info/requires.txt +0 -0
  104. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  105. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/__init__.py +0 -0
  106. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/conftest.py +0 -0
  107. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/__init__.py +0 -0
  108. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/model_compressors/__init__.py +0 -0
  109. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  110. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  111. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  112. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  113. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  114. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  115. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  116. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  117. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  118. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  119. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  120. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_configs/__init__.py +0 -0
  121. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_configs/test_base.py +0 -0
  122. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_configs/test_infer_quant.py +0 -0
  123. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  124. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_linear/__init__.py +0 -0
  125. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_linear/test_compressed_linear.py +0 -0
  126. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/__init__.py +0 -0
  127. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/lifecycle/__init__.py +0 -0
  128. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/lifecycle/conftest.py +0 -0
  129. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  130. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  131. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  132. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  133. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  134. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  135. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/test_configs/__init__.py +0 -0
  136. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  137. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  138. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/test_quant_args.py +0 -0
  139. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/test_quant_config.py +0 -0
  140. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/test_quant_scheme.py +0 -0
  141. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  142. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_registry.py +0 -0
  143. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_transform/conftest.py +0 -0
  144. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_transform/factory/test_correctness.py +0 -0
  145. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_transform/factory/test_memory.py +0 -0
  146. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_transform/factory/test_serialization.py +0 -0
  147. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_transform/test_transform_args.py +0 -0
  148. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_transform/test_transform_config.py +0 -0
  149. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_transform/test_transform_scheme.py +0 -0
  150. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_transform/utils/test_hadamard.py +0 -0
  151. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_utils/__init__.py +0 -0
  152. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_utils/test_helpers.py +0 -0
  153. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_utils/test_match.py +0 -0
  154. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_utils/test_offload.py +0 -0
  155. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_utils/test_safetensors_load.py +0 -0
  156. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/test_utils/test_type.py +0 -0
  157. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/tests/testing_utils.py +0 -0
  158. {compressed_tensors-0.12.3a20251008 → compressed_tensors-0.12.3a20251009}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251008
3
+ Version: 0.12.3a20251009
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -330,7 +330,7 @@ def _process_quantization(
330
330
  inv_perm = torch.argsort(perm)
331
331
  output = output.index_select(-1, inv_perm)
332
332
 
333
- else: # covers channel, token and tensor strategies
333
+ else: # covers tensor, channel, token, and attn_head strategies
334
334
  if do_quantize:
335
335
  output = _quantize(
336
336
  x=x,
@@ -14,7 +14,7 @@
14
14
 
15
15
 
16
16
  import logging
17
- from typing import Optional, Tuple
17
+ from typing import Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
  from compressed_tensors.quantization import (
@@ -152,7 +152,7 @@ def initialize_qparams(
152
152
  module: Module,
153
153
  base_name: str,
154
154
  quantization_args: QuantizationArgs,
155
- observed_shape: Tuple[int],
155
+ observed_shape: Tuple[Union[int, None]],
156
156
  observed_dtype: torch.dtype,
157
157
  force_zero_point: bool = True,
158
158
  ):
@@ -234,6 +234,13 @@ def initialize_qparams(
234
234
  num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy)
235
235
  expected_shape = (num_rows, num_cols)
236
236
 
237
+ elif strategy == QuantizationStrategy.ATTN_HEAD:
238
+ # (batch_size, num_attention_heads, seq_len, head_dim)
239
+ if len(observed_shape) < 3:
240
+ raise ValueError("Attention quant requires at least 3 observed dimensions")
241
+
242
+ expected_shape = (observed_shape[-3], 1, 1)
243
+
237
244
  else:
238
245
  assert False, f"Unknown strategy {strategy}"
239
246
 
@@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum):
101
101
  BLOCK = "block"
102
102
  TOKEN = "token"
103
103
  TENSOR_GROUP = "tensor_group"
104
+ ATTN_HEAD = "attn_head"
104
105
 
105
106
 
106
107
  class DynamicType(str, Enum):
@@ -65,6 +65,7 @@ class QuantizationScheme(BaseModel):
65
65
  QuantizationStrategy.TENSOR,
66
66
  QuantizationStrategy.GROUP,
67
67
  QuantizationStrategy.TENSOR_GROUP,
68
+ QuantizationStrategy.ATTN_HEAD,
68
69
  ):
69
70
  if (
70
71
  inputs.strategy == QuantizationStrategy.GROUP
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.12.3.a20251008'
20
+ __version__ = version = '0.12.3.a20251009'
21
21
  __version_tuple__ = version_tuple = (0, 12, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251008
3
+ Version: 0.12.3a20251009
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -101,6 +101,7 @@ src/compressed_tensors/utils/semi_structured_conversions.py
101
101
  src/compressed_tensors/utils/type.py
102
102
  tests/__init__.py
103
103
  tests/conftest.py
104
+ tests/mock_observer.py
104
105
  tests/test_registry.py
105
106
  tests/testing_utils.py
106
107
  tests/test_compressors/__init__.py
@@ -134,6 +135,7 @@ tests/test_quantization/lifecycle/test_enabled.py
134
135
  tests/test_quantization/lifecycle/test_forward.py
135
136
  tests/test_quantization/lifecycle/test_initialize.py
136
137
  tests/test_quantization/lifecycle/test_lifecycle.py
138
+ tests/test_quantization/lifecycle/test_static_lifecycle.py
137
139
  tests/test_quantization/test_configs/__init__.py
138
140
  tests/test_quantization/test_configs/test_bit_depths.py
139
141
  tests/test_quantization/test_configs/test_strategies.py
@@ -0,0 +1,173 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+ from weakref import ref
17
+
18
+ import torch
19
+ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
20
+ from compressed_tensors.quantization.utils import (
21
+ calculate_qparams,
22
+ generate_gparam,
23
+ strategy_cdiv,
24
+ )
25
+
26
+
27
+ class MockMinMaxObserver(torch.nn.Module):
28
+ def __init__(self, base_name: str, args: QuantizationArgs, module: torch.nn.Module):
29
+ super().__init__()
30
+ self.parent = ref(module)
31
+ self.base_name = base_name
32
+ self.args = args
33
+
34
+ # used for testing
35
+ self.min_vals = None
36
+ self.max_vals = None
37
+
38
+ def get_min_max(self, observed: torch.Tensor):
39
+ min_vals = torch.amin(observed, dim=(0, -1))
40
+ max_vals = torch.amax(observed, dim=(0, -1))
41
+
42
+ return min_vals, max_vals
43
+
44
+ def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
45
+ observed = flatten_for_quantization(observed, self.base_name, self.args)
46
+
47
+ self.min_vals, self.max_vals = self.get_min_max(observed)
48
+
49
+ scales, zero_points = calculate_qparams(
50
+ min_vals=self.min_vals,
51
+ max_vals=self.max_vals,
52
+ quantization_args=self.args,
53
+ global_scale=getattr(self.parent(), f"{self.base_name}_global_scale", None),
54
+ )
55
+
56
+ return scales, zero_points
57
+
58
+ def get_global_scale(self, observed: torch.Tensor):
59
+ observed = observed.reshape((1, 1, -1)) # per tensor reshape
60
+ min_vals, max_vals = self.get_min_max(observed)
61
+ global_scale = generate_gparam(min_vals, max_vals)
62
+
63
+ return global_scale
64
+
65
+
66
+ def flatten_for_quantization(
67
+ value: torch.Tensor, base_name: str, args: QuantizationArgs
68
+ ) -> torch.Tensor:
69
+ if base_name == "weight":
70
+ return flatten_weight_for_quantization(value, args)
71
+ elif base_name in ("input", "output"):
72
+ return flatten_activation_for_quantization(value, args)
73
+ elif base_name in ("q", "k", "v"):
74
+ return flatten_attention_for_quantization(value, args)
75
+ else:
76
+ raise ValueError(f"Unknown quantization base name: {base_name}")
77
+
78
+
79
+ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs):
80
+ # value.shape = (num_rows, num_cols)
81
+
82
+ if args.strategy == QuantizationStrategy.TENSOR:
83
+ # (1, 1, num_weight_elems)
84
+ return value.reshape((1, 1, -1))
85
+
86
+ if args.strategy == QuantizationStrategy.TOKEN:
87
+ raise ValueError("Token quantization cannot be applied to weights")
88
+
89
+ if args.strategy == QuantizationStrategy.CHANNEL:
90
+ # (1, num_rows, 1, num_cols)
91
+ return value.unsqueeze(-2).unsqueeze(0)
92
+
93
+ if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
94
+ # (1, num_rows, num_groups, group_size)
95
+ return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0)
96
+
97
+ if args.strategy == QuantizationStrategy.BLOCK:
98
+ # (1, num_block_rows, num_block_cols, block_width * block_height)
99
+ block_height, block_width = args.block_structure
100
+ num_rows, num_cols = value.shape
101
+ num_block_rows = strategy_cdiv(num_rows, block_height, args.strategy)
102
+ num_block_cols = strategy_cdiv(num_cols, block_width, args.strategy)
103
+ return (
104
+ value.reshape(
105
+ num_block_rows,
106
+ block_height,
107
+ num_block_cols,
108
+ block_width,
109
+ )
110
+ .transpose(1, 2)
111
+ .flatten(-2, -1)
112
+ .unsqueeze(0)
113
+ )
114
+
115
+ if args.strategy == QuantizationStrategy.ATTN_HEAD:
116
+ raise ValueError("attention head quantization cannot be applied to weights")
117
+
118
+ assert False, f"Unknown strategy {args.strategy}"
119
+
120
+
121
+ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs):
122
+ # value.shape = (batch_size, seq_len, hidden_dim)
123
+
124
+ if args.strategy == QuantizationStrategy.TENSOR:
125
+ # (batch_size * seq_len, 1, hidden_dim)
126
+ return value.reshape((-1, 1, value.size(-1)))
127
+
128
+ if args.strategy == QuantizationStrategy.TOKEN:
129
+ # (batch_size, seq_len, hidden_dim)
130
+ # warning: token quantization uses `compute_dynamic_scales_and_zp`
131
+ return value.flatten(2, -1)
132
+
133
+ if args.strategy == QuantizationStrategy.CHANNEL:
134
+ raise ValueError("Channel quantization cannot be applied to activations")
135
+
136
+ if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
137
+ # (batch_size * seq_len, num_groups, group_size)
138
+ # warning: group activation quantization uses compute_dynamic_scales_and_zp
139
+ return value.flatten(0, 1).unflatten(-1, (-1, args.group_size))
140
+
141
+ if args.strategy == QuantizationStrategy.BLOCK:
142
+ raise ValueError("Block quantization cannot be applied to activations")
143
+
144
+ if args.strategy == QuantizationStrategy.ATTN_HEAD:
145
+ raise ValueError("attention head quantization cannot be applied to linear acts")
146
+
147
+ assert False, f"Unknown strategy {args.strategy}"
148
+
149
+
150
+ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs):
151
+ # value.shape = (batch_size, num_heads, seq_len, head_dim)
152
+
153
+ if args.strategy == QuantizationStrategy.TENSOR:
154
+ # (batch_size * seq_len, 1, num_heads * head_dim)
155
+ return value.transpose(1, 2).flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
156
+
157
+ if args.strategy == QuantizationStrategy.TOKEN:
158
+ raise ValueError("Token quantization cannot be applied to attention")
159
+
160
+ if args.strategy == QuantizationStrategy.CHANNEL:
161
+ raise ValueError("Channel quantization cannot be applied to attention")
162
+
163
+ if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
164
+ raise ValueError("Group quantization cannot be applied to attention")
165
+
166
+ if args.strategy == QuantizationStrategy.BLOCK:
167
+ raise ValueError("Block quantization cannot be applied to attention")
168
+
169
+ if args.strategy == QuantizationStrategy.ATTN_HEAD:
170
+ # (batch_size * seq_len, num_heads, 1, 1, head_dim)
171
+ return value.transpose(1, 2).flatten(0, 1).unsqueeze(-2).unsqueeze(-2)
172
+
173
+ assert False, f"Unknown strategy {args.strategy}"
@@ -0,0 +1,388 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pytest
16
+ import torch
17
+ from compressed_tensors.quantization import (
18
+ QuantizationScheme,
19
+ forward_quantize,
20
+ initialize_module_for_quantization,
21
+ initialize_qparams,
22
+ )
23
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
24
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
25
+ from tests.mock_observer import MockMinMaxObserver
26
+
27
+
28
+ @pytest.mark.parametrize(
29
+ "args,exp_min_val,exp_max_val,exp_quant,exp_loss",
30
+ [
31
+ (
32
+ QuantizationArgs(
33
+ num_bits=4,
34
+ type="int",
35
+ symmetric=True,
36
+ strategy="tensor", # equivalent to token
37
+ ),
38
+ torch.tensor([0.0]),
39
+ torch.tensor([23.0]),
40
+ torch.tensor(
41
+ [
42
+ [0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250],
43
+ [6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500],
44
+ [12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750],
45
+ [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000],
46
+ ],
47
+ dtype=torch.bfloat16,
48
+ ),
49
+ 0.85,
50
+ ),
51
+ # token is not supported
52
+ (
53
+ QuantizationArgs(
54
+ num_bits=4,
55
+ type="int",
56
+ symmetric=True,
57
+ strategy="channel",
58
+ ),
59
+ torch.tensor([[0], [6], [12], [18]]),
60
+ torch.tensor([[5], [11], [17], [23]]),
61
+ torch.tensor(
62
+ [
63
+ [0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875],
64
+ [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500],
65
+ [11.3125, 13.6250, 13.6250, 15.8750, 15.8750, 15.8750],
66
+ [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000],
67
+ ],
68
+ dtype=torch.bfloat16,
69
+ ),
70
+ 0.45,
71
+ ),
72
+ (
73
+ QuantizationArgs(
74
+ num_bits=4,
75
+ type="int",
76
+ symmetric=True,
77
+ strategy="group",
78
+ group_size=3,
79
+ ),
80
+ torch.tensor([[0, 3], [6, 9], [12, 15], [18, 21]]),
81
+ torch.tensor([[2, 5], [8, 11], [14, 17], [20, 23]]),
82
+ torch.tensor(
83
+ [
84
+ [0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875],
85
+ [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500],
86
+ [11.1875, 13.0625, 13.0625, 15.8750, 15.8750, 15.8750],
87
+ [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000],
88
+ ],
89
+ ),
90
+ 0.45,
91
+ ),
92
+ (
93
+ QuantizationArgs(
94
+ num_bits=4,
95
+ type="float", # tensor group requires FP4
96
+ symmetric=True,
97
+ strategy="tensor_group", # requires float4
98
+ group_size=3,
99
+ ),
100
+ torch.tensor([[0, 3], [6, 9], [12, 15], [18, 21]]),
101
+ torch.tensor([[2, 5], [8, 11], [14, 17], [20, 23]]),
102
+ torch.tensor(
103
+ [
104
+ [0.0000, 1.0234, 2.0469, 3.2812, 3.2812, 4.9375],
105
+ [5.4688, 8.1875, 8.1875, 10.6875, 10.6875, 10.6875],
106
+ [9.8750, 14.7500, 14.7500, 16.3750, 16.3750, 16.3750],
107
+ [19.7500, 19.7500, 19.7500, 23.0000, 23.0000, 23.0000],
108
+ ],
109
+ ),
110
+ 1.1,
111
+ ),
112
+ (
113
+ QuantizationArgs(
114
+ num_bits=4,
115
+ type="int",
116
+ symmetric=True,
117
+ strategy="block",
118
+ block_structure=[2, 3],
119
+ ),
120
+ torch.tensor([[0, 3], [12, 15]]),
121
+ torch.tensor([[8, 11], [20, 23]]),
122
+ torch.tensor(
123
+ [
124
+ [0.0000, 1.0703, 2.1406, 2.9375, 4.4062, 4.4062],
125
+ [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500],
126
+ [10.6875, 13.3750, 13.3750, 15.3125, 15.3125, 18.3750],
127
+ [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000],
128
+ ],
129
+ ),
130
+ 0.5,
131
+ ),
132
+ ],
133
+ )
134
+ def test_static_weight_quantization(
135
+ args, exp_min_val, exp_max_val, exp_quant, exp_loss
136
+ ):
137
+ """
138
+ weight = tensor([[ 0, 1, 2, 3, 4, 5],
139
+ [ 6, 7, 8, 9, 10, 11],
140
+ [12, 13, 14, 15, 16, 17],
141
+ [18, 19, 20, 21, 22, 23]])
142
+ """
143
+ # set up weight
144
+ input_size, output_size = 6, 4
145
+ linear = torch.nn.Linear(input_size, output_size, bias=False)
146
+ linear.weight.data = torch.arange(
147
+ input_size * output_size, dtype=torch.bfloat16
148
+ ).reshape(output_size, input_size)
149
+
150
+ # initialize quantization parameters
151
+ scheme = QuantizationScheme(targets=[], weights=args)
152
+ initialize_module_for_quantization(linear, scheme)
153
+ assert getattr(linear, "quantization_scheme") is scheme
154
+ linear.weight_observer = MockMinMaxObserver("weight", args, linear)
155
+
156
+ # calibrate_global_scale
157
+ if hasattr(linear, "weight_global_scale"):
158
+ global_scale = linear.weight_observer.get_global_scale(linear.weight)
159
+ linear.weight_global_scale.data = global_scale
160
+
161
+ # calibrate quantization parameters
162
+ scale, zero_point = linear.weight_observer(linear.weight)
163
+ linear.weight_scale.data = scale
164
+ linear.weight_zero_point.data = zero_point
165
+ assert torch.equal(linear.weight_observer.min_vals, exp_min_val)
166
+ assert torch.equal(linear.weight_observer.max_vals, exp_max_val)
167
+
168
+ # forward pass
169
+ input = torch.eye(input_size, dtype=torch.bfloat16)
170
+ output = linear(input)
171
+
172
+ assert torch.allclose(output.T, exp_quant.to(output.dtype))
173
+ assert torch.nn.functional.mse_loss(output.T, linear.weight) <= exp_loss
174
+
175
+
176
+ @pytest.mark.parametrize(
177
+ "args,exp_min_val,exp_max_val,exp_quant,exp_loss",
178
+ [
179
+ (
180
+ QuantizationArgs(
181
+ num_bits=4,
182
+ type="int",
183
+ symmetric=True,
184
+ strategy="tensor",
185
+ ),
186
+ torch.tensor([0.0]),
187
+ torch.tensor([11.0]),
188
+ torch.tensor(
189
+ [
190
+ [
191
+ [0.0000, 1.4688, 1.4688, 2.9375, 4.4062, 4.4062],
192
+ [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500],
193
+ ]
194
+ ]
195
+ ),
196
+ 0.2,
197
+ ),
198
+ # static token is not supported
199
+ # channel is not supported
200
+ # group is not supported
201
+ (
202
+ QuantizationArgs(
203
+ num_bits=4,
204
+ type="float", # must be fp4
205
+ symmetric=True,
206
+ strategy="tensor_group",
207
+ dynamic="local",
208
+ group_size=3,
209
+ ),
210
+ None,
211
+ None,
212
+ torch.tensor(
213
+ [
214
+ [
215
+ [0.0000, 0.9844, 1.9688, 3.4062, 3.4062, 5.1250],
216
+ [5.2500, 7.8750, 7.8750, 7.3438, 11.0000, 11.0000],
217
+ ]
218
+ ]
219
+ ),
220
+ 0.5,
221
+ ),
222
+ # block is not supported
223
+ # head is not supported
224
+ ],
225
+ )
226
+ def test_static_activation_quantization(
227
+ args, exp_min_val, exp_max_val, exp_quant, exp_loss
228
+ ):
229
+ """
230
+ input = tensor([[ 0, 1, 2, 3, 4, 5]
231
+ [ 6, 7, 8, 9, 10, 11]])
232
+ """
233
+ # set up activation (and identity weight)
234
+ batch_size, seq_len, input_size = 1, 2, 6
235
+ input = torch.arange(
236
+ (batch_size * seq_len * input_size), dtype=torch.bfloat16
237
+ ).reshape((batch_size, seq_len, input_size))
238
+ linear = torch.nn.Linear(input_size, input_size, bias=False)
239
+ linear.weight.data = torch.eye(input_size, dtype=torch.bfloat16)
240
+
241
+ # initialize quantization parameters
242
+ scheme = QuantizationScheme(targets=[], input_activations=args)
243
+ initialize_module_for_quantization(linear, scheme)
244
+ assert getattr(linear, "quantization_scheme") is scheme
245
+ linear.input_observer = MockMinMaxObserver("input", args, linear)
246
+
247
+ # calibrate quantization parameters
248
+ def calibrate_input_hook(_, args):
249
+ if hasattr(linear, "input_global_scale"):
250
+ global_scale = linear.input_observer.get_global_scale(args[0])
251
+ linear.input_global_scale.data = global_scale
252
+
253
+ if linear.quantization_scheme.input_activations.dynamic is False:
254
+ scale, zero_point = linear.input_observer(args[0])
255
+ linear.input_scale.data = scale
256
+ linear.input_zero_point.data = zero_point
257
+
258
+ linear.register_forward_pre_hook(calibrate_input_hook)
259
+
260
+ # calibration forward pass
261
+ output = linear(input)
262
+
263
+ # check calibration
264
+ if exp_min_val is not None:
265
+ assert torch.equal(linear.input_observer.min_vals, exp_min_val)
266
+ if exp_max_val is not None:
267
+ assert torch.equal(linear.input_observer.max_vals, exp_max_val)
268
+
269
+ # check forward pass
270
+ assert torch.allclose(output, exp_quant.to(output.dtype))
271
+ assert torch.nn.functional.mse_loss(output, input) <= exp_loss
272
+
273
+
274
+ class MockAttention(torch.nn.Module):
275
+ pass
276
+
277
+
278
+ @pytest.mark.filterwarnings("ignore::UserWarning")
279
+ @pytest.mark.parametrize(
280
+ "args,exp_min_val,exp_max_val,exp_quant,exp_loss",
281
+ [
282
+ (
283
+ QuantizationArgs(
284
+ num_bits=4,
285
+ type="int",
286
+ symmetric=True,
287
+ strategy="tensor",
288
+ ),
289
+ torch.tensor([0.0]),
290
+ torch.tensor([23.0]),
291
+ torch.tensor(
292
+ [
293
+ [
294
+ [
295
+ [0.0000, 0.0000, 3.0625, 3.0625],
296
+ [3.0625, 6.1250, 6.1250, 6.1250],
297
+ [9.1875, 9.1875, 9.1875, 12.2500],
298
+ ],
299
+ [
300
+ [12.2500, 12.2500, 15.3125, 15.3125],
301
+ [15.3125, 18.3750, 18.3750, 18.3750],
302
+ [21.5000, 21.5000, 21.5000, 21.5000],
303
+ ],
304
+ ]
305
+ ]
306
+ ),
307
+ 0.81,
308
+ ),
309
+ # static token is not supported
310
+ # channel is not supported
311
+ # group is not supported
312
+ # tensor group is not supported
313
+ # block is not supported
314
+ (
315
+ QuantizationArgs(
316
+ num_bits=4,
317
+ type="int",
318
+ symmetric=True,
319
+ strategy="attn_head",
320
+ ),
321
+ torch.tensor([[[0.0]], [[12.0]]]),
322
+ torch.tensor([[[11.0]], [[23.0]]]),
323
+ torch.tensor(
324
+ [
325
+ [
326
+ [
327
+ [0.0000, 1.4688, 1.4688, 2.9375],
328
+ [4.4062, 4.4062, 5.8750, 7.3438],
329
+ [7.3438, 8.8125, 10.2500, 10.2500],
330
+ ],
331
+ [
332
+ [12.2500, 12.2500, 15.3125, 15.3125],
333
+ [15.3125, 18.3750, 18.3750, 18.3750],
334
+ [21.5000, 21.5000, 21.5000, 21.5000],
335
+ ],
336
+ ]
337
+ ]
338
+ ),
339
+ 0.55,
340
+ ),
341
+ ],
342
+ )
343
+ def test_static_attention_quantization(
344
+ args, exp_min_val, exp_max_val, exp_quant, exp_loss
345
+ ):
346
+ """
347
+ input = tensor([[[[ 0., 1., 2., 3.],
348
+ [ 4., 5., 6., 7.],
349
+ [ 8., 9., 10., 11.]],
350
+
351
+ [[12., 13., 14., 15.],
352
+ [16., 17., 18., 19.],
353
+ [20., 21., 22., 23.]]]])
354
+ """
355
+ # set up attention
356
+ batch_size, num_heads, seq_len, head_dim = 1, 2, 3, 4
357
+ input = torch.arange(
358
+ (batch_size * num_heads * seq_len * head_dim), dtype=torch.bfloat16
359
+ ).reshape((batch_size, num_heads, seq_len, head_dim))
360
+ attention = MockAttention()
361
+
362
+ # initialize quantization parameters
363
+ scheme = QuantizationScheme(targets=[], input_activations=args)
364
+ initialize_qparams(
365
+ attention, "k", args, (num_heads, None, head_dim), observed_dtype=torch.bfloat16
366
+ )
367
+ attention.quantization_scheme = scheme
368
+ attention.quantization_status = QuantizationStatus.INITIALIZED
369
+ attention.k_observer = MockMinMaxObserver("k", args, attention)
370
+
371
+ # calibrate quantization parameters
372
+ if scheme.input_activations.dynamic is False:
373
+ scale, zero_point = attention.k_observer(input)
374
+ attention.k_scale.data = scale
375
+ attention.k_zero_point.data = zero_point
376
+
377
+ # calibration forward pass
378
+ output = forward_quantize(attention, input, "k", scheme.input_activations)
379
+
380
+ # check calibration
381
+ if exp_min_val is not None:
382
+ assert torch.equal(attention.k_observer.min_vals, exp_min_val)
383
+ if exp_max_val is not None:
384
+ assert torch.equal(attention.k_observer.max_vals, exp_max_val)
385
+
386
+ # check forward pass
387
+ assert torch.allclose(output, exp_quant.to(output.dtype))
388
+ assert torch.nn.functional.mse_loss(output, input) <= exp_loss