compressed-tensors 0.10.2a20250613__tar.gz → 0.10.2a20250617__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. {compressed_tensors-0.10.2a20250613/src/compressed_tensors.egg-info → compressed_tensors-0.10.2a20250617}/PKG-INFO +1 -1
  2. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/utils/offload.py +55 -8
  3. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/version.py +1 -1
  4. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  5. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_utils/test_offload.py +41 -0
  6. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.github/.gitkeep +0 -0
  7. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.github/actions/test/action.yml +0 -0
  8. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.github/scripts/step-status +0 -0
  9. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.github/workflows/build-test.yml +0 -0
  10. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.github/workflows/build.yml +0 -0
  11. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.github/workflows/report.yml +0 -0
  12. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.github/workflows/test-check.yaml +0 -0
  13. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.github/workflows/test.yml +0 -0
  14. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.github/workflows/trigger-all.yml +0 -0
  15. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.github/workflows/upload.yml +0 -0
  16. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/.gitignore +0 -0
  17. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/LICENSE +0 -0
  18. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/Makefile +0 -0
  19. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/README.md +0 -0
  20. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  21. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/examples/bit_packing/int4_config.json +0 -0
  22. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/examples/bitmask_compression.ipynb +0 -0
  23. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  24. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  25. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/examples/llama_1.1b/example_quant_config.json +0 -0
  26. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  27. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/examples/quantize_and_pack_int4.ipynb +0 -0
  28. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/pyproject.toml +0 -0
  29. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/setup.cfg +0 -0
  30. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/setup.py +0 -0
  31. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/__init__.py +0 -0
  32. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/README.md +0 -0
  33. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/__init__.py +0 -0
  34. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/base.py +0 -0
  35. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/__init__.py +0 -0
  36. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/base.py +0 -0
  37. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/helpers.py +0 -0
  38. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  39. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  40. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  41. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  42. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  43. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  44. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  45. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  46. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  47. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  48. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  49. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  50. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  51. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  52. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/config/__init__.py +0 -0
  53. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/config/base.py +0 -0
  54. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/config/dense.py +0 -0
  55. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  56. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  57. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/linear/__init__.py +0 -0
  58. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  59. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/__init__.py +0 -0
  60. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  61. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  62. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  63. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  64. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  65. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  66. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/quant_args.py +0 -0
  67. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/quant_config.py +0 -0
  68. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  69. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  70. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  71. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/registry/__init__.py +0 -0
  72. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/registry/registry.py +0 -0
  73. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/__init__.py +0 -0
  74. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  75. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/factory/base.py +0 -0
  76. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/factory/hadamard.py +0 -0
  77. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/factory/matrix_multiply.py +0 -0
  78. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  79. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/transform_args.py +0 -0
  80. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/transform_config.py +0 -0
  81. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  82. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  83. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  84. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  85. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/transform/utils/utils.py +0 -0
  86. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/utils/__init__.py +0 -0
  87. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/utils/helpers.py +0 -0
  88. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/utils/permutations_24.py +0 -0
  89. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/utils/permute.py +0 -0
  90. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  91. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  92. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors.egg-info/SOURCES.txt +0 -0
  93. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  94. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors.egg-info/requires.txt +0 -0
  95. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  96. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/__init__.py +0 -0
  97. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/conftest.py +0 -0
  98. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/__init__.py +0 -0
  99. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/model_compressors/__init__.py +0 -0
  100. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  101. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  102. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  103. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  104. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  105. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  106. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  107. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  108. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  109. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  110. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  111. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_configs/__init__.py +0 -0
  112. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_configs/test_base.py +0 -0
  113. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  114. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_linear/__init__.py +0 -0
  115. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_linear/test_compressed_linear.py +0 -0
  116. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/__init__.py +0 -0
  117. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/lifecycle/__init__.py +0 -0
  118. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/lifecycle/conftest.py +0 -0
  119. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  120. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  121. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  122. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  123. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  124. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  125. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  126. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/test_configs/__init__.py +0 -0
  127. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  128. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  129. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/test_quant_args.py +0 -0
  130. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/test_quant_config.py +0 -0
  131. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/test_quant_scheme.py +0 -0
  132. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  133. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_registry.py +0 -0
  134. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_transform/factory/test_correctness.py +0 -0
  135. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_transform/factory/test_memory.py +0 -0
  136. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_transform/test_transform_args.py +0 -0
  137. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_transform/test_transform_config.py +0 -0
  138. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_transform/test_transform_scheme.py +0 -0
  139. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_transform/utils/test_hadamard.py +0 -0
  140. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_utils/__init__.py +0 -0
  141. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_utils/test_helpers.py +0 -0
  142. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/test_utils/test_safetensors_load.py +0 -0
  143. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/tests/testing_utils.py +0 -0
  144. {compressed_tensors-0.10.2a20250613 → compressed_tensors-0.10.2a20250617}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.2a20250613
3
+ Version: 0.10.2a20250617
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -85,6 +85,7 @@ __all__ = [
85
85
  "delete_offload_module",
86
86
  "offloaded_dispatch",
87
87
  "disable_offloading",
88
+ "remove_dispatch",
88
89
  ]
89
90
 
90
91
 
@@ -206,9 +207,24 @@ def register_offload_parameter(
206
207
  has_onload = any(p.device != torch.device("meta") for p in module.parameters())
207
208
  module.register_parameter(name, parameter)
208
209
 
210
+ # do everything AlignDevicesHook.init_hook does
211
+ # https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
209
212
  if has_offloaded_params(module):
210
- weights_map = module._hf_hook.weights_map
211
- offload_to_weights_map(weights_map, name, parameter.data, offload_device)
213
+ hook: AlignDevicesHook = module._hf_hook
214
+ assert hook.weights_map is not None
215
+
216
+ # append to original_devices
217
+ hook.original_devices[name] = parameter.device
218
+
219
+ # append to weights map
220
+ offload_to_weights_map(hook.weights_map, name, parameter.data, offload_device)
221
+
222
+ # append to tied_params_map
223
+ offloaded = hook.weights_map[name]
224
+ if hook.tied_params_map is not None:
225
+ hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
226
+
227
+ # perform offloading
212
228
  if not has_onload:
213
229
  set_module_tensor_to_device(module, name, "meta")
214
230
 
@@ -422,7 +438,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
422
438
  hook: AlignDevicesHook = base._hf_hook
423
439
  assert hook.offload
424
440
  assert hook.weights_map is not None
425
- assert hook.tied_params_map is not None
426
441
 
427
442
  # offloading kwargs for submodule
428
443
  place_submodules = False
@@ -437,7 +452,8 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
437
452
  module, include_buffers=offload_buffers, recurse=place_submodules
438
453
  ):
439
454
  offloaded = param.to(offload_device)
440
- hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
455
+ if hook.tied_params_map is not None:
456
+ hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
441
457
  offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
442
458
 
443
459
  # if the parent places submodules, offload here
@@ -465,9 +481,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
465
481
 
466
482
  base.register_module(name, module)
467
483
 
468
- # (1): Since we cannot know which pointers are shared when we add parameters in an
469
- # online way, assume that all pointers are shared. This comes at no runtime cost
470
-
471
484
 
472
485
  def delete_offload_module(base: torch.nn.Module, name: str):
473
486
  """
@@ -502,6 +515,9 @@ def offloaded_dispatch(
502
515
  if offload_device == "disk":
503
516
  raise NotImplementedError("Disk offloading is not currently supported")
504
517
 
518
+ # remove any existing hooks
519
+ remove_dispatch(module)
520
+
505
521
  # create weights map
506
522
  state_dict = module.state_dict()
507
523
  state_dict = {key: val.to(offload_device) for key, val in state_dict.items()}
@@ -523,6 +539,33 @@ def offloaded_dispatch(
523
539
  weights_map=weights_map,
524
540
  tied_params_map=tied_params_map,
525
541
  )
542
+
543
+ # when saving a model, `PretrainedModel.save_pretrained` will only
544
+ # onload weights if the following requirements are met
545
+ # if (
546
+ # hasattr(self, "hf_device_map")
547
+ # and len(set(self.hf_device_map.values())) > 1
548
+ # and ("cpu" in self.hf_device_map.values()
549
+ # or "disk" in self.hf_device_map.values())
550
+ # ):
551
+ # because this function always offloads, disregard actual devices and
552
+ # always use `cpu` and `cuda:0` to guarantee this condition passes
553
+ setattr(module, "hf_device_map", {"fake_offload": "cpu", "fake_exec": "cuda:0"})
554
+
555
+ return module
556
+
557
+
558
+ def remove_dispatch(module: torch.nn.Module) -> torch.nn.Module:
559
+ """
560
+ Remove any existing dispatches from module
561
+
562
+ :param module: module which may be dispatched with hf hooks
563
+ :return: module without dispatch
564
+ """
565
+ remove_hook_from_module(module, recurse=True)
566
+ if hasattr(module, "hf_device_map"):
567
+ delattr(module, "hf_device_map")
568
+
526
569
  return module
527
570
 
528
571
 
@@ -551,7 +594,7 @@ def disable_offloading():
551
594
  # update any parameters which may have changed
552
595
  for module, (hook, offload) in onloaded_modules.items():
553
596
  hook.offload = offload
554
- for name, param in module.named_parameters():
597
+ for name, param in module.named_parameters(recurse=False):
555
598
  update_offload_parameter(module, name, param.data)
556
599
  hook.post_forward(module, None)
557
600
 
@@ -623,3 +666,7 @@ def align_module_device(
623
666
 
624
667
  else:
625
668
  yield
669
+
670
+
671
+ # (1): Since we cannot know which pointers are shared when we add parameters in an
672
+ # online way, assume that all pointers are shared. This has virtually no runtime cost
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.2.a20250613'
20
+ __version__ = version = '0.10.2.a20250617'
21
21
  __version_tuple__ = version_tuple = (0, 10, 2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.2a20250613
3
+ Version: 0.10.2a20250617
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -149,6 +149,47 @@ def test_register_offload_parameter():
149
149
  assert module.a.device == module.b.device == module.c.device == torch.device("meta")
150
150
 
151
151
 
152
+ @requires_accelerate()
153
+ @requires_gpu
154
+ def test_register_offload_parameter_hook_replacement():
155
+ module = ExampleModule()
156
+ parameter_c = torch.nn.Parameter(torch.tensor(1.0, device="cuda"))
157
+ parameter_d = torch.nn.Parameter(torch.tensor(1.0, device="cpu"))
158
+
159
+ offloaded_dispatch(module, "cuda")
160
+ register_offload_parameter(module, "c", parameter_c)
161
+ register_offload_parameter(module, "d", parameter_d)
162
+
163
+ with disable_hf_hook(module):
164
+ assert module.a.device == torch.device("cpu")
165
+ assert module.b.device == torch.device("cpu")
166
+ assert module.c.device == torch.device("cuda:0")
167
+ assert module.d.device == torch.device("cpu")
168
+
169
+ assert module.a.device == torch.device("meta")
170
+ assert module.b.device == torch.device("meta")
171
+ assert module.c.device == torch.device("meta")
172
+ assert module.d.device == torch.device("meta")
173
+ assert module._hf_hook.weights_map["a"].device == torch.device("cpu")
174
+ assert module._hf_hook.weights_map["b"].device == torch.device("cpu")
175
+ assert module._hf_hook.weights_map["c"].device == torch.device("cpu")
176
+ assert module._hf_hook.weights_map["d"].device == torch.device("cpu")
177
+
178
+
179
+ @requires_accelerate()
180
+ @requires_gpu
181
+ def test_register_offload_parameter_shared():
182
+ module = ExampleModule()
183
+ parameter = torch.nn.Parameter(torch.tensor(1.0))
184
+
185
+ offloaded_dispatch(module, "cuda")
186
+ register_offload_parameter(module, "c", parameter)
187
+ register_offload_parameter(module, "d", parameter)
188
+
189
+ with align_module_device(module):
190
+ assert module.c is module.d
191
+
192
+
152
193
  @requires_accelerate()
153
194
  def test_update_offload_parameter():
154
195
  from accelerate.hooks import attach_align_device_hook