mct-nightly 2.2.0.20241022.507__py3-none-any.whl → 2.2.0.20241024.501__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD +38 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +43 -29
  5. model_compression_toolkit/core/common/hessian/__init__.py +1 -1
  6. model_compression_toolkit/core/common/hessian/hessian_info_service.py +222 -371
  7. model_compression_toolkit/core/common/hessian/hessian_scores_request.py +27 -41
  8. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -10
  9. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +11 -9
  10. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +10 -6
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +17 -15
  12. model_compression_toolkit/core/keras/data_util.py +67 -0
  13. model_compression_toolkit/core/keras/keras_implementation.py +7 -1
  14. model_compression_toolkit/core/keras/tf_tensor_numpy.py +1 -1
  15. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  16. model_compression_toolkit/core/pytorch/data_util.py +163 -0
  17. model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +6 -31
  18. model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py +11 -21
  19. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +9 -7
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +8 -2
  21. model_compression_toolkit/core/pytorch/utils.py +22 -19
  22. model_compression_toolkit/core/quantization_prep_runner.py +2 -1
  23. model_compression_toolkit/core/runner.py +1 -2
  24. model_compression_toolkit/gptq/common/gptq_config.py +0 -2
  25. model_compression_toolkit/gptq/common/gptq_training.py +58 -114
  26. model_compression_toolkit/gptq/keras/gptq_training.py +15 -6
  27. model_compression_toolkit/gptq/pytorch/gptq_loss.py +3 -2
  28. model_compression_toolkit/gptq/pytorch/gptq_training.py +97 -64
  29. model_compression_toolkit/gptq/pytorch/quantization_facade.py +0 -2
  30. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +4 -3
  31. tests_pytest/keras/__init__.py +14 -0
  32. tests_pytest/keras/core/__init__.py +14 -0
  33. tests_pytest/keras/core/test_data_util.py +91 -0
  34. tests_pytest/pytorch/core/__init__.py +14 -0
  35. tests_pytest/pytorch/core/test_data_util.py +125 -0
  36. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md +0 -0
  37. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/WHEEL +0 -0
  38. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,10 @@
15
15
 
16
16
  from typing import Union, List
17
17
 
18
- from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
19
- from model_compression_toolkit.core.common import Graph
20
- from model_compression_toolkit.core.common.hessian import HessianScoresRequest
18
+ import torch
19
+
21
20
  from model_compression_toolkit.core.common.hessian.hessian_scores_calculator import HessianScoresCalculator
22
21
  from model_compression_toolkit.logger import Logger
23
- import torch
24
22
 
25
23
 
26
24
  class HessianScoresCalculatorPytorch(HessianScoresCalculator):
@@ -28,28 +26,20 @@ class HessianScoresCalculatorPytorch(HessianScoresCalculator):
28
26
  Pytorch-specific implementation of the Hessian approximation scores Calculator.
29
27
  This class serves as a base for other Pytorch-specific Hessian approximation scores calculators.
30
28
  """
31
- def __init__(self,
32
- graph: Graph,
33
- input_images: List[torch.Tensor],
34
- fw_impl,
35
- hessian_scores_request: HessianScoresRequest,
36
- num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
29
+ def _generate_random_vectors_batch(self, shape: tuple, device: torch.device) -> torch.Tensor:
37
30
  """
31
+ Generate a batch of random vectors for Hutchinson estimation using Rademacher distribution.
38
32
 
39
33
  Args:
40
- graph: Computational graph for the float model.
41
- input_images: List of input images for the computation.
42
- fw_impl: Framework-specific implementation for Hessian scores computation.
43
- hessian_scores_request: Configuration request for which to compute the Hessian approximation scores.
44
- num_iterations_for_approximation: Number of iterations to use when approximating the Hessian based scores.
34
+ shape: target shape.
35
+ device: target device.
45
36
 
37
+ Returns:
38
+ Random tensor.
46
39
  """
47
- super(HessianScoresCalculatorPytorch, self).__init__(graph=graph,
48
- input_images=input_images,
49
- fw_impl=fw_impl,
50
- hessian_scores_request=hessian_scores_request,
51
- num_iterations_for_approximation=num_iterations_for_approximation)
52
-
40
+ v = torch.randint(high=2, size=shape, device=device)
41
+ v[v == 0] = -1
42
+ return v
53
43
 
54
44
  def concat_tensors(self, tensors_to_concate: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
55
45
  """
@@ -12,19 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from tqdm import tqdm
16
15
  from typing import List
16
+
17
+ import numpy as np
17
18
  import torch
18
19
  from torch import autograd
19
- import numpy as np
20
+ from tqdm import tqdm
21
+
22
+ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE
20
23
  from model_compression_toolkit.core.common import Graph
21
24
  from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
25
+ from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
26
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
22
27
  from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
23
28
  HessianScoresCalculatorPytorch
24
29
  from model_compression_toolkit.logger import Logger
25
- from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
26
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
27
- from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_EPS
28
30
 
29
31
 
30
32
  class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
@@ -84,8 +86,8 @@ class WeightsHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
84
86
 
85
87
  prev_mean_results = None
86
88
  for j in tqdm(range(self.num_iterations_for_approximation)):
87
- # Getting a random vector with normal distribution and the same shape as the model output
88
- v = torch.randn_like(output_tensor, device=device)
89
+ # Getting a random vector with the same shape as the model output
90
+ v = self._generate_random_vectors_batch(output_tensor.shape, device=device)
89
91
  f_v = torch.mean(torch.sum(v * output_tensor, dim=-1))
90
92
  for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
91
93
 
@@ -15,7 +15,7 @@
15
15
  import operator
16
16
  from copy import deepcopy
17
17
  from functools import partial
18
- from typing import List, Any, Tuple, Callable, Type, Dict
18
+ from typing import List, Any, Tuple, Callable, Type, Dict, Generator
19
19
 
20
20
  import numpy as np
21
21
  import torch
@@ -38,6 +38,7 @@ from model_compression_toolkit.core.common.model_builder_mode import ModelBuilde
38
38
  from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
39
39
  from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_kl_divergence, compute_cs
40
40
  from model_compression_toolkit.core.pytorch.back2framework import get_pytorch_model_builder
41
+ from model_compression_toolkit.core.pytorch.data_util import data_gen_to_dataloader
41
42
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
42
43
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_folding import \
43
44
  pytorch_batchnorm_folding, pytorch_batchnorm_forward_folding
@@ -563,4 +564,9 @@ class PytorchImplementation(FrameworkImplementation):
563
564
  return get_inferable_quantizers(node,
564
565
  get_weights_quantizer_for_node,
565
566
  get_activations_quantizer_for_node,
566
- node.get_node_weights_attributes())
567
+ node.get_node_weights_attributes())
568
+
569
+ @staticmethod
570
+ def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
571
+ """ Converts data generator into framework dataloader with arbitrary batch size. """
572
+ return data_gen_to_dataloader(data_gen_fn, batch_size=batch_size)
@@ -15,7 +15,7 @@
15
15
  import torch
16
16
  from torch import Tensor
17
17
  import numpy as np
18
- from typing import Union
18
+ from typing import Union, Sequence, Optional, List, Tuple
19
19
 
20
20
  from model_compression_toolkit.core.pytorch.constants import MAX_FLOAT16, MIN_FLOAT16
21
21
  from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
@@ -41,30 +41,33 @@ def set_model(model: torch.nn.Module, train_mode: bool = False):
41
41
  model.to(device)
42
42
 
43
43
 
44
- def to_torch_tensor(tensor,
45
- numpy_type=np.float32):
44
+ def to_torch_tensor(data,
45
+ dtype: Optional = torch.float32) -> Union[Tensor, List[Tensor], Tuple[Tensor]]:
46
+ # TODO it would make more sense to keep the original type by default but it will break lots of existing calls
47
+ # that count on implicit convertion
46
48
  """
47
- Convert a Numpy array to a Torch tensor.
49
+ Convert data to Torch tensors and move to the working device.
50
+ Data can be numpy or torch tensor, a scalar, or a list or a tuple of such data. In the latter case only the inner
51
+ data is converted.
52
+
48
53
  Args:
49
- tensor: Numpy array.
50
- numpy_type: The desired data type for the tensor. Default is np.float32.
54
+ data: Input data
55
+ dtype: The desired data type for the tensor. Pass None to keep the type of the input data.
51
56
 
52
57
  Returns:
53
- Torch tensor converted from the input Numpy array.
58
+ Torch tensor
54
59
  """
60
+
55
61
  working_device = get_working_device()
56
- if isinstance(tensor, torch.Tensor):
57
- return tensor.to(working_device)
58
- elif isinstance(tensor, list):
59
- return [to_torch_tensor(t) for t in tensor]
60
- elif isinstance(tensor, tuple):
61
- return (to_torch_tensor(t) for t in tensor)
62
- elif isinstance(tensor, np.ndarray):
63
- return torch.from_numpy(tensor.astype(numpy_type)).to(working_device)
64
- elif isinstance(tensor, (int, float)):
65
- return torch.from_numpy(np.array(tensor).astype(numpy_type)).to(working_device)
66
- else:
67
- Logger.critical(f'Unsupported type for conversion to Torch.tensor: {type(tensor)}.')
62
+
63
+ if isinstance(data, list):
64
+ return [to_torch_tensor(t, dtype) for t in data]
65
+
66
+ if isinstance(data, tuple):
67
+ return tuple(to_torch_tensor(t, dtype) for t in data)
68
+
69
+ kwargs = {} if dtype is None else {'dtype': dtype}
70
+ return torch.as_tensor(data, device=working_device, **kwargs)
68
71
 
69
72
 
70
73
  def torch_tensor_to_numpy(tensor: Union[torch.Tensor, list, tuple]) -> Union[np.ndarray, list, tuple]:
@@ -90,7 +90,8 @@ def quantization_preparation_runner(graph: Graph,
90
90
  # Calculate quantization params
91
91
  ######################################
92
92
 
93
- calculate_quantization_params(graph, hessian_info_service=hessian_info_service)
93
+ calculate_quantization_params(graph, fw_impl=fw_impl, repr_data_gen_fn=representative_data_gen,
94
+ hessian_info_service=hessian_info_service)
94
95
 
95
96
  if tb_w is not None:
96
97
  tb_w.add_graph(graph, 'thresholds_selection')
@@ -122,8 +122,7 @@ def core_runner(in_model: Any,
122
122
  mixed_precision_enable=core_config.is_mixed_precision_enabled,
123
123
  running_gptq=running_gptq)
124
124
 
125
- hessian_info_service = HessianInfoService(graph=graph, representative_dataset_gen=representative_data_gen,
126
- fw_impl=fw_impl)
125
+ hessian_info_service = HessianInfoService(graph=graph, fw_impl=fw_impl)
127
126
 
128
127
  tg = quantization_preparation_runner(graph=graph,
129
128
  representative_data_gen=representative_data_gen,
@@ -17,7 +17,6 @@ from enum import Enum
17
17
  from typing import Callable, Any, Dict, Optional
18
18
 
19
19
  from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
20
- from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
21
20
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
22
21
 
23
22
 
@@ -54,7 +53,6 @@ class GPTQHessianScoresConfig:
54
53
  scale_log_norm: bool = False
55
54
  hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
56
55
  per_sample: bool = False
57
- estimator_distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN
58
56
 
59
57
 
60
58
  @dataclass
@@ -13,23 +13,21 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import copy
16
- import hashlib
17
16
  from abc import ABC, abstractmethod
17
+ from typing import Callable, List, Any, Iterable, Optional, Generator
18
+
18
19
  import numpy as np
19
- from typing import Callable, List, Any, Dict
20
20
 
21
- from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE
22
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
23
- from model_compression_toolkit.core.common import Graph, BaseNode
21
+ from model_compression_toolkit.core.common import Graph
24
22
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
+ from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
24
+ HessianScoresGranularity, hessian_info_utils as hessian_utils
25
+ from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
26
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
25
27
  from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
26
28
  from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
27
29
  from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
28
- from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
29
30
  from model_compression_toolkit.logger import Logger
30
- from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
31
- HessianScoresGranularity
32
- from model_compression_toolkit.core.common.hessian import hessian_info_utils as hessian_utils
33
31
 
34
32
 
35
33
  class GPTQTrainer(ABC):
@@ -43,6 +41,7 @@ class GPTQTrainer(ABC):
43
41
  gptq_config: GradientPTQConfig,
44
42
  fw_impl: GPTQFrameworkImplemantation,
45
43
  fw_info: FrameworkInfo,
44
+ representative_data_gen_fn: Callable[[], Generator],
46
45
  hessian_info_service: HessianInfoService = None):
47
46
  """
48
47
  Build two models from a graph: A teacher network (float model) and a student network (quantized model).
@@ -56,6 +55,7 @@ class GPTQTrainer(ABC):
56
55
  gptq_config: GradientPTQConfig with parameters about the tuning process.
57
56
  fw_impl: Framework implementation
58
57
  fw_info: Framework information
58
+ representative_data_gen_fn: factory for representative data generator.
59
59
  hessian_info_service: HessianInfoService for fetching and computing Hessian-approximation information.
60
60
  """
61
61
  self.graph_float = copy.deepcopy(graph_float)
@@ -63,7 +63,7 @@ class GPTQTrainer(ABC):
63
63
  self.gptq_config = gptq_config
64
64
  self.fw_impl = fw_impl
65
65
  self.fw_info = fw_info
66
-
66
+ self.representative_data_gen_fn = representative_data_gen_fn
67
67
  # ----------------------------------------------
68
68
  # Build two models and create compare nodes
69
69
  # ----------------------------------------------
@@ -131,124 +131,69 @@ class GPTQTrainer(ABC):
131
131
 
132
132
  return optimizer_with_param
133
133
 
134
- def compute_hessian_based_weights(self) -> np.ndarray:
134
+ def compute_hessian_based_weights(self, data_loader: Iterable) -> np.ndarray:
135
135
  """
136
136
  Computes scores based on the hessian approximation per layer w.r.t activations of the interest points.
137
137
 
138
138
  Returns:
139
139
  np.ndarray: Scores based on the hessian matrix approximation.
140
140
  """
141
- if not self.gptq_config.use_hessian_based_weights:
142
- # Return a default weight distribution based on the number of compare points
143
- num_nodes = len(self.compare_points)
144
- return np.asarray([1 / num_nodes for _ in range(num_nodes)])
145
-
146
- # Fetch hessian approximations for each target node
147
- # TODO this smells like a potential bug. In hessian calculation target nodes are topo sorted and results are returned
148
- # TODO also target nodes are replaced for reuse. Does this work correctly?
149
- approximations = self._fetch_hessian_approximations(HessianScoresGranularity.PER_TENSOR)
150
- compare_point_to_hessian_approx_scores = {node: score for node, score in zip(self.compare_points, approximations)}
151
-
152
- # Process the fetched hessian approximations to gather them per images
153
- hessian_approx_score_by_image = (
154
- self._process_hessian_approximations(compare_point_to_hessian_approx_scores))
155
-
156
- # Check if log normalization is enabled in the configuration
157
- if self.gptq_config.hessian_weights_config.log_norm:
158
- # Calculate the mean of the approximations across images
159
- mean_approx_scores = np.mean(hessian_approx_score_by_image, axis=0)
160
- # Reduce unnecessary dims, should remain with one dimension for the number of nodes
161
- mean_approx_scores = np.squeeze(mean_approx_scores)
162
- # Handle zero values to avoid log(0)
163
- mean_approx_scores = np.where(mean_approx_scores != 0, mean_approx_scores,
164
- np.partition(mean_approx_scores, 1)[1])
165
-
166
- # Calculate log weights
167
- log_weights = np.log10(mean_approx_scores)
168
-
169
- # Check if scaling of log normalization is enabled in the configuration
170
- if self.gptq_config.hessian_weights_config.scale_log_norm:
171
- # Scale the log weights to the range [0, 1]
172
- return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights))
173
-
174
- # Offset the log weights so the minimum value is 0
175
- return log_weights - np.min(log_weights)
176
- else:
177
- # If log normalization is not enabled, return the mean of the approximations across images
178
- return np.mean(hessian_approx_score_by_image, axis=0)
179
-
180
- def _compute_sample_layer_attention_scores(self, inputs_batch) -> Dict[str, Dict[BaseNode, np.ndarray]]:
181
- """
182
- Compute sample layer attention scores per image hash per layer.
141
+ request = self._build_hessian_request(
142
+ HessianScoresGranularity.PER_TENSOR,
143
+ data_loader=data_loader,
144
+ n_samples=self.gptq_config.hessian_weights_config.hessians_num_samples
145
+ )
146
+ layers_hessians = self.hessian_service.fetch_hessian(request)
183
147
 
184
- Args:
185
- inputs_batch: a list containing a batch of inputs.
148
+ hessian_approx_score_by_image = np.stack([layers_hessians[node.name] for node in self.compare_points], axis=1)
149
+ assert hessian_approx_score_by_image.shape[0] == self.gptq_config.hessian_weights_config.hessians_num_samples
186
150
 
187
- Returns:
188
- A dictionary with a structure {img_hash: {layer: score}}.
151
+ if self.gptq_config.hessian_weights_config.norm_scores:
152
+ hessian_approx_score_by_image = hessian_utils.normalize_scores(hessian_approx_score_by_image)
189
153
 
190
- """
191
- request = self._build_hessian_request(HessianScoresGranularity.PER_OUTPUT_CHANNEL)
192
- hessian_batch_size = self.gptq_config.hessian_weights_config.hessian_batch_size
193
-
194
- hessian_score_per_image_per_layer = {}
195
- # If hessian batch is smaller than inputs batch, split it to hessian batches. If hessian batch is larger,
196
- # it's currently ignored (TODO)
197
- for i in range(0, inputs_batch[0].shape[0], hessian_batch_size):
198
- inputs = [t[i: i+hessian_batch_size] for t in inputs_batch]
199
- hessian_score_per_image_per_layer.update(
200
- self.hessian_service.compute_trackable_per_sample_hessian(request, inputs)
201
- )
202
- for img_hash, v in hessian_score_per_image_per_layer.items():
203
- hessian_score_per_image_per_layer[img_hash] = {k: t.max(axis=0) for k, t in v.items()}
204
- return hessian_score_per_image_per_layer
205
-
206
- def _fetch_hessian_approximations(self, granularity: HessianScoresGranularity) -> Dict[BaseNode, List[List[float]]]:
207
- """
208
- Fetches hessian approximations for each target node.
154
+ # Calculate the mean of the approximations across images
155
+ mean_approx_scores = np.mean(hessian_approx_score_by_image, axis=0)
156
+ # assert len(mean_approx_scores.shape) == len(self.compare_points)
209
157
 
210
- Returns:
211
- Mapping of target nodes to their hessian approximations.
212
- """
213
- hessian_scores_request = self._build_hessian_request(granularity)
158
+ if not self.gptq_config.hessian_weights_config.log_norm:
159
+ return mean_approx_scores
214
160
 
215
- node_approximations = self.hessian_service.fetch_hessian(
216
- hessian_scores_request=hessian_scores_request,
217
- required_size=self.gptq_config.hessian_weights_config.hessians_num_samples,
218
- batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size
219
- )
220
- return node_approximations
161
+ # Reduce unnecessary dims, should remain with one dimension for the number of nodes
162
+ mean_approx_scores = np.squeeze(mean_approx_scores)
163
+ # Handle zero values to avoid log(0)
164
+ mean_approx_scores = np.where(mean_approx_scores != 0, mean_approx_scores,
165
+ np.partition(mean_approx_scores, 1)[1])
221
166
 
222
- def _build_hessian_request(self, granularity):
223
- return HessianScoresRequest(
224
- mode=HessianMode.ACTIVATION,
225
- granularity=granularity,
226
- target_nodes=self.compare_points,
227
- distribution=self.gptq_config.hessian_weights_config.estimator_distribution
228
- )
167
+ # Calculate log weights
168
+ log_weights = np.log10(mean_approx_scores)
229
169
 
230
- def _process_hessian_approximations(self, approximations: Dict[BaseNode, List[List[float]]]) -> List:
170
+ if self.gptq_config.hessian_weights_config.scale_log_norm:
171
+ # Scale the log weights to the range [0, 1]
172
+ return (log_weights - np.min(log_weights)) / (np.max(log_weights) - np.min(log_weights))
173
+
174
+ # Offset the log weights so the minimum value is 0
175
+ return log_weights - np.min(log_weights)
176
+
177
+ def _build_hessian_request(self, granularity: HessianScoresGranularity, data_loader: Iterable,
178
+ n_samples: Optional[int]) -> HessianScoresRequest:
231
179
  """
232
- Processes the fetched hessian approximations by image.
233
- Receives a dictionary of Node to a list of the length of the number of images that were fetched.
234
- Returns list of lists where each inner list is the approximations per image to all interest points.
180
+ Build hessian request for hessian service.
235
181
 
236
182
  Args:
237
- approximations: Hessian scores approximations mapping to process.
238
- Dictionary of Node to a list of the length of the number of images that were fetched.
183
+ granularity: requested granularity.
184
+ data_loader: data loader yielding samples to compute hessians on.
185
+ n_samples: request number of samples.
239
186
 
240
187
  Returns:
241
- Processed approximations as a list of lists where each inner list is the approximations
242
- per image to all interest points.
188
+ Hessian request.
243
189
  """
244
- hessian_approx_score_by_image = [[approximations[target_node][image_idx] for target_node in self.compare_points]
245
- for image_idx in
246
- range(self.gptq_config.hessian_weights_config.hessians_num_samples)]
247
-
248
- if self.gptq_config.hessian_weights_config.norm_scores:
249
- hessian_approx_score_by_image = hessian_utils.normalize_scores(hessian_approx_score_by_image)
250
-
251
- return hessian_approx_score_by_image
190
+ return HessianScoresRequest(
191
+ mode=HessianMode.ACTIVATION,
192
+ granularity=granularity,
193
+ target_nodes=self.compare_points,
194
+ data_loader=data_loader,
195
+ n_samples=n_samples
196
+ )
252
197
 
253
198
  @abstractmethod
254
199
  def build_gptq_model(self):
@@ -261,11 +206,9 @@ class GPTQTrainer(ABC):
261
206
  f'framework\'s GPTQ model builder method.') # pragma: no cover
262
207
 
263
208
  @abstractmethod
264
- def train(self, representative_data_gen: Callable):
209
+ def train(self):
265
210
  """
266
- Train the quantized model using GPTQ training process
267
- Args:
268
- representative_data_gen: Dataset to use for inputs of the models.
211
+ Train the quantized model using GPTQ training process.
269
212
  """
270
213
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
271
214
  f'framework\'s train method.') # pragma: no cover
@@ -281,6 +224,7 @@ class GPTQTrainer(ABC):
281
224
  raise NotImplemented(f'{self.__class__.__name__} have to implement the '
282
225
  f'framework\'s update_graph method.') # pragma: no cover
283
226
 
227
+
284
228
  def gptq_training(graph_float: Graph,
285
229
  graph_quant: Graph,
286
230
  gptq_config: GradientPTQConfig,
@@ -315,7 +259,7 @@ def gptq_training(graph_float: Graph,
315
259
  hessian_info_service=hessian_info_service)
316
260
 
317
261
  # Training process
318
- gptq_trainer.train(representative_data_gen)
262
+ gptq_trainer.train()
319
263
 
320
264
  # Update graph
321
265
  graph_quant = gptq_trainer.update_graph()
@@ -24,6 +24,7 @@ from model_compression_toolkit.core.common.hessian import HessianInfoService
24
24
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
25
25
  from model_compression_toolkit.core.common.user_info import UserInformation
26
26
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
27
+ from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
27
28
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
28
29
  from model_compression_toolkit.gptq.keras.quantizer.quantization_builder import quantization_builder
29
30
  from model_compression_toolkit.logger import Logger
@@ -82,6 +83,7 @@ class KerasGPTQTrainer(GPTQTrainer):
82
83
  gptq_config,
83
84
  fw_impl,
84
85
  fw_info,
86
+ representative_data_gen_fn=representative_data_gen,
85
87
  hessian_info_service=hessian_info_service)
86
88
 
87
89
  self.loss_list = []
@@ -115,10 +117,20 @@ class KerasGPTQTrainer(GPTQTrainer):
115
117
  else:
116
118
  self.input_scale = self.gptq_user_info.input_scale
117
119
 
118
- self.weights_for_average_loss = self.compute_hessian_based_weights()
120
+ self.weights_for_average_loss = self._get_compare_points_loss_weights()
119
121
 
120
122
  self.reg_func = get_regularization(self.gptq_config, representative_data_gen)
121
123
 
124
+ def _get_compare_points_loss_weights(self):
125
+ """ Get compare points weights for the distillation loss. """
126
+ if self.gptq_config.use_hessian_based_weights:
127
+ hess_dataloader = data_gen_to_dataloader(self.representative_data_gen_fn,
128
+ batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
129
+ return self.compute_hessian_based_weights(hess_dataloader)
130
+
131
+ num_nodes = len(self.compare_points)
132
+ return np.ones((num_nodes,)) / num_nodes
133
+
122
134
  def _is_gptq_weights_trainable(self,
123
135
  node: common.BaseNode) -> bool:
124
136
  """
@@ -182,7 +194,6 @@ class KerasGPTQTrainer(GPTQTrainer):
182
194
  f"but {len(activation_quantizers)} quantizers were found for node '{n}'. "
183
195
  f"Ensure only one quantizer is configured for each node's activation.")
184
196
 
185
-
186
197
  def build_gptq_model(self) -> Tuple[Model, UserInformation]:
187
198
  """
188
199
  Build the GPTQ model with QuantizationWrappers
@@ -243,11 +254,9 @@ class KerasGPTQTrainer(GPTQTrainer):
243
254
  i += len(p)
244
255
  return loss_value, res
245
256
 
246
- def train(self, representative_data_gen: Callable):
257
+ def train(self):
247
258
  """
248
259
  Train the quantized model using GPTQ training process in Keras framework
249
- Args:
250
- representative_data_gen: Dataset to use for inputs of the models.
251
260
  """
252
261
  compute_gradients = self.compute_gradients
253
262
 
@@ -255,7 +264,7 @@ class KerasGPTQTrainer(GPTQTrainer):
255
264
  # Training loop
256
265
  # ----------------------------------------------
257
266
  if self.has_params_to_train:
258
- self.micro_training_loop(representative_data_gen,
267
+ self.micro_training_loop(self.representative_data_gen_fn,
259
268
  compute_gradients,
260
269
  self.optimizer_with_param,
261
270
  self.gptq_config.n_epochs,
@@ -79,7 +79,7 @@ def sample_layer_attention_loss(y_list: List[torch.Tensor],
79
79
  y_list: First list of tensors.
80
80
  x_list: Second list of tensors.
81
81
  fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface).
82
- loss_weights: layer-sample weights tensor of shape (layers, batch)
82
+ loss_weights: layer-sample weights tensor of shape (batch X layers)
83
83
 
84
84
  Returns:
85
85
  Sample Layer Attention loss (a scalar).
@@ -87,10 +87,11 @@ def sample_layer_attention_loss(y_list: List[torch.Tensor],
87
87
  loss = 0
88
88
  layers_mean_w = []
89
89
 
90
- for i, (y, x, w) in enumerate(zip(y_list, x_list, loss_weights)):
90
+ for i, (y, x) in enumerate(zip(y_list, x_list)):
91
91
  norm = (y - x).pow(2).sum(1)
92
92
  if len(norm.shape) > 1:
93
93
  norm = norm.flatten(1).mean(1)
94
+ w = loss_weights[:, i]
94
95
  loss += torch.mean(w * norm)
95
96
  layers_mean_w.append(w.mean())
96
97