mct-nightly 2.1.0.20240807.445__py3-none-any.whl → 2.1.0.20240809.432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240809.432.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240809.432.dist-info}/RECORD +33 -32
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/constants.py +14 -1
  5. model_compression_toolkit/core/common/fusion/graph_fuser.py +135 -0
  6. model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +4 -0
  7. model_compression_toolkit/core/common/quantization/debug_config.py +4 -1
  8. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +29 -1
  9. model_compression_toolkit/core/runner.py +21 -1
  10. model_compression_toolkit/gptq/keras/quantization_facade.py +13 -11
  11. model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -11
  12. model_compression_toolkit/metadata.py +61 -2
  13. model_compression_toolkit/ptq/keras/quantization_facade.py +12 -10
  14. model_compression_toolkit/ptq/pytorch/quantization_facade.py +12 -12
  15. model_compression_toolkit/qat/keras/quantization_facade.py +8 -8
  16. model_compression_toolkit/qat/pytorch/quantization_facade.py +8 -8
  17. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +10 -13
  18. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +68 -52
  19. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +35 -29
  20. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +35 -28
  21. model_compression_toolkit/xquant/common/constants.py +3 -0
  22. model_compression_toolkit/xquant/common/core_report_generator.py +9 -1
  23. model_compression_toolkit/xquant/common/framework_report_utils.py +5 -14
  24. model_compression_toolkit/xquant/common/tensorboard_utils.py +30 -5
  25. model_compression_toolkit/xquant/keras/facade_xquant_report.py +2 -0
  26. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -1
  27. model_compression_toolkit/xquant/keras/tensorboard_utils.py +101 -4
  28. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +2 -0
  29. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -2
  30. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +109 -3
  31. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240809.432.dist-info}/LICENSE.md +0 -0
  32. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240809.432.dist-info}/WHEEL +0 -0
  33. {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240809.432.dist-info}/top_level.txt +0 -0
@@ -12,18 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
15
+ from model_compression_toolkit.constants import MAX_CUT
16
16
  from model_compression_toolkit.core.common import Graph
17
17
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
18
18
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
19
 
20
20
 
21
21
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
22
- from model_compression_toolkit.xquant.common.constants import TENSORBOARD_DEFAULT_TAG
22
+ from model_compression_toolkit.xquant.common.constants import TENSORBOARD_DEFAULT_TAG, OUTPUT_SIMILARITY_METRICS_REPR, \
23
+ OUTPUT_SIMILARITY_METRICS_VAL
23
24
  from model_compression_toolkit.logger import Logger
24
25
 
25
26
 
26
27
  from typing import Any, Dict, Callable
28
+ from mct_quantizers.keras.metadata import get_metadata
27
29
 
28
30
 
29
31
  class TensorboardUtils:
@@ -52,7 +54,8 @@ class TensorboardUtils:
52
54
  def get_graph_for_tensorboard_display(self,
53
55
  quantized_model: Any,
54
56
  similarity_metrics: Dict[str, Any],
55
- repr_dataset: Callable) -> Graph:
57
+ repr_dataset: Callable,
58
+ quantized_model_metadata: Dict) -> Graph:
56
59
  """
57
60
  Get the graph for Tensorboard display. The framework-specific implementations
58
61
  (like KerasTensorboardUtils and PytorchTensorboardUtils) should implement this
@@ -62,6 +65,7 @@ class TensorboardUtils:
62
65
  quantized_model (Any): The quantized model.
63
66
  similarity_metrics (Dict[str, Any]): Metrics for model similarity.
64
67
  repr_dataset (Callable): Representative dataset function.
68
+ quantized_model_metadata (Dict): Metadata from the quantized model.
65
69
 
66
70
  Returns:
67
71
  Graph: The generated graph for Tensorboard display.
@@ -81,7 +85,8 @@ class TensorboardUtils:
81
85
  def add_graph_to_tensorboard(self,
82
86
  quantized_model: Any,
83
87
  similarity_metrics: Dict[str, Any],
84
- repr_dataset: Callable):
88
+ repr_dataset: Callable,
89
+ quantized_model_metadata: Dict):
85
90
  """
86
91
  Add a graph to Tensorboard. The graph represents the quantized graph
87
92
  with the similarity metrics that were measured in different nodes.
@@ -90,12 +95,32 @@ class TensorboardUtils:
90
95
  quantized_model (Any): The quantized model.
91
96
  similarity_metrics (Dict[str, Any]): The similarity metrics that were collected.
92
97
  repr_dataset (Callable): Representative dataset to use (if needed, like in pytorch case).
98
+ quantized_model_metadata (Dict): Metadata from the quantized model.
93
99
  """
94
100
  # Generate the quantized graph with similarity metrics.
95
101
  tb_graph = self.get_graph_for_tensorboard_display(quantized_model=quantized_model,
96
102
  similarity_metrics=similarity_metrics,
97
- repr_dataset=repr_dataset)
103
+ repr_dataset=repr_dataset,
104
+ quantized_model_metadata=quantized_model_metadata)
98
105
 
99
106
  self.tb_writer.add_graph(tb_graph, TENSORBOARD_DEFAULT_TAG)
100
107
 
108
+ def add_text_information(self,
109
+ similarity_metrics: Dict[str, Dict[str, float]],
110
+ quantized_model_metadata: Dict[str, Any]):
111
+ """
112
+ Adds text information (like max cut and output similarity metrics) to the tensorboard writer.
101
113
 
114
+ Args:
115
+ similarity_metrics (Dict[str, Dict[str, float]]): A dictionary containing similarity metrics between quantized and float models for both representative and validation datasets.
116
+ quantized_model_metadata (Dict): Metadata from the quantized model.
117
+ """
118
+ # Add the computed max cut
119
+ maxcut_str = f"MaxCut: {quantized_model_metadata['scheduling_info'][MAX_CUT]}"
120
+ self.tb_writer.add_text(maxcut_str, MAX_CUT)
121
+
122
+ # Add output similarity between quantized and float models on representative and validation datasets
123
+ output_similarity_repr = f"Similarity Metrics on outputs using representative dataset: \n" + "\n".join([f"{key}: {value:.4f}" for key, value in similarity_metrics[OUTPUT_SIMILARITY_METRICS_REPR].items()])
124
+ output_similarity_val = f"Similarity Metrics on outputs using validation dataset: \n" + "\n".join([f"{key}: {value:.4f}" for key, value in similarity_metrics[OUTPUT_SIMILARITY_METRICS_VAL].items()])
125
+ self.tb_writer.add_text(output_similarity_repr, OUTPUT_SIMILARITY_METRICS_REPR)
126
+ self.tb_writer.add_text(output_similarity_val, OUTPUT_SIMILARITY_METRICS_VAL)
@@ -56,6 +56,8 @@ if FOUND_TF:
56
56
  fw_report_utils=keras_report_utils,
57
57
  xquant_config=xquant_config)
58
58
 
59
+ Logger.shutdown()
60
+
59
61
  return _collected_data
60
62
  else:
61
63
  def xquant_report_keras_experimental(*args, **kwargs):
@@ -25,6 +25,7 @@ from model_compression_toolkit.xquant.keras.model_analyzer import KerasModelAnal
25
25
 
26
26
  from model_compression_toolkit.xquant.keras.similarity_functions import KerasSimilarityFunctions
27
27
  from model_compression_toolkit.xquant.keras.tensorboard_utils import KerasTensorboardUtils
28
+ from mct_quantizers.keras.metadata import get_metadata
28
29
 
29
30
 
30
31
  class KerasReportUtils(FrameworkReportUtils):
@@ -57,4 +58,5 @@ class KerasReportUtils(FrameworkReportUtils):
57
58
  similarity_calculator,
58
59
  dataset_utils,
59
60
  model_folding,
60
- tb_utils)
61
+ tb_utils,
62
+ get_metadata)
@@ -12,19 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Dict, Callable
15
+ from typing import Dict, Callable, Any
16
16
 
17
17
  import keras
18
18
 
19
- from model_compression_toolkit.core.common import Graph
19
+ from mct_quantizers import KerasActivationQuantizationHolder, KerasQuantizationWrapper
20
+ from model_compression_toolkit.constants import MEM_ELEMENTS, CUTS, OP_ORDER, NODE_NAME, NODE_OUTPUT_INDEX, TOTAL_SIZE, FUSED_NODES_MAPPING
21
+ from model_compression_toolkit.core.common import Graph, BaseNode
20
22
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
23
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
22
24
 
23
25
  from model_compression_toolkit.core.keras.reader.reader import model_reader
24
26
 
25
- from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTERMEDIATE_SIMILARITY_METRICS_REPR, XQUANT_VAL, INTERMEDIATE_SIMILARITY_METRICS_VAL
27
+ from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTERMEDIATE_SIMILARITY_METRICS_REPR, \
28
+ XQUANT_VAL, INTERMEDIATE_SIMILARITY_METRICS_VAL, CUT_MEMORY_ELEMENTS, CUT_TOTAL_SIZE
26
29
  from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
27
30
 
31
+ NODES_WITHOUT_CUT_INFO = [KerasActivationQuantizationHolder]
32
+
28
33
 
29
34
  class KerasTensorboardUtils(TensorboardUtils):
30
35
  """
@@ -52,7 +57,8 @@ class KerasTensorboardUtils(TensorboardUtils):
52
57
  def get_graph_for_tensorboard_display(self,
53
58
  quantized_model: keras.Model,
54
59
  similarity_metrics: Dict[str, Dict[str, float]],
55
- repr_dataset: Callable) -> Graph:
60
+ repr_dataset: Callable,
61
+ quantized_model_metadata: Dict) -> Graph:
56
62
  """
57
63
  Generate a graph suitable for TensorBoard display from the provided quantized model
58
64
  and similarity metrics.
@@ -62,6 +68,7 @@ class KerasTensorboardUtils(TensorboardUtils):
62
68
  similarity_metrics (Dict[str, Dict[str, float]]): A dictionary containing similarity metrics
63
69
  for different nodes in the model.
64
70
  repr_dataset (Callable): A function or callable that provides the representative dataset.
71
+ quantized_model_metadata (Dict): Metadata from the quantized model.
65
72
 
66
73
  Returns:
67
74
  Graph: A graph object representing the quantized model, annotated with similarity metrics.
@@ -69,6 +76,8 @@ class KerasTensorboardUtils(TensorboardUtils):
69
76
  # Read the quantized model into a graph structure.
70
77
  quant_graph = model_reader(quantized_model)
71
78
 
79
+ insert_cut_info_into_graph(quant_graph, quantized_model_metadata)
80
+
72
81
  # Iterate over each node in the graph.
73
82
  for node in quant_graph.nodes:
74
83
  # Check if the node's name is in the similarity metrics for intermediate representation.
@@ -82,3 +91,91 @@ class KerasTensorboardUtils(TensorboardUtils):
82
91
  node.framework_attr[XQUANT_VAL] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL][node.name]
83
92
 
84
93
  return quant_graph
94
+
95
+
96
+ def populate_fused_node_memory_elements(quantized_model_metadata: Dict[str, Any]) -> Dict[str, list]:
97
+ """
98
+ Populate a dictionary mapping fused node names to their corresponding memory elements.
99
+
100
+ Args:
101
+ quantized_model_metadata (dict): Metadata containing scheduling information for the quantized model.
102
+
103
+ Returns:
104
+ dict: A dictionary with fused node names as keys and memory elements as values.
105
+ """
106
+ fused_node_to_memory_elements = {}
107
+
108
+ for cut in quantized_model_metadata['scheduling_info'][CUTS]:
109
+ fused_node = cut[OP_ORDER][-1]
110
+
111
+ # Ignore dummy types
112
+ if not fused_node.startswith('DummyType'):
113
+ fused_node_to_memory_elements[fused_node] = cut[MEM_ELEMENTS]
114
+
115
+ return fused_node_to_memory_elements
116
+
117
+ def assign_cut_info_to_node(node: BaseNode, memory_elements: list):
118
+ """
119
+ Assign cut memory elements and total size to a node's attributes according to the
120
+ tensors in the cut of this node.
121
+
122
+ Args:
123
+ node (Node): The node to which the memory elements and total size will be assigned.
124
+ memory_elements (list): List of memory elements to be assigned to the node since they are in memory during this node inference.
125
+ """
126
+ node.framework_attr[CUT_MEMORY_ELEMENTS] = [
127
+ f"{mem_element[NODE_NAME]}_outTensor_{mem_element[NODE_OUTPUT_INDEX]}"
128
+ for mem_element in memory_elements
129
+ ]
130
+ node.framework_attr[CUT_TOTAL_SIZE] = sum(
131
+ mem_element[TOTAL_SIZE] for mem_element in memory_elements
132
+ )
133
+
134
+ def process_node_cut_info(node: BaseNode,
135
+ fused_node_to_memory_elements: Dict[str, list],
136
+ quantized_model_metadata: Dict[str, Any]):
137
+ """
138
+ Process and assign cut information for a given node based on metadata and fused nodes mapping.
139
+
140
+ Args:
141
+ node (Node): The node to process.
142
+ fused_node_to_memory_elements (dict): Dictionary mapping fused nodes to memory elements.
143
+ quantized_model_metadata (dict): Metadata containing scheduling information for the quantized model.
144
+ """
145
+ if node.name in fused_node_to_memory_elements:
146
+ # Directly assign cut info if node name is in fused_node_to_memory_elements
147
+ assign_cut_info_to_node(node, fused_node_to_memory_elements[node.name])
148
+
149
+ elif node.name in quantized_model_metadata['scheduling_info'][FUSED_NODES_MAPPING]:
150
+ # Assign cut info if the node name is in the fused nodes mapping
151
+ original_node_name = quantized_model_metadata['scheduling_info'][FUSED_NODES_MAPPING][node.name]
152
+ assign_cut_info_to_node(node, fused_node_to_memory_elements[original_node_name])
153
+
154
+ elif node.type == KerasQuantizationWrapper:
155
+ if node.framework_attr['layer']['config']['name'] in fused_node_to_memory_elements:
156
+ # Assign cut info if the node is a KerasQuantizationWrapper with a matching layer name
157
+ assign_cut_info_to_node(node, fused_node_to_memory_elements[node.framework_attr['layer']['config']['name']])
158
+
159
+ elif node.framework_attr['layer']['config']['name'] in quantized_model_metadata['scheduling_info'][FUSED_NODES_MAPPING]:
160
+ # Assign cut info if the node is a KerasQuantizationWrapper and its layer name is in the fused nodes mapping
161
+ original_node_name = quantized_model_metadata['scheduling_info'][FUSED_NODES_MAPPING][node.framework_attr['layer']['config']['name']]
162
+ assign_cut_info_to_node(node, fused_node_to_memory_elements[original_node_name])
163
+
164
+ def insert_cut_info_into_graph(quant_graph: Graph, quantized_model_metadata: Dict[str, Any]):
165
+ """
166
+ Insert information about cut tensors into the graph nodes based on the provided metadata.
167
+
168
+ Args:
169
+ quant_graph (Graph): The graph representing the quantized model.
170
+ quantized_model_metadata (dict): Metadata containing scheduling information for the quantized model.
171
+ """
172
+ # Populate the mapping of fused nodes to memory elements
173
+ fused_node_to_memory_elements = populate_fused_node_memory_elements(quantized_model_metadata)
174
+
175
+ for node in quant_graph.nodes:
176
+ # Skip nodes without cut information
177
+ if node.type not in NODES_WITHOUT_CUT_INFO:
178
+ process_node_cut_info(node,
179
+ fused_node_to_memory_elements,
180
+ quantized_model_metadata)
181
+
@@ -54,6 +54,8 @@ if FOUND_TORCH:
54
54
  fw_report_utils=pytorch_report_utils,
55
55
  xquant_config=xquant_config)
56
56
 
57
+ Logger.shutdown()
58
+
57
59
  return _collected_data
58
60
 
59
61
  else:
@@ -24,7 +24,7 @@ from model_compression_toolkit.xquant.pytorch.dataset_utils import PytorchDatase
24
24
  from model_compression_toolkit.xquant.pytorch.model_analyzer import PytorchModelAnalyzer
25
25
  from model_compression_toolkit.xquant.pytorch.similarity_functions import PytorchSimilarityFunctions
26
26
  from model_compression_toolkit.xquant.pytorch.tensorboard_utils import PytorchTensorboardUtils
27
-
27
+ from mct_quantizers.pytorch.metadata import get_metadata
28
28
 
29
29
  class PytorchReportUtils(FrameworkReportUtils):
30
30
  """
@@ -58,4 +58,5 @@ class PytorchReportUtils(FrameworkReportUtils):
58
58
  tb_utils=tb_utils,
59
59
  dataset_utils=dataset_utils,
60
60
  similarity_calculator=similarity_calculator,
61
- model_folding_utils=model_folding)
61
+ model_folding_utils=model_folding,
62
+ get_metadata_fn=get_metadata)
@@ -12,9 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
15
+ from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
16
+ from model_compression_toolkit.core.common import Graph, BaseNode
16
17
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
17
- from typing import Dict, Any, Callable
18
+ from typing import Dict, Any, Callable, List
18
19
 
19
20
  import torch
20
21
 
@@ -24,6 +25,14 @@ from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTER
24
25
  from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
25
26
  from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
26
27
 
28
+ NODES_WITHOUT_CUT_INFO = [torch.fake_quantize_per_tensor_affine]
29
+
30
+ def is_wrapped_linear_op(quantized_model, node):
31
+ # Check if a node in a torch fx graph represents a linear layer (conv2d/linear)
32
+ # that is wrapped in the quantized model
33
+ return hasattr(quantized_model, node.name.removesuffix('_layer')) and isinstance(
34
+ getattr(quantized_model, node.name.removesuffix('_layer')), PytorchQuantizationWrapper)
35
+
27
36
  class PytorchTensorboardUtils(TensorboardUtils):
28
37
  """
29
38
  Utility class for handling PyTorch models with TensorBoard. Inherits from TensorboardUtils.
@@ -49,7 +58,8 @@ class PytorchTensorboardUtils(TensorboardUtils):
49
58
  def get_graph_for_tensorboard_display(self,
50
59
  quantized_model: torch.nn.Module,
51
60
  similarity_metrics: Dict[str, Any],
52
- repr_dataset: Callable):
61
+ repr_dataset: Callable,
62
+ quantized_model_metadata: Dict):
53
63
  """
54
64
  Get the graph to display on TensorBoard. The graph represents the quantized model
55
65
  with the similarity metrics that were measured.
@@ -58,6 +68,7 @@ class PytorchTensorboardUtils(TensorboardUtils):
58
68
  quantized_model: The quantized model to be displayed on TensorBoard.
59
69
  similarity_metrics: Dictionary containing the collected similarity metrics values.
60
70
  repr_dataset: Callable that generates the representative dataset used during graph building.
71
+ quantized_model_metadata (Dict): Metadata from the quantized model.
61
72
 
62
73
  Returns:
63
74
  The updated quantized model graph with similarity metrics embedded.
@@ -68,6 +79,8 @@ class PytorchTensorboardUtils(TensorboardUtils):
68
79
  to_tensor=self.fw_impl.to_tensor,
69
80
  to_numpy=self.fw_impl.to_numpy)
70
81
 
82
+ insert_cut_info_into_graph(quant_graph, quantized_model_metadata, quantized_model)
83
+
71
84
  # Iterate through each node in the graph
72
85
  for node in quant_graph.nodes:
73
86
  # Check and add similarity metrics for each node in the graph
@@ -85,3 +98,96 @@ class PytorchTensorboardUtils(TensorboardUtils):
85
98
  node.name.removesuffix("_layer")]
86
99
 
87
100
  return quant_graph
101
+
102
+
103
+ def populate_fused_node_memory_elements(quantized_model_metadata: Dict[str, Any]) -> Dict[str, list]:
104
+ """
105
+ Populate a dictionary mapping fused node names to their corresponding memory elements.
106
+
107
+ Args:
108
+ quantized_model_metadata: Metadata containing scheduling information for the quantized model.
109
+
110
+ Returns:
111
+ dict: A dictionary with fused node names as keys and memory elements as values.
112
+ """
113
+ fused_node_to_memory_elements = {}
114
+
115
+ for cut in quantized_model_metadata['scheduling_info']['cuts']:
116
+ fused_node = cut['op_order'][-1]
117
+
118
+ # Ignore dummy types
119
+ if not fused_node.startswith('DummyType'):
120
+ fused_node_to_memory_elements[fused_node] = cut['mem_elements']
121
+
122
+ return fused_node_to_memory_elements
123
+
124
+
125
+ def assign_cut_info_to_node(node: BaseNode, memory_elements: List[dict]):
126
+ """
127
+ Assign cut memory elements and total size to a node's framework attributes.
128
+
129
+ Args:
130
+ node (Node): The node to which the memory elements and total size will be assigned.
131
+ memory_elements (list): List of memory elements to be assigned to the node.
132
+ """
133
+ node.framework_attr['cut_memory_elements'] = [
134
+ f"{mem_element['node_name']}_outTensor_{mem_element['node_output_index']}"
135
+ for mem_element in memory_elements
136
+ ]
137
+ node.framework_attr['cut_total_size'] = sum(
138
+ mem_element['total_size'] for mem_element in memory_elements
139
+ )
140
+
141
+
142
+ def process_node_cut_info(node: BaseNode, fused_node_to_memory_elements: Dict[str, list], quantized_model_metadata: Dict[str, Any], quantized_model: torch.nn.Module):
143
+ """
144
+ Process and assign cut information for a given node based on metadata and fused nodes mapping.
145
+
146
+ Args:
147
+ node: The node to process.
148
+ fused_node_to_memory_elements: Dictionary mapping fused nodes to memory elements.
149
+ quantized_model_metadata: Metadata containing scheduling information for the quantized model.
150
+ quantized_model: The quantized model.
151
+ """
152
+ node_name_without_suffix = node.name.removesuffix('_layer')
153
+ fused_nodes_mapping = quantized_model_metadata['scheduling_info']['fused_nodes_mapping']
154
+
155
+ if node.name in fused_node_to_memory_elements:
156
+ # Directly assign cut info if node name is in fused_node_to_memory_elements
157
+ assign_cut_info_to_node(node, fused_node_to_memory_elements[node.name])
158
+
159
+ elif is_wrapped_linear_op(quantized_model, node) and node_name_without_suffix in fused_node_to_memory_elements:
160
+ # Assign cut info if the node is a wrapped linear operation with a matching name without suffix
161
+ assign_cut_info_to_node(node, fused_node_to_memory_elements[node_name_without_suffix])
162
+
163
+ elif node.name in fused_nodes_mapping:
164
+ # Assign cut info if the node name is in the fused nodes mapping
165
+ original_node_name = fused_nodes_mapping[node.name]
166
+ assign_cut_info_to_node(node, fused_node_to_memory_elements[original_node_name])
167
+
168
+ elif is_wrapped_linear_op(quantized_model, node) and node_name_without_suffix in fused_nodes_mapping:
169
+ # Assign cut info if the node is a wrapped linear operation and its name without suffix is in the fused nodes mapping
170
+ original_node_name = fused_nodes_mapping[node_name_without_suffix]
171
+ assign_cut_info_to_node(node, fused_node_to_memory_elements[original_node_name])
172
+
173
+
174
+ def insert_cut_info_into_graph(quant_graph: Graph,
175
+ quantized_model_metadata: Dict[str, Any],
176
+ quantized_model: torch.nn.Module):
177
+ """
178
+ Insert information about cut tensors into the graph nodes based on the provided metadata.
179
+
180
+ Args:
181
+ quant_graph: The graph representing the quantized model.
182
+ quantized_model_metadata: Metadata containing scheduling information for the quantized model.
183
+ quantized_model: The quantized model.
184
+ """
185
+ # Populate the mapping of fused nodes to memory elements
186
+ fused_node_to_memory_elements = populate_fused_node_memory_elements(quantized_model_metadata)
187
+
188
+ for node in quant_graph.nodes:
189
+ # Skip nodes without cut information
190
+ if node.type not in NODES_WITHOUT_CUT_INFO:
191
+ process_node_cut_info(node, fused_node_to_memory_elements, quantized_model_metadata, quantized_model)
192
+
193
+