compressed-tensors 0.10.3a20250715__tar.gz → 0.10.3a20250716__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. {compressed_tensors-0.10.3a20250715/src/compressed_tensors.egg-info → compressed_tensors-0.10.3a20250716}/PKG-INFO +1 -1
  2. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/factory/base.py +1 -3
  3. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/factory/hadamard.py +17 -8
  4. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/factory/matrix_multiply.py +18 -8
  5. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/transform_scheme.py +2 -1
  6. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/utils/hadamard.py +2 -2
  7. compressed_tensors-0.10.3a20250716/src/compressed_tensors/transform/utils/matrix.py +179 -0
  8. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/version.py +1 -1
  9. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  10. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors.egg-info/SOURCES.txt +1 -1
  11. compressed_tensors-0.10.3a20250716/tests/test_transform/conftest.py +115 -0
  12. compressed_tensors-0.10.3a20250716/tests/test_transform/factory/test_correctness.py +168 -0
  13. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_transform/utils/test_hadamard.py +2 -2
  14. compressed_tensors-0.10.3a20250715/src/compressed_tensors/transform/utils/utils.py +0 -91
  15. compressed_tensors-0.10.3a20250715/tests/test_transform/conftest.py +0 -54
  16. compressed_tensors-0.10.3a20250715/tests/test_transform/factory/test_correctness.py +0 -89
  17. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.github/.gitkeep +0 -0
  18. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.github/actions/test/action.yml +0 -0
  19. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.github/scripts/step-status +0 -0
  20. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.github/workflows/build-test.yml +0 -0
  21. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.github/workflows/build.yml +0 -0
  22. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.github/workflows/report.yml +0 -0
  23. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.github/workflows/test-check.yaml +0 -0
  24. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.github/workflows/test.yml +0 -0
  25. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.github/workflows/trigger-all.yml +0 -0
  26. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.github/workflows/upload.yml +0 -0
  27. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/.gitignore +0 -0
  28. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/LICENSE +0 -0
  29. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/Makefile +0 -0
  30. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/README.md +0 -0
  31. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  32. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/examples/bit_packing/int4_config.json +0 -0
  33. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/examples/bitmask_compression.ipynb +0 -0
  34. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  35. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  36. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/examples/llama_1.1b/example_quant_config.json +0 -0
  37. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  38. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/examples/quantize_and_pack_int4.ipynb +0 -0
  39. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/pyproject.toml +0 -0
  40. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/setup.cfg +0 -0
  41. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/setup.py +0 -0
  42. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/__init__.py +0 -0
  43. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/README.md +0 -0
  44. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/__init__.py +0 -0
  45. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/base.py +0 -0
  46. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/__init__.py +0 -0
  47. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/base.py +0 -0
  48. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/helpers.py +0 -0
  49. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  50. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  51. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  52. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  53. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  54. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  55. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  56. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  57. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  58. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  59. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  60. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  61. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  62. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  63. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/config/__init__.py +0 -0
  64. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/config/base.py +0 -0
  65. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/config/dense.py +0 -0
  66. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  67. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  68. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/linear/__init__.py +0 -0
  69. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  70. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/__init__.py +0 -0
  71. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  72. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  73. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  74. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  75. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  76. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  77. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/quant_args.py +0 -0
  78. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/quant_config.py +0 -0
  79. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  80. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  81. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  82. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/registry/__init__.py +0 -0
  83. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/registry/registry.py +0 -0
  84. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/__init__.py +0 -0
  85. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/apply.py +0 -0
  86. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  87. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  88. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/transform_args.py +0 -0
  89. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/transform_config.py +0 -0
  90. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  91. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  92. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/utils/__init__.py +0 -0
  93. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/utils/helpers.py +0 -0
  94. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/utils/internal.py +0 -0
  95. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/utils/offload.py +0 -0
  96. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/utils/permutations_24.py +0 -0
  97. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/utils/permute.py +0 -0
  98. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  99. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  100. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  101. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors.egg-info/requires.txt +0 -0
  102. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  103. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/__init__.py +0 -0
  104. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/conftest.py +0 -0
  105. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/__init__.py +0 -0
  106. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/model_compressors/__init__.py +0 -0
  107. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  108. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  109. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  110. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  111. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  112. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  113. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  114. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  115. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  116. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  117. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  118. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_configs/__init__.py +0 -0
  119. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_configs/test_base.py +0 -0
  120. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  121. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_linear/__init__.py +0 -0
  122. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_linear/test_compressed_linear.py +0 -0
  123. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/__init__.py +0 -0
  124. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/lifecycle/__init__.py +0 -0
  125. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/lifecycle/conftest.py +0 -0
  126. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  127. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  128. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  129. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  130. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  131. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  132. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  133. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/test_configs/__init__.py +0 -0
  134. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  135. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  136. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/test_quant_args.py +0 -0
  137. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/test_quant_config.py +0 -0
  138. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/test_quant_scheme.py +0 -0
  139. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  140. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_registry.py +0 -0
  141. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_transform/factory/test_memory.py +0 -0
  142. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_transform/test_transform_args.py +0 -0
  143. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_transform/test_transform_config.py +0 -0
  144. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_transform/test_transform_scheme.py +0 -0
  145. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_utils/__init__.py +0 -0
  146. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_utils/test_helpers.py +0 -0
  147. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_utils/test_offload.py +0 -0
  148. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/test_utils/test_safetensors_load.py +0 -0
  149. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/tests/testing_utils.py +0 -0
  150. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250716}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250715
3
+ Version: 0.10.3a20250716
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -117,10 +117,8 @@ class TransformFactory(RegistryMixin, ABC):
117
117
  TransformLocation.WEIGHT_INPUT,
118
118
  TransformLocation.WEIGHT_OUTPUT,
119
119
  ):
120
- assert isinstance(module, torch.nn.Linear)
121
- assert module.bias is None
122
-
123
120
  # fuse transform into weight
121
+ assert hasattr(module, "weight")
124
122
  with torch.no_grad(), align_module_device(module):
125
123
  update_offload_parameter(module, "weight", transform(module.weight))
126
124
 
@@ -14,13 +14,14 @@
14
14
 
15
15
  from typing import Optional, Union
16
16
 
17
+ import math
17
18
  import torch
18
19
  from compressed_tensors.transform import TransformArgs, TransformScheme
19
20
  from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20
21
  from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
21
- from compressed_tensors.transform.utils.utils import (
22
+ from compressed_tensors.transform.utils.matrix import (
22
23
  apply_transform_weight,
23
- get_matrix_size,
24
+ get_transform_size,
24
25
  )
25
26
  from compressed_tensors.utils import get_execution_device, get_offloaded_device
26
27
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
@@ -51,8 +52,8 @@ class HadamardFactory(TransformFactory):
51
52
  :param module: parent module that transform will be applied to
52
53
  :param args: defines how the transform will be applied to the module
53
54
  """
54
- assert isinstance(module, Linear)
55
- size = get_matrix_size(module, args.location)
55
+ assert hasattr(module, "weight")
56
+ size = get_transform_size(module, args.location, self.scheme.head_dim)
56
57
  dtype = module.weight.dtype
57
58
  device = get_offloaded_device(module)
58
59
  exec_device = get_execution_device(module)
@@ -60,7 +61,7 @@ class HadamardFactory(TransformFactory):
60
61
  factory_kwargs = {"construct_device": exec_device}
61
62
  weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
62
63
  perm = self.perms[weight] if self.scheme.randomize else None
63
- return HadamardTransform(weight, perm, args)
64
+ return HadamardTransform(weight, perm, args, type(module))
64
65
 
65
66
  def _create_weight(
66
67
  self,
@@ -81,12 +82,18 @@ class HadamardFactory(TransformFactory):
81
82
 
82
83
  class HadamardTransform(TransformBase):
83
84
  def __init__(
84
- self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
85
+ self,
86
+ weight: Parameter,
87
+ perm: Optional[Parameter],
88
+ args: TransformArgs,
89
+ module_type: type[torch.nn.Module],
85
90
  ):
86
91
  super().__init__()
87
92
  self.weight = weight
88
93
  self.perm = perm
89
94
  self.args = args
95
+ self.module_type = module_type
96
+ self._scale = math.sqrt(weight.size(0))
90
97
 
91
98
  def forward(self, value: Tensor) -> Tensor:
92
99
  weight = self.weight
@@ -96,5 +103,7 @@ class HadamardTransform(TransformBase):
96
103
 
97
104
  if self.args.inverse:
98
105
  weight = weight.T
99
-
100
- return apply_transform_weight(weight, value, self.args.location)
106
+
107
+ return apply_transform_weight(
108
+ weight, value, self.args.location, self.module_type
109
+ ) / self._scale
@@ -17,9 +17,9 @@ from typing import Optional
17
17
  import torch
18
18
  from compressed_tensors.transform import TransformArgs, TransformScheme
19
19
  from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20
- from compressed_tensors.transform.utils.utils import (
20
+ from compressed_tensors.transform.utils.matrix import (
21
21
  apply_transform_weight,
22
- get_matrix_size,
22
+ get_transform_size,
23
23
  )
24
24
  from compressed_tensors.utils import get_offloaded_device
25
25
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
@@ -50,8 +50,8 @@ class RandomMatrixFactory(TransformFactory):
50
50
  :param module: parent module that transform will be applied to
51
51
  :param args: defines how the transform will be applied to the module
52
52
  """
53
- assert isinstance(module, Linear)
54
- size = get_matrix_size(module, args.location)
53
+ assert hasattr(module, "weight")
54
+ size = get_transform_size(module, args.location, self.scheme.head_dim)
55
55
  dtype = module.weight.dtype
56
56
  device = get_offloaded_device(module)
57
57
 
@@ -59,7 +59,7 @@ class RandomMatrixFactory(TransformFactory):
59
59
  if args.inverse:
60
60
  weight = self.inverses[weight]
61
61
 
62
- return RandomMatrixTransform(weight, args)
62
+ return RandomMatrixTransform(weight, args, type(module))
63
63
 
64
64
  def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65
65
  # TODO: verify that weight is invertible (has non-zero determinant)
@@ -74,17 +74,27 @@ class RandomMatrixFactory(TransformFactory):
74
74
 
75
75
 
76
76
  class RandomMatrixTransform(TransformBase):
77
- def __init__(self, weight: Tensor, args: TransformArgs):
77
+ def __init__(
78
+ self,
79
+ weight: Tensor,
80
+ args: TransformArgs,
81
+ module_type: type[torch.nn.Module],
82
+ ):
78
83
  super().__init__()
79
84
  self.weight = weight # is an inverse if args.inverse
80
85
  self.args = args
86
+ self.module_type = module_type
81
87
 
82
88
  def forward(self, value: Tensor) -> Parameter:
83
- return apply_transform_weight(self.weight, value, self.args.location)
89
+ return apply_transform_weight(
90
+ self.weight, value, self.args.location, self.module_type
91
+ )
84
92
 
85
93
  def right_inverse(self, value: Tensor) -> Tensor:
86
94
  inverse = high_precision_invert(self.weight)
87
- return apply_transform_weight(inverse, value, self.args.location)
95
+ return apply_transform_weight(
96
+ inverse, value, self.args.location, self.module_type
97
+ )
88
98
 
89
99
 
90
100
  def high_precision_invert(weight: Tensor) -> Tensor:
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import List
15
+ from typing import List, Optional
16
16
 
17
17
  from compressed_tensors.transform import TransformArgs
18
18
  from pydantic import BaseModel, Field
@@ -40,3 +40,4 @@ class TransformScheme(BaseModel):
40
40
  apply: List[TransformArgs] = Field(default_factory=list)
41
41
  randomize: bool = Field(default=False)
42
42
  requires_grad: bool = Field(default=False)
43
+ head_dim: Optional[int] = Field(default=None)
@@ -59,7 +59,7 @@ def deterministic_hadamard_matrix(
59
59
  for _ in range(log2):
60
60
  H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
61
61
 
62
- return H / math.sqrt(size)
62
+ return H
63
63
 
64
64
 
65
65
  def random_hadamard_matrix(
@@ -86,7 +86,7 @@ def random_hadamard_matrix(
86
86
  Q = Q.to(device=device)
87
87
  Q = Q * 2 - 1
88
88
  Q = torch.diag(Q)
89
- return _matmul_hadU(Q) / math.sqrt(size)
89
+ return _matmul_hadU(Q)
90
90
 
91
91
 
92
92
  def is_pow2(n: int) -> bool:
@@ -0,0 +1,179 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Optional, Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.transform import TransformLocation
19
+
20
+
21
+ __all__ = ["get_transform_size", "apply_transform_weight"]
22
+
23
+
24
+ def get_transform_size(
25
+ module: torch.nn.Module,
26
+ location: TransformLocation,
27
+ head_dim: Optional[int] = None,
28
+ ) -> int:
29
+ """
30
+ Determine the size of a transform matrix given its location on the module
31
+
32
+ :param module: module that matrix will be applied to
33
+ :param location: location on module
34
+ :param head_dim: size of head when transform is applied to mha
35
+ :return: size of matrix
36
+ """
37
+ if isinstance(module, torch.nn.Linear):
38
+ if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
39
+ size = module.in_features
40
+ else:
41
+ size = module.out_features
42
+ elif isinstance(module, torch.nn.Embedding):
43
+ if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
44
+ size = module.num_embeddings
45
+ else:
46
+ size = module.embedding_dim
47
+ else:
48
+ raise NotImplementedError(f"Transforms on {type(module)} are not supported")
49
+
50
+ if head_dim is not None:
51
+ if size % head_dim != 0:
52
+ raise ValueError(
53
+ f"{head_dim} must divide {size} for {type(module)} at {location}"
54
+ )
55
+
56
+ size = head_dim
57
+
58
+ return size
59
+
60
+
61
+ def apply_transform_weight(
62
+ transform_weight: torch.Tensor,
63
+ value: torch.Tensor,
64
+ location: TransformLocation,
65
+ module_type: type[torch.nn.Module],
66
+ ) -> torch.Tensor:
67
+ """
68
+ Using the transform location, apply the transform_weight to the
69
+ given value wrt linear weights. For more info on input and output transforms,
70
+ see `TransformLocation`
71
+
72
+ The following explains how weights should be applied to values according to location
73
+
74
+ let x be input activation
75
+ W be weight,
76
+ yh, xh, Wh be transformed output, input, weight
77
+
78
+ note that
79
+ y = (x W.T) // torch.nn.Linear
80
+
81
+ Choose values for yh, xh, and Wh which incorporate matrix transforms
82
+
83
+ let V, Vi be transform matrices on input side
84
+ U, Ui be transform matrices on output side
85
+
86
+ pick xh = (x V)
87
+ Wh = (U.T W Vi.T)
88
+ yh = (y U)
89
+
90
+ The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
91
+
92
+ (xh) (Wh).T = (x V) (U.T W Vi.T).T
93
+ = (x V) (Vi W.T U) // transpose matrix product identity
94
+ = (x W.T) U
95
+ = y U
96
+ = yh
97
+
98
+ :param transform_weight: transform weight to apply
99
+ :param value: value to apply transform_weight to
100
+ :param location: determines how weight should be applied
101
+ :param model_type: result of type(module), passed in to determine application of
102
+ weight transform
103
+ :return: value after transform_weight has been applied
104
+ """
105
+
106
+ assert transform_weight.shape[0] == transform_weight.shape[1]
107
+
108
+ if module_type == torch.nn.Linear:
109
+ if location == TransformLocation.INPUT:
110
+ return _multihead_matmul(value, transform_weight)
111
+
112
+ elif location == TransformLocation.WEIGHT_INPUT:
113
+ # equivalent to (transform_weight @ value.T).T
114
+ return _multihead_matmul(value, transform_weight.T)
115
+
116
+ elif location == TransformLocation.WEIGHT_OUTPUT:
117
+ # equivalent to (value.T @ transform_weight).T
118
+ return _multihead_matmul(transform_weight.T, value)
119
+
120
+ elif location == TransformLocation.OUTPUT:
121
+ return _multihead_matmul(value, transform_weight)
122
+
123
+ # similar derivation to torch.nn.Linear, but `y = (x W)`
124
+ elif module_type == torch.nn.Embedding:
125
+ if location == TransformLocation.INPUT:
126
+ return _multihead_matmul(value, transform_weight)
127
+
128
+ elif location == TransformLocation.WEIGHT_INPUT:
129
+ return _multihead_matmul(
130
+ transform_weight,
131
+ value,
132
+ )
133
+
134
+ elif location == TransformLocation.WEIGHT_OUTPUT:
135
+ return _multihead_matmul(value, transform_weight)
136
+
137
+ elif location == TransformLocation.OUTPUT:
138
+ return _multihead_matmul(value, transform_weight)
139
+
140
+ raise NotImplementedError(
141
+ f"Applying transforms to {module_type} {location} is not supported"
142
+ )
143
+
144
+
145
+ def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
146
+ """
147
+ Performs A @ B for last two dims of two matrices A and B that possibly
148
+ have different shapes, as is the case in multi-headed dimension. If
149
+ shapes are different, this is equivalent to converting the last two dims
150
+ of the smaller matrix into a block-diagonal matrix with the same shape as
151
+ the last two dims of the larger matrix.
152
+
153
+ E.g. if A is half the size of B, this function will perform
154
+ [[A ] @ B
155
+ [ A]]
156
+
157
+ If B is a third of the size of A, this function will perform
158
+ A @ [[B ]
159
+ [ B ]
160
+ [ B]]
161
+
162
+ This function will error out if the shapes are not evenly divisble
163
+
164
+ :param A: left-hand tensor
165
+ :param B: right-hand tensor
166
+ :return: result
167
+ """
168
+ if A.shape[-1] > B.shape[-2]:
169
+ head_dim = B.shape[-2]
170
+ num_heads = A.shape[-1] // head_dim
171
+ A = A.unflatten(-1, (num_heads, head_dim))
172
+ return (A @ B).flatten(-2, -1)
173
+ elif A.shape[-1] < B.shape[-2]:
174
+ head_dim = A.shape[-1]
175
+ num_heads = B.shape[-2] // head_dim
176
+ B = B.unflatten(-2, (num_heads, head_dim))
177
+ return (A @ B).flatten(-3, -2)
178
+ else:
179
+ return A @ B
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.3.a20250715'
20
+ __version__ = version = '0.10.3.a20250716'
21
21
  __version_tuple__ = version_tuple = (0, 10, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250715
3
+ Version: 0.10.3a20250716
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -84,7 +84,7 @@ src/compressed_tensors/transform/factory/random_hadamard.py
84
84
  src/compressed_tensors/transform/utils/__init__.py
85
85
  src/compressed_tensors/transform/utils/hadamard.py
86
86
  src/compressed_tensors/transform/utils/hadamards.safetensors
87
- src/compressed_tensors/transform/utils/utils.py
87
+ src/compressed_tensors/transform/utils/matrix.py
88
88
  src/compressed_tensors/utils/__init__.py
89
89
  src/compressed_tensors/utils/helpers.py
90
90
  src/compressed_tensors/utils/internal.py
@@ -0,0 +1,115 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pytest
16
+ import torch
17
+ from compressed_tensors.transform import TransformArgs
18
+
19
+
20
+ class TransformableModel(torch.nn.Module):
21
+ def __init__(self, *sizes):
22
+ super().__init__()
23
+ self.fcs = torch.nn.ModuleList(
24
+ [
25
+ torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)
26
+ for index in range(0, len(sizes) - 1)
27
+ ]
28
+ )
29
+
30
+ def forward(self, x):
31
+ for layer in self.fcs:
32
+ x = layer(x)
33
+ return x
34
+
35
+
36
+ class MockAttention(torch.nn.Module):
37
+ def __init__(
38
+ self, hidden_size: int, num_attention_heads: int, num_key_value_heads: int
39
+ ):
40
+ super().__init__()
41
+ self.num_attention_heads = num_attention_heads
42
+ self.num_key_value_heads = num_key_value_heads
43
+
44
+ self.num_key_value_groups = num_attention_heads // num_key_value_heads
45
+ self.head_dim = hidden_size // num_attention_heads
46
+ self.scaling = self.head_dim**-0.5
47
+ assert hidden_size >= num_attention_heads * self.head_dim
48
+
49
+ self.q_proj = torch.nn.Linear(
50
+ hidden_size, num_attention_heads * self.head_dim, bias=False
51
+ )
52
+ self.k_proj = torch.nn.Linear(
53
+ hidden_size, num_key_value_heads * self.head_dim, bias=False
54
+ )
55
+ self.v_proj = torch.nn.Linear(
56
+ hidden_size, num_key_value_heads * self.head_dim, bias=False
57
+ )
58
+ self.o_proj = torch.nn.Linear(
59
+ num_attention_heads * self.head_dim, hidden_size, bias=False
60
+ )
61
+
62
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
63
+ batch_size, seq_len, hidden_size = hidden_states.shape
64
+ hidden_shape = (batch_size, seq_len, -1, self.head_dim)
65
+
66
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
67
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
68
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
69
+
70
+ key_states = self.repeat_kv(key_states, self.num_key_value_groups)
71
+ value_states = self.repeat_kv(value_states, self.num_key_value_groups)
72
+
73
+ attn_weights = (
74
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
75
+ )
76
+
77
+ attn_weights = torch.nn.functional.softmax(
78
+ attn_weights, dim=-1, dtype=torch.float32
79
+ ).to(query_states.dtype)
80
+ attn_output = torch.matmul(attn_weights, value_states)
81
+ attn_output = attn_output.transpose(1, 2).contiguous()
82
+
83
+ attn_output = attn_output.reshape((batch_size, seq_len, -1)).contiguous()
84
+
85
+ return self.o_proj(attn_output)
86
+
87
+ def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
88
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
89
+ if n_rep == 1:
90
+ return hidden_states
91
+ hidden_states = hidden_states[:, :, None, :, :].expand(
92
+ batch, num_key_value_heads, n_rep, slen, head_dim
93
+ )
94
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
95
+
96
+
97
+ @pytest.fixture(scope="function")
98
+ def model_apply():
99
+ model = TransformableModel(2, 4, 8, 16, 32, 64)
100
+ apply = [
101
+ # weight output -> input
102
+ TransformArgs(targets="fcs.0", location="weight_output"),
103
+ TransformArgs(targets="fcs.1", location="input", inverse=True),
104
+ # output -> weight input
105
+ TransformArgs(targets="fcs.1", location="output"),
106
+ TransformArgs(targets="fcs.2", location="weight_input", inverse=True),
107
+ # output -> input
108
+ TransformArgs(targets="fcs.2", location="output"),
109
+ TransformArgs(targets="fcs.3", location="input", inverse=True),
110
+ # weight output -> weight input
111
+ TransformArgs(targets="fcs.3", location="weight_output"),
112
+ TransformArgs(targets="fcs.4", location="weight_input", inverse=True),
113
+ ]
114
+
115
+ return model, apply
@@ -0,0 +1,168 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pytest
16
+ import torch
17
+ from compressed_tensors.transform import (
18
+ TransformArgs,
19
+ TransformConfig,
20
+ TransformFactory,
21
+ TransformScheme,
22
+ apply_transform_config,
23
+ )
24
+ from compressed_tensors.utils import offloaded_dispatch
25
+ from tests.test_transform.conftest import MockAttention
26
+ from tests.testing_utils import requires_accelerate, requires_gpu
27
+
28
+
29
+ @pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
30
+ @pytest.mark.parametrize("randomized", (True, False))
31
+ @pytest.mark.parametrize("head_dim", (None, 2, 4))
32
+ @pytest.mark.parametrize("input_batch_size", (1, 5, 17))
33
+ def test_correctness_linear(type, randomized, head_dim, input_batch_size):
34
+ size = (4, 8)
35
+ module = torch.nn.Linear(*size, bias=False)
36
+ scheme = TransformScheme(type=type, randomized=randomized, head_dim=head_dim)
37
+ factory = TransformFactory.from_scheme(scheme, name="")
38
+
39
+ input_tfm = factory.create_transform(
40
+ module, TransformArgs(targets="Linear", location="input", inverse=True)
41
+ )
42
+ w_in_tfm = factory.create_transform(
43
+ module, TransformArgs(targets="Linear", location="weight_input")
44
+ )
45
+ w_out_tfm = factory.create_transform(
46
+ module, TransformArgs(targets="Linear", location="weight_output")
47
+ )
48
+ output_tfm = factory.create_transform(
49
+ module, TransformArgs(targets="Linear", location="output", inverse=True)
50
+ )
51
+
52
+ input = torch.rand((input_batch_size, 5, size[0]))
53
+ true_output = input @ module.weight.T
54
+ input_transformed = input_tfm(input)
55
+ weight_transformed = w_out_tfm(w_in_tfm(module.weight))
56
+ output = output_tfm(input_transformed @ weight_transformed.T)
57
+ assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
58
+
59
+
60
+ @pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
61
+ @pytest.mark.parametrize("randomized", (True, False))
62
+ @pytest.mark.parametrize("embed_loc", ("weight_output", "output"))
63
+ @pytest.mark.parametrize("linear_loc", ("input", "weight_input"))
64
+ def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
65
+ model = torch.nn.Sequential(
66
+ torch.nn.Embedding(2, 4),
67
+ torch.nn.Linear(4, 8, bias=False),
68
+ )
69
+
70
+ input = torch.randint(high=1, low=0, size=(17, 5, 2))
71
+ true_output = model(input)
72
+
73
+ config = TransformConfig(
74
+ config_groups={
75
+ "": TransformScheme(
76
+ type=type,
77
+ randomized=randomized,
78
+ apply=[
79
+ TransformArgs(targets="Embedding", location=embed_loc),
80
+ TransformArgs(targets="Linear", location=linear_loc, inverse=True),
81
+ ],
82
+ )
83
+ }
84
+ )
85
+ apply_transform_config(model, config)
86
+
87
+ # compare outputs
88
+ output = model(input)
89
+ assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
90
+
91
+
92
+ @pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
93
+ @pytest.mark.parametrize("randomized", (True, False))
94
+ @pytest.mark.parametrize("input_batch_size", (1, 5, 17))
95
+ def test_correctness_model(
96
+ type, randomized, input_batch_size, model_apply, offload=False
97
+ ):
98
+ # load model
99
+ model = model_apply[0]
100
+ if offload:
101
+ model = offloaded_dispatch(model, torch.device("cuda"))
102
+
103
+ # get output
104
+ input = torch.rand((input_batch_size, 5, model.fcs[0].in_features))
105
+ if offload:
106
+ input = input.to(torch.device("cuda"))
107
+ true_output = model(input)
108
+
109
+ # apply transforms
110
+ config = TransformConfig(
111
+ config_groups={
112
+ "": TransformScheme(type=type, randomized=randomized, apply=model_apply[1])
113
+ }
114
+ )
115
+ apply_transform_config(model, config)
116
+
117
+ # compare outputs
118
+ output = model(input)
119
+ assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
120
+
121
+
122
+ @requires_gpu
123
+ @requires_accelerate()
124
+ @pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
125
+ @pytest.mark.parametrize("randomized", (True, False))
126
+ @pytest.mark.parametrize("input_batch_size", (1, 5, 17))
127
+ def test_correctness_model_offload(type, randomized, input_batch_size, model_apply):
128
+ test_correctness_model(
129
+ type, randomized, input_batch_size, model_apply, offload=True
130
+ )
131
+
132
+
133
+ @pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
134
+ @pytest.mark.parametrize("randomized", (True, False))
135
+ @pytest.mark.parametrize("head_dim", (4, 8))
136
+ @pytest.mark.parametrize("input_batch_size", (1, 5, 17))
137
+ def test_correctness_attention_heads(type, randomized, head_dim, input_batch_size):
138
+ hidden_size = 64
139
+ num_attention_heads = 8
140
+
141
+ attention = MockAttention(
142
+ hidden_size=hidden_size,
143
+ num_attention_heads=num_attention_heads,
144
+ num_key_value_heads=head_dim,
145
+ )
146
+
147
+ input = torch.rand(input_batch_size, 5, hidden_size)
148
+ true_output = attention(input)
149
+
150
+ config = TransformConfig(
151
+ config_groups={
152
+ "": TransformScheme(
153
+ type=type,
154
+ randomized=randomized,
155
+ head_dim=head_dim,
156
+ apply=[
157
+ TransformArgs(targets="v_proj", location="weight_output"),
158
+ TransformArgs(
159
+ targets="o_proj", location="weight_input", inverse=True
160
+ ),
161
+ ],
162
+ )
163
+ }
164
+ )
165
+ apply_transform_config(attention, config)
166
+
167
+ output = attention(input)
168
+ assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)