compressed-tensors 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressors/model_compressor.py +17 -5
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +4 -2
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -0
- compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
- compressed_tensors/config/base.py +60 -2
- compressed_tensors/linear/compressed_linear.py +3 -1
- compressed_tensors/quantization/__init__.py +0 -1
- compressed_tensors/quantization/lifecycle/__init__.py +0 -2
- compressed_tensors/quantization/lifecycle/apply.py +3 -17
- compressed_tensors/quantization/lifecycle/forward.py +24 -87
- compressed_tensors/quantization/lifecycle/initialize.py +21 -24
- compressed_tensors/quantization/quant_args.py +27 -25
- compressed_tensors/quantization/quant_config.py +2 -2
- compressed_tensors/quantization/quant_scheme.py +17 -24
- compressed_tensors/quantization/utils/helpers.py +125 -8
- compressed_tensors/registry/registry.py +1 -1
- compressed_tensors/utils/helpers.py +33 -1
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/METADATA +1 -1
- {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/RECORD +23 -31
- {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/WHEEL +1 -1
- compressed_tensors/quantization/cache.py +0 -201
- compressed_tensors/quantization/lifecycle/calibration.py +0 -70
- compressed_tensors/quantization/lifecycle/frozen.py +0 -55
- compressed_tensors/quantization/observers/__init__.py +0 -21
- compressed_tensors/quantization/observers/base.py +0 -213
- compressed_tensors/quantization/observers/helpers.py +0 -149
- compressed_tensors/quantization/observers/min_max.py +0 -104
- compressed_tensors/quantization/observers/mse.py +0 -162
- {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/top_level.txt +0 -0
@@ -13,14 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from copy import deepcopy
|
16
|
-
from typing import List, Optional
|
16
|
+
from typing import Any, Dict, List, Optional
|
17
17
|
|
18
18
|
from compressed_tensors.quantization.quant_args import (
|
19
19
|
QuantizationArgs,
|
20
20
|
QuantizationStrategy,
|
21
21
|
QuantizationType,
|
22
22
|
)
|
23
|
-
from pydantic import BaseModel
|
23
|
+
from pydantic import BaseModel, model_validator
|
24
24
|
|
25
25
|
|
26
26
|
__all__ = [
|
@@ -36,7 +36,7 @@ class QuantizationScheme(BaseModel):
|
|
36
36
|
of modules should be quantized
|
37
37
|
|
38
38
|
:param targets: list of modules to apply the QuantizationArgs to, can be layer
|
39
|
-
names, layer types or a regular expression
|
39
|
+
names, layer types or a regular expression, typically ["Linear"]
|
40
40
|
:param weights: quantization config for layer weights
|
41
41
|
:param input_activations: quantization config for layer inputs
|
42
42
|
:param output_activations: quantization config for layer outputs
|
@@ -47,27 +47,20 @@ class QuantizationScheme(BaseModel):
|
|
47
47
|
input_activations: Optional[QuantizationArgs] = None
|
48
48
|
output_activations: Optional[QuantizationArgs] = None
|
49
49
|
|
50
|
-
@
|
51
|
-
def
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
return cls(
|
66
|
-
targets=targets,
|
67
|
-
weights=weights,
|
68
|
-
input_activations=input_activations,
|
69
|
-
output_activations=output_activations,
|
70
|
-
)
|
50
|
+
@model_validator(mode="after")
|
51
|
+
def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
|
52
|
+
inputs = model.input_activations
|
53
|
+
outputs = model.output_activations
|
54
|
+
|
55
|
+
if inputs is not None:
|
56
|
+
if inputs.actorder is not None:
|
57
|
+
raise ValueError("Cannot apply actorder to input activations")
|
58
|
+
|
59
|
+
if outputs is not None:
|
60
|
+
if outputs.actorder is not None:
|
61
|
+
raise ValueError("Cannot apply actorder to output activations")
|
62
|
+
|
63
|
+
return model
|
71
64
|
|
72
65
|
|
73
66
|
"""
|
@@ -16,9 +16,14 @@ import logging
|
|
16
16
|
from typing import Generator, List, Optional, Tuple
|
17
17
|
|
18
18
|
import torch
|
19
|
-
from compressed_tensors.quantization.
|
20
|
-
|
19
|
+
from compressed_tensors.quantization.quant_args import (
|
20
|
+
FP8_DTYPE,
|
21
|
+
QuantizationArgs,
|
22
|
+
QuantizationStrategy,
|
23
|
+
QuantizationType,
|
24
|
+
)
|
21
25
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
26
|
+
from torch import FloatTensor, IntTensor, Tensor
|
22
27
|
from torch.nn import Module
|
23
28
|
from tqdm import tqdm
|
24
29
|
|
@@ -36,6 +41,9 @@ __all__ = [
|
|
36
41
|
"is_kv_cache_quant_scheme",
|
37
42
|
"iter_named_leaf_modules",
|
38
43
|
"iter_named_quantizable_modules",
|
44
|
+
"compute_dynamic_scales_and_zp",
|
45
|
+
"calculate_range",
|
46
|
+
"calculate_qparams",
|
39
47
|
]
|
40
48
|
|
41
49
|
# target the self_attn layer
|
@@ -45,6 +53,105 @@ KV_CACHE_TARGETS = ["re:.*self_attn$"]
|
|
45
53
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
46
54
|
|
47
55
|
|
56
|
+
def calculate_qparams(
|
57
|
+
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
|
58
|
+
) -> Tuple[FloatTensor, IntTensor]:
|
59
|
+
"""
|
60
|
+
:param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
|
61
|
+
from
|
62
|
+
:param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
|
63
|
+
from
|
64
|
+
:param quantization_args: settings to quantization
|
65
|
+
:return: tuple of the calculated scale(s) and zero point(s)
|
66
|
+
"""
|
67
|
+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
|
68
|
+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
69
|
+
device = min_vals.device
|
70
|
+
|
71
|
+
bit_min, bit_max = calculate_range(quantization_args, device)
|
72
|
+
bit_range = bit_max - bit_min
|
73
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
74
|
+
|
75
|
+
if quantization_args.symmetric:
|
76
|
+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
|
77
|
+
scales = max_val_pos / (float(bit_range) / 2)
|
78
|
+
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
79
|
+
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
|
80
|
+
else:
|
81
|
+
scales = (max_vals - min_vals) / float(bit_range)
|
82
|
+
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
83
|
+
zero_points = bit_min - (min_vals / scales)
|
84
|
+
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
85
|
+
|
86
|
+
# match zero-points to quantized type
|
87
|
+
zero_points = zero_points.to(zp_dtype)
|
88
|
+
|
89
|
+
if scales.ndim == 0:
|
90
|
+
scales = scales.reshape(1)
|
91
|
+
zero_points = zero_points.reshape(1)
|
92
|
+
|
93
|
+
return scales, zero_points
|
94
|
+
|
95
|
+
|
96
|
+
def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
|
97
|
+
"""
|
98
|
+
Returns the computed scales and zero points for dynamic activation
|
99
|
+
qunatization.
|
100
|
+
|
101
|
+
:param value: tensor to calculate quantization parameters for
|
102
|
+
:param args: quantization args
|
103
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
104
|
+
returned scale and zero point will be shaped (1,) along the
|
105
|
+
reduced dimensions
|
106
|
+
:return: tuple of scale and zero point derived from the observed tensor
|
107
|
+
"""
|
108
|
+
if args.strategy == QuantizationStrategy.TOKEN:
|
109
|
+
dim = {1, 2}
|
110
|
+
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
|
111
|
+
elif args.strategy == QuantizationStrategy.TENSOR:
|
112
|
+
reduce_dims = None
|
113
|
+
else:
|
114
|
+
raise ValueError(
|
115
|
+
f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
|
116
|
+
"must be used for dynamic quantization",
|
117
|
+
)
|
118
|
+
|
119
|
+
if not reduce_dims:
|
120
|
+
min_val, max_val = torch.aminmax(value)
|
121
|
+
else:
|
122
|
+
min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
|
123
|
+
max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
|
124
|
+
|
125
|
+
return calculate_qparams(min_val, max_val, args)
|
126
|
+
|
127
|
+
|
128
|
+
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
|
129
|
+
"""
|
130
|
+
Calculated the effective quantization range for the given Quantization Args
|
131
|
+
|
132
|
+
:param quantization_args: quantization args to get range of
|
133
|
+
:param device: device to store the range to
|
134
|
+
:return: tuple endpoints for the given quantization range
|
135
|
+
"""
|
136
|
+
if quantization_args.type == QuantizationType.INT:
|
137
|
+
bit_range = 2**quantization_args.num_bits
|
138
|
+
q_max = torch.tensor(bit_range / 2 - 1, device=device)
|
139
|
+
q_min = torch.tensor(-bit_range / 2, device=device)
|
140
|
+
elif quantization_args.type == QuantizationType.FLOAT:
|
141
|
+
if quantization_args.num_bits != 8:
|
142
|
+
raise ValueError(
|
143
|
+
"Floating point quantization is only supported for 8 bits,"
|
144
|
+
f"got {quantization_args.num_bits}"
|
145
|
+
)
|
146
|
+
fp_range_info = torch.finfo(FP8_DTYPE)
|
147
|
+
q_max = torch.tensor(fp_range_info.max, device=device)
|
148
|
+
q_min = torch.tensor(fp_range_info.min, device=device)
|
149
|
+
else:
|
150
|
+
raise ValueError(f"Invalid quantization type {quantization_args.type}")
|
151
|
+
|
152
|
+
return q_min, q_max
|
153
|
+
|
154
|
+
|
48
155
|
def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
|
49
156
|
"""
|
50
157
|
Checks the quantization status of a model. Assumes all modules in the model have
|
@@ -118,12 +225,17 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None
|
|
118
225
|
"""
|
119
226
|
for name, submodule in model.named_modules():
|
120
227
|
children = list(submodule.children())
|
121
|
-
if
|
228
|
+
# TODO: verify if an observer would ever be attached in this case/remove check
|
229
|
+
if len(children) == 0 and "observer" in name:
|
122
230
|
yield name, submodule
|
123
231
|
else:
|
232
|
+
if len(children) > 0:
|
233
|
+
named_children, children = zip(*list(submodule.named_children()))
|
124
234
|
has_non_observer_children = False
|
125
|
-
for
|
126
|
-
|
235
|
+
for i in range(len(children)):
|
236
|
+
child_name = named_children[i]
|
237
|
+
|
238
|
+
if "observer" not in child_name:
|
127
239
|
has_non_observer_children = True
|
128
240
|
|
129
241
|
if not has_non_observer_children:
|
@@ -144,14 +256,19 @@ def iter_named_quantizable_modules(
|
|
144
256
|
:returns: generator tuple of (name, submodule)
|
145
257
|
"""
|
146
258
|
for name, submodule in model.named_modules():
|
259
|
+
# TODO: verify if an observer would ever be attached in this case/remove check
|
147
260
|
if include_children:
|
148
261
|
children = list(submodule.children())
|
149
|
-
if len(children) == 0 and not
|
262
|
+
if len(children) == 0 and "observer" not in name:
|
150
263
|
yield name, submodule
|
151
264
|
else:
|
265
|
+
if len(children) > 0:
|
266
|
+
named_children, children = zip(*list(submodule.named_children()))
|
152
267
|
has_non_observer_children = False
|
153
|
-
for
|
154
|
-
|
268
|
+
for i in range(len(children)):
|
269
|
+
child_name = named_children[i]
|
270
|
+
|
271
|
+
if "observer" not in child_name:
|
155
272
|
has_non_observer_children = True
|
156
273
|
|
157
274
|
if not has_non_observer_children:
|
@@ -258,7 +258,7 @@ def get_from_registry(
|
|
258
258
|
retrieved_value = _import_and_get_value_from_module(module_path, value_name)
|
259
259
|
else:
|
260
260
|
# look up name in alias registry
|
261
|
-
name = _ALIAS_REGISTRY[parent_class].get(name)
|
261
|
+
name = _ALIAS_REGISTRY[parent_class].get(name, name)
|
262
262
|
# look up name in registry
|
263
263
|
retrieved_value = _REGISTRY[parent_class].get(name)
|
264
264
|
if retrieved_value is None:
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Any, Optional
|
15
|
+
from typing import Any, Dict, Optional
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from transformers import AutoConfig
|
@@ -24,6 +24,7 @@ __all__ = [
|
|
24
24
|
"tensor_follows_mask_structure",
|
25
25
|
"replace_module",
|
26
26
|
"is_compressed_tensors_config",
|
27
|
+
"Aliasable",
|
27
28
|
]
|
28
29
|
|
29
30
|
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
@@ -119,3 +120,34 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
|
|
119
120
|
return isinstance(compression_config, CompressedTensorsConfig)
|
120
121
|
except ImportError:
|
121
122
|
return False
|
123
|
+
|
124
|
+
|
125
|
+
class Aliasable:
|
126
|
+
"""
|
127
|
+
A mixin for enums to allow aliasing of enum members
|
128
|
+
|
129
|
+
Example:
|
130
|
+
>>> class MyClass(Aliasable, int, Enum):
|
131
|
+
>>> ...
|
132
|
+
"""
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def get_aliases() -> Dict[str, str]:
|
136
|
+
raise NotImplementedError()
|
137
|
+
|
138
|
+
def __eq__(self, other):
|
139
|
+
if isinstance(other, self.__class__):
|
140
|
+
aliases = self.get_aliases()
|
141
|
+
return self.value == other.value or (
|
142
|
+
aliases.get(self.value, self.value)
|
143
|
+
== aliases.get(other.value, other.value)
|
144
|
+
)
|
145
|
+
else:
|
146
|
+
aliases = self.get_aliases()
|
147
|
+
self_value = aliases.get(self.value, self.value)
|
148
|
+
other_value = aliases.get(other, other)
|
149
|
+
return self_value == other_value
|
150
|
+
|
151
|
+
def __hash__(self):
|
152
|
+
canonical_value = self.aliases.get(self.value, self.value)
|
153
|
+
return hash(canonical_value)
|
compressed_tensors/version.py
CHANGED
@@ -1,58 +1,50 @@
|
|
1
1
|
compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
|
2
2
|
compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
|
3
|
-
compressed_tensors/version.py,sha256=
|
3
|
+
compressed_tensors/version.py,sha256=U6bppqc5inOxvcJDHWhDoSXvBrvbH425oJM2WG7TECY,1585
|
4
4
|
compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
|
5
5
|
compressed_tensors/compressors/base.py,sha256=D9TNwQcjanDiAHODPbg8JUqc66e3j50rctY7A708NEs,6743
|
6
6
|
compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
|
7
7
|
compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
|
8
|
-
compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=
|
8
|
+
compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=sxh1TvW1Bp9YJE41hW0XZfd0kYYB85nhJvBLVRTDcV0,15886
|
9
9
|
compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=09UJq68Pht6Bf-4iP9xYl3tetKsncNPHD8IAGbePsr4,714
|
10
10
|
compressed_tensors/compressors/quantized_compressors/base.py,sha256=K1KOnS6Y8nUA1-HN7VhyfsDc01nilW0WfXMUhuD-l8w,5954
|
11
|
-
compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=
|
12
|
-
compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=
|
11
|
+
compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=MMUya3Iwarm0BkeYXqKTUnEDPiBw98GKF09QiNST45k,4960
|
12
|
+
compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=1CLwvBlu4AtGkuo3IisD1-rQzwLiA6hE1bCc-pF_XGo,7758
|
13
13
|
compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=i2TESH27l7KXeOhJ6hShIoI904XX96l-cRQiMR6MAaU,704
|
14
14
|
compressed_tensors/compressors/sparse_compressors/base.py,sha256=Ua4rUSGyucEs-YJI5z3oIUF-zqQLrFsQ9f-qKasEdUM,4410
|
15
15
|
compressed_tensors/compressors/sparse_compressors/dense.py,sha256=lSKNWRx6H7aUqaJj1j4qbXk8Gkm1UohbnvW1Rvq6Ra4,1284
|
16
16
|
compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=4fKwCG7ZM8mUtSnjPvubzEHl-mTnxMzwjmcs7L43WLY,6622
|
17
17
|
compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
|
18
|
-
compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=
|
18
|
+
compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=BMIQWTLlnUvxy14iEJegtiP75WHJeOVojey9mKOK1hE,9427
|
19
19
|
compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
|
20
|
-
compressed_tensors/config/base.py,sha256=
|
20
|
+
compressed_tensors/config/base.py,sha256=3bFAdwDZjOt-U3fneOeL8dRci-PS8DqstnXuQVtkfiQ,3435
|
21
21
|
compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
|
22
22
|
compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
|
23
23
|
compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
24
|
-
compressed_tensors/linear/compressed_linear.py,sha256=
|
25
|
-
compressed_tensors/quantization/__init__.py,sha256=
|
26
|
-
compressed_tensors/quantization/
|
27
|
-
compressed_tensors/quantization/
|
28
|
-
compressed_tensors/quantization/
|
29
|
-
compressed_tensors/quantization/
|
30
|
-
compressed_tensors/quantization/lifecycle/
|
31
|
-
compressed_tensors/quantization/lifecycle/apply.py,sha256=czaayvpeUYyWRJhO_klffw6esptOgA9sBKL5TWQcRdw,15805
|
32
|
-
compressed_tensors/quantization/lifecycle/calibration.py,sha256=IuLeRkVQPrMxkMcIjr4OMFlIUMHkqjH4qAxC2KiUBGw,2673
|
24
|
+
compressed_tensors/linear/compressed_linear.py,sha256=MJa-UfoKhIkdUWRD1shrXXri2cOwR5GK0a4t4bNYosM,3268
|
25
|
+
compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
|
26
|
+
compressed_tensors/quantization/quant_args.py,sha256=jwC__lSmuiJ2qSJYYZGgWgQNbZu6YhhS0e-qugrTNXE,9058
|
27
|
+
compressed_tensors/quantization/quant_config.py,sha256=K6kOZ6LDXpFlqsVzR4NEATV6y6Ea83rJWnNyVlvw-pI,10379
|
28
|
+
compressed_tensors/quantization/quant_scheme.py,sha256=eQ0JrRZ80GX69fpwW87VzPzzhajhk4mUaJScjk82OY4,6010
|
29
|
+
compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
|
30
|
+
compressed_tensors/quantization/lifecycle/apply.py,sha256=jCUSgeOBtagE5IhgIbyYMZ4kv8Rm20VGJ4IxXZ5HAnw,15066
|
33
31
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
34
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
35
|
-
compressed_tensors/quantization/lifecycle/frozen.py,sha256=NiJw7NP7pcT6idWFa8vksgiLoT8oQ975e57S4QfD2QQ,1874
|
32
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=QPL6-vKOFuKdKIEsVqMhsw4x552Jpm2sqO0oeChbnrM,12941
|
36
33
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
37
|
-
compressed_tensors/quantization/lifecycle/initialize.py,sha256=
|
38
|
-
compressed_tensors/quantization/observers/__init__.py,sha256=DYrttzq-8MHLZUzpX-xzzm4hrw6HcXkMkux82KBKb1M,738
|
39
|
-
compressed_tensors/quantization/observers/base.py,sha256=5ovQicWPYHjIxr6-EkQ4lgOX0PpI9g23iSzKpxjM1Zg,8420
|
40
|
-
compressed_tensors/quantization/observers/helpers.py,sha256=o9hg4E9b5cCb5PaEAj6jHiUWkNrKtYtv0b1pGg-T9B4,5516
|
41
|
-
compressed_tensors/quantization/observers/min_max.py,sha256=sQXqU3z-voxIDfR_9mQzwQUflZj2sASm_G8CYaXntFw,3865
|
42
|
-
compressed_tensors/quantization/observers/mse.py,sha256=Aeh-253Vbab1F8cYuBiGNn4OXWJ67wXQ_JVfl3mu2a8,6034
|
34
|
+
compressed_tensors/quantization/lifecycle/initialize.py,sha256=C41hKA5VANyEwkB5FxzEn3Z0Da5tfxF1I07P8rUcyS0,8537
|
43
35
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
44
|
-
compressed_tensors/quantization/utils/helpers.py,sha256=
|
36
|
+
compressed_tensors/quantization/utils/helpers.py,sha256=DBP-sGRpGAY01K0LFE7qqonNj4hkTYL_mXrMs2LtAD8,14100
|
45
37
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
46
|
-
compressed_tensors/registry/registry.py,sha256=
|
38
|
+
compressed_tensors/registry/registry.py,sha256=vRcjVB1ITfSbfYUaGndBBmqhip_5vsS62weorVg0iXo,11896
|
47
39
|
compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
|
48
|
-
compressed_tensors/utils/helpers.py,sha256=
|
40
|
+
compressed_tensors/utils/helpers.py,sha256=T3p0TbhWbQIRjL6Up2Z7UhZO5jpR6WxBhYPPvrhE6lE,5018
|
49
41
|
compressed_tensors/utils/offload.py,sha256=d9q8LNe8HyF8tOjgjA7QGLD3HRysmNp0d8eBbdqBgIM,4089
|
50
42
|
compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
51
43
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
52
44
|
compressed_tensors/utils/safetensors_load.py,sha256=m08ANVuTBxQdoa6LufDgcNJ7wCLDJolyZljB8VEybAU,8578
|
53
45
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
54
|
-
compressed_tensors-0.
|
55
|
-
compressed_tensors-0.
|
56
|
-
compressed_tensors-0.
|
57
|
-
compressed_tensors-0.
|
58
|
-
compressed_tensors-0.
|
46
|
+
compressed_tensors-0.8.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
47
|
+
compressed_tensors-0.8.1.dist-info/METADATA,sha256=rDPAoGePUI_yRN7LRP23t3vKWhDfxPbeNR1TX6vpPPI,6782
|
48
|
+
compressed_tensors-0.8.1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
49
|
+
compressed_tensors-0.8.1.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
50
|
+
compressed_tensors-0.8.1.dist-info/RECORD,,
|
@@ -1,201 +0,0 @@
|
|
1
|
-
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing,
|
10
|
-
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
|
16
|
-
from enum import Enum
|
17
|
-
from typing import Any, Dict, List, Optional, Tuple
|
18
|
-
|
19
|
-
from compressed_tensors.quantization.observers import Observer
|
20
|
-
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
21
|
-
from torch import Tensor
|
22
|
-
from transformers import DynamicCache as HFDyanmicCache
|
23
|
-
|
24
|
-
|
25
|
-
class KVCacheScaleType(Enum):
|
26
|
-
KEY = "k_scale"
|
27
|
-
VALUE = "v_scale"
|
28
|
-
|
29
|
-
|
30
|
-
class QuantizedKVParameterCache(HFDyanmicCache):
|
31
|
-
|
32
|
-
"""
|
33
|
-
Quantized KV cache used in the forward call based on HF's dynamic cache.
|
34
|
-
Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
|
35
|
-
Singleton, so that the same cache gets reused in all forward call of self_attn.
|
36
|
-
Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
|
37
|
-
gets called appropriately.
|
38
|
-
The size of tensor is
|
39
|
-
`[batch_size, num_heads, seq_len - residual_length, head_dim]`.
|
40
|
-
|
41
|
-
|
42
|
-
Triggered by adding kv_cache_scheme in the recipe.
|
43
|
-
|
44
|
-
Example:
|
45
|
-
|
46
|
-
```python3
|
47
|
-
recipe = '''
|
48
|
-
quant_stage:
|
49
|
-
quant_modifiers:
|
50
|
-
QuantizationModifier:
|
51
|
-
kv_cache_scheme:
|
52
|
-
num_bits: 8
|
53
|
-
type: float
|
54
|
-
strategy: tensor
|
55
|
-
dynamic: false
|
56
|
-
symmetric: true
|
57
|
-
'''
|
58
|
-
|
59
|
-
"""
|
60
|
-
|
61
|
-
_instance = None
|
62
|
-
_initialized = False
|
63
|
-
|
64
|
-
def __new__(cls, *args, **kwargs):
|
65
|
-
"""Singleton"""
|
66
|
-
if cls._instance is None:
|
67
|
-
cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
|
68
|
-
return cls._instance
|
69
|
-
|
70
|
-
def __init__(self, quantization_args: QuantizationArgs):
|
71
|
-
if not self._initialized:
|
72
|
-
super().__init__()
|
73
|
-
|
74
|
-
self.quantization_args = quantization_args
|
75
|
-
|
76
|
-
self.k_observers: List[Observer] = []
|
77
|
-
self.v_observers: List[Observer] = []
|
78
|
-
|
79
|
-
# each index corresponds to layer_idx of the attention layer
|
80
|
-
self.k_scales: List[Tensor] = []
|
81
|
-
self.v_scales: List[Tensor] = []
|
82
|
-
|
83
|
-
self.k_zps: List[Tensor] = []
|
84
|
-
self.v_zps: List[Tensor] = []
|
85
|
-
|
86
|
-
self._initialized = True
|
87
|
-
|
88
|
-
def update(
|
89
|
-
self,
|
90
|
-
key_states: Tensor,
|
91
|
-
value_states: Tensor,
|
92
|
-
layer_idx: int,
|
93
|
-
cache_kwargs: Optional[Dict[str, Any]] = None,
|
94
|
-
) -> Tuple[Tensor, Tensor]:
|
95
|
-
"""
|
96
|
-
Get the k_scale and v_scale and output the
|
97
|
-
fakequant-ed key_states and value_states
|
98
|
-
"""
|
99
|
-
|
100
|
-
if len(self.k_observers) <= layer_idx:
|
101
|
-
k_observer = self.quantization_args.get_observer()
|
102
|
-
v_observer = self.quantization_args.get_observer()
|
103
|
-
|
104
|
-
self.k_observers.append(k_observer)
|
105
|
-
self.v_observers.append(v_observer)
|
106
|
-
|
107
|
-
q_key_states = self._quantize(
|
108
|
-
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
|
109
|
-
)
|
110
|
-
q_value_states = self._quantize(
|
111
|
-
value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
|
112
|
-
)
|
113
|
-
|
114
|
-
qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
|
115
|
-
qdq_value_states = self._dequantize(
|
116
|
-
q_value_states, KVCacheScaleType.VALUE, layer_idx
|
117
|
-
)
|
118
|
-
|
119
|
-
keys_to_return, values_to_return = qdq_key_states, qdq_value_states
|
120
|
-
|
121
|
-
return keys_to_return, values_to_return
|
122
|
-
|
123
|
-
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
124
|
-
"""
|
125
|
-
Returns the sequence length of the cached states.
|
126
|
-
A layer index can be optionally passed.
|
127
|
-
"""
|
128
|
-
if len(self.key_cache) <= layer_idx:
|
129
|
-
return 0
|
130
|
-
# since we cannot get the seq_length of each layer directly and
|
131
|
-
# rely on `_seen_tokens` which is updated every "layer_idx" == 0,
|
132
|
-
# this is a hack to get the actual seq_length for the given layer_idx
|
133
|
-
# this part of code otherwise fails when used to
|
134
|
-
# verify attn_weight shape in some models
|
135
|
-
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
|
136
|
-
|
137
|
-
def reset_states(self):
|
138
|
-
"""reset the kv states (used in calibration)"""
|
139
|
-
self.key_cache: List[Tensor] = []
|
140
|
-
self.value_cache: List[Tensor] = []
|
141
|
-
# Used in `generate` to keep tally of how many tokens the cache has seen
|
142
|
-
self._seen_tokens = 0
|
143
|
-
self._quantized_key_cache: List[Tensor] = []
|
144
|
-
self._quantized_value_cache: List[Tensor] = []
|
145
|
-
|
146
|
-
def reset(self):
|
147
|
-
"""
|
148
|
-
Reset the instantiation, create new instance on init
|
149
|
-
"""
|
150
|
-
QuantizedKVParameterCache._instance = None
|
151
|
-
QuantizedKVParameterCache._initialized = False
|
152
|
-
|
153
|
-
def _quantize(self, tensor, kv_type, layer_idx):
|
154
|
-
"""Quantizes a key/value using a defined quantization method."""
|
155
|
-
from compressed_tensors.quantization.lifecycle.forward import quantize
|
156
|
-
|
157
|
-
if kv_type == KVCacheScaleType.KEY: # key type
|
158
|
-
observer = self.k_observers[layer_idx]
|
159
|
-
scales = self.k_scales
|
160
|
-
zps = self.k_zps
|
161
|
-
else:
|
162
|
-
assert kv_type == KVCacheScaleType.VALUE
|
163
|
-
observer = self.v_observers[layer_idx]
|
164
|
-
scales = self.v_scales
|
165
|
-
zps = self.v_zps
|
166
|
-
|
167
|
-
scale, zp = observer(tensor)
|
168
|
-
if len(scales) <= layer_idx:
|
169
|
-
scales.append(scale)
|
170
|
-
zps.append(zp)
|
171
|
-
else:
|
172
|
-
scales[layer_idx] = scale
|
173
|
-
zps[layer_idx] = scale
|
174
|
-
|
175
|
-
q_tensor = quantize(
|
176
|
-
x=tensor,
|
177
|
-
scale=scale,
|
178
|
-
zero_point=zp,
|
179
|
-
args=self.quantization_args,
|
180
|
-
)
|
181
|
-
return q_tensor
|
182
|
-
|
183
|
-
def _dequantize(self, qtensor, kv_type, layer_idx):
|
184
|
-
"""Dequantizes back the tensor that was quantized by `self._quantize()`"""
|
185
|
-
from compressed_tensors.quantization.lifecycle.forward import dequantize
|
186
|
-
|
187
|
-
if kv_type == KVCacheScaleType.KEY:
|
188
|
-
scale = self.k_scales[layer_idx]
|
189
|
-
zp = self.k_zps[layer_idx]
|
190
|
-
else:
|
191
|
-
assert kv_type == KVCacheScaleType.VALUE
|
192
|
-
scale = self.v_scales[layer_idx]
|
193
|
-
zp = self.v_zps[layer_idx]
|
194
|
-
|
195
|
-
qdq_tensor = dequantize(
|
196
|
-
x_q=qtensor,
|
197
|
-
scale=scale,
|
198
|
-
zero_point=zp,
|
199
|
-
args=self.quantization_args,
|
200
|
-
)
|
201
|
-
return qdq_tensor
|