compressed-tensors 0.9.5a20250602__tar.gz → 0.9.5a20250604__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. {compressed_tensors-0.9.5a20250602/src/compressed_tensors.egg-info → compressed_tensors-0.9.5a20250604}/PKG-INFO +1 -1
  2. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +16 -18
  3. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/quantized_compressors/base.py +30 -3
  4. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/lifecycle/apply.py +1 -10
  5. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/lifecycle/forward.py +2 -3
  6. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/lifecycle/initialize.py +7 -113
  7. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/utils/helpers.py +10 -6
  8. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/utils/offload.py +134 -1
  9. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/version.py +1 -1
  10. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  11. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/test_utils/test_helpers.py +4 -6
  12. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_utils/test_offload.py +95 -5
  13. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.github/.gitkeep +0 -0
  14. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.github/actions/test/action.yml +0 -0
  15. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.github/scripts/step-status +0 -0
  16. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.github/workflows/build-test.yml +0 -0
  17. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.github/workflows/build.yml +0 -0
  18. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.github/workflows/report.yml +0 -0
  19. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.github/workflows/test-check.yaml +0 -0
  20. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.github/workflows/test.yml +0 -0
  21. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.github/workflows/trigger-all.yml +0 -0
  22. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.github/workflows/upload.yml +0 -0
  23. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/.gitignore +0 -0
  24. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/LICENSE +0 -0
  25. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/Makefile +0 -0
  26. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/README.md +0 -0
  27. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  28. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/examples/bit_packing/int4_config.json +0 -0
  29. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/examples/bitmask_compression.ipynb +0 -0
  30. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  31. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  32. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/examples/llama_1.1b/example_quant_config.json +0 -0
  33. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  34. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/examples/quantize_and_pack_int4.ipynb +0 -0
  35. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/pyproject.toml +0 -0
  36. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/setup.cfg +0 -0
  37. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/setup.py +0 -0
  38. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/__init__.py +0 -0
  39. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/README.md +0 -0
  40. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/__init__.py +0 -0
  41. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/base.py +0 -0
  42. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/__init__.py +0 -0
  43. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/base.py +0 -0
  44. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/helpers.py +0 -0
  45. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  46. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  47. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  48. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  49. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  50. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  51. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  52. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  53. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  54. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  55. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  56. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  57. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/config/__init__.py +0 -0
  58. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/config/base.py +0 -0
  59. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/config/dense.py +0 -0
  60. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  61. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  62. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/linear/__init__.py +0 -0
  63. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  64. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/__init__.py +0 -0
  65. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  66. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  67. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  68. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/quant_args.py +0 -0
  69. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/quant_config.py +0 -0
  70. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  71. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  72. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/registry/__init__.py +0 -0
  73. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/registry/registry.py +0 -0
  74. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/transform/__init__.py +0 -0
  75. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/transform/transform_args.py +0 -0
  76. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/transform/transform_config.py +0 -0
  77. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  78. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/utils/__init__.py +0 -0
  79. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/utils/helpers.py +0 -0
  80. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/utils/permutations_24.py +0 -0
  81. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/utils/permute.py +0 -0
  82. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  83. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  84. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors.egg-info/SOURCES.txt +0 -0
  85. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  86. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors.egg-info/requires.txt +0 -0
  87. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  88. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/__init__.py +0 -0
  89. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/conftest.py +0 -0
  90. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/__init__.py +0 -0
  91. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/model_compressors/__init__.py +0 -0
  92. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  93. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  94. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  95. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  96. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  97. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  98. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  99. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  100. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  101. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  102. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  103. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_configs/__init__.py +0 -0
  104. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_configs/test_base.py +0 -0
  105. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  106. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_linear/__init__.py +0 -0
  107. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_linear/test_compressed_linear.py +0 -0
  108. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/__init__.py +0 -0
  109. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/lifecycle/__init__.py +0 -0
  110. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/lifecycle/conftest.py +0 -0
  111. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  112. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  113. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  114. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  115. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  116. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  117. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  118. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/test_configs/__init__.py +0 -0
  119. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  120. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  121. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/test_quant_args.py +0 -0
  122. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/test_quant_config.py +0 -0
  123. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_quantization/test_quant_scheme.py +0 -0
  124. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_registry.py +0 -0
  125. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_transform/test_transform_args.py +0 -0
  126. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_transform/test_transform_config.py +0 -0
  127. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_transform/test_transform_scheme.py +0 -0
  128. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_utils/__init__.py +0 -0
  129. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_utils/test_helpers.py +0 -0
  130. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/test_utils/test_safetensors_load.py +0 -0
  131. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/tests/testing_utils.py +0 -0
  132. {compressed_tensors-0.9.5a20250602 → compressed_tensors-0.9.5a20250604}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250602
3
+ Version: 0.9.5a20250604
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -50,6 +50,7 @@ from compressed_tensors.utils import (
50
50
  align_module_device,
51
51
  delete_offload_parameter,
52
52
  get_execution_device,
53
+ get_offloaded_device,
53
54
  get_safetensors_folder,
54
55
  has_offloaded_params,
55
56
  merge_names,
@@ -408,16 +409,17 @@ class ModelCompressor:
408
409
  )
409
410
 
410
411
  # remove any existing parameters
411
- device = get_execution_device(module)
412
+ exec_device = get_execution_device(module)
413
+ offload_device = get_offloaded_device(module)
412
414
  for name, _ in list(module.named_parameters()):
413
- delattr(module, name)
415
+ delete_offload_parameter(module, name)
414
416
 
415
417
  # replace with compressed parameters
416
418
  for name, value in state_dict.items():
417
419
  name = name.removeprefix(f"{prefix}.")
418
- value = value.to(device)
420
+ value = value.to(exec_device)
419
421
  param = torch.nn.Parameter(value, requires_grad=False)
420
- register_offload_parameter(module, name, param)
422
+ register_offload_parameter(module, name, param, offload_device)
421
423
 
422
424
  module.quantization_status = QuantizationStatus.COMPRESSED
423
425
 
@@ -460,30 +462,26 @@ class ModelCompressor:
460
462
 
461
463
  # quantization second
462
464
  if prefix in module_to_scheme:
463
- generator = self.quantization_compressor.decompress_from_state_dict(
464
- state_dict,
465
- names_to_scheme=module_to_scheme,
465
+ state_dict = (
466
+ self.quantization_compressor.decompress_module_from_state_dict(
467
+ prefix,
468
+ state_dict,
469
+ scheme=module_to_scheme[prefix],
470
+ )
466
471
  )
467
- # generates (mod_path, {param_name, param_val})
468
- # of compressed params and used params, but not unused params
469
- # some used params are removed by get_unexpected_file_keys
470
- state_dict = {
471
- merge_names(module_path, param_name): param_value
472
- for module_path, compressed_data in generator
473
- for param_name, param_value in compressed_data.items()
474
- }
475
472
 
476
473
  # remove any existing parameters
477
- device = get_execution_device(module)
474
+ exec_device = get_execution_device(module)
475
+ offload_device = get_offloaded_device(module)
478
476
  for name, _ in list(module.named_parameters()):
479
477
  delete_offload_parameter(module, name)
480
478
 
481
479
  # replace with decompressed parameters
482
480
  for name, value in state_dict.items():
483
481
  name = name.removeprefix(f"{prefix}.")
484
- value = value.to(device)
482
+ value = value.to(exec_device)
485
483
  param = torch.nn.Parameter(value, requires_grad=False)
486
- register_offload_parameter(module, name, param)
484
+ register_offload_parameter(module, name, param, offload_device)
487
485
 
488
486
  module.quantization_status = QuantizationStatus.FROZEN
489
487
 
@@ -24,6 +24,7 @@ from compressed_tensors.utils import (
24
24
  get_nested_weight_mappings,
25
25
  merge_names,
26
26
  )
27
+ from compressed_tensors.utils.safetensors_load import match_param_name
27
28
  from safetensors import safe_open
28
29
  from torch import Tensor
29
30
  from tqdm import tqdm
@@ -223,9 +224,7 @@ class BaseQuantizationCompressor(BaseCompressor):
223
224
  state_dict, self.compression_param_names
224
225
  )
225
226
  for module_path in weight_mappings.keys():
226
- weight_data = {}
227
- for param_name, param_value in weight_mappings[module_path].items():
228
- weight_data[param_name] = param_value
227
+ weight_data = weight_mappings[module_path].copy()
229
228
 
230
229
  if "weight_scale" in weight_data:
231
230
  quant_args = names_to_scheme[module_path].weights
@@ -234,3 +233,31 @@ class BaseQuantizationCompressor(BaseCompressor):
234
233
  )
235
234
  weight_data["weight"] = decompressed
236
235
  yield module_path, weight_data
236
+
237
+ def decompress_module_from_state_dict(
238
+ self,
239
+ prefix: str,
240
+ state_dict: Dict[str, torch.Tensor],
241
+ scheme: QuantizationScheme,
242
+ ) -> Dict[str, torch.Tensor]:
243
+ """
244
+ Only used by in-memory decompression pathways to decompress the parameters of
245
+ one module
246
+
247
+ :param prefix: prefix of state_dict, typically the path to the module
248
+ :param state_dict: state dict containing module parameter values
249
+ :param scheme: quantization scheme of module to decompress
250
+ :return: state dict with weight decompressed if applicable
251
+ """
252
+ state_dict = {
253
+ key.removeprefix(f"{prefix}."): value for key, value in state_dict.items()
254
+ }
255
+
256
+ if "weight_scale" in state_dict:
257
+ state_dict["weight"] = self.decompress_weight(
258
+ compressed_data=state_dict, quantization_args=scheme.weights
259
+ )
260
+
261
+ state_dict = {f"{prefix}.{key}": value for key, value in state_dict.items()}
262
+
263
+ return state_dict
@@ -27,14 +27,8 @@ from compressed_tensors.quantization.lifecycle.compressed import (
27
27
  )
28
28
  from compressed_tensors.quantization.lifecycle.initialize import (
29
29
  initialize_module_for_quantization,
30
- update_fused_layer_weight_global_scales,
31
- )
32
- from compressed_tensors.quantization.quant_args import (
33
- FP4_E2M1_DATA,
34
- FP8_E4M3_DATA,
35
- QuantizationArgs,
36
- QuantizationType,
37
30
  )
31
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
38
32
  from compressed_tensors.quantization.quant_config import (
39
33
  QuantizationConfig,
40
34
  QuantizationStatus,
@@ -272,9 +266,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
272
266
  )
273
267
  )
274
268
 
275
- if status == QuantizationStatus.INITIALIZED:
276
- update_fused_layer_weight_global_scales(model)
277
-
278
269
  if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
279
270
  model.apply(compress_quantized_weights)
280
271
 
@@ -21,7 +21,6 @@ from compressed_tensors.quantization.quant_args import (
21
21
  DynamicType,
22
22
  QuantizationArgs,
23
23
  QuantizationStrategy,
24
- QuantizationType,
25
24
  round_to_quantized_type,
26
25
  )
27
26
  from compressed_tensors.quantization.quant_config import QuantizationStatus
@@ -405,7 +404,7 @@ def _quantize(
405
404
 
406
405
  # if a global scale is optionally provided, use it
407
406
  # to further scale the local `scale` parameter
408
- if global_scale:
407
+ if global_scale is not None:
409
408
  scale = scale.to(global_scale.dtype) / global_scale
410
409
 
411
410
  scaled = x / scale
@@ -438,7 +437,7 @@ def _dequantize(
438
437
 
439
438
  # if a global scale is optionally provided, use it
440
439
  # to further scale the local `scale` parameter
441
- if global_scale:
440
+ if global_scale is not None:
442
441
  scale = scale.to(global_scale.dtype) / global_scale
443
442
 
444
443
  dequant_value = x_q.to(scale.dtype)
@@ -23,26 +23,18 @@ from compressed_tensors.quantization.lifecycle.forward import (
23
23
  wrap_module_forward_quantized,
24
24
  )
25
25
  from compressed_tensors.quantization.quant_args import (
26
- FP4_E2M1_DATA,
27
26
  FP8_E4M3_DATA,
28
27
  ActivationOrdering,
29
28
  QuantizationArgs,
30
29
  QuantizationStrategy,
31
- QuantizationType,
32
30
  )
33
31
  from compressed_tensors.quantization.quant_config import QuantizationStatus
34
32
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
35
- from compressed_tensors.quantization.utils import (
36
- generate_global_scale,
37
- is_fp4,
38
- is_kv_cache_quant_scheme,
39
- iter_named_quantizable_modules,
40
- )
33
+ from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
41
34
  from compressed_tensors.utils import (
42
35
  disable_hf_hook,
43
36
  get_execution_device,
44
37
  register_offload_parameter,
45
- update_parameter_data,
46
38
  )
47
39
  from torch.nn import Module, Parameter
48
40
 
@@ -51,7 +43,6 @@ __all__ = [
51
43
  "initialize_module_for_quantization",
52
44
  "is_attention_module",
53
45
  "KVCacheScaleType",
54
- "update_fused_layer_weight_global_scales",
55
46
  ]
56
47
 
57
48
 
@@ -162,22 +153,13 @@ def _initialize_scale_zero_point(
162
153
  # initialize on execution device to avoid performing quantized ops on cpu
163
154
  device = get_execution_device(module)
164
155
 
165
- # 1. Create global_scales for tensor_group
156
+ # 1. Create global_scales for tensor_group - generates
157
+ # a per tensor scale
166
158
  if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
167
- # TODO: should move to llmcompressor
168
- if base_name == "weight":
169
- # When applying weight-only FP4 quantization, generate a global_scale
170
- # This scale is applied during runtime to ensure that the generated
171
- # local scale falls properly within the FP8 range (i.e max value is FP8_max)
172
- # which is the expected dtype of NVFP4A16 scales
173
- value = generate_global_scale(input_tensor=module.weight)
174
- value = value.to(device)
175
- init_global_scale = Parameter(value, requires_grad=False)
176
- else:
177
- init_global_scale = Parameter(
178
- torch.empty(1, dtype=torch.float32, device=device),
179
- requires_grad=False,
180
- )
159
+ init_global_scale = Parameter(
160
+ torch.empty(1, dtype=torch.float32, device=device),
161
+ requires_grad=False,
162
+ )
181
163
  register_offload_parameter(
182
164
  module, f"{base_name}_global_scale", init_global_scale
183
165
  )
@@ -258,91 +240,3 @@ def _initialize_attn_scales(module: Module) -> None:
258
240
  requires_grad=False,
259
241
  )
260
242
  register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
261
-
262
-
263
- # TODO: Potentially introduce an argument to turn this off
264
- # Only relevant for NVFP4A16 currently
265
- def update_fused_layer_weight_global_scales(model: torch.nn.Module):
266
- """
267
- When running NVFP4A16 quantization, update the global scale
268
- such that q,k,v layers are treated as one tensor with the same
269
- global_scale and gate_proj/up_proj layers are treated as one tensor
270
- with the same global scale. This is requirement currently being set
271
- by vLLM and may be removed in the future OR potentially make it
272
- an optional step.
273
-
274
- :param model: model to quantize
275
- """
276
-
277
- def _is_attention_module(module: Module):
278
- return "attention" in module.__class__.__name__.lower() and (
279
- hasattr(module, "k_proj")
280
- or hasattr(module, "v_proj")
281
- or hasattr(module, "qkv_proj")
282
- )
283
-
284
- def _is_mlp_module(module: Module):
285
- return "mlp" in module.__class__.__name__.lower() and (
286
- hasattr(module, "gate_proj") or hasattr(module, "up_proj")
287
- )
288
-
289
- def _valid_fp4_quant(layer_list: List[torch.nn.Linear]):
290
- """
291
- Return True if all the linear layers in the layer_list are
292
- NVFP4A16 quantized.
293
- """
294
- for layer in layer_list:
295
- scheme = getattr(layer, "quantization_scheme", None)
296
- if scheme is None:
297
- return False
298
-
299
- weight_quant_args = scheme.weights
300
-
301
- if weight_quant_args is None:
302
- return False
303
-
304
- if not is_fp4(quantization_args=weight_quant_args):
305
- return False
306
- return True
307
-
308
- for name, submodule in iter_named_quantizable_modules(
309
- model,
310
- include_attn=True,
311
- include_mlp=True,
312
- ):
313
-
314
- if _is_attention_module(submodule):
315
- # already fused/treated as one layer
316
- if hasattr(submodule, "qkv_proj"):
317
- continue
318
-
319
- if not _valid_fp4_quant(
320
- [submodule.q_proj, submodule.v_proj, submodule.k_proj]
321
- ):
322
- continue
323
-
324
- q_weight = submodule.q_proj.weight.data
325
- v_weight = submodule.v_proj.weight.data
326
- k_weight = submodule.k_proj.weight.data
327
-
328
- value = generate_global_scale(
329
- input_tensor=torch.cat((q_weight, v_weight, k_weight), dim=0)
330
- )
331
-
332
- update_parameter_data(submodule.q_proj, value, "weight_global_scale")
333
- update_parameter_data(submodule.k_proj, value, "weight_global_scale")
334
- update_parameter_data(submodule.v_proj, value, "weight_global_scale")
335
-
336
- if _is_mlp_module(submodule):
337
- if not _valid_fp4_quant([submodule.gate_proj, submodule.up_proj]):
338
- continue
339
-
340
- gate_data = submodule.gate_proj.weight.data
341
- up_data = submodule.up_proj.weight.data
342
-
343
- value = generate_global_scale(
344
- input_tensor=torch.cat((gate_data, up_data), dim=0)
345
- )
346
-
347
- update_parameter_data(submodule.gate_proj, value, "weight_global_scale")
348
- update_parameter_data(submodule.up_proj, value, "weight_global_scale")
@@ -47,7 +47,7 @@ __all__ = [
47
47
  "compute_dynamic_scales_and_zp",
48
48
  "calculate_range",
49
49
  "calculate_qparams",
50
- "generate_global_scale",
50
+ "generate_gparam",
51
51
  "is_fp4",
52
52
  ]
53
53
 
@@ -110,6 +110,7 @@ def calculate_qparams(
110
110
  else:
111
111
  scales = max_val_pos / (float(bit_range) / 2)
112
112
 
113
+ # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
113
114
  if scales.dtype == FP8_E4M3_DATA.dtype:
114
115
  # torch.clamp not supported for FP8
115
116
  # use the next largest fp8 value from 0
@@ -475,8 +476,9 @@ def parse_out_kv_cache_args(
475
476
  return kv_cache_args, quant_scheme_to_layers
476
477
 
477
478
 
478
- def generate_global_scale(
479
- input_tensor: torch.Tensor,
479
+ def generate_gparam(
480
+ updated_min_val: torch.Tensor,
481
+ updated_max_val: torch.Tensor,
480
482
  scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
481
483
  quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
482
484
  dtype: Optional[torch.dtype] = torch.float32,
@@ -490,6 +492,8 @@ def generate_global_scale(
490
492
  attempts to use the entire FP8 dtype range while mapping a per-group max
491
493
  to the FP4 max.
492
494
  """
493
- tensor_amax = torch.abs(input_tensor.data).max().to(dtype)
494
- global_scale = scale_data.max * quant_data.max / tensor_amax
495
- return global_scale.to(dtype)
495
+ min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val))
496
+ max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
497
+ max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
498
+ global_scale = scale_data.max * quant_data.max / max_val_pos
499
+ return global_scale.to(dtype).reshape([1])
@@ -28,15 +28,18 @@ Utilities associated with offloading functionality provided by `accelerate`.
28
28
  import contextlib
29
29
  import warnings
30
30
  from functools import wraps
31
- from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
31
+ from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union
32
32
 
33
33
  import torch
34
34
 
35
35
 
36
36
  try:
37
+ from accelerate import dispatch_model
37
38
  from accelerate.hooks import (
38
39
  AlignDevicesHook,
39
40
  add_hook_to_module,
41
+ attach_align_device_hook,
42
+ named_module_tensors,
40
43
  remove_hook_from_module,
41
44
  )
42
45
  from accelerate.utils import (
@@ -54,6 +57,9 @@ except ImportError:
54
57
  OffloadedWeightsLoader = None
55
58
  PrefixedDataset = None
56
59
  set_module_tensor_to_device = None
60
+ named_module_tensors = None
61
+ dispatch_model = None
62
+ attach_align_device_hook = None
57
63
 
58
64
 
59
65
  __all__ = [
@@ -70,6 +76,9 @@ __all__ = [
70
76
  "disable_offload",
71
77
  "align_modules",
72
78
  "align_module_device",
79
+ "register_offload_module",
80
+ "delete_offload_module",
81
+ "force_cpu_offload",
73
82
  ]
74
83
 
75
84
 
@@ -77,6 +86,11 @@ def check_accelerate(fallback: Any):
77
86
  def decorator(func: Callable[[Any], Any]):
78
87
  if not _has_accelerate:
79
88
 
89
+ if fallback == "error":
90
+ raise ValueError(
91
+ "Please install `accelerate` in order to use this function"
92
+ )
93
+
80
94
  @wraps(func)
81
95
  def fallback_fn(*args, **kwargs):
82
96
  return fallback
@@ -346,6 +360,7 @@ def delete_from_weights_map(
346
360
  )
347
361
 
348
362
 
363
+ @check_accelerate(fallback=contextlib.nullcontext())
349
364
  @contextlib.contextmanager
350
365
  def disable_offload(module: torch.nn.Module):
351
366
  """
@@ -362,6 +377,7 @@ def disable_offload(module: torch.nn.Module):
362
377
  yield
363
378
 
364
379
 
380
+ @check_accelerate(fallback=contextlib.nullcontext())
365
381
  @contextlib.contextmanager
366
382
  def align_modules(
367
383
  modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
@@ -383,6 +399,123 @@ def align_modules(
383
399
  yield
384
400
 
385
401
 
402
+ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
403
+ """
404
+ Register a submodule with offloading if the parent module is offloaded
405
+
406
+ :param base: module to attach submodule to
407
+ :param name: name of submodule
408
+ :param module: submodule to attach
409
+ """
410
+
411
+ if has_offloaded_params(base):
412
+ hook: AlignDevicesHook = base._hf_hook
413
+ assert hook.offload
414
+ assert hook.weights_map is not None
415
+ assert hook.tied_params_map is not None
416
+
417
+ # offloading kwargs for submodule
418
+ place_submodules = False
419
+ offload_buffers = True
420
+
421
+ # copy device offloading arguments from parent
422
+ current_device = next(base.parameters()).device # assume base has parameters
423
+ offload_device = get_offloaded_device(base)
424
+
425
+ # offload parameters to weights map
426
+ for param_name, param in named_module_tensors(
427
+ module, include_buffers=offload_buffers, recurse=place_submodules
428
+ ):
429
+ offloaded = param.to(offload_device)
430
+ hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
431
+ offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
432
+
433
+ # if the parent places submodules, offload here
434
+ if hook.place_submodules:
435
+ set_module_tensor_to_device(module, param_name, current_device)
436
+
437
+ # if the parent does not place submodules, then add a hook
438
+ # parameters are offloaded by `add_hook_to_module`
439
+ if not hook.place_submodules:
440
+ weights_map = PrefixedDataset(
441
+ hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}."
442
+ )
443
+
444
+ submodule_hook = AlignDevicesHook(
445
+ execution_device=hook.execution_device,
446
+ offload=hook.offload,
447
+ io_same_device=False,
448
+ weights_map=weights_map,
449
+ offload_buffers=offload_buffers,
450
+ place_submodules=place_submodules,
451
+ skip_keys=None,
452
+ tied_params_map=hook.tied_params_map,
453
+ )
454
+ add_hook_to_module(module, submodule_hook)
455
+
456
+ base.register_module(name, module)
457
+
458
+ # (1): Since we cannot know which pointers are shared when we add parameters in an
459
+ # online way, assume that all pointers are shared. This comes at no runtime cost
460
+
461
+
462
+ def delete_offload_module(base: torch.nn.Module, name: str):
463
+ """
464
+ Delete a submodule from a model which may contain offloading
465
+ :param base: parent module to delete submodule from
466
+ :param name: name of submodule on parent
467
+ """
468
+ module: torch.nn.Module = getattr(base, name)
469
+
470
+ for param_name, _ in list(module.named_parameters()):
471
+ delete_offload_parameter(module, param_name)
472
+
473
+ delattr(base, name)
474
+
475
+
476
+ @check_accelerate(fallback="error")
477
+ def force_cpu_offload(
478
+ module: torch.nn.Module, execution_device: torch.device
479
+ ) -> torch.nn.Module:
480
+ """
481
+ Force cpu offloading a module, primarily used for testing
482
+
483
+ :param module: module containing parameters to offload
484
+ :param execution_device: execution device submodules
485
+ :return: module with hooks to perform cpu offloading
486
+ """
487
+ # edge case: there is a bug in `dispatch_model` which causes
488
+ # the function to only work if the model contains submodules
489
+ if next(module.children(), None) is None:
490
+ attach_align_device_hook(
491
+ module,
492
+ execution_device=execution_device,
493
+ offload=True,
494
+ weights_map=module.state_dict(),
495
+ tied_params_map={},
496
+ )
497
+ return module
498
+
499
+ device_map = {}
500
+
501
+ def collect_device_map(name: List[str], module: torch.nn.Module):
502
+ if next(module.parameters(recurse=False), None) is not None:
503
+ device_map[".".join(name)] = "cpu"
504
+ return
505
+
506
+ else:
507
+ for submodule_name, submodule in module.named_children():
508
+ name.append(submodule_name)
509
+ collect_device_map(name, submodule)
510
+ name.pop()
511
+
512
+ collect_device_map([], module)
513
+
514
+ return dispatch_model(
515
+ module, device_map, main_device=execution_device, force_hooks=True
516
+ )
517
+
518
+
386
519
  """ Upstreamed Functions """
387
520
 
388
521
 
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.5.a20250602'
20
+ __version__ = version = '0.9.5.a20250604'
21
21
  __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250602
3
+ Version: 0.9.5a20250604
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -20,10 +20,7 @@ from compressed_tensors.quantization import (
20
20
  QuantizationArgs,
21
21
  QuantizationStrategy,
22
22
  )
23
- from compressed_tensors.quantization.utils import (
24
- calculate_qparams,
25
- generate_global_scale,
26
- )
23
+ from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
27
24
 
28
25
 
29
26
  @pytest.mark.parametrize(
@@ -70,8 +67,9 @@ def test_fused_global_scales():
70
67
  layer = torch.nn.Linear(7, 8)
71
68
  max_tensor_value = torch.abs(layer.weight.data).max()
72
69
  # use defaults
73
- global_scale = generate_global_scale(layer.weight)
70
+ min_val, max_val = torch.aminmax(layer.weight)
71
+ global_scale = generate_gparam(min_val.data, max_val.data)
74
72
  # max value should be = (448 * 6) / global_scale
75
- assert max_tensor_value == pytest.approx(
73
+ assert max_tensor_value.item() == pytest.approx(
76
74
  FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001
77
75
  )
@@ -16,10 +16,13 @@ import torch
16
16
  from compressed_tensors.utils import (
17
17
  align_module_device,
18
18
  align_modules,
19
+ delete_offload_module,
19
20
  delete_offload_parameter,
20
21
  disable_hf_hook,
22
+ force_cpu_offload,
21
23
  get_execution_device,
22
24
  has_offloaded_params,
25
+ register_offload_module,
23
26
  register_offload_parameter,
24
27
  update_offload_parameter,
25
28
  )
@@ -37,9 +40,17 @@ class ExampleModule(torch.nn.Module):
37
40
  return x * self.a + self.b
38
41
 
39
42
 
43
+ class ExampleModel(torch.nn.Module):
44
+ def __init__(self):
45
+ super().__init__()
46
+ self.linear = torch.nn.Linear(1, 2)
47
+
48
+ def forward(self, x):
49
+ return self.linear(x)
50
+
51
+
40
52
  @requires_accelerate()
41
53
  def test_has_offloaded_params():
42
- from accelerate.big_modeling import cpu_offload_with_hook
43
54
  from accelerate.hooks import attach_align_device_hook, remove_hook_from_module
44
55
 
45
56
  module = ExampleModule()
@@ -48,10 +59,6 @@ def test_has_offloaded_params():
48
59
  attach_align_device_hook(module, offload=False)
49
60
  assert not has_offloaded_params(module)
50
61
 
51
- remove_hook_from_module(module)
52
- module, _ = cpu_offload_with_hook(module)
53
- assert not has_offloaded_params(module)
54
-
55
62
  remove_hook_from_module(module)
56
63
  attach_align_device_hook(module, offload=True, weights_map=module.state_dict())
57
64
  assert has_offloaded_params(module)
@@ -334,3 +341,86 @@ def test_offload_to_weights_map():
334
341
  weights_map = PrefixedDataset(OffloadedWeightsLoader({name: old_value}), prefix)
335
342
  offload_to_weights_map(weights_map, name, new_value)
336
343
  assert weights_map[name] == new_value
344
+
345
+
346
+ @requires_gpu
347
+ @requires_accelerate()
348
+ @pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
349
+ def test_register_offload_module(exec_device):
350
+ # no offloading
351
+ model = ExampleModel()
352
+ child = torch.nn.Linear(2, 3)
353
+ register_offload_module(model, "child", child)
354
+ register_offload_module(model.linear, "child", child)
355
+ assert child in model.children()
356
+ assert child in model.linear.children()
357
+
358
+ # with offloading
359
+ model = ExampleModel()
360
+ child = torch.nn.Linear(2, 3)
361
+ force_cpu_offload(model, exec_device)
362
+ register_offload_module(model, "child", child)
363
+ register_offload_module(model.linear, "child", child)
364
+ assert child in model.children()
365
+ assert child in model.linear.children()
366
+
367
+ # can run modules
368
+ model(torch.empty(1))
369
+ child(torch.empty(2, device=exec_device))
370
+
371
+
372
+ @requires_gpu
373
+ @requires_accelerate()
374
+ @pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
375
+ def test_delete_offload_module(exec_device):
376
+ # no offloading
377
+ model = ExampleModel()
378
+ child = torch.nn.Linear(2, 3)
379
+ register_offload_module(model, "child", child)
380
+ register_offload_module(model.linear, "child", child)
381
+ delete_offload_module(model, "child")
382
+ delete_offload_module(model.linear, "child")
383
+ assert not child in model.children()
384
+ assert not child in model.linear.children()
385
+
386
+ # with offloading
387
+ model = ExampleModel()
388
+ child = torch.nn.Linear(2, 3)
389
+ force_cpu_offload(model, exec_device)
390
+ register_offload_module(model, "child", child)
391
+ register_offload_module(model.linear, "child", child)
392
+ delete_offload_module(model, "child")
393
+ delete_offload_module(model.linear, "child")
394
+ assert not child in model.children()
395
+ assert not child in model.linear.children()
396
+
397
+
398
+ @requires_gpu
399
+ @requires_accelerate()
400
+ @pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
401
+ def test_force_cpu_offload(exec_device):
402
+ # single module
403
+ module = torch.nn.Linear(1, 2)
404
+ module = force_cpu_offload(module, exec_device)
405
+ assert has_offloaded_params(module)
406
+ assert module._hf_hook.offload
407
+ assert module.weight.device == torch.device("meta")
408
+ assert "weight" in module._hf_hook.weights_map
409
+ assert module._hf_hook.tied_params_map is not None
410
+
411
+ # can run
412
+ module(torch.empty(1, device=exec_device))
413
+
414
+ # model
415
+ model = ExampleModel()
416
+ model = force_cpu_offload(model, exec_device)
417
+ assert not has_offloaded_params(model)
418
+
419
+ assert has_offloaded_params(model.linear)
420
+ assert model.linear._hf_hook.offload
421
+ assert model.linear.weight.device == torch.device("meta")
422
+ assert "weight" in model.linear._hf_hook.weights_map
423
+ assert model.linear._hf_hook.tied_params_map is not None
424
+
425
+ # can run
426
+ model(torch.empty(1, device=exec_device))