compressed-tensors 0.9.5a20250424__tar.gz → 0.9.5a20250428__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.github/workflows/build.yml +1 -1
  2. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.github/workflows/test-check.yaml +5 -5
  3. {compressed_tensors-0.9.5a20250424/src/compressed_tensors.egg-info → compressed_tensors-0.9.5a20250428}/PKG-INFO +1 -1
  4. compressed_tensors-0.9.5a20250428/pyproject.toml +7 -0
  5. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/lifecycle/initialize.py +3 -6
  6. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/utils/offload.py +40 -1
  7. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/version.py +1 -1
  8. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  9. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_utils/test_offload.py +30 -0
  10. compressed_tensors-0.9.5a20250424/pyproject.toml +0 -3
  11. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.github/.gitkeep +0 -0
  12. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.github/actions/test/action.yml +0 -0
  13. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.github/scripts/step-status +0 -0
  14. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.github/workflows/build-test.yml +0 -0
  15. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.github/workflows/report.yml +0 -0
  16. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.github/workflows/test.yml +0 -0
  17. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.github/workflows/trigger-all.yml +0 -0
  18. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.github/workflows/upload.yml +0 -0
  19. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/.gitignore +0 -0
  20. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/LICENSE +0 -0
  21. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/Makefile +0 -0
  22. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/README.md +0 -0
  23. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  24. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/examples/bit_packing/int4_config.json +0 -0
  25. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/examples/bitmask_compression.ipynb +0 -0
  26. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  27. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  28. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/examples/llama_1.1b/example_quant_config.json +0 -0
  29. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  30. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/examples/quantize_and_pack_int4.ipynb +0 -0
  31. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/setup.cfg +0 -0
  32. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/setup.py +0 -0
  33. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/__init__.py +0 -0
  34. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/README.md +0 -0
  35. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/__init__.py +0 -0
  36. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/base.py +0 -0
  37. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/__init__.py +0 -0
  38. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/base.py +0 -0
  39. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/helpers.py +0 -0
  40. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  41. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  42. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  43. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  44. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  45. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  46. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  47. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  48. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  49. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  50. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  51. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  52. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  53. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/config/__init__.py +0 -0
  54. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/config/base.py +0 -0
  55. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/config/dense.py +0 -0
  56. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  57. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  58. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/linear/__init__.py +0 -0
  59. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  60. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/__init__.py +0 -0
  61. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  62. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  63. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  64. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  65. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  66. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/quant_args.py +0 -0
  67. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/quant_config.py +0 -0
  68. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  69. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  70. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  71. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/registry/__init__.py +0 -0
  72. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/registry/registry.py +0 -0
  73. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/utils/__init__.py +0 -0
  74. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/utils/helpers.py +0 -0
  75. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/utils/permutations_24.py +0 -0
  76. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/utils/permute.py +0 -0
  77. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  78. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  79. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors.egg-info/SOURCES.txt +0 -0
  80. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  81. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors.egg-info/requires.txt +0 -0
  82. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  83. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/__init__.py +0 -0
  84. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/conftest.py +0 -0
  85. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/__init__.py +0 -0
  86. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/model_compressors/__init__.py +0 -0
  87. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  88. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  89. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  90. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  91. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  92. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  93. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  94. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  95. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  96. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  97. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_configs/__init__.py +0 -0
  98. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_configs/test_base.py +0 -0
  99. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  100. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_linear/__init__.py +0 -0
  101. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_linear/test_compressed_linear.py +0 -0
  102. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/__init__.py +0 -0
  103. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/lifecycle/__init__.py +0 -0
  104. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/lifecycle/conftest.py +0 -0
  105. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  106. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  107. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  108. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  109. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  110. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  111. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  112. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/test_configs/__init__.py +0 -0
  113. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  114. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  115. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/test_quant_args.py +0 -0
  116. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/test_quant_config.py +0 -0
  117. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/test_quant_scheme.py +0 -0
  118. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  119. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_registry.py +0 -0
  120. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_utils/__init__.py +0 -0
  121. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_utils/test_helpers.py +0 -0
  122. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/test_utils/test_safetensors_load.py +0 -0
  123. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/tests/testing_utils.py +0 -0
  124. {compressed_tensors-0.9.5a20250424 → compressed_tensors-0.9.5a20250428}/utils/copyright.py +0 -0
@@ -76,7 +76,7 @@ jobs:
76
76
 
77
77
  - name: build
78
78
  id: build
79
- uses: neuralmagic/nm-actions/actions/build-ml-whl@fix-whl-checks
79
+ uses: neuralmagic/nm-actions/actions/build-ml-whl@c7e5a66c382104e1beadcb7dadf429f8ab15b344 # v1.20.0
80
80
  with:
81
81
  dev: false
82
82
  release: ${{ inputs.wf_category == 'RELEASE' }}
@@ -12,16 +12,16 @@ jobs:
12
12
  python-tests:
13
13
  runs-on: ubuntu-24.04
14
14
  steps:
15
- - uses: actions/setup-python@v4
15
+ - uses: actions/setup-python@v5
16
16
  with:
17
17
  python-version: '3.10'
18
- - uses: actions/checkout@v3
18
+ - uses: actions/checkout@v4
19
+ with:
20
+ fetch-depth: 0
21
+ fetch-tags: true
19
22
  - name: Set Env
20
23
  run: |
21
24
  pip3 install --upgrade pip && pip3 install --upgrade setuptools
22
- pip3 install virtualenv
23
- virtualenv venv
24
- source venv/bin/activate
25
25
  - name: "⚙️ Install dependencies"
26
26
  run: pip3 install .[dev,accelerate]
27
27
  - name: "🔬 Running tests"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250424
3
+ Version: 0.9.5a20250428
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -0,0 +1,7 @@
1
+ [build-system]
2
+ requires = ["setuptools", "wheel", "setuptools_scm==8.2.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.black]
6
+ line-length = 88
7
+ target-version = ['py36']
@@ -31,7 +31,7 @@ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
31
31
  from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
32
32
  from compressed_tensors.utils import (
33
33
  disable_hf_hook,
34
- has_offloaded_params,
34
+ get_execution_device,
35
35
  register_offload_parameter,
36
36
  )
37
37
  from torch.nn import Module, Parameter
@@ -148,11 +148,8 @@ def _initialize_scale_zero_point(
148
148
  if quantization_args.dynamic:
149
149
  return
150
150
 
151
- # begin on the same device as other parameters or cpu if offloaded.
152
- # in the offloaded case, there's no point moving tensors to the execution device
153
- # if they're going to be immediately offloaded by `register_offload_parameter`
154
- params_device = next(module.parameters()).device
155
- device = "cpu" if has_offloaded_params(module) else params_device
151
+ # initialize on execution device to avoid performing quantized ops on cpu
152
+ device = get_execution_device(module)
156
153
 
157
154
  # infer expected scale/zero point shape
158
155
  if quantization_args.strategy == QuantizationStrategy.TOKEN:
@@ -28,7 +28,7 @@ Utilities associated with offloading functionality provided by `accelerate`.
28
28
  import contextlib
29
29
  import warnings
30
30
  from functools import wraps
31
- from typing import Any, Callable, Dict, Literal, Optional, Union
31
+ from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
32
32
 
33
33
  import torch
34
34
 
@@ -67,6 +67,8 @@ __all__ = [
67
67
  "delete_offload_parameter",
68
68
  "has_offloaded_params",
69
69
  "disable_hf_hook",
70
+ "disable_offload",
71
+ "align_modules",
70
72
  "align_module_device",
71
73
  ]
72
74
 
@@ -344,6 +346,43 @@ def delete_from_weights_map(
344
346
  )
345
347
 
346
348
 
349
+ @contextlib.contextmanager
350
+ def disable_offload(module: torch.nn.Module):
351
+ """
352
+ Context manager to disable module onloading and offloading. Parameters will stay on
353
+ their current device
354
+
355
+ :param module: module to disable offloading for
356
+ """
357
+ if has_offloaded_params(module):
358
+ module._hf_hook.offload = False
359
+ yield
360
+ module._hf_hook.offload = True
361
+ else:
362
+ yield
363
+
364
+
365
+ @contextlib.contextmanager
366
+ def align_modules(
367
+ modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
368
+ execution_device: Optional[torch.device] = None,
369
+ ):
370
+ """
371
+ Context manager for onloading modules to a device, and disabling onload and offload
372
+ attempts triggered by forward calls. Used for sequential onloading of layers
373
+
374
+ :param modules: `torch.nn.Module` or iterable of `torch.nn.Module`s to onload
375
+ :param execution_device: device to onload to
376
+ """
377
+ modules = (modules,) if isinstance(modules, torch.nn.Module) else modules
378
+
379
+ with contextlib.ExitStack() as stack:
380
+ for module in modules:
381
+ stack.enter_context(align_module_device(module, execution_device))
382
+ stack.enter_context(disable_offload(module)) # disable redundant onloading
383
+ yield
384
+
385
+
347
386
  """ Upstreamed Functions """
348
387
 
349
388
 
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.5.a20250424'
20
+ __version__ = version = '0.9.5.a20250428'
21
21
  __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250424
3
+ Version: 0.9.5a20250428
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -15,6 +15,7 @@ import pytest
15
15
  import torch
16
16
  from compressed_tensors.utils import (
17
17
  align_module_device,
18
+ align_modules,
18
19
  delete_offload_parameter,
19
20
  disable_hf_hook,
20
21
  get_execution_device,
@@ -248,6 +249,35 @@ def test_disable_hf_hook_model_recurse():
248
249
  assert hasattr(module2, "_hf_hook")
249
250
 
250
251
 
252
+ @requires_accelerate()
253
+ def test_align_modules():
254
+ from accelerate.hooks import attach_align_device_hook
255
+
256
+ module0 = ExampleModule()
257
+ module1 = ExampleModule()
258
+ module2 = ExampleModule()
259
+ model = torch.nn.Sequential(module0, torch.nn.Sequential(module1, module2))
260
+ attach_align_device_hook(
261
+ model,
262
+ execution_device=torch.device("cpu"),
263
+ offload=True,
264
+ weights_map=model.state_dict(),
265
+ )
266
+
267
+ assert module0.a.device == torch.device("meta")
268
+ assert module1.a.device == torch.device("meta")
269
+ assert module2.a.device == torch.device("meta")
270
+
271
+ with align_modules((module0, module1)):
272
+ assert module0.a.device != torch.device("meta")
273
+ assert module1.a.device != torch.device("meta")
274
+ assert module2.a.device == torch.device("meta")
275
+
276
+ assert module0.a.device == torch.device("meta")
277
+ assert module1.a.device == torch.device("meta")
278
+ assert module2.a.device == torch.device("meta")
279
+
280
+
251
281
  @requires_accelerate()
252
282
  def test_offload_to_weights_map():
253
283
  from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset
@@ -1,3 +0,0 @@
1
- [tool.black]
2
- line-length = 88
3
- target-version = ['py36']