compressed-tensors 0.9.5a20250428__tar.gz → 0.9.5a20250502__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {compressed_tensors-0.9.5a20250428/src/compressed_tensors.egg-info → compressed_tensors-0.9.5a20250502}/PKG-INFO +1 -1
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +18 -36
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/quantized_compressors/base.py +84 -77
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +23 -11
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/lifecycle/apply.py +4 -4
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/lifecycle/initialize.py +2 -1
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/utils/helpers.py +7 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/version.py +1 -1
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +7 -7
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/quantized_compressors/test_int_quant.py +9 -9
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/quantized_compressors/test_pack_quant.py +17 -21
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +4 -4
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/lifecycle/test_initialize.py +3 -1
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.github/.gitkeep +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.github/actions/test/action.yml +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.github/scripts/step-status +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.github/workflows/build-test.yml +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.github/workflows/build.yml +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.github/workflows/report.yml +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.github/workflows/test-check.yaml +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.github/workflows/test.yml +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.github/workflows/trigger-all.yml +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.github/workflows/upload.yml +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/.gitignore +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/LICENSE +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/Makefile +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/README.md +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/examples/bit_packing/int4_config.json +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/examples/bitmask_compression.ipynb +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/examples/llama_1.1b/ex_config_quantization.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/examples/llama_1.1b/example_quant_config.json +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/examples/quantize_and_pack_int4.ipynb +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/pyproject.toml +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/setup.cfg +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/setup.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/README.md +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/base.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/base.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/helpers.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/config/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/config/base.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/config/dense.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/linear/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/linear/compressed_linear.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/quant_args.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/quant_config.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/registry/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/registry/registry.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/utils/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/utils/offload.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/utils/permutations_24.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/utils/permute.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/utils/safetensors_load.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors.egg-info/SOURCES.txt +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors.egg-info/requires.txt +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/src/compressed_tensors.egg-info/top_level.txt +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/conftest.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_configs/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_configs/test_base.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_linear/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_linear/test_compressed_linear.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/lifecycle/conftest.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/lifecycle/test_apply.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/lifecycle/test_forward.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/test_configs/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/test_configs/test_strategies.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/test_quant_args.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/test_quant_config.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/test_quant_scheme.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_quantization/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_registry.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_utils/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_utils/test_offload.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/test_utils/test_safetensors_load.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/tests/testing_utils.py +0 -0
- {compressed_tensors-0.9.5a20250428 → compressed_tensors-0.9.5a20250502}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.9.
|
3
|
+
Version: 0.9.5a20250502
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -19,7 +19,7 @@ import os
|
|
19
19
|
import re
|
20
20
|
from contextlib import contextmanager
|
21
21
|
from copy import deepcopy
|
22
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set,
|
22
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
|
23
23
|
|
24
24
|
import compressed_tensors
|
25
25
|
import torch
|
@@ -36,12 +36,12 @@ from compressed_tensors.config import CompressionFormat, SparsityCompressionConf
|
|
36
36
|
from compressed_tensors.quantization import (
|
37
37
|
DEFAULT_QUANTIZATION_METHOD,
|
38
38
|
QuantizationConfig,
|
39
|
+
QuantizationScheme,
|
39
40
|
QuantizationStatus,
|
40
41
|
apply_quantization_config,
|
41
42
|
load_pretrained_quantization_parameters,
|
42
43
|
)
|
43
44
|
from compressed_tensors.quantization.lifecycle import expand_target_names
|
44
|
-
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
45
45
|
from compressed_tensors.quantization.utils import (
|
46
46
|
is_module_quantized,
|
47
47
|
iter_named_leaf_modules,
|
@@ -64,7 +64,7 @@ from transformers import AutoConfig
|
|
64
64
|
from transformers.file_utils import CONFIG_NAME
|
65
65
|
|
66
66
|
|
67
|
-
__all__ = ["ModelCompressor", "
|
67
|
+
__all__ = ["ModelCompressor", "map_module_to_scheme"]
|
68
68
|
|
69
69
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
70
70
|
|
@@ -372,20 +372,17 @@ class ModelCompressor:
|
|
372
372
|
:param state_dict: optional uncompressed state_dict to insert into model
|
373
373
|
:return: compressed state dict
|
374
374
|
"""
|
375
|
+
|
375
376
|
if state_dict is None:
|
376
377
|
state_dict = model.state_dict()
|
377
378
|
|
378
|
-
compressed_state_dict = state_dict
|
379
|
-
|
380
|
-
quantized_modules_to_args: Dict[
|
381
|
-
str, QuantizationArgs
|
382
|
-
] = map_modules_to_quant_args(model)
|
383
|
-
|
384
379
|
if self.quantization_compressor is not None:
|
385
|
-
|
386
|
-
|
380
|
+
module_to_scheme = map_module_to_scheme(model)
|
381
|
+
state_dict = self.quantization_compressor.compress(
|
382
|
+
state_dict, names_to_scheme=module_to_scheme
|
387
383
|
)
|
388
384
|
|
385
|
+
# TODO: consider sparse compression to also be compression
|
389
386
|
if self.quantization_config.format != CompressionFormat.dense.value:
|
390
387
|
self.quantization_config.quantization_status = (
|
391
388
|
QuantizationStatus.COMPRESSED
|
@@ -397,8 +394,8 @@ class ModelCompressor:
|
|
397
394
|
targets=self.sparsity_config.targets,
|
398
395
|
ignore=self.sparsity_config.ignore,
|
399
396
|
)
|
400
|
-
|
401
|
-
|
397
|
+
state_dict = self.sparsity_compressor.compress(
|
398
|
+
state_dict,
|
402
399
|
compression_targets=sparse_compression_targets,
|
403
400
|
)
|
404
401
|
|
@@ -407,7 +404,7 @@ class ModelCompressor:
|
|
407
404
|
# https://github.com/huggingface/transformers/pull/30488
|
408
405
|
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
|
409
406
|
|
410
|
-
return
|
407
|
+
return state_dict
|
411
408
|
|
412
409
|
def decompress(self, model_path: str, model: Module):
|
413
410
|
"""
|
@@ -605,30 +602,15 @@ class ModelCompressor:
|
|
605
602
|
update_parameter_data(module, param_data, param_name)
|
606
603
|
|
607
604
|
|
608
|
-
def
|
609
|
-
model: Module,
|
610
|
-
) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]:
|
605
|
+
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
|
611
606
|
"""
|
612
|
-
|
613
|
-
to the weight QuantizationArgs. If running input activation quantization, will also
|
614
|
-
map to the input QuantizationArgs in a tuple.
|
615
|
-
|
616
|
-
:param model: pytorch model
|
607
|
+
Returns a dictionary which maps quantized module names to their quantization schemes
|
617
608
|
"""
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
|
624
|
-
if submodule.quantization_scheme.input_activations is not None:
|
625
|
-
weight_args = quantized_modules_to_args.get(name)
|
626
|
-
quantized_modules_to_args[name] = (
|
627
|
-
weight_args,
|
628
|
-
submodule.quantization_scheme.input_activations,
|
629
|
-
)
|
630
|
-
|
631
|
-
return quantized_modules_to_args
|
609
|
+
return {
|
610
|
+
fix_fsdp_module_name(name): module.quantization_scheme
|
611
|
+
for name, module in iter_named_leaf_modules(model)
|
612
|
+
if is_module_quantized(module)
|
613
|
+
}
|
632
614
|
|
633
615
|
|
634
616
|
# HACK: Override the dtype_byte_size function in transformers to support float8 types
|
@@ -14,15 +14,16 @@
|
|
14
14
|
|
15
15
|
import logging
|
16
16
|
from pathlib import Path
|
17
|
-
from typing import Any, Dict, Generator,
|
17
|
+
from typing import Any, Dict, Generator, Tuple, Union
|
18
18
|
|
19
19
|
import torch
|
20
20
|
from compressed_tensors.compressors.base import BaseCompressor
|
21
|
-
from compressed_tensors.quantization import
|
21
|
+
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
|
22
22
|
from compressed_tensors.utils import (
|
23
23
|
get_nested_mappings_from_state_dict,
|
24
24
|
get_nested_weight_mappings,
|
25
25
|
merge_names,
|
26
|
+
remove_suffix,
|
26
27
|
)
|
27
28
|
from safetensors import safe_open
|
28
29
|
from torch import Tensor
|
@@ -69,7 +70,7 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
69
70
|
def compress(
|
70
71
|
self,
|
71
72
|
model_state: Dict[str, Tensor],
|
72
|
-
names_to_scheme: Dict[str,
|
73
|
+
names_to_scheme: Dict[str, QuantizationScheme],
|
73
74
|
**kwargs,
|
74
75
|
) -> Dict[str, Tensor]:
|
75
76
|
"""
|
@@ -81,87 +82,87 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
81
82
|
:return: compressed state dict
|
82
83
|
"""
|
83
84
|
compressed_dict = {}
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
85
|
+
save_device = "cpu"
|
86
|
+
|
87
|
+
uncompressed_names = list(model_state.keys())
|
88
|
+
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
|
89
|
+
value = model_state[name]
|
90
|
+
|
91
|
+
# compress weights
|
92
|
+
if name.endswith("weight"):
|
93
|
+
prefix = remove_suffix(name, "weight")
|
94
|
+
|
95
|
+
# gather qparams
|
96
|
+
scale = model_state.get(prefix + "weight_scale", None)
|
97
|
+
g_idx = model_state.get(prefix + "weight_g_idx", None)
|
98
|
+
zp = model_state.get(prefix + "weight_zero_point", None)
|
99
|
+
|
100
|
+
# is scale does not exist, then weight cannot be compressed
|
101
|
+
if scale is None:
|
102
|
+
compressed_dict[name] = value.to(save_device)
|
103
|
+
continue
|
104
|
+
|
105
|
+
# compress values on cpu (memory movement too expensive)
|
106
|
+
module_path = prefix[:-1] if prefix.endswith(".") else prefix
|
107
|
+
quant_args = names_to_scheme[module_path].weights
|
108
|
+
compressed_values = self.compress_weight(
|
109
|
+
weight=value,
|
110
|
+
scale=scale,
|
111
|
+
zero_point=zp,
|
112
|
+
g_idx=g_idx,
|
113
|
+
quantization_args=quant_args,
|
114
|
+
device="cpu",
|
115
|
+
)
|
116
|
+
|
117
|
+
# update state dict
|
118
|
+
for key, value in compressed_values.items():
|
119
|
+
compressed_dict[prefix + key] = value.to(save_device)
|
90
120
|
|
91
|
-
for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
|
92
|
-
# check if the parameter we're compressing is the weight zp
|
93
|
-
# or the input zp
|
94
|
-
is_weight_zp = name.endswith(weight_zp_suffix)
|
95
|
-
is_input_zp = name.endswith(input_zp_suffix)
|
96
|
-
|
97
|
-
# if we're saving the weight zp, fetch weight quant args
|
98
|
-
if is_weight_zp:
|
99
|
-
quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))])
|
100
|
-
if isinstance(quant_args_zp, tuple):
|
101
|
-
# If tuple, first value is weight args, second is input args
|
102
|
-
quant_args_zp = quant_args_zp[0]
|
103
|
-
|
104
|
-
# if we're saving the input zp, fetch input quant args
|
105
|
-
if is_input_zp:
|
106
|
-
input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))])
|
107
|
-
if isinstance(input_args_zp, tuple):
|
108
|
-
# If tuple, first value is weight args, second is input args
|
109
|
-
input_args_zp = input_args_zp[-1]
|
110
|
-
|
111
|
-
if name.endswith(weight_suffix):
|
112
|
-
prefix = name[: -(len(weight_suffix))]
|
113
|
-
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
|
114
|
-
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
|
115
|
-
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
|
116
|
-
if scale is not None:
|
117
|
-
# weight is quantized, compress it
|
118
|
-
if isinstance(names_to_scheme[prefix], tuple):
|
119
|
-
quant_args = names_to_scheme[prefix][0]
|
120
|
-
else:
|
121
|
-
quant_args = names_to_scheme[prefix]
|
122
|
-
|
123
|
-
compressed_data = self.compress_weight(
|
124
|
-
weight=value,
|
125
|
-
scale=scale,
|
126
|
-
zero_point=zp,
|
127
|
-
g_idx=g_idx,
|
128
|
-
quantization_args=quant_args,
|
129
|
-
device="cpu",
|
130
|
-
)
|
131
|
-
for key, value in compressed_data.items():
|
132
|
-
compressed_dict[merge_names(prefix, key)] = value
|
133
|
-
else:
|
134
|
-
compressed_dict[name] = value.to("cpu")
|
135
|
-
# only save zp if asym and not packed zp
|
136
|
-
elif is_weight_zp and (
|
137
|
-
quant_args_zp.symmetric or self._check_if_zp_pack_quantized(quant_args)
|
138
|
-
):
|
139
|
-
continue
|
140
|
-
# only save if asym
|
141
|
-
elif is_input_zp and input_args_zp.symmetric:
|
142
|
-
continue
|
143
|
-
elif name.endswith("g_idx") and torch.any(value <= -1):
|
144
|
-
continue
|
145
121
|
else:
|
146
|
-
|
122
|
+
# omit saving zero points for symmetric or packed quantization
|
123
|
+
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
|
124
|
+
continue
|
125
|
+
|
126
|
+
# omit saving for g_idx if uninitialized
|
127
|
+
# TODO: does this case actually occur?
|
128
|
+
elif name.endswith("g_idx") and torch.any(value <= -1):
|
129
|
+
continue
|
130
|
+
|
131
|
+
compressed_dict[name] = value.to(save_device)
|
147
132
|
|
148
133
|
return compressed_dict
|
149
134
|
|
150
|
-
def
|
135
|
+
def _skip_zp(
|
136
|
+
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
|
137
|
+
) -> bool:
|
151
138
|
from compressed_tensors.compressors import PackedQuantizationCompressor
|
152
139
|
|
153
|
-
if
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
140
|
+
module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
|
141
|
+
scheme = names_to_scheme[module_name]
|
142
|
+
|
143
|
+
if zp_name == "weight_zero_point":
|
144
|
+
args = scheme.weights
|
145
|
+
if zp_name == "input_zero_point":
|
146
|
+
args = scheme.input_activations
|
147
|
+
if zp_name == "output_zero_point":
|
148
|
+
args = scheme.output_activations
|
149
|
+
|
150
|
+
symmetric = args.symmetric
|
151
|
+
packable_strategies = [
|
152
|
+
QuantizationStrategy.GROUP.value,
|
153
|
+
QuantizationStrategy.CHANNEL.value,
|
154
|
+
]
|
155
|
+
packed = (
|
156
|
+
isinstance(self, PackedQuantizationCompressor)
|
157
|
+
and args.strategy in packable_strategies
|
158
|
+
)
|
159
|
+
|
160
|
+
return symmetric or packed
|
160
161
|
|
161
162
|
def decompress(
|
162
163
|
self,
|
163
164
|
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
|
164
|
-
names_to_scheme: Dict[str,
|
165
|
+
names_to_scheme: Dict[str, QuantizationScheme],
|
165
166
|
device: str = "cpu",
|
166
167
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
167
168
|
"""
|
@@ -170,8 +171,9 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
170
171
|
dense state dict
|
171
172
|
:param path_to_model_or_tensors: path to compressed safetensors model (directory
|
172
173
|
with one or more safetensors files) or compressed tensors file
|
173
|
-
:param names_to_scheme: quantization
|
174
|
-
:param device: optional device to load intermediate weights into
|
174
|
+
:param names_to_scheme: quantization scheme for each quantized weight
|
175
|
+
:param device: optional device to load intermediate weights into (must be `str`,
|
176
|
+
not `torch.device`)
|
175
177
|
:return: compressed state dict
|
176
178
|
"""
|
177
179
|
if isinstance(path_to_model_or_tensors, (str, Path)):
|
@@ -184,7 +186,12 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
184
186
|
path_to_model_or_tensors, names_to_scheme
|
185
187
|
)
|
186
188
|
|
187
|
-
def _decompress_from_path(
|
189
|
+
def _decompress_from_path(
|
190
|
+
self,
|
191
|
+
path_to_model: Union[str, Path, Dict[str, Any]],
|
192
|
+
names_to_scheme: Dict[str, QuantizationScheme],
|
193
|
+
device: str,
|
194
|
+
):
|
188
195
|
weight_mappings = get_nested_weight_mappings(
|
189
196
|
path_to_model, self.compression_param_names
|
190
197
|
)
|
@@ -195,7 +202,7 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
195
202
|
with safe_open(safe_path, framework="pt", device=device) as f:
|
196
203
|
weight_data[param_name] = f.get_tensor(full_name)
|
197
204
|
if "weight_scale" in weight_data:
|
198
|
-
quant_args = names_to_scheme[weight_name]
|
205
|
+
quant_args = names_to_scheme[weight_name].weights
|
199
206
|
decompressed = self.decompress_weight(
|
200
207
|
compressed_data=weight_data, quantization_args=quant_args
|
201
208
|
)
|
@@ -19,7 +19,11 @@ import numpy as np
|
|
19
19
|
import torch
|
20
20
|
from compressed_tensors.compressors.base import BaseCompressor
|
21
21
|
from compressed_tensors.config import CompressionFormat
|
22
|
-
from compressed_tensors.quantization import
|
22
|
+
from compressed_tensors.quantization import (
|
23
|
+
QuantizationArgs,
|
24
|
+
QuantizationScheme,
|
25
|
+
QuantizationStrategy,
|
26
|
+
)
|
23
27
|
from compressed_tensors.quantization.lifecycle.forward import quantize
|
24
28
|
from compressed_tensors.utils import (
|
25
29
|
get_permutations_24,
|
@@ -44,19 +48,25 @@ class Marlin24Compressor(BaseCompressor):
|
|
44
48
|
|
45
49
|
@staticmethod
|
46
50
|
def validate_quant_compatability(
|
47
|
-
|
51
|
+
names_to_scheme: Dict[str, QuantizationScheme]
|
48
52
|
) -> bool:
|
49
53
|
"""
|
50
54
|
Checks if every quantized module in the model is compatible with Marlin24
|
51
55
|
compression. Quantization must be channel or group strategy with group_size
|
52
56
|
of 128. Only symmetric quantization is supported
|
53
57
|
|
54
|
-
:param
|
55
|
-
quantization
|
58
|
+
:param names_to_scheme: dictionary of mapping module names to their
|
59
|
+
quantization schemes
|
56
60
|
:return: True if all modules are compatible with Marlin24 compression, raises
|
57
61
|
a ValueError otherwise
|
58
62
|
"""
|
59
|
-
for name,
|
63
|
+
for name, scheme in names_to_scheme.items():
|
64
|
+
quant_args = scheme.weights
|
65
|
+
if quant_args is None:
|
66
|
+
raise ValueError(
|
67
|
+
"Marlin24 Compressor is only valid for weight quantization schemes"
|
68
|
+
)
|
69
|
+
|
60
70
|
strategy = quant_args.strategy
|
61
71
|
group_size = quant_args.group_size
|
62
72
|
symmetric = quant_args.symmetric
|
@@ -114,7 +124,7 @@ class Marlin24Compressor(BaseCompressor):
|
|
114
124
|
def compress(
|
115
125
|
self,
|
116
126
|
model_state: Dict[str, Tensor],
|
117
|
-
names_to_scheme: Dict[str,
|
127
|
+
names_to_scheme: Dict[str, QuantizationScheme],
|
118
128
|
**kwargs,
|
119
129
|
) -> Dict[str, Tensor]:
|
120
130
|
"""
|
@@ -122,8 +132,8 @@ class Marlin24Compressor(BaseCompressor):
|
|
122
132
|
with the Marlin24 kernel
|
123
133
|
|
124
134
|
:param model_state: state dict of uncompressed model
|
125
|
-
:param names_to_scheme: quantization
|
126
|
-
|
135
|
+
:param names_to_scheme: quantization scheme for each quantized weight, needed
|
136
|
+
for quantize function to calculate bit depth
|
127
137
|
:return: compressed state dict
|
128
138
|
"""
|
129
139
|
self.validate_quant_compatability(names_to_scheme)
|
@@ -146,7 +156,7 @@ class Marlin24Compressor(BaseCompressor):
|
|
146
156
|
value = value.to(torch.float16)
|
147
157
|
|
148
158
|
# quantize weight, keeping it as a float16 for now
|
149
|
-
quant_args = names_to_scheme[prefix]
|
159
|
+
quant_args = names_to_scheme[prefix].weights
|
150
160
|
value = quantize(
|
151
161
|
x=value, scale=scale, zero_point=zp, args=quant_args
|
152
162
|
)
|
@@ -215,7 +225,7 @@ def pack_weight_24(
|
|
215
225
|
weight: Tensor,
|
216
226
|
quantization_args: QuantizationArgs,
|
217
227
|
tile: int = 16,
|
218
|
-
):
|
228
|
+
) -> torch.Tensor:
|
219
229
|
size_k = weight.shape[0]
|
220
230
|
size_n = weight.shape[1]
|
221
231
|
num_bits = quantization_args.num_bits
|
@@ -236,7 +246,9 @@ def pack_weight_24(
|
|
236
246
|
return q_packed
|
237
247
|
|
238
248
|
|
239
|
-
def pack_scales_24(
|
249
|
+
def pack_scales_24(
|
250
|
+
scales: torch.Tensor, quantization_args: QuantizationArgs, w_shape: torch.Size
|
251
|
+
) -> torch.Tensor:
|
240
252
|
size_k = w_shape[0]
|
241
253
|
size_n = w_shape[1]
|
242
254
|
num_bits = quantization_args.num_bits
|
@@ -119,7 +119,7 @@ def load_pretrained_quantization_parameters(
|
|
119
119
|
|
120
120
|
def apply_quantization_config(
|
121
121
|
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
|
122
|
-
) ->
|
122
|
+
) -> Dict[str, QuantizationScheme]:
|
123
123
|
"""
|
124
124
|
Initializes the model for quantization in-place based on the given config.
|
125
125
|
Optionally coverts quantizable modules to compressed_linear modules
|
@@ -131,7 +131,7 @@ def apply_quantization_config(
|
|
131
131
|
"""
|
132
132
|
# Workaround for when HF Quantizer passes None, see PR #180
|
133
133
|
if config is None:
|
134
|
-
return
|
134
|
+
return dict()
|
135
135
|
|
136
136
|
# remove reference to the original `config`
|
137
137
|
# argument. This function can mutate it, and we'd
|
@@ -141,7 +141,7 @@ def apply_quantization_config(
|
|
141
141
|
# use ordered dict to preserve target ordering in config
|
142
142
|
target_to_scheme = OrderedDict()
|
143
143
|
config = process_quantization_config(config)
|
144
|
-
names_to_scheme =
|
144
|
+
names_to_scheme = dict()
|
145
145
|
for scheme in config.config_groups.values():
|
146
146
|
for target in scheme.targets:
|
147
147
|
target_to_scheme[target] = scheme
|
@@ -187,7 +187,7 @@ def apply_quantization_config(
|
|
187
187
|
target_to_scheme, targets, name
|
188
188
|
)
|
189
189
|
|
190
|
-
names_to_scheme[name] = submodule.quantization_scheme
|
190
|
+
names_to_scheme[name] = submodule.quantization_scheme
|
191
191
|
|
192
192
|
if config.ignore is not None and ignored_submodules is not None:
|
193
193
|
if set(config.ignore) - set(ignored_submodules):
|
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
import logging
|
17
|
+
import math
|
17
18
|
from enum import Enum
|
18
19
|
from typing import Optional
|
19
20
|
|
@@ -162,7 +163,7 @@ def _initialize_scale_zero_point(
|
|
162
163
|
# (output_channels, 1)
|
163
164
|
expected_shape = (weight_shape[0], 1)
|
164
165
|
elif quantization_args.strategy == QuantizationStrategy.GROUP:
|
165
|
-
num_groups = weight_shape[1]
|
166
|
+
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
|
166
167
|
expected_shape = (weight_shape[0], max(num_groups, 1))
|
167
168
|
|
168
169
|
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
|
@@ -38,6 +38,7 @@ __all__ = [
|
|
38
38
|
"shard_tensor",
|
39
39
|
"pack_bitmasks",
|
40
40
|
"unpack_bitmasks",
|
41
|
+
"remove_suffix",
|
41
42
|
]
|
42
43
|
|
43
44
|
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
@@ -328,3 +329,9 @@ def unpack_bitmasks(
|
|
328
329
|
)
|
329
330
|
|
330
331
|
return unpacked_bitmasks_torch
|
332
|
+
|
333
|
+
|
334
|
+
def remove_suffix(value: str, suffix: str) -> str:
|
335
|
+
# can replace with str.removesuffix in python3.9+
|
336
|
+
assert value.endswith(suffix)
|
337
|
+
return value[: -len(suffix)]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.9.
|
3
|
+
Version: 0.9.5a20250502
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -84,9 +84,9 @@ def test_quant_format(strategy, group_size, sc, zp):
|
|
84
84
|
quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)
|
85
85
|
|
86
86
|
compressor = FloatQuantizationCompressor(config=quant_config)
|
87
|
-
|
87
|
+
module_name_to_scheme = {"dummy": quant_config.config_groups["group_1"]}
|
88
88
|
compressed_state_dict = compressor.compress(
|
89
|
-
dense_state_dict, names_to_scheme=
|
89
|
+
dense_state_dict, names_to_scheme=module_name_to_scheme
|
90
90
|
)
|
91
91
|
|
92
92
|
# state_dict params should be the same, minus the zero_point if symmetric
|
@@ -140,15 +140,15 @@ def test_reload_match(
|
|
140
140
|
)
|
141
141
|
|
142
142
|
compressor = FloatQuantizationCompressor(config=quant_config)
|
143
|
-
|
144
|
-
"dummy": quant_config.config_groups["group_1"]
|
143
|
+
module_name_to_scheme = {
|
144
|
+
"dummy": quant_config.config_groups["group_1"],
|
145
145
|
}
|
146
146
|
compressed_state_dict = compressor.compress(
|
147
|
-
model.state_dict(), names_to_scheme=
|
147
|
+
model.state_dict(), names_to_scheme=module_name_to_scheme
|
148
148
|
)
|
149
149
|
save_file(compressed_state_dict, tmp_path / "model.safetensors")
|
150
150
|
reconstructed_dense_gen = compressor.decompress(
|
151
|
-
tmp_path, names_to_scheme=
|
151
|
+
tmp_path, names_to_scheme=module_name_to_scheme
|
152
152
|
)
|
153
153
|
reconstructed_dense = {}
|
154
154
|
for name, value in reconstructed_dense_gen:
|
@@ -158,7 +158,7 @@ def test_reload_match(
|
|
158
158
|
model.dummy.weight,
|
159
159
|
scale=model.dummy.weight_scale,
|
160
160
|
zero_point=model.dummy.weight_zero_point,
|
161
|
-
args=
|
161
|
+
args=module_name_to_scheme["dummy"].weights,
|
162
162
|
)
|
163
163
|
assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy"].get("weight"))
|
164
164
|
|
@@ -76,9 +76,9 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp):
|
|
76
76
|
)
|
77
77
|
|
78
78
|
compressor = IntQuantizationCompressor(config=quant_config)
|
79
|
-
|
79
|
+
quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]}
|
80
80
|
compressed_state_dict = compressor.compress(
|
81
|
-
dense_state_dict, names_to_scheme=
|
81
|
+
dense_state_dict, names_to_scheme=quantized_modules_to_scheme
|
82
82
|
)
|
83
83
|
|
84
84
|
# state_dict params should be the same, minus the zero_point if symmetric
|
@@ -124,16 +124,16 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
|
|
124
124
|
quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)
|
125
125
|
|
126
126
|
compressor = IntQuantizationCompressor(config=quant_config)
|
127
|
-
|
128
|
-
"dummy": quant_config.config_groups["group_1"]
|
129
|
-
"dummy2": quant_config.config_groups["group_1"]
|
127
|
+
module_name_to_scheme = {
|
128
|
+
"dummy": quant_config.config_groups["group_1"],
|
129
|
+
"dummy2": quant_config.config_groups["group_1"],
|
130
130
|
}
|
131
131
|
compressed_state_dict = compressor.compress(
|
132
|
-
dense_state_dict, names_to_scheme=
|
132
|
+
dense_state_dict, names_to_scheme=module_name_to_scheme
|
133
133
|
)
|
134
134
|
save_file(compressed_state_dict, tmp_path / "model.safetensors")
|
135
135
|
reconstructed_dense_gen = compressor.decompress(
|
136
|
-
tmp_path, names_to_scheme=
|
136
|
+
tmp_path, names_to_scheme=module_name_to_scheme
|
137
137
|
)
|
138
138
|
reconstructed_dense = {}
|
139
139
|
for name, value in reconstructed_dense_gen:
|
@@ -143,7 +143,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
|
|
143
143
|
dense_state_dict["dummy.weight"],
|
144
144
|
scale=dense_state_dict["dummy.weight_scale"],
|
145
145
|
zero_point=dense_state_dict["dummy.weight_zero_point"],
|
146
|
-
args=
|
146
|
+
args=module_name_to_scheme["dummy"].weights,
|
147
147
|
)
|
148
148
|
assert torch.equal(
|
149
149
|
fake_quant_dummy, reconstructed_dense["dummy"].get("weight").to(torch.float32)
|
@@ -153,7 +153,7 @@ def test_reload_match(strategy, group_size, sc, zp, tmp_path):
|
|
153
153
|
dense_state_dict["dummy2.weight"],
|
154
154
|
scale=dense_state_dict["dummy2.weight_scale"],
|
155
155
|
zero_point=dense_state_dict["dummy2.weight_zero_point"],
|
156
|
-
args=
|
156
|
+
args=module_name_to_scheme["dummy2"].weights,
|
157
157
|
)
|
158
158
|
assert torch.equal(
|
159
159
|
fake_quant_dummy2, reconstructed_dense["dummy2"].get("weight").to(torch.float32)
|