mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250619.621__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/RECORD +123 -123
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +2 -5
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
- model_compression_toolkit/core/common/framework_implementation.py +10 -22
- model_compression_toolkit/core/common/framework_info.py +105 -68
- model_compression_toolkit/core/common/graph/base_graph.py +15 -42
- model_compression_toolkit/core/common/graph/base_node.py +103 -42
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
- model_compression_toolkit/core/common/model_collector.py +10 -20
- model_compression_toolkit/core/common/model_validation.py +1 -4
- model_compression_toolkit/core/common/network_editors/actions.py +14 -38
- model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
- model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
- model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
- model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
- model_compression_toolkit/core/common/pruning/pruner.py +1 -6
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
- model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
- model_compression_toolkit/core/graph_prep_runner.py +2 -16
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/keras/default_framework_info.py +138 -87
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
- model_compression_toolkit/core/keras/keras_implementation.py +15 -35
- model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/quantization_prep_runner.py +4 -9
- model_compression_toolkit/core/runner.py +5 -15
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +19 -17
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
- model_compression_toolkit/gptq/common/gptq_training.py +1 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
- model_compression_toolkit/gptq/keras/graph_info.py +4 -6
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/runner.py +1 -7
- model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
- model_compression_toolkit/ptq/runner.py +1 -4
- model_compression_toolkit/qat/common/qat_config.py +2 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
- model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
- model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
- model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/top_level.txt +0 -0
@@ -32,8 +32,7 @@ def analyzer_model_quantization(representative_data_gen: Callable,
|
|
32
32
|
tb_w: TensorboardWriter,
|
33
33
|
float_graph: Graph,
|
34
34
|
quantized_graph: Graph,
|
35
|
-
fw_impl: FrameworkImplementation
|
36
|
-
fw_info: FrameworkInfo):
|
35
|
+
fw_impl: FrameworkImplementation):
|
37
36
|
"""
|
38
37
|
Plot the cosine similarity of different points on the graph between the float and quantized
|
39
38
|
graphs. Add them to the passed TensorboardWriter object and close all tensorboard writer open
|
@@ -45,14 +44,12 @@ def analyzer_model_quantization(representative_data_gen: Callable,
|
|
45
44
|
float_graph: Graph of float model.
|
46
45
|
quantized_graph: Graph of quantized model.
|
47
46
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
48
|
-
fw_info: Information needed for quantization about the specific framework.
|
49
47
|
|
50
48
|
"""
|
51
49
|
if tb_w is not None:
|
52
50
|
visual = NNVisualizer(float_graph,
|
53
51
|
quantized_graph,
|
54
|
-
fw_impl=fw_impl
|
55
|
-
fw_info=fw_info)
|
52
|
+
fw_impl=fw_impl)
|
56
53
|
if not visual.has_compare_points():
|
57
54
|
Logger.error(f'No comparing points were found to plot analyze similarity.')
|
58
55
|
else:
|
@@ -28,20 +28,17 @@ class BaseModelBuilder(ABC):
|
|
28
28
|
def __init__(self,
|
29
29
|
graph: common.Graph,
|
30
30
|
append2output=None,
|
31
|
-
fw_info: FrameworkInfo = None,
|
32
31
|
return_float_outputs: bool = False):
|
33
32
|
"""
|
34
33
|
|
35
34
|
Args:
|
36
35
|
graph: Graph to build the model from.
|
37
36
|
append2output: Nodes of graph to append to model's output.
|
38
|
-
fw_info: Information about the specific framework of the model that is built.
|
39
37
|
return_float_outputs: Whether the model returns float tensors or not.
|
40
38
|
"""
|
41
39
|
|
42
40
|
self.graph = graph
|
43
41
|
self.append2output = append2output
|
44
|
-
self.fw_info = fw_info
|
45
42
|
self.return_float_outputs = return_float_outputs
|
46
43
|
|
47
44
|
@abstractmethod
|
@@ -125,18 +125,16 @@ class FrameworkImplementation(ABC):
|
|
125
125
|
graph: Graph,
|
126
126
|
mode: ModelBuilderMode,
|
127
127
|
append2output: List[Any],
|
128
|
-
fw_info: FrameworkInfo,
|
129
128
|
return_float_outputs: bool = False) -> Tuple:
|
130
129
|
"""
|
131
130
|
Build a framework model from a graph.
|
132
|
-
The mode determines how the model should be
|
131
|
+
The mode determines how the model should be built. append2output is a list of Nodes
|
133
132
|
to set as the model outputs.
|
134
133
|
|
135
134
|
Args:
|
136
135
|
graph: Graph to build the model from it.
|
137
136
|
mode: Mode for how to build the model.
|
138
137
|
append2output: List of Nodes to set as the model's outputs.
|
139
|
-
fw_info: FrameworkInfo object with information about the specific framework's model
|
140
138
|
return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
|
141
139
|
|
142
140
|
Returns:
|
@@ -170,15 +168,13 @@ class FrameworkImplementation(ABC):
|
|
170
168
|
@abstractmethod
|
171
169
|
def shift_negative_correction(self,
|
172
170
|
graph: Graph,
|
173
|
-
core_config: CoreConfig
|
174
|
-
fw_info: FrameworkInfo) -> Graph:
|
171
|
+
core_config: CoreConfig) -> Graph:
|
175
172
|
"""
|
176
173
|
Apply shift negative correction (SNC) on a graph.
|
177
174
|
|
178
175
|
Args:
|
179
176
|
graph: Graph to apply SNC on.
|
180
177
|
core_config: Quantization configuration.
|
181
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
182
178
|
|
183
179
|
Returns:
|
184
180
|
Graph after SNC.
|
@@ -189,15 +185,13 @@ class FrameworkImplementation(ABC):
|
|
189
185
|
@abstractmethod
|
190
186
|
def compute_activation_bias_correction(self,
|
191
187
|
graph: Graph,
|
192
|
-
quant_config: QuantizationConfig
|
193
|
-
fw_info: FrameworkInfo) -> Graph:
|
188
|
+
quant_config: QuantizationConfig) -> Graph:
|
194
189
|
"""
|
195
190
|
Compute activation bias correction on a graph.
|
196
191
|
|
197
192
|
Args:
|
198
193
|
graph: Graph to apply activation bias correction on.
|
199
194
|
quant_config: QuantizationConfig of how the model should be quantized.
|
200
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
201
195
|
|
202
196
|
Returns:
|
203
197
|
Graph after activation bias correction computing.
|
@@ -207,30 +201,28 @@ class FrameworkImplementation(ABC):
|
|
207
201
|
|
208
202
|
@abstractmethod
|
209
203
|
def get_substitutions_channel_equalization(self,
|
210
|
-
quant_config: QuantizationConfig
|
211
|
-
fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
|
204
|
+
quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
|
212
205
|
"""
|
213
206
|
Return a list of the framework substitutions used for channel equalization.
|
214
207
|
|
215
208
|
Args:
|
216
209
|
quant_config: QuantizationConfig to determine which substitutions to return.
|
217
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
218
210
|
|
219
211
|
Returns:
|
220
212
|
A list of the framework substitutions used after we collect statistics.
|
221
213
|
"""
|
222
214
|
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
223
|
-
|
215
|
+
f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover
|
224
216
|
|
225
217
|
@abstractmethod
|
226
|
-
def get_substitutions_prepare_graph(self
|
218
|
+
def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
|
227
219
|
"""
|
228
220
|
|
229
221
|
Returns: A list of the framework substitutions used to prepare the graph.
|
230
222
|
|
231
223
|
"""
|
232
224
|
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
233
|
-
|
225
|
+
f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover
|
234
226
|
|
235
227
|
@abstractmethod
|
236
228
|
def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
|
@@ -328,14 +320,12 @@ class FrameworkImplementation(ABC):
|
|
328
320
|
f'method.') # pragma: no cover
|
329
321
|
|
330
322
|
def get_node_prior_info(self, node: BaseNode,
|
331
|
-
fw_info: FrameworkInfo,
|
332
323
|
graph: Graph) -> NodePriorInfo:
|
333
324
|
"""
|
334
325
|
Get a NodePriorInfo object for a node.
|
335
326
|
|
336
327
|
Args:
|
337
328
|
node: Node to get its prior info.
|
338
|
-
fw_info: Framework specific information needed to create the prior info of the node.
|
339
329
|
graph: Graph to check the next node type.
|
340
330
|
|
341
331
|
Returns:
|
@@ -343,7 +333,7 @@ class FrameworkImplementation(ABC):
|
|
343
333
|
"""
|
344
334
|
|
345
335
|
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
346
|
-
|
336
|
+
f'framework\'s get_node_prior_info method.') # pragma: no cover
|
347
337
|
|
348
338
|
def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
|
349
339
|
"""
|
@@ -394,20 +384,18 @@ class FrameworkImplementation(ABC):
|
|
394
384
|
|
395
385
|
@abstractmethod
|
396
386
|
def get_node_mac_operations(self,
|
397
|
-
node: BaseNode
|
398
|
-
fw_info: FrameworkInfo) -> float:
|
387
|
+
node: BaseNode) -> float:
|
399
388
|
"""
|
400
389
|
Gets the MAC operation count for a given operation.
|
401
390
|
|
402
391
|
Args:
|
403
392
|
node: A graph node that wraps the operation for which the MAC count is computed.
|
404
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
405
393
|
|
406
394
|
Returns: The MAC count of the operation
|
407
395
|
"""
|
408
396
|
|
409
397
|
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
410
|
-
|
398
|
+
f'framework\'s get_node_mac_operations method.') # pragma: no cover
|
411
399
|
|
412
400
|
@abstractmethod
|
413
401
|
def apply_second_moment_correction(self,
|
@@ -16,16 +16,16 @@
|
|
16
16
|
|
17
17
|
from collections.abc import Callable
|
18
18
|
from enum import Enum
|
19
|
-
from typing import Dict, Any,
|
19
|
+
from typing import Dict, Any, Tuple, NamedTuple
|
20
|
+
from abc import ABC, abstractmethod
|
20
21
|
|
21
22
|
from mct_quantizers import QuantizationMethod
|
22
|
-
from model_compression_toolkit.defaultdict import DefaultDict
|
23
23
|
|
24
24
|
|
25
25
|
# Default value to use for ops without kernel.
|
26
26
|
# This is a weird default, but it's used all over the place, so for now only extract it to const so that it can be
|
27
27
|
# referenced by variable instead of hard-coded.
|
28
|
-
|
28
|
+
DEFAULT_KERNEL_ATTRIBUTE = None
|
29
29
|
|
30
30
|
|
31
31
|
class ChannelAxis(Enum):
|
@@ -42,89 +42,83 @@ class ChannelAxis(Enum):
|
|
42
42
|
NCHW = 1
|
43
43
|
|
44
44
|
|
45
|
-
class
|
45
|
+
class ChannelAxisMapping(NamedTuple):
|
46
|
+
output: int
|
47
|
+
input: int
|
46
48
|
|
47
|
-
def __init__(self,
|
48
|
-
activation_quantizer_mapping: Dict[QuantizationMethod, Callable],
|
49
|
-
kernel_channels_mapping: DefaultDict,
|
50
|
-
activation_min_max_mapping: Dict[str, tuple],
|
51
|
-
layer_min_max_mapping: Dict[Any, tuple],
|
52
|
-
kernel_ops_attributes_mapping: DefaultDict,
|
53
|
-
out_channel_axis_mapping: DefaultDict):
|
54
|
-
"""
|
55
|
-
A class to wrap all information about a specific framework the library needs to quantize a model.
|
56
|
-
Specifically, FrameworkInfo holds lists of layers by how they should be quantized, and multiple mappings such as
|
57
|
-
layer to it kernel channels indices, and a layer to its min/max values, etc.
|
58
|
-
The layers lists are divided into three groups:
|
59
|
-
kernel_ops: Layers that have coefficients and need to get quantized (e.g., Conv2D, Dense, etc.)
|
60
|
-
activation_ops: Layers that their outputs should get quantized (e.g., Add, ReLU, etc.)
|
61
|
-
no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
|
62
|
-
|
63
|
-
Args:
|
64
|
-
activation_quantizer_mapping (Dict[QuantizationMethod, Callable]): A dictionary mapping from QuantizationMethod to a quantization function.
|
65
|
-
kernel_channels_mapping (DefaultDict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
|
66
|
-
activation_min_max_mapping (Dict[str, tuple]): Dictionary from an activation function to its min/max output values.
|
67
|
-
layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
|
68
|
-
kernel_ops_attributes_mapping (DefaultDict): Dictionary from a framework operator to a list of its weights attirbutes to quantize.
|
69
|
-
out_channel_axis_mapping (DefaultDict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
|
70
|
-
|
71
|
-
Examples:
|
72
|
-
When quantizing a Keras model, if we want to quantize the kernels of Conv2D layers only, we can
|
73
|
-
set, and we know it's kernel out/in channel indices are (3, 2) respectivly:
|
74
|
-
|
75
|
-
>>> import tensorflow as tf
|
76
|
-
>>> kernel_ops = [tf.keras.layers.Conv2D]
|
77
|
-
>>> kernel_channels_mapping = DefaultDict({tf.keras.layers.Conv2D: (3,2)})
|
78
49
|
|
79
|
-
|
50
|
+
class FrameworkInfo(ABC):
|
51
|
+
"""
|
52
|
+
A class to wrap all information about a specific framework the library needs to quantize a model.
|
53
|
+
Specifically, FrameworkInfo holds lists of layers by how they should be quantized, and multiple mappings such as
|
54
|
+
layer to it kernel channels indices, and a layer to its min/max values, etc.
|
55
|
+
The layers lists are divided into three groups:
|
56
|
+
kernel_ops: Layers that have coefficients and need to get quantized (e.g., Conv2D, Dense, etc.)
|
57
|
+
activation_ops: Layers that their outputs should get quantized (e.g., Add, ReLU, etc.)
|
58
|
+
no_quantization_ops:Layers that should not get quantized (e.g., Reshape, Transpose, etc.)
|
59
|
+
|
60
|
+
Fields:
|
61
|
+
activation_quantizer_mapping (Dict[QuantizationMethod, Callable]): A dictionary mapping from QuantizationMethod to a quantization function.
|
62
|
+
kernel_channels_mapping (Dict): Dictionary from a layer to a tuple of its kernel in/out channels indices.
|
63
|
+
kernel_ops_attribute_mapping (Dict): Dictionary from a framework operator to its weight attribute to quantize.
|
64
|
+
out_channel_axis_mapping (Dict): Dictionary of output channels of the model's layers (for computing statistics per-channel).
|
65
|
+
_layer_min_max_mapping (Dict[Any, tuple]): Dictionary from a layer to its min/max output values.
|
80
66
|
|
81
|
-
|
67
|
+
"""
|
82
68
|
|
83
|
-
|
69
|
+
activation_quantizer_mapping: Dict[QuantizationMethod, Callable]
|
70
|
+
kernel_channels_mapping: Dict[Any, ChannelAxisMapping]
|
71
|
+
kernel_ops_attribute_mapping: Dict[Any, str]
|
72
|
+
out_channel_axis_mapping: Dict[Any, int]
|
73
|
+
_layer_min_max_mapping: Dict[Any, tuple]
|
84
74
|
|
85
|
-
|
86
|
-
>>> FrameworkInfo(kernel_channels_mapping, activation_min_max_mapping, {})
|
75
|
+
_default_channel_mapping = ChannelAxisMapping(None, None)
|
87
76
|
|
88
|
-
|
77
|
+
@classmethod
|
78
|
+
def get_kernel_op_attribute(cls, node_type: Any) -> str:
|
79
|
+
"""
|
80
|
+
Get attribute of a layer's weight to quantize.
|
89
81
|
|
90
|
-
|
91
|
-
|
82
|
+
Args:
|
83
|
+
node_type: Layer to get its attribute.
|
92
84
|
|
85
|
+
Returns:
|
86
|
+
Attribute the layer has and should be quantized.
|
93
87
|
"""
|
88
|
+
return cls.kernel_ops_attribute_mapping.get(node_type, DEFAULT_KERNEL_ATTRIBUTE)
|
94
89
|
|
95
|
-
|
96
|
-
|
97
|
-
self.activation_min_max_mapping = activation_min_max_mapping
|
98
|
-
self.layer_min_max_mapping = layer_min_max_mapping
|
99
|
-
self.kernel_ops_attributes_mapping = kernel_ops_attributes_mapping
|
100
|
-
self.out_channel_axis_mapping = out_channel_axis_mapping
|
101
|
-
|
102
|
-
def get_kernel_op_attributes(self, node_type: Any) -> List[str]:
|
90
|
+
@classmethod
|
91
|
+
def is_kernel_op(cls, node_type: Any) -> bool:
|
103
92
|
"""
|
104
|
-
|
93
|
+
Check is the node is a kernel operation.
|
105
94
|
|
106
95
|
Args:
|
107
96
|
node_type: Layer to get its attributes.
|
108
97
|
|
109
98
|
Returns:
|
110
|
-
|
99
|
+
True if node type is a kernel operation, else False.
|
111
100
|
"""
|
112
|
-
|
113
|
-
return attr_list
|
101
|
+
return node_type in cls.kernel_ops_attribute_mapping
|
114
102
|
|
115
|
-
|
103
|
+
@classmethod
|
104
|
+
def get_layer_min_max(cls, layer: Any, fw_attrs: Dict) -> Tuple[float, float]:
|
116
105
|
"""
|
117
|
-
|
118
|
-
|
106
|
+
Return layer min/max mapping the FrameworkInfo holds.
|
119
107
|
Args:
|
120
|
-
|
108
|
+
layer: A layer to check if has a min/max known values.
|
109
|
+
fw_attrs: framework attributes from framework layer.
|
121
110
|
|
122
111
|
Returns:
|
123
|
-
|
112
|
+
Layer's min/max known values.
|
124
113
|
"""
|
125
|
-
return node_type in self.kernel_ops_attributes_mapping.keys()
|
126
114
|
|
127
|
-
|
115
|
+
if cls.layers_has_min_max(layer):
|
116
|
+
return cls._layer_min_max_mapping[layer]
|
117
|
+
else:
|
118
|
+
return None, None
|
119
|
+
|
120
|
+
@classmethod
|
121
|
+
def layers_has_min_max(cls, layer: Any) -> bool:
|
128
122
|
"""
|
129
123
|
Check if a layer is in a layer to min/max mapping the FrameworkInfo holds.
|
130
124
|
Args:
|
@@ -134,17 +128,60 @@ class FrameworkInfo:
|
|
134
128
|
Whether a layer has a min/max known values or not.
|
135
129
|
"""
|
136
130
|
|
137
|
-
return layer in
|
131
|
+
return layer in cls._layer_min_max_mapping
|
138
132
|
|
139
|
-
|
133
|
+
@classmethod
|
134
|
+
@abstractmethod
|
135
|
+
def get_kernel_channels(cls, node_type: Any) -> ChannelAxisMapping:
|
140
136
|
"""
|
141
|
-
|
137
|
+
Returns node's channels mapping from kernel_channels_mapping or framework specific default value.
|
138
|
+
Args:
|
139
|
+
node_type: A node type
|
142
140
|
|
141
|
+
Returns:
|
142
|
+
Node's channels mapping.
|
143
|
+
"""
|
144
|
+
pass
|
145
|
+
|
146
|
+
@classmethod
|
147
|
+
@abstractmethod
|
148
|
+
def get_out_channel_axis(cls, node_type: Any):
|
149
|
+
"""
|
150
|
+
Returns node's output channel mapping from out_channel_axis_mapping or framework specific default value.
|
143
151
|
Args:
|
144
|
-
|
152
|
+
node_type: A node type.
|
145
153
|
|
146
154
|
Returns:
|
147
|
-
|
155
|
+
Node's output channel axis.
|
156
|
+
|
148
157
|
"""
|
158
|
+
pass
|
159
|
+
|
160
|
+
|
161
|
+
# Pointer to current FrameworkInfo class.
|
162
|
+
_current_framework_info: type[FrameworkInfo] = None
|
163
|
+
|
164
|
+
|
165
|
+
def get_fw_info():
|
166
|
+
"""
|
167
|
+
A common function to get the current FrameworkInfo class. Raises an error if the pointer wasn't initialized.
|
168
|
+
|
169
|
+
Returns: FrameworkInfo class.
|
170
|
+
"""
|
171
|
+
assert _current_framework_info is not None, "fw_info isn't initialized."
|
172
|
+
assert issubclass(_current_framework_info, FrameworkInfo), "fw_info isn't initialized to a FrameworkInfo class."
|
173
|
+
return _current_framework_info
|
174
|
+
|
175
|
+
|
176
|
+
def set_fw_info(fw_info: type[FrameworkInfo]):
|
177
|
+
"""
|
178
|
+
A common function to set the current FrameworkInfo class. Raises an error if fw_info doesn't inherit from FrameworkInfo.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
fw_info: Framework specific object implementing the FrameworkInfo.
|
182
|
+
"""
|
183
|
+
global _current_framework_info
|
184
|
+
assert _current_framework_info in [None, _current_framework_info], "FrameworkInfo already initialized."
|
185
|
+
assert issubclass(fw_info, FrameworkInfo), "fw_info must inherit from FrameworkInfo."
|
149
186
|
|
150
|
-
|
187
|
+
_current_framework_info = fw_info
|
@@ -23,7 +23,6 @@ import numpy as np
|
|
23
23
|
|
24
24
|
from networkx.algorithms.dag import topological_sort
|
25
25
|
|
26
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
27
26
|
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
|
28
27
|
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX, EDGE_SOURCE_INDEX
|
29
28
|
from model_compression_toolkit.core.common.graph.edge import Edge, convert_to_edge
|
@@ -74,7 +73,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
74
73
|
input_nodes: List[BaseNode],
|
75
74
|
output_nodes: List[OutTensor],
|
76
75
|
edge_list: List[Edge],
|
77
|
-
fw_info: FrameworkInfo = None,
|
78
76
|
**attr):
|
79
77
|
"""
|
80
78
|
Args:
|
@@ -82,7 +80,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
82
80
|
input_nodes: List of input nodes the model
|
83
81
|
output_nodes: List of output nodes of the model to a list of their output indices.
|
84
82
|
edge_list: List of edges the graph has between nodes.
|
85
|
-
fw_info: FrameworkInfo object (needed for computing the graph's weights memory).
|
86
83
|
**attr: Attributes to add to graph as key=value pairs.
|
87
84
|
"""
|
88
85
|
|
@@ -103,7 +100,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
103
100
|
e.sink_node,
|
104
101
|
**e.get_attributes())
|
105
102
|
self.user_info = UserInformation()
|
106
|
-
self.fw_info = fw_info
|
107
103
|
|
108
104
|
@property
|
109
105
|
def skip_validation_check(self) -> bool:
|
@@ -124,16 +120,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
124
120
|
def fusing_info(self, fusing_info: FusingInfo):
|
125
121
|
self._fusing_info = fusing_info
|
126
122
|
|
127
|
-
def set_fw_info(self,
|
128
|
-
fw_info: FrameworkInfo):
|
129
|
-
"""
|
130
|
-
Set the graph's framework info.
|
131
|
-
Args:
|
132
|
-
fw_info: FrameworkInfo object.
|
133
|
-
"""
|
134
|
-
|
135
|
-
self.fw_info = fw_info
|
136
|
-
|
137
123
|
def set_fqc(self,
|
138
124
|
fqc: FrameworkQuantizationCapabilities):
|
139
125
|
"""
|
@@ -563,7 +549,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
563
549
|
return output_edges
|
564
550
|
|
565
551
|
def get_configurable_sorted_nodes_names(self,
|
566
|
-
fw_info: FrameworkInfo,
|
567
552
|
include_reused_nodes: bool = False) -> List[str]:
|
568
553
|
"""
|
569
554
|
Get a list of nodes' names that can be configured (namely, has one or
|
@@ -571,56 +556,49 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
571
556
|
order of the graph.
|
572
557
|
|
573
558
|
Args:
|
574
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
575
559
|
include_reused_nodes: Whether or not to include reused nodes (False by default).
|
576
560
|
|
577
561
|
Returns: List of nodes' names that can be configured (namely, has one or
|
578
562
|
more weight qc candidate) sorted topology.
|
579
563
|
|
580
564
|
"""
|
581
|
-
sorted_names = [n.name for n in self.get_configurable_sorted_nodes(
|
582
|
-
include_reused_nodes=include_reused_nodes)]
|
565
|
+
sorted_names = [n.name for n in self.get_configurable_sorted_nodes(include_reused_nodes=include_reused_nodes)]
|
583
566
|
return sorted_names
|
584
567
|
|
585
568
|
def get_weights_configurable_nodes(self,
|
586
|
-
fw_info: FrameworkInfo,
|
587
569
|
include_reused_nodes: bool = False) -> List[BaseNode]:
|
588
570
|
"""
|
589
571
|
Get a list of nodes that their weights can be configured (namely, has one or
|
590
572
|
more weight qc candidate and their weights should be quantized).
|
591
573
|
|
592
574
|
Args:
|
593
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
594
575
|
include_reused_nodes: Whether to include reused nodes (False by default).
|
595
576
|
|
596
577
|
Returns:
|
597
578
|
A list of nodes that their weights can be configured (namely, has one or more weight qc candidate).
|
598
579
|
"""
|
599
580
|
# configurability is only relevant for kernel attribute quantization
|
600
|
-
potential_conf_nodes = [n for n in list(self) if
|
581
|
+
potential_conf_nodes = [n for n in list(self) if n.is_kernel_op]
|
601
582
|
|
602
583
|
def is_configurable(n):
|
603
|
-
|
604
|
-
return any(n.is_configurable_weight(attr) for attr in kernel_attrs) and (not n.reuse or include_reused_nodes)
|
584
|
+
return n.is_configurable_weight(n.kernel_attr) and (not n.reuse or include_reused_nodes)
|
605
585
|
|
606
586
|
return [n for n in potential_conf_nodes if is_configurable(n)]
|
607
587
|
|
608
588
|
def get_sorted_weights_configurable_nodes(self,
|
609
|
-
fw_info: FrameworkInfo,
|
610
589
|
include_reused_nodes: bool = False) -> List[BaseNode]:
|
611
590
|
"""
|
612
591
|
Get a list of sorted nodes that their weights can be configured (namely, has one or
|
613
592
|
more weight qc candidate and their weights should be quantized).
|
614
593
|
|
615
594
|
Args:
|
616
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
617
595
|
include_reused_nodes: Whether to include reused nodes (False by default).
|
618
596
|
|
619
597
|
Returns:
|
620
598
|
A list of nodes that their weights can be configured (namely, has one or more weight qc candidate)
|
621
599
|
sorted topologically.
|
622
600
|
"""
|
623
|
-
return self._sort_nodes_in_list(self.get_weights_configurable_nodes(
|
601
|
+
return self._sort_nodes_in_list(self.get_weights_configurable_nodes(include_reused_nodes))
|
624
602
|
|
625
603
|
def get_activation_configurable_nodes(self) -> List[BaseNode]:
|
626
604
|
"""
|
@@ -644,7 +622,6 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
644
622
|
return self._sort_nodes_in_list(self.get_activation_configurable_nodes())
|
645
623
|
|
646
624
|
def get_configurable_sorted_nodes(self,
|
647
|
-
fw_info: FrameworkInfo,
|
648
625
|
include_reused_nodes: bool = False) -> List[BaseNode]:
|
649
626
|
"""
|
650
627
|
Get a list of nodes that can be configured (namely, has one or
|
@@ -652,14 +629,13 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
652
629
|
The nodes are sorted according to the topological order of the graph.
|
653
630
|
|
654
631
|
Args:
|
655
|
-
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
|
656
632
|
include_reused_nodes: Whether or not to include reused nodes (False by default).
|
657
633
|
|
658
634
|
Returns:
|
659
635
|
A list of nodes that can be configured (namely, has one or more qc candidate) sorted topology.
|
660
636
|
|
661
637
|
"""
|
662
|
-
weights_configurable_nodes = self.get_weights_configurable_nodes(
|
638
|
+
weights_configurable_nodes = self.get_weights_configurable_nodes(include_reused_nodes)
|
663
639
|
activation_configurable_nodes = self.get_activation_configurable_nodes()
|
664
640
|
|
665
641
|
# combine and remove duplications
|
@@ -684,7 +660,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
684
660
|
sorted_configurable_nodes.append(n)
|
685
661
|
return sorted_configurable_nodes
|
686
662
|
|
687
|
-
def get_min_candidates_config(self
|
663
|
+
def get_min_candidates_config(self) -> Dict[BaseNode, int]:
|
688
664
|
"""
|
689
665
|
Builds a minimal configuration.
|
690
666
|
Note: we assume that a minimal configuration exists, i.e., each configurable node has exactly one candidate
|
@@ -697,26 +673,23 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
697
673
|
Returns:
|
698
674
|
A dict from layer to an index of its minimal candidate.
|
699
675
|
"""
|
700
|
-
conf_sorted_nodes = self.get_configurable_sorted_nodes(
|
676
|
+
conf_sorted_nodes = self.get_configurable_sorted_nodes()
|
701
677
|
return {n: n.find_min_candidate_index() for n in conf_sorted_nodes}
|
702
678
|
|
703
|
-
def get_max_candidates_config(self
|
679
|
+
def get_max_candidates_config(self) -> Dict[BaseNode, int]:
|
704
680
|
"""
|
705
681
|
Builds a maximal configuration.
|
706
682
|
Note: we assume that a maximal configuration exists, i.e., each configurable node has exactly one candidate
|
707
683
|
with maximal n_bits (in both weight and activation if both are quantized, or in the relevant one if only
|
708
684
|
one of them is quantized)
|
709
685
|
|
710
|
-
Args:
|
711
|
-
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
|
712
|
-
|
713
686
|
Returns:
|
714
687
|
A dict from layer to an index of its maximal candidate.
|
715
688
|
"""
|
716
|
-
conf_sorted_nodes = self.get_configurable_sorted_nodes(
|
689
|
+
conf_sorted_nodes = self.get_configurable_sorted_nodes()
|
717
690
|
return {n: n.find_max_candidate_index() for n in conf_sorted_nodes}
|
718
691
|
|
719
|
-
def get_final_weights_config(self
|
692
|
+
def get_final_weights_config(self) -> List[Tuple[BaseNode, int]]:
|
720
693
|
"""
|
721
694
|
Gets the final number of bits for quantization of each weights' configurable layer.
|
722
695
|
|
@@ -726,9 +699,9 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
726
699
|
Returns: A list of pairs of (node type, node's weights quantization bitwidth).
|
727
700
|
|
728
701
|
"""
|
729
|
-
sorted_conf_weights = self.get_sorted_weights_configurable_nodes(
|
702
|
+
sorted_conf_weights = self.get_sorted_weights_configurable_nodes()
|
730
703
|
# a configurable node by definition has a kernel op
|
731
|
-
return [(n, n.final_weights_quantization_cfg.get_attr_config(
|
704
|
+
return [(n, n.final_weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_n_bits)
|
732
705
|
for n in sorted_conf_weights]
|
733
706
|
|
734
707
|
def get_final_activation_config(self) -> List[Tuple[BaseNode, int]]:
|
@@ -846,7 +819,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
846
819
|
next_node = self.out_edges(next_node)[0].sink_node
|
847
820
|
|
848
821
|
# If next_node is an exit node and has only one incoming edge, the topology is prunable.
|
849
|
-
if fw_impl.is_node_exit_node(next_node, entry_node
|
822
|
+
if fw_impl.is_node_exit_node(next_node, entry_node) and len(self.in_edges(next_node)) == 1:
|
850
823
|
return True
|
851
824
|
|
852
825
|
# If the next node is not an intermediate node or has more than one incoming/outgoing edge,
|
@@ -876,7 +849,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
876
849
|
|
877
850
|
intermediate_nodes, exit_node = self._find_intermediate_and_exit_nodes(entry_node, fw_impl)
|
878
851
|
|
879
|
-
if not fw_impl.is_node_exit_node(exit_node, entry_node
|
852
|
+
if not fw_impl.is_node_exit_node(exit_node, entry_node):
|
880
853
|
Logger.critical(f"Node {exit_node} is not a valid exit node for the pruning section starting with {entry_node}.") # pragma: no cover
|
881
854
|
|
882
855
|
return PruningSection(entry_node=entry_node,
|
@@ -897,7 +870,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
897
870
|
"""
|
898
871
|
intermediate_nodes = []
|
899
872
|
next_node = self.out_edges(entry_node)[0].sink_node
|
900
|
-
while not fw_impl.is_node_exit_node(next_node, entry_node
|
873
|
+
while not fw_impl.is_node_exit_node(next_node, entry_node):
|
901
874
|
intermediate_nodes.append(next_node)
|
902
875
|
next_node = self.out_edges(next_node)[0].sink_node
|
903
876
|
|