compressed-tensors 0.10.3a20250716__tar.gz → 0.10.3a20250724__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {compressed_tensors-0.10.3a20250716/src/compressed_tensors.egg-info → compressed_tensors-0.10.3a20250724}/PKG-INFO +1 -1
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +12 -6
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/forward.py +68 -5
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/initialize.py +35 -2
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/quant_args.py +31 -8
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/quant_scheme.py +41 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/utils/helpers.py +11 -2
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/factory/base.py +3 -4
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/__init__.py +1 -0
- compressed_tensors-0.10.3a20250724/src/compressed_tensors/utils/match.py +191 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/version.py +1 -1
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors.egg-info/SOURCES.txt +2 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_examples/test_bitmask_compression_ipynb.py +3 -1
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_forward.py +50 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_initialize.py +13 -3
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_quant_args.py +2 -1
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_utils/test_helpers.py +28 -1
- compressed_tensors-0.10.3a20250724/tests/test_utils/test_match.py +426 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/.gitkeep +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/actions/test/action.yml +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/scripts/step-status +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/build-test.yml +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/build.yml +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/report.yml +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/test-check.yaml +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/test.yml +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/trigger-all.yml +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.github/workflows/upload.yml +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/.gitignore +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/LICENSE +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/Makefile +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/README.md +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/bit_packing/int4_config.json +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/bitmask_compression.ipynb +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/llama_1.1b/ex_config_quantization.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/llama_1.1b/example_quant_config.json +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/examples/quantize_and_pack_int4.ipynb +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/pyproject.toml +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/setup.cfg +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/setup.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/README.md +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/base.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/base.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/config/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/config/base.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/config/dense.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/linear/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/linear/compressed_linear.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/quant_config.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/registry/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/registry/registry.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/apply.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/factory/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/factory/hadamard.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/factory/matrix_multiply.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/factory/random_hadamard.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/transform_args.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/transform_config.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/transform_scheme.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/utils/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/utils/hadamard.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/transform/utils/matrix.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/helpers.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/internal.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/offload.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/permutations_24.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/permute.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/safetensors_load.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors.egg-info/requires.txt +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/src/compressed_tensors.egg-info/top_level.txt +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/conftest.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_configs/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_configs/test_base.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_linear/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_linear/test_compressed_linear.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/conftest.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_apply.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_configs/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_configs/test_strategies.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_quant_config.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_quantization/test_quant_scheme.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_registry.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/conftest.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/factory/test_correctness.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/factory/test_memory.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/test_transform_args.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/test_transform_config.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/test_transform_scheme.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_transform/utils/test_hadamard.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_utils/__init__.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_utils/test_offload.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/test_utils/test_safetensors_load.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/tests/testing_utils.py +0 -0
- {compressed_tensors-0.10.3a20250716 → compressed_tensors-0.10.3a20250724}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.3a20250724
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -400,7 +400,10 @@ class ModelCompressor:
|
|
400
400
|
|
401
401
|
# in the future, support compression on same device
|
402
402
|
with align_module_device(module, execution_device=exec_device):
|
403
|
-
state_dict =
|
403
|
+
state_dict = {
|
404
|
+
f"{prefix}.{name}": param
|
405
|
+
for name, param in module.named_parameters(recurse=False)
|
406
|
+
}
|
404
407
|
|
405
408
|
# quantization first
|
406
409
|
if prefix in module_to_scheme:
|
@@ -421,7 +424,7 @@ class ModelCompressor:
|
|
421
424
|
|
422
425
|
# remove any existing parameters
|
423
426
|
offload_device = get_offloaded_device(module)
|
424
|
-
for name, _ in list(module.named_parameters()):
|
427
|
+
for name, _ in list(module.named_parameters(recurse=False)):
|
425
428
|
delete_offload_parameter(module, name)
|
426
429
|
|
427
430
|
# replace with compressed parameters
|
@@ -458,7 +461,10 @@ class ModelCompressor:
|
|
458
461
|
if prefix in module_to_scheme or prefix in sparse_compression_targets:
|
459
462
|
# in the future, support decompression on same device
|
460
463
|
with align_module_device(module, execution_device="cpu"):
|
461
|
-
state_dict =
|
464
|
+
state_dict = {
|
465
|
+
f"{prefix}.{name}": param
|
466
|
+
for name, param in module.named_parameters(recurse=False)
|
467
|
+
}
|
462
468
|
|
463
469
|
# sparsity first
|
464
470
|
if prefix in sparse_compression_targets:
|
@@ -483,7 +489,7 @@ class ModelCompressor:
|
|
483
489
|
# remove any existing parameters
|
484
490
|
exec_device = get_execution_device(module)
|
485
491
|
offload_device = get_offloaded_device(module)
|
486
|
-
for name, _ in list(module.named_parameters()):
|
492
|
+
for name, _ in list(module.named_parameters(recurse=False)):
|
487
493
|
delete_offload_parameter(module, name)
|
488
494
|
|
489
495
|
# replace with decompressed parameters
|
@@ -754,8 +760,8 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
|
|
754
760
|
fix_fsdp_module_name(name): module.quantization_scheme
|
755
761
|
for name, module in model.named_modules()
|
756
762
|
if (
|
757
|
-
hasattr(module, "quantization_scheme")
|
758
|
-
module.quantization_scheme.weights is not None
|
763
|
+
hasattr(module, "quantization_scheme")
|
764
|
+
and module.quantization_scheme.weights is not None
|
759
765
|
)
|
760
766
|
}
|
761
767
|
|
@@ -111,11 +111,18 @@ def dequantize(
|
|
111
111
|
elif scale.ndim == 2:
|
112
112
|
if scale.shape[1] == 1:
|
113
113
|
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
|
114
|
-
|
114
|
+
# Scale height matches input or is 1 -> group quantization across columns
|
115
|
+
#
|
116
|
+
# Example 1: scale.shape[0] == 1
|
117
|
+
# x_q: (4, 8), scale: (1, 4) -> 2 columns per group
|
118
|
+
#
|
119
|
+
# Example 2: scale.shape[0] == x_q.shape[0]
|
120
|
+
# x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
|
121
|
+
elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
|
115
122
|
group_size = int(x_q.shape[1] / scale.shape[1])
|
116
|
-
args = QuantizationArgs(
|
117
|
-
|
118
|
-
)
|
123
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=group_size)
|
124
|
+
else:
|
125
|
+
args = QuantizationArgs(strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape)
|
119
126
|
else:
|
120
127
|
raise ValueError(
|
121
128
|
f"Could not infer a quantization strategy from scale with {scale.ndim} "
|
@@ -189,7 +196,63 @@ def _process_quantization(
|
|
189
196
|
q_min, q_max = calculate_range(args, x.device)
|
190
197
|
group_size = args.group_size
|
191
198
|
|
192
|
-
|
199
|
+
# blockwise FP8: quantize per 2D block, supports block_structure for static block quant
|
200
|
+
if args.strategy == QuantizationStrategy.BLOCK:
|
201
|
+
original_shape = x.shape
|
202
|
+
rows, cols = x.shape[-2], x.shape[-1]
|
203
|
+
block_height, block_width = args.block_structure
|
204
|
+
|
205
|
+
# Ensure exact division (tensor dimensions must be divisible by block size)
|
206
|
+
if rows % block_height != 0:
|
207
|
+
raise ValueError(
|
208
|
+
f"Tensor height {rows} is not divisible by block_height {block_height}. "
|
209
|
+
f"Block quantization requires exact division."
|
210
|
+
)
|
211
|
+
if cols % block_width != 0:
|
212
|
+
raise ValueError(
|
213
|
+
f"Tensor width {cols} is not divisible by block_width {block_width}. "
|
214
|
+
f"Block quantization requires exact division."
|
215
|
+
)
|
216
|
+
|
217
|
+
# reshape into blocks and transpose to make each block contiguous
|
218
|
+
num_rows_blocks = rows // block_height
|
219
|
+
num_cols_blocks = cols // block_width
|
220
|
+
x_blocks = x.reshape(
|
221
|
+
num_rows_blocks,
|
222
|
+
block_height,
|
223
|
+
num_cols_blocks,
|
224
|
+
block_width,
|
225
|
+
).transpose(1, 2)
|
226
|
+
|
227
|
+
# expand scale/zero_point for blocks
|
228
|
+
sb = scale.unsqueeze(-1).unsqueeze(-1)
|
229
|
+
zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
|
230
|
+
if do_quantize:
|
231
|
+
# quantize blocks
|
232
|
+
x_blocks = _quantize(
|
233
|
+
x=x_blocks,
|
234
|
+
scale=sb,
|
235
|
+
zero_point=zb,
|
236
|
+
q_min=q_min,
|
237
|
+
q_max=q_max,
|
238
|
+
args=args,
|
239
|
+
dtype=dtype,
|
240
|
+
global_scale=global_scale,
|
241
|
+
)
|
242
|
+
if do_dequantize:
|
243
|
+
# dequantize blocks
|
244
|
+
x_blocks = _dequantize(
|
245
|
+
x_q=x_blocks,
|
246
|
+
scale=sb,
|
247
|
+
zero_point=zb,
|
248
|
+
global_scale=global_scale,
|
249
|
+
)
|
250
|
+
# restore original shape
|
251
|
+
output = x_blocks.transpose(1, 2).reshape(original_shape)
|
252
|
+
elif args.strategy in (
|
253
|
+
QuantizationStrategy.GROUP,
|
254
|
+
QuantizationStrategy.TENSOR_GROUP,
|
255
|
+
):
|
193
256
|
n_dims = x.shape
|
194
257
|
if len(n_dims) > 2:
|
195
258
|
x = x.squeeze(0)
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import math
|
18
|
+
import warnings
|
18
19
|
from enum import Enum
|
19
20
|
from typing import List, Optional
|
20
21
|
|
@@ -172,14 +173,41 @@ def _initialize_scale_zero_point(
|
|
172
173
|
|
173
174
|
if base_name == "weight" and weight_shape is not None:
|
174
175
|
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
175
|
-
# (output_channels, 1)
|
176
|
+
# (output_channels, 1) - only for weights
|
176
177
|
expected_shape = (weight_shape[0], 1)
|
177
178
|
elif quantization_args.strategy in (
|
178
179
|
QuantizationStrategy.TENSOR_GROUP,
|
179
180
|
QuantizationStrategy.GROUP,
|
180
181
|
):
|
182
|
+
# GROUP/TENSOR_GROUP for both weights and activations
|
181
183
|
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
|
182
184
|
expected_shape = (weight_shape[0], max(num_groups, 1))
|
185
|
+
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
|
186
|
+
# For block quantization, scale shape should match number of blocks - only for weights
|
187
|
+
if quantization_args.block_structure is None:
|
188
|
+
raise ValueError("Block quantization requires block_structure to be specified")
|
189
|
+
block_height, block_width = quantization_args.block_structure
|
190
|
+
rows, cols = weight_shape[-2], weight_shape[-1]
|
191
|
+
num_rows_blocks = math.ceil(rows / block_height)
|
192
|
+
num_cols_blocks = math.ceil(cols / block_width)
|
193
|
+
|
194
|
+
# Warn if dimensions don't divide evenly
|
195
|
+
if rows % block_height != 0 or cols % block_width != 0:
|
196
|
+
warnings.warn(
|
197
|
+
f"Block quantization: tensor shape {weight_shape} does not divide evenly "
|
198
|
+
f"by block structure {quantization_args.block_structure}. "
|
199
|
+
f"Some blocks will be incomplete which may affect quantization quality.",
|
200
|
+
UserWarning
|
201
|
+
)
|
202
|
+
|
203
|
+
expected_shape = (num_rows_blocks, num_cols_blocks)
|
204
|
+
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
|
205
|
+
warnings.warn(
|
206
|
+
f"BLOCK quantization not supported for {base_name} activations. "
|
207
|
+
f"Falling back to tensor-level quantization.",
|
208
|
+
UserWarning
|
209
|
+
)
|
210
|
+
expected_shape = 1
|
183
211
|
|
184
212
|
# 3. Identify quantization scale and zp dtype
|
185
213
|
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
|
@@ -189,7 +217,12 @@ def _initialize_scale_zero_point(
|
|
189
217
|
else:
|
190
218
|
# TODO: consider erroring out in the future as if the dtype if not one of these,
|
191
219
|
# there is likely bug
|
192
|
-
if scale_dtype not in [
|
220
|
+
if scale_dtype not in [
|
221
|
+
torch.float16,
|
222
|
+
torch.bfloat16,
|
223
|
+
torch.float32,
|
224
|
+
torch.float64,
|
225
|
+
]:
|
193
226
|
scale_dtype = torch.float16
|
194
227
|
zp_dtype = quantization_args.pytorch_dtype()
|
195
228
|
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
import warnings
|
16
16
|
from enum import Enum
|
17
|
-
from typing import Any, Dict, Optional, Union
|
17
|
+
from typing import Any, Dict, List, Optional, Union
|
18
18
|
|
19
19
|
import torch
|
20
20
|
from compressed_tensors.utils import Aliasable
|
@@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
153
153
|
:param symmetric: whether or not quantization scale is symmetric about zero-point
|
154
154
|
:param strategy: string id determining the scope of scale/zero-point to apply
|
155
155
|
:param group_size: group length to use for the group strategy
|
156
|
-
:param block_structure: 2d block structure to use for the block strategy
|
157
|
-
|
156
|
+
:param block_structure: 2d block structure to use for the block strategy; must be
|
157
|
+
a list of two ints [rows, cols] like [128, 128].
|
158
158
|
:param dynamic: set True to perform dynamic quantization - values will not be
|
159
159
|
calibrated during calibration phase, instead during inference new quantization
|
160
160
|
ranges will be observed with every sample. Defaults to False for static
|
@@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
169
169
|
symmetric: bool = True
|
170
170
|
group_size: Optional[int] = None
|
171
171
|
strategy: Optional[QuantizationStrategy] = None
|
172
|
-
block_structure: Optional[
|
172
|
+
block_structure: Optional[List[int]] = None
|
173
173
|
dynamic: Union[DynamicType, bool] = False
|
174
174
|
actorder: Union[ActivationOrdering, bool, None] = None
|
175
175
|
observer: Optional[str] = Field(
|
@@ -207,6 +207,28 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
207
207
|
|
208
208
|
return value
|
209
209
|
|
210
|
+
@field_validator("block_structure", mode="before")
|
211
|
+
def validate_block_structure(cls, value) -> Optional[List[int]]:
|
212
|
+
if value is None:
|
213
|
+
return value
|
214
|
+
# For backward compatibility, allow string format "2x4", "8x16", etc.
|
215
|
+
if isinstance(value, str):
|
216
|
+
try:
|
217
|
+
return [int(x) for x in value.split("x")]
|
218
|
+
except Exception:
|
219
|
+
raise ValueError(
|
220
|
+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
|
221
|
+
)
|
222
|
+
if isinstance(value, (list, tuple)):
|
223
|
+
if len(value) != 2 or not all(isinstance(v, int) for v in value):
|
224
|
+
raise ValueError(
|
225
|
+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
|
226
|
+
)
|
227
|
+
return list(value)
|
228
|
+
raise ValueError(
|
229
|
+
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
|
230
|
+
)
|
231
|
+
|
210
232
|
@field_validator("strategy", mode="before")
|
211
233
|
def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
|
212
234
|
if isinstance(value, str):
|
@@ -277,14 +299,15 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
277
299
|
|
278
300
|
# infer observer w.r.t. dynamic
|
279
301
|
if dynamic:
|
280
|
-
|
302
|
+
supported_strategies = (
|
281
303
|
QuantizationStrategy.TOKEN,
|
282
304
|
QuantizationStrategy.TENSOR,
|
283
305
|
QuantizationStrategy.TENSOR_GROUP,
|
284
|
-
|
306
|
+
QuantizationStrategy.GROUP,
|
307
|
+
)
|
308
|
+
if strategy not in supported_strategies:
|
285
309
|
raise ValueError(
|
286
|
-
f"One of {
|
287
|
-
"must be used for dynamic quantization",
|
310
|
+
f"One of {supported_strategies} must be used for dynamic quantization"
|
288
311
|
)
|
289
312
|
|
290
313
|
if (
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import warnings
|
15
16
|
from copy import deepcopy
|
16
17
|
from typing import Any, Dict, List, Optional
|
17
18
|
|
@@ -52,6 +53,7 @@ class QuantizationScheme(BaseModel):
|
|
52
53
|
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
|
53
54
|
inputs = model.input_activations
|
54
55
|
outputs = model.output_activations
|
56
|
+
weights = model.weights
|
55
57
|
|
56
58
|
if inputs is not None:
|
57
59
|
if inputs.actorder is not None:
|
@@ -61,6 +63,21 @@ class QuantizationScheme(BaseModel):
|
|
61
63
|
if outputs.actorder is not None:
|
62
64
|
raise ValueError("Cannot apply actorder to output activations")
|
63
65
|
|
66
|
+
if (
|
67
|
+
inputs and weights
|
68
|
+
and weights.strategy == QuantizationStrategy.GROUP
|
69
|
+
and inputs.strategy == QuantizationStrategy.GROUP
|
70
|
+
and weights.group_size != inputs.group_size
|
71
|
+
):
|
72
|
+
warnings.warn(
|
73
|
+
"Using GROUP strategy for both weights and input_activations "
|
74
|
+
f"with different group sizes ({weights.group_size} vs {inputs.group_size}) "
|
75
|
+
"may complicate fused kernel implementations. Consider using "
|
76
|
+
"TENSOR_GROUP strategy for both or matching group sizes.",
|
77
|
+
UserWarning,
|
78
|
+
stacklevel=2
|
79
|
+
)
|
80
|
+
|
64
81
|
return model
|
65
82
|
|
66
83
|
|
@@ -243,6 +260,29 @@ FP8_DYNAMIC = dict(
|
|
243
260
|
),
|
244
261
|
)
|
245
262
|
|
263
|
+
# Block‐wise FP8 (deepseekv3-style quantization):
|
264
|
+
# static 128x128 per‐block weights and
|
265
|
+
# dynamic per‐token‐group activations
|
266
|
+
FP8_BLOCK = dict(
|
267
|
+
weights=QuantizationArgs(
|
268
|
+
num_bits=8,
|
269
|
+
type=QuantizationType.FLOAT,
|
270
|
+
strategy=QuantizationStrategy.BLOCK,
|
271
|
+
symmetric=True,
|
272
|
+
dynamic=False,
|
273
|
+
block_structure=[128, 128],
|
274
|
+
),
|
275
|
+
input_activations=QuantizationArgs(
|
276
|
+
num_bits=8,
|
277
|
+
type=QuantizationType.FLOAT,
|
278
|
+
strategy=QuantizationStrategy.GROUP,
|
279
|
+
symmetric=True,
|
280
|
+
dynamic=True,
|
281
|
+
observer=None,
|
282
|
+
group_size=128,
|
283
|
+
),
|
284
|
+
)
|
285
|
+
|
246
286
|
PRESET_SCHEMES = {
|
247
287
|
# Unquantized (no-op)
|
248
288
|
"UNQUANTIZED": UNQUANTIZED,
|
@@ -257,6 +297,7 @@ PRESET_SCHEMES = {
|
|
257
297
|
# Float weight and activation schemes
|
258
298
|
"FP8": FP8,
|
259
299
|
"FP8_DYNAMIC": FP8_DYNAMIC,
|
300
|
+
"FP8_BLOCK": FP8_BLOCK,
|
260
301
|
"NVFP4A16": NVFP4A16,
|
261
302
|
"NVFP4": NVFP4,
|
262
303
|
}
|
@@ -171,7 +171,10 @@ def compute_dynamic_scales_and_zp(
|
|
171
171
|
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
|
172
172
|
elif args.strategy == QuantizationStrategy.TENSOR:
|
173
173
|
reduce_dims = None
|
174
|
-
elif args.strategy
|
174
|
+
elif args.strategy in (
|
175
|
+
QuantizationStrategy.TENSOR_GROUP,
|
176
|
+
QuantizationStrategy.GROUP,
|
177
|
+
):
|
175
178
|
if len(value.shape) > 2:
|
176
179
|
value = value.squeeze(0)
|
177
180
|
|
@@ -187,9 +190,15 @@ def compute_dynamic_scales_and_zp(
|
|
187
190
|
),
|
188
191
|
)
|
189
192
|
else:
|
193
|
+
supported_strategies = (
|
194
|
+
QuantizationStrategy.TOKEN,
|
195
|
+
QuantizationStrategy.TENSOR,
|
196
|
+
QuantizationStrategy.TENSOR_GROUP,
|
197
|
+
QuantizationStrategy.GROUP,
|
198
|
+
)
|
190
199
|
raise ValueError(
|
191
200
|
"Dynamic quantization is only supported for ",
|
192
|
-
f"{
|
201
|
+
f"{supported_strategies}",
|
193
202
|
)
|
194
203
|
|
195
204
|
if not reduce_dims:
|
@@ -18,7 +18,6 @@ from typing import Optional
|
|
18
18
|
import torch
|
19
19
|
import torch.nn.utils.parametrize as P
|
20
20
|
from compressed_tensors import InternalModule
|
21
|
-
from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
|
22
21
|
from compressed_tensors.registry.registry import RegistryMixin, T
|
23
22
|
from compressed_tensors.transform import (
|
24
23
|
TransformArgs,
|
@@ -29,6 +28,7 @@ from compressed_tensors.utils import (
|
|
29
28
|
align_module_device,
|
30
29
|
delete_offload_module,
|
31
30
|
has_offloaded_params,
|
31
|
+
match_named_modules,
|
32
32
|
patch_attr,
|
33
33
|
register_offload_module,
|
34
34
|
update_offload_parameter,
|
@@ -87,9 +87,8 @@ class TransformFactory(RegistryMixin, ABC):
|
|
87
87
|
:param model: module to apply transforms to
|
88
88
|
"""
|
89
89
|
for arg in self.scheme.apply:
|
90
|
-
for
|
91
|
-
|
92
|
-
self._apply_to_module(module, arg)
|
90
|
+
for _, module in match_named_modules(model, arg.targets, arg.ignore):
|
91
|
+
self._apply_to_module(module, arg)
|
93
92
|
|
94
93
|
def _apply_to_module(self, module: Module, args: TransformArgs):
|
95
94
|
"""
|
@@ -0,0 +1,191 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import logging
|
16
|
+
import re
|
17
|
+
from collections.abc import Generator
|
18
|
+
from typing import Iterable, Tuple
|
19
|
+
|
20
|
+
import torch
|
21
|
+
|
22
|
+
|
23
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
"match_named_modules",
|
28
|
+
"match_named_parameters",
|
29
|
+
"match_modules_set",
|
30
|
+
"is_match",
|
31
|
+
"match_name",
|
32
|
+
"match_class",
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
def match_named_modules(
|
37
|
+
model: torch.nn.Module,
|
38
|
+
targets: Iterable[str],
|
39
|
+
ignore: Iterable[str] = tuple(),
|
40
|
+
warn_on_fail: bool = False,
|
41
|
+
) -> Generator[Tuple[str, torch.nn.Module]]:
|
42
|
+
"""
|
43
|
+
Yields names and modules which match `targets` but do not match `ignore`.
|
44
|
+
Values are returned in order of `model.named_modules()`
|
45
|
+
|
46
|
+
:param model: model containing submodules to match against
|
47
|
+
:param targets: target strings, potentially containing "re:" prefixes
|
48
|
+
:param ignore: targets to ignore, potentially containing "re:" prefixes
|
49
|
+
:param warn_on_fail: if True, warns if any targets do not match any modules in model
|
50
|
+
:return: generator of module names and modules
|
51
|
+
"""
|
52
|
+
unmatched_targets = set(targets)
|
53
|
+
for name, module in model.named_modules():
|
54
|
+
for target in targets:
|
55
|
+
if is_match(name, module, target):
|
56
|
+
unmatched_targets -= {target}
|
57
|
+
|
58
|
+
if not any(is_match(name, module, ign) for ign in ignore):
|
59
|
+
yield name, module
|
60
|
+
|
61
|
+
if warn_on_fail:
|
62
|
+
for target in unmatched_targets:
|
63
|
+
_LOGGER.warning(
|
64
|
+
f"Could not match `{target}` in instance of {model.__class__.__name__}"
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
def match_named_parameters(
|
69
|
+
model: torch.nn.Module,
|
70
|
+
targets: Iterable[str],
|
71
|
+
ignore: Iterable[str] = tuple(),
|
72
|
+
warn_on_fail: bool = False,
|
73
|
+
) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
|
74
|
+
"""
|
75
|
+
Yields parameters which match `targets` but do not match `ignore`.
|
76
|
+
Values are returned in order of `model.named_modules()`
|
77
|
+
|
78
|
+
:param model: model containing params to match against
|
79
|
+
:param targets: target strings, potentially containing "re:" prefixes
|
80
|
+
:param ignore: targets to ignore, potentially containing "re:" prefixes
|
81
|
+
:param warn_on_fail: if True, warns if any targets do not match any params in model
|
82
|
+
:return: generator of fully-qualified param names, parent modules, and params
|
83
|
+
"""
|
84
|
+
unmatched_targets = set(targets)
|
85
|
+
for module_name, module in model.named_modules():
|
86
|
+
for param_name, param in module.named_parameters(recurse=False):
|
87
|
+
param_fqn = f"{module_name}.{param_name}"
|
88
|
+
for target in targets:
|
89
|
+
if match_name(param_fqn, target):
|
90
|
+
unmatched_targets -= {target}
|
91
|
+
|
92
|
+
if not any(match_name(param_fqn, ign) for ign in ignore):
|
93
|
+
yield param_fqn, module, param
|
94
|
+
|
95
|
+
if warn_on_fail:
|
96
|
+
for target in unmatched_targets:
|
97
|
+
_LOGGER.warning(
|
98
|
+
f"Could not match `{target}` in instance of {model.__class__.__name__}"
|
99
|
+
)
|
100
|
+
|
101
|
+
|
102
|
+
def match_modules_set(
|
103
|
+
model: torch.nn.Module,
|
104
|
+
targets: Iterable[str],
|
105
|
+
ignore: Iterable[str] = tuple(),
|
106
|
+
) -> Generator[Iterable[torch.nn.Module]]:
|
107
|
+
"""
|
108
|
+
Yields modules grouped with the same order and size as `targets`.
|
109
|
+
Values are returned in order of `model.named_modules()`
|
110
|
+
|
111
|
+
For example, the following targets would yield module belonging to the following layers:
|
112
|
+
```python3
|
113
|
+
match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
|
114
|
+
(
|
115
|
+
`model.layers.0.self_attn.q_proj`,
|
116
|
+
`model.layers.0.self_attn.k_proj`,
|
117
|
+
`model.layers.0.self_attn.v_proj`,
|
118
|
+
),
|
119
|
+
(
|
120
|
+
`model.layers.1.self_attn.q_proj`,
|
121
|
+
`model.layers.1.self_attn.k_proj`,
|
122
|
+
`model.layers.1.self_attn.v_proj`,
|
123
|
+
),
|
124
|
+
...
|
125
|
+
(
|
126
|
+
`model.layers.32.self_attn.q_proj`,
|
127
|
+
`model.layers.32.self_attn.k_proj`,
|
128
|
+
`model.layers.32.self_attn.v_proj`,
|
129
|
+
),
|
130
|
+
)
|
131
|
+
```
|
132
|
+
|
133
|
+
This can be used to match layers to their corresponding downstream counterparts.
|
134
|
+
For example, matching layer norms to their subsequent linear layers
|
135
|
+
```python3
|
136
|
+
for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)):
|
137
|
+
fuse_norm_linears(norm, [q, k, v])
|
138
|
+
|
139
|
+
:param model: model containing modules to match against
|
140
|
+
:param targets: target strings, potentially containing "re:" prefixes
|
141
|
+
:param ignore: targets to ignore, potentially containing "re:" prefixes
|
142
|
+
"""
|
143
|
+
matches = dict.fromkeys(targets, None)
|
144
|
+
for name, module in model.named_modules():
|
145
|
+
# match until we get a full set
|
146
|
+
for target in targets:
|
147
|
+
if is_match(name, module, target) and not any(
|
148
|
+
is_match(name, module, ign) for ign in ignore
|
149
|
+
):
|
150
|
+
if matches[target] is not None:
|
151
|
+
raise ValueError(f"Matched a {target} twice before completing set")
|
152
|
+
matches[target] = module
|
153
|
+
|
154
|
+
# once we have a full set, yield and reset
|
155
|
+
if targets and all((matches[target] is not None for target in targets)):
|
156
|
+
yield [matches[target] for target in targets] # ensure correct ordering
|
157
|
+
matches = dict.fromkeys(targets, None)
|
158
|
+
|
159
|
+
# check that none are left over
|
160
|
+
unmatched_keys = [match for match, value in matches.items() if value is not None]
|
161
|
+
if len(unmatched_keys):
|
162
|
+
raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
|
163
|
+
|
164
|
+
|
165
|
+
def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
|
166
|
+
"""
|
167
|
+
Returns true if either module name or module parent classes match against target
|
168
|
+
"""
|
169
|
+
return match_name(name, target) or match_class(module, target)
|
170
|
+
|
171
|
+
|
172
|
+
def match_name(name: str, target: str) -> bool:
|
173
|
+
"""
|
174
|
+
Returns true if target string begins with "re:" and
|
175
|
+
regex matches or if target string exactly matches name
|
176
|
+
"""
|
177
|
+
if target.startswith("re:"):
|
178
|
+
return re.match(target.removeprefix("re:"), name) is not None
|
179
|
+
else:
|
180
|
+
return target == name
|
181
|
+
|
182
|
+
|
183
|
+
def match_class(module: torch.nn.Module, target: str) -> bool:
|
184
|
+
"""
|
185
|
+
Returns true if any torch parent class names match the target string exactly
|
186
|
+
"""
|
187
|
+
# will never match against a regex pattern since `:` is not allowed in class names
|
188
|
+
return any(
|
189
|
+
issubclass(cls, torch.nn.Module) and cls.__name__ == target
|
190
|
+
for cls in module.__class__.__mro__
|
191
|
+
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.3a20250724
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -88,6 +88,7 @@ src/compressed_tensors/transform/utils/matrix.py
|
|
88
88
|
src/compressed_tensors/utils/__init__.py
|
89
89
|
src/compressed_tensors/utils/helpers.py
|
90
90
|
src/compressed_tensors/utils/internal.py
|
91
|
+
src/compressed_tensors/utils/match.py
|
91
92
|
src/compressed_tensors/utils/offload.py
|
92
93
|
src/compressed_tensors/utils/permutations_24.py
|
93
94
|
src/compressed_tensors/utils/permute.py
|
@@ -141,6 +142,7 @@ tests/test_transform/factory/test_memory.py
|
|
141
142
|
tests/test_transform/utils/test_hadamard.py
|
142
143
|
tests/test_utils/__init__.py
|
143
144
|
tests/test_utils/test_helpers.py
|
145
|
+
tests/test_utils/test_match.py
|
144
146
|
tests/test_utils/test_offload.py
|
145
147
|
tests/test_utils/test_safetensors_load.py
|
146
148
|
utils/copyright.py
|
@@ -12,8 +12,10 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import nbformat
|
16
15
|
import pytest
|
16
|
+
|
17
|
+
|
18
|
+
nbformat = pytest.importorskip("nbformat")
|
17
19
|
from nbconvert.preprocessors import ExecutePreprocessor
|
18
20
|
|
19
21
|
|