mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +2 -5
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
- model_compression_toolkit/core/common/framework_implementation.py +10 -22
- model_compression_toolkit/core/common/framework_info.py +105 -68
- model_compression_toolkit/core/common/graph/base_graph.py +15 -42
- model_compression_toolkit/core/common/graph/base_node.py +103 -42
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
- model_compression_toolkit/core/common/model_collector.py +10 -20
- model_compression_toolkit/core/common/model_validation.py +1 -4
- model_compression_toolkit/core/common/network_editors/actions.py +14 -38
- model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
- model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
- model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
- model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
- model_compression_toolkit/core/common/pruning/pruner.py +1 -6
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
- model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
- model_compression_toolkit/core/graph_prep_runner.py +2 -16
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/keras/default_framework_info.py +138 -87
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
- model_compression_toolkit/core/keras/keras_implementation.py +15 -35
- model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/quantization_prep_runner.py +4 -9
- model_compression_toolkit/core/runner.py +5 -15
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
- model_compression_toolkit/gptq/common/gptq_training.py +1 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
- model_compression_toolkit/gptq/keras/graph_info.py +4 -6
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/runner.py +1 -7
- model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
- model_compression_toolkit/ptq/runner.py +1 -4
- model_compression_toolkit/qat/common/qat_config.py +2 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
- model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
- model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
- model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py
CHANGED
@@ -227,15 +227,13 @@ def is_padding_node_and_node_has_padding(pad_node_to_consider: BaseNode,
|
|
227
227
|
|
228
228
|
|
229
229
|
def keras_apply_shift_negative_correction(graph: Graph,
|
230
|
-
core_config: CoreConfig
|
231
|
-
fw_info: FrameworkInfo) -> Graph:
|
230
|
+
core_config: CoreConfig) -> Graph:
|
232
231
|
"""
|
233
232
|
Apply shift negative correction (SNC) on a graph built from a Keras model.
|
234
233
|
|
235
234
|
Args:
|
236
235
|
graph: Graph to apply SNC on.
|
237
236
|
core_config: Quantization configuration.
|
238
|
-
fw_info: FrameworkInfo object with information about the specific framework's module.
|
239
237
|
|
240
238
|
Returns:
|
241
239
|
Graph after SNC.
|
@@ -244,7 +242,6 @@ def keras_apply_shift_negative_correction(graph: Graph,
|
|
244
242
|
|
245
243
|
return apply_shift_negative_correction(graph,
|
246
244
|
core_config,
|
247
|
-
fw_info,
|
248
245
|
snc_node,
|
249
246
|
linear_node,
|
250
247
|
bypass_node,
|
@@ -22,7 +22,6 @@ from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS, MIN_HESS
|
|
22
22
|
from model_compression_toolkit.core.common import Graph
|
23
23
|
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
|
24
24
|
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
|
25
|
-
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
26
25
|
from model_compression_toolkit.core.keras.hessian.hessian_scores_calculator_keras import HessianScoresCalculatorKeras
|
27
26
|
from model_compression_toolkit.logger import Logger
|
28
27
|
|
@@ -95,20 +94,11 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
|
|
95
94
|
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
|
96
95
|
|
97
96
|
# Check if the target node's layer type is supported.
|
98
|
-
if not
|
97
|
+
if not ipt_node.is_kernel_op:
|
99
98
|
Logger.critical(f"Hessian information with respect to weights is not supported for "
|
100
99
|
f"{ipt_node.type} layers.") # pragma: no cover
|
101
100
|
|
102
|
-
|
103
|
-
weight_attributes = DEFAULT_KERAS_INFO.get_kernel_op_attributes(ipt_node.type)
|
104
|
-
|
105
|
-
# Get the weight tensor for the target node
|
106
|
-
if len(weight_attributes) != 1: # pragma: no cover
|
107
|
-
Logger.critical(
|
108
|
-
f"Hessian-based scoring with respect to weights is currently supported only for nodes with "
|
109
|
-
f"a single weight attribute. Found {len(weight_attributes)} attributes.")
|
110
|
-
|
111
|
-
weight_tensor = getattr(model.get_layer(ipt_node.name), weight_attributes[0])
|
101
|
+
weight_tensor = getattr(model.get_layer(ipt_node.name), ipt_node.kernel_attr)
|
112
102
|
|
113
103
|
if j == 0:
|
114
104
|
# On the first iteration we store the weight_tensor shape for later reshaping the results
|
@@ -116,7 +106,7 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
|
|
116
106
|
tensors_original_shape.append(weight_tensor.shape)
|
117
107
|
|
118
108
|
# Get the output channel index (needed for HessianInfoGranularity.PER_OUTPUT_CHANNEL case)
|
119
|
-
output_channel_axis
|
109
|
+
output_channel_axis = ipt_node.channel_axis.output
|
120
110
|
|
121
111
|
# Get number of scores that should be calculated by the granularity.
|
122
112
|
num_of_scores = self._get_num_scores_by_granularity(weight_tensor,
|
@@ -65,7 +65,6 @@ from model_compression_toolkit.core.common import Graph, BaseNode
|
|
65
65
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
66
66
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
67
67
|
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
|
68
|
-
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
69
68
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.activation_decomposition import \
|
70
69
|
ActivationDecomposition
|
71
70
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.matmul_substitution import \
|
@@ -175,18 +174,16 @@ class KerasImplementation(FrameworkImplementation):
|
|
175
174
|
graph: Graph,
|
176
175
|
mode: ModelBuilderMode,
|
177
176
|
append2output: List[Any] = None,
|
178
|
-
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
179
177
|
return_float_outputs: bool = False) -> Tuple:
|
180
178
|
"""
|
181
179
|
Build a Keras model from a graph.
|
182
|
-
The mode determines how the model should be
|
180
|
+
The mode determines how the model should be built. append2output is a list of Nodes
|
183
181
|
to set as the model outputs.
|
184
182
|
|
185
183
|
Args:
|
186
184
|
graph: Graph to build the model from it.
|
187
185
|
mode: Mode for how to build the model.
|
188
186
|
append2output: List of Nodes to set as the model's outputs.
|
189
|
-
fw_info: FrameworkInfo object with information about the specific framework's model
|
190
187
|
return_float_outputs (bool): whether to return outputs before or after quantization nodes (default)
|
191
188
|
Returns:
|
192
189
|
A tuple with the model and additional relevant supporting objects.
|
@@ -195,7 +192,6 @@ class KerasImplementation(FrameworkImplementation):
|
|
195
192
|
keras_model_builder = get_keras_model_builder(mode)
|
196
193
|
return keras_model_builder(graph=graph,
|
197
194
|
append2output=append2output,
|
198
|
-
fw_info=fw_info,
|
199
195
|
return_float_outputs=return_float_outputs).build_model()
|
200
196
|
|
201
197
|
def run_model_inference(self,
|
@@ -227,65 +223,57 @@ class KerasImplementation(FrameworkImplementation):
|
|
227
223
|
|
228
224
|
def shift_negative_correction(self,
|
229
225
|
graph: Graph,
|
230
|
-
core_config: CoreConfig
|
231
|
-
fw_info: FrameworkInfo) -> Graph:
|
226
|
+
core_config: CoreConfig) -> Graph:
|
232
227
|
"""
|
233
228
|
Apply shift negative correction (SNC) on a graph.
|
234
229
|
|
235
230
|
Args:
|
236
231
|
graph: Graph to apply SNC on.
|
237
232
|
core_config: Quantization configuration.
|
238
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
239
233
|
|
240
234
|
Returns:
|
241
235
|
Graph after SNC.
|
242
236
|
"""
|
243
237
|
return keras_apply_shift_negative_correction(graph,
|
244
|
-
core_config
|
245
|
-
fw_info)
|
238
|
+
core_config)
|
246
239
|
|
247
240
|
def compute_activation_bias_correction(self,
|
248
241
|
graph: Graph,
|
249
|
-
quant_config: QuantizationConfig
|
250
|
-
fw_info: FrameworkInfo):
|
242
|
+
quant_config: QuantizationConfig):
|
251
243
|
"""
|
252
244
|
Compute activation bias correction on a graph.
|
253
245
|
|
254
246
|
Args:
|
255
247
|
graph: Graph to apply activation bias correction on.
|
256
248
|
quant_config: QuantizationConfig of how the model should be quantized.
|
257
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
258
249
|
|
259
250
|
Returns:
|
260
251
|
Graph after activation bias correction computing.
|
261
252
|
"""
|
262
253
|
return keras_compute_activation_bias_correction_of_graph(graph=graph,
|
263
254
|
quant_config=quant_config,
|
264
|
-
fw_info=fw_info,
|
265
255
|
fw_impl=self)
|
266
256
|
|
267
257
|
def get_substitutions_channel_equalization(self,
|
268
|
-
quant_config: QuantizationConfig
|
269
|
-
fw_info: FrameworkInfo) -> List[common.BaseSubstitution]:
|
258
|
+
quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
|
270
259
|
"""
|
271
260
|
Return a list of the framework substitutions used for channel equalization.
|
272
261
|
|
273
262
|
Args:
|
274
263
|
quant_config: QuantizationConfig to determine which substitutions to return.
|
275
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
276
264
|
|
277
265
|
Returns:
|
278
266
|
A list of the framework substitutions used after we collect statistics.
|
279
267
|
"""
|
280
268
|
substitutions_list = []
|
281
269
|
if quant_config.activation_channel_equalization:
|
282
|
-
substitutions_list.extend([ScaleEqualization(quant_config
|
283
|
-
ScaleEqualizationWithPad(quant_config
|
284
|
-
ScaleEqualizationMidActivation(quant_config
|
285
|
-
ScaleEqualizationMidActivationWithPad(quant_config
|
270
|
+
substitutions_list.extend([ScaleEqualization(quant_config),
|
271
|
+
ScaleEqualizationWithPad(quant_config),
|
272
|
+
ScaleEqualizationMidActivation(quant_config),
|
273
|
+
ScaleEqualizationMidActivationWithPad(quant_config)])
|
286
274
|
return substitutions_list
|
287
275
|
|
288
|
-
def get_substitutions_prepare_graph(self
|
276
|
+
def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
|
289
277
|
"""
|
290
278
|
|
291
279
|
Returns: A list of the framework substitutions used to prepare the graph.
|
@@ -402,22 +390,19 @@ class KerasImplementation(FrameworkImplementation):
|
|
402
390
|
|
403
391
|
def get_node_prior_info(self,
|
404
392
|
node: BaseNode,
|
405
|
-
fw_info: FrameworkInfo,
|
406
393
|
graph: Graph) -> NodePriorInfo:
|
407
394
|
"""
|
408
395
|
Get a NodePriorInfo object for a node that represents a Keras layer.
|
409
396
|
|
410
397
|
Args:
|
411
398
|
node: Node to get its prior info.
|
412
|
-
fw_info: Framework specific information needed to create the prior info of the node.
|
413
399
|
graph: Graph to check the next node type.
|
414
400
|
|
415
401
|
Returns:
|
416
402
|
NodePriorInfo with information about the node.
|
417
403
|
"""
|
418
404
|
|
419
|
-
return create_node_prior_info(node=node,
|
420
|
-
fw_info=fw_info, graph=graph)
|
405
|
+
return create_node_prior_info(node=node, graph=graph)
|
421
406
|
|
422
407
|
def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
|
423
408
|
"""
|
@@ -530,23 +515,19 @@ class KerasImplementation(FrameworkImplementation):
|
|
530
515
|
return True
|
531
516
|
|
532
517
|
def get_node_mac_operations(self,
|
533
|
-
node: BaseNode
|
534
|
-
fw_info: FrameworkInfo) -> float:
|
518
|
+
node: BaseNode) -> float:
|
535
519
|
"""
|
536
520
|
Gets the MAC operation count for a given operation.
|
537
521
|
|
538
522
|
Args:
|
539
523
|
node: A graph node that wraps the operation for which the MAC count is computed.
|
540
|
-
fw_info: FrameworkInfo object with information about the Keras model.
|
541
524
|
|
542
525
|
Returns: The MAC count og the operation
|
543
526
|
"""
|
544
|
-
|
545
|
-
if not kernels or kernels[0] is None:
|
527
|
+
if node.kernel_attr is None:
|
546
528
|
return 0
|
547
529
|
|
548
|
-
|
549
|
-
kernel_shape = node.get_weights_by_keys(kernels[0]).shape
|
530
|
+
kernel_shape = node.get_weights_by_keys(node.kernel_attr).shape
|
550
531
|
|
551
532
|
if node.is_match_type(Conv2D) or node.is_match_type(Conv2DTranspose) or node.is_match_type(DepthwiseConv2D):
|
552
533
|
h, w = node.get_output_shapes_list()[0][-3:-1]
|
@@ -554,8 +535,7 @@ class KerasImplementation(FrameworkImplementation):
|
|
554
535
|
|
555
536
|
if node.is_match_type(Dense):
|
556
537
|
# IN * OUT * (all previous dims[:-1])
|
557
|
-
|
558
|
-
return node.get_total_output_params() * kernel_shape[input_channel_axis]
|
538
|
+
return node.get_total_output_params() * kernel_shape[node.channel_axis.input]
|
559
539
|
|
560
540
|
return 0
|
561
541
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from tensorflow.keras.models import Model
|
2
2
|
|
3
|
-
from model_compression_toolkit.core import
|
3
|
+
from model_compression_toolkit.core.common.framework_info import get_fw_info
|
4
4
|
from model_compression_toolkit.core.common.framework_info import ChannelAxis
|
5
5
|
from model_compression_toolkit.core.common.model_validation import ModelValidation
|
6
6
|
from model_compression_toolkit.core.keras.constants import CHANNELS_FORMAT, CHANNELS_FORMAT_LAST, CHANNELS_FORMAT_FIRST
|
@@ -11,17 +11,15 @@ class KerasModelValidation(ModelValidation):
|
|
11
11
|
Class to define validation methods in order to validate the received Keras model to quantize.
|
12
12
|
"""
|
13
13
|
|
14
|
-
def __init__(self, model: Model
|
14
|
+
def __init__(self, model: Model):
|
15
15
|
"""
|
16
16
|
Initialize a KerasModelValidation object.
|
17
17
|
|
18
18
|
Args:
|
19
19
|
model: Keras model to check its validity.
|
20
|
-
fw_info: Information about the framework of the model (Keras).
|
21
20
|
"""
|
22
21
|
|
23
|
-
super(KerasModelValidation, self).__init__(model=model
|
24
|
-
fw_info=fw_info)
|
22
|
+
super(KerasModelValidation, self).__init__(model=model)
|
25
23
|
|
26
24
|
def validate_output_channel_consistency(self):
|
27
25
|
"""
|
@@ -30,9 +28,10 @@ class KerasModelValidation(ModelValidation):
|
|
30
28
|
If the model has layers with different output channels index, an exception is thrown.
|
31
29
|
|
32
30
|
"""
|
31
|
+
fw_info = get_fw_info()
|
33
32
|
for layer in self.model.layers:
|
34
33
|
data_format = layer.get_config().get(CHANNELS_FORMAT)
|
35
34
|
if data_format is not None:
|
36
|
-
assert (data_format == CHANNELS_FORMAT_LAST and
|
37
|
-
or data_format == CHANNELS_FORMAT_FIRST and
|
35
|
+
assert (data_format == CHANNELS_FORMAT_LAST and fw_info.get_out_channel_axis(layer) == ChannelAxis.NHWC.value
|
36
|
+
or data_format == CHANNELS_FORMAT_FIRST and fw_info.get_out_channel_axis(layer) == ChannelAxis.NCHW.value), \
|
38
37
|
f'Model can not have layers with different data formats.'
|
@@ -17,22 +17,19 @@ from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
17
17
|
|
18
18
|
|
19
19
|
def create_node_prior_info(node: BaseNode,
|
20
|
-
fw_info: FrameworkInfo,
|
21
20
|
graph: Graph):
|
22
21
|
"""
|
23
22
|
Create a NodePriorInfo object for a given node.
|
24
23
|
|
25
24
|
Args:
|
26
25
|
node: Node to create its prior info.
|
27
|
-
fw_info: Information about a specific framework the node was generated from.
|
28
26
|
graph: Graph to check the next node type.
|
29
27
|
|
30
28
|
Returns:
|
31
29
|
NodePriorInfo object with info about the node.
|
32
30
|
"""
|
33
31
|
|
34
|
-
min_output, max_output = _get_min_max_outputs(node=node
|
35
|
-
fw_info=fw_info)
|
32
|
+
min_output, max_output = _get_min_max_outputs(node=node)
|
36
33
|
|
37
34
|
mean_output, std_output = _get_mean_std_outputs(node=node,
|
38
35
|
graph=graph)
|
@@ -42,14 +39,12 @@ def create_node_prior_info(node: BaseNode,
|
|
42
39
|
std_output=std_output)
|
43
40
|
|
44
41
|
|
45
|
-
def _get_min_max_outputs(node: BaseNode,
|
46
|
-
fw_info: FrameworkInfo) -> Tuple[Any, Any]:
|
42
|
+
def _get_min_max_outputs(node: BaseNode) -> Tuple[Any, Any]:
|
47
43
|
"""
|
48
44
|
Return the min/max output values of a node if known.
|
49
45
|
If one of them (or both of them) is unknown - return None instead of a value.
|
50
46
|
Args:
|
51
47
|
node: Node to create its prior info.
|
52
|
-
fw_info: Information about a specific framework the node was generated from.
|
53
48
|
|
54
49
|
Returns:
|
55
50
|
Min/max output values if known.
|
@@ -58,12 +53,8 @@ def _get_min_max_outputs(node: BaseNode,
|
|
58
53
|
|
59
54
|
if node.is_match_type(ReLU):
|
60
55
|
min_output = node.framework_attr[THRESHOLD] if node.framework_attr[NEGATIVE_SLOPE] == 0 else None
|
61
|
-
|
62
|
-
|
63
|
-
min_output, max_output = fw_info.layer_min_max_mapping[node.type]
|
64
|
-
|
65
|
-
elif node.is_match_type(Activation) and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]):
|
66
|
-
min_output, max_output = fw_info.activation_min_max_mapping[node.framework_attr[ACTIVATION]]
|
56
|
+
else:
|
57
|
+
min_output, max_output = node.minmax
|
67
58
|
|
68
59
|
return min_output, max_output
|
69
60
|
|
@@ -19,7 +19,6 @@ from model_compression_toolkit.core.common.pruning.pruning_framework_implementat
|
|
19
19
|
PruningFrameworkImplementation
|
20
20
|
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
|
21
21
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
22
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
23
22
|
from model_compression_toolkit.core.common import BaseNode
|
24
23
|
from model_compression_toolkit.core.keras.constants import BIAS, GROUPS, FILTERS, UNITS, USE_BIAS
|
25
24
|
import keras
|
@@ -38,27 +37,23 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
38
37
|
|
39
38
|
def prune_entry_node(self,
|
40
39
|
node: BaseNode,
|
41
|
-
output_mask: np.ndarray
|
42
|
-
fw_info: FrameworkInfo):
|
40
|
+
output_mask: np.ndarray):
|
43
41
|
"""
|
44
42
|
Prunes the entry node of a model in Keras.
|
45
43
|
|
46
44
|
Args:
|
47
45
|
node (BaseNode): The entry node to be pruned.
|
48
46
|
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
49
|
-
fw_info (FrameworkInfo): Framework-specific information object.
|
50
47
|
|
51
48
|
"""
|
52
49
|
return _prune_keras_edge_node(node=node,
|
53
50
|
mask=output_mask,
|
54
|
-
fw_info=fw_info,
|
55
51
|
is_exit_node=False)
|
56
52
|
|
57
53
|
def prune_intermediate_node(self,
|
58
54
|
node: BaseNode,
|
59
55
|
input_mask: np.ndarray,
|
60
|
-
output_mask: np.ndarray
|
61
|
-
fw_info: FrameworkInfo):
|
56
|
+
output_mask: np.ndarray):
|
62
57
|
"""
|
63
58
|
Prunes an intermediate node in a Keras model.
|
64
59
|
|
@@ -66,7 +61,6 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
66
61
|
node (BaseNode): The intermediate node to be pruned.
|
67
62
|
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
68
63
|
output_mask (np.ndarray): A numpy array representing the mask to be applied to the output channels.
|
69
|
-
fw_info (FrameworkInfo): Framework-specific information object.
|
70
64
|
|
71
65
|
"""
|
72
66
|
_edit_node_input_shape(input_mask, node)
|
@@ -79,20 +73,17 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
79
73
|
|
80
74
|
def prune_exit_node(self,
|
81
75
|
node: BaseNode,
|
82
|
-
input_mask: np.ndarray
|
83
|
-
fw_info: FrameworkInfo):
|
76
|
+
input_mask: np.ndarray):
|
84
77
|
"""
|
85
78
|
Prunes the exit node of a model in Keras.
|
86
79
|
|
87
80
|
Args:
|
88
81
|
node (BaseNode): The exit node to be pruned.
|
89
82
|
input_mask (np.ndarray): A numpy array representing the mask to be applied to the input channels.
|
90
|
-
fw_info (FrameworkInfo): Framework-specific information object.
|
91
83
|
|
92
84
|
"""
|
93
85
|
return _prune_keras_edge_node(node=node,
|
94
86
|
mask=input_mask,
|
95
|
-
fw_info=fw_info,
|
96
87
|
is_exit_node=True)
|
97
88
|
|
98
89
|
def is_node_entry_node(self, node: BaseNode) -> bool:
|
@@ -109,22 +100,19 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
109
100
|
|
110
101
|
def is_node_exit_node(self,
|
111
102
|
node: BaseNode,
|
112
|
-
corresponding_entry_node: BaseNode
|
113
|
-
fw_info: FrameworkInfo) -> bool:
|
103
|
+
corresponding_entry_node: BaseNode) -> bool:
|
114
104
|
"""
|
115
105
|
Determines whether a node is an exit node in a Keras model.
|
116
106
|
|
117
107
|
Args:
|
118
108
|
node (BaseNode): The node to be checked.
|
119
109
|
corresponding_entry_node (BaseNode): The entry node of the pruning section that is checked.
|
120
|
-
fw_info (FrameworkInfo): Framework-specific information object.
|
121
110
|
|
122
111
|
Returns:
|
123
112
|
bool: Boolean indicating if the node is an exit node.
|
124
113
|
"""
|
125
114
|
return _is_keras_node_pruning_section_edge(node) and PruningSection.has_matching_channel_count(node,
|
126
|
-
corresponding_entry_node
|
127
|
-
fw_info)
|
115
|
+
corresponding_entry_node)
|
128
116
|
|
129
117
|
def is_node_intermediate_pruning_section(self, node: BaseNode) -> bool:
|
130
118
|
"""
|
@@ -143,8 +131,7 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
143
131
|
keras.layers.Dense]
|
144
132
|
|
145
133
|
def attrs_oi_channels_info_for_pruning(self,
|
146
|
-
node: BaseNode,
|
147
|
-
fw_info: FrameworkInfo) -> Dict[str, Tuple[int, int]]:
|
134
|
+
node: BaseNode) -> Dict[str, Tuple[int, int]]:
|
148
135
|
"""
|
149
136
|
Retrieves the attributes of a given node along with the output/input (OI) channel axis
|
150
137
|
for each attribute used to prune these attributes.
|
@@ -161,7 +148,6 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
161
148
|
|
162
149
|
Args:
|
163
150
|
node (BaseNode): The node from the computational graph.
|
164
|
-
fw_info (FrameworkInfo): Contains framework-specific information and utilities.
|
165
151
|
|
166
152
|
Returns:
|
167
153
|
Dict[str, Tuple[int, int]]: A dictionary where each key is an attribute name (like 'kernel' or 'bias')
|
@@ -169,13 +155,8 @@ class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementa
|
|
169
155
|
"""
|
170
156
|
|
171
157
|
attributes_with_axis = {}
|
172
|
-
if
|
173
|
-
|
174
|
-
if kernel_attributes is None or len(kernel_attributes)==0:
|
175
|
-
Logger.critical(f"Expected kernel attributes for operation for node type {node.type}, found None or empty.")
|
176
|
-
|
177
|
-
for attr in kernel_attributes:
|
178
|
-
attributes_with_axis[attr] = fw_info.kernel_channels_mapping.get(node.type)
|
158
|
+
if node.is_kernel_op:
|
159
|
+
attributes_with_axis[node.kernel_attr] = (node.channel_axis.output, node.channel_axis.input)
|
179
160
|
|
180
161
|
# Bias is a vector at the length of the number of output channels.
|
181
162
|
# For this reason, input channel axis is irrelevant to the bias attribute.
|
@@ -216,7 +197,6 @@ def _is_keras_node_pruning_section_edge(node: BaseNode) -> bool:
|
|
216
197
|
|
217
198
|
def _prune_keras_edge_node(node: BaseNode,
|
218
199
|
mask: np.ndarray,
|
219
|
-
fw_info: FrameworkInfo,
|
220
200
|
is_exit_node: bool):
|
221
201
|
"""
|
222
202
|
Prunes the given Keras node by applying the mask to the node's weights (kernels and biases).
|
@@ -225,21 +205,18 @@ def _prune_keras_edge_node(node: BaseNode,
|
|
225
205
|
Args:
|
226
206
|
node: The node to be pruned.
|
227
207
|
mask: The pruning mask to be applied.
|
228
|
-
fw_info: Framework-specific information object.
|
229
208
|
is_exit_node: A boolean indicating whether the node is an exit node.
|
230
209
|
|
231
210
|
"""
|
232
211
|
|
233
212
|
# Retrieve the kernel attribute and the axes to prune.
|
234
|
-
|
235
|
-
|
236
|
-
axis_to_prune = io_axis[int(is_exit_node)]
|
237
|
-
kernel = node.get_weights_by_keys(kernel_attr)
|
213
|
+
axis_to_prune = node.channel_axis.input if is_exit_node else node.channel_axis.output
|
214
|
+
kernel = node.get_weights_by_keys(node.kernel_attr)
|
238
215
|
# Convert mask to boolean.
|
239
216
|
mask_bool = mask.astype(bool)
|
240
217
|
|
241
218
|
pruned_kernel = kernel.compress(mask_bool, axis=axis_to_prune)
|
242
|
-
node.set_weights_by_keys(name=kernel_attr, tensor=pruned_kernel)
|
219
|
+
node.set_weights_by_keys(name=node.kernel_attr, tensor=pruned_kernel)
|
243
220
|
|
244
221
|
if not is_exit_node and node.framework_attr[USE_BIAS]:
|
245
222
|
# Prune the bias if applicable and it's an entry node.
|
@@ -27,7 +27,6 @@ if FOUND_TF:
|
|
27
27
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
|
28
28
|
AttachTpcToKeras
|
29
29
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
30
|
-
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
31
30
|
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
|
32
31
|
from tensorflow.keras.models import Model
|
33
32
|
|
@@ -93,7 +92,6 @@ if FOUND_TF:
|
|
93
92
|
representative_data_gen,
|
94
93
|
core_config,
|
95
94
|
target_platform_capabilities,
|
96
|
-
DEFAULT_KERAS_INFO,
|
97
95
|
fw_impl)
|
98
96
|
|
99
97
|
else:
|
@@ -43,7 +43,6 @@ def activation_bias_correction_node_matchers():
|
|
43
43
|
|
44
44
|
def keras_compute_activation_bias_correction_of_graph(graph: Graph,
|
45
45
|
quant_config: QuantizationConfig,
|
46
|
-
fw_info: FrameworkInfo,
|
47
46
|
fw_impl: FrameworkImplementation) -> Graph:
|
48
47
|
"""
|
49
48
|
Compute the activation bias correction term for graph based on a Keras model.
|
@@ -51,7 +50,6 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
|
|
51
50
|
Args:
|
52
51
|
graph: Graph with nodes to compute the activation bias correction.
|
53
52
|
quant_config: QuantizationConfig of how the model should be quantized.
|
54
|
-
fw_info: Framework info like lists of nodes their kernel should quantized.
|
55
53
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
56
54
|
|
57
55
|
Returns:
|
@@ -59,7 +57,6 @@ def keras_compute_activation_bias_correction_of_graph(graph: Graph,
|
|
59
57
|
"""
|
60
58
|
graph = compute_activation_bias_correction_of_graph(graph=graph,
|
61
59
|
quant_config=quant_config,
|
62
|
-
fw_info=fw_info,
|
63
60
|
fw_impl=fw_impl,
|
64
61
|
activation_bias_correction_node_matchers=
|
65
62
|
activation_bias_correction_node_matchers,
|
@@ -24,8 +24,6 @@ from model_compression_toolkit.core.common.user_info import UserInformation
|
|
24
24
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder, \
|
25
25
|
PytorchModel
|
26
26
|
|
27
|
-
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
28
|
-
|
29
27
|
|
30
28
|
class FloatPyTorchModel(PytorchModel):
|
31
29
|
"""
|
@@ -34,19 +32,16 @@ class FloatPyTorchModel(PytorchModel):
|
|
34
32
|
|
35
33
|
def __init__(self,
|
36
34
|
graph: common.Graph,
|
37
|
-
append2output=None
|
38
|
-
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO):
|
35
|
+
append2output=None):
|
39
36
|
"""
|
40
37
|
|
41
38
|
Args:
|
42
39
|
graph: Graph to build its corresponding Pytorch model.
|
43
40
|
append2output: List of nodes or OutTensor objects.
|
44
|
-
fw_info: Framework information (e.g., mapping from layers to their attributes to quantize).
|
45
41
|
"""
|
46
42
|
|
47
43
|
super().__init__(graph,
|
48
|
-
append2output
|
49
|
-
fw_info)
|
44
|
+
append2output)
|
50
45
|
|
51
46
|
def _quantize_node_activations(self,
|
52
47
|
node: BaseNode,
|
@@ -71,20 +66,17 @@ class FloatPyTorchModelBuilder(PyTorchModelBuilder):
|
|
71
66
|
def __init__(self,
|
72
67
|
graph: common.Graph,
|
73
68
|
append2output=None,
|
74
|
-
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
75
69
|
return_float_outputs: bool = False):
|
76
70
|
"""
|
77
71
|
|
78
72
|
Args:
|
79
73
|
graph: Graph to build the model from.
|
80
74
|
append2output: Nodes to append to model's output.
|
81
|
-
fw_info: Information about the specific framework of the model that is built.
|
82
75
|
return_float_outputs: Whether the model returns float tensors or not.
|
83
76
|
"""
|
84
77
|
|
85
78
|
super().__init__(graph,
|
86
79
|
append2output,
|
87
|
-
fw_info,
|
88
80
|
return_float_outputs)
|
89
81
|
|
90
82
|
def build_model(self) -> Tuple[PytorchModel, UserInformation]:
|
@@ -94,5 +86,4 @@ class FloatPyTorchModelBuilder(PyTorchModelBuilder):
|
|
94
86
|
|
95
87
|
"""
|
96
88
|
return FloatPyTorchModel(self.graph,
|
97
|
-
self.append2output,
|
98
|
-
self.fw_info), self.graph.user_info
|
89
|
+
self.append2output), self.graph.user_info
|