mct-nightly 1.11.0.20240304.post404__py3-none-any.whl → 1.11.0.20240306.post426__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/METADATA +5 -5
  2. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/RECORD +42 -40
  3. model_compression_toolkit/core/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +2 -2
  5. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +4 -70
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -2
  7. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -2
  8. model_compression_toolkit/core/common/pruning/memory_calculator.py +19 -1
  9. model_compression_toolkit/core/common/quantization/core_config.py +3 -3
  10. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +0 -3
  11. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +0 -3
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +0 -1
  13. model_compression_toolkit/core/keras/keras_implementation.py +2 -2
  14. model_compression_toolkit/core/keras/kpi_data_facade.py +5 -6
  15. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +19 -19
  16. model_compression_toolkit/core/pytorch/constants.py +3 -0
  17. model_compression_toolkit/core/pytorch/kpi_data_facade.py +5 -5
  18. model_compression_toolkit/core/pytorch/pruning/__init__.py +14 -0
  19. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +315 -0
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -2
  21. model_compression_toolkit/gptq/__init__.py +1 -1
  22. model_compression_toolkit/gptq/common/gptq_config.py +5 -72
  23. model_compression_toolkit/gptq/keras/gptq_training.py +2 -2
  24. model_compression_toolkit/gptq/keras/quantization_facade.py +19 -33
  25. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +3 -3
  26. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +2 -4
  27. model_compression_toolkit/gptq/pytorch/gptq_training.py +2 -2
  28. model_compression_toolkit/gptq/pytorch/quantization_facade.py +14 -31
  29. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +3 -3
  30. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +2 -4
  31. model_compression_toolkit/gptq/runner.py +3 -3
  32. model_compression_toolkit/pruning/__init__.py +1 -0
  33. model_compression_toolkit/pruning/pytorch/__init__.py +14 -0
  34. model_compression_toolkit/pruning/pytorch/pruning_facade.py +166 -0
  35. model_compression_toolkit/ptq/__init__.py +2 -2
  36. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -30
  37. model_compression_toolkit/ptq/pytorch/quantization_facade.py +12 -30
  38. model_compression_toolkit/qat/keras/quantization_facade.py +6 -9
  39. model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -7
  40. model_compression_toolkit/core/common/quantization/quantization_params_generation/kmeans_params.py +0 -64
  41. model_compression_toolkit/core/common/quantization/quantizers/kmeans_quantizer.py +0 -53
  42. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/LICENSE.md +0 -0
  43. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/WHEEL +0 -0
  44. {mct_nightly-1.11.0.20240304.post404.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/top_level.txt +0 -0
@@ -44,9 +44,9 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
44
44
  Prunes the entry node of a model in Keras.
45
45
 
46
46
  Args:
47
- node: The entry node to be pruned.
48
- output_mask: A numpy array representing the mask to be applied to the output channels.
49
- fw_info: Framework-specific information object.
47
+ node (BaseNode): The entry node to be pruned.
48
+ output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
49
+ fw_info (FrameworkInfo): Framework-specific information object.
50
50
 
51
51
  """
52
52
  return _prune_keras_edge_node(node=node,
@@ -63,10 +63,10 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
63
63
  Prunes an intermediate node in a Keras model.
64
64
 
65
65
  Args:
66
- node: The intermediate node to be pruned.
67
- input_mask: A numpy array representing the mask to be applied to the input channels.
68
- output_mask: A numpy array representing the mask to be applied to the output channels.
69
- fw_info: Framework-specific information object.
66
+ node (BaseNode): The intermediate node to be pruned.
67
+ input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
68
+ output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
69
+ fw_info (FrameworkInfo): Framework-specific information object.
70
70
 
71
71
  """
72
72
  _edit_node_input_shape(input_mask, node)
@@ -85,9 +85,9 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
85
85
  Prunes the exit node of a model in Keras.
86
86
 
87
87
  Args:
88
- node: The exit node to be pruned.
89
- input_mask: A numpy array representing the mask to be applied to the input channels.
90
- fw_info: Framework-specific information object.
88
+ node (BaseNode): The exit node to be pruned.
89
+ input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
90
+ fw_info (FrameworkInfo): Framework-specific information object.
91
91
 
92
92
  """
93
93
  return _prune_keras_edge_node(node=node,
@@ -100,10 +100,10 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
100
100
  Determines whether a node is an entry node in a Keras model.
101
101
 
102
102
  Args:
103
- node: The node to be checked.
103
+ node (BaseNode): The node to be checked.
104
104
 
105
105
  Returns:
106
- Boolean indicating if the node is an entry node.
106
+ bool: Boolean indicating if the node is an entry node.
107
107
  """
108
108
  return _is_keras_node_pruning_section_edge(node)
109
109
 
@@ -115,26 +115,26 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
115
115
  Determines whether a node is an exit node in a Keras model.
116
116
 
117
117
  Args:
118
- node: The node to be checked.
119
- corresponding_entry_node: The entry node of the pruning section that is checked.
120
- fw_info: Framework-specific information object.
118
+ node (BaseNode): The node to be checked.
119
+ corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
120
+ fw_info (FrameworkInfo): Framework-specific information object.
121
121
 
122
122
  Returns:
123
- Boolean indicating if the node is an exit node.
123
+ bool: Boolean indicating if the node is an exit node.
124
124
  """
125
125
  return _is_keras_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
126
126
  corresponding_entry_node,
127
127
  fw_info)
128
128
 
129
- def is_node_intermediate_pruning_section(self, node) -> bool:
129
+ def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
130
130
  """
131
131
  Determines whether a node is part of the intermediate section in the pruning process of a Keras model.
132
132
 
133
133
  Args:
134
- node: The node to be checked.
134
+ node (BaseNode): The node to be checked.
135
135
 
136
136
  Returns:
137
- Boolean indicating if the node is part of the intermediate pruning section.
137
+ bool: Boolean indicating if the node is part of the intermediate pruning section.
138
138
  """
139
139
  # Nodes that are not Conv2D, Conv2DTranspose, DepthwiseConv2D, or Dense are considered intermediate.
140
140
  return node.type not in [keras.layers.DepthwiseConv2D,
@@ -62,6 +62,9 @@ DIM = 'dim'
62
62
  IN_CHANNELS = 'in_channels'
63
63
  OUT_CHANNELS = 'out_channels'
64
64
  NUM_FEATURES = 'num_features'
65
+ NUM_PARAMETERS = 'num_parameters'
66
+ IN_FEATURES = 'in_features'
67
+ OUT_FEATURES = 'out_features'
65
68
 
66
69
  # torch devices
67
70
  CUDA = 'cuda'
@@ -22,7 +22,7 @@ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import
22
22
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
23
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_data import compute_kpi_data
24
24
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
25
- from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
25
+ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
26
26
  from model_compression_toolkit.constants import FOUND_TORCH
27
27
 
28
28
  if FOUND_TORCH:
@@ -38,7 +38,7 @@ if FOUND_TORCH:
38
38
 
39
39
  def pytorch_kpi_data(in_model: Module,
40
40
  representative_data_gen: Callable,
41
- core_config: CoreConfig = CoreConfig(), # TODO: Why pytorch is initilized and keras not?
41
+ core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
42
42
  fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
43
43
  target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC) -> KPI:
44
44
  """
@@ -75,9 +75,9 @@ if FOUND_TORCH:
75
75
 
76
76
  """
77
77
 
78
- if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfigV2):
79
- Logger.error("KPI data computation can't be executed without MixedPrecisionQuantizationConfigV2 object."
80
- "Given quant_config is not of type MixedPrecisionQuantizationConfigV2.")
78
+ if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
79
+ Logger.error("KPI data computation can't be executed without MixedPrecisionQuantizationConfig object."
80
+ "Given quant_config is not of type MixedPrecisionQuantizationConfig.")
81
81
 
82
82
  fw_impl = PytorchImplementation()
83
83
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,315 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Tuple, Dict
17
+
18
+ from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import \
19
+ PruningFrameworkImplementation
20
+ from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
21
+ from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
22
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
+ from model_compression_toolkit.core.common import BaseNode
24
+ from model_compression_toolkit.core.pytorch.constants import BIAS, GROUPS, OUT_CHANNELS, OUT_FEATURES, NUM_FEATURES, \
25
+ IN_CHANNELS, IN_FEATURES, NUM_PARAMETERS
26
+ import torch
27
+
28
+ import numpy as np
29
+
30
+ from model_compression_toolkit.logger import Logger
31
+
32
+
33
+ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplementation):
34
+ """
35
+ Implementation of the PruningFramework for the Pytorch framework. This class provides
36
+ concrete implementations of the abstract methods defined in PruningFrameworkImplementation
37
+ for the Pytorch framework.
38
+ """
39
+
40
+ def prune_entry_node(self,
41
+ node: BaseNode,
42
+ output_mask: np.ndarray,
43
+ fw_info: FrameworkInfo):
44
+ """
45
+ Prunes the entry node of a model in Pytorch.
46
+
47
+ Args:
48
+ node (BaseNode): The entry node to be pruned.
49
+ output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
50
+ fw_info (FrameworkInfo): Framework-specific information object.
51
+
52
+ """
53
+ return _prune_pytorch_edge_node(node=node,
54
+ mask=output_mask,
55
+ fw_info=fw_info,
56
+ is_exit_node=False)
57
+
58
+ def prune_intermediate_node(self,
59
+ node: BaseNode,
60
+ input_mask: np.ndarray,
61
+ output_mask: np.ndarray,
62
+ fw_info: FrameworkInfo):
63
+ """
64
+ Prunes an intermediate node in a Pytorch model.
65
+
66
+ Args:
67
+ node (BaseNode): The intermediate node to be pruned.
68
+ input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
69
+ output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
70
+ fw_info (FrameworkInfo): Framework-specific information object.
71
+
72
+ """
73
+ # TODO (reuvenp/liord): Address handling of node parameters that can be either a single value across all channels or distinct per channel, e.g., PReLU. Consider developing a structured approach.
74
+ pruning_en = True
75
+ _edit_node_input_shape(node, input_mask, fw_info)
76
+ pruned_parameters = {}
77
+ mask_bool = output_mask.astype(bool)
78
+ node.weights = pruned_parameters
79
+ if node.type == torch.nn.BatchNorm2d:
80
+ node.framework_attr[NUM_FEATURES] = int(np.sum(input_mask))
81
+ elif node.type == torch.nn.PReLU:
82
+ if node.framework_attr[NUM_PARAMETERS] > 1:
83
+ node.framework_attr[NUM_PARAMETERS] = int(np.sum(input_mask))
84
+ else:
85
+ pruning_en = False
86
+
87
+ if pruning_en:
88
+ for k, v in node.weights.items():
89
+ # Apply the mask to the weights.
90
+ pruned_parameters[k] = v.compress(mask_bool, axis=-1)
91
+
92
+ def prune_exit_node(self,
93
+ node: BaseNode,
94
+ input_mask: np.ndarray,
95
+ fw_info: FrameworkInfo):
96
+ """
97
+ Prunes the exit node of a model in Pytorch.
98
+
99
+ Args:
100
+ node (BaseNode): The exit node to be pruned.
101
+ input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
102
+ fw_info (FrameworkInfo): Framework-specific information object.
103
+
104
+ """
105
+ return _prune_pytorch_edge_node(node=node,
106
+ mask=input_mask,
107
+ fw_info=fw_info,
108
+ is_exit_node=True)
109
+
110
+ def is_node_entry_node(self, node: BaseNode) -> bool:
111
+ """
112
+ Determines whether a node is an entry node in a Pytorch model.
113
+
114
+ Args:
115
+ node (BaseNode): The node to be checked.
116
+
117
+ Returns:
118
+ bool: Boolean indicating if the node is an entry node.
119
+ """
120
+ return _is_pytorch_node_pruning_section_edge(node)
121
+
122
+ def is_node_exit_node(self,
123
+ node: BaseNode,
124
+ corresponding_entry_node: BaseNode,
125
+ fw_info: FrameworkInfo) -> bool:
126
+ """
127
+ Determines whether a node is an exit node in a Pytorch model.
128
+
129
+ Args:
130
+ node (BaseNode): The node to be checked.
131
+ corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
132
+ fw_info (FrameworkInfo) Framework-specific information object.
133
+
134
+ Returns:
135
+ bool: Boolean indicating if the node is an exit node.
136
+ """
137
+ return _is_pytorch_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
138
+ corresponding_entry_node,
139
+ fw_info)
140
+
141
+ def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
142
+ """
143
+ Determines whether a node is part of the intermediate section in the pruning process of a Pytorch model.
144
+
145
+ Args:
146
+ node (BaseNode): The node to be checked.
147
+
148
+ Returns:
149
+ bool: Boolean indicating if the node is part of the intermediate pruning section.
150
+ """
151
+ # Nodes that are not Conv2d, ConvTranspose2d, or Linear are considered intermediate.
152
+ # For PReLU prune attributes only if there is a parameter per channel
153
+ return node.type not in [torch.nn.Conv2d,
154
+ torch.nn.ConvTranspose2d,
155
+ torch.nn.Linear]
156
+
157
+ def attrs_oi_channels_info_for_pruning(self,
158
+ node: BaseNode,
159
+ fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
160
+ """
161
+ Retrieves the attributes of a given node along with the output/input (OI) channel axis
162
+ for each attribute used to prune these attributes.
163
+
164
+ Not all attributes of a node are directly associated with both input and output channels.
165
+ For example, bias vectors in convolutional layers are solely related to the number of output
166
+ channels and do not have a corresponding input channel dimension.
167
+ In cases like that, None is returned in the tuple of axis for such attributes.
168
+
169
+ For kernel operations (like convolutions), the function identifies the output and input
170
+ channel axis based on framework-specific information.
171
+ For non-kernel operations, it defaults to setting the last axis as the output
172
+ channel axis, assuming no specific input channel axis.
173
+
174
+ Args:
175
+ node (BaseNode): The node from the computational graph.
176
+ fw_info (FrameworkInfo): Contains framework-specific information and utilities.
177
+
178
+ Returns:
179
+ Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'weight' or 'bias')
180
+ and each value is a tuple representing the output and input channel axis indices respectively.
181
+ """
182
+
183
+ attributes_with_axis = {}
184
+ if fw_info.is_kernel_op(node.type):
185
+ kernel_attributes = fw_info.get_kernel_op_attributes(node.type)
186
+ if kernel_attributes is None or len(kernel_attributes) == 0:
187
+ Logger.error(f"Expected to find attributes but found {kernel_attributes}")
188
+
189
+ for attr in kernel_attributes:
190
+ attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type)
191
+
192
+ # Bias is a vector at the length of the number of output channels.
193
+ # For this reason, input channel axis is irrelevant to the bias attribute.
194
+ attributes_with_axis[BIAS] = (0, None)
195
+ else:
196
+ # We have several assumptions here:
197
+ # 1. For intermediate nodes, we prune all nodes' weights.
198
+ # 2. The output channel axis is the last axis of this attribute.
199
+ # 3. The input channel axis is irrelevant since these attributes are pruned only by
200
+ # their output channels.
201
+ for attr in list(node.weights.keys()):
202
+ # If the number of float parameters is 1 or less - is the case where
203
+ # we have one parameter for all channels. For this case, we don't
204
+ # want to prune the parameter.
205
+ if node.get_num_parameters(fw_info)[1] <= 1:
206
+ attributes_with_axis[attr] = (None, None)
207
+ else:
208
+ attributes_with_axis[attr] = (-1, None)
209
+
210
+ return attributes_with_axis
211
+
212
+
213
+ def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
214
+ """
215
+ Determines if a Pytorch node is an edge of a pruning section.
216
+
217
+ In the context of pruning, an 'edge' node is a layer that can potentially be pruned.
218
+ This function identifies such nodes based on their type and attributes. Specifically,
219
+ Conv2d and ConvTranspose2d layers with 'groups' attribute set to 1, and Linear layers
220
+ are considered as edges for pruning sections.
221
+
222
+ Args:
223
+ node (BaseNode): The node to be evaluated.
224
+
225
+ Returns:
226
+ bool: True if the node is an edge of a pruning section, False otherwise.
227
+ """
228
+
229
+ # Check if the node is a Conv2D or Conv2DTranspose layer with groups set to 1.
230
+ if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
231
+ return node.framework_attr[GROUPS] == 1
232
+ return node.type == torch.nn.Linear
233
+
234
+
235
+ def _prune_pytorch_edge_node(node: BaseNode,
236
+ mask: np.ndarray,
237
+ fw_info: FrameworkInfo,
238
+ is_exit_node: bool):
239
+ """
240
+ Prunes the given Pytorch node by applying the mask to the node's weights (weights and biases).
241
+ This function can handle both entry and exit nodes by specifying the is_exit_node parameter.
242
+
243
+ Args:
244
+ node (BaseNode): The node to be pruned.
245
+ mask (np.ndarray): The pruning mask to be applied.
246
+ fw_info (FrameworkInfo): Framework-specific information object.
247
+ is_exit_node (bool): A boolean indicating whether the node is an exit node.
248
+
249
+ """
250
+
251
+ # Retrieve the kernel attribute and the axes to prune.
252
+ kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
253
+ io_axis = fw_info.kernel_channels_mapping.get(node.type)
254
+ axis_to_prune = io_axis[int(is_exit_node)]
255
+ kernel = node.get_weights_by_keys(kernel_attr)
256
+ # Convert mask to boolean.
257
+ mask_bool = mask.astype(bool)
258
+
259
+ pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune)
260
+ node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel)
261
+
262
+ if not is_exit_node and node.framework_attr[BIAS]:
263
+ # Prune the bias if applicable and it's an entry node.
264
+ bias = node.get_weights_by_keys(BIAS)
265
+ pruned_bias = bias.compress(mask_bool)
266
+ node.set_weights_by_keys(name=BIAS, tensor=pruned_bias)
267
+
268
+ if not is_exit_node:
269
+ # Update 'out_channels' or 'out_features' attributes for entry nodes
270
+ # Conv2d,ConvTranspose2d / Linear layers.
271
+ if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
272
+ node.framework_attr[OUT_CHANNELS] = int(np.sum(mask))
273
+ elif node.type == torch.nn.Linear:
274
+ node.framework_attr[OUT_FEATURES] = int(np.sum(mask))
275
+ else:
276
+ Logger.exception(f"{node.type} is currently not supported"
277
+ f"as an edge node in a pruning section")
278
+
279
+ if is_exit_node:
280
+ if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
281
+ node.framework_attr[IN_CHANNELS] = int(np.sum(mask))
282
+ elif node.type == torch.nn.Linear:
283
+ node.framework_attr[IN_FEATURES] = int(np.sum(mask))
284
+ else:
285
+ Logger.exception(f"{node.type} is currently not supported"
286
+ f"as an edge node in a pruning section")
287
+ # Adjust the input shape for the last node in the section.
288
+ _edit_node_input_shape(node, mask_bool, fw_info)
289
+
290
+
291
+ def _edit_node_input_shape(node: BaseNode,
292
+ input_mask: np.ndarray,
293
+ fw_info: FrameworkInfo):
294
+ """
295
+ Adjusts the input shape of a node based on the given input mask.
296
+
297
+ This function modifies the input shape of the given node to reflect the pruning
298
+ that has taken place. It updates the last dimension of the node's input shape
299
+ to match the number of channels that remain after pruning.
300
+
301
+ Args:
302
+ node (BaseNode): The node whose input shape needs to be adjusted.
303
+ input_mask (np.ndarray): A binary array where 1 indicates the channel is kept and 0 means pruned.
304
+ fw_info (FrameworkInfo): Framework-specific information object.
305
+ """
306
+ # Start with the current input shape of the node.
307
+ new_input_shape = list(node.input_shape)
308
+
309
+ # Adjust the last dimension of the shape to match the number of unpruned (retained) channels.
310
+ # This is done by summing the mask, as each '1' in the mask represents a retained channel.
311
+ channel_axis = fw_info.out_channel_axis_mapping.get(node.type)
312
+ new_input_shape[0][channel_axis] = int(np.sum(input_mask))
313
+
314
+ # Update the node's input shape with the new dimensions.
315
+ node.input_shape = tuple(new_input_shape)
@@ -26,7 +26,7 @@ from torch.nn import Module, Sigmoid, Softmax
26
26
 
27
27
  import model_compression_toolkit.core.pytorch.constants as pytorch_constants
28
28
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
29
- from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfigV2
29
+ from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfig
30
30
  from model_compression_toolkit.core import common
31
31
  from model_compression_toolkit.core.common import Graph, BaseNode
32
32
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -332,7 +332,7 @@ class PytorchImplementation(FrameworkImplementation):
332
332
 
333
333
  def get_sensitivity_evaluator(self,
334
334
  graph: Graph,
335
- quant_config: MixedPrecisionQuantizationConfigV2,
335
+ quant_config: MixedPrecisionQuantizationConfig,
336
336
  representative_data_gen: Callable,
337
337
  fw_info: FrameworkInfo,
338
338
  disable_activation_for_metric: bool = False,
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfigV2, GPTQHessianScoresConfig
16
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfig, GPTQHessianScoresConfig
17
17
  from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization
18
18
  from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
19
19
  from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization
@@ -61,8 +61,8 @@ class GradientPTQConfig:
61
61
  """
62
62
  Configuration to use for quantization with GradientPTQ.
63
63
  """
64
-
65
- def __init__(self, n_iter: int,
64
+ def __init__(self,
65
+ n_epochs: int,
66
66
  optimizer: Any,
67
67
  optimizer_rest: Any = None,
68
68
  loss: Callable = None,
@@ -79,7 +79,7 @@ class GradientPTQConfig:
79
79
  Initialize a GradientPTQConfig.
80
80
 
81
81
  Args:
82
- n_iter (int): Number of iterations to train.
82
+ n_epochs (int): Number of representative dataset epochs to train.
83
83
  optimizer (Any): Optimizer to use.
84
84
  optimizer_rest (Any): Optimizer to use for bias and quantizer parameters.
85
85
  loss (Callable): The loss to use. should accept 6 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors,
@@ -96,7 +96,8 @@ class GradientPTQConfig:
96
96
  gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
97
97
 
98
98
  """
99
- self.n_iter = n_iter
99
+
100
+ self.n_epochs = n_epochs
100
101
  self.optimizer = optimizer
101
102
  self.optimizer_rest = optimizer_rest
102
103
  self.loss = loss
@@ -114,71 +115,3 @@ class GradientPTQConfig:
114
115
  else gptq_quantizer_params_override
115
116
 
116
117
 
117
- class GradientPTQConfigV2(GradientPTQConfig):
118
- """
119
- Configuration to use for quantization with GradientPTQV2.
120
- """
121
- def __init__(self, n_epochs: int,
122
- optimizer: Any,
123
- optimizer_rest: Any = None,
124
- loss: Callable = None,
125
- log_function: Callable = None,
126
- train_bias: bool = True,
127
- rounding_type: RoundingType = RoundingType.SoftQuantizer,
128
- use_hessian_based_weights: bool = True,
129
- optimizer_quantization_parameter: Any = None,
130
- optimizer_bias: Any = None,
131
- regularization_factor: float = REG_DEFAULT,
132
- hessian_weights_config: GPTQHessianScoresConfig = GPTQHessianScoresConfig(),
133
- gptq_quantizer_params_override: Dict[str, Any] = None):
134
- """
135
- Initialize a GradientPTQConfigV2.
136
-
137
- Args:
138
- n_epochs (int): Number of representative dataset epochs to train.
139
- optimizer (Any): Optimizer to use.
140
- optimizer_rest (Any): Optimizer to use for bias and quantizer parameters.
141
- loss (Callable): The loss to use. should accept 6 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors,
142
- the 3rd is a list of quantized weights, the 4th is a list of float weights, the 5th and 6th lists are the mean and std of the tensors
143
- accordingly. see example in multiple_tensors_mse_loss
144
- log_function (Callable): Function to log information about the GPTQ process.
145
- train_bias (bool): Whether to update the bias during the training or not.
146
- rounding_type (RoundingType): An enum that defines the rounding type.
147
- use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
148
- optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
149
- optimizer_bias (Any): Optimizer to override the rest optimizerfor bias.
150
- regularization_factor (float): A floating point number that defines the regularization factor.
151
- hessian_weights_config (GPTQHessianScoresConfig): A configuration that include all necessary arguments to run a computation of Hessian scores for the GPTQ loss.
152
- gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
153
-
154
- """
155
-
156
- super().__init__(n_iter=None,
157
- optimizer=optimizer,
158
- optimizer_rest=optimizer_rest,
159
- loss=loss,
160
- log_function=log_function,
161
- train_bias=train_bias,
162
- rounding_type=rounding_type,
163
- use_hessian_based_weights=use_hessian_based_weights,
164
- optimizer_quantization_parameter=optimizer_quantization_parameter,
165
- optimizer_bias=optimizer_bias,
166
- regularization_factor=regularization_factor,
167
- hessian_weights_config=hessian_weights_config,
168
- gptq_quantizer_params_override=gptq_quantizer_params_override)
169
- self.n_epochs = n_epochs
170
-
171
- @classmethod
172
- def from_v1(cls, n_ptq_iter: int, config_v1: GradientPTQConfig):
173
- """
174
- Initialize a GradientPTQConfigV2 from GradientPTQConfig instance.
175
-
176
- Args:
177
- n_ptq_iter (int): Number of PTQ calibration iters (length of representative dataset).
178
- config_v1 (GradientPTQConfig): A GPTQ config to convert to V2.
179
-
180
- """
181
- n_epochs = int(round(config_v1.n_iter) / n_ptq_iter)
182
- v1_params = config_v1.__dict__
183
- v1_params = {k: v for k, v in v1_params.items() if k != 'n_iter'}
184
- return cls(n_epochs, **v1_params)
@@ -37,7 +37,7 @@ else:
37
37
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
38
38
  from model_compression_toolkit.core import common
39
39
  from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
40
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
40
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
41
41
  from model_compression_toolkit.core.common import Graph
42
42
  from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, get_gptq_trainable_parameters
43
43
  from model_compression_toolkit.gptq.keras.quantizer.regularization_factory import get_regularization
@@ -56,7 +56,7 @@ class KerasGPTQTrainer(GPTQTrainer):
56
56
  def __init__(self,
57
57
  graph_float: Graph,
58
58
  graph_quant: Graph,
59
- gptq_config: GradientPTQConfigV2,
59
+ gptq_config: GradientPTQConfig,
60
60
  fw_impl: FrameworkImplementation,
61
61
  fw_info: FrameworkInfo,
62
62
  representative_data_gen: Callable,