compressed-tensors 0.10.1a20250605__tar.gz → 0.10.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/workflows/report.yml +1 -1
- {compressed_tensors-0.10.1a20250605/src/compressed_tensors.egg-info → compressed_tensors-0.10.2}/PKG-INFO +1 -1
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/setup.py +1 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +7 -1
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/sparse_compressors/dense.py +19 -1
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/lifecycle/apply.py +1 -3
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/transform/__init__.py +5 -0
- compressed_tensors-0.10.2/src/compressed_tensors/transform/factory/base.py +164 -0
- compressed_tensors-0.10.2/src/compressed_tensors/transform/factory/hadamard.py +79 -0
- compressed_tensors-0.10.2/src/compressed_tensors/transform/factory/matrix_multiply.py +90 -0
- compressed_tensors-0.10.2/src/compressed_tensors/transform/factory/random_hadamard.py +34 -0
- compressed_tensors-0.10.2/src/compressed_tensors/transform/utils/hadamard.py +160 -0
- compressed_tensors-0.10.2/src/compressed_tensors/transform/utils/hadamards.safetensors +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/utils/offload.py +158 -71
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/version.py +2 -2
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors.egg-info/SOURCES.txt +8 -0
- compressed_tensors-0.10.2/tests/test_transform/factory/test_correctness.py +116 -0
- compressed_tensors-0.10.2/tests/test_transform/factory/test_memory.py +112 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_transform/utils/test_hadamard.py +38 -32
- compressed_tensors-0.10.2/tests/test_utils/__init__.py +13 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_utils/test_offload.py +126 -12
- compressed_tensors-0.10.1a20250605/src/compressed_tensors/transform/utils/hadamard.py +0 -161
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/.gitkeep +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/actions/test/action.yml +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/scripts/step-status +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/workflows/build-test.yml +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/workflows/build.yml +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/workflows/test-check.yaml +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/workflows/test.yml +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/workflows/trigger-all.yml +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/workflows/upload.yml +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.gitignore +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/LICENSE +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/Makefile +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/README.md +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/examples/bit_packing/int4_config.json +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/examples/bitmask_compression.ipynb +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/examples/llama_1.1b/ex_config_quantization.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/examples/llama_1.1b/example_quant_config.json +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/examples/quantize_and_pack_int4.ipynb +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/pyproject.toml +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/setup.cfg +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/README.md +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/base.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/base.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/helpers.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/config/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/config/base.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/config/dense.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/linear/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/linear/compressed_linear.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/quant_args.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/quant_config.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/registry/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/registry/registry.py +0 -0
- {compressed_tensors-0.10.1a20250605/src/compressed_tensors/transform/utils → compressed_tensors-0.10.2/src/compressed_tensors/transform/factory}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/transform/transform_args.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/transform/transform_config.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/transform/transform_scheme.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests → compressed_tensors-0.10.2/src/compressed_tensors/transform/utils}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/transform/utils/utils.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/utils/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/utils/helpers.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/utils/permutations_24.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/utils/permute.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/utils/safetensors_load.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors.egg-info/requires.txt +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/src/compressed_tensors.egg-info/top_level.txt +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_compressors → compressed_tensors-0.10.2/tests}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/conftest.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_compressors/model_compressors → compressed_tensors-0.10.2/tests/test_compressors}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_compressors/quantized_compressors → compressed_tensors-0.10.2/tests/test_compressors/model_compressors}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_compressors/sparse_compressors → compressed_tensors-0.10.2/tests/test_compressors/quantized_compressors}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_compressors/sparse_quantized_compressors → compressed_tensors-0.10.2/tests/test_compressors/sparse_compressors}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_configs → compressed_tensors-0.10.2/tests/test_compressors/sparse_quantized_compressors}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_linear → compressed_tensors-0.10.2/tests/test_configs}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_configs/test_base.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_quantization → compressed_tensors-0.10.2/tests/test_linear}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_linear/test_compressed_linear.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_quantization/lifecycle → compressed_tensors-0.10.2/tests/test_quantization}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_quantization/test_configs → compressed_tensors-0.10.2/tests/test_quantization/lifecycle}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/lifecycle/conftest.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/lifecycle/test_apply.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/lifecycle/test_forward.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
- {compressed_tensors-0.10.1a20250605/tests/test_utils → compressed_tensors-0.10.2/tests/test_quantization/test_configs}/__init__.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/test_configs/test_strategies.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/test_quant_args.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/test_quant_config.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/test_quant_scheme.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_quantization/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_registry.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_transform/test_transform_args.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_transform/test_transform_config.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_transform/test_transform_scheme.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/test_utils/test_safetensors_load.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/tests/testing_utils.py +0 -0
- {compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/utils/copyright.py +0 -0
{compressed_tensors-0.10.1a20250605 → compressed_tensors-0.10.2}/.github/workflows/report.yml
RENAMED
@@ -120,7 +120,7 @@ jobs:
|
|
120
120
|
shell: bash
|
121
121
|
|
122
122
|
- name: report to reportportal
|
123
|
-
uses: neuralmagic/nm-actions/actions/reportportal_submit_execution_results@v1.
|
123
|
+
uses: neuralmagic/nm-actions/actions/reportportal_submit_execution_results@v1.22.0
|
124
124
|
with:
|
125
125
|
droute_username: ${{ secrets.DROUTE_USERNAME }}
|
126
126
|
droute_password: ${{ secrets.DROUTE_PASSWORD }}
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.10.
|
3
|
+
Version: 0.10.2
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -68,6 +68,10 @@ from transformers import AutoConfig
|
|
68
68
|
from transformers.file_utils import CONFIG_NAME
|
69
69
|
|
70
70
|
|
71
|
+
if TYPE_CHECKING:
|
72
|
+
from compressed_tensors.compressors import BaseQuantizationCompressor
|
73
|
+
|
74
|
+
|
71
75
|
__all__ = ["ModelCompressor", "map_module_to_scheme"]
|
72
76
|
|
73
77
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
@@ -257,7 +261,9 @@ class ModelCompressor:
|
|
257
261
|
self.sparsity_config = sparsity_config
|
258
262
|
self.quantization_config = quantization_config
|
259
263
|
self.sparsity_compressor = None
|
260
|
-
self.quantization_compressor
|
264
|
+
self.quantization_compressor: Optional[
|
265
|
+
Union[BaseQuantizationCompressor, DenseCompressor]
|
266
|
+
] = None
|
261
267
|
|
262
268
|
if sparsity_config is not None:
|
263
269
|
self.sparsity_compressor = BaseCompressor.load_from_registry(
|
@@ -12,13 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Dict, Generator, Tuple
|
15
|
+
from typing import TYPE_CHECKING, Dict, Generator, Tuple
|
16
16
|
|
17
|
+
import torch
|
17
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
18
19
|
from compressed_tensors.config import CompressionFormat
|
19
20
|
from torch import Tensor
|
20
21
|
|
21
22
|
|
23
|
+
if TYPE_CHECKING:
|
24
|
+
from compressed_tensors.quantization import QuantizationScheme
|
25
|
+
|
26
|
+
|
22
27
|
@BaseCompressor.register(name=CompressionFormat.dense.value)
|
23
28
|
class DenseCompressor(BaseCompressor):
|
24
29
|
"""
|
@@ -47,3 +52,16 @@ class DenseCompressor(BaseCompressor):
|
|
47
52
|
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
|
48
53
|
for key, value in state_dict.items():
|
49
54
|
yield key, value
|
55
|
+
|
56
|
+
def decompress_module_from_state_dict(
|
57
|
+
self,
|
58
|
+
prefix: str,
|
59
|
+
state_dict: Dict[str, torch.Tensor],
|
60
|
+
scheme: "QuantizationScheme",
|
61
|
+
) -> Dict[str, torch.Tensor]:
|
62
|
+
"""
|
63
|
+
This function is implemented as a workaround because of how
|
64
|
+
`ModelCompressor.quantization_compressor` can be set to either
|
65
|
+
an instance of `BaseQuantizationCompressor` or `DenseCompressor`.
|
66
|
+
"""
|
67
|
+
return state_dict.copy()
|
@@ -183,9 +183,7 @@ def apply_quantization_config(
|
|
183
183
|
replace_module(model, name, compressed_linear)
|
184
184
|
|
185
185
|
# target matched - add layer and scheme to target list
|
186
|
-
submodule.quantization_scheme =
|
187
|
-
target_to_scheme, targets, name
|
188
|
-
)
|
186
|
+
submodule.quantization_scheme = scheme
|
189
187
|
|
190
188
|
names_to_scheme[name] = submodule.quantization_scheme
|
191
189
|
|
@@ -0,0 +1,164 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn.utils.parametrize as P
|
20
|
+
from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
|
21
|
+
from compressed_tensors.registry.registry import RegistryMixin, T
|
22
|
+
from compressed_tensors.transform import (
|
23
|
+
TransformArgs,
|
24
|
+
TransformLocation,
|
25
|
+
TransformScheme,
|
26
|
+
)
|
27
|
+
from compressed_tensors.utils import (
|
28
|
+
align_module_device,
|
29
|
+
has_offloaded_params,
|
30
|
+
patch_attr,
|
31
|
+
register_offload_module,
|
32
|
+
update_offload_parameter,
|
33
|
+
)
|
34
|
+
from torch import Tensor
|
35
|
+
from torch.nn import Module, Parameter
|
36
|
+
|
37
|
+
|
38
|
+
__all__ = ["TransformFactory", "TransformBase"]
|
39
|
+
|
40
|
+
|
41
|
+
class TransformFactory(RegistryMixin, ABC):
|
42
|
+
"""
|
43
|
+
Abstract factory base used to create and apply transforms to a model
|
44
|
+
|
45
|
+
:param name: name associated with transform scheme
|
46
|
+
:param scheme: transform scheme which defines how transforms should be created
|
47
|
+
:param seed: random seed used to transform weight randomization
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
51
|
+
self.name = name
|
52
|
+
self.scheme = scheme
|
53
|
+
self.generator = torch.Generator()
|
54
|
+
if seed is not None:
|
55
|
+
self.generator.manual_seed(seed)
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T:
|
59
|
+
"""
|
60
|
+
Create a transform factory from a scheme
|
61
|
+
|
62
|
+
:param scheme: defines how transforms should be created
|
63
|
+
:param kwargs: TransformFactory constructor arguments
|
64
|
+
:return: subclass of `TransformFactory` corresponding to the scheme type
|
65
|
+
"""
|
66
|
+
constructor = cls.get_value_from_registry(name=scheme.type)
|
67
|
+
return constructor(scheme=scheme, **kwargs)
|
68
|
+
|
69
|
+
@abstractmethod
|
70
|
+
def create_transform(self, module: Module, args: TransformArgs) -> "TransformBase":
|
71
|
+
"""
|
72
|
+
Abstract method which defines how a transform should be created. May utilize
|
73
|
+
caching to maximize shared memory
|
74
|
+
|
75
|
+
:param module: parent module that transform will be applied to
|
76
|
+
:param args: defines how the transform will be applied to the module
|
77
|
+
:return: instance of TransformBase
|
78
|
+
"""
|
79
|
+
raise NotImplementedError()
|
80
|
+
|
81
|
+
def apply_to_model(self, model: Module):
|
82
|
+
"""
|
83
|
+
Create transforms and apply them to the model
|
84
|
+
|
85
|
+
:param model: module to apply transforms to
|
86
|
+
"""
|
87
|
+
for arg in self.scheme.apply:
|
88
|
+
for name, module in list(model.named_modules()):
|
89
|
+
if is_target(name, module, arg.targets, arg.ignore):
|
90
|
+
self._apply_to_module(module, arg)
|
91
|
+
|
92
|
+
def _apply_to_module(self, module: Module, args: TransformArgs):
|
93
|
+
"""
|
94
|
+
Create transforms and apply them to the module
|
95
|
+
|
96
|
+
:param module: target module to apply transforms to
|
97
|
+
:param args: defines how the transform will be applied to the target module
|
98
|
+
"""
|
99
|
+
# create transform as submodule
|
100
|
+
transform_name = f"{self.name}_{args.location.value}"
|
101
|
+
transform = self.create_transform(module, args)
|
102
|
+
register_offload_module(module, transform_name, transform) # (1)
|
103
|
+
|
104
|
+
# register input transformation hook
|
105
|
+
if args.location == TransformLocation.INPUT:
|
106
|
+
|
107
|
+
def input_hook(_, args):
|
108
|
+
input = args[0]
|
109
|
+
return transform(input)
|
110
|
+
|
111
|
+
module.register_forward_pre_hook(input_hook, prepend=True)
|
112
|
+
|
113
|
+
# eagerly apply transformation to weight
|
114
|
+
elif args.location in (
|
115
|
+
TransformLocation.WEIGHT_INPUT,
|
116
|
+
TransformLocation.WEIGHT_OUTPUT,
|
117
|
+
):
|
118
|
+
assert isinstance(module, torch.nn.Linear)
|
119
|
+
assert module.bias is None
|
120
|
+
|
121
|
+
with torch.no_grad(), align_module_device(module):
|
122
|
+
update_offload_parameter(module, "weight", transform(module.weight))
|
123
|
+
|
124
|
+
if self.scheme.requires_grad:
|
125
|
+
# for training, the weight changes with every forward pass
|
126
|
+
# so we can leverage parametrization to propagate the gradient
|
127
|
+
if has_offloaded_params(module):
|
128
|
+
raise ValueError("Offloaded training is not supported")
|
129
|
+
P.register_parametrization(module, "weight", transform)
|
130
|
+
|
131
|
+
# register output transformation hook
|
132
|
+
elif args.location == TransformLocation.OUTPUT:
|
133
|
+
|
134
|
+
def output_hook(_, _input, output):
|
135
|
+
return transform(output)
|
136
|
+
|
137
|
+
module.register_forward_hook(output_hook)
|
138
|
+
|
139
|
+
# other locations such as q_attn and k_attn have not been implemented
|
140
|
+
else:
|
141
|
+
raise NotImplementedError()
|
142
|
+
|
143
|
+
# (1) even in the `weight` cases, this submodule attachment is needed in order
|
144
|
+
# to support saving in the frozen state
|
145
|
+
|
146
|
+
|
147
|
+
class TransformBase(Module, ABC):
|
148
|
+
"""
|
149
|
+
Represents the application of a transform accord to TransformArgs
|
150
|
+
"""
|
151
|
+
|
152
|
+
args: TransformArgs
|
153
|
+
weight: Parameter
|
154
|
+
|
155
|
+
@abstractmethod
|
156
|
+
def forward(self, value: Tensor) -> Tensor:
|
157
|
+
raise NotImplementedError()
|
158
|
+
|
159
|
+
def right_inverse(self, value: Tensor) -> Tensor:
|
160
|
+
with patch_attr(self.args, "inverse", not self.args.inverse):
|
161
|
+
return self.forward(value)
|
162
|
+
|
163
|
+
def __repr__(self):
|
164
|
+
return f"{self.__class__.__name__}(inverse={self.args.inverse})"
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.transform import TransformArgs, TransformScheme
|
19
|
+
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
|
20
|
+
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
21
|
+
from compressed_tensors.transform.utils.utils import (
|
22
|
+
apply_transform_weight,
|
23
|
+
get_matrix_size,
|
24
|
+
)
|
25
|
+
from compressed_tensors.utils import get_offloaded_device
|
26
|
+
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
27
|
+
from torch import Tensor, device, dtype
|
28
|
+
from torch.nn import Linear, Module, Parameter
|
29
|
+
|
30
|
+
|
31
|
+
@TransformFactory.register("hadamard")
|
32
|
+
class HadamardFactory(TransformFactory):
|
33
|
+
"""
|
34
|
+
Factory used to apply hadamard transforms to a model
|
35
|
+
|
36
|
+
:param name: name associated with transform scheme
|
37
|
+
:param scheme: transform scheme which defines how transforms should be created
|
38
|
+
:param seed: random seed used to transform weight randomization
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
42
|
+
super().__init__(name, scheme, seed)
|
43
|
+
self.weights = ParameterizedDefaultDict(self._create_weight)
|
44
|
+
|
45
|
+
def create_transform(self, module: Module, args: TransformArgs):
|
46
|
+
"""
|
47
|
+
Create a HadamardTransform for applying to a module. Transforms with the same
|
48
|
+
size, dtype, and device are cached
|
49
|
+
|
50
|
+
:param module: parent module that transform will be applied to
|
51
|
+
:param args: defines how the transform will be applied to the module
|
52
|
+
"""
|
53
|
+
assert isinstance(module, Linear)
|
54
|
+
size = get_matrix_size(module, args.location)
|
55
|
+
dtype = module.weight.dtype
|
56
|
+
device = get_offloaded_device(module)
|
57
|
+
|
58
|
+
weight = self.weights[size, dtype, device]
|
59
|
+
return HadamardTransform(weight, args)
|
60
|
+
|
61
|
+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
62
|
+
data = deterministic_hadamard_matrix(size, dtype, device)
|
63
|
+
data = data.to(dtype=dtype, device=device)
|
64
|
+
return Parameter(data, requires_grad=self.scheme.requires_grad)
|
65
|
+
|
66
|
+
|
67
|
+
class HadamardTransform(TransformBase):
|
68
|
+
def __init__(self, weight: Parameter, args: TransformArgs):
|
69
|
+
super().__init__()
|
70
|
+
self.weight = weight
|
71
|
+
self.args = args
|
72
|
+
|
73
|
+
def forward(self, value: Tensor) -> Tensor:
|
74
|
+
if not self.args.inverse:
|
75
|
+
weight = self.weight
|
76
|
+
else:
|
77
|
+
weight = self.weight.T
|
78
|
+
|
79
|
+
return apply_transform_weight(weight, value, self.args.location)
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from compressed_tensors.transform import TransformArgs, TransformScheme
|
19
|
+
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
|
20
|
+
from compressed_tensors.transform.utils.utils import (
|
21
|
+
apply_transform_weight,
|
22
|
+
get_matrix_size,
|
23
|
+
)
|
24
|
+
from compressed_tensors.utils import get_offloaded_device
|
25
|
+
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
|
26
|
+
from torch import Tensor, device, dtype
|
27
|
+
from torch.nn import Linear, Module, Parameter
|
28
|
+
|
29
|
+
|
30
|
+
@TransformFactory.register("random-matrix")
|
31
|
+
class RandomMatrixFactory(TransformFactory):
|
32
|
+
"""
|
33
|
+
Factory used to apply random matrix transforms to a model
|
34
|
+
|
35
|
+
:param name: name associated with transform scheme
|
36
|
+
:param scheme: transform scheme which defines how transforms should be created
|
37
|
+
:param seed: random seed used to transform weight randomization
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
|
41
|
+
super().__init__(name, scheme, seed)
|
42
|
+
self.weights = ParameterizedDefaultDict(self._create_weight)
|
43
|
+
self.inverses = ParameterizedDefaultDict(self._create_inverse)
|
44
|
+
|
45
|
+
def create_transform(self, module: Module, args: TransformArgs):
|
46
|
+
"""
|
47
|
+
Create a RandomMatrixTransform for applying to a module. Transforms with the
|
48
|
+
same size, dtype, and device are cached
|
49
|
+
|
50
|
+
:param module: parent module that transform will be applied to
|
51
|
+
:param args: defines how the transform will be applied to the module
|
52
|
+
"""
|
53
|
+
assert isinstance(module, Linear)
|
54
|
+
size = get_matrix_size(module, args.location)
|
55
|
+
dtype = module.weight.dtype
|
56
|
+
device = get_offloaded_device(module)
|
57
|
+
|
58
|
+
weight = self.weights[size, dtype, device]
|
59
|
+
if args.inverse:
|
60
|
+
weight = self.inverses[weight]
|
61
|
+
|
62
|
+
return RandomMatrixTransform(weight, args)
|
63
|
+
|
64
|
+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
65
|
+
data = torch.rand(
|
66
|
+
(size, size), generator=self.generator, dtype=dtype, device=device
|
67
|
+
)
|
68
|
+
return Parameter(data, requires_grad=self.scheme.requires_grad)
|
69
|
+
|
70
|
+
def _create_inverse(self, weight: Parameter) -> Parameter:
|
71
|
+
data = high_precision_invert(weight.data)
|
72
|
+
return Parameter(data, requires_grad=False)
|
73
|
+
|
74
|
+
|
75
|
+
class RandomMatrixTransform(TransformBase):
|
76
|
+
def __init__(self, weight: Tensor, args: TransformArgs):
|
77
|
+
super().__init__()
|
78
|
+
self.weight = weight # is an inverse if args.inverse
|
79
|
+
self.args = args
|
80
|
+
|
81
|
+
def forward(self, value: Tensor) -> Parameter:
|
82
|
+
return apply_transform_weight(self.weight, value, self.args.location)
|
83
|
+
|
84
|
+
def right_inverse(self, value: Tensor) -> Tensor:
|
85
|
+
inverse = high_precision_invert(self.weight)
|
86
|
+
return apply_transform_weight(inverse, value, self.args.location)
|
87
|
+
|
88
|
+
|
89
|
+
def high_precision_invert(weight: Tensor) -> Tensor:
|
90
|
+
return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from compressed_tensors.transform import HadamardFactory, TransformFactory
|
16
|
+
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
|
17
|
+
from torch import device, dtype
|
18
|
+
from torch.nn import Parameter
|
19
|
+
|
20
|
+
|
21
|
+
@TransformFactory.register("random-hadamard")
|
22
|
+
class RandomHadamardFactory(HadamardFactory):
|
23
|
+
"""
|
24
|
+
Factory used to apply random hadamard transforms to a model
|
25
|
+
|
26
|
+
:param name: name associated with transform scheme
|
27
|
+
:param scheme: transform scheme which defines how transforms should be created
|
28
|
+
:param seed: random seed used to transform weight randomization
|
29
|
+
"""
|
30
|
+
|
31
|
+
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
|
32
|
+
data = random_hadamard_matrix(size, dtype, device, self.generator)
|
33
|
+
data = data.to(dtype=dtype, device=device)
|
34
|
+
return Parameter(data, requires_grad=self.scheme.requires_grad)
|
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from safetensors import safe_open
|
21
|
+
|
22
|
+
|
23
|
+
REPO_PATH = Path(__file__).parent / "hadamards.safetensors"
|
24
|
+
|
25
|
+
|
26
|
+
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
|
27
|
+
|
28
|
+
|
29
|
+
# note that hadamard matrix multiplication can be accelerated using a library such as
|
30
|
+
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
|
31
|
+
|
32
|
+
|
33
|
+
def deterministic_hadamard_matrix(
|
34
|
+
size: int,
|
35
|
+
dtype: torch.dtype = torch.bfloat16,
|
36
|
+
device: torch.device = torch.device("cpu"),
|
37
|
+
) -> torch.Tensor:
|
38
|
+
"""
|
39
|
+
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
|
40
|
+
`n` must be a power of 2.
|
41
|
+
|
42
|
+
Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
|
43
|
+
|
44
|
+
:param size: order of the matrix, must be a power of 2
|
45
|
+
:param dtype: data type of matrix
|
46
|
+
:param device: device to construct matrix on
|
47
|
+
:return: hadamard matrix of size `size`
|
48
|
+
"""
|
49
|
+
if size <= 0:
|
50
|
+
raise ValueError("Cannot construct deterministic hadamard of size <= 0")
|
51
|
+
|
52
|
+
log2 = int(math.log2(size))
|
53
|
+
if size != 2**log2:
|
54
|
+
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
|
55
|
+
|
56
|
+
H = torch.tensor([[1]], dtype=dtype, device=device)
|
57
|
+
|
58
|
+
# Sylvester's construction
|
59
|
+
for _ in range(log2):
|
60
|
+
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
|
61
|
+
|
62
|
+
return H / math.sqrt(size)
|
63
|
+
|
64
|
+
|
65
|
+
def random_hadamard_matrix(
|
66
|
+
size: int,
|
67
|
+
dtype: torch.dtype = torch.bfloat16,
|
68
|
+
device: torch.device = torch.device("cpu"),
|
69
|
+
gen: Optional[torch.Generator] = None,
|
70
|
+
) -> torch.Tensor:
|
71
|
+
"""
|
72
|
+
Produces a randomly generated Hadamard matrix. Differs from
|
73
|
+
`deterministic_hadamard_matrix` in that this function supports non powers of 2
|
74
|
+
and randomization using a seeded generator
|
75
|
+
|
76
|
+
Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
|
77
|
+
Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501
|
78
|
+
|
79
|
+
:param size: The dimension of the hamadard matrix
|
80
|
+
:param dtype: data type of matrix
|
81
|
+
:param device: device to construct matrix on
|
82
|
+
:param gen: Optional generator random values
|
83
|
+
:return: randomly generated hadamard matrix
|
84
|
+
"""
|
85
|
+
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
|
86
|
+
Q = Q.to(device=device)
|
87
|
+
Q = Q * 2 - 1
|
88
|
+
Q = torch.diag(Q)
|
89
|
+
return _matmul_hadU(Q) / math.sqrt(size)
|
90
|
+
|
91
|
+
|
92
|
+
def is_pow2(n: int) -> bool:
|
93
|
+
"""
|
94
|
+
Check if a number is a power of 2
|
95
|
+
|
96
|
+
:param n: number to check
|
97
|
+
:return: True iff `n` is a power of 2
|
98
|
+
"""
|
99
|
+
return n > 0 and (n & (n - 1) == 0)
|
100
|
+
|
101
|
+
|
102
|
+
def _fetch_hadamard_divisor(
|
103
|
+
n: int,
|
104
|
+
dtype: torch.dtype,
|
105
|
+
device: torch.device = torch.device("cpu"),
|
106
|
+
file_path: str = REPO_PATH,
|
107
|
+
) -> Optional[torch.Tensor]:
|
108
|
+
"""
|
109
|
+
Fetch a known hadamard matrix from the given file path. The returned matrix will
|
110
|
+
be of of size `k` such that `n / k` is a power of two. Return None if no such
|
111
|
+
matrix exists.
|
112
|
+
|
113
|
+
Note: This function reopens the safetensors file every time it is called.
|
114
|
+
This is technically inefficient, but a very small runtime cost and simpler
|
115
|
+
than forcing callers to manage the file open context
|
116
|
+
|
117
|
+
:param n: size of known hadamard matrix
|
118
|
+
:return: a known hadamard matrix of size `n` if one exists, else None
|
119
|
+
"""
|
120
|
+
with safe_open(file_path, framework="pt", device=str(device)) as file:
|
121
|
+
divisors = sorted((int(key) for key in file.keys()), reverse=True)
|
122
|
+
for divisor in divisors:
|
123
|
+
if n % divisor == 0 and is_pow2(n // divisor):
|
124
|
+
return file.get_tensor(str(divisor)).to(dtype=dtype)
|
125
|
+
|
126
|
+
return None
|
127
|
+
|
128
|
+
|
129
|
+
def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
|
130
|
+
size = X.size(0)
|
131
|
+
dtype = X.dtype
|
132
|
+
device = X.device
|
133
|
+
|
134
|
+
# Check if we have the determined hadamard matrix
|
135
|
+
hadK = _fetch_hadamard_divisor(size, dtype, device=device)
|
136
|
+
if hadK is None:
|
137
|
+
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
|
138
|
+
K = hadK.size(0)
|
139
|
+
|
140
|
+
# Reshape diag matrix with randomized -1/+1
|
141
|
+
input = X.clone().view(-1, size, 1)
|
142
|
+
output = input.clone()
|
143
|
+
while input.shape[1] > K:
|
144
|
+
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
|
145
|
+
output = output.view(input.shape)
|
146
|
+
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
|
147
|
+
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
|
148
|
+
output = output.view(input.shape[0], input.shape[1], -1)
|
149
|
+
(input, output) = (output, input)
|
150
|
+
assert input.shape[1] == K
|
151
|
+
del output
|
152
|
+
|
153
|
+
# Do not explicitly repeat - OOM
|
154
|
+
# input = torch.bmm(
|
155
|
+
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
|
156
|
+
# Use bcast instead
|
157
|
+
input = hadK.view(1, K, K).to(input) @ input
|
158
|
+
|
159
|
+
# normalize
|
160
|
+
return input.view(X.shape)
|