mct-nightly 2.2.0.20241022.507__py3-none-any.whl → 2.2.0.20241024.501__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD +38 -31
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +43 -29
- model_compression_toolkit/core/common/hessian/__init__.py +1 -1
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +222 -371
- model_compression_toolkit/core/common/hessian/hessian_scores_request.py +27 -41
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -10
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +11 -9
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +10 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +17 -15
- model_compression_toolkit/core/keras/data_util.py +67 -0
- model_compression_toolkit/core/keras/keras_implementation.py +7 -1
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/data_util.py +163 -0
- model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +6 -31
- model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py +11 -21
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +9 -7
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +8 -2
- model_compression_toolkit/core/pytorch/utils.py +22 -19
- model_compression_toolkit/core/quantization_prep_runner.py +2 -1
- model_compression_toolkit/core/runner.py +1 -2
- model_compression_toolkit/gptq/common/gptq_config.py +0 -2
- model_compression_toolkit/gptq/common/gptq_training.py +58 -114
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -6
- model_compression_toolkit/gptq/pytorch/gptq_loss.py +3 -2
- model_compression_toolkit/gptq/pytorch/gptq_training.py +97 -64
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +0 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +4 -3
- tests_pytest/keras/__init__.py +14 -0
- tests_pytest/keras/core/__init__.py +14 -0
- tests_pytest/keras/core/test_data_util.py +91 -0
- tests_pytest/pytorch/core/__init__.py +14 -0
- tests_pytest/pytorch/core/test_data_util.py +125 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,10 @@
|
|
15
15
|
|
16
16
|
from typing import Union, List
|
17
17
|
|
18
|
-
|
19
|
-
|
20
|
-
from model_compression_toolkit.core.common.hessian import HessianScoresRequest
|
18
|
+
import torch
|
19
|
+
|
21
20
|
from model_compression_toolkit.core.common.hessian.hessian_scores_calculator import HessianScoresCalculator
|
22
21
|
from model_compression_toolkit.logger import Logger
|
23
|
-
import torch
|
24
22
|
|
25
23
|
|
26
24
|
class HessianScoresCalculatorPytorch(HessianScoresCalculator):
|
@@ -28,28 +26,20 @@ class HessianScoresCalculatorPytorch(HessianScoresCalculator):
|
|
28
26
|
Pytorch-specific implementation of the Hessian approximation scores Calculator.
|
29
27
|
This class serves as a base for other Pytorch-specific Hessian approximation scores calculators.
|
30
28
|
"""
|
31
|
-
def
|
32
|
-
graph: Graph,
|
33
|
-
input_images: List[torch.Tensor],
|
34
|
-
fw_impl,
|
35
|
-
hessian_scores_request: HessianScoresRequest,
|
36
|
-
num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
|
29
|
+
def _generate_random_vectors_batch(self, shape: tuple, device: torch.device) -> torch.Tensor:
|
37
30
|
"""
|
31
|
+
Generate a batch of random vectors for Hutchinson estimation using Rademacher distribution.
|
38
32
|
|
39
33
|
Args:
|
40
|
-
|
41
|
-
|
42
|
-
fw_impl: Framework-specific implementation for Hessian scores computation.
|
43
|
-
hessian_scores_request: Configuration request for which to compute the Hessian approximation scores.
|
44
|
-
num_iterations_for_approximation: Number of iterations to use when approximating the Hessian based scores.
|
34
|
+
shape: target shape.
|
35
|
+
device: target device.
|
45
36
|
|
37
|
+
Returns:
|
38
|
+
Random tensor.
|
46
39
|
"""
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
hessian_scores_request=hessian_scores_request,
|
51
|
-
num_iterations_for_approximation=num_iterations_for_approximation)
|
52
|
-
|
40
|
+
v = torch.randint(high=2, size=shape, device=device)
|
41
|
+
v[v == 0] = -1
|
42
|
+
return v
|
53
43
|
|
54
44
|
def concat_tensors(self, tensors_to_concate: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
|
55
45
|
"""
|
@@ -12,19 +12,21 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from tqdm import tqdm
|
16
15
|
from typing import List
|
16
|
+
|
17
|
+
import numpy as np
|
17
18
|
import torch
|
18
19
|
from torch import autograd
|
19
|
-
|
20
|
+
from tqdm import tqdm
|
21
|
+
|
22
|
+
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE
|
20
23
|
from model_compression_toolkit.core.common import Graph
|
21
24
|
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
|
25
|
+
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
26
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
22
27
|
from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
|
23
28
|
HessianScoresCalculatorPytorch
|
24
29
|
from model_compression_toolkit.logger import Logger
|
25
|
-
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
26
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
27
|
-
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_EPS
|
28
30
|
|
29
31
|
|
30
32
|
class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
@@ -84,8 +86,8 @@ class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
84
86
|
|
85
87
|
prev_mean_results = None
|
86
88
|
for j in tqdm(range(self.num_iterations_for_approximation)):
|
87
|
-
# Getting a random vector with
|
88
|
-
v =
|
89
|
+
# Getting a random vector with the same shape as the model output
|
90
|
+
v = self._generate_random_vectors_batch(output_tensor.shape, device=device)
|
89
91
|
f_v = torch.mean(torch.sum(v * output_tensor, dim=-1))
|
90
92
|
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
|
91
93
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import operator
|
16
16
|
from copy import deepcopy
|
17
17
|
from functools import partial
|
18
|
-
from typing import List, Any, Tuple, Callable, Type, Dict
|
18
|
+
from typing import List, Any, Tuple, Callable, Type, Dict, Generator
|
19
19
|
|
20
20
|
import numpy as np
|
21
21
|
import torch
|
@@ -38,6 +38,7 @@ from model_compression_toolkit.core.common.model_builder_mode import ModelBuilde
|
|
38
38
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
39
39
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_kl_divergence, compute_cs
|
40
40
|
from model_compression_toolkit.core.pytorch.back2framework import get_pytorch_model_builder
|
41
|
+
from model_compression_toolkit.core.pytorch.data_util import data_gen_to_dataloader
|
41
42
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
42
43
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_folding import \
|
43
44
|
pytorch_batchnorm_folding, pytorch_batchnorm_forward_folding
|
@@ -563,4 +564,9 @@ class PytorchImplementation(FrameworkImplementation):
|
|
563
564
|
return get_inferable_quantizers(node,
|
564
565
|
get_weights_quantizer_for_node,
|
565
566
|
get_activations_quantizer_for_node,
|
566
|
-
node.get_node_weights_attributes())
|
567
|
+
node.get_node_weights_attributes())
|
568
|
+
|
569
|
+
@staticmethod
|
570
|
+
def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
|
571
|
+
""" Converts data generator into framework dataloader with arbitrary batch size. """
|
572
|
+
return data_gen_to_dataloader(data_gen_fn, batch_size=batch_size)
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import torch
|
16
16
|
from torch import Tensor
|
17
17
|
import numpy as np
|
18
|
-
from typing import Union
|
18
|
+
from typing import Union, Sequence, Optional, List, Tuple
|
19
19
|
|
20
20
|
from model_compression_toolkit.core.pytorch.constants import MAX_FLOAT16, MIN_FLOAT16
|
21
21
|
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
|
@@ -41,30 +41,33 @@ def set_model(model: torch.nn.Module, train_mode: bool = False):
|
|
41
41
|
model.to(device)
|
42
42
|
|
43
43
|
|
44
|
-
def to_torch_tensor(
|
45
|
-
|
44
|
+
def to_torch_tensor(data,
|
45
|
+
dtype: Optional = torch.float32) -> Union[Tensor, List[Tensor], Tuple[Tensor]]:
|
46
|
+
# TODO it would make more sense to keep the original type by default but it will break lots of existing calls
|
47
|
+
# that count on implicit convertion
|
46
48
|
"""
|
47
|
-
Convert
|
49
|
+
Convert data to Torch tensors and move to the working device.
|
50
|
+
Data can be numpy or torch tensor, a scalar, or a list or a tuple of such data. In the latter case only the inner
|
51
|
+
data is converted.
|
52
|
+
|
48
53
|
Args:
|
49
|
-
|
50
|
-
|
54
|
+
data: Input data
|
55
|
+
dtype: The desired data type for the tensor. Pass None to keep the type of the input data.
|
51
56
|
|
52
57
|
Returns:
|
53
|
-
Torch tensor
|
58
|
+
Torch tensor
|
54
59
|
"""
|
60
|
+
|
55
61
|
working_device = get_working_device()
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
return (to_torch_tensor(t) for t in
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
return torch.from_numpy(np.array(tensor).astype(numpy_type)).to(working_device)
|
66
|
-
else:
|
67
|
-
Logger.critical(f'Unsupported type for conversion to Torch.tensor: {type(tensor)}.')
|
62
|
+
|
63
|
+
if isinstance(data, list):
|
64
|
+
return [to_torch_tensor(t, dtype) for t in data]
|
65
|
+
|
66
|
+
if isinstance(data, tuple):
|
67
|
+
return tuple(to_torch_tensor(t, dtype) for t in data)
|
68
|
+
|
69
|
+
kwargs = {} if dtype is None else {'dtype': dtype}
|
70
|
+
return torch.as_tensor(data, device=working_device, **kwargs)
|
68
71
|
|
69
72
|
|
70
73
|
def torch_tensor_to_numpy(tensor: Union[torch.Tensor, list, tuple]) -> Union[np.ndarray, list, tuple]:
|
@@ -90,7 +90,8 @@ def quantization_preparation_runner(graph: Graph,
|
|
90
90
|
# Calculate quantization params
|
91
91
|
######################################
|
92
92
|
|
93
|
-
calculate_quantization_params(graph,
|
93
|
+
calculate_quantization_params(graph, fw_impl=fw_impl, repr_data_gen_fn=representative_data_gen,
|
94
|
+
hessian_info_service=hessian_info_service)
|
94
95
|
|
95
96
|
if tb_w is not None:
|
96
97
|
tb_w.add_graph(graph, 'thresholds_selection')
|
@@ -122,8 +122,7 @@ def core_runner(in_model: Any,
|
|
122
122
|
mixed_precision_enable=core_config.is_mixed_precision_enabled,
|
123
123
|
running_gptq=running_gptq)
|
124
124
|
|
125
|
-
hessian_info_service = HessianInfoService(graph=graph,
|
126
|
-
fw_impl=fw_impl)
|
125
|
+
hessian_info_service = HessianInfoService(graph=graph, fw_impl=fw_impl)
|
127
126
|
|
128
127
|
tg = quantization_preparation_runner(graph=graph,
|
129
128
|
representative_data_gen=representative_data_gen,
|
@@ -17,7 +17,6 @@ from enum import Enum
|
|
17
17
|
from typing import Callable, Any, Dict, Optional
|
18
18
|
|
19
19
|
from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
20
|
-
from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
|
21
20
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
22
21
|
|
23
22
|
|
@@ -54,7 +53,6 @@ class GPTQHessianScoresConfig:
|
|
54
53
|
scale_log_norm: bool = False
|
55
54
|
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
56
55
|
per_sample: bool = False
|
57
|
-
estimator_distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN
|
58
56
|
|
59
57
|
|
60
58
|
@dataclass
|
@@ -13,23 +13,21 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import copy
|
16
|
-
import hashlib
|
17
16
|
from abc import ABC, abstractmethod
|
17
|
+
from typing import Callable, List, Any, Iterable, Optional, Generator
|
18
|
+
|
18
19
|
import numpy as np
|
19
|
-
from typing import Callable, List, Any, Dict
|
20
20
|
|
21
|
-
from model_compression_toolkit.
|
22
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
23
|
-
from model_compression_toolkit.core.common import Graph, BaseNode
|
21
|
+
from model_compression_toolkit.core.common import Graph
|
24
22
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
|
+
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
|
24
|
+
HessianScoresGranularity, hessian_info_utils as hessian_utils
|
25
|
+
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
26
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
25
27
|
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
|
26
28
|
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
|
27
29
|
from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
|
28
|
-
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
29
30
|
from model_compression_toolkit.logger import Logger
|
30
|
-
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
|
31
|
-
HessianScoresGranularity
|
32
|
-
from model_compression_toolkit.core.common.hessian import hessian_info_utils as hessian_utils
|
33
31
|
|
34
32
|
|
35
33
|
class GPTQTrainer(ABC):
|
@@ -43,6 +41,7 @@ class GPTQTrainer(ABC):
|
|
43
41
|
gptq_config: GradientPTQConfig,
|
44
42
|
fw_impl: GPTQFrameworkImplemantation,
|
45
43
|
fw_info: FrameworkInfo,
|
44
|
+
representative_data_gen_fn: Callable[[], Generator],
|
46
45
|
hessian_info_service: HessianInfoService = None):
|
47
46
|
"""
|
48
47
|
Build two models from a graph: A teacher network (float model) and a student network (quantized model).
|
@@ -56,6 +55,7 @@ class GPTQTrainer(ABC):
|
|
56
55
|
gptq_config: GradientPTQConfig with parameters about the tuning process.
|
57
56
|
fw_impl: Framework implementation
|
58
57
|
fw_info: Framework information
|
58
|
+
representative_data_gen_fn: factory for representative data generator.
|
59
59
|
hessian_info_service: HessianInfoService for fetching and computing Hessian-approximation information.
|
60
60
|
"""
|
61
61
|
self.graph_float = copy.deepcopy(graph_float)
|
@@ -63,7 +63,7 @@ class GPTQTrainer(ABC):
|
|
63
63
|
self.gptq_config = gptq_config
|
64
64
|
self.fw_impl = fw_impl
|
65
65
|
self.fw_info = fw_info
|
66
|
-
|
66
|
+
self.representative_data_gen_fn = representative_data_gen_fn
|
67
67
|
# ----------------------------------------------
|
68
68
|
# Build two models and create compare nodes
|
69
69
|
# ----------------------------------------------
|
@@ -131,124 +131,69 @@ class GPTQTrainer(ABC):
|
|
131
131
|
|
132
132
|
return optimizer_with_param
|
133
133
|
|
134
|
-
def compute_hessian_based_weights(self) -> np.ndarray:
|
134
|
+
def compute_hessian_based_weights(self, data_loader: Iterable) -> np.ndarray:
|
135
135
|
"""
|
136
136
|
Computes scores based on the hessian approximation per layer w.r.t activations of the interest points.
|
137
137
|
|
138
138
|
Returns:
|
139
139
|
np.ndarray: Scores based on the hessian matrix approximation.
|
140
140
|
"""
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
# TODO this smells like a potential bug. In hessian calculation target nodes are topo sorted and results are returned
|
148
|
-
# TODO also target nodes are replaced for reuse. Does this work correctly?
|
149
|
-
approximations = self._fetch_hessian_approximations(HessianScoresGranularity.PER_TENSOR)
|
150
|
-
compare_point_to_hessian_approx_scores = {node: score for node, score in zip(self.compare_points, approximations)}
|
151
|
-
|
152
|
-
# Process the fetched hessian approximations to gather them per images
|
153
|
-
hessian_approx_score_by_image = (
|
154
|
-
self._process_hessian_approximations(compare_point_to_hessian_approx_scores))
|
155
|
-
|
156
|
-
# Check if log normalization is enabled in the configuration
|
157
|
-
if self.gptq_config.hessian_weights_config.log_norm:
|
158
|
-
# Calculate the mean of the approximations across images
|
159
|
-
mean_approx_scores = np.mean(hessian_approx_score_by_image, axis=0)
|
160
|
-
# Reduce unnecessary dims, should remain with one dimension for the number of nodes
|
161
|
-
mean_approx_scores = np.squeeze(mean_approx_scores)
|
162
|
-
# Handle zero values to avoid log(0)
|
163
|
-
mean_approx_scores = np.where(mean_approx_scores != 0, mean_approx_scores,
|
164
|
-
np.partition(mean_approx_scores, 1)[1])
|
165
|
-
|
166
|
-
# Calculate log weights
|
167
|
-
log_weights = np.log10(mean_approx_scores)
|
168
|
-
|
169
|
-
# Check if scaling of log normalization is enabled in the configuration
|
170
|
-
if self.gptq_config.hessian_weights_config.scale_log_norm:
|
171
|
-
# Scale the log weights to the range [0, 1]
|
172
|
-
return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights))
|
173
|
-
|
174
|
-
# Offset the log weights so the minimum value is 0
|
175
|
-
return log_weights - np.min(log_weights)
|
176
|
-
else:
|
177
|
-
# If log normalization is not enabled, return the mean of the approximations across images
|
178
|
-
return np.mean(hessian_approx_score_by_image, axis=0)
|
179
|
-
|
180
|
-
def _compute_sample_layer_attention_scores(self, inputs_batch) -> Dict[str, Dict[BaseNode, np.ndarray]]:
|
181
|
-
"""
|
182
|
-
Compute sample layer attention scores per image hash per layer.
|
141
|
+
request = self._build_hessian_request(
|
142
|
+
HessianScoresGranularity.PER_TENSOR,
|
143
|
+
data_loader=data_loader,
|
144
|
+
n_samples=self.gptq_config.hessian_weights_config.hessians_num_samples
|
145
|
+
)
|
146
|
+
layers_hessians = self.hessian_service.fetch_hessian(request)
|
183
147
|
|
184
|
-
|
185
|
-
|
148
|
+
hessian_approx_score_by_image = np.stack([layers_hessians[node.name] for node in self.compare_points], axis=1)
|
149
|
+
assert hessian_approx_score_by_image.shape[0] == self.gptq_config.hessian_weights_config.hessians_num_samples
|
186
150
|
|
187
|
-
|
188
|
-
|
151
|
+
if self.gptq_config.hessian_weights_config.norm_scores:
|
152
|
+
hessian_approx_score_by_image = hessian_utils.normalize_scores(hessian_approx_score_by_image)
|
189
153
|
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
hessian_score_per_image_per_layer = {}
|
195
|
-
# If hessian batch is smaller than inputs batch, split it to hessian batches. If hessian batch is larger,
|
196
|
-
# it's currently ignored (TODO)
|
197
|
-
for i in range(0, inputs_batch[0].shape[0], hessian_batch_size):
|
198
|
-
inputs = [t[i: i+hessian_batch_size] for t in inputs_batch]
|
199
|
-
hessian_score_per_image_per_layer.update(
|
200
|
-
self.hessian_service.compute_trackable_per_sample_hessian(request, inputs)
|
201
|
-
)
|
202
|
-
for img_hash, v in hessian_score_per_image_per_layer.items():
|
203
|
-
hessian_score_per_image_per_layer[img_hash] = {k: t.max(axis=0) for k, t in v.items()}
|
204
|
-
return hessian_score_per_image_per_layer
|
205
|
-
|
206
|
-
def _fetch_hessian_approximations(self, granularity: HessianScoresGranularity) -> Dict[BaseNode, List[List[float]]]:
|
207
|
-
"""
|
208
|
-
Fetches hessian approximations for each target node.
|
154
|
+
# Calculate the mean of the approximations across images
|
155
|
+
mean_approx_scores = np.mean(hessian_approx_score_by_image, axis=0)
|
156
|
+
# assert len(mean_approx_scores.shape) == len(self.compare_points)
|
209
157
|
|
210
|
-
|
211
|
-
|
212
|
-
"""
|
213
|
-
hessian_scores_request = self._build_hessian_request(granularity)
|
158
|
+
if not self.gptq_config.hessian_weights_config.log_norm:
|
159
|
+
return mean_approx_scores
|
214
160
|
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
return node_approximations
|
161
|
+
# Reduce unnecessary dims, should remain with one dimension for the number of nodes
|
162
|
+
mean_approx_scores = np.squeeze(mean_approx_scores)
|
163
|
+
# Handle zero values to avoid log(0)
|
164
|
+
mean_approx_scores = np.where(mean_approx_scores != 0, mean_approx_scores,
|
165
|
+
np.partition(mean_approx_scores, 1)[1])
|
221
166
|
|
222
|
-
|
223
|
-
|
224
|
-
mode=HessianMode.ACTIVATION,
|
225
|
-
granularity=granularity,
|
226
|
-
target_nodes=self.compare_points,
|
227
|
-
distribution=self.gptq_config.hessian_weights_config.estimator_distribution
|
228
|
-
)
|
167
|
+
# Calculate log weights
|
168
|
+
log_weights = np.log10(mean_approx_scores)
|
229
169
|
|
230
|
-
|
170
|
+
if self.gptq_config.hessian_weights_config.scale_log_norm:
|
171
|
+
# Scale the log weights to the range [0, 1]
|
172
|
+
return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights))
|
173
|
+
|
174
|
+
# Offset the log weights so the minimum value is 0
|
175
|
+
return log_weights - np.min(log_weights)
|
176
|
+
|
177
|
+
def _build_hessian_request(self, granularity: HessianScoresGranularity, data_loader: Iterable,
|
178
|
+
n_samples: Optional[int]) -> HessianScoresRequest:
|
231
179
|
"""
|
232
|
-
|
233
|
-
Receives a dictionary of Node to a list of the length of the number of images that were fetched.
|
234
|
-
Returns list of lists where each inner list is the approximations per image to all interest points.
|
180
|
+
Build hessian request for hessian service.
|
235
181
|
|
236
182
|
Args:
|
237
|
-
|
238
|
-
|
183
|
+
granularity: requested granularity.
|
184
|
+
data_loader: data loader yielding samples to compute hessians on.
|
185
|
+
n_samples: request number of samples.
|
239
186
|
|
240
187
|
Returns:
|
241
|
-
|
242
|
-
per image to all interest points.
|
188
|
+
Hessian request.
|
243
189
|
"""
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
return hessian_approx_score_by_image
|
190
|
+
return HessianScoresRequest(
|
191
|
+
mode=HessianMode.ACTIVATION,
|
192
|
+
granularity=granularity,
|
193
|
+
target_nodes=self.compare_points,
|
194
|
+
data_loader=data_loader,
|
195
|
+
n_samples=n_samples
|
196
|
+
)
|
252
197
|
|
253
198
|
@abstractmethod
|
254
199
|
def build_gptq_model(self):
|
@@ -261,11 +206,9 @@ class GPTQTrainer(ABC):
|
|
261
206
|
f'framework\'s GPTQ model builder method.') # pragma: no cover
|
262
207
|
|
263
208
|
@abstractmethod
|
264
|
-
def train(self
|
209
|
+
def train(self):
|
265
210
|
"""
|
266
|
-
Train the quantized model using GPTQ training process
|
267
|
-
Args:
|
268
|
-
representative_data_gen: Dataset to use for inputs of the models.
|
211
|
+
Train the quantized model using GPTQ training process.
|
269
212
|
"""
|
270
213
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
271
214
|
f'framework\'s train method.') # pragma: no cover
|
@@ -281,6 +224,7 @@ class GPTQTrainer(ABC):
|
|
281
224
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
282
225
|
f'framework\'s update_graph method.') # pragma: no cover
|
283
226
|
|
227
|
+
|
284
228
|
def gptq_training(graph_float: Graph,
|
285
229
|
graph_quant: Graph,
|
286
230
|
gptq_config: GradientPTQConfig,
|
@@ -315,7 +259,7 @@ def gptq_training(graph_float: Graph,
|
|
315
259
|
hessian_info_service=hessian_info_service)
|
316
260
|
|
317
261
|
# Training process
|
318
|
-
gptq_trainer.train(
|
262
|
+
gptq_trainer.train()
|
319
263
|
|
320
264
|
# Update graph
|
321
265
|
graph_quant = gptq_trainer.update_graph()
|
@@ -24,6 +24,7 @@ from model_compression_toolkit.core.common.hessian import HessianInfoService
|
|
24
24
|
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
25
25
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
26
26
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
27
|
+
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
|
27
28
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
28
29
|
from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
|
29
30
|
from model_compression_toolkit.logger import Logger
|
@@ -82,6 +83,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
82
83
|
gptq_config,
|
83
84
|
fw_impl,
|
84
85
|
fw_info,
|
86
|
+
representative_data_gen_fn=representative_data_gen,
|
85
87
|
hessian_info_service=hessian_info_service)
|
86
88
|
|
87
89
|
self.loss_list = []
|
@@ -115,10 +117,20 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
115
117
|
else:
|
116
118
|
self.input_scale = self.gptq_user_info.input_scale
|
117
119
|
|
118
|
-
self.weights_for_average_loss = self.
|
120
|
+
self.weights_for_average_loss = self._get_compare_points_loss_weights()
|
119
121
|
|
120
122
|
self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
|
121
123
|
|
124
|
+
def _get_compare_points_loss_weights(self):
|
125
|
+
""" Get compare points weights for the distillation loss. """
|
126
|
+
if self.gptq_config.use_hessian_based_weights:
|
127
|
+
hess_dataloader = data_gen_to_dataloader(self.representative_data_gen_fn,
|
128
|
+
batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
|
129
|
+
return self.compute_hessian_based_weights(hess_dataloader)
|
130
|
+
|
131
|
+
num_nodes = len(self.compare_points)
|
132
|
+
return np.ones((num_nodes,)) / num_nodes
|
133
|
+
|
122
134
|
def _is_gptq_weights_trainable(self,
|
123
135
|
node: common.BaseNode) -> bool:
|
124
136
|
"""
|
@@ -182,7 +194,6 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
182
194
|
f"but {len(activation_quantizers)} quantizers were found for node '{n}'. "
|
183
195
|
f"Ensure only one quantizer is configured for each node's activation.")
|
184
196
|
|
185
|
-
|
186
197
|
def build_gptq_model(self) -> Tuple[Model, UserInformation]:
|
187
198
|
"""
|
188
199
|
Build the GPTQ model with QuantizationWrappers
|
@@ -243,11 +254,9 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
243
254
|
i += len(p)
|
244
255
|
return loss_value, res
|
245
256
|
|
246
|
-
def train(self
|
257
|
+
def train(self):
|
247
258
|
"""
|
248
259
|
Train the quantized model using GPTQ training process in Keras framework
|
249
|
-
Args:
|
250
|
-
representative_data_gen: Dataset to use for inputs of the models.
|
251
260
|
"""
|
252
261
|
compute_gradients = self.compute_gradients
|
253
262
|
|
@@ -255,7 +264,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
255
264
|
# Training loop
|
256
265
|
# ----------------------------------------------
|
257
266
|
if self.has_params_to_train:
|
258
|
-
self.micro_training_loop(
|
267
|
+
self.micro_training_loop(self.representative_data_gen_fn,
|
259
268
|
compute_gradients,
|
260
269
|
self.optimizer_with_param,
|
261
270
|
self.gptq_config.n_epochs,
|
@@ -79,7 +79,7 @@ def sample_layer_attention_loss(y_list: List[torch.Tensor],
|
|
79
79
|
y_list: First list of tensors.
|
80
80
|
x_list: Second list of tensors.
|
81
81
|
fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface).
|
82
|
-
loss_weights: layer-sample weights tensor of shape (layers
|
82
|
+
loss_weights: layer-sample weights tensor of shape (batch X layers)
|
83
83
|
|
84
84
|
Returns:
|
85
85
|
Sample Layer Attention loss (a scalar).
|
@@ -87,10 +87,11 @@ def sample_layer_attention_loss(y_list: List[torch.Tensor],
|
|
87
87
|
loss = 0
|
88
88
|
layers_mean_w = []
|
89
89
|
|
90
|
-
for i, (y, x
|
90
|
+
for i, (y, x) in enumerate(zip(y_list, x_list)):
|
91
91
|
norm = (y - x).pow(2).sum(1)
|
92
92
|
if len(norm.shape) > 1:
|
93
93
|
norm = norm.flatten(1).mean(1)
|
94
|
+
w = loss_weights[:, i]
|
94
95
|
loss += torch.mean(w * norm)
|
95
96
|
layers_mean_w.append(w.mean())
|
96
97
|
|