compressed-tensors 0.10.3a20250806__tar.gz → 0.10.3a20250812__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {compressed_tensors-0.10.3a20250806/src/compressed_tensors.egg-info → compressed_tensors-0.10.3a20250812}/PKG-INFO +1 -1
- compressed_tensors-0.10.3a20250812/pyproject.toml +16 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/base.py +8 -3
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +58 -35
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/quant_args.py +3 -1
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/quant_config.py +8 -2
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/quant_scheme.py +4 -2
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/apply.py +4 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/factory/base.py +2 -2
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/factory/hadamard.py +15 -8
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/factory/matrix_multiply.py +17 -8
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/transform_args.py +9 -1
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/transform_config.py +2 -40
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/transform_scheme.py +8 -1
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/__init__.py +1 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/offload.py +15 -1
- compressed_tensors-0.10.3a20250812/src/compressed_tensors/utils/type.py +74 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/version.py +1 -1
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors.egg-info/SOURCES.txt +2 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/factory/test_memory.py +1 -1
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/test_transform_config.py +14 -11
- compressed_tensors-0.10.3a20250812/tests/test_utils/test_type.py +79 -0
- compressed_tensors-0.10.3a20250806/pyproject.toml +0 -7
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/.gitkeep +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/actions/test/action.yml +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/scripts/step-status +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/build-test.yml +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/build.yml +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/report.yml +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/test-check.yaml +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/test.yml +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/trigger-all.yml +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.github/workflows/upload.yml +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/.gitignore +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/LICENSE +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/Makefile +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/README.md +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/bit_packing/int4_config.json +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/bitmask_compression.ipynb +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/llama_1.1b/ex_config_quantization.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/llama_1.1b/example_quant_config.json +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/examples/quantize_and_pack_int4.ipynb +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/setup.cfg +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/setup.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/README.md +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/base.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/config/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/config/base.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/config/dense.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/linear/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/linear/compressed_linear.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/registry/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/registry/registry.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/factory/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/utils/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/transform/utils/matrix.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/internal.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/match.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/permutations_24.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/permute.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/safetensors_load.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors.egg-info/requires.txt +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/src/compressed_tensors.egg-info/top_level.txt +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/conftest.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_configs/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_configs/test_base.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_linear/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_linear/test_compressed_linear.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/conftest.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_apply.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_forward.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_configs/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_configs/test_strategies.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_quant_args.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_quant_config.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_quant_scheme.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_quantization/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_registry.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/conftest.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/factory/test_correctness.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/factory/test_serialization.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/test_transform_args.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/test_transform_scheme.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_transform/utils/test_hadamard.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_utils/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_utils/test_match.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_utils/test_offload.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/test_utils/test_safetensors_load.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/tests/testing_utils.py +0 -0
- {compressed_tensors-0.10.3a20250806 → compressed_tensors-0.10.3a20250812}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.3a20250812
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -0,0 +1,16 @@
|
|
1
|
+
[build-system]
|
2
|
+
requires = ["setuptools", "wheel", "setuptools_scm==8.2.0"]
|
3
|
+
build-backend = "setuptools.build_meta"
|
4
|
+
|
5
|
+
[tool.black]
|
6
|
+
line-length = 88
|
7
|
+
target-version = ['py36']
|
8
|
+
|
9
|
+
[tool.pytest.ini_options]
|
10
|
+
markers = [
|
11
|
+
"unit: tests to ensure code correctness and regression test functionality",
|
12
|
+
"smoke: quick tests to check basic functionality",
|
13
|
+
"sanity: tests to ensure that new changes do not break existing functionality",
|
14
|
+
"regression: detailed tests to ensure major functions work correctly",
|
15
|
+
"integration: tests which integrate with a third party service such as HF",
|
16
|
+
]
|
@@ -12,9 +12,14 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
# configs
|
16
16
|
QUANTIZATION_CONFIG_NAME = "quantization_config"
|
17
|
-
|
18
|
-
|
17
|
+
SPARSITY_CONFIG_NAME = "sparsity_config"
|
18
|
+
TRANSFORM_CONFIG_NAME = "transform_config"
|
19
|
+
|
20
|
+
# required fields
|
19
21
|
COMPRESSION_VERSION_NAME = "version"
|
20
22
|
QUANTIZATION_METHOD_NAME = "quant_method"
|
23
|
+
|
24
|
+
# auxillary configs
|
25
|
+
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
|
@@ -29,6 +29,7 @@ from compressed_tensors.base import (
|
|
29
29
|
QUANTIZATION_CONFIG_NAME,
|
30
30
|
QUANTIZATION_METHOD_NAME,
|
31
31
|
SPARSITY_CONFIG_NAME,
|
32
|
+
TRANSFORM_CONFIG_NAME,
|
32
33
|
)
|
33
34
|
from compressed_tensors.compressors.base import BaseCompressor
|
34
35
|
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
|
@@ -43,6 +44,7 @@ from compressed_tensors.quantization import (
|
|
43
44
|
)
|
44
45
|
from compressed_tensors.quantization.lifecycle import expand_target_names
|
45
46
|
from compressed_tensors.quantization.utils import is_module_quantized
|
47
|
+
from compressed_tensors.transform import TransformConfig
|
46
48
|
from compressed_tensors.utils import (
|
47
49
|
align_module_device,
|
48
50
|
delete_offload_parameter,
|
@@ -105,6 +107,7 @@ class ModelCompressor:
|
|
105
107
|
|
106
108
|
sparsity_config: Optional[SparsityCompressionConfig] = None
|
107
109
|
quantization_config: Optional[QuantizationConfig] = None
|
110
|
+
transform_config: Optional[TransformConfig] = None
|
108
111
|
|
109
112
|
@classmethod
|
110
113
|
def from_pretrained(
|
@@ -144,6 +147,8 @@ class ModelCompressor:
|
|
144
147
|
|
145
148
|
sparsity_config = cls.parse_sparsity_config(compression_config)
|
146
149
|
quantization_config = cls.parse_quantization_config(compression_config)
|
150
|
+
# TODO: transform config is not support by CompressedTensorsConfig yet
|
151
|
+
|
147
152
|
if sparsity_config is None and quantization_config is None:
|
148
153
|
return None
|
149
154
|
|
@@ -177,25 +182,32 @@ class ModelCompressor:
|
|
177
182
|
algorithm
|
178
183
|
:return: compressor for the configs, or None if model is not compressed
|
179
184
|
"""
|
185
|
+
# reconstruct config from schemes attached to modules
|
180
186
|
quantization_config = QuantizationConfig.from_pretrained(
|
181
187
|
model, format=quantization_format
|
182
188
|
)
|
183
189
|
|
190
|
+
# use config passed as argument
|
184
191
|
if isinstance(sparsity_config, str): # we passed in a sparsity format
|
185
192
|
sparsity_config = SparsityCompressionConfig.load_from_registry(
|
186
193
|
sparsity_config
|
187
194
|
)
|
188
195
|
|
189
|
-
|
196
|
+
# use config attached to model
|
197
|
+
transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
|
198
|
+
|
199
|
+
if not any((quantization_config, sparsity_config, transform_config)):
|
190
200
|
return None
|
191
201
|
|
192
202
|
return cls(
|
193
|
-
sparsity_config=sparsity_config,
|
203
|
+
sparsity_config=sparsity_config,
|
204
|
+
quantization_config=quantization_config,
|
205
|
+
transform_config=transform_config,
|
194
206
|
)
|
195
207
|
|
196
208
|
@staticmethod
|
197
209
|
def parse_sparsity_config(
|
198
|
-
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
|
210
|
+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
|
199
211
|
) -> Union[Dict[str, Any], None]:
|
200
212
|
"""
|
201
213
|
Parse sparsity config from quantization/compression config. Sparsity
|
@@ -215,7 +227,7 @@ class ModelCompressor:
|
|
215
227
|
|
216
228
|
@staticmethod
|
217
229
|
def parse_quantization_config(
|
218
|
-
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
|
230
|
+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
|
219
231
|
) -> Union[Dict[str, Any], None]:
|
220
232
|
"""
|
221
233
|
Parse quantization config from quantization/compression config. The
|
@@ -234,6 +246,7 @@ class ModelCompressor:
|
|
234
246
|
|
235
247
|
quantization_config = deepcopy(compression_config)
|
236
248
|
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
|
249
|
+
quantization_config.pop(TRANSFORM_CONFIG_NAME, None)
|
237
250
|
|
238
251
|
# some fields are required, even if a qconfig is not present
|
239
252
|
# pop them off and if nothing remains, then there is no qconfig
|
@@ -254,13 +267,17 @@ class ModelCompressor:
|
|
254
267
|
self,
|
255
268
|
sparsity_config: Optional[SparsityCompressionConfig] = None,
|
256
269
|
quantization_config: Optional[QuantizationConfig] = None,
|
270
|
+
transform_config: Optional[TransformConfig] = None,
|
257
271
|
):
|
258
272
|
self.sparsity_config = sparsity_config
|
259
273
|
self.quantization_config = quantization_config
|
274
|
+
self.transform_config = transform_config
|
275
|
+
|
260
276
|
self.sparsity_compressor = None
|
261
277
|
self.quantization_compressor: Optional[
|
262
278
|
Union[BaseQuantizationCompressor, DenseCompressor]
|
263
279
|
] = None
|
280
|
+
# no transform compressor is required
|
264
281
|
|
265
282
|
if sparsity_config is not None:
|
266
283
|
self.sparsity_compressor = BaseCompressor.load_from_registry(
|
@@ -640,43 +657,49 @@ class ModelCompressor:
|
|
640
657
|
|
641
658
|
:param save_directory: path to a folder containing a HF model config
|
642
659
|
"""
|
643
|
-
|
660
|
+
# this check is also done in `from_pretrained_model`,
|
661
|
+
# but not in `from_pretrained`` or `from_compression_config``
|
662
|
+
if not any(
|
663
|
+
(self.quantization_config, self.sparsity_config, self.transform_config)
|
664
|
+
):
|
644
665
|
return
|
645
666
|
|
667
|
+
# write to config.json file, regardless of whether it exists already
|
668
|
+
# overwrite previous config and version if already existing
|
646
669
|
config_file_path = os.path.join(save_directory, CONFIG_NAME)
|
647
|
-
if
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
return
|
670
|
+
if os.path.exists(config_file_path):
|
671
|
+
with open(config_file_path, "r") as file:
|
672
|
+
config_data = json.load(file)
|
673
|
+
else:
|
674
|
+
config_data = {}
|
653
675
|
|
654
|
-
|
655
|
-
|
676
|
+
# serialize configs into json
|
677
|
+
qconfig_data = (
|
678
|
+
self.quantization_config.model_dump(exclude=["quant_method"])
|
679
|
+
if self.quantization_config is not None
|
680
|
+
else {}
|
681
|
+
)
|
682
|
+
sconfig_data = (
|
683
|
+
self.sparsity_config.model_dump()
|
684
|
+
if self.sparsity_config is not None
|
685
|
+
else {}
|
686
|
+
)
|
687
|
+
tconfig_data = (
|
688
|
+
self.transform_config.model_dump()
|
689
|
+
if self.transform_config is not None
|
690
|
+
else {}
|
691
|
+
)
|
656
692
|
|
657
|
-
#
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
else:
|
666
|
-
config_data[QUANTIZATION_CONFIG_NAME][
|
667
|
-
QUANTIZATION_METHOD_NAME
|
668
|
-
] = DEFAULT_QUANTIZATION_METHOD
|
669
|
-
|
670
|
-
# quantization and sparsity configs
|
671
|
-
if self.quantization_config is not None:
|
672
|
-
quant_config_data = self.quantization_config.model_dump()
|
673
|
-
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
|
674
|
-
if self.sparsity_config is not None:
|
675
|
-
sparsity_config_data = self.sparsity_config.model_dump()
|
676
|
-
config_data[QUANTIZATION_CONFIG_NAME][
|
677
|
-
SPARSITY_CONFIG_NAME
|
678
|
-
] = sparsity_config_data
|
693
|
+
# construct compression (quantization) config
|
694
|
+
config_data[QUANTIZATION_CONFIG_NAME] = {
|
695
|
+
COMPRESSION_VERSION_NAME: compressed_tensors.__version__,
|
696
|
+
QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD,
|
697
|
+
SPARSITY_CONFIG_NAME: sconfig_data,
|
698
|
+
TRANSFORM_CONFIG_NAME: tconfig_data,
|
699
|
+
**qconfig_data,
|
700
|
+
}
|
679
701
|
|
702
|
+
# write results to config.json file
|
680
703
|
with open(config_file_path, "w") as config_file:
|
681
704
|
json.dump(config_data, config_file, indent=2, sort_keys=True)
|
682
705
|
|
@@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
19
19
|
import torch
|
20
20
|
from compressed_tensors.utils import Aliasable
|
21
21
|
from compressed_tensors.utils.helpers import deprecated
|
22
|
-
from pydantic import BaseModel, Field, field_validator, model_validator
|
22
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
23
23
|
|
24
24
|
|
25
25
|
__all__ = [
|
@@ -358,6 +358,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
358
358
|
def get_observer(self) -> str:
|
359
359
|
return self.observer
|
360
360
|
|
361
|
+
model_config = ConfigDict(extra="forbid")
|
362
|
+
|
361
363
|
|
362
364
|
def round_to_quantized_type(
|
363
365
|
tensor: torch.Tensor, args: QuantizationArgs
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from enum import Enum
|
16
|
-
from typing import Dict, List, Optional, Union
|
16
|
+
from typing import Annotated, Any, Dict, List, Optional, Union
|
17
17
|
|
18
18
|
from compressed_tensors.config import CompressionFormat
|
19
19
|
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
|
@@ -26,7 +26,7 @@ from compressed_tensors.quantization.utils import (
|
|
26
26
|
module_type,
|
27
27
|
parse_out_kv_cache_args,
|
28
28
|
)
|
29
|
-
from pydantic import BaseModel, Field
|
29
|
+
from pydantic import BaseModel, ConfigDict, Field
|
30
30
|
from torch.nn import Module
|
31
31
|
|
32
32
|
|
@@ -142,6 +142,9 @@ class QuantizationConfig(BaseModel):
|
|
142
142
|
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
|
143
143
|
global_compression_ratio: Optional[float] = None
|
144
144
|
ignore: Optional[List[str]] = Field(default_factory=list)
|
145
|
+
# `run_compressed` is a dummy, unused arg for backwards compatibility
|
146
|
+
# see: https://github.com/huggingface/transformers/pull/39324
|
147
|
+
run_compressed: Annotated[Any, Field(exclude=True)] = None
|
145
148
|
|
146
149
|
def model_post_init(self, __context):
|
147
150
|
"""
|
@@ -254,3 +257,6 @@ class QuantizationConfig(BaseModel):
|
|
254
257
|
return True
|
255
258
|
|
256
259
|
return False
|
260
|
+
|
261
|
+
# TODO set `extra="forbid"` when upstream transformers is compatible
|
262
|
+
model_config = ConfigDict(extra="ignore")
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
import warnings
|
16
16
|
from copy import deepcopy
|
17
|
-
from typing import
|
17
|
+
from typing import List, Optional
|
18
18
|
|
19
19
|
from compressed_tensors.quantization.quant_args import (
|
20
20
|
DynamicType,
|
@@ -22,7 +22,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
22
22
|
QuantizationStrategy,
|
23
23
|
QuantizationType,
|
24
24
|
)
|
25
|
-
from pydantic import BaseModel, model_validator
|
25
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
26
26
|
|
27
27
|
|
28
28
|
__all__ = [
|
@@ -81,6 +81,8 @@ class QuantizationScheme(BaseModel):
|
|
81
81
|
|
82
82
|
return model
|
83
83
|
|
84
|
+
model_config = ConfigDict(extra="forbid")
|
85
|
+
|
84
86
|
|
85
87
|
"""
|
86
88
|
Pre-Set Quantization Scheme Args
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import torch
|
16
|
+
from compressed_tensors import TRANSFORM_CONFIG_NAME
|
16
17
|
from compressed_tensors.transform import TransformConfig, TransformFactory
|
17
18
|
|
18
19
|
|
@@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
|
|
30
31
|
for name, scheme in config.config_groups.items():
|
31
32
|
factory = TransformFactory.from_scheme(scheme, name=name)
|
32
33
|
factory.apply_to_model(model)
|
34
|
+
|
35
|
+
# attach config to model for compression/serialization
|
36
|
+
setattr(model, TRANSFORM_CONFIG_NAME, config)
|
@@ -14,11 +14,10 @@
|
|
14
14
|
|
15
15
|
from abc import ABC, abstractmethod
|
16
16
|
from collections import defaultdict
|
17
|
-
from typing import List, Optional,
|
17
|
+
from typing import List, Optional, Set, Tuple
|
18
18
|
|
19
19
|
import torch
|
20
20
|
import torch.nn.utils.parametrize as P
|
21
|
-
from compressed_tensors import InternalModule
|
22
21
|
from compressed_tensors.registry.registry import RegistryMixin, T
|
23
22
|
from compressed_tensors.transform import (
|
24
23
|
TransformArgs,
|
@@ -34,6 +33,7 @@ from compressed_tensors.utils import (
|
|
34
33
|
register_offload_module,
|
35
34
|
update_offload_parameter,
|
36
35
|
)
|
36
|
+
from compressed_tensors.utils.internal import InternalModule
|
37
37
|
from torch import Tensor
|
38
38
|
from torch.nn import Module, Parameter
|
39
39
|
|
@@ -12,8 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import
|
16
|
-
from typing import Optional, Union
|
15
|
+
from typing import Optional
|
17
16
|
|
18
17
|
import torch
|
19
18
|
from compressed_tensors.transform import TransformArgs, TransformScheme
|
@@ -26,7 +25,7 @@ from compressed_tensors.transform.utils.matrix import (
|
|
26
25
|
from compressed_tensors.utils import get_execution_device, get_offloaded_device
|
27
26
|
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
28
27
|
from torch import Tensor, device, dtype
|
29
|
-
from torch.nn import
|
28
|
+
from torch.nn import Module, Parameter
|
30
29
|
|
31
30
|
|
32
31
|
@TransformFactory.register("hadamard")
|
@@ -54,14 +53,14 @@ class HadamardFactory(TransformFactory):
|
|
54
53
|
"""
|
55
54
|
assert hasattr(module, "weight")
|
56
55
|
size = get_transform_size(module, args.location, self.scheme.head_dim)
|
57
|
-
dtype =
|
56
|
+
dtype = self.scheme.precision
|
58
57
|
device = get_offloaded_device(module)
|
59
58
|
exec_device = get_execution_device(module)
|
60
59
|
|
61
60
|
factory_kwargs = {"construct_device": exec_device}
|
62
61
|
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
|
63
62
|
perm = self.perms[weight] if self.scheme.randomize else None
|
64
|
-
return HadamardTransform(weight, perm, args, type(module))
|
63
|
+
return HadamardTransform(weight, perm, self.scheme, args, type(module))
|
65
64
|
|
66
65
|
def _create_weight(
|
67
66
|
self,
|
@@ -85,15 +84,18 @@ class HadamardTransform(TransformBase):
|
|
85
84
|
self,
|
86
85
|
weight: Parameter,
|
87
86
|
perm: Optional[Parameter],
|
87
|
+
scheme: TransformScheme,
|
88
88
|
args: TransformArgs,
|
89
89
|
module_type: type[torch.nn.Module],
|
90
90
|
):
|
91
91
|
super().__init__()
|
92
92
|
self.weight = weight
|
93
93
|
self.perm = perm
|
94
|
+
self.scheme = scheme
|
94
95
|
self.args = args
|
95
96
|
self.module_type = module_type
|
96
|
-
self._scale =
|
97
|
+
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
|
98
|
+
self._precision = scheme.precision if args.is_online() else torch.float64
|
97
99
|
|
98
100
|
def forward(self, value: Tensor) -> Tensor:
|
99
101
|
weight = self.weight
|
@@ -105,6 +107,11 @@ class HadamardTransform(TransformBase):
|
|
105
107
|
weight = weight.T
|
106
108
|
|
107
109
|
return (
|
108
|
-
apply_transform_weight(
|
110
|
+
apply_transform_weight(
|
111
|
+
weight.to(self._precision),
|
112
|
+
value.to(self._precision),
|
113
|
+
self.args.location,
|
114
|
+
self.module_type,
|
115
|
+
)
|
109
116
|
/ self._scale
|
110
|
-
)
|
117
|
+
).to(value.dtype)
|
@@ -24,7 +24,7 @@ from compressed_tensors.transform.utils.matrix import (
|
|
24
24
|
from compressed_tensors.utils import get_offloaded_device
|
25
25
|
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
26
26
|
from torch import Tensor, device, dtype
|
27
|
-
from torch.nn import
|
27
|
+
from torch.nn import Module, Parameter
|
28
28
|
|
29
29
|
|
30
30
|
@TransformFactory.register("random-matrix")
|
@@ -52,14 +52,14 @@ class RandomMatrixFactory(TransformFactory):
|
|
52
52
|
"""
|
53
53
|
assert hasattr(module, "weight")
|
54
54
|
size = get_transform_size(module, args.location, self.scheme.head_dim)
|
55
|
-
dtype =
|
55
|
+
dtype = self.scheme.precision
|
56
56
|
device = get_offloaded_device(module)
|
57
57
|
|
58
58
|
weight = self.weights[size, dtype, device]
|
59
59
|
if args.inverse:
|
60
60
|
weight = self.inverses[weight]
|
61
61
|
|
62
|
-
return RandomMatrixTransform(weight, args, type(module))
|
62
|
+
return RandomMatrixTransform(weight, self.scheme, args, type(module))
|
63
63
|
|
64
64
|
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
65
65
|
# TODO: verify that weight is invertible (has non-zero determinant)
|
@@ -78,25 +78,34 @@ class RandomMatrixTransform(TransformBase):
|
|
78
78
|
def __init__(
|
79
79
|
self,
|
80
80
|
weight: Tensor,
|
81
|
+
scheme: TransformScheme,
|
81
82
|
args: TransformArgs,
|
82
83
|
module_type: type[torch.nn.Module],
|
83
84
|
):
|
84
85
|
super().__init__()
|
85
86
|
self.weight = weight # is an inverse if args.inverse
|
87
|
+
self.scheme = scheme
|
86
88
|
self.args = args
|
87
89
|
self.module_type = module_type
|
90
|
+
self._precision = scheme.precision if args.is_online() else torch.float64
|
88
91
|
|
89
92
|
def forward(self, value: Tensor) -> Parameter:
|
90
93
|
return apply_transform_weight(
|
91
|
-
self.weight
|
92
|
-
|
94
|
+
self.weight.to(self._precision),
|
95
|
+
value.to(self._precision),
|
96
|
+
self.args.location,
|
97
|
+
self.module_type,
|
98
|
+
).to(value.dtype)
|
93
99
|
|
94
100
|
def right_inverse(self, value: Tensor) -> Tensor:
|
95
101
|
inverse = high_precision_invert(self.weight)
|
96
102
|
return apply_transform_weight(
|
97
|
-
inverse
|
98
|
-
|
103
|
+
inverse.to(self._precision),
|
104
|
+
value.to(self._precision),
|
105
|
+
self.args.location,
|
106
|
+
self.module_type,
|
107
|
+
).to(value.dtype)
|
99
108
|
|
100
109
|
|
101
110
|
def high_precision_invert(weight: Tensor) -> Tensor:
|
102
|
-
return torch.linalg.inv(weight.to(torch.
|
111
|
+
return torch.linalg.inv(weight.to(torch.float64)).to(weight.dtype)
|
@@ -15,7 +15,7 @@
|
|
15
15
|
from enum import Enum
|
16
16
|
from typing import List
|
17
17
|
|
18
|
-
from pydantic import BaseModel, Field, field_validator
|
18
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
19
19
|
|
20
20
|
|
21
21
|
__all__ = ["TransformArgs", "TransformLocation"]
|
@@ -68,3 +68,11 @@ class TransformArgs(BaseModel, use_enum_values=True):
|
|
68
68
|
if isinstance(value, str):
|
69
69
|
return [value]
|
70
70
|
return value
|
71
|
+
|
72
|
+
def is_online(self) -> bool:
|
73
|
+
return self.location not in (
|
74
|
+
TransformLocation.WEIGHT_INPUT,
|
75
|
+
TransformLocation.WEIGHT_OUTPUT,
|
76
|
+
)
|
77
|
+
|
78
|
+
model_config = ConfigDict(extra="forbid")
|
@@ -15,7 +15,7 @@
|
|
15
15
|
from typing import Dict
|
16
16
|
|
17
17
|
from compressed_tensors.transform import TransformArgs, TransformScheme
|
18
|
-
from pydantic import BaseModel
|
18
|
+
from pydantic import BaseModel, ConfigDict
|
19
19
|
|
20
20
|
|
21
21
|
__all__ = ["TransformConfig"]
|
@@ -32,42 +32,4 @@ class TransformConfig(BaseModel):
|
|
32
32
|
|
33
33
|
config_groups: Dict[str, TransformScheme]
|
34
34
|
|
35
|
-
|
36
|
-
# quip / quip sharp
|
37
|
-
QUIP = TransformConfig(
|
38
|
-
config_groups={
|
39
|
-
"v": TransformScheme(
|
40
|
-
type="hadamard",
|
41
|
-
apply=[
|
42
|
-
TransformArgs(
|
43
|
-
targets=["Linear"],
|
44
|
-
location="input", # non-mergable
|
45
|
-
),
|
46
|
-
TransformArgs(
|
47
|
-
targets=["Linear"],
|
48
|
-
location="weight_input",
|
49
|
-
inverse=True,
|
50
|
-
),
|
51
|
-
],
|
52
|
-
randomize=True,
|
53
|
-
),
|
54
|
-
"u": TransformScheme(
|
55
|
-
type="hadamard",
|
56
|
-
apply=[
|
57
|
-
TransformArgs(
|
58
|
-
targets=["Linear"],
|
59
|
-
location="weight_output",
|
60
|
-
),
|
61
|
-
TransformArgs(
|
62
|
-
targets=["Linear"], location="output", inverse=True # non-mergable
|
63
|
-
),
|
64
|
-
],
|
65
|
-
randomize=True,
|
66
|
-
),
|
67
|
-
}
|
68
|
-
)
|
69
|
-
|
70
|
-
|
71
|
-
PRESET_CONFIGS = {
|
72
|
-
"QUIP": QUIP,
|
73
|
-
}
|
35
|
+
model_config = ConfigDict(extra="forbid")
|
@@ -14,8 +14,10 @@
|
|
14
14
|
|
15
15
|
from typing import List, Optional
|
16
16
|
|
17
|
+
import torch
|
17
18
|
from compressed_tensors.transform import TransformArgs
|
18
|
-
from
|
19
|
+
from compressed_tensors.utils import TorchDtype
|
20
|
+
from pydantic import BaseModel, ConfigDict, Field
|
19
21
|
|
20
22
|
|
21
23
|
__all__ = ["TransformScheme"]
|
@@ -34,6 +36,8 @@ class TransformScheme(BaseModel):
|
|
34
36
|
:param randomize: True if uniquely randomized transform weights should be used,
|
35
37
|
otherwise use identical transform weights where applicable
|
36
38
|
:param requires_grad: True if weights include gradients for training
|
39
|
+
:param precision: Precision at which this transform should be applied during online
|
40
|
+
rotations. Fused (offline) rotations are always performed in float64
|
37
41
|
"""
|
38
42
|
|
39
43
|
type: str
|
@@ -41,3 +45,6 @@ class TransformScheme(BaseModel):
|
|
41
45
|
randomize: bool = Field(default=False)
|
42
46
|
requires_grad: bool = Field(default=False)
|
43
47
|
head_dim: Optional[int] = Field(default=None)
|
48
|
+
precision: TorchDtype = Field(default=torch.float32)
|
49
|
+
|
50
|
+
model_config = ConfigDict(extra="forbid")
|
@@ -86,6 +86,7 @@ __all__ = [
|
|
86
86
|
"offloaded_dispatch",
|
87
87
|
"disable_offloading",
|
88
88
|
"remove_dispatch",
|
89
|
+
"cast_to_device",
|
89
90
|
]
|
90
91
|
|
91
92
|
|
@@ -169,6 +170,19 @@ def update_parameter_data(
|
|
169
170
|
""" Candidates for Upstreaming """
|
170
171
|
|
171
172
|
|
173
|
+
def cast_to_device(device_spec: Union[int, torch.device]) -> torch.device:
|
174
|
+
"""
|
175
|
+
Convert an integer device index or torch.device into a torch.device object.
|
176
|
+
|
177
|
+
:param device_spec: Device index (int) or torch.device object.
|
178
|
+
Negative integers map to CPU.
|
179
|
+
:return: torch.device corresponding to the given device specification.
|
180
|
+
"""
|
181
|
+
if isinstance(device_spec, int):
|
182
|
+
return torch.device(f"cuda:{device_spec}" if device_spec >= 0 else "cpu")
|
183
|
+
return device_spec
|
184
|
+
|
185
|
+
|
172
186
|
def get_execution_device(module: torch.nn.Module) -> torch.device:
|
173
187
|
"""
|
174
188
|
Get the device which inputs should be moved to before module execution.
|
@@ -179,7 +193,7 @@ def get_execution_device(module: torch.nn.Module) -> torch.device:
|
|
179
193
|
"""
|
180
194
|
for submodule in module.modules():
|
181
195
|
if has_offloaded_params(submodule):
|
182
|
-
return submodule._hf_hook.execution_device
|
196
|
+
return cast_to_device(submodule._hf_hook.execution_device)
|
183
197
|
|
184
198
|
param = next(submodule.parameters(recurse=False), None)
|
185
199
|
if param is not None:
|