compressed-tensors 0.10.2a20250606__tar.gz → 0.10.2a20250611__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. {compressed_tensors-0.10.2a20250606/src/compressed_tensors.egg-info → compressed_tensors-0.10.2a20250611}/PKG-INFO +1 -1
  2. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +7 -1
  3. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/sparse_compressors/dense.py +19 -1
  4. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/lifecycle/apply.py +1 -3
  5. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/transform/__init__.py +5 -0
  6. compressed_tensors-0.10.2a20250611/src/compressed_tensors/transform/factory/base.py +164 -0
  7. compressed_tensors-0.10.2a20250611/src/compressed_tensors/transform/factory/hadamard.py +79 -0
  8. compressed_tensors-0.10.2a20250611/src/compressed_tensors/transform/factory/matrix_multiply.py +90 -0
  9. compressed_tensors-0.10.2a20250611/src/compressed_tensors/transform/factory/random_hadamard.py +34 -0
  10. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/utils/offload.py +3 -0
  11. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/version.py +1 -1
  12. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  13. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors.egg-info/SOURCES.txt +7 -0
  14. compressed_tensors-0.10.2a20250611/tests/test_transform/factory/test_correctness.py +116 -0
  15. compressed_tensors-0.10.2a20250611/tests/test_transform/factory/test_memory.py +112 -0
  16. compressed_tensors-0.10.2a20250611/tests/test_utils/__init__.py +13 -0
  17. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.github/.gitkeep +0 -0
  18. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.github/actions/test/action.yml +0 -0
  19. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.github/scripts/step-status +0 -0
  20. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.github/workflows/build-test.yml +0 -0
  21. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.github/workflows/build.yml +0 -0
  22. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.github/workflows/report.yml +0 -0
  23. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.github/workflows/test-check.yaml +0 -0
  24. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.github/workflows/test.yml +0 -0
  25. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.github/workflows/trigger-all.yml +0 -0
  26. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.github/workflows/upload.yml +0 -0
  27. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/.gitignore +0 -0
  28. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/LICENSE +0 -0
  29. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/Makefile +0 -0
  30. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/README.md +0 -0
  31. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  32. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/examples/bit_packing/int4_config.json +0 -0
  33. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/examples/bitmask_compression.ipynb +0 -0
  34. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  35. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  36. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/examples/llama_1.1b/example_quant_config.json +0 -0
  37. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  38. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/examples/quantize_and_pack_int4.ipynb +0 -0
  39. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/pyproject.toml +0 -0
  40. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/setup.cfg +0 -0
  41. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/setup.py +0 -0
  42. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/__init__.py +0 -0
  43. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/README.md +0 -0
  44. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/__init__.py +0 -0
  45. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/base.py +0 -0
  46. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/__init__.py +0 -0
  47. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/base.py +0 -0
  48. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/helpers.py +0 -0
  49. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  50. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  51. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  52. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  53. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  54. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  55. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  56. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  57. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  58. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  59. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  60. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  61. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/config/__init__.py +0 -0
  62. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/config/base.py +0 -0
  63. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/config/dense.py +0 -0
  64. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  65. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  66. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/linear/__init__.py +0 -0
  67. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  68. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/__init__.py +0 -0
  69. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  70. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  71. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  72. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  73. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  74. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/quant_args.py +0 -0
  75. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/quant_config.py +0 -0
  76. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  77. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  78. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  79. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/registry/__init__.py +0 -0
  80. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/registry/registry.py +0 -0
  81. {compressed_tensors-0.10.2a20250606/src/compressed_tensors/transform/utils → compressed_tensors-0.10.2a20250611/src/compressed_tensors/transform/factory}/__init__.py +0 -0
  82. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/transform/transform_args.py +0 -0
  83. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/transform/transform_config.py +0 -0
  84. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  85. {compressed_tensors-0.10.2a20250606/tests → compressed_tensors-0.10.2a20250611/src/compressed_tensors/transform/utils}/__init__.py +0 -0
  86. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  87. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/transform/utils/utils.py +0 -0
  88. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/utils/__init__.py +0 -0
  89. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/utils/helpers.py +0 -0
  90. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/utils/permutations_24.py +0 -0
  91. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/utils/permute.py +0 -0
  92. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  93. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  94. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  95. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors.egg-info/requires.txt +0 -0
  96. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  97. {compressed_tensors-0.10.2a20250606/tests/test_compressors → compressed_tensors-0.10.2a20250611/tests}/__init__.py +0 -0
  98. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/conftest.py +0 -0
  99. {compressed_tensors-0.10.2a20250606/tests/test_compressors/model_compressors → compressed_tensors-0.10.2a20250611/tests/test_compressors}/__init__.py +0 -0
  100. {compressed_tensors-0.10.2a20250606/tests/test_compressors/quantized_compressors → compressed_tensors-0.10.2a20250611/tests/test_compressors/model_compressors}/__init__.py +0 -0
  101. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  102. {compressed_tensors-0.10.2a20250606/tests/test_compressors/sparse_compressors → compressed_tensors-0.10.2a20250611/tests/test_compressors/quantized_compressors}/__init__.py +0 -0
  103. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  104. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  105. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  106. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  107. {compressed_tensors-0.10.2a20250606/tests/test_compressors/sparse_quantized_compressors → compressed_tensors-0.10.2a20250611/tests/test_compressors/sparse_compressors}/__init__.py +0 -0
  108. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  109. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  110. {compressed_tensors-0.10.2a20250606/tests/test_configs → compressed_tensors-0.10.2a20250611/tests/test_compressors/sparse_quantized_compressors}/__init__.py +0 -0
  111. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  112. {compressed_tensors-0.10.2a20250606/tests/test_linear → compressed_tensors-0.10.2a20250611/tests/test_configs}/__init__.py +0 -0
  113. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_configs/test_base.py +0 -0
  114. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  115. {compressed_tensors-0.10.2a20250606/tests/test_quantization → compressed_tensors-0.10.2a20250611/tests/test_linear}/__init__.py +0 -0
  116. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_linear/test_compressed_linear.py +0 -0
  117. {compressed_tensors-0.10.2a20250606/tests/test_quantization/lifecycle → compressed_tensors-0.10.2a20250611/tests/test_quantization}/__init__.py +0 -0
  118. {compressed_tensors-0.10.2a20250606/tests/test_quantization/test_configs → compressed_tensors-0.10.2a20250611/tests/test_quantization/lifecycle}/__init__.py +0 -0
  119. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/lifecycle/conftest.py +0 -0
  120. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  121. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  122. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  123. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  124. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  125. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  126. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  127. {compressed_tensors-0.10.2a20250606/tests/test_utils → compressed_tensors-0.10.2a20250611/tests/test_quantization/test_configs}/__init__.py +0 -0
  128. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  129. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  130. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/test_quant_args.py +0 -0
  131. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/test_quant_config.py +0 -0
  132. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/test_quant_scheme.py +0 -0
  133. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  134. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_registry.py +0 -0
  135. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_transform/test_transform_args.py +0 -0
  136. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_transform/test_transform_config.py +0 -0
  137. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_transform/test_transform_scheme.py +0 -0
  138. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_transform/utils/test_hadamard.py +0 -0
  139. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_utils/test_helpers.py +0 -0
  140. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_utils/test_offload.py +0 -0
  141. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/test_utils/test_safetensors_load.py +0 -0
  142. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/tests/testing_utils.py +0 -0
  143. {compressed_tensors-0.10.2a20250606 → compressed_tensors-0.10.2a20250611}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.2a20250606
3
+ Version: 0.10.2a20250611
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -68,6 +68,10 @@ from transformers import AutoConfig
68
68
  from transformers.file_utils import CONFIG_NAME
69
69
 
70
70
 
71
+ if TYPE_CHECKING:
72
+ from compressed_tensors.compressors import BaseQuantizationCompressor
73
+
74
+
71
75
  __all__ = ["ModelCompressor", "map_module_to_scheme"]
72
76
 
73
77
  _LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -257,7 +261,9 @@ class ModelCompressor:
257
261
  self.sparsity_config = sparsity_config
258
262
  self.quantization_config = quantization_config
259
263
  self.sparsity_compressor = None
260
- self.quantization_compressor = None
264
+ self.quantization_compressor: Optional[
265
+ Union[BaseQuantizationCompressor, DenseCompressor]
266
+ ] = None
261
267
 
262
268
  if sparsity_config is not None:
263
269
  self.sparsity_compressor = BaseCompressor.load_from_registry(
@@ -12,13 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Dict, Generator, Tuple
15
+ from typing import TYPE_CHECKING, Dict, Generator, Tuple
16
16
 
17
+ import torch
17
18
  from compressed_tensors.compressors.base import BaseCompressor
18
19
  from compressed_tensors.config import CompressionFormat
19
20
  from torch import Tensor
20
21
 
21
22
 
23
+ if TYPE_CHECKING:
24
+ from compressed_tensors.quantization import QuantizationScheme
25
+
26
+
22
27
  @BaseCompressor.register(name=CompressionFormat.dense.value)
23
28
  class DenseCompressor(BaseCompressor):
24
29
  """
@@ -47,3 +52,16 @@ class DenseCompressor(BaseCompressor):
47
52
  ) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
48
53
  for key, value in state_dict.items():
49
54
  yield key, value
55
+
56
+ def decompress_module_from_state_dict(
57
+ self,
58
+ prefix: str,
59
+ state_dict: Dict[str, torch.Tensor],
60
+ scheme: "QuantizationScheme",
61
+ ) -> Dict[str, torch.Tensor]:
62
+ """
63
+ This function is implemented as a workaround because of how
64
+ `ModelCompressor.quantization_compressor` can be set to either
65
+ an instance of `BaseQuantizationCompressor` or `DenseCompressor`.
66
+ """
67
+ return state_dict.copy()
@@ -183,9 +183,7 @@ def apply_quantization_config(
183
183
  replace_module(model, name, compressed_linear)
184
184
 
185
185
  # target matched - add layer and scheme to target list
186
- submodule.quantization_scheme = _scheme_from_targets(
187
- target_to_scheme, targets, name
188
- )
186
+ submodule.quantization_scheme = scheme
189
187
 
190
188
  names_to_scheme[name] = submodule.quantization_scheme
191
189
 
@@ -18,3 +18,8 @@
18
18
  from .transform_args import *
19
19
  from .transform_scheme import *
20
20
  from .transform_config import *
21
+
22
+ from .factory.base import *
23
+ from .factory.hadamard import *
24
+ from .factory.matrix_multiply import *
25
+ from .factory.random_hadamard import *
@@ -0,0 +1,164 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn.utils.parametrize as P
20
+ from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
21
+ from compressed_tensors.registry.registry import RegistryMixin, T
22
+ from compressed_tensors.transform import (
23
+ TransformArgs,
24
+ TransformLocation,
25
+ TransformScheme,
26
+ )
27
+ from compressed_tensors.utils import (
28
+ align_module_device,
29
+ has_offloaded_params,
30
+ patch_attr,
31
+ register_offload_module,
32
+ update_offload_parameter,
33
+ )
34
+ from torch import Tensor
35
+ from torch.nn import Module, Parameter
36
+
37
+
38
+ __all__ = ["TransformFactory", "TransformBase"]
39
+
40
+
41
+ class TransformFactory(RegistryMixin, ABC):
42
+ """
43
+ Abstract factory base used to create and apply transforms to a model
44
+
45
+ :param name: name associated with transform scheme
46
+ :param scheme: transform scheme which defines how transforms should be created
47
+ :param seed: random seed used to transform weight randomization
48
+ """
49
+
50
+ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
51
+ self.name = name
52
+ self.scheme = scheme
53
+ self.generator = torch.Generator()
54
+ if seed is not None:
55
+ self.generator.manual_seed(seed)
56
+
57
+ @classmethod
58
+ def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T:
59
+ """
60
+ Create a transform factory from a scheme
61
+
62
+ :param scheme: defines how transforms should be created
63
+ :param kwargs: TransformFactory constructor arguments
64
+ :return: subclass of `TransformFactory` corresponding to the scheme type
65
+ """
66
+ constructor = cls.get_value_from_registry(name=scheme.type)
67
+ return constructor(scheme=scheme, **kwargs)
68
+
69
+ @abstractmethod
70
+ def create_transform(self, module: Module, args: TransformArgs) -> "TransformBase":
71
+ """
72
+ Abstract method which defines how a transform should be created. May utilize
73
+ caching to maximize shared memory
74
+
75
+ :param module: parent module that transform will be applied to
76
+ :param args: defines how the transform will be applied to the module
77
+ :return: instance of TransformBase
78
+ """
79
+ raise NotImplementedError()
80
+
81
+ def apply_to_model(self, model: Module):
82
+ """
83
+ Create transforms and apply them to the model
84
+
85
+ :param model: module to apply transforms to
86
+ """
87
+ for arg in self.scheme.apply:
88
+ for name, module in list(model.named_modules()):
89
+ if is_target(name, module, arg.targets, arg.ignore):
90
+ self._apply_to_module(module, arg)
91
+
92
+ def _apply_to_module(self, module: Module, args: TransformArgs):
93
+ """
94
+ Create transforms and apply them to the module
95
+
96
+ :param module: target module to apply transforms to
97
+ :param args: defines how the transform will be applied to the target module
98
+ """
99
+ # create transform as submodule
100
+ transform_name = f"{self.name}_{args.location.value}"
101
+ transform = self.create_transform(module, args)
102
+ register_offload_module(module, transform_name, transform) # (1)
103
+
104
+ # register input transformation hook
105
+ if args.location == TransformLocation.INPUT:
106
+
107
+ def input_hook(_, args):
108
+ input = args[0]
109
+ return transform(input)
110
+
111
+ module.register_forward_pre_hook(input_hook, prepend=True)
112
+
113
+ # eagerly apply transformation to weight
114
+ elif args.location in (
115
+ TransformLocation.WEIGHT_INPUT,
116
+ TransformLocation.WEIGHT_OUTPUT,
117
+ ):
118
+ assert isinstance(module, torch.nn.Linear)
119
+ assert module.bias is None
120
+
121
+ with torch.no_grad(), align_module_device(module):
122
+ update_offload_parameter(module, "weight", transform(module.weight))
123
+
124
+ if self.scheme.requires_grad:
125
+ # for training, the weight changes with every forward pass
126
+ # so we can leverage parametrization to propagate the gradient
127
+ if has_offloaded_params(module):
128
+ raise ValueError("Offloaded training is not supported")
129
+ P.register_parametrization(module, "weight", transform)
130
+
131
+ # register output transformation hook
132
+ elif args.location == TransformLocation.OUTPUT:
133
+
134
+ def output_hook(_, _input, output):
135
+ return transform(output)
136
+
137
+ module.register_forward_hook(output_hook)
138
+
139
+ # other locations such as q_attn and k_attn have not been implemented
140
+ else:
141
+ raise NotImplementedError()
142
+
143
+ # (1) even in the `weight` cases, this submodule attachment is needed in order
144
+ # to support saving in the frozen state
145
+
146
+
147
+ class TransformBase(Module, ABC):
148
+ """
149
+ Represents the application of a transform accord to TransformArgs
150
+ """
151
+
152
+ args: TransformArgs
153
+ weight: Parameter
154
+
155
+ @abstractmethod
156
+ def forward(self, value: Tensor) -> Tensor:
157
+ raise NotImplementedError()
158
+
159
+ def right_inverse(self, value: Tensor) -> Tensor:
160
+ with patch_attr(self.args, "inverse", not self.args.inverse):
161
+ return self.forward(value)
162
+
163
+ def __repr__(self):
164
+ return f"{self.__class__.__name__}(inverse={self.args.inverse})"
@@ -0,0 +1,79 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from compressed_tensors.transform import TransformArgs, TransformScheme
19
+ from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20
+ from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
21
+ from compressed_tensors.transform.utils.utils import (
22
+ apply_transform_weight,
23
+ get_matrix_size,
24
+ )
25
+ from compressed_tensors.utils import get_offloaded_device
26
+ from compressed_tensors.utils.helpers import ParameterizedDefaultDict
27
+ from torch import Tensor, device, dtype
28
+ from torch.nn import Linear, Module, Parameter
29
+
30
+
31
+ @TransformFactory.register("hadamard")
32
+ class HadamardFactory(TransformFactory):
33
+ """
34
+ Factory used to apply hadamard transforms to a model
35
+
36
+ :param name: name associated with transform scheme
37
+ :param scheme: transform scheme which defines how transforms should be created
38
+ :param seed: random seed used to transform weight randomization
39
+ """
40
+
41
+ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
42
+ super().__init__(name, scheme, seed)
43
+ self.weights = ParameterizedDefaultDict(self._create_weight)
44
+
45
+ def create_transform(self, module: Module, args: TransformArgs):
46
+ """
47
+ Create a HadamardTransform for applying to a module. Transforms with the same
48
+ size, dtype, and device are cached
49
+
50
+ :param module: parent module that transform will be applied to
51
+ :param args: defines how the transform will be applied to the module
52
+ """
53
+ assert isinstance(module, Linear)
54
+ size = get_matrix_size(module, args.location)
55
+ dtype = module.weight.dtype
56
+ device = get_offloaded_device(module)
57
+
58
+ weight = self.weights[size, dtype, device]
59
+ return HadamardTransform(weight, args)
60
+
61
+ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62
+ data = deterministic_hadamard_matrix(size)
63
+ data = data.to(dtype=dtype, device=device)
64
+ return Parameter(data, requires_grad=self.scheme.requires_grad)
65
+
66
+
67
+ class HadamardTransform(TransformBase):
68
+ def __init__(self, weight: Parameter, args: TransformArgs):
69
+ super().__init__()
70
+ self.weight = weight
71
+ self.args = args
72
+
73
+ def forward(self, value: Tensor) -> Tensor:
74
+ if not self.args.inverse:
75
+ weight = self.weight
76
+ else:
77
+ weight = self.weight.T
78
+
79
+ return apply_transform_weight(weight, value, self.args.location)
@@ -0,0 +1,90 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from compressed_tensors.transform import TransformArgs, TransformScheme
19
+ from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20
+ from compressed_tensors.transform.utils.utils import (
21
+ apply_transform_weight,
22
+ get_matrix_size,
23
+ )
24
+ from compressed_tensors.utils import get_offloaded_device
25
+ from compressed_tensors.utils.helpers import ParameterizedDefaultDict
26
+ from torch import Tensor, device, dtype
27
+ from torch.nn import Linear, Module, Parameter
28
+
29
+
30
+ @TransformFactory.register("random-matrix")
31
+ class RandomMatrixFactory(TransformFactory):
32
+ """
33
+ Factory used to apply random matrix transforms to a model
34
+
35
+ :param name: name associated with transform scheme
36
+ :param scheme: transform scheme which defines how transforms should be created
37
+ :param seed: random seed used to transform weight randomization
38
+ """
39
+
40
+ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
41
+ super().__init__(name, scheme, seed)
42
+ self.weights = ParameterizedDefaultDict(self._create_weight)
43
+ self.inverses = ParameterizedDefaultDict(self._create_inverse)
44
+
45
+ def create_transform(self, module: Module, args: TransformArgs):
46
+ """
47
+ Create a RandomMatrixTransform for applying to a module. Transforms with the
48
+ same size, dtype, and device are cached
49
+
50
+ :param module: parent module that transform will be applied to
51
+ :param args: defines how the transform will be applied to the module
52
+ """
53
+ assert isinstance(module, Linear)
54
+ size = get_matrix_size(module, args.location)
55
+ dtype = module.weight.dtype
56
+ device = get_offloaded_device(module)
57
+
58
+ weight = self.weights[size, dtype, device]
59
+ if args.inverse:
60
+ weight = self.inverses[weight]
61
+
62
+ return RandomMatrixTransform(weight, args)
63
+
64
+ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65
+ data = torch.rand(
66
+ (size, size), generator=self.generator, dtype=dtype, device=device
67
+ )
68
+ return Parameter(data, requires_grad=self.scheme.requires_grad)
69
+
70
+ def _create_inverse(self, weight: Parameter) -> Parameter:
71
+ data = high_precision_invert(weight.data)
72
+ return Parameter(data, requires_grad=False)
73
+
74
+
75
+ class RandomMatrixTransform(TransformBase):
76
+ def __init__(self, weight: Tensor, args: TransformArgs):
77
+ super().__init__()
78
+ self.weight = weight # is an inverse if args.inverse
79
+ self.args = args
80
+
81
+ def forward(self, value: Tensor) -> Parameter:
82
+ return apply_transform_weight(self.weight, value, self.args.location)
83
+
84
+ def right_inverse(self, value: Tensor) -> Tensor:
85
+ inverse = high_precision_invert(self.weight)
86
+ return apply_transform_weight(inverse, value, self.args.location)
87
+
88
+
89
+ def high_precision_invert(weight: Tensor) -> Tensor:
90
+ return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
@@ -0,0 +1,34 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from compressed_tensors.transform import HadamardFactory, TransformFactory
16
+ from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
17
+ from torch import device, dtype
18
+ from torch.nn import Parameter
19
+
20
+
21
+ @TransformFactory.register("random-hadamard")
22
+ class RandomHadamardFactory(HadamardFactory):
23
+ """
24
+ Factory used to apply random hadamard transforms to a model
25
+
26
+ :param name: name associated with transform scheme
27
+ :param scheme: transform scheme which defines how transforms should be created
28
+ :param seed: random seed used to transform weight randomization
29
+ """
30
+
31
+ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32
+ data = random_hadamard_matrix(size, self.generator)
33
+ data = data.to(dtype=dtype, device=device)
34
+ return Parameter(data, requires_grad=self.scheme.requires_grad)
@@ -87,12 +87,15 @@ def check_accelerate(fallback: Any):
87
87
  if not _has_accelerate:
88
88
 
89
89
  if fallback == "error":
90
+
90
91
  @wraps(func)
91
92
  def fallback_fn(*args, **kwargs):
92
93
  raise ValueError(
93
94
  "Please install `accelerate` in order to use this function"
94
95
  )
96
+
95
97
  else:
98
+
96
99
  @wraps(func)
97
100
  def fallback_fn(*args, **kwargs):
98
101
  return fallback
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.2.a20250606'
20
+ __version__ = version = '0.10.2.a20250611'
21
21
  __version_tuple__ = version_tuple = (0, 10, 2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.2a20250606
3
+ Version: 0.10.2a20250611
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -75,6 +75,11 @@ src/compressed_tensors/transform/__init__.py
75
75
  src/compressed_tensors/transform/transform_args.py
76
76
  src/compressed_tensors/transform/transform_config.py
77
77
  src/compressed_tensors/transform/transform_scheme.py
78
+ src/compressed_tensors/transform/factory/__init__.py
79
+ src/compressed_tensors/transform/factory/base.py
80
+ src/compressed_tensors/transform/factory/hadamard.py
81
+ src/compressed_tensors/transform/factory/matrix_multiply.py
82
+ src/compressed_tensors/transform/factory/random_hadamard.py
78
83
  src/compressed_tensors/transform/utils/__init__.py
79
84
  src/compressed_tensors/transform/utils/hadamard.py
80
85
  src/compressed_tensors/transform/utils/utils.py
@@ -127,6 +132,8 @@ tests/test_quantization/test_utils/test_helpers.py
127
132
  tests/test_transform/test_transform_args.py
128
133
  tests/test_transform/test_transform_config.py
129
134
  tests/test_transform/test_transform_scheme.py
135
+ tests/test_transform/factory/test_correctness.py
136
+ tests/test_transform/factory/test_memory.py
130
137
  tests/test_transform/utils/test_hadamard.py
131
138
  tests/test_utils/__init__.py
132
139
  tests/test_utils/test_helpers.py
@@ -0,0 +1,116 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pytest
16
+ import torch
17
+ from compressed_tensors.transform import (
18
+ TransformArgs,
19
+ TransformFactory,
20
+ TransformScheme,
21
+ )
22
+ from compressed_tensors.utils import align_modules, force_cpu_offload
23
+ from tests.testing_utils import requires_accelerate, requires_gpu
24
+
25
+
26
+ class TransformableModel(torch.nn.Module):
27
+ def __init__(self, *sizes):
28
+ super().__init__()
29
+ self.fcs = torch.nn.ModuleList([])
30
+ self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False))
31
+ for index in range(1, len(sizes) - 1):
32
+ self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False))
33
+
34
+ def forward(self, x):
35
+ for layer in self.fcs:
36
+ x = layer(x)
37
+ return x
38
+
39
+
40
+ @pytest.mark.parametrize(
41
+ "scheme",
42
+ [TransformScheme(type=name) for name in TransformFactory.registered_names()],
43
+ )
44
+ def test_correctness_linear(scheme):
45
+ size = (4, 8)
46
+ module = torch.nn.Linear(*size, bias=True)
47
+ factory = TransformFactory.from_scheme(scheme, name="")
48
+
49
+ input_tfm = factory.create_transform(
50
+ module, TransformArgs(targets="Linear", location="input", inverse=True)
51
+ )
52
+ w_in_tfm = factory.create_transform(
53
+ module, TransformArgs(targets="Linear", location="weight_input")
54
+ )
55
+ w_out_tfm = factory.create_transform(
56
+ module, TransformArgs(targets="Linear", location="weight_output")
57
+ )
58
+ output_tfm = factory.create_transform(
59
+ module, TransformArgs(targets="Linear", location="output", inverse=True)
60
+ )
61
+
62
+ input = torch.rand((17, size[0]))
63
+ true_output = input @ module.weight.T
64
+ input_transformed = input_tfm(input)
65
+ weight_transformed = w_out_tfm(w_in_tfm(module.weight))
66
+ output = output_tfm(input_transformed @ weight_transformed.T)
67
+ assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
68
+
69
+
70
+ @pytest.mark.parametrize(
71
+ "scheme",
72
+ [TransformScheme(type=name) for name in TransformFactory.registered_names()],
73
+ )
74
+ def test_correctness_model(scheme, offload=False):
75
+ # load model
76
+ model = TransformableModel(2, 4, 8, 16, 32, 64)
77
+ if offload:
78
+ model = force_cpu_offload(model, torch.device("cuda"))
79
+
80
+ # create factory
81
+ scheme.apply = [
82
+ # weight output -> input
83
+ TransformArgs(targets="fcs.0", location="weight_output"),
84
+ TransformArgs(targets="fcs.1", location="input", inverse=True),
85
+ # output -> weight input
86
+ TransformArgs(targets="fcs.1", location="output"),
87
+ TransformArgs(targets="fcs.2", location="weight_input", inverse=True),
88
+ # output -> input
89
+ TransformArgs(targets="fcs.2", location="output"),
90
+ TransformArgs(targets="fcs.3", location="input", inverse=True),
91
+ # weight output -> weight input
92
+ TransformArgs(targets="fcs.3", location="weight_output"),
93
+ TransformArgs(targets="fcs.4", location="weight_input", inverse=True),
94
+ ]
95
+ factory = TransformFactory.from_scheme(scheme, name="")
96
+
97
+ # create inputs
98
+ input = torch.rand((17, model.fcs[0].in_features))
99
+ if offload:
100
+ input = input.to(torch.device("cuda"))
101
+
102
+ # compare outputs
103
+ true_output = model(input)
104
+ factory.apply_to_model(model)
105
+ output = model(input)
106
+ assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
107
+
108
+
109
+ @requires_gpu
110
+ @requires_accelerate()
111
+ @pytest.mark.parametrize(
112
+ "scheme",
113
+ [TransformScheme(type=name) for name in TransformFactory.registered_names()],
114
+ )
115
+ def test_correctness_model_offload(scheme):
116
+ test_correctness_model(scheme, offload=True)