mct-nightly 2.2.0.20250105.534__py3-none-any.whl → 2.2.0.20250107.15510__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/RECORD +43 -78
  3. {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +1 -1
  5. model_compression_toolkit/core/__init__.py +1 -1
  6. model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +1 -1
  7. model_compression_toolkit/core/common/graph/memory_graph/cut.py +5 -2
  8. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +25 -25
  9. model_compression_toolkit/core/common/quantization/quantization_config.py +19 -1
  10. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -33
  11. model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py +2 -2
  12. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +11 -1
  13. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/matmul_decomposition.py +499 -0
  14. model_compression_toolkit/core/pytorch/pytorch_implementation.py +3 -0
  15. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +11 -3
  16. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -1
  17. model_compression_toolkit/gptq/pytorch/quantization_facade.py +10 -1
  18. model_compression_toolkit/pruning/keras/pruning_facade.py +8 -2
  19. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -2
  20. model_compression_toolkit/ptq/keras/quantization_facade.py +10 -1
  21. model_compression_toolkit/ptq/pytorch/quantization_facade.py +9 -1
  22. model_compression_toolkit/qat/__init__.py +5 -2
  23. model_compression_toolkit/qat/keras/quantization_facade.py +9 -1
  24. model_compression_toolkit/qat/pytorch/quantization_facade.py +9 -1
  25. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
  26. model_compression_toolkit/target_platform_capabilities/schema/v1.py +63 -55
  27. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py +29 -18
  28. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py +78 -57
  29. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py +69 -54
  30. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +2 -4
  31. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +0 -10
  32. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +93 -0
  33. model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py +46 -28
  34. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +6 -5
  35. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +51 -19
  36. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +8 -4
  37. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +19 -9
  38. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +7 -4
  39. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +46 -32
  40. model_compression_toolkit/xquant/keras/keras_report_utils.py +11 -3
  41. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -2
  42. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +0 -98
  43. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +0 -129
  44. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +0 -108
  45. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/__init__.py +0 -16
  46. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +0 -217
  47. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +0 -130
  48. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +0 -109
  49. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/__init__.py +0 -16
  50. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +0 -215
  51. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +0 -130
  52. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +0 -110
  53. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/__init__.py +0 -16
  54. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +0 -222
  55. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +0 -132
  56. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +0 -110
  57. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/__init__.py +0 -16
  58. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +0 -219
  59. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +0 -132
  60. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +0 -109
  61. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +0 -16
  62. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +0 -246
  63. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +0 -135
  64. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +0 -113
  65. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +0 -16
  66. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +0 -230
  67. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +0 -132
  68. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +0 -110
  69. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/__init__.py +0 -16
  70. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +0 -332
  71. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +0 -140
  72. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +0 -122
  73. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +0 -55
  74. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +0 -89
  75. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +0 -78
  76. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +0 -55
  77. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +0 -118
  78. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +0 -100
  79. {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/LICENSE.md +0 -0
  80. {mct_nightly-2.2.0.20250105.534.dist-info → mct_nightly-2.2.0.20250107.15510.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ from model_compression_toolkit.core import MixedPrecisionQuantizationConfig, Cor
18
18
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
19
19
  from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.constants import TENSORFLOW
21
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
21
22
  from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
22
23
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
23
24
  from model_compression_toolkit.verify_packages import FOUND_TF
@@ -27,6 +28,8 @@ if FOUND_TF:
27
28
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
28
29
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
29
30
  from tensorflow.keras.models import Model
31
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
32
+ AttachTpcToKeras
30
33
 
31
34
  from model_compression_toolkit import get_target_platform_capabilities
32
35
 
@@ -36,7 +39,8 @@ if FOUND_TF:
36
39
  representative_data_gen: Callable,
37
40
  core_config: CoreConfig = CoreConfig(
38
41
  mixed_precision_config=MixedPrecisionQuantizationConfig()),
39
- target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC) -> ResourceUtilization:
42
+ target_platform_capabilities: TargetPlatformModel = KERAS_DEFAULT_TPC
43
+ ) -> ResourceUtilization:
40
44
  """
41
45
  Computes resource utilization data that can be used to calculate the desired target resource utilization
42
46
  for mixed-precision quantization.
@@ -78,6 +82,12 @@ if FOUND_TF:
78
82
 
79
83
  fw_impl = KerasImplementation()
80
84
 
85
+ # Attach tpc model to framework
86
+ attach2keras = AttachTpcToKeras()
87
+ target_platform_capabilities = attach2keras.attach(
88
+ target_platform_capabilities,
89
+ custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
90
+
81
91
  return compute_resource_utilization_data(in_model,
82
92
  representative_data_gen,
83
93
  core_config,
@@ -0,0 +1,499 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import operator
16
+ from typing import List
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from model_compression_toolkit.core.common.graph.base_graph import OutTensor
22
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
23
+ from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution
24
+ from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
25
+ from model_compression_toolkit.core.pytorch.constants import *
26
+ from model_compression_toolkit.logger import Logger
27
+
28
+
29
+ class MatMulParams:
30
+ """
31
+ A data class to hold all relevant parameter shapes and nodes for MatMul decomposition.
32
+ """
33
+
34
+ def __init__(self,
35
+ matmul_node: FunctionalNode):
36
+ """
37
+ Extract params for all the substitution nodes from original matmul node.
38
+ Args:
39
+ matmul_node: original MatMul Node
40
+
41
+ Naming convention:
42
+ * First parameter - input
43
+ * Second parameter - other
44
+ """
45
+ self.head_input_node, self.head_other_node = None, None
46
+ self.prev_input_node, self.prev_other_node = None, None
47
+
48
+ self.input_shape, self.other_shape = matmul_node.input_shape
49
+
50
+ # Step 1 - Expand
51
+ expand_shape = np.max(np.vstack((self.input_shape[1:-2], self.other_shape[1:-2])), axis=0).tolist()
52
+ self.input_expand_shape = tuple([-1] + expand_shape + list(self.input_shape[-2:]))
53
+ self.other_expand_shape = tuple([-1] + expand_shape + list(self.other_shape[-2:]))
54
+
55
+ # Step 2 - Reshape
56
+ # (B, D_1, ... , D_N, m, p) --> (B, (D_1*...*D_N), m, p)
57
+ self.input_reshape_shape = [
58
+ -1,
59
+ int(np.prod(self.input_expand_shape[1:-2])),
60
+ self.input_expand_shape[-2],
61
+ self.input_expand_shape[-1]
62
+ ]
63
+ # (B, D_1, ... , D_N, p, n) --> (B, (D_1*...*D_N), p, n)
64
+ self.other_reshape_shape = [
65
+ -1,
66
+ int(np.prod(self.other_expand_shape[1:-2])),
67
+ self.other_expand_shape[-2],
68
+ self.other_expand_shape[-1]
69
+ ]
70
+
71
+ # Step 3 - Split
72
+ # (B, (D_1*...*D_N), m, p) --> [(B, m, p)] * (D_1*...*D_N)
73
+ self.input_matmul_shape = [-1] + self.input_reshape_shape[-2:]
74
+ self.input_split_shape = tuple([self.input_matmul_shape] * self.input_reshape_shape[1])
75
+ # (B, (D_1*...*D_N), p, n) --> [(B, p, n)] * (D_1*...*D_N)
76
+ self.other_matmul_shape = [-1] + self.other_reshape_shape[-2:]
77
+ self.other_split_shape = tuple([self.other_matmul_shape] * self.other_reshape_shape[1])
78
+
79
+ # Step 4 - Matmul loop
80
+ # [(B, m, p)] * (D_1*...*D_N) X [(B, p, n)] * (D_1*...*D_N) --> [(B, m, n)] * (D_1*...*D_N)
81
+ self.single_matmul_shape = self.input_matmul_shape[:-1] + [self.other_matmul_shape[-1]]
82
+
83
+ # Step 5 - Stack and Reshape all matmul outputs to original dimensions
84
+ # [(B, m, n)] * (D_1*...*D_N) --> (B, (D_1*...*D_N), m, n)
85
+ self.matmul_stack_shape = tuple([-1] + [self.input_reshape_shape[1]] + self.single_matmul_shape[1:])
86
+ # (B, (D_1*...*D_N), m, n) --> (B, D_1, ..., D_N, m, n)
87
+ self.output_shape = tuple(list(self.input_expand_shape)[:-1] + [self.matmul_stack_shape[-1]])
88
+
89
+ def update_nodes(self,
90
+ input_node: FunctionalNode,
91
+ other_node: FunctionalNode):
92
+ """
93
+ Updates the head and prev nodes to support the option of skipping unnecessary operations.
94
+ Args:
95
+ input_node: node that operates on the input branch
96
+ other_node: node that operates on the other branch
97
+ """
98
+ if not self.head_input_node:
99
+ self.head_input_node = input_node
100
+ if not self.head_other_node:
101
+ self.head_other_node = other_node
102
+ self.prev_input_node = input_node
103
+ self.prev_other_node = other_node
104
+
105
+
106
+ class MatMulDecomposition(BaseSubstitution):
107
+ """
108
+ Removes A MatMul node from the graph if one of its inputs has >3 dimensions and replaces it with unbind, matmul
109
+ and stack nodes. Substitution is done inplace.
110
+
111
+ Naming convention:
112
+ * First parameter - input
113
+ * Second parameter - other
114
+ """
115
+
116
+ def __init__(self):
117
+ """
118
+ Matches: torch matmul or matmul operator.
119
+ """
120
+ func_node = NodeOperationMatcher(torch.matmul) | NodeOperationMatcher(operator.matmul)
121
+ super().__init__(matcher_instance=func_node)
122
+
123
+ def substitute(self,
124
+ graph: Graph,
125
+ matmul_node: FunctionalNode) -> Graph:
126
+ """
127
+ Decompose matmul of matrices with >3 dimensions to multiple matmuls and reconstruct graph.
128
+ Args:
129
+ graph: Graph we apply the substitution on.
130
+ matmul_node: MatMul node to substitute
131
+ Returns:
132
+ A graph after applying the substitution.
133
+ """
134
+
135
+ # If both matrices are already 3D or less, no need to change the graph
136
+ if len(matmul_node.input_shape[0]) <= 3 and len(matmul_node.input_shape[1]) <= 3:
137
+ return graph
138
+
139
+ if len(matmul_node.input_shape[0]) != len(matmul_node.input_shape[1]):
140
+ Logger.critical(f'Mismatch between number of input dimensions in node {matmul_node.name}.')
141
+
142
+ matmul_params = MatMulParams(matmul_node)
143
+
144
+ # Expand inputs to equal dimensions (instead of broadcasting) - if needed
145
+ if not np.array_equal(matmul_params.input_shape[1:-2], matmul_params.other_shape[1:-2]):
146
+ input_expand_node, other_expand_node = self._expand_inputs(
147
+ graph,
148
+ matmul_node,
149
+ matmul_params
150
+ )
151
+ matmul_params.update_nodes(input_node=input_expand_node, other_node=other_expand_node)
152
+
153
+ # Reshape inputs - if needed
154
+ # (B, D_1, ... , D_N, m, p) --> (B, (D_1*...*D_N), m, p)
155
+ # (B, D_1, ... , D_N, p, n) --> (B, (D_1*...*D_N), p, n)
156
+ if len(matmul_params.input_shape) > 4: # both input & other have the same number of dimensions
157
+ input_reshape_node, other_reshape_node = self._reshape_input(
158
+ graph,
159
+ matmul_node,
160
+ matmul_params
161
+ )
162
+ matmul_params.update_nodes(input_node=input_reshape_node, other_node=other_reshape_node)
163
+
164
+ # Split inputs
165
+ # (B, (D_1*...*D_N), m, p) --> [(B, m, p)] * (D_1*...*D_N)
166
+ # (B, (D_1*...*D_N), p, n) --> [(B, p, n)] * (D_1*...*D_N)
167
+ input_split_node, other_split_node = self._split_inputs(
168
+ graph,
169
+ matmul_node,
170
+ matmul_params
171
+ )
172
+ matmul_params.update_nodes(input_node=input_split_node, other_node=other_split_node)
173
+
174
+ # Matmul each pair
175
+ # [(B, m, p)] * (D_1*...*D_N) X [(B, p, n)] * (D_1*...*D_N) --> [(B, m, n)] * (D_1*...*D_N)
176
+ split_matmul_nodes = []
177
+ for idx in range(matmul_params.input_reshape_shape[1]):
178
+ split_matmul_node = self._calc_single_matmul(
179
+ graph,
180
+ matmul_node,
181
+ input_split_node,
182
+ other_split_node,
183
+ idx,
184
+ matmul_params
185
+ )
186
+ split_matmul_nodes.append(split_matmul_node)
187
+
188
+ # Stack and reshape all results - reshape if needed
189
+ # [(B, m, n)] * (D_1*...*D_N) --> (B, (D_1*...*D_N), m, n)
190
+ # (B, (D_1*...*D_N), m, n) --> (B, D_1, ..., D_N, m, n)
191
+ output_node = self._stack_matmul_outputs(
192
+ graph,
193
+ matmul_node,
194
+ split_matmul_nodes,
195
+ matmul_params
196
+ )
197
+
198
+ # connect edges to new nodes
199
+ self._connect_to_graph(
200
+ graph,
201
+ matmul_node,
202
+ matmul_params.head_input_node,
203
+ matmul_params.head_other_node,
204
+ output_node
205
+ )
206
+
207
+ # remove the original matmul node
208
+ graph.remove_node(matmul_node, new_graph_outputs=[OutTensor(output_node, 0)])
209
+
210
+ return graph
211
+
212
+ @staticmethod
213
+ def _expand_inputs(graph: Graph,
214
+ matmul_node: FunctionalNode,
215
+ params: MatMulParams) -> List[FunctionalNode]:
216
+ """
217
+ This method creates the nodes that expand the inputs such that the dimensions fit the MatMul process.
218
+
219
+ Args:
220
+ graph: Graph to apply the substitution on.
221
+ matmul_node: MatMul node.
222
+ params: MatMul shape params.
223
+
224
+ Returns:
225
+ Input & Other expand nodes.
226
+ """
227
+ if params.input_shape[1:] != list(params.input_expand_shape[1:]):
228
+ input_expand_node = FunctionalNode(
229
+ name=f'{matmul_node.name}_input_expand',
230
+ framework_attr={},
231
+ input_shape=params.input_shape,
232
+ output_shape=params.input_expand_shape,
233
+ weights={},
234
+ layer_class=torch.broadcast_to,
235
+ op_call_args=[params.input_expand_shape],
236
+ op_call_kwargs={},
237
+ functional_op=torch.broadcast_to
238
+ )
239
+ graph.add_node(input_expand_node)
240
+ else:
241
+ input_expand_node = None
242
+
243
+ if params.other_shape[1:] != list(params.other_expand_shape[1:]):
244
+ other_expand_node = FunctionalNode(
245
+ name=f'{matmul_node.name}_other_expand',
246
+ framework_attr={},
247
+ input_shape=params.other_shape,
248
+ output_shape=params.other_expand_shape,
249
+ weights={},
250
+ layer_class=torch.broadcast_to,
251
+ op_call_args=[params.other_expand_shape],
252
+ op_call_kwargs={},
253
+ functional_op=torch.broadcast_to
254
+ )
255
+ graph.add_node(other_expand_node)
256
+ else:
257
+ other_expand_node = None
258
+
259
+ return [input_expand_node, other_expand_node]
260
+
261
+ @staticmethod
262
+ def _reshape_input(graph: Graph,
263
+ matmul_node: FunctionalNode,
264
+ params: MatMulParams) -> List[FunctionalNode]:
265
+ """
266
+ This method creates the nodes that reshape the input nodes to be 4D before the split.
267
+
268
+ Args:
269
+ graph: Graph to apply the substitution on.
270
+ matmul_node: MatMul node.
271
+ params: MatMul shape params.
272
+
273
+ Returns:
274
+ Input & Other reshape nodes.
275
+ """
276
+ input_reshape_node = FunctionalNode(
277
+ name=f'{matmul_node.name}_input_reshape',
278
+ framework_attr={},
279
+ input_shape=params.input_expand_shape,
280
+ output_shape=params.input_reshape_shape,
281
+ weights={},
282
+ layer_class=torch.reshape,
283
+ op_call_args=[params.input_reshape_shape],
284
+ op_call_kwargs={},
285
+ functional_op=torch.reshape
286
+ )
287
+ other_reshape_node = FunctionalNode(
288
+ name=f'{matmul_node.name}_other_reshape',
289
+ framework_attr={},
290
+ input_shape=params.other_expand_shape,
291
+ output_shape=params.other_reshape_shape,
292
+ weights={},
293
+ layer_class=torch.reshape,
294
+ op_call_args=[params.other_reshape_shape],
295
+ op_call_kwargs={},
296
+ functional_op=torch.reshape
297
+ )
298
+ # Add reshapes to graph
299
+ if params.prev_input_node:
300
+ graph.add_node_with_in_edges(input_reshape_node, [params.prev_input_node])
301
+ else:
302
+ graph.add_node(input_reshape_node)
303
+
304
+ if params.prev_other_node:
305
+ graph.add_node_with_in_edges(other_reshape_node, [params.prev_other_node])
306
+ else:
307
+ graph.add_node(other_reshape_node)
308
+
309
+ return [input_reshape_node, other_reshape_node]
310
+
311
+ @staticmethod
312
+ def _split_inputs(graph: Graph,
313
+ matmul_node: FunctionalNode,
314
+ params: MatMulParams) -> List[FunctionalNode]:
315
+ """
316
+ This method creates the nodes that split the parameters from 4D to 3D for single MatMul operations.
317
+
318
+ Args:
319
+ graph: Graph to apply the substitution on.
320
+ matmul_node: MatMul node.
321
+ params: MatMul shape params.
322
+
323
+ Returns:
324
+ Input & Other unbind nodes - output of each is list of 3D matrices
325
+ """
326
+ input_split_node = FunctionalNode(
327
+ name=f'{matmul_node.name}_input_split',
328
+ framework_attr={},
329
+ input_shape=params.input_reshape_shape,
330
+ output_shape=params.input_split_shape,
331
+ weights={},
332
+ layer_class=torch.unbind,
333
+ op_call_args=[1],
334
+ op_call_kwargs={},
335
+ functional_op=torch.unbind
336
+ )
337
+
338
+ other_split_node = FunctionalNode(
339
+ name=f'{matmul_node.name}_other_split',
340
+ framework_attr={},
341
+ input_shape=params.other_reshape_shape,
342
+ output_shape=params.other_split_shape,
343
+ weights={},
344
+ layer_class=torch.unbind,
345
+ op_call_args=[1],
346
+ op_call_kwargs={},
347
+ functional_op=torch.unbind
348
+ )
349
+
350
+ if params.prev_input_node:
351
+ graph.add_node_with_in_edges(input_split_node, [params.prev_input_node])
352
+ else:
353
+ graph.add_node(input_split_node)
354
+ if params.prev_other_node:
355
+ graph.add_node_with_in_edges(other_split_node, [params.prev_other_node])
356
+ else:
357
+ graph.add_node(other_split_node)
358
+
359
+ return [input_split_node, other_split_node]
360
+
361
+ @staticmethod
362
+ def _calc_single_matmul(graph: Graph,
363
+ matmul_node: FunctionalNode,
364
+ input_split_node: FunctionalNode,
365
+ other_split_node: FunctionalNode,
366
+ dim_index: int,
367
+ params: MatMulParams) -> FunctionalNode:
368
+ """
369
+ This method creates the per channel (index) matmul.
370
+ Retrieves the matrices from index dim_index and multiplies them.
371
+
372
+ Args:
373
+ graph: Graph to apply the substitution on.
374
+ matmul_node: Original Matmul node
375
+ input_split_node: input after reshape and split.
376
+ other_split_node: other after reshape and split.
377
+ dim_index: index to run matmul on
378
+ params: MatMul Params
379
+
380
+ Returns:
381
+ Node after matmul of single dimension
382
+ """
383
+ # (B, m, n) X (B, n, p) -> (B, m, p)
384
+ # Get the input in index dim_index
385
+ get_input_node = FunctionalNode(
386
+ name=f'{matmul_node.name}_input_{dim_index}',
387
+ framework_attr={},
388
+ input_shape=params.input_split_shape,
389
+ output_shape=params.input_matmul_shape,
390
+ weights={},
391
+ layer_class=operator.getitem,
392
+ op_call_args=[dim_index],
393
+ op_call_kwargs={},
394
+ functional_op=operator.getitem
395
+ )
396
+ graph.add_node_with_in_edges(get_input_node, [input_split_node], [dim_index])
397
+ # Get the other in index dim_index
398
+ get_other_node = FunctionalNode(
399
+ name=f'{matmul_node.name}_other_{dim_index}',
400
+ framework_attr={},
401
+ input_shape=params.other_split_shape,
402
+ output_shape=params.other_matmul_shape,
403
+ weights={},
404
+ layer_class=operator.getitem,
405
+ op_call_args=[dim_index],
406
+ op_call_kwargs={},
407
+ functional_op=operator.getitem
408
+ )
409
+ graph.add_node_with_in_edges(get_other_node, [other_split_node], [dim_index])
410
+
411
+ matmul_node = FunctionalNode(name=f'{matmul_node.name}_matmul_{dim_index}',
412
+ framework_attr={},
413
+ input_shape=(params.input_matmul_shape, params.other_matmul_shape),
414
+ output_shape=[params.single_matmul_shape],
415
+ weights={},
416
+ layer_class=torch.matmul,
417
+ op_call_args=[],
418
+ op_call_kwargs={},
419
+ functional_op=torch.matmul)
420
+ graph.add_node_with_in_edges(matmul_node, [get_input_node, get_other_node])
421
+
422
+ return matmul_node
423
+
424
+ @staticmethod
425
+ def _stack_matmul_outputs(graph: Graph,
426
+ matmul_node: FunctionalNode,
427
+ split_matmul_nodes: List[FunctionalNode],
428
+ params: MatMulParams) -> FunctionalNode:
429
+ """
430
+ This method creates the node that concats all single matmuls together and then reshapes to the original output
431
+ shape.
432
+
433
+ Args:
434
+ graph: Graph to apply the substitution on.
435
+ matmul_node: Original Matmul node
436
+ split_matmul_nodes: list of all single matmul nodes.
437
+ params: MatMul Params
438
+
439
+ Returns:
440
+ Node after reshape - final output
441
+ """
442
+ # [(B, m, n)] * (D_1*...*D_N) --> (B, (D_1*...*D_N), m, n)
443
+ cat_node = FunctionalNode(
444
+ name=f'{matmul_node.name}_stack',
445
+ framework_attr={DIM: 1},
446
+ input_shape=[params.single_matmul_shape] * params.matmul_stack_shape[1],
447
+ output_shape=params.matmul_stack_shape,
448
+ weights={},
449
+ layer_class=torch.stack,
450
+ op_call_args=[],
451
+ op_call_kwargs={DIM: 1},
452
+ functional_op=torch.stack,
453
+ inputs_as_list=True
454
+ )
455
+ graph.add_node_with_in_edges(cat_node, split_matmul_nodes)
456
+
457
+ if params.matmul_stack_shape != params.output_shape:
458
+ # (B, (D_1 * ... * D_N), m, n) --> (B, D_1, ..., D_N, m, n)
459
+ matmul_reshape_node = FunctionalNode(
460
+ name=f'{matmul_node.name}_reshape',
461
+ framework_attr={},
462
+ input_shape=params.matmul_stack_shape,
463
+ output_shape=params.output_shape,
464
+ weights={},
465
+ layer_class=torch.reshape,
466
+ op_call_args=[params.output_shape],
467
+ op_call_kwargs={},
468
+ functional_op=torch.reshape
469
+ )
470
+ graph.add_node_with_in_edges(matmul_reshape_node, [cat_node])
471
+
472
+ return matmul_reshape_node if params.matmul_stack_shape != params.output_shape else cat_node
473
+
474
+ @staticmethod
475
+ def _connect_to_graph(
476
+ graph: Graph,
477
+ matmul_node: FunctionalNode,
478
+ head_input_node: FunctionalNode,
479
+ head_other_node: FunctionalNode,
480
+ output_node: FunctionalNode):
481
+ """
482
+ Connect the subgraph to the input graph.
483
+ Args:
484
+ graph: input graph
485
+ matmul_node: MatMul node to substitute inputs and outputs with
486
+ head_input_node: 1st input to MatMul Node
487
+ head_other_node: 2nd input to MatMul Node
488
+ output_node: output node of decomposed MatMul.
489
+ """
490
+ input_in_edge, other_in_edge = graph.in_edges(matmul_node)
491
+ if graph.get_edge_data(*input_in_edge, 0).get('sink_index') == 0:
492
+ graph.add_edge(input_in_edge[0], head_input_node, **graph.get_edge_data(*input_in_edge, 0))
493
+ graph.add_edge(other_in_edge[0], head_other_node, **graph.get_edge_data(*other_in_edge, 0))
494
+ else:
495
+ graph.add_edge(input_in_edge[0], head_other_node, **graph.get_edge_data(*input_in_edge, 0))
496
+ graph.add_edge(other_in_edge[0], head_input_node, **graph.get_edge_data(*other_in_edge, 0))
497
+ graph.remove_edge(input_in_edge[0], matmul_node)
498
+ graph.remove_edge(other_in_edge[0], matmul_node)
499
+ graph.reconnect_out_edges(current_node=matmul_node, new_node=output_node)
@@ -52,6 +52,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.fu
52
52
  FunctionalLayerNorm
53
53
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_linear import \
54
54
  FunctionalLinear
55
+ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.matmul_decomposition import \
56
+ MatMulDecomposition
55
57
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.linear_collapsing import \
56
58
  pytorch_linear_collapsing
57
59
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
@@ -264,6 +266,7 @@ class PytorchImplementation(FrameworkImplementation):
264
266
  return [ReshapeWithStaticShapes(),
265
267
  MultiHeadAttentionDecomposition(),
266
268
  ScaledDotProductDecomposition(),
269
+ MatMulDecomposition(),
267
270
  TransformFunctionCallMethod(),
268
271
  FunctionalConvSubstitution(fw_info),
269
272
  FunctionalBatchNorm(),
@@ -17,19 +17,21 @@ from typing import Callable
17
17
 
18
18
  from model_compression_toolkit.logger import Logger
19
19
  from model_compression_toolkit.constants import PYTORCH
20
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
20
21
  from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
21
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
22
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
23
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
24
24
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
25
25
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
26
+ from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
26
27
  from model_compression_toolkit.verify_packages import FOUND_TORCH
27
28
 
28
29
  if FOUND_TORCH:
29
30
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
30
31
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
31
- from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
32
32
  from torch.nn import Module
33
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
34
+ AttachTpcToPytorch
33
35
 
34
36
  from model_compression_toolkit import get_target_platform_capabilities
35
37
 
@@ -39,7 +41,7 @@ if FOUND_TORCH:
39
41
  def pytorch_resource_utilization_data(in_model: Module,
40
42
  representative_data_gen: Callable,
41
43
  core_config: CoreConfig = CoreConfig(),
42
- target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC
44
+ target_platform_capabilities: TargetPlatformModel= PYTORCH_DEFAULT_TPC
43
45
  ) -> ResourceUtilization:
44
46
  """
45
47
  Computes resource utilization data that can be used to calculate the desired target resource utilization for mixed-precision quantization.
@@ -80,6 +82,12 @@ if FOUND_TORCH:
80
82
 
81
83
  fw_impl = PytorchImplementation()
82
84
 
85
+ # Attach tpc model to framework
86
+ attach2pytorch = AttachTpcToPytorch()
87
+ target_platform_capabilities = (
88
+ attach2pytorch.attach(target_platform_capabilities,
89
+ custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer))
90
+
83
91
  return compute_resource_utilization_data(in_model,
84
92
  representative_data_gen,
85
93
  core_config,
@@ -22,6 +22,7 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR
22
22
  LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
23
23
  from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
25
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
25
26
  from model_compression_toolkit.verify_packages import FOUND_TF
26
27
  from model_compression_toolkit.core.common.user_info import UserInformation
27
28
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \
@@ -47,6 +48,8 @@ if FOUND_TF:
47
48
  from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
48
49
  from model_compression_toolkit import get_target_platform_capabilities
49
50
  from mct_quantizers.keras.metadata import add_metadata
51
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2keras import \
52
+ AttachTpcToKeras
50
53
 
51
54
  # As from TF2.9 optimizers package is changed
52
55
  if version.parse(tf.__version__) < version.parse("2.9"):
@@ -152,7 +155,7 @@ if FOUND_TF:
152
155
  gptq_representative_data_gen: Callable = None,
153
156
  target_resource_utilization: ResourceUtilization = None,
154
157
  core_config: CoreConfig = CoreConfig(),
155
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
158
+ target_platform_capabilities: TargetPlatformModel = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
156
159
  """
157
160
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
158
161
  symmetric constraint quantization thresholds (power of two).
@@ -237,6 +240,12 @@ if FOUND_TF:
237
240
 
238
241
  fw_impl = GPTQKerasImplemantation()
239
242
 
243
+ # Attach tpc model to framework
244
+ attach2keras = AttachTpcToKeras()
245
+ target_platform_capabilities = attach2keras.attach(
246
+ target_platform_capabilities,
247
+ custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
248
+
240
249
  tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
241
250
  representative_data_gen=representative_data_gen,
242
251
  core_config=core_config,
@@ -31,6 +31,7 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR
31
31
  from model_compression_toolkit.gptq.runner import gptq_runner
32
32
  from model_compression_toolkit.logger import Logger
33
33
  from model_compression_toolkit.metadata import create_model_metadata
34
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
34
35
  from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
35
36
  from model_compression_toolkit.verify_packages import FOUND_TORCH
36
37
 
@@ -47,6 +48,9 @@ if FOUND_TORCH:
47
48
  from torch.optim import Adam, Optimizer
48
49
  from model_compression_toolkit import get_target_platform_capabilities
49
50
  from mct_quantizers.pytorch.metadata import add_metadata
51
+ from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2pytorch import \
52
+ AttachTpcToPytorch
53
+
50
54
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
51
55
 
52
56
  def get_pytorch_gptq_config(n_epochs: int,
@@ -140,7 +144,7 @@ if FOUND_TORCH:
140
144
  core_config: CoreConfig = CoreConfig(),
141
145
  gptq_config: GradientPTQConfig = None,
142
146
  gptq_representative_data_gen: Callable = None,
143
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
147
+ target_platform_capabilities: TargetPlatformModel = DEFAULT_PYTORCH_TPC):
144
148
  """
145
149
  Quantize a trained Pytorch module using post-training quantization.
146
150
  By default, the module is quantized using a symmetric constraint quantization thresholds
@@ -209,6 +213,11 @@ if FOUND_TORCH:
209
213
 
210
214
  fw_impl = GPTQPytorchImplemantation()
211
215
 
216
+ # Attach tpc model to framework
217
+ attach2pytorch = AttachTpcToPytorch()
218
+ target_platform_capabilities = attach2pytorch.attach(target_platform_capabilities,
219
+ core_config.quantization_config.custom_tpc_opset_to_layer)
220
+
212
221
  # ---------------------- #
213
222
  # Core Runner
214
223
  # ---------------------- #