mct-nightly 2.3.0.20250323.559__py3-none-any.whl → 2.3.0.20250324.606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250323.559.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/METADATA +1 -1
- {mct_nightly-2.3.0.20250323.559.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/RECORD +16 -16
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_graph.py +14 -4
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +32 -96
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +17 -42
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +179 -60
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +22 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +1 -5
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +14 -94
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +132 -312
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +1 -1
- model_compression_toolkit/core/runner.py +2 -12
- {mct_nightly-2.3.0.20250323.559.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250323.559.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250323.559.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250323.559.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mct-nightly
|
3
|
-
Version: 2.3.0.
|
3
|
+
Version: 2.3.0.20250324.606
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: Apache Software License
|
{mct_nightly-2.3.0.20250323.559.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/RECORD
RENAMED
@@ -1,5 +1,5 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250324.606.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=cmlxKyxfMU7e0tT0g_dE3z_TjW63WzFf9t_y-H5dZ80,1557
|
3
3
|
model_compression_toolkit/constants.py,sha256=i_R6uXBfO1ph_X6DNJych2x59SUojfJbn7dNjs_mZnc,3846
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
@@ -9,7 +9,7 @@ model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZ
|
|
9
9
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
10
10
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=CVTjBaci8F6EP3IKDnRMfxkP-Sv8qY8GpkGt6FyII2U,11376
|
11
11
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8K5f5v9g5lbT3r1GSmhLAk3NkL40Fg,6593
|
12
|
-
model_compression_toolkit/core/runner.py,sha256=
|
12
|
+
model_compression_toolkit/core/runner.py,sha256=WjZMVXc-OGBTnkiH0PRjNdJEM5pKQRPvLHXor5tjwjk,13096
|
13
13
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
14
14
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
15
15
|
model_compression_toolkit/core/common/framework_implementation.py,sha256=s3yiqnbWkwfnAB1sSal_KAuqVg27rLhAJ2O8LHUbSHE,22494
|
@@ -34,7 +34,7 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
|
|
34
34
|
model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=b41_4rL_Adiza4vpWlmmqgvkpUmWVdfdx0nEIB0p2n8,6195
|
35
35
|
model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=-2fnjyC9q2RPw9st6RxROW-gdtT2mSRz0QZ_Gz1KDz4,5579
|
36
36
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
37
|
-
model_compression_toolkit/core/common/graph/base_graph.py,sha256=
|
37
|
+
model_compression_toolkit/core/common/graph/base_graph.py,sha256=VhniLTiMqL7i1Vqg2UBQuFFTvw2cYeJayssUJwabp3E,38112
|
38
38
|
model_compression_toolkit/core/common/graph/base_node.py,sha256=kZbmAMh5cPAwYzlY8KYa8w0ipL58yApB09-WXQ8plrE,33763
|
39
39
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
40
40
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
|
@@ -67,18 +67,18 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_uti
|
|
67
67
|
model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
|
68
68
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
|
69
69
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
|
70
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=
|
71
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=
|
72
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=
|
70
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=2Pp4hiYvGW2I9YhloDxQNT0sZRg3TDp9CXObloF8IFU,4971
|
71
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=LddWtLileazCOvVSz-7j-GA4yskcGD3UHQGo7XUzSTE,5661
|
72
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=NBzzhkVI407S9cIiw7t7nsP3MrkOdSnweKQdPBXb8to,38180
|
73
73
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=gsigifJ-ykWNafF4t7UMEC_-nd6YPERAk1_z0kT-Y88,27172
|
74
74
|
model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
|
75
|
-
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=
|
75
|
+
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=fk7PWiZ6Na5O_Z_dymk_UfDCTqW_X_4EROU7DZknQnc,9444
|
76
76
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
77
|
-
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=
|
78
|
-
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=
|
79
|
-
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=
|
77
|
+
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
|
78
|
+
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=xCYL36K0nK41VSsLcy52uDA7zVfoLxhubmOrtXbqw7s,39140
|
79
|
+
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
|
80
80
|
model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
81
|
-
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=
|
81
|
+
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=rSSN5MhH5BO5b58d8pe2pY9wc5HbfescoUStfg-nWfk,7263
|
82
82
|
model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
|
83
83
|
model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
|
84
84
|
model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
|
@@ -526,7 +526,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
526
526
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
527
527
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
528
528
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
529
|
-
mct_nightly-2.3.0.
|
530
|
-
mct_nightly-2.3.0.
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
529
|
+
mct_nightly-2.3.0.20250324.606.dist-info/METADATA,sha256=xCfGClT5tOs76bxlrQX92fjTkmtPi-QvD6CBI8TM-EA,27098
|
530
|
+
mct_nightly-2.3.0.20250324.606.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
|
531
|
+
mct_nightly-2.3.0.20250324.606.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
532
|
+
mct_nightly-2.3.0.20250324.606.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250324.000606"
|
@@ -706,14 +706,24 @@ class Graph(nx.MultiDiGraph, GraphSearches):
|
|
706
706
|
"""
|
707
707
|
self.fused_nodes.append(fusion)
|
708
708
|
|
709
|
-
def
|
709
|
+
def has_any_configurable_activation(self) -> bool:
|
710
710
|
"""
|
711
|
-
Checks whether
|
711
|
+
Checks whether any node in the graph has a configurable activation quantization.
|
712
712
|
|
713
|
-
Returns:
|
713
|
+
Returns:
|
714
|
+
Whether any node in the graph has a configurable activation quantization.
|
715
|
+
"""
|
716
|
+
return any([n.has_configurable_activation() for n in self.nodes])
|
717
|
+
|
718
|
+
def has_any_configurable_weights(self):
|
719
|
+
"""
|
720
|
+
Checks whether any node in the graph has any configurable weights quantization.
|
714
721
|
|
722
|
+
Returns:
|
723
|
+
Whether any node in the graph has any configurable weights quantization.
|
715
724
|
"""
|
716
|
-
|
725
|
+
|
726
|
+
return any([n.has_any_configurable_weight() for n in self.nodes])
|
717
727
|
|
718
728
|
def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode):
|
719
729
|
"""
|
@@ -12,12 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from typing import List, Set, Dict,
|
15
|
+
from typing import List, Set, Dict, Tuple
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
|
19
19
|
from model_compression_toolkit.core import FrameworkInfo
|
20
|
-
from model_compression_toolkit.core.common import Graph
|
20
|
+
from model_compression_toolkit.core.common import Graph
|
21
21
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
22
22
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
23
23
|
RUTarget
|
@@ -36,42 +36,46 @@ class MixedPrecisionRUHelper:
|
|
36
36
|
self.fw_impl = fw_impl
|
37
37
|
self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
|
38
38
|
|
39
|
-
def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg:
|
39
|
+
def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: List[int]) -> Dict[RUTarget, np.ndarray]:
|
40
40
|
"""
|
41
|
-
Compute utilization of requested targets for a specific configuration
|
42
|
-
|
43
|
-
|
41
|
+
Compute utilization of requested targets for a specific configuration:
|
42
|
+
for weights and bops - total utilization,
|
43
|
+
for activations and total - utilization per cut.
|
44
44
|
|
45
45
|
Args:
|
46
46
|
ru_targets: resource utilization targets to compute.
|
47
47
|
mp_cfg: a list of candidates indices for configurable layers.
|
48
48
|
|
49
49
|
Returns:
|
50
|
-
Dict of the computed utilization per target.
|
50
|
+
Dict of the computed utilization per target, as 1d vector.
|
51
51
|
"""
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
assert
|
70
|
-
|
71
|
-
|
52
|
+
act_qcs, w_qcs = self.get_quantization_candidates(mp_cfg)
|
53
|
+
|
54
|
+
ru, detailed_ru = self.ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
|
55
|
+
BitwidthMode.QCustom,
|
56
|
+
act_qcs=act_qcs,
|
57
|
+
w_qcs=w_qcs,
|
58
|
+
ru_targets=ru_targets,
|
59
|
+
allow_unused_qcs=True,
|
60
|
+
return_detailed=True)
|
61
|
+
|
62
|
+
ru_dict = {k: np.array([v]) for k, v in ru.get_resource_utilization_dict(restricted_only=True).items()}
|
63
|
+
# For activation and total we need utilization per cut, as different mp configurations might result in
|
64
|
+
# different cuts to be maximal.
|
65
|
+
for target in [RUTarget.ACTIVATION, RUTarget.TOTAL]:
|
66
|
+
if target in ru_dict:
|
67
|
+
ru_dict[target] = np.array(list(detailed_ru[target].values()))
|
68
|
+
|
69
|
+
assert all(v.ndim == 1 for v in ru_dict.values())
|
70
|
+
if RUTarget.ACTIVATION in ru_targets and RUTarget.TOTAL in ru_targets:
|
71
|
+
assert ru_dict[RUTarget.ACTIVATION].shape == ru_dict[RUTarget.TOTAL].shape
|
72
|
+
|
73
|
+
assert len(ru_dict) == len(ru_targets), (f'Mismatch between the number of computed and requested metrics.'
|
74
|
+
f'Requested {ru_targets}')
|
75
|
+
return ru_dict
|
72
76
|
|
73
77
|
def get_quantization_candidates(self, mp_cfg) \
|
74
|
-
-> Tuple[Dict[
|
78
|
+
-> Tuple[Dict[str, NodeActivationQuantizationConfig], Dict[str, NodeWeightsQuantizationConfig]]:
|
75
79
|
"""
|
76
80
|
Retrieve quantization candidates objects for weights and activations from the configuration list.
|
77
81
|
|
@@ -87,71 +91,3 @@ class MixedPrecisionRUHelper:
|
|
87
91
|
act_qcs = {n.name: cfg.activation_quantization_cfg for n, cfg in node_qcs.items()}
|
88
92
|
w_qcs = {n.name: cfg.weights_quantization_cfg for n, cfg in node_qcs.items()}
|
89
93
|
return act_qcs, w_qcs
|
90
|
-
|
91
|
-
def _weights_utilization(self, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> Dict[BaseNode, float]:
|
92
|
-
"""
|
93
|
-
Compute weights utilization for configurable weights if configuration is passed,
|
94
|
-
or for non-configurable nodes otherwise.
|
95
|
-
|
96
|
-
Args:
|
97
|
-
w_qcs: nodes quantization configuration to compute, or None.
|
98
|
-
|
99
|
-
Returns:
|
100
|
-
Weight utilization per node.
|
101
|
-
"""
|
102
|
-
if w_qcs:
|
103
|
-
target_criterion = TargetInclusionCriterion.QConfigurable
|
104
|
-
bitwidth_mode = BitwidthMode.QCustom
|
105
|
-
else:
|
106
|
-
target_criterion = TargetInclusionCriterion.QNonConfigurable
|
107
|
-
bitwidth_mode = BitwidthMode.QDefaultSP
|
108
|
-
|
109
|
-
_, nodes_util, _ = self.ru_calculator.compute_weights_utilization(target_criterion=target_criterion,
|
110
|
-
bitwidth_mode=bitwidth_mode,
|
111
|
-
w_qcs=w_qcs)
|
112
|
-
nodes_util = {n: u.bytes for n, u in nodes_util.items()}
|
113
|
-
return nodes_util
|
114
|
-
|
115
|
-
def _activation_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]) \
|
116
|
-
-> Optional[Dict[Any, float]]:
|
117
|
-
"""
|
118
|
-
Compute activation utilization using MaxCut for all quantized nodes if configuration is passed.
|
119
|
-
|
120
|
-
Args:
|
121
|
-
act_qcs: nodes activation configuration or None.
|
122
|
-
|
123
|
-
Returns:
|
124
|
-
Activation utilization per cut, or empty dict if no configuration was passed.
|
125
|
-
"""
|
126
|
-
# Maxcut activation utilization is computed for all quantized nodes, so non-configurable memory is already
|
127
|
-
# covered by the computation of configurable activations.
|
128
|
-
if not act_qcs:
|
129
|
-
return {}
|
130
|
-
|
131
|
-
_, cuts_util, *_ = self.ru_calculator.compute_activation_utilization_by_cut(
|
132
|
-
TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs)
|
133
|
-
cuts_util = {c: u.bytes for c, u in cuts_util.items()}
|
134
|
-
return cuts_util
|
135
|
-
|
136
|
-
def _bops_utilization(self,
|
137
|
-
act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]],
|
138
|
-
w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> np.ndarray:
|
139
|
-
"""
|
140
|
-
Computes a resource utilization vector with the respective bit-operations (BOPS) count
|
141
|
-
according to the given mixed-precision configuration.
|
142
|
-
|
143
|
-
Args:
|
144
|
-
act_qcs: nodes activation configuration or None.
|
145
|
-
w_qcs: nodes quantization configuration to compute, or None.
|
146
|
-
Either both are provided, or both are None.
|
147
|
-
|
148
|
-
Returns:
|
149
|
-
A vector of node's BOPS count.
|
150
|
-
"""
|
151
|
-
assert [act_qcs, w_qcs].count(None) in [0, 2], 'act_qcs and w_qcs should both be provided or both be None.'
|
152
|
-
if act_qcs is None:
|
153
|
-
return np.array([])
|
154
|
-
|
155
|
-
_, detailed_bops = self.ru_calculator.compute_bops(TargetInclusionCriterion.Any, BitwidthMode.QCustom,
|
156
|
-
act_qcs=act_qcs, w_qcs=w_qcs)
|
157
|
-
return np.array(list(detailed_bops.values()))
|
@@ -13,37 +13,27 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
import copy
|
17
16
|
from enum import Enum
|
18
|
-
import
|
19
|
-
from typing import List, Callable, Dict
|
17
|
+
from typing import List, Callable
|
20
18
|
|
21
19
|
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
|
22
20
|
from model_compression_toolkit.core.common import Graph
|
23
|
-
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
24
|
-
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
|
25
21
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
26
|
-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
|
27
|
-
from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \
|
28
|
-
mp_integer_programming_search
|
29
22
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
|
+
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
24
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import \
|
25
|
+
MixedPrecisionSearchManager
|
26
|
+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
27
|
+
ResourceUtilization
|
30
28
|
from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
|
31
29
|
greedy_solution_refinement_procedure
|
32
|
-
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
|
33
|
-
from model_compression_toolkit.logger import Logger
|
34
30
|
|
35
31
|
|
36
32
|
class BitWidthSearchMethod(Enum):
|
37
|
-
# When adding a new search_methods MP configuration method, these enum and factory dictionary
|
38
|
-
# should be updated with it's kind and a search_method implementation.
|
39
33
|
INTEGER_PROGRAMMING = 0
|
40
34
|
|
41
35
|
|
42
|
-
|
43
|
-
BitWidthSearchMethod.INTEGER_PROGRAMMING: mp_integer_programming_search}
|
44
|
-
|
45
|
-
|
46
|
-
def search_bit_width(graph_to_search_cfg: Graph,
|
36
|
+
def search_bit_width(graph: Graph,
|
47
37
|
fw_info: FrameworkInfo,
|
48
38
|
fw_impl: FrameworkImplementation,
|
49
39
|
target_resource_utilization: ResourceUtilization,
|
@@ -60,7 +50,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
|
|
60
50
|
target_resource_utilization have to be passed. If it was not passed, the facade is not supposed to get here by now.
|
61
51
|
|
62
52
|
Args:
|
63
|
-
|
53
|
+
graph: Graph to search a MP configuration for.
|
64
54
|
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
65
55
|
fw_impl: FrameworkImplementation object with specific framework methods implementation.
|
66
56
|
target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
|
@@ -75,17 +65,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
|
|
75
65
|
bit-width index on the node).
|
76
66
|
|
77
67
|
"""
|
78
|
-
|
79
|
-
# target_resource_utilization have to be passed. If it was not passed, the facade is not supposed to get here by now.
|
80
|
-
if target_resource_utilization is None:
|
81
|
-
Logger.critical("Target ResourceUtilization is required for the bit-width search method's configuration.") # pragma: no cover
|
82
|
-
|
83
|
-
# Set graph for MP search
|
84
|
-
graph = copy.deepcopy(graph_to_search_cfg) # Copy graph before searching
|
85
|
-
if target_resource_utilization.bops_restricted():
|
86
|
-
# TODO: we only need the virtual graph is both activations and weights are configurable
|
87
|
-
# Since Bit-operations count target resource utilization is set, we need to reconstruct the graph for the MP search
|
88
|
-
graph = substitute(graph, fw_impl.get_substitutions_virtual_weights_activation_coupling())
|
68
|
+
assert target_resource_utilization.is_any_restricted()
|
89
69
|
|
90
70
|
# If we only run weights compression with MP than no need to consider activation quantization when computing the
|
91
71
|
# MP metric (it adds noise to the computation)
|
@@ -93,33 +73,28 @@ def search_bit_width(graph_to_search_cfg: Graph,
|
|
93
73
|
weight_only_restricted = tru.weight_restricted() and not (tru.activation_restricted() or
|
94
74
|
tru.total_mem_restricted() or
|
95
75
|
tru.bops_restricted())
|
96
|
-
disable_activation_for_metric = weight_only_restricted or
|
76
|
+
disable_activation_for_metric = weight_only_restricted or not graph.has_any_configurable_activation()
|
97
77
|
|
98
78
|
# Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
|
99
79
|
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
|
100
80
|
se = fw_impl.get_sensitivity_evaluator(
|
101
|
-
|
81
|
+
graph,
|
102
82
|
mp_config,
|
103
83
|
representative_data_gen=representative_data_gen,
|
104
84
|
fw_info=fw_info,
|
105
85
|
disable_activation_for_metric=disable_activation_for_metric,
|
106
86
|
hessian_info_service=hessian_info_service)
|
107
87
|
|
108
|
-
|
88
|
+
if search_method != BitWidthSearchMethod.INTEGER_PROGRAMMING:
|
89
|
+
raise NotImplementedError()
|
90
|
+
|
91
|
+
# Search manager and LP are highly coupled, so LP search method was moved inside search manager.
|
109
92
|
search_manager = MixedPrecisionSearchManager(graph,
|
110
93
|
fw_info,
|
111
94
|
fw_impl,
|
112
95
|
se,
|
113
|
-
target_resource_utilization
|
114
|
-
|
115
|
-
|
116
|
-
if search_method not in search_methods:
|
117
|
-
raise NotImplementedError() # pragma: no cover
|
118
|
-
|
119
|
-
search_method_fn = search_methods[search_method]
|
120
|
-
# Search for the desired mixed-precision configuration
|
121
|
-
result_bit_cfg = search_method_fn(search_manager,
|
122
|
-
target_resource_utilization)
|
96
|
+
target_resource_utilization)
|
97
|
+
result_bit_cfg = search_manager.search()
|
123
98
|
|
124
99
|
if mp_config.refine_mp_solution:
|
125
100
|
result_bit_cfg = greedy_solution_refinement_procedure(result_bit_cfg, search_manager, target_resource_utilization)
|