mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250619.621__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/RECORD +123 -123
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +2 -5
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
- model_compression_toolkit/core/common/framework_implementation.py +10 -22
- model_compression_toolkit/core/common/framework_info.py +105 -68
- model_compression_toolkit/core/common/graph/base_graph.py +15 -42
- model_compression_toolkit/core/common/graph/base_node.py +103 -42
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
- model_compression_toolkit/core/common/model_collector.py +10 -20
- model_compression_toolkit/core/common/model_validation.py +1 -4
- model_compression_toolkit/core/common/network_editors/actions.py +14 -38
- model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
- model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
- model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
- model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
- model_compression_toolkit/core/common/pruning/pruner.py +1 -6
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
- model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
- model_compression_toolkit/core/graph_prep_runner.py +2 -16
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/keras/default_framework_info.py +138 -87
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
- model_compression_toolkit/core/keras/keras_implementation.py +15 -35
- model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/quantization_prep_runner.py +4 -9
- model_compression_toolkit/core/runner.py +5 -15
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -0
- model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +19 -17
- model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -0
- model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
- model_compression_toolkit/gptq/common/gptq_training.py +1 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
- model_compression_toolkit/gptq/keras/graph_info.py +4 -6
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/runner.py +1 -7
- model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
- model_compression_toolkit/ptq/runner.py +1 -4
- model_compression_toolkit/qat/common/qat_config.py +2 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
- model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
- model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
- model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,6 @@ from model_compression_toolkit.core.common.pruning.pruning_framework_implementat
|
|
19
19
|
PruningFrameworkImplementation
|
20
20
|
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
|
21
21
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
22
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
22
|
from model_compression_toolkit.core.common import BaseNode
|
24
23
|
from model_compression_toolkit.core.pytorch.constants import BIAS, GROUPS, OUT_CHANNELS, OUT_FEATURES, NUM_FEATURES, \
|
25
24
|
IN_CHANNELS, IN_FEATURES, NUM_PARAMETERS
|
@@ -39,27 +38,23 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
39
38
|
|
40
39
|
def prune_entry_node(self,
|
41
40
|
node: BaseNode,
|
42
|
-
output_mask: np.ndarray
|
43
|
-
fw_info: FrameworkInfo):
|
41
|
+
output_mask: np.ndarray):
|
44
42
|
"""
|
45
43
|
Prunes the entry node of a model in Pytorch.
|
46
44
|
|
47
45
|
Args:
|
48
46
|
node (BaseNode): The entry node to be pruned.
|
49
47
|
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
50
|
-
fw_info (FrameworkInfo): Framework-specific information object.
|
51
48
|
|
52
49
|
"""
|
53
50
|
return _prune_pytorch_edge_node(node=node,
|
54
51
|
mask=output_mask,
|
55
|
-
fw_info=fw_info,
|
56
52
|
is_exit_node=False)
|
57
53
|
|
58
54
|
def prune_intermediate_node(self,
|
59
55
|
node: BaseNode,
|
60
56
|
input_mask: np.ndarray,
|
61
|
-
output_mask: np.ndarray
|
62
|
-
fw_info: FrameworkInfo):
|
57
|
+
output_mask: np.ndarray):
|
63
58
|
"""
|
64
59
|
Prunes an intermediate node in a Pytorch model.
|
65
60
|
|
@@ -67,12 +62,11 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
67
62
|
node (BaseNode): The intermediate node to be pruned.
|
68
63
|
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
69
64
|
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
70
|
-
fw_info (FrameworkInfo): Framework-specific information object.
|
71
65
|
|
72
66
|
"""
|
73
67
|
# TODO (reuvenp/liord): Address handling of node parameters that can be either a single value across all channels or distinct per channel, e.g., PReLU. Consider developing a structured approach.
|
74
68
|
pruning_en = True
|
75
|
-
_edit_node_input_shape(node, input_mask
|
69
|
+
_edit_node_input_shape(node, input_mask)
|
76
70
|
pruned_parameters = {}
|
77
71
|
mask_bool = output_mask.astype(bool)
|
78
72
|
node.weights = pruned_parameters
|
@@ -91,20 +85,17 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
91
85
|
|
92
86
|
def prune_exit_node(self,
|
93
87
|
node: BaseNode,
|
94
|
-
input_mask: np.ndarray
|
95
|
-
fw_info: FrameworkInfo):
|
88
|
+
input_mask: np.ndarray):
|
96
89
|
"""
|
97
90
|
Prunes the exit node of a model in Pytorch.
|
98
91
|
|
99
92
|
Args:
|
100
93
|
node (BaseNode): The exit node to be pruned.
|
101
94
|
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
102
|
-
fw_info (FrameworkInfo): Framework-specific information object.
|
103
95
|
|
104
96
|
"""
|
105
97
|
return _prune_pytorch_edge_node(node=node,
|
106
98
|
mask=input_mask,
|
107
|
-
fw_info=fw_info,
|
108
99
|
is_exit_node=True)
|
109
100
|
|
110
101
|
def is_node_entry_node(self, node: BaseNode) -> bool:
|
@@ -121,22 +112,19 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
121
112
|
|
122
113
|
def is_node_exit_node(self,
|
123
114
|
node: BaseNode,
|
124
|
-
corresponding_entry_node: BaseNode
|
125
|
-
fw_info: FrameworkInfo) -> bool:
|
115
|
+
corresponding_entry_node: BaseNode) -> bool:
|
126
116
|
"""
|
127
117
|
Determines whether a node is an exit node in a Pytorch model.
|
128
118
|
|
129
119
|
Args:
|
130
120
|
node (BaseNode): The node to be checked.
|
131
121
|
corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
|
132
|
-
fw_info (FrameworkInfo) Framework-specific information object.
|
133
122
|
|
134
123
|
Returns:
|
135
124
|
bool: Boolean indicating if the node is an exit node.
|
136
125
|
"""
|
137
126
|
return _is_pytorch_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
|
138
|
-
corresponding_entry_node
|
139
|
-
fw_info)
|
127
|
+
corresponding_entry_node)
|
140
128
|
|
141
129
|
def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
|
142
130
|
"""
|
@@ -155,8 +143,7 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
155
143
|
torch.nn.Linear]
|
156
144
|
|
157
145
|
def attrs_oi_channels_info_for_pruning(self,
|
158
|
-
node: BaseNode,
|
159
|
-
fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
|
146
|
+
node: BaseNode) -> Dict[str, Tuple[int, int]]:
|
160
147
|
"""
|
161
148
|
Retrieves the attributes of a given node along with the output/input (OI) channel axis
|
162
149
|
for each attribute used to prune these attributes.
|
@@ -173,7 +160,6 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
173
160
|
|
174
161
|
Args:
|
175
162
|
node (BaseNode): The node from the computational graph.
|
176
|
-
fw_info (FrameworkInfo): Contains framework-specific information and utilities.
|
177
163
|
|
178
164
|
Returns:
|
179
165
|
Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'weight' or 'bias')
|
@@ -181,13 +167,8 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
181
167
|
"""
|
182
168
|
|
183
169
|
attributes_with_axis = {}
|
184
|
-
if
|
185
|
-
|
186
|
-
if kernel_attributes is None or len(kernel_attributes) == 0:
|
187
|
-
Logger.critical(f"Expected to find kernel attributes but none were identified for node '{node.name}' of type {node.type}.")
|
188
|
-
|
189
|
-
for attr in kernel_attributes:
|
190
|
-
attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type)
|
170
|
+
if node.is_kernel_op:
|
171
|
+
attributes_with_axis[node.kernel_attr] = (node.channel_axis.output, node.channel_axis.input)
|
191
172
|
|
192
173
|
# Bias is a vector at the length of the number of output channels.
|
193
174
|
# For this reason, input channel axis is irrelevant to the bias attribute.
|
@@ -202,7 +183,7 @@ class PruningPytorchImplementation(PytorchImplementation, PruningFrameworkImplem
|
|
202
183
|
# If the number of float parameters is 1 or less - is the case where
|
203
184
|
# we have one parameter for all channels. For this case, we don't
|
204
185
|
# want to prune the parameter.
|
205
|
-
if node.get_num_parameters(
|
186
|
+
if node.get_num_parameters()[1] <= 1:
|
206
187
|
attributes_with_axis[attr] = (None, None)
|
207
188
|
else:
|
208
189
|
attributes_with_axis[attr] = (-1, None)
|
@@ -234,7 +215,6 @@ def _is_pytorch_node_pruning_section_edge(node: BaseNode) -> bool:
|
|
234
215
|
|
235
216
|
def _prune_pytorch_edge_node(node: BaseNode,
|
236
217
|
mask: np.ndarray,
|
237
|
-
fw_info: FrameworkInfo,
|
238
218
|
is_exit_node: bool):
|
239
219
|
"""
|
240
220
|
Prunes the given Pytorch node by applying the mask to the node's weights (weights and biases).
|
@@ -243,21 +223,18 @@ def _prune_pytorch_edge_node(node: BaseNode,
|
|
243
223
|
Args:
|
244
224
|
node (BaseNode): The node to be pruned.
|
245
225
|
mask (np.ndarray): The pruning mask to be applied.
|
246
|
-
fw_info (FrameworkInfo): Framework-specific information object.
|
247
226
|
is_exit_node (bool): A boolean indicating whether the node is an exit node.
|
248
227
|
|
249
228
|
"""
|
250
229
|
|
251
230
|
# Retrieve the kernel attribute and the axes to prune.
|
252
|
-
|
253
|
-
|
254
|
-
axis_to_prune = io_axis[int(is_exit_node)]
|
255
|
-
kernel = node.get_weights_by_keys(kernel_attr)
|
231
|
+
axis_to_prune = node.channel_axis.input if is_exit_node else node.channel_axis.output
|
232
|
+
kernel = node.get_weights_by_keys(node.kernel_attr)
|
256
233
|
# Convert mask to boolean.
|
257
234
|
mask_bool = mask.astype(bool)
|
258
235
|
|
259
236
|
pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune)
|
260
|
-
node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel)
|
237
|
+
node.set_weights_by_keys(name=node.kernel_attr, tensor=pruned_kernel)
|
261
238
|
|
262
239
|
if not is_exit_node and node.framework_attr[BIAS]:
|
263
240
|
# Prune the bias if applicable and it's an entry node.
|
@@ -285,12 +262,11 @@ def _prune_pytorch_edge_node(node: BaseNode,
|
|
285
262
|
Logger.critical(f"{node.type} is currently not supported"
|
286
263
|
f"as an edge node in a pruning section")
|
287
264
|
# Adjust the input shape for the last node in the section.
|
288
|
-
_edit_node_input_shape(node, mask_bool
|
265
|
+
_edit_node_input_shape(node, mask_bool)
|
289
266
|
|
290
267
|
|
291
268
|
def _edit_node_input_shape(node: BaseNode,
|
292
|
-
input_mask: np.ndarray
|
293
|
-
fw_info: FrameworkInfo):
|
269
|
+
input_mask: np.ndarray):
|
294
270
|
"""
|
295
271
|
Adjusts the input shape of a node based on the given input mask.
|
296
272
|
|
@@ -301,14 +277,13 @@ def _edit_node_input_shape(node: BaseNode,
|
|
301
277
|
Args:
|
302
278
|
node (BaseNode): The node whose input shape needs to be adjusted.
|
303
279
|
input_mask (np.ndarray): A binary array where 1 indicates the channel is kept and 0 means pruned.
|
304
|
-
fw_info (FrameworkInfo): Framework-specific information object.
|
305
280
|
"""
|
306
281
|
# Start with the current input shape of the node.
|
307
282
|
new_input_shape = list(node.input_shape)
|
308
283
|
|
309
284
|
# Adjust the last dimension of the shape to match the number of unpruned (retained) channels.
|
310
285
|
# This is done by summing the mask, as each '1' in the mask represents a retained channel.
|
311
|
-
channel_axis =
|
286
|
+
channel_axis = node.out_channel_axis
|
312
287
|
new_input_shape[0][channel_axis] = int(np.sum(input_mask))
|
313
288
|
|
314
289
|
# Update the node's input shape with the new dimensions.
|
@@ -37,7 +37,6 @@ from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
|
37
37
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_kl_divergence, compute_cs
|
38
38
|
from model_compression_toolkit.core.pytorch.back2framework import get_pytorch_model_builder
|
39
39
|
from model_compression_toolkit.core.pytorch.data_util import data_gen_to_dataloader
|
40
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
41
40
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_folding import \
|
42
41
|
pytorch_batchnorm_folding, pytorch_batchnorm_forward_folding
|
43
42
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.batchnorm_reconstruction import \
|
@@ -178,7 +177,6 @@ class PytorchImplementation(FrameworkImplementation):
|
|
178
177
|
graph: Graph,
|
179
178
|
mode: ModelBuilderMode,
|
180
179
|
append2output: List[Any] = None,
|
181
|
-
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
182
180
|
return_float_outputs: bool = False) -> Tuple:
|
183
181
|
"""
|
184
182
|
Build a Pytorch module from a graph.
|
@@ -189,7 +187,6 @@ class PytorchImplementation(FrameworkImplementation):
|
|
189
187
|
graph: Graph to build the module from it.
|
190
188
|
mode: Mode for how to build the module.
|
191
189
|
append2output: List of Nodes to set as the module's outputs.
|
192
|
-
fw_info: FrameworkInfo object with information about the specific framework's module
|
193
190
|
return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
|
194
191
|
|
195
192
|
Returns:
|
@@ -198,7 +195,6 @@ class PytorchImplementation(FrameworkImplementation):
|
|
198
195
|
pytorch_model_builder = get_pytorch_model_builder(mode)
|
199
196
|
return pytorch_model_builder(graph=graph,
|
200
197
|
append2output=append2output,
|
201
|
-
fw_info=fw_info,
|
202
198
|
return_float_outputs=return_float_outputs).build_model()
|
203
199
|
|
204
200
|
def run_model_inference(self,
|
@@ -232,63 +228,55 @@ class PytorchImplementation(FrameworkImplementation):
|
|
232
228
|
|
233
229
|
def shift_negative_correction(self,
|
234
230
|
graph: Graph,
|
235
|
-
core_config: CoreConfig
|
236
|
-
fw_info: FrameworkInfo) -> Graph:
|
231
|
+
core_config: CoreConfig) -> Graph:
|
237
232
|
"""
|
238
233
|
Apply shift negative correction (SNC) on a graph.
|
239
234
|
|
240
235
|
Args:
|
241
236
|
graph: Graph to apply SNC on.
|
242
237
|
core_config: Quantization configuration.
|
243
|
-
fw_info: FrameworkInfo object with information about the specific framework's module.
|
244
238
|
|
245
239
|
Returns:
|
246
240
|
Graph after SNC.
|
247
241
|
"""
|
248
242
|
return pytorch_apply_shift_negative_correction(graph,
|
249
|
-
core_config
|
250
|
-
fw_info)
|
243
|
+
core_config)
|
251
244
|
|
252
245
|
def compute_activation_bias_correction(self,
|
253
246
|
graph: Graph,
|
254
|
-
quant_config: QuantizationConfig
|
255
|
-
fw_info: FrameworkInfo):
|
247
|
+
quant_config: QuantizationConfig):
|
256
248
|
"""
|
257
249
|
Compute activation bias correction on a graph.
|
258
250
|
|
259
251
|
Args:
|
260
252
|
graph: Graph to apply activation bias correction on.
|
261
253
|
quant_config: QuantizationConfig of how the model should be quantized.
|
262
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
263
254
|
|
264
255
|
Returns:
|
265
256
|
Graph after activation bias correction computing.
|
266
257
|
"""
|
267
258
|
return pytorch_compute_activation_bias_correction_of_graph(graph=graph,
|
268
259
|
quant_config=quant_config,
|
269
|
-
fw_info=fw_info,
|
270
260
|
fw_impl=self)
|
271
261
|
|
272
262
|
def get_substitutions_channel_equalization(self,
|
273
|
-
quant_config: QuantizationConfig
|
274
|
-
fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
|
263
|
+
quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
|
275
264
|
"""
|
276
265
|
Return a list of the framework substitutions used for channel equalization.
|
277
266
|
|
278
267
|
Args:
|
279
268
|
quant_config: QuantizationConfig to determine which substitutions to return.
|
280
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
281
269
|
|
282
270
|
Returns:
|
283
271
|
A list of the framework substitutions used after we collect statistics.
|
284
272
|
"""
|
285
273
|
substitutions_list = []
|
286
274
|
if quant_config.activation_channel_equalization:
|
287
|
-
substitutions_list.extend([ScaleEqualization(quant_config
|
288
|
-
ScaleEqualizationWithPad(quant_config
|
275
|
+
substitutions_list.extend([ScaleEqualization(quant_config),
|
276
|
+
ScaleEqualizationWithPad(quant_config)])
|
289
277
|
return substitutions_list
|
290
278
|
|
291
|
-
def get_substitutions_prepare_graph(self
|
279
|
+
def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
|
292
280
|
"""
|
293
281
|
|
294
282
|
Returns: A list of the framework substitutions used before we collect the prior information.
|
@@ -299,7 +287,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
299
287
|
ScaledDotProductDecomposition(),
|
300
288
|
MatMulDecomposition(),
|
301
289
|
TransformFunctionCallMethod(),
|
302
|
-
FunctionalConvSubstitution(
|
290
|
+
FunctionalConvSubstitution(),
|
303
291
|
FunctionalBatchNorm(),
|
304
292
|
FunctionalLayerNorm(),
|
305
293
|
FunctionalLinear(),
|
@@ -401,20 +389,17 @@ class PytorchImplementation(FrameworkImplementation):
|
|
401
389
|
|
402
390
|
def get_node_prior_info(self,
|
403
391
|
node: BaseNode,
|
404
|
-
fw_info: FrameworkInfo,
|
405
392
|
graph: Graph) -> NodePriorInfo:
|
406
393
|
"""
|
407
394
|
Get a NodePriorInfo object for a node that represents a Pytorch layer.
|
408
395
|
Args:
|
409
396
|
node: Node to get its prior info.
|
410
|
-
fw_info: Framework specific information needed to create the prior info of the node.
|
411
397
|
graph: Graph to check the next node type.
|
412
398
|
Returns:
|
413
399
|
NodePriorInfo with information about the node.
|
414
400
|
"""
|
415
401
|
|
416
402
|
return create_node_prior_info(node=node,
|
417
|
-
fw_info=fw_info,
|
418
403
|
graph=graph)
|
419
404
|
|
420
405
|
def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
|
@@ -476,23 +461,19 @@ class PytorchImplementation(FrameworkImplementation):
|
|
476
461
|
return node.layer_class not in [argmax, softmax, Softmax]
|
477
462
|
|
478
463
|
def get_node_mac_operations(self,
|
479
|
-
node: BaseNode
|
480
|
-
fw_info: FrameworkInfo) -> float:
|
464
|
+
node: BaseNode) -> float:
|
481
465
|
"""
|
482
466
|
Gets the MAC operation count for a given operation.
|
483
467
|
|
484
468
|
Args:
|
485
469
|
node: A graph node that wraps the operation for which the MAC count is computed.
|
486
|
-
fw_info: FrameworkInfo object with information about the Pytorch model.
|
487
470
|
|
488
471
|
Returns: The MAC count of the operation
|
489
472
|
"""
|
490
|
-
|
491
|
-
if not kernels or kernels[0] is None:
|
473
|
+
if node.kernel_attr is None:
|
492
474
|
return 0
|
493
475
|
|
494
|
-
|
495
|
-
kernel_shape = node.get_weights_by_keys(kernels[0]).shape
|
476
|
+
kernel_shape = node.get_weights_by_keys(node.kernel_attr).shape
|
496
477
|
|
497
478
|
if node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d):
|
498
479
|
h, w = node.get_output_shapes_list()[0][-2:]
|
@@ -500,8 +481,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
500
481
|
|
501
482
|
if node.is_match_type(Linear):
|
502
483
|
# IN * OUT * (all previous dims[:-1])
|
503
|
-
|
504
|
-
return node.get_total_output_params() * kernel_shape[input_channel_axis]
|
484
|
+
return node.get_total_output_params() * kernel_shape[node.channel_axis.input]
|
505
485
|
|
506
486
|
return 0
|
507
487
|
|
@@ -23,23 +23,19 @@ from model_compression_toolkit.core.pytorch.constants import MOVING_MEAN, MOVING
|
|
23
23
|
|
24
24
|
|
25
25
|
def create_node_prior_info(node: BaseNode,
|
26
|
-
fw_info: FrameworkInfo,
|
27
26
|
graph: Graph):
|
28
27
|
"""
|
29
28
|
Create a NodePriorInfo object for a given node.
|
30
29
|
|
31
30
|
Args:
|
32
31
|
node: Node to create its prior info.
|
33
|
-
fw_info: Information about a specific framework the node was generated from.
|
34
32
|
graph: Graph to check the next node type.
|
35
33
|
|
36
34
|
Returns:
|
37
35
|
NodePriorInfo object with info about the node.
|
38
36
|
"""
|
39
37
|
|
40
|
-
min_output, max_output =
|
41
|
-
if fw_info.layers_has_min_max(node.type):
|
42
|
-
min_output, max_output = fw_info.layer_min_max_mapping[node.type]
|
38
|
+
min_output, max_output = node.minmax
|
43
39
|
mean_output, std_output = _get_mean_std_outputs(node=node,
|
44
40
|
graph=graph)
|
45
41
|
return NodePriorInfo(min_output=min_output,
|
@@ -27,7 +27,7 @@ from model_compression_toolkit.target_platform_capabilities.tpc_io_handler impor
|
|
27
27
|
from model_compression_toolkit.verify_packages import FOUND_TORCH
|
28
28
|
|
29
29
|
if FOUND_TORCH:
|
30
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import
|
30
|
+
from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
|
31
31
|
from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
|
32
32
|
from torch.nn import Module
|
33
33
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
|
@@ -38,6 +38,7 @@ if FOUND_TORCH:
|
|
38
38
|
PYTORCH_DEFAULT_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
39
39
|
|
40
40
|
|
41
|
+
@set_pytorch_info
|
41
42
|
def pytorch_resource_utilization_data(in_model: Module,
|
42
43
|
representative_data_gen: Callable,
|
43
44
|
core_config: CoreConfig = CoreConfig(),
|
@@ -93,7 +94,6 @@ if FOUND_TORCH:
|
|
93
94
|
representative_data_gen,
|
94
95
|
core_config,
|
95
96
|
target_platform_capabilities,
|
96
|
-
DEFAULT_PYTORCH_INFO,
|
97
97
|
fw_impl)
|
98
98
|
|
99
99
|
else:
|
@@ -33,7 +33,6 @@ def activation_bias_correction_node_matchers():
|
|
33
33
|
|
34
34
|
def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
|
35
35
|
quant_config: QuantizationConfig,
|
36
|
-
fw_info: FrameworkInfo,
|
37
36
|
fw_impl: FrameworkImplementation) -> Graph:
|
38
37
|
"""
|
39
38
|
Compute the activation bias correction term for graph based on a PyTorch model.
|
@@ -41,7 +40,6 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
|
|
41
40
|
Args:
|
42
41
|
graph: Graph with nodes to compute the activation bias correction.
|
43
42
|
quant_config: QuantizationConfig of how the model should be quantized.
|
44
|
-
fw_info: Framework info like lists of nodes their kernel should quantized.
|
45
43
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
46
44
|
|
47
45
|
Returns:
|
@@ -49,7 +47,6 @@ def pytorch_compute_activation_bias_correction_of_graph(graph: Graph,
|
|
49
47
|
"""
|
50
48
|
graph = compute_activation_bias_correction_of_graph(graph=graph,
|
51
49
|
quant_config=quant_config,
|
52
|
-
fw_info=fw_info,
|
53
50
|
fw_impl=fw_impl,
|
54
51
|
activation_bias_correction_node_matchers=
|
55
52
|
activation_bias_correction_node_matchers,
|
@@ -37,7 +37,6 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
|
|
37
37
|
def quantization_preparation_runner(graph: Graph,
|
38
38
|
representative_data_gen: Callable,
|
39
39
|
core_config: CoreConfig,
|
40
|
-
fw_info: FrameworkInfo,
|
41
40
|
fw_impl: FrameworkImplementation,
|
42
41
|
tb_w: TensorboardWriter = None,
|
43
42
|
hessian_info_service: HessianInfoService = None, ) -> Graph:
|
@@ -53,8 +52,6 @@ def quantization_preparation_runner(graph: Graph,
|
|
53
52
|
graph: A graph representation of the model to be quantized.
|
54
53
|
representative_data_gen: Dataset used for calibration.
|
55
54
|
core_config: CoreConfig containing parameters of how the model should be quantized
|
56
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
57
|
-
groups of layers by how they should be quantized, etc.).
|
58
55
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
59
56
|
tb_w: TensorboardWriter object for logging
|
60
57
|
hessian_info_service: HessianInfoService object for retrieving Hessian-based scores.
|
@@ -68,7 +65,6 @@ def quantization_preparation_runner(graph: Graph,
|
|
68
65
|
######################################
|
69
66
|
mi = ModelCollector(graph,
|
70
67
|
fw_impl,
|
71
|
-
fw_info,
|
72
68
|
hessian_info_service,
|
73
69
|
core_config.quantization_config) # Mark points for statistics collection
|
74
70
|
|
@@ -85,7 +81,7 @@ def quantization_preparation_runner(graph: Graph,
|
|
85
81
|
# Notice that not all actions affect at this stage (for example, actions that edit the final configuration as
|
86
82
|
# there are no final configurations at this stage of the optimization). For this reason we edit the graph
|
87
83
|
# again at the end of the optimization process.
|
88
|
-
edit_network_graph(graph,
|
84
|
+
edit_network_graph(graph, core_config.debug_config.network_editor)
|
89
85
|
|
90
86
|
######################################
|
91
87
|
# Calculate quantization params
|
@@ -109,8 +105,7 @@ def quantization_preparation_runner(graph: Graph,
|
|
109
105
|
######################################
|
110
106
|
if core_config.quantization_config.shift_negative_activation_correction:
|
111
107
|
transformed_graph = fw_impl.shift_negative_correction(transformed_graph,
|
112
|
-
core_config
|
113
|
-
fw_info)
|
108
|
+
core_config)
|
114
109
|
if tb_w is not None:
|
115
110
|
tb_w.add_graph(transformed_graph, 'after_shift_negative_correction')
|
116
111
|
tb_w.add_all_statistics(transformed_graph, 'after_shift_negative_correction')
|
@@ -122,9 +117,9 @@ def quantization_preparation_runner(graph: Graph,
|
|
122
117
|
######################################
|
123
118
|
# Statistics Correction
|
124
119
|
######################################
|
125
|
-
tg_with_bias = statistics_correction_runner(transformed_graph, core_config,
|
120
|
+
tg_with_bias = statistics_correction_runner(transformed_graph, core_config, fw_impl, tb_w)
|
126
121
|
|
127
122
|
for n in tg_with_bias.nodes:
|
128
123
|
assert n.final_weights_quantization_cfg is None
|
129
124
|
|
130
|
-
return tg_with_bias
|
125
|
+
return tg_with_bias
|
@@ -16,7 +16,6 @@
|
|
16
16
|
import copy
|
17
17
|
from typing import Callable, Any, List, Optional
|
18
18
|
|
19
|
-
from model_compression_toolkit.core.common import FrameworkInfo
|
20
19
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
20
|
from model_compression_toolkit.core.common.fusion.graph_fuser import GraphFuser
|
22
21
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
@@ -46,7 +45,6 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
|
|
46
45
|
def core_runner(in_model: Any,
|
47
46
|
representative_data_gen: Callable,
|
48
47
|
core_config: CoreConfig,
|
49
|
-
fw_info: FrameworkInfo,
|
50
48
|
fw_impl: FrameworkImplementation,
|
51
49
|
fqc: FrameworkQuantizationCapabilities,
|
52
50
|
target_resource_utilization: ResourceUtilization = None,
|
@@ -65,7 +63,6 @@ def core_runner(in_model: Any,
|
|
65
63
|
in_model: Model to quantize.
|
66
64
|
representative_data_gen: Dataset used for calibration.
|
67
65
|
core_config: CoreConfig containing parameters of how the model should be quantized
|
68
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
69
66
|
groups of layers by how they should be quantized, etc.).
|
70
67
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
71
68
|
fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
|
@@ -99,7 +96,6 @@ def core_runner(in_model: Any,
|
|
99
96
|
graph = graph_preparation_runner(in_model,
|
100
97
|
representative_data_gen,
|
101
98
|
core_config.quantization_config,
|
102
|
-
fw_info,
|
103
99
|
fw_impl,
|
104
100
|
fqc,
|
105
101
|
core_config.bit_width_config,
|
@@ -112,7 +108,6 @@ def core_runner(in_model: Any,
|
|
112
108
|
tg = quantization_preparation_runner(graph=graph,
|
113
109
|
representative_data_gen=representative_data_gen,
|
114
110
|
core_config=core_config,
|
115
|
-
fw_info=fw_info,
|
116
111
|
fw_impl=fw_impl,
|
117
112
|
tb_w=tb_w,
|
118
113
|
hessian_info_service=hessian_info_service)
|
@@ -123,9 +118,8 @@ def core_runner(in_model: Any,
|
|
123
118
|
if core_config.is_mixed_precision_enabled:
|
124
119
|
if core_config.mixed_precision_config.configuration_overwrite is None:
|
125
120
|
|
126
|
-
filter_candidates_for_mixed_precision(graph, target_resource_utilization,
|
121
|
+
filter_candidates_for_mixed_precision(graph, target_resource_utilization, fqc)
|
127
122
|
bit_widths_config = search_bit_width(tg,
|
128
|
-
fw_info,
|
129
123
|
fw_impl,
|
130
124
|
target_resource_utilization,
|
131
125
|
core_config.mixed_precision_config,
|
@@ -153,22 +147,20 @@ def core_runner(in_model: Any,
|
|
153
147
|
######################################
|
154
148
|
if core_config.quantization_config.activation_bias_correction:
|
155
149
|
tg = fw_impl.compute_activation_bias_correction(graph=tg,
|
156
|
-
quant_config=core_config.quantization_config
|
157
|
-
fw_info=fw_info)
|
150
|
+
quant_config=core_config.quantization_config)
|
158
151
|
|
159
152
|
# Edit the graph again after finalizing the configurations.
|
160
153
|
# This is since some actions regard the final configuration and should be edited.
|
161
|
-
edit_network_graph(tg,
|
154
|
+
edit_network_graph(tg, core_config.debug_config.network_editor)
|
162
155
|
|
163
156
|
_set_final_resource_utilization(graph=tg,
|
164
157
|
final_bit_widths_config=bit_widths_config,
|
165
158
|
target_resource_utilization=target_resource_utilization,
|
166
|
-
fw_info=fw_info,
|
167
159
|
fw_impl=fw_impl)
|
168
160
|
|
169
161
|
if core_config.is_mixed_precision_enabled:
|
170
162
|
# Retrieve lists of tuples (node, node's final weights/activation bitwidth)
|
171
|
-
weights_conf_nodes_bitwidth = tg.get_final_weights_config(
|
163
|
+
weights_conf_nodes_bitwidth = tg.get_final_weights_config()
|
172
164
|
activation_conf_nodes_bitwidth = tg.get_final_activation_config()
|
173
165
|
|
174
166
|
if len(weights_conf_nodes_bitwidth) > 0:
|
@@ -200,7 +192,6 @@ def core_runner(in_model: Any,
|
|
200
192
|
def _set_final_resource_utilization(graph: Graph,
|
201
193
|
final_bit_widths_config: List[int],
|
202
194
|
target_resource_utilization: Optional[ResourceUtilization],
|
203
|
-
fw_info: FrameworkInfo,
|
204
195
|
fw_impl: FrameworkImplementation):
|
205
196
|
"""
|
206
197
|
Computing the resource utilization of the model according to the final bit-width configuration,
|
@@ -210,14 +201,13 @@ def _set_final_resource_utilization(graph: Graph,
|
|
210
201
|
graph: Graph to compute the resource utilization for.
|
211
202
|
final_bit_widths_config: The final bit-width configuration to quantize the model accordingly.
|
212
203
|
target_resource_utilization: Requested target resource utilization if relevant.
|
213
|
-
fw_info: A FrameworkInfo object.
|
214
204
|
fw_impl: FrameworkImplementation object with specific framework methods implementation.
|
215
205
|
|
216
206
|
"""
|
217
207
|
ru_targets = target_resource_utilization.get_restricted_targets() if target_resource_utilization else None
|
218
208
|
final_ru = None
|
219
209
|
if ru_targets:
|
220
|
-
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl
|
210
|
+
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl)
|
221
211
|
w_qcs = {n.name: n.final_weights_quantization_cfg for n in graph.nodes}
|
222
212
|
a_qcs = {n.name: n.final_activation_quantization_cfg for n in graph.nodes}
|
223
213
|
final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused,
|
@@ -56,7 +56,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
56
56
|
super(ReduceLROnPlateau, self).__init__()
|
57
57
|
|
58
58
|
if factor >= 1.0:
|
59
|
-
Logger.critical('Factor should be < 1.0.')
|
59
|
+
Logger.critical('Factor should be < 1.0.') # pragma: no cover
|
60
60
|
self.factor = factor
|
61
61
|
|
62
62
|
self.optimizer = optimizer
|
@@ -101,7 +101,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
101
101
|
else:
|
102
102
|
self.num_bad_epochs += 1
|
103
103
|
|
104
|
-
if self.in_cooldown:
|
104
|
+
if self.in_cooldown: # pragma: no cover
|
105
105
|
self.cooldown_counter -= 1
|
106
106
|
self.num_bad_epochs = 0 # Ignore any bad epochs in cooldown
|
107
107
|
|
@@ -122,7 +122,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
122
122
|
new_lr = max(old_lr * self.factor, self.min_lr)
|
123
123
|
if old_lr - new_lr > self.eps:
|
124
124
|
tf.keras.backend.set_value(self.optimizer.learning_rate, new_lr)
|
125
|
-
if self.verbose:
|
125
|
+
if self.verbose: # pragma: no cover
|
126
126
|
print(f'Epoch {epoch:05d}: reducing learning rate to {new_lr:.4e}.')
|
127
127
|
|
128
128
|
@property
|
@@ -152,13 +152,13 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
152
152
|
if self.mode == 'min' and self.threshold_mode == 'rel':
|
153
153
|
rel_epsilon = 1. - self.threshold
|
154
154
|
return a < best * rel_epsilon
|
155
|
-
elif self.mode == 'min' and self.threshold_mode == 'abs':
|
155
|
+
elif self.mode == 'min' and self.threshold_mode == 'abs': # pragma: no cover
|
156
156
|
return a < best - self.threshold
|
157
|
-
elif self.mode == 'max' and self.threshold_mode == 'rel':
|
157
|
+
elif self.mode == 'max' and self.threshold_mode == 'rel': # pragma: no cover
|
158
158
|
rel_epsilon = self.threshold + 1.
|
159
159
|
return a > best * rel_epsilon
|
160
160
|
else: # mode == 'max' and threshold_mode == 'abs':
|
161
|
-
return a > best + self.threshold
|
161
|
+
return a > best + self.threshold # pragma: no cover
|
162
162
|
|
163
163
|
def _init_is_better(self, mode: str, threshold: float, threshold_mode: str) -> None:
|
164
164
|
"""
|
@@ -186,7 +186,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
186
186
|
self.threshold = threshold
|
187
187
|
self.threshold_mode = threshold_mode
|
188
188
|
|
189
|
-
def get_config(self) -> Dict:
|
189
|
+
def get_config(self) -> Dict: # pragma: no cover
|
190
190
|
"""
|
191
191
|
Return the configuration of the scheduler as a dictionary.
|
192
192
|
|
@@ -207,7 +207,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
|
|
207
207
|
base_config = super(ReduceLROnPlateau, self).get_config()
|
208
208
|
return {**base_config, **config}
|
209
209
|
|
210
|
-
def set_config(self, config: Dict) -> None:
|
210
|
+
def set_config(self, config: Dict) -> None: # pragma: no cover
|
211
211
|
"""
|
212
212
|
Set the configuration of the scheduler from a dictionary.
|
213
213
|
|