mct-nightly 2.3.0.20250505.616__py3-none-any.whl → 2.3.0.20250507.555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250505.616.dist-info → mct_nightly-2.3.0.20250507.555.dist-info}/METADATA +4 -2
- {mct_nightly-2.3.0.20250505.616.dist-info → mct_nightly-2.3.0.20250507.555.dist-info}/RECORD +9 -9
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +16 -16
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +179 -352
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +6 -27
- {mct_nightly-2.3.0.20250505.616.dist-info → mct_nightly-2.3.0.20250507.555.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250505.616.dist-info → mct_nightly-2.3.0.20250507.555.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250505.616.dist-info → mct_nightly-2.3.0.20250507.555.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250505.616.dist-info → mct_nightly-2.3.0.20250507.555.dist-info}/METADATA
RENAMED
@@ -1,7 +1,8 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mct-nightly
|
3
|
-
Version: 2.3.0.
|
3
|
+
Version: 2.3.0.20250507.555
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
5
|
+
Author-email: ssi-dnn-dev@sony.com
|
5
6
|
Classifier: Programming Language :: Python :: 3
|
6
7
|
Classifier: License :: OSI Approved :: Apache Software License
|
7
8
|
Classifier: Operating System :: OS Independent
|
@@ -23,6 +24,7 @@ Requires-Dist: protobuf
|
|
23
24
|
Requires-Dist: mct-quantizers-nightly
|
24
25
|
Requires-Dist: pydantic>=2.0
|
25
26
|
Requires-Dist: sony-custom-layers-dev==0.4.0.dev6
|
27
|
+
Dynamic: author-email
|
26
28
|
Dynamic: classifier
|
27
29
|
Dynamic: description
|
28
30
|
Dynamic: description-content-type
|
@@ -51,7 +53,7 @@ ______________________________________________________________________
|
|
51
53
|
</p>
|
52
54
|
<p align="center">
|
53
55
|
<a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/pytorch-2.2%20%7C%202.3%20%7C%202.4%20%7C%202.5-blue" /></a>
|
54
|
-
<a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/tensorflow-
|
56
|
+
<a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/tensorflow-2.14%20%7C%202.15-blue" /></a>
|
55
57
|
<a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue" /></a>
|
56
58
|
<a href="https://github.com/sony/model_optimization/releases"><img src="https://img.shields.io/github/v/release/sony/model_optimization" /></a>
|
57
59
|
<a href="https://github.com/sony/model_optimization/blob/main/LICENSE.md"><img src="https://img.shields.io/badge/license-Apache%202.0-blue" /></a>
|
{mct_nightly-2.3.0.20250505.616.dist-info → mct_nightly-2.3.0.20250507.555.dist-info}/RECORD
RENAMED
@@ -1,5 +1,5 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250507.555.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=QzNbJcOvpHUdWIDaA4UVEDr7PLGs8z-feuEZ7nopltg,1557
|
3
3
|
model_compression_toolkit/constants.py,sha256=iJ6vfTjC2oFIZWt8wvHoxEw5YJi3yl0Hd4q30_8q0Zc,3958
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
@@ -40,7 +40,7 @@ model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-
|
|
40
40
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
|
41
41
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
42
42
|
model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
|
43
|
-
model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=
|
43
|
+
model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=ahuvX2H7__hwTrtR02QbadlDJjagvKovFg6KKNU9svo,10443
|
44
44
|
model_compression_toolkit/core/common/graph/memory_graph/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
45
45
|
model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py,sha256=X6FK3C3y8ixFRPjC_wm3ClloCX8_06SOdA1TRi7o_LA,3800
|
46
46
|
model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py,sha256=oyz260JXDbvL8aI-DVtUvLHtLRWC2Yu4SBYlGL68c2Y,3498
|
@@ -69,13 +69,13 @@ model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates
|
|
69
69
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=onHgDwfw8CUbZFNU-RYit9eqA6FrzAtFA3akVZ2d7IM,4533
|
70
70
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=-hOMBucYn12ePyLd0b1KxniPOIRu4b53SwEzv0bWToI,4943
|
71
71
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=d5-3j2e_rdcQOT7c4s0p7640i3nSetjJ6MgMhhMM7dc,6152
|
72
|
-
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=
|
72
|
+
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=Lk5cftihGpgFQoyqnRGiwJFFqkI8dkx0l1q0sVJi2CE,27505
|
73
73
|
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=R3UIO9lKf-lpEGfJOqgpQAXdP1IWMatWxXKYDkhWj_E,28096
|
74
74
|
model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
|
75
75
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=S1ChgxtUjzXJufNWyRbKoNdyNC6fGUjPeComDMx8ZCo,9479
|
76
76
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
77
77
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
|
78
|
-
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256
|
78
|
+
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=-kNcmQQFVHRPizInaRrCEIuh_q_57CWxC6CIV6azF4g,39640
|
79
79
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
|
80
80
|
model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
81
81
|
model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=6Z6nQL9UH7B8dbcUR0cuCTEYFOKZAlvOb-SCk_cAZFA,6670
|
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
528
528
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
529
529
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
530
530
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
533
|
-
mct_nightly-2.3.0.
|
534
|
-
mct_nightly-2.3.0.
|
531
|
+
mct_nightly-2.3.0.20250507.555.dist-info/METADATA,sha256=hIfm1mpPLcDseqKnO60EFxdc5f3T66WfAIb0gwT7TEk,25157
|
532
|
+
mct_nightly-2.3.0.20250507.555.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
533
|
+
mct_nightly-2.3.0.20250507.555.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
534
|
+
mct_nightly-2.3.0.20250507.555.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250507.000555"
|
@@ -12,22 +12,25 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
import abc
|
15
16
|
import uuid
|
16
17
|
|
17
|
-
from typing import Dict, Any, Tuple
|
18
|
-
|
19
18
|
from model_compression_toolkit.core import FrameworkInfo
|
20
19
|
from model_compression_toolkit.constants import VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX, \
|
21
20
|
VIRTUAL_WEIGHTS_SUFFIX, VIRTUAL_ACTIVATION_SUFFIX, FLOAT_BITWIDTH
|
22
|
-
|
21
|
+
from model_compression_toolkit.core.common.framework_info import DEFAULT_KERNEL_ATTRIBUTES
|
23
22
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
24
|
-
import numpy as np
|
25
|
-
|
26
23
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
27
24
|
CandidateNodeQuantizationConfig
|
25
|
+
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
|
26
|
+
|
28
27
|
|
28
|
+
class VirtualNode(BaseNode, abc.ABC):
|
29
|
+
""" Base class for all virtual nodes. """
|
30
|
+
pass
|
29
31
|
|
30
|
-
|
32
|
+
|
33
|
+
class VirtualSplitNode(VirtualNode, abc.ABC):
|
31
34
|
"""
|
32
35
|
A class that represents a node that was split from a kernel node (node with weights).
|
33
36
|
"""
|
@@ -73,14 +76,11 @@ class VirtualSplitWeightsNode(VirtualSplitNode):
|
|
73
76
|
super().__init__(origin_node)
|
74
77
|
|
75
78
|
self.name = origin_node.name + VIRTUAL_WEIGHTS_SUFFIX
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
# from the original graph, so in the worst case the utilization will be higher in virtual graph.
|
82
|
-
# This should guarantee that the utilization of the original graph does not exceed the requested target.
|
83
|
-
self.candidates_quantization_cfg = origin_node.candidates_quantization_cfg
|
79
|
+
|
80
|
+
self.candidates_quantization_cfg = origin_node.get_unique_weights_candidates(kernel_attr)
|
81
|
+
for c in self.candidates_quantization_cfg:
|
82
|
+
c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
|
83
|
+
c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
|
84
84
|
|
85
85
|
|
86
86
|
class VirtualSplitActivationNode(VirtualSplitNode):
|
@@ -113,7 +113,7 @@ class VirtualSplitActivationNode(VirtualSplitNode):
|
|
113
113
|
c.weights_quantization_cfg.weights_n_bits = FLOAT_BITWIDTH
|
114
114
|
|
115
115
|
|
116
|
-
class VirtualActivationWeightsNode(
|
116
|
+
class VirtualActivationWeightsNode(VirtualNode):
|
117
117
|
"""
|
118
118
|
A node that represents a composition of pair of sequential activation node and weights (kernel) node.
|
119
119
|
This structure is used for mixed-precision search with bit-operation constraint.
|
@@ -149,7 +149,7 @@ class VirtualActivationWeightsNode(BaseNode):
|
|
149
149
|
weights = weights_node.weights.copy()
|
150
150
|
act_node_w_rename = {}
|
151
151
|
if act_node.weights:
|
152
|
-
if
|
152
|
+
if fw_info.get_kernel_op_attributes(act_node) != DEFAULT_KERNEL_ATTRIBUTES:
|
153
153
|
raise NotImplementedError(f'Node {act_node} with kernel cannot be used as activation for '
|
154
154
|
f'VirtualActivationWeightsNode.')
|
155
155
|
if act_node.has_any_configurable_weight():
|
@@ -19,7 +19,7 @@ from collections import defaultdict
|
|
19
19
|
|
20
20
|
from tqdm import tqdm
|
21
21
|
|
22
|
-
from typing import Dict, List, Tuple
|
22
|
+
from typing import Dict, List, Tuple, Optional
|
23
23
|
|
24
24
|
import numpy as np
|
25
25
|
|
@@ -28,7 +28,7 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
|
|
28
28
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
29
29
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
30
30
|
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
|
31
|
-
VirtualSplitWeightsNode, VirtualSplitActivationNode
|
31
|
+
VirtualSplitWeightsNode, VirtualSplitActivationNode, VirtualNode
|
32
32
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
33
33
|
RUTarget, ResourceUtilization
|
34
34
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
|
@@ -83,10 +83,9 @@ class MixedPrecisionSearchManager:
|
|
83
83
|
self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info)
|
84
84
|
self.max_ru_config: Dict[BaseNode, int] = self.mp_graph.get_max_candidates_config(fw_info)
|
85
85
|
|
86
|
-
self.config_reconstruction_helper = ConfigReconstructionHelper(
|
87
|
-
original_graph=self.original_graph)
|
86
|
+
self.config_reconstruction_helper = ConfigReconstructionHelper(self.original_graph)
|
88
87
|
if self.using_virtual_graph:
|
89
|
-
real_min_ru_config
|
88
|
+
real_min_ru_config = self.config_reconstruction_helper.reconstruct_full_configuration(self.min_ru_config)
|
90
89
|
self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, real_min_ru_config)
|
91
90
|
else:
|
92
91
|
self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config)
|
@@ -101,7 +100,7 @@ class MixedPrecisionSearchManager:
|
|
101
100
|
mp_config = self._prepare_and_run_solver()
|
102
101
|
|
103
102
|
if self.using_virtual_graph:
|
104
|
-
mp_config = self.config_reconstruction_helper.
|
103
|
+
mp_config = self.config_reconstruction_helper.reconstruct_full_configuration(mp_config)
|
105
104
|
|
106
105
|
return mp_config
|
107
106
|
|
@@ -112,9 +111,9 @@ class MixedPrecisionSearchManager:
|
|
112
111
|
Returns:
|
113
112
|
Mapping from nodes to indices of the selected bit-widths candidate.
|
114
113
|
"""
|
115
|
-
layers_candidates_sensitivity: Dict[BaseNode, List[float]] = self._build_sensitivity_mapping()
|
116
114
|
candidates_ru = self._compute_relative_ru_matrices()
|
117
115
|
rel_target_ru = self._get_relative_ru_constraint_per_mem_element()
|
116
|
+
layers_candidates_sensitivity: Dict[BaseNode, List[float]] = self._build_sensitivity_mapping()
|
118
117
|
solver = MixedPrecisionIntegerLPSolver(layers_candidates_sensitivity, candidates_ru, rel_target_ru)
|
119
118
|
mp_config = solver.run()
|
120
119
|
return mp_config
|
@@ -171,8 +170,7 @@ class MixedPrecisionSearchManager:
|
|
171
170
|
topo_cfg(baseline_cfg) if baseline_cfg else None)
|
172
171
|
|
173
172
|
if self.using_virtual_graph:
|
174
|
-
origin_max_config = self.config_reconstruction_helper.
|
175
|
-
self.max_ru_config)
|
173
|
+
origin_max_config = self.config_reconstruction_helper.reconstruct_full_configuration(self.max_ru_config)
|
176
174
|
max_config_value = compute_metric(origin_max_config)
|
177
175
|
else:
|
178
176
|
max_config_value = compute_metric(self.max_ru_config)
|
@@ -192,22 +190,12 @@ class MixedPrecisionSearchManager:
|
|
192
190
|
# Build a distance matrix using the function we got from the framework implementation.
|
193
191
|
if self.using_virtual_graph:
|
194
192
|
# Reconstructing original graph's configuration from virtual graph's configuration
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
original_base_config=origin_max_config)
|
200
|
-
origin_changed_nodes_indices = [i for i, (n, c) in enumerate(origin_max_config.items()) if
|
201
|
-
c != origin_mp_model_configuration[n]]
|
202
|
-
metric_value = compute_metric(
|
203
|
-
origin_mp_model_configuration,
|
204
|
-
origin_changed_nodes_indices,
|
205
|
-
origin_max_config)
|
193
|
+
orig_mp_config = self.config_reconstruction_helper.reconstruct_full_configuration(mp_model_configuration)
|
194
|
+
changed_nodes = [orig_sorted_nodes.index(n) for n, ind in orig_mp_config.items()
|
195
|
+
if origin_max_config[n] != ind]
|
196
|
+
metric_value = compute_metric(orig_mp_config, changed_nodes, origin_max_config)
|
206
197
|
else:
|
207
|
-
metric_value = compute_metric(
|
208
|
-
mp_model_configuration,
|
209
|
-
[node_idx],
|
210
|
-
self.max_ru_config)
|
198
|
+
metric_value = compute_metric(mp_model_configuration, [node_idx], self.max_ru_config)
|
211
199
|
metric_value = max(metric_value, max_config_value + eps)
|
212
200
|
layer_to_metrics_mapping[node].append(metric_value)
|
213
201
|
|
@@ -256,7 +244,7 @@ class MixedPrecisionSearchManager:
|
|
256
244
|
else:
|
257
245
|
cfg = self.min_ru_config.copy()
|
258
246
|
cfg[node] = candidate_idx
|
259
|
-
real_cfg = self.config_reconstruction_helper.
|
247
|
+
real_cfg = self.config_reconstruction_helper.reconstruct_full_configuration(cfg)
|
260
248
|
candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, real_cfg)
|
261
249
|
|
262
250
|
for target, ru in candidate_rus.items():
|
@@ -326,353 +314,192 @@ class MixedPrecisionSearchManager:
|
|
326
314
|
|
327
315
|
|
328
316
|
class ConfigReconstructionHelper:
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
It provides a reconstruct_config_from_virtual_graph which allows to translate a bit-width config of a virtual graph
|
333
|
-
to a config of the original configurable nodes.
|
334
|
-
"""
|
335
|
-
|
336
|
-
def __init__(self, virtual_graph: Graph, original_graph: Graph):
|
337
|
-
"""
|
338
|
-
Init a ConfigReconstructionHelper object.
|
339
|
-
It holds a dictionary variable named origin_node_idx_to_cfg which holds the mapping from an original graph's
|
340
|
-
configurable node to its actual bit-width index (this data structure is being cleared
|
341
|
-
after every reconstruction call).
|
342
|
-
|
343
|
-
Args:
|
344
|
-
virtual_graph: The virtual graph.
|
345
|
-
original_graph: The original graph.
|
346
|
-
"""
|
347
|
-
|
348
|
-
self.virtual_graph = virtual_graph
|
349
|
-
self.original_graph = original_graph
|
350
|
-
self.fw_info = original_graph.fw_info
|
351
|
-
|
352
|
-
self.virtual_sorted_nodes_names = self.virtual_graph.get_configurable_sorted_nodes_names(self.fw_info)
|
353
|
-
self.origin_sorted_conf_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
|
354
|
-
self.origin_sorted_conf_nodes_names = [n.name for n in self.origin_sorted_conf_nodes]
|
355
|
-
|
356
|
-
self.origin_node_idx_to_cfg = {}
|
357
|
-
|
358
|
-
def _clear_reconstruction_dict(self):
|
359
|
-
"""
|
360
|
-
Clears the origin_node_idx_to_cfg data structure.
|
361
|
-
"""
|
362
|
-
|
363
|
-
self.origin_node_idx_to_cfg = {}
|
364
|
-
|
365
|
-
def reconstruct_config_from_virtual_graph(self,
|
366
|
-
virtual_mp_cfg: Dict[BaseNode, int],
|
367
|
-
changed_virtual_nodes_idx: List[int] = None,
|
368
|
-
original_base_config: Dict[BaseNode, int] = None) -> Dict[BaseNode, int]:
|
369
|
-
"""
|
370
|
-
Reconstructs the original config for a given virtual graph mixed-precision config.
|
371
|
-
It iterates over all virtual configurable node (that has some chosen bit-width virtual candidate)
|
372
|
-
and translates its chosen candidate to a candidate index of configurable nodes in the original graph.
|
373
|
-
The translation is based of the virtual node's type. Note that if the node is a split activation node
|
374
|
-
for instance, then we need to find its matching weights node in order to construct the original linear node's
|
375
|
-
chosen config.
|
376
|
-
|
377
|
-
Args:
|
378
|
-
virtual_mp_cfg: A mixed-precision configuration (list of candidates indices) of the virtual graph.
|
379
|
-
changed_virtual_nodes_idx: Provide an optional list of virtual nodes indices for which the
|
380
|
-
config reconstruction will be computed.
|
381
|
-
original_base_config: If changed_virtual_nodes_idx is provided, need to provide a base config from which the
|
382
|
-
bit-width for all un-changed original nodes will be taken.
|
383
|
-
|
384
|
-
Returns: A mixed-precision configuration (list of candidates indices) of the original graph.
|
385
|
-
|
386
|
-
"""
|
387
|
-
|
388
|
-
if changed_virtual_nodes_idx is not None:
|
389
|
-
if original_base_config is None:
|
390
|
-
Logger.critical("To run config reconstruction for a partial set of nodes, a base original config must be provided.") # pragma: no cover
|
391
|
-
|
392
|
-
updated_virtual_nodes = \
|
393
|
-
[(idx, self.virtual_graph.get_configurable_sorted_nodes(self.fw_info)[idx]) for idx in changed_virtual_nodes_idx]
|
394
|
-
# Iterating only over the virtual nodes that have updated config
|
395
|
-
for virtual_node_idx, n in updated_virtual_nodes:
|
396
|
-
self.reconstruct_node_config(n, list(virtual_mp_cfg.values()), virtual_node_idx)
|
397
|
-
# Updating reconstructed config for all other nodes based on provided base_config
|
398
|
-
original_sorted_conf_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
|
399
|
-
for i, (n, qc_ind) in enumerate(original_base_config.items()):
|
400
|
-
if i not in list(self.origin_node_idx_to_cfg.keys()):
|
401
|
-
self.update_config_at_original_idx(n=n, origin_cfg_idx=qc_ind)
|
402
|
-
else:
|
403
|
-
# Reconstruct entire config
|
404
|
-
for virtual_node_idx, n in enumerate(self.virtual_graph.get_configurable_sorted_nodes(self.fw_info)):
|
405
|
-
self.reconstruct_node_config(n, list(virtual_mp_cfg.values()), virtual_node_idx)
|
406
|
-
|
407
|
-
res_config = [self.origin_node_idx_to_cfg[key] for key in sorted(self.origin_node_idx_to_cfg.keys())]
|
408
|
-
self._clear_reconstruction_dict()
|
409
|
-
assert len(res_config) == len(self.origin_sorted_conf_nodes)
|
410
|
-
return {n: candidate_idx for n, candidate_idx in zip(self.origin_sorted_conf_nodes, res_config)}
|
411
|
-
|
412
|
-
def reconstruct_node_config(self,
|
413
|
-
n: BaseNode,
|
414
|
-
virtual_mp_cfg: List[int],
|
415
|
-
virtual_node_idx: int):
|
416
|
-
"""
|
417
|
-
Reconstructs the original configuration for a single node. Updates the mapping inplace.
|
418
|
-
|
419
|
-
Args:
|
420
|
-
n: The node to reconstruct the configuration for.
|
421
|
-
virtual_mp_cfg: A mixed-precision configuration (list of candidates indices) of the virtual graph.
|
422
|
-
virtual_node_idx: The index of the virtual node in the virtual mixed-precision configuration.
|
423
|
-
"""
|
424
|
-
|
425
|
-
virtual_cfg_idx = virtual_mp_cfg[virtual_node_idx]
|
426
|
-
|
427
|
-
if isinstance(n, VirtualActivationWeightsNode):
|
428
|
-
weights_node = n.original_weights_node
|
429
|
-
if isinstance(weights_node, VirtualSplitWeightsNode):
|
430
|
-
self.get_activation_for_split_weights(weights_node, n, virtual_cfg_idx, virtual_mp_cfg)
|
431
|
-
else:
|
432
|
-
Logger.critical(f"Virtual graph construction error: Expected all weights nodes to be split into weights and activation nodes. Found node '{n.name}' not split as expected. Every weights node should correspond to a VirtualSplitWeightsNode type.") # pragma: no cover
|
433
|
-
|
434
|
-
activation_node = n.original_activation_node
|
435
|
-
if isinstance(activation_node, VirtualSplitActivationNode):
|
436
|
-
self.get_weights_for_split_activation(activation_node, n, virtual_cfg_idx, virtual_mp_cfg)
|
437
|
-
else:
|
438
|
-
if activation_node.name in self.origin_sorted_conf_nodes_names:
|
439
|
-
# It is possible that the original activation node is not configurable,
|
440
|
-
# in this case we don't need to retrieve its bit-width config
|
441
|
-
self.retrieve_activation_only_config(activation_node, n, virtual_cfg_idx)
|
442
|
-
elif isinstance(n, VirtualSplitWeightsNode):
|
443
|
-
# If the node's predecessor have multiple outgoing edges then it is possible that this weights
|
444
|
-
# node is not composed with an activation, but otherwise there is something wrong, and we need
|
445
|
-
# to raise an exception
|
446
|
-
predecessor = self.virtual_graph.get_prev_nodes(n)
|
447
|
-
assert len(predecessor) == 1 # Sanity check
|
448
|
-
predecessor = predecessor[0]
|
449
|
-
if len(self.virtual_graph.out_edges(predecessor)) > 1:
|
450
|
-
# It's ok, need to find the node's configuration
|
451
|
-
self.get_activation_for_split_weights(n, n, virtual_cfg_idx, virtual_mp_cfg)
|
452
|
-
else:
|
453
|
-
Logger.critical(f"Virtual graph configuration error: Expected the predecessor of node '{n.name}' to have multiple outputs when not composed with an activation node.") # pragma: no cover
|
454
|
-
elif isinstance(n, VirtualSplitActivationNode):
|
455
|
-
self.get_weights_for_split_activation(n, n, virtual_cfg_idx, virtual_mp_cfg)
|
456
|
-
else:
|
457
|
-
# Node didn't change in virtual graph - candidates list is similar to original
|
458
|
-
if n.name not in self.origin_sorted_conf_nodes_names:
|
459
|
-
Logger.critical(f"Configuration mismatch: Node '{n.name}' is configurable in the virtual graph but not in the original graph. Verify node configurations.") # pragma: no cover
|
460
|
-
origin_idx = self.origin_sorted_conf_nodes_names.index(n.name)
|
461
|
-
self.origin_node_idx_to_cfg[origin_idx] = virtual_cfg_idx
|
462
|
-
|
463
|
-
def retrieve_weights_only_config(self, weights_node: BaseNode, virtual_node: BaseNode, virtual_cfg_idx: int):
|
464
|
-
"""
|
465
|
-
Retrieves the configuration of an original weights configurable node based on a
|
466
|
-
virtual weights configurable node's chosen config idx, and updates (inplace) the origin_cfg_idx mapping dict.
|
467
|
-
If the original node is not configurable, nothing will be updated.
|
317
|
+
def __init__(self, original_graph):
|
318
|
+
# mapping in order to return the actual node objects from the original graph
|
319
|
+
self.orig_nodes = {n.name: n for n in original_graph.nodes}
|
468
320
|
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
virtual_cfg_idx: The virtual node's chosen config index.
|
473
|
-
"""
|
474
|
-
|
475
|
-
if weights_node.name in self.origin_sorted_conf_nodes_names:
|
476
|
-
# It is possible that the original weights node is not configurable,
|
477
|
-
# in this case we don't need to retrieve its bit-width config
|
478
|
-
kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0]
|
479
|
-
weights_bitwidth = (virtual_node.candidates_quantization_cfg[virtual_cfg_idx].weights_quantization_cfg
|
480
|
-
.get_attr_config(kernel_attr).weights_n_bits)
|
481
|
-
origin_cfg_idx = [i for i, c in
|
482
|
-
enumerate(weights_node.candidates_quantization_cfg) if
|
483
|
-
c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_bitwidth]
|
484
|
-
|
485
|
-
self.update_config_at_original_idx(weights_node, origin_cfg_idx[0])
|
486
|
-
|
487
|
-
def retrieve_activation_only_config(self, activation_node: BaseNode, virtual_node: BaseNode, virtual_cfg_idx: int):
|
321
|
+
def reconstruct_full_configuration(self,
|
322
|
+
virtual_cfg: Dict[BaseNode, int],
|
323
|
+
include_non_configurable: bool = False) -> Dict[BaseNode, int]:
|
488
324
|
"""
|
489
|
-
|
490
|
-
|
491
|
-
|
325
|
+
Convert a configuration of a virtual graph into the corresponding configuration of the original graph.
|
326
|
+
Note that a configurable VirtualActivationWeightsNode might comprise one configurable and one non-configurable
|
327
|
+
original nodes.
|
492
328
|
|
493
329
|
Args:
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
"""
|
330
|
+
virtual_cfg: a mapping from nodes in the virtual graph to selected candidate index. Should contain all
|
331
|
+
configurable nodes of the virtual graph, and only configurable nodes.
|
332
|
+
include_non_configurable: whether to return configs for non-configurable original nodes.
|
498
333
|
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
334
|
+
Returns:
|
335
|
+
A mapping from configurable nodes in the original graph to their candidate indices.
|
336
|
+
"""
|
337
|
+
# Original candidate of a node that has been split might be determined by two different virtual nodes, one
|
338
|
+
# determines activation and one - weights. First, for each virtual node we collect the original
|
339
|
+
# activation / weights nodes, with all original candidates that match the virtual candidate
|
340
|
+
# activation / weights config. If both activation and weights of the original node are determined by virtual
|
341
|
+
# candidates, we look for a common candidate.
|
342
|
+
orig_nodes_a_candidates = {}
|
343
|
+
orig_nodes_w_candidates = {}
|
344
|
+
for virtual_node, virtual_qc_ind in virtual_cfg.items():
|
345
|
+
assert virtual_node.has_configurable_activation() or virtual_node.has_any_configurable_weight()
|
346
|
+
orig_a_node, orig_a_candidates = self._retrieve_matching_orig_a_candidates(virtual_node, virtual_qc_ind)
|
347
|
+
if orig_a_node and (include_non_configurable or orig_a_node.has_configurable_activation()):
|
348
|
+
assert orig_a_node not in orig_nodes_a_candidates
|
349
|
+
orig_nodes_a_candidates[orig_a_node] = orig_a_candidates
|
350
|
+
orig_w_node, orig_w_candidates = self._retrieve_matching_orig_w_candidates(virtual_node, virtual_qc_ind)
|
351
|
+
if orig_w_node and (include_non_configurable or orig_w_node.has_any_configurable_weight()):
|
352
|
+
assert orig_w_node not in orig_nodes_w_candidates
|
353
|
+
orig_nodes_w_candidates[orig_w_node] = orig_w_candidates
|
354
|
+
|
355
|
+
orig_cfg = {}
|
356
|
+
common_orig_nodes = set(orig_nodes_a_candidates.keys()).intersection(set(orig_nodes_w_candidates))
|
357
|
+
for orig_node in common_orig_nodes:
|
358
|
+
a_candidates = orig_nodes_a_candidates[orig_node]
|
359
|
+
w_candidates = orig_nodes_w_candidates[orig_node]
|
360
|
+
# find the common candidate
|
361
|
+
common_candidates = set(a_candidates).intersection(set(w_candidates))
|
362
|
+
if len(common_candidates) != 1:
|
363
|
+
raise ValueError(f'Expected to find exactly one candidate with the required activation and weights '
|
364
|
+
f'quantization configuration for node {orig_node}. Found {len(common_candidates)}')
|
365
|
+
# in theory it's possible that original non-configurable node gets split and each part is combined
|
366
|
+
# with a configurable part of another node and we end up here
|
367
|
+
if orig_node.has_configurable_activation() or orig_node.has_any_configurable_weight():
|
368
|
+
orig_cfg[orig_node] = common_candidates.pop()
|
369
|
+
del orig_nodes_a_candidates[orig_node]
|
370
|
+
del orig_nodes_w_candidates[orig_node]
|
371
|
+
|
372
|
+
# remaining a nodes
|
373
|
+
for orig_node, a_candidates in orig_nodes_a_candidates.items():
|
374
|
+
assert not orig_node.has_any_configurable_weight() # if it had we should have caught it above
|
375
|
+
assert len(a_candidates) == 1
|
376
|
+
assert orig_node not in orig_cfg
|
377
|
+
if include_non_configurable or orig_node.has_configurable_activation():
|
378
|
+
orig_cfg[orig_node] = a_candidates[0]
|
379
|
+
|
380
|
+
# remaining w nodes
|
381
|
+
for orig_node, w_candidates in orig_nodes_w_candidates.items():
|
382
|
+
assert not orig_node.has_configurable_activation() # if it had we should have caught it above
|
383
|
+
assert len(w_candidates) == 1
|
384
|
+
assert orig_node not in orig_cfg
|
385
|
+
if include_non_configurable or orig_node.has_any_configurable_weight():
|
386
|
+
orig_cfg[orig_node] = w_candidates[0]
|
387
|
+
|
388
|
+
return orig_cfg
|
389
|
+
|
390
|
+
def reconstruct_separate_aw_configs(self, virtual_cfg: Dict[BaseNode, int], include_non_configurable: bool) \
|
391
|
+
-> Tuple[Dict[BaseNode, int], Dict[BaseNode, int]]:
|
392
|
+
"""
|
393
|
+
Retrieves original activation and weights nodes and corresponding candidates for a given configuration of the
|
394
|
+
virtual graph. Only returns configuration specified by the virtual config, per configurable target (activation
|
395
|
+
or weights). For example, if 'virtual_cfg' contains a single VirtualActivationWeightsNode, the returned
|
396
|
+
configuration will contain only activation config for the original activation node, and only weights config
|
397
|
+
for the original weights node).
|
398
|
+
In practice, we return candidate index in both cases, instead of actual activation or weights config, since
|
399
|
+
sensitivity evaluator heavily depends on it, so we must ignore activation config in weights candidate and vice
|
400
|
+
versa. This is bad!!! TODO
|
520
401
|
|
521
402
|
Args:
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
virtual_cfg_idx: The virtual node's chosen config index.
|
526
|
-
virtual_mp_cfg: The virtual graph's chosen mp config.
|
527
|
-
"""
|
403
|
+
virtual_cfg: a mapping from nodes in the virtual graph to selected candidate index.
|
404
|
+
include_non_configurable: whether to return configs for non-configurable target (i.e. activation config
|
405
|
+
for non-configurable activation, and weights config for non-configurable weight).
|
528
406
|
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0]
|
533
|
-
|
534
|
-
weights_bitwidth = (virtual_node.candidates_quantization_cfg[virtual_cfg_idx].weights_quantization_cfg
|
535
|
-
.get_attr_config(kernel_attr).weights_n_bits)
|
536
|
-
|
537
|
-
origin_cfg_idx = [i for i, c in
|
538
|
-
enumerate(weights_node.origin_node.candidates_quantization_cfg) if
|
539
|
-
c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_bitwidth and
|
540
|
-
c.activation_quantization_cfg.activation_n_bits == activation_bitwidth]
|
541
|
-
|
542
|
-
self.update_config_at_original_idx(weights_node.origin_node, origin_cfg_idx[0])
|
543
|
-
|
544
|
-
def retrieve_weights_activation_config(self,
|
545
|
-
activation_node: BaseNode,
|
546
|
-
weights_node: BaseNode,
|
547
|
-
virtual_node: BaseNode,
|
548
|
-
virtual_cfg_idx: int,
|
549
|
-
virtual_mp_cfg: List[int]):
|
550
|
-
"""
|
551
|
-
Retrieves the configuration of an original weights and activation (possibly) configurable node based on a given
|
552
|
-
virtual split activation node and a virtual split weights node which represents its matching in the original graph.
|
553
|
-
it updates (inplace) the origin_cfg_idx mapping dict.
|
554
|
-
|
555
|
-
Args:
|
556
|
-
activation_node: The virtual node that contains the activation representation of an original node.
|
557
|
-
weights_node: The virtual node that contains the weights that matches the activation node in the original graph.
|
558
|
-
virtual_node: The virtual node that contains the virtual activation node (either a composed node or a split activation node).
|
559
|
-
virtual_cfg_idx: The virtual node's chosen config index.
|
560
|
-
virtual_mp_cfg: The virtual graph's chosen mp config.
|
407
|
+
Returns:
|
408
|
+
Configuration for original activation nodes and a separate configuration for original weights nodes.
|
561
409
|
"""
|
410
|
+
a_cfg = {}
|
411
|
+
w_cfg = {}
|
412
|
+
for virtual_node, virtual_qc_ind in virtual_cfg.items():
|
413
|
+
orig_a_node, orig_a_candidates = self._retrieve_matching_orig_a_candidates(virtual_node, virtual_qc_ind)
|
414
|
+
if orig_a_node and (include_non_configurable or orig_a_node.has_configurable_activation()):
|
415
|
+
# we may have retrieved multiple candidates with different weights configs and identical activation
|
416
|
+
# configs, so we just take the first
|
417
|
+
a_cfg[orig_a_node] = orig_a_candidates[0]
|
562
418
|
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
activation_bitwidth = virtual_node.candidates_quantization_cfg[
|
570
|
-
virtual_cfg_idx].activation_quantization_cfg.activation_n_bits
|
571
|
-
|
572
|
-
origin_cfg_idx = [i for i, c in enumerate(activation_node.origin_node.candidates_quantization_cfg) if
|
573
|
-
c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_bitwidth and
|
574
|
-
c.activation_quantization_cfg.activation_n_bits == activation_bitwidth]
|
419
|
+
orig_w_node, orig_w_candidates = self._retrieve_matching_orig_w_candidates(virtual_node, virtual_qc_ind)
|
420
|
+
if orig_w_node and (include_non_configurable or orig_w_node.has_any_configurable_weight()):
|
421
|
+
# we may have retrieved multiple candidates with different activation configs and identical weights
|
422
|
+
# configs, so we just take the first
|
423
|
+
w_cfg[orig_w_node] = orig_w_candidates[0]
|
575
424
|
|
576
|
-
|
425
|
+
return a_cfg, w_cfg
|
577
426
|
|
578
|
-
def
|
579
|
-
|
580
|
-
|
581
|
-
virtual_cfg_idx: int,
|
582
|
-
virtual_mp_cfg: List[int]):
|
427
|
+
def _retrieve_matching_orig_a_candidates(self,
|
428
|
+
virtual_node: BaseNode,
|
429
|
+
virtual_qc_ind: int) -> Tuple[Optional[BaseNode], Optional[List[int]]]:
|
583
430
|
"""
|
584
|
-
|
585
|
-
|
431
|
+
Retrieve the original activation node and all its candidates matching activation quantization config of the
|
432
|
+
given virtual candidate (candidate of a node in the virtual graph).
|
433
|
+
Note that we do simple matching, without any filtering, so disabled activation quantization will be also matched.
|
586
434
|
|
587
435
|
Args:
|
588
|
-
|
589
|
-
|
590
|
-
virtual_cfg_idx: The virtual node's chosen config index.
|
591
|
-
virtual_mp_cfg: The virtual graph's chosen mp config.
|
436
|
+
virtual_node: node in the virtual graph (can be virtual or regular).
|
437
|
+
virtual_qc_ind: candidate index of the virtual node.
|
592
438
|
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
if
|
603
|
-
|
604
|
-
assert activation_node.name in self.virtual_sorted_nodes_names # Sanity check
|
605
|
-
# The original node is both weights and activation configurable
|
606
|
-
self.retrieve_activation_weights_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg)
|
607
|
-
else:
|
608
|
-
# weights_node here is a split weights node therefore must have 'origin_node'
|
609
|
-
self.retrieve_weights_only_config(weights_node.origin_node, virtual_node, virtual_cfg_idx)
|
439
|
+
Returns:
|
440
|
+
The original activation node (actual object from the original graph) and a list of its matching candidates.
|
441
|
+
"""
|
442
|
+
if not isinstance(virtual_node, VirtualNode):
|
443
|
+
return self.orig_nodes[virtual_node.name], [virtual_qc_ind]
|
444
|
+
if isinstance(virtual_node, VirtualSplitWeightsNode):
|
445
|
+
return None, None
|
446
|
+
if isinstance(virtual_node, VirtualActivationWeightsNode):
|
447
|
+
orig_a_node = virtual_node.original_activation_node
|
448
|
+
if isinstance(orig_a_node, VirtualSplitActivationNode):
|
449
|
+
orig_a_node = orig_a_node.origin_node
|
610
450
|
else:
|
611
|
-
assert isinstance(
|
612
|
-
|
613
|
-
self.retrieve_activation_weights_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg)
|
614
|
-
else:
|
615
|
-
# The original node is only weights configurable
|
616
|
-
# weights_node here is a split weights node therefore must have 'origin_node'
|
617
|
-
self.retrieve_weights_only_config(weights_node.origin_node, virtual_node, virtual_cfg_idx)
|
618
|
-
|
619
|
-
def get_weights_for_split_activation(self,
|
620
|
-
activation_node: BaseNode,
|
621
|
-
virtual_node: BaseNode,
|
622
|
-
virtual_cfg_idx: int,
|
623
|
-
virtual_mp_cfg: List[int]):
|
624
|
-
"""
|
625
|
-
Finds the matching weights node in the virtual graph for a given split activation node,
|
626
|
-
and calls the relevant method for updating the configuration mapping.
|
451
|
+
assert isinstance(virtual_node, VirtualSplitActivationNode)
|
452
|
+
orig_a_node = virtual_node.origin_node
|
627
453
|
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
454
|
+
virtual_qc = virtual_node.candidates_quantization_cfg[virtual_qc_ind]
|
455
|
+
matching_orig_a_cfgs = [i for i, orig_qc in enumerate(orig_a_node.candidates_quantization_cfg)
|
456
|
+
if orig_qc.activation_quantization_cfg == virtual_qc.activation_quantization_cfg]
|
457
|
+
if not matching_orig_a_cfgs: # pragma: no cover
|
458
|
+
raise ValueError(f'Could not find matching activation quantization config in the original node '
|
459
|
+
f'{orig_a_node} for candidate {virtual_qc_ind} of the virtual node {virtual_node}')
|
460
|
+
return self.orig_nodes[orig_a_node.name], matching_orig_a_cfgs
|
635
461
|
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
assert len(matching_weights_node) == 1
|
640
|
-
weights_node = matching_weights_node[0]
|
641
|
-
|
642
|
-
if isinstance(weights_node, VirtualActivationWeightsNode):
|
643
|
-
kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0]
|
644
|
-
if weights_node.original_weights_node.is_weights_quantization_enabled(kernel_attr) and not \
|
645
|
-
weights_node.original_weights_node.is_all_weights_candidates_equal(kernel_attr):
|
646
|
-
assert weights_node.name in self.virtual_sorted_nodes_names # Sanity check
|
647
|
-
# The original node is both weights and activation configurable
|
648
|
-
self.retrieve_weights_activation_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg)
|
649
|
-
else:
|
650
|
-
# The original node is only activation configurable
|
651
|
-
# activation_node here is a split activation node therefore must have 'origin_node'
|
652
|
-
self.retrieve_activation_only_config(activation_node.origin_node, virtual_node, virtual_cfg_idx)
|
653
|
-
else:
|
654
|
-
# If the node's predecessor e multiple outgoing edges than it is possible that this weights
|
655
|
-
# node is not composed with an activation, but otherwise this is something wrong and we need
|
656
|
-
# to raise an exception
|
657
|
-
predecessor = self.virtual_graph.get_prev_nodes(weights_node)
|
658
|
-
assert len(predecessor) == 1 # Sanity check
|
659
|
-
predecessor = predecessor[0]
|
660
|
-
if len(self.virtual_graph.out_edges(predecessor)) > 1:
|
661
|
-
# It's ok, need to find the node's configuration
|
662
|
-
self.retrieve_weights_activation_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg)
|
663
|
-
else:
|
664
|
-
Logger.critical(f"Virtual graph configuration error: Expected the predecessor of node '{weights_node.name}' to have multiple outputs when not composed with an activation node.") # pragma: no cover
|
665
|
-
|
666
|
-
def update_config_at_original_idx(self, n: BaseNode, origin_cfg_idx: int):
|
462
|
+
def _retrieve_matching_orig_w_candidates(self,
|
463
|
+
virtual_node: BaseNode,
|
464
|
+
virtual_qc_ind: int) -> Tuple[Optional[BaseNode], Optional[List[int]]]:
|
667
465
|
"""
|
668
|
-
|
669
|
-
(
|
466
|
+
Retrieve the original weights node and all its candidates matching weights quantization config of the
|
467
|
+
given virtual candidate (candidate of a node in the virtual graph).
|
670
468
|
|
671
469
|
Args:
|
672
|
-
|
673
|
-
|
470
|
+
virtual_node: node in the virtual graph (can be virtual or regular).
|
471
|
+
virtual_qc_ind: candidate index of the virtual node.
|
674
472
|
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
473
|
+
Returns:
|
474
|
+
The original weights node (actual object from the original graph) and a list of all its matching candidates.
|
475
|
+
"""
|
476
|
+
if not isinstance(virtual_node, VirtualNode):
|
477
|
+
if virtual_node.weights:
|
478
|
+
return self.orig_nodes[virtual_node.name], [virtual_qc_ind]
|
479
|
+
return None, None
|
480
|
+
if isinstance(virtual_node, VirtualSplitActivationNode):
|
481
|
+
return None, None
|
482
|
+
|
483
|
+
if isinstance(virtual_node, VirtualActivationWeightsNode):
|
484
|
+
assert isinstance(virtual_node.original_weights_node, VirtualSplitWeightsNode)
|
485
|
+
orig_w_node = virtual_node.original_weights_node.origin_node
|
486
|
+
else:
|
487
|
+
assert isinstance(virtual_node, VirtualSplitWeightsNode)
|
488
|
+
orig_w_node = virtual_node.origin_node
|
489
|
+
|
490
|
+
virtual_qc = virtual_node.candidates_quantization_cfg[virtual_qc_ind]
|
491
|
+
|
492
|
+
# Matching candidate is a candidate with matching configs for configurable weights. We cannot compare the entire
|
493
|
+
# weights config since the virtual node may contain additional non-configurable weights from the activation node
|
494
|
+
orig_configurable_attrs = [attr for attr in orig_w_node.weights if virtual_node.is_configurable_weight(attr)]
|
495
|
+
assert all(virtual_node.is_configurable_weight(attr) for attr in orig_configurable_attrs)
|
496
|
+
|
497
|
+
def get_configurable_attrs_cfgs(qc):
|
498
|
+
return {attr: qc.weights_quantization_cfg.get_attr_config(attr) for attr in orig_configurable_attrs}
|
499
|
+
virtual_cfg = get_configurable_attrs_cfgs(virtual_qc)
|
500
|
+
matching_orig_w_cfgs = [i for i, orig_qc in enumerate(orig_w_node.candidates_quantization_cfg)
|
501
|
+
if get_configurable_attrs_cfgs(orig_qc) == virtual_cfg]
|
502
|
+
if not matching_orig_w_cfgs: # pragma: no cover
|
503
|
+
raise ValueError(f'Could not find matching weights quantization config in the original node '
|
504
|
+
f'{orig_w_node} for candidate {virtual_qc_ind} of the virtual node {virtual_node}')
|
505
|
+
return self.orig_nodes[orig_w_node.name], matching_orig_w_cfgs
|
@@ -29,7 +29,7 @@ from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_
|
|
29
29
|
from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
|
30
30
|
from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph
|
31
31
|
from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
|
32
|
-
VirtualSplitWeightsNode
|
32
|
+
VirtualSplitWeightsNode, VirtualNode
|
33
33
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
34
34
|
RUTarget, ResourceUtilization
|
35
35
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
|
@@ -531,6 +531,7 @@ class ResourceUtilizationCalculator:
|
|
531
531
|
Returns:
|
532
532
|
Node's BOPS count.
|
533
533
|
"""
|
534
|
+
assert not isinstance(n, VirtualNode), 'Use original graph to compute BOPS.'
|
534
535
|
if target_criterion is None:
|
535
536
|
target_criterion = TargetInclusionCriterion.Any
|
536
537
|
if target_criterion not in [TargetInclusionCriterion.AnyQuantized, TargetInclusionCriterion.Any]:
|
@@ -539,20 +540,6 @@ class ResourceUtilizationCalculator:
|
|
539
540
|
self._validate_custom_qcs(act_qcs, bitwidth_mode)
|
540
541
|
self._validate_custom_qcs(w_qc, bitwidth_mode)
|
541
542
|
|
542
|
-
if isinstance(n, VirtualSplitWeightsNode):
|
543
|
-
# Virtual weights node can only be present if it couldn't be merged into VirtualActivationWeightsNode.
|
544
|
-
# This means that during MP search we cannot compute bops for all A/W nbits combinations. To prevent
|
545
|
-
# inconsistencies we ignore such nodes for bops computation.
|
546
|
-
return 0
|
547
|
-
|
548
|
-
# Fetch the original weights node for mac computation (VirtualActivationWeightsNode input/output shapes are
|
549
|
-
# based on the activation original node, not weights original node)
|
550
|
-
orig_w_node = n
|
551
|
-
if isinstance(n, VirtualActivationWeightsNode):
|
552
|
-
orig_w_node = n.original_weights_node
|
553
|
-
if isinstance(orig_w_node, VirtualSplitWeightsNode):
|
554
|
-
orig_w_node = orig_w_node.origin_node
|
555
|
-
|
556
543
|
# check if the node has kernel
|
557
544
|
kernel_attrs = self.fw_info.get_kernel_op_attributes(n.type)
|
558
545
|
if len(kernel_attrs) > 1: # pragma: no cover
|
@@ -561,21 +548,13 @@ class ResourceUtilizationCalculator:
|
|
561
548
|
return 0
|
562
549
|
|
563
550
|
kernel_attr = kernel_attrs[0]
|
564
|
-
node_mac = self.fw_impl.get_node_mac_operations(
|
551
|
+
node_mac = self.fw_impl.get_node_mac_operations(n, self.fw_info)
|
565
552
|
if node_mac == 0:
|
566
553
|
return node_mac
|
567
554
|
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
a_node = n
|
572
|
-
else:
|
573
|
-
# if we are running on the original (non-virtual) graph, we only compute bops if it would be computed in an
|
574
|
-
# equivalent virtual graph for consistency.
|
575
|
-
a_node = get_input_activation_if_composable(self.graph, n, warn=False)
|
576
|
-
if a_node is None:
|
577
|
-
return 0
|
578
|
-
|
555
|
+
prev_nodes = self.graph.get_prev_nodes(n)
|
556
|
+
assert len(prev_nodes) == 1, f'Weights node is expected to have exactly one input, {n} has {len(prev_nodes)}'
|
557
|
+
a_node = prev_nodes[0]
|
579
558
|
if (target_criterion == TargetInclusionCriterion.AnyQuantized and
|
580
559
|
not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):
|
581
560
|
return 0
|
File without changes
|
File without changes
|
{mct_nightly-2.3.0.20250505.616.dist-info → mct_nightly-2.3.0.20250507.555.dist-info}/top_level.txt
RENAMED
File without changes
|