mct-nightly 2.1.0.20240807.445__py3-none-any.whl → 2.1.0.20240808.431__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/RECORD +33 -32
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/constants.py +14 -1
- model_compression_toolkit/core/common/fusion/graph_fuser.py +135 -0
- model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +4 -0
- model_compression_toolkit/core/common/quantization/debug_config.py +4 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +29 -1
- model_compression_toolkit/core/runner.py +21 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +13 -11
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +13 -11
- model_compression_toolkit/metadata.py +61 -2
- model_compression_toolkit/ptq/keras/quantization_facade.py +12 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +12 -12
- model_compression_toolkit/qat/keras/quantization_facade.py +8 -8
- model_compression_toolkit/qat/pytorch/quantization_facade.py +8 -8
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +10 -13
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +68 -52
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +35 -29
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +35 -28
- model_compression_toolkit/xquant/common/constants.py +3 -0
- model_compression_toolkit/xquant/common/core_report_generator.py +9 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +5 -14
- model_compression_toolkit/xquant/common/tensorboard_utils.py +30 -5
- model_compression_toolkit/xquant/keras/facade_xquant_report.py +2 -0
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -1
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +101 -4
- model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +2 -0
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +3 -2
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +109 -3
- {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/top_level.txt +0 -0
@@ -12,18 +12,20 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
from model_compression_toolkit.constants import MAX_CUT
|
16
16
|
from model_compression_toolkit.core.common import Graph
|
17
17
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
18
18
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
19
19
|
|
20
20
|
|
21
21
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
|
22
|
-
from model_compression_toolkit.xquant.common.constants import TENSORBOARD_DEFAULT_TAG
|
22
|
+
from model_compression_toolkit.xquant.common.constants import TENSORBOARD_DEFAULT_TAG, OUTPUT_SIMILARITY_METRICS_REPR, \
|
23
|
+
OUTPUT_SIMILARITY_METRICS_VAL
|
23
24
|
from model_compression_toolkit.logger import Logger
|
24
25
|
|
25
26
|
|
26
27
|
from typing import Any, Dict, Callable
|
28
|
+
from mct_quantizers.keras.metadata import get_metadata
|
27
29
|
|
28
30
|
|
29
31
|
class TensorboardUtils:
|
@@ -52,7 +54,8 @@ class TensorboardUtils:
|
|
52
54
|
def get_graph_for_tensorboard_display(self,
|
53
55
|
quantized_model: Any,
|
54
56
|
similarity_metrics: Dict[str, Any],
|
55
|
-
repr_dataset: Callable
|
57
|
+
repr_dataset: Callable,
|
58
|
+
quantized_model_metadata: Dict) -> Graph:
|
56
59
|
"""
|
57
60
|
Get the graph for Tensorboard display. The framework-specific implementations
|
58
61
|
(like KerasTensorboardUtils and PytorchTensorboardUtils) should implement this
|
@@ -62,6 +65,7 @@ class TensorboardUtils:
|
|
62
65
|
quantized_model (Any): The quantized model.
|
63
66
|
similarity_metrics (Dict[str, Any]): Metrics for model similarity.
|
64
67
|
repr_dataset (Callable): Representative dataset function.
|
68
|
+
quantized_model_metadata (Dict): Metadata from the quantized model.
|
65
69
|
|
66
70
|
Returns:
|
67
71
|
Graph: The generated graph for Tensorboard display.
|
@@ -81,7 +85,8 @@ class TensorboardUtils:
|
|
81
85
|
def add_graph_to_tensorboard(self,
|
82
86
|
quantized_model: Any,
|
83
87
|
similarity_metrics: Dict[str, Any],
|
84
|
-
repr_dataset: Callable
|
88
|
+
repr_dataset: Callable,
|
89
|
+
quantized_model_metadata: Dict):
|
85
90
|
"""
|
86
91
|
Add a graph to Tensorboard. The graph represents the quantized graph
|
87
92
|
with the similarity metrics that were measured in different nodes.
|
@@ -90,12 +95,32 @@ class TensorboardUtils:
|
|
90
95
|
quantized_model (Any): The quantized model.
|
91
96
|
similarity_metrics (Dict[str, Any]): The similarity metrics that were collected.
|
92
97
|
repr_dataset (Callable): Representative dataset to use (if needed, like in pytorch case).
|
98
|
+
quantized_model_metadata (Dict): Metadata from the quantized model.
|
93
99
|
"""
|
94
100
|
# Generate the quantized graph with similarity metrics.
|
95
101
|
tb_graph = self.get_graph_for_tensorboard_display(quantized_model=quantized_model,
|
96
102
|
similarity_metrics=similarity_metrics,
|
97
|
-
repr_dataset=repr_dataset
|
103
|
+
repr_dataset=repr_dataset,
|
104
|
+
quantized_model_metadata=quantized_model_metadata)
|
98
105
|
|
99
106
|
self.tb_writer.add_graph(tb_graph, TENSORBOARD_DEFAULT_TAG)
|
100
107
|
|
108
|
+
def add_text_information(self,
|
109
|
+
similarity_metrics: Dict[str, Dict[str, float]],
|
110
|
+
quantized_model_metadata: Dict[str, Any]):
|
111
|
+
"""
|
112
|
+
Adds text information (like max cut and output similarity metrics) to the tensorboard writer.
|
101
113
|
|
114
|
+
Args:
|
115
|
+
similarity_metrics (Dict[str, Dict[str, float]]): A dictionary containing similarity metrics between quantized and float models for both representative and validation datasets.
|
116
|
+
quantized_model_metadata (Dict): Metadata from the quantized model.
|
117
|
+
"""
|
118
|
+
# Add the computed max cut
|
119
|
+
maxcut_str = f"MaxCut: {quantized_model_metadata['scheduling_info'][MAX_CUT]}"
|
120
|
+
self.tb_writer.add_text(maxcut_str, MAX_CUT)
|
121
|
+
|
122
|
+
# Add output similarity between quantized and float models on representative and validation datasets
|
123
|
+
output_similarity_repr = f"Similarity Metrics on outputs using representative dataset: \n" + "\n".join([f"{key}: {value:.4f}" for key, value in similarity_metrics[OUTPUT_SIMILARITY_METRICS_REPR].items()])
|
124
|
+
output_similarity_val = f"Similarity Metrics on outputs using validation dataset: \n" + "\n".join([f"{key}: {value:.4f}" for key, value in similarity_metrics[OUTPUT_SIMILARITY_METRICS_VAL].items()])
|
125
|
+
self.tb_writer.add_text(output_similarity_repr, OUTPUT_SIMILARITY_METRICS_REPR)
|
126
|
+
self.tb_writer.add_text(output_similarity_val, OUTPUT_SIMILARITY_METRICS_VAL)
|
@@ -25,6 +25,7 @@ from model_compression_toolkit.xquant.keras.model_analyzer import KerasModelAnal
|
|
25
25
|
|
26
26
|
from model_compression_toolkit.xquant.keras.similarity_functions import KerasSimilarityFunctions
|
27
27
|
from model_compression_toolkit.xquant.keras.tensorboard_utils import KerasTensorboardUtils
|
28
|
+
from mct_quantizers.keras.metadata import get_metadata
|
28
29
|
|
29
30
|
|
30
31
|
class KerasReportUtils(FrameworkReportUtils):
|
@@ -57,4 +58,5 @@ class KerasReportUtils(FrameworkReportUtils):
|
|
57
58
|
similarity_calculator,
|
58
59
|
dataset_utils,
|
59
60
|
model_folding,
|
60
|
-
tb_utils
|
61
|
+
tb_utils,
|
62
|
+
get_metadata)
|
@@ -12,19 +12,24 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from typing import Dict, Callable
|
15
|
+
from typing import Dict, Callable, Any
|
16
16
|
|
17
17
|
import keras
|
18
18
|
|
19
|
-
from
|
19
|
+
from mct_quantizers import KerasActivationQuantizationHolder, KerasQuantizationWrapper
|
20
|
+
from model_compression_toolkit.constants import MEM_ELEMENTS, CUTS, OP_ORDER, NODE_NAME, NODE_OUTPUT_INDEX, TOTAL_SIZE, FUSED_NODES_MAPPING
|
21
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
20
22
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
23
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
22
24
|
|
23
25
|
from model_compression_toolkit.core.keras.reader.reader import model_reader
|
24
26
|
|
25
|
-
from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTERMEDIATE_SIMILARITY_METRICS_REPR,
|
27
|
+
from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTERMEDIATE_SIMILARITY_METRICS_REPR, \
|
28
|
+
XQUANT_VAL, INTERMEDIATE_SIMILARITY_METRICS_VAL, CUT_MEMORY_ELEMENTS, CUT_TOTAL_SIZE
|
26
29
|
from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
|
27
30
|
|
31
|
+
NODES_WITHOUT_CUT_INFO = [KerasActivationQuantizationHolder]
|
32
|
+
|
28
33
|
|
29
34
|
class KerasTensorboardUtils(TensorboardUtils):
|
30
35
|
"""
|
@@ -52,7 +57,8 @@ class KerasTensorboardUtils(TensorboardUtils):
|
|
52
57
|
def get_graph_for_tensorboard_display(self,
|
53
58
|
quantized_model: keras.Model,
|
54
59
|
similarity_metrics: Dict[str, Dict[str, float]],
|
55
|
-
repr_dataset: Callable
|
60
|
+
repr_dataset: Callable,
|
61
|
+
quantized_model_metadata: Dict) -> Graph:
|
56
62
|
"""
|
57
63
|
Generate a graph suitable for TensorBoard display from the provided quantized model
|
58
64
|
and similarity metrics.
|
@@ -62,6 +68,7 @@ class KerasTensorboardUtils(TensorboardUtils):
|
|
62
68
|
similarity_metrics (Dict[str, Dict[str, float]]): A dictionary containing similarity metrics
|
63
69
|
for different nodes in the model.
|
64
70
|
repr_dataset (Callable): A function or callable that provides the representative dataset.
|
71
|
+
quantized_model_metadata (Dict): Metadata from the quantized model.
|
65
72
|
|
66
73
|
Returns:
|
67
74
|
Graph: A graph object representing the quantized model, annotated with similarity metrics.
|
@@ -69,6 +76,8 @@ class KerasTensorboardUtils(TensorboardUtils):
|
|
69
76
|
# Read the quantized model into a graph structure.
|
70
77
|
quant_graph = model_reader(quantized_model)
|
71
78
|
|
79
|
+
insert_cut_info_into_graph(quant_graph, quantized_model_metadata)
|
80
|
+
|
72
81
|
# Iterate over each node in the graph.
|
73
82
|
for node in quant_graph.nodes:
|
74
83
|
# Check if the node's name is in the similarity metrics for intermediate representation.
|
@@ -82,3 +91,91 @@ class KerasTensorboardUtils(TensorboardUtils):
|
|
82
91
|
node.framework_attr[XQUANT_VAL] = similarity_metrics[INTERMEDIATE_SIMILARITY_METRICS_VAL][node.name]
|
83
92
|
|
84
93
|
return quant_graph
|
94
|
+
|
95
|
+
|
96
|
+
def populate_fused_node_memory_elements(quantized_model_metadata: Dict[str, Any]) -> Dict[str, list]:
|
97
|
+
"""
|
98
|
+
Populate a dictionary mapping fused node names to their corresponding memory elements.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
quantized_model_metadata (dict): Metadata containing scheduling information for the quantized model.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
dict: A dictionary with fused node names as keys and memory elements as values.
|
105
|
+
"""
|
106
|
+
fused_node_to_memory_elements = {}
|
107
|
+
|
108
|
+
for cut in quantized_model_metadata['scheduling_info'][CUTS]:
|
109
|
+
fused_node = cut[OP_ORDER][-1]
|
110
|
+
|
111
|
+
# Ignore dummy types
|
112
|
+
if not fused_node.startswith('DummyType'):
|
113
|
+
fused_node_to_memory_elements[fused_node] = cut[MEM_ELEMENTS]
|
114
|
+
|
115
|
+
return fused_node_to_memory_elements
|
116
|
+
|
117
|
+
def assign_cut_info_to_node(node: BaseNode, memory_elements: list):
|
118
|
+
"""
|
119
|
+
Assign cut memory elements and total size to a node's attributes according to the
|
120
|
+
tensors in the cut of this node.
|
121
|
+
|
122
|
+
Args:
|
123
|
+
node (Node): The node to which the memory elements and total size will be assigned.
|
124
|
+
memory_elements (list): List of memory elements to be assigned to the node since they are in memory during this node inference.
|
125
|
+
"""
|
126
|
+
node.framework_attr[CUT_MEMORY_ELEMENTS] = [
|
127
|
+
f"{mem_element[NODE_NAME]}_outTensor_{mem_element[NODE_OUTPUT_INDEX]}"
|
128
|
+
for mem_element in memory_elements
|
129
|
+
]
|
130
|
+
node.framework_attr[CUT_TOTAL_SIZE] = sum(
|
131
|
+
mem_element[TOTAL_SIZE] for mem_element in memory_elements
|
132
|
+
)
|
133
|
+
|
134
|
+
def process_node_cut_info(node: BaseNode,
|
135
|
+
fused_node_to_memory_elements: Dict[str, list],
|
136
|
+
quantized_model_metadata: Dict[str, Any]):
|
137
|
+
"""
|
138
|
+
Process and assign cut information for a given node based on metadata and fused nodes mapping.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
node (Node): The node to process.
|
142
|
+
fused_node_to_memory_elements (dict): Dictionary mapping fused nodes to memory elements.
|
143
|
+
quantized_model_metadata (dict): Metadata containing scheduling information for the quantized model.
|
144
|
+
"""
|
145
|
+
if node.name in fused_node_to_memory_elements:
|
146
|
+
# Directly assign cut info if node name is in fused_node_to_memory_elements
|
147
|
+
assign_cut_info_to_node(node, fused_node_to_memory_elements[node.name])
|
148
|
+
|
149
|
+
elif node.name in quantized_model_metadata['scheduling_info'][FUSED_NODES_MAPPING]:
|
150
|
+
# Assign cut info if the node name is in the fused nodes mapping
|
151
|
+
original_node_name = quantized_model_metadata['scheduling_info'][FUSED_NODES_MAPPING][node.name]
|
152
|
+
assign_cut_info_to_node(node, fused_node_to_memory_elements[original_node_name])
|
153
|
+
|
154
|
+
elif node.type == KerasQuantizationWrapper:
|
155
|
+
if node.framework_attr['layer']['config']['name'] in fused_node_to_memory_elements:
|
156
|
+
# Assign cut info if the node is a KerasQuantizationWrapper with a matching layer name
|
157
|
+
assign_cut_info_to_node(node, fused_node_to_memory_elements[node.framework_attr['layer']['config']['name']])
|
158
|
+
|
159
|
+
elif node.framework_attr['layer']['config']['name'] in quantized_model_metadata['scheduling_info'][FUSED_NODES_MAPPING]:
|
160
|
+
# Assign cut info if the node is a KerasQuantizationWrapper and its layer name is in the fused nodes mapping
|
161
|
+
original_node_name = quantized_model_metadata['scheduling_info'][FUSED_NODES_MAPPING][node.framework_attr['layer']['config']['name']]
|
162
|
+
assign_cut_info_to_node(node, fused_node_to_memory_elements[original_node_name])
|
163
|
+
|
164
|
+
def insert_cut_info_into_graph(quant_graph: Graph, quantized_model_metadata: Dict[str, Any]):
|
165
|
+
"""
|
166
|
+
Insert information about cut tensors into the graph nodes based on the provided metadata.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
quant_graph (Graph): The graph representing the quantized model.
|
170
|
+
quantized_model_metadata (dict): Metadata containing scheduling information for the quantized model.
|
171
|
+
"""
|
172
|
+
# Populate the mapping of fused nodes to memory elements
|
173
|
+
fused_node_to_memory_elements = populate_fused_node_memory_elements(quantized_model_metadata)
|
174
|
+
|
175
|
+
for node in quant_graph.nodes:
|
176
|
+
# Skip nodes without cut information
|
177
|
+
if node.type not in NODES_WITHOUT_CUT_INFO:
|
178
|
+
process_node_cut_info(node,
|
179
|
+
fused_node_to_memory_elements,
|
180
|
+
quantized_model_metadata)
|
181
|
+
|
@@ -24,7 +24,7 @@ from model_compression_toolkit.xquant.pytorch.dataset_utils import PytorchDatase
|
|
24
24
|
from model_compression_toolkit.xquant.pytorch.model_analyzer import PytorchModelAnalyzer
|
25
25
|
from model_compression_toolkit.xquant.pytorch.similarity_functions import PytorchSimilarityFunctions
|
26
26
|
from model_compression_toolkit.xquant.pytorch.tensorboard_utils import PytorchTensorboardUtils
|
27
|
-
|
27
|
+
from mct_quantizers.pytorch.metadata import get_metadata
|
28
28
|
|
29
29
|
class PytorchReportUtils(FrameworkReportUtils):
|
30
30
|
"""
|
@@ -58,4 +58,5 @@ class PytorchReportUtils(FrameworkReportUtils):
|
|
58
58
|
tb_utils=tb_utils,
|
59
59
|
dataset_utils=dataset_utils,
|
60
60
|
similarity_calculator=similarity_calculator,
|
61
|
-
model_folding_utils=model_folding
|
61
|
+
model_folding_utils=model_folding,
|
62
|
+
get_metadata_fn=get_metadata)
|
@@ -12,9 +12,10 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
|
16
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
16
17
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
17
|
-
from typing import Dict, Any, Callable
|
18
|
+
from typing import Dict, Any, Callable, List
|
18
19
|
|
19
20
|
import torch
|
20
21
|
|
@@ -24,6 +25,14 @@ from model_compression_toolkit.xquant.common.constants import XQUANT_REPR, INTER
|
|
24
25
|
from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
|
25
26
|
from model_compression_toolkit.xquant.common.tensorboard_utils import TensorboardUtils
|
26
27
|
|
28
|
+
NODES_WITHOUT_CUT_INFO = [torch.fake_quantize_per_tensor_affine]
|
29
|
+
|
30
|
+
def is_wrapped_linear_op(quantized_model, node):
|
31
|
+
# Check if a node in a torch fx graph represents a linear layer (conv2d/linear)
|
32
|
+
# that is wrapped in the quantized model
|
33
|
+
return hasattr(quantized_model, node.name.removesuffix('_layer')) and isinstance(
|
34
|
+
getattr(quantized_model, node.name.removesuffix('_layer')), PytorchQuantizationWrapper)
|
35
|
+
|
27
36
|
class PytorchTensorboardUtils(TensorboardUtils):
|
28
37
|
"""
|
29
38
|
Utility class for handling PyTorch models with TensorBoard. Inherits from TensorboardUtils.
|
@@ -49,7 +58,8 @@ class PytorchTensorboardUtils(TensorboardUtils):
|
|
49
58
|
def get_graph_for_tensorboard_display(self,
|
50
59
|
quantized_model: torch.nn.Module,
|
51
60
|
similarity_metrics: Dict[str, Any],
|
52
|
-
repr_dataset: Callable
|
61
|
+
repr_dataset: Callable,
|
62
|
+
quantized_model_metadata: Dict):
|
53
63
|
"""
|
54
64
|
Get the graph to display on TensorBoard. The graph represents the quantized model
|
55
65
|
with the similarity metrics that were measured.
|
@@ -58,6 +68,7 @@ class PytorchTensorboardUtils(TensorboardUtils):
|
|
58
68
|
quantized_model: The quantized model to be displayed on TensorBoard.
|
59
69
|
similarity_metrics: Dictionary containing the collected similarity metrics values.
|
60
70
|
repr_dataset: Callable that generates the representative dataset used during graph building.
|
71
|
+
quantized_model_metadata (Dict): Metadata from the quantized model.
|
61
72
|
|
62
73
|
Returns:
|
63
74
|
The updated quantized model graph with similarity metrics embedded.
|
@@ -68,6 +79,8 @@ class PytorchTensorboardUtils(TensorboardUtils):
|
|
68
79
|
to_tensor=self.fw_impl.to_tensor,
|
69
80
|
to_numpy=self.fw_impl.to_numpy)
|
70
81
|
|
82
|
+
insert_cut_info_into_graph(quant_graph, quantized_model_metadata, quantized_model)
|
83
|
+
|
71
84
|
# Iterate through each node in the graph
|
72
85
|
for node in quant_graph.nodes:
|
73
86
|
# Check and add similarity metrics for each node in the graph
|
@@ -85,3 +98,96 @@ class PytorchTensorboardUtils(TensorboardUtils):
|
|
85
98
|
node.name.removesuffix("_layer")]
|
86
99
|
|
87
100
|
return quant_graph
|
101
|
+
|
102
|
+
|
103
|
+
def populate_fused_node_memory_elements(quantized_model_metadata: Dict[str, Any]) -> Dict[str, list]:
|
104
|
+
"""
|
105
|
+
Populate a dictionary mapping fused node names to their corresponding memory elements.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
quantized_model_metadata: Metadata containing scheduling information for the quantized model.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
dict: A dictionary with fused node names as keys and memory elements as values.
|
112
|
+
"""
|
113
|
+
fused_node_to_memory_elements = {}
|
114
|
+
|
115
|
+
for cut in quantized_model_metadata['scheduling_info']['cuts']:
|
116
|
+
fused_node = cut['op_order'][-1]
|
117
|
+
|
118
|
+
# Ignore dummy types
|
119
|
+
if not fused_node.startswith('DummyType'):
|
120
|
+
fused_node_to_memory_elements[fused_node] = cut['mem_elements']
|
121
|
+
|
122
|
+
return fused_node_to_memory_elements
|
123
|
+
|
124
|
+
|
125
|
+
def assign_cut_info_to_node(node: BaseNode, memory_elements: List[dict]):
|
126
|
+
"""
|
127
|
+
Assign cut memory elements and total size to a node's framework attributes.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
node (Node): The node to which the memory elements and total size will be assigned.
|
131
|
+
memory_elements (list): List of memory elements to be assigned to the node.
|
132
|
+
"""
|
133
|
+
node.framework_attr['cut_memory_elements'] = [
|
134
|
+
f"{mem_element['node_name']}_outTensor_{mem_element['node_output_index']}"
|
135
|
+
for mem_element in memory_elements
|
136
|
+
]
|
137
|
+
node.framework_attr['cut_total_size'] = sum(
|
138
|
+
mem_element['total_size'] for mem_element in memory_elements
|
139
|
+
)
|
140
|
+
|
141
|
+
|
142
|
+
def process_node_cut_info(node: BaseNode, fused_node_to_memory_elements: Dict[str, list], quantized_model_metadata: Dict[str, Any], quantized_model: torch.nn.Module):
|
143
|
+
"""
|
144
|
+
Process and assign cut information for a given node based on metadata and fused nodes mapping.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
node: The node to process.
|
148
|
+
fused_node_to_memory_elements: Dictionary mapping fused nodes to memory elements.
|
149
|
+
quantized_model_metadata: Metadata containing scheduling information for the quantized model.
|
150
|
+
quantized_model: The quantized model.
|
151
|
+
"""
|
152
|
+
node_name_without_suffix = node.name.removesuffix('_layer')
|
153
|
+
fused_nodes_mapping = quantized_model_metadata['scheduling_info']['fused_nodes_mapping']
|
154
|
+
|
155
|
+
if node.name in fused_node_to_memory_elements:
|
156
|
+
# Directly assign cut info if node name is in fused_node_to_memory_elements
|
157
|
+
assign_cut_info_to_node(node, fused_node_to_memory_elements[node.name])
|
158
|
+
|
159
|
+
elif is_wrapped_linear_op(quantized_model, node) and node_name_without_suffix in fused_node_to_memory_elements:
|
160
|
+
# Assign cut info if the node is a wrapped linear operation with a matching name without suffix
|
161
|
+
assign_cut_info_to_node(node, fused_node_to_memory_elements[node_name_without_suffix])
|
162
|
+
|
163
|
+
elif node.name in fused_nodes_mapping:
|
164
|
+
# Assign cut info if the node name is in the fused nodes mapping
|
165
|
+
original_node_name = fused_nodes_mapping[node.name]
|
166
|
+
assign_cut_info_to_node(node, fused_node_to_memory_elements[original_node_name])
|
167
|
+
|
168
|
+
elif is_wrapped_linear_op(quantized_model, node) and node_name_without_suffix in fused_nodes_mapping:
|
169
|
+
# Assign cut info if the node is a wrapped linear operation and its name without suffix is in the fused nodes mapping
|
170
|
+
original_node_name = fused_nodes_mapping[node_name_without_suffix]
|
171
|
+
assign_cut_info_to_node(node, fused_node_to_memory_elements[original_node_name])
|
172
|
+
|
173
|
+
|
174
|
+
def insert_cut_info_into_graph(quant_graph: Graph,
|
175
|
+
quantized_model_metadata: Dict[str, Any],
|
176
|
+
quantized_model: torch.nn.Module):
|
177
|
+
"""
|
178
|
+
Insert information about cut tensors into the graph nodes based on the provided metadata.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
quant_graph: The graph representing the quantized model.
|
182
|
+
quantized_model_metadata: Metadata containing scheduling information for the quantized model.
|
183
|
+
quantized_model: The quantized model.
|
184
|
+
"""
|
185
|
+
# Populate the mapping of fused nodes to memory elements
|
186
|
+
fused_node_to_memory_elements = populate_fused_node_memory_elements(quantized_model_metadata)
|
187
|
+
|
188
|
+
for node in quant_graph.nodes:
|
189
|
+
# Skip nodes without cut information
|
190
|
+
if node.type not in NODES_WITHOUT_CUT_INFO:
|
191
|
+
process_node_cut_info(node, fused_node_to_memory_elements, quantized_model_metadata, quantized_model)
|
192
|
+
|
193
|
+
|
{mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.1.0.20240807.445.dist-info → mct_nightly-2.1.0.20240808.431.dist-info}/top_level.txt
RENAMED
File without changes
|