mct-nightly 2.3.0.20250517.552__py3-none-any.whl → 2.3.0.20250519.609__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {mct_nightly-2.3.0.20250517.552.dist-info → mct_nightly-2.3.0.20250519.609.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.3.0.20250517.552.dist-info → mct_nightly-2.3.0.20250519.609.dist-info}/RECORD +17 -17
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/fusion/fusing_info.py +99 -32
  5. model_compression_toolkit/core/common/fusion/graph_fuser.py +5 -3
  6. model_compression_toolkit/core/common/graph/base_graph.py +1 -1
  7. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +1 -1
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -2
  9. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -3
  10. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +1 -1
  11. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
  12. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +84 -4
  13. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +5 -5
  14. model_compression_toolkit/core/runner.py +1 -1
  15. {mct_nightly-2.3.0.20250517.552.dist-info → mct_nightly-2.3.0.20250519.609.dist-info}/WHEEL +0 -0
  16. {mct_nightly-2.3.0.20250517.552.dist-info → mct_nightly-2.3.0.20250519.609.dist-info}/licenses/LICENSE.md +0 -0
  17. {mct_nightly-2.3.0.20250517.552.dist-info → mct_nightly-2.3.0.20250519.609.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250517.552
3
+ Version: 2.3.0.20250519.609
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Author-email: ssi-dnn-dev@sony.com
6
6
  Classifier: Programming Language :: Python :: 3
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250517.552.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=RueNJeNFQG6WxXCYDucXbXAnF5xB1DP5nCV-ouC3da0,1557
1
+ mct_nightly-2.3.0.20250519.609.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=YPU9BdtahSvlNGtO2EWfbfcVInfNzm5_GyZndC6ZA78,1557
3
3
  model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -9,7 +9,7 @@ model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZ
9
9
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
10
10
  model_compression_toolkit/core/graph_prep_runner.py,sha256=C6eUTd-fcgxk0LUbt51gFZwmyDDDEB8-9Q4kr9ujYvI,11555
11
11
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8K5f5v9g5lbT3r1GSmhLAk3NkL40Fg,6593
12
- model_compression_toolkit/core/runner.py,sha256=_r6cieb7Ur2BeHQK5XxTZHogjyA0utybvIVbH06CBHY,13056
12
+ model_compression_toolkit/core/runner.py,sha256=EM3B_t_TDUr_ttrQvZFhf6qxO9aIAYOwdl5FU8Y32Ow,13064
13
13
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
14
14
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
15
15
  model_compression_toolkit/core/common/framework_implementation.py,sha256=JQI_eoZZoNk5Y_jAxLfYt9-wzfs7zGpTldz9UblxmMc,21182
@@ -31,10 +31,10 @@ model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.p
31
31
  model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
32
32
  model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
33
33
  model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
34
- model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=S7hBbUJxL52Z8uJ9_upLdFyoSEJvgmVX0OmneqDIj-c,18656
35
- model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
34
+ model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=uDxF0awrjn3SbcpXBpoQ4OGcKO6Z7HBk8ierZPCGbGo,21970
35
+ model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=yxxxuwrmQ4wLW-PlTu0MEW59LmNJEh1OWy9Li15YH-8,7520
36
36
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
37
- model_compression_toolkit/core/common/graph/base_graph.py,sha256=BSQpKy0BXoGX0G0bySTo72n2isTqvtpkbRYYa8-hPO4,41435
37
+ model_compression_toolkit/core/common/graph/base_graph.py,sha256=YA0c8ucaaZu9eRO-xruLqDT3QFOpxq24ViG6ILS2jqA,41403
38
38
  model_compression_toolkit/core/common/graph/base_node.py,sha256=AbUadAT581zelVcGcK9_--6CAGiht9qwkeWahwT3RzE,33389
39
39
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
40
40
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
@@ -67,16 +67,16 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_uti
67
67
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
68
68
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
69
69
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=BO4ouM_UVS9Fg0z95gLJSMz1ep6YQC5za_iXI_qW2yQ,5399
70
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=-hOMBucYn12ePyLd0b1KxniPOIRu4b53SwEzv0bWToI,4943
70
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=axgAypzsiCOw04ZOtOEjK4riuNsaEU2qU6KkWnEXtMo,4951
71
71
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=KhiHGpmN5QbpyJQnTZmXigdXFlSlRNqpOOyKGj1Fwek,6412
72
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=OzRhlJ2IS9Dwv0rgobee0xTtAeRwlBC6KvVEcx2_oB0,28089
72
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=MXOK9WPy3fSt5uxsWYMF4szwwqWWgrlzNJdE9VIb-AQ,28145
73
73
  model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=4uhUXKgwyMrJqEVK5uJzVr67GI5YzDTHLveV4maB7z0,28079
74
74
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=Zn6SgzGLWWKmuYGHd1YtKxZdYnQWRDeXEkKlBiTbHcs,2929
75
75
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=MY8df-c_kITEr_7hOctaxhdiq29hSTA0La9Qo0oTJJY,9678
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
77
77
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
78
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=-kNcmQQFVHRPizInaRrCEIuh_q_57CWxC6CIV6azF4g,39640
79
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
78
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=8f6KDTKD8SzVXDl9jmYJ-p19cQB0Nr_UTdCPuhELTdg,40329
79
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=ZY5yFIDzbaqIk0UzakDBObfsVevn4fydqAfAm4RCikY,4058
80
80
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
81
81
  model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=6Z6nQL9UH7B8dbcUR0cuCTEYFOKZAlvOb-SCk_cAZFA,6670
82
82
  model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
@@ -138,14 +138,14 @@ model_compression_toolkit/core/common/statistics_correction/statistics_correctio
138
138
  model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
139
139
  model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
140
140
  model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
141
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=kz1Xg2OMNXyRbCW3K-wfZpbv6jmLShJjHYUoziOUNv4,8496
141
+ model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=ndAKcbnNtDQ0DfL9WOYMYPlxU71t7xo9uxvaFZQsfjI,8501
142
142
  model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=dWJpVfomF4Ppeeor3VzS23TXHyBm85QI7snyLOYP_ko,9972
143
143
  model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=iEtzbWCDXP6EDkTZCtREQ0rpMxhQ2kM9zlcP_0KLq9I,12367
144
144
  model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=uoauhmncQqUBNvD-qCLIXsIbl_IzrbxSKdxiMig-5W4,2406
145
145
  model_compression_toolkit/core/common/substitutions/remove_identity.py,sha256=TKU1TIU52UIkVnl0EZvWnDhLV9nIVZ4hqi-w1i4NXMk,2637
146
146
  model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=N82mso5j3EJQlKt9EMHjjEJ67FmdGQeCfN8U5grOFXo,4830
147
147
  model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=p57u25qdW2pimxzGwgMXEBV4S-LzXuTVAlIM7830WfU,10966
148
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=zCkdyZHEkbxkORmd071_XWajkpIhnDq9D6FyeE4TQjc,30057
148
+ model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=VcwCVWEooYwg6NGcnSP8OaSzgtzSd4k1r-5a68rpqZc,33713
149
149
  model_compression_toolkit/core/common/substitutions/softmax_shift.py,sha256=R-0ZqhYAuZLEFWHvB2UTPm52L6gWHGdRdEnwGxKSeGI,2625
150
150
  model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py,sha256=w43dRmaG96a8SNECgghxoFCTSoZ-vUb33dXGm2PbomE,4251
151
151
  model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=gt07lXRUvYunJKiwv_w20zfXhcplSW4oT2C1dqiNNXc,4719
@@ -258,7 +258,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape
258
258
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
259
259
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
260
260
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py,sha256=D1hxN3pZ5-_FLJSS30ZJUo-v8TqUWFcMjhMijFa9aSo,12407
261
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=3WCLvPyx7tVkM0rwYhYq-gntCzW9R_DcImR1ucKlPac,10772
261
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=n90Fu2ZkuWPoqy1_GchrQSk6O-HlaeuBeVfaCR_O8xI,10755
262
262
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/softmax_shift.py,sha256=05lV4pIL3hJkZl4JQPV4wk_EFD0eYLG5b8cdzvZk4P8,1588
263
263
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transform_function_call_method.py,sha256=EC9Dvp-_UlpDWnipnf8ds65wh_Y-T8pXAFIwRScWpiY,2044
264
264
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=WmEa8Xjji-_tIbthDxlLAGSr69nWk-YKcHNaVqLa7sg,1375
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
528
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
529
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
530
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250517.552.dist-info/METADATA,sha256=kxKFMh-zWtlCfUBFowzu71E5L-8ybwVw0pgy_rCxVYw,25135
532
- mct_nightly-2.3.0.20250517.552.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
533
- mct_nightly-2.3.0.20250517.552.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250517.552.dist-info/RECORD,,
531
+ mct_nightly-2.3.0.20250519.609.dist-info/METADATA,sha256=7AynprEgb8NPAdgWcUm7I4vNV6rnqahpC68k459nm5Y,25135
532
+ mct_nightly-2.3.0.20250519.609.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
533
+ mct_nightly-2.3.0.20250519.609.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.3.0.20250519.609.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250517.000552"
30
+ __version__ = "2.3.0.20250519.000609"
@@ -36,22 +36,28 @@ class FusingInfo:
36
36
  belong to fused operations and validate this info is correct after changes in the graph.
37
37
 
38
38
  The core structures maintained are:
39
+ - 'fusing_patterns': The patterns to generate the fused operators from.
40
+ - 'manual_fused_ops': List of sequence of node names to handle as fused ops (even if they are not part of the fusing patterns).
39
41
  - `fusing_data`: A dictionary mapping fused operation IDs to lists of nodes that belong to that operation.
40
42
  - `node_to_fused_node_map`: A dictionary mapping each node name to the ID of the fused operation it belongs to.
41
43
 
42
44
  """
43
- fusing_patterns: any = None
45
+ fusing_patterns: List[list[any]] = None
46
+ manual_fused_ops: List[List[str]] = None
44
47
  fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
45
48
  node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
46
49
  fused_op_id_to_quant_config: Dict[str, OpQuantizationConfig] = field(default_factory=dict)
47
50
 
48
51
  def __post_init__(self):
49
52
  """Validates and initializes mappings after dataclass instantiation."""
53
+ self.fusing_patterns = self.fusing_patterns or []
50
54
  for op_id, op_nodes in self.fusing_data.items():
51
55
  assert isinstance(op_id, str) and op_id.startswith(FUSED_OP_ID_PREFIX), f"Found invalid fused op id: {op_id}"
52
56
  assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"
53
57
 
54
58
  self._init_node_mapping()
59
+ self._manual_fused_ops = self.manual_fused_ops or []
60
+ del self.manual_fused_ops
55
61
  self._init_quantization_config_map()
56
62
 
57
63
  def _init_node_mapping(self) -> None:
@@ -63,6 +69,26 @@ class FusingInfo:
63
69
  for node in nodes:
64
70
  self.node_to_fused_node_map[node.name] = op_id
65
71
 
72
+ def get_manual_nodes_to_fuse(self) -> List[List[str]]:
73
+ """
74
+ Get the list of node names to be fused manually.
75
+ """
76
+ return self._manual_fused_ops
77
+
78
+
79
+ def add_manual_nodes_to_fuse(self, node_names: List[str]):
80
+ """
81
+ Add a list of node names to be fused manually.
82
+
83
+ Args:
84
+ node_names: List of nodes to be fused.
85
+
86
+ """
87
+ assert isinstance(node_names, list)
88
+ assert all([isinstance(n, str) for n in node_names])
89
+ assert node_names not in self._manual_fused_ops, f"{node_names} is already in manual fused ops: {self._manual_fused_ops}"
90
+ self._manual_fused_ops.append(node_names)
91
+
66
92
  def _init_quantization_config_map(self) -> None:
67
93
  """
68
94
  Init the mapping between fused operation IDs and their quantization configurations.
@@ -121,12 +147,16 @@ class FusingInfo:
121
147
  raise ValueError(f"Fused operation {op_id} does not exist.")
122
148
  # Remove nodes from the mapping
123
149
  nodes = self.fusing_data[op_id]
150
+ node_names = [n.name for n in nodes]
151
+ if node_names in self._manual_fused_ops:
152
+ self._manual_fused_ops.remove(node_names)
153
+
124
154
  for node in nodes:
125
155
  self.node_to_fused_node_map.pop(node.name, None)
126
156
  del self.fusing_data[op_id]
127
157
  self.fused_op_id_to_quant_config.pop(op_id, None)
128
158
 
129
- def get_fused_node_name(self, node_name: str) -> Optional[str]:
159
+ def get_fused_op_id_for_node(self, node_name: str) -> Optional[str]:
130
160
  """
131
161
  Get the name of the fused node containing the given original node name.
132
162
 
@@ -168,6 +198,12 @@ class FusingInfo:
168
198
  """
169
199
  return self.fusing_data.get(op_id)
170
200
 
201
+ def get_nodes_to_disable_activation_quantization(self) -> List['BaseNode']:
202
+ """
203
+ Returns a list of the nodes that their activation quantization is disabled due to fusing.
204
+ """
205
+ return [node for nodes in self.get_all_fused_operations().values() for node in nodes[:-1]]
206
+
171
207
  def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig:
172
208
  """
173
209
  Retrieve the quantization configuration for a given fused operation ID.
@@ -268,7 +304,7 @@ class FusingInfo:
268
304
 
269
305
  # Check 4: Ensure the sequence matches a valid fusing pattern
270
306
  valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
271
- if not is_valid_fusion(valid_fusing_patterns, nodes):
307
+ if not is_valid_fusion(valid_fusing_patterns, nodes, self._manual_fused_ops):
272
308
  raise ValueError(
273
309
  f"Fused operation {op_id} does not match any valid fusing pattern "
274
310
  f"from {valid_fusing_patterns}."
@@ -311,13 +347,17 @@ class FusingInfo:
311
347
  f" Total fused operations: {len(self.fusing_data)}\n"
312
348
  f" Fusing Data:\n{fusing_data_repr}\n"
313
349
  f" Node-to-Fused Mapping:\n {mapping_repr}\n"
350
+ f" Manual fused ops:\n {self._manual_fused_ops}\n"
314
351
  f")"
315
352
  )
316
353
 
317
354
 
318
355
  class FusingInfoGenerator:
319
- def __init__(self, fusing_patterns):
320
- self._fusing_patterns = fusing_patterns
356
+ def __init__(self, fusing_patterns: List[list] = None, manual_fused_ops: List[List[str]] = None):
357
+ self._fusing_patterns = fusing_patterns or []
358
+ assert isinstance(self._fusing_patterns, list)
359
+ self._manual_fused_ops = manual_fused_ops or []
360
+ assert isinstance(self._manual_fused_ops, list)
321
361
 
322
362
  def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
323
363
  """
@@ -338,7 +378,7 @@ class FusingInfoGenerator:
338
378
  - Each node belongs to at most one fused operation.
339
379
  """
340
380
  if not self._fusing_patterns:
341
- return FusingInfo(fusing_patterns=self._fusing_patterns)
381
+ return FusingInfo(fusing_patterns=self._fusing_patterns, manual_fused_ops=self._manual_fused_ops)
342
382
 
343
383
  # Extract fusing layer patterns
344
384
  fusing_layer_patterns = _get_fusing_layer_patterns(self._fusing_patterns)
@@ -352,31 +392,53 @@ class FusingInfoGenerator:
352
392
  fusing_info: Dict[str, Tuple['BaseNode']] = {}
353
393
  fused_nodes = [] # nodes that are participating in fusing
354
394
 
355
- for node in nodes:
356
- # Skip if already in fusing
357
- if node in fused_nodes:
358
- continue
359
- # Start fusing search
360
- fusing_nodes = [] # nodes that are candidates for participating in fusing
361
- patterns = copy.deepcopy(fusing_layer_patterns)
362
- next_nodes = [node]
363
- for i in range(max_layer_patterns):
364
- patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
365
- if len(patterns) == 0: # Give up if no more fusion pattern
366
- break
367
- fusing_nodes.append(next_nodes[0])
368
- next_nodes = graph.get_next_nodes(fusing_nodes[-1])
369
- if len(next_nodes) != 1: # Give up if node has more than one connection (not supported for fusion)
370
- break
371
-
372
- # New fusion
373
- if is_valid_fusion(fusing_layer_patterns, fusing_nodes):
374
- fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
375
- assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
376
- fusing_info[fused_op_id] = tuple(fusing_nodes)
377
- fused_nodes.extend(fusing_nodes)
378
-
379
- return FusingInfo(fusing_data=fusing_info, fusing_patterns=self._fusing_patterns)
395
+ if len(self._fusing_patterns)>0:
396
+ for node in nodes:
397
+ # Skip if already in fusing
398
+ if node in fused_nodes:
399
+ continue
400
+ # Start fusing search
401
+ fusing_nodes = [] # nodes that are candidates for participating in fusing
402
+ patterns = copy.deepcopy(fusing_layer_patterns)
403
+ next_nodes = [node]
404
+ for i in range(max_layer_patterns):
405
+ patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
406
+ if len(patterns) == 0: # Give up if no more fusion pattern
407
+ break
408
+ fusing_nodes.append(next_nodes[0])
409
+ next_nodes = graph.get_next_nodes(fusing_nodes[-1])
410
+ if len(next_nodes) != 1: # Give up if node has more than one connection (not supported for fusion)
411
+ break
412
+
413
+ # New fusion
414
+ if is_valid_fusion(fusing_layer_patterns, fusing_nodes):
415
+ fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
416
+ assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
417
+ fusing_info[fused_op_id] = tuple(fusing_nodes)
418
+ fused_nodes.extend(fusing_nodes)
419
+
420
+ for manual_names in self._manual_fused_ops:
421
+ manual_nodes = [graph.find_node_by_name(n) for n in manual_names]
422
+ for n in manual_nodes:
423
+ if len(n) != 1:
424
+ raise ValueError(f"Expected exactly one node, but got {len(n)}")
425
+ manual_nodes = [n[0] for n in manual_nodes]
426
+
427
+ # Remove any existing fused ops containing any of the manual nodes
428
+ fused_ids_to_remove = {
429
+ op_id for op_id, nodes in fusing_info.items()
430
+ if any(node in nodes for node in manual_nodes)
431
+ }
432
+ for op_id in fused_ids_to_remove:
433
+ del fusing_info[op_id]
434
+
435
+ fused_op_id = FusingInfo.generate_fused_op_id(manual_nodes)
436
+ assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
437
+ fusing_info[fused_op_id] = tuple(manual_nodes)
438
+
439
+ return FusingInfo(fusing_data=fusing_info,
440
+ fusing_patterns=self._fusing_patterns,
441
+ manual_fused_ops=self._manual_fused_ops)
380
442
 
381
443
 
382
444
  def get_valid_fusing_patterns_for_node(fusing_patterns: List[List[Any]],
@@ -404,15 +466,20 @@ def get_valid_fusing_patterns_for_node(fusing_patterns: List[List[Any]],
404
466
  return valid_fusing_patterns
405
467
 
406
468
 
407
- def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -> bool:
469
+ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode'], manual_fused_names: List[List[str]]=None) -> bool:
408
470
  """
409
471
  Check if the fusion is valid: exist in fusing_patterns
410
472
  Args:
411
473
  fusing_patterns: supported fusing patterns
412
474
  nodes: nodes which are participating in fusion
475
+ manual_fused_names: list of nodes names to handle as a valid fusing op.
413
476
  Returns:
414
477
  whether the fusion in valid
415
478
  """
479
+ node_names = [n.name for n in nodes]
480
+ if any(manual == node_names for manual in (manual_fused_names or [])):
481
+ return True
482
+
416
483
  fusion_depth = len(nodes)
417
484
  if fusion_depth <= 1:
418
485
  return False
@@ -46,12 +46,14 @@ class GraphFuser:
46
46
  The updated graph with fused nodes replacing the original node groups.
47
47
  """
48
48
  graph_copy = copy.deepcopy(graph)
49
- expected_fusing_info = FusingInfoGenerator(graph_copy.fusing_info.fusing_patterns).generate_fusing_info(graph_copy)
49
+ expected_fusing_info = FusingInfoGenerator(graph_copy.fusing_info.fusing_patterns,
50
+ graph_copy.fusing_info.get_manual_nodes_to_fuse()).generate_fusing_info(graph_copy)
50
51
 
51
- if expected_fusing_info != graph_copy.fusing_info:
52
+ existing_fusing_info = graph_copy.fusing_info
53
+ if expected_fusing_info != existing_fusing_info:
52
54
  raise ValueError(
53
55
  f"Mismatch between expected and existing fusing information.\n"
54
- f"Expected:\n{expected_fusing_info}\nExisting:\n{graph_copy.fusing_info}"
56
+ f"Expected:\n{expected_fusing_info}\nExisting:\n{existing_fusing_info}"
55
57
  )
56
58
 
57
59
  fused_operations = list(graph_copy.fusing_info.get_all_fused_operations().items())
@@ -908,7 +908,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
908
908
  Disable activation quantization for all nodes in fused operations,
909
909
  except for the last node in each fused group.
910
910
  """
911
- nodes_to_disable = [node for nodes in self.fusing_info.get_all_fused_operations().values() for node in nodes[:-1]]
911
+ nodes_to_disable = self.fusing_info.get_nodes_to_disable_activation_quantization()
912
912
  for node in nodes_to_disable:
913
913
  for qc in node.candidates_quantization_cfg:
914
914
  qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
@@ -51,7 +51,7 @@ class MixedPrecisionRUHelper:
51
51
  """
52
52
  act_qcs, w_qcs = self.get_quantization_candidates(mp_cfg)
53
53
 
54
- ru, detailed_ru = self.ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
54
+ ru, detailed_ru = self.ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused,
55
55
  BitwidthMode.QCustom,
56
56
  act_qcs=act_qcs,
57
57
  w_qcs=w_qcs,
@@ -294,8 +294,12 @@ class MixedPrecisionSearchManager:
294
294
  """
295
295
  act_qcs, w_qcs = self.orig_graph_ru_helper.get_quantization_candidates(config)
296
296
  ru = self.orig_graph_ru_helper.ru_calculator.compute_resource_utilization(
297
- target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs,
298
- w_qcs=w_qcs, ru_targets=self.ru_targets, allow_unused_qcs=True)
297
+ target_criterion=TargetInclusionCriterion.AnyQuantizedNonFused,
298
+ bitwidth_mode=BitwidthMode.QCustom,
299
+ act_qcs=act_qcs,
300
+ w_qcs=w_qcs,
301
+ ru_targets=self.ru_targets,
302
+ allow_unused_qcs=True)
299
303
  return ru
300
304
 
301
305
  def _finalize_distance_metric(self, layer_to_metrics_mapping: Dict[BaseNode, List[float]]):
@@ -67,11 +67,13 @@ class TargetInclusionCriterion(Enum):
67
67
  QNonConfigurable: non-configurable targets (single quantization candidate).
68
68
  AnyQuantized: any quantized targets (configurable and non-configurable).
69
69
  Any: all targets (quantized + float).
70
+ QuantizedNonFused: any quantized targets that are not inside fused operations.
70
71
  """
71
72
  QConfigurable = auto()
72
73
  QNonConfigurable = auto()
73
74
  AnyQuantized = auto()
74
75
  Any = auto()
76
+ AnyQuantizedNonFused = auto()
75
77
 
76
78
 
77
79
  class Utilization(NamedTuple):
@@ -534,8 +536,9 @@ class ResourceUtilizationCalculator:
534
536
  assert not isinstance(n, VirtualNode), 'Use original graph to compute BOPS.'
535
537
  if target_criterion is None:
536
538
  target_criterion = TargetInclusionCriterion.Any
537
- if target_criterion not in [TargetInclusionCriterion.AnyQuantized, TargetInclusionCriterion.Any]:
538
- raise ValueError('BOPS computation is supported only for Any and AnyQuantized targets.')
539
+ if target_criterion not in [TargetInclusionCriterion.AnyQuantized, TargetInclusionCriterion.AnyQuantizedNonFused, TargetInclusionCriterion.Any]:
540
+ raise ValueError(
541
+ 'BOPS computation is supported only for Any, AnyQuantized and AnyQuantizedNonFused targets.')
539
542
 
540
543
  self._validate_custom_qcs(act_qcs, bitwidth_mode)
541
544
  self._validate_custom_qcs(w_qc, bitwidth_mode)
@@ -621,7 +624,7 @@ class ResourceUtilizationCalculator:
621
624
  weight_attrs = n.get_node_weights_attributes()
622
625
  if target_criterion == TargetInclusionCriterion.QConfigurable:
623
626
  weight_attrs = [attr for attr in weight_attrs if n.is_configurable_weight(attr)]
624
- elif target_criterion == TargetInclusionCriterion.AnyQuantized:
627
+ elif target_criterion in [TargetInclusionCriterion.AnyQuantized, TargetInclusionCriterion.AnyQuantizedNonFused]:
625
628
  weight_attrs = [attr for attr in weight_attrs if n.is_weights_quantization_enabled(attr)]
626
629
  elif target_criterion == TargetInclusionCriterion.QNonConfigurable:
627
630
  quantized = [attr for attr in weight_attrs if n.is_weights_quantization_enabled(attr)]
@@ -671,6 +674,10 @@ class ResourceUtilizationCalculator:
671
674
  nodes = [n for n in nodes if n.has_configurable_activation()]
672
675
  elif target_criterion == TargetInclusionCriterion.AnyQuantized:
673
676
  nodes = [n for n in nodes if n.is_activation_quantization_enabled() or n.is_quantization_preserving()]
677
+ elif target_criterion == TargetInclusionCriterion.AnyQuantizedNonFused:
678
+ nodes = [n for n in nodes if n.is_activation_quantization_enabled() or n.is_quantization_preserving()]
679
+ # remove fused nodes (due to SNC, where the non-linear is quantized, even though it should not be quantized)
680
+ nodes = [n for n in nodes if n not in self.graph.fusing_info.get_nodes_to_disable_activation_quantization()]
674
681
  elif target_criterion == TargetInclusionCriterion.QNonConfigurable:
675
682
  nodes = [n for n in nodes if n.is_activation_quantization_enabled() and not n.has_configurable_activation()]
676
683
  elif target_criterion != TargetInclusionCriterion.Any: # pragma: no cover
@@ -63,4 +63,4 @@ def compute_resource_utilization_data(in_model: Any,
63
63
  running_gptq=False)
64
64
 
65
65
  ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl, fw_info)
66
- return ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized, BitwidthMode.QDefaultSP)
66
+ return ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused, BitwidthMode.QDefaultSP)
@@ -149,7 +149,7 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
149
149
  # the current info, or this creates a new fusion and the old pattern should be
150
150
  # replaced with the new one.
151
151
  fi = graph.fusing_info
152
- fused_op = fi.get_fused_node_name(source_node.name)
152
+ fused_op = fi.get_fused_op_id_for_node(source_node.name)
153
153
  if fused_op:
154
154
  fused_nodes = list(fi.get_fused_nodes(fused_op))
155
155
  assert source_node in fused_nodes
@@ -23,8 +23,6 @@ from model_compression_toolkit.logger import Logger
23
23
  from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
24
24
  from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
25
25
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
26
- from mct_quantizers import QuantizationMethod
27
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
28
26
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
29
27
  set_quantization_configs_to_node
30
28
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
@@ -33,6 +31,7 @@ from model_compression_toolkit.core.common.quantization.quantization_params_gene
33
31
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
34
32
  _mse_error_histogram
35
33
  from model_compression_toolkit.core.common.quantization.quantization_params_generation import z_score_filter
34
+ from model_compression_toolkit.target_platform_capabilities import QuantizationMethod, AttributeQuantizationConfig
36
35
 
37
36
  """
38
37
  This substitution aims to solve an issue of activation with negative outputs where
@@ -188,6 +187,65 @@ def remove_node_between_two_nodes(graph: Graph,
188
187
  graph.remove_node(node_to_remove)
189
188
 
190
189
 
190
+ def fuse_padding_with_op2d(graph: 'BaseGraph', pad_node: 'BaseNode', op2d_node: 'BaseNode') -> None:
191
+ """
192
+ Add a padding node to the fused operation containing op2d_node.
193
+ If op2d_node is not already in a fused op, create a new fused group with both nodes.
194
+
195
+ Args:
196
+ graph: The computational graph.
197
+ pad_node: The padding node to be added.
198
+ op2d_node: The Conv2D or similar op node following the pad.
199
+ """
200
+ fusing_info = graph.fusing_info
201
+
202
+ if fusing_info.is_node_in_fused_op(op2d_node):
203
+ fused_id = fusing_info.get_fused_op_id_for_node(op2d_node.name)
204
+ fused_nodes = fusing_info.get_fused_nodes(fused_id)
205
+ fusing_info.remove_fused_operation(fused_id)
206
+ else:
207
+ fused_nodes = [op2d_node]
208
+
209
+ new_fused_nodes = [pad_node] + list(fused_nodes)
210
+ fused_op_id = fusing_info.generate_fused_op_id(new_fused_nodes)
211
+
212
+ fusing_info.add_fused_operation(fused_op_id, tuple(new_fused_nodes))
213
+ fusing_info.add_manual_nodes_to_fuse([n.name for n in new_fused_nodes])
214
+
215
+ def update_fused_op_with_add(graph: 'BaseGraph', non_linear_node: 'BaseNode', add_node: 'BaseNode') -> None:
216
+ """
217
+ Update the fused operation to include an Add node that follows a non-linear activation node.
218
+
219
+ Args:
220
+ graph: The computational graph.
221
+ non_linear_node: The non-linear activation node (e.g., ReLU).
222
+ add_node: The Add node inserted after the non-linear node.
223
+ """
224
+ fusing_info = graph.fusing_info
225
+ prev_node = graph.get_prev_nodes(non_linear_node)[0]
226
+
227
+ # Gather existing fused nodes (if any)
228
+ fused_candidates = []
229
+ for node in (prev_node, non_linear_node):
230
+ if fusing_info.is_node_in_fused_op(node):
231
+ fused_id = fusing_info.get_fused_op_id_for_node(node.name)
232
+ fused_candidates.extend(fusing_info.get_fused_nodes(fused_id))
233
+ fused_candidates.append(add_node)
234
+
235
+ # Remove duplicates while preserving order
236
+ fused_candidates = list(dict.fromkeys(fused_candidates))
237
+
238
+ # Remove existing fused ops involving prev_node or non_linear_node
239
+ for node in (prev_node, non_linear_node):
240
+ if fusing_info.is_node_in_fused_op(node):
241
+ fusing_info.remove_fused_operation(fusing_info.get_fused_op_id_for_node(node.name))
242
+
243
+ # Register new fused operation
244
+ fused_op_id = fusing_info.generate_fused_op_id(fused_candidates)
245
+ fusing_info.add_manual_nodes_to_fuse([n.name for n in fused_candidates])
246
+ fusing_info.add_fused_operation(fused_op_id, tuple(fused_candidates))
247
+
248
+
191
249
  def shift_negative_function(graph: Graph,
192
250
  core_config: CoreConfig,
193
251
  non_linear_node: BaseNode,
@@ -232,7 +290,6 @@ def shift_negative_function(graph: Graph,
232
290
  Returns:
233
291
  Graph after applying the shifting and correction.
234
292
  """
235
-
236
293
  min_to_correct, max_value2compare = graph.get_out_stats_collector(non_linear_node).get_min_max_values()
237
294
 
238
295
  if not non_linear_node.is_all_activation_candidates_equal():
@@ -242,6 +299,7 @@ def shift_negative_function(graph: Graph,
242
299
  # all candidates have same activation config, so taking the first candidate for calculations
243
300
  non_linear_node_cfg_candidate = non_linear_node.candidates_quantization_cfg[0].activation_quantization_cfg
244
301
 
302
+
245
303
  # get the non-linear activation threshold
246
304
  activation_threshold = non_linear_node_cfg_candidate.activation_quantization_params.get(THRESHOLD)
247
305
 
@@ -350,7 +408,14 @@ def shift_negative_function(graph: Graph,
350
408
  fqc=graph.fqc,
351
409
  mixed_precision_enable=core_config.is_mixed_precision_enabled)
352
410
 
353
- if padding is not None:
411
+ update_fused_op_with_add(graph=graph,
412
+ non_linear_node=non_linear_node,
413
+ add_node=add_node)
414
+
415
+ # If sum([pad_top, pad_btm, pad_left, pad_right])==0 it means we do not pad in any side, thus
416
+ # we do not add a padding node as this is meaningless
417
+ pad_node = None
418
+ if padding is not None and sum([pad_top, pad_btm, pad_left, pad_right])>0:
354
419
  pad_node = create_pad_node(op2d_node.name,
355
420
  add_node.name,
356
421
  shift_value,
@@ -394,8 +459,16 @@ def shift_negative_function(graph: Graph,
394
459
  graph.shift_stats_collector(bypass_node, np.array(shift_value))
395
460
 
396
461
  add_node_qco = add_node.get_qco(graph.fqc).quantization_configurations
462
+ add_supported_bitwidths = [c.activation_n_bits for c in add_node_qco]
463
+ if original_non_linear_activation_nbits not in add_supported_bitwidths:
464
+ raise ValueError(
465
+ f"Add supported activation bit-widths according to the TPC are: {add_supported_bitwidths}, but non-linear "
466
+ f"bitwidth is {original_non_linear_activation_nbits}. Consider adapting the TPC so 'Add' will support the "
467
+ f"same bitwidth as {non_linear_node.type} or disable shift negative correction.")
468
+
397
469
  for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
398
470
  for attr in add_node.get_node_weights_attributes():
471
+ # TODO: do we not quantize the weights of this 'add' on purpose?
399
472
  candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
400
473
 
401
474
  candidate_qc.activation_quantization_cfg = create_node_activation_qc(core_config.quantization_config,
@@ -404,8 +477,15 @@ def shift_negative_function(graph: Graph,
404
477
 
405
478
  candidate_qc.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
406
479
  SIGNED: False})
480
+
407
481
  candidate_qc.activation_quantization_cfg.activation_n_bits = original_non_linear_activation_nbits
408
482
 
483
+ # Add the new padding node to a fused op with the op2d.
484
+ if pad_node:
485
+ fuse_padding_with_op2d(graph=graph,
486
+ pad_node=pad_node,
487
+ op2d_node=op2d_node)
488
+
409
489
  if non_linear_node_cfg_candidate.shift_negative_threshold_recalculation:
410
490
  activation_param = get_activations_qparams(activation_quant_cfg=non_linear_node_cfg_candidate,
411
491
  nodes_prior_info=non_linear_node.prior_info,
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
15
16
  import operator
16
17
  from typing import Tuple, Any, Callable
17
18
 
@@ -149,9 +150,9 @@ def create_pad_node(next_node_name: str,
149
150
  op_call_kwargs = {PAD: [pad_left, pad_right, pad_top, pad_btm],
150
151
  VALUE: float(value_to_pad)}
151
152
 
152
- padded_shape = input_shape[0]
153
- padded_shape[1] += pad_top + pad_btm
154
- padded_shape[2] += pad_left + pad_right
153
+ padded_shape = copy.deepcopy(input_shape[0])
154
+ padded_shape[2] += pad_top + pad_btm
155
+ padded_shape[3] += pad_left + pad_right
155
156
  pad_node = common.graph.functional_node.FunctionalNode(pad_node_name,
156
157
  {},
157
158
  input_shape,
@@ -241,5 +242,4 @@ def pytorch_apply_shift_negative_correction(graph: Graph,
241
242
  PADDING,
242
243
  BIAS,
243
244
  USE_BIAS,
244
- params_search_quantization_fn=params_search_quantization_fn
245
- )
245
+ params_search_quantization_fn=params_search_quantization_fn)
@@ -220,7 +220,7 @@ def _set_final_resource_utilization(graph: Graph,
220
220
  ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
221
221
  w_qcs = {n.name: n.final_weights_quantization_cfg for n in graph.nodes}
222
222
  a_qcs = {n.name: n.final_activation_quantization_cfg for n in graph.nodes}
223
- final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
223
+ final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused,
224
224
  BitwidthMode.QCustom, act_qcs=a_qcs, w_qcs=w_qcs,
225
225
  ru_targets=ru_targets, allow_unused_qcs=True)
226
226
  summary = final_ru.get_summary_str(restricted=True)