compressed-tensors 0.10.3a20250806__tar.gz → 0.10.3a20250812__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {compressed_tensors-0.10.3a20250806/src/compressed_tensors.egg-info → compressed_tensors-0.10.3a20250812}/PKG-INFO +1 -1
  2. compressed_tensors-0.10.3a20250812/pyproject.toml +16 -0
  3. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/base.py +8 -3
  4. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +58 -35
  5. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/quant_args.py +3 -1
  6. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/quant_config.py +8 -2
  7. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/quant_scheme.py +4 -2
  8. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/apply.py +4 -0
  9. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/factory/base.py +2 -2
  10. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/factory/hadamard.py +15 -8
  11. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/factory/matrix_multiply.py +17 -8
  12. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/transform_args.py +9 -1
  13. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/transform_config.py +2 -40
  14. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/transform_scheme.py +8 -1
  15. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/__init__.py +1 -0
  16. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/offload.py +15 -1
  17. compressed_tensors-0.10.3a20250812/src/compressed_tensors/utils/type.py +74 -0
  18. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/version.py +1 -1
  19. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  20. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors.egg-info/SOURCES.txt +2 -0
  21. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/factory/test_memory.py +1 -1
  22. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/test_transform_config.py +14 -11
  23. compressed_tensors-0.10.3a20250812/tests/test_utils/test_type.py +79 -0
  24. compressed_tensors-0.10.3a20250806/pyproject.toml +0 -7
  25. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/.gitkeep +0 -0
  26. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/actions/test/action.yml +0 -0
  27. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/scripts/step-status +0 -0
  28. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/build-test.yml +0 -0
  29. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/build.yml +0 -0
  30. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/report.yml +0 -0
  31. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/test-check.yaml +0 -0
  32. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/test.yml +0 -0
  33. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/trigger-all.yml +0 -0
  34. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/upload.yml +0 -0
  35. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.gitignore +0 -0
  36. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/LICENSE +0 -0
  37. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/Makefile +0 -0
  38. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/README.md +0 -0
  39. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  40. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/bit_packing/int4_config.json +0 -0
  41. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/bitmask_compression.ipynb +0 -0
  42. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  43. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  44. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/llama_1.1b/example_quant_config.json +0 -0
  45. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  46. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/quantize_and_pack_int4.ipynb +0 -0
  47. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/setup.cfg +0 -0
  48. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/setup.py +0 -0
  49. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/__init__.py +0 -0
  50. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/README.md +0 -0
  51. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/__init__.py +0 -0
  52. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/__init__.py +0 -0
  53. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/base.py +0 -0
  54. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/helpers.py +0 -0
  55. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  56. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  57. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  58. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  59. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  60. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  61. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  62. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  63. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  64. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  65. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  66. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  67. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  68. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/config/__init__.py +0 -0
  69. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/config/base.py +0 -0
  70. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/config/dense.py +0 -0
  71. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  72. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  73. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/linear/__init__.py +0 -0
  74. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  75. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/__init__.py +0 -0
  76. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  77. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  78. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  79. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  80. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  81. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  82. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  83. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  84. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/registry/__init__.py +0 -0
  85. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/registry/registry.py +0 -0
  86. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/__init__.py +0 -0
  87. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  88. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  89. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  90. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  91. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  92. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/utils/matrix.py +0 -0
  93. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/helpers.py +0 -0
  94. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/internal.py +0 -0
  95. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/match.py +0 -0
  96. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/permutations_24.py +0 -0
  97. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/permute.py +0 -0
  98. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  99. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  100. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  101. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors.egg-info/requires.txt +0 -0
  102. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  103. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/__init__.py +0 -0
  104. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/conftest.py +0 -0
  105. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/__init__.py +0 -0
  106. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/model_compressors/__init__.py +0 -0
  107. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  108. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  109. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  110. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  111. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  112. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  113. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  114. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  115. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  116. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  117. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  118. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_configs/__init__.py +0 -0
  119. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_configs/test_base.py +0 -0
  120. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  121. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_linear/__init__.py +0 -0
  122. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_linear/test_compressed_linear.py +0 -0
  123. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/__init__.py +0 -0
  124. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/__init__.py +0 -0
  125. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/conftest.py +0 -0
  126. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  127. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  128. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  129. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  130. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  131. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  132. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  133. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_configs/__init__.py +0 -0
  134. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  135. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  136. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_quant_args.py +0 -0
  137. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_quant_config.py +0 -0
  138. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_quant_scheme.py +0 -0
  139. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  140. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_registry.py +0 -0
  141. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/conftest.py +0 -0
  142. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/factory/test_correctness.py +0 -0
  143. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/factory/test_serialization.py +0 -0
  144. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/test_transform_args.py +0 -0
  145. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/test_transform_scheme.py +0 -0
  146. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/utils/test_hadamard.py +0 -0
  147. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_utils/__init__.py +0 -0
  148. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_utils/test_helpers.py +0 -0
  149. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_utils/test_match.py +0 -0
  150. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_utils/test_offload.py +0 -0
  151. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_utils/test_safetensors_load.py +0 -0
  152. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/testing_utils.py +0 -0
  153. {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250806
3
+ Version: 0.10.3a20250812
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -0,0 +1,16 @@
1
+ [build-system]
2
+ requires = ["setuptools", "wheel", "setuptools_scm==8.2.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.black]
6
+ line-length = 88
7
+ target-version = ['py36']
8
+
9
+ [tool.pytest.ini_options]
10
+ markers = [
11
+ "unit: tests to ensure code correctness and regression test functionality",
12
+ "smoke: quick tests to check basic functionality",
13
+ "sanity: tests to ensure that new changes do not break existing functionality",
14
+ "regression: detailed tests to ensure major functions work correctly",
15
+ "integration: tests which integrate with a third party service such as HF",
16
+ ]
@@ -12,9 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- SPARSITY_CONFIG_NAME = "sparsity_config"
15
+ # configs
16
16
  QUANTIZATION_CONFIG_NAME = "quantization_config"
17
- COMPRESSION_CONFIG_NAME = "compression_config"
18
- KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
17
+ SPARSITY_CONFIG_NAME = "sparsity_config"
18
+ TRANSFORM_CONFIG_NAME = "transform_config"
19
+
20
+ # required fields
19
21
  COMPRESSION_VERSION_NAME = "version"
20
22
  QUANTIZATION_METHOD_NAME = "quant_method"
23
+
24
+ # auxillary configs
25
+ KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
@@ -29,6 +29,7 @@ from compressed_tensors.base import (
29
29
  QUANTIZATION_CONFIG_NAME,
30
30
  QUANTIZATION_METHOD_NAME,
31
31
  SPARSITY_CONFIG_NAME,
32
+ TRANSFORM_CONFIG_NAME,
32
33
  )
33
34
  from compressed_tensors.compressors.base import BaseCompressor
34
35
  from compressed_tensors.compressors.sparse_compressors import DenseCompressor
@@ -43,6 +44,7 @@ from compressed_tensors.quantization import (
43
44
  )
44
45
  from compressed_tensors.quantization.lifecycle import expand_target_names
45
46
  from compressed_tensors.quantization.utils import is_module_quantized
47
+ from compressed_tensors.transform import TransformConfig
46
48
  from compressed_tensors.utils import (
47
49
  align_module_device,
48
50
  delete_offload_parameter,
@@ -105,6 +107,7 @@ class ModelCompressor:
105
107
 
106
108
  sparsity_config: Optional[SparsityCompressionConfig] = None
107
109
  quantization_config: Optional[QuantizationConfig] = None
110
+ transform_config: Optional[TransformConfig] = None
108
111
 
109
112
  @classmethod
110
113
  def from_pretrained(
@@ -144,6 +147,8 @@ class ModelCompressor:
144
147
 
145
148
  sparsity_config = cls.parse_sparsity_config(compression_config)
146
149
  quantization_config = cls.parse_quantization_config(compression_config)
150
+ # TODO: transform config is not support by CompressedTensorsConfig yet
151
+
147
152
  if sparsity_config is None and quantization_config is None:
148
153
  return None
149
154
 
@@ -177,25 +182,32 @@ class ModelCompressor:
177
182
  algorithm
178
183
  :return: compressor for the configs, or None if model is not compressed
179
184
  """
185
+ # reconstruct config from schemes attached to modules
180
186
  quantization_config = QuantizationConfig.from_pretrained(
181
187
  model, format=quantization_format
182
188
  )
183
189
 
190
+ # use config passed as argument
184
191
  if isinstance(sparsity_config, str): # we passed in a sparsity format
185
192
  sparsity_config = SparsityCompressionConfig.load_from_registry(
186
193
  sparsity_config
187
194
  )
188
195
 
189
- if sparsity_config is None and quantization_config is None:
196
+ # use config attached to model
197
+ transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
198
+
199
+ if not any((quantization_config, sparsity_config, transform_config)):
190
200
  return None
191
201
 
192
202
  return cls(
193
- sparsity_config=sparsity_config, quantization_config=quantization_config
203
+ sparsity_config=sparsity_config,
204
+ quantization_config=quantization_config,
205
+ transform_config=transform_config,
194
206
  )
195
207
 
196
208
  @staticmethod
197
209
  def parse_sparsity_config(
198
- compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
210
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
199
211
  ) -> Union[Dict[str, Any], None]:
200
212
  """
201
213
  Parse sparsity config from quantization/compression config. Sparsity
@@ -215,7 +227,7 @@ class ModelCompressor:
215
227
 
216
228
  @staticmethod
217
229
  def parse_quantization_config(
218
- compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
230
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
219
231
  ) -> Union[Dict[str, Any], None]:
220
232
  """
221
233
  Parse quantization config from quantization/compression config. The
@@ -234,6 +246,7 @@ class ModelCompressor:
234
246
 
235
247
  quantization_config = deepcopy(compression_config)
236
248
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
249
+ quantization_config.pop(TRANSFORM_CONFIG_NAME, None)
237
250
 
238
251
  # some fields are required, even if a qconfig is not present
239
252
  # pop them off and if nothing remains, then there is no qconfig
@@ -254,13 +267,17 @@ class ModelCompressor:
254
267
  self,
255
268
  sparsity_config: Optional[SparsityCompressionConfig] = None,
256
269
  quantization_config: Optional[QuantizationConfig] = None,
270
+ transform_config: Optional[TransformConfig] = None,
257
271
  ):
258
272
  self.sparsity_config = sparsity_config
259
273
  self.quantization_config = quantization_config
274
+ self.transform_config = transform_config
275
+
260
276
  self.sparsity_compressor = None
261
277
  self.quantization_compressor: Optional[
262
278
  Union[BaseQuantizationCompressor, DenseCompressor]
263
279
  ] = None
280
+ # no transform compressor is required
264
281
 
265
282
  if sparsity_config is not None:
266
283
  self.sparsity_compressor = BaseCompressor.load_from_registry(
@@ -640,43 +657,49 @@ class ModelCompressor:
640
657
 
641
658
  :param save_directory: path to a folder containing a HF model config
642
659
  """
643
- if self.quantization_config is None and self.sparsity_config is None:
660
+ # this check is also done in `from_pretrained_model`,
661
+ # but not in `from_pretrained`` or `from_compression_config``
662
+ if not any(
663
+ (self.quantization_config, self.sparsity_config, self.transform_config)
664
+ ):
644
665
  return
645
666
 
667
+ # write to config.json file, regardless of whether it exists already
668
+ # overwrite previous config and version if already existing
646
669
  config_file_path = os.path.join(save_directory, CONFIG_NAME)
647
- if not os.path.exists(config_file_path):
648
- _LOGGER.warning(
649
- f"Could not find a valid model config file in "
650
- f"{save_directory}. Compression config will not be saved."
651
- )
652
- return
670
+ if os.path.exists(config_file_path):
671
+ with open(config_file_path, "r") as file:
672
+ config_data = json.load(file)
673
+ else:
674
+ config_data = {}
653
675
 
654
- with open(config_file_path, "r") as config_file:
655
- config_data = json.load(config_file)
676
+ # serialize configs into json
677
+ qconfig_data = (
678
+ self.quantization_config.model_dump(exclude=["quant_method"])
679
+ if self.quantization_config is not None
680
+ else {}
681
+ )
682
+ sconfig_data = (
683
+ self.sparsity_config.model_dump()
684
+ if self.sparsity_config is not None
685
+ else {}
686
+ )
687
+ tconfig_data = (
688
+ self.transform_config.model_dump()
689
+ if self.transform_config is not None
690
+ else {}
691
+ )
656
692
 
657
- # required metadata whenever a quantization or sparsity config is present
658
- # overwrite previous config and version if already existing
659
- config_data[QUANTIZATION_CONFIG_NAME] = {}
660
- config_data[QUANTIZATION_CONFIG_NAME][
661
- COMPRESSION_VERSION_NAME
662
- ] = compressed_tensors.__version__
663
- if self.quantization_config is not None:
664
- self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
665
- else:
666
- config_data[QUANTIZATION_CONFIG_NAME][
667
- QUANTIZATION_METHOD_NAME
668
- ] = DEFAULT_QUANTIZATION_METHOD
669
-
670
- # quantization and sparsity configs
671
- if self.quantization_config is not None:
672
- quant_config_data = self.quantization_config.model_dump()
673
- config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
674
- if self.sparsity_config is not None:
675
- sparsity_config_data = self.sparsity_config.model_dump()
676
- config_data[QUANTIZATION_CONFIG_NAME][
677
- SPARSITY_CONFIG_NAME
678
- ] = sparsity_config_data
693
+ # construct compression (quantization) config
694
+ config_data[QUANTIZATION_CONFIG_NAME] = {
695
+ COMPRESSION_VERSION_NAME: compressed_tensors.__version__,
696
+ QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD,
697
+ SPARSITY_CONFIG_NAME: sconfig_data,
698
+ TRANSFORM_CONFIG_NAME: tconfig_data,
699
+ **qconfig_data,
700
+ }
679
701
 
702
+ # write results to config.json file
680
703
  with open(config_file_path, "w") as config_file:
681
704
  json.dump(config_data, config_file, indent=2, sort_keys=True)
682
705
 
@@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Union
19
19
  import torch
20
20
  from compressed_tensors.utils import Aliasable
21
21
  from compressed_tensors.utils.helpers import deprecated
22
- from pydantic import BaseModel, Field, field_validator, model_validator
22
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
23
23
 
24
24
 
25
25
  __all__ = [
@@ -358,6 +358,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
358
358
  def get_observer(self) -> str:
359
359
  return self.observer
360
360
 
361
+ model_config = ConfigDict(extra="forbid")
362
+
361
363
 
362
364
  def round_to_quantized_type(
363
365
  tensor: torch.Tensor, args: QuantizationArgs
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Dict, List, Optional, Union
16
+ from typing import Annotated, Any, Dict, List, Optional, Union
17
17
 
18
18
  from compressed_tensors.config import CompressionFormat
19
19
  from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
@@ -26,7 +26,7 @@ from compressed_tensors.quantization.utils import (
26
26
  module_type,
27
27
  parse_out_kv_cache_args,
28
28
  )
29
- from pydantic import BaseModel, Field
29
+ from pydantic import BaseModel, ConfigDict, Field
30
30
  from torch.nn import Module
31
31
 
32
32
 
@@ -142,6 +142,9 @@ class QuantizationConfig(BaseModel):
142
142
  quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
143
143
  global_compression_ratio: Optional[float] = None
144
144
  ignore: Optional[List[str]] = Field(default_factory=list)
145
+ # `run_compressed` is a dummy, unused arg for backwards compatibility
146
+ # see: https://github.com/huggingface/transformers/pull/39324
147
+ run_compressed: Annotated[Any, Field(exclude=True)] = None
145
148
 
146
149
  def model_post_init(self, __context):
147
150
  """
@@ -254,3 +257,6 @@ class QuantizationConfig(BaseModel):
254
257
  return True
255
258
 
256
259
  return False
260
+
261
+ # TODO set `extra="forbid"` when upstream transformers is compatible
262
+ model_config = ConfigDict(extra="ignore")
@@ -14,7 +14,7 @@
14
14
 
15
15
  import warnings
16
16
  from copy import deepcopy
17
- from typing import Any, Dict, List, Optional
17
+ from typing import List, Optional
18
18
 
19
19
  from compressed_tensors.quantization.quant_args import (
20
20
  DynamicType,
@@ -22,7 +22,7 @@ from compressed_tensors.quantization.quant_args import (
22
22
  QuantizationStrategy,
23
23
  QuantizationType,
24
24
  )
25
- from pydantic import BaseModel, model_validator
25
+ from pydantic import BaseModel, ConfigDict, model_validator
26
26
 
27
27
 
28
28
  __all__ = [
@@ -81,6 +81,8 @@ class QuantizationScheme(BaseModel):
81
81
 
82
82
  return model
83
83
 
84
+ model_config = ConfigDict(extra="forbid")
85
+
84
86
 
85
87
  """
86
88
  Pre-Set Quantization Scheme Args
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import torch
16
+ from compressed_tensors import TRANSFORM_CONFIG_NAME
16
17
  from compressed_tensors.transform import TransformConfig, TransformFactory
17
18
 
18
19
 
@@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
30
31
  for name, scheme in config.config_groups.items():
31
32
  factory = TransformFactory.from_scheme(scheme, name=name)
32
33
  factory.apply_to_model(model)
34
+
35
+ # attach config to model for compression/serialization
36
+ setattr(model, TRANSFORM_CONFIG_NAME, config)
@@ -14,11 +14,10 @@
14
14
 
15
15
  from abc import ABC, abstractmethod
16
16
  from collections import defaultdict
17
- from typing import List, Optional, Tuple, Set
17
+ from typing import List, Optional, Set, Tuple
18
18
 
19
19
  import torch
20
20
  import torch.nn.utils.parametrize as P
21
- from compressed_tensors import InternalModule
22
21
  from compressed_tensors.registry.registry import RegistryMixin, T
23
22
  from compressed_tensors.transform import (
24
23
  TransformArgs,
@@ -34,6 +33,7 @@ from compressed_tensors.utils import (
34
33
  register_offload_module,
35
34
  update_offload_parameter,
36
35
  )
36
+ from compressed_tensors.utils.internal import InternalModule
37
37
  from torch import Tensor
38
38
  from torch.nn import Module, Parameter
39
39
 
@@ -12,8 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import math
16
- from typing import Optional, Union
15
+ from typing import Optional
17
16
 
18
17
  import torch
19
18
  from compressed_tensors.transform import TransformArgs, TransformScheme
@@ -26,7 +25,7 @@ from compressed_tensors.transform.utils.matrix import (
26
25
  from compressed_tensors.utils import get_execution_device, get_offloaded_device
27
26
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
28
27
  from torch import Tensor, device, dtype
29
- from torch.nn import Linear, Module, Parameter
28
+ from torch.nn import Module, Parameter
30
29
 
31
30
 
32
31
  @TransformFactory.register("hadamard")
@@ -54,14 +53,14 @@ class HadamardFactory(TransformFactory):
54
53
  """
55
54
  assert hasattr(module, "weight")
56
55
  size = get_transform_size(module, args.location, self.scheme.head_dim)
57
- dtype = module.weight.dtype
56
+ dtype = self.scheme.precision
58
57
  device = get_offloaded_device(module)
59
58
  exec_device = get_execution_device(module)
60
59
 
61
60
  factory_kwargs = {"construct_device": exec_device}
62
61
  weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
63
62
  perm = self.perms[weight] if self.scheme.randomize else None
64
- return HadamardTransform(weight, perm, args, type(module))
63
+ return HadamardTransform(weight, perm, self.scheme, args, type(module))
65
64
 
66
65
  def _create_weight(
67
66
  self,
@@ -85,15 +84,18 @@ class HadamardTransform(TransformBase):
85
84
  self,
86
85
  weight: Parameter,
87
86
  perm: Optional[Parameter],
87
+ scheme: TransformScheme,
88
88
  args: TransformArgs,
89
89
  module_type: type[torch.nn.Module],
90
90
  ):
91
91
  super().__init__()
92
92
  self.weight = weight
93
93
  self.perm = perm
94
+ self.scheme = scheme
94
95
  self.args = args
95
96
  self.module_type = module_type
96
- self._scale = math.sqrt(weight.size(0))
97
+ self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
98
+ self._precision = scheme.precision if args.is_online() else torch.float64
97
99
 
98
100
  def forward(self, value: Tensor) -> Tensor:
99
101
  weight = self.weight
@@ -105,6 +107,11 @@ class HadamardTransform(TransformBase):
105
107
  weight = weight.T
106
108
 
107
109
  return (
108
- apply_transform_weight(weight, value, self.args.location, self.module_type)
110
+ apply_transform_weight(
111
+ weight.to(self._precision),
112
+ value.to(self._precision),
113
+ self.args.location,
114
+ self.module_type,
115
+ )
109
116
  / self._scale
110
- )
117
+ ).to(value.dtype)
@@ -24,7 +24,7 @@ from compressed_tensors.transform.utils.matrix import (
24
24
  from compressed_tensors.utils import get_offloaded_device
25
25
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
26
26
  from torch import Tensor, device, dtype
27
- from torch.nn import Linear, Module, Parameter
27
+ from torch.nn import Module, Parameter
28
28
 
29
29
 
30
30
  @TransformFactory.register("random-matrix")
@@ -52,14 +52,14 @@ class RandomMatrixFactory(TransformFactory):
52
52
  """
53
53
  assert hasattr(module, "weight")
54
54
  size = get_transform_size(module, args.location, self.scheme.head_dim)
55
- dtype = module.weight.dtype
55
+ dtype = self.scheme.precision
56
56
  device = get_offloaded_device(module)
57
57
 
58
58
  weight = self.weights[size, dtype, device]
59
59
  if args.inverse:
60
60
  weight = self.inverses[weight]
61
61
 
62
- return RandomMatrixTransform(weight, args, type(module))
62
+ return RandomMatrixTransform(weight, self.scheme, args, type(module))
63
63
 
64
64
  def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65
65
  # TODO: verify that weight is invertible (has non-zero determinant)
@@ -78,25 +78,34 @@ class RandomMatrixTransform(TransformBase):
78
78
  def __init__(
79
79
  self,
80
80
  weight: Tensor,
81
+ scheme: TransformScheme,
81
82
  args: TransformArgs,
82
83
  module_type: type[torch.nn.Module],
83
84
  ):
84
85
  super().__init__()
85
86
  self.weight = weight # is an inverse if args.inverse
87
+ self.scheme = scheme
86
88
  self.args = args
87
89
  self.module_type = module_type
90
+ self._precision = scheme.precision if args.is_online() else torch.float64
88
91
 
89
92
  def forward(self, value: Tensor) -> Parameter:
90
93
  return apply_transform_weight(
91
- self.weight, value, self.args.location, self.module_type
92
- )
94
+ self.weight.to(self._precision),
95
+ value.to(self._precision),
96
+ self.args.location,
97
+ self.module_type,
98
+ ).to(value.dtype)
93
99
 
94
100
  def right_inverse(self, value: Tensor) -> Tensor:
95
101
  inverse = high_precision_invert(self.weight)
96
102
  return apply_transform_weight(
97
- inverse, value, self.args.location, self.module_type
98
- )
103
+ inverse.to(self._precision),
104
+ value.to(self._precision),
105
+ self.args.location,
106
+ self.module_type,
107
+ ).to(value.dtype)
99
108
 
100
109
 
101
110
  def high_precision_invert(weight: Tensor) -> Tensor:
102
- return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
111
+ return torch.linalg.inv(weight.to(torch.float64)).to(weight.dtype)
@@ -15,7 +15,7 @@
15
15
  from enum import Enum
16
16
  from typing import List
17
17
 
18
- from pydantic import BaseModel, Field, field_validator
18
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
19
19
 
20
20
 
21
21
  __all__ = ["TransformArgs", "TransformLocation"]
@@ -68,3 +68,11 @@ class TransformArgs(BaseModel, use_enum_values=True):
68
68
  if isinstance(value, str):
69
69
  return [value]
70
70
  return value
71
+
72
+ def is_online(self) -> bool:
73
+ return self.location not in (
74
+ TransformLocation.WEIGHT_INPUT,
75
+ TransformLocation.WEIGHT_OUTPUT,
76
+ )
77
+
78
+ model_config = ConfigDict(extra="forbid")
@@ -15,7 +15,7 @@
15
15
  from typing import Dict
16
16
 
17
17
  from compressed_tensors.transform import TransformArgs, TransformScheme
18
- from pydantic import BaseModel
18
+ from pydantic import BaseModel, ConfigDict
19
19
 
20
20
 
21
21
  __all__ = ["TransformConfig"]
@@ -32,42 +32,4 @@ class TransformConfig(BaseModel):
32
32
 
33
33
  config_groups: Dict[str, TransformScheme]
34
34
 
35
-
36
- # quip / quip sharp
37
- QUIP = TransformConfig(
38
- config_groups={
39
- "v": TransformScheme(
40
- type="hadamard",
41
- apply=[
42
- TransformArgs(
43
- targets=["Linear"],
44
- location="input", # non-mergable
45
- ),
46
- TransformArgs(
47
- targets=["Linear"],
48
- location="weight_input",
49
- inverse=True,
50
- ),
51
- ],
52
- randomize=True,
53
- ),
54
- "u": TransformScheme(
55
- type="hadamard",
56
- apply=[
57
- TransformArgs(
58
- targets=["Linear"],
59
- location="weight_output",
60
- ),
61
- TransformArgs(
62
- targets=["Linear"], location="output", inverse=True # non-mergable
63
- ),
64
- ],
65
- randomize=True,
66
- ),
67
- }
68
- )
69
-
70
-
71
- PRESET_CONFIGS = {
72
- "QUIP": QUIP,
73
- }
35
+ model_config = ConfigDict(extra="forbid")
@@ -14,8 +14,10 @@
14
14
 
15
15
  from typing import List, Optional
16
16
 
17
+ import torch
17
18
  from compressed_tensors.transform import TransformArgs
18
- from pydantic import BaseModel, Field
19
+ from compressed_tensors.utils import TorchDtype
20
+ from pydantic import BaseModel, ConfigDict, Field
19
21
 
20
22
 
21
23
  __all__ = ["TransformScheme"]
@@ -34,6 +36,8 @@ class TransformScheme(BaseModel):
34
36
  :param randomize: True if uniquely randomized transform weights should be used,
35
37
  otherwise use identical transform weights where applicable
36
38
  :param requires_grad: True if weights include gradients for training
39
+ :param precision: Precision at which this transform should be applied during online
40
+ rotations. Fused (offline) rotations are always performed in float64
37
41
  """
38
42
 
39
43
  type: str
@@ -41,3 +45,6 @@ class TransformScheme(BaseModel):
41
45
  randomize: bool = Field(default=False)
42
46
  requires_grad: bool = Field(default=False)
43
47
  head_dim: Optional[int] = Field(default=None)
48
+ precision: TorchDtype = Field(default=torch.float32)
49
+
50
+ model_config = ConfigDict(extra="forbid")
@@ -21,3 +21,4 @@ from .permutations_24 import *
21
21
  from .permute import *
22
22
  from .safetensors_load import *
23
23
  from .semi_structured_conversions import *
24
+ from .type import *
@@ -86,6 +86,7 @@ __all__ = [
86
86
  "offloaded_dispatch",
87
87
  "disable_offloading",
88
88
  "remove_dispatch",
89
+ "cast_to_device",
89
90
  ]
90
91
 
91
92
 
@@ -169,6 +170,19 @@ def update_parameter_data(
169
170
  """ Candidates for Upstreaming """
170
171
 
171
172
 
173
+ def cast_to_device(device_spec: Union[int, torch.device]) -> torch.device:
174
+ """
175
+ Convert an integer device index or torch.device into a torch.device object.
176
+
177
+ :param device_spec: Device index (int) or torch.device object.
178
+ Negative integers map to CPU.
179
+ :return: torch.device corresponding to the given device specification.
180
+ """
181
+ if isinstance(device_spec, int):
182
+ return torch.device(f"cuda:{device_spec}" if device_spec >= 0 else "cpu")
183
+ return device_spec
184
+
185
+
172
186
  def get_execution_device(module: torch.nn.Module) -> torch.device:
173
187
  """
174
188
  Get the device which inputs should be moved to before module execution.
@@ -179,7 +193,7 @@ def get_execution_device(module: torch.nn.Module) -> torch.device:
179
193
  """
180
194
  for submodule in module.modules():
181
195
  if has_offloaded_params(submodule):
182
- return submodule._hf_hook.execution_device
196
+ return cast_to_device(submodule._hf_hook.execution_device)
183
197
 
184
198
  param = next(submodule.parameters(recurse=False), None)
185
199
  if param is not None: