mct-nightly 2.3.0.20250322.517__py3-none-any.whl → 2.3.0.20250324.606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {mct_nightly-2.3.0.20250322.517.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.3.0.20250322.517.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/RECORD +16 -16
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/base_graph.py +14 -4
  5. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +32 -96
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +17 -42
  7. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +179 -60
  8. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +22 -10
  9. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +1 -5
  10. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +14 -94
  11. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +132 -312
  12. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +1 -1
  13. model_compression_toolkit/core/runner.py +2 -12
  14. {mct_nightly-2.3.0.20250322.517.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/WHEEL +0 -0
  15. {mct_nightly-2.3.0.20250322.517.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/licenses/LICENSE.md +0 -0
  16. {mct_nightly-2.3.0.20250322.517.dist-info → mct_nightly-2.3.0.20250324.606.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250322.517
3
+ Version: 2.3.0.20250324.606
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250322.517.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=zwR3rimyZAfQQXE_6fl9E7U2oLz9pyRmT9EYhEEzQaU,1557
1
+ mct_nightly-2.3.0.20250324.606.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=cmlxKyxfMU7e0tT0g_dE3z_TjW63WzFf9t_y-H5dZ80,1557
3
3
  model_compression_toolkit/constants.py,sha256=i_R6uXBfO1ph_X6DNJych2x59SUojfJbn7dNjs_mZnc,3846
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -9,7 +9,7 @@ model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZ
9
9
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
10
10
  model_compression_toolkit/core/graph_prep_runner.py,sha256=CVTjBaci8F6EP3IKDnRMfxkP-Sv8qY8GpkGt6FyII2U,11376
11
11
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8K5f5v9g5lbT3r1GSmhLAk3NkL40Fg,6593
12
- model_compression_toolkit/core/runner.py,sha256=qblr8WM6R5v4jip94kBeWHKsjc-FUOteVgMtunGf8lU,13716
12
+ model_compression_toolkit/core/runner.py,sha256=WjZMVXc-OGBTnkiH0PRjNdJEM5pKQRPvLHXor5tjwjk,13096
13
13
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
14
14
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
15
15
  model_compression_toolkit/core/common/framework_implementation.py,sha256=s3yiqnbWkwfnAB1sSal_KAuqVg27rLhAJ2O8LHUbSHE,22494
@@ -34,7 +34,7 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
34
34
  model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=b41_4rL_Adiza4vpWlmmqgvkpUmWVdfdx0nEIB0p2n8,6195
35
35
  model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=-2fnjyC9q2RPw9st6RxROW-gdtT2mSRz0QZ_Gz1KDz4,5579
36
36
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
37
- model_compression_toolkit/core/common/graph/base_graph.py,sha256=0zsiEldkV_wjDoTjaGtL8DOMGEv2yQqhajwEAnFgqR8,37819
37
+ model_compression_toolkit/core/common/graph/base_graph.py,sha256=VhniLTiMqL7i1Vqg2UBQuFFTvw2cYeJayssUJwabp3E,38112
38
38
  model_compression_toolkit/core/common/graph/base_node.py,sha256=kZbmAMh5cPAwYzlY8KYa8w0ipL58yApB09-WXQ8plrE,33763
39
39
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
40
40
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
@@ -67,18 +67,18 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_uti
67
67
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
68
68
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
69
69
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
70
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=k7LjEmcvlkiV995DU7S1CrNOllu6qPZrhUUKXcZDIUQ,7538
71
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=4YH9tsFPOn6rCcedfyocZhZwDLNX5kB1tebu0-nvhyA,7226
72
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=ItBWNZYOf-Zzi8FaRv1y170wYRXYcR3pJysClOtH8qc,32525
70
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=2Pp4hiYvGW2I9YhloDxQNT0sZRg3TDp9CXObloF8IFU,4971
71
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=LddWtLileazCOvVSz-7j-GA4yskcGD3UHQGo7XUzSTE,5661
72
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=NBzzhkVI407S9cIiw7t7nsP3MrkOdSnweKQdPBXb8to,38180
73
73
  model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=gsigifJ-ykWNafF4t7UMEC_-nd6YPERAk1_z0kT-Y88,27172
74
74
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
75
- model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=MQZnBcpBDMd5y6rOunUtH3t41GQH0aBmxVB4muoxNfk,9477
75
+ model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=fk7PWiZ6Na5O_Z_dymk_UfDCTqW_X_4EROU7DZknQnc,9444
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
77
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=T5yVr7lay-6QLuTDBZNI1Ufj02EMBWuY_yHjC8eHx5I,3998
78
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=Oj-tVGUyBXtTpxNFQVPja8fFcUOpi6B2PdpNKHkAlbc,39314
79
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=J7gqUGs4ITo4ufl84A5vACxm670LG6RhQyXkejfpbn8,8834
77
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
78
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=xCYL36K0nK41VSsLcy52uDA7zVfoLxhubmOrtXbqw7s,39140
79
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
80
80
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
81
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=9Hh85pr0VL65umhf9mPnrrssJXwJPAsIkBwCZnfzjHY,17575
81
+ model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=rSSN5MhH5BO5b58d8pe2pY9wc5HbfescoUStfg-nWfk,7263
82
82
  model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
83
83
  model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
84
84
  model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
@@ -526,7 +526,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
526
526
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
527
527
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
528
528
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
529
- mct_nightly-2.3.0.20250322.517.dist-info/METADATA,sha256=aYBvUc3xM30BNiWYY4eIwtN5--e4Scbr-bHnYGl6lAk,27098
530
- mct_nightly-2.3.0.20250322.517.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
531
- mct_nightly-2.3.0.20250322.517.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
532
- mct_nightly-2.3.0.20250322.517.dist-info/RECORD,,
529
+ mct_nightly-2.3.0.20250324.606.dist-info/METADATA,sha256=xCfGClT5tOs76bxlrQX92fjTkmtPi-QvD6CBI8TM-EA,27098
530
+ mct_nightly-2.3.0.20250324.606.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
531
+ mct_nightly-2.3.0.20250324.606.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
532
+ mct_nightly-2.3.0.20250324.606.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250322.000517"
30
+ __version__ = "2.3.0.20250324.000606"
@@ -706,14 +706,24 @@ class Graph(nx.MultiDiGraph, GraphSearches):
706
706
  """
707
707
  self.fused_nodes.append(fusion)
708
708
 
709
- def is_single_activation_cfg(self):
709
+ def has_any_configurable_activation(self) -> bool:
710
710
  """
711
- Checks whether all nodes in the graph that have activation quantization are quantized with the same bit-width.
711
+ Checks whether any node in the graph has a configurable activation quantization.
712
712
 
713
- Returns: True if all quantization config candidates of all nodes have the same activation quantization bit-width.
713
+ Returns:
714
+ Whether any node in the graph has a configurable activation quantization.
715
+ """
716
+ return any([n.has_configurable_activation() for n in self.nodes])
717
+
718
+ def has_any_configurable_weights(self):
719
+ """
720
+ Checks whether any node in the graph has any configurable weights quantization.
714
721
 
722
+ Returns:
723
+ Whether any node in the graph has any configurable weights quantization.
715
724
  """
716
- return all([n.is_all_activation_candidates_equal() for n in self.nodes])
725
+
726
+ return any([n.has_any_configurable_weight() for n in self.nodes])
717
727
 
718
728
  def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode):
719
729
  """
@@ -12,12 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import List, Set, Dict, Optional, Tuple, Any
15
+ from typing import List, Set, Dict, Tuple
16
16
 
17
17
  import numpy as np
18
18
 
19
19
  from model_compression_toolkit.core import FrameworkInfo
20
- from model_compression_toolkit.core.common import Graph, BaseNode
20
+ from model_compression_toolkit.core.common import Graph
21
21
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
22
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
23
23
  RUTarget
@@ -36,42 +36,46 @@ class MixedPrecisionRUHelper:
36
36
  self.fw_impl = fw_impl
37
37
  self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
38
38
 
39
- def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Optional[List[int]]) -> Dict[RUTarget, np.ndarray]:
39
+ def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: List[int]) -> Dict[RUTarget, np.ndarray]:
40
40
  """
41
- Compute utilization of requested targets for a specific configuration in the format expected by LP problem
42
- formulation namely a vector of ru values for relevant memory elements (nodes or cuts) in a constant order
43
- (between calls).
41
+ Compute utilization of requested targets for a specific configuration:
42
+ for weights and bops - total utilization,
43
+ for activations and total - utilization per cut.
44
44
 
45
45
  Args:
46
46
  ru_targets: resource utilization targets to compute.
47
47
  mp_cfg: a list of candidates indices for configurable layers.
48
48
 
49
49
  Returns:
50
- Dict of the computed utilization per target.
50
+ Dict of the computed utilization per target, as 1d vector.
51
51
  """
52
-
53
- ru = {}
54
- act_qcs, w_qcs = self.get_quantization_candidates(mp_cfg) if mp_cfg else (None, None)
55
- if RUTarget.WEIGHTS in ru_targets:
56
- wu = self._weights_utilization(w_qcs)
57
- ru[RUTarget.WEIGHTS] = np.array(list(wu.values()))
58
-
59
- if RUTarget.ACTIVATION in ru_targets:
60
- au = self._activation_utilization(act_qcs)
61
- ru[RUTarget.ACTIVATION] = np.array(list(au.values()))
62
-
63
- if RUTarget.BOPS in ru_targets:
64
- ru[RUTarget.BOPS] = self._bops_utilization(act_qcs=act_qcs, w_qcs=w_qcs)
65
-
66
- if RUTarget.TOTAL in ru_targets:
67
- raise ValueError('Total target should be computed based on weights and activations targets.')
68
-
69
- assert len(ru) == len(ru_targets), (f'Mismatch between the number of computed and requested metrics.'
70
- f'Requested {ru_targets}')
71
- return ru
52
+ act_qcs, w_qcs = self.get_quantization_candidates(mp_cfg)
53
+
54
+ ru, detailed_ru = self.ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
55
+ BitwidthMode.QCustom,
56
+ act_qcs=act_qcs,
57
+ w_qcs=w_qcs,
58
+ ru_targets=ru_targets,
59
+ allow_unused_qcs=True,
60
+ return_detailed=True)
61
+
62
+ ru_dict = {k: np.array([v]) for k, v in ru.get_resource_utilization_dict(restricted_only=True).items()}
63
+ # For activation and total we need utilization per cut, as different mp configurations might result in
64
+ # different cuts to be maximal.
65
+ for target in [RUTarget.ACTIVATION, RUTarget.TOTAL]:
66
+ if target in ru_dict:
67
+ ru_dict[target] = np.array(list(detailed_ru[target].values()))
68
+
69
+ assert all(v.ndim == 1 for v in ru_dict.values())
70
+ if RUTarget.ACTIVATION in ru_targets and RUTarget.TOTAL in ru_targets:
71
+ assert ru_dict[RUTarget.ACTIVATION].shape == ru_dict[RUTarget.TOTAL].shape
72
+
73
+ assert len(ru_dict) == len(ru_targets), (f'Mismatch between the number of computed and requested metrics.'
74
+ f'Requested {ru_targets}')
75
+ return ru_dict
72
76
 
73
77
  def get_quantization_candidates(self, mp_cfg) \
74
- -> Tuple[Dict[BaseNode, NodeActivationQuantizationConfig], Dict[BaseNode, NodeWeightsQuantizationConfig]]:
78
+ -> Tuple[Dict[str, NodeActivationQuantizationConfig], Dict[str, NodeWeightsQuantizationConfig]]:
75
79
  """
76
80
  Retrieve quantization candidates objects for weights and activations from the configuration list.
77
81
 
@@ -87,71 +91,3 @@ class MixedPrecisionRUHelper:
87
91
  act_qcs = {n.name: cfg.activation_quantization_cfg for n, cfg in node_qcs.items()}
88
92
  w_qcs = {n.name: cfg.weights_quantization_cfg for n, cfg in node_qcs.items()}
89
93
  return act_qcs, w_qcs
90
-
91
- def _weights_utilization(self, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> Dict[BaseNode, float]:
92
- """
93
- Compute weights utilization for configurable weights if configuration is passed,
94
- or for non-configurable nodes otherwise.
95
-
96
- Args:
97
- w_qcs: nodes quantization configuration to compute, or None.
98
-
99
- Returns:
100
- Weight utilization per node.
101
- """
102
- if w_qcs:
103
- target_criterion = TargetInclusionCriterion.QConfigurable
104
- bitwidth_mode = BitwidthMode.QCustom
105
- else:
106
- target_criterion = TargetInclusionCriterion.QNonConfigurable
107
- bitwidth_mode = BitwidthMode.QDefaultSP
108
-
109
- _, nodes_util, _ = self.ru_calculator.compute_weights_utilization(target_criterion=target_criterion,
110
- bitwidth_mode=bitwidth_mode,
111
- w_qcs=w_qcs)
112
- nodes_util = {n: u.bytes for n, u in nodes_util.items()}
113
- return nodes_util
114
-
115
- def _activation_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]) \
116
- -> Optional[Dict[Any, float]]:
117
- """
118
- Compute activation utilization using MaxCut for all quantized nodes if configuration is passed.
119
-
120
- Args:
121
- act_qcs: nodes activation configuration or None.
122
-
123
- Returns:
124
- Activation utilization per cut, or empty dict if no configuration was passed.
125
- """
126
- # Maxcut activation utilization is computed for all quantized nodes, so non-configurable memory is already
127
- # covered by the computation of configurable activations.
128
- if not act_qcs:
129
- return {}
130
-
131
- _, cuts_util, *_ = self.ru_calculator.compute_activation_utilization_by_cut(
132
- TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs)
133
- cuts_util = {c: u.bytes for c, u in cuts_util.items()}
134
- return cuts_util
135
-
136
- def _bops_utilization(self,
137
- act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]],
138
- w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> np.ndarray:
139
- """
140
- Computes a resource utilization vector with the respective bit-operations (BOPS) count
141
- according to the given mixed-precision configuration.
142
-
143
- Args:
144
- act_qcs: nodes activation configuration or None.
145
- w_qcs: nodes quantization configuration to compute, or None.
146
- Either both are provided, or both are None.
147
-
148
- Returns:
149
- A vector of node's BOPS count.
150
- """
151
- assert [act_qcs, w_qcs].count(None) in [0, 2], 'act_qcs and w_qcs should both be provided or both be None.'
152
- if act_qcs is None:
153
- return np.array([])
154
-
155
- _, detailed_bops = self.ru_calculator.compute_bops(TargetInclusionCriterion.Any, BitwidthMode.QCustom,
156
- act_qcs=act_qcs, w_qcs=w_qcs)
157
- return np.array(list(detailed_bops.values()))
@@ -13,37 +13,27 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import copy
17
16
  from enum import Enum
18
- import numpy as np
19
- from typing import List, Callable, Dict
17
+ from typing import List, Callable
20
18
 
21
19
  from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
22
20
  from model_compression_toolkit.core.common import Graph
23
- from model_compression_toolkit.core.common.hessian import HessianInfoService
24
- from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
25
21
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
26
- from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
27
- from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \
28
- mp_integer_programming_search
29
22
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
+ from model_compression_toolkit.core.common.hessian import HessianInfoService
24
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import \
25
+ MixedPrecisionSearchManager
26
+ from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
27
+ ResourceUtilization
30
28
  from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
31
29
  greedy_solution_refinement_procedure
32
- from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
33
- from model_compression_toolkit.logger import Logger
34
30
 
35
31
 
36
32
  class BitWidthSearchMethod(Enum):
37
- # When adding a new search_methods MP configuration method, these enum and factory dictionary
38
- # should be updated with it's kind and a search_method implementation.
39
33
  INTEGER_PROGRAMMING = 0
40
34
 
41
35
 
42
- search_methods = {
43
- BitWidthSearchMethod.INTEGER_PROGRAMMING: mp_integer_programming_search}
44
-
45
-
46
- def search_bit_width(graph_to_search_cfg: Graph,
36
+ def search_bit_width(graph: Graph,
47
37
  fw_info: FrameworkInfo,
48
38
  fw_impl: FrameworkImplementation,
49
39
  target_resource_utilization: ResourceUtilization,
@@ -60,7 +50,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
60
50
  target_resource_utilization have to be passed. If it was not passed, the facade is not supposed to get here by now.
61
51
 
62
52
  Args:
63
- graph_to_search_cfg: Graph to search a MP configuration for.
53
+ graph: Graph to search a MP configuration for.
64
54
  fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
65
55
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
66
56
  target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
@@ -75,17 +65,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
75
65
  bit-width index on the node).
76
66
 
77
67
  """
78
-
79
- # target_resource_utilization have to be passed. If it was not passed, the facade is not supposed to get here by now.
80
- if target_resource_utilization is None:
81
- Logger.critical("Target ResourceUtilization is required for the bit-width search method's configuration.") # pragma: no cover
82
-
83
- # Set graph for MP search
84
- graph = copy.deepcopy(graph_to_search_cfg) # Copy graph before searching
85
- if target_resource_utilization.bops_restricted():
86
- # TODO: we only need the virtual graph is both activations and weights are configurable
87
- # Since Bit-operations count target resource utilization is set, we need to reconstruct the graph for the MP search
88
- graph = substitute(graph, fw_impl.get_substitutions_virtual_weights_activation_coupling())
68
+ assert target_resource_utilization.is_any_restricted()
89
69
 
90
70
  # If we only run weights compression with MP than no need to consider activation quantization when computing the
91
71
  # MP metric (it adds noise to the computation)
@@ -93,33 +73,28 @@ def search_bit_width(graph_to_search_cfg: Graph,
93
73
  weight_only_restricted = tru.weight_restricted() and not (tru.activation_restricted() or
94
74
  tru.total_mem_restricted() or
95
75
  tru.bops_restricted())
96
- disable_activation_for_metric = weight_only_restricted or graph_to_search_cfg.is_single_activation_cfg()
76
+ disable_activation_for_metric = weight_only_restricted or not graph.has_any_configurable_activation()
97
77
 
98
78
  # Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
99
79
  # even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
100
80
  se = fw_impl.get_sensitivity_evaluator(
101
- graph_to_search_cfg,
81
+ graph,
102
82
  mp_config,
103
83
  representative_data_gen=representative_data_gen,
104
84
  fw_info=fw_info,
105
85
  disable_activation_for_metric=disable_activation_for_metric,
106
86
  hessian_info_service=hessian_info_service)
107
87
 
108
- # Instantiate a manager object
88
+ if search_method != BitWidthSearchMethod.INTEGER_PROGRAMMING:
89
+ raise NotImplementedError()
90
+
91
+ # Search manager and LP are highly coupled, so LP search method was moved inside search manager.
109
92
  search_manager = MixedPrecisionSearchManager(graph,
110
93
  fw_info,
111
94
  fw_impl,
112
95
  se,
113
- target_resource_utilization,
114
- original_graph=graph_to_search_cfg)
115
-
116
- if search_method not in search_methods:
117
- raise NotImplementedError() # pragma: no cover
118
-
119
- search_method_fn = search_methods[search_method]
120
- # Search for the desired mixed-precision configuration
121
- result_bit_cfg = search_method_fn(search_manager,
122
- target_resource_utilization)
96
+ target_resource_utilization)
97
+ result_bit_cfg = search_manager.search()
123
98
 
124
99
  if mp_config.refine_mp_solution:
125
100
  result_bit_cfg = greedy_solution_refinement_procedure(result_bit_cfg, search_manager, target_resource_utilization)