compressed-tensors 0.9.4a20250421__tar.gz → 0.9.5a20250424__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.github/workflows/build.yml +1 -1
  2. {compressed_tensors-0.9.4a20250421/src/compressed_tensors.egg-info → compressed_tensors-0.9.5a20250424}/PKG-INFO +1 -1
  3. compressed_tensors-0.9.5a20250424/pyproject.toml +3 -0
  4. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/setup.py +1 -0
  5. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/utils/offload.py +20 -16
  6. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/version.py +2 -2
  7. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  8. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_utils/test_offload.py +46 -1
  9. compressed_tensors-0.9.4a20250421/pyproject.toml +0 -10
  10. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.github/.gitkeep +0 -0
  11. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.github/actions/test/action.yml +0 -0
  12. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.github/scripts/step-status +0 -0
  13. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.github/workflows/build-test.yml +0 -0
  14. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.github/workflows/report.yml +0 -0
  15. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.github/workflows/test-check.yaml +0 -0
  16. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.github/workflows/test.yml +0 -0
  17. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.github/workflows/trigger-all.yml +0 -0
  18. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.github/workflows/upload.yml +0 -0
  19. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/.gitignore +0 -0
  20. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/LICENSE +0 -0
  21. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/Makefile +0 -0
  22. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/README.md +0 -0
  23. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  24. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/examples/bit_packing/int4_config.json +0 -0
  25. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/examples/bitmask_compression.ipynb +0 -0
  26. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  27. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  28. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/examples/llama_1.1b/example_quant_config.json +0 -0
  29. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  30. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/examples/quantize_and_pack_int4.ipynb +0 -0
  31. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/setup.cfg +0 -0
  32. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/__init__.py +0 -0
  33. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/README.md +0 -0
  34. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/__init__.py +0 -0
  35. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/base.py +0 -0
  36. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/__init__.py +0 -0
  37. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/base.py +0 -0
  38. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/helpers.py +0 -0
  39. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  40. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  41. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  42. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  43. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  44. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  45. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  46. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  47. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  48. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  49. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  50. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  51. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  52. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/config/__init__.py +0 -0
  53. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/config/base.py +0 -0
  54. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/config/dense.py +0 -0
  55. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  56. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  57. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/linear/__init__.py +0 -0
  58. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  59. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/__init__.py +0 -0
  60. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  61. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  62. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  63. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  64. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  65. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  66. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/quant_args.py +0 -0
  67. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/quant_config.py +0 -0
  68. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  69. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  70. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  71. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/registry/__init__.py +0 -0
  72. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/registry/registry.py +0 -0
  73. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/utils/__init__.py +0 -0
  74. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/utils/helpers.py +0 -0
  75. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/utils/permutations_24.py +0 -0
  76. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/utils/permute.py +0 -0
  77. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  78. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  79. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors.egg-info/SOURCES.txt +0 -0
  80. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  81. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors.egg-info/requires.txt +0 -0
  82. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  83. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/__init__.py +0 -0
  84. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/conftest.py +0 -0
  85. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/__init__.py +0 -0
  86. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/model_compressors/__init__.py +0 -0
  87. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  88. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  89. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  90. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  91. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  92. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  93. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  94. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  95. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  96. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  97. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_configs/__init__.py +0 -0
  98. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_configs/test_base.py +0 -0
  99. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  100. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_linear/__init__.py +0 -0
  101. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_linear/test_compressed_linear.py +0 -0
  102. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/__init__.py +0 -0
  103. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/lifecycle/__init__.py +0 -0
  104. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/lifecycle/conftest.py +0 -0
  105. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  106. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  107. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  108. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  109. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  110. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  111. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  112. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/test_configs/__init__.py +0 -0
  113. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  114. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  115. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/test_quant_args.py +0 -0
  116. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/test_quant_config.py +0 -0
  117. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/test_quant_scheme.py +0 -0
  118. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  119. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_registry.py +0 -0
  120. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_utils/__init__.py +0 -0
  121. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_utils/test_helpers.py +0 -0
  122. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/test_utils/test_safetensors_load.py +0 -0
  123. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/tests/testing_utils.py +0 -0
  124. {compressed_tensors-0.9.4a20250421 → compressed_tensors-0.9.5a20250424}/utils/copyright.py +0 -0
@@ -76,7 +76,7 @@ jobs:
76
76
 
77
77
  - name: build
78
78
  id: build
79
- uses: neuralmagic/nm-actions/actions/build-ml-whl@v1.18.0
79
+ uses: neuralmagic/nm-actions/actions/build-ml-whl@fix-whl-checks
80
80
  with:
81
81
  dev: false
82
82
  release: ${{ inputs.wf_category == 'RELEASE' }}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.4a20250421
3
+ Version: 0.9.5a20250424
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -0,0 +1,3 @@
1
+ [tool.black]
2
+ line-length = 88
3
+ target-version = ['py36']
@@ -101,6 +101,7 @@ setup(
101
101
  use_scm_version={
102
102
  "version_scheme": version_func,
103
103
  "local_scheme": localversion_func,
104
+ "version_file": "src/compressed_tensors/version.py",
104
105
  },
105
106
  author="Neuralmagic, Inc.",
106
107
  author_email="support@neuralmagic.com",
@@ -94,22 +94,6 @@ def is_module_offloaded(module: torch.nn.Module) -> bool:
94
94
  return has_offloaded_params(module)
95
95
 
96
96
 
97
- def get_execution_device(module: torch.nn.Module) -> torch.device:
98
- """
99
- :param module: module to check
100
- :return: device module is loaded onto during forward pass
101
- """
102
- if has_offloaded_params(module):
103
- return module._hf_hook.execution_device
104
- device = next(module.parameters()).device
105
-
106
- # offload only gets set for leaf modules, fallback to checking for device type
107
- if device.type == "meta":
108
- return module._hf_hook.execution_device
109
-
110
- return device
111
-
112
-
113
97
  def get_offloaded_device(module: torch.nn.Module) -> torch.device:
114
98
  """
115
99
  :param module: module to check
@@ -158,6 +142,26 @@ def update_parameter_data(
158
142
  """ Candidates for Upstreaming """
159
143
 
160
144
 
145
+ def get_execution_device(module: torch.nn.Module) -> torch.device:
146
+ """
147
+ Get the device which inputs should be moved to before module execution
148
+
149
+ :param module: module to check, may be offloaded
150
+ :return: onload device of module
151
+ """
152
+ if has_offloaded_params(module):
153
+ return module._hf_hook.execution_device
154
+
155
+ first_param = next(module.parameters(), None)
156
+ if first_param is None:
157
+ warnings.warn(
158
+ f"Unable able to infer execution device of {module}, falling back to CPU"
159
+ )
160
+ return torch.device("cpu")
161
+
162
+ return first_param.device
163
+
164
+
161
165
  def register_offload_parameter(
162
166
  module: torch.nn.Module,
163
167
  name: str,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.4.a20250421'
21
- __version_tuple__ = version_tuple = (0, 9, 4, 'a20250421')
20
+ __version__ = version = '0.9.5.a20250424'
21
+ __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.4a20250421
3
+ Version: 0.9.5a20250424
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -17,12 +17,13 @@ from compressed_tensors.utils import (
17
17
  align_module_device,
18
18
  delete_offload_parameter,
19
19
  disable_hf_hook,
20
+ get_execution_device,
20
21
  has_offloaded_params,
21
22
  register_offload_parameter,
22
23
  update_offload_parameter,
23
24
  )
24
25
  from compressed_tensors.utils.offload import offload_to_weights_map
25
- from tests.testing_utils import requires_accelerate
26
+ from tests.testing_utils import requires_accelerate, requires_gpu
26
27
 
27
28
 
28
29
  class ExampleModule(torch.nn.Module):
@@ -55,8 +56,46 @@ def test_has_offloaded_params():
55
56
  assert has_offloaded_params(module)
56
57
 
57
58
 
59
+ @requires_gpu
60
+ @requires_accelerate()
61
+ def test_get_execution_device():
62
+ from accelerate import init_empty_weights
63
+ from accelerate.big_modeling import attach_align_device_hook
64
+
65
+ # no offloading
66
+ module = ExampleModule()
67
+ assert get_execution_device(module) == torch.device("cpu")
68
+
69
+ # with offloading
70
+ attach_align_device_hook(module, torch.device("cuda:0"))
71
+ assert get_execution_device(module) == torch.device("cuda:0")
72
+
73
+ # in meta context
74
+ with torch.device("meta"):
75
+ module = ExampleModule()
76
+ assert get_execution_device(module) == torch.device("meta")
77
+
78
+ # offloaded in meta context
79
+ module = ExampleModule()
80
+ attach_align_device_hook(module, torch.device("cuda:0"))
81
+ with torch.device("meta"):
82
+ assert get_execution_device(module) == torch.device("cuda:0")
83
+
84
+ # in empty weights context
85
+ with init_empty_weights():
86
+ module = ExampleModule()
87
+ assert get_execution_device(module) == torch.device("meta")
88
+
89
+ # offloaded in empty weights context
90
+ module = ExampleModule()
91
+ attach_align_device_hook(module, torch.device("cuda:0"))
92
+ with init_empty_weights():
93
+ assert get_execution_device(module) == torch.device("cuda:0")
94
+
95
+
58
96
  @requires_accelerate()
59
97
  def test_register_offload_parameter():
98
+ from accelerate import init_empty_weights
60
99
  from accelerate.hooks import attach_align_device_hook
61
100
 
62
101
  module = ExampleModule()
@@ -94,6 +133,12 @@ def test_register_offload_parameter():
94
133
  assert module.f.device == torch.device("cpu")
95
134
  assert module._hf_hook.weights_map["f"].device == torch.device("cpu")
96
135
 
136
+ # parameters registered in the empty init context are still empty
137
+ with init_empty_weights():
138
+ module = ExampleModule()
139
+ register_offload_parameter(module, "c", parameter)
140
+ assert module.a.device == module.b.device == module.c.device == torch.device("meta")
141
+
97
142
 
98
143
  @requires_accelerate()
99
144
  def test_update_offload_parameter():
@@ -1,10 +0,0 @@
1
- [build-system]
2
- requires = ["setuptools", "wheel", "setuptools_scm>8"]
3
- build-backend = "setuptools.build_meta"
4
-
5
- [tool.setuptools_scm]
6
- version_file = "src/compressed_tensors/version.py"
7
-
8
- [tool.black]
9
- line-length = 88
10
- target-version = ['py36']