compressed-tensors 0.12.3a20251003__tar.gz → 0.12.3a20251007__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/PKG-INFO +1 -1
  2. compressed_tensors-0.12.3a20251007/src/compressed_tensors/quantization/lifecycle/initialize.py +291 -0
  3. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/utils/helpers.py +25 -0
  4. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/version.py +1 -1
  5. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors.egg-info/PKG-INFO +1 -1
  6. compressed_tensors-0.12.3a20251003/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -270
  7. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/.gitkeep +0 -0
  8. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/actions/test/action.yml +0 -0
  9. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/scripts/step-status +0 -0
  10. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/workflows/build-test.yml +0 -0
  11. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/workflows/build.yml +0 -0
  12. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/workflows/post-release-nightly-build.yml +0 -0
  13. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/workflows/quality-check.yaml +0 -0
  14. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/workflows/report.yml +0 -0
  15. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/workflows/test-check.yaml +0 -0
  16. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/workflows/test.yml +0 -0
  17. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/workflows/trigger-all.yml +0 -0
  18. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.github/workflows/upload.yml +0 -0
  19. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/.gitignore +0 -0
  20. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/LICENSE +0 -0
  21. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/Makefile +0 -0
  22. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/README.md +0 -0
  23. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  24. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/examples/bit_packing/int4_config.json +0 -0
  25. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/examples/bitmask_compression.ipynb +0 -0
  26. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  27. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  28. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/examples/llama_1.1b/example_quant_config.json +0 -0
  29. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  30. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/examples/quantize_and_pack_int4.ipynb +0 -0
  31. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/pyproject.toml +0 -0
  32. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/setup.cfg +0 -0
  33. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/setup.py +0 -0
  34. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/__init__.py +0 -0
  35. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/README.md +0 -0
  36. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/__init__.py +0 -0
  37. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/base.py +0 -0
  38. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/__init__.py +0 -0
  39. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/base.py +0 -0
  40. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/helpers.py +0 -0
  41. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  42. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  43. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  44. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  45. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  46. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  47. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  48. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  49. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  50. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  51. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  52. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  53. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  54. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  55. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/config/__init__.py +0 -0
  56. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/config/base.py +0 -0
  57. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/config/dense.py +0 -0
  58. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/config/format.py +0 -0
  59. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  60. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  61. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/linear/__init__.py +0 -0
  62. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  63. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/logger.py +0 -0
  64. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/__init__.py +0 -0
  65. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  66. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  67. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  68. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  69. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  70. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/quant_args.py +0 -0
  71. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/quant_config.py +0 -0
  72. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/quant_metadata.py +0 -0
  73. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  74. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  75. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/registry/__init__.py +0 -0
  76. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/registry/registry.py +0 -0
  77. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/__init__.py +0 -0
  78. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/apply.py +0 -0
  79. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  80. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/factory/base.py +0 -0
  81. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/factory/hadamard.py +0 -0
  82. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/factory/matrix_multiply.py +0 -0
  83. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  84. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/transform_args.py +0 -0
  85. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/transform_config.py +0 -0
  86. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/transform_scheme.py +0 -0
  87. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  88. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  89. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  90. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/transform/utils/matrix.py +0 -0
  91. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/utils/__init__.py +0 -0
  92. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/utils/helpers.py +0 -0
  93. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/utils/internal.py +0 -0
  94. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/utils/match.py +0 -0
  95. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/utils/offload.py +0 -0
  96. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/utils/permutations_24.py +0 -0
  97. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  98. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  99. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors/utils/type.py +0 -0
  100. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors.egg-info/SOURCES.txt +0 -0
  101. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  102. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors.egg-info/requires.txt +0 -0
  103. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  104. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/__init__.py +0 -0
  105. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/conftest.py +0 -0
  106. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/__init__.py +0 -0
  107. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/model_compressors/__init__.py +0 -0
  108. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  109. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  110. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  111. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  112. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  113. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  114. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  115. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  116. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  117. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  118. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  119. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_configs/__init__.py +0 -0
  120. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_configs/test_base.py +0 -0
  121. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_configs/test_infer_quant.py +0 -0
  122. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  123. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_linear/__init__.py +0 -0
  124. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_linear/test_compressed_linear.py +0 -0
  125. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/__init__.py +0 -0
  126. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/lifecycle/__init__.py +0 -0
  127. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/lifecycle/conftest.py +0 -0
  128. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  129. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  130. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  131. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  132. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  133. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  134. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/test_configs/__init__.py +0 -0
  135. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  136. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  137. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/test_quant_args.py +0 -0
  138. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/test_quant_config.py +0 -0
  139. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/test_quant_scheme.py +0 -0
  140. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  141. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_registry.py +0 -0
  142. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_transform/conftest.py +0 -0
  143. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_transform/factory/test_correctness.py +0 -0
  144. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_transform/factory/test_memory.py +0 -0
  145. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_transform/factory/test_serialization.py +0 -0
  146. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_transform/test_transform_args.py +0 -0
  147. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_transform/test_transform_config.py +0 -0
  148. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_transform/test_transform_scheme.py +0 -0
  149. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_transform/utils/test_hadamard.py +0 -0
  150. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_utils/__init__.py +0 -0
  151. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_utils/test_helpers.py +0 -0
  152. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_utils/test_match.py +0 -0
  153. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_utils/test_offload.py +0 -0
  154. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_utils/test_safetensors_load.py +0 -0
  155. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/test_utils/test_type.py +0 -0
  156. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/tests/testing_utils.py +0 -0
  157. {compressed_tensors-0.12.3a20251003 → compressed_tensors-0.12.3a20251007}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251003
3
+ Version: 0.12.3a20251007
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -0,0 +1,291 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+ from typing import Optional, Tuple
18
+
19
+ import torch
20
+ from compressed_tensors.quantization import (
21
+ FP8_E4M3_DATA,
22
+ ActivationOrdering,
23
+ DynamicType,
24
+ KVCacheScaleType,
25
+ QuantizationArgs,
26
+ QuantizationMetadata,
27
+ QuantizationScheme,
28
+ QuantizationStatus,
29
+ QuantizationStrategy,
30
+ )
31
+ from compressed_tensors.quantization.lifecycle.forward import (
32
+ wrap_module_forward_quantized,
33
+ )
34
+ from compressed_tensors.quantization.utils import (
35
+ is_fp4,
36
+ is_kv_cache_quant_scheme,
37
+ strategy_cdiv,
38
+ )
39
+ from compressed_tensors.utils import (
40
+ disable_hf_hook,
41
+ get_execution_device,
42
+ register_offload_parameter,
43
+ )
44
+ from torch.nn import Module, Parameter
45
+
46
+
47
+ __all__ = [
48
+ "initialize_module_for_quantization",
49
+ "is_attention_module",
50
+ "initialize_qparams",
51
+ ]
52
+
53
+
54
+ _LOGGER = logging.getLogger(__name__)
55
+
56
+
57
+ def initialize_module_for_quantization(
58
+ module: Module,
59
+ scheme: Optional[QuantizationScheme] = None,
60
+ force_zero_point: bool = True,
61
+ ):
62
+ """
63
+ Attaches appropriate scales, zero points, and observers to a layer
64
+ given its target quantization scheme.
65
+
66
+ Previously initialized scales and zero points will be removed from
67
+ module if they no longer apply to the scheme
68
+
69
+ :param module: module to set for calibration
70
+ :param scheme: scheme to use for quantization. if None is provided,
71
+ will attempt to use scheme stored in the module under `quantization_scheme`,
72
+ if not provided, the layer will be skipped
73
+ :param force_zero_point: whether to force initialization of a zero point for
74
+ symmetric quantization
75
+ """
76
+ scheme = scheme or getattr(module, "quantization_scheme", None)
77
+ if scheme is None:
78
+ return
79
+
80
+ QuantizationMetadata.clear_all_qparams(module)
81
+
82
+ if is_attention_module(module):
83
+ # quantized actions based on calltime status
84
+ _initialize_attn_scales(module)
85
+
86
+ else:
87
+ if not isinstance(module, torch.nn.Linear):
88
+ _LOGGER.warning(f"Attempting to quantize module of type {type(module)}")
89
+
90
+ # use weight to determine observed shapes and dtype
91
+ if hasattr(module, "weight"):
92
+ weight = module.weight
93
+ assert isinstance(weight, torch.Tensor)
94
+ else:
95
+ # Note that a weight is required for both weight and activation
96
+ # quantization in order to know the dtype of activation scales
97
+ _LOGGER.warning(
98
+ f"module type {type(module)} targeted for quantization but "
99
+ f"has no attribute weight, skipping quantization for {type(module)}"
100
+ )
101
+ return
102
+
103
+ if scheme.input_activations is not None:
104
+ initialize_qparams(
105
+ module,
106
+ "input",
107
+ scheme.input_activations,
108
+ observed_shape=weight.shape[-1:],
109
+ observed_dtype=weight.dtype,
110
+ force_zero_point=force_zero_point,
111
+ )
112
+
113
+ if scheme.weights is not None:
114
+ initialize_qparams(
115
+ module,
116
+ "weight",
117
+ scheme.weights,
118
+ observed_shape=weight.shape,
119
+ observed_dtype=weight.dtype,
120
+ force_zero_point=force_zero_point,
121
+ )
122
+
123
+ output_is_kv_cache = is_kv_cache_quant_scheme(scheme)
124
+ if scheme.output_activations is not None and not output_is_kv_cache:
125
+ initialize_qparams(
126
+ module,
127
+ "output",
128
+ scheme.output_activations,
129
+ observed_shape=weight.shape[:-1],
130
+ observed_dtype=weight.dtype,
131
+ force_zero_point=force_zero_point,
132
+ )
133
+
134
+ module.quantization_scheme = scheme
135
+ module.quantization_status = QuantizationStatus.INITIALIZED
136
+
137
+ with disable_hf_hook(module):
138
+ # wrap forward call of module to perform
139
+ # quantized actions based on calltime status
140
+ wrap_module_forward_quantized(module, scheme)
141
+
142
+
143
+ def is_attention_module(module: Module):
144
+ return "attention" in module.__class__.__name__.lower() and (
145
+ hasattr(module, "k_proj")
146
+ or hasattr(module, "v_proj")
147
+ or hasattr(module, "qkv_proj")
148
+ )
149
+
150
+
151
+ def initialize_qparams(
152
+ module: Module,
153
+ base_name: str,
154
+ quantization_args: QuantizationArgs,
155
+ observed_shape: Tuple[int],
156
+ observed_dtype: torch.dtype,
157
+ force_zero_point: bool = True,
158
+ ):
159
+ """
160
+ Initialize quantization parameters for a given basename according to the passed
161
+ quantization args. The shape and dtype of the observed weight/activation must also
162
+ be provided.
163
+
164
+ Scales will always be initialized. Global scales are initialized depending on args.
165
+ Zero points will be initialized if not symmetric or if `force_zero_point` is True.
166
+
167
+ :param module: module to register qparams to
168
+ :param base_name: base name of qparams, for example "input", "weight", "k", "v"
169
+ :param quantization_args: arguments for quantization
170
+ :param observed_shape: last (right-most) known dimensions of the observed weight/act
171
+ :param observed_dtype: dtype of the observed weight/actt
172
+ :param force_zero_point: force the zero_point parameter to be initialized
173
+ """
174
+ strategy = quantization_args.strategy
175
+ dynamic = quantization_args.dynamic
176
+ actorder = quantization_args.actorder
177
+ device = get_execution_device(module) # avoid performing intialization ops on cpu
178
+
179
+ # Skip all intialization for fully dynamic quantization
180
+ if dynamic is True:
181
+ return
182
+
183
+ # 0. Create global scale for tensor-group quantization
184
+ if strategy == QuantizationStrategy.TENSOR_GROUP:
185
+ init_global_scale = Parameter(
186
+ torch.empty(1, dtype=torch.float32, device=device),
187
+ requires_grad=False,
188
+ )
189
+ register_offload_parameter(
190
+ module, f"{base_name}_global_scale", init_global_scale
191
+ )
192
+
193
+ # Skip scale/zp initialization for locally dynamic quantization
194
+ if dynamic == DynamicType.LOCAL:
195
+ return
196
+
197
+ # 1. Infer expected scale/zp shape
198
+ if strategy == QuantizationStrategy.TENSOR:
199
+ expected_shape = (1,)
200
+
201
+ elif strategy == QuantizationStrategy.TOKEN:
202
+ expected_shape = (1, 1)
203
+
204
+ elif strategy == QuantizationStrategy.CHANNEL:
205
+ if len(observed_shape) < 2:
206
+ raise ValueError("Channel quant requires at least 2 observed dimensions")
207
+
208
+ expected_shape = (observed_shape[-2], 1)
209
+
210
+ elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
211
+ assert quantization_args.group_size is not None
212
+ if len(observed_shape) < 1:
213
+ raise ValueError("Group quant requires at least 1 observed dimension")
214
+
215
+ group_size = quantization_args.group_size
216
+ num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy)
217
+ expected_shape = (*observed_shape[:-1], num_groups)
218
+
219
+ # initialize activation ordering if applicable
220
+ if actorder == ActivationOrdering.GROUP:
221
+ init_g_idx = Parameter(
222
+ torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int),
223
+ requires_grad=False,
224
+ )
225
+ register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
226
+
227
+ elif strategy == QuantizationStrategy.BLOCK:
228
+ assert quantization_args.block_structure is not None
229
+ if len(observed_shape) < 2:
230
+ raise ValueError("Block quant requires at least 2 observed dimensions")
231
+
232
+ block_structure = quantization_args.block_structure
233
+ num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy)
234
+ num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy)
235
+ expected_shape = (num_rows, num_cols)
236
+
237
+ else:
238
+ assert False, f"Unknown strategy {strategy}"
239
+
240
+ # 2. Identify quantization scale and zp dtype
241
+ scale_dtype = observed_dtype
242
+
243
+ if is_fp4(quantization_args=quantization_args):
244
+ scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
245
+ else:
246
+ # TODO: consider erroring out in the future as if the dtype if not one of these,
247
+ # there is likely bug
248
+ if scale_dtype not in [
249
+ torch.float16,
250
+ torch.bfloat16,
251
+ torch.float32,
252
+ torch.float64,
253
+ ]:
254
+ scale_dtype = torch.bfloat16
255
+ zp_dtype = quantization_args.pytorch_dtype()
256
+
257
+ # 3. Initializes scale/zp for the module
258
+ init_scale = Parameter(
259
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
260
+ requires_grad=False,
261
+ )
262
+ register_offload_parameter(module, f"{base_name}_scale", init_scale)
263
+
264
+ if force_zero_point or not quantization_args.symmetric:
265
+ init_zero_point = Parameter(
266
+ torch.zeros(expected_shape, device=device, dtype=zp_dtype),
267
+ requires_grad=False,
268
+ )
269
+ register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
270
+
271
+
272
+ def _initialize_attn_scales(module: Module) -> None:
273
+ """Initlaize k_scale, v_scale for self_attn"""
274
+
275
+ expected_shape = 1 # per tensor
276
+
277
+ param = next(module.parameters())
278
+ scale_dtype = param.dtype
279
+ device = param.device
280
+
281
+ init_scale = Parameter(
282
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
283
+ requires_grad=False,
284
+ )
285
+ register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale)
286
+
287
+ init_scale = Parameter(
288
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
289
+ requires_grad=False,
290
+ )
291
+ register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
@@ -27,6 +27,7 @@ from compressed_tensors.quantization.quant_args import (
27
27
  )
28
28
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
29
29
  from compressed_tensors.utils import deprecated
30
+ from loguru import logger
30
31
  from torch import FloatTensor, IntTensor, Tensor
31
32
  from torch.nn import Module
32
33
 
@@ -47,6 +48,7 @@ __all__ = [
47
48
  "calculate_qparams",
48
49
  "generate_gparam",
49
50
  "is_fp4",
51
+ "strategy_cdiv",
50
52
  ]
51
53
 
52
54
  # target the self_attn layer
@@ -461,3 +463,26 @@ def generate_gparam(
461
463
  max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
462
464
  global_scale = scale_data.max * quant_data.max / max_val_pos
463
465
  return global_scale.to(dtype).reshape([1])
466
+
467
+
468
+ def strategy_cdiv(
469
+ value: int,
470
+ divisor: int,
471
+ strategy: Optional[QuantizationStrategy],
472
+ strict: bool = False,
473
+ ) -> int:
474
+ dividend = math.ceil(value / divisor)
475
+ if dividend * divisor != value:
476
+ message = (
477
+ f"{strategy} quantization strategy requires strict division of "
478
+ f"weight/activation size {value} and group/block size {divisor}. "
479
+ "consider reducing the group/block size or ignoring modules with "
480
+ f"weights not divisible by {divisor}"
481
+ )
482
+ if strict:
483
+ raise ValueError(message)
484
+
485
+ else:
486
+ logger.bind(log_once=True).warning(message)
487
+
488
+ return dividend
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.12.3.a20251003'
20
+ __version__ = version = '0.12.3.a20251007'
21
21
  __version_tuple__ = version_tuple = (0, 12, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251003
3
+ Version: 0.12.3a20251007
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,270 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import logging
17
- import math
18
- import warnings
19
- from typing import Optional
20
-
21
- import torch
22
- from compressed_tensors.quantization import (
23
- FP8_E4M3_DATA,
24
- ActivationOrdering,
25
- KVCacheScaleType,
26
- QuantizationArgs,
27
- QuantizationMetadata,
28
- QuantizationScheme,
29
- QuantizationStatus,
30
- QuantizationStrategy,
31
- )
32
- from compressed_tensors.quantization.lifecycle.forward import (
33
- wrap_module_forward_quantized,
34
- )
35
- from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
36
- from compressed_tensors.utils import (
37
- disable_hf_hook,
38
- get_execution_device,
39
- register_offload_parameter,
40
- )
41
- from torch.nn import Module, Parameter
42
-
43
-
44
- __all__ = [
45
- "initialize_module_for_quantization",
46
- "is_attention_module",
47
- ]
48
-
49
-
50
- _LOGGER = logging.getLogger(__name__)
51
-
52
-
53
- def initialize_module_for_quantization(
54
- module: Module,
55
- scheme: Optional[QuantizationScheme] = None,
56
- force_zero_point: bool = True,
57
- ):
58
- """
59
- Attaches appropriate scales, zero points, and observers to a layer
60
- given its target quantization scheme.
61
-
62
- Previously initialized scales and zero points will be removed from
63
- module if they no longer apply to the scheme
64
-
65
- :param module: module to set for calibration
66
- :param scheme: scheme to use for quantization. if None is provided,
67
- will attempt to use scheme stored in the module under `quantization_scheme`,
68
- if not provided, the layer will be skipped
69
- :param force_zero_point: whether to force initialization of a zero point for
70
- symmetric quantization
71
- """
72
- # TODO: don't initialize parameters when running decompression
73
- scheme = scheme or getattr(module, "quantization_scheme", None)
74
- if scheme is None:
75
- # no scheme passed and layer not targeted for quantization - skip
76
- return
77
-
78
- QuantizationMetadata.clear_all_qparams(module)
79
-
80
- if is_attention_module(module):
81
- # quantized actions based on calltime status
82
- _initialize_attn_scales(module)
83
-
84
- else:
85
- if scheme.input_activations is not None:
86
- _initialize_scale_zero_point(
87
- module,
88
- "input",
89
- scheme.input_activations,
90
- force_zero_point=force_zero_point,
91
- )
92
-
93
- if scheme.weights is not None:
94
- if hasattr(module, "weight"):
95
- weight_shape = None
96
- if isinstance(module, torch.nn.Linear):
97
- weight_shape = module.weight.shape
98
- _initialize_scale_zero_point(
99
- module,
100
- "weight",
101
- scheme.weights,
102
- weight_shape=weight_shape,
103
- force_zero_point=force_zero_point,
104
- )
105
- else:
106
- _LOGGER.warning(
107
- f"module type {type(module)} targeted for weight quantization but "
108
- "has no attribute weight, skipping weight quantization "
109
- f"for {type(module)}"
110
- )
111
-
112
- if scheme.output_activations is not None:
113
- if not is_kv_cache_quant_scheme(scheme):
114
- _initialize_scale_zero_point(
115
- module, "output", scheme.output_activations
116
- )
117
-
118
- module.quantization_scheme = scheme
119
- module.quantization_status = QuantizationStatus.INITIALIZED
120
-
121
- with disable_hf_hook(module):
122
- # wrap forward call of module to perform
123
- # quantized actions based on calltime status
124
- wrap_module_forward_quantized(module, scheme)
125
-
126
-
127
- def is_attention_module(module: Module):
128
- return "attention" in module.__class__.__name__.lower() and (
129
- hasattr(module, "k_proj")
130
- or hasattr(module, "v_proj")
131
- or hasattr(module, "qkv_proj")
132
- )
133
-
134
-
135
- def _initialize_scale_zero_point(
136
- module: Module,
137
- base_name: str,
138
- quantization_args: QuantizationArgs,
139
- weight_shape: Optional[torch.Size] = None,
140
- force_zero_point: bool = True,
141
- ):
142
- if quantization_args.dynamic is True:
143
- return
144
-
145
- # initialize on execution device to avoid performing quantized ops on cpu
146
- device = get_execution_device(module)
147
-
148
- # 1. Create global_scales for tensor_group - generates
149
- # a per tensor scale
150
- if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
151
- init_global_scale = Parameter(
152
- torch.empty(1, dtype=torch.float32, device=device),
153
- requires_grad=False,
154
- )
155
- register_offload_parameter(
156
- module, f"{base_name}_global_scale", init_global_scale
157
- )
158
-
159
- # 2. Infer expected scale/zero point shape
160
- if quantization_args.strategy == QuantizationStrategy.TOKEN:
161
- expected_shape = (1, 1)
162
- else:
163
- expected_shape = 1
164
-
165
- if base_name == "weight" and weight_shape is not None:
166
- if quantization_args.strategy == QuantizationStrategy.CHANNEL:
167
- # (output_channels, 1) - only for weights
168
- expected_shape = (weight_shape[0], 1)
169
- elif quantization_args.strategy in (
170
- QuantizationStrategy.TENSOR_GROUP,
171
- QuantizationStrategy.GROUP,
172
- ):
173
- # GROUP/TENSOR_GROUP for both weights and activations
174
- num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
175
- expected_shape = (weight_shape[0], max(num_groups, 1))
176
- elif quantization_args.strategy == QuantizationStrategy.BLOCK:
177
- # For block quantization, scale shape should match number of blocks - only
178
- # for weights
179
- if quantization_args.block_structure is None:
180
- raise ValueError(
181
- "Block quantization requires block_structure to be specified"
182
- )
183
- block_height, block_width = quantization_args.block_structure
184
- rows, cols = weight_shape[-2], weight_shape[-1]
185
- num_rows_blocks = math.ceil(rows / block_height)
186
- num_cols_blocks = math.ceil(cols / block_width)
187
-
188
- # Warn if dimensions don't divide evenly
189
- if rows % block_height != 0 or cols % block_width != 0:
190
- warnings.warn(
191
- f"Block quantization: tensor shape {weight_shape} does not divide"
192
- f"evenly by block structure {quantization_args.block_structure}. "
193
- f"Some blocks will be incomplete which may affect quantization"
194
- "quality.",
195
- UserWarning,
196
- )
197
-
198
- expected_shape = (num_rows_blocks, num_cols_blocks)
199
- elif quantization_args.strategy == QuantizationStrategy.BLOCK:
200
- warnings.warn(
201
- f"BLOCK quantization not supported for {base_name} activations. "
202
- f"Falling back to tensor-level quantization.",
203
- UserWarning,
204
- )
205
- expected_shape = 1
206
-
207
- # 3. Identify quantization scale and zp dtype
208
- scale_dtype = module.weight.dtype
209
-
210
- if is_fp4(quantization_args=quantization_args):
211
- scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
212
- else:
213
- # TODO: consider erroring out in the future as if the dtype if not one of these,
214
- # there is likely bug
215
- if scale_dtype not in [
216
- torch.float16,
217
- torch.bfloat16,
218
- torch.float32,
219
- torch.float64,
220
- ]:
221
- scale_dtype = torch.bfloat16
222
- zp_dtype = quantization_args.pytorch_dtype()
223
-
224
- # 4. Initializes empty scale, zero point, and g_idx parameters for the module
225
- # do not init scales for quantzation_args.dynamic == DynamicType.local
226
- if not quantization_args.dynamic:
227
- init_scale = Parameter(
228
- torch.empty(expected_shape, dtype=scale_dtype, device=device),
229
- requires_grad=False,
230
- )
231
- register_offload_parameter(module, f"{base_name}_scale", init_scale)
232
-
233
- if force_zero_point or not quantization_args.symmetric:
234
- init_zero_point = Parameter(
235
- torch.zeros(expected_shape, device=device, dtype=zp_dtype),
236
- requires_grad=False,
237
- )
238
- register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
239
-
240
- # only grouped activation ordering has g_idx
241
- if quantization_args.actorder == ActivationOrdering.GROUP:
242
- g_idx_shape = (weight_shape[1],)
243
- g_idx_dtype = torch.int
244
- init_g_idx = Parameter(
245
- torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
246
- requires_grad=False,
247
- )
248
- register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
249
-
250
-
251
- def _initialize_attn_scales(module: Module) -> None:
252
- """Initlaize k_scale, v_scale for self_attn"""
253
-
254
- expected_shape = 1 # per tensor
255
-
256
- param = next(module.parameters())
257
- scale_dtype = param.dtype
258
- device = param.device
259
-
260
- init_scale = Parameter(
261
- torch.empty(expected_shape, dtype=scale_dtype, device=device),
262
- requires_grad=False,
263
- )
264
- register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale)
265
-
266
- init_scale = Parameter(
267
- torch.empty(expected_shape, dtype=scale_dtype, device=device),
268
- requires_grad=False,
269
- )
270
- register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)