mct-nightly 2.3.0.20250505.616__py3-none-any.whl → 2.3.0.20250507.555__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250505.616
3
+ Version: 2.3.0.20250507.555
4
4
  Summary: A Model Compression Toolkit for neural networks
5
+ Author-email: ssi-dnn-dev@sony.com
5
6
  Classifier: Programming Language :: Python :: 3
6
7
  Classifier: License :: OSI Approved :: Apache Software License
7
8
  Classifier: Operating System :: OS Independent
@@ -23,6 +24,7 @@ Requires-Dist: protobuf
23
24
  Requires-Dist: mct-quantizers-nightly
24
25
  Requires-Dist: pydantic>=2.0
25
26
  Requires-Dist: sony-custom-layers-dev==0.4.0.dev6
27
+ Dynamic: author-email
26
28
  Dynamic: classifier
27
29
  Dynamic: description
28
30
  Dynamic: description-content-type
@@ -51,7 +53,7 @@ ______________________________________________________________________
51
53
  </p>
52
54
  <p align="center">
53
55
  <a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/pytorch-2.2%20%7C%202.3%20%7C%202.4%20%7C%202.5-blue" /></a>
54
- <a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/tensorflow-02.14%20%7C%202.15-blue" /></a>
56
+ <a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/tensorflow-2.14%20%7C%202.15-blue" /></a>
55
57
  <a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue" /></a>
56
58
  <a href="https://github.com/sony/model_optimization/releases"><img src="https://img.shields.io/github/v/release/sony/model_optimization" /></a>
57
59
  <a href="https://github.com/sony/model_optimization/blob/main/LICENSE.md"><img src="https://img.shields.io/badge/license-Apache%202.0-blue" /></a>
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250505.616.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=blX207LzIjS7kQeI15kYyjJnPY-XhwGQPlrTjt2S0CY,1557
1
+ mct_nightly-2.3.0.20250507.555.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=QzNbJcOvpHUdWIDaA4UVEDr7PLGs8z-feuEZ7nopltg,1557
3
3
  model_compression_toolkit/constants.py,sha256=iJ6vfTjC2oFIZWt8wvHoxEw5YJi3yl0Hd4q30_8q0Zc,3958
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -40,7 +40,7 @@ model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-
40
40
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
41
41
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
42
42
  model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
43
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=JH33qnHTaqFXcYSzTVMpDc9N93503y2pY3hiVJELuZI,10704
43
+ model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=ahuvX2H7__hwTrtR02QbadlDJjagvKovFg6KKNU9svo,10443
44
44
  model_compression_toolkit/core/common/graph/memory_graph/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
45
45
  model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py,sha256=X6FK3C3y8ixFRPjC_wm3ClloCX8_06SOdA1TRi7o_LA,3800
46
46
  model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py,sha256=oyz260JXDbvL8aI-DVtUvLHtLRWC2Yu4SBYlGL68c2Y,3498
@@ -69,13 +69,13 @@ model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates
69
69
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=onHgDwfw8CUbZFNU-RYit9eqA6FrzAtFA3akVZ2d7IM,4533
70
70
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=-hOMBucYn12ePyLd0b1KxniPOIRu4b53SwEzv0bWToI,4943
71
71
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=d5-3j2e_rdcQOT7c4s0p7640i3nSetjJ6MgMhhMM7dc,6152
72
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=J8io_axti6gRoch9QR0FmKOP8JSHGeKqX95rf-nG6fI,37719
72
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=Lk5cftihGpgFQoyqnRGiwJFFqkI8dkx0l1q0sVJi2CE,27505
73
73
  model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=R3UIO9lKf-lpEGfJOqgpQAXdP1IWMatWxXKYDkhWj_E,28096
74
74
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
75
75
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=S1ChgxtUjzXJufNWyRbKoNdyNC6fGUjPeComDMx8ZCo,9479
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
77
77
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
78
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=PXBuUUuYDmukjhgyrwEe71egpT_iu-LQt5SqddgkRHo,40793
78
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=-kNcmQQFVHRPizInaRrCEIuh_q_57CWxC6CIV6azF4g,39640
79
79
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
80
80
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
81
81
  model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=6Z6nQL9UH7B8dbcUR0cuCTEYFOKZAlvOb-SCk_cAZFA,6670
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
528
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
529
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
530
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250505.616.dist-info/METADATA,sha256=StJnfqa-V7mHpnvKCpoRrFkWKX2TVY2tXMbSAK5vkA0,25101
532
- mct_nightly-2.3.0.20250505.616.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
533
- mct_nightly-2.3.0.20250505.616.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250505.616.dist-info/RECORD,,
531
+ mct_nightly-2.3.0.20250507.555.dist-info/METADATA,sha256=hIfm1mpPLcDseqKnO60EFxdc5f3T66WfAIb0gwT7TEk,25157
532
+ mct_nightly-2.3.0.20250507.555.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
533
+ mct_nightly-2.3.0.20250507.555.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.3.0.20250507.555.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250505.000616"
30
+ __version__ = "2.3.0.20250507.000555"
@@ -12,22 +12,25 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import abc
15
16
  import uuid
16
17
 
17
- from typing import Dict, Any, Tuple
18
-
19
18
  from model_compression_toolkit.core import FrameworkInfo
20
19
  from model_compression_toolkit.constants import VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX, \
21
20
  VIRTUAL_WEIGHTS_SUFFIX, VIRTUAL_ACTIVATION_SUFFIX, FLOAT_BITWIDTH
22
-
21
+ from model_compression_toolkit.core.common.framework_info import DEFAULT_KERNEL_ATTRIBUTES
23
22
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
24
- import numpy as np
25
-
26
23
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
27
24
  CandidateNodeQuantizationConfig
25
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
26
+
28
27
 
28
+ class VirtualNode(BaseNode, abc.ABC):
29
+ """ Base class for all virtual nodes. """
30
+ pass
29
31
 
30
- class VirtualSplitNode(BaseNode):
32
+
33
+ class VirtualSplitNode(VirtualNode, abc.ABC):
31
34
  """
32
35
  A class that represents a node that was split from a kernel node (node with weights).
33
36
  """
@@ -73,14 +76,11 @@ class VirtualSplitWeightsNode(VirtualSplitNode):
73
76
  super().__init__(origin_node)
74
77
 
75
78
  self.name = origin_node.name + VIRTUAL_WEIGHTS_SUFFIX
76
- # Virtual weights node is created only to be absorbed into virtual composed node right away.
77
- # However, in some cases composition is impossible and virtual weights node can remain in the graph.
78
- # In such case it messes up resource utilization computation, specifically activation cuts. In order to minimize
79
- # the impact, we preserve the behavior of the original node wrt activation (shape and quantization),
80
- # so that prev - virtualW cut is identical to prev-origin_node. Only the cut virtualW-virtualA will be different
81
- # from the original graph, so in the worst case the utilization will be higher in virtual graph.
82
- # This should guarantee that the utilization of the original graph does not exceed the requested target.
83
- self.candidates_quantization_cfg = origin_node.candidates_quantization_cfg
79
+
80
+ self.candidates_quantization_cfg = origin_node.get_unique_weights_candidates(kernel_attr)
81
+ for c in self.candidates_quantization_cfg:
82
+ c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
83
+ c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
84
84
 
85
85
 
86
86
  class VirtualSplitActivationNode(VirtualSplitNode):
@@ -113,7 +113,7 @@ class VirtualSplitActivationNode(VirtualSplitNode):
113
113
  c.weights_quantization_cfg.weights_n_bits = FLOAT_BITWIDTH
114
114
 
115
115
 
116
- class VirtualActivationWeightsNode(BaseNode):
116
+ class VirtualActivationWeightsNode(VirtualNode):
117
117
  """
118
118
  A node that represents a composition of pair of sequential activation node and weights (kernel) node.
119
119
  This structure is used for mixed-precision search with bit-operation constraint.
@@ -149,7 +149,7 @@ class VirtualActivationWeightsNode(BaseNode):
149
149
  weights = weights_node.weights.copy()
150
150
  act_node_w_rename = {}
151
151
  if act_node.weights:
152
- if not fw_info.get_kernel_op_attributes(act_node)[0] is None:
152
+ if fw_info.get_kernel_op_attributes(act_node) != DEFAULT_KERNEL_ATTRIBUTES:
153
153
  raise NotImplementedError(f'Node {act_node} with kernel cannot be used as activation for '
154
154
  f'VirtualActivationWeightsNode.')
155
155
  if act_node.has_any_configurable_weight():
@@ -19,7 +19,7 @@ from collections import defaultdict
19
19
 
20
20
  from tqdm import tqdm
21
21
 
22
- from typing import Dict, List, Tuple
22
+ from typing import Dict, List, Tuple, Optional
23
23
 
24
24
  import numpy as np
25
25
 
@@ -28,7 +28,7 @@ from model_compression_toolkit.core.common.framework_implementation import Frame
28
28
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
29
29
  from model_compression_toolkit.core.common.graph.base_graph import Graph
30
30
  from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
31
- VirtualSplitWeightsNode, VirtualSplitActivationNode
31
+ VirtualSplitWeightsNode, VirtualSplitActivationNode, VirtualNode
32
32
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
33
33
  RUTarget, ResourceUtilization
34
34
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
@@ -83,10 +83,9 @@ class MixedPrecisionSearchManager:
83
83
  self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(fw_info)
84
84
  self.max_ru_config: Dict[BaseNode, int] = self.mp_graph.get_max_candidates_config(fw_info)
85
85
 
86
- self.config_reconstruction_helper = ConfigReconstructionHelper(virtual_graph=self.mp_graph,
87
- original_graph=self.original_graph)
86
+ self.config_reconstruction_helper = ConfigReconstructionHelper(self.original_graph)
88
87
  if self.using_virtual_graph:
89
- real_min_ru_config: Dict[BaseNode, int] = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(self.min_ru_config)
88
+ real_min_ru_config = self.config_reconstruction_helper.reconstruct_full_configuration(self.min_ru_config)
90
89
  self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, real_min_ru_config)
91
90
  else:
92
91
  self.min_ru = self.ru_helper.compute_utilization(self.ru_targets, self.min_ru_config)
@@ -101,7 +100,7 @@ class MixedPrecisionSearchManager:
101
100
  mp_config = self._prepare_and_run_solver()
102
101
 
103
102
  if self.using_virtual_graph:
104
- mp_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(mp_config)
103
+ mp_config = self.config_reconstruction_helper.reconstruct_full_configuration(mp_config)
105
104
 
106
105
  return mp_config
107
106
 
@@ -112,9 +111,9 @@ class MixedPrecisionSearchManager:
112
111
  Returns:
113
112
  Mapping from nodes to indices of the selected bit-widths candidate.
114
113
  """
115
- layers_candidates_sensitivity: Dict[BaseNode, List[float]] = self._build_sensitivity_mapping()
116
114
  candidates_ru = self._compute_relative_ru_matrices()
117
115
  rel_target_ru = self._get_relative_ru_constraint_per_mem_element()
116
+ layers_candidates_sensitivity: Dict[BaseNode, List[float]] = self._build_sensitivity_mapping()
118
117
  solver = MixedPrecisionIntegerLPSolver(layers_candidates_sensitivity, candidates_ru, rel_target_ru)
119
118
  mp_config = solver.run()
120
119
  return mp_config
@@ -171,8 +170,7 @@ class MixedPrecisionSearchManager:
171
170
  topo_cfg(baseline_cfg) if baseline_cfg else None)
172
171
 
173
172
  if self.using_virtual_graph:
174
- origin_max_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(
175
- self.max_ru_config)
173
+ origin_max_config = self.config_reconstruction_helper.reconstruct_full_configuration(self.max_ru_config)
176
174
  max_config_value = compute_metric(origin_max_config)
177
175
  else:
178
176
  max_config_value = compute_metric(self.max_ru_config)
@@ -192,22 +190,12 @@ class MixedPrecisionSearchManager:
192
190
  # Build a distance matrix using the function we got from the framework implementation.
193
191
  if self.using_virtual_graph:
194
192
  # Reconstructing original graph's configuration from virtual graph's configuration
195
- origin_mp_model_configuration = \
196
- self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(
197
- mp_model_configuration,
198
- changed_virtual_nodes_idx=[node_idx],
199
- original_base_config=origin_max_config)
200
- origin_changed_nodes_indices = [i for i, (n, c) in enumerate(origin_max_config.items()) if
201
- c != origin_mp_model_configuration[n]]
202
- metric_value = compute_metric(
203
- origin_mp_model_configuration,
204
- origin_changed_nodes_indices,
205
- origin_max_config)
193
+ orig_mp_config = self.config_reconstruction_helper.reconstruct_full_configuration(mp_model_configuration)
194
+ changed_nodes = [orig_sorted_nodes.index(n) for n, ind in orig_mp_config.items()
195
+ if origin_max_config[n] != ind]
196
+ metric_value = compute_metric(orig_mp_config, changed_nodes, origin_max_config)
206
197
  else:
207
- metric_value = compute_metric(
208
- mp_model_configuration,
209
- [node_idx],
210
- self.max_ru_config)
198
+ metric_value = compute_metric(mp_model_configuration, [node_idx], self.max_ru_config)
211
199
  metric_value = max(metric_value, max_config_value + eps)
212
200
  layer_to_metrics_mapping[node].append(metric_value)
213
201
 
@@ -256,7 +244,7 @@ class MixedPrecisionSearchManager:
256
244
  else:
257
245
  cfg = self.min_ru_config.copy()
258
246
  cfg[node] = candidate_idx
259
- real_cfg = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(cfg)
247
+ real_cfg = self.config_reconstruction_helper.reconstruct_full_configuration(cfg)
260
248
  candidate_rus = self.ru_helper.compute_utilization(self.ru_targets, real_cfg)
261
249
 
262
250
  for target, ru in candidate_rus.items():
@@ -326,353 +314,192 @@ class MixedPrecisionSearchManager:
326
314
 
327
315
 
328
316
  class ConfigReconstructionHelper:
329
- """
330
- A class to help reconstruct an original mixed-precision configuration from a virtual one,
331
- when running mixed-precision search with BOPS utilization.
332
- It provides a reconstruct_config_from_virtual_graph which allows to translate a bit-width config of a virtual graph
333
- to a config of the original configurable nodes.
334
- """
335
-
336
- def __init__(self, virtual_graph: Graph, original_graph: Graph):
337
- """
338
- Init a ConfigReconstructionHelper object.
339
- It holds a dictionary variable named origin_node_idx_to_cfg which holds the mapping from an original graph's
340
- configurable node to its actual bit-width index (this data structure is being cleared
341
- after every reconstruction call).
342
-
343
- Args:
344
- virtual_graph: The virtual graph.
345
- original_graph: The original graph.
346
- """
347
-
348
- self.virtual_graph = virtual_graph
349
- self.original_graph = original_graph
350
- self.fw_info = original_graph.fw_info
351
-
352
- self.virtual_sorted_nodes_names = self.virtual_graph.get_configurable_sorted_nodes_names(self.fw_info)
353
- self.origin_sorted_conf_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
354
- self.origin_sorted_conf_nodes_names = [n.name for n in self.origin_sorted_conf_nodes]
355
-
356
- self.origin_node_idx_to_cfg = {}
357
-
358
- def _clear_reconstruction_dict(self):
359
- """
360
- Clears the origin_node_idx_to_cfg data structure.
361
- """
362
-
363
- self.origin_node_idx_to_cfg = {}
364
-
365
- def reconstruct_config_from_virtual_graph(self,
366
- virtual_mp_cfg: Dict[BaseNode, int],
367
- changed_virtual_nodes_idx: List[int] = None,
368
- original_base_config: Dict[BaseNode, int] = None) -> Dict[BaseNode, int]:
369
- """
370
- Reconstructs the original config for a given virtual graph mixed-precision config.
371
- It iterates over all virtual configurable node (that has some chosen bit-width virtual candidate)
372
- and translates its chosen candidate to a candidate index of configurable nodes in the original graph.
373
- The translation is based of the virtual node's type. Note that if the node is a split activation node
374
- for instance, then we need to find its matching weights node in order to construct the original linear node's
375
- chosen config.
376
-
377
- Args:
378
- virtual_mp_cfg: A mixed-precision configuration (list of candidates indices) of the virtual graph.
379
- changed_virtual_nodes_idx: Provide an optional list of virtual nodes indices for which the
380
- config reconstruction will be computed.
381
- original_base_config: If changed_virtual_nodes_idx is provided, need to provide a base config from which the
382
- bit-width for all un-changed original nodes will be taken.
383
-
384
- Returns: A mixed-precision configuration (list of candidates indices) of the original graph.
385
-
386
- """
387
-
388
- if changed_virtual_nodes_idx is not None:
389
- if original_base_config is None:
390
- Logger.critical("To run config reconstruction for a partial set of nodes, a base original config must be provided.") # pragma: no cover
391
-
392
- updated_virtual_nodes = \
393
- [(idx, self.virtual_graph.get_configurable_sorted_nodes(self.fw_info)[idx]) for idx in changed_virtual_nodes_idx]
394
- # Iterating only over the virtual nodes that have updated config
395
- for virtual_node_idx, n in updated_virtual_nodes:
396
- self.reconstruct_node_config(n, list(virtual_mp_cfg.values()), virtual_node_idx)
397
- # Updating reconstructed config for all other nodes based on provided base_config
398
- original_sorted_conf_nodes = self.original_graph.get_configurable_sorted_nodes(self.fw_info)
399
- for i, (n, qc_ind) in enumerate(original_base_config.items()):
400
- if i not in list(self.origin_node_idx_to_cfg.keys()):
401
- self.update_config_at_original_idx(n=n, origin_cfg_idx=qc_ind)
402
- else:
403
- # Reconstruct entire config
404
- for virtual_node_idx, n in enumerate(self.virtual_graph.get_configurable_sorted_nodes(self.fw_info)):
405
- self.reconstruct_node_config(n, list(virtual_mp_cfg.values()), virtual_node_idx)
406
-
407
- res_config = [self.origin_node_idx_to_cfg[key] for key in sorted(self.origin_node_idx_to_cfg.keys())]
408
- self._clear_reconstruction_dict()
409
- assert len(res_config) == len(self.origin_sorted_conf_nodes)
410
- return {n: candidate_idx for n, candidate_idx in zip(self.origin_sorted_conf_nodes, res_config)}
411
-
412
- def reconstruct_node_config(self,
413
- n: BaseNode,
414
- virtual_mp_cfg: List[int],
415
- virtual_node_idx: int):
416
- """
417
- Reconstructs the original configuration for a single node. Updates the mapping inplace.
418
-
419
- Args:
420
- n: The node to reconstruct the configuration for.
421
- virtual_mp_cfg: A mixed-precision configuration (list of candidates indices) of the virtual graph.
422
- virtual_node_idx: The index of the virtual node in the virtual mixed-precision configuration.
423
- """
424
-
425
- virtual_cfg_idx = virtual_mp_cfg[virtual_node_idx]
426
-
427
- if isinstance(n, VirtualActivationWeightsNode):
428
- weights_node = n.original_weights_node
429
- if isinstance(weights_node, VirtualSplitWeightsNode):
430
- self.get_activation_for_split_weights(weights_node, n, virtual_cfg_idx, virtual_mp_cfg)
431
- else:
432
- Logger.critical(f"Virtual graph construction error: Expected all weights nodes to be split into weights and activation nodes. Found node '{n.name}' not split as expected. Every weights node should correspond to a VirtualSplitWeightsNode type.") # pragma: no cover
433
-
434
- activation_node = n.original_activation_node
435
- if isinstance(activation_node, VirtualSplitActivationNode):
436
- self.get_weights_for_split_activation(activation_node, n, virtual_cfg_idx, virtual_mp_cfg)
437
- else:
438
- if activation_node.name in self.origin_sorted_conf_nodes_names:
439
- # It is possible that the original activation node is not configurable,
440
- # in this case we don't need to retrieve its bit-width config
441
- self.retrieve_activation_only_config(activation_node, n, virtual_cfg_idx)
442
- elif isinstance(n, VirtualSplitWeightsNode):
443
- # If the node's predecessor have multiple outgoing edges then it is possible that this weights
444
- # node is not composed with an activation, but otherwise there is something wrong, and we need
445
- # to raise an exception
446
- predecessor = self.virtual_graph.get_prev_nodes(n)
447
- assert len(predecessor) == 1 # Sanity check
448
- predecessor = predecessor[0]
449
- if len(self.virtual_graph.out_edges(predecessor)) > 1:
450
- # It's ok, need to find the node's configuration
451
- self.get_activation_for_split_weights(n, n, virtual_cfg_idx, virtual_mp_cfg)
452
- else:
453
- Logger.critical(f"Virtual graph configuration error: Expected the predecessor of node '{n.name}' to have multiple outputs when not composed with an activation node.") # pragma: no cover
454
- elif isinstance(n, VirtualSplitActivationNode):
455
- self.get_weights_for_split_activation(n, n, virtual_cfg_idx, virtual_mp_cfg)
456
- else:
457
- # Node didn't change in virtual graph - candidates list is similar to original
458
- if n.name not in self.origin_sorted_conf_nodes_names:
459
- Logger.critical(f"Configuration mismatch: Node '{n.name}' is configurable in the virtual graph but not in the original graph. Verify node configurations.") # pragma: no cover
460
- origin_idx = self.origin_sorted_conf_nodes_names.index(n.name)
461
- self.origin_node_idx_to_cfg[origin_idx] = virtual_cfg_idx
462
-
463
- def retrieve_weights_only_config(self, weights_node: BaseNode, virtual_node: BaseNode, virtual_cfg_idx: int):
464
- """
465
- Retrieves the configuration of an original weights configurable node based on a
466
- virtual weights configurable node's chosen config idx, and updates (inplace) the origin_cfg_idx mapping dict.
467
- If the original node is not configurable, nothing will be updated.
317
+ def __init__(self, original_graph):
318
+ # mapping in order to return the actual node objects from the original graph
319
+ self.orig_nodes = {n.name: n for n in original_graph.nodes}
468
320
 
469
- Args:
470
- weights_node: The original weights (possibly configurable) node.
471
- virtual_node: The virtual weights configurable node.
472
- virtual_cfg_idx: The virtual node's chosen config index.
473
- """
474
-
475
- if weights_node.name in self.origin_sorted_conf_nodes_names:
476
- # It is possible that the original weights node is not configurable,
477
- # in this case we don't need to retrieve its bit-width config
478
- kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0]
479
- weights_bitwidth = (virtual_node.candidates_quantization_cfg[virtual_cfg_idx].weights_quantization_cfg
480
- .get_attr_config(kernel_attr).weights_n_bits)
481
- origin_cfg_idx = [i for i, c in
482
- enumerate(weights_node.candidates_quantization_cfg) if
483
- c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_bitwidth]
484
-
485
- self.update_config_at_original_idx(weights_node, origin_cfg_idx[0])
486
-
487
- def retrieve_activation_only_config(self, activation_node: BaseNode, virtual_node: BaseNode, virtual_cfg_idx: int):
321
+ def reconstruct_full_configuration(self,
322
+ virtual_cfg: Dict[BaseNode, int],
323
+ include_non_configurable: bool = False) -> Dict[BaseNode, int]:
488
324
  """
489
- Retrieves the configuration of an original activation configurable node based on a
490
- virtual activation configurable node's chosen config idx, and updates (inplace) the origin_cfg_idx mapping dict.
491
- If the original node is not configurable, nothing will be updated.
325
+ Convert a configuration of a virtual graph into the corresponding configuration of the original graph.
326
+ Note that a configurable VirtualActivationWeightsNode might comprise one configurable and one non-configurable
327
+ original nodes.
492
328
 
493
329
  Args:
494
- activation_node: The original activation (possibly configurable) node.
495
- virtual_node: The virtual activation configurable node.
496
- virtual_cfg_idx: The virtual node's chosen config index.
497
- """
330
+ virtual_cfg: a mapping from nodes in the virtual graph to selected candidate index. Should contain all
331
+ configurable nodes of the virtual graph, and only configurable nodes.
332
+ include_non_configurable: whether to return configs for non-configurable original nodes.
498
333
 
499
- if activation_node.name in self.origin_sorted_conf_nodes_names:
500
- # It is possible that the original activation node is not configurable,
501
- # in this case we don't need to retrieve its bit-width config
502
- activation_bitwidth = virtual_node.candidates_quantization_cfg[
503
- virtual_cfg_idx].activation_quantization_cfg.activation_n_bits
504
- origin_cfg_idx = [i for i, c in
505
- enumerate(activation_node.candidates_quantization_cfg) if
506
- c.activation_quantization_cfg.activation_n_bits == activation_bitwidth]
507
-
508
- self.update_config_at_original_idx(activation_node, origin_cfg_idx[0])
509
-
510
- def retrieve_activation_weights_config(self,
511
- activation_node: BaseNode,
512
- weights_node: BaseNode,
513
- virtual_node: BaseNode,
514
- virtual_cfg_idx: int,
515
- virtual_mp_cfg: List[int]):
516
- """
517
- Retrieves the configuration of an original weights and activation (possibly) configurable node based on a given
518
- virtual split weights node and a virtual split activation node which represents its matching in the original graph.
519
- it updates (inplace) the origin_cfg_idx mapping dict.
334
+ Returns:
335
+ A mapping from configurable nodes in the original graph to their candidate indices.
336
+ """
337
+ # Original candidate of a node that has been split might be determined by two different virtual nodes, one
338
+ # determines activation and one - weights. First, for each virtual node we collect the original
339
+ # activation / weights nodes, with all original candidates that match the virtual candidate
340
+ # activation / weights config. If both activation and weights of the original node are determined by virtual
341
+ # candidates, we look for a common candidate.
342
+ orig_nodes_a_candidates = {}
343
+ orig_nodes_w_candidates = {}
344
+ for virtual_node, virtual_qc_ind in virtual_cfg.items():
345
+ assert virtual_node.has_configurable_activation() or virtual_node.has_any_configurable_weight()
346
+ orig_a_node, orig_a_candidates = self._retrieve_matching_orig_a_candidates(virtual_node, virtual_qc_ind)
347
+ if orig_a_node and (include_non_configurable or orig_a_node.has_configurable_activation()):
348
+ assert orig_a_node not in orig_nodes_a_candidates
349
+ orig_nodes_a_candidates[orig_a_node] = orig_a_candidates
350
+ orig_w_node, orig_w_candidates = self._retrieve_matching_orig_w_candidates(virtual_node, virtual_qc_ind)
351
+ if orig_w_node and (include_non_configurable or orig_w_node.has_any_configurable_weight()):
352
+ assert orig_w_node not in orig_nodes_w_candidates
353
+ orig_nodes_w_candidates[orig_w_node] = orig_w_candidates
354
+
355
+ orig_cfg = {}
356
+ common_orig_nodes = set(orig_nodes_a_candidates.keys()).intersection(set(orig_nodes_w_candidates))
357
+ for orig_node in common_orig_nodes:
358
+ a_candidates = orig_nodes_a_candidates[orig_node]
359
+ w_candidates = orig_nodes_w_candidates[orig_node]
360
+ # find the common candidate
361
+ common_candidates = set(a_candidates).intersection(set(w_candidates))
362
+ if len(common_candidates) != 1:
363
+ raise ValueError(f'Expected to find exactly one candidate with the required activation and weights '
364
+ f'quantization configuration for node {orig_node}. Found {len(common_candidates)}')
365
+ # in theory it's possible that original non-configurable node gets split and each part is combined
366
+ # with a configurable part of another node and we end up here
367
+ if orig_node.has_configurable_activation() or orig_node.has_any_configurable_weight():
368
+ orig_cfg[orig_node] = common_candidates.pop()
369
+ del orig_nodes_a_candidates[orig_node]
370
+ del orig_nodes_w_candidates[orig_node]
371
+
372
+ # remaining a nodes
373
+ for orig_node, a_candidates in orig_nodes_a_candidates.items():
374
+ assert not orig_node.has_any_configurable_weight() # if it had we should have caught it above
375
+ assert len(a_candidates) == 1
376
+ assert orig_node not in orig_cfg
377
+ if include_non_configurable or orig_node.has_configurable_activation():
378
+ orig_cfg[orig_node] = a_candidates[0]
379
+
380
+ # remaining w nodes
381
+ for orig_node, w_candidates in orig_nodes_w_candidates.items():
382
+ assert not orig_node.has_configurable_activation() # if it had we should have caught it above
383
+ assert len(w_candidates) == 1
384
+ assert orig_node not in orig_cfg
385
+ if include_non_configurable or orig_node.has_any_configurable_weight():
386
+ orig_cfg[orig_node] = w_candidates[0]
387
+
388
+ return orig_cfg
389
+
390
+ def reconstruct_separate_aw_configs(self, virtual_cfg: Dict[BaseNode, int], include_non_configurable: bool) \
391
+ -> Tuple[Dict[BaseNode, int], Dict[BaseNode, int]]:
392
+ """
393
+ Retrieves original activation and weights nodes and corresponding candidates for a given configuration of the
394
+ virtual graph. Only returns configuration specified by the virtual config, per configurable target (activation
395
+ or weights). For example, if 'virtual_cfg' contains a single VirtualActivationWeightsNode, the returned
396
+ configuration will contain only activation config for the original activation node, and only weights config
397
+ for the original weights node).
398
+ In practice, we return candidate index in both cases, instead of actual activation or weights config, since
399
+ sensitivity evaluator heavily depends on it, so we must ignore activation config in weights candidate and vice
400
+ versa. This is bad!!! TODO
520
401
 
521
402
  Args:
522
- activation_node: The virtual node that contains the activation that matches the weights node in the original graph.
523
- weights_node: The virtual node that contains the weights representation of an original node.
524
- virtual_node: The virtual node that contains the virtual weights node (either a composed node or a split weights node).
525
- virtual_cfg_idx: The virtual node's chosen config index.
526
- virtual_mp_cfg: The virtual graph's chosen mp config.
527
- """
403
+ virtual_cfg: a mapping from nodes in the virtual graph to selected candidate index.
404
+ include_non_configurable: whether to return configs for non-configurable target (i.e. activation config
405
+ for non-configurable activation, and weights config for non-configurable weight).
528
406
 
529
- activation_bitwidth = activation_node.candidates_quantization_cfg[virtual_mp_cfg[
530
- self.virtual_sorted_nodes_names.index(activation_node.name)]].activation_quantization_cfg.activation_n_bits
531
-
532
- kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0]
533
-
534
- weights_bitwidth = (virtual_node.candidates_quantization_cfg[virtual_cfg_idx].weights_quantization_cfg
535
- .get_attr_config(kernel_attr).weights_n_bits)
536
-
537
- origin_cfg_idx = [i for i, c in
538
- enumerate(weights_node.origin_node.candidates_quantization_cfg) if
539
- c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_bitwidth and
540
- c.activation_quantization_cfg.activation_n_bits == activation_bitwidth]
541
-
542
- self.update_config_at_original_idx(weights_node.origin_node, origin_cfg_idx[0])
543
-
544
- def retrieve_weights_activation_config(self,
545
- activation_node: BaseNode,
546
- weights_node: BaseNode,
547
- virtual_node: BaseNode,
548
- virtual_cfg_idx: int,
549
- virtual_mp_cfg: List[int]):
550
- """
551
- Retrieves the configuration of an original weights and activation (possibly) configurable node based on a given
552
- virtual split activation node and a virtual split weights node which represents its matching in the original graph.
553
- it updates (inplace) the origin_cfg_idx mapping dict.
554
-
555
- Args:
556
- activation_node: The virtual node that contains the activation representation of an original node.
557
- weights_node: The virtual node that contains the weights that matches the activation node in the original graph.
558
- virtual_node: The virtual node that contains the virtual activation node (either a composed node or a split activation node).
559
- virtual_cfg_idx: The virtual node's chosen config index.
560
- virtual_mp_cfg: The virtual graph's chosen mp config.
407
+ Returns:
408
+ Configuration for original activation nodes and a separate configuration for original weights nodes.
561
409
  """
410
+ a_cfg = {}
411
+ w_cfg = {}
412
+ for virtual_node, virtual_qc_ind in virtual_cfg.items():
413
+ orig_a_node, orig_a_candidates = self._retrieve_matching_orig_a_candidates(virtual_node, virtual_qc_ind)
414
+ if orig_a_node and (include_non_configurable or orig_a_node.has_configurable_activation()):
415
+ # we may have retrieved multiple candidates with different weights configs and identical activation
416
+ # configs, so we just take the first
417
+ a_cfg[orig_a_node] = orig_a_candidates[0]
562
418
 
563
- kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0]
564
-
565
- weights_bitwidth = (weights_node.candidates_quantization_cfg[virtual_mp_cfg[
566
- self.virtual_sorted_nodes_names.index(weights_node.name)]]
567
- .weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits)
568
-
569
- activation_bitwidth = virtual_node.candidates_quantization_cfg[
570
- virtual_cfg_idx].activation_quantization_cfg.activation_n_bits
571
-
572
- origin_cfg_idx = [i for i, c in enumerate(activation_node.origin_node.candidates_quantization_cfg) if
573
- c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_bitwidth and
574
- c.activation_quantization_cfg.activation_n_bits == activation_bitwidth]
419
+ orig_w_node, orig_w_candidates = self._retrieve_matching_orig_w_candidates(virtual_node, virtual_qc_ind)
420
+ if orig_w_node and (include_non_configurable or orig_w_node.has_any_configurable_weight()):
421
+ # we may have retrieved multiple candidates with different activation configs and identical weights
422
+ # configs, so we just take the first
423
+ w_cfg[orig_w_node] = orig_w_candidates[0]
575
424
 
576
- self.update_config_at_original_idx(activation_node.origin_node, origin_cfg_idx[0])
425
+ return a_cfg, w_cfg
577
426
 
578
- def get_activation_for_split_weights(self,
579
- weights_node: BaseNode,
580
- virtual_node: BaseNode,
581
- virtual_cfg_idx: int,
582
- virtual_mp_cfg: List[int]):
427
+ def _retrieve_matching_orig_a_candidates(self,
428
+ virtual_node: BaseNode,
429
+ virtual_qc_ind: int) -> Tuple[Optional[BaseNode], Optional[List[int]]]:
583
430
  """
584
- Finds the matching activation node in the virtual graph for a given split weights node,
585
- and calls the relevant method for updating the configuration mapping.
431
+ Retrieve the original activation node and all its candidates matching activation quantization config of the
432
+ given virtual candidate (candidate of a node in the virtual graph).
433
+ Note that we do simple matching, without any filtering, so disabled activation quantization will be also matched.
586
434
 
587
435
  Args:
588
- weights_node: A virtual weights node.
589
- virtual_node: A virtual node that contains the virtual weights node (either a composed node or a split weights node).
590
- virtual_cfg_idx: The virtual node's chosen config index.
591
- virtual_mp_cfg: The virtual graph's chosen mp config.
436
+ virtual_node: node in the virtual graph (can be virtual or regular).
437
+ virtual_qc_ind: candidate index of the virtual node.
592
438
 
593
- """
594
-
595
- # This is a weights node that was split, means it has an activation node that should follow it,
596
- # and we need its configuration in order to reconstruct the original node's configuration.
597
- matching_activation_node = self.virtual_graph.get_next_nodes(virtual_node)
598
- assert len(matching_activation_node) == 1
599
- activation_node = matching_activation_node[0]
600
-
601
- if isinstance(activation_node, VirtualActivationWeightsNode):
602
- if activation_node.original_activation_node.is_activation_quantization_enabled() and not \
603
- activation_node.original_activation_node.is_all_activation_candidates_equal():
604
- assert activation_node.name in self.virtual_sorted_nodes_names # Sanity check
605
- # The original node is both weights and activation configurable
606
- self.retrieve_activation_weights_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg)
607
- else:
608
- # weights_node here is a split weights node therefore must have 'origin_node'
609
- self.retrieve_weights_only_config(weights_node.origin_node, virtual_node, virtual_cfg_idx)
439
+ Returns:
440
+ The original activation node (actual object from the original graph) and a list of its matching candidates.
441
+ """
442
+ if not isinstance(virtual_node, VirtualNode):
443
+ return self.orig_nodes[virtual_node.name], [virtual_qc_ind]
444
+ if isinstance(virtual_node, VirtualSplitWeightsNode):
445
+ return None, None
446
+ if isinstance(virtual_node, VirtualActivationWeightsNode):
447
+ orig_a_node = virtual_node.original_activation_node
448
+ if isinstance(orig_a_node, VirtualSplitActivationNode):
449
+ orig_a_node = orig_a_node.origin_node
610
450
  else:
611
- assert isinstance(activation_node, VirtualSplitActivationNode) # Sanity check
612
- if activation_node.name in self.virtual_sorted_nodes_names:
613
- self.retrieve_activation_weights_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg)
614
- else:
615
- # The original node is only weights configurable
616
- # weights_node here is a split weights node therefore must have 'origin_node'
617
- self.retrieve_weights_only_config(weights_node.origin_node, virtual_node, virtual_cfg_idx)
618
-
619
- def get_weights_for_split_activation(self,
620
- activation_node: BaseNode,
621
- virtual_node: BaseNode,
622
- virtual_cfg_idx: int,
623
- virtual_mp_cfg: List[int]):
624
- """
625
- Finds the matching weights node in the virtual graph for a given split activation node,
626
- and calls the relevant method for updating the configuration mapping.
451
+ assert isinstance(virtual_node, VirtualSplitActivationNode)
452
+ orig_a_node = virtual_node.origin_node
627
453
 
628
- Args:
629
- activation_node: A virtual activation node.
630
- virtual_node: A virtual node that contains the virtual activation node (either a composed node or a split activation node).
631
- virtual_cfg_idx: The virtual node's chosen config index.
632
- virtual_mp_cfg: The virtual graph's chosen mp config.
633
-
634
- """
454
+ virtual_qc = virtual_node.candidates_quantization_cfg[virtual_qc_ind]
455
+ matching_orig_a_cfgs = [i for i, orig_qc in enumerate(orig_a_node.candidates_quantization_cfg)
456
+ if orig_qc.activation_quantization_cfg == virtual_qc.activation_quantization_cfg]
457
+ if not matching_orig_a_cfgs: # pragma: no cover
458
+ raise ValueError(f'Could not find matching activation quantization config in the original node '
459
+ f'{orig_a_node} for candidate {virtual_qc_ind} of the virtual node {virtual_node}')
460
+ return self.orig_nodes[orig_a_node.name], matching_orig_a_cfgs
635
461
 
636
- # This is an activation node that was split, means it has a weights node that should come before it,
637
- # and we need its configuration in order to reconstruct the original node's configuration.
638
- matching_weights_node = self.virtual_graph.get_prev_nodes(virtual_node)
639
- assert len(matching_weights_node) == 1
640
- weights_node = matching_weights_node[0]
641
-
642
- if isinstance(weights_node, VirtualActivationWeightsNode):
643
- kernel_attr = self.fw_info.get_kernel_op_attributes(weights_node.type)[0]
644
- if weights_node.original_weights_node.is_weights_quantization_enabled(kernel_attr) and not \
645
- weights_node.original_weights_node.is_all_weights_candidates_equal(kernel_attr):
646
- assert weights_node.name in self.virtual_sorted_nodes_names # Sanity check
647
- # The original node is both weights and activation configurable
648
- self.retrieve_weights_activation_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg)
649
- else:
650
- # The original node is only activation configurable
651
- # activation_node here is a split activation node therefore must have 'origin_node'
652
- self.retrieve_activation_only_config(activation_node.origin_node, virtual_node, virtual_cfg_idx)
653
- else:
654
- # If the node's predecessor e multiple outgoing edges than it is possible that this weights
655
- # node is not composed with an activation, but otherwise this is something wrong and we need
656
- # to raise an exception
657
- predecessor = self.virtual_graph.get_prev_nodes(weights_node)
658
- assert len(predecessor) == 1 # Sanity check
659
- predecessor = predecessor[0]
660
- if len(self.virtual_graph.out_edges(predecessor)) > 1:
661
- # It's ok, need to find the node's configuration
662
- self.retrieve_weights_activation_config(activation_node, weights_node, virtual_node, virtual_cfg_idx, virtual_mp_cfg)
663
- else:
664
- Logger.critical(f"Virtual graph configuration error: Expected the predecessor of node '{weights_node.name}' to have multiple outputs when not composed with an activation node.") # pragma: no cover
665
-
666
- def update_config_at_original_idx(self, n: BaseNode, origin_cfg_idx: int):
462
+ def _retrieve_matching_orig_w_candidates(self,
463
+ virtual_node: BaseNode,
464
+ virtual_qc_ind: int) -> Tuple[Optional[BaseNode], Optional[List[int]]]:
667
465
  """
668
- Updates (inplace) the origin_node_idx_to_cfg mapping wit hthe given index for a given original node index
669
- (in the original graph's sorted configurable nodes list).
466
+ Retrieve the original weights node and all its candidates matching weights quantization config of the
467
+ given virtual candidate (candidate of a node in the virtual graph).
670
468
 
671
469
  Args:
672
- n: An original graph's node
673
- origin_cfg_idx: A candidate index.
470
+ virtual_node: node in the virtual graph (can be virtual or regular).
471
+ virtual_qc_ind: candidate index of the virtual node.
674
472
 
675
- """
676
-
677
- origin_idx = self.origin_sorted_conf_nodes_names.index(n.name)
678
- self.origin_node_idx_to_cfg[origin_idx] = origin_cfg_idx
473
+ Returns:
474
+ The original weights node (actual object from the original graph) and a list of all its matching candidates.
475
+ """
476
+ if not isinstance(virtual_node, VirtualNode):
477
+ if virtual_node.weights:
478
+ return self.orig_nodes[virtual_node.name], [virtual_qc_ind]
479
+ return None, None
480
+ if isinstance(virtual_node, VirtualSplitActivationNode):
481
+ return None, None
482
+
483
+ if isinstance(virtual_node, VirtualActivationWeightsNode):
484
+ assert isinstance(virtual_node.original_weights_node, VirtualSplitWeightsNode)
485
+ orig_w_node = virtual_node.original_weights_node.origin_node
486
+ else:
487
+ assert isinstance(virtual_node, VirtualSplitWeightsNode)
488
+ orig_w_node = virtual_node.origin_node
489
+
490
+ virtual_qc = virtual_node.candidates_quantization_cfg[virtual_qc_ind]
491
+
492
+ # Matching candidate is a candidate with matching configs for configurable weights. We cannot compare the entire
493
+ # weights config since the virtual node may contain additional non-configurable weights from the activation node
494
+ orig_configurable_attrs = [attr for attr in orig_w_node.weights if virtual_node.is_configurable_weight(attr)]
495
+ assert all(virtual_node.is_configurable_weight(attr) for attr in orig_configurable_attrs)
496
+
497
+ def get_configurable_attrs_cfgs(qc):
498
+ return {attr: qc.weights_quantization_cfg.get_attr_config(attr) for attr in orig_configurable_attrs}
499
+ virtual_cfg = get_configurable_attrs_cfgs(virtual_qc)
500
+ matching_orig_w_cfgs = [i for i, orig_qc in enumerate(orig_w_node.candidates_quantization_cfg)
501
+ if get_configurable_attrs_cfgs(orig_qc) == virtual_cfg]
502
+ if not matching_orig_w_cfgs: # pragma: no cover
503
+ raise ValueError(f'Could not find matching weights quantization config in the original node '
504
+ f'{orig_w_node} for candidate {virtual_qc_ind} of the virtual node {virtual_node}')
505
+ return self.orig_nodes[orig_w_node.name], matching_orig_w_cfgs
@@ -29,7 +29,7 @@ from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_
29
29
  from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
30
30
  from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph
31
31
  from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
32
- VirtualSplitWeightsNode
32
+ VirtualSplitWeightsNode, VirtualNode
33
33
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
34
34
  RUTarget, ResourceUtilization
35
35
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
@@ -531,6 +531,7 @@ class ResourceUtilizationCalculator:
531
531
  Returns:
532
532
  Node's BOPS count.
533
533
  """
534
+ assert not isinstance(n, VirtualNode), 'Use original graph to compute BOPS.'
534
535
  if target_criterion is None:
535
536
  target_criterion = TargetInclusionCriterion.Any
536
537
  if target_criterion not in [TargetInclusionCriterion.AnyQuantized, TargetInclusionCriterion.Any]:
@@ -539,20 +540,6 @@ class ResourceUtilizationCalculator:
539
540
  self._validate_custom_qcs(act_qcs, bitwidth_mode)
540
541
  self._validate_custom_qcs(w_qc, bitwidth_mode)
541
542
 
542
- if isinstance(n, VirtualSplitWeightsNode):
543
- # Virtual weights node can only be present if it couldn't be merged into VirtualActivationWeightsNode.
544
- # This means that during MP search we cannot compute bops for all A/W nbits combinations. To prevent
545
- # inconsistencies we ignore such nodes for bops computation.
546
- return 0
547
-
548
- # Fetch the original weights node for mac computation (VirtualActivationWeightsNode input/output shapes are
549
- # based on the activation original node, not weights original node)
550
- orig_w_node = n
551
- if isinstance(n, VirtualActivationWeightsNode):
552
- orig_w_node = n.original_weights_node
553
- if isinstance(orig_w_node, VirtualSplitWeightsNode):
554
- orig_w_node = orig_w_node.origin_node
555
-
556
543
  # check if the node has kernel
557
544
  kernel_attrs = self.fw_info.get_kernel_op_attributes(n.type)
558
545
  if len(kernel_attrs) > 1: # pragma: no cover
@@ -561,21 +548,13 @@ class ResourceUtilizationCalculator:
561
548
  return 0
562
549
 
563
550
  kernel_attr = kernel_attrs[0]
564
- node_mac = self.fw_impl.get_node_mac_operations(orig_w_node, self.fw_info)
551
+ node_mac = self.fw_impl.get_node_mac_operations(n, self.fw_info)
565
552
  if node_mac == 0:
566
553
  return node_mac
567
554
 
568
- # find the activation node from which to get quantization info and for which to look in custom configuration
569
- if isinstance(n, VirtualActivationWeightsNode):
570
- # we don't need the original node (and cannot use it for custom configuration anyway)
571
- a_node = n
572
- else:
573
- # if we are running on the original (non-virtual) graph, we only compute bops if it would be computed in an
574
- # equivalent virtual graph for consistency.
575
- a_node = get_input_activation_if_composable(self.graph, n, warn=False)
576
- if a_node is None:
577
- return 0
578
-
555
+ prev_nodes = self.graph.get_prev_nodes(n)
556
+ assert len(prev_nodes) == 1, f'Weights node is expected to have exactly one input, {n} has {len(prev_nodes)}'
557
+ a_node = prev_nodes[0]
579
558
  if (target_criterion == TargetInclusionCriterion.AnyQuantized and
580
559
  not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):
581
560
  return 0