mct-nightly 2.0.0.20240409.404__py3-none-any.whl → 2.0.0.20240410.422__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/METADATA +1 -1
- {mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/RECORD +12 -10
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/quantization_config.py +3 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py +66 -0
- model_compression_toolkit/core/keras/keras_implementation.py +6 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py +69 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +4 -0
- {mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/WHEEL +0 -0
- {mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/top_level.txt +0 -0
{mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/RECORD
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
|
1
|
+
model_compression_toolkit/__init__.py,sha256=c33LV9Kt6hpVEoLixt_I5rqhtSzRBPSrdmFEifg-VHU,1573
|
|
2
2
|
model_compression_toolkit/constants.py,sha256=KW_HUEPmQEYqCvWGyORqkYxpvO7w5LViB5J5D-pm_6o,3648
|
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
|
@@ -100,8 +100,8 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
|
|
|
100
100
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
|
|
101
101
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
|
|
102
102
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
|
|
103
|
-
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=
|
|
104
|
-
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=
|
|
103
|
+
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=TCgpvtfyzFUedv4sZ6sKzsTyikaVl2ixLj_aHPSC2r0,27014
|
|
104
|
+
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=BieZDv9oc-Mc78S_LRMGo-s_2acbqiLE0ewaSE1v2VY,6818
|
|
105
105
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=T1nVWdRJfBQ_iuMQYQSIkjfkR-2n3lAOKGAz_rUZZN0,2190
|
|
106
106
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
|
|
107
107
|
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
|
|
@@ -148,7 +148,7 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
|
|
|
148
148
|
model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
|
|
149
149
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
|
150
150
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=Ha4HTHuiw_KTS5Po1Xnv6GyK9eprpDhYWf-eooS62Ys,4961
|
|
151
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
|
151
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=RS2UEtZ_anZeDxz7Zv6sNv7v9tFVct6d9KVrUlxTGpo,29309
|
|
152
152
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
|
153
153
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=Aqh31wOPaiZcJIOm-uJwzev0eTMdJyXaOk97rs4z7BU,3879
|
|
154
154
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=Xmk2ZL5CaYdb7iG62HdtZ1F64vap7ffnrsuR3e3G5hc,4851
|
|
@@ -166,6 +166,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/activatio
|
|
|
166
166
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py,sha256=9YCNPiK5BD7tLs1meabPhzfb2VsyPxrZM17zMFsW_Fo,8158
|
|
167
167
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=GR1a3mCZpNUu4WxixJXF_aSm57phAdxaRoHecNx3hxw,3168
|
|
168
168
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=5df_xGfXkqNub4xVRnCWQvSohWqdv12axjJ6edVU2H0,2478
|
|
169
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py,sha256=Hl4LEQ_bw_Vpmf3ZqHujYUqVdvTNsPlEMvr9dZhwg2U,2806
|
|
169
170
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=R3U7cjc2E0zheMem16GHygp5jZFGSaomkNOTxTjcAgw,5794
|
|
170
171
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=V6hp67CkS_A3WqdsjLjs0ETtdZAOo4P9mhy4aT7W5FE,5940
|
|
171
172
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=dyhqZrxSTclXyarT2JYnI5WPX0OvWR_CQiwddIr632U,8143
|
|
@@ -209,7 +210,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
|
|
|
209
210
|
model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
|
|
210
211
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
|
|
211
212
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=IoMvTch5awAEPvB6Tg6ANhFGXvfSgv7JLsUBlxpMwk4,4330
|
|
212
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
|
213
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=mT4jd8E1saCpAgrsClufQbnVJ0eYn1xaTQ3teALu4jk,27117
|
|
213
214
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
|
|
214
215
|
model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=E6ifk1HdO60k4IRH2EFBzAYWtwUlrGqJoQ66nknpHoQ,4983
|
|
215
216
|
model_compression_toolkit/core/pytorch/utils.py,sha256=dRPiteBg2dBNsHwZyYzXiCIAjnelSoeZZsDXlsTw5JQ,2880
|
|
@@ -228,6 +229,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/__init_
|
|
|
228
229
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py,sha256=j3q5DzbH3ys5MPFfSOVnAXdD7-g4XEKj2ADrdihVr30,8292
|
|
229
230
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=B7aC2TZNrQJ2oQVGBFhKAVqdUU5lYVJSMmwKhjxOHWk,2822
|
|
230
231
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=JDWOaNwYrZG0zTwd3HwoZUM3tKu7zPbzLOrqNQsu8xA,2162
|
|
232
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py,sha256=SBrR24ZAnWPftLinv4FuIqdBGjfYtfXbYQJN5mgy5V4,2861
|
|
231
233
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=dYGyb5ebnoeFBF0EaHPQU7CkXvoARdznEEe0laM47LA,3919
|
|
232
234
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=iX8bLHtw2osP42-peNLTRmbpX3cUxdGsAbEfw7NLpx0,3935
|
|
233
235
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=zKSgtVw_P9fUvdq4e7P9yaLDPG_vZ0cecM9sVPtm1ns,3799
|
|
@@ -469,8 +471,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
469
471
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
470
472
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
471
473
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=7bbzqJN8ZAycVDvZr_5xC-niTAR5df8f03Kooev_pfg,3047
|
|
472
|
-
mct_nightly-2.0.0.
|
|
473
|
-
mct_nightly-2.0.0.
|
|
474
|
-
mct_nightly-2.0.0.
|
|
475
|
-
mct_nightly-2.0.0.
|
|
476
|
-
mct_nightly-2.0.0.
|
|
474
|
+
mct_nightly-2.0.0.20240410.422.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
475
|
+
mct_nightly-2.0.0.20240410.422.dist-info/METADATA,sha256=Xx2HTbZkpp4O8bS07IXSnaYSh9ZZTxe61I47ovv9fzE,18795
|
|
476
|
+
mct_nightly-2.0.0.20240410.422.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
477
|
+
mct_nightly-2.0.0.20240410.422.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
478
|
+
mct_nightly-2.0.0.20240410.422.dist-info/RECORD,,
|
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
|
27
27
|
from model_compression_toolkit import pruning
|
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
29
29
|
|
|
30
|
-
__version__ = "2.0.0.
|
|
30
|
+
__version__ = "2.0.0.20240410.000422"
|
|
@@ -106,6 +106,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
106
106
|
self.z_threshold = qc.z_threshold
|
|
107
107
|
self.shift_negative_ratio = qc.shift_negative_ratio
|
|
108
108
|
self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
|
|
109
|
+
self.concat_threshold_update = qc.concat_threshold_update
|
|
109
110
|
|
|
110
111
|
def quantize_node_output(self,
|
|
111
112
|
tensors: Any) -> Any:
|
|
@@ -219,7 +220,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
219
220
|
self.shift_negative_activation_correction == other.shift_negative_activation_correction and \
|
|
220
221
|
self.z_threshold == other.z_threshold and \
|
|
221
222
|
self.shift_negative_ratio == other.shift_negative_ratio and \
|
|
222
|
-
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
|
|
223
|
+
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
|
|
223
224
|
|
|
224
225
|
def __hash__(self):
|
|
225
226
|
return hash((self.activation_quantization_fn,
|
|
@@ -62,7 +62,8 @@ class QuantizationConfig:
|
|
|
62
62
|
residual_collapsing: bool = True,
|
|
63
63
|
shift_negative_ratio: float = 0.05,
|
|
64
64
|
shift_negative_threshold_recalculation: bool = False,
|
|
65
|
-
shift_negative_params_search: bool = False
|
|
65
|
+
shift_negative_params_search: bool = False,
|
|
66
|
+
concat_threshold_update: bool = False):
|
|
66
67
|
"""
|
|
67
68
|
Class to wrap all different parameters the library quantize the input model according to.
|
|
68
69
|
|
|
@@ -117,6 +118,7 @@ class QuantizationConfig:
|
|
|
117
118
|
self.shift_negative_ratio = shift_negative_ratio
|
|
118
119
|
self.shift_negative_threshold_recalculation = shift_negative_threshold_recalculation
|
|
119
120
|
self.shift_negative_params_search = shift_negative_params_search
|
|
121
|
+
self.concat_threshold_update = concat_threshold_update
|
|
120
122
|
|
|
121
123
|
def __repr__(self):
|
|
122
124
|
return str(self.__dict__)
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from tensorflow.keras.layers import Concatenate
|
|
18
|
+
import tensorflow as tf
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core import common
|
|
21
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
22
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
23
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ConcatThresholdUpdate(common.BaseSubstitution):
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
Find concat layers and match their prior layers thresholds unless prior layer outputs to multiple layers.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
"""
|
|
36
|
+
Initialize a threshold_updater object.
|
|
37
|
+
"""
|
|
38
|
+
concatination_node = NodeOperationMatcher(Concatenate) | \
|
|
39
|
+
NodeOperationMatcher(tf.concat)
|
|
40
|
+
super().__init__(matcher_instance=concatination_node)
|
|
41
|
+
|
|
42
|
+
def substitute(self,
|
|
43
|
+
graph: Graph,
|
|
44
|
+
node: BaseNode) -> Graph:
|
|
45
|
+
"""
|
|
46
|
+
Update previous layers thresholds to match concatinations quantization thresholds. No change if
|
|
47
|
+
previous layer outputs to multiple layers. No change in case of uniform quantization.
|
|
48
|
+
No change in case of multiple quantization candidates (mixed precision).
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
graph: Graph we apply the substitution on.
|
|
53
|
+
node: Node refference to edit previous nodes thresholds.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Graph after applying the substitution.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
if len(node.candidates_quantization_cfg) == 1 and THRESHOLD in node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params:
|
|
60
|
+
concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD]
|
|
61
|
+
prev_nodes = graph.get_prev_nodes(node)
|
|
62
|
+
for prev_node in prev_nodes:
|
|
63
|
+
if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != Concatenate and prev_node.type != tf.concat:
|
|
64
|
+
prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold
|
|
65
|
+
|
|
66
|
+
return graph
|
|
@@ -80,7 +80,8 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.line
|
|
|
80
80
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \
|
|
81
81
|
keras_residual_collapsing
|
|
82
82
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
|
|
83
|
-
InputScalingWithPad
|
|
83
|
+
InputScalingWithPad
|
|
84
|
+
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.concat_threshold_update import ConcatThresholdUpdate
|
|
84
85
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
|
|
85
86
|
ReLUBoundToPowerOfTwo
|
|
86
87
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.multi_head_attention_decomposition import \
|
|
@@ -300,8 +301,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
300
301
|
"""
|
|
301
302
|
return keras_op2d_add_const_collapsing()
|
|
302
303
|
|
|
303
|
-
def get_substitutions_post_statistics_collection(self,
|
|
304
|
-
|
|
304
|
+
def get_substitutions_post_statistics_collection(self,
|
|
305
|
+
quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
|
|
305
306
|
"""
|
|
306
307
|
Return a list of the framework substitutions used after we collect statistics.
|
|
307
308
|
|
|
@@ -317,6 +318,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
317
318
|
if quant_config.input_scaling:
|
|
318
319
|
substitutions_list.append(InputScaling())
|
|
319
320
|
substitutions_list.append(InputScalingWithPad())
|
|
321
|
+
if quant_config.concat_threshold_update:
|
|
322
|
+
substitutions_list.append(ConcatThresholdUpdate())
|
|
320
323
|
return substitutions_list
|
|
321
324
|
|
|
322
325
|
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import List
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core import common
|
|
21
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
22
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
23
|
+
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
24
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ConcatThresholdUpdate(common.BaseSubstitution):
|
|
28
|
+
"""
|
|
29
|
+
Find concat layers and match their prior layers thresholds unless prior layer outputs to multiple layers.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def __init__(self):
|
|
34
|
+
"""
|
|
35
|
+
Initialize a threshold_updater object.
|
|
36
|
+
"""
|
|
37
|
+
concatination_node = NodeOperationMatcher(torch.cat) | \
|
|
38
|
+
NodeOperationMatcher(torch.concat)
|
|
39
|
+
super().__init__(matcher_instance=concatination_node)
|
|
40
|
+
|
|
41
|
+
def substitute(self,
|
|
42
|
+
graph: Graph,
|
|
43
|
+
node: BaseNode) -> Graph:
|
|
44
|
+
"""
|
|
45
|
+
Update previous layers thresholds to match concatinations quantization thresholds. No change if
|
|
46
|
+
previous layer outputs to multiple layers. No change in case of uniform quantization.
|
|
47
|
+
No change in case of multiple quantization candidates (mixed precision).
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
graph: Graph we apply the substitution on.
|
|
52
|
+
node: Node refference to edit previous nodes thresholds.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Graph after applying the substitution.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
if len(node.candidates_quantization_cfg) == 1 and THRESHOLD in node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params:
|
|
59
|
+
concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD]
|
|
60
|
+
prev_nodes = graph.get_prev_nodes(node)
|
|
61
|
+
for prev_node in prev_nodes:
|
|
62
|
+
if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != torch.cat and prev_node.type != torch.concat:
|
|
63
|
+
prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold
|
|
64
|
+
|
|
65
|
+
return graph
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
|
|
@@ -73,6 +73,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.vi
|
|
|
73
73
|
VirtualActivationWeightsComposition
|
|
74
74
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.weights_activation_split import \
|
|
75
75
|
WeightsActivationSplit
|
|
76
|
+
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.concat_threshold_update import \
|
|
77
|
+
ConcatThresholdUpdate
|
|
76
78
|
from model_compression_toolkit.core.pytorch.hessian.activation_trace_hessian_calculator_pytorch import \
|
|
77
79
|
ActivationTraceHessianCalculatorPytorch
|
|
78
80
|
from model_compression_toolkit.core.pytorch.hessian.weights_trace_hessian_calculator_pytorch import \
|
|
@@ -302,6 +304,8 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
302
304
|
substitutions_list.append(pytorch_softmax_shift())
|
|
303
305
|
if quant_config.input_scaling:
|
|
304
306
|
Logger.critical('Input scaling is currently not supported for Pytorch.')
|
|
307
|
+
if quant_config.concat_threshold_update:
|
|
308
|
+
substitutions_list.append(ConcatThresholdUpdate())
|
|
305
309
|
return substitutions_list
|
|
306
310
|
|
|
307
311
|
|
{mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/LICENSE.md
RENAMED
|
File without changes
|
|
File without changes
|
{mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240410.422.dist-info}/top_level.txt
RENAMED
|
File without changes
|