compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +1 -0
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +200 -8
- compressed_tensors/compressors/dense.py +1 -1
- compressed_tensors/compressors/marlin_24.py +11 -10
- compressed_tensors/compressors/model_compressor.py +101 -13
- compressed_tensors/compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/pack_quantized.py +128 -132
- compressed_tensors/compressors/sparse_bitmask.py +1 -1
- compressed_tensors/config/base.py +8 -1
- compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -44
- compressed_tensors/quantization/lifecycle/calibration.py +22 -2
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +139 -61
- compressed_tensors/quantization/lifecycle/helpers.py +80 -0
- compressed_tensors/quantization/lifecycle/initialize.py +77 -13
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +93 -14
- compressed_tensors/quantization/observers/helpers.py +64 -11
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +139 -23
- compressed_tensors/quantization/quant_config.py +35 -2
- compressed_tensors/quantization/quant_scheme.py +112 -13
- compressed_tensors/quantization/utils/helpers.py +68 -2
- compressed_tensors/utils/__init__.py +5 -0
- compressed_tensors/utils/helpers.py +44 -2
- compressed_tensors/utils/offload.py +116 -0
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
- compressed_tensors-0.6.0.dist-info/RECORD +52 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/int_quantized.py +0 -126
- compressed_tensors/compressors/utils/helpers.py +0 -43
- compressed_tensors-0.4.0.dist-info/RECORD +0 -48
- /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -11,10 +11,8 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
|
-
import logging
|
16
14
|
import math
|
17
|
-
from typing import Dict,
|
15
|
+
from typing import Dict, Optional, Tuple
|
18
16
|
|
19
17
|
import numpy as np
|
20
18
|
import torch
|
@@ -23,15 +21,10 @@ from compressed_tensors.config import CompressionFormat
|
|
23
21
|
from compressed_tensors.quantization import QuantizationArgs
|
24
22
|
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
25
23
|
from compressed_tensors.quantization.utils import can_quantize
|
26
|
-
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
27
|
-
from safetensors import safe_open
|
28
24
|
from torch import Tensor
|
29
|
-
from tqdm import tqdm
|
30
|
-
|
31
25
|
|
32
|
-
__all__ = ["PackedQuantizationCompressor", "pack_4bit_ints", "unpack_4bit_ints"]
|
33
26
|
|
34
|
-
|
27
|
+
__all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
|
35
28
|
|
36
29
|
|
37
30
|
@Compressor.register(name=CompressionFormat.pack_quantized.value)
|
@@ -44,142 +37,142 @@ class PackedQuantizationCompressor(Compressor):
|
|
44
37
|
"weight_packed",
|
45
38
|
"weight_scale",
|
46
39
|
"weight_zero_point",
|
40
|
+
"weight_g_idx",
|
47
41
|
"weight_shape",
|
48
42
|
]
|
49
43
|
|
50
|
-
def
|
44
|
+
def compression_param_info(
|
51
45
|
self,
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
) -> Dict[str, Tensor]:
|
46
|
+
weight_shape: torch.Size,
|
47
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
48
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
56
49
|
"""
|
57
|
-
|
50
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
51
|
+
parameter used by the compressor
|
58
52
|
|
59
|
-
:param
|
60
|
-
:param
|
61
|
-
|
62
|
-
|
53
|
+
:param weight_shape: uncompressed weight shape
|
54
|
+
:param quantization_args: quantization parameters for the weight
|
55
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
56
|
+
"""
|
57
|
+
pack_factor = 32 // quantization_args.num_bits
|
58
|
+
packed_size = math.ceil(weight_shape[1] / pack_factor)
|
59
|
+
return {
|
60
|
+
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
|
61
|
+
"weight_shape": (torch.Size((2,)), torch.int32),
|
62
|
+
}
|
63
|
+
|
64
|
+
def compress_weight(
|
65
|
+
self,
|
66
|
+
weight: Tensor,
|
67
|
+
scale: Tensor,
|
68
|
+
zero_point: Optional[Tensor] = None,
|
69
|
+
g_idx: Optional[torch.Tensor] = None,
|
70
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
71
|
+
device: Optional[torch.device] = None,
|
72
|
+
) -> Dict[str, torch.Tensor]:
|
73
|
+
"""
|
74
|
+
Compresses a single uncompressed weight
|
75
|
+
|
76
|
+
:param weight: uncompressed weight tensor
|
77
|
+
:param scale: quantization scale for weight
|
78
|
+
:param zero_point: quantization zero point for weight
|
79
|
+
:param g_idx: optional mapping from column index to group index
|
80
|
+
:param quantization_args: quantization parameters for weight
|
81
|
+
:param device: optional device to move compressed output to
|
82
|
+
:return: dictionary of compressed weight data
|
63
83
|
"""
|
64
84
|
compressed_dict = {}
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
scale=scale,
|
84
|
-
zero_point=zp,
|
85
|
-
args=quant_args,
|
86
|
-
dtype=torch.int8,
|
87
|
-
)
|
88
|
-
value = pack_4bit_ints(value.cpu())
|
89
|
-
compressed_dict[merge_names(prefix, "weight_shape")] = shape
|
90
|
-
compressed_dict[merge_names(prefix, "weight_packed")] = value
|
91
|
-
continue
|
92
|
-
|
93
|
-
elif name.endswith("zero_point"):
|
94
|
-
if torch.all(value == 0):
|
95
|
-
# all zero_points are 0, no need to include in
|
96
|
-
# compressed state_dict
|
97
|
-
continue
|
98
|
-
|
99
|
-
compressed_dict[name] = value.to("cpu")
|
85
|
+
if can_quantize(weight, quantization_args):
|
86
|
+
quantized_weight = quantize(
|
87
|
+
x=weight,
|
88
|
+
scale=scale,
|
89
|
+
zero_point=zero_point,
|
90
|
+
g_idx=g_idx,
|
91
|
+
args=quantization_args,
|
92
|
+
dtype=torch.int8,
|
93
|
+
)
|
94
|
+
|
95
|
+
packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
|
96
|
+
weight_shape = torch.tensor(weight.shape)
|
97
|
+
if device is not None:
|
98
|
+
packed_weight = packed_weight.to(device)
|
99
|
+
weight_shape = weight_shape.to(device)
|
100
|
+
|
101
|
+
compressed_dict["weight_shape"] = weight_shape
|
102
|
+
compressed_dict["weight_packed"] = packed_weight
|
100
103
|
|
101
104
|
return compressed_dict
|
102
105
|
|
103
|
-
def
|
104
|
-
self,
|
105
|
-
|
106
|
+
def decompress_weight(
|
107
|
+
self,
|
108
|
+
compressed_data: Dict[str, Tensor],
|
109
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
110
|
+
) -> torch.Tensor:
|
106
111
|
"""
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
:
|
112
|
-
one or more safetensors files) or compressed tensors file
|
113
|
-
:param device: optional device to load intermediate weights into
|
114
|
-
:return: compressed state dict
|
112
|
+
Decompresses a single compressed weight
|
113
|
+
|
114
|
+
:param compressed_data: dictionary of data needed for decompression
|
115
|
+
:param quantization_args: quantization parameters for the weight
|
116
|
+
:return: tensor of the decompressed weight
|
115
117
|
"""
|
116
|
-
|
117
|
-
|
118
|
+
weight = compressed_data["weight_packed"]
|
119
|
+
scale = compressed_data["weight_scale"]
|
120
|
+
zero_point = compressed_data.get("weight_zero_point", None)
|
121
|
+
g_idx = compressed_data.get("weight_g_idx", None)
|
122
|
+
original_shape = torch.Size(compressed_data["weight_shape"])
|
123
|
+
num_bits = quantization_args.num_bits
|
124
|
+
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
125
|
+
decompressed_weight = dequantize(
|
126
|
+
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
|
118
127
|
)
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
weight_data[param_name] = f.get_tensor(full_name)
|
125
|
-
|
126
|
-
if "weight_scale" in weight_data:
|
127
|
-
zero_point = weight_data.get("weight_zero_point", None)
|
128
|
-
scale = weight_data["weight_scale"]
|
129
|
-
if zero_point is None:
|
130
|
-
# zero_point assumed to be 0 if not included in state_dict
|
131
|
-
zero_point = torch.zeros_like(scale)
|
132
|
-
|
133
|
-
weight = weight_data["weight_packed"]
|
134
|
-
original_shape = torch.Size(weight_data["weight_shape"])
|
135
|
-
unpacked = unpack_4bit_ints(weight, original_shape)
|
136
|
-
decompressed = dequantize(
|
137
|
-
x_q=unpacked,
|
138
|
-
scale=scale,
|
139
|
-
zero_point=zero_point,
|
140
|
-
)
|
141
|
-
yield merge_names(weight_name, "weight"), decompressed
|
142
|
-
|
143
|
-
|
144
|
-
def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor:
|
128
|
+
|
129
|
+
return decompressed_weight
|
130
|
+
|
131
|
+
|
132
|
+
def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
|
145
133
|
"""
|
146
|
-
Packs a tensor of
|
134
|
+
Packs a tensor of quantized weights stored in int8 into int32s with padding
|
147
135
|
|
148
136
|
:param value: tensor to pack
|
137
|
+
:param num_bits: number of bits used to store underlying data
|
149
138
|
:returns: packed int32 tensor
|
150
139
|
"""
|
151
140
|
if value.dtype is not torch.int8:
|
152
141
|
raise ValueError("Tensor must be quantized to torch.int8 before packing")
|
153
142
|
|
154
|
-
|
155
|
-
|
156
|
-
bits = np.unpackbits(temp.numpy(), axis=-1, bitorder="little")
|
157
|
-
ranges = np.array([range(x, x + 4) for x in range(0, bits.shape[1], 8)]).flatten()
|
158
|
-
only_4_bits = bits[:, ranges] # top 4 bits are 0 because we're really uint4
|
143
|
+
if num_bits > 8:
|
144
|
+
raise ValueError("Packing is only supported for less than 8 bits")
|
159
145
|
|
160
|
-
#
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
)
|
146
|
+
# convert to unsigned for packing
|
147
|
+
offset = pow(2, num_bits) // 2
|
148
|
+
value = (value + offset).to(torch.uint8)
|
149
|
+
value = value.cpu().numpy().astype(np.uint32)
|
150
|
+
pack_factor = 32 // num_bits
|
151
|
+
|
152
|
+
# pad input tensor and initialize packed output
|
153
|
+
packed_size = math.ceil(value.shape[1] / pack_factor)
|
154
|
+
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
|
155
|
+
padding = packed.shape[1] * pack_factor - value.shape[1]
|
156
|
+
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
|
168
157
|
|
169
|
-
#
|
170
|
-
|
171
|
-
|
172
|
-
compressed = np.ascontiguousarray(compressed).view(np.int32)
|
158
|
+
# pack values
|
159
|
+
for i in range(pack_factor):
|
160
|
+
packed |= value[:, i::pack_factor] << num_bits * i
|
173
161
|
|
174
|
-
|
162
|
+
# convert back to signed and torch
|
163
|
+
packed = np.ascontiguousarray(packed).view(np.int32)
|
164
|
+
return torch.from_numpy(packed)
|
175
165
|
|
176
166
|
|
177
|
-
def
|
167
|
+
def unpack_from_int32(
|
168
|
+
value: torch.Tensor, num_bits: int, shape: torch.Size
|
169
|
+
) -> torch.Tensor:
|
178
170
|
"""
|
179
|
-
Unpacks a tensor packed
|
180
|
-
original their
|
171
|
+
Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
|
172
|
+
original their bit range
|
181
173
|
|
182
174
|
:param value: tensor to upack
|
175
|
+
:param num_bits: number of bits to unpack each data point into
|
183
176
|
:param shape: shape to unpack into, used to remove padding
|
184
177
|
:returns: unpacked int8 tensor
|
185
178
|
"""
|
@@ -188,25 +181,28 @@ def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
|
|
188
181
|
f"Expected {torch.int32} but got {value.dtype}, Aborting unpack."
|
189
182
|
)
|
190
183
|
|
191
|
-
|
192
|
-
|
193
|
-
as_uint8 = value.numpy().view(np.uint8)
|
194
|
-
bits = np.unpackbits(as_uint8, axis=-1, bitorder="little")
|
195
|
-
original_row_size = int(shape[1] * individual_depth)
|
196
|
-
bits = bits[:, :original_row_size]
|
184
|
+
if num_bits > 8:
|
185
|
+
raise ValueError("Unpacking is only supported for less than 8 bits")
|
197
186
|
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
187
|
+
pack_factor = 32 // num_bits
|
188
|
+
|
189
|
+
# unpack
|
190
|
+
mask = pow(2, num_bits) - 1
|
191
|
+
unpacked = torch.zeros(
|
192
|
+
(value.shape[0], value.shape[1] * pack_factor),
|
193
|
+
device=value.device,
|
194
|
+
dtype=torch.int32,
|
195
|
+
)
|
196
|
+
for i in range(pack_factor):
|
197
|
+
unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
|
204
198
|
|
205
|
-
#
|
206
|
-
|
199
|
+
# remove padding
|
200
|
+
original_row_size = int(shape[1])
|
201
|
+
unpacked = unpacked[:, :original_row_size]
|
207
202
|
|
208
203
|
# bits are packed in unsigned format, reformat to signed
|
209
|
-
# update the value range from
|
210
|
-
|
204
|
+
# update the value range from unsigned to signed
|
205
|
+
offset = pow(2, num_bits) // 2
|
206
|
+
unpacked = (unpacked - offset).to(torch.int8)
|
211
207
|
|
212
|
-
return
|
208
|
+
return unpacked
|
@@ -72,7 +72,7 @@ class BitmaskCompressor(Compressor):
|
|
72
72
|
return compressed_dict
|
73
73
|
|
74
74
|
def decompress(
|
75
|
-
self, path_to_model_or_tensors: str, device: str = "cpu"
|
75
|
+
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
|
76
76
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
77
77
|
"""
|
78
78
|
Reads a bitmask compressed state dict located
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from enum import Enum
|
16
|
-
from typing import Optional
|
16
|
+
from typing import List, Optional
|
17
17
|
|
18
18
|
from compressed_tensors.registry import RegistryMixin
|
19
19
|
from pydantic import BaseModel
|
@@ -26,6 +26,8 @@ class CompressionFormat(Enum):
|
|
26
26
|
dense = "dense"
|
27
27
|
sparse_bitmask = "sparse-bitmask"
|
28
28
|
int_quantized = "int-quantized"
|
29
|
+
float_quantized = "float-quantized"
|
30
|
+
naive_quantized = "naive-quantized"
|
29
31
|
pack_quantized = "pack-quantized"
|
30
32
|
marlin_24 = "marlin-24"
|
31
33
|
|
@@ -35,11 +37,16 @@ class SparsityCompressionConfig(RegistryMixin, BaseModel):
|
|
35
37
|
Base data class for storing sparsity compression parameters
|
36
38
|
|
37
39
|
:param format: name of compression format
|
40
|
+
:param targets: List of layer names or layer types that aren't sparse and should
|
41
|
+
be ignored during compression. By default, assume all layers are targeted
|
42
|
+
:param ignore: List of layer names (unique) to ignore from targets. Defaults to None
|
38
43
|
:param global_sparsity: average sparsity of the entire model
|
39
44
|
:param sparsity_structure: structure of the sparsity, such as
|
40
45
|
"unstructured", "2:4", "8:16" etc
|
41
46
|
"""
|
42
47
|
|
43
48
|
format: str
|
49
|
+
targets: Optional[List[str]] = None
|
50
|
+
ignore: Optional[List[str]] = None
|
44
51
|
global_sparsity: Optional[float] = 0.0
|
45
52
|
sparsity_structure: Optional[str] = "unstructured"
|
@@ -11,9 +11,3 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
|
-
# flake8: noqa
|
16
|
-
|
17
|
-
from .helpers import *
|
18
|
-
from .permutations_24 import *
|
19
|
-
from .semi_structured_conversions import *
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from compressed_tensors.compressors.base import Compressor
|
17
|
+
from compressed_tensors.quantization import (
|
18
|
+
QuantizationScheme,
|
19
|
+
QuantizationStatus,
|
20
|
+
initialize_module_for_quantization,
|
21
|
+
)
|
22
|
+
from torch import Tensor
|
23
|
+
from torch.nn import Parameter
|
24
|
+
from torch.nn.functional import linear
|
25
|
+
from torch.nn.modules import Linear
|
26
|
+
|
27
|
+
|
28
|
+
class CompressedLinear(Linear):
|
29
|
+
"""
|
30
|
+
Wrapper module for running a compressed forward pass of a quantized Linear module.
|
31
|
+
The wrapped layer will decompressed on each forward call.
|
32
|
+
|
33
|
+
:param module: dense linear module to replace
|
34
|
+
:param quantization_scheme: quantization config for the module to wrap
|
35
|
+
:param quantization_format: compression format module is stored as
|
36
|
+
"""
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
@torch.no_grad()
|
40
|
+
def from_linear(
|
41
|
+
cls,
|
42
|
+
module: Linear,
|
43
|
+
quantization_scheme: QuantizationScheme,
|
44
|
+
quantization_format: str,
|
45
|
+
):
|
46
|
+
module.__class__ = CompressedLinear
|
47
|
+
module.compressor = Compressor.load_from_registry(quantization_format)
|
48
|
+
device = next(module.parameters()).device
|
49
|
+
|
50
|
+
# this will initialize all the scales and zero points
|
51
|
+
initialize_module_for_quantization(
|
52
|
+
module, quantization_scheme, force_zero_point=False
|
53
|
+
)
|
54
|
+
|
55
|
+
# get the shape and dtype of compressed parameters
|
56
|
+
compression_params = module.compressor.compression_param_info(
|
57
|
+
module.weight.shape, quantization_scheme.weights
|
58
|
+
)
|
59
|
+
|
60
|
+
# no need for this once quantization is initialized, will be replaced
|
61
|
+
# with the compressed parameter
|
62
|
+
delattr(module, "weight")
|
63
|
+
|
64
|
+
# populate compressed weights and quantization parameters
|
65
|
+
for name, (shape, dtype) in compression_params.items():
|
66
|
+
param = Parameter(
|
67
|
+
torch.empty(shape, device=device, dtype=dtype), requires_grad=False
|
68
|
+
)
|
69
|
+
module.register_parameter(name, param)
|
70
|
+
|
71
|
+
# mark module as compressed
|
72
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
73
|
+
|
74
|
+
# handles case where forward is wrapped in new_forward by accelerate hooks
|
75
|
+
if hasattr(module, "_old_forward"):
|
76
|
+
module._old_forward = CompressedLinear.forward.__get__(
|
77
|
+
module, CompressedLinear
|
78
|
+
)
|
79
|
+
|
80
|
+
return module
|
81
|
+
|
82
|
+
def forward(self, input: Tensor) -> Tensor:
|
83
|
+
"""
|
84
|
+
Decompresses the weight, then runs the wrapped forward pass
|
85
|
+
"""
|
86
|
+
uncompressed_weight = self.compressor.decompress_module(self)
|
87
|
+
return linear(input, uncompressed_weight, self.bias)
|