mct-nightly 1.11.0.20240304.post404__py3-none-any.whl → 1.11.0.20240305.post352__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/METADATA +5 -5
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/RECORD +32 -30
- model_compression_toolkit/core/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +4 -70
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
- model_compression_toolkit/core/common/pruning/memory_calculator.py +19 -1
- model_compression_toolkit/core/common/quantization/core_config.py +3 -3
- model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +0 -3
- model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +0 -1
- model_compression_toolkit/core/keras/keras_implementation.py +2 -2
- model_compression_toolkit/core/keras/kpi_data_facade.py +5 -6
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +19 -19
- model_compression_toolkit/core/pytorch/constants.py +3 -0
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +5 -5
- model_compression_toolkit/core/pytorch/pruning/__init__.py +14 -0
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +315 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +4 -4
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +3 -3
- model_compression_toolkit/pruning/__init__.py +1 -0
- model_compression_toolkit/pruning/pytorch/__init__.py +14 -0
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +166 -0
- model_compression_toolkit/ptq/keras/quantization_facade.py +4 -7
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +3 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +6 -9
- model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -7
- model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +0 -64
- model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +0 -53
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240305.post352.dist-info}/top_level.txt +0 -0
|
@@ -44,9 +44,9 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
|
44
44
|
Prunes the entry node of a model in Keras.
|
|
45
45
|
|
|
46
46
|
Args:
|
|
47
|
-
node: The entry node to be pruned.
|
|
48
|
-
output_mask: A numpy array representing the mask to be applied to the output channels.
|
|
49
|
-
fw_info: Framework-specific information object.
|
|
47
|
+
node (BaseNode): The entry node to be pruned.
|
|
48
|
+
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
|
49
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
|
50
50
|
|
|
51
51
|
"""
|
|
52
52
|
return _prune_keras_edge_node(node=node,
|
|
@@ -63,10 +63,10 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
|
63
63
|
Prunes an intermediate node in a Keras model.
|
|
64
64
|
|
|
65
65
|
Args:
|
|
66
|
-
node: The intermediate node to be pruned.
|
|
67
|
-
input_mask: A numpy array representing the mask to be applied to the input channels.
|
|
68
|
-
output_mask: A numpy array representing the mask to be applied to the output channels.
|
|
69
|
-
fw_info: Framework-specific information object.
|
|
66
|
+
node (BaseNode): The intermediate node to be pruned.
|
|
67
|
+
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
|
68
|
+
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
|
69
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
|
70
70
|
|
|
71
71
|
"""
|
|
72
72
|
_edit_node_input_shape(input_mask, node)
|
|
@@ -85,9 +85,9 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
|
85
85
|
Prunes the exit node of a model in Keras.
|
|
86
86
|
|
|
87
87
|
Args:
|
|
88
|
-
node: The exit node to be pruned.
|
|
89
|
-
input_mask: A numpy array representing the mask to be applied to the input channels.
|
|
90
|
-
fw_info: Framework-specific information object.
|
|
88
|
+
node (BaseNode): The exit node to be pruned.
|
|
89
|
+
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
|
90
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
|
91
91
|
|
|
92
92
|
"""
|
|
93
93
|
return _prune_keras_edge_node(node=node,
|
|
@@ -100,10 +100,10 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
|
100
100
|
Determines whether a node is an entry node in a Keras model.
|
|
101
101
|
|
|
102
102
|
Args:
|
|
103
|
-
node: The node to be checked.
|
|
103
|
+
node (BaseNode): The node to be checked.
|
|
104
104
|
|
|
105
105
|
Returns:
|
|
106
|
-
Boolean indicating if the node is an entry node.
|
|
106
|
+
bool: Boolean indicating if the node is an entry node.
|
|
107
107
|
"""
|
|
108
108
|
return _is_keras_node_pruning_section_edge(node)
|
|
109
109
|
|
|
@@ -115,26 +115,26 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
|
115
115
|
Determines whether a node is an exit node in a Keras model.
|
|
116
116
|
|
|
117
117
|
Args:
|
|
118
|
-
node: The node to be checked.
|
|
119
|
-
corresponding_entry_node: The entry node of the pruning section that is checked.
|
|
120
|
-
fw_info: Framework-specific information object.
|
|
118
|
+
node (BaseNode): The node to be checked.
|
|
119
|
+
corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
|
|
120
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
|
121
121
|
|
|
122
122
|
Returns:
|
|
123
|
-
Boolean indicating if the node is an exit node.
|
|
123
|
+
bool: Boolean indicating if the node is an exit node.
|
|
124
124
|
"""
|
|
125
125
|
return _is_keras_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
|
|
126
126
|
corresponding_entry_node,
|
|
127
127
|
fw_info)
|
|
128
128
|
|
|
129
|
-
def is_node_intermediate_pruning_section(self, node) -> bool:
|
|
129
|
+
def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
|
|
130
130
|
"""
|
|
131
131
|
Determines whether a node is part of the intermediate section in the pruning process of a Keras model.
|
|
132
132
|
|
|
133
133
|
Args:
|
|
134
|
-
node: The node to be checked.
|
|
134
|
+
node (BaseNode): The node to be checked.
|
|
135
135
|
|
|
136
136
|
Returns:
|
|
137
|
-
Boolean indicating if the node is part of the intermediate pruning section.
|
|
137
|
+
bool: Boolean indicating if the node is part of the intermediate pruning section.
|
|
138
138
|
"""
|
|
139
139
|
# Nodes that are not Conv2D, Conv2DTranspose, DepthwiseConv2D, or Dense are considered intermediate.
|
|
140
140
|
return node.type not in [keras.layers.DepthwiseConv2D,
|
|
@@ -22,7 +22,7 @@ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import
|
|
|
22
22
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_data import compute_kpi_data
|
|
24
24
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
25
|
-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import
|
|
25
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
|
26
26
|
from model_compression_toolkit.constants import FOUND_TORCH
|
|
27
27
|
|
|
28
28
|
if FOUND_TORCH:
|
|
@@ -38,7 +38,7 @@ if FOUND_TORCH:
|
|
|
38
38
|
|
|
39
39
|
def pytorch_kpi_data(in_model: Module,
|
|
40
40
|
representative_data_gen: Callable,
|
|
41
|
-
core_config: CoreConfig = CoreConfig(),
|
|
41
|
+
core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
|
|
42
42
|
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
|
43
43
|
target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC) -> KPI:
|
|
44
44
|
"""
|
|
@@ -75,9 +75,9 @@ if FOUND_TORCH:
|
|
|
75
75
|
|
|
76
76
|
"""
|
|
77
77
|
|
|
78
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
79
|
-
Logger.error("KPI data computation can't be executed without
|
|
80
|
-
"Given quant_config is not of type
|
|
78
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
79
|
+
Logger.error("KPI data computation can't be executed without MixedPrecisionQuantizationConfig object."
|
|
80
|
+
"Given quant_config is not of type MixedPrecisionQuantizationConfig.")
|
|
81
81
|
|
|
82
82
|
fw_impl = PytorchImplementation()
|
|
83
83
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Tuple, Dict
|
|
17
|
+
|
|
18
|
+
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import \
|
|
19
|
+
PruningFrameworkImplementation
|
|
20
|
+
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
|
|
21
|
+
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
|
22
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
23
|
+
from model_compression_toolkit.core.common import BaseNode
|
|
24
|
+
from model_compression_toolkit.core.pytorch.constants import BIAS, GROUPS, OUT_CHANNELS, OUT_FEATURES, NUM_FEATURES, \
|
|
25
|
+
IN_CHANNELS, IN_FEATURES, NUM_PARAMETERS
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
|
|
30
|
+
from model_compression_toolkit.logger import Logger
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplementation):
|
|
34
|
+
"""
|
|
35
|
+
Implementation of the PruningFramework for the Pytorch framework. This class provides
|
|
36
|
+
concrete implementations of the abstract methods defined in PruningFrameworkImplementation
|
|
37
|
+
for the Pytorch framework.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def prune_entry_node(self,
|
|
41
|
+
node: BaseNode,
|
|
42
|
+
output_mask: np.ndarray,
|
|
43
|
+
fw_info: FrameworkInfo):
|
|
44
|
+
"""
|
|
45
|
+
Prunes the entry node of a model in Pytorch.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
node (BaseNode): The entry node to be pruned.
|
|
49
|
+
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
|
50
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
|
51
|
+
|
|
52
|
+
"""
|
|
53
|
+
return _prune_pytorch_edge_node(node=node,
|
|
54
|
+
mask=output_mask,
|
|
55
|
+
fw_info=fw_info,
|
|
56
|
+
is_exit_node=False)
|
|
57
|
+
|
|
58
|
+
def prune_intermediate_node(self,
|
|
59
|
+
node: BaseNode,
|
|
60
|
+
input_mask: np.ndarray,
|
|
61
|
+
output_mask: np.ndarray,
|
|
62
|
+
fw_info: FrameworkInfo):
|
|
63
|
+
"""
|
|
64
|
+
Prunes an intermediate node in a Pytorch model.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
node (BaseNode): The intermediate node to be pruned.
|
|
68
|
+
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
|
69
|
+
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
|
70
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
|
71
|
+
|
|
72
|
+
"""
|
|
73
|
+
# TODO (reuvenp/liord): Address handling of node parameters that can be either a single value across all channels or distinct per channel, e.g., PReLU. Consider developing a structured approach.
|
|
74
|
+
pruning_en = True
|
|
75
|
+
_edit_node_input_shape(node, input_mask, fw_info)
|
|
76
|
+
pruned_parameters = {}
|
|
77
|
+
mask_bool = output_mask.astype(bool)
|
|
78
|
+
node.weights = pruned_parameters
|
|
79
|
+
if node.type == torch.nn.BatchNorm2d:
|
|
80
|
+
node.framework_attr[NUM_FEATURES] = int(np.sum(input_mask))
|
|
81
|
+
elif node.type == torch.nn.PReLU:
|
|
82
|
+
if node.framework_attr[NUM_PARAMETERS] > 1:
|
|
83
|
+
node.framework_attr[NUM_PARAMETERS] = int(np.sum(input_mask))
|
|
84
|
+
else:
|
|
85
|
+
pruning_en = False
|
|
86
|
+
|
|
87
|
+
if pruning_en:
|
|
88
|
+
for k, v in node.weights.items():
|
|
89
|
+
# Apply the mask to the weights.
|
|
90
|
+
pruned_parameters[k] = v.compress(mask_bool, axis=-1)
|
|
91
|
+
|
|
92
|
+
def prune_exit_node(self,
|
|
93
|
+
node: BaseNode,
|
|
94
|
+
input_mask: np.ndarray,
|
|
95
|
+
fw_info: FrameworkInfo):
|
|
96
|
+
"""
|
|
97
|
+
Prunes the exit node of a model in Pytorch.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
node (BaseNode): The exit node to be pruned.
|
|
101
|
+
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
|
102
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
|
103
|
+
|
|
104
|
+
"""
|
|
105
|
+
return _prune_pytorch_edge_node(node=node,
|
|
106
|
+
mask=input_mask,
|
|
107
|
+
fw_info=fw_info,
|
|
108
|
+
is_exit_node=True)
|
|
109
|
+
|
|
110
|
+
def is_node_entry_node(self, node: BaseNode) -> bool:
|
|
111
|
+
"""
|
|
112
|
+
Determines whether a node is an entry node in a Pytorch model.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
node (BaseNode): The node to be checked.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
bool: Boolean indicating if the node is an entry node.
|
|
119
|
+
"""
|
|
120
|
+
return _is_pytorch_node_pruning_section_edge(node)
|
|
121
|
+
|
|
122
|
+
def is_node_exit_node(self,
|
|
123
|
+
node: BaseNode,
|
|
124
|
+
corresponding_entry_node: BaseNode,
|
|
125
|
+
fw_info: FrameworkInfo) -> bool:
|
|
126
|
+
"""
|
|
127
|
+
Determines whether a node is an exit node in a Pytorch model.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
node (BaseNode): The node to be checked.
|
|
131
|
+
corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
|
|
132
|
+
fw_info (FrameworkInfo) Framework-specific information object.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
bool: Boolean indicating if the node is an exit node.
|
|
136
|
+
"""
|
|
137
|
+
return _is_pytorch_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
|
|
138
|
+
corresponding_entry_node,
|
|
139
|
+
fw_info)
|
|
140
|
+
|
|
141
|
+
def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
|
|
142
|
+
"""
|
|
143
|
+
Determines whether a node is part of the intermediate section in the pruning process of a Pytorch model.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
node (BaseNode): The node to be checked.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
bool: Boolean indicating if the node is part of the intermediate pruning section.
|
|
150
|
+
"""
|
|
151
|
+
# Nodes that are not Conv2d, ConvTranspose2d, or Linear are considered intermediate.
|
|
152
|
+
# For PReLU prune attributes only if there is a parameter per channel
|
|
153
|
+
return node.type not in [torch.nn.Conv2d,
|
|
154
|
+
torch.nn.ConvTranspose2d,
|
|
155
|
+
torch.nn.Linear]
|
|
156
|
+
|
|
157
|
+
def attrs_oi_channels_info_for_pruning(self,
|
|
158
|
+
node: BaseNode,
|
|
159
|
+
fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
|
|
160
|
+
"""
|
|
161
|
+
Retrieves the attributes of a given node along with the output/input (OI) channel axis
|
|
162
|
+
for each attribute used to prune these attributes.
|
|
163
|
+
|
|
164
|
+
Not all attributes of a node are directly associated with both input and output channels.
|
|
165
|
+
For example, bias vectors in convolutional layers are solely related to the number of output
|
|
166
|
+
channels and do not have a corresponding input channel dimension.
|
|
167
|
+
In cases like that, None is returned in the tuple of axis for such attributes.
|
|
168
|
+
|
|
169
|
+
For kernel operations (like convolutions), the function identifies the output and input
|
|
170
|
+
channel axis based on framework-specific information.
|
|
171
|
+
For non-kernel operations, it defaults to setting the last axis as the output
|
|
172
|
+
channel axis, assuming no specific input channel axis.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
node (BaseNode): The node from the computational graph.
|
|
176
|
+
fw_info (FrameworkInfo): Contains framework-specific information and utilities.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'weight' or 'bias')
|
|
180
|
+
and each value is a tuple representing the output and input channel axis indices respectively.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
attributes_with_axis = {}
|
|
184
|
+
if fw_info.is_kernel_op(node.type):
|
|
185
|
+
kernel_attributes = fw_info.get_kernel_op_attributes(node.type)
|
|
186
|
+
if kernel_attributes is None or len(kernel_attributes) == 0:
|
|
187
|
+
Logger.error(f"Expected to find attributes but found {kernel_attributes}")
|
|
188
|
+
|
|
189
|
+
for attr in kernel_attributes:
|
|
190
|
+
attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type)
|
|
191
|
+
|
|
192
|
+
# Bias is a vector at the length of the number of output channels.
|
|
193
|
+
# For this reason, input channel axis is irrelevant to the bias attribute.
|
|
194
|
+
attributes_with_axis[BIAS] = (0, None)
|
|
195
|
+
else:
|
|
196
|
+
# We have several assumptions here:
|
|
197
|
+
# 1. For intermediate nodes, we prune all nodes' weights.
|
|
198
|
+
# 2. The output channel axis is the last axis of this attribute.
|
|
199
|
+
# 3. The input channel axis is irrelevant since these attributes are pruned only by
|
|
200
|
+
# their output channels.
|
|
201
|
+
for attr in list(node.weights.keys()):
|
|
202
|
+
# If the number of float parameters is 1 or less - is the case where
|
|
203
|
+
# we have one parameter for all channels. For this case, we don't
|
|
204
|
+
# want to prune the parameter.
|
|
205
|
+
if node.get_num_parameters(fw_info)[1] <= 1:
|
|
206
|
+
attributes_with_axis[attr] = (None, None)
|
|
207
|
+
else:
|
|
208
|
+
attributes_with_axis[attr] = (-1, None)
|
|
209
|
+
|
|
210
|
+
return attributes_with_axis
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
|
|
214
|
+
"""
|
|
215
|
+
Determines if a Pytorch node is an edge of a pruning section.
|
|
216
|
+
|
|
217
|
+
In the context of pruning, an 'edge' node is a layer that can potentially be pruned.
|
|
218
|
+
This function identifies such nodes based on their type and attributes. Specifically,
|
|
219
|
+
Conv2d and ConvTranspose2d layers with 'groups' attribute set to 1, and Linear layers
|
|
220
|
+
are considered as edges for pruning sections.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
node (BaseNode): The node to be evaluated.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
bool: True if the node is an edge of a pruning section, False otherwise.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
# Check if the node is a Conv2D or Conv2DTranspose layer with groups set to 1.
|
|
230
|
+
if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
|
|
231
|
+
return node.framework_attr[GROUPS] == 1
|
|
232
|
+
return node.type == torch.nn.Linear
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _prune_pytorch_edge_node(node: BaseNode,
|
|
236
|
+
mask: np.ndarray,
|
|
237
|
+
fw_info: FrameworkInfo,
|
|
238
|
+
is_exit_node: bool):
|
|
239
|
+
"""
|
|
240
|
+
Prunes the given Pytorch node by applying the mask to the node's weights (weights and biases).
|
|
241
|
+
This function can handle both entry and exit nodes by specifying the is_exit_node parameter.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
node (BaseNode): The node to be pruned.
|
|
245
|
+
mask (np.ndarray): The pruning mask to be applied.
|
|
246
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
|
247
|
+
is_exit_node (bool): A boolean indicating whether the node is an exit node.
|
|
248
|
+
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
# Retrieve the kernel attribute and the axes to prune.
|
|
252
|
+
kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
|
|
253
|
+
io_axis = fw_info.kernel_channels_mapping.get(node.type)
|
|
254
|
+
axis_to_prune = io_axis[int(is_exit_node)]
|
|
255
|
+
kernel = node.get_weights_by_keys(kernel_attr)
|
|
256
|
+
# Convert mask to boolean.
|
|
257
|
+
mask_bool = mask.astype(bool)
|
|
258
|
+
|
|
259
|
+
pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune)
|
|
260
|
+
node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel)
|
|
261
|
+
|
|
262
|
+
if not is_exit_node and node.framework_attr[BIAS]:
|
|
263
|
+
# Prune the bias if applicable and it's an entry node.
|
|
264
|
+
bias = node.get_weights_by_keys(BIAS)
|
|
265
|
+
pruned_bias = bias.compress(mask_bool)
|
|
266
|
+
node.set_weights_by_keys(name=BIAS, tensor=pruned_bias)
|
|
267
|
+
|
|
268
|
+
if not is_exit_node:
|
|
269
|
+
# Update 'out_channels' or 'out_features' attributes for entry nodes
|
|
270
|
+
# Conv2d,ConvTranspose2d / Linear layers.
|
|
271
|
+
if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
|
|
272
|
+
node.framework_attr[OUT_CHANNELS] = int(np.sum(mask))
|
|
273
|
+
elif node.type == torch.nn.Linear:
|
|
274
|
+
node.framework_attr[OUT_FEATURES] = int(np.sum(mask))
|
|
275
|
+
else:
|
|
276
|
+
Logger.exception(f"{node.type} is currently not supported"
|
|
277
|
+
f"as an edge node in a pruning section")
|
|
278
|
+
|
|
279
|
+
if is_exit_node:
|
|
280
|
+
if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
|
|
281
|
+
node.framework_attr[IN_CHANNELS] = int(np.sum(mask))
|
|
282
|
+
elif node.type == torch.nn.Linear:
|
|
283
|
+
node.framework_attr[IN_FEATURES] = int(np.sum(mask))
|
|
284
|
+
else:
|
|
285
|
+
Logger.exception(f"{node.type} is currently not supported"
|
|
286
|
+
f"as an edge node in a pruning section")
|
|
287
|
+
# Adjust the input shape for the last node in the section.
|
|
288
|
+
_edit_node_input_shape(node, mask_bool, fw_info)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def _edit_node_input_shape(node: BaseNode,
|
|
292
|
+
input_mask: np.ndarray,
|
|
293
|
+
fw_info: FrameworkInfo):
|
|
294
|
+
"""
|
|
295
|
+
Adjusts the input shape of a node based on the given input mask.
|
|
296
|
+
|
|
297
|
+
This function modifies the input shape of the given node to reflect the pruning
|
|
298
|
+
that has taken place. It updates the last dimension of the node's input shape
|
|
299
|
+
to match the number of channels that remain after pruning.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
node (BaseNode): The node whose input shape needs to be adjusted.
|
|
303
|
+
input_mask (np.ndarray): A binary array where 1 indicates the channel is kept and 0 means pruned.
|
|
304
|
+
fw_info (FrameworkInfo): Framework-specific information object.
|
|
305
|
+
"""
|
|
306
|
+
# Start with the current input shape of the node.
|
|
307
|
+
new_input_shape = list(node.input_shape)
|
|
308
|
+
|
|
309
|
+
# Adjust the last dimension of the shape to match the number of unpruned (retained) channels.
|
|
310
|
+
# This is done by summing the mask, as each '1' in the mask represents a retained channel.
|
|
311
|
+
channel_axis = fw_info.out_channel_axis_mapping.get(node.type)
|
|
312
|
+
new_input_shape[0][channel_axis] = int(np.sum(input_mask))
|
|
313
|
+
|
|
314
|
+
# Update the node's input shape with the new dimensions.
|
|
315
|
+
node.input_shape = tuple(new_input_shape)
|
|
@@ -26,7 +26,7 @@ from torch.nn import Module, Sigmoid, Softmax
|
|
|
26
26
|
|
|
27
27
|
import model_compression_toolkit.core.pytorch.constants as pytorch_constants
|
|
28
28
|
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
|
|
29
|
-
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig,
|
|
29
|
+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfig
|
|
30
30
|
from model_compression_toolkit.core import common
|
|
31
31
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
32
32
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
@@ -332,7 +332,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
332
332
|
|
|
333
333
|
def get_sensitivity_evaluator(self,
|
|
334
334
|
graph: Graph,
|
|
335
|
-
quant_config:
|
|
335
|
+
quant_config: MixedPrecisionQuantizationConfig,
|
|
336
336
|
representative_data_gen: Callable,
|
|
337
337
|
fw_info: FrameworkInfo,
|
|
338
338
|
disable_activation_for_metric: bool = False,
|
|
@@ -24,7 +24,7 @@ from model_compression_toolkit.core.common.user_info import UserInformation
|
|
|
24
24
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
25
25
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
26
26
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
27
|
-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import
|
|
27
|
+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
|
28
28
|
from model_compression_toolkit.core import CoreConfig
|
|
29
29
|
from model_compression_toolkit.core.runner import core_runner
|
|
30
30
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
@@ -177,7 +177,7 @@ if FOUND_TF:
|
|
|
177
177
|
with different bitwidths for different layers.
|
|
178
178
|
The candidates bitwidth for quantization should be defined in the target platform model:
|
|
179
179
|
|
|
180
|
-
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.
|
|
180
|
+
>>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))
|
|
181
181
|
|
|
182
182
|
For mixed-precision set a target KPI object:
|
|
183
183
|
Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
|
|
@@ -199,9 +199,9 @@ if FOUND_TF:
|
|
|
199
199
|
fw_info=fw_info).validate()
|
|
200
200
|
|
|
201
201
|
if core_config.mixed_precision_enable:
|
|
202
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
202
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
203
203
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
204
|
-
"
|
|
204
|
+
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
|
205
205
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
206
206
|
|
|
207
207
|
tb_w = init_tensorboard_writer(fw_info)
|
|
@@ -29,7 +29,7 @@ from model_compression_toolkit.core.exporter import export_model
|
|
|
29
29
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
30
30
|
from model_compression_toolkit.core import CoreConfig
|
|
31
31
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
32
|
-
|
|
32
|
+
MixedPrecisionQuantizationConfig
|
|
33
33
|
|
|
34
34
|
LR_DEFAULT = 1e-4
|
|
35
35
|
LR_REST_DEFAULT = 1e-4
|
|
@@ -157,9 +157,9 @@ if FOUND_TORCH:
|
|
|
157
157
|
"""
|
|
158
158
|
|
|
159
159
|
if core_config.mixed_precision_enable:
|
|
160
|
-
if not isinstance(core_config.mixed_precision_config,
|
|
160
|
+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
161
161
|
Logger.error("Given quantization config to mixed-precision facade is not of type "
|
|
162
|
-
"
|
|
162
|
+
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
|
163
163
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
164
164
|
|
|
165
165
|
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
@@ -16,4 +16,5 @@
|
|
|
16
16
|
from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
|
|
17
17
|
from model_compression_toolkit.core.common.pruning.pruning_config import ImportanceMetric, PruningConfig, ChannelsFilteringStrategy
|
|
18
18
|
from model_compression_toolkit.pruning.keras.pruning_facade import keras_pruning_experimental
|
|
19
|
+
from model_compression_toolkit.pruning.pytorch.pruning_facade import pytorch_pruning_experimental
|
|
19
20
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|