compressed-tensors 0.13.1a20260115__tar.gz → 0.13.1a20260123__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. compressed_tensors-0.13.1a20260123/.github/mergify.yml +64 -0
  2. compressed_tensors-0.13.1a20260123/.github/workflows/stale.yml +44 -0
  3. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/Makefile +2 -2
  4. {compressed_tensors-0.13.1a20260115/src/compressed_tensors.egg-info → compressed_tensors-0.13.1a20260123}/PKG-INFO +2 -2
  5. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/setup.py +1 -1
  6. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/__init__.py +10 -1
  7. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +7 -0
  8. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/config/format.py +7 -0
  9. compressed_tensors-0.13.1a20260123/src/compressed_tensors/offload/__init__.py +197 -0
  10. compressed_tensors-0.13.1a20260123/src/compressed_tensors/offload/cache/__init__.py +17 -0
  11. compressed_tensors-0.13.1a20260123/src/compressed_tensors/offload/cache/base.py +231 -0
  12. compressed_tensors-0.13.1a20260123/src/compressed_tensors/offload/cache/cpu.py +43 -0
  13. compressed_tensors-0.13.1a20260123/src/compressed_tensors/offload/cache/device.py +48 -0
  14. compressed_tensors-0.13.1a20260123/src/compressed_tensors/offload/dispatch.py +228 -0
  15. compressed_tensors-0.13.1a20260123/src/compressed_tensors/offload/module.py +103 -0
  16. compressed_tensors-0.13.1a20260123/src/compressed_tensors/offload/utils.py +158 -0
  17. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/lifecycle/forward.py +8 -10
  18. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/quant_metadata.py +24 -1
  19. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/utils/mxfp4_utils.py +1 -1
  20. compressed_tensors-0.13.1a20260123/src/compressed_tensors/utils/binary_search.py +52 -0
  21. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/utils/offload.py +10 -1
  22. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/version.py +1 -1
  23. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123/src/compressed_tensors.egg-info}/PKG-INFO +2 -2
  24. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors.egg-info/SOURCES.txt +15 -0
  25. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors.egg-info/requires.txt +1 -1
  26. compressed_tensors-0.13.1a20260123/tests/test_offload/cache/test_cpu.py +138 -0
  27. compressed_tensors-0.13.1a20260123/tests/test_offload/test_dispatch.py +215 -0
  28. compressed_tensors-0.13.1a20260123/tests/test_offload/test_interface.py +174 -0
  29. compressed_tensors-0.13.1a20260123/tests/test_offload/test_module.py +213 -0
  30. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/test_utils/test_mxfp4_utils.py +1 -1
  31. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/.github/.gitkeep +0 -0
  32. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/.github/actions/test/action.yml +0 -0
  33. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/.github/scripts/step-status +0 -0
  34. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/.github/workflows/quality-check.yaml +0 -0
  35. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/.github/workflows/test-check.yaml +0 -0
  36. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/.gitignore +0 -0
  37. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/LICENSE +0 -0
  38. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/README.md +0 -0
  39. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  40. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/examples/bit_packing/int4_config.json +0 -0
  41. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/examples/bitmask_compression.ipynb +0 -0
  42. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  43. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  44. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/examples/llama_1.1b/example_quant_config.json +0 -0
  45. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  46. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/examples/quantize_and_pack_int4.ipynb +0 -0
  47. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/pyproject.toml +0 -0
  48. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/setup.cfg +0 -0
  49. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/__init__.py +0 -0
  50. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/README.md +0 -0
  51. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/base.py +0 -0
  52. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/__init__.py +0 -0
  53. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/base.py +0 -0
  54. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/helpers.py +0 -0
  55. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  56. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  57. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  58. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  59. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +0 -0
  60. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  61. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  62. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  63. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  64. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  65. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  66. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  67. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  68. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/config/__init__.py +0 -0
  69. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/config/base.py +0 -0
  70. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/config/dense.py +0 -0
  71. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  72. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  73. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/linear/__init__.py +0 -0
  74. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  75. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/logger.py +0 -0
  76. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/modeling/__init__.py +0 -0
  77. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/modeling/attention.py +0 -0
  78. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/modeling/kvcache.py +0 -0
  79. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/__init__.py +0 -0
  80. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  81. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  82. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  83. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  84. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  85. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/quant_args.py +0 -0
  86. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/quant_config.py +0 -0
  87. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  88. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  89. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  90. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/registry/__init__.py +0 -0
  91. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/registry/registry.py +0 -0
  92. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/__init__.py +0 -0
  93. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/apply.py +0 -0
  94. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  95. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/factory/base.py +0 -0
  96. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/factory/hadamard.py +0 -0
  97. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/factory/matrix_multiply.py +0 -0
  98. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  99. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/transform_args.py +0 -0
  100. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/transform_config.py +0 -0
  101. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  102. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  103. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  104. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  105. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/transform/utils/matrix.py +0 -0
  106. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/utils/__init__.py +0 -0
  107. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/utils/helpers.py +0 -0
  108. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/utils/internal.py +0 -0
  109. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/utils/match.py +0 -0
  110. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/utils/permutations_24.py +0 -0
  111. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  112. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  113. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors/utils/type.py +0 -0
  114. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  115. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  116. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/__init__.py +0 -0
  117. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/conftest.py +0 -0
  118. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/mock_observer.py +0 -0
  119. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/__init__.py +0 -0
  120. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/model_compressors/__init__.py +0 -0
  121. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  122. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  123. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/quantized_compressors/test_fp4_quant.py +0 -0
  124. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  125. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  126. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  127. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +0 -0
  128. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  129. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  130. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  131. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  132. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  133. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_configs/__init__.py +0 -0
  134. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_configs/test_base.py +0 -0
  135. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_configs/test_infer_quant.py +0 -0
  136. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  137. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_linear/__init__.py +0 -0
  138. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_linear/test_compressed_linear.py +0 -0
  139. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_modeling/test_attention_and_cache.py +0 -0
  140. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/__init__.py +0 -0
  141. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/lifecycle/__init__.py +0 -0
  142. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/lifecycle/conftest.py +0 -0
  143. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  144. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  145. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  146. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  147. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  148. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  149. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/lifecycle/test_static_lifecycle.py +0 -0
  150. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/test_configs/__init__.py +0 -0
  151. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  152. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  153. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/test_quant_args.py +0 -0
  154. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/test_quant_config.py +0 -0
  155. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/test_quant_scheme.py +0 -0
  156. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  157. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_registry.py +0 -0
  158. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_transform/conftest.py +0 -0
  159. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_transform/factory/test_correctness.py +0 -0
  160. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_transform/factory/test_memory.py +0 -0
  161. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_transform/factory/test_serialization.py +0 -0
  162. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_transform/test_transform_args.py +0 -0
  163. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_transform/test_transform_config.py +0 -0
  164. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_transform/test_transform_scheme.py +0 -0
  165. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_transform/utils/test_hadamard.py +0 -0
  166. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_utils/__init__.py +0 -0
  167. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_utils/test_helpers.py +0 -0
  168. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_utils/test_match.py +0 -0
  169. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_utils/test_offload.py +0 -0
  170. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_utils/test_safetensors_load.py +0 -0
  171. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/test_utils/test_type.py +0 -0
  172. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/tests/testing_utils.py +0 -0
  173. {compressed_tensors-0.13.1a20260115 → compressed_tensors-0.13.1a20260123}/utils/copyright.py +0 -0
@@ -0,0 +1,64 @@
1
+ pull_request_rules:
2
+ - name: label-documentation
3
+ description: Automatically apply documentation label
4
+ conditions:
5
+ - label != stale
6
+ - -closed
7
+ - or:
8
+ - files~=^[^/]+\.md$
9
+ - files~=^docs/
10
+ - files~=^examples/
11
+ actions:
12
+ label:
13
+ add:
14
+ - documentation
15
+
16
+ - name: ping author on conflicts and add 'needs-rebase' label
17
+ conditions:
18
+ - label != stale
19
+ - conflict
20
+ - -closed
21
+ actions:
22
+ label:
23
+ add:
24
+ - needs-rebase
25
+ comment:
26
+ message: |
27
+ This pull request has merge conflicts that must be resolved before it can be
28
+ merged. Please rebase the PR, @{{author}}.
29
+
30
+ https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
31
+
32
+ - name: remove 'needs-rebase' label when conflict is resolved
33
+ conditions:
34
+ - -conflict
35
+ - -closed
36
+ actions:
37
+ label:
38
+ remove:
39
+ - needs-rebase
40
+
41
+ - name: add quality-failed label
42
+ conditions:
43
+ - label != stale
44
+ - check-failure = quality-check
45
+ - -closed
46
+ actions:
47
+ label:
48
+ add:
49
+ - quality-failed
50
+ comment:
51
+ message: |
52
+ The quality checks have failed. Please run `make style` and `make quality` under
53
+ the root directory to adddress the lint failures. You will need to install the
54
+ dev optional install to get the required linting packages.
55
+
56
+ - name: remove quality-failed label
57
+ conditions:
58
+ - label != stale
59
+ - -check-failure = quality-check
60
+ - -closed
61
+ actions:
62
+ label:
63
+ remove:
64
+ - quality-failed
@@ -0,0 +1,44 @@
1
+ name: 'Close inactive PRs'
2
+
3
+ on:
4
+ schedule:
5
+ - cron: '0 17 * * *'
6
+
7
+ jobs:
8
+ close-pull-requests:
9
+ if: github.repository == 'vllm-project/compressed-tensors'
10
+ permissions:
11
+ issues: write
12
+ pull-requests: write
13
+ actions: write
14
+ runs-on: ubuntu-latest
15
+ steps:
16
+ - uses: actions/stale@997185467fa4f803885201cee163a9f38240193d
17
+ with:
18
+ operations-per-run: 1000
19
+ exempt-draft-pr: true
20
+ exempt-issue-labels: 'keep-open'
21
+ exempt-pr-labels: 'keep-open'
22
+
23
+ days-before-issue-stale: 90
24
+ days-before-issue-close: 30
25
+ stale-issue-label: 'stale'
26
+ stale-issue-message: >
27
+ This issue has been automatically marked as stale because it has not
28
+ had any activity within 90 days. It will be automatically closed if no
29
+ further activity occurs within 30 days. Leave a comment if
30
+ you feel this issue should remain open. Thank you!
31
+ close-issue-message: >
32
+ This issue has been automatically closed due to inactivity. Please
33
+ feel free to reopen if you feel it is still relevant. Thank you!
34
+
35
+ days-before-pr-stale: 90
36
+ days-before-pr-close: 30
37
+ stale-pr-label: 'stale'
38
+ stale-pr-message: >
39
+ This pull request has been automatically marked as stale because it
40
+ has not had any activity within 90 days. It will be automatically
41
+ closed if no further activity occurs within 30 days.
42
+ close-pr-message: >
43
+ This pull request has been automatically closed due to inactivity.
44
+ Please feel free to reopen if you intend to continue working on it.
@@ -8,7 +8,7 @@ quality:
8
8
  @echo "Running copyright checks";
9
9
  python utils/copyright.py quality $(PYCHECKGLOBS)
10
10
  @echo "Running python quality checks";
11
- black --check $(PYCHECKDIRS);
11
+ black --target-version py310 --check $(PYCHECKDIRS);
12
12
  isort --check-only $(PYCHECKDIRS);
13
13
  flake8 $(PYCHECKDIRS);
14
14
 
@@ -17,7 +17,7 @@ style:
17
17
  @echo "Running copyright style";
18
18
  python utils/copyright.py style $(PYCHECKGLOBS)
19
19
  @echo "Running python styling";
20
- black $(PYCHECKDIRS);
20
+ black --target-version py310 $(PYCHECKDIRS);
21
21
  isort $(PYCHECKDIRS);
22
22
 
23
23
  # run tests for the repo
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.13.1a20260115
3
+ Version: 0.13.1a20260123
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/vllm-project/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -8,7 +8,7 @@ Author-email: support@neuralmagic.com
8
8
  License: Apache 2.0
9
9
  Description-Content-Type: text/markdown
10
10
  License-File: LICENSE
11
- Requires-Dist: torch>=1.7.0
11
+ Requires-Dist: torch<=2.9.1,>=1.7.0
12
12
  Requires-Dist: transformers
13
13
  Requires-Dist: pydantic>=2.0
14
14
  Requires-Dist: loguru
@@ -88,7 +88,7 @@ def _setup_packages() -> List:
88
88
  )
89
89
 
90
90
  def _setup_install_requires() -> List:
91
- return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "loguru"]
91
+ return ["torch>=1.7.0,<=2.9.1", "transformers", "pydantic>=2.0", "loguru"]
92
92
 
93
93
  def _setup_extras() -> Dict:
94
94
  return {
@@ -20,5 +20,14 @@ from .base import *
20
20
  from .compressors import *
21
21
  from .config import *
22
22
  from .quantization import QuantizationConfig, QuantizationStatus
23
- from .utils import *
23
+
24
+ # avoid resolving compressed_tensors.offload as compressed_tensors.utils.offload
25
+ from .utils.offload import *
26
+ from .utils.helpers import *
27
+ from .utils.internal import *
28
+ from .utils.match import *
29
+ from .utils.permutations_24 import *
30
+ from .utils.safetensors_load import *
31
+ from .utils.semi_structured_conversions import *
32
+ from .utils.type import *
24
33
  from .version import *
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
+ import warnings
16
17
  from typing import Dict, Generator, Tuple
17
18
 
18
19
  import numpy as np
@@ -138,6 +139,12 @@ class Marlin24Compressor(BaseCompressor):
138
139
  :param show_progress: whether to show tqdm progress
139
140
  :return: compressed state dict
140
141
  """
142
+ warnings.warn(
143
+ "The marlin24 format is deprecated and will be removed in a "
144
+ "future release. vLLM no longer supports marlin24 models.",
145
+ DeprecationWarning,
146
+ stacklevel=2,
147
+ )
141
148
  self.validate_quant_compatability(names_to_scheme)
142
149
 
143
150
  compressed_dict = {}
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import warnings
15
16
  from typing import List, Optional
16
17
 
17
18
  import torch
@@ -68,6 +69,12 @@ def _get_quant_compression_format(
68
69
  ):
69
70
  # marlin24 kernel only applicable for channel/group quantization
70
71
  # Note: vLLM may only support group quant for marlin24
72
+ warnings.warn(
73
+ "The marlin24 format is deprecated and will be removed in a "
74
+ "future release. vLLM no longer supports marlin24 models.",
75
+ DeprecationWarning,
76
+ stacklevel=2,
77
+ )
71
78
  return CompressionFormat.marlin_24
72
79
  return CompressionFormat.pack_quantized
73
80
 
@@ -0,0 +1,197 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ from typing import Iterable, Optional
17
+
18
+ import torch
19
+ from compressed_tensors.offload.cache import OffloadCache
20
+ from compressed_tensors.offload.dispatch import ( # noqa: F401
21
+ dispatch_model,
22
+ offload_model,
23
+ remove_dispatch,
24
+ )
25
+ from compressed_tensors.offload.module import offload_module, unwrap_offload_forward
26
+ from compressed_tensors.offload.utils import get_module_device, move_module_tensor
27
+ from compressed_tensors.utils.helpers import patch_attr
28
+
29
+
30
+ __all__ = [
31
+ # dispatch models
32
+ "offload_model",
33
+ "dispatch_model",
34
+ "remove_dispatch",
35
+ # control movement
36
+ "disable_onloading",
37
+ "disable_offloading",
38
+ # manipulate parameters
39
+ "update_offload_parameter",
40
+ "get_execution_device",
41
+ "get_offloaded_device",
42
+ "register_offload_module",
43
+ # manipulate forward
44
+ "unwrap_offload_forward",
45
+ # backwards compatibility: should be deprecated
46
+ "align_modules",
47
+ "align_module_device",
48
+ ]
49
+
50
+
51
+ @contextlib.contextmanager
52
+ def disable_offloading():
53
+ """
54
+ When offloading is disabled, onloaded tensors remain onloaded in memory until exit
55
+
56
+ ```
57
+ with OffloadCache.disable_offloading():
58
+ ... = cache["weight"]
59
+ ... = cache["weight"] # cache hit
60
+ ... = cache["weight"] # cache hit
61
+
62
+ # upon exit, all onloaded weights are released
63
+ ```
64
+ """
65
+ with OffloadCache.disable_offloading():
66
+ yield
67
+
68
+
69
+ @contextlib.contextmanager
70
+ def disable_onloading():
71
+ """
72
+ When onloading is disabled, tensors are not offloaded on access, and assignments do
73
+ not trigger offloading. This is mostly used to disable device movement for debugging
74
+
75
+ ```
76
+ with OffloadCache.disable_onloading():
77
+ tensor = ...
78
+ cache["weight"] = tensor # assignments do not trigger onloading
79
+ cache["weight"] is tensor # tensor remains offloaded
80
+ ```
81
+ """
82
+ with OffloadCache.disable_onloading():
83
+ yield
84
+
85
+
86
+ def update_offload_parameter(module: torch.nn.Module, name: str, data: torch.Tensor):
87
+ """
88
+ Update the data of an existing parameter and its offload dict. Supports both
89
+ parameters of offloaded modules and non-offloaded modules
90
+
91
+ :param module: module containing the parameter to update
92
+ :param name: name of module parameter to update
93
+ :param data: tensor to update parameter with
94
+ """
95
+ if isinstance(module._parameters, OffloadCache):
96
+ with module._parameters.disable_onloading():
97
+ value = getattr(module, name)
98
+ value.copy_(module._parameters.offload(data))
99
+ setattr(module, name, value)
100
+
101
+ else:
102
+ getattr(module, name).copy_(data)
103
+
104
+
105
+ def get_execution_device(module: torch.nn.Module) -> torch.device | str:
106
+ """
107
+ Get the device which inputs should be moved to before module execution.
108
+
109
+ :param module: module to check, may be offloaded
110
+ :return: onload device of module
111
+ """
112
+ if isinstance(module._parameters, OffloadCache):
113
+ return module._parameters.onload_device
114
+
115
+ else:
116
+ return get_module_device(module)
117
+
118
+
119
+ def get_offloaded_device(module: torch.nn.Module) -> torch.device:
120
+ """
121
+ :param module: module to check
122
+ :return: device module is offloaded to onto after forward pass
123
+ """
124
+ with disable_onloading():
125
+ return get_module_device(module)
126
+
127
+
128
+ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
129
+ """
130
+ Register a submodule with offloading if the parent module is offloaded
131
+
132
+ :param base: module to attach submodule to
133
+ :param name: name of submodule
134
+ :param module: submodule to attach
135
+ """
136
+ cache = base._parameters
137
+ if isinstance(cache, OffloadCache):
138
+ offload_module(
139
+ module, cache.onload_device, cache.offload_device, no_split=False
140
+ )
141
+
142
+ base.register_module(name, module)
143
+
144
+
145
+ """ Implemented for backwards compatibility """
146
+
147
+
148
+ @contextlib.contextmanager
149
+ def align_modules(
150
+ modules: torch.nn.Module | Iterable[torch.nn.Module],
151
+ execution_device: Optional[torch.device] = None,
152
+ ):
153
+ """
154
+ Context manager for onloading modules to a device, and disabling onload and offload
155
+ attempts triggered by forward calls. Used for sequential onloading of layers
156
+
157
+ :param modules: `torch.nn.Module` or iterable of `torch.nn.Module`s to onload
158
+ :param execution_device: device to onload to
159
+ """
160
+ with contextlib.ExitStack() as stack:
161
+ for module in modules:
162
+ stack.enter_context(align_module_device(module, execution_device))
163
+ yield
164
+
165
+
166
+ @contextlib.contextmanager
167
+ def align_module_device(
168
+ module: torch.nn.Module, execution_device: Optional[torch.device] = None
169
+ ):
170
+ """
171
+ Context manager that moves a module's parameters to the specified execution device.
172
+
173
+ :param module: Module with parameters to align
174
+ :param execution_device: If provided, overrides the module's execution device
175
+ within the context. Otherwise, use hook execution device or pass
176
+ """
177
+
178
+ if isinstance(module._parameters, OffloadCache):
179
+ assert isinstance(module._buffers, OffloadCache)
180
+ with module._parameters.disable_offloading():
181
+ with patch_attr(
182
+ module._parameters, "onload_device", execution_device
183
+ ), patch_attr(module._buffers, "onload_device", execution_device):
184
+ yield
185
+
186
+ else:
187
+ original_device = {}
188
+ for name, param in module.named_parameters(recurse=False):
189
+ original_device[name] = param.device
190
+ move_module_tensor(module, name, execution_device)
191
+
192
+ try:
193
+ yield
194
+ finally:
195
+ for name, param in module.named_parameters(recurse=False):
196
+ device = original_device[name]
197
+ move_module_tensor(module, name, device)
@@ -0,0 +1,17 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .base import OffloadCache
17
+ from .cpu import CPUCache
@@ -0,0 +1,231 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ from abc import ABC, abstractmethod
17
+ from collections.abc import MutableMapping
18
+ from typing import ClassVar, Literal, Optional
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+
23
+
24
+ class OffloadCache(MutableMapping, ABC):
25
+ """
26
+ Base class for offload caches. Subclasses must implement `offload` and `onload`.
27
+ Instances have similar behavior to dicts, except that tensors are offloaded when
28
+ assigned and onloaded when accessed.
29
+
30
+ Typical usage:
31
+ ```
32
+ module._parameters = cache_cls.from_mapping(module._parameters, onload_device)
33
+ tensor = ...
34
+ module._parameters["name"] = tensor # tensor is offloaded
35
+ onloaded_tensor = module._parameters["name"] # tensor is onloaded
36
+ ```
37
+
38
+ This class implements two contexts for more fine-grained control of device movement:
39
+ `OffloadCache.disable_offloading` and `OffloadCache.disable_onloading`. For more
40
+ info, see `compressed_tensors.offload::(disable_offloading|disable_onloading)`
41
+ """
42
+
43
+ onload_device: torch.device | str
44
+ offload_device: Optional[torch.device | str]
45
+
46
+ # global flags for disabling
47
+ offloading_disabled: ClassVar[bool] = False
48
+ onloading_disabled: ClassVar[bool] = False
49
+
50
+ # names -> offloaded tensors (populated from _parameters or _buffers)
51
+ offloaded_values: dict[str, torch.Tensor]
52
+
53
+ # offloaded tensors -> onloaded tensors (only when offloading is disabled)
54
+ keep_onloaded_values: ClassVar[dict[torch.Tensor, torch.Tensor]] = dict()
55
+
56
+ @classmethod
57
+ def cls_from_device(
58
+ cls,
59
+ device: Optional[torch.device | str | Literal["disk"]] = None,
60
+ ) -> type["OffloadCache"]:
61
+ """
62
+ Get the subclass which implements offloading for the given `offload_device`.
63
+ Use `torch.distributed` to detect if the environment is distributed
64
+
65
+ :param device: offload device used to find subclass
66
+ :return: subclass of `OffloadCache`
67
+ """
68
+ from compressed_tensors.offload.cache.cpu import CPUCache
69
+ from compressed_tensors.offload.cache.device import DeviceCache
70
+
71
+ device_type = torch.device(device).type if device != "disk" else "disk"
72
+ distributed = dist.is_available() and dist.is_initialized()
73
+
74
+ match (device_type, distributed):
75
+ case ("cpu", False):
76
+ return CPUCache
77
+ case ("cuda", False):
78
+ return DeviceCache
79
+ case _:
80
+ raise NotImplementedError(
81
+ f"Offload of type {device} and "
82
+ f"distributed={distributed} has not been implemented"
83
+ )
84
+
85
+ @classmethod
86
+ def from_mapping(
87
+ cls,
88
+ mapping: MutableMapping[str, torch.Tensor | None],
89
+ onload_device: torch.device | str,
90
+ ):
91
+ """
92
+ Initialize an instance from a given mapping, typically `Module._parameters` or
93
+ `Module._buffers`. Mapping values will be offloaded
94
+
95
+ :param mapping: mapping used to populate cache
96
+ :param onload_device: device which tensors will be onloaded to
97
+ """
98
+ instance = cls(onload_device=onload_device)
99
+ instance.offloaded_values = {
100
+ name: instance.offload(tensor) for name, tensor in mapping.items()
101
+ }
102
+
103
+ return instance
104
+
105
+ def __init__(self, onload_device: torch.device | str):
106
+ super().__init__()
107
+ self.onload_device = onload_device
108
+ self.offloaded_values = dict()
109
+
110
+ @abstractmethod
111
+ def onload(self, offloaded: torch.Tensor | None) -> torch.Tensor:
112
+ """
113
+ Given an offloaded tensor, returns that tensor after onloading
114
+
115
+ :param offloaded: offloaded tensor
116
+ :return: onloaded tensor
117
+ """
118
+ raise NotImplementedError()
119
+
120
+ @abstractmethod
121
+ def offload(self, tensor: torch.Tensor | None) -> torch.Tensor:
122
+ """
123
+ Given a tensor, returns that tensor after offloading
124
+
125
+ :param tensor: tensor to offload
126
+ :return: offloaded tensor
127
+ """
128
+ raise NotImplementedError()
129
+
130
+ def __getitem__(self, key: str) -> torch.Tensor:
131
+ """
132
+ Onload a tensor
133
+
134
+ If called within the `disable_offloading` context, a strong reference of the
135
+ onloaded tensor is kept so that future accesses will not require device movement
136
+
137
+ :param key: name of tensor to access
138
+ :return: onloaded tensor
139
+ """
140
+ offloaded = self.offloaded_values[key]
141
+
142
+ # when onloading is disabled, offloaded tensors can be accessed directly
143
+ if offloaded is None or self.onloading_disabled:
144
+ return offloaded
145
+
146
+ # check for cache hit
147
+ if offloaded in self.keep_onloaded_values:
148
+ return self.keep_onloaded_values[offloaded]
149
+
150
+ # onload value
151
+ onloaded = self.onload(offloaded)
152
+
153
+ # when offloading is disabled, populate cache
154
+ if self.offloading_disabled:
155
+ self.keep_onloaded_values[offloaded] = onloaded
156
+
157
+ return onloaded
158
+
159
+ def __setitem__(self, key: str, value: torch.Tensor | None):
160
+ """
161
+ Offload a tensor and add it to the cache.
162
+
163
+ If called within the `disable_onloading` context, the tensor is not offloaded
164
+ and is instead assigned directly
165
+
166
+ :param key: name of tensor
167
+ :param value: tensor value to offload
168
+ """
169
+ if key in self:
170
+ del self[key]
171
+
172
+ # when onloading is disabled, parameters can be access and assigned directly
173
+ if self.onloading_disabled:
174
+ self.offloaded_values[key] = value
175
+ return
176
+
177
+ self.offloaded_values[key] = self.offload(value)
178
+
179
+ def __delitem__(self, key: str):
180
+ """
181
+ Remove the offloaded tensor associated with `key`. Any references to its
182
+ onloaded tensors held by this class are invalidated.
183
+
184
+ :param key: name of tensor to invalidate
185
+ """
186
+ offloaded = self.offloaded_values[key]
187
+ del self.offloaded_values[key]
188
+
189
+ # remove strong ref
190
+ if offloaded in self.keep_onloaded_values:
191
+ del self.keep_onloaded_values[offloaded]
192
+
193
+ def __contains__(self, key) -> bool:
194
+ return key in self.offloaded_values
195
+
196
+ def __iter__(self):
197
+ return iter(self.offloaded_values)
198
+
199
+ def __len__(self):
200
+ return len(self.offloaded_values)
201
+
202
+ @classmethod
203
+ @contextlib.contextmanager
204
+ def disable_offloading(cls):
205
+ """
206
+ Context to disable all offloading for offloaded modules which share this cache.
207
+ After a weight has been fetched once, that onloaded value is cached and
208
+ subsequent fetches will leverage the cache, reducing device movement
209
+ """
210
+ if not OffloadCache.offloading_disabled:
211
+ OffloadCache.offloading_disabled = True
212
+ yield
213
+ OffloadCache.offloading_disabled = False
214
+ OffloadCache.keep_onloaded_values.clear()
215
+ else:
216
+ yield
217
+
218
+ @classmethod
219
+ @contextlib.contextmanager
220
+ def disable_onloading(cls):
221
+ """
222
+ Context to disable all onloading for offloaded modules which share this cache.
223
+ This is mostly used for debugging purposes, and allows the caller to directly
224
+ inspect offloaded tensors and directly assign offloaded tensors without copying
225
+ """
226
+ if not OffloadCache.onloading_disabled:
227
+ OffloadCache.onloading_disabled = True
228
+ yield
229
+ OffloadCache.onloading_disabled = False
230
+ else:
231
+ yield