compressed-tensors 0.10.2a20250612__tar.gz → 0.10.2a20250613__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. {compressed_tensors-0.10.2a20250612/src/compressed_tensors.egg-info → compressed_tensors-0.10.2a20250613}/PKG-INFO +1 -1
  2. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/setup.py +1 -0
  3. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/factory/hadamard.py +1 -1
  4. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/factory/random_hadamard.py +1 -1
  5. compressed_tensors-0.10.2a20250613/src/compressed_tensors/transform/utils/hadamard.py +160 -0
  6. compressed_tensors-0.10.2a20250613/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  7. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/utils/offload.py +39 -5
  8. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/version.py +1 -1
  9. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  10. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors.egg-info/SOURCES.txt +1 -0
  11. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_transform/utils/test_hadamard.py +38 -32
  12. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_utils/test_offload.py +56 -8
  13. compressed_tensors-0.10.2a20250612/src/compressed_tensors/transform/utils/hadamard.py +0 -161
  14. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.github/.gitkeep +0 -0
  15. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.github/actions/test/action.yml +0 -0
  16. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.github/scripts/step-status +0 -0
  17. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.github/workflows/build-test.yml +0 -0
  18. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.github/workflows/build.yml +0 -0
  19. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.github/workflows/report.yml +0 -0
  20. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.github/workflows/test-check.yaml +0 -0
  21. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.github/workflows/test.yml +0 -0
  22. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.github/workflows/trigger-all.yml +0 -0
  23. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.github/workflows/upload.yml +0 -0
  24. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/.gitignore +0 -0
  25. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/LICENSE +0 -0
  26. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/Makefile +0 -0
  27. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/README.md +0 -0
  28. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  29. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/examples/bit_packing/int4_config.json +0 -0
  30. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/examples/bitmask_compression.ipynb +0 -0
  31. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  32. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  33. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/examples/llama_1.1b/example_quant_config.json +0 -0
  34. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  35. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/examples/quantize_and_pack_int4.ipynb +0 -0
  36. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/pyproject.toml +0 -0
  37. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/setup.cfg +0 -0
  38. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/__init__.py +0 -0
  39. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/README.md +0 -0
  40. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/__init__.py +0 -0
  41. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/base.py +0 -0
  42. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/__init__.py +0 -0
  43. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/base.py +0 -0
  44. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/helpers.py +0 -0
  45. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  46. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  47. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  48. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  49. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  50. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  51. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  52. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  53. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  54. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  55. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  56. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  57. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  58. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  59. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/config/__init__.py +0 -0
  60. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/config/base.py +0 -0
  61. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/config/dense.py +0 -0
  62. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  63. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  64. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/linear/__init__.py +0 -0
  65. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  66. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/__init__.py +0 -0
  67. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  68. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  69. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  70. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  71. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  72. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  73. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/quant_args.py +0 -0
  74. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/quant_config.py +0 -0
  75. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  76. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  77. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  78. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/registry/__init__.py +0 -0
  79. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/registry/registry.py +0 -0
  80. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/__init__.py +0 -0
  81. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  82. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/factory/base.py +0 -0
  83. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/factory/matrix_multiply.py +0 -0
  84. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/transform_args.py +0 -0
  85. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/transform_config.py +0 -0
  86. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  87. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  88. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/transform/utils/utils.py +0 -0
  89. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/utils/__init__.py +0 -0
  90. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/utils/helpers.py +0 -0
  91. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/utils/permutations_24.py +0 -0
  92. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/utils/permute.py +0 -0
  93. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  94. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  95. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  96. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors.egg-info/requires.txt +0 -0
  97. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  98. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/__init__.py +0 -0
  99. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/conftest.py +0 -0
  100. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/__init__.py +0 -0
  101. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/model_compressors/__init__.py +0 -0
  102. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  103. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  104. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  105. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  106. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  107. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  108. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  109. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  110. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  111. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  112. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  113. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_configs/__init__.py +0 -0
  114. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_configs/test_base.py +0 -0
  115. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  116. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_linear/__init__.py +0 -0
  117. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_linear/test_compressed_linear.py +0 -0
  118. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/__init__.py +0 -0
  119. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/lifecycle/__init__.py +0 -0
  120. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/lifecycle/conftest.py +0 -0
  121. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  122. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  123. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  124. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  125. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  126. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  127. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  128. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/test_configs/__init__.py +0 -0
  129. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  130. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  131. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/test_quant_args.py +0 -0
  132. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/test_quant_config.py +0 -0
  133. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/test_quant_scheme.py +0 -0
  134. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  135. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_registry.py +0 -0
  136. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_transform/factory/test_correctness.py +0 -0
  137. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_transform/factory/test_memory.py +0 -0
  138. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_transform/test_transform_args.py +0 -0
  139. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_transform/test_transform_config.py +0 -0
  140. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_transform/test_transform_scheme.py +0 -0
  141. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_utils/__init__.py +0 -0
  142. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_utils/test_helpers.py +0 -0
  143. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/test_utils/test_safetensors_load.py +0 -0
  144. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/tests/testing_utils.py +0 -0
  145. {compressed_tensors-0.10.2a20250612 → compressed_tensors-0.10.2a20250613}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.2a20250612
3
+ Version: 0.10.2a20250613
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -113,5 +113,6 @@ setup(
113
113
  extras_require=_setup_extras(),
114
114
  install_requires=_setup_install_requires(),
115
115
  package_dir={"": "src"},
116
+ package_data={"": ["transform/utils/hadamards.safetensors"]},
116
117
  packages=_setup_packages(),
117
118
  )
@@ -59,7 +59,7 @@ class HadamardFactory(TransformFactory):
59
59
  return HadamardTransform(weight, args)
60
60
 
61
61
  def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62
- data = deterministic_hadamard_matrix(size)
62
+ data = deterministic_hadamard_matrix(size, dtype, device)
63
63
  data = data.to(dtype=dtype, device=device)
64
64
  return Parameter(data, requires_grad=self.scheme.requires_grad)
65
65
 
@@ -29,6 +29,6 @@ class RandomHadamardFactory(HadamardFactory):
29
29
  """
30
30
 
31
31
  def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32
- data = random_hadamard_matrix(size, self.generator)
32
+ data = random_hadamard_matrix(size, dtype, device, self.generator)
33
33
  data = data.to(dtype=dtype, device=device)
34
34
  return Parameter(data, requires_grad=self.scheme.requires_grad)
@@ -0,0 +1,160 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ import torch
20
+ from safetensors import safe_open
21
+
22
+
23
+ REPO_PATH = Path(__file__).parent / "hadamards.safetensors"
24
+
25
+
26
+ __all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
27
+
28
+
29
+ # note that hadamard matrix multiplication can be accelerated using a library such as
30
+ # https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
31
+
32
+
33
+ def deterministic_hadamard_matrix(
34
+ size: int,
35
+ dtype: torch.dtype = torch.bfloat16,
36
+ device: torch.device = torch.device("cpu"),
37
+ ) -> torch.Tensor:
38
+ """
39
+ Construct an n-by-n Hadamard matrix, using Sylvester's construction.
40
+ `n` must be a power of 2.
41
+
42
+ Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
43
+
44
+ :param size: order of the matrix, must be a power of 2
45
+ :param dtype: data type of matrix
46
+ :param device: device to construct matrix on
47
+ :return: hadamard matrix of size `size`
48
+ """
49
+ if size <= 0:
50
+ raise ValueError("Cannot construct deterministic hadamard of size <= 0")
51
+
52
+ log2 = int(math.log2(size))
53
+ if size != 2**log2:
54
+ raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
55
+
56
+ H = torch.tensor([[1]], dtype=dtype, device=device)
57
+
58
+ # Sylvester's construction
59
+ for _ in range(log2):
60
+ H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
61
+
62
+ return H / math.sqrt(size)
63
+
64
+
65
+ def random_hadamard_matrix(
66
+ size: int,
67
+ dtype: torch.dtype = torch.bfloat16,
68
+ device: torch.device = torch.device("cpu"),
69
+ gen: Optional[torch.Generator] = None,
70
+ ) -> torch.Tensor:
71
+ """
72
+ Produces a randomly generated Hadamard matrix. Differs from
73
+ `deterministic_hadamard_matrix` in that this function supports non powers of 2
74
+ and randomization using a seeded generator
75
+
76
+ Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
77
+ Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501
78
+
79
+ :param size: The dimension of the hamadard matrix
80
+ :param dtype: data type of matrix
81
+ :param device: device to construct matrix on
82
+ :param gen: Optional generator random values
83
+ :return: randomly generated hadamard matrix
84
+ """
85
+ Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
86
+ Q = Q.to(device=device)
87
+ Q = Q * 2 - 1
88
+ Q = torch.diag(Q)
89
+ return _matmul_hadU(Q) / math.sqrt(size)
90
+
91
+
92
+ def is_pow2(n: int) -> bool:
93
+ """
94
+ Check if a number is a power of 2
95
+
96
+ :param n: number to check
97
+ :return: True iff `n` is a power of 2
98
+ """
99
+ return n > 0 and (n & (n - 1) == 0)
100
+
101
+
102
+ def _fetch_hadamard_divisor(
103
+ n: int,
104
+ dtype: torch.dtype,
105
+ device: torch.device = torch.device("cpu"),
106
+ file_path: str = REPO_PATH,
107
+ ) -> Optional[torch.Tensor]:
108
+ """
109
+ Fetch a known hadamard matrix from the given file path. The returned matrix will
110
+ be of of size `k` such that `n / k` is a power of two. Return None if no such
111
+ matrix exists.
112
+
113
+ Note: This function reopens the safetensors file every time it is called.
114
+ This is technically inefficient, but a very small runtime cost and simpler
115
+ than forcing callers to manage the file open context
116
+
117
+ :param n: size of known hadamard matrix
118
+ :return: a known hadamard matrix of size `n` if one exists, else None
119
+ """
120
+ with safe_open(file_path, framework="pt", device=str(device)) as file:
121
+ divisors = sorted((int(key) for key in file.keys()), reverse=True)
122
+ for divisor in divisors:
123
+ if n % divisor == 0 and is_pow2(n // divisor):
124
+ return file.get_tensor(str(divisor)).to(dtype=dtype)
125
+
126
+ return None
127
+
128
+
129
+ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
130
+ size = X.size(0)
131
+ dtype = X.dtype
132
+ device = X.device
133
+
134
+ # Check if we have the determined hadamard matrix
135
+ hadK = _fetch_hadamard_divisor(size, dtype, device=device)
136
+ if hadK is None:
137
+ raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
138
+ K = hadK.size(0)
139
+
140
+ # Reshape diag matrix with randomized -1/+1
141
+ input = X.clone().view(-1, size, 1)
142
+ output = input.clone()
143
+ while input.shape[1] > K:
144
+ input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
145
+ output = output.view(input.shape)
146
+ output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
147
+ output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
148
+ output = output.view(input.shape[0], input.shape[1], -1)
149
+ (input, output) = (output, input)
150
+ assert input.shape[1] == K
151
+ del output
152
+
153
+ # Do not explicitly repeat - OOM
154
+ # input = torch.bmm(
155
+ # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
156
+ # Use bcast instead
157
+ input = hadK.view(1, K, K).to(input) @ input
158
+
159
+ # normalize
160
+ return input.view(X.shape)
@@ -31,9 +31,10 @@ import contextlib
31
31
  import warnings
32
32
  from functools import wraps
33
33
  from operator import attrgetter
34
- from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
34
+ from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Union
35
35
 
36
36
  import torch
37
+ from compressed_tensors.utils import patch_attr
37
38
 
38
39
 
39
40
  try:
@@ -83,6 +84,7 @@ __all__ = [
83
84
  "register_offload_module",
84
85
  "delete_offload_module",
85
86
  "offloaded_dispatch",
87
+ "disable_offloading",
86
88
  ]
87
89
 
88
90
 
@@ -214,7 +216,7 @@ def register_offload_parameter(
214
216
  def update_offload_parameter(
215
217
  module: torch.nn.Module,
216
218
  name: str,
217
- data: Optional[torch.Tensor],
219
+ data: torch.Tensor,
218
220
  offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
219
221
  ):
220
222
  """
@@ -227,7 +229,7 @@ def update_offload_parameter(
227
229
  :param offload_device: device on which weight will be offloaded to. If None is
228
230
  provided, then infer device from parameters on module
229
231
  """
230
- param = getattr(module, name)
232
+ param: torch.nn.Parameter = getattr(module, name)
231
233
  if param.data.shape != data.shape:
232
234
  warnings.warn(
233
235
  f"Shape of parameter being updated {param.data.shape} does not match shape "
@@ -235,7 +237,7 @@ def update_offload_parameter(
235
237
  )
236
238
 
237
239
  # copy data into onloaded parameter if applicable
238
- if param.device != torch.device("meta"):
240
+ if param.device != torch.device("meta") and data is not param.data:
239
241
  param.data.copy_(data)
240
242
 
241
243
  # update offload dict
@@ -501,7 +503,9 @@ def offloaded_dispatch(
501
503
  raise NotImplementedError("Disk offloading is not currently supported")
502
504
 
503
505
  # create weights map
504
- weights_map = OffloadedWeightsLoader(state_dict=module.state_dict(), device="cpu")
506
+ state_dict = module.state_dict()
507
+ state_dict = {key: val.to(offload_device) for key, val in state_dict.items()}
508
+ weights_map = OffloadedWeightsLoader(state_dict=state_dict, device=offload_device)
505
509
 
506
510
  # create tied params map
507
511
  tied_params = find_tied_parameters(module)
@@ -522,6 +526,36 @@ def offloaded_dispatch(
522
526
  return module
523
527
 
524
528
 
529
+ @contextlib.contextmanager
530
+ def disable_offloading():
531
+ """
532
+ Keep modules onloaded and disable offloading until this context exits.
533
+ Affects modules which have been hooked with accelerate's `AlignDevicesHook`
534
+ """
535
+ original_pre_forward = AlignDevicesHook.pre_forward
536
+ onloaded_modules: Dict[torch.nn.Module, Tuple[AlignDevicesHook, bool]] = dict()
537
+
538
+ # onload once and disable any future onloading/offloading steps
539
+ def keep_onload_pre_forward(self: AlignDevicesHook, module, *args, **kwargs):
540
+ ret = original_pre_forward(self, module, *args, **kwargs)
541
+ if module not in onloaded_modules:
542
+ onloaded_modules[module] = (self, self.offload)
543
+ self.offload = False
544
+ return ret
545
+
546
+ # use the patched pre_forward function within the context
547
+ with patch_attr(AlignDevicesHook, "pre_forward", keep_onload_pre_forward):
548
+ yield
549
+
550
+ # manually offload all modules that were onloaded
551
+ # update any parameters which may have changed
552
+ for module, (hook, offload) in onloaded_modules.items():
553
+ hook.offload = offload
554
+ for name, param in module.named_parameters():
555
+ update_offload_parameter(module, name, param.data)
556
+ hook.post_forward(module, None)
557
+
558
+
525
559
  """ Upstreamed Functions """
526
560
 
527
561
 
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.2.a20250612'
20
+ __version__ = version = '0.10.2.a20250613'
21
21
  __version_tuple__ = version_tuple = (0, 10, 2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.2a20250612
3
+ Version: 0.10.2a20250613
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -82,6 +82,7 @@ src/compressed_tensors/transform/factory/matrix_multiply.py
82
82
  src/compressed_tensors/transform/factory/random_hadamard.py
83
83
  src/compressed_tensors/transform/utils/__init__.py
84
84
  src/compressed_tensors/transform/utils/hadamard.py
85
+ src/compressed_tensors/transform/utils/hadamards.safetensors
85
86
  src/compressed_tensors/transform/utils/utils.py
86
87
  src/compressed_tensors/utils/__init__.py
87
88
  src/compressed_tensors/utils/helpers.py
@@ -13,46 +13,48 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- import numpy
17
16
  import pytest
18
17
  import torch
19
18
  from compressed_tensors.transform.utils.hadamard import (
20
- _get_had12,
21
- _get_had20,
22
19
  deterministic_hadamard_matrix,
20
+ is_pow2,
23
21
  random_hadamard_matrix,
24
22
  )
23
+ from tests.testing_utils import requires_gpu
25
24
 
26
25
 
27
- @pytest.mark.parametrize(
28
- "had_func",
29
- [
30
- _get_had12,
31
- _get_had20,
32
- ],
33
- )
34
- def test_packed_hadamard_compliant(had_func):
35
- had_matrix = had_func()
36
- size = had_matrix.size(0)
37
- # HH.T == nI
38
- product = had_matrix @ had_matrix.T
39
- assert torch.equal(product, size * torch.eye(size))
26
+ _sizes_to_test = [
27
+ 768, # gpt2 small
28
+ 1024, # gpt2 medium
29
+ 1280, # qwen_2_5_vl vision
30
+ 1600, # gpt2 xl
31
+ 2048, # gpt3 small
32
+ 3584, # qwen_2_5_vl
33
+ 3840, # qwen_2_5_vl vision qkv
34
+ 4096, # llama3
35
+ 7168, # deepseek_v3
36
+ 14336, # llama3 intermediate
37
+ 18432, # deepseek_v3 intermediate
38
+ 18944, # qwen_2_5_vl intermediate
39
+ ]
40
+ _atol = 1e-1 # bfloat16 is low precision for large matrices
40
41
 
41
42
 
42
- @pytest.mark.parametrize(
43
- "size",
44
- [4096, 2048],
45
- )
43
+ @requires_gpu
44
+ @pytest.mark.parametrize("size", _sizes_to_test)
46
45
  def test_random_hadamard_matrix_compliant(size):
47
- had_matrix = random_hadamard_matrix(size)
48
- product = torch.round(had_matrix @ had_matrix.T)
49
- assert torch.equal(product, torch.eye(size))
46
+ # (H / sqrt(n))(H.T / sqrt(n)) == I
47
+ matrix = random_hadamard_matrix(size, device="cuda")
48
+ product = matrix @ matrix.T
49
+ eye = torch.eye(size, dtype=product.dtype, device="cuda")
50
+ assert torch.allclose(product, eye, atol=_atol)
50
51
 
51
52
 
52
53
  def test_random_hadamard_generator():
54
+ # check that generation is deterministic with a seed
53
55
  generator = torch.Generator().manual_seed(42)
54
- one = random_hadamard_matrix(2048, generator)
55
- two = random_hadamard_matrix(2048, generator)
56
+ one = random_hadamard_matrix(2048, gen=generator)
57
+ two = random_hadamard_matrix(2048, gen=generator)
56
58
 
57
59
  one_true = torch.tensor(
58
60
  [
@@ -73,12 +75,16 @@ def test_random_hadamard_generator():
73
75
  assert torch.all(two[:3, :3].sign() == two_true.sign())
74
76
 
75
77
 
76
- @pytest.mark.parametrize(
77
- "size",
78
- [1024],
79
- )
78
+ @requires_gpu
79
+ @pytest.mark.parametrize("size", _sizes_to_test)
80
80
  def test_deterministic_hadamard_compliant(size):
81
- had_matrix = deterministic_hadamard_matrix(size)
81
+ if not is_pow2(size):
82
+ with pytest.raises(ValueError):
83
+ matrix = deterministic_hadamard_matrix(size, device="cuda")
84
+ return
85
+
82
86
  # (H / sqrt(n))(H.T / sqrt(n)) == I
83
- product = had_matrix @ had_matrix.T
84
- assert numpy.array_equal(product, numpy.eye(size))
87
+ matrix = deterministic_hadamard_matrix(size, device="cuda")
88
+ product = matrix @ matrix.T
89
+ eye = torch.eye(size, dtype=product.dtype, device="cuda")
90
+ assert torch.allclose(product, eye, atol=_atol)
@@ -19,6 +19,7 @@ from compressed_tensors.utils import (
19
19
  delete_offload_module,
20
20
  delete_offload_parameter,
21
21
  disable_hf_hook,
22
+ disable_offloading,
22
23
  get_execution_device,
23
24
  has_offloaded_params,
24
25
  offloaded_dispatch,
@@ -397,15 +398,23 @@ def test_delete_offload_module(exec_device):
397
398
 
398
399
  @requires_gpu
399
400
  @requires_accelerate()
400
- @pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
401
- def test_offloaded_dispatch(exec_device):
401
+ @pytest.mark.parametrize(
402
+ "exec_device,offload_device",
403
+ [
404
+ (torch.device("cpu"), torch.device("cpu")),
405
+ (torch.device("cpu"), torch.device("cuda:0")),
406
+ (torch.device("cuda:0"), torch.device("cpu")),
407
+ (torch.device("cuda:0"), torch.device("cuda:0")),
408
+ ],
409
+ )
410
+ def test_offloaded_dispatch(exec_device, offload_device):
402
411
  # single module
403
- module = torch.nn.Linear(1, 2)
404
- module = offloaded_dispatch(module, exec_device)
412
+ module = torch.nn.Linear(1, 2, device=offload_device)
413
+ module = offloaded_dispatch(module, exec_device, offload_device)
405
414
  assert has_offloaded_params(module)
406
415
  assert module._hf_hook.offload
407
416
  assert module.weight.device == torch.device("meta")
408
- assert "weight" in module._hf_hook.weights_map
417
+ assert module._hf_hook.weights_map["weight"].device == offload_device
409
418
  assert module._hf_hook.tied_params_map is not None
410
419
 
411
420
  # can run
@@ -413,13 +422,13 @@ def test_offloaded_dispatch(exec_device):
413
422
 
414
423
  # model
415
424
  model = ExampleModel()
416
- model = offloaded_dispatch(model, exec_device)
425
+ model = offloaded_dispatch(model, exec_device, offload_device)
417
426
  assert not has_offloaded_params(model)
418
427
 
419
428
  assert has_offloaded_params(model.linear)
420
429
  assert model.linear._hf_hook.offload
421
430
  assert model.linear.weight.device == torch.device("meta")
422
- assert "weight" in model.linear._hf_hook.weights_map
431
+ assert model.linear._hf_hook.weights_map["weight"].device == offload_device
423
432
  assert model.linear._hf_hook.tied_params_map is not None
424
433
 
425
434
  # can run
@@ -429,4 +438,43 @@ def test_offloaded_dispatch(exec_device):
429
438
  parameter = torch.nn.Parameter(torch.tensor(1.0))
430
439
  register_offload_parameter(module, "new_param", parameter)
431
440
  assert module.new_param.device == torch.device("meta")
432
- assert module._hf_hook.weights_map["new_param"].device == torch.device("cpu")
441
+ assert module._hf_hook.weights_map["new_param"].device == offload_device
442
+
443
+
444
+ @requires_gpu
445
+ @requires_accelerate()
446
+ @pytest.mark.parametrize(
447
+ "exec_device,offload_device",
448
+ [
449
+ (torch.device("cpu"), torch.device("cpu")),
450
+ (torch.device("cpu"), torch.device("cuda:0")),
451
+ (torch.device("cuda:0"), torch.device("cpu")),
452
+ (torch.device("cuda:0"), torch.device("cuda:0")),
453
+ ],
454
+ )
455
+ def test_disable_offloading(exec_device, offload_device):
456
+ module = torch.nn.Linear(1, 2, device=exec_device)
457
+
458
+ # non-offloaded modules are unaffected
459
+ with disable_offloading():
460
+ output = module(torch.empty(1, device=exec_device))
461
+ assert module.weight.device == exec_device
462
+ assert output.device == exec_device
463
+
464
+ # offloaded modules stay on device until context exit
465
+ offloaded_dispatch(module, exec_device, offload_device)
466
+ assert module.weight.device == torch.device("meta")
467
+ assert module._hf_hook.weights_map["weight"].device == offload_device
468
+
469
+ with disable_offloading():
470
+ assert module.weight.device == torch.device("meta")
471
+ output = module(torch.empty(1, device=exec_device))
472
+ assert module.weight.device == exec_device
473
+ assert output.device == exec_device
474
+
475
+ output = module(torch.empty(1, device=exec_device))
476
+ assert module.weight.device == exec_device
477
+ assert output.device == exec_device
478
+
479
+ assert module.weight.device == torch.device("meta")
480
+ assert module._hf_hook.weights_map["weight"].device == offload_device
@@ -1,161 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import math
16
- from typing import Optional, Tuple
17
-
18
- import numpy
19
- import torch
20
-
21
-
22
- __all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
23
-
24
- # adapted from:
25
- # https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
26
- def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
27
- """
28
- Construct an n-by-n Hadamard matrix, using Sylvester's construction.
29
- `n` must be a power of 2.
30
-
31
- :param size: order of the matrix, must be a power of 2
32
- :return: hadamard matrix of size `size`
33
- """
34
- if size <= 0:
35
- raise ValueError("Cannot construct deterministic hadamard of size <= 0")
36
-
37
- log2 = int(math.log(size, 2))
38
- if size != 2**log2:
39
- raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
40
-
41
- H = numpy.array([[1]], dtype=int)
42
-
43
- # Sylvester's construction
44
- for i in range(0, log2):
45
- H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))
46
-
47
- return torch.from_numpy(H / math.sqrt(size))
48
-
49
-
50
- # adapted from:
51
- # https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py
52
-
53
- # TODO: the following library exists for online rotations and should be considered
54
- # in the future:
55
- # https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
56
-
57
-
58
- def random_hadamard_matrix(
59
- size: int, gen: Optional[torch.Generator] = None
60
- ) -> torch.Tensor:
61
- """
62
- Produces a randomly generated Hadamard matrix.
63
- See https://cornell-relaxml.github.io/quip-sharp/ ,
64
- Section "Randomized Hadamard Transformation"
65
-
66
- :param size: The dimension of the hamadard matrix
67
- :param gen: Optional generator random values
68
- :return: randomly generated hadamard matrix
69
- """
70
- # Benefits: support other shapes / non powers of 2, support randomization
71
- Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64)
72
- Q = Q * 2 - 1
73
- Q = torch.diag(Q)
74
- return _matmul_hadU(Q) / math.sqrt(size)
75
-
76
-
77
- def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]:
78
- # NOTE: we can easily extend the list of supported shapes/sizes
79
- # by adding to these methods
80
- hadK, K = None, None
81
- if n % 20 == 0:
82
- assert _is_pow2(n // 20)
83
- K = 20
84
- hadK = _get_had20().T if transpose else _get_had20()
85
- elif n % 12 == 0:
86
- assert _is_pow2(n // 12)
87
- K = 12
88
- hadK = _get_had12().T if transpose else _get_had12()
89
- else:
90
- assert _is_pow2(n)
91
- K = 1
92
-
93
- return hadK, K
94
-
95
-
96
- def _matmul_hadU(X, transpose=False) -> torch.Tensor:
97
- n = X.shape[-1]
98
- # Check if we have the determined hadamard matrix
99
- hadK, K = _get_hadK(n, transpose)
100
- # Reshape diag matrix with randomized -1/+1
101
- input = X.clone().view(-1, n, 1)
102
- output = input.clone()
103
-
104
- # for cases when hadK is not predetermined, determine hadamard matrix
105
- while input.shape[1] > K:
106
- input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
107
- output = output.view(input.shape)
108
- output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
109
- output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
110
- output = output.view(input.shape[0], input.shape[1], -1)
111
- (input, output) = (output, input)
112
- del output
113
-
114
- # K == 1 when hadK is None; this happens when the size dim (n)
115
- # is not comaptible with any of the maintained hadamard matrices
116
-
117
- if K > 1:
118
- # Do not explicitly repeat - OOM
119
- # input = torch.bmm(
120
- # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
121
- # Use bcast instead
122
-
123
- # for cases when hadK is pre-determined
124
- input = hadK.view(1, K, K).to(input) @ input
125
-
126
- # normalize
127
- return input.view(X.shape)
128
-
129
-
130
- def _is_pow2(n: int) -> bool:
131
- return (n & (n - 1) == 0) and (n > 0)
132
-
133
-
134
- def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray:
135
- had_unpacked = numpy.unpackbits(packed_bits)
136
- had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
137
- had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
138
- return had_unpacked
139
-
140
-
141
- # http://www.neilsloane.com/hadamard/index.html
142
- def _get_had12() -> torch.Tensor:
143
- # fmt: off
144
- had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
145
- 62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
146
- # fmt: on
147
- # TODO: just unpack during apply
148
- had_12_unpacked = _reshape_bits(had_12, original_size=12)
149
- return torch.tensor(had_12_unpacked)
150
-
151
-
152
- def _get_had20() -> torch.Tensor:
153
- # fmt: off
154
- had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
155
- 216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
156
- 97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
157
- 205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
158
- # fmt: on
159
- # TODO: just unpack during apply
160
- had_20_unpacked = _reshape_bits(had_20, original_size=20)
161
- return torch.tensor(had_20_unpacked)