mct-nightly 2.2.0.20250105.534__py3-none-any.whl → 2.2.0.20250107.15510__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/RECORD +43 -78
- {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +1 -1
- model_compression_toolkit/core/common/graph/memory_graph/cut.py +5 -2
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +25 -25
- model_compression_toolkit/core/common/quantization/quantization_config.py +19 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -33
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +2 -2
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +11 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/matmul_decomposition.py +499 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +3 -0
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +11 -3
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -1
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +10 -1
- model_compression_toolkit/pruning/keras/pruning_facade.py +8 -2
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -2
- model_compression_toolkit/ptq/keras/quantization_facade.py +10 -1
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +9 -1
- model_compression_toolkit/qat/__init__.py +5 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +9 -1
- model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -1
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +63 -55
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py +29 -18
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py +78 -57
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py +69 -54
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +0 -10
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +93 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +46 -28
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +6 -5
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +51 -19
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +8 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +19 -9
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +7 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +46 -32
- model_compression_toolkit/xquant/keras/keras_report_utils.py +11 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +0 -98
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +0 -129
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +0 -108
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +0 -217
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +0 -130
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +0 -109
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +0 -215
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +0 -130
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +0 -222
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +0 -219
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +0 -109
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +0 -246
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +0 -135
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +0 -113
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +0 -230
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +0 -132
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +0 -110
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py +0 -16
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +0 -332
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +0 -140
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +0 -122
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +0 -55
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +0 -89
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +0 -78
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +0 -55
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +0 -118
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +0 -100
- {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, Cor
|
|
18
18
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
19
19
|
from model_compression_toolkit.logger import Logger
|
20
20
|
from model_compression_toolkit.constants import TENSORFLOW
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
21
22
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
22
23
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
|
23
24
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
@@ -27,6 +28,8 @@ if FOUND_TF:
|
|
27
28
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
28
29
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
29
30
|
from tensorflow.keras.models import Model
|
31
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
32
|
+
AttachTpcToKeras
|
30
33
|
|
31
34
|
from model_compression_toolkit import get_target_platform_capabilities
|
32
35
|
|
@@ -36,7 +39,8 @@ if FOUND_TF:
|
|
36
39
|
representative_data_gen: Callable,
|
37
40
|
core_config: CoreConfig = CoreConfig(
|
38
41
|
mixed_precision_config=MixedPrecisionQuantizationConfig()),
|
39
|
-
target_platform_capabilities:
|
42
|
+
target_platform_capabilities: TargetPlatformModel = KERAS_DEFAULT_TPC
|
43
|
+
) -> ResourceUtilization:
|
40
44
|
"""
|
41
45
|
Computes resource utilization data that can be used to calculate the desired target resource utilization
|
42
46
|
for mixed-precision quantization.
|
@@ -78,6 +82,12 @@ if FOUND_TF:
|
|
78
82
|
|
79
83
|
fw_impl = KerasImplementation()
|
80
84
|
|
85
|
+
# Attach tpc model to framework
|
86
|
+
attach2keras = AttachTpcToKeras()
|
87
|
+
target_platform_capabilities = attach2keras.attach(
|
88
|
+
target_platform_capabilities,
|
89
|
+
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
|
90
|
+
|
81
91
|
return compute_resource_utilization_data(in_model,
|
82
92
|
representative_data_gen,
|
83
93
|
core_config,
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/matmul_decomposition.py
ADDED
@@ -0,0 +1,499 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
import operator
|
16
|
+
from typing import List
|
17
|
+
|
18
|
+
import numpy as np
|
19
|
+
import torch
|
20
|
+
|
21
|
+
from model_compression_toolkit.core.common.graph.base_graph import OutTensor
|
22
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
23
|
+
from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution
|
24
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
25
|
+
from model_compression_toolkit.core.pytorch.constants import *
|
26
|
+
from model_compression_toolkit.logger import Logger
|
27
|
+
|
28
|
+
|
29
|
+
class MatMulParams:
|
30
|
+
"""
|
31
|
+
A data class to hold all relevant parameter shapes and nodes for MatMul decomposition.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self,
|
35
|
+
matmul_node: FunctionalNode):
|
36
|
+
"""
|
37
|
+
Extract params for all the substitution nodes from original matmul node.
|
38
|
+
Args:
|
39
|
+
matmul_node: original MatMul Node
|
40
|
+
|
41
|
+
Naming convention:
|
42
|
+
* First parameter - input
|
43
|
+
* Second parameter - other
|
44
|
+
"""
|
45
|
+
self.head_input_node, self.head_other_node = None, None
|
46
|
+
self.prev_input_node, self.prev_other_node = None, None
|
47
|
+
|
48
|
+
self.input_shape, self.other_shape = matmul_node.input_shape
|
49
|
+
|
50
|
+
# Step 1 - Expand
|
51
|
+
expand_shape = np.max(np.vstack((self.input_shape[1:-2], self.other_shape[1:-2])), axis=0).tolist()
|
52
|
+
self.input_expand_shape = tuple([-1] + expand_shape + list(self.input_shape[-2:]))
|
53
|
+
self.other_expand_shape = tuple([-1] + expand_shape + list(self.other_shape[-2:]))
|
54
|
+
|
55
|
+
# Step 2 - Reshape
|
56
|
+
# (B, D_1, ... , D_N, m, p) --> (B, (D_1*...*D_N), m, p)
|
57
|
+
self.input_reshape_shape = [
|
58
|
+
-1,
|
59
|
+
int(np.prod(self.input_expand_shape[1:-2])),
|
60
|
+
self.input_expand_shape[-2],
|
61
|
+
self.input_expand_shape[-1]
|
62
|
+
]
|
63
|
+
# (B, D_1, ... , D_N, p, n) --> (B, (D_1*...*D_N), p, n)
|
64
|
+
self.other_reshape_shape = [
|
65
|
+
-1,
|
66
|
+
int(np.prod(self.other_expand_shape[1:-2])),
|
67
|
+
self.other_expand_shape[-2],
|
68
|
+
self.other_expand_shape[-1]
|
69
|
+
]
|
70
|
+
|
71
|
+
# Step 3 - Split
|
72
|
+
# (B, (D_1*...*D_N), m, p) --> [(B, m, p)] * (D_1*...*D_N)
|
73
|
+
self.input_matmul_shape = [-1] + self.input_reshape_shape[-2:]
|
74
|
+
self.input_split_shape = tuple([self.input_matmul_shape] * self.input_reshape_shape[1])
|
75
|
+
# (B, (D_1*...*D_N), p, n) --> [(B, p, n)] * (D_1*...*D_N)
|
76
|
+
self.other_matmul_shape = [-1] + self.other_reshape_shape[-2:]
|
77
|
+
self.other_split_shape = tuple([self.other_matmul_shape] * self.other_reshape_shape[1])
|
78
|
+
|
79
|
+
# Step 4 - Matmul loop
|
80
|
+
# [(B, m, p)] * (D_1*...*D_N) X [(B, p, n)] * (D_1*...*D_N) --> [(B, m, n)] * (D_1*...*D_N)
|
81
|
+
self.single_matmul_shape = self.input_matmul_shape[:-1] + [self.other_matmul_shape[-1]]
|
82
|
+
|
83
|
+
# Step 5 - Stack and Reshape all matmul outputs to original dimensions
|
84
|
+
# [(B, m, n)] * (D_1*...*D_N) --> (B, (D_1*...*D_N), m, n)
|
85
|
+
self.matmul_stack_shape = tuple([-1] + [self.input_reshape_shape[1]] + self.single_matmul_shape[1:])
|
86
|
+
# (B, (D_1*...*D_N), m, n) --> (B, D_1, ..., D_N, m, n)
|
87
|
+
self.output_shape = tuple(list(self.input_expand_shape)[:-1] + [self.matmul_stack_shape[-1]])
|
88
|
+
|
89
|
+
def update_nodes(self,
|
90
|
+
input_node: FunctionalNode,
|
91
|
+
other_node: FunctionalNode):
|
92
|
+
"""
|
93
|
+
Updates the head and prev nodes to support the option of skipping unnecessary operations.
|
94
|
+
Args:
|
95
|
+
input_node: node that operates on the input branch
|
96
|
+
other_node: node that operates on the other branch
|
97
|
+
"""
|
98
|
+
if not self.head_input_node:
|
99
|
+
self.head_input_node = input_node
|
100
|
+
if not self.head_other_node:
|
101
|
+
self.head_other_node = other_node
|
102
|
+
self.prev_input_node = input_node
|
103
|
+
self.prev_other_node = other_node
|
104
|
+
|
105
|
+
|
106
|
+
class MatMulDecomposition(BaseSubstitution):
|
107
|
+
"""
|
108
|
+
Removes A MatMul node from the graph if one of its inputs has >3 dimensions and replaces it with unbind, matmul
|
109
|
+
and stack nodes. Substitution is done inplace.
|
110
|
+
|
111
|
+
Naming convention:
|
112
|
+
* First parameter - input
|
113
|
+
* Second parameter - other
|
114
|
+
"""
|
115
|
+
|
116
|
+
def __init__(self):
|
117
|
+
"""
|
118
|
+
Matches: torch matmul or matmul operator.
|
119
|
+
"""
|
120
|
+
func_node = NodeOperationMatcher(torch.matmul) | NodeOperationMatcher(operator.matmul)
|
121
|
+
super().__init__(matcher_instance=func_node)
|
122
|
+
|
123
|
+
def substitute(self,
|
124
|
+
graph: Graph,
|
125
|
+
matmul_node: FunctionalNode) -> Graph:
|
126
|
+
"""
|
127
|
+
Decompose matmul of matrices with >3 dimensions to multiple matmuls and reconstruct graph.
|
128
|
+
Args:
|
129
|
+
graph: Graph we apply the substitution on.
|
130
|
+
matmul_node: MatMul node to substitute
|
131
|
+
Returns:
|
132
|
+
A graph after applying the substitution.
|
133
|
+
"""
|
134
|
+
|
135
|
+
# If both matrices are already 3D or less, no need to change the graph
|
136
|
+
if len(matmul_node.input_shape[0]) <= 3 and len(matmul_node.input_shape[1]) <= 3:
|
137
|
+
return graph
|
138
|
+
|
139
|
+
if len(matmul_node.input_shape[0]) != len(matmul_node.input_shape[1]):
|
140
|
+
Logger.critical(f'Mismatch between number of input dimensions in node {matmul_node.name}.')
|
141
|
+
|
142
|
+
matmul_params = MatMulParams(matmul_node)
|
143
|
+
|
144
|
+
# Expand inputs to equal dimensions (instead of broadcasting) - if needed
|
145
|
+
if not np.array_equal(matmul_params.input_shape[1:-2], matmul_params.other_shape[1:-2]):
|
146
|
+
input_expand_node, other_expand_node = self._expand_inputs(
|
147
|
+
graph,
|
148
|
+
matmul_node,
|
149
|
+
matmul_params
|
150
|
+
)
|
151
|
+
matmul_params.update_nodes(input_node=input_expand_node, other_node=other_expand_node)
|
152
|
+
|
153
|
+
# Reshape inputs - if needed
|
154
|
+
# (B, D_1, ... , D_N, m, p) --> (B, (D_1*...*D_N), m, p)
|
155
|
+
# (B, D_1, ... , D_N, p, n) --> (B, (D_1*...*D_N), p, n)
|
156
|
+
if len(matmul_params.input_shape) > 4: # both input & other have the same number of dimensions
|
157
|
+
input_reshape_node, other_reshape_node = self._reshape_input(
|
158
|
+
graph,
|
159
|
+
matmul_node,
|
160
|
+
matmul_params
|
161
|
+
)
|
162
|
+
matmul_params.update_nodes(input_node=input_reshape_node, other_node=other_reshape_node)
|
163
|
+
|
164
|
+
# Split inputs
|
165
|
+
# (B, (D_1*...*D_N), m, p) --> [(B, m, p)] * (D_1*...*D_N)
|
166
|
+
# (B, (D_1*...*D_N), p, n) --> [(B, p, n)] * (D_1*...*D_N)
|
167
|
+
input_split_node, other_split_node = self._split_inputs(
|
168
|
+
graph,
|
169
|
+
matmul_node,
|
170
|
+
matmul_params
|
171
|
+
)
|
172
|
+
matmul_params.update_nodes(input_node=input_split_node, other_node=other_split_node)
|
173
|
+
|
174
|
+
# Matmul each pair
|
175
|
+
# [(B, m, p)] * (D_1*...*D_N) X [(B, p, n)] * (D_1*...*D_N) --> [(B, m, n)] * (D_1*...*D_N)
|
176
|
+
split_matmul_nodes = []
|
177
|
+
for idx in range(matmul_params.input_reshape_shape[1]):
|
178
|
+
split_matmul_node = self._calc_single_matmul(
|
179
|
+
graph,
|
180
|
+
matmul_node,
|
181
|
+
input_split_node,
|
182
|
+
other_split_node,
|
183
|
+
idx,
|
184
|
+
matmul_params
|
185
|
+
)
|
186
|
+
split_matmul_nodes.append(split_matmul_node)
|
187
|
+
|
188
|
+
# Stack and reshape all results - reshape if needed
|
189
|
+
# [(B, m, n)] * (D_1*...*D_N) --> (B, (D_1*...*D_N), m, n)
|
190
|
+
# (B, (D_1*...*D_N), m, n) --> (B, D_1, ..., D_N, m, n)
|
191
|
+
output_node = self._stack_matmul_outputs(
|
192
|
+
graph,
|
193
|
+
matmul_node,
|
194
|
+
split_matmul_nodes,
|
195
|
+
matmul_params
|
196
|
+
)
|
197
|
+
|
198
|
+
# connect edges to new nodes
|
199
|
+
self._connect_to_graph(
|
200
|
+
graph,
|
201
|
+
matmul_node,
|
202
|
+
matmul_params.head_input_node,
|
203
|
+
matmul_params.head_other_node,
|
204
|
+
output_node
|
205
|
+
)
|
206
|
+
|
207
|
+
# remove the original matmul node
|
208
|
+
graph.remove_node(matmul_node, new_graph_outputs=[OutTensor(output_node, 0)])
|
209
|
+
|
210
|
+
return graph
|
211
|
+
|
212
|
+
@staticmethod
|
213
|
+
def _expand_inputs(graph: Graph,
|
214
|
+
matmul_node: FunctionalNode,
|
215
|
+
params: MatMulParams) -> List[FunctionalNode]:
|
216
|
+
"""
|
217
|
+
This method creates the nodes that expand the inputs such that the dimensions fit the MatMul process.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
graph: Graph to apply the substitution on.
|
221
|
+
matmul_node: MatMul node.
|
222
|
+
params: MatMul shape params.
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
Input & Other expand nodes.
|
226
|
+
"""
|
227
|
+
if params.input_shape[1:] != list(params.input_expand_shape[1:]):
|
228
|
+
input_expand_node = FunctionalNode(
|
229
|
+
name=f'{matmul_node.name}_input_expand',
|
230
|
+
framework_attr={},
|
231
|
+
input_shape=params.input_shape,
|
232
|
+
output_shape=params.input_expand_shape,
|
233
|
+
weights={},
|
234
|
+
layer_class=torch.broadcast_to,
|
235
|
+
op_call_args=[params.input_expand_shape],
|
236
|
+
op_call_kwargs={},
|
237
|
+
functional_op=torch.broadcast_to
|
238
|
+
)
|
239
|
+
graph.add_node(input_expand_node)
|
240
|
+
else:
|
241
|
+
input_expand_node = None
|
242
|
+
|
243
|
+
if params.other_shape[1:] != list(params.other_expand_shape[1:]):
|
244
|
+
other_expand_node = FunctionalNode(
|
245
|
+
name=f'{matmul_node.name}_other_expand',
|
246
|
+
framework_attr={},
|
247
|
+
input_shape=params.other_shape,
|
248
|
+
output_shape=params.other_expand_shape,
|
249
|
+
weights={},
|
250
|
+
layer_class=torch.broadcast_to,
|
251
|
+
op_call_args=[params.other_expand_shape],
|
252
|
+
op_call_kwargs={},
|
253
|
+
functional_op=torch.broadcast_to
|
254
|
+
)
|
255
|
+
graph.add_node(other_expand_node)
|
256
|
+
else:
|
257
|
+
other_expand_node = None
|
258
|
+
|
259
|
+
return [input_expand_node, other_expand_node]
|
260
|
+
|
261
|
+
@staticmethod
|
262
|
+
def _reshape_input(graph: Graph,
|
263
|
+
matmul_node: FunctionalNode,
|
264
|
+
params: MatMulParams) -> List[FunctionalNode]:
|
265
|
+
"""
|
266
|
+
This method creates the nodes that reshape the input nodes to be 4D before the split.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
graph: Graph to apply the substitution on.
|
270
|
+
matmul_node: MatMul node.
|
271
|
+
params: MatMul shape params.
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
Input & Other reshape nodes.
|
275
|
+
"""
|
276
|
+
input_reshape_node = FunctionalNode(
|
277
|
+
name=f'{matmul_node.name}_input_reshape',
|
278
|
+
framework_attr={},
|
279
|
+
input_shape=params.input_expand_shape,
|
280
|
+
output_shape=params.input_reshape_shape,
|
281
|
+
weights={},
|
282
|
+
layer_class=torch.reshape,
|
283
|
+
op_call_args=[params.input_reshape_shape],
|
284
|
+
op_call_kwargs={},
|
285
|
+
functional_op=torch.reshape
|
286
|
+
)
|
287
|
+
other_reshape_node = FunctionalNode(
|
288
|
+
name=f'{matmul_node.name}_other_reshape',
|
289
|
+
framework_attr={},
|
290
|
+
input_shape=params.other_expand_shape,
|
291
|
+
output_shape=params.other_reshape_shape,
|
292
|
+
weights={},
|
293
|
+
layer_class=torch.reshape,
|
294
|
+
op_call_args=[params.other_reshape_shape],
|
295
|
+
op_call_kwargs={},
|
296
|
+
functional_op=torch.reshape
|
297
|
+
)
|
298
|
+
# Add reshapes to graph
|
299
|
+
if params.prev_input_node:
|
300
|
+
graph.add_node_with_in_edges(input_reshape_node, [params.prev_input_node])
|
301
|
+
else:
|
302
|
+
graph.add_node(input_reshape_node)
|
303
|
+
|
304
|
+
if params.prev_other_node:
|
305
|
+
graph.add_node_with_in_edges(other_reshape_node, [params.prev_other_node])
|
306
|
+
else:
|
307
|
+
graph.add_node(other_reshape_node)
|
308
|
+
|
309
|
+
return [input_reshape_node, other_reshape_node]
|
310
|
+
|
311
|
+
@staticmethod
|
312
|
+
def _split_inputs(graph: Graph,
|
313
|
+
matmul_node: FunctionalNode,
|
314
|
+
params: MatMulParams) -> List[FunctionalNode]:
|
315
|
+
"""
|
316
|
+
This method creates the nodes that split the parameters from 4D to 3D for single MatMul operations.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
graph: Graph to apply the substitution on.
|
320
|
+
matmul_node: MatMul node.
|
321
|
+
params: MatMul shape params.
|
322
|
+
|
323
|
+
Returns:
|
324
|
+
Input & Other unbind nodes - output of each is list of 3D matrices
|
325
|
+
"""
|
326
|
+
input_split_node = FunctionalNode(
|
327
|
+
name=f'{matmul_node.name}_input_split',
|
328
|
+
framework_attr={},
|
329
|
+
input_shape=params.input_reshape_shape,
|
330
|
+
output_shape=params.input_split_shape,
|
331
|
+
weights={},
|
332
|
+
layer_class=torch.unbind,
|
333
|
+
op_call_args=[1],
|
334
|
+
op_call_kwargs={},
|
335
|
+
functional_op=torch.unbind
|
336
|
+
)
|
337
|
+
|
338
|
+
other_split_node = FunctionalNode(
|
339
|
+
name=f'{matmul_node.name}_other_split',
|
340
|
+
framework_attr={},
|
341
|
+
input_shape=params.other_reshape_shape,
|
342
|
+
output_shape=params.other_split_shape,
|
343
|
+
weights={},
|
344
|
+
layer_class=torch.unbind,
|
345
|
+
op_call_args=[1],
|
346
|
+
op_call_kwargs={},
|
347
|
+
functional_op=torch.unbind
|
348
|
+
)
|
349
|
+
|
350
|
+
if params.prev_input_node:
|
351
|
+
graph.add_node_with_in_edges(input_split_node, [params.prev_input_node])
|
352
|
+
else:
|
353
|
+
graph.add_node(input_split_node)
|
354
|
+
if params.prev_other_node:
|
355
|
+
graph.add_node_with_in_edges(other_split_node, [params.prev_other_node])
|
356
|
+
else:
|
357
|
+
graph.add_node(other_split_node)
|
358
|
+
|
359
|
+
return [input_split_node, other_split_node]
|
360
|
+
|
361
|
+
@staticmethod
|
362
|
+
def _calc_single_matmul(graph: Graph,
|
363
|
+
matmul_node: FunctionalNode,
|
364
|
+
input_split_node: FunctionalNode,
|
365
|
+
other_split_node: FunctionalNode,
|
366
|
+
dim_index: int,
|
367
|
+
params: MatMulParams) -> FunctionalNode:
|
368
|
+
"""
|
369
|
+
This method creates the per channel (index) matmul.
|
370
|
+
Retrieves the matrices from index dim_index and multiplies them.
|
371
|
+
|
372
|
+
Args:
|
373
|
+
graph: Graph to apply the substitution on.
|
374
|
+
matmul_node: Original Matmul node
|
375
|
+
input_split_node: input after reshape and split.
|
376
|
+
other_split_node: other after reshape and split.
|
377
|
+
dim_index: index to run matmul on
|
378
|
+
params: MatMul Params
|
379
|
+
|
380
|
+
Returns:
|
381
|
+
Node after matmul of single dimension
|
382
|
+
"""
|
383
|
+
# (B, m, n) X (B, n, p) -> (B, m, p)
|
384
|
+
# Get the input in index dim_index
|
385
|
+
get_input_node = FunctionalNode(
|
386
|
+
name=f'{matmul_node.name}_input_{dim_index}',
|
387
|
+
framework_attr={},
|
388
|
+
input_shape=params.input_split_shape,
|
389
|
+
output_shape=params.input_matmul_shape,
|
390
|
+
weights={},
|
391
|
+
layer_class=operator.getitem,
|
392
|
+
op_call_args=[dim_index],
|
393
|
+
op_call_kwargs={},
|
394
|
+
functional_op=operator.getitem
|
395
|
+
)
|
396
|
+
graph.add_node_with_in_edges(get_input_node, [input_split_node], [dim_index])
|
397
|
+
# Get the other in index dim_index
|
398
|
+
get_other_node = FunctionalNode(
|
399
|
+
name=f'{matmul_node.name}_other_{dim_index}',
|
400
|
+
framework_attr={},
|
401
|
+
input_shape=params.other_split_shape,
|
402
|
+
output_shape=params.other_matmul_shape,
|
403
|
+
weights={},
|
404
|
+
layer_class=operator.getitem,
|
405
|
+
op_call_args=[dim_index],
|
406
|
+
op_call_kwargs={},
|
407
|
+
functional_op=operator.getitem
|
408
|
+
)
|
409
|
+
graph.add_node_with_in_edges(get_other_node, [other_split_node], [dim_index])
|
410
|
+
|
411
|
+
matmul_node = FunctionalNode(name=f'{matmul_node.name}_matmul_{dim_index}',
|
412
|
+
framework_attr={},
|
413
|
+
input_shape=(params.input_matmul_shape, params.other_matmul_shape),
|
414
|
+
output_shape=[params.single_matmul_shape],
|
415
|
+
weights={},
|
416
|
+
layer_class=torch.matmul,
|
417
|
+
op_call_args=[],
|
418
|
+
op_call_kwargs={},
|
419
|
+
functional_op=torch.matmul)
|
420
|
+
graph.add_node_with_in_edges(matmul_node, [get_input_node, get_other_node])
|
421
|
+
|
422
|
+
return matmul_node
|
423
|
+
|
424
|
+
@staticmethod
|
425
|
+
def _stack_matmul_outputs(graph: Graph,
|
426
|
+
matmul_node: FunctionalNode,
|
427
|
+
split_matmul_nodes: List[FunctionalNode],
|
428
|
+
params: MatMulParams) -> FunctionalNode:
|
429
|
+
"""
|
430
|
+
This method creates the node that concats all single matmuls together and then reshapes to the original output
|
431
|
+
shape.
|
432
|
+
|
433
|
+
Args:
|
434
|
+
graph: Graph to apply the substitution on.
|
435
|
+
matmul_node: Original Matmul node
|
436
|
+
split_matmul_nodes: list of all single matmul nodes.
|
437
|
+
params: MatMul Params
|
438
|
+
|
439
|
+
Returns:
|
440
|
+
Node after reshape - final output
|
441
|
+
"""
|
442
|
+
# [(B, m, n)] * (D_1*...*D_N) --> (B, (D_1*...*D_N), m, n)
|
443
|
+
cat_node = FunctionalNode(
|
444
|
+
name=f'{matmul_node.name}_stack',
|
445
|
+
framework_attr={DIM: 1},
|
446
|
+
input_shape=[params.single_matmul_shape] * params.matmul_stack_shape[1],
|
447
|
+
output_shape=params.matmul_stack_shape,
|
448
|
+
weights={},
|
449
|
+
layer_class=torch.stack,
|
450
|
+
op_call_args=[],
|
451
|
+
op_call_kwargs={DIM: 1},
|
452
|
+
functional_op=torch.stack,
|
453
|
+
inputs_as_list=True
|
454
|
+
)
|
455
|
+
graph.add_node_with_in_edges(cat_node, split_matmul_nodes)
|
456
|
+
|
457
|
+
if params.matmul_stack_shape != params.output_shape:
|
458
|
+
# (B, (D_1 * ... * D_N), m, n) --> (B, D_1, ..., D_N, m, n)
|
459
|
+
matmul_reshape_node = FunctionalNode(
|
460
|
+
name=f'{matmul_node.name}_reshape',
|
461
|
+
framework_attr={},
|
462
|
+
input_shape=params.matmul_stack_shape,
|
463
|
+
output_shape=params.output_shape,
|
464
|
+
weights={},
|
465
|
+
layer_class=torch.reshape,
|
466
|
+
op_call_args=[params.output_shape],
|
467
|
+
op_call_kwargs={},
|
468
|
+
functional_op=torch.reshape
|
469
|
+
)
|
470
|
+
graph.add_node_with_in_edges(matmul_reshape_node, [cat_node])
|
471
|
+
|
472
|
+
return matmul_reshape_node if params.matmul_stack_shape != params.output_shape else cat_node
|
473
|
+
|
474
|
+
@staticmethod
|
475
|
+
def _connect_to_graph(
|
476
|
+
graph: Graph,
|
477
|
+
matmul_node: FunctionalNode,
|
478
|
+
head_input_node: FunctionalNode,
|
479
|
+
head_other_node: FunctionalNode,
|
480
|
+
output_node: FunctionalNode):
|
481
|
+
"""
|
482
|
+
Connect the subgraph to the input graph.
|
483
|
+
Args:
|
484
|
+
graph: input graph
|
485
|
+
matmul_node: MatMul node to substitute inputs and outputs with
|
486
|
+
head_input_node: 1st input to MatMul Node
|
487
|
+
head_other_node: 2nd input to MatMul Node
|
488
|
+
output_node: output node of decomposed MatMul.
|
489
|
+
"""
|
490
|
+
input_in_edge, other_in_edge = graph.in_edges(matmul_node)
|
491
|
+
if graph.get_edge_data(*input_in_edge, 0).get('sink_index') == 0:
|
492
|
+
graph.add_edge(input_in_edge[0], head_input_node, **graph.get_edge_data(*input_in_edge, 0))
|
493
|
+
graph.add_edge(other_in_edge[0], head_other_node, **graph.get_edge_data(*other_in_edge, 0))
|
494
|
+
else:
|
495
|
+
graph.add_edge(input_in_edge[0], head_other_node, **graph.get_edge_data(*input_in_edge, 0))
|
496
|
+
graph.add_edge(other_in_edge[0], head_input_node, **graph.get_edge_data(*other_in_edge, 0))
|
497
|
+
graph.remove_edge(input_in_edge[0], matmul_node)
|
498
|
+
graph.remove_edge(other_in_edge[0], matmul_node)
|
499
|
+
graph.reconnect_out_edges(current_node=matmul_node, new_node=output_node)
|
@@ -52,6 +52,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.fu
|
|
52
52
|
FunctionalLayerNorm
|
53
53
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_linear import \
|
54
54
|
FunctionalLinear
|
55
|
+
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.matmul_decomposition import \
|
56
|
+
MatMulDecomposition
|
55
57
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.linear_collapsing import \
|
56
58
|
pytorch_linear_collapsing
|
57
59
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
|
@@ -264,6 +266,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
264
266
|
return [ReshapeWithStaticShapes(),
|
265
267
|
MultiHeadAttentionDecomposition(),
|
266
268
|
ScaledDotProductDecomposition(),
|
269
|
+
MatMulDecomposition(),
|
267
270
|
TransformFunctionCallMethod(),
|
268
271
|
FunctionalConvSubstitution(fw_info),
|
269
272
|
FunctionalBatchNorm(),
|
@@ -17,19 +17,21 @@ from typing import Callable
|
|
17
17
|
|
18
18
|
from model_compression_toolkit.logger import Logger
|
19
19
|
from model_compression_toolkit.constants import PYTORCH
|
20
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
20
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
21
22
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
|
22
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
|
24
24
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
25
25
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
26
|
+
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
26
27
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
27
28
|
|
28
29
|
if FOUND_TORCH:
|
29
30
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
30
31
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
31
|
-
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
32
32
|
from torch.nn import Module
|
33
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
|
34
|
+
AttachTpcToPytorch
|
33
35
|
|
34
36
|
from model_compression_toolkit import get_target_platform_capabilities
|
35
37
|
|
@@ -39,7 +41,7 @@ if FOUND_TORCH:
|
|
39
41
|
def pytorch_resource_utilization_data(in_model: Module,
|
40
42
|
representative_data_gen: Callable,
|
41
43
|
core_config: CoreConfig = CoreConfig(),
|
42
|
-
target_platform_capabilities:
|
44
|
+
target_platform_capabilities: TargetPlatformModel= PYTORCH_DEFAULT_TPC
|
43
45
|
) -> ResourceUtilization:
|
44
46
|
"""
|
45
47
|
Computes resource utilization data that can be used to calculate the desired target resource utilization for mixed-precision quantization.
|
@@ -80,6 +82,12 @@ if FOUND_TORCH:
|
|
80
82
|
|
81
83
|
fw_impl = PytorchImplementation()
|
82
84
|
|
85
|
+
# Attach tpc model to framework
|
86
|
+
attach2pytorch = AttachTpcToPytorch()
|
87
|
+
target_platform_capabilities = (
|
88
|
+
attach2pytorch.attach(target_platform_capabilities,
|
89
|
+
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer))
|
90
|
+
|
83
91
|
return compute_resource_utilization_data(in_model,
|
84
92
|
representative_data_gen,
|
85
93
|
core_config,
|
@@ -22,6 +22,7 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR
|
|
22
22
|
LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
24
|
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
25
26
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
26
27
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
27
28
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \
|
@@ -47,6 +48,8 @@ if FOUND_TF:
|
|
47
48
|
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
|
48
49
|
from model_compression_toolkit import get_target_platform_capabilities
|
49
50
|
from mct_quantizers.keras.metadata import add_metadata
|
51
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
|
52
|
+
AttachTpcToKeras
|
50
53
|
|
51
54
|
# As from TF2.9 optimizers package is changed
|
52
55
|
if version.parse(tf.__version__) < version.parse("2.9"):
|
@@ -152,7 +155,7 @@ if FOUND_TF:
|
|
152
155
|
gptq_representative_data_gen: Callable = None,
|
153
156
|
target_resource_utilization: ResourceUtilization = None,
|
154
157
|
core_config: CoreConfig = CoreConfig(),
|
155
|
-
target_platform_capabilities:
|
158
|
+
target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
|
156
159
|
"""
|
157
160
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
158
161
|
symmetric constraint quantization thresholds (power of two).
|
@@ -237,6 +240,12 @@ if FOUND_TF:
|
|
237
240
|
|
238
241
|
fw_impl = GPTQKerasImplemantation()
|
239
242
|
|
243
|
+
# Attach tpc model to framework
|
244
|
+
attach2keras = AttachTpcToKeras()
|
245
|
+
target_platform_capabilities = attach2keras.attach(
|
246
|
+
target_platform_capabilities,
|
247
|
+
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
|
248
|
+
|
240
249
|
tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
|
241
250
|
representative_data_gen=representative_data_gen,
|
242
251
|
core_config=core_config,
|
@@ -31,6 +31,7 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR
|
|
31
31
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
32
32
|
from model_compression_toolkit.logger import Logger
|
33
33
|
from model_compression_toolkit.metadata import create_model_metadata
|
34
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
34
35
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
35
36
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
36
37
|
|
@@ -47,6 +48,9 @@ if FOUND_TORCH:
|
|
47
48
|
from torch.optim import Adam, Optimizer
|
48
49
|
from model_compression_toolkit import get_target_platform_capabilities
|
49
50
|
from mct_quantizers.pytorch.metadata import add_metadata
|
51
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
|
52
|
+
AttachTpcToPytorch
|
53
|
+
|
50
54
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
51
55
|
|
52
56
|
def get_pytorch_gptq_config(n_epochs: int,
|
@@ -140,7 +144,7 @@ if FOUND_TORCH:
|
|
140
144
|
core_config: CoreConfig = CoreConfig(),
|
141
145
|
gptq_config: GradientPTQConfig = None,
|
142
146
|
gptq_representative_data_gen: Callable = None,
|
143
|
-
target_platform_capabilities:
|
147
|
+
target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
|
144
148
|
"""
|
145
149
|
Quantize a trained Pytorch module using post-training quantization.
|
146
150
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
@@ -209,6 +213,11 @@ if FOUND_TORCH:
|
|
209
213
|
|
210
214
|
fw_impl = GPTQPytorchImplemantation()
|
211
215
|
|
216
|
+
# Attach tpc model to framework
|
217
|
+
attach2pytorch = AttachTpcToPytorch()
|
218
|
+
target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
|
219
|
+
core_config.quantization_config.custom_tpc_opset_to_layer)
|
220
|
+
|
212
221
|
# ---------------------- #
|
213
222
|
# Core Runner
|
214
223
|
# ---------------------- #
|