compressed-tensors 0.9.5a20250530__tar.gz → 0.9.5a20250603__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {compressed_tensors-0.9.5a20250530/src/compressed_tensors.egg-info → compressed_tensors-0.9.5a20250603}/PKG-INFO +1 -1
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/apply.py +1 -10
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/forward.py +35 -24
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/initialize.py +7 -113
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/quant_args.py +1 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/utils/helpers.py +9 -7
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/offload.py +134 -1
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/version.py +1 -1
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +3 -3
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/quantized_compressors/test_int_quant.py +2 -2
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_forward.py +12 -12
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_utils/test_helpers.py +3 -5
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_utils/test_offload.py +95 -5
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/.gitkeep +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/actions/test/action.yml +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/scripts/step-status +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/build-test.yml +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/build.yml +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/report.yml +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/test-check.yaml +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/test.yml +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/trigger-all.yml +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.github/workflows/upload.yml +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/.gitignore +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/LICENSE +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/Makefile +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/README.md +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/bit_packing/int4_config.json +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/bitmask_compression.ipynb +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/llama_1.1b/ex_config_quantization.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/llama_1.1b/example_quant_config.json +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/examples/quantize_and_pack_int4.ipynb +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/pyproject.toml +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/setup.cfg +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/setup.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/README.md +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/base.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/base.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/helpers.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/config/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/config/base.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/config/dense.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/linear/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/linear/compressed_linear.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/quant_config.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/registry/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/registry/registry.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/transform/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/transform/transform_args.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/transform/transform_config.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/transform/transform_scheme.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/helpers.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/permutations_24.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/permute.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/safetensors_load.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors.egg-info/SOURCES.txt +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors.egg-info/requires.txt +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/src/compressed_tensors.egg-info/top_level.txt +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/conftest.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_configs/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_configs/test_base.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_linear/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_linear/test_compressed_linear.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/conftest.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_apply.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_initialize.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_configs/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_configs/test_strategies.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_quant_args.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_quant_config.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_quantization/test_quant_scheme.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_registry.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_transform/test_transform_args.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_transform/test_transform_config.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_transform/test_transform_scheme.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_utils/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/test_utils/test_safetensors_load.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/tests/testing_utils.py +0 -0
- {compressed_tensors-0.9.5a20250530 → compressed_tensors-0.9.5a20250603}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.9.
|
3
|
+
Version: 0.9.5a20250603
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -27,14 +27,8 @@ from compressed_tensors.quantization.lifecycle.compressed import (
|
|
27
27
|
)
|
28
28
|
from compressed_tensors.quantization.lifecycle.initialize import (
|
29
29
|
initialize_module_for_quantization,
|
30
|
-
update_fused_layer_weight_global_scales,
|
31
|
-
)
|
32
|
-
from compressed_tensors.quantization.quant_args import (
|
33
|
-
FP4_E2M1_DATA,
|
34
|
-
FP8_E4M3_DATA,
|
35
|
-
QuantizationArgs,
|
36
|
-
QuantizationType,
|
37
30
|
)
|
31
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
38
32
|
from compressed_tensors.quantization.quant_config import (
|
39
33
|
QuantizationConfig,
|
40
34
|
QuantizationStatus,
|
@@ -272,9 +266,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
272
266
|
)
|
273
267
|
)
|
274
268
|
|
275
|
-
if status == QuantizationStatus.INITIALIZED:
|
276
|
-
update_fused_layer_weight_global_scales(model)
|
277
|
-
|
278
269
|
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
|
279
270
|
model.apply(compress_quantized_weights)
|
280
271
|
|
@@ -227,31 +227,42 @@ def _process_quantization(
|
|
227
227
|
perm = torch.argsort(g_idx)
|
228
228
|
x = safe_permute(x, perm, dim=1)
|
229
229
|
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
if do_quantize:
|
239
|
-
output[:, start:end] = _quantize(
|
240
|
-
x=x[:, start:end],
|
241
|
-
scale=sc,
|
242
|
-
zero_point=zp,
|
243
|
-
q_min=q_min,
|
244
|
-
q_max=q_max,
|
245
|
-
args=args,
|
246
|
-
dtype=dtype,
|
247
|
-
global_scale=global_scale,
|
248
|
-
)
|
230
|
+
x = torch.reshape(
|
231
|
+
x,
|
232
|
+
(
|
233
|
+
x.shape[0],
|
234
|
+
ceil(x.shape[1] / group_size),
|
235
|
+
group_size,
|
236
|
+
),
|
237
|
+
)
|
249
238
|
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
)
|
239
|
+
if do_quantize:
|
240
|
+
output = _quantize(
|
241
|
+
x=x,
|
242
|
+
scale=scale.unsqueeze(-1),
|
243
|
+
zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
|
244
|
+
dtype=dtype,
|
245
|
+
global_scale=global_scale,
|
246
|
+
q_min=q_min,
|
247
|
+
q_max=q_max,
|
248
|
+
args=args,
|
249
|
+
)
|
250
|
+
|
251
|
+
if do_dequantize:
|
252
|
+
input = output if do_quantize else x
|
253
|
+
output = _dequantize(
|
254
|
+
x_q=input,
|
255
|
+
scale=scale.unsqueeze(-1),
|
256
|
+
zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
|
257
|
+
global_scale=global_scale,
|
258
|
+
)
|
259
|
+
|
260
|
+
output = torch.reshape(
|
261
|
+
output,
|
262
|
+
(output.shape[0], output.shape[1] * output.shape[2]),
|
263
|
+
)
|
264
|
+
|
265
|
+
output = output.to(output_dtype)
|
255
266
|
|
256
267
|
if not is_column_order:
|
257
268
|
output = safe_permute(output, torch.argsort(perm), dim=1)
|
@@ -23,26 +23,18 @@ from compressed_tensors.quantization.lifecycle.forward import (
|
|
23
23
|
wrap_module_forward_quantized,
|
24
24
|
)
|
25
25
|
from compressed_tensors.quantization.quant_args import (
|
26
|
-
FP4_E2M1_DATA,
|
27
26
|
FP8_E4M3_DATA,
|
28
27
|
ActivationOrdering,
|
29
28
|
QuantizationArgs,
|
30
29
|
QuantizationStrategy,
|
31
|
-
QuantizationType,
|
32
30
|
)
|
33
31
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
34
32
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
35
|
-
from compressed_tensors.quantization.utils import
|
36
|
-
generate_global_scale,
|
37
|
-
is_fp4,
|
38
|
-
is_kv_cache_quant_scheme,
|
39
|
-
iter_named_quantizable_modules,
|
40
|
-
)
|
33
|
+
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
|
41
34
|
from compressed_tensors.utils import (
|
42
35
|
disable_hf_hook,
|
43
36
|
get_execution_device,
|
44
37
|
register_offload_parameter,
|
45
|
-
update_parameter_data,
|
46
38
|
)
|
47
39
|
from torch.nn import Module, Parameter
|
48
40
|
|
@@ -51,7 +43,6 @@ __all__ = [
|
|
51
43
|
"initialize_module_for_quantization",
|
52
44
|
"is_attention_module",
|
53
45
|
"KVCacheScaleType",
|
54
|
-
"update_fused_layer_weight_global_scales",
|
55
46
|
]
|
56
47
|
|
57
48
|
|
@@ -162,22 +153,13 @@ def _initialize_scale_zero_point(
|
|
162
153
|
# initialize on execution device to avoid performing quantized ops on cpu
|
163
154
|
device = get_execution_device(module)
|
164
155
|
|
165
|
-
# 1. Create global_scales for tensor_group
|
156
|
+
# 1. Create global_scales for tensor_group - generates
|
157
|
+
# a per tensor scale
|
166
158
|
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
# local scale falls properly within the FP8 range (i.e max value is FP8_max)
|
172
|
-
# which is the expected dtype of NVFP4A16 scales
|
173
|
-
value = generate_global_scale(input_tensor=module.weight)
|
174
|
-
value = value.to(device)
|
175
|
-
init_global_scale = Parameter(value, requires_grad=False)
|
176
|
-
else:
|
177
|
-
init_global_scale = Parameter(
|
178
|
-
torch.empty(1, dtype=torch.float32, device=device),
|
179
|
-
requires_grad=False,
|
180
|
-
)
|
159
|
+
init_global_scale = Parameter(
|
160
|
+
torch.empty(1, dtype=torch.float32, device=device),
|
161
|
+
requires_grad=False,
|
162
|
+
)
|
181
163
|
register_offload_parameter(
|
182
164
|
module, f"{base_name}_global_scale", init_global_scale
|
183
165
|
)
|
@@ -258,91 +240,3 @@ def _initialize_attn_scales(module: Module) -> None:
|
|
258
240
|
requires_grad=False,
|
259
241
|
)
|
260
242
|
register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
|
261
|
-
|
262
|
-
|
263
|
-
# TODO: Potentially introduce an argument to turn this off
|
264
|
-
# Only relevant for NVFP4A16 currently
|
265
|
-
def update_fused_layer_weight_global_scales(model: torch.nn.Module):
|
266
|
-
"""
|
267
|
-
When running NVFP4A16 quantization, update the global scale
|
268
|
-
such that q,k,v layers are treated as one tensor with the same
|
269
|
-
global_scale and gate_proj/up_proj layers are treated as one tensor
|
270
|
-
with the same global scale. This is requirement currently being set
|
271
|
-
by vLLM and may be removed in the future OR potentially make it
|
272
|
-
an optional step.
|
273
|
-
|
274
|
-
:param model: model to quantize
|
275
|
-
"""
|
276
|
-
|
277
|
-
def _is_attention_module(module: Module):
|
278
|
-
return "attention" in module.__class__.__name__.lower() and (
|
279
|
-
hasattr(module, "k_proj")
|
280
|
-
or hasattr(module, "v_proj")
|
281
|
-
or hasattr(module, "qkv_proj")
|
282
|
-
)
|
283
|
-
|
284
|
-
def _is_mlp_module(module: Module):
|
285
|
-
return "mlp" in module.__class__.__name__.lower() and (
|
286
|
-
hasattr(module, "gate_proj") or hasattr(module, "up_proj")
|
287
|
-
)
|
288
|
-
|
289
|
-
def _valid_fp4_quant(layer_list: List[torch.nn.Linear]):
|
290
|
-
"""
|
291
|
-
Return True if all the linear layers in the layer_list are
|
292
|
-
NVFP4A16 quantized.
|
293
|
-
"""
|
294
|
-
for layer in layer_list:
|
295
|
-
scheme = getattr(layer, "quantization_scheme", None)
|
296
|
-
if scheme is None:
|
297
|
-
return False
|
298
|
-
|
299
|
-
weight_quant_args = scheme.weights
|
300
|
-
|
301
|
-
if weight_quant_args is None:
|
302
|
-
return False
|
303
|
-
|
304
|
-
if not is_fp4(quantization_args=weight_quant_args):
|
305
|
-
return False
|
306
|
-
return True
|
307
|
-
|
308
|
-
for name, submodule in iter_named_quantizable_modules(
|
309
|
-
model,
|
310
|
-
include_attn=True,
|
311
|
-
include_mlp=True,
|
312
|
-
):
|
313
|
-
|
314
|
-
if _is_attention_module(submodule):
|
315
|
-
# already fused/treated as one layer
|
316
|
-
if hasattr(submodule, "qkv_proj"):
|
317
|
-
continue
|
318
|
-
|
319
|
-
if not _valid_fp4_quant(
|
320
|
-
[submodule.q_proj, submodule.v_proj, submodule.k_proj]
|
321
|
-
):
|
322
|
-
continue
|
323
|
-
|
324
|
-
q_weight = submodule.q_proj.weight.data
|
325
|
-
v_weight = submodule.v_proj.weight.data
|
326
|
-
k_weight = submodule.k_proj.weight.data
|
327
|
-
|
328
|
-
value = generate_global_scale(
|
329
|
-
input_tensor=torch.cat((q_weight, v_weight, k_weight), dim=0)
|
330
|
-
)
|
331
|
-
|
332
|
-
update_parameter_data(submodule.q_proj, value, "weight_global_scale")
|
333
|
-
update_parameter_data(submodule.k_proj, value, "weight_global_scale")
|
334
|
-
update_parameter_data(submodule.v_proj, value, "weight_global_scale")
|
335
|
-
|
336
|
-
if _is_mlp_module(submodule):
|
337
|
-
if not _valid_fp4_quant([submodule.gate_proj, submodule.up_proj]):
|
338
|
-
continue
|
339
|
-
|
340
|
-
gate_data = submodule.gate_proj.weight.data
|
341
|
-
up_data = submodule.up_proj.weight.data
|
342
|
-
|
343
|
-
value = generate_global_scale(
|
344
|
-
input_tensor=torch.cat((gate_data, up_data), dim=0)
|
345
|
-
)
|
346
|
-
|
347
|
-
update_parameter_data(submodule.gate_proj, value, "weight_global_scale")
|
348
|
-
update_parameter_data(submodule.up_proj, value, "weight_global_scale")
|
@@ -47,7 +47,7 @@ __all__ = [
|
|
47
47
|
"compute_dynamic_scales_and_zp",
|
48
48
|
"calculate_range",
|
49
49
|
"calculate_qparams",
|
50
|
-
"
|
50
|
+
"generate_gparam",
|
51
51
|
"is_fp4",
|
52
52
|
]
|
53
53
|
|
@@ -81,7 +81,7 @@ def calculate_qparams(
|
|
81
81
|
currently only applied/supported for Fp4
|
82
82
|
|
83
83
|
:return: tuple of the calculated scale(s) and zero point(s). For FP4, the calculated
|
84
|
-
scale
|
84
|
+
scale is of dtype FP8
|
85
85
|
"""
|
86
86
|
# based on the implementations for consuming quantized values,
|
87
87
|
# 0.0 must always be representable within the quantized range
|
@@ -475,8 +475,9 @@ def parse_out_kv_cache_args(
|
|
475
475
|
return kv_cache_args, quant_scheme_to_layers
|
476
476
|
|
477
477
|
|
478
|
-
def
|
479
|
-
|
478
|
+
def generate_gparam(
|
479
|
+
updated_min_val: torch.Tensor,
|
480
|
+
updated_max_val: torch.Tensor,
|
480
481
|
scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
|
481
482
|
quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
|
482
483
|
dtype: Optional[torch.dtype] = torch.float32,
|
@@ -490,7 +491,8 @@ def generate_global_scale(
|
|
490
491
|
attempts to use the entire FP8 dtype range while mapping a per-group max
|
491
492
|
to the FP4 max.
|
492
493
|
"""
|
493
|
-
|
494
|
-
|
495
|
-
|
494
|
+
min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val))
|
495
|
+
max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
|
496
|
+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
|
497
|
+
global_scale = scale_data.max * quant_data.max / max_val_pos
|
496
498
|
return global_scale.to(dtype)
|
@@ -28,15 +28,18 @@ Utilities associated with offloading functionality provided by `accelerate`.
|
|
28
28
|
import contextlib
|
29
29
|
import warnings
|
30
30
|
from functools import wraps
|
31
|
-
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
|
31
|
+
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union
|
32
32
|
|
33
33
|
import torch
|
34
34
|
|
35
35
|
|
36
36
|
try:
|
37
|
+
from accelerate import dispatch_model
|
37
38
|
from accelerate.hooks import (
|
38
39
|
AlignDevicesHook,
|
39
40
|
add_hook_to_module,
|
41
|
+
attach_align_device_hook,
|
42
|
+
named_module_tensors,
|
40
43
|
remove_hook_from_module,
|
41
44
|
)
|
42
45
|
from accelerate.utils import (
|
@@ -54,6 +57,9 @@ except ImportError:
|
|
54
57
|
OffloadedWeightsLoader = None
|
55
58
|
PrefixedDataset = None
|
56
59
|
set_module_tensor_to_device = None
|
60
|
+
named_module_tensors = None
|
61
|
+
dispatch_model = None
|
62
|
+
attach_align_device_hook = None
|
57
63
|
|
58
64
|
|
59
65
|
__all__ = [
|
@@ -70,6 +76,9 @@ __all__ = [
|
|
70
76
|
"disable_offload",
|
71
77
|
"align_modules",
|
72
78
|
"align_module_device",
|
79
|
+
"register_offload_module",
|
80
|
+
"delete_offload_module",
|
81
|
+
"force_cpu_offload",
|
73
82
|
]
|
74
83
|
|
75
84
|
|
@@ -77,6 +86,11 @@ def check_accelerate(fallback: Any):
|
|
77
86
|
def decorator(func: Callable[[Any], Any]):
|
78
87
|
if not _has_accelerate:
|
79
88
|
|
89
|
+
if fallback == "error":
|
90
|
+
raise ValueError(
|
91
|
+
"Please install `accelerate` in order to use this function"
|
92
|
+
)
|
93
|
+
|
80
94
|
@wraps(func)
|
81
95
|
def fallback_fn(*args, **kwargs):
|
82
96
|
return fallback
|
@@ -346,6 +360,7 @@ def delete_from_weights_map(
|
|
346
360
|
)
|
347
361
|
|
348
362
|
|
363
|
+
@check_accelerate(fallback=contextlib.nullcontext())
|
349
364
|
@contextlib.contextmanager
|
350
365
|
def disable_offload(module: torch.nn.Module):
|
351
366
|
"""
|
@@ -362,6 +377,7 @@ def disable_offload(module: torch.nn.Module):
|
|
362
377
|
yield
|
363
378
|
|
364
379
|
|
380
|
+
@check_accelerate(fallback=contextlib.nullcontext())
|
365
381
|
@contextlib.contextmanager
|
366
382
|
def align_modules(
|
367
383
|
modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
|
@@ -383,6 +399,123 @@ def align_modules(
|
|
383
399
|
yield
|
384
400
|
|
385
401
|
|
402
|
+
def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
|
403
|
+
"""
|
404
|
+
Register a submodule with offloading if the parent module is offloaded
|
405
|
+
|
406
|
+
:param base: module to attach submodule to
|
407
|
+
:param name: name of submodule
|
408
|
+
:param module: submodule to attach
|
409
|
+
"""
|
410
|
+
|
411
|
+
if has_offloaded_params(base):
|
412
|
+
hook: AlignDevicesHook = base._hf_hook
|
413
|
+
assert hook.offload
|
414
|
+
assert hook.weights_map is not None
|
415
|
+
assert hook.tied_params_map is not None
|
416
|
+
|
417
|
+
# offloading kwargs for submodule
|
418
|
+
place_submodules = False
|
419
|
+
offload_buffers = True
|
420
|
+
|
421
|
+
# copy device offloading arguments from parent
|
422
|
+
current_device = next(base.parameters()).device # assume base has parameters
|
423
|
+
offload_device = get_offloaded_device(base)
|
424
|
+
|
425
|
+
# offload parameters to weights map
|
426
|
+
for param_name, param in named_module_tensors(
|
427
|
+
module, include_buffers=offload_buffers, recurse=place_submodules
|
428
|
+
):
|
429
|
+
offloaded = param.to(offload_device)
|
430
|
+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
|
431
|
+
offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
|
432
|
+
|
433
|
+
# if the parent places submodules, offload here
|
434
|
+
if hook.place_submodules:
|
435
|
+
set_module_tensor_to_device(module, param_name, current_device)
|
436
|
+
|
437
|
+
# if the parent does not place submodules, then add a hook
|
438
|
+
# parameters are offloaded by `add_hook_to_module`
|
439
|
+
if not hook.place_submodules:
|
440
|
+
weights_map = PrefixedDataset(
|
441
|
+
hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}."
|
442
|
+
)
|
443
|
+
|
444
|
+
submodule_hook = AlignDevicesHook(
|
445
|
+
execution_device=hook.execution_device,
|
446
|
+
offload=hook.offload,
|
447
|
+
io_same_device=False,
|
448
|
+
weights_map=weights_map,
|
449
|
+
offload_buffers=offload_buffers,
|
450
|
+
place_submodules=place_submodules,
|
451
|
+
skip_keys=None,
|
452
|
+
tied_params_map=hook.tied_params_map,
|
453
|
+
)
|
454
|
+
add_hook_to_module(module, submodule_hook)
|
455
|
+
|
456
|
+
base.register_module(name, module)
|
457
|
+
|
458
|
+
# (1): Since we cannot know which pointers are shared when we add parameters in an
|
459
|
+
# online way, assume that all pointers are shared. This comes at no runtime cost
|
460
|
+
|
461
|
+
|
462
|
+
def delete_offload_module(base: torch.nn.Module, name: str):
|
463
|
+
"""
|
464
|
+
Delete a submodule from a model which may contain offloading
|
465
|
+
:param base: parent module to delete submodule from
|
466
|
+
:param name: name of submodule on parent
|
467
|
+
"""
|
468
|
+
module: torch.nn.Module = getattr(base, name)
|
469
|
+
|
470
|
+
for param_name, _ in list(module.named_parameters()):
|
471
|
+
delete_offload_parameter(module, param_name)
|
472
|
+
|
473
|
+
delattr(base, name)
|
474
|
+
|
475
|
+
|
476
|
+
@check_accelerate(fallback="error")
|
477
|
+
def force_cpu_offload(
|
478
|
+
module: torch.nn.Module, execution_device: torch.device
|
479
|
+
) -> torch.nn.Module:
|
480
|
+
"""
|
481
|
+
Force cpu offloading a module, primarily used for testing
|
482
|
+
|
483
|
+
:param module: module containing parameters to offload
|
484
|
+
:param execution_device: execution device submodules
|
485
|
+
:return: module with hooks to perform cpu offloading
|
486
|
+
"""
|
487
|
+
# edge case: there is a bug in `dispatch_model` which causes
|
488
|
+
# the function to only work if the model contains submodules
|
489
|
+
if next(module.children(), None) is None:
|
490
|
+
attach_align_device_hook(
|
491
|
+
module,
|
492
|
+
execution_device=execution_device,
|
493
|
+
offload=True,
|
494
|
+
weights_map=module.state_dict(),
|
495
|
+
tied_params_map={},
|
496
|
+
)
|
497
|
+
return module
|
498
|
+
|
499
|
+
device_map = {}
|
500
|
+
|
501
|
+
def collect_device_map(name: List[str], module: torch.nn.Module):
|
502
|
+
if next(module.parameters(recurse=False), None) is not None:
|
503
|
+
device_map[".".join(name)] = "cpu"
|
504
|
+
return
|
505
|
+
|
506
|
+
else:
|
507
|
+
for submodule_name, submodule in module.named_children():
|
508
|
+
name.append(submodule_name)
|
509
|
+
collect_device_map(name, submodule)
|
510
|
+
name.pop()
|
511
|
+
|
512
|
+
collect_device_map([], module)
|
513
|
+
|
514
|
+
return dispatch_model(
|
515
|
+
module, device_map, main_device=execution_device, force_hooks=True
|
516
|
+
)
|
517
|
+
|
518
|
+
|
386
519
|
""" Upstreamed Functions """
|
387
520
|
|
388
521
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.9.
|
3
|
+
Version: 0.9.5a20250603
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -61,8 +61,8 @@ def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor:
|
|
61
61
|
[
|
62
62
|
QuantizationStrategy.GROUP,
|
63
63
|
128,
|
64
|
-
torch.rand((512, 8
|
65
|
-
torch.zeros((512, 8
|
64
|
+
torch.rand((512, 8)) * 0.01,
|
65
|
+
torch.zeros((512, 8), dtype=torch.int8),
|
66
66
|
],
|
67
67
|
[
|
68
68
|
QuantizationStrategy.CHANNEL,
|
@@ -79,7 +79,7 @@ def test_quant_format(strategy, group_size, sc, zp):
|
|
79
79
|
"dummy.weight_zero_point": torch.tensor(zp, dtype=torch.float32),
|
80
80
|
}
|
81
81
|
if group_size is not None:
|
82
|
-
dense_state_dict["dummy.weight_g_idx"] = make_dummy_g_idx(
|
82
|
+
dense_state_dict["dummy.weight_g_idx"] = make_dummy_g_idx(1024, group_size)
|
83
83
|
|
84
84
|
quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size)
|
85
85
|
|
@@ -53,8 +53,8 @@ def get_dummy_quant_config(strategy, group_size=None, symmetric=True):
|
|
53
53
|
QuantizationStrategy.GROUP,
|
54
54
|
True,
|
55
55
|
128,
|
56
|
-
torch.rand((512, 8
|
57
|
-
torch.zeros((512, 8
|
56
|
+
torch.rand((512, 8)) * 0.01,
|
57
|
+
torch.zeros((512, 8), dtype=torch.int8),
|
58
58
|
],
|
59
59
|
[
|
60
60
|
QuantizationStrategy.CHANNEL,
|
@@ -108,8 +108,8 @@ def test_forward_quantize(
|
|
108
108
|
"int",
|
109
109
|
QuantizationStrategy.GROUP,
|
110
110
|
128,
|
111
|
-
torch.rand((512, 8
|
112
|
-
torch.zeros((512, 8
|
111
|
+
torch.rand((512, 8)) * 0.01,
|
112
|
+
torch.zeros((512, 8)),
|
113
113
|
None,
|
114
114
|
),
|
115
115
|
(
|
@@ -117,8 +117,8 @@ def test_forward_quantize(
|
|
117
117
|
"int",
|
118
118
|
QuantizationStrategy.GROUP,
|
119
119
|
128,
|
120
|
-
torch.rand((512, 8
|
121
|
-
torch.zeros((512, 8
|
120
|
+
torch.rand((512, 8)) * 0.01,
|
121
|
+
torch.zeros((512, 8)),
|
122
122
|
make_dummy_g_idx(1024, 128),
|
123
123
|
),
|
124
124
|
(
|
@@ -135,8 +135,8 @@ def test_forward_quantize(
|
|
135
135
|
"float",
|
136
136
|
QuantizationStrategy.GROUP,
|
137
137
|
128,
|
138
|
-
torch.rand((512, 8
|
139
|
-
torch.zeros((512, 8
|
138
|
+
torch.rand((512, 8)) * 0.01,
|
139
|
+
torch.zeros((512, 8)),
|
140
140
|
None,
|
141
141
|
),
|
142
142
|
(
|
@@ -144,8 +144,8 @@ def test_forward_quantize(
|
|
144
144
|
"float",
|
145
145
|
QuantizationStrategy.GROUP,
|
146
146
|
128,
|
147
|
-
torch.rand((512, 8
|
148
|
-
torch.zeros((512, 8
|
147
|
+
torch.rand((512, 8)) * 0.01,
|
148
|
+
torch.zeros((512, 8)),
|
149
149
|
make_dummy_g_idx(1024, 128),
|
150
150
|
),
|
151
151
|
],
|
@@ -174,8 +174,8 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx
|
|
174
174
|
"int",
|
175
175
|
QuantizationStrategy.GROUP,
|
176
176
|
128,
|
177
|
-
torch.rand((512, 8
|
178
|
-
torch.zeros((512, 8
|
177
|
+
torch.rand((512, 8)) * 0.01,
|
178
|
+
torch.zeros((512, 8)),
|
179
179
|
None,
|
180
180
|
),
|
181
181
|
(
|
@@ -183,8 +183,8 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx
|
|
183
183
|
"int",
|
184
184
|
QuantizationStrategy.GROUP,
|
185
185
|
128,
|
186
|
-
torch.rand((512, 8
|
187
|
-
torch.zeros((512, 8
|
186
|
+
torch.rand((512, 8)) * 0.01,
|
187
|
+
torch.zeros((512, 8)),
|
188
188
|
make_dummy_g_idx(1024, 128),
|
189
189
|
),
|
190
190
|
],
|
@@ -20,10 +20,7 @@ from compressed_tensors.quantization import (
|
|
20
20
|
QuantizationArgs,
|
21
21
|
QuantizationStrategy,
|
22
22
|
)
|
23
|
-
from compressed_tensors.quantization.utils import
|
24
|
-
calculate_qparams,
|
25
|
-
generate_global_scale,
|
26
|
-
)
|
23
|
+
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
|
27
24
|
|
28
25
|
|
29
26
|
@pytest.mark.parametrize(
|
@@ -70,7 +67,8 @@ def test_fused_global_scales():
|
|
70
67
|
layer = torch.nn.Linear(7, 8)
|
71
68
|
max_tensor_value = torch.abs(layer.weight.data).max()
|
72
69
|
# use defaults
|
73
|
-
|
70
|
+
min_val, max_val = torch.aminmax(layer.weight)
|
71
|
+
global_scale = generate_gparam(min_val.data, max_val.data)
|
74
72
|
# max value should be = (448 * 6) / global_scale
|
75
73
|
assert max_tensor_value == pytest.approx(
|
76
74
|
FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001
|