compressed-tensors 0.10.3a20250715__tar.gz → 0.10.3a20250721__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. {compressed_tensors-0.10.3a20250715/src/compressed_tensors.egg-info → compressed_tensors-0.10.3a20250721}/PKG-INFO +1 -1
  2. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/lifecycle/forward.py +68 -5
  3. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/quant_args.py +31 -8
  4. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/quant_scheme.py +41 -0
  5. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/utils/helpers.py +11 -2
  6. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/factory/base.py +1 -3
  7. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/factory/hadamard.py +17 -8
  8. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/factory/matrix_multiply.py +18 -8
  9. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/transform_scheme.py +2 -1
  10. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/utils/hadamard.py +2 -2
  11. compressed_tensors-0.10.3a20250721/src/compressed_tensors/transform/utils/matrix.py +179 -0
  12. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/version.py +1 -1
  13. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  14. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors.egg-info/SOURCES.txt +1 -1
  15. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_examples/test_bitmask_compression_ipynb.py +3 -1
  16. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/lifecycle/test_forward.py +50 -0
  17. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/test_quant_args.py +2 -1
  18. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/test_utils/test_helpers.py +28 -1
  19. compressed_tensors-0.10.3a20250721/tests/test_transform/conftest.py +115 -0
  20. compressed_tensors-0.10.3a20250721/tests/test_transform/factory/test_correctness.py +168 -0
  21. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_transform/utils/test_hadamard.py +2 -2
  22. compressed_tensors-0.10.3a20250715/src/compressed_tensors/transform/utils/utils.py +0 -91
  23. compressed_tensors-0.10.3a20250715/tests/test_transform/conftest.py +0 -54
  24. compressed_tensors-0.10.3a20250715/tests/test_transform/factory/test_correctness.py +0 -89
  25. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.github/.gitkeep +0 -0
  26. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.github/actions/test/action.yml +0 -0
  27. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.github/scripts/step-status +0 -0
  28. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.github/workflows/build-test.yml +0 -0
  29. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.github/workflows/build.yml +0 -0
  30. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.github/workflows/report.yml +0 -0
  31. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.github/workflows/test-check.yaml +0 -0
  32. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.github/workflows/test.yml +0 -0
  33. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.github/workflows/trigger-all.yml +0 -0
  34. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.github/workflows/upload.yml +0 -0
  35. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/.gitignore +0 -0
  36. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/LICENSE +0 -0
  37. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/Makefile +0 -0
  38. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/README.md +0 -0
  39. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  40. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/examples/bit_packing/int4_config.json +0 -0
  41. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/examples/bitmask_compression.ipynb +0 -0
  42. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  43. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  44. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/examples/llama_1.1b/example_quant_config.json +0 -0
  45. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  46. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/examples/quantize_and_pack_int4.ipynb +0 -0
  47. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/pyproject.toml +0 -0
  48. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/setup.cfg +0 -0
  49. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/setup.py +0 -0
  50. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/__init__.py +0 -0
  51. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/README.md +0 -0
  52. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/__init__.py +0 -0
  53. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/base.py +0 -0
  54. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/__init__.py +0 -0
  55. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/base.py +0 -0
  56. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/helpers.py +0 -0
  57. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  58. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  59. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  60. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  61. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  62. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  63. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  64. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  65. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  66. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  67. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  68. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  69. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  70. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  71. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/config/__init__.py +0 -0
  72. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/config/base.py +0 -0
  73. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/config/dense.py +0 -0
  74. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  75. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  76. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/linear/__init__.py +0 -0
  77. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  78. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/__init__.py +0 -0
  79. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  80. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  81. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  82. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  83. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  84. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/quant_config.py +0 -0
  85. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  86. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/registry/__init__.py +0 -0
  87. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/registry/registry.py +0 -0
  88. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/__init__.py +0 -0
  89. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/apply.py +0 -0
  90. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  91. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  92. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/transform_args.py +0 -0
  93. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/transform_config.py +0 -0
  94. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  95. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  96. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/utils/__init__.py +0 -0
  97. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/utils/helpers.py +0 -0
  98. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/utils/internal.py +0 -0
  99. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/utils/offload.py +0 -0
  100. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/utils/permutations_24.py +0 -0
  101. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/utils/permute.py +0 -0
  102. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  103. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  104. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  105. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors.egg-info/requires.txt +0 -0
  106. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  107. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/__init__.py +0 -0
  108. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/conftest.py +0 -0
  109. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/__init__.py +0 -0
  110. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/model_compressors/__init__.py +0 -0
  111. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  112. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  113. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  114. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  115. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  116. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  117. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  118. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  119. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  120. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  121. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  122. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_configs/__init__.py +0 -0
  123. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_configs/test_base.py +0 -0
  124. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_linear/__init__.py +0 -0
  125. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_linear/test_compressed_linear.py +0 -0
  126. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/__init__.py +0 -0
  127. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/lifecycle/__init__.py +0 -0
  128. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/lifecycle/conftest.py +0 -0
  129. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  130. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  131. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  132. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  133. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  134. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  135. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/test_configs/__init__.py +0 -0
  136. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  137. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  138. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/test_quant_config.py +0 -0
  139. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_quantization/test_quant_scheme.py +0 -0
  140. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_registry.py +0 -0
  141. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_transform/factory/test_memory.py +0 -0
  142. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_transform/test_transform_args.py +0 -0
  143. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_transform/test_transform_config.py +0 -0
  144. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_transform/test_transform_scheme.py +0 -0
  145. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_utils/__init__.py +0 -0
  146. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_utils/test_helpers.py +0 -0
  147. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_utils/test_offload.py +0 -0
  148. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/test_utils/test_safetensors_load.py +0 -0
  149. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/tests/testing_utils.py +0 -0
  150. {compressed_tensors-0.10.3a20250715 → compressed_tensors-0.10.3a20250721}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250715
3
+ Version: 0.10.3a20250721
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -111,11 +111,18 @@ def dequantize(
111
111
  elif scale.ndim == 2:
112
112
  if scale.shape[1] == 1:
113
113
  args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
114
- else:
114
+ # Scale height matches input or is 1 -> group quantization across columns
115
+ #
116
+ # Example 1: scale.shape[0] == 1
117
+ # x_q: (4, 8), scale: (1, 4) -> 2 columns per group
118
+ #
119
+ # Example 2: scale.shape[0] == x_q.shape[0]
120
+ # x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
121
+ elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
115
122
  group_size = int(x_q.shape[1] / scale.shape[1])
116
- args = QuantizationArgs(
117
- strategy=QuantizationStrategy.GROUP, group_size=group_size
118
- )
123
+ args = QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=group_size)
124
+ else:
125
+ args = QuantizationArgs(strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape)
119
126
  else:
120
127
  raise ValueError(
121
128
  f"Could not infer a quantization strategy from scale with {scale.ndim} "
@@ -189,7 +196,63 @@ def _process_quantization(
189
196
  q_min, q_max = calculate_range(args, x.device)
190
197
  group_size = args.group_size
191
198
 
192
- if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
199
+ # blockwise FP8: quantize per 2D block, supports block_structure for static block quant
200
+ if args.strategy == QuantizationStrategy.BLOCK:
201
+ original_shape = x.shape
202
+ rows, cols = x.shape[-2], x.shape[-1]
203
+ block_height, block_width = args.block_structure
204
+
205
+ # Ensure exact division (tensor dimensions must be divisible by block size)
206
+ if rows % block_height != 0:
207
+ raise ValueError(
208
+ f"Tensor height {rows} is not divisible by block_height {block_height}. "
209
+ f"Block quantization requires exact division."
210
+ )
211
+ if cols % block_width != 0:
212
+ raise ValueError(
213
+ f"Tensor width {cols} is not divisible by block_width {block_width}. "
214
+ f"Block quantization requires exact division."
215
+ )
216
+
217
+ # reshape into blocks and transpose to make each block contiguous
218
+ num_rows_blocks = rows // block_height
219
+ num_cols_blocks = cols // block_width
220
+ x_blocks = x.reshape(
221
+ num_rows_blocks,
222
+ block_height,
223
+ num_cols_blocks,
224
+ block_width,
225
+ ).transpose(1, 2)
226
+
227
+ # expand scale/zero_point for blocks
228
+ sb = scale.unsqueeze(-1).unsqueeze(-1)
229
+ zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
230
+ if do_quantize:
231
+ # quantize blocks
232
+ x_blocks = _quantize(
233
+ x=x_blocks,
234
+ scale=sb,
235
+ zero_point=zb,
236
+ q_min=q_min,
237
+ q_max=q_max,
238
+ args=args,
239
+ dtype=dtype,
240
+ global_scale=global_scale,
241
+ )
242
+ if do_dequantize:
243
+ # dequantize blocks
244
+ x_blocks = _dequantize(
245
+ x_q=x_blocks,
246
+ scale=sb,
247
+ zero_point=zb,
248
+ global_scale=global_scale,
249
+ )
250
+ # restore original shape
251
+ output = x_blocks.transpose(1, 2).reshape(original_shape)
252
+ elif args.strategy in (
253
+ QuantizationStrategy.GROUP,
254
+ QuantizationStrategy.TENSOR_GROUP,
255
+ ):
193
256
  n_dims = x.shape
194
257
  if len(n_dims) > 2:
195
258
  x = x.squeeze(0)
@@ -14,7 +14,7 @@
14
14
 
15
15
  import warnings
16
16
  from enum import Enum
17
- from typing import Any, Dict, Optional, Union
17
+ from typing import Any, Dict, List, Optional, Union
18
18
 
19
19
  import torch
20
20
  from compressed_tensors.utils import Aliasable
@@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
153
153
  :param symmetric: whether or not quantization scale is symmetric about zero-point
154
154
  :param strategy: string id determining the scope of scale/zero-point to apply
155
155
  :param group_size: group length to use for the group strategy
156
- :param block_structure: 2d block structure to use for the block strategy, must be
157
- of the format "2x4", "8x16", etc.
156
+ :param block_structure: 2d block structure to use for the block strategy; must be
157
+ a list of two ints [rows, cols] like [128, 128].
158
158
  :param dynamic: set True to perform dynamic quantization - values will not be
159
159
  calibrated during calibration phase, instead during inference new quantization
160
160
  ranges will be observed with every sample. Defaults to False for static
@@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
169
169
  symmetric: bool = True
170
170
  group_size: Optional[int] = None
171
171
  strategy: Optional[QuantizationStrategy] = None
172
- block_structure: Optional[str] = None
172
+ block_structure: Optional[List[int]] = None
173
173
  dynamic: Union[DynamicType, bool] = False
174
174
  actorder: Union[ActivationOrdering, bool, None] = None
175
175
  observer: Optional[str] = Field(
@@ -207,6 +207,28 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
207
207
 
208
208
  return value
209
209
 
210
+ @field_validator("block_structure", mode="before")
211
+ def validate_block_structure(cls, value) -> Optional[List[int]]:
212
+ if value is None:
213
+ return value
214
+ # For backward compatibility, allow string format "2x4", "8x16", etc.
215
+ if isinstance(value, str):
216
+ try:
217
+ return [int(x) for x in value.split("x")]
218
+ except Exception:
219
+ raise ValueError(
220
+ f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
221
+ )
222
+ if isinstance(value, (list, tuple)):
223
+ if len(value) != 2 or not all(isinstance(v, int) for v in value):
224
+ raise ValueError(
225
+ f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
226
+ )
227
+ return list(value)
228
+ raise ValueError(
229
+ f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
230
+ )
231
+
210
232
  @field_validator("strategy", mode="before")
211
233
  def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
212
234
  if isinstance(value, str):
@@ -277,14 +299,15 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
277
299
 
278
300
  # infer observer w.r.t. dynamic
279
301
  if dynamic:
280
- if strategy not in (
302
+ supported_strategies = (
281
303
  QuantizationStrategy.TOKEN,
282
304
  QuantizationStrategy.TENSOR,
283
305
  QuantizationStrategy.TENSOR_GROUP,
284
- ):
306
+ QuantizationStrategy.GROUP,
307
+ )
308
+ if strategy not in supported_strategies:
285
309
  raise ValueError(
286
- f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} "
287
- "must be used for dynamic quantization",
310
+ f"One of {supported_strategies} must be used for dynamic quantization"
288
311
  )
289
312
 
290
313
  if (
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import warnings
15
16
  from copy import deepcopy
16
17
  from typing import Any, Dict, List, Optional
17
18
 
@@ -52,6 +53,7 @@ class QuantizationScheme(BaseModel):
52
53
  def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
53
54
  inputs = model.input_activations
54
55
  outputs = model.output_activations
56
+ weights = model.weights
55
57
 
56
58
  if inputs is not None:
57
59
  if inputs.actorder is not None:
@@ -61,6 +63,21 @@ class QuantizationScheme(BaseModel):
61
63
  if outputs.actorder is not None:
62
64
  raise ValueError("Cannot apply actorder to output activations")
63
65
 
66
+ if (
67
+ inputs and weights
68
+ and weights.strategy == QuantizationStrategy.GROUP
69
+ and inputs.strategy == QuantizationStrategy.GROUP
70
+ and weights.group_size != inputs.group_size
71
+ ):
72
+ warnings.warn(
73
+ "Using GROUP strategy for both weights and input_activations "
74
+ f"with different group sizes ({weights.group_size} vs {inputs.group_size}) "
75
+ "may complicate fused kernel implementations. Consider using "
76
+ "TENSOR_GROUP strategy for both or matching group sizes.",
77
+ UserWarning,
78
+ stacklevel=2
79
+ )
80
+
64
81
  return model
65
82
 
66
83
 
@@ -243,6 +260,29 @@ FP8_DYNAMIC = dict(
243
260
  ),
244
261
  )
245
262
 
263
+ # Block‐wise FP8 (deepseekv3-style quantization):
264
+ # static 128x128 per‐block weights and
265
+ # dynamic per‐token‐group activations
266
+ FP8_BLOCK = dict(
267
+ weights=QuantizationArgs(
268
+ num_bits=8,
269
+ type=QuantizationType.FLOAT,
270
+ strategy=QuantizationStrategy.BLOCK,
271
+ symmetric=True,
272
+ dynamic=False,
273
+ block_structure=[128, 128],
274
+ ),
275
+ input_activations=QuantizationArgs(
276
+ num_bits=8,
277
+ type=QuantizationType.FLOAT,
278
+ strategy=QuantizationStrategy.GROUP,
279
+ symmetric=True,
280
+ dynamic=True,
281
+ observer=None,
282
+ group_size=128,
283
+ ),
284
+ )
285
+
246
286
  PRESET_SCHEMES = {
247
287
  # Unquantized (no-op)
248
288
  "UNQUANTIZED": UNQUANTIZED,
@@ -257,6 +297,7 @@ PRESET_SCHEMES = {
257
297
  # Float weight and activation schemes
258
298
  "FP8": FP8,
259
299
  "FP8_DYNAMIC": FP8_DYNAMIC,
300
+ "FP8_BLOCK": FP8_BLOCK,
260
301
  "NVFP4A16": NVFP4A16,
261
302
  "NVFP4": NVFP4,
262
303
  }
@@ -171,7 +171,10 @@ def compute_dynamic_scales_and_zp(
171
171
  reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
172
172
  elif args.strategy == QuantizationStrategy.TENSOR:
173
173
  reduce_dims = None
174
- elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
174
+ elif args.strategy in (
175
+ QuantizationStrategy.TENSOR_GROUP,
176
+ QuantizationStrategy.GROUP,
177
+ ):
175
178
  if len(value.shape) > 2:
176
179
  value = value.squeeze(0)
177
180
 
@@ -187,9 +190,15 @@ def compute_dynamic_scales_and_zp(
187
190
  ),
188
191
  )
189
192
  else:
193
+ supported_strategies = (
194
+ QuantizationStrategy.TOKEN,
195
+ QuantizationStrategy.TENSOR,
196
+ QuantizationStrategy.TENSOR_GROUP,
197
+ QuantizationStrategy.GROUP,
198
+ )
190
199
  raise ValueError(
191
200
  "Dynamic quantization is only supported for ",
192
- f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}",
201
+ f"{supported_strategies}",
193
202
  )
194
203
 
195
204
  if not reduce_dims:
@@ -117,10 +117,8 @@ class TransformFactory(RegistryMixin, ABC):
117
117
  TransformLocation.WEIGHT_INPUT,
118
118
  TransformLocation.WEIGHT_OUTPUT,
119
119
  ):
120
- assert isinstance(module, torch.nn.Linear)
121
- assert module.bias is None
122
-
123
120
  # fuse transform into weight
121
+ assert hasattr(module, "weight")
124
122
  with torch.no_grad(), align_module_device(module):
125
123
  update_offload_parameter(module, "weight", transform(module.weight))
126
124
 
@@ -14,13 +14,14 @@
14
14
 
15
15
  from typing import Optional, Union
16
16
 
17
+ import math
17
18
  import torch
18
19
  from compressed_tensors.transform import TransformArgs, TransformScheme
19
20
  from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20
21
  from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
21
- from compressed_tensors.transform.utils.utils import (
22
+ from compressed_tensors.transform.utils.matrix import (
22
23
  apply_transform_weight,
23
- get_matrix_size,
24
+ get_transform_size,
24
25
  )
25
26
  from compressed_tensors.utils import get_execution_device, get_offloaded_device
26
27
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
@@ -51,8 +52,8 @@ class HadamardFactory(TransformFactory):
51
52
  :param module: parent module that transform will be applied to
52
53
  :param args: defines how the transform will be applied to the module
53
54
  """
54
- assert isinstance(module, Linear)
55
- size = get_matrix_size(module, args.location)
55
+ assert hasattr(module, "weight")
56
+ size = get_transform_size(module, args.location, self.scheme.head_dim)
56
57
  dtype = module.weight.dtype
57
58
  device = get_offloaded_device(module)
58
59
  exec_device = get_execution_device(module)
@@ -60,7 +61,7 @@ class HadamardFactory(TransformFactory):
60
61
  factory_kwargs = {"construct_device": exec_device}
61
62
  weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
62
63
  perm = self.perms[weight] if self.scheme.randomize else None
63
- return HadamardTransform(weight, perm, args)
64
+ return HadamardTransform(weight, perm, args, type(module))
64
65
 
65
66
  def _create_weight(
66
67
  self,
@@ -81,12 +82,18 @@ class HadamardFactory(TransformFactory):
81
82
 
82
83
  class HadamardTransform(TransformBase):
83
84
  def __init__(
84
- self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
85
+ self,
86
+ weight: Parameter,
87
+ perm: Optional[Parameter],
88
+ args: TransformArgs,
89
+ module_type: type[torch.nn.Module],
85
90
  ):
86
91
  super().__init__()
87
92
  self.weight = weight
88
93
  self.perm = perm
89
94
  self.args = args
95
+ self.module_type = module_type
96
+ self._scale = math.sqrt(weight.size(0))
90
97
 
91
98
  def forward(self, value: Tensor) -> Tensor:
92
99
  weight = self.weight
@@ -96,5 +103,7 @@ class HadamardTransform(TransformBase):
96
103
 
97
104
  if self.args.inverse:
98
105
  weight = weight.T
99
-
100
- return apply_transform_weight(weight, value, self.args.location)
106
+
107
+ return apply_transform_weight(
108
+ weight, value, self.args.location, self.module_type
109
+ ) / self._scale
@@ -17,9 +17,9 @@ from typing import Optional
17
17
  import torch
18
18
  from compressed_tensors.transform import TransformArgs, TransformScheme
19
19
  from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20
- from compressed_tensors.transform.utils.utils import (
20
+ from compressed_tensors.transform.utils.matrix import (
21
21
  apply_transform_weight,
22
- get_matrix_size,
22
+ get_transform_size,
23
23
  )
24
24
  from compressed_tensors.utils import get_offloaded_device
25
25
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
@@ -50,8 +50,8 @@ class RandomMatrixFactory(TransformFactory):
50
50
  :param module: parent module that transform will be applied to
51
51
  :param args: defines how the transform will be applied to the module
52
52
  """
53
- assert isinstance(module, Linear)
54
- size = get_matrix_size(module, args.location)
53
+ assert hasattr(module, "weight")
54
+ size = get_transform_size(module, args.location, self.scheme.head_dim)
55
55
  dtype = module.weight.dtype
56
56
  device = get_offloaded_device(module)
57
57
 
@@ -59,7 +59,7 @@ class RandomMatrixFactory(TransformFactory):
59
59
  if args.inverse:
60
60
  weight = self.inverses[weight]
61
61
 
62
- return RandomMatrixTransform(weight, args)
62
+ return RandomMatrixTransform(weight, args, type(module))
63
63
 
64
64
  def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65
65
  # TODO: verify that weight is invertible (has non-zero determinant)
@@ -74,17 +74,27 @@ class RandomMatrixFactory(TransformFactory):
74
74
 
75
75
 
76
76
  class RandomMatrixTransform(TransformBase):
77
- def __init__(self, weight: Tensor, args: TransformArgs):
77
+ def __init__(
78
+ self,
79
+ weight: Tensor,
80
+ args: TransformArgs,
81
+ module_type: type[torch.nn.Module],
82
+ ):
78
83
  super().__init__()
79
84
  self.weight = weight # is an inverse if args.inverse
80
85
  self.args = args
86
+ self.module_type = module_type
81
87
 
82
88
  def forward(self, value: Tensor) -> Parameter:
83
- return apply_transform_weight(self.weight, value, self.args.location)
89
+ return apply_transform_weight(
90
+ self.weight, value, self.args.location, self.module_type
91
+ )
84
92
 
85
93
  def right_inverse(self, value: Tensor) -> Tensor:
86
94
  inverse = high_precision_invert(self.weight)
87
- return apply_transform_weight(inverse, value, self.args.location)
95
+ return apply_transform_weight(
96
+ inverse, value, self.args.location, self.module_type
97
+ )
88
98
 
89
99
 
90
100
  def high_precision_invert(weight: Tensor) -> Tensor:
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import List
15
+ from typing import List, Optional
16
16
 
17
17
  from compressed_tensors.transform import TransformArgs
18
18
  from pydantic import BaseModel, Field
@@ -40,3 +40,4 @@ class TransformScheme(BaseModel):
40
40
  apply: List[TransformArgs] = Field(default_factory=list)
41
41
  randomize: bool = Field(default=False)
42
42
  requires_grad: bool = Field(default=False)
43
+ head_dim: Optional[int] = Field(default=None)
@@ -59,7 +59,7 @@ def deterministic_hadamard_matrix(
59
59
  for _ in range(log2):
60
60
  H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
61
61
 
62
- return H / math.sqrt(size)
62
+ return H
63
63
 
64
64
 
65
65
  def random_hadamard_matrix(
@@ -86,7 +86,7 @@ def random_hadamard_matrix(
86
86
  Q = Q.to(device=device)
87
87
  Q = Q * 2 - 1
88
88
  Q = torch.diag(Q)
89
- return _matmul_hadU(Q) / math.sqrt(size)
89
+ return _matmul_hadU(Q)
90
90
 
91
91
 
92
92
  def is_pow2(n: int) -> bool:
@@ -0,0 +1,179 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Optional, Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.transform import TransformLocation
19
+
20
+
21
+ __all__ = ["get_transform_size", "apply_transform_weight"]
22
+
23
+
24
+ def get_transform_size(
25
+ module: torch.nn.Module,
26
+ location: TransformLocation,
27
+ head_dim: Optional[int] = None,
28
+ ) -> int:
29
+ """
30
+ Determine the size of a transform matrix given its location on the module
31
+
32
+ :param module: module that matrix will be applied to
33
+ :param location: location on module
34
+ :param head_dim: size of head when transform is applied to mha
35
+ :return: size of matrix
36
+ """
37
+ if isinstance(module, torch.nn.Linear):
38
+ if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
39
+ size = module.in_features
40
+ else:
41
+ size = module.out_features
42
+ elif isinstance(module, torch.nn.Embedding):
43
+ if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
44
+ size = module.num_embeddings
45
+ else:
46
+ size = module.embedding_dim
47
+ else:
48
+ raise NotImplementedError(f"Transforms on {type(module)} are not supported")
49
+
50
+ if head_dim is not None:
51
+ if size % head_dim != 0:
52
+ raise ValueError(
53
+ f"{head_dim} must divide {size} for {type(module)} at {location}"
54
+ )
55
+
56
+ size = head_dim
57
+
58
+ return size
59
+
60
+
61
+ def apply_transform_weight(
62
+ transform_weight: torch.Tensor,
63
+ value: torch.Tensor,
64
+ location: TransformLocation,
65
+ module_type: type[torch.nn.Module],
66
+ ) -> torch.Tensor:
67
+ """
68
+ Using the transform location, apply the transform_weight to the
69
+ given value wrt linear weights. For more info on input and output transforms,
70
+ see `TransformLocation`
71
+
72
+ The following explains how weights should be applied to values according to location
73
+
74
+ let x be input activation
75
+ W be weight,
76
+ yh, xh, Wh be transformed output, input, weight
77
+
78
+ note that
79
+ y = (x W.T) // torch.nn.Linear
80
+
81
+ Choose values for yh, xh, and Wh which incorporate matrix transforms
82
+
83
+ let V, Vi be transform matrices on input side
84
+ U, Ui be transform matrices on output side
85
+
86
+ pick xh = (x V)
87
+ Wh = (U.T W Vi.T)
88
+ yh = (y U)
89
+
90
+ The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
91
+
92
+ (xh) (Wh).T = (x V) (U.T W Vi.T).T
93
+ = (x V) (Vi W.T U) // transpose matrix product identity
94
+ = (x W.T) U
95
+ = y U
96
+ = yh
97
+
98
+ :param transform_weight: transform weight to apply
99
+ :param value: value to apply transform_weight to
100
+ :param location: determines how weight should be applied
101
+ :param model_type: result of type(module), passed in to determine application of
102
+ weight transform
103
+ :return: value after transform_weight has been applied
104
+ """
105
+
106
+ assert transform_weight.shape[0] == transform_weight.shape[1]
107
+
108
+ if module_type == torch.nn.Linear:
109
+ if location == TransformLocation.INPUT:
110
+ return _multihead_matmul(value, transform_weight)
111
+
112
+ elif location == TransformLocation.WEIGHT_INPUT:
113
+ # equivalent to (transform_weight @ value.T).T
114
+ return _multihead_matmul(value, transform_weight.T)
115
+
116
+ elif location == TransformLocation.WEIGHT_OUTPUT:
117
+ # equivalent to (value.T @ transform_weight).T
118
+ return _multihead_matmul(transform_weight.T, value)
119
+
120
+ elif location == TransformLocation.OUTPUT:
121
+ return _multihead_matmul(value, transform_weight)
122
+
123
+ # similar derivation to torch.nn.Linear, but `y = (x W)`
124
+ elif module_type == torch.nn.Embedding:
125
+ if location == TransformLocation.INPUT:
126
+ return _multihead_matmul(value, transform_weight)
127
+
128
+ elif location == TransformLocation.WEIGHT_INPUT:
129
+ return _multihead_matmul(
130
+ transform_weight,
131
+ value,
132
+ )
133
+
134
+ elif location == TransformLocation.WEIGHT_OUTPUT:
135
+ return _multihead_matmul(value, transform_weight)
136
+
137
+ elif location == TransformLocation.OUTPUT:
138
+ return _multihead_matmul(value, transform_weight)
139
+
140
+ raise NotImplementedError(
141
+ f"Applying transforms to {module_type} {location} is not supported"
142
+ )
143
+
144
+
145
+ def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
146
+ """
147
+ Performs A @ B for last two dims of two matrices A and B that possibly
148
+ have different shapes, as is the case in multi-headed dimension. If
149
+ shapes are different, this is equivalent to converting the last two dims
150
+ of the smaller matrix into a block-diagonal matrix with the same shape as
151
+ the last two dims of the larger matrix.
152
+
153
+ E.g. if A is half the size of B, this function will perform
154
+ [[A ] @ B
155
+ [ A]]
156
+
157
+ If B is a third of the size of A, this function will perform
158
+ A @ [[B ]
159
+ [ B ]
160
+ [ B]]
161
+
162
+ This function will error out if the shapes are not evenly divisble
163
+
164
+ :param A: left-hand tensor
165
+ :param B: right-hand tensor
166
+ :return: result
167
+ """
168
+ if A.shape[-1] > B.shape[-2]:
169
+ head_dim = B.shape[-2]
170
+ num_heads = A.shape[-1] // head_dim
171
+ A = A.unflatten(-1, (num_heads, head_dim))
172
+ return (A @ B).flatten(-2, -1)
173
+ elif A.shape[-1] < B.shape[-2]:
174
+ head_dim = A.shape[-1]
175
+ num_heads = B.shape[-2] // head_dim
176
+ B = B.unflatten(-2, (num_heads, head_dim))
177
+ return (A @ B).flatten(-3, -2)
178
+ else:
179
+ return A @ B
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.3.a20250715'
20
+ __version__ = version = '0.10.3.a20250721'
21
21
  __version_tuple__ = version_tuple = (0, 10, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250715
3
+ Version: 0.10.3a20250721
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -84,7 +84,7 @@ src/compressed_tensors/transform/factory/random_hadamard.py
84
84
  src/compressed_tensors/transform/utils/__init__.py
85
85
  src/compressed_tensors/transform/utils/hadamard.py
86
86
  src/compressed_tensors/transform/utils/hadamards.safetensors
87
- src/compressed_tensors/transform/utils/utils.py
87
+ src/compressed_tensors/transform/utils/matrix.py
88
88
  src/compressed_tensors/utils/__init__.py
89
89
  src/compressed_tensors/utils/helpers.py
90
90
  src/compressed_tensors/utils/internal.py
@@ -12,8 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import nbformat
16
15
  import pytest
16
+
17
+
18
+ nbformat = pytest.importorskip("nbformat")
17
19
  from nbconvert.preprocessors import ExecutePreprocessor
18
20
 
19
21