compressed-tensors 0.9.5a20250520__tar.gz → 0.9.5a20250528__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {compressed_tensors-0.9.5a20250520/src/compressed_tensors.egg-info → compressed_tensors-0.9.5a20250528}/PKG-INFO +1 -1
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/lifecycle/forward.py +16 -3
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/lifecycle/initialize.py +47 -36
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/quant_args.py +47 -10
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/quant_config.py +2 -2
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/quant_scheme.py +23 -1
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/utils/helpers.py +31 -6
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/registry/registry.py +12 -11
- compressed_tensors-0.9.5a20250528/src/compressed_tensors/transform/__init__.py +20 -0
- compressed_tensors-0.9.5a20250528/src/compressed_tensors/transform/transform_args.py +54 -0
- compressed_tensors-0.9.5a20250528/src/compressed_tensors/transform/transform_config.py +73 -0
- compressed_tensors-0.9.5a20250528/src/compressed_tensors/transform/transform_scheme.py +43 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/version.py +1 -1
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528/src/compressed_tensors.egg-info}/PKG-INFO +1 -1
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors.egg-info/SOURCES.txt +7 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/lifecycle/test_initialize.py +33 -9
- compressed_tensors-0.9.5a20250528/tests/test_transform/test_transform_args.py +55 -0
- compressed_tensors-0.9.5a20250528/tests/test_transform/test_transform_config.py +71 -0
- compressed_tensors-0.9.5a20250528/tests/test_transform/test_transform_scheme.py +74 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.github/.gitkeep +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.github/actions/test/action.yml +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.github/scripts/step-status +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.github/workflows/build-test.yml +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.github/workflows/build.yml +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.github/workflows/report.yml +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.github/workflows/test-check.yaml +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.github/workflows/test.yml +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.github/workflows/trigger-all.yml +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.github/workflows/upload.yml +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/.gitignore +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/LICENSE +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/Makefile +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/README.md +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/examples/bit_packing/ex_quantize_and_pack.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/examples/bit_packing/int4_config.json +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/examples/bitmask_compression.ipynb +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/examples/llama_1.1b/ex_config_quantization.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/examples/llama_1.1b/ex_llmcompressor_quantization.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/examples/llama_1.1b/example_quant_config.json +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/examples/llama_1.1b/example_quant_recipe.yaml +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/examples/quantize_and_pack_int4.ipynb +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/pyproject.toml +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/setup.cfg +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/setup.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/README.md +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/base.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/base.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/helpers.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/config/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/config/base.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/config/dense.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/config/sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/linear/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/linear/compressed_linear.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/registry/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/utils/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/utils/helpers.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/utils/offload.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/utils/permutations_24.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/utils/permute.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/utils/safetensors_load.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors.egg-info/requires.txt +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/src/compressed_tensors.egg-info/top_level.txt +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/conftest.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/model_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/model_compressors/test_model_compressor.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/quantized_compressors/test_fp8_quant.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/quantized_compressors/test_int_quant.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/quantized_compressors/test_nvfp4_quant.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/quantized_compressors/test_pack_quant.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/sparse_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/sparse_compressors/test_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_configs/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_configs/test_base.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_examples/test_bitmask_compression_ipynb.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_linear/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_linear/test_compressed_linear.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/lifecycle/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/lifecycle/conftest.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/lifecycle/test_apply.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/lifecycle/test_enabled.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/lifecycle/test_forward.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/lifecycle/test_helpers.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/lifecycle/test_lifecycle.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/test_configs/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/test_configs/test_bit_depths.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/test_configs/test_strategies.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/test_quant_args.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/test_quant_config.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/test_quant_scheme.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_quantization/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_registry.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_utils/__init__.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_utils/test_helpers.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_utils/test_offload.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/test_utils/test_safetensors_load.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/tests/testing_utils.py +0 -0
- {compressed_tensors-0.9.5a20250520 → compressed_tensors-0.9.5a20250528}/utils/copyright.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.9.
|
3
|
+
Version: 0.9.5a20250528
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -18,6 +18,7 @@ from typing import Optional
|
|
18
18
|
|
19
19
|
import torch
|
20
20
|
from compressed_tensors.quantization.quant_args import (
|
21
|
+
DynamicType,
|
21
22
|
QuantizationArgs,
|
22
23
|
QuantizationStrategy,
|
23
24
|
QuantizationType,
|
@@ -189,7 +190,11 @@ def _process_quantization(
|
|
189
190
|
q_min, q_max = calculate_range(args, x.device)
|
190
191
|
group_size = args.group_size
|
191
192
|
|
192
|
-
if args.strategy
|
193
|
+
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
|
194
|
+
n_dims = x.shape
|
195
|
+
if len(n_dims) > 2:
|
196
|
+
x = x.squeeze(0)
|
197
|
+
|
193
198
|
output_dtype = dtype if dtype is not None else x.dtype
|
194
199
|
output = torch.zeros_like(x).to(output_dtype)
|
195
200
|
columns = output.shape[1]
|
@@ -251,6 +256,9 @@ def _process_quantization(
|
|
251
256
|
if not is_column_order:
|
252
257
|
output = safe_permute(output, torch.argsort(perm), dim=1)
|
253
258
|
|
259
|
+
if len(n_dims) > 2:
|
260
|
+
output = output.unsqueeze(0)
|
261
|
+
|
254
262
|
else: # covers channel, token and tensor strategies
|
255
263
|
if do_quantize:
|
256
264
|
output = _quantize(
|
@@ -352,9 +360,11 @@ def forward_quantize(
|
|
352
360
|
g_idx = getattr(module, "weight_g_idx", None)
|
353
361
|
global_scale = getattr(module, f"{base_name}_global_scale", None)
|
354
362
|
|
355
|
-
if args.dynamic:
|
363
|
+
if args.dynamic in (True, DynamicType.LOCAL):
|
356
364
|
# dynamic quantization - determine the scale/zp on the fly
|
357
|
-
scale, zero_point = compute_dynamic_scales_and_zp(
|
365
|
+
scale, zero_point = compute_dynamic_scales_and_zp(
|
366
|
+
value=value, args=args, module=module, global_scale=global_scale
|
367
|
+
)
|
358
368
|
else:
|
359
369
|
# static quantization - get scale and zero point from layer
|
360
370
|
scale = getattr(module, f"{base_name}_scale")
|
@@ -388,6 +398,7 @@ def _quantize(
|
|
388
398
|
scale = scale.to(global_scale.dtype) / global_scale
|
389
399
|
|
390
400
|
scaled = x / scale
|
401
|
+
|
391
402
|
if zero_point is not None:
|
392
403
|
scaled += zero_point.to(x.dtype)
|
393
404
|
|
@@ -398,6 +409,7 @@ def _quantize(
|
|
398
409
|
q_max,
|
399
410
|
)
|
400
411
|
quantized_value = round_to_quantized_type(clamped_value, args)
|
412
|
+
|
401
413
|
if dtype is not None:
|
402
414
|
quantized_value = quantized_value.to(dtype)
|
403
415
|
|
@@ -422,6 +434,7 @@ def _dequantize(
|
|
422
434
|
|
423
435
|
if zero_point is not None:
|
424
436
|
dequant_value = dequant_value - zero_point.to(scale.dtype)
|
437
|
+
|
425
438
|
dequant_value = dequant_value * scale
|
426
439
|
|
427
440
|
if dtype is not None:
|
@@ -156,13 +156,33 @@ def _initialize_scale_zero_point(
|
|
156
156
|
force_zero_point: bool = True,
|
157
157
|
scale_dtype: Optional[torch.dtype] = None,
|
158
158
|
):
|
159
|
-
if quantization_args.dynamic:
|
159
|
+
if quantization_args.dynamic is True:
|
160
160
|
return
|
161
161
|
|
162
162
|
# initialize on execution device to avoid performing quantized ops on cpu
|
163
163
|
device = get_execution_device(module)
|
164
164
|
|
165
|
-
#
|
165
|
+
# 1. Create global_scales for tensor_group
|
166
|
+
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
|
167
|
+
# TODO: should move to llmcompressor
|
168
|
+
if base_name == "weight":
|
169
|
+
# When applying weight-only FP4 quantization, generate a global_scale
|
170
|
+
# This scale is applied during runtime to ensure that the generated
|
171
|
+
# local scale falls properly within the FP8 range (i.e max value is FP8_max)
|
172
|
+
# which is the expected dtype of NVFP4A16 scales
|
173
|
+
value = generate_global_scale(input_tensor=module.weight)
|
174
|
+
value = value.to(device)
|
175
|
+
init_global_scale = Parameter(value, requires_grad=False)
|
176
|
+
else:
|
177
|
+
init_global_scale = Parameter(
|
178
|
+
torch.empty(1, dtype=torch.float32, device=device),
|
179
|
+
requires_grad=False,
|
180
|
+
)
|
181
|
+
register_offload_parameter(
|
182
|
+
module, f"{base_name}_global_scale", init_global_scale
|
183
|
+
)
|
184
|
+
|
185
|
+
# 2. Infer expected scale/zero point shape
|
166
186
|
if quantization_args.strategy == QuantizationStrategy.TOKEN:
|
167
187
|
expected_shape = (1, 1)
|
168
188
|
else:
|
@@ -172,47 +192,35 @@ def _initialize_scale_zero_point(
|
|
172
192
|
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
173
193
|
# (output_channels, 1)
|
174
194
|
expected_shape = (weight_shape[0], 1)
|
175
|
-
elif quantization_args.strategy
|
195
|
+
elif quantization_args.strategy in (
|
196
|
+
QuantizationStrategy.TENSOR_GROUP,
|
197
|
+
QuantizationStrategy.GROUP,
|
198
|
+
):
|
176
199
|
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
|
177
200
|
expected_shape = (weight_shape[0], max(num_groups, 1))
|
178
201
|
|
202
|
+
# 3. Identify quantization scale and zp dtype
|
179
203
|
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
|
180
|
-
# TODO: consider erroring out in the future as if the dtype if not one fo these,
|
181
|
-
# there is likely bug
|
182
|
-
|
183
|
-
if is_fp4(quantization_args=quantization_args) and base_name == "weight":
|
184
|
-
scale_dtype = FP8_E4M3_DATA.dtype
|
185
|
-
# When applying weight-only FP4 quantization, generate a global_scale
|
186
|
-
# This scale is applied during runtime to ensure that the generated
|
187
|
-
# local scale falls properly within the FP8 range (i.e max value is FP8_max)
|
188
|
-
# which is the expected dtype of NVFP4A16 scales
|
189
|
-
value = generate_global_scale(input_tensor=module.weight)
|
190
|
-
value = value.to(device)
|
191
|
-
init_global_scale = Parameter(value, requires_grad=False)
|
192
|
-
register_offload_parameter(
|
193
|
-
module, f"{base_name}_global_scale", init_global_scale
|
194
|
-
)
|
195
|
-
|
196
|
-
if scale_dtype not in [
|
197
|
-
torch.float16,
|
198
|
-
torch.bfloat16,
|
199
|
-
torch.float32,
|
200
|
-
] and not is_fp4(quantization_args=quantization_args):
|
201
|
-
scale_dtype = torch.float16
|
202
204
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
205
|
+
if is_fp4(quantization_args=quantization_args):
|
206
|
+
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
|
207
|
+
else:
|
208
|
+
# TODO: consider erroring out in the future as if the dtype if not one of these,
|
209
|
+
# there is likely bug
|
210
|
+
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
|
211
|
+
scale_dtype = torch.float16
|
212
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
213
|
+
|
214
|
+
# 4. Initializes empty scale, zero point, and g_idx parameters for the module
|
215
|
+
# do not init scales for quantzation_args.dynamic == DynamicType.local
|
216
|
+
if not quantization_args.dynamic:
|
217
|
+
init_scale = Parameter(
|
218
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
219
|
+
requires_grad=False,
|
220
|
+
)
|
221
|
+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
|
209
222
|
|
210
223
|
if force_zero_point or not quantization_args.symmetric:
|
211
|
-
if is_fp4(quantization_args=quantization_args):
|
212
|
-
zp_dtype = FP8_E4M3_DATA.dtype
|
213
|
-
else:
|
214
|
-
zp_dtype = quantization_args.pytorch_dtype()
|
215
|
-
|
216
224
|
init_zero_point = Parameter(
|
217
225
|
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
|
218
226
|
requires_grad=False,
|
@@ -304,6 +312,9 @@ def update_fused_layer_weight_global_scales(model: torch.nn.Module):
|
|
304
312
|
):
|
305
313
|
|
306
314
|
if _is_attention_module(submodule):
|
315
|
+
# already fused/treated as one layer
|
316
|
+
if hasattr(submodule, "qkv_proj"):
|
317
|
+
continue
|
307
318
|
|
308
319
|
if not _valid_fp4_quant(
|
309
320
|
[submodule.q_proj, submodule.v_proj, submodule.k_proj]
|
@@ -32,6 +32,7 @@ __all__ = [
|
|
32
32
|
"QuantizationArgs",
|
33
33
|
"round_to_quantized_type",
|
34
34
|
"ActivationOrdering",
|
35
|
+
"DynamicType",
|
35
36
|
]
|
36
37
|
|
37
38
|
|
@@ -98,6 +99,22 @@ class QuantizationStrategy(str, Enum):
|
|
98
99
|
GROUP = "group"
|
99
100
|
BLOCK = "block"
|
100
101
|
TOKEN = "token"
|
102
|
+
TENSOR_GROUP = "tensor_group"
|
103
|
+
|
104
|
+
|
105
|
+
class DynamicType(str, Enum):
|
106
|
+
"""
|
107
|
+
Enum storing potential dynamic types.
|
108
|
+
|
109
|
+
1. If dynamic is True, all quantization parameters are generated on the fly.
|
110
|
+
2. If dynamic is False, all quantization parameters generated are static.
|
111
|
+
3. If "local" is provided, only local quantization parameters are dynamic.
|
112
|
+
|
113
|
+
Note: "local" is only currently supported for NVFP4.
|
114
|
+
|
115
|
+
"""
|
116
|
+
|
117
|
+
LOCAL = "local"
|
101
118
|
|
102
119
|
|
103
120
|
class ActivationOrdering(Aliasable, str, Enum):
|
@@ -152,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
152
169
|
group_size: Optional[int] = None
|
153
170
|
strategy: Optional[QuantizationStrategy] = None
|
154
171
|
block_structure: Optional[str] = None
|
155
|
-
dynamic: bool = False
|
172
|
+
dynamic: Union[DynamicType, bool] = False
|
156
173
|
actorder: Union[ActivationOrdering, bool, None] = None
|
157
174
|
observer: Optional[str] = Field(
|
158
175
|
default=None,
|
@@ -206,6 +223,12 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
206
223
|
|
207
224
|
return value
|
208
225
|
|
226
|
+
@field_validator("dynamic", mode="before")
|
227
|
+
def validate_dynamic(cls, value) -> Union[DynamicType, bool]:
|
228
|
+
if isinstance(value, str):
|
229
|
+
return DynamicType(value.lower())
|
230
|
+
return value
|
231
|
+
|
209
232
|
@model_validator(mode="after")
|
210
233
|
def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
|
211
234
|
# extract user-passed values from dictionary
|
@@ -239,7 +262,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
239
262
|
if (
|
240
263
|
group_size is not None
|
241
264
|
and group_size > 0
|
242
|
-
and strategy
|
265
|
+
and strategy
|
266
|
+
not in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP)
|
243
267
|
):
|
244
268
|
raise ValueError("group_size requires strategy to be set to 'group'")
|
245
269
|
|
@@ -255,18 +279,31 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
255
279
|
if strategy not in (
|
256
280
|
QuantizationStrategy.TOKEN,
|
257
281
|
QuantizationStrategy.TENSOR,
|
282
|
+
QuantizationStrategy.TENSOR_GROUP,
|
258
283
|
):
|
259
284
|
raise ValueError(
|
260
|
-
f"One of {QuantizationStrategy.TOKEN}
|
261
|
-
|
262
|
-
"quantization",
|
285
|
+
f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} "
|
286
|
+
"must be used for dynamic quantization",
|
263
287
|
)
|
288
|
+
|
289
|
+
if (
|
290
|
+
dynamic == DynamicType.LOCAL
|
291
|
+
and strategy != QuantizationStrategy.TENSOR_GROUP
|
292
|
+
):
|
293
|
+
raise ValueError("local is only supported for strategy tensor_group")
|
294
|
+
|
264
295
|
if observer is not None:
|
265
|
-
if
|
266
|
-
|
267
|
-
|
268
|
-
)
|
269
|
-
|
296
|
+
if dynamic is True: # checking if dynamic is True, not "local"
|
297
|
+
if (
|
298
|
+
observer != "memoryless"
|
299
|
+
): # avoid annoying users with old configs
|
300
|
+
warnings.warn(
|
301
|
+
"No observer is used for dynamic quantization, setting to None"
|
302
|
+
)
|
303
|
+
observer = None
|
304
|
+
else:
|
305
|
+
if dynamic == DynamicType.LOCAL:
|
306
|
+
observer = "minmax"
|
270
307
|
|
271
308
|
elif observer is None:
|
272
309
|
# default to minmax for non-dynamic cases
|
@@ -16,7 +16,7 @@ from enum import Enum
|
|
16
16
|
from typing import Dict, List, Optional, Union
|
17
17
|
|
18
18
|
from compressed_tensors.config import CompressionFormat
|
19
|
-
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
19
|
+
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
|
20
20
|
from compressed_tensors.quantization.quant_scheme import (
|
21
21
|
QuantizationScheme,
|
22
22
|
preset_name_to_scheme,
|
@@ -251,7 +251,7 @@ class QuantizationConfig(BaseModel):
|
|
251
251
|
|
252
252
|
for _, scheme in self.config_groups.items():
|
253
253
|
if scheme.input_activations is not None:
|
254
|
-
if
|
254
|
+
if scheme.input_activations.dynamic in (False, DynamicType.LOCAL):
|
255
255
|
return True
|
256
256
|
if scheme.output_activations is not None:
|
257
257
|
if not scheme.output_activations.dynamic:
|
@@ -16,6 +16,7 @@ from copy import deepcopy
|
|
16
16
|
from typing import Any, Dict, List, Optional
|
17
17
|
|
18
18
|
from compressed_tensors.quantization.quant_args import (
|
19
|
+
DynamicType,
|
19
20
|
QuantizationArgs,
|
20
21
|
QuantizationStrategy,
|
21
22
|
QuantizationType,
|
@@ -104,13 +105,33 @@ NVFP4A16 = dict(
|
|
104
105
|
weights=QuantizationArgs(
|
105
106
|
num_bits=4,
|
106
107
|
type=QuantizationType.FLOAT,
|
107
|
-
strategy=QuantizationStrategy.
|
108
|
+
strategy=QuantizationStrategy.TENSOR_GROUP,
|
108
109
|
symmetric=True,
|
109
110
|
dynamic=False,
|
110
111
|
group_size=16,
|
111
112
|
)
|
112
113
|
)
|
113
114
|
|
115
|
+
|
116
|
+
NVFP4 = dict(
|
117
|
+
weights=QuantizationArgs(
|
118
|
+
num_bits=4,
|
119
|
+
type=QuantizationType.FLOAT,
|
120
|
+
strategy=QuantizationStrategy.TENSOR_GROUP,
|
121
|
+
symmetric=True,
|
122
|
+
dynamic=False,
|
123
|
+
group_size=16,
|
124
|
+
),
|
125
|
+
input_activations=QuantizationArgs(
|
126
|
+
num_bits=4,
|
127
|
+
type=QuantizationType.FLOAT,
|
128
|
+
strategy=QuantizationStrategy.TENSOR_GROUP,
|
129
|
+
symmetric=True,
|
130
|
+
dynamic=DynamicType.LOCAL,
|
131
|
+
group_size=16,
|
132
|
+
),
|
133
|
+
)
|
134
|
+
|
114
135
|
# 8 bit integer weights and 8 bit activations quantization
|
115
136
|
INT8_W8A8 = dict(
|
116
137
|
weights=QuantizationArgs(
|
@@ -237,4 +258,5 @@ PRESET_SCHEMES = {
|
|
237
258
|
"FP8": FP8,
|
238
259
|
"FP8_DYNAMIC": FP8_DYNAMIC,
|
239
260
|
"NVFP4A16": NVFP4A16,
|
261
|
+
"NVFP4": NVFP4,
|
240
262
|
}
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
+
import math
|
16
17
|
from typing import Generator, List, Optional, Tuple
|
17
18
|
|
18
19
|
import torch
|
@@ -103,7 +104,9 @@ def calculate_qparams(
|
|
103
104
|
if is_fp4(quantization_args=quantization_args) and global_scale is not None:
|
104
105
|
# Conditionally scale the generated local scale by a global_scale
|
105
106
|
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
|
107
|
+
scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
|
106
108
|
scales = scales.to(FP8_E4M3_DATA.dtype)
|
109
|
+
|
107
110
|
else:
|
108
111
|
scales = max_val_pos / (float(bit_range) / 2)
|
109
112
|
|
@@ -143,7 +146,12 @@ def calculate_qparams(
|
|
143
146
|
return scales, zero_points
|
144
147
|
|
145
148
|
|
146
|
-
def compute_dynamic_scales_and_zp(
|
149
|
+
def compute_dynamic_scales_and_zp(
|
150
|
+
value: Tensor,
|
151
|
+
args: QuantizationArgs,
|
152
|
+
module: torch.nn.Module,
|
153
|
+
global_scale: Optional[Tensor] = None,
|
154
|
+
):
|
147
155
|
"""
|
148
156
|
Returns the computed scales and zero points for dynamic activation
|
149
157
|
quantization.
|
@@ -155,24 +163,41 @@ def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
|
|
155
163
|
reduced dimensions
|
156
164
|
:return: tuple of scale and zero point derived from the observed tensor
|
157
165
|
"""
|
166
|
+
|
167
|
+
keep_dims = True
|
158
168
|
if args.strategy == QuantizationStrategy.TOKEN:
|
159
169
|
dim = {1, 2}
|
160
170
|
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
|
161
171
|
elif args.strategy == QuantizationStrategy.TENSOR:
|
162
172
|
reduce_dims = None
|
173
|
+
elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
|
174
|
+
if len(value.shape) > 2:
|
175
|
+
value = value.squeeze(0)
|
176
|
+
|
177
|
+
dim = {0, 1}
|
178
|
+
reduce_dims = tuple(idx for idx in range(3) if idx not in dim)
|
179
|
+
keep_dims = False
|
180
|
+
value = torch.reshape(
|
181
|
+
value,
|
182
|
+
(
|
183
|
+
value.shape[0],
|
184
|
+
math.ceil(value.shape[1] / args.group_size),
|
185
|
+
args.group_size,
|
186
|
+
),
|
187
|
+
)
|
163
188
|
else:
|
164
189
|
raise ValueError(
|
165
|
-
|
166
|
-
"
|
190
|
+
"Dynamic quantization is only supported for ",
|
191
|
+
f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}",
|
167
192
|
)
|
168
193
|
|
169
194
|
if not reduce_dims:
|
170
195
|
min_val, max_val = torch.aminmax(value)
|
171
196
|
else:
|
172
|
-
min_val = torch.amin(value, dim=reduce_dims, keepdims=
|
173
|
-
max_val = torch.amax(value, dim=reduce_dims, keepdims=
|
197
|
+
min_val = torch.amin(value, dim=reduce_dims, keepdims=keep_dims)
|
198
|
+
max_val = torch.amax(value, dim=reduce_dims, keepdims=keep_dims)
|
174
199
|
|
175
|
-
return calculate_qparams(min_val, max_val, args)
|
200
|
+
return calculate_qparams(min_val, max_val, args, global_scale=global_scale)
|
176
201
|
|
177
202
|
|
178
203
|
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
|
@@ -19,7 +19,7 @@ of neuralmagic utilities
|
|
19
19
|
|
20
20
|
import importlib
|
21
21
|
from collections import defaultdict
|
22
|
-
from typing import Any, Dict, List, Optional,
|
22
|
+
from typing import Any, Dict, List, Optional, TypeVar, Union
|
23
23
|
|
24
24
|
|
25
25
|
__all__ = [
|
@@ -32,8 +32,9 @@ __all__ = [
|
|
32
32
|
]
|
33
33
|
|
34
34
|
|
35
|
-
_ALIAS_REGISTRY: Dict[
|
36
|
-
_REGISTRY: Dict[
|
35
|
+
_ALIAS_REGISTRY: Dict[type, Dict[str, str]] = defaultdict(dict)
|
36
|
+
_REGISTRY: Dict[type, Dict[str, Any]] = defaultdict(dict)
|
37
|
+
T = TypeVar("", bound="RegistryMixin")
|
37
38
|
|
38
39
|
|
39
40
|
def standardize_lookup_name(name: str) -> str:
|
@@ -159,7 +160,7 @@ class RegistryMixin:
|
|
159
160
|
)
|
160
161
|
|
161
162
|
@classmethod
|
162
|
-
def load_from_registry(cls, name: str, **constructor_kwargs) ->
|
163
|
+
def load_from_registry(cls: type[T], name: str, **constructor_kwargs) -> T:
|
163
164
|
"""
|
164
165
|
:param name: name of registered class to load
|
165
166
|
:param constructor_kwargs: arguments to pass to the constructor retrieved
|
@@ -172,7 +173,7 @@ class RegistryMixin:
|
|
172
173
|
return constructor(**constructor_kwargs)
|
173
174
|
|
174
175
|
@classmethod
|
175
|
-
def get_value_from_registry(cls, name: str):
|
176
|
+
def get_value_from_registry(cls: type[T], name: str) -> T:
|
176
177
|
"""
|
177
178
|
:param name: name to retrieve from the registry
|
178
179
|
:return: value from retrieved the registry for the given name, raises
|
@@ -200,7 +201,7 @@ class RegistryMixin:
|
|
200
201
|
|
201
202
|
|
202
203
|
def register(
|
203
|
-
parent_class:
|
204
|
+
parent_class: type,
|
204
205
|
value: Any,
|
205
206
|
name: Optional[str] = None,
|
206
207
|
alias: Union[List[str], str, None] = None,
|
@@ -240,7 +241,7 @@ def register(
|
|
240
241
|
|
241
242
|
|
242
243
|
def get_from_registry(
|
243
|
-
parent_class:
|
244
|
+
parent_class: type, name: str, require_subclass: bool = False
|
244
245
|
) -> Any:
|
245
246
|
"""
|
246
247
|
:param parent_class: class that the name is registered under
|
@@ -276,7 +277,7 @@ def get_from_registry(
|
|
276
277
|
return retrieved_value
|
277
278
|
|
278
279
|
|
279
|
-
def registered_names(parent_class:
|
280
|
+
def registered_names(parent_class: type) -> List[str]:
|
280
281
|
"""
|
281
282
|
:param parent_class: class to look up the registry of
|
282
283
|
:return: all names registered to the given class
|
@@ -284,7 +285,7 @@ def registered_names(parent_class: Type) -> List[str]:
|
|
284
285
|
return list(_REGISTRY[parent_class].keys())
|
285
286
|
|
286
287
|
|
287
|
-
def registered_aliases(parent_class:
|
288
|
+
def registered_aliases(parent_class: type) -> List[str]:
|
288
289
|
"""
|
289
290
|
:param parent_class: class to look up the registry of
|
290
291
|
:return: all aliases registered to the given class
|
@@ -297,7 +298,7 @@ def registered_aliases(parent_class: Type) -> List[str]:
|
|
297
298
|
|
298
299
|
|
299
300
|
def register_alias(
|
300
|
-
name: str, parent_class:
|
301
|
+
name: str, parent_class: type, alias: Union[str, List[str], None] = None
|
301
302
|
):
|
302
303
|
"""
|
303
304
|
Updates the mapping from the alias(es) to the given name.
|
@@ -352,7 +353,7 @@ def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
|
|
352
353
|
return value
|
353
354
|
|
354
355
|
|
355
|
-
def _validate_subclass(parent_class:
|
356
|
+
def _validate_subclass(parent_class: type, child_class: type):
|
356
357
|
if not issubclass(child_class, parent_class):
|
357
358
|
raise ValueError(
|
358
359
|
f"class {child_class} is not a subclass of the class it is "
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# flake8: noqa
|
16
|
+
# isort: skip_file
|
17
|
+
|
18
|
+
from .transform_args import *
|
19
|
+
from .transform_scheme import *
|
20
|
+
from .transform_config import *
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from enum import Enum
|
16
|
+
from typing import Any, List
|
17
|
+
|
18
|
+
from pydantic import BaseModel, Field, field_validator
|
19
|
+
|
20
|
+
|
21
|
+
__all__ = ["TransformArgs"]
|
22
|
+
|
23
|
+
|
24
|
+
class TransformLocation(str, Enum):
|
25
|
+
INPUT = "input"
|
26
|
+
WEIGHT_INPUT = "weight_input"
|
27
|
+
WEIGHT_OUTPUT = "weight_output"
|
28
|
+
OUTPUT = "output"
|
29
|
+
K_CACHE = "k_cache"
|
30
|
+
Q_ATTN = "q_attn"
|
31
|
+
|
32
|
+
|
33
|
+
class TransformArgs(BaseModel):
|
34
|
+
"""
|
35
|
+
Arguments which define *how* and where a transform should be applied to a model
|
36
|
+
|
37
|
+
:param targets: list of modules to apply transforms to
|
38
|
+
:param location: where to apply transform on module, one of (`input`, `weight`,
|
39
|
+
`output`, `k_cache`, `q_attn`)
|
40
|
+
:param inverse: whether or not to apply the inverse of a transform
|
41
|
+
:param ignore: any modules which should be ignored from the targets list
|
42
|
+
"""
|
43
|
+
|
44
|
+
targets: List[str]
|
45
|
+
location: TransformLocation
|
46
|
+
inverse: bool = Field(default=False)
|
47
|
+
ignore: List[str] = Field(default_factory=list)
|
48
|
+
|
49
|
+
@field_validator("targets", "ignore", mode="before")
|
50
|
+
@classmethod
|
51
|
+
def wrap_singleton(cls, value):
|
52
|
+
if isinstance(value, str):
|
53
|
+
return [value]
|
54
|
+
return value
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict
|
16
|
+
|
17
|
+
from compressed_tensors.transform import TransformArgs, TransformScheme
|
18
|
+
from pydantic import BaseModel
|
19
|
+
|
20
|
+
|
21
|
+
__all__ = ["TransformConfig"]
|
22
|
+
|
23
|
+
|
24
|
+
class TransformConfig(BaseModel):
|
25
|
+
"""
|
26
|
+
Configuration of transforms to be applied to a model. This config is to be
|
27
|
+
serialized within a model's `config.json` file
|
28
|
+
|
29
|
+
:param config_groups: A dictionary of `TransformSchemes` that should be applied
|
30
|
+
to a particular model. The keys can be any arbitrary string
|
31
|
+
"""
|
32
|
+
|
33
|
+
config_groups: Dict[str, TransformScheme]
|
34
|
+
|
35
|
+
|
36
|
+
# quip / quip sharp
|
37
|
+
QUIP = TransformConfig(
|
38
|
+
config_groups={
|
39
|
+
"v": TransformScheme(
|
40
|
+
type="hadamard",
|
41
|
+
apply=[
|
42
|
+
TransformArgs(
|
43
|
+
targets=["Linear"],
|
44
|
+
location="input", # non-mergable
|
45
|
+
),
|
46
|
+
TransformArgs(
|
47
|
+
targets=["Linear"],
|
48
|
+
location="weight_input",
|
49
|
+
inverse=True,
|
50
|
+
),
|
51
|
+
],
|
52
|
+
randomize_modules=True,
|
53
|
+
),
|
54
|
+
"u": TransformScheme(
|
55
|
+
type="hadamard",
|
56
|
+
apply=[
|
57
|
+
TransformArgs(
|
58
|
+
targets=["Linear"],
|
59
|
+
location="weight_output",
|
60
|
+
),
|
61
|
+
TransformArgs(
|
62
|
+
targets=["Linear"], location="output", inverse=True # non-mergable
|
63
|
+
),
|
64
|
+
],
|
65
|
+
randomize_modules=True,
|
66
|
+
),
|
67
|
+
}
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
PRESET_CONFIGS = {
|
72
|
+
"QUIP": QUIP,
|
73
|
+
}
|