mct-nightly 1.10.0.20231204.post420__py3-none-any.whl → 1.10.0.20231206.post417__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.10.0.20231204.post420.dist-info → mct_nightly-1.10.0.20231206.post417.dist-info}/METADATA +1 -1
- {mct_nightly-1.10.0.20231204.post420.dist-info → mct_nightly-1.10.0.20231206.post417.dist-info}/RECORD +19 -18
- model_compression_toolkit/core/common/framework_implementation.py +8 -0
- model_compression_toolkit/core/common/graph/base_node.py +10 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_folding.py +2 -2
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +1 -1
- model_compression_toolkit/core/common/substitutions/linear_collapsing.py +82 -5
- model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py +3 -0
- model_compression_toolkit/core/common/substitutions/residual_collapsing.py +1 -3
- model_compression_toolkit/core/graph_prep_runner.py +1 -0
- model_compression_toolkit/core/keras/constants.py +2 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py +72 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py +108 -0
- model_compression_toolkit/core/keras/keras_implementation.py +10 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -0
- {mct_nightly-1.10.0.20231204.post420.dist-info → mct_nightly-1.10.0.20231206.post417.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.10.0.20231204.post420.dist-info → mct_nightly-1.10.0.20231206.post417.dist-info}/WHEEL +0 -0
- {mct_nightly-1.10.0.20231204.post420.dist-info → mct_nightly-1.10.0.20231206.post417.dist-info}/top_level.txt +0 -0
|
@@ -4,14 +4,14 @@ model_compression_toolkit/logger.py,sha256=b9DVktZ-LymFcRxv2aL_sdiE6S2sSrFGWltx6
|
|
|
4
4
|
model_compression_toolkit/core/__init__.py,sha256=qnBA6aaojI7RpEQZU2vXWiELHfVJf-MnAP-4T0tcFDY,2008
|
|
5
5
|
model_compression_toolkit/core/analyzer.py,sha256=dbsD61pakp_9JXNyAScLdtJvcXny9jr_cMbET0Bd3Sg,2975
|
|
6
6
|
model_compression_toolkit/core/exporter.py,sha256=U_-ea-zYHsnIt2ydameMLZ_gzDaCMI1dRa5IjA8RUuc,4233
|
|
7
|
-
model_compression_toolkit/core/graph_prep_runner.py,sha256=
|
|
7
|
+
model_compression_toolkit/core/graph_prep_runner.py,sha256=3xp0WYqyeRdlBkf5R6uD2zWubg_JPttOwS7JRhKykBY,10043
|
|
8
8
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=npv55-QsJFR7bnbHj4tBMf13Y18Ns7QGa-UDSI6WJRE,6554
|
|
9
9
|
model_compression_toolkit/core/runner.py,sha256=Cb8_TWAOBz4SO1O48ehxqC9PpaR4KifbCs0nV724zMM,10454
|
|
10
10
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
|
11
11
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
|
12
12
|
model_compression_toolkit/core/common/data_loader.py,sha256=7YF5Mqz64Xb4rVwY3knrdIZ4JEHybXxiQqx0deR_c5k,4017
|
|
13
13
|
model_compression_toolkit/core/common/defaultdict.py,sha256=P2WOZbWQTfVKtMfpGhGOS_1_5YWfYQWiJ5pBCn6F-3k,2182
|
|
14
|
-
model_compression_toolkit/core/common/framework_implementation.py,sha256=
|
|
14
|
+
model_compression_toolkit/core/common/framework_implementation.py,sha256=3oFMtvGkUKPtNxAAiXISmNM8XyccR3DyFQbOioBE4b4,21094
|
|
15
15
|
model_compression_toolkit/core/common/framework_info.py,sha256=hwmstv7IuBRfa6IxDbeG4y-7AxKx4bwCyI_Exi2C7mo,6424
|
|
16
16
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
|
17
17
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
|
@@ -33,7 +33,7 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
|
|
|
33
33
|
model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=tIsWFYc771o59uvq5fxAaBmOCnd_gd-_xMbQI9SupQA,5479
|
|
34
34
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
|
35
35
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=-YBWWl3ZS7FJvZldGqT4SetlKI8j60f4sS0rYyFQpbI,30059
|
|
36
|
-
model_compression_toolkit/core/common/graph/base_node.py,sha256=
|
|
36
|
+
model_compression_toolkit/core/common/graph/base_node.py,sha256=csIgi5ex7EquQsF34w5waRIHzbg7XitvIqQgCC29azs,21118
|
|
37
37
|
model_compression_toolkit/core/common/graph/edge.py,sha256=K6Wc2hBcIqig5PbbLhbjtTgYtkyZEohfgj4Wn_J5yEA,3733
|
|
38
38
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=0TpYNa2ODZ0M9lQ2z_GsStqAbrg1Muwdni74LjphAh0,2922
|
|
39
39
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=kQ14uXW6ecsj7IarjRLAXUzDBmakD_v6Ck7-u24_nxg,4732
|
|
@@ -118,12 +118,12 @@ model_compression_toolkit/core/common/statistics_correction/compute_bias_correct
|
|
|
118
118
|
model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=KFWY8jERabXwKm-qzQFc2V7v-fM1dqOlwRaOQ8UIiQA,5584
|
|
119
119
|
model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
|
120
120
|
model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
|
|
121
|
-
model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=
|
|
122
|
-
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=
|
|
123
|
-
model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=
|
|
124
|
-
model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=
|
|
125
|
-
model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=
|
|
126
|
-
model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=
|
|
121
|
+
model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
|
|
122
|
+
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=lYUZobQKydFyE3nRS-CBsYF3r4YlFirLp3-EmYa9qHM,5859
|
|
123
|
+
model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=eTDmac2OvqQgJMCg_dkCGFCmkvvO6mdYjsBui9HLymY,9929
|
|
124
|
+
model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=iEtzbWCDXP6EDkTZCtREQ0rpMxhQ2kM9zlcP_0KLq9I,12367
|
|
125
|
+
model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=uoauhmncQqUBNvD-qCLIXsIbl_IzrbxSKdxiMig-5W4,2406
|
|
126
|
+
model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=doErjlMq-uSObYMSjA6IywSHb3Hz3QCc0HKU68ccrQ4,4767
|
|
127
127
|
model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=nmb5QC_YiQJRbsEIq6uF50y1IRWhmRAUKaeUE9hnoNw,10978
|
|
128
128
|
model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=QbT6LMt4Eit4i1bLGIizHhE6R_tLeJf2Ix2qVod2bcw,28749
|
|
129
129
|
model_compression_toolkit/core/common/substitutions/softmax_shift.py,sha256=R-0ZqhYAuZLEFWHvB2UTPm52L6gWHGdRdEnwGxKSeGI,2625
|
|
@@ -134,10 +134,10 @@ model_compression_toolkit/core/common/visualization/final_config_visualizer.py,s
|
|
|
134
134
|
model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=6EjZj_KE1tICTQ0XSKIx5ivsRFpRktFywda7pW7YnNQ,5955
|
|
135
135
|
model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=954742gUTrrKmcVjcuBJaKR-EfMMsrWZ7PXd07unA6E,21939
|
|
136
136
|
model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
137
|
-
model_compression_toolkit/core/keras/constants.py,sha256=
|
|
137
|
+
model_compression_toolkit/core/keras/constants.py,sha256=OVa9yHaIlTKU4WatwTw_1dANk1-7ocQxDCluQwnwGy0,3094
|
|
138
138
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
|
139
139
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=cMdt9KvJMqOmWjFtUiEejzOe77mCpnnd3GzERgNh8Zk,4970
|
|
140
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
|
140
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=NYv0gHvv7wHs7grpTjh2SnBgu5OUb5r-fUBpQgb0PT4,28835
|
|
141
141
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
|
142
142
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=f6o5Fmpw0aDrO704_A-SqBrKSO1iNEOyofP9pm3g8yg,3936
|
|
143
143
|
model_compression_toolkit/core/keras/kpi_data_facade.py,sha256=rArrfMtxWGR1P4nhKKxqh6fo7pauRDzkRsZIh_SXxO4,8502
|
|
@@ -157,7 +157,8 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm
|
|
|
157
157
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=5df_xGfXkqNub4xVRnCWQvSohWqdv12axjJ6edVU2H0,2478
|
|
158
158
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=R3U7cjc2E0zheMem16GHygp5jZFGSaomkNOTxTjcAgw,5794
|
|
159
159
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=Yj59BMBrITJnXJHH-7de91LJwH_1l1WhY1udSQjdoi4,5598
|
|
160
|
-
model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=
|
|
160
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=mNLcAjSYzht_-mKh_fdBs4H4YYcQSLBJBFr_k1owF3s,8473
|
|
161
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py,sha256=QzzKXC_WhojIjIjpqeHxI171DKXcZMdr0hNcf_78o-s,4523
|
|
161
162
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=aAG5wrcnnydn1pPYqvH56LWsQXjSODbsoNbX_jtQGP4,26759
|
|
162
163
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=IdKOg6AWZWMcmDbOuNdxetS5_zTarXIIffdYL7JTdvk,3872
|
|
163
164
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py,sha256=cJQTDzTDQKAJ7EQ20tfsmReGA_OoTIN793MwVe1Ok8g,2387
|
|
@@ -196,7 +197,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
|
|
|
196
197
|
model_compression_toolkit/core/pytorch/constants.py,sha256=Kt_GDwe3yX9oMS1DI2eXYuUT25_lpjeCkxpstsAiXCI,2472
|
|
197
198
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=qee8TFcDro2lfyXe_fujjX2OlxELTyKSsLlZ7QkzeXU,4200
|
|
198
199
|
model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=J0IDOtFMVFSFyBXDzNGbwJfHu89iRBJFdid1_wFB-xQ,8482
|
|
199
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
|
200
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=oTAd6_XYtyvTX2fRXx0BzajvgqbYreXGKD7ij8iL2SY,26482
|
|
200
201
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
|
|
201
202
|
model_compression_toolkit/core/pytorch/utils.py,sha256=rBQMAbWluyIMjVfeghzq6FZv3sR_khszSRpbWvwussw,2959
|
|
202
203
|
model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
|
|
@@ -448,8 +449,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
448
449
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
449
450
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
450
451
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
|
|
451
|
-
mct_nightly-1.10.0.
|
|
452
|
-
mct_nightly-1.10.0.
|
|
453
|
-
mct_nightly-1.10.0.
|
|
454
|
-
mct_nightly-1.10.0.
|
|
455
|
-
mct_nightly-1.10.0.
|
|
452
|
+
mct_nightly-1.10.0.20231206.post417.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
453
|
+
mct_nightly-1.10.0.20231206.post417.dist-info/METADATA,sha256=kcsyt4UEdtqaUF2t42UBYPuB1oav069aO3eg3MjBrII,16232
|
|
454
|
+
mct_nightly-1.10.0.20231206.post417.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
|
455
|
+
mct_nightly-1.10.0.20231206.post417.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
456
|
+
mct_nightly-1.10.0.20231206.post417.dist-info/RECORD,,
|
|
@@ -235,6 +235,14 @@ class FrameworkImplementation(ABC):
|
|
|
235
235
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
236
236
|
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover
|
|
237
237
|
|
|
238
|
+
@abstractmethod
|
|
239
|
+
def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
|
|
240
|
+
"""
|
|
241
|
+
Returns: conv2d add const collapsing substitution
|
|
242
|
+
"""
|
|
243
|
+
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
244
|
+
f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover
|
|
245
|
+
|
|
238
246
|
@abstractmethod
|
|
239
247
|
def get_substitutions_statistics_correction(self, quant_config: QuantizationConfig) -> \
|
|
240
248
|
List[common.BaseSubstitution]:
|
|
@@ -79,7 +79,8 @@ class BaseNode:
|
|
|
79
79
|
def type(self):
|
|
80
80
|
"""
|
|
81
81
|
A function to get the node's layer_class op for convenient comparison
|
|
82
|
-
:
|
|
82
|
+
Returns:
|
|
83
|
+
the node's layer_class
|
|
83
84
|
"""
|
|
84
85
|
return self.layer_class
|
|
85
86
|
|
|
@@ -130,6 +131,14 @@ class BaseNode:
|
|
|
130
131
|
"""
|
|
131
132
|
return f'{self.type.__name__}:{self.name}'
|
|
132
133
|
|
|
134
|
+
def is_reused(self) -> bool:
|
|
135
|
+
"""
|
|
136
|
+
Check whether the node is reused or not
|
|
137
|
+
Returns:
|
|
138
|
+
True if node is reused, else False
|
|
139
|
+
"""
|
|
140
|
+
return self.reuse or self.reuse_group is not None
|
|
141
|
+
|
|
133
142
|
def get_weights_by_keys(self, name: str) -> np.ndarray:
|
|
134
143
|
"""
|
|
135
144
|
Get a node's weight by its name.
|
|
@@ -93,7 +93,7 @@ class BatchNormalizationFolding(common.BaseSubstitution):
|
|
|
93
93
|
|
|
94
94
|
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
95
95
|
# we should skip the substitution.
|
|
96
|
-
if conv_node.
|
|
96
|
+
if conv_node.is_reused():
|
|
97
97
|
return graph
|
|
98
98
|
|
|
99
99
|
bn_node = edge_nodes[1]
|
|
@@ -230,7 +230,7 @@ class BatchNormalizationForwardFolding(common.BaseSubstitution):
|
|
|
230
230
|
|
|
231
231
|
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
232
232
|
# we should skip the substitution.
|
|
233
|
-
if conv_node.
|
|
233
|
+
if conv_node.is_reused() or bn_node.is_reused():
|
|
234
234
|
return graph
|
|
235
235
|
|
|
236
236
|
if len(graph.get_next_nodes(bn_node)) > 1 or len(graph.get_prev_nodes(conv_node)) > 1:
|
|
@@ -79,7 +79,7 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
|
79
79
|
|
|
80
80
|
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
81
81
|
# we should skip the substitution.
|
|
82
|
-
if source_node.
|
|
82
|
+
if source_node.is_reused():
|
|
83
83
|
for qc in source_node.candidates_quantization_cfg:
|
|
84
84
|
qc.weights_quantization_cfg.weights_second_moment_correction = False
|
|
85
85
|
return graph
|
|
@@ -102,7 +102,7 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
|
102
102
|
|
|
103
103
|
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
104
104
|
# we should skip the substitution.
|
|
105
|
-
if source_node.
|
|
105
|
+
if source_node.is_reused():
|
|
106
106
|
Logger.exception("If the linear operator is part of a reused group we should skip the the BN folding "
|
|
107
107
|
"substitution and SMC feature") # pragma: no cover
|
|
108
108
|
|
|
@@ -91,14 +91,11 @@ class Conv2DCollapsing(common.BaseSubstitution):
|
|
|
91
91
|
Graph after applying the substitution.
|
|
92
92
|
"""
|
|
93
93
|
|
|
94
|
-
first_node = edge_nodes
|
|
95
|
-
second_node = edge_nodes[1]
|
|
94
|
+
first_node, second_node, _ = edge_nodes
|
|
96
95
|
|
|
97
96
|
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
98
97
|
# we should skip the substitution.
|
|
99
|
-
if first_node.
|
|
100
|
-
return graph
|
|
101
|
-
if second_node.reuse or second_node.reuse_group is not None:
|
|
98
|
+
if first_node.is_reused() or second_node.is_reused():
|
|
102
99
|
return graph
|
|
103
100
|
|
|
104
101
|
# If there is an extra connection between these two nodes skip the substitution
|
|
@@ -182,3 +179,83 @@ class Conv2DCollapsing(common.BaseSubstitution):
|
|
|
182
179
|
assert num_edges_before_substition - len(graph.edges) == 1
|
|
183
180
|
|
|
184
181
|
return graph
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class Op2DAddConstCollapsing(common.BaseSubstitution):
|
|
185
|
+
"""
|
|
186
|
+
Collapse Add-const into preceding Op2D (Not non-linear activation between them)
|
|
187
|
+
"""
|
|
188
|
+
def __init__(self,
|
|
189
|
+
first_node: NodeOperationMatcher,
|
|
190
|
+
second_node: NodeOperationMatcher,
|
|
191
|
+
op2d_collapsing_fn: Callable,
|
|
192
|
+
bias_str: str,
|
|
193
|
+
use_bias_str: str,
|
|
194
|
+
layer_name_str: str = None):
|
|
195
|
+
"""
|
|
196
|
+
Collapsing Add-const node (2nd node) to Op2D node (first node).
|
|
197
|
+
Args:
|
|
198
|
+
first_node: Node matcher for Op2d type nodes.
|
|
199
|
+
second_node: Node matcher for add type nodes.
|
|
200
|
+
op2d_collapsing_fn: Function for updating the convolution kernel and bias
|
|
201
|
+
bias_str: The framework specific attribute name of the convolution layer's bias.
|
|
202
|
+
use_bias_str: The framework specific attribute name of the convolution layer's bias flag.
|
|
203
|
+
layer_name_str: The framework specific attribute name of layer's name.
|
|
204
|
+
"""
|
|
205
|
+
super().__init__(matcher_instance=EdgeMatcher(first_node, second_node))
|
|
206
|
+
self.op2d_collapsing_fn = op2d_collapsing_fn
|
|
207
|
+
self.bias_str = bias_str
|
|
208
|
+
self.use_bias_str = use_bias_str
|
|
209
|
+
self.layer_name_str = layer_name_str
|
|
210
|
+
|
|
211
|
+
def substitute(self,
|
|
212
|
+
graph: Graph,
|
|
213
|
+
edge_nodes: Tuple[BaseNode, BaseNode]) -> Graph:
|
|
214
|
+
"""
|
|
215
|
+
Collapse linear layer into preceding linear layers.
|
|
216
|
+
Convolution condition:
|
|
217
|
+
|-------------------------| |------|
|
|
218
|
+
| Op2D | ---> | Add-const | -> | Op2D |
|
|
219
|
+
|-------------------------| |------|
|
|
220
|
+
Args:
|
|
221
|
+
graph: Graph we apply the substitution on.
|
|
222
|
+
edge_nodes: Tuple of linear node and add nodes
|
|
223
|
+
Returns:
|
|
224
|
+
Graph after applying the substitution.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
first_node, second_node, _ = edge_nodes
|
|
228
|
+
|
|
229
|
+
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
230
|
+
# we should skip the substitution.
|
|
231
|
+
if first_node.is_reused() or second_node.is_reused():
|
|
232
|
+
return graph
|
|
233
|
+
|
|
234
|
+
# If there is an extra connection between these two nodes skip the substitution
|
|
235
|
+
if len(graph.get_next_nodes(first_node)) > 1 or len(graph.get_prev_nodes(second_node)) > 1:
|
|
236
|
+
return graph
|
|
237
|
+
|
|
238
|
+
# New collapsed bias
|
|
239
|
+
bias = self.op2d_collapsing_fn(first_node, second_node, self.bias_str)
|
|
240
|
+
|
|
241
|
+
# New collapsed node
|
|
242
|
+
op2d_collapsed = copy.deepcopy(first_node)
|
|
243
|
+
op2d_collapsed_name = first_node.name + '_collapsed'
|
|
244
|
+
op2d_collapsed.name = op2d_collapsed_name
|
|
245
|
+
op2d_collapsed.framework_attr[self.use_bias_str] = True
|
|
246
|
+
op2d_collapsed.set_weights_by_keys(self.bias_str, bias)
|
|
247
|
+
|
|
248
|
+
if self.layer_name_str is not None:
|
|
249
|
+
op2d_collapsed.framework_attr[self.layer_name_str] = op2d_collapsed_name
|
|
250
|
+
|
|
251
|
+
# Update graph
|
|
252
|
+
graph.add_node(op2d_collapsed)
|
|
253
|
+
graph.reconnect_out_edges(current_node=second_node, new_node=op2d_collapsed)
|
|
254
|
+
graph.reconnect_in_edges(current_node=first_node, new_node=op2d_collapsed)
|
|
255
|
+
graph.replace_output_node(current_node=second_node, new_node=op2d_collapsed)
|
|
256
|
+
|
|
257
|
+
graph.remove_edge(first_node, second_node)
|
|
258
|
+
graph.remove_node(first_node)
|
|
259
|
+
graph.remove_node(second_node)
|
|
260
|
+
|
|
261
|
+
return graph
|
|
@@ -30,6 +30,9 @@ def linear_collapsing_substitute(graph: common.Graph,
|
|
|
30
30
|
Returns:
|
|
31
31
|
Transformed graph after applying all linear collapsing substitutions.
|
|
32
32
|
"""
|
|
33
|
+
# TODO: remove this if after adding Op2d-add_const collapse substitution in PyTorch
|
|
34
|
+
if linear_collapsing_substitution is None:
|
|
35
|
+
return graph
|
|
33
36
|
matched_nodes = graph.filter(linear_collapsing_substitution.matcher_instance)
|
|
34
37
|
matched_nodes_list = []
|
|
35
38
|
match_indicator = True
|
|
@@ -63,9 +63,7 @@ class ResidualCollapsing(common.BaseSubstitution):
|
|
|
63
63
|
|
|
64
64
|
# If the linear operator is part of a reused group (it is the "base" node, or a reused node),
|
|
65
65
|
# we should skip the substitution.
|
|
66
|
-
if first_node.
|
|
67
|
-
return graph
|
|
68
|
-
if second_node.reuse or second_node.reuse_group is not None:
|
|
66
|
+
if first_node.is_reused() or second_node.is_reused():
|
|
69
67
|
return graph
|
|
70
68
|
|
|
71
69
|
# Check if convolution and residual satisfy the collapsing conditions, otherwise skip substitution
|
|
@@ -129,6 +129,7 @@ def get_finalized_graph(initial_graph: Graph,
|
|
|
129
129
|
transformed_graph = substitute(graph, fw_impl.get_substitutions_pre_statistics_collection(quant_config))
|
|
130
130
|
if quant_config.linear_collapsing:
|
|
131
131
|
transformed_graph = linear_collapsing_substitute(transformed_graph, fw_impl.get_linear_collapsing_substitution())
|
|
132
|
+
transformed_graph = linear_collapsing_substitute(transformed_graph, fw_impl.get_op2d_add_const_collapsing_substitution())
|
|
132
133
|
if quant_config.residual_collapsing:
|
|
133
134
|
transformed_graph = substitute(transformed_graph, fw_impl.get_residual_collapsing_substitution())
|
|
134
135
|
|
|
@@ -53,6 +53,8 @@ DIMS = 'dims'
|
|
|
53
53
|
TARGET_SHAPE = 'target_shape'
|
|
54
54
|
TRANSPOSE_A = 'transpose_a'
|
|
55
55
|
TRANSPOSE_B = 'transpose_b'
|
|
56
|
+
ADJOINT_A = 'adjoint_a'
|
|
57
|
+
ADJOINT_B = 'adjoint_b'
|
|
56
58
|
DEPTH_MULTIPLIER = 'depth_multiplier'
|
|
57
59
|
DEPTHWISE_INITIALIZER = 'depthwise_initializer'
|
|
58
60
|
DEPTHWISE_REGULARIZER = 'depthwise_regularizer'
|
|
@@ -15,10 +15,14 @@
|
|
|
15
15
|
from typing import Tuple
|
|
16
16
|
import numpy as np
|
|
17
17
|
import tensorflow as tf
|
|
18
|
-
|
|
18
|
+
if tf.__version__ < "2.6":
|
|
19
|
+
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense
|
|
20
|
+
else:
|
|
21
|
+
from keras.layers import Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense
|
|
22
|
+
|
|
19
23
|
from model_compression_toolkit.core.common import BaseNode
|
|
20
24
|
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher, NodeFrameworkAttrMatcher
|
|
21
|
-
from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing
|
|
25
|
+
from model_compression_toolkit.core.common.substitutions.linear_collapsing import Conv2DCollapsing, Op2DAddConstCollapsing
|
|
22
26
|
from model_compression_toolkit.core.keras.constants import KERNEL, KERNEL_SIZE, STRIDES, DILATIONS, LINEAR, \
|
|
23
27
|
ACTIVATION, BIAS, USE_BIAS, LAYER_NAME, FILTERS, PADDING, GROUPS, DATA_FORMAT
|
|
24
28
|
from model_compression_toolkit.logger import Logger
|
|
@@ -123,3 +127,69 @@ def keras_linear_collapsing() -> Conv2DCollapsing:
|
|
|
123
127
|
FILTERS,
|
|
124
128
|
data_format_str=DATA_FORMAT,
|
|
125
129
|
layer_name_str=LAYER_NAME)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def op2d_add_const_collapsing_node_matchers() -> Tuple[NodeOperationMatcher, NodeOperationMatcher]:
|
|
133
|
+
"""
|
|
134
|
+
Function generates matchers for matching:
|
|
135
|
+
(Op2D, Add(const)) -> Op2D. (Op2D is one of [DepthwiseConv2D, Conv2D, Conv2DTranspose, Dense)
|
|
136
|
+
Returns:
|
|
137
|
+
Matcher for Op2D followed by Add const
|
|
138
|
+
"""
|
|
139
|
+
first_node = NodeOperationMatcher(DepthwiseConv2D) | \
|
|
140
|
+
NodeOperationMatcher(Conv2D) | \
|
|
141
|
+
NodeOperationMatcher(Conv2DTranspose) | \
|
|
142
|
+
NodeOperationMatcher(Dense)
|
|
143
|
+
second_node = NodeOperationMatcher(tf.math.add)
|
|
144
|
+
return first_node, second_node
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def op2d_add_const_collapsing_fn(op2d_node: BaseNode,
|
|
148
|
+
add_node: BaseNode,
|
|
149
|
+
bias_str: str) -> np.ndarray:
|
|
150
|
+
"""
|
|
151
|
+
Collapsing Add-Const to previous node's bias
|
|
152
|
+
Args:
|
|
153
|
+
op2d_node: Op2d layer node
|
|
154
|
+
add_node: Add layer to collapse
|
|
155
|
+
bias_str: The framework specific attribute name of the convolution layer's bias.
|
|
156
|
+
Returns:
|
|
157
|
+
The modified conv layer node's bias
|
|
158
|
+
"""
|
|
159
|
+
bias = op2d_node.get_weights_by_keys(bias_str)
|
|
160
|
+
|
|
161
|
+
# read constant from add node
|
|
162
|
+
if len(add_node.op_call_args) > 0:
|
|
163
|
+
const = add_node.op_call_args[0]
|
|
164
|
+
elif 'y' in add_node.op_call_kwargs:
|
|
165
|
+
const = add_node.op_call_kwargs['y']
|
|
166
|
+
else:
|
|
167
|
+
Logger.error(f'Unable to read constant from add node: {add_node.name}') # pragma: no cover
|
|
168
|
+
|
|
169
|
+
# convert constant to numpy array
|
|
170
|
+
if isinstance(const, tf.Tensor):
|
|
171
|
+
const = const.numpy()
|
|
172
|
+
elif isinstance(const, list):
|
|
173
|
+
const = np.array(const)
|
|
174
|
+
else:
|
|
175
|
+
Logger.error(f'Unable to convert constant to numpy array: {add_node.name}') # pragma: no cover
|
|
176
|
+
|
|
177
|
+
# return new bias
|
|
178
|
+
if bias is None:
|
|
179
|
+
return const
|
|
180
|
+
else:
|
|
181
|
+
return const + bias
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def keras_op2d_add_const_collapsing() -> Op2DAddConstCollapsing:
|
|
185
|
+
"""
|
|
186
|
+
Returns:
|
|
187
|
+
An Op2DCollapsing initialized for Keras models.
|
|
188
|
+
"""
|
|
189
|
+
first_node, second_node = op2d_add_const_collapsing_node_matchers()
|
|
190
|
+
return Op2DAddConstCollapsing(first_node,
|
|
191
|
+
second_node,
|
|
192
|
+
op2d_add_const_collapsing_fn,
|
|
193
|
+
BIAS,
|
|
194
|
+
USE_BIAS,
|
|
195
|
+
layer_name_str=LAYER_NAME)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import tensorflow as tf
|
|
18
|
+
from model_compression_toolkit.logger import Logger
|
|
19
|
+
from model_compression_toolkit.core import common
|
|
20
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
21
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
22
|
+
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
23
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
24
|
+
from model_compression_toolkit.core.keras.constants import TRANSPOSE_A, TRANSPOSE_B, \
|
|
25
|
+
ADJOINT_A, ADJOINT_B, UNITS, USE_BIAS, KERNEL
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MatmulToDenseSubstitution(common.BaseSubstitution):
|
|
29
|
+
"""
|
|
30
|
+
Replace a linear layer that has an activation function, with two nodes: same linear layer without
|
|
31
|
+
an activation function, and a new activation layer to replace the function the linear node had.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
"""
|
|
36
|
+
Matches: tf.linalg.matmul
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(matcher_instance=NodeOperationMatcher(tf.linalg.matmul))
|
|
39
|
+
|
|
40
|
+
def substitute(self,
|
|
41
|
+
graph: Graph,
|
|
42
|
+
matmul_node: FunctionalNode) -> Graph:
|
|
43
|
+
"""
|
|
44
|
+
Replace tf.linalg.matmul with Tensor and const with Dense layer
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
graph: Graph we apply the substitution on.
|
|
48
|
+
matmul_node: Node to replace.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Graph after applying the substitution.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
if len(graph.get_prev_nodes(matmul_node)) > 1:
|
|
55
|
+
# matmul of 2 activation tensors -> can't replace with Dense layer
|
|
56
|
+
return graph
|
|
57
|
+
|
|
58
|
+
if matmul_node.framework_attr.get(ADJOINT_A, False) or matmul_node.framework_attr.get(ADJOINT_B, False):
|
|
59
|
+
# MCT doesn't support complex tensors
|
|
60
|
+
return graph
|
|
61
|
+
|
|
62
|
+
if matmul_node.framework_attr.get(TRANSPOSE_A, False):
|
|
63
|
+
# first input should be an activation tensor with batch axis, that shouldn't be transposed
|
|
64
|
+
return graph
|
|
65
|
+
|
|
66
|
+
# read const from matmul inputs
|
|
67
|
+
if len(matmul_node.op_call_args) > 0:
|
|
68
|
+
w = matmul_node.op_call_args[0]
|
|
69
|
+
elif 'b' in matmul_node.op_call_kwargs:
|
|
70
|
+
w = matmul_node.op_call_kwargs['b']
|
|
71
|
+
else:
|
|
72
|
+
Logger.error(f"Matmul substitution: can't locate weight for node {matmul_node.name}") # pragma: no cover
|
|
73
|
+
|
|
74
|
+
# Convert weight const to numpy array
|
|
75
|
+
if isinstance(w, tf.Tensor):
|
|
76
|
+
w = w.numpy()
|
|
77
|
+
elif isinstance(w, list):
|
|
78
|
+
w = np.array(w)
|
|
79
|
+
elif not isinstance(w, np.ndarray):
|
|
80
|
+
Logger.error(f'Unable to convert constant to numpy array: {matmul_node.name}') # pragma: no cover
|
|
81
|
+
|
|
82
|
+
if len(w.shape) != 2:
|
|
83
|
+
# weight tensor should be of shape (Cin, Cout)
|
|
84
|
+
return graph
|
|
85
|
+
|
|
86
|
+
# transpose const if "transpose_b" flag is True
|
|
87
|
+
if matmul_node.op_call_kwargs.get(TRANSPOSE_B, False) or (
|
|
88
|
+
len(matmul_node.op_call_args) >= 3 and matmul_node.op_call_args[2]):
|
|
89
|
+
w = w.transpose()
|
|
90
|
+
|
|
91
|
+
dense_node = BaseNode(matmul_node.name,
|
|
92
|
+
{UNITS: w.shape[1], USE_BIAS: False},
|
|
93
|
+
matmul_node.input_shape, matmul_node.output_shape,
|
|
94
|
+
{KERNEL: w}, tf.keras.layers.Dense,
|
|
95
|
+
reuse=matmul_node.reuse, reuse_group=matmul_node.reuse_group)
|
|
96
|
+
|
|
97
|
+
graph.add_node(dense_node)
|
|
98
|
+
graph.reconnect_in_edges(current_node=matmul_node,
|
|
99
|
+
new_node=dense_node)
|
|
100
|
+
graph.reconnect_out_edges(current_node=matmul_node,
|
|
101
|
+
new_node=dense_node)
|
|
102
|
+
graph.replace_output_node(current_node=matmul_node,
|
|
103
|
+
new_node=dense_node)
|
|
104
|
+
graph.remove_node(matmul_node)
|
|
105
|
+
|
|
106
|
+
return graph
|
|
107
|
+
|
|
108
|
+
|
|
@@ -68,6 +68,8 @@ from model_compression_toolkit.core.common.user_info import UserInformation
|
|
|
68
68
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
69
69
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.activation_decomposition import \
|
|
70
70
|
ActivationDecomposition
|
|
71
|
+
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.matmul_substitution import \
|
|
72
|
+
MatmulToDenseSubstitution
|
|
71
73
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.softmax_shift import \
|
|
72
74
|
keras_softmax_shift
|
|
73
75
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.batchnorm_folding import \
|
|
@@ -75,7 +77,7 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.batc
|
|
|
75
77
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.batchnorm_refusing import \
|
|
76
78
|
keras_batchnorm_refusing
|
|
77
79
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.linear_collapsing import \
|
|
78
|
-
keras_linear_collapsing
|
|
80
|
+
keras_linear_collapsing, keras_op2d_add_const_collapsing
|
|
79
81
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \
|
|
80
82
|
keras_residual_collapsing
|
|
81
83
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
|
|
@@ -260,6 +262,7 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
260
262
|
|
|
261
263
|
"""
|
|
262
264
|
return [SeparableConvDecomposition(),
|
|
265
|
+
MatmulToDenseSubstitution(),
|
|
263
266
|
MultiHeadAttentionDecomposition(),
|
|
264
267
|
ActivationDecomposition(),
|
|
265
268
|
DwconvToConv()]
|
|
@@ -311,6 +314,12 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
311
314
|
"""
|
|
312
315
|
return keras_linear_collapsing()
|
|
313
316
|
|
|
317
|
+
def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
|
|
318
|
+
"""
|
|
319
|
+
Returns: Op2d add-const collapsing substitution
|
|
320
|
+
"""
|
|
321
|
+
return keras_op2d_add_const_collapsing()
|
|
322
|
+
|
|
314
323
|
def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) \
|
|
315
324
|
-> List[common.BaseSubstitution]:
|
|
316
325
|
"""
|
|
@@ -289,6 +289,12 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
289
289
|
"""
|
|
290
290
|
return pytorch_linear_collapsing()
|
|
291
291
|
|
|
292
|
+
def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
|
|
293
|
+
"""
|
|
294
|
+
Returns: None, as Op2d add-const substitution is not supported in torch yet
|
|
295
|
+
"""
|
|
296
|
+
return None
|
|
297
|
+
|
|
292
298
|
def get_substitutions_post_statistics_collection(self,
|
|
293
299
|
quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
|
|
294
300
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|