mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250927.534__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/METADATA +6 -3
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/RECORD +165 -159
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +5 -2
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
- model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
- model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
- model_compression_toolkit/core/common/framework_implementation.py +22 -10
- model_compression_toolkit/core/common/framework_info.py +83 -93
- model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
- model_compression_toolkit/core/common/graph/base_graph.py +72 -45
- model_compression_toolkit/core/common/graph/base_node.py +141 -121
- model_compression_toolkit/core/common/graph/functional_node.py +2 -19
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
- model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
- model_compression_toolkit/core/common/model_collector.py +18 -22
- model_compression_toolkit/core/common/model_validation.py +44 -0
- model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
- model_compression_toolkit/core/common/network_editors/actions.py +130 -14
- model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
- model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
- model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
- model_compression_toolkit/core/common/pruning/pruner.py +6 -1
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
- model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
- model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
- model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
- model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
- model_compression_toolkit/core/graph_prep_runner.py +35 -22
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/keras/default_framework_info.py +91 -131
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
- model_compression_toolkit/core/keras/keras_implementation.py +37 -17
- model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
- model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
- model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
- model_compression_toolkit/core/quantization_prep_runner.py +11 -6
- model_compression_toolkit/core/runner.py +15 -5
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
- model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
- model_compression_toolkit/gptq/common/gptq_training.py +8 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
- model_compression_toolkit/gptq/keras/graph_info.py +6 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
- model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
- model_compression_toolkit/gptq/runner.py +7 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
- model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
- model_compression_toolkit/ptq/runner.py +4 -1
- model_compression_toolkit/qat/common/qat_config.py +6 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
- model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
- model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
- model_compression_toolkit/xquant/__init__.py +1 -0
- model_compression_toolkit/xquant/common/constants.py +1 -0
- model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
- model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
- model_compression_toolkit/xquant/common/xquant_config.py +27 -1
- model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
- model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
- model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
- model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
- model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
- model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
- model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
- model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
- model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
- model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
- model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
- model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
- model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
- model_compression_toolkit/quantization_preparation/__init__.py +0 -14
- model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250927.534.dist-info}/top_level.txt +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
- /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -23,6 +23,7 @@ import numpy as np
|
|
23
23
|
|
24
24
|
from networkx.algorithms.dag import topological_sort
|
25
25
|
|
26
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
26
27
|
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
|
27
28
|
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX, EDGE_SOURCE_INDEX
|
28
29
|
from model_compression_toolkit.core.common.graph.edge import Edge, convert_to_edge
|
@@ -32,8 +33,7 @@ from model_compression_toolkit.core.common.collectors.statistics_collector impor
|
|
32
33
|
from model_compression_toolkit.core.common.collectors.statistics_collector import scale_statistics, shift_statistics
|
33
34
|
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
|
34
35
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
35
|
-
from model_compression_toolkit.core.common.quantization.node_quantization_config import
|
36
|
-
NodeActivationQuantizationConfig, ActivationQuantizationMode
|
36
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
|
37
37
|
from model_compression_toolkit.logger import Logger
|
38
38
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
|
39
39
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
|
@@ -74,6 +74,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
74
74
|
input_nodes: List[BaseNode],
|
75
75
|
output_nodes: List[OutTensor],
|
76
76
|
edge_list: List[Edge],
|
77
|
+
fw_info: FrameworkInfo = None,
|
77
78
|
**attr):
|
78
79
|
"""
|
79
80
|
Args:
|
@@ -81,6 +82,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
81
82
|
input_nodes: List of input nodes the model
|
82
83
|
output_nodes: List of output nodes of the model to a list of their output indices.
|
83
84
|
edge_list: List of edges the graph has between nodes.
|
85
|
+
fw_info: FrameworkInfo object (needed for computing the graph's weights memory).
|
84
86
|
**attr: Attributes to add to graph as key=value pairs.
|
85
87
|
"""
|
86
88
|
|
@@ -101,6 +103,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
101
103
|
e.sink_node,
|
102
104
|
**e.get_attributes())
|
103
105
|
self.user_info = UserInformation()
|
106
|
+
self.fw_info = fw_info
|
104
107
|
|
105
108
|
@property
|
106
109
|
def skip_validation_check(self) -> bool:
|
@@ -121,13 +124,38 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
121
124
|
def fusing_info(self, fusing_info: FusingInfo):
|
122
125
|
self._fusing_info = fusing_info
|
123
126
|
|
124
|
-
def
|
127
|
+
def set_fw_info(self,
|
128
|
+
fw_info: FrameworkInfo):
|
129
|
+
"""
|
130
|
+
Set the graph's framework info.
|
131
|
+
Args:
|
132
|
+
fw_info: FrameworkInfo object.
|
133
|
+
"""
|
134
|
+
|
135
|
+
self.fw_info = fw_info
|
136
|
+
|
137
|
+
def set_fqc(self,
|
138
|
+
fqc: FrameworkQuantizationCapabilities):
|
125
139
|
"""
|
126
140
|
Set the graph's FQC.
|
127
141
|
Args:
|
128
142
|
fqc: FrameworkQuantizationCapabilities object.
|
129
143
|
"""
|
130
|
-
#
|
144
|
+
# validate graph nodes are either from the framework or a custom layer defined in the FQC
|
145
|
+
# Validate graph nodes are either built-in layers from the framework or custom layers defined in the FQC
|
146
|
+
fqc_layers = fqc.op_sets_to_layers.get_layers()
|
147
|
+
fqc_filtered_layers = [layer for layer in fqc_layers if isinstance(layer, LayerFilterParams)]
|
148
|
+
for n in self.nodes:
|
149
|
+
is_node_in_fqc = any([n.is_match_type(_type) for _type in fqc_layers]) or \
|
150
|
+
any([n.is_match_filter_params(filtered_layer) for filtered_layer in fqc_filtered_layers])
|
151
|
+
if n.is_custom:
|
152
|
+
if not is_node_in_fqc:
|
153
|
+
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
|
154
|
+
' Please add the custom layer to Framework Quantization Capabilities (FQC), or file a feature '
|
155
|
+
'request or an issue if you believe this should be supported.') # pragma: no cover
|
156
|
+
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(fqc).quantization_configurations]):
|
157
|
+
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover
|
158
|
+
|
131
159
|
self.fqc = fqc
|
132
160
|
|
133
161
|
def get_topo_sorted_nodes(self):
|
@@ -535,6 +563,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
535
563
|
return output_edges
|
536
564
|
|
537
565
|
def get_configurable_sorted_nodes_names(self,
|
566
|
+
fw_info: FrameworkInfo,
|
538
567
|
include_reused_nodes: bool = False) -> List[str]:
|
539
568
|
"""
|
540
569
|
Get a list of nodes' names that can be configured (namely, has one or
|
@@ -542,49 +571,56 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
542
571
|
order of the graph.
|
543
572
|
|
544
573
|
Args:
|
574
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
545
575
|
include_reused_nodes: Whether or not to include reused nodes (False by default).
|
546
576
|
|
547
577
|
Returns: List of nodes' names that can be configured (namely, has one or
|
548
578
|
more weight qc candidate) sorted topology.
|
549
579
|
|
550
580
|
"""
|
551
|
-
sorted_names = [n.name for n in self.get_configurable_sorted_nodes(
|
581
|
+
sorted_names = [n.name for n in self.get_configurable_sorted_nodes(fw_info=fw_info,
|
582
|
+
include_reused_nodes=include_reused_nodes)]
|
552
583
|
return sorted_names
|
553
584
|
|
554
585
|
def get_weights_configurable_nodes(self,
|
586
|
+
fw_info: FrameworkInfo,
|
555
587
|
include_reused_nodes: bool = False) -> List[BaseNode]:
|
556
588
|
"""
|
557
589
|
Get a list of nodes that their weights can be configured (namely, has one or
|
558
590
|
more weight qc candidate and their weights should be quantized).
|
559
591
|
|
560
592
|
Args:
|
593
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
561
594
|
include_reused_nodes: Whether to include reused nodes (False by default).
|
562
595
|
|
563
596
|
Returns:
|
564
597
|
A list of nodes that their weights can be configured (namely, has one or more weight qc candidate).
|
565
598
|
"""
|
566
599
|
# configurability is only relevant for kernel attribute quantization
|
567
|
-
potential_conf_nodes = [n for n in self
|
600
|
+
potential_conf_nodes = [n for n in list(self) if fw_info.is_kernel_op(n.type)]
|
568
601
|
|
569
602
|
def is_configurable(n):
|
570
|
-
|
603
|
+
kernel_attrs = fw_info.get_kernel_op_attributes(n.type)
|
604
|
+
return any(n.is_configurable_weight(attr) for attr in kernel_attrs) and (not n.reuse or include_reused_nodes)
|
571
605
|
|
572
606
|
return [n for n in potential_conf_nodes if is_configurable(n)]
|
573
607
|
|
574
608
|
def get_sorted_weights_configurable_nodes(self,
|
609
|
+
fw_info: FrameworkInfo,
|
575
610
|
include_reused_nodes: bool = False) -> List[BaseNode]:
|
576
611
|
"""
|
577
612
|
Get a list of sorted nodes that their weights can be configured (namely, has one or
|
578
613
|
more weight qc candidate and their weights should be quantized).
|
579
614
|
|
580
615
|
Args:
|
616
|
+
fw_info: FrameworkInfo object with information about the specific framework's model.
|
581
617
|
include_reused_nodes: Whether to include reused nodes (False by default).
|
582
618
|
|
583
619
|
Returns:
|
584
620
|
A list of nodes that their weights can be configured (namely, has one or more weight qc candidate)
|
585
621
|
sorted topologically.
|
586
622
|
"""
|
587
|
-
return self._sort_nodes_in_list(self.get_weights_configurable_nodes(include_reused_nodes))
|
623
|
+
return self._sort_nodes_in_list(self.get_weights_configurable_nodes(fw_info, include_reused_nodes))
|
588
624
|
|
589
625
|
def get_activation_configurable_nodes(self) -> List[BaseNode]:
|
590
626
|
"""
|
@@ -608,6 +644,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
608
644
|
return self._sort_nodes_in_list(self.get_activation_configurable_nodes())
|
609
645
|
|
610
646
|
def get_configurable_sorted_nodes(self,
|
647
|
+
fw_info: FrameworkInfo,
|
611
648
|
include_reused_nodes: bool = False) -> List[BaseNode]:
|
612
649
|
"""
|
613
650
|
Get a list of nodes that can be configured (namely, has one or
|
@@ -615,13 +652,14 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
615
652
|
The nodes are sorted according to the topological order of the graph.
|
616
653
|
|
617
654
|
Args:
|
655
|
+
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
|
618
656
|
include_reused_nodes: Whether or not to include reused nodes (False by default).
|
619
657
|
|
620
658
|
Returns:
|
621
659
|
A list of nodes that can be configured (namely, has one or more qc candidate) sorted topology.
|
622
660
|
|
623
661
|
"""
|
624
|
-
weights_configurable_nodes = self.get_weights_configurable_nodes(include_reused_nodes)
|
662
|
+
weights_configurable_nodes = self.get_weights_configurable_nodes(fw_info, include_reused_nodes)
|
625
663
|
activation_configurable_nodes = self.get_activation_configurable_nodes()
|
626
664
|
|
627
665
|
# combine and remove duplications
|
@@ -646,7 +684,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
646
684
|
sorted_configurable_nodes.append(n)
|
647
685
|
return sorted_configurable_nodes
|
648
686
|
|
649
|
-
def get_min_candidates_config(self) -> Dict[BaseNode, int]:
|
687
|
+
def get_min_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
|
650
688
|
"""
|
651
689
|
Builds a minimal configuration.
|
652
690
|
Note: we assume that a minimal configuration exists, i.e., each configurable node has exactly one candidate
|
@@ -659,33 +697,38 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
659
697
|
Returns:
|
660
698
|
A dict from layer to an index of its minimal candidate.
|
661
699
|
"""
|
662
|
-
conf_sorted_nodes = self.get_configurable_sorted_nodes()
|
700
|
+
conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
|
663
701
|
return {n: n.find_min_candidate_index() for n in conf_sorted_nodes}
|
664
702
|
|
665
|
-
def get_max_candidates_config(self) -> Dict[BaseNode, int]:
|
703
|
+
def get_max_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
|
666
704
|
"""
|
667
705
|
Builds a maximal configuration.
|
668
706
|
Note: we assume that a maximal configuration exists, i.e., each configurable node has exactly one candidate
|
669
707
|
with maximal n_bits (in both weight and activation if both are quantized, or in the relevant one if only
|
670
708
|
one of them is quantized)
|
671
709
|
|
710
|
+
Args:
|
711
|
+
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
|
712
|
+
|
672
713
|
Returns:
|
673
714
|
A dict from layer to an index of its maximal candidate.
|
674
715
|
"""
|
675
|
-
conf_sorted_nodes = self.get_configurable_sorted_nodes()
|
716
|
+
conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
|
676
717
|
return {n: n.find_max_candidate_index() for n in conf_sorted_nodes}
|
677
718
|
|
678
|
-
def get_final_weights_config(self) -> List[Tuple[BaseNode, int]]:
|
719
|
+
def get_final_weights_config(self, fw_info: FrameworkInfo) -> List[Tuple[BaseNode, int]]:
|
679
720
|
"""
|
680
721
|
Gets the final number of bits for quantization of each weights' configurable layer.
|
681
722
|
|
682
|
-
|
683
|
-
|
723
|
+
Args:
|
724
|
+
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.
|
725
|
+
|
726
|
+
Returns: A list of pairs of (node type, node's weights quantization bitwidth).
|
684
727
|
|
685
728
|
"""
|
686
|
-
sorted_conf_weights = self.get_sorted_weights_configurable_nodes()
|
729
|
+
sorted_conf_weights = self.get_sorted_weights_configurable_nodes(fw_info)
|
687
730
|
# a configurable node by definition has a kernel op
|
688
|
-
return [(n, n.final_weights_quantization_cfg.get_attr_config(n.
|
731
|
+
return [(n, n.final_weights_quantization_cfg.get_attr_config(self.fw_info.get_kernel_op_attributes(n.type)[0]).weights_n_bits)
|
689
732
|
for n in sorted_conf_weights]
|
690
733
|
|
691
734
|
def get_final_activation_config(self) -> List[Tuple[BaseNode, int]]:
|
@@ -803,7 +846,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
803
846
|
next_node = self.out_edges(next_node)[0].sink_node
|
804
847
|
|
805
848
|
# If next_node is an exit node and has only one incoming edge, the topology is prunable.
|
806
|
-
if fw_impl.is_node_exit_node(next_node, entry_node) and len(self.in_edges(next_node)) == 1:
|
849
|
+
if fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info) and len(self.in_edges(next_node)) == 1:
|
807
850
|
return True
|
808
851
|
|
809
852
|
# If the next node is not an intermediate node or has more than one incoming/outgoing edge,
|
@@ -833,7 +876,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
833
876
|
|
834
877
|
intermediate_nodes, exit_node = self._find_intermediate_and_exit_nodes(entry_node, fw_impl)
|
835
878
|
|
836
|
-
if not fw_impl.is_node_exit_node(exit_node, entry_node):
|
879
|
+
if not fw_impl.is_node_exit_node(exit_node, entry_node, self.fw_info):
|
837
880
|
Logger.critical(f"Node {exit_node} is not a valid exit node for the pruning section starting with {entry_node}.") # pragma: no cover
|
838
881
|
|
839
882
|
return PruningSection(entry_node=entry_node,
|
@@ -854,37 +897,21 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
854
897
|
"""
|
855
898
|
intermediate_nodes = []
|
856
899
|
next_node = self.out_edges(entry_node)[0].sink_node
|
857
|
-
while not fw_impl.is_node_exit_node(next_node, entry_node):
|
900
|
+
while not fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info):
|
858
901
|
intermediate_nodes.append(next_node)
|
859
902
|
next_node = self.out_edges(next_node)[0].sink_node
|
860
903
|
|
861
904
|
return intermediate_nodes, next_node
|
862
905
|
|
863
|
-
|
864
|
-
def override_fused_node_activation_quantization_candidates(self):
|
906
|
+
def disable_fused_nodes_activation_quantization(self):
|
865
907
|
"""
|
866
|
-
|
908
|
+
Disable activation quantization for all nodes in fused operations,
|
867
909
|
except for the last node in each fused group.
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
|
874
|
-
if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization:
|
875
|
-
def update(qc):
|
876
|
-
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(fusing_op_quantization_cfg)
|
877
|
-
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
|
878
|
-
node.quantization_cfg.update_all(update, remove_duplicates=True)
|
879
|
-
else:
|
880
|
-
node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.FLN_NO_QUANT)
|
881
|
-
# Remove duplicate candidates. We cannot compare whole candidates since activation configs might not
|
882
|
-
# be identical, but we do want to treat them as such. So we only check duplication by weight configs.
|
883
|
-
uniq_qcs = []
|
884
|
-
for qc in node.candidates_quantization_cfg:
|
885
|
-
if not any(qc.weights_quantization_cfg == uqc.weights_quantization_cfg for uqc in uniq_qcs):
|
886
|
-
uniq_qcs.append(qc)
|
887
|
-
node.quantization_cfg.candidates_quantization_cfg = uniq_qcs
|
910
|
+
"""
|
911
|
+
nodes_to_disable = self.fusing_info.get_inner_fln_nodes()
|
912
|
+
for node in nodes_to_disable:
|
913
|
+
for qc in node.candidates_quantization_cfg:
|
914
|
+
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
|
888
915
|
|
889
916
|
def validate(self):
|
890
917
|
"""
|
@@ -908,4 +935,4 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
908
935
|
"""
|
909
936
|
Wrap networkx functions (that modifies the graph) with our validate decorator.
|
910
937
|
"""
|
911
|
-
return super().remove_edge(*args, **kwargs)
|
938
|
+
return super().remove_edge(*args, **kwargs)
|