mct-nightly 2.3.0.20250223.538__py3-none-any.whl → 2.3.0.20250225.512__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250223.538.dist-info → mct_nightly-2.3.0.20250225.512.dist-info}/METADATA +2 -2
- {mct_nightly-2.3.0.20250223.538.dist-info → mct_nightly-2.3.0.20250225.512.dist-info}/RECORD +24 -23
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/collectors/histogram_collector.py +19 -20
- model_compression_toolkit/core/common/collectors/statistics_collector.py +7 -3
- model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py +114 -0
- model_compression_toolkit/core/common/framework_implementation.py +9 -4
- model_compression_toolkit/core/common/graph/base_node.py +16 -6
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +31 -15
- model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py +1 -1
- model_compression_toolkit/core/common/hessian/hessian_scores_request.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
- model_compression_toolkit/core/common/model_collector.py +115 -17
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +2 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +110 -33
- model_compression_toolkit/core/keras/keras_implementation.py +35 -27
- model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +23 -61
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +34 -18
- model_compression_toolkit/core/quantization_prep_runner.py +1 -0
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py +2 -2
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py +2 -1
- {mct_nightly-2.3.0.20250223.538.dist-info → mct_nightly-2.3.0.20250225.512.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250223.538.dist-info → mct_nightly-2.3.0.20250225.512.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250223.538.dist-info → mct_nightly-2.3.0.20250225.512.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,8 @@ import dataclasses
|
|
17
17
|
|
18
18
|
from enum import Enum
|
19
19
|
|
20
|
+
from model_compression_toolkit.logger import Logger
|
21
|
+
|
20
22
|
if TYPE_CHECKING: # pragma: no cover
|
21
23
|
from model_compression_toolkit.core.common import BaseNode
|
22
24
|
|
@@ -60,16 +62,19 @@ class HessianScoresRequest:
|
|
60
62
|
the computation. Can be None if all hessians for the request are expected to be pre-computed previously.
|
61
63
|
n_samples: The number of samples to fetch hessian estimations for. If None, fetch hessians for a full pass
|
62
64
|
of the data loader.
|
65
|
+
compute_from_tensors: If `True`, Hessians are computed directly from given tensors instead of using the data loader.
|
66
|
+
|
63
67
|
"""
|
64
68
|
mode: HessianMode
|
65
69
|
granularity: HessianScoresGranularity
|
66
70
|
target_nodes: Sequence['BaseNode']
|
67
71
|
data_loader: Optional[Iterable]
|
68
72
|
n_samples: Optional[int]
|
73
|
+
compute_from_tensors: bool = False
|
69
74
|
|
70
75
|
def __post_init__(self):
|
71
|
-
if self.data_loader is None and self.n_samples is None:
|
72
|
-
|
76
|
+
if self.data_loader is None and self.n_samples is None and not self.compute_from_tensors:
|
77
|
+
Logger.critical('Data loader and the number of samples cannot both be None.')
|
73
78
|
|
74
79
|
def clone(self, **kwargs):
|
75
80
|
""" Create a clone with optional overrides """
|
@@ -218,7 +218,12 @@ def _add_ru_constraints(search_manager: MixedPrecisionSearchManager,
|
|
218
218
|
ru_vec = np.concatenate([ru_vec, non_conf_ru_vec])
|
219
219
|
ru_indicated_vectors[target] = ru_vec
|
220
220
|
|
221
|
-
#
|
221
|
+
# Add constraints only for the restricted targets in target resource utilization.
|
222
|
+
# Adding activation constraints modifies the lp term in ru_indicated_vectors, so if both activation and total
|
223
|
+
# are restricted we first add the constraints for total.
|
224
|
+
if RUTarget.TOTAL in constraints_targets and RUTarget.ACTIVATION in constraints_targets:
|
225
|
+
constraints_targets.remove(RUTarget.ACTIVATION)
|
226
|
+
constraints_targets = list(constraints_targets) + [RUTarget.ACTIVATION]
|
222
227
|
for target in constraints_targets:
|
223
228
|
target_resource_utilization_value = target_resource_utilization.get_resource_utilization_dict()[target]
|
224
229
|
aggr_ru = _aggregate_for_lp(ru_indicated_vectors, target)
|
@@ -15,13 +15,15 @@
|
|
15
15
|
|
16
16
|
|
17
17
|
import numpy as np
|
18
|
-
from typing import List
|
18
|
+
from typing import List, Union, Tuple, Optional
|
19
19
|
|
20
20
|
from networkx.algorithms.dag import topological_sort
|
21
|
-
from model_compression_toolkit.core import FrameworkInfo
|
21
|
+
from model_compression_toolkit.core import FrameworkInfo, QuantizationErrorMethod
|
22
22
|
from model_compression_toolkit.core import common
|
23
23
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
24
24
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
25
|
+
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresGranularity, HessianMode, \
|
26
|
+
HessianScoresRequest
|
25
27
|
from model_compression_toolkit.logger import Logger
|
26
28
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
27
29
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
@@ -66,12 +68,67 @@ def create_tensor2node(graph: common.Graph,
|
|
66
68
|
|
67
69
|
"""
|
68
70
|
current_sc = graph.get_out_stats_collector(node)
|
69
|
-
is_list_nostat_collectors = isinstance(current_sc, list) and len(
|
71
|
+
is_list_nostat_collectors = isinstance(current_sc, list) and len(
|
72
|
+
[sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
|
70
73
|
if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors:
|
71
74
|
stats_collector = common.StatsCollector(fw_info.out_channel_axis_mapping.get(node.type))
|
72
75
|
graph.set_out_stats_collector_to_node(node, stats_collector)
|
73
76
|
|
74
77
|
|
78
|
+
def ensure_matching_data_lengths(
|
79
|
+
stats_collector: Union[List[BaseStatsCollector], Tuple[BaseStatsCollector, ...]],
|
80
|
+
tensor_data: Union[List, Tuple],
|
81
|
+
hessian_data: Union[List, Tuple]
|
82
|
+
):
|
83
|
+
"""
|
84
|
+
Ensures that the lengths of `tensor_data`, `hessian_data`, and `stats_collector` are matching.
|
85
|
+
If the types or lengths do not match, a critical error is logged.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
stats_collector: A list or tuple of statistics collectors.
|
89
|
+
tensor_data: A list or tuple of tensors corresponding to the statistics collectors.
|
90
|
+
hessian_data: A list or tuple of Hessian tensors corresponding to the statistics collectors.
|
91
|
+
|
92
|
+
Raises:
|
93
|
+
Logs a critical error and halts execution if there is a type mismatch or
|
94
|
+
if the lengths of the inputs do not match.
|
95
|
+
"""
|
96
|
+
|
97
|
+
if not isinstance(tensor_data, (list, tuple)):
|
98
|
+
Logger.critical(
|
99
|
+
f"'tensor_data' is of type {type(tensor_data)}, but must be of the same type as 'stats_collector' ({type(stats_collector)})."
|
100
|
+
) # pragma: no cover
|
101
|
+
|
102
|
+
if len(stats_collector) != len(tensor_data):
|
103
|
+
Logger.critical(
|
104
|
+
"'tensor_data' and 'stats_collector' must have matching lengths."
|
105
|
+
) # pragma: no cover
|
106
|
+
|
107
|
+
if not isinstance(hessian_data, (list, tuple)):
|
108
|
+
Logger.critical(
|
109
|
+
f"'hessian_data' is of type {type(hessian_data)}, but must be of the same type as 'stats_collector' ({type(stats_collector)})."
|
110
|
+
) # pragma: no cover
|
111
|
+
|
112
|
+
if len(stats_collector) != len(hessian_data):
|
113
|
+
Logger.critical(
|
114
|
+
"'hessian_data' and 'stats_collector' must have matching lengths."
|
115
|
+
) # pragma: no cover
|
116
|
+
|
117
|
+
|
118
|
+
def convert_to_numpy_and_abs(tensor: Optional[np.ndarray], fw_impl: FrameworkImplementation) -> Optional[np.ndarray]:
|
119
|
+
"""
|
120
|
+
Converts a tensor to a NumPy array and applies the absolute value operation if the tensor is not None.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
tensor: Input tensor to be converted to a NumPy array.
|
124
|
+
fw_impl: Framework implementation that provides the 'to_numpy' method for tensor conversion.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
A NumPy array of the input tensor with absolute values applied. If the input tensor is None, returns None.
|
128
|
+
"""
|
129
|
+
return tensor if tensor is None else np.abs(fw_impl.to_numpy(tensor))
|
130
|
+
|
131
|
+
|
75
132
|
class ModelCollector:
|
76
133
|
"""
|
77
134
|
Build a model from a graph for statistics collection purposes.
|
@@ -83,6 +140,7 @@ class ModelCollector:
|
|
83
140
|
def __init__(self, graph: Graph,
|
84
141
|
fw_impl: FrameworkImplementation,
|
85
142
|
fw_info: FrameworkInfo,
|
143
|
+
hessian_info_service: HessianInfoService = None,
|
86
144
|
qc: common.QuantizationConfig = common.DEFAULTCONFIG):
|
87
145
|
"""
|
88
146
|
Build a model from a graph per framework for statistics collection.
|
@@ -96,14 +154,18 @@ class ModelCollector:
|
|
96
154
|
|
97
155
|
self.fw_impl = fw_impl
|
98
156
|
self.fw_info = fw_info
|
157
|
+
self.hessian_service = hessian_info_service
|
158
|
+
self.qc = qc
|
159
|
+
self.model_outputs = [out.node for out in graph.get_outputs()]
|
99
160
|
|
100
|
-
# Assign
|
161
|
+
# Assign statistics collectors to nodes
|
101
162
|
for n in graph.get_topo_sorted_nodes():
|
102
163
|
sc = create_stats_collector_for_node(n, fw_info=fw_info) # Get static collector for the node
|
103
164
|
# If we use bias correction, and the node has kernel weights to quantize, we need to make sure
|
104
165
|
# its previous nodes' tensors are consistent with this node.
|
105
166
|
kernel_attr = fw_info.get_kernel_op_attributes(n.type)[0]
|
106
|
-
if qc.weights_bias_correction and kernel_attr is not None and n.is_weights_quantization_enabled(
|
167
|
+
if qc.weights_bias_correction and kernel_attr is not None and n.is_weights_quantization_enabled(
|
168
|
+
kernel_attr):
|
107
169
|
for ie in graph.incoming_edges(n):
|
108
170
|
input_node = ie.source_node
|
109
171
|
create_tensor2node(graph,
|
@@ -112,7 +174,6 @@ class ModelCollector:
|
|
112
174
|
if sc is not None:
|
113
175
|
graph.set_out_stats_collector_to_node(n, sc)
|
114
176
|
|
115
|
-
|
116
177
|
outputs_nodes = [] # List of graph nodes, the model should output their outputs.
|
117
178
|
self.stats_containers_list = [] # List of output statistics containers of nodes ordered
|
118
179
|
# the same as outputs_nodes so statistics of outputs can be gathered for the correct statistics container.
|
@@ -135,11 +196,19 @@ class ModelCollector:
|
|
135
196
|
outputs_nodes.append(n)
|
136
197
|
self.stats_containers_list.append(out_stats_container)
|
137
198
|
|
199
|
+
self.intermediate_output_tensors = [n for n in outputs_nodes if n not in self.model_outputs]
|
200
|
+
|
201
|
+
# Append nodes from graph.get_outputs() that are not already in outputs_nodes for Hessian
|
202
|
+
# calculation for output nodes that don't collect statistics such as "permute", "transpose" etc.
|
203
|
+
# TODO: Add integration test for this case
|
204
|
+
append2output = outputs_nodes + [n for n in self.model_outputs if n not in outputs_nodes]
|
205
|
+
|
206
|
+
|
138
207
|
# Build a float model and output all layers' outputs
|
139
208
|
# (that should be collected) as the model's outputs
|
140
209
|
self.model, _ = self.fw_impl.model_builder(graph,
|
141
210
|
mode=ModelBuilderMode.FLOAT,
|
142
|
-
append2output=
|
211
|
+
append2output=append2output,
|
143
212
|
fw_info=self.fw_info)
|
144
213
|
|
145
214
|
def infer(self, inputs_list: List[np.ndarray]):
|
@@ -154,14 +223,43 @@ class ModelCollector:
|
|
154
223
|
|
155
224
|
# TODO: Thinking about delegating collections to framework
|
156
225
|
# TODO: migrate datasets to framework datasets
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
226
|
+
compute_hessians = self.qc.activation_error_method == QuantizationErrorMethod.HMSE
|
227
|
+
|
228
|
+
# Retrieve intermediate layer activations for statistical analysis.
|
229
|
+
# Enable gradient computation if Hessian calculations are required.
|
230
|
+
activation_tensors = self.fw_impl.run_model_inference(self.model, inputs_list, requires_grad=compute_hessians)
|
231
|
+
|
232
|
+
if compute_hessians:
|
233
|
+
if self.hessian_service is None:
|
234
|
+
Logger.critical(
|
235
|
+
"Hessian computation is enabled but `hessian_service` is not initialized. "
|
236
|
+
"Ensure that `hessian_service` is properly set."
|
237
|
+
) # pragma: no cover
|
238
|
+
request = HessianScoresRequest(
|
239
|
+
mode=HessianMode.ACTIVATION,
|
240
|
+
granularity=HessianScoresGranularity.PER_ELEMENT,
|
241
|
+
target_nodes=self.intermediate_output_tensors,
|
242
|
+
data_loader=None,
|
243
|
+
n_samples=None,
|
244
|
+
compute_from_tensors=True
|
245
|
+
)
|
246
|
+
hessian_tensors = self.hessian_service.fetch_hessian(request=request,
|
247
|
+
activation_tensors=activation_tensors)
|
248
|
+
hessian_tensors = list(hessian_tensors.values())
|
249
|
+
else:
|
250
|
+
hessian_tensors = []
|
251
|
+
|
252
|
+
# Hessian is not calculated for the output, add "None" as weights for output tenosrs
|
253
|
+
hessian_tensors += [None for _ in range(len(activation_tensors) - len(hessian_tensors))]
|
254
|
+
|
255
|
+
for activation_tensor, hessian_tensor, stats_container in zip(activation_tensors, hessian_tensors, self.stats_containers_list):
|
256
|
+
if isinstance(stats_container, (list, tuple)):
|
257
|
+
if hessian_tensor is None:
|
258
|
+
hessian_tensor = [None for _ in range(len(activation_tensor))]
|
259
|
+
ensure_matching_data_lengths(activation_tensor, hessian_tensor, stats_container)
|
260
|
+
for activation_tensor_i, hessian_tensor_i, sci in zip(activation_tensor, hessian_tensor, stats_container):
|
261
|
+
sci.update_statistics(self.fw_impl.to_numpy(activation_tensor_i),
|
262
|
+
convert_to_numpy_and_abs(hessian_tensor_i, self.fw_impl))
|
166
263
|
else:
|
167
|
-
|
264
|
+
stats_container.update_statistics(self.fw_impl.to_numpy(activation_tensor),
|
265
|
+
convert_to_numpy_and_abs(hessian_tensor, self.fw_impl))
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py
CHANGED
@@ -514,6 +514,8 @@ def get_threshold_selection_histogram_error_function(quantization_method: Quanti
|
|
514
514
|
quant_method_error_function_mapping = {
|
515
515
|
qc.QuantizationErrorMethod.MSE: lambda q_bins, q_count, bins, counts, threshold, _range:
|
516
516
|
_mse_error_histogram(q_bins, q_count, bins, counts),
|
517
|
+
qc.QuantizationErrorMethod.HMSE: lambda q_bins, q_count, bins, counts, threshold, _range:
|
518
|
+
_mse_error_histogram(q_bins, q_count, bins, counts), #HMSE need the same functionality as MSE
|
517
519
|
qc.QuantizationErrorMethod.MAE: lambda q_bins, q_count, bins, counts, threshold, _range:
|
518
520
|
_mae_error_histogram(q_bins, q_count, bins, counts),
|
519
521
|
qc.QuantizationErrorMethod.LP: lambda q_bins, q_count, bins, counts, threshold, _range:
|
@@ -13,66 +13,143 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import numpy as np
|
16
|
-
from typing import Dict, Union
|
16
|
+
from typing import Dict, Union, Optional, Tuple
|
17
17
|
|
18
18
|
from mct_quantizers import QuantizationMethod
|
19
|
+
from model_compression_toolkit.core import QuantizationErrorMethod
|
19
20
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
|
20
21
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
21
22
|
from model_compression_toolkit.core.common.quantization import quantization_params_generation
|
22
23
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
23
24
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
|
24
25
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
26
|
+
def get_histogram_data(
|
27
|
+
activation_quant_cfg: NodeActivationQuantizationConfig,
|
28
|
+
out_stats_container: BaseStatsCollector
|
29
|
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
29
30
|
"""
|
30
|
-
|
31
|
+
Extract and filter the histogram data from the statistics container.
|
31
32
|
|
32
33
|
Args:
|
33
|
-
activation_quant_cfg:
|
34
|
-
|
35
|
-
out_stats_container: Tensor containing output statistics of the node.
|
34
|
+
activation_quant_cfg: Node's activation quantization configuration.
|
35
|
+
out_stats_container: Statistics container with histogram data.
|
36
36
|
|
37
37
|
Returns:
|
38
|
-
|
38
|
+
A tuple containing the filtered bins_values and bins_counts.
|
39
39
|
"""
|
40
|
-
|
41
40
|
bins_values, bins_counts = None, None
|
42
41
|
|
43
42
|
# If the statistics container collected the histogram, we start by filtering outliers using z threshold
|
44
43
|
# filtering, and then computing the threshold based on the filtered histogram.
|
45
44
|
if out_stats_container.require_collection():
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
45
|
+
if activation_quant_cfg.activation_error_method == QuantizationErrorMethod.HMSE:
|
46
|
+
bins_values, bins_counts = out_stats_container.weighted_hc.get_histogram()
|
47
|
+
else:
|
48
|
+
bins_values, bins_counts = out_stats_container.hc.get_histogram()
|
49
|
+
bins_counts = quantization_params_generation.z_score_filter(
|
50
|
+
activation_quant_cfg.z_threshold,
|
51
|
+
bins_values,
|
52
|
+
bins_counts
|
53
|
+
)
|
54
|
+
return bins_values, bins_counts
|
55
|
+
|
56
|
+
def determine_signedness(
|
57
|
+
activation_quant_cfg: NodeActivationQuantizationConfig,
|
58
|
+
nodes_prior_info: NodePriorInfo,
|
59
|
+
min_value: float,
|
60
|
+
bins_values: Optional[np.ndarray],
|
61
|
+
bins_counts: Optional[np.ndarray]
|
62
|
+
) -> bool:
|
63
|
+
"""
|
64
|
+
Determine if the activations should be considered signed based on the quantization configuration,
|
65
|
+
node prior information, and histogram statistics.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
activation_quant_cfg: Node's activation quantization configuration.
|
69
|
+
nodes_prior_info: Prior info collected for the node that is being quantized.
|
70
|
+
min_value: Minimum value from the statistics container.
|
71
|
+
bins_values: Numpy array of histogram bin values.
|
72
|
+
bins_counts: Numpy array of histogram bin counts.
|
51
73
|
|
74
|
+
Returns:
|
75
|
+
A boolean indicating if the activations are signed.
|
76
|
+
"""
|
52
77
|
if activation_quant_cfg.signedness in [Signedness.SIGNED, Signedness.UNSIGNED]:
|
53
|
-
|
54
|
-
elif nodes_prior_info.is_output_bounded():
|
55
|
-
signed = min_value < 0
|
56
|
-
else:
|
57
|
-
signed = np.any(bins_values[:-1][bins_counts > 0] < 0)
|
78
|
+
return activation_quant_cfg.signedness == Signedness.SIGNED
|
58
79
|
|
80
|
+
if nodes_prior_info.is_output_bounded():
|
81
|
+
return min_value < 0
|
82
|
+
|
83
|
+
return np.any(bins_values[:-1][bins_counts > 0] < 0)
|
84
|
+
|
85
|
+
|
86
|
+
def update_activation_quantization_params_fn(
|
87
|
+
activation_quant_cfg: NodeActivationQuantizationConfig,
|
88
|
+
nodes_prior_info: NodePriorInfo):
|
89
|
+
"""
|
90
|
+
Update the activation quantization parameters function based on the quantization method
|
91
|
+
and whether the node's output is bounded.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
activation_quant_cfg: Node's activation quantization configuration.
|
95
|
+
nodes_prior_info: Prior info collected for the node that is being quantized.
|
96
|
+
"""
|
59
97
|
if nodes_prior_info.is_output_bounded():
|
60
98
|
if activation_quant_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
|
61
|
-
activation_quant_cfg.
|
99
|
+
activation_quant_cfg.set_activation_quantization_params_fn(
|
62
100
|
quantization_params_generation.power_of_two_no_clipping_selection_min_max
|
101
|
+
)
|
63
102
|
elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC:
|
64
|
-
activation_quant_cfg.
|
103
|
+
activation_quant_cfg.set_activation_quantization_params_fn(
|
65
104
|
quantization_params_generation.symmetric_no_clipping_selection_min_max
|
105
|
+
)
|
66
106
|
elif activation_quant_cfg.activation_quantization_method == QuantizationMethod.UNIFORM:
|
67
|
-
activation_quant_cfg.
|
107
|
+
activation_quant_cfg.set_activation_quantization_params_fn(
|
68
108
|
quantization_params_generation.uniform_no_clipping_selection_min_max
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
|
113
|
+
nodes_prior_info: NodePriorInfo,
|
114
|
+
out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
|
115
|
+
"""
|
116
|
+
Compute the activations params for a given node in a graph according to a params function.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
activation_quant_cfg: node's activation quantization configuration.
|
120
|
+
nodes_prior_info: Prior info collected for the node that is being quantized.
|
121
|
+
out_stats_container: Tensor containing output statistics of the node.
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
The computed activation quantization params.
|
125
|
+
"""
|
126
|
+
# Update quantization parameters function based on output bounds and quantization method.
|
127
|
+
update_activation_quantization_params_fn(activation_quant_cfg, nodes_prior_info)
|
128
|
+
|
129
|
+
# Extract and filter histogram data from the statistics container.
|
130
|
+
bins_values, bins_counts = get_histogram_data(activation_quant_cfg, out_stats_container)
|
131
|
+
|
132
|
+
# Retrieve the minimum and maximum values from the statistics container.
|
133
|
+
min_value, max_value = out_stats_container.get_min_max_values()
|
134
|
+
|
135
|
+
# Determine if the activations should be considered signed.
|
136
|
+
signed = determine_signedness(
|
137
|
+
activation_quant_cfg,
|
138
|
+
nodes_prior_info,
|
139
|
+
min_value,
|
140
|
+
bins_values,
|
141
|
+
bins_counts
|
142
|
+
)
|
69
143
|
|
70
|
-
return
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
144
|
+
# Compute and return the activation quantization parameters.
|
145
|
+
return activation_quant_cfg.activation_quantization_params_fn(
|
146
|
+
bins_values,
|
147
|
+
bins_counts,
|
148
|
+
activation_quant_cfg.l_p_value,
|
149
|
+
activation_quant_cfg.activation_n_bits,
|
150
|
+
min_value,
|
151
|
+
max_value,
|
152
|
+
min_threshold=activation_quant_cfg.min_threshold,
|
153
|
+
quant_error_method=activation_quant_cfg.activation_error_method,
|
154
|
+
is_signed=signed
|
155
|
+
)
|
@@ -188,18 +188,30 @@ class KerasImplementation(FrameworkImplementation):
|
|
188
188
|
|
189
189
|
def run_model_inference(self,
|
190
190
|
model: Any,
|
191
|
-
input_list: List[Any]
|
191
|
+
input_list: List[Any],
|
192
|
+
requires_grad: bool = False) -> Tuple[tf.Tensor]:
|
192
193
|
"""
|
193
|
-
|
194
|
+
Runs inference on the given Keras model with the provided inputs.
|
195
|
+
|
196
|
+
This method executes the model on the given input data. If `requires_grad` is set to `False`,
|
197
|
+
gradients will not be computed during inference by wrapping execution in a `tf.stop_gradient()` context.
|
194
198
|
|
195
199
|
Args:
|
196
|
-
model: Keras model.
|
197
|
-
input_list:
|
200
|
+
model: The Keras model to execute.
|
201
|
+
input_list: A list of inputs for the model.
|
202
|
+
requires_grad: If False, prevents gradient computation (default: False).
|
198
203
|
|
199
204
|
Returns:
|
200
|
-
|
201
|
-
"""
|
202
|
-
|
205
|
+
A tuple containing the model's output tensors.
|
206
|
+
"""
|
207
|
+
# Prevent gradient computation if requires_grad is False
|
208
|
+
if requires_grad:
|
209
|
+
# Record operations for automatic differentiation
|
210
|
+
with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
|
211
|
+
g.watch(input_list)
|
212
|
+
return model(input_list)
|
213
|
+
else:
|
214
|
+
return model(input_list)
|
203
215
|
|
204
216
|
def shift_negative_correction(self,
|
205
217
|
graph: Graph,
|
@@ -553,28 +565,24 @@ class KerasImplementation(FrameworkImplementation):
|
|
553
565
|
|
554
566
|
Returns: The MAC count og the operation
|
555
567
|
"""
|
556
|
-
|
557
|
-
|
558
|
-
kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape
|
559
|
-
output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
|
560
|
-
|
561
|
-
if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose):
|
562
|
-
# (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel)
|
563
|
-
return np.prod([x for x in output_shape if x is not None]) * \
|
564
|
-
kernel_shape[input_channel_axis] * \
|
565
|
-
(kernel_shape[0] * kernel_shape[1])
|
566
|
-
elif node.is_match_type(DepthwiseConv2D):
|
567
|
-
# Depth * (W_out * H_out) * C_in * (W_kernel * H_kernel)
|
568
|
-
return node.framework_attr.get(DEPTH_MULTIPLIER) * \
|
569
|
-
np.prod([x for x in output_shape if x is not None]) / output_shape[output_channel_axis] * \
|
570
|
-
kernel_shape[input_channel_axis] * \
|
571
|
-
(kernel_shape[0] * kernel_shape[1])
|
572
|
-
elif node.is_match_type(Dense):
|
573
|
-
# IN * OUT
|
574
|
-
return kernel_shape[0] * kernel_shape[1]
|
575
|
-
else:
|
568
|
+
kernels = fw_info.get_kernel_op_attributes(node.type)
|
569
|
+
if not kernels or kernels[0] is None:
|
576
570
|
return 0
|
577
571
|
|
572
|
+
assert len(kernels) == 1
|
573
|
+
kernel_shape = node.get_weights_by_keys(kernels[0]).shape
|
574
|
+
|
575
|
+
if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose) or node.is_match_type(DepthwiseConv2D):
|
576
|
+
h, w = node.get_output_shapes_list()[0][-3:-1]
|
577
|
+
return np.prod(kernel_shape) * h * w
|
578
|
+
|
579
|
+
if node.is_match_type(Dense):
|
580
|
+
# IN * OUT * (all previous dims[:-1])
|
581
|
+
_, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
|
582
|
+
return node.get_total_output_params() * kernel_shape[input_channel_axis]
|
583
|
+
|
584
|
+
return 0
|
585
|
+
|
578
586
|
def apply_second_moment_correction(self,
|
579
587
|
quantized_model: Any,
|
580
588
|
core_config: CoreConfig,
|
model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py
CHANGED
@@ -72,19 +72,21 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
72
72
|
input_tensor.requires_grad_()
|
73
73
|
input_tensor.retain_grad()
|
74
74
|
|
75
|
-
|
75
|
+
model_output_tensors = model(*self.input_images)
|
76
76
|
|
77
|
-
if len(
|
77
|
+
if len(model_output_tensors) != len(grad_model_outputs): # pragma: no cover
|
78
78
|
Logger.critical(f"Mismatch in expected and actual model outputs for activation Hessian approximation. "
|
79
|
-
f"Expected {len(grad_model_outputs)} outputs, received {len(
|
79
|
+
f"Expected {len(grad_model_outputs)} outputs, received {len(model_output_tensors)}.")
|
80
|
+
return model_output_tensors
|
80
81
|
|
82
|
+
def _prep_tensors_for_compute(self, model_output_tensors):
|
81
83
|
# Extracting the intermediate activation tensors and the model real output.
|
82
84
|
# Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
|
83
85
|
num_target_nodes = len(self.hessian_request.target_nodes)
|
84
86
|
# Extract activation tensors of nodes for which we want to compute Hessian
|
85
|
-
target_activation_tensors =
|
87
|
+
target_activation_tensors = model_output_tensors[:num_target_nodes]
|
86
88
|
# Extract the model outputs
|
87
|
-
output_tensors =
|
89
|
+
output_tensors = model_output_tensors[num_target_nodes:]
|
88
90
|
|
89
91
|
# Concat outputs
|
90
92
|
# First, we need to unfold all outputs that are given as list, to extract the actual output tensors
|
@@ -98,79 +100,39 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
98
100
|
Returns:
|
99
101
|
List[np.ndarray]: Scores based on the approximated Hessian for the requested nodes.
|
100
102
|
"""
|
101
|
-
|
102
|
-
|
103
|
-
if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
|
104
|
-
hessian_scores = self._compute_per_tensor(output, target_activation_tensors)
|
105
|
-
elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
|
106
|
-
hessian_scores = self._compute_per_channel(output, target_activation_tensors)
|
103
|
+
if self.hessian_request.compute_from_tensors:
|
104
|
+
model_output_tensors = self.input_images
|
107
105
|
else:
|
108
|
-
|
109
|
-
|
110
|
-
# Convert results to list of numpy arrays
|
111
|
-
hessian_results = [torch_tensor_to_numpy(h) for h in hessian_scores]
|
112
|
-
return hessian_results
|
106
|
+
model_output_tensors = self.forward_pass()
|
107
|
+
output, target_activation_tensors = self._prep_tensors_for_compute(model_output_tensors)
|
113
108
|
|
114
|
-
|
115
|
-
assert self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR
|
116
|
-
ipts_hessian_approx_scores = [torch.tensor([0.0], requires_grad=True, device=output.device)
|
109
|
+
ipts_hessian_approx_scores = [torch.tensor(0.0, requires_grad=True, device=output.device)
|
117
110
|
for _ in range(len(target_activation_tensors))]
|
118
|
-
|
119
|
-
for j in tqdm(range(self.num_iterations_for_approximation),
|
120
|
-
|
111
|
+
|
112
|
+
for j in tqdm(range(self.num_iterations_for_approximation),
|
113
|
+
"Hessian random iterations"): # Approximation iterations
|
121
114
|
v = self._generate_random_vectors_batch(output.shape, output.device)
|
122
115
|
f_v = torch.sum(v * output)
|
123
116
|
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
|
124
|
-
# Computing the hessian-approximation scores by getting the gradient of (output * v)
|
125
117
|
hess_v = autograd.grad(outputs=f_v,
|
126
118
|
inputs=ipt_tensor,
|
127
119
|
retain_graph=True,
|
128
120
|
allow_unused=True)[0]
|
129
|
-
|
130
121
|
if hess_v is None:
|
131
122
|
# In case we have an output node, which is an interest point, but it is not differentiable,
|
132
123
|
# we consider its Hessian to be the initial value 0.
|
133
124
|
continue # pragma: no cover
|
134
125
|
|
135
|
-
# Mean over all dims but the batch (CXHXW for conv)
|
136
|
-
hessian_approx_scores = torch.sum(hess_v ** 2.0, dim=tuple(d for d in range(1, len(hess_v.shape))))
|
137
|
-
|
138
|
-
# Update node Hessian approximation mean over random iterations
|
139
|
-
ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)
|
140
|
-
|
141
|
-
# If the change to the maximal mean Hessian approximation is insignificant we stop the calculation
|
142
|
-
if j > MIN_HESSIAN_ITER:
|
143
|
-
if prev_mean_results is not None:
|
144
|
-
new_mean_res = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
|
145
|
-
relative_delta_per_node = (torch.abs(new_mean_res - prev_mean_results) /
|
146
|
-
(torch.abs(new_mean_res) + 1e-6))
|
147
|
-
max_delta = torch.max(relative_delta_per_node)
|
148
|
-
if max_delta < HESSIAN_COMP_TOLERANCE:
|
149
|
-
break
|
150
|
-
prev_mean_results = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
|
151
|
-
|
152
|
-
# add extra dimension to preserve previous behaviour
|
153
|
-
ipts_hessian_approx_scores = [torch.unsqueeze(t, -1) for t in ipts_hessian_approx_scores]
|
154
|
-
return ipts_hessian_approx_scores
|
155
|
-
|
156
|
-
def _compute_per_channel(self, output, target_activation_tensors):
|
157
|
-
assert self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL
|
158
|
-
ipts_hessian_approx_scores = [torch.tensor(0.0, requires_grad=True, device=output.device)
|
159
|
-
for _ in range(len(target_activation_tensors))]
|
160
|
-
|
161
|
-
for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
|
162
|
-
v = self._generate_random_vectors_batch(output.shape, output.device)
|
163
|
-
f_v = torch.sum(v * output)
|
164
|
-
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
|
165
|
-
hess_v = autograd.grad(outputs=f_v,
|
166
|
-
inputs=ipt_tensor,
|
167
|
-
retain_graph=True)[0]
|
168
126
|
hessian_approx_scores = hess_v ** 2
|
169
|
-
|
170
|
-
if
|
171
|
-
hessian_approx_scores = torch.
|
127
|
+
num_dims = len(hess_v.shape)
|
128
|
+
if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
|
129
|
+
hessian_approx_scores = torch.sum(hessian_approx_scores, dim=tuple(range(1,num_dims))).unsqueeze(-1)
|
130
|
+
elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL and num_dims > 2:
|
131
|
+
hessian_approx_scores = torch.mean(hessian_approx_scores, dim=tuple(range(2, num_dims)))
|
172
132
|
|
173
133
|
# Update node Hessian approximation mean over random iterations
|
174
134
|
ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)
|
175
135
|
|
176
|
-
|
136
|
+
# Convert results to list of numpy arrays
|
137
|
+
hessian_results = [torch_tensor_to_numpy(h) for h in ipts_hessian_approx_scores]
|
138
|
+
return hessian_results
|