compressed-tensors 0.10.3a20250805__tar.gz → 0.10.3a20250811__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {compressed_tensors-0.10.3a20250805/src/compressed_tensors.egg-info → compressed_tensors-0.10.3a20250811}/PKG-INFO +1 -1
  2. compressed_tensors-0.10.3a20250811/pyproject.toml +16 -0
  3. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/base.py +8 -3
  4. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +55 -33
  5. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/quant_args.py +3 -1
  6. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/quant_config.py +8 -2
  7. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/quant_scheme.py +4 -2
  8. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/apply.py +4 -0
  9. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/factory/base.py +49 -4
  10. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/factory/hadamard.py +15 -8
  11. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/factory/matrix_multiply.py +18 -8
  12. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/transform_args.py +9 -1
  13. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/transform_config.py +2 -40
  14. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/transform_scheme.py +8 -1
  15. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/__init__.py +1 -0
  16. compressed_tensors-0.10.3a20250811/src/compressed_tensors/utils/type.py +74 -0
  17. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/version.py +1 -1
  18. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
  19. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors.egg-info/SOURCES.txt +3 -0
  20. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/conftest.py +4 -3
  21. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/factory/test_correctness.py +15 -17
  22. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/factory/test_memory.py +6 -6
  23. compressed_tensors-0.10.3a20250811/tests/test_transform/factory/test_serialization.py +54 -0
  24. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/test_transform_config.py +14 -11
  25. compressed_tensors-0.10.3a20250811/tests/test_utils/test_type.py +79 -0
  26. compressed_tensors-0.10.3a20250805/pyproject.toml +0 -7
  27. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/.gitkeep +0 -0
  28. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/actions/test/action.yml +0 -0
  29. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/scripts/step-status +0 -0
  30. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/build-test.yml +0 -0
  31. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/build.yml +0 -0
  32. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/report.yml +0 -0
  33. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/test-check.yaml +0 -0
  34. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/test.yml +0 -0
  35. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/trigger-all.yml +0 -0
  36. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/upload.yml +0 -0
  37. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.gitignore +0 -0
  38. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/LICENSE +0 -0
  39. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/Makefile +0 -0
  40. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/README.md +0 -0
  41. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
  42. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/bit_packing/int4_config.json +0 -0
  43. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/bitmask_compression.ipynb +0 -0
  44. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/llama_1.1b/ex_config_quantization.py +0 -0
  45. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
  46. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/llama_1.1b/example_quant_config.json +0 -0
  47. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
  48. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/quantize_and_pack_int4.ipynb +0 -0
  49. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/setup.cfg +0 -0
  50. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/setup.py +0 -0
  51. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/__init__.py +0 -0
  52. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/README.md +0 -0
  53. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/__init__.py +0 -0
  54. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/__init__.py +0 -0
  55. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/base.py +0 -0
  56. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/helpers.py +0 -0
  57. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  58. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  59. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  60. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  61. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
  62. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  63. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  64. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  65. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  66. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
  67. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  68. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  69. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  70. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/config/__init__.py +0 -0
  71. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/config/base.py +0 -0
  72. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/config/dense.py +0 -0
  73. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
  74. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  75. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/linear/__init__.py +0 -0
  76. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  77. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/__init__.py +0 -0
  78. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  79. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  80. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  81. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  82. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  83. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  84. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  85. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  86. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/registry/__init__.py +0 -0
  87. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/registry/registry.py +0 -0
  88. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/__init__.py +0 -0
  89. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/factory/__init__.py +0 -0
  90. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
  91. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/utils/__init__.py +0 -0
  92. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
  93. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
  94. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/utils/matrix.py +0 -0
  95. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/helpers.py +0 -0
  96. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/internal.py +0 -0
  97. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/match.py +0 -0
  98. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/offload.py +0 -0
  99. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/permutations_24.py +0 -0
  100. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/permute.py +0 -0
  101. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  102. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  103. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  104. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors.egg-info/requires.txt +0 -0
  105. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors.egg-info/top_level.txt +0 -0
  106. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/__init__.py +0 -0
  107. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/conftest.py +0 -0
  108. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/__init__.py +0 -0
  109. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/model_compressors/__init__.py +0 -0
  110. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
  111. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
  112. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
  113. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
  114. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
  115. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
  116. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
  117. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
  118. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
  119. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
  120. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
  121. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_configs/__init__.py +0 -0
  122. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_configs/test_base.py +0 -0
  123. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
  124. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_linear/__init__.py +0 -0
  125. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_linear/test_compressed_linear.py +0 -0
  126. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/__init__.py +0 -0
  127. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/__init__.py +0 -0
  128. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/conftest.py +0 -0
  129. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_apply.py +0 -0
  130. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
  131. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
  132. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_forward.py +0 -0
  133. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
  134. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
  135. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
  136. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_configs/__init__.py +0 -0
  137. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
  138. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_configs/test_strategies.py +0 -0
  139. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_quant_args.py +0 -0
  140. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_quant_config.py +0 -0
  141. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_quant_scheme.py +0 -0
  142. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_utils/test_helpers.py +0 -0
  143. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_registry.py +0 -0
  144. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/test_transform_args.py +0 -0
  145. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/test_transform_scheme.py +0 -0
  146. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/utils/test_hadamard.py +0 -0
  147. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_utils/__init__.py +0 -0
  148. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_utils/test_helpers.py +0 -0
  149. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_utils/test_match.py +0 -0
  150. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_utils/test_offload.py +0 -0
  151. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_utils/test_safetensors_load.py +0 -0
  152. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/testing_utils.py +0 -0
  153. {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250805
3
+ Version: 0.10.3a20250811
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -0,0 +1,16 @@
1
+ [build-system]
2
+ requires = ["setuptools", "wheel", "setuptools_scm==8.2.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.black]
6
+ line-length = 88
7
+ target-version = ['py36']
8
+
9
+ [tool.pytest.ini_options]
10
+ markers = [
11
+ "unit: tests to ensure code correctness and regression test functionality",
12
+ "smoke: quick tests to check basic functionality",
13
+ "sanity: tests to ensure that new changes do not break existing functionality",
14
+ "regression: detailed tests to ensure major functions work correctly",
15
+ "integration: tests which integrate with a third party service such as HF",
16
+ ]
@@ -12,9 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- SPARSITY_CONFIG_NAME = "sparsity_config"
15
+ # configs
16
16
  QUANTIZATION_CONFIG_NAME = "quantization_config"
17
- COMPRESSION_CONFIG_NAME = "compression_config"
18
- KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
17
+ SPARSITY_CONFIG_NAME = "sparsity_config"
18
+ TRANSFORM_CONFIG_NAME = "transform_config"
19
+
20
+ # required fields
19
21
  COMPRESSION_VERSION_NAME = "version"
20
22
  QUANTIZATION_METHOD_NAME = "quant_method"
23
+
24
+ # auxillary configs
25
+ KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
@@ -29,6 +29,7 @@ from compressed_tensors.base import (
29
29
  QUANTIZATION_CONFIG_NAME,
30
30
  QUANTIZATION_METHOD_NAME,
31
31
  SPARSITY_CONFIG_NAME,
32
+ TRANSFORM_CONFIG_NAME,
32
33
  )
33
34
  from compressed_tensors.compressors.base import BaseCompressor
34
35
  from compressed_tensors.compressors.sparse_compressors import DenseCompressor
@@ -43,6 +44,7 @@ from compressed_tensors.quantization import (
43
44
  )
44
45
  from compressed_tensors.quantization.lifecycle import expand_target_names
45
46
  from compressed_tensors.quantization.utils import is_module_quantized
47
+ from compressed_tensors.transform import TransformConfig
46
48
  from compressed_tensors.utils import (
47
49
  align_module_device,
48
50
  delete_offload_parameter,
@@ -105,6 +107,7 @@ class ModelCompressor:
105
107
 
106
108
  sparsity_config: Optional[SparsityCompressionConfig] = None
107
109
  quantization_config: Optional[QuantizationConfig] = None
110
+ transform_config: Optional[TransformConfig] = None
108
111
 
109
112
  @classmethod
110
113
  def from_pretrained(
@@ -144,6 +147,8 @@ class ModelCompressor:
144
147
 
145
148
  sparsity_config = cls.parse_sparsity_config(compression_config)
146
149
  quantization_config = cls.parse_quantization_config(compression_config)
150
+ # TODO: transform config is not support by CompressedTensorsConfig yet
151
+
147
152
  if sparsity_config is None and quantization_config is None:
148
153
  return None
149
154
 
@@ -177,20 +182,27 @@ class ModelCompressor:
177
182
  algorithm
178
183
  :return: compressor for the configs, or None if model is not compressed
179
184
  """
185
+ # reconstruct config from schemes attached to modules
180
186
  quantization_config = QuantizationConfig.from_pretrained(
181
187
  model, format=quantization_format
182
188
  )
183
189
 
190
+ # use config passed as argument
184
191
  if isinstance(sparsity_config, str): # we passed in a sparsity format
185
192
  sparsity_config = SparsityCompressionConfig.load_from_registry(
186
193
  sparsity_config
187
194
  )
188
195
 
189
- if sparsity_config is None and quantization_config is None:
196
+ # use config attached to model
197
+ transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
198
+
199
+ if not any((quantization_config, sparsity_config, transform_config)):
190
200
  return None
191
201
 
192
202
  return cls(
193
- sparsity_config=sparsity_config, quantization_config=quantization_config
203
+ sparsity_config=sparsity_config,
204
+ quantization_config=quantization_config,
205
+ transform_config=transform_config,
194
206
  )
195
207
 
196
208
  @staticmethod
@@ -254,13 +266,17 @@ class ModelCompressor:
254
266
  self,
255
267
  sparsity_config: Optional[SparsityCompressionConfig] = None,
256
268
  quantization_config: Optional[QuantizationConfig] = None,
269
+ transform_config: Optional[TransformConfig] = None,
257
270
  ):
258
271
  self.sparsity_config = sparsity_config
259
272
  self.quantization_config = quantization_config
273
+ self.transform_config = transform_config
274
+
260
275
  self.sparsity_compressor = None
261
276
  self.quantization_compressor: Optional[
262
277
  Union[BaseQuantizationCompressor, DenseCompressor]
263
278
  ] = None
279
+ # no transform compressor is required
264
280
 
265
281
  if sparsity_config is not None:
266
282
  self.sparsity_compressor = BaseCompressor.load_from_registry(
@@ -640,43 +656,49 @@ class ModelCompressor:
640
656
 
641
657
  :param save_directory: path to a folder containing a HF model config
642
658
  """
643
- if self.quantization_config is None and self.sparsity_config is None:
659
+ # this check is also done in `from_pretrained_model`,
660
+ # but not in `from_pretrained`` or `from_compression_config``
661
+ if not any(
662
+ (self.quantization_config, self.sparsity_config, self.transform_config)
663
+ ):
644
664
  return
645
665
 
666
+ # write to config.json file, regardless of whether it exists already
667
+ # overwrite previous config and version if already existing
646
668
  config_file_path = os.path.join(save_directory, CONFIG_NAME)
647
- if not os.path.exists(config_file_path):
648
- _LOGGER.warning(
649
- f"Could not find a valid model config file in "
650
- f"{save_directory}. Compression config will not be saved."
651
- )
652
- return
669
+ if os.path.exists(config_file_path):
670
+ with open(config_file_path, "r") as file:
671
+ config_data = json.load(file)
672
+ else:
673
+ config_data = {}
653
674
 
654
- with open(config_file_path, "r") as config_file:
655
- config_data = json.load(config_file)
675
+ # serialize configs into json
676
+ qconfig_data = (
677
+ self.quantization_config.model_dump(exclude=["quant_method", "format"])
678
+ if self.quantization_config is not None
679
+ else {}
680
+ )
681
+ sconfig_data = (
682
+ self.sparsity_config.model_dump()
683
+ if self.sparsity_config is not None
684
+ else {}
685
+ )
686
+ tconfig_data = (
687
+ self.transform_config.model_dump()
688
+ if self.transform_config is not None
689
+ else {}
690
+ )
656
691
 
657
- # required metadata whenever a quantization or sparsity config is present
658
- # overwrite previous config and version if already existing
659
- config_data[QUANTIZATION_CONFIG_NAME] = {}
660
- config_data[QUANTIZATION_CONFIG_NAME][
661
- COMPRESSION_VERSION_NAME
662
- ] = compressed_tensors.__version__
663
- if self.quantization_config is not None:
664
- self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
665
- else:
666
- config_data[QUANTIZATION_CONFIG_NAME][
667
- QUANTIZATION_METHOD_NAME
668
- ] = DEFAULT_QUANTIZATION_METHOD
669
-
670
- # quantization and sparsity configs
671
- if self.quantization_config is not None:
672
- quant_config_data = self.quantization_config.model_dump()
673
- config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
674
- if self.sparsity_config is not None:
675
- sparsity_config_data = self.sparsity_config.model_dump()
676
- config_data[QUANTIZATION_CONFIG_NAME][
677
- SPARSITY_CONFIG_NAME
678
- ] = sparsity_config_data
692
+ # construct compression (quantization) config
693
+ config_data[QUANTIZATION_CONFIG_NAME] = {
694
+ COMPRESSION_VERSION_NAME: compressed_tensors.__version__,
695
+ QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD,
696
+ SPARSITY_CONFIG_NAME: sconfig_data,
697
+ TRANSFORM_CONFIG_NAME: tconfig_data,
698
+ **qconfig_data,
699
+ }
679
700
 
701
+ # write results to config.json file
680
702
  with open(config_file_path, "w") as config_file:
681
703
  json.dump(config_data, config_file, indent=2, sort_keys=True)
682
704
 
@@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Union
19
19
  import torch
20
20
  from compressed_tensors.utils import Aliasable
21
21
  from compressed_tensors.utils.helpers import deprecated
22
- from pydantic import BaseModel, Field, field_validator, model_validator
22
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
23
23
 
24
24
 
25
25
  __all__ = [
@@ -358,6 +358,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
358
358
  def get_observer(self) -> str:
359
359
  return self.observer
360
360
 
361
+ model_config = ConfigDict(extra="forbid")
362
+
361
363
 
362
364
  def round_to_quantized_type(
363
365
  tensor: torch.Tensor, args: QuantizationArgs
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Dict, List, Optional, Union
16
+ from typing import Annotated, Any, Dict, List, Optional, Union
17
17
 
18
18
  from compressed_tensors.config import CompressionFormat
19
19
  from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
@@ -26,7 +26,7 @@ from compressed_tensors.quantization.utils import (
26
26
  module_type,
27
27
  parse_out_kv_cache_args,
28
28
  )
29
- from pydantic import BaseModel, Field
29
+ from pydantic import BaseModel, ConfigDict, Field
30
30
  from torch.nn import Module
31
31
 
32
32
 
@@ -142,6 +142,9 @@ class QuantizationConfig(BaseModel):
142
142
  quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
143
143
  global_compression_ratio: Optional[float] = None
144
144
  ignore: Optional[List[str]] = Field(default_factory=list)
145
+ # `run_compressed` is a dummy, unused arg for backwards compatibility
146
+ # see: https://github.com/huggingface/transformers/pull/39324
147
+ run_compressed: Annotated[Any, Field(exclude=True)] = None
145
148
 
146
149
  def model_post_init(self, __context):
147
150
  """
@@ -254,3 +257,6 @@ class QuantizationConfig(BaseModel):
254
257
  return True
255
258
 
256
259
  return False
260
+
261
+ # TODO set `extra="forbid"` when upstream transformers is compatible
262
+ model_config = ConfigDict(extra="ignore")
@@ -14,7 +14,7 @@
14
14
 
15
15
  import warnings
16
16
  from copy import deepcopy
17
- from typing import Any, Dict, List, Optional
17
+ from typing import List, Optional
18
18
 
19
19
  from compressed_tensors.quantization.quant_args import (
20
20
  DynamicType,
@@ -22,7 +22,7 @@ from compressed_tensors.quantization.quant_args import (
22
22
  QuantizationStrategy,
23
23
  QuantizationType,
24
24
  )
25
- from pydantic import BaseModel, model_validator
25
+ from pydantic import BaseModel, ConfigDict, model_validator
26
26
 
27
27
 
28
28
  __all__ = [
@@ -81,6 +81,8 @@ class QuantizationScheme(BaseModel):
81
81
 
82
82
  return model
83
83
 
84
+ model_config = ConfigDict(extra="forbid")
85
+
84
86
 
85
87
  """
86
88
  Pre-Set Quantization Scheme Args
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import torch
16
+ from compressed_tensors import TRANSFORM_CONFIG_NAME
16
17
  from compressed_tensors.transform import TransformConfig, TransformFactory
17
18
 
18
19
 
@@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
30
31
  for name, scheme in config.config_groups.items():
31
32
  factory = TransformFactory.from_scheme(scheme, name=name)
32
33
  factory.apply_to_model(model)
34
+
35
+ # attach config to model for compression/serialization
36
+ setattr(model, TRANSFORM_CONFIG_NAME, config)
@@ -13,11 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from abc import ABC, abstractmethod
16
- from typing import Optional
16
+ from collections import defaultdict
17
+ from typing import List, Optional, Set, Tuple
17
18
 
18
19
  import torch
19
20
  import torch.nn.utils.parametrize as P
20
- from compressed_tensors import InternalModule
21
21
  from compressed_tensors.registry.registry import RegistryMixin, T
22
22
  from compressed_tensors.transform import (
23
23
  TransformArgs,
@@ -33,6 +33,7 @@ from compressed_tensors.utils import (
33
33
  register_offload_module,
34
34
  update_offload_parameter,
35
35
  )
36
+ from compressed_tensors.utils.internal import InternalModule
36
37
  from torch import Tensor
37
38
  from torch.nn import Module, Parameter
38
39
 
@@ -49,10 +50,13 @@ class TransformFactory(RegistryMixin, ABC):
49
50
  :param seed: random seed used to transform weight randomization
50
51
  """
51
52
 
53
+ transforms: List["TransformBase"]
54
+
52
55
  def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
53
56
  self.name = name
54
57
  self.scheme = scheme
55
58
  self.generator = torch.Generator()
59
+ self.transforms = list()
56
60
  if seed is not None:
57
61
  self.generator.manual_seed(seed)
58
62
 
@@ -90,6 +94,8 @@ class TransformFactory(RegistryMixin, ABC):
90
94
  for _, module in match_named_modules(model, arg.targets, arg.ignore):
91
95
  self._apply_to_module(module, arg)
92
96
 
97
+ self._update_tied_weights()
98
+
93
99
  def _apply_to_module(self, module: Module, args: TransformArgs):
94
100
  """
95
101
  Create transforms and apply them to the module
@@ -97,9 +103,17 @@ class TransformFactory(RegistryMixin, ABC):
97
103
  :param module: target module to apply transforms to
98
104
  :param args: defines how the transform will be applied to the target module
99
105
  """
106
+ if has_offloaded_params(module):
107
+ if module._hf_hook.place_submodules:
108
+ raise NotImplementedError(
109
+ "Applying transforms to offloaded submodules with "
110
+ "`place_submodules=True` is not supported"
111
+ )
112
+
100
113
  # create transform as submodule
101
114
  transform_name = f"{self.name}_{args.location}"
102
115
  transform = self.create_transform(module, args)
116
+ self.transforms.append(transform)
103
117
  register_offload_module(module, transform_name, transform)
104
118
 
105
119
  # register input transformation hook
@@ -128,8 +142,9 @@ class TransformFactory(RegistryMixin, ABC):
128
142
  raise ValueError("Offloaded training is not supported")
129
143
  P.register_parametrization(module, "weight", transform)
130
144
 
131
- # transform is no longer needed (unfusing is not supported)
132
- delete_offload_module(module, transform_name)
145
+ else:
146
+ # transform is no longer needed (unfusing is not supported)
147
+ delete_offload_module(module, transform_name)
133
148
 
134
149
  # register output transformation hook
135
150
  elif args.location == TransformLocation.OUTPUT:
@@ -143,6 +158,31 @@ class TransformFactory(RegistryMixin, ABC):
143
158
  else:
144
159
  raise NotImplementedError()
145
160
 
161
+ def _update_tied_weights(self):
162
+ """
163
+ Populate the `_dynamic_tied_weights_keys` attribute of transforms,
164
+ which is used by transformers to detect and remove shared pointers
165
+ during saving
166
+ """
167
+ # map from data_ptrs to keys
168
+ ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
169
+ for transform in self.transforms:
170
+ for name, param in transform.named_parameters(recurse=False):
171
+ # NOTE: previously asserted that parent._hf_hook.place_submodules=False
172
+ if has_offloaded_params(transform):
173
+ param = transform._hf_hook.weights_map[name]
174
+ ptr_to_keys[param.data_ptr()].append((transform, name))
175
+
176
+ # populate `_dynamic_tied_weights_keys` if there is more than one key
177
+ # and ensure that they share tensors
178
+ for shared_keys in ptr_to_keys.values():
179
+ if len(shared_keys) > 1:
180
+ tensor = getattr(shared_keys[0][0], shared_keys[0][1])
181
+
182
+ for transform, name in shared_keys:
183
+ transform._dynamic_tied_weights_keys.add(name)
184
+ setattr(transform, name, tensor)
185
+
146
186
 
147
187
  class TransformBase(InternalModule, ABC):
148
188
  """
@@ -151,6 +191,11 @@ class TransformBase(InternalModule, ABC):
151
191
 
152
192
  args: TransformArgs
153
193
  weight: Parameter
194
+ _dynamic_tied_weights_keys: Set[str]
195
+
196
+ def __init__(self):
197
+ super().__init__()
198
+ self._dynamic_tied_weights_keys = set()
154
199
 
155
200
  @abstractmethod
156
201
  def forward(self, value: Tensor) -> Tensor:
@@ -12,8 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import math
16
- from typing import Optional, Union
15
+ from typing import Optional
17
16
 
18
17
  import torch
19
18
  from compressed_tensors.transform import TransformArgs, TransformScheme
@@ -26,7 +25,7 @@ from compressed_tensors.transform.utils.matrix import (
26
25
  from compressed_tensors.utils import get_execution_device, get_offloaded_device
27
26
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
28
27
  from torch import Tensor, device, dtype
29
- from torch.nn import Linear, Module, Parameter
28
+ from torch.nn import Module, Parameter
30
29
 
31
30
 
32
31
  @TransformFactory.register("hadamard")
@@ -54,14 +53,14 @@ class HadamardFactory(TransformFactory):
54
53
  """
55
54
  assert hasattr(module, "weight")
56
55
  size = get_transform_size(module, args.location, self.scheme.head_dim)
57
- dtype = module.weight.dtype
56
+ dtype = self.scheme.precision
58
57
  device = get_offloaded_device(module)
59
58
  exec_device = get_execution_device(module)
60
59
 
61
60
  factory_kwargs = {"construct_device": exec_device}
62
61
  weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
63
62
  perm = self.perms[weight] if self.scheme.randomize else None
64
- return HadamardTransform(weight, perm, args, type(module))
63
+ return HadamardTransform(weight, perm, self.scheme, args, type(module))
65
64
 
66
65
  def _create_weight(
67
66
  self,
@@ -85,15 +84,18 @@ class HadamardTransform(TransformBase):
85
84
  self,
86
85
  weight: Parameter,
87
86
  perm: Optional[Parameter],
87
+ scheme: TransformScheme,
88
88
  args: TransformArgs,
89
89
  module_type: type[torch.nn.Module],
90
90
  ):
91
91
  super().__init__()
92
92
  self.weight = weight
93
93
  self.perm = perm
94
+ self.scheme = scheme
94
95
  self.args = args
95
96
  self.module_type = module_type
96
- self._scale = math.sqrt(weight.size(0))
97
+ self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
98
+ self._precision = scheme.precision if args.is_online() else torch.float64
97
99
 
98
100
  def forward(self, value: Tensor) -> Tensor:
99
101
  weight = self.weight
@@ -105,6 +107,11 @@ class HadamardTransform(TransformBase):
105
107
  weight = weight.T
106
108
 
107
109
  return (
108
- apply_transform_weight(weight, value, self.args.location, self.module_type)
110
+ apply_transform_weight(
111
+ weight.to(self._precision),
112
+ value.to(self._precision),
113
+ self.args.location,
114
+ self.module_type,
115
+ )
109
116
  / self._scale
110
- )
117
+ ).to(value.dtype)
@@ -24,7 +24,7 @@ from compressed_tensors.transform.utils.matrix import (
24
24
  from compressed_tensors.utils import get_offloaded_device
25
25
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
26
26
  from torch import Tensor, device, dtype
27
- from torch.nn import Linear, Module, Parameter
27
+ from torch.nn import Module, Parameter
28
28
 
29
29
 
30
30
  @TransformFactory.register("random-matrix")
@@ -52,14 +52,14 @@ class RandomMatrixFactory(TransformFactory):
52
52
  """
53
53
  assert hasattr(module, "weight")
54
54
  size = get_transform_size(module, args.location, self.scheme.head_dim)
55
- dtype = module.weight.dtype
55
+ dtype = self.scheme.precision
56
56
  device = get_offloaded_device(module)
57
57
 
58
58
  weight = self.weights[size, dtype, device]
59
59
  if args.inverse:
60
60
  weight = self.inverses[weight]
61
61
 
62
- return RandomMatrixTransform(weight, args, type(module))
62
+ return RandomMatrixTransform(weight, self.scheme, args, type(module))
63
63
 
64
64
  def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65
65
  # TODO: verify that weight is invertible (has non-zero determinant)
@@ -70,6 +70,7 @@ class RandomMatrixFactory(TransformFactory):
70
70
 
71
71
  def _create_inverse(self, weight: Parameter) -> Parameter:
72
72
  data = high_precision_invert(weight.data)
73
+ data = data.contiguous() # ensure proper serialization
73
74
  return Parameter(data, requires_grad=False)
74
75
 
75
76
 
@@ -77,25 +78,34 @@ class RandomMatrixTransform(TransformBase):
77
78
  def __init__(
78
79
  self,
79
80
  weight: Tensor,
81
+ scheme: TransformScheme,
80
82
  args: TransformArgs,
81
83
  module_type: type[torch.nn.Module],
82
84
  ):
83
85
  super().__init__()
84
86
  self.weight = weight # is an inverse if args.inverse
87
+ self.scheme = scheme
85
88
  self.args = args
86
89
  self.module_type = module_type
90
+ self._precision = scheme.precision if args.is_online() else torch.float64
87
91
 
88
92
  def forward(self, value: Tensor) -> Parameter:
89
93
  return apply_transform_weight(
90
- self.weight, value, self.args.location, self.module_type
91
- )
94
+ self.weight.to(self._precision),
95
+ value.to(self._precision),
96
+ self.args.location,
97
+ self.module_type,
98
+ ).to(value.dtype)
92
99
 
93
100
  def right_inverse(self, value: Tensor) -> Tensor:
94
101
  inverse = high_precision_invert(self.weight)
95
102
  return apply_transform_weight(
96
- inverse, value, self.args.location, self.module_type
97
- )
103
+ inverse.to(self._precision),
104
+ value.to(self._precision),
105
+ self.args.location,
106
+ self.module_type,
107
+ ).to(value.dtype)
98
108
 
99
109
 
100
110
  def high_precision_invert(weight: Tensor) -> Tensor:
101
- return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
111
+ return torch.linalg.inv(weight.to(torch.float64)).to(weight.dtype)
@@ -15,7 +15,7 @@
15
15
  from enum import Enum
16
16
  from typing import List
17
17
 
18
- from pydantic import BaseModel, Field, field_validator
18
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
19
19
 
20
20
 
21
21
  __all__ = ["TransformArgs", "TransformLocation"]
@@ -68,3 +68,11 @@ class TransformArgs(BaseModel, use_enum_values=True):
68
68
  if isinstance(value, str):
69
69
  return [value]
70
70
  return value
71
+
72
+ def is_online(self) -> bool:
73
+ return self.location not in (
74
+ TransformLocation.WEIGHT_INPUT,
75
+ TransformLocation.WEIGHT_OUTPUT,
76
+ )
77
+
78
+ model_config = ConfigDict(extra="forbid")
@@ -15,7 +15,7 @@
15
15
  from typing import Dict
16
16
 
17
17
  from compressed_tensors.transform import TransformArgs, TransformScheme
18
- from pydantic import BaseModel
18
+ from pydantic import BaseModel, ConfigDict
19
19
 
20
20
 
21
21
  __all__ = ["TransformConfig"]
@@ -32,42 +32,4 @@ class TransformConfig(BaseModel):
32
32
 
33
33
  config_groups: Dict[str, TransformScheme]
34
34
 
35
-
36
- # quip / quip sharp
37
- QUIP = TransformConfig(
38
- config_groups={
39
- "v": TransformScheme(
40
- type="hadamard",
41
- apply=[
42
- TransformArgs(
43
- targets=["Linear"],
44
- location="input", # non-mergable
45
- ),
46
- TransformArgs(
47
- targets=["Linear"],
48
- location="weight_input",
49
- inverse=True,
50
- ),
51
- ],
52
- randomize=True,
53
- ),
54
- "u": TransformScheme(
55
- type="hadamard",
56
- apply=[
57
- TransformArgs(
58
- targets=["Linear"],
59
- location="weight_output",
60
- ),
61
- TransformArgs(
62
- targets=["Linear"], location="output", inverse=True # non-mergable
63
- ),
64
- ],
65
- randomize=True,
66
- ),
67
- }
68
- )
69
-
70
-
71
- PRESET_CONFIGS = {
72
- "QUIP": QUIP,
73
- }
35
+ model_config = ConfigDict(extra="forbid")
@@ -14,8 +14,10 @@
14
14
 
15
15
  from typing import List, Optional
16
16
 
17
+ import torch
17
18
  from compressed_tensors.transform import TransformArgs
18
- from pydantic import BaseModel, Field
19
+ from compressed_tensors.utils import TorchDtype
20
+ from pydantic import BaseModel, ConfigDict, Field
19
21
 
20
22
 
21
23
  __all__ = ["TransformScheme"]
@@ -34,6 +36,8 @@ class TransformScheme(BaseModel):
34
36
  :param randomize: True if uniquely randomized transform weights should be used,
35
37
  otherwise use identical transform weights where applicable
36
38
  :param requires_grad: True if weights include gradients for training
39
+ :param precision: Precision at which this transform should be applied during online
40
+ rotations. Fused (offline) rotations are always performed in float64
37
41
  """
38
42
 
39
43
  type: str
@@ -41,3 +45,6 @@ class TransformScheme(BaseModel):
41
45
  randomize: bool = Field(default=False)
42
46
  requires_grad: bool = Field(default=False)
43
47
  head_dim: Optional[int] = Field(default=None)
48
+ precision: TorchDtype = Field(default=torch.float32)
49
+
50
+ model_config = ConfigDict(extra="forbid")
@@ -21,3 +21,4 @@ from .permutations_24 import *
21
21
  from .permute import *
22
22
  from .safetensors_load import *
23
23
  from .semi_structured_conversions import *
24
+ from .type import *