compressed-tensors 0.10.3a20250805__tar.gz → 0.10.3a20250811__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {compressed_tensors-0.10.3a20250805/src/compressed_tensors.egg-info → compressed_tensors-0.10.3a20250811}/PKG-INFO +1 -1
- compressed_tensors-0.10.3a20250811/pyproject.toml +16 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/base.py +8 -3
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +55 -33
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/quant_args.py +3 -1
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/quant_config.py +8 -2
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/quant_scheme.py +4 -2
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/apply.py +4 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/factory/base.py +49 -4
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/factory/hadamard.py +15 -8
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/factory/matrix_multiply.py +18 -8
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/transform_args.py +9 -1
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/transform_config.py +2 -40
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/transform_scheme.py +8 -1
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors-0.10.3a20250811/src/compressed_tensors/utils/type.py +74 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/version.py +1 -1
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors.egg-info/SOURCES.txt +3 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/conftest.py +4 -3
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/factory/test_correctness.py +15 -17
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/factory/test_memory.py +6 -6
- compressed_tensors-0.10.3a20250811/tests/test_transform/factory/test_serialization.py +54 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/test_transform_config.py +14 -11
- compressed_tensors-0.10.3a20250811/tests/test_utils/test_type.py +79 -0
- compressed_tensors-0.10.3a20250805/pyproject.toml +0 -7
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/.gitkeep +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/actions/test/action.yml +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/scripts/step-status +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/build-test.yml +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/build.yml +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/report.yml +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/test-check.yaml +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/test.yml +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/trigger-all.yml +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.github/workflows/upload.yml +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/.gitignore +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/LICENSE +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/Makefile +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/README.md +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/bit_packing/int4_config.json +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/bitmask_compression.ipynb +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/llama_1.1b/ex_config_quantization.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/llama_1.1b/example_quant_config.json +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/examples/quantize_and_pack_int4.ipynb +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/setup.cfg +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/setup.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/README.md +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/base.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/config/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/config/base.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/config/dense.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/linear/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/linear/compressed_linear.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/registry/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/registry/registry.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/factory/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/utils/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/transform/utils/matrix.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/internal.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/match.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/offload.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/permutations_24.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/permute.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/safetensors_load.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors.egg-info/requires.txt +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/src/compressed_tensors.egg-info/top_level.txt +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/conftest.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_configs/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_configs/test_base.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_linear/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_linear/test_compressed_linear.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/conftest.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_apply.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_forward.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_configs/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_configs/test_strategies.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_quant_args.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_quant_config.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_quant_scheme.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_quantization/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_registry.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/test_transform_args.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/test_transform_scheme.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_transform/utils/test_hadamard.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_utils/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_utils/test_match.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_utils/test_offload.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/test_utils/test_safetensors_load.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/tests/testing_utils.py +0 -0
- {compressed_tensors-0.10.3a20250805 → compressed_tensors-0.10.3a20250811}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.3a20250811
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -0,0 +1,16 @@
|
|
1
|
+
[build-system]
|
2
|
+
requires = ["setuptools", "wheel", "setuptools_scm==8.2.0"]
|
3
|
+
build-backend = "setuptools.build_meta"
|
4
|
+
|
5
|
+
[tool.black]
|
6
|
+
line-length = 88
|
7
|
+
target-version = ['py36']
|
8
|
+
|
9
|
+
[tool.pytest.ini_options]
|
10
|
+
markers = [
|
11
|
+
"unit: tests to ensure code correctness and regression test functionality",
|
12
|
+
"smoke: quick tests to check basic functionality",
|
13
|
+
"sanity: tests to ensure that new changes do not break existing functionality",
|
14
|
+
"regression: detailed tests to ensure major functions work correctly",
|
15
|
+
"integration: tests which integrate with a third party service such as HF",
|
16
|
+
]
|
@@ -12,9 +12,14 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
# configs
|
16
16
|
QUANTIZATION_CONFIG_NAME = "quantization_config"
|
17
|
-
|
18
|
-
|
17
|
+
SPARSITY_CONFIG_NAME = "sparsity_config"
|
18
|
+
TRANSFORM_CONFIG_NAME = "transform_config"
|
19
|
+
|
20
|
+
# required fields
|
19
21
|
COMPRESSION_VERSION_NAME = "version"
|
20
22
|
QUANTIZATION_METHOD_NAME = "quant_method"
|
23
|
+
|
24
|
+
# auxillary configs
|
25
|
+
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
|
@@ -29,6 +29,7 @@ from compressed_tensors.base import (
|
|
29
29
|
QUANTIZATION_CONFIG_NAME,
|
30
30
|
QUANTIZATION_METHOD_NAME,
|
31
31
|
SPARSITY_CONFIG_NAME,
|
32
|
+
TRANSFORM_CONFIG_NAME,
|
32
33
|
)
|
33
34
|
from compressed_tensors.compressors.base import BaseCompressor
|
34
35
|
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
|
@@ -43,6 +44,7 @@ from compressed_tensors.quantization import (
|
|
43
44
|
)
|
44
45
|
from compressed_tensors.quantization.lifecycle import expand_target_names
|
45
46
|
from compressed_tensors.quantization.utils import is_module_quantized
|
47
|
+
from compressed_tensors.transform import TransformConfig
|
46
48
|
from compressed_tensors.utils import (
|
47
49
|
align_module_device,
|
48
50
|
delete_offload_parameter,
|
@@ -105,6 +107,7 @@ class ModelCompressor:
|
|
105
107
|
|
106
108
|
sparsity_config: Optional[SparsityCompressionConfig] = None
|
107
109
|
quantization_config: Optional[QuantizationConfig] = None
|
110
|
+
transform_config: Optional[TransformConfig] = None
|
108
111
|
|
109
112
|
@classmethod
|
110
113
|
def from_pretrained(
|
@@ -144,6 +147,8 @@ class ModelCompressor:
|
|
144
147
|
|
145
148
|
sparsity_config = cls.parse_sparsity_config(compression_config)
|
146
149
|
quantization_config = cls.parse_quantization_config(compression_config)
|
150
|
+
# TODO: transform config is not support by CompressedTensorsConfig yet
|
151
|
+
|
147
152
|
if sparsity_config is None and quantization_config is None:
|
148
153
|
return None
|
149
154
|
|
@@ -177,20 +182,27 @@ class ModelCompressor:
|
|
177
182
|
algorithm
|
178
183
|
:return: compressor for the configs, or None if model is not compressed
|
179
184
|
"""
|
185
|
+
# reconstruct config from schemes attached to modules
|
180
186
|
quantization_config = QuantizationConfig.from_pretrained(
|
181
187
|
model, format=quantization_format
|
182
188
|
)
|
183
189
|
|
190
|
+
# use config passed as argument
|
184
191
|
if isinstance(sparsity_config, str): # we passed in a sparsity format
|
185
192
|
sparsity_config = SparsityCompressionConfig.load_from_registry(
|
186
193
|
sparsity_config
|
187
194
|
)
|
188
195
|
|
189
|
-
|
196
|
+
# use config attached to model
|
197
|
+
transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
|
198
|
+
|
199
|
+
if not any((quantization_config, sparsity_config, transform_config)):
|
190
200
|
return None
|
191
201
|
|
192
202
|
return cls(
|
193
|
-
sparsity_config=sparsity_config,
|
203
|
+
sparsity_config=sparsity_config,
|
204
|
+
quantization_config=quantization_config,
|
205
|
+
transform_config=transform_config,
|
194
206
|
)
|
195
207
|
|
196
208
|
@staticmethod
|
@@ -254,13 +266,17 @@ class ModelCompressor:
|
|
254
266
|
self,
|
255
267
|
sparsity_config: Optional[SparsityCompressionConfig] = None,
|
256
268
|
quantization_config: Optional[QuantizationConfig] = None,
|
269
|
+
transform_config: Optional[TransformConfig] = None,
|
257
270
|
):
|
258
271
|
self.sparsity_config = sparsity_config
|
259
272
|
self.quantization_config = quantization_config
|
273
|
+
self.transform_config = transform_config
|
274
|
+
|
260
275
|
self.sparsity_compressor = None
|
261
276
|
self.quantization_compressor: Optional[
|
262
277
|
Union[BaseQuantizationCompressor, DenseCompressor]
|
263
278
|
] = None
|
279
|
+
# no transform compressor is required
|
264
280
|
|
265
281
|
if sparsity_config is not None:
|
266
282
|
self.sparsity_compressor = BaseCompressor.load_from_registry(
|
@@ -640,43 +656,49 @@ class ModelCompressor:
|
|
640
656
|
|
641
657
|
:param save_directory: path to a folder containing a HF model config
|
642
658
|
"""
|
643
|
-
|
659
|
+
# this check is also done in `from_pretrained_model`,
|
660
|
+
# but not in `from_pretrained`` or `from_compression_config``
|
661
|
+
if not any(
|
662
|
+
(self.quantization_config, self.sparsity_config, self.transform_config)
|
663
|
+
):
|
644
664
|
return
|
645
665
|
|
666
|
+
# write to config.json file, regardless of whether it exists already
|
667
|
+
# overwrite previous config and version if already existing
|
646
668
|
config_file_path = os.path.join(save_directory, CONFIG_NAME)
|
647
|
-
if
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
return
|
669
|
+
if os.path.exists(config_file_path):
|
670
|
+
with open(config_file_path, "r") as file:
|
671
|
+
config_data = json.load(file)
|
672
|
+
else:
|
673
|
+
config_data = {}
|
653
674
|
|
654
|
-
|
655
|
-
|
675
|
+
# serialize configs into json
|
676
|
+
qconfig_data = (
|
677
|
+
self.quantization_config.model_dump(exclude=["quant_method", "format"])
|
678
|
+
if self.quantization_config is not None
|
679
|
+
else {}
|
680
|
+
)
|
681
|
+
sconfig_data = (
|
682
|
+
self.sparsity_config.model_dump()
|
683
|
+
if self.sparsity_config is not None
|
684
|
+
else {}
|
685
|
+
)
|
686
|
+
tconfig_data = (
|
687
|
+
self.transform_config.model_dump()
|
688
|
+
if self.transform_config is not None
|
689
|
+
else {}
|
690
|
+
)
|
656
691
|
|
657
|
-
#
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
else:
|
666
|
-
config_data[QUANTIZATION_CONFIG_NAME][
|
667
|
-
QUANTIZATION_METHOD_NAME
|
668
|
-
] = DEFAULT_QUANTIZATION_METHOD
|
669
|
-
|
670
|
-
# quantization and sparsity configs
|
671
|
-
if self.quantization_config is not None:
|
672
|
-
quant_config_data = self.quantization_config.model_dump()
|
673
|
-
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
|
674
|
-
if self.sparsity_config is not None:
|
675
|
-
sparsity_config_data = self.sparsity_config.model_dump()
|
676
|
-
config_data[QUANTIZATION_CONFIG_NAME][
|
677
|
-
SPARSITY_CONFIG_NAME
|
678
|
-
] = sparsity_config_data
|
692
|
+
# construct compression (quantization) config
|
693
|
+
config_data[QUANTIZATION_CONFIG_NAME] = {
|
694
|
+
COMPRESSION_VERSION_NAME: compressed_tensors.__version__,
|
695
|
+
QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD,
|
696
|
+
SPARSITY_CONFIG_NAME: sconfig_data,
|
697
|
+
TRANSFORM_CONFIG_NAME: tconfig_data,
|
698
|
+
**qconfig_data,
|
699
|
+
}
|
679
700
|
|
701
|
+
# write results to config.json file
|
680
702
|
with open(config_file_path, "w") as config_file:
|
681
703
|
json.dump(config_data, config_file, indent=2, sort_keys=True)
|
682
704
|
|
@@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
19
19
|
import torch
|
20
20
|
from compressed_tensors.utils import Aliasable
|
21
21
|
from compressed_tensors.utils.helpers import deprecated
|
22
|
-
from pydantic import BaseModel, Field, field_validator, model_validator
|
22
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
23
23
|
|
24
24
|
|
25
25
|
__all__ = [
|
@@ -358,6 +358,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
358
358
|
def get_observer(self) -> str:
|
359
359
|
return self.observer
|
360
360
|
|
361
|
+
model_config = ConfigDict(extra="forbid")
|
362
|
+
|
361
363
|
|
362
364
|
def round_to_quantized_type(
|
363
365
|
tensor: torch.Tensor, args: QuantizationArgs
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from enum import Enum
|
16
|
-
from typing import Dict, List, Optional, Union
|
16
|
+
from typing import Annotated, Any, Dict, List, Optional, Union
|
17
17
|
|
18
18
|
from compressed_tensors.config import CompressionFormat
|
19
19
|
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
|
@@ -26,7 +26,7 @@ from compressed_tensors.quantization.utils import (
|
|
26
26
|
module_type,
|
27
27
|
parse_out_kv_cache_args,
|
28
28
|
)
|
29
|
-
from pydantic import BaseModel, Field
|
29
|
+
from pydantic import BaseModel, ConfigDict, Field
|
30
30
|
from torch.nn import Module
|
31
31
|
|
32
32
|
|
@@ -142,6 +142,9 @@ class QuantizationConfig(BaseModel):
|
|
142
142
|
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
|
143
143
|
global_compression_ratio: Optional[float] = None
|
144
144
|
ignore: Optional[List[str]] = Field(default_factory=list)
|
145
|
+
# `run_compressed` is a dummy, unused arg for backwards compatibility
|
146
|
+
# see: https://github.com/huggingface/transformers/pull/39324
|
147
|
+
run_compressed: Annotated[Any, Field(exclude=True)] = None
|
145
148
|
|
146
149
|
def model_post_init(self, __context):
|
147
150
|
"""
|
@@ -254,3 +257,6 @@ class QuantizationConfig(BaseModel):
|
|
254
257
|
return True
|
255
258
|
|
256
259
|
return False
|
260
|
+
|
261
|
+
# TODO set `extra="forbid"` when upstream transformers is compatible
|
262
|
+
model_config = ConfigDict(extra="ignore")
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
import warnings
|
16
16
|
from copy import deepcopy
|
17
|
-
from typing import
|
17
|
+
from typing import List, Optional
|
18
18
|
|
19
19
|
from compressed_tensors.quantization.quant_args import (
|
20
20
|
DynamicType,
|
@@ -22,7 +22,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
22
22
|
QuantizationStrategy,
|
23
23
|
QuantizationType,
|
24
24
|
)
|
25
|
-
from pydantic import BaseModel, model_validator
|
25
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
26
26
|
|
27
27
|
|
28
28
|
__all__ = [
|
@@ -81,6 +81,8 @@ class QuantizationScheme(BaseModel):
|
|
81
81
|
|
82
82
|
return model
|
83
83
|
|
84
|
+
model_config = ConfigDict(extra="forbid")
|
85
|
+
|
84
86
|
|
85
87
|
"""
|
86
88
|
Pre-Set Quantization Scheme Args
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import torch
|
16
|
+
from compressed_tensors import TRANSFORM_CONFIG_NAME
|
16
17
|
from compressed_tensors.transform import TransformConfig, TransformFactory
|
17
18
|
|
18
19
|
|
@@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
|
|
30
31
|
for name, scheme in config.config_groups.items():
|
31
32
|
factory = TransformFactory.from_scheme(scheme, name=name)
|
32
33
|
factory.apply_to_model(model)
|
34
|
+
|
35
|
+
# attach config to model for compression/serialization
|
36
|
+
setattr(model, TRANSFORM_CONFIG_NAME, config)
|
@@ -13,11 +13,11 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from abc import ABC, abstractmethod
|
16
|
-
from
|
16
|
+
from collections import defaultdict
|
17
|
+
from typing import List, Optional, Set, Tuple
|
17
18
|
|
18
19
|
import torch
|
19
20
|
import torch.nn.utils.parametrize as P
|
20
|
-
from compressed_tensors import InternalModule
|
21
21
|
from compressed_tensors.registry.registry import RegistryMixin, T
|
22
22
|
from compressed_tensors.transform import (
|
23
23
|
TransformArgs,
|
@@ -33,6 +33,7 @@ from compressed_tensors.utils import (
|
|
33
33
|
register_offload_module,
|
34
34
|
update_offload_parameter,
|
35
35
|
)
|
36
|
+
from compressed_tensors.utils.internal import InternalModule
|
36
37
|
from torch import Tensor
|
37
38
|
from torch.nn import Module, Parameter
|
38
39
|
|
@@ -49,10 +50,13 @@ class TransformFactory(RegistryMixin, ABC):
|
|
49
50
|
:param seed: random seed used to transform weight randomization
|
50
51
|
"""
|
51
52
|
|
53
|
+
transforms: List["TransformBase"]
|
54
|
+
|
52
55
|
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
53
56
|
self.name = name
|
54
57
|
self.scheme = scheme
|
55
58
|
self.generator = torch.Generator()
|
59
|
+
self.transforms = list()
|
56
60
|
if seed is not None:
|
57
61
|
self.generator.manual_seed(seed)
|
58
62
|
|
@@ -90,6 +94,8 @@ class TransformFactory(RegistryMixin, ABC):
|
|
90
94
|
for _, module in match_named_modules(model, arg.targets, arg.ignore):
|
91
95
|
self._apply_to_module(module, arg)
|
92
96
|
|
97
|
+
self._update_tied_weights()
|
98
|
+
|
93
99
|
def _apply_to_module(self, module: Module, args: TransformArgs):
|
94
100
|
"""
|
95
101
|
Create transforms and apply them to the module
|
@@ -97,9 +103,17 @@ class TransformFactory(RegistryMixin, ABC):
|
|
97
103
|
:param module: target module to apply transforms to
|
98
104
|
:param args: defines how the transform will be applied to the target module
|
99
105
|
"""
|
106
|
+
if has_offloaded_params(module):
|
107
|
+
if module._hf_hook.place_submodules:
|
108
|
+
raise NotImplementedError(
|
109
|
+
"Applying transforms to offloaded submodules with "
|
110
|
+
"`place_submodules=True` is not supported"
|
111
|
+
)
|
112
|
+
|
100
113
|
# create transform as submodule
|
101
114
|
transform_name = f"{self.name}_{args.location}"
|
102
115
|
transform = self.create_transform(module, args)
|
116
|
+
self.transforms.append(transform)
|
103
117
|
register_offload_module(module, transform_name, transform)
|
104
118
|
|
105
119
|
# register input transformation hook
|
@@ -128,8 +142,9 @@ class TransformFactory(RegistryMixin, ABC):
|
|
128
142
|
raise ValueError("Offloaded training is not supported")
|
129
143
|
P.register_parametrization(module, "weight", transform)
|
130
144
|
|
131
|
-
|
132
|
-
|
145
|
+
else:
|
146
|
+
# transform is no longer needed (unfusing is not supported)
|
147
|
+
delete_offload_module(module, transform_name)
|
133
148
|
|
134
149
|
# register output transformation hook
|
135
150
|
elif args.location == TransformLocation.OUTPUT:
|
@@ -143,6 +158,31 @@ class TransformFactory(RegistryMixin, ABC):
|
|
143
158
|
else:
|
144
159
|
raise NotImplementedError()
|
145
160
|
|
161
|
+
def _update_tied_weights(self):
|
162
|
+
"""
|
163
|
+
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
|
164
|
+
which is used by transformers to detect and remove shared pointers
|
165
|
+
during saving
|
166
|
+
"""
|
167
|
+
# map from data_ptrs to keys
|
168
|
+
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
|
169
|
+
for transform in self.transforms:
|
170
|
+
for name, param in transform.named_parameters(recurse=False):
|
171
|
+
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
|
172
|
+
if has_offloaded_params(transform):
|
173
|
+
param = transform._hf_hook.weights_map[name]
|
174
|
+
ptr_to_keys[param.data_ptr()].append((transform, name))
|
175
|
+
|
176
|
+
# populate `_dynamic_tied_weights_keys` if there is more than one key
|
177
|
+
# and ensure that they share tensors
|
178
|
+
for shared_keys in ptr_to_keys.values():
|
179
|
+
if len(shared_keys) > 1:
|
180
|
+
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
|
181
|
+
|
182
|
+
for transform, name in shared_keys:
|
183
|
+
transform._dynamic_tied_weights_keys.add(name)
|
184
|
+
setattr(transform, name, tensor)
|
185
|
+
|
146
186
|
|
147
187
|
class TransformBase(InternalModule, ABC):
|
148
188
|
"""
|
@@ -151,6 +191,11 @@ class TransformBase(InternalModule, ABC):
|
|
151
191
|
|
152
192
|
args: TransformArgs
|
153
193
|
weight: Parameter
|
194
|
+
_dynamic_tied_weights_keys: Set[str]
|
195
|
+
|
196
|
+
def __init__(self):
|
197
|
+
super().__init__()
|
198
|
+
self._dynamic_tied_weights_keys = set()
|
154
199
|
|
155
200
|
@abstractmethod
|
156
201
|
def forward(self, value: Tensor) -> Tensor:
|
@@ -12,8 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import
|
16
|
-
from typing import Optional, Union
|
15
|
+
from typing import Optional
|
17
16
|
|
18
17
|
import torch
|
19
18
|
from compressed_tensors.transform import TransformArgs, TransformScheme
|
@@ -26,7 +25,7 @@ from compressed_tensors.transform.utils.matrix import (
|
|
26
25
|
from compressed_tensors.utils import get_execution_device, get_offloaded_device
|
27
26
|
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
28
27
|
from torch import Tensor, device, dtype
|
29
|
-
from torch.nn import
|
28
|
+
from torch.nn import Module, Parameter
|
30
29
|
|
31
30
|
|
32
31
|
@TransformFactory.register("hadamard")
|
@@ -54,14 +53,14 @@ class HadamardFactory(TransformFactory):
|
|
54
53
|
"""
|
55
54
|
assert hasattr(module, "weight")
|
56
55
|
size = get_transform_size(module, args.location, self.scheme.head_dim)
|
57
|
-
dtype =
|
56
|
+
dtype = self.scheme.precision
|
58
57
|
device = get_offloaded_device(module)
|
59
58
|
exec_device = get_execution_device(module)
|
60
59
|
|
61
60
|
factory_kwargs = {"construct_device": exec_device}
|
62
61
|
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
|
63
62
|
perm = self.perms[weight] if self.scheme.randomize else None
|
64
|
-
return HadamardTransform(weight, perm, args, type(module))
|
63
|
+
return HadamardTransform(weight, perm, self.scheme, args, type(module))
|
65
64
|
|
66
65
|
def _create_weight(
|
67
66
|
self,
|
@@ -85,15 +84,18 @@ class HadamardTransform(TransformBase):
|
|
85
84
|
self,
|
86
85
|
weight: Parameter,
|
87
86
|
perm: Optional[Parameter],
|
87
|
+
scheme: TransformScheme,
|
88
88
|
args: TransformArgs,
|
89
89
|
module_type: type[torch.nn.Module],
|
90
90
|
):
|
91
91
|
super().__init__()
|
92
92
|
self.weight = weight
|
93
93
|
self.perm = perm
|
94
|
+
self.scheme = scheme
|
94
95
|
self.args = args
|
95
96
|
self.module_type = module_type
|
96
|
-
self._scale =
|
97
|
+
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
|
98
|
+
self._precision = scheme.precision if args.is_online() else torch.float64
|
97
99
|
|
98
100
|
def forward(self, value: Tensor) -> Tensor:
|
99
101
|
weight = self.weight
|
@@ -105,6 +107,11 @@ class HadamardTransform(TransformBase):
|
|
105
107
|
weight = weight.T
|
106
108
|
|
107
109
|
return (
|
108
|
-
apply_transform_weight(
|
110
|
+
apply_transform_weight(
|
111
|
+
weight.to(self._precision),
|
112
|
+
value.to(self._precision),
|
113
|
+
self.args.location,
|
114
|
+
self.module_type,
|
115
|
+
)
|
109
116
|
/ self._scale
|
110
|
-
)
|
117
|
+
).to(value.dtype)
|
@@ -24,7 +24,7 @@ from compressed_tensors.transform.utils.matrix import (
|
|
24
24
|
from compressed_tensors.utils import get_offloaded_device
|
25
25
|
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
26
26
|
from torch import Tensor, device, dtype
|
27
|
-
from torch.nn import
|
27
|
+
from torch.nn import Module, Parameter
|
28
28
|
|
29
29
|
|
30
30
|
@TransformFactory.register("random-matrix")
|
@@ -52,14 +52,14 @@ class RandomMatrixFactory(TransformFactory):
|
|
52
52
|
"""
|
53
53
|
assert hasattr(module, "weight")
|
54
54
|
size = get_transform_size(module, args.location, self.scheme.head_dim)
|
55
|
-
dtype =
|
55
|
+
dtype = self.scheme.precision
|
56
56
|
device = get_offloaded_device(module)
|
57
57
|
|
58
58
|
weight = self.weights[size, dtype, device]
|
59
59
|
if args.inverse:
|
60
60
|
weight = self.inverses[weight]
|
61
61
|
|
62
|
-
return RandomMatrixTransform(weight, args, type(module))
|
62
|
+
return RandomMatrixTransform(weight, self.scheme, args, type(module))
|
63
63
|
|
64
64
|
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
65
65
|
# TODO: verify that weight is invertible (has non-zero determinant)
|
@@ -70,6 +70,7 @@ class RandomMatrixFactory(TransformFactory):
|
|
70
70
|
|
71
71
|
def _create_inverse(self, weight: Parameter) -> Parameter:
|
72
72
|
data = high_precision_invert(weight.data)
|
73
|
+
data = data.contiguous() # ensure proper serialization
|
73
74
|
return Parameter(data, requires_grad=False)
|
74
75
|
|
75
76
|
|
@@ -77,25 +78,34 @@ class RandomMatrixTransform(TransformBase):
|
|
77
78
|
def __init__(
|
78
79
|
self,
|
79
80
|
weight: Tensor,
|
81
|
+
scheme: TransformScheme,
|
80
82
|
args: TransformArgs,
|
81
83
|
module_type: type[torch.nn.Module],
|
82
84
|
):
|
83
85
|
super().__init__()
|
84
86
|
self.weight = weight # is an inverse if args.inverse
|
87
|
+
self.scheme = scheme
|
85
88
|
self.args = args
|
86
89
|
self.module_type = module_type
|
90
|
+
self._precision = scheme.precision if args.is_online() else torch.float64
|
87
91
|
|
88
92
|
def forward(self, value: Tensor) -> Parameter:
|
89
93
|
return apply_transform_weight(
|
90
|
-
self.weight
|
91
|
-
|
94
|
+
self.weight.to(self._precision),
|
95
|
+
value.to(self._precision),
|
96
|
+
self.args.location,
|
97
|
+
self.module_type,
|
98
|
+
).to(value.dtype)
|
92
99
|
|
93
100
|
def right_inverse(self, value: Tensor) -> Tensor:
|
94
101
|
inverse = high_precision_invert(self.weight)
|
95
102
|
return apply_transform_weight(
|
96
|
-
inverse
|
97
|
-
|
103
|
+
inverse.to(self._precision),
|
104
|
+
value.to(self._precision),
|
105
|
+
self.args.location,
|
106
|
+
self.module_type,
|
107
|
+
).to(value.dtype)
|
98
108
|
|
99
109
|
|
100
110
|
def high_precision_invert(weight: Tensor) -> Tensor:
|
101
|
-
return torch.linalg.inv(weight.to(torch.
|
111
|
+
return torch.linalg.inv(weight.to(torch.float64)).to(weight.dtype)
|
@@ -15,7 +15,7 @@
|
|
15
15
|
from enum import Enum
|
16
16
|
from typing import List
|
17
17
|
|
18
|
-
from pydantic import BaseModel, Field, field_validator
|
18
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
19
19
|
|
20
20
|
|
21
21
|
__all__ = ["TransformArgs", "TransformLocation"]
|
@@ -68,3 +68,11 @@ class TransformArgs(BaseModel, use_enum_values=True):
|
|
68
68
|
if isinstance(value, str):
|
69
69
|
return [value]
|
70
70
|
return value
|
71
|
+
|
72
|
+
def is_online(self) -> bool:
|
73
|
+
return self.location not in (
|
74
|
+
TransformLocation.WEIGHT_INPUT,
|
75
|
+
TransformLocation.WEIGHT_OUTPUT,
|
76
|
+
)
|
77
|
+
|
78
|
+
model_config = ConfigDict(extra="forbid")
|
@@ -15,7 +15,7 @@
|
|
15
15
|
from typing import Dict
|
16
16
|
|
17
17
|
from compressed_tensors.transform import TransformArgs, TransformScheme
|
18
|
-
from pydantic import BaseModel
|
18
|
+
from pydantic import BaseModel, ConfigDict
|
19
19
|
|
20
20
|
|
21
21
|
__all__ = ["TransformConfig"]
|
@@ -32,42 +32,4 @@ class TransformConfig(BaseModel):
|
|
32
32
|
|
33
33
|
config_groups: Dict[str, TransformScheme]
|
34
34
|
|
35
|
-
|
36
|
-
# quip / quip sharp
|
37
|
-
QUIP = TransformConfig(
|
38
|
-
config_groups={
|
39
|
-
"v": TransformScheme(
|
40
|
-
type="hadamard",
|
41
|
-
apply=[
|
42
|
-
TransformArgs(
|
43
|
-
targets=["Linear"],
|
44
|
-
location="input", # non-mergable
|
45
|
-
),
|
46
|
-
TransformArgs(
|
47
|
-
targets=["Linear"],
|
48
|
-
location="weight_input",
|
49
|
-
inverse=True,
|
50
|
-
),
|
51
|
-
],
|
52
|
-
randomize=True,
|
53
|
-
),
|
54
|
-
"u": TransformScheme(
|
55
|
-
type="hadamard",
|
56
|
-
apply=[
|
57
|
-
TransformArgs(
|
58
|
-
targets=["Linear"],
|
59
|
-
location="weight_output",
|
60
|
-
),
|
61
|
-
TransformArgs(
|
62
|
-
targets=["Linear"], location="output", inverse=True # non-mergable
|
63
|
-
),
|
64
|
-
],
|
65
|
-
randomize=True,
|
66
|
-
),
|
67
|
-
}
|
68
|
-
)
|
69
|
-
|
70
|
-
|
71
|
-
PRESET_CONFIGS = {
|
72
|
-
"QUIP": QUIP,
|
73
|
-
}
|
35
|
+
model_config = ConfigDict(extra="forbid")
|
@@ -14,8 +14,10 @@
|
|
14
14
|
|
15
15
|
from typing import List, Optional
|
16
16
|
|
17
|
+
import torch
|
17
18
|
from compressed_tensors.transform import TransformArgs
|
18
|
-
from
|
19
|
+
from compressed_tensors.utils import TorchDtype
|
20
|
+
from pydantic import BaseModel, ConfigDict, Field
|
19
21
|
|
20
22
|
|
21
23
|
__all__ = ["TransformScheme"]
|
@@ -34,6 +36,8 @@ class TransformScheme(BaseModel):
|
|
34
36
|
:param randomize: True if uniquely randomized transform weights should be used,
|
35
37
|
otherwise use identical transform weights where applicable
|
36
38
|
:param requires_grad: True if weights include gradients for training
|
39
|
+
:param precision: Precision at which this transform should be applied during online
|
40
|
+
rotations. Fused (offline) rotations are always performed in float64
|
37
41
|
"""
|
38
42
|
|
39
43
|
type: str
|
@@ -41,3 +45,6 @@ class TransformScheme(BaseModel):
|
|
41
45
|
randomize: bool = Field(default=False)
|
42
46
|
requires_grad: bool = Field(default=False)
|
43
47
|
head_dim: Optional[int] = Field(default=None)
|
48
|
+
precision: TorchDtype = Field(default=torch.float32)
|
49
|
+
|
50
|
+
model_config = ConfigDict(extra="forbid")
|